update
Browse files- examples/cnn_vad_by_webrtcvad/step_1_prepare_data.py +2 -2
- examples/fsmn_vad_by_webrtcvad/run.sh +0 -14
- examples/silero_vad_by_webrtcvad/run.sh +0 -8
- examples/silero_vad_by_webrtcvad/step_1_prepare_data.py +3 -3
- examples/silero_vad_by_webrtcvad/step_2_make_vad_segments.py +59 -2
- examples/silero_vad_by_webrtcvad/step_4_train_model.py +24 -6
- toolbox/torchaudio/models/vad/fsmn_vad/fsmn_encoder.py +0 -1
- toolbox/torchaudio/models/vad/fsmn_vad/inference_fsmn_vad.py +21 -11
examples/cnn_vad_by_webrtcvad/step_1_prepare_data.py
CHANGED
@@ -56,7 +56,7 @@ def target_second_noise_signal_generator(filename_patterns: List[str],
|
|
56 |
|
57 |
for epoch_idx in range(max_epoch):
|
58 |
for filename_pattern in filename_patterns:
|
59 |
-
for filename in glob(filename_pattern):
|
60 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
61 |
|
62 |
if signal.ndim != 1:
|
@@ -109,7 +109,7 @@ def target_second_speech_signal_generator(filename_patterns: List[str],
|
|
109 |
sample_rate: int = 8000, max_epoch: int = 1):
|
110 |
for epoch_idx in range(max_epoch):
|
111 |
for filename_pattern in filename_patterns:
|
112 |
-
for filename in glob(filename_pattern):
|
113 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
114 |
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
115 |
|
|
|
56 |
|
57 |
for epoch_idx in range(max_epoch):
|
58 |
for filename_pattern in filename_patterns:
|
59 |
+
for filename in glob(filename_pattern, recursive=True):
|
60 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
61 |
|
62 |
if signal.ndim != 1:
|
|
|
109 |
sample_rate: int = 8000, max_epoch: int = 1):
|
110 |
for epoch_idx in range(max_epoch):
|
111 |
for filename_pattern in filename_patterns:
|
112 |
+
for filename in glob(filename_pattern, recursive=True):
|
113 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
114 |
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
115 |
|
examples/fsmn_vad_by_webrtcvad/run.sh
CHANGED
@@ -2,20 +2,6 @@
|
|
2 |
|
3 |
: <<'END'
|
4 |
|
5 |
-
bash run.sh --stage 1 --stop_stage 1 --system_version windows \
|
6 |
-
--file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
7 |
-
--final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
8 |
-
--noise_patterns "D:/Users/tianx/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
9 |
-
--speech_patterns "D:/Users/tianx/HuggingDatasets/nx_noise/data/speech/**/*.wav"
|
10 |
-
|
11 |
-
|
12 |
-
bash run.sh --stage 1 --stop_stage 1 --system_version centos \
|
13 |
-
--file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
14 |
-
--final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
15 |
-
--noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
16 |
-
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
17 |
-
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
18 |
-
|
19 |
bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
20 |
--file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
21 |
--final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
|
|
2 |
|
3 |
: <<'END'
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
6 |
--file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
7 |
--final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
examples/silero_vad_by_webrtcvad/run.sh
CHANGED
@@ -2,13 +2,6 @@
|
|
2 |
|
3 |
: <<'END'
|
4 |
|
5 |
-
bash run.sh --stage 2 --stop_stage 2 --system_version centos \
|
6 |
-
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
7 |
-
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
8 |
-
--noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
9 |
-
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
10 |
-
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
11 |
-
|
12 |
bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
13 |
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
14 |
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
@@ -16,7 +9,6 @@ bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
|
16 |
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
17 |
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
18 |
|
19 |
-
|
20 |
END
|
21 |
|
22 |
|
|
|
2 |
|
3 |
: <<'END'
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
6 |
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
7 |
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
|
|
9 |
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
10 |
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
11 |
|
|
|
12 |
END
|
13 |
|
14 |
|
examples/silero_vad_by_webrtcvad/step_1_prepare_data.py
CHANGED
@@ -35,7 +35,7 @@ def get_args():
|
|
35 |
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
36 |
|
37 |
parser.add_argument("--duration", default=8.0, type=float)
|
38 |
-
parser.add_argument("--min_speech_duration", default=
|
39 |
parser.add_argument("--max_speech_duration", default=8.0, type=float)
|
40 |
parser.add_argument("--min_snr_db", default=-10, type=float)
|
41 |
parser.add_argument("--max_snr_db", default=20, type=float)
|
@@ -56,7 +56,7 @@ def target_second_noise_signal_generator(filename_patterns: List[str],
|
|
56 |
|
57 |
for epoch_idx in range(max_epoch):
|
58 |
for filename_pattern in filename_patterns:
|
59 |
-
for filename in glob(filename_pattern):
|
60 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
61 |
|
62 |
if signal.ndim != 1:
|
@@ -109,7 +109,7 @@ def target_second_speech_signal_generator(filename_patterns: List[str],
|
|
109 |
sample_rate: int = 8000, max_epoch: int = 1):
|
110 |
for epoch_idx in range(max_epoch):
|
111 |
for filename_pattern in filename_patterns:
|
112 |
-
for filename in glob(filename_pattern):
|
113 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
114 |
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
115 |
|
|
|
35 |
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
36 |
|
37 |
parser.add_argument("--duration", default=8.0, type=float)
|
38 |
+
parser.add_argument("--min_speech_duration", default=4.0, type=float)
|
39 |
parser.add_argument("--max_speech_duration", default=8.0, type=float)
|
40 |
parser.add_argument("--min_snr_db", default=-10, type=float)
|
41 |
parser.add_argument("--max_snr_db", default=20, type=float)
|
|
|
56 |
|
57 |
for epoch_idx in range(max_epoch):
|
58 |
for filename_pattern in filename_patterns:
|
59 |
+
for filename in glob(filename_pattern, recursive=True):
|
60 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
61 |
|
62 |
if signal.ndim != 1:
|
|
|
109 |
sample_rate: int = 8000, max_epoch: int = 1):
|
110 |
for epoch_idx in range(max_epoch):
|
111 |
for filename_pattern in filename_patterns:
|
112 |
+
for filename in glob(filename_pattern, recursive=True):
|
113 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
114 |
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
115 |
|
examples/silero_vad_by_webrtcvad/step_2_make_vad_segments.py
CHANGED
@@ -4,6 +4,7 @@ import argparse
|
|
4 |
import json
|
5 |
import os
|
6 |
import sys
|
|
|
7 |
|
8 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
9 |
sys.path.append(os.path.join(pwd, "../../"))
|
@@ -42,6 +43,54 @@ def get_args():
|
|
42 |
return args
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def main():
|
46 |
args = get_args()
|
47 |
|
@@ -68,8 +117,8 @@ def main():
|
|
68 |
end_ring_rate=0.1,
|
69 |
frame_size_ms=30,
|
70 |
frame_step_ms=30,
|
71 |
-
padding_length_ms=
|
72 |
-
max_silence_length_ms=
|
73 |
max_speech_length_s=100,
|
74 |
min_speech_length_s=0.1,
|
75 |
sample_rate=args.expected_sample_rate,
|
@@ -114,6 +163,9 @@ def main():
|
|
114 |
)
|
115 |
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
116 |
|
|
|
|
|
|
|
117 |
# vad
|
118 |
vad_segments = list()
|
119 |
segments = w_vad.vad(waveform)
|
@@ -122,6 +174,7 @@ def main():
|
|
122 |
vad_segments += segments
|
123 |
w_vad.reset()
|
124 |
|
|
|
125 |
row["vad_segments"] = vad_segments
|
126 |
|
127 |
row = json.dumps(row, ensure_ascii=False)
|
@@ -168,6 +221,9 @@ def main():
|
|
168 |
)
|
169 |
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
170 |
|
|
|
|
|
|
|
171 |
# vad
|
172 |
vad_segments = list()
|
173 |
segments = w_vad.vad(waveform)
|
@@ -176,6 +232,7 @@ def main():
|
|
176 |
vad_segments += segments
|
177 |
w_vad.reset()
|
178 |
|
|
|
179 |
row["vad_segments"] = vad_segments
|
180 |
|
181 |
row = json.dumps(row, ensure_ascii=False)
|
|
|
4 |
import json
|
5 |
import os
|
6 |
import sys
|
7 |
+
from typing import List, Tuple
|
8 |
|
9 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
10 |
sys.path.append(os.path.join(pwd, "../../"))
|
|
|
43 |
return args
|
44 |
|
45 |
|
46 |
+
def get_non_silence_segments(waveform: np.ndarray, sample_rate: int = 8000):
|
47 |
+
non_silent_intervals = librosa.effects.split(
|
48 |
+
waveform,
|
49 |
+
top_db=40, # 静音阈值(单位:dB)
|
50 |
+
frame_length=512, # 分析帧长
|
51 |
+
hop_length=128 # 帧移
|
52 |
+
)
|
53 |
+
|
54 |
+
# 输出非静音段的时间区间(单位:秒)
|
55 |
+
result = [(start / sample_rate, end / sample_rate) for (start, end) in non_silent_intervals]
|
56 |
+
return result
|
57 |
+
|
58 |
+
|
59 |
+
def get_intersection(non_silence: list[tuple[float, float]],
|
60 |
+
speech: list[tuple[float, float]]) -> list[tuple[float, float]]:
|
61 |
+
"""
|
62 |
+
计算语音段与非静音段的交集
|
63 |
+
:param non_silence: 非静音段列表,格式 [(start1, end1), ...]
|
64 |
+
:param speech: 语音检测段列表,格式 [(start2, end2), ...]
|
65 |
+
:return: 交集段列表,格式 [(start, end), ...]
|
66 |
+
"""
|
67 |
+
# 按起始时间排序(假设输入已排序可不排)
|
68 |
+
non_silence = sorted(non_silence, key=lambda x: x[0])
|
69 |
+
speech = sorted(speech, key=lambda x: x[0])
|
70 |
+
|
71 |
+
result = []
|
72 |
+
i = j = 0
|
73 |
+
|
74 |
+
while i < len(non_silence) and j < len(speech):
|
75 |
+
ns_start, ns_end = non_silence[i]
|
76 |
+
sp_start, sp_end = speech[j]
|
77 |
+
|
78 |
+
# 计算重叠区间
|
79 |
+
overlap_start = max(ns_start, sp_start)
|
80 |
+
overlap_end = min(ns_end, sp_end)
|
81 |
+
|
82 |
+
if overlap_start < overlap_end:
|
83 |
+
result.append((overlap_start, overlap_end))
|
84 |
+
|
85 |
+
# 移动指针策略:优先处理先结束的区间
|
86 |
+
if ns_end < sp_end:
|
87 |
+
i += 1 # 非静音段先结束
|
88 |
+
else:
|
89 |
+
j += 1 # 语音段先结束
|
90 |
+
|
91 |
+
return result
|
92 |
+
|
93 |
+
|
94 |
def main():
|
95 |
args = get_args()
|
96 |
|
|
|
117 |
end_ring_rate=0.1,
|
118 |
frame_size_ms=30,
|
119 |
frame_step_ms=30,
|
120 |
+
padding_length_ms=30,
|
121 |
+
max_silence_length_ms=0,
|
122 |
max_speech_length_s=100,
|
123 |
min_speech_length_s=0.1,
|
124 |
sample_rate=args.expected_sample_rate,
|
|
|
163 |
)
|
164 |
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
165 |
|
166 |
+
# non_silence_segments
|
167 |
+
non_silence_segments = get_non_silence_segments(waveform, sample_rate=args.expected_sample_rate)
|
168 |
+
|
169 |
# vad
|
170 |
vad_segments = list()
|
171 |
segments = w_vad.vad(waveform)
|
|
|
174 |
vad_segments += segments
|
175 |
w_vad.reset()
|
176 |
|
177 |
+
vad_segments = get_intersection(non_silence_segments, vad_segments)
|
178 |
row["vad_segments"] = vad_segments
|
179 |
|
180 |
row = json.dumps(row, ensure_ascii=False)
|
|
|
221 |
)
|
222 |
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
223 |
|
224 |
+
# non_silence_segments
|
225 |
+
non_silence_segments = get_non_silence_segments(waveform, sample_rate=args.expected_sample_rate)
|
226 |
+
|
227 |
# vad
|
228 |
vad_segments = list()
|
229 |
segments = w_vad.vad(waveform)
|
|
|
232 |
vad_segments += segments
|
233 |
w_vad.reset()
|
234 |
|
235 |
+
vad_segments = get_intersection(non_silence_segments, vad_segments)
|
236 |
row["vad_segments"] = vad_segments
|
237 |
|
238 |
row = json.dumps(row, ensure_ascii=False)
|
examples/silero_vad_by_webrtcvad/step_4_train_model.py
CHANGED
@@ -255,19 +255,22 @@ def main():
|
|
255 |
desc="Training; epoch-{}".format(epoch_idx),
|
256 |
)
|
257 |
for train_batch in train_data_loader:
|
258 |
-
noisy_audios, batch_vad_segments = train_batch
|
259 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
|
|
260 |
# noisy_audios shape: [b, num_samples]
|
261 |
num_samples = noisy_audios.shape[-1]
|
262 |
|
263 |
-
logits, probs = model.forward(noisy_audios)
|
|
|
264 |
|
265 |
targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
|
266 |
|
267 |
bce_loss = bce_loss_fn.forward(probs, targets)
|
268 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
|
|
269 |
|
270 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss
|
271 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
272 |
logger.info(f"find nan or inf in loss. continue.")
|
273 |
continue
|
@@ -284,11 +287,13 @@ def main():
|
|
284 |
total_loss += loss.item()
|
285 |
total_bce_loss += bce_loss.item()
|
286 |
total_dice_loss += dice_loss.item()
|
|
|
287 |
total_batches += 1
|
288 |
|
289 |
average_loss = round(total_loss / total_batches, 4)
|
290 |
average_bce_loss = round(total_bce_loss / total_batches, 4)
|
291 |
average_dice_loss = round(total_dice_loss / total_batches, 4)
|
|
|
292 |
|
293 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
294 |
accuracy = metrics["accuracy"]
|
@@ -303,6 +308,7 @@ def main():
|
|
303 |
"loss": average_loss,
|
304 |
"bce_loss": average_bce_loss,
|
305 |
"dice_loss": average_dice_loss,
|
|
|
306 |
"accuracy": accuracy,
|
307 |
"f1": f1,
|
308 |
"precision": precision,
|
@@ -322,6 +328,7 @@ def main():
|
|
322 |
total_loss = 0.
|
323 |
total_bce_loss = 0.
|
324 |
total_dice_loss = 0.
|
|
|
325 |
total_batches = 0.
|
326 |
|
327 |
progress_bar_train.close()
|
@@ -329,19 +336,22 @@ def main():
|
|
329 |
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
330 |
)
|
331 |
for eval_batch in valid_data_loader:
|
332 |
-
noisy_audios, batch_vad_segments = eval_batch
|
333 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
|
|
334 |
# noisy_audios shape: [b, num_samples]
|
335 |
num_samples = noisy_audios.shape[-1]
|
336 |
|
337 |
-
logits, probs = model.forward(noisy_audios)
|
|
|
338 |
|
339 |
targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
|
340 |
|
341 |
bce_loss = bce_loss_fn.forward(probs, targets)
|
342 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
|
|
343 |
|
344 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss
|
345 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
346 |
logger.info(f"find nan or inf in loss. continue.")
|
347 |
continue
|
@@ -352,11 +362,13 @@ def main():
|
|
352 |
total_loss += loss.item()
|
353 |
total_bce_loss += bce_loss.item()
|
354 |
total_dice_loss += dice_loss.item()
|
|
|
355 |
total_batches += 1
|
356 |
|
357 |
average_loss = round(total_loss / total_batches, 4)
|
358 |
average_bce_loss = round(total_bce_loss / total_batches, 4)
|
359 |
average_dice_loss = round(total_dice_loss / total_batches, 4)
|
|
|
360 |
|
361 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
362 |
accuracy = metrics["accuracy"]
|
@@ -371,6 +383,7 @@ def main():
|
|
371 |
"loss": average_loss,
|
372 |
"bce_loss": average_bce_loss,
|
373 |
"dice_loss": average_dice_loss,
|
|
|
374 |
"accuracy": accuracy,
|
375 |
"f1": f1,
|
376 |
"precision": precision,
|
@@ -384,6 +397,7 @@ def main():
|
|
384 |
total_loss = 0.
|
385 |
total_bce_loss = 0.
|
386 |
total_dice_loss = 0.
|
|
|
387 |
total_batches = 0.
|
388 |
|
389 |
progress_bar_eval.close()
|
@@ -425,8 +439,12 @@ def main():
|
|
425 |
"loss": average_loss,
|
426 |
"bce_loss": average_bce_loss,
|
427 |
"dice_loss": average_dice_loss,
|
|
|
428 |
|
429 |
"accuracy": accuracy,
|
|
|
|
|
|
|
430 |
}
|
431 |
metrics_filename = save_dir / "metrics_epoch.json"
|
432 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
|
|
255 |
desc="Training; epoch-{}".format(epoch_idx),
|
256 |
)
|
257 |
for train_batch in train_data_loader:
|
258 |
+
noisy_audios, clean_audios, batch_vad_segments = train_batch
|
259 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
260 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
261 |
# noisy_audios shape: [b, num_samples]
|
262 |
num_samples = noisy_audios.shape[-1]
|
263 |
|
264 |
+
logits, probs, lsnr = model.forward(noisy_audios)
|
265 |
+
lsnr = torch.squeeze(lsnr, dim=-1)
|
266 |
|
267 |
targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
|
268 |
|
269 |
bce_loss = bce_loss_fn.forward(probs, targets)
|
270 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
271 |
+
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
272 |
|
273 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
|
274 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
275 |
logger.info(f"find nan or inf in loss. continue.")
|
276 |
continue
|
|
|
287 |
total_loss += loss.item()
|
288 |
total_bce_loss += bce_loss.item()
|
289 |
total_dice_loss += dice_loss.item()
|
290 |
+
total_lsnr_loss += lsnr_loss.item()
|
291 |
total_batches += 1
|
292 |
|
293 |
average_loss = round(total_loss / total_batches, 4)
|
294 |
average_bce_loss = round(total_bce_loss / total_batches, 4)
|
295 |
average_dice_loss = round(total_dice_loss / total_batches, 4)
|
296 |
+
average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
|
297 |
|
298 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
299 |
accuracy = metrics["accuracy"]
|
|
|
308 |
"loss": average_loss,
|
309 |
"bce_loss": average_bce_loss,
|
310 |
"dice_loss": average_dice_loss,
|
311 |
+
"lsnr_loss": average_lsnr_loss,
|
312 |
"accuracy": accuracy,
|
313 |
"f1": f1,
|
314 |
"precision": precision,
|
|
|
328 |
total_loss = 0.
|
329 |
total_bce_loss = 0.
|
330 |
total_dice_loss = 0.
|
331 |
+
total_lsnr_loss = 0.
|
332 |
total_batches = 0.
|
333 |
|
334 |
progress_bar_train.close()
|
|
|
336 |
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
337 |
)
|
338 |
for eval_batch in valid_data_loader:
|
339 |
+
noisy_audios, clean_audios, batch_vad_segments = eval_batch
|
340 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
341 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
342 |
# noisy_audios shape: [b, num_samples]
|
343 |
num_samples = noisy_audios.shape[-1]
|
344 |
|
345 |
+
logits, probs, lsnr = model.forward(noisy_audios)
|
346 |
+
lsnr = torch.squeeze(lsnr, dim=-1)
|
347 |
|
348 |
targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
|
349 |
|
350 |
bce_loss = bce_loss_fn.forward(probs, targets)
|
351 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
352 |
+
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
353 |
|
354 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
|
355 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
356 |
logger.info(f"find nan or inf in loss. continue.")
|
357 |
continue
|
|
|
362 |
total_loss += loss.item()
|
363 |
total_bce_loss += bce_loss.item()
|
364 |
total_dice_loss += dice_loss.item()
|
365 |
+
total_lsnr_loss += lsnr_loss.item()
|
366 |
total_batches += 1
|
367 |
|
368 |
average_loss = round(total_loss / total_batches, 4)
|
369 |
average_bce_loss = round(total_bce_loss / total_batches, 4)
|
370 |
average_dice_loss = round(total_dice_loss / total_batches, 4)
|
371 |
+
average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
|
372 |
|
373 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
374 |
accuracy = metrics["accuracy"]
|
|
|
383 |
"loss": average_loss,
|
384 |
"bce_loss": average_bce_loss,
|
385 |
"dice_loss": average_dice_loss,
|
386 |
+
"lsnr_loss": average_lsnr_loss,
|
387 |
"accuracy": accuracy,
|
388 |
"f1": f1,
|
389 |
"precision": precision,
|
|
|
397 |
total_loss = 0.
|
398 |
total_bce_loss = 0.
|
399 |
total_dice_loss = 0.
|
400 |
+
total_lsnr_loss = 0.
|
401 |
total_batches = 0.
|
402 |
|
403 |
progress_bar_eval.close()
|
|
|
439 |
"loss": average_loss,
|
440 |
"bce_loss": average_bce_loss,
|
441 |
"dice_loss": average_dice_loss,
|
442 |
+
"lsnr_loss": average_lsnr_loss,
|
443 |
|
444 |
"accuracy": accuracy,
|
445 |
+
"f1": f1,
|
446 |
+
"precision": precision,
|
447 |
+
"recall": recall,
|
448 |
}
|
449 |
metrics_filename = save_dir / "metrics_epoch.json"
|
450 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
toolbox/torchaudio/models/vad/fsmn_vad/fsmn_encoder.py
CHANGED
@@ -197,7 +197,6 @@ class FSMN(nn.Module):
|
|
197 |
basic_block_rstride: int,
|
198 |
output_affine_size: int,
|
199 |
output_size: int,
|
200 |
-
use_softmax: bool = True,
|
201 |
):
|
202 |
super(FSMN, self).__init__()
|
203 |
self.input_size = input_size
|
|
|
197 |
basic_block_rstride: int,
|
198 |
output_affine_size: int,
|
199 |
output_size: int,
|
|
|
200 |
):
|
201 |
super(FSMN, self).__init__()
|
202 |
self.input_size = input_size
|
toolbox/torchaudio/models/vad/fsmn_vad/inference_fsmn_vad.py
CHANGED
@@ -68,7 +68,7 @@ class InferenceFSMNVad(object):
|
|
68 |
# inputs shape: [1, num_samples,]
|
69 |
|
70 |
with torch.no_grad():
|
71 |
-
logits, probs = self.model.forward(inputs)
|
72 |
|
73 |
# probs shape: [b, t, 1]
|
74 |
probs = torch.squeeze(probs, dim=-1)
|
@@ -92,15 +92,24 @@ def get_args():
|
|
92 |
# default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
|
93 |
# default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
|
94 |
# default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
|
95 |
-
|
96 |
-
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\
|
97 |
-
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\
|
98 |
-
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\
|
99 |
-
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\
|
100 |
-
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\
|
101 |
-
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\
|
102 |
-
|
103 |
-
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
type=str,
|
105 |
)
|
106 |
args = parser.parse_args()
|
@@ -119,7 +128,8 @@ def main():
|
|
119 |
signal = signal / (1 << 15)
|
120 |
|
121 |
infer = InferenceFSMNVad(
|
122 |
-
pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix()
|
|
|
123 |
)
|
124 |
frame_step = infer.config.hop_size
|
125 |
|
|
|
68 |
# inputs shape: [1, num_samples,]
|
69 |
|
70 |
with torch.no_grad():
|
71 |
+
logits, probs, lsnr = self.model.forward(inputs)
|
72 |
|
73 |
# probs shape: [b, t, 1]
|
74 |
probs = torch.squeeze(probs, dim=-1)
|
|
|
92 |
# default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
|
93 |
# default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
|
94 |
# default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
|
95 |
+
default=(project_path / "data/examples/speech/active_media_r_2a2f472b-a0b8-4fd5-b1c4-1aedc5d2ce57_0.wav").as_posix(),
|
96 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_w_8b6e28e2-a238-4c8c-b2e3-426b1fca149b_6.wav",
|
97 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0a56f035-40f6-4530-b852-613f057d718d_6.wav",
|
98 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ae70b76-3651-4a71-bc0c-9e1429e4c854_5.wav",
|
99 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d483249-57f8-4d45-b4c6-bda82d6816ae_2.wav",
|
100 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d952885-5bc2-4633-81b6-e0e809e113f1_2.wav",
|
101 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ddac777-d986-4a5c-9c7c-ff64be0a463d_11.wav",
|
102 |
+
|
103 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0b8a8e80-52af-423b-8877-03a78b1e6e43_0.wav",
|
104 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0ebffb68-6490-4a8b-8eb6-eb82443d7d75_0.wav",
|
105 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0f6ec933-90df-447b-aca4-6ddc149452ab_0.wav",
|
106 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_0.wav",
|
107 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_1.wav",
|
108 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aff518b-4749-42fc-adfe-64046f9baeb6_0.wav",
|
109 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_0.wav",
|
110 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_1.wav",
|
111 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1bb1f22e-9c3a-4aea-b53f-71cc6547a6ee_0.wav",
|
112 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1dab161b-2a76-4491-abd1-60dba6172f8d_2.wav",
|
113 |
type=str,
|
114 |
)
|
115 |
args = parser.parse_args()
|
|
|
128 |
signal = signal / (1 << 15)
|
129 |
|
130 |
infer = InferenceFSMNVad(
|
131 |
+
# pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix(),
|
132 |
+
pretrained_model_path_or_zip_file = (project_path / "trained_models/fsmn-vad-by-webrtcvad-nx2-dns3.zip").as_posix(),
|
133 |
)
|
134 |
frame_step = infer.config.hop_size
|
135 |
|