File size: 277 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import pytest
import torch

from s3prl.nn import BeamDecoder


@pytest.mark.extra_dependency
def test_beam_decoder():
    decoder = BeamDecoder()
    emissions = torch.randn((4, 100, 31))
    emissions = torch.log_softmax(emissions, dim=2)
    hyps = decoder.decode(emissions)