update
Browse files- examples/silero_vad_by_webrtcvad/yaml/config.yaml +5 -5
- toolbox/torchaudio/models/vad/silero_vad/configuration_silero_vad.py +12 -6
- toolbox/torchaudio/models/vad/silero_vad/inference_silero_vad_onnx.py +181 -0
- toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py +232 -159
- toolbox/torchaudio/models/vad/silero_vad/yaml/config.yaml +5 -5
examples/silero_vad_by_webrtcvad/yaml/config.yaml
CHANGED
@@ -8,11 +8,11 @@ hop_size: 80
|
|
8 |
win_type: hann
|
9 |
|
10 |
# model
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
|
17 |
# lsnr
|
18 |
n_frame: 3
|
|
|
8 |
win_type: hann
|
9 |
|
10 |
# model
|
11 |
+
encoder_in_channels: 64
|
12 |
+
encoder_kernel_size: 3
|
13 |
+
encoder_num_layers: 3
|
14 |
+
|
15 |
+
decoder_hidden_size: 64
|
16 |
|
17 |
# lsnr
|
18 |
n_frame: 3
|
toolbox/torchaudio/models/vad/silero_vad/configuration_silero_vad.py
CHANGED
@@ -13,9 +13,12 @@ class SileroVadConfig(PretrainedConfig):
|
|
13 |
hop_size: int = 80,
|
14 |
win_type: str = "hann",
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
19 |
|
20 |
n_frame: int = 3,
|
21 |
min_local_snr_db: float = -15,
|
@@ -48,9 +51,12 @@ class SileroVadConfig(PretrainedConfig):
|
|
48 |
self.win_type = win_type
|
49 |
|
50 |
# encoder
|
51 |
-
self.
|
52 |
-
self.
|
53 |
-
self.
|
|
|
|
|
|
|
54 |
|
55 |
# lsnr
|
56 |
self.n_frame = n_frame
|
|
|
13 |
hop_size: int = 80,
|
14 |
win_type: str = "hann",
|
15 |
|
16 |
+
encoder_in_channels: int = 64,
|
17 |
+
encoder_kernel_size: int = 3,
|
18 |
+
encoder_num_layers: int = 3,
|
19 |
+
|
20 |
+
decoder_hidden_size: int = 64,
|
21 |
+
decoder_num_layers: int = 2,
|
22 |
|
23 |
n_frame: int = 3,
|
24 |
min_local_snr_db: float = -15,
|
|
|
51 |
self.win_type = win_type
|
52 |
|
53 |
# encoder
|
54 |
+
self.encoder_in_channels = encoder_in_channels
|
55 |
+
self.encoder_kernel_size = encoder_kernel_size
|
56 |
+
self.encoder_num_layers = encoder_num_layers
|
57 |
+
|
58 |
+
self.decoder_hidden_size = decoder_hidden_size
|
59 |
+
self.decoder_num_layers = decoder_num_layers
|
60 |
|
61 |
# lsnr
|
62 |
self.n_frame = n_frame
|
toolbox/torchaudio/models/vad/silero_vad/inference_silero_vad_onnx.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
from pathlib import Path
|
6 |
+
import shutil
|
7 |
+
import tempfile, time
|
8 |
+
from typing import List
|
9 |
+
import zipfile
|
10 |
+
|
11 |
+
from scipy.io import wavfile
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import onnxruntime as ort
|
15 |
+
|
16 |
+
torch.set_num_threads(1)
|
17 |
+
|
18 |
+
from project_settings import project_path
|
19 |
+
from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
|
20 |
+
from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("toolbox")
|
24 |
+
|
25 |
+
|
26 |
+
class InferenceSileroVadOnnx(object):
|
27 |
+
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
28 |
+
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
29 |
+
self.device = torch.device(device)
|
30 |
+
|
31 |
+
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
|
32 |
+
config, ort_session = self.load_models(self.pretrained_model_path_or_zip_file)
|
33 |
+
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
|
34 |
+
|
35 |
+
self.config = config
|
36 |
+
self.ort_session = ort_session
|
37 |
+
|
38 |
+
def load_models(self, model_path: str):
|
39 |
+
model_path = Path(model_path)
|
40 |
+
if model_path.name.endswith(".zip"):
|
41 |
+
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
|
42 |
+
out_root = Path(tempfile.gettempdir()) / "cc_vad"
|
43 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
44 |
+
f_zip.extractall(path=out_root)
|
45 |
+
model_path = out_root / model_path.stem
|
46 |
+
|
47 |
+
config = SileroVadConfig.from_pretrained(
|
48 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
49 |
+
)
|
50 |
+
ort_session = ort.InferenceSession(
|
51 |
+
path_or_bytes=(model_path / "model.onnx").as_posix()
|
52 |
+
)
|
53 |
+
|
54 |
+
shutil.rmtree(model_path)
|
55 |
+
return config, ort_session
|
56 |
+
|
57 |
+
def infer(self, signal: np.ndarray) -> np.ndarray:
|
58 |
+
# signal shape: [num_samples,], value between -1 and 1.
|
59 |
+
|
60 |
+
inputs = torch.tensor(signal, dtype=torch.float32)
|
61 |
+
inputs = torch.unsqueeze(inputs, dim=0)
|
62 |
+
inputs = torch.unsqueeze(inputs, dim=0)
|
63 |
+
# inputs shape: [1, 1, num_samples]
|
64 |
+
|
65 |
+
b = 1
|
66 |
+
|
67 |
+
# param
|
68 |
+
encoder_num_layers = self.config.encoder_num_layers
|
69 |
+
p = (self.config.encoder_kernel_size - 1) // 2
|
70 |
+
encoder_in_channels = self.config.encoder_in_channels
|
71 |
+
|
72 |
+
decoder_num_layers = self.config.decoder_num_layers
|
73 |
+
decoder_hidden_size = self.config.decoder_hidden_size
|
74 |
+
|
75 |
+
# cache 1
|
76 |
+
encoder_cache_list = [
|
77 |
+
torch.zeros(size=(b, 2 * p, encoder_in_channels), dtype=torch.float32)
|
78 |
+
] * encoder_num_layers
|
79 |
+
encoder_cache_list = torch.stack(encoder_cache_list, dim=0)
|
80 |
+
|
81 |
+
# cache 2
|
82 |
+
lstm_hidden_state = [
|
83 |
+
torch.zeros(size=(decoder_num_layers, b, decoder_hidden_size), dtype=torch.float32)
|
84 |
+
] * 2
|
85 |
+
lstm_hidden_state = torch.stack(lstm_hidden_state, dim=0)
|
86 |
+
|
87 |
+
input_feed = {
|
88 |
+
"inputs": inputs.numpy(),
|
89 |
+
"encoder_cache_list": encoder_cache_list.numpy(),
|
90 |
+
"lstm_hidden_state": lstm_hidden_state.numpy(),
|
91 |
+
}
|
92 |
+
output_names = [
|
93 |
+
"logits", "probs", "lsnr", "new_encoder_cache_list", "new_lstm_hidden_state"
|
94 |
+
]
|
95 |
+
logits, probs, lsnr, new_encoder_cache_list, new_lstm_hidden_state = self.ort_session.run(output_names, input_feed)
|
96 |
+
# probs shape: [b, t, 1]
|
97 |
+
probs = np.squeeze(probs, axis=-1)
|
98 |
+
# probs shape: [b, t]
|
99 |
+
probs = probs[0]
|
100 |
+
|
101 |
+
# lsnr shape: [b, t, 1]
|
102 |
+
lsnr = np.squeeze(lsnr, axis=-1)
|
103 |
+
# lsnr shape: [b, t]
|
104 |
+
lsnr = lsnr[0]
|
105 |
+
|
106 |
+
result = {
|
107 |
+
"probs": probs,
|
108 |
+
"lsnr": lsnr,
|
109 |
+
}
|
110 |
+
return result
|
111 |
+
|
112 |
+
def post_process(self, probs: List[float]):
|
113 |
+
return
|
114 |
+
|
115 |
+
|
116 |
+
def get_args():
|
117 |
+
parser = argparse.ArgumentParser()
|
118 |
+
parser.add_argument(
|
119 |
+
"--wav_file",
|
120 |
+
# default=(project_path / "data/examples/ai_agent/chinese-4.wav").as_posix(),
|
121 |
+
# default=(project_path / "data/examples/ai_agent/chinese-5.wav").as_posix(),
|
122 |
+
# default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
|
123 |
+
# default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
|
124 |
+
# default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
|
125 |
+
# default=(project_path / "data/examples/speech/active_media_r_2a2f472b-a0b8-4fd5-b1c4-1aedc5d2ce57_0.wav").as_posix(),
|
126 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_w_8b6e28e2-a238-4c8c-b2e3-426b1fca149b_6.wav",
|
127 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0a56f035-40f6-4530-b852-613f057d718d_6.wav",
|
128 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ae70b76-3651-4a71-bc0c-9e1429e4c854_5.wav",
|
129 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d483249-57f8-4d45-b4c6-bda82d6816ae_2.wav",
|
130 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d952885-5bc2-4633-81b6-e0e809e113f1_2.wav",
|
131 |
+
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ddac777-d986-4a5c-9c7c-ff64be0a463d_11.wav",
|
132 |
+
|
133 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0b8a8e80-52af-423b-8877-03a78b1e6e43_0.wav",
|
134 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0ebffb68-6490-4a8b-8eb6-eb82443d7d75_0.wav",
|
135 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0f6ec933-90df-447b-aca4-6ddc149452ab_0.wav",
|
136 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_0.wav",
|
137 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_1.wav",
|
138 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aff518b-4749-42fc-adfe-64046f9baeb6_0.wav",
|
139 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_0.wav",
|
140 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_1.wav",
|
141 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1bb1f22e-9c3a-4aea-b53f-71cc6547a6ee_0.wav",
|
142 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1dab161b-2a76-4491-abd1-60dba6172f8d_2.wav",
|
143 |
+
type=str,
|
144 |
+
)
|
145 |
+
args = parser.parse_args()
|
146 |
+
return args
|
147 |
+
|
148 |
+
|
149 |
+
SAMPLE_RATE = 8000
|
150 |
+
|
151 |
+
|
152 |
+
def main():
|
153 |
+
args = get_args()
|
154 |
+
|
155 |
+
sample_rate, signal = wavfile.read(args.wav_file)
|
156 |
+
if SAMPLE_RATE != sample_rate:
|
157 |
+
raise AssertionError
|
158 |
+
signal = signal / (1 << 15)
|
159 |
+
|
160 |
+
infer = InferenceFSMNVadOnnx(
|
161 |
+
# pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix(),
|
162 |
+
pretrained_model_path_or_zip_file = (project_path / "trained_models/fsmn-vad-by-webrtcvad-nx2-dns3.zip").as_posix(),
|
163 |
+
)
|
164 |
+
frame_step = infer.config.hop_size
|
165 |
+
|
166 |
+
speech_probs: np.ndarray = infer.infer(signal)
|
167 |
+
speech_probs = speech_probs.tolist()
|
168 |
+
|
169 |
+
speech_probs = process_speech_probs(
|
170 |
+
signal=signal,
|
171 |
+
speech_probs=speech_probs,
|
172 |
+
frame_step=frame_step,
|
173 |
+
)
|
174 |
+
|
175 |
+
# plot
|
176 |
+
make_visualization(signal, speech_probs, SAMPLE_RATE)
|
177 |
+
return
|
178 |
+
|
179 |
+
|
180 |
+
if __name__ == "__main__":
|
181 |
+
main()
|
toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py
CHANGED
@@ -2,15 +2,12 @@
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
https://github.com/snakers4/silero-vad/wiki/Quality-Metrics
|
5 |
-
|
6 |
https://pytorch.org/hub/snakers4_silero-vad_vad/
|
7 |
https://github.com/snakers4/silero-vad
|
8 |
-
|
9 |
https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/data/silero_vad.jit
|
10 |
"""
|
11 |
-
import math
|
12 |
import os
|
13 |
-
from typing import
|
14 |
|
15 |
import torch
|
16 |
import torch.nn as nn
|
@@ -25,156 +22,96 @@ from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
|
|
25 |
MODEL_FILE = "model.pt"
|
26 |
|
27 |
|
28 |
-
|
29 |
-
"batch_norm_2d": torch.nn.BatchNorm2d
|
30 |
-
}
|
31 |
-
|
32 |
-
|
33 |
-
activation_layer_dict = {
|
34 |
-
"relu": torch.nn.ReLU,
|
35 |
-
"identity": torch.nn.Identity,
|
36 |
-
"sigmoid": torch.nn.Sigmoid,
|
37 |
-
}
|
38 |
-
|
39 |
-
|
40 |
-
class CausalConv2d(nn.Module):
|
41 |
def __init__(self,
|
42 |
-
in_channels: int,
|
43 |
-
out_channels: int,
|
44 |
-
kernel_size:
|
45 |
-
fstride: int = 1,
|
46 |
-
dilation: int = 1,
|
47 |
-
pad_f_dim: bool = True,
|
48 |
-
bias: bool = True,
|
49 |
-
separable: bool = False,
|
50 |
-
norm_layer: str = "batch_norm_2d",
|
51 |
-
activation_layer: str = "relu",
|
52 |
):
|
53 |
-
super(
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
fpad = kernel_size[1] // 2 + dilation - 1
|
58 |
-
else:
|
59 |
-
fpad = 0
|
60 |
-
|
61 |
-
# for last 2 dim, pad (left, right, top, bottom).
|
62 |
-
self.lookback = kernel_size[0] - 1
|
63 |
-
if self.lookback > 0:
|
64 |
-
self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0)
|
65 |
-
else:
|
66 |
-
self.tpad = nn.Identity()
|
67 |
-
|
68 |
-
groups = math.gcd(in_channels, out_channels) if separable else 1
|
69 |
-
if groups == 1:
|
70 |
-
separable = False
|
71 |
-
if max(kernel_size) == 1:
|
72 |
-
separable = False
|
73 |
-
|
74 |
-
self.conv = nn.Conv2d(
|
75 |
-
in_channels,
|
76 |
-
out_channels,
|
77 |
kernel_size=kernel_size,
|
78 |
-
padding=
|
79 |
-
stride=(1, fstride), # stride over time is always 1
|
80 |
-
dilation=(1, dilation), # dilation over time is always 1
|
81 |
-
groups=groups,
|
82 |
-
bias=bias,
|
83 |
)
|
|
|
|
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
kernel_size=1,
|
90 |
-
bias=False,
|
91 |
-
)
|
92 |
-
else:
|
93 |
-
self.convp = nn.Identity()
|
94 |
-
|
95 |
-
if norm_layer is not None:
|
96 |
-
norm_layer = norm_layer_dict[norm_layer]
|
97 |
-
self.norm = norm_layer(out_channels)
|
98 |
-
else:
|
99 |
-
self.norm = nn.Identity()
|
100 |
-
|
101 |
-
if activation_layer is not None:
|
102 |
-
activation_layer = activation_layer_dict[activation_layer]
|
103 |
-
self.activation = activation_layer()
|
104 |
-
else:
|
105 |
-
self.activation = nn.Identity()
|
106 |
-
|
107 |
-
def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None):
|
108 |
-
"""
|
109 |
-
:param inputs: shape: [b, c, t, f]
|
110 |
-
:param cache: shape: [b, c, lookback, f];
|
111 |
-
:return:
|
112 |
-
"""
|
113 |
-
x = inputs
|
114 |
-
|
115 |
-
if cache is None:
|
116 |
-
x = self.tpad(x)
|
117 |
-
else:
|
118 |
-
x = torch.concat(tensors=[cache, x], dim=2)
|
119 |
-
|
120 |
-
new_cache = None
|
121 |
-
if self.lookback > 0:
|
122 |
-
new_cache = x[:, :, -self.lookback:, :]
|
123 |
-
|
124 |
-
x = self.conv(x)
|
125 |
|
126 |
-
x = self.
|
127 |
-
x = self.norm(x)
|
128 |
x = self.activation(x)
|
|
|
|
|
|
|
|
|
129 |
|
130 |
-
return x
|
131 |
|
132 |
|
133 |
-
class
|
134 |
def __init__(self,
|
135 |
-
|
136 |
-
|
|
|
|
|
137 |
num_layers: int = 3,
|
138 |
):
|
139 |
-
super(
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
# x shape: [b, t, f]
|
163 |
-
|
164 |
-
# x shape: [b, 1, t, f]
|
165 |
|
166 |
new_cache_list = list()
|
167 |
for idx, layer in enumerate(self.layers):
|
168 |
-
cache =
|
169 |
-
|
|
|
|
|
|
|
|
|
170 |
new_cache_list.append(new_cache)
|
171 |
|
172 |
-
|
173 |
-
x = x.permute(0, 2, 1, 3)
|
174 |
-
# x shape: [b, t, c, f]
|
175 |
-
b, t, c, f = x.shape
|
176 |
-
x = torch.reshape(x, shape=(b, t, c*f))
|
177 |
-
# x shape: [b, t, c*f]
|
178 |
return x, new_cache_list
|
179 |
|
180 |
|
@@ -185,15 +122,14 @@ class SileroVadModel(nn.Module):
|
|
185 |
win_size: int,
|
186 |
hop_size: int,
|
187 |
win_type: int,
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
n_frame: int,
|
194 |
min_local_snr_db: float,
|
195 |
max_local_snr_db: float,
|
196 |
-
|
197 |
):
|
198 |
super(SileroVadModel, self).__init__()
|
199 |
self.sample_rate = sample_rate
|
@@ -202,9 +138,12 @@ class SileroVadModel(nn.Module):
|
|
202 |
self.hop_size = hop_size
|
203 |
self.win_type = win_type
|
204 |
|
205 |
-
self.
|
206 |
-
self.
|
207 |
-
self.
|
|
|
|
|
|
|
208 |
|
209 |
self.n_frame = n_frame
|
210 |
self.min_local_snr_db = min_local_snr_db
|
@@ -231,24 +170,33 @@ class SileroVadModel(nn.Module):
|
|
231 |
|
232 |
self.linear = nn.Linear(
|
233 |
in_features=(self.nfft // 2 + 1),
|
234 |
-
out_features=self.
|
235 |
)
|
236 |
|
237 |
-
|
238 |
-
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
)
|
241 |
|
242 |
self.lstm = nn.LSTM(
|
243 |
-
input_size=self.
|
244 |
-
hidden_size=self.
|
|
|
245 |
bidirectional=False,
|
246 |
batch_first=True
|
247 |
)
|
248 |
|
249 |
# vad
|
250 |
self.vad_fc = nn.Sequential(
|
251 |
-
nn.Linear(self.
|
252 |
nn.ReLU(),
|
253 |
nn.Linear(32, 1),
|
254 |
)
|
@@ -256,7 +204,7 @@ class SileroVadModel(nn.Module):
|
|
256 |
|
257 |
# lsnr
|
258 |
self.lsnr_fc = nn.Sequential(
|
259 |
-
nn.Linear(self.
|
260 |
nn.Sigmoid()
|
261 |
)
|
262 |
self.lsnr_scale = self.max_local_snr_db - self.min_local_snr_db
|
@@ -289,10 +237,14 @@ class SileroVadModel(nn.Module):
|
|
289 |
x = self.linear.forward(x)
|
290 |
# x shape: [b, t, f']
|
291 |
|
292 |
-
|
|
|
|
|
|
|
|
|
293 |
# x shape: [b, t, f']
|
294 |
|
295 |
-
x,
|
296 |
|
297 |
logits = self.vad_fc.forward(x)
|
298 |
# logits shape: [b, t, 1]
|
@@ -345,9 +297,11 @@ class SileroVadPretrainedModel(SileroVadModel):
|
|
345 |
win_size=config.win_size,
|
346 |
hop_size=config.hop_size,
|
347 |
win_type=config.win_type,
|
348 |
-
|
349 |
-
|
350 |
-
|
|
|
|
|
351 |
n_frame=config.n_frame,
|
352 |
min_local_snr_db=config.min_local_snr_db,
|
353 |
max_local_snr_db=config.max_local_snr_db,
|
@@ -392,7 +346,61 @@ class SileroVadPretrainedModel(SileroVadModel):
|
|
392 |
return save_directory
|
393 |
|
394 |
|
395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
config = SileroVadConfig()
|
397 |
model = SileroVadPretrainedModel(config=config)
|
398 |
|
@@ -406,5 +414,70 @@ def main():
|
|
406 |
return
|
407 |
|
408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
if __name__ == "__main__":
|
410 |
-
|
|
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
https://github.com/snakers4/silero-vad/wiki/Quality-Metrics
|
|
|
5 |
https://pytorch.org/hub/snakers4_silero-vad_vad/
|
6 |
https://github.com/snakers4/silero-vad
|
|
|
7 |
https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/data/silero_vad.jit
|
8 |
"""
|
|
|
9 |
import os
|
10 |
+
from typing import Optional, Union
|
11 |
|
12 |
import torch
|
13 |
import torch.nn as nn
|
|
|
22 |
MODEL_FILE = "model.pt"
|
23 |
|
24 |
|
25 |
+
class EncoderBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def __init__(self,
|
27 |
+
in_channels: int = 64,
|
28 |
+
out_channels: int = 128,
|
29 |
+
kernel_size: int = 3,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
):
|
31 |
+
super(EncoderBlock, self).__init__()
|
32 |
+
self.conv1d = nn.Conv1d(
|
33 |
+
in_channels=in_channels,
|
34 |
+
out_channels=out_channels,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
kernel_size=kernel_size,
|
36 |
+
padding="valid",
|
|
|
|
|
|
|
|
|
37 |
)
|
38 |
+
self.activation = nn.ReLU()
|
39 |
+
self.norm = nn.BatchNorm1d(out_channels)
|
40 |
|
41 |
+
def forward(self, x: torch.Tensor):
|
42 |
+
# x shape: [b, t, f]
|
43 |
+
x = torch.transpose(x, dim0=1, dim1=2)
|
44 |
+
# x shape: [b, f, t]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
x = self.conv1d.forward(x)
|
|
|
47 |
x = self.activation(x)
|
48 |
+
x = self.norm(x)
|
49 |
+
|
50 |
+
x = torch.transpose(x, dim0=1, dim1=2)
|
51 |
+
# x shape: [b, t, f]
|
52 |
|
53 |
+
return x
|
54 |
|
55 |
|
56 |
+
class Encoder(nn.Module):
|
57 |
def __init__(self,
|
58 |
+
in_channels: int = 64,
|
59 |
+
hidden_channels: int = 128,
|
60 |
+
out_channels: int = 64,
|
61 |
+
kernel_size: int = 3,
|
62 |
num_layers: int = 3,
|
63 |
):
|
64 |
+
super(Encoder, self).__init__()
|
65 |
+
|
66 |
+
self.layers = nn.ModuleList(modules=[])
|
67 |
+
for i in range(num_layers):
|
68 |
+
if i == 0:
|
69 |
+
encoder_block = EncoderBlock(
|
70 |
+
in_channels=in_channels,
|
71 |
+
out_channels=hidden_channels,
|
72 |
+
kernel_size=kernel_size,
|
73 |
+
)
|
74 |
+
elif i == (num_layers - 1):
|
75 |
+
encoder_block = EncoderBlock(
|
76 |
+
in_channels=hidden_channels,
|
77 |
+
out_channels=out_channels,
|
78 |
+
kernel_size=kernel_size,
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
encoder_block = EncoderBlock(
|
82 |
+
in_channels=hidden_channels,
|
83 |
+
out_channels=hidden_channels,
|
84 |
+
kernel_size=kernel_size,
|
85 |
+
)
|
86 |
+
self.layers.append(encoder_block)
|
87 |
+
|
88 |
+
def forward(self, x: torch.Tensor):
|
89 |
+
# x shape: [b, t, f]
|
90 |
+
for layer in self.layers:
|
91 |
+
x = layer.forward(x)
|
92 |
+
return x
|
93 |
+
|
94 |
+
|
95 |
+
class EncoderExport(nn.Module):
|
96 |
+
def __init__(self, model: Encoder):
|
97 |
+
super(EncoderExport, self).__init__()
|
98 |
+
self.layers = model.layers
|
99 |
+
|
100 |
+
def forward(self, x: torch.Tensor, cache_list: torch.Tensor):
|
101 |
# x shape: [b, t, f]
|
102 |
+
# cache_list shape: [num_layers, b, 2p, f]
|
|
|
103 |
|
104 |
new_cache_list = list()
|
105 |
for idx, layer in enumerate(self.layers):
|
106 |
+
cache = cache_list[idx]
|
107 |
+
x_pad = torch.concat(tensors=[cache, x], dim=1)
|
108 |
+
x = layer.forward(x_pad)
|
109 |
+
|
110 |
+
_, twop, _ = cache.shape
|
111 |
+
new_cache = x_pad[:, -twop:, :]
|
112 |
new_cache_list.append(new_cache)
|
113 |
|
114 |
+
new_cache_list = torch.stack(tensors=new_cache_list, dim=0)
|
|
|
|
|
|
|
|
|
|
|
115 |
return x, new_cache_list
|
116 |
|
117 |
|
|
|
122 |
win_size: int,
|
123 |
hop_size: int,
|
124 |
win_type: int,
|
125 |
+
encoder_in_channels: int,
|
126 |
+
encoder_kernel_size: int,
|
127 |
+
encoder_num_layers: int,
|
128 |
+
decoder_hidden_size: int,
|
129 |
+
decoder_num_layers: int,
|
130 |
n_frame: int,
|
131 |
min_local_snr_db: float,
|
132 |
max_local_snr_db: float,
|
|
|
133 |
):
|
134 |
super(SileroVadModel, self).__init__()
|
135 |
self.sample_rate = sample_rate
|
|
|
138 |
self.hop_size = hop_size
|
139 |
self.win_type = win_type
|
140 |
|
141 |
+
self.encoder_in_channels = encoder_in_channels
|
142 |
+
self.encoder_kernel_size = encoder_kernel_size
|
143 |
+
self.encoder_num_layers = encoder_num_layers
|
144 |
+
|
145 |
+
self.decoder_hidden_size = decoder_hidden_size
|
146 |
+
self.decoder_num_layers = decoder_num_layers
|
147 |
|
148 |
self.n_frame = n_frame
|
149 |
self.min_local_snr_db = min_local_snr_db
|
|
|
170 |
|
171 |
self.linear = nn.Linear(
|
172 |
in_features=(self.nfft // 2 + 1),
|
173 |
+
out_features=self.encoder_in_channels,
|
174 |
)
|
175 |
|
176 |
+
# for last 2 dim, pad (left, right, top, bottom).
|
177 |
+
# (b, t, f) -> (b, t+p, f)
|
178 |
+
self.p = self.encoder_num_layers * (self.encoder_kernel_size - 1) // 2
|
179 |
+
self.tpad = nn.ConstantPad2d(padding=(0, 0, self.p, self.p), value=0.0)
|
180 |
+
|
181 |
+
self.encoder = Encoder(
|
182 |
+
in_channels=self.encoder_in_channels,
|
183 |
+
hidden_channels=self.decoder_hidden_size,
|
184 |
+
out_channels=self.decoder_hidden_size,
|
185 |
+
kernel_size=self.encoder_kernel_size,
|
186 |
+
num_layers=self.encoder_num_layers,
|
187 |
)
|
188 |
|
189 |
self.lstm = nn.LSTM(
|
190 |
+
input_size=self.decoder_hidden_size,
|
191 |
+
hidden_size=self.decoder_hidden_size,
|
192 |
+
num_layers=self.decoder_num_layers,
|
193 |
bidirectional=False,
|
194 |
batch_first=True
|
195 |
)
|
196 |
|
197 |
# vad
|
198 |
self.vad_fc = nn.Sequential(
|
199 |
+
nn.Linear(self.decoder_hidden_size, 32),
|
200 |
nn.ReLU(),
|
201 |
nn.Linear(32, 1),
|
202 |
)
|
|
|
204 |
|
205 |
# lsnr
|
206 |
self.lsnr_fc = nn.Sequential(
|
207 |
+
nn.Linear(self.decoder_hidden_size, 1),
|
208 |
nn.Sigmoid()
|
209 |
)
|
210 |
self.lsnr_scale = self.max_local_snr_db - self.min_local_snr_db
|
|
|
237 |
x = self.linear.forward(x)
|
238 |
# x shape: [b, t, f']
|
239 |
|
240 |
+
# pad
|
241 |
+
x = self.tpad.forward(x)
|
242 |
+
# x shape: [b, t+2p, f']
|
243 |
+
|
244 |
+
x = self.encoder.forward(x)
|
245 |
# x shape: [b, t, f']
|
246 |
|
247 |
+
x, (h, c) = self.lstm.forward(x)
|
248 |
|
249 |
logits = self.vad_fc.forward(x)
|
250 |
# logits shape: [b, t, 1]
|
|
|
297 |
win_size=config.win_size,
|
298 |
hop_size=config.hop_size,
|
299 |
win_type=config.win_type,
|
300 |
+
encoder_in_channels=config.encoder_in_channels,
|
301 |
+
encoder_kernel_size=config.encoder_kernel_size,
|
302 |
+
encoder_num_layers=config.encoder_num_layers,
|
303 |
+
decoder_hidden_size=config.decoder_hidden_size,
|
304 |
+
decoder_num_layers=config.decoder_num_layers,
|
305 |
n_frame=config.n_frame,
|
306 |
min_local_snr_db=config.min_local_snr_db,
|
307 |
max_local_snr_db=config.max_local_snr_db,
|
|
|
346 |
return save_directory
|
347 |
|
348 |
|
349 |
+
class SileroVadModelExport(nn.Module):
|
350 |
+
def __init__(self, model: SileroVadModel):
|
351 |
+
super(SileroVadModelExport, self).__init__()
|
352 |
+
self.stft = model.stft
|
353 |
+
self.linear = model.linear
|
354 |
+
self.encoder = EncoderExport(model.encoder)
|
355 |
+
self.lstm = model.lstm
|
356 |
+
self.vad_fc = model.vad_fc
|
357 |
+
self.sigmoid = model.sigmoid
|
358 |
+
|
359 |
+
self.lsnr_fc = model.lsnr_fc
|
360 |
+
self.lsnr_scale = model.lsnr_scale
|
361 |
+
self.lsnr_offset = model.lsnr_offset
|
362 |
+
|
363 |
+
def forward(self,
|
364 |
+
signal: torch.Tensor,
|
365 |
+
encoder_cache_list: torch.Tensor,
|
366 |
+
lstm_hidden_state: torch.Tensor,
|
367 |
+
):
|
368 |
+
# encoder_cache_list shape: [num_layers, b, 2p, f]
|
369 |
+
# lstm_hidden_state shape: [2, num_layers, b, h]
|
370 |
+
|
371 |
+
# signal shape [b, 1, num_samples]
|
372 |
+
mags = self.stft.forward(signal)
|
373 |
+
# mags shape: [b, f, t]
|
374 |
+
|
375 |
+
x = torch.transpose(mags, dim0=1, dim1=2)
|
376 |
+
# x shape: [b, t, f]
|
377 |
+
|
378 |
+
x = self.linear.forward(x)
|
379 |
+
# x shape: [b, t, f']
|
380 |
+
|
381 |
+
# pad
|
382 |
+
# x = self.tpad.forward(x)
|
383 |
+
# x shape: [b, t+p, f']
|
384 |
+
|
385 |
+
x, new_encoder_cache_list = self.encoder.forward(x, cache_list=encoder_cache_list)
|
386 |
+
# x shape: [b, t, f']
|
387 |
+
|
388 |
+
x, new_lstm_hidden_state = self.lstm.forward(x, (lstm_hidden_state[0], lstm_hidden_state[1]))
|
389 |
+
new_lstm_hidden_state = torch.stack(tensors=new_lstm_hidden_state, dim=0)
|
390 |
+
# new_lstm_hidden_state shape: [2, num_layers, b, h]
|
391 |
+
|
392 |
+
logits = self.vad_fc.forward(x)
|
393 |
+
# logits shape: [b, t, 1]
|
394 |
+
probs = self.sigmoid.forward(logits)
|
395 |
+
# probs shape: [b, t, 1]
|
396 |
+
|
397 |
+
lsnr = self.lsnr_fc.forward(x) * self.lsnr_scale + self.lsnr_offset
|
398 |
+
# lsnr shape: [b, t, 1]
|
399 |
+
|
400 |
+
return logits, probs, lsnr, new_encoder_cache_list, new_lstm_hidden_state
|
401 |
+
|
402 |
+
|
403 |
+
def main1():
|
404 |
config = SileroVadConfig()
|
405 |
model = SileroVadPretrainedModel(config=config)
|
406 |
|
|
|
414 |
return
|
415 |
|
416 |
|
417 |
+
def main2():
|
418 |
+
import onnx
|
419 |
+
import onnxruntime as ort
|
420 |
+
|
421 |
+
config = SileroVadConfig()
|
422 |
+
model = SileroVadPretrainedModel(config=config)
|
423 |
+
model_export = SileroVadModelExport(model)
|
424 |
+
|
425 |
+
encoder_num_layers = config.encoder_num_layers
|
426 |
+
p = (config.encoder_kernel_size - 1) // 2
|
427 |
+
encoder_in_channels = config.encoder_in_channels
|
428 |
+
|
429 |
+
decoder_num_layers = config.decoder_num_layers
|
430 |
+
decoder_hidden_size = config.decoder_hidden_size
|
431 |
+
|
432 |
+
b = 1
|
433 |
+
inputs = torch.randn(size=(b, 1, 16000), dtype=torch.float32)
|
434 |
+
|
435 |
+
encoder_cache_list = [
|
436 |
+
torch.zeros(size=(b, 2*p, encoder_in_channels), dtype=torch.float32)
|
437 |
+
] * encoder_num_layers
|
438 |
+
encoder_cache_list = torch.stack(encoder_cache_list, dim=0)
|
439 |
+
|
440 |
+
lstm_hidden_state = [
|
441 |
+
torch.zeros(size=(decoder_num_layers, b, decoder_hidden_size), dtype=torch.float32)
|
442 |
+
] * 2
|
443 |
+
lstm_hidden_state = torch.stack(lstm_hidden_state, dim=0)
|
444 |
+
|
445 |
+
logits, probs, lsnr, new_encoder_cache_list, new_lstm_hidden_state = model_export.forward(inputs, encoder_cache_list, lstm_hidden_state)
|
446 |
+
print(f"logits.shape: {logits.shape}")
|
447 |
+
print(f"new_encoder_cache_list.shape: {new_encoder_cache_list.shape}")
|
448 |
+
print(f"new_lstm_hidden_state.shape: {new_lstm_hidden_state.shape}")
|
449 |
+
|
450 |
+
torch.onnx.export(model_export,
|
451 |
+
args=(inputs, encoder_cache_list, lstm_hidden_state),
|
452 |
+
f="silero_vad.onnx",
|
453 |
+
input_names=["inputs", "encoder_cache_list", "lstm_hidden_state"],
|
454 |
+
output_names=["logits", "probs", "lsnr", "new_encoder_cache_list", "new_lstm_hidden_state"],
|
455 |
+
dynamic_axes={
|
456 |
+
"inputs": {0: "batch_size", 2: "num_samples"},
|
457 |
+
"encoder_cache_list": {1: "batch_size"},
|
458 |
+
"lstm_hidden_state": {2: "batch_size"},
|
459 |
+
"logits": {0: "batch_size"},
|
460 |
+
"probs": {0: "batch_size"},
|
461 |
+
"lsnr": {0: "batch_size"},
|
462 |
+
"new_encoder_cache_list": {1: "batch_size"},
|
463 |
+
"new_lstm_hidden_state": {2: "batch_size"},
|
464 |
+
})
|
465 |
+
|
466 |
+
ort_session = ort.InferenceSession("silero_vad.onnx")
|
467 |
+
input_feed = {
|
468 |
+
"inputs": inputs.numpy(),
|
469 |
+
"encoder_cache_list": encoder_cache_list.numpy(),
|
470 |
+
"lstm_hidden_state": lstm_hidden_state.numpy(),
|
471 |
+
}
|
472 |
+
output_names = [
|
473 |
+
"logits", "probs", "lsnr", "new_encoder_cache_list", "new_lstm_hidden_state"
|
474 |
+
]
|
475 |
+
logits, probs, lsnr, new_encoder_cache_list, new_lstm_hidden_state = ort_session.run(output_names, input_feed)
|
476 |
+
print(f"probs.shape: {probs.shape}")
|
477 |
+
print(f"new_encoder_cache_list.shape: {new_encoder_cache_list.shape}")
|
478 |
+
return
|
479 |
+
|
480 |
+
|
481 |
if __name__ == "__main__":
|
482 |
+
main2()
|
483 |
+
|
toolbox/torchaudio/models/vad/silero_vad/yaml/config.yaml
CHANGED
@@ -8,11 +8,11 @@ hop_size: 80
|
|
8 |
win_type: hann
|
9 |
|
10 |
# model
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
|
17 |
# lsnr
|
18 |
n_frame: 3
|
|
|
8 |
win_type: hann
|
9 |
|
10 |
# model
|
11 |
+
encoder_in_channels: 64
|
12 |
+
encoder_kernel_size: 3
|
13 |
+
encoder_num_layers: 3
|
14 |
+
|
15 |
+
decoder_hidden_size: 64
|
16 |
|
17 |
# lsnr
|
18 |
n_frame: 3
|