GNU Octave  9.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
dot.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2009-2024 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29 
30 #include "lo-blas-proto.h"
31 #include "mx-base.h"
32 
33 #include "builtin-defun-decls.h"
34 #include "defun.h"
35 #include "error.h"
36 #include "parse.h"
37 
39 
40 // FIXME: input 'y' is no longer necessary (2/5/2022).
41 // At some point it would be better to change all occurrences of
42 // get_red_dims to eliminate this input parameter.
43 static void
44 get_red_dims (const dim_vector& x, const dim_vector& /* y */, int dim,
45  dim_vector& z, F77_INT& m, F77_INT& n, F77_INT& k)
46 {
47  int nd = x.ndims ();
48  z = dim_vector::alloc (nd);
49  octave_idx_type tmp_m = 1;
50  octave_idx_type tmp_n = 1;
51  octave_idx_type tmp_k = 1;
52  for (int i = 0; i < nd; i++)
53  {
54  if (i < dim)
55  {
56  z(i) = x(i);
57  tmp_m *= x(i);
58  }
59  else if (i > dim)
60  {
61  z(i) = x(i);
62  tmp_n *= x(i);
63  }
64  else
65  {
66  z(i) = 1;
67  tmp_k = x(i);
68  }
69  }
70 
71  m = to_f77_int (tmp_m);
72  n = to_f77_int (tmp_n);
73  k = to_f77_int (tmp_k);
74 }
75 
76 DEFUN (dot, args, ,
77  doc: /* -*- texinfo -*-
78 @deftypefn {} {@var{z} =} dot (@var{x}, @var{y})
79 @deftypefnx {} {@var{z} =} dot (@var{x}, @var{y}, @var{dim})
80 Compute the dot product of two vectors.
81 
82 If @var{x} and @var{y} are matrices, calculate the dot products along the
83 first non-singleton dimension.
84 
85 If the optional argument @var{dim} is given, calculate the dot products
86 along this dimension.
87 
88 Implementation Note: This is equivalent to
89 @code{sum (conj (@var{X}) .* @var{Y}, @var{dim})}, but avoids forming a
90 temporary array and is faster. When @var{X} and @var{Y} are column vectors,
91 the result is equivalent to @code{@var{X}' * @var{Y}}. Although, @code{dot}
92 is defined for integer arrays, the output may differ from the expected result
93 due to the limited range of integer objects.
94 @seealso{cross, divergence, tensorprod}
95 @end deftypefn */)
96 {
97  int nargin = args.length ();
98 
99  if (nargin < 2 || nargin > 3)
100  print_usage ();
101 
102  octave_value retval;
103  octave_value argx = args(0);
104  octave_value argy = args(1);
105 
106  if (! argx.isnumeric () || ! argy.isnumeric ())
107  error ("dot: X and Y must be numeric");
108 
109  dim_vector dimx = argx.dims ();
110  dim_vector dimy = argy.dims ();
111  bool match = dimx == dimy;
112  if (! match && nargin == 2 && dimx.isvector () && dimy.isvector ())
113  {
114  // Change to column vectors.
115  dimx = dimx.redim (1);
116  argx = argx.reshape (dimx);
117  dimy = dimy.redim (1);
118  argy = argy.reshape (dimy);
119  match = dimx == dimy;
120  }
121 
122  if (! match)
123  error ("dot: sizes of X and Y must match");
124 
125  int dim;
126  if (nargin == 2)
127  dim = dimx.first_non_singleton ();
128  else
129  dim = args(2).int_value (true) - 1;
130 
131  if (dim < 0)
132  error ("dot: DIM must be a valid dimension");
133 
134  F77_INT m, n, k;
135  dim_vector dimz;
136  if (argx.iscomplex () || argy.iscomplex ())
137  {
138  if (argx.is_single_type () || argy.is_single_type ())
139  {
142  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
143  FloatComplexNDArray z (dimz);
144 
145  F77_XFCN (cdotc3, CDOTC3, (m, n, k,
146  F77_CONST_CMPLX_ARG (x.data ()), F77_CONST_CMPLX_ARG (y.data ()),
147  F77_CMPLX_ARG (z.fortran_vec ())));
148  retval = z;
149  }
150  else
151  {
154  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
155  ComplexNDArray z (dimz);
156 
157  F77_XFCN (zdotc3, ZDOTC3, (m, n, k,
160  retval = z;
161  }
162  }
163  else if (argx.isfloat () && argy.isfloat ())
164  {
165  if (argx.is_single_type () || argy.is_single_type ())
166  {
167  FloatNDArray x = argx.float_array_value ();
168  FloatNDArray y = argy.float_array_value ();
169  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
170  FloatNDArray z (dimz);
171 
172  F77_XFCN (sdot3, SDOT3, (m, n, k, x.data (), y.data (),
173  z.fortran_vec ()));
174  retval = z;
175  }
176  else
177  {
178  NDArray x = argx.array_value ();
179  NDArray y = argy.array_value ();
180  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
181  NDArray z (dimz);
182 
183  F77_XFCN (ddot3, DDOT3, (m, n, k, x.data (), y.data (),
184  z.fortran_vec ()));
185  retval = z;
186  }
187  }
188  else
189  {
190  // Non-optimized evaluation.
191  // FIXME: This may *not* do what the user expects.
192  // It might be more useful to issue a warning, or even an error, instead
193  // of calculating possibly garbage results.
194  // Think of the dot product of two int8 vectors where the multiplications
195  // exceed intmax.
196  octave_value_list tmp (2);
197  tmp(0) = binary_op (octave_value::op_el_mul, argx, argy);
198  tmp(1) = dim + 1;
199 
200  tmp = Fsum (tmp, 1);
201  if (! tmp.empty ())
202  retval = tmp(0);
203  }
204 
205  return retval;
206 }
207 
208 /*
209 %!assert (dot ([1, 2], [2, 3]), 8)
210 
211 %!test
212 %! x = [2, 1; 2, 1];
213 %! y = [-0.5, 2; 0.5, -2];
214 %! assert (dot (x, y), [0 0]);
215 %! assert (dot (single (x), single (y)), single ([0 0]));
216 
217 %!test
218 %! x = [1+i, 3-i; 1-i, 3-i];
219 %! assert (dot (x, x), [4, 20]);
220 %! assert (dot (single (x), single (x)), single ([4, 20]));
221 
222 %!test
223 %! x = int8 ([1, 2]);
224 %! y = int8 ([2, 3]);
225 %! assert (dot (x, y), 8);
226 
227 %!test
228 %! x = int8 ([1, 2; 3, 4]);
229 %! y = int8 ([5, 6; 7, 8]);
230 %! assert (dot (x, y), [26 44]);
231 %! assert (dot (x, y, 2), [17; 53]);
232 %! assert (dot (x, y, 3), [5 12; 21 32]);
233 
234 ## This is, perhaps, surprising. Integer maximums and saturation mechanics
235 ## prevent accurate value from being calculated.
236 %!test
237 %! x = int8 ([127]);
238 %! assert (dot (x, x), 127);
239 
240 ## Test input validation
241 %!error dot ()
242 %!error dot (1)
243 %!error dot (1,2,3,4)
244 %!error <X and Y must be numeric> dot ({1,2}, [3,4])
245 %!error <X and Y must be numeric> dot ([1,2], {3,4})
246 %!error <sizes of X and Y must match> dot ([1 2], [1 2 3])
247 %!error <sizes of X and Y must match> dot ([1 2]', [1 2 3]')
248 %!error <sizes of X and Y must match> dot (ones (2,2), ones (2,3))
249 %!error <DIM must be a valid dimension> dot ([1 2], [1 2], 0)
250 */
251 
252 template <typename T>
253 static void
254 blkmm_internal (const T& x, const T& y, T& z,
255  F77_INT m, F77_INT n, F77_INT k, F77_INT np);
256 
257 template <>
258 void
261  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
262 {
263  F77_XFCN (cmatm3, CMATM3, (m, n, k, np,
264  F77_CONST_CMPLX_ARG (x.data ()),
265  F77_CONST_CMPLX_ARG (y.data ()),
266  F77_CMPLX_ARG (z.fortran_vec ())));
267 }
268 
269 template <>
270 void
272  ComplexNDArray& z,
273  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
274 {
275  F77_XFCN (zmatm3, ZMATM3, (m, n, k, np,
276  F77_CONST_DBLE_CMPLX_ARG (x.data ()),
279 }
280 
281 template <>
282 void
284  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
285 {
286  F77_XFCN (smatm3, SMATM3, (m, n, k, np,
287  x.data (), y.data (),
288  z.fortran_vec ()));
289 }
290 
291 template <>
292 void
293 blkmm_internal (const NDArray& x, const NDArray& y, NDArray& z,
294  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
295 {
296  F77_XFCN (dmatm3, DMATM3, (m, n, k, np,
297  x.data (), y.data (),
298  z.fortran_vec ()));
299 }
300 
301 static void
302 get_blkmm_dims (const dim_vector& dimx, const dim_vector& dimy,
303  F77_INT& m, F77_INT& n, F77_INT& k, F77_INT& np,
304  dim_vector& dimz)
305 {
306  int nd = dimx.ndims ();
307 
308  m = to_f77_int (dimx(0));
309  k = to_f77_int (dimx(1));
310  n = to_f77_int (dimy(1));
311 
312  octave_idx_type tmp_np = 1;
313 
314  bool match = ((dimy(0) == k) && (nd == dimy.ndims ()));
315 
316  dimz = dim_vector::alloc (nd);
317 
318  dimz(0) = m;
319  dimz(1) = n;
320  for (int i = 2; match && i < nd; i++)
321  {
322  match = (dimx(i) == dimy(i));
323  dimz(i) = dimx(i);
324  tmp_np *= dimz(i);
325  }
326 
327  np = to_f77_int (tmp_np);
328 
329  if (! match)
330  error ("blkmm: A and B dimensions don't match: (%s) and (%s)",
331  dimx.str ().c_str (), dimy.str ().c_str ());
332 }
333 
334 template <typename T>
335 T
336 do_blkmm (const octave_value& xov, const octave_value& yov)
337 {
338  const T x = octave_value_extract<T> (xov);
339  const T y = octave_value_extract<T> (yov);
340  F77_INT m, n, k, np;
341  dim_vector dimz;
342 
343  get_blkmm_dims (x.dims (), y.dims (), m, n, k, np, dimz);
344 
345  T z (dimz);
346 
347  if (n != 0 && m != 0)
348  blkmm_internal<T> (x, y, z, m, n, k, np);
349 
350  return z;
351 }
352 
353 DEFUN (blkmm, args, ,
354  doc: /* -*- texinfo -*-
355 @deftypefn {} {@var{C} =} blkmm (@var{A}, @var{B})
356 Compute products of matrix blocks.
357 
358 The blocks are given as 2-dimensional subarrays of the arrays @var{A},
359 @var{B}. The size of @var{A} must have the form @code{[m,k,@dots{}]} and
360 size of @var{B} must be @code{[k,n,@dots{}]}. The result is then of size
361 @code{[m,n,@dots{}]} and is computed as follows:
362 
363 @example
364 @group
365 for i = 1:prod (size (@var{A})(3:end))
366  @var{C}(:,:,i) = @var{A}(:,:,i) * @var{B}(:,:,i)
367 endfor
368 @end group
369 @end example
370 @end deftypefn */)
371 {
372  if (args.length () != 2)
373  print_usage ();
374 
375  octave_value retval;
376 
377  octave_value argx = args(0);
378  octave_value argy = args(1);
379 
380  if (! argx.isnumeric () || ! argy.isnumeric ())
381  error ("blkmm: A and B must be numeric");
382 
383  if (argx.iscomplex () || argy.iscomplex ())
384  {
385  if (argx.is_single_type () || argy.is_single_type ())
386  retval = do_blkmm<FloatComplexNDArray> (argx, argy);
387  else
388  retval = do_blkmm<ComplexNDArray> (argx, argy);
389  }
390  else
391  {
392  if (argx.is_single_type () || argy.is_single_type ())
393  retval = do_blkmm<FloatNDArray> (argx, argy);
394  else
395  retval = do_blkmm<NDArray> (argx, argy);
396  }
397 
398  return retval;
399 }
400 
401 /*
402 %!test
403 %! x(:,:,1) = [1 2; 3 4];
404 %! x(:,:,2) = [1 1; 1 1];
405 %! z(:,:,1) = [7 10; 15 22];
406 %! z(:,:,2) = [2 2; 2 2];
407 %! assert (blkmm (x,x), z);
408 %! assert (blkmm (single (x), single (x)), single (z));
409 %! assert (blkmm (x, single (x)), single (z));
410 
411 %!test
412 %! x(:,:,1) = [1 2; 3 4];
413 %! x(:,:,2) = [1i 1i; 1i 1i];
414 %! z(:,:,1) = [7 10; 15 22];
415 %! z(:,:,2) = [-2 -2; -2 -2];
416 %! assert (blkmm (x,x), z);
417 %! assert (blkmm (single (x), single (x)), single (z));
418 %! assert (blkmm (x, single (x)), single (z));
419 
420 %!test <*54261>
421 %! x = ones (0, 3, 3);
422 %! y = ones (3, 5, 3);
423 %! z = blkmm (x,y);
424 %! assert (size (z), [0, 5, 3]);
425 %! x = ones (1, 3, 3);
426 %! y = ones (3, 0, 3);
427 %! z = blkmm (x,y);
428 %! assert (size (z), [1, 0, 3]);
429 
430 ## Test input validation
431 %!error blkmm ()
432 %!error blkmm (1)
433 %!error blkmm (1,2,3)
434 %!error <A and B must be numeric> blkmm ({1,2}, [3,4])
435 %!error <A and B must be numeric> blkmm ([3,4], {1,2})
436 %!error <A and B dimensions don't match> blkmm (ones (2,2), ones (3,3))
437 */
438 
439 OCTAVE_END_NAMESPACE(octave)
octave_value_list Fsum(const octave_value_list &=octave_value_list(), int=0)
Definition: data.cc:3104
subroutine cdotc3(m, n, k, a, b, c)
Definition: cdotc3.f:23
T * fortran_vec()
Size of the specified dimension.
Definition: Array-base.cc:1764
const T * data() const
Size of the specified dimension.
Definition: Array.h:663
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:94
std::string str(char sep='x') const
Definition: dim-vector.cc:68
static dim_vector alloc(int n)
Definition: dim-vector.h:202
octave_idx_type ndims() const
Number of dimensions.
Definition: dim-vector.h:257
bool isvector() const
Definition: dim-vector.h:395
int first_non_singleton(int def=0) const
Definition: dim-vector.h:444
dim_vector redim(int n) const
Force certain dimensionality, preserving numel ().
Definition: dim-vector.cc:226
bool empty() const
Definition: ovl.h:115
bool isfloat() const
Definition: ov.h:701
bool isnumeric() const
Definition: ov.h:750
ComplexNDArray complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:878
bool is_single_type() const
Definition: ov.h:698
octave_value reshape(const dim_vector &dv) const
Definition: ov.h:571
bool iscomplex() const
Definition: ov.h:741
@ op_el_mul
Definition: ov.h:105
NDArray array_value(bool frc_str_conv=false) const
Definition: ov.h:859
FloatComplexNDArray float_complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:882
FloatNDArray float_array_value(bool frc_str_conv=false) const
Definition: ov.h:862
dim_vector dims() const
Definition: ov.h:541
subroutine cmatm3(m, n, k, np, a, b, c)
Definition: cmatm3.f:21
OCTAVE_BEGIN_NAMESPACE(octave) static octave_value daspk_fcn
subroutine ddot3(m, n, k, a, b, c)
Definition: ddot3.f:23
void print_usage(void)
Definition: defun-int.h:72
#define DEFUN(name, args_name, nargout_name, doc)
Macro to define a builtin function.
Definition: defun.h:56
subroutine dmatm3(m, n, k, np, a, b, c)
Definition: dmatm3.f:23
T do_blkmm(const octave_value &xov, const octave_value &yov)
Definition: dot.cc:336
void blkmm_internal(const FloatComplexNDArray &x, const FloatComplexNDArray &y, FloatComplexNDArray &z, F77_INT m, F77_INT n, F77_INT k, F77_INT np)
Definition: dot.cc:259
void() error(const char *fmt,...)
Definition: error.cc:988
#define F77_CONST_CMPLX_ARG(x)
Definition: f77-fcn.h:313
#define F77_DBLE_CMPLX_ARG(x)
Definition: f77-fcn.h:316
#define F77_CMPLX_ARG(x)
Definition: f77-fcn.h:310
#define F77_XFCN(f, F, args)
Definition: f77-fcn.h:45
octave_f77_int_type F77_INT
Definition: f77-fcn.h:306
#define F77_CONST_DBLE_CMPLX_ARG(x)
Definition: f77-fcn.h:319
double dot(const ColumnVector &v1, const ColumnVector &v2)
Definition: graphics.cc:5541
F77_RET_T const F77_DBLE * x
T octave_idx_type m
Definition: mx-inlines.cc:781
octave_idx_type n
Definition: mx-inlines.cc:761
octave_value binary_op(type_info &ti, octave_value::binary_op op, const octave_value &a, const octave_value &b)
subroutine sdot3(m, n, k, a, b, c)
Definition: sdot3.f:23
subroutine smatm3(m, n, k, np, a, b, c)
Definition: smatm3.f:23
subroutine zdotc3(m, n, k, a, b, c)
Definition: zdotc3.f:23
subroutine zmatm3(m, n, k, np, a, b, c)
Definition: zmatm3.f:23