GNU Octave  6.2.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
Sparse-diag-op-defs.h
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2009-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if ! defined (octave_Sparse_diag_op_defs_h)
27 #define octave_Sparse_diag_op_defs_h 1
28 
29 #include "octave-config.h"
30 
31 #include "lo-array-errwarn.h"
32 
33 // Matrix multiplication
34 
35 template <typename RT, typename DM, typename SM>
36 RT do_mul_dm_sm (const DM& d, const SM& a)
37 {
38  const octave_idx_type nr = d.rows ();
39  const octave_idx_type nc = d.cols ();
40 
41  const octave_idx_type a_nr = a.rows ();
42  const octave_idx_type a_nc = a.cols ();
43 
44  if (nc != a_nr)
45  octave::err_nonconformant ("operator *", nr, nc, a_nr, a_nc);
46 
47  RT r (nr, a_nc, a.nnz ());
48 
49  octave_idx_type l = 0;
50 
51  for (octave_idx_type j = 0; j < a_nc; j++)
52  {
53  r.xcidx (j) = l;
54  const octave_idx_type colend = a.cidx (j+1);
55  for (octave_idx_type k = a.cidx (j); k < colend; k++)
56  {
57  const octave_idx_type i = a.ridx (k);
58  if (i >= nr) break;
59  r.xdata (l) = d.dgelem (i) * a.data (k);
60  r.xridx (l) = i;
61  l++;
62  }
63  }
64 
65  r.xcidx (a_nc) = l;
66 
67  r.maybe_compress (true);
68  return r;
69 }
70 
71 template <typename RT, typename SM, typename DM>
72 RT do_mul_sm_dm (const SM& a, const DM& d)
73 {
74  const octave_idx_type nr = d.rows ();
75  const octave_idx_type nc = d.cols ();
76 
77  const octave_idx_type a_nr = a.rows ();
78  const octave_idx_type a_nc = a.cols ();
79 
80  if (nr != a_nc)
81  octave::err_nonconformant ("operator *", a_nr, a_nc, nr, nc);
82 
83  const octave_idx_type mnc = (nc < a_nc ? nc: a_nc);
84  RT r (a_nr, nc, a.cidx (mnc));
85 
86  for (octave_idx_type j = 0; j < mnc; ++j)
87  {
88  const typename DM::element_type s = d.dgelem (j);
89  const octave_idx_type colend = a.cidx (j+1);
90  r.xcidx (j) = a.cidx (j);
91  for (octave_idx_type k = a.cidx (j); k < colend; ++k)
92  {
93  r.xdata (k) = s * a.data (k);
94  r.xridx (k) = a.ridx (k);
95  }
96  }
97  for (octave_idx_type j = mnc; j <= nc; ++j)
98  r.xcidx (j) = a.cidx (mnc);
99 
100  r.maybe_compress (true);
101  return r;
102 }
103 
104 // FIXME: functors such as this should be gathered somewhere
105 template <typename T>
107  : public std::unary_function <T, T>
108 {
109  T operator () (const T x) { return x; }
110 };
111 
112 // Matrix addition
113 
114 template <typename RT, typename SM, typename DM, typename OpA, typename OpD>
115 RT inner_do_add_sm_dm (const SM& a, const DM& d, OpA opa, OpD opd)
116 {
117  using std::min;
118  const octave_idx_type nr = d.rows ();
119  const octave_idx_type nc = d.cols ();
120  const octave_idx_type n = min (nr, nc);
121 
122  const octave_idx_type a_nr = a.rows ();
123  const octave_idx_type a_nc = a.cols ();
124 
125  const octave_idx_type nz = a.nnz ();
126  RT r (a_nr, a_nc, nz + n);
127  octave_idx_type k = 0;
128 
129  for (octave_idx_type j = 0; j < nc; ++j)
130  {
131  octave_quit ();
132  const octave_idx_type colend = a.cidx (j+1);
133  r.xcidx (j) = k;
134  octave_idx_type k_src = a.cidx (j), k_split;
135 
136  for (k_split = k_src; k_split < colend; k_split++)
137  if (a.ridx (k_split) >= j)
138  break;
139 
140  for (; k_src < k_split; k_src++, k++)
141  {
142  r.xridx (k) = a.ridx (k_src);
143  r.xdata (k) = opa (a.data (k_src));
144  }
145 
146  if (k_src < colend && a.ridx (k_src) == j)
147  {
148  r.xridx (k) = j;
149  r.xdata (k) = opa (a.data (k_src)) + opd (d.dgelem (j));
150  k++; k_src++;
151  }
152  else
153  {
154  r.xridx (k) = j;
155  r.xdata (k) = opd (d.dgelem (j));
156  k++;
157  }
158 
159  for (; k_src < colend; k_src++, k++)
160  {
161  r.xridx (k) = a.ridx (k_src);
162  r.xdata (k) = opa (a.data (k_src));
163  }
164 
165  }
166  r.xcidx (nc) = k;
167 
168  r.maybe_compress (true);
169  return r;
170 }
171 
172 template <typename RT, typename DM, typename SM>
173 RT do_commutative_add_dm_sm (const DM& d, const SM& a)
174 {
175  // Extra function to ensure this is only emitted once.
176  return inner_do_add_sm_dm<RT> (a, d,
179 }
180 
181 template <typename RT, typename DM, typename SM>
182 RT do_add_dm_sm (const DM& d, const SM& a)
183 {
184  if (a.rows () != d.rows () || a.cols () != d.cols ())
185  octave::err_nonconformant ("operator +",
186  d.rows (), d.cols (), a.rows (), a.cols ());
187  else
188  return do_commutative_add_dm_sm<RT> (d, a);
189 }
190 
191 template <typename RT, typename DM, typename SM>
192 RT do_sub_dm_sm (const DM& d, const SM& a)
193 {
194  if (a.rows () != d.rows () || a.cols () != d.cols ())
195  octave::err_nonconformant ("operator -",
196  d.rows (), d.cols (), a.rows (), a.cols ());
197 
198  return inner_do_add_sm_dm<RT> (a, d,
199  std::negate<typename SM::element_type> (),
201 }
202 
203 template <typename RT, typename SM, typename DM>
204 RT do_add_sm_dm (const SM& a, const DM& d)
205 {
206  if (a.rows () != d.rows () || a.cols () != d.cols ())
207  octave::err_nonconformant ("operator +",
208  a.rows (), a.cols (), d.rows (), d.cols ());
209 
210  return do_commutative_add_dm_sm<RT> (d, a);
211 }
212 
213 template <typename RT, typename SM, typename DM>
214 RT do_sub_sm_dm (const SM& a, const DM& d)
215 {
216  if (a.rows () != d.rows () || a.cols () != d.cols ())
217  octave::err_nonconformant ("operator -",
218  a.rows (), a.cols (), d.rows (), d.cols ());
219 
220  return inner_do_add_sm_dm<RT> (a, d,
222  std::negate<typename DM::element_type> ());
223 }
224 
225 #endif
RT inner_do_add_sm_dm(const SM &a, const DM &d, OpA opa, OpD opd)
RT do_sub_sm_dm(const SM &a, const DM &d)
RT do_add_sm_dm(const SM &a, const DM &d)
RT do_commutative_add_dm_sm(const DM &d, const SM &a)
RT do_add_dm_sm(const DM &d, const SM &a)
RT do_sub_dm_sm(const DM &d, const SM &a)
RT do_mul_sm_dm(const SM &a, const DM &d)
RT do_mul_dm_sm(const DM &d, const SM &a)
charNDArray min(char d, const charNDArray &m)
Definition: chNDArray.cc:207
F77_RET_T const F77_DBLE const F77_DBLE F77_DBLE * d
F77_RET_T const F77_DBLE * x
octave_idx_type n
Definition: mx-inlines.cc:753
T * r
Definition: mx-inlines.cc:773
void err_nonconformant(const char *op, octave_idx_type op1_len, octave_idx_type op2_len)
T operator()(const T x)