|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_ |
|
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_ |
|
|
|
#include <cstdint> |
|
#include <vector> |
|
|
|
#include "absl/time/time.h" |
|
#include "sparse_matmul/compute/matmul_fixed_avx2.h" |
|
#include "sparse_matmul/compute/matmul_generic.h" |
|
#include "sparse_matmul/numerics/fixed_types.h" |
|
#include "sparse_matmul/numerics/type_utils.h" |
|
#if defined(__x86_64__) || defined(__i386__) || defined(_WIN32) |
|
#include <cpuid.h> |
|
#endif |
|
|
|
namespace csrblocksparse { |
|
|
|
|
|
constexpr int kBlockSize = 4; |
|
|
|
|
|
class MatmulBase { |
|
public: |
|
|
|
|
|
MatmulBase() { |
|
#if defined(__x86_64__) || defined(__i386__) || defined(_WIN32) |
|
|
|
unsigned int eax, ebx, ecx, edx; |
|
if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) != 0) { |
|
using_avx_ = (ecx & bit_AVX) != 0; |
|
if (using_avx_) { |
|
__get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx); |
|
using_avx2_ = (ebx & bit_AVX2) != 0; |
|
using_avx512_ = (ebx & bit_AVX512F) != 0 && (ebx & bit_AVX512DQ) && |
|
(ebx & bit_AVX512BW) != 0; |
|
VLOG(2) << "avx2 flag=" << using_avx2_ << " 512=" << using_avx512_; |
|
} else { |
|
LOG(ERROR) << "AVX not found at all!"; |
|
} |
|
} |
|
#else |
|
using_aarch64_ = true; |
|
#endif |
|
} |
|
|
|
protected: |
|
|
|
|
|
bool using_avx512_ = false; |
|
bool using_avx2_ = false; |
|
bool using_avx_ = false; |
|
bool using_aarch64_ = false; |
|
}; |
|
|
|
|
|
|
|
template <typename WeightType, typename RhsType> |
|
class Matmul : public MatmulBase { |
|
public: |
|
|
|
template <typename OutType> |
|
void MatVec4x4(const WeightType* weights, const RhsType* rhs, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias, |
|
const int32_t* nnz_per_row, const int16_t* rhs_indices, |
|
int start_row, int end_row, bool relu, int replicas, |
|
int stride, OutType* output) { |
|
|
|
CHECK(false) << "Unsupported combination of types used!"; |
|
} |
|
template <typename OutType> |
|
void MatVec8x4(const WeightType* weights, const RhsType* rhs, |
|
const typename TypeOfProduct<WeightType, RhsType>::type* bias, |
|
const int32_t* nnz_per_row, const int16_t* rhs_indices, |
|
int start_row, int end_row, bool relu, int replicas, |
|
int stride, OutType* output) { |
|
|
|
CHECK(false) << "Unsupported combination of types used!"; |
|
} |
|
}; |
|
|
|
|
|
template <> |
|
class Matmul<float, float> : public MatmulBase { |
|
public: |
|
void MatVec4x4(const float* weights, const float* rhs, const float* bias, |
|
const int32_t* nnz_per_row, const int16_t* rhs_indices, |
|
int start_row, int end_row, bool relu, int replicas, |
|
int stride, float* output) { |
|
detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, |
|
start_row, end_row, 4, |
|
4, relu, replicas, stride, |
|
output); |
|
} |
|
void MatVec8x4(const float* weights, const float* rhs, const float* bias, |
|
const int32_t* nnz_per_row, const int16_t* rhs_indices, |
|
int start_row, int end_row, bool relu, int replicas, |
|
int stride, float* output) { |
|
detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, |
|
start_row, end_row, 8, |
|
4, relu, replicas, stride, |
|
output); |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
template <int WeightBits, int RhsBits> |
|
class Matmul<fixed16<WeightBits>, fixed16<RhsBits>> : public MatmulBase { |
|
public: |
|
using WeightType = fixed16<WeightBits>; |
|
using RhsType = fixed16<RhsBits>; |
|
|
|
template <typename OutType> |
|
void MatVec4x4(const int16_t* weights, const int16_t* rhs, |
|
const int32_t* bias, const int32_t* nnz_per_row, |
|
const int16_t* rhs_indices, int start_row, int end_row, |
|
bool relu, int replicas, int stride, OutType* output) { |
|
constexpr int kShiftAmount = |
|
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits - |
|
OutType::kMantissaBits; |
|
static_assert(kShiftAmount >= 0, |
|
"OutType must not have more mantissa bits than inputs"); |
|
#if defined __AVX2__ |
|
CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!"; |
|
if (sizeof(*output) == 4) { |
|
int32_t* out32 = reinterpret_cast<int32_t*>(output); |
|
detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices, |
|
start_row, end_row, relu, kShiftAmount, |
|
replicas, stride, out32); |
|
} else { |
|
int16_t* out16 = reinterpret_cast<int16_t*>(output); |
|
detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices, |
|
start_row, end_row, relu, kShiftAmount, |
|
replicas, stride, out16); |
|
} |
|
#elif defined __aarch64__ |
|
if (using_aarch64_) { |
|
LOG(FATAL) << "Fixed16 MatVec4x4 not yet implemented!"; |
|
} |
|
|
|
#else |
|
detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, |
|
start_row, end_row, 4, |
|
4, relu, sizeof(*output), |
|
kShiftAmount, replicas, stride, output); |
|
#endif |
|
} |
|
|
|
template <typename OutType> |
|
void MatVec8x4(const int16_t* weights, const int16_t* rhs, |
|
const int32_t* bias, const int32_t* nnz_per_row, |
|
const int16_t* rhs_indices, int start_row, int end_row, |
|
bool relu, int replicas, int stride, OutType* output) { |
|
constexpr int kShiftAmount = |
|
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits - |
|
OutType::kMantissaBits; |
|
static_assert(kShiftAmount >= 0, |
|
"OutType must not have more mantissa bits than inputs"); |
|
#if defined __AVX2__ |
|
CHECK(replicas == 1 && sizeof(*output) == 4) |
|
<< "Only replicas == 1 and fixed32 output are implemented for AVX2!"; |
|
CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!"; |
|
int32_t* out32 = reinterpret_cast<int32_t*>(output); |
|
detail::MatVec8x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices, |
|
start_row, end_row, relu, kShiftAmount, out32); |
|
#elif defined __aarch64__ |
|
if (using_aarch64_) { |
|
LOG(FATAL) << "Fixed16 MatVec8x4 not yet implemented!"; |
|
} |
|
#else |
|
detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, |
|
start_row, end_row, 8, |
|
4, relu, sizeof(*output), |
|
kShiftAmount, replicas, stride, output); |
|
#endif |
|
} |
|
}; |
|
|
|
} |
|
|
|
#endif |
|
|