HoneyTian commited on
Commit
48776cd
·
1 Parent(s): 5eb1356
examples/cnn_vad_by_webrtcvad/run.sh CHANGED
@@ -5,14 +5,16 @@
5
  bash run.sh --stage 2 --stop_stage 2 --system_version centos \
6
  --file_folder_name cnn-vad-by-webrtcvad-nx-dns3 \
7
  --final_model_name cnn-vad-by-webrtcvad-nx-dns3 \
8
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
9
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
 
10
 
11
  bash run.sh --stage 3 --stop_stage 3 --system_version centos \
12
  --file_folder_name cnn-vad-by-webrtcvad-nx-dns3 \
13
  --final_model_name cnn-vad-by-webrtcvad-nx-dns3 \
14
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
15
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
 
16
 
17
 
18
  END
@@ -30,8 +32,8 @@ final_model_name=final_model_name
30
  config_file="yaml/config.yaml"
31
  limit=10
32
 
33
- noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
34
- speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
35
 
36
  max_count=-1
37
 
@@ -98,8 +100,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
98
  $verbose && echo "stage 1: prepare data"
99
  cd "${work_dir}" || exit 1
100
  python3 step_1_prepare_data.py \
101
- --noise_dir "${noise_dir}" \
102
- --speech_dir "${speech_dir}" \
103
  --train_dataset "${train_dataset}" \
104
  --valid_dataset "${valid_dataset}" \
105
  --max_count "${max_count}" \
 
5
  bash run.sh --stage 2 --stop_stage 2 --system_version centos \
6
  --file_folder_name cnn-vad-by-webrtcvad-nx-dns3 \
7
  --final_model_name cnn-vad-by-webrtcvad-nx-dns3 \
8
+ --noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
9
+ --speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
10
+ /data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
11
 
12
  bash run.sh --stage 3 --stop_stage 3 --system_version centos \
13
  --file_folder_name cnn-vad-by-webrtcvad-nx-dns3 \
14
  --final_model_name cnn-vad-by-webrtcvad-nx-dns3 \
15
+ --noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
16
+ --speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
17
+ /data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
18
 
19
 
20
  END
 
32
  config_file="yaml/config.yaml"
33
  limit=10
34
 
35
+ noise_patterns=/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav
36
+ speech_patterns=/data/tianxing/HuggingDatasets/nx_noise/data/speech/**/*.wav
37
 
38
  max_count=-1
39
 
 
100
  $verbose && echo "stage 1: prepare data"
101
  cd "${work_dir}" || exit 1
102
  python3 step_1_prepare_data.py \
103
+ --noise_patterns "${noise_patterns}" \
104
+ --speech_patterns "${speech_patterns}" \
105
  --train_dataset "${train_dataset}" \
106
  --valid_dataset "${valid_dataset}" \
107
  --max_count "${max_count}" \
examples/cnn_vad_by_webrtcvad/step_1_prepare_data.py CHANGED
@@ -1,12 +1,14 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
 
4
  import json
5
  import os
6
  from pathlib import Path
7
  import random
8
  import sys
9
  import time
 
10
 
11
  pwd = os.path.abspath(os.path.dirname(__file__))
12
  sys.path.append(os.path.join(pwd, "../../"))
@@ -19,13 +21,13 @@ from tqdm import tqdm
19
  def get_args():
20
  parser = argparse.ArgumentParser()
21
  parser.add_argument(
22
- "--noise_dir",
23
- default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
24
  type=str
25
  )
26
  parser.add_argument(
27
- "--speech_dir",
28
- default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech",
29
  type=str
30
  )
31
 
@@ -46,108 +48,112 @@ def get_args():
46
  return args
47
 
48
 
49
- def target_second_noise_signal_generator(data_dir: str,
50
  duration: int = 4,
51
  sample_rate: int = 8000, max_epoch: int = 20000):
52
  noise_list = list()
53
  wait_duration = duration
54
 
55
- data_dir = Path(data_dir)
56
  for epoch_idx in range(max_epoch):
57
- for filename in data_dir.glob("**/*.wav"):
58
- signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
-
60
- if signal.ndim != 1:
61
- raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
 
64
 
65
- offset = 0.
66
- rest_duration = raw_duration
 
67
 
68
- for _ in range(1000):
69
- if rest_duration <= 0:
70
- break
71
- if rest_duration <= wait_duration:
72
- noise_list.append({
73
  "epoch_idx": epoch_idx,
74
- "filename": filename.as_posix(),
75
  "raw_duration": round(raw_duration, 4),
76
- "offset": round(offset, 4),
77
- "duration": None,
78
- "duration_": round(rest_duration, 4),
79
- })
80
- wait_duration -= rest_duration
81
- offset = 0
82
- rest_duration = 0
83
- elif rest_duration > wait_duration:
84
- noise_list.append({
 
 
85
  "epoch_idx": epoch_idx,
86
- "filename": filename.as_posix(),
87
  "raw_duration": round(raw_duration, 4),
88
- "offset": round(offset, 4),
89
- "duration": round(wait_duration, 4),
90
- "duration_": round(wait_duration, 4),
91
- })
92
- offset += wait_duration
93
- rest_duration -= wait_duration
94
- wait_duration = 0
95
- else:
96
- raise AssertionError
97
-
98
- if wait_duration <= 0:
99
- yield noise_list
100
- noise_list = list()
101
- wait_duration = duration
102
-
103
-
104
- def target_second_speech_signal_generator(data_dir: str,
105
- min_duration: int = 4,
106
- max_duration: int = 6,
107
- sample_rate: int = 8000, max_epoch: int = 1):
108
- data_dir = Path(data_dir)
109
- for epoch_idx in range(max_epoch):
110
- for filename in data_dir.glob("**/*.wav"):
111
- signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
112
- raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
113
-
114
- if signal.ndim != 1:
115
- raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
116
-
117
- if raw_duration < min_duration:
118
- # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
119
- continue
120
-
121
- if raw_duration < max_duration:
122
- row = {
123
- "epoch_idx": epoch_idx,
124
- "filename": filename.as_posix(),
125
- "raw_duration": round(raw_duration, 4),
126
- "offset": 0.,
127
- "duration": round(raw_duration, 4),
128
- }
129
- yield row
130
-
131
- signal_length = len(signal)
132
- win_size = int(max_duration * sample_rate)
133
- for begin in range(0, signal_length - win_size, win_size):
134
- if np.sum(signal[begin: begin+win_size]) == 0:
135
- continue
136
- row = {
137
- "epoch_idx": epoch_idx,
138
- "filename": filename.as_posix(),
139
- "raw_duration": round(raw_duration, 4),
140
- "offset": round(begin / sample_rate, 4),
141
- "duration": round(max_duration, 4),
142
- }
143
- yield row
144
 
145
 
146
  def main():
147
  args = get_args()
148
 
149
- noise_dir = Path(args.noise_dir)
150
- speech_dir = Path(args.speech_dir)
 
 
 
 
151
 
152
  train_dataset = Path(args.train_dataset)
153
  valid_dataset = Path(args.valid_dataset)
@@ -155,13 +161,13 @@ def main():
155
  valid_dataset.parent.mkdir(parents=True, exist_ok=True)
156
 
157
  noise_generator = target_second_noise_signal_generator(
158
- noise_dir.as_posix(),
159
  duration=args.duration,
160
  sample_rate=args.target_sample_rate,
161
  max_epoch=100000,
162
  )
163
  speech_generator = target_second_speech_signal_generator(
164
- speech_dir.as_posix(),
165
  min_duration=args.min_speech_duration,
166
  max_duration=args.max_speech_duration,
167
  sample_rate=args.target_sample_rate,
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
+ from glob import glob
5
  import json
6
  import os
7
  from pathlib import Path
8
  import random
9
  import sys
10
  import time
11
+ from typing import List
12
 
13
  pwd = os.path.abspath(os.path.dirname(__file__))
14
  sys.path.append(os.path.join(pwd, "../../"))
 
21
  def get_args():
22
  parser = argparse.ArgumentParser()
23
  parser.add_argument(
24
+ "--noise_patterns",
25
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\**\*.wav",
26
  type=str
27
  )
28
  parser.add_argument(
29
+ "--speech_patterns",
30
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\**\*.wav",
31
  type=str
32
  )
33
 
 
48
  return args
49
 
50
 
51
+ def target_second_noise_signal_generator(filename_patterns: List[str],
52
  duration: int = 4,
53
  sample_rate: int = 8000, max_epoch: int = 20000):
54
  noise_list = list()
55
  wait_duration = duration
56
 
 
57
  for epoch_idx in range(max_epoch):
58
+ for filename_pattern in filename_patterns:
59
+ for filename in glob(filename_pattern):
60
+ signal, _ = librosa.load(filename, sr=sample_rate)
61
+
62
+ if signal.ndim != 1:
63
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
64
+
65
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
66
+
67
+ offset = 0.
68
+ rest_duration = raw_duration
69
+
70
+ for _ in range(1000):
71
+ if rest_duration <= 0:
72
+ break
73
+ if rest_duration <= wait_duration:
74
+ noise_list.append({
75
+ "epoch_idx": epoch_idx,
76
+ "filename": filename,
77
+ "raw_duration": round(raw_duration, 4),
78
+ "offset": round(offset, 4),
79
+ "duration": None,
80
+ "duration_": round(rest_duration, 4),
81
+ })
82
+ wait_duration -= rest_duration
83
+ offset = 0
84
+ rest_duration = 0
85
+ elif rest_duration > wait_duration:
86
+ noise_list.append({
87
+ "epoch_idx": epoch_idx,
88
+ "filename": filename,
89
+ "raw_duration": round(raw_duration, 4),
90
+ "offset": round(offset, 4),
91
+ "duration": round(wait_duration, 4),
92
+ "duration_": round(wait_duration, 4),
93
+ })
94
+ offset += wait_duration
95
+ rest_duration -= wait_duration
96
+ wait_duration = 0
97
+ else:
98
+ raise AssertionError
99
+
100
+ if wait_duration <= 0:
101
+ yield noise_list
102
+ noise_list = list()
103
+ wait_duration = duration
104
+
105
+
106
+ def target_second_speech_signal_generator(filename_patterns: List[str],
107
+ min_duration: int = 4,
108
+ max_duration: int = 6,
109
+ sample_rate: int = 8000, max_epoch: int = 1):
110
+ for epoch_idx in range(max_epoch):
111
+ for filename_pattern in filename_patterns:
112
+ for filename in glob(filename_pattern):
113
+ signal, _ = librosa.load(filename, sr=sample_rate)
114
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
115
 
116
+ if signal.ndim != 1:
117
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
118
 
119
+ if raw_duration < min_duration:
120
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
121
+ continue
122
 
123
+ if raw_duration < max_duration:
124
+ row = {
 
 
 
125
  "epoch_idx": epoch_idx,
126
+ "filename": filename,
127
  "raw_duration": round(raw_duration, 4),
128
+ "offset": 0.,
129
+ "duration": round(raw_duration, 4),
130
+ }
131
+ yield row
132
+
133
+ signal_length = len(signal)
134
+ win_size = int(max_duration * sample_rate)
135
+ for begin in range(0, signal_length - win_size, win_size):
136
+ if np.sum(signal[begin: begin+win_size]) == 0:
137
+ continue
138
+ row = {
139
  "epoch_idx": epoch_idx,
140
+ "filename": filename,
141
  "raw_duration": round(raw_duration, 4),
142
+ "offset": round(begin / sample_rate, 4),
143
+ "duration": round(max_duration, 4),
144
+ }
145
+ yield row
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
 
148
  def main():
149
  args = get_args()
150
 
151
+ noise_patterns = args.noise_patterns
152
+ noise_patterns = noise_patterns.split(" ")
153
+ print(f"noise_patterns: {noise_patterns}")
154
+ speech_patterns = args.speech_patterns
155
+ speech_patterns = speech_patterns.split(" ")
156
+ print(f"speech_patterns: {speech_patterns}")
157
 
158
  train_dataset = Path(args.train_dataset)
159
  valid_dataset = Path(args.valid_dataset)
 
161
  valid_dataset.parent.mkdir(parents=True, exist_ok=True)
162
 
163
  noise_generator = target_second_noise_signal_generator(
164
+ noise_patterns,
165
  duration=args.duration,
166
  sample_rate=args.target_sample_rate,
167
  max_epoch=100000,
168
  )
169
  speech_generator = target_second_speech_signal_generator(
170
+ speech_patterns,
171
  min_duration=args.min_speech_duration,
172
  max_duration=args.max_speech_duration,
173
  sample_rate=args.target_sample_rate,
examples/cnn_vad_by_webrtcvad/step_4_train_model.py CHANGED
@@ -17,8 +17,6 @@ sys.path.append(os.path.join(pwd, "../../"))
17
 
18
  import numpy as np
19
  import torch
20
- import torch.nn as nn
21
- from torch.nn import functional as F
22
  from torch.utils.data.dataloader import DataLoader
23
  from tqdm import tqdm
24
 
 
17
 
18
  import numpy as np
19
  import torch
 
 
20
  from torch.utils.data.dataloader import DataLoader
21
  from tqdm import tqdm
22
 
examples/fsmn_vad_by_webrtcvad/run.sh CHANGED
@@ -2,17 +2,19 @@
2
 
3
  : <<'END'
4
 
5
- bash run.sh --stage 2 --stop_stage 2 --system_version centos \
6
  --file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
7
  --final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
8
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
9
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
 
10
 
11
  bash run.sh --stage 3 --stop_stage 3 --system_version centos \
12
  --file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
13
  --final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
14
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
15
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
 
16
 
17
 
18
  END
@@ -30,8 +32,8 @@ final_model_name=final_model_name
30
  config_file="yaml/config.yaml"
31
  limit=10
32
 
33
- noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
34
- speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
35
 
36
  max_count=-1
37
 
@@ -98,8 +100,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
98
  $verbose && echo "stage 1: prepare data"
99
  cd "${work_dir}" || exit 1
100
  python3 step_1_prepare_data.py \
101
- --noise_dir "${noise_dir}" \
102
- --speech_dir "${speech_dir}" \
103
  --train_dataset "${train_dataset}" \
104
  --valid_dataset "${valid_dataset}" \
105
  --max_count "${max_count}" \
 
2
 
3
  : <<'END'
4
 
5
+ bash run.sh --stage 1 --stop_stage 1 --system_version centos \
6
  --file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
7
  --final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
8
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
9
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
10
+ /data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
11
 
12
  bash run.sh --stage 3 --stop_stage 3 --system_version centos \
13
  --file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
14
  --final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
15
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
16
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
17
+ /data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
18
 
19
 
20
  END
 
32
  config_file="yaml/config.yaml"
33
  limit=10
34
 
35
+ noise_patterns=/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav
36
+ speech_patterns=/data/tianxing/HuggingDatasets/nx_noise/data/speech/**/*.wav
37
 
38
  max_count=-1
39
 
 
100
  $verbose && echo "stage 1: prepare data"
101
  cd "${work_dir}" || exit 1
102
  python3 step_1_prepare_data.py \
103
+ --noise_patterns "${noise_patterns}" \
104
+ --speech_patterns "${speech_patterns}" \
105
  --train_dataset "${train_dataset}" \
106
  --valid_dataset "${valid_dataset}" \
107
  --max_count "${max_count}" \
examples/fsmn_vad_by_webrtcvad/step_1_prepare_data.py CHANGED
@@ -1,12 +1,14 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
 
4
  import json
5
  import os
6
  from pathlib import Path
7
  import random
8
  import sys
9
  import time
 
10
 
11
  pwd = os.path.abspath(os.path.dirname(__file__))
12
  sys.path.append(os.path.join(pwd, "../../"))
@@ -19,13 +21,13 @@ from tqdm import tqdm
19
  def get_args():
20
  parser = argparse.ArgumentParser()
21
  parser.add_argument(
22
- "--noise_dir",
23
- default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
24
  type=str
25
  )
26
  parser.add_argument(
27
- "--speech_dir",
28
- default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech",
29
  type=str
30
  )
31
 
@@ -46,108 +48,112 @@ def get_args():
46
  return args
47
 
48
 
49
- def target_second_noise_signal_generator(data_dir: str,
50
  duration: int = 4,
51
  sample_rate: int = 8000, max_epoch: int = 20000):
52
  noise_list = list()
53
  wait_duration = duration
54
 
55
- data_dir = Path(data_dir)
56
  for epoch_idx in range(max_epoch):
57
- for filename in data_dir.glob("**/*.wav"):
58
- signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
-
60
- if signal.ndim != 1:
61
- raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
 
64
 
65
- offset = 0.
66
- rest_duration = raw_duration
 
67
 
68
- for _ in range(1000):
69
- if rest_duration <= 0:
70
- break
71
- if rest_duration <= wait_duration:
72
- noise_list.append({
73
  "epoch_idx": epoch_idx,
74
- "filename": filename.as_posix(),
75
  "raw_duration": round(raw_duration, 4),
76
- "offset": round(offset, 4),
77
- "duration": None,
78
- "duration_": round(rest_duration, 4),
79
- })
80
- wait_duration -= rest_duration
81
- offset = 0
82
- rest_duration = 0
83
- elif rest_duration > wait_duration:
84
- noise_list.append({
 
 
85
  "epoch_idx": epoch_idx,
86
- "filename": filename.as_posix(),
87
  "raw_duration": round(raw_duration, 4),
88
- "offset": round(offset, 4),
89
- "duration": round(wait_duration, 4),
90
- "duration_": round(wait_duration, 4),
91
- })
92
- offset += wait_duration
93
- rest_duration -= wait_duration
94
- wait_duration = 0
95
- else:
96
- raise AssertionError
97
-
98
- if wait_duration <= 0:
99
- yield noise_list
100
- noise_list = list()
101
- wait_duration = duration
102
-
103
-
104
- def target_second_speech_signal_generator(data_dir: str,
105
- min_duration: int = 4,
106
- max_duration: int = 6,
107
- sample_rate: int = 8000, max_epoch: int = 1):
108
- data_dir = Path(data_dir)
109
- for epoch_idx in range(max_epoch):
110
- for filename in data_dir.glob("**/*.wav"):
111
- signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
112
- raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
113
-
114
- if signal.ndim != 1:
115
- raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
116
-
117
- if raw_duration < min_duration:
118
- # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
119
- continue
120
-
121
- if raw_duration < max_duration:
122
- row = {
123
- "epoch_idx": epoch_idx,
124
- "filename": filename.as_posix(),
125
- "raw_duration": round(raw_duration, 4),
126
- "offset": 0.,
127
- "duration": round(raw_duration, 4),
128
- }
129
- yield row
130
-
131
- signal_length = len(signal)
132
- win_size = int(max_duration * sample_rate)
133
- for begin in range(0, signal_length - win_size, win_size):
134
- if np.sum(signal[begin: begin+win_size]) == 0:
135
- continue
136
- row = {
137
- "epoch_idx": epoch_idx,
138
- "filename": filename.as_posix(),
139
- "raw_duration": round(raw_duration, 4),
140
- "offset": round(begin / sample_rate, 4),
141
- "duration": round(max_duration, 4),
142
- }
143
- yield row
144
 
145
 
146
  def main():
147
  args = get_args()
148
 
149
- noise_dir = Path(args.noise_dir)
150
- speech_dir = Path(args.speech_dir)
 
 
 
 
151
 
152
  train_dataset = Path(args.train_dataset)
153
  valid_dataset = Path(args.valid_dataset)
@@ -155,13 +161,13 @@ def main():
155
  valid_dataset.parent.mkdir(parents=True, exist_ok=True)
156
 
157
  noise_generator = target_second_noise_signal_generator(
158
- noise_dir.as_posix(),
159
  duration=args.duration,
160
  sample_rate=args.target_sample_rate,
161
  max_epoch=100000,
162
  )
163
  speech_generator = target_second_speech_signal_generator(
164
- speech_dir.as_posix(),
165
  min_duration=args.min_speech_duration,
166
  max_duration=args.max_speech_duration,
167
  sample_rate=args.target_sample_rate,
@@ -210,7 +216,7 @@ def main():
210
  "random1": random1,
211
  }
212
  row = json.dumps(row, ensure_ascii=False)
213
- if random2 < (1 / 300 / 1):
214
  fvalid.write(f"{row}\n")
215
  else:
216
  ftrain.write(f"{row}\n")
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
+ from glob import glob
5
  import json
6
  import os
7
  from pathlib import Path
8
  import random
9
  import sys
10
  import time
11
+ from typing import List
12
 
13
  pwd = os.path.abspath(os.path.dirname(__file__))
14
  sys.path.append(os.path.join(pwd, "../../"))
 
21
  def get_args():
22
  parser = argparse.ArgumentParser()
23
  parser.add_argument(
24
+ "--noise_patterns",
25
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\**\*.wav",
26
  type=str
27
  )
28
  parser.add_argument(
29
+ "--speech_patterns",
30
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\**\*.wav",
31
  type=str
32
  )
33
 
 
48
  return args
49
 
50
 
51
+ def target_second_noise_signal_generator(filename_patterns: List[str],
52
  duration: int = 4,
53
  sample_rate: int = 8000, max_epoch: int = 20000):
54
  noise_list = list()
55
  wait_duration = duration
56
 
 
57
  for epoch_idx in range(max_epoch):
58
+ for filename_pattern in filename_patterns:
59
+ for filename in glob(filename_pattern):
60
+ signal, _ = librosa.load(filename, sr=sample_rate)
61
+
62
+ if signal.ndim != 1:
63
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
64
+
65
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
66
+
67
+ offset = 0.
68
+ rest_duration = raw_duration
69
+
70
+ for _ in range(1000):
71
+ if rest_duration <= 0:
72
+ break
73
+ if rest_duration <= wait_duration:
74
+ noise_list.append({
75
+ "epoch_idx": epoch_idx,
76
+ "filename": filename,
77
+ "raw_duration": round(raw_duration, 4),
78
+ "offset": round(offset, 4),
79
+ "duration": None,
80
+ "duration_": round(rest_duration, 4),
81
+ })
82
+ wait_duration -= rest_duration
83
+ offset = 0
84
+ rest_duration = 0
85
+ elif rest_duration > wait_duration:
86
+ noise_list.append({
87
+ "epoch_idx": epoch_idx,
88
+ "filename": filename,
89
+ "raw_duration": round(raw_duration, 4),
90
+ "offset": round(offset, 4),
91
+ "duration": round(wait_duration, 4),
92
+ "duration_": round(wait_duration, 4),
93
+ })
94
+ offset += wait_duration
95
+ rest_duration -= wait_duration
96
+ wait_duration = 0
97
+ else:
98
+ raise AssertionError
99
+
100
+ if wait_duration <= 0:
101
+ yield noise_list
102
+ noise_list = list()
103
+ wait_duration = duration
104
+
105
+
106
+ def target_second_speech_signal_generator(filename_patterns: List[str],
107
+ min_duration: int = 4,
108
+ max_duration: int = 6,
109
+ sample_rate: int = 8000, max_epoch: int = 1):
110
+ for epoch_idx in range(max_epoch):
111
+ for filename_pattern in filename_patterns:
112
+ for filename in glob(filename_pattern):
113
+ signal, _ = librosa.load(filename, sr=sample_rate)
114
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
115
 
116
+ if signal.ndim != 1:
117
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
118
 
119
+ if raw_duration < min_duration:
120
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
121
+ continue
122
 
123
+ if raw_duration < max_duration:
124
+ row = {
 
 
 
125
  "epoch_idx": epoch_idx,
126
+ "filename": filename,
127
  "raw_duration": round(raw_duration, 4),
128
+ "offset": 0.,
129
+ "duration": round(raw_duration, 4),
130
+ }
131
+ yield row
132
+
133
+ signal_length = len(signal)
134
+ win_size = int(max_duration * sample_rate)
135
+ for begin in range(0, signal_length - win_size, win_size):
136
+ if np.sum(signal[begin: begin+win_size]) == 0:
137
+ continue
138
+ row = {
139
  "epoch_idx": epoch_idx,
140
+ "filename": filename,
141
  "raw_duration": round(raw_duration, 4),
142
+ "offset": round(begin / sample_rate, 4),
143
+ "duration": round(max_duration, 4),
144
+ }
145
+ yield row
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
 
148
  def main():
149
  args = get_args()
150
 
151
+ noise_patterns = args.noise_patterns
152
+ noise_patterns = noise_patterns.split(" ")
153
+ print(f"noise_patterns: {noise_patterns}")
154
+ speech_patterns = args.speech_patterns
155
+ speech_patterns = speech_patterns.split(" ")
156
+ print(f"speech_patterns: {speech_patterns}")
157
 
158
  train_dataset = Path(args.train_dataset)
159
  valid_dataset = Path(args.valid_dataset)
 
161
  valid_dataset.parent.mkdir(parents=True, exist_ok=True)
162
 
163
  noise_generator = target_second_noise_signal_generator(
164
+ noise_patterns,
165
  duration=args.duration,
166
  sample_rate=args.target_sample_rate,
167
  max_epoch=100000,
168
  )
169
  speech_generator = target_second_speech_signal_generator(
170
+ speech_patterns,
171
  min_duration=args.min_speech_duration,
172
  max_duration=args.max_speech_duration,
173
  sample_rate=args.target_sample_rate,
 
216
  "random1": random1,
217
  }
218
  row = json.dumps(row, ensure_ascii=False)
219
+ if random2 < (2 / 300):
220
  fvalid.write(f"{row}\n")
221
  else:
222
  ftrain.write(f"{row}\n")
examples/fsmn_vad_by_webrtcvad/step_2_make_vad_segments.py CHANGED
@@ -4,6 +4,7 @@ import argparse
4
  import json
5
  import os
6
  import sys
 
7
 
8
  pwd = os.path.abspath(os.path.dirname(__file__))
9
  sys.path.append(os.path.join(pwd, "../../"))
@@ -42,6 +43,54 @@ def get_args():
42
  return args
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def main():
46
  args = get_args()
47
 
@@ -68,8 +117,8 @@ def main():
68
  end_ring_rate=0.1,
69
  frame_size_ms=30,
70
  frame_step_ms=30,
71
- padding_length_ms=90,
72
- max_silence_length_ms=100,
73
  max_speech_length_s=100,
74
  min_speech_length_s=0.1,
75
  sample_rate=args.expected_sample_rate,
@@ -114,6 +163,9 @@ def main():
114
  )
115
  waveform = np.array(waveform * (1 << 15), dtype=np.int16)
116
 
 
 
 
117
  # vad
118
  vad_segments = list()
119
  segments = w_vad.vad(waveform)
@@ -122,6 +174,7 @@ def main():
122
  vad_segments += segments
123
  w_vad.reset()
124
 
 
125
  row["vad_segments"] = vad_segments
126
 
127
  row = json.dumps(row, ensure_ascii=False)
@@ -168,6 +221,9 @@ def main():
168
  )
169
  waveform = np.array(waveform * (1 << 15), dtype=np.int16)
170
 
 
 
 
171
  # vad
172
  vad_segments = list()
173
  segments = w_vad.vad(waveform)
@@ -176,6 +232,7 @@ def main():
176
  vad_segments += segments
177
  w_vad.reset()
178
 
 
179
  row["vad_segments"] = vad_segments
180
 
181
  row = json.dumps(row, ensure_ascii=False)
 
4
  import json
5
  import os
6
  import sys
7
+ from typing import List, Tuple
8
 
9
  pwd = os.path.abspath(os.path.dirname(__file__))
10
  sys.path.append(os.path.join(pwd, "../../"))
 
43
  return args
44
 
45
 
46
+ def get_non_silence_segments(waveform: np.ndarray, sample_rate: int = 8000):
47
+ non_silent_intervals = librosa.effects.split(
48
+ waveform,
49
+ top_db=40, # 静音阈值(单位:dB)
50
+ frame_length=512, # 分析帧长
51
+ hop_length=128 # 帧移
52
+ )
53
+
54
+ # 输出非静音段的时间区间(单位:秒)
55
+ result = [(start / sample_rate, end / sample_rate) for (start, end) in non_silent_intervals]
56
+ return result
57
+
58
+
59
+ def get_intersection(non_silence: list[tuple[float, float]],
60
+ speech: list[tuple[float, float]]) -> list[tuple[float, float]]:
61
+ """
62
+ 计算语音段与非静音段的交集
63
+ :param non_silence: 非静音段列表,格式 [(start1, end1), ...]
64
+ :param speech: 语音检测段列表,格式 [(start2, end2), ...]
65
+ :return: 交集段列表,格式 [(start, end), ...]
66
+ """
67
+ # 按起始时间排序(假设输入已排序可不排)
68
+ non_silence = sorted(non_silence, key=lambda x: x[0])
69
+ speech = sorted(speech, key=lambda x: x[0])
70
+
71
+ result = []
72
+ i = j = 0
73
+
74
+ while i < len(non_silence) and j < len(speech):
75
+ ns_start, ns_end = non_silence[i]
76
+ sp_start, sp_end = speech[j]
77
+
78
+ # 计算重叠区间
79
+ overlap_start = max(ns_start, sp_start)
80
+ overlap_end = min(ns_end, sp_end)
81
+
82
+ if overlap_start < overlap_end:
83
+ result.append((overlap_start, overlap_end))
84
+
85
+ # 移动指针策略:优先处理先结束的区间
86
+ if ns_end < sp_end:
87
+ i += 1 # 非静音段先结束
88
+ else:
89
+ j += 1 # 语音段先结束
90
+
91
+ return result
92
+
93
+
94
  def main():
95
  args = get_args()
96
 
 
117
  end_ring_rate=0.1,
118
  frame_size_ms=30,
119
  frame_step_ms=30,
120
+ padding_length_ms=30,
121
+ max_silence_length_ms=0,
122
  max_speech_length_s=100,
123
  min_speech_length_s=0.1,
124
  sample_rate=args.expected_sample_rate,
 
163
  )
164
  waveform = np.array(waveform * (1 << 15), dtype=np.int16)
165
 
166
+ # non_silence_segments
167
+ non_silence_segments = get_non_silence_segments(waveform, sample_rate=args.expected_sample_rate)
168
+
169
  # vad
170
  vad_segments = list()
171
  segments = w_vad.vad(waveform)
 
174
  vad_segments += segments
175
  w_vad.reset()
176
 
177
+ vad_segments = get_intersection(non_silence_segments, vad_segments)
178
  row["vad_segments"] = vad_segments
179
 
180
  row = json.dumps(row, ensure_ascii=False)
 
221
  )
222
  waveform = np.array(waveform * (1 << 15), dtype=np.int16)
223
 
224
+ # non_silence_segments
225
+ non_silence_segments = get_non_silence_segments(waveform, sample_rate=args.expected_sample_rate)
226
+
227
  # vad
228
  vad_segments = list()
229
  segments = w_vad.vad(waveform)
 
232
  vad_segments += segments
233
  w_vad.reset()
234
 
235
+ vad_segments = get_intersection(non_silence_segments, vad_segments)
236
  row["vad_segments"] = vad_segments
237
 
238
  row = json.dumps(row, ensure_ascii=False)
examples/fsmn_vad_by_webrtcvad/step_4_train_model.py CHANGED
@@ -17,8 +17,6 @@ sys.path.append(os.path.join(pwd, "../../"))
17
 
18
  import numpy as np
19
  import torch
20
- import torch.nn as nn
21
- from torch.nn import functional as F
22
  from torch.utils.data.dataloader import DataLoader
23
  from tqdm import tqdm
24
 
@@ -38,7 +36,7 @@ def get_args():
38
  parser.add_argument("--valid_dataset", default="valid-vad.jsonl", type=str)
39
 
40
  parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
41
- parser.add_argument("--patience", default=30, type=int)
42
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
43
 
44
  parser.add_argument("--config_file", default="yaml/config.yaml", type=str)
@@ -74,22 +72,28 @@ class CollateFunction(object):
74
 
75
  def __call__(self, batch: List[dict]):
76
  noisy_audios = list()
 
77
  batch_vad_segments = list()
78
 
79
  for sample in batch:
80
  noisy_wave: torch.Tensor = sample["noisy_wave"]
 
81
  vad_segments: List[Tuple[float, float]] = sample["vad_segments"]
82
 
83
  noisy_audios.append(noisy_wave)
 
84
  batch_vad_segments.append(vad_segments)
85
 
86
  noisy_audios = torch.stack(noisy_audios)
 
87
 
88
  # assert
89
  if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
90
  raise AssertionError("nan or inf in noisy_audios")
 
 
91
 
92
- return noisy_audios, batch_vad_segments
93
 
94
 
95
  collate_fn = CollateFunction()
@@ -214,6 +218,7 @@ def main():
214
  average_loss = 1000000000
215
  average_bce_loss = 1000000000
216
  average_dice_loss = 1000000000
 
217
 
218
  accuracy = -1
219
  f1 = -1
@@ -242,6 +247,7 @@ def main():
242
  total_loss = 0.
243
  total_bce_loss = 0.
244
  total_dice_loss = 0.
 
245
  total_batches = 0.
246
 
247
  progress_bar_train = tqdm(
@@ -249,19 +255,22 @@ def main():
249
  desc="Training; epoch-{}".format(epoch_idx),
250
  )
251
  for train_batch in train_data_loader:
252
- noisy_audios, batch_vad_segments = train_batch
253
  noisy_audios: torch.Tensor = noisy_audios.to(device)
 
254
  # noisy_audios shape: [b, num_samples]
255
  num_samples = noisy_audios.shape[-1]
256
 
257
- logits, probs = model.forward(noisy_audios)
 
258
 
259
  targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
260
 
261
  bce_loss = bce_loss_fn.forward(probs, targets)
262
  dice_loss = dice_loss_fn.forward(probs, targets)
 
263
 
264
- loss = 1.0 * bce_loss + 1.0 * dice_loss
265
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
266
  logger.info(f"find nan or inf in loss. continue.")
267
  continue
@@ -278,11 +287,13 @@ def main():
278
  total_loss += loss.item()
279
  total_bce_loss += bce_loss.item()
280
  total_dice_loss += dice_loss.item()
 
281
  total_batches += 1
282
 
283
  average_loss = round(total_loss / total_batches, 4)
284
  average_bce_loss = round(total_bce_loss / total_batches, 4)
285
  average_dice_loss = round(total_dice_loss / total_batches, 4)
 
286
 
287
  metrics = vad_accuracy_metrics_fn.get_metric()
288
  accuracy = metrics["accuracy"]
@@ -297,6 +308,7 @@ def main():
297
  "loss": average_loss,
298
  "bce_loss": average_bce_loss,
299
  "dice_loss": average_dice_loss,
 
300
  "accuracy": accuracy,
301
  "f1": f1,
302
  "precision": precision,
@@ -316,6 +328,7 @@ def main():
316
  total_loss = 0.
317
  total_bce_loss = 0.
318
  total_dice_loss = 0.
 
319
  total_batches = 0.
320
 
321
  progress_bar_train.close()
@@ -323,19 +336,22 @@ def main():
323
  desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
324
  )
325
  for eval_batch in valid_data_loader:
326
- noisy_audios, batch_vad_segments = eval_batch
327
  noisy_audios: torch.Tensor = noisy_audios.to(device)
 
328
  # noisy_audios shape: [b, num_samples]
329
  num_samples = noisy_audios.shape[-1]
330
 
331
- logits, probs = model.forward(noisy_audios)
 
332
 
333
  targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
334
 
335
  bce_loss = bce_loss_fn.forward(probs, targets)
336
  dice_loss = dice_loss_fn.forward(probs, targets)
 
337
 
338
- loss = 1.0 * bce_loss + 1.0 * dice_loss
339
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
340
  logger.info(f"find nan or inf in loss. continue.")
341
  continue
@@ -346,11 +362,13 @@ def main():
346
  total_loss += loss.item()
347
  total_bce_loss += bce_loss.item()
348
  total_dice_loss += dice_loss.item()
 
349
  total_batches += 1
350
 
351
  average_loss = round(total_loss / total_batches, 4)
352
  average_bce_loss = round(total_bce_loss / total_batches, 4)
353
  average_dice_loss = round(total_dice_loss / total_batches, 4)
 
354
 
355
  metrics = vad_accuracy_metrics_fn.get_metric()
356
  accuracy = metrics["accuracy"]
@@ -365,6 +383,7 @@ def main():
365
  "loss": average_loss,
366
  "bce_loss": average_bce_loss,
367
  "dice_loss": average_dice_loss,
 
368
  "accuracy": accuracy,
369
  "f1": f1,
370
  "precision": precision,
@@ -378,6 +397,7 @@ def main():
378
  total_loss = 0.
379
  total_bce_loss = 0.
380
  total_dice_loss = 0.
 
381
  total_batches = 0.
382
 
383
  progress_bar_eval.close()
@@ -419,8 +439,12 @@ def main():
419
  "loss": average_loss,
420
  "bce_loss": average_bce_loss,
421
  "dice_loss": average_dice_loss,
 
422
 
423
  "accuracy": accuracy,
 
 
 
424
  }
425
  metrics_filename = save_dir / "metrics_epoch.json"
426
  with open(metrics_filename, "w", encoding="utf-8") as f:
 
17
 
18
  import numpy as np
19
  import torch
 
 
20
  from torch.utils.data.dataloader import DataLoader
21
  from tqdm import tqdm
22
 
 
36
  parser.add_argument("--valid_dataset", default="valid-vad.jsonl", type=str)
37
 
38
  parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
39
+ parser.add_argument("--patience", default=10, type=int)
40
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
41
 
42
  parser.add_argument("--config_file", default="yaml/config.yaml", type=str)
 
72
 
73
  def __call__(self, batch: List[dict]):
74
  noisy_audios = list()
75
+ clean_audios = list()
76
  batch_vad_segments = list()
77
 
78
  for sample in batch:
79
  noisy_wave: torch.Tensor = sample["noisy_wave"]
80
+ clean_wave: torch.Tensor = sample["clean_wave"]
81
  vad_segments: List[Tuple[float, float]] = sample["vad_segments"]
82
 
83
  noisy_audios.append(noisy_wave)
84
+ clean_audios.append(clean_wave)
85
  batch_vad_segments.append(vad_segments)
86
 
87
  noisy_audios = torch.stack(noisy_audios)
88
+ clean_audios = torch.stack(clean_audios)
89
 
90
  # assert
91
  if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
92
  raise AssertionError("nan or inf in noisy_audios")
93
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
94
+ raise AssertionError("nan or inf in clean_audios")
95
 
96
+ return noisy_audios, clean_audios, batch_vad_segments
97
 
98
 
99
  collate_fn = CollateFunction()
 
218
  average_loss = 1000000000
219
  average_bce_loss = 1000000000
220
  average_dice_loss = 1000000000
221
+ average_lsnr_loss = 1000000000
222
 
223
  accuracy = -1
224
  f1 = -1
 
247
  total_loss = 0.
248
  total_bce_loss = 0.
249
  total_dice_loss = 0.
250
+ total_lsnr_loss = 0.
251
  total_batches = 0.
252
 
253
  progress_bar_train = tqdm(
 
255
  desc="Training; epoch-{}".format(epoch_idx),
256
  )
257
  for train_batch in train_data_loader:
258
+ noisy_audios, clean_audios, batch_vad_segments = train_batch
259
  noisy_audios: torch.Tensor = noisy_audios.to(device)
260
+ clean_audios: torch.Tensor = clean_audios.to(device)
261
  # noisy_audios shape: [b, num_samples]
262
  num_samples = noisy_audios.shape[-1]
263
 
264
+ logits, probs, lsnr = model.forward(noisy_audios)
265
+ lsnr = torch.squeeze(lsnr, dim=-1)
266
 
267
  targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
268
 
269
  bce_loss = bce_loss_fn.forward(probs, targets)
270
  dice_loss = dice_loss_fn.forward(probs, targets)
271
+ lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
272
 
273
+ loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
274
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
275
  logger.info(f"find nan or inf in loss. continue.")
276
  continue
 
287
  total_loss += loss.item()
288
  total_bce_loss += bce_loss.item()
289
  total_dice_loss += dice_loss.item()
290
+ total_lsnr_loss += lsnr_loss.item()
291
  total_batches += 1
292
 
293
  average_loss = round(total_loss / total_batches, 4)
294
  average_bce_loss = round(total_bce_loss / total_batches, 4)
295
  average_dice_loss = round(total_dice_loss / total_batches, 4)
296
+ average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
297
 
298
  metrics = vad_accuracy_metrics_fn.get_metric()
299
  accuracy = metrics["accuracy"]
 
308
  "loss": average_loss,
309
  "bce_loss": average_bce_loss,
310
  "dice_loss": average_dice_loss,
311
+ "lsnr_loss": average_lsnr_loss,
312
  "accuracy": accuracy,
313
  "f1": f1,
314
  "precision": precision,
 
328
  total_loss = 0.
329
  total_bce_loss = 0.
330
  total_dice_loss = 0.
331
+ total_lsnr_loss = 0.
332
  total_batches = 0.
333
 
334
  progress_bar_train.close()
 
336
  desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
337
  )
338
  for eval_batch in valid_data_loader:
339
+ noisy_audios, clean_audios, batch_vad_segments = eval_batch
340
  noisy_audios: torch.Tensor = noisy_audios.to(device)
341
+ clean_audios: torch.Tensor = clean_audios.to(device)
342
  # noisy_audios shape: [b, num_samples]
343
  num_samples = noisy_audios.shape[-1]
344
 
345
+ logits, probs, lsnr = model.forward(noisy_audios)
346
+ lsnr = torch.squeeze(lsnr, dim=-1)
347
 
348
  targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
349
 
350
  bce_loss = bce_loss_fn.forward(probs, targets)
351
  dice_loss = dice_loss_fn.forward(probs, targets)
352
+ lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
353
 
354
+ loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
355
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
356
  logger.info(f"find nan or inf in loss. continue.")
357
  continue
 
362
  total_loss += loss.item()
363
  total_bce_loss += bce_loss.item()
364
  total_dice_loss += dice_loss.item()
365
+ total_lsnr_loss += lsnr_loss.item()
366
  total_batches += 1
367
 
368
  average_loss = round(total_loss / total_batches, 4)
369
  average_bce_loss = round(total_bce_loss / total_batches, 4)
370
  average_dice_loss = round(total_dice_loss / total_batches, 4)
371
+ average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
372
 
373
  metrics = vad_accuracy_metrics_fn.get_metric()
374
  accuracy = metrics["accuracy"]
 
383
  "loss": average_loss,
384
  "bce_loss": average_bce_loss,
385
  "dice_loss": average_dice_loss,
386
+ "lsnr_loss": average_lsnr_loss,
387
  "accuracy": accuracy,
388
  "f1": f1,
389
  "precision": precision,
 
397
  total_loss = 0.
398
  total_bce_loss = 0.
399
  total_dice_loss = 0.
400
+ total_lsnr_loss = 0.
401
  total_batches = 0.
402
 
403
  progress_bar_eval.close()
 
439
  "loss": average_loss,
440
  "bce_loss": average_bce_loss,
441
  "dice_loss": average_dice_loss,
442
+ "lsnr_loss": average_lsnr_loss,
443
 
444
  "accuracy": accuracy,
445
+ "f1": f1,
446
+ "precision": precision,
447
+ "recall": recall,
448
  }
449
  metrics_filename = save_dir / "metrics_epoch.json"
450
  with open(metrics_filename, "w", encoding="utf-8") as f:
examples/fsmn_vad_by_webrtcvad/yaml/config.yaml CHANGED
@@ -18,9 +18,13 @@ fsmn_basic_block_rorder: 0
18
  fsmn_basic_block_lstride: 1
19
  fsmn_basic_block_rstride: 0
20
  fsmn_output_affine_size: 140
21
- fsmn_output_size: 1
22
 
23
- use_softmax: false
 
 
 
 
24
 
25
  # data
26
  min_snr_db: -10
 
18
  fsmn_basic_block_lstride: 1
19
  fsmn_basic_block_rstride: 0
20
  fsmn_output_affine_size: 140
21
+ fsmn_output_size: 2
22
 
23
+ # lsnr
24
+ n_frame: 3
25
+ min_local_snr_db: -15
26
+ max_local_snr_db: 30
27
+ norm_tau: 1.
28
 
29
  # data
30
  min_snr_db: -10
examples/silero_vad_by_webrtcvad/run.sh CHANGED
@@ -5,14 +5,16 @@
5
  bash run.sh --stage 2 --stop_stage 2 --system_version centos \
6
  --file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
7
  --final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
8
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
9
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
 
10
 
11
  bash run.sh --stage 3 --stop_stage 3 --system_version centos \
12
  --file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
13
  --final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
14
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
15
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
 
16
 
17
 
18
  END
@@ -30,8 +32,8 @@ final_model_name=final_model_name
30
  config_file="yaml/config.yaml"
31
  limit=10
32
 
33
- noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
34
- speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
35
 
36
  max_count=-1
37
 
@@ -98,8 +100,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
98
  $verbose && echo "stage 1: prepare data"
99
  cd "${work_dir}" || exit 1
100
  python3 step_1_prepare_data.py \
101
- --noise_dir "${noise_dir}" \
102
- --speech_dir "${speech_dir}" \
103
  --train_dataset "${train_dataset}" \
104
  --valid_dataset "${valid_dataset}" \
105
  --max_count "${max_count}" \
 
5
  bash run.sh --stage 2 --stop_stage 2 --system_version centos \
6
  --file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
7
  --final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
8
+ --noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
9
+ --speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
10
+ /data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
11
 
12
  bash run.sh --stage 3 --stop_stage 3 --system_version centos \
13
  --file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
14
  --final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
15
+ --noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
16
+ --speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
17
+ /data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
18
 
19
 
20
  END
 
32
  config_file="yaml/config.yaml"
33
  limit=10
34
 
35
+ noise_patterns=/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav
36
+ speech_patterns=/data/tianxing/HuggingDatasets/nx_noise/data/speech/**/*.wav
37
 
38
  max_count=-1
39
 
 
100
  $verbose && echo "stage 1: prepare data"
101
  cd "${work_dir}" || exit 1
102
  python3 step_1_prepare_data.py \
103
+ --noise_patterns "${noise_patterns}" \
104
+ --speech_patterns "${speech_patterns}" \
105
  --train_dataset "${train_dataset}" \
106
  --valid_dataset "${valid_dataset}" \
107
  --max_count "${max_count}" \
examples/silero_vad_by_webrtcvad/step_1_prepare_data.py CHANGED
@@ -1,12 +1,14 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
 
4
  import json
5
  import os
6
  from pathlib import Path
7
  import random
8
  import sys
9
  import time
 
10
 
11
  pwd = os.path.abspath(os.path.dirname(__file__))
12
  sys.path.append(os.path.join(pwd, "../../"))
@@ -19,13 +21,13 @@ from tqdm import tqdm
19
  def get_args():
20
  parser = argparse.ArgumentParser()
21
  parser.add_argument(
22
- "--noise_dir",
23
- default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
24
  type=str
25
  )
26
  parser.add_argument(
27
- "--speech_dir",
28
- default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech",
29
  type=str
30
  )
31
 
@@ -46,108 +48,112 @@ def get_args():
46
  return args
47
 
48
 
49
- def target_second_noise_signal_generator(data_dir: str,
50
  duration: int = 4,
51
  sample_rate: int = 8000, max_epoch: int = 20000):
52
  noise_list = list()
53
  wait_duration = duration
54
 
55
- data_dir = Path(data_dir)
56
  for epoch_idx in range(max_epoch):
57
- for filename in data_dir.glob("**/*.wav"):
58
- signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
-
60
- if signal.ndim != 1:
61
- raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
 
64
 
65
- offset = 0.
66
- rest_duration = raw_duration
 
67
 
68
- for _ in range(1000):
69
- if rest_duration <= 0:
70
- break
71
- if rest_duration <= wait_duration:
72
- noise_list.append({
73
  "epoch_idx": epoch_idx,
74
- "filename": filename.as_posix(),
75
  "raw_duration": round(raw_duration, 4),
76
- "offset": round(offset, 4),
77
- "duration": None,
78
- "duration_": round(rest_duration, 4),
79
- })
80
- wait_duration -= rest_duration
81
- offset = 0
82
- rest_duration = 0
83
- elif rest_duration > wait_duration:
84
- noise_list.append({
 
 
85
  "epoch_idx": epoch_idx,
86
- "filename": filename.as_posix(),
87
  "raw_duration": round(raw_duration, 4),
88
- "offset": round(offset, 4),
89
- "duration": round(wait_duration, 4),
90
- "duration_": round(wait_duration, 4),
91
- })
92
- offset += wait_duration
93
- rest_duration -= wait_duration
94
- wait_duration = 0
95
- else:
96
- raise AssertionError
97
-
98
- if wait_duration <= 0:
99
- yield noise_list
100
- noise_list = list()
101
- wait_duration = duration
102
-
103
-
104
- def target_second_speech_signal_generator(data_dir: str,
105
- min_duration: int = 4,
106
- max_duration: int = 6,
107
- sample_rate: int = 8000, max_epoch: int = 1):
108
- data_dir = Path(data_dir)
109
- for epoch_idx in range(max_epoch):
110
- for filename in data_dir.glob("**/*.wav"):
111
- signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
112
- raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
113
-
114
- if signal.ndim != 1:
115
- raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
116
-
117
- if raw_duration < min_duration:
118
- # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
119
- continue
120
-
121
- if raw_duration < max_duration:
122
- row = {
123
- "epoch_idx": epoch_idx,
124
- "filename": filename.as_posix(),
125
- "raw_duration": round(raw_duration, 4),
126
- "offset": 0.,
127
- "duration": round(raw_duration, 4),
128
- }
129
- yield row
130
-
131
- signal_length = len(signal)
132
- win_size = int(max_duration * sample_rate)
133
- for begin in range(0, signal_length - win_size, win_size):
134
- if np.sum(signal[begin: begin+win_size]) == 0:
135
- continue
136
- row = {
137
- "epoch_idx": epoch_idx,
138
- "filename": filename.as_posix(),
139
- "raw_duration": round(raw_duration, 4),
140
- "offset": round(begin / sample_rate, 4),
141
- "duration": round(max_duration, 4),
142
- }
143
- yield row
144
 
145
 
146
  def main():
147
  args = get_args()
148
 
149
- noise_dir = Path(args.noise_dir)
150
- speech_dir = Path(args.speech_dir)
 
 
 
 
151
 
152
  train_dataset = Path(args.train_dataset)
153
  valid_dataset = Path(args.valid_dataset)
@@ -155,13 +161,13 @@ def main():
155
  valid_dataset.parent.mkdir(parents=True, exist_ok=True)
156
 
157
  noise_generator = target_second_noise_signal_generator(
158
- noise_dir.as_posix(),
159
  duration=args.duration,
160
  sample_rate=args.target_sample_rate,
161
  max_epoch=100000,
162
  )
163
  speech_generator = target_second_speech_signal_generator(
164
- speech_dir.as_posix(),
165
  min_duration=args.min_speech_duration,
166
  max_duration=args.max_speech_duration,
167
  sample_rate=args.target_sample_rate,
@@ -210,7 +216,7 @@ def main():
210
  "random1": random1,
211
  }
212
  row = json.dumps(row, ensure_ascii=False)
213
- if random2 < (1 / 300 / 1):
214
  fvalid.write(f"{row}\n")
215
  else:
216
  ftrain.write(f"{row}\n")
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
+ from glob import glob
5
  import json
6
  import os
7
  from pathlib import Path
8
  import random
9
  import sys
10
  import time
11
+ from typing import List
12
 
13
  pwd = os.path.abspath(os.path.dirname(__file__))
14
  sys.path.append(os.path.join(pwd, "../../"))
 
21
  def get_args():
22
  parser = argparse.ArgumentParser()
23
  parser.add_argument(
24
+ "--noise_patterns",
25
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\**\*.wav",
26
  type=str
27
  )
28
  parser.add_argument(
29
+ "--speech_patterns",
30
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\**\*.wav",
31
  type=str
32
  )
33
 
 
48
  return args
49
 
50
 
51
+ def target_second_noise_signal_generator(filename_patterns: List[str],
52
  duration: int = 4,
53
  sample_rate: int = 8000, max_epoch: int = 20000):
54
  noise_list = list()
55
  wait_duration = duration
56
 
 
57
  for epoch_idx in range(max_epoch):
58
+ for filename_pattern in filename_patterns:
59
+ for filename in glob(filename_pattern):
60
+ signal, _ = librosa.load(filename, sr=sample_rate)
61
+
62
+ if signal.ndim != 1:
63
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
64
+
65
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
66
+
67
+ offset = 0.
68
+ rest_duration = raw_duration
69
+
70
+ for _ in range(1000):
71
+ if rest_duration <= 0:
72
+ break
73
+ if rest_duration <= wait_duration:
74
+ noise_list.append({
75
+ "epoch_idx": epoch_idx,
76
+ "filename": filename,
77
+ "raw_duration": round(raw_duration, 4),
78
+ "offset": round(offset, 4),
79
+ "duration": None,
80
+ "duration_": round(rest_duration, 4),
81
+ })
82
+ wait_duration -= rest_duration
83
+ offset = 0
84
+ rest_duration = 0
85
+ elif rest_duration > wait_duration:
86
+ noise_list.append({
87
+ "epoch_idx": epoch_idx,
88
+ "filename": filename,
89
+ "raw_duration": round(raw_duration, 4),
90
+ "offset": round(offset, 4),
91
+ "duration": round(wait_duration, 4),
92
+ "duration_": round(wait_duration, 4),
93
+ })
94
+ offset += wait_duration
95
+ rest_duration -= wait_duration
96
+ wait_duration = 0
97
+ else:
98
+ raise AssertionError
99
+
100
+ if wait_duration <= 0:
101
+ yield noise_list
102
+ noise_list = list()
103
+ wait_duration = duration
104
+
105
+
106
+ def target_second_speech_signal_generator(filename_patterns: List[str],
107
+ min_duration: int = 4,
108
+ max_duration: int = 6,
109
+ sample_rate: int = 8000, max_epoch: int = 1):
110
+ for epoch_idx in range(max_epoch):
111
+ for filename_pattern in filename_patterns:
112
+ for filename in glob(filename_pattern):
113
+ signal, _ = librosa.load(filename, sr=sample_rate)
114
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
115
 
116
+ if signal.ndim != 1:
117
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
118
 
119
+ if raw_duration < min_duration:
120
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
121
+ continue
122
 
123
+ if raw_duration < max_duration:
124
+ row = {
 
 
 
125
  "epoch_idx": epoch_idx,
126
+ "filename": filename,
127
  "raw_duration": round(raw_duration, 4),
128
+ "offset": 0.,
129
+ "duration": round(raw_duration, 4),
130
+ }
131
+ yield row
132
+
133
+ signal_length = len(signal)
134
+ win_size = int(max_duration * sample_rate)
135
+ for begin in range(0, signal_length - win_size, win_size):
136
+ if np.sum(signal[begin: begin+win_size]) == 0:
137
+ continue
138
+ row = {
139
  "epoch_idx": epoch_idx,
140
+ "filename": filename,
141
  "raw_duration": round(raw_duration, 4),
142
+ "offset": round(begin / sample_rate, 4),
143
+ "duration": round(max_duration, 4),
144
+ }
145
+ yield row
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
 
148
  def main():
149
  args = get_args()
150
 
151
+ noise_patterns = args.noise_patterns
152
+ noise_patterns = noise_patterns.split(" ")
153
+ print(f"noise_patterns: {noise_patterns}")
154
+ speech_patterns = args.speech_patterns
155
+ speech_patterns = speech_patterns.split(" ")
156
+ print(f"speech_patterns: {speech_patterns}")
157
 
158
  train_dataset = Path(args.train_dataset)
159
  valid_dataset = Path(args.valid_dataset)
 
161
  valid_dataset.parent.mkdir(parents=True, exist_ok=True)
162
 
163
  noise_generator = target_second_noise_signal_generator(
164
+ noise_patterns,
165
  duration=args.duration,
166
  sample_rate=args.target_sample_rate,
167
  max_epoch=100000,
168
  )
169
  speech_generator = target_second_speech_signal_generator(
170
+ speech_patterns,
171
  min_duration=args.min_speech_duration,
172
  max_duration=args.max_speech_duration,
173
  sample_rate=args.target_sample_rate,
 
216
  "random1": random1,
217
  }
218
  row = json.dumps(row, ensure_ascii=False)
219
+ if random2 < (2 / 300):
220
  fvalid.write(f"{row}\n")
221
  else:
222
  ftrain.write(f"{row}\n")
examples/silero_vad_by_webrtcvad/step_4_train_model.py CHANGED
@@ -17,8 +17,6 @@ sys.path.append(os.path.join(pwd, "../../"))
17
 
18
  import numpy as np
19
  import torch
20
- import torch.nn as nn
21
- from torch.nn import functional as F
22
  from torch.utils.data.dataloader import DataLoader
23
  from tqdm import tqdm
24
 
@@ -38,7 +36,7 @@ def get_args():
38
  parser.add_argument("--valid_dataset", default="valid-vad.jsonl", type=str)
39
 
40
  parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
41
- parser.add_argument("--patience", default=30, type=int)
42
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
43
 
44
  parser.add_argument("--config_file", default="yaml/config.yaml", type=str)
@@ -74,22 +72,28 @@ class CollateFunction(object):
74
 
75
  def __call__(self, batch: List[dict]):
76
  noisy_audios = list()
 
77
  batch_vad_segments = list()
78
 
79
  for sample in batch:
80
  noisy_wave: torch.Tensor = sample["noisy_wave"]
 
81
  vad_segments: List[Tuple[float, float]] = sample["vad_segments"]
82
 
83
  noisy_audios.append(noisy_wave)
 
84
  batch_vad_segments.append(vad_segments)
85
 
86
  noisy_audios = torch.stack(noisy_audios)
 
87
 
88
  # assert
89
  if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
90
  raise AssertionError("nan or inf in noisy_audios")
 
 
91
 
92
- return noisy_audios, batch_vad_segments
93
 
94
 
95
  collate_fn = CollateFunction()
@@ -214,6 +218,7 @@ def main():
214
  average_loss = 1000000000
215
  average_bce_loss = 1000000000
216
  average_dice_loss = 1000000000
 
217
 
218
  accuracy = -1
219
  f1 = -1
@@ -242,6 +247,7 @@ def main():
242
  total_loss = 0.
243
  total_bce_loss = 0.
244
  total_dice_loss = 0.
 
245
  total_batches = 0.
246
 
247
  progress_bar_train = tqdm(
 
17
 
18
  import numpy as np
19
  import torch
 
 
20
  from torch.utils.data.dataloader import DataLoader
21
  from tqdm import tqdm
22
 
 
36
  parser.add_argument("--valid_dataset", default="valid-vad.jsonl", type=str)
37
 
38
  parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
39
+ parser.add_argument("--patience", default=10, type=int)
40
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
41
 
42
  parser.add_argument("--config_file", default="yaml/config.yaml", type=str)
 
72
 
73
  def __call__(self, batch: List[dict]):
74
  noisy_audios = list()
75
+ clean_audios = list()
76
  batch_vad_segments = list()
77
 
78
  for sample in batch:
79
  noisy_wave: torch.Tensor = sample["noisy_wave"]
80
+ clean_wave: torch.Tensor = sample["clean_wave"]
81
  vad_segments: List[Tuple[float, float]] = sample["vad_segments"]
82
 
83
  noisy_audios.append(noisy_wave)
84
+ clean_audios.append(clean_wave)
85
  batch_vad_segments.append(vad_segments)
86
 
87
  noisy_audios = torch.stack(noisy_audios)
88
+ clean_audios = torch.stack(clean_audios)
89
 
90
  # assert
91
  if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
92
  raise AssertionError("nan or inf in noisy_audios")
93
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
94
+ raise AssertionError("nan or inf in clean_audios")
95
 
96
+ return noisy_audios, clean_audios, batch_vad_segments
97
 
98
 
99
  collate_fn = CollateFunction()
 
218
  average_loss = 1000000000
219
  average_bce_loss = 1000000000
220
  average_dice_loss = 1000000000
221
+ average_lsnr_loss = 1000000000
222
 
223
  accuracy = -1
224
  f1 = -1
 
247
  total_loss = 0.
248
  total_bce_loss = 0.
249
  total_dice_loss = 0.
250
+ total_lsnr_loss = 0.
251
  total_batches = 0.
252
 
253
  progress_bar_train = tqdm(
examples/silero_vad_by_webrtcvad/yaml/config.yaml CHANGED
@@ -11,6 +11,12 @@ win_type: hann
11
  in_channels: 64
12
  hidden_size: 128
13
 
 
 
 
 
 
 
14
  # data
15
  min_snr_db: -10
16
  max_snr_db: 20
 
11
  in_channels: 64
12
  hidden_size: 128
13
 
14
+ # lsnr
15
+ n_frame: 3
16
+ min_local_snr_db: -15
17
+ max_local_snr_db: 30
18
+ norm_tau: 1.
19
+
20
  # data
21
  min_snr_db: -10
22
  max_snr_db: 20
toolbox/torchaudio/models/vad/cnn_vad/inference_cnn_vad.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import logging
5
+ from pathlib import Path
6
+ import shutil
7
+ import tempfile, time
8
+ from typing import List
9
+ import zipfile
10
+
11
+ from scipy.io import wavfile
12
+ import librosa
13
+ import numpy as np
14
+ import torch
15
+ import torchaudio
16
+
17
+ torch.set_num_threads(1)
18
+
19
+ from project_settings import project_path
20
+ from toolbox.torchaudio.models.vad.cnn_vad.configuration_cnn_vad import CNNVadConfig
21
+ from toolbox.torchaudio.models.vad.cnn_vad.modeling_cnn_vad import CNNVadPretrainedModel, MODEL_FILE
22
+ from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
23
+
24
+
25
+ logger = logging.getLogger("toolbox")
26
+
27
+
28
+ class InferenceSileroVad(object):
29
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
30
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
31
+ self.device = torch.device(device)
32
+
33
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
34
+ config, model = self.load_models(self.pretrained_model_path_or_zip_file)
35
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
36
+
37
+ self.config = config
38
+ self.model = model
39
+ self.model.to(device)
40
+ self.model.eval()
41
+
42
+ def load_models(self, model_path: str):
43
+ model_path = Path(model_path)
44
+ if model_path.name.endswith(".zip"):
45
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
46
+ out_root = Path(tempfile.gettempdir()) / "cc_vad"
47
+ out_root.mkdir(parents=True, exist_ok=True)
48
+ f_zip.extractall(path=out_root)
49
+ model_path = out_root / model_path.stem
50
+
51
+ config = CNNVadConfig.from_pretrained(
52
+ pretrained_model_name_or_path=model_path.as_posix(),
53
+ )
54
+ model = CNNVadPretrainedModel.from_pretrained(
55
+ pretrained_model_name_or_path=model_path.as_posix(),
56
+ )
57
+ model.to(self.device)
58
+ model.eval()
59
+
60
+ shutil.rmtree(model_path)
61
+ return config, model
62
+
63
+ def infer(self, signal: torch.Tensor) -> float:
64
+ # signal shape: [num_samples,], value between -1 and 1.
65
+
66
+ inputs = torch.tensor(signal, dtype=torch.float32)
67
+ inputs = torch.unsqueeze(inputs, dim=0)
68
+ # inputs shape: [1, num_samples,]
69
+
70
+ with torch.no_grad():
71
+ logits, probs, lsnr = self.model.forward(inputs)
72
+
73
+ # probs shape: [b, t, 1]
74
+ probs = torch.squeeze(probs, dim=-1)
75
+ # probs shape: [b, t]
76
+
77
+ probs = probs.numpy()
78
+ probs = probs[0]
79
+ probs = probs.tolist()
80
+ return probs
81
+
82
+ def post_process(self, probs: List[float]):
83
+ return
84
+
85
+
86
+ def get_args():
87
+ parser = argparse.ArgumentParser()
88
+ parser.add_argument(
89
+ "--wav_file",
90
+ # default=(project_path / "data/examples/ai_agent/chinese-4.wav").as_posix(),
91
+ # default=(project_path / "data/examples/ai_agent/chinese-5.wav").as_posix(),
92
+ # default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
93
+ # default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
94
+ # default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
95
+ # default=(project_path / "data/examples/speech/active_media_r_2a2f472b-a0b8-4fd5-b1c4-1aedc5d2ce57_0.wav").as_posix(),
96
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0b8a8e80-52af-423b-8877-03a78b1e6e43_0.wav",
97
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0ebffb68-6490-4a8b-8eb6-eb82443d7d75_0.wav",
98
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0f6ec933-90df-447b-aca4-6ddc149452ab_0.wav",
99
+ type=str,
100
+ )
101
+ args = parser.parse_args()
102
+ return args
103
+
104
+
105
+ SAMPLE_RATE = 8000
106
+
107
+
108
+ def main():
109
+ args = get_args()
110
+
111
+ sample_rate, signal = wavfile.read(args.wav_file)
112
+ if SAMPLE_RATE != sample_rate:
113
+ raise AssertionError
114
+ signal = signal / (1 << 15)
115
+
116
+ infer = InferenceSileroVad(
117
+ pretrained_model_path_or_zip_file=(project_path / "trained_models/cnn-vad-by-webrtcvad-nx-dns3.zip").as_posix()
118
+ # pretrained_model_path_or_zip_file=(project_path / "trained_models/silero-vad-webrtcvad-nx2-dns3.zip").as_posix()
119
+ )
120
+ frame_step = infer.model.hop_size
121
+
122
+ speech_probs = infer.infer(signal)
123
+
124
+ # print(speech_probs)
125
+
126
+ speech_probs = process_speech_probs(
127
+ signal=signal,
128
+ speech_probs=speech_probs,
129
+ frame_step=frame_step,
130
+ )
131
+
132
+ # plot
133
+ make_visualization(signal, speech_probs, SAMPLE_RATE)
134
+ return
135
+
136
+
137
+ if __name__ == "__main__":
138
+ main()
toolbox/torchaudio/models/vad/fsmn_vad/configuration_fsmn_vad.py CHANGED
@@ -23,9 +23,12 @@ class FSMNVadConfig(PretrainedConfig):
23
  fsmn_basic_block_lstride: int = 1,
24
  fsmn_basic_block_rstride: int = 0,
25
  fsmn_output_affine_size: int = 140,
26
- fsmn_output_size: int = 1,
27
 
28
- use_softmax: bool = False,
 
 
 
29
 
30
  min_snr_db: float = -10,
31
  max_snr_db: float = 20,
@@ -65,7 +68,11 @@ class FSMNVadConfig(PretrainedConfig):
65
  self.fsmn_output_affine_size = fsmn_output_affine_size
66
  self.fsmn_output_size = fsmn_output_size
67
 
68
- self.use_softmax = use_softmax
 
 
 
 
69
 
70
  # data snr
71
  self.min_snr_db = min_snr_db
 
23
  fsmn_basic_block_lstride: int = 1,
24
  fsmn_basic_block_rstride: int = 0,
25
  fsmn_output_affine_size: int = 140,
26
+ fsmn_output_size: int = 2,
27
 
28
+ n_frame: int = 3,
29
+ min_local_snr_db: float = -15,
30
+ max_local_snr_db: float = 30,
31
+ norm_tau: float = 1.,
32
 
33
  min_snr_db: float = -10,
34
  max_snr_db: float = 20,
 
68
  self.fsmn_output_affine_size = fsmn_output_affine_size
69
  self.fsmn_output_size = fsmn_output_size
70
 
71
+ # lsnr
72
+ self.n_frame = n_frame
73
+ self.min_local_snr_db = min_local_snr_db
74
+ self.max_local_snr_db = max_local_snr_db
75
+ self.norm_tau = norm_tau
76
 
77
  # data snr
78
  self.min_snr_db = min_snr_db
toolbox/torchaudio/models/vad/fsmn_vad/modeling_fsmn_vad.py CHANGED
@@ -15,48 +15,111 @@ from typing import Optional, Union
15
 
16
  import torch
17
  import torch.nn as nn
 
18
 
19
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
20
  from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
21
  from toolbox.torchaudio.modules.conv_stft import ConvSTFT
22
  from toolbox.torchaudio.models.vad.fsmn_vad.fsmn_encoder import FSMN
 
23
 
24
 
25
  MODEL_FILE = "model.pt"
26
 
27
 
28
  class FSMNVadModel(nn.Module):
29
- def __init__(self, config: FSMNVadConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  super(FSMNVadModel, self).__init__()
31
- self.config = config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  self.eps = 1e-12
33
 
34
  self.stft = ConvSTFT(
35
- nfft=config.nfft,
36
- win_size=config.win_size,
37
- hop_size=config.hop_size,
38
- win_type=config.win_type,
39
  power=1,
40
  requires_grad=False
41
  )
 
 
 
 
 
 
 
 
42
 
43
  self.fsmn_encoder = FSMN(
44
- input_size=config.fsmn_input_size,
45
- input_affine_size=config.fsmn_input_affine_size,
46
- hidden_size=config.fsmn_hidden_size,
47
- basic_block_layers=config.fsmn_basic_block_layers,
48
- basic_block_hidden_size=config.fsmn_basic_block_hidden_size,
49
- basic_block_lorder=config.fsmn_basic_block_lorder,
50
- basic_block_rorder=config.fsmn_basic_block_rorder,
51
- basic_block_lstride=config.fsmn_basic_block_lstride,
52
- basic_block_rstride=config.fsmn_basic_block_rstride,
53
- output_affine_size=config.fsmn_output_affine_size,
54
- output_size=config.fsmn_output_size,
 
55
  )
56
 
57
- self.use_softmax = config.use_softmax
58
- self.sigmoid = nn.Sigmoid()
59
- self.softmax = nn.Softmax()
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def forward(self, signal: torch.Tensor):
62
  if signal.dim() == 2:
@@ -71,14 +134,49 @@ class FSMNVadModel(nn.Module):
71
  # x shape: [b, t, f]
72
 
73
  logits, _ = self.fsmn_encoder.forward(x)
 
74
 
75
- if self.use_softmax:
76
- probs = self.softmax.forward(logits)
77
- # probs shape: [b, t, n]
78
- else:
79
- probs = self.sigmoid.forward(logits)
80
- # probs shape: [b, t, 1]
81
- return logits, probs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  class FSMNVadPretrainedModel(FSMNVadModel):
@@ -86,8 +184,26 @@ class FSMNVadPretrainedModel(FSMNVadModel):
86
  config: FSMNVadConfig,
87
  ):
88
  super(FSMNVadPretrainedModel, self).__init__(
89
- config=config,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  )
 
91
 
92
  @classmethod
93
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
@@ -133,10 +249,11 @@ def main():
133
 
134
  noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
135
 
136
- logits, probs = model.forward(noisy)
137
- print(f"probs: {probs}")
138
- print(f"probs.shape: {logits.shape}")
139
- print(f"use_softmax: {config.use_softmax}")
 
140
  return
141
 
142
 
 
15
 
16
  import torch
17
  import torch.nn as nn
18
+ from torch.nn import functional as F
19
 
20
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
21
  from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
22
  from toolbox.torchaudio.modules.conv_stft import ConvSTFT
23
  from toolbox.torchaudio.models.vad.fsmn_vad.fsmn_encoder import FSMN
24
+ from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
25
 
26
 
27
  MODEL_FILE = "model.pt"
28
 
29
 
30
  class FSMNVadModel(nn.Module):
31
+ def __init__(self,
32
+ sample_rate: int,
33
+ nfft: int,
34
+ win_size: int,
35
+ hop_size: int,
36
+ win_type: int,
37
+
38
+ fsmn_input_size: int,
39
+ fsmn_input_affine_size: int,
40
+ fsmn_hidden_size: int,
41
+ fsmn_basic_block_layers: int,
42
+ fsmn_basic_block_hidden_size: int,
43
+ fsmn_basic_block_lorder: int,
44
+ fsmn_basic_block_rorder: int,
45
+ fsmn_basic_block_lstride: int,
46
+ fsmn_basic_block_rstride: int,
47
+ fsmn_output_affine_size: int,
48
+
49
+ n_frame: int,
50
+ min_local_snr_db: float,
51
+ max_local_snr_db: float,
52
+ ):
53
  super(FSMNVadModel, self).__init__()
54
+ self.sample_rate = sample_rate
55
+ self.nfft = nfft
56
+ self.win_size = win_size
57
+ self.hop_size = hop_size
58
+ self.win_type = win_type
59
+
60
+ self.fsmn_input_size = fsmn_input_size
61
+ self.fsmn_input_affine_size = fsmn_input_affine_size
62
+ self.fsmn_hidden_size = fsmn_hidden_size
63
+ self.fsmn_basic_block_layers = fsmn_basic_block_layers
64
+ self.fsmn_basic_block_hidden_size = fsmn_basic_block_hidden_size
65
+ self.fsmn_basic_block_lorder = fsmn_basic_block_lorder
66
+ self.fsmn_basic_block_rorder = fsmn_basic_block_rorder
67
+ self.fsmn_basic_block_lstride = fsmn_basic_block_lstride
68
+ self.fsmn_basic_block_rstride = fsmn_basic_block_rstride
69
+ self.fsmn_output_affine_size = fsmn_output_affine_size
70
+
71
+ self.n_frame = n_frame
72
+ self.min_local_snr_db = min_local_snr_db
73
+ self.max_local_snr_db = max_local_snr_db
74
+
75
  self.eps = 1e-12
76
 
77
  self.stft = ConvSTFT(
78
+ nfft=self.nfft,
79
+ win_size=self.win_size,
80
+ hop_size=self.hop_size,
81
+ win_type=self.win_type,
82
  power=1,
83
  requires_grad=False
84
  )
85
+ self.complex_stft = ConvSTFT(
86
+ nfft=self.nfft,
87
+ win_size=self.win_size,
88
+ hop_size=self.hop_size,
89
+ win_type=self.win_type,
90
+ power=None,
91
+ requires_grad=False
92
+ )
93
 
94
  self.fsmn_encoder = FSMN(
95
+ input_size=self.fsmn_input_size,
96
+ input_affine_size=self.fsmn_input_affine_size,
97
+ hidden_size=self.fsmn_hidden_size,
98
+ basic_block_layers=self.fsmn_basic_block_layers,
99
+ basic_block_hidden_size=self.fsmn_basic_block_hidden_size,
100
+ basic_block_lorder=self.fsmn_basic_block_lorder,
101
+ basic_block_rorder=self.fsmn_basic_block_rorder,
102
+ basic_block_lstride=self.fsmn_basic_block_lstride,
103
+ basic_block_rstride=self.fsmn_basic_block_rstride,
104
+ output_affine_size=self.fsmn_output_affine_size,
105
+ output_size=2,
106
+ # output_size=self.fsmn_output_size,
107
  )
108
 
109
+ # lsnr
110
+ self.lsnr_scale = self.max_local_snr_db - self.min_local_snr_db
111
+ self.lsnr_offset = self.min_local_snr_db
112
+
113
+ self.lsnr_fn = LocalSnrTarget(
114
+ sample_rate=self.sample_rate,
115
+ nfft=self.nfft,
116
+ win_size=self.win_size,
117
+ hop_size=self.hop_size,
118
+ n_frame=self.n_frame,
119
+ min_local_snr=self.min_local_snr_db,
120
+ max_local_snr=self.max_local_snr_db,
121
+ db=True,
122
+ )
123
 
124
  def forward(self, signal: torch.Tensor):
125
  if signal.dim() == 2:
 
134
  # x shape: [b, t, f]
135
 
136
  logits, _ = self.fsmn_encoder.forward(x)
137
+ # logits shape: [b, t, 2]
138
 
139
+ splits = torch.split(logits, split_size_or_sections=[1, 1], dim=-1)
140
+ vad_logits = splits[0]
141
+ snr_logits = splits[1]
142
+ # shape: [b, t, 1]
143
+ vad_probs = F.sigmoid(vad_logits)
144
+ # vad_probs shape: [b, t, 1]
145
+
146
+ lsnr = F.sigmoid(snr_logits) * self.lsnr_scale + self.lsnr_offset
147
+ # lsnr shape: [b, t, 1]
148
+
149
+ return vad_logits, vad_probs, lsnr
150
+
151
+ def lsnr_loss_fn(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
152
+ if noisy.shape != clean.shape:
153
+ raise AssertionError("Input signals must have the same shape")
154
+ noise = noisy - clean
155
+
156
+ if clean.dim() == 2:
157
+ clean = torch.unsqueeze(clean, dim=1)
158
+ if noise.dim() == 2:
159
+ noise = torch.unsqueeze(noise, dim=1)
160
+
161
+ stft_clean = self.complex_stft.forward(clean)
162
+ stft_noise = self.complex_stft.forward(noise)
163
+ # shape: [b, f, t]
164
+ stft_clean = torch.transpose(stft_clean, dim0=1, dim1=2)
165
+ stft_noise = torch.transpose(stft_noise, dim0=1, dim1=2)
166
+ # shape: [b, t, f]
167
+ stft_clean = torch.unsqueeze(stft_clean, dim=1)
168
+ stft_noise = torch.unsqueeze(stft_noise, dim=1)
169
+ # shape: [b, 1, t, f]
170
+
171
+ # lsnr shape: [b, 1, t]
172
+ lsnr = lsnr.squeeze(1)
173
+ # lsnr shape: [b, t]
174
+
175
+ lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
176
+ # lsnr_gth shape: [b, t]
177
+
178
+ loss = F.mse_loss(lsnr, lsnr_gth)
179
+ return loss
180
 
181
 
182
  class FSMNVadPretrainedModel(FSMNVadModel):
 
184
  config: FSMNVadConfig,
185
  ):
186
  super(FSMNVadPretrainedModel, self).__init__(
187
+ sample_rate=config.sample_rate,
188
+ nfft=config.nfft,
189
+ win_size=config.win_size,
190
+ hop_size=config.hop_size,
191
+ win_type=config.win_type,
192
+ fsmn_input_size=config.fsmn_input_size,
193
+ fsmn_input_affine_size=config.fsmn_input_affine_size,
194
+ fsmn_hidden_size=config.fsmn_hidden_size,
195
+ fsmn_basic_block_layers=config.fsmn_basic_block_layers,
196
+ fsmn_basic_block_hidden_size=config.fsmn_basic_block_hidden_size,
197
+ fsmn_basic_block_lorder=config.fsmn_basic_block_lorder,
198
+ fsmn_basic_block_rorder=config.fsmn_basic_block_rorder,
199
+ fsmn_basic_block_lstride=config.fsmn_basic_block_lstride,
200
+ fsmn_basic_block_rstride=config.fsmn_basic_block_rstride,
201
+ fsmn_output_affine_size=config.fsmn_output_affine_size,
202
+ n_frame=config.n_frame,
203
+ min_local_snr_db=config.min_local_snr_db,
204
+ max_local_snr_db=config.max_local_snr_db,
205
  )
206
+ self.config = config
207
 
208
  @classmethod
209
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
 
249
 
250
  noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
251
 
252
+ logits, probs, lsnr = model.forward(noisy)
253
+ print(f"logits.shape: {logits.shape}")
254
+ print(f"probs.shape: {probs.shape}")
255
+ print(f"lsnr.shape: {lsnr.shape}")
256
+
257
  return
258
 
259
 
toolbox/torchaudio/models/vad/fsmn_vad/yaml/{config-sigmoid.yaml → config.yaml} RENAMED
@@ -18,9 +18,13 @@ fsmn_basic_block_rorder: 0
18
  fsmn_basic_block_lstride: 1
19
  fsmn_basic_block_rstride: 0
20
  fsmn_output_affine_size: 140
21
- fsmn_output_size: 1
22
 
23
- use_softmax: false
 
 
 
 
24
 
25
  # data
26
  min_snr_db: -10
 
18
  fsmn_basic_block_lstride: 1
19
  fsmn_basic_block_rstride: 0
20
  fsmn_output_affine_size: 140
21
+ fsmn_output_size: 2
22
 
23
+ # lsnr
24
+ n_frame: 3
25
+ min_local_snr_db: -15
26
+ max_local_snr_db: 30
27
+ norm_tau: 1.
28
 
29
  # data
30
  min_snr_db: -10
toolbox/torchaudio/models/vad/silero_vad/configuration_silero_vad.py CHANGED
@@ -16,6 +16,11 @@ class SileroVadConfig(PretrainedConfig):
16
  in_channels: int = 64,
17
  hidden_size: int = 128,
18
 
 
 
 
 
 
19
  min_snr_db: float = -10,
20
  max_snr_db: float = 20,
21
 
@@ -45,6 +50,12 @@ class SileroVadConfig(PretrainedConfig):
45
  self.in_channels = in_channels
46
  self.hidden_size = hidden_size
47
 
 
 
 
 
 
 
48
  # data snr
49
  self.min_snr_db = min_snr_db
50
  self.max_snr_db = max_snr_db
 
16
  in_channels: int = 64,
17
  hidden_size: int = 128,
18
 
19
+ n_frame: int = 3,
20
+ min_local_snr_db: float = -15,
21
+ max_local_snr_db: float = 30,
22
+ norm_tau: float = 1.,
23
+
24
  min_snr_db: float = -10,
25
  max_snr_db: float = 20,
26
 
 
50
  self.in_channels = in_channels
51
  self.hidden_size = hidden_size
52
 
53
+ # lsnr
54
+ self.n_frame = n_frame
55
+ self.min_local_snr_db = min_local_snr_db
56
+ self.max_local_snr_db = max_local_snr_db
57
+ self.norm_tau = norm_tau
58
+
59
  # data snr
60
  self.min_snr_db = min_snr_db
61
  self.max_snr_db = max_snr_db
toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py CHANGED
@@ -13,10 +13,12 @@ from typing import Optional, Union
13
 
14
  import torch
15
  import torch.nn as nn
 
16
 
17
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
18
  from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
19
  from toolbox.torchaudio.modules.conv_stft import ConvSTFT
 
20
 
21
 
22
  MODEL_FILE = "model.pt"
@@ -80,50 +82,99 @@ class Encoder(nn.Module):
80
 
81
 
82
  class SileroVadModel(nn.Module):
83
- def __init__(self, config: SileroVadConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  super(SileroVadModel, self).__init__()
85
- self.nfft = config.nfft
86
- self.win_size = config.win_size
87
- self.hop_size = config.hop_size
88
- self.win_type = config.win_type
 
 
 
 
 
 
 
 
89
 
90
- self.config = config
91
  self.eps = 1e-12
92
 
93
  self.stft = ConvSTFT(
94
- nfft=config.nfft,
95
- win_size=config.win_size,
96
- hop_size=config.hop_size,
97
- win_type=config.win_type,
98
  power=1,
99
  requires_grad=False
100
  )
 
 
 
 
 
 
 
 
101
 
102
  self.linear = nn.Linear(
103
- in_features=(config.nfft // 2 + 1),
104
- out_features=config.in_channels,
105
  )
106
 
107
  self.encoder = Encoder(
108
- in_channels=config.in_channels,
109
- out_channels=config.hidden_size,
110
  )
111
 
112
  self.lstm = nn.LSTM(
113
- input_size=config.hidden_size,
114
- hidden_size=config.hidden_size,
115
  bidirectional=False,
116
  batch_first=True
117
  )
118
 
119
- self.classifier = nn.Sequential(
120
- nn.Linear(config.hidden_size, 32),
 
121
  nn.ReLU(),
122
  nn.Linear(32, 1),
123
  )
124
-
125
  self.sigmoid = nn.Sigmoid()
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def forward(self, signal: torch.Tensor):
128
  if signal.dim() == 2:
129
  signal = torch.unsqueeze(signal, dim=1)
@@ -143,40 +194,46 @@ class SileroVadModel(nn.Module):
143
  # x shape: [b, t, f]
144
 
145
  x, _ = self.lstm.forward(x)
146
- logits = self.classifier.forward(x)
 
147
  # logits shape: [b, t, 1]
148
  probs = self.sigmoid.forward(logits)
149
  # probs shape: [b, t, 1]
150
- return logits, probs
151
 
152
- def forward_chunk(self, chunk: torch.Tensor):
153
- # chunk shape [b, 1, num_samples]
154
 
155
- mags = self.stft.forward(chunk)
156
- # mags shape: [b, f, t]
157
 
158
- x = torch.transpose(mags, dim0=1, dim1=2)
159
- # x shape: [b, t, f]
 
 
160
 
161
- x = self.linear.forward(x)
162
- # x shape: [b, t, f']
 
 
163
 
164
- return
 
 
 
 
 
 
 
 
165
 
166
- def forward_chunk_by_chunk(self, signal: torch.Tensor):
167
- if signal.dim() == 2:
168
- signal = torch.unsqueeze(signal, dim=1)
169
- _, _, num_samples = signal.shape
170
- # signal shape [b, 1, num_samples]
171
 
172
- t = (num_samples - self.win_size) // self.hop_size + 1
173
- waveform_list = list()
174
- for i in range(int(t)):
175
- begin = i * self.hop_size
176
- end = begin + self.win_size
177
- sub_signal = signal[:, :, begin: end]
178
 
179
- return
 
180
 
181
 
182
  class SileroVadPretrainedModel(SileroVadModel):
@@ -184,8 +241,18 @@ class SileroVadPretrainedModel(SileroVadModel):
184
  config: SileroVadConfig,
185
  ):
186
  super(SileroVadPretrainedModel, self).__init__(
187
- config=config,
 
 
 
 
 
 
 
 
 
188
  )
 
189
 
190
  @classmethod
191
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
@@ -227,13 +294,14 @@ class SileroVadPretrainedModel(SileroVadModel):
227
 
228
  def main():
229
  config = SileroVadConfig()
230
- model = SileroVadModel(config=config)
231
 
232
  noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
233
 
234
- logits, probs = model.forward(noisy)
235
- print(f"logits: {probs}")
236
  print(f"logits.shape: {logits.shape}")
 
 
237
 
238
  return
239
 
 
13
 
14
  import torch
15
  import torch.nn as nn
16
+ from torch.nn import functional as F
17
 
18
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
19
  from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
20
  from toolbox.torchaudio.modules.conv_stft import ConvSTFT
21
+ from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
22
 
23
 
24
  MODEL_FILE = "model.pt"
 
82
 
83
 
84
  class SileroVadModel(nn.Module):
85
+ def __init__(self,
86
+ sample_rate: int,
87
+ nfft: int,
88
+ win_size: int,
89
+ hop_size: int,
90
+ win_type: int,
91
+
92
+ in_channels: int,
93
+ hidden_size: int,
94
+
95
+ n_frame: int,
96
+ min_local_snr_db: float,
97
+ max_local_snr_db: float,
98
+
99
+ ):
100
  super(SileroVadModel, self).__init__()
101
+ self.sample_rate = sample_rate
102
+ self.nfft = nfft
103
+ self.win_size = win_size
104
+ self.hop_size = hop_size
105
+ self.win_type = win_type
106
+
107
+ self.in_channels = in_channels
108
+ self.hidden_size = hidden_size
109
+
110
+ self.n_frame = n_frame
111
+ self.min_local_snr_db = min_local_snr_db
112
+ self.max_local_snr_db = max_local_snr_db
113
 
 
114
  self.eps = 1e-12
115
 
116
  self.stft = ConvSTFT(
117
+ nfft=nfft,
118
+ win_size=win_size,
119
+ hop_size=hop_size,
120
+ win_type=win_type,
121
  power=1,
122
  requires_grad=False
123
  )
124
+ self.complex_stft = ConvSTFT(
125
+ nfft=nfft,
126
+ win_size=win_size,
127
+ hop_size=hop_size,
128
+ win_type=win_type,
129
+ power=None,
130
+ requires_grad=False
131
+ )
132
 
133
  self.linear = nn.Linear(
134
+ in_features=(self.nfft // 2 + 1),
135
+ out_features=self.in_channels,
136
  )
137
 
138
  self.encoder = Encoder(
139
+ in_channels=self.in_channels,
140
+ out_channels=self.hidden_size,
141
  )
142
 
143
  self.lstm = nn.LSTM(
144
+ input_size=self.hidden_size,
145
+ hidden_size=self.hidden_size,
146
  bidirectional=False,
147
  batch_first=True
148
  )
149
 
150
+ # vad
151
+ self.vad_fc = nn.Sequential(
152
+ nn.Linear(self.hidden_size, 32),
153
  nn.ReLU(),
154
  nn.Linear(32, 1),
155
  )
 
156
  self.sigmoid = nn.Sigmoid()
157
 
158
+ # lsnr
159
+ self.lsnr_fc = nn.Sequential(
160
+ nn.Linear(self.hidden_size, 1),
161
+ nn.Sigmoid()
162
+ )
163
+ self.lsnr_scale = self.max_local_snr_db - self.min_local_snr_db
164
+ self.lsnr_offset = self.min_local_snr_db
165
+
166
+ # lsnr
167
+ self.lsnr_fn = LocalSnrTarget(
168
+ sample_rate=self.sample_rate,
169
+ nfft=self.nfft,
170
+ win_size=self.win_size,
171
+ hop_size=self.hop_size,
172
+ n_frame=self.n_frame,
173
+ min_local_snr=self.min_local_snr_db,
174
+ max_local_snr=self.max_local_snr_db,
175
+ db=True,
176
+ )
177
+
178
  def forward(self, signal: torch.Tensor):
179
  if signal.dim() == 2:
180
  signal = torch.unsqueeze(signal, dim=1)
 
194
  # x shape: [b, t, f]
195
 
196
  x, _ = self.lstm.forward(x)
197
+
198
+ logits = self.vad_fc.forward(x)
199
  # logits shape: [b, t, 1]
200
  probs = self.sigmoid.forward(logits)
201
  # probs shape: [b, t, 1]
 
202
 
203
+ lsnr = self.lsnr_fc.forward(x) * self.lsnr_scale + self.lsnr_offset
204
+ # lsnr shape: [b, t, 1]
205
 
206
+ return logits, probs, lsnr
 
207
 
208
+ def lsnr_loss_fn(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
209
+ if noisy.shape != clean.shape:
210
+ raise AssertionError("Input signals must have the same shape")
211
+ noise = noisy - clean
212
 
213
+ if clean.dim() == 2:
214
+ clean = torch.unsqueeze(clean, dim=1)
215
+ if noise.dim() == 2:
216
+ noise = torch.unsqueeze(noise, dim=1)
217
 
218
+ stft_clean = self.complex_stft.forward(clean)
219
+ stft_noise = self.complex_stft.forward(noise)
220
+ # shape: [b, f, t]
221
+ stft_clean = torch.transpose(stft_clean, dim0=1, dim1=2)
222
+ stft_noise = torch.transpose(stft_noise, dim0=1, dim1=2)
223
+ # shape: [b, t, f]
224
+ stft_clean = torch.unsqueeze(stft_clean, dim=1)
225
+ stft_noise = torch.unsqueeze(stft_noise, dim=1)
226
+ # shape: [b, 1, t, f]
227
 
228
+ # lsnr shape: [b, 1, t]
229
+ lsnr = lsnr.squeeze(1)
230
+ # lsnr shape: [b, t]
 
 
231
 
232
+ lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
233
+ # lsnr_gth shape: [b, t]
 
 
 
 
234
 
235
+ loss = F.mse_loss(lsnr, lsnr_gth)
236
+ return loss
237
 
238
 
239
  class SileroVadPretrainedModel(SileroVadModel):
 
241
  config: SileroVadConfig,
242
  ):
243
  super(SileroVadPretrainedModel, self).__init__(
244
+ sample_rate=config.sample_rate,
245
+ nfft=config.nfft,
246
+ win_size=config.win_size,
247
+ hop_size=config.hop_size,
248
+ win_type=config.win_type,
249
+ in_channels=config.in_channels,
250
+ hidden_size=config.hidden_size,
251
+ n_frame=config.n_frame,
252
+ min_local_snr_db=config.min_local_snr_db,
253
+ max_local_snr_db=config.max_local_snr_db,
254
  )
255
+ self.config = config
256
 
257
  @classmethod
258
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
 
294
 
295
  def main():
296
  config = SileroVadConfig()
297
+ model = SileroVadPretrainedModel(config=config)
298
 
299
  noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
300
 
301
+ logits, probs, lsnr = model.forward(noisy)
 
302
  print(f"logits.shape: {logits.shape}")
303
+ print(f"probs.shape: {probs.shape}")
304
+ print(f"lsnr.shape: {lsnr.shape}")
305
 
306
  return
307
 
toolbox/torchaudio/models/vad/silero_vad/yaml/config.yaml CHANGED
@@ -11,6 +11,12 @@ win_type: hann
11
  in_channels: 64
12
  hidden_size: 128
13
 
 
 
 
 
 
 
14
  # data
15
  min_snr_db: -10
16
  max_snr_db: 20
 
11
  in_channels: 64
12
  hidden_size: 128
13
 
14
+ # lsnr
15
+ n_frame: 3
16
+ min_local_snr_db: -15
17
+ max_local_snr_db: 30
18
+ norm_tau: 1.
19
+
20
  # data
21
  min_snr_db: -10
22
  max_snr_db: 20