GNU Octave  9.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
kron.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2002-2024 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29 
30 #include "dMatrix.h"
31 #include "fMatrix.h"
32 #include "CMatrix.h"
33 #include "fCMatrix.h"
34 
35 #include "dSparse.h"
36 #include "CSparse.h"
37 
38 #include "dDiagMatrix.h"
39 #include "fDiagMatrix.h"
40 #include "CDiagMatrix.h"
41 #include "fCDiagMatrix.h"
42 
43 #include "PermMatrix.h"
44 
45 #include "mx-inlines.cc"
46 #include "quit.h"
47 
48 #include "defun.h"
49 #include "error.h"
50 #include "ovl.h"
51 
53 
54 template <typename R, typename T>
55 static MArray<T>
56 kron (const MArray<R>& a, const MArray<T>& b)
57 {
58  error_unless (a.ndims () == 2);
59  error_unless (b.ndims () == 2);
60 
61  octave_idx_type nra = a.rows ();
62  octave_idx_type nrb = b.rows ();
63  octave_idx_type nca = a.cols ();
64  octave_idx_type ncb = b.cols ();
65 
66  MArray<T> c (dim_vector (nra*nrb, nca*ncb));
67  T *cv = c.fortran_vec ();
68 
69  for (octave_idx_type ja = 0; ja < nca; ja++)
70  {
71  octave_quit ();
72  for (octave_idx_type jb = 0; jb < ncb; jb++)
73  {
74  for (octave_idx_type ia = 0; ia < nra; ia++)
75  {
76  mx_inline_mul (nrb, cv, a(ia, ja), b.data () + nrb*jb);
77  cv += nrb;
78  }
79  }
80  }
81 
82  return c;
83 }
84 
85 template <typename R, typename T>
86 static MArray<T>
87 kron (const MDiagArray2<R>& a, const MArray<T>& b)
88 {
89  error_unless (b.ndims () == 2);
90 
91  octave_idx_type nra = a.rows ();
92  octave_idx_type nrb = b.rows ();
93  octave_idx_type dla = a.diag_length ();
94  octave_idx_type nca = a.cols ();
95  octave_idx_type ncb = b.cols ();
96 
97  MArray<T> c (dim_vector (nra*nrb, nca*ncb), T ());
98 
99  for (octave_idx_type ja = 0; ja < dla; ja++)
100  {
101  octave_quit ();
102  for (octave_idx_type jb = 0; jb < ncb; jb++)
103  {
104  mx_inline_mul (nrb, &c.xelem (ja*nrb, ja*ncb + jb), a.dgelem (ja),
105  b.data () + nrb*jb);
106  }
107  }
108 
109  return c;
110 }
111 
112 template <typename T>
113 static MSparse<T>
114 kron (const MSparse<T>& A, const MSparse<T>& B)
115 {
116  octave_idx_type idx = 0;
117  MSparse<T> C (A.rows () * B.rows (), A.columns () * B.columns (),
118  A.nnz () * B.nnz ());
119 
120  C.cidx (0) = 0;
121 
122  for (octave_idx_type Aj = 0; Aj < A.columns (); Aj++)
123  {
124  octave_quit ();
125  for (octave_idx_type Bj = 0; Bj < B.columns (); Bj++)
126  {
127  for (octave_idx_type Ai = A.cidx (Aj); Ai < A.cidx (Aj+1); Ai++)
128  {
129  octave_idx_type Ci = A.ridx (Ai) * B.rows ();
130  const T v = A.data (Ai);
131 
132  for (octave_idx_type Bi = B.cidx (Bj); Bi < B.cidx (Bj+1); Bi++)
133  {
134  C.data (idx) = v * B.data (Bi);
135  C.ridx (idx++) = Ci + B.ridx (Bi);
136  }
137  }
138  C.cidx (Aj * B.columns () + Bj + 1) = idx;
139  }
140  }
141 
142  return C;
143 }
144 
145 static PermMatrix
146 kron (const PermMatrix& a, const PermMatrix& b)
147 {
148  octave_idx_type na = a.rows ();
149  octave_idx_type nb = b.rows ();
150  const Array<octave_idx_type>& pa = a.col_perm_vec ();
151  const Array<octave_idx_type>& pb = b.col_perm_vec ();
152  Array<octave_idx_type> res_perm (dim_vector (na * nb, 1));
153  octave_idx_type rescol = 0;
154  for (octave_idx_type i = 0; i < na; i++)
155  {
156  octave_idx_type a_add = pa(i) * nb;
157  for (octave_idx_type j = 0; j < nb; j++)
158  res_perm.xelem (rescol++) = a_add + pb(j);
159  }
160 
161  return PermMatrix (res_perm, true);
162 }
163 
164 template <typename MTA, typename MTB>
166 do_kron (const octave_value& a, const octave_value& b)
167 {
168  MTA am = octave_value_extract<MTA> (a);
169  MTB bm = octave_value_extract<MTB> (b);
170 
171  return octave_value (kron (am, bm));
172 }
173 
176 {
177  octave_value retval;
178  if (a.is_perm_matrix () && b.is_perm_matrix ())
179  retval = do_kron<PermMatrix, PermMatrix> (a, b);
180  else if (a.issparse () || b.issparse ())
181  {
182  if (a.iscomplex () || b.iscomplex ())
183  retval = do_kron<SparseComplexMatrix, SparseComplexMatrix> (a, b);
184  else
185  retval = do_kron<SparseMatrix, SparseMatrix> (a, b);
186  }
187  else if (a.is_diag_matrix ())
188  {
189  if (b.is_diag_matrix () && a.rows () == a.columns ()
190  && b.rows () == b.columns ())
191  {
192  // We have two diagonal matrices, the product of those will be
193  // another diagonal matrix. To do that efficiently, extract
194  // the diagonals as vectors and compute the product. That
195  // will be another vector, which we then use to construct a
196  // diagonal matrix object. Note that this will fail if our
197  // digaonal matrix object is modified to allow the nonzero
198  // values to be stored off of the principal diagonal (i.e., if
199  // diag ([1,2], 3) is modified to return a diagonal matrix
200  // object instead of a full matrix object).
201 
202  octave_value tmp = dispatch_kron (a.diag (), b.diag ());
203  retval = tmp.diag ();
204  }
205  else if (a.is_single_type () || b.is_single_type ())
206  {
207  if (a.iscomplex ())
208  retval = do_kron<FloatComplexDiagMatrix, FloatComplexMatrix> (a, b);
209  else if (b.iscomplex ())
210  retval = do_kron<FloatDiagMatrix, FloatComplexMatrix> (a, b);
211  else
212  retval = do_kron<FloatDiagMatrix, FloatMatrix> (a, b);
213  }
214  else
215  {
216  if (a.iscomplex ())
217  retval = do_kron<ComplexDiagMatrix, ComplexMatrix> (a, b);
218  else if (b.iscomplex ())
219  retval = do_kron<DiagMatrix, ComplexMatrix> (a, b);
220  else
221  retval = do_kron<DiagMatrix, Matrix> (a, b);
222  }
223  }
224  else if (a.is_single_type () || b.is_single_type ())
225  {
226  if (a.iscomplex ())
227  retval = do_kron<FloatComplexMatrix, FloatComplexMatrix> (a, b);
228  else if (b.iscomplex ())
229  retval = do_kron<FloatMatrix, FloatComplexMatrix> (a, b);
230  else
231  retval = do_kron<FloatMatrix, FloatMatrix> (a, b);
232  }
233  else
234  {
235  if (a.iscomplex ())
236  retval = do_kron<ComplexMatrix, ComplexMatrix> (a, b);
237  else if (b.iscomplex ())
238  retval = do_kron<Matrix, ComplexMatrix> (a, b);
239  else
240  retval = do_kron<Matrix, Matrix> (a, b);
241  }
242  return retval;
243 }
244 
245 
246 DEFUN (kron, args, ,
247  doc: /* -*- texinfo -*-
248 @deftypefn {} {@var{C} =} kron (@var{A}, @var{B})
249 @deftypefnx {} {@var{C} =} kron (@var{A1}, @var{A2}, @dots{})
250 Form the Kronecker product of two or more matrices.
251 
252 This is defined block by block as
253 
254 @example
255 c = [ a(i,j)*b ]
256 @end example
257 
258 For example:
259 
260 @example
261 @group
262 kron (1:4, ones (3, 1))
263  @result{} 1 2 3 4
264  1 2 3 4
265  1 2 3 4
266 @end group
267 @end example
268 
269 If there are more than two input arguments @var{A1}, @var{A2}, @dots{},
270 @var{An} the Kronecker product is computed as
271 
272 @example
273 kron (kron (@var{A1}, @var{A2}), @dots{}, @var{An})
274 @end example
275 
276 @noindent
277 Since the Kronecker product is associative, this is well-defined.
278 @seealso{tensorprod}
279 @end deftypefn */)
280 {
281  int nargin = args.length ();
282 
283  if (nargin < 2)
284  print_usage ();
285 
286  octave_value retval;
287 
288  octave_value a = args(0);
289  octave_value b = args(1);
290 
291  retval = dispatch_kron (a, b);
292 
293  for (octave_idx_type i = 2; i < nargin; i++)
294  retval = dispatch_kron (retval, args(i));
295 
296  return retval;
297 }
298 
299 /*
300 %!test
301 %! x = ones (2);
302 %! assert (kron (x, x), ones (4));
303 
304 %!shared x, y, z, p1, p2, d1, d2
305 %! x = [1, 2];
306 %! y = [-1, -2];
307 %! z = [1, 2, 3, 4; 1, 2, 3, 4; 1, 2, 3, 4];
308 %! p1 = eye (3)([2, 3, 1], :); ## Permutation matrix
309 %! p2 = [0 1 0; 0 0 1; 1 0 0]; ## Non-permutation equivalent
310 %! d1 = diag ([1 2 3]); ## Diag type matrix
311 %! d2 = [1 0 0; 0 2 0; 0 0 3]; ## Non-diag equivalent
312 %!assert (kron (1:4, ones (3, 1)), z)
313 %!assert (kron (single (1:4), ones (3, 1)), single (z))
314 %!assert (kron (sparse (1:4), ones (3, 1)), sparse (z))
315 %!assert (kron (complex (1:4), ones (3, 1)), z)
316 %!assert (kron (complex (single (1:4)), ones (3, 1)), single (z))
317 %!assert (kron (x, y, z), kron (kron (x, y), z))
318 %!assert (kron (x, y, z), kron (x, kron (y, z)))
319 %!assert (kron (p1, p1), kron (p2, p2))
320 %!assert (kron (p1, p2), kron (p2, p1))
321 %!assert (kron (d1, d1), kron (d2, d2))
322 %!assert (kron (d1, d2), kron (d2, d1))
323 
324 %!assert (kron (diag ([1, 2]), diag ([3, 4])), diag ([3, 4, 6, 8]))
325 
326 ## Test for two diag matrices.
327 ## See the comments above in dispatch_kron for this case.
328 %!test
329 %! expected = zeros (16, 16);
330 %! expected (1, 11) = 3;
331 %! expected (2, 12) = 4;
332 %! expected (5, 15) = 6;
333 %! expected (6, 16) = 8;
334 %! assert (kron (diag ([1, 2], 2), diag ([3, 4], 2)), expected);
335 */
336 
337 OCTAVE_END_NAMESPACE(octave)
#define C(a, b)
Definition: Faddeeva.cc:259
int ndims() const
Size of the specified dimension.
Definition: Array.h:671
octave_idx_type rows() const
Definition: Array.h:459
const T * data() const
Size of the specified dimension.
Definition: Array.h:663
octave_idx_type cols() const
Definition: Array.h:469
octave_idx_type rows() const
Definition: DiagArray2.h:89
T dgelem(octave_idx_type i) const
Definition: DiagArray2.h:124
octave_idx_type cols() const
Definition: DiagArray2.h:90
octave_idx_type diag_length() const
Definition: DiagArray2.h:93
Template for N-dimensional array classes with like-type math operators.
Definition: MArray.h:63
Template for two dimensional diagonal array with math operators.
Definition: MDiagArray2.h:56
octave_idx_type rows() const
Definition: PermMatrix.h:62
const Array< octave_idx_type > & col_perm_vec() const
Definition: PermMatrix.h:83
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:94
bool is_diag_matrix() const
Definition: ov.h:631
octave_idx_type rows() const
Definition: ov.h:545
bool is_perm_matrix() const
Definition: ov.h:634
bool is_single_type() const
Definition: ov.h:698
bool issparse() const
Definition: ov.h:753
bool iscomplex() const
Definition: ov.h:741
octave_idx_type columns() const
Definition: ov.h:547
octave_value diag(octave_idx_type k=0) const
Definition: ov.h:1410
OCTAVE_BEGIN_NAMESPACE(octave) static octave_value daspk_fcn
void print_usage(void)
Definition: defun-int.h:72
#define DEFUN(name, args_name, nargout_name, doc)
Macro to define a builtin function.
Definition: defun.h:56
#define error_unless(cond)
Definition: error.h:530
octave_value dispatch_kron(const octave_value &a, const octave_value &b)
Definition: kron.cc:175
octave_value do_kron(const octave_value &a, const octave_value &b)
Definition: kron.cc:166
F77_RET_T const F77_INT F77_CMPLX const F77_INT F77_CMPLX * B
F77_RET_T const F77_INT F77_CMPLX * A
void mx_inline_mul(std::size_t n, R *r, const X *x, const Y *y)
Definition: mx-inlines.cc:110
return octave_value(v1.char_array_value() . concat(v2.char_array_value(), ra_idx),((a1.is_sq_string()||a2.is_sq_string()) ? '\'' :'"'))