lmzjms's picture
Upload 1162 files
0b32ad6 verified
raw
history blame
2.77 kB
import torch
import logging
import argparse
from pathlib import Path
from s3prl import hub
from s3prl.util.pseudo_data import get_pseudo_wavs
SAMPLE_RATE = 16000
logger = logging.getLogger(__name__)
def extract_single_name(
name: str,
ckpt: str,
legacy: bool,
output_dir: str,
device: str,
refresh: bool = False,
):
output_dir: Path = Path(output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
output_path = str((output_dir / f"{name}.pt").resolve())
if Path(output_path).is_file() and not refresh:
return
model = getattr(hub, name)(ckpt=ckpt, legacy=legacy).to(device)
model.eval()
with torch.no_grad():
hidden_states = model(get_pseudo_wavs(device=device))["hidden_states"]
hs = [h.detach().cpu() for h in hidden_states]
torch.save(hs, output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("output_dir")
parser.add_argument("--all", action="store_true")
parser.add_argument("--name")
parser.add_argument("--ckpt")
parser.add_argument("--device", default="cpu")
parser.add_argument("--legacy", action="store_true")
parser.add_argument("--refresh", action="store_true")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.all:
options = [
name
for name in hub.options(only_registered_ckpt=True)
if (not name == "customized_upstream")
and (
not "mos" in name
) # mos models do not have hidden_states key. They only return a single mos score
and (
not "stft_mag" in name
) # stft_mag upstream must past the config file currently and is not so important. So, skip the test now
and (
not "pase" in name
) # pase_plus needs lots of dependencies and is difficult to be tested and is not very worthy today
and (
not name == "xls_r_1b"
) # skip due to too large model, too long download time
and (
not name == "xls_r_2b"
) # skip due to too large model, too long download time
]
logger.info(f"Extract for: {options}")
for option in options:
extract_single_name(
option,
args.ckpt,
args.legacy,
args.output_dir,
args.device,
args.refresh,
)
else:
assert args.name is not None
extract_single_name(
args.name,
args.ckpt,
args.legacy,
args.output_dir,
args.device,
args.refresh,
)