#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import os from pathlib import Path import shutil import sys import tempfile import zipfile pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) from scipy.io import wavfile import torch from project_settings import project_path from toolbox.torch.utils.data.vocabulary import Vocabulary from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model_file", default=(project_path / "trained_models/vm_sound_classification3.zip").as_posix(), type=str ) parser.add_argument( "--wav_file", default=r"C:\Users\tianx\Desktop\4b284733-0be3-4a48-abbb-615b32ac44b7_6ndddc2szlh0.wav", type=str ) parser.add_argument("--device", default="cpu", type=str) args = parser.parse_args() return args def main(): args = get_args() model_file = Path(args.model_file) device = torch.device(args.device) with zipfile.ZipFile(model_file, "r") as f_zip: out_root = Path(tempfile.gettempdir()) / "vm_sound_classification" print(out_root) if out_root.exists(): shutil.rmtree(out_root.as_posix()) out_root.mkdir(parents=True, exist_ok=True) f_zip.extractall(path=out_root) tgt_path = out_root / model_file.stem vocab_path = tgt_path / "vocabulary" vocabulary = Vocabulary.from_files(vocab_path.as_posix()) model = WaveClassifierPretrainedModel.from_pretrained( pretrained_model_name_or_path=tgt_path.as_posix(), ) model.to(device) model.eval() # infer sample_rate, waveform = wavfile.read(args.wav_file) waveform = waveform[:16000] waveform = waveform / (1 << 15) waveform = torch.tensor(waveform, dtype=torch.float32) waveform = torch.unsqueeze(waveform, dim=0) waveform = waveform.to(device) print(waveform.shape) with torch.no_grad(): logits = model.forward(waveform) probs = torch.nn.functional.softmax(logits, dim=-1) label_idx = torch.argmax(probs, dim=-1) label_idx = label_idx.cpu() probs = probs.cpu() label_idx = label_idx.numpy()[0] prob = probs.numpy()[0][label_idx] label_str = vocabulary.get_token_from_index(label_idx, namespace="labels") print(label_str) print(prob) return if __name__ == '__main__': main()