File size: 833 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from utility.helper import get_transformer_tester

example_path = './result/result_transformer/tera/fmllrBase960-F-N-K-libri/states-1000000.ckpt'
tester= get_transformer_tester(from_path=example_path)

# A batch of spectrograms: (batch_size, seq_len, feature_size)
spec = torch.zeros(3, 800, 40)

# reps.shape: (batch_size, num_hiddem_layers, seq_len, hidden_size)
reps = tester.forward(spec=spec, all_layers=True, tile=True)

# reps.shape: (batch_size, num_hiddem_layers, seq_len // downsample_rate, hidden_size)
reps = tester.forward(spec=spec, all_layers=True, tile=False)

# reps.shape: (batch_size, seq_len, hidden_size)
reps = tester.forward(spec=spec, all_layers=False, tile=True)

# reps.shape: (batch_size, seq_len // downsample_rate, hidden_size)
reps = tester.forward(spec=spec, all_layers=False, tile=False)