lmzjms's picture
Upload 1162 files
0b32ad6 verified
raw
history blame
1.16 kB
import logging
import argparse
from pathlib import Path
import torch
from s3prl.nn import S3PRLUpstream
from s3prl.util.pseudo_data import get_pseudo_wavs
from s3prl.util.override import parse_overrides
logger = logging.getLogger(__name__)
SAMPLE_RATE = 16000
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument("name")
parser.add_argument("--output_dir", default="./sample_hidden_states")
parser.add_argument("--refresh", action="store_true")
parser.add_argument("--device", default="cuda")
args, others = parser.parse_known_args()
overrides = parse_overrides(others)
output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
model = S3PRLUpstream(args.name, refresh=args.refresh, extra_conf=overrides).to(
args.device
)
model.eval()
with torch.no_grad():
x, x_len = get_pseudo_wavs(padded=True)
hs, hs_len = model(x.to(args.device), x_len.to(args.device))
hs = [h.detach().cpu() for h, h_len in zip(hs, hs_len)]
torch.save(hs, output_dir / f"{args.name}.pt")