00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #if !defined (octave_sparse_diag_op_defs_h)
00024 #define octave_sparse_diag_op_defs_h 1
00025
00026
00027
00028 template <typename RT, typename DM, typename SM>
00029 RT do_mul_dm_sm (const DM& d, const SM& a)
00030 {
00031 const octave_idx_type nr = d.rows ();
00032 const octave_idx_type nc = d.cols ();
00033
00034 const octave_idx_type a_nr = a.rows ();
00035 const octave_idx_type a_nc = a.cols ();
00036
00037 if (nc != a_nr)
00038 {
00039 gripe_nonconformant ("operator *", nr, nc, a_nr, a_nc);
00040 return RT ();
00041 }
00042 else
00043 {
00044 RT r (nr, a_nc, a.nnz ());
00045
00046 octave_idx_type l = 0;
00047
00048 for (octave_idx_type j = 0; j < a_nc; j++)
00049 {
00050 r.xcidx (j) = l;
00051 const octave_idx_type colend = a.cidx (j+1);
00052 for (octave_idx_type k = a.cidx (j); k < colend; k++)
00053 {
00054 const octave_idx_type i = a.ridx (k);
00055 if (i >= nr) break;
00056 r.xdata (l) = d.dgelem (i) * a.data (k);
00057 r.xridx (l) = i;
00058 l++;
00059 }
00060 }
00061
00062 r.xcidx (a_nc) = l;
00063
00064 r.maybe_compress (true);
00065 return r;
00066 }
00067 }
00068
00069 template <typename RT, typename SM, typename DM>
00070 RT do_mul_sm_dm (const SM& a, const DM& d)
00071 {
00072 const octave_idx_type nr = d.rows ();
00073 const octave_idx_type nc = d.cols ();
00074
00075 const octave_idx_type a_nr = a.rows ();
00076 const octave_idx_type a_nc = a.cols ();
00077
00078 if (nr != a_nc)
00079 {
00080 gripe_nonconformant ("operator *", a_nr, a_nc, nr, nc);
00081 return RT ();
00082 }
00083 else
00084 {
00085
00086 const octave_idx_type mnc = nc < a_nc ? nc: a_nc;
00087 RT r (a_nr, nc, a.cidx (mnc));
00088
00089 for (octave_idx_type j = 0; j < mnc; ++j)
00090 {
00091 const typename DM::element_type s = d.dgelem (j);
00092 const octave_idx_type colend = a.cidx (j+1);
00093 r.xcidx (j) = a.cidx (j);
00094 for (octave_idx_type k = a.cidx (j); k < colend; ++k)
00095 {
00096 r.xdata (k) = s * a.data (k);
00097 r.xridx (k) = a.ridx (k);
00098 }
00099 }
00100 for (octave_idx_type j = mnc; j <= nc; ++j)
00101 r.xcidx (j) = a.cidx (mnc);
00102
00103 r.maybe_compress (true);
00104 return r;
00105 }
00106 }
00107
00108
00109 template <typename T>
00110 struct identity_val
00111 : public std::unary_function <T, T>
00112 {
00113 T operator () (const T x) { return x; }
00114 };
00115
00116
00117
00118 template <typename RT, typename SM, typename DM, typename OpA, typename OpD>
00119 RT inner_do_add_sm_dm (const SM& a, const DM& d, OpA opa, OpD opd)
00120 {
00121 using std::min;
00122 const octave_idx_type nr = d.rows ();
00123 const octave_idx_type nc = d.cols ();
00124 const octave_idx_type n = min (nr, nc);
00125
00126 const octave_idx_type a_nr = a.rows ();
00127 const octave_idx_type a_nc = a.cols ();
00128
00129 const octave_idx_type nz = a.nnz ();
00130 RT r (a_nr, a_nc, nz + n);
00131 octave_idx_type k = 0;
00132
00133 for (octave_idx_type j = 0; j < nc; ++j)
00134 {
00135 OCTAVE_QUIT;
00136 const octave_idx_type colend = a.cidx (j+1);
00137 r.xcidx (j) = k;
00138 octave_idx_type k_src = a.cidx (j), k_split;
00139
00140 for (k_split = k_src; k_split < colend; k_split++)
00141 if (a.ridx (k_split) >= j)
00142 break;
00143
00144 for (; k_src < k_split; k_src++, k++)
00145 {
00146 r.xridx (k) = a.ridx (k_src);
00147 r.xdata (k) = opa (a.data (k_src));
00148 }
00149
00150 if (k_src < colend && a.ridx (k_src) == j)
00151 {
00152 r.xridx (k) = j;
00153 r.xdata (k) = opa (a.data (k_src)) + opd (d.dgelem (j));
00154 k++; k_src++;
00155 }
00156 else
00157 {
00158 r.xridx (k) = j;
00159 r.xdata (k) = opd (d.dgelem (j));
00160 k++;
00161 }
00162
00163 for (; k_src < colend; k_src++, k++)
00164 {
00165 r.xridx (k) = a.ridx (k_src);
00166 r.xdata (k) = opa (a.data (k_src));
00167 }
00168
00169 }
00170 r.xcidx (nc) = k;
00171
00172 r.maybe_compress (true);
00173 return r;
00174 }
00175
00176 template <typename RT, typename DM, typename SM>
00177 RT do_commutative_add_dm_sm (const DM& d, const SM& a)
00178 {
00179
00180 return inner_do_add_sm_dm<RT> (a, d,
00181 identity_val<typename SM::element_type> (),
00182 identity_val<typename DM::element_type> ());
00183 }
00184
00185 template <typename RT, typename DM, typename SM>
00186 RT do_add_dm_sm (const DM& d, const SM& a)
00187 {
00188 if (a.rows () != d.rows () || a.cols () != d.cols ())
00189 {
00190 gripe_nonconformant ("operator +", d.rows (), d.cols (), a.rows (), a.cols ());
00191 return RT ();
00192 }
00193 else
00194 return do_commutative_add_dm_sm<RT> (d, a);
00195 }
00196
00197 template <typename RT, typename DM, typename SM>
00198 RT do_sub_dm_sm (const DM& d, const SM& a)
00199 {
00200 if (a.rows () != d.rows () || a.cols () != d.cols ())
00201 {
00202 gripe_nonconformant ("operator -", d.rows (), d.cols (), a.rows (), a.cols ());
00203 return RT ();
00204 }
00205 else
00206 return inner_do_add_sm_dm<RT> (a, d, std::negate<typename SM::element_type> (),
00207 identity_val<typename DM::element_type> ());
00208 }
00209
00210 template <typename RT, typename SM, typename DM>
00211 RT do_add_sm_dm (const SM& a, const DM& d)
00212 {
00213 if (a.rows () != d.rows () || a.cols () != d.cols ())
00214 {
00215 gripe_nonconformant ("operator +", a.rows (), a.cols (), d.rows (), d.cols ());
00216 return RT ();
00217 }
00218 else
00219 return do_commutative_add_dm_sm<RT> (d, a);
00220 }
00221
00222 template <typename RT, typename SM, typename DM>
00223 RT do_sub_sm_dm (const SM& a, const DM& d)
00224 {
00225 if (a.rows () != d.rows () || a.cols () != d.cols ())
00226 {
00227 gripe_nonconformant ("operator -", a.rows (), a.cols (), d.rows (), d.cols ());
00228 return RT ();
00229 }
00230 else
00231 return inner_do_add_sm_dm<RT> (a, d,
00232 identity_val<typename SM::element_type> (),
00233 std::negate<typename DM::element_type> ());
00234 }
00235
00236 #endif // octave_sparse_diag_op_defs_h