GNU Octave 11.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
 
Loading...
Searching...
No Matches
oct-convn.cc
Go to the documentation of this file.
1////////////////////////////////////////////////////////////////////////
2//
3// Copyright (C) 2010-2026 The Octave Project Developers
4//
5// See the file COPYRIGHT.md in the top-level directory of this
6// distribution or <https://octave.org/copyright/>.
7//
8// This file is part of Octave.
9//
10// Octave is free software: you can redistribute it and/or modify it
11// under the terms of the GNU General Public License as published by
12// the Free Software Foundation, either version 3 of the License, or
13// (at your option) any later version.
14//
15// Octave is distributed in the hope that it will be useful, but
16// WITHOUT ANY WARRANTY; without even the implied warranty of
17// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18// GNU General Public License for more details.
19//
20// You should have received a copy of the GNU General Public License
21// along with Octave; see the file COPYING. If not, see
22// <https://www.gnu.org/licenses/>.
23//
24////////////////////////////////////////////////////////////////////////
25
26#if defined (HAVE_CONFIG_H)
27# include "config.h"
28#endif
29
30#include <algorithm>
31
32#include "Array-oct.h"
33#include "CColVector.h"
34#include "CMatrix.h"
35#include "CNDArray.h"
36#include "CRowVector.h"
37#include "MArray.h"
38#include "blas-proto.h"
39#include "dColVector.h"
40#include "dMatrix.h"
41#include "dNDArray.h"
42#include "dRowVector.h"
43#include "f77-fcn.h"
44#include "fCColVector.h"
45#include "fCMatrix.h"
46#include "fCNDArray.h"
47#include "fCRowVector.h"
48#include "fColVector.h"
49#include "fMatrix.h"
50#include "fNDArray.h"
51#include "fRowVector.h"
52#include "oct-convn.h"
53
55
56// Overload function "blas_axpy" to wrap BLAS ?axpy
57
58static inline void
59blas_axpy (const F77_INT& n, const double& alpha, const double *x,
60 const F77_INT& incx, double *y, const F77_INT& incy)
61{
62 F77_FUNC (daxpy, DAXPY) (n, alpha, x, incx, y, incy);
63}
64
65static inline void
66blas_axpy (const F77_INT& n, const float& alpha, const float *x,
67 const F77_INT& incx, float *y, const F77_INT& incy)
68{
69 F77_FUNC (saxpy, SAXPY) (n, alpha, x, incx, y, incy);
70}
71
72static inline void
73blas_axpy (const F77_INT& n, const Complex& alpha,
74 const Complex *x, const F77_INT& incx,
75 Complex *y, const F77_INT& incy)
76{
77 F77_FUNC (zaxpy, ZAXPY) (n, *F77_CONST_DBLE_CMPLX_ARG (&alpha),
79 F77_DBLE_CMPLX_ARG (y), incy);
80}
81
82static inline void
83blas_axpy (const F77_INT& n, const FloatComplex& alpha,
84 const FloatComplex *x, const F77_INT& incx,
85 FloatComplex *y, const F77_INT& incy)
86{
87 F77_FUNC (caxpy, CAXPY) (n, *F77_CONST_CMPLX_ARG (&alpha),
88 F77_CONST_CMPLX_ARG (x), incx,
89 F77_CMPLX_ARG (y), incy);
90}
91
92
93// 2-D convolution with a matrix kernel.
94template <typename T, typename R>
95static void
96convolve_2d (const T *a, F77_INT ma, F77_INT na,
97 const R *b, F77_INT mb, F77_INT nb,
98 T *c, bool inner)
99{
100 if (inner)
101 {
102 // Inner convolution ("valid")
103 const F77_INT len = ma - mb + 1; // Pre-calculate length
104 for (F77_INT k = 0; k < na - nb + 1; k++)
105 for (F77_INT j = 0; j < nb; j++)
106 for (F77_INT i = 0; i < mb; i++)
107 {
108 // Create a T value from R
109 T b_val = static_cast<T>(b[i + j*mb]);
110
111 // Call the appropriate blas_axpy function based on type T
112 blas_axpy (len, b_val, &a[mb-i-1 + (k+nb-j-1)*ma], 1,
113 &c[k*len], 1);
114 }
115 }
116 else
117 {
118 // Outer convolution ("full")
119 const F77_INT len = ma + mb - 1; // Pre-calculate length
120 for (F77_INT k = 0; k < na; k++)
121 for (F77_INT j = 0; j < nb; j++)
122 for (F77_INT i = 0; i < mb; i++)
123 {
124 // Create a T value from R
125 T b_val = static_cast<T>(b[i + j*mb]);
126
127 // Call the appropriate blas_axpy function based on type T
128 blas_axpy (ma, b_val, &a[k*ma], 1, &c[i + (j+k)*len], 1);
129 }
130 }
131}
132
133template <typename T, typename R>
134void
135convolve_nd (const T *a, const dim_vector& ad, const dim_vector& acd,
136 const R *b, const dim_vector& bd, const dim_vector& bcd,
137 T *c, const dim_vector& ccd, int nd, bool inner)
138{
139 if (nd == 2)
140 {
141 F77_INT ad0 = to_f77_int (ad(0));
142 F77_INT ad1 = to_f77_int (ad(1));
143
144 F77_INT bd0 = to_f77_int (bd(0));
145 F77_INT bd1 = to_f77_int (bd(1));
146
147 convolve_2d<T, R> (a, ad0, ad1, b, bd0, bd1, c, inner);
148 }
149 else
150 {
151 octave_idx_type ma = acd(nd-2);
152 octave_idx_type na = ad(nd-1);
153 octave_idx_type mb = bcd(nd-2);
154 octave_idx_type nb = bd(nd-1);
155 octave_idx_type ldc = ccd(nd-2);
156
157 if (inner)
158 {
159 for (octave_idx_type ja = 0; ja < na - nb + 1; ja++)
160 for (octave_idx_type jb = 0; jb < nb; jb++)
161 convolve_nd<T, R> (a + ma*(ja+jb), ad, acd,
162 b + mb*(nb-jb-1), bd, bcd,
163 c + ldc*ja, ccd, nd-1, inner);
164 }
165 else
166 {
167 for (octave_idx_type ja = 0; ja < na; ja++)
168 for (octave_idx_type jb = 0; jb < nb; jb++)
169 convolve_nd<T, R> (a + ma*ja, ad, acd, b + mb*jb, bd, bcd,
170 c + ldc*(ja+jb), ccd, nd-1, inner);
171 }
172 }
173}
174
175// Arbitrary convolutor.
176template <typename T, typename R>
177static MArray<T>
178convolve (const MArray<T>& a, const MArray<R>& b, convn_type ct)
179{
180 if (a.isempty () || b.isempty ())
181 return MArray<T> ();
182
183 const int nd = std::max (a.ndims (), b.ndims ());
184 const dim_vector adims = a.dims ().redim (nd);
185 dim_vector apdims = a.dims ().redim (nd); // permuted adims
186 const dim_vector bdims = b.dims ().redim (nd);
187 dim_vector cdims = dim_vector::alloc (nd);
188
189 for (int i = 0; i < nd; i++)
190 {
191 if (ct == convn_valid)
192 cdims(i) = std::max (adims(i) - bdims(i) + 1,
193 static_cast<octave_idx_type> (0));
194 else
195 cdims(i) = std::max (adims(i) + bdims(i) - 1,
196 static_cast<octave_idx_type> (0));
197 }
198
199 // "valid" shape can sometimes result in empty matrices which must avoid
200 // calling Fortran code which does not expect this (bug #52067)
201 if (cdims.numel () == 0)
202 return MArray<T> (cdims);
203
204 // Permute dimensions of a/b/c such that the dimensions of a are ordered
205 // by decreasing number of elements (for efficiency in Fortran loops).
206 Array<octave_idx_type> order (dim_vector (1, nd));
207 for (int i = 0; i < nd; i++)
208 order(i) = i;
209
210 // Since the number of dimensions is nearly always small, it is faster
211 // to sort them inline instead of calling octave_sort::sort ().
212 bool reordered = false;
213 for (int i = 0; i < nd; i++)
214 for (int j = (i+1); j < nd; j++)
215 if (apdims(i) < apdims(j))
216 {
217 std::swap (apdims(i), apdims(j));
218 std::swap (cdims(i), cdims(j));
219 std::swap (order(i), order(j));
220 reordered = true;
221 }
222
223 // Initialize output based on the current order of cdims.
224 MArray<T> c (cdims, T ());
225
226 if (reordered) // adims was reordered, so the inputs must be as well.
227 {
228 // Permute the inputs
229 const MArray<T> ap = a.permute (order);
230 const MArray<R> bp = b.permute (order);
231 const dim_vector bpdims = bp.dims ().redim (nd);
232
233 // Do convolution on the permuted arrays.
234 convolve_nd<T, R> (ap.data (), apdims, apdims.cumulative (),
235 bp.data (), bpdims, bpdims.cumulative (),
236 c.rwdata (), cdims.cumulative (),
237 nd, ct == convn_valid);
238
239 // Permute back to original order.
240 c = c.ipermute (order);
241 }
242 else // No reordering ==> no need to create permuted arrays.
243 {
244 // Do convolution on the original arrays.
245 convolve_nd<T, R> (a.data (), adims, adims.cumulative (),
246 b.data (), bdims, bdims.cumulative (),
247 c.rwdata (), cdims.cumulative (),
248 nd, ct == convn_valid);
249 }
250
251 if (ct == convn_same)
252 {
253 // Pick the relevant part.
254 Array<idx_vector> sidx (dim_vector (nd, 1));
255
256 for (int i = 0; i < nd; i++)
257 sidx(i) = idx_vector::make_range (bdims(i)/2, 1, adims(i));
258 c = c.index (sidx);
259 }
260
261 return c;
262}
263
264#define CONV_DEFS(TPREF, RPREF) \
265 TPREF ## NDArray \
266 convn (const TPREF ## NDArray& a, const RPREF ## NDArray& b, \
267 convn_type ct) \
268 { \
269 return convolve (a, b, ct); \
270 } \
271 TPREF ## Matrix \
272 convn (const TPREF ## Matrix& a, const RPREF ## Matrix& b, \
273 convn_type ct) \
274 { \
275 return convolve (a, b, ct); \
276 } \
277 TPREF ## Matrix \
278 convn (const TPREF ## Matrix& a, const RPREF ## ColumnVector& c, \
279 const RPREF ## RowVector& r, convn_type ct) \
280 { \
281 return convolve (a, c * r, ct); \
282 }
283
287CONV_DEFS (Float, Float)
290
291OCTAVE_END_NAMESPACE(octave)
N Dimensional Array with copy-on-write semantics.
Definition Array-base.h:130
const dim_vector & dims() const
Return a const-reference so that dims ()(i) works efficiently.
Definition Array-base.h:529
int ndims() const
Size of the specified dimension.
Definition Array-base.h:701
bool isempty() const
Size of the specified dimension.
Definition Array-base.h:674
const T * data() const
Size of the specified dimension.
Definition Array-base.h:687
Template for N-dimensional array classes with like-type math operators.
Definition MArray.h:61
MArray< T > permute(const Array< octave_idx_type > &vec, bool inv=false) const
Definition MArray.h:95
Vector representing the dimensions (size) of an Array.
Definition dim-vector.h:92
octave_idx_type numel(int n=0) const
Number of elements that a matrix with this dimensions would have.
Definition dim-vector.h:341
static dim_vector alloc(int n)
Definition dim-vector.h:208
dim_vector cumulative() const
Return cumulative dimensions.
Definition dim-vector.h:525
dim_vector redim(int n) const
Force certain dimensionality, preserving numel ().
static idx_vector make_range(octave_idx_type start, octave_idx_type step, octave_idx_type len)
Definition idx-vector.h:450
OCTAVE_BEGIN_NAMESPACE(octave) static octave_value daspk_fcn
#define F77_CONST_CMPLX_ARG(x)
Definition f77-fcn.h:313
#define F77_DBLE_CMPLX_ARG(x)
Definition f77-fcn.h:316
#define F77_CMPLX_ARG(x)
Definition f77-fcn.h:310
octave_f77_int_type F77_INT
Definition f77-fcn.h:306
#define F77_CONST_DBLE_CMPLX_ARG(x)
Definition f77-fcn.h:319
std::complex< double > Complex
Definition oct-cmplx.h:33
std::complex< float > FloatComplex
Definition oct-cmplx.h:34
void convolve_nd(const T *a, const dim_vector &ad, const dim_vector &acd, const R *b, const dim_vector &bd, const dim_vector &bcd, T *c, const dim_vector &ccd, int nd, bool inner)
Definition oct-convn.cc:135
#define CONV_DEFS(TPREF, RPREF)
Definition oct-convn.cc:264
convn_type
Definition oct-convn.h:52
@ convn_same
Definition oct-convn.h:54
@ convn_valid
Definition oct-convn.h:55
F77_RET_T const F77_DBLE * x
F77_RET_T F77_FUNC(xerbla, XERBLA)(F77_CONST_CHAR_ARG_DEF(s_arg
F77_RET_T len
Definition xerbla.cc:61