00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #if !defined (octave_bsxfun_defs_h)
00025 #define octave_bsxfun_defs_h 1
00026
00027 #include <algorithm>
00028 #include <iostream>
00029
00030 #include "dim-vector.h"
00031 #include "oct-locbuf.h"
00032 #include "lo-error.h"
00033
00034 #include "mx-inlines.cc"
00035
00036 template <class R, class X, class Y>
00037 Array<R>
00038 do_bsxfun_op (const Array<X>& x, const Array<Y>& y,
00039 void (*op_vv) (size_t, R *, const X *, const Y *),
00040 void (*op_sv) (size_t, R *, X, const Y *),
00041 void (*op_vs) (size_t, R *, const X *, Y))
00042 {
00043 int nd = std::max (x.ndims (), y.ndims ());
00044 dim_vector dvx = x.dims ().redim (nd), dvy = y.dims ().redim (nd);
00045
00046
00047 dim_vector dvr;
00048 dvr.resize (nd);
00049 for (int i = 0; i < nd; i++)
00050 {
00051 octave_idx_type xk = dvx(i), yk = dvy(i);
00052 if (xk == 1)
00053 dvr(i) = yk;
00054 else if (yk == 1 || xk == yk)
00055 dvr(i) = xk;
00056 else
00057 {
00058 (*current_liboctave_error_handler)
00059 ("bsxfun: nonconformant dimensions: %s and %s",
00060 x.dims ().str ().c_str (), y.dims ().str ().c_str ());
00061 break;
00062 }
00063 }
00064
00065 Array<R> retval (dvr);
00066
00067 const X *xvec = x.fortran_vec ();
00068 const Y *yvec = y.fortran_vec ();
00069 R *rvec = retval.fortran_vec ();
00070
00071
00072 octave_idx_type start, ldr = 1;
00073 for (start = 0; start < nd; start++)
00074 {
00075 if (dvx(start) != dvy(start))
00076 break;
00077 ldr *= dvr(start);
00078 }
00079
00080 if (retval.is_empty ())
00081 ;
00082 else if (start == nd)
00083 op_vv (retval.numel (), rvec, xvec, yvec);
00084 else
00085 {
00086
00087 bool xsing = false, ysing = false;
00088 if (ldr == 1)
00089 {
00090 xsing = dvx(start) == 1;
00091 ysing = dvy(start) == 1;
00092 if (xsing || ysing)
00093 {
00094 ldr *= dvx(start) * dvy(start);
00095 start++;
00096 }
00097 }
00098 dim_vector cdvx = dvx.cumulative (), cdvy = dvy.cumulative ();
00099
00100 for (int i = std::max (start, octave_idx_type (1)); i < nd; i++)
00101 {
00102 if (dvx(i) == 1)
00103 cdvx(i-1) = 0;
00104 if (dvy(i) == 1)
00105 cdvy(i-1) = 0;
00106 }
00107
00108 octave_idx_type niter = dvr.numel (start);
00109
00110 OCTAVE_LOCAL_BUFFER_INIT (octave_idx_type, idx, nd, 0);
00111 for (octave_idx_type iter = 0; iter < niter; iter++)
00112 {
00113 octave_quit ();
00114
00115
00116
00117 octave_idx_type xidx = cdvx.cum_compute_index (idx);
00118 octave_idx_type yidx = cdvy.cum_compute_index (idx);
00119 octave_idx_type ridx = dvr.compute_index (idx);
00120
00121
00122 if (xsing)
00123 op_sv (ldr, rvec + ridx, xvec[xidx], yvec + yidx);
00124 else if (ysing)
00125 op_vs (ldr, rvec + ridx, xvec + xidx, yvec[yidx]);
00126 else
00127 op_vv (ldr, rvec + ridx, xvec + xidx, yvec + yidx);
00128
00129 dvr.increment_index (idx + start, start);
00130 }
00131 }
00132
00133 return retval;
00134 }
00135
00136 template <class R, class X>
00137 void
00138 do_inplace_bsxfun_op (Array<R>& r, const Array<X>& x,
00139 void (*op_vv) (size_t, R *, const X *),
00140 void (*op_vs) (size_t, R *, X))
00141 {
00142 dim_vector dvr = r.dims (), dvx = x.dims ();
00143 octave_idx_type nd = r.ndims ();
00144 dvx.redim (nd);
00145
00146 const X* xvec = x.fortran_vec ();
00147 R* rvec = r.fortran_vec ();
00148
00149
00150 octave_idx_type start, ldr = 1;
00151 for (start = 0; start < nd; start++)
00152 {
00153 if (dvr(start) != dvx(start))
00154 break;
00155 ldr *= dvr(start);
00156 }
00157
00158 if (r.is_empty ())
00159 ;
00160 else if (start == nd)
00161 op_vv (r.numel (), rvec, xvec);
00162 else
00163 {
00164
00165 bool xsing = false;
00166 if (ldr == 1)
00167 {
00168 xsing = dvx(start) == 1;
00169 if (xsing)
00170 {
00171 ldr *= dvr(start) * dvx(start);
00172 start++;
00173 }
00174 }
00175
00176 dim_vector cdvx = dvx.cumulative ();
00177
00178 for (int i = std::max (start, octave_idx_type (1)); i < nd; i++)
00179 {
00180 if (dvx(i) == 1)
00181 cdvx(i-1) = 0;
00182 }
00183
00184 octave_idx_type niter = dvr.numel (start);
00185
00186 OCTAVE_LOCAL_BUFFER_INIT (octave_idx_type, idx, nd, 0);
00187 for (octave_idx_type iter = 0; iter < niter; iter++)
00188 {
00189 octave_quit ();
00190
00191
00192
00193 octave_idx_type xidx = cdvx.cum_compute_index (idx);
00194 octave_idx_type ridx = dvr.compute_index (idx);
00195
00196
00197 if (xsing)
00198 op_vs (ldr, rvec + ridx, xvec[xidx]);
00199 else
00200 op_vv (ldr, rvec + ridx, xvec + xidx);
00201
00202 dvr.increment_index (idx + start, start);
00203 }
00204 }
00205 }
00206
00207 #define BSXFUN_OP_DEF(OP, ARRAY) \
00208 ARRAY bsxfun_ ## OP (const ARRAY& x, const ARRAY& y)
00209
00210 #define BSXFUN_OP2_DEF(OP, ARRAY, ARRAY1, ARRAY2) \
00211 ARRAY bsxfun_ ## OP (const ARRAY1& x, const ARRAY2& y)
00212
00213 #define BSXFUN_REL_DEF(OP, ARRAY) \
00214 boolNDArray bsxfun_ ## OP (const ARRAY& x, const ARRAY& y)
00215
00216 #define BSXFUN_OP_DEF_MXLOOP(OP, ARRAY, LOOP) \
00217 BSXFUN_OP_DEF(OP, ARRAY) \
00218 { return do_bsxfun_op<ARRAY::element_type, ARRAY::element_type, ARRAY::element_type> \
00219 (x, y, LOOP, LOOP, LOOP); }
00220
00221 #define BSXFUN_OP2_DEF_MXLOOP(OP, ARRAY, ARRAY1, ARRAY2, LOOP) \
00222 BSXFUN_OP2_DEF(OP, ARRAY, ARRAY1, ARRAY2) \
00223 { return do_bsxfun_op<ARRAY::element_type, ARRAY1::element_type, ARRAY2::element_type> \
00224 (x, y, LOOP, LOOP, LOOP); }
00225
00226 #define BSXFUN_REL_DEF_MXLOOP(OP, ARRAY, LOOP) \
00227 BSXFUN_REL_DEF(OP, ARRAY) \
00228 { return do_bsxfun_op<bool, ARRAY::element_type, ARRAY::element_type> \
00229 (x, y, LOOP, LOOP, LOOP); }
00230
00231 #define BSXFUN_STDOP_DEFS_MXLOOP(ARRAY) \
00232 BSXFUN_OP_DEF_MXLOOP (add, ARRAY, mx_inline_add) \
00233 BSXFUN_OP_DEF_MXLOOP (sub, ARRAY, mx_inline_sub) \
00234 BSXFUN_OP_DEF_MXLOOP (mul, ARRAY, mx_inline_mul) \
00235 BSXFUN_OP_DEF_MXLOOP (div, ARRAY, mx_inline_div) \
00236 BSXFUN_OP_DEF_MXLOOP (min, ARRAY, mx_inline_xmin) \
00237 BSXFUN_OP_DEF_MXLOOP (max, ARRAY, mx_inline_xmax) \
00238
00239 #define BSXFUN_STDREL_DEFS_MXLOOP(ARRAY) \
00240 BSXFUN_REL_DEF_MXLOOP (eq, ARRAY, mx_inline_eq) \
00241 BSXFUN_REL_DEF_MXLOOP (ne, ARRAY, mx_inline_ne) \
00242 BSXFUN_REL_DEF_MXLOOP (lt, ARRAY, mx_inline_lt) \
00243 BSXFUN_REL_DEF_MXLOOP (le, ARRAY, mx_inline_le) \
00244 BSXFUN_REL_DEF_MXLOOP (gt, ARRAY, mx_inline_gt) \
00245 BSXFUN_REL_DEF_MXLOOP (ge, ARRAY, mx_inline_ge)
00246
00247
00248 #define BSXFUN_POW_MIXED_MXLOOP(INT_TYPE) \
00249 BSXFUN_OP2_DEF_MXLOOP (pow, INT_TYPE, INT_TYPE, NDArray, mx_inline_pow) \
00250 BSXFUN_OP2_DEF_MXLOOP (pow, INT_TYPE, INT_TYPE, FloatNDArray, mx_inline_pow)\
00251 BSXFUN_OP2_DEF_MXLOOP (pow, INT_TYPE, NDArray, INT_TYPE, mx_inline_pow) \
00252 BSXFUN_OP2_DEF_MXLOOP (pow, INT_TYPE, FloatNDArray, INT_TYPE, mx_inline_pow)
00253
00254 #endif