00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #ifdef HAVE_CONFIG_H
00024 #include <config.h>
00025 #endif
00026
00027 #include "f77-fcn.h"
00028 #include "mx-base.h"
00029 #include "error.h"
00030 #include "defun-dld.h"
00031 #include "parse.h"
00032
00033 extern "C"
00034 {
00035 F77_RET_T
00036 F77_FUNC (ddot3, DDOT3) (const octave_idx_type&, const octave_idx_type&,
00037 const octave_idx_type&, const double*,
00038 const double*, double*);
00039
00040 F77_RET_T
00041 F77_FUNC (sdot3, SDOT3) (const octave_idx_type&, const octave_idx_type&,
00042 const octave_idx_type&, const float*,
00043 const float*, float*);
00044
00045 F77_RET_T
00046 F77_FUNC (zdotc3, ZDOTC3) (const octave_idx_type&, const octave_idx_type&,
00047 const octave_idx_type&, const Complex*,
00048 const Complex*, Complex*);
00049
00050 F77_RET_T
00051 F77_FUNC (cdotc3, CDOTC3) (const octave_idx_type&, const octave_idx_type&,
00052 const octave_idx_type&, const FloatComplex*,
00053 const FloatComplex*, FloatComplex*);
00054
00055 F77_RET_T
00056 F77_FUNC (dmatm3, DMATM3) (const octave_idx_type&, const octave_idx_type&,
00057 const octave_idx_type&, const octave_idx_type&,
00058 const double*, const double*, double*);
00059
00060 F77_RET_T
00061 F77_FUNC (smatm3, SMATM3) (const octave_idx_type&, const octave_idx_type&,
00062 const octave_idx_type&, const octave_idx_type&,
00063 const float*, const float*, float*);
00064
00065 F77_RET_T
00066 F77_FUNC (zmatm3, ZMATM3) (const octave_idx_type&, const octave_idx_type&,
00067 const octave_idx_type&, const octave_idx_type&,
00068 const Complex*, const Complex*, Complex*);
00069
00070 F77_RET_T
00071 F77_FUNC (cmatm3, CMATM3) (const octave_idx_type&, const octave_idx_type&,
00072 const octave_idx_type&, const octave_idx_type&,
00073 const FloatComplex*, const FloatComplex*,
00074 FloatComplex*);
00075 }
00076
00077 static void
00078 get_red_dims (const dim_vector& x, const dim_vector& y, int dim,
00079 dim_vector& z, octave_idx_type& m, octave_idx_type& n,
00080 octave_idx_type& k)
00081 {
00082 int nd = x.length ();
00083 assert (nd == y.length ());
00084 z = dim_vector::alloc (nd);
00085 m = 1, n = 1, k = 1;
00086 for (int i = 0; i < nd; i++)
00087 {
00088 if (i < dim)
00089 {
00090 z(i) = x(i);
00091 m *= x(i);
00092 }
00093 else if (i > dim)
00094 {
00095 z(i) = x(i);
00096 n *= x(i);
00097 }
00098 else
00099 {
00100 k = x(i);
00101 z(i) = 1;
00102 }
00103 }
00104 }
00105
00106 DEFUN_DLD (dot, args, ,
00107 "-*- texinfo -*-\n\
00108 @deftypefn {Loadable Function} {} dot (@var{x}, @var{y}, @var{dim})\n\
00109 Compute the dot product of two vectors. If @var{x} and @var{y}\n\
00110 are matrices, calculate the dot products along the first\n\
00111 non-singleton dimension. If the optional argument @var{dim} is\n\
00112 given, calculate the dot products along this dimension.\n\
00113 \n\
00114 This is equivalent to\n\
00115 @code{sum (conj (@var{X}) .* @var{Y}, @var{dim})},\n\
00116 but avoids forming a temporary array and is faster. When @var{X} and\n\
00117 @var{Y} are column vectors, the result is equivalent to\n\
00118 @code{@var{X}' * @var{Y}}.\n\
00119 @seealso{cross, divergence}\n\
00120 @end deftypefn")
00121 {
00122 octave_value retval;
00123 int nargin = args.length ();
00124
00125 if (nargin < 2 || nargin > 3)
00126 {
00127 print_usage ();
00128 return retval;
00129 }
00130
00131 octave_value argx = args(0), argy = args(1);
00132
00133 if (argx.is_numeric_type () && argy.is_numeric_type ())
00134 {
00135 dim_vector dimx = argx.dims (), dimy = argy.dims ();
00136 bool match = dimx == dimy;
00137 if (! match && nargin == 2
00138 && dimx.is_vector () && dimy.is_vector ())
00139 {
00140
00141 dimx = dimx.redim (1);
00142 argx = argx.reshape (dimx);
00143 dimy = dimy.redim (1);
00144 argy = argy.reshape (dimy);
00145 match = ! error_state;
00146 }
00147
00148 if (match)
00149 {
00150 int dim;
00151 if (nargin == 2)
00152 dim = dimx.first_non_singleton ();
00153 else
00154 dim = args(2).int_value (true) - 1;
00155
00156 if (error_state)
00157 ;
00158 else if (dim < 0)
00159 error ("dot: DIM must be a valid dimension");
00160 else
00161 {
00162 octave_idx_type m, n, k;
00163 dim_vector dimz;
00164 if (argx.is_complex_type () || argy.is_complex_type ())
00165 {
00166 if (argx.is_single_type () || argy.is_single_type ())
00167 {
00168 FloatComplexNDArray x = argx.float_complex_array_value ();
00169 FloatComplexNDArray y = argy.float_complex_array_value ();
00170 get_red_dims (dimx, dimy, dim, dimz, m, n, k);
00171 FloatComplexNDArray z(dimz);
00172 if (! error_state)
00173 F77_XFCN (cdotc3, CDOTC3, (m, n, k, x.data (), y.data (),
00174 z.fortran_vec ()));
00175 retval = z;
00176 }
00177 else
00178 {
00179 ComplexNDArray x = argx.complex_array_value ();
00180 ComplexNDArray y = argy.complex_array_value ();
00181 get_red_dims (dimx, dimy, dim, dimz, m, n, k);
00182 ComplexNDArray z(dimz);
00183 if (! error_state)
00184 F77_XFCN (zdotc3, ZDOTC3, (m, n, k, x.data (), y.data (),
00185 z.fortran_vec ()));
00186 retval = z;
00187 }
00188 }
00189 else if (argx.is_float_type () && argy.is_float_type ())
00190 {
00191 if (argx.is_single_type () || argy.is_single_type ())
00192 {
00193 FloatNDArray x = argx.float_array_value ();
00194 FloatNDArray y = argy.float_array_value ();
00195 get_red_dims (dimx, dimy, dim, dimz, m, n, k);
00196 FloatNDArray z(dimz);
00197 if (! error_state)
00198 F77_XFCN (sdot3, SDOT3, (m, n, k, x.data (), y.data (),
00199 z.fortran_vec ()));
00200 retval = z;
00201 }
00202 else
00203 {
00204 NDArray x = argx.array_value ();
00205 NDArray y = argy.array_value ();
00206 get_red_dims (dimx, dimy, dim, dimz, m, n, k);
00207 NDArray z(dimz);
00208 if (! error_state)
00209 F77_XFCN (ddot3, DDOT3, (m, n, k, x.data (), y.data (),
00210 z.fortran_vec ()));
00211 retval = z;
00212 }
00213 }
00214 else
00215 {
00216
00217 octave_value_list tmp;
00218 tmp(1) = args(2);
00219 tmp(0) = do_binary_op (octave_value::op_el_mul, argx, argy);
00220 if (! error_state)
00221 {
00222 tmp = feval ("sum", tmp, 1);
00223 if (! tmp.empty ())
00224 retval = tmp(0);
00225 }
00226 }
00227 }
00228 }
00229 else
00230 error ("dot: sizes of X and Y must match");
00231
00232 }
00233 else
00234 error ("dot: X and Y must be numeric");
00235
00236 return retval;
00237 }
00238
00239
00240
00241
00242
00243
00244
00245
00246
00247
00248
00249
00250
00251
00252
00253
00254 DEFUN_DLD (blkmm, args, ,
00255 "-*- texinfo -*-\n\
00256 @deftypefn {Loadable Function} {} blkmm (@var{A}, @var{B})\n\
00257 Compute products of matrix blocks. The blocks are given as\n\
00258 2-dimensional subarrays of the arrays @var{A}, @var{B}.\n\
00259 The size of @var{A} must have the form @code{[m,k,@dots{}]} and\n\
00260 size of @var{B} must be @code{[k,n,@dots{}]}. The result is\n\
00261 then of size @code{[m,n,@dots{}]} and is computed as follows:\n\
00262 \n\
00263 @example\n\
00264 @group\n\
00265 for i = 1:prod (size (@var{A})(3:end))\n\
00266 @var{C}(:,:,i) = @var{A}(:,:,i) * @var{B}(:,:,i)\n\
00267 endfor\n\
00268 @end group\n\
00269 @end example\n\
00270 @end deftypefn")
00271 {
00272 octave_value retval;
00273 int nargin = args.length ();
00274
00275 if (nargin != 2)
00276 {
00277 print_usage ();
00278 return retval;
00279 }
00280
00281 octave_value argx = args(0), argy = args(1);
00282
00283 if (argx.is_numeric_type () && argy.is_numeric_type ())
00284 {
00285 const dim_vector dimx = argx.dims (), dimy = argy.dims ();
00286 int nd = dimx.length ();
00287 octave_idx_type m = dimx(0), k = dimx(1), n = dimy(1), np = 1;
00288 bool match = dimy(0) == k && nd == dimy.length ();
00289 dim_vector dimz = dim_vector::alloc (nd);
00290 dimz(0) = m;
00291 dimz(1) = n;
00292 for (int i = 2; match && i < nd; i++)
00293 {
00294 match = match && dimx(i) == dimy(i);
00295 dimz(i) = dimx(i);
00296 np *= dimz(i);
00297 }
00298
00299 if (match)
00300 {
00301 if (argx.is_complex_type () || argy.is_complex_type ())
00302 {
00303 if (argx.is_single_type () || argy.is_single_type ())
00304 {
00305 FloatComplexNDArray x = argx.float_complex_array_value ();
00306 FloatComplexNDArray y = argy.float_complex_array_value ();
00307 FloatComplexNDArray z(dimz);
00308 if (! error_state)
00309 F77_XFCN (cmatm3, CMATM3, (m, n, k, np, x.data (), y.data (),
00310 z.fortran_vec ()));
00311 retval = z;
00312 }
00313 else
00314 {
00315 ComplexNDArray x = argx.complex_array_value ();
00316 ComplexNDArray y = argy.complex_array_value ();
00317 ComplexNDArray z(dimz);
00318 if (! error_state)
00319 F77_XFCN (zmatm3, ZMATM3, (m, n, k, np, x.data (), y.data (),
00320 z.fortran_vec ()));
00321 retval = z;
00322 }
00323 }
00324 else
00325 {
00326 if (argx.is_single_type () || argy.is_single_type ())
00327 {
00328 FloatNDArray x = argx.float_array_value ();
00329 FloatNDArray y = argy.float_array_value ();
00330 FloatNDArray z(dimz);
00331 if (! error_state)
00332 F77_XFCN (smatm3, SMATM3, (m, n, k, np, x.data (), y.data (),
00333 z.fortran_vec ()));
00334 retval = z;
00335 }
00336 else
00337 {
00338 NDArray x = argx.array_value ();
00339 NDArray y = argy.array_value ();
00340 NDArray z(dimz);
00341 if (! error_state)
00342 F77_XFCN (dmatm3, DMATM3, (m, n, k, np, x.data (), y.data (),
00343 z.fortran_vec ()));
00344 retval = z;
00345 }
00346 }
00347 }
00348 else
00349 error ("blkmm: A and B dimensions don't match: (%s) and (%s)",
00350 dimx.str ().c_str (), dimy.str ().c_str ());
00351
00352 }
00353 else
00354 error ("blkmm: A and B must be numeric");
00355
00356 return retval;
00357 }
00358
00359
00360
00361
00362
00363
00364
00365
00366
00367
00368