Spaces:
Sleeping
Sleeping
| /* | |
| * Copyright 2021 Google LLC | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| // TODO(b/188702959): Remove fast_transcendentals with GRU refactor. | |
| namespace csrblocksparse { | |
| namespace detail { | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| struct IsAllowableFloatTypes | |
| : std::integral_constant<bool, std::is_same<WeightType, float>::value && | |
| std::is_same<RhsType, float>::value && | |
| std::is_same<OutType, float>::value> {}; | |
| // 16-bit inputs, 32-bit output exponent matches sum of input exponents | |
| // OR | |
| // 16-bit inputs, 16-bit output - will shift to match exponent | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| struct IsAllowableFixedTypes | |
| : std::integral_constant<bool, (IsFixed16Type<WeightType>::value && | |
| IsFixed16Type<RhsType>::value) && | |
| (IsFixed32Type<OutType>::value || | |
| IsFixed16Type<OutType>::value)> {}; | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| struct ShouldEnableGenericKernel | |
| : std::integral_constant< | |
| bool, | |
| !IsAllowableFloatTypes<WeightType, RhsType, OutType>::value && | |
| !IsAllowableFixedTypes<WeightType, RhsType, OutType>::value> {}; | |
| template <typename Type> | |
| struct IsAddableFixedTypes | |
| : std::integral_constant<bool, IsFixed32Type<Type>::value || | |
| IsFixed16Type<Type>::value> {}; | |
| template <typename Type> | |
| struct ShouldEnableGenericAdd | |
| : std::integral_constant<bool, !IsAddableFixedTypes<Type>::value> {}; | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| struct ShouldEnableGenericKernel | |
| : std::integral_constant< | |
| bool, !IsAllowableFloatTypes<WeightType, RhsType, OutType>::value> {}; | |
| template <typename Type> | |
| struct ShouldEnableGenericAdd : std::true_type {}; | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| struct ShouldEnableGenericSpMV_4x4 | |
| : ShouldEnableGenericKernel<WeightType, RhsType, OutType> {}; | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| struct ShouldEnableGenericSpMM5_4x4 | |
| : ShouldEnableGenericKernel<WeightType, RhsType, OutType> {}; | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| struct ShouldEnableGenericSpMV_1x1 : std::true_type {}; | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| struct ShouldEnableGenericSpMM5_1x1 : std::true_type {}; | |
| // The computational routines do NO error checking for speed. It is assumed | |
| // that this has been handled by CSRBlockSparseMatrix. | |
| // In-line function to extract results from a pair of registers and store in | |
| // memory. Note that the non-const references are registers, and are modified | |
| // by this function! | |
| inline void Extract4Results(bool relu, __m256& sum1, __m256& sum2, | |
| float** out_ptr) { | |
| // Horizontally add the results. We have 2 registers, |sum1| and |sum2| that | |
| // each contain 2 sets of 4 values that need to be added. | |
| sum1 = _mm256_hadd_ps(sum1, sum2); | |
| sum1 = _mm256_hadd_ps(sum1, sum1); | |
| // Now |sum1| contains [|res0|, |res2|, |res0|, |res2|, |res1|, |res3|, | |
| // |res1|, |res3|] | |
| if (relu) { | |
| sum1 = _mm256_max_ps(sum1, _mm256_setzero_ps()); | |
| } | |
| // It is really hard in AVX to cross the 128 bit 'lanes' and this is the | |
| // *only* way to do it. | |
| // Get the top half of |sum1| in to bottom of |sum2|. | |
| sum2 = _mm256_permute2f128_ps(sum1, sum1, 1); | |
| // Interleave the values between the two registers. | |
| sum1 = _mm256_unpacklo_ps(sum1, sum2); | |
| // Save the lower 128 bits (4 floats). | |
| __m128 result = _mm256_extractf128_ps(sum1, 0); | |
| _mm_store_ps(*out_ptr, result); | |
| *out_ptr += 4; | |
| } | |
| // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 | |
| // blocked pattern, x is a vector and b is vector. Weights are stored for this | |
| // routine by making each 4x4 block contiguous. Blocks are ordered in standard | |
| // row-major format. column indices are converted to deltas and then multiplied | |
| // by 2 to convert to bytes, so that the value can be used directly to offset | |
| // the pointer into the rhs vector. | |
| // | |
| // NOTE: The bias is expected to have be multiplied by .25f prior to calling | |
| // this function. This is automatically taken care of in SparseLinearLayer. | |
| // The bias is reconstructed through horizontal additions, leads to a small | |
| // speedup by reducing latencies at the end of the loop. | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| typename std::enable_if<std::is_same<WeightType, float>::value && | |
| std::is_same<RhsType, float>::value && | |
| std::is_same<OutType, float>::value>::type | |
| SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, | |
| const int32_t* nnz_per_row, const RhsType* rhs_ptr, | |
| const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, | |
| OutType* out_ptr, int64_t assigned_rows, | |
| int64_t rows /* only used in SpMM variants */, | |
| int64_t cols /* only used in SpMM variants */, int relu) { | |
| for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { | |
| // Broadcast the biases by 4 to undo the division by 4 in the input biases. | |
| __m256 sum1 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), | |
| _mm_broadcast_ss(bias_ptr)); | |
| bias_ptr += 2; | |
| __m256 sum2 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), | |
| _mm_broadcast_ss(bias_ptr)); | |
| bias_ptr += 2; | |
| int reduced_col_count = *nnz_per_row++; | |
| for (int c = 0; c < reduced_col_count; ++c) { | |
| int col_delta = *col_deltas_bytes++ / sizeof(RhsType); | |
| rhs_ptr += col_delta; | |
| // Multiply this 4x4 block. | |
| __m256 rhs = | |
| _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr)); | |
| __m256 weights1 = _mm256_load_ps(weights_ptr); | |
| weights_ptr += 8; | |
| sum1 = _mm256_add_ps(sum1, _mm256_mul_ps(weights1, rhs)); | |
| __m256 weights2 = _mm256_load_ps(weights_ptr); | |
| weights_ptr += 8; | |
| sum2 = _mm256_add_ps(sum2, _mm256_mul_ps(weights2, rhs)); | |
| } | |
| Extract4Results(relu, sum1, sum2, &out_ptr); | |
| } | |
| } | |
| // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 | |
| // blocked pattern, x is a fat vector with 5 columns and b is vector. b is | |
| // broadcast. Weights are stored for this routine by making each 4x4 block | |
| // contiguous. Blocks are ordered in standard row-major format. column indices | |
| // are converted to deltas and then multiplied by 2 to convert to bytes, so | |
| // that the value can be used directly to offset the pointer into the rhs | |
| // vector. | |
| // | |
| // NOTE: The bias is expected to have be multiplied by .25f prior to calling | |
| // this function. This is automatically taken care of in SparseLinearLayer. | |
| // The bias is reconstructed through horizontal additions, leads to a small | |
| // speedup by reducing latencies at the end of the loop. | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| typename std::enable_if<std::is_same<WeightType, float>::value && | |
| std::is_same<RhsType, float>::value && | |
| std::is_same<OutType, float>::value>::type | |
| SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, | |
| const int32_t* nnz_per_row, const RhsType* rhs_ptr, | |
| const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, | |
| OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, | |
| int relu) { | |
| const RhsType* rhs_ptrs[5]; | |
| for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; | |
| OutType* out_ptrs[5]; | |
| for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; | |
| for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { | |
| // We will acumulate the results in 10 registers, |sum1_0| to |sum2_4|. | |
| // Broadcast the biases by 4 to undo the division by 4 in the input biases. | |
| __m256 sum1_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), | |
| _mm_broadcast_ss(bias_ptr)); | |
| bias_ptr += 2; | |
| __m256 sum2_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), | |
| _mm_broadcast_ss(bias_ptr)); | |
| bias_ptr += 2; | |
| __m256 sum1_1 = sum1_0; | |
| __m256 sum2_1 = sum2_0; | |
| __m256 sum1_2 = sum1_0; | |
| __m256 sum2_2 = sum2_0; | |
| __m256 sum1_3 = sum1_0; | |
| __m256 sum2_3 = sum2_0; | |
| __m256 sum1_4 = sum1_0; | |
| __m256 sum2_4 = sum2_0; | |
| int reduced_col_count = *nnz_per_row++; | |
| for (int c = 0; c < reduced_col_count; ++c) { | |
| int col_delta = *col_deltas_bytes++ / sizeof(RhsType); | |
| for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta; | |
| // Multiply this 4x4 block. | |
| __m256 rhs = | |
| _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[0])); | |
| __m256 weights1 = _mm256_load_ps(weights_ptr); | |
| weights_ptr += 8; | |
| sum1_0 = _mm256_add_ps(sum1_0, _mm256_mul_ps(weights1, rhs)); | |
| __m256 weights2 = _mm256_load_ps(weights_ptr); | |
| weights_ptr += 8; | |
| sum2_0 = _mm256_add_ps(sum2_0, _mm256_mul_ps(weights2, rhs)); | |
| rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[1])); | |
| sum1_1 = _mm256_add_ps(sum1_1, _mm256_mul_ps(weights1, rhs)); | |
| sum2_1 = _mm256_add_ps(sum2_1, _mm256_mul_ps(weights2, rhs)); | |
| rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[2])); | |
| sum1_2 = _mm256_add_ps(sum1_2, _mm256_mul_ps(weights1, rhs)); | |
| sum2_2 = _mm256_add_ps(sum2_2, _mm256_mul_ps(weights2, rhs)); | |
| rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[3])); | |
| sum1_3 = _mm256_add_ps(sum1_3, _mm256_mul_ps(weights1, rhs)); | |
| sum2_3 = _mm256_add_ps(sum2_3, _mm256_mul_ps(weights2, rhs)); | |
| rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[4])); | |
| sum1_4 = _mm256_add_ps(sum1_4, _mm256_mul_ps(weights1, rhs)); | |
| sum2_4 = _mm256_add_ps(sum2_4, _mm256_mul_ps(weights2, rhs)); | |
| } | |
| Extract4Results(relu, sum1_0, sum2_0, &out_ptrs[0]); | |
| Extract4Results(relu, sum1_1, sum2_1, &out_ptrs[1]); | |
| Extract4Results(relu, sum1_2, sum2_2, &out_ptrs[2]); | |
| Extract4Results(relu, sum1_3, sum2_3, &out_ptrs[3]); | |
| Extract4Results(relu, sum1_4, sum2_4, &out_ptrs[4]); | |
| } | |
| } | |
| // In-line function to finish the computation of the result as 4x int32 in | |
| // |sum|. | |
| inline void Compute4Results(bool relu, int kShiftAmount, __m256i& sum) { | |
| // Horizontally add the results. We have 1 register that contains results | |
| // [0 0 1 1 2 2 3 3], but hadd (and almost no other AVX instruction) will not | |
| // cross lanes, so we end up with [0 1 0 1 2 3 2 3] | |
| sum = _mm256_hadd_epi32(sum, sum); | |
| // Permutes the middle two pairs to get the answers together. | |
| sum = _mm256_permute4x64_epi64(sum, 0xd8); | |
| if (kShiftAmount > 0) { | |
| // Shift right with rounding to get the right number of mantissa bits. | |
| __m256i rounding = _mm256_set1_epi32(1 << (kShiftAmount - 1)); | |
| sum = _mm256_add_epi32(sum, rounding); | |
| sum = _mm256_srai_epi32(sum, kShiftAmount); | |
| } | |
| // Now |sum| contains [|res0|, |res1|, |res2|, |res3|, |res0|, |res1|, | |
| // |res2|, |res3|] | |
| if (relu) { | |
| sum = _mm256_max_epi32(sum, _mm256_setzero_si256()); | |
| } | |
| } | |
| // In-line function to extract the 4x int32 results from |sum| to memory. | |
| // Non-const reference for |sum| as it is a register. | |
| inline void Extract4xint32(bool relu, int kShiftAmount, __m256i& sum, | |
| int32_t** out_ptr) { | |
| Compute4Results(relu, kShiftAmount, sum); | |
| // Save the lower 128 bits (4x int32). | |
| __m128i result = _mm256_extractf128_si256(sum, 0); | |
| _mm_store_si128(reinterpret_cast<__m128i*>(*out_ptr), result); | |
| *out_ptr += 4; | |
| } | |
| // In-line function to extract the 4x int32 results from sum to 4x int16 in | |
| // memory. | |
| // Non-const reference for |sum| as it is a register. | |
| inline void Extract4xint16(bool relu, int kShiftAmount, __m256i& sum, | |
| int16_t** out_ptr) { | |
| Compute4Results(relu, kShiftAmount, sum); | |
| // Clip to 16 bit range (with saturation) and pack in the bottom 64 bits. | |
| // Converts the lower 4x int32 in bottom 128 bits to 4x int16 in bottom 64 | |
| // bits, replicated in the next 64 bits. | |
| sum = _mm256_packs_epi32(sum, sum); | |
| // Save 4x int 16 from the bottom 64 bits. | |
| *reinterpret_cast<int64_t*>(*out_ptr) = _mm256_extract_epi64(sum, 0); | |
| *out_ptr += 4; | |
| } | |
| // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 | |
| // blocked pattern, x is a vector and b is vector. Weights are stored for this | |
| // routine by making each 4x4 block contiguous. Blocks are ordered in standard | |
| // row-major format. column indices are converted to deltas and then multiplied | |
| // by 2 to convert to bytes, so that the value can be used directly to offset | |
| // the pointer into the rhs vector. | |
| // | |
| // NOTE: The bias is expected to have be multiplied by .25f prior to calling | |
| // this function. This is automatically taken care of in SparseLinearLayer. | |
| // The bias is reconstructed through horizontal additions, leads to a small | |
| // speedup by reducing latencies at the end of the loop. | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| typename std::enable_if< | |
| IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value && | |
| (IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type | |
| SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, | |
| const int32_t* nnz_per_row, const RhsType* rhs_ptr, | |
| const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, | |
| OutType* out_ptr, int64_t assigned_rows, | |
| int64_t rows /* only used in SpMM variants */, | |
| int64_t cols /* only used in SpMM variants */, int relu) { | |
| constexpr int kShiftAmount = | |
| TypeOfProduct<WeightType, RhsType>::type::kMantissaBits - | |
| OutType::kMantissaBits; | |
| static_assert(kShiftAmount >= 0, | |
| "Result must have fewer mantissa bits than product"); | |
| for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { | |
| // Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3]. | |
| __m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr)); | |
| __m256i biases = _mm256_set_m128i(bias, bias); | |
| bias_ptr += 4; | |
| // Swap the top two pairs: [0 1 2 3 2 3 0 1] | |
| // TODO(b/188702959): consider |_mm256_permutevar8x32|, and set the index | |
| // register outside the row loop. | |
| biases = _mm256_permute4x64_epi64(biases, 0xb4); | |
| // Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3]. | |
| biases = _mm256_unpacklo_epi32(biases, biases); | |
| // Double the results to make up for the division by 4. | |
| // TODO(b/188702959): consider moving this to where the biases are computed. | |
| __m256i sum = _mm256_add_epi32(biases, biases); | |
| // TODO(b/188702959): People don't like the old-fashioned, close-to-the- | |
| // metal notation of *|nnz_per_row|++, so measure the effect of putting the | |
| // increment in the for loop. | |
| int reduced_col_count = *nnz_per_row; | |
| ++nnz_per_row; | |
| for (int c = 0; c < reduced_col_count; ++c) { | |
| int col_delta = *col_deltas_bytes++ / sizeof(RhsType); | |
| rhs_ptr += col_delta; | |
| // Multiply this 4x4 block. | |
| // Get the 4x int16 into the bottom of rhs_64. | |
| __m128i rhs_64 = | |
| _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptr)); | |
| // Load all 16 weights. | |
| __m256i weights = | |
| _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); | |
| // Broadcast the rhs, pretending that each is a 64-bit unit: | |
| // [0123 0123 0123 0123]. | |
| __m256i rhs = _mm256_broadcastq_epi64(rhs_64); | |
| weights_ptr += 16; | |
| // |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally | |
| // adds adjacent pairs to make 8x32 bit results. Add these to the sum. | |
| sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs)); | |
| } | |
| static_assert( | |
| IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value, | |
| "AVX2 kernel only supports fixed16 and fixed32 types"); | |
| // The only significant difference between fixed16 and fixed32 is the size | |
| // of the storage unit. The registers have to be repacked accordingly. | |
| if (IsFixed32Type<OutType>::value) { | |
| Extract4xint32(relu, kShiftAmount, sum, | |
| reinterpret_cast<int32_t**>(&out_ptr)); | |
| } else { | |
| Extract4xint16(relu, kShiftAmount, sum, | |
| reinterpret_cast<int16_t**>(&out_ptr)); | |
| } | |
| } | |
| } | |
| // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 | |
| // blocked pattern, x is a fat vector with 5 columns and b is vector. b is | |
| // broadcast. Weights are stored for this routine by making each 4x4 block | |
| // contiguous. Blocks are ordered in standard row-major format. column indices | |
| // are converted to deltas and then multiplied by 2 to convert to bytes, so | |
| // that the value can be used directly to offset the pointer into the rhs | |
| // vector. | |
| // | |
| // NOTE: The bias is expected to have be multiplied by .25f prior to calling | |
| // this function. This is automatically taken care of in SparseLinearLayer. | |
| // The bias is reconstructed through horizontal additions, leads to a small | |
| // speedup by reducing latencies at the end of the loop. | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| typename std::enable_if< | |
| IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value && | |
| (IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type | |
| SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, | |
| const int32_t* nnz_per_row, const RhsType* rhs_ptr, | |
| const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, | |
| OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, | |
| int relu) { | |
| constexpr int kShiftAmount = | |
| TypeOfProduct<WeightType, RhsType>::type::kMantissaBits - | |
| OutType::kMantissaBits; | |
| static_assert(kShiftAmount >= 0, | |
| "Result must have fewer mantissa bits than product"); | |
| const RhsType* rhs_ptrs[5]; | |
| for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; | |
| OutType* out_ptrs[5]; | |
| for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; | |
| for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { | |
| // We will acumulate the results in 5 registers, sum_0 to sum_4. | |
| // Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3]. | |
| __m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr)); | |
| __m256i biases = _mm256_set_m128i(bias, bias); | |
| bias_ptr += 4; | |
| // Swap the top two pairs: [0 1 2 3 2 3 0 1] | |
| biases = _mm256_permute4x64_epi64(biases, 0xb4); | |
| // Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3]. | |
| biases = _mm256_unpacklo_epi32(biases, biases); | |
| // Double the results to make up for the division by 4. | |
| __m256i sum_0 = _mm256_add_epi32(biases, biases); | |
| __m256i sum_1 = sum_0; | |
| __m256i sum_2 = sum_0; | |
| __m256i sum_3 = sum_0; | |
| __m256i sum_4 = sum_0; | |
| int reduced_col_count = *nnz_per_row; | |
| ++nnz_per_row; | |
| for (int c = 0; c < reduced_col_count; ++c) { | |
| int col_delta = *col_deltas_bytes++ / sizeof(RhsType); | |
| for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta; | |
| // Multiply this 4x4 block. | |
| // Get the 4x int16 into the bottom of |rhs_64|. | |
| __m128i rhs_64 = | |
| _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[0])); | |
| // Load all 16 weights. | |
| __m256i weights = | |
| _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); | |
| // Broadcast the rhs, pretending that each is a 64-bit unit: | |
| // [0123 0123 0123 0123]. | |
| __m256i rhs = _mm256_broadcastq_epi64(rhs_64); | |
| weights_ptr += 16; | |
| // |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally | |
| // adds adjacent pairs to make 8x32 bit results. Add these to the sum. | |
| sum_0 = _mm256_add_epi32(sum_0, _mm256_madd_epi16(weights, rhs)); | |
| rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[1])); | |
| rhs = _mm256_broadcastq_epi64(rhs_64); | |
| sum_1 = _mm256_add_epi32(sum_1, _mm256_madd_epi16(weights, rhs)); | |
| rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[2])); | |
| rhs = _mm256_broadcastq_epi64(rhs_64); | |
| sum_2 = _mm256_add_epi32(sum_2, _mm256_madd_epi16(weights, rhs)); | |
| rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[3])); | |
| rhs = _mm256_broadcastq_epi64(rhs_64); | |
| sum_3 = _mm256_add_epi32(sum_3, _mm256_madd_epi16(weights, rhs)); | |
| rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[4])); | |
| rhs = _mm256_broadcastq_epi64(rhs_64); | |
| sum_4 = _mm256_add_epi32(sum_4, _mm256_madd_epi16(weights, rhs)); | |
| } | |
| static_assert( | |
| IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value, | |
| "AVX2 kernel only supports fixed16 and fixed32 types"); | |
| // The only significant difference between fixed16 and fixed32 is the size | |
| // of the storage unit. The registers have to be repacked accordingly. | |
| if (IsFixed32Type<OutType>::value) { | |
| Extract4xint32(relu, kShiftAmount, sum_0, | |
| reinterpret_cast<int32_t**>(&out_ptrs[0])); | |
| Extract4xint32(relu, kShiftAmount, sum_1, | |
| reinterpret_cast<int32_t**>(&out_ptrs[1])); | |
| Extract4xint32(relu, kShiftAmount, sum_2, | |
| reinterpret_cast<int32_t**>(&out_ptrs[2])); | |
| Extract4xint32(relu, kShiftAmount, sum_3, | |
| reinterpret_cast<int32_t**>(&out_ptrs[3])); | |
| Extract4xint32(relu, kShiftAmount, sum_4, | |
| reinterpret_cast<int32_t**>(&out_ptrs[4])); | |
| } else { | |
| Extract4xint16(relu, kShiftAmount, sum_0, | |
| reinterpret_cast<int16_t**>(&out_ptrs[0])); | |
| Extract4xint16(relu, kShiftAmount, sum_1, | |
| reinterpret_cast<int16_t**>(&out_ptrs[1])); | |
| Extract4xint16(relu, kShiftAmount, sum_2, | |
| reinterpret_cast<int16_t**>(&out_ptrs[2])); | |
| Extract4xint16(relu, kShiftAmount, sum_3, | |
| reinterpret_cast<int16_t**>(&out_ptrs[3])); | |
| Extract4xint16(relu, kShiftAmount, sum_4, | |
| reinterpret_cast<int16_t**>(&out_ptrs[4])); | |
| } | |
| } | |
| } | |
| // Processes one GRU gate input with sigmoid. | |
| template <int InputMantissaBits, int StateMantissaBits, bool SplitGates> | |
| inline __m256i GRUGateSigmoid(const void* gate_ptr, const void* gate_other_ptr, | |
| const __m256i& input, | |
| const int32_t* sigmoid_table) { | |
| __m256i gate = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_ptr)); | |
| if (SplitGates) { | |
| __m256i other = | |
| _mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_other_ptr)); | |
| gate = _mm256_add_epi32(gate, other); | |
| } | |
| gate = _mm256_add_epi32(gate, input); | |
| // Compute sigmoids on reset and update. | |
| return csrblocksparse::fixed32_sigmoid_fixed16<InputMantissaBits, | |
| StateMantissaBits>( | |
| sigmoid_table, gate); | |
| } | |
| // Processes the tanh and the final combination, returning the new GRU state. | |
| template <int InputMantissaBits, int StateMantissaBits, bool SplitGates = false> | |
| inline __m256i GRUGateState(const __m256i& cell, const __m256i& reset, | |
| const __m256i& update, | |
| const __m256i& rounding_offset, | |
| const void* gate_ptr, const void* gate_other_ptr, | |
| const void* gru_h_ptr, const int32_t* tanh_table) { | |
| // Multiply the cell GRU output and the reset. There is a slight danger of | |
| // loss of precision here, so use 32x32=64 bit and shift back after. | |
| __m256i gru = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_ptr)); | |
| if (SplitGates) { | |
| __m256i other_gru = | |
| _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_other_ptr)); | |
| gru = _mm256_add_epi32(gru, other_gru); | |
| } | |
| // This only computes the products of the low-order 32 bits of each pair. | |
| __m256i gru_lo = _mm256_mul_epi32(gru, reset); | |
| // Swap odd and even 32-bit units and do it again to get the high products. | |
| gru = _mm256_shuffle_epi32(gru, 0xb1); | |
| __m256i gru_hi = _mm256_mul_epi32(gru, _mm256_shuffle_epi32(reset, 0xb1)); | |
| // Now shift right to compensate for the multiply and re-interleave the | |
| // 32-bit results. | |
| // NOTE: There is no shift right arithmetic for 64 bit values until AVX512! | |
| // Fortunately it doesn't matter, as the results are being truncated to 32 | |
| // bits and we aren't shifting right by more than 32 bits here. | |
| gru_lo = _mm256_srli_epi64(gru_lo, StateMantissaBits); | |
| // The upper results are shifted LEFT, so we can use blend to recombine in | |
| // a single instruction. | |
| gru_hi = _mm256_slli_epi64(gru_hi, 32 - StateMantissaBits); | |
| // Recombine the 32 bit results from lo and hi, alternating. | |
| gru = _mm256_blend_epi32(gru_lo, gru_hi, 0xaa); | |
| gru = _mm256_add_epi32(cell, gru); | |
| // Compute tanh on the result. Although this instantly discards a bunch of | |
| // bits, there were only 7 surplus bits for the multiply, which isn't enough | |
| // to do it as 16x16=32. | |
| __m256i hbar = | |
| csrblocksparse::fixed32_tanh_fixed16<InputMantissaBits, | |
| StateMantissaBits>(tanh_table, gru); | |
| // Load the 16-bit previous GRU state and sign-extend to 32 bits. | |
| gru = _mm256_cvtepi16_epi32( | |
| _mm_load_si128(reinterpret_cast<__m128i const*>(gru_h_ptr))); | |
| gru = _mm256_sub_epi32(gru, hbar); | |
| // Since |gru| is 16 bit sign-extended to 32, and |update| is the output of | |
| // sigmoid, it is always contained within 16 bits and never negative, we can | |
| // use |madd_epi16| to do 16x16=32 multiply with horizontal adding as the | |
| // addend will always be zero, and this is twice as fast as full blown | |
| // 32x32=32. The only possible problem is if the subtract above caused | |
| // overflow. | |
| gru = _mm256_madd_epi16(gru, update); | |
| // Renormalize to fixed16. This time rounding is critical, as this is the | |
| // output GRU state. | |
| gru = _mm256_add_epi32(gru, rounding_offset); | |
| gru = _mm256_srai_epi32(gru, StateMantissaBits); | |
| return _mm256_add_epi32(gru, hbar); | |
| } | |
| template <typename Type> | |
| typename std::enable_if<IsFixed32Type<Type>::value>::type SumVectors( | |
| int start, int end, const Type* add1, const Type* add2, Type* result) { | |
| constexpr int kSIMDWidth = 8; | |
| for (int i = start; i < end; i += kSIMDWidth) { | |
| __m256i data1 = | |
| _mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i)); | |
| __m256i data2 = | |
| _mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i)); | |
| data1 = _mm256_add_epi32(data1, data2); | |
| _mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1); | |
| } | |
| } | |
| template <typename Type> | |
| typename std::enable_if<IsFixed16Type<Type>::value>::type SumVectors( | |
| int start, int end, const Type* add1, const Type* add2, Type* result) { | |
| constexpr int kSIMDWidth = 16; | |
| for (int i = start; i < end; i += kSIMDWidth) { | |
| __m256i data1 = | |
| _mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i)); | |
| __m256i data2 = | |
| _mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i)); | |
| data1 = _mm256_add_epi16(data1, data2); | |
| _mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1); | |
| } | |
| } | |
| } // namespace detail | |
| } // namespace csrblocksparse | |