File size: 216 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
import torch

from s3prl.nn import FrameLevel


def test_FrameLevel(helpers):
    module = FrameLevel(3, 4, [5, 6])
    x = torch.randn(32, 10, 3)
    x_len = (torch.ones(32) * 3).long()
    h, hl = module(x, x_len)