41 if (args.length () != 3)
44 bool is_single = (args(0).is_single_type () || args(1).is_single_type ()
45 || args(2).is_single_type ());
48 int numel_x = args(0).numel ();
49 int numel_a = args(1).numel ();
50 int numel_b = args(2).numel ();
51 int len = std::max (std::max (numel_x, numel_a), numel_b);
67 x = args(0).float_array_value ();
71 a =
FloatNDArray (output_dv, args(1).float_scalar_value ());
73 a = args(1).float_array_value ();
76 b =
FloatNDArray (output_dv, args(2).float_scalar_value ());
78 b = args(2).float_array_value ();
81 static const float tiny = math::exp2 (-50.0f);
82 static constexpr float eps = std::numeric_limits<float>::epsilon ();
83 float xj, x2, y, Cj, Dj, aj, bj, Deltaj, alpha_j, beta_j;
102 beta_j = aj - (aj * (aj + bj)) / (aj + 1) * xj;
107 while ((std::abs ((Deltaj - 1)) >
eps) && (j < maxit))
109 Dj = beta_j + alpha_j * Dj;
112 Cj = beta_j + alpha_j / Cj;
118 alpha_j = ((aj + j - 1) * (aj + bj + j -1) * (bj - j) * j)
119 / ((aj + 2 * j - 1) * (aj + 2 * j - 1)) * x2;
120 beta_j = aj + 2 * j + ((j * (bj - j)) / (aj + 2 * j - 1)
121 - ((aj + j) * (aj + bj + j)) / (aj + 2 * j + 1)) * xj;
137 x =
NDArray (output_dv, args(0).scalar_value ());
139 x = args(0).array_value ();
142 a =
NDArray (output_dv, args(1).scalar_value ());
144 a = args(1).array_value ();
147 b =
NDArray (output_dv, args(2).scalar_value ());
149 b = args(2).array_value ();
152 static const double tiny = math::exp2 (-100.0);
153 static constexpr double eps = std::numeric_limits<double>::epsilon ();
154 double xj, x2, y, Cj, Dj, aj, bj, Deltaj, alpha_j, beta_j;
173 beta_j = aj - (aj * (aj + bj)) / (aj + 1) * xj;
178 while ((std::abs ((Deltaj - 1)) >
eps) && (j < maxit))
180 Dj = beta_j + alpha_j * Dj;
183 Cj = beta_j + alpha_j / Cj;
189 alpha_j = ((aj + j - 1) * (aj + bj + j - 1) * (bj - j) * j)
190 / ((aj + 2 * j - 1) * (aj + 2 * j - 1)) * x2;
191 beta_j = aj + 2 * j + ((j * (bj - j)) / (aj + 2 * j - 1)
192 - ((aj + j) * (aj + bj + j)) / (aj + 2 * j + 1)) * xj;