GNU Octave 10.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
 
Loading...
Searching...
No Matches
Sparse-diag-op-defs.h
Go to the documentation of this file.
1////////////////////////////////////////////////////////////////////////
2//
3// Copyright (C) 2009-2025 The Octave Project Developers
4//
5// See the file COPYRIGHT.md in the top-level directory of this
6// distribution or <https://octave.org/copyright/>.
7//
8// This file is part of Octave.
9//
10// Octave is free software: you can redistribute it and/or modify it
11// under the terms of the GNU General Public License as published by
12// the Free Software Foundation, either version 3 of the License, or
13// (at your option) any later version.
14//
15// Octave is distributed in the hope that it will be useful, but
16// WITHOUT ANY WARRANTY; without even the implied warranty of
17// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18// GNU General Public License for more details.
19//
20// You should have received a copy of the GNU General Public License
21// along with Octave; see the file COPYING. If not, see
22// <https://www.gnu.org/licenses/>.
23//
24////////////////////////////////////////////////////////////////////////
25
26#if ! defined (octave_Sparse_diag_op_defs_h)
27#define octave_Sparse_diag_op_defs_h 1
28
29#include "octave-config.h"
30
31#include "lo-array-errwarn.h"
32
33// Matrix multiplication
34
35template <typename RT, typename DM, typename SM>
36RT do_mul_dm_sm (const DM& d, const SM& a)
37{
38 const octave_idx_type nr = d.rows ();
39 const octave_idx_type nc = d.cols ();
40
41 const octave_idx_type a_nr = a.rows ();
42 const octave_idx_type a_nc = a.cols ();
43
44 if (nc != a_nr)
45 octave::err_nonconformant ("operator *", nr, nc, a_nr, a_nc);
46
47 RT r (nr, a_nc, a.nnz ());
48
49 octave_idx_type l = 0;
50
51 for (octave_idx_type j = 0; j < a_nc; j++)
52 {
53 r.xcidx (j) = l;
54 const octave_idx_type colend = a.cidx (j+1);
55 for (octave_idx_type k = a.cidx (j); k < colend; k++)
56 {
57 const octave_idx_type i = a.ridx (k);
58 if (i >= nr) break;
59 r.xdata (l) = d.dgelem (i) * a.data (k);
60 r.xridx (l) = i;
61 l++;
62 }
63 }
64
65 r.xcidx (a_nc) = l;
66
67 r.maybe_compress (true);
68 return r;
69}
70
71template <typename RT, typename SM, typename DM>
72RT do_mul_sm_dm (const SM& a, const DM& d)
73{
74 const octave_idx_type nr = d.rows ();
75 const octave_idx_type nc = d.cols ();
76
77 const octave_idx_type a_nr = a.rows ();
78 const octave_idx_type a_nc = a.cols ();
79
80 if (nr != a_nc)
81 octave::err_nonconformant ("operator *", a_nr, a_nc, nr, nc);
82
83 const octave_idx_type mnc = (nc < a_nc ? nc: a_nc);
84 RT r (a_nr, nc, a.cidx (mnc));
85
86 for (octave_idx_type j = 0; j < mnc; ++j)
87 {
88 const typename DM::element_type s = d.dgelem (j);
89 const octave_idx_type colend = a.cidx (j+1);
90 r.xcidx (j) = a.cidx (j);
91 for (octave_idx_type k = a.cidx (j); k < colend; ++k)
92 {
93 r.xdata (k) = s * a.data (k);
94 r.xridx (k) = a.ridx (k);
95 }
96 }
97 for (octave_idx_type j = mnc; j <= nc; ++j)
98 r.xcidx (j) = a.cidx (mnc);
99
100 r.maybe_compress (true);
101 return r;
102}
103
104// FIXME: functors such as this should be gathered somewhere
105template <typename T>
107{
108public:
109 typedef T argument_type;
110 typedef T result_type;
111 T operator () (const T x) { return x; }
112};
113
114// Matrix addition
115
116template <typename RT, typename SM, typename DM, typename OpA, typename OpD>
117RT inner_do_add_sm_dm (const SM& a, const DM& d, OpA opa, OpD opd)
118{
119 using std::min;
120 const octave_idx_type nr = d.rows ();
121 const octave_idx_type nc = d.cols ();
122 const octave_idx_type n = min (nr, nc);
123
124 const octave_idx_type a_nr = a.rows ();
125 const octave_idx_type a_nc = a.cols ();
126
127 const octave_idx_type nz = a.nnz ();
128 RT r (a_nr, a_nc, nz + n);
129 octave_idx_type k = 0;
130
131 for (octave_idx_type j = 0; j < nc; ++j)
132 {
133 octave_quit ();
134 const octave_idx_type colend = a.cidx (j+1);
135 r.xcidx (j) = k;
136 octave_idx_type k_src = a.cidx (j), k_split;
137
138 for (k_split = k_src; k_split < colend; k_split++)
139 if (a.ridx (k_split) >= j)
140 break;
141
142 for (; k_src < k_split; k_src++, k++)
143 {
144 r.xridx (k) = a.ridx (k_src);
145 r.xdata (k) = opa (a.data (k_src));
146 }
147
148 if (k_src < colend && a.ridx (k_src) == j)
149 {
150 r.xridx (k) = j;
151 r.xdata (k) = opa (a.data (k_src)) + opd (d.dgelem (j));
152 k++; k_src++;
153 }
154 else
155 {
156 r.xridx (k) = j;
157 r.xdata (k) = opd (d.dgelem (j));
158 k++;
159 }
160
161 for (; k_src < colend; k_src++, k++)
162 {
163 r.xridx (k) = a.ridx (k_src);
164 r.xdata (k) = opa (a.data (k_src));
165 }
166
167 }
168 r.xcidx (nc) = k;
169
170 r.maybe_compress (true);
171 return r;
172}
173
174template <typename RT, typename DM, typename SM>
175RT do_commutative_add_dm_sm (const DM& d, const SM& a)
176{
177 // Extra function to ensure this is only emitted once.
178 return inner_do_add_sm_dm<RT> (a, d,
181}
182
183template <typename RT, typename DM, typename SM>
184RT do_add_dm_sm (const DM& d, const SM& a)
185{
186 if (a.rows () != d.rows () || a.cols () != d.cols ())
187 octave::err_nonconformant ("operator +",
188 d.rows (), d.cols (), a.rows (), a.cols ());
189 else
190 return do_commutative_add_dm_sm<RT> (d, a);
191}
192
193template <typename RT, typename DM, typename SM>
194RT do_sub_dm_sm (const DM& d, const SM& a)
195{
196 if (a.rows () != d.rows () || a.cols () != d.cols ())
197 octave::err_nonconformant ("operator -",
198 d.rows (), d.cols (), a.rows (), a.cols ());
199
200 return inner_do_add_sm_dm<RT> (a, d,
201 std::negate<typename SM::element_type> (),
203}
204
205template <typename RT, typename SM, typename DM>
206RT do_add_sm_dm (const SM& a, const DM& d)
207{
208 if (a.rows () != d.rows () || a.cols () != d.cols ())
209 octave::err_nonconformant ("operator +",
210 a.rows (), a.cols (), d.rows (), d.cols ());
211
212 return do_commutative_add_dm_sm<RT> (d, a);
213}
214
215template <typename RT, typename SM, typename DM>
216RT do_sub_sm_dm (const SM& a, const DM& d)
217{
218 if (a.rows () != d.rows () || a.cols () != d.cols ())
219 octave::err_nonconformant ("operator -",
220 a.rows (), a.cols (), d.rows (), d.cols ());
221
222 return inner_do_add_sm_dm<RT> (a, d,
224 std::negate<typename DM::element_type> ());
225}
226
227#endif
RT inner_do_add_sm_dm(const SM &a, const DM &d, OpA opa, OpD opd)
RT do_sub_sm_dm(const SM &a, const DM &d)
RT do_add_sm_dm(const SM &a, const DM &d)
RT do_commutative_add_dm_sm(const DM &d, const SM &a)
RT do_add_dm_sm(const DM &d, const SM &a)
RT do_sub_dm_sm(const DM &d, const SM &a)
RT do_mul_sm_dm(const SM &a, const DM &d)
RT do_mul_dm_sm(const DM &d, const SM &a)
charNDArray min(char d, const charNDArray &m)
Definition chNDArray.cc:207
F77_RET_T const F77_DBLE const F77_DBLE F77_DBLE * d
F77_RET_T const F77_DBLE * x
T operator()(const T x)