GNU Octave 7.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
kron.cc
Go to the documentation of this file.
1////////////////////////////////////////////////////////////////////////
2//
3// Copyright (C) 2002-2022 The Octave Project Developers
4//
5// See the file COPYRIGHT.md in the top-level directory of this
6// distribution or <https://octave.org/copyright/>.
7//
8// This file is part of Octave.
9//
10// Octave is free software: you can redistribute it and/or modify it
11// under the terms of the GNU General Public License as published by
12// the Free Software Foundation, either version 3 of the License, or
13// (at your option) any later version.
14//
15// Octave is distributed in the hope that it will be useful, but
16// WITHOUT ANY WARRANTY; without even the implied warranty of
17// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18// GNU General Public License for more details.
19//
20// You should have received a copy of the GNU General Public License
21// along with Octave; see the file COPYING. If not, see
22// <https://www.gnu.org/licenses/>.
23//
24////////////////////////////////////////////////////////////////////////
25
26#if defined (HAVE_CONFIG_H)
27# include "config.h"
28#endif
29
30#include "dMatrix.h"
31#include "fMatrix.h"
32#include "CMatrix.h"
33#include "fCMatrix.h"
34
35#include "dSparse.h"
36#include "CSparse.h"
37
38#include "dDiagMatrix.h"
39#include "fDiagMatrix.h"
40#include "CDiagMatrix.h"
41#include "fCDiagMatrix.h"
42
43#include "PermMatrix.h"
44
45#include "mx-inlines.cc"
46#include "quit.h"
47
48#include "defun.h"
49#include "error.h"
50#include "ovl.h"
51
52OCTAVE_NAMESPACE_BEGIN
53
54template <typename R, typename T>
55static MArray<T>
56kron (const MArray<R>& a, const MArray<T>& b)
57{
58 assert (a.ndims () == 2);
59 assert (b.ndims () == 2);
60
61 octave_idx_type nra = a.rows ();
62 octave_idx_type nrb = b.rows ();
63 octave_idx_type nca = a.cols ();
64 octave_idx_type ncb = b.cols ();
65
66 MArray<T> c (dim_vector (nra*nrb, nca*ncb));
67 T *cv = c.fortran_vec ();
68
69 for (octave_idx_type ja = 0; ja < nca; ja++)
70 {
71 octave_quit ();
72 for (octave_idx_type jb = 0; jb < ncb; jb++)
73 {
74 for (octave_idx_type ia = 0; ia < nra; ia++)
75 {
76 mx_inline_mul (nrb, cv, a(ia, ja), b.data () + nrb*jb);
77 cv += nrb;
78 }
79 }
80 }
81
82 return c;
83}
84
85template <typename R, typename T>
86static MArray<T>
87kron (const MDiagArray2<R>& a, const MArray<T>& b)
88{
89 assert (b.ndims () == 2);
90
91 octave_idx_type nra = a.rows ();
92 octave_idx_type nrb = b.rows ();
94 octave_idx_type nca = a.cols ();
95 octave_idx_type ncb = b.cols ();
96
97 MArray<T> c (dim_vector (nra*nrb, nca*ncb), T ());
98
99 for (octave_idx_type ja = 0; ja < dla; ja++)
100 {
101 octave_quit ();
102 for (octave_idx_type jb = 0; jb < ncb; jb++)
103 {
104 mx_inline_mul (nrb, &c.xelem (ja*nrb, ja*ncb + jb), a.dgelem (ja),
105 b.data () + nrb*jb);
106 }
107 }
108
109 return c;
110}
111
112template <typename T>
113static MSparse<T>
114kron (const MSparse<T>& A, const MSparse<T>& B)
115{
116 octave_idx_type idx = 0;
117 MSparse<T> C (A.rows () * B.rows (), A.columns () * B.columns (),
118 A.nnz () * B.nnz ());
119
120 C.cidx (0) = 0;
121
122 for (octave_idx_type Aj = 0; Aj < A.columns (); Aj++)
123 {
124 octave_quit ();
125 for (octave_idx_type Bj = 0; Bj < B.columns (); Bj++)
126 {
127 for (octave_idx_type Ai = A.cidx (Aj); Ai < A.cidx (Aj+1); Ai++)
128 {
129 octave_idx_type Ci = A.ridx (Ai) * B.rows ();
130 const T v = A.data (Ai);
131
132 for (octave_idx_type Bi = B.cidx (Bj); Bi < B.cidx (Bj+1); Bi++)
133 {
134 C.data (idx) = v * B.data (Bi);
135 C.ridx (idx++) = Ci + B.ridx (Bi);
136 }
137 }
138 C.cidx (Aj * B.columns () + Bj + 1) = idx;
139 }
140 }
141
142 return C;
143}
144
145static PermMatrix
146kron (const PermMatrix& a, const PermMatrix& b)
147{
148 octave_idx_type na = a.rows ();
149 octave_idx_type nb = b.rows ();
150 const Array<octave_idx_type>& pa = a.col_perm_vec ();
151 const Array<octave_idx_type>& pb = b.col_perm_vec ();
152 Array<octave_idx_type> res_perm (dim_vector (na * nb, 1));
153 octave_idx_type rescol = 0;
154 for (octave_idx_type i = 0; i < na; i++)
155 {
156 octave_idx_type a_add = pa(i) * nb;
157 for (octave_idx_type j = 0; j < nb; j++)
158 res_perm.xelem (rescol++) = a_add + pb(j);
159 }
160
161 return PermMatrix (res_perm, true);
162}
163
164template <typename MTA, typename MTB>
167{
168 MTA am = octave_value_extract<MTA> (a);
169 MTB bm = octave_value_extract<MTB> (b);
170
171 return octave_value (kron (am, bm));
172}
173
176{
177 octave_value retval;
178 if (a.is_perm_matrix () && b.is_perm_matrix ())
179 retval = do_kron<PermMatrix, PermMatrix> (a, b);
180 else if (a.issparse () || b.issparse ())
181 {
182 if (a.iscomplex () || b.iscomplex ())
183 retval = do_kron<SparseComplexMatrix, SparseComplexMatrix> (a, b);
184 else
185 retval = do_kron<SparseMatrix, SparseMatrix> (a, b);
186 }
187 else if (a.is_diag_matrix ())
188 {
189 if (b.is_diag_matrix () && a.rows () == a.columns ()
190 && b.rows () == b.columns ())
191 {
192 // We have two diagonal matrices, the product of those will be
193 // another diagonal matrix. To do that efficiently, extract
194 // the diagonals as vectors and compute the product. That
195 // will be another vector, which we then use to construct a
196 // diagonal matrix object. Note that this will fail if our
197 // digaonal matrix object is modified to allow the nonzero
198 // values to be stored off of the principal diagonal (i.e., if
199 // diag ([1,2], 3) is modified to return a diagonal matrix
200 // object instead of a full matrix object).
201
202 octave_value tmp = dispatch_kron (a.diag (), b.diag ());
203 retval = tmp.diag ();
204 }
205 else if (a.is_single_type () || b.is_single_type ())
206 {
207 if (a.iscomplex ())
208 retval = do_kron<FloatComplexDiagMatrix, FloatComplexMatrix> (a, b);
209 else if (b.iscomplex ())
210 retval = do_kron<FloatDiagMatrix, FloatComplexMatrix> (a, b);
211 else
212 retval = do_kron<FloatDiagMatrix, FloatMatrix> (a, b);
213 }
214 else
215 {
216 if (a.iscomplex ())
217 retval = do_kron<ComplexDiagMatrix, ComplexMatrix> (a, b);
218 else if (b.iscomplex ())
219 retval = do_kron<DiagMatrix, ComplexMatrix> (a, b);
220 else
221 retval = do_kron<DiagMatrix, Matrix> (a, b);
222 }
223 }
224 else if (a.is_single_type () || b.is_single_type ())
225 {
226 if (a.iscomplex ())
227 retval = do_kron<FloatComplexMatrix, FloatComplexMatrix> (a, b);
228 else if (b.iscomplex ())
229 retval = do_kron<FloatMatrix, FloatComplexMatrix> (a, b);
230 else
231 retval = do_kron<FloatMatrix, FloatMatrix> (a, b);
232 }
233 else
234 {
235 if (a.iscomplex ())
236 retval = do_kron<ComplexMatrix, ComplexMatrix> (a, b);
237 else if (b.iscomplex ())
238 retval = do_kron<Matrix, ComplexMatrix> (a, b);
239 else
240 retval = do_kron<Matrix, Matrix> (a, b);
241 }
242 return retval;
243}
244
245
246DEFUN (kron, args, ,
247 doc: /* -*- texinfo -*-
248@deftypefn {} {} kron (@var{A}, @var{B})
249@deftypefnx {} {} kron (@var{A1}, @var{A2}, @dots{})
250Form the Kronecker product of two or more matrices.
251
252This is defined block by block as
253
254@example
255x = [ a(i,j)*b ]
256@end example
257
258For example:
259
260@example
261@group
262kron (1:4, ones (3, 1))
263 @result{} 1 2 3 4
264 1 2 3 4
265 1 2 3 4
266@end group
267@end example
268
269If there are more than two input arguments @var{A1}, @var{A2}, @dots{},
270@var{An} the Kronecker product is computed as
271
272@example
273kron (kron (@var{A1}, @var{A2}), @dots{}, @var{An})
274@end example
275
276@noindent
277Since the Kronecker product is associative, this is well-defined.
278@end deftypefn */)
279{
280 int nargin = args.length ();
281
282 if (nargin < 2)
283 print_usage ();
284
285 octave_value retval;
286
287 octave_value a = args(0);
288 octave_value b = args(1);
289
290 retval = dispatch_kron (a, b);
291
292 for (octave_idx_type i = 2; i < nargin; i++)
293 retval = dispatch_kron (retval, args(i));
294
295 return retval;
296}
297
298/*
299%!test
300%! x = ones (2);
301%! assert (kron (x, x), ones (4));
302
303%!shared x, y, z, p1, p2, d1, d2
304%! x = [1, 2];
305%! y = [-1, -2];
306%! z = [1, 2, 3, 4; 1, 2, 3, 4; 1, 2, 3, 4];
307%! p1 = eye (3)([2, 3, 1], :); ## Permutation matrix
308%! p2 = [0 1 0; 0 0 1; 1 0 0]; ## Non-permutation equivalent
309%! d1 = diag ([1 2 3]); ## Diag type matrix
310%! d2 = [1 0 0; 0 2 0; 0 0 3]; ## Non-diag equivalent
311%!assert (kron (1:4, ones (3, 1)), z)
312%!assert (kron (single (1:4), ones (3, 1)), single (z))
313%!assert (kron (sparse (1:4), ones (3, 1)), sparse (z))
314%!assert (kron (complex (1:4), ones (3, 1)), z)
315%!assert (kron (complex (single (1:4)), ones (3, 1)), single (z))
316%!assert (kron (x, y, z), kron (kron (x, y), z))
317%!assert (kron (x, y, z), kron (x, kron (y, z)))
318%!assert (kron (p1, p1), kron (p2, p2))
319%!assert (kron (p1, p2), kron (p2, p1))
320%!assert (kron (d1, d1), kron (d2, d2))
321%!assert (kron (d1, d2), kron (d2, d1))
322
323%!assert (kron (diag ([1, 2]), diag ([3, 4])), diag ([3, 4, 6, 8]))
324
325## Test for two diag matrices.
326## See the comments above in dispatch_kron for this case.
327%!test
328%! expected = zeros (16, 16);
329%! expected (1, 11) = 3;
330%! expected (2, 12) = 4;
331%! expected (5, 15) = 6;
332%! expected (6, 16) = 8;
333%! assert (kron (diag ([1, 2], 2), diag ([3, 4], 2)), expected);
334*/
335
336OCTAVE_NAMESPACE_END
#define C(a, b)
Definition: Faddeeva.cc:259
T & xelem(octave_idx_type n)
Size of the specified dimension.
Definition: Array.h:504
octave_idx_type cols(void) const
Definition: Array.h:457
octave_idx_type rows(void) const
Definition: Array.h:449
const T * data(void) const
Size of the specified dimension.
Definition: Array.h:616
OCTARRAY_API T * fortran_vec(void)
Size of the specified dimension.
Definition: Array.cc:1744
int ndims(void) const
Size of the specified dimension.
Definition: Array.h:627
octave_idx_type diag_length(void) const
Definition: DiagArray2.h:93
T dgelem(octave_idx_type i) const
Definition: DiagArray2.h:124
octave_idx_type cols(void) const
Definition: DiagArray2.h:90
octave_idx_type rows(void) const
Definition: DiagArray2.h:89
Template for N-dimensional array classes with like-type math operators.
Definition: MArray.h:63
Template for two dimensional diagonal array with math operators.
Definition: MDiagArray2.h:56
octave_idx_type rows(void) const
Definition: PermMatrix.h:62
const Array< octave_idx_type > & col_perm_vec(void) const
Definition: PermMatrix.h:83
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:94
bool issparse(void) const
Definition: ov.h:798
octave_idx_type rows(void) const
Definition: ov.h:590
bool is_diag_matrix(void) const
Definition: ov.h:676
octave_idx_type columns(void) const
Definition: ov.h:592
bool is_single_type(void) const
Definition: ov.h:743
bool is_perm_matrix(void) const
Definition: ov.h:679
bool iscomplex(void) const
Definition: ov.h:786
octave_value diag(octave_idx_type k=0) const
Definition: ov.h:1531
OCTINTERP_API void print_usage(void)
Definition: defun-int.h:72
#define DEFUN(name, args_name, nargout_name, doc)
Macro to define a builtin function.
Definition: defun.h:56
octave_value dispatch_kron(const octave_value &a, const octave_value &b)
Definition: kron.cc:175
octave_value do_kron(const octave_value &a, const octave_value &b)
Definition: kron.cc:166
static OCTAVE_NAMESPACE_BEGIN MArray< T > kron(const MArray< R > &a, const MArray< T > &b)
Definition: kron.cc:56
F77_RET_T const F77_INT F77_CMPLX const F77_INT F77_CMPLX * B
F77_RET_T const F77_INT F77_CMPLX * A
class OCTAVE_API PermMatrix
Definition: mx-fwd.h:64
void mx_inline_mul(std::size_t n, R *r, const X *x, const Y *y)
Definition: mx-inlines.cc:109
return octave_value(v1.char_array_value() . concat(v2.char_array_value(), ra_idx),((a1.is_sq_string()||a2.is_sq_string()) ? '\'' :'"'))