GNU Octave  8.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
tril.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2004-2023 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29 
30 #include <algorithm>
31 #include "Array.h"
32 #include "Sparse.h"
33 #include "mx-base.h"
34 
35 #include "ov.h"
36 #include "Cell.h"
37 
38 #include "defun.h"
39 #include "error.h"
40 #include "ovl.h"
41 
43 
44 // The bulk of the work.
45 template <typename T>
46 static Array<T>
47 do_tril (const Array<T>& a, octave_idx_type k, bool pack)
48 {
49  octave_idx_type nr = a.rows ();
50  octave_idx_type nc = a.columns ();
51  const T *avec = a.data ();
52  octave_idx_type zero = 0;
53 
54  if (pack)
55  {
56  octave_idx_type j1 = std::min (std::max (zero, k), nc);
57  octave_idx_type j2 = std::min (std::max (zero, nr + k), nc);
58  octave_idx_type n = j1 * nr + ((j2 - j1) * (nr-(j1-k) + nr-(j2-1-k))) / 2;
59  Array<T> r (dim_vector (n, 1));
60  T *rvec = r.fortran_vec ();
61  for (octave_idx_type j = 0; j < nc; j++)
62  {
63  octave_idx_type ii = std::min (std::max (zero, j - k), nr);
64  rvec = std::copy (avec + ii, avec + nr, rvec);
65  avec += nr;
66  }
67 
68  return r;
69  }
70  else
71  {
72  Array<T> r (a.dims ());
73  T *rvec = r.fortran_vec ();
74  for (octave_idx_type j = 0; j < nc; j++)
75  {
76  octave_idx_type ii = std::min (std::max (zero, j - k), nr);
77  std::fill (rvec, rvec + ii, T ());
78  std::copy (avec + ii, avec + nr, rvec + ii);
79  avec += nr;
80  rvec += nr;
81  }
82 
83  return r;
84  }
85 }
86 
87 template <typename T>
88 static Array<T>
89 do_triu (const Array<T>& a, octave_idx_type k, bool pack)
90 {
91  octave_idx_type nr = a.rows ();
92  octave_idx_type nc = a.columns ();
93  const T *avec = a.data ();
94  octave_idx_type zero = 0;
95 
96  if (pack)
97  {
98  octave_idx_type j1 = std::min (std::max (zero, k), nc);
99  octave_idx_type j2 = std::min (std::max (zero, nr + k), nc);
101  = ((j2 - j1) * ((j1+1-k) + (j2-k))) / 2 + (nc - j2) * nr;
102  Array<T> r (dim_vector (n, 1));
103  T *rvec = r.fortran_vec ();
104  for (octave_idx_type j = 0; j < nc; j++)
105  {
106  octave_idx_type ii = std::min (std::max (zero, j + 1 - k), nr);
107  rvec = std::copy (avec, avec + ii, rvec);
108  avec += nr;
109  }
110 
111  return r;
112  }
113  else
114  {
115  Array<T> r (a.dims ());
116  T *rvec = r.fortran_vec ();
117  for (octave_idx_type j = 0; j < nc; j++)
118  {
119  octave_idx_type ii = std::min (std::max (zero, j + 1 - k), nr);
120  std::copy (avec, avec + ii, rvec);
121  std::fill (rvec + ii, rvec + nr, T ());
122  avec += nr;
123  rvec += nr;
124  }
125 
126  return r;
127  }
128 }
129 
130 // These two are by David Bateman.
131 // FIXME: optimizations possible. "pack" support missing.
132 
133 template <typename T>
134 static Sparse<T>
135 do_tril (const Sparse<T>& a, octave_idx_type k, bool pack)
136 {
137  if (pack) // FIXME
138  error (R"(tril: "pack" not implemented for sparse matrices)");
139 
140  Sparse<T> m = a;
141  octave_idx_type nc = m.cols ();
142 
143  for (octave_idx_type j = 0; j < nc; j++)
144  for (octave_idx_type i = m.cidx (j); i < m.cidx (j+1); i++)
145  if (m.ridx (i) < j-k)
146  m.data(i) = 0.;
147 
148  m.maybe_compress (true);
149 
150  return m;
151 }
152 
153 template <typename T>
154 static Sparse<T>
155 do_triu (const Sparse<T>& a, octave_idx_type k, bool pack)
156 {
157  if (pack) // FIXME
158  error (R"(triu: "pack" not implemented for sparse matrices)");
159 
160  Sparse<T> m = a;
161  octave_idx_type nc = m.cols ();
162 
163  for (octave_idx_type j = 0; j < nc; j++)
164  for (octave_idx_type i = m.cidx (j); i < m.cidx (j+1); i++)
165  if (m.ridx (i) > j-k)
166  m.data(i) = 0.;
167 
168  m.maybe_compress (true);
169  return m;
170 }
171 
172 // Convenience dispatchers.
173 template <typename T>
174 static Array<T>
175 do_trilu (const Array<T>& a, octave_idx_type k, bool lower, bool pack)
176 {
177  return lower ? do_tril (a, k, pack) : do_triu (a, k, pack);
178 }
179 
180 template <typename T>
181 static Sparse<T>
182 do_trilu (const Sparse<T>& a, octave_idx_type k, bool lower, bool pack)
183 {
184  return lower ? do_tril (a, k, pack) : do_triu (a, k, pack);
185 }
186 
187 static octave_value
188 do_trilu (const std::string& name,
189  const octave_value_list& args)
190 {
191  bool lower = (name == "tril");
192 
193  int nargin = args.length ();
194  bool pack = false;
195 
196  if (nargin >= 2 && args(nargin-1).is_string ())
197  {
198  pack = (args(nargin-1).string_value () == "pack");
199  nargin--;
200  }
201 
202  if (nargin < 1 || nargin > 2)
203  print_usage ();
204 
205  octave_idx_type k = 0;
206  if (nargin == 2)
207  k = args(1).idx_type_value (true);
208 
209  octave_value arg = args(0);
210 
211  dim_vector dims = arg.dims ();
212  if (dims.ndims () != 2)
213  error ("%s: need a 2-D matrix", name.c_str ());
214  else if (k < -dims(0))
215  k = -dims(0);
216  else if (k > dims(1))
217  k = dims(1);
218 
219  octave_value retval;
220 
221  switch (arg.builtin_type ())
222  {
223  case btyp_double:
224  if (arg.issparse ())
225  retval = do_trilu (arg.sparse_matrix_value (), k, lower, pack);
226  else
227  retval = do_trilu (arg.array_value (), k, lower, pack);
228  break;
229 
230  case btyp_complex:
231  if (arg.issparse ())
232  retval = do_trilu (arg.sparse_complex_matrix_value (), k, lower,
233  pack);
234  else
235  retval = do_trilu (arg.complex_array_value (), k, lower, pack);
236  break;
237 
238  case btyp_bool:
239  if (arg.issparse ())
240  retval = do_trilu (arg.sparse_bool_matrix_value (), k, lower,
241  pack);
242  else
243  retval = do_trilu (arg.bool_array_value (), k, lower, pack);
244  break;
245 
246 #define ARRAYCASE(TYP) \
247  case btyp_ ## TYP: \
248  retval = do_trilu (arg.TYP ## _array_value (), k, lower, pack); \
249  break
250 
251  ARRAYCASE (float);
252  ARRAYCASE (float_complex);
253  ARRAYCASE (int8);
254  ARRAYCASE (int16);
255  ARRAYCASE (int32);
256  ARRAYCASE (int64);
257  ARRAYCASE (uint8);
258  ARRAYCASE (uint16);
259  ARRAYCASE (uint32);
260  ARRAYCASE (uint64);
261  ARRAYCASE (char);
262 
263 #undef ARRAYCASE
264 
265  default:
266  {
267  // Generic code that works on octave-values, that is slow
268  // but will also work on arbitrary user types
269  if (pack) // FIXME
270  error (R"(%s: "pack" not implemented for class %s)",
271  name.c_str (), arg.class_name ().c_str ());
272 
273  octave_value tmp = arg;
274  if (arg.isempty ())
275  return arg;
276 
277  octave_idx_type nr = dims(0);
278  octave_idx_type nc = dims(1);
279 
280  // The sole purpose of this code is to force the correct matrix size.
281  // This would not be necessary if the octave_value resize function
282  // allowed a fill_value. It also allows odd attributes in some user
283  // types to be handled. With a fill_value, it should be replaced with
284  //
285  // octave_value_list ov_idx;
286  // tmp = tmp.resize(dim_vector (0,0)).resize (dims, fill_value);
287 
288  octave_value_list ov_idx;
289  std::list<octave_value_list> idx_tmp;
290  ov_idx(1) = static_cast<double> (nc+1);
291  ov_idx(0) = range<double> (1, nr);
292  idx_tmp.push_back (ov_idx);
293  ov_idx(1) = static_cast<double> (nc);
294  tmp = tmp.resize (dim_vector (0, 0));
295  tmp = tmp.subsasgn ("(", idx_tmp, arg.index_op (ov_idx));
296  tmp = tmp.resize (dims);
297 
298  octave_idx_type one = 1;
299 
300  if (lower)
301  {
302  octave_idx_type st = std::min (nc, nr + k);
303 
304  for (octave_idx_type j = 1; j <= st; j++)
305  {
306  octave_idx_type nr_limit = std::max (one, j - k);
307  ov_idx(1) = static_cast<double> (j);
308  ov_idx(0) = range<double> (nr_limit, nr);
309  std::list<octave_value_list> idx;
310  idx.push_back (ov_idx);
311 
312  tmp = tmp.subsasgn ("(", idx, arg.index_op (ov_idx));
313  }
314  }
315  else
316  {
317  octave_idx_type st = std::max (k + 1, one);
318 
319  for (octave_idx_type j = st; j <= nc; j++)
320  {
321  octave_idx_type nr_limit = std::min (nr, j - k);
322  ov_idx(1) = static_cast<double> (j);
323  ov_idx(0) = range<double> (1, nr_limit);
324  std::list<octave_value_list> idx;
325  idx.push_back (ov_idx);
326 
327  tmp = tmp.subsasgn ("(", idx, arg.index_op (ov_idx));
328  }
329  }
330 
331  retval = tmp;
332  }
333  }
334 
335  return retval;
336 }
337 
338 DEFUN (tril, args, ,
339  doc: /* -*- texinfo -*-
340 @deftypefn {} {@var{A_LO} =} tril (@var{A})
341 @deftypefnx {} {@var{A_LO} =} tril (@var{A}, @var{k})
342 @deftypefnx {} {@var{A_LO} =} tril (@var{A}, @var{k}, @var{pack})
343 Return a new matrix formed by extracting the lower triangular part of the
344 matrix @var{A}, and setting all other elements to zero.
345 
346 The optional second argument specifies how many diagonals above or below the
347 main diagonal should also be set to zero. The default value of @var{k} is
348 zero which includes the main diagonal as part of the result. If the value of
349 @var{k} is a nonzero integer then the selection of elements starts at an offset
350 of @var{k} diagonals above the main diagonal for positive @var{k} or below the
351 main diagonal for negative @var{k}. The absolute value of @var{k} may not be
352 greater than the number of subdiagonals or superdiagonals.
353 
354 Example 1 : exclude main diagonal
355 
356 @example
357 @group
358 tril (ones (3), -1)
359  @result{} 0 0 0
360  1 0 0
361  1 1 0
362 @end group
363 @end example
364 
365 @noindent
366 
367 Example 2 : include first superdiagonal
368 
369 @example
370 @group
371 tril (ones (3), 1)
372  @result{} 1 1 0
373  1 1 1
374  1 1 1
375 @end group
376 @end example
377 
378 If the optional third argument @qcode{"pack"} is given then the extracted
379 elements are not inserted into a matrix, but instead stacked column-wise one
380 above another, and returned as a column vector.
381 @seealso{triu, istril, diag}
382 @end deftypefn */)
383 {
384  return do_trilu ("tril", args);
385 }
386 
387 DEFUN (triu, args, ,
388  doc: /* -*- texinfo -*-
389 @deftypefn {} {@var{A_UP} =} triu (@var{A})
390 @deftypefnx {} {@var{A_UP} =} triu (@var{A}, @var{k})
391 @deftypefnx {} {@var{A_UP} =} triu (@var{A}, @var{k}, @var{pack})
392 Return a new matrix formed by extracting the upper triangular part of the
393 matrix @var{A}, and setting all other elements to zero.
394 
395 The optional second argument specifies how many diagonals above or below the
396 main diagonal should also be set to zero. The default value of @var{k} is
397 zero which includes the main diagonal as part of the result. If the value of
398 @var{k} is a nonzero integer then the selection of elements starts at an offset
399 of @var{k} diagonals above the main diagonal for positive @var{k} or below the
400 main diagonal for negative @var{k}. The absolute value of @var{k} may not be
401 greater than the number of subdiagonals or superdiagonals.
402 
403 Example 1 : exclude main diagonal
404 
405 @example
406 @group
407 triu (ones (3), 1)
408  @result{} 0 1 1
409  0 0 1
410  0 0 0
411 @end group
412 @end example
413 
414 @noindent
415 
416 Example 2 : include first subdiagonal
417 
418 @example
419 @group
420 triu (ones (3), -1)
421  @result{} 1 1 1
422  1 1 1
423  0 1 1
424 @end group
425 @end example
426 
427 If the optional third argument @qcode{"pack"} is given then the extracted
428 elements are not inserted into a matrix, but instead stacked column-wise one
429 above another, and returned as a column vector.
430 @seealso{tril, istriu, diag}
431 @end deftypefn */)
432 {
433  return do_trilu ("triu", args);
434 }
435 
436 /*
437 %!shared a, l2, l1, l0, lm1, lm2, lm3, lm4
438 %! a = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12];
439 %!
440 %! l2 = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12];
441 %! l1 = [1, 2, 0; 4, 5, 6; 7, 8, 9; 10, 11, 12];
442 %! l0 = [1, 0, 0; 4, 5, 0; 7, 8, 9; 10, 11, 12];
443 %! lm1 = [0, 0, 0; 4, 0, 0; 7, 8, 0; 10, 11, 12];
444 %! lm2 = [0, 0, 0; 0, 0, 0; 7, 0, 0; 10, 11, 0];
445 %! lm3 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 10, 0, 0];
446 %! lm4 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 0, 0, 0];
447 %!
448 %!assert (tril (a, 3), l2)
449 %!assert (tril (a, 2), l2)
450 %!assert (tril (a, 1), l1)
451 %!assert (tril (a, 0), l0)
452 %!assert (tril (a), l0)
453 %!assert (tril (a, -1), lm1)
454 %!assert (tril (a, -2), lm2)
455 %!assert (tril (a, -3), lm3)
456 %!assert (tril (a, -4), lm4)
457 %!assert (tril (a, -5), lm4)
458 
459 %!shared a, u3, u2, u1, u0, um1, um2, um3
460 %!
461 %! a = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12];
462 %!
463 %! u3 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 0, 0, 0];
464 %! u2 = [0, 0, 3; 0, 0, 0; 0, 0, 0; 0, 0, 0];
465 %! u1 = [0, 2, 3; 0, 0, 6; 0, 0, 0; 0, 0, 0];
466 %! u0 = [1, 2, 3; 0, 5, 6; 0, 0, 9; 0, 0, 0];
467 %! um1 = [1, 2, 3; 4, 5, 6; 0, 8, 9; 0, 0, 12];
468 %! um2 = [1, 2, 3; 4, 5, 6; 7, 8, 9; 0, 11, 12];
469 %! um3 = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12];
470 %!
471 %!assert (triu (a, 4), u3)
472 %!assert (triu (a, 3), u3)
473 %!assert (triu (a, 2), u2)
474 %!assert (triu (a, 1), u1)
475 %!assert (triu (a, 0), u0)
476 %!assert (triu (a), u0)
477 %!assert (triu (a, -1), um1)
478 %!assert (triu (a, -2), um2)
479 %!assert (triu (a, -3), um3)
480 %!assert (triu (a, -4), um3)
481 
482 %!error tril ()
483 %!error triu ()
484 */
485 
OCTAVE_END_NAMESPACE(octave)
charNDArray max(char d, const charNDArray &m)
Definition: chNDArray.cc:230
charNDArray min(char d, const charNDArray &m)
Definition: chNDArray.cc:207
OCTARRAY_OVERRIDABLE_FUNC_API octave_idx_type columns(void) const
Definition: Array.h:471
OCTARRAY_OVERRIDABLE_FUNC_API const T * data(void) const
Size of the specified dimension.
Definition: Array.h:663
OCTARRAY_OVERRIDABLE_FUNC_API const dim_vector & dims(void) const
Return a const-reference so that dims ()(i) works efficiently.
Definition: Array.h:503
OCTARRAY_OVERRIDABLE_FUNC_API octave_idx_type rows(void) const
Definition: Array.h:459
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:94
octave_idx_type ndims(void) const
Number of dimensions.
Definition: dim-vector.h:257
octave_idx_type length(void) const
Definition: ovl.h:113
boolNDArray bool_array_value(bool warn=false) const
Definition: ov.h:936
SparseMatrix sparse_matrix_value(bool frc_str_conv=false) const
Definition: ov.h:945
bool issparse(void) const
Definition: ov.h:798
octave_value index_op(const octave_value_list &idx, bool resize_ok=false)
Definition: ov.h:550
builtin_type_t builtin_type(void) const
Definition: ov.h:735
ComplexNDArray complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:923
std::string class_name(void) const
Definition: ov.h:1454
bool isempty(void) const
Definition: ov.h:646
SparseBoolMatrix sparse_bool_matrix_value(bool warn=false) const
Definition: ov.h:952
NDArray array_value(bool frc_str_conv=false) const
Definition: ov.h:904
octave_value resize(const dim_vector &dv, bool fill=false) const
Definition: ov.h:625
OCTINTERP_API octave_value subsasgn(const std::string &type, const std::list< octave_value_list > &idx, const octave_value &rhs)
dim_vector dims(void) const
Definition: ov.h:586
SparseComplexMatrix sparse_complex_matrix_value(bool frc_str_conv=false) const
Definition: ov.h:949
OCTAVE_BEGIN_NAMESPACE(octave) static octave_value daspk_fcn
OCTINTERP_API void print_usage(void)
Definition: defun-int.h:72
#define DEFUN(name, args_name, nargout_name, doc)
Macro to define a builtin function.
Definition: defun.h:56
void error(const char *fmt,...)
Definition: error.cc:979
T octave_idx_type m
Definition: mx-inlines.cc:773
octave_idx_type n
Definition: mx-inlines.cc:753
T * r
Definition: mx-inlines.cc:773
@ btyp_double
Definition: ov-base.h:84
@ btyp_bool
Definition: ov-base.h:96
@ btyp_complex
Definition: ov-base.h:86
static Array< T > do_triu(const Array< T > &a, octave_idx_type k, bool pack)
Definition: tril.cc:89
#define ARRAYCASE(TYP)
static Array< T > do_trilu(const Array< T > &a, octave_idx_type k, bool lower, bool pack)
Definition: tril.cc:175
static Array< T > do_tril(const Array< T > &a, octave_idx_type k, bool pack)
Definition: tril.cc:47