HoneyTian commited on
Commit
5703a24
·
1 Parent(s): 360c8df
examples/cnn_vad_by_webrtcvad/run.sh ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ bash run.sh --stage 2 --stop_stage 2 --system_version centos \
6
+ --file_folder_name cnn-vad-by-webrtcvad-nx-dns3 \
7
+ --final_model_name cnn-vad-by-webrtcvad-nx-dns3 \
8
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
9
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
10
+
11
+ bash run.sh --stage 3 --stop_stage 3 --system_version centos \
12
+ --file_folder_name cnn-vad-by-webrtcvad-nx-dns3 \
13
+ --final_model_name cnn-vad-by-webrtcvad-nx-dns3 \
14
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
15
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
16
+
17
+
18
+ END
19
+
20
+
21
+ # params
22
+ system_version="windows";
23
+ verbose=true;
24
+ stage=0 # start from 0 if you need to start from data preparation
25
+ stop_stage=9
26
+
27
+ work_dir="$(pwd)"
28
+ file_folder_name=file_folder_name
29
+ final_model_name=final_model_name
30
+ config_file="yaml/config.yaml"
31
+ limit=10
32
+
33
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
34
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
35
+
36
+ max_count=-1
37
+
38
+ nohup_name=nohup.out
39
+
40
+ # model params
41
+ batch_size=64
42
+ max_epochs=200
43
+ save_top_k=10
44
+ patience=5
45
+
46
+
47
+ # parse options
48
+ while true; do
49
+ [ -z "${1:-}" ] && break; # break if there are no arguments
50
+ case "$1" in
51
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
52
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
53
+ old_value="(eval echo \\$$name)";
54
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
55
+ was_bool=true;
56
+ else
57
+ was_bool=false;
58
+ fi
59
+
60
+ # Set the variable to the right value-- the escaped quotes make it work if
61
+ # the option had spaces, like --cmd "queue.pl -sync y"
62
+ eval "${name}=\"$2\"";
63
+
64
+ # Check that Boolean-valued arguments are really Boolean.
65
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
66
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
67
+ exit 1;
68
+ fi
69
+ shift 2;
70
+ ;;
71
+
72
+ *) break;
73
+ esac
74
+ done
75
+
76
+ file_dir="${work_dir}/${file_folder_name}"
77
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
78
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
79
+
80
+ train_dataset="${file_dir}/train.jsonl"
81
+ valid_dataset="${file_dir}/valid.jsonl"
82
+
83
+ train_vad_dataset="${file_dir}/train-vad.jsonl"
84
+ valid_vad_dataset="${file_dir}/valid-vad.jsonl"
85
+
86
+ $verbose && echo "system_version: ${system_version}"
87
+ $verbose && echo "file_folder_name: ${file_folder_name}"
88
+
89
+ if [ $system_version == "windows" ]; then
90
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
91
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
92
+ #source /data/local/bin/nx_denoise/bin/activate
93
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
94
+ fi
95
+
96
+
97
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
98
+ $verbose && echo "stage 1: prepare data"
99
+ cd "${work_dir}" || exit 1
100
+ python3 step_1_prepare_data.py \
101
+ --noise_dir "${noise_dir}" \
102
+ --speech_dir "${speech_dir}" \
103
+ --train_dataset "${train_dataset}" \
104
+ --valid_dataset "${valid_dataset}" \
105
+ --max_count "${max_count}" \
106
+
107
+ fi
108
+
109
+
110
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
111
+ $verbose && echo "stage 2: make vad segments"
112
+ cd "${work_dir}" || exit 1
113
+ python3 step_2_make_vad_segments.py \
114
+ --train_dataset "${train_dataset}" \
115
+ --valid_dataset "${valid_dataset}" \
116
+ --train_vad_dataset "${train_vad_dataset}" \
117
+ --valid_vad_dataset "${valid_vad_dataset}" \
118
+
119
+ fi
120
+
121
+
122
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
123
+ $verbose && echo "stage 3: train model"
124
+ cd "${work_dir}" || exit 1
125
+ python3 step_4_train_model.py \
126
+ --train_dataset "${train_vad_dataset}" \
127
+ --valid_dataset "${valid_vad_dataset}" \
128
+ --serialization_dir "${file_dir}" \
129
+ --config_file "${config_file}" \
130
+
131
+ fi
132
+
133
+
134
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
135
+ $verbose && echo "stage 4: test model"
136
+ cd "${work_dir}" || exit 1
137
+ python3 step_3_evaluation.py \
138
+ --valid_dataset "${valid_dataset}" \
139
+ --model_dir "${file_dir}/best" \
140
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
141
+ --limit "${limit}" \
142
+
143
+ fi
144
+
145
+
146
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
147
+ $verbose && echo "stage 5: collect files"
148
+ cd "${work_dir}" || exit 1
149
+
150
+ mkdir -p ${final_model_dir}
151
+
152
+ cp "${file_dir}/best"/* "${final_model_dir}"
153
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
154
+
155
+ cd "${final_model_dir}/.." || exit 1;
156
+
157
+ if [ -e "${final_model_name}.zip" ]; then
158
+ rm -rf "${final_model_name}_backup.zip"
159
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
160
+ fi
161
+
162
+ zip -r "${final_model_name}.zip" "${final_model_name}"
163
+ rm -rf "${final_model_name}"
164
+
165
+ fi
166
+
167
+
168
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
169
+ $verbose && echo "stage 6: clear file_dir"
170
+ cd "${work_dir}" || exit 1
171
+
172
+ rm -rf "${file_dir}";
173
+
174
+ fi
examples/cnn_vad_by_webrtcvad/step_1_prepare_data.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+ import time
10
+
11
+ pwd = os.path.abspath(os.path.dirname(__file__))
12
+ sys.path.append(os.path.join(pwd, "../../"))
13
+
14
+ import librosa
15
+ import numpy as np
16
+ from tqdm import tqdm
17
+
18
+
19
+ def get_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument(
22
+ "--noise_dir",
23
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
24
+ type=str
25
+ )
26
+ parser.add_argument(
27
+ "--speech_dir",
28
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech",
29
+ type=str
30
+ )
31
+
32
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
33
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
34
+
35
+ parser.add_argument("--duration", default=8.0, type=float)
36
+ parser.add_argument("--min_speech_duration", default=6.0, type=float)
37
+ parser.add_argument("--max_speech_duration", default=8.0, type=float)
38
+ parser.add_argument("--min_snr_db", default=-10, type=float)
39
+ parser.add_argument("--max_snr_db", default=20, type=float)
40
+
41
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
42
+
43
+ parser.add_argument("--max_count", default=-1, type=int)
44
+
45
+ args = parser.parse_args()
46
+ return args
47
+
48
+
49
+ def target_second_noise_signal_generator(data_dir: str,
50
+ duration: int = 4,
51
+ sample_rate: int = 8000, max_epoch: int = 20000):
52
+ noise_list = list()
53
+ wait_duration = duration
54
+
55
+ data_dir = Path(data_dir)
56
+ for epoch_idx in range(max_epoch):
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+
60
+ if signal.ndim != 1:
61
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
62
+
63
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
64
+
65
+ offset = 0.
66
+ rest_duration = raw_duration
67
+
68
+ for _ in range(1000):
69
+ if rest_duration <= 0:
70
+ break
71
+ if rest_duration <= wait_duration:
72
+ noise_list.append({
73
+ "epoch_idx": epoch_idx,
74
+ "filename": filename.as_posix(),
75
+ "raw_duration": round(raw_duration, 4),
76
+ "offset": round(offset, 4),
77
+ "duration": None,
78
+ "duration_": round(rest_duration, 4),
79
+ })
80
+ wait_duration -= rest_duration
81
+ offset = 0
82
+ rest_duration = 0
83
+ elif rest_duration > wait_duration:
84
+ noise_list.append({
85
+ "epoch_idx": epoch_idx,
86
+ "filename": filename.as_posix(),
87
+ "raw_duration": round(raw_duration, 4),
88
+ "offset": round(offset, 4),
89
+ "duration": round(wait_duration, 4),
90
+ "duration_": round(wait_duration, 4),
91
+ })
92
+ offset += wait_duration
93
+ rest_duration -= wait_duration
94
+ wait_duration = 0
95
+ else:
96
+ raise AssertionError
97
+
98
+ if wait_duration <= 0:
99
+ yield noise_list
100
+ noise_list = list()
101
+ wait_duration = duration
102
+
103
+
104
+ def target_second_speech_signal_generator(data_dir: str,
105
+ min_duration: int = 4,
106
+ max_duration: int = 6,
107
+ sample_rate: int = 8000, max_epoch: int = 1):
108
+ data_dir = Path(data_dir)
109
+ for epoch_idx in range(max_epoch):
110
+ for filename in data_dir.glob("**/*.wav"):
111
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
112
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
113
+
114
+ if signal.ndim != 1:
115
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
116
+
117
+ if raw_duration < min_duration:
118
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
119
+ continue
120
+
121
+ if raw_duration < max_duration:
122
+ row = {
123
+ "epoch_idx": epoch_idx,
124
+ "filename": filename.as_posix(),
125
+ "raw_duration": round(raw_duration, 4),
126
+ "offset": 0.,
127
+ "duration": round(raw_duration, 4),
128
+ }
129
+ yield row
130
+
131
+ signal_length = len(signal)
132
+ win_size = int(max_duration * sample_rate)
133
+ for begin in range(0, signal_length - win_size, win_size):
134
+ if np.sum(signal[begin: begin+win_size]) == 0:
135
+ continue
136
+ row = {
137
+ "epoch_idx": epoch_idx,
138
+ "filename": filename.as_posix(),
139
+ "raw_duration": round(raw_duration, 4),
140
+ "offset": round(begin / sample_rate, 4),
141
+ "duration": round(max_duration, 4),
142
+ }
143
+ yield row
144
+
145
+
146
+ def main():
147
+ args = get_args()
148
+
149
+ noise_dir = Path(args.noise_dir)
150
+ speech_dir = Path(args.speech_dir)
151
+
152
+ train_dataset = Path(args.train_dataset)
153
+ valid_dataset = Path(args.valid_dataset)
154
+ train_dataset.parent.mkdir(parents=True, exist_ok=True)
155
+ valid_dataset.parent.mkdir(parents=True, exist_ok=True)
156
+
157
+ noise_generator = target_second_noise_signal_generator(
158
+ noise_dir.as_posix(),
159
+ duration=args.duration,
160
+ sample_rate=args.target_sample_rate,
161
+ max_epoch=100000,
162
+ )
163
+ speech_generator = target_second_speech_signal_generator(
164
+ speech_dir.as_posix(),
165
+ min_duration=args.min_speech_duration,
166
+ max_duration=args.max_speech_duration,
167
+ sample_rate=args.target_sample_rate,
168
+ max_epoch=1,
169
+ )
170
+
171
+ count = 0
172
+ process_bar = tqdm(desc="build dataset jsonl")
173
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
174
+ for speech, noise_list in zip(speech_generator, noise_generator):
175
+ if count >= args.max_count > 0:
176
+ break
177
+
178
+ # row
179
+ speech_filename = speech["filename"]
180
+ speech_raw_duration = speech["raw_duration"]
181
+ speech_offset = speech["offset"]
182
+ speech_duration = speech["duration"]
183
+
184
+ noise_list = [
185
+ {
186
+ "filename": noise["filename"],
187
+ "raw_duration": noise["raw_duration"],
188
+ "offset": noise["offset"],
189
+ "duration": noise["duration"],
190
+ }
191
+ for noise in noise_list
192
+ ]
193
+
194
+ # row
195
+ random1 = random.random()
196
+ random2 = random.random()
197
+
198
+ row = {
199
+ "count": count,
200
+
201
+ "speech_filename": speech_filename,
202
+ "speech_raw_duration": speech_raw_duration,
203
+ "speech_offset": speech_offset,
204
+ "speech_duration": speech_duration,
205
+
206
+ "noise_list": noise_list,
207
+
208
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
209
+
210
+ "random1": random1,
211
+ }
212
+ row = json.dumps(row, ensure_ascii=False)
213
+ if random2 < (1 / 300 / 1):
214
+ fvalid.write(f"{row}\n")
215
+ else:
216
+ ftrain.write(f"{row}\n")
217
+
218
+ count += 1
219
+ duration_seconds = count * args.duration
220
+ duration_hours = duration_seconds / 3600
221
+
222
+ process_bar.update(n=1)
223
+ process_bar.set_postfix({
224
+ "duration_hours": round(duration_hours, 4),
225
+ })
226
+
227
+ return
228
+
229
+
230
+ if __name__ == "__main__":
231
+ main()
examples/cnn_vad_by_webrtcvad/step_2_make_vad_segments.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ import sys
7
+ from typing import List, Tuple
8
+
9
+ pwd = os.path.abspath(os.path.dirname(__file__))
10
+ sys.path.append(os.path.join(pwd, "../../"))
11
+
12
+ import librosa
13
+ import numpy as np
14
+ from tqdm import tqdm
15
+
16
+ from project_settings import project_path
17
+ from toolbox.vad.vad import WebRTCVoiceClassifier, SileroVoiceClassifier, CCSoundsClassifier, RingVad
18
+
19
+
20
+ def get_args():
21
+ parser = argparse.ArgumentParser()
22
+
23
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
24
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
25
+
26
+ parser.add_argument("--train_vad_dataset", default="train-vad.jsonl", type=str)
27
+ parser.add_argument("--valid_vad_dataset", default="valid-vad.jsonl", type=str)
28
+
29
+ parser.add_argument("--duration", default=8.0, type=float)
30
+ parser.add_argument("--expected_sample_rate", default=8000, type=int)
31
+
32
+ parser.add_argument(
33
+ "--silero_model_path",
34
+ default=(project_path / "trained_models/silero_vad.jit").as_posix(),
35
+ type=str,
36
+ )
37
+ parser.add_argument(
38
+ "--cc_sounds_model_path",
39
+ default=(project_path / "trained_models/sound-2-ch32.zip").as_posix(),
40
+ type=str,
41
+ )
42
+ args = parser.parse_args()
43
+ return args
44
+
45
+
46
+ def get_non_silence_segments(waveform: np.ndarray, sample_rate: int = 8000):
47
+ non_silent_intervals = librosa.effects.split(
48
+ waveform,
49
+ top_db=40, # 静音阈值(单位:dB)
50
+ frame_length=512, # 分析帧长
51
+ hop_length=128 # 帧移
52
+ )
53
+
54
+ # 输出非静音段的时间区间(单位:秒)
55
+ result = [(start / sample_rate, end / sample_rate) for (start, end) in non_silent_intervals]
56
+ return result
57
+
58
+
59
+ def get_intersection(non_silence: list[tuple[float, float]],
60
+ speech: list[tuple[float, float]]) -> list[tuple[float, float]]:
61
+ """
62
+ 计算语音段与非静音段的交集
63
+ :param non_silence: 非静音段列表,格式 [(start1, end1), ...]
64
+ :param speech: 语音检测段列表,格式 [(start2, end2), ...]
65
+ :return: 交集段列表,格式 [(start, end), ...]
66
+ """
67
+ # 按起始时间排序(假设输入已排序可不排)
68
+ non_silence = sorted(non_silence, key=lambda x: x[0])
69
+ speech = sorted(speech, key=lambda x: x[0])
70
+
71
+ result = []
72
+ i = j = 0
73
+
74
+ while i < len(non_silence) and j < len(speech):
75
+ ns_start, ns_end = non_silence[i]
76
+ sp_start, sp_end = speech[j]
77
+
78
+ # 计算重叠区间
79
+ overlap_start = max(ns_start, sp_start)
80
+ overlap_end = min(ns_end, sp_end)
81
+
82
+ if overlap_start < overlap_end:
83
+ result.append((overlap_start, overlap_end))
84
+
85
+ # 移动指针策略:优先处理先结束的区间
86
+ if ns_end < sp_end:
87
+ i += 1 # 非静音段先结束
88
+ else:
89
+ j += 1 # 语音段先结束
90
+
91
+ return result
92
+
93
+
94
+ def main():
95
+ args = get_args()
96
+
97
+ # webrtcvad
98
+ # model = SileroVoiceClassifier(model_path=args.silero_model_path, sample_rate=args.expected_sample_rate)
99
+ # w_vad = RingVad(
100
+ # model=model,
101
+ # start_ring_rate=0.2,
102
+ # end_ring_rate=0.1,
103
+ # frame_size_ms=32,
104
+ # frame_step_ms=32,
105
+ # padding_length_ms=320,
106
+ # max_silence_length_ms=320,
107
+ # max_speech_length_s=100,
108
+ # min_speech_length_s=0.1,
109
+ # sample_rate=args.expected_sample_rate,
110
+ # )
111
+
112
+ # webrtcvad
113
+ model = WebRTCVoiceClassifier(agg=3, sample_rate=args.expected_sample_rate)
114
+ w_vad = RingVad(
115
+ model=model,
116
+ start_ring_rate=0.9,
117
+ end_ring_rate=0.1,
118
+ frame_size_ms=30,
119
+ frame_step_ms=30,
120
+ padding_length_ms=30,
121
+ max_silence_length_ms=0,
122
+ max_speech_length_s=100,
123
+ min_speech_length_s=0.1,
124
+ sample_rate=args.expected_sample_rate,
125
+ )
126
+
127
+ # cc sounds
128
+ # model = CCSoundsClassifier(model_path=args.cc_sounds_model_path, sample_rate=args.expected_sample_rate)
129
+ # w_vad = RingVad(
130
+ # model=model,
131
+ # start_ring_rate=0.5,
132
+ # end_ring_rate=0.3,
133
+ # frame_size_ms=300,
134
+ # frame_step_ms=300,
135
+ # padding_length_ms=300,
136
+ # max_silence_length_ms=100,
137
+ # max_speech_length_s=100,
138
+ # min_speech_length_s=0.1,
139
+ # sample_rate=args.expected_sample_rate,
140
+ # )
141
+
142
+ # valid
143
+ va_duration = 0
144
+ raw_duration = 0
145
+ use_duration = 0
146
+
147
+ count = 0
148
+ process_bar_valid = tqdm(desc="process valid dataset jsonl")
149
+ with (open(args.valid_dataset, "r", encoding="utf-8") as fvalid,
150
+ open(args.valid_vad_dataset, "w", encoding="utf-8") as fvalid_vad):
151
+ for row in fvalid:
152
+ row = json.loads(row)
153
+
154
+ speech_filename = row["speech_filename"]
155
+ speech_offset = row["speech_offset"]
156
+ speech_duration = row["speech_duration"]
157
+
158
+ waveform, _ = librosa.load(
159
+ speech_filename,
160
+ sr=args.expected_sample_rate,
161
+ offset=speech_offset,
162
+ duration=speech_duration,
163
+ )
164
+ waveform = np.array(waveform * (1 << 15), dtype=np.int16)
165
+
166
+ # non_silence_segments
167
+ non_silence_segments = get_non_silence_segments(waveform, sample_rate=args.expected_sample_rate)
168
+
169
+ # vad
170
+ vad_segments = list()
171
+ segments = w_vad.vad(waveform)
172
+ vad_segments += segments
173
+ segments = w_vad.last_vad_segments()
174
+ vad_segments += segments
175
+ w_vad.reset()
176
+
177
+ vad_segments = get_intersection(non_silence_segments, vad_segments)
178
+ row["vad_segments"] = vad_segments
179
+
180
+ row = json.dumps(row, ensure_ascii=False)
181
+ fvalid_vad.write(f"{row}\n")
182
+
183
+ va_duration += sum([vad_segment[1] - vad_segment[0] for vad_segment in vad_segments])
184
+ raw_duration += speech_duration
185
+ use_duration += args.duration
186
+
187
+ count += 1
188
+
189
+ va_rate = va_duration / use_duration
190
+ va_raw_rate = va_duration / raw_duration
191
+ use_duration_hours = use_duration / 3600
192
+
193
+ process_bar_valid.update(n=1)
194
+ process_bar_valid.set_postfix({
195
+ "va_rate": round(va_rate, 4),
196
+ "va_raw_rate": round(va_raw_rate, 4),
197
+ "duration_hours": round(use_duration_hours, 4),
198
+ })
199
+
200
+ # train
201
+ va_duration = 0
202
+ raw_duration = 0
203
+ use_duration = 0
204
+
205
+ count = 0
206
+ process_bar_train = tqdm(desc="process train dataset jsonl")
207
+ with (open(args.train_dataset, "r", encoding="utf-8") as ftrain,
208
+ open(args.train_vad_dataset, "w", encoding="utf-8") as ftrain_vad):
209
+ for row in ftrain:
210
+ row = json.loads(row)
211
+
212
+ speech_filename = row["speech_filename"]
213
+ speech_offset = row["speech_offset"]
214
+ speech_duration = row["speech_duration"]
215
+
216
+ waveform, _ = librosa.load(
217
+ speech_filename,
218
+ sr=args.expected_sample_rate,
219
+ offset=speech_offset,
220
+ duration=speech_duration,
221
+ )
222
+ waveform = np.array(waveform * (1 << 15), dtype=np.int16)
223
+
224
+ # non_silence_segments
225
+ non_silence_segments = get_non_silence_segments(waveform, sample_rate=args.expected_sample_rate)
226
+
227
+ # vad
228
+ vad_segments = list()
229
+ segments = w_vad.vad(waveform)
230
+ vad_segments += segments
231
+ segments = w_vad.last_vad_segments()
232
+ vad_segments += segments
233
+ w_vad.reset()
234
+
235
+ vad_segments = get_intersection(non_silence_segments, vad_segments)
236
+ row["vad_segments"] = vad_segments
237
+
238
+ row = json.dumps(row, ensure_ascii=False)
239
+ ftrain_vad.write(f"{row}\n")
240
+
241
+ va_duration += sum([vad_segment[1] - vad_segment[0] for vad_segment in vad_segments])
242
+ raw_duration += speech_duration
243
+ use_duration += args.duration
244
+
245
+ count += 1
246
+
247
+ va_rate = va_duration / use_duration
248
+ va_raw_rate = va_duration / raw_duration
249
+ use_duration_hours = use_duration / 3600
250
+
251
+ process_bar_train.update(n=1)
252
+ process_bar_train.set_postfix({
253
+ "va_rate": round(va_rate, 4),
254
+ "va_raw_rate": round(va_raw_rate, 4),
255
+ "duration_hours": round(use_duration_hours, 4),
256
+ })
257
+
258
+ return
259
+
260
+
261
+ if __name__ == "__main__":
262
+ main()
examples/cnn_vad_by_webrtcvad/step_3_check_vad.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ import sys
7
+
8
+ pwd = os.path.abspath(os.path.dirname(__file__))
9
+ sys.path.append(os.path.join(pwd, "../../"))
10
+
11
+ import librosa
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ from scipy.io import wavfile
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+
21
+ parser.add_argument("--train_vad_dataset", default="train-vad.jsonl", type=str)
22
+ parser.add_argument("--valid_vad_dataset", default="valid-vad.jsonl", type=str)
23
+
24
+ parser.add_argument("--duration", default=8.0, type=float)
25
+ parser.add_argument("--expected_sample_rate", default=8000, type=int)
26
+
27
+ args = parser.parse_args()
28
+ return args
29
+
30
+
31
+ def main():
32
+ args = get_args()
33
+
34
+ SAMPLE_RATE = 8000
35
+
36
+ with open(args.valid_vad_dataset, "r", encoding="utf-8") as f:
37
+ for row in f:
38
+ row = json.loads(row)
39
+
40
+ speech_filename = row["speech_filename"]
41
+ speech_offset = row["speech_offset"]
42
+ speech_duration = row["speech_duration"]
43
+
44
+ vad_segments = row["vad_segments"]
45
+
46
+ print(f"speech_filename: {speech_filename}")
47
+ signal, sample_rate = librosa.load(
48
+ speech_filename,
49
+ sr=SAMPLE_RATE,
50
+ offset=speech_offset,
51
+ duration=speech_duration,
52
+ )
53
+
54
+ # plot
55
+ time = np.arange(0, len(signal)) / sample_rate
56
+ plt.figure(figsize=(12, 5))
57
+ plt.plot(time, signal, color='b')
58
+ for start, end in vad_segments:
59
+ plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') # 标记开始端点
60
+ plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') # 标记结束端点
61
+
62
+ plt.show()
63
+
64
+ return
65
+
66
+
67
+ if __name__ == "__main__":
68
+ main()
examples/cnn_vad_by_webrtcvad/step_4_train_model.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import logging
6
+ from logging.handlers import TimedRotatingFileHandler
7
+ import os
8
+ import platform
9
+ from pathlib import Path
10
+ import random
11
+ import sys
12
+ import shutil
13
+ from typing import List, Tuple
14
+
15
+ pwd = os.path.abspath(os.path.dirname(__file__))
16
+ sys.path.append(os.path.join(pwd, "../../"))
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn import functional as F
22
+ from torch.utils.data.dataloader import DataLoader
23
+ from tqdm import tqdm
24
+
25
+ from toolbox.torch.utils.data.dataset.vad_padding_jsonl_dataset import VadPaddingJsonlDataset
26
+ from toolbox.torchaudio.models.vad.cnn_vad.configuration_cnn_vad import CNNVadConfig
27
+ from toolbox.torchaudio.models.vad.cnn_vad.modeling_cnn_vad import CNNVadModel, CNNVadPretrainedModel
28
+ from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss
29
+ from toolbox.torchaudio.losses.bce_loss import BCELoss
30
+ from toolbox.torchaudio.losses.dice_loss import DiceLoss
31
+ from toolbox.torchaudio.metrics.vad_metrics.vad_accuracy import VadAccuracy
32
+ from toolbox.torchaudio.metrics.vad_metrics.vad_f1_score import VadF1Score
33
+
34
+
35
+ def get_args():
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument("--train_dataset", default="train-vad.jsonl", type=str)
38
+ parser.add_argument("--valid_dataset", default="valid-vad.jsonl", type=str)
39
+
40
+ parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
41
+ parser.add_argument("--patience", default=30, type=int)
42
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
43
+
44
+ parser.add_argument("--config_file", default="yaml/config.yaml", type=str)
45
+
46
+ args = parser.parse_args()
47
+ return args
48
+
49
+
50
+ def logging_config(file_dir: str):
51
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
52
+
53
+ logging.basicConfig(format=fmt,
54
+ datefmt="%m/%d/%Y %H:%M:%S",
55
+ level=logging.INFO)
56
+ file_handler = TimedRotatingFileHandler(
57
+ filename=os.path.join(file_dir, "main.log"),
58
+ encoding="utf-8",
59
+ when="D",
60
+ interval=1,
61
+ backupCount=7
62
+ )
63
+ file_handler.setLevel(logging.INFO)
64
+ file_handler.setFormatter(logging.Formatter(fmt))
65
+ logger = logging.getLogger(__name__)
66
+ logger.addHandler(file_handler)
67
+
68
+ return logger
69
+
70
+
71
+ class CollateFunction(object):
72
+ def __init__(self):
73
+ pass
74
+
75
+ def __call__(self, batch: List[dict]):
76
+ noisy_audios = list()
77
+ batch_vad_segments = list()
78
+
79
+ for sample in batch:
80
+ noisy_wave: torch.Tensor = sample["noisy_wave"]
81
+ vad_segments: List[Tuple[float, float]] = sample["vad_segments"]
82
+
83
+ noisy_audios.append(noisy_wave)
84
+ batch_vad_segments.append(vad_segments)
85
+
86
+ noisy_audios = torch.stack(noisy_audios)
87
+
88
+ # assert
89
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
90
+ raise AssertionError("nan or inf in noisy_audios")
91
+
92
+ return noisy_audios, batch_vad_segments
93
+
94
+
95
+ collate_fn = CollateFunction()
96
+
97
+
98
+ def main():
99
+ args = get_args()
100
+
101
+ config = CNNVadConfig.from_pretrained(
102
+ pretrained_model_name_or_path=args.config_file,
103
+ )
104
+
105
+ serialization_dir = Path(args.serialization_dir)
106
+ serialization_dir.mkdir(parents=True, exist_ok=True)
107
+
108
+ logger = logging_config(serialization_dir)
109
+
110
+ random.seed(config.seed)
111
+ np.random.seed(config.seed)
112
+ torch.manual_seed(config.seed)
113
+ logger.info(f"set seed: {config.seed}")
114
+
115
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
+ n_gpu = torch.cuda.device_count()
117
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
118
+
119
+ # datasets
120
+ train_dataset = VadPaddingJsonlDataset(
121
+ jsonl_file=args.train_dataset,
122
+ expected_sample_rate=config.sample_rate,
123
+ max_wave_value=32768.0,
124
+ min_snr_db=config.min_snr_db,
125
+ max_snr_db=config.max_snr_db,
126
+ # skip=225000,
127
+ )
128
+ valid_dataset = VadPaddingJsonlDataset(
129
+ jsonl_file=args.valid_dataset,
130
+ expected_sample_rate=config.sample_rate,
131
+ max_wave_value=32768.0,
132
+ # min_snr_db=config.min_snr_db,
133
+ # max_snr_db=config.max_snr_db,
134
+ )
135
+ train_data_loader = DataLoader(
136
+ dataset=train_dataset,
137
+ batch_size=config.batch_size,
138
+ # shuffle=True,
139
+ sampler=None,
140
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
141
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
142
+ collate_fn=collate_fn,
143
+ pin_memory=False,
144
+ prefetch_factor=None if platform.system() == "Windows" else 2,
145
+ )
146
+ valid_data_loader = DataLoader(
147
+ dataset=valid_dataset,
148
+ batch_size=config.batch_size,
149
+ # shuffle=True,
150
+ sampler=None,
151
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
152
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
153
+ collate_fn=collate_fn,
154
+ pin_memory=False,
155
+ prefetch_factor=None if platform.system() == "Windows" else 2,
156
+ )
157
+
158
+ # models
159
+ logger.info(f"prepare models. config_file: {args.config_file}")
160
+ model = CNNVadPretrainedModel(config).to(device)
161
+ model.to(device)
162
+ model.train()
163
+
164
+ # optimizer
165
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
166
+ optimizer = torch.optim.AdamW(model.parameters(), config.lr)
167
+
168
+ # resume training
169
+ last_step_idx = -1
170
+ last_epoch = -1
171
+ for step_idx_str in serialization_dir.glob("steps-*"):
172
+ step_idx_str = Path(step_idx_str)
173
+ step_idx = step_idx_str.stem.split("-")[1]
174
+ step_idx = int(step_idx)
175
+ if step_idx > last_step_idx:
176
+ last_step_idx = step_idx
177
+ # last_epoch = 1
178
+
179
+ if last_step_idx != -1:
180
+ logger.info(f"resume from steps-{last_step_idx}.")
181
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
182
+
183
+ logger.info(f"load state dict for model.")
184
+ with open(model_pt.as_posix(), "rb") as f:
185
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
186
+ model.load_state_dict(state_dict, strict=True)
187
+
188
+ if config.lr_scheduler == "CosineAnnealingLR":
189
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
190
+ optimizer,
191
+ last_epoch=last_epoch,
192
+ # T_max=10 * config.eval_steps,
193
+ # eta_min=0.01 * config.lr,
194
+ **config.lr_scheduler_kwargs,
195
+ )
196
+ elif config.lr_scheduler == "MultiStepLR":
197
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
198
+ optimizer,
199
+ last_epoch=last_epoch,
200
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
201
+ )
202
+ else:
203
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
204
+
205
+ bce_loss_fn = BCELoss(reduction="mean").to(device)
206
+ dice_loss_fn = DiceLoss(reduction="mean").to(device)
207
+
208
+ vad_accuracy_metrics_fn = VadAccuracy(threshold=0.5)
209
+ vad_f1_score_metrics_fn = VadF1Score(threshold=0.5)
210
+
211
+ # training loop
212
+
213
+ # state
214
+ average_loss = 1000000000
215
+ average_bce_loss = 1000000000
216
+ average_dice_loss = 1000000000
217
+
218
+ accuracy = -1
219
+ f1 = -1
220
+ precision = -1
221
+ recall = -1
222
+
223
+ model_list = list()
224
+ best_epoch_idx = None
225
+ best_step_idx = None
226
+ best_metric = None
227
+ patience_count = 0
228
+
229
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
230
+
231
+ logger.info("training")
232
+ early_stop_flag = False
233
+ for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
234
+ if early_stop_flag:
235
+ break
236
+
237
+ # train
238
+ model.train()
239
+ vad_accuracy_metrics_fn.reset()
240
+ vad_f1_score_metrics_fn.reset()
241
+
242
+ total_loss = 0.
243
+ total_bce_loss = 0.
244
+ total_dice_loss = 0.
245
+ total_batches = 0.
246
+
247
+ progress_bar_train = tqdm(
248
+ initial=step_idx,
249
+ desc="Training; epoch-{}".format(epoch_idx),
250
+ )
251
+ for train_batch in train_data_loader:
252
+ noisy_audios, batch_vad_segments = train_batch
253
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
254
+ # noisy_audios shape: [b, num_samples]
255
+ num_samples = noisy_audios.shape[-1]
256
+
257
+ logits, probs = model.forward(noisy_audios)
258
+
259
+ targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
260
+
261
+ bce_loss = bce_loss_fn.forward(probs, targets)
262
+ dice_loss = dice_loss_fn.forward(probs, targets)
263
+
264
+ loss = 1.0 * bce_loss + 1.0 * dice_loss
265
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
266
+ logger.info(f"find nan or inf in loss. continue.")
267
+ continue
268
+
269
+ vad_accuracy_metrics_fn.__call__(probs, targets)
270
+ vad_f1_score_metrics_fn.__call__(probs, targets)
271
+
272
+ optimizer.zero_grad()
273
+ loss.backward()
274
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
275
+ optimizer.step()
276
+ lr_scheduler.step()
277
+
278
+ total_loss += loss.item()
279
+ total_bce_loss += bce_loss.item()
280
+ total_dice_loss += dice_loss.item()
281
+ total_batches += 1
282
+
283
+ average_loss = round(total_loss / total_batches, 4)
284
+ average_bce_loss = round(total_bce_loss / total_batches, 4)
285
+ average_dice_loss = round(total_dice_loss / total_batches, 4)
286
+
287
+ metrics = vad_accuracy_metrics_fn.get_metric()
288
+ accuracy = metrics["accuracy"]
289
+ metrics = vad_f1_score_metrics_fn.get_metric()
290
+ f1 = metrics["f1"]
291
+ precision = metrics["precision"]
292
+ recall = metrics["recall"]
293
+
294
+ progress_bar_train.update(1)
295
+ progress_bar_train.set_postfix({
296
+ "lr": lr_scheduler.get_last_lr()[0],
297
+ "loss": average_loss,
298
+ "bce_loss": average_bce_loss,
299
+ "dice_loss": average_dice_loss,
300
+ "accuracy": accuracy,
301
+ "f1": f1,
302
+ "precision": precision,
303
+ "recall": recall,
304
+ })
305
+
306
+ # evaluation
307
+ step_idx += 1
308
+ if step_idx % config.eval_steps == 0:
309
+ with torch.no_grad():
310
+ torch.cuda.empty_cache()
311
+
312
+ model.eval()
313
+ vad_accuracy_metrics_fn.reset()
314
+ vad_f1_score_metrics_fn.reset()
315
+
316
+ total_loss = 0.
317
+ total_bce_loss = 0.
318
+ total_dice_loss = 0.
319
+ total_batches = 0.
320
+
321
+ progress_bar_train.close()
322
+ progress_bar_eval = tqdm(
323
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
324
+ )
325
+ for eval_batch in valid_data_loader:
326
+ noisy_audios, batch_vad_segments = train_batch
327
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
328
+ # noisy_audios shape: [b, num_samples]
329
+ num_samples = noisy_audios.shape[-1]
330
+
331
+ logits, probs = model.forward(noisy_audios)
332
+
333
+ targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
334
+
335
+ bce_loss = bce_loss_fn.forward(probs, targets)
336
+ dice_loss = dice_loss_fn.forward(probs, targets)
337
+
338
+ loss = 1.0 * bce_loss + 1.0 * dice_loss
339
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
340
+ logger.info(f"find nan or inf in loss. continue.")
341
+ continue
342
+
343
+ vad_accuracy_metrics_fn.__call__(probs, targets)
344
+ vad_f1_score_metrics_fn.__call__(probs, targets)
345
+
346
+ total_loss += loss.item()
347
+ total_bce_loss += bce_loss.item()
348
+ total_dice_loss += dice_loss.item()
349
+ total_batches += 1
350
+
351
+ average_loss = round(total_loss / total_batches, 4)
352
+ average_bce_loss = round(total_bce_loss / total_batches, 4)
353
+ average_dice_loss = round(total_dice_loss / total_batches, 4)
354
+
355
+ metrics = vad_accuracy_metrics_fn.get_metric()
356
+ accuracy = metrics["accuracy"]
357
+ metrics = vad_f1_score_metrics_fn.get_metric()
358
+ f1 = metrics["f1"]
359
+ precision = metrics["precision"]
360
+ recall = metrics["recall"]
361
+
362
+ progress_bar_eval.update(1)
363
+ progress_bar_eval.set_postfix({
364
+ "lr": lr_scheduler.get_last_lr()[0],
365
+ "loss": average_loss,
366
+ "bce_loss": average_bce_loss,
367
+ "dice_loss": average_dice_loss,
368
+ "accuracy": accuracy,
369
+ "f1": f1,
370
+ "precision": precision,
371
+ "recall": recall,
372
+ })
373
+
374
+ model.train()
375
+ vad_accuracy_metrics_fn.reset()
376
+ vad_f1_score_metrics_fn.reset()
377
+
378
+ total_loss = 0.
379
+ total_bce_loss = 0.
380
+ total_dice_loss = 0.
381
+ total_batches = 0.
382
+
383
+ progress_bar_eval.close()
384
+ progress_bar_train = tqdm(
385
+ initial=progress_bar_train.n,
386
+ postfix=progress_bar_train.postfix,
387
+ desc=progress_bar_train.desc,
388
+ )
389
+
390
+ # save path
391
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
392
+ save_dir.mkdir(parents=True, exist_ok=False)
393
+
394
+ # save models
395
+ model.save_pretrained(save_dir.as_posix())
396
+
397
+ model_list.append(save_dir)
398
+ if len(model_list) >= args.num_serialized_models_to_keep:
399
+ model_to_delete: Path = model_list.pop(0)
400
+ shutil.rmtree(model_to_delete.as_posix())
401
+
402
+ # save metric
403
+ if best_metric is None:
404
+ best_epoch_idx = epoch_idx
405
+ best_step_idx = step_idx
406
+ best_metric = f1
407
+ elif f1 >= best_metric:
408
+ # great is better.
409
+ best_epoch_idx = epoch_idx
410
+ best_step_idx = step_idx
411
+ best_metric = f1
412
+ else:
413
+ pass
414
+
415
+ metrics = {
416
+ "epoch_idx": epoch_idx,
417
+ "best_epoch_idx": best_epoch_idx,
418
+ "best_step_idx": best_step_idx,
419
+ "loss": average_loss,
420
+ "bce_loss": average_bce_loss,
421
+ "dice_loss": average_dice_loss,
422
+
423
+ "accuracy": accuracy,
424
+ }
425
+ metrics_filename = save_dir / "metrics_epoch.json"
426
+ with open(metrics_filename, "w", encoding="utf-8") as f:
427
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
428
+
429
+ # save best
430
+ best_dir = serialization_dir / "best"
431
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
432
+ if best_dir.exists():
433
+ shutil.rmtree(best_dir)
434
+ shutil.copytree(save_dir, best_dir)
435
+
436
+ # early stop
437
+ early_stop_flag = False
438
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
439
+ patience_count = 0
440
+ else:
441
+ patience_count += 1
442
+ if patience_count >= args.patience:
443
+ early_stop_flag = True
444
+
445
+ # early stop
446
+ if early_stop_flag:
447
+ break
448
+
449
+ return
450
+
451
+
452
+ if __name__ == "__main__":
453
+ main()
examples/cnn_vad_by_webrtcvad/yaml/config.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "fsmn_vad"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ nfft: 512
6
+ win_size: 240
7
+ hop_size: 80
8
+ win_type: hann
9
+
10
+ # model
11
+ fsmn_input_size: 257
12
+ fsmn_input_affine_size: 140
13
+ fsmn_hidden_size: 250
14
+ fsmn_basic_block_layers: 4
15
+ fsmn_basic_block_hidden_size: 128
16
+ fsmn_basic_block_lorder: 20
17
+ fsmn_basic_block_rorder: 0
18
+ fsmn_basic_block_lstride: 1
19
+ fsmn_basic_block_rstride: 0
20
+ fsmn_output_affine_size: 140
21
+ fsmn_output_size: 1
22
+
23
+ use_softmax: false
24
+
25
+ # data
26
+ min_snr_db: -10
27
+ max_snr_db: 20
28
+
29
+ # train
30
+ lr: 0.001
31
+ lr_scheduler: "CosineAnnealingLR"
32
+ lr_scheduler_kwargs:
33
+ T_max: 250000
34
+ eta_min: 0.0001
35
+
36
+ max_epochs: 100
37
+ clip_grad_norm: 10.0
38
+ seed: 1234
39
+
40
+ num_workers: 4
41
+ batch_size: 128
42
+ eval_steps: 25000
examples/fsmn_vad_by_webrtcvad/run.sh CHANGED
@@ -2,7 +2,7 @@
2
 
3
  : <<'END'
4
 
5
- bash run.sh --stage 2 --stop_stage 3 --system_version centos \
6
  --file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
7
  --final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
8
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
 
2
 
3
  : <<'END'
4
 
5
+ bash run.sh --stage 2 --stop_stage 2 --system_version centos \
6
  --file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
7
  --final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
8
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
examples/fsmn_vad_by_webrtcvad/step_4_train_model.py CHANGED
@@ -129,8 +129,8 @@ def main():
129
  jsonl_file=args.valid_dataset,
130
  expected_sample_rate=config.sample_rate,
131
  max_wave_value=32768.0,
132
- min_snr_db=config.min_snr_db,
133
- max_snr_db=config.max_snr_db,
134
  )
135
  train_data_loader = DataLoader(
136
  dataset=train_dataset,
 
129
  jsonl_file=args.valid_dataset,
130
  expected_sample_rate=config.sample_rate,
131
  max_wave_value=32768.0,
132
+ # min_snr_db=config.min_snr_db,
133
+ # max_snr_db=config.max_snr_db,
134
  )
135
  train_data_loader = DataLoader(
136
  dataset=train_dataset,
examples/silero_vad_by_webrtcvad/step_4_train_model.py CHANGED
@@ -129,8 +129,8 @@ def main():
129
  jsonl_file=args.valid_dataset,
130
  expected_sample_rate=config.sample_rate,
131
  max_wave_value=32768.0,
132
- min_snr_db=config.min_snr_db,
133
- max_snr_db=config.max_snr_db,
134
  )
135
  train_data_loader = DataLoader(
136
  dataset=train_dataset,
 
129
  jsonl_file=args.valid_dataset,
130
  expected_sample_rate=config.sample_rate,
131
  max_wave_value=32768.0,
132
+ # min_snr_db=config.min_snr_db,
133
+ # max_snr_db=config.max_snr_db,
134
  )
135
  train_data_loader = DataLoader(
136
  dataset=train_dataset,
toolbox/torchaudio/models/vad/cnn_vad/configuration_cnn_vad.py CHANGED
@@ -5,6 +5,38 @@ from typing import Tuple
5
  from toolbox.torchaudio.configuration_utils import PretrainedConfig
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  class CNNVadConfig(PretrainedConfig):
9
  def __init__(self,
10
  sample_rate: int = 8000,
@@ -14,7 +46,7 @@ class CNNVadConfig(PretrainedConfig):
14
  win_type: str = "hann",
15
 
16
  conv2d_block_param_list: list = None,
17
- classifier_hidden_size: int = 128,
18
 
19
  min_snr_db: float = -10,
20
  max_snr_db: float = 20,
@@ -42,8 +74,8 @@ class CNNVadConfig(PretrainedConfig):
42
  self.win_type = win_type
43
 
44
  # encoder
45
- self.conv2d_block_param_list = conv2d_block_param_list
46
- self.classifier_hidden_size = classifier_hidden_size
47
 
48
  # data snr
49
  self.min_snr_db = min_snr_db
@@ -64,7 +96,7 @@ class CNNVadConfig(PretrainedConfig):
64
 
65
 
66
  def main():
67
- config = SileroVadConfig()
68
  config.to_yaml_file("config.yaml")
69
  return
70
 
 
5
  from toolbox.torchaudio.configuration_utils import PretrainedConfig
6
 
7
 
8
+ DEFAULT_CONV2D_BLOCK_PARAM_LIST = [
9
+ {
10
+ 'batch_norm': True,
11
+ 'in_channels': 1,
12
+ 'out_channels': 4,
13
+ 'kernel_size': 3,
14
+ 'padding': 'same',
15
+ 'dilation': 3,
16
+ 'activation': 'relu',
17
+ 'dropout': 0.1
18
+ },
19
+ {
20
+ 'in_channels': 4,
21
+ 'out_channels': 4,
22
+ 'kernel_size': 5,
23
+ 'padding': 'same',
24
+ 'dilation': 3,
25
+ 'activation': 'relu',
26
+ 'dropout': 0.1
27
+ },
28
+ {
29
+ 'in_channels': 4,
30
+ 'out_channels': 4,
31
+ 'kernel_size': 3,
32
+ 'padding': 'same',
33
+ 'dilation': 2,
34
+ 'activation': 'relu',
35
+ 'dropout': 0.1
36
+ }
37
+ ]
38
+
39
+
40
  class CNNVadConfig(PretrainedConfig):
41
  def __init__(self,
42
  sample_rate: int = 8000,
 
46
  win_type: str = "hann",
47
 
48
  conv2d_block_param_list: list = None,
49
+ classifier_input_size: int = 1028,
50
 
51
  min_snr_db: float = -10,
52
  max_snr_db: float = 20,
 
74
  self.win_type = win_type
75
 
76
  # encoder
77
+ self.conv2d_block_param_list = conv2d_block_param_list or DEFAULT_CONV2D_BLOCK_PARAM_LIST
78
+ self.classifier_input_size = classifier_input_size
79
 
80
  # data snr
81
  self.min_snr_db = min_snr_db
 
96
 
97
 
98
  def main():
99
+ config = CNNVadConfig()
100
  config.to_yaml_file("config.yaml")
101
  return
102
 
toolbox/torchaudio/models/vad/cnn_vad/modeling_cnn_vad.py CHANGED
@@ -7,7 +7,7 @@ import torch
7
  import torch.nn as nn
8
 
9
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
10
- from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
11
  from toolbox.torchaudio.modules.conv_stft import ConvSTFT
12
 
13
 
@@ -24,7 +24,6 @@ class Conv2dBlock(nn.Module):
24
  in_channels: int,
25
  out_channels: int,
26
  kernel_size: Union[int, Tuple[int, int]],
27
- stride: Tuple[int, int],
28
  padding: str = 0,
29
  dilation: int = 1,
30
  batch_norm: bool = False,
@@ -45,8 +44,7 @@ class Conv2dBlock(nn.Module):
45
  in_channels,
46
  out_channels,
47
  kernel_size=kernel_size,
48
- stride=stride,
49
- padding=(padding,),
50
  dilation=(dilation,),
51
  )
52
 
@@ -61,6 +59,7 @@ class Conv2dBlock(nn.Module):
61
  self.dropout = None
62
 
63
  def forward(self, x: torch.Tensor):
 
64
 
65
  if self.batch_norm is not None:
66
  x = self.batch_norm(x)
@@ -83,7 +82,7 @@ class CNNVadModel(nn.Module):
83
  hop_size: int,
84
  win_type: str,
85
  conv2d_block_param_list: List[dict],
86
- classifier_hidden_size: int,
87
  ):
88
  super(CNNVadModel, self).__init__()
89
  self.nfft = nfft
@@ -91,7 +90,7 @@ class CNNVadModel(nn.Module):
91
  self.hop_size = hop_size
92
  self.win_type = win_type
93
  self.conv2d_block_param_list = conv2d_block_param_list
94
- self.classifier_hidden_size = classifier_hidden_size
95
 
96
  self.eps = 1e-12
97
 
@@ -106,11 +105,11 @@ class CNNVadModel(nn.Module):
106
 
107
  self.cnn_encoder_list = nn.ModuleList(modules=[
108
  Conv2dBlock(
109
- batch_norm=param["batch_norm"],
110
  in_channels=param["in_channels"],
111
  out_channels=param["out_channels"],
112
  kernel_size=param["kernel_size"],
113
- stride=param["stride"],
114
  dilation=param["dilation"],
115
  activation=param["activation"],
116
  dropout=param["dropout"],
@@ -119,7 +118,7 @@ class CNNVadModel(nn.Module):
119
  ])
120
 
121
  self.classifier = nn.Sequential(
122
- nn.Linear(classifier_hidden_size, 32),
123
  nn.ReLU(),
124
  nn.Linear(32, 1),
125
  )
@@ -137,14 +136,18 @@ class CNNVadModel(nn.Module):
137
 
138
  x = torch.transpose(mags, dim0=1, dim1=2)
139
  # x shape: [b, t, f]
 
140
 
141
- x = self.linear.forward(x)
142
- # x shape: [b, t, f']
 
 
143
 
144
- x = self.encoder.forward(x)
145
- # x shape: [b, t, f]
 
 
146
 
147
- x, _ = self.lstm.forward(x)
148
  logits = self.classifier.forward(x)
149
  # logits shape: [b, t, 1]
150
  probs = self.sigmoid.forward(logits)
@@ -152,15 +155,68 @@ class CNNVadModel(nn.Module):
152
  return logits, probs
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def main():
156
- config = SileroVadConfig()
157
- model = SileroVadModel(config=config)
 
158
 
159
  noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
160
 
161
  logits, probs = model.forward(noisy)
162
- print(f"logits: {probs}")
163
- print(f"logits.shape: {logits.shape}")
164
 
165
  return
166
 
 
7
  import torch.nn as nn
8
 
9
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
10
+ from toolbox.torchaudio.models.vad.cnn_vad.configuration_cnn_vad import CNNVadConfig
11
  from toolbox.torchaudio.modules.conv_stft import ConvSTFT
12
 
13
 
 
24
  in_channels: int,
25
  out_channels: int,
26
  kernel_size: Union[int, Tuple[int, int]],
 
27
  padding: str = 0,
28
  dilation: int = 1,
29
  batch_norm: bool = False,
 
44
  in_channels,
45
  out_channels,
46
  kernel_size=kernel_size,
47
+ padding=padding,
 
48
  dilation=(dilation,),
49
  )
50
 
 
59
  self.dropout = None
60
 
61
  def forward(self, x: torch.Tensor):
62
+ # x: [b, c, t, f]
63
 
64
  if self.batch_norm is not None:
65
  x = self.batch_norm(x)
 
82
  hop_size: int,
83
  win_type: str,
84
  conv2d_block_param_list: List[dict],
85
+ classifier_input_size: int,
86
  ):
87
  super(CNNVadModel, self).__init__()
88
  self.nfft = nfft
 
90
  self.hop_size = hop_size
91
  self.win_type = win_type
92
  self.conv2d_block_param_list = conv2d_block_param_list
93
+ self.classifier_input_size = classifier_input_size
94
 
95
  self.eps = 1e-12
96
 
 
105
 
106
  self.cnn_encoder_list = nn.ModuleList(modules=[
107
  Conv2dBlock(
108
+ batch_norm=param.get("batch_norm"),
109
  in_channels=param["in_channels"],
110
  out_channels=param["out_channels"],
111
  kernel_size=param["kernel_size"],
112
+ padding=param["padding"],
113
  dilation=param["dilation"],
114
  activation=param["activation"],
115
  dropout=param["dropout"],
 
118
  ])
119
 
120
  self.classifier = nn.Sequential(
121
+ nn.Linear(classifier_input_size, 32),
122
  nn.ReLU(),
123
  nn.Linear(32, 1),
124
  )
 
136
 
137
  x = torch.transpose(mags, dim0=1, dim1=2)
138
  # x shape: [b, t, f]
139
+ x = torch.unsqueeze(x, dim=1)
140
 
141
+ # x: [b, c, t, f]
142
+ for cnn_encoder in self.cnn_encoder_list:
143
+ x = cnn_encoder.forward(x)
144
+ # x: [b, c, t, d]
145
 
146
+ x = x.permute(0, 2, 1, 3)
147
+ b, t, c, d = x.shape
148
+ x = torch.reshape(x, shape=(b, t, c*d))
149
+ # x: [b, t, c*d]
150
 
 
151
  logits = self.classifier.forward(x)
152
  # logits shape: [b, t, 1]
153
  probs = self.sigmoid.forward(logits)
 
155
  return logits, probs
156
 
157
 
158
+ class CNNVadPretrainedModel(CNNVadModel):
159
+ def __init__(self,
160
+ config: CNNVadConfig,
161
+ ):
162
+ super(CNNVadPretrainedModel, self).__init__(
163
+ nfft=config.nfft,
164
+ win_size=config.win_size,
165
+ hop_size=config.hop_size,
166
+ win_type=config.win_type,
167
+ conv2d_block_param_list=config.conv2d_block_param_list,
168
+ classifier_input_size=config.classifier_input_size,
169
+ )
170
+ self.config = config
171
+
172
+ @classmethod
173
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
174
+ config = CNNVadConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
175
+
176
+ model = cls(config)
177
+
178
+ if os.path.isdir(pretrained_model_name_or_path):
179
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
180
+ else:
181
+ ckpt_file = pretrained_model_name_or_path
182
+
183
+ with open(ckpt_file, "rb") as f:
184
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
185
+ model.load_state_dict(state_dict, strict=True)
186
+ return model
187
+
188
+ def save_pretrained(self,
189
+ save_directory: Union[str, os.PathLike],
190
+ state_dict: Optional[dict] = None,
191
+ ):
192
+
193
+ model = self
194
+
195
+ if state_dict is None:
196
+ state_dict = model.state_dict()
197
+
198
+ os.makedirs(save_directory, exist_ok=True)
199
+
200
+ # save state dict
201
+ model_file = os.path.join(save_directory, MODEL_FILE)
202
+ torch.save(state_dict, model_file)
203
+
204
+ # save config
205
+ config_file = os.path.join(save_directory, CONFIG_FILE)
206
+ self.config.to_yaml_file(config_file)
207
+ return save_directory
208
+
209
+
210
  def main():
211
+ # config = CNNVadConfig.from_pretrained("yaml/config.yaml")
212
+ config = CNNVadConfig()
213
+ model = CNNVadPretrainedModel(config)
214
 
215
  noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
216
 
217
  logits, probs = model.forward(noisy)
218
+ print(f"probs: {probs}")
219
+ print(f"probs.shape: {probs.shape}")
220
 
221
  return
222
 
toolbox/torchaudio/models/vad/cnn_vad/yaml/config.yaml CHANGED
@@ -13,24 +13,25 @@ conv2d_block_param_list:
13
  in_channels: 1
14
  out_channels: 4
15
  kernel_size: 3
16
- stride: 1
17
  dilation: 3
18
  activation: relu
19
  dropout: 0.1
20
  - in_channels: 4
21
  out_channels: 4
22
  kernel_size: 5
23
- stride: 2
24
  dilation: 3
25
  activation: relu
26
  dropout: 0.1
27
  - in_channels: 4
28
  out_channels: 4
29
  kernel_size: 3
30
- stride: 1
31
  dilation: 2
32
  activation: relu
33
  dropout: 0.1
 
34
 
35
  # data
36
  min_snr_db: -10
 
13
  in_channels: 1
14
  out_channels: 4
15
  kernel_size: 3
16
+ padding: "same"
17
  dilation: 3
18
  activation: relu
19
  dropout: 0.1
20
  - in_channels: 4
21
  out_channels: 4
22
  kernel_size: 5
23
+ padding: "same"
24
  dilation: 3
25
  activation: relu
26
  dropout: 0.1
27
  - in_channels: 4
28
  out_channels: 4
29
  kernel_size: 3
30
+ padding: "same"
31
  dilation: 2
32
  activation: relu
33
  dropout: 0.1
34
+ classifier_input_size: 1028
35
 
36
  # data
37
  min_snr_db: -10
toolbox/torchaudio/models/vad/fsmn_vad/inference_fsmn_vad.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import logging
5
+ from pathlib import Path
6
+ import shutil
7
+ import tempfile, time
8
+ from typing import List
9
+ import zipfile
10
+
11
+ from scipy.io import wavfile
12
+ import librosa
13
+ import numpy as np
14
+ import torch
15
+ import torchaudio
16
+
17
+ torch.set_num_threads(1)
18
+
19
+ from project_settings import project_path
20
+ from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
21
+ from toolbox.torchaudio.models.vad.fsmn_vad.modeling_fsmn_vad import FSMNVadPretrainedModel, MODEL_FILE
22
+ from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
23
+
24
+
25
+ logger = logging.getLogger("toolbox")
26
+
27
+
28
+ class InferenceFSMNVad(object):
29
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
30
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
31
+ self.device = torch.device(device)
32
+
33
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
34
+ config, model = self.load_models(self.pretrained_model_path_or_zip_file)
35
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
36
+
37
+ self.config = config
38
+ self.model = model
39
+ self.model.to(device)
40
+ self.model.eval()
41
+
42
+ def load_models(self, model_path: str):
43
+ model_path = Path(model_path)
44
+ if model_path.name.endswith(".zip"):
45
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
46
+ out_root = Path(tempfile.gettempdir()) / "cc_vad"
47
+ out_root.mkdir(parents=True, exist_ok=True)
48
+ f_zip.extractall(path=out_root)
49
+ model_path = out_root / model_path.stem
50
+
51
+ config = FSMNVadConfig.from_pretrained(
52
+ pretrained_model_name_or_path=model_path.as_posix(),
53
+ )
54
+ model = FSMNVadPretrainedModel.from_pretrained(
55
+ pretrained_model_name_or_path=model_path.as_posix(),
56
+ )
57
+ model.to(self.device)
58
+ model.eval()
59
+
60
+ shutil.rmtree(model_path)
61
+ return config, model
62
+
63
+ def infer(self, signal: torch.Tensor) -> float:
64
+ # signal shape: [num_samples,], value between -1 and 1.
65
+
66
+ inputs = torch.tensor(signal, dtype=torch.float32)
67
+ inputs = torch.unsqueeze(inputs, dim=0)
68
+ # inputs shape: [1, num_samples,]
69
+
70
+ with torch.no_grad():
71
+ logits, probs = self.model.forward(inputs)
72
+
73
+ # probs shape: [b, t, 1]
74
+ probs = torch.squeeze(probs, dim=-1)
75
+ # probs shape: [b, t]
76
+
77
+ probs = probs.numpy()
78
+ probs = probs[0]
79
+ probs = probs.tolist()
80
+ return probs
81
+
82
+ def post_process(self, probs: List[float]):
83
+ return
84
+
85
+
86
+ def get_args():
87
+ parser = argparse.ArgumentParser()
88
+ parser.add_argument(
89
+ "--wav_file",
90
+ # default=(project_path / "data/examples/ai_agent/chinese-4.wav").as_posix(),
91
+ # default=(project_path / "data/examples/ai_agent/chinese-5.wav").as_posix(),
92
+ # default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
93
+ # default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
94
+ # default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
95
+ default=(project_path / "data/examples/speech/active_media_r_2a2f472b-a0b8-4fd5-b1c4-1aedc5d2ce57_0.wav").as_posix(),
96
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-SG\2025-06-17\active_media_r_0af6bd3a-9aef-4bef-935b-63abfb4d46d8_5.wav",
97
+ type=str,
98
+ )
99
+ args = parser.parse_args()
100
+ return args
101
+
102
+
103
+ SAMPLE_RATE = 8000
104
+
105
+
106
+ def main():
107
+ args = get_args()
108
+
109
+ sample_rate, signal = wavfile.read(args.wav_file)
110
+ if SAMPLE_RATE != sample_rate:
111
+ raise AssertionError
112
+ signal = signal / (1 << 15)
113
+
114
+ infer = InferenceFSMNVad(
115
+ pretrained_model_path_or_zip_file=(project_path / "trained_models/fsmn-vad-by-webrtcvad-nx-dns3.zip").as_posix()
116
+ )
117
+ frame_step = infer.config.hop_size
118
+
119
+ speech_probs = infer.infer(signal)
120
+
121
+ # print(speech_probs)
122
+
123
+ speech_probs = process_speech_probs(
124
+ signal=signal,
125
+ speech_probs=speech_probs,
126
+ frame_step=frame_step,
127
+ )
128
+
129
+ # plot
130
+ make_visualization(signal, speech_probs, SAMPLE_RATE)
131
+ return
132
+
133
+
134
+ if __name__ == "__main__":
135
+ main()
toolbox/vad/vad.py CHANGED
@@ -366,7 +366,8 @@ def get_args():
366
  parser.add_argument(
367
  "--wav_file",
368
  # default=(project_path / "data/early_media/62/3300999628999191096.wav").as_posix(),
369
- default=r"D:/Users/tianx/HuggingDatasets/nx_noise/data/speech/nx-speech/en-PH/2025-05-28/active_media_w_1f650e5c-bd22-4803-bb88-d670b00fccda_30.wav",
 
370
  type=str,
371
  )
372
  parser.add_argument(
@@ -410,8 +411,8 @@ def main():
410
  end_ring_rate=0.1,
411
  frame_size_ms=30,
412
  frame_step_ms=30,
413
- padding_length_ms=300,
414
- max_silence_length_ms=300,
415
  max_speech_length_s=100,
416
  min_speech_length_s=0.1,
417
  sample_rate=SAMPLE_RATE,
 
366
  parser.add_argument(
367
  "--wav_file",
368
  # default=(project_path / "data/early_media/62/3300999628999191096.wav").as_posix(),
369
+ # default=r"D:/Users/tianx/HuggingDatasets/nx_noise/data/speech/nx-speech/en-PH/2025-05-28/active_media_w_1f650e5c-bd22-4803-bb88-d670b00fccda_30.wav",
370
+ default=r"D:/Users/tianx/HuggingDatasets/nx_noise/data/speech/en-PH/2025-05-15/active_media_r_0617d225-f396-4011-a86e-eaf68cdda5a8_3.wav",
371
  type=str,
372
  )
373
  parser.add_argument(
 
411
  end_ring_rate=0.1,
412
  frame_size_ms=30,
413
  frame_step_ms=30,
414
+ padding_length_ms=30,
415
+ max_silence_length_ms=0,
416
  max_speech_length_s=100,
417
  min_speech_length_s=0.1,
418
  sample_rate=SAMPLE_RATE,