Spaces:
Running
Running
update
Browse files
examples/sound_classification_by_lstm/step_6_export_onnx_model.py
CHANGED
|
@@ -120,9 +120,9 @@ def main():
|
|
| 120 |
"logits", "new_h", "new_c"
|
| 121 |
]
|
| 122 |
logits, new_h, new_c = ort_session.run(output_names, input_feed)
|
| 123 |
-
print(f"logits: {logits.shape}")
|
| 124 |
-
print(f"new_h: {new_h.shape}")
|
| 125 |
-
print(f"new_c: {new_c.shape}")
|
| 126 |
return
|
| 127 |
|
| 128 |
|
|
|
|
| 120 |
"logits", "new_h", "new_c"
|
| 121 |
]
|
| 122 |
logits, new_h, new_c = ort_session.run(output_names, input_feed)
|
| 123 |
+
# print(f"logits: {logits.shape}")
|
| 124 |
+
# print(f"new_h: {new_h.shape}")
|
| 125 |
+
# print(f"new_c: {new_c.shape}")
|
| 126 |
return
|
| 127 |
|
| 128 |
|
examples/sound_classification_by_lstm/step_8_test_onnx_model.py
CHANGED
|
@@ -31,7 +31,8 @@ def get_args():
|
|
| 31 |
)
|
| 32 |
parser.add_argument(
|
| 33 |
"--wav_file",
|
| 34 |
-
default=r"C:\Users\tianx\Desktop\a073d03d-d280-46df-9b2d-d904965f4500_zh-CN_h3f25ivhb0c0_1719478037746.wav",
|
|
|
|
| 35 |
type=str
|
| 36 |
)
|
| 37 |
|
|
@@ -107,10 +108,23 @@ def main():
|
|
| 107 |
"logits", "new_h", "new_c"
|
| 108 |
]
|
| 109 |
logits, new_h, new_c = ort_session.run(output_names, input_feed)
|
| 110 |
-
print(f"logits: {logits.shape}")
|
| 111 |
-
print(f"new_h: {new_h.shape}")
|
| 112 |
-
print(f"new_c: {new_c.shape}")
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
return
|
| 115 |
|
| 116 |
|
|
|
|
| 31 |
)
|
| 32 |
parser.add_argument(
|
| 33 |
"--wav_file",
|
| 34 |
+
# default=r"C:\Users\tianx\Desktop\a073d03d-d280-46df-9b2d-d904965f4500_zh-CN_h3f25ivhb0c0_1719478037746.wav",
|
| 35 |
+
default=(project_path / "data/examples/examples/zh-TW/voicemail/00a1d109-23c2-4b8b-a066-993ac2ae8260_zh-TW_1672210785598.wav").as_posix(),
|
| 36 |
type=str
|
| 37 |
)
|
| 38 |
|
|
|
|
| 108 |
"logits", "new_h", "new_c"
|
| 109 |
]
|
| 110 |
logits, new_h, new_c = ort_session.run(output_names, input_feed)
|
| 111 |
+
# print(f"logits: {logits.shape}")
|
| 112 |
+
# print(f"new_h: {new_h.shape}")
|
| 113 |
+
# print(f"new_c: {new_c.shape}")
|
| 114 |
|
| 115 |
+
logits = torch.tensor(logits, dtype=torch.float32)
|
| 116 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 117 |
+
label_idx = torch.argmax(probs, dim=-1)
|
| 118 |
+
|
| 119 |
+
label_idx = label_idx.cpu()
|
| 120 |
+
probs = probs.cpu()
|
| 121 |
+
|
| 122 |
+
label_idx = label_idx.numpy()[0]
|
| 123 |
+
prob = probs.numpy()[0][label_idx]
|
| 124 |
+
|
| 125 |
+
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
|
| 126 |
+
print(label_str)
|
| 127 |
+
print(prob)
|
| 128 |
return
|
| 129 |
|
| 130 |
|
tabs/cls_tab.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
| 1 |
#!/usr/bin/python3
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
import argparse
|
|
|
|
| 4 |
from functools import lru_cache
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
import platform
|
| 7 |
import shutil
|
| 8 |
import tempfile
|
|
|
|
| 9 |
import zipfile
|
| 10 |
from typing import Tuple
|
| 11 |
|
|
@@ -61,10 +64,12 @@ def when_click_cls_button(audio_t,
|
|
| 61 |
inputs = torch.tensor(inputs, dtype=torch.float32)
|
| 62 |
inputs = torch.unsqueeze(inputs, dim=0)
|
| 63 |
|
|
|
|
| 64 |
with torch.no_grad():
|
| 65 |
logits = model.forward(inputs)
|
| 66 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 67 |
label_idx = torch.argmax(probs, dim=-1)
|
|
|
|
| 68 |
|
| 69 |
label_idx = label_idx.cpu()
|
| 70 |
probs = probs.cpu()
|
|
@@ -74,7 +79,13 @@ def when_click_cls_button(audio_t,
|
|
| 74 |
|
| 75 |
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
|
| 76 |
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
def get_cls_tab(examples_dir: str, trained_model_dir: str):
|
|
@@ -121,13 +132,12 @@ def get_cls_tab(examples_dir: str, trained_model_dir: str):
|
|
| 121 |
|
| 122 |
cls_button = gr.Button("run", variant="primary")
|
| 123 |
with gr.Column(scale=3):
|
| 124 |
-
|
| 125 |
-
cls_probability = gr.Number(label="probability")
|
| 126 |
|
| 127 |
gr.Examples(
|
| 128 |
cls_examples,
|
| 129 |
inputs=[cls_audio, cls_model_name, cls_ground_true],
|
| 130 |
-
outputs=[
|
| 131 |
fn=when_click_cls_button,
|
| 132 |
examples_per_page=5,
|
| 133 |
)
|
|
@@ -135,7 +145,7 @@ def get_cls_tab(examples_dir: str, trained_model_dir: str):
|
|
| 135 |
cls_button.click(
|
| 136 |
when_click_cls_button,
|
| 137 |
inputs=[cls_audio, cls_model_name, cls_ground_true],
|
| 138 |
-
outputs=[
|
| 139 |
)
|
| 140 |
|
| 141 |
return locals()
|
|
|
|
| 1 |
#!/usr/bin/python3
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
import argparse
|
| 4 |
+
import json
|
| 5 |
from functools import lru_cache
|
| 6 |
+
from os import times
|
| 7 |
from pathlib import Path
|
| 8 |
import platform
|
| 9 |
import shutil
|
| 10 |
import tempfile
|
| 11 |
+
import time
|
| 12 |
import zipfile
|
| 13 |
from typing import Tuple
|
| 14 |
|
|
|
|
| 64 |
inputs = torch.tensor(inputs, dtype=torch.float32)
|
| 65 |
inputs = torch.unsqueeze(inputs, dim=0)
|
| 66 |
|
| 67 |
+
time_begin = time.time()
|
| 68 |
with torch.no_grad():
|
| 69 |
logits = model.forward(inputs)
|
| 70 |
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 71 |
label_idx = torch.argmax(probs, dim=-1)
|
| 72 |
+
time_cost = time.time() - time_begin
|
| 73 |
|
| 74 |
label_idx = label_idx.cpu()
|
| 75 |
probs = probs.cpu()
|
|
|
|
| 79 |
|
| 80 |
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
|
| 81 |
|
| 82 |
+
result = {
|
| 83 |
+
"label": label_str,
|
| 84 |
+
"prob": round(float(prob), 4),
|
| 85 |
+
"time_cost": round(time_cost, 4),
|
| 86 |
+
}
|
| 87 |
+
result = json.dumps(result, ensure_ascii=False, indent=4)
|
| 88 |
+
return result
|
| 89 |
|
| 90 |
|
| 91 |
def get_cls_tab(examples_dir: str, trained_model_dir: str):
|
|
|
|
| 132 |
|
| 133 |
cls_button = gr.Button("run", variant="primary")
|
| 134 |
with gr.Column(scale=3):
|
| 135 |
+
cls_outputs = gr.Textbox(label="outputs", lines=1, max_lines=15)
|
|
|
|
| 136 |
|
| 137 |
gr.Examples(
|
| 138 |
cls_examples,
|
| 139 |
inputs=[cls_audio, cls_model_name, cls_ground_true],
|
| 140 |
+
outputs=[cls_outputs],
|
| 141 |
fn=when_click_cls_button,
|
| 142 |
examples_per_page=5,
|
| 143 |
)
|
|
|
|
| 145 |
cls_button.click(
|
| 146 |
when_click_cls_button,
|
| 147 |
inputs=[cls_audio, cls_model_name, cls_ground_true],
|
| 148 |
+
outputs=[cls_outputs],
|
| 149 |
)
|
| 150 |
|
| 151 |
return locals()
|