sqrtm.cc

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (C) 2001-2012 Ross Lippert and Paul Kienzle
00004 Copyright (C) 2010 VZLU Prague
00005 
00006 This file is part of Octave.
00007 
00008 Octave is free software; you can redistribute it and/or modify it
00009 under the terms of the GNU General Public License as published by the
00010 Free Software Foundation; either version 3 of the License, or (at your
00011 option) any later version.
00012 
00013 Octave is distributed in the hope that it will be useful, but WITHOUT
00014 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
00015 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
00016 for more details.
00017 
00018 You should have received a copy of the GNU General Public License
00019 along with Octave; see the file COPYING.  If not, see
00020 <http://www.gnu.org/licenses/>.
00021 
00022 */
00023 
00024 #ifdef HAVE_CONFIG_H
00025 #include <config.h>
00026 #endif
00027 
00028 #include <float.h>
00029 
00030 #include "CmplxSCHUR.h"
00031 #include "fCmplxSCHUR.h"
00032 #include "lo-ieee.h"
00033 #include "lo-mappers.h"
00034 #include "oct-norm.h"
00035 
00036 #include "defun-dld.h"
00037 #include "error.h"
00038 #include "gripes.h"
00039 #include "utils.h"
00040 #include "xnorm.h"
00041 
00042 template <class Matrix>
00043 static void
00044 sqrtm_utri_inplace (Matrix& T)
00045 {
00046   typedef typename Matrix::element_type element_type;
00047 
00048   const element_type zero = element_type ();
00049 
00050   bool singular = false;
00051 
00052   // The following code is equivalent to this triple loop:
00053   //
00054   //   n = rows (T);
00055   //   for j = 1:n
00056   //     T(j,j) = sqrt (T(j,j));
00057   //     for i = j-1:-1:1
00058   //       T(i,j) /= (T(i,i) + T(j,j));
00059   //       k = 1:i-1;
00060   //       T(k,j) -= T(k,i) * T(i,j);
00061   //     endfor
00062   //   endfor
00063   //
00064   // this is an in-place, cache-aligned variant of the code
00065   // given in Higham's paper.
00066 
00067   const octave_idx_type n = T.rows ();
00068   element_type *Tp = T.fortran_vec ();
00069   for (octave_idx_type j = 0; j < n; j++)
00070     {
00071       element_type *colj = Tp + n*j;
00072       if (colj[j] != zero)
00073         colj[j] = sqrt (colj[j]);
00074       else
00075         singular = true;
00076 
00077       for (octave_idx_type i = j-1; i >= 0; i--)
00078         {
00079           const element_type *coli = Tp + n*i;
00080           const element_type colji = colj[i] /= (coli[i] + colj[j]);
00081           for (octave_idx_type k = 0; k < i; k++)
00082             colj[k] -= coli[k] * colji;
00083         }
00084     }
00085 
00086   if (singular)
00087     warning_with_id ("Octave:sqrtm:SingularMatrix",
00088                      "sqrtm: matrix is singular, may not have a square root");
00089 }
00090 
00091 template <class Matrix, class ComplexMatrix, class ComplexSCHUR>
00092 static octave_value
00093 do_sqrtm (const octave_value& arg)
00094 {
00095 
00096   octave_value retval;
00097 
00098   MatrixType mt = arg.matrix_type ();
00099 
00100   bool iscomplex = arg.is_complex_type ();
00101 
00102   typedef typename Matrix::element_type real_type;
00103 
00104   real_type cutoff = 0, one = 1;
00105   real_type eps = std::numeric_limits<real_type>::epsilon ();
00106 
00107   if (! iscomplex)
00108     {
00109       Matrix x = octave_value_extract<Matrix> (arg);
00110 
00111       if (mt.is_unknown ()) // if type is not known, compute it now.
00112         arg.matrix_type (mt = MatrixType (x));
00113 
00114       switch (mt.type ())
00115         {
00116         case MatrixType::Upper:
00117         case MatrixType::Diagonal:
00118           if (! x.diag ().any_element_is_negative ())
00119             {
00120               // Do it in real arithmetic.
00121               sqrtm_utri_inplace (x);
00122               retval = x;
00123               retval.matrix_type (mt);
00124             }
00125           else
00126             iscomplex = true;
00127           break;
00128 
00129         case MatrixType::Lower:
00130           if (! x.diag ().any_element_is_negative ())
00131             {
00132               x = x.transpose ();
00133               sqrtm_utri_inplace (x);
00134               retval = x.transpose ();
00135               retval.matrix_type (mt);
00136             }
00137           else
00138             iscomplex = true;
00139           break;
00140 
00141         default:
00142           iscomplex = true;
00143           break;
00144         }
00145 
00146       if (iscomplex)
00147         cutoff = 10 * x.rows () * eps * xnorm (x, one);
00148     }
00149 
00150   if (iscomplex)
00151     {
00152       ComplexMatrix x = octave_value_extract<ComplexMatrix> (arg);
00153 
00154       if (mt.is_unknown ()) // if type is not known, compute it now.
00155         arg.matrix_type (mt = MatrixType (x));
00156 
00157       switch (mt.type ())
00158         {
00159         case MatrixType::Upper:
00160         case MatrixType::Diagonal:
00161           sqrtm_utri_inplace (x);
00162           retval = x;
00163           retval.matrix_type (mt);
00164           break;
00165 
00166         case MatrixType::Lower:
00167           x = x.transpose ();
00168           sqrtm_utri_inplace (x);
00169           retval = x.transpose ();
00170           retval.matrix_type (mt);
00171           break;
00172 
00173         default:
00174           {
00175             ComplexMatrix u;
00176 
00177             do
00178               {
00179                 ComplexSCHUR schur (x, std::string (), true);
00180                 x = schur.schur_matrix ();
00181                 u = schur.unitary_matrix ();
00182               }
00183             while (0); // schur no longer needed.
00184 
00185             sqrtm_utri_inplace (x);
00186 
00187             x = u * x; // original x no longer needed.
00188             ComplexMatrix res = xgemm (x, u, blas_no_trans, blas_conj_trans);
00189 
00190             if (cutoff > 0 && xnorm (imag (res), one) <= cutoff)
00191               retval = real (res);
00192             else
00193               retval = res;
00194           }
00195           break;
00196         }
00197     }
00198 
00199   return retval;
00200 }
00201 
00202 DEFUN_DLD (sqrtm, args, nargout,
00203  "-*- texinfo -*-\n\
00204 @deftypefn  {Loadable Function} {@var{s} =} sqrtm (@var{A})\n\
00205 @deftypefnx {Loadable Function} {[@var{s}, @var{error_estimate}] =} sqrtm (@var{A})\n\
00206 Compute the matrix square root of the square matrix @var{A}.\n\
00207 \n\
00208 Ref: N.J. Higham.  @cite{A New sqrtm for @sc{matlab}}.  Numerical\n\
00209 Analysis Report No. 336, Manchester @nospell{Centre} for Computational\n\
00210 Mathematics, Manchester, England, January 1999.\n\
00211 @seealso{expm, logm}\n\
00212 @end deftypefn")
00213 {
00214   octave_value_list retval;
00215 
00216   int nargin = args.length ();
00217 
00218   if (nargin != 1)
00219     {
00220       print_usage ();
00221       return retval;
00222     }
00223 
00224   octave_value arg = args(0);
00225 
00226   octave_idx_type n = arg.rows ();
00227   octave_idx_type nc = arg.columns ();
00228 
00229   if (n != nc || arg.ndims () > 2)
00230     {
00231       gripe_square_matrix_required ("sqrtm");
00232       return retval;
00233     }
00234 
00235   if (nargout > 1)
00236     {
00237       retval.resize (1, 2);
00238       retval(2) = -1.0;
00239     }
00240 
00241   if (arg.is_diag_matrix ())
00242     // sqrtm of a diagonal matrix is just sqrt.
00243     retval(0) = arg.sqrt ();
00244   else if (arg.is_single_type ())
00245     retval(0) = do_sqrtm<FloatMatrix, FloatComplexMatrix, FloatComplexSCHUR> (arg);
00246   else if (arg.is_numeric_type ())
00247     retval(0) = do_sqrtm<Matrix, ComplexMatrix, ComplexSCHUR> (arg);
00248 
00249   if (nargout > 1 && ! error_state)
00250     {
00251       // This corresponds to generic code
00252       //
00253       //   norm (s*s - x, "fro") / norm (x, "fro");
00254 
00255       octave_value s = retval(0);
00256       retval(1) = xfrobnorm (s*s - arg) / xfrobnorm (arg);
00257     }
00258 
00259   return retval;
00260 }
00261 
00262 /*
00263 
00264 %!assert (sqrtm (2*ones (2)), ones (2), 3*eps)
00265 
00266 ## The following two tests are from the reference in the docstring above.
00267 
00268 %!test
00269 %! x = [0 1; 0 0];
00270 %! assert (any (isnan (sqrtm (x))(:) ))
00271 
00272 %!test
00273 %! x = eye (4); x(2,2) = x(3,3) = 2^-26; x(1,4) = 1;
00274 %! z = eye (4); z(2,2) = z(3,3) = 2^-13; z(1,4) = 0.5;
00275 %! [y, err] = sqrtm(x);
00276 %! assert (y, z)
00277 %! assert (err, 0)   ## Yes, this one has to hold exactly
00278 
00279 */
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines