Spaces:
Sleeping
Sleeping
| /* | |
| * Copyright 2021 Google LLC | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| namespace csrblocksparse { | |
| // The number of elements in a block. | |
| constexpr int kBlockSize = 4; | |
| // Base class for Matmul containing the members that are non type-specicfic. | |
| class MatmulBase { | |
| public: | |
| // Constructor initializes the flags that determine which implementation to | |
| // use at run-time, constrained by both compiler flags and cpuid. | |
| MatmulBase() { | |
| // Code tested to work on Linux systems and multiple Android emulators. | |
| unsigned int eax, ebx, ecx, edx; | |
| if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) != 0) { | |
| using_avx_ = (ecx & bit_AVX) != 0; | |
| if (using_avx_) { | |
| __get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx); | |
| using_avx2_ = (ebx & bit_AVX2) != 0; | |
| using_avx512_ = (ebx & bit_AVX512F) != 0 && (ebx & bit_AVX512DQ) && | |
| (ebx & bit_AVX512BW) != 0; | |
| VLOG(2) << "avx2 flag=" << using_avx2_ << " 512=" << using_avx512_; | |
| } else { | |
| LOG(ERROR) << "AVX not found at all!"; | |
| } | |
| } | |
| using_aarch64_ = true; | |
| } | |
| protected: | |
| // Flags that define what (runtime) architectures are available. Flags that | |
| // are set are limited by both the compiler flags and runtime environment. | |
| bool using_avx512_ = false; | |
| bool using_avx2_ = false; | |
| bool using_avx_ = false; | |
| bool using_aarch64_ = false; | |
| }; | |
| // The master template is really a catch-all for the unimplmented cases to | |
| // report an error. | |
| template <typename WeightType, typename RhsType> | |
| class Matmul : public MatmulBase { | |
| public: | |
| // Sparse inputs, outputs replicated strided for each thread. | |
| template <typename OutType> | |
| void MatVec4x4(const WeightType* weights, const RhsType* rhs, | |
| const typename TypeOfProduct<WeightType, RhsType>::type* bias, | |
| const int32_t* nnz_per_row, const int16_t* rhs_indices, | |
| int start_row, int end_row, bool relu, int replicas, | |
| int stride, OutType* output) { | |
| // The specializations should take care of every real case. | |
| CHECK(false) << "Unsupported combination of types used!"; | |
| } | |
| template <typename OutType> | |
| void MatVec8x4(const WeightType* weights, const RhsType* rhs, | |
| const typename TypeOfProduct<WeightType, RhsType>::type* bias, | |
| const int32_t* nnz_per_row, const int16_t* rhs_indices, | |
| int start_row, int end_row, bool relu, int replicas, | |
| int stride, OutType* output) { | |
| // The specializations should take care of every real case. | |
| CHECK(false) << "Unsupported combination of types used!"; | |
| } | |
| }; | |
| // Full specialization for float. | |
| template <> | |
| class Matmul<float, float> : public MatmulBase { | |
| public: | |
| void MatVec4x4(const float* weights, const float* rhs, const float* bias, | |
| const int32_t* nnz_per_row, const int16_t* rhs_indices, | |
| int start_row, int end_row, bool relu, int replicas, | |
| int stride, float* output) { | |
| detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, | |
| start_row, end_row, /*block_height=*/4, | |
| /*block_width=*/4, relu, replicas, stride, | |
| output); | |
| } | |
| void MatVec8x4(const float* weights, const float* rhs, const float* bias, | |
| const int32_t* nnz_per_row, const int16_t* rhs_indices, | |
| int start_row, int end_row, bool relu, int replicas, | |
| int stride, float* output) { | |
| detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, | |
| start_row, end_row, /*block_height=*/8, | |
| /*block_width=*/4, relu, replicas, stride, | |
| output); | |
| } | |
| }; | |
| // Partial specialization for fixed types. Covers fixed16xfixed16 = OutType, | |
| // where OutType should be fixed16 or fixed32. The mantissa bits don't have | |
| // to match. | |
| template <int WeightBits, int RhsBits> | |
| class Matmul<fixed16<WeightBits>, fixed16<RhsBits>> : public MatmulBase { | |
| public: | |
| using WeightType = fixed16<WeightBits>; | |
| using RhsType = fixed16<RhsBits>; | |
| template <typename OutType> | |
| void MatVec4x4(const int16_t* weights, const int16_t* rhs, | |
| const int32_t* bias, const int32_t* nnz_per_row, | |
| const int16_t* rhs_indices, int start_row, int end_row, | |
| bool relu, int replicas, int stride, OutType* output) { | |
| constexpr int kShiftAmount = | |
| TypeOfProduct<WeightType, RhsType>::type::kMantissaBits - | |
| OutType::kMantissaBits; | |
| static_assert(kShiftAmount >= 0, | |
| "OutType must not have more mantissa bits than inputs"); | |
| CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!"; | |
| if (sizeof(*output) == 4) { | |
| int32_t* out32 = reinterpret_cast<int32_t*>(output); | |
| detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices, | |
| start_row, end_row, relu, kShiftAmount, | |
| replicas, stride, out32); | |
| } else { | |
| int16_t* out16 = reinterpret_cast<int16_t*>(output); | |
| detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices, | |
| start_row, end_row, relu, kShiftAmount, | |
| replicas, stride, out16); | |
| } | |
| if (using_aarch64_) { | |
| LOG(FATAL) << "Fixed16 MatVec4x4 not yet implemented!"; | |
| } | |
| detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, | |
| start_row, end_row, /*block_height=*/4, | |
| /*block_width=*/4, relu, sizeof(*output), | |
| kShiftAmount, replicas, stride, output); | |
| } | |
| template <typename OutType> | |
| void MatVec8x4(const int16_t* weights, const int16_t* rhs, | |
| const int32_t* bias, const int32_t* nnz_per_row, | |
| const int16_t* rhs_indices, int start_row, int end_row, | |
| bool relu, int replicas, int stride, OutType* output) { | |
| constexpr int kShiftAmount = | |
| TypeOfProduct<WeightType, RhsType>::type::kMantissaBits - | |
| OutType::kMantissaBits; | |
| static_assert(kShiftAmount >= 0, | |
| "OutType must not have more mantissa bits than inputs"); | |
| CHECK(replicas == 1 && sizeof(*output) == 4) | |
| << "Only replicas == 1 and fixed32 output are implemented for AVX2!"; | |
| CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!"; | |
| int32_t* out32 = reinterpret_cast<int32_t*>(output); | |
| detail::MatVec8x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices, | |
| start_row, end_row, relu, kShiftAmount, out32); | |
| if (using_aarch64_) { | |
| LOG(FATAL) << "Fixed16 MatVec8x4 not yet implemented!"; | |
| } | |
| detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, | |
| start_row, end_row, /*block_height=*/8, | |
| /*block_width=*/4, relu, sizeof(*output), | |
| kShiftAmount, replicas, stride, output); | |
| } | |
| }; | |
| } // namespace csrblocksparse | |