HoneyTian commited on
Commit
19f8ea7
·
1 Parent(s): d87e440
Dockerfile CHANGED
@@ -10,10 +10,6 @@ RUN apt-get install -y ffmpeg build-essential
10
  RUN pip install --upgrade pip
11
  RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
12
 
13
- RUN pip install --upgrade pip
14
-
15
- RUN bash install.sh --stage 1 --stop_stage 2 --system_version centos
16
-
17
  USER user
18
 
19
  ENV HOME=/home/user \
@@ -23,4 +19,6 @@ WORKDIR $HOME/app
23
 
24
  COPY --chown=user . $HOME/app
25
 
 
 
26
  CMD ["python3", "main.py"]
 
10
  RUN pip install --upgrade pip
11
  RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
12
 
 
 
 
 
13
  USER user
14
 
15
  ENV HOME=/home/user \
 
19
 
20
  COPY --chown=user . $HOME/app
21
 
22
+ RUN bash install.sh --stage 1 --stop_stage 2 --system_version centos
23
+
24
  CMD ["python3", "main.py"]
examples/data_annotation/annotation_by_google.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ import re
7
+ from pathlib import Path
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ from google import genai
14
+ from google.genai import types
15
+
16
+ from project_settings import environment, project_path
17
+
18
+
19
+ def get_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument(
22
+ "--google_application_credentials",
23
+ default=(project_path / "dotenv/potent-veld-462405-t3-8091a29b2894.json").as_posix(),
24
+ type=str
25
+ )
26
+ parser.add_argument(
27
+ "--model_name",
28
+ default="gemini-2.5-pro",
29
+ type=str
30
+ )
31
+ parser.add_argument(
32
+ "--speech_audio_dir",
33
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-SG\2025-06-17",
34
+ type=str
35
+ )
36
+ parser.add_argument(
37
+ "--output_file",
38
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\nx-noise\en-SG\2025-06-17\vad.jsonl",
39
+ default=r"vad.jsonl",
40
+ type=str
41
+ )
42
+ parser.add_argument(
43
+ "--gemini_api_key",
44
+ default=environment.get("GEMINI_API_KEY", dtype=str),
45
+ type=str
46
+ )
47
+ args = parser.parse_args()
48
+ return args
49
+
50
+
51
+ def main():
52
+ args = get_args()
53
+
54
+ speech_audio_dir = Path(args.speech_audio_dir)
55
+ output_file = Path(args.output_file)
56
+
57
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = args.google_application_credentials
58
+ os.environ["gemini_api_key"] = args.gemini_api_key
59
+
60
+
61
+ developer_client = genai.Client(
62
+ api_key=args.gemini_api_key,
63
+ )
64
+ client = genai.Client(
65
+ vertexai=True,
66
+ project="potent-veld-462405-t3",
67
+ location="global",
68
+ )
69
+ generate_content_config = types.GenerateContentConfig(
70
+ temperature=1,
71
+ top_p=0.95,
72
+ max_output_tokens=8192,
73
+ response_modalities=["TEXT"],
74
+ )
75
+
76
+ # finished
77
+ finished_set = set()
78
+ if output_file.exists():
79
+ with open(output_file.as_posix(), "r", encoding="utf-8") as f:
80
+ for row in f:
81
+ row = json.loads(row)
82
+ name = row["name"]
83
+ finished_set.add(name)
84
+ print(f"finished count: {len(finished_set)}")
85
+
86
+ with open(output_file.as_posix(), "a+", encoding="utf-8") as f:
87
+
88
+ for filename in speech_audio_dir.glob("**/*.wav"):
89
+ name = filename.name
90
+ if name in finished_set:
91
+ continue
92
+ finished_set.add(name)
93
+
94
+ # upload
95
+ audio_file = developer_client.files.upload(
96
+ file=filename.as_posix(),
97
+ config=None
98
+ )
99
+ print(f"upload file: {audio_file.name}")
100
+
101
+ prompt = f"""
102
+ 给我这段音频中的语音分段的开始和结束时间,单位为秒,精确到毫秒,并输出JSON格式,
103
+ 例如:
104
+ ```json
105
+ [[0.254, 1.214], [2.200, 3.100]],
106
+ ```
107
+ 如果没有语音段则输出:
108
+ ```json
109
+ []
110
+ ```
111
+ """.strip()
112
+
113
+ try:
114
+ contents = [
115
+ types.Content(
116
+ role="user",
117
+ parts=[
118
+ types.Part(text=prompt),
119
+ types.Part.from_uri(
120
+ file_uri=audio_file.uri,
121
+ mime_type=audio_file.mime_type,
122
+ )
123
+ ]
124
+ )
125
+ ]
126
+ response: types.GenerateContentResponse = developer_client.models.generate_content(
127
+ model=args.model_name,
128
+ contents=contents,
129
+ config=generate_content_config,
130
+ )
131
+ answer = response.candidates[0].content.parts[0].text
132
+ print(answer)
133
+ finally:
134
+ # delete
135
+ print(f"delete file: {audio_file.name}")
136
+ developer_client.files.delete(name=audio_file.name)
137
+
138
+ pattern = "```json(.+?)```"
139
+ match = re.search(pattern=pattern, string=answer, flags=re.DOTALL | re.IGNORECASE)
140
+ if match is None:
141
+ raise AssertionError(f"answer: {answer}")
142
+ vad_segments = match.group(1)
143
+ vad_segments = json.loads(vad_segments)
144
+ row = {
145
+ "name": name,
146
+ "filename": filename.as_posix(),
147
+ "vad_segments": vad_segments
148
+ }
149
+ row = json.dumps(row, ensure_ascii=False)
150
+
151
+ f.write(f"{row}\n")
152
+ exit(0)
153
+
154
+ return
155
+
156
+
157
+ if __name__ == "__main__":
158
+ main()
examples/fsmn_vad_by_webrtcvad/run.sh ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ bash run.sh --stage 2 --stop_stage 2 --system_version centos \
6
+ --file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
7
+ --final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
8
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
9
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
10
+
11
+ bash run.sh --stage 3 --stop_stage 3 --system_version centos \
12
+ --file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
13
+ --final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
14
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
15
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
16
+
17
+
18
+ END
19
+
20
+
21
+ # params
22
+ system_version="windows";
23
+ verbose=true;
24
+ stage=0 # start from 0 if you need to start from data preparation
25
+ stop_stage=9
26
+
27
+ work_dir="$(pwd)"
28
+ file_folder_name=file_folder_name
29
+ final_model_name=final_model_name
30
+ config_file="yaml/config.yaml"
31
+ limit=10
32
+
33
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
34
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
35
+
36
+ max_count=-1
37
+
38
+ nohup_name=nohup.out
39
+
40
+ # model params
41
+ batch_size=64
42
+ max_epochs=200
43
+ save_top_k=10
44
+ patience=5
45
+
46
+
47
+ # parse options
48
+ while true; do
49
+ [ -z "${1:-}" ] && break; # break if there are no arguments
50
+ case "$1" in
51
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
52
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
53
+ old_value="(eval echo \\$$name)";
54
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
55
+ was_bool=true;
56
+ else
57
+ was_bool=false;
58
+ fi
59
+
60
+ # Set the variable to the right value-- the escaped quotes make it work if
61
+ # the option had spaces, like --cmd "queue.pl -sync y"
62
+ eval "${name}=\"$2\"";
63
+
64
+ # Check that Boolean-valued arguments are really Boolean.
65
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
66
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
67
+ exit 1;
68
+ fi
69
+ shift 2;
70
+ ;;
71
+
72
+ *) break;
73
+ esac
74
+ done
75
+
76
+ file_dir="${work_dir}/${file_folder_name}"
77
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
78
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
79
+
80
+ train_dataset="${file_dir}/train.jsonl"
81
+ valid_dataset="${file_dir}/valid.jsonl"
82
+
83
+ train_vad_dataset="${file_dir}/train-vad.jsonl"
84
+ valid_vad_dataset="${file_dir}/valid-vad.jsonl"
85
+
86
+ $verbose && echo "system_version: ${system_version}"
87
+ $verbose && echo "file_folder_name: ${file_folder_name}"
88
+
89
+ if [ $system_version == "windows" ]; then
90
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
91
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
92
+ #source /data/local/bin/nx_denoise/bin/activate
93
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
94
+ fi
95
+
96
+
97
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
98
+ $verbose && echo "stage 1: prepare data"
99
+ cd "${work_dir}" || exit 1
100
+ python3 step_1_prepare_data.py \
101
+ --noise_dir "${noise_dir}" \
102
+ --speech_dir "${speech_dir}" \
103
+ --train_dataset "${train_dataset}" \
104
+ --valid_dataset "${valid_dataset}" \
105
+ --max_count "${max_count}" \
106
+
107
+ fi
108
+
109
+
110
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
111
+ $verbose && echo "stage 2: make vad segments"
112
+ cd "${work_dir}" || exit 1
113
+ python3 step_2_make_vad_segments.py \
114
+ --train_dataset "${train_dataset}" \
115
+ --valid_dataset "${valid_dataset}" \
116
+ --train_vad_dataset "${train_vad_dataset}" \
117
+ --valid_vad_dataset "${valid_vad_dataset}" \
118
+
119
+ fi
120
+
121
+
122
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
123
+ $verbose && echo "stage 3: train model"
124
+ cd "${work_dir}" || exit 1
125
+ python3 step_4_train_model.py \
126
+ --train_dataset "${train_vad_dataset}" \
127
+ --valid_dataset "${valid_vad_dataset}" \
128
+ --serialization_dir "${file_dir}" \
129
+ --config_file "${config_file}" \
130
+
131
+ fi
132
+
133
+
134
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
135
+ $verbose && echo "stage 4: test model"
136
+ cd "${work_dir}" || exit 1
137
+ python3 step_3_evaluation.py \
138
+ --valid_dataset "${valid_dataset}" \
139
+ --model_dir "${file_dir}/best" \
140
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
141
+ --limit "${limit}" \
142
+
143
+ fi
144
+
145
+
146
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
147
+ $verbose && echo "stage 5: collect files"
148
+ cd "${work_dir}" || exit 1
149
+
150
+ mkdir -p ${final_model_dir}
151
+
152
+ cp "${file_dir}/best"/* "${final_model_dir}"
153
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
154
+
155
+ cd "${final_model_dir}/.." || exit 1;
156
+
157
+ if [ -e "${final_model_name}.zip" ]; then
158
+ rm -rf "${final_model_name}_backup.zip"
159
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
160
+ fi
161
+
162
+ zip -r "${final_model_name}.zip" "${final_model_name}"
163
+ rm -rf "${final_model_name}"
164
+
165
+ fi
166
+
167
+
168
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
169
+ $verbose && echo "stage 6: clear file_dir"
170
+ cd "${work_dir}" || exit 1
171
+
172
+ rm -rf "${file_dir}";
173
+
174
+ fi
examples/fsmn_vad_by_webrtcvad/step_1_prepare_data.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+ import time
10
+
11
+ pwd = os.path.abspath(os.path.dirname(__file__))
12
+ sys.path.append(os.path.join(pwd, "../../"))
13
+
14
+ import librosa
15
+ import numpy as np
16
+ from tqdm import tqdm
17
+
18
+
19
+ def get_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument(
22
+ "--noise_dir",
23
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
24
+ type=str
25
+ )
26
+ parser.add_argument(
27
+ "--speech_dir",
28
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech",
29
+ type=str
30
+ )
31
+
32
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
33
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
34
+
35
+ parser.add_argument("--duration", default=8.0, type=float)
36
+ parser.add_argument("--min_speech_duration", default=6.0, type=float)
37
+ parser.add_argument("--max_speech_duration", default=8.0, type=float)
38
+ parser.add_argument("--min_snr_db", default=-10, type=float)
39
+ parser.add_argument("--max_snr_db", default=20, type=float)
40
+
41
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
42
+
43
+ parser.add_argument("--max_count", default=-1, type=int)
44
+
45
+ args = parser.parse_args()
46
+ return args
47
+
48
+
49
+ def target_second_noise_signal_generator(data_dir: str,
50
+ duration: int = 4,
51
+ sample_rate: int = 8000, max_epoch: int = 20000):
52
+ noise_list = list()
53
+ wait_duration = duration
54
+
55
+ data_dir = Path(data_dir)
56
+ for epoch_idx in range(max_epoch):
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+
60
+ if signal.ndim != 1:
61
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
62
+
63
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
64
+
65
+ offset = 0.
66
+ rest_duration = raw_duration
67
+
68
+ for _ in range(1000):
69
+ if rest_duration <= 0:
70
+ break
71
+ if rest_duration <= wait_duration:
72
+ noise_list.append({
73
+ "epoch_idx": epoch_idx,
74
+ "filename": filename.as_posix(),
75
+ "raw_duration": round(raw_duration, 4),
76
+ "offset": round(offset, 4),
77
+ "duration": None,
78
+ "duration_": round(rest_duration, 4),
79
+ })
80
+ wait_duration -= rest_duration
81
+ offset = 0
82
+ rest_duration = 0
83
+ elif rest_duration > wait_duration:
84
+ noise_list.append({
85
+ "epoch_idx": epoch_idx,
86
+ "filename": filename.as_posix(),
87
+ "raw_duration": round(raw_duration, 4),
88
+ "offset": round(offset, 4),
89
+ "duration": round(wait_duration, 4),
90
+ "duration_": round(wait_duration, 4),
91
+ })
92
+ offset += wait_duration
93
+ rest_duration -= wait_duration
94
+ wait_duration = 0
95
+ else:
96
+ raise AssertionError
97
+
98
+ if wait_duration <= 0:
99
+ yield noise_list
100
+ noise_list = list()
101
+ wait_duration = duration
102
+
103
+
104
+ def target_second_speech_signal_generator(data_dir: str,
105
+ min_duration: int = 4,
106
+ max_duration: int = 6,
107
+ sample_rate: int = 8000, max_epoch: int = 1):
108
+ data_dir = Path(data_dir)
109
+ for epoch_idx in range(max_epoch):
110
+ for filename in data_dir.glob("**/*.wav"):
111
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
112
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
113
+
114
+ if signal.ndim != 1:
115
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
116
+
117
+ if raw_duration < min_duration:
118
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
119
+ continue
120
+
121
+ if raw_duration < max_duration:
122
+ row = {
123
+ "epoch_idx": epoch_idx,
124
+ "filename": filename.as_posix(),
125
+ "raw_duration": round(raw_duration, 4),
126
+ "offset": 0.,
127
+ "duration": round(raw_duration, 4),
128
+ }
129
+ yield row
130
+
131
+ signal_length = len(signal)
132
+ win_size = int(max_duration * sample_rate)
133
+ for begin in range(0, signal_length - win_size, win_size):
134
+ if np.sum(signal[begin: begin+win_size]) == 0:
135
+ continue
136
+ row = {
137
+ "epoch_idx": epoch_idx,
138
+ "filename": filename.as_posix(),
139
+ "raw_duration": round(raw_duration, 4),
140
+ "offset": round(begin / sample_rate, 4),
141
+ "duration": round(max_duration, 4),
142
+ }
143
+ yield row
144
+
145
+
146
+ def main():
147
+ args = get_args()
148
+
149
+ noise_dir = Path(args.noise_dir)
150
+ speech_dir = Path(args.speech_dir)
151
+
152
+ train_dataset = Path(args.train_dataset)
153
+ valid_dataset = Path(args.valid_dataset)
154
+ train_dataset.parent.mkdir(parents=True, exist_ok=True)
155
+ valid_dataset.parent.mkdir(parents=True, exist_ok=True)
156
+
157
+ noise_generator = target_second_noise_signal_generator(
158
+ noise_dir.as_posix(),
159
+ duration=args.duration,
160
+ sample_rate=args.target_sample_rate,
161
+ max_epoch=100000,
162
+ )
163
+ speech_generator = target_second_speech_signal_generator(
164
+ speech_dir.as_posix(),
165
+ min_duration=args.min_speech_duration,
166
+ max_duration=args.max_speech_duration,
167
+ sample_rate=args.target_sample_rate,
168
+ max_epoch=1,
169
+ )
170
+
171
+ count = 0
172
+ process_bar = tqdm(desc="build dataset jsonl")
173
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
174
+ for speech, noise_list in zip(speech_generator, noise_generator):
175
+ if count >= args.max_count > 0:
176
+ break
177
+
178
+ # row
179
+ speech_filename = speech["filename"]
180
+ speech_raw_duration = speech["raw_duration"]
181
+ speech_offset = speech["offset"]
182
+ speech_duration = speech["duration"]
183
+
184
+ noise_list = [
185
+ {
186
+ "filename": noise["filename"],
187
+ "raw_duration": noise["raw_duration"],
188
+ "offset": noise["offset"],
189
+ "duration": noise["duration"],
190
+ }
191
+ for noise in noise_list
192
+ ]
193
+
194
+ # row
195
+ random1 = random.random()
196
+ random2 = random.random()
197
+
198
+ row = {
199
+ "count": count,
200
+
201
+ "speech_filename": speech_filename,
202
+ "speech_raw_duration": speech_raw_duration,
203
+ "speech_offset": speech_offset,
204
+ "speech_duration": speech_duration,
205
+
206
+ "noise_list": noise_list,
207
+
208
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
209
+
210
+ "random1": random1,
211
+ }
212
+ row = json.dumps(row, ensure_ascii=False)
213
+ if random2 < (1 / 300 / 1):
214
+ fvalid.write(f"{row}\n")
215
+ else:
216
+ ftrain.write(f"{row}\n")
217
+
218
+ count += 1
219
+ duration_seconds = count * args.duration
220
+ duration_hours = duration_seconds / 3600
221
+
222
+ process_bar.update(n=1)
223
+ process_bar.set_postfix({
224
+ "duration_hours": round(duration_hours, 4),
225
+ })
226
+
227
+ return
228
+
229
+
230
+ if __name__ == "__main__":
231
+ main()
examples/fsmn_vad_by_webrtcvad/step_2_make_vad_segments.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ import sys
7
+
8
+ pwd = os.path.abspath(os.path.dirname(__file__))
9
+ sys.path.append(os.path.join(pwd, "../../"))
10
+
11
+ import librosa
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+
15
+ from project_settings import project_path
16
+ from toolbox.vad.vad import WebRTCVoiceClassifier, SileroVoiceClassifier, CCSoundsClassifier, RingVad
17
+
18
+
19
+ def get_args():
20
+ parser = argparse.ArgumentParser()
21
+
22
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
23
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
24
+
25
+ parser.add_argument("--train_vad_dataset", default="train-vad.jsonl", type=str)
26
+ parser.add_argument("--valid_vad_dataset", default="valid-vad.jsonl", type=str)
27
+
28
+ parser.add_argument("--duration", default=8.0, type=float)
29
+ parser.add_argument("--expected_sample_rate", default=8000, type=int)
30
+
31
+ parser.add_argument(
32
+ "--silero_model_path",
33
+ default=(project_path / "trained_models/silero_vad.jit").as_posix(),
34
+ type=str,
35
+ )
36
+ parser.add_argument(
37
+ "--cc_sounds_model_path",
38
+ default=(project_path / "trained_models/sound-2-ch32.zip").as_posix(),
39
+ type=str,
40
+ )
41
+ args = parser.parse_args()
42
+ return args
43
+
44
+
45
+ def main():
46
+ args = get_args()
47
+
48
+ # webrtcvad
49
+ # model = SileroVoiceClassifier(model_path=args.silero_model_path, sample_rate=args.expected_sample_rate)
50
+ # w_vad = RingVad(
51
+ # model=model,
52
+ # start_ring_rate=0.2,
53
+ # end_ring_rate=0.1,
54
+ # frame_size_ms=32,
55
+ # frame_step_ms=32,
56
+ # padding_length_ms=320,
57
+ # max_silence_length_ms=320,
58
+ # max_speech_length_s=100,
59
+ # min_speech_length_s=0.1,
60
+ # sample_rate=args.expected_sample_rate,
61
+ # )
62
+
63
+ # webrtcvad
64
+ model = WebRTCVoiceClassifier(agg=3, sample_rate=args.expected_sample_rate)
65
+ w_vad = RingVad(
66
+ model=model,
67
+ start_ring_rate=0.9,
68
+ end_ring_rate=0.1,
69
+ frame_size_ms=30,
70
+ frame_step_ms=30,
71
+ padding_length_ms=90,
72
+ max_silence_length_ms=100,
73
+ max_speech_length_s=100,
74
+ min_speech_length_s=0.1,
75
+ sample_rate=args.expected_sample_rate,
76
+ )
77
+
78
+ # cc sounds
79
+ # model = CCSoundsClassifier(model_path=args.cc_sounds_model_path, sample_rate=args.expected_sample_rate)
80
+ # w_vad = RingVad(
81
+ # model=model,
82
+ # start_ring_rate=0.5,
83
+ # end_ring_rate=0.3,
84
+ # frame_size_ms=300,
85
+ # frame_step_ms=300,
86
+ # padding_length_ms=300,
87
+ # max_silence_length_ms=100,
88
+ # max_speech_length_s=100,
89
+ # min_speech_length_s=0.1,
90
+ # sample_rate=args.expected_sample_rate,
91
+ # )
92
+
93
+ # valid
94
+ va_duration = 0
95
+ raw_duration = 0
96
+ use_duration = 0
97
+
98
+ count = 0
99
+ process_bar_valid = tqdm(desc="process valid dataset jsonl")
100
+ with (open(args.valid_dataset, "r", encoding="utf-8") as fvalid,
101
+ open(args.valid_vad_dataset, "w", encoding="utf-8") as fvalid_vad):
102
+ for row in fvalid:
103
+ row = json.loads(row)
104
+
105
+ speech_filename = row["speech_filename"]
106
+ speech_offset = row["speech_offset"]
107
+ speech_duration = row["speech_duration"]
108
+
109
+ waveform, _ = librosa.load(
110
+ speech_filename,
111
+ sr=args.expected_sample_rate,
112
+ offset=speech_offset,
113
+ duration=speech_duration,
114
+ )
115
+ waveform = np.array(waveform * (1 << 15), dtype=np.int16)
116
+
117
+ # vad
118
+ vad_segments = list()
119
+ segments = w_vad.vad(waveform)
120
+ vad_segments += segments
121
+ segments = w_vad.last_vad_segments()
122
+ vad_segments += segments
123
+ w_vad.reset()
124
+
125
+ row["vad_segments"] = vad_segments
126
+
127
+ row = json.dumps(row, ensure_ascii=False)
128
+ fvalid_vad.write(f"{row}\n")
129
+
130
+ va_duration += sum([vad_segment[1] - vad_segment[0] for vad_segment in vad_segments])
131
+ raw_duration += speech_duration
132
+ use_duration += args.duration
133
+
134
+ count += 1
135
+
136
+ va_rate = va_duration / use_duration
137
+ va_raw_rate = va_duration / raw_duration
138
+ use_duration_hours = use_duration / 3600
139
+
140
+ process_bar_valid.update(n=1)
141
+ process_bar_valid.set_postfix({
142
+ "va_rate": round(va_rate, 4),
143
+ "va_raw_rate": round(va_raw_rate, 4),
144
+ "duration_hours": round(use_duration_hours, 4),
145
+ })
146
+
147
+ # train
148
+ va_duration = 0
149
+ raw_duration = 0
150
+ use_duration = 0
151
+
152
+ count = 0
153
+ process_bar_train = tqdm(desc="process train dataset jsonl")
154
+ with (open(args.train_dataset, "r", encoding="utf-8") as ftrain,
155
+ open(args.train_vad_dataset, "w", encoding="utf-8") as ftrain_vad):
156
+ for row in ftrain:
157
+ row = json.loads(row)
158
+
159
+ speech_filename = row["speech_filename"]
160
+ speech_offset = row["speech_offset"]
161
+ speech_duration = row["speech_duration"]
162
+
163
+ waveform, _ = librosa.load(
164
+ speech_filename,
165
+ sr=args.expected_sample_rate,
166
+ offset=speech_offset,
167
+ duration=speech_duration,
168
+ )
169
+ waveform = np.array(waveform * (1 << 15), dtype=np.int16)
170
+
171
+ # vad
172
+ vad_segments = list()
173
+ segments = w_vad.vad(waveform)
174
+ vad_segments += segments
175
+ segments = w_vad.last_vad_segments()
176
+ vad_segments += segments
177
+ w_vad.reset()
178
+
179
+ row["vad_segments"] = vad_segments
180
+
181
+ row = json.dumps(row, ensure_ascii=False)
182
+ ftrain_vad.write(f"{row}\n")
183
+
184
+ va_duration += sum([vad_segment[1] - vad_segment[0] for vad_segment in vad_segments])
185
+ raw_duration += speech_duration
186
+ use_duration += args.duration
187
+
188
+ count += 1
189
+
190
+ va_rate = va_duration / use_duration
191
+ va_raw_rate = va_duration / raw_duration
192
+ use_duration_hours = use_duration / 3600
193
+
194
+ process_bar_train.update(n=1)
195
+ process_bar_train.set_postfix({
196
+ "va_rate": round(va_rate, 4),
197
+ "va_raw_rate": round(va_raw_rate, 4),
198
+ "duration_hours": round(use_duration_hours, 4),
199
+ })
200
+
201
+ return
202
+
203
+
204
+ if __name__ == "__main__":
205
+ main()
examples/fsmn_vad_by_webrtcvad/step_3_check_vad.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ import sys
7
+
8
+ pwd = os.path.abspath(os.path.dirname(__file__))
9
+ sys.path.append(os.path.join(pwd, "../../"))
10
+
11
+ import librosa
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ from scipy.io import wavfile
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+
21
+ parser.add_argument("--train_vad_dataset", default="train-vad.jsonl", type=str)
22
+ parser.add_argument("--valid_vad_dataset", default="valid-vad.jsonl", type=str)
23
+
24
+ parser.add_argument("--duration", default=8.0, type=float)
25
+ parser.add_argument("--expected_sample_rate", default=8000, type=int)
26
+
27
+ args = parser.parse_args()
28
+ return args
29
+
30
+
31
+ def main():
32
+ args = get_args()
33
+
34
+ SAMPLE_RATE = 8000
35
+
36
+ with open(args.train_vad_dataset, "r", encoding="utf-8") as f:
37
+ for row in f:
38
+ row = json.loads(row)
39
+
40
+ speech_filename = row["speech_filename"]
41
+ speech_offset = row["speech_offset"]
42
+ speech_duration = row["speech_duration"]
43
+
44
+ vad_segments = row["vad_segments"]
45
+
46
+ print(f"speech_filename: {speech_filename}")
47
+ signal, sample_rate = librosa.load(
48
+ speech_filename,
49
+ sr=SAMPLE_RATE,
50
+ offset=speech_offset,
51
+ duration=speech_duration,
52
+ )
53
+
54
+ # plot
55
+ time = np.arange(0, len(signal)) / sample_rate
56
+ plt.figure(figsize=(12, 5))
57
+ plt.plot(time, signal, color='b')
58
+ for start, end in vad_segments:
59
+ plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') # 标记开始端点
60
+ plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') # 标记结束端点
61
+
62
+ plt.show()
63
+
64
+ return
65
+
66
+
67
+ if __name__ == "__main__":
68
+ main()
examples/fsmn_vad_by_webrtcvad/step_4_train_model.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import logging
6
+ from logging.handlers import TimedRotatingFileHandler
7
+ import os
8
+ import platform
9
+ from pathlib import Path
10
+ import random
11
+ import sys
12
+ import shutil
13
+ from typing import List, Tuple
14
+
15
+ pwd = os.path.abspath(os.path.dirname(__file__))
16
+ sys.path.append(os.path.join(pwd, "../../"))
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn import functional as F
22
+ from torch.utils.data.dataloader import DataLoader
23
+ from tqdm import tqdm
24
+
25
+ from toolbox.torch.utils.data.dataset.vad_padding_jsonl_dataset import VadPaddingJsonlDataset
26
+ from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
27
+ from toolbox.torchaudio.models.vad.fsmn_vad.modeling_fsmn_vad import FSMNVadModel, FSMNVadPretrainedModel
28
+ from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss
29
+ from toolbox.torchaudio.losses.bce_loss import BCELoss
30
+ from toolbox.torchaudio.losses.dice_loss import DiceLoss
31
+ from toolbox.torchaudio.metrics.vad_metrics.vad_accuracy import VadAccuracy
32
+ from toolbox.torchaudio.metrics.vad_metrics.vad_f1_score import VadF1Score
33
+
34
+
35
+ def get_args():
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument("--train_dataset", default="train-vad.jsonl", type=str)
38
+ parser.add_argument("--valid_dataset", default="valid-vad.jsonl", type=str)
39
+
40
+ parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
41
+ parser.add_argument("--patience", default=30, type=int)
42
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
43
+
44
+ parser.add_argument("--config_file", default="yaml/config.yaml", type=str)
45
+
46
+ args = parser.parse_args()
47
+ return args
48
+
49
+
50
+ def logging_config(file_dir: str):
51
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
52
+
53
+ logging.basicConfig(format=fmt,
54
+ datefmt="%m/%d/%Y %H:%M:%S",
55
+ level=logging.INFO)
56
+ file_handler = TimedRotatingFileHandler(
57
+ filename=os.path.join(file_dir, "main.log"),
58
+ encoding="utf-8",
59
+ when="D",
60
+ interval=1,
61
+ backupCount=7
62
+ )
63
+ file_handler.setLevel(logging.INFO)
64
+ file_handler.setFormatter(logging.Formatter(fmt))
65
+ logger = logging.getLogger(__name__)
66
+ logger.addHandler(file_handler)
67
+
68
+ return logger
69
+
70
+
71
+ class CollateFunction(object):
72
+ def __init__(self):
73
+ pass
74
+
75
+ def __call__(self, batch: List[dict]):
76
+ noisy_audios = list()
77
+ batch_vad_segments = list()
78
+
79
+ for sample in batch:
80
+ noisy_wave: torch.Tensor = sample["noisy_wave"]
81
+ vad_segments: List[Tuple[float, float]] = sample["vad_segments"]
82
+
83
+ noisy_audios.append(noisy_wave)
84
+ batch_vad_segments.append(vad_segments)
85
+
86
+ noisy_audios = torch.stack(noisy_audios)
87
+
88
+ # assert
89
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
90
+ raise AssertionError("nan or inf in noisy_audios")
91
+
92
+ return noisy_audios, batch_vad_segments
93
+
94
+
95
+ collate_fn = CollateFunction()
96
+
97
+
98
+ def main():
99
+ args = get_args()
100
+
101
+ config = FSMNVadConfig.from_pretrained(
102
+ pretrained_model_name_or_path=args.config_file,
103
+ )
104
+
105
+ serialization_dir = Path(args.serialization_dir)
106
+ serialization_dir.mkdir(parents=True, exist_ok=True)
107
+
108
+ logger = logging_config(serialization_dir)
109
+
110
+ random.seed(config.seed)
111
+ np.random.seed(config.seed)
112
+ torch.manual_seed(config.seed)
113
+ logger.info(f"set seed: {config.seed}")
114
+
115
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
+ n_gpu = torch.cuda.device_count()
117
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
118
+
119
+ # datasets
120
+ train_dataset = VadPaddingJsonlDataset(
121
+ jsonl_file=args.train_dataset,
122
+ expected_sample_rate=config.sample_rate,
123
+ max_wave_value=32768.0,
124
+ min_snr_db=config.min_snr_db,
125
+ max_snr_db=config.max_snr_db,
126
+ # skip=225000,
127
+ )
128
+ valid_dataset = VadPaddingJsonlDataset(
129
+ jsonl_file=args.valid_dataset,
130
+ expected_sample_rate=config.sample_rate,
131
+ max_wave_value=32768.0,
132
+ min_snr_db=config.min_snr_db,
133
+ max_snr_db=config.max_snr_db,
134
+ )
135
+ train_data_loader = DataLoader(
136
+ dataset=train_dataset,
137
+ batch_size=config.batch_size,
138
+ # shuffle=True,
139
+ sampler=None,
140
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
141
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
142
+ collate_fn=collate_fn,
143
+ pin_memory=False,
144
+ prefetch_factor=None if platform.system() == "Windows" else 2,
145
+ )
146
+ valid_data_loader = DataLoader(
147
+ dataset=valid_dataset,
148
+ batch_size=config.batch_size,
149
+ # shuffle=True,
150
+ sampler=None,
151
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
152
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
153
+ collate_fn=collate_fn,
154
+ pin_memory=False,
155
+ prefetch_factor=None if platform.system() == "Windows" else 2,
156
+ )
157
+
158
+ # models
159
+ logger.info(f"prepare models. config_file: {args.config_file}")
160
+ model = FSMNVadPretrainedModel(config).to(device)
161
+ model.to(device)
162
+ model.train()
163
+
164
+ # optimizer
165
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
166
+ optimizer = torch.optim.AdamW(model.parameters(), config.lr)
167
+
168
+ # resume training
169
+ last_step_idx = -1
170
+ last_epoch = -1
171
+ for step_idx_str in serialization_dir.glob("steps-*"):
172
+ step_idx_str = Path(step_idx_str)
173
+ step_idx = step_idx_str.stem.split("-")[1]
174
+ step_idx = int(step_idx)
175
+ if step_idx > last_step_idx:
176
+ last_step_idx = step_idx
177
+ # last_epoch = 1
178
+
179
+ if last_step_idx != -1:
180
+ logger.info(f"resume from steps-{last_step_idx}.")
181
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
182
+
183
+ logger.info(f"load state dict for model.")
184
+ with open(model_pt.as_posix(), "rb") as f:
185
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
186
+ model.load_state_dict(state_dict, strict=True)
187
+
188
+ if config.lr_scheduler == "CosineAnnealingLR":
189
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
190
+ optimizer,
191
+ last_epoch=last_epoch,
192
+ # T_max=10 * config.eval_steps,
193
+ # eta_min=0.01 * config.lr,
194
+ **config.lr_scheduler_kwargs,
195
+ )
196
+ elif config.lr_scheduler == "MultiStepLR":
197
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
198
+ optimizer,
199
+ last_epoch=last_epoch,
200
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
201
+ )
202
+ else:
203
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
204
+
205
+ bce_loss_fn = BCELoss(reduction="mean").to(device)
206
+ dice_loss_fn = DiceLoss(reduction="mean").to(device)
207
+
208
+ vad_accuracy_metrics_fn = VadAccuracy(threshold=0.5)
209
+ vad_f1_score_metrics_fn = VadF1Score(threshold=0.5)
210
+
211
+ # training loop
212
+
213
+ # state
214
+ average_loss = 1000000000
215
+ average_bce_loss = 1000000000
216
+ average_dice_loss = 1000000000
217
+
218
+ accuracy = -1
219
+ f1 = -1
220
+ precision = -1
221
+ recall = -1
222
+
223
+ model_list = list()
224
+ best_epoch_idx = None
225
+ best_step_idx = None
226
+ best_metric = None
227
+ patience_count = 0
228
+
229
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
230
+
231
+ logger.info("training")
232
+ early_stop_flag = False
233
+ for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
234
+ if early_stop_flag:
235
+ break
236
+
237
+ # train
238
+ model.train()
239
+ vad_accuracy_metrics_fn.reset()
240
+ vad_f1_score_metrics_fn.reset()
241
+
242
+ total_loss = 0.
243
+ total_bce_loss = 0.
244
+ total_dice_loss = 0.
245
+ total_batches = 0.
246
+
247
+ progress_bar_train = tqdm(
248
+ initial=step_idx,
249
+ desc="Training; epoch-{}".format(epoch_idx),
250
+ )
251
+ for train_batch in train_data_loader:
252
+ noisy_audios, batch_vad_segments = train_batch
253
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
254
+ # noisy_audios shape: [b, num_samples]
255
+ num_samples = noisy_audios.shape[-1]
256
+
257
+ logits, probs = model.forward(noisy_audios)
258
+
259
+ targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
260
+
261
+ bce_loss = bce_loss_fn.forward(probs, targets)
262
+ dice_loss = dice_loss_fn.forward(probs, targets)
263
+
264
+ loss = 1.0 * bce_loss + 1.0 * dice_loss
265
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
266
+ logger.info(f"find nan or inf in loss. continue.")
267
+ continue
268
+
269
+ vad_accuracy_metrics_fn.__call__(probs, targets)
270
+ vad_f1_score_metrics_fn.__call__(probs, targets)
271
+
272
+ optimizer.zero_grad()
273
+ loss.backward()
274
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
275
+ optimizer.step()
276
+ lr_scheduler.step()
277
+
278
+ total_loss += loss.item()
279
+ total_bce_loss += bce_loss.item()
280
+ total_dice_loss += dice_loss.item()
281
+ total_batches += 1
282
+
283
+ average_loss = round(total_loss / total_batches, 4)
284
+ average_bce_loss = round(total_bce_loss / total_batches, 4)
285
+ average_dice_loss = round(total_dice_loss / total_batches, 4)
286
+
287
+ metrics = vad_accuracy_metrics_fn.get_metric()
288
+ accuracy = metrics["accuracy"]
289
+ metrics = vad_f1_score_metrics_fn.get_metric()
290
+ f1 = metrics["f1"]
291
+ precision = metrics["precision"]
292
+ recall = metrics["recall"]
293
+
294
+ progress_bar_train.update(1)
295
+ progress_bar_train.set_postfix({
296
+ "lr": lr_scheduler.get_last_lr()[0],
297
+ "loss": average_loss,
298
+ "bce_loss": average_bce_loss,
299
+ "dice_loss": average_dice_loss,
300
+ "accuracy": accuracy,
301
+ "f1": f1,
302
+ "precision": precision,
303
+ "recall": recall,
304
+ })
305
+
306
+ # evaluation
307
+ step_idx += 1
308
+ if step_idx % config.eval_steps == 0:
309
+ with torch.no_grad():
310
+ torch.cuda.empty_cache()
311
+
312
+ model.eval()
313
+ vad_accuracy_metrics_fn.reset()
314
+ vad_f1_score_metrics_fn.reset()
315
+
316
+ total_loss = 0.
317
+ total_bce_loss = 0.
318
+ total_dice_loss = 0.
319
+ total_batches = 0.
320
+
321
+ progress_bar_train.close()
322
+ progress_bar_eval = tqdm(
323
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
324
+ )
325
+ for eval_batch in valid_data_loader:
326
+ noisy_audios, batch_vad_segments = train_batch
327
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
328
+ # noisy_audios shape: [b, num_samples]
329
+ num_samples = noisy_audios.shape[-1]
330
+
331
+ logits, probs = model.forward(noisy_audios)
332
+
333
+ targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
334
+
335
+ bce_loss = bce_loss_fn.forward(probs, targets)
336
+ dice_loss = dice_loss_fn.forward(probs, targets)
337
+
338
+ loss = 1.0 * bce_loss + 1.0 * dice_loss
339
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
340
+ logger.info(f"find nan or inf in loss. continue.")
341
+ continue
342
+
343
+ vad_accuracy_metrics_fn.__call__(probs, targets)
344
+ vad_f1_score_metrics_fn.__call__(probs, targets)
345
+
346
+ total_loss += loss.item()
347
+ total_bce_loss += bce_loss.item()
348
+ total_dice_loss += dice_loss.item()
349
+ total_batches += 1
350
+
351
+ average_loss = round(total_loss / total_batches, 4)
352
+ average_bce_loss = round(total_bce_loss / total_batches, 4)
353
+ average_dice_loss = round(total_dice_loss / total_batches, 4)
354
+
355
+ metrics = vad_accuracy_metrics_fn.get_metric()
356
+ accuracy = metrics["accuracy"]
357
+ metrics = vad_f1_score_metrics_fn.get_metric()
358
+ f1 = metrics["f1"]
359
+ precision = metrics["precision"]
360
+ recall = metrics["recall"]
361
+
362
+ progress_bar_eval.update(1)
363
+ progress_bar_eval.set_postfix({
364
+ "lr": lr_scheduler.get_last_lr()[0],
365
+ "loss": average_loss,
366
+ "bce_loss": average_bce_loss,
367
+ "dice_loss": average_dice_loss,
368
+ "accuracy": accuracy,
369
+ "f1": f1,
370
+ "precision": precision,
371
+ "recall": recall,
372
+ })
373
+
374
+ model.train()
375
+ vad_accuracy_metrics_fn.reset()
376
+ vad_f1_score_metrics_fn.reset()
377
+
378
+ total_loss = 0.
379
+ total_bce_loss = 0.
380
+ total_dice_loss = 0.
381
+ total_batches = 0.
382
+
383
+ progress_bar_eval.close()
384
+ progress_bar_train = tqdm(
385
+ initial=progress_bar_train.n,
386
+ postfix=progress_bar_train.postfix,
387
+ desc=progress_bar_train.desc,
388
+ )
389
+
390
+ # save path
391
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
392
+ save_dir.mkdir(parents=True, exist_ok=False)
393
+
394
+ # save models
395
+ model.save_pretrained(save_dir.as_posix())
396
+
397
+ model_list.append(save_dir)
398
+ if len(model_list) >= args.num_serialized_models_to_keep:
399
+ model_to_delete: Path = model_list.pop(0)
400
+ shutil.rmtree(model_to_delete.as_posix())
401
+
402
+ # save metric
403
+ if best_metric is None:
404
+ best_epoch_idx = epoch_idx
405
+ best_step_idx = step_idx
406
+ best_metric = f1
407
+ elif f1 >= best_metric:
408
+ # great is better.
409
+ best_epoch_idx = epoch_idx
410
+ best_step_idx = step_idx
411
+ best_metric = f1
412
+ else:
413
+ pass
414
+
415
+ metrics = {
416
+ "epoch_idx": epoch_idx,
417
+ "best_epoch_idx": best_epoch_idx,
418
+ "best_step_idx": best_step_idx,
419
+ "loss": average_loss,
420
+ "bce_loss": average_bce_loss,
421
+ "dice_loss": average_dice_loss,
422
+
423
+ "accuracy": accuracy,
424
+ }
425
+ metrics_filename = save_dir / "metrics_epoch.json"
426
+ with open(metrics_filename, "w", encoding="utf-8") as f:
427
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
428
+
429
+ # save best
430
+ best_dir = serialization_dir / "best"
431
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
432
+ if best_dir.exists():
433
+ shutil.rmtree(best_dir)
434
+ shutil.copytree(save_dir, best_dir)
435
+
436
+ # early stop
437
+ early_stop_flag = False
438
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
439
+ patience_count = 0
440
+ else:
441
+ patience_count += 1
442
+ if patience_count >= args.patience:
443
+ early_stop_flag = True
444
+
445
+ # early stop
446
+ if early_stop_flag:
447
+ break
448
+
449
+ return
450
+
451
+
452
+ if __name__ == "__main__":
453
+ main()
examples/fsmn_vad_by_webrtcvad/yaml/config.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "fsmn_vad"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ nfft: 512
6
+ win_size: 240
7
+ hop_size: 80
8
+ win_type: hann
9
+
10
+ # model
11
+ fsmn_input_size: 257
12
+ fsmn_input_affine_size: 140
13
+ fsmn_hidden_size: 250
14
+ fsmn_basic_block_layers: 4
15
+ fsmn_basic_block_hidden_size: 128
16
+ fsmn_basic_block_lorder: 20
17
+ fsmn_basic_block_rorder: 0
18
+ fsmn_basic_block_lstride: 1
19
+ fsmn_basic_block_rstride: 0
20
+ fsmn_output_affine_size: 140
21
+ fsmn_output_size: 1
22
+
23
+ use_softmax: false
24
+
25
+ # data
26
+ min_snr_db: -10
27
+ max_snr_db: 20
28
+
29
+ # train
30
+ lr: 0.001
31
+ lr_scheduler: "CosineAnnealingLR"
32
+ lr_scheduler_kwargs:
33
+ T_max: 250000
34
+ eta_min: 0.0001
35
+
36
+ max_epochs: 100
37
+ clip_grad_norm: 10.0
38
+ seed: 1234
39
+
40
+ num_workers: 4
41
+ batch_size: 128
42
+ eval_steps: 25000
examples/silero_vad_by_webrtcvad/run.sh CHANGED
@@ -2,7 +2,7 @@
2
 
3
  : <<'END'
4
 
5
- bash run.sh --stage 1 --stop_stage 1 --system_version centos \
6
  --file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
7
  --final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
8
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
 
2
 
3
  : <<'END'
4
 
5
+ bash run.sh --stage 2 --stop_stage 2 --system_version centos \
6
  --file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
7
  --final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
8
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
requirements.txt CHANGED
@@ -11,3 +11,4 @@ torchaudio==2.5.1
11
  overrides==7.7.0
12
  webrtcvad==2.0.10
13
  matplotlib==3.10.3
 
 
11
  overrides==7.7.0
12
  webrtcvad==2.0.10
13
  matplotlib==3.10.3
14
+ google-genai
toolbox/torchaudio/models/vad/cnn_vad/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/models/vad/cnn_vad/modeling_cnn_vad.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/models/vad/fsmn_vad/configuration_fsmn_vad.py CHANGED
@@ -13,8 +13,19 @@ class FSMNVadConfig(PretrainedConfig):
13
  hop_size: int = 80,
14
  win_type: str = "hann",
15
 
16
- in_channels: int = 64,
17
- hidden_size: int = 128,
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  lr: float = 0.001,
20
  lr_scheduler: str = "CosineAnnealingLR",
@@ -39,8 +50,19 @@ class FSMNVadConfig(PretrainedConfig):
39
  self.win_type = win_type
40
 
41
  # encoder
42
- self.in_channels = in_channels
43
- self.hidden_size = hidden_size
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # train
46
  self.lr = lr
 
13
  hop_size: int = 80,
14
  win_type: str = "hann",
15
 
16
+ fsmn_input_size: int = 257,
17
+ fsmn_input_affine_size: int = 140,
18
+ fsmn_hidden_size: int = 250,
19
+ fsmn_basic_block_layers: int = 4,
20
+ fsmn_basic_block_hidden_size: int = 128,
21
+ fsmn_basic_block_lorder: int = 20,
22
+ fsmn_basic_block_rorder: int = 0,
23
+ fsmn_basic_block_lstride: int = 1,
24
+ fsmn_basic_block_rstride: int = 0,
25
+ fsmn_output_affine_size: int = 140,
26
+ fsmn_output_size: int = 1,
27
+
28
+ use_softmax: bool = False,
29
 
30
  lr: float = 0.001,
31
  lr_scheduler: str = "CosineAnnealingLR",
 
50
  self.win_type = win_type
51
 
52
  # encoder
53
+ self.fsmn_input_size = fsmn_input_size
54
+ self.fsmn_input_affine_size = fsmn_input_affine_size
55
+ self.fsmn_hidden_size = fsmn_hidden_size
56
+ self.fsmn_basic_block_layers = fsmn_basic_block_layers
57
+ self.fsmn_basic_block_hidden_size = fsmn_basic_block_hidden_size
58
+ self.fsmn_basic_block_lorder = fsmn_basic_block_lorder
59
+ self.fsmn_basic_block_rorder = fsmn_basic_block_rorder
60
+ self.fsmn_basic_block_lstride = fsmn_basic_block_lstride
61
+ self.fsmn_basic_block_rstride = fsmn_basic_block_rstride
62
+ self.fsmn_output_affine_size = fsmn_output_affine_size
63
+ self.fsmn_output_size = fsmn_output_size
64
+
65
+ self.use_softmax = use_softmax
66
 
67
  # train
68
  self.lr = lr
toolbox/torchaudio/models/vad/fsmn_vad/fsmn_encoder.py CHANGED
@@ -226,10 +226,6 @@ class FSMN(nn.Module):
226
  self.out_linear1 = AffineTransform(hidden_size, output_affine_size)
227
  self.out_linear2 = AffineTransform(output_affine_size, output_size)
228
 
229
- self.use_softmax = use_softmax
230
- if self.use_softmax:
231
- self.softmax = nn.Softmax(dim=-1)
232
-
233
  def forward(self,
234
  inputs: torch.Tensor,
235
  cache_list: List[torch.Tensor] = None,
@@ -253,8 +249,6 @@ class FSMN(nn.Module):
253
  outputs = self.out_linear2.forward(x)
254
  # outputs shape: [b, t, f]
255
 
256
- if self.use_softmax:
257
- outputs = self.softmax(outputs)
258
  return outputs, new_cache_list
259
 
260
 
@@ -271,7 +265,6 @@ def main():
271
  basic_block_rstride=1,
272
  output_affine_size=16,
273
  output_size=32,
274
- use_softmax=True,
275
  )
276
 
277
  inputs = torch.randn(size=(1, 198, 32), dtype=torch.float32)
 
226
  self.out_linear1 = AffineTransform(hidden_size, output_affine_size)
227
  self.out_linear2 = AffineTransform(output_affine_size, output_size)
228
 
 
 
 
 
229
  def forward(self,
230
  inputs: torch.Tensor,
231
  cache_list: List[torch.Tensor] = None,
 
249
  outputs = self.out_linear2.forward(x)
250
  # outputs shape: [b, t, f]
251
 
 
 
252
  return outputs, new_cache_list
253
 
254
 
 
265
  basic_block_rstride=1,
266
  output_affine_size=16,
267
  output_size=32,
 
268
  )
269
 
270
  inputs = torch.randn(size=(1, 198, 32), dtype=torch.float32)
toolbox/torchaudio/models/vad/fsmn_vad/modeling_fsmn_vad.py CHANGED
@@ -41,20 +41,104 @@ class FSMNVadModel(nn.Module):
41
  )
42
 
43
  self.fsmn_encoder = FSMN(
44
- input_size=400,
45
- input_affine_size=140,
46
- hidden_size=250,
47
- basic_block_layers=4,
48
- basic_block_hidden_size=128,
49
- basic_block_lorder=20,
50
- basic_block_rorder=0,
51
- basic_block_lstride=1,
52
- basic_block_rstride=0,
53
- output_affine_size=140,
54
- output_size=248,
55
- use_softmax=True,
56
  )
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  if __name__ == "__main__":
60
- pass
 
41
  )
42
 
43
  self.fsmn_encoder = FSMN(
44
+ input_size=config.fsmn_input_size,
45
+ input_affine_size=config.fsmn_input_affine_size,
46
+ hidden_size=config.fsmn_hidden_size,
47
+ basic_block_layers=config.fsmn_basic_block_layers,
48
+ basic_block_hidden_size=config.fsmn_basic_block_hidden_size,
49
+ basic_block_lorder=config.fsmn_basic_block_lorder,
50
+ basic_block_rorder=config.fsmn_basic_block_rorder,
51
+ basic_block_lstride=config.fsmn_basic_block_lstride,
52
+ basic_block_rstride=config.fsmn_basic_block_rstride,
53
+ output_affine_size=config.fsmn_output_affine_size,
54
+ output_size=config.fsmn_output_size,
 
55
  )
56
 
57
+ self.use_softmax = config.use_softmax
58
+ self.sigmoid = nn.Sigmoid()
59
+ self.softmax = nn.Softmax()
60
+
61
+ def forward(self, signal: torch.Tensor):
62
+ if signal.dim() == 2:
63
+ signal = torch.unsqueeze(signal, dim=1)
64
+ _, _, num_samples = signal.shape
65
+ # signal shape [b, 1, num_samples]
66
+
67
+ mags = self.stft.forward(signal)
68
+ # mags shape: [b, f, t]
69
+
70
+ x = torch.transpose(mags, dim0=1, dim1=2)
71
+ # x shape: [b, t, f]
72
+
73
+ logits, _ = self.fsmn_encoder.forward(x)
74
+
75
+ if self.use_softmax:
76
+ probs = self.softmax.forward(logits)
77
+ # probs shape: [b, t, n]
78
+ else:
79
+ probs = self.sigmoid.forward(logits)
80
+ # probs shape: [b, t, 1]
81
+ return logits, probs
82
+
83
+
84
+ class FSMNVadPretrainedModel(FSMNVadModel):
85
+ def __init__(self,
86
+ config: FSMNVadConfig,
87
+ ):
88
+ super(FSMNVadPretrainedModel, self).__init__(
89
+ config=config,
90
+ )
91
+
92
+ @classmethod
93
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
94
+ config = FSMNVadConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
95
+
96
+ model = cls(config)
97
+
98
+ if os.path.isdir(pretrained_model_name_or_path):
99
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
100
+ else:
101
+ ckpt_file = pretrained_model_name_or_path
102
+
103
+ with open(ckpt_file, "rb") as f:
104
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
105
+ model.load_state_dict(state_dict, strict=True)
106
+ return model
107
+
108
+ def save_pretrained(self,
109
+ save_directory: Union[str, os.PathLike],
110
+ state_dict: Optional[dict] = None,
111
+ ):
112
+
113
+ model = self
114
+
115
+ if state_dict is None:
116
+ state_dict = model.state_dict()
117
+
118
+ os.makedirs(save_directory, exist_ok=True)
119
+
120
+ # save state dict
121
+ model_file = os.path.join(save_directory, MODEL_FILE)
122
+ torch.save(state_dict, model_file)
123
+
124
+ # save config
125
+ config_file = os.path.join(save_directory, CONFIG_FILE)
126
+ self.config.to_yaml_file(config_file)
127
+ return save_directory
128
+
129
+
130
+ def main():
131
+ config = FSMNVadConfig()
132
+ model = FSMNVadPretrainedModel(config=config)
133
+
134
+ noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
135
+
136
+ logits, probs = model.forward(noisy)
137
+ print(f"probs: {probs}")
138
+ print(f"probs.shape: {logits.shape}")
139
+ print(f"use_softmax: {config.use_softmax}")
140
+ return
141
+
142
 
143
  if __name__ == "__main__":
144
+ main()
toolbox/torchaudio/models/vad/fsmn_vad/yaml/config-sigmoid.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "fsmn_vad"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ nfft: 512
6
+ win_size: 240
7
+ hop_size: 80
8
+ win_type: hann
9
+
10
+ # model
11
+ fsmn_input_size: 257
12
+ fsmn_input_affine_size: 140
13
+ fsmn_hidden_size: 250
14
+ fsmn_basic_block_layers: 4
15
+ fsmn_basic_block_hidden_size: 128
16
+ fsmn_basic_block_lorder: 20
17
+ fsmn_basic_block_rorder: 0
18
+ fsmn_basic_block_lstride: 1
19
+ fsmn_basic_block_rstride: 0
20
+ fsmn_output_affine_size: 140
21
+ fsmn_output_size: 1
22
+
23
+ use_softmax: false
24
+
25
+ # data
26
+ min_snr_db: -10
27
+ max_snr_db: 20
28
+
29
+ # train
30
+ lr: 0.001
31
+ lr_scheduler: "CosineAnnealingLR"
32
+ lr_scheduler_kwargs:
33
+ T_max: 250000
34
+ eta_min: 0.0001
35
+
36
+ max_epochs: 100
37
+ clip_grad_norm: 10.0
38
+ seed: 1234
39
+
40
+ num_workers: 4
41
+ batch_size: 128
42
+ eval_steps: 25000
toolbox/torchaudio/models/vad/silero_vad/inference_silero_vad.py CHANGED
@@ -5,6 +5,7 @@ import logging
5
  from pathlib import Path
6
  import shutil
7
  import tempfile, time
 
8
  import zipfile
9
 
10
  from scipy.io import wavfile
@@ -18,17 +19,14 @@ torch.set_num_threads(1)
18
  from project_settings import project_path
19
  from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
20
  from toolbox.torchaudio.models.vad.silero_vad.modeling_silero_vad import SileroVadPretrainedModel, MODEL_FILE
21
- from toolbox.vad.vad import FrameVoiceClassifier, RingVad, process_speech_probs, make_visualization
22
 
23
 
24
  logger = logging.getLogger("toolbox")
25
 
26
 
27
- class SileroVadVoiceClassifier(FrameVoiceClassifier):
28
- def __init__(self,
29
- pretrained_model_path_or_zip_file: str,
30
- device: str = "cpu",
31
- ):
32
  self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
33
  self.device = torch.device(device)
34
 
@@ -62,72 +60,38 @@ class SileroVadVoiceClassifier(FrameVoiceClassifier):
62
  shutil.rmtree(model_path)
63
  return config, model
64
 
65
- def predict(self, chunk: np.ndarray) -> float:
66
- if chunk.dtype != np.int16:
67
- raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
68
 
69
- chunk = chunk / 32768
70
-
71
- inputs = torch.tensor(chunk, dtype=torch.float32)
72
  inputs = torch.unsqueeze(inputs, dim=0)
 
73
 
74
- try:
75
- logits, _ = self.model.forward(inputs)
76
- except RuntimeError as e:
77
- print(inputs.shape)
78
- raise e
79
- # logits shape: [b, t, 1]
80
- logits_ = torch.mean(logits, dim=1)
81
- # logits_ shape: [b, 1]
82
- probs = torch.sigmoid(logits_)
83
-
84
- voice_prob = probs[0][0]
85
- return float(voice_prob)
86
-
87
-
88
- class InferenceSileroVad(object):
89
- def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
90
- self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
91
- self.device = torch.device(device)
92
-
93
- self.voice_classifier = SileroVadVoiceClassifier(pretrained_model_path_or_zip_file, device=device)
94
-
95
- self.ring_vad = RingVad(model=self.voice_classifier,
96
- start_ring_rate=0.2,
97
- end_ring_rate=0.1,
98
- frame_size_ms=30,
99
- frame_step_ms=30,
100
- padding_length_ms=300,
101
- max_silence_length_ms=300,
102
- sample_rate=SAMPLE_RATE,
103
- )
104
-
105
- def vad(self, signal: np.ndarray) -> np.ndarray:
106
- self.ring_vad.reset()
107
-
108
- vad_segments = list()
109
 
110
- segments = self.ring_vad.vad(signal)
111
- vad_segments += segments
112
- # last vad segment
113
- segments = self.ring_vad.last_vad_segments()
114
- vad_segments += segments
115
- return vad_segments
116
 
117
- def get_vad_speech_probs(self):
118
- result = self.ring_vad.speech_probs
119
- return result
 
120
 
121
- def get_vad_frame_step(self):
122
- result = self.ring_vad.frame_step
123
- return result
124
 
125
 
126
  def get_args():
127
  parser = argparse.ArgumentParser()
128
  parser.add_argument(
129
  "--wav_file",
130
- default=(project_path / "data/examples/hado/2f16ca0b-baec-4601-8a1e-7893eb875623.wav").as_posix(),
 
 
 
 
131
  type=str,
132
  )
133
  args = parser.parse_args()
@@ -143,17 +107,18 @@ def main():
143
  sample_rate, signal = wavfile.read(args.wav_file)
144
  if SAMPLE_RATE != sample_rate:
145
  raise AssertionError
 
146
 
147
  infer = InferenceSileroVad(
148
- pretrained_model_path_or_zip_file=(project_path / "trained_models/silero-vad-webrtcvad-nx2-dns3.zip").as_posix()
 
149
  )
 
150
 
151
- vad_segments = infer.vad(signal)
152
 
153
- speech_probs = infer.get_vad_speech_probs()
154
- frame_step = infer.get_vad_frame_step()
155
 
156
- # speech_probs
157
  speech_probs = process_speech_probs(
158
  signal=signal,
159
  speech_probs=speech_probs,
@@ -161,7 +126,7 @@ def main():
161
  )
162
 
163
  # plot
164
- make_visualization(signal, speech_probs, SAMPLE_RATE, vad_segments)
165
  return
166
 
167
 
 
5
  from pathlib import Path
6
  import shutil
7
  import tempfile, time
8
+ from typing import List
9
  import zipfile
10
 
11
  from scipy.io import wavfile
 
19
  from project_settings import project_path
20
  from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
21
  from toolbox.torchaudio.models.vad.silero_vad.modeling_silero_vad import SileroVadPretrainedModel, MODEL_FILE
22
+ from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
23
 
24
 
25
  logger = logging.getLogger("toolbox")
26
 
27
 
28
+ class InferenceSileroVad(object):
29
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
 
 
 
30
  self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
31
  self.device = torch.device(device)
32
 
 
60
  shutil.rmtree(model_path)
61
  return config, model
62
 
63
+ def infer(self, signal: torch.Tensor) -> float:
64
+ # signal shape: [num_samples,], value between -1 and 1.
 
65
 
66
+ inputs = torch.tensor(signal, dtype=torch.float32)
 
 
67
  inputs = torch.unsqueeze(inputs, dim=0)
68
+ # inputs shape: [1, num_samples,]
69
 
70
+ with torch.no_grad():
71
+ logits, probs = self.model.forward(inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ # probs shape: [b, t, 1]
74
+ probs = torch.squeeze(probs, dim=-1)
75
+ # probs shape: [b, t]
 
 
 
76
 
77
+ probs = probs.numpy()
78
+ probs = probs[0]
79
+ probs = probs.tolist()
80
+ return probs
81
 
82
+ def post_process(self, probs: List[float]):
83
+ return
 
84
 
85
 
86
  def get_args():
87
  parser = argparse.ArgumentParser()
88
  parser.add_argument(
89
  "--wav_file",
90
+ default=(project_path / "data/examples/ai_agent/chinese-4.wav").as_posix(),
91
+ # default=(project_path / "data/examples/ai_agent/chinese-5.wav").as_posix(),
92
+ # default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
93
+ # default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
94
+ # default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-SG\2025-06-17\active_media_r_0af6bd3a-9aef-4bef-935b-63abfb4d46d8_5.wav",
95
  type=str,
96
  )
97
  args = parser.parse_args()
 
107
  sample_rate, signal = wavfile.read(args.wav_file)
108
  if SAMPLE_RATE != sample_rate:
109
  raise AssertionError
110
+ signal = signal / (1 << 15)
111
 
112
  infer = InferenceSileroVad(
113
+ pretrained_model_path_or_zip_file=(project_path / "trained_models/silero-vad-by-webrtcvad-nx2-dns3.zip").as_posix()
114
+ # pretrained_model_path_or_zip_file=(project_path / "trained_models/silero-vad-webrtcvad-nx2-dns3.zip").as_posix()
115
  )
116
+ frame_step = infer.model.hop_size
117
 
118
+ speech_probs = infer.infer(signal)
119
 
120
+ # print(speech_probs)
 
121
 
 
122
  speech_probs = process_speech_probs(
123
  signal=signal,
124
  speech_probs=speech_probs,
 
126
  )
127
 
128
  # plot
129
+ make_visualization(signal, speech_probs, SAMPLE_RATE)
130
  return
131
 
132
 
toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py CHANGED
@@ -82,6 +82,11 @@ class Encoder(nn.Module):
82
  class SileroVadModel(nn.Module):
83
  def __init__(self, config: SileroVadConfig):
84
  super(SileroVadModel, self).__init__()
 
 
 
 
 
85
  self.config = config
86
  self.eps = 1e-12
87
 
@@ -120,6 +125,11 @@ class SileroVadModel(nn.Module):
120
  self.sigmoid = nn.Sigmoid()
121
 
122
  def forward(self, signal: torch.Tensor):
 
 
 
 
 
123
  mags = self.stft.forward(signal)
124
  # mags shape: [b, f, t]
125
 
@@ -139,6 +149,35 @@ class SileroVadModel(nn.Module):
139
  # probs shape: [b, t, 1]
140
  return logits, probs
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  class SileroVadPretrainedModel(SileroVadModel):
144
  def __init__(self,
 
82
  class SileroVadModel(nn.Module):
83
  def __init__(self, config: SileroVadConfig):
84
  super(SileroVadModel, self).__init__()
85
+ self.nfft = config.nfft
86
+ self.win_size = config.win_size
87
+ self.hop_size = config.hop_size
88
+ self.win_type = config.win_type
89
+
90
  self.config = config
91
  self.eps = 1e-12
92
 
 
125
  self.sigmoid = nn.Sigmoid()
126
 
127
  def forward(self, signal: torch.Tensor):
128
+ if signal.dim() == 2:
129
+ signal = torch.unsqueeze(signal, dim=1)
130
+ _, _, num_samples = signal.shape
131
+ # signal shape [b, 1, num_samples]
132
+
133
  mags = self.stft.forward(signal)
134
  # mags shape: [b, f, t]
135
 
 
149
  # probs shape: [b, t, 1]
150
  return logits, probs
151
 
152
+ def forward_chunk(self, chunk: torch.Tensor):
153
+ # chunk shape [b, 1, num_samples]
154
+
155
+ mags = self.stft.forward(chunk)
156
+ # mags shape: [b, f, t]
157
+
158
+ x = torch.transpose(mags, dim0=1, dim1=2)
159
+ # x shape: [b, t, f]
160
+
161
+ x = self.linear.forward(x)
162
+ # x shape: [b, t, f']
163
+
164
+ return
165
+
166
+ def forward_chunk_by_chunk(self, signal: torch.Tensor):
167
+ if signal.dim() == 2:
168
+ signal = torch.unsqueeze(signal, dim=1)
169
+ _, _, num_samples = signal.shape
170
+ # signal shape [b, 1, num_samples]
171
+
172
+ t = (num_samples - self.win_size) // self.hop_size + 1
173
+ waveform_list = list()
174
+ for i in range(int(t)):
175
+ begin = i * self.hop_size
176
+ end = begin + self.win_size
177
+ sub_signal = signal[:, :, begin: end]
178
+
179
+ return
180
+
181
 
182
  class SileroVadPretrainedModel(SileroVadModel):
183
  def __init__(self,
toolbox/torchaudio/models/vad/ten_vad/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://huggingface.co/TEN-framework/ten-vad
5
+ https://zhuanlan.zhihu.com/p/1906832842756976909
6
+ https://github.com/TEN-framework/ten-vad
7
+
8
+ """
9
+
10
+
11
+ if __name__ == "__main__":
12
+ pass
toolbox/torchaudio/models/vad/ten_vad/modeling_ten_vad.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/models/vad/wav2vec2_vad/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/models/vad/wav2vec2_vad/modeling_wav2vec2.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/utils/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/utils/visualization.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import List
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+
8
+
9
+ def process_speech_probs(signal: np.ndarray, speech_probs: List[float], frame_step: int) -> np.ndarray:
10
+ speech_probs_ = list()
11
+ for p in speech_probs[1:]:
12
+ speech_probs_.extend([p] * frame_step)
13
+
14
+ pad = (signal.shape[0] - len(speech_probs_))
15
+ speech_probs_ = speech_probs_ + [0.0] * pad
16
+ speech_probs_ = np.array(speech_probs_, dtype=np.float32)
17
+
18
+ if len(speech_probs_) != len(signal):
19
+ raise AssertionError
20
+ return speech_probs_
21
+
22
+
23
+ def make_visualization(signal: np.ndarray, speech_probs, sample_rate: int):
24
+ time = np.arange(0, len(signal)) / sample_rate
25
+ plt.figure(figsize=(12, 5))
26
+ plt.plot(time, signal, color='b')
27
+ plt.plot(time, speech_probs, color='gray')
28
+ plt.show()
29
+ return
30
+
31
+
32
+ if __name__ == "__main__":
33
+ pass