| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import random | |
| import torch | |
| from audiocraft.modules.lstm import StreamableLSTM | |
| class TestStreamableLSTM: | |
| def test_lstm(self): | |
| B, C, T = 4, 2, random.randint(1, 100) | |
| lstm = StreamableLSTM(C, 3, skip=False) | |
| x = torch.randn(B, C, T) | |
| y = lstm(x) | |
| print(y.shape) | |
| assert y.shape == torch.Size([B, C, T]) | |
| def test_lstm_skip(self): | |
| B, C, T = 4, 2, random.randint(1, 100) | |
| lstm = StreamableLSTM(C, 3, skip=True) | |
| x = torch.randn(B, C, T) | |
| y = lstm(x) | |
| assert y.shape == torch.Size([B, C, T]) | |