File size: 216 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 |
import torch
from s3prl.nn import FrameLevel
def test_FrameLevel(helpers):
module = FrameLevel(3, 4, [5, 6])
x = torch.randn(32, 10, 3)
x_len = (torch.ones(32) * 3).long()
h, hl = module(x, x_len)
|