update
Browse files- examples/cnn_vad_by_webrtcvad/step_1_prepare_data.py +2 -2
- examples/fsmn_vad_by_webrtcvad/run.sh +0 -14
- examples/silero_vad_by_webrtcvad/run.sh +0 -8
- examples/silero_vad_by_webrtcvad/step_1_prepare_data.py +3 -3
- examples/silero_vad_by_webrtcvad/step_2_make_vad_segments.py +59 -2
- examples/silero_vad_by_webrtcvad/step_4_train_model.py +24 -6
- toolbox/torchaudio/models/vad/fsmn_vad/fsmn_encoder.py +0 -1
- toolbox/torchaudio/models/vad/fsmn_vad/inference_fsmn_vad.py +21 -11
examples/cnn_vad_by_webrtcvad/step_1_prepare_data.py
CHANGED
|
@@ -56,7 +56,7 @@ def target_second_noise_signal_generator(filename_patterns: List[str],
|
|
| 56 |
|
| 57 |
for epoch_idx in range(max_epoch):
|
| 58 |
for filename_pattern in filename_patterns:
|
| 59 |
-
for filename in glob(filename_pattern):
|
| 60 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
| 61 |
|
| 62 |
if signal.ndim != 1:
|
|
@@ -109,7 +109,7 @@ def target_second_speech_signal_generator(filename_patterns: List[str],
|
|
| 109 |
sample_rate: int = 8000, max_epoch: int = 1):
|
| 110 |
for epoch_idx in range(max_epoch):
|
| 111 |
for filename_pattern in filename_patterns:
|
| 112 |
-
for filename in glob(filename_pattern):
|
| 113 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
| 114 |
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 115 |
|
|
|
|
| 56 |
|
| 57 |
for epoch_idx in range(max_epoch):
|
| 58 |
for filename_pattern in filename_patterns:
|
| 59 |
+
for filename in glob(filename_pattern, recursive=True):
|
| 60 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
| 61 |
|
| 62 |
if signal.ndim != 1:
|
|
|
|
| 109 |
sample_rate: int = 8000, max_epoch: int = 1):
|
| 110 |
for epoch_idx in range(max_epoch):
|
| 111 |
for filename_pattern in filename_patterns:
|
| 112 |
+
for filename in glob(filename_pattern, recursive=True):
|
| 113 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
| 114 |
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 115 |
|
examples/fsmn_vad_by_webrtcvad/run.sh
CHANGED
|
@@ -2,20 +2,6 @@
|
|
| 2 |
|
| 3 |
: <<'END'
|
| 4 |
|
| 5 |
-
bash run.sh --stage 1 --stop_stage 1 --system_version windows \
|
| 6 |
-
--file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
| 7 |
-
--final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
| 8 |
-
--noise_patterns "D:/Users/tianx/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
| 9 |
-
--speech_patterns "D:/Users/tianx/HuggingDatasets/nx_noise/data/speech/**/*.wav"
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
bash run.sh --stage 1 --stop_stage 1 --system_version centos \
|
| 13 |
-
--file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
| 14 |
-
--final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
| 15 |
-
--noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
| 16 |
-
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
| 17 |
-
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
| 18 |
-
|
| 19 |
bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
| 20 |
--file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
| 21 |
--final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
|
|
|
| 2 |
|
| 3 |
: <<'END'
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
| 6 |
--file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
| 7 |
--final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
examples/silero_vad_by_webrtcvad/run.sh
CHANGED
|
@@ -2,13 +2,6 @@
|
|
| 2 |
|
| 3 |
: <<'END'
|
| 4 |
|
| 5 |
-
bash run.sh --stage 2 --stop_stage 2 --system_version centos \
|
| 6 |
-
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
| 7 |
-
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
| 8 |
-
--noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
| 9 |
-
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
| 10 |
-
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
| 11 |
-
|
| 12 |
bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
| 13 |
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
| 14 |
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
|
@@ -16,7 +9,6 @@ bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
|
| 16 |
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
| 17 |
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
| 18 |
|
| 19 |
-
|
| 20 |
END
|
| 21 |
|
| 22 |
|
|
|
|
| 2 |
|
| 3 |
: <<'END'
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
| 6 |
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
| 7 |
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
|
|
|
| 9 |
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
| 10 |
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
| 11 |
|
|
|
|
| 12 |
END
|
| 13 |
|
| 14 |
|
examples/silero_vad_by_webrtcvad/step_1_prepare_data.py
CHANGED
|
@@ -35,7 +35,7 @@ def get_args():
|
|
| 35 |
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 36 |
|
| 37 |
parser.add_argument("--duration", default=8.0, type=float)
|
| 38 |
-
parser.add_argument("--min_speech_duration", default=
|
| 39 |
parser.add_argument("--max_speech_duration", default=8.0, type=float)
|
| 40 |
parser.add_argument("--min_snr_db", default=-10, type=float)
|
| 41 |
parser.add_argument("--max_snr_db", default=20, type=float)
|
|
@@ -56,7 +56,7 @@ def target_second_noise_signal_generator(filename_patterns: List[str],
|
|
| 56 |
|
| 57 |
for epoch_idx in range(max_epoch):
|
| 58 |
for filename_pattern in filename_patterns:
|
| 59 |
-
for filename in glob(filename_pattern):
|
| 60 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
| 61 |
|
| 62 |
if signal.ndim != 1:
|
|
@@ -109,7 +109,7 @@ def target_second_speech_signal_generator(filename_patterns: List[str],
|
|
| 109 |
sample_rate: int = 8000, max_epoch: int = 1):
|
| 110 |
for epoch_idx in range(max_epoch):
|
| 111 |
for filename_pattern in filename_patterns:
|
| 112 |
-
for filename in glob(filename_pattern):
|
| 113 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
| 114 |
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 115 |
|
|
|
|
| 35 |
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 36 |
|
| 37 |
parser.add_argument("--duration", default=8.0, type=float)
|
| 38 |
+
parser.add_argument("--min_speech_duration", default=4.0, type=float)
|
| 39 |
parser.add_argument("--max_speech_duration", default=8.0, type=float)
|
| 40 |
parser.add_argument("--min_snr_db", default=-10, type=float)
|
| 41 |
parser.add_argument("--max_snr_db", default=20, type=float)
|
|
|
|
| 56 |
|
| 57 |
for epoch_idx in range(max_epoch):
|
| 58 |
for filename_pattern in filename_patterns:
|
| 59 |
+
for filename in glob(filename_pattern, recursive=True):
|
| 60 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
| 61 |
|
| 62 |
if signal.ndim != 1:
|
|
|
|
| 109 |
sample_rate: int = 8000, max_epoch: int = 1):
|
| 110 |
for epoch_idx in range(max_epoch):
|
| 111 |
for filename_pattern in filename_patterns:
|
| 112 |
+
for filename in glob(filename_pattern, recursive=True):
|
| 113 |
signal, _ = librosa.load(filename, sr=sample_rate)
|
| 114 |
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 115 |
|
examples/silero_vad_by_webrtcvad/step_2_make_vad_segments.py
CHANGED
|
@@ -4,6 +4,7 @@ import argparse
|
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
import sys
|
|
|
|
| 7 |
|
| 8 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 9 |
sys.path.append(os.path.join(pwd, "../../"))
|
|
@@ -42,6 +43,54 @@ def get_args():
|
|
| 42 |
return args
|
| 43 |
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def main():
|
| 46 |
args = get_args()
|
| 47 |
|
|
@@ -68,8 +117,8 @@ def main():
|
|
| 68 |
end_ring_rate=0.1,
|
| 69 |
frame_size_ms=30,
|
| 70 |
frame_step_ms=30,
|
| 71 |
-
padding_length_ms=
|
| 72 |
-
max_silence_length_ms=
|
| 73 |
max_speech_length_s=100,
|
| 74 |
min_speech_length_s=0.1,
|
| 75 |
sample_rate=args.expected_sample_rate,
|
|
@@ -114,6 +163,9 @@ def main():
|
|
| 114 |
)
|
| 115 |
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
| 116 |
|
|
|
|
|
|
|
|
|
|
| 117 |
# vad
|
| 118 |
vad_segments = list()
|
| 119 |
segments = w_vad.vad(waveform)
|
|
@@ -122,6 +174,7 @@ def main():
|
|
| 122 |
vad_segments += segments
|
| 123 |
w_vad.reset()
|
| 124 |
|
|
|
|
| 125 |
row["vad_segments"] = vad_segments
|
| 126 |
|
| 127 |
row = json.dumps(row, ensure_ascii=False)
|
|
@@ -168,6 +221,9 @@ def main():
|
|
| 168 |
)
|
| 169 |
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
| 170 |
|
|
|
|
|
|
|
|
|
|
| 171 |
# vad
|
| 172 |
vad_segments = list()
|
| 173 |
segments = w_vad.vad(waveform)
|
|
@@ -176,6 +232,7 @@ def main():
|
|
| 176 |
vad_segments += segments
|
| 177 |
w_vad.reset()
|
| 178 |
|
|
|
|
| 179 |
row["vad_segments"] = vad_segments
|
| 180 |
|
| 181 |
row = json.dumps(row, ensure_ascii=False)
|
|
|
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
import sys
|
| 7 |
+
from typing import List, Tuple
|
| 8 |
|
| 9 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 10 |
sys.path.append(os.path.join(pwd, "../../"))
|
|
|
|
| 43 |
return args
|
| 44 |
|
| 45 |
|
| 46 |
+
def get_non_silence_segments(waveform: np.ndarray, sample_rate: int = 8000):
|
| 47 |
+
non_silent_intervals = librosa.effects.split(
|
| 48 |
+
waveform,
|
| 49 |
+
top_db=40, # 静音阈值(单位:dB)
|
| 50 |
+
frame_length=512, # 分析帧长
|
| 51 |
+
hop_length=128 # 帧移
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# 输出非静音段的时间区间(单位:秒)
|
| 55 |
+
result = [(start / sample_rate, end / sample_rate) for (start, end) in non_silent_intervals]
|
| 56 |
+
return result
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_intersection(non_silence: list[tuple[float, float]],
|
| 60 |
+
speech: list[tuple[float, float]]) -> list[tuple[float, float]]:
|
| 61 |
+
"""
|
| 62 |
+
计算语音段与非静音段的交集
|
| 63 |
+
:param non_silence: 非静音段列表,格式 [(start1, end1), ...]
|
| 64 |
+
:param speech: 语音检测段列表,格式 [(start2, end2), ...]
|
| 65 |
+
:return: 交集段列表,格式 [(start, end), ...]
|
| 66 |
+
"""
|
| 67 |
+
# 按起始时间排序(假设输入已排序可不排)
|
| 68 |
+
non_silence = sorted(non_silence, key=lambda x: x[0])
|
| 69 |
+
speech = sorted(speech, key=lambda x: x[0])
|
| 70 |
+
|
| 71 |
+
result = []
|
| 72 |
+
i = j = 0
|
| 73 |
+
|
| 74 |
+
while i < len(non_silence) and j < len(speech):
|
| 75 |
+
ns_start, ns_end = non_silence[i]
|
| 76 |
+
sp_start, sp_end = speech[j]
|
| 77 |
+
|
| 78 |
+
# 计算重叠区间
|
| 79 |
+
overlap_start = max(ns_start, sp_start)
|
| 80 |
+
overlap_end = min(ns_end, sp_end)
|
| 81 |
+
|
| 82 |
+
if overlap_start < overlap_end:
|
| 83 |
+
result.append((overlap_start, overlap_end))
|
| 84 |
+
|
| 85 |
+
# 移动指针策略:优先处理先结束的区间
|
| 86 |
+
if ns_end < sp_end:
|
| 87 |
+
i += 1 # 非静音段先结束
|
| 88 |
+
else:
|
| 89 |
+
j += 1 # 语音段先结束
|
| 90 |
+
|
| 91 |
+
return result
|
| 92 |
+
|
| 93 |
+
|
| 94 |
def main():
|
| 95 |
args = get_args()
|
| 96 |
|
|
|
|
| 117 |
end_ring_rate=0.1,
|
| 118 |
frame_size_ms=30,
|
| 119 |
frame_step_ms=30,
|
| 120 |
+
padding_length_ms=30,
|
| 121 |
+
max_silence_length_ms=0,
|
| 122 |
max_speech_length_s=100,
|
| 123 |
min_speech_length_s=0.1,
|
| 124 |
sample_rate=args.expected_sample_rate,
|
|
|
|
| 163 |
)
|
| 164 |
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
| 165 |
|
| 166 |
+
# non_silence_segments
|
| 167 |
+
non_silence_segments = get_non_silence_segments(waveform, sample_rate=args.expected_sample_rate)
|
| 168 |
+
|
| 169 |
# vad
|
| 170 |
vad_segments = list()
|
| 171 |
segments = w_vad.vad(waveform)
|
|
|
|
| 174 |
vad_segments += segments
|
| 175 |
w_vad.reset()
|
| 176 |
|
| 177 |
+
vad_segments = get_intersection(non_silence_segments, vad_segments)
|
| 178 |
row["vad_segments"] = vad_segments
|
| 179 |
|
| 180 |
row = json.dumps(row, ensure_ascii=False)
|
|
|
|
| 221 |
)
|
| 222 |
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
| 223 |
|
| 224 |
+
# non_silence_segments
|
| 225 |
+
non_silence_segments = get_non_silence_segments(waveform, sample_rate=args.expected_sample_rate)
|
| 226 |
+
|
| 227 |
# vad
|
| 228 |
vad_segments = list()
|
| 229 |
segments = w_vad.vad(waveform)
|
|
|
|
| 232 |
vad_segments += segments
|
| 233 |
w_vad.reset()
|
| 234 |
|
| 235 |
+
vad_segments = get_intersection(non_silence_segments, vad_segments)
|
| 236 |
row["vad_segments"] = vad_segments
|
| 237 |
|
| 238 |
row = json.dumps(row, ensure_ascii=False)
|
examples/silero_vad_by_webrtcvad/step_4_train_model.py
CHANGED
|
@@ -255,19 +255,22 @@ def main():
|
|
| 255 |
desc="Training; epoch-{}".format(epoch_idx),
|
| 256 |
)
|
| 257 |
for train_batch in train_data_loader:
|
| 258 |
-
noisy_audios, batch_vad_segments = train_batch
|
| 259 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
|
|
|
| 260 |
# noisy_audios shape: [b, num_samples]
|
| 261 |
num_samples = noisy_audios.shape[-1]
|
| 262 |
|
| 263 |
-
logits, probs = model.forward(noisy_audios)
|
|
|
|
| 264 |
|
| 265 |
targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
|
| 266 |
|
| 267 |
bce_loss = bce_loss_fn.forward(probs, targets)
|
| 268 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
|
|
|
| 269 |
|
| 270 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss
|
| 271 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 272 |
logger.info(f"find nan or inf in loss. continue.")
|
| 273 |
continue
|
|
@@ -284,11 +287,13 @@ def main():
|
|
| 284 |
total_loss += loss.item()
|
| 285 |
total_bce_loss += bce_loss.item()
|
| 286 |
total_dice_loss += dice_loss.item()
|
|
|
|
| 287 |
total_batches += 1
|
| 288 |
|
| 289 |
average_loss = round(total_loss / total_batches, 4)
|
| 290 |
average_bce_loss = round(total_bce_loss / total_batches, 4)
|
| 291 |
average_dice_loss = round(total_dice_loss / total_batches, 4)
|
|
|
|
| 292 |
|
| 293 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
| 294 |
accuracy = metrics["accuracy"]
|
|
@@ -303,6 +308,7 @@ def main():
|
|
| 303 |
"loss": average_loss,
|
| 304 |
"bce_loss": average_bce_loss,
|
| 305 |
"dice_loss": average_dice_loss,
|
|
|
|
| 306 |
"accuracy": accuracy,
|
| 307 |
"f1": f1,
|
| 308 |
"precision": precision,
|
|
@@ -322,6 +328,7 @@ def main():
|
|
| 322 |
total_loss = 0.
|
| 323 |
total_bce_loss = 0.
|
| 324 |
total_dice_loss = 0.
|
|
|
|
| 325 |
total_batches = 0.
|
| 326 |
|
| 327 |
progress_bar_train.close()
|
|
@@ -329,19 +336,22 @@ def main():
|
|
| 329 |
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
| 330 |
)
|
| 331 |
for eval_batch in valid_data_loader:
|
| 332 |
-
noisy_audios, batch_vad_segments = eval_batch
|
| 333 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
|
|
|
| 334 |
# noisy_audios shape: [b, num_samples]
|
| 335 |
num_samples = noisy_audios.shape[-1]
|
| 336 |
|
| 337 |
-
logits, probs = model.forward(noisy_audios)
|
|
|
|
| 338 |
|
| 339 |
targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
|
| 340 |
|
| 341 |
bce_loss = bce_loss_fn.forward(probs, targets)
|
| 342 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
|
|
|
| 343 |
|
| 344 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss
|
| 345 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 346 |
logger.info(f"find nan or inf in loss. continue.")
|
| 347 |
continue
|
|
@@ -352,11 +362,13 @@ def main():
|
|
| 352 |
total_loss += loss.item()
|
| 353 |
total_bce_loss += bce_loss.item()
|
| 354 |
total_dice_loss += dice_loss.item()
|
|
|
|
| 355 |
total_batches += 1
|
| 356 |
|
| 357 |
average_loss = round(total_loss / total_batches, 4)
|
| 358 |
average_bce_loss = round(total_bce_loss / total_batches, 4)
|
| 359 |
average_dice_loss = round(total_dice_loss / total_batches, 4)
|
|
|
|
| 360 |
|
| 361 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
| 362 |
accuracy = metrics["accuracy"]
|
|
@@ -371,6 +383,7 @@ def main():
|
|
| 371 |
"loss": average_loss,
|
| 372 |
"bce_loss": average_bce_loss,
|
| 373 |
"dice_loss": average_dice_loss,
|
|
|
|
| 374 |
"accuracy": accuracy,
|
| 375 |
"f1": f1,
|
| 376 |
"precision": precision,
|
|
@@ -384,6 +397,7 @@ def main():
|
|
| 384 |
total_loss = 0.
|
| 385 |
total_bce_loss = 0.
|
| 386 |
total_dice_loss = 0.
|
|
|
|
| 387 |
total_batches = 0.
|
| 388 |
|
| 389 |
progress_bar_eval.close()
|
|
@@ -425,8 +439,12 @@ def main():
|
|
| 425 |
"loss": average_loss,
|
| 426 |
"bce_loss": average_bce_loss,
|
| 427 |
"dice_loss": average_dice_loss,
|
|
|
|
| 428 |
|
| 429 |
"accuracy": accuracy,
|
|
|
|
|
|
|
|
|
|
| 430 |
}
|
| 431 |
metrics_filename = save_dir / "metrics_epoch.json"
|
| 432 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
|
|
|
| 255 |
desc="Training; epoch-{}".format(epoch_idx),
|
| 256 |
)
|
| 257 |
for train_batch in train_data_loader:
|
| 258 |
+
noisy_audios, clean_audios, batch_vad_segments = train_batch
|
| 259 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
| 260 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
| 261 |
# noisy_audios shape: [b, num_samples]
|
| 262 |
num_samples = noisy_audios.shape[-1]
|
| 263 |
|
| 264 |
+
logits, probs, lsnr = model.forward(noisy_audios)
|
| 265 |
+
lsnr = torch.squeeze(lsnr, dim=-1)
|
| 266 |
|
| 267 |
targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
|
| 268 |
|
| 269 |
bce_loss = bce_loss_fn.forward(probs, targets)
|
| 270 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
| 271 |
+
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 272 |
|
| 273 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
|
| 274 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 275 |
logger.info(f"find nan or inf in loss. continue.")
|
| 276 |
continue
|
|
|
|
| 287 |
total_loss += loss.item()
|
| 288 |
total_bce_loss += bce_loss.item()
|
| 289 |
total_dice_loss += dice_loss.item()
|
| 290 |
+
total_lsnr_loss += lsnr_loss.item()
|
| 291 |
total_batches += 1
|
| 292 |
|
| 293 |
average_loss = round(total_loss / total_batches, 4)
|
| 294 |
average_bce_loss = round(total_bce_loss / total_batches, 4)
|
| 295 |
average_dice_loss = round(total_dice_loss / total_batches, 4)
|
| 296 |
+
average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
|
| 297 |
|
| 298 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
| 299 |
accuracy = metrics["accuracy"]
|
|
|
|
| 308 |
"loss": average_loss,
|
| 309 |
"bce_loss": average_bce_loss,
|
| 310 |
"dice_loss": average_dice_loss,
|
| 311 |
+
"lsnr_loss": average_lsnr_loss,
|
| 312 |
"accuracy": accuracy,
|
| 313 |
"f1": f1,
|
| 314 |
"precision": precision,
|
|
|
|
| 328 |
total_loss = 0.
|
| 329 |
total_bce_loss = 0.
|
| 330 |
total_dice_loss = 0.
|
| 331 |
+
total_lsnr_loss = 0.
|
| 332 |
total_batches = 0.
|
| 333 |
|
| 334 |
progress_bar_train.close()
|
|
|
|
| 336 |
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
| 337 |
)
|
| 338 |
for eval_batch in valid_data_loader:
|
| 339 |
+
noisy_audios, clean_audios, batch_vad_segments = eval_batch
|
| 340 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
| 341 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
| 342 |
# noisy_audios shape: [b, num_samples]
|
| 343 |
num_samples = noisy_audios.shape[-1]
|
| 344 |
|
| 345 |
+
logits, probs, lsnr = model.forward(noisy_audios)
|
| 346 |
+
lsnr = torch.squeeze(lsnr, dim=-1)
|
| 347 |
|
| 348 |
targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
|
| 349 |
|
| 350 |
bce_loss = bce_loss_fn.forward(probs, targets)
|
| 351 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
| 352 |
+
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 353 |
|
| 354 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
|
| 355 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 356 |
logger.info(f"find nan or inf in loss. continue.")
|
| 357 |
continue
|
|
|
|
| 362 |
total_loss += loss.item()
|
| 363 |
total_bce_loss += bce_loss.item()
|
| 364 |
total_dice_loss += dice_loss.item()
|
| 365 |
+
total_lsnr_loss += lsnr_loss.item()
|
| 366 |
total_batches += 1
|
| 367 |
|
| 368 |
average_loss = round(total_loss / total_batches, 4)
|
| 369 |
average_bce_loss = round(total_bce_loss / total_batches, 4)
|
| 370 |
average_dice_loss = round(total_dice_loss / total_batches, 4)
|
| 371 |
+
average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
|
| 372 |
|
| 373 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
| 374 |
accuracy = metrics["accuracy"]
|
|
|
|
| 383 |
"loss": average_loss,
|
| 384 |
"bce_loss": average_bce_loss,
|
| 385 |
"dice_loss": average_dice_loss,
|
| 386 |
+
"lsnr_loss": average_lsnr_loss,
|
| 387 |
"accuracy": accuracy,
|
| 388 |
"f1": f1,
|
| 389 |
"precision": precision,
|
|
|
|
| 397 |
total_loss = 0.
|
| 398 |
total_bce_loss = 0.
|
| 399 |
total_dice_loss = 0.
|
| 400 |
+
total_lsnr_loss = 0.
|
| 401 |
total_batches = 0.
|
| 402 |
|
| 403 |
progress_bar_eval.close()
|
|
|
|
| 439 |
"loss": average_loss,
|
| 440 |
"bce_loss": average_bce_loss,
|
| 441 |
"dice_loss": average_dice_loss,
|
| 442 |
+
"lsnr_loss": average_lsnr_loss,
|
| 443 |
|
| 444 |
"accuracy": accuracy,
|
| 445 |
+
"f1": f1,
|
| 446 |
+
"precision": precision,
|
| 447 |
+
"recall": recall,
|
| 448 |
}
|
| 449 |
metrics_filename = save_dir / "metrics_epoch.json"
|
| 450 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
toolbox/torchaudio/models/vad/fsmn_vad/fsmn_encoder.py
CHANGED
|
@@ -197,7 +197,6 @@ class FSMN(nn.Module):
|
|
| 197 |
basic_block_rstride: int,
|
| 198 |
output_affine_size: int,
|
| 199 |
output_size: int,
|
| 200 |
-
use_softmax: bool = True,
|
| 201 |
):
|
| 202 |
super(FSMN, self).__init__()
|
| 203 |
self.input_size = input_size
|
|
|
|
| 197 |
basic_block_rstride: int,
|
| 198 |
output_affine_size: int,
|
| 199 |
output_size: int,
|
|
|
|
| 200 |
):
|
| 201 |
super(FSMN, self).__init__()
|
| 202 |
self.input_size = input_size
|
toolbox/torchaudio/models/vad/fsmn_vad/inference_fsmn_vad.py
CHANGED
|
@@ -68,7 +68,7 @@ class InferenceFSMNVad(object):
|
|
| 68 |
# inputs shape: [1, num_samples,]
|
| 69 |
|
| 70 |
with torch.no_grad():
|
| 71 |
-
logits, probs = self.model.forward(inputs)
|
| 72 |
|
| 73 |
# probs shape: [b, t, 1]
|
| 74 |
probs = torch.squeeze(probs, dim=-1)
|
|
@@ -92,15 +92,24 @@ def get_args():
|
|
| 92 |
# default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
|
| 93 |
# default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
|
| 94 |
# default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
|
| 95 |
-
|
| 96 |
-
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\
|
| 97 |
-
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\
|
| 98 |
-
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\
|
| 99 |
-
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\
|
| 100 |
-
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\
|
| 101 |
-
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\
|
| 102 |
-
|
| 103 |
-
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
type=str,
|
| 105 |
)
|
| 106 |
args = parser.parse_args()
|
|
@@ -119,7 +128,8 @@ def main():
|
|
| 119 |
signal = signal / (1 << 15)
|
| 120 |
|
| 121 |
infer = InferenceFSMNVad(
|
| 122 |
-
pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix()
|
|
|
|
| 123 |
)
|
| 124 |
frame_step = infer.config.hop_size
|
| 125 |
|
|
|
|
| 68 |
# inputs shape: [1, num_samples,]
|
| 69 |
|
| 70 |
with torch.no_grad():
|
| 71 |
+
logits, probs, lsnr = self.model.forward(inputs)
|
| 72 |
|
| 73 |
# probs shape: [b, t, 1]
|
| 74 |
probs = torch.squeeze(probs, dim=-1)
|
|
|
|
| 92 |
# default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
|
| 93 |
# default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
|
| 94 |
# default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
|
| 95 |
+
default=(project_path / "data/examples/speech/active_media_r_2a2f472b-a0b8-4fd5-b1c4-1aedc5d2ce57_0.wav").as_posix(),
|
| 96 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_w_8b6e28e2-a238-4c8c-b2e3-426b1fca149b_6.wav",
|
| 97 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0a56f035-40f6-4530-b852-613f057d718d_6.wav",
|
| 98 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ae70b76-3651-4a71-bc0c-9e1429e4c854_5.wav",
|
| 99 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d483249-57f8-4d45-b4c6-bda82d6816ae_2.wav",
|
| 100 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d952885-5bc2-4633-81b6-e0e809e113f1_2.wav",
|
| 101 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ddac777-d986-4a5c-9c7c-ff64be0a463d_11.wav",
|
| 102 |
+
|
| 103 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0b8a8e80-52af-423b-8877-03a78b1e6e43_0.wav",
|
| 104 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0ebffb68-6490-4a8b-8eb6-eb82443d7d75_0.wav",
|
| 105 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0f6ec933-90df-447b-aca4-6ddc149452ab_0.wav",
|
| 106 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_0.wav",
|
| 107 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_1.wav",
|
| 108 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aff518b-4749-42fc-adfe-64046f9baeb6_0.wav",
|
| 109 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_0.wav",
|
| 110 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_1.wav",
|
| 111 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1bb1f22e-9c3a-4aea-b53f-71cc6547a6ee_0.wav",
|
| 112 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1dab161b-2a76-4491-abd1-60dba6172f8d_2.wav",
|
| 113 |
type=str,
|
| 114 |
)
|
| 115 |
args = parser.parse_args()
|
|
|
|
| 128 |
signal = signal / (1 << 15)
|
| 129 |
|
| 130 |
infer = InferenceFSMNVad(
|
| 131 |
+
# pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix(),
|
| 132 |
+
pretrained_model_path_or_zip_file = (project_path / "trained_models/fsmn-vad-by-webrtcvad-nx2-dns3.zip").as_posix(),
|
| 133 |
)
|
| 134 |
frame_step = infer.config.hop_size
|
| 135 |
|