File size: 642 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import pytest
import torch

from s3prl.nn.common import UtteranceLevel
from s3prl.nn.pooling import (
    AttentiveStatisticsPooling,
    MeanPooling,
    SelfAttentivePooling,
    TemporalStatisticsPooling,
)


@pytest.mark.parametrize(
    "pooling_type",
    [
        "MeanPooling",
        "TemporalStatisticsPooling",
        "AttentiveStatisticsPooling",
        "SelfAttentivePooling",
    ],
)
def test_utterance_level_with_pooling(pooling_type: str):
    model = UtteranceLevel(256, 64, [128], "ReLU", None, pooling_type, None)
    output = model(torch.randn(32, 100, 256), torch.arange(32) + 1)
    assert output.shape == (32, 64)