HoneyTian commited on
Commit
7225f3a
·
1 Parent(s): 9d169ba
examples/sound_classification_by_lstm/step_6_export_onnx_model.py CHANGED
@@ -120,9 +120,9 @@ def main():
120
  "logits", "new_h", "new_c"
121
  ]
122
  logits, new_h, new_c = ort_session.run(output_names, input_feed)
123
- print(f"logits: {logits.shape}")
124
- print(f"new_h: {new_h.shape}")
125
- print(f"new_c: {new_c.shape}")
126
  return
127
 
128
 
 
120
  "logits", "new_h", "new_c"
121
  ]
122
  logits, new_h, new_c = ort_session.run(output_names, input_feed)
123
+ # print(f"logits: {logits.shape}")
124
+ # print(f"new_h: {new_h.shape}")
125
+ # print(f"new_c: {new_c.shape}")
126
  return
127
 
128
 
examples/sound_classification_by_lstm/step_8_test_onnx_model.py CHANGED
@@ -31,7 +31,8 @@ def get_args():
31
  )
32
  parser.add_argument(
33
  "--wav_file",
34
- default=r"C:\Users\tianx\Desktop\a073d03d-d280-46df-9b2d-d904965f4500_zh-CN_h3f25ivhb0c0_1719478037746.wav",
 
35
  type=str
36
  )
37
 
@@ -107,10 +108,23 @@ def main():
107
  "logits", "new_h", "new_c"
108
  ]
109
  logits, new_h, new_c = ort_session.run(output_names, input_feed)
110
- print(f"logits: {logits.shape}")
111
- print(f"new_h: {new_h.shape}")
112
- print(f"new_c: {new_c.shape}")
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  return
115
 
116
 
 
31
  )
32
  parser.add_argument(
33
  "--wav_file",
34
+ # default=r"C:\Users\tianx\Desktop\a073d03d-d280-46df-9b2d-d904965f4500_zh-CN_h3f25ivhb0c0_1719478037746.wav",
35
+ default=(project_path / "data/examples/examples/zh-TW/voicemail/00a1d109-23c2-4b8b-a066-993ac2ae8260_zh-TW_1672210785598.wav").as_posix(),
36
  type=str
37
  )
38
 
 
108
  "logits", "new_h", "new_c"
109
  ]
110
  logits, new_h, new_c = ort_session.run(output_names, input_feed)
111
+ # print(f"logits: {logits.shape}")
112
+ # print(f"new_h: {new_h.shape}")
113
+ # print(f"new_c: {new_c.shape}")
114
 
115
+ logits = torch.tensor(logits, dtype=torch.float32)
116
+ probs = torch.nn.functional.softmax(logits, dim=-1)
117
+ label_idx = torch.argmax(probs, dim=-1)
118
+
119
+ label_idx = label_idx.cpu()
120
+ probs = probs.cpu()
121
+
122
+ label_idx = label_idx.numpy()[0]
123
+ prob = probs.numpy()[0][label_idx]
124
+
125
+ label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
126
+ print(label_str)
127
+ print(prob)
128
  return
129
 
130
 
tabs/cls_tab.py CHANGED
@@ -1,11 +1,14 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
 
4
  from functools import lru_cache
 
5
  from pathlib import Path
6
  import platform
7
  import shutil
8
  import tempfile
 
9
  import zipfile
10
  from typing import Tuple
11
 
@@ -61,10 +64,12 @@ def when_click_cls_button(audio_t,
61
  inputs = torch.tensor(inputs, dtype=torch.float32)
62
  inputs = torch.unsqueeze(inputs, dim=0)
63
 
 
64
  with torch.no_grad():
65
  logits = model.forward(inputs)
66
  probs = torch.nn.functional.softmax(logits, dim=-1)
67
  label_idx = torch.argmax(probs, dim=-1)
 
68
 
69
  label_idx = label_idx.cpu()
70
  probs = probs.cpu()
@@ -74,7 +79,13 @@ def when_click_cls_button(audio_t,
74
 
75
  label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
76
 
77
- return label_str, round(prob, 4)
 
 
 
 
 
 
78
 
79
 
80
  def get_cls_tab(examples_dir: str, trained_model_dir: str):
@@ -121,13 +132,12 @@ def get_cls_tab(examples_dir: str, trained_model_dir: str):
121
 
122
  cls_button = gr.Button("run", variant="primary")
123
  with gr.Column(scale=3):
124
- cls_label = gr.Textbox(label="label")
125
- cls_probability = gr.Number(label="probability")
126
 
127
  gr.Examples(
128
  cls_examples,
129
  inputs=[cls_audio, cls_model_name, cls_ground_true],
130
- outputs=[cls_label, cls_probability],
131
  fn=when_click_cls_button,
132
  examples_per_page=5,
133
  )
@@ -135,7 +145,7 @@ def get_cls_tab(examples_dir: str, trained_model_dir: str):
135
  cls_button.click(
136
  when_click_cls_button,
137
  inputs=[cls_audio, cls_model_name, cls_ground_true],
138
- outputs=[cls_label, cls_probability],
139
  )
140
 
141
  return locals()
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
+ import json
5
  from functools import lru_cache
6
+ from os import times
7
  from pathlib import Path
8
  import platform
9
  import shutil
10
  import tempfile
11
+ import time
12
  import zipfile
13
  from typing import Tuple
14
 
 
64
  inputs = torch.tensor(inputs, dtype=torch.float32)
65
  inputs = torch.unsqueeze(inputs, dim=0)
66
 
67
+ time_begin = time.time()
68
  with torch.no_grad():
69
  logits = model.forward(inputs)
70
  probs = torch.nn.functional.softmax(logits, dim=-1)
71
  label_idx = torch.argmax(probs, dim=-1)
72
+ time_cost = time.time() - time_begin
73
 
74
  label_idx = label_idx.cpu()
75
  probs = probs.cpu()
 
79
 
80
  label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
81
 
82
+ result = {
83
+ "label": label_str,
84
+ "prob": round(float(prob), 4),
85
+ "time_cost": round(time_cost, 4),
86
+ }
87
+ result = json.dumps(result, ensure_ascii=False, indent=4)
88
+ return result
89
 
90
 
91
  def get_cls_tab(examples_dir: str, trained_model_dir: str):
 
132
 
133
  cls_button = gr.Button("run", variant="primary")
134
  with gr.Column(scale=3):
135
+ cls_outputs = gr.Textbox(label="outputs", lines=1, max_lines=15)
 
136
 
137
  gr.Examples(
138
  cls_examples,
139
  inputs=[cls_audio, cls_model_name, cls_ground_true],
140
+ outputs=[cls_outputs],
141
  fn=when_click_cls_button,
142
  examples_per_page=5,
143
  )
 
145
  cls_button.click(
146
  when_click_cls_button,
147
  inputs=[cls_audio, cls_model_name, cls_ground_true],
148
+ outputs=[cls_outputs],
149
  )
150
 
151
  return locals()