update
Browse files- Dockerfile +2 -4
- examples/data_annotation/annotation_by_google.py +158 -0
- examples/fsmn_vad_by_webrtcvad/run.sh +174 -0
- examples/fsmn_vad_by_webrtcvad/step_1_prepare_data.py +231 -0
- examples/fsmn_vad_by_webrtcvad/step_2_make_vad_segments.py +205 -0
- examples/fsmn_vad_by_webrtcvad/step_3_check_vad.py +68 -0
- examples/fsmn_vad_by_webrtcvad/step_4_train_model.py +453 -0
- examples/fsmn_vad_by_webrtcvad/yaml/config.yaml +42 -0
- examples/silero_vad_by_webrtcvad/run.sh +1 -1
- requirements.txt +1 -0
- toolbox/torchaudio/models/vad/cnn_vad/__init__.py +6 -0
- toolbox/torchaudio/models/vad/cnn_vad/modeling_cnn_vad.py +6 -0
- toolbox/torchaudio/models/vad/fsmn_vad/configuration_fsmn_vad.py +26 -4
- toolbox/torchaudio/models/vad/fsmn_vad/fsmn_encoder.py +0 -7
- toolbox/torchaudio/models/vad/fsmn_vad/modeling_fsmn_vad.py +97 -13
- toolbox/torchaudio/models/vad/fsmn_vad/yaml/config-sigmoid.yaml +42 -0
- toolbox/torchaudio/models/vad/silero_vad/inference_silero_vad.py +31 -66
- toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py +39 -0
- toolbox/torchaudio/models/vad/ten_vad/__init__.py +12 -0
- toolbox/torchaudio/models/vad/ten_vad/modeling_ten_vad.py +6 -0
- toolbox/torchaudio/models/vad/wav2vec2_vad/__init__.py +6 -0
- toolbox/torchaudio/models/vad/wav2vec2_vad/modeling_wav2vec2.py +6 -0
- toolbox/torchaudio/utils/__init__.py +6 -0
- toolbox/torchaudio/utils/visualization.py +33 -0
Dockerfile
CHANGED
@@ -10,10 +10,6 @@ RUN apt-get install -y ffmpeg build-essential
|
|
10 |
RUN pip install --upgrade pip
|
11 |
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
12 |
|
13 |
-
RUN pip install --upgrade pip
|
14 |
-
|
15 |
-
RUN bash install.sh --stage 1 --stop_stage 2 --system_version centos
|
16 |
-
|
17 |
USER user
|
18 |
|
19 |
ENV HOME=/home/user \
|
@@ -23,4 +19,6 @@ WORKDIR $HOME/app
|
|
23 |
|
24 |
COPY --chown=user . $HOME/app
|
25 |
|
|
|
|
|
26 |
CMD ["python3", "main.py"]
|
|
|
10 |
RUN pip install --upgrade pip
|
11 |
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
12 |
|
|
|
|
|
|
|
|
|
13 |
USER user
|
14 |
|
15 |
ENV HOME=/home/user \
|
|
|
19 |
|
20 |
COPY --chown=user . $HOME/app
|
21 |
|
22 |
+
RUN bash install.sh --stage 1 --stop_stage 2 --system_version centos
|
23 |
+
|
24 |
CMD ["python3", "main.py"]
|
examples/data_annotation/annotation_by_google.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
from pathlib import Path
|
8 |
+
import sys
|
9 |
+
|
10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
11 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
12 |
+
|
13 |
+
from google import genai
|
14 |
+
from google.genai import types
|
15 |
+
|
16 |
+
from project_settings import environment, project_path
|
17 |
+
|
18 |
+
|
19 |
+
def get_args():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument(
|
22 |
+
"--google_application_credentials",
|
23 |
+
default=(project_path / "dotenv/potent-veld-462405-t3-8091a29b2894.json").as_posix(),
|
24 |
+
type=str
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--model_name",
|
28 |
+
default="gemini-2.5-pro",
|
29 |
+
type=str
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--speech_audio_dir",
|
33 |
+
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-SG\2025-06-17",
|
34 |
+
type=str
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--output_file",
|
38 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\nx-noise\en-SG\2025-06-17\vad.jsonl",
|
39 |
+
default=r"vad.jsonl",
|
40 |
+
type=str
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--gemini_api_key",
|
44 |
+
default=environment.get("GEMINI_API_KEY", dtype=str),
|
45 |
+
type=str
|
46 |
+
)
|
47 |
+
args = parser.parse_args()
|
48 |
+
return args
|
49 |
+
|
50 |
+
|
51 |
+
def main():
|
52 |
+
args = get_args()
|
53 |
+
|
54 |
+
speech_audio_dir = Path(args.speech_audio_dir)
|
55 |
+
output_file = Path(args.output_file)
|
56 |
+
|
57 |
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = args.google_application_credentials
|
58 |
+
os.environ["gemini_api_key"] = args.gemini_api_key
|
59 |
+
|
60 |
+
|
61 |
+
developer_client = genai.Client(
|
62 |
+
api_key=args.gemini_api_key,
|
63 |
+
)
|
64 |
+
client = genai.Client(
|
65 |
+
vertexai=True,
|
66 |
+
project="potent-veld-462405-t3",
|
67 |
+
location="global",
|
68 |
+
)
|
69 |
+
generate_content_config = types.GenerateContentConfig(
|
70 |
+
temperature=1,
|
71 |
+
top_p=0.95,
|
72 |
+
max_output_tokens=8192,
|
73 |
+
response_modalities=["TEXT"],
|
74 |
+
)
|
75 |
+
|
76 |
+
# finished
|
77 |
+
finished_set = set()
|
78 |
+
if output_file.exists():
|
79 |
+
with open(output_file.as_posix(), "r", encoding="utf-8") as f:
|
80 |
+
for row in f:
|
81 |
+
row = json.loads(row)
|
82 |
+
name = row["name"]
|
83 |
+
finished_set.add(name)
|
84 |
+
print(f"finished count: {len(finished_set)}")
|
85 |
+
|
86 |
+
with open(output_file.as_posix(), "a+", encoding="utf-8") as f:
|
87 |
+
|
88 |
+
for filename in speech_audio_dir.glob("**/*.wav"):
|
89 |
+
name = filename.name
|
90 |
+
if name in finished_set:
|
91 |
+
continue
|
92 |
+
finished_set.add(name)
|
93 |
+
|
94 |
+
# upload
|
95 |
+
audio_file = developer_client.files.upload(
|
96 |
+
file=filename.as_posix(),
|
97 |
+
config=None
|
98 |
+
)
|
99 |
+
print(f"upload file: {audio_file.name}")
|
100 |
+
|
101 |
+
prompt = f"""
|
102 |
+
给我这段音频中的语音分段的开始和结束时间,单位为秒,精确到毫秒,并输出JSON格式,
|
103 |
+
例如:
|
104 |
+
```json
|
105 |
+
[[0.254, 1.214], [2.200, 3.100]],
|
106 |
+
```
|
107 |
+
如果没有语音段则输出:
|
108 |
+
```json
|
109 |
+
[]
|
110 |
+
```
|
111 |
+
""".strip()
|
112 |
+
|
113 |
+
try:
|
114 |
+
contents = [
|
115 |
+
types.Content(
|
116 |
+
role="user",
|
117 |
+
parts=[
|
118 |
+
types.Part(text=prompt),
|
119 |
+
types.Part.from_uri(
|
120 |
+
file_uri=audio_file.uri,
|
121 |
+
mime_type=audio_file.mime_type,
|
122 |
+
)
|
123 |
+
]
|
124 |
+
)
|
125 |
+
]
|
126 |
+
response: types.GenerateContentResponse = developer_client.models.generate_content(
|
127 |
+
model=args.model_name,
|
128 |
+
contents=contents,
|
129 |
+
config=generate_content_config,
|
130 |
+
)
|
131 |
+
answer = response.candidates[0].content.parts[0].text
|
132 |
+
print(answer)
|
133 |
+
finally:
|
134 |
+
# delete
|
135 |
+
print(f"delete file: {audio_file.name}")
|
136 |
+
developer_client.files.delete(name=audio_file.name)
|
137 |
+
|
138 |
+
pattern = "```json(.+?)```"
|
139 |
+
match = re.search(pattern=pattern, string=answer, flags=re.DOTALL | re.IGNORECASE)
|
140 |
+
if match is None:
|
141 |
+
raise AssertionError(f"answer: {answer}")
|
142 |
+
vad_segments = match.group(1)
|
143 |
+
vad_segments = json.loads(vad_segments)
|
144 |
+
row = {
|
145 |
+
"name": name,
|
146 |
+
"filename": filename.as_posix(),
|
147 |
+
"vad_segments": vad_segments
|
148 |
+
}
|
149 |
+
row = json.dumps(row, ensure_ascii=False)
|
150 |
+
|
151 |
+
f.write(f"{row}\n")
|
152 |
+
exit(0)
|
153 |
+
|
154 |
+
return
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
main()
|
examples/fsmn_vad_by_webrtcvad/run.sh
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
: <<'END'
|
4 |
+
|
5 |
+
bash run.sh --stage 2 --stop_stage 2 --system_version centos \
|
6 |
+
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
7 |
+
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
8 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
|
9 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
|
10 |
+
|
11 |
+
bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
12 |
+
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
13 |
+
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
14 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
|
15 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
|
16 |
+
|
17 |
+
|
18 |
+
END
|
19 |
+
|
20 |
+
|
21 |
+
# params
|
22 |
+
system_version="windows";
|
23 |
+
verbose=true;
|
24 |
+
stage=0 # start from 0 if you need to start from data preparation
|
25 |
+
stop_stage=9
|
26 |
+
|
27 |
+
work_dir="$(pwd)"
|
28 |
+
file_folder_name=file_folder_name
|
29 |
+
final_model_name=final_model_name
|
30 |
+
config_file="yaml/config.yaml"
|
31 |
+
limit=10
|
32 |
+
|
33 |
+
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
34 |
+
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
35 |
+
|
36 |
+
max_count=-1
|
37 |
+
|
38 |
+
nohup_name=nohup.out
|
39 |
+
|
40 |
+
# model params
|
41 |
+
batch_size=64
|
42 |
+
max_epochs=200
|
43 |
+
save_top_k=10
|
44 |
+
patience=5
|
45 |
+
|
46 |
+
|
47 |
+
# parse options
|
48 |
+
while true; do
|
49 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
50 |
+
case "$1" in
|
51 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
52 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
53 |
+
old_value="(eval echo \\$$name)";
|
54 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
55 |
+
was_bool=true;
|
56 |
+
else
|
57 |
+
was_bool=false;
|
58 |
+
fi
|
59 |
+
|
60 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
61 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
62 |
+
eval "${name}=\"$2\"";
|
63 |
+
|
64 |
+
# Check that Boolean-valued arguments are really Boolean.
|
65 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
66 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
67 |
+
exit 1;
|
68 |
+
fi
|
69 |
+
shift 2;
|
70 |
+
;;
|
71 |
+
|
72 |
+
*) break;
|
73 |
+
esac
|
74 |
+
done
|
75 |
+
|
76 |
+
file_dir="${work_dir}/${file_folder_name}"
|
77 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
78 |
+
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
79 |
+
|
80 |
+
train_dataset="${file_dir}/train.jsonl"
|
81 |
+
valid_dataset="${file_dir}/valid.jsonl"
|
82 |
+
|
83 |
+
train_vad_dataset="${file_dir}/train-vad.jsonl"
|
84 |
+
valid_vad_dataset="${file_dir}/valid-vad.jsonl"
|
85 |
+
|
86 |
+
$verbose && echo "system_version: ${system_version}"
|
87 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
88 |
+
|
89 |
+
if [ $system_version == "windows" ]; then
|
90 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
91 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
92 |
+
#source /data/local/bin/nx_denoise/bin/activate
|
93 |
+
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
94 |
+
fi
|
95 |
+
|
96 |
+
|
97 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
98 |
+
$verbose && echo "stage 1: prepare data"
|
99 |
+
cd "${work_dir}" || exit 1
|
100 |
+
python3 step_1_prepare_data.py \
|
101 |
+
--noise_dir "${noise_dir}" \
|
102 |
+
--speech_dir "${speech_dir}" \
|
103 |
+
--train_dataset "${train_dataset}" \
|
104 |
+
--valid_dataset "${valid_dataset}" \
|
105 |
+
--max_count "${max_count}" \
|
106 |
+
|
107 |
+
fi
|
108 |
+
|
109 |
+
|
110 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
111 |
+
$verbose && echo "stage 2: make vad segments"
|
112 |
+
cd "${work_dir}" || exit 1
|
113 |
+
python3 step_2_make_vad_segments.py \
|
114 |
+
--train_dataset "${train_dataset}" \
|
115 |
+
--valid_dataset "${valid_dataset}" \
|
116 |
+
--train_vad_dataset "${train_vad_dataset}" \
|
117 |
+
--valid_vad_dataset "${valid_vad_dataset}" \
|
118 |
+
|
119 |
+
fi
|
120 |
+
|
121 |
+
|
122 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
123 |
+
$verbose && echo "stage 3: train model"
|
124 |
+
cd "${work_dir}" || exit 1
|
125 |
+
python3 step_4_train_model.py \
|
126 |
+
--train_dataset "${train_vad_dataset}" \
|
127 |
+
--valid_dataset "${valid_vad_dataset}" \
|
128 |
+
--serialization_dir "${file_dir}" \
|
129 |
+
--config_file "${config_file}" \
|
130 |
+
|
131 |
+
fi
|
132 |
+
|
133 |
+
|
134 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
135 |
+
$verbose && echo "stage 4: test model"
|
136 |
+
cd "${work_dir}" || exit 1
|
137 |
+
python3 step_3_evaluation.py \
|
138 |
+
--valid_dataset "${valid_dataset}" \
|
139 |
+
--model_dir "${file_dir}/best" \
|
140 |
+
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
141 |
+
--limit "${limit}" \
|
142 |
+
|
143 |
+
fi
|
144 |
+
|
145 |
+
|
146 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
147 |
+
$verbose && echo "stage 5: collect files"
|
148 |
+
cd "${work_dir}" || exit 1
|
149 |
+
|
150 |
+
mkdir -p ${final_model_dir}
|
151 |
+
|
152 |
+
cp "${file_dir}/best"/* "${final_model_dir}"
|
153 |
+
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
154 |
+
|
155 |
+
cd "${final_model_dir}/.." || exit 1;
|
156 |
+
|
157 |
+
if [ -e "${final_model_name}.zip" ]; then
|
158 |
+
rm -rf "${final_model_name}_backup.zip"
|
159 |
+
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
160 |
+
fi
|
161 |
+
|
162 |
+
zip -r "${final_model_name}.zip" "${final_model_name}"
|
163 |
+
rm -rf "${final_model_name}"
|
164 |
+
|
165 |
+
fi
|
166 |
+
|
167 |
+
|
168 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
169 |
+
$verbose && echo "stage 6: clear file_dir"
|
170 |
+
cd "${work_dir}" || exit 1
|
171 |
+
|
172 |
+
rm -rf "${file_dir}";
|
173 |
+
|
174 |
+
fi
|
examples/fsmn_vad_by_webrtcvad/step_1_prepare_data.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
import random
|
8 |
+
import sys
|
9 |
+
import time
|
10 |
+
|
11 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
12 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
13 |
+
|
14 |
+
import librosa
|
15 |
+
import numpy as np
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
|
19 |
+
def get_args():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument(
|
22 |
+
"--noise_dir",
|
23 |
+
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
24 |
+
type=str
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--speech_dir",
|
28 |
+
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech",
|
29 |
+
type=str
|
30 |
+
)
|
31 |
+
|
32 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
33 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
34 |
+
|
35 |
+
parser.add_argument("--duration", default=8.0, type=float)
|
36 |
+
parser.add_argument("--min_speech_duration", default=6.0, type=float)
|
37 |
+
parser.add_argument("--max_speech_duration", default=8.0, type=float)
|
38 |
+
parser.add_argument("--min_snr_db", default=-10, type=float)
|
39 |
+
parser.add_argument("--max_snr_db", default=20, type=float)
|
40 |
+
|
41 |
+
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
42 |
+
|
43 |
+
parser.add_argument("--max_count", default=-1, type=int)
|
44 |
+
|
45 |
+
args = parser.parse_args()
|
46 |
+
return args
|
47 |
+
|
48 |
+
|
49 |
+
def target_second_noise_signal_generator(data_dir: str,
|
50 |
+
duration: int = 4,
|
51 |
+
sample_rate: int = 8000, max_epoch: int = 20000):
|
52 |
+
noise_list = list()
|
53 |
+
wait_duration = duration
|
54 |
+
|
55 |
+
data_dir = Path(data_dir)
|
56 |
+
for epoch_idx in range(max_epoch):
|
57 |
+
for filename in data_dir.glob("**/*.wav"):
|
58 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
59 |
+
|
60 |
+
if signal.ndim != 1:
|
61 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
62 |
+
|
63 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
64 |
+
|
65 |
+
offset = 0.
|
66 |
+
rest_duration = raw_duration
|
67 |
+
|
68 |
+
for _ in range(1000):
|
69 |
+
if rest_duration <= 0:
|
70 |
+
break
|
71 |
+
if rest_duration <= wait_duration:
|
72 |
+
noise_list.append({
|
73 |
+
"epoch_idx": epoch_idx,
|
74 |
+
"filename": filename.as_posix(),
|
75 |
+
"raw_duration": round(raw_duration, 4),
|
76 |
+
"offset": round(offset, 4),
|
77 |
+
"duration": None,
|
78 |
+
"duration_": round(rest_duration, 4),
|
79 |
+
})
|
80 |
+
wait_duration -= rest_duration
|
81 |
+
offset = 0
|
82 |
+
rest_duration = 0
|
83 |
+
elif rest_duration > wait_duration:
|
84 |
+
noise_list.append({
|
85 |
+
"epoch_idx": epoch_idx,
|
86 |
+
"filename": filename.as_posix(),
|
87 |
+
"raw_duration": round(raw_duration, 4),
|
88 |
+
"offset": round(offset, 4),
|
89 |
+
"duration": round(wait_duration, 4),
|
90 |
+
"duration_": round(wait_duration, 4),
|
91 |
+
})
|
92 |
+
offset += wait_duration
|
93 |
+
rest_duration -= wait_duration
|
94 |
+
wait_duration = 0
|
95 |
+
else:
|
96 |
+
raise AssertionError
|
97 |
+
|
98 |
+
if wait_duration <= 0:
|
99 |
+
yield noise_list
|
100 |
+
noise_list = list()
|
101 |
+
wait_duration = duration
|
102 |
+
|
103 |
+
|
104 |
+
def target_second_speech_signal_generator(data_dir: str,
|
105 |
+
min_duration: int = 4,
|
106 |
+
max_duration: int = 6,
|
107 |
+
sample_rate: int = 8000, max_epoch: int = 1):
|
108 |
+
data_dir = Path(data_dir)
|
109 |
+
for epoch_idx in range(max_epoch):
|
110 |
+
for filename in data_dir.glob("**/*.wav"):
|
111 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
112 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
113 |
+
|
114 |
+
if signal.ndim != 1:
|
115 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
116 |
+
|
117 |
+
if raw_duration < min_duration:
|
118 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
119 |
+
continue
|
120 |
+
|
121 |
+
if raw_duration < max_duration:
|
122 |
+
row = {
|
123 |
+
"epoch_idx": epoch_idx,
|
124 |
+
"filename": filename.as_posix(),
|
125 |
+
"raw_duration": round(raw_duration, 4),
|
126 |
+
"offset": 0.,
|
127 |
+
"duration": round(raw_duration, 4),
|
128 |
+
}
|
129 |
+
yield row
|
130 |
+
|
131 |
+
signal_length = len(signal)
|
132 |
+
win_size = int(max_duration * sample_rate)
|
133 |
+
for begin in range(0, signal_length - win_size, win_size):
|
134 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
135 |
+
continue
|
136 |
+
row = {
|
137 |
+
"epoch_idx": epoch_idx,
|
138 |
+
"filename": filename.as_posix(),
|
139 |
+
"raw_duration": round(raw_duration, 4),
|
140 |
+
"offset": round(begin / sample_rate, 4),
|
141 |
+
"duration": round(max_duration, 4),
|
142 |
+
}
|
143 |
+
yield row
|
144 |
+
|
145 |
+
|
146 |
+
def main():
|
147 |
+
args = get_args()
|
148 |
+
|
149 |
+
noise_dir = Path(args.noise_dir)
|
150 |
+
speech_dir = Path(args.speech_dir)
|
151 |
+
|
152 |
+
train_dataset = Path(args.train_dataset)
|
153 |
+
valid_dataset = Path(args.valid_dataset)
|
154 |
+
train_dataset.parent.mkdir(parents=True, exist_ok=True)
|
155 |
+
valid_dataset.parent.mkdir(parents=True, exist_ok=True)
|
156 |
+
|
157 |
+
noise_generator = target_second_noise_signal_generator(
|
158 |
+
noise_dir.as_posix(),
|
159 |
+
duration=args.duration,
|
160 |
+
sample_rate=args.target_sample_rate,
|
161 |
+
max_epoch=100000,
|
162 |
+
)
|
163 |
+
speech_generator = target_second_speech_signal_generator(
|
164 |
+
speech_dir.as_posix(),
|
165 |
+
min_duration=args.min_speech_duration,
|
166 |
+
max_duration=args.max_speech_duration,
|
167 |
+
sample_rate=args.target_sample_rate,
|
168 |
+
max_epoch=1,
|
169 |
+
)
|
170 |
+
|
171 |
+
count = 0
|
172 |
+
process_bar = tqdm(desc="build dataset jsonl")
|
173 |
+
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
174 |
+
for speech, noise_list in zip(speech_generator, noise_generator):
|
175 |
+
if count >= args.max_count > 0:
|
176 |
+
break
|
177 |
+
|
178 |
+
# row
|
179 |
+
speech_filename = speech["filename"]
|
180 |
+
speech_raw_duration = speech["raw_duration"]
|
181 |
+
speech_offset = speech["offset"]
|
182 |
+
speech_duration = speech["duration"]
|
183 |
+
|
184 |
+
noise_list = [
|
185 |
+
{
|
186 |
+
"filename": noise["filename"],
|
187 |
+
"raw_duration": noise["raw_duration"],
|
188 |
+
"offset": noise["offset"],
|
189 |
+
"duration": noise["duration"],
|
190 |
+
}
|
191 |
+
for noise in noise_list
|
192 |
+
]
|
193 |
+
|
194 |
+
# row
|
195 |
+
random1 = random.random()
|
196 |
+
random2 = random.random()
|
197 |
+
|
198 |
+
row = {
|
199 |
+
"count": count,
|
200 |
+
|
201 |
+
"speech_filename": speech_filename,
|
202 |
+
"speech_raw_duration": speech_raw_duration,
|
203 |
+
"speech_offset": speech_offset,
|
204 |
+
"speech_duration": speech_duration,
|
205 |
+
|
206 |
+
"noise_list": noise_list,
|
207 |
+
|
208 |
+
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
209 |
+
|
210 |
+
"random1": random1,
|
211 |
+
}
|
212 |
+
row = json.dumps(row, ensure_ascii=False)
|
213 |
+
if random2 < (1 / 300 / 1):
|
214 |
+
fvalid.write(f"{row}\n")
|
215 |
+
else:
|
216 |
+
ftrain.write(f"{row}\n")
|
217 |
+
|
218 |
+
count += 1
|
219 |
+
duration_seconds = count * args.duration
|
220 |
+
duration_hours = duration_seconds / 3600
|
221 |
+
|
222 |
+
process_bar.update(n=1)
|
223 |
+
process_bar.set_postfix({
|
224 |
+
"duration_hours": round(duration_hours, 4),
|
225 |
+
})
|
226 |
+
|
227 |
+
return
|
228 |
+
|
229 |
+
|
230 |
+
if __name__ == "__main__":
|
231 |
+
main()
|
examples/fsmn_vad_by_webrtcvad/step_2_make_vad_segments.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
|
8 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
9 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
10 |
+
|
11 |
+
import librosa
|
12 |
+
import numpy as np
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from project_settings import project_path
|
16 |
+
from toolbox.vad.vad import WebRTCVoiceClassifier, SileroVoiceClassifier, CCSoundsClassifier, RingVad
|
17 |
+
|
18 |
+
|
19 |
+
def get_args():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
|
22 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
23 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
24 |
+
|
25 |
+
parser.add_argument("--train_vad_dataset", default="train-vad.jsonl", type=str)
|
26 |
+
parser.add_argument("--valid_vad_dataset", default="valid-vad.jsonl", type=str)
|
27 |
+
|
28 |
+
parser.add_argument("--duration", default=8.0, type=float)
|
29 |
+
parser.add_argument("--expected_sample_rate", default=8000, type=int)
|
30 |
+
|
31 |
+
parser.add_argument(
|
32 |
+
"--silero_model_path",
|
33 |
+
default=(project_path / "trained_models/silero_vad.jit").as_posix(),
|
34 |
+
type=str,
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--cc_sounds_model_path",
|
38 |
+
default=(project_path / "trained_models/sound-2-ch32.zip").as_posix(),
|
39 |
+
type=str,
|
40 |
+
)
|
41 |
+
args = parser.parse_args()
|
42 |
+
return args
|
43 |
+
|
44 |
+
|
45 |
+
def main():
|
46 |
+
args = get_args()
|
47 |
+
|
48 |
+
# webrtcvad
|
49 |
+
# model = SileroVoiceClassifier(model_path=args.silero_model_path, sample_rate=args.expected_sample_rate)
|
50 |
+
# w_vad = RingVad(
|
51 |
+
# model=model,
|
52 |
+
# start_ring_rate=0.2,
|
53 |
+
# end_ring_rate=0.1,
|
54 |
+
# frame_size_ms=32,
|
55 |
+
# frame_step_ms=32,
|
56 |
+
# padding_length_ms=320,
|
57 |
+
# max_silence_length_ms=320,
|
58 |
+
# max_speech_length_s=100,
|
59 |
+
# min_speech_length_s=0.1,
|
60 |
+
# sample_rate=args.expected_sample_rate,
|
61 |
+
# )
|
62 |
+
|
63 |
+
# webrtcvad
|
64 |
+
model = WebRTCVoiceClassifier(agg=3, sample_rate=args.expected_sample_rate)
|
65 |
+
w_vad = RingVad(
|
66 |
+
model=model,
|
67 |
+
start_ring_rate=0.9,
|
68 |
+
end_ring_rate=0.1,
|
69 |
+
frame_size_ms=30,
|
70 |
+
frame_step_ms=30,
|
71 |
+
padding_length_ms=90,
|
72 |
+
max_silence_length_ms=100,
|
73 |
+
max_speech_length_s=100,
|
74 |
+
min_speech_length_s=0.1,
|
75 |
+
sample_rate=args.expected_sample_rate,
|
76 |
+
)
|
77 |
+
|
78 |
+
# cc sounds
|
79 |
+
# model = CCSoundsClassifier(model_path=args.cc_sounds_model_path, sample_rate=args.expected_sample_rate)
|
80 |
+
# w_vad = RingVad(
|
81 |
+
# model=model,
|
82 |
+
# start_ring_rate=0.5,
|
83 |
+
# end_ring_rate=0.3,
|
84 |
+
# frame_size_ms=300,
|
85 |
+
# frame_step_ms=300,
|
86 |
+
# padding_length_ms=300,
|
87 |
+
# max_silence_length_ms=100,
|
88 |
+
# max_speech_length_s=100,
|
89 |
+
# min_speech_length_s=0.1,
|
90 |
+
# sample_rate=args.expected_sample_rate,
|
91 |
+
# )
|
92 |
+
|
93 |
+
# valid
|
94 |
+
va_duration = 0
|
95 |
+
raw_duration = 0
|
96 |
+
use_duration = 0
|
97 |
+
|
98 |
+
count = 0
|
99 |
+
process_bar_valid = tqdm(desc="process valid dataset jsonl")
|
100 |
+
with (open(args.valid_dataset, "r", encoding="utf-8") as fvalid,
|
101 |
+
open(args.valid_vad_dataset, "w", encoding="utf-8") as fvalid_vad):
|
102 |
+
for row in fvalid:
|
103 |
+
row = json.loads(row)
|
104 |
+
|
105 |
+
speech_filename = row["speech_filename"]
|
106 |
+
speech_offset = row["speech_offset"]
|
107 |
+
speech_duration = row["speech_duration"]
|
108 |
+
|
109 |
+
waveform, _ = librosa.load(
|
110 |
+
speech_filename,
|
111 |
+
sr=args.expected_sample_rate,
|
112 |
+
offset=speech_offset,
|
113 |
+
duration=speech_duration,
|
114 |
+
)
|
115 |
+
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
116 |
+
|
117 |
+
# vad
|
118 |
+
vad_segments = list()
|
119 |
+
segments = w_vad.vad(waveform)
|
120 |
+
vad_segments += segments
|
121 |
+
segments = w_vad.last_vad_segments()
|
122 |
+
vad_segments += segments
|
123 |
+
w_vad.reset()
|
124 |
+
|
125 |
+
row["vad_segments"] = vad_segments
|
126 |
+
|
127 |
+
row = json.dumps(row, ensure_ascii=False)
|
128 |
+
fvalid_vad.write(f"{row}\n")
|
129 |
+
|
130 |
+
va_duration += sum([vad_segment[1] - vad_segment[0] for vad_segment in vad_segments])
|
131 |
+
raw_duration += speech_duration
|
132 |
+
use_duration += args.duration
|
133 |
+
|
134 |
+
count += 1
|
135 |
+
|
136 |
+
va_rate = va_duration / use_duration
|
137 |
+
va_raw_rate = va_duration / raw_duration
|
138 |
+
use_duration_hours = use_duration / 3600
|
139 |
+
|
140 |
+
process_bar_valid.update(n=1)
|
141 |
+
process_bar_valid.set_postfix({
|
142 |
+
"va_rate": round(va_rate, 4),
|
143 |
+
"va_raw_rate": round(va_raw_rate, 4),
|
144 |
+
"duration_hours": round(use_duration_hours, 4),
|
145 |
+
})
|
146 |
+
|
147 |
+
# train
|
148 |
+
va_duration = 0
|
149 |
+
raw_duration = 0
|
150 |
+
use_duration = 0
|
151 |
+
|
152 |
+
count = 0
|
153 |
+
process_bar_train = tqdm(desc="process train dataset jsonl")
|
154 |
+
with (open(args.train_dataset, "r", encoding="utf-8") as ftrain,
|
155 |
+
open(args.train_vad_dataset, "w", encoding="utf-8") as ftrain_vad):
|
156 |
+
for row in ftrain:
|
157 |
+
row = json.loads(row)
|
158 |
+
|
159 |
+
speech_filename = row["speech_filename"]
|
160 |
+
speech_offset = row["speech_offset"]
|
161 |
+
speech_duration = row["speech_duration"]
|
162 |
+
|
163 |
+
waveform, _ = librosa.load(
|
164 |
+
speech_filename,
|
165 |
+
sr=args.expected_sample_rate,
|
166 |
+
offset=speech_offset,
|
167 |
+
duration=speech_duration,
|
168 |
+
)
|
169 |
+
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
170 |
+
|
171 |
+
# vad
|
172 |
+
vad_segments = list()
|
173 |
+
segments = w_vad.vad(waveform)
|
174 |
+
vad_segments += segments
|
175 |
+
segments = w_vad.last_vad_segments()
|
176 |
+
vad_segments += segments
|
177 |
+
w_vad.reset()
|
178 |
+
|
179 |
+
row["vad_segments"] = vad_segments
|
180 |
+
|
181 |
+
row = json.dumps(row, ensure_ascii=False)
|
182 |
+
ftrain_vad.write(f"{row}\n")
|
183 |
+
|
184 |
+
va_duration += sum([vad_segment[1] - vad_segment[0] for vad_segment in vad_segments])
|
185 |
+
raw_duration += speech_duration
|
186 |
+
use_duration += args.duration
|
187 |
+
|
188 |
+
count += 1
|
189 |
+
|
190 |
+
va_rate = va_duration / use_duration
|
191 |
+
va_raw_rate = va_duration / raw_duration
|
192 |
+
use_duration_hours = use_duration / 3600
|
193 |
+
|
194 |
+
process_bar_train.update(n=1)
|
195 |
+
process_bar_train.set_postfix({
|
196 |
+
"va_rate": round(va_rate, 4),
|
197 |
+
"va_raw_rate": round(va_raw_rate, 4),
|
198 |
+
"duration_hours": round(use_duration_hours, 4),
|
199 |
+
})
|
200 |
+
|
201 |
+
return
|
202 |
+
|
203 |
+
|
204 |
+
if __name__ == "__main__":
|
205 |
+
main()
|
examples/fsmn_vad_by_webrtcvad/step_3_check_vad.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
|
8 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
9 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
10 |
+
|
11 |
+
import librosa
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import numpy as np
|
14 |
+
from scipy.io import wavfile
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
|
18 |
+
def get_args():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
|
21 |
+
parser.add_argument("--train_vad_dataset", default="train-vad.jsonl", type=str)
|
22 |
+
parser.add_argument("--valid_vad_dataset", default="valid-vad.jsonl", type=str)
|
23 |
+
|
24 |
+
parser.add_argument("--duration", default=8.0, type=float)
|
25 |
+
parser.add_argument("--expected_sample_rate", default=8000, type=int)
|
26 |
+
|
27 |
+
args = parser.parse_args()
|
28 |
+
return args
|
29 |
+
|
30 |
+
|
31 |
+
def main():
|
32 |
+
args = get_args()
|
33 |
+
|
34 |
+
SAMPLE_RATE = 8000
|
35 |
+
|
36 |
+
with open(args.train_vad_dataset, "r", encoding="utf-8") as f:
|
37 |
+
for row in f:
|
38 |
+
row = json.loads(row)
|
39 |
+
|
40 |
+
speech_filename = row["speech_filename"]
|
41 |
+
speech_offset = row["speech_offset"]
|
42 |
+
speech_duration = row["speech_duration"]
|
43 |
+
|
44 |
+
vad_segments = row["vad_segments"]
|
45 |
+
|
46 |
+
print(f"speech_filename: {speech_filename}")
|
47 |
+
signal, sample_rate = librosa.load(
|
48 |
+
speech_filename,
|
49 |
+
sr=SAMPLE_RATE,
|
50 |
+
offset=speech_offset,
|
51 |
+
duration=speech_duration,
|
52 |
+
)
|
53 |
+
|
54 |
+
# plot
|
55 |
+
time = np.arange(0, len(signal)) / sample_rate
|
56 |
+
plt.figure(figsize=(12, 5))
|
57 |
+
plt.plot(time, signal, color='b')
|
58 |
+
for start, end in vad_segments:
|
59 |
+
plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') # 标记开始端点
|
60 |
+
plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') # 标记结束端点
|
61 |
+
|
62 |
+
plt.show()
|
63 |
+
|
64 |
+
return
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
main()
|
examples/fsmn_vad_by_webrtcvad/step_4_train_model.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
from logging.handlers import TimedRotatingFileHandler
|
7 |
+
import os
|
8 |
+
import platform
|
9 |
+
from pathlib import Path
|
10 |
+
import random
|
11 |
+
import sys
|
12 |
+
import shutil
|
13 |
+
from typing import List, Tuple
|
14 |
+
|
15 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
16 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
from torch.nn import functional as F
|
22 |
+
from torch.utils.data.dataloader import DataLoader
|
23 |
+
from tqdm import tqdm
|
24 |
+
|
25 |
+
from toolbox.torch.utils.data.dataset.vad_padding_jsonl_dataset import VadPaddingJsonlDataset
|
26 |
+
from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
|
27 |
+
from toolbox.torchaudio.models.vad.fsmn_vad.modeling_fsmn_vad import FSMNVadModel, FSMNVadPretrainedModel
|
28 |
+
from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss
|
29 |
+
from toolbox.torchaudio.losses.bce_loss import BCELoss
|
30 |
+
from toolbox.torchaudio.losses.dice_loss import DiceLoss
|
31 |
+
from toolbox.torchaudio.metrics.vad_metrics.vad_accuracy import VadAccuracy
|
32 |
+
from toolbox.torchaudio.metrics.vad_metrics.vad_f1_score import VadF1Score
|
33 |
+
|
34 |
+
|
35 |
+
def get_args():
|
36 |
+
parser = argparse.ArgumentParser()
|
37 |
+
parser.add_argument("--train_dataset", default="train-vad.jsonl", type=str)
|
38 |
+
parser.add_argument("--valid_dataset", default="valid-vad.jsonl", type=str)
|
39 |
+
|
40 |
+
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
41 |
+
parser.add_argument("--patience", default=30, type=int)
|
42 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
43 |
+
|
44 |
+
parser.add_argument("--config_file", default="yaml/config.yaml", type=str)
|
45 |
+
|
46 |
+
args = parser.parse_args()
|
47 |
+
return args
|
48 |
+
|
49 |
+
|
50 |
+
def logging_config(file_dir: str):
|
51 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
52 |
+
|
53 |
+
logging.basicConfig(format=fmt,
|
54 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
55 |
+
level=logging.INFO)
|
56 |
+
file_handler = TimedRotatingFileHandler(
|
57 |
+
filename=os.path.join(file_dir, "main.log"),
|
58 |
+
encoding="utf-8",
|
59 |
+
when="D",
|
60 |
+
interval=1,
|
61 |
+
backupCount=7
|
62 |
+
)
|
63 |
+
file_handler.setLevel(logging.INFO)
|
64 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
65 |
+
logger = logging.getLogger(__name__)
|
66 |
+
logger.addHandler(file_handler)
|
67 |
+
|
68 |
+
return logger
|
69 |
+
|
70 |
+
|
71 |
+
class CollateFunction(object):
|
72 |
+
def __init__(self):
|
73 |
+
pass
|
74 |
+
|
75 |
+
def __call__(self, batch: List[dict]):
|
76 |
+
noisy_audios = list()
|
77 |
+
batch_vad_segments = list()
|
78 |
+
|
79 |
+
for sample in batch:
|
80 |
+
noisy_wave: torch.Tensor = sample["noisy_wave"]
|
81 |
+
vad_segments: List[Tuple[float, float]] = sample["vad_segments"]
|
82 |
+
|
83 |
+
noisy_audios.append(noisy_wave)
|
84 |
+
batch_vad_segments.append(vad_segments)
|
85 |
+
|
86 |
+
noisy_audios = torch.stack(noisy_audios)
|
87 |
+
|
88 |
+
# assert
|
89 |
+
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
90 |
+
raise AssertionError("nan or inf in noisy_audios")
|
91 |
+
|
92 |
+
return noisy_audios, batch_vad_segments
|
93 |
+
|
94 |
+
|
95 |
+
collate_fn = CollateFunction()
|
96 |
+
|
97 |
+
|
98 |
+
def main():
|
99 |
+
args = get_args()
|
100 |
+
|
101 |
+
config = FSMNVadConfig.from_pretrained(
|
102 |
+
pretrained_model_name_or_path=args.config_file,
|
103 |
+
)
|
104 |
+
|
105 |
+
serialization_dir = Path(args.serialization_dir)
|
106 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
107 |
+
|
108 |
+
logger = logging_config(serialization_dir)
|
109 |
+
|
110 |
+
random.seed(config.seed)
|
111 |
+
np.random.seed(config.seed)
|
112 |
+
torch.manual_seed(config.seed)
|
113 |
+
logger.info(f"set seed: {config.seed}")
|
114 |
+
|
115 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
116 |
+
n_gpu = torch.cuda.device_count()
|
117 |
+
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
118 |
+
|
119 |
+
# datasets
|
120 |
+
train_dataset = VadPaddingJsonlDataset(
|
121 |
+
jsonl_file=args.train_dataset,
|
122 |
+
expected_sample_rate=config.sample_rate,
|
123 |
+
max_wave_value=32768.0,
|
124 |
+
min_snr_db=config.min_snr_db,
|
125 |
+
max_snr_db=config.max_snr_db,
|
126 |
+
# skip=225000,
|
127 |
+
)
|
128 |
+
valid_dataset = VadPaddingJsonlDataset(
|
129 |
+
jsonl_file=args.valid_dataset,
|
130 |
+
expected_sample_rate=config.sample_rate,
|
131 |
+
max_wave_value=32768.0,
|
132 |
+
min_snr_db=config.min_snr_db,
|
133 |
+
max_snr_db=config.max_snr_db,
|
134 |
+
)
|
135 |
+
train_data_loader = DataLoader(
|
136 |
+
dataset=train_dataset,
|
137 |
+
batch_size=config.batch_size,
|
138 |
+
# shuffle=True,
|
139 |
+
sampler=None,
|
140 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
141 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
142 |
+
collate_fn=collate_fn,
|
143 |
+
pin_memory=False,
|
144 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
145 |
+
)
|
146 |
+
valid_data_loader = DataLoader(
|
147 |
+
dataset=valid_dataset,
|
148 |
+
batch_size=config.batch_size,
|
149 |
+
# shuffle=True,
|
150 |
+
sampler=None,
|
151 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
152 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
153 |
+
collate_fn=collate_fn,
|
154 |
+
pin_memory=False,
|
155 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
156 |
+
)
|
157 |
+
|
158 |
+
# models
|
159 |
+
logger.info(f"prepare models. config_file: {args.config_file}")
|
160 |
+
model = FSMNVadPretrainedModel(config).to(device)
|
161 |
+
model.to(device)
|
162 |
+
model.train()
|
163 |
+
|
164 |
+
# optimizer
|
165 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
|
166 |
+
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
167 |
+
|
168 |
+
# resume training
|
169 |
+
last_step_idx = -1
|
170 |
+
last_epoch = -1
|
171 |
+
for step_idx_str in serialization_dir.glob("steps-*"):
|
172 |
+
step_idx_str = Path(step_idx_str)
|
173 |
+
step_idx = step_idx_str.stem.split("-")[1]
|
174 |
+
step_idx = int(step_idx)
|
175 |
+
if step_idx > last_step_idx:
|
176 |
+
last_step_idx = step_idx
|
177 |
+
# last_epoch = 1
|
178 |
+
|
179 |
+
if last_step_idx != -1:
|
180 |
+
logger.info(f"resume from steps-{last_step_idx}.")
|
181 |
+
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
182 |
+
|
183 |
+
logger.info(f"load state dict for model.")
|
184 |
+
with open(model_pt.as_posix(), "rb") as f:
|
185 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
186 |
+
model.load_state_dict(state_dict, strict=True)
|
187 |
+
|
188 |
+
if config.lr_scheduler == "CosineAnnealingLR":
|
189 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
190 |
+
optimizer,
|
191 |
+
last_epoch=last_epoch,
|
192 |
+
# T_max=10 * config.eval_steps,
|
193 |
+
# eta_min=0.01 * config.lr,
|
194 |
+
**config.lr_scheduler_kwargs,
|
195 |
+
)
|
196 |
+
elif config.lr_scheduler == "MultiStepLR":
|
197 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
198 |
+
optimizer,
|
199 |
+
last_epoch=last_epoch,
|
200 |
+
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
201 |
+
)
|
202 |
+
else:
|
203 |
+
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
204 |
+
|
205 |
+
bce_loss_fn = BCELoss(reduction="mean").to(device)
|
206 |
+
dice_loss_fn = DiceLoss(reduction="mean").to(device)
|
207 |
+
|
208 |
+
vad_accuracy_metrics_fn = VadAccuracy(threshold=0.5)
|
209 |
+
vad_f1_score_metrics_fn = VadF1Score(threshold=0.5)
|
210 |
+
|
211 |
+
# training loop
|
212 |
+
|
213 |
+
# state
|
214 |
+
average_loss = 1000000000
|
215 |
+
average_bce_loss = 1000000000
|
216 |
+
average_dice_loss = 1000000000
|
217 |
+
|
218 |
+
accuracy = -1
|
219 |
+
f1 = -1
|
220 |
+
precision = -1
|
221 |
+
recall = -1
|
222 |
+
|
223 |
+
model_list = list()
|
224 |
+
best_epoch_idx = None
|
225 |
+
best_step_idx = None
|
226 |
+
best_metric = None
|
227 |
+
patience_count = 0
|
228 |
+
|
229 |
+
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
230 |
+
|
231 |
+
logger.info("training")
|
232 |
+
early_stop_flag = False
|
233 |
+
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
234 |
+
if early_stop_flag:
|
235 |
+
break
|
236 |
+
|
237 |
+
# train
|
238 |
+
model.train()
|
239 |
+
vad_accuracy_metrics_fn.reset()
|
240 |
+
vad_f1_score_metrics_fn.reset()
|
241 |
+
|
242 |
+
total_loss = 0.
|
243 |
+
total_bce_loss = 0.
|
244 |
+
total_dice_loss = 0.
|
245 |
+
total_batches = 0.
|
246 |
+
|
247 |
+
progress_bar_train = tqdm(
|
248 |
+
initial=step_idx,
|
249 |
+
desc="Training; epoch-{}".format(epoch_idx),
|
250 |
+
)
|
251 |
+
for train_batch in train_data_loader:
|
252 |
+
noisy_audios, batch_vad_segments = train_batch
|
253 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
254 |
+
# noisy_audios shape: [b, num_samples]
|
255 |
+
num_samples = noisy_audios.shape[-1]
|
256 |
+
|
257 |
+
logits, probs = model.forward(noisy_audios)
|
258 |
+
|
259 |
+
targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
|
260 |
+
|
261 |
+
bce_loss = bce_loss_fn.forward(probs, targets)
|
262 |
+
dice_loss = dice_loss_fn.forward(probs, targets)
|
263 |
+
|
264 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss
|
265 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
266 |
+
logger.info(f"find nan or inf in loss. continue.")
|
267 |
+
continue
|
268 |
+
|
269 |
+
vad_accuracy_metrics_fn.__call__(probs, targets)
|
270 |
+
vad_f1_score_metrics_fn.__call__(probs, targets)
|
271 |
+
|
272 |
+
optimizer.zero_grad()
|
273 |
+
loss.backward()
|
274 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
275 |
+
optimizer.step()
|
276 |
+
lr_scheduler.step()
|
277 |
+
|
278 |
+
total_loss += loss.item()
|
279 |
+
total_bce_loss += bce_loss.item()
|
280 |
+
total_dice_loss += dice_loss.item()
|
281 |
+
total_batches += 1
|
282 |
+
|
283 |
+
average_loss = round(total_loss / total_batches, 4)
|
284 |
+
average_bce_loss = round(total_bce_loss / total_batches, 4)
|
285 |
+
average_dice_loss = round(total_dice_loss / total_batches, 4)
|
286 |
+
|
287 |
+
metrics = vad_accuracy_metrics_fn.get_metric()
|
288 |
+
accuracy = metrics["accuracy"]
|
289 |
+
metrics = vad_f1_score_metrics_fn.get_metric()
|
290 |
+
f1 = metrics["f1"]
|
291 |
+
precision = metrics["precision"]
|
292 |
+
recall = metrics["recall"]
|
293 |
+
|
294 |
+
progress_bar_train.update(1)
|
295 |
+
progress_bar_train.set_postfix({
|
296 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
297 |
+
"loss": average_loss,
|
298 |
+
"bce_loss": average_bce_loss,
|
299 |
+
"dice_loss": average_dice_loss,
|
300 |
+
"accuracy": accuracy,
|
301 |
+
"f1": f1,
|
302 |
+
"precision": precision,
|
303 |
+
"recall": recall,
|
304 |
+
})
|
305 |
+
|
306 |
+
# evaluation
|
307 |
+
step_idx += 1
|
308 |
+
if step_idx % config.eval_steps == 0:
|
309 |
+
with torch.no_grad():
|
310 |
+
torch.cuda.empty_cache()
|
311 |
+
|
312 |
+
model.eval()
|
313 |
+
vad_accuracy_metrics_fn.reset()
|
314 |
+
vad_f1_score_metrics_fn.reset()
|
315 |
+
|
316 |
+
total_loss = 0.
|
317 |
+
total_bce_loss = 0.
|
318 |
+
total_dice_loss = 0.
|
319 |
+
total_batches = 0.
|
320 |
+
|
321 |
+
progress_bar_train.close()
|
322 |
+
progress_bar_eval = tqdm(
|
323 |
+
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
324 |
+
)
|
325 |
+
for eval_batch in valid_data_loader:
|
326 |
+
noisy_audios, batch_vad_segments = train_batch
|
327 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
328 |
+
# noisy_audios shape: [b, num_samples]
|
329 |
+
num_samples = noisy_audios.shape[-1]
|
330 |
+
|
331 |
+
logits, probs = model.forward(noisy_audios)
|
332 |
+
|
333 |
+
targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
|
334 |
+
|
335 |
+
bce_loss = bce_loss_fn.forward(probs, targets)
|
336 |
+
dice_loss = dice_loss_fn.forward(probs, targets)
|
337 |
+
|
338 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss
|
339 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
340 |
+
logger.info(f"find nan or inf in loss. continue.")
|
341 |
+
continue
|
342 |
+
|
343 |
+
vad_accuracy_metrics_fn.__call__(probs, targets)
|
344 |
+
vad_f1_score_metrics_fn.__call__(probs, targets)
|
345 |
+
|
346 |
+
total_loss += loss.item()
|
347 |
+
total_bce_loss += bce_loss.item()
|
348 |
+
total_dice_loss += dice_loss.item()
|
349 |
+
total_batches += 1
|
350 |
+
|
351 |
+
average_loss = round(total_loss / total_batches, 4)
|
352 |
+
average_bce_loss = round(total_bce_loss / total_batches, 4)
|
353 |
+
average_dice_loss = round(total_dice_loss / total_batches, 4)
|
354 |
+
|
355 |
+
metrics = vad_accuracy_metrics_fn.get_metric()
|
356 |
+
accuracy = metrics["accuracy"]
|
357 |
+
metrics = vad_f1_score_metrics_fn.get_metric()
|
358 |
+
f1 = metrics["f1"]
|
359 |
+
precision = metrics["precision"]
|
360 |
+
recall = metrics["recall"]
|
361 |
+
|
362 |
+
progress_bar_eval.update(1)
|
363 |
+
progress_bar_eval.set_postfix({
|
364 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
365 |
+
"loss": average_loss,
|
366 |
+
"bce_loss": average_bce_loss,
|
367 |
+
"dice_loss": average_dice_loss,
|
368 |
+
"accuracy": accuracy,
|
369 |
+
"f1": f1,
|
370 |
+
"precision": precision,
|
371 |
+
"recall": recall,
|
372 |
+
})
|
373 |
+
|
374 |
+
model.train()
|
375 |
+
vad_accuracy_metrics_fn.reset()
|
376 |
+
vad_f1_score_metrics_fn.reset()
|
377 |
+
|
378 |
+
total_loss = 0.
|
379 |
+
total_bce_loss = 0.
|
380 |
+
total_dice_loss = 0.
|
381 |
+
total_batches = 0.
|
382 |
+
|
383 |
+
progress_bar_eval.close()
|
384 |
+
progress_bar_train = tqdm(
|
385 |
+
initial=progress_bar_train.n,
|
386 |
+
postfix=progress_bar_train.postfix,
|
387 |
+
desc=progress_bar_train.desc,
|
388 |
+
)
|
389 |
+
|
390 |
+
# save path
|
391 |
+
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
392 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
393 |
+
|
394 |
+
# save models
|
395 |
+
model.save_pretrained(save_dir.as_posix())
|
396 |
+
|
397 |
+
model_list.append(save_dir)
|
398 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
399 |
+
model_to_delete: Path = model_list.pop(0)
|
400 |
+
shutil.rmtree(model_to_delete.as_posix())
|
401 |
+
|
402 |
+
# save metric
|
403 |
+
if best_metric is None:
|
404 |
+
best_epoch_idx = epoch_idx
|
405 |
+
best_step_idx = step_idx
|
406 |
+
best_metric = f1
|
407 |
+
elif f1 >= best_metric:
|
408 |
+
# great is better.
|
409 |
+
best_epoch_idx = epoch_idx
|
410 |
+
best_step_idx = step_idx
|
411 |
+
best_metric = f1
|
412 |
+
else:
|
413 |
+
pass
|
414 |
+
|
415 |
+
metrics = {
|
416 |
+
"epoch_idx": epoch_idx,
|
417 |
+
"best_epoch_idx": best_epoch_idx,
|
418 |
+
"best_step_idx": best_step_idx,
|
419 |
+
"loss": average_loss,
|
420 |
+
"bce_loss": average_bce_loss,
|
421 |
+
"dice_loss": average_dice_loss,
|
422 |
+
|
423 |
+
"accuracy": accuracy,
|
424 |
+
}
|
425 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
426 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
427 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
428 |
+
|
429 |
+
# save best
|
430 |
+
best_dir = serialization_dir / "best"
|
431 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
432 |
+
if best_dir.exists():
|
433 |
+
shutil.rmtree(best_dir)
|
434 |
+
shutil.copytree(save_dir, best_dir)
|
435 |
+
|
436 |
+
# early stop
|
437 |
+
early_stop_flag = False
|
438 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
439 |
+
patience_count = 0
|
440 |
+
else:
|
441 |
+
patience_count += 1
|
442 |
+
if patience_count >= args.patience:
|
443 |
+
early_stop_flag = True
|
444 |
+
|
445 |
+
# early stop
|
446 |
+
if early_stop_flag:
|
447 |
+
break
|
448 |
+
|
449 |
+
return
|
450 |
+
|
451 |
+
|
452 |
+
if __name__ == "__main__":
|
453 |
+
main()
|
examples/fsmn_vad_by_webrtcvad/yaml/config.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "fsmn_vad"
|
2 |
+
|
3 |
+
# spec
|
4 |
+
sample_rate: 8000
|
5 |
+
nfft: 512
|
6 |
+
win_size: 240
|
7 |
+
hop_size: 80
|
8 |
+
win_type: hann
|
9 |
+
|
10 |
+
# model
|
11 |
+
fsmn_input_size: 257
|
12 |
+
fsmn_input_affine_size: 140
|
13 |
+
fsmn_hidden_size: 250
|
14 |
+
fsmn_basic_block_layers: 4
|
15 |
+
fsmn_basic_block_hidden_size: 128
|
16 |
+
fsmn_basic_block_lorder: 20
|
17 |
+
fsmn_basic_block_rorder: 0
|
18 |
+
fsmn_basic_block_lstride: 1
|
19 |
+
fsmn_basic_block_rstride: 0
|
20 |
+
fsmn_output_affine_size: 140
|
21 |
+
fsmn_output_size: 1
|
22 |
+
|
23 |
+
use_softmax: false
|
24 |
+
|
25 |
+
# data
|
26 |
+
min_snr_db: -10
|
27 |
+
max_snr_db: 20
|
28 |
+
|
29 |
+
# train
|
30 |
+
lr: 0.001
|
31 |
+
lr_scheduler: "CosineAnnealingLR"
|
32 |
+
lr_scheduler_kwargs:
|
33 |
+
T_max: 250000
|
34 |
+
eta_min: 0.0001
|
35 |
+
|
36 |
+
max_epochs: 100
|
37 |
+
clip_grad_norm: 10.0
|
38 |
+
seed: 1234
|
39 |
+
|
40 |
+
num_workers: 4
|
41 |
+
batch_size: 128
|
42 |
+
eval_steps: 25000
|
examples/silero_vad_by_webrtcvad/run.sh
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
|
3 |
: <<'END'
|
4 |
|
5 |
-
bash run.sh --stage
|
6 |
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
7 |
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
8 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
|
|
|
2 |
|
3 |
: <<'END'
|
4 |
|
5 |
+
bash run.sh --stage 2 --stop_stage 2 --system_version centos \
|
6 |
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
7 |
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
8 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
|
requirements.txt
CHANGED
@@ -11,3 +11,4 @@ torchaudio==2.5.1
|
|
11 |
overrides==7.7.0
|
12 |
webrtcvad==2.0.10
|
13 |
matplotlib==3.10.3
|
|
|
|
11 |
overrides==7.7.0
|
12 |
webrtcvad==2.0.10
|
13 |
matplotlib==3.10.3
|
14 |
+
google-genai
|
toolbox/torchaudio/models/vad/cnn_vad/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/models/vad/cnn_vad/modeling_cnn_vad.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/models/vad/fsmn_vad/configuration_fsmn_vad.py
CHANGED
@@ -13,8 +13,19 @@ class FSMNVadConfig(PretrainedConfig):
|
|
13 |
hop_size: int = 80,
|
14 |
win_type: str = "hann",
|
15 |
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
lr: float = 0.001,
|
20 |
lr_scheduler: str = "CosineAnnealingLR",
|
@@ -39,8 +50,19 @@ class FSMNVadConfig(PretrainedConfig):
|
|
39 |
self.win_type = win_type
|
40 |
|
41 |
# encoder
|
42 |
-
self.
|
43 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
# train
|
46 |
self.lr = lr
|
|
|
13 |
hop_size: int = 80,
|
14 |
win_type: str = "hann",
|
15 |
|
16 |
+
fsmn_input_size: int = 257,
|
17 |
+
fsmn_input_affine_size: int = 140,
|
18 |
+
fsmn_hidden_size: int = 250,
|
19 |
+
fsmn_basic_block_layers: int = 4,
|
20 |
+
fsmn_basic_block_hidden_size: int = 128,
|
21 |
+
fsmn_basic_block_lorder: int = 20,
|
22 |
+
fsmn_basic_block_rorder: int = 0,
|
23 |
+
fsmn_basic_block_lstride: int = 1,
|
24 |
+
fsmn_basic_block_rstride: int = 0,
|
25 |
+
fsmn_output_affine_size: int = 140,
|
26 |
+
fsmn_output_size: int = 1,
|
27 |
+
|
28 |
+
use_softmax: bool = False,
|
29 |
|
30 |
lr: float = 0.001,
|
31 |
lr_scheduler: str = "CosineAnnealingLR",
|
|
|
50 |
self.win_type = win_type
|
51 |
|
52 |
# encoder
|
53 |
+
self.fsmn_input_size = fsmn_input_size
|
54 |
+
self.fsmn_input_affine_size = fsmn_input_affine_size
|
55 |
+
self.fsmn_hidden_size = fsmn_hidden_size
|
56 |
+
self.fsmn_basic_block_layers = fsmn_basic_block_layers
|
57 |
+
self.fsmn_basic_block_hidden_size = fsmn_basic_block_hidden_size
|
58 |
+
self.fsmn_basic_block_lorder = fsmn_basic_block_lorder
|
59 |
+
self.fsmn_basic_block_rorder = fsmn_basic_block_rorder
|
60 |
+
self.fsmn_basic_block_lstride = fsmn_basic_block_lstride
|
61 |
+
self.fsmn_basic_block_rstride = fsmn_basic_block_rstride
|
62 |
+
self.fsmn_output_affine_size = fsmn_output_affine_size
|
63 |
+
self.fsmn_output_size = fsmn_output_size
|
64 |
+
|
65 |
+
self.use_softmax = use_softmax
|
66 |
|
67 |
# train
|
68 |
self.lr = lr
|
toolbox/torchaudio/models/vad/fsmn_vad/fsmn_encoder.py
CHANGED
@@ -226,10 +226,6 @@ class FSMN(nn.Module):
|
|
226 |
self.out_linear1 = AffineTransform(hidden_size, output_affine_size)
|
227 |
self.out_linear2 = AffineTransform(output_affine_size, output_size)
|
228 |
|
229 |
-
self.use_softmax = use_softmax
|
230 |
-
if self.use_softmax:
|
231 |
-
self.softmax = nn.Softmax(dim=-1)
|
232 |
-
|
233 |
def forward(self,
|
234 |
inputs: torch.Tensor,
|
235 |
cache_list: List[torch.Tensor] = None,
|
@@ -253,8 +249,6 @@ class FSMN(nn.Module):
|
|
253 |
outputs = self.out_linear2.forward(x)
|
254 |
# outputs shape: [b, t, f]
|
255 |
|
256 |
-
if self.use_softmax:
|
257 |
-
outputs = self.softmax(outputs)
|
258 |
return outputs, new_cache_list
|
259 |
|
260 |
|
@@ -271,7 +265,6 @@ def main():
|
|
271 |
basic_block_rstride=1,
|
272 |
output_affine_size=16,
|
273 |
output_size=32,
|
274 |
-
use_softmax=True,
|
275 |
)
|
276 |
|
277 |
inputs = torch.randn(size=(1, 198, 32), dtype=torch.float32)
|
|
|
226 |
self.out_linear1 = AffineTransform(hidden_size, output_affine_size)
|
227 |
self.out_linear2 = AffineTransform(output_affine_size, output_size)
|
228 |
|
|
|
|
|
|
|
|
|
229 |
def forward(self,
|
230 |
inputs: torch.Tensor,
|
231 |
cache_list: List[torch.Tensor] = None,
|
|
|
249 |
outputs = self.out_linear2.forward(x)
|
250 |
# outputs shape: [b, t, f]
|
251 |
|
|
|
|
|
252 |
return outputs, new_cache_list
|
253 |
|
254 |
|
|
|
265 |
basic_block_rstride=1,
|
266 |
output_affine_size=16,
|
267 |
output_size=32,
|
|
|
268 |
)
|
269 |
|
270 |
inputs = torch.randn(size=(1, 198, 32), dtype=torch.float32)
|
toolbox/torchaudio/models/vad/fsmn_vad/modeling_fsmn_vad.py
CHANGED
@@ -41,20 +41,104 @@ class FSMNVadModel(nn.Module):
|
|
41 |
)
|
42 |
|
43 |
self.fsmn_encoder = FSMN(
|
44 |
-
input_size=
|
45 |
-
input_affine_size=
|
46 |
-
hidden_size=
|
47 |
-
basic_block_layers=
|
48 |
-
basic_block_hidden_size=
|
49 |
-
basic_block_lorder=
|
50 |
-
basic_block_rorder=
|
51 |
-
basic_block_lstride=
|
52 |
-
basic_block_rstride=
|
53 |
-
output_affine_size=
|
54 |
-
output_size=
|
55 |
-
use_softmax=True,
|
56 |
)
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
if __name__ == "__main__":
|
60 |
-
|
|
|
41 |
)
|
42 |
|
43 |
self.fsmn_encoder = FSMN(
|
44 |
+
input_size=config.fsmn_input_size,
|
45 |
+
input_affine_size=config.fsmn_input_affine_size,
|
46 |
+
hidden_size=config.fsmn_hidden_size,
|
47 |
+
basic_block_layers=config.fsmn_basic_block_layers,
|
48 |
+
basic_block_hidden_size=config.fsmn_basic_block_hidden_size,
|
49 |
+
basic_block_lorder=config.fsmn_basic_block_lorder,
|
50 |
+
basic_block_rorder=config.fsmn_basic_block_rorder,
|
51 |
+
basic_block_lstride=config.fsmn_basic_block_lstride,
|
52 |
+
basic_block_rstride=config.fsmn_basic_block_rstride,
|
53 |
+
output_affine_size=config.fsmn_output_affine_size,
|
54 |
+
output_size=config.fsmn_output_size,
|
|
|
55 |
)
|
56 |
|
57 |
+
self.use_softmax = config.use_softmax
|
58 |
+
self.sigmoid = nn.Sigmoid()
|
59 |
+
self.softmax = nn.Softmax()
|
60 |
+
|
61 |
+
def forward(self, signal: torch.Tensor):
|
62 |
+
if signal.dim() == 2:
|
63 |
+
signal = torch.unsqueeze(signal, dim=1)
|
64 |
+
_, _, num_samples = signal.shape
|
65 |
+
# signal shape [b, 1, num_samples]
|
66 |
+
|
67 |
+
mags = self.stft.forward(signal)
|
68 |
+
# mags shape: [b, f, t]
|
69 |
+
|
70 |
+
x = torch.transpose(mags, dim0=1, dim1=2)
|
71 |
+
# x shape: [b, t, f]
|
72 |
+
|
73 |
+
logits, _ = self.fsmn_encoder.forward(x)
|
74 |
+
|
75 |
+
if self.use_softmax:
|
76 |
+
probs = self.softmax.forward(logits)
|
77 |
+
# probs shape: [b, t, n]
|
78 |
+
else:
|
79 |
+
probs = self.sigmoid.forward(logits)
|
80 |
+
# probs shape: [b, t, 1]
|
81 |
+
return logits, probs
|
82 |
+
|
83 |
+
|
84 |
+
class FSMNVadPretrainedModel(FSMNVadModel):
|
85 |
+
def __init__(self,
|
86 |
+
config: FSMNVadConfig,
|
87 |
+
):
|
88 |
+
super(FSMNVadPretrainedModel, self).__init__(
|
89 |
+
config=config,
|
90 |
+
)
|
91 |
+
|
92 |
+
@classmethod
|
93 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
94 |
+
config = FSMNVadConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
95 |
+
|
96 |
+
model = cls(config)
|
97 |
+
|
98 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
99 |
+
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
|
100 |
+
else:
|
101 |
+
ckpt_file = pretrained_model_name_or_path
|
102 |
+
|
103 |
+
with open(ckpt_file, "rb") as f:
|
104 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
105 |
+
model.load_state_dict(state_dict, strict=True)
|
106 |
+
return model
|
107 |
+
|
108 |
+
def save_pretrained(self,
|
109 |
+
save_directory: Union[str, os.PathLike],
|
110 |
+
state_dict: Optional[dict] = None,
|
111 |
+
):
|
112 |
+
|
113 |
+
model = self
|
114 |
+
|
115 |
+
if state_dict is None:
|
116 |
+
state_dict = model.state_dict()
|
117 |
+
|
118 |
+
os.makedirs(save_directory, exist_ok=True)
|
119 |
+
|
120 |
+
# save state dict
|
121 |
+
model_file = os.path.join(save_directory, MODEL_FILE)
|
122 |
+
torch.save(state_dict, model_file)
|
123 |
+
|
124 |
+
# save config
|
125 |
+
config_file = os.path.join(save_directory, CONFIG_FILE)
|
126 |
+
self.config.to_yaml_file(config_file)
|
127 |
+
return save_directory
|
128 |
+
|
129 |
+
|
130 |
+
def main():
|
131 |
+
config = FSMNVadConfig()
|
132 |
+
model = FSMNVadPretrainedModel(config=config)
|
133 |
+
|
134 |
+
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
|
135 |
+
|
136 |
+
logits, probs = model.forward(noisy)
|
137 |
+
print(f"probs: {probs}")
|
138 |
+
print(f"probs.shape: {logits.shape}")
|
139 |
+
print(f"use_softmax: {config.use_softmax}")
|
140 |
+
return
|
141 |
+
|
142 |
|
143 |
if __name__ == "__main__":
|
144 |
+
main()
|
toolbox/torchaudio/models/vad/fsmn_vad/yaml/config-sigmoid.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "fsmn_vad"
|
2 |
+
|
3 |
+
# spec
|
4 |
+
sample_rate: 8000
|
5 |
+
nfft: 512
|
6 |
+
win_size: 240
|
7 |
+
hop_size: 80
|
8 |
+
win_type: hann
|
9 |
+
|
10 |
+
# model
|
11 |
+
fsmn_input_size: 257
|
12 |
+
fsmn_input_affine_size: 140
|
13 |
+
fsmn_hidden_size: 250
|
14 |
+
fsmn_basic_block_layers: 4
|
15 |
+
fsmn_basic_block_hidden_size: 128
|
16 |
+
fsmn_basic_block_lorder: 20
|
17 |
+
fsmn_basic_block_rorder: 0
|
18 |
+
fsmn_basic_block_lstride: 1
|
19 |
+
fsmn_basic_block_rstride: 0
|
20 |
+
fsmn_output_affine_size: 140
|
21 |
+
fsmn_output_size: 1
|
22 |
+
|
23 |
+
use_softmax: false
|
24 |
+
|
25 |
+
# data
|
26 |
+
min_snr_db: -10
|
27 |
+
max_snr_db: 20
|
28 |
+
|
29 |
+
# train
|
30 |
+
lr: 0.001
|
31 |
+
lr_scheduler: "CosineAnnealingLR"
|
32 |
+
lr_scheduler_kwargs:
|
33 |
+
T_max: 250000
|
34 |
+
eta_min: 0.0001
|
35 |
+
|
36 |
+
max_epochs: 100
|
37 |
+
clip_grad_norm: 10.0
|
38 |
+
seed: 1234
|
39 |
+
|
40 |
+
num_workers: 4
|
41 |
+
batch_size: 128
|
42 |
+
eval_steps: 25000
|
toolbox/torchaudio/models/vad/silero_vad/inference_silero_vad.py
CHANGED
@@ -5,6 +5,7 @@ import logging
|
|
5 |
from pathlib import Path
|
6 |
import shutil
|
7 |
import tempfile, time
|
|
|
8 |
import zipfile
|
9 |
|
10 |
from scipy.io import wavfile
|
@@ -18,17 +19,14 @@ torch.set_num_threads(1)
|
|
18 |
from project_settings import project_path
|
19 |
from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
|
20 |
from toolbox.torchaudio.models.vad.silero_vad.modeling_silero_vad import SileroVadPretrainedModel, MODEL_FILE
|
21 |
-
from toolbox.
|
22 |
|
23 |
|
24 |
logger = logging.getLogger("toolbox")
|
25 |
|
26 |
|
27 |
-
class
|
28 |
-
def __init__(self,
|
29 |
-
pretrained_model_path_or_zip_file: str,
|
30 |
-
device: str = "cpu",
|
31 |
-
):
|
32 |
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
33 |
self.device = torch.device(device)
|
34 |
|
@@ -62,72 +60,38 @@ class SileroVadVoiceClassifier(FrameVoiceClassifier):
|
|
62 |
shutil.rmtree(model_path)
|
63 |
return config, model
|
64 |
|
65 |
-
def
|
66 |
-
|
67 |
-
raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
inputs = torch.tensor(chunk, dtype=torch.float32)
|
72 |
inputs = torch.unsqueeze(inputs, dim=0)
|
|
|
73 |
|
74 |
-
|
75 |
-
logits,
|
76 |
-
except RuntimeError as e:
|
77 |
-
print(inputs.shape)
|
78 |
-
raise e
|
79 |
-
# logits shape: [b, t, 1]
|
80 |
-
logits_ = torch.mean(logits, dim=1)
|
81 |
-
# logits_ shape: [b, 1]
|
82 |
-
probs = torch.sigmoid(logits_)
|
83 |
-
|
84 |
-
voice_prob = probs[0][0]
|
85 |
-
return float(voice_prob)
|
86 |
-
|
87 |
-
|
88 |
-
class InferenceSileroVad(object):
|
89 |
-
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
90 |
-
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
91 |
-
self.device = torch.device(device)
|
92 |
-
|
93 |
-
self.voice_classifier = SileroVadVoiceClassifier(pretrained_model_path_or_zip_file, device=device)
|
94 |
-
|
95 |
-
self.ring_vad = RingVad(model=self.voice_classifier,
|
96 |
-
start_ring_rate=0.2,
|
97 |
-
end_ring_rate=0.1,
|
98 |
-
frame_size_ms=30,
|
99 |
-
frame_step_ms=30,
|
100 |
-
padding_length_ms=300,
|
101 |
-
max_silence_length_ms=300,
|
102 |
-
sample_rate=SAMPLE_RATE,
|
103 |
-
)
|
104 |
-
|
105 |
-
def vad(self, signal: np.ndarray) -> np.ndarray:
|
106 |
-
self.ring_vad.reset()
|
107 |
-
|
108 |
-
vad_segments = list()
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
#
|
113 |
-
segments = self.ring_vad.last_vad_segments()
|
114 |
-
vad_segments += segments
|
115 |
-
return vad_segments
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
120 |
|
121 |
-
def
|
122 |
-
|
123 |
-
return result
|
124 |
|
125 |
|
126 |
def get_args():
|
127 |
parser = argparse.ArgumentParser()
|
128 |
parser.add_argument(
|
129 |
"--wav_file",
|
130 |
-
default=(project_path / "data/examples/
|
|
|
|
|
|
|
|
|
131 |
type=str,
|
132 |
)
|
133 |
args = parser.parse_args()
|
@@ -143,17 +107,18 @@ def main():
|
|
143 |
sample_rate, signal = wavfile.read(args.wav_file)
|
144 |
if SAMPLE_RATE != sample_rate:
|
145 |
raise AssertionError
|
|
|
146 |
|
147 |
infer = InferenceSileroVad(
|
148 |
-
pretrained_model_path_or_zip_file=(project_path / "trained_models/silero-vad-webrtcvad-nx2-dns3.zip").as_posix()
|
|
|
149 |
)
|
|
|
150 |
|
151 |
-
|
152 |
|
153 |
-
|
154 |
-
frame_step = infer.get_vad_frame_step()
|
155 |
|
156 |
-
# speech_probs
|
157 |
speech_probs = process_speech_probs(
|
158 |
signal=signal,
|
159 |
speech_probs=speech_probs,
|
@@ -161,7 +126,7 @@ def main():
|
|
161 |
)
|
162 |
|
163 |
# plot
|
164 |
-
make_visualization(signal, speech_probs, SAMPLE_RATE
|
165 |
return
|
166 |
|
167 |
|
|
|
5 |
from pathlib import Path
|
6 |
import shutil
|
7 |
import tempfile, time
|
8 |
+
from typing import List
|
9 |
import zipfile
|
10 |
|
11 |
from scipy.io import wavfile
|
|
|
19 |
from project_settings import project_path
|
20 |
from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
|
21 |
from toolbox.torchaudio.models.vad.silero_vad.modeling_silero_vad import SileroVadPretrainedModel, MODEL_FILE
|
22 |
+
from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
|
23 |
|
24 |
|
25 |
logger = logging.getLogger("toolbox")
|
26 |
|
27 |
|
28 |
+
class InferenceSileroVad(object):
|
29 |
+
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
|
|
|
|
|
|
30 |
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
31 |
self.device = torch.device(device)
|
32 |
|
|
|
60 |
shutil.rmtree(model_path)
|
61 |
return config, model
|
62 |
|
63 |
+
def infer(self, signal: torch.Tensor) -> float:
|
64 |
+
# signal shape: [num_samples,], value between -1 and 1.
|
|
|
65 |
|
66 |
+
inputs = torch.tensor(signal, dtype=torch.float32)
|
|
|
|
|
67 |
inputs = torch.unsqueeze(inputs, dim=0)
|
68 |
+
# inputs shape: [1, num_samples,]
|
69 |
|
70 |
+
with torch.no_grad():
|
71 |
+
logits, probs = self.model.forward(inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
+
# probs shape: [b, t, 1]
|
74 |
+
probs = torch.squeeze(probs, dim=-1)
|
75 |
+
# probs shape: [b, t]
|
|
|
|
|
|
|
76 |
|
77 |
+
probs = probs.numpy()
|
78 |
+
probs = probs[0]
|
79 |
+
probs = probs.tolist()
|
80 |
+
return probs
|
81 |
|
82 |
+
def post_process(self, probs: List[float]):
|
83 |
+
return
|
|
|
84 |
|
85 |
|
86 |
def get_args():
|
87 |
parser = argparse.ArgumentParser()
|
88 |
parser.add_argument(
|
89 |
"--wav_file",
|
90 |
+
default=(project_path / "data/examples/ai_agent/chinese-4.wav").as_posix(),
|
91 |
+
# default=(project_path / "data/examples/ai_agent/chinese-5.wav").as_posix(),
|
92 |
+
# default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
|
93 |
+
# default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
|
94 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-SG\2025-06-17\active_media_r_0af6bd3a-9aef-4bef-935b-63abfb4d46d8_5.wav",
|
95 |
type=str,
|
96 |
)
|
97 |
args = parser.parse_args()
|
|
|
107 |
sample_rate, signal = wavfile.read(args.wav_file)
|
108 |
if SAMPLE_RATE != sample_rate:
|
109 |
raise AssertionError
|
110 |
+
signal = signal / (1 << 15)
|
111 |
|
112 |
infer = InferenceSileroVad(
|
113 |
+
pretrained_model_path_or_zip_file=(project_path / "trained_models/silero-vad-by-webrtcvad-nx2-dns3.zip").as_posix()
|
114 |
+
# pretrained_model_path_or_zip_file=(project_path / "trained_models/silero-vad-webrtcvad-nx2-dns3.zip").as_posix()
|
115 |
)
|
116 |
+
frame_step = infer.model.hop_size
|
117 |
|
118 |
+
speech_probs = infer.infer(signal)
|
119 |
|
120 |
+
# print(speech_probs)
|
|
|
121 |
|
|
|
122 |
speech_probs = process_speech_probs(
|
123 |
signal=signal,
|
124 |
speech_probs=speech_probs,
|
|
|
126 |
)
|
127 |
|
128 |
# plot
|
129 |
+
make_visualization(signal, speech_probs, SAMPLE_RATE)
|
130 |
return
|
131 |
|
132 |
|
toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py
CHANGED
@@ -82,6 +82,11 @@ class Encoder(nn.Module):
|
|
82 |
class SileroVadModel(nn.Module):
|
83 |
def __init__(self, config: SileroVadConfig):
|
84 |
super(SileroVadModel, self).__init__()
|
|
|
|
|
|
|
|
|
|
|
85 |
self.config = config
|
86 |
self.eps = 1e-12
|
87 |
|
@@ -120,6 +125,11 @@ class SileroVadModel(nn.Module):
|
|
120 |
self.sigmoid = nn.Sigmoid()
|
121 |
|
122 |
def forward(self, signal: torch.Tensor):
|
|
|
|
|
|
|
|
|
|
|
123 |
mags = self.stft.forward(signal)
|
124 |
# mags shape: [b, f, t]
|
125 |
|
@@ -139,6 +149,35 @@ class SileroVadModel(nn.Module):
|
|
139 |
# probs shape: [b, t, 1]
|
140 |
return logits, probs
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
class SileroVadPretrainedModel(SileroVadModel):
|
144 |
def __init__(self,
|
|
|
82 |
class SileroVadModel(nn.Module):
|
83 |
def __init__(self, config: SileroVadConfig):
|
84 |
super(SileroVadModel, self).__init__()
|
85 |
+
self.nfft = config.nfft
|
86 |
+
self.win_size = config.win_size
|
87 |
+
self.hop_size = config.hop_size
|
88 |
+
self.win_type = config.win_type
|
89 |
+
|
90 |
self.config = config
|
91 |
self.eps = 1e-12
|
92 |
|
|
|
125 |
self.sigmoid = nn.Sigmoid()
|
126 |
|
127 |
def forward(self, signal: torch.Tensor):
|
128 |
+
if signal.dim() == 2:
|
129 |
+
signal = torch.unsqueeze(signal, dim=1)
|
130 |
+
_, _, num_samples = signal.shape
|
131 |
+
# signal shape [b, 1, num_samples]
|
132 |
+
|
133 |
mags = self.stft.forward(signal)
|
134 |
# mags shape: [b, f, t]
|
135 |
|
|
|
149 |
# probs shape: [b, t, 1]
|
150 |
return logits, probs
|
151 |
|
152 |
+
def forward_chunk(self, chunk: torch.Tensor):
|
153 |
+
# chunk shape [b, 1, num_samples]
|
154 |
+
|
155 |
+
mags = self.stft.forward(chunk)
|
156 |
+
# mags shape: [b, f, t]
|
157 |
+
|
158 |
+
x = torch.transpose(mags, dim0=1, dim1=2)
|
159 |
+
# x shape: [b, t, f]
|
160 |
+
|
161 |
+
x = self.linear.forward(x)
|
162 |
+
# x shape: [b, t, f']
|
163 |
+
|
164 |
+
return
|
165 |
+
|
166 |
+
def forward_chunk_by_chunk(self, signal: torch.Tensor):
|
167 |
+
if signal.dim() == 2:
|
168 |
+
signal = torch.unsqueeze(signal, dim=1)
|
169 |
+
_, _, num_samples = signal.shape
|
170 |
+
# signal shape [b, 1, num_samples]
|
171 |
+
|
172 |
+
t = (num_samples - self.win_size) // self.hop_size + 1
|
173 |
+
waveform_list = list()
|
174 |
+
for i in range(int(t)):
|
175 |
+
begin = i * self.hop_size
|
176 |
+
end = begin + self.win_size
|
177 |
+
sub_signal = signal[:, :, begin: end]
|
178 |
+
|
179 |
+
return
|
180 |
+
|
181 |
|
182 |
class SileroVadPretrainedModel(SileroVadModel):
|
183 |
def __init__(self,
|
toolbox/torchaudio/models/vad/ten_vad/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://huggingface.co/TEN-framework/ten-vad
|
5 |
+
https://zhuanlan.zhihu.com/p/1906832842756976909
|
6 |
+
https://github.com/TEN-framework/ten-vad
|
7 |
+
|
8 |
+
"""
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
pass
|
toolbox/torchaudio/models/vad/ten_vad/modeling_ten_vad.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/models/vad/wav2vec2_vad/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/models/vad/wav2vec2_vad/modeling_wav2vec2.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/utils/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/utils/visualization.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def process_speech_probs(signal: np.ndarray, speech_probs: List[float], frame_step: int) -> np.ndarray:
|
10 |
+
speech_probs_ = list()
|
11 |
+
for p in speech_probs[1:]:
|
12 |
+
speech_probs_.extend([p] * frame_step)
|
13 |
+
|
14 |
+
pad = (signal.shape[0] - len(speech_probs_))
|
15 |
+
speech_probs_ = speech_probs_ + [0.0] * pad
|
16 |
+
speech_probs_ = np.array(speech_probs_, dtype=np.float32)
|
17 |
+
|
18 |
+
if len(speech_probs_) != len(signal):
|
19 |
+
raise AssertionError
|
20 |
+
return speech_probs_
|
21 |
+
|
22 |
+
|
23 |
+
def make_visualization(signal: np.ndarray, speech_probs, sample_rate: int):
|
24 |
+
time = np.arange(0, len(signal)) / sample_rate
|
25 |
+
plt.figure(figsize=(12, 5))
|
26 |
+
plt.plot(time, signal, color='b')
|
27 |
+
plt.plot(time, speech_probs, color='gray')
|
28 |
+
plt.show()
|
29 |
+
return
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == "__main__":
|
33 |
+
pass
|