GNU Octave  6.2.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
kron.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2002-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29 
30 #include "dMatrix.h"
31 #include "fMatrix.h"
32 #include "CMatrix.h"
33 #include "fCMatrix.h"
34 
35 #include "dSparse.h"
36 #include "CSparse.h"
37 
38 #include "dDiagMatrix.h"
39 #include "fDiagMatrix.h"
40 #include "CDiagMatrix.h"
41 #include "fCDiagMatrix.h"
42 
43 #include "PermMatrix.h"
44 
45 #include "mx-inlines.cc"
46 #include "quit.h"
47 
48 #include "defun.h"
49 #include "error.h"
50 #include "ovl.h"
51 
52 template <typename R, typename T>
53 static MArray<T>
54 kron (const MArray<R>& a, const MArray<T>& b)
55 {
56  assert (a.ndims () == 2);
57  assert (b.ndims () == 2);
58 
59  octave_idx_type nra = a.rows ();
60  octave_idx_type nrb = b.rows ();
61  octave_idx_type nca = a.cols ();
62  octave_idx_type ncb = b.cols ();
63 
64  MArray<T> c (dim_vector (nra*nrb, nca*ncb));
65  T *cv = c.fortran_vec ();
66 
67  for (octave_idx_type ja = 0; ja < nca; ja++)
68  {
69  octave_quit ();
70  for (octave_idx_type jb = 0; jb < ncb; jb++)
71  {
72  for (octave_idx_type ia = 0; ia < nra; ia++)
73  {
74  mx_inline_mul (nrb, cv, a(ia, ja), b.data () + nrb*jb);
75  cv += nrb;
76  }
77  }
78  }
79 
80  return c;
81 }
82 
83 template <typename R, typename T>
84 static MArray<T>
85 kron (const MDiagArray2<R>& a, const MArray<T>& b)
86 {
87  assert (b.ndims () == 2);
88 
89  octave_idx_type nra = a.rows ();
90  octave_idx_type nrb = b.rows ();
91  octave_idx_type dla = a.diag_length ();
92  octave_idx_type nca = a.cols ();
93  octave_idx_type ncb = b.cols ();
94 
95  MArray<T> c (dim_vector (nra*nrb, nca*ncb), T ());
96 
97  for (octave_idx_type ja = 0; ja < dla; ja++)
98  {
99  octave_quit ();
100  for (octave_idx_type jb = 0; jb < ncb; jb++)
101  {
102  mx_inline_mul (nrb, &c.xelem (ja*nrb, ja*ncb + jb), a.dgelem (ja),
103  b.data () + nrb*jb);
104  }
105  }
106 
107  return c;
108 }
109 
110 template <typename T>
111 static MSparse<T>
112 kron (const MSparse<T>& A, const MSparse<T>& B)
113 {
114  octave_idx_type idx = 0;
115  MSparse<T> C (A.rows () * B.rows (), A.columns () * B.columns (),
116  A.nnz () * B.nnz ());
117 
118  C.cidx (0) = 0;
119 
120  for (octave_idx_type Aj = 0; Aj < A.columns (); Aj++)
121  {
122  octave_quit ();
123  for (octave_idx_type Bj = 0; Bj < B.columns (); Bj++)
124  {
125  for (octave_idx_type Ai = A.cidx (Aj); Ai < A.cidx (Aj+1); Ai++)
126  {
127  octave_idx_type Ci = A.ridx (Ai) * B.rows ();
128  const T v = A.data (Ai);
129 
130  for (octave_idx_type Bi = B.cidx (Bj); Bi < B.cidx (Bj+1); Bi++)
131  {
132  C.data (idx) = v * B.data (Bi);
133  C.ridx (idx++) = Ci + B.ridx (Bi);
134  }
135  }
136  C.cidx (Aj * B.columns () + Bj + 1) = idx;
137  }
138  }
139 
140  return C;
141 }
142 
143 static PermMatrix
144 kron (const PermMatrix& a, const PermMatrix& b)
145 {
146  octave_idx_type na = a.rows ();
147  octave_idx_type nb = b.rows ();
148  const Array<octave_idx_type>& pa = a.col_perm_vec ();
149  const Array<octave_idx_type>& pb = b.col_perm_vec ();
150  Array<octave_idx_type> res_perm (dim_vector (na * nb, 1));
151  octave_idx_type rescol = 0;
152  for (octave_idx_type i = 0; i < na; i++)
153  {
154  octave_idx_type a_add = pa(i) * nb;
155  for (octave_idx_type j = 0; j < nb; j++)
156  res_perm.xelem (rescol++) = a_add + pb(j);
157  }
158 
159  return PermMatrix (res_perm, true);
160 }
161 
162 template <typename MTA, typename MTB>
164 do_kron (const octave_value& a, const octave_value& b)
165 {
166  MTA am = octave_value_extract<MTA> (a);
167  MTB bm = octave_value_extract<MTB> (b);
168 
169  return octave_value (kron (am, bm));
170 }
171 
174 {
176  if (a.is_perm_matrix () && b.is_perm_matrix ())
177  retval = do_kron<PermMatrix, PermMatrix> (a, b);
178  else if (a.issparse () || b.issparse ())
179  {
180  if (a.iscomplex () || b.iscomplex ())
181  retval = do_kron<SparseComplexMatrix, SparseComplexMatrix> (a, b);
182  else
183  retval = do_kron<SparseMatrix, SparseMatrix> (a, b);
184  }
185  else if (a.is_diag_matrix ())
186  {
187  if (b.is_diag_matrix () && a.rows () == a.columns ()
188  && b.rows () == b.columns ())
189  {
190  // We have two diagonal matrices, the product of those will be
191  // another diagonal matrix. To do that efficiently, extract
192  // the diagonals as vectors and compute the product. That
193  // will be another vector, which we then use to construct a
194  // diagonal matrix object. Note that this will fail if our
195  // digaonal matrix object is modified to allow the nonzero
196  // values to be stored off of the principal diagonal (i.e., if
197  // diag ([1,2], 3) is modified to return a diagonal matrix
198  // object instead of a full matrix object).
199 
200  octave_value tmp = dispatch_kron (a.diag (), b.diag ());
201  retval = tmp.diag ();
202  }
203  else if (a.is_single_type () || b.is_single_type ())
204  {
205  if (a.iscomplex ())
206  retval = do_kron<FloatComplexDiagMatrix, FloatComplexMatrix> (a, b);
207  else if (b.iscomplex ())
208  retval = do_kron<FloatDiagMatrix, FloatComplexMatrix> (a, b);
209  else
210  retval = do_kron<FloatDiagMatrix, FloatMatrix> (a, b);
211  }
212  else
213  {
214  if (a.iscomplex ())
215  retval = do_kron<ComplexDiagMatrix, ComplexMatrix> (a, b);
216  else if (b.iscomplex ())
217  retval = do_kron<DiagMatrix, ComplexMatrix> (a, b);
218  else
219  retval = do_kron<DiagMatrix, Matrix> (a, b);
220  }
221  }
222  else if (a.is_single_type () || b.is_single_type ())
223  {
224  if (a.iscomplex ())
225  retval = do_kron<FloatComplexMatrix, FloatComplexMatrix> (a, b);
226  else if (b.iscomplex ())
227  retval = do_kron<FloatMatrix, FloatComplexMatrix> (a, b);
228  else
229  retval = do_kron<FloatMatrix, FloatMatrix> (a, b);
230  }
231  else
232  {
233  if (a.iscomplex ())
234  retval = do_kron<ComplexMatrix, ComplexMatrix> (a, b);
235  else if (b.iscomplex ())
236  retval = do_kron<Matrix, ComplexMatrix> (a, b);
237  else
238  retval = do_kron<Matrix, Matrix> (a, b);
239  }
240  return retval;
241 }
242 
243 
244 DEFUN (kron, args, ,
245  doc: /* -*- texinfo -*-
246 @deftypefn {} {} kron (@var{A}, @var{B})
247 @deftypefnx {} {} kron (@var{A1}, @var{A2}, @dots{})
248 Form the Kronecker product of two or more matrices.
249 
250 This is defined block by block as
251 
252 @example
253 x = [ a(i,j)*b ]
254 @end example
255 
256 For example:
257 
258 @example
259 @group
260 kron (1:4, ones (3, 1))
261  @result{} 1 2 3 4
262  1 2 3 4
263  1 2 3 4
264 @end group
265 @end example
266 
267 If there are more than two input arguments @var{A1}, @var{A2}, @dots{},
268 @var{An} the Kronecker product is computed as
269 
270 @example
271 kron (kron (@var{A1}, @var{A2}), @dots{}, @var{An})
272 @end example
273 
274 @noindent
275 Since the Kronecker product is associative, this is well-defined.
276 @end deftypefn */)
277 {
278  int nargin = args.length ();
279 
280  if (nargin < 2)
281  print_usage ();
282 
284 
285  octave_value a = args(0);
286  octave_value b = args(1);
287 
288  retval = dispatch_kron (a, b);
289 
290  for (octave_idx_type i = 2; i < nargin; i++)
291  retval = dispatch_kron (retval, args(i));
292 
293  return retval;
294 }
295 
296 /*
297 %!test
298 %! x = ones (2);
299 %! assert (kron (x, x), ones (4));
300 
301 %!shared x, y, z, p1, p2, d1, d2
302 %! x = [1, 2];
303 %! y = [-1, -2];
304 %! z = [1, 2, 3, 4; 1, 2, 3, 4; 1, 2, 3, 4];
305 %! p1 = eye (3)([2, 3, 1], :); ## Permutation matrix
306 %! p2 = [0 1 0; 0 0 1; 1 0 0]; ## Non-permutation equivalent
307 %! d1 = diag ([1 2 3]); ## Diag type matrix
308 %! d2 = [1 0 0; 0 2 0; 0 0 3]; ## Non-diag equivalent
309 %!assert (kron (1:4, ones (3, 1)), z)
310 %!assert (kron (single (1:4), ones (3, 1)), single (z))
311 %!assert (kron (sparse (1:4), ones (3, 1)), sparse (z))
312 %!assert (kron (complex (1:4), ones (3, 1)), z)
313 %!assert (kron (complex (single(1:4)), ones (3, 1)), single(z))
314 %!assert (kron (x, y, z), kron (kron (x, y), z))
315 %!assert (kron (x, y, z), kron (x, kron (y, z)))
316 %!assert (kron (p1, p1), kron (p2, p2))
317 %!assert (kron (p1, p2), kron (p2, p1))
318 %!assert (kron (d1, d1), kron (d2, d2))
319 %!assert (kron (d1, d2), kron (d2, d1))
320 
321 %!assert (kron (diag ([1, 2]), diag ([3, 4])), diag ([3, 4, 6, 8]))
322 
323 ## Test for two diag matrices.
324 ## See the comments above in dispatch_kron for this case.
325 %!test
326 %! expected = zeros (16, 16);
327 %! expected (1, 11) = 3;
328 %! expected (2, 12) = 4;
329 %! expected (5, 15) = 6;
330 %! expected (6, 16) = 8;
331 %! assert (kron (diag ([1, 2], 2), diag ([3, 4], 2)), expected);
332 */
#define C(a, b)
Definition: Faddeeva.cc:246
T & xelem(octave_idx_type n)
Size of the specified dimension.
Definition: Array.h:469
const T * data(void) const
Size of the specified dimension.
Definition: Array.h:581
octave_idx_type cols(void) const
Definition: Array.h:423
octave_idx_type rows(void) const
Definition: Array.h:415
int ndims(void) const
Size of the specified dimension.
Definition: Array.h:589
const T * fortran_vec(void) const
Size of the specified dimension.
Definition: Array.h:583
octave_idx_type diag_length(void) const
Definition: DiagArray2.h:92
T dgelem(octave_idx_type i) const
Definition: DiagArray2.h:123
octave_idx_type cols(void) const
Definition: DiagArray2.h:89
octave_idx_type rows(void) const
Definition: DiagArray2.h:88
Template for N-dimensional array classes with like-type math operators.
Definition: MArray.h:63
Template for two dimensional diagonal array with math operators.
Definition: MDiagArray2.h:55
const Array< octave_idx_type > & col_perm_vec(void) const
Definition: PermMatrix.h:81
octave_idx_type rows(void) const
Definition: PermMatrix.h:60
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:95
bool issparse(void) const
Definition: ov.h:706
octave_idx_type rows(void) const
Definition: ov.h:504
bool is_diag_matrix(void) const
Definition: ov.h:587
octave_idx_type columns(void) const
Definition: ov.h:506
bool is_single_type(void) const
Definition: ov.h:651
bool is_perm_matrix(void) const
Definition: ov.h:590
bool iscomplex(void) const
Definition: ov.h:694
octave_value diag(octave_idx_type k=0) const
Definition: ov.h:1333
OCTINTERP_API void print_usage(void)
Definition: defun.cc:53
#define DEFUN(name, args_name, nargout_name, doc)
Macro to define a builtin function.
Definition: defun.h:56
static MArray< T > kron(const MArray< R > &a, const MArray< T > &b)
Definition: kron.cc:54
octave_value dispatch_kron(const octave_value &a, const octave_value &b)
Definition: kron.cc:173
octave_value do_kron(const octave_value &a, const octave_value &b)
Definition: kron.cc:164
F77_RET_T const F77_INT F77_CMPLX const F77_INT F77_CMPLX * B
F77_RET_T const F77_INT F77_CMPLX * A
void mx_inline_mul(size_t n, R *r, const X *x, const Y *y)
Definition: mx-inlines.cc:109
return octave_value(v1.char_array_value() . concat(v2.char_array_value(), ra_idx),((a1.is_sq_string()||a2.is_sq_string()) ? '\'' :'"'))
octave_value::octave_value(const Array< char > &chm, char type) return retval
Definition: ov.cc:811