File size: 1,152 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import random
import argparse
import numpy as np

import torch
from s3prl.nn import S3PRLUpstream
from torch.nn.utils.rnn import pad_sequence

SAMPLE_RATE = 16000
BATCH_SIZE = 3


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("upstream")
    parser.add_argument("--ckpt")
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    upstream = S3PRLUpstream(args.upstream, args.ckpt).to(args.device)
    wavs = [
        torch.randn(random.randint(SAMPLE_RATE * 1, SAMPLE_RATE * 15)).to(args.device)
        for _ in range(BATCH_SIZE)
    ]
    wavs_len = torch.LongTensor([len(w) for w in wavs]).to(args.device)
    wavs = pad_sequence(wavs, batch_first=True)

    with torch.no_grad():
        upstream.eval()
        hidden, hidden_len = upstream(wavs, wavs_len)