GNU Octave 7.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
dot.cc
Go to the documentation of this file.
1////////////////////////////////////////////////////////////////////////
2//
3// Copyright (C) 2009-2022 The Octave Project Developers
4//
5// See the file COPYRIGHT.md in the top-level directory of this
6// distribution or <https://octave.org/copyright/>.
7//
8// This file is part of Octave.
9//
10// Octave is free software: you can redistribute it and/or modify it
11// under the terms of the GNU General Public License as published by
12// the Free Software Foundation, either version 3 of the License, or
13// (at your option) any later version.
14//
15// Octave is distributed in the hope that it will be useful, but
16// WITHOUT ANY WARRANTY; without even the implied warranty of
17// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18// GNU General Public License for more details.
19//
20// You should have received a copy of the GNU General Public License
21// along with Octave; see the file COPYING. If not, see
22// <https://www.gnu.org/licenses/>.
23//
24////////////////////////////////////////////////////////////////////////
25
26#if defined (HAVE_CONFIG_H)
27# include "config.h"
28#endif
29
30#include "lo-blas-proto.h"
31#include "mx-base.h"
32
33#include "builtin-defun-decls.h"
34#include "defun.h"
35#include "error.h"
36#include "parse.h"
37
38OCTAVE_NAMESPACE_BEGIN
39
40static void
41get_red_dims (const dim_vector& x, const dim_vector& y, int dim,
42 dim_vector& z, F77_INT& m, F77_INT& n, F77_INT& k)
43{
44 int nd = x.ndims ();
45 assert (nd == y.ndims ());
46 z = dim_vector::alloc (nd);
47 octave_idx_type tmp_m = 1;
48 octave_idx_type tmp_n = 1;
49 octave_idx_type tmp_k = 1;
50 for (int i = 0; i < nd; i++)
51 {
52 if (i < dim)
53 {
54 z(i) = x(i);
55 tmp_m *= x(i);
56 }
57 else if (i > dim)
58 {
59 z(i) = x(i);
60 tmp_n *= x(i);
61 }
62 else
63 {
64 z(i) = 1;
65 tmp_k = x(i);
66 }
67 }
68
69 m = to_f77_int (tmp_m);
70 n = to_f77_int (tmp_n);
71 k = to_f77_int (tmp_k);
72}
73
74DEFUN (dot, args, ,
75 doc: /* -*- texinfo -*-
76@deftypefn {} {} dot (@var{x}, @var{y}, @var{dim})
77Compute the dot product of two vectors.
78
79If @var{x} and @var{y} are matrices, calculate the dot products along the
80first non-singleton dimension.
81
82If the optional argument @var{dim} is given, calculate the dot products
83along this dimension.
84
85Implementation Note: This is equivalent to
86@code{sum (conj (@var{X}) .* @var{Y}, @var{dim})}, but avoids forming a
87temporary array and is faster. When @var{X} and @var{Y} are column vectors,
88the result is equivalent to @code{@var{X}' * @var{Y}}. Although, @code{dot}
89is defined for integer arrays, the output may differ from the expected result
90due to the limited range of integer objects.
91@seealso{cross, divergence}
92@end deftypefn */)
93{
94 int nargin = args.length ();
95
96 if (nargin < 2 || nargin > 3)
97 print_usage ();
98
99 octave_value retval;
100 octave_value argx = args(0);
101 octave_value argy = args(1);
102
103 if (! argx.isnumeric () || ! argy.isnumeric ())
104 error ("dot: X and Y must be numeric");
105
106 dim_vector dimx = argx.dims ();
107 dim_vector dimy = argy.dims ();
108 bool match = dimx == dimy;
109 if (! match && nargin == 2 && dimx.isvector () && dimy.isvector ())
110 {
111 // Change to column vectors.
112 dimx = dimx.redim (1);
113 argx = argx.reshape (dimx);
114 dimy = dimy.redim (1);
115 argy = argy.reshape (dimy);
116 match = dimx == dimy;
117 }
118
119 if (! match)
120 error ("dot: sizes of X and Y must match");
121
122 int dim;
123 if (nargin == 2)
124 dim = dimx.first_non_singleton ();
125 else
126 dim = args(2).int_value (true) - 1;
127
128 if (dim < 0)
129 error ("dot: DIM must be a valid dimension");
130
131 F77_INT m, n, k;
132 dim_vector dimz;
133 if (argx.iscomplex () || argy.iscomplex ())
134 {
135 if (argx.is_single_type () || argy.is_single_type ())
136 {
139 get_red_dims (dimx, dimy, dim, dimz, m, n, k);
140 FloatComplexNDArray z (dimz);
141
142 F77_XFCN (cdotc3, CDOTC3, (m, n, k,
144 F77_CMPLX_ARG (z.fortran_vec ())));
145 retval = z;
146 }
147 else
148 {
151 get_red_dims (dimx, dimy, dim, dimz, m, n, k);
152 ComplexNDArray z (dimz);
153
154 F77_XFCN (zdotc3, ZDOTC3, (m, n, k,
157 retval = z;
158 }
159 }
160 else if (argx.isfloat () && argy.isfloat ())
161 {
162 if (argx.is_single_type () || argy.is_single_type ())
163 {
166 get_red_dims (dimx, dimy, dim, dimz, m, n, k);
167 FloatNDArray z (dimz);
168
169 F77_XFCN (sdot3, SDOT3, (m, n, k, x.data (), y.data (),
170 z.fortran_vec ()));
171 retval = z;
172 }
173 else
174 {
175 NDArray x = argx.array_value ();
176 NDArray y = argy.array_value ();
177 get_red_dims (dimx, dimy, dim, dimz, m, n, k);
178 NDArray z (dimz);
179
180 F77_XFCN (ddot3, DDOT3, (m, n, k, x.data (), y.data (),
181 z.fortran_vec ()));
182 retval = z;
183 }
184 }
185 else
186 {
187 // Non-optimized evaluation.
188 // FIXME: This may *not* do what the user expects.
189 // It might be more useful to issue a warning, or even an error, instead
190 // of calculating possibly garbage results.
191 // Think of the dot product of two int8 vectors where the multiplications
192 // exceed intmax.
194 tmp(1) = dim + 1;
195 tmp(0) = binary_op (octave_value::op_el_mul, argx, argy);
196
197 tmp = Fsum (tmp, 1);
198 if (! tmp.empty ())
199 retval = tmp(0);
200 }
201
202 return retval;
203}
204
205/*
206%!assert (dot ([1, 2], [2, 3]), 8)
207
208%!test
209%! x = [2, 1; 2, 1];
210%! y = [-0.5, 2; 0.5, -2];
211%! assert (dot (x, y), [0 0]);
212%! assert (dot (single (x), single (y)), single ([0 0]));
213
214%!test
215%! x = [1+i, 3-i; 1-i, 3-i];
216%! assert (dot (x, x), [4, 20]);
217%! assert (dot (single (x), single (x)), single ([4, 20]));
218
219%!test
220%! x = int8 ([1, 2]);
221%! y = int8 ([2, 3]);
222%! assert (dot (x, y), 8);
223
224%!test
225%! x = int8 ([1, 2; 3, 4]);
226%! y = int8 ([5, 6; 7, 8]);
227%! assert (dot (x, y), [26 44]);
228%! assert (dot (x, y, 2), [17; 53]);
229%! assert (dot (x, y, 3), [5 12; 21 32]);
230
231## This is, perhaps, surprising. Integer maximums and saturation mechanics
232## prevent accurate value from being calculated.
233%!test
234%! x = int8 ([127]);
235%! assert (dot (x, x), 127);
236
237## Test input validation
238%!error dot ()
239%!error dot (1)
240%!error dot (1,2,3,4)
241%!error <X and Y must be numeric> dot ({1,2}, [3,4])
242%!error <X and Y must be numeric> dot ([1,2], {3,4})
243%!error <sizes of X and Y must match> dot ([1 2], [1 2 3])
244%!error <sizes of X and Y must match> dot ([1 2]', [1 2 3]')
245%!error <sizes of X and Y must match> dot (ones (2,2), ones (2,3))
246%!error <DIM must be a valid dimension> dot ([1 2], [1 2], 0)
247*/
248
249template <typename T>
250static void
251blkmm_internal (const T& x, const T& y, T& z,
252 F77_INT m, F77_INT n, F77_INT k, F77_INT np);
253
254template <>
255void
258 F77_INT m, F77_INT n, F77_INT k, F77_INT np)
259{
260 F77_XFCN (cmatm3, CMATM3, (m, n, k, np,
261 F77_CONST_CMPLX_ARG (x.data ()),
263 F77_CMPLX_ARG (z.fortran_vec ())));
264}
265
266template <>
267void
270 F77_INT m, F77_INT n, F77_INT k, F77_INT np)
271{
272 F77_XFCN (zmatm3, ZMATM3, (m, n, k, np,
273 F77_CONST_DBLE_CMPLX_ARG (x.data ()),
276}
277
278template <>
279void
281 F77_INT m, F77_INT n, F77_INT k, F77_INT np)
282{
283 F77_XFCN (smatm3, SMATM3, (m, n, k, np,
284 x.data (), y.data (),
285 z.fortran_vec ()));
286}
287
288template <>
289void
290blkmm_internal (const NDArray& x, const NDArray& y, NDArray& z,
291 F77_INT m, F77_INT n, F77_INT k, F77_INT np)
292{
293 F77_XFCN (dmatm3, DMATM3, (m, n, k, np,
294 x.data (), y.data (),
295 z.fortran_vec ()));
296}
297
298static void
299get_blkmm_dims (const dim_vector& dimx, const dim_vector& dimy,
300 F77_INT& m, F77_INT& n, F77_INT& k, F77_INT& np,
301 dim_vector& dimz)
302{
303 int nd = dimx.ndims ();
304
305 m = to_f77_int (dimx(0));
306 k = to_f77_int (dimx(1));
307 n = to_f77_int (dimy(1));
308
309 octave_idx_type tmp_np = 1;
310
311 bool match = ((dimy(0) == k) && (nd == dimy.ndims ()));
312
313 dimz = dim_vector::alloc (nd);
314
315 dimz(0) = m;
316 dimz(1) = n;
317 for (int i = 2; match && i < nd; i++)
318 {
319 match = (dimx(i) == dimy(i));
320 dimz(i) = dimx(i);
321 tmp_np *= dimz(i);
322 }
323
324 np = to_f77_int (tmp_np);
325
326 if (! match)
327 error ("blkmm: A and B dimensions don't match: (%s) and (%s)",
328 dimx.str ().c_str (), dimy.str ().c_str ());
329}
330
331template <typename T>
332T
333do_blkmm (const octave_value& xov, const octave_value& yov)
334{
335 const T x = octave_value_extract<T> (xov);
336 const T y = octave_value_extract<T> (yov);
337 F77_INT m, n, k, np;
338 dim_vector dimz;
339
340 get_blkmm_dims (x.dims (), y.dims (), m, n, k, np, dimz);
341
342 T z (dimz);
343
344 if (n != 0 && m != 0)
345 blkmm_internal<T> (x, y, z, m, n, k, np);
346
347 return z;
348}
349
350DEFUN (blkmm, args, ,
351 doc: /* -*- texinfo -*-
352@deftypefn {} {} blkmm (@var{A}, @var{B})
353Compute products of matrix blocks.
354
355The blocks are given as 2-dimensional subarrays of the arrays @var{A},
356@var{B}. The size of @var{A} must have the form @code{[m,k,@dots{}]} and
357size of @var{B} must be @code{[k,n,@dots{}]}. The result is then of size
358@code{[m,n,@dots{}]} and is computed as follows:
359
360@example
361@group
362for i = 1:prod (size (@var{A})(3:end))
363 @var{C}(:,:,i) = @var{A}(:,:,i) * @var{B}(:,:,i)
364endfor
365@end group
366@end example
367@end deftypefn */)
368{
369 if (args.length () != 2)
370 print_usage ();
371
372 octave_value retval;
373
374 octave_value argx = args(0);
375 octave_value argy = args(1);
376
377 if (! argx.isnumeric () || ! argy.isnumeric ())
378 error ("blkmm: A and B must be numeric");
379
380 if (argx.iscomplex () || argy.iscomplex ())
381 {
382 if (argx.is_single_type () || argy.is_single_type ())
383 retval = do_blkmm<FloatComplexNDArray> (argx, argy);
384 else
385 retval = do_blkmm<ComplexNDArray> (argx, argy);
386 }
387 else
388 {
389 if (argx.is_single_type () || argy.is_single_type ())
390 retval = do_blkmm<FloatNDArray> (argx, argy);
391 else
392 retval = do_blkmm<NDArray> (argx, argy);
393 }
394
395 return retval;
396}
397
398/*
399%!test
400%! x(:,:,1) = [1 2; 3 4];
401%! x(:,:,2) = [1 1; 1 1];
402%! z(:,:,1) = [7 10; 15 22];
403%! z(:,:,2) = [2 2; 2 2];
404%! assert (blkmm (x,x), z);
405%! assert (blkmm (single (x), single (x)), single (z));
406%! assert (blkmm (x, single (x)), single (z));
407
408%!test
409%! x(:,:,1) = [1 2; 3 4];
410%! x(:,:,2) = [1i 1i; 1i 1i];
411%! z(:,:,1) = [7 10; 15 22];
412%! z(:,:,2) = [-2 -2; -2 -2];
413%! assert (blkmm (x,x), z);
414%! assert (blkmm (single (x), single (x)), single (z));
415%! assert (blkmm (x, single (x)), single (z));
416
417%!test <*54261>
418%! x = ones (0, 3, 3);
419%! y = ones (3, 5, 3);
420%! z = blkmm (x,y);
421%! assert (size (z), [0, 5, 3]);
422%! x = ones (1, 3, 3);
423%! y = ones (3, 0, 3);
424%! z = blkmm (x,y);
425%! assert (size (z), [1, 0, 3]);
426
427## Test input validation
428%!error blkmm ()
429%!error blkmm (1)
430%!error blkmm (1,2,3)
431%!error <A and B must be numeric> blkmm ({1,2}, [3,4])
432%!error <A and B must be numeric> blkmm ([3,4], {1,2})
433%!error <A and B dimensions don't match> blkmm (ones (2,2), ones (3,3))
434*/
435
436OCTAVE_NAMESPACE_END
subroutine cdotc3(m, n, k, a, b, c)
Definition: cdotc3.f:23
const T * data(void) const
Size of the specified dimension.
Definition: Array.h:616
OCTARRAY_API T * fortran_vec(void)
Size of the specified dimension.
Definition: Array.cc:1744
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:94
OCTAVE_API std::string str(char sep='x') const
Definition: dim-vector.cc:68
static dim_vector alloc(int n)
Definition: dim-vector.h:202
bool isvector(void) const
Definition: dim-vector.h:395
int first_non_singleton(int def=0) const
Definition: dim-vector.h:444
octave_idx_type ndims(void) const
Number of dimensions.
Definition: dim-vector.h:257
OCTAVE_API dim_vector redim(int n) const
Force certain dimensionality, preserving numel ().
Definition: dim-vector.cc:226
bool empty(void) const
Definition: ovl.h:115
bool isnumeric(void) const
Definition: ov.h:795
ComplexNDArray complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:923
octave_value reshape(const dim_vector &dv) const
Definition: ov.h:616
@ op_el_mul
Definition: ov.h:103
NDArray array_value(bool frc_str_conv=false) const
Definition: ov.h:904
bool is_single_type(void) const
Definition: ov.h:743
FloatComplexNDArray float_complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:927
FloatNDArray float_array_value(bool frc_str_conv=false) const
Definition: ov.h:907
bool iscomplex(void) const
Definition: ov.h:786
bool isfloat(void) const
Definition: ov.h:746
dim_vector dims(void) const
Definition: ov.h:586
subroutine cmatm3(m, n, k, np, a, b, c)
Definition: cmatm3.f:21
OCTAVE_EXPORT octave_value_list Fsum(const octave_value_list &args, int)
Definition: data.cc:3039
subroutine ddot3(m, n, k, a, b, c)
Definition: ddot3.f:23
OCTINTERP_API void print_usage(void)
Definition: defun-int.h:72
#define DEFUN(name, args_name, nargout_name, doc)
Macro to define a builtin function.
Definition: defun.h:56
subroutine dmatm3(m, n, k, np, a, b, c)
Definition: dmatm3.f:23
T do_blkmm(const octave_value &xov, const octave_value &yov)
Definition: dot.cc:333
static OCTAVE_NAMESPACE_BEGIN void get_red_dims(const dim_vector &x, const dim_vector &y, int dim, dim_vector &z, F77_INT &m, F77_INT &n, F77_INT &k)
Definition: dot.cc:41
static void blkmm_internal(const T &x, const T &y, T &z, F77_INT m, F77_INT n, F77_INT k, F77_INT np)
static void get_blkmm_dims(const dim_vector &dimx, const dim_vector &dimy, F77_INT &m, F77_INT &n, F77_INT &k, F77_INT &np, dim_vector &dimz)
Definition: dot.cc:299
void error(const char *fmt,...)
Definition: error.cc:980
#define F77_CONST_CMPLX_ARG(x)
Definition: f77-fcn.h:313
#define F77_DBLE_CMPLX_ARG(x)
Definition: f77-fcn.h:316
#define F77_CMPLX_ARG(x)
Definition: f77-fcn.h:310
#define F77_XFCN(f, F, args)
Definition: f77-fcn.h:45
octave_f77_int_type F77_INT
Definition: f77-fcn.h:306
#define F77_CONST_DBLE_CMPLX_ARG(x)
Definition: f77-fcn.h:319
double dot(const ColumnVector &v1, const ColumnVector &v2)
Definition: graphics.cc:5934
F77_RET_T const F77_DBLE * x
OCTINTERP_API octave_value binary_op(type_info &ti, octave_value::binary_op op, const octave_value &a, const octave_value &b)
subroutine sdot3(m, n, k, a, b, c)
Definition: sdot3.f:23
subroutine smatm3(m, n, k, np, a, b, c)
Definition: smatm3.f:23
subroutine zdotc3(m, n, k, a, b, c)
Definition: zdotc3.f:23
subroutine zmatm3(m, n, k, np, a, b, c)
Definition: zmatm3.f:23