/* Building blocks for implementing special functions */ | |
namespace xsf { | |
namespace detail { | |
/* Result type of a "generator", a callable object that produces a value | |
* each time it is called. | |
*/ | |
template <typename Generator> | |
using generator_result_t = typename std::decay<typename std::invoke_result<Generator>::type>::type; | |
/* Used to deduce the type of the numerator/denominator of a fraction. */ | |
template <typename Pair> | |
struct pair_traits; | |
template <typename T> | |
struct pair_traits<std::pair<T, T>> { | |
using value_type = T; | |
}; | |
template <typename Pair> | |
using pair_value_t = typename pair_traits<Pair>::value_type; | |
/* Used to extract the "value type" of a complex type. */ | |
template <typename T> | |
struct real_type { | |
using type = T; | |
}; | |
template <typename T> | |
struct real_type<std::complex<T>> { | |
using type = T; | |
}; | |
template <typename T> | |
using real_type_t = typename real_type<T>::type; | |
// Return NaN, handling both real and complex types. | |
template <typename T> | |
XSF_HOST_DEVICE inline typename std::enable_if<std::is_floating_point<T>::value, T>::type maybe_complex_NaN() { | |
return std::numeric_limits<T>::quiet_NaN(); | |
} | |
template <typename T> | |
XSF_HOST_DEVICE inline typename std::enable_if<!std::is_floating_point<T>::value, T>::type maybe_complex_NaN() { | |
using V = typename T::value_type; | |
return {std::numeric_limits<V>::quiet_NaN(), std::numeric_limits<V>::quiet_NaN()}; | |
} | |
// Series evaluators. | |
template <typename Generator, typename T = generator_result_t<Generator>> | |
XSF_HOST_DEVICE T | |
series_eval(Generator &g, T init_val, real_type_t<T> tol, std::uint64_t max_terms, const char *func_name) { | |
/* Sum an infinite series to a given precision. | |
* | |
* g : a generator of terms for the series. | |
* | |
* init_val : A starting value that terms are added to. This argument determines the | |
* type of the result. | |
* | |
* tol : relative tolerance for stopping criterion. | |
* | |
* max_terms : The maximum number of terms to add before giving up and declaring | |
* non-convergence. | |
* | |
* func_name : The name of the function within SciPy where this call to series_eval | |
* will ultimately be used. This is needed to pass to set_error in case | |
* of non-convergence. | |
*/ | |
T result = init_val; | |
T term; | |
for (std::uint64_t i = 0; i < max_terms; ++i) { | |
term = g(); | |
result += term; | |
if (std::abs(term) < std::abs(result) * tol) { | |
return result; | |
} | |
} | |
// Exceeded max terms without converging. Return NaN. | |
set_error(func_name, SF_ERROR_NO_RESULT, NULL); | |
return maybe_complex_NaN<T>(); | |
} | |
template <typename Generator, typename T = generator_result_t<Generator>> | |
XSF_HOST_DEVICE T series_eval_fixed_length(Generator &g, T init_val, std::uint64_t num_terms) { | |
/* Sum a fixed number of terms from a series. | |
* | |
* g : a generator of terms for the series. | |
* | |
* init_val : A starting value that terms are added to. This argument determines the | |
* type of the result. | |
* | |
* max_terms : The number of terms from the series to sum. | |
* | |
*/ | |
T result = init_val; | |
for (std::uint64_t i = 0; i < num_terms; ++i) { | |
result += g(); | |
} | |
return result; | |
} | |
/* Performs one step of Kahan summation. */ | |
template <typename T> | |
XSF_HOST_DEVICE void kahan_step(T &sum, T &comp, T x) { | |
T y = x - comp; | |
T t = sum + y; | |
comp = (t - sum) - y; | |
sum = t; | |
} | |
/* Evaluates an infinite series using Kahan summation. | |
* | |
* Denote the series by | |
* | |
* S = a[0] + a[1] + a[2] + ... | |
* | |
* And for n = 0, 1, 2, ..., denote its n-th partial sum by | |
* | |
* S[n] = a[0] + a[1] + ... + a[n] | |
* | |
* This function computes S[0], S[1], ... until a[n] is sufficiently | |
* small or if the maximum number of terms have been evaluated. | |
* | |
* Parameters | |
* ---------- | |
* g | |
* Reference to generator that yields the sequence of values a[1], | |
* a[2], a[3], ... | |
* | |
* tol | |
* Relative tolerance for convergence. Specifically, stop iteration | |
* as soon as `abs(a[n]) <= tol * abs(S[n])` for some n >= 1. | |
* | |
* max_terms | |
* Maximum number of terms after a[0] to evaluate. It should be set | |
* large enough such that the convergence criterion is guaranteed | |
* to have been satisfied within that many terms if there is no | |
* rounding error. | |
* | |
* init_val | |
* a[0]. Default is zero. The type of this parameter (T) is used | |
* for intermediary computations as well as the result. | |
* | |
* Return Value | |
* ------------ | |
* If the convergence criterion is satisfied by some `n <= max_terms`, | |
* returns `(S[n], n)`. Otherwise, returns `(S[max_terms], 0)`. | |
*/ | |
template <typename Generator, typename T = generator_result_t<Generator>> | |
XSF_HOST_DEVICE std::pair<T, std::uint64_t> | |
series_eval_kahan(Generator &&g, real_type_t<T> tol, std::uint64_t max_terms, T init_val = T(0)) { | |
using std::abs; | |
T sum = init_val; | |
T comp = T(0); | |
for (std::uint64_t i = 0; i < max_terms; ++i) { | |
T term = g(); | |
kahan_step(sum, comp, term); | |
if (abs(term) <= tol * abs(sum)) { | |
return {sum, i + 1}; | |
} | |
} | |
return {sum, 0}; | |
} | |
/* Generator that yields the difference of successive convergents of a | |
* continued fraction. | |
* | |
* Let f[n] denote the n-th convergent of a continued fraction: | |
* | |
* a[1] a[2] a[n] | |
* f[n] = b[0] + ------ ------ ... ---- | |
* b[1] + b[2] + b[n] | |
* | |
* with f[0] = b[0]. This generator yields the sequence of values | |
* f[1]-f[0], f[2]-f[1], f[3]-f[2], ... | |
* | |
* Constructor Arguments | |
* --------------------- | |
* cf | |
* Reference to generator that yields the terms of the continued | |
* fraction as (numerator, denominator) pairs, starting from | |
* (a[1], b[1]). | |
* | |
* `cf` must outlive the ContinuedFractionSeriesGenerator object. | |
* | |
* The constructed object always eagerly retrieves the next term | |
* of the continued fraction. Specifically, (a[1], b[1]) is | |
* retrieved upon construction, and (a[n], b[n]) is retrieved after | |
* (n-1) calls of `()`. | |
* | |
* Type Arguments | |
* -------------- | |
* T | |
* Type in which computations are performed and results are turned. | |
* | |
* Remarks | |
* ------- | |
* The series is computed using the recurrence relation described in [1]. | |
* Let v[n], n >= 1 denote the terms of the series. Then | |
* | |
* v[1] = a[1] / b[1] | |
* v[n] = v[n-1] * r[n-1], n >= 2 | |
* | |
* where | |
* | |
* -(a[n] + a[n] * r[n-1]) | |
* r[1] = 0, r[n] = ------------------------------------------, n >= 2 | |
* (a[n] + a[n] * r[n-1]) + (b[n] * b[n-1]) | |
* | |
* No error checking is performed. The caller must ensure that all terms | |
* are finite and that intermediary computations do not trigger floating | |
* point exceptions such as overflow. | |
* | |
* The numerical stability of this method depends on the characteristics | |
* of the continued fraction being evaluated. | |
* | |
* Reference | |
* --------- | |
* [1] Gautschi, W. (1967). “Computational Aspects of Three-Term | |
* Recurrence Relations.” SIAM Review, 9(1):24-82. | |
*/ | |
template <typename Generator, typename T = pair_value_t<generator_result_t<Generator>>> | |
class ContinuedFractionSeriesGenerator { | |
public: | |
XSF_HOST_DEVICE explicit ContinuedFractionSeriesGenerator(Generator &cf) : cf_(cf) { init(); } | |
XSF_HOST_DEVICE T operator()() { | |
T v = v_; | |
advance(); | |
return v; | |
} | |
private: | |
XSF_HOST_DEVICE void init() { | |
auto [num, denom] = cf_(); | |
T a = num; | |
T b = denom; | |
r_ = T(0); | |
v_ = a / b; | |
b_ = b; | |
} | |
XSF_HOST_DEVICE void advance() { | |
auto [num, denom] = cf_(); | |
T a = num; | |
T b = denom; | |
T p = a + a * r_; | |
T q = p + b * b_; | |
r_ = -p / q; | |
v_ = v_ * r_; | |
b_ = b; | |
} | |
Generator &cf_; // reference to continued fraction generator | |
T v_; // v[n] == f[n] - f[n-1], n >= 1 | |
T r_; // r[1] = 0, r[n] = v[n]/v[n-1], n >= 2 | |
T b_; // last denominator, i.e. b[n-1] | |
}; | |
/* Converts a continued fraction into a series whose terms are the | |
* difference of its successive convergents. | |
* | |
* See ContinuedFractionSeriesGenerator for details. | |
*/ | |
template <typename Generator, typename T = pair_value_t<generator_result_t<Generator>>> | |
XSF_HOST_DEVICE ContinuedFractionSeriesGenerator<Generator, T> continued_fraction_series(Generator &cf) { | |
return ContinuedFractionSeriesGenerator<Generator, T>(cf); | |
} | |
/* Find initial bracket for a bracketing scalar root finder. A valid bracket is a pair of points a < b for | |
* which the signs of f(a) and f(b) differ. If f(x0) = 0, where x0 is the initial guess, this bracket finder | |
* will return the bracket (x0, x0). It is expected that the rootfinder will check if the bracket | |
* endpoints are roots. | |
* | |
* This is a private function intended specifically for the situation where | |
* the goal is to invert a CDF function F for a parametrized family of distributions with respect to one | |
* parameter, when the other parameters are known, and where F is monotonic with respect to the unknown parameter. | |
*/ | |
template <typename Function> | |
XSF_HOST_DEVICE inline std::tuple<double, double, double, double, int> bracket_root_for_cdf_inversion( | |
Function func, double x0, double xmin, double xmax, double step0_left, | |
double step0_right, double factor_left, double factor_right, bool increasing, std::uint64_t maxiter | |
) { | |
double y0 = func(x0); | |
if (y0 == 0) { | |
// Initial guess is correct. | |
return {x0, x0, y0, y0, 0}; | |
} | |
double y0_sgn = std::signbit(y0); | |
bool search_left; | |
/* The frontier is the new leading endpoint of the expanding bracket. The | |
* interior endpoint trails behind the frontier. In each step, the old frontier | |
* endpoint becomes the new interior endpoint. */ | |
double interior, frontier, y_interior, y_frontier, y_interior_sgn, y_frontier_sgn, boundary, factor; | |
if ((increasing && y0 < 0) || (!increasing && y0 > 0)) { | |
/* If func is increasing and func(x_right) < 0 or if func is decreasing and | |
* f(y_right) > 0, we should expand the bracket to the right. */ | |
interior = x0, y_interior = y0; | |
frontier = x0 + step0_right; | |
y_interior_sgn = y0_sgn; | |
search_left = false; | |
boundary = xmax; | |
factor = factor_right; | |
} else { | |
/* Otherwise we move and expand the bracket to the left. */ | |
interior = x0, y_interior = y0; | |
frontier = x0 + step0_left; | |
y_interior_sgn = y0_sgn; | |
search_left = true; | |
boundary = xmin; | |
factor = factor_left; | |
} | |
bool reached_boundary = false; | |
for (std::uint64_t i = 0; i < maxiter; i++) { | |
y_frontier = func(frontier); | |
y_frontier_sgn = std::signbit(y_frontier); | |
if (y_frontier_sgn != y_interior_sgn || (y_frontier == 0.0)) { | |
/* Stopping condition, func evaluated at endpoints of bracket has opposing signs, | |
* meeting requirement for bracketing root finder. (Or endpoint has reached a | |
* zero.) */ | |
if (search_left) { | |
/* Ensure we return an interval (a, b) with a < b. */ | |
std::swap(interior, frontier); | |
std::swap(y_interior, y_frontier); | |
} | |
return {interior, frontier, y_interior, y_frontier, 0}; | |
} | |
if (reached_boundary) { | |
/* We've reached a boundary point without finding a root . */ | |
return { | |
std::numeric_limits<double>::quiet_NaN(), std::numeric_limits<double>::quiet_NaN(), | |
std::numeric_limits<double>::quiet_NaN(), std::numeric_limits<double>::quiet_NaN(), | |
search_left ? 1 : 2 | |
}; | |
} | |
double step = (frontier - interior) * factor; | |
interior = frontier; | |
y_interior = y_frontier; | |
y_interior_sgn = y_frontier_sgn; | |
frontier += step; | |
if ((search_left && frontier <= boundary) || (!search_left && frontier >= boundary)) { | |
/* If the frontier has reached the boundary, set a flag so the algorithm will know | |
* not to search beyond this point. */ | |
frontier = boundary; | |
reached_boundary = true; | |
} | |
} | |
/* Failed to converge within maxiter iterations. If maxiter is sufficiently high and | |
* factor_left and factor_right are set appropriately, this should only happen due to | |
* a bug in this function. Limiting the number of iterations is a defensive programming measure. */ | |
return { | |
std::numeric_limits<double>::quiet_NaN(), std::numeric_limits<double>::quiet_NaN(), | |
std::numeric_limits<double>::quiet_NaN(), std::numeric_limits<double>::quiet_NaN(), 3 | |
}; | |
} | |
/* Find root of a scalar function using Chandrupatla's algorithm */ | |
template <typename Function> | |
XSF_HOST_DEVICE inline std::pair<double, int> find_root_chandrupatla( | |
Function func, double x1, double x2, double f1, double f2, double rtol, | |
double atol, std::uint64_t maxiter | |
) { | |
if (f1 == 0) { | |
return {x1, 0}; | |
} | |
if (f2 == 0) { | |
return {x2, 0}; | |
} | |
double t = 0.5, x3, f3; | |
for (uint64_t i = 0; i < maxiter; i++) { | |
double x = x1 + t * (x2 - x1); | |
double f = func(x); | |
if (std::signbit(f) == std::signbit(f1)) { | |
x3 = x1; | |
x1 = x; | |
f3 = f1; | |
f1 = f; | |
} else { | |
x3 = x2; | |
x2 = x1; | |
x1 = x; | |
f3 = f2; | |
f2 = f1; | |
f1 = f; | |
} | |
double xm, fm; | |
if (std::abs(f2) < std::abs(f1)) { | |
xm = x2; | |
fm = f2; | |
} else { | |
xm = x1; | |
fm = f1; | |
} | |
double tol = 2.0 * rtol * std::abs(xm) + 0.5 * atol; | |
double tl = tol / std::abs(x2 - x1); | |
if (tl > 0.5 || fm == 0) { | |
return {xm, 0}; | |
} | |
double xi = (x1 - x2) / (x3 - x2); | |
double phi = (f1 - f2) / (f3 - f2); | |
double fl = 1.0 - std::sqrt(1.0 - xi); | |
double fh = std::sqrt(xi); | |
if ((fl < phi) && (phi < fh)) { | |
t = (f1 / (f2 - f1)) * (f3 / (f2 - f3)) + (f1 / (f3 - f1)) * (f2 / (f3 - f2)) * ((x3 - x1) / (x2 - x1)); | |
} else { | |
t = 0.5; | |
} | |
t = std::fmin(std::fmax(t, tl), 1.0 - tl); | |
} | |
return {std::numeric_limits<double>::quiet_NaN(), 1}; | |
} | |
} // namespace detail | |
} // namespace xsf | |