GNU Octave  8.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
sqrtm.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2001-2023 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29 
30 #include "schur.h"
31 #include "lo-ieee.h"
32 #include "lo-mappers.h"
33 #include "oct-norm.h"
34 
35 #include "defun.h"
36 #include "error.h"
37 #include "errwarn.h"
38 #include "utils.h"
39 #include "xnorm.h"
40 
42 
43 template <typename Matrix>
44 static void
46 {
47  typedef typename Matrix::element_type element_type;
48 
49  const element_type zero = element_type ();
50 
51  bool singular = false;
52 
53  // The following code is equivalent to this triple loop:
54  //
55  // n = rows (T);
56  // for j = 1:n
57  // T(j,j) = sqrt (T(j,j));
58  // for i = j-1:-1:1
59  // if T(i,j) != 0
60  // T(i,j) /= (T(i,i) + T(j,j));
61  // endif
62  // k = 1:i-1;
63  // T(k,j) -= T(k,i) * T(i,j);
64  // endfor
65  // endfor
66  //
67  // this is an in-place, cache-aligned variant of the code
68  // given in Higham's paper.
69 
70  const octave_idx_type n = T.rows ();
71  element_type *Tp = T.fortran_vec ();
72  for (octave_idx_type j = 0; j < n; j++)
73  {
74  element_type *colj = Tp + n*j;
75  if (colj[j] != zero)
76  colj[j] = sqrt (colj[j]);
77  else
78  singular = true;
79 
80  for (octave_idx_type i = j-1; i >= 0; i--)
81  {
82  const element_type *coli = Tp + n*i;
83  if (colj[i] != zero)
84  colj[i] /= (coli[i] + colj[j]);
85  const element_type colji = colj[i];
86  for (octave_idx_type k = 0; k < i; k++)
87  colj[k] -= coli[k] * colji;
88  }
89  }
90 
91  if (singular)
92  warning_with_id ("Octave:sqrtm:SingularMatrix",
93  "sqrtm: matrix is singular, may not have a square root");
94 }
95 
96 template <typename Matrix, typename ComplexMatrix, typename ComplexSCHUR>
97 static octave_value
98 do_sqrtm (const octave_value& arg)
99 {
100 
101  octave_value retval;
102 
103  MatrixType mt = arg.matrix_type ();
104 
105  bool iscomplex = arg.iscomplex ();
106 
107  typedef typename Matrix::element_type real_type;
108 
109  real_type cutoff = 0;
110  real_type one = 1;
111  real_type eps = std::numeric_limits<real_type>::epsilon ();
112 
113  if (! iscomplex)
114  {
116 
117  if (mt.is_unknown ()) // if type is not known, compute it now.
118  arg.matrix_type (mt = MatrixType (x));
119 
120  switch (mt.type ())
121  {
122  case MatrixType::Upper:
124  if (! x.diag ().any_element_is_negative ())
125  {
126  // Do it in real arithmetic.
128  retval = x;
129  retval.matrix_type (mt);
130  }
131  else
132  iscomplex = true;
133  break;
134 
135  case MatrixType::Lower:
136  if (! x.diag ().any_element_is_negative ())
137  {
138  x = x.transpose ();
140  retval = x.transpose ();
141  retval.matrix_type (mt);
142  }
143  else
144  iscomplex = true;
145  break;
146 
147  default:
148  iscomplex = true;
149  break;
150  }
151 
152  if (iscomplex)
153  cutoff = 10 * x.rows () * eps * xnorm (x, one);
154  }
155 
156  if (iscomplex)
157  {
159 
160  if (mt.is_unknown ()) // if type is not known, compute it now.
161  arg.matrix_type (mt = MatrixType (x));
162 
163  switch (mt.type ())
164  {
165  case MatrixType::Upper:
168  retval = x;
169  retval.matrix_type (mt);
170  break;
171 
172  case MatrixType::Lower:
173  x = x.transpose ();
175  retval = x.transpose ();
176  retval.matrix_type (mt);
177  break;
178 
179  default:
180  {
181  ComplexMatrix u;
182 
183  do
184  {
185  ComplexSCHUR schur_fact (x, "", true);
186  x = schur_fact.schur_matrix ();
187  u = schur_fact.unitary_schur_matrix ();
188  }
189  while (0); // schur no longer needed.
190 
192 
193  x = u * x; // original x no longer needed.
195 
196  if (cutoff > 0 && xnorm (imag (res), one) <= cutoff)
197  retval = real (res);
198  else
199  retval = res;
200  }
201  break;
202  }
203  }
204 
205  return retval;
206 }
207 
208 DEFUN (sqrtm, args, nargout,
209  doc: /* -*- texinfo -*-
210 @deftypefn {} {@var{s} =} sqrtm (@var{A})
211 @deftypefnx {} {[@var{s}, @var{error_estimate}] =} sqrtm (@var{A})
212 Compute the matrix square root of the square matrix @var{A}.
213 
214 Ref: @nospell{N.J. Higham}. @cite{A New sqrtm for @sc{matlab}}. Numerical
215 Analysis Report No.@: 336, Manchester @nospell{Centre} for Computational
216 Mathematics, Manchester, England, January 1999.
217 @seealso{expm, logm}
218 @end deftypefn */)
219 {
220  if (args.length () != 1)
221  print_usage ();
222 
223  octave_value arg = args(0);
224 
225  octave_idx_type n = arg.rows ();
226  octave_idx_type nc = arg.columns ();
227 
228  if (n != nc || arg.ndims () > 2)
229  err_square_matrix_required ("sqrtm", "A");
230 
231  octave_value_list retval (nargout > 1 ? 3 : 1);
232 
233  if (nargout > 1)
234  {
235  // FIXME: Octave does not calculate a condition number with respect to
236  // sqrtm. Should this return NaN instead of -1?
237  retval(2) = -1.0;
238  }
239 
240  if (arg.is_diag_matrix ())
241  // sqrtm of a diagonal matrix is just sqrt.
242  retval(0) = arg.sqrt ();
243  else if (arg.is_single_type ())
245  math::schur<FloatComplexMatrix>> (arg);
246  else if (arg.isnumeric ())
247  retval(0) = do_sqrtm<Matrix, ComplexMatrix,
248  math::schur<ComplexMatrix>> (arg);
249 
250  if (nargout > 1)
251  {
252  // This corresponds to generic code
253  //
254  // norm (s*s - x, "fro") / norm (x, "fro");
255 
256  octave_value s = retval(0);
257  retval(1) = xfrobnorm (s*s - arg) / xfrobnorm (arg);
258  }
259 
260  return retval;
261 }
262 
263 /*
264 %!assert (sqrtm (2* ones (2)), ones (2), 3*eps)
265 %!assert <*60797> (sqrtm (ones (4))^2, ones (4), 5*eps)
266 
267 ## The following two tests are from the reference in the docstring above.
268 %!test
269 %! warning ("off", "Octave:sqrtm:SingularMatrix", "local");
270 %! x = [0 1; 0 0];
271 %! assert (any (isnan (sqrtm (x))(:)));
272 
273 %!test
274 %! x = eye (4); x(2,2) = x(3,3) = 2^-26; x(1,4) = 1;
275 %! z = eye (4); z(2,2) = z(3,3) = 2^-13; z(1,4) = 0.5;
276 %! [y, err] = sqrtm (x);
277 %! assert (y, z);
278 %! assert (err, 0); # Yes, this one has to hold exactly
279 */
280 
OCTAVE_END_NAMESPACE(octave)
ComplexMatrix xgemm(const ComplexMatrix &a, const ComplexMatrix &b, blas_trans_type transa, blas_trans_type transb)
Definition: CMatrix.cc:3346
OCTARRAY_OVERRIDABLE_FUNC_API octave_idx_type rows(void) const
Definition: Array.h:459
OCTARRAY_API T * fortran_vec(void)
Size of the specified dimension.
Definition: Array-base.cc:1766
double element_type
Definition: Array.h:230
bool is_unknown(void) const
Definition: MatrixType.h:132
OCTAVE_API int type(bool quiet=true)
Definition: MatrixType.cc:656
Definition: dMatrix.h:42
MatrixType matrix_type(void) const
Definition: ov.h:628
octave_idx_type rows(void) const
Definition: ov.h:590
bool isnumeric(void) const
Definition: ov.h:795
bool is_diag_matrix(void) const
Definition: ov.h:676
octave_idx_type columns(void) const
Definition: ov.h:592
int ndims(void) const
Definition: ov.h:596
octave_value sqrt(void) const
Definition: ov.h:1613
bool is_single_type(void) const
Definition: ov.h:743
bool iscomplex(void) const
Definition: ov.h:786
ColumnVector real(const ComplexColumnVector &a)
Definition: dColVector.cc:137
ColumnVector imag(const ComplexColumnVector &a)
Definition: dColVector.cc:143
OCTAVE_BEGIN_NAMESPACE(octave) static octave_value daspk_fcn
T eps(const T &x)
Definition: data.cc:4942
OCTINTERP_API void print_usage(void)
Definition: defun-int.h:72
#define DEFUN(name, args_name, nargout_name, doc)
Macro to define a builtin function.
Definition: defun.h:56
void warning_with_id(const char *id, const char *fmt,...)
Definition: error.cc:1069
void err_square_matrix_required(const char *fcn, const char *name)
Definition: errwarn.cc:122
F77_RET_T const F77_DBLE * x
@ blas_no_trans
Definition: mx-defs.h:81
@ blas_conj_trans
Definition: mx-defs.h:83
class OCTAVE_API Matrix
Definition: mx-fwd.h:31
class OCTAVE_API ComplexMatrix
Definition: mx-fwd.h:32
class OCTAVE_API FloatComplexMatrix
Definition: mx-fwd.h:34
class OCTAVE_API FloatMatrix
Definition: mx-fwd.h:33
octave_idx_type n
Definition: mx-inlines.cc:753
double xfrobnorm(const Matrix &x)
Definition: oct-norm.cc:585
double xnorm(const ColumnVector &x, double p)
Definition: oct-norm.cc:585
ComplexMatrix octave_value_extract< ComplexMatrix >(const octave_value &v)
Definition: ov.h:1953
Matrix octave_value_extract< Matrix >(const octave_value &v)
Definition: ov.h:1951
static octave_value do_sqrtm(const octave_value &arg)
Definition: sqrtm.cc:98
static void sqrtm_utri_inplace(Matrix &T)
Definition: sqrtm.cc:45