/* Building blocks for implementing special functions */ #pragma once #include "config.h" #include "error.h" namespace xsf { namespace detail { /* Result type of a "generator", a callable object that produces a value * each time it is called. */ template using generator_result_t = typename std::decay::type>::type; /* Used to deduce the type of the numerator/denominator of a fraction. */ template struct pair_traits; template struct pair_traits> { using value_type = T; }; template using pair_value_t = typename pair_traits::value_type; /* Used to extract the "value type" of a complex type. */ template struct real_type { using type = T; }; template struct real_type> { using type = T; }; template using real_type_t = typename real_type::type; // Return NaN, handling both real and complex types. template XSF_HOST_DEVICE inline typename std::enable_if::value, T>::type maybe_complex_NaN() { return std::numeric_limits::quiet_NaN(); } template XSF_HOST_DEVICE inline typename std::enable_if::value, T>::type maybe_complex_NaN() { using V = typename T::value_type; return {std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN()}; } // Series evaluators. template > XSF_HOST_DEVICE T series_eval(Generator &g, T init_val, real_type_t tol, std::uint64_t max_terms, const char *func_name) { /* Sum an infinite series to a given precision. * * g : a generator of terms for the series. * * init_val : A starting value that terms are added to. This argument determines the * type of the result. * * tol : relative tolerance for stopping criterion. * * max_terms : The maximum number of terms to add before giving up and declaring * non-convergence. * * func_name : The name of the function within SciPy where this call to series_eval * will ultimately be used. This is needed to pass to set_error in case * of non-convergence. */ T result = init_val; T term; for (std::uint64_t i = 0; i < max_terms; ++i) { term = g(); result += term; if (std::abs(term) < std::abs(result) * tol) { return result; } } // Exceeded max terms without converging. Return NaN. set_error(func_name, SF_ERROR_NO_RESULT, NULL); return maybe_complex_NaN(); } template > XSF_HOST_DEVICE T series_eval_fixed_length(Generator &g, T init_val, std::uint64_t num_terms) { /* Sum a fixed number of terms from a series. * * g : a generator of terms for the series. * * init_val : A starting value that terms are added to. This argument determines the * type of the result. * * max_terms : The number of terms from the series to sum. * */ T result = init_val; for (std::uint64_t i = 0; i < num_terms; ++i) { result += g(); } return result; } /* Performs one step of Kahan summation. */ template XSF_HOST_DEVICE void kahan_step(T &sum, T &comp, T x) { T y = x - comp; T t = sum + y; comp = (t - sum) - y; sum = t; } /* Evaluates an infinite series using Kahan summation. * * Denote the series by * * S = a[0] + a[1] + a[2] + ... * * And for n = 0, 1, 2, ..., denote its n-th partial sum by * * S[n] = a[0] + a[1] + ... + a[n] * * This function computes S[0], S[1], ... until a[n] is sufficiently * small or if the maximum number of terms have been evaluated. * * Parameters * ---------- * g * Reference to generator that yields the sequence of values a[1], * a[2], a[3], ... * * tol * Relative tolerance for convergence. Specifically, stop iteration * as soon as `abs(a[n]) <= tol * abs(S[n])` for some n >= 1. * * max_terms * Maximum number of terms after a[0] to evaluate. It should be set * large enough such that the convergence criterion is guaranteed * to have been satisfied within that many terms if there is no * rounding error. * * init_val * a[0]. Default is zero. The type of this parameter (T) is used * for intermediary computations as well as the result. * * Return Value * ------------ * If the convergence criterion is satisfied by some `n <= max_terms`, * returns `(S[n], n)`. Otherwise, returns `(S[max_terms], 0)`. */ template > XSF_HOST_DEVICE std::pair series_eval_kahan(Generator &&g, real_type_t tol, std::uint64_t max_terms, T init_val = T(0)) { using std::abs; T sum = init_val; T comp = T(0); for (std::uint64_t i = 0; i < max_terms; ++i) { T term = g(); kahan_step(sum, comp, term); if (abs(term) <= tol * abs(sum)) { return {sum, i + 1}; } } return {sum, 0}; } /* Generator that yields the difference of successive convergents of a * continued fraction. * * Let f[n] denote the n-th convergent of a continued fraction: * * a[1] a[2] a[n] * f[n] = b[0] + ------ ------ ... ---- * b[1] + b[2] + b[n] * * with f[0] = b[0]. This generator yields the sequence of values * f[1]-f[0], f[2]-f[1], f[3]-f[2], ... * * Constructor Arguments * --------------------- * cf * Reference to generator that yields the terms of the continued * fraction as (numerator, denominator) pairs, starting from * (a[1], b[1]). * * `cf` must outlive the ContinuedFractionSeriesGenerator object. * * The constructed object always eagerly retrieves the next term * of the continued fraction. Specifically, (a[1], b[1]) is * retrieved upon construction, and (a[n], b[n]) is retrieved after * (n-1) calls of `()`. * * Type Arguments * -------------- * T * Type in which computations are performed and results are turned. * * Remarks * ------- * The series is computed using the recurrence relation described in [1]. * Let v[n], n >= 1 denote the terms of the series. Then * * v[1] = a[1] / b[1] * v[n] = v[n-1] * r[n-1], n >= 2 * * where * * -(a[n] + a[n] * r[n-1]) * r[1] = 0, r[n] = ------------------------------------------, n >= 2 * (a[n] + a[n] * r[n-1]) + (b[n] * b[n-1]) * * No error checking is performed. The caller must ensure that all terms * are finite and that intermediary computations do not trigger floating * point exceptions such as overflow. * * The numerical stability of this method depends on the characteristics * of the continued fraction being evaluated. * * Reference * --------- * [1] Gautschi, W. (1967). “Computational Aspects of Three-Term * Recurrence Relations.” SIAM Review, 9(1):24-82. */ template >> class ContinuedFractionSeriesGenerator { public: XSF_HOST_DEVICE explicit ContinuedFractionSeriesGenerator(Generator &cf) : cf_(cf) { init(); } XSF_HOST_DEVICE T operator()() { T v = v_; advance(); return v; } private: XSF_HOST_DEVICE void init() { auto [num, denom] = cf_(); T a = num; T b = denom; r_ = T(0); v_ = a / b; b_ = b; } XSF_HOST_DEVICE void advance() { auto [num, denom] = cf_(); T a = num; T b = denom; T p = a + a * r_; T q = p + b * b_; r_ = -p / q; v_ = v_ * r_; b_ = b; } Generator &cf_; // reference to continued fraction generator T v_; // v[n] == f[n] - f[n-1], n >= 1 T r_; // r[1] = 0, r[n] = v[n]/v[n-1], n >= 2 T b_; // last denominator, i.e. b[n-1] }; /* Converts a continued fraction into a series whose terms are the * difference of its successive convergents. * * See ContinuedFractionSeriesGenerator for details. */ template >> XSF_HOST_DEVICE ContinuedFractionSeriesGenerator continued_fraction_series(Generator &cf) { return ContinuedFractionSeriesGenerator(cf); } /* Find initial bracket for a bracketing scalar root finder. A valid bracket is a pair of points a < b for * which the signs of f(a) and f(b) differ. If f(x0) = 0, where x0 is the initial guess, this bracket finder * will return the bracket (x0, x0). It is expected that the rootfinder will check if the bracket * endpoints are roots. * * This is a private function intended specifically for the situation where * the goal is to invert a CDF function F for a parametrized family of distributions with respect to one * parameter, when the other parameters are known, and where F is monotonic with respect to the unknown parameter. */ template XSF_HOST_DEVICE inline std::tuple bracket_root_for_cdf_inversion( Function func, double x0, double xmin, double xmax, double step0_left, double step0_right, double factor_left, double factor_right, bool increasing, std::uint64_t maxiter ) { double y0 = func(x0); if (y0 == 0) { // Initial guess is correct. return {x0, x0, y0, y0, 0}; } double y0_sgn = std::signbit(y0); bool search_left; /* The frontier is the new leading endpoint of the expanding bracket. The * interior endpoint trails behind the frontier. In each step, the old frontier * endpoint becomes the new interior endpoint. */ double interior, frontier, y_interior, y_frontier, y_interior_sgn, y_frontier_sgn, boundary, factor; if ((increasing && y0 < 0) || (!increasing && y0 > 0)) { /* If func is increasing and func(x_right) < 0 or if func is decreasing and * f(y_right) > 0, we should expand the bracket to the right. */ interior = x0, y_interior = y0; frontier = x0 + step0_right; y_interior_sgn = y0_sgn; search_left = false; boundary = xmax; factor = factor_right; } else { /* Otherwise we move and expand the bracket to the left. */ interior = x0, y_interior = y0; frontier = x0 + step0_left; y_interior_sgn = y0_sgn; search_left = true; boundary = xmin; factor = factor_left; } bool reached_boundary = false; for (std::uint64_t i = 0; i < maxiter; i++) { y_frontier = func(frontier); y_frontier_sgn = std::signbit(y_frontier); if (y_frontier_sgn != y_interior_sgn || (y_frontier == 0.0)) { /* Stopping condition, func evaluated at endpoints of bracket has opposing signs, * meeting requirement for bracketing root finder. (Or endpoint has reached a * zero.) */ if (search_left) { /* Ensure we return an interval (a, b) with a < b. */ std::swap(interior, frontier); std::swap(y_interior, y_frontier); } return {interior, frontier, y_interior, y_frontier, 0}; } if (reached_boundary) { /* We've reached a boundary point without finding a root . */ return { std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), search_left ? 1 : 2 }; } double step = (frontier - interior) * factor; interior = frontier; y_interior = y_frontier; y_interior_sgn = y_frontier_sgn; frontier += step; if ((search_left && frontier <= boundary) || (!search_left && frontier >= boundary)) { /* If the frontier has reached the boundary, set a flag so the algorithm will know * not to search beyond this point. */ frontier = boundary; reached_boundary = true; } } /* Failed to converge within maxiter iterations. If maxiter is sufficiently high and * factor_left and factor_right are set appropriately, this should only happen due to * a bug in this function. Limiting the number of iterations is a defensive programming measure. */ return { std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), 3 }; } /* Find root of a scalar function using Chandrupatla's algorithm */ template XSF_HOST_DEVICE inline std::pair find_root_chandrupatla( Function func, double x1, double x2, double f1, double f2, double rtol, double atol, std::uint64_t maxiter ) { if (f1 == 0) { return {x1, 0}; } if (f2 == 0) { return {x2, 0}; } double t = 0.5, x3, f3; for (uint64_t i = 0; i < maxiter; i++) { double x = x1 + t * (x2 - x1); double f = func(x); if (std::signbit(f) == std::signbit(f1)) { x3 = x1; x1 = x; f3 = f1; f1 = f; } else { x3 = x2; x2 = x1; x1 = x; f3 = f2; f2 = f1; f1 = f; } double xm, fm; if (std::abs(f2) < std::abs(f1)) { xm = x2; fm = f2; } else { xm = x1; fm = f1; } double tol = 2.0 * rtol * std::abs(xm) + 0.5 * atol; double tl = tol / std::abs(x2 - x1); if (tl > 0.5 || fm == 0) { return {xm, 0}; } double xi = (x1 - x2) / (x3 - x2); double phi = (f1 - f2) / (f3 - f2); double fl = 1.0 - std::sqrt(1.0 - xi); double fh = std::sqrt(xi); if ((fl < phi) && (phi < fh)) { t = (f1 / (f2 - f1)) * (f3 / (f2 - f3)) + (f1 / (f3 - f1)) * (f2 / (f3 - f2)) * ((x3 - x1) / (x2 - x1)); } else { t = 0.5; } t = std::fmin(std::fmax(t, tl), 1.0 - tl); } return {std::numeric_limits::quiet_NaN(), 1}; } } // namespace detail } // namespace xsf