HoneyTian commited on
Commit
7f331d5
·
1 Parent(s): 5e1cd25
examples/silero_vad_by_webrtcvad/run.sh CHANGED
@@ -74,6 +74,9 @@ evaluation_audio_dir="${file_dir}/evaluation_audio"
74
  train_dataset="${file_dir}/train.jsonl"
75
  valid_dataset="${file_dir}/valid.jsonl"
76
 
 
 
 
77
  $verbose && echo "system_version: ${system_version}"
78
  $verbose && echo "file_folder_name: ${file_folder_name}"
79
 
@@ -89,7 +92,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
89
  $verbose && echo "stage 1: prepare data"
90
  cd "${work_dir}" || exit 1
91
  python3 step_1_prepare_data.py \
92
- --file_dir "${file_dir}" \
93
  --noise_dir "${noise_dir}" \
94
  --speech_dir "${speech_dir}" \
95
  --train_dataset "${train_dataset}" \
@@ -100,11 +102,23 @@ fi
100
 
101
 
102
  if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
103
- $verbose && echo "stage 2: train model"
104
  cd "${work_dir}" || exit 1
105
- python3 step_2_train_model.py \
106
  --train_dataset "${train_dataset}" \
107
  --valid_dataset "${valid_dataset}" \
 
 
 
 
 
 
 
 
 
 
 
 
108
  --serialization_dir "${file_dir}" \
109
  --config_file "${config_file}" \
110
 
 
74
  train_dataset="${file_dir}/train.jsonl"
75
  valid_dataset="${file_dir}/valid.jsonl"
76
 
77
+ train_vad_dataset="${file_dir}/train-vad.jsonl"
78
+ valid_vad_dataset="${file_dir}/valid-vad.jsonl"
79
+
80
  $verbose && echo "system_version: ${system_version}"
81
  $verbose && echo "file_folder_name: ${file_folder_name}"
82
 
 
92
  $verbose && echo "stage 1: prepare data"
93
  cd "${work_dir}" || exit 1
94
  python3 step_1_prepare_data.py \
 
95
  --noise_dir "${noise_dir}" \
96
  --speech_dir "${speech_dir}" \
97
  --train_dataset "${train_dataset}" \
 
102
 
103
 
104
  if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
105
+ $verbose && echo "stage 2: make vad segments"
106
  cd "${work_dir}" || exit 1
107
+ python3 step_2_make_vad_segments.py \
108
  --train_dataset "${train_dataset}" \
109
  --valid_dataset "${valid_dataset}" \
110
+ --train_vad_dataset "${train_vad_dataset}" \
111
+ --valid_vad_dataset "${valid_vad_dataset}" \
112
+
113
+ fi
114
+
115
+
116
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
117
+ $verbose && echo "stage 3: train model"
118
+ cd "${work_dir}" || exit 1
119
+ python3 step_3_train_model.py \
120
+ --train_dataset "${train_vad_dataset}" \
121
+ --valid_dataset "${valid_vad_dataset}" \
122
  --serialization_dir "${file_dir}" \
123
  --config_file "${config_file}" \
124
 
examples/silero_vad_by_webrtcvad/step_1_prepare_data.py CHANGED
@@ -12,16 +12,11 @@ sys.path.append(os.path.join(pwd, "../../"))
12
 
13
  import librosa
14
  import numpy as np
15
- from scipy.io import wavfile
16
  from tqdm import tqdm
17
 
18
- from toolbox.webrtcvad.vad import WebRTCVad
19
-
20
 
21
  def get_args():
22
  parser = argparse.ArgumentParser()
23
- parser.add_argument("--file_dir", default="./", type=str)
24
-
25
  parser.add_argument(
26
  "--noise_dir",
27
  default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
@@ -36,7 +31,7 @@ def get_args():
36
  parser.add_argument("--train_dataset", default="train.jsonl", type=str)
37
  parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
38
 
39
- parser.add_argument("--duration", default=4.0, type=float)
40
  parser.add_argument("--min_snr_db", default=-10, type=float)
41
  parser.add_argument("--max_snr_db", default=20, type=float)
42
 
@@ -44,12 +39,6 @@ def get_args():
44
 
45
  parser.add_argument("--max_count", default=-1, type=int)
46
 
47
- # vad
48
- parser.add_argument("--agg", default=3, type=int)
49
- parser.add_argument("--frame_duration_ms", default=30, type=int)
50
- parser.add_argument("--padding_duration_ms", default=30, type=int)
51
- parser.add_argument("--silence_duration_threshold", default=0.3, type=float)
52
-
53
  args = parser.parse_args()
54
  return args
55
 
@@ -85,9 +74,6 @@ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate
85
  def main():
86
  args = get_args()
87
 
88
- file_dir = Path(args.file_dir)
89
- file_dir.mkdir(exist_ok=True)
90
-
91
  noise_dir = Path(args.noise_dir)
92
  speech_dir = Path(args.speech_dir)
93
 
@@ -104,14 +90,6 @@ def main():
104
  max_epoch=1,
105
  )
106
 
107
- w_vad = WebRTCVad(
108
- agg=args.agg,
109
- frame_duration_ms=args.frame_duration_ms,
110
- padding_duration_ms=args.padding_duration_ms,
111
- silence_duration_threshold=args.silence_duration_threshold,
112
- sample_rate=args.target_sample_rate,
113
- )
114
-
115
  count = 0
116
  process_bar = tqdm(desc="build dataset jsonl")
117
  with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
@@ -130,14 +108,6 @@ def main():
130
  speech_offset = speech["offset"]
131
  speech_duration = speech["duration"]
132
 
133
- # vad
134
- _, signal = wavfile.read(speech_filename)
135
- vad_segments = list()
136
- segments = w_vad.vad(signal)
137
- vad_segments += segments
138
- segments = w_vad.last_vad_segments()
139
- vad_segments += segments
140
-
141
  # row
142
  random1 = random.random()
143
  random2 = random.random()
@@ -157,8 +127,6 @@ def main():
157
 
158
  "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
159
 
160
- "vad_segments": vad_segments,
161
-
162
  "random1": random1,
163
  }
164
  row = json.dumps(row, ensure_ascii=False)
@@ -173,9 +141,7 @@ def main():
173
 
174
  process_bar.update(n=1)
175
  process_bar.set_postfix({
176
- # "duration_seconds": round(duration_seconds, 4),
177
  "duration_hours": round(duration_hours, 4),
178
-
179
  })
180
 
181
  return
 
12
 
13
  import librosa
14
  import numpy as np
 
15
  from tqdm import tqdm
16
 
 
 
17
 
18
  def get_args():
19
  parser = argparse.ArgumentParser()
 
 
20
  parser.add_argument(
21
  "--noise_dir",
22
  default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
 
31
  parser.add_argument("--train_dataset", default="train.jsonl", type=str)
32
  parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
33
 
34
+ parser.add_argument("--duration", default=6.0, type=float)
35
  parser.add_argument("--min_snr_db", default=-10, type=float)
36
  parser.add_argument("--max_snr_db", default=20, type=float)
37
 
 
39
 
40
  parser.add_argument("--max_count", default=-1, type=int)
41
 
 
 
 
 
 
 
42
  args = parser.parse_args()
43
  return args
44
 
 
74
  def main():
75
  args = get_args()
76
 
 
 
 
77
  noise_dir = Path(args.noise_dir)
78
  speech_dir = Path(args.speech_dir)
79
 
 
90
  max_epoch=1,
91
  )
92
 
 
 
 
 
 
 
 
 
93
  count = 0
94
  process_bar = tqdm(desc="build dataset jsonl")
95
  with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
 
108
  speech_offset = speech["offset"]
109
  speech_duration = speech["duration"]
110
 
 
 
 
 
 
 
 
 
111
  # row
112
  random1 = random.random()
113
  random2 = random.random()
 
127
 
128
  "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
129
 
 
 
130
  "random1": random1,
131
  }
132
  row = json.dumps(row, ensure_ascii=False)
 
141
 
142
  process_bar.update(n=1)
143
  process_bar.set_postfix({
 
144
  "duration_hours": round(duration_hours, 4),
 
145
  })
146
 
147
  return
examples/silero_vad_by_webrtcvad/step_2_make_vad_segments.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ import sys
7
+
8
+ pwd = os.path.abspath(os.path.dirname(__file__))
9
+ sys.path.append(os.path.join(pwd, "../../"))
10
+
11
+ import librosa
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+
15
+ from toolbox.webrtcvad.vad import WebRTCVad
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+
21
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
22
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
23
+
24
+ parser.add_argument("--train_vad_dataset", default="train-vad.jsonl", type=str)
25
+ parser.add_argument("--valid_vad_dataset", default="valid-vad.jsonl", type=str)
26
+
27
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
28
+
29
+ # vad
30
+ parser.add_argument("--agg", default=3, type=int)
31
+ parser.add_argument("--frame_duration_ms", default=30, type=int)
32
+ parser.add_argument("--padding_duration_ms", default=30, type=int)
33
+ parser.add_argument("--silence_duration_threshold", default=0.3, type=float)
34
+
35
+ args = parser.parse_args()
36
+ return args
37
+
38
+
39
+ def main():
40
+ args = get_args()
41
+
42
+ w_vad = WebRTCVad(
43
+ agg=args.agg,
44
+ frame_duration_ms=args.frame_duration_ms,
45
+ padding_duration_ms=args.padding_duration_ms,
46
+ silence_duration_threshold=args.silence_duration_threshold,
47
+ sample_rate=args.target_sample_rate,
48
+ )
49
+
50
+ # valid
51
+ count = 0
52
+ process_bar = tqdm(desc="process valid dataset jsonl")
53
+ with (open(args.valid_dataset, "r", encoding="utf-8") as fvalid,
54
+ open(args.valid_vad_dataset, "w", encoding="utf-8") as fvalid_vad):
55
+ for row in fvalid:
56
+ row = json.loads(row)
57
+
58
+ speech_filename = row["speech_filename"]
59
+ speech_offset = row["speech_offset"]
60
+ speech_duration = row["speech_duration"]
61
+
62
+ waveform, _ = librosa.load(
63
+ speech_filename,
64
+ sr=args.expected_sample_rate,
65
+ offset=speech_offset,
66
+ duration=speech_duration,
67
+ )
68
+ waveform = np.array(waveform * (1 << 15), dtype=np.int16)
69
+
70
+ # vad
71
+ vad_segments = list()
72
+ segments = w_vad.vad(waveform)
73
+ vad_segments += segments
74
+ segments = w_vad.last_vad_segments()
75
+ vad_segments += segments
76
+ w_vad.reset()
77
+
78
+ row["vad_segments"] = vad_segments
79
+
80
+ row = json.dumps(row, ensure_ascii=False)
81
+ fvalid_vad.write(f"{row}\n")
82
+
83
+ count += 1
84
+ duration_seconds = count * args.duration
85
+ duration_hours = duration_seconds / 3600
86
+
87
+ process_bar.update(n=1)
88
+ process_bar.set_postfix({
89
+ "duration_hours": round(duration_hours, 4),
90
+ })
91
+
92
+ # train
93
+ count = 0
94
+ process_bar = tqdm(desc="process train dataset jsonl")
95
+ with (open(args.train_dataset, "r", encoding="utf-8") as ftrain,
96
+ open(args.train_vad_dataset, "w", encoding="utf-8") as ftrain_vad):
97
+ for row in ftrain:
98
+ row = json.loads(row)
99
+
100
+ speech_filename = row["speech_filename"]
101
+ speech_offset = row["speech_offset"]
102
+ speech_duration = row["speech_duration"]
103
+
104
+ waveform, _ = librosa.load(
105
+ speech_filename,
106
+ sr=args.expected_sample_rate,
107
+ offset=speech_offset,
108
+ duration=speech_duration,
109
+ )
110
+ waveform = np.array(waveform * (1 << 15), dtype=np.int16)
111
+
112
+ # vad
113
+ vad_segments = list()
114
+ segments = w_vad.vad(waveform)
115
+ vad_segments += segments
116
+ segments = w_vad.last_vad_segments()
117
+ vad_segments += segments
118
+ w_vad.reset()
119
+
120
+ row["vad_segments"] = vad_segments
121
+
122
+ row = json.dumps(row, ensure_ascii=False)
123
+ ftrain_vad.write(f"{row}\n")
124
+
125
+ count += 1
126
+ duration_seconds = count * args.duration
127
+ duration_hours = duration_seconds / 3600
128
+
129
+ process_bar.update(n=1)
130
+ process_bar.set_postfix({
131
+ "duration_hours": round(duration_hours, 4),
132
+ })
133
+
134
+ return
135
+
136
+
137
+ if __name__ == "__main__":
138
+ main()
examples/silero_vad_by_webrtcvad/{step_2_train_model.py → step_3_train_model.py} RENAMED
@@ -246,19 +246,19 @@ def main():
246
  # noisy_audios shape: [b, num_samples]
247
  num_samples = noisy_audios.shape[-1]
248
 
249
- predictions = model.forward(noisy_audios)
250
 
251
- targets = BaseVadLoss.get_targets(predictions, batch_vad_segments, duration=num_samples / config.sample_rate)
252
 
253
- bce_loss = bce_loss_fn.forward(predictions, targets)
254
- dice_loss = dice_loss_fn.forward(predictions, targets)
255
 
256
  loss = 1.0 * bce_loss + 1.0 * dice_loss
257
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
258
  logger.info(f"find nan or inf in loss. continue.")
259
  continue
260
 
261
- vad_accuracy_metrics_fn.__call__(predictions, targets)
262
 
263
  optimizer.zero_grad()
264
  loss.backward()
@@ -311,19 +311,19 @@ def main():
311
  # noisy_audios shape: [b, num_samples]
312
  num_samples = noisy_audios.shape[-1]
313
 
314
- predictions = model.forward(noisy_audios)
315
 
316
- targets = BaseVadLoss.get_targets(predictions, batch_vad_segments, duration=num_samples / config.sample_rate)
317
 
318
- bce_loss = bce_loss_fn.forward(predictions, targets)
319
- dice_loss = dice_loss_fn.forward(predictions, targets)
320
 
321
  loss = 1.0 * bce_loss + 1.0 * dice_loss
322
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
323
  logger.info(f"find nan or inf in loss. continue.")
324
  continue
325
 
326
- vad_accuracy_metrics_fn.__call__(predictions, targets)
327
 
328
  total_loss += loss.item()
329
  total_bce_loss += bce_loss.item()
 
246
  # noisy_audios shape: [b, num_samples]
247
  num_samples = noisy_audios.shape[-1]
248
 
249
+ logits, probs = model.forward(noisy_audios)
250
 
251
+ targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
252
 
253
+ bce_loss = bce_loss_fn.forward(probs, targets)
254
+ dice_loss = dice_loss_fn.forward(probs, targets)
255
 
256
  loss = 1.0 * bce_loss + 1.0 * dice_loss
257
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
258
  logger.info(f"find nan or inf in loss. continue.")
259
  continue
260
 
261
+ vad_accuracy_metrics_fn.__call__(probs, targets)
262
 
263
  optimizer.zero_grad()
264
  loss.backward()
 
311
  # noisy_audios shape: [b, num_samples]
312
  num_samples = noisy_audios.shape[-1]
313
 
314
+ logits, probs = model.forward(noisy_audios)
315
 
316
+ targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
317
 
318
+ bce_loss = bce_loss_fn.forward(probs, targets)
319
+ dice_loss = dice_loss_fn.forward(probs, targets)
320
 
321
  loss = 1.0 * bce_loss + 1.0 * dice_loss
322
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
323
  logger.info(f"find nan or inf in loss. continue.")
324
  continue
325
 
326
+ vad_accuracy_metrics_fn.__call__(probs, targets)
327
 
328
  total_loss += loss.item()
329
  total_bce_loss += bce_loss.item()
toolbox/torchaudio/models/vad/fsmn_vad/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/models/vad/fsmn_vad/fsmn_encoder.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Tuple, Dict, List
4
+ import copy
5
+ import os
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class LinearTransform(nn.Module):
14
+ def __init__(self,
15
+ input_dim: int,
16
+ output_dim: int,
17
+ ):
18
+ super(LinearTransform, self).__init__()
19
+ self.input_dim = input_dim
20
+ self.output_dim = output_dim
21
+
22
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
23
+
24
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
25
+ output = self.linear.forward(inputs)
26
+ return output
27
+
28
+
29
+ class AffineTransform(nn.Module):
30
+ def __init__(self,
31
+ input_dim: int,
32
+ output_dim: int,
33
+ ):
34
+ super(AffineTransform, self).__init__()
35
+ self.input_dim = input_dim
36
+ self.output_dim = output_dim
37
+
38
+ self.linear = nn.Linear(input_dim, output_dim)
39
+
40
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
41
+ output = self.linear.forward(inputs)
42
+ return output
43
+
44
+
45
+ class RectifiedLinear(nn.Module):
46
+ def __init__(self,
47
+ input_dim: int,
48
+ output_dim: int,
49
+ ):
50
+ super(RectifiedLinear, self).__init__()
51
+ self.dim = input_dim
52
+
53
+ self.relu = nn.ReLU()
54
+ self.dropout = nn.Dropout(0.1)
55
+
56
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
57
+ output = self.relu(inputs)
58
+ return output
59
+
60
+
61
+ class FSMNBlock(nn.Module):
62
+ def __init__(self,
63
+ hidden_size: int,
64
+ lorder: int,
65
+ rorder: int = -1,
66
+ lstride: int = 1,
67
+ rstride: int = 1,
68
+ ):
69
+ super(FSMNBlock, self).__init__()
70
+ self.hidden_size = hidden_size
71
+
72
+ self.lorder = lorder
73
+ self.rorder = rorder
74
+ self.lstride = lstride
75
+ self.rstride = rstride
76
+
77
+ self.conv_left = nn.Conv2d(
78
+ in_channels=self.hidden_size,
79
+ out_channels=self.hidden_size,
80
+ kernel_size=[lorder, 1],
81
+ dilation=[lstride, 1],
82
+ groups=self.hidden_size,
83
+ bias=False,
84
+ )
85
+
86
+ self.conv_right = None
87
+ if self.rorder > 0:
88
+ self.conv_right = nn.Conv2d(
89
+ in_channels=self.hidden_size,
90
+ out_channels=self.hidden_size,
91
+ kernel_size=[rorder, 1],
92
+ dilation=[rstride, 1],
93
+ groups=self.hidden_size,
94
+ bias=False,
95
+ )
96
+
97
+ def forward(self,
98
+ inputs: torch.Tensor,
99
+ cache: torch.Tensor = None,
100
+ ):
101
+ # inputs shape: [b, t, f]
102
+ x = torch.unsqueeze(inputs, dim=1)
103
+ # x shape: [b, 1, t, f]
104
+ x_per = x.permute(0, 3, 2, 1)
105
+ # x shape: [b, f, t, 1] / [b, c, t, 1]
106
+
107
+ if cache is None:
108
+ y_left = F.pad(x_per, pad=[0, 0, (self.lorder - 1) * self.lstride, 0])
109
+ else:
110
+ cache = cache.to(x_per.device)
111
+ y_left = torch.cat(tensors=(cache, x_per), dim=2)
112
+ cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
113
+ # cache shape: [b, f, t_pad, 1]
114
+ # y_left shape: [b, f, t', 1]
115
+ y_left = self.conv_left(y_left)
116
+ # y_left shape: [b, f, t, 1]
117
+
118
+ out = x_per + y_left
119
+ # out shape: [b, f, t, 1]
120
+
121
+ if self.conv_right is not None:
122
+ y_right = F.pad(x_per, pad=[0, 0, 0, self.rorder * self.rstride])
123
+ # y_right shape: [b, f, t', 1]
124
+
125
+ y_right = y_right[:, :, self.rstride:, :]
126
+ y_right = self.conv_right(y_right)
127
+ out += y_right
128
+
129
+ # out shape: [b, f, t, 1]
130
+ out_per = out.permute(0, 3, 2, 1)
131
+ # out_per shape: [b, 1, t, f]
132
+
133
+ output = out_per.squeeze(1)
134
+ # output shape: [b, t, f]
135
+ return output, cache
136
+
137
+
138
+ class BasicBlock(nn.Module):
139
+ def __init__(self,
140
+ input_size: int,
141
+ hidden_size: int,
142
+ lorder: int,
143
+ rorder: int = -1,
144
+ lstride: int = 1,
145
+ rstride: int = 1,
146
+ ):
147
+ super(BasicBlock, self).__init__()
148
+ self.lorder = lorder
149
+ self.rorder = rorder
150
+ self.lstride = lstride
151
+ self.rstride = rstride
152
+
153
+ self.linear = LinearTransform(input_size, hidden_size)
154
+ self.fsmn_block = FSMNBlock(
155
+ hidden_size=hidden_size,
156
+ lorder=lorder,
157
+ rorder=rorder,
158
+ lstride=lstride,
159
+ rstride=rstride,
160
+ )
161
+ self.affine = AffineTransform(hidden_size, input_size)
162
+ self.relu = RectifiedLinear(input_size, input_size)
163
+
164
+ def forward(self,
165
+ inputs: torch.Tensor,
166
+ cache: torch.Tensor = None,
167
+ ):
168
+ # inputs shape: [b, t, f]
169
+ x1 = self.linear.forward(inputs)
170
+ # x1 shape: [b, t, f']
171
+
172
+ if cache is None:
173
+ # cache shape: [b, f', t_pad, 1]
174
+ cache = torch.zeros(size=(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1))
175
+ x2, new_cache = self.fsmn_block.forward(x1, cache=cache)
176
+ # x2 shape: [b, t, f']
177
+
178
+ x3 = self.affine.forward(x2)
179
+ # x3 shape: [b, t, f]
180
+
181
+ x4 = self.relu(x3)
182
+ return x4, new_cache
183
+
184
+
185
+ class FSMN(nn.Module):
186
+ def __init__(
187
+ self,
188
+ input_size: int,
189
+ input_affine_size: int,
190
+ hidden_size: int,
191
+ basic_block_layers: int,
192
+ basic_block_hidden_size: int,
193
+ basic_block_lorder: int,
194
+ basic_block_rorder: int,
195
+ basic_block_lstride: int,
196
+ basic_block_rstride: int,
197
+ output_affine_size: int,
198
+ output_size: int,
199
+ use_softmax: bool = True,
200
+ ):
201
+ super(FSMN, self).__init__()
202
+ self.input_size = input_size
203
+ self.input_affine_size = input_affine_size
204
+ self.hidden_size = hidden_size
205
+
206
+ self.basic_block_layers = basic_block_layers
207
+
208
+ self.output_affine_size = output_affine_size
209
+ self.output_size = output_size
210
+
211
+ self.in_linear1 = AffineTransform(input_size, input_affine_size)
212
+ self.in_linear2 = AffineTransform(input_affine_size, hidden_size)
213
+ self.relu = RectifiedLinear(hidden_size, hidden_size)
214
+
215
+ self.fsmn_basic_block_list = nn.ModuleList(modules=[
216
+ BasicBlock(input_size=hidden_size,
217
+ hidden_size=basic_block_hidden_size,
218
+ lorder=basic_block_lorder,
219
+ rorder=basic_block_rorder,
220
+ lstride=basic_block_lstride,
221
+ rstride=basic_block_rstride,
222
+ )
223
+ for _ in range(basic_block_layers)
224
+ ])
225
+ self.out_linear1 = AffineTransform(hidden_size, output_affine_size)
226
+ self.out_linear2 = AffineTransform(output_affine_size, output_size)
227
+
228
+ self.use_softmax = use_softmax
229
+ if self.use_softmax:
230
+ self.softmax = nn.Softmax(dim=-1)
231
+
232
+ def forward(self,
233
+ inputs: torch.Tensor,
234
+ cache_list: List[torch.Tensor] = None,
235
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
236
+ # inputs shape: [b, t, f]
237
+ x = self.in_linear1.forward(inputs)
238
+ # x shape: [b, t, input_affine_dim]
239
+ x = self.in_linear2.forward(x)
240
+ # x shape: [b, t, f]
241
+
242
+ x = self.relu(x)
243
+
244
+ new_cache_list = list()
245
+ for idx, fsmn_basic_block in enumerate(self.fsmn_basic_block_list):
246
+ cache = None if cache_list is None else cache_list[idx]
247
+ x, new_cache = fsmn_basic_block.forward(x, cache)
248
+ new_cache_list.append(new_cache)
249
+
250
+ # x shape: [b, t, f]
251
+ x = self.out_linear1.forward(x)
252
+ outputs = self.out_linear2.forward(x)
253
+ # outputs shape: [b, t, f]
254
+
255
+ if self.use_softmax:
256
+ outputs = self.softmax(outputs)
257
+ return outputs, new_cache_list
258
+
259
+
260
+ def main():
261
+ fsmn = FSMN(
262
+ input_size=32,
263
+ input_affine_size=16,
264
+ hidden_size=16,
265
+ basic_block_layers=3,
266
+ basic_block_hidden_size=16,
267
+ basic_block_lorder=3,
268
+ basic_block_rorder=0,
269
+ basic_block_lstride=1,
270
+ basic_block_rstride=1,
271
+ output_affine_size=16,
272
+ output_size=32,
273
+ use_softmax=True,
274
+ )
275
+
276
+ inputs = torch.randn(size=(1, 198, 32), dtype=torch.float32)
277
+
278
+ result, _ = fsmn.forward(inputs)
279
+ print(result.shape)
280
+
281
+ return
282
+
283
+
284
+ if __name__ == "__main__":
285
+ main()
toolbox/torchaudio/models/vad/fsmn_vad/modeling_fsmn_vad.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://modelscope.cn/models/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary
5
+ https://huggingface.co/funasr/fsmn-vad
6
+ https://huggingface.co/funasr/fsmn-vad-onnx
7
+
8
+ https://github.com/lovemefan/fsmn-vad
9
+
10
+ https://github.com/modelscope/FunASR/blob/main/funasr/models/fsmn_vad_streaming/encoder.py
11
+
12
+ """
13
+
14
+
15
+
16
+
17
+ if __name__ == "__main__":
18
+ pass
toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py CHANGED
@@ -115,9 +115,10 @@ class SileroVadModel(nn.Module):
115
  nn.Linear(config.hidden_size, 32),
116
  nn.ReLU(),
117
  nn.Linear(32, 1),
118
- nn.Sigmoid()
119
  )
120
 
 
 
121
  def forward(self, signal: torch.Tensor):
122
  mags = self.stft.forward(signal)
123
  # mags shape: [b, f, t]
@@ -132,10 +133,11 @@ class SileroVadModel(nn.Module):
132
  # x shape: [b, t, f]
133
 
134
  x, _ = self.lstm.forward(x)
135
- x = self.classifier.forward(x)
136
-
137
- # x shape: [b, t, 1]
138
- return x
 
139
 
140
 
141
  class SileroVadPretrainedModel(SileroVadModel):
@@ -190,9 +192,9 @@ def main():
190
 
191
  noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
192
 
193
- probs = model.forward(noisy)
194
- print(f"probs: {probs}")
195
- print(f"probs.shape: {probs.shape}")
196
 
197
  return
198
 
 
115
  nn.Linear(config.hidden_size, 32),
116
  nn.ReLU(),
117
  nn.Linear(32, 1),
 
118
  )
119
 
120
+ self.sigmoid = nn.Sigmoid()
121
+
122
  def forward(self, signal: torch.Tensor):
123
  mags = self.stft.forward(signal)
124
  # mags shape: [b, f, t]
 
133
  # x shape: [b, t, f]
134
 
135
  x, _ = self.lstm.forward(x)
136
+ logits = self.classifier.forward(x)
137
+ # logits shape: [b, t, 1]
138
+ probs = self.sigmoid.forward(logits)
139
+ # probs shape: [b, t, 1]
140
+ return logits, probs
141
 
142
 
143
  class SileroVadPretrainedModel(SileroVadModel):
 
192
 
193
  noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
194
 
195
+ logits, probs = model.forward(noisy)
196
+ print(f"logits: {probs}")
197
+ print(f"logits.shape: {logits.shape}")
198
 
199
  return
200
 
toolbox/webrtcvad/vad.py CHANGED
@@ -51,6 +51,24 @@ class WebRTCVad(object):
51
  self.timestamp_start = 0.0
52
  self.timestamp_end = 0.0
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def signal_to_frames(self, signal: np.ndarray):
55
  frames = list()
56
 
@@ -138,6 +156,7 @@ class WebRTCVad(object):
138
  self.timestamp_end = end
139
 
140
  def vad(self, signal: np.ndarray) -> List[list]:
 
141
  segments = self.segments_generator(signal)
142
  vad_segments = self.vad_segments_generator(segments)
143
  vad_segments = list(vad_segments)
 
51
  self.timestamp_start = 0.0
52
  self.timestamp_end = 0.0
53
 
54
+ def reset(self):
55
+ # frames
56
+ self.frame_length = int(self.sample_rate * (self.frame_duration_ms / 1000.0))
57
+ self.frame_timestamp = 0.0
58
+ self.signal_cache = None
59
+
60
+ # segments
61
+ self.num_padding_frames = int(self.padding_duration_ms / self.frame_duration_ms)
62
+ self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
63
+ self.triggered = False
64
+ self.voiced_frames: List[Frame] = list()
65
+ self.segments = list()
66
+
67
+ # vad segments
68
+ self.is_first_segment = True
69
+ self.timestamp_start = 0.0
70
+ self.timestamp_end = 0.0
71
+
72
  def signal_to_frames(self, signal: np.ndarray):
73
  frames = list()
74
 
 
156
  self.timestamp_end = end
157
 
158
  def vad(self, signal: np.ndarray) -> List[list]:
159
+ # signal dtype: np.int16
160
  segments = self.segments_generator(signal)
161
  vad_segments = self.vad_segments_generator(segments)
162
  vad_segments = list(vad_segments)