GNU Octave  9.1.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
Sparse-perm-op-defs.h
Go to the documentation of this file.
1 ////////////////////////////////////////////////////////////////////////
2 //
3 // Copyright (C) 2009-2024 The Octave Project Developers
4 //
5 // See the file COPYRIGHT.md in the top-level directory of this
6 // distribution or <https://octave.org/copyright/>.
7 //
8 // This file is part of Octave.
9 //
10 // Octave is free software: you can redistribute it and/or modify it
11 // under the terms of the GNU General Public License as published by
12 // the Free Software Foundation, either version 3 of the License, or
13 // (at your option) any later version.
14 //
15 // Octave is distributed in the hope that it will be useful, but
16 // WITHOUT ANY WARRANTY; without even the implied warranty of
17 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 // GNU General Public License for more details.
19 //
20 // You should have received a copy of the GNU General Public License
21 // along with Octave; see the file COPYING. If not, see
22 // <https://www.gnu.org/licenses/>.
23 //
24 ////////////////////////////////////////////////////////////////////////
25 
26 #if ! defined (octave_Sparse_perm_op_defs_h)
27 #define octave_Sparse_perm_op_defs_h 1
28 
29 #include "octave-config.h"
30 
31 #include "PermMatrix.h"
32 #include "lo-array-errwarn.h"
33 #include "oct-locbuf.h"
34 #include "oct-sort.h"
35 #include "quit.h"
36 
37 // Matrix multiplication
38 
39 template <typename SM>
40 SM octinternal_do_mul_colpm_sm (const octave_idx_type *pcol, const SM& a)
41 // Relabel the rows according to pcol.
42 {
43  const octave_idx_type nr = a.rows ();
44  const octave_idx_type nc = a.cols ();
45  const octave_idx_type nent = a.nnz ();
46  SM r (nr, nc, nent);
47 
49 
50  for (octave_idx_type j = 0; j <= nc; ++j)
51  r.xcidx (j) = a.cidx (j);
52 
53  for (octave_idx_type j = 0; j < nc; j++)
54  {
55  octave_quit ();
56 
57  OCTAVE_LOCAL_BUFFER (octave_idx_type, sidx, r.xcidx (j+1) - r.xcidx (j));
58  for (octave_idx_type i = r.xcidx (j), ii = 0; i < r.xcidx (j+1); i++)
59  {
60  sidx[ii++]=i;
61  r.xridx (i) = pcol[a.ridx (i)];
62  }
63  sort.sort (r.xridx () + r.xcidx (j), sidx, r.xcidx (j+1) - r.xcidx (j));
64  for (octave_idx_type i = r.xcidx (j), ii = 0; i < r.xcidx (j+1); i++)
65  r.xdata (i) = a.data (sidx[ii++]);
66  }
67 
68  return r;
69 }
70 
71 template <typename SM>
72 SM octinternal_do_mul_pm_sm (const PermMatrix& p, const SM& a)
73 {
74  const octave_idx_type nr = a.rows ();
75  if (p.cols () != nr)
76  octave::err_nonconformant ("operator *",
77  p.rows (), p.cols (), a.rows (), a.cols ());
78 
79  return octinternal_do_mul_colpm_sm (p.col_perm_vec ().data (), a);
80 }
81 
82 template <typename SM>
83 SM octinternal_do_mul_sm_rowpm (const SM& a, const octave_idx_type *prow)
84 // For a row permutation, iterate across the source a and stuff the
85 // results into the correct destination column in r.
86 {
87  const octave_idx_type nr = a.rows ();
88  const octave_idx_type nc = a.cols ();
89  const octave_idx_type nent = a.nnz ();
90  SM r (nr, nc, nent);
91 
92  for (octave_idx_type j_src = 0; j_src < nc; ++j_src)
93  r.xcidx (prow[j_src]) = a.cidx (j_src+1) - a.cidx (j_src);
94  octave_idx_type k = 0;
95  for (octave_idx_type j = 0; j < nc; ++j)
96  {
97  const octave_idx_type tmp = r.xcidx (j);
98  r.xcidx (j) = k;
99  k += tmp;
100  }
101  r.xcidx (nc) = nent;
102 
103  octave_idx_type k_src = 0;
104  for (octave_idx_type j_src = 0; j_src < nc; ++j_src)
105  {
106  octave_quit ();
107  const octave_idx_type j = prow[j_src];
108  const octave_idx_type kend_src = a.cidx (j_src + 1);
109  for (k = r.xcidx (j); k_src < kend_src; ++k, ++k_src)
110  {
111  r.xridx (k) = a.ridx (k_src);
112  r.xdata (k) = a.data (k_src);
113  }
114  }
115  assert (k_src == nent);
116 
117  return r;
118 }
119 
120 template <typename SM>
121 SM octinternal_do_mul_sm_colpm (const SM& a, const octave_idx_type *pcol)
122 // For a column permutation, iterate across the destination r and pull
123 // data from the correct column of a.
124 {
125  const octave_idx_type nr = a.rows ();
126  const octave_idx_type nc = a.cols ();
127  const octave_idx_type nent = a.nnz ();
128  SM r (nr, nc, nent);
129 
130  for (octave_idx_type j = 0; j < nc; ++j)
131  {
132  const octave_idx_type j_src = pcol[j];
133  r.xcidx (j+1) = r.xcidx (j) + (a.cidx (j_src+1) - a.cidx (j_src));
134  }
135  assert (r.xcidx (nc) == nent);
136 
137  octave_idx_type k = 0;
138  for (octave_idx_type j = 0; j < nc; ++j)
139  {
140  octave_quit ();
141  const octave_idx_type j_src = pcol[j];
142  octave_idx_type k_src;
143  const octave_idx_type kend_src = a.cidx (j_src + 1);
144  for (k_src = a.cidx (j_src); k_src < kend_src; ++k_src, ++k)
145  {
146  r.xridx (k) = a.ridx (k_src);
147  r.xdata (k) = a.data (k_src);
148  }
149  }
150  assert (k == nent);
151 
152  return r;
153 }
154 
155 template <typename SM>
156 SM octinternal_do_mul_sm_pm (const SM& a, const PermMatrix& p)
157 {
158  const octave_idx_type nc = a.cols ();
159  if (p.rows () != nc)
160  octave::err_nonconformant ("operator *",
161  a.rows (), a.cols (), p.rows (), p.cols ());
162 
163  return octinternal_do_mul_sm_colpm (a, p.col_perm_vec ().data ());
164 }
165 
166 #endif
SM octinternal_do_mul_sm_colpm(const SM &a, const octave_idx_type *pcol)
SM octinternal_do_mul_pm_sm(const PermMatrix &p, const SM &a)
SM octinternal_do_mul_sm_rowpm(const SM &a, const octave_idx_type *prow)
SM octinternal_do_mul_colpm_sm(const octave_idx_type *pcol, const SM &a)
SM octinternal_do_mul_sm_pm(const SM &a, const PermMatrix &p)
const T * data() const
Size of the specified dimension.
Definition: Array.h:663
octave_idx_type rows() const
Definition: PermMatrix.h:62
octave_idx_type cols() const
Definition: PermMatrix.h:63
const Array< octave_idx_type > & col_perm_vec() const
Definition: PermMatrix.h:83
void sort(T *data, octave_idx_type nel)
Definition: oct-sort.cc:1522
void err_nonconformant(const char *op, octave_idx_type op1_len, octave_idx_type op2_len)
T * r
Definition: mx-inlines.cc:781
#define OCTAVE_LOCAL_BUFFER(T, buf, size)
Definition: oct-locbuf.h:44