00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #ifdef HAVE_CONFIG_H
00025 #include <config.h>
00026 #endif
00027
00028 #include <algorithm>
00029 #include "Array.h"
00030 #include "Sparse.h"
00031 #include "mx-base.h"
00032
00033 #include "ov.h"
00034 #include "Cell.h"
00035
00036 #include "defun-dld.h"
00037 #include "error.h"
00038 #include "oct-obj.h"
00039
00040
00041 template <class T>
00042 static Array<T>
00043 do_tril (const Array<T>& a, octave_idx_type k, bool pack)
00044 {
00045 octave_idx_type nr = a.rows (), nc = a.columns ();
00046 const T *avec = a.fortran_vec ();
00047 octave_idx_type zero = 0;
00048
00049 if (pack)
00050 {
00051 octave_idx_type j1 = std::min (std::max (zero, k), nc);
00052 octave_idx_type j2 = std::min (std::max (zero, nr + k), nc);
00053 octave_idx_type n = j1 * nr + ((j2 - j1) * (nr-(j1-k) + nr-(j2-1-k))) / 2;
00054 Array<T> r (dim_vector (n, 1));
00055 T *rvec = r.fortran_vec ();
00056 for (octave_idx_type j = 0; j < nc; j++)
00057 {
00058 octave_idx_type ii = std::min (std::max (zero, j - k), nr);
00059 rvec = std::copy (avec + ii, avec + nr, rvec);
00060 avec += nr;
00061 }
00062
00063 return r;
00064 }
00065 else
00066 {
00067 Array<T> r (a.dims ());
00068 T *rvec = r.fortran_vec ();
00069 for (octave_idx_type j = 0; j < nc; j++)
00070 {
00071 octave_idx_type ii = std::min (std::max (zero, j - k), nr);
00072 std::fill (rvec, rvec + ii, T());
00073 std::copy (avec + ii, avec + nr, rvec + ii);
00074 avec += nr;
00075 rvec += nr;
00076 }
00077
00078 return r;
00079 }
00080 }
00081
00082 template <class T>
00083 static Array<T>
00084 do_triu (const Array<T>& a, octave_idx_type k, bool pack)
00085 {
00086 octave_idx_type nr = a.rows (), nc = a.columns ();
00087 const T *avec = a.fortran_vec ();
00088 octave_idx_type zero = 0;
00089
00090 if (pack)
00091 {
00092 octave_idx_type j1 = std::min (std::max (zero, k), nc);
00093 octave_idx_type j2 = std::min (std::max (zero, nr + k), nc);
00094 octave_idx_type n = ((j2 - j1) * ((j1+1-k) + (j2-k))) / 2 + (nc - j2) * nr;
00095 Array<T> r (dim_vector (n, 1));
00096 T *rvec = r.fortran_vec ();
00097 for (octave_idx_type j = 0; j < nc; j++)
00098 {
00099 octave_idx_type ii = std::min (std::max (zero, j + 1 - k), nr);
00100 rvec = std::copy (avec, avec + ii, rvec);
00101 avec += nr;
00102 }
00103
00104 return r;
00105 }
00106 else
00107 {
00108 NoAlias<Array<T> > r (a.dims ());
00109 T *rvec = r.fortran_vec ();
00110 for (octave_idx_type j = 0; j < nc; j++)
00111 {
00112 octave_idx_type ii = std::min (std::max (zero, j + 1 - k), nr);
00113 std::copy (avec, avec + ii, rvec);
00114 std::fill (rvec + ii, rvec + nr, T());
00115 avec += nr;
00116 rvec += nr;
00117 }
00118
00119 return r;
00120 }
00121 }
00122
00123
00124
00125
00126 template <class T>
00127 static Sparse<T>
00128 do_tril (const Sparse<T>& a, octave_idx_type k, bool pack)
00129 {
00130 if (pack)
00131 {
00132 error ("tril: \"pack\" not implemented for sparse matrices");
00133 return Sparse<T> ();
00134 }
00135
00136 Sparse<T> m = a;
00137 octave_idx_type nc = m.cols();
00138
00139 for (octave_idx_type j = 0; j < nc; j++)
00140 for (octave_idx_type i = m.cidx(j); i < m.cidx(j+1); i++)
00141 if (m.ridx(i) < j-k)
00142 m.data(i) = 0.;
00143
00144 m.maybe_compress (true);
00145 return m;
00146 }
00147
00148 template <class T>
00149 static Sparse<T>
00150 do_triu (const Sparse<T>& a, octave_idx_type k, bool pack)
00151 {
00152 if (pack)
00153 {
00154 error ("triu: \"pack\" not implemented for sparse matrices");
00155 return Sparse<T> ();
00156 }
00157
00158 Sparse<T> m = a;
00159 octave_idx_type nc = m.cols();
00160
00161 for (octave_idx_type j = 0; j < nc; j++)
00162 for (octave_idx_type i = m.cidx(j); i < m.cidx(j+1); i++)
00163 if (m.ridx(i) > j-k)
00164 m.data(i) = 0.;
00165
00166 m.maybe_compress (true);
00167 return m;
00168 }
00169
00170
00171 template <class T>
00172 static Array<T>
00173 do_trilu (const Array<T>& a, octave_idx_type k, bool lower, bool pack)
00174 {
00175 return lower ? do_tril (a, k, pack) : do_triu (a, k, pack);
00176 }
00177
00178 template <class T>
00179 static Sparse<T>
00180 do_trilu (const Sparse<T>& a, octave_idx_type k, bool lower, bool pack)
00181 {
00182 return lower ? do_tril (a, k, pack) : do_triu (a, k, pack);
00183 }
00184
00185 static octave_value
00186 do_trilu (const std::string& name,
00187 const octave_value_list& args)
00188 {
00189 bool lower = name == "tril";
00190
00191 octave_value retval;
00192 int nargin = args.length ();
00193 octave_idx_type k = 0;
00194 bool pack = false;
00195 if (nargin >= 2 && args(nargin-1).is_string ())
00196 {
00197 pack = args(nargin-1).string_value () == "pack";
00198 nargin--;
00199 }
00200
00201 if (nargin == 2)
00202 {
00203 k = args(1).int_value (true);
00204
00205 if (error_state)
00206 return retval;
00207 }
00208
00209 if (nargin < 1 || nargin > 2)
00210 print_usage ();
00211 else
00212 {
00213 octave_value arg = args (0);
00214
00215 dim_vector dims = arg.dims ();
00216 if (dims.length () != 2)
00217 error ("%s: need a 2-D matrix", name.c_str ());
00218 else if (k < -dims (0) || k > dims(1))
00219 error ("%s: requested diagonal out of range", name.c_str ());
00220 else
00221 {
00222 switch (arg.builtin_type ())
00223 {
00224 case btyp_double:
00225 if (arg.is_sparse_type ())
00226 retval = do_trilu (arg.sparse_matrix_value (), k, lower, pack);
00227 else
00228 retval = do_trilu (arg.array_value (), k, lower, pack);
00229 break;
00230 case btyp_complex:
00231 if (arg.is_sparse_type ())
00232 retval = do_trilu (arg.sparse_complex_matrix_value (), k, lower, pack);
00233 else
00234 retval = do_trilu (arg.complex_array_value (), k, lower, pack);
00235 break;
00236 case btyp_bool:
00237 if (arg.is_sparse_type ())
00238 retval = do_trilu (arg.sparse_bool_matrix_value (), k, lower, pack);
00239 else
00240 retval = do_trilu (arg.bool_array_value (), k, lower, pack);
00241 break;
00242 #define ARRAYCASE(TYP) \
00243 case btyp_ ## TYP: \
00244 retval = do_trilu (arg.TYP ## _array_value (), k, lower, pack); \
00245 break
00246 ARRAYCASE (float);
00247 ARRAYCASE (float_complex);
00248 ARRAYCASE (int8);
00249 ARRAYCASE (int16);
00250 ARRAYCASE (int32);
00251 ARRAYCASE (int64);
00252 ARRAYCASE (uint8);
00253 ARRAYCASE (uint16);
00254 ARRAYCASE (uint32);
00255 ARRAYCASE (uint64);
00256 ARRAYCASE (char);
00257 #undef ARRAYCASE
00258 default:
00259 {
00260
00261
00262
00263 if (pack)
00264 {
00265 error ("%s: \"pack\" not implemented for class %s",
00266 name.c_str (), arg.class_name ().c_str ());
00267 return octave_value ();
00268 }
00269
00270 octave_value tmp = arg;
00271 if (arg.numel () == 0)
00272 return arg;
00273
00274 octave_idx_type nr = dims(0), nc = dims (1);
00275
00276
00277
00278
00279
00280
00281
00282
00283
00284
00285
00286 octave_value_list ov_idx;
00287 std::list<octave_value_list> idx_tmp;
00288 ov_idx(1) = static_cast<double> (nc+1);
00289 ov_idx(0) = Range (1, nr);
00290 idx_tmp.push_back (ov_idx);
00291 ov_idx(1) = static_cast<double> (nc);
00292 tmp = tmp.resize (dim_vector (0,0));
00293 tmp = tmp.subsasgn("(",idx_tmp, arg.do_index_op (ov_idx));
00294 tmp = tmp.resize(dims);
00295
00296 if (lower)
00297 {
00298 octave_idx_type st = nc < nr + k ? nc : nr + k;
00299
00300 for (octave_idx_type j = 1; j <= st; j++)
00301 {
00302 octave_idx_type nr_limit = 1 > j - k ? 1 : j - k;
00303 ov_idx(1) = static_cast<double> (j);
00304 ov_idx(0) = Range (nr_limit, nr);
00305 std::list<octave_value_list> idx;
00306 idx.push_back (ov_idx);
00307
00308 tmp = tmp.subsasgn ("(", idx, arg.do_index_op(ov_idx));
00309
00310 if (error_state)
00311 return retval;
00312 }
00313 }
00314 else
00315 {
00316 octave_idx_type st = k + 1 > 1 ? k + 1 : 1;
00317
00318 for (octave_idx_type j = st; j <= nc; j++)
00319 {
00320 octave_idx_type nr_limit = nr < j - k ? nr : j - k;
00321 ov_idx(1) = static_cast<double> (j);
00322 ov_idx(0) = Range (1, nr_limit);
00323 std::list<octave_value_list> idx;
00324 idx.push_back (ov_idx);
00325
00326 tmp = tmp.subsasgn ("(", idx, arg.do_index_op(ov_idx));
00327
00328 if (error_state)
00329 return retval;
00330 }
00331 }
00332
00333 retval = tmp;
00334 }
00335 }
00336 }
00337 }
00338
00339 return retval;
00340 }
00341
00342 DEFUN_DLD (tril, args, ,
00343 "-*- texinfo -*-\n\
00344 @deftypefn {Function File} {} tril (@var{A})\n\
00345 @deftypefnx {Function File} {} tril (@var{A}, @var{k})\n\
00346 @deftypefnx {Function File} {} tril (@var{A}, @var{k}, @var{pack})\n\
00347 @deftypefnx {Function File} {} triu (@var{A})\n\
00348 @deftypefnx {Function File} {} triu (@var{A}, @var{k})\n\
00349 @deftypefnx {Function File} {} triu (@var{A}, @var{k}, @var{pack})\n\
00350 Return a new matrix formed by extracting the lower (@code{tril})\n\
00351 or upper (@code{triu}) triangular part of the matrix @var{A}, and\n\
00352 setting all other elements to zero. The second argument is optional,\n\
00353 and specifies how many diagonals above or below the main diagonal should\n\
00354 also be set to zero.\n\
00355 \n\
00356 The default value of @var{k} is zero, so that @code{triu} and\n\
00357 @code{tril} normally include the main diagonal as part of the result.\n\
00358 \n\
00359 If the value of @var{k} is nonzero integer, the selection of elements\
00360 starts at an offset of @var{k} diagonals above or below the main\
00361 diagonal; above for positive @var{k} and below for negative @var{k}.\
00362 \n\
00363 The absolute value of @var{k} must not be greater than the number of\n\
00364 sub-diagonals or super-diagonals.\n\
00365 \n\
00366 For example:\n\
00367 \n\
00368 @example\n\
00369 @group\n\
00370 tril (ones (3), -1)\n\
00371 @result{} 0 0 0\n\
00372 1 0 0\n\
00373 1 1 0\n\
00374 @end group\n\
00375 @end example\n\
00376 \n\
00377 @noindent\n\
00378 and\n\
00379 \n\
00380 @example\n\
00381 @group\n\
00382 tril (ones (3), 1)\n\
00383 @result{} 1 1 0\n\
00384 1 1 1\n\
00385 1 1 1\n\
00386 @end group\n\
00387 @end example\n\
00388 \n\
00389 If the option \"pack\" is given as third argument, the extracted elements\n\
00390 are not inserted into a matrix, but rather stacked column-wise one above\n\
00391 other.\n\
00392 @seealso{diag}\n\
00393 @end deftypefn")
00394 {
00395 return do_trilu ("tril", args);
00396 }
00397
00398 DEFUN_DLD (triu, args, ,
00399 "-*- texinfo -*-\n\
00400 @deftypefn {Function File} {} triu (@var{A})\n\
00401 @deftypefnx {Function File} {} triu (@var{A}, @var{k})\n\
00402 @deftypefnx {Function File} {} triu (@var{A}, @var{k}, @var{pack})\n\
00403 See the documentation for the @code{tril} function (@pxref{tril}).\n\
00404 @end deftypefn")
00405 {
00406 return do_trilu ("triu", args);
00407 }
00408
00409
00410
00411
00412
00413
00414
00415
00416
00417
00418
00419
00420
00421
00422
00423
00424
00425
00426
00427
00428