GNU Octave  9.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
svd.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 1996-2024 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29 
30 #include "svd.h"
31 
32 #include "defun.h"
33 #include "error.h"
34 #include "errwarn.h"
35 #include "ovl.h"
36 #include "pr-output.h"
37 #include "utils.h"
38 #include "variables.h"
39 
41 
42 static std::string Vsvd_driver = "gesvd";
43 
44 template <typename T>
45 static typename math::svd<T>::Type
46 svd_type (int nargin, int nargout, const octave_value_list& args, const T& A)
47 {
48  if (nargout == 0 || nargout == 1)
49  return math::svd<T>::Type::sigma_only;
50  else if (nargin == 1)
51  return math::svd<T>::Type::std;
52  else if (! args(1).is_real_scalar ())
53  return math::svd<T>::Type::economy;
54  else
55  {
56  if (A.rows () > A.columns ())
57  return math::svd<T>::Type::economy;
58  else
59  return math::svd<T>::Type::std;
60  }
61 }
62 
63 template <typename T>
64 static typename math::svd<T>::Driver
65 svd_driver ()
66 {
67  if (Vsvd_driver == "gejsv")
68  return math::svd<T>::Driver::GEJSV;
69  else if (Vsvd_driver == "gesdd")
70  return math::svd<T>::Driver::GESDD;
71  else
72  return math::svd<T>::Driver::GESVD; // default
73 }
74 
75 DEFUN (svd, args, nargout,
76  classes: double single
77  doc: /* -*- texinfo -*-
78 @deftypefn {} {@var{s} =} svd (@var{A})
79 @deftypefnx {} {[@var{U}, @var{S}, @var{V}] =} svd (@var{A})
80 @deftypefnx {} {[@var{U}, @var{S}, @var{V}] =} svd (@var{A}, "econ")
81 @deftypefnx {} {[@var{U}, @var{S}, @var{V}] =} svd (@var{A}, 0)
82 @cindex singular value decomposition
83 Compute the singular value decomposition of @var{A}.
84 
85 The singular value decomposition is defined by the relation
86 
87 @tex
88 $$
89  A = U S V^{\dagger}
90 $$
91 @end tex
92 @ifnottex
93 
94 @example
95 A = U*S*V'
96 @end example
97 
98 @end ifnottex
99 
100 The function @code{svd} normally returns only the vector of singular values.
101 When called with three return values, it computes
102 @tex
103 $U$, $S$, and $V$.
104 @end tex
105 @ifnottex
106 @var{U}, @var{S}, and @var{V}.
107 @end ifnottex
108 For example,
109 
110 @example
111 svd (hilb (3))
112 @end example
113 
114 @noindent
115 returns
116 
117 @example
118 @group
119 ans =
120 
121  1.4083189
122  0.1223271
123  0.0026873
124 @end group
125 @end example
126 
127 @noindent
128 and
129 
130 @example
131 [u, s, v] = svd (hilb (3))
132 @end example
133 
134 @noindent
135 returns
136 
137 @example
138 @group
139 u =
140 
141  -0.82704 0.54745 0.12766
142  -0.45986 -0.52829 -0.71375
143  -0.32330 -0.64901 0.68867
144 
145 s =
146 
147  1.40832 0.00000 0.00000
148  0.00000 0.12233 0.00000
149  0.00000 0.00000 0.00269
150 
151 v =
152 
153  -0.82704 0.54745 0.12766
154  -0.45986 -0.52829 -0.71375
155  -0.32330 -0.64901 0.68867
156 @end group
157 @end example
158 
159 When given a second argument that is not 0, @code{svd} returns an economy-sized
160 decomposition, eliminating the unnecessary rows or columns of @var{U} or
161 @var{V}.
162 
163 If the second argument is exactly 0, then the choice of decomposition is based
164 on the matrix @var{A}. If @var{A} has more rows than columns then an
165 economy-sized decomposition is returned, otherwise a regular decomposition
166 is calculated.
167 
168 Algorithm Notes: When calculating the full decomposition (left and right
169 singular matrices in addition to singular values) there is a choice of two
170 routines in @sc{lapack}. The default routine used by Octave is @code{gesvd}.
171 The alternative is @code{gesdd} which is 5X faster, but may use more memory
172 and may be inaccurate for some input matrices. There is a third routine
173 @code{gejsv}, suitable for better accuracy at extreme scale. See the
174 documentation for @code{svd_driver} for more information on choosing a driver.
175 @seealso{svd_driver, svds, eig, lu, chol, hess, qr, qz}
176 @end deftypefn */)
177 {
178  int nargin = args.length ();
179 
180  if (nargin < 1 || nargin > 2 || nargout > 3)
181  print_usage ();
182 
183  octave_value arg = args(0);
184 
185  if (arg.ndims () != 2)
186  error ("svd: A must be a 2-D matrix");
187 
188  octave_value_list retval;
189 
190  bool isfloat = arg.is_single_type ();
191 
192  if (isfloat)
193  {
194  if (arg.isreal ())
195  {
196  FloatMatrix tmp = arg.float_matrix_value ();
197 
198  if (tmp.any_element_is_inf_or_nan ())
199  error ("svd: cannot take SVD of matrix containing Inf or NaN values");
200 
201  math::svd<FloatMatrix> result
202  (tmp,
203  svd_type<FloatMatrix> (nargin, nargout, args, tmp),
204  svd_driver<FloatMatrix> ());
205 
206  FloatDiagMatrix sigma = result.singular_values ();
207 
208  if (nargout == 0 || nargout == 1)
209  retval(0) = sigma.extract_diag ();
210  else if (nargout == 2)
211  retval = ovl (result.left_singular_matrix (),
212  sigma);
213  else
214  retval = ovl (result.left_singular_matrix (),
215  sigma,
216  result.right_singular_matrix ());
217  }
218  else if (arg.iscomplex ())
219  {
221 
222  if (ctmp.any_element_is_inf_or_nan ())
223  error ("svd: cannot take SVD of matrix containing Inf or NaN values");
224 
225  math::svd<FloatComplexMatrix> result
226  (ctmp,
227  svd_type<FloatComplexMatrix> (nargin, nargout, args, ctmp),
228  svd_driver<FloatComplexMatrix> ());
229 
230  FloatDiagMatrix sigma = result.singular_values ();
231 
232  if (nargout == 0 || nargout == 1)
233  retval(0) = sigma.extract_diag ();
234  else if (nargout == 2)
235  retval = ovl (result.left_singular_matrix (),
236  sigma);
237  else
238  retval = ovl (result.left_singular_matrix (),
239  sigma,
240  result.right_singular_matrix ());
241  }
242  }
243  else
244  {
245  if (arg.isreal ())
246  {
247  Matrix tmp = arg.matrix_value ();
248 
249  if (tmp.any_element_is_inf_or_nan ())
250  error ("svd: cannot take SVD of matrix containing Inf or NaN values");
251 
252  math::svd<Matrix> result
253  (tmp,
254  svd_type<Matrix> (nargin, nargout, args, tmp),
255  svd_driver<Matrix> ());
256 
257  DiagMatrix sigma = result.singular_values ();
258 
259  if (nargout == 0 || nargout == 1)
260  retval(0) = sigma.extract_diag ();
261  else if (nargout == 2)
262  retval = ovl (result.left_singular_matrix (),
263  sigma);
264  else
265  retval = ovl (result.left_singular_matrix (),
266  sigma,
267  result.right_singular_matrix ());
268  }
269  else if (arg.iscomplex ())
270  {
271  ComplexMatrix ctmp = arg.complex_matrix_value ();
272 
273  if (ctmp.any_element_is_inf_or_nan ())
274  error ("svd: cannot take SVD of matrix containing Inf or NaN values");
275 
276  math::svd<ComplexMatrix> result
277  (ctmp,
278  svd_type<ComplexMatrix> (nargin, nargout, args, ctmp),
279  svd_driver<ComplexMatrix> ());
280 
281  DiagMatrix sigma = result.singular_values ();
282 
283  if (nargout == 0 || nargout == 1)
284  retval(0) = sigma.extract_diag ();
285  else if (nargout == 2)
286  retval = ovl (result.left_singular_matrix (),
287  sigma);
288  else
289  retval = ovl (result.left_singular_matrix (),
290  sigma,
291  result.right_singular_matrix ());
292  }
293  else
294  err_wrong_type_arg ("svd", arg);
295  }
296 
297  return retval;
298 }
299 
300 /*
301 %!assert (svd ([1, 2; 2, 1]), [3; 1], sqrt (eps))
302 
303 %!test
304 %! a = [1, 2; 3, 4] + [5, 6; 7, 8]*i;
305 %! [u,s,v] = svd (a);
306 %! assert (a, u * s * v', 128 * eps);
307 
308 %!test
309 %! [u, s, v] = svd ([1, 2; 2, 1]);
310 %! x = 1 / sqrt (2);
311 %! assert (u, [-x, -x; -x, x], sqrt (eps));
312 %! assert (s, [3, 0; 0, 1], sqrt (eps));
313 %! assert (v, [-x, x; -x, -x], sqrt (eps));
314 
315 %!test
316 %! a = [1, 2, 3; 4, 5, 6];
317 %! [u, s, v] = svd (a);
318 %! assert (u * s * v', a, sqrt (eps));
319 
320 %!test
321 %! a = [1, 2; 3, 4; 5, 6];
322 %! [u, s, v] = svd (a);
323 %! assert (u * s * v', a, sqrt (eps));
324 
325 %!test
326 %! a = [1, 2, 3; 4, 5, 6];
327 %! [u, s, v] = svd (a, 1);
328 %! assert (u * s * v', a, sqrt (eps));
329 
330 %!test
331 %! a = [1, 2; 3, 4; 5, 6];
332 %! [u, s, v] = svd (a, 1);
333 %! assert (u * s * v', a, sqrt (eps));
334 
335 %!assert (svd (single ([1, 2; 2, 1])), single ([3; 1]), sqrt (eps ("single")))
336 
337 %!test
338 %! [u, s, v] = svd (single ([1, 2; 2, 1]));
339 %! x = single (1 / sqrt (2));
340 %! assert (u, [-x, -x; -x, x], sqrt (eps ("single")));
341 %! assert (s, single ([3, 0; 0, 1]), sqrt (eps ("single")));
342 %! assert (v, [-x, x; -x, -x], sqrt (eps ("single")));
343 
344 %!test
345 %! a = single ([1, 2, 3; 4, 5, 6]);
346 %! [u, s, v] = svd (a);
347 %! assert (u * s * v', a, sqrt (eps ("single")));
348 
349 %!test
350 %! a = single ([1, 2; 3, 4; 5, 6]);
351 %! [u, s, v] = svd (a);
352 %! assert (u * s * v', a, sqrt (eps ("single")));
353 
354 %!test
355 %! a = single ([1, 2, 3; 4, 5, 6]);
356 %! [u, s, v] = svd (a, 1);
357 %! assert (u * s * v', a, sqrt (eps ("single")));
358 
359 %!test
360 %! a = single ([1, 2; 3, 4; 5, 6]);
361 %! [u, s, v] = svd (a, 1);
362 %! assert (u * s * v', a, sqrt (eps ("single")));
363 
364 %!test
365 %! a = zeros (0, 5);
366 %! [u, s, v] = svd (a);
367 %! assert (size (u), [0, 0]);
368 %! assert (size (s), [0, 5]);
369 %! assert (size (v), [5, 5]);
370 
371 %!test
372 %! a = zeros (5, 0);
373 %! [u, s, v] = svd (a, 1);
374 %! assert (size (u), [5, 0]);
375 %! assert (size (s), [0, 0]);
376 %! assert (size (v), [0, 0]);
377 
378 %!test <*49309>
379 %! [~,~,v] = svd ([1, 1, 1], 0);
380 %! assert (size (v), [3 3]);
381 %! [~,~,v] = svd ([1, 1, 1], "econ");
382 %! assert (size (v), [3 1]);
383 
384 %!assert <*55710> (1 / svd (-0), Inf)
385 
386 %!test
387 %! old_driver = svd_driver ("gejsv");
388 %! s0 = [1e-20; 1e-10; 1]; # only gejsv can pass
389 %! q = sqrt (0.5);
390 %! a = s0 .* [q, 0, -q; -0.5, q, -0.5; 0.5, q, 0.5];
391 %! s1 = svd (a);
392 %! svd_driver (old_driver);
393 %! assert (sort (s1), s0, -10 * eps);
394 
395 %!error svd ()
396 %!error svd ([1, 2; 4, 5], 2, 3)
397 */
398 
399 DEFUN (svd_driver, args, nargout,
400  doc: /* -*- texinfo -*-
401 @deftypefn {} {@var{val} =} svd_driver ()
402 @deftypefnx {} {@var{old_val} =} svd_driver (@var{new_val})
403 @deftypefnx {} {@var{old_val} =} svd_driver (@var{new_val}, "local")
404 Query or set the underlying @sc{lapack} driver used by @code{svd}.
405 
406 Currently recognized values are @qcode{"gesdd"}, @qcode{"gesvd"}, and
407 @qcode{"gejsv"}. The default is @qcode{"gesvd"}.
408 
409 When called from inside a function with the @qcode{"local"} option, the
410 variable is changed locally for the function and any subroutines it calls.
411 The original variable value is restored when exiting the function.
412 
413 Algorithm Notes: The @sc{lapack} library routines @code{gesvd} and @code{gesdd}
414 are different only when calculating the full singular value decomposition (left
415 and right singular matrices as well as singular values). When calculating just
416 the singular values the following discussion is not relevant.
417 
418 The newer @code{gesdd} routine is based on a Divide-and-Conquer algorithm that
419 is 5X faster than the alternative @code{gesvd}, which is based on QR
420 factorization. However, the new algorithm can use significantly more memory.
421 For an @nospell{MxN} input matrix the memory usage is of order O(min(M,N) ^ 2),
422 whereas the alternative is of order O(max(M,N)).
423 
424 The routine @code{gejsv} uses a preconditioned Jacobi SVD algorithm. Unlike
425 @code{gesvd} and @code{gesdd}, in @code{gejsv}, there is no bidiagonalization
426 step that could contaminate accuracy in some extreme cases. Also, @code{gejsv}
427 is known to be optimally accurate in some sense. However, the speed is slower
428 (single threaded at its core) and uses more memory (O(min(M,N) ^ 2 + M + N)).
429 
430 Beyond speed and memory issues, there have been instances where some input
431 matrices were not accurately decomposed by @code{gesdd}. See currently active
432 bug @url{https://savannah.gnu.org/bugs/?55564}. Until these accuracy issues
433 are resolved in a new version of the @sc{lapack} library, the default driver
434 in Octave has been set to @qcode{"gesvd"}.
435 
436 @seealso{svd}
437 @end deftypefn */)
438 {
439  static const char *driver_names[] = { "gesvd", "gesdd", "gejsv", nullptr };
440 
441  return set_internal_variable (Vsvd_driver, args, nargout,
442  "svd_driver", driver_names);
443 }
444 
445 /*
446 %!test
447 %! A = [1+1i, 1-1i, 0; 0, 2, 0; 1i, 1i, 1+2i];
448 %! old_driver = svd_driver ("gesvd");
449 %! [U1, S1, V1] = svd (A);
450 %! svd_driver ("gesdd");
451 %! [U2, S2, V2] = svd (A);
452 %! svd_driver ("gejsv");
453 %! [U3, S3, V3] = svd (A);
454 %! assert (svd_driver (), "gejsv");
455 %! svd_driver (old_driver);
456 %! assert (U1, U2, 6*eps);
457 %! assert (S1, S2, 6*eps);
458 %! assert (V1, V2, 6*eps);
459 %! z = U1(1,:) ./ U3(1,:);
460 %! assert (U1, U3 .* z, 100*eps);
461 %! assert (S1, S3, 6*eps);
462 %! assert (V1, V3 .* z, 100*eps);
463 */
464 
465 OCTAVE_END_NAMESPACE(octave)
bool any_element_is_inf_or_nan() const
Definition: CNDArray.cc:271
ColumnVector extract_diag(octave_idx_type k=0) const
Definition: dDiagMatrix.h:107
bool any_element_is_inf_or_nan() const
Definition: fCNDArray.cc:271
FloatColumnVector extract_diag(octave_idx_type k=0) const
Definition: fDiagMatrix.h:110
bool any_element_is_inf_or_nan() const
Definition: fNDArray.cc:281
Definition: dMatrix.h:42
bool any_element_is_inf_or_nan() const
Definition: dNDArray.cc:324
octave_idx_type length() const
Definition: ovl.h:113
int ndims() const
Definition: ov.h:551
bool isreal() const
Definition: ov.h:738
ComplexMatrix complex_matrix_value(bool frc_str_conv=false) const
Definition: ov.h:871
bool is_single_type() const
Definition: ov.h:698
FloatMatrix float_matrix_value(bool frc_str_conv=false) const
Definition: ov.h:856
bool iscomplex() const
Definition: ov.h:741
Matrix matrix_value(bool frc_str_conv=false) const
Definition: ov.h:853
FloatComplexMatrix float_complex_matrix_value(bool frc_str_conv=false) const
Definition: ov.h:875
Definition: svd.h:41
OCTAVE_BEGIN_NAMESPACE(octave) static octave_value daspk_fcn
void print_usage(void)
Definition: defun-int.h:72
#define DEFUN(name, args_name, nargout_name, doc)
Macro to define a builtin function.
Definition: defun.h:56
void() error(const char *fmt,...)
Definition: error.cc:988
void err_wrong_type_arg(const char *name, const char *s)
Definition: errwarn.cc:166
F77_RET_T const F77_INT F77_CMPLX * A
octave_value set_internal_variable(bool &var, const octave_value_list &args, int nargout, const char *nm)
Definition: variables.cc:583
octave_value_list ovl(const OV_Args &... args)
Construct an octave_value_list with less typing.
Definition: ovl.h:219