GNU Octave  8.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
svd.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 1994-2023 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29 
30 #include <cassert>
31 
32 #include <algorithm>
33 #include <unordered_map>
34 
35 #include "CMatrix.h"
36 #include "dDiagMatrix.h"
37 #include "dMatrix.h"
38 #include "fCMatrix.h"
39 #include "fDiagMatrix.h"
40 #include "fMatrix.h"
41 #include "lo-error.h"
42 #include "lo-lapack-proto.h"
43 #include "svd.h"
44 
45 // class to compute optimal work space size (lwork) for DGEJSV and SGEJSV
46 template<typename T>
47 class
49 {
50 public:
51  gejsv_lwork () = delete;
52 
53  // Unfortunately, dgejsv and sgejsv do not provide estimation of 'lwork'.
54  // Thus, we have to estimate it according to corresponding LAPACK
55  // documentation and related source codes (e.g. cgejsv).
56  // In LAPACKE (C interface to LAPACK), the memory handling code in
57  // LAPACKE_dgejsv() (lapacke_dgejsv.c, last visit 2019-02-17) uses
58  // the minimum required working space. In contrast, here the optimal
59  // working space size is computed, at the cost of much longer code.
60 
61  static F77_INT optimal (char& joba, char& jobu, char& jobv,
62  F77_INT m, F77_INT n);
63 
64 private:
65  typedef typename T::element_type P;
66 
67  // functions could be called from GEJSV
69  P *a, F77_INT lda,
70  F77_INT *jpvt, P *tau, P *work,
71  F77_INT lwork, F77_INT& info);
72 
74  P *a, F77_INT lda,
75  P *tau, P *work,
76  F77_INT lwork, F77_INT& info);
77 
79  P *a, F77_INT lda,
80  P *tau, P *work,
81  F77_INT lwork, F77_INT& info);
82 
83  static F77_INT ormlq_lwork (char& side, char& trans,
85  P *a, F77_INT lda,
86  P *tau, P *c, F77_INT ldc,
87  P *work, F77_INT lwork, F77_INT& info);
88 
89  static F77_INT ormqr_lwork (char& side, char& trans,
91  P *a, F77_INT lda,
92  P *tau, P *c, F77_INT ldc,
93  P *work, F77_INT lwork, F77_INT& info);
94 };
95 
96 #define GEJSV_REAL_QP3_LWORK(f, F) \
97  F77_XFCN (f, F, (m, n, a, lda, jpvt, tau, work, lwork, info))
98 
99 #define GEJSV_REAL_QR_LWORK(f, F) \
100  F77_XFCN (f, F, (m, n, a, lda, tau, work, lwork, info))
101 
102 #define GEJSV_REAL_ORM_LWORK(f, F) \
103  F77_XFCN (f, F, (F77_CONST_CHAR_ARG2 (&side, 1), \
104  F77_CONST_CHAR_ARG2 (&trans, 1), \
105  m, n, k, a, lda, tau, \
106  c, ldc, work, lwork, info \
107  F77_CHAR_ARG_LEN (1) \
108  F77_CHAR_ARG_LEN (1)))
109 
110 // For Matrix
111 template<>
112 F77_INT
114  P *a, F77_INT lda,
115  F77_INT *jpvt, P *tau, P *work,
116  F77_INT lwork, F77_INT& info)
117 {
118  GEJSV_REAL_QP3_LWORK (dgeqp3, DGEQP3);
119  return static_cast<F77_INT> (work[0]);
120 }
121 
122 template<>
123 F77_INT
125  P *a, F77_INT lda,
126  P *tau, P *work,
127  F77_INT lwork, F77_INT& info)
128 {
129  GEJSV_REAL_QR_LWORK (dgeqrf, DGEQRF);
130  return static_cast<F77_INT> (work[0]);
131 }
132 
133 template<>
134 F77_INT
136  P *a, F77_INT lda,
137  P *tau, P *work,
138  F77_INT lwork, F77_INT& info)
139 {
140  GEJSV_REAL_QR_LWORK (dgelqf, DGELQF);
141  return static_cast<F77_INT> (work[0]);
142 }
143 
144 template<>
145 F77_INT
146 gejsv_lwork<Matrix>::ormlq_lwork (char& side, char& trans,
147  F77_INT m, F77_INT n, F77_INT k,
148  P *a, F77_INT lda,
149  P *tau, P *c, F77_INT ldc,
150  P *work, F77_INT lwork, F77_INT& info)
151 {
152  GEJSV_REAL_ORM_LWORK (dormlq, DORMLQ);
153  return static_cast<F77_INT> (work[0]);
154 }
155 
156 template<>
157 F77_INT
158 gejsv_lwork<Matrix>::ormqr_lwork (char& side, char& trans,
159  F77_INT m, F77_INT n, F77_INT k,
160  P *a, F77_INT lda,
161  P *tau, P *c, F77_INT ldc,
162  P *work, F77_INT lwork, F77_INT& info)
163 {
164  GEJSV_REAL_ORM_LWORK (dormqr, DORMQR);
165  return static_cast<F77_INT> (work[0]);
166 }
167 
168 // For FloatMatrix
169 template<>
170 F77_INT
172  P *a, F77_INT lda,
173  F77_INT *jpvt, P *tau, P *work,
174  F77_INT lwork, F77_INT& info)
175 {
176  GEJSV_REAL_QP3_LWORK (sgeqp3, SGEQP3);
177  return static_cast<F77_INT> (work[0]);
178 }
179 
180 template<>
181 F77_INT
183  P *a, F77_INT lda,
184  P *tau, P *work,
185  F77_INT lwork, F77_INT& info)
186 {
187  GEJSV_REAL_QR_LWORK (sgeqrf, SGEQRF);
188  return static_cast<F77_INT> (work[0]);
189 }
190 
191 template<>
192 F77_INT
194  P *a, F77_INT lda,
195  P *tau, P *work,
196  F77_INT lwork, F77_INT& info)
197 {
198  GEJSV_REAL_QR_LWORK (sgelqf, SGELQF);
199  return static_cast<F77_INT> (work[0]);
200 }
201 
202 template<>
203 F77_INT
204 gejsv_lwork<FloatMatrix>::ormlq_lwork (char& side, char& trans,
205  F77_INT m, F77_INT n, F77_INT k,
206  P *a, F77_INT lda,
207  P *tau, P *c, F77_INT ldc,
208  P *work, F77_INT lwork, F77_INT& info)
209 {
210  GEJSV_REAL_ORM_LWORK (sormlq, SORMLQ);
211  return static_cast<F77_INT> (work[0]);
212 }
213 
214 template<>
215 F77_INT
216 gejsv_lwork<FloatMatrix>::ormqr_lwork (char& side, char& trans,
217  F77_INT m, F77_INT n, F77_INT k,
218  P *a, F77_INT lda,
219  P *tau, P *c, F77_INT ldc,
220  P *work, F77_INT lwork, F77_INT& info)
221 {
222  GEJSV_REAL_ORM_LWORK (sormqr, SORMQR);
223  return static_cast<F77_INT> (work[0]);
224 }
225 
226 #undef GEJSV_REAL_QP3_LWORK
227 #undef GEJSV_REAL_QR_LWORK
228 #undef GEJSV_REAL_ORM_LWORK
229 
230 template<typename T>
231 F77_INT
232 gejsv_lwork<T>::optimal (char& joba, char& jobu, char& jobv,
233  F77_INT m, F77_INT n)
234 {
235  F77_INT lwork = -1;
236  std::vector<P> work (2); // dummy work space
237 
238  // variables that mimic running environment of gejsv
239  F77_INT lda = std::max<F77_INT> (m, 1);
240  F77_INT ierr = 0;
241  char side = 'L';
242  char trans = 'N';
243  std::vector<P> mat_a (1);
244  P *a = mat_a.data (); // dummy input matrix
245  std::vector<F77_INT> vec_jpvt = {0};
246  P *tau = work.data ();
247  P *u = work.data ();
248  P *v = work.data ();
249 
250  bool need_lsvec = jobu == 'U' || jobu == 'F';
251  bool need_rsvec = jobv == 'V' || jobv == 'J';
252 
253  F77_INT lw_pocon = 3 * n; // for [s,d]pocon
254  F77_INT lw_geqp3 = geqp3_lwork (m, n, a, lda, vec_jpvt.data (),
255  tau, work.data (), -1, ierr);
256  F77_INT lw_geqrf = geqrf_lwork (m, n, a, lda,
257  tau, work.data (), -1, ierr);
258 
259  if (! (need_lsvec || need_rsvec) )
260  {
261  // only SIGMA is needed
262  if (! (joba == 'E' || joba == 'G') )
263  lwork = std::max<F77_INT> ({2*m + n, n + lw_geqp3, n + lw_geqrf, 7});
264  else
265  lwork = std::max<F77_INT> ({2*m + n, n + lw_geqp3, n + lw_geqrf,
266  n + n*n + lw_pocon, 7});
267  }
268  else if (need_rsvec && ! need_lsvec)
269  {
270  // SIGMA and the right singular vectors are needed
271  F77_INT lw_gelqf = gelqf_lwork (n, n, a, lda,
272  tau, work.data (), -1, ierr);
273  trans = 'T';
274  F77_INT lw_ormlq = ormlq_lwork (side, trans, n, n, n, a, lda,
275  tau, v, n, work.data (), -1, ierr);
276  lwork = std::max<F77_INT> ({2*m + n, n + lw_geqp3, n + lw_pocon,
277  n + lw_gelqf, 2*n + lw_geqrf, n + lw_ormlq});
278  }
279  else if (need_lsvec && ! need_rsvec)
280  {
281  // SIGMA and the left singular vectors are needed
282  F77_INT n1 = (jobu == 'U') ? n : m; // size of U is m x n1
283  F77_INT lw_ormqr = ormqr_lwork (side, trans, m, n1, n, a, lda,
284  tau, u, m, work.data (), -1, ierr);
285  lwork = std::max<F77_INT> ({2*m + n, n + lw_geqp3, n + lw_pocon,
286  2*n + lw_geqrf, n + lw_ormqr});
287  }
288  else // full SVD is needed
289  {
290  if (jobv == 'V')
291  lwork = std::max (2*m + n, 6*n + 2*n*n);
292  else if (jobv == 'J')
293  lwork = std::max<F77_INT> ({2*m + n, 4*n + n*n, 2*n + n*n + 6});
294 
295  F77_INT n1 = (jobu == 'U') ? n : m; // size of U is m x n1
296  F77_INT lw_ormqr = ormqr_lwork (side, trans, m, n1, n, a, lda,
297  tau, u, m, work.data (), -1, ierr);
298  lwork = std::max (lwork, n + lw_ormqr);
299  }
300 
301  return lwork;
302 }
303 
305 
307 
308 template <typename T>
309 T
311 {
312  if (m_type == svd::Type::sigma_only)
313  (*current_liboctave_error_handler)
314  ("svd: U not computed because type == svd::sigma_only");
315 
316  return m_left_sm;
317 }
318 
319 template <typename T>
320 T
322 {
323  if (m_type == svd::Type::sigma_only)
324  (*current_liboctave_error_handler)
325  ("svd: V not computed because type == svd::sigma_only");
326 
327  return m_right_sm;
328 }
329 
330 // GESVD specializations
331 
332 #define GESVD_REAL_STEP(f, F) \
333  F77_XFCN (f, F, (F77_CONST_CHAR_ARG2 (&jobu, 1), \
334  F77_CONST_CHAR_ARG2 (&jobv, 1), \
335  m, n, tmp_data, m1, s_vec, u, m1, vt, \
336  nrow_vt1, work.data (), lwork, info \
337  F77_CHAR_ARG_LEN (1) \
338  F77_CHAR_ARG_LEN (1)))
339 
340 #define GESVD_COMPLEX_STEP(f, F, CMPLX_ARG) \
341  F77_XFCN (f, F, (F77_CONST_CHAR_ARG2 (&jobu, 1), \
342  F77_CONST_CHAR_ARG2 (&jobv, 1), \
343  m, n, CMPLX_ARG (tmp_data), \
344  m1, s_vec, CMPLX_ARG (u), m1, \
345  CMPLX_ARG (vt), nrow_vt1, \
346  CMPLX_ARG (work.data ()), \
347  lwork, rwork.data (), info \
348  F77_CHAR_ARG_LEN (1) \
349  F77_CHAR_ARG_LEN (1)))
350 
351 // DGESVD
352 template<>
353 OCTAVE_API void
354 svd<Matrix>::gesvd (char& jobu, char& jobv, F77_INT m, F77_INT n,
355  double *tmp_data, F77_INT m1, double *s_vec,
356  double *u, double *vt, F77_INT nrow_vt1,
357  std::vector<double>& work, F77_INT& lwork,
358  F77_INT& info)
359 {
360  GESVD_REAL_STEP (dgesvd, DGESVD);
361 
362  lwork = static_cast<F77_INT> (work[0]);
363  work.reserve (lwork);
364 
365  GESVD_REAL_STEP (dgesvd, DGESVD);
366 }
367 
368 // SGESVD
369 template<>
370 OCTAVE_API void
371 svd<FloatMatrix>::gesvd (char& jobu, char& jobv, F77_INT m, F77_INT n,
372  float *tmp_data, F77_INT m1, float *s_vec,
373  float *u, float *vt, F77_INT nrow_vt1,
374  std::vector<float>& work, F77_INT& lwork,
375  F77_INT& info)
376 {
377  GESVD_REAL_STEP (sgesvd, SGESVD);
378 
379  lwork = static_cast<F77_INT> (work[0]);
380  work.reserve (lwork);
381 
382  GESVD_REAL_STEP (sgesvd, SGESVD);
383 }
384 
385 // ZGESVD
386 template<>
387 OCTAVE_API void
388 svd<ComplexMatrix>::gesvd (char& jobu, char& jobv, F77_INT m, F77_INT n,
389  Complex *tmp_data, F77_INT m1, double *s_vec,
390  Complex *u, Complex *vt, F77_INT nrow_vt1,
391  std::vector<Complex>& work, F77_INT& lwork,
392  F77_INT& info)
393 {
394  std::vector<double> rwork (5 * std::max (m, n));
395 
396  GESVD_COMPLEX_STEP (zgesvd, ZGESVD, F77_DBLE_CMPLX_ARG);
397 
398  lwork = static_cast<F77_INT> (work[0].real ());
399  work.reserve (lwork);
400 
401  GESVD_COMPLEX_STEP (zgesvd, ZGESVD, F77_DBLE_CMPLX_ARG);
402 }
403 
404 // CGESVD
405 template<>
406 OCTAVE_API void
407 svd<FloatComplexMatrix>::gesvd (char& jobu, char& jobv, F77_INT m,
408  F77_INT n, FloatComplex *tmp_data,
409  F77_INT m1, float *s_vec, FloatComplex *u,
410  FloatComplex *vt, F77_INT nrow_vt1,
411  std::vector<FloatComplex>& work,
412  F77_INT& lwork, F77_INT& info)
413 {
414  std::vector<float> rwork (5 * std::max (m, n));
415 
416  GESVD_COMPLEX_STEP (cgesvd, CGESVD, F77_CMPLX_ARG);
417 
418  lwork = static_cast<F77_INT> (work[0].real ());
419  work.reserve (lwork);
420 
421  GESVD_COMPLEX_STEP (cgesvd, CGESVD, F77_CMPLX_ARG);
422 }
423 
424 #undef GESVD_REAL_STEP
425 #undef GESVD_COMPLEX_STEP
426 
427 // GESDD specializations
428 
429 #define GESDD_REAL_STEP(f, F) \
430  F77_XFCN (f, F, (F77_CONST_CHAR_ARG2 (&jobz, 1), \
431  m, n, tmp_data, m1, s_vec, u, m1, vt, nrow_vt1, \
432  work.data (), lwork, iwork, info \
433  F77_CHAR_ARG_LEN (1)))
434 
435 #define GESDD_COMPLEX_STEP(f, F, CMPLX_ARG) \
436  F77_XFCN (f, F, (F77_CONST_CHAR_ARG2 (&jobz, 1), m, n, \
437  CMPLX_ARG (tmp_data), m1, \
438  s_vec, CMPLX_ARG (u), m1, \
439  CMPLX_ARG (vt), nrow_vt1, \
440  CMPLX_ARG (work.data ()), lwork, \
441  rwork.data (), iwork, info \
442  F77_CHAR_ARG_LEN (1)))
443 
444 // DGESDD
445 template<>
446 OCTAVE_API void
447 svd<Matrix>::gesdd (char& jobz, F77_INT m, F77_INT n, double *tmp_data,
448  F77_INT m1, double *s_vec, double *u, double *vt,
449  F77_INT nrow_vt1, std::vector<double>& work,
450  F77_INT& lwork, F77_INT *iwork, F77_INT& info)
451 {
452  GESDD_REAL_STEP (dgesdd, DGESDD);
453 
454  lwork = static_cast<F77_INT> (work[0]);
455  work.reserve (lwork);
456 
457  GESDD_REAL_STEP (dgesdd, DGESDD);
458 }
459 
460 // SGESDD
461 template<>
462 OCTAVE_API void
463 svd<FloatMatrix>::gesdd (char& jobz, F77_INT m, F77_INT n, float *tmp_data,
464  F77_INT m1, float *s_vec, float *u, float *vt,
465  F77_INT nrow_vt1, std::vector<float>& work,
466  F77_INT& lwork, F77_INT *iwork, F77_INT& info)
467 {
468  GESDD_REAL_STEP (sgesdd, SGESDD);
469 
470  lwork = static_cast<F77_INT> (work[0]);
471  work.reserve (lwork);
472 
473  GESDD_REAL_STEP (sgesdd, SGESDD);
474 }
475 
476 // ZGESDD
477 template<>
478 OCTAVE_API void
480  Complex *tmp_data, F77_INT m1, double *s_vec,
481  Complex *u, Complex *vt, F77_INT nrow_vt1,
482  std::vector<Complex>& work, F77_INT& lwork,
483  F77_INT *iwork, F77_INT& info)
484 {
485 
486  F77_INT min_mn = std::min (m, n);
487  F77_INT max_mn = std::max (m, n);
488 
489  F77_INT lrwork;
490  if (jobz == 'N')
491  lrwork = 7*min_mn;
492  else
493  lrwork = min_mn * std::max (5*min_mn+5, 2*max_mn+2*min_mn+1);
494 
495  std::vector<double> rwork (lrwork);
496 
497  GESDD_COMPLEX_STEP (zgesdd, ZGESDD, F77_DBLE_CMPLX_ARG);
498 
499  lwork = static_cast<F77_INT> (work[0].real ());
500  work.reserve (lwork);
501 
502  GESDD_COMPLEX_STEP (zgesdd, ZGESDD, F77_DBLE_CMPLX_ARG);
503 }
504 
505 // CGESDD
506 template<>
507 OCTAVE_API void
509  FloatComplex *tmp_data, F77_INT m1,
510  float *s_vec, FloatComplex *u,
511  FloatComplex *vt, F77_INT nrow_vt1,
512  std::vector<FloatComplex>& work,
513  F77_INT& lwork, F77_INT *iwork,
514  F77_INT& info)
515 {
516  F77_INT min_mn = std::min (m, n);
517  F77_INT max_mn = std::max (m, n);
518 
519  F77_INT lrwork;
520  if (jobz == 'N')
521  lrwork = 7*min_mn;
522  else
523  lrwork = min_mn * std::max (5*min_mn+5, 2*max_mn+2*min_mn+1);
524  std::vector<float> rwork (lrwork);
525 
526  GESDD_COMPLEX_STEP (cgesdd, CGESDD, F77_CMPLX_ARG);
527 
528  lwork = static_cast<F77_INT> (work[0].real ());
529  work.reserve (lwork);
530 
531  GESDD_COMPLEX_STEP (cgesdd, CGESDD, F77_CMPLX_ARG);
532 }
533 
534 #undef GESDD_REAL_STEP
535 #undef GESDD_COMPLEX_STEP
536 
537 // GEJSV specializations
538 
539 #define GEJSV_REAL_STEP(f, F) \
540  F77_XFCN (f, F, (F77_CONST_CHAR_ARG2 (&joba, 1), \
541  F77_CONST_CHAR_ARG2 (&jobu, 1), \
542  F77_CONST_CHAR_ARG2 (&jobv, 1), \
543  F77_CONST_CHAR_ARG2 (&jobr, 1), \
544  F77_CONST_CHAR_ARG2 (&jobt, 1), \
545  F77_CONST_CHAR_ARG2 (&jobp, 1), \
546  m, n, tmp_data, m1, s_vec, u, m1, v, nrow_v1, \
547  work.data (), lwork, iwork.data (), info \
548  F77_CHAR_ARG_LEN (1) \
549  F77_CHAR_ARG_LEN (1) \
550  F77_CHAR_ARG_LEN (1) \
551  F77_CHAR_ARG_LEN (1) \
552  F77_CHAR_ARG_LEN (1) \
553  F77_CHAR_ARG_LEN (1)))
554 
555 #define GEJSV_COMPLEX_STEP(f, F, CMPLX_ARG) \
556  F77_XFCN (f, F, (F77_CONST_CHAR_ARG2 (&joba, 1), \
557  F77_CONST_CHAR_ARG2 (&jobu, 1), \
558  F77_CONST_CHAR_ARG2 (&jobv, 1), \
559  F77_CONST_CHAR_ARG2 (&jobr, 1), \
560  F77_CONST_CHAR_ARG2 (&jobt, 1), \
561  F77_CONST_CHAR_ARG2 (&jobp, 1), \
562  m, n, CMPLX_ARG (tmp_data), m1, \
563  s_vec, CMPLX_ARG (u), m1, \
564  CMPLX_ARG (v), nrow_v1, \
565  CMPLX_ARG (work.data ()), lwork, \
566  rwork.data (), lrwork, iwork.data (), info \
567  F77_CHAR_ARG_LEN (1) \
568  F77_CHAR_ARG_LEN (1) \
569  F77_CHAR_ARG_LEN (1) \
570  F77_CHAR_ARG_LEN (1) \
571  F77_CHAR_ARG_LEN (1) \
572  F77_CHAR_ARG_LEN (1)))
573 
574 // DGEJSV
575 template<>
576 void
577 svd<Matrix>::gejsv (char& joba, char& jobu, char& jobv,
578  char& jobr, char& jobt, char& jobp,
579  F77_INT m, F77_INT n,
580  P *tmp_data, F77_INT m1, DM_P *s_vec, P *u,
581  P *v, F77_INT nrow_v1, std::vector<P>& work,
582  F77_INT& lwork, std::vector<F77_INT>& iwork,
583  F77_INT& info)
584 {
585  lwork = gejsv_lwork<Matrix>::optimal (joba, jobu, jobv, m, n);
586  work.reserve (lwork);
587 
588  GEJSV_REAL_STEP (dgejsv, DGEJSV);
589 }
590 
591 // SGEJSV
592 template<>
593 void
594 svd<FloatMatrix>::gejsv (char& joba, char& jobu, char& jobv,
595  char& jobr, char& jobt, char& jobp,
596  F77_INT m, F77_INT n,
597  P *tmp_data, F77_INT m1, DM_P *s_vec, P *u,
598  P *v, F77_INT nrow_v1, std::vector<P>& work,
599  F77_INT& lwork, std::vector<F77_INT>& iwork,
600  F77_INT& info)
601 {
602  lwork = gejsv_lwork<FloatMatrix>::optimal (joba, jobu, jobv, m, n);
603  work.reserve (lwork);
604 
605  GEJSV_REAL_STEP (sgejsv, SGEJSV);
606 }
607 
608 // ZGEJSV
609 template<>
610 void
611 svd<ComplexMatrix>::gejsv (char& joba, char& jobu, char& jobv,
612  char& jobr, char& jobt, char& jobp,
613  F77_INT m, F77_INT n,
614  P *tmp_data, F77_INT m1, DM_P *s_vec, P *u,
615  P *v, F77_INT nrow_v1, std::vector<P>& work,
616  F77_INT& lwork, std::vector<F77_INT>& iwork,
617  F77_INT& info)
618 {
619  F77_INT lrwork = -1; // work space size query
620  std::vector<double> rwork (1);
621  work.reserve (2);
622 
623  GEJSV_COMPLEX_STEP (zgejsv, ZGEJSV, F77_DBLE_CMPLX_ARG);
624 
625  lwork = static_cast<F77_INT> (work[0].real ());
626  work.reserve (lwork);
627 
628  lrwork = static_cast<F77_INT> (rwork[0]);
629  rwork.reserve (lrwork);
630 
631  F77_INT liwork = static_cast<F77_INT> (iwork[0]);
632  iwork.reserve (liwork);
633 
634  GEJSV_COMPLEX_STEP (zgejsv, ZGEJSV, F77_DBLE_CMPLX_ARG);
635 }
636 
637 // CGEJSV
638 template<>
639 void
640 svd<FloatComplexMatrix>::gejsv (char& joba, char& jobu, char& jobv,
641  char& jobr, char& jobt, char& jobp,
642  F77_INT m, F77_INT n, P *tmp_data,
643  F77_INT m1, DM_P *s_vec, P *u, P *v,
644  F77_INT nrow_v1, std::vector<P>& work,
645  F77_INT& lwork,
646  std::vector<F77_INT>& iwork, F77_INT& info)
647 {
648  F77_INT lrwork = -1; // work space size query
649  std::vector<float> rwork (1);
650  work.reserve (2);
651 
652  GEJSV_COMPLEX_STEP (cgejsv, CGEJSV, F77_CMPLX_ARG);
653 
654  lwork = static_cast<F77_INT> (work[0].real ());
655  work.reserve (lwork);
656 
657  lrwork = static_cast<F77_INT> (rwork[0]);
658  rwork.reserve (lrwork);
659 
660  F77_INT liwork = static_cast<F77_INT> (iwork[0]);
661  iwork.reserve (liwork);
662 
663  GEJSV_COMPLEX_STEP (cgejsv, CGEJSV, F77_CMPLX_ARG);
664 }
665 
666 #undef GEJSV_REAL_STEP
667 #undef GEJSV_COMPLEX_STEP
668 
669 template<typename T>
670 svd<T>::svd (const T& a, svd::Type type, svd::Driver driver)
671  : m_type (type), m_driver (driver), m_left_sm (), m_sigma (),
672  m_right_sm ()
673 {
674  F77_INT info;
675 
676  F77_INT m = to_f77_int (a.rows ());
677  F77_INT n = to_f77_int (a.cols ());
678 
679  if (m == 0 || n == 0)
680  {
681  switch (m_type)
682  {
683  case svd::Type::std:
684  m_left_sm = T (m, m, 0);
685  for (F77_INT i = 0; i < m; i++)
686  m_left_sm.xelem (i, i) = 1;
687  m_sigma = DM_T (m, n);
688  m_right_sm = T (n, n, 0);
689  for (F77_INT i = 0; i < n; i++)
690  m_right_sm.xelem (i, i) = 1;
691  break;
692 
693  case svd::Type::economy:
694  m_left_sm = T (m, 0, 0);
695  m_sigma = DM_T (0, 0);
696  m_right_sm = T (n, 0, 0);
697  break;
698 
700  default:
701  m_sigma = DM_T (0, 1);
702  break;
703  }
704  return;
705  }
706 
707  T atmp = a;
708  P *tmp_data = atmp.fortran_vec ();
709 
710  F77_INT min_mn = (m < n ? m : n);
711 
712  char jobu = 'A';
713  char jobv = 'A';
714 
715  F77_INT ncol_u = m;
716  F77_INT nrow_vt = n;
717  F77_INT nrow_s = m;
718  F77_INT ncol_s = n;
719 
720  switch (m_type)
721  {
722  case svd::Type::economy:
723  jobu = jobv = 'S';
724  ncol_u = nrow_vt = nrow_s = ncol_s = min_mn;
725  break;
726 
728 
729  // Note: for this case, both jobu and jobv should be 'N', but there
730  // seems to be a bug in dgesvd from Lapack V2.0. To demonstrate the
731  // bug, set both jobu and jobv to 'N' and find the singular values of
732  // [eye(3), eye(3)]. The result is [-sqrt(2), -sqrt(2), -sqrt(2)].
733  //
734  // For Lapack 3.0, this problem seems to be fixed.
735 
736  jobu = jobv = 'N';
737  ncol_u = nrow_vt = 1;
738  break;
739 
740  default:
741  break;
742  }
743 
744  if (! (jobu == 'N' || jobu == 'O'))
745  m_left_sm.resize (m, ncol_u);
746 
747  P *u = m_left_sm.fortran_vec ();
748 
749  m_sigma.resize (nrow_s, ncol_s);
750  DM_P *s_vec = m_sigma.fortran_vec ();
751 
752  if (! (jobv == 'N' || jobv == 'O'))
753  {
755  m_right_sm.resize (n, nrow_vt);
756  else
757  m_right_sm.resize (nrow_vt, n);
758  }
759 
760  P *vt = m_right_sm.fortran_vec ();
761 
762  // Query _GESVD for the correct dimension of WORK.
763 
764  F77_INT lwork = -1;
765 
766  std::vector<P> work (1);
767 
768  const F77_INT f77_int_one = static_cast<F77_INT> (1);
769  F77_INT m1 = std::max (m, f77_int_one);
770  F77_INT nrow_vt1 = std::max (nrow_vt, f77_int_one);
771 
773  gesvd (jobu, jobv, m, n, tmp_data, m1, s_vec, u, vt, nrow_vt1,
774  work, lwork, info);
775  else if (m_driver == svd::Driver::GESDD)
776  {
777  assert (jobu == jobv);
778  char jobz = jobu;
779 
780  std::vector<F77_INT> iwork (8 * std::min (m, n));
781 
782  gesdd (jobz, m, n, tmp_data, m1, s_vec, u, vt, nrow_vt1,
783  work, lwork, iwork.data (), info);
784  }
785  else if (m_driver == svd::Driver::GEJSV)
786  {
787  bool transposed = false;
788  if (n > m)
789  {
790  // GEJSV only accepts m >= n, thus we need to transpose here
791  transposed = true;
792 
793  std::swap (m, n);
794  m1 = std::max (m, f77_int_one);
795  nrow_vt1 = std::max (n, f77_int_one); // we have m > n
797  nrow_vt1 = 1;
798  std::swap (jobu, jobv);
799 
800  atmp = atmp.hermitian ();
801  tmp_data = atmp.fortran_vec ();
802 
803  // Swap pointers of U and V.
804  u = m_right_sm.fortran_vec ();
805  vt = m_left_sm.fortran_vec ();
806  }
807 
808  // translate jobu and jobv from gesvd to gejsv.
809  std::unordered_map<char, std::string> job_svd2jsv;
810  job_svd2jsv['A'] = "FJ";
811  job_svd2jsv['S'] = "UV";
812  job_svd2jsv['O'] = "WW";
813  job_svd2jsv['N'] = "NN";
814  jobu = job_svd2jsv[jobu][0];
815  jobv = job_svd2jsv[jobv][1];
816 
817  char joba = 'F'; // 'F': most conservative
818  char jobr = 'R'; // 'R' is recommended.
819  char jobt = 'N'; // or 'T', but that requires U and V appear together
820  char jobp = 'N'; // use 'P' if denormal is poorly implemented.
821 
822  std::vector<F77_INT> iwork (std::max<F77_INT> (m + 3*n, 1));
823 
824  gejsv (joba, jobu, jobv, jobr, jobt, jobp, m, n, tmp_data, m1,
825  s_vec, u, vt, nrow_vt1, work, lwork, iwork, info);
826 
827  if (iwork[2] == 1)
828  (*current_liboctave_warning_with_id_handler)
829  ("Octave:convergence", "svd: (driver: GEJSV) "
830  "Denormal occurred, possible loss of accuracy.");
831 
832  if (info < 0)
833  (*current_liboctave_error_handler)
834  ("svd: (driver: GEJSV) Illegal argument at #%d",
835  static_cast<int> (-info));
836  else if (info > 0)
837  (*current_liboctave_warning_with_id_handler)
838  ("Octave:convergence", "svd: (driver: GEJSV) "
839  "Fail to converge within max sweeps, "
840  "possible inaccurate result.");
841 
842  if (transposed) // put things that need to transpose back here
843  std::swap (m, n);
844  }
845  else
846  (*current_liboctave_error_handler) ("svd: unknown driver");
847 
848  // LAPACK can return -0 which is a small problem (bug #55710).
849  for (octave_idx_type i = 0; i < m_sigma.diag_length (); i++)
850  {
851  if (! m_sigma.dgxelem (i))
852  m_sigma.dgxelem (i) = DM_P (0);
853  }
854 
855  // GESVD and GESDD return VT instead of V, GEJSV return V.
856  if (! (jobv == 'N' || jobv == 'O') && (m_driver != svd::Driver::GEJSV))
857  m_right_sm = m_right_sm.hermitian ();
858 }
859 
860 // Instantiations we need.
861 
862 template class svd<Matrix>;
863 
864 template class svd<FloatMatrix>;
865 
866 template class svd<ComplexMatrix>;
867 
868 template class svd<FloatComplexMatrix>;
869 
OCTAVE_END_NAMESPACE(octave)
charNDArray max(char d, const charNDArray &m)
Definition: chNDArray.cc:230
charNDArray min(char d, const charNDArray &m)
Definition: chNDArray.cc:207
static F77_INT gelqf_lwork(F77_INT m, F77_INT n, P *a, F77_INT lda, P *tau, P *work, F77_INT lwork, F77_INT &info)
static F77_INT ormlq_lwork(char &side, char &trans, F77_INT m, F77_INT n, F77_INT k, P *a, F77_INT lda, P *tau, P *c, F77_INT ldc, P *work, F77_INT lwork, F77_INT &info)
static F77_INT geqp3_lwork(F77_INT m, F77_INT n, P *a, F77_INT lda, F77_INT *jpvt, P *tau, P *work, F77_INT lwork, F77_INT &info)
static F77_INT optimal(char &joba, char &jobu, char &jobv, F77_INT m, F77_INT n)
Definition: svd.cc:232
static F77_INT ormqr_lwork(char &side, char &trans, F77_INT m, F77_INT n, F77_INT k, P *a, F77_INT lda, P *tau, P *c, F77_INT ldc, P *work, F77_INT lwork, F77_INT &info)
static F77_INT geqrf_lwork(F77_INT m, F77_INT n, P *a, F77_INT lda, P *tau, P *work, F77_INT lwork, F77_INT &info)
T::element_type P
Definition: svd.cc:65
gejsv_lwork()=delete
Definition: svd.h:41
T m_left_sm
Definition: svd.h:102
T m_right_sm
Definition: svd.h:104
void gesdd(char &jobz, octave_f77_int_type m, octave_f77_int_type n, P *tmp_data, octave_f77_int_type m1, DM_P *s_vec, P *u, P *vt, octave_f77_int_type nrow_vt1, std::vector< P > &work, octave_f77_int_type &lwork, octave_f77_int_type *iwork, octave_f77_int_type &info)
svd(void)
Definition: svd.h:60
void gejsv(char &joba, char &jobu, char &jobv, char &jobr, char &jobt, char &jobp, octave_f77_int_type m, octave_f77_int_type n, P *tmp_data, octave_f77_int_type m1, DM_P *s_vec, P *u, P *v, octave_f77_int_type nrow_v1, std::vector< P > &work, octave_f77_int_type &lwork, std::vector< octave_f77_int_type > &iwork, octave_f77_int_type &info)
T::element_type P
Definition: svd.h:96
T::real_diag_matrix_type DM_T
Definition: svd.h:44
void gesvd(char &jobu, char &jobv, octave_f77_int_type m, octave_f77_int_type n, P *tmp_data, octave_f77_int_type m1, DM_P *s_vec, P *u, P *vt, octave_f77_int_type nrow_vt1, std::vector< P > &work, octave_f77_int_type &lwork, octave_f77_int_type &info)
DM_T::element_type DM_P
Definition: svd.h:97
Type
Definition: svd.h:47
T right_singular_matrix(void) const
Definition: svd.cc:321
DM_T m_sigma
Definition: svd.h:103
T left_singular_matrix(void) const
Definition: svd.cc:310
Driver
Definition: svd.h:54
svd::Type m_type
Definition: svd.h:99
svd::Driver m_driver
Definition: svd.h:100
OCTAVE_BEGIN_NAMESPACE(octave) static octave_value daspk_fcn
#define F77_DBLE_CMPLX_ARG(x)
Definition: f77-fcn.h:316
#define F77_CMPLX_ARG(x)
Definition: f77-fcn.h:310
octave_f77_int_type F77_INT
Definition: f77-fcn.h:306
#define GESVD_REAL_STEP(f, F)
Definition: svd.cc:332
#define GESDD_COMPLEX_STEP(f, F, CMPLX_ARG)
Definition: svd.cc:435
#define GEJSV_REAL_STEP(f, F)
Definition: svd.cc:539
#define GEJSV_REAL_ORM_LWORK(f, F)
Definition: svd.cc:102
#define GEJSV_REAL_QR_LWORK(f, F)
Definition: svd.cc:99
#define GEJSV_COMPLEX_STEP(f, F, CMPLX_ARG)
Definition: svd.cc:555
#define GEJSV_REAL_QP3_LWORK(f, F)
Definition: svd.cc:96
#define GESDD_REAL_STEP(f, F)
Definition: svd.cc:429
#define GESVD_COMPLEX_STEP(f, F, CMPLX_ARG)
Definition: svd.cc:340
OCTAVE_NORETURN liboctave_error_handler current_liboctave_error_handler
Definition: lo-error.c:41
F77_RET_T const F77_DBLE const F77_DBLE F77_DBLE const F77_INT F77_INT & ierr
#define OCTAVE_API
Definition: main.in.cc:55
T octave_idx_type m
Definition: mx-inlines.cc:773
octave_idx_type n
Definition: mx-inlines.cc:753
std::complex< double > Complex
Definition: oct-cmplx.h:33
std::complex< float > FloatComplex
Definition: oct-cmplx.h:34