GNU Octave  6.2.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
dot.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2009-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29 
30 #include "lo-blas-proto.h"
31 #include "mx-base.h"
32 
33 #include "builtin-defun-decls.h"
34 #include "defun.h"
35 #include "error.h"
36 #include "parse.h"
37 
38 static void
39 get_red_dims (const dim_vector& x, const dim_vector& y, int dim,
40  dim_vector& z, F77_INT& m, F77_INT& n, F77_INT& k)
41 {
42  int nd = x.ndims ();
43  assert (nd == y.ndims ());
44  z = dim_vector::alloc (nd);
45  octave_idx_type tmp_m = 1;
46  octave_idx_type tmp_n = 1;
47  octave_idx_type tmp_k = 1;
48  for (int i = 0; i < nd; i++)
49  {
50  if (i < dim)
51  {
52  z(i) = x(i);
53  tmp_m *= x(i);
54  }
55  else if (i > dim)
56  {
57  z(i) = x(i);
58  tmp_n *= x(i);
59  }
60  else
61  {
62  z(i) = 1;
63  tmp_k = x(i);
64  }
65  }
66 
67  m = octave::to_f77_int (tmp_m);
68  n = octave::to_f77_int (tmp_n);
69  k = octave::to_f77_int (tmp_k);
70 }
71 
72 DEFUN (dot, args, ,
73  doc: /* -*- texinfo -*-
74 @deftypefn {} {} dot (@var{x}, @var{y}, @var{dim})
75 Compute the dot product of two vectors.
76 
77 If @var{x} and @var{y} are matrices, calculate the dot products along the
78 first non-singleton dimension.
79 
80 If the optional argument @var{dim} is given, calculate the dot products
81 along this dimension.
82 
83 Implementation Note: This is equivalent to
84 @code{sum (conj (@var{X}) .* @var{Y}, @var{dim})}, but avoids forming a
85 temporary array and is faster. When @var{X} and @var{Y} are column vectors,
86 the result is equivalent to @code{@var{X}' * @var{Y}}. Although, @code{dot}
87 is defined for integer arrays, the output may differ from the expected result
88 due to the limited range of integer objects.
89 @seealso{cross, divergence}
90 @end deftypefn */)
91 {
92  int nargin = args.length ();
93 
94  if (nargin < 2 || nargin > 3)
95  print_usage ();
96 
98  octave_value argx = args(0);
99  octave_value argy = args(1);
100 
101  if (! argx.isnumeric () || ! argy.isnumeric ())
102  error ("dot: X and Y must be numeric");
103 
104  dim_vector dimx = argx.dims ();
105  dim_vector dimy = argy.dims ();
106  bool match = dimx == dimy;
107  if (! match && nargin == 2 && dimx.isvector () && dimy.isvector ())
108  {
109  // Change to column vectors.
110  dimx = dimx.redim (1);
111  argx = argx.reshape (dimx);
112  dimy = dimy.redim (1);
113  argy = argy.reshape (dimy);
114  match = dimx == dimy;
115  }
116 
117  if (! match)
118  error ("dot: sizes of X and Y must match");
119 
120  int dim;
121  if (nargin == 2)
122  dim = dimx.first_non_singleton ();
123  else
124  dim = args(2).int_value (true) - 1;
125 
126  if (dim < 0)
127  error ("dot: DIM must be a valid dimension");
128 
129  F77_INT m, n, k;
130  dim_vector dimz;
131  if (argx.iscomplex () || argy.iscomplex ())
132  {
133  if (argx.is_single_type () || argy.is_single_type ())
134  {
137  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
138  FloatComplexNDArray z (dimz);
139 
140  F77_XFCN (cdotc3, CDOTC3, (m, n, k,
141  F77_CONST_CMPLX_ARG (x.data ()), F77_CONST_CMPLX_ARG (y.data ()),
142  F77_CMPLX_ARG (z.fortran_vec ())));
143  retval = z;
144  }
145  else
146  {
149  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
150  ComplexNDArray z (dimz);
151 
152  F77_XFCN (zdotc3, ZDOTC3, (m, n, k,
155  retval = z;
156  }
157  }
158  else if (argx.isfloat () && argy.isfloat ())
159  {
160  if (argx.is_single_type () || argy.is_single_type ())
161  {
162  FloatNDArray x = argx.float_array_value ();
163  FloatNDArray y = argy.float_array_value ();
164  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
165  FloatNDArray z (dimz);
166 
167  F77_XFCN (sdot3, SDOT3, (m, n, k, x.data (), y.data (),
168  z.fortran_vec ()));
169  retval = z;
170  }
171  else
172  {
173  NDArray x = argx.array_value ();
174  NDArray y = argy.array_value ();
175  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
176  NDArray z (dimz);
177 
178  F77_XFCN (ddot3, DDOT3, (m, n, k, x.data (), y.data (),
179  z.fortran_vec ()));
180  retval = z;
181  }
182  }
183  else
184  {
185  // Non-optimized evaluation.
186  // FIXME: This may *not* do what the user expects.
187  // It might be more useful to issue a warning, or even an error, instead
188  // of calculating possibly garbage results.
189  // Think of the dot product of two int8 vectors where the multiplications
190  // exceed intmax.
191  octave_value_list tmp;
192  tmp(1) = dim + 1;
193  tmp(0) = do_binary_op (octave_value::op_el_mul, argx, argy);
194 
195  tmp = Fsum (tmp, 1);
196  if (! tmp.empty ())
197  retval = tmp(0);
198  }
199 
200  return retval;
201 }
202 
203 /*
204 %!assert (dot ([1, 2], [2, 3]), 8)
205 
206 %!test
207 %! x = [2, 1; 2, 1];
208 %! y = [-0.5, 2; 0.5, -2];
209 %! assert (dot (x, y), [0 0]);
210 %! assert (dot (single (x), single (y)), single ([0 0]));
211 
212 %!test
213 %! x = [1+i, 3-i; 1-i, 3-i];
214 %! assert (dot (x, x), [4, 20]);
215 %! assert (dot (single (x), single (x)), single ([4, 20]));
216 
217 %!test
218 %! x = int8 ([1, 2]);
219 %! y = int8 ([2, 3]);
220 %! assert (dot (x, y), 8);
221 
222 %!test
223 %! x = int8 ([1, 2; 3, 4]);
224 %! y = int8 ([5, 6; 7, 8]);
225 %! assert (dot (x, y), [26 44]);
226 %! assert (dot (x, y, 2), [17; 53]);
227 %! assert (dot (x, y, 3), [5 12; 21 32]);
228 
229 ## This is, perhaps, surprising. Integer maximums and saturation mechanics
230 ## prevent accurate value from being calculated.
231 %!test
232 %! x = int8 ([127]);
233 %! assert (dot (x, x), 127);
234 
235 ## Test input validation
236 %!error dot ()
237 %!error dot (1)
238 %!error dot (1,2,3,4)
239 %!error <X and Y must be numeric> dot ({1,2}, [3,4])
240 %!error <X and Y must be numeric> dot ([1,2], {3,4})
241 %!error <sizes of X and Y must match> dot ([1 2], [1 2 3])
242 %!error <sizes of X and Y must match> dot ([1 2]', [1 2 3]')
243 %!error <sizes of X and Y must match> dot (ones (2,2), ones (2,3))
244 %!error <DIM must be a valid dimension> dot ([1 2], [1 2], 0)
245 */
246 
247 template <typename T>
248 void
249 blkmm_internal (const T& x, const T& y, T& z,
250  F77_INT m, F77_INT n, F77_INT k, F77_INT np);
251 
252 template <>
253 void
256  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
257 {
258  F77_XFCN (cmatm3, CMATM3, (m, n, k, np,
259  F77_CONST_CMPLX_ARG (x.data ()),
260  F77_CONST_CMPLX_ARG (y.data ()),
261  F77_CMPLX_ARG (z.fortran_vec ())));
262 }
263 
264 template <>
265 void
267  ComplexNDArray& z,
268  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
269 {
270  F77_XFCN (zmatm3, ZMATM3, (m, n, k, np,
271  F77_CONST_DBLE_CMPLX_ARG (x.data ()),
274 }
275 
276 template <>
277 void
279  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
280 {
281  F77_XFCN (smatm3, SMATM3, (m, n, k, np,
282  x.data (), y.data (),
283  z.fortran_vec ()));
284 }
285 
286 template <>
287 void
288 blkmm_internal (const NDArray& x, const NDArray& y, NDArray& z,
289  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
290 {
291  F77_XFCN (dmatm3, DMATM3, (m, n, k, np,
292  x.data (), y.data (),
293  z.fortran_vec ()));
294 }
295 
296 static void
297 get_blkmm_dims (const dim_vector& dimx, const dim_vector& dimy,
298  F77_INT& m, F77_INT& n, F77_INT& k, F77_INT& np,
299  dim_vector& dimz)
300 {
301  int nd = dimx.ndims ();
302 
303  m = octave::to_f77_int (dimx(0));
304  k = octave::to_f77_int (dimx(1));
305  n = octave::to_f77_int (dimy(1));
306 
307  octave_idx_type tmp_np = 1;
308 
309  bool match = ((dimy(0) == k) && (nd == dimy.ndims ()));
310 
311  dimz = dim_vector::alloc (nd);
312 
313  dimz(0) = m;
314  dimz(1) = n;
315  for (int i = 2; match && i < nd; i++)
316  {
317  match = (dimx(i) == dimy(i));
318  dimz(i) = dimx(i);
319  tmp_np *= dimz(i);
320  }
321 
322  np = octave::to_f77_int (tmp_np);
323 
324  if (! match)
325  error ("blkmm: A and B dimensions don't match: (%s) and (%s)",
326  dimx.str ().c_str (), dimy.str ().c_str ());
327 }
328 
329 template <typename T>
330 T
331 do_blkmm (const octave_value& xov, const octave_value& yov)
332 {
333  const T x = octave_value_extract<T> (xov);
334  const T y = octave_value_extract<T> (yov);
335  F77_INT m, n, k, np;
336  dim_vector dimz;
337 
338  get_blkmm_dims (x.dims (), y.dims (), m, n, k, np, dimz);
339 
340  T z (dimz);
341 
342  if (n != 0 && m != 0)
343  blkmm_internal<T> (x, y, z, m, n, k, np);
344 
345  return z;
346 }
347 
348 DEFUN (blkmm, args, ,
349  doc: /* -*- texinfo -*-
350 @deftypefn {} {} blkmm (@var{A}, @var{B})
351 Compute products of matrix blocks.
352 
353 The blocks are given as 2-dimensional subarrays of the arrays @var{A},
354 @var{B}. The size of @var{A} must have the form @code{[m,k,@dots{}]} and
355 size of @var{B} must be @code{[k,n,@dots{}]}. The result is then of size
356 @code{[m,n,@dots{}]} and is computed as follows:
357 
358 @example
359 @group
360 for i = 1:prod (size (@var{A})(3:end))
361  @var{C}(:,:,i) = @var{A}(:,:,i) * @var{B}(:,:,i)
362 endfor
363 @end group
364 @end example
365 @end deftypefn */)
366 {
367  if (args.length () != 2)
368  print_usage ();
369 
371 
372  octave_value argx = args(0);
373  octave_value argy = args(1);
374 
375  if (! argx.isnumeric () || ! argy.isnumeric ())
376  error ("blkmm: A and B must be numeric");
377 
378  if (argx.iscomplex () || argy.iscomplex ())
379  {
380  if (argx.is_single_type () || argy.is_single_type ())
381  retval = do_blkmm<FloatComplexNDArray> (argx, argy);
382  else
383  retval = do_blkmm<ComplexNDArray> (argx, argy);
384  }
385  else
386  {
387  if (argx.is_single_type () || argy.is_single_type ())
388  retval = do_blkmm<FloatNDArray> (argx, argy);
389  else
390  retval = do_blkmm<NDArray> (argx, argy);
391  }
392 
393  return retval;
394 }
395 
396 /*
397 %!test
398 %! x(:,:,1) = [1 2; 3 4];
399 %! x(:,:,2) = [1 1; 1 1];
400 %! z(:,:,1) = [7 10; 15 22];
401 %! z(:,:,2) = [2 2; 2 2];
402 %! assert (blkmm (x,x), z);
403 %! assert (blkmm (single (x), single (x)), single (z));
404 %! assert (blkmm (x, single (x)), single (z));
405 
406 %!test
407 %! x(:,:,1) = [1 2; 3 4];
408 %! x(:,:,2) = [1i 1i; 1i 1i];
409 %! z(:,:,1) = [7 10; 15 22];
410 %! z(:,:,2) = [-2 -2; -2 -2];
411 %! assert (blkmm (x,x), z);
412 %! assert (blkmm (single (x), single (x)), single (z));
413 %! assert (blkmm (x, single (x)), single (z));
414 
415 %!test <*54261>
416 %! x = ones (0, 3, 3);
417 %! y = ones (3, 5, 3);
418 %! z = blkmm (x,y);
419 %! assert (size (z), [0, 5, 3]);
420 %! x = ones (1, 3, 3);
421 %! y = ones (3, 0, 3);
422 %! z = blkmm (x,y);
423 %! assert (size (z), [1, 0, 3]);
424 
425 ## Test input validation
426 %!error blkmm ()
427 %!error blkmm (1)
428 %!error blkmm (1,2,3)
429 %!error <A and B must be numeric> blkmm ({1,2}, [3,4])
430 %!error <A and B must be numeric> blkmm ([3,4], {1,2})
431 %!error <A and B dimensions don't match> blkmm (ones (2,2), ones (3,3))
432 */
subroutine cdotc3(m, n, k, a, b, c)
Definition: cdotc3.f:23
const T * data(void) const
Size of the specified dimension.
Definition: Array.h:581
const T * fortran_vec(void) const
Size of the specified dimension.
Definition: Array.h:583
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:95
std::string str(char sep='x') const
Definition: dim-vector.cc:85
static dim_vector alloc(int n)
Definition: dim-vector.h:281
bool isvector(void) const
Definition: dim-vector.h:461
int first_non_singleton(int def=0) const
Definition: dim-vector.h:510
octave_idx_type ndims(void) const
Number of dimensions.
Definition: dim-vector.h:334
dim_vector redim(int n) const
Force certain dimensionality, preserving numel ().
Definition: dim-vector.cc:245
bool empty(void) const
Definition: ovl.h:115
bool isnumeric(void) const
Definition: ov.h:703
ComplexNDArray complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:831
octave_value reshape(const dim_vector &dv) const
Definition: ov.h:530
@ op_el_mul
Definition: ov.h:109
NDArray array_value(bool frc_str_conv=false) const
Definition: ov.h:812
bool is_single_type(void) const
Definition: ov.h:651
FloatComplexNDArray float_complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:835
FloatNDArray float_array_value(bool frc_str_conv=false) const
Definition: ov.h:815
bool iscomplex(void) const
Definition: ov.h:694
bool isfloat(void) const
Definition: ov.h:654
dim_vector dims(void) const
Definition: ov.h:500
subroutine cmatm3(m, n, k, np, a, b, c)
Definition: cmatm3.f:21
OCTAVE_EXPORT octave_value_list Fsum(const octave_value_list &args, int)
Definition: data.cc:2926
subroutine ddot3(m, n, k, a, b, c)
Definition: ddot3.f:23
OCTINTERP_API void print_usage(void)
Definition: defun.cc:53
#define DEFUN(name, args_name, nargout_name, doc)
Macro to define a builtin function.
Definition: defun.h:56
subroutine dmatm3(m, n, k, np, a, b, c)
Definition: dmatm3.f:23
T do_blkmm(const octave_value &xov, const octave_value &yov)
Definition: dot.cc:331
static void get_red_dims(const dim_vector &x, const dim_vector &y, int dim, dim_vector &z, F77_INT &m, F77_INT &n, F77_INT &k)
Definition: dot.cc:39
static void get_blkmm_dims(const dim_vector &dimx, const dim_vector &dimy, F77_INT &m, F77_INT &n, F77_INT &k, F77_INT &np, dim_vector &dimz)
Definition: dot.cc:297
void blkmm_internal(const T &x, const T &y, T &z, F77_INT m, F77_INT n, F77_INT k, F77_INT np)
void error(const char *fmt,...)
Definition: error.cc:968
#define F77_CONST_CMPLX_ARG(x)
Definition: f77-fcn.h:312
#define F77_DBLE_CMPLX_ARG(x)
Definition: f77-fcn.h:315
#define F77_CMPLX_ARG(x)
Definition: f77-fcn.h:309
#define F77_XFCN(f, F, args)
Definition: f77-fcn.h:44
octave_f77_int_type F77_INT
Definition: f77-fcn.h:305
#define F77_CONST_DBLE_CMPLX_ARG(x)
Definition: f77-fcn.h:318
double dot(const ColumnVector &v1, const ColumnVector &v2)
Definition: graphics.cc:5887
F77_RET_T const F77_DBLE * x
T octave_idx_type m
Definition: mx-inlines.cc:773
octave_idx_type n
Definition: mx-inlines.cc:753
octave_value::octave_value(const Array< char > &chm, char type) return retval
Definition: ov.cc:811
OCTINTERP_API octave_value do_binary_op(octave::type_info &ti, octave_value::binary_op op, const octave_value &a, const octave_value &b)
subroutine sdot3(m, n, k, a, b, c)
Definition: sdot3.f:23
subroutine smatm3(m, n, k, np, a, b, c)
Definition: smatm3.f:23
subroutine zdotc3(m, n, k, a, b, c)
Definition: zdotc3.f:23
subroutine zmatm3(m, n, k, np, a, b, c)
Definition: zmatm3.f:23