File size: 5,411 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
/* Translated from Cython into C++ by SciPy developers in 2023.
* Original header with Copyright information appears below.
*/
/* Implementation of the Lambert W function [1]. Based on MPMath
* Implementation [2], and documentation [3].
*
* Copyright: Yosef Meller, 2009
* Author email: [email protected]
*
* Distributed under the same license as SciPy
*
*
* References:
* [1] On the Lambert W function, Adv. Comp. Math. 5 (1996) 329-359,
* available online: https://web.archive.org/web/20230123211413/https://cs.uwaterloo.ca/research/tr/1993/03/W.pdf
* [2] mpmath source code,
https://github.com/mpmath/mpmath/blob/c5939823669e1bcce151d89261b802fe0d8978b4/mpmath/functions/functions.py#L435-L461
* [3]
https://web.archive.org/web/20230504171447/https://mpmath.org/doc/current/functions/powers.html#lambert-w-function
*
* TODO: use a series expansion when extremely close to the branch point
* at `-1/e` and make sure that the proper branch is chosen there.
*/
#pragma once
#include "config.h"
#include "error.h"
#include "evalpoly.h"
namespace xsf {
constexpr double EXPN1 = 0.36787944117144232159553; // exp(-1)
constexpr double OMEGA = 0.56714329040978387299997; // W(1, 0)
namespace detail {
XSF_HOST_DEVICE inline std::complex<double> lambertw_branchpt(std::complex<double> z) {
// Series for W(z, 0) around the branch point; see 4.22 in [1].
double coeffs[] = {-1.0 / 3.0, 1.0, -1.0};
std::complex<double> p = std::sqrt(2.0 * (M_E * z + 1.0));
return cevalpoly(coeffs, 2, p);
}
XSF_HOST_DEVICE inline std::complex<double> lambertw_pade0(std::complex<double> z) {
// (3, 2) Pade approximation for W(z, 0) around 0.
double num[] = {12.85106382978723404255, 12.34042553191489361902, 1.0};
double denom[] = {32.53191489361702127660, 14.34042553191489361702, 1.0};
/* This only gets evaluated close to 0, so we don't need a more
* careful algorithm that avoids overflow in the numerator for
* large z. */
return z * cevalpoly(num, 2, z) / cevalpoly(denom, 2, z);
}
XSF_HOST_DEVICE inline std::complex<double> lambertw_asy(std::complex<double> z, long k) {
/* Compute the W function using the first two terms of the
* asymptotic series. See 4.20 in [1].
*/
std::complex<double> w = std::log(z) + 2.0 * M_PI * k * std::complex<double>(0, 1);
return w - std::log(w);
}
} // namespace detail
XSF_HOST_DEVICE inline std::complex<double> lambertw(std::complex<double> z, long k, double tol) {
double absz;
std::complex<double> w;
std::complex<double> ew, wew, wewz, wn;
if (std::isnan(z.real()) || std::isnan(z.imag())) {
return z;
}
if (z.real() == std::numeric_limits<double>::infinity()) {
return z + 2.0 * M_PI * k * std::complex<double>(0, 1);
}
if (z.real() == -std::numeric_limits<double>::infinity()) {
return -z + (2.0 * M_PI * k + M_PI) * std::complex<double>(0, 1);
}
if (z == 0.0) {
if (k == 0) {
return z;
}
set_error("lambertw", SF_ERROR_SINGULAR, NULL);
return -std::numeric_limits<double>::infinity();
}
if (z == 1.0 && k == 0) {
// Split out this case because the asymptotic series blows up
return OMEGA;
}
absz = std::abs(z);
// Get an initial guess for Halley's method
if (k == 0) {
if (std::abs(z + EXPN1) < 0.3) {
w = detail::lambertw_branchpt(z);
} else if (-1.0 < z.real() && z.real() < 1.5 && std::abs(z.imag()) < 1.0 &&
-2.5 * std::abs(z.imag()) - 0.2 < z.real()) {
/* Empirically determined decision boundary where the Pade
* approximation is more accurate. */
w = detail::lambertw_pade0(z);
} else {
w = detail::lambertw_asy(z, k);
}
} else if (k == -1) {
if (absz <= EXPN1 && z.imag() == 0.0 && z.real() < 0.0) {
w = std::log(-z.real());
} else {
w = detail::lambertw_asy(z, k);
}
} else {
w = detail::lambertw_asy(z, k);
}
// Halley's method; see 5.9 in [1]
if (w.real() >= 0) {
// Rearrange the formula to avoid overflow in exp
for (int i = 0; i < 100; i++) {
ew = std::exp(-w);
wewz = w - z * ew;
wn = w - wewz / (w + 1.0 - (w + 2.0) * wewz / (2.0 * w + 2.0));
if (std::abs(wn - w) <= tol * std::abs(wn)) {
return wn;
}
w = wn;
}
} else {
for (int i = 0; i < 100; i++) {
ew = std::exp(w);
wew = w * ew;
wewz = wew - z;
wn = w - wewz / (wew + ew - (w + 2.0) * wewz / (2.0 * w + 2.0));
if (std::abs(wn - w) <= tol * std::abs(wn)) {
return wn;
}
w = wn;
}
}
set_error("lambertw", SF_ERROR_SLOW, "iteration failed to converge: %g + %gj", z.real(), z.imag());
return {std::numeric_limits<double>::quiet_NaN(), std::numeric_limits<double>::quiet_NaN()};
}
XSF_HOST_DEVICE inline std::complex<float> lambertw(std::complex<float> z, long k, float tol) {
return static_cast<std::complex<float>>(
lambertw(static_cast<std::complex<double>>(z), k, static_cast<double>(tol)));
}
} // namespace xsf
|