HoneyTian commited on
Commit
5e1cd25
·
1 Parent(s): 9829721
README.md CHANGED
@@ -61,7 +61,7 @@ docker run -itd \
61
  --ipc=host \
62
  -v /data/tianxing/HuggingDatasets/nx_noise/data:/data/tianxing/HuggingDatasets/nx_noise/data \
63
  -v /data/tianxing/PycharmProjects/cc_vad:/data/tianxing/PycharmProjects/cc_vad \
64
- python:3.12
65
 
66
 
67
  查看GPU
 
61
  --ipc=host \
62
  -v /data/tianxing/HuggingDatasets/nx_noise/data:/data/tianxing/HuggingDatasets/nx_noise/data \
63
  -v /data/tianxing/PycharmProjects/cc_vad:/data/tianxing/PycharmProjects/cc_vad \
64
+ python:3.12 /bin/bash
65
 
66
 
67
  查看GPU
examples/silero_vad_by_webrtcvad/run.sh CHANGED
@@ -2,19 +2,9 @@
2
 
3
  : <<'END'
4
 
5
- sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name dfnet-nx-speech \
6
- --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
7
- --speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
8
-
9
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx-dns3 \
10
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
11
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
12
-
13
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx2 \
14
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
15
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
16
-
17
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dfnet2-nx2-dns3 --final_model_name dfnet2-nx2-dns3 \
18
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
19
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
20
 
 
2
 
3
  : <<'END'
4
 
5
+ bash run.sh --stage 1 --stop_stage 1 --system_version centos \
6
+ --file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
7
+ --final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
 
 
 
 
 
 
 
 
 
 
8
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
9
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
10
 
examples/silero_vad_by_webrtcvad/step_2_train_model.py CHANGED
@@ -1,8 +1,5 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
- """
4
- https://github.com/Rikorose/DeepFilterNet
5
- """
6
  import argparse
7
  import json
8
  import logging
@@ -13,9 +10,7 @@ from pathlib import Path
13
  import random
14
  import sys
15
  import shutil
16
- from typing import List
17
-
18
- from fontTools.varLib.plot import stops
19
 
20
  pwd = os.path.abspath(os.path.dirname(__file__))
21
  sys.path.append(os.path.join(pwd, "../../"))
@@ -27,12 +22,13 @@ from torch.nn import functional as F
27
  from torch.utils.data.dataloader import DataLoader
28
  from tqdm import tqdm
29
 
30
- from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
31
- from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
32
- from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
33
- from toolbox.torchaudio.metrics.pesq import run_pesq_score
34
- from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config
35
- from toolbox.torchaudio.models.dfnet2.modeling_dfnet2 import DfNet2, DfNet2PretrainedModel
 
36
 
37
 
38
  def get_args():
@@ -76,28 +72,23 @@ class CollateFunction(object):
76
  pass
77
 
78
  def __call__(self, batch: List[dict]):
79
- clean_audios = list()
80
  noisy_audios = list()
81
- snr_db_list = list()
82
 
83
  for sample in batch:
84
- # noise_wave: torch.Tensor = sample["noise_wave"]
85
- clean_audio: torch.Tensor = sample["speech_wave"]
86
- noisy_audio: torch.Tensor = sample["mix_wave"]
87
- # snr_db: float = sample["snr_db"]
88
 
89
- clean_audios.append(clean_audio)
90
- noisy_audios.append(noisy_audio)
91
 
92
- clean_audios = torch.stack(clean_audios)
93
  noisy_audios = torch.stack(noisy_audios)
94
 
95
  # assert
96
- if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
97
- raise AssertionError("nan or inf in clean_audios")
98
  if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
99
  raise AssertionError("nan or inf in noisy_audios")
100
- return clean_audios, noisy_audios
 
101
 
102
 
103
  collate_fn = CollateFunction()
@@ -106,7 +97,7 @@ collate_fn = CollateFunction()
106
  def main():
107
  args = get_args()
108
 
109
- config = DfNet2Config.from_pretrained(
110
  pretrained_model_name_or_path=args.config_file,
111
  )
112
 
@@ -125,7 +116,7 @@ def main():
125
  logger.info(f"GPU available count: {n_gpu}; device: {device}")
126
 
127
  # datasets
128
- train_dataset = DenoiseJsonlDataset(
129
  jsonl_file=args.train_dataset,
130
  expected_sample_rate=config.sample_rate,
131
  max_wave_value=32768.0,
@@ -133,7 +124,7 @@ def main():
133
  max_snr_db=config.max_snr_db,
134
  # skip=225000,
135
  )
136
- valid_dataset = DenoiseJsonlDataset(
137
  jsonl_file=args.valid_dataset,
138
  expected_sample_rate=config.sample_rate,
139
  max_wave_value=32768.0,
@@ -165,7 +156,7 @@ def main():
165
 
166
  # models
167
  logger.info(f"prepare models. config_file: {args.config_file}")
168
- model = DfNet2PretrainedModel(config).to(device)
169
  model.to(device)
170
  model.train()
171
 
@@ -210,25 +201,17 @@ def main():
210
  else:
211
  raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
212
 
213
- neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
214
- mr_stft_loss_fn = MultiResolutionSTFTLoss(
215
- fft_size_list=[256, 512, 1024],
216
- win_size_list=[256, 512, 1024],
217
- hop_size_list=[128, 256, 512],
218
- factor_sc=1.5,
219
- factor_mag=1.0,
220
- reduction="mean"
221
- ).to(device)
222
 
223
  # training loop
224
 
225
  # state
226
- average_pesq_score = 1000000000
227
  average_loss = 1000000000
228
- average_mr_stft_loss = 1000000000
229
- average_neg_si_snr_loss = 1000000000
230
- average_mask_loss = 1000000000
231
- average_lsnr_loss = 1000000000
232
 
233
  model_list = list()
234
  best_epoch_idx = None
@@ -246,13 +229,11 @@ def main():
246
 
247
  # train
248
  model.train()
 
249
 
250
- total_pesq_score = 0.
251
  total_loss = 0.
252
- total_mr_stft_loss = 0.
253
- total_neg_si_snr_loss = 0.
254
- total_mask_loss = 0.
255
- total_lsnr_loss = 0.
256
  total_batches = 0.
257
 
258
  progress_bar_train = tqdm(
@@ -260,28 +241,24 @@ def main():
260
  desc="Training; epoch-{}".format(epoch_idx),
261
  )
262
  for train_batch in train_data_loader:
263
- clean_audios, noisy_audios = train_batch
264
- clean_audios: torch.Tensor = clean_audios.to(device)
265
  noisy_audios: torch.Tensor = noisy_audios.to(device)
 
 
 
 
266
 
267
- est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
268
- # est_wav shape: [b, 1, n_samples]
269
- est_wav = torch.squeeze(est_wav, dim=1)
270
- # est_wav shape: [b, n_samples]
271
 
272
- mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
273
- neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
274
- mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
275
- lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
276
 
277
- loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss
278
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
279
  logger.info(f"find nan or inf in loss. continue.")
280
  continue
281
 
282
- denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
283
- clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
284
- pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
285
 
286
  optimizer.zero_grad()
287
  loss.backward()
@@ -289,30 +266,25 @@ def main():
289
  optimizer.step()
290
  lr_scheduler.step()
291
 
292
- total_pesq_score += pesq_score
293
  total_loss += loss.item()
294
- total_mr_stft_loss += mr_stft_loss.item()
295
- total_neg_si_snr_loss += neg_si_snr_loss.item()
296
- total_mask_loss += mask_loss.item()
297
- total_lsnr_loss += lsnr_loss.item()
298
  total_batches += 1
299
 
300
- average_pesq_score = round(total_pesq_score / total_batches, 4)
301
  average_loss = round(total_loss / total_batches, 4)
302
- average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
303
- average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
304
- average_mask_loss = round(total_mask_loss / total_batches, 4)
305
- average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
 
306
 
307
  progress_bar_train.update(1)
308
  progress_bar_train.set_postfix({
309
  "lr": lr_scheduler.get_last_lr()[0],
310
- "pesq_score": average_pesq_score,
311
  "loss": average_loss,
312
- "mr_stft_loss": average_mr_stft_loss,
313
- "neg_si_snr_loss": average_neg_si_snr_loss,
314
- "mask_loss": average_mask_loss,
315
- "lsnr_loss": average_lsnr_loss,
316
  })
317
 
318
  # evaluation
@@ -322,13 +294,11 @@ def main():
322
  torch.cuda.empty_cache()
323
 
324
  model.eval()
 
325
 
326
- total_pesq_score = 0.
327
  total_loss = 0.
328
- total_mr_stft_loss = 0.
329
- total_neg_si_snr_loss = 0.
330
- total_mask_loss = 0.
331
- total_lsnr_loss = 0.
332
  total_batches = 0.
333
 
334
  progress_bar_train.close()
@@ -336,63 +306,52 @@ def main():
336
  desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
337
  )
338
  for eval_batch in valid_data_loader:
339
- clean_audios, noisy_audios = eval_batch
340
- clean_audios: torch.Tensor = clean_audios.to(device)
341
  noisy_audios: torch.Tensor = noisy_audios.to(device)
 
 
 
 
342
 
343
- est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
344
- # est_wav shape: [b, 1, n_samples]
345
- est_wav = torch.squeeze(est_wav, dim=1)
346
- # est_wav shape: [b, n_samples]
347
 
348
- mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
349
- neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
350
- mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
351
- lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
352
 
353
- loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss
354
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
355
  logger.info(f"find nan or inf in loss. continue.")
356
  continue
357
 
358
- denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
359
- clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
360
- pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
361
 
362
- total_pesq_score += pesq_score
363
  total_loss += loss.item()
364
- total_mr_stft_loss += mr_stft_loss.item()
365
- total_neg_si_snr_loss += neg_si_snr_loss.item()
366
- total_mask_loss += mask_loss.item()
367
- total_lsnr_loss += lsnr_loss.item()
368
  total_batches += 1
369
 
370
- average_pesq_score = round(total_pesq_score / total_batches, 4)
371
  average_loss = round(total_loss / total_batches, 4)
372
- average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
373
- average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
374
- average_mask_loss = round(total_mask_loss / total_batches, 4)
375
- average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
 
376
 
377
  progress_bar_eval.update(1)
378
  progress_bar_eval.set_postfix({
379
  "lr": lr_scheduler.get_last_lr()[0],
380
- "pesq_score": average_pesq_score,
381
  "loss": average_loss,
382
- "mr_stft_loss": average_mr_stft_loss,
383
- "neg_si_snr_loss": average_neg_si_snr_loss,
384
- "mask_loss": average_mask_loss,
385
- "lsnr_loss": average_lsnr_loss,
386
  })
387
 
388
  model.train()
 
389
 
390
- total_pesq_score = 0.
391
  total_loss = 0.
392
- total_mr_stft_loss = 0.
393
- total_neg_si_snr_loss = 0.
394
- total_mask_loss = 0.
395
- total_lsnr_loss = 0.
396
  total_batches = 0.
397
 
398
  progress_bar_eval.close()
@@ -418,12 +377,12 @@ def main():
418
  if best_metric is None:
419
  best_epoch_idx = epoch_idx
420
  best_step_idx = step_idx
421
- best_metric = average_pesq_score
422
- elif average_pesq_score >= best_metric:
423
  # great is better.
424
  best_epoch_idx = epoch_idx
425
  best_step_idx = step_idx
426
- best_metric = average_pesq_score
427
  else:
428
  pass
429
 
@@ -431,12 +390,11 @@ def main():
431
  "epoch_idx": epoch_idx,
432
  "best_epoch_idx": best_epoch_idx,
433
  "best_step_idx": best_step_idx,
434
- "pesq_score": average_pesq_score,
435
  "loss": average_loss,
436
- "mr_stft_loss": average_mr_stft_loss,
437
- "neg_si_snr_loss": average_neg_si_snr_loss,
438
- "mask_loss": average_mask_loss,
439
- "lsnr_loss": average_lsnr_loss,
440
  }
441
  metrics_filename = save_dir / "metrics_epoch.json"
442
  with open(metrics_filename, "w", encoding="utf-8") as f:
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
 
 
 
3
  import argparse
4
  import json
5
  import logging
 
10
  import random
11
  import sys
12
  import shutil
13
+ from typing import List, Tuple
 
 
14
 
15
  pwd = os.path.abspath(os.path.dirname(__file__))
16
  sys.path.append(os.path.join(pwd, "../../"))
 
22
  from torch.utils.data.dataloader import DataLoader
23
  from tqdm import tqdm
24
 
25
+ from toolbox.torch.utils.data.dataset.vad_jsonl_dataset import VadJsonlDataset
26
+ from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
27
+ from toolbox.torchaudio.models.vad.silero_vad.modeling_silero_vad import SileroVadModel, SileroVadPretrainedModel
28
+ from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss
29
+ from toolbox.torchaudio.losses.bce_loss import BCELoss
30
+ from toolbox.torchaudio.losses.dice_loss import DiceLoss
31
+ from toolbox.torchaudio.metrics.vad_metrics.vad_accuracy import VadAccuracy
32
 
33
 
34
  def get_args():
 
72
  pass
73
 
74
  def __call__(self, batch: List[dict]):
 
75
  noisy_audios = list()
76
+ batch_vad_segments = list()
77
 
78
  for sample in batch:
79
+ noisy_wave: torch.Tensor = sample["noisy_wave"]
80
+ vad_segments: List[Tuple[float, float]] = sample["vad_segments"]
 
 
81
 
82
+ noisy_audios.append(noisy_wave)
83
+ batch_vad_segments.append(vad_segments)
84
 
 
85
  noisy_audios = torch.stack(noisy_audios)
86
 
87
  # assert
 
 
88
  if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
89
  raise AssertionError("nan or inf in noisy_audios")
90
+
91
+ return noisy_audios, batch_vad_segments
92
 
93
 
94
  collate_fn = CollateFunction()
 
97
  def main():
98
  args = get_args()
99
 
100
+ config = SileroVadConfig.from_pretrained(
101
  pretrained_model_name_or_path=args.config_file,
102
  )
103
 
 
116
  logger.info(f"GPU available count: {n_gpu}; device: {device}")
117
 
118
  # datasets
119
+ train_dataset = VadJsonlDataset(
120
  jsonl_file=args.train_dataset,
121
  expected_sample_rate=config.sample_rate,
122
  max_wave_value=32768.0,
 
124
  max_snr_db=config.max_snr_db,
125
  # skip=225000,
126
  )
127
+ valid_dataset = VadJsonlDataset(
128
  jsonl_file=args.valid_dataset,
129
  expected_sample_rate=config.sample_rate,
130
  max_wave_value=32768.0,
 
156
 
157
  # models
158
  logger.info(f"prepare models. config_file: {args.config_file}")
159
+ model = SileroVadPretrainedModel(config).to(device)
160
  model.to(device)
161
  model.train()
162
 
 
201
  else:
202
  raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
203
 
204
+ bce_loss_fn = BCELoss(reduction="mean").to(device)
205
+ dice_loss_fn = DiceLoss(reduction="mean").to(device)
206
+
207
+ vad_accuracy_metrics_fn = VadAccuracy(threshold=0.5)
 
 
 
 
 
208
 
209
  # training loop
210
 
211
  # state
 
212
  average_loss = 1000000000
213
+ average_bce_loss = 1000000000
214
+ average_dice_loss = 1000000000
 
 
215
 
216
  model_list = list()
217
  best_epoch_idx = None
 
229
 
230
  # train
231
  model.train()
232
+ vad_accuracy_metrics_fn.reset()
233
 
 
234
  total_loss = 0.
235
+ total_bce_loss = 0.
236
+ total_dice_loss = 0.
 
 
237
  total_batches = 0.
238
 
239
  progress_bar_train = tqdm(
 
241
  desc="Training; epoch-{}".format(epoch_idx),
242
  )
243
  for train_batch in train_data_loader:
244
+ noisy_audios, batch_vad_segments = train_batch
 
245
  noisy_audios: torch.Tensor = noisy_audios.to(device)
246
+ # noisy_audios shape: [b, num_samples]
247
+ num_samples = noisy_audios.shape[-1]
248
+
249
+ predictions = model.forward(noisy_audios)
250
 
251
+ targets = BaseVadLoss.get_targets(predictions, batch_vad_segments, duration=num_samples / config.sample_rate)
 
 
 
252
 
253
+ bce_loss = bce_loss_fn.forward(predictions, targets)
254
+ dice_loss = dice_loss_fn.forward(predictions, targets)
 
 
255
 
256
+ loss = 1.0 * bce_loss + 1.0 * dice_loss
257
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
258
  logger.info(f"find nan or inf in loss. continue.")
259
  continue
260
 
261
+ vad_accuracy_metrics_fn.__call__(predictions, targets)
 
 
262
 
263
  optimizer.zero_grad()
264
  loss.backward()
 
266
  optimizer.step()
267
  lr_scheduler.step()
268
 
 
269
  total_loss += loss.item()
270
+ total_bce_loss += bce_loss.item()
271
+ total_dice_loss += dice_loss.item()
 
 
272
  total_batches += 1
273
 
 
274
  average_loss = round(total_loss / total_batches, 4)
275
+ average_bce_loss = round(total_bce_loss / total_batches, 4)
276
+ average_dice_loss = round(total_dice_loss / total_batches, 4)
277
+
278
+ metrics = vad_accuracy_metrics_fn.get_metric()
279
+ accuracy = metrics["accuracy"]
280
 
281
  progress_bar_train.update(1)
282
  progress_bar_train.set_postfix({
283
  "lr": lr_scheduler.get_last_lr()[0],
 
284
  "loss": average_loss,
285
+ "average_bce_loss": average_bce_loss,
286
+ "average_dice_loss": average_dice_loss,
287
+ "accuracy": accuracy,
 
288
  })
289
 
290
  # evaluation
 
294
  torch.cuda.empty_cache()
295
 
296
  model.eval()
297
+ vad_accuracy_metrics_fn.reset()
298
 
 
299
  total_loss = 0.
300
+ total_bce_loss = 0.
301
+ total_dice_loss = 0.
 
 
302
  total_batches = 0.
303
 
304
  progress_bar_train.close()
 
306
  desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
307
  )
308
  for eval_batch in valid_data_loader:
309
+ noisy_audios, batch_vad_segments = train_batch
 
310
  noisy_audios: torch.Tensor = noisy_audios.to(device)
311
+ # noisy_audios shape: [b, num_samples]
312
+ num_samples = noisy_audios.shape[-1]
313
+
314
+ predictions = model.forward(noisy_audios)
315
 
316
+ targets = BaseVadLoss.get_targets(predictions, batch_vad_segments, duration=num_samples / config.sample_rate)
 
 
 
317
 
318
+ bce_loss = bce_loss_fn.forward(predictions, targets)
319
+ dice_loss = dice_loss_fn.forward(predictions, targets)
 
 
320
 
321
+ loss = 1.0 * bce_loss + 1.0 * dice_loss
322
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
323
  logger.info(f"find nan or inf in loss. continue.")
324
  continue
325
 
326
+ vad_accuracy_metrics_fn.__call__(predictions, targets)
 
 
327
 
 
328
  total_loss += loss.item()
329
+ total_bce_loss += bce_loss.item()
330
+ total_dice_loss += dice_loss.item()
 
 
331
  total_batches += 1
332
 
 
333
  average_loss = round(total_loss / total_batches, 4)
334
+ average_bce_loss = round(total_bce_loss / total_batches, 4)
335
+ average_dice_loss = round(total_dice_loss / total_batches, 4)
336
+
337
+ metrics = vad_accuracy_metrics_fn.get_metric()
338
+ accuracy = metrics["accuracy"]
339
 
340
  progress_bar_eval.update(1)
341
  progress_bar_eval.set_postfix({
342
  "lr": lr_scheduler.get_last_lr()[0],
 
343
  "loss": average_loss,
344
+ "average_bce_loss": average_bce_loss,
345
+ "average_dice_loss": average_dice_loss,
346
+ "accuracy": accuracy,
 
347
  })
348
 
349
  model.train()
350
+ vad_accuracy_metrics_fn.reset()
351
 
 
352
  total_loss = 0.
353
+ total_bce_loss = 0.
354
+ total_dice_loss = 0.
 
 
355
  total_batches = 0.
356
 
357
  progress_bar_eval.close()
 
377
  if best_metric is None:
378
  best_epoch_idx = epoch_idx
379
  best_step_idx = step_idx
380
+ best_metric = accuracy
381
+ elif accuracy >= best_metric:
382
  # great is better.
383
  best_epoch_idx = epoch_idx
384
  best_step_idx = step_idx
385
+ best_metric = accuracy
386
  else:
387
  pass
388
 
 
390
  "epoch_idx": epoch_idx,
391
  "best_epoch_idx": best_epoch_idx,
392
  "best_step_idx": best_step_idx,
 
393
  "loss": average_loss,
394
+ "bce_loss": average_bce_loss,
395
+ "dice_loss": average_dice_loss,
396
+
397
+ "accuracy": accuracy,
398
  }
399
  metrics_filename = save_dir / "metrics_epoch.json"
400
  with open(metrics_filename, "w", encoding="utf-8") as f:
toolbox/torchaudio/losses/bce_loss.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss
9
+
10
+
11
+ class BCELoss(BaseVadLoss):
12
+ """
13
+ Binary Cross-Entropy Loss, BCE Loss
14
+ """
15
+ def __init__(self,
16
+ reduction: str = "mean",
17
+ ):
18
+ super(BCELoss, self).__init__()
19
+ self.reduction = reduction
20
+
21
+ self.bce_loss_fn = nn.BCELoss(reduction=reduction)
22
+
23
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
24
+ """
25
+ :param inputs: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation.
26
+ :param targets: shape as `inputs`.
27
+ :return:
28
+ """
29
+ loss = self.bce_loss_fn.forward(inputs, targets)
30
+ return loss
31
+
32
+
33
+ def main():
34
+ inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
35
+
36
+ loss_fn = BCELoss()
37
+
38
+ loss = loss_fn.forward(inputs, inputs)
39
+ print(loss)
40
+ return
41
+
42
+
43
+ if __name__ == "__main__":
44
+ main()
toolbox/torchaudio/losses/dice_loss.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class DiceLoss(nn.Module):
10
+ def __init__(self,
11
+ reduction: str = "mean",
12
+ eps: float = 1e-6,
13
+ ):
14
+ super(DiceLoss, self).__init__()
15
+ self.reduction = reduction
16
+ self.eps = eps
17
+
18
+ if reduction not in ("sum", "mean"):
19
+ raise AssertionError(f"param reduction must be sum or mean.")
20
+
21
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
22
+ """
23
+ :param inputs: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation.
24
+ :param targets: shape as `inputs`.
25
+ :return:
26
+ """
27
+ inputs_ = torch.squeeze(inputs, dim=-1)
28
+ targets_ = torch.squeeze(targets, dim=-1)
29
+ # shape: [b, t]
30
+
31
+ intersection = (inputs_ * targets_).sum(dim=-1)
32
+ union = (inputs_ + targets_).sum(dim=-1)
33
+ # shape: [b,]
34
+
35
+ dice = (2. * intersection + self.eps) / (union + self.eps)
36
+ # shape: [b,]
37
+
38
+ loss = 1. - dice
39
+ # shape: [b,]
40
+
41
+ if self.reduction == "mean":
42
+ loss = torch.mean(loss)
43
+ elif self.reduction == "sum":
44
+ loss = torch.sum(loss)
45
+ else:
46
+ raise AssertionError
47
+ return loss
48
+
49
+
50
+ def main():
51
+ inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
52
+
53
+ loss_fn = DiceLoss()
54
+
55
+ loss = loss_fn.forward(inputs, inputs)
56
+ print(loss)
57
+ return
58
+
59
+
60
+ if __name__ == "__main__":
61
+ main()
toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py CHANGED
@@ -8,9 +8,13 @@ https://github.com/snakers4/silero-vad
8
 
9
  https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/data/silero_vad.jit
10
  """
 
 
 
11
  import torch
12
  import torch.nn as nn
13
 
 
14
  from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
15
  from toolbox.torchaudio.modules.conv_stft import ConvSTFT
16
 
@@ -134,6 +138,52 @@ class SileroVadModel(nn.Module):
134
  return x
135
 
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  def main():
138
  config = SileroVadConfig()
139
  model = SileroVadModel(config=config)
 
8
 
9
  https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/data/silero_vad.jit
10
  """
11
+ import os
12
+ from typing import Optional, Union
13
+
14
  import torch
15
  import torch.nn as nn
16
 
17
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
18
  from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
19
  from toolbox.torchaudio.modules.conv_stft import ConvSTFT
20
 
 
138
  return x
139
 
140
 
141
+ class SileroVadPretrainedModel(SileroVadModel):
142
+ def __init__(self,
143
+ config: SileroVadConfig,
144
+ ):
145
+ super(SileroVadPretrainedModel, self).__init__(
146
+ config=config,
147
+ )
148
+
149
+ @classmethod
150
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
151
+ config = SileroVadConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
152
+
153
+ model = cls(config)
154
+
155
+ if os.path.isdir(pretrained_model_name_or_path):
156
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
157
+ else:
158
+ ckpt_file = pretrained_model_name_or_path
159
+
160
+ with open(ckpt_file, "rb") as f:
161
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
162
+ model.load_state_dict(state_dict, strict=True)
163
+ return model
164
+
165
+ def save_pretrained(self,
166
+ save_directory: Union[str, os.PathLike],
167
+ state_dict: Optional[dict] = None,
168
+ ):
169
+
170
+ model = self
171
+
172
+ if state_dict is None:
173
+ state_dict = model.state_dict()
174
+
175
+ os.makedirs(save_directory, exist_ok=True)
176
+
177
+ # save state dict
178
+ model_file = os.path.join(save_directory, MODEL_FILE)
179
+ torch.save(state_dict, model_file)
180
+
181
+ # save config
182
+ config_file = os.path.join(save_directory, CONFIG_FILE)
183
+ self.config.to_yaml_file(config_file)
184
+ return save_directory
185
+
186
+
187
  def main():
188
  config = SileroVadConfig()
189
  model = SileroVadModel(config=config)