00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifdef HAVE_CONFIG_H
00026 #include <config.h>
00027 #endif
00028
00029 #include "dMatrix.h"
00030 #include "fMatrix.h"
00031 #include "CMatrix.h"
00032 #include "fCMatrix.h"
00033
00034 #include "dSparse.h"
00035 #include "CSparse.h"
00036
00037 #include "dDiagMatrix.h"
00038 #include "fDiagMatrix.h"
00039 #include "CDiagMatrix.h"
00040 #include "fCDiagMatrix.h"
00041
00042 #include "PermMatrix.h"
00043
00044 #include "mx-inlines.cc"
00045 #include "quit.h"
00046
00047 #include "defun-dld.h"
00048 #include "error.h"
00049 #include "oct-obj.h"
00050
00051 template <class R, class T>
00052 static MArray<T>
00053 kron (const MArray<R>& a, const MArray<T>& b)
00054 {
00055 assert (a.ndims () == 2);
00056 assert (b.ndims () == 2);
00057
00058 octave_idx_type nra = a.rows (), nrb = b.rows ();
00059 octave_idx_type nca = a.cols (), ncb = b.cols ();
00060
00061 MArray<T> c (dim_vector (nra*nrb, nca*ncb));
00062 T *cv = c.fortran_vec ();
00063
00064 for (octave_idx_type ja = 0; ja < nca; ja++)
00065 for (octave_idx_type jb = 0; jb < ncb; jb++)
00066 for (octave_idx_type ia = 0; ia < nra; ia++)
00067 {
00068 octave_quit ();
00069 mx_inline_mul (nrb, cv, a(ia, ja), b.data () + nrb*jb);
00070 cv += nrb;
00071 }
00072
00073 return c;
00074 }
00075
00076 template <class R, class T>
00077 static MArray<T>
00078 kron (const MDiagArray2<R>& a, const MArray<T>& b)
00079 {
00080 assert (b.ndims () == 2);
00081
00082 octave_idx_type nra = a.rows (), nrb = b.rows (), dla = a.diag_length ();
00083 octave_idx_type nca = a.cols (), ncb = b.cols ();
00084
00085 MArray<T> c (dim_vector (nra*nrb, nca*ncb), T());
00086
00087 for (octave_idx_type ja = 0; ja < dla; ja++)
00088 for (octave_idx_type jb = 0; jb < ncb; jb++)
00089 {
00090 octave_quit ();
00091 mx_inline_mul (nrb, &c.xelem(ja*nrb, ja*ncb + jb), a.dgelem (ja), b.data () + nrb*jb);
00092 }
00093
00094 return c;
00095 }
00096
00097 template <class T>
00098 static MSparse<T>
00099 kron (const MSparse<T>& A, const MSparse<T>& B)
00100 {
00101 octave_idx_type idx = 0;
00102 MSparse<T> C (A.rows () * B.rows (), A.columns () * B.columns (),
00103 A.nnz () * B.nnz ());
00104
00105 C.cidx (0) = 0;
00106
00107 for (octave_idx_type Aj = 0; Aj < A.columns (); Aj++)
00108 for (octave_idx_type Bj = 0; Bj < B.columns (); Bj++)
00109 {
00110 octave_quit ();
00111 for (octave_idx_type Ai = A.cidx (Aj); Ai < A.cidx (Aj+1); Ai++)
00112 {
00113 octave_idx_type Ci = A.ridx(Ai) * B.rows ();
00114 const T v = A.data (Ai);
00115
00116 for (octave_idx_type Bi = B.cidx (Bj); Bi < B.cidx (Bj+1); Bi++)
00117 {
00118 C.data (idx) = v * B.data (Bi);
00119 C.ridx (idx++) = Ci + B.ridx (Bi);
00120 }
00121 }
00122 C.cidx (Aj * B.columns () + Bj + 1) = idx;
00123 }
00124
00125 return C;
00126 }
00127
00128 static PermMatrix
00129 kron (const PermMatrix& a, const PermMatrix& b)
00130 {
00131 octave_idx_type na = a.rows (), nb = b.rows ();
00132 const octave_idx_type *pa = a.data (), *pb = b.data ();
00133 PermMatrix c(na*nb);
00134 octave_idx_type *pc = c.fortran_vec ();
00135
00136 bool cola = a.is_col_perm (), colb = b.is_col_perm ();
00137 if (cola && colb)
00138 {
00139 for (octave_idx_type i = 0; i < na; i++)
00140 for (octave_idx_type j = 0; j < nb; j++)
00141 pc[pa[i]*nb+pb[j]] = i*nb+j;
00142 }
00143 else if (cola)
00144 {
00145 for (octave_idx_type i = 0; i < na; i++)
00146 for (octave_idx_type j = 0; j < nb; j++)
00147 pc[pa[i]*nb+j] = i*nb+pb[j];
00148 }
00149 else if (colb)
00150 {
00151 for (octave_idx_type i = 0; i < na; i++)
00152 for (octave_idx_type j = 0; j < nb; j++)
00153 pc[i*nb+pb[j]] = pa[i]*nb+j;
00154 }
00155 else
00156 {
00157 for (octave_idx_type i = 0; i < na; i++)
00158 for (octave_idx_type j = 0; j < nb; j++)
00159 pc[i*nb+j] = pa[i]*nb+pb[j];
00160 }
00161
00162 return c;
00163 }
00164
00165 template <class MTA, class MTB>
00166 octave_value
00167 do_kron (const octave_value& a, const octave_value& b)
00168 {
00169 MTA am = octave_value_extract<MTA> (a);
00170 MTB bm = octave_value_extract<MTB> (b);
00171 return octave_value (kron (am, bm));
00172 }
00173
00174 octave_value
00175 dispatch_kron (const octave_value& a, const octave_value& b)
00176 {
00177 octave_value retval;
00178 if (a.is_perm_matrix () && b.is_perm_matrix ())
00179 retval = do_kron<PermMatrix, PermMatrix> (a, b);
00180 else if (a.is_diag_matrix ())
00181 {
00182 if (b.is_diag_matrix () && a.rows () == a.columns ()
00183 && b.rows () == b.columns ())
00184 {
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195 octave_value tmp = dispatch_kron (a.diag (), b.diag ());
00196 retval = tmp.diag ();
00197 }
00198 else if (a.is_single_type () || b.is_single_type ())
00199 {
00200 if (a.is_complex_type ())
00201 retval = do_kron<FloatComplexDiagMatrix, FloatComplexMatrix> (a, b);
00202 else if (b.is_complex_type ())
00203 retval = do_kron<FloatDiagMatrix, FloatComplexMatrix> (a, b);
00204 else
00205 retval = do_kron<FloatDiagMatrix, FloatMatrix> (a, b);
00206 }
00207 else
00208 {
00209 if (a.is_complex_type ())
00210 retval = do_kron<ComplexDiagMatrix, ComplexMatrix> (a, b);
00211 else if (b.is_complex_type ())
00212 retval = do_kron<DiagMatrix, ComplexMatrix> (a, b);
00213 else
00214 retval = do_kron<DiagMatrix, Matrix> (a, b);
00215 }
00216 }
00217 else if (a.is_sparse_type () || b.is_sparse_type ())
00218 {
00219 if (a.is_complex_type () || b.is_complex_type ())
00220 retval = do_kron<SparseComplexMatrix, SparseComplexMatrix> (a, b);
00221 else
00222 retval = do_kron<SparseMatrix, SparseMatrix> (a, b);
00223 }
00224 else if (a.is_single_type () || b.is_single_type ())
00225 {
00226 if (a.is_complex_type ())
00227 retval = do_kron<FloatComplexMatrix, FloatComplexMatrix> (a, b);
00228 else if (b.is_complex_type ())
00229 retval = do_kron<FloatMatrix, FloatComplexMatrix> (a, b);
00230 else
00231 retval = do_kron<FloatMatrix, FloatMatrix> (a, b);
00232 }
00233 else
00234 {
00235 if (a.is_complex_type ())
00236 retval = do_kron<ComplexMatrix, ComplexMatrix> (a, b);
00237 else if (b.is_complex_type ())
00238 retval = do_kron<Matrix, ComplexMatrix> (a, b);
00239 else
00240 retval = do_kron<Matrix, Matrix> (a, b);
00241 }
00242 return retval;
00243 }
00244
00245
00246 DEFUN_DLD (kron, args, , "-*- texinfo -*-\n\
00247 @deftypefn {Loadable Function} {} kron (@var{A}, @var{B})\n\
00248 @deftypefnx {Loadable Function} {} kron (@var{A1}, @var{A2}, @dots{})\n\
00249 Form the Kronecker product of two or more matrices, defined block by \n\
00250 block as\n\
00251 \n\
00252 @example\n\
00253 x = [a(i, j) b]\n\
00254 @end example\n\
00255 \n\
00256 For example:\n\
00257 \n\
00258 @example\n\
00259 @group\n\
00260 kron (1:4, ones (3, 1))\n\
00261 @result{} 1 2 3 4\n\
00262 1 2 3 4\n\
00263 1 2 3 4\n\
00264 @end group\n\
00265 @end example\n\
00266 \n\
00267 If there are more than two input arguments @var{A1}, @var{A2}, @dots{}, \n\
00268 @var{An} the Kronecker product is computed as\n\
00269 \n\
00270 @example\n\
00271 kron (kron (@var{A1}, @var{A2}), @dots{}, @var{An})\n\
00272 @end example\n\
00273 \n\
00274 @noindent\n\
00275 Since the Kronecker product is associative, this is well-defined.\n\
00276 @end deftypefn")
00277 {
00278 octave_value retval;
00279
00280 int nargin = args.length ();
00281
00282 if (nargin >= 2)
00283 {
00284 octave_value a = args(0), b = args(1);
00285 retval = dispatch_kron (a, b);
00286 for (octave_idx_type i = 2; i < nargin; i++)
00287 retval = dispatch_kron (retval, args(i));
00288 }
00289 else
00290 print_usage ();
00291
00292 return retval;
00293 }
00294
00295
00296
00297
00298
00299
00300
00301
00302
00303
00304
00305
00306
00307
00308
00309
00310
00311
00312
00313
00314
00315
00316
00317
00318
00319
00320
00321
00322