#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import os from pathlib import Path import shutil import sys import tempfile import zipfile pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) from scipy.io import wavfile import torch from project_settings import project_path from toolbox.torch.utils.data.vocabulary import Vocabulary def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model_file", default=(project_path / "trained_models/vm_sound_classification3.zip").as_posix(), type=str ) parser.add_argument( "--wav_file", default=r"C:\Users\tianx\Desktop\a073d03d-d280-46df-9b2d-d904965f4500_zh-CN_h3f25ivhb0c0_1719478037746.wav", type=str ) parser.add_argument("--device", default="cpu", type=str) args = parser.parse_args() return args def main(): args = get_args() model_file = Path(args.model_file) device = torch.device(args.device) with zipfile.ZipFile(model_file, "r") as f_zip: out_root = Path(tempfile.gettempdir()) / "vm_sound_classification" print(out_root.as_posix()) if out_root.exists(): shutil.rmtree(out_root.as_posix()) out_root.mkdir(parents=True, exist_ok=True) f_zip.extractall(path=out_root) tgt_path = out_root / model_file.stem jit_model_file = tgt_path / "trace_model.zip" vocab_path = tgt_path / "vocabulary" with open(jit_model_file.as_posix(), "rb") as f: model = torch.jit.load(f) model.to(device) model.eval() vocabulary = Vocabulary.from_files(vocab_path.as_posix()) # infer sample_rate, waveform = wavfile.read(args.wav_file) waveform = waveform[:16000] waveform = waveform / (1 << 15) waveform = torch.tensor(waveform, dtype=torch.float32) waveform = torch.unsqueeze(waveform, dim=0) waveform = waveform.to(device) with torch.no_grad(): logits = model.forward(waveform) probs = torch.nn.functional.softmax(logits, dim=-1) label_idx = torch.argmax(probs, dim=-1) label_idx = label_idx.cpu() probs = probs.cpu() label_idx = label_idx.numpy()[0] prob = probs.numpy()[0][label_idx] label_str = vocabulary.get_token_from_index(label_idx, namespace="labels") print(label_str) print(prob) return if __name__ == '__main__': main()