update
Browse files
.gitignore
CHANGED
@@ -23,3 +23,4 @@
|
|
23 |
**/*.wav
|
24 |
**/*.xlsx
|
25 |
**/*.jsonl
|
|
|
|
23 |
**/*.wav
|
24 |
**/*.xlsx
|
25 |
**/*.jsonl
|
26 |
+
**/*.onnx
|
examples/silero_vad_by_webrtcvad/run.sh
CHANGED
@@ -126,13 +126,11 @@ fi
|
|
126 |
|
127 |
|
128 |
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
129 |
-
$verbose && echo "stage 4:
|
130 |
cd "${work_dir}" || exit 1
|
131 |
-
python3
|
132 |
-
--valid_dataset "${valid_dataset}" \
|
133 |
--model_dir "${file_dir}/best" \
|
134 |
-
--
|
135 |
-
--limit "${limit}" \
|
136 |
|
137 |
fi
|
138 |
|
@@ -144,7 +142,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|
144 |
mkdir -p ${final_model_dir}
|
145 |
|
146 |
cp "${file_dir}/best"/* "${final_model_dir}"
|
147 |
-
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
148 |
|
149 |
cd "${final_model_dir}/.." || exit 1;
|
150 |
|
|
|
126 |
|
127 |
|
128 |
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
129 |
+
$verbose && echo "stage 4: export model"
|
130 |
cd "${work_dir}" || exit 1
|
131 |
+
python3 step_5_export_model.py \
|
|
|
132 |
--model_dir "${file_dir}/best" \
|
133 |
+
--output_dir "${file_dir}/best" \
|
|
|
134 |
|
135 |
fi
|
136 |
|
|
|
142 |
mkdir -p ${final_model_dir}
|
143 |
|
144 |
cp "${file_dir}/best"/* "${final_model_dir}"
|
|
|
145 |
|
146 |
cd "${final_model_dir}/.." || exit 1;
|
147 |
|
examples/silero_vad_by_webrtcvad/step_5_export_model.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import sys
|
7 |
+
|
8 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
9 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
10 |
+
|
11 |
+
import onnxruntime as ort
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from toolbox.torchaudio.models.vad.silero_vad.modeling_silero_vad import SileroVadModel, SileroVadModelExport, SileroVadPretrainedModel
|
15 |
+
|
16 |
+
|
17 |
+
def get_args():
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
# parser.add_argument("--model_dir", default="file_dir/best", type=str)
|
20 |
+
# parser.add_argument("--output_dir", default="file_dir/best", type=str)
|
21 |
+
|
22 |
+
parser.add_argument(
|
23 |
+
"--model_dir",
|
24 |
+
default=r"D:\Users\tianx\HuggingSpaces\cc_vad\trained_models\fsmn-vad-by-webrtcvad-nx2-dns3\fsmn-vad-by-webrtcvad-nx2-dns3",
|
25 |
+
type=str
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--output_dir",
|
29 |
+
default=r"D:\Users\tianx\HuggingSpaces\cc_vad\trained_models\fsmn-vad-by-webrtcvad-nx2-dns3\fsmn-vad-by-webrtcvad-nx2-dns3",
|
30 |
+
type=str
|
31 |
+
)
|
32 |
+
args = parser.parse_args()
|
33 |
+
return args
|
34 |
+
|
35 |
+
|
36 |
+
def main():
|
37 |
+
args = get_args()
|
38 |
+
|
39 |
+
output_dir = Path(args.output_dir)
|
40 |
+
output_file = output_dir / "model.onnx"
|
41 |
+
|
42 |
+
model = SileroVadPretrainedModel.from_pretrained(args.model_dir)
|
43 |
+
model.eval()
|
44 |
+
config = model.config
|
45 |
+
|
46 |
+
model_export = SileroVadModelExport(model)
|
47 |
+
|
48 |
+
encoder_num_layers = config.encoder_num_layers
|
49 |
+
p = (config.encoder_kernel_size - 1) // 2
|
50 |
+
encoder_in_channels = config.encoder_in_channels
|
51 |
+
encoder_hidden_channels = config.encoder_hidden_channels
|
52 |
+
|
53 |
+
decoder_num_layers = config.decoder_num_layers
|
54 |
+
decoder_hidden_size = config.decoder_hidden_size
|
55 |
+
|
56 |
+
b = 1
|
57 |
+
inputs = torch.randn(size=(b, 1, 16000), dtype=torch.float32)
|
58 |
+
|
59 |
+
encoder_in_cache = torch.zeros(size=(b, 2*p, encoder_in_channels), dtype=torch.float32)
|
60 |
+
encoder_hidden_cache_list = [
|
61 |
+
torch.zeros(size=(b, 2*p, encoder_hidden_channels), dtype=torch.float32)
|
62 |
+
] * encoder_num_layers
|
63 |
+
encoder_hidden_cache_list = torch.stack(encoder_hidden_cache_list, dim=0)
|
64 |
+
|
65 |
+
lstm_hidden_state = [
|
66 |
+
torch.zeros(size=(decoder_num_layers, b, decoder_hidden_size), dtype=torch.float32)
|
67 |
+
] * 2
|
68 |
+
lstm_hidden_state = torch.stack(lstm_hidden_state, dim=0)
|
69 |
+
|
70 |
+
logits, probs, lsnr, new_encoder_in_cache, new_encoder_hidden_cache_list, new_lstm_hidden_state = model_export.forward(
|
71 |
+
inputs, encoder_in_cache, encoder_hidden_cache_list, lstm_hidden_state
|
72 |
+
)
|
73 |
+
|
74 |
+
torch.onnx.export(model_export,
|
75 |
+
args=(inputs, encoder_in_cache, encoder_hidden_cache_list, lstm_hidden_state),
|
76 |
+
f="silero_vad.onnx",
|
77 |
+
input_names=["inputs", "encoder_in_cache", "encoder_hidden_cache_list", "lstm_hidden_state"],
|
78 |
+
output_names=[
|
79 |
+
"logits", "probs", "lsnr",
|
80 |
+
"new_encoder_in_cache",
|
81 |
+
"new_encoder_hidden_cache_list",
|
82 |
+
"new_lstm_hidden_state"
|
83 |
+
],
|
84 |
+
dynamic_axes={
|
85 |
+
"inputs": {0: "batch_size", 2: "num_samples"},
|
86 |
+
"encoder_in_cache": {1: "batch_size"},
|
87 |
+
"encoder_hidden_cache_list": {1: "batch_size"},
|
88 |
+
"lstm_hidden_state": {2: "batch_size"},
|
89 |
+
"logits": {0: "batch_size"},
|
90 |
+
"probs": {0: "batch_size"},
|
91 |
+
"lsnr": {0: "batch_size"},
|
92 |
+
"new_encoder_in_cache": {1: "batch_size"},
|
93 |
+
"new_encoder_hidden_cache_list": {1: "batch_size"},
|
94 |
+
"new_lstm_hidden_state": {2: "batch_size"},
|
95 |
+
})
|
96 |
+
|
97 |
+
ort_session = ort.InferenceSession("silero_vad.onnx")
|
98 |
+
input_feed = {
|
99 |
+
"inputs": inputs.numpy(),
|
100 |
+
"encoder_in_cache": encoder_in_cache.numpy(),
|
101 |
+
"encoder_hidden_cache_list": encoder_hidden_cache_list.numpy(),
|
102 |
+
"lstm_hidden_state": lstm_hidden_state.numpy(),
|
103 |
+
}
|
104 |
+
output_names = [
|
105 |
+
"logits", "probs", "lsnr", "new_encoder_in_cache", "new_encoder_hidden_cache_list", "new_lstm_hidden_state"
|
106 |
+
]
|
107 |
+
logits, probs, lsnr, new_encoder_in_cache, new_encoder_hidden_cache_list, new_lstm_hidden_state = ort_session.run(output_names, input_feed)
|
108 |
+
return
|
109 |
+
|
110 |
+
|
111 |
+
if __name__ == "__main__":
|
112 |
+
main()
|
toolbox/torchaudio/models/vad/silero_vad/configuration_silero_vad.py
CHANGED
@@ -14,6 +14,8 @@ class SileroVadConfig(PretrainedConfig):
|
|
14 |
win_type: str = "hann",
|
15 |
|
16 |
encoder_in_channels: int = 64,
|
|
|
|
|
17 |
encoder_kernel_size: int = 3,
|
18 |
encoder_num_layers: int = 3,
|
19 |
|
@@ -52,6 +54,8 @@ class SileroVadConfig(PretrainedConfig):
|
|
52 |
|
53 |
# encoder
|
54 |
self.encoder_in_channels = encoder_in_channels
|
|
|
|
|
55 |
self.encoder_kernel_size = encoder_kernel_size
|
56 |
self.encoder_num_layers = encoder_num_layers
|
57 |
|
|
|
14 |
win_type: str = "hann",
|
15 |
|
16 |
encoder_in_channels: int = 64,
|
17 |
+
encoder_hidden_channels: int = 128,
|
18 |
+
encoder_out_channels: int = 64,
|
19 |
encoder_kernel_size: int = 3,
|
20 |
encoder_num_layers: int = 3,
|
21 |
|
|
|
54 |
|
55 |
# encoder
|
56 |
self.encoder_in_channels = encoder_in_channels
|
57 |
+
self.encoder_hidden_channels = encoder_hidden_channels
|
58 |
+
self.encoder_out_channels = encoder_out_channels
|
59 |
self.encoder_kernel_size = encoder_kernel_size
|
60 |
self.encoder_num_layers = encoder_num_layers
|
61 |
|
toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py
CHANGED
@@ -62,6 +62,7 @@ class Encoder(nn.Module):
|
|
62 |
num_layers: int = 3,
|
63 |
):
|
64 |
super(Encoder, self).__init__()
|
|
|
65 |
|
66 |
self.layers = nn.ModuleList(modules=[])
|
67 |
for i in range(num_layers):
|
@@ -96,23 +97,33 @@ class EncoderExport(nn.Module):
|
|
96 |
def __init__(self, model: Encoder):
|
97 |
super(EncoderExport, self).__init__()
|
98 |
self.layers = model.layers
|
|
|
99 |
|
100 |
-
def forward(self, x: torch.Tensor,
|
101 |
# x shape: [b, t, f]
|
102 |
-
#
|
|
|
103 |
|
104 |
-
|
|
|
105 |
for idx, layer in enumerate(self.layers):
|
106 |
-
|
|
|
|
|
|
|
|
|
107 |
x_pad = torch.concat(tensors=[cache, x], dim=1)
|
108 |
x = layer.forward(x_pad)
|
109 |
|
110 |
_, twop, _ = cache.shape
|
111 |
new_cache = x_pad[:, -twop:, :]
|
112 |
-
|
|
|
|
|
|
|
113 |
|
114 |
-
|
115 |
-
return x,
|
116 |
|
117 |
|
118 |
class SileroVadModel(nn.Module):
|
@@ -123,6 +134,8 @@ class SileroVadModel(nn.Module):
|
|
123 |
hop_size: int,
|
124 |
win_type: int,
|
125 |
encoder_in_channels: int,
|
|
|
|
|
126 |
encoder_kernel_size: int,
|
127 |
encoder_num_layers: int,
|
128 |
decoder_hidden_size: int,
|
@@ -139,6 +152,8 @@ class SileroVadModel(nn.Module):
|
|
139 |
self.win_type = win_type
|
140 |
|
141 |
self.encoder_in_channels = encoder_in_channels
|
|
|
|
|
142 |
self.encoder_kernel_size = encoder_kernel_size
|
143 |
self.encoder_num_layers = encoder_num_layers
|
144 |
|
@@ -180,8 +195,8 @@ class SileroVadModel(nn.Module):
|
|
180 |
|
181 |
self.encoder = Encoder(
|
182 |
in_channels=self.encoder_in_channels,
|
183 |
-
hidden_channels=self.
|
184 |
-
out_channels=self.
|
185 |
kernel_size=self.encoder_kernel_size,
|
186 |
num_layers=self.encoder_num_layers,
|
187 |
)
|
@@ -298,6 +313,8 @@ class SileroVadPretrainedModel(SileroVadModel):
|
|
298 |
hop_size=config.hop_size,
|
299 |
win_type=config.win_type,
|
300 |
encoder_in_channels=config.encoder_in_channels,
|
|
|
|
|
301 |
encoder_kernel_size=config.encoder_kernel_size,
|
302 |
encoder_num_layers=config.encoder_num_layers,
|
303 |
decoder_hidden_size=config.decoder_hidden_size,
|
@@ -362,10 +379,12 @@ class SileroVadModelExport(nn.Module):
|
|
362 |
|
363 |
def forward(self,
|
364 |
signal: torch.Tensor,
|
365 |
-
|
|
|
366 |
lstm_hidden_state: torch.Tensor,
|
367 |
):
|
368 |
-
#
|
|
|
369 |
# lstm_hidden_state shape: [2, num_layers, b, h]
|
370 |
|
371 |
# signal shape [b, 1, num_samples]
|
@@ -382,7 +401,9 @@ class SileroVadModelExport(nn.Module):
|
|
382 |
# x = self.tpad.forward(x)
|
383 |
# x shape: [b, t+p, f']
|
384 |
|
385 |
-
x,
|
|
|
|
|
386 |
# x shape: [b, t, f']
|
387 |
|
388 |
x, new_lstm_hidden_state = self.lstm.forward(x, (lstm_hidden_state[0], lstm_hidden_state[1]))
|
@@ -397,7 +418,7 @@ class SileroVadModelExport(nn.Module):
|
|
397 |
lsnr = self.lsnr_fc.forward(x) * self.lsnr_scale + self.lsnr_offset
|
398 |
# lsnr shape: [b, t, 1]
|
399 |
|
400 |
-
return logits, probs, lsnr,
|
401 |
|
402 |
|
403 |
def main1():
|
@@ -425,6 +446,7 @@ def main2():
|
|
425 |
encoder_num_layers = config.encoder_num_layers
|
426 |
p = (config.encoder_kernel_size - 1) // 2
|
427 |
encoder_in_channels = config.encoder_in_channels
|
|
|
428 |
|
429 |
decoder_num_layers = config.decoder_num_layers
|
430 |
decoder_hidden_size = config.decoder_hidden_size
|
@@ -432,49 +454,60 @@ def main2():
|
|
432 |
b = 1
|
433 |
inputs = torch.randn(size=(b, 1, 16000), dtype=torch.float32)
|
434 |
|
435 |
-
|
436 |
-
|
|
|
437 |
] * encoder_num_layers
|
438 |
-
|
439 |
|
440 |
lstm_hidden_state = [
|
441 |
torch.zeros(size=(decoder_num_layers, b, decoder_hidden_size), dtype=torch.float32)
|
442 |
] * 2
|
443 |
lstm_hidden_state = torch.stack(lstm_hidden_state, dim=0)
|
444 |
|
445 |
-
logits, probs, lsnr,
|
|
|
|
|
446 |
print(f"logits.shape: {logits.shape}")
|
447 |
-
print(f"
|
|
|
448 |
print(f"new_lstm_hidden_state.shape: {new_lstm_hidden_state.shape}")
|
449 |
|
450 |
torch.onnx.export(model_export,
|
451 |
-
args=(inputs,
|
452 |
f="silero_vad.onnx",
|
453 |
-
input_names=["inputs", "
|
454 |
-
output_names=[
|
|
|
|
|
|
|
|
|
|
|
455 |
dynamic_axes={
|
456 |
"inputs": {0: "batch_size", 2: "num_samples"},
|
457 |
-
"
|
|
|
458 |
"lstm_hidden_state": {2: "batch_size"},
|
459 |
"logits": {0: "batch_size"},
|
460 |
"probs": {0: "batch_size"},
|
461 |
"lsnr": {0: "batch_size"},
|
462 |
-
"
|
|
|
463 |
"new_lstm_hidden_state": {2: "batch_size"},
|
464 |
})
|
465 |
|
466 |
ort_session = ort.InferenceSession("silero_vad.onnx")
|
467 |
input_feed = {
|
468 |
"inputs": inputs.numpy(),
|
469 |
-
"
|
|
|
470 |
"lstm_hidden_state": lstm_hidden_state.numpy(),
|
471 |
}
|
472 |
output_names = [
|
473 |
-
"logits", "probs", "lsnr", "
|
474 |
]
|
475 |
-
logits, probs, lsnr,
|
476 |
print(f"probs.shape: {probs.shape}")
|
477 |
-
print(f"new_encoder_cache_list.shape: {new_encoder_cache_list.shape}")
|
478 |
return
|
479 |
|
480 |
|
|
|
62 |
num_layers: int = 3,
|
63 |
):
|
64 |
super(Encoder, self).__init__()
|
65 |
+
self.num_layers = num_layers
|
66 |
|
67 |
self.layers = nn.ModuleList(modules=[])
|
68 |
for i in range(num_layers):
|
|
|
97 |
def __init__(self, model: Encoder):
|
98 |
super(EncoderExport, self).__init__()
|
99 |
self.layers = model.layers
|
100 |
+
self.num_layers = model.num_layers
|
101 |
|
102 |
+
def forward(self, x: torch.Tensor, in_cache: torch.Tensor, hidden_cache_list: torch.Tensor):
|
103 |
# x shape: [b, t, f]
|
104 |
+
# in_cache shape: [b, 2p, f1]
|
105 |
+
# hidden_cache_list shape: [num_layers, b, 2p, fi]
|
106 |
|
107 |
+
new_in_cache = None
|
108 |
+
new_hidden_cache_list = list()
|
109 |
for idx, layer in enumerate(self.layers):
|
110 |
+
if idx == 0:
|
111 |
+
cache = in_cache
|
112 |
+
else:
|
113 |
+
cache = hidden_cache_list[idx]
|
114 |
+
|
115 |
x_pad = torch.concat(tensors=[cache, x], dim=1)
|
116 |
x = layer.forward(x_pad)
|
117 |
|
118 |
_, twop, _ = cache.shape
|
119 |
new_cache = x_pad[:, -twop:, :]
|
120 |
+
if idx == 0:
|
121 |
+
new_in_cache = new_cache
|
122 |
+
else:
|
123 |
+
new_hidden_cache_list.append(new_cache)
|
124 |
|
125 |
+
new_hidden_cache_list = torch.stack(tensors=new_hidden_cache_list, dim=0)
|
126 |
+
return x, new_in_cache, new_hidden_cache_list
|
127 |
|
128 |
|
129 |
class SileroVadModel(nn.Module):
|
|
|
134 |
hop_size: int,
|
135 |
win_type: int,
|
136 |
encoder_in_channels: int,
|
137 |
+
encoder_hidden_channels: int,
|
138 |
+
encoder_out_channels: int,
|
139 |
encoder_kernel_size: int,
|
140 |
encoder_num_layers: int,
|
141 |
decoder_hidden_size: int,
|
|
|
152 |
self.win_type = win_type
|
153 |
|
154 |
self.encoder_in_channels = encoder_in_channels
|
155 |
+
self.encoder_hidden_channels = encoder_hidden_channels
|
156 |
+
self.encoder_out_channels = encoder_out_channels
|
157 |
self.encoder_kernel_size = encoder_kernel_size
|
158 |
self.encoder_num_layers = encoder_num_layers
|
159 |
|
|
|
195 |
|
196 |
self.encoder = Encoder(
|
197 |
in_channels=self.encoder_in_channels,
|
198 |
+
hidden_channels=self.encoder_hidden_channels,
|
199 |
+
out_channels=self.encoder_out_channels,
|
200 |
kernel_size=self.encoder_kernel_size,
|
201 |
num_layers=self.encoder_num_layers,
|
202 |
)
|
|
|
313 |
hop_size=config.hop_size,
|
314 |
win_type=config.win_type,
|
315 |
encoder_in_channels=config.encoder_in_channels,
|
316 |
+
encoder_hidden_channels=config.encoder_hidden_channels,
|
317 |
+
encoder_out_channels=config.encoder_out_channels,
|
318 |
encoder_kernel_size=config.encoder_kernel_size,
|
319 |
encoder_num_layers=config.encoder_num_layers,
|
320 |
decoder_hidden_size=config.decoder_hidden_size,
|
|
|
379 |
|
380 |
def forward(self,
|
381 |
signal: torch.Tensor,
|
382 |
+
encoder_in_cache: torch.Tensor,
|
383 |
+
encoder_hidden_cache_list: torch.Tensor,
|
384 |
lstm_hidden_state: torch.Tensor,
|
385 |
):
|
386 |
+
# encoder_in_cache shape: [b, 2p, f]
|
387 |
+
# encoder_hidden_cache_list shape: [num_layers, b, 2p, f]
|
388 |
# lstm_hidden_state shape: [2, num_layers, b, h]
|
389 |
|
390 |
# signal shape [b, 1, num_samples]
|
|
|
401 |
# x = self.tpad.forward(x)
|
402 |
# x shape: [b, t+p, f']
|
403 |
|
404 |
+
x, new_encoder_in_cache, new_encoder_hidden_cache_list = self.encoder.forward(
|
405 |
+
x, in_cache=encoder_in_cache, hidden_cache_list=encoder_hidden_cache_list
|
406 |
+
)
|
407 |
# x shape: [b, t, f']
|
408 |
|
409 |
x, new_lstm_hidden_state = self.lstm.forward(x, (lstm_hidden_state[0], lstm_hidden_state[1]))
|
|
|
418 |
lsnr = self.lsnr_fc.forward(x) * self.lsnr_scale + self.lsnr_offset
|
419 |
# lsnr shape: [b, t, 1]
|
420 |
|
421 |
+
return logits, probs, lsnr, new_encoder_in_cache, new_encoder_hidden_cache_list, new_lstm_hidden_state
|
422 |
|
423 |
|
424 |
def main1():
|
|
|
446 |
encoder_num_layers = config.encoder_num_layers
|
447 |
p = (config.encoder_kernel_size - 1) // 2
|
448 |
encoder_in_channels = config.encoder_in_channels
|
449 |
+
encoder_hidden_channels = config.encoder_hidden_channels
|
450 |
|
451 |
decoder_num_layers = config.decoder_num_layers
|
452 |
decoder_hidden_size = config.decoder_hidden_size
|
|
|
454 |
b = 1
|
455 |
inputs = torch.randn(size=(b, 1, 16000), dtype=torch.float32)
|
456 |
|
457 |
+
encoder_in_cache = torch.zeros(size=(b, 2*p, encoder_in_channels), dtype=torch.float32)
|
458 |
+
encoder_hidden_cache_list = [
|
459 |
+
torch.zeros(size=(b, 2*p, encoder_hidden_channels), dtype=torch.float32)
|
460 |
] * encoder_num_layers
|
461 |
+
encoder_hidden_cache_list = torch.stack(encoder_hidden_cache_list, dim=0)
|
462 |
|
463 |
lstm_hidden_state = [
|
464 |
torch.zeros(size=(decoder_num_layers, b, decoder_hidden_size), dtype=torch.float32)
|
465 |
] * 2
|
466 |
lstm_hidden_state = torch.stack(lstm_hidden_state, dim=0)
|
467 |
|
468 |
+
logits, probs, lsnr, new_encoder_in_cache, new_encoder_hidden_cache_list, new_lstm_hidden_state = model_export.forward(
|
469 |
+
inputs, encoder_in_cache, encoder_hidden_cache_list, lstm_hidden_state
|
470 |
+
)
|
471 |
print(f"logits.shape: {logits.shape}")
|
472 |
+
print(f"new_encoder_in_cache.shape: {new_encoder_in_cache.shape}")
|
473 |
+
print(f"new_encoder_hidden_cache_list.shape: {new_encoder_hidden_cache_list.shape}")
|
474 |
print(f"new_lstm_hidden_state.shape: {new_lstm_hidden_state.shape}")
|
475 |
|
476 |
torch.onnx.export(model_export,
|
477 |
+
args=(inputs, encoder_in_cache, encoder_hidden_cache_list, lstm_hidden_state),
|
478 |
f="silero_vad.onnx",
|
479 |
+
input_names=["inputs", "encoder_in_cache", "encoder_hidden_cache_list", "lstm_hidden_state"],
|
480 |
+
output_names=[
|
481 |
+
"logits", "probs", "lsnr",
|
482 |
+
"new_encoder_in_cache",
|
483 |
+
"new_encoder_hidden_cache_list",
|
484 |
+
"new_lstm_hidden_state"
|
485 |
+
],
|
486 |
dynamic_axes={
|
487 |
"inputs": {0: "batch_size", 2: "num_samples"},
|
488 |
+
"encoder_in_cache": {1: "batch_size"},
|
489 |
+
"encoder_hidden_cache_list": {1: "batch_size"},
|
490 |
"lstm_hidden_state": {2: "batch_size"},
|
491 |
"logits": {0: "batch_size"},
|
492 |
"probs": {0: "batch_size"},
|
493 |
"lsnr": {0: "batch_size"},
|
494 |
+
"new_encoder_in_cache": {1: "batch_size"},
|
495 |
+
"new_encoder_hidden_cache_list": {1: "batch_size"},
|
496 |
"new_lstm_hidden_state": {2: "batch_size"},
|
497 |
})
|
498 |
|
499 |
ort_session = ort.InferenceSession("silero_vad.onnx")
|
500 |
input_feed = {
|
501 |
"inputs": inputs.numpy(),
|
502 |
+
"encoder_in_cache": encoder_in_cache.numpy(),
|
503 |
+
"encoder_hidden_cache_list": encoder_hidden_cache_list.numpy(),
|
504 |
"lstm_hidden_state": lstm_hidden_state.numpy(),
|
505 |
}
|
506 |
output_names = [
|
507 |
+
"logits", "probs", "lsnr", "new_encoder_in_cache", "new_encoder_hidden_cache_list", "new_lstm_hidden_state"
|
508 |
]
|
509 |
+
logits, probs, lsnr, new_encoder_in_cache, new_encoder_hidden_cache_list, new_lstm_hidden_state = ort_session.run(output_names, input_feed)
|
510 |
print(f"probs.shape: {probs.shape}")
|
|
|
511 |
return
|
512 |
|
513 |
|