GNU Octave 7.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
find.cc
Go to the documentation of this file.
1////////////////////////////////////////////////////////////////////////
2//
3// Copyright (C) 1996-2022 The Octave Project Developers
4//
5// See the file COPYRIGHT.md in the top-level directory of this
6// distribution or <https://octave.org/copyright/>.
7//
8// This file is part of Octave.
9//
10// Octave is free software: you can redistribute it and/or modify it
11// under the terms of the GNU General Public License as published by
12// the Free Software Foundation, either version 3 of the License, or
13// (at your option) any later version.
14//
15// Octave is distributed in the hope that it will be useful, but
16// WITHOUT ANY WARRANTY; without even the implied warranty of
17// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18// GNU General Public License for more details.
19//
20// You should have received a copy of the GNU General Public License
21// along with Octave; see the file COPYING. If not, see
22// <https://www.gnu.org/licenses/>.
23//
24////////////////////////////////////////////////////////////////////////
25
26#if defined (HAVE_CONFIG_H)
27# include "config.h"
28#endif
29
30#include "quit.h"
31
32#include "defun.h"
33#include "error.h"
34#include "errwarn.h"
35#include "ovl.h"
36
37OCTAVE_NAMESPACE_BEGIN
38
39// Find at most N_TO_FIND nonzero elements in NDA. Search forward if
40// DIRECTION is 1, backward if it is -1. NARGOUT is the number of
41// output arguments. If N_TO_FIND is -1, find all nonzero elements.
42
43template <typename T>
45find_nonzero_elem_idx (const Array<T>& nda, int nargout,
46 octave_idx_type n_to_find, int direction)
47{
48 octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ());
49
51 if (n_to_find >= 0)
52 idx = nda.find (n_to_find, direction == -1);
53 else
54 idx = nda.find ();
55
56 // The maximum element is always at the end.
57 octave_idx_type iext = (idx.isempty () ? 0 : idx.xelem (idx.numel () - 1) + 1);
58
59 switch (nargout)
60 {
61 default:
62 case 3:
63 retval(2) = Array<T> (nda.index (idx_vector (idx)));
64 OCTAVE_FALLTHROUGH;
65
66 case 2:
67 {
68 Array<octave_idx_type> jdx (idx.dims ());
69 octave_idx_type n = idx.numel ();
70 octave_idx_type nr = nda.rows ();
71 for (octave_idx_type i = 0; i < n; i++)
72 {
73 jdx.xelem (i) = idx.xelem (i) / nr;
74 idx.xelem (i) %= nr;
75 }
76 iext = -1;
77 retval(1) = idx_vector (jdx, -1);
78 }
79 OCTAVE_FALLTHROUGH;
80
81 case 1:
82 case 0:
83 retval(0) = idx_vector (idx, iext);
84 break;
85 }
86
87 return retval;
88}
89
90template <typename T>
92find_nonzero_elem_idx (const Sparse<T>& v, int nargout,
93 octave_idx_type n_to_find, int direction)
94{
95 nargout = std::min (nargout, 5);
96 octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ());
97
98 octave_idx_type nr = v.rows ();
99 octave_idx_type nc = v.cols ();
100 octave_idx_type nz = v.nnz ();
101
102 // Search in the default range.
103 octave_idx_type start_nc = -1;
104 octave_idx_type end_nc = -1;
105 octave_idx_type count;
106
107 // Search for the range to search
108 if (n_to_find < 0)
109 {
110 start_nc = 0;
111 end_nc = nc;
112 n_to_find = nz;
113 }
114 else if (direction > 0)
115 {
116 for (octave_idx_type j = 0; j < nc; j++)
117 {
118 octave_quit ();
119
120 if (v.cidx (j) == 0 && v.cidx (j+1) != 0)
121 start_nc = j;
122 if (v.cidx (j+1) >= n_to_find)
123 {
124 end_nc = j + 1;
125 break;
126 }
127 }
128 }
129 else
130 {
131 for (octave_idx_type j = nc; j > 0; j--)
132 {
133 octave_quit ();
134
135 if (v.cidx (j) == nz && v.cidx (j-1) != nz)
136 end_nc = j;
137 if (nz - v.cidx (j-1) >= n_to_find)
138 {
139 start_nc = j - 1;
140 break;
141 }
142 }
143 }
144
145 count = (n_to_find > v.cidx (end_nc) - v.cidx (start_nc) ?
146 v.cidx (end_nc) - v.cidx (start_nc) : n_to_find);
147
148 octave_idx_type result_nr;
149 octave_idx_type result_nc;
150
151 // Default case is to return a column vector, however, if the original
152 // argument was a row vector, then force return of a row vector.
153 if (nr == 1)
154 {
155 result_nr = 1;
156 result_nc = count;
157 }
158 else
159 {
160 result_nr = count;
161 result_nc = 1;
162 }
163
164 Matrix idx (result_nr, result_nc);
165
166 Matrix i_idx (result_nr, result_nc);
167 Matrix j_idx (result_nr, result_nc);
168
169 Array<T> val (dim_vector (result_nr, result_nc));
170
171 if (count > 0)
172 {
173 // Search for elements to return. Only search the region where there
174 // are elements to be found using the count that we want to find.
175 for (octave_idx_type j = start_nc, cx = 0; j < end_nc; j++)
176 for (octave_idx_type i = v.cidx (j); i < v.cidx (j+1); i++)
177 {
178 octave_quit ();
179
180 if (direction < 0 && i < nz - count)
181 continue;
182 i_idx(cx) = static_cast<double> (v.ridx (i) + 1);
183 j_idx(cx) = static_cast<double> (j + 1);
184 idx(cx) = j * nr + v.ridx (i) + 1;
185 val(cx) = v.data(i);
186 cx++;
187 if (cx == count)
188 break;
189 }
190 }
191 else
192 {
193 // No items found. Fixup return dimensions for Matlab compatibility.
194 // The behavior to match is documented in Array.cc (Array<T>::find).
195 if ((nr == 0 && nc == 0) || (nr == 1 && nc == 1))
196 {
197 idx.resize (0, 0);
198
199 i_idx.resize (0, 0);
200 j_idx.resize (0, 0);
201
202 val.resize (dim_vector (0, 0));
203 }
204 }
205
206 switch (nargout)
207 {
208 case 0:
209 case 1:
210 retval(0) = idx;
211 break;
212
213 case 5:
214 retval(4) = nc;
215 OCTAVE_FALLTHROUGH;
216
217 case 4:
218 retval(3) = nr;
219 OCTAVE_FALLTHROUGH;
220
221 case 3:
222 retval(2) = val;
223 OCTAVE_FALLTHROUGH;
224
225 case 2:
226 retval(1) = j_idx;
227 retval(0) = i_idx;
228 }
229
230 return retval;
231}
232
234find_nonzero_elem_idx (const PermMatrix& v, int nargout,
235 octave_idx_type n_to_find, int direction)
236{
237 // There are far fewer special cases to handle for a PermMatrix.
238 nargout = std::min (nargout, 5);
239 octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ());
240
241 octave_idx_type nr = v.rows ();
242 octave_idx_type nc = v.cols ();
243 octave_idx_type start_nc, count;
244
245 // Determine the range to search.
246 if (n_to_find < 0 || n_to_find >= nc)
247 {
248 start_nc = 0;
249 count = nc;
250 }
251 else if (direction > 0)
252 {
253 start_nc = 0;
254 count = n_to_find;
255 }
256 else
257 {
258 start_nc = nc - n_to_find;
259 count = n_to_find;
260 }
261
262 Matrix idx (count, 1);
263 Matrix i_idx (count, 1);
264 Matrix j_idx (count, 1);
265 // Every value is 1.
266 Array<double> val (dim_vector (count, 1), 1.0);
267
268 if (count > 0)
269 {
270 const Array<octave_idx_type>& p = v.col_perm_vec ();
271 for (octave_idx_type k = 0; k < count; k++)
272 {
273 octave_quit ();
274
275 const octave_idx_type j = start_nc + k;
276 const octave_idx_type i = p(j);
277 i_idx(k) = static_cast<double> (1+i);
278 j_idx(k) = static_cast<double> (1+j);
279 idx(k) = j * nc + i + 1;
280 }
281 }
282 else
283 {
284 // FIXME: Is this case even possible? A scalar permutation matrix seems
285 // to devolve to a scalar full matrix, at least from the Octave command
286 // line. Perhaps this function could be called internally from C++ with
287 // such a matrix.
288 // No items found. Fixup return dimensions for Matlab compatibility.
289 // The behavior to match is documented in Array.cc (Array<T>::find).
290 if ((nr == 0 && nc == 0) || (nr == 1 && nc == 1))
291 {
292 idx.resize (0, 0);
293
294 i_idx.resize (0, 0);
295 j_idx.resize (0, 0);
296
297 val.resize (dim_vector (0, 0));
298 }
299 }
300
301 switch (nargout)
302 {
303 case 0:
304 case 1:
305 retval(0) = idx;
306 break;
307
308 case 5:
309 retval(4) = nc;
310 OCTAVE_FALLTHROUGH;
311
312 case 4:
313 retval(3) = nc;
314 OCTAVE_FALLTHROUGH;
315
316 case 3:
317 retval(2) = val;
318 OCTAVE_FALLTHROUGH;
319
320 case 2:
321 retval(1) = j_idx;
322 retval(0) = i_idx;
323 }
324
325 return retval;
326}
327
328DEFUN (find, args, nargout,
329 doc: /* -*- texinfo -*-
330@deftypefn {} {@var{idx} =} find (@var{x})
331@deftypefnx {} {@var{idx} =} find (@var{x}, @var{n})
332@deftypefnx {} {@var{idx} =} find (@var{x}, @var{n}, @var{direction})
333@deftypefnx {} {[i, j] =} find (@dots{})
334@deftypefnx {} {[i, j, v] =} find (@dots{})
335Return a vector of indices of nonzero elements of a matrix, as a row if
336@var{x} is a row vector or as a column otherwise.
337
338To obtain a single index for each matrix element, Octave pretends that the
339columns of a matrix form one long vector (like Fortran arrays are stored).
340For example:
341
342@example
343@group
344find (eye (2))
345 @result{} [ 1; 4 ]
346@end group
347@end example
348
349If two inputs are given, @var{n} indicates the maximum number of elements to
350find from the beginning of the matrix or vector.
351
352If three inputs are given, @var{direction} should be one of
353@qcode{"first"} or @qcode{"last"}, requesting only the first or last
354@var{n} indices, respectively. However, the indices are always returned in
355ascending order.
356
357If two outputs are requested, @code{find} returns the row and column
358indices of nonzero elements of a matrix. For example:
359
360@example
361@group
362[i, j] = find (2 * eye (2))
363 @result{} i = [ 1; 2 ]
364 @result{} j = [ 1; 2 ]
365@end group
366@end example
367
368If three outputs are requested, @code{find} also returns a vector
369containing the nonzero values. For example:
370
371@example
372@group
373[i, j, v] = find (3 * eye (2))
374 @result{} i = [ 1; 2 ]
375 @result{} j = [ 1; 2 ]
376 @result{} v = [ 3; 3 ]
377@end group
378@end example
379
380If @var{x} is a multi-dimensional array of size m x n x p x @dots{}, @var{j}
381contains the column locations as if @var{x} was flattened into a
382two-dimensional matrix of size m x (n + p + @dots{}).
383
384Note that this function is particularly useful for sparse matrices, as
385it extracts the nonzero elements as vectors, which can then be used to
386create the original matrix. For example:
387
388@example
389@group
390sz = size (a);
391[i, j, v] = find (a);
392b = sparse (i, j, v, sz(1), sz(2));
393@end group
394@end example
395@seealso{nonzeros}
396@end deftypefn */)
397{
398 int nargin = args.length ();
399
400 if (nargin < 1 || nargin > 3)
401 print_usage ();
402
403 // Setup the default options.
404 octave_idx_type n_to_find = -1;
405 if (nargin > 1)
406 {
407 double val = args(1).xscalar_value ("find: N must be an integer");
408
409 if (val < 0 || (! math::isinf (val)
410 && val != math::fix (val)))
411 error ("find: N must be a non-negative integer");
412 else if (! math::isinf (val))
413 n_to_find = val;
414 }
415
416 // Direction to do the searching (1 == forward, -1 == reverse).
417 int direction = 1;
418 if (nargin > 2)
419 {
420 std::string s_arg = args(2).string_value ();
421
422 if (s_arg == "first")
423 direction = 1;
424 else if (s_arg == "last")
425 direction = -1;
426 else
427 error (R"(find: DIRECTION must be "first" or "last")");
428 }
429
430 octave_value_list retval;
431
432 octave_value arg = args(0);
433
434 if (arg.islogical ())
435 {
436 if (arg.issparse ())
437 {
439
440 retval = find_nonzero_elem_idx (v, nargout, n_to_find, direction);
441 }
442 else if (nargout <= 1 && n_to_find == -1 && direction == 1)
443 {
444 // This case is equivalent to extracting indices from a logical
445 // matrix. Try to reuse the possibly cached index vector.
446
447 // No need to catch index_exception, since arg is bool.
448 // Out-of-range errors have already set pos, and will be
449 // caught later.
450
451 octave_value result = arg.index_vector ().unmask ();
452
453 dim_vector dv = result.dims ();
454
455 retval(0) = (dv.all_zero () || dv.isvector ()
456 ? result : result.reshape (dv.as_column ()));
457 }
458 else
459 {
460 boolNDArray v = arg.bool_array_value ();
461
462 retval = find_nonzero_elem_idx (v, nargout, n_to_find, direction);
463 }
464 }
465 else if (arg.isinteger ())
466 {
467#define DO_INT_BRANCH(INTT) \
468 else if (arg.is_ ## INTT ## _type ()) \
469 { \
470 INTT ## NDArray v = arg.INTT ## _array_value (); \
471 \
472 retval = find_nonzero_elem_idx (v, nargout, n_to_find, direction); \
473 }
474
475 if (false)
476 ;
477 DO_INT_BRANCH (int8)
478 DO_INT_BRANCH (int16)
479 DO_INT_BRANCH (int32)
480 DO_INT_BRANCH (int64)
481 DO_INT_BRANCH (uint8)
482 DO_INT_BRANCH (uint16)
483 DO_INT_BRANCH (uint32)
484 DO_INT_BRANCH (uint64)
485 else
487 }
488 else if (arg.issparse ())
489 {
490 if (arg.isreal ())
491 {
493
494 retval = find_nonzero_elem_idx (v, nargout, n_to_find, direction);
495 }
496 else if (arg.iscomplex ())
497 {
499
500 retval = find_nonzero_elem_idx (v, nargout, n_to_find, direction);
501 }
502 else
503 err_wrong_type_arg ("find", arg);
504 }
505 else if (arg.is_perm_matrix ())
506 {
507 PermMatrix P = arg.perm_matrix_value ();
508
509 retval = find_nonzero_elem_idx (P, nargout, n_to_find, direction);
510 }
511 else if (arg.is_string ())
512 {
513 charNDArray chnda = arg.char_array_value ();
514
515 retval = find_nonzero_elem_idx (chnda, nargout, n_to_find, direction);
516 }
517 else if (arg.is_single_type ())
518 {
519 if (arg.isreal ())
520 {
521 FloatNDArray nda = arg.float_array_value ();
522
523 retval = find_nonzero_elem_idx (nda, nargout, n_to_find, direction);
524 }
525 else if (arg.iscomplex ())
526 {
528
529 retval = find_nonzero_elem_idx (cnda, nargout, n_to_find, direction);
530 }
531 }
532 else if (arg.isreal ())
533 {
534 NDArray nda = arg.array_value ();
535
536 retval = find_nonzero_elem_idx (nda, nargout, n_to_find, direction);
537 }
538 else if (arg.iscomplex ())
539 {
541
542 retval = find_nonzero_elem_idx (cnda, nargout, n_to_find, direction);
543 }
544 else
545 err_wrong_type_arg ("find", arg);
546
547 return retval;
548}
549
550/*
551%!assert (find (char ([0, 97])), 2)
552%!assert (find ([1, 0, 1, 0, 1]), [1, 3, 5])
553%!assert (find ([1; 0; 3; 0; 1]), [1; 3; 5])
554%!assert (find ([0, 0, 2; 0, 3, 0; -1, 0, 0]), [3; 5; 7])
555
556%!assert <*53603> (find (ones (1,1,2) > 0), [1;2])
557%!assert <*53603> (find (ones (1,1,1,3) > 0), [1;2;3])
558
559%!test
560%! [i, j, v] = find ([0, 0, 2; 0, 3, 0; -1, 0, 0]);
561%!
562%! assert (i, [3; 2; 1]);
563%! assert (j, [1; 2; 3]);
564%! assert (v, [-1; 3; 2]);
565
566%!assert (find (single ([1, 0, 1, 0, 1])), [1, 3, 5])
567%!assert (find (single ([1; 0; 3; 0; 1])), [1; 3; 5])
568%!assert (find (single ([0, 0, 2; 0, 3, 0; -1, 0, 0])), [3; 5; 7])
569
570%!test
571%! [i, j, v] = find (single ([0, 0, 2; 0, 3, 0; -1, 0, 0]));
572%!
573%! assert (i, [3; 2; 1]);
574%! assert (j, [1; 2; 3]);
575%! assert (v, single ([-1; 3; 2]));
576
577%!test
578%! pcol = [5 1 4 3 2];
579%! P = eye (5) (:, pcol);
580%! [i, j, v] = find (P);
581%! [ifull, jfull, vfull] = find (full (P));
582%! assert (i, ifull);
583%! assert (j, jfull);
584%! assert (all (v == 1));
585
586%!test
587%! prow = [5 1 4 3 2];
588%! P = eye (5) (prow, :);
589%! [i, j, v] = find (P);
590%! [ifull, jfull, vfull] = find (full (P));
591%! assert (i, ifull);
592%! assert (j, jfull);
593%! assert (all (v == 1));
594
595%!test <*61986>
596%! P = cat (3, eye(3), eye(3));
597%! loc = find (P);
598%! [i, j, v] = find(P);
599%! assert (loc, [1, 5, 9, 10, 14, 18]');
600%! assert (i, [1, 2, 3, 1, 2, 3]');
601%! assert (j, [1, 2, 3, 4, 5, 6]');
602%! assert (v, [1, 1, 1, 1, 1, 1]');
603
604%!assert <*53655> (find (false), zeros (0, 0))
605%!assert <*53655> (find ([false, false]), zeros (1, 0))
606%!assert <*53655> (find ([false; false]), zeros (0, 1))
607%!assert <*53655> (find ([false, false; false, false]), zeros (0, 1))
608
609%!assert (find ([2 0 1 0 5 0], 1), 1)
610%!assert (find ([2 0 1 0 5 0], 2, "last"), [3, 5])
611
612%!assert (find ([2 0 1 0 5 0], Inf), [1, 3, 5])
613%!assert (find ([2 0 1 0 5 0], Inf, "last"), [1, 3, 5])
614
615%!error find ()
616*/
617
618OCTAVE_NAMESPACE_END
charNDArray min(char d, const charNDArray &m)
Definition: chNDArray.cc:207
T & xelem(octave_idx_type n)
Size of the specified dimension.
Definition: Array.h:504
OCTARRAY_API Array< T, Alloc > index(const octave::idx_vector &i) const
Indexing without resizing.
Definition: Array.cc:697
octave_idx_type numel(void) const
Number of elements in the array.
Definition: Array.h:411
const dim_vector & dims(void) const
Return a const-reference so that dims ()(i) works efficiently.
Definition: Array.h:487
octave_idx_type rows(void) const
Definition: Array.h:449
bool isempty(void) const
Size of the specified dimension.
Definition: Array.h:607
OCTARRAY_API void resize(const dim_vector &dv, const T &rfv)
Size of the specified dimension.
Definition: Array.cc:1010
OCTARRAY_API Array< octave_idx_type > find(octave_idx_type n=-1, bool backward=false) const
Find indices of (at most n) nonzero elements.
Definition: Array.cc:2218
Definition: dMatrix.h:42
void resize(octave_idx_type nr, octave_idx_type nc, double rfv=0)
Definition: dMatrix.h:158
octave_idx_type cols(void) const
Definition: PermMatrix.h:63
octave_idx_type rows(void) const
Definition: PermMatrix.h:62
const Array< octave_idx_type > & col_perm_vec(void) const
Definition: PermMatrix.h:83
octave_idx_type rows(void) const
Definition: Sparse.h:351
T * data(void)
Definition: Sparse.h:574
octave_idx_type nnz(void) const
Actual number of nonzero terms.
Definition: Sparse.h:339
octave_idx_type * ridx(void)
Definition: Sparse.h:583
octave_idx_type * cidx(void)
Definition: Sparse.h:596
octave_idx_type cols(void) const
Definition: Sparse.h:352
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:94
bool isvector(void) const
Definition: dim-vector.h:395
dim_vector as_column(void) const
Definition: dim-vector.h:379
bool all_zero(void) const
Definition: dim-vector.h:300
OCTAVE_API idx_vector unmask(void) const
Definition: idx-vector.cc:1158
boolNDArray bool_array_value(bool warn=false) const
Definition: ov.h:936
SparseMatrix sparse_matrix_value(bool frc_str_conv=false) const
Definition: ov.h:945
bool isreal(void) const
Definition: ov.h:783
bool issparse(void) const
Definition: ov.h:798
octave::idx_vector index_vector(bool require_integers=false) const
Definition: ov.h:579
bool is_string(void) const
Definition: ov.h:682
ComplexNDArray complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:923
charNDArray char_array_value(bool frc_str_conv=false) const
Definition: ov.h:942
bool isinteger(void) const
Definition: ov.h:775
octave_value reshape(const dim_vector &dv) const
Definition: ov.h:616
PermMatrix perm_matrix_value(void) const
Definition: ov.h:968
SparseBoolMatrix sparse_bool_matrix_value(bool warn=false) const
Definition: ov.h:952
NDArray array_value(bool frc_str_conv=false) const
Definition: ov.h:904
bool is_single_type(void) const
Definition: ov.h:743
FloatComplexNDArray float_complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:927
FloatNDArray float_array_value(bool frc_str_conv=false) const
Definition: ov.h:907
bool is_perm_matrix(void) const
Definition: ov.h:679
bool iscomplex(void) const
Definition: ov.h:786
bool islogical(void) const
Definition: ov.h:780
dim_vector dims(void) const
Definition: ov.h:586
SparseComplexMatrix sparse_complex_matrix_value(bool frc_str_conv=false) const
Definition: ov.h:949
static octave_idx_type find(octave_idx_type i, octave_idx_type *pp)
Definition: colamd.cc:106
OCTINTERP_API void print_usage(void)
Definition: defun-int.h:72
#define DEFUN(name, args_name, nargout_name, doc)
Macro to define a builtin function.
Definition: defun.h:56
void error(const char *fmt,...)
Definition: error.cc:980
#define panic_impossible()
Definition: error.h:411
void err_wrong_type_arg(const char *name, const char *s)
Definition: errwarn.cc:166
#define DO_INT_BRANCH(INTT)
OCTAVE_NAMESPACE_BEGIN octave_value_list find_nonzero_elem_idx(const Array< T > &nda, int nargout, octave_idx_type n_to_find, int direction)
Definition: find.cc:45
octave::idx_vector idx_vector
Definition: idx-vector.h:1037
double fix(double x)
Definition: lo-mappers.h:118
bool isinf(double x)
Definition: lo-mappers.h:203