floatCHOL.cc

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (C) 1994-2012 John W. Eaton
00004 Copyright (C) 2008-2009 Jaroslav Hajek
00005 
00006 This file is part of Octave.
00007 
00008 Octave is free software; you can redistribute it and/or modify it
00009 under the terms of the GNU General Public License as published by the
00010 Free Software Foundation; either version 3 of the License, or (at your
00011 option) any later version.
00012 
00013 Octave is distributed in the hope that it will be useful, but WITHOUT
00014 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
00015 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
00016 for more details.
00017 
00018 You should have received a copy of the GNU General Public License
00019 along with Octave; see the file COPYING.  If not, see
00020 <http://www.gnu.org/licenses/>.
00021 
00022 */
00023 
00024 #ifdef HAVE_CONFIG_H
00025 #include <config.h>
00026 #endif
00027 
00028 #include <vector>
00029 
00030 #include "fRowVector.h"
00031 #include "floatCHOL.h"
00032 #include "f77-fcn.h"
00033 #include "lo-error.h"
00034 #include "oct-locbuf.h"
00035 #include "oct-norm.h"
00036 #ifndef HAVE_QRUPDATE
00037 #include "dbleQR.h"
00038 #endif
00039 
00040 extern "C"
00041 {
00042   F77_RET_T
00043   F77_FUNC (spotrf, SPOTRF) (F77_CONST_CHAR_ARG_DECL,
00044                              const octave_idx_type&, float*,
00045                              const octave_idx_type&, octave_idx_type&
00046                              F77_CHAR_ARG_LEN_DECL);
00047 
00048   F77_RET_T
00049   F77_FUNC (spotri, SPOTRI) (F77_CONST_CHAR_ARG_DECL,
00050                              const octave_idx_type&, float*,
00051                              const octave_idx_type&, octave_idx_type&
00052                              F77_CHAR_ARG_LEN_DECL);
00053 
00054   F77_RET_T
00055   F77_FUNC (spocon, SPOCON) (F77_CONST_CHAR_ARG_DECL,
00056                              const octave_idx_type&, float*,
00057                              const octave_idx_type&, const float&,
00058                              float&, float*, octave_idx_type*,
00059                              octave_idx_type&
00060                              F77_CHAR_ARG_LEN_DECL);
00061 #ifdef HAVE_QRUPDATE
00062 
00063   F77_RET_T
00064   F77_FUNC (sch1up, SCH1UP) (const octave_idx_type&, float*,
00065                              const octave_idx_type&, float*, float*);
00066 
00067   F77_RET_T
00068   F77_FUNC (sch1dn, SCH1DN) (const octave_idx_type&, float*,
00069                              const octave_idx_type&, float*, float*,
00070                              octave_idx_type&);
00071 
00072   F77_RET_T
00073   F77_FUNC (schinx, SCHINX) (const octave_idx_type&, float*,
00074                              const octave_idx_type&, const octave_idx_type&,
00075                              float*, float*, octave_idx_type&);
00076 
00077   F77_RET_T
00078   F77_FUNC (schdex, SCHDEX) (const octave_idx_type&, float*,
00079                              const octave_idx_type&, const octave_idx_type&,
00080                              float*);
00081 
00082   F77_RET_T
00083   F77_FUNC (schshx, SCHSHX) (const octave_idx_type&, float*,
00084                              const octave_idx_type&, const octave_idx_type&,
00085                              const octave_idx_type&, float*);
00086 #endif
00087 }
00088 
00089 octave_idx_type
00090 FloatCHOL::init (const FloatMatrix& a, bool calc_cond)
00091 {
00092   octave_idx_type a_nr = a.rows ();
00093   octave_idx_type a_nc = a.cols ();
00094 
00095   if (a_nr != a_nc)
00096     {
00097       (*current_liboctave_error_handler) ("FloatCHOL requires square matrix");
00098       return -1;
00099     }
00100 
00101   octave_idx_type n = a_nc;
00102   octave_idx_type info;
00103 
00104   chol_mat.clear (n, n);
00105   for (octave_idx_type j = 0; j < n; j++)
00106     {
00107       for (octave_idx_type i = 0; i <= j; i++)
00108         chol_mat.xelem (i, j) = a(i, j);
00109       for (octave_idx_type i = j+1; i < n; i++)
00110         chol_mat.xelem (i, j) = 0.0f;
00111     }
00112   float *h = chol_mat.fortran_vec ();
00113 
00114   // Calculate the norm of the matrix, for later use.
00115   float anorm = 0;
00116   if (calc_cond)
00117     anorm = xnorm (a, 1);
00118 
00119   F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1),
00120                              n, h, n, info
00121                              F77_CHAR_ARG_LEN (1)));
00122 
00123   xrcond = 0.0;
00124   if (info > 0)
00125     chol_mat.resize (info - 1, info - 1);
00126   else if (calc_cond)
00127     {
00128       octave_idx_type spocon_info = 0;
00129 
00130       // Now calculate the condition number for non-singular matrix.
00131       Array<float> z (dim_vector (3*n, 1));
00132       float *pz = z.fortran_vec ();
00133       Array<octave_idx_type> iz (dim_vector (n, 1));
00134       octave_idx_type *piz = iz.fortran_vec ();
00135       F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h,
00136                                  n, anorm, xrcond, pz, piz, spocon_info
00137                                  F77_CHAR_ARG_LEN (1)));
00138 
00139       if (spocon_info != 0)
00140         info = -1;
00141     }
00142 
00143   return info;
00144 }
00145 
00146 static FloatMatrix
00147 chol2inv_internal (const FloatMatrix& r)
00148 {
00149   FloatMatrix retval;
00150 
00151   octave_idx_type r_nr = r.rows ();
00152   octave_idx_type r_nc = r.cols ();
00153 
00154   if (r_nr == r_nc)
00155     {
00156       octave_idx_type n = r_nc;
00157       octave_idx_type info = 0;
00158 
00159       FloatMatrix tmp = r;
00160       float *v = tmp.fortran_vec();
00161 
00162       if (info == 0)
00163         {
00164           F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
00165                                      v, n, info
00166                                      F77_CHAR_ARG_LEN (1)));
00167 
00168           // If someone thinks of a more graceful way of doing this (or
00169           // faster for that matter :-)), please let me know!
00170 
00171           if (n > 1)
00172             for (octave_idx_type j = 0; j < r_nc; j++)
00173               for (octave_idx_type i = j+1; i < r_nr; i++)
00174                 tmp.xelem (i, j) = tmp.xelem (j, i);
00175 
00176           retval = tmp;
00177         }
00178     }
00179   else
00180     (*current_liboctave_error_handler) ("chol2inv requires square matrix");
00181 
00182   return retval;
00183 }
00184 
00185 // Compute the inverse of a matrix using the Cholesky factorization.
00186 FloatMatrix
00187 FloatCHOL::inverse (void) const
00188 {
00189   return chol2inv_internal (chol_mat);
00190 }
00191 
00192 void
00193 FloatCHOL::set (const FloatMatrix& R)
00194 {
00195   if (R.is_square ())
00196     chol_mat = R;
00197   else
00198     (*current_liboctave_error_handler) ("FloatCHOL requires square matrix");
00199 }
00200 
00201 #ifdef HAVE_QRUPDATE
00202 
00203 void
00204 FloatCHOL::update (const FloatColumnVector& u)
00205 {
00206   octave_idx_type n = chol_mat.rows ();
00207 
00208   if (u.length () == n)
00209     {
00210       FloatColumnVector utmp = u;
00211 
00212       OCTAVE_LOCAL_BUFFER (float, w, n);
00213 
00214       F77_XFCN (sch1up, SCH1UP, (n, chol_mat.fortran_vec (), chol_mat.rows (),
00215                                  utmp.fortran_vec (), w));
00216     }
00217   else
00218     (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
00219 }
00220 
00221 octave_idx_type
00222 FloatCHOL::downdate (const FloatColumnVector& u)
00223 {
00224   octave_idx_type info = -1;
00225 
00226   octave_idx_type n = chol_mat.rows ();
00227 
00228   if (u.length () == n)
00229     {
00230       FloatColumnVector utmp = u;
00231 
00232       OCTAVE_LOCAL_BUFFER (float, w, n);
00233 
00234       F77_XFCN (sch1dn, SCH1DN, (n, chol_mat.fortran_vec (), chol_mat.rows (),
00235                                  utmp.fortran_vec (), w, info));
00236     }
00237   else
00238     (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
00239 
00240   return info;
00241 }
00242 
00243 octave_idx_type
00244 FloatCHOL::insert_sym (const FloatColumnVector& u, octave_idx_type j)
00245 {
00246   octave_idx_type info = -1;
00247 
00248   octave_idx_type n = chol_mat.rows ();
00249 
00250   if (u.length () != n + 1)
00251     (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
00252   else if (j < 0 || j > n)
00253     (*current_liboctave_error_handler) ("cholinsert: index out of range");
00254   else
00255     {
00256       FloatColumnVector utmp = u;
00257 
00258       OCTAVE_LOCAL_BUFFER (float, w, n);
00259 
00260       chol_mat.resize (n+1, n+1);
00261 
00262       F77_XFCN (schinx, SCHINX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
00263                                  j + 1, utmp.fortran_vec (), w, info));
00264     }
00265 
00266   return info;
00267 }
00268 
00269 void
00270 FloatCHOL::delete_sym (octave_idx_type j)
00271 {
00272   octave_idx_type n = chol_mat.rows ();
00273 
00274   if (j < 0 || j > n-1)
00275     (*current_liboctave_error_handler) ("choldelete: index out of range");
00276   else
00277     {
00278       OCTAVE_LOCAL_BUFFER (float, w, n);
00279 
00280       F77_XFCN (schdex, SCHDEX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
00281                                  j + 1, w));
00282 
00283       chol_mat.resize (n-1, n-1);
00284     }
00285 }
00286 
00287 void
00288 FloatCHOL::shift_sym (octave_idx_type i, octave_idx_type j)
00289 {
00290   octave_idx_type n = chol_mat.rows ();
00291 
00292   if (i < 0 || i > n-1 || j < 0 || j > n-1)
00293     (*current_liboctave_error_handler) ("cholshift: index out of range");
00294   else
00295     {
00296       OCTAVE_LOCAL_BUFFER (float, w, 2*n);
00297 
00298       F77_XFCN (schshx, SCHSHX, (n, chol_mat.fortran_vec (), chol_mat.rows (),
00299                                  i + 1, j + 1, w));
00300     }
00301 }
00302 
00303 #else
00304 
00305 void
00306 FloatCHOL::update (const FloatColumnVector& u)
00307 {
00308   warn_qrupdate_once ();
00309 
00310   octave_idx_type n = chol_mat.rows ();
00311 
00312   if (u.length () == n)
00313     {
00314       init (chol_mat.transpose () * chol_mat
00315             + FloatMatrix (u) * FloatMatrix (u).transpose (), false);
00316     }
00317   else
00318     (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
00319 }
00320 
00321 static bool
00322 singular (const FloatMatrix& a)
00323 {
00324   for (octave_idx_type i = 0; i < a.rows (); i++)
00325     if (a(i,i) == 0.0f) return true;
00326   return false;
00327 }
00328 
00329 octave_idx_type
00330 FloatCHOL::downdate (const FloatColumnVector& u)
00331 {
00332   warn_qrupdate_once ();
00333 
00334   octave_idx_type info = -1;
00335 
00336   octave_idx_type n = chol_mat.rows ();
00337 
00338   if (u.length () == n)
00339     {
00340       if (singular (chol_mat))
00341         info = 2;
00342       else
00343         {
00344           info = init (chol_mat.transpose () * chol_mat
00345                 - FloatMatrix (u) * FloatMatrix (u).transpose (), false);
00346           if (info) info = 1;
00347         }
00348     }
00349   else
00350     (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
00351 
00352   return info;
00353 }
00354 
00355 octave_idx_type
00356 FloatCHOL::insert_sym (const FloatColumnVector& u, octave_idx_type j)
00357 {
00358   warn_qrupdate_once ();
00359 
00360   octave_idx_type info = -1;
00361 
00362   octave_idx_type n = chol_mat.rows ();
00363 
00364   if (u.length () != n + 1)
00365     (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
00366   else if (j < 0 || j > n)
00367     (*current_liboctave_error_handler) ("cholinsert: index out of range");
00368   else
00369     {
00370       if (singular (chol_mat))
00371         info = 2;
00372       else
00373         {
00374           FloatMatrix a = chol_mat.transpose () * chol_mat;
00375           FloatMatrix a1 (n+1, n+1);
00376           for (octave_idx_type k = 0; k < n+1; k++)
00377             for (octave_idx_type l = 0; l < n+1; l++)
00378               {
00379                 if (l == j)
00380                   a1(k, l) = u(k);
00381                 else if (k == j)
00382                   a1(k, l) = u(l);
00383                 else
00384                   a1(k, l) = a(k < j ? k : k-1, l < j ? l : l-1);
00385               }
00386           info = init (a1, false);
00387           if (info) info = 1;
00388         }
00389     }
00390 
00391   return info;
00392 }
00393 
00394 void
00395 FloatCHOL::delete_sym (octave_idx_type j)
00396 {
00397   warn_qrupdate_once ();
00398 
00399   octave_idx_type n = chol_mat.rows ();
00400 
00401   if (j < 0 || j > n-1)
00402     (*current_liboctave_error_handler) ("choldelete: index out of range");
00403   else
00404     {
00405       FloatMatrix a = chol_mat.transpose () * chol_mat;
00406       a.delete_elements (1, idx_vector (j));
00407       a.delete_elements (0, idx_vector (j));
00408       init (a, false);
00409     }
00410 }
00411 
00412 void
00413 FloatCHOL::shift_sym (octave_idx_type i, octave_idx_type j)
00414 {
00415   warn_qrupdate_once ();
00416 
00417   octave_idx_type n = chol_mat.rows ();
00418 
00419   if (i < 0 || i > n-1 || j < 0 || j > n-1)
00420     (*current_liboctave_error_handler) ("cholshift: index out of range");
00421   else
00422     {
00423       FloatMatrix a = chol_mat.transpose () * chol_mat;
00424       Array<octave_idx_type> p (dim_vector (n, 1));
00425       for (octave_idx_type k = 0; k < n; k++) p(k) = k;
00426       if (i < j)
00427         {
00428           for (octave_idx_type k = i; k < j; k++) p(k) = k+1;
00429           p(j) = i;
00430         }
00431       else if (j < i)
00432         {
00433           p(j) = i;
00434           for (octave_idx_type k = j+1; k < i+1; k++) p(k) = k-1;
00435         }
00436 
00437       init (a.index (idx_vector (p), idx_vector (p)), false);
00438     }
00439 }
00440 
00441 #endif
00442 
00443 FloatMatrix
00444 chol2inv (const FloatMatrix& r)
00445 {
00446   return chol2inv_internal (r);
00447 }
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines