GNU Octave  8.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
chol.cc
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 1994-2023 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if defined (HAVE_CONFIG_H)
27 # include "config.h"
28 #endif
29 
30 #include "Array.h"
31 #include "CColVector.h"
32 #include "CMatrix.h"
33 #include "chol.h"
34 #include "dColVector.h"
35 #include "dMatrix.h"
36 #include "fCColVector.h"
37 #include "fCMatrix.h"
38 #include "fColVector.h"
39 #include "fMatrix.h"
40 #include "lo-error.h"
41 #include "lo-lapack-proto.h"
42 #include "lo-qrupdate-proto.h"
43 #include "oct-locbuf.h"
44 #include "oct-norm.h"
45 
46 #if ! defined (HAVE_QRUPDATE)
47 # include "qr.h"
48 #endif
49 
51 
52 static Matrix
53 chol2inv_internal (const Matrix& r, bool is_upper = true)
54 {
55  Matrix retval;
56 
57  octave_idx_type r_nr = r.rows ();
58  octave_idx_type r_nc = r.cols ();
59 
60  if (r_nr != r_nc)
61  (*current_liboctave_error_handler) ("chol2inv requires square matrix");
62 
63  F77_INT n = to_f77_int (r_nc);
64  F77_INT info;
65 
66  Matrix tmp = r;
67  double *v = tmp.fortran_vec ();
68 
69  if (is_upper)
70  F77_XFCN (dpotri, DPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
71  v, n, info
72  F77_CHAR_ARG_LEN (1)));
73  else
74  F77_XFCN (dpotri, DPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n,
75  v, n, info
76  F77_CHAR_ARG_LEN (1)));
77 
78  // FIXME: Should we check info exit value and possibly report an error?
79 
80  // If someone thinks of a more graceful way of doing this
81  // (or faster for that matter :-)), please let me know!
82 
83  if (n > 1)
84  {
85  if (is_upper)
86  for (octave_idx_type j = 0; j < r_nc; j++)
87  for (octave_idx_type i = j+1; i < r_nr; i++)
88  tmp.xelem (i, j) = tmp.xelem (j, i);
89  else
90  for (octave_idx_type j = 0; j < r_nc; j++)
91  for (octave_idx_type i = j+1; i < r_nr; i++)
92  tmp.xelem (j, i) = tmp.xelem (i, j);
93  }
94 
95  retval = tmp;
96 
97  return retval;
98 }
99 
100 static FloatMatrix
101 chol2inv_internal (const FloatMatrix& r, bool is_upper = true)
102 {
103  FloatMatrix retval;
104 
105  octave_idx_type r_nr = r.rows ();
106  octave_idx_type r_nc = r.cols ();
107 
108  if (r_nr != r_nc)
109  (*current_liboctave_error_handler) ("chol2inv requires square matrix");
110 
111  F77_INT n = to_f77_int (r_nc);
112  F77_INT info;
113 
114  FloatMatrix tmp = r;
115  float *v = tmp.fortran_vec ();
116 
117  if (is_upper)
118  F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
119  v, n, info
120  F77_CHAR_ARG_LEN (1)));
121  else
122  F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n,
123  v, n, info
124  F77_CHAR_ARG_LEN (1)));
125 
126  // FIXME: Should we check info exit value and possibly report an error?
127 
128  // If someone thinks of a more graceful way of doing this (or
129  // faster for that matter :-)), please let me know!
130 
131  if (n > 1)
132  {
133  if (is_upper)
134  for (octave_idx_type j = 0; j < r_nc; j++)
135  for (octave_idx_type i = j+1; i < r_nr; i++)
136  tmp.xelem (i, j) = tmp.xelem (j, i);
137  else
138  for (octave_idx_type j = 0; j < r_nc; j++)
139  for (octave_idx_type i = j+1; i < r_nr; i++)
140  tmp.xelem (j, i) = tmp.xelem (i, j);
141  }
142 
143  retval = tmp;
144 
145  return retval;
146 }
147 
148 static ComplexMatrix
149 chol2inv_internal (const ComplexMatrix& r, bool is_upper = true)
150 {
151  ComplexMatrix retval;
152 
153  octave_idx_type r_nr = r.rows ();
154  octave_idx_type r_nc = r.cols ();
155 
156  if (r_nr != r_nc)
157  (*current_liboctave_error_handler) ("chol2inv requires square matrix");
158 
159  F77_INT n = to_f77_int (r_nc);
160  F77_INT info;
161 
162  ComplexMatrix tmp = r;
163 
164  if (is_upper)
165  F77_XFCN (zpotri, ZPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
166  F77_DBLE_CMPLX_ARG (tmp.fortran_vec ()), n, info
167  F77_CHAR_ARG_LEN (1)));
168  else
169  F77_XFCN (zpotri, ZPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n,
170  F77_DBLE_CMPLX_ARG (tmp.fortran_vec ()), n, info
171  F77_CHAR_ARG_LEN (1)));
172 
173  // If someone thinks of a more graceful way of doing this (or
174  // faster for that matter :-)), please let me know!
175 
176  if (n > 1)
177  {
178  if (is_upper)
179  for (octave_idx_type j = 0; j < r_nc; j++)
180  for (octave_idx_type i = j+1; i < r_nr; i++)
181  tmp.xelem (i, j) = std::conj (tmp.xelem (j, i));
182  else
183  for (octave_idx_type j = 0; j < r_nc; j++)
184  for (octave_idx_type i = j+1; i < r_nr; i++)
185  tmp.xelem (j, i) = std::conj (tmp.xelem (i, j));
186  }
187 
188  retval = tmp;
189 
190  return retval;
191 }
192 
193 static FloatComplexMatrix
194 chol2inv_internal (const FloatComplexMatrix& r, bool is_upper = true)
195 {
196  FloatComplexMatrix retval;
197 
198  octave_idx_type r_nr = r.rows ();
199  octave_idx_type r_nc = r.cols ();
200 
201  if (r_nr != r_nc)
202  (*current_liboctave_error_handler) ("chol2inv requires square matrix");
203 
204  F77_INT n = to_f77_int (r_nc);
205  F77_INT info;
206 
207  FloatComplexMatrix tmp = r;
208 
209  if (is_upper)
210  F77_XFCN (cpotri, CPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
211  F77_CMPLX_ARG (tmp.fortran_vec ()), n, info
212  F77_CHAR_ARG_LEN (1)));
213  else
214  F77_XFCN (cpotri, CPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n,
215  F77_CMPLX_ARG (tmp.fortran_vec ()), n, info
216  F77_CHAR_ARG_LEN (1)));
217 
218  // If someone thinks of a more graceful way of doing this (or
219  // faster for that matter :-)), please let me know!
220 
221  if (n > 1)
222  {
223  if (is_upper)
224  for (octave_idx_type j = 0; j < r_nc; j++)
225  for (octave_idx_type i = j+1; i < r_nr; i++)
226  tmp.xelem (i, j) = std::conj (tmp.xelem (j, i));
227  else
228  for (octave_idx_type j = 0; j < r_nc; j++)
229  for (octave_idx_type i = j+1; i < r_nr; i++)
230  tmp.xelem (j, i) = std::conj (tmp.xelem (i, j));
231  }
232 
233  retval = tmp;
234 
235  return retval;
236 }
237 
239 
240 template <typename T>
241 T
242 chol2inv (const T& r)
243 {
244  return chol2inv_internal (r);
245 }
246 
247 // Compute the inverse of a matrix using the Cholesky factorization.
248 template <typename T>
249 T
250 chol<T>::inverse (void) const
251 {
252  return chol2inv_internal (m_chol_mat, m_is_upper);
253 }
254 
255 template <typename T>
256 void
257 chol<T>::set (const T& R)
258 {
259  if (! R.issquare ())
260  (*current_liboctave_error_handler) ("chol: requires square matrix");
261 
262  m_chol_mat = R;
263 }
264 
265 #if ! defined (HAVE_QRUPDATE)
266 
267 template <typename T>
268 void
269 chol<T>::update (const VT& u)
270 {
272 
273  octave_idx_type n = m_chol_mat.rows ();
274 
275  if (u.numel () != n)
276  (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
277 
278  init (m_chol_mat.hermitian () * m_chol_mat + T (u) * T (u).hermitian (),
279  true, false);
280 }
281 
282 template <typename T>
283 bool
284 singular (const T& a)
285 {
286  static typename T::element_type zero (0);
287  for (octave_idx_type i = 0; i < a.rows (); i++)
288  if (a(i, i) == zero) return true;
289  return false;
290 }
291 
292 template <typename T>
294 chol<T>::downdate (const VT& u)
295 {
297 
298  octave_idx_type info = -1;
299 
300  octave_idx_type n = m_chol_mat.rows ();
301 
302  if (u.numel () != n)
303  (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
304 
305  if (singular (m_chol_mat))
306  info = 2;
307  else
308  {
309  info = init (m_chol_mat.hermitian () * m_chol_mat
310  - T (u) * T (u).hermitian (), true, false);
311  if (info) info = 1;
312  }
313 
314  return info;
315 }
316 
317 template <typename T>
319 chol<T>::insert_sym (const VT& u, octave_idx_type j)
320 {
321  static typename T::element_type zero (0);
322 
324 
325  octave_idx_type info = -1;
326 
327  octave_idx_type n = m_chol_mat.rows ();
328 
329  if (u.numel () != n + 1)
330  (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
331  if (j < 0 || j > n)
332  (*current_liboctave_error_handler) ("cholinsert: index out of range");
333 
334  if (singular (m_chol_mat))
335  info = 2;
336  else if (std::imag (u(j)) != zero)
337  info = 3;
338  else
339  {
340  T a = m_chol_mat.hermitian () * m_chol_mat;
341  T a1 (n+1, n+1);
342  for (octave_idx_type k = 0; k < n+1; k++)
343  for (octave_idx_type l = 0; l < n+1; l++)
344  {
345  if (l == j)
346  a1(k, l) = u(k);
347  else if (k == j)
348  a1(k, l) = math::conj (u(l));
349  else
350  a1(k, l) = a(k < j ? k : k-1, l < j ? l : l-1);
351  }
352  info = init (a1, true, false);
353  if (info) info = 1;
354  }
355 
356  return info;
357 }
358 
359 template <typename T>
360 void
362 {
364 
365  octave_idx_type n = m_chol_mat.rows ();
366 
367  if (j < 0 || j > n-1)
368  (*current_liboctave_error_handler) ("choldelete: index out of range");
369 
370  T a = m_chol_mat.hermitian () * m_chol_mat;
371  a.delete_elements (1, idx_vector (j));
372  a.delete_elements (0, idx_vector (j));
373  init (a, true, false);
374 }
375 
376 template <typename T>
377 void
379 {
381 
382  octave_idx_type n = m_chol_mat.rows ();
383 
384  if (i < 0 || i > n-1 || j < 0 || j > n-1)
385  (*current_liboctave_error_handler) ("cholshift: index out of range");
386 
387  T a = m_chol_mat.hermitian () * m_chol_mat;
389  for (octave_idx_type k = 0; k < n; k++) p(k) = k;
390  if (i < j)
391  {
392  for (octave_idx_type k = i; k < j; k++) p(k) = k+1;
393  p(j) = i;
394  }
395  else if (j < i)
396  {
397  p(j) = i;
398  for (octave_idx_type k = j+1; k < i+1; k++) p(k) = k-1;
399  }
400 
401  init (a.index (idx_vector (p), idx_vector (p)), true, false);
402 }
403 
404 #endif
405 
406 // Specializations.
407 
408 template <>
410 chol<Matrix>::init (const Matrix& a, bool upper, bool calc_cond)
411 {
412  octave_idx_type a_nr = a.rows ();
413  octave_idx_type a_nc = a.cols ();
414 
415  if (a_nr != a_nc)
416  (*current_liboctave_error_handler) ("chol: requires square matrix");
417 
418  F77_INT n = to_f77_int (a_nc);
419  F77_INT info;
420 
421  m_is_upper = upper;
422 
423  m_chol_mat.clear (n, n);
424  if (m_is_upper)
425  for (octave_idx_type j = 0; j < n; j++)
426  {
427  for (octave_idx_type i = 0; i <= j; i++)
428  m_chol_mat.xelem (i, j) = a(i, j);
429  for (octave_idx_type i = j+1; i < n; i++)
430  m_chol_mat.xelem (i, j) = 0.0;
431  }
432  else
433  for (octave_idx_type j = 0; j < n; j++)
434  {
435  for (octave_idx_type i = 0; i < j; i++)
436  m_chol_mat.xelem (i, j) = 0.0;
437  for (octave_idx_type i = j; i < n; i++)
438  m_chol_mat.xelem (i, j) = a(i, j);
439  }
440  double *h = m_chol_mat.fortran_vec ();
441 
442  // Calculate the norm of the matrix, for later use.
443  double anorm = 0;
444  if (calc_cond)
445  anorm = octave::xnorm (a, 1);
446 
447  if (m_is_upper)
448  F77_XFCN (dpotrf, DPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, n, info
449  F77_CHAR_ARG_LEN (1)));
450  else
451  F77_XFCN (dpotrf, DPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1), n, h, n, info
452  F77_CHAR_ARG_LEN (1)));
453 
454  m_rcond = 0.0;
455  if (info > 0)
456  m_chol_mat.resize (info - 1, info - 1);
457  else if (calc_cond)
458  {
459  F77_INT dpocon_info = 0;
460 
461  // Now calculate the condition number for non-singular matrix.
462  Array<double> z (dim_vector (3*n, 1));
463  double *pz = z.fortran_vec ();
465  if (m_is_upper)
466  F77_XFCN (dpocon, DPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h,
467  n, anorm, m_rcond, pz, iz, dpocon_info
468  F77_CHAR_ARG_LEN (1)));
469  else
470  F77_XFCN (dpocon, DPOCON, (F77_CONST_CHAR_ARG2 ("L", 1), n, h,
471  n, anorm, m_rcond, pz, iz, dpocon_info
472  F77_CHAR_ARG_LEN (1)));
473 
474  if (dpocon_info != 0)
475  info = -1;
476  }
477 
478  return info;
479 }
480 
481 #if defined (HAVE_QRUPDATE)
482 
483 template <>
484 OCTAVE_API void
486 {
487  F77_INT n = to_f77_int (m_chol_mat.rows ());
488 
489  if (u.numel () != n)
490  (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
491 
492  ColumnVector utmp = u;
493 
494  OCTAVE_LOCAL_BUFFER (double, w, n);
495 
496  F77_XFCN (dch1up, DCH1UP, (n, m_chol_mat.fortran_vec (), n,
497  utmp.fortran_vec (), w));
498 }
499 
500 template <>
503 {
504  F77_INT info = -1;
505 
506  F77_INT n = to_f77_int (m_chol_mat.rows ());
507 
508  if (u.numel () != n)
509  (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
510 
511  ColumnVector utmp = u;
512 
513  OCTAVE_LOCAL_BUFFER (double, w, n);
514 
515  F77_XFCN (dch1dn, DCH1DN, (n, m_chol_mat.fortran_vec (), n,
516  utmp.fortran_vec (), w, info));
517 
518  return info;
519 }
520 
521 template <>
524 {
525  F77_INT info = -1;
526 
527  F77_INT n = to_f77_int (m_chol_mat.rows ());
528  F77_INT j = to_f77_int (j_arg);
529 
530  if (u.numel () != n + 1)
531  (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
532  if (j < 0 || j > n)
533  (*current_liboctave_error_handler) ("cholinsert: index out of range");
534 
535  ColumnVector utmp = u;
536 
537  OCTAVE_LOCAL_BUFFER (double, w, n);
538 
539  m_chol_mat.resize (n+1, n+1);
540  F77_INT ldcm = to_f77_int (m_chol_mat.rows ());
541 
542  F77_XFCN (dchinx, DCHINX, (n, m_chol_mat.fortran_vec (), ldcm,
543  j + 1, utmp.fortran_vec (), w, info));
544 
545  return info;
546 }
547 
548 template <>
549 OCTAVE_API void
551 {
552  F77_INT n = to_f77_int (m_chol_mat.rows ());
553  F77_INT j = to_f77_int (j_arg);
554 
555  if (j < 0 || j > n-1)
556  (*current_liboctave_error_handler) ("choldelete: index out of range");
557 
558  OCTAVE_LOCAL_BUFFER (double, w, n);
559 
560  F77_XFCN (dchdex, DCHDEX, (n, m_chol_mat.fortran_vec (), n, j + 1, w));
561 
562  m_chol_mat.resize (n-1, n-1);
563 }
564 
565 template <>
566 OCTAVE_API void
568 {
569  F77_INT n = to_f77_int (m_chol_mat.rows ());
570  F77_INT i = to_f77_int (i_arg);
571  F77_INT j = to_f77_int (j_arg);
572 
573  if (i < 0 || i > n-1 || j < 0 || j > n-1)
574  (*current_liboctave_error_handler) ("cholshift: index out of range");
575 
576  OCTAVE_LOCAL_BUFFER (double, w, 2*n);
577 
578  F77_XFCN (dchshx, DCHSHX, (n, m_chol_mat.fortran_vec (), n,
579  i + 1, j + 1, w));
580 }
581 
582 #endif
583 
584 template <>
586 chol<FloatMatrix>::init (const FloatMatrix& a, bool upper, bool calc_cond)
587 {
588  octave_idx_type a_nr = a.rows ();
589  octave_idx_type a_nc = a.cols ();
590 
591  if (a_nr != a_nc)
592  (*current_liboctave_error_handler) ("chol: requires square matrix");
593 
594  F77_INT n = to_f77_int (a_nc);
595  F77_INT info;
596 
597  m_is_upper = upper;
598 
599  m_chol_mat.clear (n, n);
600  if (m_is_upper)
601  for (octave_idx_type j = 0; j < n; j++)
602  {
603  for (octave_idx_type i = 0; i <= j; i++)
604  m_chol_mat.xelem (i, j) = a(i, j);
605  for (octave_idx_type i = j+1; i < n; i++)
606  m_chol_mat.xelem (i, j) = 0.0f;
607  }
608  else
609  for (octave_idx_type j = 0; j < n; j++)
610  {
611  for (octave_idx_type i = 0; i < j; i++)
612  m_chol_mat.xelem (i, j) = 0.0f;
613  for (octave_idx_type i = j; i < n; i++)
614  m_chol_mat.xelem (i, j) = a(i, j);
615  }
616  float *h = m_chol_mat.fortran_vec ();
617 
618  // Calculate the norm of the matrix, for later use.
619  float anorm = 0;
620  if (calc_cond)
621  anorm = octave::xnorm (a, 1);
622 
623  if (m_is_upper)
624  F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, n, info
625  F77_CHAR_ARG_LEN (1)));
626  else
627  F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1), n, h, n, info
628  F77_CHAR_ARG_LEN (1)));
629 
630  m_rcond = 0.0;
631  if (info > 0)
632  m_chol_mat.resize (info - 1, info - 1);
633  else if (calc_cond)
634  {
635  F77_INT spocon_info = 0;
636 
637  // Now calculate the condition number for non-singular matrix.
638  Array<float> z (dim_vector (3*n, 1));
639  float *pz = z.fortran_vec ();
641  if (m_is_upper)
642  F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h,
643  n, anorm, m_rcond, pz, iz, spocon_info
644  F77_CHAR_ARG_LEN (1)));
645  else
646  F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("L", 1), n, h,
647  n, anorm, m_rcond, pz, iz, spocon_info
648  F77_CHAR_ARG_LEN (1)));
649 
650  if (spocon_info != 0)
651  info = -1;
652  }
653 
654  return info;
655 }
656 
657 #if defined (HAVE_QRUPDATE)
658 
659 template <>
660 OCTAVE_API void
662 {
663  F77_INT n = to_f77_int (m_chol_mat.rows ());
664 
665  if (u.numel () != n)
666  (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
667 
668  FloatColumnVector utmp = u;
669 
670  OCTAVE_LOCAL_BUFFER (float, w, n);
671 
672  F77_XFCN (sch1up, SCH1UP, (n, m_chol_mat.fortran_vec (), n,
673  utmp.fortran_vec (), w));
674 }
675 
676 template <>
679 {
680  F77_INT info = -1;
681 
682  F77_INT n = to_f77_int (m_chol_mat.rows ());
683 
684  if (u.numel () != n)
685  (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
686 
687  FloatColumnVector utmp = u;
688 
689  OCTAVE_LOCAL_BUFFER (float, w, n);
690 
691  F77_XFCN (sch1dn, SCH1DN, (n, m_chol_mat.fortran_vec (), n,
692  utmp.fortran_vec (), w, info));
693 
694  return info;
695 }
696 
697 template <>
700  octave_idx_type j_arg)
701 {
702  F77_INT info = -1;
703 
704  F77_INT n = to_f77_int (m_chol_mat.rows ());
705  F77_INT j = to_f77_int (j_arg);
706 
707  if (u.numel () != n + 1)
708  (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
709  if (j < 0 || j > n)
710  (*current_liboctave_error_handler) ("cholinsert: index out of range");
711 
712  FloatColumnVector utmp = u;
713 
714  OCTAVE_LOCAL_BUFFER (float, w, n);
715 
716  m_chol_mat.resize (n+1, n+1);
717  F77_INT ldcm = to_f77_int (m_chol_mat.rows ());
718 
719  F77_XFCN (schinx, SCHINX, (n, m_chol_mat.fortran_vec (), ldcm,
720  j + 1, utmp.fortran_vec (), w, info));
721 
722  return info;
723 }
724 
725 template <>
726 OCTAVE_API void
728 {
729  F77_INT n = to_f77_int (m_chol_mat.rows ());
730  F77_INT j = to_f77_int (j_arg);
731 
732  if (j < 0 || j > n-1)
733  (*current_liboctave_error_handler) ("choldelete: index out of range");
734 
735  OCTAVE_LOCAL_BUFFER (float, w, n);
736 
737  F77_XFCN (schdex, SCHDEX, (n, m_chol_mat.fortran_vec (), n,
738  j + 1, w));
739 
740  m_chol_mat.resize (n-1, n-1);
741 }
742 
743 template <>
744 OCTAVE_API void
746 {
747  F77_INT n = to_f77_int (m_chol_mat.rows ());
748  F77_INT i = to_f77_int (i_arg);
749  F77_INT j = to_f77_int (j_arg);
750 
751  if (i < 0 || i > n-1 || j < 0 || j > n-1)
752  (*current_liboctave_error_handler) ("cholshift: index out of range");
753 
754  OCTAVE_LOCAL_BUFFER (float, w, 2*n);
755 
756  F77_XFCN (schshx, SCHSHX, (n, m_chol_mat.fortran_vec (), n,
757  i + 1, j + 1, w));
758 }
759 
760 #endif
761 
762 template <>
764 chol<ComplexMatrix>::init (const ComplexMatrix& a, bool upper, bool calc_cond)
765 {
766  octave_idx_type a_nr = a.rows ();
767  octave_idx_type a_nc = a.cols ();
768 
769  if (a_nr != a_nc)
770  (*current_liboctave_error_handler) ("chol: requires square matrix");
771 
772  F77_INT n = to_f77_int (a_nc);
773  F77_INT info;
774 
775  m_is_upper = upper;
776 
777  m_chol_mat.clear (n, n);
778  if (m_is_upper)
779  for (octave_idx_type j = 0; j < n; j++)
780  {
781  for (octave_idx_type i = 0; i <= j; i++)
782  m_chol_mat.xelem (i, j) = a(i, j);
783  for (octave_idx_type i = j+1; i < n; i++)
784  m_chol_mat.xelem (i, j) = 0.0;
785  }
786  else
787  for (octave_idx_type j = 0; j < n; j++)
788  {
789  for (octave_idx_type i = 0; i < j; i++)
790  m_chol_mat.xelem (i, j) = 0.0;
791  for (octave_idx_type i = j; i < n; i++)
792  m_chol_mat.xelem (i, j) = a(i, j);
793  }
794  Complex *h = m_chol_mat.fortran_vec ();
795 
796  // Calculate the norm of the matrix, for later use.
797  double anorm = 0;
798  if (calc_cond)
799  anorm = octave::xnorm (a, 1);
800 
801  if (m_is_upper)
802  F77_XFCN (zpotrf, ZPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), n,
803  F77_DBLE_CMPLX_ARG (h), n, info
804  F77_CHAR_ARG_LEN (1)));
805  else
806  F77_XFCN (zpotrf, ZPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1), n,
807  F77_DBLE_CMPLX_ARG (h), n, info
808  F77_CHAR_ARG_LEN (1)));
809 
810  m_rcond = 0.0;
811  if (info > 0)
812  m_chol_mat.resize (info - 1, info - 1);
813  else if (calc_cond)
814  {
815  F77_INT zpocon_info = 0;
816 
817  // Now calculate the condition number for non-singular matrix.
818  Array<Complex> z (dim_vector (2*n, 1));
819  Complex *pz = z.fortran_vec ();
820  Array<double> rz (dim_vector (n, 1));
821  double *prz = rz.fortran_vec ();
822  F77_XFCN (zpocon, ZPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n,
823  F77_DBLE_CMPLX_ARG (h), n, anorm, m_rcond,
824  F77_DBLE_CMPLX_ARG (pz), prz, zpocon_info
825  F77_CHAR_ARG_LEN (1)));
826 
827  if (zpocon_info != 0)
828  info = -1;
829  }
830 
831  return info;
832 }
833 
834 #if defined (HAVE_QRUPDATE)
835 
836 template <>
837 OCTAVE_API void
839 {
840  F77_INT n = to_f77_int (m_chol_mat.rows ());
841 
842  if (u.numel () != n)
843  (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
844 
845  ComplexColumnVector utmp = u;
846 
847  OCTAVE_LOCAL_BUFFER (double, rw, n);
848 
849  F77_XFCN (zch1up, ZCH1UP, (n,
850  F77_DBLE_CMPLX_ARG (m_chol_mat.fortran_vec ()),
851  n,
852  F77_DBLE_CMPLX_ARG (utmp.fortran_vec ()),
853  rw));
854 }
855 
856 template <>
859 {
860  F77_INT info = -1;
861 
862  F77_INT n = to_f77_int (m_chol_mat.rows ());
863 
864  if (u.numel () != n)
865  (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
866 
867  ComplexColumnVector utmp = u;
868 
869  OCTAVE_LOCAL_BUFFER (double, rw, n);
870 
871  F77_XFCN (zch1dn, ZCH1DN, (n,
872  F77_DBLE_CMPLX_ARG (m_chol_mat.fortran_vec ()),
873  n,
874  F77_DBLE_CMPLX_ARG (utmp.fortran_vec ()),
875  rw, info));
876 
877  return info;
878 }
879 
880 template <>
883  octave_idx_type j_arg)
884 {
885  F77_INT info = -1;
886 
887  F77_INT n = to_f77_int (m_chol_mat.rows ());
888  F77_INT j = to_f77_int (j_arg);
889 
890  if (u.numel () != n + 1)
891  (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
892  if (j < 0 || j > n)
893  (*current_liboctave_error_handler) ("cholinsert: index out of range");
894 
895  ComplexColumnVector utmp = u;
896 
897  OCTAVE_LOCAL_BUFFER (double, rw, n);
898 
899  m_chol_mat.resize (n+1, n+1);
900  F77_INT ldcm = to_f77_int (m_chol_mat.rows ());
901 
902  F77_XFCN (zchinx, ZCHINX, (n,
903  F77_DBLE_CMPLX_ARG (m_chol_mat.fortran_vec ()),
904  ldcm, j + 1,
905  F77_DBLE_CMPLX_ARG (utmp.fortran_vec ()),
906  rw, info));
907 
908  return info;
909 }
910 
911 template <>
912 OCTAVE_API void
914 {
915  F77_INT n = to_f77_int (m_chol_mat.rows ());
916  F77_INT j = to_f77_int (j_arg);
917 
918  if (j < 0 || j > n-1)
919  (*current_liboctave_error_handler) ("choldelete: index out of range");
920 
921  OCTAVE_LOCAL_BUFFER (double, rw, n);
922 
923  F77_XFCN (zchdex, ZCHDEX, (n,
924  F77_DBLE_CMPLX_ARG (m_chol_mat.fortran_vec ()),
925  n, j + 1, rw));
926 
927  m_chol_mat.resize (n-1, n-1);
928 }
929 
930 template <>
931 OCTAVE_API void
933  octave_idx_type j_arg)
934 {
935  F77_INT n = to_f77_int (m_chol_mat.rows ());
936  F77_INT i = to_f77_int (i_arg);
937  F77_INT j = to_f77_int (j_arg);
938 
939  if (i < 0 || i > n-1 || j < 0 || j > n-1)
940  (*current_liboctave_error_handler) ("cholshift: index out of range");
941 
943  OCTAVE_LOCAL_BUFFER (double, rw, n);
944 
945  F77_XFCN (zchshx, ZCHSHX, (n,
946  F77_DBLE_CMPLX_ARG (m_chol_mat.fortran_vec ()),
947  n, i + 1, j + 1,
948  F77_DBLE_CMPLX_ARG (w), rw));
949 }
950 
951 #endif
952 
953 template <>
956  bool calc_cond)
957 {
958  octave_idx_type a_nr = a.rows ();
959  octave_idx_type a_nc = a.cols ();
960 
961  if (a_nr != a_nc)
962  (*current_liboctave_error_handler) ("chol: requires square matrix");
963 
964  F77_INT n = to_f77_int (a_nc);
965  F77_INT info;
966 
967  m_is_upper = upper;
968 
969  m_chol_mat.clear (n, n);
970  if (m_is_upper)
971  for (octave_idx_type j = 0; j < n; j++)
972  {
973  for (octave_idx_type i = 0; i <= j; i++)
974  m_chol_mat.xelem (i, j) = a(i, j);
975  for (octave_idx_type i = j+1; i < n; i++)
976  m_chol_mat.xelem (i, j) = 0.0f;
977  }
978  else
979  for (octave_idx_type j = 0; j < n; j++)
980  {
981  for (octave_idx_type i = 0; i < j; i++)
982  m_chol_mat.xelem (i, j) = 0.0f;
983  for (octave_idx_type i = j; i < n; i++)
984  m_chol_mat.xelem (i, j) = a(i, j);
985  }
986  FloatComplex *h = m_chol_mat.fortran_vec ();
987 
988  // Calculate the norm of the matrix, for later use.
989  float anorm = 0;
990  if (calc_cond)
991  anorm = octave::xnorm (a, 1);
992 
993  if (m_is_upper)
994  F77_XFCN (cpotrf, CPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1),
995  n, F77_CMPLX_ARG (h), n, info
996  F77_CHAR_ARG_LEN (1)));
997  else
998  F77_XFCN (cpotrf, CPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1),
999  n, F77_CMPLX_ARG (h), n, info
1000  F77_CHAR_ARG_LEN (1)));
1001 
1002  m_rcond = 0.0;
1003  if (info > 0)
1004  m_chol_mat.resize (info - 1, info - 1);
1005  else if (calc_cond)
1006  {
1007  F77_INT cpocon_info = 0;
1008 
1009  // Now calculate the condition number for non-singular matrix.
1010  Array<FloatComplex> z (dim_vector (2*n, 1));
1011  FloatComplex *pz = z.fortran_vec ();
1012  Array<float> rz (dim_vector (n, 1));
1013  float *prz = rz.fortran_vec ();
1014  F77_XFCN (cpocon, CPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n,
1015  F77_CMPLX_ARG (h), n, anorm, m_rcond,
1016  F77_CMPLX_ARG (pz), prz, cpocon_info
1017  F77_CHAR_ARG_LEN (1)));
1018 
1019  if (cpocon_info != 0)
1020  info = -1;
1021  }
1022 
1023  return info;
1024 }
1025 
1026 #if defined (HAVE_QRUPDATE)
1027 
1028 template <>
1029 OCTAVE_API void
1031 {
1032  F77_INT n = to_f77_int (m_chol_mat.rows ());
1033 
1034  if (u.numel () != n)
1035  (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
1036 
1037  FloatComplexColumnVector utmp = u;
1038 
1039  OCTAVE_LOCAL_BUFFER (float, rw, n);
1040 
1041  F77_XFCN (cch1up, CCH1UP, (n, F77_CMPLX_ARG (m_chol_mat.fortran_vec ()),
1042  n, F77_CMPLX_ARG (utmp.fortran_vec ()), rw));
1043 }
1044 
1045 template <>
1048 {
1049  F77_INT info = -1;
1050 
1051  F77_INT n = to_f77_int (m_chol_mat.rows ());
1052 
1053  if (u.numel () != n)
1054  (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
1055 
1056  FloatComplexColumnVector utmp = u;
1057 
1058  OCTAVE_LOCAL_BUFFER (float, rw, n);
1059 
1060  F77_XFCN (cch1dn, CCH1DN, (n, F77_CMPLX_ARG (m_chol_mat.fortran_vec ()),
1061  n, F77_CMPLX_ARG (utmp.fortran_vec ()),
1062  rw, info));
1063 
1064  return info;
1065 }
1066 
1067 template <>
1070  octave_idx_type j_arg)
1071 {
1072  F77_INT info = -1;
1073  F77_INT j = to_f77_int (j_arg);
1074 
1075  F77_INT n = to_f77_int (m_chol_mat.rows ());
1076 
1077  if (u.numel () != n + 1)
1078  (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
1079  if (j < 0 || j > n)
1080  (*current_liboctave_error_handler) ("cholinsert: index out of range");
1081 
1082  FloatComplexColumnVector utmp = u;
1083 
1084  OCTAVE_LOCAL_BUFFER (float, rw, n);
1085 
1086  m_chol_mat.resize (n+1, n+1);
1087  F77_INT ldcm = to_f77_int (m_chol_mat.rows ());
1088 
1089  F77_XFCN (cchinx, CCHINX, (n, F77_CMPLX_ARG (m_chol_mat.fortran_vec ()),
1090  ldcm, j + 1,
1091  F77_CMPLX_ARG (utmp.fortran_vec ()),
1092  rw, info));
1093 
1094  return info;
1095 }
1096 
1097 template <>
1098 OCTAVE_API void
1100 {
1101  F77_INT n = to_f77_int (m_chol_mat.rows ());
1102  F77_INT j = to_f77_int (j_arg);
1103 
1104  if (j < 0 || j > n-1)
1105  (*current_liboctave_error_handler) ("choldelete: index out of range");
1106 
1107  OCTAVE_LOCAL_BUFFER (float, rw, n);
1108 
1109  F77_XFCN (cchdex, CCHDEX, (n, F77_CMPLX_ARG (m_chol_mat.fortran_vec ()),
1110  n, j + 1, rw));
1111 
1112  m_chol_mat.resize (n-1, n-1);
1113 }
1114 
1115 template <>
1116 OCTAVE_API void
1118  octave_idx_type j_arg)
1119 {
1120  F77_INT n = to_f77_int (m_chol_mat.rows ());
1121  F77_INT i = to_f77_int (i_arg);
1122  F77_INT j = to_f77_int (j_arg);
1123 
1124  if (i < 0 || i > n-1 || j < 0 || j > n-1)
1125  (*current_liboctave_error_handler) ("cholshift: index out of range");
1126 
1128  OCTAVE_LOCAL_BUFFER (float, rw, n);
1129 
1130  F77_XFCN (cchshx, CCHSHX, (n, F77_CMPLX_ARG (m_chol_mat.fortran_vec ()),
1131  n, i + 1, j + 1, F77_CMPLX_ARG (w), rw));
1132 }
1133 
1134 #endif
1135 
1136 // Instantiations we need.
1137 
1138 template class chol<Matrix>;
1139 
1140 template class chol<FloatMatrix>;
1141 
1142 template class chol<ComplexMatrix>;
1143 
1144 template class chol<FloatComplexMatrix>;
1145 
1147 chol2inv<Matrix> (const Matrix& r);
1148 
1151 
1154 
1157 
OCTAVE_END_NAMESPACE(octave)
ComplexColumnVector conj(const ComplexColumnVector &a)
Definition: CColVector.cc:217
OCTARRAY_OVERRIDABLE_FUNC_API octave_idx_type numel(void) const
Number of elements in the array.
Definition: Array.h:414
OCTARRAY_OVERRIDABLE_FUNC_API octave_idx_type rows(void) const
Definition: Array.h:459
OCTARRAY_API T * fortran_vec(void)
Size of the specified dimension.
Definition: Array-base.cc:1766
OCTARRAY_OVERRIDABLE_FUNC_API octave_idx_type cols(void) const
Definition: Array.h:469
OCTARRAY_OVERRIDABLE_FUNC_API T & xelem(octave_idx_type n)
Size of the specified dimension.
Definition: Array.h:524
Definition: dMatrix.h:42
Definition: chol.h:38
OCTAVE_API void set(const T &R)
Definition: chol.cc:257
OCTAVE_API void delete_sym(octave_idx_type j)
OCTAVE_API void update(const VT &u)
OCTAVE_API octave_idx_type downdate(const VT &u)
OCTAVE_API T inverse(void) const
Definition: chol.cc:250
OCTAVE_API octave_idx_type insert_sym(const VT &u, octave_idx_type j)
OCTAVE_API octave_idx_type init(const T &a, bool upper, bool calc_cond)
OCTAVE_API void shift_sym(octave_idx_type i, octave_idx_type j)
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:94
ColumnVector imag(const ComplexColumnVector &a)
Definition: dColVector.cc:143
OCTAVE_BEGIN_NAMESPACE(octave) static octave_value daspk_fcn
#define F77_DBLE_CMPLX_ARG(x)
Definition: f77-fcn.h:316
#define F77_CMPLX_ARG(x)
Definition: f77-fcn.h:310
#define F77_XFCN(f, F, args)
Definition: f77-fcn.h:45
octave_f77_int_type F77_INT
Definition: f77-fcn.h:306
template OCTAVE_API FloatMatrix chol2inv< FloatMatrix >(const FloatMatrix &r)
template OCTAVE_API Matrix chol2inv< Matrix >(const Matrix &r)
static Matrix chol2inv_internal(const Matrix &r, bool is_upper=true)
Definition: chol.cc:53
template OCTAVE_API ComplexMatrix chol2inv< ComplexMatrix >(const ComplexMatrix &r)
template OCTAVE_API FloatComplexMatrix chol2inv< FloatComplexMatrix >(const FloatComplexMatrix &r)
T chol2inv(const T &r)
Definition: chol.cc:242
#define OCTAVE_API
Definition: main.in.cc:55
octave_idx_type n
Definition: mx-inlines.cc:753
T * r
Definition: mx-inlines.cc:773
std::complex< double > w(std::complex< double > z, double relerr=0)
std::complex< double > Complex
Definition: oct-cmplx.h:33
std::complex< float > FloatComplex
Definition: oct-cmplx.h:34
#define OCTAVE_LOCAL_BUFFER(T, buf, size)
Definition: oct-locbuf.h:44
double xnorm(const ColumnVector &x, double p)
Definition: oct-norm.cc:585
OCTAVE_API void warn_qrupdate_once(void)