HoneyTian commited on
Commit
d87e440
·
1 Parent(s): 51ac2c7
.gitignore CHANGED
@@ -9,6 +9,7 @@
9
  **/log/
10
  **/logs/
11
  **/__pycache__/
 
12
 
13
  /data/
14
  /docs/
@@ -21,3 +22,4 @@
21
 
22
  **/*.wav
23
  **/*.xlsx
 
 
9
  **/log/
10
  **/logs/
11
  **/__pycache__/
12
+ **/serialization_dir/
13
 
14
  /data/
15
  /docs/
 
22
 
23
  **/*.wav
24
  **/*.xlsx
25
+ **/*.jsonl
Dockerfile CHANGED
@@ -10,7 +10,9 @@ RUN apt-get install -y ffmpeg build-essential
10
  RUN pip install --upgrade pip
11
  RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
12
 
13
- RUN useradd -m -u 1000 user
 
 
14
 
15
  USER user
16
 
 
10
  RUN pip install --upgrade pip
11
  RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
12
 
13
+ RUN pip install --upgrade pip
14
+
15
+ RUN bash install.sh --stage 1 --stop_stage 2 --system_version centos
16
 
17
  USER user
18
 
download_sound_models.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ from huggingface_hub import snapshot_download
7
+
8
+ from project_settings import environment, project_path
9
+
10
+
11
+ def get_args():
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument(
14
+ "--trained_model_dir",
15
+ default=(project_path / "trained_models").as_posix(),
16
+ type=str,
17
+ )
18
+ parser.add_argument(
19
+ "--models_repo_id",
20
+ default="qgyd2021/vm_sound_classification",
21
+ type=str,
22
+ )
23
+ parser.add_argument(
24
+ "--model_pattern",
25
+ default="sound-*-ch32.zip",
26
+ type=str,
27
+ )
28
+ parser.add_argument(
29
+ "--hf_token",
30
+ default=environment.get("hf_token"),
31
+ type=str,
32
+ )
33
+ args = parser.parse_args()
34
+ return args
35
+
36
+
37
+ def main():
38
+ args = get_args()
39
+
40
+ trained_model_dir = Path(args.trained_model_dir)
41
+ trained_model_dir.mkdir(parents=True, exist_ok=True)
42
+
43
+ _ = snapshot_download(
44
+ repo_id=args.models_repo_id,
45
+ allow_patterns=[args.model_pattern],
46
+ local_dir=trained_model_dir.as_posix(),
47
+ token=args.hf_token,
48
+ )
49
+ return
50
+
51
+
52
+ if __name__ == '__main__':
53
+ main()
examples/fsmn_vad/step_1_prepare_data.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument(
21
+ "--noise_dir",
22
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
23
+ type=str
24
+ )
25
+ parser.add_argument(
26
+ "--speech_dir",
27
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
28
+ type=str
29
+ )
30
+
31
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
32
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
33
+
34
+ parser.add_argument("--duration", default=6.0, type=float)
35
+ parser.add_argument("--min_snr_db", default=-10, type=float)
36
+ parser.add_argument("--max_snr_db", default=20, type=float)
37
+
38
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
39
+
40
+ parser.add_argument("--max_count", default=-1, type=int)
41
+
42
+ args = parser.parse_args()
43
+ return args
44
+
45
+
46
+ def target_second_signal_generator(data_dir: str, duration: int = 6, sample_rate: int = 8000, max_epoch: int = 20000):
47
+ data_dir = Path(data_dir)
48
+ for epoch_idx in range(max_epoch):
49
+ for filename in data_dir.glob("**/*.wav"):
50
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
51
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
52
+
53
+ if raw_duration < duration:
54
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
55
+ continue
56
+ if signal.ndim != 1:
57
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
58
+
59
+ signal_length = len(signal)
60
+ win_size = int(duration * sample_rate)
61
+ for begin in range(0, signal_length - win_size, win_size):
62
+ if np.sum(signal[begin: begin+win_size]) == 0:
63
+ continue
64
+ row = {
65
+ "epoch_idx": epoch_idx,
66
+ "filename": filename.as_posix(),
67
+ "raw_duration": round(raw_duration, 4),
68
+ "offset": round(begin / sample_rate, 4),
69
+ "duration": round(duration, 4),
70
+ }
71
+ yield row
72
+
73
+
74
+ def main():
75
+ args = get_args()
76
+
77
+ noise_dir = Path(args.noise_dir)
78
+ speech_dir = Path(args.speech_dir)
79
+
80
+ train_dataset = Path(args.train_dataset)
81
+ valid_dataset = Path(args.valid_dataset)
82
+ train_dataset.parent.mkdir(parents=True, exist_ok=True)
83
+ valid_dataset.parent.mkdir(parents=True, exist_ok=True)
84
+
85
+ noise_generator = target_second_signal_generator(
86
+ noise_dir.as_posix(),
87
+ duration=args.duration,
88
+ sample_rate=args.target_sample_rate,
89
+ max_epoch=100000,
90
+ )
91
+ speech_generator = target_second_signal_generator(
92
+ speech_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate,
95
+ max_epoch=1,
96
+ )
97
+
98
+ count = 0
99
+ process_bar = tqdm(desc="build dataset jsonl")
100
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
101
+ for noise, speech in zip(noise_generator, speech_generator):
102
+ if count >= args.max_count > 0:
103
+ break
104
+
105
+ # row
106
+ noise_filename = noise["filename"]
107
+ noise_raw_duration = noise["raw_duration"]
108
+ noise_offset = noise["offset"]
109
+ noise_duration = noise["duration"]
110
+
111
+ speech_filename = speech["filename"]
112
+ speech_raw_duration = speech["raw_duration"]
113
+ speech_offset = speech["offset"]
114
+ speech_duration = speech["duration"]
115
+
116
+ # row
117
+ random1 = random.random()
118
+ random2 = random.random()
119
+
120
+ row = {
121
+ "count": count,
122
+
123
+ "noise_filename": noise_filename,
124
+ "noise_raw_duration": noise_raw_duration,
125
+ "noise_offset": noise_offset,
126
+ "noise_duration": noise_duration,
127
+
128
+ "speech_filename": speech_filename,
129
+ "speech_raw_duration": speech_raw_duration,
130
+ "speech_offset": speech_offset,
131
+ "speech_duration": speech_duration,
132
+
133
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
134
+
135
+ "random1": random1,
136
+ }
137
+ row = json.dumps(row, ensure_ascii=False)
138
+ if random2 < (1 / 300 / 1):
139
+ fvalid.write(f"{row}\n")
140
+ else:
141
+ ftrain.write(f"{row}\n")
142
+
143
+ count += 1
144
+ duration_seconds = count * args.duration
145
+ duration_hours = duration_seconds / 3600
146
+
147
+ process_bar.update(n=1)
148
+ process_bar.set_postfix({
149
+ "duration_hours": round(duration_hours, 4),
150
+ })
151
+
152
+ return
153
+
154
+
155
+ if __name__ == "__main__":
156
+ main()
examples/silero_vad_by_webrtcvad/run.sh CHANGED
@@ -122,7 +122,7 @@ fi
122
  if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
123
  $verbose && echo "stage 3: train model"
124
  cd "${work_dir}" || exit 1
125
- python3 step_3_train_model.py \
126
  --train_dataset "${train_vad_dataset}" \
127
  --valid_dataset "${valid_vad_dataset}" \
128
  --serialization_dir "${file_dir}" \
@@ -131,8 +131,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
131
  fi
132
 
133
 
134
- if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
135
- $verbose && echo "stage 3: test model"
136
  cd "${work_dir}" || exit 1
137
  python3 step_3_evaluation.py \
138
  --valid_dataset "${valid_dataset}" \
@@ -143,8 +143,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
143
  fi
144
 
145
 
146
- if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
147
- $verbose && echo "stage 4: collect files"
148
  cd "${work_dir}" || exit 1
149
 
150
  mkdir -p ${final_model_dir}
@@ -165,8 +165,8 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
165
  fi
166
 
167
 
168
- if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
169
- $verbose && echo "stage 5: clear file_dir"
170
  cd "${work_dir}" || exit 1
171
 
172
  rm -rf "${file_dir}";
 
122
  if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
123
  $verbose && echo "stage 3: train model"
124
  cd "${work_dir}" || exit 1
125
+ python3 step_4_train_model.py \
126
  --train_dataset "${train_vad_dataset}" \
127
  --valid_dataset "${valid_vad_dataset}" \
128
  --serialization_dir "${file_dir}" \
 
131
  fi
132
 
133
 
134
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
135
+ $verbose && echo "stage 4: test model"
136
  cd "${work_dir}" || exit 1
137
  python3 step_3_evaluation.py \
138
  --valid_dataset "${valid_dataset}" \
 
143
  fi
144
 
145
 
146
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
147
+ $verbose && echo "stage 5: collect files"
148
  cd "${work_dir}" || exit 1
149
 
150
  mkdir -p ${final_model_dir}
 
165
  fi
166
 
167
 
168
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
169
+ $verbose && echo "stage 6: clear file_dir"
170
  cd "${work_dir}" || exit 1
171
 
172
  rm -rf "${file_dir}";
examples/silero_vad_by_webrtcvad/step_1_prepare_data.py CHANGED
@@ -6,6 +6,7 @@ import os
6
  from pathlib import Path
7
  import random
8
  import sys
 
9
 
10
  pwd = os.path.abspath(os.path.dirname(__file__))
11
  sys.path.append(os.path.join(pwd, "../../"))
@@ -19,19 +20,21 @@ def get_args():
19
  parser = argparse.ArgumentParser()
20
  parser.add_argument(
21
  "--noise_dir",
22
- default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
23
  type=str
24
  )
25
  parser.add_argument(
26
  "--speech_dir",
27
- default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
28
  type=str
29
  )
30
 
31
  parser.add_argument("--train_dataset", default="train.jsonl", type=str)
32
  parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
33
 
34
- parser.add_argument("--duration", default=6.0, type=float)
 
 
35
  parser.add_argument("--min_snr_db", default=-10, type=float)
36
  parser.add_argument("--max_snr_db", default=20, type=float)
37
 
@@ -43,21 +46,90 @@ def get_args():
43
  return args
44
 
45
 
46
- def target_second_signal_generator(data_dir: str, duration: int = 6, sample_rate: int = 8000, max_epoch: int = 20000):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  data_dir = Path(data_dir)
48
  for epoch_idx in range(max_epoch):
49
  for filename in data_dir.glob("**/*.wav"):
50
  signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
51
  raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
52
 
53
- if raw_duration < duration:
54
- # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
55
- continue
56
  if signal.ndim != 1:
57
  raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  signal_length = len(signal)
60
- win_size = int(duration * sample_rate)
61
  for begin in range(0, signal_length - win_size, win_size):
62
  if np.sum(signal[begin: begin+win_size]) == 0:
63
  continue
@@ -66,7 +138,7 @@ def target_second_signal_generator(data_dir: str, duration: int = 6, sample_rate
66
  "filename": filename.as_posix(),
67
  "raw_duration": round(raw_duration, 4),
68
  "offset": round(begin / sample_rate, 4),
69
- "duration": round(duration, 4),
70
  }
71
  yield row
72
 
@@ -82,15 +154,16 @@ def main():
82
  train_dataset.parent.mkdir(parents=True, exist_ok=True)
83
  valid_dataset.parent.mkdir(parents=True, exist_ok=True)
84
 
85
- noise_generator = target_second_signal_generator(
86
  noise_dir.as_posix(),
87
  duration=args.duration,
88
  sample_rate=args.target_sample_rate,
89
  max_epoch=100000,
90
  )
91
- speech_generator = target_second_signal_generator(
92
  speech_dir.as_posix(),
93
- duration=args.duration,
 
94
  sample_rate=args.target_sample_rate,
95
  max_epoch=1,
96
  )
@@ -98,21 +171,26 @@ def main():
98
  count = 0
99
  process_bar = tqdm(desc="build dataset jsonl")
100
  with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
101
- for noise, speech in zip(noise_generator, speech_generator):
102
  if count >= args.max_count > 0:
103
  break
104
 
105
  # row
106
- noise_filename = noise["filename"]
107
- noise_raw_duration = noise["raw_duration"]
108
- noise_offset = noise["offset"]
109
- noise_duration = noise["duration"]
110
-
111
  speech_filename = speech["filename"]
112
  speech_raw_duration = speech["raw_duration"]
113
  speech_offset = speech["offset"]
114
  speech_duration = speech["duration"]
115
 
 
 
 
 
 
 
 
 
 
 
116
  # row
117
  random1 = random.random()
118
  random2 = random.random()
@@ -120,16 +198,13 @@ def main():
120
  row = {
121
  "count": count,
122
 
123
- "noise_filename": noise_filename,
124
- "noise_raw_duration": noise_raw_duration,
125
- "noise_offset": noise_offset,
126
- "noise_duration": noise_duration,
127
-
128
  "speech_filename": speech_filename,
129
  "speech_raw_duration": speech_raw_duration,
130
  "speech_offset": speech_offset,
131
  "speech_duration": speech_duration,
132
 
 
 
133
  "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
134
 
135
  "random1": random1,
 
6
  from pathlib import Path
7
  import random
8
  import sys
9
+ import time
10
 
11
  pwd = os.path.abspath(os.path.dirname(__file__))
12
  sys.path.append(os.path.join(pwd, "../../"))
 
20
  parser = argparse.ArgumentParser()
21
  parser.add_argument(
22
  "--noise_dir",
23
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
24
  type=str
25
  )
26
  parser.add_argument(
27
  "--speech_dir",
28
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech",
29
  type=str
30
  )
31
 
32
  parser.add_argument("--train_dataset", default="train.jsonl", type=str)
33
  parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
34
 
35
+ parser.add_argument("--duration", default=8.0, type=float)
36
+ parser.add_argument("--min_speech_duration", default=6.0, type=float)
37
+ parser.add_argument("--max_speech_duration", default=8.0, type=float)
38
  parser.add_argument("--min_snr_db", default=-10, type=float)
39
  parser.add_argument("--max_snr_db", default=20, type=float)
40
 
 
46
  return args
47
 
48
 
49
+ def target_second_noise_signal_generator(data_dir: str,
50
+ duration: int = 4,
51
+ sample_rate: int = 8000, max_epoch: int = 20000):
52
+ noise_list = list()
53
+ wait_duration = duration
54
+
55
+ data_dir = Path(data_dir)
56
+ for epoch_idx in range(max_epoch):
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+
60
+ if signal.ndim != 1:
61
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
62
+
63
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
64
+
65
+ offset = 0.
66
+ rest_duration = raw_duration
67
+
68
+ for _ in range(1000):
69
+ if rest_duration <= 0:
70
+ break
71
+ if rest_duration <= wait_duration:
72
+ noise_list.append({
73
+ "epoch_idx": epoch_idx,
74
+ "filename": filename.as_posix(),
75
+ "raw_duration": round(raw_duration, 4),
76
+ "offset": round(offset, 4),
77
+ "duration": None,
78
+ "duration_": round(rest_duration, 4),
79
+ })
80
+ wait_duration -= rest_duration
81
+ offset = 0
82
+ rest_duration = 0
83
+ elif rest_duration > wait_duration:
84
+ noise_list.append({
85
+ "epoch_idx": epoch_idx,
86
+ "filename": filename.as_posix(),
87
+ "raw_duration": round(raw_duration, 4),
88
+ "offset": round(offset, 4),
89
+ "duration": round(wait_duration, 4),
90
+ "duration_": round(wait_duration, 4),
91
+ })
92
+ offset += wait_duration
93
+ rest_duration -= wait_duration
94
+ wait_duration = 0
95
+ else:
96
+ raise AssertionError
97
+
98
+ if wait_duration <= 0:
99
+ yield noise_list
100
+ noise_list = list()
101
+ wait_duration = duration
102
+
103
+
104
+ def target_second_speech_signal_generator(data_dir: str,
105
+ min_duration: int = 4,
106
+ max_duration: int = 6,
107
+ sample_rate: int = 8000, max_epoch: int = 1):
108
  data_dir = Path(data_dir)
109
  for epoch_idx in range(max_epoch):
110
  for filename in data_dir.glob("**/*.wav"):
111
  signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
112
  raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
113
 
 
 
 
114
  if signal.ndim != 1:
115
  raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
116
 
117
+ if raw_duration < min_duration:
118
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
119
+ continue
120
+
121
+ if raw_duration < max_duration:
122
+ row = {
123
+ "epoch_idx": epoch_idx,
124
+ "filename": filename.as_posix(),
125
+ "raw_duration": round(raw_duration, 4),
126
+ "offset": 0.,
127
+ "duration": round(raw_duration, 4),
128
+ }
129
+ yield row
130
+
131
  signal_length = len(signal)
132
+ win_size = int(max_duration * sample_rate)
133
  for begin in range(0, signal_length - win_size, win_size):
134
  if np.sum(signal[begin: begin+win_size]) == 0:
135
  continue
 
138
  "filename": filename.as_posix(),
139
  "raw_duration": round(raw_duration, 4),
140
  "offset": round(begin / sample_rate, 4),
141
+ "duration": round(max_duration, 4),
142
  }
143
  yield row
144
 
 
154
  train_dataset.parent.mkdir(parents=True, exist_ok=True)
155
  valid_dataset.parent.mkdir(parents=True, exist_ok=True)
156
 
157
+ noise_generator = target_second_noise_signal_generator(
158
  noise_dir.as_posix(),
159
  duration=args.duration,
160
  sample_rate=args.target_sample_rate,
161
  max_epoch=100000,
162
  )
163
+ speech_generator = target_second_speech_signal_generator(
164
  speech_dir.as_posix(),
165
+ min_duration=args.min_speech_duration,
166
+ max_duration=args.max_speech_duration,
167
  sample_rate=args.target_sample_rate,
168
  max_epoch=1,
169
  )
 
171
  count = 0
172
  process_bar = tqdm(desc="build dataset jsonl")
173
  with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
174
+ for speech, noise_list in zip(speech_generator, noise_generator):
175
  if count >= args.max_count > 0:
176
  break
177
 
178
  # row
 
 
 
 
 
179
  speech_filename = speech["filename"]
180
  speech_raw_duration = speech["raw_duration"]
181
  speech_offset = speech["offset"]
182
  speech_duration = speech["duration"]
183
 
184
+ noise_list = [
185
+ {
186
+ "filename": noise["filename"],
187
+ "raw_duration": noise["raw_duration"],
188
+ "offset": noise["offset"],
189
+ "duration": noise["duration"],
190
+ }
191
+ for noise in noise_list
192
+ ]
193
+
194
  # row
195
  random1 = random.random()
196
  random2 = random.random()
 
198
  row = {
199
  "count": count,
200
 
 
 
 
 
 
201
  "speech_filename": speech_filename,
202
  "speech_raw_duration": speech_raw_duration,
203
  "speech_offset": speech_offset,
204
  "speech_duration": speech_duration,
205
 
206
+ "noise_list": noise_list,
207
+
208
  "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
209
 
210
  "random1": random1,
examples/silero_vad_by_webrtcvad/step_2_make_vad_segments.py CHANGED
@@ -12,7 +12,8 @@ import librosa
12
  import numpy as np
13
  from tqdm import tqdm
14
 
15
- from toolbox.webrtcvad.vad import WebRTCVad
 
16
 
17
 
18
  def get_args():
@@ -24,15 +25,19 @@ def get_args():
24
  parser.add_argument("--train_vad_dataset", default="train-vad.jsonl", type=str)
25
  parser.add_argument("--valid_vad_dataset", default="valid-vad.jsonl", type=str)
26
 
27
- parser.add_argument("--duration", default=6.0, type=float)
28
  parser.add_argument("--expected_sample_rate", default=8000, type=int)
29
 
30
- # vad
31
- parser.add_argument("--agg", default=3, type=int)
32
- parser.add_argument("--frame_duration_ms", default=30, type=int)
33
- parser.add_argument("--padding_duration_ms", default=30, type=int)
34
- parser.add_argument("--silence_duration_threshold", default=0.0, type=float)
35
-
 
 
 
 
36
  args = parser.parse_args()
37
  return args
38
 
@@ -40,17 +45,58 @@ def get_args():
40
  def main():
41
  args = get_args()
42
 
43
- w_vad = WebRTCVad(
44
- agg=args.agg,
45
- frame_duration_ms=args.frame_duration_ms,
46
- padding_duration_ms=args.padding_duration_ms,
47
- silence_duration_threshold=args.silence_duration_threshold,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  sample_rate=args.expected_sample_rate,
49
  )
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # valid
 
 
 
 
52
  count = 0
53
- process_bar = tqdm(desc="process valid dataset jsonl")
54
  with (open(args.valid_dataset, "r", encoding="utf-8") as fvalid,
55
  open(args.valid_vad_dataset, "w", encoding="utf-8") as fvalid_vad):
56
  for row in fvalid:
@@ -81,18 +127,30 @@ def main():
81
  row = json.dumps(row, ensure_ascii=False)
82
  fvalid_vad.write(f"{row}\n")
83
 
 
 
 
 
84
  count += 1
85
- duration_seconds = count * args.duration
86
- duration_hours = duration_seconds / 3600
87
 
88
- process_bar.update(n=1)
89
- process_bar.set_postfix({
90
- "duration_hours": round(duration_hours, 4),
 
 
 
 
 
 
91
  })
92
 
93
  # train
 
 
 
 
94
  count = 0
95
- process_bar = tqdm(desc="process train dataset jsonl")
96
  with (open(args.train_dataset, "r", encoding="utf-8") as ftrain,
97
  open(args.train_vad_dataset, "w", encoding="utf-8") as ftrain_vad):
98
  for row in ftrain:
@@ -123,13 +181,21 @@ def main():
123
  row = json.dumps(row, ensure_ascii=False)
124
  ftrain_vad.write(f"{row}\n")
125
 
 
 
 
 
126
  count += 1
127
- duration_seconds = count * args.duration
128
- duration_hours = duration_seconds / 3600
129
 
130
- process_bar.update(n=1)
131
- process_bar.set_postfix({
132
- "duration_hours": round(duration_hours, 4),
 
 
 
 
 
 
133
  })
134
 
135
  return
 
12
  import numpy as np
13
  from tqdm import tqdm
14
 
15
+ from project_settings import project_path
16
+ from toolbox.vad.vad import WebRTCVoiceClassifier, SileroVoiceClassifier, CCSoundsClassifier, RingVad
17
 
18
 
19
  def get_args():
 
25
  parser.add_argument("--train_vad_dataset", default="train-vad.jsonl", type=str)
26
  parser.add_argument("--valid_vad_dataset", default="valid-vad.jsonl", type=str)
27
 
28
+ parser.add_argument("--duration", default=8.0, type=float)
29
  parser.add_argument("--expected_sample_rate", default=8000, type=int)
30
 
31
+ parser.add_argument(
32
+ "--silero_model_path",
33
+ default=(project_path / "trained_models/silero_vad.jit").as_posix(),
34
+ type=str,
35
+ )
36
+ parser.add_argument(
37
+ "--cc_sounds_model_path",
38
+ default=(project_path / "trained_models/sound-2-ch32.zip").as_posix(),
39
+ type=str,
40
+ )
41
  args = parser.parse_args()
42
  return args
43
 
 
45
  def main():
46
  args = get_args()
47
 
48
+ # webrtcvad
49
+ # model = SileroVoiceClassifier(model_path=args.silero_model_path, sample_rate=args.expected_sample_rate)
50
+ # w_vad = RingVad(
51
+ # model=model,
52
+ # start_ring_rate=0.2,
53
+ # end_ring_rate=0.1,
54
+ # frame_size_ms=32,
55
+ # frame_step_ms=32,
56
+ # padding_length_ms=320,
57
+ # max_silence_length_ms=320,
58
+ # max_speech_length_s=100,
59
+ # min_speech_length_s=0.1,
60
+ # sample_rate=args.expected_sample_rate,
61
+ # )
62
+
63
+ # webrtcvad
64
+ model = WebRTCVoiceClassifier(agg=3, sample_rate=args.expected_sample_rate)
65
+ w_vad = RingVad(
66
+ model=model,
67
+ start_ring_rate=0.9,
68
+ end_ring_rate=0.1,
69
+ frame_size_ms=30,
70
+ frame_step_ms=30,
71
+ padding_length_ms=90,
72
+ max_silence_length_ms=100,
73
+ max_speech_length_s=100,
74
+ min_speech_length_s=0.1,
75
  sample_rate=args.expected_sample_rate,
76
  )
77
 
78
+ # cc sounds
79
+ # model = CCSoundsClassifier(model_path=args.cc_sounds_model_path, sample_rate=args.expected_sample_rate)
80
+ # w_vad = RingVad(
81
+ # model=model,
82
+ # start_ring_rate=0.5,
83
+ # end_ring_rate=0.3,
84
+ # frame_size_ms=300,
85
+ # frame_step_ms=300,
86
+ # padding_length_ms=300,
87
+ # max_silence_length_ms=100,
88
+ # max_speech_length_s=100,
89
+ # min_speech_length_s=0.1,
90
+ # sample_rate=args.expected_sample_rate,
91
+ # )
92
+
93
  # valid
94
+ va_duration = 0
95
+ raw_duration = 0
96
+ use_duration = 0
97
+
98
  count = 0
99
+ process_bar_valid = tqdm(desc="process valid dataset jsonl")
100
  with (open(args.valid_dataset, "r", encoding="utf-8") as fvalid,
101
  open(args.valid_vad_dataset, "w", encoding="utf-8") as fvalid_vad):
102
  for row in fvalid:
 
127
  row = json.dumps(row, ensure_ascii=False)
128
  fvalid_vad.write(f"{row}\n")
129
 
130
+ va_duration += sum([vad_segment[1] - vad_segment[0] for vad_segment in vad_segments])
131
+ raw_duration += speech_duration
132
+ use_duration += args.duration
133
+
134
  count += 1
 
 
135
 
136
+ va_rate = va_duration / use_duration
137
+ va_raw_rate = va_duration / raw_duration
138
+ use_duration_hours = use_duration / 3600
139
+
140
+ process_bar_valid.update(n=1)
141
+ process_bar_valid.set_postfix({
142
+ "va_rate": round(va_rate, 4),
143
+ "va_raw_rate": round(va_raw_rate, 4),
144
+ "duration_hours": round(use_duration_hours, 4),
145
  })
146
 
147
  # train
148
+ va_duration = 0
149
+ raw_duration = 0
150
+ use_duration = 0
151
+
152
  count = 0
153
+ process_bar_train = tqdm(desc="process train dataset jsonl")
154
  with (open(args.train_dataset, "r", encoding="utf-8") as ftrain,
155
  open(args.train_vad_dataset, "w", encoding="utf-8") as ftrain_vad):
156
  for row in ftrain:
 
181
  row = json.dumps(row, ensure_ascii=False)
182
  ftrain_vad.write(f"{row}\n")
183
 
184
+ va_duration += sum([vad_segment[1] - vad_segment[0] for vad_segment in vad_segments])
185
+ raw_duration += speech_duration
186
+ use_duration += args.duration
187
+
188
  count += 1
 
 
189
 
190
+ va_rate = va_duration / use_duration
191
+ va_raw_rate = va_duration / raw_duration
192
+ use_duration_hours = use_duration / 3600
193
+
194
+ process_bar_train.update(n=1)
195
+ process_bar_train.set_postfix({
196
+ "va_rate": round(va_rate, 4),
197
+ "va_raw_rate": round(va_raw_rate, 4),
198
+ "duration_hours": round(use_duration_hours, 4),
199
  })
200
 
201
  return
examples/silero_vad_by_webrtcvad/step_3_check_vad.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ import sys
7
+
8
+ pwd = os.path.abspath(os.path.dirname(__file__))
9
+ sys.path.append(os.path.join(pwd, "../../"))
10
+
11
+ import librosa
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ from scipy.io import wavfile
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+
21
+ parser.add_argument("--train_vad_dataset", default="train-vad.jsonl", type=str)
22
+ parser.add_argument("--valid_vad_dataset", default="valid-vad.jsonl", type=str)
23
+
24
+ parser.add_argument("--duration", default=8.0, type=float)
25
+ parser.add_argument("--expected_sample_rate", default=8000, type=int)
26
+
27
+ args = parser.parse_args()
28
+ return args
29
+
30
+
31
+ def main():
32
+ args = get_args()
33
+
34
+ SAMPLE_RATE = 8000
35
+
36
+ with open(args.train_vad_dataset, "r", encoding="utf-8") as f:
37
+ for row in f:
38
+ row = json.loads(row)
39
+
40
+ speech_filename = row["speech_filename"]
41
+ speech_offset = row["speech_offset"]
42
+ speech_duration = row["speech_duration"]
43
+
44
+ vad_segments = row["vad_segments"]
45
+
46
+ print(f"speech_filename: {speech_filename}")
47
+ signal, sample_rate = librosa.load(
48
+ speech_filename,
49
+ sr=SAMPLE_RATE,
50
+ offset=speech_offset,
51
+ duration=speech_duration,
52
+ )
53
+
54
+ # plot
55
+ time = np.arange(0, len(signal)) / sample_rate
56
+ plt.figure(figsize=(12, 5))
57
+ plt.plot(time, signal, color='b')
58
+ for start, end in vad_segments:
59
+ plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') # 标记开始端点
60
+ plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') # 标记结束端点
61
+
62
+ plt.show()
63
+
64
+ return
65
+
66
+
67
+ if __name__ == "__main__":
68
+ main()
examples/silero_vad_by_webrtcvad/{step_3_train_model.py → step_4_train_model.py} RENAMED
@@ -22,25 +22,26 @@ from torch.nn import functional as F
22
  from torch.utils.data.dataloader import DataLoader
23
  from tqdm import tqdm
24
 
25
- from toolbox.torch.utils.data.dataset.vad_jsonl_dataset import VadJsonlDataset
26
  from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
27
  from toolbox.torchaudio.models.vad.silero_vad.modeling_silero_vad import SileroVadModel, SileroVadPretrainedModel
28
  from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss
29
  from toolbox.torchaudio.losses.bce_loss import BCELoss
30
  from toolbox.torchaudio.losses.dice_loss import DiceLoss
31
  from toolbox.torchaudio.metrics.vad_metrics.vad_accuracy import VadAccuracy
 
32
 
33
 
34
  def get_args():
35
  parser = argparse.ArgumentParser()
36
- parser.add_argument("--train_dataset", default="train.jsonl", type=str)
37
- parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
38
 
39
  parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
40
  parser.add_argument("--patience", default=30, type=int)
41
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
42
 
43
- parser.add_argument("--config_file", default="config.yaml", type=str)
44
 
45
  args = parser.parse_args()
46
  return args
@@ -116,7 +117,7 @@ def main():
116
  logger.info(f"GPU available count: {n_gpu}; device: {device}")
117
 
118
  # datasets
119
- train_dataset = VadJsonlDataset(
120
  jsonl_file=args.train_dataset,
121
  expected_sample_rate=config.sample_rate,
122
  max_wave_value=32768.0,
@@ -124,7 +125,7 @@ def main():
124
  max_snr_db=config.max_snr_db,
125
  # skip=225000,
126
  )
127
- valid_dataset = VadJsonlDataset(
128
  jsonl_file=args.valid_dataset,
129
  expected_sample_rate=config.sample_rate,
130
  max_wave_value=32768.0,
@@ -205,6 +206,7 @@ def main():
205
  dice_loss_fn = DiceLoss(reduction="mean").to(device)
206
 
207
  vad_accuracy_metrics_fn = VadAccuracy(threshold=0.5)
 
208
 
209
  # training loop
210
 
@@ -213,6 +215,11 @@ def main():
213
  average_bce_loss = 1000000000
214
  average_dice_loss = 1000000000
215
 
 
 
 
 
 
216
  model_list = list()
217
  best_epoch_idx = None
218
  best_step_idx = None
@@ -230,6 +237,7 @@ def main():
230
  # train
231
  model.train()
232
  vad_accuracy_metrics_fn.reset()
 
233
 
234
  total_loss = 0.
235
  total_bce_loss = 0.
@@ -259,6 +267,7 @@ def main():
259
  continue
260
 
261
  vad_accuracy_metrics_fn.__call__(probs, targets)
 
262
 
263
  optimizer.zero_grad()
264
  loss.backward()
@@ -277,14 +286,21 @@ def main():
277
 
278
  metrics = vad_accuracy_metrics_fn.get_metric()
279
  accuracy = metrics["accuracy"]
 
 
 
 
280
 
281
  progress_bar_train.update(1)
282
  progress_bar_train.set_postfix({
283
  "lr": lr_scheduler.get_last_lr()[0],
284
  "loss": average_loss,
285
- "average_bce_loss": average_bce_loss,
286
- "average_dice_loss": average_dice_loss,
287
  "accuracy": accuracy,
 
 
 
288
  })
289
 
290
  # evaluation
@@ -295,6 +311,7 @@ def main():
295
 
296
  model.eval()
297
  vad_accuracy_metrics_fn.reset()
 
298
 
299
  total_loss = 0.
300
  total_bce_loss = 0.
@@ -324,6 +341,7 @@ def main():
324
  continue
325
 
326
  vad_accuracy_metrics_fn.__call__(probs, targets)
 
327
 
328
  total_loss += loss.item()
329
  total_bce_loss += bce_loss.item()
@@ -336,18 +354,26 @@ def main():
336
 
337
  metrics = vad_accuracy_metrics_fn.get_metric()
338
  accuracy = metrics["accuracy"]
 
 
 
 
339
 
340
  progress_bar_eval.update(1)
341
  progress_bar_eval.set_postfix({
342
  "lr": lr_scheduler.get_last_lr()[0],
343
  "loss": average_loss,
344
- "average_bce_loss": average_bce_loss,
345
- "average_dice_loss": average_dice_loss,
346
  "accuracy": accuracy,
 
 
 
347
  })
348
 
349
  model.train()
350
  vad_accuracy_metrics_fn.reset()
 
351
 
352
  total_loss = 0.
353
  total_bce_loss = 0.
@@ -377,12 +403,12 @@ def main():
377
  if best_metric is None:
378
  best_epoch_idx = epoch_idx
379
  best_step_idx = step_idx
380
- best_metric = accuracy
381
- elif accuracy >= best_metric:
382
  # great is better.
383
  best_epoch_idx = epoch_idx
384
  best_step_idx = step_idx
385
- best_metric = accuracy
386
  else:
387
  pass
388
 
 
22
  from torch.utils.data.dataloader import DataLoader
23
  from tqdm import tqdm
24
 
25
+ from toolbox.torch.utils.data.dataset.vad_padding_jsonl_dataset import VadPaddingJsonlDataset
26
  from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
27
  from toolbox.torchaudio.models.vad.silero_vad.modeling_silero_vad import SileroVadModel, SileroVadPretrainedModel
28
  from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss
29
  from toolbox.torchaudio.losses.bce_loss import BCELoss
30
  from toolbox.torchaudio.losses.dice_loss import DiceLoss
31
  from toolbox.torchaudio.metrics.vad_metrics.vad_accuracy import VadAccuracy
32
+ from toolbox.torchaudio.metrics.vad_metrics.vad_f1_score import VadF1Score
33
 
34
 
35
  def get_args():
36
  parser = argparse.ArgumentParser()
37
+ parser.add_argument("--train_dataset", default="train-vad.jsonl", type=str)
38
+ parser.add_argument("--valid_dataset", default="valid-vad.jsonl", type=str)
39
 
40
  parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
41
  parser.add_argument("--patience", default=30, type=int)
42
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
43
 
44
+ parser.add_argument("--config_file", default="yaml/config.yaml", type=str)
45
 
46
  args = parser.parse_args()
47
  return args
 
117
  logger.info(f"GPU available count: {n_gpu}; device: {device}")
118
 
119
  # datasets
120
+ train_dataset = VadPaddingJsonlDataset(
121
  jsonl_file=args.train_dataset,
122
  expected_sample_rate=config.sample_rate,
123
  max_wave_value=32768.0,
 
125
  max_snr_db=config.max_snr_db,
126
  # skip=225000,
127
  )
128
+ valid_dataset = VadPaddingJsonlDataset(
129
  jsonl_file=args.valid_dataset,
130
  expected_sample_rate=config.sample_rate,
131
  max_wave_value=32768.0,
 
206
  dice_loss_fn = DiceLoss(reduction="mean").to(device)
207
 
208
  vad_accuracy_metrics_fn = VadAccuracy(threshold=0.5)
209
+ vad_f1_score_metrics_fn = VadF1Score(threshold=0.5)
210
 
211
  # training loop
212
 
 
215
  average_bce_loss = 1000000000
216
  average_dice_loss = 1000000000
217
 
218
+ accuracy = -1
219
+ f1 = -1
220
+ precision = -1
221
+ recall = -1
222
+
223
  model_list = list()
224
  best_epoch_idx = None
225
  best_step_idx = None
 
237
  # train
238
  model.train()
239
  vad_accuracy_metrics_fn.reset()
240
+ vad_f1_score_metrics_fn.reset()
241
 
242
  total_loss = 0.
243
  total_bce_loss = 0.
 
267
  continue
268
 
269
  vad_accuracy_metrics_fn.__call__(probs, targets)
270
+ vad_f1_score_metrics_fn.__call__(probs, targets)
271
 
272
  optimizer.zero_grad()
273
  loss.backward()
 
286
 
287
  metrics = vad_accuracy_metrics_fn.get_metric()
288
  accuracy = metrics["accuracy"]
289
+ metrics = vad_f1_score_metrics_fn.get_metric()
290
+ f1 = metrics["f1"]
291
+ precision = metrics["precision"]
292
+ recall = metrics["recall"]
293
 
294
  progress_bar_train.update(1)
295
  progress_bar_train.set_postfix({
296
  "lr": lr_scheduler.get_last_lr()[0],
297
  "loss": average_loss,
298
+ "bce_loss": average_bce_loss,
299
+ "dice_loss": average_dice_loss,
300
  "accuracy": accuracy,
301
+ "f1": f1,
302
+ "precision": precision,
303
+ "recall": recall,
304
  })
305
 
306
  # evaluation
 
311
 
312
  model.eval()
313
  vad_accuracy_metrics_fn.reset()
314
+ vad_f1_score_metrics_fn.reset()
315
 
316
  total_loss = 0.
317
  total_bce_loss = 0.
 
341
  continue
342
 
343
  vad_accuracy_metrics_fn.__call__(probs, targets)
344
+ vad_f1_score_metrics_fn.__call__(probs, targets)
345
 
346
  total_loss += loss.item()
347
  total_bce_loss += bce_loss.item()
 
354
 
355
  metrics = vad_accuracy_metrics_fn.get_metric()
356
  accuracy = metrics["accuracy"]
357
+ metrics = vad_f1_score_metrics_fn.get_metric()
358
+ f1 = metrics["f1"]
359
+ precision = metrics["precision"]
360
+ recall = metrics["recall"]
361
 
362
  progress_bar_eval.update(1)
363
  progress_bar_eval.set_postfix({
364
  "lr": lr_scheduler.get_last_lr()[0],
365
  "loss": average_loss,
366
+ "bce_loss": average_bce_loss,
367
+ "dice_loss": average_dice_loss,
368
  "accuracy": accuracy,
369
+ "f1": f1,
370
+ "precision": precision,
371
+ "recall": recall,
372
  })
373
 
374
  model.train()
375
  vad_accuracy_metrics_fn.reset()
376
+ vad_f1_score_metrics_fn.reset()
377
 
378
  total_loss = 0.
379
  total_bce_loss = 0.
 
403
  if best_metric is None:
404
  best_epoch_idx = epoch_idx
405
  best_step_idx = step_idx
406
+ best_metric = f1
407
+ elif f1 >= best_metric:
408
  # great is better.
409
  best_epoch_idx = epoch_idx
410
  best_step_idx = step_idx
411
+ best_metric = f1
412
  else:
413
  pass
414
 
install.sh CHANGED
@@ -1,9 +1,9 @@
1
  #!/usr/bin/env bash
2
 
3
- # bash install.sh --stage 2 --stop_stage 2 --system_version centos
4
 
5
 
6
- python_version=3.12.1
7
  system_version="centos";
8
 
9
  verbose=true;
@@ -41,20 +41,42 @@ while true; do
41
  done
42
 
43
  work_dir="$(pwd)"
 
 
 
44
 
45
 
46
  if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
47
- $verbose && echo "stage 1: install python"
48
  cd "${work_dir}" || exit 1;
49
 
50
- sh ./script/install_python.sh --python_version "${python_version}" --system_version "${system_version}"
 
51
  fi
52
 
53
 
54
  if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
55
- $verbose && echo "stage 2: create virtualenv"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # /usr/local/python-3.12.1/bin/virtualenv cc_vad
 
58
  # source /data/local/bin/cc_vad/bin/activate
59
  /usr/local/python-${python_version}/bin/pip3 install virtualenv
60
  mkdir -p /data/local/bin
 
1
  #!/usr/bin/env bash
2
 
3
+ # bash install.sh --stage 1 --stop_stage 2 --system_version centos
4
 
5
 
6
+ python_version=3.12.8
7
  system_version="centos";
8
 
9
  verbose=true;
 
41
  done
42
 
43
  work_dir="$(pwd)"
44
+ trained_models_dir="$(pwd)/trained_models"
45
+
46
+ mkdir -p "${trained_models_dir}"
47
 
48
 
49
  if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
50
+ $verbose && echo "stage 1: download sound models"
51
  cd "${work_dir}" || exit 1;
52
 
53
+ python download_sound_models.py
54
+
55
  fi
56
 
57
 
58
  if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
59
+ $verbose && echo "stage 2: download silero vad model"
60
+ cd "${trained_models_dir}" || exit 1;
61
+
62
+ wget https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/data/silero_vad.jit
63
+
64
+ fi
65
+
66
+
67
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
68
+ $verbose && echo "stage 3: install python"
69
+ cd "${work_dir}" || exit 1;
70
+
71
+ sh ./script/install_python.sh --python_version "${python_version}" --system_version "${system_version}"
72
+ fi
73
+
74
+
75
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
76
+ $verbose && echo "stage 4: create virtualenv"
77
 
78
+ # /usr/local/python-3.9.9/bin/pip3 install virtualenv
79
+ # /usr/local/python-3.9.9/bin/virtualenv cc_vad
80
  # source /data/local/bin/cc_vad/bin/activate
81
  /usr/local/python-${python_version}/bin/pip3 install virtualenv
82
  mkdir -p /data/local/bin
toolbox/torch/utils/data/dataset/vad_padding_jsonl_dataset.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import json
4
+ import random
5
+ from typing import List
6
+
7
+ import librosa
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import Dataset, IterableDataset
11
+
12
+
13
+ class VadPaddingJsonlDataset(IterableDataset):
14
+ def __init__(self,
15
+ jsonl_file: str,
16
+ expected_sample_rate: int,
17
+ resample: bool = False,
18
+ max_wave_value: float = 1.0,
19
+ buffer_size: int = 1000,
20
+ min_snr_db: float = None,
21
+ max_snr_db: float = None,
22
+ speech_target_duration: float = 8.0,
23
+ eps: float = 1e-8,
24
+ skip: int = 0,
25
+ ):
26
+ self.jsonl_file = jsonl_file
27
+ self.expected_sample_rate = expected_sample_rate
28
+ self.resample = resample
29
+ self.max_wave_value = max_wave_value
30
+ self.min_snr_db = min_snr_db
31
+ self.max_snr_db = max_snr_db
32
+ self.speech_target_duration = speech_target_duration
33
+ self.eps = eps
34
+ self.skip = skip
35
+
36
+ self.buffer_size = buffer_size
37
+ self.buffer_samples: List[dict] = list()
38
+
39
+ def __iter__(self):
40
+ self.buffer_samples = list()
41
+
42
+ iterable_source = self.iterable_source()
43
+
44
+ try:
45
+ for _ in range(self.skip):
46
+ next(iterable_source)
47
+ except StopIteration:
48
+ pass
49
+
50
+ # 初始填充缓冲区
51
+ try:
52
+ for _ in range(self.buffer_size):
53
+ self.buffer_samples.append(next(iterable_source))
54
+ except StopIteration:
55
+ pass
56
+
57
+ # 动态替换逻辑
58
+ while True:
59
+ try:
60
+ item = next(iterable_source)
61
+ # 随机替换缓冲区元素
62
+ replace_idx = random.randint(0, len(self.buffer_samples) - 1)
63
+ sample = self.buffer_samples[replace_idx]
64
+ self.buffer_samples[replace_idx] = item
65
+ yield self.convert_sample(sample)
66
+ except StopIteration:
67
+ break
68
+
69
+ # 清空剩余元素
70
+ random.shuffle(self.buffer_samples)
71
+ for sample in self.buffer_samples:
72
+ yield self.convert_sample(sample)
73
+
74
+ def iterable_source(self):
75
+ last_sample = None
76
+ with open(self.jsonl_file, "r", encoding="utf-8") as f:
77
+ for row in f:
78
+ row = json.loads(row)
79
+
80
+ speech_filename = row["speech_filename"]
81
+ speech_raw_duration = row["speech_raw_duration"]
82
+ speech_offset = row["speech_offset"]
83
+ speech_duration = row["speech_duration"]
84
+
85
+ noise_list = row["noise_list"]
86
+ noise_list = [
87
+ {
88
+ "filename": noise["filename"],
89
+ "raw_duration": noise["raw_duration"],
90
+ "offset": noise["offset"],
91
+ "duration": noise["duration"],
92
+ }
93
+ for noise in noise_list
94
+ ]
95
+
96
+ if self.min_snr_db is None or self.max_snr_db is None:
97
+ snr_db = row["snr_db"]
98
+ else:
99
+ snr_db = random.uniform(self.min_snr_db, self.max_snr_db)
100
+
101
+ vad_segments = row["vad_segments"]
102
+
103
+ sample = {
104
+ "speech_filename": speech_filename,
105
+ "speech_raw_duration": speech_raw_duration,
106
+ "speech_offset": speech_offset,
107
+ "speech_duration": speech_duration,
108
+
109
+ "noise_list": noise_list,
110
+
111
+ "snr_db": snr_db,
112
+
113
+ "vad_segments": vad_segments,
114
+ }
115
+ if last_sample is None:
116
+ last_sample = sample
117
+ continue
118
+ yield sample
119
+ yield last_sample
120
+
121
+ def convert_sample(self, sample: dict):
122
+ speech_filename = sample["speech_filename"]
123
+ speech_offset = sample["speech_offset"]
124
+ speech_duration = sample["speech_duration"]
125
+
126
+ noise_list = sample["noise_list"]
127
+
128
+ snr_db = sample["snr_db"]
129
+
130
+ vad_segments = sample["vad_segments"]
131
+
132
+ speech_wave = self.filename_to_waveform(speech_filename, speech_offset, speech_duration)
133
+ speech_wave_np = speech_wave.numpy()
134
+ speech_wave_np, left_pad_duration, _ = self.pad_waveform(speech_wave_np, self.expected_sample_rate, self.speech_target_duration)
135
+ speech_wave_np = self.make_sure_duration(speech_wave_np, self.expected_sample_rate, self.speech_target_duration)
136
+
137
+ noise_wave_list = list()
138
+ for noise in noise_list:
139
+ filename = noise["filename"]
140
+ offset = noise["offset"]
141
+ duration = noise["duration"]
142
+ noise_wave_: torch.Tensor = self.filename_to_waveform(filename, offset, duration)
143
+ noise_wave_list.append(noise_wave_)
144
+ noise_wave = torch.cat(noise_wave_list, dim=-1)
145
+ noise_wave_np = noise_wave.numpy()
146
+ noise_wave_np = self.make_sure_duration(noise_wave_np, self.expected_sample_rate, self.speech_target_duration)
147
+
148
+ noisy_wave_np, _ = self.mix_speech_and_noise(
149
+ speech=speech_wave_np,
150
+ noise=noise_wave_np,
151
+ snr_db=snr_db, eps=self.eps,
152
+ )
153
+ noisy_wave = torch.tensor(noisy_wave_np, dtype=torch.float32)
154
+
155
+ vad_segments = [
156
+ [
157
+ vad_segment[0] + left_pad_duration,
158
+ vad_segment[1] + left_pad_duration,
159
+ ]
160
+ for vad_segment in vad_segments
161
+ ]
162
+
163
+ result = {
164
+ "noisy_wave": noisy_wave,
165
+ "vad_segments": vad_segments,
166
+ }
167
+ return result
168
+
169
+ def filename_to_waveform(self, filename: str, offset: float, duration: float):
170
+ try:
171
+ waveform, sample_rate = librosa.load(
172
+ filename,
173
+ sr=self.expected_sample_rate,
174
+ offset=offset,
175
+ duration=duration,
176
+ )
177
+ except ValueError as e:
178
+ print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}")
179
+ raise e
180
+ waveform = torch.tensor(waveform, dtype=torch.float32)
181
+ return waveform
182
+
183
+ @staticmethod
184
+ def pad_waveform(waveform: np.ndarray, sample_rate: int = 8000, target_duration: float = 8.0):
185
+ num_samples = len(waveform)
186
+ target_num_samples = int(sample_rate * target_duration)
187
+ if target_num_samples < num_samples:
188
+ return waveform, 0, 0
189
+
190
+ left_pad_size = (target_num_samples - num_samples) // 2
191
+ right_pad_size = target_num_samples - left_pad_size
192
+ result = np.concat([
193
+ np.zeros(left_pad_size, dtype=waveform.dtype),
194
+ waveform,
195
+ np.zeros(right_pad_size, dtype=waveform.dtype),
196
+ ])
197
+
198
+ left_pad_duration = left_pad_size / sample_rate
199
+ right_pad_duration = right_pad_size / sample_rate
200
+ return result, left_pad_duration, right_pad_duration
201
+
202
+ @staticmethod
203
+ def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float, eps: float = 1e-8):
204
+ l1 = len(speech)
205
+ l2 = len(noise)
206
+ l = min(l1, l2)
207
+ speech = speech[:l]
208
+ noise = noise[:l]
209
+
210
+ # np.float32, value between (-1, 1).
211
+
212
+ speech_power = np.mean(np.square(speech))
213
+ noise_power = speech_power / (10 ** (snr_db / 10))
214
+
215
+ noise_adjusted = np.sqrt(noise_power) * noise / (np.sqrt(np.mean(noise ** 2)) + eps)
216
+
217
+ noisy_signal = speech + noise_adjusted
218
+
219
+ return noisy_signal, noise_adjusted
220
+
221
+ @staticmethod
222
+ def make_sure_duration(waveform: np.ndarray, sample_rate: int = 8000, target_duration: float = 8.0):
223
+ num_samples = len(waveform)
224
+ target_num_samples = int(sample_rate * target_duration)
225
+
226
+ if target_num_samples < num_samples:
227
+ waveform = waveform[:target_num_samples]
228
+ elif target_num_samples > num_samples:
229
+ pad_size = target_num_samples - num_samples
230
+ waveform = np.concat([
231
+ waveform,
232
+ np.zeros(pad_size, dtype=waveform.dtype),
233
+ ])
234
+ else:
235
+ pass
236
+ return waveform
237
+
238
+
239
+ if __name__ == "__main__":
240
+ pass
toolbox/torch/utils/data/vocabulary.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from collections import defaultdict, OrderedDict
4
+ import os
5
+ from typing import Any, Callable, Dict, Iterable, List, Set
6
+
7
+
8
+ def namespace_match(pattern: str, namespace: str):
9
+ """
10
+ Matches a namespace pattern against a namespace string. For example, ``*tags`` matches
11
+ ``passage_tags`` and ``question_tags`` and ``tokens`` matches ``tokens`` but not
12
+ ``stemmed_tokens``.
13
+ """
14
+ if pattern[0] == '*' and namespace.endswith(pattern[1:]):
15
+ return True
16
+ elif pattern == namespace:
17
+ return True
18
+ return False
19
+
20
+
21
+ class _NamespaceDependentDefaultDict(defaultdict):
22
+ def __init__(self,
23
+ non_padded_namespaces: Set[str],
24
+ padded_function: Callable[[], Any],
25
+ non_padded_function: Callable[[], Any]) -> None:
26
+ self._non_padded_namespaces = set(non_padded_namespaces)
27
+ self._padded_function = padded_function
28
+ self._non_padded_function = non_padded_function
29
+ super(_NamespaceDependentDefaultDict, self).__init__()
30
+
31
+ def __missing__(self, key: str):
32
+ if any(namespace_match(pattern, key) for pattern in self._non_padded_namespaces):
33
+ value = self._non_padded_function()
34
+ else:
35
+ value = self._padded_function()
36
+ dict.__setitem__(self, key, value)
37
+ return value
38
+
39
+ def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
40
+ # add non_padded_namespaces which weren't already present
41
+ self._non_padded_namespaces.update(non_padded_namespaces)
42
+
43
+
44
+ class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict):
45
+ def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
46
+ super(_TokenToIndexDefaultDict, self).__init__(non_padded_namespaces,
47
+ lambda: {padding_token: 0, oov_token: 1},
48
+ lambda: {})
49
+
50
+
51
+ class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict):
52
+ def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
53
+ super(_IndexToTokenDefaultDict, self).__init__(non_padded_namespaces,
54
+ lambda: {0: padding_token, 1: oov_token},
55
+ lambda: {})
56
+
57
+
58
+ DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels")
59
+ DEFAULT_PADDING_TOKEN = '[PAD]'
60
+ DEFAULT_OOV_TOKEN = '[UNK]'
61
+ NAMESPACE_PADDING_FILE = 'non_padded_namespaces.txt'
62
+
63
+
64
+ class Vocabulary(object):
65
+ def __init__(self, non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES):
66
+ self._non_padded_namespaces = set(non_padded_namespaces)
67
+ self._padding_token = DEFAULT_PADDING_TOKEN
68
+ self._oov_token = DEFAULT_OOV_TOKEN
69
+ self._token_to_index = _TokenToIndexDefaultDict(self._non_padded_namespaces,
70
+ self._padding_token,
71
+ self._oov_token)
72
+ self._index_to_token = _IndexToTokenDefaultDict(self._non_padded_namespaces,
73
+ self._padding_token,
74
+ self._oov_token)
75
+
76
+ def add_token_to_namespace(self, token: str, namespace: str = 'tokens') -> int:
77
+ if token not in self._token_to_index[namespace]:
78
+ index = len(self._token_to_index[namespace])
79
+ self._token_to_index[namespace][token] = index
80
+ self._index_to_token[namespace][index] = token
81
+ return index
82
+ else:
83
+ return self._token_to_index[namespace][token]
84
+
85
+ def get_index_to_token_vocabulary(self, namespace: str = 'tokens') -> Dict[int, str]:
86
+ return self._index_to_token[namespace]
87
+
88
+ def get_token_to_index_vocabulary(self, namespace: str = 'tokens') -> Dict[str, int]:
89
+ return self._token_to_index[namespace]
90
+
91
+ def get_token_index(self, token: str, namespace: str = 'tokens') -> int:
92
+ if token in self._token_to_index[namespace]:
93
+ return self._token_to_index[namespace][token]
94
+ else:
95
+ return self._token_to_index[namespace][self._oov_token]
96
+
97
+ def get_token_from_index(self, index: int, namespace: str = 'tokens'):
98
+ return self._index_to_token[namespace][index]
99
+
100
+ def get_vocab_size(self, namespace: str = 'tokens') -> int:
101
+ return len(self._token_to_index[namespace])
102
+
103
+ def save_to_files(self, directory: str):
104
+ os.makedirs(directory, exist_ok=True)
105
+ with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'w', encoding='utf-8') as f:
106
+ for namespace_str in self._non_padded_namespaces:
107
+ f.write('{}\n'.format(namespace_str))
108
+
109
+ for namespace, token_to_index in self._token_to_index.items():
110
+ filename = os.path.join(directory, '{}.txt'.format(namespace))
111
+ with open(filename, 'w', encoding='utf-8') as f:
112
+ for token, _ in token_to_index.items():
113
+ f.write('{}\n'.format(token))
114
+
115
+ @classmethod
116
+ def from_files(cls, directory: str) -> 'Vocabulary':
117
+ with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'r', encoding='utf-8') as f:
118
+ non_padded_namespaces = [namespace_str.strip() for namespace_str in f]
119
+
120
+ vocab = cls(non_padded_namespaces=non_padded_namespaces)
121
+
122
+ for namespace_filename in os.listdir(directory):
123
+ if namespace_filename == NAMESPACE_PADDING_FILE:
124
+ continue
125
+ if namespace_filename.startswith("."):
126
+ continue
127
+ namespace = namespace_filename.replace('.txt', '')
128
+ if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces):
129
+ is_padded = False
130
+ else:
131
+ is_padded = True
132
+ filename = os.path.join(directory, namespace_filename)
133
+ vocab.set_from_file(filename, is_padded, namespace=namespace)
134
+
135
+ return vocab
136
+
137
+ def set_from_file(self,
138
+ filename: str,
139
+ is_padded: bool = True,
140
+ oov_token: str = DEFAULT_OOV_TOKEN,
141
+ namespace: str = "tokens"
142
+ ):
143
+ if is_padded:
144
+ self._token_to_index[namespace] = {self._padding_token: 0}
145
+ self._index_to_token[namespace] = {0: self._padding_token}
146
+ else:
147
+ self._token_to_index[namespace] = {}
148
+ self._index_to_token[namespace] = {}
149
+
150
+ with open(filename, 'r', encoding='utf-8') as f:
151
+ index = 1 if is_padded else 0
152
+ for row in f:
153
+ token = str(row).strip()
154
+ if token == oov_token:
155
+ token = self._oov_token
156
+ self._token_to_index[namespace][token] = index
157
+ self._index_to_token[namespace][index] = token
158
+ index += 1
159
+
160
+ def convert_tokens_to_ids(self, tokens: List[str], namespace: str = "tokens"):
161
+ result = list()
162
+ for token in tokens:
163
+ idx = self._token_to_index[namespace].get(token)
164
+ if idx is None:
165
+ idx = self._token_to_index[namespace][self._oov_token]
166
+ result.append(idx)
167
+ return result
168
+
169
+ def convert_ids_to_tokens(self, ids: List[int], namespace: str = "tokens"):
170
+ result = list()
171
+ for idx in ids:
172
+ idx = self._index_to_token[namespace][idx]
173
+ result.append(idx)
174
+ return result
175
+
176
+ def pad_or_truncate_ids_by_max_length(self, ids: List[int], max_length: int, namespace: str = "tokens"):
177
+ pad_idx = self._token_to_index[namespace][self._padding_token]
178
+
179
+ length = len(ids)
180
+ if length > max_length:
181
+ result = ids[:max_length]
182
+ else:
183
+ result = ids + [pad_idx] * (max_length - length)
184
+ return result
185
+
186
+
187
+ def demo1():
188
+ import jieba
189
+
190
+ vocabulary = Vocabulary()
191
+ vocabulary.add_token_to_namespace('白天', 'tokens')
192
+ vocabulary.add_token_to_namespace('晚上', 'tokens')
193
+
194
+ text = '不是在白天, 就是在晚上'
195
+ tokens = jieba.lcut(text)
196
+
197
+ print(tokens)
198
+
199
+ ids = vocabulary.convert_tokens_to_ids(tokens)
200
+ print(ids)
201
+
202
+ padded_idx = vocabulary.pad_or_truncate_ids_by_max_length(ids, 10)
203
+ print(padded_idx)
204
+
205
+ tokens = vocabulary.convert_ids_to_tokens(padded_idx)
206
+ print(tokens)
207
+ return
208
+
209
+
210
+ if __name__ == '__main__':
211
+ demo1()
toolbox/torchaudio/metrics/vad_metrics/vad_f1_score.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+
5
+
6
+ class VadF1Score(object):
7
+ def __init__(self, threshold: float = 0.5, epsilon: float = 1e-12) -> None:
8
+ self.threshold = threshold
9
+ self.epsilon = epsilon # 防止除零错误
10
+
11
+ self.true_positives = 0.0
12
+ self.false_positives = 0.0
13
+ self.false_negatives = 0.0
14
+
15
+ def __call__(self,
16
+ predictions: torch.Tensor,
17
+ gold_labels: torch.Tensor,
18
+ ):
19
+ """
20
+ :param predictions: [b, t, 1], 经过sigmoid的概率输出
21
+ :param gold_labels: [b, t, 1], 二值标签 (0或1)
22
+ """
23
+ # 将预测值转为二进制标签
24
+ pred_labels = (predictions > self.threshold).float()
25
+
26
+ # 计算TP/FP/FN
27
+ tp = (pred_labels * gold_labels).sum() # True Positives
28
+ fp = (pred_labels * (1 - gold_labels)).sum() # False Positives
29
+ fn = ((1 - pred_labels) * gold_labels).sum() # False Negatives
30
+
31
+ # 累加统计量
32
+ self.true_positives += tp.item()
33
+ self.false_positives += fp.item()
34
+ self.false_negatives += fn.item()
35
+
36
+ def get_metric(self, reset: bool = False):
37
+ # 计算Precision和Recall
38
+ precision = self.true_positives / (self.true_positives + self.false_positives + self.epsilon)
39
+ recall = self.true_positives / (self.true_positives + self.false_negatives + self.epsilon)
40
+
41
+ # 计算F1 Score
42
+ f1 = 2 * (precision * recall) / (precision + recall + self.epsilon)
43
+
44
+ if reset:
45
+ self.reset()
46
+
47
+ return {
48
+ 'f1': f1,
49
+ 'precision': precision,
50
+ 'recall': recall
51
+ }
52
+
53
+ def reset(self):
54
+ self.true_positives = 0.0
55
+ self.false_positives = 0.0
56
+ self.false_negatives = 0.0
57
+
58
+
59
+ if __name__ == "__main__":
60
+ pass
toolbox/torchaudio/models/vad/silero_vad/inference_silero_vad.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import logging
5
+ from pathlib import Path
6
+ import shutil
7
+ import tempfile, time
8
+ import zipfile
9
+
10
+ from scipy.io import wavfile
11
+ import librosa
12
+ import numpy as np
13
+ import torch
14
+ import torchaudio
15
+
16
+ torch.set_num_threads(1)
17
+
18
+ from project_settings import project_path
19
+ from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
20
+ from toolbox.torchaudio.models.vad.silero_vad.modeling_silero_vad import SileroVadPretrainedModel, MODEL_FILE
21
+ from toolbox.vad.vad import FrameVoiceClassifier, RingVad, process_speech_probs, make_visualization
22
+
23
+
24
+ logger = logging.getLogger("toolbox")
25
+
26
+
27
+ class SileroVadVoiceClassifier(FrameVoiceClassifier):
28
+ def __init__(self,
29
+ pretrained_model_path_or_zip_file: str,
30
+ device: str = "cpu",
31
+ ):
32
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
33
+ self.device = torch.device(device)
34
+
35
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
36
+ config, model = self.load_models(self.pretrained_model_path_or_zip_file)
37
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
38
+
39
+ self.config = config
40
+ self.model = model
41
+ self.model.to(device)
42
+ self.model.eval()
43
+
44
+ def load_models(self, model_path: str):
45
+ model_path = Path(model_path)
46
+ if model_path.name.endswith(".zip"):
47
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
48
+ out_root = Path(tempfile.gettempdir()) / "cc_vad"
49
+ out_root.mkdir(parents=True, exist_ok=True)
50
+ f_zip.extractall(path=out_root)
51
+ model_path = out_root / model_path.stem
52
+
53
+ config = SileroVadConfig.from_pretrained(
54
+ pretrained_model_name_or_path=model_path.as_posix(),
55
+ )
56
+ model = SileroVadPretrainedModel.from_pretrained(
57
+ pretrained_model_name_or_path=model_path.as_posix(),
58
+ )
59
+ model.to(self.device)
60
+ model.eval()
61
+
62
+ shutil.rmtree(model_path)
63
+ return config, model
64
+
65
+ def predict(self, chunk: np.ndarray) -> float:
66
+ if chunk.dtype != np.int16:
67
+ raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
68
+
69
+ chunk = chunk / 32768
70
+
71
+ inputs = torch.tensor(chunk, dtype=torch.float32)
72
+ inputs = torch.unsqueeze(inputs, dim=0)
73
+
74
+ try:
75
+ logits, _ = self.model.forward(inputs)
76
+ except RuntimeError as e:
77
+ print(inputs.shape)
78
+ raise e
79
+ # logits shape: [b, t, 1]
80
+ logits_ = torch.mean(logits, dim=1)
81
+ # logits_ shape: [b, 1]
82
+ probs = torch.sigmoid(logits_)
83
+
84
+ voice_prob = probs[0][0]
85
+ return float(voice_prob)
86
+
87
+
88
+ class InferenceSileroVad(object):
89
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
90
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
91
+ self.device = torch.device(device)
92
+
93
+ self.voice_classifier = SileroVadVoiceClassifier(pretrained_model_path_or_zip_file, device=device)
94
+
95
+ self.ring_vad = RingVad(model=self.voice_classifier,
96
+ start_ring_rate=0.2,
97
+ end_ring_rate=0.1,
98
+ frame_size_ms=30,
99
+ frame_step_ms=30,
100
+ padding_length_ms=300,
101
+ max_silence_length_ms=300,
102
+ sample_rate=SAMPLE_RATE,
103
+ )
104
+
105
+ def vad(self, signal: np.ndarray) -> np.ndarray:
106
+ self.ring_vad.reset()
107
+
108
+ vad_segments = list()
109
+
110
+ segments = self.ring_vad.vad(signal)
111
+ vad_segments += segments
112
+ # last vad segment
113
+ segments = self.ring_vad.last_vad_segments()
114
+ vad_segments += segments
115
+ return vad_segments
116
+
117
+ def get_vad_speech_probs(self):
118
+ result = self.ring_vad.speech_probs
119
+ return result
120
+
121
+ def get_vad_frame_step(self):
122
+ result = self.ring_vad.frame_step
123
+ return result
124
+
125
+
126
+ def get_args():
127
+ parser = argparse.ArgumentParser()
128
+ parser.add_argument(
129
+ "--wav_file",
130
+ default=(project_path / "data/examples/hado/2f16ca0b-baec-4601-8a1e-7893eb875623.wav").as_posix(),
131
+ type=str,
132
+ )
133
+ args = parser.parse_args()
134
+ return args
135
+
136
+
137
+ SAMPLE_RATE = 8000
138
+
139
+
140
+ def main():
141
+ args = get_args()
142
+
143
+ sample_rate, signal = wavfile.read(args.wav_file)
144
+ if SAMPLE_RATE != sample_rate:
145
+ raise AssertionError
146
+
147
+ infer = InferenceSileroVad(
148
+ pretrained_model_path_or_zip_file=(project_path / "trained_models/silero-vad-webrtcvad-nx2-dns3.zip").as_posix()
149
+ )
150
+
151
+ vad_segments = infer.vad(signal)
152
+
153
+ speech_probs = infer.get_vad_speech_probs()
154
+ frame_step = infer.get_vad_frame_step()
155
+
156
+ # speech_probs
157
+ speech_probs = process_speech_probs(
158
+ signal=signal,
159
+ speech_probs=speech_probs,
160
+ frame_step=frame_step,
161
+ )
162
+
163
+ # plot
164
+ make_visualization(signal, speech_probs, SAMPLE_RATE, vad_segments)
165
+ return
166
+
167
+
168
+ if __name__ == "__main__":
169
+ main()
toolbox/vad/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/vad/vad.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import collections
5
+ from functools import lru_cache
6
+ import os
7
+ from pathlib import Path
8
+ import shutil
9
+ import tempfile
10
+ import zipfile
11
+ from typing import List
12
+
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ from scipy.io import wavfile
16
+ import torch
17
+ import webrtcvad
18
+
19
+ from project_settings import project_path
20
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
21
+
22
+
23
+ class FrameVoiceClassifier(object):
24
+ def predict(self, chunk: np.ndarray) -> float:
25
+ raise NotImplementedError
26
+
27
+
28
+ class WebRTCVoiceClassifier(FrameVoiceClassifier):
29
+ def __init__(self,
30
+ agg: int = 3,
31
+ sample_rate: int = 8000
32
+ ):
33
+ self.agg = agg
34
+ self.sample_rate = sample_rate
35
+
36
+ self.model = webrtcvad.Vad(mode=agg)
37
+
38
+ def predict(self, chunk: np.ndarray) -> float:
39
+ if chunk.dtype != np.int16:
40
+ raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
41
+
42
+ audio_bytes = bytes(chunk)
43
+ is_speech = self.model.is_speech(audio_bytes, self.sample_rate)
44
+ return 1.0 if is_speech else 0.0
45
+
46
+
47
+ class SileroVoiceClassifier(FrameVoiceClassifier):
48
+ def __init__(self,
49
+ model_path: str,
50
+ sample_rate: int = 8000):
51
+ self.model_path = model_path
52
+ self.sample_rate = sample_rate
53
+
54
+ with open(self.model_path, "rb") as f:
55
+ model = torch.jit.load(f, map_location="cpu")
56
+ self.model = model
57
+ self.model.reset_states()
58
+
59
+ def predict(self, chunk: np.ndarray) -> float:
60
+ if self.sample_rate / len(chunk) > 31.25:
61
+ raise AssertionError("chunk samples number {} is less than {}".format(len(chunk), self.sample_rate / 31.25))
62
+ if chunk.dtype != np.int16:
63
+ raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
64
+
65
+ num_samples = len(chunk)
66
+ if self.sample_rate == 8000 and num_samples != 256:
67
+ raise AssertionError(f"win size must be 32 ms for silero vad. ")
68
+ if self.sample_rate == 16000 and num_samples != 512:
69
+ raise AssertionError(f"win size must be 32 ms for silero vad. ")
70
+
71
+ chunk = chunk / 32768
72
+ chunk = torch.tensor(chunk, dtype=torch.float32)
73
+ speech_prob = self.model(chunk, self.sample_rate).item()
74
+ return float(speech_prob)
75
+
76
+
77
+ class CCSoundsClassifier(FrameVoiceClassifier):
78
+ def __init__(self,
79
+ model_path: str,
80
+ sample_rate: int = 8000):
81
+ self.model_path = model_path
82
+ self.sample_rate = sample_rate
83
+
84
+ d = self.load_model(Path(model_path))
85
+
86
+ model = d["model"]
87
+ vocabulary = d["vocabulary"]
88
+
89
+ self.model = model
90
+ self.vocabulary = vocabulary
91
+
92
+ @staticmethod
93
+ @lru_cache(maxsize=100)
94
+ def load_model(model_file: Path):
95
+ with zipfile.ZipFile(model_file, "r") as f_zip:
96
+ out_root = Path(tempfile.gettempdir()) / "cc_audio_8"
97
+ if out_root.exists():
98
+ shutil.rmtree(out_root.as_posix())
99
+ out_root.mkdir(parents=True, exist_ok=True)
100
+ f_zip.extractall(path=out_root)
101
+
102
+ tgt_path = out_root / model_file.stem
103
+ jit_model_file = tgt_path / "trace_model.zip"
104
+ vocab_path = tgt_path / "vocabulary"
105
+
106
+ vocabulary = Vocabulary.from_files(vocab_path.as_posix())
107
+
108
+ with open(jit_model_file.as_posix(), "rb") as f:
109
+ model = torch.jit.load(f)
110
+ model.eval()
111
+
112
+ shutil.rmtree(tgt_path)
113
+
114
+ d = {
115
+ "model": model,
116
+ "vocabulary": vocabulary
117
+ }
118
+ return d
119
+
120
+ def predict(self, chunk: np.ndarray) -> float:
121
+ if chunk.dtype != np.int16:
122
+ raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
123
+
124
+ chunk = chunk / (1 << 15)
125
+ inputs = torch.tensor(chunk, dtype=torch.float32)
126
+ inputs = torch.unsqueeze(inputs, dim=0)
127
+
128
+ with torch.no_grad():
129
+ logits = self.model(inputs)
130
+ probs = torch.nn.functional.softmax(logits, dim=-1)
131
+
132
+ voice_idx = self.vocabulary.get_token_index(token="voice", namespace="labels")
133
+
134
+ probs = probs.cpu()
135
+
136
+ voice_prob = probs[0][voice_idx]
137
+ return float(voice_prob)
138
+
139
+
140
+ class Frame(object):
141
+ def __init__(self, signal: np.ndarray, timestamp_s: float):
142
+ self.signal = signal
143
+ self.timestamp_s = timestamp_s
144
+
145
+
146
+ class RingVad(object):
147
+ def __init__(self,
148
+ model: FrameVoiceClassifier,
149
+ start_ring_rate: float = 0.5,
150
+ end_ring_rate: float = 0.5,
151
+ frame_size_ms: int = 30,
152
+ frame_step_ms: int = 30,
153
+ padding_length_ms: int = 300,
154
+ max_silence_length_ms: int = 300,
155
+ max_speech_length_s: float = 2.0,
156
+ min_speech_length_s: float = 0.3,
157
+ sample_rate: int = 8000
158
+ ):
159
+ self.model = model
160
+ self.start_ring_rate = start_ring_rate
161
+ self.end_ring_rate = end_ring_rate
162
+ self.frame_size_ms = frame_size_ms
163
+ self.frame_step_ms = frame_step_ms
164
+ self.padding_length_ms = padding_length_ms
165
+ self.max_silence_length_ms = max_silence_length_ms
166
+ self.max_speech_length_s = max_speech_length_s
167
+ self.min_speech_length_s = min_speech_length_s
168
+ self.sample_rate = sample_rate
169
+
170
+ # frames
171
+ self.frame_size = int(sample_rate * (frame_size_ms / 1000.0))
172
+ self.frame_step = int(sample_rate * (frame_step_ms / 1000.0))
173
+ self.frame_timestamp_s = 0.0
174
+ self.signal_cache = np.zeros(shape=(self.frame_size,), dtype=np.int16)
175
+
176
+ # segments
177
+ self.num_padding_frames = int(padding_length_ms / frame_step_ms)
178
+ self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
179
+ self.triggered = False
180
+ self.voiced_frames: List[Frame] = list()
181
+ self.segments = list()
182
+
183
+ # vad segments
184
+ self.is_first_segment = True
185
+ self.timestamp_start_s = 0.0
186
+ self.timestamp_end_s = 0.0
187
+
188
+ # speech probs
189
+ self.speech_probs: List[float] = list()
190
+
191
+ def reset(self):
192
+ # frames
193
+ self.frame_size = int(self.sample_rate * (self.frame_size_ms / 1000.0))
194
+ self.frame_step = int(self.sample_rate * (self.frame_step_ms / 1000.0))
195
+ self.frame_timestamp_s = 0.0
196
+ self.signal_cache = np.zeros(shape=(self.frame_size,), dtype=np.int16)
197
+
198
+ # segments
199
+ self.num_padding_frames = int(self.padding_length_ms / self.frame_step_ms)
200
+ self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
201
+ self.triggered = False
202
+ self.voiced_frames: List[Frame] = list()
203
+ self.segments = list()
204
+
205
+ # vad segments
206
+ self.is_first_segment = True
207
+ self.timestamp_start_s = 0.0
208
+ self.timestamp_end_s = 0.0
209
+
210
+ # speech probs
211
+ self.speech_probs: List[float] = list()
212
+
213
+ def signal_to_frames(self, signal: np.ndarray):
214
+ frames = list()
215
+
216
+ l = len(signal)
217
+
218
+ duration_s = float(self.frame_step) / self.sample_rate
219
+
220
+ for offset in range(0, l - self.frame_size + 1, self.frame_step):
221
+ sub_signal = signal[offset:offset+self.frame_size]
222
+ frame = Frame(sub_signal, self.frame_timestamp_s)
223
+ self.frame_timestamp_s += duration_s
224
+
225
+ frames.append(frame)
226
+ return frames
227
+
228
+ def segments_generator(self, signal: np.ndarray):
229
+ # signal rounding
230
+ if self.signal_cache is not None:
231
+ signal = np.concatenate([self.signal_cache, signal])
232
+
233
+ # rest
234
+ rest = (len(signal) - self.frame_size) % self.frame_step
235
+
236
+ if rest == 0:
237
+ self.signal_cache = None
238
+ signal_ = signal
239
+ else:
240
+ self.signal_cache = signal[-rest:]
241
+ signal_ = signal[:-rest]
242
+
243
+ # frames
244
+ frames = self.signal_to_frames(signal_)
245
+
246
+ for frame in frames:
247
+ speech_prob = self.model.predict(frame.signal)
248
+ self.speech_probs.append(speech_prob)
249
+
250
+ if not self.triggered:
251
+ self.ring_buffer.append((frame, speech_prob))
252
+ num_voiced = sum([p for _, p in self.ring_buffer])
253
+
254
+ if num_voiced > self.start_ring_rate * self.ring_buffer.maxlen:
255
+ self.triggered = True
256
+
257
+ for f, _ in self.ring_buffer:
258
+ self.voiced_frames.append(f)
259
+ continue
260
+
261
+ self.voiced_frames.append(frame)
262
+ self.ring_buffer.append((frame, speech_prob))
263
+ num_voiced = sum([p for _, p in self.ring_buffer])
264
+
265
+ if num_voiced < self.end_ring_rate * self.ring_buffer.maxlen:
266
+ segment = [
267
+ np.concatenate([f.signal for f in self.voiced_frames]),
268
+ self.voiced_frames[0].timestamp_s,
269
+ self.voiced_frames[-1].timestamp_s,
270
+ ]
271
+ yield segment
272
+ self.triggered = False
273
+ self.ring_buffer.clear()
274
+ self.voiced_frames = []
275
+ continue
276
+
277
+ def vad_segments_generator(self, segments_generator):
278
+ segments = list(segments_generator)
279
+
280
+ for i, segment in enumerate(segments):
281
+ start = round(segment[1], 4)
282
+ end = round(segment[2], 4)
283
+
284
+ if self.timestamp_start_s is None and self.timestamp_end_s is None:
285
+ self.timestamp_start_s = start
286
+ self.timestamp_end_s = end
287
+ continue
288
+
289
+ if self.timestamp_end_s - self.timestamp_start_s > self.max_speech_length_s:
290
+ end_ = self.timestamp_start_s + self.max_speech_length_s
291
+ vad_segment = [self.timestamp_start_s, end_]
292
+ yield vad_segment
293
+ self.timestamp_start_s = end_
294
+
295
+ silence_length_ms = (start - self.timestamp_end_s) * 1000
296
+ if silence_length_ms < self.max_silence_length_ms:
297
+ self.timestamp_end_s = end
298
+ continue
299
+
300
+ if self.timestamp_end_s - self.timestamp_start_s < self.min_speech_length_s:
301
+ self.timestamp_start_s = start
302
+ self.timestamp_end_s = end
303
+ continue
304
+
305
+ vad_segment = [self.timestamp_start_s, self.timestamp_end_s]
306
+ yield vad_segment
307
+ self.timestamp_start_s = start
308
+ self.timestamp_end_s = end
309
+
310
+ def vad(self, signal: np.ndarray) -> List[list]:
311
+ segments = self.segments_generator(signal)
312
+ vad_segments = self.vad_segments_generator(segments)
313
+ vad_segments = list(vad_segments)
314
+ return vad_segments
315
+
316
+ def last_vad_segments(self) -> List[list]:
317
+ # last segments
318
+ if len(self.voiced_frames) == 0:
319
+ segments = []
320
+ else:
321
+ segment = [
322
+ np.concatenate([f.signal for f in self.voiced_frames]),
323
+ self.voiced_frames[0].timestamp_s,
324
+ self.voiced_frames[-1].timestamp_s
325
+ ]
326
+ segments = [segment]
327
+
328
+ # last vad segments
329
+ vad_segments = self.vad_segments_generator(segments)
330
+ vad_segments = list(vad_segments)
331
+
332
+ if self.timestamp_end_s > 1e-5 and self.timestamp_end_s > 1e-5:
333
+ vad_segments = vad_segments + [[self.timestamp_start_s, self.timestamp_end_s]]
334
+ return vad_segments
335
+
336
+
337
+ def process_speech_probs(signal: np.ndarray, speech_probs: List[float], frame_step: int) -> np.ndarray:
338
+ speech_probs_ = list()
339
+ for p in speech_probs[1:]:
340
+ speech_probs_.extend([p] * frame_step)
341
+
342
+ pad = (signal.shape[0] - len(speech_probs_))
343
+ speech_probs_ = speech_probs_ + [0.0] * pad
344
+ speech_probs_ = np.array(speech_probs_, dtype=np.float32)
345
+
346
+ if len(speech_probs_) != len(signal):
347
+ raise AssertionError
348
+ return speech_probs_
349
+
350
+
351
+ def make_visualization(signal: np.ndarray, speech_probs, sample_rate: int, vad_segments: list):
352
+ time = np.arange(0, len(signal)) / sample_rate
353
+ plt.figure(figsize=(12, 5))
354
+ plt.plot(time, signal / 32768, color='b')
355
+ plt.plot(time, speech_probs, color='gray')
356
+ for start, end in vad_segments:
357
+ plt.axvline(x=start, ymin=0.15, ymax=0.85, color="g", linestyle="--", label="开始端点")
358
+ plt.axvline(x=end, ymin=0.15, ymax=0.85, color="r", linestyle="--", label="结束端点")
359
+
360
+ plt.show()
361
+ return
362
+
363
+
364
+ def get_args():
365
+ parser = argparse.ArgumentParser()
366
+ parser.add_argument(
367
+ "--wav_file",
368
+ # default=(project_path / "data/early_media/62/3300999628999191096.wav").as_posix(),
369
+ default=r"D:/Users/tianx/HuggingDatasets/nx_noise/data/speech/nx-speech/en-PH/2025-05-28/active_media_w_1f650e5c-bd22-4803-bb88-d670b00fccda_30.wav",
370
+ type=str,
371
+ )
372
+ parser.add_argument(
373
+ "--model_path",
374
+ default=(project_path / "trained_models/silero_vad.jit").as_posix(),
375
+ type=str,
376
+ )
377
+ args = parser.parse_args()
378
+ return args
379
+
380
+
381
+ SAMPLE_RATE = 8000
382
+
383
+
384
+ def main():
385
+ args = get_args()
386
+
387
+ sample_rate, signal = wavfile.read(args.wav_file)
388
+ if SAMPLE_RATE != sample_rate:
389
+ raise AssertionError
390
+
391
+ # model = SileroVoiceClassifier(model_path=args.model_path, sample_rate=SAMPLE_RATE)
392
+ model = WebRTCVoiceClassifier(agg=3, sample_rate=SAMPLE_RATE)
393
+ # model = CallVoiceClassifier(model_path=(project_path / "trained_models/cnn_voicemail_common_20231130").as_posix())
394
+
395
+ # silero vad
396
+ ring_vad = RingVad(model=model,
397
+ start_ring_rate=0.2,
398
+ end_ring_rate=0.1,
399
+ frame_size_ms=32,
400
+ frame_step_ms=32,
401
+ padding_length_ms=320,
402
+ max_silence_length_ms=320,
403
+ max_speech_length_s=100,
404
+ min_speech_length_s=0.1,
405
+ sample_rate=SAMPLE_RATE,
406
+ )
407
+ # webrtcvad
408
+ ring_vad = RingVad(model=model,
409
+ start_ring_rate=0.9,
410
+ end_ring_rate=0.1,
411
+ frame_size_ms=30,
412
+ frame_step_ms=30,
413
+ padding_length_ms=300,
414
+ max_silence_length_ms=300,
415
+ max_speech_length_s=100,
416
+ min_speech_length_s=0.1,
417
+ sample_rate=SAMPLE_RATE,
418
+ )
419
+ print(ring_vad)
420
+
421
+ vad_segments = list()
422
+
423
+ segments = ring_vad.vad(signal)
424
+ vad_segments += segments
425
+ for segment in segments:
426
+ print(segment)
427
+
428
+ # last vad segment
429
+ segments = ring_vad.last_vad_segments()
430
+ vad_segments += segments
431
+ for segment in segments:
432
+ print(segment)
433
+
434
+ print(ring_vad.speech_probs)
435
+ print(len(ring_vad.speech_probs))
436
+
437
+ # speech_probs
438
+ speech_probs = process_speech_probs(
439
+ signal=signal,
440
+ speech_probs=ring_vad.speech_probs,
441
+ frame_step=ring_vad.frame_step,
442
+ )
443
+
444
+ # plot
445
+ make_visualization(signal, speech_probs, SAMPLE_RATE, vad_segments)
446
+ return
447
+
448
+
449
+ if __name__ == "__main__":
450
+ main()
toolbox/webrtcvad/vad.py CHANGED
@@ -107,6 +107,7 @@ class WebRTCVad(object):
107
  for frame in frames:
108
  audio_bytes = bytes(frame.signal)
109
  is_speech = self._vad.is_speech(audio_bytes, self.sample_rate)
 
110
 
111
  if not self.triggered:
112
  self.ring_buffer.append((frame, is_speech))
@@ -189,8 +190,9 @@ def get_args():
189
  parser.add_argument(
190
  "--wav_file",
191
  # default=(project_path / "data/0eeaef67-ea59-4f2d-a5b8-b70c813fd45c.wav").as_posix(),
192
- default=(project_path / "data/1c998b62-c3aa-4541-b59a-d4a40b79eff3.wav").as_posix(),
193
  # default=(project_path / "data/8cbad66f-2c4e-43c2-ad11-ad95bab8bc15.wav").as_posix(),
 
194
  type=str,
195
  )
196
  parser.add_argument(
@@ -206,12 +208,12 @@ def get_args():
206
  )
207
  parser.add_argument(
208
  "--padding_duration_ms",
209
- default=30,
210
  type=int,
211
  )
212
  parser.add_argument(
213
  "--silence_duration_threshold",
214
- default=0.0,
215
  type=float,
216
  help="minimum silence duration, in seconds."
217
  )
 
107
  for frame in frames:
108
  audio_bytes = bytes(frame.signal)
109
  is_speech = self._vad.is_speech(audio_bytes, self.sample_rate)
110
+ print(f"is_speech: {is_speech}")
111
 
112
  if not self.triggered:
113
  self.ring_buffer.append((frame, is_speech))
 
190
  parser.add_argument(
191
  "--wav_file",
192
  # default=(project_path / "data/0eeaef67-ea59-4f2d-a5b8-b70c813fd45c.wav").as_posix(),
193
+ # default=(project_path / "data/1c998b62-c3aa-4541-b59a-d4a40b79eff3.wav").as_posix(),
194
  # default=(project_path / "data/8cbad66f-2c4e-43c2-ad11-ad95bab8bc15.wav").as_posix(),
195
+ default=r"D:/Users/tianx/HuggingDatasets/nx_noise/data/speech/nx-speech/en-PH/2025-05-28/active_media_w_1f650e5c-bd22-4803-bb88-d670b00fccda_30.wav",
196
  type=str,
197
  )
198
  parser.add_argument(
 
208
  )
209
  parser.add_argument(
210
  "--padding_duration_ms",
211
+ default=300,
212
  type=int,
213
  )
214
  parser.add_argument(
215
  "--silence_duration_threshold",
216
+ default=0.3,
217
  type=float,
218
  help="minimum silence duration, in seconds."
219
  )