Spaces:
Sleeping
Sleeping
| /* | |
| * Copyright 2021 Google LLC | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| namespace csrblocksparse { | |
| // MaskedSparseMatrix serves two purposes: | |
| // 1) It is useful as a reference implementation of SpMV for correctness | |
| // checking the much more complicated implementations in CSRBlockSparseMatrix | |
| // 2) This is the format that sparse matrices are represented after pruning | |
| // in TF. This class provides a bridge to getting these parameters into | |
| // a compressed form suitable for computation and serialization. | |
| // | |
| // MaskedSparseMatrix<float> matrix(rows, cols, mask_from_tf, values_from_tf); | |
| // CSRBlockSparseMatrix<float, bfloat16, int16_t> csr_matrix(matrix); | |
| // csr_matrix.Multiply(rhs, bias, &out); | |
| template <typename T> | |
| class MaskedSparseMatrix { | |
| public: | |
| MaskedSparseMatrix() {} | |
| // Construct a MaskedSparseMatrix of the given size, sparsity and block size. | |
| // This is mainly useful for testing. | |
| MaskedSparseMatrix(int rows, int cols, float sparsity, int block_height = 1, | |
| int block_width = 1, float constant = 1.f, | |
| bool random = true) | |
| : rows_(rows), cols_(cols), sparsity_(sparsity) { | |
| CHECK_EQ(rows % block_height, 0); | |
| CHECK_EQ(cols % block_width, 0); | |
| init(sparsity, block_height, block_width, constant, random); | |
| } | |
| // Construct from an existing mask and values (most likely from a TF model). | |
| template <typename MaskType> | |
| MaskedSparseMatrix(int rows, int cols, const MaskType* mask, const T* values) | |
| : rows_(rows), cols_(cols) { | |
| mask_.resize(rows * cols); | |
| values_.resize(rows * cols); | |
| std::copy_n(mask, rows * cols, mask_.begin()); | |
| std::copy_n(values, rows * cols, values_.begin()); | |
| sparsity_ = | |
| 1.f - std::accumulate(mask_.begin(), mask_.end(), 0.f) / mask_.size(); | |
| } | |
| const std::vector<int>& mask() const { return mask_; } | |
| const std::vector<T>& values() const { return values_; } | |
| T* data() { return values_.data(); } | |
| const T* data() const { return values_.data(); } | |
| int rows() const { return rows_; } | |
| int cols() const { return cols_; } | |
| float sparsity() const { return sparsity_; } | |
| void Print() const { | |
| absl::PrintF("-------Values---------\n"); | |
| for (int r = 0; r < rows_; ++r) { | |
| for (int c = 0; c < cols_; ++c) { | |
| absl::PrintF("%+6.3f ", static_cast<float>(values_[r * cols_ + c])); | |
| } | |
| absl::PrintF("\n"); | |
| } | |
| absl::PrintF("-------Mask---------\n"); | |
| for (int r = 0; r < rows_; ++r) { | |
| for (int c = 0; c < cols_; ++c) { | |
| printf("%2d ", mask_[r * cols_ + c]); | |
| } | |
| absl::PrintF("\n"); | |
| } | |
| } | |
| // This routine is useful for rounding the possibly higher precision values | |
| // stored in this class to a lower precision, so that correctness checks | |
| // between this class and CSRBlockSparseMatrix can have a tighter tolerance. | |
| template <typename U> | |
| void CastWeights() { | |
| for (int i = 0; i < values_.size(); ++i) { | |
| values_[i] = static_cast<T>(U(values_[i])); | |
| } | |
| } | |
| // Only meant for correctness checking. | |
| // RhsClassType is meant to be either CacheAlignedVector OR | |
| // FatCacheAlignedVector. | |
| // The weight matrix is ROW MAJOR and RhsClassType is COLUMN MAJOR. | |
| // |bias| is broadcast if |rhs| has more than one column. | |
| template <typename RhsClassType, typename BiasType, typename OutClassType, | |
| typename RhsType = typename RhsClassType::value_type, | |
| typename OutType = typename OutClassType::value_type> | |
| void SpMM_bias(const RhsClassType& rhs, | |
| const CacheAlignedVector<BiasType>& bias, OutClassType* out, | |
| bool relu = false) { | |
| for (int r = 0; r < rows_; ++r) { | |
| for (int n = 0; n < rhs.cols(); ++n) { | |
| float sum = 0.f; | |
| const RhsType* rhs_ptr = rhs.data() + n * rhs.rows(); | |
| OutType* out_ptr = out->data() + n * out->rows(); | |
| const int* mask_ptr = mask_.data() + r * cols_; | |
| const T* value_ptr = values_.data() + r * cols_; | |
| for (int c = 0; c < cols_; ++c) { | |
| sum += mask_ptr[c] * static_cast<float>(value_ptr[c]) * | |
| static_cast<float>(rhs_ptr[c]); | |
| } | |
| out_ptr[r] = static_cast<OutType>( | |
| relu ? std::max(sum + static_cast<float>(bias[r]), 0.f) | |
| : sum + static_cast<float>(bias[r])); | |
| } | |
| } | |
| } | |
| private: | |
| // Generate a random matrix with the specified sparsity. | |
| // Useful for testing. | |
| void init(float sparsity, int block_height, int block_width, float constant, | |
| bool random = true) { | |
| int reduced_rows = rows_ / block_height; | |
| int reduced_cols = cols_ / block_width; | |
| mask_.resize(rows_ * cols_, 0); | |
| // Fill with non-zero value to make sure masking works. | |
| values_.resize(rows_ * cols_, static_cast<T>(2.f)); | |
| std::mt19937 generator(0); | |
| std::uniform_real_distribution<float> dist_sparsity; | |
| std::uniform_real_distribution<float> dist_value(-1.f, 1.f); | |
| int nnz = 0; | |
| while (nnz == 0) { | |
| for (int r = 0; r < reduced_rows; ++r) { | |
| for (int c = 0; c < reduced_cols; ++c) { | |
| if (dist_sparsity(generator) > sparsity) { | |
| nnz++; | |
| for (int i = 0; i < block_height; ++i) { | |
| for (int j = 0; j < block_width; ++j) { | |
| mask_[(r * block_height + i) * cols_ + block_width * c + j] = 1; | |
| values_[(r * block_height + i) * cols_ + block_width * c + j] = | |
| static_cast<T>(random ? dist_value(generator) : constant); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| std::vector<int> mask_; | |
| std::vector<T> values_; | |
| int rows_; | |
| int cols_; | |
| float sparsity_; | |
| }; | |
| template <typename T> | |
| class MaskedLinearLayer { | |
| public: | |
| MaskedLinearLayer(MaskedSparseMatrix<T>&& weights, | |
| CacheAlignedVector<T>&& bias) | |
| : weights_(std::move(weights)), bias_(std::move(bias)) {} | |
| MaskedLinearLayer() {} | |
| template <typename U> | |
| void CastWeights() { | |
| weights_.template CastWeights<U>(); | |
| } | |
| // Does Ax + b where A is a masked sparse ROW MAJOR matrix and | |
| // x is a COLUMN MAJOR dense vector or matrix. Bias is a vector that is | |
| // broadcast is rhs has more than one column. | |
| template <typename FatVector> | |
| void SpMM_bias(const FatVector& rhs, FatVector* out, bool relu = false) { | |
| static_assert(std::is_same<typename FatVector::value_type, T>::value, | |
| "FatVector value_type must match masked_linear_layer type"); | |
| weights_.SpMM_bias(rhs, bias_, out, relu); | |
| } | |
| private: | |
| MaskedSparseMatrix<T> weights_; | |
| CacheAlignedVector<T> bias_; | |
| }; | |
| } // namespace csrblocksparse | |