GNU Octave  9.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
dot.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2009-2024 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29 
30 #include "lo-blas-proto.h"
31 #include "mx-base.h"
32 
33 #include "builtin-defun-decls.h"
34 #include "defun.h"
35 #include "error.h"
36 #include "parse.h"
37 
39 
40 // FIXME: input 'y' is no longer necessary (2/5/2022).
41 // At some point it would be better to change all occurrences of
42 // get_red_dims to eliminate this input parameter.
43 static void
44 get_red_dims (const dim_vector& x, const dim_vector& /* y */, int dim,
45  dim_vector& z, F77_INT& m, F77_INT& n, F77_INT& k)
46 {
47  int nd = x.ndims ();
48  z = dim_vector::alloc (nd);
49  octave_idx_type tmp_m = 1;
50  octave_idx_type tmp_n = 1;
51  octave_idx_type tmp_k = 1;
52  for (int i = 0; i < nd; i++)
53  {
54  if (i < dim)
55  {
56  z(i) = x(i);
57  tmp_m *= x(i);
58  }
59  else if (i > dim)
60  {
61  z(i) = x(i);
62  tmp_n *= x(i);
63  }
64  else
65  {
66  z(i) = 1;
67  tmp_k = x(i);
68  }
69  }
70 
71  m = to_f77_int (tmp_m);
72  n = to_f77_int (tmp_n);
73  k = to_f77_int (tmp_k);
74 }
75 
76 DEFUN (dot, args, ,
77  doc: /* -*- texinfo -*-
78 @deftypefn {} {@var{z} =} dot (@var{x}, @var{y})
79 @deftypefnx {} {@var{z} =} dot (@var{x}, @var{y}, @var{dim})
80 Compute the dot product of two vectors.
81 
82 If @var{x} and @var{y} are matrices, calculate the dot products along the
83 first non-singleton dimension.
84 
85 If the optional argument @var{dim} is given, calculate the dot products
86 along this dimension.
87 
88 Implementation Note: This is equivalent to
89 @code{sum (conj (@var{X}) .* @var{Y}, @var{dim})}, but avoids forming a
90 temporary array and is faster. When @var{X} and @var{Y} are column vectors,
91 the result is equivalent to @code{@var{X}' * @var{Y}}. Although, @code{dot}
92 is defined for integer arrays, the output may differ from the expected result
93 due to the limited range of integer objects.
94 @seealso{cross, divergence, tensorprod}
95 @end deftypefn */)
96 {
97  int nargin = args.length ();
98 
99  if (nargin < 2 || nargin > 3)
100  print_usage ();
101 
102  octave_value retval;
103  octave_value argx = args(0);
104  octave_value argy = args(1);
105 
106  if (! argx.isnumeric () || ! argy.isnumeric ())
107  error ("dot: X and Y must be numeric");
108 
109  dim_vector dimx = argx.dims ();
110  dim_vector dimy = argy.dims ();
111  bool match = dimx == dimy;
112  if (! match && nargin == 2 && dimx.isvector () && dimy.isvector ())
113  {
114  // Change to column vectors.
115  dimx = dimx.redim (1);
116  argx = argx.reshape (dimx);
117  dimy = dimy.redim (1);
118  argy = argy.reshape (dimy);
119  match = dimx == dimy;
120  }
121 
122  if (! match)
123  error ("dot: sizes of X and Y must match");
124 
125  int dim;
126  if (nargin == 2)
127  dim = dimx.first_non_singleton ();
128  else
129  dim = args(2).int_value (true) - 1;
130 
131  if (dim < 0)
132  error ("dot: DIM must be a valid dimension");
133 
134  F77_INT m, n, k;
135  dim_vector dimz;
136  if (argx.iscomplex () || argy.iscomplex ())
137  {
138  if (argx.is_single_type () || argy.is_single_type ())
139  {
142  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
143  FloatComplexNDArray z (dimz);
144 
145  F77_XFCN (cdotc3, CDOTC3, (m, n, k,
146  F77_CONST_CMPLX_ARG (x.data ()), F77_CONST_CMPLX_ARG (y.data ()),
147  F77_CMPLX_ARG (z.fortran_vec ())));
148  retval = z;
149  }
150  else
151  {
154  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
155  ComplexNDArray z (dimz);
156 
157  F77_XFCN (zdotc3, ZDOTC3, (m, n, k,
160  retval = z;
161  }
162  }
163  else if (argx.isfloat () && argy.isfloat ())
164  {
165  if (argx.is_single_type () || argy.is_single_type ())
166  {
167  FloatNDArray x = argx.float_array_value ();
168  FloatNDArray y = argy.float_array_value ();
169  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
170  FloatNDArray z (dimz);
171 
172  F77_XFCN (sdot3, SDOT3, (m, n, k, x.data (), y.data (),
173  z.fortran_vec ()));
174  retval = z;
175  }
176  else
177  {
178  NDArray x = argx.array_value ();
179  NDArray y = argy.array_value ();
180  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
181  NDArray z (dimz);
182 
183  F77_XFCN (ddot3, DDOT3, (m, n, k, x.data (), y.data (),
184  z.fortran_vec ()));
185  retval = z;
186  }
187  }
188  else
189  {
190  // Non-optimized evaluation.
191  // FIXME: This may *not* do what the user expects.
192  // It might be more useful to issue a warning, or even an error, instead
193  // of calculating possibly garbage results.
194  // Think of the dot product of two int8 vectors where the multiplications
195  // exceed intmax.
196  octave_value_list tmp (2);
197  tmp(0) = binary_op (octave_value::op_el_mul, argx, argy);
198  tmp(1) = dim + 1;
199 
200  tmp = Fsum (tmp, 1);
201  if (! tmp.empty ())
202  retval = tmp(0);
203  }
204 
205  return retval;
206 }
207 
208 /*
209 %!assert (dot ([1, 2], [2, 3]), 8)
210 
211 %!test
212 %! x = [2, 1; 2, 1];
213 %! y = [-0.5, 2; 0.5, -2];
214 %! assert (dot (x, y), [0 0]);
215 %! assert (dot (single (x), single (y)), single ([0 0]));
216 
217 %!test
218 %! x = [1+i, 3-i; 1-i, 3-i];
219 %! assert (dot (x, x), [4, 20]);
220 %! assert (dot (single (x), single (x)), single ([4, 20]));
221 
222 %!test
223 %! x = int8 ([1, 2]);
224 %! y = int8 ([2, 3]);
225 %! assert (dot (x, y), 8);
226 
227 %!test
228 %! x = int8 ([1, 2; 3, 4]);
229 %! y = int8 ([5, 6; 7, 8]);
230 %! assert (dot (x, y), [26 44]);
231 %! assert (dot (x, y, 2), [17; 53]);
232 %! assert (dot (x, y, 3), [5 12; 21 32]);
233 
234 ## This is, perhaps, surprising. Integer maximums and saturation mechanics
235 ## prevent accurate value from being calculated.
236 %!test
237 %! x = int8 ([127]);
238 %! assert (dot (x, x), 127);
239 
240 ## Test input validation
241 %!error dot ()
242 %!error dot (1)
243 %!error dot (1,2,3,4)
244 %!error <X and Y must be numeric> dot ({1,2}, [3,4])
245 %!error <X and Y must be numeric> dot ([1,2], {3,4})
246 %!error <sizes of X and Y must match> dot ([1 2], [1 2 3])
247 %!error <sizes of X and Y must match> dot ([1 2]', [1 2 3]')
248 %!error <sizes of X and Y must match> dot (ones (2,2), ones (2,3))
249 %!error <DIM must be a valid dimension> dot ([1 2], [1 2], 0)
250 */
251 
252 template <typename T>
253 static void
254 blkmm_internal (const T& x, const T& y, T& z,
255  F77_INT m, F77_INT n, F77_INT k, F77_INT np);
256 
257 template <>
258 void
261  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
262 {
263  F77_XFCN (cmatm3, CMATM3, (m, n, k, np,
264  F77_CONST_CMPLX_ARG (x.data ()),
265  F77_CONST_CMPLX_ARG (y.data ()),
266  F77_CMPLX_ARG (z.fortran_vec ())));
267 }
268 
269 template <>
270 void
272  ComplexNDArray& z,
273  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
274 {
275  F77_XFCN (zmatm3, ZMATM3, (m, n, k, np,
276  F77_CONST_DBLE_CMPLX_ARG (x.data ()),
279 }
280 
281 template <>
282 void
284  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
285 {
286  F77_XFCN (smatm3, SMATM3, (m, n, k, np,
287  x.data (), y.data (),
288  z.fortran_vec ()));
289 }
290 
291 template <>
292 void
293 blkmm_internal (const NDArray& x, const NDArray& y, NDArray& z,
294  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
295 {
296  F77_XFCN (dmatm3, DMATM3, (m, n, k, np,
297  x.data (), y.data (),
298  z.fortran_vec ()));
299 }
300 
301 static void
302 get_blkmm_dims (const dim_vector& dimx, const dim_vector& dimy,
303  F77_INT& m, F77_INT& n, F77_INT& k, F77_INT& np,
304  dim_vector& dimz)
305 {
306  int nd = dimx.ndims ();
307 
308  m = to_f77_int (dimx(0));
309  k = to_f77_int (dimx(1));
310  n = to_f77_int (dimy(1));
311 
312  octave_idx_type tmp_np = 1;
313 
314  bool match = ((dimy(0) == k) && (nd == dimy.ndims ()));
315 
316  dimz = dim_vector::alloc (nd);
317 
318  dimz(0) = m;
319  dimz(1) = n;
320  for (int i = 2; match && i < nd; i++)
321  {
322  match = (dimx(i) == dimy(i));
323  dimz(i) = dimx(i);
324  tmp_np *= dimz(i);
325  }
326 
327  np = to_f77_int (tmp_np);
328 
329  if (! match)
330  error ("blkmm: A and B dimensions don't match: (%s) and (%s)",
331  dimx.str ().c_str (), dimy.str ().c_str ());
332 }
333 
334 template <typename T>
335 T
336 do_blkmm (const octave_value& xov, const octave_value& yov)
337 {
338  const T x = octave_value_extract<T> (xov);
339  const T y = octave_value_extract<T> (yov);
340  F77_INT m, n, k, np;
341  dim_vector dimz;
342 
343  get_blkmm_dims (x.dims (), y.dims (), m, n, k, np, dimz);
344 
345  T z (dimz);
346 
347  if (n != 0 && m != 0)
348  blkmm_internal<T> (x, y, z, m, n, k, np);
349 
350  return z;
351 }
352 
353 DEFUN (blkmm, args, ,
354  doc: /* -*- texinfo -*-
355 @deftypefn {} {@var{C} =} blkmm (@var{A}, @var{B})
356 Compute products of matrix blocks.
357 
358 The blocks are given as 2-dimensional subarrays of the arrays @var{A},
359 @var{B}. The size of @var{A} must have the form @code{[m,k,@dots{}]} and
360 size of @var{B} must be @code{[k,n,@dots{}]}. The result is then of size
361 @code{[m,n,@dots{}]} and is computed as follows:
362 
363 @example
364 @group
365 for i = 1:prod (size (@var{A})(3:end))
366  @var{C}(:,:,i) = @var{A}(:,:,i) * @var{B}(:,:,i)
367 endfor
368 @end group
369 @end example
370 @end deftypefn */)
371 {
372  if (args.length () != 2)
373  print_usage ();
374 
375  octave_value retval;
376 
377  octave_value argx = args(0);
378  octave_value argy = args(1);
379 
380  if (! argx.isnumeric () || ! argy.isnumeric ())
381  error ("blkmm: A and B must be numeric");
382 
383  if (argx.iscomplex () || argy.iscomplex ())
384  {
385  if (argx.is_single_type () || argy.is_single_type ())
386  retval = do_blkmm<FloatComplexNDArray> (argx, argy);
387  else
388  retval = do_blkmm<ComplexNDArray> (argx, argy);
389  }
390  else
391  {
392  if (argx.is_single_type () || argy.is_single_type ())
393  retval = do_blkmm<FloatNDArray> (argx, argy);
394  else
395  retval = do_blkmm<NDArray> (argx, argy);
396  }
397 
398  return retval;
399 }
400 
401 /*
402 %!test
403 %! x(:,:,1) = [1 2; 3 4];
404 %! x(:,:,2) = [1 1; 1 1];
405 %! z(:,:,1) = [7 10; 15 22];
406 %! z(:,:,2) = [2 2; 2 2];
407 %! assert (blkmm (x,x), z);
408 %! assert (blkmm (single (x), single (x)), single (z));
409 %! assert (blkmm (x, single (x)), single (z));
410 
411 %!test
412 %! x(:,:,1) = [1 2; 3 4];
413 %! x(:,:,2) = [1i 1i; 1i 1i];
414 %! z(:,:,1) = [7 10; 15 22];
415 %! z(:,:,2) = [-2 -2; -2 -2];
416 %! assert (blkmm (x,x), z);
417 %! assert (blkmm (single (x), single (x)), single (z));
418 %! assert (blkmm (x, single (x)), single (z));
419 
420 %!test <*54261>
421 %! x = ones (0, 3, 3);
422 %! y = ones (3, 5, 3);
423 %! z = blkmm (x,y);
424 %! assert (size (z), [0, 5, 3]);
425 %! x = ones (1, 3, 3);
426 %! y = ones (3, 0, 3);
427 %! z = blkmm (x,y);
428 %! assert (size (z), [1, 0, 3]);
429 
430 ## Test input validation
431 %!error blkmm ()
432 %!error blkmm (1)
433 %!error blkmm (1,2,3)
434 %!error <A and B must be numeric> blkmm ({1,2}, [3,4])
435 %!error <A and B must be numeric> blkmm ([3,4], {1,2})
436 %!error <A and B dimensions don't match> blkmm (ones (2,2), ones (3,3))
437 */
438 
439 OCTAVE_END_NAMESPACE(octave)
octave_value_list Fsum(const octave_value_list &=octave_value_list(), int=0)
Definition: data.cc:3104
subroutine cdotc3(m, n, k, a, b, c)
Definition: cdotc3.f:23
T * fortran_vec()
Size of the specified dimension.
Definition: Array-base.cc:1764
const T * data() const
Size of the specified dimension.
Definition: Array.h:663
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:94
std::string str(char sep='x') const
Definition: dim-vector.cc:68
static dim_vector alloc(int n)
Definition: dim-vector.h:202
octave_idx_type ndims() const
Number of dimensions.
Definition: dim-vector.h:257
bool isvector() const
Definition: dim-vector.h:395
int first_non_singleton(int def=0) const
Definition: dim-vector.h:444
dim_vector redim(int n) const
Force certain dimensionality, preserving numel ().
Definition: dim-vector.cc:226
bool empty() const
Definition: ovl.h:115
bool isfloat() const
Definition: ov.h:701
bool isnumeric() const
Definition: ov.h:750
ComplexNDArray complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:878
bool is_single_type() const
Definition: ov.h:698
octave_value reshape(const dim_vector &dv) const
Definition: ov.h:571
bool iscomplex() const
Definition: ov.h:741
@ op_el_mul
Definition: ov.h:105
NDArray array_value(bool frc_str_conv=false) const
Definition: ov.h:859
FloatComplexNDArray float_complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:882
FloatNDArray float_array_value(bool frc_str_conv=false) const
Definition: ov.h:862
dim_vector dims() const
Definition: ov.h:541
subroutine cmatm3(m, n, k, np, a, b, c)
Definition: cmatm3.f:21
OCTAVE_BEGIN_NAMESPACE(octave) static octave_value daspk_fcn
subroutine ddot3(m, n, k, a, b, c)
Definition: ddot3.f:23
void print_usage(void)
Definition: defun-int.h:72
#define DEFUN(name, args_name, nargout_name, doc)
Macro to define a builtin function.
Definition: defun.h:56
subroutine dmatm3(m, n, k, np, a, b, c)
Definition: dmatm3.f:23
T do_blkmm(const octave_value &xov, const octave_value &yov)
Definition: dot.cc:336
void blkmm_internal(const FloatComplexNDArray &x, const FloatComplexNDArray &y, FloatComplexNDArray &z, F77_INT m, F77_INT n, F77_INT k, F77_INT np)
Definition: dot.cc:259
void() error(const char *fmt,...)
Definition: error.cc:988
#define F77_CONST_CMPLX_ARG(x)
Definition: f77-fcn.h:313
#define F77_DBLE_CMPLX_ARG(x)
Definition: f77-fcn.h:316
#define F77_CMPLX_ARG(x)
Definition: f77-fcn.h:310
#define F77_XFCN(f, F, args)
Definition: f77-fcn.h:45
octave_f77_int_type F77_INT
Definition: f77-fcn.h:306
#define F77_CONST_DBLE_CMPLX_ARG(x)
Definition: f77-fcn.h:319
double dot(const ColumnVector &v1, const ColumnVector &v2)
Definition: graphics.cc:5541
F77_RET_T const F77_DBLE * x
T octave_idx_type m
Definition: mx-inlines.cc:781
octave_idx_type n
Definition: mx-inlines.cc:761
octave_value binary_op(type_info &ti, octave_value::binary_op op, const octave_value &a, const octave_value &b)
subroutine sdot3(m, n, k, a, b, c)
Definition: sdot3.f:23
subroutine smatm3(m, n, k, np, a, b, c)
Definition: smatm3.f:23
subroutine zdotc3(m, n, k, a, b, c)
Definition: zdotc3.f:23
subroutine zmatm3(m, n, k, np, a, b, c)
Definition: zmatm3.f:23