Spaces:
Sleeping
Sleeping
| // Copyright 2021 Google LLC | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| namespace { | |
| using csrblocksparse::ARInputsMode; | |
| template <typename GRUStateType, typename InputType, typename SampleType = void, | |
| csrblocksparse::ARInputsMode kInputsMode, bool kSplitGates> | |
| csrblocksparse::CacheAlignedVector<GRUStateType> TestGruGates() { | |
| using SampleWeightType = float; | |
| constexpr int kStateSize = 16; | |
| csrblocksparse::CacheAlignedVector<SampleWeightType> qr(6 * kStateSize); | |
| csrblocksparse::CacheAlignedVector<SampleWeightType> w(3 * kStateSize); | |
| csrblocksparse::CacheAlignedVector<InputType> gru_gates(3 * kStateSize); | |
| csrblocksparse::CacheAlignedVector<InputType> gru_other_gates(3 * kStateSize); | |
| csrblocksparse::CacheAlignedVector<InputType> conditioning(3 * kStateSize); | |
| csrblocksparse::CacheAlignedVector<GRUStateType> gru_h(kStateSize); | |
| csrblocksparse::GruGates<GRUStateType, InputType, SampleType> gru_gates_impl; | |
| const SampleType kCoarseAtSMinus1(0.03f); | |
| const SampleType kFineAtSMinus1(0.07f); | |
| const SampleType kCoarseAtS(-0.02f); | |
| qr.FillOnes(); | |
| w.FillOnes(); | |
| gru_gates.FillRandom(); | |
| gru_other_gates.FillRandom(); | |
| conditioning.FillRandom(); | |
| gru_h.FillZero(); | |
| gru_gates_impl.template GruWithARInput<kInputsMode, kSplitGates>( | |
| /*start=*/0, /*end=*/kStateSize, kStateSize, gru_gates.data(), | |
| conditioning.data(), gru_h.data(), &kCoarseAtSMinus1, &kFineAtSMinus1, | |
| qr.data(), | |
| /*num_replicas=*/1, /*replica_stride=*/0, &kCoarseAtS, w.data(), | |
| gru_other_gates.data()); | |
| return gru_h; | |
| } | |
| TEST(GruGates, FloatWaveRNNCoarseMatchesGolden) { | |
| // If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers | |
| // will also need to change. | |
| const std::vector<float> kGoldenValues = { | |
| 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.746f, 0.0f, 0.0f, | |
| 0.0f, 0.0f, 0.970f, 0.0f, 0.0f, 1.0f, 0.0f, -0.993f}; | |
| csrblocksparse::CacheAlignedVector<float> gru_h = | |
| TestGruGates<float, float, float, ARInputsMode::k2ARInputs, | |
| /*kSplitGates=*/true>(); | |
| ASSERT_EQ(kGoldenValues.size(), gru_h.size()); | |
| for (int i = 0; i < gru_h.size(); ++i) { | |
| EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i; | |
| } | |
| } | |
| TEST(GruGates, FloatWaveRNNFineMatchesGolden) { | |
| // If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers | |
| // will also need to change. | |
| const std::vector<float> kGoldenValues = { | |
| 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.737f, 0.0f, 0.0f, | |
| 0.0f, 0.0f, 0.969f, 0.0f, 0.0f, 1.0f, 0.0f, -0.994f}; | |
| csrblocksparse::CacheAlignedVector<float> gru_h = | |
| TestGruGates<float, float, float, ARInputsMode::k3ARInputs, | |
| /*kSplitGates=*/true>(); | |
| ASSERT_EQ(kGoldenValues.size(), gru_h.size()); | |
| for (int i = 0; i < gru_h.size(); ++i) { | |
| EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i; | |
| } | |
| } | |
| TEST(GruGates, FloatTwoArInputsNonSplitGateMatchesGolden) { | |
| // If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers | |
| // will also need to change. | |
| const std::vector<float> kGoldenValues = { | |
| 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.714f, 0.0f, -0.002f, | |
| 0.0f, 0.0f, 0.970f, 0.0f, 0.0f, 1.0f, 0.0f, -0.965f}; | |
| csrblocksparse::CacheAlignedVector<float> gru_h = | |
| TestGruGates<float, float, float, ARInputsMode::k2ARInputs, | |
| /*kSplitGates=*/false>(); | |
| ASSERT_EQ(kGoldenValues.size(), gru_h.size()); | |
| for (int i = 0; i < gru_h.size(); ++i) { | |
| EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i; | |
| } | |
| } | |
| TEST(GruGates, FixedWaveRNNCoarseMatchesFloat) { | |
| using GRUMatMulOutType = csrblocksparse::fixed32<11>; | |
| using GRUStateType = csrblocksparse::fixed16<2>; | |
| using SampleType = csrblocksparse::fixed16<0>; | |
| csrblocksparse::CacheAlignedVector<float> float_gru_h = | |
| TestGruGates<float, float, float, ARInputsMode::k2ARInputs, | |
| /*kSplitGates=*/true>(); | |
| csrblocksparse::CacheAlignedVector<GRUStateType> fixed_gru_h = | |
| TestGruGates<GRUStateType, GRUMatMulOutType, SampleType, | |
| ARInputsMode::k2ARInputs, /*kSplitGates=*/true>(); | |
| ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size()); | |
| for (int i = 0; i < fixed_gru_h.size(); ++i) { | |
| EXPECT_NEAR(float_gru_h[i], static_cast<float>(fixed_gru_h[i]), 1e-3) | |
| << "i=" << i; | |
| } | |
| } | |
| TEST(GruGates, FixedWaveRNNFineMatchesFloat) { | |
| using GRUMatMulOutType = csrblocksparse::fixed32<11>; | |
| using GRUStateType = csrblocksparse::fixed16<2>; | |
| using SampleType = csrblocksparse::fixed16<0>; | |
| csrblocksparse::CacheAlignedVector<float> float_gru_h = | |
| TestGruGates<float, float, float, ARInputsMode::k3ARInputs, | |
| /*kSplitGates=*/true>(); | |
| csrblocksparse::CacheAlignedVector<GRUStateType> fixed_gru_h = | |
| TestGruGates<GRUStateType, GRUMatMulOutType, SampleType, | |
| ARInputsMode::k3ARInputs, /*kSplitGates=*/true>(); | |
| ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size()); | |
| for (int i = 0; i < fixed_gru_h.size(); ++i) { | |
| EXPECT_NEAR(float_gru_h[i], static_cast<float>(fixed_gru_h[i]), 1e-3) | |
| << "i=" << i; | |
| } | |
| } | |
| TEST(GruGates, FixedTwoArInputsNonSplitGateMatchesFloat) { | |
| using GRUMatMulOutType = csrblocksparse::fixed32<11>; | |
| using GRUStateType = csrblocksparse::fixed16<2>; | |
| using SampleType = csrblocksparse::fixed16<0>; | |
| csrblocksparse::CacheAlignedVector<float> float_gru_h = | |
| TestGruGates<float, float, float, ARInputsMode::k2ARInputs, | |
| /*kSplitGates=*/false>(); | |
| csrblocksparse::CacheAlignedVector<GRUStateType> fixed_gru_h = | |
| TestGruGates<GRUStateType, GRUMatMulOutType, SampleType, | |
| ARInputsMode::k2ARInputs, /*kSplitGates=*/false>(); | |
| ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size()); | |
| for (int i = 0; i < fixed_gru_h.size(); ++i) { | |
| EXPECT_NEAR(float_gru_h[i], static_cast<float>(fixed_gru_h[i]), 1e-3) | |
| << "i=" << i; | |
| } | |
| } | |
| } // namespace | |