HoneyTian commited on
Commit
e90b328
·
1 Parent(s): fbd43a1
examples/fsmn_vad_by_webrtcvad/run.sh CHANGED
@@ -127,13 +127,11 @@ fi
127
 
128
 
129
  if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
130
- $verbose && echo "stage 4: test model"
131
  cd "${work_dir}" || exit 1
132
- python3 step_3_evaluation.py \
133
- --valid_dataset "${valid_dataset}" \
134
  --model_dir "${file_dir}/best" \
135
- --evaluation_audio_dir "${evaluation_audio_dir}" \
136
- --limit "${limit}" \
137
 
138
  fi
139
 
@@ -145,7 +143,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
145
  mkdir -p ${final_model_dir}
146
 
147
  cp "${file_dir}/best"/* "${final_model_dir}"
148
- cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
149
 
150
  cd "${final_model_dir}/.." || exit 1;
151
 
 
127
 
128
 
129
  if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
130
+ $verbose && echo "stage 4: export model"
131
  cd "${work_dir}" || exit 1
132
+ python3 step_5_export_model.py \
 
133
  --model_dir "${file_dir}/best" \
134
+ --output_dir "${file_dir}/best" \
 
135
 
136
  fi
137
 
 
143
  mkdir -p ${final_model_dir}
144
 
145
  cp "${file_dir}/best"/* "${final_model_dir}"
 
146
 
147
  cd "${final_model_dir}/.." || exit 1;
148
 
examples/fsmn_vad_by_webrtcvad/step_5_export_model.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import sys
7
+
8
+ pwd = os.path.abspath(os.path.dirname(__file__))
9
+ sys.path.append(os.path.join(pwd, "../../"))
10
+
11
+ import torch
12
+
13
+ from toolbox.torchaudio.models.vad.fsmn_vad.modeling_fsmn_vad import FSMNVadModel, FSMNVadPretrainedModel, FSMNVadModelExport
14
+
15
+
16
+ def get_args():
17
+ parser = argparse.ArgumentParser()
18
+ # parser.add_argument("--model_dir", default="file_dir/best", type=str)
19
+ # parser.add_argument("--output_dir", default="file_dir/best", type=str)
20
+
21
+ parser.add_argument(
22
+ "--model_dir",
23
+ default=r"D:\Users\tianx\HuggingSpaces\cc_vad\trained_models\fsmn-vad-by-webrtcvad-nx2-dns3\fsmn-vad-by-webrtcvad-nx2-dns3",
24
+ type=str
25
+ )
26
+ parser.add_argument(
27
+ "--output_dir",
28
+ default=r"D:\Users\tianx\HuggingSpaces\cc_vad\trained_models\fsmn-vad-by-webrtcvad-nx2-dns3\fsmn-vad-by-webrtcvad-nx2-dns3",
29
+ type=str
30
+ )
31
+ args = parser.parse_args()
32
+ return args
33
+
34
+
35
+ def main():
36
+ args = get_args()
37
+
38
+ output_dir = Path(args.output_dir)
39
+ output_file = output_dir / "model.onnx"
40
+
41
+ model = FSMNVadPretrainedModel.from_pretrained(args.model_dir)
42
+ model.eval()
43
+ config = model.config
44
+
45
+ basic_block_layers = config.fsmn_basic_block_layers
46
+ hidden_size = config.fsmn_basic_block_hidden_size
47
+ basic_block_lorder = config.fsmn_basic_block_lorder
48
+ basic_block_lstride = config.fsmn_basic_block_lstride
49
+
50
+ model_export = FSMNVadModelExport(model)
51
+
52
+ b = 1
53
+ inputs = torch.randn(size=(b, 1, 16000), dtype=torch.float32)
54
+ cache_list = [
55
+ torch.zeros(size=(b, hidden_size, (basic_block_lorder - 1) * basic_block_lstride, 1)),
56
+ ] * basic_block_layers
57
+ cache_list = torch.stack(cache_list, dim=0)
58
+
59
+ torch.onnx.export(model_export,
60
+ args=(inputs, cache_list),
61
+ f=output_file.as_posix(),
62
+ input_names=["inputs", "cache_list"],
63
+ output_names=["logits", "probs", "lsnr", "new_cache_list"],
64
+ dynamic_axes={
65
+ "inputs": {0: "batch_size", 2: "num_samples"},
66
+ "cache_list": {0: "basic_block_layers", 1: "batch_size"},
67
+ "logits": {0: "batch_size"},
68
+ "probs": {0: "batch_size"},
69
+ "lsnr": {0: "batch_size"},
70
+ "new_cache_list": {0: "basic_block_layers", 1: "batch_size"},
71
+ })
72
+
73
+ return
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()
main.py CHANGED
@@ -4,21 +4,24 @@ import argparse
4
  from functools import lru_cache
5
  import json
6
  import logging
 
7
  import platform
 
8
  import tempfile
9
  import time
10
  from typing import Dict, Tuple
 
11
 
12
  import gradio as gr
13
- import librosa
14
- import librosa.display
15
  import matplotlib.pyplot as plt
16
  import numpy as np
17
 
18
  import log
19
  from project_settings import environment, project_path, log_directory, time_zone_info
20
  from toolbox.os.command import Command
21
- from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad import InferenceFSMNVad
 
22
  from toolbox.torchaudio.utils.visualization import process_speech_probs
23
 
24
  log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
@@ -28,6 +31,22 @@ logger = logging.getLogger("main")
28
 
29
  def get_args():
30
  parser = argparse.ArgumentParser()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  parser.add_argument(
32
  "--hf_token",
33
  default=environment.get("hf_token"),
@@ -49,7 +68,9 @@ def shell(cmd: str):
49
 
50
  def get_infer_cls_by_model_name(model_name: str):
51
  if model_name.__contains__("fsmn"):
52
- infer_cls = InferenceFSMNVad
 
 
53
  else:
54
  raise AssertionError
55
  return infer_cls
@@ -111,7 +132,8 @@ def when_click_vad_button(audio_file_t = None, audio_microphone_t = None, engine
111
 
112
  probs = vad_info["probs"]
113
  lsnr = vad_info["lsnr"]
114
- lsnr = lsnr / np.max(np.abs(lsnr))
 
115
 
116
  frame_step = infer_engine.config.hop_size
117
  probs = process_speech_probs(audio, probs, frame_step)
@@ -128,6 +150,18 @@ def when_click_vad_button(audio_file_t = None, audio_microphone_t = None, engine
128
  def main():
129
  args = get_args()
130
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  # engines
132
  global vad_engines
133
  vad_engines = {
@@ -152,6 +186,25 @@ def main():
152
  # choices
153
  vad_engine_choices = list(vad_engines.keys())
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # ui
156
  with gr.Blocks() as blocks:
157
  gr.Markdown(value="vad.")
@@ -175,7 +228,15 @@ def main():
175
  vad_button.click(
176
  when_click_vad_button,
177
  inputs=[vad_audio_file, vad_audio_microphone, vad_engine],
178
- outputs=[vad_vad_image, vad_lsnr_image, vad_message]
 
 
 
 
 
 
 
 
179
  )
180
  with gr.TabItem("shell"):
181
  shell_text = gr.Textbox(label="cmd")
 
4
  from functools import lru_cache
5
  import json
6
  import logging
7
+ from pathlib import Path
8
  import platform
9
+ import shutil
10
  import tempfile
11
  import time
12
  from typing import Dict, Tuple
13
+ import zipfile
14
 
15
  import gradio as gr
16
+ from huggingface_hub import snapshot_download
 
17
  import matplotlib.pyplot as plt
18
  import numpy as np
19
 
20
  import log
21
  from project_settings import environment, project_path, log_directory, time_zone_info
22
  from toolbox.os.command import Command
23
+ from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx
24
+ from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad
25
  from toolbox.torchaudio.utils.visualization import process_speech_probs
26
 
27
  log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
 
31
 
32
  def get_args():
33
  parser = argparse.ArgumentParser()
34
+ parser.add_argument(
35
+ "--examples_dir",
36
+ # default=(project_path / "data").as_posix(),
37
+ default=(project_path / "data/examples").as_posix(),
38
+ type=str
39
+ )
40
+ parser.add_argument(
41
+ "--models_repo_id",
42
+ default="qgyd2021/cc_vad",
43
+ type=str
44
+ )
45
+ parser.add_argument(
46
+ "--trained_model_dir",
47
+ default=(project_path / "trained_models").as_posix(),
48
+ type=str
49
+ )
50
  parser.add_argument(
51
  "--hf_token",
52
  default=environment.get("hf_token"),
 
68
 
69
  def get_infer_cls_by_model_name(model_name: str):
70
  if model_name.__contains__("fsmn"):
71
+ infer_cls = InferenceFSMNVadOnnx
72
+ elif model_name.__contains__("silero"):
73
+ infer_cls = InferenceSileroVad
74
  else:
75
  raise AssertionError
76
  return infer_cls
 
132
 
133
  probs = vad_info["probs"]
134
  lsnr = vad_info["lsnr"]
135
+ # lsnr = lsnr / np.max(np.abs(lsnr))
136
+ lsnr = lsnr / 30
137
 
138
  frame_step = infer_engine.config.hop_size
139
  probs = process_speech_probs(audio, probs, frame_step)
 
150
  def main():
151
  args = get_args()
152
 
153
+ examples_dir = Path(args.examples_dir)
154
+ trained_model_dir = Path(args.trained_model_dir)
155
+
156
+ # download models
157
+ if not trained_model_dir.exists():
158
+ trained_model_dir.mkdir(parents=True, exist_ok=True)
159
+ _ = snapshot_download(
160
+ repo_id=args.models_repo_id,
161
+ local_dir=trained_model_dir.as_posix(),
162
+ token=args.hf_token,
163
+ )
164
+
165
  # engines
166
  global vad_engines
167
  vad_engines = {
 
186
  # choices
187
  vad_engine_choices = list(vad_engines.keys())
188
 
189
+ # examples
190
+ if not examples_dir.exists():
191
+ example_zip_file = trained_model_dir / "examples.zip"
192
+ with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip:
193
+ out_root = examples_dir
194
+ if out_root.exists():
195
+ shutil.rmtree(out_root.as_posix())
196
+ out_root.mkdir(parents=True, exist_ok=True)
197
+ f_zip.extractall(path=out_root)
198
+
199
+ # examples
200
+ examples = list()
201
+ for filename in examples_dir.glob("**/*.wav"):
202
+ examples.append([
203
+ filename.as_posix(),
204
+ None,
205
+ vad_engine_choices[0],
206
+ ])
207
+
208
  # ui
209
  with gr.Blocks() as blocks:
210
  gr.Markdown(value="vad.")
 
228
  vad_button.click(
229
  when_click_vad_button,
230
  inputs=[vad_audio_file, vad_audio_microphone, vad_engine],
231
+ outputs=[vad_vad_image, vad_lsnr_image, vad_message],
232
+ )
233
+ gr.Examples(
234
+ examples=examples,
235
+ inputs=[vad_audio_file, vad_audio_microphone, vad_engine],
236
+ outputs=[vad_vad_image, vad_lsnr_image, vad_message],
237
+ fn=when_click_vad_button,
238
+ # cache_examples=True,
239
+ # cache_mode="lazy",
240
  )
241
  with gr.TabItem("shell"):
242
  shell_text = gr.Textbox(label="cmd")
requirements.txt CHANGED
@@ -12,3 +12,5 @@ overrides==7.7.0
12
  webrtcvad==2.0.10
13
  matplotlib==3.10.3
14
  google-genai
 
 
 
12
  webrtcvad==2.0.10
13
  matplotlib==3.10.3
14
  google-genai
15
+ onnx==1.18.0
16
+ onnxruntime==1.22.1
toolbox/torchaudio/models/vad/fsmn_vad/fsmn_encoder.py CHANGED
@@ -183,6 +183,29 @@ class BasicBlock(nn.Module):
183
  return x4, new_cache
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  class FSMN(nn.Module):
187
  def __init__(
188
  self,
@@ -251,28 +274,193 @@ class FSMN(nn.Module):
251
  return outputs, new_cache_list
252
 
253
 
254
- def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  fsmn = FSMN(
256
- input_size=32,
257
- input_affine_size=16,
258
- hidden_size=16,
259
- basic_block_layers=3,
260
- basic_block_hidden_size=16,
261
- basic_block_lorder=3,
262
- basic_block_rorder=0,
263
- basic_block_lstride=1,
264
- basic_block_rstride=1,
265
- output_affine_size=16,
266
- output_size=32,
267
  )
268
 
269
- inputs = torch.randn(size=(1, 198, 32), dtype=torch.float32)
 
 
 
270
 
271
  result, _ = fsmn.forward(inputs)
272
- print(result.shape)
273
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  return
275
 
276
 
277
  if __name__ == "__main__":
278
- main()
 
183
  return x4, new_cache
184
 
185
 
186
+ class BasicBlockExport(nn.Module):
187
+ def __init__(self, model: BasicBlock):
188
+ super(BasicBlockExport, self).__init__()
189
+ self.linear = model.linear
190
+ self.fsmn_block = model.fsmn_block
191
+ self.affine = model.affine
192
+ self.relu = model.relu
193
+
194
+ def forward(self, inputs: torch.Tensor, cache: torch.Tensor):
195
+ # inputs shape: [b, t, f]
196
+ x1 = self.linear.forward(inputs)
197
+ # x1 shape: [b, t, f']
198
+
199
+ x2, new_cache = self.fsmn_block.forward(x1, cache=cache)
200
+ # x2 shape: [b, t, f']
201
+
202
+ x3 = self.affine.forward(x2)
203
+ # x3 shape: [b, t, f]
204
+
205
+ x4 = self.relu(x3)
206
+ return x4, new_cache
207
+
208
+
209
  class FSMN(nn.Module):
210
  def __init__(
211
  self,
 
274
  return outputs, new_cache_list
275
 
276
 
277
+ class FSMNExport(nn.Module):
278
+ def __init__(self, model: FSMN):
279
+ super(FSMNExport, self).__init__()
280
+ self.in_linear1 = model.in_linear1
281
+ self.in_linear2 = model.in_linear2
282
+ self.relu = model.relu
283
+
284
+ self.out_linear1 = model.out_linear1
285
+ self.out_linear2 = model.out_linear2
286
+
287
+ self.fsmn_basic_block_list = nn.ModuleList(modules=[])
288
+ for i, d in enumerate(model.fsmn_basic_block_list):
289
+ if isinstance(d, BasicBlock):
290
+ self.fsmn_basic_block_list.append(BasicBlockExport(d))
291
+
292
+ def forward(self,
293
+ inputs: torch.Tensor,
294
+ cache_list: torch.Tensor,
295
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
296
+ # cache_list shape: [basic_block_layers, b, hidden_size, (basic_block_lorder - 1) * basic_block_lstride, 1]
297
+
298
+ # inputs shape: [b, t, f]
299
+ x = self.in_linear1.forward(inputs)
300
+ # x shape: [b, t, input_affine_dim]
301
+ x = self.in_linear2.forward(x)
302
+ # x shape: [b, t, f]
303
+
304
+ x = self.relu(x)
305
+
306
+ new_cache_list = list()
307
+ for idx, fsmn_basic_block in enumerate(self.fsmn_basic_block_list):
308
+ cache = cache_list[idx]
309
+ x, new_cache = fsmn_basic_block.forward(x, cache)
310
+ new_cache_list.append(new_cache)
311
+ new_cache_list = torch.stack(new_cache_list, dim=0)
312
+
313
+ # x shape: [b, t, f]
314
+ x = self.out_linear1.forward(x)
315
+ outputs = self.out_linear2.forward(x)
316
+ # outputs shape: [b, t, f]
317
+
318
+ return outputs, new_cache_list
319
+
320
+
321
+ def main1():
322
+ import onnx
323
+ import onnxruntime as ort
324
+
325
+ input_size = 32
326
+ input_affine_size = 16
327
+ hidden_size = 16
328
+ basic_block_layers = 3
329
+ basic_block_hidden_size = 16
330
+ basic_block_lorder = 3
331
+ basic_block_rorder = 0
332
+ basic_block_lstride = 1
333
+ basic_block_rstride = 1
334
+ output_affine_size = 16
335
+ output_size = 32
336
+
337
+ basic_block = BasicBlock(
338
+ input_size=hidden_size,
339
+ hidden_size=basic_block_hidden_size,
340
+ lorder=basic_block_lorder,
341
+ rorder=basic_block_rorder,
342
+ lstride=basic_block_lstride,
343
+ rstride=basic_block_rstride,
344
+ )
345
+
346
+ b = 1
347
+ t = 198
348
+ f = hidden_size
349
+ inputs = torch.randn(size=(b, t, f), dtype=torch.float32)
350
+
351
+ result, _ = basic_block.forward(inputs)
352
+ print(f"result.shape: {result.shape}")
353
+
354
+ basic_block_export = BasicBlockExport(model=basic_block)
355
+
356
+ cache = torch.zeros(size=(b, hidden_size, (basic_block_lorder - 1) * basic_block_lstride, 1))
357
+ result, new_cache = basic_block_export.forward(inputs, cache)
358
+ print(f"result.shape: {result.shape}")
359
+ print(f"new_cache.shape: {new_cache.shape}")
360
+
361
+ torch.onnx.export(basic_block_export,
362
+ args=(inputs, cache),
363
+ f="basic_block.onnx",
364
+ input_names=["inputs", "cache"],
365
+ output_names=["outputs", "new_cache"],
366
+ dynamic_axes={
367
+ "inputs": {0: "batch_size"},
368
+ "cache": {0: "batch_size"},
369
+ "outputs": {0: "batch_size"},
370
+ "new_cache": {0: "batch_size"},
371
+ })
372
+
373
+ ort_session = ort.InferenceSession("basic_block.onnx")
374
+ input_feed = {
375
+ "inputs": inputs.numpy(),
376
+ "cache": cache.numpy(),
377
+ }
378
+ output_names = [
379
+ "outputs",
380
+ "new_cache"
381
+ ]
382
+ outputs = ort_session.run(output_names, input_feed)
383
+ print(outputs)
384
+ print(len(outputs))
385
+ return
386
+
387
+
388
+ def main2():
389
+ import onnx
390
+ import onnxruntime as ort
391
+
392
+ input_size = 32
393
+ input_affine_size = 16
394
+ hidden_size = 16
395
+ basic_block_layers = 3
396
+ basic_block_hidden_size = 16
397
+ basic_block_lorder = 3
398
+ basic_block_rorder = 0
399
+ basic_block_lstride = 1
400
+ basic_block_rstride = 1
401
+ output_affine_size = 16
402
+ output_size = 32
403
+
404
  fsmn = FSMN(
405
+ input_size=input_size,
406
+ input_affine_size=input_affine_size,
407
+ hidden_size=hidden_size,
408
+ basic_block_layers=basic_block_layers,
409
+ basic_block_hidden_size=basic_block_hidden_size,
410
+ basic_block_lorder=basic_block_lorder,
411
+ basic_block_rorder=basic_block_rorder,
412
+ basic_block_lstride=basic_block_lstride,
413
+ basic_block_rstride=basic_block_rstride,
414
+ output_affine_size=output_affine_size,
415
+ output_size=output_size,
416
  )
417
 
418
+ b = 1
419
+ t = 198
420
+ f = input_size
421
+ inputs = torch.randn(size=(b, t, f), dtype=torch.float32)
422
 
423
  result, _ = fsmn.forward(inputs)
424
+ print(f"result.shape: {result.shape}")
425
+
426
+ fsmn_export = FSMNExport(model=fsmn)
427
+
428
+ cache_list = [
429
+ torch.zeros(size=(b, hidden_size, (basic_block_lorder - 1) * basic_block_lstride, 1)),
430
+ torch.zeros(size=(b, hidden_size, (basic_block_lorder - 1) * basic_block_lstride, 1)),
431
+ torch.zeros(size=(b, hidden_size, (basic_block_lorder - 1) * basic_block_lstride, 1)),
432
+ ]
433
+ cache_list = torch.stack(cache_list, dim=0)
434
+ result, new_cache_list = fsmn_export.forward(inputs, cache_list)
435
+ print(f"result.shape: {result.shape}")
436
+ print(f"new_cache_list.shape: {new_cache_list.shape}")
437
+
438
+ torch.onnx.export(fsmn_export,
439
+ args=(inputs, cache_list),
440
+ f="fsmn.onnx",
441
+ input_names=["inputs", "cache_list"],
442
+ output_names=["outputs", "new_cache_list"],
443
+ dynamic_axes={
444
+ "inputs": {0: "batch_size"},
445
+ "cache_list": {0: "basic_block_layers", 1: "batch_size"},
446
+ "outputs": {0: "batch_size"},
447
+ "new_cache_list": {0: "basic_block_layers", 1: "batch_size"},
448
+ })
449
+
450
+ ort_session = ort.InferenceSession("fsmn.onnx")
451
+ input_feed = {
452
+ "inputs": inputs.numpy(),
453
+ "cache_list": cache_list.numpy(),
454
+ }
455
+ output_names = [
456
+ "outputs",
457
+ "new_cache_list"
458
+ ]
459
+ outputs, new_cache_list = ort_session.run(output_names, input_feed)
460
+ print(f"outputs.shape: {outputs.shape}")
461
+ print(f"new_cache_list.shape: {new_cache_list.shape}")
462
  return
463
 
464
 
465
  if __name__ == "__main__":
466
+ main2()
toolbox/torchaudio/models/vad/fsmn_vad/inference_fsmn_vad.py CHANGED
@@ -18,7 +18,7 @@ torch.set_num_threads(1)
18
 
19
  from project_settings import project_path
20
  from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
21
- from toolbox.torchaudio.models.vad.fsmn_vad.modeling_fsmn_vad import FSMNVadPretrainedModel, MODEL_FILE
22
  from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
23
 
24
 
 
18
 
19
  from project_settings import project_path
20
  from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
21
+ from toolbox.torchaudio.models.vad.fsmn_vad.modeling_fsmn_vad import FSMNVadPretrainedModel, MODEL_FILE, FSMNVadModelExport
22
  from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
23
 
24
 
toolbox/torchaudio/models/vad/fsmn_vad/inference_fsmn_vad_onnx.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import logging
5
+ from pathlib import Path
6
+ import shutil
7
+ import tempfile, time
8
+ from typing import List
9
+ import zipfile
10
+
11
+ from scipy.io import wavfile
12
+ import numpy as np
13
+ import torch
14
+ import onnxruntime as ort
15
+
16
+ torch.set_num_threads(1)
17
+
18
+ from project_settings import project_path
19
+ from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
20
+ from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
21
+
22
+
23
+ logger = logging.getLogger("toolbox")
24
+
25
+
26
+ class InferenceFSMNVadOnnx(object):
27
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
28
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
29
+ self.device = torch.device(device)
30
+
31
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
32
+ config, ort_session = self.load_models(self.pretrained_model_path_or_zip_file)
33
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
34
+
35
+ self.config = config
36
+ self.ort_session = ort_session
37
+
38
+ def load_models(self, model_path: str):
39
+ model_path = Path(model_path)
40
+ if model_path.name.endswith(".zip"):
41
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
42
+ out_root = Path(tempfile.gettempdir()) / "cc_vad"
43
+ out_root.mkdir(parents=True, exist_ok=True)
44
+ f_zip.extractall(path=out_root)
45
+ model_path = out_root / model_path.stem
46
+
47
+ config = FSMNVadConfig.from_pretrained(
48
+ pretrained_model_name_or_path=model_path.as_posix(),
49
+ )
50
+ ort_session = ort.InferenceSession(
51
+ path_or_bytes=(model_path / "model.onnx").as_posix()
52
+ )
53
+
54
+ shutil.rmtree(model_path)
55
+ return config, ort_session
56
+
57
+ def infer(self, signal: np.ndarray) -> np.ndarray:
58
+ # signal shape: [num_samples,], value between -1 and 1.
59
+
60
+ inputs = torch.tensor(signal, dtype=torch.float32)
61
+ inputs = torch.unsqueeze(inputs, dim=0)
62
+ inputs = torch.unsqueeze(inputs, dim=0)
63
+ # inputs shape: [1, 1, num_samples]
64
+
65
+ b = 1
66
+ cache_list = [
67
+ torch.zeros(size=(
68
+ b, self.config.fsmn_basic_block_hidden_size,
69
+ (self.config.fsmn_basic_block_lorder - 1) * self.config.fsmn_basic_block_lstride,
70
+ 1
71
+ )),
72
+ ] * self.config.fsmn_basic_block_layers
73
+ cache_list = torch.stack(cache_list, dim=0)
74
+
75
+ input_feed = {
76
+ "inputs": inputs.numpy(),
77
+ "cache_list": cache_list.numpy(),
78
+ }
79
+ output_names = [
80
+ "logits", "probs", "lsnr", "new_cache_list"
81
+ ]
82
+ logits, probs, lsnr, new_cache_list = self.ort_session.run(output_names, input_feed)
83
+ # probs shape: [b, t, 1]
84
+ probs = np.squeeze(probs, axis=-1)
85
+ # probs shape: [b, t]
86
+ probs = probs[0]
87
+
88
+ # lsnr shape: [b, t, 1]
89
+ lsnr = np.squeeze(lsnr, axis=-1)
90
+ # lsnr shape: [b, t]
91
+ lsnr = lsnr[0]
92
+
93
+ result = {
94
+ "probs": probs,
95
+ "lsnr": lsnr,
96
+ }
97
+ return result
98
+
99
+ def post_process(self, probs: List[float]):
100
+ return
101
+
102
+
103
+ def get_args():
104
+ parser = argparse.ArgumentParser()
105
+ parser.add_argument(
106
+ "--wav_file",
107
+ # default=(project_path / "data/examples/ai_agent/chinese-4.wav").as_posix(),
108
+ # default=(project_path / "data/examples/ai_agent/chinese-5.wav").as_posix(),
109
+ # default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
110
+ # default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
111
+ # default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
112
+ # default=(project_path / "data/examples/speech/active_media_r_2a2f472b-a0b8-4fd5-b1c4-1aedc5d2ce57_0.wav").as_posix(),
113
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_w_8b6e28e2-a238-4c8c-b2e3-426b1fca149b_6.wav",
114
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0a56f035-40f6-4530-b852-613f057d718d_6.wav",
115
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ae70b76-3651-4a71-bc0c-9e1429e4c854_5.wav",
116
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d483249-57f8-4d45-b4c6-bda82d6816ae_2.wav",
117
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d952885-5bc2-4633-81b6-e0e809e113f1_2.wav",
118
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ddac777-d986-4a5c-9c7c-ff64be0a463d_11.wav",
119
+
120
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0b8a8e80-52af-423b-8877-03a78b1e6e43_0.wav",
121
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0ebffb68-6490-4a8b-8eb6-eb82443d7d75_0.wav",
122
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0f6ec933-90df-447b-aca4-6ddc149452ab_0.wav",
123
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_0.wav",
124
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_1.wav",
125
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aff518b-4749-42fc-adfe-64046f9baeb6_0.wav",
126
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_0.wav",
127
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_1.wav",
128
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1bb1f22e-9c3a-4aea-b53f-71cc6547a6ee_0.wav",
129
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1dab161b-2a76-4491-abd1-60dba6172f8d_2.wav",
130
+ type=str,
131
+ )
132
+ args = parser.parse_args()
133
+ return args
134
+
135
+
136
+ SAMPLE_RATE = 8000
137
+
138
+
139
+ def main():
140
+ args = get_args()
141
+
142
+ sample_rate, signal = wavfile.read(args.wav_file)
143
+ if SAMPLE_RATE != sample_rate:
144
+ raise AssertionError
145
+ signal = signal / (1 << 15)
146
+
147
+ infer = InferenceFSMNVadOnnx(
148
+ # pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix(),
149
+ pretrained_model_path_or_zip_file = (project_path / "trained_models/fsmn-vad-by-webrtcvad-nx2-dns3.zip").as_posix(),
150
+ )
151
+ frame_step = infer.config.hop_size
152
+
153
+ speech_probs: np.ndarray = infer.infer(signal)
154
+ speech_probs = speech_probs.tolist()
155
+
156
+ speech_probs = process_speech_probs(
157
+ signal=signal,
158
+ speech_probs=speech_probs,
159
+ frame_step=frame_step,
160
+ )
161
+
162
+ # plot
163
+ make_visualization(signal, speech_probs, SAMPLE_RATE)
164
+ return
165
+
166
+
167
+ if __name__ == "__main__":
168
+ main()
toolbox/torchaudio/models/vad/fsmn_vad/modeling_fsmn_vad.py CHANGED
@@ -20,7 +20,7 @@ from torch.nn import functional as F
20
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
21
  from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
22
  from toolbox.torchaudio.modules.conv_stft import ConvSTFT
23
- from toolbox.torchaudio.models.vad.fsmn_vad.fsmn_encoder import FSMN
24
  from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
25
 
26
 
@@ -243,7 +243,45 @@ class FSMNVadPretrainedModel(FSMNVadModel):
243
  return save_directory
244
 
245
 
246
- def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  config = FSMNVadConfig()
248
  model = FSMNVadPretrainedModel(config=config)
249
 
@@ -253,9 +291,62 @@ def main():
253
  print(f"logits.shape: {logits.shape}")
254
  print(f"probs.shape: {probs.shape}")
255
  print(f"lsnr.shape: {lsnr.shape}")
 
 
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  return
258
 
259
 
260
  if __name__ == "__main__":
261
- main()
 
20
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
21
  from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
22
  from toolbox.torchaudio.modules.conv_stft import ConvSTFT
23
+ from toolbox.torchaudio.models.vad.fsmn_vad.fsmn_encoder import FSMN, FSMNExport
24
  from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
25
 
26
 
 
243
  return save_directory
244
 
245
 
246
+ class FSMNVadModelExport(nn.Module):
247
+ def __init__(self, model: FSMNVadModel):
248
+ super(FSMNVadModelExport, self).__init__()
249
+ self.stft = model.stft
250
+ self.fsmn_encoder = FSMNExport(model.fsmn_encoder)
251
+
252
+ # lsnr
253
+ self.lsnr_scale = model.lsnr_scale
254
+ self.lsnr_offset = model.lsnr_offset
255
+
256
+ def forward(self,
257
+ signal: torch.Tensor,
258
+ cache_list: torch.Tensor,
259
+ ):
260
+ # signal shape [b, 1, num_samples]
261
+
262
+ mags = self.stft.forward(signal)
263
+ # mags shape: [b, f, t]
264
+
265
+ x = torch.transpose(mags, dim0=1, dim1=2)
266
+ # x shape: [b, t, f]
267
+
268
+ logits, new_cache_list = self.fsmn_encoder.forward(x, cache_list)
269
+ # logits shape: [b, t, 2]
270
+
271
+ splits = torch.split(logits, split_size_or_sections=[1, 1], dim=-1)
272
+ vad_logits = splits[0]
273
+ snr_logits = splits[1]
274
+ # shape: [b, t, 1]
275
+ vad_probs = F.sigmoid(vad_logits)
276
+ # vad_probs shape: [b, t, 1]
277
+
278
+ lsnr = F.sigmoid(snr_logits) * self.lsnr_scale + self.lsnr_offset
279
+ # lsnr shape: [b, t, 1]
280
+
281
+ return vad_logits, vad_probs, lsnr, new_cache_list
282
+
283
+
284
+ def main1():
285
  config = FSMNVadConfig()
286
  model = FSMNVadPretrainedModel(config=config)
287
 
 
291
  print(f"logits.shape: {logits.shape}")
292
  print(f"probs.shape: {probs.shape}")
293
  print(f"lsnr.shape: {lsnr.shape}")
294
+ return
295
+
296
 
297
+ def main2():
298
+ import onnx
299
+ import onnxruntime as ort
300
+
301
+ config = FSMNVadConfig()
302
+ model = FSMNVadPretrainedModel(config=config)
303
+
304
+ basic_block_layers = config.fsmn_basic_block_layers
305
+ hidden_size = config.fsmn_basic_block_hidden_size
306
+ basic_block_lorder = config.fsmn_basic_block_lorder
307
+ basic_block_lstride = config.fsmn_basic_block_lstride
308
+
309
+ model_export = FSMNVadModelExport(model)
310
+
311
+ b = 1
312
+ inputs = torch.randn(size=(b, 1, 16000), dtype=torch.float32)
313
+ cache_list = [
314
+ torch.zeros(size=(b, hidden_size, (basic_block_lorder - 1) * basic_block_lstride, 1)),
315
+ ] * basic_block_layers
316
+ cache_list = torch.stack(cache_list, dim=0)
317
+
318
+ logits, probs, lsnr, new_cache_list = model_export.forward(inputs, cache_list)
319
+ print(f"logits.shape: {logits.shape}")
320
+ print(f"new_cache_list.shape: {new_cache_list.shape}")
321
+
322
+ torch.onnx.export(model_export,
323
+ args=(inputs, cache_list),
324
+ f="fsmn_vad.onnx",
325
+ input_names=["inputs", "cache_list"],
326
+ output_names=["logits", "probs", "lsnr", "new_cache_list"],
327
+ dynamic_axes={
328
+ "inputs": {0: "batch_size", 2: "num_samples"},
329
+ "cache_list": {0: "basic_block_layers", 1: "batch_size"},
330
+ "logits": {0: "batch_size"},
331
+ "probs": {0: "batch_size"},
332
+ "lsnr": {0: "batch_size"},
333
+ "new_cache_list": {0: "basic_block_layers", 1: "batch_size"},
334
+ })
335
+
336
+ ort_session = ort.InferenceSession("fsmn_vad.onnx")
337
+ input_feed = {
338
+ "inputs": inputs.numpy(),
339
+ "cache_list": cache_list.numpy(),
340
+ }
341
+ output_names = [
342
+ "outputs",
343
+ "new_cache_list"
344
+ ]
345
+ outputs, new_cache_list = ort_session.run(output_names, input_feed)
346
+ print(f"outputs.shape: {outputs.shape}")
347
+ print(f"new_cache_list.shape: {new_cache_list.shape}")
348
  return
349
 
350
 
351
  if __name__ == "__main__":
352
+ main2()