GNU Octave  9.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
Sparse-diag-op-defs.h
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2009-2024 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if ! defined (octave_Sparse_diag_op_defs_h)
27 #define octave_Sparse_diag_op_defs_h 1
28 
29 #include "octave-config.h"
30 
31 #include "lo-array-errwarn.h"
32 
33 // Matrix multiplication
34 
35 template <typename RT, typename DM, typename SM>
36 RT do_mul_dm_sm (const DM& d, const SM& a)
37 {
38  const octave_idx_type nr = d.rows ();
39  const octave_idx_type nc = d.cols ();
40 
41  const octave_idx_type a_nr = a.rows ();
42  const octave_idx_type a_nc = a.cols ();
43 
44  if (nc != a_nr)
45  octave::err_nonconformant ("operator *", nr, nc, a_nr, a_nc);
46 
47  RT r (nr, a_nc, a.nnz ());
48 
49  octave_idx_type l = 0;
50 
51  for (octave_idx_type j = 0; j < a_nc; j++)
52  {
53  r.xcidx (j) = l;
54  const octave_idx_type colend = a.cidx (j+1);
55  for (octave_idx_type k = a.cidx (j); k < colend; k++)
56  {
57  const octave_idx_type i = a.ridx (k);
58  if (i >= nr) break;
59  r.xdata (l) = d.dgelem (i) * a.data (k);
60  r.xridx (l) = i;
61  l++;
62  }
63  }
64 
65  r.xcidx (a_nc) = l;
66 
67  r.maybe_compress (true);
68  return r;
69 }
70 
71 template <typename RT, typename SM, typename DM>
72 RT do_mul_sm_dm (const SM& a, const DM& d)
73 {
74  const octave_idx_type nr = d.rows ();
75  const octave_idx_type nc = d.cols ();
76 
77  const octave_idx_type a_nr = a.rows ();
78  const octave_idx_type a_nc = a.cols ();
79 
80  if (nr != a_nc)
81  octave::err_nonconformant ("operator *", a_nr, a_nc, nr, nc);
82 
83  const octave_idx_type mnc = (nc < a_nc ? nc: a_nc);
84  RT r (a_nr, nc, a.cidx (mnc));
85 
86  for (octave_idx_type j = 0; j < mnc; ++j)
87  {
88  const typename DM::element_type s = d.dgelem (j);
89  const octave_idx_type colend = a.cidx (j+1);
90  r.xcidx (j) = a.cidx (j);
91  for (octave_idx_type k = a.cidx (j); k < colend; ++k)
92  {
93  r.xdata (k) = s * a.data (k);
94  r.xridx (k) = a.ridx (k);
95  }
96  }
97  for (octave_idx_type j = mnc; j <= nc; ++j)
98  r.xcidx (j) = a.cidx (mnc);
99 
100  r.maybe_compress (true);
101  return r;
102 }
103 
104 // FIXME: functors such as this should be gathered somewhere
105 template <typename T>
107 {
108 public:
109  typedef T argument_type;
110  typedef T result_type;
111  T operator () (const T x) { return x; }
112 };
113 
114 // Matrix addition
115 
116 template <typename RT, typename SM, typename DM, typename OpA, typename OpD>
117 RT inner_do_add_sm_dm (const SM& a, const DM& d, OpA opa, OpD opd)
118 {
119  using std::min;
120  const octave_idx_type nr = d.rows ();
121  const octave_idx_type nc = d.cols ();
122  const octave_idx_type n = min (nr, nc);
123 
124  const octave_idx_type a_nr = a.rows ();
125  const octave_idx_type a_nc = a.cols ();
126 
127  const octave_idx_type nz = a.nnz ();
128  RT r (a_nr, a_nc, nz + n);
129  octave_idx_type k = 0;
130 
131  for (octave_idx_type j = 0; j < nc; ++j)
132  {
133  octave_quit ();
134  const octave_idx_type colend = a.cidx (j+1);
135  r.xcidx (j) = k;
136  octave_idx_type k_src = a.cidx (j), k_split;
137 
138  for (k_split = k_src; k_split < colend; k_split++)
139  if (a.ridx (k_split) >= j)
140  break;
141 
142  for (; k_src < k_split; k_src++, k++)
143  {
144  r.xridx (k) = a.ridx (k_src);
145  r.xdata (k) = opa (a.data (k_src));
146  }
147 
148  if (k_src < colend && a.ridx (k_src) == j)
149  {
150  r.xridx (k) = j;
151  r.xdata (k) = opa (a.data (k_src)) + opd (d.dgelem (j));
152  k++; k_src++;
153  }
154  else
155  {
156  r.xridx (k) = j;
157  r.xdata (k) = opd (d.dgelem (j));
158  k++;
159  }
160 
161  for (; k_src < colend; k_src++, k++)
162  {
163  r.xridx (k) = a.ridx (k_src);
164  r.xdata (k) = opa (a.data (k_src));
165  }
166 
167  }
168  r.xcidx (nc) = k;
169 
170  r.maybe_compress (true);
171  return r;
172 }
173 
174 template <typename RT, typename DM, typename SM>
175 RT do_commutative_add_dm_sm (const DM& d, const SM& a)
176 {
177  // Extra function to ensure this is only emitted once.
178  return inner_do_add_sm_dm<RT> (a, d,
181 }
182 
183 template <typename RT, typename DM, typename SM>
184 RT do_add_dm_sm (const DM& d, const SM& a)
185 {
186  if (a.rows () != d.rows () || a.cols () != d.cols ())
187  octave::err_nonconformant ("operator +",
188  d.rows (), d.cols (), a.rows (), a.cols ());
189  else
190  return do_commutative_add_dm_sm<RT> (d, a);
191 }
192 
193 template <typename RT, typename DM, typename SM>
194 RT do_sub_dm_sm (const DM& d, const SM& a)
195 {
196  if (a.rows () != d.rows () || a.cols () != d.cols ())
197  octave::err_nonconformant ("operator -",
198  d.rows (), d.cols (), a.rows (), a.cols ());
199 
200  return inner_do_add_sm_dm<RT> (a, d,
201  std::negate<typename SM::element_type> (),
203 }
204 
205 template <typename RT, typename SM, typename DM>
206 RT do_add_sm_dm (const SM& a, const DM& d)
207 {
208  if (a.rows () != d.rows () || a.cols () != d.cols ())
209  octave::err_nonconformant ("operator +",
210  a.rows (), a.cols (), d.rows (), d.cols ());
211 
212  return do_commutative_add_dm_sm<RT> (d, a);
213 }
214 
215 template <typename RT, typename SM, typename DM>
216 RT do_sub_sm_dm (const SM& a, const DM& d)
217 {
218  if (a.rows () != d.rows () || a.cols () != d.cols ())
219  octave::err_nonconformant ("operator -",
220  a.rows (), a.cols (), d.rows (), d.cols ());
221 
222  return inner_do_add_sm_dm<RT> (a, d,
224  std::negate<typename DM::element_type> ());
225 }
226 
227 #endif
RT inner_do_add_sm_dm(const SM &a, const DM &d, OpA opa, OpD opd)
RT do_sub_sm_dm(const SM &a, const DM &d)
RT do_add_sm_dm(const SM &a, const DM &d)
RT do_commutative_add_dm_sm(const DM &d, const SM &a)
RT do_add_dm_sm(const DM &d, const SM &a)
RT do_sub_dm_sm(const DM &d, const SM &a)
RT do_mul_sm_dm(const SM &a, const DM &d)
RT do_mul_dm_sm(const DM &d, const SM &a)
charNDArray min(char d, const charNDArray &m)
Definition: chNDArray.cc:207
void err_nonconformant(const char *op, octave_idx_type op1_len, octave_idx_type op2_len)
F77_RET_T const F77_DBLE const F77_DBLE F77_DBLE * d
F77_RET_T const F77_DBLE * x
octave_idx_type n
Definition: mx-inlines.cc:761
T * r
Definition: mx-inlines.cc:781
T operator()(const T x)