GNU Octave  8.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
kron.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2002-2023 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29 
30 #include "dMatrix.h"
31 #include "fMatrix.h"
32 #include "CMatrix.h"
33 #include "fCMatrix.h"
34 
35 #include "dSparse.h"
36 #include "CSparse.h"
37 
38 #include "dDiagMatrix.h"
39 #include "fDiagMatrix.h"
40 #include "CDiagMatrix.h"
41 #include "fCDiagMatrix.h"
42 
43 #include "PermMatrix.h"
44 
45 #include "mx-inlines.cc"
46 #include "quit.h"
47 
48 #include "defun.h"
49 #include "error.h"
50 #include "ovl.h"
51 
53 
54 template <typename R, typename T>
55 static MArray<T>
56 kron (const MArray<R>& a, const MArray<T>& b)
57 {
58  error_unless (a.ndims () == 2);
59  error_unless (b.ndims () == 2);
60 
61  octave_idx_type nra = a.rows ();
62  octave_idx_type nrb = b.rows ();
63  octave_idx_type nca = a.cols ();
64  octave_idx_type ncb = b.cols ();
65 
66  MArray<T> c (dim_vector (nra*nrb, nca*ncb));
67  T *cv = c.fortran_vec ();
68 
69  for (octave_idx_type ja = 0; ja < nca; ja++)
70  {
71  octave_quit ();
72  for (octave_idx_type jb = 0; jb < ncb; jb++)
73  {
74  for (octave_idx_type ia = 0; ia < nra; ia++)
75  {
76  mx_inline_mul (nrb, cv, a(ia, ja), b.data () + nrb*jb);
77  cv += nrb;
78  }
79  }
80  }
81 
82  return c;
83 }
84 
85 template <typename R, typename T>
86 static MArray<T>
87 kron (const MDiagArray2<R>& a, const MArray<T>& b)
88 {
89  error_unless (b.ndims () == 2);
90 
91  octave_idx_type nra = a.rows ();
92  octave_idx_type nrb = b.rows ();
93  octave_idx_type dla = a.diag_length ();
94  octave_idx_type nca = a.cols ();
95  octave_idx_type ncb = b.cols ();
96 
97  MArray<T> c (dim_vector (nra*nrb, nca*ncb), T ());
98 
99  for (octave_idx_type ja = 0; ja < dla; ja++)
100  {
101  octave_quit ();
102  for (octave_idx_type jb = 0; jb < ncb; jb++)
103  {
104  mx_inline_mul (nrb, &c.xelem (ja*nrb, ja*ncb + jb), a.dgelem (ja),
105  b.data () + nrb*jb);
106  }
107  }
108 
109  return c;
110 }
111 
112 template <typename T>
113 static MSparse<T>
114 kron (const MSparse<T>& A, const MSparse<T>& B)
115 {
116  octave_idx_type idx = 0;
117  MSparse<T> C (A.rows () * B.rows (), A.columns () * B.columns (),
118  A.nnz () * B.nnz ());
119 
120  C.cidx (0) = 0;
121 
122  for (octave_idx_type Aj = 0; Aj < A.columns (); Aj++)
123  {
124  octave_quit ();
125  for (octave_idx_type Bj = 0; Bj < B.columns (); Bj++)
126  {
127  for (octave_idx_type Ai = A.cidx (Aj); Ai < A.cidx (Aj+1); Ai++)
128  {
129  octave_idx_type Ci = A.ridx (Ai) * B.rows ();
130  const T v = A.data (Ai);
131 
132  for (octave_idx_type Bi = B.cidx (Bj); Bi < B.cidx (Bj+1); Bi++)
133  {
134  C.data (idx) = v * B.data (Bi);
135  C.ridx (idx++) = Ci + B.ridx (Bi);
136  }
137  }
138  C.cidx (Aj * B.columns () + Bj + 1) = idx;
139  }
140  }
141 
142  return C;
143 }
144 
145 static PermMatrix
146 kron (const PermMatrix& a, const PermMatrix& b)
147 {
148  octave_idx_type na = a.rows ();
149  octave_idx_type nb = b.rows ();
150  const Array<octave_idx_type>& pa = a.col_perm_vec ();
151  const Array<octave_idx_type>& pb = b.col_perm_vec ();
152  Array<octave_idx_type> res_perm (dim_vector (na * nb, 1));
153  octave_idx_type rescol = 0;
154  for (octave_idx_type i = 0; i < na; i++)
155  {
156  octave_idx_type a_add = pa(i) * nb;
157  for (octave_idx_type j = 0; j < nb; j++)
158  res_perm.xelem (rescol++) = a_add + pb(j);
159  }
160 
161  return PermMatrix (res_perm, true);
162 }
163 
164 template <typename MTA, typename MTB>
166 do_kron (const octave_value& a, const octave_value& b)
167 {
168  MTA am = octave_value_extract<MTA> (a);
169  MTB bm = octave_value_extract<MTB> (b);
170 
171  return octave_value (kron (am, bm));
172 }
173 
176 {
177  octave_value retval;
178  if (a.is_perm_matrix () && b.is_perm_matrix ())
179  retval = do_kron<PermMatrix, PermMatrix> (a, b);
180  else if (a.issparse () || b.issparse ())
181  {
182  if (a.iscomplex () || b.iscomplex ())
183  retval = do_kron<SparseComplexMatrix, SparseComplexMatrix> (a, b);
184  else
185  retval = do_kron<SparseMatrix, SparseMatrix> (a, b);
186  }
187  else if (a.is_diag_matrix ())
188  {
189  if (b.is_diag_matrix () && a.rows () == a.columns ()
190  && b.rows () == b.columns ())
191  {
192  // We have two diagonal matrices, the product of those will be
193  // another diagonal matrix. To do that efficiently, extract
194  // the diagonals as vectors and compute the product. That
195  // will be another vector, which we then use to construct a
196  // diagonal matrix object. Note that this will fail if our
197  // digaonal matrix object is modified to allow the nonzero
198  // values to be stored off of the principal diagonal (i.e., if
199  // diag ([1,2], 3) is modified to return a diagonal matrix
200  // object instead of a full matrix object).
201 
202  octave_value tmp = dispatch_kron (a.diag (), b.diag ());
203  retval = tmp.diag ();
204  }
205  else if (a.is_single_type () || b.is_single_type ())
206  {
207  if (a.iscomplex ())
208  retval = do_kron<FloatComplexDiagMatrix, FloatComplexMatrix> (a, b);
209  else if (b.iscomplex ())
210  retval = do_kron<FloatDiagMatrix, FloatComplexMatrix> (a, b);
211  else
212  retval = do_kron<FloatDiagMatrix, FloatMatrix> (a, b);
213  }
214  else
215  {
216  if (a.iscomplex ())
217  retval = do_kron<ComplexDiagMatrix, ComplexMatrix> (a, b);
218  else if (b.iscomplex ())
219  retval = do_kron<DiagMatrix, ComplexMatrix> (a, b);
220  else
221  retval = do_kron<DiagMatrix, Matrix> (a, b);
222  }
223  }
224  else if (a.is_single_type () || b.is_single_type ())
225  {
226  if (a.iscomplex ())
227  retval = do_kron<FloatComplexMatrix, FloatComplexMatrix> (a, b);
228  else if (b.iscomplex ())
229  retval = do_kron<FloatMatrix, FloatComplexMatrix> (a, b);
230  else
231  retval = do_kron<FloatMatrix, FloatMatrix> (a, b);
232  }
233  else
234  {
235  if (a.iscomplex ())
236  retval = do_kron<ComplexMatrix, ComplexMatrix> (a, b);
237  else if (b.iscomplex ())
238  retval = do_kron<Matrix, ComplexMatrix> (a, b);
239  else
240  retval = do_kron<Matrix, Matrix> (a, b);
241  }
242  return retval;
243 }
244 
245 
246 DEFUN (kron, args, ,
247  doc: /* -*- texinfo -*-
248 @deftypefn {} {@var{C} =} kron (@var{A}, @var{B})
249 @deftypefnx {} {@var{C} =} kron (@var{A1}, @var{A2}, @dots{})
250 Form the Kronecker product of two or more matrices.
251 
252 This is defined block by block as
253 
254 @example
255 c = [ a(i,j)*b ]
256 @end example
257 
258 For example:
259 
260 @example
261 @group
262 kron (1:4, ones (3, 1))
263  @result{} 1 2 3 4
264  1 2 3 4
265  1 2 3 4
266 @end group
267 @end example
268 
269 If there are more than two input arguments @var{A1}, @var{A2}, @dots{},
270 @var{An} the Kronecker product is computed as
271 
272 @example
273 kron (kron (@var{A1}, @var{A2}), @dots{}, @var{An})
274 @end example
275 
276 @noindent
277 Since the Kronecker product is associative, this is well-defined.
278 @end deftypefn */)
279 {
280  int nargin = args.length ();
281 
282  if (nargin < 2)
283  print_usage ();
284 
285  octave_value retval;
286 
287  octave_value a = args(0);
288  octave_value b = args(1);
289 
290  retval = dispatch_kron (a, b);
291 
292  for (octave_idx_type i = 2; i < nargin; i++)
293  retval = dispatch_kron (retval, args(i));
294 
295  return retval;
296 }
297 
298 /*
299 %!test
300 %! x = ones (2);
301 %! assert (kron (x, x), ones (4));
302 
303 %!shared x, y, z, p1, p2, d1, d2
304 %! x = [1, 2];
305 %! y = [-1, -2];
306 %! z = [1, 2, 3, 4; 1, 2, 3, 4; 1, 2, 3, 4];
307 %! p1 = eye (3)([2, 3, 1], :); ## Permutation matrix
308 %! p2 = [0 1 0; 0 0 1; 1 0 0]; ## Non-permutation equivalent
309 %! d1 = diag ([1 2 3]); ## Diag type matrix
310 %! d2 = [1 0 0; 0 2 0; 0 0 3]; ## Non-diag equivalent
311 %!assert (kron (1:4, ones (3, 1)), z)
312 %!assert (kron (single (1:4), ones (3, 1)), single (z))
313 %!assert (kron (sparse (1:4), ones (3, 1)), sparse (z))
314 %!assert (kron (complex (1:4), ones (3, 1)), z)
315 %!assert (kron (complex (single (1:4)), ones (3, 1)), single (z))
316 %!assert (kron (x, y, z), kron (kron (x, y), z))
317 %!assert (kron (x, y, z), kron (x, kron (y, z)))
318 %!assert (kron (p1, p1), kron (p2, p2))
319 %!assert (kron (p1, p2), kron (p2, p1))
320 %!assert (kron (d1, d1), kron (d2, d2))
321 %!assert (kron (d1, d2), kron (d2, d1))
322 
323 %!assert (kron (diag ([1, 2]), diag ([3, 4])), diag ([3, 4, 6, 8]))
324 
325 ## Test for two diag matrices.
326 ## See the comments above in dispatch_kron for this case.
327 %!test
328 %! expected = zeros (16, 16);
329 %! expected (1, 11) = 3;
330 %! expected (2, 12) = 4;
331 %! expected (5, 15) = 6;
332 %! expected (6, 16) = 8;
333 %! assert (kron (diag ([1, 2], 2), diag ([3, 4], 2)), expected);
334 */
335 
OCTAVE_END_NAMESPACE(octave)
#define C(a, b)
Definition: Faddeeva.cc:259
OCTARRAY_OVERRIDABLE_FUNC_API const T * data(void) const
Size of the specified dimension.
Definition: Array.h:663
OCTARRAY_OVERRIDABLE_FUNC_API octave_idx_type rows(void) const
Definition: Array.h:459
OCTARRAY_API T * fortran_vec(void)
Size of the specified dimension.
Definition: Array-base.cc:1766
OCTARRAY_OVERRIDABLE_FUNC_API int ndims(void) const
Size of the specified dimension.
Definition: Array.h:677
OCTARRAY_OVERRIDABLE_FUNC_API octave_idx_type cols(void) const
Definition: Array.h:469
OCTARRAY_OVERRIDABLE_FUNC_API T & xelem(octave_idx_type n)
Size of the specified dimension.
Definition: Array.h:524
octave_idx_type diag_length(void) const
Definition: DiagArray2.h:93
T dgelem(octave_idx_type i) const
Definition: DiagArray2.h:124
octave_idx_type cols(void) const
Definition: DiagArray2.h:90
octave_idx_type rows(void) const
Definition: DiagArray2.h:89
Template for N-dimensional array classes with like-type math operators.
Definition: MArray.h:63
Template for two dimensional diagonal array with math operators.
Definition: MDiagArray2.h:56
const Array< octave_idx_type > & col_perm_vec(void) const
Definition: PermMatrix.h:83
octave_idx_type rows(void) const
Definition: PermMatrix.h:62
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:94
bool issparse(void) const
Definition: ov.h:798
octave_idx_type rows(void) const
Definition: ov.h:590
bool is_diag_matrix(void) const
Definition: ov.h:676
octave_idx_type columns(void) const
Definition: ov.h:592
bool is_single_type(void) const
Definition: ov.h:743
bool is_perm_matrix(void) const
Definition: ov.h:679
bool iscomplex(void) const
Definition: ov.h:786
octave_value diag(octave_idx_type k=0) const
Definition: ov.h:1534
OCTAVE_BEGIN_NAMESPACE(octave) static octave_value daspk_fcn
OCTINTERP_API void print_usage(void)
Definition: defun-int.h:72
#define DEFUN(name, args_name, nargout_name, doc)
Macro to define a builtin function.
Definition: defun.h:56
void error_unless(bool cond)
Definition: error.h:549
static MArray< T > kron(const MArray< R > &a, const MArray< T > &b)
Definition: kron.cc:56
octave_value dispatch_kron(const octave_value &a, const octave_value &b)
Definition: kron.cc:175
octave_value do_kron(const octave_value &a, const octave_value &b)
Definition: kron.cc:166
F77_RET_T const F77_INT F77_CMPLX const F77_INT F77_CMPLX * B
F77_RET_T const F77_INT F77_CMPLX * A
class OCTAVE_API PermMatrix
Definition: mx-fwd.h:64
void mx_inline_mul(std::size_t n, R *r, const X *x, const Y *y)
Definition: mx-inlines.cc:109
return octave_value(v1.char_array_value() . concat(v2.char_array_value(), ra_idx),((a1.is_sq_string()||a2.is_sq_string()) ? '\'' :'"'))