00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #ifdef HAVE_CONFIG_H
00024 #include <config.h>
00025 #endif
00026
00027 #include "quit.h"
00028
00029 #include "defun-dld.h"
00030 #include "error.h"
00031 #include "gripes.h"
00032 #include "oct-obj.h"
00033
00034
00035
00036
00037
00038 template <typename T>
00039 octave_value_list
00040 find_nonzero_elem_idx (const Array<T>& nda, int nargout,
00041 octave_idx_type n_to_find, int direction)
00042 {
00043 octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ());
00044
00045 Array<octave_idx_type> idx;
00046 if (n_to_find >= 0)
00047 idx = nda.find (n_to_find, direction == -1);
00048 else
00049 idx = nda.find ();
00050
00051
00052 octave_idx_type iext = idx.is_empty () ? 0 : idx.xelem (idx.numel () - 1) + 1;
00053
00054 switch (nargout)
00055 {
00056 default:
00057 case 3:
00058 retval(2) = Array<T> (nda.index (idx_vector (idx)));
00059
00060
00061 case 2:
00062 {
00063 Array<octave_idx_type> jdx (idx.dims ());
00064 octave_idx_type n = idx.length (), nr = nda.rows ();
00065 for (octave_idx_type i = 0; i < n; i++)
00066 {
00067 jdx.xelem (i) = idx.xelem (i) / nr;
00068 idx.xelem (i) %= nr;
00069 }
00070 iext = -1;
00071 retval(1) = idx_vector (jdx, -1);
00072 }
00073
00074
00075 case 1:
00076 case 0:
00077 retval(0) = idx_vector (idx, iext);
00078 break;
00079 }
00080
00081 return retval;
00082 }
00083
00084 template <typename T>
00085 octave_value_list
00086 find_nonzero_elem_idx (const Sparse<T>& v, int nargout,
00087 octave_idx_type n_to_find, int direction)
00088 {
00089 octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ());
00090
00091
00092 octave_idx_type nc = v.cols();
00093 octave_idx_type nr = v.rows();
00094 octave_idx_type nz = v.nnz();
00095
00096
00097 octave_idx_type start_nc = -1;
00098 octave_idx_type end_nc = -1;
00099 octave_idx_type count;
00100
00101
00102 if (n_to_find < 0)
00103 {
00104 start_nc = 0;
00105 end_nc = nc;
00106 n_to_find = nz;
00107 count = nz;
00108 }
00109 else if (direction > 0)
00110 {
00111 for (octave_idx_type j = 0; j < nc; j++)
00112 {
00113 OCTAVE_QUIT;
00114 if (v.cidx(j) == 0 && v.cidx(j+1) != 0)
00115 start_nc = j;
00116 if (v.cidx(j+1) >= n_to_find)
00117 {
00118 end_nc = j + 1;
00119 break;
00120 }
00121 }
00122 }
00123 else
00124 {
00125 for (octave_idx_type j = nc; j > 0; j--)
00126 {
00127 OCTAVE_QUIT;
00128 if (v.cidx(j) == nz && v.cidx(j-1) != nz)
00129 end_nc = j;
00130 if (nz - v.cidx(j-1) >= n_to_find)
00131 {
00132 start_nc = j - 1;
00133 break;
00134 }
00135 }
00136 }
00137
00138 count = (n_to_find > v.cidx(end_nc) - v.cidx(start_nc) ?
00139 v.cidx(end_nc) - v.cidx(start_nc) : n_to_find);
00140
00141
00142
00143
00144
00145 octave_idx_type result_nr = count;
00146 octave_idx_type result_nc = 1;
00147
00148 bool scalar_arg = false;
00149
00150 if (v.rows () == 1)
00151 {
00152 result_nr = 1;
00153 result_nc = count;
00154
00155 scalar_arg = (v.columns () == 1);
00156 }
00157
00158 Matrix idx (result_nr, result_nc);
00159
00160 Matrix i_idx (result_nr, result_nc);
00161 Matrix j_idx (result_nr, result_nc);
00162
00163 Array<T> val (dim_vector (result_nr, result_nc));
00164
00165 if (count > 0)
00166 {
00167
00168
00169
00170 for (octave_idx_type j = start_nc, cx = 0; j < end_nc; j++)
00171 for (octave_idx_type i = v.cidx(j); i < v.cidx(j+1); i++ )
00172 {
00173 OCTAVE_QUIT;
00174 if (direction < 0 && i < nz - count)
00175 continue;
00176 i_idx(cx) = static_cast<double> (v.ridx(i) + 1);
00177 j_idx(cx) = static_cast<double> (j + 1);
00178 idx(cx) = j * nr + v.ridx(i) + 1;
00179 val(cx) = v.data(i);
00180 cx++;
00181 if (cx == count)
00182 break;
00183 }
00184 }
00185 else if (scalar_arg)
00186 {
00187 idx.resize (0, 0);
00188
00189 i_idx.resize (0, 0);
00190 j_idx.resize (0, 0);
00191
00192 val.resize (dim_vector (0, 0));
00193 }
00194
00195 switch (nargout)
00196 {
00197 case 0:
00198 case 1:
00199 retval(0) = idx;
00200 break;
00201
00202 case 5:
00203 retval(4) = nc;
00204
00205
00206 case 4:
00207 retval(3) = nr;
00208
00209
00210 case 3:
00211 retval(2) = val;
00212
00213
00214 case 2:
00215 retval(1) = j_idx;
00216 retval(0) = i_idx;
00217 break;
00218
00219 default:
00220 panic_impossible ();
00221 break;
00222 }
00223
00224 return retval;
00225 }
00226
00227 octave_value_list
00228 find_nonzero_elem_idx (const PermMatrix& v, int nargout,
00229 octave_idx_type n_to_find, int direction)
00230 {
00231
00232 octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ());
00233
00234 octave_idx_type nc = v.cols();
00235 octave_idx_type start_nc, count;
00236
00237
00238 if (n_to_find < 0 || n_to_find >= nc)
00239 {
00240 start_nc = 0;
00241 n_to_find = nc;
00242 count = nc;
00243 }
00244 else if (direction > 0)
00245 {
00246 start_nc = 0;
00247 count = n_to_find;
00248 }
00249 else
00250 {
00251 start_nc = nc - n_to_find;
00252 count = n_to_find;
00253 }
00254
00255 bool scalar_arg = (v.rows () == 1 && v.cols () == 1);
00256
00257 Matrix idx (count, 1);
00258 Matrix i_idx (count, 1);
00259 Matrix j_idx (count, 1);
00260
00261 Array<double> val (dim_vector (count, 1), 1.0);
00262
00263 if (count > 0)
00264 {
00265 const octave_idx_type* p = v.data ();
00266 if (v.is_col_perm ())
00267 {
00268 for (octave_idx_type k = 0; k < count; k++)
00269 {
00270 OCTAVE_QUIT;
00271 const octave_idx_type j = start_nc + k;
00272 const octave_idx_type i = p[j];
00273 i_idx(k) = static_cast<double> (1+i);
00274 j_idx(k) = static_cast<double> (1+j);
00275 idx(k) = j * nc + i + 1;
00276 }
00277 }
00278 else
00279 {
00280 for (octave_idx_type k = 0; k < count; k++)
00281 {
00282 OCTAVE_QUIT;
00283 const octave_idx_type i = start_nc + k;
00284 const octave_idx_type j = p[i];
00285
00286
00287 const octave_idx_type koff = j - start_nc;
00288 i_idx(koff) = static_cast<double> (1+i);
00289 j_idx(koff) = static_cast<double> (1+j);
00290 idx(koff) = j * nc + i + 1;
00291 }
00292 }
00293 }
00294 else if (scalar_arg)
00295 {
00296
00297 idx.resize (0, 0);
00298 i_idx.resize (0, 0);
00299 j_idx.resize (0, 0);
00300 val.resize (dim_vector (0, 0));
00301 }
00302
00303 switch (nargout)
00304 {
00305 case 0:
00306 case 1:
00307 retval(0) = idx;
00308 break;
00309
00310 case 5:
00311 retval(4) = nc;
00312
00313
00314 case 4:
00315 retval(3) = nc;
00316
00317
00318 case 3:
00319 retval(2) = val;
00320
00321
00322 case 2:
00323 retval(1) = j_idx;
00324 retval(0) = i_idx;
00325 break;
00326
00327 default:
00328 panic_impossible ();
00329 break;
00330 }
00331
00332 return retval;
00333 }
00334
00335 DEFUN_DLD (find, args, nargout,
00336 "-*- texinfo -*-\n\
00337 @deftypefn {Loadable Function} {@var{idx} =} find (@var{x})\n\
00338 @deftypefnx {Loadable Function} {@var{idx} =} find (@var{x}, @var{n})\n\
00339 @deftypefnx {Loadable Function} {@var{idx} =} find (@var{x}, @var{n}, @var{direction})\n\
00340 @deftypefnx {Loadable Function} {[i, j] =} find (@dots{})\n\
00341 @deftypefnx {Loadable Function} {[i, j, v] =} find (@dots{})\n\
00342 Return a vector of indices of nonzero elements of a matrix, as a row if\n\
00343 @var{x} is a row vector or as a column otherwise. To obtain a single index\n\
00344 for each matrix element, Octave pretends that the columns of a matrix form\n\
00345 one long vector (like Fortran arrays are stored). For example:\n\
00346 \n\
00347 @example\n\
00348 @group\n\
00349 find (eye (2))\n\
00350 @result{} [ 1; 4 ]\n\
00351 @end group\n\
00352 @end example\n\
00353 \n\
00354 If two outputs are requested, @code{find} returns the row and column\n\
00355 indices of nonzero elements of a matrix. For example:\n\
00356 \n\
00357 @example\n\
00358 @group\n\
00359 [i, j] = find (2 * eye (2))\n\
00360 @result{} i = [ 1; 2 ]\n\
00361 @result{} j = [ 1; 2 ]\n\
00362 @end group\n\
00363 @end example\n\
00364 \n\
00365 If three outputs are requested, @code{find} also returns a vector\n\
00366 containing the nonzero values. For example:\n\
00367 \n\
00368 @example\n\
00369 @group\n\
00370 [i, j, v] = find (3 * eye (2))\n\
00371 @result{} i = [ 1; 2 ]\n\
00372 @result{} j = [ 1; 2 ]\n\
00373 @result{} v = [ 3; 3 ]\n\
00374 @end group\n\
00375 @end example\n\
00376 \n\
00377 If two inputs are given, @var{n} indicates the maximum number of\n\
00378 elements to find from the beginning of the matrix or vector.\n\
00379 \n\
00380 If three inputs are given, @var{direction} should be one of \"first\" or\n\
00381 \"last\", requesting only the first or last @var{n} indices, respectively.\n\
00382 However, the indices are always returned in ascending order.\n\
00383 \n\
00384 Note that this function is particularly useful for sparse matrices, as\n\
00385 it extracts the non-zero elements as vectors, which can then be used to\n\
00386 create the original matrix. For example:\n\
00387 \n\
00388 @example\n\
00389 @group\n\
00390 sz = size (a);\n\
00391 [i, j, v] = find (a);\n\
00392 b = sparse (i, j, v, sz(1), sz(2));\n\
00393 @end group\n\
00394 @end example\n\
00395 @seealso{nonzeros}\n\
00396 @end deftypefn")
00397 {
00398 octave_value_list retval;
00399
00400 int nargin = args.length ();
00401
00402 if (nargin > 3 || nargin < 1)
00403 {
00404 print_usage ();
00405 return retval;
00406 }
00407
00408
00409 octave_idx_type n_to_find = -1;
00410 if (nargin > 1)
00411 {
00412 double val = args(1).scalar_value ();
00413
00414 if (error_state || (val < 0 || (! xisinf (val) && val != xround (val))))
00415 {
00416 error ("find: N must be a non-negative integer");
00417 return retval;
00418 }
00419 else if (! xisinf (val))
00420 n_to_find = val;
00421 }
00422
00423
00424 int direction = 1;
00425 if (nargin > 2)
00426 {
00427 direction = 0;
00428
00429 std::string s_arg = args(2).string_value ();
00430
00431 if (! error_state)
00432 {
00433 if (s_arg == "first")
00434 direction = 1;
00435 else if (s_arg == "last")
00436 direction = -1;
00437 }
00438
00439 if (direction == 0)
00440 {
00441 error ("find: DIRECTION must be \"first\" or \"last\"");
00442 return retval;
00443 }
00444 }
00445
00446 octave_value arg = args(0);
00447
00448 if (arg.is_bool_type ())
00449 {
00450 if (arg.is_sparse_type ())
00451 {
00452 SparseBoolMatrix v = arg.sparse_bool_matrix_value ();
00453
00454 if (! error_state)
00455 retval = find_nonzero_elem_idx (v, nargout,
00456 n_to_find, direction);
00457 }
00458 else if (nargout <= 1 && n_to_find == -1 && direction == 1)
00459 {
00460
00461
00462 retval(0) = arg.index_vector ().unmask ();
00463 }
00464 else
00465 {
00466 boolNDArray v = arg.bool_array_value ();
00467
00468 if (! error_state)
00469 retval = find_nonzero_elem_idx (v, nargout,
00470 n_to_find, direction);
00471 }
00472 }
00473 else if (arg.is_integer_type ())
00474 {
00475 #define DO_INT_BRANCH(INTT) \
00476 else if (arg.is_ ## INTT ## _type ()) \
00477 { \
00478 INTT ## NDArray v = arg.INTT ## _array_value (); \
00479 \
00480 if (! error_state) \
00481 retval = find_nonzero_elem_idx (v, nargout, \
00482 n_to_find, direction);\
00483 }
00484
00485 if (false)
00486 ;
00487 DO_INT_BRANCH (int8)
00488 DO_INT_BRANCH (int16)
00489 DO_INT_BRANCH (int32)
00490 DO_INT_BRANCH (int64)
00491 DO_INT_BRANCH (uint8)
00492 DO_INT_BRANCH (uint16)
00493 DO_INT_BRANCH (uint32)
00494 DO_INT_BRANCH (uint64)
00495 else
00496 panic_impossible ();
00497 }
00498 else if (arg.is_sparse_type ())
00499 {
00500 if (arg.is_real_type ())
00501 {
00502 SparseMatrix v = arg.sparse_matrix_value ();
00503
00504 if (! error_state)
00505 retval = find_nonzero_elem_idx (v, nargout,
00506 n_to_find, direction);
00507 }
00508 else if (arg.is_complex_type ())
00509 {
00510 SparseComplexMatrix v = arg.sparse_complex_matrix_value ();
00511
00512 if (! error_state)
00513 retval = find_nonzero_elem_idx (v, nargout,
00514 n_to_find, direction);
00515 }
00516 else
00517 gripe_wrong_type_arg ("find", arg);
00518 }
00519 else if (arg.is_perm_matrix ())
00520 {
00521 PermMatrix P = arg.perm_matrix_value ();
00522
00523 if (! error_state)
00524 retval = find_nonzero_elem_idx (P, nargout, n_to_find, direction);
00525 }
00526 else if (arg.is_string ())
00527 {
00528 charNDArray chnda = arg.char_array_value ();
00529
00530 if (! error_state)
00531 retval = find_nonzero_elem_idx (chnda, nargout, n_to_find, direction);
00532 }
00533 else if (arg.is_single_type ())
00534 {
00535 if (arg.is_real_type ())
00536 {
00537 FloatNDArray nda = arg.float_array_value ();
00538
00539 if (! error_state)
00540 retval = find_nonzero_elem_idx (nda, nargout, n_to_find,
00541 direction);
00542 }
00543 else if (arg.is_complex_type ())
00544 {
00545 FloatComplexNDArray cnda = arg.float_complex_array_value ();
00546
00547 if (! error_state)
00548 retval = find_nonzero_elem_idx (cnda, nargout, n_to_find,
00549 direction);
00550 }
00551 }
00552 else if (arg.is_real_type ())
00553 {
00554 NDArray nda = arg.array_value ();
00555
00556 if (! error_state)
00557 retval = find_nonzero_elem_idx (nda, nargout, n_to_find, direction);
00558 }
00559 else if (arg.is_complex_type ())
00560 {
00561 ComplexNDArray cnda = arg.complex_array_value ();
00562
00563 if (! error_state)
00564 retval = find_nonzero_elem_idx (cnda, nargout, n_to_find, direction);
00565 }
00566 else
00567 gripe_wrong_type_arg ("find", arg);
00568
00569 return retval;
00570 }
00571
00572
00573
00574
00575
00576
00577
00578
00579
00580
00581
00582
00583
00584
00585
00586
00587
00588
00589
00590
00591
00592
00593
00594
00595
00596
00597
00598
00599
00600
00601
00602
00603
00604
00605
00606
00607
00608
00609
00610
00611
00612
00613
00614
00615
00616
00617
00618
00619
00620
00621
00622