GNU Octave  9.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
MSparse.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 1998-2024 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 // sparse array with math ops.
27 
28 // Element by element MSparse by MSparse ops.
29 
30 template <typename T, typename OP>
32 plus_or_minus (MSparse<T>& a, const MSparse<T>& b, OP op, const char *op_name)
33 {
34  MSparse<T> r;
35 
36  octave_idx_type a_nr = a.rows ();
37  octave_idx_type a_nc = a.cols ();
38 
39  octave_idx_type b_nr = b.rows ();
40  octave_idx_type b_nc = b.cols ();
41 
42  if (a_nr != b_nr || a_nc != b_nc)
43  octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
44 
45  r = MSparse<T> (a_nr, a_nc, (a.nnz () + b.nnz ()));
46 
47  octave_idx_type jx = 0;
48  for (octave_idx_type i = 0 ; i < a_nc ; i++)
49  {
50  octave_idx_type ja = a.cidx (i);
51  octave_idx_type ja_max = a.cidx (i+1);
52  bool ja_lt_max = ja < ja_max;
53 
54  octave_idx_type jb = b.cidx (i);
55  octave_idx_type jb_max = b.cidx (i+1);
56  bool jb_lt_max = jb < jb_max;
57 
58  while (ja_lt_max || jb_lt_max)
59  {
60  octave_quit ();
61  if ((! jb_lt_max) || (ja_lt_max && (a.ridx (ja) < b.ridx (jb))))
62  {
63  r.ridx (jx) = a.ridx (ja);
64  r.data (jx) = op (a.data (ja), 0.);
65  jx++;
66  ja++;
67  ja_lt_max= ja < ja_max;
68  }
69  else if ((! ja_lt_max)
70  || (b.ridx (jb) < a.ridx (ja)))
71  {
72  r.ridx (jx) = b.ridx (jb);
73  r.data (jx) = op (0., b.data (jb));
74  jx++;
75  jb++;
76  jb_lt_max= jb < jb_max;
77  }
78  else
79  {
80  if (op (a.data (ja), b.data (jb)) != 0.)
81  {
82  r.data (jx) = op (a.data (ja), b.data (jb));
83  r.ridx (jx) = a.ridx (ja);
84  jx++;
85  }
86  ja++;
87  ja_lt_max= ja < ja_max;
88  jb++;
89  jb_lt_max= jb < jb_max;
90  }
91  }
92  r.cidx (i+1) = jx;
93  }
94 
95  a = r.maybe_compress ();
96 
97  return a;
98 }
99 
100 template <typename T>
101 MSparse<T>&
103 {
104  return plus_or_minus (a, b, std::plus<T> (), "operator +=");
105 }
106 
107 template <typename T>
108 MSparse<T>&
110 {
111  return plus_or_minus (a, b, std::minus<T> (), "operator -=");
112 }
113 
114 // Element by element MSparse by scalar ops.
115 
116 template <typename T, typename OP>
117 MArray<T>
118 plus_or_minus (const MSparse<T>& a, const T& s, OP op)
119 {
120  octave_idx_type nr = a.rows ();
121  octave_idx_type nc = a.cols ();
122 
123  MArray<T> r (dim_vector (nr, nc), op (0.0, s));
124 
125  for (octave_idx_type j = 0; j < nc; j++)
126  for (octave_idx_type i = a.cidx (j); i < a.cidx (j+1); i++)
127  r.elem (a.ridx (i), j) = op (a.data (i), s);
128  return r;
129 }
130 
131 template <typename T>
132 MArray<T>
133 operator + (const MSparse<T>& a, const T& s)
134 {
135  return plus_or_minus (a, s, std::plus<T> ());
136 }
137 
138 template <typename T>
139 MArray<T>
140 operator - (const MSparse<T>& a, const T& s)
141 {
142  return plus_or_minus (a, s, std::minus<T> ());
143 }
144 
145 template <typename T, typename OP>
147 times_or_divide (const MSparse<T>& a, const T& s, OP op)
148 {
149  octave_idx_type nr = a.rows ();
150  octave_idx_type nc = a.cols ();
151  octave_idx_type nz = a.nnz ();
152 
153  MSparse<T> r (nr, nc, nz);
154 
155  for (octave_idx_type i = 0; i < nz; i++)
156  {
157  r.data (i) = op (a.data (i), s);
158  r.ridx (i) = a.ridx (i);
159  }
160  for (octave_idx_type i = 0; i < nc + 1; i++)
161  r.cidx (i) = a.cidx (i);
162  r.maybe_compress (true);
163  return r;
164 }
165 
166 template <typename T>
168 operator * (const MSparse<T>& a, const T& s)
169 {
170  return times_or_divide (a, s, std::multiplies<T> ());
171 }
172 
173 template <typename T>
175 operator / (const MSparse<T>& a, const T& s)
176 {
177  return times_or_divide (a, s, std::divides<T> ());
178 }
179 
180 // Element by element scalar by MSparse ops.
181 
182 template <typename T, typename OP>
183 MArray<T>
184 plus_or_minus (const T& s, const MSparse<T>& a, OP op)
185 {
186  octave_idx_type nr = a.rows ();
187  octave_idx_type nc = a.cols ();
188 
189  MArray<T> r (dim_vector (nr, nc), op (s, 0.0));
190 
191  for (octave_idx_type j = 0; j < nc; j++)
192  for (octave_idx_type i = a.cidx (j); i < a.cidx (j+1); i++)
193  r.elem (a.ridx (i), j) = op (s, a.data (i));
194  return r;
195 }
196 
197 template <typename T>
198 MArray<T>
199 operator + (const T& s, const MSparse<T>& a)
200 {
201  return plus_or_minus (s, a, std::plus<T> ());
202 }
203 
204 template <typename T>
205 MArray<T>
206 operator - (const T& s, const MSparse<T>& a)
207 {
208  return plus_or_minus (s, a, std::minus<T> ());
209 }
210 
211 template <typename T, typename OP>
213 times_or_divides (const T& s, const MSparse<T>& a, OP op)
214 {
215  octave_idx_type nr = a.rows ();
216  octave_idx_type nc = a.cols ();
217  octave_idx_type nz = a.nnz ();
218 
219  MSparse<T> r (nr, nc, nz);
220 
221  for (octave_idx_type i = 0; i < nz; i++)
222  {
223  r.data (i) = op (s, a.data (i));
224  r.ridx (i) = a.ridx (i);
225  }
226  for (octave_idx_type i = 0; i < nc + 1; i++)
227  r.cidx (i) = a.cidx (i);
228  r.maybe_compress (true);
229  return r;
230 }
231 
232 template <typename T>
234 operator * (const T& s, const MSparse<T>& a)
235 {
236  return times_or_divides (s, a, std::multiplies<T> ());
237 }
238 
239 template <typename T>
241 operator / (const T& s, const MSparse<T>& a)
242 {
243  return times_or_divides (s, a, std::divides<T> ());
244 }
245 
246 // Element by element MSparse by MSparse ops.
247 
248 template <typename T, typename OP>
250 plus_or_minus (const MSparse<T>& a, const MSparse<T>& b, OP op,
251  const char *op_name, bool negate)
252 {
253  MSparse<T> r;
254 
255  octave_idx_type a_nr = a.rows ();
256  octave_idx_type a_nc = a.cols ();
257 
258  octave_idx_type b_nr = b.rows ();
259  octave_idx_type b_nc = b.cols ();
260 
261  if (a_nr == 1 && a_nc == 1)
262  {
263  if (a.elem (0, 0) == 0.)
264  if (negate)
265  r = -MSparse<T> (b);
266  else
267  r = MSparse<T> (b);
268  else
269  {
270  r = MSparse<T> (b_nr, b_nc, op (a.data (0), 0.));
271 
272  for (octave_idx_type j = 0 ; j < b_nc ; j++)
273  {
274  octave_quit ();
275  octave_idx_type idxj = j * b_nr;
276  for (octave_idx_type i = b.cidx (j) ; i < b.cidx (j+1) ; i++)
277  {
278  octave_quit ();
279  r.data (idxj + b.ridx (i)) = op (a.data (0), b.data (i));
280  }
281  }
282  r.maybe_compress ();
283  }
284  }
285  else if (b_nr == 1 && b_nc == 1)
286  {
287  if (b.elem (0, 0) == 0.)
288  r = MSparse<T> (a);
289  else
290  {
291  r = MSparse<T> (a_nr, a_nc, op (0.0, b.data (0)));
292 
293  for (octave_idx_type j = 0 ; j < a_nc ; j++)
294  {
295  octave_quit ();
296  octave_idx_type idxj = j * a_nr;
297  for (octave_idx_type i = a.cidx (j) ; i < a.cidx (j+1) ; i++)
298  {
299  octave_quit ();
300  r.data (idxj + a.ridx (i)) = op (a.data (i), b.data (0));
301  }
302  }
303  r.maybe_compress ();
304  }
305  }
306  else if (a_nr != b_nr || a_nc != b_nc)
307  octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
308  else
309  {
310  r = MSparse<T> (a_nr, a_nc, (a.nnz () + b.nnz ()));
311 
312  octave_idx_type jx = 0;
313  r.cidx (0) = 0;
314  for (octave_idx_type i = 0 ; i < a_nc ; i++)
315  {
316  octave_idx_type ja = a.cidx (i);
317  octave_idx_type ja_max = a.cidx (i+1);
318  bool ja_lt_max = ja < ja_max;
319 
320  octave_idx_type jb = b.cidx (i);
321  octave_idx_type jb_max = b.cidx (i+1);
322  bool jb_lt_max = jb < jb_max;
323 
324  while (ja_lt_max || jb_lt_max)
325  {
326  octave_quit ();
327  if ((! jb_lt_max) || (ja_lt_max && (a.ridx (ja) < b.ridx (jb))))
328  {
329  r.ridx (jx) = a.ridx (ja);
330  r.data (jx) = op (a.data (ja), 0.);
331  jx++;
332  ja++;
333  ja_lt_max= ja < ja_max;
334  }
335  else if ((! ja_lt_max)
336  || (b.ridx (jb) < a.ridx (ja)))
337  {
338  r.ridx (jx) = b.ridx (jb);
339  r.data (jx) = op (0., b.data (jb));
340  jx++;
341  jb++;
342  jb_lt_max= jb < jb_max;
343  }
344  else
345  {
346  if (op (a.data (ja), b.data (jb)) != 0.)
347  {
348  r.data (jx) = op (a.data (ja), b.data (jb));
349  r.ridx (jx) = a.ridx (ja);
350  jx++;
351  }
352  ja++;
353  ja_lt_max= ja < ja_max;
354  jb++;
355  jb_lt_max= jb < jb_max;
356  }
357  }
358  r.cidx (i+1) = jx;
359  }
360 
361  r.maybe_compress ();
362  }
363 
364  return r;
365 }
366 
367 template <typename T>
369 operator + (const MSparse<T>& a, const MSparse<T>& b)
370 {
371  return plus_or_minus (a, b, std::plus<T> (), "operator +", false);
372 }
373 
374 template <typename T>
376 operator - (const MSparse<T>& a, const MSparse<T>& b)
377 {
378  return plus_or_minus (a, b, std::minus<T> (), "operator -", true);
379 }
380 
381 template <typename T>
383 product (const MSparse<T>& a, const MSparse<T>& b)
384 {
385  MSparse<T> r;
386 
387  octave_idx_type a_nr = a.rows ();
388  octave_idx_type a_nc = a.cols ();
389 
390  octave_idx_type b_nr = b.rows ();
391  octave_idx_type b_nc = b.cols ();
392 
393  if (a_nr == 1 && a_nc == 1)
394  {
395  if (a.elem (0, 0) == 0.)
396  r = MSparse<T> (b_nr, b_nc);
397  else
398  {
399  r = MSparse<T> (b);
400  octave_idx_type b_nnz = b.nnz ();
401 
402  for (octave_idx_type i = 0 ; i < b_nnz ; i++)
403  {
404  octave_quit ();
405  r.data (i) = a.data (0) * r.data (i);
406  }
407  r.maybe_compress ();
408  }
409  }
410  else if (b_nr == 1 && b_nc == 1)
411  {
412  if (b.elem (0, 0) == 0.)
413  r = MSparse<T> (a_nr, a_nc);
414  else
415  {
416  r = MSparse<T> (a);
417  octave_idx_type a_nnz = a.nnz ();
418 
419  for (octave_idx_type i = 0 ; i < a_nnz ; i++)
420  {
421  octave_quit ();
422  r.data (i) = r.data (i) * b.data (0);
423  }
424  r.maybe_compress ();
425  }
426  }
427  else if (a_nr != b_nr || a_nc != b_nc)
428  octave::err_nonconformant ("product", a_nr, a_nc, b_nr, b_nc);
429  else
430  {
431  r = MSparse<T> (a_nr, a_nc, (a.nnz () > b.nnz () ? a.nnz () : b.nnz ()));
432 
433  octave_idx_type jx = 0;
434  r.cidx (0) = 0;
435  for (octave_idx_type i = 0 ; i < a_nc ; i++)
436  {
437  octave_idx_type ja = a.cidx (i);
438  octave_idx_type ja_max = a.cidx (i+1);
439  bool ja_lt_max = ja < ja_max;
440 
441  octave_idx_type jb = b.cidx (i);
442  octave_idx_type jb_max = b.cidx (i+1);
443  bool jb_lt_max = jb < jb_max;
444 
445  while (ja_lt_max || jb_lt_max)
446  {
447  octave_quit ();
448  if ((! jb_lt_max) || (ja_lt_max && (a.ridx (ja) < b.ridx (jb))))
449  {
450  ja++; ja_lt_max= ja < ja_max;
451  }
452  else if ((! ja_lt_max)
453  || (b.ridx (jb) < a.ridx (ja)))
454  {
455  jb++; jb_lt_max= jb < jb_max;
456  }
457  else
458  {
459  if ((a.data (ja) * b.data (jb)) != 0.)
460  {
461  r.data (jx) = a.data (ja) * b.data (jb);
462  r.ridx (jx) = a.ridx (ja);
463  jx++;
464  }
465  ja++; ja_lt_max= ja < ja_max;
466  jb++; jb_lt_max= jb < jb_max;
467  }
468  }
469  r.cidx (i+1) = jx;
470  }
471 
472  r.maybe_compress ();
473  }
474 
475  return r;
476 }
477 
478 template <typename T>
480 quotient (const MSparse<T>& a, const MSparse<T>& b)
481 {
482  MSparse<T> r;
483  T Zero = T ();
484 
485  octave_idx_type a_nr = a.rows ();
486  octave_idx_type a_nc = a.cols ();
487 
488  octave_idx_type b_nr = b.rows ();
489  octave_idx_type b_nc = b.cols ();
490 
491  if (a_nr == 1 && a_nc == 1)
492  {
493  T val = a.elem (0, 0);
494  T fill = val / T ();
495  if (fill == T ())
496  {
497  octave_idx_type b_nnz = b.nnz ();
498  r = MSparse<T> (b);
499  for (octave_idx_type i = 0 ; i < b_nnz ; i++)
500  r.data (i) = val / r.data (i);
501  r.maybe_compress ();
502  }
503  else
504  {
505  r = MSparse<T> (b_nr, b_nc, fill);
506  for (octave_idx_type j = 0 ; j < b_nc ; j++)
507  {
508  octave_quit ();
509  octave_idx_type idxj = j * b_nr;
510  for (octave_idx_type i = b.cidx (j) ; i < b.cidx (j+1) ; i++)
511  {
512  octave_quit ();
513  r.data (idxj + b.ridx (i)) = val / b.data (i);
514  }
515  }
516  r.maybe_compress ();
517  }
518  }
519  else if (b_nr == 1 && b_nc == 1)
520  {
521  T val = b.elem (0, 0);
522  T fill = T () / val;
523  if (fill == T ())
524  {
525  octave_idx_type a_nnz = a.nnz ();
526  r = MSparse<T> (a);
527  for (octave_idx_type i = 0 ; i < a_nnz ; i++)
528  r.data (i) = r.data (i) / val;
529  r.maybe_compress ();
530  }
531  else
532  {
533  r = MSparse<T> (a_nr, a_nc, fill);
534  for (octave_idx_type j = 0 ; j < a_nc ; j++)
535  {
536  octave_quit ();
537  octave_idx_type idxj = j * a_nr;
538  for (octave_idx_type i = a.cidx (j) ; i < a.cidx (j+1) ; i++)
539  {
540  octave_quit ();
541  r.data (idxj + a.ridx (i)) = a.data (i) / val;
542  }
543  }
544  r.maybe_compress ();
545  }
546  }
547  else if (a_nr != b_nr || a_nc != b_nc)
548  octave::err_nonconformant ("quotient", a_nr, a_nc, b_nr, b_nc);
549  else
550  {
551  r = MSparse<T> (a_nr, a_nc, (Zero / Zero));
552 
553  for (octave_idx_type i = 0 ; i < a_nc ; i++)
554  {
555  octave_idx_type ja = a.cidx (i);
556  octave_idx_type ja_max = a.cidx (i+1);
557  bool ja_lt_max = ja < ja_max;
558 
559  octave_idx_type jb = b.cidx (i);
560  octave_idx_type jb_max = b.cidx (i+1);
561  bool jb_lt_max = jb < jb_max;
562 
563  while (ja_lt_max || jb_lt_max)
564  {
565  octave_quit ();
566  if ((! jb_lt_max) || (ja_lt_max && (a.ridx (ja) < b.ridx (jb))))
567  {
568  r.elem (a.ridx (ja), i) = a.data (ja) / Zero;
569  ja++; ja_lt_max= ja < ja_max;
570  }
571  else if ((! ja_lt_max)
572  || (b.ridx (jb) < a.ridx (ja)))
573  {
574  r.elem (b.ridx (jb), i) = Zero / b.data (jb);
575  jb++; jb_lt_max= jb < jb_max;
576  }
577  else
578  {
579  r.elem (a.ridx (ja), i) = a.data (ja) / b.data (jb);
580  ja++; ja_lt_max= ja < ja_max;
581  jb++; jb_lt_max= jb < jb_max;
582  }
583  }
584  }
585 
586  r.maybe_compress (true);
587  }
588 
589  return r;
590 }
591 
592 // Unary MSparse ops.
593 
594 template <typename T>
597 {
598  return a;
599 }
600 
601 template <typename T>
604 {
605  MSparse<T> retval (a);
606  octave_idx_type nz = a.nnz ();
607  for (octave_idx_type i = 0; i < nz; i++)
608  retval.data (i) = - retval.data (i);
609  return retval;
610 }
MSparse< T > & operator-=(MSparse< T > &a, const MSparse< T > &b)
Definition: MSparse.cc:109
MSparse< T > times_or_divides(const T &s, const MSparse< T > &a, OP op)
Definition: MSparse.cc:213
MSparse< T > & plus_or_minus(MSparse< T > &a, const MSparse< T > &b, OP op, const char *op_name)
Definition: MSparse.cc:32
MArray< T > operator-(const MSparse< T > &a, const T &s)
Definition: MSparse.cc:140
MArray< T > operator+(const MSparse< T > &a, const T &s)
Definition: MSparse.cc:133
MSparse< T > times_or_divide(const MSparse< T > &a, const T &s, OP op)
Definition: MSparse.cc:147
MSparse< T > product(const MSparse< T > &a, const MSparse< T > &b)
Definition: MSparse.cc:383
MSparse< T > quotient(const MSparse< T > &a, const MSparse< T > &b)
Definition: MSparse.cc:480
MSparse< T > & operator+=(MSparse< T > &a, const MSparse< T > &b)
Definition: MSparse.cc:102
MSparse< T > operator/(const MSparse< T > &a, const T &s)
Definition: MSparse.cc:175
MSparse< T > operator*(const MSparse< T > &a, const T &s)
Definition: MSparse.cc:168
Template for N-dimensional array classes with like-type math operators.
Definition: MArray.h:63
octave_idx_type cols() const
Definition: Sparse.h:352
octave_idx_type * cidx()
Definition: Sparse.h:596
T * data()
Definition: Sparse.h:574
octave_idx_type * ridx()
Definition: Sparse.h:583
T & elem(octave_idx_type n)
Definition: Sparse.h:456
octave_idx_type nnz() const
Actual number of nonzero terms.
Definition: Sparse.h:339
octave_idx_type rows() const
Definition: Sparse.h:351
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:94
void err_nonconformant(const char *op, octave_idx_type op1_len, octave_idx_type op2_len)
T * r
Definition: mx-inlines.cc:781