00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #ifdef HAVE_CONFIG_H
00024 #include <config.h>
00025 #endif
00026
00027 #include <iostream>
00028
00029 #include "floatSVD.h"
00030 #include "f77-fcn.h"
00031 #include "oct-locbuf.h"
00032
00033 extern "C"
00034 {
00035 F77_RET_T
00036 F77_FUNC (sgesvd, SGESVD) (F77_CONST_CHAR_ARG_DECL,
00037 F77_CONST_CHAR_ARG_DECL,
00038 const octave_idx_type&, const octave_idx_type&,
00039 float*, const octave_idx_type&, float*,
00040 float*, const octave_idx_type&, float*,
00041 const octave_idx_type&, float*,
00042 const octave_idx_type&, octave_idx_type&
00043 F77_CHAR_ARG_LEN_DECL
00044 F77_CHAR_ARG_LEN_DECL);
00045
00046 F77_RET_T
00047 F77_FUNC (sgesdd, SGESDD) (F77_CONST_CHAR_ARG_DECL,
00048 const octave_idx_type&, const octave_idx_type&,
00049 float*, const octave_idx_type&, float*,
00050 float*, const octave_idx_type&, float*,
00051 const octave_idx_type&, float*,
00052 const octave_idx_type&, octave_idx_type *,
00053 octave_idx_type&
00054 F77_CHAR_ARG_LEN_DECL);
00055 }
00056
00057 FloatMatrix
00058 FloatSVD::left_singular_matrix (void) const
00059 {
00060 if (type_computed == SVD::sigma_only)
00061 {
00062 (*current_liboctave_error_handler)
00063 ("FloatSVD: U not computed because type == SVD::sigma_only");
00064 return FloatMatrix ();
00065 }
00066 else
00067 return left_sm;
00068 }
00069
00070 FloatMatrix
00071 FloatSVD::right_singular_matrix (void) const
00072 {
00073 if (type_computed == SVD::sigma_only)
00074 {
00075 (*current_liboctave_error_handler)
00076 ("FloatSVD: V not computed because type == SVD::sigma_only");
00077 return FloatMatrix ();
00078 }
00079 else
00080 return right_sm;
00081 }
00082
00083 octave_idx_type
00084 FloatSVD::init (const FloatMatrix& a, SVD::type svd_type, SVD::driver svd_driver)
00085 {
00086 octave_idx_type info;
00087
00088 octave_idx_type m = a.rows ();
00089 octave_idx_type n = a.cols ();
00090
00091 FloatMatrix atmp = a;
00092 float *tmp_data = atmp.fortran_vec ();
00093
00094 octave_idx_type min_mn = m < n ? m : n;
00095
00096 char jobu = 'A';
00097 char jobv = 'A';
00098
00099 octave_idx_type ncol_u = m;
00100 octave_idx_type nrow_vt = n;
00101 octave_idx_type nrow_s = m;
00102 octave_idx_type ncol_s = n;
00103
00104 switch (svd_type)
00105 {
00106 case SVD::economy:
00107 jobu = jobv = 'S';
00108 ncol_u = nrow_vt = nrow_s = ncol_s = min_mn;
00109 break;
00110
00111 case SVD::sigma_only:
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121 jobu = jobv = 'N';
00122 ncol_u = nrow_vt = 1;
00123 break;
00124
00125 default:
00126 break;
00127 }
00128
00129 type_computed = svd_type;
00130
00131 if (! (jobu == 'N' || jobu == 'O'))
00132 left_sm.resize (m, ncol_u);
00133
00134 float *u = left_sm.fortran_vec ();
00135
00136 sigma.resize (nrow_s, ncol_s);
00137 float *s_vec = sigma.fortran_vec ();
00138
00139 if (! (jobv == 'N' || jobv == 'O'))
00140 right_sm.resize (nrow_vt, n);
00141
00142 float *vt = right_sm.fortran_vec ();
00143
00144
00145
00146 octave_idx_type lwork = -1;
00147
00148 Array<float> work (dim_vector (1, 1));
00149
00150 octave_idx_type one = 1;
00151 octave_idx_type m1 = std::max (m, one);
00152 octave_idx_type nrow_vt1 = std::max (nrow_vt, one);
00153
00154 if (svd_driver == SVD::GESVD)
00155 {
00156 F77_XFCN (sgesvd, SGESVD, (F77_CONST_CHAR_ARG2 (&jobu, 1),
00157 F77_CONST_CHAR_ARG2 (&jobv, 1),
00158 m, n, tmp_data, m1, s_vec, u, m1, vt,
00159 nrow_vt1, work.fortran_vec (), lwork, info
00160 F77_CHAR_ARG_LEN (1)
00161 F77_CHAR_ARG_LEN (1)));
00162
00163 lwork = static_cast<octave_idx_type> (work(0));
00164 work.resize (dim_vector (lwork, 1));
00165
00166 F77_XFCN (sgesvd, SGESVD, (F77_CONST_CHAR_ARG2 (&jobu, 1),
00167 F77_CONST_CHAR_ARG2 (&jobv, 1),
00168 m, n, tmp_data, m1, s_vec, u, m1, vt,
00169 nrow_vt1, work.fortran_vec (), lwork, info
00170 F77_CHAR_ARG_LEN (1)
00171 F77_CHAR_ARG_LEN (1)));
00172
00173 }
00174 else if (svd_driver == SVD::GESDD)
00175 {
00176 assert (jobu == jobv);
00177 char jobz = jobu;
00178 OCTAVE_LOCAL_BUFFER (octave_idx_type, iwork, 8*min_mn);
00179
00180 F77_XFCN (sgesdd, SGESDD, (F77_CONST_CHAR_ARG2 (&jobz, 1),
00181 m, n, tmp_data, m1, s_vec, u, m1, vt,
00182 nrow_vt1, work.fortran_vec (), lwork, iwork, info
00183 F77_CHAR_ARG_LEN (1)));
00184
00185 lwork = static_cast<octave_idx_type> (work(0));
00186 work.resize (dim_vector (lwork, 1));
00187
00188 F77_XFCN (sgesdd, SGESDD, (F77_CONST_CHAR_ARG2 (&jobz, 1),
00189 m, n, tmp_data, m1, s_vec, u, m1, vt,
00190 nrow_vt1, work.fortran_vec (), lwork, iwork, info
00191 F77_CHAR_ARG_LEN (1)));
00192
00193 }
00194 else
00195 assert (0);
00196
00197 if (! (jobv == 'N' || jobv == 'O'))
00198 right_sm = right_sm.transpose ();
00199
00200 return info;
00201 }
00202
00203 std::ostream&
00204 operator << (std::ostream& os, const FloatSVD& a)
00205 {
00206 os << a.left_singular_matrix () << "\n";
00207 os << a.singular_values () << "\n";
00208 os << a.right_singular_matrix () << "\n";
00209
00210 return os;
00211 }