HoneyTian commited on
Commit
e006e10
·
1 Parent(s): 35035c8
examples/cnn_vad_by_webrtcvad/step_1_prepare_data.py CHANGED
@@ -56,7 +56,7 @@ def target_second_noise_signal_generator(filename_patterns: List[str],
56
 
57
  for epoch_idx in range(max_epoch):
58
  for filename_pattern in filename_patterns:
59
- for filename in glob(filename_pattern):
60
  signal, _ = librosa.load(filename, sr=sample_rate)
61
 
62
  if signal.ndim != 1:
@@ -109,7 +109,7 @@ def target_second_speech_signal_generator(filename_patterns: List[str],
109
  sample_rate: int = 8000, max_epoch: int = 1):
110
  for epoch_idx in range(max_epoch):
111
  for filename_pattern in filename_patterns:
112
- for filename in glob(filename_pattern):
113
  signal, _ = librosa.load(filename, sr=sample_rate)
114
  raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
115
 
 
56
 
57
  for epoch_idx in range(max_epoch):
58
  for filename_pattern in filename_patterns:
59
+ for filename in glob(filename_pattern, recursive=True):
60
  signal, _ = librosa.load(filename, sr=sample_rate)
61
 
62
  if signal.ndim != 1:
 
109
  sample_rate: int = 8000, max_epoch: int = 1):
110
  for epoch_idx in range(max_epoch):
111
  for filename_pattern in filename_patterns:
112
+ for filename in glob(filename_pattern, recursive=True):
113
  signal, _ = librosa.load(filename, sr=sample_rate)
114
  raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
115
 
examples/fsmn_vad_by_webrtcvad/run.sh CHANGED
@@ -2,20 +2,6 @@
2
 
3
  : <<'END'
4
 
5
- bash run.sh --stage 1 --stop_stage 1 --system_version windows \
6
- --file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
7
- --final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
8
- --noise_patterns "D:/Users/tianx/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
9
- --speech_patterns "D:/Users/tianx/HuggingDatasets/nx_noise/data/speech/**/*.wav"
10
-
11
-
12
- bash run.sh --stage 1 --stop_stage 1 --system_version centos \
13
- --file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
14
- --final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
15
- --noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
16
- --speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
17
- /data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
18
-
19
  bash run.sh --stage 3 --stop_stage 3 --system_version centos \
20
  --file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
21
  --final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
 
2
 
3
  : <<'END'
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  bash run.sh --stage 3 --stop_stage 3 --system_version centos \
6
  --file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
7
  --final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
examples/silero_vad_by_webrtcvad/run.sh CHANGED
@@ -2,13 +2,6 @@
2
 
3
  : <<'END'
4
 
5
- bash run.sh --stage 2 --stop_stage 2 --system_version centos \
6
- --file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
7
- --final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
8
- --noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
9
- --speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
10
- /data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
11
-
12
  bash run.sh --stage 3 --stop_stage 3 --system_version centos \
13
  --file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
14
  --final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
@@ -16,7 +9,6 @@ bash run.sh --stage 3 --stop_stage 3 --system_version centos \
16
  --speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
17
  /data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
18
 
19
-
20
  END
21
 
22
 
 
2
 
3
  : <<'END'
4
 
 
 
 
 
 
 
 
5
  bash run.sh --stage 3 --stop_stage 3 --system_version centos \
6
  --file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
7
  --final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
 
9
  --speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
10
  /data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
11
 
 
12
  END
13
 
14
 
examples/silero_vad_by_webrtcvad/step_1_prepare_data.py CHANGED
@@ -35,7 +35,7 @@ def get_args():
35
  parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
36
 
37
  parser.add_argument("--duration", default=8.0, type=float)
38
- parser.add_argument("--min_speech_duration", default=6.0, type=float)
39
  parser.add_argument("--max_speech_duration", default=8.0, type=float)
40
  parser.add_argument("--min_snr_db", default=-10, type=float)
41
  parser.add_argument("--max_snr_db", default=20, type=float)
@@ -56,7 +56,7 @@ def target_second_noise_signal_generator(filename_patterns: List[str],
56
 
57
  for epoch_idx in range(max_epoch):
58
  for filename_pattern in filename_patterns:
59
- for filename in glob(filename_pattern):
60
  signal, _ = librosa.load(filename, sr=sample_rate)
61
 
62
  if signal.ndim != 1:
@@ -109,7 +109,7 @@ def target_second_speech_signal_generator(filename_patterns: List[str],
109
  sample_rate: int = 8000, max_epoch: int = 1):
110
  for epoch_idx in range(max_epoch):
111
  for filename_pattern in filename_patterns:
112
- for filename in glob(filename_pattern):
113
  signal, _ = librosa.load(filename, sr=sample_rate)
114
  raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
115
 
 
35
  parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
36
 
37
  parser.add_argument("--duration", default=8.0, type=float)
38
+ parser.add_argument("--min_speech_duration", default=4.0, type=float)
39
  parser.add_argument("--max_speech_duration", default=8.0, type=float)
40
  parser.add_argument("--min_snr_db", default=-10, type=float)
41
  parser.add_argument("--max_snr_db", default=20, type=float)
 
56
 
57
  for epoch_idx in range(max_epoch):
58
  for filename_pattern in filename_patterns:
59
+ for filename in glob(filename_pattern, recursive=True):
60
  signal, _ = librosa.load(filename, sr=sample_rate)
61
 
62
  if signal.ndim != 1:
 
109
  sample_rate: int = 8000, max_epoch: int = 1):
110
  for epoch_idx in range(max_epoch):
111
  for filename_pattern in filename_patterns:
112
+ for filename in glob(filename_pattern, recursive=True):
113
  signal, _ = librosa.load(filename, sr=sample_rate)
114
  raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
115
 
examples/silero_vad_by_webrtcvad/step_2_make_vad_segments.py CHANGED
@@ -4,6 +4,7 @@ import argparse
4
  import json
5
  import os
6
  import sys
 
7
 
8
  pwd = os.path.abspath(os.path.dirname(__file__))
9
  sys.path.append(os.path.join(pwd, "../../"))
@@ -42,6 +43,54 @@ def get_args():
42
  return args
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def main():
46
  args = get_args()
47
 
@@ -68,8 +117,8 @@ def main():
68
  end_ring_rate=0.1,
69
  frame_size_ms=30,
70
  frame_step_ms=30,
71
- padding_length_ms=90,
72
- max_silence_length_ms=100,
73
  max_speech_length_s=100,
74
  min_speech_length_s=0.1,
75
  sample_rate=args.expected_sample_rate,
@@ -114,6 +163,9 @@ def main():
114
  )
115
  waveform = np.array(waveform * (1 << 15), dtype=np.int16)
116
 
 
 
 
117
  # vad
118
  vad_segments = list()
119
  segments = w_vad.vad(waveform)
@@ -122,6 +174,7 @@ def main():
122
  vad_segments += segments
123
  w_vad.reset()
124
 
 
125
  row["vad_segments"] = vad_segments
126
 
127
  row = json.dumps(row, ensure_ascii=False)
@@ -168,6 +221,9 @@ def main():
168
  )
169
  waveform = np.array(waveform * (1 << 15), dtype=np.int16)
170
 
 
 
 
171
  # vad
172
  vad_segments = list()
173
  segments = w_vad.vad(waveform)
@@ -176,6 +232,7 @@ def main():
176
  vad_segments += segments
177
  w_vad.reset()
178
 
 
179
  row["vad_segments"] = vad_segments
180
 
181
  row = json.dumps(row, ensure_ascii=False)
 
4
  import json
5
  import os
6
  import sys
7
+ from typing import List, Tuple
8
 
9
  pwd = os.path.abspath(os.path.dirname(__file__))
10
  sys.path.append(os.path.join(pwd, "../../"))
 
43
  return args
44
 
45
 
46
+ def get_non_silence_segments(waveform: np.ndarray, sample_rate: int = 8000):
47
+ non_silent_intervals = librosa.effects.split(
48
+ waveform,
49
+ top_db=40, # 静音阈值(单位:dB)
50
+ frame_length=512, # 分析帧长
51
+ hop_length=128 # 帧移
52
+ )
53
+
54
+ # 输出非静音段的时间区间(单位:秒)
55
+ result = [(start / sample_rate, end / sample_rate) for (start, end) in non_silent_intervals]
56
+ return result
57
+
58
+
59
+ def get_intersection(non_silence: list[tuple[float, float]],
60
+ speech: list[tuple[float, float]]) -> list[tuple[float, float]]:
61
+ """
62
+ 计算语音段与非静音段的交集
63
+ :param non_silence: 非静音段列表,格式 [(start1, end1), ...]
64
+ :param speech: 语音检测段列表,格式 [(start2, end2), ...]
65
+ :return: 交集段列表,格式 [(start, end), ...]
66
+ """
67
+ # 按起始时间排序(假设输入已排序可不排)
68
+ non_silence = sorted(non_silence, key=lambda x: x[0])
69
+ speech = sorted(speech, key=lambda x: x[0])
70
+
71
+ result = []
72
+ i = j = 0
73
+
74
+ while i < len(non_silence) and j < len(speech):
75
+ ns_start, ns_end = non_silence[i]
76
+ sp_start, sp_end = speech[j]
77
+
78
+ # 计算重叠区间
79
+ overlap_start = max(ns_start, sp_start)
80
+ overlap_end = min(ns_end, sp_end)
81
+
82
+ if overlap_start < overlap_end:
83
+ result.append((overlap_start, overlap_end))
84
+
85
+ # 移动指针策略:优先处理先结束的区间
86
+ if ns_end < sp_end:
87
+ i += 1 # 非静音段先结束
88
+ else:
89
+ j += 1 # 语音段先结束
90
+
91
+ return result
92
+
93
+
94
  def main():
95
  args = get_args()
96
 
 
117
  end_ring_rate=0.1,
118
  frame_size_ms=30,
119
  frame_step_ms=30,
120
+ padding_length_ms=30,
121
+ max_silence_length_ms=0,
122
  max_speech_length_s=100,
123
  min_speech_length_s=0.1,
124
  sample_rate=args.expected_sample_rate,
 
163
  )
164
  waveform = np.array(waveform * (1 << 15), dtype=np.int16)
165
 
166
+ # non_silence_segments
167
+ non_silence_segments = get_non_silence_segments(waveform, sample_rate=args.expected_sample_rate)
168
+
169
  # vad
170
  vad_segments = list()
171
  segments = w_vad.vad(waveform)
 
174
  vad_segments += segments
175
  w_vad.reset()
176
 
177
+ vad_segments = get_intersection(non_silence_segments, vad_segments)
178
  row["vad_segments"] = vad_segments
179
 
180
  row = json.dumps(row, ensure_ascii=False)
 
221
  )
222
  waveform = np.array(waveform * (1 << 15), dtype=np.int16)
223
 
224
+ # non_silence_segments
225
+ non_silence_segments = get_non_silence_segments(waveform, sample_rate=args.expected_sample_rate)
226
+
227
  # vad
228
  vad_segments = list()
229
  segments = w_vad.vad(waveform)
 
232
  vad_segments += segments
233
  w_vad.reset()
234
 
235
+ vad_segments = get_intersection(non_silence_segments, vad_segments)
236
  row["vad_segments"] = vad_segments
237
 
238
  row = json.dumps(row, ensure_ascii=False)
examples/silero_vad_by_webrtcvad/step_4_train_model.py CHANGED
@@ -255,19 +255,22 @@ def main():
255
  desc="Training; epoch-{}".format(epoch_idx),
256
  )
257
  for train_batch in train_data_loader:
258
- noisy_audios, batch_vad_segments = train_batch
259
  noisy_audios: torch.Tensor = noisy_audios.to(device)
 
260
  # noisy_audios shape: [b, num_samples]
261
  num_samples = noisy_audios.shape[-1]
262
 
263
- logits, probs = model.forward(noisy_audios)
 
264
 
265
  targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
266
 
267
  bce_loss = bce_loss_fn.forward(probs, targets)
268
  dice_loss = dice_loss_fn.forward(probs, targets)
 
269
 
270
- loss = 1.0 * bce_loss + 1.0 * dice_loss
271
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
272
  logger.info(f"find nan or inf in loss. continue.")
273
  continue
@@ -284,11 +287,13 @@ def main():
284
  total_loss += loss.item()
285
  total_bce_loss += bce_loss.item()
286
  total_dice_loss += dice_loss.item()
 
287
  total_batches += 1
288
 
289
  average_loss = round(total_loss / total_batches, 4)
290
  average_bce_loss = round(total_bce_loss / total_batches, 4)
291
  average_dice_loss = round(total_dice_loss / total_batches, 4)
 
292
 
293
  metrics = vad_accuracy_metrics_fn.get_metric()
294
  accuracy = metrics["accuracy"]
@@ -303,6 +308,7 @@ def main():
303
  "loss": average_loss,
304
  "bce_loss": average_bce_loss,
305
  "dice_loss": average_dice_loss,
 
306
  "accuracy": accuracy,
307
  "f1": f1,
308
  "precision": precision,
@@ -322,6 +328,7 @@ def main():
322
  total_loss = 0.
323
  total_bce_loss = 0.
324
  total_dice_loss = 0.
 
325
  total_batches = 0.
326
 
327
  progress_bar_train.close()
@@ -329,19 +336,22 @@ def main():
329
  desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
330
  )
331
  for eval_batch in valid_data_loader:
332
- noisy_audios, batch_vad_segments = eval_batch
333
  noisy_audios: torch.Tensor = noisy_audios.to(device)
 
334
  # noisy_audios shape: [b, num_samples]
335
  num_samples = noisy_audios.shape[-1]
336
 
337
- logits, probs = model.forward(noisy_audios)
 
338
 
339
  targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
340
 
341
  bce_loss = bce_loss_fn.forward(probs, targets)
342
  dice_loss = dice_loss_fn.forward(probs, targets)
 
343
 
344
- loss = 1.0 * bce_loss + 1.0 * dice_loss
345
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
346
  logger.info(f"find nan or inf in loss. continue.")
347
  continue
@@ -352,11 +362,13 @@ def main():
352
  total_loss += loss.item()
353
  total_bce_loss += bce_loss.item()
354
  total_dice_loss += dice_loss.item()
 
355
  total_batches += 1
356
 
357
  average_loss = round(total_loss / total_batches, 4)
358
  average_bce_loss = round(total_bce_loss / total_batches, 4)
359
  average_dice_loss = round(total_dice_loss / total_batches, 4)
 
360
 
361
  metrics = vad_accuracy_metrics_fn.get_metric()
362
  accuracy = metrics["accuracy"]
@@ -371,6 +383,7 @@ def main():
371
  "loss": average_loss,
372
  "bce_loss": average_bce_loss,
373
  "dice_loss": average_dice_loss,
 
374
  "accuracy": accuracy,
375
  "f1": f1,
376
  "precision": precision,
@@ -384,6 +397,7 @@ def main():
384
  total_loss = 0.
385
  total_bce_loss = 0.
386
  total_dice_loss = 0.
 
387
  total_batches = 0.
388
 
389
  progress_bar_eval.close()
@@ -425,8 +439,12 @@ def main():
425
  "loss": average_loss,
426
  "bce_loss": average_bce_loss,
427
  "dice_loss": average_dice_loss,
 
428
 
429
  "accuracy": accuracy,
 
 
 
430
  }
431
  metrics_filename = save_dir / "metrics_epoch.json"
432
  with open(metrics_filename, "w", encoding="utf-8") as f:
 
255
  desc="Training; epoch-{}".format(epoch_idx),
256
  )
257
  for train_batch in train_data_loader:
258
+ noisy_audios, clean_audios, batch_vad_segments = train_batch
259
  noisy_audios: torch.Tensor = noisy_audios.to(device)
260
+ clean_audios: torch.Tensor = clean_audios.to(device)
261
  # noisy_audios shape: [b, num_samples]
262
  num_samples = noisy_audios.shape[-1]
263
 
264
+ logits, probs, lsnr = model.forward(noisy_audios)
265
+ lsnr = torch.squeeze(lsnr, dim=-1)
266
 
267
  targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
268
 
269
  bce_loss = bce_loss_fn.forward(probs, targets)
270
  dice_loss = dice_loss_fn.forward(probs, targets)
271
+ lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
272
 
273
+ loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
274
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
275
  logger.info(f"find nan or inf in loss. continue.")
276
  continue
 
287
  total_loss += loss.item()
288
  total_bce_loss += bce_loss.item()
289
  total_dice_loss += dice_loss.item()
290
+ total_lsnr_loss += lsnr_loss.item()
291
  total_batches += 1
292
 
293
  average_loss = round(total_loss / total_batches, 4)
294
  average_bce_loss = round(total_bce_loss / total_batches, 4)
295
  average_dice_loss = round(total_dice_loss / total_batches, 4)
296
+ average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
297
 
298
  metrics = vad_accuracy_metrics_fn.get_metric()
299
  accuracy = metrics["accuracy"]
 
308
  "loss": average_loss,
309
  "bce_loss": average_bce_loss,
310
  "dice_loss": average_dice_loss,
311
+ "lsnr_loss": average_lsnr_loss,
312
  "accuracy": accuracy,
313
  "f1": f1,
314
  "precision": precision,
 
328
  total_loss = 0.
329
  total_bce_loss = 0.
330
  total_dice_loss = 0.
331
+ total_lsnr_loss = 0.
332
  total_batches = 0.
333
 
334
  progress_bar_train.close()
 
336
  desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
337
  )
338
  for eval_batch in valid_data_loader:
339
+ noisy_audios, clean_audios, batch_vad_segments = eval_batch
340
  noisy_audios: torch.Tensor = noisy_audios.to(device)
341
+ clean_audios: torch.Tensor = clean_audios.to(device)
342
  # noisy_audios shape: [b, num_samples]
343
  num_samples = noisy_audios.shape[-1]
344
 
345
+ logits, probs, lsnr = model.forward(noisy_audios)
346
+ lsnr = torch.squeeze(lsnr, dim=-1)
347
 
348
  targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
349
 
350
  bce_loss = bce_loss_fn.forward(probs, targets)
351
  dice_loss = dice_loss_fn.forward(probs, targets)
352
+ lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
353
 
354
+ loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
355
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
356
  logger.info(f"find nan or inf in loss. continue.")
357
  continue
 
362
  total_loss += loss.item()
363
  total_bce_loss += bce_loss.item()
364
  total_dice_loss += dice_loss.item()
365
+ total_lsnr_loss += lsnr_loss.item()
366
  total_batches += 1
367
 
368
  average_loss = round(total_loss / total_batches, 4)
369
  average_bce_loss = round(total_bce_loss / total_batches, 4)
370
  average_dice_loss = round(total_dice_loss / total_batches, 4)
371
+ average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
372
 
373
  metrics = vad_accuracy_metrics_fn.get_metric()
374
  accuracy = metrics["accuracy"]
 
383
  "loss": average_loss,
384
  "bce_loss": average_bce_loss,
385
  "dice_loss": average_dice_loss,
386
+ "lsnr_loss": average_lsnr_loss,
387
  "accuracy": accuracy,
388
  "f1": f1,
389
  "precision": precision,
 
397
  total_loss = 0.
398
  total_bce_loss = 0.
399
  total_dice_loss = 0.
400
+ total_lsnr_loss = 0.
401
  total_batches = 0.
402
 
403
  progress_bar_eval.close()
 
439
  "loss": average_loss,
440
  "bce_loss": average_bce_loss,
441
  "dice_loss": average_dice_loss,
442
+ "lsnr_loss": average_lsnr_loss,
443
 
444
  "accuracy": accuracy,
445
+ "f1": f1,
446
+ "precision": precision,
447
+ "recall": recall,
448
  }
449
  metrics_filename = save_dir / "metrics_epoch.json"
450
  with open(metrics_filename, "w", encoding="utf-8") as f:
toolbox/torchaudio/models/vad/fsmn_vad/fsmn_encoder.py CHANGED
@@ -197,7 +197,6 @@ class FSMN(nn.Module):
197
  basic_block_rstride: int,
198
  output_affine_size: int,
199
  output_size: int,
200
- use_softmax: bool = True,
201
  ):
202
  super(FSMN, self).__init__()
203
  self.input_size = input_size
 
197
  basic_block_rstride: int,
198
  output_affine_size: int,
199
  output_size: int,
 
200
  ):
201
  super(FSMN, self).__init__()
202
  self.input_size = input_size
toolbox/torchaudio/models/vad/fsmn_vad/inference_fsmn_vad.py CHANGED
@@ -68,7 +68,7 @@ class InferenceFSMNVad(object):
68
  # inputs shape: [1, num_samples,]
69
 
70
  with torch.no_grad():
71
- logits, probs = self.model.forward(inputs)
72
 
73
  # probs shape: [b, t, 1]
74
  probs = torch.squeeze(probs, dim=-1)
@@ -92,15 +92,24 @@ def get_args():
92
  # default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
93
  # default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
94
  # default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
95
- # default=(project_path / "data/examples/speech/active_media_r_2a2f472b-a0b8-4fd5-b1c4-1aedc5d2ce57_0.wav").as_posix(),
96
- # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-SG\2025-05-29\active_media_r_1d4edd08-c6db-41a1-a349-7a22ac36f684_6.wav",
97
- # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-SG\2025-05-29\active_media_r_04f6d842-488e-4e34-967b-2980fdd877c7_5.wav",
98
- # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-SG\2025-05-29\active_media_r_7f6670aa-5600-44c0-9bce-77c1d2b739c7_8.wav",
99
- # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-SG\2025-05-29\active_media_r_1187ff81-3a38-4b0b-846f-b81ad6540ce9_5.wav",
100
- # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-SG\2025-05-29\active_media_r_e44bbfaa-f332-4c02-90a3-cc98505d9a1b_3.wav",
101
- # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-SG\2025-05-29\active_media_r_f89cf1af-f556-42fd-9a42-6c9431002a12_11.wav",
102
- # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-SG\2025-05-29\active_media_r_f89cf1af-f556-42fd-9a42-6c9431002a12_15.wav",
103
- default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-SG\2025-05-29\active_media_w_8b6e28e2-a238-4c8c-b2e3-426b1fca149b_6.wav",
 
 
 
 
 
 
 
 
 
104
  type=str,
105
  )
106
  args = parser.parse_args()
@@ -119,7 +128,8 @@ def main():
119
  signal = signal / (1 << 15)
120
 
121
  infer = InferenceFSMNVad(
122
- pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix()
 
123
  )
124
  frame_step = infer.config.hop_size
125
 
 
68
  # inputs shape: [1, num_samples,]
69
 
70
  with torch.no_grad():
71
+ logits, probs, lsnr = self.model.forward(inputs)
72
 
73
  # probs shape: [b, t, 1]
74
  probs = torch.squeeze(probs, dim=-1)
 
92
  # default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
93
  # default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
94
  # default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
95
+ default=(project_path / "data/examples/speech/active_media_r_2a2f472b-a0b8-4fd5-b1c4-1aedc5d2ce57_0.wav").as_posix(),
96
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_w_8b6e28e2-a238-4c8c-b2e3-426b1fca149b_6.wav",
97
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0a56f035-40f6-4530-b852-613f057d718d_6.wav",
98
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ae70b76-3651-4a71-bc0c-9e1429e4c854_5.wav",
99
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d483249-57f8-4d45-b4c6-bda82d6816ae_2.wav",
100
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d952885-5bc2-4633-81b6-e0e809e113f1_2.wav",
101
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ddac777-d986-4a5c-9c7c-ff64be0a463d_11.wav",
102
+
103
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0b8a8e80-52af-423b-8877-03a78b1e6e43_0.wav",
104
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0ebffb68-6490-4a8b-8eb6-eb82443d7d75_0.wav",
105
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0f6ec933-90df-447b-aca4-6ddc149452ab_0.wav",
106
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_0.wav",
107
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_1.wav",
108
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aff518b-4749-42fc-adfe-64046f9baeb6_0.wav",
109
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_0.wav",
110
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_1.wav",
111
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1bb1f22e-9c3a-4aea-b53f-71cc6547a6ee_0.wav",
112
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1dab161b-2a76-4491-abd1-60dba6172f8d_2.wav",
113
  type=str,
114
  )
115
  args = parser.parse_args()
 
128
  signal = signal / (1 << 15)
129
 
130
  infer = InferenceFSMNVad(
131
+ # pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix(),
132
+ pretrained_model_path_or_zip_file = (project_path / "trained_models/fsmn-vad-by-webrtcvad-nx2-dns3.zip").as_posix(),
133
  )
134
  frame_step = infer.config.hop_size
135