00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #ifdef HAVE_CONFIG_H
00025 #include <config.h>
00026 #endif
00027
00028 #include <vector>
00029
00030 #include "fRowVector.h"
00031 #include "floatCHOL.h"
00032 #include "f77-fcn.h"
00033 #include "lo-error.h"
00034 #include "oct-locbuf.h"
00035 #include "oct-norm.h"
00036 #ifndef HAVE_QRUPDATE
00037 #include "dbleQR.h"
00038 #endif
00039
00040 extern "C"
00041 {
00042 F77_RET_T
00043 F77_FUNC (spotrf, SPOTRF) (F77_CONST_CHAR_ARG_DECL,
00044 const octave_idx_type&, float*,
00045 const octave_idx_type&, octave_idx_type&
00046 F77_CHAR_ARG_LEN_DECL);
00047
00048 F77_RET_T
00049 F77_FUNC (spotri, SPOTRI) (F77_CONST_CHAR_ARG_DECL,
00050 const octave_idx_type&, float*,
00051 const octave_idx_type&, octave_idx_type&
00052 F77_CHAR_ARG_LEN_DECL);
00053
00054 F77_RET_T
00055 F77_FUNC (spocon, SPOCON) (F77_CONST_CHAR_ARG_DECL,
00056 const octave_idx_type&, float*,
00057 const octave_idx_type&, const float&,
00058 float&, float*, octave_idx_type*,
00059 octave_idx_type&
00060 F77_CHAR_ARG_LEN_DECL);
00061 #ifdef HAVE_QRUPDATE
00062
00063 F77_RET_T
00064 F77_FUNC (sch1up, SCH1UP) (const octave_idx_type&, float*,
00065 const octave_idx_type&, float*, float*);
00066
00067 F77_RET_T
00068 F77_FUNC (sch1dn, SCH1DN) (const octave_idx_type&, float*,
00069 const octave_idx_type&, float*, float*,
00070 octave_idx_type&);
00071
00072 F77_RET_T
00073 F77_FUNC (schinx, SCHINX) (const octave_idx_type&, float*,
00074 const octave_idx_type&, const octave_idx_type&,
00075 float*, float*, octave_idx_type&);
00076
00077 F77_RET_T
00078 F77_FUNC (schdex, SCHDEX) (const octave_idx_type&, float*,
00079 const octave_idx_type&, const octave_idx_type&,
00080 float*);
00081
00082 F77_RET_T
00083 F77_FUNC (schshx, SCHSHX) (const octave_idx_type&, float*,
00084 const octave_idx_type&, const octave_idx_type&,
00085 const octave_idx_type&, float*);
00086 #endif
00087 }
00088
00089 octave_idx_type
00090 FloatCHOL::init (const FloatMatrix& a, bool calc_cond)
00091 {
00092 octave_idx_type a_nr = a.rows ();
00093 octave_idx_type a_nc = a.cols ();
00094
00095 if (a_nr != a_nc)
00096 {
00097 (*current_liboctave_error_handler) ("FloatCHOL requires square matrix");
00098 return -1;
00099 }
00100
00101 octave_idx_type n = a_nc;
00102 octave_idx_type info;
00103
00104 chol_mat.clear (n, n);
00105 for (octave_idx_type j = 0; j < n; j++)
00106 {
00107 for (octave_idx_type i = 0; i <= j; i++)
00108 chol_mat.xelem (i, j) = a(i, j);
00109 for (octave_idx_type i = j+1; i < n; i++)
00110 chol_mat.xelem (i, j) = 0.0f;
00111 }
00112 float *h = chol_mat.fortran_vec ();
00113
00114
00115 float anorm = 0;
00116 if (calc_cond)
00117 anorm = xnorm (a, 1);
00118
00119 F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1),
00120 n, h, n, info
00121 F77_CHAR_ARG_LEN (1)));
00122
00123 xrcond = 0.0;
00124 if (info > 0)
00125 chol_mat.resize (info - 1, info - 1);
00126 else if (calc_cond)
00127 {
00128 octave_idx_type spocon_info = 0;
00129
00130
00131 Array<float> z (dim_vector (3*n, 1));
00132 float *pz = z.fortran_vec ();
00133 Array<octave_idx_type> iz (dim_vector (n, 1));
00134 octave_idx_type *piz = iz.fortran_vec ();
00135 F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h,
00136 n, anorm, xrcond, pz, piz, spocon_info
00137 F77_CHAR_ARG_LEN (1)));
00138
00139 if (spocon_info != 0)
00140 info = -1;
00141 }
00142
00143 return info;
00144 }
00145
00146 static FloatMatrix
00147 chol2inv_internal (const FloatMatrix& r)
00148 {
00149 FloatMatrix retval;
00150
00151 octave_idx_type r_nr = r.rows ();
00152 octave_idx_type r_nc = r.cols ();
00153
00154 if (r_nr == r_nc)
00155 {
00156 octave_idx_type n = r_nc;
00157 octave_idx_type info = 0;
00158
00159 FloatMatrix tmp = r;
00160 float *v = tmp.fortran_vec();
00161
00162 if (info == 0)
00163 {
00164 F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
00165 v, n, info
00166 F77_CHAR_ARG_LEN (1)));
00167
00168
00169
00170
00171 if (n > 1)
00172 for (octave_idx_type j = 0; j < r_nc; j++)
00173 for (octave_idx_type i = j+1; i < r_nr; i++)
00174 tmp.xelem (i, j) = tmp.xelem (j, i);
00175
00176 retval = tmp;
00177 }
00178 }
00179 else
00180 (*current_liboctave_error_handler) ("chol2inv requires square matrix");
00181
00182 return retval;
00183 }
00184
00185
00186 FloatMatrix
00187 FloatCHOL::inverse (void) const
00188 {
00189 return chol2inv_internal (chol_mat);
00190 }
00191
00192 void
00193 FloatCHOL::set (const FloatMatrix& R)
00194 {
00195 if (R.is_square ())
00196 chol_mat = R;
00197 else
00198 (*current_liboctave_error_handler) ("FloatCHOL requires square matrix");
00199 }
00200
00201 #ifdef HAVE_QRUPDATE
00202
00203 void
00204 FloatCHOL::update (const FloatColumnVector& u)
00205 {
00206 octave_idx_type n = chol_mat.rows ();
00207
00208 if (u.length () == n)
00209 {
00210 FloatColumnVector utmp = u;
00211
00212 OCTAVE_LOCAL_BUFFER (float, w, n);
00213
00214 F77_XFCN (sch1up, SCH1UP, (n, chol_mat.fortran_vec (), chol_mat.rows (),
00215 utmp.fortran_vec (), w));
00216 }
00217 else
00218 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
00219 }
00220
00221 octave_idx_type
00222 FloatCHOL::downdate (const FloatColumnVector& u)
00223 {
00224 octave_idx_type info = -1;
00225
00226 octave_idx_type n = chol_mat.rows ();
00227
00228 if (u.length () == n)
00229 {
00230 FloatColumnVector utmp = u;
00231
00232 OCTAVE_LOCAL_BUFFER (float, w, n);
00233
00234 F77_XFCN (sch1dn, SCH1DN, (n, chol_mat.fortran_vec (), chol_mat.rows (),
00235 utmp.fortran_vec (), w, info));
00236 }
00237 else
00238 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
00239
00240 return info;
00241 }
00242
00243 octave_idx_type
00244 FloatCHOL::insert_sym (const FloatColumnVector& u, octave_idx_type j)
00245 {
00246 octave_idx_type info = -1;
00247
00248 octave_idx_type n = chol_mat.rows ();
00249
00250 if (u.length () != n + 1)
00251 (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
00252 else if (j < 0 || j > n)
00253 (*current_liboctave_error_handler) ("cholinsert: index out of range");
00254 else
00255 {
00256 FloatColumnVector utmp = u;
00257
00258 OCTAVE_LOCAL_BUFFER (float, w, n);
00259
00260 chol_mat.resize (n+1, n+1);
00261
00262 F77_XFCN (schinx, SCHINX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
00263 j + 1, utmp.fortran_vec (), w, info));
00264 }
00265
00266 return info;
00267 }
00268
00269 void
00270 FloatCHOL::delete_sym (octave_idx_type j)
00271 {
00272 octave_idx_type n = chol_mat.rows ();
00273
00274 if (j < 0 || j > n-1)
00275 (*current_liboctave_error_handler) ("choldelete: index out of range");
00276 else
00277 {
00278 OCTAVE_LOCAL_BUFFER (float, w, n);
00279
00280 F77_XFCN (schdex, SCHDEX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
00281 j + 1, w));
00282
00283 chol_mat.resize (n-1, n-1);
00284 }
00285 }
00286
00287 void
00288 FloatCHOL::shift_sym (octave_idx_type i, octave_idx_type j)
00289 {
00290 octave_idx_type n = chol_mat.rows ();
00291
00292 if (i < 0 || i > n-1 || j < 0 || j > n-1)
00293 (*current_liboctave_error_handler) ("cholshift: index out of range");
00294 else
00295 {
00296 OCTAVE_LOCAL_BUFFER (float, w, 2*n);
00297
00298 F77_XFCN (schshx, SCHSHX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
00299 i + 1, j + 1, w));
00300 }
00301 }
00302
00303 #else
00304
00305 void
00306 FloatCHOL::update (const FloatColumnVector& u)
00307 {
00308 warn_qrupdate_once ();
00309
00310 octave_idx_type n = chol_mat.rows ();
00311
00312 if (u.length () == n)
00313 {
00314 init (chol_mat.transpose () * chol_mat
00315 + FloatMatrix (u) * FloatMatrix (u).transpose (), false);
00316 }
00317 else
00318 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
00319 }
00320
00321 static bool
00322 singular (const FloatMatrix& a)
00323 {
00324 for (octave_idx_type i = 0; i < a.rows (); i++)
00325 if (a(i,i) == 0.0f) return true;
00326 return false;
00327 }
00328
00329 octave_idx_type
00330 FloatCHOL::downdate (const FloatColumnVector& u)
00331 {
00332 warn_qrupdate_once ();
00333
00334 octave_idx_type info = -1;
00335
00336 octave_idx_type n = chol_mat.rows ();
00337
00338 if (u.length () == n)
00339 {
00340 if (singular (chol_mat))
00341 info = 2;
00342 else
00343 {
00344 info = init (chol_mat.transpose () * chol_mat
00345 - FloatMatrix (u) * FloatMatrix (u).transpose (), false);
00346 if (info) info = 1;
00347 }
00348 }
00349 else
00350 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
00351
00352 return info;
00353 }
00354
00355 octave_idx_type
00356 FloatCHOL::insert_sym (const FloatColumnVector& u, octave_idx_type j)
00357 {
00358 warn_qrupdate_once ();
00359
00360 octave_idx_type info = -1;
00361
00362 octave_idx_type n = chol_mat.rows ();
00363
00364 if (u.length () != n + 1)
00365 (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
00366 else if (j < 0 || j > n)
00367 (*current_liboctave_error_handler) ("cholinsert: index out of range");
00368 else
00369 {
00370 if (singular (chol_mat))
00371 info = 2;
00372 else
00373 {
00374 FloatMatrix a = chol_mat.transpose () * chol_mat;
00375 FloatMatrix a1 (n+1, n+1);
00376 for (octave_idx_type k = 0; k < n+1; k++)
00377 for (octave_idx_type l = 0; l < n+1; l++)
00378 {
00379 if (l == j)
00380 a1(k, l) = u(k);
00381 else if (k == j)
00382 a1(k, l) = u(l);
00383 else
00384 a1(k, l) = a(k < j ? k : k-1, l < j ? l : l-1);
00385 }
00386 info = init (a1, false);
00387 if (info) info = 1;
00388 }
00389 }
00390
00391 return info;
00392 }
00393
00394 void
00395 FloatCHOL::delete_sym (octave_idx_type j)
00396 {
00397 warn_qrupdate_once ();
00398
00399 octave_idx_type n = chol_mat.rows ();
00400
00401 if (j < 0 || j > n-1)
00402 (*current_liboctave_error_handler) ("choldelete: index out of range");
00403 else
00404 {
00405 FloatMatrix a = chol_mat.transpose () * chol_mat;
00406 a.delete_elements (1, idx_vector (j));
00407 a.delete_elements (0, idx_vector (j));
00408 init (a, false);
00409 }
00410 }
00411
00412 void
00413 FloatCHOL::shift_sym (octave_idx_type i, octave_idx_type j)
00414 {
00415 warn_qrupdate_once ();
00416
00417 octave_idx_type n = chol_mat.rows ();
00418
00419 if (i < 0 || i > n-1 || j < 0 || j > n-1)
00420 (*current_liboctave_error_handler) ("cholshift: index out of range");
00421 else
00422 {
00423 FloatMatrix a = chol_mat.transpose () * chol_mat;
00424 Array<octave_idx_type> p (dim_vector (n, 1));
00425 for (octave_idx_type k = 0; k < n; k++) p(k) = k;
00426 if (i < j)
00427 {
00428 for (octave_idx_type k = i; k < j; k++) p(k) = k+1;
00429 p(j) = i;
00430 }
00431 else if (j < i)
00432 {
00433 p(j) = i;
00434 for (octave_idx_type k = j+1; k < i+1; k++) p(k) = k-1;
00435 }
00436
00437 init (a.index (idx_vector (p), idx_vector (p)), false);
00438 }
00439 }
00440
00441 #endif
00442
00443 FloatMatrix
00444 chol2inv (const FloatMatrix& r)
00445 {
00446 return chol2inv_internal (r);
00447 }