Spaces:
Sleeping
Sleeping
| /* | |
| * Copyright 2021 Google LLC | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| namespace csrblocksparse { | |
| // The input to exp is clipped to bounds that prevent overflow/underflow in a | |
| // 32 bit float representation. e^80 ~ 6e34, which is close to maxfloat. | |
| constexpr float kMaxExpInput = 80.f; | |
| constexpr int kMaxExpInputInt = static_cast<int>(kMaxExpInput); | |
| constexpr float kMinExpInput = -80.f; | |
| // tanh(9) ~ 0.99999997, which cannot be resolved from 1 in a float32. | |
| constexpr float kMaxTanhInput = 9.f; | |
| constexpr float kMinTanhInput = -9.f; | |
| // sigmoid(18) ~ 0.999999985, which cannot be resolved from 1 in a float32. | |
| constexpr float kMaxSigmoidInput = 18.f; | |
| constexpr float kMinSigmoidInput = -18.f; | |
| // kAConstant ~= 2^23 / ln 2 | |
| constexpr uint32_t kAConstant = 0x4b38aa3b; | |
| // kBConstant ~= (127 << 23) - 366000 | |
| constexpr uint32_t kBConstant = 0x4e7de9a9; | |
| // Coefficients of the rational approximation to tanh. | |
| // Coefficients of the numerator polynomial (odd). | |
| constexpr float kTanhAlpha1 = 4.89352455891786e-03; | |
| constexpr float kTanhAlpha3 = 6.37261928875436e-04; | |
| constexpr float kTanhAlpha5 = 1.48572235717979e-05; | |
| constexpr float kTanhAlpha7 = 5.12229709037114e-08; | |
| constexpr float kTanhAlpha9 = -8.60467152213735e-11; | |
| constexpr float kTanhAlpha11 = 2.00018790482477e-13; | |
| constexpr float kTanhAlpha13 = -2.76076847742355e-16; | |
| // The monomial coefficients of the denominator polynomial (even). | |
| constexpr float kTanhBeta0 = 4.89352518554385e-03; | |
| constexpr float kTanhBeta2 = 2.26843463243900e-03; | |
| constexpr float kTanhBeta4 = 1.18534705686654e-04; | |
| constexpr float kTanhBeta6 = 1.19825839466702e-06; | |
| // Coefficients of the rational approximation to sigmoid. | |
| // Coefficients of the numerator polynomial (odd). | |
| constexpr float kSigmoidAlpha1 = 2.48287947061529e-01; | |
| constexpr float kSigmoidAlpha3 = 8.51377133304701e-03; | |
| constexpr float kSigmoidAlpha5 = 6.08574864600143e-05; | |
| constexpr float kSigmoidAlpha7 = 1.15627324459942e-07; | |
| constexpr float kSigmoidAlpha9 = 4.37031012579801e-11; | |
| // The monomial coefficients of the denominator polynomial (even). | |
| constexpr float kSigmoidBeta0 = 9.93151921023180e-01; | |
| constexpr float kSigmoidBeta2 = 1.16817656904453e-01; | |
| constexpr float kSigmoidBeta4 = 1.70198817374094e-03; | |
| constexpr float kSigmoidBeta6 = 6.29106785017040e-06; | |
| constexpr float kSigmoidBeta8 = 5.76102136993427e-09; | |
| constexpr float kSigmoidBeta10 = 6.10247389755681e-13; | |
| // x is the first term of the Taylor series approximation of tanh near 0 and | |
| // because the leading error term of tanh(x) - x is O(x^3), it is good for a | |
| // wide interval, use it in this region where the other approximation is | |
| // inaccurate. tanh(x) = x - x^3 / 3 + 2x^5 / 15 - 17x^7 / 315 + ... | |
| // Similarly for sigmoid where the first term is .25x | |
| constexpr float kTanhLinearRegion = .15f; | |
| constexpr float kSigmoidLinearRegion = .75f; | |
| // Maximum shift factor for 1/log 2 to keep it inside int32. | |
| constexpr int kMaxLog2Shift = 30; | |
| static const int kLogFactor = static_cast<int>((1 << kMaxLog2Shift) / log(2.f)); | |
| static const float kOneOverLog2 = 1.0f / log(2.f); | |
| // Number of real mantissa bits in IEEE float32. | |
| constexpr int kFloatMantissaBits = 23; | |
| // Offset to correct the exponent value in the resulting float. | |
| constexpr int kFloatExponentOffset = 127 << kFloatMantissaBits; | |
| // Mask for mantissa. | |
| constexpr int kFloatMantissaMask = (1 << kFloatMantissaBits) - 1; | |
| // Mask for exponent; | |
| constexpr int kFloatExponentMask = (-1) ^ kFloatMantissaMask; | |
| // ========== COMMON DOCUMENTATION FOR THE FLOATING EXPONENT TRICK ============ | |
| // Summary: Use the exponent-mantissa representation of a floating point number | |
| // to give exponentiation of 2 for free. If we desire f(z) = e^z = 2^(x+n), (for | |
| // some fixed-point z expressed as an integer with imaginary binary point within | |
| // it) then we have to compute x+n = z / ln 2 and then splitting x+n into | |
| // n = int(x+n) and x = fract(x+n) in [0, 1), we can use n and 2^x as the | |
| // exponent and mantissa of a floating point number, and that float is equal to | |
| // e^z. For original reference see: | |
| // http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.9.4508&rep=rep1&type=pdf | |
| // Important detail: | |
| // IEEE floats are stored normalized, ie 1.bbbbbbb... x 2^exponent. The leading | |
| // 1 bit is not actually stored, (as it is always 1), providing an extra bit of | |
| // precision. | |
| // Since 2^0=1 and 2^1=2, we can treat the problem as 2^x = 1 + u and we thus | |
| // need a mapping x in [0, 1) -> u in [0, 1) and the 1 + is provided by the | |
| // representation. | |
| // In the original paper cited above, the mapping is u = x - c, where c is set | |
| // to minimize the average error. The function to compute exp(x) this way is | |
| // incredibly simple and computationally cheap, but not very accurate. | |
| // Fortunately, the problem has been reduced to u = 2^x - 1 over [0, 1) for | |
| // which it is far easier to construct accurate approximations with small | |
| // polynomials than a full range exp(x), and this is what the cubic and quartic | |
| // versions below do. An important feature of these functions is that they | |
| // constrain the solution to be exact at 0 and 1 so there is continuity at each | |
| // integer boundary where we wrap from 1 to 0 and increment the power of 2. | |
| // Coefficients for quartic representation of 2^x - 1 for x on [0,1). | |
| // The quartic representation is 2^x - 1 ~ x - x(1-x)(ax^2 + bx + c), hence the | |
| // coefficients of a quadratic are all that is required. | |
| // Coefficients came from numerical experiments. | |
| constexpr float kExpQuarticFactor2 = 0.0135302434f; | |
| constexpr float kExpQuarticFactor1 = 0.0656107542f; | |
| constexpr float kExpQuarticFactor0 = 0.306963906f; | |
| // Coefficients for cubic representation of 2^x - 1 for x on [0,1] | |
| // The cubic representation is 2^x - 1 ~ x - x(1-x)(mx + c), hence the | |
| // coefficients of a linear function are all that is required. | |
| // Coefficients came from numerical experiments. | |
| constexpr float kExpCubicFactor1 = 0.0780252018f; | |
| constexpr float kExpCubicFactor0 = 0.304684167f; | |
| // Coefficients are optimized to minimize the absolute error on | |
| // tanh = (e^2x - 1) / (e^2x + 1) instead of on pure e^x. | |
| // Enum that determines how a transcendental is computed. | |
| enum TranscendentalMode { | |
| // Cubic using 16 bit integer arithmetic. | |
| TM_ORDER3_16BIT, | |
| // Quartic using 16 bit integer arithmetic. | |
| TM_ORDER4_16BIT, | |
| // Quartic using 32 bit float arithmetic. | |
| TM_ORDER4_FLOAT, | |
| }; | |
| inline int FloatAsInt16(float x) { | |
| return static_cast<int>(x * (1 << 15) + 0.5f); | |
| } | |
| inline int FloatAsInt32(float x) { | |
| return static_cast<int>(x * (1 << 30) + 0.5f); | |
| } | |
| constexpr int kMaxSigmoidInputInt = static_cast<int>(kMaxSigmoidInput); | |
| // Computes and returns 2^(x>>23) ie 2^u where x = u << 23 bits. | |
| // Uses the quartic floating point exponent trick, see COMMON DOCUMENTATION FOR | |
| // THE FLOATING EXPONENT TRICK above for details. | |
| // Returns the true value, ie not scaled. | |
| inline float32x4_t float32_pow2(float32x4_t x) { | |
| // The input is already shifted left by 23 bits, so when we convert to int, | |
| // the bottom 23 bits are the fractional part, and the top bits are the | |
| // integer part. We want to compute a function of the fractional part, so | |
| // we will mask it off and manipulate it. | |
| int32x4_t exp_int_x = vcvtq_s32_f32(x); | |
| // Mask to allow conversion of just the fractional part of x to fixed16<0>. | |
| int32x4_t mantissa_mask16 = vdupq_n_s32(0x7fff00); | |
| // Mask to allow conversion of just the fractional part of x to fixed32<1>. | |
| int32x4_t mantissa_mask32 = vdupq_n_s32(0x7fffff); | |
| // Narrowing shift to convert to fixed16<0>. | |
| int16x4_t x_16 = vshrn_n_s32(vandq_s32(mantissa_mask16, exp_int_x), 8); | |
| // Shift to convert to fixed32<1>. | |
| int32x4_t x_32 = vshlq_n_s32(vandq_s32(mantissa_mask32, exp_int_x), 7); | |
| // Compute the polynomial x(x - 1)(ax^2 + bx + c) of the fractional part. | |
| // Ordering these lines carefully makes it faster, as some of the multiply | |
| // operations can pipeline instead of waiting for the previous result. | |
| int32x4_t x_squared = vmull_s16(x_16, x_16); | |
| int16x4_t b = vdup_n_s16(FloatAsInt16(kExpQuarticFactor1)); | |
| int32x4_t c = vdupq_n_s32(FloatAsInt32(kExpQuarticFactor0)); | |
| int32x4_t bx_plus_c = vmlal_s16(c, b, x_16); | |
| int16x4_t a = vdup_n_s16(FloatAsInt16(kExpQuarticFactor2)); | |
| // Finish the quadratic: result = ax^2 + bx + c. | |
| int32x4_t result = vmlal_s16(bx_plus_c, a, vshrn_n_s32(x_squared, 15)); | |
| int32x4_t x_squared_minus_x = vsubq_s32(x_squared, x_32); | |
| // Multiply by x^2 - x. | |
| result = vqrdmulhq_s32(result, x_squared_minus_x); | |
| // Shift back to mantissa position. vqrdmulhq_s32 took 2x 30-mantissa bit | |
| // inputs, made 60-mantissa bit result, doubled it to 61 bits, then discarded | |
| // the bottom 32 making 29, so shift right 6 to get 23. | |
| result = vshrq_n_s32(result, 6); | |
| // Add the constant to normalize the exponent for IEEE format. | |
| int32x4_t exp_offset = vdupq_n_s32(kFloatExponentOffset); | |
| exp_int_x = vaddq_s32(exp_int_x, exp_offset); | |
| exp_int_x = vaddq_s32(exp_int_x, result); | |
| // Cast back to float, as we just computed the exponent and mantissa and | |
| // assembled them in IEEE format. | |
| return vreinterpretq_f32_s32(exp_int_x); | |
| } | |
| // Scaled float to float exp approximation, using a quartic refinement of | |
| // the exponent trick. See COMMON DOCUMENTATION FOR THE FLOATING EXPONENT TRICK | |
| // above for details. Input is a fixed32<31 - mantissa_bits> that has been | |
| // converted to a float without any further shifting. MUST HAVE ALREADY BEEN | |
| // CLIPPED to a suitable range for exp! | |
| // Returns a vector of standard unscaled floats. | |
| inline float32x4_t fixed32_exp_float_preclipped(const int mantissa_bits, | |
| float32x4_t x) { | |
| // Divide by log 2 to convert problem to 2^x, and scale to match the | |
| // mantissa bits required by IEEE floats. | |
| // This is the shift of the FP mantissa relative to the input mantissa. | |
| const int kXShift = kFloatMantissaBits - mantissa_bits; | |
| const float kLogFactor = static_cast<float>(1 << kXShift); | |
| float32x4_t factor = vdupq_n_f32(kLogFactor * kOneOverLog2); | |
| float32x4_t y = vmulq_f32(x, factor); | |
| // Now compute 2^x. | |
| return float32_pow2(y); | |
| } | |
| // uses trick that 2^x can be computed by shifting integer into the | |
| // exponent, see the following reference for a derivation using double: | |
| // goo.gl/aUVTK3 | |
| // Input x is clamped to [-64, 64], even infinity and NaN. | |
| // Accurate to within 3% relative across the entire range. | |
| // Fully pipelined throughput is about 10 cycles per fast_exp call. | |
| inline float32x4_t fast_exp(float32x4_t x) { | |
| // Uses vcvtnq_s32_f32, not available on ARM v7 NEON. | |
| // Load A and B, which are defined as integers into float registers. | |
| float32x4_t A = vreinterpretq_f32_u32(vdupq_n_u32(kAConstant)); | |
| float32x4_t res = vreinterpretq_f32_u32(vdupq_n_u32(kBConstant)); | |
| // Make sure x within the allowed range. | |
| x = vminq_f32(x, vdupq_n_f32(kMaxExpInput)); | |
| x = vmaxq_f32(x, vdupq_n_f32(kMinExpInput)); | |
| // res = A * x + B. | |
| // This shifts x into the exponent field and adds the bias. | |
| res = vmlaq_f32(res, A, x); | |
| // Convert back to an integer, this is what uses the floating point | |
| // unit to compute 2^x. | |
| int32x4_t x_int = vcvtnq_s32_f32(res); | |
| return vreinterpretq_f32_s32(x_int); | |
| float32x4_t return_val = vdupq_n_f32(0.f); | |
| float exponent = expf(vgetq_lane_f32(x, 0)); | |
| return_val = vld1q_lane_f32(&exponent, return_val, 0); | |
| exponent = expf(vgetq_lane_f32(x, 1)); | |
| return_val = vld1q_lane_f32(&exponent, return_val, 1); | |
| exponent = expf(vgetq_lane_f32(x, 2)); | |
| return_val = vld1q_lane_f32(&exponent, return_val, 2); | |
| exponent = expf(vgetq_lane_f32(x, 3)); | |
| return_val = vld1q_lane_f32(&exponent, return_val, 3); | |
| return return_val; | |
| } | |
| // This version does a conversion of the input to floating point, then calls | |
| // the floating point fast_exp function. There is another version | |
| // fast_exp_fixed, that never does a conversion and is less accurate, but much | |
| // faster. | |
| template <int ExponentBits> | |
| inline float32x4_t fast_exp(int32x4_t x) { | |
| return fast_exp(vcvtq_n_f32_s32(x, 31 - ExponentBits)); | |
| } | |
| // Performs an exp estimate without doing any floating point operations. The | |
| // result is a floating point number. See scalar version for an explanation. | |
| template <int ExponentBits> | |
| inline float32x4_t fast_exp_fixed(int32x4_t x) { | |
| static_assert(ExponentBits > 8, "Must have more than 8 ExponentBits"); | |
| constexpr int kA = 1.4426950408889634 * (1 << (ExponentBits - 8)); | |
| constexpr int kB = (127 << 23) - 366000; | |
| constexpr int maxInput = 80 << (31 - ExponentBits); | |
| constexpr int minInput = -maxInput; | |
| int32x4_t A = vdupq_n_s32(kA); | |
| int32x4_t res = vdupq_n_s32(kB); | |
| // Make sure x within the allowed range. | |
| x = vminq_s32(x, vdupq_n_s32(maxInput)); | |
| x = vmaxq_s32(x, vdupq_n_s32(minInput)); | |
| // res = A * x + B. | |
| // This shifts x into the exponent field and adds the bias. | |
| res = vmlaq_s32(res, A, x); | |
| return vreinterpretq_f32_s32(res); | |
| } | |
| // fast_exp_norange_check uses vcvtnq_s32_f32, not available on ARM v7 NEON. | |
| namespace detail { | |
| // tanh can do range check once. | |
| // Input x is clamped to [-64, 64], even infinity and NaN. | |
| inline float32x4_t fast_exp_norange_check(float32x4_t x) { | |
| float32x4_t A = vreinterpretq_f32_u32(vdupq_n_u32(kAConstant)); | |
| float32x4_t res = vreinterpretq_f32_u32(vdupq_n_u32(kBConstant)); | |
| res = vmlaq_f32(res, A, x); | |
| int32x4_t x_int = vcvtnq_s32_f32(res); | |
| return vreinterpretq_f32_s32(x_int); | |
| } | |
| } // namespace detail | |
| // Clips float input to [-kLimit,kLimit]. | |
| inline float32x4_t ClipToFloatBounds(const float kLimit, const float32x4_t x) { | |
| // Clip to the input bounds for this approximation. | |
| float32x4_t clip_limit = vdupq_n_f32(kLimit); | |
| float32x4_t clipped_x = vminq_f32(x, clip_limit); | |
| clip_limit = vnegq_f32(clip_limit); | |
| return vmaxq_f32(clipped_x, clip_limit); | |
| } | |
| inline float32x4_t float_tanh_float(const float32x4_t& x) { | |
| float32x4_t clipped_x = ClipToFloatBounds(kMaxTanhInput, x); | |
| // Divide by log 2 to convert problem to 2^x, double (as we need exp(2x)) and | |
| // scale to the mantissa bits required by float32_pow2 all in one multiply. | |
| // Add one to double the input. | |
| const float kLogFactor = static_cast<float>(1 << (kFloatMantissaBits + 1)); | |
| float32x4_t factor = vdupq_n_f32(kLogFactor * kOneOverLog2); | |
| clipped_x = vmulq_f32(clipped_x, factor); | |
| // Now compute 2^x. | |
| float32x4_t exp_result = float32_pow2(clipped_x); | |
| // Now compute tanh using (e^2x - 1) / (e^2x + 1). | |
| float32x4_t one = vdupq_n_f32(1.0f); | |
| float32x4_t numerator = vsubq_f32(exp_result, one); | |
| float32x4_t denominator = vaddq_f32(exp_result, one); | |
| float32x4_t recp = vrecpeq_f32(denominator); | |
| // Newton-Raphson iteration, accuracy is important for audio quality | |
| recp = vmulq_f32(recp, vrecpsq_f32(recp, denominator)); | |
| recp = vmulq_f32(recp, numerator); | |
| // Compute 3rd-order Taylor tanh ~ x - x^3/3 for high accuracy and thus low | |
| // relative error close to 0. | |
| float32x4_t third = vdupq_n_f32(1.0f / 3.0f); | |
| float32x4_t taylor = vmulq_f32(x, x); | |
| taylor = vmulq_f32(taylor, x); | |
| taylor = vmulq_f32(taylor, third); | |
| taylor = vsubq_f32(x, taylor); | |
| // Test |x| <= 1/9, roughly where the errors cross over, without needing yet | |
| // another constant. | |
| float32x4_t ninth = vmulq_f32(third, third); | |
| uint32x4_t cmp_results = vcaleq_f32(x, ninth); | |
| return vbslq_f32(cmp_results, taylor, recp); | |
| } | |
| // Calculates (exp(x) - exp(-x)) / (exp(x) + exp(-x)). | |
| // Input x is clamped to [-9, 9], even infinity and NaN. | |
| // See test program for bounds. Throughput of FAST is 334 Mega/sec, | |
| // throughput of accurate is 232 Mega/sec. | |
| inline float32x4_t fast_tanh(float32x4_t x) { | |
| return float_tanh_float(x); | |
| x = vminq_f32(x, vdupq_n_f32(kMaxTanhInput)); | |
| x = vmaxq_f32(x, vdupq_n_f32(kMinTanhInput)); | |
| // The monomial coefficients of the numerator polynomial (odd). | |
| const float32x4_t alpha_1 = vdupq_n_f32(kTanhAlpha1); | |
| const float32x4_t alpha_3 = vdupq_n_f32(kTanhAlpha3); | |
| const float32x4_t alpha_5 = vdupq_n_f32(kTanhAlpha5); | |
| const float32x4_t alpha_7 = vdupq_n_f32(kTanhAlpha7); | |
| const float32x4_t alpha_9 = vdupq_n_f32(kTanhAlpha9); | |
| const float32x4_t alpha_11 = vdupq_n_f32(kTanhAlpha11); | |
| const float32x4_t alpha_13 = vdupq_n_f32(kTanhAlpha13); | |
| // The monomial coefficients of the denominator polynomial (even). | |
| const float32x4_t beta_0 = vdupq_n_f32(kTanhBeta0); | |
| const float32x4_t beta_2 = vdupq_n_f32(kTanhBeta2); | |
| const float32x4_t beta_4 = vdupq_n_f32(kTanhBeta4); | |
| const float32x4_t beta_6 = vdupq_n_f32(kTanhBeta6); | |
| // Since the polynomials are odd/even, we need x^2. | |
| const float32x4_t x2 = vmulq_f32(x, x); | |
| // Evaluate the numerator polynomial |p|. | |
| float32x4_t p = vmlaq_f32(alpha_11, x2, alpha_13); | |
| p = vmlaq_f32(alpha_9, x2, p); | |
| p = vmlaq_f32(alpha_7, x2, p); | |
| p = vmlaq_f32(alpha_5, x2, p); | |
| p = vmlaq_f32(alpha_3, x2, p); | |
| p = vmlaq_f32(alpha_1, x2, p); | |
| p = vmulq_f32(x, p); | |
| // Evaluate the denominator polynomial p. | |
| float32x4_t q = vmlaq_f32(beta_4, x2, beta_6); | |
| q = vmlaq_f32(beta_2, x2, q); | |
| q = vmlaq_f32(beta_0, x2, q); | |
| // Divide the numerator by the denominator. | |
| float32x4_t recp = vrecpeq_f32(q); | |
| recp = vmulq_f32(recp, vrecpsq_f32(recp, q)); | |
| return vmulq_f32(p, recp); | |
| // Uses vcvtnq_s32_f32, not available on ARM v7 NEON. | |
| x = vminq_f32(x, vdupq_n_f32(kMaxTanhInput)); | |
| x = vmaxq_f32(x, vdupq_n_f32(kMinTanhInput)); | |
| float32x4_t exp_est = detail::fast_exp_norange_check(x); | |
| float32x4_t neg_exp_est = detail::fast_exp_norange_check(-x); | |
| // If we're in the linear region. | |
| // caleq = compare absolute <= | |
| uint32x4_t cmp_results = vcaleq_f32(x, vdupq_n_f32(kTanhLinearRegion)); | |
| float32x4_t diff = vsubq_f32(exp_est, neg_exp_est); | |
| float32x4_t sum = vaddq_f32(exp_est, neg_exp_est); | |
| float32x4_t recp = vrecpeq_f32(sum); | |
| recp = vmulq_f32(recp, vrecpsq_f32(recp, sum)); | |
| float32x4_t tanh_estimate = vmulq_f32(diff, recp); | |
| // Based on comparison, possibly copy x through instead of calculated value. | |
| // TODO(b/191497441): Is the compiler generating VBIT or VBSL ? VBIT is one | |
| // cycle and VBSL is two... documentation suggests it can do either. | |
| return vbslq_f32(cmp_results, x, tanh_estimate); | |
| float32x4_t return_val = vdupq_n_f32(0.f); | |
| float tanh_value = tanhf(vgetq_lane_f32(x, 0)); | |
| return_val = vld1q_lane_f32(&tanh_value, return_val, 0); | |
| tanh_value = tanhf(vgetq_lane_f32(x, 1)); | |
| return_val = vld1q_lane_f32(&tanh_value, return_val, 1); | |
| tanh_value = tanhf(vgetq_lane_f32(x, 2)); | |
| return_val = vld1q_lane_f32(&tanh_value, return_val, 2); | |
| tanh_value = tanhf(vgetq_lane_f32(x, 3)); | |
| return_val = vld1q_lane_f32(&tanh_value, return_val, 3); | |
| return return_val; | |
| } | |
| // Input x is clamped to [-18, 18], even infinity and NaN. | |
| // See tests for error bounds. Using SIGMOID_AS_TANH with | |
| // ACCURATE_TRANSCENDENTAL_APPROX is both faster and more accurate. Using | |
| // SIGMOID_AS_TANH with just FAST is slower, but more accurate. | |
| // SIGMOID_AS_TANH, ACCURATE is 205 Mega/sec | |
| // SIGMOID_AS_TANH, FAST is 290 Mega/sec | |
| // FAST is 340 Mega/sec | |
| inline float32x4_t fast_sigmoid(float32x4_t x) { | |
| float32x4_t half = vdupq_n_f32(0.5f); | |
| return vmlaq_f32(half, half, fast_tanh(vmulq_f32(half, x))); | |
| x = vminq_f32(x, vdupq_n_f32(kMaxSigmoidInput)); | |
| x = vmaxq_f32(x, vdupq_n_f32(kMinSigmoidInput)); | |
| // The monomial coefficients of the numerator polynomial (odd). | |
| const float32x4_t alpha_1 = vdupq_n_f32(kSigmoidAlpha1); | |
| const float32x4_t alpha_3 = vdupq_n_f32(kSigmoidAlpha3); | |
| const float32x4_t alpha_5 = vdupq_n_f32(kSigmoidAlpha5); | |
| const float32x4_t alpha_7 = vdupq_n_f32(kSigmoidAlpha7); | |
| const float32x4_t alpha_9 = vdupq_n_f32(kSigmoidAlpha9); | |
| // The monomial coefficients of the denominator polynomial (even). | |
| const float32x4_t beta_0 = vdupq_n_f32(kSigmoidBeta0); | |
| const float32x4_t beta_2 = vdupq_n_f32(kSigmoidBeta2); | |
| const float32x4_t beta_4 = vdupq_n_f32(kSigmoidBeta4); | |
| const float32x4_t beta_6 = vdupq_n_f32(kSigmoidBeta6); | |
| const float32x4_t beta_8 = vdupq_n_f32(kSigmoidBeta8); | |
| const float32x4_t beta_10 = vdupq_n_f32(kSigmoidBeta10); | |
| // Since the polynomials are odd/even, we need x^2. | |
| const float32x4_t x2 = vmulq_f32(x, x); | |
| // Evaluate the numerator polynomial p. | |
| float32x4_t p = vmlaq_f32(alpha_7, x2, alpha_9); | |
| p = vmlaq_f32(alpha_5, x2, p); | |
| p = vmlaq_f32(alpha_3, x2, p); | |
| p = vmlaq_f32(alpha_1, x2, p); | |
| p = vmulq_f32(x, p); | |
| // Evaluate the denominator polynomial p. | |
| float32x4_t q = vmlaq_f32(beta_8, x2, beta_10); | |
| q = vmlaq_f32(beta_6, x2, q); | |
| q = vmlaq_f32(beta_4, x2, q); | |
| q = vmlaq_f32(beta_2, x2, q); | |
| q = vmlaq_f32(beta_0, x2, q); | |
| // Divide the numerator by the denominator. | |
| float32x4_t recp = vrecpeq_f32(q); | |
| recp = vmulq_f32(recp, vrecpsq_f32(recp, q)); | |
| return vmlaq_f32(vdupq_n_f32(0.5f), p, recp); | |
| float32x4_t denom = vaddq_f32(fast_exp(vnegq_f32(x)), vdupq_n_f32(1.f)); | |
| float32x4_t recp = vrecpeq_f32(denom); | |
| // Newton-Raphson iteration, accuracy is important for audio quality. | |
| recp = vmulq_f32(recp, vrecpsq_f32(recp, denom)); | |
| float32x4_t half = vdupq_n_f32(0.5f); | |
| float32x4_t quarter = vdupq_n_f32(0.245f); | |
| float32x4_t linear_approx = vmlaq_f32(half, quarter, x); | |
| uint32x4_t cmp_results = vcaleq_f32(x, vdupq_n_f32(kSigmoidLinearRegion)); | |
| return vbslq_f32(cmp_results, linear_approx, recp); | |
| float32x4_t return_val = vdupq_n_f32(0.f); | |
| float result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 0))); | |
| return_val = vld1q_lane_f32(&result, return_val, 0); | |
| result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 1))); | |
| return_val = vld1q_lane_f32(&result, return_val, 1); | |
| result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 2))); | |
| return_val = vld1q_lane_f32(&result, return_val, 2); | |
| result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 3))); | |
| return_val = vld1q_lane_f32(&result, return_val, 3); | |
| return return_val; | |
| } | |
| // Scalar implementations, mainly useful for testing. | |
| inline float fast_exp(float x) { | |
| return vgetq_lane_f32(fast_exp(vdupq_n_f32(x)), 0); | |
| } | |
| template <int ExponentBits> | |
| inline float fast_exp(fixed32<ExponentBits> x) { | |
| return vgetq_lane_f32(fast_exp<ExponentBits>(vdupq_n_s32(x.raw_val())), 0); | |
| } | |
| // Returns the exponent of a fixed point number in floating point without ever | |
| // doing any conversions. Less accurate than the version that does conversions, | |
| // but still accurate to within 4% relative for x < 16. | |
| template <int ExponentBits> | |
| inline float fast_exp_fixed(fixed32<ExponentBits> x) { | |
| return vgetq_lane_f32(fast_exp_fixed<ExponentBits>(vdupq_n_s32(x.raw_val())), | |
| 0); | |
| } | |
| inline float fast_sigmoid(float x) { | |
| return vgetq_lane_f32(fast_sigmoid(vdupq_n_f32(x)), 0); | |
| } | |
| inline float fast_tanh(float x) { | |
| return vgetq_lane_f32(fast_tanh(vdupq_n_f32(x)), 0); | |
| } | |
| // Clips integer input to [-|kLimit|, |kLimit|]. | |
| // Input: register containins 4x fixed32 with mantissa_bits. | |
| // Output: register containing 4x fixed32 limited to | |
| // [-|kLimit| << |mantissa_bits|, |kLimit| << |mantissa_bits|]. | |
| template <int kLimit> | |
| inline int32x4_t ClipToBounds(const int mantissa_bits, const int32x4_t x) { | |
| // Clip to the input bounds for this approximation. | |
| int32x4_t clip_limit = vdupq_n_s32(-(kLimit << mantissa_bits)); | |
| int32x4_t clipped_x = vmaxq_s32(x, clip_limit); | |
| clip_limit = vnegq_s32(clip_limit); | |
| return vminq_s32(clipped_x, clip_limit); | |
| } | |
| // Fixed32 sigmoid approximation via a quadratic refinement of the exponent | |
| // trick. | |
| // Input: Register containing 4x fixed32 with |mantissa_bits|. | |
| // Output: Register containing 4x float results. | |
| inline float32x4_t fixed32_sigmoid_float(const int mantissa_bits, | |
| const int32x4_t x) { | |
| int32x4_t input = vnegq_s32(x); | |
| float32x4_t y = | |
| vcvtq_f32_s32(ClipToBounds<kMaxSigmoidInputInt>(mantissa_bits, input)); | |
| y = fixed32_exp_float_preclipped(mantissa_bits, y); | |
| float32x4_t one = vdupq_n_f32(1.0f); | |
| // Approximate reciprocal is not accurate enough - use full division. | |
| float32x4_t denom = vaddq_f32(y, one); | |
| float32x4_t recp = vrecpeq_f32(denom); | |
| // Newton-Raphson iteration, accuracy is important for audio quality | |
| recp = vmulq_f32(recp, vrecpsq_f32(recp, denom)); | |
| return recp; | |
| } | |
| template <int ExponentBits> | |
| inline float32x4_t fast_sigmoid(int32x4_t x) { | |
| // Computation will fail to produce the right result if the input mantissa | |
| // bits exceeds the number in a float. | |
| static_assert(kFloatMantissaBits >= fixed32<ExponentBits>::kMantissaBits, | |
| "Mantissa bits must be at most 23!"); | |
| return fixed32_sigmoid_float(fixed32<ExponentBits>::kMantissaBits, x); | |
| return fast_sigmoid(vcvtq_n_f32_s32(x, fixed32<ExponentBits>::kMantissaBits)); | |
| } | |
| template <int ExponentBits> | |
| inline float fast_sigmoid(fixed32<ExponentBits> x) { | |
| return vgetq_lane_f32(fast_sigmoid<ExponentBits>(vdupq_n_s32(x.raw_val())), | |
| 0); | |
| } | |
| inline float fast_exp(float x) { | |
| if (isnan(x)) return 0.0f; | |
| x = std::max(std::min(x, kMaxExpInput), kMinExpInput); | |
| float AConstant, BConstant; | |
| memcpy(&AConstant, &kAConstant, sizeof(int)); | |
| memcpy(&BConstant, &kBConstant, sizeof(int)); | |
| float y = x * AConstant + BConstant; | |
| int x_int = static_cast<int>(y); | |
| float ret; | |
| memcpy(&ret, &x_int, sizeof(float)); | |
| return ret; | |
| return expf(x); | |
| } | |
| template <int ExponentBits> | |
| inline float fast_exp(fixed32<ExponentBits> x) { | |
| return fast_exp(static_cast<float>(x)); | |
| } | |
| template <int ExponentBits> | |
| inline float fast_exp_fixed(fixed32<ExponentBits> x) { | |
| static_assert(ExponentBits > 8, "Must have more than 8 ExponentBits"); | |
| int matched_decimal = | |
| std::max(std::min(x.raw_val(), (80 << (31 - ExponentBits))), | |
| -(80 << (31 - ExponentBits))); | |
| // Convert 1 / log(2) to 16-bit fixed point with 1 exponent bit | |
| // (1 / log(2)) * (1 << 14), but then right shift by the appropriate amount to | |
| // line the decimal point up with the 32-bit float representation. | |
| // (MantissaBits of x) + (MantissaBits of constant) = 23 | |
| // 23 - (MantissaBits of x) = MantissaBits of constant | |
| // 23 - (31 - ExponentBits of x) = ... | |
| // (ExponentBits of x - 8) = MantissaBits of constant | |
| const int16_t A = (1.f / logf(2.f)) * (1 << (ExponentBits - 8)); | |
| // Same rationale as for floating point versions, bias exponent, subtract | |
| // 366000 to reduce error by centering approximation, instead of being | |
| // one-sided. | |
| const int B = (127 << 23) - 366000; | |
| matched_decimal = A * matched_decimal + B; | |
| float ret_val; | |
| memcpy(&ret_val, &matched_decimal, sizeof(float)); | |
| return ret_val; | |
| } | |
| inline float fast_tanh(float x) { | |
| // Doesn't do anything fancy, just a 13/6-degree rational interpolant which | |
| // is accurate up to a couple of ulp in the range [-9, 9], outside of which | |
| // fl(tanh(x)) = +/-1. | |
| x = std::max(std::min(x, kMaxTanhInput), kMinTanhInput); | |
| // Since the polynomials are odd/even, we need x^2. | |
| float x2 = x * x; | |
| // Evaluate numerator. | |
| float p = kTanhAlpha11 + x2 * kTanhAlpha13; | |
| p = kTanhAlpha9 + x2 * p; | |
| p = kTanhAlpha7 + x2 * p; | |
| p = kTanhAlpha5 + x2 * p; | |
| p = kTanhAlpha3 + x2 * p; | |
| p = kTanhAlpha1 + x2 * p; | |
| p = x * p; | |
| // Evaluate denominator. | |
| float q = kTanhBeta4 + x2 * kTanhBeta6; | |
| q = kTanhBeta2 + x2 * q; | |
| q = kTanhBeta0 + x2 * q; | |
| return p / q; | |
| if (std::abs(x) < kTanhLinearRegion) { | |
| return x; | |
| } else { | |
| x = std::max(std::min(x, kMaxTanhInput), kMinTanhInput); | |
| float positive = fast_exp(x); | |
| float negative = fast_exp(-x); | |
| return (positive - negative) / (positive + negative); | |
| } | |
| return tanhf(x); | |
| } | |
| inline float fast_sigmoid(float x) { | |
| return .5f * fast_tanh(.5f * x) + .5f; | |
| // Doesn't do anything fancy, just a 9/10-degree rational interpolant which | |
| // interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulp in the range | |
| // [-18, 18], outside of which the fl(sigmoid(x)) = {0|1}. The shifted | |
| // sigmoid is interpolated because it was easier to make the fit converge. | |
| // See GenericPacketMath.h* in the open source Eigen library. | |
| x = std::max(std::min(x, kMaxSigmoidInput), kMinSigmoidInput); | |
| // Since the polynomials are odd/even, we need x^2. | |
| float x2 = x * x; | |
| // Evaluate numerator. | |
| float p = kSigmoidAlpha7 + x2 * kSigmoidAlpha9; | |
| p = kSigmoidAlpha5 + x2 * p; | |
| p = kSigmoidAlpha3 + x2 * p; | |
| p = kSigmoidAlpha1 + x2 * p; | |
| p = x * p; | |
| // Evaluate denominator. | |
| float q = kSigmoidBeta8 + x2 * kSigmoidBeta10; | |
| q = kSigmoidBeta6 + x2 * q; | |
| q = kSigmoidBeta4 + x2 * q; | |
| q = kSigmoidBeta2 + x2 * q; | |
| q = kSigmoidBeta0 + x2 * q; | |
| return p / q + 0.5f; | |
| if (std::abs(x) < kSigmoidLinearRegion) { | |
| return .245 * x + .5; | |
| } else { | |
| return 1.f / (1.f + fast_exp(-x)); | |
| } | |
| return 1.f / (1.f + expf(-x)); | |
| } | |
| template <int ExponentBits> | |
| inline float fast_sigmoid(fixed32<ExponentBits> x) { | |
| return fast_sigmoid(static_cast<float>(x)); | |
| } | |
| // Number of exponent bits to use for tanh. | |
| static constexpr int kNumTanhExpBits = 3; | |
| // Number of exponent bits to use for sigmoid. | |
| static constexpr int kNumSigmoidExpBits = 4; | |
| // Number of extra bits to shift sigmoid, due to its low gradient. | |
| static constexpr int kNumExtraSigmoidShiftBits = 1; | |
| // Returns (and builds if not done yet) a static data table (that is never | |
| // deleted, as per the style guide) that implements tanh on fixed32 input, | |
| // returning another fixed32 with the given number of mantissa bits (which is | |
| // assumed to be less than the input mantissa bits). | |
| // NOTE that this function is intended to be used only with fixed16 outputs that | |
| // are sign-extended to 32 bits for convenience, and will return a nullptr | |
| // if asked for more than |kMaxMantissaBits| of precision in the output table. | |
| const int* TanhTable(int num_mantissa_bits_out); | |
| // As TanhTable, but for Sigmoid. | |
| const int* SigmoidTable(int num_mantissa_bits_out); | |
| // Scalar/generic function to compute and return the fast approximation to exp | |
| // via a polynomial refinement of the floating point exponent trick. | |
| // TM_ORDER4_16BIT:Max relative error < 5e-6, absolute error < 1e-5 for x < 1. | |
| // TM_ORDER3_16BIT:Max relative error < 1.1e-4, absolute error < 3e-4 for x | |
| // < 1. | |
| template <int kExponentBits, TranscendentalMode kOrder = TM_ORDER4_16BIT> | |
| float fixed32_exp(fixed32<kExponentBits> x) { | |
| constexpr int kMantissaBits = MantissaBitsOf<fixed32<kExponentBits>>::value; | |
| // Clip x to min/max exp input to avoid infinities. | |
| int64_t clipped_x = | |
| std::max(std::min(x.raw_val(), kMaxExpInputInt << kMantissaBits), | |
| -(kMaxExpInputInt << kMantissaBits)); | |
| // First convert problem from e^x to 2^x by multiplying by 1/log(2). | |
| // To maximize precision, log_factor is shifted left the maximum amount to | |
| // keep within int32, and we shift x left a further amount such that the | |
| // binary point of the product sits in the correct place in the top 32 bits of | |
| // the result to be used directly as a float. We can't do that directly, as x | |
| // would overflow, so we have to shift by 1 bit less and shift the result by | |
| // 1 bit less to match. | |
| constexpr int kXShift = | |
| kFloatMantissaBits + 31 - kMaxLog2Shift - kMantissaBits; | |
| static_assert(kXShift >= 0, | |
| "Mantissa bits > kFloatMantissaBits + 31 - kMaxLog2Shift"); | |
| clipped_x <<= kXShift; | |
| int float_as_int = (kLogFactor * clipped_x >> 31) + kFloatExponentOffset; | |
| // Separate the resulting fixed-point into integer and fractional parts. | |
| int int_part = float_as_int & kFloatExponentMask; | |
| int float_part = float_as_int & kFloatMantissaMask; | |
| float fraction = static_cast<float>(float_part) / (1 << kFloatMantissaBits); | |
| // Compute the mantissa = 2^fraction using: | |
| // fraction - fraction*(1-fraction)*(polynomial of fraction) | |
| // This guarantees exactness at 0 and 1, providing continuity of the error at | |
| // integer boundaries. | |
| float mantissa; | |
| if (kOrder == TM_ORDER4_16BIT || kOrder == TM_ORDER4_FLOAT) { | |
| mantissa = (kExpQuarticFactor2 * fraction + kExpQuarticFactor1) * fraction + | |
| kExpQuarticFactor0; | |
| } else if (kOrder == TM_ORDER3_16BIT) { | |
| mantissa = kExpCubicFactor1 * fraction + kExpCubicFactor0; | |
| } | |
| mantissa = fraction - fraction * (1.0f - fraction) * mantissa; | |
| // Since the function above guarantees to stay within [0, 1), we could do all | |
| // the above in fixed point if necessary, in which case, we can just stuff | |
| // the bottom kFloatMantissaBits in with the exponent and we are done. | |
| // In the floating point world, it is simpler to just multiply them together. | |
| float result; | |
| memcpy(&result, &int_part, sizeof(float)); | |
| return result * (1.0f + mantissa); | |
| } | |
| // Computes and returns tanh(x) fixed32->float using a polynomial refinement of | |
| // the floating point exponent trick. | |
| // kOrder=4: Absolute error < 1.8e-6. Relative error < 1.2e-4 for |x| > 0.01. | |
| // kOrder=3: Absolute error < 6e-5. Relative error < 3e-3 for |x| > 0.01 | |
| template <int kExponentBits, TranscendentalMode kOrder = TM_ORDER4_16BIT> | |
| float fixed32_tanh(fixed32<kExponentBits> x) { | |
| float float_x = static_cast<float>(x); | |
| if (std::abs(float_x) < 1.0f / 9.0f) { | |
| return float_x * (1 - float_x * float_x / 3.0f); | |
| } | |
| x = static_cast<fixed32<kExponentBits>>(x.raw_val() * 2); | |
| float exp_2x = fixed32_exp<kExponentBits, kOrder>(x); | |
| return (exp_2x - 1.0f) / (exp_2x + 1.0f); | |
| } | |
| // Computes and returns sigmoid(x) fixed32->float using a polynomial refinement | |
| // of the floating point exponent trick. | |
| // TM_ORDER4_16BIT: Absolute error < 9e-7, relative < 4e-6. | |
| // TM_ORDER3_16BIT: Absolute error < 3e-5, relative < 1.1e-4. | |
| template <int kExponentBits, TranscendentalMode kOrder = TM_ORDER4_16BIT> | |
| float fixed32_sigmoid(fixed32<kExponentBits> x) { | |
| x = static_cast<fixed32<kExponentBits>>(-x.raw_val()); | |
| float exp_x = fixed32_exp<kExponentBits, kOrder>(x); | |
| return 1.0f / (exp_x + 1.0f); | |
| } | |
| // Inline function to access an int32 data table by shifting |x| right by | |
| // |kNumShiftBits|, and adding |kTableOffset| to the result. |x| contains 8 | |
| // indices and 8 results are returned. The data table is of size | |
| // |kTableOffset| * 2 + 1. | |
| template <int kNumShiftBits, int kTableOffset> | |
| inline __m256i index_data_table(const int32_t* data_table, const __m256i& x) { | |
| // Shift right with rounding to match input and output precision. | |
| __m256i shifted = _mm256_set1_epi32(1 << (kNumShiftBits - 1)); | |
| shifted = _mm256_add_epi32(x, shifted); | |
| shifted = _mm256_srai_epi32(shifted, kNumShiftBits); | |
| // Add the offset. | |
| __m256i addend = _mm256_set1_epi32(kTableOffset); | |
| shifted = _mm256_add_epi32(shifted, addend); | |
| // And clamp to the indices of the LUT. | |
| addend = _mm256_add_epi32(addend, addend); | |
| shifted = _mm256_min_epi32(shifted, addend); | |
| shifted = _mm256_max_epi32(shifted, _mm256_setzero_si256()); | |
| // Lookup the results in the table. | |
| return _mm256_i32gather_epi32(data_table, shifted, 4); | |
| } | |
| // Fixed32 to fixed16-in-an-int32 tanh LUT function. | |
| // Input: register containins 8x fixed32 with |NumInputMantissaBits|. | |
| // Output: a register containing 8x fixed16 with |NumOutputMantissaBits|, but | |
| // note that they are sign-extended to 32 bits and are therefore basically the | |
| // same as fixed32 with |NumOutputMantissaBits|. | |
| template <int NumInputMantissaBits, int NumOutputMantissaBits> | |
| inline __m256i fixed32_tanh_fixed16(const int* tanh_table, const __m256i& x) { | |
| // Lose the unnecessary input precision. | |
| constexpr int kNumShiftBits = NumInputMantissaBits - NumOutputMantissaBits; | |
| constexpr int kTableOffset = 1 << (NumOutputMantissaBits + kNumTanhExpBits); | |
| return index_data_table<kNumShiftBits, kTableOffset>(tanh_table, x); | |
| } | |
| // Fixed32 to fixed16-in-an-int32 sigmoid LUT function. | |
| // Input: register containins 8x fixed32 with |NumInputMantissaBits|. | |
| // Output: a register containing 8x fixed16 with |NumOutputMantissaBits|, but | |
| // note that they are sign-extended to 32 bits and are therefore basically the | |
| // same as fixed32 with |NumOutputMantissaBits|. | |
| template <int NumInputMantissaBits, int NumOutputMantissaBits> | |
| inline __m256i fixed32_sigmoid_fixed16(const int* sigmoid_table, | |
| const __m256i& x) { | |
| // Lose the unnecessary input precision. | |
| constexpr int kNumShiftBits = | |
| kNumExtraSigmoidShiftBits + NumInputMantissaBits - NumOutputMantissaBits; | |
| constexpr int kTableOffset = 1 | |
| << (NumOutputMantissaBits + kNumSigmoidExpBits - | |
| kNumExtraSigmoidShiftBits); | |
| return index_data_table<kNumShiftBits, kTableOffset>(sigmoid_table, x); | |
| } | |
| // Convert 2x registers of 8x float32 into 1 register of 16x16 bit fixed int, | |
| // assuming that the floats are already scaled up. | |
| inline __m256i PackFloatsToFixed16(const __m256& x0, const __m256& x1) { | |
| __m256i int0 = _mm256_cvtps_epi32(x0); | |
| __m256i int1 = _mm256_cvtps_epi32(x1); | |
| int0 = _mm256_packs_epi32(int0, int1); | |
| // Swap the middle 64 bit elements so the results are in the right order. | |
| return _mm256_permute4x64_epi64(int0, 0xd8); | |
| } | |
| // Clips integer input to [-|kLimit|, |kLimit|]. | |
| // Input: register containins 8x fixed32 with |mantissa_bits|. | |
| // Output: register containing 8x fixed32 limited to | |
| // [-|kLimit| << |mantissa_bits|, |kLimit| << |mantissa_bits|]. | |
| template <int kLimit> | |
| inline __m256i ClipToBounds(const int mantissa_bits, const __m256i& x) { | |
| // Clip to the input bounds for this approximation. | |
| __m256i clip_limit = _mm256_set1_epi32(-(kLimit << mantissa_bits)); | |
| __m256i clipped_x = _mm256_max_epi32(x, clip_limit); | |
| // This quickly negates the limit without having to load another constant. | |
| clip_limit = _mm256_sign_epi32(clip_limit, clip_limit); | |
| return _mm256_min_epi32(clipped_x, clip_limit); | |
| } | |
| // Clips float input to [-|kLimit|, |kLimit|]. | |
| // Input: register containins 8x float. | |
| // Output: register containing 8x float limited to [-|kLimit|, |kLimit|]. | |
| inline __m256 ClipToFloatBounds(const float kLimit, const __m256& x) { | |
| __m256 clip_limit = _mm256_set1_ps(kLimit); | |
| __m256 clipped_x = _mm256_min_ps(x, clip_limit); | |
| clip_limit = _mm256_set1_ps(-kLimit); | |
| return _mm256_max_ps(clipped_x, clip_limit); | |
| } | |
| // Float to float power of 2 approximation, using a quartic refinement of | |
| // the exponent trick. For TM_ORDER4_16BIT and TM_ORDER3_16BIT, implementation | |
| // is entirely in integer, using 16x16=16 multiplication, using AVX2, which | |
| // enables 16 elements to be computed in parallel, hence the double register | |
| // input/output args. | |
| // The price paid for this speed is an increase in error over the (scalar) int32 | |
| // example implementations above by a variable factor of 4-10. | |
| // For the TM_ORDER4_FLOAT case, the computation is all done in float, solving | |
| // this lower precision problem. | |
| // NOTE: The input must have already been clipped to prevent overflow, which | |
| // sets the practical limit to +/-126 << kFloatMantissaBits. | |
| // NOTE: The input is a scaled float, as if converted raw from int, and the | |
| // scale factor is fixed at kFloatMantissaBits! | |
| // Input: 2x register containining 8x float * 1 << kFloatMantissaBits. | |
| // Output: 2x register containing 8x float. | |
| // TM_ORDER4_FLOAT: Max relative error < 8e-6, absolute error < 9e-6 for x < 1. | |
| // TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 6e-5 for x < 1. | |
| // TM_ORDER3_16BIT: Max relative error < 6e-4, absolute error < 2e-3 for x < 1. | |
| template <TranscendentalMode kOrder = TM_ORDER4_16BIT> | |
| inline void float32_pow2(__m256& x0, __m256& x1) { | |
| // Convert straight to int. | |
| __m256i exp_int_x0 = _mm256_cvtps_epi32(x0); | |
| __m256i exp_int_x1 = _mm256_cvtps_epi32(x1); | |
| __m256i result_x0, result_x1; | |
| static_assert(kOrder == TM_ORDER4_FLOAT || kOrder == TM_ORDER4_16BIT || | |
| kOrder == TM_ORDER3_16BIT, | |
| "Invalid order."); | |
| if (kOrder == TM_ORDER4_FLOAT) { | |
| __m256i mantissa_mask = _mm256_set1_epi32(0x7fffff); | |
| __m256 float_factor = | |
| _mm256_set1_ps(1.0f / static_cast<float>(1 << kFloatMantissaBits)); | |
| __m256i fract0 = _mm256_and_si256(mantissa_mask, exp_int_x0); | |
| __m256i fract1 = _mm256_and_si256(mantissa_mask, exp_int_x1); | |
| __m256 float0 = _mm256_mul_ps(_mm256_cvtepi32_ps(fract0), float_factor); | |
| __m256 float1 = _mm256_mul_ps(_mm256_cvtepi32_ps(fract1), float_factor); | |
| // Compute the polynomial of the fractional part. | |
| // Ordering these lines carefully makes it faster, as some of the multiply | |
| // operations can pipeline instead of waiting for the previous result. | |
| __m256 x_squared0 = _mm256_mul_ps(float0, float0); | |
| __m256 x_squared1 = _mm256_mul_ps(float1, float1); | |
| __m256 b = _mm256_set1_ps(kExpQuarticFactor1); | |
| __m256 b_x0 = _mm256_mul_ps(b, float0); | |
| __m256 b_x1 = _mm256_mul_ps(b, float1); | |
| __m256 a = _mm256_set1_ps(kExpQuarticFactor2); | |
| __m256 a_x_squared0 = _mm256_mul_ps(a, x_squared0); | |
| __m256 a_x_squared1 = _mm256_mul_ps(a, x_squared1); | |
| __m256 x_squared_minus_x0 = _mm256_sub_ps(x_squared0, float0); | |
| __m256 x_squared_minus_x1 = _mm256_sub_ps(x_squared1, float1); | |
| __m256 c = _mm256_set1_ps(kExpQuarticFactor0); | |
| b_x0 = _mm256_add_ps(b_x0, c); | |
| b_x1 = _mm256_add_ps(b_x1, c); | |
| float_factor = _mm256_set1_ps(static_cast<float>(1 << kFloatMantissaBits)); | |
| a_x_squared0 = _mm256_add_ps(a_x_squared0, b_x0); | |
| a_x_squared1 = _mm256_add_ps(a_x_squared1, b_x1); | |
| a_x_squared0 = _mm256_mul_ps(a_x_squared0, x_squared_minus_x0); | |
| a_x_squared1 = _mm256_mul_ps(a_x_squared1, x_squared_minus_x1); | |
| result_x0 = _mm256_cvtps_epi32(_mm256_mul_ps(a_x_squared0, float_factor)); | |
| result_x1 = _mm256_cvtps_epi32(_mm256_mul_ps(a_x_squared1, float_factor)); | |
| } else { | |
| // Combine the fractional part of both inputs into a single register. | |
| // The representation is fixed16<0>, ie 15 mantissa bits. | |
| __m256i mantissa_mask = _mm256_set1_epi32(0x7fff00); | |
| __m256i x_01 = | |
| _mm256_srli_epi32(_mm256_and_si256(mantissa_mask, exp_int_x0), 8); | |
| x_01 = _mm256_or_si256( | |
| x_01, | |
| _mm256_slli_epi32(_mm256_and_si256(mantissa_mask, exp_int_x1), 8)); | |
| // Compute the polynomial of the fractional part. | |
| // Ordering these lines carefully makes it faster, as some of the multiply | |
| // operations can pipeline instead of waiting for the previous result. | |
| __m256i x_squared = _mm256_mulhrs_epi16(x_01, x_01); | |
| __m256i result, x_squared_minus_x; | |
| if (kOrder == TM_ORDER4_16BIT) { | |
| __m256i b = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor1)); | |
| __m256i b_x = _mm256_mulhrs_epi16(b, x_01); | |
| __m256i a = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor2)); | |
| __m256i a_x_squared = _mm256_mulhrs_epi16(a, x_squared); | |
| x_squared_minus_x = _mm256_sub_epi16(x_squared, x_01); | |
| // LOG(INFO) << "x_squared_minus_x=" << | |
| // static_cast<int16>(_mm256_extract_epi16(x_squared_minus_x, 0)) / | |
| // 32768.0f; | |
| __m256i c = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor0)); | |
| b_x = _mm256_add_epi16(b_x, c); | |
| // LOG(INFO) << "bx+c=" << static_cast<int16>(_mm256_extract_epi16(b_x, | |
| // 0)) / 32768.0f; | |
| result = _mm256_add_epi16(a_x_squared, b_x); | |
| } else { // kOrder = TM_ORDER3_16BIT | |
| __m256i a = _mm256_set1_epi16(FloatAsInt16(kExpCubicFactor1)); | |
| __m256i b = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor0)); | |
| __m256i a_x = _mm256_mulhrs_epi16(a, x_01); | |
| x_squared_minus_x = _mm256_sub_epi16(x_squared, x_01); | |
| result = _mm256_add_epi16(a_x, b); | |
| } | |
| result = _mm256_mulhrs_epi16(result, x_squared_minus_x); | |
| // Extract 16x16-bit results back to the separate sets of 8x32. | |
| result_x0 = _mm256_slli_epi32(result, 16); | |
| result_x0 = _mm256_srai_epi32(result_x0, 8); | |
| result_x1 = _mm256_srai_epi32(result, 16); | |
| result_x1 = _mm256_slli_epi32(result_x1, 8); | |
| } | |
| // Add the constant to normalize the exponent. | |
| __m256i exp_offset = _mm256_set1_epi32(kFloatExponentOffset); | |
| exp_int_x0 = _mm256_add_epi32(exp_int_x0, exp_offset); | |
| exp_int_x0 = _mm256_add_epi32(exp_int_x0, result_x0); | |
| exp_int_x1 = _mm256_add_epi32(exp_int_x1, exp_offset); | |
| exp_int_x1 = _mm256_add_epi32(exp_int_x1, result_x1); | |
| // Cast back to float, as we just computed the exponent and mantissa and | |
| // assembled them in IEEE format. | |
| x0 = _mm256_castsi256_ps(exp_int_x0); | |
| x1 = _mm256_castsi256_ps(exp_int_x1); | |
| } | |
| // Fixed32 to to float exp approximation, using a quartic/cubic refinement of | |
| // the exponent trick. Implementation is entirely in integer, using 16x16=16 | |
| // multiplication, using AVX2, which enables 16 elements to be computed in | |
| // parallel, hence the double register input/output args. | |
| // The price paid for this speed is an increase in error over the (scalar) int32 | |
| // example implementations above by a variable factor of 4-10. | |
| // The TM_ORDER4_FLOAT version uses floats and improves the precision. | |
| // Input: 2x registers containins 8x fixed32 with kMantissaBits. | |
| // Output: 2x registers containing 8x float32. | |
| // TM_ORDER4_FLOAT: Max relative error < 8e-6, absolute error < 9e-6 for x < 1. | |
| // TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 6e-5 for x < 1. | |
| // TM_ORDER3_16BIT: Max relative error < 6e-4, absolute error < 2e-3 for x < 1. | |
| template <int kInputMantissaBits, TranscendentalMode kOrder = TM_ORDER4_16BIT> | |
| inline void float_exp_float_preclipped(__m256& y0, __m256& y1) { | |
| // Divide by log 2 to convert problem to 2^x, and scale to match the | |
| // mantissa bits required by IEEE floats. Without a _mm256_mulhrs_epi32, it is | |
| // much easier to do this in float, even with the double conversion, as 16 bit | |
| // is not precise enough here. | |
| // This is the shift of the FP mantissa relative to the input mantissa. | |
| constexpr int kXShift = kFloatMantissaBits - kInputMantissaBits; | |
| constexpr float kLogFactor = static_cast<float>(1 << kXShift); | |
| __m256 factor = _mm256_set1_ps(kLogFactor * kOneOverLog2); | |
| y0 = _mm256_mul_ps(y0, factor); | |
| y1 = _mm256_mul_ps(y1, factor); | |
| // Now compute 2^x. | |
| float32_pow2<kOrder>(y0, y1); | |
| } | |
| template <int kInputMantissaBits, TranscendentalMode kOrder = TM_ORDER4_16BIT> | |
| inline void fixed32_exp_float(const __m256i& x0, const __m256i& x1, __m256& y0, | |
| __m256& y1) { | |
| // Clip to acceptable bounds to prevent overflow, and convert to float. | |
| y0 = | |
| _mm256_cvtepi32_ps(ClipToBounds<kMaxExpInputInt>(kInputMantissaBits, x0)); | |
| y1 = | |
| _mm256_cvtepi32_ps(ClipToBounds<kMaxExpInputInt>(kInputMantissaBits, x1)); | |
| float_exp_float_preclipped<kInputMantissaBits, kOrder>(y0, y1); | |
| } | |
| // Float->float tanh approximation via the exponent trick. | |
| // Note that the input is scaled floats, as if converted raw from fixed16/32. | |
| // Input: 2x registers containing 8x float scaled by input_mantissa_bits. | |
| // Output: two registers containing 8x float. | |
| // TM_ORDER4_FLOAT: Max relative error < 2.1e-5, absolute error < 2.3e-6. | |
| // TM_ORDER4_16BIT: Max relative error < 1e-4, absolute error < 1.3e-5. | |
| // TM_ORDER3_16BIT: Max relative error < 2.1e-3, absolute error < 3e-4. | |
| template <int kInputMantissaBits, TranscendentalMode kOrder = TM_ORDER4_FLOAT> | |
| inline void float_tanh_float(const __m256& x0, const __m256& x1, __m256& y0, | |
| __m256& y1) { | |
| // Divide by log 2 to convert problem to 2^x, double (as we need exp(2x)) and | |
| // scale to the mantissa bits required by float32_pow2 all in one multiply. | |
| // This is the shift of the FP mantissa relative to the input mantissa. | |
| // Add one to double the input. | |
| const float kLogFactor = | |
| static_cast<float>(1 << (kFloatMantissaBits - kInputMantissaBits + 1)); | |
| __m256 factor = _mm256_set1_ps(kLogFactor * kOneOverLog2); | |
| // Clip to suitable input bounds for tanh. | |
| __m256 clip_limit = _mm256_set1_ps(kMaxTanhInput * (1 << kInputMantissaBits)); | |
| __m256 clip0 = _mm256_min_ps(x0, clip_limit); | |
| __m256 clip1 = _mm256_min_ps(x1, clip_limit); | |
| clip_limit = _mm256_set1_ps(-kMaxTanhInput * (1 << kInputMantissaBits)); | |
| clip0 = _mm256_max_ps(clip0, clip_limit); | |
| clip1 = _mm256_max_ps(clip1, clip_limit); | |
| __m256 exp0 = _mm256_mul_ps(clip0, factor); | |
| __m256 exp1 = _mm256_mul_ps(clip1, factor); | |
| // Now compute 2^x. | |
| float32_pow2<kOrder>(exp0, exp1); | |
| // Now compute tanh using (e^2x - 1) / (e^2x + 1). | |
| __m256 one = _mm256_set1_ps(1.0f); | |
| __m256 numerator = _mm256_sub_ps(exp0, one); | |
| __m256 denominator = _mm256_add_ps(exp0, one); | |
| // Approximate reciprocal is not accurate enough - use full division. | |
| exp0 = _mm256_div_ps(numerator, denominator); | |
| numerator = _mm256_sub_ps(exp1, one); | |
| denominator = _mm256_add_ps(exp1, one); | |
| exp1 = _mm256_div_ps(numerator, denominator); | |
| // Compute 3rd-order Taylor tanh ~ x - x^3/3 for high accuracy and thus low | |
| // relative error close to 0. | |
| // Normalize the inputs back to proper floats. | |
| factor = _mm256_set1_ps(1.0f / (1 << kInputMantissaBits)); | |
| clip0 = _mm256_mul_ps(clip0, factor); | |
| clip1 = _mm256_mul_ps(clip1, factor); | |
| __m256 third = _mm256_set1_ps(-1.0f / 3.0f); | |
| __m256 taylor0 = _mm256_mul_ps(clip0, clip0); | |
| __m256 taylor1 = _mm256_mul_ps(clip1, clip1); | |
| taylor0 = _mm256_mul_ps(taylor0, clip0); | |
| taylor1 = _mm256_mul_ps(taylor1, clip1); | |
| // TODO(b/191497441): The next two pairs of instructions could be combined to | |
| // _mm256_fmadd_ps, but requires -mfma compilation option, eg: | |
| // taylor0 = _mm256_fmadd_ps(taylor0, third, clip0); | |
| taylor0 = _mm256_mul_ps(taylor0, third); | |
| taylor1 = _mm256_mul_ps(taylor1, third); | |
| taylor0 = _mm256_add_ps(clip0, taylor0); | |
| taylor1 = _mm256_add_ps(clip1, taylor1); | |
| // Test |x| <= 1/9, roughly where the errors cross over, without needing yet | |
| // another constant. | |
| third = _mm256_mul_ps(third, third); | |
| __m256 neg_zero = _mm256_set1_ps(-0.0f); | |
| clip0 = _mm256_andnot_ps(neg_zero, clip0); | |
| clip1 = _mm256_andnot_ps(neg_zero, clip1); | |
| __m256 cmp_results0 = _mm256_cmp_ps(clip0, third, _CMP_LE_OQ); | |
| __m256 cmp_results1 = _mm256_cmp_ps(clip1, third, _CMP_LE_OQ); | |
| y0 = _mm256_blendv_ps(exp0, taylor0, cmp_results0); | |
| y1 = _mm256_blendv_ps(exp1, taylor1, cmp_results1); | |
| } | |
| // Fixed32 sigmoid approximation via the AVX2 implementation of the exponent | |
| // trick. | |
| // Input: 2x registers containins 8x float containing converted fixed32 scaled | |
| // with kInputMantissaBits. | |
| // Output: 2x registers containing 8x float. | |
| // TM_ORDER4_FLOAT: Max relative error < 4e-6, absolute error < 1e-6. | |
| // TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 7e-6. | |
| // TM_ORDER3_16BIT: Max relative error < 5.4e-4, absolute error < 1.4e-4. | |
| template <int kInputMantissaBits, TranscendentalMode kOrder = TM_ORDER4_FLOAT> | |
| inline void float_sigmoid_float(__m256& y0, __m256& y1) { | |
| constexpr float kInputFactor = static_cast<float>(1 << kInputMantissaBits); | |
| // Negate the inputs. | |
| __m256 minus_zero = _mm256_set1_ps(-0.0f); | |
| y0 = _mm256_xor_ps(y0, minus_zero); | |
| y1 = _mm256_xor_ps(y1, minus_zero); | |
| y0 = ClipToFloatBounds(kMaxSigmoidInput * kInputFactor, y0); | |
| y1 = ClipToFloatBounds(kMaxSigmoidInput * kInputFactor, y1); | |
| float_exp_float_preclipped<kInputMantissaBits, kOrder>(y0, y1); | |
| __m256 one = _mm256_set1_ps(1.0f); | |
| // Approximate reciprocal is not accurate enough - use full division. | |
| y0 = _mm256_div_ps(one, _mm256_add_ps(y0, one)); | |
| y1 = _mm256_div_ps(one, _mm256_add_ps(y1, one)); | |
| } | |
| } // namespace csrblocksparse | |