GNU Octave 7.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
chol.cc
Go to the documentation of this file.
1////////////////////////////////////////////////////////////////////////
2//
3// Copyright (C) 1994-2022 The Octave Project Developers
4//
5// See the file COPYRIGHT.md in the top-level directory of this
6// distribution or <https://octave.org/copyright/>.
7//
8// This file is part of Octave.
9//
10// Octave is free software: you can redistribute it and/or modify it
11// under the terms of the GNU General Public License as published by
12// the Free Software Foundation, either version 3 of the License, or
13// (at your option) any later version.
14//
15// Octave is distributed in the hope that it will be useful, but
16// WITHOUT ANY WARRANTY; without even the implied warranty of
17// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18// GNU General Public License for more details.
19//
20// You should have received a copy of the GNU General Public License
21// along with Octave; see the file COPYING. If not, see
22// <https://www.gnu.org/licenses/>.
23//
24////////////////////////////////////////////////////////////////////////
25
26#if defined (HAVE_CONFIG_H)
27# include "config.h"
28#endif
29
30#include "Array.h"
31#include "CColVector.h"
32#include "CMatrix.h"
33#include "chol.h"
34#include "dColVector.h"
35#include "dMatrix.h"
36#include "fCColVector.h"
37#include "fCMatrix.h"
38#include "fColVector.h"
39#include "fMatrix.h"
40#include "lo-error.h"
41#include "lo-lapack-proto.h"
42#include "lo-qrupdate-proto.h"
43#include "oct-locbuf.h"
44#include "oct-norm.h"
45
46#if ! defined (HAVE_QRUPDATE)
47# include "qr.h"
48#endif
49
50namespace octave
51{
52 static Matrix
53 chol2inv_internal (const Matrix& r, bool is_upper = true)
54 {
55 Matrix retval;
56
57 octave_idx_type r_nr = r.rows ();
58 octave_idx_type r_nc = r.cols ();
59
60 if (r_nr != r_nc)
61 (*current_liboctave_error_handler) ("chol2inv requires square matrix");
62
63 F77_INT n = to_f77_int (r_nc);
64 F77_INT info;
65
66 Matrix tmp = r;
67 double *v = tmp.fortran_vec ();
68
69 if (is_upper)
70 F77_XFCN (dpotri, DPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
71 v, n, info
72 F77_CHAR_ARG_LEN (1)));
73 else
74 F77_XFCN (dpotri, DPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n,
75 v, n, info
76 F77_CHAR_ARG_LEN (1)));
77
78 // FIXME: Should we check info exit value and possibly report an error?
79
80 // If someone thinks of a more graceful way of doing this
81 // (or faster for that matter :-)), please let me know!
82
83 if (n > 1)
84 {
85 if (is_upper)
86 for (octave_idx_type j = 0; j < r_nc; j++)
87 for (octave_idx_type i = j+1; i < r_nr; i++)
88 tmp.xelem (i, j) = tmp.xelem (j, i);
89 else
90 for (octave_idx_type j = 0; j < r_nc; j++)
91 for (octave_idx_type i = j+1; i < r_nr; i++)
92 tmp.xelem (j, i) = tmp.xelem (i, j);
93 }
94
95 retval = tmp;
96
97 return retval;
98 }
99
100 static FloatMatrix
101 chol2inv_internal (const FloatMatrix& r, bool is_upper = true)
102 {
103 FloatMatrix retval;
104
105 octave_idx_type r_nr = r.rows ();
106 octave_idx_type r_nc = r.cols ();
107
108 if (r_nr != r_nc)
109 (*current_liboctave_error_handler) ("chol2inv requires square matrix");
110
111 F77_INT n = to_f77_int (r_nc);
112 F77_INT info;
113
114 FloatMatrix tmp = r;
115 float *v = tmp.fortran_vec ();
116
117 if (is_upper)
118 F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
119 v, n, info
120 F77_CHAR_ARG_LEN (1)));
121 else
122 F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n,
123 v, n, info
124 F77_CHAR_ARG_LEN (1)));
125
126 // FIXME: Should we check info exit value and possibly report an error?
127
128 // If someone thinks of a more graceful way of doing this (or
129 // faster for that matter :-)), please let me know!
130
131 if (n > 1)
132 {
133 if (is_upper)
134 for (octave_idx_type j = 0; j < r_nc; j++)
135 for (octave_idx_type i = j+1; i < r_nr; i++)
136 tmp.xelem (i, j) = tmp.xelem (j, i);
137 else
138 for (octave_idx_type j = 0; j < r_nc; j++)
139 for (octave_idx_type i = j+1; i < r_nr; i++)
140 tmp.xelem (j, i) = tmp.xelem (i, j);
141 }
142
143 retval = tmp;
144
145 return retval;
146 }
147
148 static ComplexMatrix
149 chol2inv_internal (const ComplexMatrix& r, bool is_upper = true)
150 {
151 ComplexMatrix retval;
152
153 octave_idx_type r_nr = r.rows ();
154 octave_idx_type r_nc = r.cols ();
155
156 if (r_nr != r_nc)
157 (*current_liboctave_error_handler) ("chol2inv requires square matrix");
158
159 F77_INT n = to_f77_int (r_nc);
160 F77_INT info;
161
162 ComplexMatrix tmp = r;
163
164 if (is_upper)
165 F77_XFCN (zpotri, ZPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
166 F77_DBLE_CMPLX_ARG (tmp.fortran_vec ()), n, info
167 F77_CHAR_ARG_LEN (1)));
168 else
169 F77_XFCN (zpotri, ZPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n,
170 F77_DBLE_CMPLX_ARG (tmp.fortran_vec ()), n, info
171 F77_CHAR_ARG_LEN (1)));
172
173 // If someone thinks of a more graceful way of doing this (or
174 // faster for that matter :-)), please let me know!
175
176 if (n > 1)
177 {
178 if (is_upper)
179 for (octave_idx_type j = 0; j < r_nc; j++)
180 for (octave_idx_type i = j+1; i < r_nr; i++)
181 tmp.xelem (i, j) = std::conj (tmp.xelem (j, i));
182 else
183 for (octave_idx_type j = 0; j < r_nc; j++)
184 for (octave_idx_type i = j+1; i < r_nr; i++)
185 tmp.xelem (j, i) = std::conj (tmp.xelem (i, j));
186 }
187
188 retval = tmp;
189
190 return retval;
191 }
192
193 static FloatComplexMatrix
194 chol2inv_internal (const FloatComplexMatrix& r, bool is_upper = true)
195 {
196 FloatComplexMatrix retval;
197
198 octave_idx_type r_nr = r.rows ();
199 octave_idx_type r_nc = r.cols ();
200
201 if (r_nr != r_nc)
202 (*current_liboctave_error_handler) ("chol2inv requires square matrix");
203
204 F77_INT n = to_f77_int (r_nc);
205 F77_INT info;
206
207 FloatComplexMatrix tmp = r;
208
209 if (is_upper)
210 F77_XFCN (cpotri, CPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
211 F77_CMPLX_ARG (tmp.fortran_vec ()), n, info
212 F77_CHAR_ARG_LEN (1)));
213 else
214 F77_XFCN (cpotri, CPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n,
215 F77_CMPLX_ARG (tmp.fortran_vec ()), n, info
216 F77_CHAR_ARG_LEN (1)));
217
218 // If someone thinks of a more graceful way of doing this (or
219 // faster for that matter :-)), please let me know!
220
221 if (n > 1)
222 {
223 if (is_upper)
224 for (octave_idx_type j = 0; j < r_nc; j++)
225 for (octave_idx_type i = j+1; i < r_nr; i++)
226 tmp.xelem (i, j) = std::conj (tmp.xelem (j, i));
227 else
228 for (octave_idx_type j = 0; j < r_nc; j++)
229 for (octave_idx_type i = j+1; i < r_nr; i++)
230 tmp.xelem (j, i) = std::conj (tmp.xelem (i, j));
231 }
232
233 retval = tmp;
234
235 return retval;
236 }
237
238 namespace math
239 {
240 template <typename T>
241 T
242 chol2inv (const T& r)
243 {
244 return chol2inv_internal (r);
245 }
246
247 // Compute the inverse of a matrix using the Cholesky factorization.
248 template <typename T>
249 T
250 chol<T>::inverse (void) const
251 {
252 return chol2inv_internal (m_chol_mat, m_is_upper);
253 }
254
255 template <typename T>
256 void
257 chol<T>::set (const T& R)
258 {
259 if (! R.issquare ())
260 (*current_liboctave_error_handler) ("chol: requires square matrix");
261
262 m_chol_mat = R;
263 }
264
265#if ! defined (HAVE_QRUPDATE)
266
267 template <typename T>
268 void
269 chol<T>::update (const VT& u)
270 {
272
273 octave_idx_type n = m_chol_mat.rows ();
274
275 if (u.numel () != n)
276 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
277
278 init (m_chol_mat.hermitian () * m_chol_mat + T (u) * T (u).hermitian (),
279 true, false);
280 }
281
282 template <typename T>
283 bool
284 singular (const T& a)
285 {
286 static typename T::element_type zero (0);
287 for (octave_idx_type i = 0; i < a.rows (); i++)
288 if (a(i, i) == zero) return true;
289 return false;
290 }
291
292 template <typename T>
294 chol<T>::downdate (const VT& u)
295 {
297
298 octave_idx_type info = -1;
299
300 octave_idx_type n = m_chol_mat.rows ();
301
302 if (u.numel () != n)
303 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
304
305 if (singular (m_chol_mat))
306 info = 2;
307 else
308 {
309 info = init (m_chol_mat.hermitian () * m_chol_mat
310 - T (u) * T (u).hermitian (), true, false);
311 if (info) info = 1;
312 }
313
314 return info;
315 }
316
317 template <typename T>
319 chol<T>::insert_sym (const VT& u, octave_idx_type j)
320 {
321 static typename T::element_type zero (0);
322
324
325 octave_idx_type info = -1;
326
327 octave_idx_type n = m_chol_mat.rows ();
328
329 if (u.numel () != n + 1)
330 (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
331 if (j < 0 || j > n)
332 (*current_liboctave_error_handler) ("cholinsert: index out of range");
333
334 if (singular (m_chol_mat))
335 info = 2;
336 else if (std::imag (u(j)) != zero)
337 info = 3;
338 else
339 {
340 T a = m_chol_mat.hermitian () * m_chol_mat;
341 T a1 (n+1, n+1);
342 for (octave_idx_type k = 0; k < n+1; k++)
343 for (octave_idx_type l = 0; l < n+1; l++)
344 {
345 if (l == j)
346 a1(k, l) = u(k);
347 else if (k == j)
348 a1(k, l) = math::conj (u(l));
349 else
350 a1(k, l) = a(k < j ? k : k-1, l < j ? l : l-1);
351 }
352 info = init (a1, true, false);
353 if (info) info = 1;
354 }
355
356 return info;
357 }
358
359 template <typename T>
360 void
362 {
364
365 octave_idx_type n = m_chol_mat.rows ();
366
367 if (j < 0 || j > n-1)
368 (*current_liboctave_error_handler) ("choldelete: index out of range");
369
370 T a = m_chol_mat.hermitian () * m_chol_mat;
371 a.delete_elements (1, idx_vector (j));
372 a.delete_elements (0, idx_vector (j));
373 init (a, true, false);
374 }
375
376 template <typename T>
377 void
379 {
381
382 octave_idx_type n = m_chol_mat.rows ();
383
384 if (i < 0 || i > n-1 || j < 0 || j > n-1)
385 (*current_liboctave_error_handler) ("cholshift: index out of range");
386
387 T a = m_chol_mat.hermitian () * m_chol_mat;
389 for (octave_idx_type k = 0; k < n; k++) p(k) = k;
390 if (i < j)
391 {
392 for (octave_idx_type k = i; k < j; k++) p(k) = k+1;
393 p(j) = i;
394 }
395 else if (j < i)
396 {
397 p(j) = i;
398 for (octave_idx_type k = j+1; k < i+1; k++) p(k) = k-1;
399 }
400
401 init (a.index (idx_vector (p), idx_vector (p)), true, false);
402 }
403
404#endif
405
406 // Specializations.
407
408 template <>
410 chol<Matrix>::init (const Matrix& a, bool upper, bool calc_cond)
411 {
412 octave_idx_type a_nr = a.rows ();
413 octave_idx_type a_nc = a.cols ();
414
415 if (a_nr != a_nc)
416 (*current_liboctave_error_handler) ("chol: requires square matrix");
417
418 F77_INT n = to_f77_int (a_nc);
419 F77_INT info;
420
421 m_is_upper = upper;
422
423 m_chol_mat.clear (n, n);
424 if (m_is_upper)
425 for (octave_idx_type j = 0; j < n; j++)
426 {
427 for (octave_idx_type i = 0; i <= j; i++)
428 m_chol_mat.xelem (i, j) = a(i, j);
429 for (octave_idx_type i = j+1; i < n; i++)
430 m_chol_mat.xelem (i, j) = 0.0;
431 }
432 else
433 for (octave_idx_type j = 0; j < n; j++)
434 {
435 for (octave_idx_type i = 0; i < j; i++)
436 m_chol_mat.xelem (i, j) = 0.0;
437 for (octave_idx_type i = j; i < n; i++)
438 m_chol_mat.xelem (i, j) = a(i, j);
439 }
440 double *h = m_chol_mat.fortran_vec ();
441
442 // Calculate the norm of the matrix, for later use.
443 double anorm = 0;
444 if (calc_cond)
445 anorm = octave::xnorm (a, 1);
446
447 if (m_is_upper)
448 F77_XFCN (dpotrf, DPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, n, info
449 F77_CHAR_ARG_LEN (1)));
450 else
451 F77_XFCN (dpotrf, DPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1), n, h, n, info
452 F77_CHAR_ARG_LEN (1)));
453
454 m_rcond = 0.0;
455 if (info > 0)
456 m_chol_mat.resize (info - 1, info - 1);
457 else if (calc_cond)
458 {
459 F77_INT dpocon_info = 0;
460
461 // Now calculate the condition number for non-singular matrix.
462 Array<double> z (dim_vector (3*n, 1));
463 double *pz = z.fortran_vec ();
465 if (m_is_upper)
466 F77_XFCN (dpocon, DPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h,
467 n, anorm, m_rcond, pz, iz, dpocon_info
468 F77_CHAR_ARG_LEN (1)));
469 else
470 F77_XFCN (dpocon, DPOCON, (F77_CONST_CHAR_ARG2 ("L", 1), n, h,
471 n, anorm, m_rcond, pz, iz, dpocon_info
472 F77_CHAR_ARG_LEN (1)));
473
474 if (dpocon_info != 0)
475 info = -1;
476 }
477
478 return info;
479 }
480
481#if defined (HAVE_QRUPDATE)
482
483 template <>
484 OCTAVE_API void
486 {
487 F77_INT n = to_f77_int (m_chol_mat.rows ());
488
489 if (u.numel () != n)
490 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
491
492 ColumnVector utmp = u;
493
494 OCTAVE_LOCAL_BUFFER (double, w, n);
495
496 F77_XFCN (dch1up, DCH1UP, (n, m_chol_mat.fortran_vec (), n,
497 utmp.fortran_vec (), w));
498 }
499
500 template <>
503 {
504 F77_INT info = -1;
505
506 F77_INT n = to_f77_int (m_chol_mat.rows ());
507
508 if (u.numel () != n)
509 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
510
511 ColumnVector utmp = u;
512
513 OCTAVE_LOCAL_BUFFER (double, w, n);
514
515 F77_XFCN (dch1dn, DCH1DN, (n, m_chol_mat.fortran_vec (), n,
516 utmp.fortran_vec (), w, info));
517
518 return info;
519 }
520
521 template <>
524 {
525 F77_INT info = -1;
526
527 F77_INT n = to_f77_int (m_chol_mat.rows ());
528 F77_INT j = to_f77_int (j_arg);
529
530 if (u.numel () != n + 1)
531 (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
532 if (j < 0 || j > n)
533 (*current_liboctave_error_handler) ("cholinsert: index out of range");
534
535 ColumnVector utmp = u;
536
537 OCTAVE_LOCAL_BUFFER (double, w, n);
538
539 m_chol_mat.resize (n+1, n+1);
540 F77_INT ldcm = to_f77_int (m_chol_mat.rows ());
541
542 F77_XFCN (dchinx, DCHINX, (n, m_chol_mat.fortran_vec (), ldcm,
543 j + 1, utmp.fortran_vec (), w, info));
544
545 return info;
546 }
547
548 template <>
549 OCTAVE_API void
551 {
552 F77_INT n = to_f77_int (m_chol_mat.rows ());
553 F77_INT j = to_f77_int (j_arg);
554
555 if (j < 0 || j > n-1)
556 (*current_liboctave_error_handler) ("choldelete: index out of range");
557
558 OCTAVE_LOCAL_BUFFER (double, w, n);
559
560 F77_XFCN (dchdex, DCHDEX, (n, m_chol_mat.fortran_vec (), n, j + 1, w));
561
562 m_chol_mat.resize (n-1, n-1);
563 }
564
565 template <>
566 OCTAVE_API void
568 {
569 F77_INT n = to_f77_int (m_chol_mat.rows ());
570 F77_INT i = to_f77_int (i_arg);
571 F77_INT j = to_f77_int (j_arg);
572
573 if (i < 0 || i > n-1 || j < 0 || j > n-1)
574 (*current_liboctave_error_handler) ("cholshift: index out of range");
575
576 OCTAVE_LOCAL_BUFFER (double, w, 2*n);
577
578 F77_XFCN (dchshx, DCHSHX, (n, m_chol_mat.fortran_vec (), n,
579 i + 1, j + 1, w));
580 }
581
582#endif
583
584 template <>
586 chol<FloatMatrix>::init (const FloatMatrix& a, bool upper, bool calc_cond)
587 {
588 octave_idx_type a_nr = a.rows ();
589 octave_idx_type a_nc = a.cols ();
590
591 if (a_nr != a_nc)
592 (*current_liboctave_error_handler) ("chol: requires square matrix");
593
594 F77_INT n = to_f77_int (a_nc);
595 F77_INT info;
596
597 m_is_upper = upper;
598
599 m_chol_mat.clear (n, n);
600 if (m_is_upper)
601 for (octave_idx_type j = 0; j < n; j++)
602 {
603 for (octave_idx_type i = 0; i <= j; i++)
604 m_chol_mat.xelem (i, j) = a(i, j);
605 for (octave_idx_type i = j+1; i < n; i++)
606 m_chol_mat.xelem (i, j) = 0.0f;
607 }
608 else
609 for (octave_idx_type j = 0; j < n; j++)
610 {
611 for (octave_idx_type i = 0; i < j; i++)
612 m_chol_mat.xelem (i, j) = 0.0f;
613 for (octave_idx_type i = j; i < n; i++)
614 m_chol_mat.xelem (i, j) = a(i, j);
615 }
616 float *h = m_chol_mat.fortran_vec ();
617
618 // Calculate the norm of the matrix, for later use.
619 float anorm = 0;
620 if (calc_cond)
621 anorm = octave::xnorm (a, 1);
622
623 if (m_is_upper)
624 F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, n, info
625 F77_CHAR_ARG_LEN (1)));
626 else
627 F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1), n, h, n, info
628 F77_CHAR_ARG_LEN (1)));
629
630 m_rcond = 0.0;
631 if (info > 0)
632 m_chol_mat.resize (info - 1, info - 1);
633 else if (calc_cond)
634 {
635 F77_INT spocon_info = 0;
636
637 // Now calculate the condition number for non-singular matrix.
638 Array<float> z (dim_vector (3*n, 1));
639 float *pz = z.fortran_vec ();
641 if (m_is_upper)
642 F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h,
643 n, anorm, m_rcond, pz, iz, spocon_info
644 F77_CHAR_ARG_LEN (1)));
645 else
646 F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("L", 1), n, h,
647 n, anorm, m_rcond, pz, iz, spocon_info
648 F77_CHAR_ARG_LEN (1)));
649
650 if (spocon_info != 0)
651 info = -1;
652 }
653
654 return info;
655 }
656
657#if defined (HAVE_QRUPDATE)
658
659 template <>
660 OCTAVE_API void
662 {
663 F77_INT n = to_f77_int (m_chol_mat.rows ());
664
665 if (u.numel () != n)
666 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
667
668 FloatColumnVector utmp = u;
669
670 OCTAVE_LOCAL_BUFFER (float, w, n);
671
672 F77_XFCN (sch1up, SCH1UP, (n, m_chol_mat.fortran_vec (), n,
673 utmp.fortran_vec (), w));
674 }
675
676 template <>
679 {
680 F77_INT info = -1;
681
682 F77_INT n = to_f77_int (m_chol_mat.rows ());
683
684 if (u.numel () != n)
685 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
686
687 FloatColumnVector utmp = u;
688
689 OCTAVE_LOCAL_BUFFER (float, w, n);
690
691 F77_XFCN (sch1dn, SCH1DN, (n, m_chol_mat.fortran_vec (), n,
692 utmp.fortran_vec (), w, info));
693
694 return info;
695 }
696
697 template <>
700 octave_idx_type j_arg)
701 {
702 F77_INT info = -1;
703
704 F77_INT n = to_f77_int (m_chol_mat.rows ());
705 F77_INT j = to_f77_int (j_arg);
706
707 if (u.numel () != n + 1)
708 (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
709 if (j < 0 || j > n)
710 (*current_liboctave_error_handler) ("cholinsert: index out of range");
711
712 FloatColumnVector utmp = u;
713
714 OCTAVE_LOCAL_BUFFER (float, w, n);
715
716 m_chol_mat.resize (n+1, n+1);
717 F77_INT ldcm = to_f77_int (m_chol_mat.rows ());
718
719 F77_XFCN (schinx, SCHINX, (n, m_chol_mat.fortran_vec (), ldcm,
720 j + 1, utmp.fortran_vec (), w, info));
721
722 return info;
723 }
724
725 template <>
726 OCTAVE_API void
728 {
729 F77_INT n = to_f77_int (m_chol_mat.rows ());
730 F77_INT j = to_f77_int (j_arg);
731
732 if (j < 0 || j > n-1)
733 (*current_liboctave_error_handler) ("choldelete: index out of range");
734
735 OCTAVE_LOCAL_BUFFER (float, w, n);
736
737 F77_XFCN (schdex, SCHDEX, (n, m_chol_mat.fortran_vec (), n,
738 j + 1, w));
739
740 m_chol_mat.resize (n-1, n-1);
741 }
742
743 template <>
744 OCTAVE_API void
746 {
747 F77_INT n = to_f77_int (m_chol_mat.rows ());
748 F77_INT i = to_f77_int (i_arg);
749 F77_INT j = to_f77_int (j_arg);
750
751 if (i < 0 || i > n-1 || j < 0 || j > n-1)
752 (*current_liboctave_error_handler) ("cholshift: index out of range");
753
754 OCTAVE_LOCAL_BUFFER (float, w, 2*n);
755
756 F77_XFCN (schshx, SCHSHX, (n, m_chol_mat.fortran_vec (), n,
757 i + 1, j + 1, w));
758 }
759
760#endif
761
762 template <>
764 chol<ComplexMatrix>::init (const ComplexMatrix& a, bool upper, bool calc_cond)
765 {
766 octave_idx_type a_nr = a.rows ();
767 octave_idx_type a_nc = a.cols ();
768
769 if (a_nr != a_nc)
770 (*current_liboctave_error_handler) ("chol: requires square matrix");
771
772 F77_INT n = to_f77_int (a_nc);
773 F77_INT info;
774
775 m_is_upper = upper;
776
777 m_chol_mat.clear (n, n);
778 if (m_is_upper)
779 for (octave_idx_type j = 0; j < n; j++)
780 {
781 for (octave_idx_type i = 0; i <= j; i++)
782 m_chol_mat.xelem (i, j) = a(i, j);
783 for (octave_idx_type i = j+1; i < n; i++)
784 m_chol_mat.xelem (i, j) = 0.0;
785 }
786 else
787 for (octave_idx_type j = 0; j < n; j++)
788 {
789 for (octave_idx_type i = 0; i < j; i++)
790 m_chol_mat.xelem (i, j) = 0.0;
791 for (octave_idx_type i = j; i < n; i++)
792 m_chol_mat.xelem (i, j) = a(i, j);
793 }
794 Complex *h = m_chol_mat.fortran_vec ();
795
796 // Calculate the norm of the matrix, for later use.
797 double anorm = 0;
798 if (calc_cond)
799 anorm = octave::xnorm (a, 1);
800
801 if (m_is_upper)
802 F77_XFCN (zpotrf, ZPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), n,
803 F77_DBLE_CMPLX_ARG (h), n, info
804 F77_CHAR_ARG_LEN (1)));
805 else
806 F77_XFCN (zpotrf, ZPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1), n,
807 F77_DBLE_CMPLX_ARG (h), n, info
808 F77_CHAR_ARG_LEN (1)));
809
810 m_rcond = 0.0;
811 if (info > 0)
812 m_chol_mat.resize (info - 1, info - 1);
813 else if (calc_cond)
814 {
815 F77_INT zpocon_info = 0;
816
817 // Now calculate the condition number for non-singular matrix.
818 Array<Complex> z (dim_vector (2*n, 1));
819 Complex *pz = z.fortran_vec ();
820 Array<double> rz (dim_vector (n, 1));
821 double *prz = rz.fortran_vec ();
822 F77_XFCN (zpocon, ZPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n,
823 F77_DBLE_CMPLX_ARG (h), n, anorm, m_rcond,
824 F77_DBLE_CMPLX_ARG (pz), prz, zpocon_info
825 F77_CHAR_ARG_LEN (1)));
826
827 if (zpocon_info != 0)
828 info = -1;
829 }
830
831 return info;
832 }
833
834#if defined (HAVE_QRUPDATE)
835
836 template <>
837 OCTAVE_API void
839 {
840 F77_INT n = to_f77_int (m_chol_mat.rows ());
841
842 if (u.numel () != n)
843 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
844
845 ComplexColumnVector utmp = u;
846
847 OCTAVE_LOCAL_BUFFER (double, rw, n);
848
849 F77_XFCN (zch1up, ZCH1UP, (n,
850 F77_DBLE_CMPLX_ARG (m_chol_mat.fortran_vec ()),
851 n,
852 F77_DBLE_CMPLX_ARG (utmp.fortran_vec ()),
853 rw));
854 }
855
856 template <>
859 {
860 F77_INT info = -1;
861
862 F77_INT n = to_f77_int (m_chol_mat.rows ());
863
864 if (u.numel () != n)
865 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
866
867 ComplexColumnVector utmp = u;
868
869 OCTAVE_LOCAL_BUFFER (double, rw, n);
870
871 F77_XFCN (zch1dn, ZCH1DN, (n,
872 F77_DBLE_CMPLX_ARG (m_chol_mat.fortran_vec ()),
873 n,
874 F77_DBLE_CMPLX_ARG (utmp.fortran_vec ()),
875 rw, info));
876
877 return info;
878 }
879
880 template <>
883 octave_idx_type j_arg)
884 {
885 F77_INT info = -1;
886
887 F77_INT n = to_f77_int (m_chol_mat.rows ());
888 F77_INT j = to_f77_int (j_arg);
889
890 if (u.numel () != n + 1)
891 (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
892 if (j < 0 || j > n)
893 (*current_liboctave_error_handler) ("cholinsert: index out of range");
894
895 ComplexColumnVector utmp = u;
896
897 OCTAVE_LOCAL_BUFFER (double, rw, n);
898
899 m_chol_mat.resize (n+1, n+1);
900 F77_INT ldcm = to_f77_int (m_chol_mat.rows ());
901
902 F77_XFCN (zchinx, ZCHINX, (n,
903 F77_DBLE_CMPLX_ARG (m_chol_mat.fortran_vec ()),
904 ldcm, j + 1,
905 F77_DBLE_CMPLX_ARG (utmp.fortran_vec ()),
906 rw, info));
907
908 return info;
909 }
910
911 template <>
912 OCTAVE_API void
914 {
915 F77_INT n = to_f77_int (m_chol_mat.rows ());
916 F77_INT j = to_f77_int (j_arg);
917
918 if (j < 0 || j > n-1)
919 (*current_liboctave_error_handler) ("choldelete: index out of range");
920
921 OCTAVE_LOCAL_BUFFER (double, rw, n);
922
923 F77_XFCN (zchdex, ZCHDEX, (n,
924 F77_DBLE_CMPLX_ARG (m_chol_mat.fortran_vec ()),
925 n, j + 1, rw));
926
927 m_chol_mat.resize (n-1, n-1);
928 }
929
930 template <>
931 OCTAVE_API void
933 octave_idx_type j_arg)
934 {
935 F77_INT n = to_f77_int (m_chol_mat.rows ());
936 F77_INT i = to_f77_int (i_arg);
937 F77_INT j = to_f77_int (j_arg);
938
939 if (i < 0 || i > n-1 || j < 0 || j > n-1)
940 (*current_liboctave_error_handler) ("cholshift: index out of range");
941
943 OCTAVE_LOCAL_BUFFER (double, rw, n);
944
945 F77_XFCN (zchshx, ZCHSHX, (n,
946 F77_DBLE_CMPLX_ARG (m_chol_mat.fortran_vec ()),
947 n, i + 1, j + 1,
948 F77_DBLE_CMPLX_ARG (w), rw));
949 }
950
951#endif
952
953 template <>
956 bool calc_cond)
957 {
958 octave_idx_type a_nr = a.rows ();
959 octave_idx_type a_nc = a.cols ();
960
961 if (a_nr != a_nc)
962 (*current_liboctave_error_handler) ("chol: requires square matrix");
963
964 F77_INT n = to_f77_int (a_nc);
965 F77_INT info;
966
967 m_is_upper = upper;
968
969 m_chol_mat.clear (n, n);
970 if (m_is_upper)
971 for (octave_idx_type j = 0; j < n; j++)
972 {
973 for (octave_idx_type i = 0; i <= j; i++)
974 m_chol_mat.xelem (i, j) = a(i, j);
975 for (octave_idx_type i = j+1; i < n; i++)
976 m_chol_mat.xelem (i, j) = 0.0f;
977 }
978 else
979 for (octave_idx_type j = 0; j < n; j++)
980 {
981 for (octave_idx_type i = 0; i < j; i++)
982 m_chol_mat.xelem (i, j) = 0.0f;
983 for (octave_idx_type i = j; i < n; i++)
984 m_chol_mat.xelem (i, j) = a(i, j);
985 }
986 FloatComplex *h = m_chol_mat.fortran_vec ();
987
988 // Calculate the norm of the matrix, for later use.
989 float anorm = 0;
990 if (calc_cond)
991 anorm = octave::xnorm (a, 1);
992
993 if (m_is_upper)
994 F77_XFCN (cpotrf, CPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1),
995 n, F77_CMPLX_ARG (h), n, info
996 F77_CHAR_ARG_LEN (1)));
997 else
998 F77_XFCN (cpotrf, CPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1),
999 n, F77_CMPLX_ARG (h), n, info
1000 F77_CHAR_ARG_LEN (1)));
1001
1002 m_rcond = 0.0;
1003 if (info > 0)
1004 m_chol_mat.resize (info - 1, info - 1);
1005 else if (calc_cond)
1006 {
1007 F77_INT cpocon_info = 0;
1008
1009 // Now calculate the condition number for non-singular matrix.
1010 Array<FloatComplex> z (dim_vector (2*n, 1));
1011 FloatComplex *pz = z.fortran_vec ();
1012 Array<float> rz (dim_vector (n, 1));
1013 float *prz = rz.fortran_vec ();
1014 F77_XFCN (cpocon, CPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n,
1015 F77_CMPLX_ARG (h), n, anorm, m_rcond,
1016 F77_CMPLX_ARG (pz), prz, cpocon_info
1017 F77_CHAR_ARG_LEN (1)));
1018
1019 if (cpocon_info != 0)
1020 info = -1;
1021 }
1022
1023 return info;
1024 }
1025
1026#if defined (HAVE_QRUPDATE)
1027
1028 template <>
1029 OCTAVE_API void
1031 {
1032 F77_INT n = to_f77_int (m_chol_mat.rows ());
1033
1034 if (u.numel () != n)
1035 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
1036
1037 FloatComplexColumnVector utmp = u;
1038
1039 OCTAVE_LOCAL_BUFFER (float, rw, n);
1040
1041 F77_XFCN (cch1up, CCH1UP, (n, F77_CMPLX_ARG (m_chol_mat.fortran_vec ()),
1042 n, F77_CMPLX_ARG (utmp.fortran_vec ()), rw));
1043 }
1044
1045 template <>
1048 {
1049 F77_INT info = -1;
1050
1051 F77_INT n = to_f77_int (m_chol_mat.rows ());
1052
1053 if (u.numel () != n)
1054 (*current_liboctave_error_handler) ("cholupdate: dimension mismatch");
1055
1056 FloatComplexColumnVector utmp = u;
1057
1058 OCTAVE_LOCAL_BUFFER (float, rw, n);
1059
1060 F77_XFCN (cch1dn, CCH1DN, (n, F77_CMPLX_ARG (m_chol_mat.fortran_vec ()),
1061 n, F77_CMPLX_ARG (utmp.fortran_vec ()),
1062 rw, info));
1063
1064 return info;
1065 }
1066
1067 template <>
1070 octave_idx_type j_arg)
1071 {
1072 F77_INT info = -1;
1073 F77_INT j = to_f77_int (j_arg);
1074
1075 F77_INT n = to_f77_int (m_chol_mat.rows ());
1076
1077 if (u.numel () != n + 1)
1078 (*current_liboctave_error_handler) ("cholinsert: dimension mismatch");
1079 if (j < 0 || j > n)
1080 (*current_liboctave_error_handler) ("cholinsert: index out of range");
1081
1082 FloatComplexColumnVector utmp = u;
1083
1084 OCTAVE_LOCAL_BUFFER (float, rw, n);
1085
1086 m_chol_mat.resize (n+1, n+1);
1087 F77_INT ldcm = to_f77_int (m_chol_mat.rows ());
1088
1089 F77_XFCN (cchinx, CCHINX, (n, F77_CMPLX_ARG (m_chol_mat.fortran_vec ()),
1090 ldcm, j + 1,
1091 F77_CMPLX_ARG (utmp.fortran_vec ()),
1092 rw, info));
1093
1094 return info;
1095 }
1096
1097 template <>
1098 OCTAVE_API void
1100 {
1101 F77_INT n = to_f77_int (m_chol_mat.rows ());
1102 F77_INT j = to_f77_int (j_arg);
1103
1104 if (j < 0 || j > n-1)
1105 (*current_liboctave_error_handler) ("choldelete: index out of range");
1106
1107 OCTAVE_LOCAL_BUFFER (float, rw, n);
1108
1109 F77_XFCN (cchdex, CCHDEX, (n, F77_CMPLX_ARG (m_chol_mat.fortran_vec ()),
1110 n, j + 1, rw));
1111
1112 m_chol_mat.resize (n-1, n-1);
1113 }
1114
1115 template <>
1116 OCTAVE_API void
1118 octave_idx_type j_arg)
1119 {
1120 F77_INT n = to_f77_int (m_chol_mat.rows ());
1121 F77_INT i = to_f77_int (i_arg);
1122 F77_INT j = to_f77_int (j_arg);
1123
1124 if (i < 0 || i > n-1 || j < 0 || j > n-1)
1125 (*current_liboctave_error_handler) ("cholshift: index out of range");
1126
1128 OCTAVE_LOCAL_BUFFER (float, rw, n);
1129
1130 F77_XFCN (cchshx, CCHSHX, (n, F77_CMPLX_ARG (m_chol_mat.fortran_vec ()),
1131 n, i + 1, j + 1, F77_CMPLX_ARG (w), rw));
1132 }
1133
1134#endif
1135
1136 // Instantiations we need.
1137
1138 template class chol<Matrix>;
1139
1140 template class chol<FloatMatrix>;
1141
1142 template class chol<ComplexMatrix>;
1143
1144 template class chol<FloatComplexMatrix>;
1145
1147 chol2inv<Matrix> (const Matrix& r);
1148
1151
1154
1157 }
1158}
ComplexColumnVector conj(const ComplexColumnVector &a)
Definition: CColVector.cc:217
T & xelem(octave_idx_type n)
Size of the specified dimension.
Definition: Array.h:504
octave_idx_type numel(void) const
Number of elements in the array.
Definition: Array.h:411
octave_idx_type cols(void) const
Definition: Array.h:457
octave_idx_type rows(void) const
Definition: Array.h:449
OCTARRAY_API T * fortran_vec(void)
Size of the specified dimension.
Definition: Array.cc:1744
Definition: dMatrix.h:42
Definition: mx-defs.h:39
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:94
OCTAVE_API void set(const T &R)
Definition: chol.cc:257
OCTAVE_API octave_idx_type insert_sym(const VT &u, octave_idx_type j)
OCTAVE_API octave_idx_type init(const T &a, bool upper, bool calc_cond)
OCTAVE_API T inverse(void) const
Definition: chol.cc:250
OCTAVE_API void shift_sym(octave_idx_type i, octave_idx_type j)
OCTAVE_API void update(const VT &u)
OCTAVE_API octave_idx_type downdate(const VT &u)
OCTAVE_API void delete_sym(octave_idx_type j)
ColumnVector imag(const ComplexColumnVector &a)
Definition: dColVector.cc:143
#define F77_DBLE_CMPLX_ARG(x)
Definition: f77-fcn.h:316
#define F77_CMPLX_ARG(x)
Definition: f77-fcn.h:310
#define F77_XFCN(f, F, args)
Definition: f77-fcn.h:45
octave_f77_int_type F77_INT
Definition: f77-fcn.h:306
#define OCTAVE_API
Definition: main.in.cc:55
std::complex< double > w(std::complex< double > z, double relerr=0)
template OCTAVE_API FloatComplexMatrix chol2inv< FloatComplexMatrix >(const FloatComplexMatrix &r)
T chol2inv(const T &r)
Definition: chol.cc:242
template OCTAVE_API Matrix chol2inv< Matrix >(const Matrix &r)
template OCTAVE_API FloatMatrix chol2inv< FloatMatrix >(const FloatMatrix &r)
OCTAVE_API void warn_qrupdate_once(void)
template OCTAVE_API ComplexMatrix chol2inv< ComplexMatrix >(const ComplexMatrix &r)
double conj(double x)
Definition: lo-mappers.h:76
static Matrix chol2inv_internal(const Matrix &r, bool is_upper=true)
Definition: chol.cc:53
double xnorm(const ColumnVector &x, double p)
Definition: oct-norm.cc:585
std::complex< double > Complex
Definition: oct-cmplx.h:33
std::complex< float > FloatComplex
Definition: oct-cmplx.h:34
#define OCTAVE_LOCAL_BUFFER(T, buf, size)
Definition: oct-locbuf.h:44