File size: 2,535 Bytes
bfa885e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
from pathlib import Path
import shutil
import sys
import tempfile
import zipfile

pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))

from scipy.io import wavfile
import torch

from project_settings import project_path
from toolbox.torch.utils.data.vocabulary import Vocabulary
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_file",
        default=(project_path / "trained_models/vm_sound_classification3.zip").as_posix(),
        type=str
    )
    parser.add_argument(
        "--wav_file",
        default=r"C:\Users\tianx\Desktop\4b284733-0be3-4a48-abbb-615b32ac44b7_6ndddc2szlh0.wav",
        type=str
    )

    parser.add_argument("--device", default="cpu", type=str)

    args = parser.parse_args()
    return args


def main():
    args = get_args()

    model_file = Path(args.model_file)

    device = torch.device(args.device)

    with zipfile.ZipFile(model_file, "r") as f_zip:
        out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
        print(out_root)
        if out_root.exists():
            shutil.rmtree(out_root.as_posix())
        out_root.mkdir(parents=True, exist_ok=True)
        f_zip.extractall(path=out_root)

    tgt_path = out_root / model_file.stem
    vocab_path = tgt_path / "vocabulary"

    vocabulary = Vocabulary.from_files(vocab_path.as_posix())

    model = WaveClassifierPretrainedModel.from_pretrained(
        pretrained_model_name_or_path=tgt_path.as_posix(),
    )
    model.to(device)
    model.eval()

    # infer
    sample_rate, waveform = wavfile.read(args.wav_file)
    waveform = waveform[:16000]
    waveform = waveform / (1 << 15)
    waveform = torch.tensor(waveform, dtype=torch.float32)
    waveform = torch.unsqueeze(waveform, dim=0)
    waveform = waveform.to(device)
    print(waveform.shape)
    with torch.no_grad():
        logits = model.forward(waveform)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        label_idx = torch.argmax(probs, dim=-1)

    label_idx = label_idx.cpu()
    probs = probs.cpu()

    label_idx = label_idx.numpy()[0]
    prob = probs.numpy()[0][label_idx]

    label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
    print(label_str)
    print(prob)
    return


if __name__ == '__main__':
    main()