GNU Octave 11.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
 
Loading...
Searching...
No Matches
MSparse.cc
Go to the documentation of this file.
1////////////////////////////////////////////////////////////////////////
2//
3// Copyright (C) 1998-2026 The Octave Project Developers
4//
5// See the file COPYRIGHT.md in the top-level directory of this
6// distribution or <https://octave.org/copyright/>.
7//
8// This file is part of Octave.
9//
10// Octave is free software: you can redistribute it and/or modify it
11// under the terms of the GNU General Public License as published by
12// the Free Software Foundation, either version 3 of the License, or
13// (at your option) any later version.
14//
15// Octave is distributed in the hope that it will be useful, but
16// WITHOUT ANY WARRANTY; without even the implied warranty of
17// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18// GNU General Public License for more details.
19//
20// You should have received a copy of the GNU General Public License
21// along with Octave; see the file COPYING. If not, see
22// <https://www.gnu.org/licenses/>.
23//
24////////////////////////////////////////////////////////////////////////
25
26// sparse array with math ops.
27// Element by element MSparse by MSparse ops.
28
29template <typename T, typename OP>
31plus_or_minus (MSparse<T>& a, const MSparse<T>& b, OP op, const char *op_name)
32{
33 MSparse<T> r;
34
35 octave_idx_type a_nr = a.rows ();
36 octave_idx_type a_nc = a.cols ();
37
38 octave_idx_type b_nr = b.rows ();
39 octave_idx_type b_nc = b.cols ();
40
41 if (a_nr == 0 && (b_nr == 0 || b_nr == 1))
42 {
43 if (a_nc == 1 || b_nc == 1 || a_nc == b_nc)
44 r.resize (a_nr, std::max (a_nc, b_nc));
45 else
46 octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
47 }
48 else if (a_nc == 0 && (b_nc == 0 || b_nc == 1))
49 {
50 if (a_nr == 1 || b_nr == 1 || a_nr == b_nr)
51 r.resize (std::max (a_nr, b_nr), a_nc);
52 else
53 octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
54 }
55 else if (b_nr == 0 && (a_nr == 0 || a_nr == 1))
56 {
57 if (b_nc == 1 || a_nc == 1 || b_nc == a_nc)
58 r.resize (b_nr, std::max (a_nc, b_nc));
59 else
60 octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
61 }
62 else if (b_nc == 0 && (a_nc == 0 || a_nc == 1))
63 {
64 if (b_nr == 1 || a_nr == 1 || b_nr == a_nr)
65 r.resize (std::max (a_nr, b_nr), b_nc);
66 else
67 octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
68 }
69 else if (a_nr != b_nr || a_nc != b_nc)
70 octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
71 else
72 {
73 r = MSparse<T> (a_nr, a_nc, (a.nnz () + b.nnz ()));
74
75 octave_idx_type jx = 0;
76 for (octave_idx_type i = 0 ; i < a_nc ; i++)
77 {
78 octave_idx_type ja = a.cidx (i);
79 octave_idx_type ja_max = a.cidx (i+1);
80 bool ja_lt_max = ja < ja_max;
81
82 octave_idx_type jb = b.cidx (i);
83 octave_idx_type jb_max = b.cidx (i+1);
84 bool jb_lt_max = jb < jb_max;
85
86 while (ja_lt_max || jb_lt_max)
87 {
88 octave_quit ();
89 if ((! jb_lt_max) || (ja_lt_max && (a.ridx (ja) < b.ridx (jb))))
90 {
91 r.ridx (jx) = a.ridx (ja);
92 r.data (jx) = op (a.data (ja), 0.);
93 jx++;
94 ja++;
95 ja_lt_max= ja < ja_max;
96 }
97 else if ((! ja_lt_max)
98 || (b.ridx (jb) < a.ridx (ja)))
99 {
100 r.ridx (jx) = b.ridx (jb);
101 r.data (jx) = op (0., b.data (jb));
102 jx++;
103 jb++;
104 jb_lt_max= jb < jb_max;
105 }
106 else
107 {
108 if (op (a.data (ja), b.data (jb)) != 0.)
109 {
110 r.data (jx) = op (a.data (ja), b.data (jb));
111 r.ridx (jx) = a.ridx (ja);
112 jx++;
113 }
114 ja++;
115 ja_lt_max= ja < ja_max;
116 jb++;
117 jb_lt_max= jb < jb_max;
118 }
119 }
120 r.cidx (i+1) = jx;
121 }
122 }
123 a = r.maybe_compress ();
124 return a;
125}
126
127template <typename T>
130{
131 return plus_or_minus (a, b, std::plus<T> (), "operator +=");
132}
133
134template <typename T>
137{
138 return plus_or_minus (a, b, std::minus<T> (), "operator -=");
139}
140
141// Element by element MSparse by scalar ops.
142
143template <typename T, typename OP>
145plus_or_minus (const MSparse<T>& a, const T& s, OP op)
146{
147 octave_idx_type nr = a.rows ();
148 octave_idx_type nc = a.cols ();
149
150 MArray<T> r;
151
152 if (octave::math::isnan (s))
153 r = MArray<T> (dim_vector (nr, nc), octave::numeric_limits<T>::NaN ());
154 else
155 {
156 r = MArray<T> (dim_vector (nr, nc), op (0.0, s));
157
158 for (octave_idx_type j = 0; j < nc; j++)
159 for (octave_idx_type i = a.cidx (j); i < a.cidx (j+1); i++)
160 r.elem (a.ridx (i), j) = op (a.data (i), s);
161 }
162 return r;
163}
164
165template <typename T>
167operator + (const MSparse<T>& a, const T& s)
168{
169 return plus_or_minus (a, s, std::plus<T> ());
170}
171
172template <typename T>
174operator - (const MSparse<T>& a, const T& s)
175{
176 return plus_or_minus (a, s, std::minus<T> ());
177}
178
179template <typename T, typename OP>
181times_or_divide (const MSparse<T>& a, const T& s, OP op)
182{
183 octave_idx_type nr = a.rows ();
184 octave_idx_type nc = a.cols ();
185 octave_idx_type nz = a.nnz ();
186
187 MSparse<T> r;
188
189 // Handle scalars that affect zero elements in sparse array
190 bool non_zero = op (0.0, s) != 0.0;
191
192 if (non_zero)
193 {
194 r = MSparse<T> (nr, nc, nr * nc);
195 for (octave_idx_type j = 0; j < nc; j++)
196 for (octave_idx_type i = 0; i < nr; i++)
197 r.elem (i, j) = op (0.0, s);
198
199 for (octave_idx_type j = 0; j < nc; j++)
200 for (octave_idx_type i = a.cidx(j); i < a.cidx (j+1); i++)
201 r.elem (a.ridx (i), j) = op (a.data (i), s);
202 r.maybe_compress (true);
203 }
204 else
205 {
206 r = MSparse<T> (nr, nc, nz);
207
208 for (octave_idx_type i = 0; i < nz; i++)
209 {
210 r.data (i) = op (a.data (i), s);
211 r.ridx (i) = a.ridx (i);
212 }
213 for (octave_idx_type i = 0; i < nc + 1; i++)
214 r.cidx (i) = a.cidx (i);
215 r.maybe_compress (true);
216 }
217 return r;
218}
219
220template <typename T>
222operator * (const MSparse<T>& a, const T& s)
223{
224 return times_or_divide (a, s, std::multiplies<T> ());
225}
226
227template <typename T>
229operator / (const MSparse<T>& a, const T& s)
230{
231 return times_or_divide (a, s, std::divides<T> ());
232}
233
234// Element by element scalar by MSparse ops.
235
236template <typename T, typename OP>
238plus_or_minus (const T& s, const MSparse<T>& a, OP op)
239{
240 octave_idx_type nr = a.rows ();
241 octave_idx_type nc = a.cols ();
242
243 MArray<T> r;
244
245 if (octave::math::isnan (s))
246 r = MArray<T> (dim_vector (nr, nc), octave::numeric_limits<T>::NaN ());
247 else
248 {
249 r = MArray<T> (dim_vector (nr, nc), op (s, 0.0));
250
251 for (octave_idx_type j = 0; j < nc; j++)
252 for (octave_idx_type i = a.cidx (j); i < a.cidx (j+1); i++)
253 r.elem (a.ridx (i), j) = op (s, a.data (i));
254 }
255 return r;
256}
257
258template <typename T>
260operator + (const T& s, const MSparse<T>& a)
261{
262 return plus_or_minus (s, a, std::plus<T> ());
263}
264
265template <typename T>
267operator - (const T& s, const MSparse<T>& a)
268{
269 return plus_or_minus (s, a, std::minus<T> ());
270}
271
272template <typename T, typename OP>
274times_or_divide (const T& s, const MSparse<T>& a, OP op)
275{
276 octave_idx_type nr = a.rows ();
277 octave_idx_type nc = a.cols ();
278 octave_idx_type nz = a.nnz ();
279
280 MSparse<T> r;
281
282 // Handle scalars that affect zero elements in sparse array
283 bool non_zero = op (s, 0.0) != 0.0;
284
285 if (non_zero)
286 {
287 r = MSparse<T> (nr, nc, nr * nc);
288 for (octave_idx_type j = 0; j < nc; j++)
289 for (octave_idx_type i = 0; i < nr; i++)
290 r.elem (i, j) = op (s, 0.0);
291
292 for (octave_idx_type j = 0; j < nc; j++)
293 for (octave_idx_type i = a.cidx(j); i < a.cidx (j+1); i++)
294 r.elem (a.ridx (i), j) = op (s, a.data (i));
295 r.maybe_compress (true);
296 }
297 else
298 {
299 r = MSparse<T> (nr, nc, nz);
300
301 for (octave_idx_type i = 0; i < nz; i++)
302 {
303 r.data (i) = op (s, a.data (i));
304 r.ridx (i) = a.ridx (i);
305 }
306 for (octave_idx_type i = 0; i < nc + 1; i++)
307 r.cidx (i) = a.cidx (i);
308 r.maybe_compress (true);
309 }
310 return r;
311}
312
313template <typename T>
315operator * (const T& s, const MSparse<T>& a)
316{
317 return times_or_divide (s, a, std::multiplies<T> ());
318}
319
320template <typename T>
322operator / (const T& s, const MSparse<T>& a)
323{
324 return times_or_divide (s, a, std::divides<T> ());
325}
326
327// Element by element MSparse by MSparse ops.
328
329template <typename T, typename OP>
331plus_or_minus (const MSparse<T>& a, const MSparse<T>& b, OP op,
332 const char *op_name, bool negate)
333{
334 MSparse<T> r;
335
336 octave_idx_type a_nr = a.rows ();
337 octave_idx_type a_nc = a.cols ();
338
339 octave_idx_type b_nr = b.rows ();
340 octave_idx_type b_nc = b.cols ();
341
342 if (a_nr == 1 && a_nc == 1)
343 {
344 if (a.elem (0, 0) == 0.)
345 if (negate)
346 r = -MSparse<T> (b);
347 else
348 r = MSparse<T> (b);
349 else
350 {
351 r = MSparse<T> (b_nr, b_nc, op (a.data (0), 0.));
352
353 for (octave_idx_type j = 0 ; j < b_nc ; j++)
354 {
355 octave_quit ();
356 octave_idx_type idxj = j * b_nr;
357 for (octave_idx_type i = b.cidx (j) ; i < b.cidx (j+1) ; i++)
358 {
359 octave_quit ();
360 r.data (idxj + b.ridx (i)) = op (a.data (0), b.data (i));
361 }
362 }
363 r.maybe_compress ();
364 }
365 }
366 else if (b_nr == 1 && b_nc == 1)
367 {
368 if (b.elem (0, 0) == 0.)
369 r = MSparse<T> (a);
370 else
371 {
372 r = MSparse<T> (a_nr, a_nc, op (0.0, b.data (0)));
373
374 for (octave_idx_type j = 0 ; j < a_nc ; j++)
375 {
376 octave_quit ();
377 octave_idx_type idxj = j * a_nr;
378 for (octave_idx_type i = a.cidx (j) ; i < a.cidx (j+1) ; i++)
379 {
380 octave_quit ();
381 r.data (idxj + a.ridx (i)) = op (a.data (i), b.data (0));
382 }
383 }
384 r.maybe_compress ();
385 }
386 }
387 else if (a_nr == 0 && (b_nr == 0 || b_nr == 1))
388 {
389 if (a_nc == 1 || b_nc == 1 || a_nc == b_nc)
390 r.resize (a_nr, std::max (a_nc, b_nc));
391 else
392 octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
393 }
394 else if (a_nc == 0 && (b_nc == 0 || b_nc == 1))
395 {
396 if (a_nr == 1 || b_nr == 1 || a_nr == b_nr)
397 r.resize (std::max (a_nr, b_nr), a_nc);
398 else
399 octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
400 }
401 else if (b_nr == 0 && (a_nr == 0 || a_nr == 1))
402 {
403 if (b_nc == 1 || a_nc == 1 || b_nc == a_nc)
404 r.resize (b_nr, std::max (a_nc, b_nc));
405 else
406 octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
407 }
408 else if (b_nc == 0 && (a_nc == 0 || a_nc == 1))
409 {
410 if (b_nr == 1 || a_nr == 1 || b_nr == a_nr)
411 r.resize (std::max (a_nr, b_nr), b_nc);
412 else
413 octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
414 }
415 else if (a_nr == b_nr && a_nc == b_nc)
416 {
417 r = MSparse<T> (a_nr, a_nc, (a.nnz () + b.nnz ()));
418
419 octave_idx_type jx = 0;
420 r.cidx (0) = 0;
421 for (octave_idx_type i = 0 ; i < a_nc ; i++)
422 {
423 octave_idx_type ja = a.cidx (i);
424 octave_idx_type ja_max = a.cidx (i+1);
425 bool ja_lt_max = ja < ja_max;
426
427 octave_idx_type jb = b.cidx (i);
428 octave_idx_type jb_max = b.cidx (i+1);
429 bool jb_lt_max = jb < jb_max;
430
431 while (ja_lt_max || jb_lt_max)
432 {
433 octave_quit ();
434 if ((! jb_lt_max) || (ja_lt_max && (a.ridx (ja) < b.ridx (jb))))
435 {
436 r.ridx (jx) = a.ridx (ja);
437 r.data (jx) = op (a.data (ja), 0.);
438 jx++;
439 ja++;
440 ja_lt_max = ja < ja_max;
441 }
442 else if ((! ja_lt_max) || (b.ridx (jb) < a.ridx (ja)))
443 {
444 r.ridx (jx) = b.ridx (jb);
445 r.data (jx) = op (0., b.data (jb));
446 jx++;
447 jb++;
448 jb_lt_max = jb < jb_max;
449 }
450 else
451 {
452 if (op (a.data (ja), b.data (jb)) != 0.)
453 {
454 r.data (jx) = op (a.data (ja), b.data (jb));
455 r.ridx (jx) = a.ridx (ja);
456 jx++;
457 }
458 ja++;
459 ja_lt_max = ja < ja_max;
460 jb++;
461 jb_lt_max = jb < jb_max;
462 }
463 }
464 r.cidx (i+1) = jx;
465 }
466 r.maybe_compress ();
467 }
468 else if (a_nr == b_nr && (a_nc == 1 || b_nc == 1))
469 // (a_nc == b_nc && (a_nr == 1 || b_nr == 1)) is handled
470 // by double transpose in the caller functions
471 {
472 octave_idx_type r_nc = (a_nc < b_nc ? b_nc : a_nc);
473 octave_idx_type rnnz = (a_nc < b_nc ? a.nnz () * b_nc + b.nnz () :
474 a.nnz () + a_nc * b.nnz ());
475 r = MSparse<T> (a_nr, r_nc, rnnz);
476
477 octave_idx_type jx = 0;
478 r.cidx (0) = 0;
479 for (octave_idx_type i = 0 ; i < r_nc ; i++)
480 {
482 octave_idx_type ja_max;
484 octave_idx_type jb_max;
485 if (a_nc == 1)
486 {
487 ja = a.cidx(0);
488 ja_max = a.cidx(1);
489 jb = b.cidx (i);
490 jb_max = b.cidx (i+1);
491 }
492 else
493 {
494 ja = a.cidx(i);
495 ja_max = a.cidx(i+1);
496 jb = b.cidx (0);
497 jb_max = b.cidx (1);
498 }
499 bool ja_lt_max = ja < ja_max;
500 bool jb_lt_max = jb < jb_max;
501
502 while (ja_lt_max || jb_lt_max)
503 {
504 octave_quit ();
505 if ((! jb_lt_max) || (ja_lt_max && (a.ridx (ja) < b.ridx (jb))))
506 {
507 r.ridx (jx) = a.ridx (ja);
508 r.data (jx) = op (a.data (ja), 0.);
509 jx++;
510 ja++;
511 ja_lt_max = ja < ja_max;
512 }
513 else if ((! ja_lt_max) || (b.ridx (jb) < a.ridx (ja)))
514 {
515 r.ridx (jx) = b.ridx (jb);
516 r.data (jx) = op (0., b.data (jb));
517 jx++;
518 jb++;
519 jb_lt_max = jb < jb_max;
520 }
521 else
522 {
523 if (op (a.data (ja), b.data (jb)) != 0.)
524 {
525 r.data (jx) = op (a.data (ja), b.data (jb));
526 r.ridx (jx) = a.ridx (ja);
527 jx++;
528 }
529 ja++;
530 ja_lt_max = ja < ja_max;
531 jb++;
532 jb_lt_max = jb < jb_max;
533 }
534 }
535 r.cidx (i+1) = jx;
536 }
537 r.maybe_compress ();
538 }
539 else if (a_nr == 1 && b_nc == 1)
540 // (a_nc == 1 && b_nr == 1) is handled
541 // by double transpose in the caller functions
542 {
543 // a: 1 x a_nc (row vector)
544 // b: b_nr x 1 (column vector)
545 //
546 // Result: outer broadcast
547 // r(i,j) = op(a(0,j), b(i,0))
548
549 r = MSparse<T> (b_nr, a_nc, (a.nnz () * b_nr + b.nnz () * a_nc));
550
551 for (octave_idx_type j = 0; j < a_nc; j++)
552 {
553 const T a_val = a.elem (0, j);
554
555 for (octave_idx_type i = 0; i < b_nr; i++)
556 {
557 octave_quit ();
558
559 const T b_val = b.elem (i, 0);
560 const T val = op (a_val, b_val);
561
562 if (val != T ())
563 r.elem (i, j) = val;
564 }
565 }
566 r.maybe_compress ();
567 }
568 else
569 octave::err_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
570
571 return r;
572}
573
574template <typename T>
577{
578 if ((a.cols () != 1 && a.cols () == b.cols () &&
579 (a.rows () == 1 || b.rows () == 1)) ||
580 (a.cols () == 1 && b.rows () == 1 && a.rows () > 1 && b.cols () > 1))
581 return plus_or_minus (a.transpose (), b.transpose (), std::plus<T> (),
582 "operator +", false).transpose ();
583 else
584 return plus_or_minus (a, b, std::plus<T> (), "operator +", false);
585}
586
587template <typename T>
590{
591 if ((a.cols () != 1 && a.cols () == b.cols () &&
592 (a.rows () == 1 || b.rows () == 1)) ||
593 (a.cols () == 1 && b.rows () == 1 && a.rows () > 1 && b.cols () > 1))
594 return plus_or_minus (a.transpose (), b.transpose (), std::minus<T> (),
595 "operator -", true).transpose ();
596 else
597 return plus_or_minus (a, b, std::minus<T> (), "operator -", true);
598}
599
600template <typename T>
602product (const MSparse<T>& a, const MSparse<T>& b)
603{
604 MSparse<T> r;
605
606 octave_idx_type a_nr = a.rows ();
607 octave_idx_type a_nc = a.cols ();
608
609 octave_idx_type b_nr = b.rows ();
610 octave_idx_type b_nc = b.cols ();
611
612 if (a_nr == 1 && a_nc == 1)
613 {
614 if (a.elem (0, 0) == 0.)
615 r = MSparse<T> (b_nr, b_nc);
616 else if (octave::math::isnan (a.elem (0, 0)))
617 {
618 r = MSparse<T> (b_nr, b_nc, b_nr * b_nc);
619 for (octave_idx_type i = 0 ; i < r.numel () ; i++)
620 r.elem(i) = octave::numeric_limits<T>::NaN ();
621 }
622 else if (octave::math::isinf (a.elem (0, 0)))
623 {
624 r = MSparse<T> (b_nr, b_nc, b_nr * b_nc);
625
626 for (octave_idx_type j = 0 ; j < b_nc ; j++)
627 {
628 octave_quit ();
629 for (octave_idx_type i = 0 ; i < b_nr ; i++)
630 {
631 if (b.elem (i, j) == 0.0)
632 r.elem (i, j) = octave::numeric_limits<T>::NaN ();
633 else
634 r.elem (i, j) = a.elem (0, 0) * b.elem (i, j);
635 }
636 }
637 r.maybe_compress (true);
638 }
639 else
640 {
641 r = MSparse<T> (b);
642 octave_idx_type b_nnz = b.nnz ();
643
644 for (octave_idx_type i = 0 ; i < b_nnz ; i++)
645 {
646 octave_quit ();
647 r.data (i) = a.data (0) * r.data (i);
648 }
649 r.maybe_compress ();
650 }
651 }
652 else if (b_nr == 1 && b_nc == 1)
653 {
654 if (b.elem (0, 0) == 0.)
655 r = MSparse<T> (a_nr, a_nc);
656 else if (octave::math::isnan (b.elem (0, 0)))
657 {
658 r = MSparse<T> (a_nr, a_nc, a_nr * a_nc);
659 for (octave_idx_type i = 0 ; i < r.numel () ; i++)
660 r.elem(i) = octave::numeric_limits<T>::NaN ();
661 }
662 else if (octave::math::isinf (b.elem (0, 0)))
663 {
664 r = MSparse<T> (a_nr, a_nc, a_nr * a_nc);
665
666 for (octave_idx_type j = 0 ; j < a_nc ; j++)
667 {
668 octave_quit ();
669 for (octave_idx_type i = 0 ; i < a_nr ; i++)
670 {
671 if (a.elem (i, j) == 0.0)
672 r.elem (i, j) = octave::numeric_limits<T>::NaN ();
673 else
674 r.elem (i, j) = a.elem (i, j) * b.elem (0, 0);
675 }
676 }
677 r.maybe_compress (true);
678 }
679 else
680 {
681 r = MSparse<T> (a);
682 octave_idx_type a_nnz = a.nnz ();
683
684 for (octave_idx_type i = 0 ; i < a_nnz ; i++)
685 {
686 octave_quit ();
687 r.data (i) = r.data (i) * b.data (0);
688 }
689 r.maybe_compress ();
690 }
691 }
692 else if (a_nr == 0 && (b_nr == 0 || b_nr == 1))
693 {
694 if (a_nc == 1 || b_nc == 1 || a_nc == b_nc)
695 r.resize (a_nr, std::max (a_nc, b_nc));
696 else
697 octave::err_nonconformant ("product", a_nr, a_nc, b_nr, b_nc);
698 }
699 else if (a_nc == 0 && (b_nc == 0 || b_nc == 1))
700 {
701 if (a_nr == 1 || b_nr == 1 || a_nr == b_nr)
702 r.resize (std::max (a_nr, b_nr), a_nc);
703 else
704 octave::err_nonconformant ("product", a_nr, a_nc, b_nr, b_nc);
705 }
706 else if (b_nr == 0 && (a_nr == 0 || a_nr == 1))
707 {
708 if (b_nc == 1 || a_nc == 1 || b_nc == a_nc)
709 r.resize (b_nr, std::max (a_nc, b_nc));
710 else
711 octave::err_nonconformant ("product", a_nr, a_nc, b_nr, b_nc);
712 }
713 else if (b_nc == 0 && (a_nc == 0 || a_nc == 1))
714 {
715 if (b_nr == 1 || a_nr == 1 || b_nr == a_nr)
716 r.resize (std::max (a_nr, b_nr), b_nc);
717 else
718 octave::err_nonconformant ("product", a_nr, a_nc, b_nr, b_nc);
719 }
720 else if (a_nr == b_nr && a_nc == b_nc &&
722 {
723 r = MSparse<T> (a_nr, a_nc, a_nr * a_nc);
724
725 for (octave_idx_type j = 0 ; j < a_nc ; j++)
726 {
727 octave_quit ();
728 for (octave_idx_type i = 0 ; i < a_nr ; i++)
729 r.elem (i, j) = a.elem (i, j) * b.elem (i, j);
730 }
731 r.maybe_compress (true);
732 }
733 else if (a_nr == b_nr && a_nc == b_nc)
734 {
735 r = MSparse<T> (a_nr, a_nc, (a.nnz () > b.nnz () ? a.nnz () : b.nnz ()));
736
737 octave_idx_type jx = 0;
738 r.cidx (0) = 0;
739 for (octave_idx_type i = 0 ; i < a_nc ; i++)
740 {
741 octave_idx_type ja = a.cidx (i);
742 octave_idx_type ja_max = a.cidx (i+1);
743 bool ja_lt_max = ja < ja_max;
744
745 octave_idx_type jb = b.cidx (i);
746 octave_idx_type jb_max = b.cidx (i+1);
747 bool jb_lt_max = jb < jb_max;
748
749 while (ja_lt_max || jb_lt_max)
750 {
751 octave_quit ();
752 if ((! jb_lt_max) || (ja_lt_max && (a.ridx (ja) < b.ridx (jb))))
753 {
754 ja++;
755 ja_lt_max = ja < ja_max;
756 }
757 else if ((! ja_lt_max) || (b.ridx (jb) < a.ridx (ja)))
758 {
759 jb++;
760 jb_lt_max = jb < jb_max;
761 }
762 else
763 {
764 if ((a.data (ja) * b.data (jb)) != 0.)
765 {
766 r.data (jx) = a.data (ja) * b.data (jb);
767 r.ridx (jx) = a.ridx (ja);
768 jx++;
769 }
770 ja++;
771 ja_lt_max = ja < ja_max;
772 jb++;
773 jb_lt_max = jb < jb_max;
774 }
775 }
776 r.cidx (i+1) = jx;
777 }
778 r.maybe_compress ();
779 }
780 else if (a_nr == b_nr && (a_nc == 1 || b_nc == 1) &&
782 {
783 octave_idx_type r_nc = (a_nc < b_nc ? b_nc : a_nc);
784 r = MSparse<T> (a_nr, r_nc, a_nr * r_nc);
785
786 if (a_nc == 1)
787 {
788 for (octave_idx_type j = 0 ; j < r_nc ; j++)
789 {
790 octave_quit ();
791 for (octave_idx_type i = 0 ; i < a_nr ; i++)
792 r.elem (i, j) = a.elem (i, 0) * b.elem (i, j);
793 }
794 }
795 else
796 {
797 for (octave_idx_type j = 0 ; j < r_nc ; j++)
798 {
799 octave_quit ();
800 for (octave_idx_type i = 0 ; i < a_nr ; i++)
801 r.elem (i, j) = a.elem (i, j) * b.elem (i, 0);
802 }
803 }
804 r.maybe_compress (true);
805 }
806 else if (a_nr == b_nr && (a_nc == 1 || b_nc == 1))
807 {
808 octave_idx_type r_nc = (a_nc < b_nc ? b_nc : a_nc);
809 octave_idx_type rnnz = (a_nc < b_nc ? a.nnz () * b_nc + b.nnz () :
810 a.nnz () + a_nc * b.nnz ());
811 r = MSparse<T> (a_nr, r_nc, rnnz);
812
813 octave_idx_type jx = 0;
814 r.cidx (0) = 0;
815 for (octave_idx_type i = 0 ; i < r_nc ; i++)
816 {
818 octave_idx_type ja_max;
820 octave_idx_type jb_max;
821 if (a_nc == 1)
822 {
823 ja = a.cidx(0);
824 ja_max = a.cidx(1);
825 jb = b.cidx (i);
826 jb_max = b.cidx (i+1);
827 }
828 else
829 {
830 ja = a.cidx(i);
831 ja_max = a.cidx(i+1);
832 jb = b.cidx (0);
833 jb_max = b.cidx (1);
834 }
835 bool ja_lt_max = ja < ja_max;
836 bool jb_lt_max = jb < jb_max;
837
838 while (ja_lt_max || jb_lt_max)
839 {
840 octave_quit ();
841 if ((! jb_lt_max) || (ja_lt_max && (a.ridx (ja) < b.ridx (jb))))
842 {
843 ja++;
844 ja_lt_max = ja < ja_max;
845 }
846 else if ((! ja_lt_max) || (b.ridx (jb) < a.ridx (ja)))
847 {
848 jb++;
849 jb_lt_max = jb < jb_max;
850 }
851 else
852 {
853 if ((a.data (ja) * b.data (jb)) != 0.)
854 {
855 r.data (jx) = a.data (ja) * b.data (jb);
856 r.ridx (jx) = a.ridx (ja);
857 jx++;
858 }
859 ja++;
860 ja_lt_max = ja < ja_max;
861 jb++;
862 jb_lt_max = jb < jb_max;
863 }
864 }
865 r.cidx (i+1) = jx;
866 }
867 r.maybe_compress ();
868 }
869 else if (a_nc == b_nc && (a_nr == 1 || b_nr == 1))
870 r = product (a.transpose (), b.transpose ()).transpose ();
871 else if (a_nr == 1 && b_nc == 1 &&
873 {
874 r = MSparse<T> (b_nr, a_nc, b_nr * a_nc);
875
876 for (octave_idx_type j = 0 ; j < a_nc ; j++)
877 {
878 octave_quit ();
879 for (octave_idx_type i = 0 ; i < b_nr ; i++)
880 r.elem (i, j) = a.elem (0, j) * b.elem (i, 0);
881 }
882 r.maybe_compress (true);
883 }
884 else if (a_nr == 1 && b_nc == 1)
885 {
886 r = MSparse<T> (b_nr, a_nc, (a.nnz () * b_nr + b.nnz () * a_nc));
887
888 octave_idx_type jx = 0;
889 r.cidx (0) = 0;
890 for (octave_idx_type i = 0 ; i < a_nc ; i++)
891 {
892 octave_idx_type ja = a.cidx (i);
893 octave_idx_type ja_max = a.cidx (i+1);
894 bool ja_lt_max = ja < ja_max;
895
896 octave_idx_type jb = b.cidx (0);
897 octave_idx_type jb_max = b.cidx (1);
898 bool jb_lt_max = jb < jb_max;
899
900 while (ja_lt_max || jb_lt_max)
901 {
902 octave_quit ();
903 if (! ja_lt_max && jb_lt_max)
904 {
905 jb++;
906 jb_lt_max = jb < jb_max;
907 }
908 else if (ja_lt_max && ! jb_lt_max)
909 {
910 ja_lt_max = false;
911 }
912 else // (ja_lt_max && jb_lt_max)
913 {
914 if ((a.data (ja) * b.data (jb)) != 0.)
915 {
916 r.data (jx) = a.data (ja) * b.data (jb);
917 r.ridx (jx) = b.ridx (jb);
918 jx++;
919 }
920 jb++;
921 jb_lt_max = jb < jb_max;
922 ja_lt_max = jb_lt_max;
923 }
924 }
925 r.cidx (i+1) = jx;
926 }
927 r.maybe_compress ();
928 }
929 else if (a_nc == 1 && b_nr == 1)
930 r = product (a.transpose (), b.transpose ()).transpose ();
931 else
932 octave::err_nonconformant ("product", a_nr, a_nc, b_nr, b_nc);
933
934 return r;
935}
936
937template <typename T>
939quotient (const MSparse<T>& a, const MSparse<T>& b)
940{
941 MSparse<T> r;
942 T Zero = T ();
943
944 octave_idx_type a_nr = a.rows ();
945 octave_idx_type a_nc = a.cols ();
946
947 octave_idx_type b_nr = b.rows ();
948 octave_idx_type b_nc = b.cols ();
949
950 if (a_nr == 1 && a_nc == 1)
951 {
952 T val = a.elem (0, 0);
953 T fill = val / T ();
954 if (fill == T ())
955 {
956 octave_idx_type b_nnz = b.nnz ();
957 r = MSparse<T> (b);
958 for (octave_idx_type i = 0 ; i < b_nnz ; i++)
959 r.data (i) = val / r.data (i);
960 r.maybe_compress (true);
961 }
962 else
963 {
964 r = MSparse<T> (b_nr, b_nc, fill);
965 for (octave_idx_type j = 0 ; j < b_nc ; j++)
966 {
967 octave_quit ();
968 octave_idx_type idxj = j * b_nr;
969 for (octave_idx_type i = b.cidx (j) ; i < b.cidx (j+1) ; i++)
970 {
971 octave_quit ();
972 r.data (idxj + b.ridx (i)) = val / b.data (i);
973 }
974 }
975 r.maybe_compress (true);
976 }
977 }
978 else if (b_nr == 1 && b_nc == 1)
979 {
980 T val = b.elem (0, 0);
981 T fill = T () / val;
982 if (fill == T ())
983 {
984 octave_idx_type a_nnz = a.nnz ();
985 r = MSparse<T> (a);
986 for (octave_idx_type i = 0 ; i < a_nnz ; i++)
987 r.data (i) = r.data (i) / val;
988 r.maybe_compress (true);
989 }
990 else
991 {
992 r = MSparse<T> (a_nr, a_nc, fill);
993 for (octave_idx_type j = 0 ; j < a_nc ; j++)
994 {
995 octave_quit ();
996 octave_idx_type idxj = j * a_nr;
997 for (octave_idx_type i = a.cidx (j) ; i < a.cidx (j+1) ; i++)
998 {
999 octave_quit ();
1000 r.data (idxj + a.ridx (i)) = a.data (i) / val;
1001 }
1002 }
1003 r.maybe_compress (true);
1004 }
1005 }
1006 else if (a_nr == 0 && (b_nr == 0 || b_nr == 1))
1007 {
1008 if (a_nc == 1 || b_nc == 1 || a_nc == b_nc)
1009 r.resize (a_nr, std::max (a_nc, b_nc));
1010 else
1011 octave::err_nonconformant ("quotient", a_nr, a_nc, b_nr, b_nc);
1012 }
1013 else if (a_nc == 0 && (b_nc == 0 || b_nc == 1))
1014 {
1015 if (a_nr == 1 || b_nr == 1 || a_nr == b_nr)
1016 r.resize (std::max (a_nr, b_nr), a_nc);
1017 else
1018 octave::err_nonconformant ("quotient", a_nr, a_nc, b_nr, b_nc);
1019 }
1020 else if (b_nr == 0 && (a_nr == 0 || a_nr == 1))
1021 {
1022 if (b_nc == 1 || a_nc == 1 || b_nc == a_nc)
1023 r.resize (b_nr, std::max (a_nc, b_nc));
1024 else
1025 octave::err_nonconformant ("quotient", a_nr, a_nc, b_nr, b_nc);
1026 }
1027 else if (b_nc == 0 && (a_nc == 0 || a_nc == 1))
1028 {
1029 if (b_nr == 1 || a_nr == 1 || b_nr == a_nr)
1030 r.resize (std::max (a_nr, b_nr), b_nc);
1031 else
1032 octave::err_nonconformant ("quotient", a_nr, a_nc, b_nr, b_nc);
1033 }
1034 else if (a_nr == b_nr && a_nc == b_nc)
1035 {
1036 r = MSparse<T> (a_nr, a_nc, (Zero / Zero));
1037
1038 for (octave_idx_type i = 0 ; i < a_nc ; i++)
1039 {
1040 octave_idx_type ja = a.cidx (i);
1041 octave_idx_type ja_max = a.cidx (i+1);
1042 bool ja_lt_max = ja < ja_max;
1043
1044 octave_idx_type jb = b.cidx (i);
1045 octave_idx_type jb_max = b.cidx (i+1);
1046 bool jb_lt_max = jb < jb_max;
1047
1048 while (ja_lt_max || jb_lt_max)
1049 {
1050 octave_quit ();
1051 if ((! jb_lt_max) || (ja_lt_max && (a.ridx (ja) < b.ridx (jb))))
1052 {
1053 r.elem (a.ridx (ja), i) = a.data (ja) / Zero;
1054 ja++;
1055 ja_lt_max = ja < ja_max;
1056 }
1057 else if ((! ja_lt_max) || (b.ridx (jb) < a.ridx (ja)))
1058 {
1059 r.elem (b.ridx (jb), i) = Zero / b.data (jb);
1060 jb++;
1061 jb_lt_max = jb < jb_max;
1062 }
1063 else
1064 {
1065 r.elem (a.ridx (ja), i) = a.data (ja) / b.data (jb);
1066 ja++;
1067 ja_lt_max = ja < ja_max;
1068 jb++;
1069 jb_lt_max = jb < jb_max;
1070 }
1071 }
1072 }
1073 r.maybe_compress (true);
1074 }
1075 else if (a_nr == b_nr && (a_nc == 1 || b_nc == 1))
1076 {
1077 octave_idx_type r_nc = (a_nc < b_nc ? b_nc : a_nc);
1078 r = MSparse<T> (a_nr, r_nc, (Zero / Zero));
1079
1080 for (octave_idx_type i = 0 ; i < r_nc ; i++)
1081 {
1082 octave_idx_type ja;
1083 octave_idx_type ja_max;
1084 octave_idx_type jb;
1085 octave_idx_type jb_max;
1086 if (a_nc == 1)
1087 {
1088 ja = a.cidx(0);
1089 ja_max = a.cidx(1);
1090 jb = b.cidx (i);
1091 jb_max = b.cidx (i+1);
1092 }
1093 else
1094 {
1095 ja = a.cidx(i);
1096 ja_max = a.cidx(i+1);
1097 jb = b.cidx (0);
1098 jb_max = b.cidx (1);
1099 }
1100 bool ja_lt_max = ja < ja_max;
1101 bool jb_lt_max = jb < jb_max;
1102
1103 while (ja_lt_max || jb_lt_max)
1104 {
1105 octave_quit ();
1106 if ((! jb_lt_max) || (ja_lt_max && (a.ridx (ja) < b.ridx (jb))))
1107 {
1108 r.elem (a.ridx (ja), i) = a.data (ja) / Zero;
1109 ja++;
1110 ja_lt_max = ja < ja_max;
1111 }
1112 else if ((! ja_lt_max) || (b.ridx (jb) < a.ridx (ja)))
1113 {
1114 r.elem (b.ridx (jb), i) = Zero / b.data (jb);
1115 jb++;
1116 jb_lt_max = jb < jb_max;
1117 }
1118 else
1119 {
1120 r.elem (a.ridx (ja), i) = a.data (ja) / b.data (jb);
1121 ja++;
1122 ja_lt_max = ja < ja_max;
1123 jb++;
1124 jb_lt_max = jb < jb_max;
1125 }
1126 }
1127 }
1128 r.maybe_compress (true);
1129 }
1130 else if (a_nc == b_nc && (a_nr == 1 || b_nr == 1))
1131 r = quotient (a.transpose (), b.transpose ()).transpose ();
1132 else if (a_nr == 1 && b_nc == 1)
1133 {
1134 r = MSparse<T> (b_nr, a_nc, (a_nc * b_nr));
1135
1136 for (octave_idx_type j = 0; j < a_nc; j++)
1137 {
1138 const T a_val = a.elem (0, j);
1139
1140 for (octave_idx_type i = 0; i < b_nr; i++)
1141 {
1142 octave_quit ();
1143
1144 const T b_val = b.elem (i, 0);
1145 const T val = a_val / b_val;
1146
1147 if (val != Zero || val != val)
1148 r.elem (i, j) = val;
1149 }
1150 }
1151 r.maybe_compress (true);
1152 }
1153 else if (a_nc == 1 && b_nr == 1)
1154 r = quotient (a.transpose (), b.transpose ()).transpose ();
1155 else
1156 octave::err_nonconformant ("quotient", a_nr, a_nc, b_nr, b_nc);
1157
1158 return r;
1159}
1160
1161// Unary MSparse ops.
1162
1163template <typename T>
1166{
1167 return a;
1168}
1169
1170template <typename T>
1173{
1174 MSparse<T> retval (a);
1175 octave_idx_type nz = a.nnz ();
1176 for (octave_idx_type i = 0; i < nz; i++)
1177 retval.data (i) = - retval.data (i);
1178 return retval;
1179}
MSparse< T > times_or_divide(const MSparse< T > &a, const T &s, OP op)
Definition MSparse.cc:181
MSparse< T > operator/(const MSparse< T > &a, const T &s)
Definition MSparse.cc:229
MSparse< T > operator*(const MSparse< T > &a, const T &s)
Definition MSparse.cc:222
MArray< T > operator-(const MSparse< T > &a, const T &s)
Definition MSparse.cc:174
MSparse< T > & plus_or_minus(MSparse< T > &a, const MSparse< T > &b, OP op, const char *op_name)
Definition MSparse.cc:31
MSparse< T > & operator-=(MSparse< T > &a, const MSparse< T > &b)
Definition MSparse.cc:136
MArray< T > operator+(const MSparse< T > &a, const T &s)
Definition MSparse.cc:167
MSparse< T > & operator+=(MSparse< T > &a, const MSparse< T > &b)
Definition MSparse.cc:129
MSparse< T > product(const MSparse< T > &a, const MSparse< T > &b)
Definition MSparse.cc:602
MSparse< T > quotient(const MSparse< T > &a, const MSparse< T > &b)
Definition MSparse.cc:939
T & elem(octave_idx_type n)
Size of the specified dimension.
Definition Array-base.h:585
Template for N-dimensional array classes with like-type math operators.
Definition MArray.h:61
MSparse< T > transpose() const
Definition MSparse.h:96
octave_idx_type cols() const
Definition Sparse.h:351
octave_idx_type * cidx()
Definition Sparse.h:595
T * data()
Definition Sparse.h:573
void resize(octave_idx_type r, octave_idx_type c)
Definition Sparse.cc:1006
T & elem(octave_idx_type n)
Definition Sparse.h:455
octave_idx_type * ridx()
Definition Sparse.h:582
octave_idx_type numel() const
Definition Sparse.h:342
bool any_element_is_inf_or_nan() const
Definition Sparse.h:753
Sparse< T, Alloc > maybe_compress(bool remove_zeros=false)
Definition Sparse.h:530
octave_idx_type nnz() const
Actual number of nonzero terms.
Definition Sparse.h:338
octave_idx_type rows() const
Definition Sparse.h:350
Vector representing the dimensions (size) of an Array.
Definition dim-vector.h:92