GNU Octave  9.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
kron.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2002-2024 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29 
30 #include "dMatrix.h"
31 #include "fMatrix.h"
32 #include "CMatrix.h"
33 #include "fCMatrix.h"
34 
35 #include "dSparse.h"
36 #include "CSparse.h"
37 
38 #include "dDiagMatrix.h"
39 #include "fDiagMatrix.h"
40 #include "CDiagMatrix.h"
41 #include "fCDiagMatrix.h"
42 
43 #include "PermMatrix.h"
44 
45 #include "mx-inlines.cc"
46 #include "quit.h"
47 
48 #include "defun.h"
49 #include "error.h"
50 #include "ovl.h"
51 
53 
54 template <typename R, typename T>
55 static MArray<T>
56 kron (const MArray<R>& a, const MArray<T>& b)
57 {
58  error_unless (a.ndims () == 2);
59  error_unless (b.ndims () == 2);
60 
61  octave_idx_type nra = a.rows ();
62  octave_idx_type nrb = b.rows ();
63  octave_idx_type nca = a.cols ();
64  octave_idx_type ncb = b.cols ();
65 
66  MArray<T> c (dim_vector (nra*nrb, nca*ncb));
67  T *cv = c.fortran_vec ();
68 
69  for (octave_idx_type ja = 0; ja < nca; ja++)
70  {
71  octave_quit ();
72  for (octave_idx_type jb = 0; jb < ncb; jb++)
73  {
74  for (octave_idx_type ia = 0; ia < nra; ia++)
75  {
76  mx_inline_mul (nrb, cv, a(ia, ja), b.data () + nrb*jb);
77  cv += nrb;
78  }
79  }
80  }
81 
82  return c;
83 }
84 
85 template <typename R, typename T>
86 static MArray<T>
87 kron (const MDiagArray2<R>& a, const MArray<T>& b)
88 {
89  error_unless (b.ndims () == 2);
90 
91  octave_idx_type nra = a.rows ();
92  octave_idx_type nrb = b.rows ();
93  octave_idx_type dla = a.diag_length ();
94  octave_idx_type nca = a.cols ();
95  octave_idx_type ncb = b.cols ();
96 
97  MArray<T> c (dim_vector (nra*nrb, nca*ncb), T ());
98 
99  for (octave_idx_type ja = 0; ja < dla; ja++)
100  {
101  octave_quit ();
102  for (octave_idx_type jb = 0; jb < ncb; jb++)
103  {
104  mx_inline_mul (nrb, &c.xelem (ja*nrb, ja*ncb + jb), a.dgelem (ja),
105  b.data () + nrb*jb);
106  }
107  }
108 
109  return c;
110 }
111 
112 template <typename T>
113 static MSparse<T>
114 kron (const MSparse<T>& A, const MSparse<T>& B)
115 {
116  octave_idx_type idx = 0;
117  MSparse<T> C (A.rows () * B.rows (), A.columns () * B.columns (),
118  A.nnz () * B.nnz ());
119 
120  C.cidx (0) = 0;
121 
122  for (octave_idx_type Aj = 0; Aj < A.columns (); Aj++)
123  {
124  octave_quit ();
125  for (octave_idx_type Bj = 0; Bj < B.columns (); Bj++)
126  {
127  for (octave_idx_type Ai = A.cidx (Aj); Ai < A.cidx (Aj+1); Ai++)
128  {
129  octave_idx_type Ci = A.ridx (Ai) * B.rows ();
130  const T v = A.data (Ai);
131 
132  for (octave_idx_type Bi = B.cidx (Bj); Bi < B.cidx (Bj+1); Bi++)
133  {
134  C.data (idx) = v * B.data (Bi);
135  C.ridx (idx++) = Ci + B.ridx (Bi);
136  }
137  }
138  C.cidx (Aj * B.columns () + Bj + 1) = idx;
139  }
140  }
141 
142  return C;
143 }
144 
145 static PermMatrix
146 kron (const PermMatrix& a, const PermMatrix& b)
147 {
148  octave_idx_type na = a.rows ();
149  octave_idx_type nb = b.rows ();
150  const Array<octave_idx_type>& pa = a.col_perm_vec ();
151  const Array<octave_idx_type>& pb = b.col_perm_vec ();
152  Array<octave_idx_type> res_perm (dim_vector (na * nb, 1));
153  octave_idx_type rescol = 0;
154  for (octave_idx_type i = 0; i < na; i++)
155  {
156  octave_idx_type a_add = pa(i) * nb;
157  for (octave_idx_type j = 0; j < nb; j++)
158  res_perm.xelem (rescol++) = a_add + pb(j);
159  }
160 
161  return PermMatrix (res_perm, true);
162 }
163 
164 template <typename MTA, typename MTB>
166 do_kron (const octave_value& a, const octave_value& b)
167 {
168  MTA am = octave_value_extract<MTA> (a);
169  MTB bm = octave_value_extract<MTB> (b);
170 
171  return octave_value (kron (am, bm));
172 }
173 
176 {
177  octave_value retval;
178  if (a.is_perm_matrix () && b.is_perm_matrix ())
179  retval = do_kron<PermMatrix, PermMatrix> (a, b);
180  else if (a.issparse () || b.issparse ())
181  {
182  if (a.iscomplex () || b.iscomplex ())
183  retval = do_kron<SparseComplexMatrix, SparseComplexMatrix> (a, b);
184  else
185  retval = do_kron<SparseMatrix, SparseMatrix> (a, b);
186  }
187  else if (a.is_diag_matrix ())
188  {
189  if (b.is_diag_matrix () && a.rows () == a.columns ()
190  && b.rows () == b.columns ())
191  {
192  // We have two diagonal matrices, the product of those will be
193  // another diagonal matrix. To do that efficiently, extract
194  // the diagonals as vectors and compute the product. That
195  // will be another vector, which we then use to construct a
196  // diagonal matrix object. Note that this will fail if our
197  // digaonal matrix object is modified to allow the nonzero
198  // values to be stored off of the principal diagonal (i.e., if
199  // diag ([1,2], 3) is modified to return a diagonal matrix
200  // object instead of a full matrix object).
201 
202  octave_value tmp = dispatch_kron (a.diag (), b.diag ());
203  retval = tmp.diag ();
204  }
205  else if (a.is_single_type () || b.is_single_type ())
206  {
207  if (a.iscomplex ())
208  retval = do_kron<FloatComplexDiagMatrix, FloatComplexMatrix> (a, b);
209  else if (b.iscomplex ())
210  retval = do_kron<FloatDiagMatrix, FloatComplexMatrix> (a, b);
211  else
212  retval = do_kron<FloatDiagMatrix, FloatMatrix> (a, b);
213  }
214  else
215  {
216  if (a.iscomplex ())
217  retval = do_kron<ComplexDiagMatrix, ComplexMatrix> (a, b);
218  else if (b.iscomplex ())
219  retval = do_kron<DiagMatrix, ComplexMatrix> (a, b);
220  else
221  retval = do_kron<DiagMatrix, Matrix> (a, b);
222  }
223  }
224  else if (a.is_single_type () || b.is_single_type ())
225  {
226  if (a.iscomplex ())
227  retval = do_kron<FloatComplexMatrix, FloatComplexMatrix> (a, b);
228  else if (b.iscomplex ())
229  retval = do_kron<FloatMatrix, FloatComplexMatrix> (a, b);
230  else
231  retval = do_kron<FloatMatrix, FloatMatrix> (a, b);
232  }
233  else
234  {
235  if (a.iscomplex ())
236  retval = do_kron<ComplexMatrix, ComplexMatrix> (a, b);
237  else if (b.iscomplex ())
238  retval = do_kron<Matrix, ComplexMatrix> (a, b);
239  else
240  retval = do_kron<Matrix, Matrix> (a, b);
241  }
242  return retval;
243 }
244 
245 
246 DEFUN (kron, args, ,
247  doc: /* -*- texinfo -*-
248 @deftypefn {} {@var{C} =} kron (@var{A}, @var{B})
249 @deftypefnx {} {@var{C} =} kron (@var{A1}, @var{A2}, @dots{})
250 Form the Kronecker product of two or more matrices.
251 
252 This is defined block by block as
253 
254 @example
255 c = [ a(i,j)*b ]
256 @end example
257 
258 For example:
259 
260 @example
261 @group
262 kron (1:4, ones (3, 1))
263  @result{} 1 2 3 4
264  1 2 3 4
265  1 2 3 4
266 @end group
267 @end example
268 
269 If there are more than two input arguments @var{A1}, @var{A2}, @dots{},
270 @var{An} the Kronecker product is computed as
271 
272 @example
273 kron (kron (@var{A1}, @var{A2}), @dots{}, @var{An})
274 @end example
275 
276 @noindent
277 Since the Kronecker product is associative, this is well-defined.
278 @seealso{tensorprod}
279 @end deftypefn */)
280 {
281  int nargin = args.length ();
282 
283  if (nargin < 2)
284  print_usage ();
285 
286  octave_value retval;
287 
288  octave_value a = args(0);
289  octave_value b = args(1);
290 
291  retval = dispatch_kron (a, b);
292 
293  for (octave_idx_type i = 2; i < nargin; i++)
294  retval = dispatch_kron (retval, args(i));
295 
296  return retval;
297 }
298 
299 /*
300 %!test
301 %! x = ones (2);
302 %! assert (kron (x, x), ones (4));
303 
304 %!shared x, y, z, p1, p2, d1, d2
305 %! x = [1, 2];
306 %! y = [-1, -2];
307 %! z = [1, 2, 3, 4; 1, 2, 3, 4; 1, 2, 3, 4];
308 %! p1 = eye (3)([2, 3, 1], :); ## Permutation matrix
309 %! p2 = [0 1 0; 0 0 1; 1 0 0]; ## Non-permutation equivalent
310 %! d1 = diag ([1 2 3]); ## Diag type matrix
311 %! d2 = [1 0 0; 0 2 0; 0 0 3]; ## Non-diag equivalent
312 %!assert (kron (1:4, ones (3, 1)), z)
313 %!assert (kron (single (1:4), ones (3, 1)), single (z))
314 %!assert (kron (sparse (1:4), ones (3, 1)), sparse (z))
315 %!assert (kron (complex (1:4), ones (3, 1)), z)
316 %!assert (kron (complex (single (1:4)), ones (3, 1)), single (z))
317 %!assert (kron (x, y, z), kron (kron (x, y), z))
318 %!assert (kron (x, y, z), kron (x, kron (y, z)))
319 %!assert (kron (p1, p1), kron (p2, p2))
320 %!assert (kron (p1, p2), kron (p2, p1))
321 %!assert (kron (d1, d1), kron (d2, d2))
322 %!assert (kron (d1, d2), kron (d2, d1))
323 
324 %!assert (kron (diag ([1, 2]), diag ([3, 4])), diag ([3, 4, 6, 8]))
325 
326 ## Test for two diag matrices.
327 ## See the comments above in dispatch_kron for this case.
328 %!test
329 %! expected = zeros (16, 16);
330 %! expected (1, 11) = 3;
331 %! expected (2, 12) = 4;
332 %! expected (5, 15) = 6;
333 %! expected (6, 16) = 8;
334 %! assert (kron (diag ([1, 2], 2), diag ([3, 4], 2)), expected);
335 */
336 
337 OCTAVE_END_NAMESPACE(octave)
#define C(a, b)
Definition: Faddeeva.cc:259
int ndims() const
Size of the specified dimension.
Definition: Array.h:671
octave_idx_type rows() const
Definition: Array.h:459
const T * data() const
Size of the specified dimension.
Definition: Array.h:663
octave_idx_type cols() const
Definition: Array.h:469
octave_idx_type rows() const
Definition: DiagArray2.h:89
T dgelem(octave_idx_type i) const
Definition: DiagArray2.h:124
octave_idx_type cols() const
Definition: DiagArray2.h:90
octave_idx_type diag_length() const
Definition: DiagArray2.h:93
Template for N-dimensional array classes with like-type math operators.
Definition: MArray.h:63
Template for two dimensional diagonal array with math operators.
Definition: MDiagArray2.h:56
octave_idx_type rows() const
Definition: PermMatrix.h:62
const Array< octave_idx_type > & col_perm_vec() const
Definition: PermMatrix.h:83
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:94
bool is_diag_matrix() const
Definition: ov.h:631
octave_idx_type rows() const
Definition: ov.h:545
bool is_perm_matrix() const
Definition: ov.h:634
bool is_single_type() const
Definition: ov.h:698
bool issparse() const
Definition: ov.h:753
bool iscomplex() const
Definition: ov.h:741
octave_idx_type columns() const
Definition: ov.h:547
octave_value diag(octave_idx_type k=0) const
Definition: ov.h:1410
OCTAVE_BEGIN_NAMESPACE(octave) static octave_value daspk_fcn
void print_usage(void)
Definition: defun-int.h:72
#define DEFUN(name, args_name, nargout_name, doc)
Macro to define a builtin function.
Definition: defun.h:56
#define error_unless(cond)
Definition: error.h:530
octave_value dispatch_kron(const octave_value &a, const octave_value &b)
Definition: kron.cc:175
octave_value do_kron(const octave_value &a, const octave_value &b)
Definition: kron.cc:166
F77_RET_T const F77_INT F77_CMPLX const F77_INT F77_CMPLX * B
F77_RET_T const F77_INT F77_CMPLX * A
void mx_inline_mul(std::size_t n, R *r, const X *x, const Y *y)
Definition: mx-inlines.cc:110
return octave_value(v1.char_array_value() . concat(v2.char_array_value(), ra_idx),((a1.is_sq_string()||a2.is_sq_string()) ? '\'' :'"'))