Spaces:
Running
Running
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| import argparse | |
| import os | |
| from pathlib import Path | |
| import shutil | |
| import sys | |
| import tempfile | |
| import zipfile | |
| pwd = os.path.abspath(os.path.dirname(__file__)) | |
| sys.path.append(os.path.join(pwd, "../../")) | |
| from scipy.io import wavfile | |
| import torch | |
| from project_settings import project_path | |
| from toolbox.torch.utils.data.vocabulary import Vocabulary | |
| from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--model_file", | |
| default=(project_path / "trained_models/vm_sound_classification3.zip").as_posix(), | |
| type=str | |
| ) | |
| parser.add_argument( | |
| "--wav_file", | |
| default=r"C:\Users\tianx\Desktop\4b284733-0be3-4a48-abbb-615b32ac44b7_6ndddc2szlh0.wav", | |
| type=str | |
| ) | |
| parser.add_argument("--device", default="cpu", type=str) | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = get_args() | |
| model_file = Path(args.model_file) | |
| device = torch.device(args.device) | |
| with zipfile.ZipFile(model_file, "r") as f_zip: | |
| out_root = Path(tempfile.gettempdir()) / "vm_sound_classification" | |
| print(out_root) | |
| if out_root.exists(): | |
| shutil.rmtree(out_root.as_posix()) | |
| out_root.mkdir(parents=True, exist_ok=True) | |
| f_zip.extractall(path=out_root) | |
| tgt_path = out_root / model_file.stem | |
| vocab_path = tgt_path / "vocabulary" | |
| vocabulary = Vocabulary.from_files(vocab_path.as_posix()) | |
| model = WaveClassifierPretrainedModel.from_pretrained( | |
| pretrained_model_name_or_path=tgt_path.as_posix(), | |
| ) | |
| model.to(device) | |
| model.eval() | |
| # infer | |
| sample_rate, waveform = wavfile.read(args.wav_file) | |
| waveform = waveform[:16000] | |
| waveform = waveform / (1 << 15) | |
| waveform = torch.tensor(waveform, dtype=torch.float32) | |
| waveform = torch.unsqueeze(waveform, dim=0) | |
| waveform = waveform.to(device) | |
| print(waveform.shape) | |
| with torch.no_grad(): | |
| logits = model.forward(waveform) | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| label_idx = torch.argmax(probs, dim=-1) | |
| label_idx = label_idx.cpu() | |
| probs = probs.cpu() | |
| label_idx = label_idx.numpy()[0] | |
| prob = probs.numpy()[0][label_idx] | |
| label_str = vocabulary.get_token_from_index(label_idx, namespace="labels") | |
| print(label_str) | |
| print(prob) | |
| return | |
| if __name__ == '__main__': | |
| main() | |