File size: 1,152 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import random
import argparse
import numpy as np
import torch
from s3prl.nn import S3PRLUpstream
from torch.nn.utils.rnn import pad_sequence
SAMPLE_RATE = 16000
BATCH_SIZE = 3
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("upstream")
parser.add_argument("--ckpt")
parser.add_argument("--device", default="cuda")
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
upstream = S3PRLUpstream(args.upstream, args.ckpt).to(args.device)
wavs = [
torch.randn(random.randint(SAMPLE_RATE * 1, SAMPLE_RATE * 15)).to(args.device)
for _ in range(BATCH_SIZE)
]
wavs_len = torch.LongTensor([len(w) for w in wavs]).to(args.device)
wavs = pad_sequence(wavs, batch_first=True)
with torch.no_grad():
upstream.eval()
hidden, hidden_len = upstream(wavs, wavs_len)
|