00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #ifdef HAVE_CONFIG_H
00025 #include <config.h>
00026 #endif
00027
00028 #include <functional>
00029
00030 #include "quit.h"
00031 #include "lo-error.h"
00032 #include "MArray.h"
00033 #include "Array-util.h"
00034
00035 #include "MSparse.h"
00036 #include "MSparse-defs.h"
00037
00038
00039
00040
00041
00042 template <class T, class OP>
00043 MSparse<T>&
00044 plus_or_minus (MSparse<T>& a, const MSparse<T>& b, OP op, const char* op_name)
00045 {
00046 MSparse<T> r;
00047
00048 octave_idx_type a_nr = a.rows ();
00049 octave_idx_type a_nc = a.cols ();
00050
00051 octave_idx_type b_nr = b.rows ();
00052 octave_idx_type b_nc = b.cols ();
00053
00054 if (a_nr != b_nr || a_nc != b_nc)
00055 gripe_nonconformant (op_name , a_nr, a_nc, b_nr, b_nc);
00056 else
00057 {
00058 r = MSparse<T> (a_nr, a_nc, (a.nnz () + b.nnz ()));
00059
00060 octave_idx_type jx = 0;
00061 for (octave_idx_type i = 0 ; i < a_nc ; i++)
00062 {
00063 octave_idx_type ja = a.cidx(i);
00064 octave_idx_type ja_max = a.cidx(i+1);
00065 bool ja_lt_max= ja < ja_max;
00066
00067 octave_idx_type jb = b.cidx(i);
00068 octave_idx_type jb_max = b.cidx(i+1);
00069 bool jb_lt_max = jb < jb_max;
00070
00071 while (ja_lt_max || jb_lt_max )
00072 {
00073 octave_quit ();
00074 if ((! jb_lt_max) ||
00075 (ja_lt_max && (a.ridx(ja) < b.ridx(jb))))
00076 {
00077 r.ridx(jx) = a.ridx(ja);
00078 r.data(jx) = op (a.data(ja), 0.);
00079 jx++;
00080 ja++;
00081 ja_lt_max= ja < ja_max;
00082 }
00083 else if (( !ja_lt_max ) ||
00084 (jb_lt_max && (b.ridx(jb) < a.ridx(ja)) ) )
00085 {
00086 r.ridx(jx) = b.ridx(jb);
00087 r.data(jx) = op (0., b.data(jb));
00088 jx++;
00089 jb++;
00090 jb_lt_max= jb < jb_max;
00091 }
00092 else
00093 {
00094 if (op (a.data(ja), b.data(jb)) != 0.)
00095 {
00096 r.data(jx) = op (a.data(ja), b.data(jb));
00097 r.ridx(jx) = a.ridx(ja);
00098 jx++;
00099 }
00100 ja++;
00101 ja_lt_max= ja < ja_max;
00102 jb++;
00103 jb_lt_max= jb < jb_max;
00104 }
00105 }
00106 r.cidx(i+1) = jx;
00107 }
00108
00109 a = r.maybe_compress ();
00110 }
00111
00112 return a;
00113 }
00114
00115 template <typename T>
00116 MSparse<T>&
00117 operator += (MSparse<T>& a, const MSparse<T>& b)
00118 {
00119 return plus_or_minus (a, b, std::plus<T> (), "operator +=");
00120 }
00121
00122 template <typename T>
00123 MSparse<T>&
00124 operator -= (MSparse<T>& a, const MSparse<T>& b)
00125 {
00126 return plus_or_minus (a, b, std::minus<T> (), "operator -=");
00127 }
00128
00129
00130
00131
00132 template <class T, class OP>
00133 MArray<T>
00134 plus_or_minus (const MSparse<T>& a, const T& s, OP op)
00135 {
00136 octave_idx_type nr = a.rows ();
00137 octave_idx_type nc = a.cols ();
00138
00139 MArray<T> r (dim_vector (nr, nc), op (0.0, s));
00140
00141 for (octave_idx_type j = 0; j < nc; j++)
00142 for (octave_idx_type i = a.cidx(j); i < a.cidx(j+1); i++)
00143 r.elem (a.ridx (i), j) = op (a.data (i), s);
00144 return r;
00145 }
00146
00147 template <typename T>
00148 MArray<T>
00149 operator + (const MSparse<T>& a, const T& s)
00150 {
00151 return plus_or_minus (a, s, std::plus<T> ());
00152 }
00153
00154 template <typename T>
00155 MArray<T>
00156 operator - (const MSparse<T>& a, const T& s)
00157 {
00158 return plus_or_minus (a, s, std::minus<T> ());
00159 }
00160
00161
00162 template <class T, class OP>
00163 MSparse<T>
00164 times_or_divide (const MSparse<T>& a, const T& s, OP op)
00165 {
00166 octave_idx_type nr = a.rows ();
00167 octave_idx_type nc = a.cols ();
00168 octave_idx_type nz = a.nnz ();
00169
00170 MSparse<T> r (nr, nc, nz);
00171
00172 for (octave_idx_type i = 0; i < nz; i++)
00173 {
00174 r.data(i) = op (a.data(i), s);
00175 r.ridx(i) = a.ridx(i);
00176 }
00177 for (octave_idx_type i = 0; i < nc + 1; i++)
00178 r.cidx(i) = a.cidx(i);
00179 r.maybe_compress (true);
00180 return r;
00181 }
00182
00183 template <typename T>
00184 MSparse<T>
00185 operator * (const MSparse<T>& a, const T& s)
00186 {
00187 return times_or_divide (a, s, std::multiplies<T> ());
00188 }
00189
00190 template <typename T>
00191 MSparse<T>
00192 operator / (const MSparse<T>& a, const T& s)
00193 {
00194 return times_or_divide (a, s, std::divides<T> ());
00195 }
00196
00197
00198
00199
00200 template <class T, class OP>
00201 MArray<T>
00202 plus_or_minus (const T& s, const MSparse<T>& a, OP op)
00203 {
00204 octave_idx_type nr = a.rows ();
00205 octave_idx_type nc = a.cols ();
00206
00207 MArray<T> r (dim_vector (nr, nc), op (s, 0.0));
00208
00209 for (octave_idx_type j = 0; j < nc; j++)
00210 for (octave_idx_type i = a.cidx(j); i < a.cidx(j+1); i++)
00211 r.elem (a.ridx (i), j) = op (s, a.data (i));
00212 return r;
00213 }
00214
00215 template <typename T>
00216 MArray<T>
00217 operator + (const T& s, const MSparse<T>& a)
00218 {
00219 return plus_or_minus (s, a, std::plus<T> ());
00220 }
00221
00222 template <typename T>
00223 MArray<T>
00224 operator - (const T& s, const MSparse<T>& a)
00225 {
00226 return plus_or_minus (s, a, std::minus<T> ());
00227 }
00228
00229 template <class T, class OP>
00230 MSparse<T>
00231 times_or_divides (const T& s, const MSparse<T>& a, OP op)
00232 {
00233 octave_idx_type nr = a.rows ();
00234 octave_idx_type nc = a.cols ();
00235 octave_idx_type nz = a.nnz ();
00236
00237 MSparse<T> r (nr, nc, nz);
00238
00239 for (octave_idx_type i = 0; i < nz; i++)
00240 {
00241 r.data(i) = op (s, a.data(i));
00242 r.ridx(i) = a.ridx(i);
00243 }
00244 for (octave_idx_type i = 0; i < nc + 1; i++)
00245 r.cidx(i) = a.cidx(i);
00246 r.maybe_compress (true);
00247 return r;
00248 }
00249
00250 template <class T>
00251 MSparse<T>
00252 operator * (const T& s, const MSparse<T>& a)
00253 {
00254 return times_or_divides (s, a, std::multiplies<T> ());
00255 }
00256
00257 template <class T>
00258 MSparse<T>
00259 operator / (const T& s, const MSparse<T>& a)
00260 {
00261 return times_or_divides (s, a, std::divides<T> ());
00262 }
00263
00264
00265
00266
00267 template <class T, class OP>
00268 MSparse<T>
00269 plus_or_minus (const MSparse<T>& a, const MSparse<T>& b, OP op,
00270 const char* op_name, bool negate)
00271 {
00272 MSparse<T> r;
00273
00274 octave_idx_type a_nr = a.rows ();
00275 octave_idx_type a_nc = a.cols ();
00276
00277 octave_idx_type b_nr = b.rows ();
00278 octave_idx_type b_nc = b.cols ();
00279
00280 if (a_nr == 1 && a_nc == 1)
00281 {
00282 if (a.elem(0,0) == 0.)
00283 if (negate)
00284 r = -MSparse<T> (b);
00285 else
00286 r = MSparse<T> (b);
00287 else
00288 {
00289 r = MSparse<T> (b_nr, b_nc, op (a.data(0), 0.));
00290
00291 for (octave_idx_type j = 0 ; j < b_nc ; j++)
00292 {
00293 octave_quit ();
00294 octave_idx_type idxj = j * b_nr;
00295 for (octave_idx_type i = b.cidx(j) ; i < b.cidx(j+1) ; i++)
00296 {
00297 octave_quit ();
00298 r.data(idxj + b.ridx(i)) = op (a.data(0), b.data(i));
00299 }
00300 }
00301 r.maybe_compress ();
00302 }
00303 }
00304 else if (b_nr == 1 && b_nc == 1)
00305 {
00306 if (b.elem(0,0) == 0.)
00307 r = MSparse<T> (a);
00308 else
00309 {
00310 r = MSparse<T> (a_nr, a_nc, op (0.0, b.data(0)));
00311
00312 for (octave_idx_type j = 0 ; j < a_nc ; j++)
00313 {
00314 octave_quit ();
00315 octave_idx_type idxj = j * a_nr;
00316 for (octave_idx_type i = a.cidx(j) ; i < a.cidx(j+1) ; i++)
00317 {
00318 octave_quit ();
00319 r.data(idxj + a.ridx(i)) = op (a.data(i), b.data(0));
00320 }
00321 }
00322 r.maybe_compress ();
00323 }
00324 }
00325 else if (a_nr != b_nr || a_nc != b_nc)
00326 gripe_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
00327 else
00328 {
00329 r = MSparse<T> (a_nr, a_nc, (a.nnz () + b.nnz ()));
00330
00331 octave_idx_type jx = 0;
00332 r.cidx (0) = 0;
00333 for (octave_idx_type i = 0 ; i < a_nc ; i++)
00334 {
00335 octave_idx_type ja = a.cidx(i);
00336 octave_idx_type ja_max = a.cidx(i+1);
00337 bool ja_lt_max= ja < ja_max;
00338
00339 octave_idx_type jb = b.cidx(i);
00340 octave_idx_type jb_max = b.cidx(i+1);
00341 bool jb_lt_max = jb < jb_max;
00342
00343 while (ja_lt_max || jb_lt_max )
00344 {
00345 octave_quit ();
00346 if ((! jb_lt_max) ||
00347 (ja_lt_max && (a.ridx(ja) < b.ridx(jb))))
00348 {
00349 r.ridx(jx) = a.ridx(ja);
00350 r.data(jx) = op (a.data(ja), 0.);
00351 jx++;
00352 ja++;
00353 ja_lt_max= ja < ja_max;
00354 }
00355 else if (( !ja_lt_max ) ||
00356 (jb_lt_max && (b.ridx(jb) < a.ridx(ja)) ) )
00357 {
00358 r.ridx(jx) = b.ridx(jb);
00359 r.data(jx) = op (0., b.data(jb));
00360 jx++;
00361 jb++;
00362 jb_lt_max= jb < jb_max;
00363 }
00364 else
00365 {
00366 if (op (a.data(ja), b.data(jb)) != 0.)
00367 {
00368 r.data(jx) = op (a.data(ja), b.data(jb));
00369 r.ridx(jx) = a.ridx(ja);
00370 jx++;
00371 }
00372 ja++;
00373 ja_lt_max= ja < ja_max;
00374 jb++;
00375 jb_lt_max= jb < jb_max;
00376 }
00377 }
00378 r.cidx(i+1) = jx;
00379 }
00380
00381 r.maybe_compress ();
00382 }
00383
00384 return r;
00385 }
00386
00387 template <class T>
00388 MSparse<T>
00389 operator+ (const MSparse<T>& a, const MSparse<T>& b)
00390 {
00391 return plus_or_minus (a, b, std::plus<T> (), "operator +", false);
00392 }
00393
00394 template <class T>
00395 MSparse<T>
00396 operator- (const MSparse<T>& a, const MSparse<T>& b)
00397 {
00398 return plus_or_minus (a, b, std::minus<T> (), "operator -", true);
00399 }
00400
00401 template <class T>
00402 MSparse<T>
00403 product (const MSparse<T>& a, const MSparse<T>& b)
00404 {
00405 MSparse<T> r;
00406
00407 octave_idx_type a_nr = a.rows ();
00408 octave_idx_type a_nc = a.cols ();
00409
00410 octave_idx_type b_nr = b.rows ();
00411 octave_idx_type b_nc = b.cols ();
00412
00413 if (a_nr == 1 && a_nc == 1)
00414 {
00415 if (a.elem(0,0) == 0.)
00416 r = MSparse<T> (b_nr, b_nc);
00417 else
00418 {
00419 r = MSparse<T> (b);
00420 octave_idx_type b_nnz = b.nnz();
00421
00422 for (octave_idx_type i = 0 ; i < b_nnz ; i++)
00423 {
00424 octave_quit ();
00425 r.data (i) = a.data(0) * r.data(i);
00426 }
00427 r.maybe_compress ();
00428 }
00429 }
00430 else if (b_nr == 1 && b_nc == 1)
00431 {
00432 if (b.elem(0,0) == 0.)
00433 r = MSparse<T> (a_nr, a_nc);
00434 else
00435 {
00436 r = MSparse<T> (a);
00437 octave_idx_type a_nnz = a.nnz();
00438
00439 for (octave_idx_type i = 0 ; i < a_nnz ; i++)
00440 {
00441 octave_quit ();
00442 r.data (i) = r.data(i) * b.data(0);
00443 }
00444 r.maybe_compress ();
00445 }
00446 }
00447 else if (a_nr != b_nr || a_nc != b_nc)
00448 gripe_nonconformant ("product", a_nr, a_nc, b_nr, b_nc);
00449 else
00450 {
00451 r = MSparse<T> (a_nr, a_nc, (a.nnz () > b.nnz () ? a.nnz () : b.nnz ()));
00452
00453 octave_idx_type jx = 0;
00454 r.cidx (0) = 0;
00455 for (octave_idx_type i = 0 ; i < a_nc ; i++)
00456 {
00457 octave_idx_type ja = a.cidx(i);
00458 octave_idx_type ja_max = a.cidx(i+1);
00459 bool ja_lt_max= ja < ja_max;
00460
00461 octave_idx_type jb = b.cidx(i);
00462 octave_idx_type jb_max = b.cidx(i+1);
00463 bool jb_lt_max = jb < jb_max;
00464
00465 while (ja_lt_max || jb_lt_max )
00466 {
00467 octave_quit ();
00468 if ((! jb_lt_max) ||
00469 (ja_lt_max && (a.ridx(ja) < b.ridx(jb))))
00470 {
00471 ja++; ja_lt_max= ja < ja_max;
00472 }
00473 else if (( !ja_lt_max ) ||
00474 (jb_lt_max && (b.ridx(jb) < a.ridx(ja)) ) )
00475 {
00476 jb++; jb_lt_max= jb < jb_max;
00477 }
00478 else
00479 {
00480 if ((a.data(ja) * b.data(jb)) != 0.)
00481 {
00482 r.data(jx) = a.data(ja) * b.data(jb);
00483 r.ridx(jx) = a.ridx(ja);
00484 jx++;
00485 }
00486 ja++; ja_lt_max= ja < ja_max;
00487 jb++; jb_lt_max= jb < jb_max;
00488 }
00489 }
00490 r.cidx(i+1) = jx;
00491 }
00492
00493 r.maybe_compress ();
00494 }
00495
00496 return r;
00497 }
00498
00499 template <class T>
00500 MSparse<T>
00501 quotient (const MSparse<T>& a, const MSparse<T>& b)
00502 {
00503 MSparse<T> r;
00504 T Zero = T ();
00505
00506 octave_idx_type a_nr = a.rows ();
00507 octave_idx_type a_nc = a.cols ();
00508
00509 octave_idx_type b_nr = b.rows ();
00510 octave_idx_type b_nc = b.cols ();
00511
00512 if (a_nr == 1 && a_nc == 1)
00513 {
00514 T val = a.elem (0,0);
00515 T fill = val / T();
00516 if (fill == T())
00517 {
00518 octave_idx_type b_nnz = b.nnz();
00519 r = MSparse<T> (b);
00520 for (octave_idx_type i = 0 ; i < b_nnz ; i++)
00521 r.data (i) = val / r.data(i);
00522 r.maybe_compress ();
00523 }
00524 else
00525 {
00526 r = MSparse<T> (b_nr, b_nc, fill);
00527 for (octave_idx_type j = 0 ; j < b_nc ; j++)
00528 {
00529 octave_quit ();
00530 octave_idx_type idxj = j * b_nr;
00531 for (octave_idx_type i = b.cidx(j) ; i < b.cidx(j+1) ; i++)
00532 {
00533 octave_quit ();
00534 r.data(idxj + b.ridx(i)) = val / b.data(i);
00535 }
00536 }
00537 r.maybe_compress ();
00538 }
00539 }
00540 else if (b_nr == 1 && b_nc == 1)
00541 {
00542 T val = b.elem (0,0);
00543 T fill = T() / val;
00544 if (fill == T())
00545 {
00546 octave_idx_type a_nnz = a.nnz();
00547 r = MSparse<T> (a);
00548 for (octave_idx_type i = 0 ; i < a_nnz ; i++)
00549 r.data (i) = r.data(i) / val;
00550 r.maybe_compress ();
00551 }
00552 else
00553 {
00554 r = MSparse<T> (a_nr, a_nc, fill);
00555 for (octave_idx_type j = 0 ; j < a_nc ; j++)
00556 {
00557 octave_quit ();
00558 octave_idx_type idxj = j * a_nr;
00559 for (octave_idx_type i = a.cidx(j) ; i < a.cidx(j+1) ; i++)
00560 {
00561 octave_quit ();
00562 r.data(idxj + a.ridx(i)) = a.data(i) / val;
00563 }
00564 }
00565 r.maybe_compress ();
00566 }
00567 }
00568 else if (a_nr != b_nr || a_nc != b_nc)
00569 gripe_nonconformant ("quotient", a_nr, a_nc, b_nr, b_nc);
00570 else
00571 {
00572 r = MSparse<T>( a_nr, a_nc, (Zero / Zero));
00573
00574 for (octave_idx_type i = 0 ; i < a_nc ; i++)
00575 {
00576 octave_idx_type ja = a.cidx(i);
00577 octave_idx_type ja_max = a.cidx(i+1);
00578 bool ja_lt_max= ja < ja_max;
00579
00580 octave_idx_type jb = b.cidx(i);
00581 octave_idx_type jb_max = b.cidx(i+1);
00582 bool jb_lt_max = jb < jb_max;
00583
00584 while (ja_lt_max || jb_lt_max )
00585 {
00586 octave_quit ();
00587 if ((! jb_lt_max) ||
00588 (ja_lt_max && (a.ridx(ja) < b.ridx(jb))))
00589 {
00590 r.elem (a.ridx(ja),i) = a.data(ja) / Zero;
00591 ja++; ja_lt_max= ja < ja_max;
00592 }
00593 else if (( !ja_lt_max ) ||
00594 (jb_lt_max && (b.ridx(jb) < a.ridx(ja)) ) )
00595 {
00596 r.elem (b.ridx(jb),i) = Zero / b.data(jb);
00597 jb++; jb_lt_max= jb < jb_max;
00598 }
00599 else
00600 {
00601 r.elem (a.ridx(ja),i) = a.data(ja) / b.data(jb);
00602 ja++; ja_lt_max= ja < ja_max;
00603 jb++; jb_lt_max= jb < jb_max;
00604 }
00605 }
00606 }
00607
00608 r.maybe_compress (true);
00609 }
00610
00611 return r;
00612 }
00613
00614
00615
00616
00617
00618 template <class T>
00619 MSparse<T>
00620 operator + (const MSparse<T>& a)
00621 {
00622 return a;
00623 }
00624
00625 template <class T>
00626 MSparse<T>
00627 operator - (const MSparse<T>& a)
00628 {
00629 MSparse<T> retval (a);
00630 octave_idx_type nz = a.nnz ();
00631 for (octave_idx_type i = 0; i < nz; i++)
00632 retval.data(i) = - retval.data(i);
00633 return retval;
00634 }