File size: 642 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import pytest
import torch
from s3prl.nn.common import UtteranceLevel
from s3prl.nn.pooling import (
AttentiveStatisticsPooling,
MeanPooling,
SelfAttentivePooling,
TemporalStatisticsPooling,
)
@pytest.mark.parametrize(
"pooling_type",
[
"MeanPooling",
"TemporalStatisticsPooling",
"AttentiveStatisticsPooling",
"SelfAttentivePooling",
],
)
def test_utterance_level_with_pooling(pooling_type: str):
model = UtteranceLevel(256, 64, [128], "ReLU", None, pooling_type, None)
output = model(torch.randn(32, 100, 256), torch.arange(32) + 1)
assert output.shape == (32, 64)
|