Sam Chaudry
Upload folder using huggingface_hub
7885a28 verified
raw
history blame
16.1 kB
/* Building blocks for implementing special functions */
#pragma once
#include "config.h"
#include "error.h"
namespace xsf {
namespace detail {
/* Result type of a "generator", a callable object that produces a value
* each time it is called.
*/
template <typename Generator>
using generator_result_t = typename std::decay<typename std::invoke_result<Generator>::type>::type;
/* Used to deduce the type of the numerator/denominator of a fraction. */
template <typename Pair>
struct pair_traits;
template <typename T>
struct pair_traits<std::pair<T, T>> {
using value_type = T;
};
template <typename Pair>
using pair_value_t = typename pair_traits<Pair>::value_type;
/* Used to extract the "value type" of a complex type. */
template <typename T>
struct real_type {
using type = T;
};
template <typename T>
struct real_type<std::complex<T>> {
using type = T;
};
template <typename T>
using real_type_t = typename real_type<T>::type;
// Return NaN, handling both real and complex types.
template <typename T>
XSF_HOST_DEVICE inline typename std::enable_if<std::is_floating_point<T>::value, T>::type maybe_complex_NaN() {
return std::numeric_limits<T>::quiet_NaN();
}
template <typename T>
XSF_HOST_DEVICE inline typename std::enable_if<!std::is_floating_point<T>::value, T>::type maybe_complex_NaN() {
using V = typename T::value_type;
return {std::numeric_limits<V>::quiet_NaN(), std::numeric_limits<V>::quiet_NaN()};
}
// Series evaluators.
template <typename Generator, typename T = generator_result_t<Generator>>
XSF_HOST_DEVICE T
series_eval(Generator &g, T init_val, real_type_t<T> tol, std::uint64_t max_terms, const char *func_name) {
/* Sum an infinite series to a given precision.
*
* g : a generator of terms for the series.
*
* init_val : A starting value that terms are added to. This argument determines the
* type of the result.
*
* tol : relative tolerance for stopping criterion.
*
* max_terms : The maximum number of terms to add before giving up and declaring
* non-convergence.
*
* func_name : The name of the function within SciPy where this call to series_eval
* will ultimately be used. This is needed to pass to set_error in case
* of non-convergence.
*/
T result = init_val;
T term;
for (std::uint64_t i = 0; i < max_terms; ++i) {
term = g();
result += term;
if (std::abs(term) < std::abs(result) * tol) {
return result;
}
}
// Exceeded max terms without converging. Return NaN.
set_error(func_name, SF_ERROR_NO_RESULT, NULL);
return maybe_complex_NaN<T>();
}
template <typename Generator, typename T = generator_result_t<Generator>>
XSF_HOST_DEVICE T series_eval_fixed_length(Generator &g, T init_val, std::uint64_t num_terms) {
/* Sum a fixed number of terms from a series.
*
* g : a generator of terms for the series.
*
* init_val : A starting value that terms are added to. This argument determines the
* type of the result.
*
* max_terms : The number of terms from the series to sum.
*
*/
T result = init_val;
for (std::uint64_t i = 0; i < num_terms; ++i) {
result += g();
}
return result;
}
/* Performs one step of Kahan summation. */
template <typename T>
XSF_HOST_DEVICE void kahan_step(T &sum, T &comp, T x) {
T y = x - comp;
T t = sum + y;
comp = (t - sum) - y;
sum = t;
}
/* Evaluates an infinite series using Kahan summation.
*
* Denote the series by
*
* S = a[0] + a[1] + a[2] + ...
*
* And for n = 0, 1, 2, ..., denote its n-th partial sum by
*
* S[n] = a[0] + a[1] + ... + a[n]
*
* This function computes S[0], S[1], ... until a[n] is sufficiently
* small or if the maximum number of terms have been evaluated.
*
* Parameters
* ----------
* g
* Reference to generator that yields the sequence of values a[1],
* a[2], a[3], ...
*
* tol
* Relative tolerance for convergence. Specifically, stop iteration
* as soon as `abs(a[n]) <= tol * abs(S[n])` for some n >= 1.
*
* max_terms
* Maximum number of terms after a[0] to evaluate. It should be set
* large enough such that the convergence criterion is guaranteed
* to have been satisfied within that many terms if there is no
* rounding error.
*
* init_val
* a[0]. Default is zero. The type of this parameter (T) is used
* for intermediary computations as well as the result.
*
* Return Value
* ------------
* If the convergence criterion is satisfied by some `n <= max_terms`,
* returns `(S[n], n)`. Otherwise, returns `(S[max_terms], 0)`.
*/
template <typename Generator, typename T = generator_result_t<Generator>>
XSF_HOST_DEVICE std::pair<T, std::uint64_t>
series_eval_kahan(Generator &&g, real_type_t<T> tol, std::uint64_t max_terms, T init_val = T(0)) {
using std::abs;
T sum = init_val;
T comp = T(0);
for (std::uint64_t i = 0; i < max_terms; ++i) {
T term = g();
kahan_step(sum, comp, term);
if (abs(term) <= tol * abs(sum)) {
return {sum, i + 1};
}
}
return {sum, 0};
}
/* Generator that yields the difference of successive convergents of a
* continued fraction.
*
* Let f[n] denote the n-th convergent of a continued fraction:
*
* a[1] a[2] a[n]
* f[n] = b[0] + ------ ------ ... ----
* b[1] + b[2] + b[n]
*
* with f[0] = b[0]. This generator yields the sequence of values
* f[1]-f[0], f[2]-f[1], f[3]-f[2], ...
*
* Constructor Arguments
* ---------------------
* cf
* Reference to generator that yields the terms of the continued
* fraction as (numerator, denominator) pairs, starting from
* (a[1], b[1]).
*
* `cf` must outlive the ContinuedFractionSeriesGenerator object.
*
* The constructed object always eagerly retrieves the next term
* of the continued fraction. Specifically, (a[1], b[1]) is
* retrieved upon construction, and (a[n], b[n]) is retrieved after
* (n-1) calls of `()`.
*
* Type Arguments
* --------------
* T
* Type in which computations are performed and results are turned.
*
* Remarks
* -------
* The series is computed using the recurrence relation described in [1].
* Let v[n], n >= 1 denote the terms of the series. Then
*
* v[1] = a[1] / b[1]
* v[n] = v[n-1] * r[n-1], n >= 2
*
* where
*
* -(a[n] + a[n] * r[n-1])
* r[1] = 0, r[n] = ------------------------------------------, n >= 2
* (a[n] + a[n] * r[n-1]) + (b[n] * b[n-1])
*
* No error checking is performed. The caller must ensure that all terms
* are finite and that intermediary computations do not trigger floating
* point exceptions such as overflow.
*
* The numerical stability of this method depends on the characteristics
* of the continued fraction being evaluated.
*
* Reference
* ---------
* [1] Gautschi, W. (1967). “Computational Aspects of Three-Term
* Recurrence Relations.” SIAM Review, 9(1):24-82.
*/
template <typename Generator, typename T = pair_value_t<generator_result_t<Generator>>>
class ContinuedFractionSeriesGenerator {
public:
XSF_HOST_DEVICE explicit ContinuedFractionSeriesGenerator(Generator &cf) : cf_(cf) { init(); }
XSF_HOST_DEVICE T operator()() {
T v = v_;
advance();
return v;
}
private:
XSF_HOST_DEVICE void init() {
auto [num, denom] = cf_();
T a = num;
T b = denom;
r_ = T(0);
v_ = a / b;
b_ = b;
}
XSF_HOST_DEVICE void advance() {
auto [num, denom] = cf_();
T a = num;
T b = denom;
T p = a + a * r_;
T q = p + b * b_;
r_ = -p / q;
v_ = v_ * r_;
b_ = b;
}
Generator &cf_; // reference to continued fraction generator
T v_; // v[n] == f[n] - f[n-1], n >= 1
T r_; // r[1] = 0, r[n] = v[n]/v[n-1], n >= 2
T b_; // last denominator, i.e. b[n-1]
};
/* Converts a continued fraction into a series whose terms are the
* difference of its successive convergents.
*
* See ContinuedFractionSeriesGenerator for details.
*/
template <typename Generator, typename T = pair_value_t<generator_result_t<Generator>>>
XSF_HOST_DEVICE ContinuedFractionSeriesGenerator<Generator, T> continued_fraction_series(Generator &cf) {
return ContinuedFractionSeriesGenerator<Generator, T>(cf);
}
/* Find initial bracket for a bracketing scalar root finder. A valid bracket is a pair of points a < b for
* which the signs of f(a) and f(b) differ. If f(x0) = 0, where x0 is the initial guess, this bracket finder
* will return the bracket (x0, x0). It is expected that the rootfinder will check if the bracket
* endpoints are roots.
*
* This is a private function intended specifically for the situation where
* the goal is to invert a CDF function F for a parametrized family of distributions with respect to one
* parameter, when the other parameters are known, and where F is monotonic with respect to the unknown parameter.
*/
template <typename Function>
XSF_HOST_DEVICE inline std::tuple<double, double, double, double, int> bracket_root_for_cdf_inversion(
Function func, double x0, double xmin, double xmax, double step0_left,
double step0_right, double factor_left, double factor_right, bool increasing, std::uint64_t maxiter
) {
double y0 = func(x0);
if (y0 == 0) {
// Initial guess is correct.
return {x0, x0, y0, y0, 0};
}
double y0_sgn = std::signbit(y0);
bool search_left;
/* The frontier is the new leading endpoint of the expanding bracket. The
* interior endpoint trails behind the frontier. In each step, the old frontier
* endpoint becomes the new interior endpoint. */
double interior, frontier, y_interior, y_frontier, y_interior_sgn, y_frontier_sgn, boundary, factor;
if ((increasing && y0 < 0) || (!increasing && y0 > 0)) {
/* If func is increasing and func(x_right) < 0 or if func is decreasing and
* f(y_right) > 0, we should expand the bracket to the right. */
interior = x0, y_interior = y0;
frontier = x0 + step0_right;
y_interior_sgn = y0_sgn;
search_left = false;
boundary = xmax;
factor = factor_right;
} else {
/* Otherwise we move and expand the bracket to the left. */
interior = x0, y_interior = y0;
frontier = x0 + step0_left;
y_interior_sgn = y0_sgn;
search_left = true;
boundary = xmin;
factor = factor_left;
}
bool reached_boundary = false;
for (std::uint64_t i = 0; i < maxiter; i++) {
y_frontier = func(frontier);
y_frontier_sgn = std::signbit(y_frontier);
if (y_frontier_sgn != y_interior_sgn || (y_frontier == 0.0)) {
/* Stopping condition, func evaluated at endpoints of bracket has opposing signs,
* meeting requirement for bracketing root finder. (Or endpoint has reached a
* zero.) */
if (search_left) {
/* Ensure we return an interval (a, b) with a < b. */
std::swap(interior, frontier);
std::swap(y_interior, y_frontier);
}
return {interior, frontier, y_interior, y_frontier, 0};
}
if (reached_boundary) {
/* We've reached a boundary point without finding a root . */
return {
std::numeric_limits<double>::quiet_NaN(), std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(), std::numeric_limits<double>::quiet_NaN(),
search_left ? 1 : 2
};
}
double step = (frontier - interior) * factor;
interior = frontier;
y_interior = y_frontier;
y_interior_sgn = y_frontier_sgn;
frontier += step;
if ((search_left && frontier <= boundary) || (!search_left && frontier >= boundary)) {
/* If the frontier has reached the boundary, set a flag so the algorithm will know
* not to search beyond this point. */
frontier = boundary;
reached_boundary = true;
}
}
/* Failed to converge within maxiter iterations. If maxiter is sufficiently high and
* factor_left and factor_right are set appropriately, this should only happen due to
* a bug in this function. Limiting the number of iterations is a defensive programming measure. */
return {
std::numeric_limits<double>::quiet_NaN(), std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(), std::numeric_limits<double>::quiet_NaN(), 3
};
}
/* Find root of a scalar function using Chandrupatla's algorithm */
template <typename Function>
XSF_HOST_DEVICE inline std::pair<double, int> find_root_chandrupatla(
Function func, double x1, double x2, double f1, double f2, double rtol,
double atol, std::uint64_t maxiter
) {
if (f1 == 0) {
return {x1, 0};
}
if (f2 == 0) {
return {x2, 0};
}
double t = 0.5, x3, f3;
for (uint64_t i = 0; i < maxiter; i++) {
double x = x1 + t * (x2 - x1);
double f = func(x);
if (std::signbit(f) == std::signbit(f1)) {
x3 = x1;
x1 = x;
f3 = f1;
f1 = f;
} else {
x3 = x2;
x2 = x1;
x1 = x;
f3 = f2;
f2 = f1;
f1 = f;
}
double xm, fm;
if (std::abs(f2) < std::abs(f1)) {
xm = x2;
fm = f2;
} else {
xm = x1;
fm = f1;
}
double tol = 2.0 * rtol * std::abs(xm) + 0.5 * atol;
double tl = tol / std::abs(x2 - x1);
if (tl > 0.5 || fm == 0) {
return {xm, 0};
}
double xi = (x1 - x2) / (x3 - x2);
double phi = (f1 - f2) / (f3 - f2);
double fl = 1.0 - std::sqrt(1.0 - xi);
double fh = std::sqrt(xi);
if ((fl < phi) && (phi < fh)) {
t = (f1 / (f2 - f1)) * (f3 / (f2 - f3)) + (f1 / (f3 - f1)) * (f2 / (f3 - f2)) * ((x3 - x1) / (x2 - x1));
} else {
t = 0.5;
}
t = std::fmin(std::fmax(t, tl), 1.0 - tl);
}
return {std::numeric_limits<double>::quiet_NaN(), 1};
}
} // namespace detail
} // namespace xsf