HoneyTian's picture
update
69ad385
raw
history blame
2.54 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
from pathlib import Path
import shutil
import sys
import tempfile
import zipfile
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))
from scipy.io import wavfile
import torch
from project_settings import project_path
from toolbox.torch.utils.data.vocabulary import Vocabulary
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_file",
default=(project_path / "trained_models/vm_sound_classification3.zip").as_posix(),
type=str
)
parser.add_argument(
"--wav_file",
default=r"C:\Users\tianx\Desktop\4b284733-0be3-4a48-abbb-615b32ac44b7_6ndddc2szlh0.wav",
type=str
)
parser.add_argument("--device", default="cpu", type=str)
args = parser.parse_args()
return args
def main():
args = get_args()
model_file = Path(args.model_file)
device = torch.device(args.device)
with zipfile.ZipFile(model_file, "r") as f_zip:
out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
print(out_root)
if out_root.exists():
shutil.rmtree(out_root.as_posix())
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
tgt_path = out_root / model_file.stem
vocab_path = tgt_path / "vocabulary"
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
model = WaveClassifierPretrainedModel.from_pretrained(
pretrained_model_name_or_path=tgt_path.as_posix(),
)
model.to(device)
model.eval()
# infer
sample_rate, waveform = wavfile.read(args.wav_file)
waveform = waveform[:16000]
waveform = waveform / (1 << 15)
waveform = torch.tensor(waveform, dtype=torch.float32)
waveform = torch.unsqueeze(waveform, dim=0)
waveform = waveform.to(device)
print(waveform.shape)
with torch.no_grad():
logits = model.forward(waveform)
probs = torch.nn.functional.softmax(logits, dim=-1)
label_idx = torch.argmax(probs, dim=-1)
label_idx = label_idx.cpu()
probs = probs.cpu()
label_idx = label_idx.numpy()[0]
prob = probs.numpy()[0][label_idx]
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
print(label_str)
print(prob)
return
if __name__ == '__main__':
main()