update
Browse files- examples/fsmn_vad_by_webrtcvad/run.sh +3 -6
- examples/fsmn_vad_by_webrtcvad/step_5_export_model.py +77 -0
- main.py +67 -6
- requirements.txt +2 -0
- toolbox/torchaudio/models/vad/fsmn_vad/fsmn_encoder.py +204 -16
- toolbox/torchaudio/models/vad/fsmn_vad/inference_fsmn_vad.py +1 -1
- toolbox/torchaudio/models/vad/fsmn_vad/inference_fsmn_vad_onnx.py +168 -0
- toolbox/torchaudio/models/vad/fsmn_vad/modeling_fsmn_vad.py +94 -3
examples/fsmn_vad_by_webrtcvad/run.sh
CHANGED
@@ -127,13 +127,11 @@ fi
|
|
127 |
|
128 |
|
129 |
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
130 |
-
$verbose && echo "stage 4:
|
131 |
cd "${work_dir}" || exit 1
|
132 |
-
python3
|
133 |
-
--valid_dataset "${valid_dataset}" \
|
134 |
--model_dir "${file_dir}/best" \
|
135 |
-
--
|
136 |
-
--limit "${limit}" \
|
137 |
|
138 |
fi
|
139 |
|
@@ -145,7 +143,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|
145 |
mkdir -p ${final_model_dir}
|
146 |
|
147 |
cp "${file_dir}/best"/* "${final_model_dir}"
|
148 |
-
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
149 |
|
150 |
cd "${final_model_dir}/.." || exit 1;
|
151 |
|
|
|
127 |
|
128 |
|
129 |
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
130 |
+
$verbose && echo "stage 4: export model"
|
131 |
cd "${work_dir}" || exit 1
|
132 |
+
python3 step_5_export_model.py \
|
|
|
133 |
--model_dir "${file_dir}/best" \
|
134 |
+
--output_dir "${file_dir}/best" \
|
|
|
135 |
|
136 |
fi
|
137 |
|
|
|
143 |
mkdir -p ${final_model_dir}
|
144 |
|
145 |
cp "${file_dir}/best"/* "${final_model_dir}"
|
|
|
146 |
|
147 |
cd "${final_model_dir}/.." || exit 1;
|
148 |
|
examples/fsmn_vad_by_webrtcvad/step_5_export_model.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import sys
|
7 |
+
|
8 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
9 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from toolbox.torchaudio.models.vad.fsmn_vad.modeling_fsmn_vad import FSMNVadModel, FSMNVadPretrainedModel, FSMNVadModelExport
|
14 |
+
|
15 |
+
|
16 |
+
def get_args():
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
# parser.add_argument("--model_dir", default="file_dir/best", type=str)
|
19 |
+
# parser.add_argument("--output_dir", default="file_dir/best", type=str)
|
20 |
+
|
21 |
+
parser.add_argument(
|
22 |
+
"--model_dir",
|
23 |
+
default=r"D:\Users\tianx\HuggingSpaces\cc_vad\trained_models\fsmn-vad-by-webrtcvad-nx2-dns3\fsmn-vad-by-webrtcvad-nx2-dns3",
|
24 |
+
type=str
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--output_dir",
|
28 |
+
default=r"D:\Users\tianx\HuggingSpaces\cc_vad\trained_models\fsmn-vad-by-webrtcvad-nx2-dns3\fsmn-vad-by-webrtcvad-nx2-dns3",
|
29 |
+
type=str
|
30 |
+
)
|
31 |
+
args = parser.parse_args()
|
32 |
+
return args
|
33 |
+
|
34 |
+
|
35 |
+
def main():
|
36 |
+
args = get_args()
|
37 |
+
|
38 |
+
output_dir = Path(args.output_dir)
|
39 |
+
output_file = output_dir / "model.onnx"
|
40 |
+
|
41 |
+
model = FSMNVadPretrainedModel.from_pretrained(args.model_dir)
|
42 |
+
model.eval()
|
43 |
+
config = model.config
|
44 |
+
|
45 |
+
basic_block_layers = config.fsmn_basic_block_layers
|
46 |
+
hidden_size = config.fsmn_basic_block_hidden_size
|
47 |
+
basic_block_lorder = config.fsmn_basic_block_lorder
|
48 |
+
basic_block_lstride = config.fsmn_basic_block_lstride
|
49 |
+
|
50 |
+
model_export = FSMNVadModelExport(model)
|
51 |
+
|
52 |
+
b = 1
|
53 |
+
inputs = torch.randn(size=(b, 1, 16000), dtype=torch.float32)
|
54 |
+
cache_list = [
|
55 |
+
torch.zeros(size=(b, hidden_size, (basic_block_lorder - 1) * basic_block_lstride, 1)),
|
56 |
+
] * basic_block_layers
|
57 |
+
cache_list = torch.stack(cache_list, dim=0)
|
58 |
+
|
59 |
+
torch.onnx.export(model_export,
|
60 |
+
args=(inputs, cache_list),
|
61 |
+
f=output_file.as_posix(),
|
62 |
+
input_names=["inputs", "cache_list"],
|
63 |
+
output_names=["logits", "probs", "lsnr", "new_cache_list"],
|
64 |
+
dynamic_axes={
|
65 |
+
"inputs": {0: "batch_size", 2: "num_samples"},
|
66 |
+
"cache_list": {0: "basic_block_layers", 1: "batch_size"},
|
67 |
+
"logits": {0: "batch_size"},
|
68 |
+
"probs": {0: "batch_size"},
|
69 |
+
"lsnr": {0: "batch_size"},
|
70 |
+
"new_cache_list": {0: "basic_block_layers", 1: "batch_size"},
|
71 |
+
})
|
72 |
+
|
73 |
+
return
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
main()
|
main.py
CHANGED
@@ -4,21 +4,24 @@ import argparse
|
|
4 |
from functools import lru_cache
|
5 |
import json
|
6 |
import logging
|
|
|
7 |
import platform
|
|
|
8 |
import tempfile
|
9 |
import time
|
10 |
from typing import Dict, Tuple
|
|
|
11 |
|
12 |
import gradio as gr
|
13 |
-
import
|
14 |
-
import librosa.display
|
15 |
import matplotlib.pyplot as plt
|
16 |
import numpy as np
|
17 |
|
18 |
import log
|
19 |
from project_settings import environment, project_path, log_directory, time_zone_info
|
20 |
from toolbox.os.command import Command
|
21 |
-
from toolbox.torchaudio.models.vad.fsmn_vad.
|
|
|
22 |
from toolbox.torchaudio.utils.visualization import process_speech_probs
|
23 |
|
24 |
log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
|
@@ -28,6 +31,22 @@ logger = logging.getLogger("main")
|
|
28 |
|
29 |
def get_args():
|
30 |
parser = argparse.ArgumentParser()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
parser.add_argument(
|
32 |
"--hf_token",
|
33 |
default=environment.get("hf_token"),
|
@@ -49,7 +68,9 @@ def shell(cmd: str):
|
|
49 |
|
50 |
def get_infer_cls_by_model_name(model_name: str):
|
51 |
if model_name.__contains__("fsmn"):
|
52 |
-
infer_cls =
|
|
|
|
|
53 |
else:
|
54 |
raise AssertionError
|
55 |
return infer_cls
|
@@ -111,7 +132,8 @@ def when_click_vad_button(audio_file_t = None, audio_microphone_t = None, engine
|
|
111 |
|
112 |
probs = vad_info["probs"]
|
113 |
lsnr = vad_info["lsnr"]
|
114 |
-
lsnr = lsnr / np.max(np.abs(lsnr))
|
|
|
115 |
|
116 |
frame_step = infer_engine.config.hop_size
|
117 |
probs = process_speech_probs(audio, probs, frame_step)
|
@@ -128,6 +150,18 @@ def when_click_vad_button(audio_file_t = None, audio_microphone_t = None, engine
|
|
128 |
def main():
|
129 |
args = get_args()
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
# engines
|
132 |
global vad_engines
|
133 |
vad_engines = {
|
@@ -152,6 +186,25 @@ def main():
|
|
152 |
# choices
|
153 |
vad_engine_choices = list(vad_engines.keys())
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
# ui
|
156 |
with gr.Blocks() as blocks:
|
157 |
gr.Markdown(value="vad.")
|
@@ -175,7 +228,15 @@ def main():
|
|
175 |
vad_button.click(
|
176 |
when_click_vad_button,
|
177 |
inputs=[vad_audio_file, vad_audio_microphone, vad_engine],
|
178 |
-
outputs=[vad_vad_image, vad_lsnr_image, vad_message]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
)
|
180 |
with gr.TabItem("shell"):
|
181 |
shell_text = gr.Textbox(label="cmd")
|
|
|
4 |
from functools import lru_cache
|
5 |
import json
|
6 |
import logging
|
7 |
+
from pathlib import Path
|
8 |
import platform
|
9 |
+
import shutil
|
10 |
import tempfile
|
11 |
import time
|
12 |
from typing import Dict, Tuple
|
13 |
+
import zipfile
|
14 |
|
15 |
import gradio as gr
|
16 |
+
from huggingface_hub import snapshot_download
|
|
|
17 |
import matplotlib.pyplot as plt
|
18 |
import numpy as np
|
19 |
|
20 |
import log
|
21 |
from project_settings import environment, project_path, log_directory, time_zone_info
|
22 |
from toolbox.os.command import Command
|
23 |
+
from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx
|
24 |
+
from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad
|
25 |
from toolbox.torchaudio.utils.visualization import process_speech_probs
|
26 |
|
27 |
log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
|
|
|
31 |
|
32 |
def get_args():
|
33 |
parser = argparse.ArgumentParser()
|
34 |
+
parser.add_argument(
|
35 |
+
"--examples_dir",
|
36 |
+
# default=(project_path / "data").as_posix(),
|
37 |
+
default=(project_path / "data/examples").as_posix(),
|
38 |
+
type=str
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--models_repo_id",
|
42 |
+
default="qgyd2021/cc_vad",
|
43 |
+
type=str
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--trained_model_dir",
|
47 |
+
default=(project_path / "trained_models").as_posix(),
|
48 |
+
type=str
|
49 |
+
)
|
50 |
parser.add_argument(
|
51 |
"--hf_token",
|
52 |
default=environment.get("hf_token"),
|
|
|
68 |
|
69 |
def get_infer_cls_by_model_name(model_name: str):
|
70 |
if model_name.__contains__("fsmn"):
|
71 |
+
infer_cls = InferenceFSMNVadOnnx
|
72 |
+
elif model_name.__contains__("silero"):
|
73 |
+
infer_cls = InferenceSileroVad
|
74 |
else:
|
75 |
raise AssertionError
|
76 |
return infer_cls
|
|
|
132 |
|
133 |
probs = vad_info["probs"]
|
134 |
lsnr = vad_info["lsnr"]
|
135 |
+
# lsnr = lsnr / np.max(np.abs(lsnr))
|
136 |
+
lsnr = lsnr / 30
|
137 |
|
138 |
frame_step = infer_engine.config.hop_size
|
139 |
probs = process_speech_probs(audio, probs, frame_step)
|
|
|
150 |
def main():
|
151 |
args = get_args()
|
152 |
|
153 |
+
examples_dir = Path(args.examples_dir)
|
154 |
+
trained_model_dir = Path(args.trained_model_dir)
|
155 |
+
|
156 |
+
# download models
|
157 |
+
if not trained_model_dir.exists():
|
158 |
+
trained_model_dir.mkdir(parents=True, exist_ok=True)
|
159 |
+
_ = snapshot_download(
|
160 |
+
repo_id=args.models_repo_id,
|
161 |
+
local_dir=trained_model_dir.as_posix(),
|
162 |
+
token=args.hf_token,
|
163 |
+
)
|
164 |
+
|
165 |
# engines
|
166 |
global vad_engines
|
167 |
vad_engines = {
|
|
|
186 |
# choices
|
187 |
vad_engine_choices = list(vad_engines.keys())
|
188 |
|
189 |
+
# examples
|
190 |
+
if not examples_dir.exists():
|
191 |
+
example_zip_file = trained_model_dir / "examples.zip"
|
192 |
+
with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip:
|
193 |
+
out_root = examples_dir
|
194 |
+
if out_root.exists():
|
195 |
+
shutil.rmtree(out_root.as_posix())
|
196 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
197 |
+
f_zip.extractall(path=out_root)
|
198 |
+
|
199 |
+
# examples
|
200 |
+
examples = list()
|
201 |
+
for filename in examples_dir.glob("**/*.wav"):
|
202 |
+
examples.append([
|
203 |
+
filename.as_posix(),
|
204 |
+
None,
|
205 |
+
vad_engine_choices[0],
|
206 |
+
])
|
207 |
+
|
208 |
# ui
|
209 |
with gr.Blocks() as blocks:
|
210 |
gr.Markdown(value="vad.")
|
|
|
228 |
vad_button.click(
|
229 |
when_click_vad_button,
|
230 |
inputs=[vad_audio_file, vad_audio_microphone, vad_engine],
|
231 |
+
outputs=[vad_vad_image, vad_lsnr_image, vad_message],
|
232 |
+
)
|
233 |
+
gr.Examples(
|
234 |
+
examples=examples,
|
235 |
+
inputs=[vad_audio_file, vad_audio_microphone, vad_engine],
|
236 |
+
outputs=[vad_vad_image, vad_lsnr_image, vad_message],
|
237 |
+
fn=when_click_vad_button,
|
238 |
+
# cache_examples=True,
|
239 |
+
# cache_mode="lazy",
|
240 |
)
|
241 |
with gr.TabItem("shell"):
|
242 |
shell_text = gr.Textbox(label="cmd")
|
requirements.txt
CHANGED
@@ -12,3 +12,5 @@ overrides==7.7.0
|
|
12 |
webrtcvad==2.0.10
|
13 |
matplotlib==3.10.3
|
14 |
google-genai
|
|
|
|
|
|
12 |
webrtcvad==2.0.10
|
13 |
matplotlib==3.10.3
|
14 |
google-genai
|
15 |
+
onnx==1.18.0
|
16 |
+
onnxruntime==1.22.1
|
toolbox/torchaudio/models/vad/fsmn_vad/fsmn_encoder.py
CHANGED
@@ -183,6 +183,29 @@ class BasicBlock(nn.Module):
|
|
183 |
return x4, new_cache
|
184 |
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
class FSMN(nn.Module):
|
187 |
def __init__(
|
188 |
self,
|
@@ -251,28 +274,193 @@ class FSMN(nn.Module):
|
|
251 |
return outputs, new_cache_list
|
252 |
|
253 |
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
fsmn = FSMN(
|
256 |
-
input_size=
|
257 |
-
input_affine_size=
|
258 |
-
hidden_size=
|
259 |
-
basic_block_layers=
|
260 |
-
basic_block_hidden_size=
|
261 |
-
basic_block_lorder=
|
262 |
-
basic_block_rorder=
|
263 |
-
basic_block_lstride=
|
264 |
-
basic_block_rstride=
|
265 |
-
output_affine_size=
|
266 |
-
output_size=
|
267 |
)
|
268 |
|
269 |
-
|
|
|
|
|
|
|
270 |
|
271 |
result, _ = fsmn.forward(inputs)
|
272 |
-
print(result.shape)
|
273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
return
|
275 |
|
276 |
|
277 |
if __name__ == "__main__":
|
278 |
-
|
|
|
183 |
return x4, new_cache
|
184 |
|
185 |
|
186 |
+
class BasicBlockExport(nn.Module):
|
187 |
+
def __init__(self, model: BasicBlock):
|
188 |
+
super(BasicBlockExport, self).__init__()
|
189 |
+
self.linear = model.linear
|
190 |
+
self.fsmn_block = model.fsmn_block
|
191 |
+
self.affine = model.affine
|
192 |
+
self.relu = model.relu
|
193 |
+
|
194 |
+
def forward(self, inputs: torch.Tensor, cache: torch.Tensor):
|
195 |
+
# inputs shape: [b, t, f]
|
196 |
+
x1 = self.linear.forward(inputs)
|
197 |
+
# x1 shape: [b, t, f']
|
198 |
+
|
199 |
+
x2, new_cache = self.fsmn_block.forward(x1, cache=cache)
|
200 |
+
# x2 shape: [b, t, f']
|
201 |
+
|
202 |
+
x3 = self.affine.forward(x2)
|
203 |
+
# x3 shape: [b, t, f]
|
204 |
+
|
205 |
+
x4 = self.relu(x3)
|
206 |
+
return x4, new_cache
|
207 |
+
|
208 |
+
|
209 |
class FSMN(nn.Module):
|
210 |
def __init__(
|
211 |
self,
|
|
|
274 |
return outputs, new_cache_list
|
275 |
|
276 |
|
277 |
+
class FSMNExport(nn.Module):
|
278 |
+
def __init__(self, model: FSMN):
|
279 |
+
super(FSMNExport, self).__init__()
|
280 |
+
self.in_linear1 = model.in_linear1
|
281 |
+
self.in_linear2 = model.in_linear2
|
282 |
+
self.relu = model.relu
|
283 |
+
|
284 |
+
self.out_linear1 = model.out_linear1
|
285 |
+
self.out_linear2 = model.out_linear2
|
286 |
+
|
287 |
+
self.fsmn_basic_block_list = nn.ModuleList(modules=[])
|
288 |
+
for i, d in enumerate(model.fsmn_basic_block_list):
|
289 |
+
if isinstance(d, BasicBlock):
|
290 |
+
self.fsmn_basic_block_list.append(BasicBlockExport(d))
|
291 |
+
|
292 |
+
def forward(self,
|
293 |
+
inputs: torch.Tensor,
|
294 |
+
cache_list: torch.Tensor,
|
295 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
296 |
+
# cache_list shape: [basic_block_layers, b, hidden_size, (basic_block_lorder - 1) * basic_block_lstride, 1]
|
297 |
+
|
298 |
+
# inputs shape: [b, t, f]
|
299 |
+
x = self.in_linear1.forward(inputs)
|
300 |
+
# x shape: [b, t, input_affine_dim]
|
301 |
+
x = self.in_linear2.forward(x)
|
302 |
+
# x shape: [b, t, f]
|
303 |
+
|
304 |
+
x = self.relu(x)
|
305 |
+
|
306 |
+
new_cache_list = list()
|
307 |
+
for idx, fsmn_basic_block in enumerate(self.fsmn_basic_block_list):
|
308 |
+
cache = cache_list[idx]
|
309 |
+
x, new_cache = fsmn_basic_block.forward(x, cache)
|
310 |
+
new_cache_list.append(new_cache)
|
311 |
+
new_cache_list = torch.stack(new_cache_list, dim=0)
|
312 |
+
|
313 |
+
# x shape: [b, t, f]
|
314 |
+
x = self.out_linear1.forward(x)
|
315 |
+
outputs = self.out_linear2.forward(x)
|
316 |
+
# outputs shape: [b, t, f]
|
317 |
+
|
318 |
+
return outputs, new_cache_list
|
319 |
+
|
320 |
+
|
321 |
+
def main1():
|
322 |
+
import onnx
|
323 |
+
import onnxruntime as ort
|
324 |
+
|
325 |
+
input_size = 32
|
326 |
+
input_affine_size = 16
|
327 |
+
hidden_size = 16
|
328 |
+
basic_block_layers = 3
|
329 |
+
basic_block_hidden_size = 16
|
330 |
+
basic_block_lorder = 3
|
331 |
+
basic_block_rorder = 0
|
332 |
+
basic_block_lstride = 1
|
333 |
+
basic_block_rstride = 1
|
334 |
+
output_affine_size = 16
|
335 |
+
output_size = 32
|
336 |
+
|
337 |
+
basic_block = BasicBlock(
|
338 |
+
input_size=hidden_size,
|
339 |
+
hidden_size=basic_block_hidden_size,
|
340 |
+
lorder=basic_block_lorder,
|
341 |
+
rorder=basic_block_rorder,
|
342 |
+
lstride=basic_block_lstride,
|
343 |
+
rstride=basic_block_rstride,
|
344 |
+
)
|
345 |
+
|
346 |
+
b = 1
|
347 |
+
t = 198
|
348 |
+
f = hidden_size
|
349 |
+
inputs = torch.randn(size=(b, t, f), dtype=torch.float32)
|
350 |
+
|
351 |
+
result, _ = basic_block.forward(inputs)
|
352 |
+
print(f"result.shape: {result.shape}")
|
353 |
+
|
354 |
+
basic_block_export = BasicBlockExport(model=basic_block)
|
355 |
+
|
356 |
+
cache = torch.zeros(size=(b, hidden_size, (basic_block_lorder - 1) * basic_block_lstride, 1))
|
357 |
+
result, new_cache = basic_block_export.forward(inputs, cache)
|
358 |
+
print(f"result.shape: {result.shape}")
|
359 |
+
print(f"new_cache.shape: {new_cache.shape}")
|
360 |
+
|
361 |
+
torch.onnx.export(basic_block_export,
|
362 |
+
args=(inputs, cache),
|
363 |
+
f="basic_block.onnx",
|
364 |
+
input_names=["inputs", "cache"],
|
365 |
+
output_names=["outputs", "new_cache"],
|
366 |
+
dynamic_axes={
|
367 |
+
"inputs": {0: "batch_size"},
|
368 |
+
"cache": {0: "batch_size"},
|
369 |
+
"outputs": {0: "batch_size"},
|
370 |
+
"new_cache": {0: "batch_size"},
|
371 |
+
})
|
372 |
+
|
373 |
+
ort_session = ort.InferenceSession("basic_block.onnx")
|
374 |
+
input_feed = {
|
375 |
+
"inputs": inputs.numpy(),
|
376 |
+
"cache": cache.numpy(),
|
377 |
+
}
|
378 |
+
output_names = [
|
379 |
+
"outputs",
|
380 |
+
"new_cache"
|
381 |
+
]
|
382 |
+
outputs = ort_session.run(output_names, input_feed)
|
383 |
+
print(outputs)
|
384 |
+
print(len(outputs))
|
385 |
+
return
|
386 |
+
|
387 |
+
|
388 |
+
def main2():
|
389 |
+
import onnx
|
390 |
+
import onnxruntime as ort
|
391 |
+
|
392 |
+
input_size = 32
|
393 |
+
input_affine_size = 16
|
394 |
+
hidden_size = 16
|
395 |
+
basic_block_layers = 3
|
396 |
+
basic_block_hidden_size = 16
|
397 |
+
basic_block_lorder = 3
|
398 |
+
basic_block_rorder = 0
|
399 |
+
basic_block_lstride = 1
|
400 |
+
basic_block_rstride = 1
|
401 |
+
output_affine_size = 16
|
402 |
+
output_size = 32
|
403 |
+
|
404 |
fsmn = FSMN(
|
405 |
+
input_size=input_size,
|
406 |
+
input_affine_size=input_affine_size,
|
407 |
+
hidden_size=hidden_size,
|
408 |
+
basic_block_layers=basic_block_layers,
|
409 |
+
basic_block_hidden_size=basic_block_hidden_size,
|
410 |
+
basic_block_lorder=basic_block_lorder,
|
411 |
+
basic_block_rorder=basic_block_rorder,
|
412 |
+
basic_block_lstride=basic_block_lstride,
|
413 |
+
basic_block_rstride=basic_block_rstride,
|
414 |
+
output_affine_size=output_affine_size,
|
415 |
+
output_size=output_size,
|
416 |
)
|
417 |
|
418 |
+
b = 1
|
419 |
+
t = 198
|
420 |
+
f = input_size
|
421 |
+
inputs = torch.randn(size=(b, t, f), dtype=torch.float32)
|
422 |
|
423 |
result, _ = fsmn.forward(inputs)
|
424 |
+
print(f"result.shape: {result.shape}")
|
425 |
+
|
426 |
+
fsmn_export = FSMNExport(model=fsmn)
|
427 |
+
|
428 |
+
cache_list = [
|
429 |
+
torch.zeros(size=(b, hidden_size, (basic_block_lorder - 1) * basic_block_lstride, 1)),
|
430 |
+
torch.zeros(size=(b, hidden_size, (basic_block_lorder - 1) * basic_block_lstride, 1)),
|
431 |
+
torch.zeros(size=(b, hidden_size, (basic_block_lorder - 1) * basic_block_lstride, 1)),
|
432 |
+
]
|
433 |
+
cache_list = torch.stack(cache_list, dim=0)
|
434 |
+
result, new_cache_list = fsmn_export.forward(inputs, cache_list)
|
435 |
+
print(f"result.shape: {result.shape}")
|
436 |
+
print(f"new_cache_list.shape: {new_cache_list.shape}")
|
437 |
+
|
438 |
+
torch.onnx.export(fsmn_export,
|
439 |
+
args=(inputs, cache_list),
|
440 |
+
f="fsmn.onnx",
|
441 |
+
input_names=["inputs", "cache_list"],
|
442 |
+
output_names=["outputs", "new_cache_list"],
|
443 |
+
dynamic_axes={
|
444 |
+
"inputs": {0: "batch_size"},
|
445 |
+
"cache_list": {0: "basic_block_layers", 1: "batch_size"},
|
446 |
+
"outputs": {0: "batch_size"},
|
447 |
+
"new_cache_list": {0: "basic_block_layers", 1: "batch_size"},
|
448 |
+
})
|
449 |
+
|
450 |
+
ort_session = ort.InferenceSession("fsmn.onnx")
|
451 |
+
input_feed = {
|
452 |
+
"inputs": inputs.numpy(),
|
453 |
+
"cache_list": cache_list.numpy(),
|
454 |
+
}
|
455 |
+
output_names = [
|
456 |
+
"outputs",
|
457 |
+
"new_cache_list"
|
458 |
+
]
|
459 |
+
outputs, new_cache_list = ort_session.run(output_names, input_feed)
|
460 |
+
print(f"outputs.shape: {outputs.shape}")
|
461 |
+
print(f"new_cache_list.shape: {new_cache_list.shape}")
|
462 |
return
|
463 |
|
464 |
|
465 |
if __name__ == "__main__":
|
466 |
+
main2()
|
toolbox/torchaudio/models/vad/fsmn_vad/inference_fsmn_vad.py
CHANGED
@@ -18,7 +18,7 @@ torch.set_num_threads(1)
|
|
18 |
|
19 |
from project_settings import project_path
|
20 |
from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
|
21 |
-
from toolbox.torchaudio.models.vad.fsmn_vad.modeling_fsmn_vad import FSMNVadPretrainedModel, MODEL_FILE
|
22 |
from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
|
23 |
|
24 |
|
|
|
18 |
|
19 |
from project_settings import project_path
|
20 |
from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
|
21 |
+
from toolbox.torchaudio.models.vad.fsmn_vad.modeling_fsmn_vad import FSMNVadPretrainedModel, MODEL_FILE, FSMNVadModelExport
|
22 |
from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
|
23 |
|
24 |
|
toolbox/torchaudio/models/vad/fsmn_vad/inference_fsmn_vad_onnx.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
from pathlib import Path
|
6 |
+
import shutil
|
7 |
+
import tempfile, time
|
8 |
+
from typing import List
|
9 |
+
import zipfile
|
10 |
+
|
11 |
+
from scipy.io import wavfile
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import onnxruntime as ort
|
15 |
+
|
16 |
+
torch.set_num_threads(1)
|
17 |
+
|
18 |
+
from project_settings import project_path
|
19 |
+
from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
|
20 |
+
from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("toolbox")
|
24 |
+
|
25 |
+
|
26 |
+
class InferenceFSMNVadOnnx(object):
|
27 |
+
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
28 |
+
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
29 |
+
self.device = torch.device(device)
|
30 |
+
|
31 |
+
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
|
32 |
+
config, ort_session = self.load_models(self.pretrained_model_path_or_zip_file)
|
33 |
+
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
|
34 |
+
|
35 |
+
self.config = config
|
36 |
+
self.ort_session = ort_session
|
37 |
+
|
38 |
+
def load_models(self, model_path: str):
|
39 |
+
model_path = Path(model_path)
|
40 |
+
if model_path.name.endswith(".zip"):
|
41 |
+
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
|
42 |
+
out_root = Path(tempfile.gettempdir()) / "cc_vad"
|
43 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
44 |
+
f_zip.extractall(path=out_root)
|
45 |
+
model_path = out_root / model_path.stem
|
46 |
+
|
47 |
+
config = FSMNVadConfig.from_pretrained(
|
48 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
49 |
+
)
|
50 |
+
ort_session = ort.InferenceSession(
|
51 |
+
path_or_bytes=(model_path / "model.onnx").as_posix()
|
52 |
+
)
|
53 |
+
|
54 |
+
shutil.rmtree(model_path)
|
55 |
+
return config, ort_session
|
56 |
+
|
57 |
+
def infer(self, signal: np.ndarray) -> np.ndarray:
|
58 |
+
# signal shape: [num_samples,], value between -1 and 1.
|
59 |
+
|
60 |
+
inputs = torch.tensor(signal, dtype=torch.float32)
|
61 |
+
inputs = torch.unsqueeze(inputs, dim=0)
|
62 |
+
inputs = torch.unsqueeze(inputs, dim=0)
|
63 |
+
# inputs shape: [1, 1, num_samples]
|
64 |
+
|
65 |
+
b = 1
|
66 |
+
cache_list = [
|
67 |
+
torch.zeros(size=(
|
68 |
+
b, self.config.fsmn_basic_block_hidden_size,
|
69 |
+
(self.config.fsmn_basic_block_lorder - 1) * self.config.fsmn_basic_block_lstride,
|
70 |
+
1
|
71 |
+
)),
|
72 |
+
] * self.config.fsmn_basic_block_layers
|
73 |
+
cache_list = torch.stack(cache_list, dim=0)
|
74 |
+
|
75 |
+
input_feed = {
|
76 |
+
"inputs": inputs.numpy(),
|
77 |
+
"cache_list": cache_list.numpy(),
|
78 |
+
}
|
79 |
+
output_names = [
|
80 |
+
"logits", "probs", "lsnr", "new_cache_list"
|
81 |
+
]
|
82 |
+
logits, probs, lsnr, new_cache_list = self.ort_session.run(output_names, input_feed)
|
83 |
+
# probs shape: [b, t, 1]
|
84 |
+
probs = np.squeeze(probs, axis=-1)
|
85 |
+
# probs shape: [b, t]
|
86 |
+
probs = probs[0]
|
87 |
+
|
88 |
+
# lsnr shape: [b, t, 1]
|
89 |
+
lsnr = np.squeeze(lsnr, axis=-1)
|
90 |
+
# lsnr shape: [b, t]
|
91 |
+
lsnr = lsnr[0]
|
92 |
+
|
93 |
+
result = {
|
94 |
+
"probs": probs,
|
95 |
+
"lsnr": lsnr,
|
96 |
+
}
|
97 |
+
return result
|
98 |
+
|
99 |
+
def post_process(self, probs: List[float]):
|
100 |
+
return
|
101 |
+
|
102 |
+
|
103 |
+
def get_args():
|
104 |
+
parser = argparse.ArgumentParser()
|
105 |
+
parser.add_argument(
|
106 |
+
"--wav_file",
|
107 |
+
# default=(project_path / "data/examples/ai_agent/chinese-4.wav").as_posix(),
|
108 |
+
# default=(project_path / "data/examples/ai_agent/chinese-5.wav").as_posix(),
|
109 |
+
# default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
|
110 |
+
# default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
|
111 |
+
# default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
|
112 |
+
# default=(project_path / "data/examples/speech/active_media_r_2a2f472b-a0b8-4fd5-b1c4-1aedc5d2ce57_0.wav").as_posix(),
|
113 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_w_8b6e28e2-a238-4c8c-b2e3-426b1fca149b_6.wav",
|
114 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0a56f035-40f6-4530-b852-613f057d718d_6.wav",
|
115 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ae70b76-3651-4a71-bc0c-9e1429e4c854_5.wav",
|
116 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d483249-57f8-4d45-b4c6-bda82d6816ae_2.wav",
|
117 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d952885-5bc2-4633-81b6-e0e809e113f1_2.wav",
|
118 |
+
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ddac777-d986-4a5c-9c7c-ff64be0a463d_11.wav",
|
119 |
+
|
120 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0b8a8e80-52af-423b-8877-03a78b1e6e43_0.wav",
|
121 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0ebffb68-6490-4a8b-8eb6-eb82443d7d75_0.wav",
|
122 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0f6ec933-90df-447b-aca4-6ddc149452ab_0.wav",
|
123 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_0.wav",
|
124 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_1.wav",
|
125 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aff518b-4749-42fc-adfe-64046f9baeb6_0.wav",
|
126 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_0.wav",
|
127 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_1.wav",
|
128 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1bb1f22e-9c3a-4aea-b53f-71cc6547a6ee_0.wav",
|
129 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1dab161b-2a76-4491-abd1-60dba6172f8d_2.wav",
|
130 |
+
type=str,
|
131 |
+
)
|
132 |
+
args = parser.parse_args()
|
133 |
+
return args
|
134 |
+
|
135 |
+
|
136 |
+
SAMPLE_RATE = 8000
|
137 |
+
|
138 |
+
|
139 |
+
def main():
|
140 |
+
args = get_args()
|
141 |
+
|
142 |
+
sample_rate, signal = wavfile.read(args.wav_file)
|
143 |
+
if SAMPLE_RATE != sample_rate:
|
144 |
+
raise AssertionError
|
145 |
+
signal = signal / (1 << 15)
|
146 |
+
|
147 |
+
infer = InferenceFSMNVadOnnx(
|
148 |
+
# pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix(),
|
149 |
+
pretrained_model_path_or_zip_file = (project_path / "trained_models/fsmn-vad-by-webrtcvad-nx2-dns3.zip").as_posix(),
|
150 |
+
)
|
151 |
+
frame_step = infer.config.hop_size
|
152 |
+
|
153 |
+
speech_probs: np.ndarray = infer.infer(signal)
|
154 |
+
speech_probs = speech_probs.tolist()
|
155 |
+
|
156 |
+
speech_probs = process_speech_probs(
|
157 |
+
signal=signal,
|
158 |
+
speech_probs=speech_probs,
|
159 |
+
frame_step=frame_step,
|
160 |
+
)
|
161 |
+
|
162 |
+
# plot
|
163 |
+
make_visualization(signal, speech_probs, SAMPLE_RATE)
|
164 |
+
return
|
165 |
+
|
166 |
+
|
167 |
+
if __name__ == "__main__":
|
168 |
+
main()
|
toolbox/torchaudio/models/vad/fsmn_vad/modeling_fsmn_vad.py
CHANGED
@@ -20,7 +20,7 @@ from torch.nn import functional as F
|
|
20 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
21 |
from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
|
22 |
from toolbox.torchaudio.modules.conv_stft import ConvSTFT
|
23 |
-
from toolbox.torchaudio.models.vad.fsmn_vad.fsmn_encoder import FSMN
|
24 |
from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
|
25 |
|
26 |
|
@@ -243,7 +243,45 @@ class FSMNVadPretrainedModel(FSMNVadModel):
|
|
243 |
return save_directory
|
244 |
|
245 |
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
config = FSMNVadConfig()
|
248 |
model = FSMNVadPretrainedModel(config=config)
|
249 |
|
@@ -253,9 +291,62 @@ def main():
|
|
253 |
print(f"logits.shape: {logits.shape}")
|
254 |
print(f"probs.shape: {probs.shape}")
|
255 |
print(f"lsnr.shape: {lsnr.shape}")
|
|
|
|
|
256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
return
|
258 |
|
259 |
|
260 |
if __name__ == "__main__":
|
261 |
-
|
|
|
20 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
21 |
from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
|
22 |
from toolbox.torchaudio.modules.conv_stft import ConvSTFT
|
23 |
+
from toolbox.torchaudio.models.vad.fsmn_vad.fsmn_encoder import FSMN, FSMNExport
|
24 |
from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
|
25 |
|
26 |
|
|
|
243 |
return save_directory
|
244 |
|
245 |
|
246 |
+
class FSMNVadModelExport(nn.Module):
|
247 |
+
def __init__(self, model: FSMNVadModel):
|
248 |
+
super(FSMNVadModelExport, self).__init__()
|
249 |
+
self.stft = model.stft
|
250 |
+
self.fsmn_encoder = FSMNExport(model.fsmn_encoder)
|
251 |
+
|
252 |
+
# lsnr
|
253 |
+
self.lsnr_scale = model.lsnr_scale
|
254 |
+
self.lsnr_offset = model.lsnr_offset
|
255 |
+
|
256 |
+
def forward(self,
|
257 |
+
signal: torch.Tensor,
|
258 |
+
cache_list: torch.Tensor,
|
259 |
+
):
|
260 |
+
# signal shape [b, 1, num_samples]
|
261 |
+
|
262 |
+
mags = self.stft.forward(signal)
|
263 |
+
# mags shape: [b, f, t]
|
264 |
+
|
265 |
+
x = torch.transpose(mags, dim0=1, dim1=2)
|
266 |
+
# x shape: [b, t, f]
|
267 |
+
|
268 |
+
logits, new_cache_list = self.fsmn_encoder.forward(x, cache_list)
|
269 |
+
# logits shape: [b, t, 2]
|
270 |
+
|
271 |
+
splits = torch.split(logits, split_size_or_sections=[1, 1], dim=-1)
|
272 |
+
vad_logits = splits[0]
|
273 |
+
snr_logits = splits[1]
|
274 |
+
# shape: [b, t, 1]
|
275 |
+
vad_probs = F.sigmoid(vad_logits)
|
276 |
+
# vad_probs shape: [b, t, 1]
|
277 |
+
|
278 |
+
lsnr = F.sigmoid(snr_logits) * self.lsnr_scale + self.lsnr_offset
|
279 |
+
# lsnr shape: [b, t, 1]
|
280 |
+
|
281 |
+
return vad_logits, vad_probs, lsnr, new_cache_list
|
282 |
+
|
283 |
+
|
284 |
+
def main1():
|
285 |
config = FSMNVadConfig()
|
286 |
model = FSMNVadPretrainedModel(config=config)
|
287 |
|
|
|
291 |
print(f"logits.shape: {logits.shape}")
|
292 |
print(f"probs.shape: {probs.shape}")
|
293 |
print(f"lsnr.shape: {lsnr.shape}")
|
294 |
+
return
|
295 |
+
|
296 |
|
297 |
+
def main2():
|
298 |
+
import onnx
|
299 |
+
import onnxruntime as ort
|
300 |
+
|
301 |
+
config = FSMNVadConfig()
|
302 |
+
model = FSMNVadPretrainedModel(config=config)
|
303 |
+
|
304 |
+
basic_block_layers = config.fsmn_basic_block_layers
|
305 |
+
hidden_size = config.fsmn_basic_block_hidden_size
|
306 |
+
basic_block_lorder = config.fsmn_basic_block_lorder
|
307 |
+
basic_block_lstride = config.fsmn_basic_block_lstride
|
308 |
+
|
309 |
+
model_export = FSMNVadModelExport(model)
|
310 |
+
|
311 |
+
b = 1
|
312 |
+
inputs = torch.randn(size=(b, 1, 16000), dtype=torch.float32)
|
313 |
+
cache_list = [
|
314 |
+
torch.zeros(size=(b, hidden_size, (basic_block_lorder - 1) * basic_block_lstride, 1)),
|
315 |
+
] * basic_block_layers
|
316 |
+
cache_list = torch.stack(cache_list, dim=0)
|
317 |
+
|
318 |
+
logits, probs, lsnr, new_cache_list = model_export.forward(inputs, cache_list)
|
319 |
+
print(f"logits.shape: {logits.shape}")
|
320 |
+
print(f"new_cache_list.shape: {new_cache_list.shape}")
|
321 |
+
|
322 |
+
torch.onnx.export(model_export,
|
323 |
+
args=(inputs, cache_list),
|
324 |
+
f="fsmn_vad.onnx",
|
325 |
+
input_names=["inputs", "cache_list"],
|
326 |
+
output_names=["logits", "probs", "lsnr", "new_cache_list"],
|
327 |
+
dynamic_axes={
|
328 |
+
"inputs": {0: "batch_size", 2: "num_samples"},
|
329 |
+
"cache_list": {0: "basic_block_layers", 1: "batch_size"},
|
330 |
+
"logits": {0: "batch_size"},
|
331 |
+
"probs": {0: "batch_size"},
|
332 |
+
"lsnr": {0: "batch_size"},
|
333 |
+
"new_cache_list": {0: "basic_block_layers", 1: "batch_size"},
|
334 |
+
})
|
335 |
+
|
336 |
+
ort_session = ort.InferenceSession("fsmn_vad.onnx")
|
337 |
+
input_feed = {
|
338 |
+
"inputs": inputs.numpy(),
|
339 |
+
"cache_list": cache_list.numpy(),
|
340 |
+
}
|
341 |
+
output_names = [
|
342 |
+
"outputs",
|
343 |
+
"new_cache_list"
|
344 |
+
]
|
345 |
+
outputs, new_cache_list = ort_session.run(output_names, input_feed)
|
346 |
+
print(f"outputs.shape: {outputs.shape}")
|
347 |
+
print(f"new_cache_list.shape: {new_cache_list.shape}")
|
348 |
return
|
349 |
|
350 |
|
351 |
if __name__ == "__main__":
|
352 |
+
main2()
|