00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #ifdef HAVE_CONFIG_H
00024 #include <config.h>
00025 #endif
00026
00027 #include "CmplxSVD.h"
00028 #include "f77-fcn.h"
00029 #include "lo-error.h"
00030 #include "oct-locbuf.h"
00031
00032 extern "C"
00033 {
00034 F77_RET_T
00035 F77_FUNC (zgesvd, ZGESVD) (F77_CONST_CHAR_ARG_DECL,
00036 F77_CONST_CHAR_ARG_DECL,
00037 const octave_idx_type&, const octave_idx_type&,
00038 Complex*, const octave_idx_type&,
00039 double*, Complex*, const octave_idx_type&,
00040 Complex*, const octave_idx_type&, Complex*,
00041 const octave_idx_type&, double*, octave_idx_type&
00042 F77_CHAR_ARG_LEN_DECL
00043 F77_CHAR_ARG_LEN_DECL);
00044
00045 F77_RET_T
00046 F77_FUNC (zgesdd, ZGESDD) (F77_CONST_CHAR_ARG_DECL,
00047 const octave_idx_type&, const octave_idx_type&,
00048 Complex*, const octave_idx_type&,
00049 double*, Complex*, const octave_idx_type&,
00050 Complex*, const octave_idx_type&, Complex*,
00051 const octave_idx_type&, double*,
00052 octave_idx_type *, octave_idx_type&
00053 F77_CHAR_ARG_LEN_DECL);
00054 }
00055
00056 ComplexMatrix
00057 ComplexSVD::left_singular_matrix (void) const
00058 {
00059 if (type_computed == SVD::sigma_only)
00060 {
00061 (*current_liboctave_error_handler)
00062 ("ComplexSVD: U not computed because type == SVD::sigma_only");
00063 return ComplexMatrix ();
00064 }
00065 else
00066 return left_sm;
00067 }
00068
00069 ComplexMatrix
00070 ComplexSVD::right_singular_matrix (void) const
00071 {
00072 if (type_computed == SVD::sigma_only)
00073 {
00074 (*current_liboctave_error_handler)
00075 ("ComplexSVD: V not computed because type == SVD::sigma_only");
00076 return ComplexMatrix ();
00077 }
00078 else
00079 return right_sm;
00080 }
00081
00082 octave_idx_type
00083 ComplexSVD::init (const ComplexMatrix& a, SVD::type svd_type, SVD::driver svd_driver)
00084 {
00085 octave_idx_type info;
00086
00087 octave_idx_type m = a.rows ();
00088 octave_idx_type n = a.cols ();
00089
00090 ComplexMatrix atmp = a;
00091 Complex *tmp_data = atmp.fortran_vec ();
00092
00093 octave_idx_type min_mn = m < n ? m : n;
00094 octave_idx_type max_mn = m > n ? m : n;
00095
00096 char jobu = 'A';
00097 char jobv = 'A';
00098
00099 octave_idx_type ncol_u = m;
00100 octave_idx_type nrow_vt = n;
00101 octave_idx_type nrow_s = m;
00102 octave_idx_type ncol_s = n;
00103
00104 switch (svd_type)
00105 {
00106 case SVD::economy:
00107 jobu = jobv = 'S';
00108 ncol_u = nrow_vt = nrow_s = ncol_s = min_mn;
00109 break;
00110
00111 case SVD::sigma_only:
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121 jobu = jobv = 'N';
00122 ncol_u = nrow_vt = 1;
00123 break;
00124
00125 default:
00126 break;
00127 }
00128
00129 type_computed = svd_type;
00130
00131 if (! (jobu == 'N' || jobu == 'O'))
00132 left_sm.resize (m, ncol_u);
00133
00134 Complex *u = left_sm.fortran_vec ();
00135
00136 sigma.resize (nrow_s, ncol_s);
00137 double *s_vec = sigma.fortran_vec ();
00138
00139 if (! (jobv == 'N' || jobv == 'O'))
00140 right_sm.resize (nrow_vt, n);
00141
00142 Complex *vt = right_sm.fortran_vec ();
00143
00144
00145
00146 octave_idx_type lwork = -1;
00147
00148 Array<Complex> work (dim_vector (1, 1));
00149
00150 octave_idx_type one = 1;
00151 octave_idx_type m1 = std::max (m, one);
00152 octave_idx_type nrow_vt1 = std::max (nrow_vt, one);
00153
00154 if (svd_driver == SVD::GESVD)
00155 {
00156 octave_idx_type lrwork = 5*max_mn;
00157 Array<double> rwork (dim_vector (lrwork, 1));
00158
00159 F77_XFCN (zgesvd, ZGESVD, (F77_CONST_CHAR_ARG2 (&jobu, 1),
00160 F77_CONST_CHAR_ARG2 (&jobv, 1),
00161 m, n, tmp_data, m1, s_vec, u, m1, vt,
00162 nrow_vt1, work.fortran_vec (), lwork,
00163 rwork.fortran_vec (), info
00164 F77_CHAR_ARG_LEN (1)
00165 F77_CHAR_ARG_LEN (1)));
00166
00167 lwork = static_cast<octave_idx_type> (work(0).real ());
00168 work.resize (dim_vector (lwork, 1));
00169
00170 F77_XFCN (zgesvd, ZGESVD, (F77_CONST_CHAR_ARG2 (&jobu, 1),
00171 F77_CONST_CHAR_ARG2 (&jobv, 1),
00172 m, n, tmp_data, m1, s_vec, u, m1, vt,
00173 nrow_vt1, work.fortran_vec (), lwork,
00174 rwork.fortran_vec (), info
00175 F77_CHAR_ARG_LEN (1)
00176 F77_CHAR_ARG_LEN (1)));
00177 }
00178 else if (svd_driver == SVD::GESDD)
00179 {
00180 assert (jobu == jobv);
00181 char jobz = jobu;
00182
00183 octave_idx_type lrwork;
00184 if (jobz == 'N')
00185 lrwork = 7*min_mn;
00186 else
00187 lrwork = 5*min_mn*min_mn + 5*min_mn;
00188 Array<double> rwork (dim_vector (lrwork, 1));
00189
00190 OCTAVE_LOCAL_BUFFER (octave_idx_type, iwork, 8*min_mn);
00191
00192 F77_XFCN (zgesdd, ZGESDD, (F77_CONST_CHAR_ARG2 (&jobz, 1),
00193 m, n, tmp_data, m1, s_vec, u, m1, vt,
00194 nrow_vt1, work.fortran_vec (), lwork,
00195 rwork.fortran_vec (), iwork, info
00196 F77_CHAR_ARG_LEN (1)));
00197
00198 lwork = static_cast<octave_idx_type> (work(0).real ());
00199 work.resize (dim_vector (lwork, 1));
00200
00201 F77_XFCN (zgesdd, ZGESDD, (F77_CONST_CHAR_ARG2 (&jobz, 1),
00202 m, n, tmp_data, m1, s_vec, u, m1, vt,
00203 nrow_vt1, work.fortran_vec (), lwork,
00204 rwork.fortran_vec (), iwork, info
00205 F77_CHAR_ARG_LEN (1)));
00206 }
00207 else
00208 assert (0);
00209
00210 if (! (jobv == 'N' || jobv == 'O'))
00211 right_sm = right_sm.hermitian ();
00212
00213 return info;
00214 }