Spaces:
Sleeping
Sleeping
| /* | |
| * Copyright 2021 Google LLC | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| // IWYU pragma: begin_exports | |
| // IWYU pragma: end_exports | |
| namespace csrblocksparse { | |
| // CsrBlockSparseMatrix stores a modified block compressed sparse row | |
| // representation of a sparse matrix. The ordering of the weights is modified | |
| // in the 16x1 and 1x1 cases so that a certain number (4 and 8 respectively) | |
| // of columns of weights are stored contiguously before moving on to the next | |
| // row. The 4x4 case stores each block contiguously. | |
| // | |
| // Currently it is constructed from a MaskedSparseMatrix which usees a dense | |
| // binary mask representation. The construction generates the compressed | |
| // representation. Further iterations will support a direct serialization | |
| // of the compressed representation. | |
| // | |
| // MaskedSparseMatrix masked_matrix(rows, cols, existing_mask, existing_values) | |
| // CsrBlockSparseMatrix matrix(masked_matrix) | |
| // | |
| // matrix.SpMV_bias(rhs, bias, &out); | |
| // | |
| // This class is thread compatible. | |
| template <typename WeightType, typename RhsType, typename DeltaType = int16_t> | |
| class CsrBlockSparseMatrix { | |
| public: | |
| CsrBlockSparseMatrix() {} | |
| // Reference used to indicate that this is an input and not an output. | |
| CsrBlockSparseMatrix(const uint8_t* const& buffer, const std::size_t& len) { | |
| ReadFromFlatBuffer(buffer, len); | |
| ComputeRHSIndices(); | |
| } | |
| template <typename InputType> | |
| CsrBlockSparseMatrix(const MaskedSparseMatrix<InputType>& masked_matrix) { | |
| sparsity_ = masked_matrix.sparsity(); | |
| rows_ = masked_matrix.rows(); | |
| cols_ = masked_matrix.cols(); | |
| DetermineBlockSize(masked_matrix); | |
| if (block_width_ == 1 && block_height_ == 1) | |
| col_multiple_ = 8; | |
| else | |
| col_multiple_ = 1; | |
| std::vector<InputType> weights(masked_matrix.values().begin(), | |
| masked_matrix.values().end()); | |
| reduced_rows_ = (rows_ + block_height_ - 1) / block_height_; | |
| rows_ = reduced_rows_ * block_height_; | |
| reduced_cols_ = cols_ / block_width_; | |
| // Calculate the reduced CSR representation of the matrix. | |
| std::vector<int> reduced_mask(reduced_rows_ * reduced_cols_); | |
| std::vector<int> row_offsets = {0}; | |
| int nnz = 0; | |
| const auto& mask = masked_matrix.mask(); | |
| for (int r = 0; r < reduced_rows_; ++r) { | |
| for (int c = 0; c < reduced_cols_; ++c) { | |
| int mask_val = mask[r * block_height_ * cols_ + c * block_width_]; | |
| reduced_mask[r * reduced_cols_ + c] = mask_val; | |
| nnz += mask_val; | |
| } | |
| row_offsets.push_back(nnz); | |
| } | |
| // Make sure the reduced representation has the correct number of columns. | |
| MakeColumnsMultiple(row_offsets, &reduced_mask, &weights); | |
| std::vector<int> col_indices; | |
| std::vector<WeightType> weights_csr; | |
| std::vector<int> nnz_per_row; | |
| MaskAndWeightsToCsr(reduced_mask, weights, &nnz_per_row, &col_indices, | |
| &weights_csr); | |
| // Generate column deltas from |col_indices|. | |
| std::vector<DeltaType> col_deltas; | |
| for (int i = 0; i < col_indices.size(); ++i) { | |
| // |col_indices| are used to index the RHS vector which is always float. | |
| int64_t diff = sizeof(RhsType); | |
| if (i == 0) | |
| diff *= block_width_ * (col_indices[i]); | |
| else | |
| diff *= block_width_ * (col_indices[i] - col_indices[i - 1]); | |
| CHECK(diff < std::numeric_limits<DeltaType>::max()) | |
| << "delta between column indices in bytes " << diff | |
| << " exceeded the maximum size of the DeltaType " | |
| << std::numeric_limits<DeltaType>::max(); | |
| col_deltas.push_back(static_cast<DeltaType>(diff)); | |
| } | |
| // Because of pre-fetching we need some extra values at the end. | |
| col_deltas.insert(col_deltas.end(), std::max(2, col_multiple_ + 1), 0); | |
| nnz_per_row.insert(nnz_per_row.end(), 2, nnz_per_row.back()); | |
| weights_ = CacheAlignedVector<WeightType>(weights_csr); | |
| col_deltas_ = CacheAlignedVector<DeltaType>(col_deltas); | |
| nnz_per_row_ = CacheAlignedVector<int>(nnz_per_row); | |
| ComputeRHSIndices(); | |
| num_threads_ = 0; | |
| PrepareForThreads(1); | |
| } | |
| // Constructor makes a matrix from the given weights, deltas and nnz, taking | |
| // the other parameters from |src_matrix|. |cols| is the number of raw columns | |
| // (NOT blocks) of the new matrix. | |
| CsrBlockSparseMatrix( | |
| const CsrBlockSparseMatrix<WeightType, RhsType, DeltaType>& src_matrix, | |
| const std::vector<WeightType>& new_weights, | |
| const std::vector<DeltaType>& new_deltas, const std::vector<int>& new_nnz, | |
| int cols) { | |
| num_threads_ = 0; | |
| col_multiple_ = src_matrix.col_multiple_; | |
| block_width_ = src_matrix.block_width_; | |
| block_height_ = src_matrix.block_height_; | |
| reduced_rows_ = new_nnz.size(); | |
| rows_ = reduced_rows_ * block_height_; | |
| cols_ = cols; | |
| reduced_cols_ = cols_ / block_width_; | |
| weights_ = CacheAlignedVector<WeightType>(new_weights); | |
| col_deltas_ = CacheAlignedVector<DeltaType>(new_deltas); | |
| nnz_per_row_ = CacheAlignedVector<int>(new_nnz); | |
| sparsity_ = 1.0f - static_cast<float>(new_weights.size()) / (rows_ * cols_); | |
| ComputeRHSIndices(); | |
| name_ = src_matrix.name_; | |
| PrepareForThreads(1); | |
| } | |
| // Factory method takes a column slice out of *this and returns a sparse | |
| // matrix that takes as inputs [|start_col|, |end_col|) of *this, and | |
| // returns the same number of outputs, but only a partial result. | |
| // If |keep_rhs_size|, then the new matrix takes the same rhs as the current | |
| // matrix, but uses a subset of it, instead of expecting just the reduced rhs. | |
| // If |start_col| > |end_col|, then we slice out the complement of the defined | |
| // interval, ie [0, |end_col|) + [|start_col|, current end). | |
| // NOTE That |start_col| and |end_col| are in raw column coordinates, NOT | |
| // block units. | |
| CsrBlockSparseMatrix SplitByColumn(int start_col, int end_col, | |
| bool keep_rhs_size = false) const { | |
| int weight_index = 0; | |
| int delta_index = 0; | |
| std::vector<DeltaType> new_deltas; | |
| std::vector<WeightType> new_weights; | |
| std::vector<int> new_nnz(reduced_rows_); | |
| int col = 0; | |
| int prev_col = keep_rhs_size ? 0 : start_col; | |
| for (int r = 0; r < reduced_rows_; ++r) { | |
| int reduced_col_count = nnz_per_row_[r]; | |
| for (int c = 0; c < reduced_col_count; ++c, ++delta_index) { | |
| col += col_deltas_[delta_index] / sizeof(RhsType); | |
| if ((start_col < end_col && start_col <= col && col < end_col) || | |
| (start_col > end_col && (col < end_col || col >= start_col))) { | |
| ++new_nnz[r]; | |
| new_deltas.push_back((col - prev_col) * sizeof(RhsType)); | |
| prev_col = col; | |
| for (int i = 0; i < block_width_ * block_height_; | |
| ++i, ++weight_index) { | |
| new_weights.push_back(weights_[weight_index]); | |
| } | |
| } else { | |
| weight_index += block_width_ * block_height_; | |
| } | |
| } | |
| } | |
| int new_cols = keep_rhs_size ? cols_ : end_col - start_col; | |
| return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz, | |
| new_cols); | |
| } | |
| // Factory method takes a row slice out of *this and returns a sparse | |
| // matrix that takes the sampe inputs as *this, and returns the outputs for | |
| // the range [|start_row|, |end_row|). | |
| // NOTE That |start_row| and |end_row| are in raw column coordinates, NOT | |
| // block units. | |
| CsrBlockSparseMatrix SplitByRow(int start_row, int end_row) const { | |
| int start_reduced = start_row / block_height_; | |
| int end_reduced = end_row / block_height_; | |
| std::vector<int> new_nnz(nnz_per_row_.data() + start_reduced, | |
| nnz_per_row_.data() + end_reduced); | |
| int weight_start = 0; | |
| for (int r = 0; r < start_reduced; ++r) { | |
| weight_start += nnz_per_row_[r]; | |
| } | |
| int weight_end = weight_start; | |
| for (int r = start_reduced; r < end_reduced; ++r) { | |
| weight_end += nnz_per_row_[r]; | |
| } | |
| int delta_start = 0; | |
| for (int i = 0; i < weight_start; ++i) { | |
| delta_start += col_deltas_[i]; | |
| } | |
| std::vector<DeltaType> new_deltas(col_deltas_.data() + weight_start, | |
| col_deltas_.data() + weight_end); | |
| new_deltas[0] += delta_start; | |
| int block_size = block_height_ * block_width_; | |
| std::vector<WeightType> new_weights( | |
| weights_.data() + weight_start * block_size, | |
| weights_.data() + weight_end * block_size); | |
| return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz, cols_); | |
| } | |
| // Combines adjacent row blocks, doubling the block height. | |
| // This necessarily involves adding zero weights where the blocks don't align | |
| // across adjacent pairs of rows, so use with caution, as the resulting matrix | |
| // is most likely to run slower if very sparse to begin with. | |
| // In the few cases where the blocks do mostly align, the resulting matmul | |
| // could be much faster, as the number of reads of the rhs will be halved. | |
| void DoubleBlockHeight() { | |
| int new_rows = reduced_rows_ / 2; | |
| std::vector<int> new_nnz(new_rows); | |
| std::vector<DeltaType> new_rhs_indices; | |
| std::vector<WeightType> new_weights; | |
| int rhs_index1 = 0; | |
| int rhs_index2 = 0; | |
| int block_size = block_height_ * block_width_; | |
| for (int r = 0; r < new_rows; ++r) { | |
| int start_nnz = new_rhs_indices.size(); | |
| rhs_index2 += nnz_per_row_[r * 2]; | |
| int end1 = rhs_index1 + nnz_per_row_[r * 2]; | |
| int end2 = rhs_index2 + nnz_per_row_[r * 2 + 1]; | |
| // Run over a pair of rows with 2 iterators, combining blocks as we go, or | |
| // padding with zeros where the block positions don't match. | |
| while (rhs_index1 < end1 || rhs_index2 < end2) { | |
| int col1 = rhs_index1 < end1 ? rhs_indices_[rhs_index1] : reduced_cols_; | |
| int col2 = rhs_index2 < end2 ? rhs_indices_[rhs_index2] : reduced_cols_; | |
| if (col1 < col2) { | |
| // Need zero weights for row2 to pad out weights block. | |
| new_rhs_indices.push_back(col1); | |
| new_weights.insert(new_weights.end(), | |
| weights_.data() + rhs_index1 * block_size, | |
| weights_.data() + (rhs_index1 + 1) * block_size); | |
| new_weights.insert(new_weights.end(), block_size, | |
| static_cast<WeightType>(0.0f)); | |
| ++rhs_index1; | |
| } else if (col1 > col2) { | |
| // Need zero weights for row1 to pad out weights block. | |
| new_rhs_indices.push_back(col2); | |
| new_weights.insert(new_weights.end(), block_size, | |
| static_cast<WeightType>(0.0f)); | |
| new_weights.insert(new_weights.end(), | |
| weights_.data() + rhs_index2 * block_size, | |
| weights_.data() + (rhs_index2 + 1) * block_size); | |
| ++rhs_index2; | |
| } else { | |
| // Combine weights for both row1 and row2. | |
| new_rhs_indices.push_back(col1); | |
| new_weights.insert(new_weights.end(), | |
| weights_.data() + rhs_index1 * block_size, | |
| weights_.data() + (rhs_index1 + 1) * block_size); | |
| new_weights.insert(new_weights.end(), | |
| weights_.data() + rhs_index2 * block_size, | |
| weights_.data() + (rhs_index2 + 1) * block_size); | |
| ++rhs_index1; | |
| ++rhs_index2; | |
| } | |
| } | |
| rhs_index1 = rhs_index2; | |
| new_nnz[r] = new_rhs_indices.size() - start_nnz; | |
| } | |
| block_height_ *= 2; | |
| reduced_rows_ /= 2; | |
| weights_ = CacheAlignedVector<WeightType>(new_weights); | |
| rhs_indices_ = CacheAlignedVector<DeltaType>(new_rhs_indices); | |
| nnz_per_row_ = CacheAlignedVector<int>(new_nnz); | |
| sparsity_ = 1.0f - static_cast<float>(new_weights.size()) / (rows_ * cols_); | |
| ComputeColDeltas(); | |
| if (num_threads_ > 0) { | |
| int num_threads = num_threads_; | |
| num_threads_ = 0; | |
| PrepareForThreads(num_threads); | |
| } | |
| } | |
| // Allocates memory and fills buffer. | |
| // Caller is responsible for the memory de-allocation. | |
| // TODO(b/189958858): Both Read and Write need to eventually handle the | |
| // different possible HalfType and DeltaType values, but punting for now as | |
| // there is only one supported combination. | |
| std::size_t WriteToFlatBuffer(std::string* csr_flatbuffer) { | |
| std::size_t bytes = 0; | |
| bytes += FixedParameterSize(); | |
| bytes += weights_.size() * sizeof(WeightType); | |
| bytes += col_deltas_.size() * sizeof(DeltaType); | |
| bytes += nnz_per_row_.size() * sizeof(int); | |
| uint8_t* bytes_ptr_ptr = | |
| reinterpret_cast<uint8_t*>(CHECK_NOTNULL(malloc(bytes))); | |
| int* int_bytes_ptr = reinterpret_cast<int*>(bytes_ptr_ptr); | |
| *int_bytes_ptr++ = rows_; | |
| *int_bytes_ptr++ = cols_; | |
| *int_bytes_ptr++ = reduced_rows_; | |
| *int_bytes_ptr++ = reduced_cols_; | |
| *int_bytes_ptr++ = block_width_; | |
| *int_bytes_ptr++ = block_height_; | |
| *int_bytes_ptr++ = col_multiple_; | |
| *int_bytes_ptr++ = num_threads_; | |
| *int_bytes_ptr++ = weights_.size(); | |
| *int_bytes_ptr++ = col_deltas_.size(); | |
| *int_bytes_ptr++ = nnz_per_row_.size(); | |
| float* float_bytes_ptr = reinterpret_cast<float*>(int_bytes_ptr); | |
| *float_bytes_ptr++ = sparsity_; | |
| uint8_t* bytes_ptr = reinterpret_cast<uint8_t*>(float_bytes_ptr); | |
| memcpy(bytes_ptr, weights_.data(), weights_.size() * sizeof(WeightType)); | |
| bytes_ptr += weights_.size() * sizeof(WeightType); | |
| memcpy(bytes_ptr, col_deltas_.data(), | |
| col_deltas_.size() * sizeof(DeltaType)); | |
| bytes_ptr += col_deltas_.size() * sizeof(DeltaType); | |
| memcpy(bytes_ptr, nnz_per_row_.data(), nnz_per_row_.size() * sizeof(int)); | |
| bytes_ptr += nnz_per_row_.size() * sizeof(int); | |
| csr_flatbuffer->resize(bytes); | |
| csr_flatbuffer->assign(reinterpret_cast<char*>(bytes_ptr_ptr), bytes); | |
| free(bytes_ptr_ptr); | |
| return bytes; | |
| } | |
| void ReadFromFlatBuffer(const uint8_t* const& bytes, const std::size_t& len) { | |
| CHECK_GE(len, FixedParameterSize()); | |
| const int* int_bytes_ptr = reinterpret_cast<const int*>(bytes); | |
| rows_ = *int_bytes_ptr++; | |
| cols_ = *int_bytes_ptr++; | |
| reduced_rows_ = *int_bytes_ptr++; | |
| reduced_cols_ = *int_bytes_ptr++; | |
| block_width_ = *int_bytes_ptr++; | |
| block_height_ = *int_bytes_ptr++; | |
| col_multiple_ = *int_bytes_ptr++; | |
| int num_threads = *int_bytes_ptr++; | |
| int32_t weights_size = *int_bytes_ptr++; | |
| int32_t col_deltas_size = *int_bytes_ptr++; | |
| int32_t nnz_per_row_size = *int_bytes_ptr++; | |
| // Make sure negative sizes don't mess things up. | |
| weights_size = std::max(0, weights_size); | |
| col_deltas_size = std::max(0, col_deltas_size); | |
| nnz_per_row_size = std::max(0, nnz_per_row_size); | |
| const float* float_bytes_ptr = | |
| reinterpret_cast<const float*>(int_bytes_ptr); | |
| sparsity_ = *float_bytes_ptr++; | |
| std::size_t total_bytes = | |
| FixedParameterSize() + weights_size * sizeof(WeightType) + | |
| col_deltas_size * sizeof(DeltaType) + nnz_per_row_size * sizeof(int); | |
| CHECK_EQ(total_bytes, len) | |
| << "total bytes: " << total_bytes << ", actual len given: " << len; | |
| const uint8_t* bytes_ptr = | |
| reinterpret_cast<const uint8_t*>(float_bytes_ptr); | |
| std::vector<WeightType> weights_raw(weights_size); | |
| memcpy(weights_raw.data(), bytes_ptr, weights_size * sizeof(WeightType)); | |
| weights_ = CacheAlignedVector<WeightType>(weights_raw); | |
| bytes_ptr += weights_size * sizeof(WeightType); | |
| std::vector<DeltaType> deltas_raw(col_deltas_size); | |
| memcpy(deltas_raw.data(), bytes_ptr, col_deltas_size * sizeof(DeltaType)); | |
| col_deltas_ = CacheAlignedVector<DeltaType>(deltas_raw); | |
| bytes_ptr += col_deltas_size * sizeof(DeltaType); | |
| std::vector<int> nnz_raw(nnz_per_row_size); | |
| memcpy(nnz_raw.data(), bytes_ptr, nnz_per_row_size * sizeof(int)); | |
| nnz_per_row_ = CacheAlignedVector<int>(nnz_raw); | |
| num_threads_ = 0; | |
| PrepareForThreads(num_threads); | |
| } | |
| // Multiply a Sparse matrix by a possibly dense matrix. Often the matrix is | |
| // a vector with a small number of columns, hence the term "fat vector". | |
| // 1x1 and 4x4 have specializations for output columns (ie fatness) > 5, | |
| // and often achieve twice as many GFlops when multiplying a right hand side | |
| // that has 5 or more columns. (Best is a multiple of 5). | |
| // 16x1 doesn't have enough registers and just loops over the width 1 kernel. | |
| // | |
| // |rhs| and |out| are COLUMN MAJOR. | |
| // Fast Tuples WeightType, BiasType, RhsType, OutType are: | |
| // (float, float, float, float) | |
| // (bfloat16, float, float, float) | |
| // and only on ARM64. All other cases use a slow generic implementation. | |
| template <typename RhsClass, typename BiasClass, typename OutClass, | |
| typename BiasType = typename BiasClass::value_type, | |
| typename OutType = typename OutClass::value_type> | |
| void SpMM_bias(const RhsClass& rhs, const BiasClass& bias, OutClass* out, | |
| bool relu = false, int tid = 0, | |
| SpinBarrier* barrier = nullptr) const { | |
| static_assert(std::is_same<typename RhsClass::value_type, RhsType>::value, | |
| "Rhs types must match"); | |
| CHECK_LT(tid, num_threads_); | |
| CHECK_EQ(rhs.cols(), out->cols()); | |
| CHECK_EQ(rhs.rows(), cols_); | |
| CHECK_GE(out->rows(), rows_); | |
| int cols_to_go = out->cols(); | |
| int rhs_index = *thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid); | |
| const RhsType* rhs_ptr = rhs.data() + rhs_index * block_height_; | |
| OutType* out_ptr = thread_bounds_.OffsetOutput(out->data(), tid); | |
| const WeightType* weights_ptr = | |
| thread_bounds_.OffsetWeights(weights_.data(), tid); | |
| const DeltaType* delta_ptr = | |
| thread_bounds_.OffsetRhsIndices(col_deltas_.data(), tid); | |
| int offset = *delta_ptr / sizeof(RhsType); | |
| rhs_ptr -= offset; | |
| const int* nnz_ptr = nnz_per_row_.data() + thread_bounds_.StartRow(tid); | |
| int assigned_rows = | |
| thread_bounds_.StartRow(tid + 1) - thread_bounds_.StartRow(tid); | |
| const BiasType* bias_ptr = thread_bounds_.OffsetBias(bias.data(), tid); | |
| while (cols_to_go > 0) { | |
| if (block_width_ == 4 && block_height_ == 4) { | |
| if (cols_to_go >= 5) { | |
| detail::SpMM5_4x4<WeightType, RhsType, OutType>( | |
| weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, | |
| assigned_rows, out->col_stride(), rhs.col_stride(), relu); | |
| } else { | |
| detail::SpMV_4x4<WeightType, RhsType, OutType>( | |
| weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, | |
| assigned_rows, out->col_stride(), rhs.col_stride(), relu); | |
| } | |
| } else { | |
| if (cols_to_go >= 5) { | |
| detail::SpMM5_1x1<WeightType, RhsType, OutType>( | |
| weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, | |
| assigned_rows, out->col_stride(), rhs.col_stride(), relu); | |
| } else { | |
| detail::SpMV_1x1<WeightType, RhsType, OutType>( | |
| weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, | |
| assigned_rows, out->col_stride(), rhs.col_stride(), relu); | |
| } | |
| } | |
| if (cols_to_go >= 5) { | |
| cols_to_go -= 5; | |
| rhs_ptr += rhs.col_stride() * 5; | |
| out_ptr += out->col_stride() * 5; | |
| } else { | |
| cols_to_go--; | |
| rhs_ptr += rhs.col_stride(); | |
| out_ptr += out->col_stride(); | |
| } | |
| if (barrier) barrier->barrier(); | |
| } | |
| } | |
| template <typename MVRhsType, typename MVBiasType, typename OutType> | |
| void MatVec(const MVRhsType* rhs, const MVBiasType* bias, bool relu, int tid, | |
| int replicas, int output_stride, OutType* output) { | |
| CHECK_LT(tid, num_threads_); | |
| CHECK_EQ(block_width_, 4) << "Block width must be 4!"; | |
| if (block_height_ == 8) { | |
| matmul_.MatVec8x4( | |
| thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs, | |
| thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(), | |
| thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid), | |
| thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu, | |
| replicas, output_stride, thread_bounds_.OffsetOutput(output, tid)); | |
| } else { | |
| CHECK_EQ(block_height_, 4) << "Block height must be 4 or 8!"; | |
| matmul_.MatVec4x4( | |
| thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs, | |
| thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(), | |
| thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid), | |
| thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu, | |
| replicas, output_stride, thread_bounds_.OffsetOutput(output, tid)); | |
| } | |
| } | |
| int rows() const { return rows_; } | |
| int cols() const { return cols_; } | |
| int block_height() const { return block_height_; } | |
| int block_width() const { return block_width_; } | |
| float sparsity() const { return sparsity_; } | |
| int num_threads() const { return num_threads_; } | |
| const ThreadBounds& thread_bounds() const { return thread_bounds_; } | |
| const CacheAlignedVector<DeltaType>& rhs_indices() const { | |
| return rhs_indices_; | |
| } | |
| const std::string& name() const { return name_; } | |
| void set_name(const std::string& name) { name_ = name; } | |
| const std::vector<int>& split_points() const { | |
| return thread_bounds_.row_starts(); | |
| } | |
| std::size_t bytes() const { | |
| return weights_.size() * sizeof(WeightType) + | |
| col_deltas_.size() * sizeof(DeltaType) + | |
| nnz_per_row_.size() * sizeof(int); | |
| } | |
| // Multiplies a sparse matrix by a possibly dense matrix, as SpMM_bias above, | |
| // and then samples from the output (softmax distribution) layer. | |
| template <typename RhsClass, typename BiasClass, typename OutClass, | |
| typename BiasType = typename BiasClass::value_type, | |
| typename OutType = typename OutClass::value_type> | |
| typename std::enable_if<!IsFixed32Type<OutType>::value, int>::type | |
| SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out, | |
| float temperature, int tid, SpinBarrier* barrier, | |
| std::minstd_rand* gen, | |
| CacheAlignedVector<float>* scratch) const { | |
| SpMM_bias(rhs, bias, out, /*relu=*/false, tid, barrier); | |
| return out->Sample(temperature, gen, scratch); | |
| } | |
| // Fixed32 version. | |
| template <typename RhsClass, typename BiasClass, typename OutClass, | |
| typename BiasType = typename BiasClass::value_type, | |
| typename OutType = typename OutClass::value_type> | |
| typename std::enable_if<IsFixed32Type<OutType>::value, int>::type | |
| SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out, | |
| float temperature, int tid, SpinBarrier* barrier, | |
| std::minstd_rand* gen, | |
| CacheAlignedVector<float>* scratch) const { | |
| // We don't pass the barrier on, as we have more work to do. | |
| SpMM_bias(rhs, bias, out, /*relu=*/false, tid); | |
| return out->ReducingSample(gen, scratch, tid, temperature, barrier); | |
| } | |
| void Print() const { | |
| std::cout << "Weights\n"; | |
| weights_.Print(); | |
| std::cout << std::endl; | |
| std::cout << "Deltas\n"; | |
| col_deltas_.Print(); | |
| std::cout << std::endl; | |
| std::cout << "nnz\n"; | |
| nnz_per_row_.Print(); | |
| std::cout << std::endl; | |
| } | |
| // Split the computation amongst threads by rows based on the number of | |
| // non zeros, with the addition of a constant to account for the work of the | |
| // bias and the horizontal add at the end, and also guarantees that each | |
| // thread writes only whole cache lines, based on the size of OutType. | |
| // The |cache_line_size| arg is used only for testing. Normally it is provided | |
| // through the architecture #defines. | |
| // Each thread gets a contiguous row range (|split_points|). | |
| // Thread t does rows [ split_points[t], split_points[t + 1] ) | |
| // Each thread also needs to know how many non zeros were before it to skip | |
| // (|nnz_to_skip|). And finally it also needs to know what the offset into | |
| // the rhs vector would have been at the split point (|rhs_to_skip|). | |
| // | |
| // Some tricky corner cases where the number of non-zeros doesn't split | |
| // nicely amongst the number of requested threads are not handled and default | |
| // to one thread; these cases are only going to happen in tests and not in | |
| // the matrices that correspond in real models. | |
| // | |
| // Returns the maximum number of threads that can be used; <= |num_threads|. | |
| template <typename OutType = int32_t> | |
| int PrepareForThreads(int num_threads, int cache_line_size = -1) { | |
| CHECK_GT(num_threads, 0); | |
| // we've already prepared for this number of threads, nothing to do | |
| if (num_threads == num_threads_) return num_threads_; | |
| num_threads_ = num_threads; | |
| thread_bounds_.PrepareForThreads( | |
| block_width_, block_height_, num_threads_, | |
| ReducedRowsPerCacheLine<OutType>(cache_line_size), reduced_rows_, | |
| nnz_per_row_.data()); | |
| return num_threads_; | |
| } | |
| // Computes and stores the |rhs_indices_| from the |col_deltas_|. | |
| void ComputeRHSIndices() { | |
| std::vector<int> cumulative_deltas = CumulativeColDeltas(); | |
| std::vector<DeltaType> rhs_indices(cumulative_deltas.size() + | |
| reduced_rows_); | |
| int total_indices = 0; | |
| int delta_index = 0; | |
| for (int r = 0; r < reduced_rows_; ++r) { | |
| for (int n = 0; n < nnz_per_row_[r]; ++n, ++delta_index) { | |
| rhs_indices[total_indices++] = | |
| cumulative_deltas[delta_index] / block_width_; | |
| } | |
| } | |
| rhs_indices_ = CacheAlignedVector<DeltaType>(rhs_indices); | |
| } | |
| // Computes and stores the |col_deltas_| from the |rhs_indices_|. | |
| void ComputeColDeltas() { | |
| std::vector<int> col_deltas(rhs_indices_.size()); | |
| int prev_index = 0; | |
| for (int i = 0; i < rhs_indices_.size(); ++i) { | |
| int offset = rhs_indices_[i] - prev_index; | |
| prev_index = rhs_indices_[i]; | |
| col_deltas[i] = offset * block_width_ * sizeof(RhsType); | |
| } | |
| col_deltas_ = CacheAlignedVector<DeltaType>(col_deltas); | |
| } | |
| // Computes and returns the inclusive prefix sum of the deltas, ie absolute | |
| // positions. | |
| std::vector<int> CumulativeColDeltas() const { | |
| std::vector<int> cum_col_deltas(col_deltas_.size()); | |
| for (int i = 0; i < col_deltas_.size(); ++i) { | |
| cum_col_deltas[i] = col_deltas_[i] / sizeof(RhsType); | |
| if (i > 0) cum_col_deltas[i] += cum_col_deltas[i - 1]; | |
| } | |
| return cum_col_deltas; | |
| } | |
| private: | |
| constexpr std::size_t FixedParameterSize() const { | |
| return sizeof(int) // rows | |
| + sizeof(int) // cols | |
| + sizeof(int) // reduced_rows | |
| + sizeof(int) // reduced_cols | |
| + sizeof(int) // block_width | |
| + sizeof(int) // block_height | |
| + sizeof(float) // sparsity | |
| + sizeof(int) // col_multiple | |
| + sizeof(int) // num_threads_ | |
| + sizeof(int) // weights_.size() | |
| + sizeof(int) // col_deltas_.size() | |
| + sizeof(int); // nnz_per_row_.size() | |
| } | |
| // Possible block sizes are only those that are supported by the computation | |
| // default is 1x1, other options are 4x4 and 16x1. | |
| template <typename InputType> | |
| void DetermineBlockSize(const MaskedSparseMatrix<InputType>& masked_matrix) { | |
| const std::vector<std::pair<int, int>> kPreferredOrder = {{4, 4}}; | |
| int rows = masked_matrix.rows(); | |
| int cols = masked_matrix.cols(); | |
| for (const auto& block_size : kPreferredOrder) { | |
| int block_height, block_width; | |
| std::tie(block_height, block_width) = block_size; | |
| if (cols % block_width != 0) continue; | |
| int reduced_rows = (rows + block_height - 1) / block_height; | |
| int reduced_cols = cols / block_width; | |
| // For each possible block, confirm that it is either all 0s or all 1s. | |
| bool all_same = true; | |
| const auto& mask = masked_matrix.mask(); | |
| for (int r = 0; r < reduced_rows; ++r) { | |
| for (int c = 0; c < reduced_cols; ++c) { | |
| int val = mask[r * block_height * cols + c * block_width]; | |
| for (int i = 0; i < block_height; ++i) { | |
| for (int j = 0; j < block_width; ++j) { | |
| int index = (r * block_height + i) * cols + c * block_width + j; | |
| if (index < masked_matrix.mask().size()) { | |
| all_same &= (masked_matrix.mask()[index] == val); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| // If this block configuration is possible, accept it. | |
| if (all_same) { | |
| block_height_ = block_height; | |
| block_width_ = block_width; | |
| return; | |
| } | |
| } | |
| // No large blocks were found, default to 1x1. | |
| block_height_ = 1; | |
| block_width_ = 1; | |
| } | |
| // CSR descriptors are for the reduced matrix, weights is the full matrix. | |
| template <typename InputType> | |
| void MakeColumnsMultiple(const std::vector<int>& row_offsets, | |
| std::vector<int>* reduced_mask, | |
| std::vector<InputType>* weights) { | |
| if (col_multiple_ > 0) { | |
| // Make sure each row has a number of columns that is a multiple of | |
| // |col_multiple|. | |
| for (int r = 1; r < row_offsets.size(); ++r) { | |
| int num_row = row_offsets[r] - row_offsets[r - 1]; | |
| int num_needed = col_multiple_ - num_row % col_multiple_; | |
| if (num_needed < col_multiple_) { | |
| // Find gaps in the columns where we can insert a column of 0 weights. | |
| int num_added = 0; | |
| for (int c = 0; c < reduced_cols_; ++c) { | |
| if ((*reduced_mask)[(r - 1) * reduced_cols_ + c] == 0) { | |
| (*reduced_mask)[(r - 1) * reduced_cols_ + c] = 1; | |
| // Zero out the weights that correspond to this block. | |
| for (int i = 0; i < block_height_; ++i) { | |
| for (int j = 0; j < block_width_; ++j) { | |
| (*weights)[((r - 1) * block_height_ + i) * cols_ + | |
| block_width_ * c + j] = InputType(0.f); | |
| } | |
| } | |
| num_added++; | |
| } | |
| if (num_added == num_needed) break; | |
| } | |
| } | |
| } | |
| } | |
| } | |
| // Given the final dense mask and weights, convert to the compressed | |
| // block CSR representation. | |
| template <typename InputType> | |
| void MaskAndWeightsToCsr(const std::vector<int>& mask, | |
| const std::vector<InputType>& weights, | |
| std::vector<int>* nnz_per_row, | |
| std::vector<int>* col_indices, | |
| std::vector<WeightType>* weights_csr) { | |
| std::vector<int> row_offsets = {0}; | |
| int nnz = 0; | |
| // Standard CSR format. | |
| if (block_width_ == 1 && block_height_ == 1) { | |
| for (int r = 0; r < rows_; ++r) { | |
| for (int c = 0; c < cols_; ++c) { | |
| if (mask[r * cols_ + c] == 1) { | |
| nnz++; | |
| col_indices->push_back(c); | |
| weights_csr->push_back(WeightType(weights[r * cols_ + c])); | |
| } | |
| } | |
| row_offsets.push_back(nnz); | |
| } | |
| } else if (block_width_ == 4 && block_height_ == 4) { | |
| // Weights are stored contiguously for each block in this case. | |
| for (int r = 0; r < reduced_rows_; ++r) { | |
| for (int c = 0; c < reduced_cols_; ++c) { | |
| if (mask[r * reduced_cols_ + c] == 1) { | |
| col_indices->push_back(c); | |
| nnz++; | |
| for (int i = 0; i < block_height_; ++i) { | |
| for (int j = 0; j < block_width_; ++j) { | |
| int row_index = (block_height_ * r + i) * cols_; | |
| int w_index = row_index + block_width_ * c + j; | |
| WeightType weight = w_index < weights.size() | |
| ? WeightType(weights[w_index]) | |
| : WeightType(0.0f); | |
| weights_csr->push_back(weight); | |
| } | |
| } | |
| } | |
| } | |
| row_offsets.push_back(nnz); | |
| } | |
| } | |
| for (int i = 1; i < row_offsets.size(); ++i) | |
| nnz_per_row->push_back(row_offsets[i] - row_offsets[i - 1]); | |
| } | |
| // Returns the number of block rows per cache line. This is the minimum unit | |
| // into which the calculation is broken for threads. | |
| template <typename OutType> | |
| int ReducedRowsPerCacheLine(int override_cache_line_size = -1) const { | |
| int line_size = kCacheLineSize; | |
| if (override_cache_line_size >= 1) line_size = override_cache_line_size; | |
| return std::max<int>(line_size / (block_height_ * sizeof(OutType)), 1); | |
| } | |
| int col_multiple_; | |
| int rows_; | |
| int cols_; | |
| int reduced_rows_; | |
| int reduced_cols_; | |
| float sparsity_; | |
| int block_width_; | |
| int block_height_; | |
| int num_threads_; | |
| std::string name_; | |
| CacheAlignedVector<WeightType> weights_; | |
| CacheAlignedVector<DeltaType> col_deltas_; | |
| CacheAlignedVector<int> nnz_per_row_; | |
| // |thread_bounds_| and |rhs_indices_| don't need to be serialized as they are | |
| // always recalculated from serialized data. | |
| CacheAlignedVector<DeltaType> rhs_indices_; | |
| Matmul<WeightType, RhsType> matmul_; | |
| ThreadBounds thread_bounds_; | |
| static constexpr int kCacheLineSize = 64; | |
| }; | |
| // Converts a sparse matrix represented with (|mask|, |weights|, |size|) into | |
| // the CSR format, and returns that as a serialized string. | |
| template <typename MaskType> | |
| std::string ConvertDenseToSparseRepresentation_Int16Deltas( | |
| const std::vector<MaskType>& mask, const std::vector<float>& weights, | |
| const int rows, const int cols) { | |
| MaskedSparseMatrix<float> masked_weights(rows, cols, mask.data(), | |
| weights.data()); | |
| CsrBlockSparseMatrix<csrblocksparse::bfloat16, float, int16_t> | |
| sparse_masked_weights(masked_weights); | |
| std::string buffer; | |
| sparse_masked_weights.WriteToFlatBuffer(&buffer); | |
| return buffer; | |
| } | |
| } // namespace csrblocksparse | |