GNU Octave  6.2.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
sqrtm.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2001-2021 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29 
30 #include "schur.h"
31 #include "lo-ieee.h"
32 #include "lo-mappers.h"
33 #include "oct-norm.h"
34 
35 #include "defun.h"
36 #include "error.h"
37 #include "errwarn.h"
38 #include "utils.h"
39 #include "xnorm.h"
40 
41 template <typename Matrix>
42 static void
44 {
45  typedef typename Matrix::element_type element_type;
46 
47  const element_type zero = element_type ();
48 
49  bool singular = false;
50 
51  // The following code is equivalent to this triple loop:
52  //
53  // n = rows (T);
54  // for j = 1:n
55  // T(j,j) = sqrt (T(j,j));
56  // for i = j-1:-1:1
57  // T(i,j) /= (T(i,i) + T(j,j));
58  // k = 1:i-1;
59  // T(k,j) -= T(k,i) * T(i,j);
60  // endfor
61  // endfor
62  //
63  // this is an in-place, cache-aligned variant of the code
64  // given in Higham's paper.
65 
66  const octave_idx_type n = T.rows ();
67  element_type *Tp = T.fortran_vec ();
68  for (octave_idx_type j = 0; j < n; j++)
69  {
70  element_type *colj = Tp + n*j;
71  if (colj[j] != zero)
72  colj[j] = sqrt (colj[j]);
73  else
74  singular = true;
75 
76  for (octave_idx_type i = j-1; i >= 0; i--)
77  {
78  const element_type *coli = Tp + n*i;
79  const element_type colji = colj[i] /= (coli[i] + colj[j]);
80  for (octave_idx_type k = 0; k < i; k++)
81  colj[k] -= coli[k] * colji;
82  }
83  }
84 
85  if (singular)
86  warning_with_id ("Octave:sqrtm:SingularMatrix",
87  "sqrtm: matrix is singular, may not have a square root");
88 }
89 
90 template <typename Matrix, typename ComplexMatrix, typename ComplexSCHUR>
91 static octave_value
92 do_sqrtm (const octave_value& arg)
93 {
94 
96 
97  MatrixType mt = arg.matrix_type ();
98 
99  bool iscomplex = arg.iscomplex ();
100 
101  typedef typename Matrix::element_type real_type;
102 
103  real_type cutoff = 0;
104  real_type one = 1;
105  real_type eps = std::numeric_limits<real_type>::epsilon ();
106 
107  if (! iscomplex)
108  {
110 
111  if (mt.is_unknown ()) // if type is not known, compute it now.
112  arg.matrix_type (mt = MatrixType (x));
113 
114  switch (mt.type ())
115  {
116  case MatrixType::Upper:
118  if (! x.diag ().any_element_is_negative ())
119  {
120  // Do it in real arithmetic.
122  retval = x;
123  retval.matrix_type (mt);
124  }
125  else
126  iscomplex = true;
127  break;
128 
129  case MatrixType::Lower:
130  if (! x.diag ().any_element_is_negative ())
131  {
132  x = x.transpose ();
134  retval = x.transpose ();
135  retval.matrix_type (mt);
136  }
137  else
138  iscomplex = true;
139  break;
140 
141  default:
142  iscomplex = true;
143  break;
144  }
145 
146  if (iscomplex)
147  cutoff = 10 * x.rows () * eps * xnorm (x, one);
148  }
149 
150  if (iscomplex)
151  {
153 
154  if (mt.is_unknown ()) // if type is not known, compute it now.
155  arg.matrix_type (mt = MatrixType (x));
156 
157  switch (mt.type ())
158  {
159  case MatrixType::Upper:
162  retval = x;
163  retval.matrix_type (mt);
164  break;
165 
166  case MatrixType::Lower:
167  x = x.transpose ();
169  retval = x.transpose ();
170  retval.matrix_type (mt);
171  break;
172 
173  default:
174  {
175  ComplexMatrix u;
176 
177  do
178  {
179  ComplexSCHUR schur_fact (x, "", true);
180  x = schur_fact.schur_matrix ();
181  u = schur_fact.unitary_matrix ();
182  }
183  while (0); // schur no longer needed.
184 
186 
187  x = u * x; // original x no longer needed.
189 
190  if (cutoff > 0 && xnorm (imag (res), one) <= cutoff)
191  retval = real (res);
192  else
193  retval = res;
194  }
195  break;
196  }
197  }
198 
199  return retval;
200 }
201 
202 DEFUN (sqrtm, args, nargout,
203  doc: /* -*- texinfo -*-
204 @deftypefn {} {@var{s} =} sqrtm (@var{A})
205 @deftypefnx {} {[@var{s}, @var{error_estimate}] =} sqrtm (@var{A})
206 Compute the matrix square root of the square matrix @var{A}.
207 
208 Ref: @nospell{N.J. Higham}. @cite{A New sqrtm for @sc{matlab}}. Numerical
209 Analysis Report No.@: 336, Manchester @nospell{Centre} for Computational
210 Mathematics, Manchester, England, January 1999.
211 @seealso{expm, logm}
212 @end deftypefn */)
213 {
214  if (args.length () != 1)
215  print_usage ();
216 
217  octave_value arg = args(0);
218 
219  octave_idx_type n = arg.rows ();
220  octave_idx_type nc = arg.columns ();
221 
222  if (n != nc || arg.ndims () > 2)
223  err_square_matrix_required ("sqrtm", "A");
224 
225  octave_value_list retval (nargout > 1 ? 3 : 1);
226 
227  if (nargout > 1)
228  {
229  // FIXME: Octave does not calculate a condition number with respect to
230  // sqrtm. Should this return NaN instead of -1?
231  retval(2) = -1.0;
232  }
233 
234  if (arg.is_diag_matrix ())
235  // sqrtm of a diagonal matrix is just sqrt.
236  retval(0) = arg.sqrt ();
237  else if (arg.is_single_type ())
240  else if (arg.isnumeric ())
243 
244  if (nargout > 1)
245  {
246  // This corresponds to generic code
247  //
248  // norm (s*s - x, "fro") / norm (x, "fro");
249 
250  octave_value s = retval(0);
251  retval(1) = xfrobnorm (s*s - arg) / xfrobnorm (arg);
252  }
253 
254  return retval;
255 }
256 
257 /*
258 %!assert (sqrtm (2*ones (2)), ones (2), 3*eps)
259 
260 ## The following two tests are from the reference in the docstring above.
261 %!test
262 %! warning ("off", "Octave:sqrtm:SingularMatrix", "local");
263 %! x = [0 1; 0 0];
264 %! assert (any (isnan (sqrtm (x))(:)));
265 
266 %!test
267 %! x = eye (4); x(2,2) = x(3,3) = 2^-26; x(1,4) = 1;
268 %! z = eye (4); z(2,2) = z(3,3) = 2^-13; z(1,4) = 0.5;
269 %! [y, err] = sqrtm (x);
270 %! assert (y, z);
271 %! assert (err, 0); # Yes, this one has to hold exactly
272 */
ComplexMatrix xgemm(const ComplexMatrix &a, const ComplexMatrix &b, blas_trans_type transa, blas_trans_type transb)
Definition: CMatrix.cc:3322
octave_idx_type rows(void) const
Definition: Array.h:415
double element_type
Definition: Array.h:202
const T * fortran_vec(void) const
Size of the specified dimension.
Definition: Array.h:583
bool is_unknown(void) const
Definition: MatrixType.h:137
int type(bool quiet=true)
Definition: MatrixType.cc:653
Definition: dMatrix.h:42
MatrixType matrix_type(void) const
Definition: ov.h:542
octave_idx_type rows(void) const
Definition: ov.h:504
bool isnumeric(void) const
Definition: ov.h:703
bool is_diag_matrix(void) const
Definition: ov.h:587
octave_idx_type columns(void) const
Definition: ov.h:506
int ndims(void) const
Definition: ov.h:510
octave_value sqrt(void) const
Definition: ov.h:1412
bool is_single_type(void) const
Definition: ov.h:651
bool iscomplex(void) const
Definition: ov.h:694
ColumnVector real(const ComplexColumnVector &a)
Definition: dColVector.cc:137
ColumnVector imag(const ComplexColumnVector &a)
Definition: dColVector.cc:143
T eps(const T &x)
Definition: data.cc:4578
OCTINTERP_API void print_usage(void)
Definition: defun.cc:53
#define DEFUN(name, args_name, nargout_name, doc)
Macro to define a builtin function.
Definition: defun.h:56
void warning_with_id(const char *id, const char *fmt,...)
Definition: error.cc:1065
void err_square_matrix_required(const char *fcn, const char *name)
Definition: errwarn.cc:122
F77_RET_T const F77_DBLE * x
@ blas_no_trans
Definition: mx-defs.h:110
@ blas_conj_trans
Definition: mx-defs.h:112
octave_idx_type n
Definition: mx-inlines.cc:753
OCTAVE_API double xfrobnorm(const Matrix &x)
Definition: oct-norm.cc:551
OCTAVE_API double xnorm(const ColumnVector &x, double p)
Definition: oct-norm.cc:551
octave_value::octave_value(const Array< char > &chm, char type) return retval
Definition: ov.cc:811
ComplexMatrix octave_value_extract< ComplexMatrix >(const octave_value &v)
Definition: ov.h:1649
Matrix octave_value_extract< Matrix >(const octave_value &v)
Definition: ov.h:1647
static octave_value do_sqrtm(const octave_value &arg)
Definition: sqrtm.cc:92
static void sqrtm_utri_inplace(Matrix &T)
Definition: sqrtm.cc:43