HoneyTian commited on
Commit
83fc52b
·
1 Parent(s): ff7995b
examples/silero_vad_by_webrtcvad/yaml/config.yaml CHANGED
@@ -8,11 +8,11 @@ hop_size: 80
8
  win_type: hann
9
 
10
  # model
11
- conv_channels: 32
12
- hidden_size: 80
13
- kernel_size:
14
- - 3
15
- - 3
16
 
17
  # lsnr
18
  n_frame: 3
 
8
  win_type: hann
9
 
10
  # model
11
+ encoder_in_channels: 64
12
+ encoder_kernel_size: 3
13
+ encoder_num_layers: 3
14
+
15
+ decoder_hidden_size: 64
16
 
17
  # lsnr
18
  n_frame: 3
toolbox/torchaudio/models/vad/silero_vad/configuration_silero_vad.py CHANGED
@@ -13,9 +13,12 @@ class SileroVadConfig(PretrainedConfig):
13
  hop_size: int = 80,
14
  win_type: str = "hann",
15
 
16
- conv_channels: int = 32,
17
- hidden_size: int = 80,
18
- kernel_size: Tuple[int, int] = (3, 3),
 
 
 
19
 
20
  n_frame: int = 3,
21
  min_local_snr_db: float = -15,
@@ -48,9 +51,12 @@ class SileroVadConfig(PretrainedConfig):
48
  self.win_type = win_type
49
 
50
  # encoder
51
- self.conv_channels = conv_channels
52
- self.hidden_size = hidden_size
53
- self.kernel_size = kernel_size
 
 
 
54
 
55
  # lsnr
56
  self.n_frame = n_frame
 
13
  hop_size: int = 80,
14
  win_type: str = "hann",
15
 
16
+ encoder_in_channels: int = 64,
17
+ encoder_kernel_size: int = 3,
18
+ encoder_num_layers: int = 3,
19
+
20
+ decoder_hidden_size: int = 64,
21
+ decoder_num_layers: int = 2,
22
 
23
  n_frame: int = 3,
24
  min_local_snr_db: float = -15,
 
51
  self.win_type = win_type
52
 
53
  # encoder
54
+ self.encoder_in_channels = encoder_in_channels
55
+ self.encoder_kernel_size = encoder_kernel_size
56
+ self.encoder_num_layers = encoder_num_layers
57
+
58
+ self.decoder_hidden_size = decoder_hidden_size
59
+ self.decoder_num_layers = decoder_num_layers
60
 
61
  # lsnr
62
  self.n_frame = n_frame
toolbox/torchaudio/models/vad/silero_vad/inference_silero_vad_onnx.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import logging
5
+ from pathlib import Path
6
+ import shutil
7
+ import tempfile, time
8
+ from typing import List
9
+ import zipfile
10
+
11
+ from scipy.io import wavfile
12
+ import numpy as np
13
+ import torch
14
+ import onnxruntime as ort
15
+
16
+ torch.set_num_threads(1)
17
+
18
+ from project_settings import project_path
19
+ from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
20
+ from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
21
+
22
+
23
+ logger = logging.getLogger("toolbox")
24
+
25
+
26
+ class InferenceSileroVadOnnx(object):
27
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
28
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
29
+ self.device = torch.device(device)
30
+
31
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
32
+ config, ort_session = self.load_models(self.pretrained_model_path_or_zip_file)
33
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
34
+
35
+ self.config = config
36
+ self.ort_session = ort_session
37
+
38
+ def load_models(self, model_path: str):
39
+ model_path = Path(model_path)
40
+ if model_path.name.endswith(".zip"):
41
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
42
+ out_root = Path(tempfile.gettempdir()) / "cc_vad"
43
+ out_root.mkdir(parents=True, exist_ok=True)
44
+ f_zip.extractall(path=out_root)
45
+ model_path = out_root / model_path.stem
46
+
47
+ config = SileroVadConfig.from_pretrained(
48
+ pretrained_model_name_or_path=model_path.as_posix(),
49
+ )
50
+ ort_session = ort.InferenceSession(
51
+ path_or_bytes=(model_path / "model.onnx").as_posix()
52
+ )
53
+
54
+ shutil.rmtree(model_path)
55
+ return config, ort_session
56
+
57
+ def infer(self, signal: np.ndarray) -> np.ndarray:
58
+ # signal shape: [num_samples,], value between -1 and 1.
59
+
60
+ inputs = torch.tensor(signal, dtype=torch.float32)
61
+ inputs = torch.unsqueeze(inputs, dim=0)
62
+ inputs = torch.unsqueeze(inputs, dim=0)
63
+ # inputs shape: [1, 1, num_samples]
64
+
65
+ b = 1
66
+
67
+ # param
68
+ encoder_num_layers = self.config.encoder_num_layers
69
+ p = (self.config.encoder_kernel_size - 1) // 2
70
+ encoder_in_channels = self.config.encoder_in_channels
71
+
72
+ decoder_num_layers = self.config.decoder_num_layers
73
+ decoder_hidden_size = self.config.decoder_hidden_size
74
+
75
+ # cache 1
76
+ encoder_cache_list = [
77
+ torch.zeros(size=(b, 2 * p, encoder_in_channels), dtype=torch.float32)
78
+ ] * encoder_num_layers
79
+ encoder_cache_list = torch.stack(encoder_cache_list, dim=0)
80
+
81
+ # cache 2
82
+ lstm_hidden_state = [
83
+ torch.zeros(size=(decoder_num_layers, b, decoder_hidden_size), dtype=torch.float32)
84
+ ] * 2
85
+ lstm_hidden_state = torch.stack(lstm_hidden_state, dim=0)
86
+
87
+ input_feed = {
88
+ "inputs": inputs.numpy(),
89
+ "encoder_cache_list": encoder_cache_list.numpy(),
90
+ "lstm_hidden_state": lstm_hidden_state.numpy(),
91
+ }
92
+ output_names = [
93
+ "logits", "probs", "lsnr", "new_encoder_cache_list", "new_lstm_hidden_state"
94
+ ]
95
+ logits, probs, lsnr, new_encoder_cache_list, new_lstm_hidden_state = self.ort_session.run(output_names, input_feed)
96
+ # probs shape: [b, t, 1]
97
+ probs = np.squeeze(probs, axis=-1)
98
+ # probs shape: [b, t]
99
+ probs = probs[0]
100
+
101
+ # lsnr shape: [b, t, 1]
102
+ lsnr = np.squeeze(lsnr, axis=-1)
103
+ # lsnr shape: [b, t]
104
+ lsnr = lsnr[0]
105
+
106
+ result = {
107
+ "probs": probs,
108
+ "lsnr": lsnr,
109
+ }
110
+ return result
111
+
112
+ def post_process(self, probs: List[float]):
113
+ return
114
+
115
+
116
+ def get_args():
117
+ parser = argparse.ArgumentParser()
118
+ parser.add_argument(
119
+ "--wav_file",
120
+ # default=(project_path / "data/examples/ai_agent/chinese-4.wav").as_posix(),
121
+ # default=(project_path / "data/examples/ai_agent/chinese-5.wav").as_posix(),
122
+ # default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
123
+ # default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
124
+ # default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
125
+ # default=(project_path / "data/examples/speech/active_media_r_2a2f472b-a0b8-4fd5-b1c4-1aedc5d2ce57_0.wav").as_posix(),
126
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_w_8b6e28e2-a238-4c8c-b2e3-426b1fca149b_6.wav",
127
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0a56f035-40f6-4530-b852-613f057d718d_6.wav",
128
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ae70b76-3651-4a71-bc0c-9e1429e4c854_5.wav",
129
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d483249-57f8-4d45-b4c6-bda82d6816ae_2.wav",
130
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d952885-5bc2-4633-81b6-e0e809e113f1_2.wav",
131
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ddac777-d986-4a5c-9c7c-ff64be0a463d_11.wav",
132
+
133
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0b8a8e80-52af-423b-8877-03a78b1e6e43_0.wav",
134
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0ebffb68-6490-4a8b-8eb6-eb82443d7d75_0.wav",
135
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0f6ec933-90df-447b-aca4-6ddc149452ab_0.wav",
136
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_0.wav",
137
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_1.wav",
138
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aff518b-4749-42fc-adfe-64046f9baeb6_0.wav",
139
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_0.wav",
140
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_1.wav",
141
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1bb1f22e-9c3a-4aea-b53f-71cc6547a6ee_0.wav",
142
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1dab161b-2a76-4491-abd1-60dba6172f8d_2.wav",
143
+ type=str,
144
+ )
145
+ args = parser.parse_args()
146
+ return args
147
+
148
+
149
+ SAMPLE_RATE = 8000
150
+
151
+
152
+ def main():
153
+ args = get_args()
154
+
155
+ sample_rate, signal = wavfile.read(args.wav_file)
156
+ if SAMPLE_RATE != sample_rate:
157
+ raise AssertionError
158
+ signal = signal / (1 << 15)
159
+
160
+ infer = InferenceFSMNVadOnnx(
161
+ # pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix(),
162
+ pretrained_model_path_or_zip_file = (project_path / "trained_models/fsmn-vad-by-webrtcvad-nx2-dns3.zip").as_posix(),
163
+ )
164
+ frame_step = infer.config.hop_size
165
+
166
+ speech_probs: np.ndarray = infer.infer(signal)
167
+ speech_probs = speech_probs.tolist()
168
+
169
+ speech_probs = process_speech_probs(
170
+ signal=signal,
171
+ speech_probs=speech_probs,
172
+ frame_step=frame_step,
173
+ )
174
+
175
+ # plot
176
+ make_visualization(signal, speech_probs, SAMPLE_RATE)
177
+ return
178
+
179
+
180
+ if __name__ == "__main__":
181
+ main()
toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py CHANGED
@@ -2,15 +2,12 @@
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://github.com/snakers4/silero-vad/wiki/Quality-Metrics
5
-
6
  https://pytorch.org/hub/snakers4_silero-vad_vad/
7
  https://github.com/snakers4/silero-vad
8
-
9
  https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/data/silero_vad.jit
10
  """
11
- import math
12
  import os
13
- from typing import List, Optional, Union, Iterable, Tuple
14
 
15
  import torch
16
  import torch.nn as nn
@@ -25,156 +22,96 @@ from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
25
  MODEL_FILE = "model.pt"
26
 
27
 
28
- norm_layer_dict = {
29
- "batch_norm_2d": torch.nn.BatchNorm2d
30
- }
31
-
32
-
33
- activation_layer_dict = {
34
- "relu": torch.nn.ReLU,
35
- "identity": torch.nn.Identity,
36
- "sigmoid": torch.nn.Sigmoid,
37
- }
38
-
39
-
40
- class CausalConv2d(nn.Module):
41
  def __init__(self,
42
- in_channels: int,
43
- out_channels: int,
44
- kernel_size: Union[int, Iterable[int]],
45
- fstride: int = 1,
46
- dilation: int = 1,
47
- pad_f_dim: bool = True,
48
- bias: bool = True,
49
- separable: bool = False,
50
- norm_layer: str = "batch_norm_2d",
51
- activation_layer: str = "relu",
52
  ):
53
- super(CausalConv2d, self).__init__()
54
- kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
55
-
56
- if pad_f_dim:
57
- fpad = kernel_size[1] // 2 + dilation - 1
58
- else:
59
- fpad = 0
60
-
61
- # for last 2 dim, pad (left, right, top, bottom).
62
- self.lookback = kernel_size[0] - 1
63
- if self.lookback > 0:
64
- self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0)
65
- else:
66
- self.tpad = nn.Identity()
67
-
68
- groups = math.gcd(in_channels, out_channels) if separable else 1
69
- if groups == 1:
70
- separable = False
71
- if max(kernel_size) == 1:
72
- separable = False
73
-
74
- self.conv = nn.Conv2d(
75
- in_channels,
76
- out_channels,
77
  kernel_size=kernel_size,
78
- padding=(0, fpad),
79
- stride=(1, fstride), # stride over time is always 1
80
- dilation=(1, dilation), # dilation over time is always 1
81
- groups=groups,
82
- bias=bias,
83
  )
 
 
84
 
85
- if separable:
86
- self.convp = nn.Conv2d(
87
- out_channels,
88
- out_channels,
89
- kernel_size=1,
90
- bias=False,
91
- )
92
- else:
93
- self.convp = nn.Identity()
94
-
95
- if norm_layer is not None:
96
- norm_layer = norm_layer_dict[norm_layer]
97
- self.norm = norm_layer(out_channels)
98
- else:
99
- self.norm = nn.Identity()
100
-
101
- if activation_layer is not None:
102
- activation_layer = activation_layer_dict[activation_layer]
103
- self.activation = activation_layer()
104
- else:
105
- self.activation = nn.Identity()
106
-
107
- def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None):
108
- """
109
- :param inputs: shape: [b, c, t, f]
110
- :param cache: shape: [b, c, lookback, f];
111
- :return:
112
- """
113
- x = inputs
114
-
115
- if cache is None:
116
- x = self.tpad(x)
117
- else:
118
- x = torch.concat(tensors=[cache, x], dim=2)
119
-
120
- new_cache = None
121
- if self.lookback > 0:
122
- new_cache = x[:, :, -self.lookback:, :]
123
-
124
- x = self.conv(x)
125
 
126
- x = self.convp(x)
127
- x = self.norm(x)
128
  x = self.activation(x)
 
 
 
 
129
 
130
- return x, new_cache
131
 
132
 
133
- class CausalEncoder(nn.Module):
134
  def __init__(self,
135
- conv_channels: int,
136
- kernel_size: Tuple[int, int] = (3, 3),
 
 
137
  num_layers: int = 3,
138
  ):
139
- super(CausalEncoder, self).__init__()
140
- self.layers: List[CausalConv2d] = nn.ModuleList(modules=[
141
- CausalConv2d(
142
- in_channels=1,
143
- out_channels=conv_channels,
144
- kernel_size=kernel_size,
145
- bias=False,
146
- separable=True,
147
- fstride=1,
148
- )
149
- if i == 0 else
150
- CausalConv2d(
151
- in_channels=conv_channels,
152
- out_channels=conv_channels,
153
- kernel_size=kernel_size,
154
- bias=False,
155
- separable=True,
156
- fstride=1,
157
- )
158
- for i in range(num_layers)
159
- ])
160
-
161
- def forward(self, x: torch.Tensor, cache_list: List[torch.Tensor] = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  # x shape: [b, t, f]
163
- x = torch.unsqueeze(x, dim=1)
164
- # x shape: [b, 1, t, f]
165
 
166
  new_cache_list = list()
167
  for idx, layer in enumerate(self.layers):
168
- cache = None if cache_list is None else cache_list[idx]
169
- x, new_cache = layer.forward(x, cache=cache)
 
 
 
 
170
  new_cache_list.append(new_cache)
171
 
172
- # x shape: [b, c, t, f]
173
- x = x.permute(0, 2, 1, 3)
174
- # x shape: [b, t, c, f]
175
- b, t, c, f = x.shape
176
- x = torch.reshape(x, shape=(b, t, c*f))
177
- # x shape: [b, t, c*f]
178
  return x, new_cache_list
179
 
180
 
@@ -185,15 +122,14 @@ class SileroVadModel(nn.Module):
185
  win_size: int,
186
  hop_size: int,
187
  win_type: int,
188
-
189
- conv_channels: int,
190
- hidden_size: int,
191
- kernel_size: Tuple[int, int],
192
-
193
  n_frame: int,
194
  min_local_snr_db: float,
195
  max_local_snr_db: float,
196
-
197
  ):
198
  super(SileroVadModel, self).__init__()
199
  self.sample_rate = sample_rate
@@ -202,9 +138,12 @@ class SileroVadModel(nn.Module):
202
  self.hop_size = hop_size
203
  self.win_type = win_type
204
 
205
- self.conv_channels = conv_channels
206
- self.hidden_size = hidden_size
207
- self.kernel_size = kernel_size
 
 
 
208
 
209
  self.n_frame = n_frame
210
  self.min_local_snr_db = min_local_snr_db
@@ -231,24 +170,33 @@ class SileroVadModel(nn.Module):
231
 
232
  self.linear = nn.Linear(
233
  in_features=(self.nfft // 2 + 1),
234
- out_features=self.hidden_size,
235
  )
236
 
237
- self.encoder = CausalEncoder(
238
- conv_channels=conv_channels,
239
- kernel_size=(3, 3),
 
 
 
 
 
 
 
 
240
  )
241
 
242
  self.lstm = nn.LSTM(
243
- input_size=self.conv_channels * self.hidden_size,
244
- hidden_size=self.hidden_size,
 
245
  bidirectional=False,
246
  batch_first=True
247
  )
248
 
249
  # vad
250
  self.vad_fc = nn.Sequential(
251
- nn.Linear(self.hidden_size, 32),
252
  nn.ReLU(),
253
  nn.Linear(32, 1),
254
  )
@@ -256,7 +204,7 @@ class SileroVadModel(nn.Module):
256
 
257
  # lsnr
258
  self.lsnr_fc = nn.Sequential(
259
- nn.Linear(self.hidden_size, 1),
260
  nn.Sigmoid()
261
  )
262
  self.lsnr_scale = self.max_local_snr_db - self.min_local_snr_db
@@ -289,10 +237,14 @@ class SileroVadModel(nn.Module):
289
  x = self.linear.forward(x)
290
  # x shape: [b, t, f']
291
 
292
- x, _ = self.encoder.forward(x)
 
 
 
 
293
  # x shape: [b, t, f']
294
 
295
- x, _ = self.lstm.forward(x)
296
 
297
  logits = self.vad_fc.forward(x)
298
  # logits shape: [b, t, 1]
@@ -345,9 +297,11 @@ class SileroVadPretrainedModel(SileroVadModel):
345
  win_size=config.win_size,
346
  hop_size=config.hop_size,
347
  win_type=config.win_type,
348
- conv_channels=config.conv_channels,
349
- hidden_size=config.hidden_size,
350
- kernel_size=config.kernel_size,
 
 
351
  n_frame=config.n_frame,
352
  min_local_snr_db=config.min_local_snr_db,
353
  max_local_snr_db=config.max_local_snr_db,
@@ -392,7 +346,61 @@ class SileroVadPretrainedModel(SileroVadModel):
392
  return save_directory
393
 
394
 
395
- def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  config = SileroVadConfig()
397
  model = SileroVadPretrainedModel(config=config)
398
 
@@ -406,5 +414,70 @@ def main():
406
  return
407
 
408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  if __name__ == "__main__":
410
- main()
 
 
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://github.com/snakers4/silero-vad/wiki/Quality-Metrics
 
5
  https://pytorch.org/hub/snakers4_silero-vad_vad/
6
  https://github.com/snakers4/silero-vad
 
7
  https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/data/silero_vad.jit
8
  """
 
9
  import os
10
+ from typing import Optional, Union
11
 
12
  import torch
13
  import torch.nn as nn
 
22
  MODEL_FILE = "model.pt"
23
 
24
 
25
+ class EncoderBlock(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
26
  def __init__(self,
27
+ in_channels: int = 64,
28
+ out_channels: int = 128,
29
+ kernel_size: int = 3,
 
 
 
 
 
 
 
30
  ):
31
+ super(EncoderBlock, self).__init__()
32
+ self.conv1d = nn.Conv1d(
33
+ in_channels=in_channels,
34
+ out_channels=out_channels,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  kernel_size=kernel_size,
36
+ padding="valid",
 
 
 
 
37
  )
38
+ self.activation = nn.ReLU()
39
+ self.norm = nn.BatchNorm1d(out_channels)
40
 
41
+ def forward(self, x: torch.Tensor):
42
+ # x shape: [b, t, f]
43
+ x = torch.transpose(x, dim0=1, dim1=2)
44
+ # x shape: [b, f, t]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ x = self.conv1d.forward(x)
 
47
  x = self.activation(x)
48
+ x = self.norm(x)
49
+
50
+ x = torch.transpose(x, dim0=1, dim1=2)
51
+ # x shape: [b, t, f]
52
 
53
+ return x
54
 
55
 
56
+ class Encoder(nn.Module):
57
  def __init__(self,
58
+ in_channels: int = 64,
59
+ hidden_channels: int = 128,
60
+ out_channels: int = 64,
61
+ kernel_size: int = 3,
62
  num_layers: int = 3,
63
  ):
64
+ super(Encoder, self).__init__()
65
+
66
+ self.layers = nn.ModuleList(modules=[])
67
+ for i in range(num_layers):
68
+ if i == 0:
69
+ encoder_block = EncoderBlock(
70
+ in_channels=in_channels,
71
+ out_channels=hidden_channels,
72
+ kernel_size=kernel_size,
73
+ )
74
+ elif i == (num_layers - 1):
75
+ encoder_block = EncoderBlock(
76
+ in_channels=hidden_channels,
77
+ out_channels=out_channels,
78
+ kernel_size=kernel_size,
79
+ )
80
+ else:
81
+ encoder_block = EncoderBlock(
82
+ in_channels=hidden_channels,
83
+ out_channels=hidden_channels,
84
+ kernel_size=kernel_size,
85
+ )
86
+ self.layers.append(encoder_block)
87
+
88
+ def forward(self, x: torch.Tensor):
89
+ # x shape: [b, t, f]
90
+ for layer in self.layers:
91
+ x = layer.forward(x)
92
+ return x
93
+
94
+
95
+ class EncoderExport(nn.Module):
96
+ def __init__(self, model: Encoder):
97
+ super(EncoderExport, self).__init__()
98
+ self.layers = model.layers
99
+
100
+ def forward(self, x: torch.Tensor, cache_list: torch.Tensor):
101
  # x shape: [b, t, f]
102
+ # cache_list shape: [num_layers, b, 2p, f]
 
103
 
104
  new_cache_list = list()
105
  for idx, layer in enumerate(self.layers):
106
+ cache = cache_list[idx]
107
+ x_pad = torch.concat(tensors=[cache, x], dim=1)
108
+ x = layer.forward(x_pad)
109
+
110
+ _, twop, _ = cache.shape
111
+ new_cache = x_pad[:, -twop:, :]
112
  new_cache_list.append(new_cache)
113
 
114
+ new_cache_list = torch.stack(tensors=new_cache_list, dim=0)
 
 
 
 
 
115
  return x, new_cache_list
116
 
117
 
 
122
  win_size: int,
123
  hop_size: int,
124
  win_type: int,
125
+ encoder_in_channels: int,
126
+ encoder_kernel_size: int,
127
+ encoder_num_layers: int,
128
+ decoder_hidden_size: int,
129
+ decoder_num_layers: int,
130
  n_frame: int,
131
  min_local_snr_db: float,
132
  max_local_snr_db: float,
 
133
  ):
134
  super(SileroVadModel, self).__init__()
135
  self.sample_rate = sample_rate
 
138
  self.hop_size = hop_size
139
  self.win_type = win_type
140
 
141
+ self.encoder_in_channels = encoder_in_channels
142
+ self.encoder_kernel_size = encoder_kernel_size
143
+ self.encoder_num_layers = encoder_num_layers
144
+
145
+ self.decoder_hidden_size = decoder_hidden_size
146
+ self.decoder_num_layers = decoder_num_layers
147
 
148
  self.n_frame = n_frame
149
  self.min_local_snr_db = min_local_snr_db
 
170
 
171
  self.linear = nn.Linear(
172
  in_features=(self.nfft // 2 + 1),
173
+ out_features=self.encoder_in_channels,
174
  )
175
 
176
+ # for last 2 dim, pad (left, right, top, bottom).
177
+ # (b, t, f) -> (b, t+p, f)
178
+ self.p = self.encoder_num_layers * (self.encoder_kernel_size - 1) // 2
179
+ self.tpad = nn.ConstantPad2d(padding=(0, 0, self.p, self.p), value=0.0)
180
+
181
+ self.encoder = Encoder(
182
+ in_channels=self.encoder_in_channels,
183
+ hidden_channels=self.decoder_hidden_size,
184
+ out_channels=self.decoder_hidden_size,
185
+ kernel_size=self.encoder_kernel_size,
186
+ num_layers=self.encoder_num_layers,
187
  )
188
 
189
  self.lstm = nn.LSTM(
190
+ input_size=self.decoder_hidden_size,
191
+ hidden_size=self.decoder_hidden_size,
192
+ num_layers=self.decoder_num_layers,
193
  bidirectional=False,
194
  batch_first=True
195
  )
196
 
197
  # vad
198
  self.vad_fc = nn.Sequential(
199
+ nn.Linear(self.decoder_hidden_size, 32),
200
  nn.ReLU(),
201
  nn.Linear(32, 1),
202
  )
 
204
 
205
  # lsnr
206
  self.lsnr_fc = nn.Sequential(
207
+ nn.Linear(self.decoder_hidden_size, 1),
208
  nn.Sigmoid()
209
  )
210
  self.lsnr_scale = self.max_local_snr_db - self.min_local_snr_db
 
237
  x = self.linear.forward(x)
238
  # x shape: [b, t, f']
239
 
240
+ # pad
241
+ x = self.tpad.forward(x)
242
+ # x shape: [b, t+2p, f']
243
+
244
+ x = self.encoder.forward(x)
245
  # x shape: [b, t, f']
246
 
247
+ x, (h, c) = self.lstm.forward(x)
248
 
249
  logits = self.vad_fc.forward(x)
250
  # logits shape: [b, t, 1]
 
297
  win_size=config.win_size,
298
  hop_size=config.hop_size,
299
  win_type=config.win_type,
300
+ encoder_in_channels=config.encoder_in_channels,
301
+ encoder_kernel_size=config.encoder_kernel_size,
302
+ encoder_num_layers=config.encoder_num_layers,
303
+ decoder_hidden_size=config.decoder_hidden_size,
304
+ decoder_num_layers=config.decoder_num_layers,
305
  n_frame=config.n_frame,
306
  min_local_snr_db=config.min_local_snr_db,
307
  max_local_snr_db=config.max_local_snr_db,
 
346
  return save_directory
347
 
348
 
349
+ class SileroVadModelExport(nn.Module):
350
+ def __init__(self, model: SileroVadModel):
351
+ super(SileroVadModelExport, self).__init__()
352
+ self.stft = model.stft
353
+ self.linear = model.linear
354
+ self.encoder = EncoderExport(model.encoder)
355
+ self.lstm = model.lstm
356
+ self.vad_fc = model.vad_fc
357
+ self.sigmoid = model.sigmoid
358
+
359
+ self.lsnr_fc = model.lsnr_fc
360
+ self.lsnr_scale = model.lsnr_scale
361
+ self.lsnr_offset = model.lsnr_offset
362
+
363
+ def forward(self,
364
+ signal: torch.Tensor,
365
+ encoder_cache_list: torch.Tensor,
366
+ lstm_hidden_state: torch.Tensor,
367
+ ):
368
+ # encoder_cache_list shape: [num_layers, b, 2p, f]
369
+ # lstm_hidden_state shape: [2, num_layers, b, h]
370
+
371
+ # signal shape [b, 1, num_samples]
372
+ mags = self.stft.forward(signal)
373
+ # mags shape: [b, f, t]
374
+
375
+ x = torch.transpose(mags, dim0=1, dim1=2)
376
+ # x shape: [b, t, f]
377
+
378
+ x = self.linear.forward(x)
379
+ # x shape: [b, t, f']
380
+
381
+ # pad
382
+ # x = self.tpad.forward(x)
383
+ # x shape: [b, t+p, f']
384
+
385
+ x, new_encoder_cache_list = self.encoder.forward(x, cache_list=encoder_cache_list)
386
+ # x shape: [b, t, f']
387
+
388
+ x, new_lstm_hidden_state = self.lstm.forward(x, (lstm_hidden_state[0], lstm_hidden_state[1]))
389
+ new_lstm_hidden_state = torch.stack(tensors=new_lstm_hidden_state, dim=0)
390
+ # new_lstm_hidden_state shape: [2, num_layers, b, h]
391
+
392
+ logits = self.vad_fc.forward(x)
393
+ # logits shape: [b, t, 1]
394
+ probs = self.sigmoid.forward(logits)
395
+ # probs shape: [b, t, 1]
396
+
397
+ lsnr = self.lsnr_fc.forward(x) * self.lsnr_scale + self.lsnr_offset
398
+ # lsnr shape: [b, t, 1]
399
+
400
+ return logits, probs, lsnr, new_encoder_cache_list, new_lstm_hidden_state
401
+
402
+
403
+ def main1():
404
  config = SileroVadConfig()
405
  model = SileroVadPretrainedModel(config=config)
406
 
 
414
  return
415
 
416
 
417
+ def main2():
418
+ import onnx
419
+ import onnxruntime as ort
420
+
421
+ config = SileroVadConfig()
422
+ model = SileroVadPretrainedModel(config=config)
423
+ model_export = SileroVadModelExport(model)
424
+
425
+ encoder_num_layers = config.encoder_num_layers
426
+ p = (config.encoder_kernel_size - 1) // 2
427
+ encoder_in_channels = config.encoder_in_channels
428
+
429
+ decoder_num_layers = config.decoder_num_layers
430
+ decoder_hidden_size = config.decoder_hidden_size
431
+
432
+ b = 1
433
+ inputs = torch.randn(size=(b, 1, 16000), dtype=torch.float32)
434
+
435
+ encoder_cache_list = [
436
+ torch.zeros(size=(b, 2*p, encoder_in_channels), dtype=torch.float32)
437
+ ] * encoder_num_layers
438
+ encoder_cache_list = torch.stack(encoder_cache_list, dim=0)
439
+
440
+ lstm_hidden_state = [
441
+ torch.zeros(size=(decoder_num_layers, b, decoder_hidden_size), dtype=torch.float32)
442
+ ] * 2
443
+ lstm_hidden_state = torch.stack(lstm_hidden_state, dim=0)
444
+
445
+ logits, probs, lsnr, new_encoder_cache_list, new_lstm_hidden_state = model_export.forward(inputs, encoder_cache_list, lstm_hidden_state)
446
+ print(f"logits.shape: {logits.shape}")
447
+ print(f"new_encoder_cache_list.shape: {new_encoder_cache_list.shape}")
448
+ print(f"new_lstm_hidden_state.shape: {new_lstm_hidden_state.shape}")
449
+
450
+ torch.onnx.export(model_export,
451
+ args=(inputs, encoder_cache_list, lstm_hidden_state),
452
+ f="silero_vad.onnx",
453
+ input_names=["inputs", "encoder_cache_list", "lstm_hidden_state"],
454
+ output_names=["logits", "probs", "lsnr", "new_encoder_cache_list", "new_lstm_hidden_state"],
455
+ dynamic_axes={
456
+ "inputs": {0: "batch_size", 2: "num_samples"},
457
+ "encoder_cache_list": {1: "batch_size"},
458
+ "lstm_hidden_state": {2: "batch_size"},
459
+ "logits": {0: "batch_size"},
460
+ "probs": {0: "batch_size"},
461
+ "lsnr": {0: "batch_size"},
462
+ "new_encoder_cache_list": {1: "batch_size"},
463
+ "new_lstm_hidden_state": {2: "batch_size"},
464
+ })
465
+
466
+ ort_session = ort.InferenceSession("silero_vad.onnx")
467
+ input_feed = {
468
+ "inputs": inputs.numpy(),
469
+ "encoder_cache_list": encoder_cache_list.numpy(),
470
+ "lstm_hidden_state": lstm_hidden_state.numpy(),
471
+ }
472
+ output_names = [
473
+ "logits", "probs", "lsnr", "new_encoder_cache_list", "new_lstm_hidden_state"
474
+ ]
475
+ logits, probs, lsnr, new_encoder_cache_list, new_lstm_hidden_state = ort_session.run(output_names, input_feed)
476
+ print(f"probs.shape: {probs.shape}")
477
+ print(f"new_encoder_cache_list.shape: {new_encoder_cache_list.shape}")
478
+ return
479
+
480
+
481
  if __name__ == "__main__":
482
+ main2()
483
+
toolbox/torchaudio/models/vad/silero_vad/yaml/config.yaml CHANGED
@@ -8,11 +8,11 @@ hop_size: 80
8
  win_type: hann
9
 
10
  # model
11
- conv_channels: 32
12
- hidden_size: 80
13
- kernel_size:
14
- - 3
15
- - 3
16
 
17
  # lsnr
18
  n_frame: 3
 
8
  win_type: hann
9
 
10
  # model
11
+ encoder_in_channels: 64
12
+ encoder_kernel_size: 3
13
+ encoder_num_layers: 3
14
+
15
+ decoder_hidden_size: 64
16
 
17
  # lsnr
18
  n_frame: 3