Spaces:
Sleeping
Sleeping
| /* | |
| * Copyright 2021 Google LLC | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| namespace csrblocksparse { | |
| namespace detail { | |
| template <typename T> | |
| struct IsFloatOrBfloat | |
| : std::integral_constant<bool, std::is_same<T, float>::value || | |
| std::is_same<T, bfloat16>::value> {}; | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| struct IsAllowableFloatTypes | |
| : std::integral_constant<bool, IsFloatOrBfloat<WeightType>::value && | |
| std::is_same<RhsType, float>::value && | |
| std::is_same<OutType, float>::value> {}; | |
| // 16-bit inputs, 32-bit output exponent matches sum of input exponents | |
| // OR | |
| // 16-bit inputs, 16-bit output - will shift to match exponent | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| struct IsAllowableFixedTypes | |
| : std::integral_constant<bool, (IsFixed16Type<WeightType>::value && | |
| IsFixed16Type<RhsType>::value) && | |
| (IsFixed32Type<OutType>::value || | |
| IsFixed16Type<OutType>::value)> {}; | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| struct ShouldEnableGenericKernel | |
| : std::integral_constant< | |
| bool, | |
| !IsAllowableFloatTypes<WeightType, RhsType, OutType>::value && | |
| !IsAllowableFixedTypes<WeightType, RhsType, OutType>::value> {}; | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| struct ShouldEnableGenericSpMV_4x4 | |
| : ShouldEnableGenericKernel<WeightType, RhsType, OutType> {}; | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| struct ShouldEnableGenericSpMM5_4x4 | |
| : ShouldEnableGenericKernel<WeightType, RhsType, OutType> {}; | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| struct ShouldEnableGenericSpMV_1x1 : std::true_type {}; | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| struct ShouldEnableGenericSpMM5_1x1 : std::true_type {}; | |
| template <typename Type> | |
| struct IsAddableFixedTypes | |
| : std::integral_constant<bool, IsFixed32Type<Type>::value || | |
| IsFixed16Type<Type>::value> {}; | |
| template <typename Type> | |
| struct ShouldEnableGenericAdd | |
| : std::integral_constant<bool, !IsAddableFixedTypes<Type>::value> {}; | |
| // The computational routines do NO error checking for speed. It is assumed | |
| // that this has been handled by CSRBlockSparseMatrix. | |
| // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 | |
| // blocked pattern, x is a vector and b is vector. Weights are stored for this | |
| // routine by making each 4x4 block contiguous. Blocks are ordered in standard | |
| // row-major format. column indices are converted to deltas and then multiplied | |
| // by 2 to convert to bytes, so that the value can be used directly to offset | |
| // the pointer into the rhs vector. | |
| // | |
| // NOTE: The bias is expected to have be multiplied by .25f prior to calling | |
| // this function. This is automatically taken care of in SparseLinearLayer. | |
| // The bias is reconstructed through horizontal additions, leads to a small | |
| // speedup by reducing latencies at the end of the loop. | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| typename std::enable_if<std::is_same<WeightType, bfloat16>::value && | |
| std::is_same<RhsType, float>::value && | |
| std::is_same<OutType, float>::value>::type | |
| SpMV_4x4(const bfloat16* weights_ptr, const int16_t* col_deltas_bytes, | |
| const int32_t* nnz_per_row, const float* rhs_ptr, | |
| const float* bias_ptr, float* out_ptr, int64_t assigned_rows, | |
| int64_t rows /* only used in SpMM variants */, | |
| int64_t cols /* only used in SpMM variants */, int relu) { | |
| /* This instrinsic version exists for reference, note that in the | |
| intrinsic version col_deltas_bytes should NOT actually be in bytes, | |
| but rather elements. Intrinsics are 25-35% slower than the | |
| assembly version. | |
| for (int r = 0; r < rows; r += 4) { | |
| int reduced_col_count = nnz_per_row[r / 4]; | |
| float32x4_t accum0 = vdupq_n_f32(bias_ptr + r); | |
| float32x4_t accum1 = vdupq_n_f32(bias_ptr + r + 1); | |
| float32x4_t accum2 = vdupq_n_f32(bias_ptr + r + 2); | |
| float32x4_t accum3 = vdupq_n_f32(bias_ptr + r + 3); | |
| for (int c = 0; c < reduced_col_count; ++c) { | |
| int32_t offset = *col_deltas_bytes; col_deltas_bytes++; | |
| rhs_ptr += offset; | |
| float32x4_t rhs = vld1q_f32(rhs_ptr); | |
| uint16x4_t lhs0_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| uint16x4_t lhs1_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| uint16x4_t lhs2_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| uint16x4_t lhs3_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| float32x4_t lhs0 = vreinterpretq_f32_u32(vshll_n_u16(lhs0_int, 16)); | |
| float32x4_t lhs1 = vreinterpretq_f32_u32(vshll_n_u16(lhs1_int, 16)); | |
| float32x4_t lhs2 = vreinterpretq_f32_u32(vshll_n_u16(lhs2_int, 16)); | |
| float32x4_t lhs3 = vreinterpretq_f32_u32(vshll_n_u16(lhs3_int, 16)); | |
| accum0 = vmlaq_f32(accum0, lhs0, rhs); | |
| accum1 = vmlaq_f32(accum1, lhs1, rhs); | |
| accum2 = vmlaq_f32(accum2, lhs2, rhs); | |
| accum3 = vmlaq_f32(accum3, lhs3, rhs); | |
| } | |
| float32x4_t reduce0 = vpaddq_f32(accum0, accum1); | |
| float32x4_t reduce1 = vpaddq_f32(accum2, accum3); | |
| float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); | |
| vst1q_f32(out_ptr + r, reduce2); | |
| } */ | |
| // If the relu is handled in the routine with a comparison and vbit (insert | |
| // if true), or by branching, then it is slightly, but noticeably slower | |
| // ~5%, the outer branch avoids that penalty. | |
| if (relu) { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "movi v25.4s, #0\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // accum_0 = 0 | |
| "dup v29.4s, v27.s[1]\n" // accum_1 = 0 | |
| "dup v30.4s, v27.s[2]\n" // accum_2 = 0 | |
| "dup v31.4s, v27.s[3]\n" // accum_3 = 0 | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each. | |
| "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Convert bfloat16 -> float32. | |
| "shll v4.4s, v2.4h, #16\n" | |
| "shll2 v5.4s, v2.8h, #16\n" | |
| "shll v6.4s, v3.4h, #16\n" | |
| "shll2 v7.4s, v3.8h, #16\n" | |
| // Multiply-accumulate. | |
| "fmla v28.4s, v4.4s, v0.4s\n" | |
| "fmla v29.4s, v5.4s, v0.4s\n" | |
| "fmla v30.4s, v6.4s, v0.4s\n" | |
| "fmla v31.4s, v7.4s, v0.4s\n" | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| // Horizontally add accumulators and store result | |
| "faddp v28.4s, v28.4s, v29.4s\n" | |
| "faddp v30.4s, v30.4s, v31.4s\n" | |
| "faddp v28.4s, v28.4s, v30.4s\n" | |
| // Do relu if requested. | |
| "fmax v28.4s, v28.4s, v25.4s\n" | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), | |
| [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), | |
| [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), | |
| [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr) | |
| : // inputs | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } else { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // accum_0 = 0 | |
| "dup v29.4s, v27.s[1]\n" // accum_1 = 0 | |
| "dup v30.4s, v27.s[2]\n" // accum_2 = 0 | |
| "dup v31.4s, v27.s[3]\n" // accum_3 = 0 | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each | |
| "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Convert bfloat16 -> float32. | |
| "shll v4.4s, v2.4h, #16\n" | |
| "shll2 v5.4s, v2.8h, #16\n" | |
| "shll v6.4s, v3.4h, #16\n" | |
| "shll2 v7.4s, v3.8h, #16\n" | |
| // Multiply-accumulate. | |
| "fmla v28.4s, v4.4s, v0.4s\n" | |
| "fmla v29.4s, v5.4s, v0.4s\n" | |
| "fmla v30.4s, v6.4s, v0.4s\n" | |
| "fmla v31.4s, v7.4s, v0.4s\n" | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| // Horizontally add accumulators and store result. | |
| "faddp v28.4s, v28.4s, v29.4s\n" | |
| "faddp v30.4s, v30.4s, v31.4s\n" | |
| "faddp v28.4s, v28.4s, v30.4s\n" | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), | |
| [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), | |
| [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), | |
| [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr) | |
| : // inputs | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } | |
| } | |
| // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 | |
| // blocked pattern, x is a fat vector with 5 columns and b is vector. b is | |
| // broadcast. Weights are stored for this routine by making each 4x4 block | |
| // contiguous. Blocks are ordered in standard row-major format. column indices | |
| // are converted to deltas and then multiplied by 2 to convert to bytes, so | |
| // that the value can be used directly to offset the pointer into the rhs | |
| // vector. | |
| // | |
| // NOTE: The bias is expected to have be multiplied by .25f prior to calling | |
| // this function. This is automatically taken care of in SparseLinearLayer. | |
| // The bias is reconstructed through horizontal additions, leads to a small | |
| // speedup by reducing latencies at the end of the loop. | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| typename std::enable_if<std::is_same<WeightType, bfloat16>::value && | |
| std::is_same<RhsType, float>::value && | |
| std::is_same<OutType, float>::value>::type | |
| SpMM5_4x4(const bfloat16* weights_ptr, const int16_t* col_deltas_bytes, | |
| const int32_t* nnz_per_row, const float* rhs_ptr, | |
| const float* bias_ptr, float* out_ptr, int64_t assigned_rows, | |
| int64_t rows, int64_t cols, int relu) { | |
| /* This instrinsic version exists for reference, note that in the | |
| intrinsic version col_deltas_bytes should NOT actually be in bytes, | |
| but rather elements. Intrinsics are 25-35% slower than the | |
| assembly version. | |
| for (int r = 0; r < rows; r += 4) { | |
| int reduced_col_count = nnz_per_row[r / 4]; | |
| float32x4_t accum0 = vdupq_n_f32(bias_ptr + r); | |
| float32x4_t accum1 = vdupq_n_f32(bias_ptr + r + 1); | |
| float32x4_t accum2 = vdupq_n_f32(bias_ptr + r + 2); | |
| float32x4_t accum3 = vdupq_n_f32(bias_ptr + r + 3); | |
| float32x4_t accum4 = vdupq_n_f32(bias_ptr + r); | |
| float32x4_t accum5 = vdupq_n_f32(bias_ptr + r + 1); | |
| float32x4_t accum6 = vdupq_n_f32(bias_ptr + r + 2); | |
| float32x4_t accum7 = vdupq_n_f32(bias_ptr + r + 3); | |
| ... | |
| for (int c = 0; c < reduced_col_count; ++c) { | |
| int32_t offset = *col_deltas_bytes; col_deltas_bytes++; | |
| rhs_ptr += offset; | |
| float32x4_t rhs = vld1q_f32(rhs_ptr); | |
| float32x4_t rhs2 = vld1q_f32(rhs2_ptr); | |
| float32x4_t rhs3 = vld1q_f32(rhs3_ptr); | |
| float32x4_t rhs4 = vld1q_f32(rhs4_ptr); | |
| float32x4_t rhs5 = vld1q_f32(rhs5_ptr); | |
| uint16x4_t lhs0_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| uint16x4_t lhs1_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| uint16x4_t lhs2_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| uint16x4_t lhs3_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| float32x4_t lhs0 = vreinterpretq_f32_u32(vshll_n_u16(lhs0_int, 16)); | |
| float32x4_t lhs1 = vreinterpretq_f32_u32(vshll_n_u16(lhs1_int, 16)); | |
| float32x4_t lhs2 = vreinterpretq_f32_u32(vshll_n_u16(lhs2_int, 16)); | |
| float32x4_t lhs3 = vreinterpretq_f32_u32(vshll_n_u16(lhs3_int, 16)); | |
| accum0 = vmlaq_f32(accum0, lhs0, rhs); | |
| accum1 = vmlaq_f32(accum1, lhs1, rhs); | |
| accum2 = vmlaq_f32(accum2, lhs2, rhs); | |
| accum3 = vmlaq_f32(accum3, lhs3, rhs); | |
| accum4 = vmlaq_f32(accum0, lhs0, rhs2); | |
| accum5 = vmlaq_f32(accum1, lhs1, rhs2); | |
| accum6 = vmlaq_f32(accum2, lhs2, rhs2); | |
| accum7 = vmlaq_f32(accum3, lhs3, rhs2); | |
| ... | |
| } | |
| float32x4_t reduce0 = vpaddq_f32(accum0, accum1); | |
| float32x4_t reduce1 = vpaddq_f32(accum2, accum3); | |
| float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); | |
| vst1q_f32(out_ptr + r, reduce2); | |
| float32x4_t reduce0 = vpaddq_f32(accum4, accum5); | |
| float32x4_t reduce1 = vpaddq_f32(accum6, accum7); | |
| float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); | |
| vst1q_f32(out2_ptr + r, reduce2); | |
| ... | |
| } */ | |
| // If the relu is handled in the routine with a comparison and vbit (insert | |
| // if true), or by branching, then it is slightly, but noticeably slower | |
| // ~5%, the outer branch avoids that penalty. | |
| // | |
| // Pointers to the columns. | |
| const float* rhs2_ptr = rhs_ptr + cols; | |
| float* out2_ptr = out_ptr + rows; | |
| const float* rhs3_ptr = rhs_ptr + 2 * cols; | |
| float* out3_ptr = out_ptr + 2 * rows; | |
| const float* rhs4_ptr = rhs_ptr + 3 * cols; | |
| float* out4_ptr = out_ptr + 3 * rows; | |
| const float* rhs5_ptr = rhs_ptr + 4 * cols; | |
| float* out5_ptr = out_ptr + 4 * rows; | |
| if (relu) { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "add %[rhs2_ptr], %[rhs2_ptr], x7\n" | |
| "add %[rhs3_ptr], %[rhs3_ptr], x7\n" | |
| "add %[rhs4_ptr], %[rhs4_ptr], x7\n" | |
| "add %[rhs5_ptr], %[rhs5_ptr], x7\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // for 1st column | |
| "dup v29.4s, v27.s[1]\n" // for 1st column | |
| "dup v30.4s, v27.s[2]\n" // for 1st column | |
| "dup v31.4s, v27.s[3]\n" // for 1st column | |
| "dup v23.4s, v27.s[0]\n" // for 2nd column | |
| "dup v24.4s, v27.s[1]\n" // for 2nd column | |
| "dup v25.4s, v27.s[2]\n" // for 2nd column | |
| "dup v26.4s, v27.s[3]\n" // for 2nd column | |
| "dup v19.4s, v27.s[0]\n" // for 3rd column | |
| "dup v20.4s, v27.s[1]\n" // for 3rd column | |
| "dup v21.4s, v27.s[2]\n" // for 3rd column | |
| "dup v22.4s, v27.s[3]\n" // for 3rd column | |
| "dup v15.4s, v27.s[0]\n" // for 4th column | |
| "dup v16.4s, v27.s[1]\n" // for 4th column | |
| "dup v17.4s, v27.s[2]\n" // for 4th column | |
| "dup v18.4s, v27.s[3]\n" // for 4th column | |
| "dup v11.4s, v27.s[0]\n" // for 5th column | |
| "dup v12.4s, v27.s[1]\n" // for 5th column | |
| "dup v13.4s, v27.s[2]\n" // for 5th column | |
| "dup v14.4s, v27.s[3]\n" // for 5th column | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each. | |
| "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" | |
| "ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" | |
| "ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" | |
| "ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" | |
| "ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Convert bfloat16 -> float32. | |
| "shll v4.4s, v2.4h, #16\n" | |
| "shll2 v5.4s, v2.8h, #16\n" | |
| "shll v6.4s, v3.4h, #16\n" | |
| "shll2 v7.4s, v3.8h, #16\n" | |
| // Multiply-accumulate. | |
| "fmla v28.4s, v4.4s, v0.4s\n" // for 1st column | |
| "fmla v29.4s, v5.4s, v0.4s\n" // for 1st column | |
| "fmla v30.4s, v6.4s, v0.4s\n" // for 1st column | |
| "fmla v31.4s, v7.4s, v0.4s\n" // for 1st column | |
| "fmla v23.4s, v4.4s, v1.4s\n" // for 2nd column | |
| "fmla v24.4s, v5.4s, v1.4s\n" // for 2nd column | |
| "fmla v25.4s, v6.4s, v1.4s\n" // for 2nd column | |
| "fmla v26.4s, v7.4s, v1.4s\n" // for 2nd column | |
| "fmla v19.4s, v4.4s, v8.4s\n" // for 3rd column | |
| "fmla v20.4s, v5.4s, v8.4s\n" // for 3rd column | |
| "fmla v21.4s, v6.4s, v8.4s\n" // for 3rd column | |
| "fmla v22.4s, v7.4s, v8.4s\n" // for 3rd column | |
| "fmla v15.4s, v4.4s, v9.4s\n" // for 4th column | |
| "fmla v16.4s, v5.4s, v9.4s\n" // for 4th column | |
| "fmla v17.4s, v6.4s, v9.4s\n" // for 4th column | |
| "fmla v18.4s, v7.4s, v9.4s\n" // for 4th column | |
| "fmla v11.4s, v4.4s, v10.4s\n" // for 5th column | |
| "fmla v12.4s, v5.4s, v10.4s\n" // for 5th column | |
| "fmla v13.4s, v6.4s, v10.4s\n" // for 5th column | |
| "fmla v14.4s, v7.4s, v10.4s\n" // for 5th column | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| "movi v0.4s, #0\n" | |
| "faddp v28.4s, v28.4s, v29.4s\n" // 1st column | |
| "faddp v23.4s, v23.4s, v24.4s\n" // 2nd column | |
| "faddp v19.4s, v19.4s, v20.4s\n" // 3rd column | |
| "faddp v15.4s, v15.4s, v16.4s\n" // 4th column | |
| "faddp v11.4s, v11.4s, v12.4s\n" // 5th column | |
| "faddp v30.4s, v30.4s, v31.4s\n" // 1st column | |
| "faddp v25.4s, v25.4s, v26.4s\n" // 2nd column | |
| "faddp v21.4s, v21.4s, v22.4s\n" // 3rd column | |
| "faddp v17.4s, v17.4s, v18.4s\n" // 4th column | |
| "faddp v13.4s, v13.4s, v14.4s\n" // 5th column | |
| "faddp v28.4s, v28.4s, v30.4s\n" // 1st column | |
| "faddp v23.4s, v23.4s, v25.4s\n" // 2nd column | |
| "faddp v19.4s, v19.4s, v21.4s\n" // 3rd column | |
| "faddp v15.4s, v15.4s, v17.4s\n" // 4th column | |
| "faddp v11.4s, v11.4s, v13.4s\n" // 5th column | |
| // Do relu as requested. | |
| "fmax v28.4s, v28.4s, v0.4s\n" | |
| "fmax v23.4s, v23.4s, v0.4s\n" | |
| "fmax v19.4s, v19.4s, v0.4s\n" | |
| "fmax v15.4s, v15.4s, v0.4s\n" | |
| "fmax v11.4s, v11.4s, v0.4s\n" | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| "st1 {v23.4s}, [%[out2_ptr]], #16\n" | |
| "st1 {v19.4s}, [%[out3_ptr]], #16\n" | |
| "st1 {v15.4s}, [%[out4_ptr]], #16\n" | |
| "st1 {v11.4s}, [%[out5_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), | |
| [out2_ptr] "+r"(out2_ptr), | |
| [out3_ptr] "+r"(out3_ptr), | |
| [out4_ptr] "+r"(out4_ptr), | |
| [out5_ptr] "+r"(out5_ptr), | |
| [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), | |
| [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), | |
| [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr), | |
| [rhs2_ptr] "+r"(rhs2_ptr), | |
| [rhs3_ptr] "+r"(rhs3_ptr), | |
| [rhs4_ptr] "+r"(rhs4_ptr), | |
| [rhs5_ptr] "+r"(rhs5_ptr) | |
| : // inputs | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } else { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "add %[rhs2_ptr], %[rhs2_ptr], x7\n" | |
| "add %[rhs3_ptr], %[rhs3_ptr], x7\n" | |
| "add %[rhs4_ptr], %[rhs4_ptr], x7\n" | |
| "add %[rhs5_ptr], %[rhs5_ptr], x7\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // for 1st column | |
| "dup v29.4s, v27.s[1]\n" // for 1st column | |
| "dup v30.4s, v27.s[2]\n" // for 1st column | |
| "dup v31.4s, v27.s[3]\n" // for 1st column | |
| "dup v23.4s, v27.s[0]\n" // for 2nd column | |
| "dup v24.4s, v27.s[1]\n" // for 2nd column | |
| "dup v25.4s, v27.s[2]\n" // for 2nd column | |
| "dup v26.4s, v27.s[3]\n" // for 2nd column | |
| "dup v19.4s, v27.s[0]\n" // for 3rd column | |
| "dup v20.4s, v27.s[1]\n" // for 3rd column | |
| "dup v21.4s, v27.s[2]\n" // for 3rd column | |
| "dup v22.4s, v27.s[3]\n" // for 3rd column | |
| "dup v15.4s, v27.s[0]\n" // for 4th column | |
| "dup v16.4s, v27.s[1]\n" // for 4th column | |
| "dup v17.4s, v27.s[2]\n" // for 4th column | |
| "dup v18.4s, v27.s[3]\n" // for 4th column | |
| "dup v11.4s, v27.s[0]\n" // for 5th column | |
| "dup v12.4s, v27.s[1]\n" // for 5th column | |
| "dup v13.4s, v27.s[2]\n" // for 5th column | |
| "dup v14.4s, v27.s[3]\n" // for 5th column | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each. | |
| "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" | |
| "ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" | |
| "ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" | |
| "ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" | |
| "ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Convert bfloat16 -> float32. | |
| "shll v4.4s, v2.4h, #16\n" | |
| "shll2 v5.4s, v2.8h, #16\n" | |
| "shll v6.4s, v3.4h, #16\n" | |
| "shll2 v7.4s, v3.8h, #16\n" | |
| // Multiply-accumulate. | |
| "fmla v28.4s, v4.4s, v0.4s\n" // for 1st column | |
| "fmla v29.4s, v5.4s, v0.4s\n" // for 1st column | |
| "fmla v30.4s, v6.4s, v0.4s\n" // for 1st column | |
| "fmla v31.4s, v7.4s, v0.4s\n" // for 1st column | |
| "fmla v23.4s, v4.4s, v1.4s\n" // for 2nd column | |
| "fmla v24.4s, v5.4s, v1.4s\n" // for 2nd column | |
| "fmla v25.4s, v6.4s, v1.4s\n" // for 2nd column | |
| "fmla v26.4s, v7.4s, v1.4s\n" // for 2nd column | |
| "fmla v19.4s, v4.4s, v8.4s\n" // for 3rd column | |
| "fmla v20.4s, v5.4s, v8.4s\n" // for 3rd column | |
| "fmla v21.4s, v6.4s, v8.4s\n" // for 3rd column | |
| "fmla v22.4s, v7.4s, v8.4s\n" // for 3rd column | |
| "fmla v15.4s, v4.4s, v9.4s\n" // for 4th column | |
| "fmla v16.4s, v5.4s, v9.4s\n" // for 4th column | |
| "fmla v17.4s, v6.4s, v9.4s\n" // for 4th column | |
| "fmla v18.4s, v7.4s, v9.4s\n" // for 4th column | |
| "fmla v11.4s, v4.4s, v10.4s\n" // for 5th column | |
| "fmla v12.4s, v5.4s, v10.4s\n" // for 5th column | |
| "fmla v13.4s, v6.4s, v10.4s\n" // for 5th column | |
| "fmla v14.4s, v7.4s, v10.4s\n" // for 5th column | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| // Horizontally add accumulators and store result. | |
| "faddp v28.4s, v28.4s, v29.4s\n" // 1st column | |
| "faddp v23.4s, v23.4s, v24.4s\n" // 2nd column | |
| "faddp v19.4s, v19.4s, v20.4s\n" // 3rd column | |
| "faddp v15.4s, v15.4s, v16.4s\n" // 4th column | |
| "faddp v11.4s, v11.4s, v12.4s\n" // 5th column | |
| "faddp v30.4s, v30.4s, v31.4s\n" // 1st column | |
| "faddp v25.4s, v25.4s, v26.4s\n" // 2nd column | |
| "faddp v21.4s, v21.4s, v22.4s\n" // 3rd column | |
| "faddp v17.4s, v17.4s, v18.4s\n" // 4th column | |
| "faddp v13.4s, v13.4s, v14.4s\n" // 5th column | |
| "faddp v28.4s, v28.4s, v30.4s\n" // 1st column | |
| "faddp v23.4s, v23.4s, v25.4s\n" // 2nd column | |
| "faddp v19.4s, v19.4s, v21.4s\n" // 3rd column | |
| "faddp v15.4s, v15.4s, v17.4s\n" // 4th column | |
| "faddp v11.4s, v11.4s, v13.4s\n" // 5th column | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| "st1 {v23.4s}, [%[out2_ptr]], #16\n" | |
| "st1 {v19.4s}, [%[out3_ptr]], #16\n" | |
| "st1 {v15.4s}, [%[out4_ptr]], #16\n" | |
| "st1 {v11.4s}, [%[out5_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), | |
| [out2_ptr] "+r"(out2_ptr), | |
| [out3_ptr] "+r"(out3_ptr), | |
| [out4_ptr] "+r"(out4_ptr), | |
| [out5_ptr] "+r"(out5_ptr), | |
| [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), | |
| [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), | |
| [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr), | |
| [rhs2_ptr] "+r"(rhs2_ptr), | |
| [rhs3_ptr] "+r"(rhs3_ptr), | |
| [rhs4_ptr] "+r"(rhs4_ptr), | |
| [rhs5_ptr] "+r"(rhs5_ptr) | |
| : // inputs | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } | |
| } | |
| // float implementations below the line. | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| typename std::enable_if<std::is_same<WeightType, float>::value && | |
| std::is_same<RhsType, float>::value && | |
| std::is_same<OutType, float>::value>::type | |
| SpMV_4x4(const float* weights_ptr, const int16_t* col_deltas_bytes, | |
| const int32_t* nnz_per_row, const float* rhs_ptr, | |
| const float* bias_ptr, float* out_ptr, int64_t assigned_rows, | |
| int64_t rows /* only used in SpMM variants */, | |
| int64_t cols /* only used in SpMM variants */, int relu) { | |
| /* This instrinsic version exists for reference, note that in the | |
| intrinsic version col_deltas_bytes should NOT actually be in bytes, | |
| but rather elements. Intrinsics are 25-35% slower than the | |
| assembly version. | |
| for (int r = 0; r < rows; r += 4) { | |
| int reduced_col_count = nnz_per_row[r / 4]; | |
| float32x4_t accum0 = vdupq_n_f32(bias_ptr + r); | |
| float32x4_t accum1 = vdupq_n_f32(bias_ptr + r + 1); | |
| float32x4_t accum2 = vdupq_n_f32(bias_ptr + r + 2); | |
| float32x4_t accum3 = vdupq_n_f32(bias_ptr + r + 3); | |
| for (int c = 0; c < reduced_col_count; ++c) { | |
| int32_t offset = *col_deltas_bytes; col_deltas_bytes++; | |
| rhs_ptr += offset; | |
| float32x4_t rhs = vld1q_f32(rhs_ptr); | |
| uint16x4_t lhs0_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| uint16x4_t lhs1_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| uint16x4_t lhs2_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| uint16x4_t lhs3_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| float32x4_t lhs0 = vreinterpretq_f32_u32(vshll_n_u16(lhs0_int, 16)); | |
| float32x4_t lhs1 = vreinterpretq_f32_u32(vshll_n_u16(lhs1_int, 16)); | |
| float32x4_t lhs2 = vreinterpretq_f32_u32(vshll_n_u16(lhs2_int, 16)); | |
| float32x4_t lhs3 = vreinterpretq_f32_u32(vshll_n_u16(lhs3_int, 16)); | |
| accum0 = vmlaq_f32(accum0, lhs0, rhs); | |
| accum1 = vmlaq_f32(accum1, lhs1, rhs); | |
| accum2 = vmlaq_f32(accum2, lhs2, rhs); | |
| accum3 = vmlaq_f32(accum3, lhs3, rhs); | |
| } | |
| float32x4_t reduce0 = vpaddq_f32(accum0, accum1); | |
| float32x4_t reduce1 = vpaddq_f32(accum2, accum3); | |
| float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); | |
| vst1q_f32(out_ptr + r, reduce2); | |
| } */ | |
| // If the relu is handled in the routine with a comparison and vbit (insert | |
| // if true), or by branching, then it is slightly, but noticeably slower | |
| // ~5%, the outer branch avoids that penalty. | |
| if (relu) { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "movi v25.4s, #0\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // accum_0 = 0 | |
| "dup v29.4s, v27.s[1]\n" // accum_1 = 0 | |
| "dup v30.4s, v27.s[2]\n" // accum_2 = 0 | |
| "dup v31.4s, v27.s[3]\n" // accum_3 = 0 | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each. | |
| "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" | |
| // Multiply-accumulate. | |
| "fmla v28.4s, v4.4s, v0.4s\n" | |
| "fmla v29.4s, v5.4s, v0.4s\n" | |
| "fmla v30.4s, v6.4s, v0.4s\n" | |
| "fmla v31.4s, v7.4s, v0.4s\n" | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| // Horizontally add accumulators and store result. | |
| "faddp v28.4s, v28.4s, v29.4s\n" | |
| "faddp v30.4s, v30.4s, v31.4s\n" | |
| "faddp v28.4s, v28.4s, v30.4s\n" | |
| // Do relu as requested. | |
| "fmax v28.4s, v28.4s, v25.4s\n" | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), | |
| [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), | |
| [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), | |
| [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr) | |
| : // inputs | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } else { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // accum_0 = 0 | |
| "dup v29.4s, v27.s[1]\n" // accum_1 = 0 | |
| "dup v30.4s, v27.s[2]\n" // accum_2 = 0 | |
| "dup v31.4s, v27.s[3]\n" // accum_3 = 0 | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each. | |
| "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" | |
| // Multiply-accumulate. | |
| "fmla v28.4s, v4.4s, v0.4s\n" | |
| "fmla v29.4s, v5.4s, v0.4s\n" | |
| "fmla v30.4s, v6.4s, v0.4s\n" | |
| "fmla v31.4s, v7.4s, v0.4s\n" | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| // Horizontally add accumulators and store result. | |
| "faddp v28.4s, v28.4s, v29.4s\n" | |
| "faddp v30.4s, v30.4s, v31.4s\n" | |
| "faddp v28.4s, v28.4s, v30.4s\n" | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), | |
| [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), | |
| [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), | |
| [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr) | |
| : // inputs | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } | |
| } | |
| // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 | |
| // blocked pattern, x is a fat vector with 5 columns and b is vector. b is | |
| // broadcast. Weights are stored for this routine by making each 4x4 block | |
| // contiguous. Blocks are ordered in standard row-major format. column indices | |
| // are converted to deltas and then multiplied by 2 to convert to bytes, so | |
| // that the value can be used directly to offset the pointer into the rhs | |
| // vector. | |
| // | |
| // NOTE: The bias is expected to have be multiplied by .25f prior to calling | |
| // this function. This is automatically taken care of in sparse_linear_layer. | |
| // The bias is reconstructed through horizontal additions, leads to a small | |
| // speedup by reducing latencies at the end of the loop. | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| typename std::enable_if<std::is_same<WeightType, float>::value && | |
| std::is_same<RhsType, float>::value && | |
| std::is_same<OutType, float>::value>::type | |
| SpMM5_4x4(const float* weights_ptr, const int16_t* col_deltas_bytes, | |
| const int32_t* nnz_per_row, const float* rhs_ptr, | |
| const float* bias_ptr, float* out_ptr, int64_t assigned_rows, | |
| int64_t rows, int64_t cols, int relu) { | |
| /* This instrinsic version exists for reference, note that in the | |
| intrinsic version col_deltas_bytes should NOT actually be in bytes, | |
| but rather elements. Intrinsics are 25-35% slower than the | |
| assembly version. | |
| for (int r = 0; r < rows; r += 4) { | |
| int reduced_col_count = nnz_per_row[r / 4]; | |
| float32x4_t accum0 = vdupq_n_f32(bias_ptr + r); | |
| float32x4_t accum1 = vdupq_n_f32(bias_ptr + r + 1); | |
| float32x4_t accum2 = vdupq_n_f32(bias_ptr + r + 2); | |
| float32x4_t accum3 = vdupq_n_f32(bias_ptr + r + 3); | |
| float32x4_t accum4 = vdupq_n_f32(bias_ptr + r); | |
| float32x4_t accum5 = vdupq_n_f32(bias_ptr + r + 1); | |
| float32x4_t accum6 = vdupq_n_f32(bias_ptr + r + 2); | |
| float32x4_t accum7 = vdupq_n_f32(bias_ptr + r + 3); | |
| ... | |
| for (int c = 0; c < reduced_col_count; ++c) { | |
| int32_t offset = *col_deltas_bytes; col_deltas_bytes++; | |
| rhs_ptr += offset; | |
| float32x4_t rhs = vld1q_f32(rhs_ptr); | |
| float32x4_t rhs2 = vld1q_f32(rhs2_ptr); | |
| float32x4_t rhs3 = vld1q_f32(rhs3_ptr); | |
| float32x4_t rhs4 = vld1q_f32(rhs4_ptr); | |
| float32x4_t rhs5 = vld1q_f32(rhs5_ptr); | |
| uint16x4_t lhs0_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| uint16x4_t lhs1_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| uint16x4_t lhs2_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| uint16x4_t lhs3_int = vld1_u16(weights_ptr); weights_ptr += 4; | |
| float32x4_t lhs0 = vreinterpretq_f32_u32(vshll_n_u16(lhs0_int, 16)); | |
| float32x4_t lhs1 = vreinterpretq_f32_u32(vshll_n_u16(lhs1_int, 16)); | |
| float32x4_t lhs2 = vreinterpretq_f32_u32(vshll_n_u16(lhs2_int, 16)); | |
| float32x4_t lhs3 = vreinterpretq_f32_u32(vshll_n_u16(lhs3_int, 16)); | |
| accum0 = vmlaq_f32(accum0, lhs0, rhs); | |
| accum1 = vmlaq_f32(accum1, lhs1, rhs); | |
| accum2 = vmlaq_f32(accum2, lhs2, rhs); | |
| accum3 = vmlaq_f32(accum3, lhs3, rhs); | |
| accum4 = vmlaq_f32(accum0, lhs0, rhs2); | |
| accum5 = vmlaq_f32(accum1, lhs1, rhs2); | |
| accum6 = vmlaq_f32(accum2, lhs2, rhs2); | |
| accum7 = vmlaq_f32(accum3, lhs3, rhs2); | |
| ... | |
| } | |
| float32x4_t reduce0 = vpaddq_f32(accum0, accum1); | |
| float32x4_t reduce1 = vpaddq_f32(accum2, accum3); | |
| float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); | |
| vst1q_f32(out_ptr + r, reduce2); | |
| float32x4_t reduce0 = vpaddq_f32(accum4, accum5); | |
| float32x4_t reduce1 = vpaddq_f32(accum6, accum7); | |
| float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); | |
| vst1q_f32(out2_ptr + r, reduce2); | |
| ... | |
| } */ | |
| // If the relu is handled in the routine with a comparison and vbit (insert | |
| // if true), or by branching, then it is slightly, but noticeably slower | |
| // ~5%, the outer branch avoids that penalty. | |
| // | |
| // Pointers to the columns. | |
| const float* rhs2_ptr = rhs_ptr + cols; | |
| float* out2_ptr = out_ptr + rows; | |
| const float* rhs3_ptr = rhs_ptr + 2 * cols; | |
| float* out3_ptr = out_ptr + 2 * rows; | |
| const float* rhs4_ptr = rhs_ptr + 3 * cols; | |
| float* out4_ptr = out_ptr + 3 * rows; | |
| const float* rhs5_ptr = rhs_ptr + 4 * cols; | |
| float* out5_ptr = out_ptr + 4 * rows; | |
| if (relu) { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "add %[rhs2_ptr], %[rhs2_ptr], x7\n" | |
| "add %[rhs3_ptr], %[rhs3_ptr], x7\n" | |
| "add %[rhs4_ptr], %[rhs4_ptr], x7\n" | |
| "add %[rhs5_ptr], %[rhs5_ptr], x7\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // for 1st column | |
| "dup v29.4s, v27.s[1]\n" // for 1st column | |
| "dup v30.4s, v27.s[2]\n" // for 1st column | |
| "dup v31.4s, v27.s[3]\n" // for 1st column | |
| "dup v23.4s, v27.s[0]\n" // for 2nd column | |
| "dup v24.4s, v27.s[1]\n" // for 2nd column | |
| "dup v25.4s, v27.s[2]\n" // for 2nd column | |
| "dup v26.4s, v27.s[3]\n" // for 2nd column | |
| "dup v19.4s, v27.s[0]\n" // for 3rd column | |
| "dup v20.4s, v27.s[1]\n" // for 3rd column | |
| "dup v21.4s, v27.s[2]\n" // for 3rd column | |
| "dup v22.4s, v27.s[3]\n" // for 3rd column | |
| "dup v15.4s, v27.s[0]\n" // for 4th column | |
| "dup v16.4s, v27.s[1]\n" // for 4th column | |
| "dup v17.4s, v27.s[2]\n" // for 4th column | |
| "dup v18.4s, v27.s[3]\n" // for 4th column | |
| "dup v11.4s, v27.s[0]\n" // for 5th column | |
| "dup v12.4s, v27.s[1]\n" // for 5th column | |
| "dup v13.4s, v27.s[2]\n" // for 5th column | |
| "dup v14.4s, v27.s[3]\n" // for 5th column | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each. | |
| "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" | |
| "ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" | |
| "ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" | |
| "ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" | |
| "ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" | |
| // Multiply-accumulate. | |
| "fmla v28.4s, v4.4s, v0.4s\n" // for 1st column | |
| "fmla v29.4s, v5.4s, v0.4s\n" // for 1st column | |
| "fmla v30.4s, v6.4s, v0.4s\n" // for 1st column | |
| "fmla v31.4s, v7.4s, v0.4s\n" // for 1st column | |
| "fmla v23.4s, v4.4s, v1.4s\n" // for 2nd column | |
| "fmla v24.4s, v5.4s, v1.4s\n" // for 2nd column | |
| "fmla v25.4s, v6.4s, v1.4s\n" // for 2nd column | |
| "fmla v26.4s, v7.4s, v1.4s\n" // for 2nd column | |
| "fmla v19.4s, v4.4s, v8.4s\n" // for 3rd column | |
| "fmla v20.4s, v5.4s, v8.4s\n" // for 3rd column | |
| "fmla v21.4s, v6.4s, v8.4s\n" // for 3rd column | |
| "fmla v22.4s, v7.4s, v8.4s\n" // for 3rd column | |
| "fmla v15.4s, v4.4s, v9.4s\n" // for 4th column | |
| "fmla v16.4s, v5.4s, v9.4s\n" // for 4th column | |
| "fmla v17.4s, v6.4s, v9.4s\n" // for 4th column | |
| "fmla v18.4s, v7.4s, v9.4s\n" // for 4th column | |
| "fmla v11.4s, v4.4s, v10.4s\n" // for 5th column | |
| "fmla v12.4s, v5.4s, v10.4s\n" // for 5th column | |
| "fmla v13.4s, v6.4s, v10.4s\n" // for 5th column | |
| "fmla v14.4s, v7.4s, v10.4s\n" // for 5th column | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| "movi v0.4s, #0\n" | |
| "faddp v28.4s, v28.4s, v29.4s\n" // 1st column | |
| "faddp v23.4s, v23.4s, v24.4s\n" // 2nd column | |
| "faddp v19.4s, v19.4s, v20.4s\n" // 3rd column | |
| "faddp v15.4s, v15.4s, v16.4s\n" // 4th column | |
| "faddp v11.4s, v11.4s, v12.4s\n" // 5th column | |
| "faddp v30.4s, v30.4s, v31.4s\n" // 1st column | |
| "faddp v25.4s, v25.4s, v26.4s\n" // 2nd column | |
| "faddp v21.4s, v21.4s, v22.4s\n" // 3rd column | |
| "faddp v17.4s, v17.4s, v18.4s\n" // 4th column | |
| "faddp v13.4s, v13.4s, v14.4s\n" // 5th column | |
| "faddp v28.4s, v28.4s, v30.4s\n" // 1st column | |
| "faddp v23.4s, v23.4s, v25.4s\n" // 2nd column | |
| "faddp v19.4s, v19.4s, v21.4s\n" // 3rd column | |
| "faddp v15.4s, v15.4s, v17.4s\n" // 4th column | |
| "faddp v11.4s, v11.4s, v13.4s\n" // 5th column | |
| // Do relu as requested. | |
| "fmax v28.4s, v28.4s, v0.4s\n" | |
| "fmax v23.4s, v23.4s, v0.4s\n" | |
| "fmax v19.4s, v19.4s, v0.4s\n" | |
| "fmax v15.4s, v15.4s, v0.4s\n" | |
| "fmax v11.4s, v11.4s, v0.4s\n" | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| "st1 {v23.4s}, [%[out2_ptr]], #16\n" | |
| "st1 {v19.4s}, [%[out3_ptr]], #16\n" | |
| "st1 {v15.4s}, [%[out4_ptr]], #16\n" | |
| "st1 {v11.4s}, [%[out5_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), | |
| [out2_ptr] "+r"(out2_ptr), | |
| [out3_ptr] "+r"(out3_ptr), | |
| [out4_ptr] "+r"(out4_ptr), | |
| [out5_ptr] "+r"(out5_ptr), | |
| [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), | |
| [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), | |
| [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr), | |
| [rhs2_ptr] "+r"(rhs2_ptr), | |
| [rhs3_ptr] "+r"(rhs3_ptr), | |
| [rhs4_ptr] "+r"(rhs4_ptr), | |
| [rhs5_ptr] "+r"(rhs5_ptr) | |
| : // inputs | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } else { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "add %[rhs2_ptr], %[rhs2_ptr], x7\n" | |
| "add %[rhs3_ptr], %[rhs3_ptr], x7\n" | |
| "add %[rhs4_ptr], %[rhs4_ptr], x7\n" | |
| "add %[rhs5_ptr], %[rhs5_ptr], x7\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // for 1st column | |
| "dup v29.4s, v27.s[1]\n" // for 1st column | |
| "dup v30.4s, v27.s[2]\n" // for 1st column | |
| "dup v31.4s, v27.s[3]\n" // for 1st column | |
| "dup v23.4s, v27.s[0]\n" // for 2nd column | |
| "dup v24.4s, v27.s[1]\n" // for 2nd column | |
| "dup v25.4s, v27.s[2]\n" // for 2nd column | |
| "dup v26.4s, v27.s[3]\n" // for 2nd column | |
| "dup v19.4s, v27.s[0]\n" // for 3rd column | |
| "dup v20.4s, v27.s[1]\n" // for 3rd column | |
| "dup v21.4s, v27.s[2]\n" // for 3rd column | |
| "dup v22.4s, v27.s[3]\n" // for 3rd column | |
| "dup v15.4s, v27.s[0]\n" // for 4th column | |
| "dup v16.4s, v27.s[1]\n" // for 4th column | |
| "dup v17.4s, v27.s[2]\n" // for 4th column | |
| "dup v18.4s, v27.s[3]\n" // for 4th column | |
| "dup v11.4s, v27.s[0]\n" // for 5th column | |
| "dup v12.4s, v27.s[1]\n" // for 5th column | |
| "dup v13.4s, v27.s[2]\n" // for 5th column | |
| "dup v14.4s, v27.s[3]\n" // for 5th column | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each. | |
| "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" | |
| "ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" | |
| "ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" | |
| "ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" | |
| "ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" | |
| // Multiply-accumulate. | |
| "fmla v28.4s, v4.4s, v0.4s\n" // for 1st column | |
| "fmla v29.4s, v5.4s, v0.4s\n" // for 1st column | |
| "fmla v30.4s, v6.4s, v0.4s\n" // for 1st column | |
| "fmla v31.4s, v7.4s, v0.4s\n" // for 1st column | |
| "fmla v23.4s, v4.4s, v1.4s\n" // for 2nd column | |
| "fmla v24.4s, v5.4s, v1.4s\n" // for 2nd column | |
| "fmla v25.4s, v6.4s, v1.4s\n" // for 2nd column | |
| "fmla v26.4s, v7.4s, v1.4s\n" // for 2nd column | |
| "fmla v19.4s, v4.4s, v8.4s\n" // for 3rd column | |
| "fmla v20.4s, v5.4s, v8.4s\n" // for 3rd column | |
| "fmla v21.4s, v6.4s, v8.4s\n" // for 3rd column | |
| "fmla v22.4s, v7.4s, v8.4s\n" // for 3rd column | |
| "fmla v15.4s, v4.4s, v9.4s\n" // for 4th column | |
| "fmla v16.4s, v5.4s, v9.4s\n" // for 4th column | |
| "fmla v17.4s, v6.4s, v9.4s\n" // for 4th column | |
| "fmla v18.4s, v7.4s, v9.4s\n" // for 4th column | |
| "fmla v11.4s, v4.4s, v10.4s\n" // for 5th column | |
| "fmla v12.4s, v5.4s, v10.4s\n" // for 5th column | |
| "fmla v13.4s, v6.4s, v10.4s\n" // for 5th column | |
| "fmla v14.4s, v7.4s, v10.4s\n" // for 5th column | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| // Horizontally add accumulators and store result. | |
| "faddp v28.4s, v28.4s, v29.4s\n" // 1st column | |
| "faddp v23.4s, v23.4s, v24.4s\n" // 2nd column | |
| "faddp v19.4s, v19.4s, v20.4s\n" // 3rd column | |
| "faddp v15.4s, v15.4s, v16.4s\n" // 4th column | |
| "faddp v11.4s, v11.4s, v12.4s\n" // 5th column | |
| "faddp v30.4s, v30.4s, v31.4s\n" // 1st column | |
| "faddp v25.4s, v25.4s, v26.4s\n" // 2nd column | |
| "faddp v21.4s, v21.4s, v22.4s\n" // 3rd column | |
| "faddp v17.4s, v17.4s, v18.4s\n" // 4th column | |
| "faddp v13.4s, v13.4s, v14.4s\n" // 5th column | |
| "faddp v28.4s, v28.4s, v30.4s\n" // 1st column | |
| "faddp v23.4s, v23.4s, v25.4s\n" // 2nd column | |
| "faddp v19.4s, v19.4s, v21.4s\n" // 3rd column | |
| "faddp v15.4s, v15.4s, v17.4s\n" // 4th column | |
| "faddp v11.4s, v11.4s, v13.4s\n" // 5th column | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| "st1 {v23.4s}, [%[out2_ptr]], #16\n" | |
| "st1 {v19.4s}, [%[out3_ptr]], #16\n" | |
| "st1 {v15.4s}, [%[out4_ptr]], #16\n" | |
| "st1 {v11.4s}, [%[out5_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), | |
| [out2_ptr] "+r"(out2_ptr), | |
| [out3_ptr] "+r"(out3_ptr), | |
| [out4_ptr] "+r"(out4_ptr), | |
| [out5_ptr] "+r"(out5_ptr), | |
| [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), | |
| [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), | |
| [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr), | |
| [rhs2_ptr] "+r"(rhs2_ptr), | |
| [rhs3_ptr] "+r"(rhs3_ptr), | |
| [rhs4_ptr] "+r"(rhs4_ptr), | |
| [rhs5_ptr] "+r"(rhs5_ptr) | |
| : // inputs | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } | |
| } | |
| // Note that the number of exponent bits in the output must exactly match | |
| // the sum of the input and rhs types. | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| typename std::enable_if< | |
| IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value && | |
| std::is_same<OutType, typename TypeOfProduct<WeightType, | |
| RhsType>::type>::value>::type | |
| SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, | |
| const int32_t* nnz_per_row, const RhsType* rhs_ptr, | |
| const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, | |
| OutType* out_ptr, int64_t assigned_rows, | |
| int64_t rows /* only used in SpMM variants */, | |
| int64_t cols /* only used in SpMM variants */, int relu) { | |
| if (relu) { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "movi v25.4s, #0\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // accum_0 = 0 | |
| "dup v29.4s, v27.s[1]\n" // accum_1 = 0 | |
| "dup v30.4s, v27.s[2]\n" // accum_2 = 0 | |
| "dup v31.4s, v27.s[3]\n" // accum_3 = 0 | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each. | |
| "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" | |
| // Duplicate the lower half into the upper half. | |
| "mov v0.d[1], v0.d[0]\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Multiply-accumulate. | |
| "smlal v28.4s, v2.4h, v0.4h\n" | |
| "smlal2 v29.4s, v2.8h, v0.8h\n" | |
| "smlal v30.4s, v3.4h, v0.4h\n" | |
| "smlal2 v31.4s, v3.8h, v0.8h\n" | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| // Horizontally add accumulators and store result. | |
| "addp v28.4s, v28.4s, v29.4s\n" | |
| "addp v30.4s, v30.4s, v31.4s\n" | |
| "addp v28.4s, v28.4s, v30.4s\n" | |
| // Do relu if requested. | |
| "smax v28.4s, v28.4s, v25.4s\n" | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr) | |
| : // inputs | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } else { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "movi v25.4s, #0\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // accum_0 = 0 | |
| "dup v29.4s, v27.s[1]\n" // accum_1 = 0 | |
| "dup v30.4s, v27.s[2]\n" // accum_2 = 0 | |
| "dup v31.4s, v27.s[3]\n" // accum_3 = 0 | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each. | |
| "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" | |
| // Duplicate the lower half into the upper half. | |
| "mov v0.d[1], v0.d[0]\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Multiply-accumulate. | |
| "smlal v28.4s, v2.4h, v0.4h\n" | |
| "smlal2 v29.4s, v2.8h, v0.8h\n" | |
| "smlal v30.4s, v3.4h, v0.4h\n" | |
| "smlal2 v31.4s, v3.8h, v0.8h\n" | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| // Horizontally add accumulators and store result. | |
| "addp v28.4s, v28.4s, v29.4s\n" | |
| "addp v30.4s, v30.4s, v31.4s\n" | |
| "addp v28.4s, v28.4s, v30.4s\n" | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr) | |
| : // inputs | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } | |
| } | |
| // Note that the number of exponent bits in the output must exactly match | |
| // the sum of the input and rhs types. | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| typename std::enable_if< | |
| IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value && | |
| std::is_same<OutType, typename TypeOfProduct<WeightType, | |
| RhsType>::type>::value>::type | |
| SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, | |
| const int32_t* nnz_per_row, const RhsType* rhs_ptr, | |
| const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, | |
| OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, | |
| int relu) { | |
| // Pointers to the columns. | |
| const RhsType* rhs2_ptr = rhs_ptr + cols; | |
| OutType* out2_ptr = out_ptr + rows; | |
| const RhsType* rhs3_ptr = rhs_ptr + 2 * cols; | |
| OutType* out3_ptr = out_ptr + 2 * rows; | |
| const RhsType* rhs4_ptr = rhs_ptr + 3 * cols; | |
| OutType* out4_ptr = out_ptr + 3 * rows; | |
| const RhsType* rhs5_ptr = rhs_ptr + 4 * cols; | |
| OutType* out5_ptr = out_ptr + 4 * rows; | |
| if (relu) { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "add %[rhs2_ptr], %[rhs2_ptr], x7\n" | |
| "add %[rhs3_ptr], %[rhs3_ptr], x7\n" | |
| "add %[rhs4_ptr], %[rhs4_ptr], x7\n" | |
| "add %[rhs5_ptr], %[rhs5_ptr], x7\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // for 1st column | |
| "dup v29.4s, v27.s[1]\n" // for 1st column | |
| "dup v30.4s, v27.s[2]\n" // for 1st column | |
| "dup v31.4s, v27.s[3]\n" // for 1st column | |
| "dup v23.4s, v27.s[0]\n" // for 2nd column | |
| "dup v24.4s, v27.s[1]\n" // for 2nd column | |
| "dup v25.4s, v27.s[2]\n" // for 2nd column | |
| "dup v26.4s, v27.s[3]\n" // for 2nd column | |
| "dup v19.4s, v27.s[0]\n" // for 3rd column | |
| "dup v20.4s, v27.s[1]\n" // for 3rd column | |
| "dup v21.4s, v27.s[2]\n" // for 3rd column | |
| "dup v22.4s, v27.s[3]\n" // for 3rd column | |
| "dup v15.4s, v27.s[0]\n" // for 4th column | |
| "dup v16.4s, v27.s[1]\n" // for 4th column | |
| "dup v17.4s, v27.s[2]\n" // for 4th column | |
| "dup v18.4s, v27.s[3]\n" // for 4th column | |
| "dup v11.4s, v27.s[0]\n" // for 5th column | |
| "dup v12.4s, v27.s[1]\n" // for 5th column | |
| "dup v13.4s, v27.s[2]\n" // for 5th column | |
| "dup v14.4s, v27.s[3]\n" // for 5th column | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. | |
| "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" | |
| "mov v0.d[1], v0.d[0]\n" | |
| "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" | |
| "mov v1.d[1], v1.d[0]\n" | |
| "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" | |
| "mov v8.d[1], v8.d[0]\n" | |
| "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" | |
| "mov v9.d[1], v9.d[0]\n" | |
| "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" | |
| "mov v10.d[1], v10.d[0]\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Multiply-accumulate. | |
| "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column | |
| "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column | |
| "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column | |
| "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh | |
| "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column | |
| "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column | |
| "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column | |
| "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column | |
| "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column | |
| "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column | |
| "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column | |
| "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column | |
| "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column | |
| "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column | |
| "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column | |
| "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column | |
| "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column | |
| "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column | |
| "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column | |
| "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| "movi v0.4s, #0\n" | |
| "addp v28.4s, v28.4s, v29.4s\n" // 1st column | |
| "addp v23.4s, v23.4s, v24.4s\n" // 2nd column | |
| "addp v19.4s, v19.4s, v20.4s\n" // 3rd column | |
| "addp v15.4s, v15.4s, v16.4s\n" // 4th column | |
| "addp v11.4s, v11.4s, v12.4s\n" // 5th column | |
| "addp v30.4s, v30.4s, v31.4s\n" // 1st column | |
| "addp v25.4s, v25.4s, v26.4s\n" // 2nd column | |
| "addp v21.4s, v21.4s, v22.4s\n" // 3rd column | |
| "addp v17.4s, v17.4s, v18.4s\n" // 4th column | |
| "addp v13.4s, v13.4s, v14.4s\n" // 5th column | |
| "addp v28.4s, v28.4s, v30.4s\n" // 1st column | |
| "addp v23.4s, v23.4s, v25.4s\n" // 2nd column | |
| "addp v19.4s, v19.4s, v21.4s\n" // 3rd column | |
| "addp v15.4s, v15.4s, v17.4s\n" // 4th column | |
| "addp v11.4s, v11.4s, v13.4s\n" // 5th column | |
| // Do relu as requested. | |
| "smax v28.4s, v28.4s, v0.4s\n" | |
| "smax v23.4s, v23.4s, v0.4s\n" | |
| "smax v19.4s, v19.4s, v0.4s\n" | |
| "smax v15.4s, v15.4s, v0.4s\n" | |
| "smax v11.4s, v11.4s, v0.4s\n" | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| "st1 {v23.4s}, [%[out2_ptr]], #16\n" | |
| "st1 {v19.4s}, [%[out3_ptr]], #16\n" | |
| "st1 {v15.4s}, [%[out4_ptr]], #16\n" | |
| "st1 {v11.4s}, [%[out5_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), | |
| [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), | |
| [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), | |
| [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), | |
| [rhs5_ptr] "+r"(rhs5_ptr) | |
| : // inputs | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } else { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "add %[rhs2_ptr], %[rhs2_ptr], x7\n" | |
| "add %[rhs3_ptr], %[rhs3_ptr], x7\n" | |
| "add %[rhs4_ptr], %[rhs4_ptr], x7\n" | |
| "add %[rhs5_ptr], %[rhs5_ptr], x7\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // for 1st column | |
| "dup v29.4s, v27.s[1]\n" // for 1st column | |
| "dup v30.4s, v27.s[2]\n" // for 1st column | |
| "dup v31.4s, v27.s[3]\n" // for 1st column | |
| "dup v23.4s, v27.s[0]\n" // for 2nd column | |
| "dup v24.4s, v27.s[1]\n" // for 2nd column | |
| "dup v25.4s, v27.s[2]\n" // for 2nd column | |
| "dup v26.4s, v27.s[3]\n" // for 2nd column | |
| "dup v19.4s, v27.s[0]\n" // for 3rd column | |
| "dup v20.4s, v27.s[1]\n" // for 3rd column | |
| "dup v21.4s, v27.s[2]\n" // for 3rd column | |
| "dup v22.4s, v27.s[3]\n" // for 3rd column | |
| "dup v15.4s, v27.s[0]\n" // for 4th column | |
| "dup v16.4s, v27.s[1]\n" // for 4th column | |
| "dup v17.4s, v27.s[2]\n" // for 4th column | |
| "dup v18.4s, v27.s[3]\n" // for 4th column | |
| "dup v11.4s, v27.s[0]\n" // for 5th column | |
| "dup v12.4s, v27.s[1]\n" // for 5th column | |
| "dup v13.4s, v27.s[2]\n" // for 5th column | |
| "dup v14.4s, v27.s[3]\n" // for 5th column | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. | |
| "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" | |
| "mov v0.d[1], v0.d[0]\n" | |
| "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" | |
| "mov v1.d[1], v1.d[0]\n" | |
| "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" | |
| "mov v8.d[1], v8.d[0]\n" | |
| "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" | |
| "mov v9.d[1], v9.d[0]\n" | |
| "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" | |
| "mov v10.d[1], v10.d[0]\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Multiply-accumulate. | |
| "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column | |
| "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column | |
| "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column | |
| "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh | |
| "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column | |
| "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column | |
| "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column | |
| "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column | |
| "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column | |
| "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column | |
| "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column | |
| "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column | |
| "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column | |
| "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column | |
| "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column | |
| "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column | |
| "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column | |
| "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column | |
| "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column | |
| "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| "addp v28.4s, v28.4s, v29.4s\n" // 1st column | |
| "addp v23.4s, v23.4s, v24.4s\n" // 2nd column | |
| "addp v19.4s, v19.4s, v20.4s\n" // 3rd column | |
| "addp v15.4s, v15.4s, v16.4s\n" // 4th column | |
| "addp v11.4s, v11.4s, v12.4s\n" // 5th column | |
| "addp v30.4s, v30.4s, v31.4s\n" // 1st column | |
| "addp v25.4s, v25.4s, v26.4s\n" // 2nd column | |
| "addp v21.4s, v21.4s, v22.4s\n" // 3rd column | |
| "addp v17.4s, v17.4s, v18.4s\n" // 4th column | |
| "addp v13.4s, v13.4s, v14.4s\n" // 5th column | |
| "addp v28.4s, v28.4s, v30.4s\n" // 1st column | |
| "addp v23.4s, v23.4s, v25.4s\n" // 2nd column | |
| "addp v19.4s, v19.4s, v21.4s\n" // 3rd column | |
| "addp v15.4s, v15.4s, v17.4s\n" // 4th column | |
| "addp v11.4s, v11.4s, v13.4s\n" // 5th column | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| "st1 {v23.4s}, [%[out2_ptr]], #16\n" | |
| "st1 {v19.4s}, [%[out3_ptr]], #16\n" | |
| "st1 {v15.4s}, [%[out4_ptr]], #16\n" | |
| "st1 {v11.4s}, [%[out5_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), | |
| [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), | |
| [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), | |
| [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), | |
| [rhs5_ptr] "+r"(rhs5_ptr) | |
| : // inputs | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } | |
| } | |
| // Note that the number of exponent bits in the bias must exactly match | |
| // the sum of the input and rhs types. | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| typename std::enable_if<IsFixed16Type<WeightType>::value && | |
| IsFixed16Type<RhsType>::value && | |
| IsFixed16Type<OutType>::value>::type | |
| SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, | |
| const int32_t* nnz_per_row, const RhsType* rhs_ptr, | |
| const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, | |
| OutType* out_ptr, int64_t assigned_rows, | |
| int64_t rows /* only used in SpMM variants */, | |
| int64_t cols /* only used in SpMM variants */, int relu) { | |
| constexpr int kShiftAmount = 15 - WeightType::kExponentBits - | |
| RhsType::kExponentBits + OutType::kExponentBits; | |
| if (relu) { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "movi v25.4s, #0\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // accum_0 = 0 | |
| "dup v29.4s, v27.s[1]\n" // accum_1 = 0 | |
| "dup v30.4s, v27.s[2]\n" // accum_2 = 0 | |
| "dup v31.4s, v27.s[3]\n" // accum_3 = 0 | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each. | |
| "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" | |
| // Duplicate the lower half into the upper half. | |
| "mov v0.d[1], v0.d[0]\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Multiply-accumulate. | |
| "smlal v28.4s, v2.4h, v0.4h\n" | |
| "smlal2 v29.4s, v2.8h, v0.8h\n" | |
| "smlal v30.4s, v3.4h, v0.4h\n" | |
| "smlal2 v31.4s, v3.8h, v0.8h\n" | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| // Horizontally add accumulators and store result. | |
| "addp v28.4s, v28.4s, v29.4s\n" | |
| "addp v30.4s, v30.4s, v31.4s\n" | |
| "addp v28.4s, v28.4s, v30.4s\n" | |
| // Do relu if requested. | |
| "smax v28.4s, v28.4s, v25.4s\n" | |
| "sqrshrn v26.4h, v28.4s, %[shift_amount]\n" | |
| // Store accumulators. | |
| "st1 {v26.4h}, [%[out_ptr]], #8\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr) | |
| : // inputs | |
| [shift_amount] "I"(kShiftAmount) | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } else { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "movi v25.4s, #0\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // accum_0 = 0 | |
| "dup v29.4s, v27.s[1]\n" // accum_1 = 0 | |
| "dup v30.4s, v27.s[2]\n" // accum_2 = 0 | |
| "dup v31.4s, v27.s[3]\n" // accum_3 = 0 | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each. | |
| "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" | |
| // Duplicate the lower half into the upper half. | |
| "mov v0.d[1], v0.d[0]\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Multiply-accumulate. | |
| "smlal v28.4s, v2.4h, v0.4h\n" | |
| "smlal2 v29.4s, v2.8h, v0.8h\n" | |
| "smlal v30.4s, v3.4h, v0.4h\n" | |
| "smlal2 v31.4s, v3.8h, v0.8h\n" | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| // Horizontally add accumulators and store result. | |
| "addp v28.4s, v28.4s, v29.4s\n" | |
| "addp v30.4s, v30.4s, v31.4s\n" | |
| "addp v28.4s, v28.4s, v30.4s\n" | |
| "sqrshrn v26.4h, v28.4s, %[shift_amount]\n" | |
| // Store accumulators. | |
| "st1 {v26.4h}, [%[out_ptr]], #8\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr) | |
| : // inputs | |
| [shift_amount] "I"(kShiftAmount) | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } | |
| } | |
| // Note that the number of exponent bits in the output must exactly match | |
| // the sum of the input and rhs types. | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| typename std::enable_if<IsFixed16Type<WeightType>::value && | |
| IsFixed16Type<RhsType>::value && | |
| IsFixed16Type<OutType>::value>::type | |
| SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, | |
| const int32_t* nnz_per_row, const RhsType* rhs_ptr, | |
| const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, | |
| OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, | |
| int relu) { | |
| constexpr int kShiftAmount = 15 - WeightType::kExponentBits - | |
| RhsType::kExponentBits + OutType::kExponentBits; | |
| // Pointers to the columns. | |
| const RhsType* rhs2_ptr = rhs_ptr + cols; | |
| OutType* out2_ptr = out_ptr + rows; | |
| const RhsType* rhs3_ptr = rhs_ptr + 2 * cols; | |
| OutType* out3_ptr = out_ptr + 2 * rows; | |
| const RhsType* rhs4_ptr = rhs_ptr + 3 * cols; | |
| OutType* out4_ptr = out_ptr + 3 * rows; | |
| const RhsType* rhs5_ptr = rhs_ptr + 4 * cols; | |
| OutType* out5_ptr = out_ptr + 4 * rows; | |
| if (relu) { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "add %[rhs2_ptr], %[rhs2_ptr], x7\n" | |
| "add %[rhs3_ptr], %[rhs3_ptr], x7\n" | |
| "add %[rhs4_ptr], %[rhs4_ptr], x7\n" | |
| "add %[rhs5_ptr], %[rhs5_ptr], x7\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // for 1st column | |
| "dup v29.4s, v27.s[1]\n" // for 1st column | |
| "dup v30.4s, v27.s[2]\n" // for 1st column | |
| "dup v31.4s, v27.s[3]\n" // for 1st column | |
| "dup v23.4s, v27.s[0]\n" // for 2nd column | |
| "dup v24.4s, v27.s[1]\n" // for 2nd column | |
| "dup v25.4s, v27.s[2]\n" // for 2nd column | |
| "dup v26.4s, v27.s[3]\n" // for 2nd column | |
| "dup v19.4s, v27.s[0]\n" // for 3rd column | |
| "dup v20.4s, v27.s[1]\n" // for 3rd column | |
| "dup v21.4s, v27.s[2]\n" // for 3rd column | |
| "dup v22.4s, v27.s[3]\n" // for 3rd column | |
| "dup v15.4s, v27.s[0]\n" // for 4th column | |
| "dup v16.4s, v27.s[1]\n" // for 4th column | |
| "dup v17.4s, v27.s[2]\n" // for 4th column | |
| "dup v18.4s, v27.s[3]\n" // for 4th column | |
| "dup v11.4s, v27.s[0]\n" // for 5th column | |
| "dup v12.4s, v27.s[1]\n" // for 5th column | |
| "dup v13.4s, v27.s[2]\n" // for 5th column | |
| "dup v14.4s, v27.s[3]\n" // for 5th column | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. | |
| "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" | |
| "mov v0.d[1], v0.d[0]\n" | |
| "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" | |
| "mov v1.d[1], v1.d[0]\n" | |
| "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" | |
| "mov v8.d[1], v8.d[0]\n" | |
| "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" | |
| "mov v9.d[1], v9.d[0]\n" | |
| "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" | |
| "mov v10.d[1], v10.d[0]\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Multiply-accumulate. | |
| "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column | |
| "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column | |
| "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column | |
| "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh | |
| "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column | |
| "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column | |
| "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column | |
| "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column | |
| "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column | |
| "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column | |
| "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column | |
| "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column | |
| "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column | |
| "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column | |
| "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column | |
| "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column | |
| "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column | |
| "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column | |
| "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column | |
| "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| "movi v0.4s, #0\n" | |
| "addp v28.4s, v28.4s, v29.4s\n" // 1st column | |
| "addp v23.4s, v23.4s, v24.4s\n" // 2nd column | |
| "addp v19.4s, v19.4s, v20.4s\n" // 3rd column | |
| "addp v15.4s, v15.4s, v16.4s\n" // 4th column | |
| "addp v11.4s, v11.4s, v12.4s\n" // 5th column | |
| "addp v30.4s, v30.4s, v31.4s\n" // 1st column | |
| "addp v25.4s, v25.4s, v26.4s\n" // 2nd column | |
| "addp v21.4s, v21.4s, v22.4s\n" // 3rd column | |
| "addp v17.4s, v17.4s, v18.4s\n" // 4th column | |
| "addp v13.4s, v13.4s, v14.4s\n" // 5th column | |
| "addp v28.4s, v28.4s, v30.4s\n" // 1st column | |
| "addp v23.4s, v23.4s, v25.4s\n" // 2nd column | |
| "addp v19.4s, v19.4s, v21.4s\n" // 3rd column | |
| "addp v15.4s, v15.4s, v17.4s\n" // 4th column | |
| "addp v11.4s, v11.4s, v13.4s\n" // 5th column | |
| // Do relu as requested. | |
| "smax v28.4s, v28.4s, v0.4s\n" | |
| "smax v23.4s, v23.4s, v0.4s\n" | |
| "smax v19.4s, v19.4s, v0.4s\n" | |
| "smax v15.4s, v15.4s, v0.4s\n" | |
| "smax v11.4s, v11.4s, v0.4s\n" | |
| "sqrshrn v26.4h, v28.4s, %[shift_amount]\n" | |
| "sqrshrn v22.4h, v23.4s, %[shift_amount]\n" | |
| "sqrshrn v18.4h, v19.4s, %[shift_amount]\n" | |
| "sqrshrn v14.4h, v15.4s, %[shift_amount]\n" | |
| "sqrshrn v10.4h, v11.4s, %[shift_amount]\n" | |
| // Store accumulators. | |
| "st1 {v26.4h}, [%[out_ptr]], #8\n" | |
| "st1 {v22.4h}, [%[out2_ptr]], #8\n" | |
| "st1 {v18.4h}, [%[out3_ptr]], #8\n" | |
| "st1 {v14.4h}, [%[out4_ptr]], #8\n" | |
| "st1 {v10.4h}, [%[out5_ptr]], #8\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), | |
| [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), | |
| [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), | |
| [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), | |
| [rhs5_ptr] "+r"(rhs5_ptr) | |
| : // inputs | |
| [shift_amount] "I"(kShiftAmount) | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } else { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "add %[rhs2_ptr], %[rhs2_ptr], x7\n" | |
| "add %[rhs3_ptr], %[rhs3_ptr], x7\n" | |
| "add %[rhs4_ptr], %[rhs4_ptr], x7\n" | |
| "add %[rhs5_ptr], %[rhs5_ptr], x7\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // for 1st column | |
| "dup v29.4s, v27.s[1]\n" // for 1st column | |
| "dup v30.4s, v27.s[2]\n" // for 1st column | |
| "dup v31.4s, v27.s[3]\n" // for 1st column | |
| "dup v23.4s, v27.s[0]\n" // for 2nd column | |
| "dup v24.4s, v27.s[1]\n" // for 2nd column | |
| "dup v25.4s, v27.s[2]\n" // for 2nd column | |
| "dup v26.4s, v27.s[3]\n" // for 2nd column | |
| "dup v19.4s, v27.s[0]\n" // for 3rd column | |
| "dup v20.4s, v27.s[1]\n" // for 3rd column | |
| "dup v21.4s, v27.s[2]\n" // for 3rd column | |
| "dup v22.4s, v27.s[3]\n" // for 3rd column | |
| "dup v15.4s, v27.s[0]\n" // for 4th column | |
| "dup v16.4s, v27.s[1]\n" // for 4th column | |
| "dup v17.4s, v27.s[2]\n" // for 4th column | |
| "dup v18.4s, v27.s[3]\n" // for 4th column | |
| "dup v11.4s, v27.s[0]\n" // for 5th column | |
| "dup v12.4s, v27.s[1]\n" // for 5th column | |
| "dup v13.4s, v27.s[2]\n" // for 5th column | |
| "dup v14.4s, v27.s[3]\n" // for 5th column | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. | |
| "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" | |
| "mov v0.d[1], v0.d[0]\n" | |
| "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" | |
| "mov v1.d[1], v1.d[0]\n" | |
| "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" | |
| "mov v8.d[1], v8.d[0]\n" | |
| "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" | |
| "mov v9.d[1], v9.d[0]\n" | |
| "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" | |
| "mov v10.d[1], v10.d[0]\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Multiply-accumulate. | |
| "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column | |
| "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column | |
| "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column | |
| "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh | |
| "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column | |
| "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column | |
| "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column | |
| "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column | |
| "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column | |
| "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column | |
| "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column | |
| "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column | |
| "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column | |
| "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column | |
| "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column | |
| "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column | |
| "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column | |
| "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column | |
| "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column | |
| "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| "addp v28.4s, v28.4s, v29.4s\n" // 1st column | |
| "addp v23.4s, v23.4s, v24.4s\n" // 2nd column | |
| "addp v19.4s, v19.4s, v20.4s\n" // 3rd column | |
| "addp v15.4s, v15.4s, v16.4s\n" // 4th column | |
| "addp v11.4s, v11.4s, v12.4s\n" // 5th column | |
| "addp v30.4s, v30.4s, v31.4s\n" // 1st column | |
| "addp v25.4s, v25.4s, v26.4s\n" // 2nd column | |
| "addp v21.4s, v21.4s, v22.4s\n" // 3rd column | |
| "addp v17.4s, v17.4s, v18.4s\n" // 4th column | |
| "addp v13.4s, v13.4s, v14.4s\n" // 5th column | |
| "addp v28.4s, v28.4s, v30.4s\n" // 1st column | |
| "addp v23.4s, v23.4s, v25.4s\n" // 2nd column | |
| "addp v19.4s, v19.4s, v21.4s\n" // 3rd column | |
| "addp v15.4s, v15.4s, v17.4s\n" // 4th column | |
| "addp v11.4s, v11.4s, v13.4s\n" // 5th column | |
| "sqrshrn v26.4h, v28.4s, %[shift_amount]\n" | |
| "sqrshrn v22.4h, v23.4s, %[shift_amount]\n" | |
| "sqrshrn v18.4h, v19.4s, %[shift_amount]\n" | |
| "sqrshrn v14.4h, v15.4s, %[shift_amount]\n" | |
| "sqrshrn v10.4h, v11.4s, %[shift_amount]\n" | |
| // Store accumulators. | |
| "st1 {v26.4h}, [%[out_ptr]], #8\n" | |
| "st1 {v22.4h}, [%[out2_ptr]], #8\n" | |
| "st1 {v18.4h}, [%[out3_ptr]], #8\n" | |
| "st1 {v14.4h}, [%[out4_ptr]], #8\n" | |
| "st1 {v10.4h}, [%[out5_ptr]], #8\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), | |
| [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), | |
| [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), | |
| [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), | |
| [rhs5_ptr] "+r"(rhs5_ptr) | |
| : // inputs | |
| [shift_amount] "I"(kShiftAmount) | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } | |
| } | |
| // Note that the number of exponent bits in the output must exactly match | |
| // the sum of the input and rhs types. | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| typename std::enable_if< | |
| IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value && | |
| IsFixed32Type<OutType>::value && | |
| !std::is_same<OutType, typename TypeOfProduct<WeightType, | |
| RhsType>::type>::value>::type | |
| SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, | |
| const int32_t* nnz_per_row, const RhsType* rhs_ptr, | |
| const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, | |
| OutType* out_ptr, int64_t assigned_rows, | |
| int64_t rows /* only used in SpMM variants */, | |
| int64_t cols /* only used in SpMM variants */, int relu) { | |
| constexpr int kShiftAmount = | |
| TypeOfProduct<WeightType, RhsType>::type::kMantissaBits - | |
| OutType::kMantissaBits; | |
| static_assert(kShiftAmount > 0, | |
| "Result must have fewer mantissa bits than product"); | |
| if (relu) { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "movi v25.4s, #0\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // accum_0 = 0 | |
| "dup v29.4s, v27.s[1]\n" // accum_1 = 0 | |
| "dup v30.4s, v27.s[2]\n" // accum_2 = 0 | |
| "dup v31.4s, v27.s[3]\n" // accum_3 = 0 | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each. | |
| "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" | |
| // Duplicate the lower half into the upper half. | |
| "mov v0.d[1], v0.d[0]\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Multiply-accumulate. | |
| "smlal v28.4s, v2.4h, v0.4h\n" | |
| "smlal2 v29.4s, v2.8h, v0.8h\n" | |
| "smlal v30.4s, v3.4h, v0.4h\n" | |
| "smlal2 v31.4s, v3.8h, v0.8h\n" | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| // Horizontally add accumulators and store result. | |
| "addp v28.4s, v28.4s, v29.4s\n" | |
| "addp v30.4s, v30.4s, v31.4s\n" | |
| "addp v28.4s, v28.4s, v30.4s\n" | |
| // Do relu if requested. | |
| "smax v28.4s, v28.4s, v25.4s\n" | |
| "srshr v28.4s, v28.4s, %[shift_amount]\n" | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr) | |
| : // inputs | |
| [shift_amount] "I"(kShiftAmount) | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } else { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "movi v25.4s, #0\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // accum_0 = 0 | |
| "dup v29.4s, v27.s[1]\n" // accum_1 = 0 | |
| "dup v30.4s, v27.s[2]\n" // accum_2 = 0 | |
| "dup v31.4s, v27.s[3]\n" // accum_3 = 0 | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each. | |
| "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" | |
| // Duplicate the lower half into the upper half. | |
| "mov v0.d[1], v0.d[0]\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Multiply-accumulate. | |
| "smlal v28.4s, v2.4h, v0.4h\n" | |
| "smlal2 v29.4s, v2.8h, v0.8h\n" | |
| "smlal v30.4s, v3.4h, v0.4h\n" | |
| "smlal2 v31.4s, v3.8h, v0.8h\n" | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| // Horizontally add accumulators and store result. | |
| "addp v28.4s, v28.4s, v29.4s\n" | |
| "addp v30.4s, v30.4s, v31.4s\n" | |
| "addp v28.4s, v28.4s, v30.4s\n" | |
| "srshr v28.4s, v28.4s, %[shift_amount]\n" | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr) | |
| : // inputs | |
| [shift_amount] "I"(kShiftAmount) | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } | |
| } | |
| // Note that the number of exponent bits in the output must exactly match | |
| // the sum of the input and rhs types. | |
| template <typename WeightType, typename RhsType, typename OutType> | |
| typename std::enable_if< | |
| IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value && | |
| IsFixed32Type<OutType>::value && | |
| !std::is_same<OutType, typename TypeOfProduct<WeightType, | |
| RhsType>::type>::value>::type | |
| SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, | |
| const int32_t* nnz_per_row, const RhsType* rhs_ptr, | |
| const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, | |
| OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, | |
| int relu) { | |
| constexpr int kShiftAmount = | |
| TypeOfProduct<WeightType, RhsType>::type::kMantissaBits - | |
| OutType::kMantissaBits; | |
| static_assert(kShiftAmount > 0, | |
| "Result must have fewer mantissa bits than product"); | |
| // Pointers to the columns. | |
| const RhsType* rhs2_ptr = rhs_ptr + cols; | |
| OutType* out2_ptr = out_ptr + rows; | |
| const RhsType* rhs3_ptr = rhs_ptr + 2 * cols; | |
| OutType* out3_ptr = out_ptr + 2 * rows; | |
| const RhsType* rhs4_ptr = rhs_ptr + 3 * cols; | |
| OutType* out4_ptr = out_ptr + 3 * rows; | |
| const RhsType* rhs5_ptr = rhs_ptr + 4 * cols; | |
| OutType* out5_ptr = out_ptr + 4 * rows; | |
| if (relu) { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "add %[rhs2_ptr], %[rhs2_ptr], x7\n" | |
| "add %[rhs3_ptr], %[rhs3_ptr], x7\n" | |
| "add %[rhs4_ptr], %[rhs4_ptr], x7\n" | |
| "add %[rhs5_ptr], %[rhs5_ptr], x7\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // for 1st column | |
| "dup v29.4s, v27.s[1]\n" // for 1st column | |
| "dup v30.4s, v27.s[2]\n" // for 1st column | |
| "dup v31.4s, v27.s[3]\n" // for 1st column | |
| "dup v23.4s, v27.s[0]\n" // for 2nd column | |
| "dup v24.4s, v27.s[1]\n" // for 2nd column | |
| "dup v25.4s, v27.s[2]\n" // for 2nd column | |
| "dup v26.4s, v27.s[3]\n" // for 2nd column | |
| "dup v19.4s, v27.s[0]\n" // for 3rd column | |
| "dup v20.4s, v27.s[1]\n" // for 3rd column | |
| "dup v21.4s, v27.s[2]\n" // for 3rd column | |
| "dup v22.4s, v27.s[3]\n" // for 3rd column | |
| "dup v15.4s, v27.s[0]\n" // for 4th column | |
| "dup v16.4s, v27.s[1]\n" // for 4th column | |
| "dup v17.4s, v27.s[2]\n" // for 4th column | |
| "dup v18.4s, v27.s[3]\n" // for 4th column | |
| "dup v11.4s, v27.s[0]\n" // for 5th column | |
| "dup v12.4s, v27.s[1]\n" // for 5th column | |
| "dup v13.4s, v27.s[2]\n" // for 5th column | |
| "dup v14.4s, v27.s[3]\n" // for 5th column | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. | |
| "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" | |
| "mov v0.d[1], v0.d[0]\n" | |
| "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" | |
| "mov v1.d[1], v1.d[0]\n" | |
| "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" | |
| "mov v8.d[1], v8.d[0]\n" | |
| "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" | |
| "mov v9.d[1], v9.d[0]\n" | |
| "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" | |
| "mov v10.d[1], v10.d[0]\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Multiply-accumulate. | |
| "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column | |
| "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column | |
| "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column | |
| "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh | |
| "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column | |
| "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column | |
| "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column | |
| "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column | |
| "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column | |
| "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column | |
| "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column | |
| "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column | |
| "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column | |
| "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column | |
| "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column | |
| "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column | |
| "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column | |
| "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column | |
| "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column | |
| "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| "movi v0.4s, #0\n" | |
| "addp v28.4s, v28.4s, v29.4s\n" // 1st column | |
| "addp v23.4s, v23.4s, v24.4s\n" // 2nd column | |
| "addp v19.4s, v19.4s, v20.4s\n" // 3rd column | |
| "addp v15.4s, v15.4s, v16.4s\n" // 4th column | |
| "addp v11.4s, v11.4s, v12.4s\n" // 5th column | |
| "addp v30.4s, v30.4s, v31.4s\n" // 1st column | |
| "addp v25.4s, v25.4s, v26.4s\n" // 2nd column | |
| "addp v21.4s, v21.4s, v22.4s\n" // 3rd column | |
| "addp v17.4s, v17.4s, v18.4s\n" // 4th column | |
| "addp v13.4s, v13.4s, v14.4s\n" // 5th column | |
| "addp v28.4s, v28.4s, v30.4s\n" // 1st column | |
| "addp v23.4s, v23.4s, v25.4s\n" // 2nd column | |
| "addp v19.4s, v19.4s, v21.4s\n" // 3rd column | |
| "addp v15.4s, v15.4s, v17.4s\n" // 4th column | |
| "addp v11.4s, v11.4s, v13.4s\n" // 5th column | |
| // Do relu as requested. | |
| "smax v28.4s, v28.4s, v0.4s\n" | |
| "smax v23.4s, v23.4s, v0.4s\n" | |
| "smax v19.4s, v19.4s, v0.4s\n" | |
| "smax v15.4s, v15.4s, v0.4s\n" | |
| "smax v11.4s, v11.4s, v0.4s\n" | |
| "srshr v28.4s, v28.4s, %[shift_amount]\n" | |
| "srshr v23.4s, v23.4s, %[shift_amount]\n" | |
| "srshr v19.4s, v19.4s, %[shift_amount]\n" | |
| "srshr v15.4s, v15.4s, %[shift_amount]\n" | |
| "srshr v11.4s, v11.4s, %[shift_amount]\n" | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| "st1 {v23.4s}, [%[out2_ptr]], #16\n" | |
| "st1 {v19.4s}, [%[out3_ptr]], #16\n" | |
| "st1 {v15.4s}, [%[out4_ptr]], #16\n" | |
| "st1 {v11.4s}, [%[out5_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), | |
| [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), | |
| [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), | |
| [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), | |
| [rhs5_ptr] "+r"(rhs5_ptr) | |
| : // inputs | |
| [shift_amount] "I"(kShiftAmount) | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } else { | |
| asm( | |
| // Load the first two column deltas. | |
| "ldrsh x7, [%[col_deltas_bytes]], #2\n" | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // ld1 doesn't support pre-index, so we do the first addition here. | |
| "add %[rhs_ptr], %[rhs_ptr], x7\n" | |
| "add %[rhs2_ptr], %[rhs2_ptr], x7\n" | |
| "add %[rhs3_ptr], %[rhs3_ptr], x7\n" | |
| "add %[rhs4_ptr], %[rhs4_ptr], x7\n" | |
| "add %[rhs5_ptr], %[rhs5_ptr], x7\n" | |
| LABEL_ROW_LOOP | |
| ":\n" | |
| // Load the bias. | |
| "ld1 {v27.4s}, [%[bias_ptr]], #16\n" | |
| // Zero out local accumulators. | |
| "dup v28.4s, v27.s[0]\n" // for 1st column | |
| "dup v29.4s, v27.s[1]\n" // for 1st column | |
| "dup v30.4s, v27.s[2]\n" // for 1st column | |
| "dup v31.4s, v27.s[3]\n" // for 1st column | |
| "dup v23.4s, v27.s[0]\n" // for 2nd column | |
| "dup v24.4s, v27.s[1]\n" // for 2nd column | |
| "dup v25.4s, v27.s[2]\n" // for 2nd column | |
| "dup v26.4s, v27.s[3]\n" // for 2nd column | |
| "dup v19.4s, v27.s[0]\n" // for 3rd column | |
| "dup v20.4s, v27.s[1]\n" // for 3rd column | |
| "dup v21.4s, v27.s[2]\n" // for 3rd column | |
| "dup v22.4s, v27.s[3]\n" // for 3rd column | |
| "dup v15.4s, v27.s[0]\n" // for 4th column | |
| "dup v16.4s, v27.s[1]\n" // for 4th column | |
| "dup v17.4s, v27.s[2]\n" // for 4th column | |
| "dup v18.4s, v27.s[3]\n" // for 4th column | |
| "dup v11.4s, v27.s[0]\n" // for 5th column | |
| "dup v12.4s, v27.s[1]\n" // for 5th column | |
| "dup v13.4s, v27.s[2]\n" // for 5th column | |
| "dup v14.4s, v27.s[3]\n" // for 5th column | |
| // Update the stopping condition for this set of rows. | |
| "ldr w6, [%[nnz_per_row]], #4\n" | |
| "cmp w6, #0\n" | |
| // Skip the body if there isn't anything in this row. | |
| "beq " LABEL_SKIP_COL_LOOP "f\n" | |
| LABEL_COL_LOOP | |
| ":\n" | |
| // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. | |
| "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" | |
| "mov v0.d[1], v0.d[0]\n" | |
| "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" | |
| "mov v1.d[1], v1.d[0]\n" | |
| "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" | |
| "mov v8.d[1], v8.d[0]\n" | |
| "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" | |
| "mov v9.d[1], v9.d[0]\n" | |
| "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" | |
| "mov v10.d[1], v10.d[0]\n" | |
| // Start this load now, which we won't need until the end of the loop. | |
| "ldrsh x8, [%[col_deltas_bytes]], #2\n" | |
| // Load 16 Lhs cells corresponding to a 4x4 block. | |
| "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" | |
| // Multiply-accumulate. | |
| "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column | |
| "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column | |
| "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column | |
| "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh | |
| "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column | |
| "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column | |
| "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column | |
| "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column | |
| "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column | |
| "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column | |
| "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column | |
| "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column | |
| "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column | |
| "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column | |
| "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column | |
| "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column | |
| "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column | |
| "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column | |
| "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column | |
| "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column | |
| // Loop. Decrement loop index. | |
| "subs w6, w6, #1\n" // decrement (reduced) columns left | |
| "bne " LABEL_COL_LOOP "b\n" | |
| LABEL_SKIP_COL_LOOP | |
| ":\n" | |
| "addp v28.4s, v28.4s, v29.4s\n" // 1st column | |
| "addp v23.4s, v23.4s, v24.4s\n" // 2nd column | |
| "addp v19.4s, v19.4s, v20.4s\n" // 3rd column | |
| "addp v15.4s, v15.4s, v16.4s\n" // 4th column | |
| "addp v11.4s, v11.4s, v12.4s\n" // 5th column | |
| "addp v30.4s, v30.4s, v31.4s\n" // 1st column | |
| "addp v25.4s, v25.4s, v26.4s\n" // 2nd column | |
| "addp v21.4s, v21.4s, v22.4s\n" // 3rd column | |
| "addp v17.4s, v17.4s, v18.4s\n" // 4th column | |
| "addp v13.4s, v13.4s, v14.4s\n" // 5th column | |
| "addp v28.4s, v28.4s, v30.4s\n" // 1st column | |
| "addp v23.4s, v23.4s, v25.4s\n" // 2nd column | |
| "addp v19.4s, v19.4s, v21.4s\n" // 3rd column | |
| "addp v15.4s, v15.4s, v17.4s\n" // 4th column | |
| "addp v11.4s, v11.4s, v13.4s\n" // 5th column | |
| "srshr v28.4s, v28.4s, %[shift_amount]\n" | |
| "srshr v23.4s, v23.4s, %[shift_amount]\n" | |
| "srshr v19.4s, v19.4s, %[shift_amount]\n" | |
| "srshr v15.4s, v15.4s, %[shift_amount]\n" | |
| "srshr v11.4s, v11.4s, %[shift_amount]\n" | |
| // Store accumulators. | |
| "st1 {v28.4s}, [%[out_ptr]], #16\n" | |
| "st1 {v23.4s}, [%[out2_ptr]], #16\n" | |
| "st1 {v19.4s}, [%[out3_ptr]], #16\n" | |
| "st1 {v15.4s}, [%[out4_ptr]], #16\n" | |
| "st1 {v11.4s}, [%[out5_ptr]], #16\n" | |
| // Decrement rows remaining. | |
| "subs %[assigned_rows], %[assigned_rows], #1\n" | |
| "bne " LABEL_ROW_LOOP "b\n" | |
| // clang-format off | |
| : // outputs | |
| [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), | |
| [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), | |
| [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), | |
| [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), | |
| [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), | |
| [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), | |
| [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), | |
| [rhs5_ptr] "+r"(rhs5_ptr) | |
| : // inputs | |
| [shift_amount] "I"(kShiftAmount) | |
| : // clobbers | |
| "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", | |
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |
| "v26", "v27", "v28", "v29", "v30", "v31"); | |
| // clang-format on | |
| } | |
| } | |
| template <typename Type> | |
| typename std::enable_if<IsFixed32Type<Type>::value>::type SumVectors( | |
| int start, int end, const Type* add1, const Type* add2, Type* result) { | |
| constexpr int kSIMDWidth = 4; | |
| for (int i = start; i < end; i += kSIMDWidth) { | |
| int32x4_t add1_int = vld1q_s32(reinterpret_cast<const int32_t*>(add1 + i)); | |
| int32x4_t add2_int = vld1q_s32(reinterpret_cast<const int32_t*>(add2 + i)); | |
| int32x4_t result_int = vqaddq_s32(add1_int, add2_int); | |
| vst1q_s32(reinterpret_cast<int32_t*>(result + i), result_int); | |
| } | |
| } | |
| template <typename Type> | |
| typename std::enable_if<IsFixed16Type<Type>::value>::type SumVectors( | |
| int start, int end, const Type* add1, const Type* add2, Type* result) { | |
| constexpr int kSIMDWidth = 8; | |
| for (int i = start; i < end; i += kSIMDWidth) { | |
| int16x8_t add1_int = vld1q_s16(reinterpret_cast<const int16_t*>(add1 + i)); | |
| int16x8_t add2_int = vld1q_s16(reinterpret_cast<const int16_t*>(add2 + i)); | |
| int16x8_t result_int = vqaddq_s16(add1_int, add2_int); | |
| vst1q_s16(reinterpret_cast<int16_t*>(result + i), result_int); | |
| } | |
| } | |
| } // namespace detail | |
| } // namespace csrblocksparse | |