Commit
·
9829721
0
Parent(s):
first commit
Browse files- .dockerignore +5 -0
- .gitattributes +35 -0
- .gitignore +23 -0
- Dockerfile +24 -0
- README.md +129 -0
- examples/silero_vad_by_webrtcvad/run.sh +164 -0
- examples/silero_vad_by_webrtcvad/step_1_prepare_data.py +185 -0
- examples/silero_vad_by_webrtcvad/step_2_train_model.py +469 -0
- examples/silero_vad_by_webrtcvad/yaml/config.yaml +22 -0
- install.sh +64 -0
- log.py +220 -0
- main.py +69 -0
- project_settings.py +27 -0
- requirements.txt +13 -0
- toolbox/__init__.py +6 -0
- toolbox/json/__init__.py +6 -0
- toolbox/json/misc.py +63 -0
- toolbox/os/__init__.py +6 -0
- toolbox/os/command.py +59 -0
- toolbox/os/environment.py +114 -0
- toolbox/os/other.py +9 -0
- toolbox/torch/__init__.py +6 -0
- toolbox/torch/utils/__init__.py +6 -0
- toolbox/torch/utils/data/__init__.py +6 -0
- toolbox/torch/utils/data/dataset/__init__.py +6 -0
- toolbox/torch/utils/data/dataset/vad_jsonl_dataset.py +179 -0
- toolbox/torchaudio/__init__.py +6 -0
- toolbox/torchaudio/configuration_utils.py +64 -0
- toolbox/torchaudio/losses/__init__.py +6 -0
- toolbox/torchaudio/losses/vad_loss/__init__.py +6 -0
- toolbox/torchaudio/losses/vad_loss/base_vad_loss.py +43 -0
- toolbox/torchaudio/losses/vad_loss/bce_vad_loss.py +52 -0
- toolbox/torchaudio/losses/vad_loss/dice_vad_loss.py +70 -0
- toolbox/torchaudio/metrics/__init__.py +6 -0
- toolbox/torchaudio/metrics/vad_metrics/__init__.py +6 -0
- toolbox/torchaudio/metrics/vad_metrics/vad_accuracy.py +60 -0
- toolbox/torchaudio/models/__init__.py +6 -0
- toolbox/torchaudio/models/vad/__init__.py +6 -0
- toolbox/torchaudio/models/vad/silero_vad/__init__.py +6 -0
- toolbox/torchaudio/models/vad/silero_vad/configuration_silero_vad.py +66 -0
- toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py +151 -0
- toolbox/torchaudio/models/vad/silero_vad/yaml/config.yaml +22 -0
- toolbox/torchaudio/modules/__init__.py +6 -0
- toolbox/torchaudio/modules/conv_stft.py +271 -0
- toolbox/webrtcvad/__init__.py +5 -0
- toolbox/webrtcvad/vad.py +249 -0
.dockerignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
.git/
|
3 |
+
.idea/
|
4 |
+
|
5 |
+
/examples/
|
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
.gradio/
|
3 |
+
.git/
|
4 |
+
.idea/
|
5 |
+
|
6 |
+
**/evaluation_audio/
|
7 |
+
**/file_dir/
|
8 |
+
**/flagged/
|
9 |
+
**/log/
|
10 |
+
**/logs/
|
11 |
+
**/__pycache__/
|
12 |
+
|
13 |
+
/data/
|
14 |
+
/docs/
|
15 |
+
/dotenv/
|
16 |
+
/hub_datasets/
|
17 |
+
/script/
|
18 |
+
/thirdparty/
|
19 |
+
/trained_models/
|
20 |
+
/temp/
|
21 |
+
|
22 |
+
**/*.wav
|
23 |
+
**/*.xlsx
|
Dockerfile
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.12
|
2 |
+
|
3 |
+
WORKDIR /code
|
4 |
+
|
5 |
+
COPY . /code
|
6 |
+
|
7 |
+
RUN apt-get update
|
8 |
+
RUN apt-get install -y ffmpeg build-essential
|
9 |
+
|
10 |
+
RUN pip install --upgrade pip
|
11 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
12 |
+
|
13 |
+
RUN useradd -m -u 1000 user
|
14 |
+
|
15 |
+
USER user
|
16 |
+
|
17 |
+
ENV HOME=/home/user \
|
18 |
+
PATH=/home/user/.local/bin:$PATH
|
19 |
+
|
20 |
+
WORKDIR $HOME/app
|
21 |
+
|
22 |
+
COPY --chown=user . $HOME/app
|
23 |
+
|
24 |
+
CMD ["python3", "main.py"]
|
README.md
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: CC VAD
|
3 |
+
emoji: 🐢
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: blue
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
license: apache-2.0
|
9 |
+
---
|
10 |
+
|
11 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
12 |
+
## CC VAD
|
13 |
+
|
14 |
+
|
15 |
+
### datasets
|
16 |
+
|
17 |
+
```text
|
18 |
+
|
19 |
+
AISHELL (15G)
|
20 |
+
https://openslr.trmal.net/resources/33/
|
21 |
+
|
22 |
+
AISHELL-3 (19G)
|
23 |
+
http://www.openslr.org/93/
|
24 |
+
|
25 |
+
DNS3
|
26 |
+
https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
|
27 |
+
噪音数据来源于 DEMAND, FreeSound, AudioSet.
|
28 |
+
|
29 |
+
MS-SNSD
|
30 |
+
https://github.com/microsoft/MS-SNSD
|
31 |
+
噪音数据来源于 DEMAND, FreeSound.
|
32 |
+
|
33 |
+
MUSAN
|
34 |
+
https://www.openslr.org/17/
|
35 |
+
其中包含 music, noise, speech.
|
36 |
+
music 是一些纯音乐, noise 包含 free-sound, sound-bible, sound-bible部分也许可以做为补充部分.
|
37 |
+
总的来说, 有用的不部不多, 可能噪音数据仍然需要自己收集为主, 更加可靠.
|
38 |
+
|
39 |
+
CHiME-4
|
40 |
+
https://www.chimechallenge.org/challenges/chime4/download.html
|
41 |
+
|
42 |
+
freesound
|
43 |
+
https://freesound.org/
|
44 |
+
|
45 |
+
AudioSet
|
46 |
+
https://research.google.com/audioset/index.html
|
47 |
+
```
|
48 |
+
|
49 |
+
|
50 |
+
### ### 创建训练容器
|
51 |
+
|
52 |
+
```text
|
53 |
+
在容器中训练模型,需要能够从容器中访问到 GPU,参考:
|
54 |
+
https://hub.docker.com/r/ollama/ollama
|
55 |
+
|
56 |
+
docker run -itd \
|
57 |
+
--name cc_vad \
|
58 |
+
--network host \
|
59 |
+
--gpus all \
|
60 |
+
--privileged \
|
61 |
+
--ipc=host \
|
62 |
+
-v /data/tianxing/HuggingDatasets/nx_noise/data:/data/tianxing/HuggingDatasets/nx_noise/data \
|
63 |
+
-v /data/tianxing/PycharmProjects/cc_vad:/data/tianxing/PycharmProjects/cc_vad \
|
64 |
+
python:3.12
|
65 |
+
|
66 |
+
|
67 |
+
查看GPU
|
68 |
+
nvidia-smi
|
69 |
+
watch -n 1 -d nvidia-smi
|
70 |
+
|
71 |
+
|
72 |
+
```
|
73 |
+
|
74 |
+
```text
|
75 |
+
在容器中访问 GPU
|
76 |
+
|
77 |
+
参考:
|
78 |
+
https://blog.csdn.net/footless_bird/article/details/136291344
|
79 |
+
步骤:
|
80 |
+
# 安装
|
81 |
+
yum install -y nvidia-container-toolkit
|
82 |
+
|
83 |
+
# 编辑文件 /etc/docker/daemon.json
|
84 |
+
cat /etc/docker/daemon.json
|
85 |
+
{
|
86 |
+
"data-root": "/data/lib/docker",
|
87 |
+
"default-runtime": "nvidia",
|
88 |
+
"runtimes": {
|
89 |
+
"nvidia": {
|
90 |
+
"path": "/usr/bin/nvidia-container-runtime",
|
91 |
+
"runtimeArgs": []
|
92 |
+
}
|
93 |
+
},
|
94 |
+
"registry-mirrors": [
|
95 |
+
"https://docker.m.daocloud.io",
|
96 |
+
"https://dockerproxy.com",
|
97 |
+
"https://docker.mirrors.ustc.edu.cn",
|
98 |
+
"https://docker.nju.edu.cn"
|
99 |
+
]
|
100 |
+
}
|
101 |
+
|
102 |
+
# 重启 docker
|
103 |
+
systemctl restart docker
|
104 |
+
systemctl daemon-reload
|
105 |
+
|
106 |
+
# 测试容器内能否访问 GPU.
|
107 |
+
docker run --gpus all python:3.12-slim nvidia-smi
|
108 |
+
|
109 |
+
# 通过这种方式启动容器, 在容器中, 可以查看到 GPU. 但是容器中没有 GPU驱动 nvidia-smi 不工作.
|
110 |
+
docker run -it --privileged python:3.12-slim /bin/bash
|
111 |
+
apt update
|
112 |
+
apt install -y pciutils
|
113 |
+
lspci | grep -i nvidia
|
114 |
+
#00:08.0 3D controller: NVIDIA Corporation TU104GL [Tesla T4] (rev a1)
|
115 |
+
|
116 |
+
# 网上看的是这种启动容器的方式, 但是进去后仍然是 nvidia-smi 不工作.
|
117 |
+
docker run \
|
118 |
+
--device /dev/nvidia0:/dev/nvidia0 \
|
119 |
+
--device /dev/nvidiactl:/dev/nvidiactl \
|
120 |
+
--device /dev/nvidia-uvm:/dev/nvidia-uvm \
|
121 |
+
-v /usr/local/nvidia:/usr/local/nvidia \
|
122 |
+
-it --privileged python:3.12-slim /bin/bash
|
123 |
+
|
124 |
+
|
125 |
+
# 这种方式进入容器, nvidia-smi 可以工作. 应该关键是 --gpus all 参数.
|
126 |
+
docker run -itd --gpus all --name open_unsloth python:3.12-slim /bin/bash
|
127 |
+
docker run -itd --gpus all --name Qwen2-7B-Instruct python:3.12-slim /bin/bash
|
128 |
+
|
129 |
+
```
|
examples/silero_vad_by_webrtcvad/run.sh
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
: <<'END'
|
4 |
+
|
5 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name dfnet-nx-speech \
|
6 |
+
--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
|
7 |
+
--speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
|
8 |
+
|
9 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx-dns3 \
|
10 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
11 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
12 |
+
|
13 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx2 \
|
14 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
|
15 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
|
16 |
+
|
17 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dfnet2-nx2-dns3 --final_model_name dfnet2-nx2-dns3 \
|
18 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
|
19 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
|
20 |
+
|
21 |
+
|
22 |
+
END
|
23 |
+
|
24 |
+
|
25 |
+
# params
|
26 |
+
system_version="windows";
|
27 |
+
verbose=true;
|
28 |
+
stage=0 # start from 0 if you need to start from data preparation
|
29 |
+
stop_stage=9
|
30 |
+
|
31 |
+
work_dir="$(pwd)"
|
32 |
+
file_folder_name=file_folder_name
|
33 |
+
final_model_name=final_model_name
|
34 |
+
config_file="yaml/config.yaml"
|
35 |
+
limit=10
|
36 |
+
|
37 |
+
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
38 |
+
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
39 |
+
|
40 |
+
max_count=-1
|
41 |
+
|
42 |
+
nohup_name=nohup.out
|
43 |
+
|
44 |
+
# model params
|
45 |
+
batch_size=64
|
46 |
+
max_epochs=200
|
47 |
+
save_top_k=10
|
48 |
+
patience=5
|
49 |
+
|
50 |
+
|
51 |
+
# parse options
|
52 |
+
while true; do
|
53 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
54 |
+
case "$1" in
|
55 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
56 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
57 |
+
old_value="(eval echo \\$$name)";
|
58 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
59 |
+
was_bool=true;
|
60 |
+
else
|
61 |
+
was_bool=false;
|
62 |
+
fi
|
63 |
+
|
64 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
65 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
66 |
+
eval "${name}=\"$2\"";
|
67 |
+
|
68 |
+
# Check that Boolean-valued arguments are really Boolean.
|
69 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
70 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
71 |
+
exit 1;
|
72 |
+
fi
|
73 |
+
shift 2;
|
74 |
+
;;
|
75 |
+
|
76 |
+
*) break;
|
77 |
+
esac
|
78 |
+
done
|
79 |
+
|
80 |
+
file_dir="${work_dir}/${file_folder_name}"
|
81 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
82 |
+
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
83 |
+
|
84 |
+
train_dataset="${file_dir}/train.jsonl"
|
85 |
+
valid_dataset="${file_dir}/valid.jsonl"
|
86 |
+
|
87 |
+
$verbose && echo "system_version: ${system_version}"
|
88 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
89 |
+
|
90 |
+
if [ $system_version == "windows" ]; then
|
91 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
92 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
93 |
+
#source /data/local/bin/nx_denoise/bin/activate
|
94 |
+
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
95 |
+
fi
|
96 |
+
|
97 |
+
|
98 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
99 |
+
$verbose && echo "stage 1: prepare data"
|
100 |
+
cd "${work_dir}" || exit 1
|
101 |
+
python3 step_1_prepare_data.py \
|
102 |
+
--file_dir "${file_dir}" \
|
103 |
+
--noise_dir "${noise_dir}" \
|
104 |
+
--speech_dir "${speech_dir}" \
|
105 |
+
--train_dataset "${train_dataset}" \
|
106 |
+
--valid_dataset "${valid_dataset}" \
|
107 |
+
--max_count "${max_count}" \
|
108 |
+
|
109 |
+
fi
|
110 |
+
|
111 |
+
|
112 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
113 |
+
$verbose && echo "stage 2: train model"
|
114 |
+
cd "${work_dir}" || exit 1
|
115 |
+
python3 step_2_train_model.py \
|
116 |
+
--train_dataset "${train_dataset}" \
|
117 |
+
--valid_dataset "${valid_dataset}" \
|
118 |
+
--serialization_dir "${file_dir}" \
|
119 |
+
--config_file "${config_file}" \
|
120 |
+
|
121 |
+
fi
|
122 |
+
|
123 |
+
|
124 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
125 |
+
$verbose && echo "stage 3: test model"
|
126 |
+
cd "${work_dir}" || exit 1
|
127 |
+
python3 step_3_evaluation.py \
|
128 |
+
--valid_dataset "${valid_dataset}" \
|
129 |
+
--model_dir "${file_dir}/best" \
|
130 |
+
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
131 |
+
--limit "${limit}" \
|
132 |
+
|
133 |
+
fi
|
134 |
+
|
135 |
+
|
136 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
137 |
+
$verbose && echo "stage 4: collect files"
|
138 |
+
cd "${work_dir}" || exit 1
|
139 |
+
|
140 |
+
mkdir -p ${final_model_dir}
|
141 |
+
|
142 |
+
cp "${file_dir}/best"/* "${final_model_dir}"
|
143 |
+
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
144 |
+
|
145 |
+
cd "${final_model_dir}/.." || exit 1;
|
146 |
+
|
147 |
+
if [ -e "${final_model_name}.zip" ]; then
|
148 |
+
rm -rf "${final_model_name}_backup.zip"
|
149 |
+
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
150 |
+
fi
|
151 |
+
|
152 |
+
zip -r "${final_model_name}.zip" "${final_model_name}"
|
153 |
+
rm -rf "${final_model_name}"
|
154 |
+
|
155 |
+
fi
|
156 |
+
|
157 |
+
|
158 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
159 |
+
$verbose && echo "stage 5: clear file_dir"
|
160 |
+
cd "${work_dir}" || exit 1
|
161 |
+
|
162 |
+
rm -rf "${file_dir}";
|
163 |
+
|
164 |
+
fi
|
examples/silero_vad_by_webrtcvad/step_1_prepare_data.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
import random
|
8 |
+
import sys
|
9 |
+
|
10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
11 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
12 |
+
|
13 |
+
import librosa
|
14 |
+
import numpy as np
|
15 |
+
from scipy.io import wavfile
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
from toolbox.webrtcvad.vad import WebRTCVad
|
19 |
+
|
20 |
+
|
21 |
+
def get_args():
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
24 |
+
|
25 |
+
parser.add_argument(
|
26 |
+
"--noise_dir",
|
27 |
+
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
28 |
+
type=str
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--speech_dir",
|
32 |
+
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
|
33 |
+
type=str
|
34 |
+
)
|
35 |
+
|
36 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
37 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
38 |
+
|
39 |
+
parser.add_argument("--duration", default=4.0, type=float)
|
40 |
+
parser.add_argument("--min_snr_db", default=-10, type=float)
|
41 |
+
parser.add_argument("--max_snr_db", default=20, type=float)
|
42 |
+
|
43 |
+
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
44 |
+
|
45 |
+
parser.add_argument("--max_count", default=-1, type=int)
|
46 |
+
|
47 |
+
# vad
|
48 |
+
parser.add_argument("--agg", default=3, type=int)
|
49 |
+
parser.add_argument("--frame_duration_ms", default=30, type=int)
|
50 |
+
parser.add_argument("--padding_duration_ms", default=30, type=int)
|
51 |
+
parser.add_argument("--silence_duration_threshold", default=0.3, type=float)
|
52 |
+
|
53 |
+
args = parser.parse_args()
|
54 |
+
return args
|
55 |
+
|
56 |
+
|
57 |
+
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
|
58 |
+
data_dir = Path(data_dir)
|
59 |
+
for epoch_idx in range(max_epoch):
|
60 |
+
for filename in data_dir.glob("**/*.wav"):
|
61 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
62 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
63 |
+
|
64 |
+
if raw_duration < duration:
|
65 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
66 |
+
continue
|
67 |
+
if signal.ndim != 1:
|
68 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
69 |
+
|
70 |
+
signal_length = len(signal)
|
71 |
+
win_size = int(duration * sample_rate)
|
72 |
+
for begin in range(0, signal_length - win_size, win_size):
|
73 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
74 |
+
continue
|
75 |
+
row = {
|
76 |
+
"epoch_idx": epoch_idx,
|
77 |
+
"filename": filename.as_posix(),
|
78 |
+
"raw_duration": round(raw_duration, 4),
|
79 |
+
"offset": round(begin / sample_rate, 4),
|
80 |
+
"duration": round(duration, 4),
|
81 |
+
}
|
82 |
+
yield row
|
83 |
+
|
84 |
+
|
85 |
+
def main():
|
86 |
+
args = get_args()
|
87 |
+
|
88 |
+
file_dir = Path(args.file_dir)
|
89 |
+
file_dir.mkdir(exist_ok=True)
|
90 |
+
|
91 |
+
noise_dir = Path(args.noise_dir)
|
92 |
+
speech_dir = Path(args.speech_dir)
|
93 |
+
|
94 |
+
noise_generator = target_second_signal_generator(
|
95 |
+
noise_dir.as_posix(),
|
96 |
+
duration=args.duration,
|
97 |
+
sample_rate=args.target_sample_rate,
|
98 |
+
max_epoch=100000,
|
99 |
+
)
|
100 |
+
speech_generator = target_second_signal_generator(
|
101 |
+
speech_dir.as_posix(),
|
102 |
+
duration=args.duration,
|
103 |
+
sample_rate=args.target_sample_rate,
|
104 |
+
max_epoch=1,
|
105 |
+
)
|
106 |
+
|
107 |
+
w_vad = WebRTCVad(
|
108 |
+
agg=args.agg,
|
109 |
+
frame_duration_ms=args.frame_duration_ms,
|
110 |
+
padding_duration_ms=args.padding_duration_ms,
|
111 |
+
silence_duration_threshold=args.silence_duration_threshold,
|
112 |
+
sample_rate=args.target_sample_rate,
|
113 |
+
)
|
114 |
+
|
115 |
+
count = 0
|
116 |
+
process_bar = tqdm(desc="build dataset jsonl")
|
117 |
+
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
118 |
+
for noise, speech in zip(noise_generator, speech_generator):
|
119 |
+
if count >= args.max_count > 0:
|
120 |
+
break
|
121 |
+
|
122 |
+
# row
|
123 |
+
noise_filename = noise["filename"]
|
124 |
+
noise_raw_duration = noise["raw_duration"]
|
125 |
+
noise_offset = noise["offset"]
|
126 |
+
noise_duration = noise["duration"]
|
127 |
+
|
128 |
+
speech_filename = speech["filename"]
|
129 |
+
speech_raw_duration = speech["raw_duration"]
|
130 |
+
speech_offset = speech["offset"]
|
131 |
+
speech_duration = speech["duration"]
|
132 |
+
|
133 |
+
# vad
|
134 |
+
_, signal = wavfile.read(speech_filename)
|
135 |
+
vad_segments = list()
|
136 |
+
segments = w_vad.vad(signal)
|
137 |
+
vad_segments += segments
|
138 |
+
segments = w_vad.last_vad_segments()
|
139 |
+
vad_segments += segments
|
140 |
+
|
141 |
+
# row
|
142 |
+
random1 = random.random()
|
143 |
+
random2 = random.random()
|
144 |
+
|
145 |
+
row = {
|
146 |
+
"count": count,
|
147 |
+
|
148 |
+
"noise_filename": noise_filename,
|
149 |
+
"noise_raw_duration": noise_raw_duration,
|
150 |
+
"noise_offset": noise_offset,
|
151 |
+
"noise_duration": noise_duration,
|
152 |
+
|
153 |
+
"speech_filename": speech_filename,
|
154 |
+
"speech_raw_duration": speech_raw_duration,
|
155 |
+
"speech_offset": speech_offset,
|
156 |
+
"speech_duration": speech_duration,
|
157 |
+
|
158 |
+
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
159 |
+
|
160 |
+
"vad_segments": vad_segments,
|
161 |
+
|
162 |
+
"random1": random1,
|
163 |
+
}
|
164 |
+
row = json.dumps(row, ensure_ascii=False)
|
165 |
+
if random2 < (1 / 300 / 1):
|
166 |
+
fvalid.write(f"{row}\n")
|
167 |
+
else:
|
168 |
+
ftrain.write(f"{row}\n")
|
169 |
+
|
170 |
+
count += 1
|
171 |
+
duration_seconds = count * args.duration
|
172 |
+
duration_hours = duration_seconds / 3600
|
173 |
+
|
174 |
+
process_bar.update(n=1)
|
175 |
+
process_bar.set_postfix({
|
176 |
+
# "duration_seconds": round(duration_seconds, 4),
|
177 |
+
"duration_hours": round(duration_hours, 4),
|
178 |
+
|
179 |
+
})
|
180 |
+
|
181 |
+
return
|
182 |
+
|
183 |
+
|
184 |
+
if __name__ == "__main__":
|
185 |
+
main()
|
examples/silero_vad_by_webrtcvad/step_2_train_model.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/Rikorose/DeepFilterNet
|
5 |
+
"""
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
from logging.handlers import TimedRotatingFileHandler
|
10 |
+
import os
|
11 |
+
import platform
|
12 |
+
from pathlib import Path
|
13 |
+
import random
|
14 |
+
import sys
|
15 |
+
import shutil
|
16 |
+
from typing import List
|
17 |
+
|
18 |
+
from fontTools.varLib.plot import stops
|
19 |
+
|
20 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
21 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
import torch.nn as nn
|
26 |
+
from torch.nn import functional as F
|
27 |
+
from torch.utils.data.dataloader import DataLoader
|
28 |
+
from tqdm import tqdm
|
29 |
+
|
30 |
+
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
31 |
+
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
32 |
+
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
33 |
+
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
34 |
+
from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config
|
35 |
+
from toolbox.torchaudio.models.dfnet2.modeling_dfnet2 import DfNet2, DfNet2PretrainedModel
|
36 |
+
|
37 |
+
|
38 |
+
def get_args():
|
39 |
+
parser = argparse.ArgumentParser()
|
40 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
41 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
42 |
+
|
43 |
+
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
44 |
+
parser.add_argument("--patience", default=30, type=int)
|
45 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
46 |
+
|
47 |
+
parser.add_argument("--config_file", default="config.yaml", type=str)
|
48 |
+
|
49 |
+
args = parser.parse_args()
|
50 |
+
return args
|
51 |
+
|
52 |
+
|
53 |
+
def logging_config(file_dir: str):
|
54 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
55 |
+
|
56 |
+
logging.basicConfig(format=fmt,
|
57 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
58 |
+
level=logging.INFO)
|
59 |
+
file_handler = TimedRotatingFileHandler(
|
60 |
+
filename=os.path.join(file_dir, "main.log"),
|
61 |
+
encoding="utf-8",
|
62 |
+
when="D",
|
63 |
+
interval=1,
|
64 |
+
backupCount=7
|
65 |
+
)
|
66 |
+
file_handler.setLevel(logging.INFO)
|
67 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
68 |
+
logger = logging.getLogger(__name__)
|
69 |
+
logger.addHandler(file_handler)
|
70 |
+
|
71 |
+
return logger
|
72 |
+
|
73 |
+
|
74 |
+
class CollateFunction(object):
|
75 |
+
def __init__(self):
|
76 |
+
pass
|
77 |
+
|
78 |
+
def __call__(self, batch: List[dict]):
|
79 |
+
clean_audios = list()
|
80 |
+
noisy_audios = list()
|
81 |
+
snr_db_list = list()
|
82 |
+
|
83 |
+
for sample in batch:
|
84 |
+
# noise_wave: torch.Tensor = sample["noise_wave"]
|
85 |
+
clean_audio: torch.Tensor = sample["speech_wave"]
|
86 |
+
noisy_audio: torch.Tensor = sample["mix_wave"]
|
87 |
+
# snr_db: float = sample["snr_db"]
|
88 |
+
|
89 |
+
clean_audios.append(clean_audio)
|
90 |
+
noisy_audios.append(noisy_audio)
|
91 |
+
|
92 |
+
clean_audios = torch.stack(clean_audios)
|
93 |
+
noisy_audios = torch.stack(noisy_audios)
|
94 |
+
|
95 |
+
# assert
|
96 |
+
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
97 |
+
raise AssertionError("nan or inf in clean_audios")
|
98 |
+
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
99 |
+
raise AssertionError("nan or inf in noisy_audios")
|
100 |
+
return clean_audios, noisy_audios
|
101 |
+
|
102 |
+
|
103 |
+
collate_fn = CollateFunction()
|
104 |
+
|
105 |
+
|
106 |
+
def main():
|
107 |
+
args = get_args()
|
108 |
+
|
109 |
+
config = DfNet2Config.from_pretrained(
|
110 |
+
pretrained_model_name_or_path=args.config_file,
|
111 |
+
)
|
112 |
+
|
113 |
+
serialization_dir = Path(args.serialization_dir)
|
114 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
115 |
+
|
116 |
+
logger = logging_config(serialization_dir)
|
117 |
+
|
118 |
+
random.seed(config.seed)
|
119 |
+
np.random.seed(config.seed)
|
120 |
+
torch.manual_seed(config.seed)
|
121 |
+
logger.info(f"set seed: {config.seed}")
|
122 |
+
|
123 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
124 |
+
n_gpu = torch.cuda.device_count()
|
125 |
+
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
126 |
+
|
127 |
+
# datasets
|
128 |
+
train_dataset = DenoiseJsonlDataset(
|
129 |
+
jsonl_file=args.train_dataset,
|
130 |
+
expected_sample_rate=config.sample_rate,
|
131 |
+
max_wave_value=32768.0,
|
132 |
+
min_snr_db=config.min_snr_db,
|
133 |
+
max_snr_db=config.max_snr_db,
|
134 |
+
# skip=225000,
|
135 |
+
)
|
136 |
+
valid_dataset = DenoiseJsonlDataset(
|
137 |
+
jsonl_file=args.valid_dataset,
|
138 |
+
expected_sample_rate=config.sample_rate,
|
139 |
+
max_wave_value=32768.0,
|
140 |
+
min_snr_db=config.min_snr_db,
|
141 |
+
max_snr_db=config.max_snr_db,
|
142 |
+
)
|
143 |
+
train_data_loader = DataLoader(
|
144 |
+
dataset=train_dataset,
|
145 |
+
batch_size=config.batch_size,
|
146 |
+
# shuffle=True,
|
147 |
+
sampler=None,
|
148 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
149 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
150 |
+
collate_fn=collate_fn,
|
151 |
+
pin_memory=False,
|
152 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
153 |
+
)
|
154 |
+
valid_data_loader = DataLoader(
|
155 |
+
dataset=valid_dataset,
|
156 |
+
batch_size=config.batch_size,
|
157 |
+
# shuffle=True,
|
158 |
+
sampler=None,
|
159 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
160 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
161 |
+
collate_fn=collate_fn,
|
162 |
+
pin_memory=False,
|
163 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
164 |
+
)
|
165 |
+
|
166 |
+
# models
|
167 |
+
logger.info(f"prepare models. config_file: {args.config_file}")
|
168 |
+
model = DfNet2PretrainedModel(config).to(device)
|
169 |
+
model.to(device)
|
170 |
+
model.train()
|
171 |
+
|
172 |
+
# optimizer
|
173 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
|
174 |
+
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
175 |
+
|
176 |
+
# resume training
|
177 |
+
last_step_idx = -1
|
178 |
+
last_epoch = -1
|
179 |
+
for step_idx_str in serialization_dir.glob("steps-*"):
|
180 |
+
step_idx_str = Path(step_idx_str)
|
181 |
+
step_idx = step_idx_str.stem.split("-")[1]
|
182 |
+
step_idx = int(step_idx)
|
183 |
+
if step_idx > last_step_idx:
|
184 |
+
last_step_idx = step_idx
|
185 |
+
# last_epoch = 1
|
186 |
+
|
187 |
+
if last_step_idx != -1:
|
188 |
+
logger.info(f"resume from steps-{last_step_idx}.")
|
189 |
+
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
190 |
+
|
191 |
+
logger.info(f"load state dict for model.")
|
192 |
+
with open(model_pt.as_posix(), "rb") as f:
|
193 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
194 |
+
model.load_state_dict(state_dict, strict=True)
|
195 |
+
|
196 |
+
if config.lr_scheduler == "CosineAnnealingLR":
|
197 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
198 |
+
optimizer,
|
199 |
+
last_epoch=last_epoch,
|
200 |
+
# T_max=10 * config.eval_steps,
|
201 |
+
# eta_min=0.01 * config.lr,
|
202 |
+
**config.lr_scheduler_kwargs,
|
203 |
+
)
|
204 |
+
elif config.lr_scheduler == "MultiStepLR":
|
205 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
206 |
+
optimizer,
|
207 |
+
last_epoch=last_epoch,
|
208 |
+
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
209 |
+
)
|
210 |
+
else:
|
211 |
+
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
212 |
+
|
213 |
+
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
214 |
+
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
215 |
+
fft_size_list=[256, 512, 1024],
|
216 |
+
win_size_list=[256, 512, 1024],
|
217 |
+
hop_size_list=[128, 256, 512],
|
218 |
+
factor_sc=1.5,
|
219 |
+
factor_mag=1.0,
|
220 |
+
reduction="mean"
|
221 |
+
).to(device)
|
222 |
+
|
223 |
+
# training loop
|
224 |
+
|
225 |
+
# state
|
226 |
+
average_pesq_score = 1000000000
|
227 |
+
average_loss = 1000000000
|
228 |
+
average_mr_stft_loss = 1000000000
|
229 |
+
average_neg_si_snr_loss = 1000000000
|
230 |
+
average_mask_loss = 1000000000
|
231 |
+
average_lsnr_loss = 1000000000
|
232 |
+
|
233 |
+
model_list = list()
|
234 |
+
best_epoch_idx = None
|
235 |
+
best_step_idx = None
|
236 |
+
best_metric = None
|
237 |
+
patience_count = 0
|
238 |
+
|
239 |
+
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
240 |
+
|
241 |
+
logger.info("training")
|
242 |
+
early_stop_flag = False
|
243 |
+
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
244 |
+
if early_stop_flag:
|
245 |
+
break
|
246 |
+
|
247 |
+
# train
|
248 |
+
model.train()
|
249 |
+
|
250 |
+
total_pesq_score = 0.
|
251 |
+
total_loss = 0.
|
252 |
+
total_mr_stft_loss = 0.
|
253 |
+
total_neg_si_snr_loss = 0.
|
254 |
+
total_mask_loss = 0.
|
255 |
+
total_lsnr_loss = 0.
|
256 |
+
total_batches = 0.
|
257 |
+
|
258 |
+
progress_bar_train = tqdm(
|
259 |
+
initial=step_idx,
|
260 |
+
desc="Training; epoch-{}".format(epoch_idx),
|
261 |
+
)
|
262 |
+
for train_batch in train_data_loader:
|
263 |
+
clean_audios, noisy_audios = train_batch
|
264 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
265 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
266 |
+
|
267 |
+
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
268 |
+
# est_wav shape: [b, 1, n_samples]
|
269 |
+
est_wav = torch.squeeze(est_wav, dim=1)
|
270 |
+
# est_wav shape: [b, n_samples]
|
271 |
+
|
272 |
+
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
273 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
274 |
+
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
275 |
+
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
276 |
+
|
277 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss
|
278 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
279 |
+
logger.info(f"find nan or inf in loss. continue.")
|
280 |
+
continue
|
281 |
+
|
282 |
+
denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
|
283 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
284 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
285 |
+
|
286 |
+
optimizer.zero_grad()
|
287 |
+
loss.backward()
|
288 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
289 |
+
optimizer.step()
|
290 |
+
lr_scheduler.step()
|
291 |
+
|
292 |
+
total_pesq_score += pesq_score
|
293 |
+
total_loss += loss.item()
|
294 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
295 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
296 |
+
total_mask_loss += mask_loss.item()
|
297 |
+
total_lsnr_loss += lsnr_loss.item()
|
298 |
+
total_batches += 1
|
299 |
+
|
300 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
301 |
+
average_loss = round(total_loss / total_batches, 4)
|
302 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
303 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
304 |
+
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
305 |
+
average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
|
306 |
+
|
307 |
+
progress_bar_train.update(1)
|
308 |
+
progress_bar_train.set_postfix({
|
309 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
310 |
+
"pesq_score": average_pesq_score,
|
311 |
+
"loss": average_loss,
|
312 |
+
"mr_stft_loss": average_mr_stft_loss,
|
313 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
314 |
+
"mask_loss": average_mask_loss,
|
315 |
+
"lsnr_loss": average_lsnr_loss,
|
316 |
+
})
|
317 |
+
|
318 |
+
# evaluation
|
319 |
+
step_idx += 1
|
320 |
+
if step_idx % config.eval_steps == 0:
|
321 |
+
with torch.no_grad():
|
322 |
+
torch.cuda.empty_cache()
|
323 |
+
|
324 |
+
model.eval()
|
325 |
+
|
326 |
+
total_pesq_score = 0.
|
327 |
+
total_loss = 0.
|
328 |
+
total_mr_stft_loss = 0.
|
329 |
+
total_neg_si_snr_loss = 0.
|
330 |
+
total_mask_loss = 0.
|
331 |
+
total_lsnr_loss = 0.
|
332 |
+
total_batches = 0.
|
333 |
+
|
334 |
+
progress_bar_train.close()
|
335 |
+
progress_bar_eval = tqdm(
|
336 |
+
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
337 |
+
)
|
338 |
+
for eval_batch in valid_data_loader:
|
339 |
+
clean_audios, noisy_audios = eval_batch
|
340 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
341 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
342 |
+
|
343 |
+
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
344 |
+
# est_wav shape: [b, 1, n_samples]
|
345 |
+
est_wav = torch.squeeze(est_wav, dim=1)
|
346 |
+
# est_wav shape: [b, n_samples]
|
347 |
+
|
348 |
+
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
349 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
350 |
+
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
351 |
+
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
352 |
+
|
353 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss
|
354 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
355 |
+
logger.info(f"find nan or inf in loss. continue.")
|
356 |
+
continue
|
357 |
+
|
358 |
+
denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
|
359 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
360 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
361 |
+
|
362 |
+
total_pesq_score += pesq_score
|
363 |
+
total_loss += loss.item()
|
364 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
365 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
366 |
+
total_mask_loss += mask_loss.item()
|
367 |
+
total_lsnr_loss += lsnr_loss.item()
|
368 |
+
total_batches += 1
|
369 |
+
|
370 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
371 |
+
average_loss = round(total_loss / total_batches, 4)
|
372 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
373 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
374 |
+
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
375 |
+
average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
|
376 |
+
|
377 |
+
progress_bar_eval.update(1)
|
378 |
+
progress_bar_eval.set_postfix({
|
379 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
380 |
+
"pesq_score": average_pesq_score,
|
381 |
+
"loss": average_loss,
|
382 |
+
"mr_stft_loss": average_mr_stft_loss,
|
383 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
384 |
+
"mask_loss": average_mask_loss,
|
385 |
+
"lsnr_loss": average_lsnr_loss,
|
386 |
+
})
|
387 |
+
|
388 |
+
model.train()
|
389 |
+
|
390 |
+
total_pesq_score = 0.
|
391 |
+
total_loss = 0.
|
392 |
+
total_mr_stft_loss = 0.
|
393 |
+
total_neg_si_snr_loss = 0.
|
394 |
+
total_mask_loss = 0.
|
395 |
+
total_lsnr_loss = 0.
|
396 |
+
total_batches = 0.
|
397 |
+
|
398 |
+
progress_bar_eval.close()
|
399 |
+
progress_bar_train = tqdm(
|
400 |
+
initial=progress_bar_train.n,
|
401 |
+
postfix=progress_bar_train.postfix,
|
402 |
+
desc=progress_bar_train.desc,
|
403 |
+
)
|
404 |
+
|
405 |
+
# save path
|
406 |
+
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
407 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
408 |
+
|
409 |
+
# save models
|
410 |
+
model.save_pretrained(save_dir.as_posix())
|
411 |
+
|
412 |
+
model_list.append(save_dir)
|
413 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
414 |
+
model_to_delete: Path = model_list.pop(0)
|
415 |
+
shutil.rmtree(model_to_delete.as_posix())
|
416 |
+
|
417 |
+
# save metric
|
418 |
+
if best_metric is None:
|
419 |
+
best_epoch_idx = epoch_idx
|
420 |
+
best_step_idx = step_idx
|
421 |
+
best_metric = average_pesq_score
|
422 |
+
elif average_pesq_score >= best_metric:
|
423 |
+
# great is better.
|
424 |
+
best_epoch_idx = epoch_idx
|
425 |
+
best_step_idx = step_idx
|
426 |
+
best_metric = average_pesq_score
|
427 |
+
else:
|
428 |
+
pass
|
429 |
+
|
430 |
+
metrics = {
|
431 |
+
"epoch_idx": epoch_idx,
|
432 |
+
"best_epoch_idx": best_epoch_idx,
|
433 |
+
"best_step_idx": best_step_idx,
|
434 |
+
"pesq_score": average_pesq_score,
|
435 |
+
"loss": average_loss,
|
436 |
+
"mr_stft_loss": average_mr_stft_loss,
|
437 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
438 |
+
"mask_loss": average_mask_loss,
|
439 |
+
"lsnr_loss": average_lsnr_loss,
|
440 |
+
}
|
441 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
442 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
443 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
444 |
+
|
445 |
+
# save best
|
446 |
+
best_dir = serialization_dir / "best"
|
447 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
448 |
+
if best_dir.exists():
|
449 |
+
shutil.rmtree(best_dir)
|
450 |
+
shutil.copytree(save_dir, best_dir)
|
451 |
+
|
452 |
+
# early stop
|
453 |
+
early_stop_flag = False
|
454 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
455 |
+
patience_count = 0
|
456 |
+
else:
|
457 |
+
patience_count += 1
|
458 |
+
if patience_count >= args.patience:
|
459 |
+
early_stop_flag = True
|
460 |
+
|
461 |
+
# early stop
|
462 |
+
if early_stop_flag:
|
463 |
+
break
|
464 |
+
|
465 |
+
return
|
466 |
+
|
467 |
+
|
468 |
+
if __name__ == "__main__":
|
469 |
+
main()
|
examples/silero_vad_by_webrtcvad/yaml/config.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "silero_vad"
|
2 |
+
|
3 |
+
sample_rate: 8000
|
4 |
+
nfft: 512
|
5 |
+
win_size: 240
|
6 |
+
hop_size: 80
|
7 |
+
win_type: hann
|
8 |
+
|
9 |
+
in_channels: 64
|
10 |
+
hidden_size: 128
|
11 |
+
|
12 |
+
lr: 0.001
|
13 |
+
lr_scheduler: CosineAnnealingLR
|
14 |
+
lr_scheduler_kwargs: {}
|
15 |
+
|
16 |
+
max_epochs: 100
|
17 |
+
clip_grad_norm: 10.0
|
18 |
+
seed: 1234
|
19 |
+
|
20 |
+
num_workers: 4
|
21 |
+
batch_size: 4
|
22 |
+
eval_steps: 25000
|
install.sh
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# bash install.sh --stage 2 --stop_stage 2 --system_version centos
|
4 |
+
|
5 |
+
|
6 |
+
python_version=3.12.1
|
7 |
+
system_version="centos";
|
8 |
+
|
9 |
+
verbose=true;
|
10 |
+
stage=-1
|
11 |
+
stop_stage=0
|
12 |
+
|
13 |
+
|
14 |
+
# parse options
|
15 |
+
while true; do
|
16 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
17 |
+
case "$1" in
|
18 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
19 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
20 |
+
old_value="(eval echo \\$$name)";
|
21 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
22 |
+
was_bool=true;
|
23 |
+
else
|
24 |
+
was_bool=false;
|
25 |
+
fi
|
26 |
+
|
27 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
28 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
29 |
+
eval "${name}=\"$2\"";
|
30 |
+
|
31 |
+
# Check that Boolean-valued arguments are really Boolean.
|
32 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
33 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
34 |
+
exit 1;
|
35 |
+
fi
|
36 |
+
shift 2;
|
37 |
+
;;
|
38 |
+
|
39 |
+
*) break;
|
40 |
+
esac
|
41 |
+
done
|
42 |
+
|
43 |
+
work_dir="$(pwd)"
|
44 |
+
|
45 |
+
|
46 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
47 |
+
$verbose && echo "stage 1: install python"
|
48 |
+
cd "${work_dir}" || exit 1;
|
49 |
+
|
50 |
+
sh ./script/install_python.sh --python_version "${python_version}" --system_version "${system_version}"
|
51 |
+
fi
|
52 |
+
|
53 |
+
|
54 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
55 |
+
$verbose && echo "stage 2: create virtualenv"
|
56 |
+
|
57 |
+
# /usr/local/python-3.12.1/bin/virtualenv cc_vad
|
58 |
+
# source /data/local/bin/cc_vad/bin/activate
|
59 |
+
/usr/local/python-${python_version}/bin/pip3 install virtualenv
|
60 |
+
mkdir -p /data/local/bin
|
61 |
+
cd /data/local/bin || exit 1;
|
62 |
+
/usr/local/python-${python_version}/bin/virtualenv cc_vad
|
63 |
+
|
64 |
+
fi
|
log.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from datetime import datetime
|
4 |
+
import logging
|
5 |
+
from logging.handlers import RotatingFileHandler, TimedRotatingFileHandler
|
6 |
+
import os
|
7 |
+
from zoneinfo import ZoneInfo # Python 3.9+ 自带,无需安装
|
8 |
+
|
9 |
+
|
10 |
+
def get_converter(tz_info: str = "Asia/Shanghai"):
|
11 |
+
def converter(timestamp):
|
12 |
+
dt = datetime.fromtimestamp(timestamp, ZoneInfo(tz_info))
|
13 |
+
result = dt.timetuple()
|
14 |
+
return result
|
15 |
+
return converter
|
16 |
+
|
17 |
+
|
18 |
+
def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
|
19 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
20 |
+
|
21 |
+
formatter = logging.Formatter(
|
22 |
+
fmt=fmt,
|
23 |
+
datefmt="%Y-%m-%d %H:%M:%S %z"
|
24 |
+
)
|
25 |
+
formatter.converter = get_converter(tz_info)
|
26 |
+
|
27 |
+
stream_handler = logging.StreamHandler()
|
28 |
+
stream_handler.setLevel(logging.INFO)
|
29 |
+
stream_handler.setFormatter(formatter)
|
30 |
+
|
31 |
+
# main
|
32 |
+
main_logger = logging.getLogger("main")
|
33 |
+
main_logger.addHandler(stream_handler)
|
34 |
+
main_info_file_handler = RotatingFileHandler(
|
35 |
+
filename=os.path.join(log_directory, "main.log"),
|
36 |
+
maxBytes=100*1024*1024, # 100MB
|
37 |
+
encoding="utf-8",
|
38 |
+
backupCount=2,
|
39 |
+
)
|
40 |
+
main_info_file_handler.setLevel(logging.INFO)
|
41 |
+
main_info_file_handler.setFormatter(logging.Formatter(fmt))
|
42 |
+
main_logger.addHandler(main_info_file_handler)
|
43 |
+
|
44 |
+
# http
|
45 |
+
http_logger = logging.getLogger("http")
|
46 |
+
http_file_handler = RotatingFileHandler(
|
47 |
+
filename=os.path.join(log_directory, "http.log"),
|
48 |
+
maxBytes=100*1024*1024, # 100MB
|
49 |
+
encoding="utf-8",
|
50 |
+
backupCount=2,
|
51 |
+
)
|
52 |
+
http_file_handler.setLevel(logging.DEBUG)
|
53 |
+
http_file_handler.setFormatter(logging.Formatter(fmt))
|
54 |
+
http_logger.addHandler(http_file_handler)
|
55 |
+
|
56 |
+
# api
|
57 |
+
api_logger = logging.getLogger("api")
|
58 |
+
api_file_handler = RotatingFileHandler(
|
59 |
+
filename=os.path.join(log_directory, "api.log"),
|
60 |
+
maxBytes=10*1024*1024, # 10MB
|
61 |
+
encoding="utf-8",
|
62 |
+
backupCount=2,
|
63 |
+
)
|
64 |
+
api_file_handler.setLevel(logging.DEBUG)
|
65 |
+
api_file_handler.setFormatter(logging.Formatter(fmt))
|
66 |
+
api_logger.addHandler(api_file_handler)
|
67 |
+
|
68 |
+
# alarm
|
69 |
+
alarm_logger = logging.getLogger("alarm")
|
70 |
+
alarm_file_handler = RotatingFileHandler(
|
71 |
+
filename=os.path.join(log_directory, "alarm.log"),
|
72 |
+
maxBytes=1*1024*1024, # 1MB
|
73 |
+
encoding="utf-8",
|
74 |
+
backupCount=2,
|
75 |
+
)
|
76 |
+
alarm_file_handler.setLevel(logging.DEBUG)
|
77 |
+
alarm_file_handler.setFormatter(logging.Formatter(fmt))
|
78 |
+
alarm_logger.addHandler(alarm_file_handler)
|
79 |
+
|
80 |
+
debug_file_handler = RotatingFileHandler(
|
81 |
+
filename=os.path.join(log_directory, "debug.log"),
|
82 |
+
maxBytes=1*1024*1024, # 1MB
|
83 |
+
encoding="utf-8",
|
84 |
+
backupCount=2,
|
85 |
+
)
|
86 |
+
debug_file_handler.setLevel(logging.DEBUG)
|
87 |
+
debug_file_handler.setFormatter(logging.Formatter(fmt))
|
88 |
+
|
89 |
+
info_file_handler = RotatingFileHandler(
|
90 |
+
filename=os.path.join(log_directory, "info.log"),
|
91 |
+
maxBytes=1*1024*1024, # 1MB
|
92 |
+
encoding="utf-8",
|
93 |
+
backupCount=2,
|
94 |
+
)
|
95 |
+
info_file_handler.setLevel(logging.INFO)
|
96 |
+
info_file_handler.setFormatter(logging.Formatter(fmt))
|
97 |
+
|
98 |
+
error_file_handler = RotatingFileHandler(
|
99 |
+
filename=os.path.join(log_directory, "error.log"),
|
100 |
+
maxBytes=1*1024*1024, # 1MB
|
101 |
+
encoding="utf-8",
|
102 |
+
backupCount=2,
|
103 |
+
)
|
104 |
+
error_file_handler.setLevel(logging.ERROR)
|
105 |
+
error_file_handler.setFormatter(logging.Formatter(fmt))
|
106 |
+
|
107 |
+
logging.basicConfig(
|
108 |
+
level=logging.DEBUG,
|
109 |
+
datefmt="%a, %d %b %Y %H:%M:%S",
|
110 |
+
handlers=[
|
111 |
+
debug_file_handler,
|
112 |
+
info_file_handler,
|
113 |
+
error_file_handler,
|
114 |
+
]
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
def setup_time_rotating(log_directory: str):
|
119 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
120 |
+
|
121 |
+
stream_handler = logging.StreamHandler()
|
122 |
+
stream_handler.setLevel(logging.INFO)
|
123 |
+
stream_handler.setFormatter(logging.Formatter(fmt))
|
124 |
+
|
125 |
+
# main
|
126 |
+
main_logger = logging.getLogger("main")
|
127 |
+
main_logger.addHandler(stream_handler)
|
128 |
+
main_info_file_handler = TimedRotatingFileHandler(
|
129 |
+
filename=os.path.join(log_directory, "main.log"),
|
130 |
+
encoding="utf-8",
|
131 |
+
when="midnight",
|
132 |
+
interval=1,
|
133 |
+
backupCount=7
|
134 |
+
)
|
135 |
+
main_info_file_handler.setLevel(logging.INFO)
|
136 |
+
main_info_file_handler.setFormatter(logging.Formatter(fmt))
|
137 |
+
main_logger.addHandler(main_info_file_handler)
|
138 |
+
|
139 |
+
# http
|
140 |
+
http_logger = logging.getLogger("http")
|
141 |
+
http_file_handler = TimedRotatingFileHandler(
|
142 |
+
filename=os.path.join(log_directory, "http.log"),
|
143 |
+
encoding='utf-8',
|
144 |
+
when="midnight",
|
145 |
+
interval=1,
|
146 |
+
backupCount=7
|
147 |
+
)
|
148 |
+
http_file_handler.setLevel(logging.DEBUG)
|
149 |
+
http_file_handler.setFormatter(logging.Formatter(fmt))
|
150 |
+
http_logger.addHandler(http_file_handler)
|
151 |
+
|
152 |
+
# api
|
153 |
+
api_logger = logging.getLogger("api")
|
154 |
+
api_file_handler = TimedRotatingFileHandler(
|
155 |
+
filename=os.path.join(log_directory, "api.log"),
|
156 |
+
encoding='utf-8',
|
157 |
+
when="midnight",
|
158 |
+
interval=1,
|
159 |
+
backupCount=7
|
160 |
+
)
|
161 |
+
api_file_handler.setLevel(logging.DEBUG)
|
162 |
+
api_file_handler.setFormatter(logging.Formatter(fmt))
|
163 |
+
api_logger.addHandler(api_file_handler)
|
164 |
+
|
165 |
+
# alarm
|
166 |
+
alarm_logger = logging.getLogger("alarm")
|
167 |
+
alarm_file_handler = TimedRotatingFileHandler(
|
168 |
+
filename=os.path.join(log_directory, "alarm.log"),
|
169 |
+
encoding="utf-8",
|
170 |
+
when="midnight",
|
171 |
+
interval=1,
|
172 |
+
backupCount=7
|
173 |
+
)
|
174 |
+
alarm_file_handler.setLevel(logging.DEBUG)
|
175 |
+
alarm_file_handler.setFormatter(logging.Formatter(fmt))
|
176 |
+
alarm_logger.addHandler(alarm_file_handler)
|
177 |
+
|
178 |
+
debug_file_handler = TimedRotatingFileHandler(
|
179 |
+
filename=os.path.join(log_directory, "debug.log"),
|
180 |
+
encoding="utf-8",
|
181 |
+
when="D",
|
182 |
+
interval=1,
|
183 |
+
backupCount=7
|
184 |
+
)
|
185 |
+
debug_file_handler.setLevel(logging.DEBUG)
|
186 |
+
debug_file_handler.setFormatter(logging.Formatter(fmt))
|
187 |
+
|
188 |
+
info_file_handler = TimedRotatingFileHandler(
|
189 |
+
filename=os.path.join(log_directory, "info.log"),
|
190 |
+
encoding="utf-8",
|
191 |
+
when="D",
|
192 |
+
interval=1,
|
193 |
+
backupCount=7
|
194 |
+
)
|
195 |
+
info_file_handler.setLevel(logging.INFO)
|
196 |
+
info_file_handler.setFormatter(logging.Formatter(fmt))
|
197 |
+
|
198 |
+
error_file_handler = TimedRotatingFileHandler(
|
199 |
+
filename=os.path.join(log_directory, "error.log"),
|
200 |
+
encoding="utf-8",
|
201 |
+
when="D",
|
202 |
+
interval=1,
|
203 |
+
backupCount=7
|
204 |
+
)
|
205 |
+
error_file_handler.setLevel(logging.ERROR)
|
206 |
+
error_file_handler.setFormatter(logging.Formatter(fmt))
|
207 |
+
|
208 |
+
logging.basicConfig(
|
209 |
+
level=logging.DEBUG,
|
210 |
+
datefmt="%a, %d %b %Y %H:%M:%S",
|
211 |
+
handlers=[
|
212 |
+
debug_file_handler,
|
213 |
+
info_file_handler,
|
214 |
+
error_file_handler,
|
215 |
+
]
|
216 |
+
)
|
217 |
+
|
218 |
+
|
219 |
+
if __name__ == "__main__":
|
220 |
+
pass
|
main.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
import platform
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
import log
|
10 |
+
from project_settings import environment, log_directory, time_zone_info
|
11 |
+
from toolbox.os.command import Command
|
12 |
+
|
13 |
+
log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
|
14 |
+
|
15 |
+
logger = logging.getLogger("main")
|
16 |
+
|
17 |
+
|
18 |
+
def get_args():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument(
|
21 |
+
"--hf_token",
|
22 |
+
default=environment.get("hf_token"),
|
23 |
+
type=str,
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--server_port",
|
27 |
+
default=environment.get("server_port", 7860),
|
28 |
+
type=int
|
29 |
+
)
|
30 |
+
|
31 |
+
args = parser.parse_args()
|
32 |
+
return args
|
33 |
+
|
34 |
+
|
35 |
+
def shell(cmd: str):
|
36 |
+
return Command.popen(cmd)
|
37 |
+
|
38 |
+
|
39 |
+
def main():
|
40 |
+
args = get_args()
|
41 |
+
|
42 |
+
# ui
|
43 |
+
with gr.Blocks() as blocks:
|
44 |
+
gr.Markdown(value="vad.")
|
45 |
+
with gr.Tabs():
|
46 |
+
with gr.TabItem("shell"):
|
47 |
+
shell_text = gr.Textbox(label="cmd")
|
48 |
+
shell_button = gr.Button("run")
|
49 |
+
shell_output = gr.Textbox(label="output")
|
50 |
+
|
51 |
+
shell_button.click(
|
52 |
+
shell,
|
53 |
+
inputs=[shell_text,],
|
54 |
+
outputs=[shell_output],
|
55 |
+
)
|
56 |
+
|
57 |
+
# http://127.0.0.1:7866/
|
58 |
+
# http://10.75.27.247:7866/
|
59 |
+
blocks.queue().launch(
|
60 |
+
# share=True,
|
61 |
+
share=False if platform.system() == "Windows" else False,
|
62 |
+
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
|
63 |
+
server_port=args.server_port
|
64 |
+
)
|
65 |
+
return
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == "__main__":
|
69 |
+
main()
|
project_settings.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from toolbox.os.environment import EnvironmentManager
|
7 |
+
|
8 |
+
|
9 |
+
project_path = os.path.abspath(os.path.dirname(__file__))
|
10 |
+
project_path = Path(project_path)
|
11 |
+
|
12 |
+
time_zone_info = "Asia/Shanghai"
|
13 |
+
|
14 |
+
log_directory = project_path / "logs"
|
15 |
+
log_directory.mkdir(parents=True, exist_ok=True)
|
16 |
+
|
17 |
+
# temp_directory = project_path / "temp"
|
18 |
+
# temp_directory.mkdir(parents=True, exist_ok=True)
|
19 |
+
|
20 |
+
environment = EnvironmentManager(
|
21 |
+
path=os.path.join(project_path, "dotenv"),
|
22 |
+
env=os.environ.get("environment", "dev"),
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == '__main__':
|
27 |
+
pass
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==5.33.0
|
2 |
+
gradio_client==1.10.2
|
3 |
+
datasets==3.2.0
|
4 |
+
python-dotenv==1.0.1
|
5 |
+
scipy==1.15.1
|
6 |
+
librosa==0.10.2.post1
|
7 |
+
pandas==2.2.3
|
8 |
+
openpyxl==3.1.5
|
9 |
+
torch==2.5.1
|
10 |
+
torchaudio==2.5.1
|
11 |
+
overrides==7.7.0
|
12 |
+
webrtcvad==2.0.10
|
13 |
+
matplotlib==3.10.3
|
toolbox/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/json/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/json/misc.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from typing import Callable
|
4 |
+
|
5 |
+
|
6 |
+
def traverse(js, callback: Callable, *args, **kwargs):
|
7 |
+
if isinstance(js, list):
|
8 |
+
result = list()
|
9 |
+
for l in js:
|
10 |
+
l = traverse(l, callback, *args, **kwargs)
|
11 |
+
result.append(l)
|
12 |
+
return result
|
13 |
+
elif isinstance(js, tuple):
|
14 |
+
result = list()
|
15 |
+
for l in js:
|
16 |
+
l = traverse(l, callback, *args, **kwargs)
|
17 |
+
result.append(l)
|
18 |
+
return tuple(result)
|
19 |
+
elif isinstance(js, dict):
|
20 |
+
result = dict()
|
21 |
+
for k, v in js.items():
|
22 |
+
k = traverse(k, callback, *args, **kwargs)
|
23 |
+
v = traverse(v, callback, *args, **kwargs)
|
24 |
+
result[k] = v
|
25 |
+
return result
|
26 |
+
elif isinstance(js, int):
|
27 |
+
return callback(js, *args, **kwargs)
|
28 |
+
elif isinstance(js, str):
|
29 |
+
return callback(js, *args, **kwargs)
|
30 |
+
else:
|
31 |
+
return js
|
32 |
+
|
33 |
+
|
34 |
+
def demo1():
|
35 |
+
d = {
|
36 |
+
"env": "ppe",
|
37 |
+
"mysql_connect": {
|
38 |
+
"host": "$mysql_connect_host",
|
39 |
+
"port": 3306,
|
40 |
+
"user": "callbot",
|
41 |
+
"password": "NxcloudAI2021!",
|
42 |
+
"database": "callbot_ppe",
|
43 |
+
"charset": "utf8"
|
44 |
+
},
|
45 |
+
"es_connect": {
|
46 |
+
"hosts": ["10.20.251.8"],
|
47 |
+
"http_auth": ["elastic", "ElasticAI2021!"],
|
48 |
+
"port": 9200
|
49 |
+
}
|
50 |
+
}
|
51 |
+
|
52 |
+
def callback(s):
|
53 |
+
if isinstance(s, str) and s.startswith('$'):
|
54 |
+
return s[1:]
|
55 |
+
return s
|
56 |
+
|
57 |
+
result = traverse(d, callback=callback)
|
58 |
+
print(result)
|
59 |
+
return
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
demo1()
|
toolbox/os/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/os/command.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class Command(object):
|
7 |
+
custom_command = [
|
8 |
+
"cd"
|
9 |
+
]
|
10 |
+
|
11 |
+
@staticmethod
|
12 |
+
def _get_cmd(command):
|
13 |
+
command = str(command).strip()
|
14 |
+
if command == "":
|
15 |
+
return None
|
16 |
+
cmd_and_args = command.split(sep=" ")
|
17 |
+
cmd = cmd_and_args[0]
|
18 |
+
args = " ".join(cmd_and_args[1:])
|
19 |
+
return cmd, args
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
def popen(cls, command):
|
23 |
+
cmd, args = cls._get_cmd(command)
|
24 |
+
if cmd in cls.custom_command:
|
25 |
+
method = getattr(cls, cmd)
|
26 |
+
return method(args)
|
27 |
+
else:
|
28 |
+
resp = os.popen(command)
|
29 |
+
result = resp.read()
|
30 |
+
resp.close()
|
31 |
+
return result
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def cd(cls, args):
|
35 |
+
if args.startswith("/"):
|
36 |
+
os.chdir(args)
|
37 |
+
else:
|
38 |
+
pwd = os.getcwd()
|
39 |
+
path = os.path.join(pwd, args)
|
40 |
+
os.chdir(path)
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def system(cls, command):
|
44 |
+
return os.system(command)
|
45 |
+
|
46 |
+
def __init__(self):
|
47 |
+
pass
|
48 |
+
|
49 |
+
|
50 |
+
def ps_ef_grep(keyword: str):
|
51 |
+
cmd = "ps -ef | grep {}".format(keyword)
|
52 |
+
rows = Command.popen(cmd)
|
53 |
+
rows = str(rows).split("\n")
|
54 |
+
rows = [row for row in rows if row.__contains__(keyword) and not row.__contains__("grep")]
|
55 |
+
return rows
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
pass
|
toolbox/os/environment.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from dotenv.main import DotEnv
|
8 |
+
|
9 |
+
from toolbox.json.misc import traverse
|
10 |
+
|
11 |
+
|
12 |
+
class EnvironmentManager(object):
|
13 |
+
def __init__(self, path, env, override=False):
|
14 |
+
filename = os.path.join(path, '{}.env'.format(env))
|
15 |
+
self.filename = filename
|
16 |
+
|
17 |
+
load_dotenv(
|
18 |
+
dotenv_path=filename,
|
19 |
+
override=override
|
20 |
+
)
|
21 |
+
|
22 |
+
self._environ = dict()
|
23 |
+
|
24 |
+
def open_dotenv(self, filename: str = None):
|
25 |
+
filename = filename or self.filename
|
26 |
+
dotenv = DotEnv(
|
27 |
+
dotenv_path=filename,
|
28 |
+
stream=None,
|
29 |
+
verbose=False,
|
30 |
+
interpolate=False,
|
31 |
+
override=False,
|
32 |
+
encoding="utf-8",
|
33 |
+
)
|
34 |
+
result = dotenv.dict()
|
35 |
+
return result
|
36 |
+
|
37 |
+
def get(self, key, default=None, dtype=str):
|
38 |
+
result = os.environ.get(key)
|
39 |
+
if result is None:
|
40 |
+
if default is None:
|
41 |
+
result = None
|
42 |
+
else:
|
43 |
+
result = default
|
44 |
+
else:
|
45 |
+
result = dtype(result)
|
46 |
+
self._environ[key] = result
|
47 |
+
return result
|
48 |
+
|
49 |
+
|
50 |
+
_DEFAULT_DTYPE_MAP = {
|
51 |
+
'int': int,
|
52 |
+
'float': float,
|
53 |
+
'str': str,
|
54 |
+
'json.loads': json.loads
|
55 |
+
}
|
56 |
+
|
57 |
+
|
58 |
+
class JsonConfig(object):
|
59 |
+
"""
|
60 |
+
将 json 中, 形如 `$float:threshold` 的值, 处理为:
|
61 |
+
从环境变量中查到 threshold, 再将其转换为 float 类型.
|
62 |
+
"""
|
63 |
+
def __init__(self, dtype_map: dict = None, environment: EnvironmentManager = None):
|
64 |
+
self.dtype_map = dtype_map or _DEFAULT_DTYPE_MAP
|
65 |
+
self.environment = environment or os.environ
|
66 |
+
|
67 |
+
def sanitize_by_filename(self, filename: str):
|
68 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
69 |
+
js = json.load(f)
|
70 |
+
|
71 |
+
return self.sanitize_by_json(js)
|
72 |
+
|
73 |
+
def sanitize_by_json(self, js):
|
74 |
+
js = traverse(
|
75 |
+
js,
|
76 |
+
callback=self.sanitize,
|
77 |
+
environment=self.environment
|
78 |
+
)
|
79 |
+
return js
|
80 |
+
|
81 |
+
def sanitize(self, string, environment):
|
82 |
+
"""支持 $ 符开始的, 环境变量配置"""
|
83 |
+
if isinstance(string, str) and string.startswith('$'):
|
84 |
+
dtype, key = string[1:].split(':')
|
85 |
+
dtype = self.dtype_map[dtype]
|
86 |
+
|
87 |
+
value = environment.get(key)
|
88 |
+
if value is None:
|
89 |
+
raise AssertionError('environment not exist. key: {}'.format(key))
|
90 |
+
|
91 |
+
value = dtype(value)
|
92 |
+
result = value
|
93 |
+
else:
|
94 |
+
result = string
|
95 |
+
return result
|
96 |
+
|
97 |
+
|
98 |
+
def demo1():
|
99 |
+
import json
|
100 |
+
|
101 |
+
from project_settings import project_path
|
102 |
+
|
103 |
+
environment = EnvironmentManager(
|
104 |
+
path=os.path.join(project_path, 'server/callbot_server/dotenv'),
|
105 |
+
env='dev',
|
106 |
+
)
|
107 |
+
init_scenes = environment.get(key='init_scenes', dtype=json.loads)
|
108 |
+
print(init_scenes)
|
109 |
+
print(environment._environ)
|
110 |
+
return
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == '__main__':
|
114 |
+
demo1()
|
toolbox/os/other.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import inspect
|
3 |
+
|
4 |
+
|
5 |
+
def pwd():
|
6 |
+
"""你在哪个文件调用此函数, 它就会返回那个文件所在的 dir 目标"""
|
7 |
+
frame = inspect.stack()[1]
|
8 |
+
module = inspect.getmodule(frame[0])
|
9 |
+
return os.path.dirname(os.path.abspath(module.__file__))
|
toolbox/torch/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torch/utils/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torch/utils/data/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torch/utils/data/dataset/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torch/utils/data/dataset/vad_jsonl_dataset.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import json
|
4 |
+
import random
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from torch.utils.data import Dataset, IterableDataset
|
11 |
+
|
12 |
+
|
13 |
+
class VadJsonlDataset(IterableDataset):
|
14 |
+
def __init__(self,
|
15 |
+
jsonl_file: str,
|
16 |
+
expected_sample_rate: int,
|
17 |
+
resample: bool = False,
|
18 |
+
max_wave_value: float = 1.0,
|
19 |
+
buffer_size: int = 1000,
|
20 |
+
min_snr_db: float = None,
|
21 |
+
max_snr_db: float = None,
|
22 |
+
eps: float = 1e-8,
|
23 |
+
skip: int = 0,
|
24 |
+
):
|
25 |
+
self.jsonl_file = jsonl_file
|
26 |
+
self.expected_sample_rate = expected_sample_rate
|
27 |
+
self.resample = resample
|
28 |
+
self.max_wave_value = max_wave_value
|
29 |
+
self.min_snr_db = min_snr_db
|
30 |
+
self.max_snr_db = max_snr_db
|
31 |
+
self.eps = eps
|
32 |
+
self.skip = skip
|
33 |
+
|
34 |
+
self.buffer_size = buffer_size
|
35 |
+
self.buffer_samples: List[dict] = list()
|
36 |
+
|
37 |
+
def __iter__(self):
|
38 |
+
self.buffer_samples = list()
|
39 |
+
|
40 |
+
iterable_source = self.iterable_source()
|
41 |
+
|
42 |
+
try:
|
43 |
+
for _ in range(self.skip):
|
44 |
+
next(iterable_source)
|
45 |
+
except StopIteration:
|
46 |
+
pass
|
47 |
+
|
48 |
+
# 初始填充缓冲区
|
49 |
+
try:
|
50 |
+
for _ in range(self.buffer_size):
|
51 |
+
self.buffer_samples.append(next(iterable_source))
|
52 |
+
except StopIteration:
|
53 |
+
pass
|
54 |
+
|
55 |
+
# 动态替换逻辑
|
56 |
+
while True:
|
57 |
+
try:
|
58 |
+
item = next(iterable_source)
|
59 |
+
# 随机替换缓冲区元素
|
60 |
+
replace_idx = random.randint(0, len(self.buffer_samples) - 1)
|
61 |
+
sample = self.buffer_samples[replace_idx]
|
62 |
+
self.buffer_samples[replace_idx] = item
|
63 |
+
yield self.convert_sample(sample)
|
64 |
+
except StopIteration:
|
65 |
+
break
|
66 |
+
|
67 |
+
# 清空剩余元素
|
68 |
+
random.shuffle(self.buffer_samples)
|
69 |
+
for sample in self.buffer_samples:
|
70 |
+
yield self.convert_sample(sample)
|
71 |
+
|
72 |
+
def iterable_source(self):
|
73 |
+
last_sample = None
|
74 |
+
with open(self.jsonl_file, "r", encoding="utf-8") as f:
|
75 |
+
for row in f:
|
76 |
+
row = json.loads(row)
|
77 |
+
noise_filename = row["noise_filename"]
|
78 |
+
noise_raw_duration = row["noise_raw_duration"]
|
79 |
+
noise_offset = row["noise_offset"]
|
80 |
+
noise_duration = row["noise_duration"]
|
81 |
+
|
82 |
+
speech_filename = row["speech_filename"]
|
83 |
+
speech_raw_duration = row["speech_raw_duration"]
|
84 |
+
speech_offset = row["speech_offset"]
|
85 |
+
speech_duration = row["speech_duration"]
|
86 |
+
|
87 |
+
if self.min_snr_db is None or self.max_snr_db is None:
|
88 |
+
snr_db = row["snr_db"]
|
89 |
+
else:
|
90 |
+
snr_db = random.uniform(self.min_snr_db, self.max_snr_db)
|
91 |
+
|
92 |
+
vad_segments = row["vad_segments"]
|
93 |
+
|
94 |
+
sample = {
|
95 |
+
"noise_filename": noise_filename,
|
96 |
+
"noise_raw_duration": noise_raw_duration,
|
97 |
+
"noise_offset": noise_offset,
|
98 |
+
"noise_duration": noise_duration,
|
99 |
+
|
100 |
+
"speech_filename": speech_filename,
|
101 |
+
"speech_raw_duration": speech_raw_duration,
|
102 |
+
"speech_offset": speech_offset,
|
103 |
+
"speech_duration": speech_duration,
|
104 |
+
|
105 |
+
"snr_db": snr_db,
|
106 |
+
|
107 |
+
"vad_segments": vad_segments,
|
108 |
+
}
|
109 |
+
if last_sample is None:
|
110 |
+
last_sample = sample
|
111 |
+
continue
|
112 |
+
yield sample
|
113 |
+
yield last_sample
|
114 |
+
|
115 |
+
def convert_sample(self, sample: dict):
|
116 |
+
noise_filename = sample["noise_filename"]
|
117 |
+
noise_offset = sample["noise_offset"]
|
118 |
+
noise_duration = sample["noise_duration"]
|
119 |
+
|
120 |
+
speech_filename = sample["speech_filename"]
|
121 |
+
speech_offset = sample["speech_offset"]
|
122 |
+
speech_duration = sample["speech_duration"]
|
123 |
+
|
124 |
+
snr_db = sample["snr_db"]
|
125 |
+
|
126 |
+
vad_segments = sample["vad_segments"]
|
127 |
+
|
128 |
+
noise_wave = self.filename_to_waveform(noise_filename, noise_offset, noise_duration)
|
129 |
+
speech_wave = self.filename_to_waveform(speech_filename, speech_offset, speech_duration)
|
130 |
+
|
131 |
+
noisy_wave, _ = self.mix_speech_and_noise(
|
132 |
+
speech=speech_wave.numpy(),
|
133 |
+
noise=noise_wave.numpy(),
|
134 |
+
snr_db=snr_db, eps=self.eps,
|
135 |
+
)
|
136 |
+
noisy_wave = torch.tensor(noisy_wave, dtype=torch.float32)
|
137 |
+
|
138 |
+
result = {
|
139 |
+
"noisy_wave": noisy_wave,
|
140 |
+
"vad_segments": vad_segments,
|
141 |
+
}
|
142 |
+
return result
|
143 |
+
|
144 |
+
def filename_to_waveform(self, filename: str, offset: float, duration: float):
|
145 |
+
try:
|
146 |
+
waveform, sample_rate = librosa.load(
|
147 |
+
filename,
|
148 |
+
sr=self.expected_sample_rate,
|
149 |
+
offset=offset,
|
150 |
+
duration=duration,
|
151 |
+
)
|
152 |
+
except ValueError as e:
|
153 |
+
print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}")
|
154 |
+
raise e
|
155 |
+
waveform = torch.tensor(waveform, dtype=torch.float32)
|
156 |
+
return waveform
|
157 |
+
|
158 |
+
@staticmethod
|
159 |
+
def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float, eps: float = 1e-8):
|
160 |
+
l1 = len(speech)
|
161 |
+
l2 = len(noise)
|
162 |
+
l = min(l1, l2)
|
163 |
+
speech = speech[:l]
|
164 |
+
noise = noise[:l]
|
165 |
+
|
166 |
+
# np.float32, value between (-1, 1).
|
167 |
+
|
168 |
+
speech_power = np.mean(np.square(speech))
|
169 |
+
noise_power = speech_power / (10 ** (snr_db / 10))
|
170 |
+
|
171 |
+
noise_adjusted = np.sqrt(noise_power) * noise / (np.sqrt(np.mean(noise ** 2)) + eps)
|
172 |
+
|
173 |
+
noisy_signal = speech + noise_adjusted
|
174 |
+
|
175 |
+
return noisy_signal, noise_adjusted
|
176 |
+
|
177 |
+
|
178 |
+
if __name__ == "__main__":
|
179 |
+
pass
|
toolbox/torchaudio/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/configuration_utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import copy
|
4 |
+
import os
|
5 |
+
from typing import Any, Dict, Union
|
6 |
+
|
7 |
+
import yaml
|
8 |
+
|
9 |
+
|
10 |
+
CONFIG_FILE = "config.yaml"
|
11 |
+
DISCRIMINATOR_CONFIG_FILE = "discriminator_config.yaml"
|
12 |
+
|
13 |
+
|
14 |
+
class PretrainedConfig(object):
|
15 |
+
def __init__(self, **kwargs):
|
16 |
+
pass
|
17 |
+
|
18 |
+
@classmethod
|
19 |
+
def _dict_from_yaml_file(cls, yaml_file: Union[str, os.PathLike]):
|
20 |
+
with open(yaml_file, encoding="utf-8") as f:
|
21 |
+
config_dict = yaml.safe_load(f)
|
22 |
+
return config_dict
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def get_config_dict(
|
26 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike]
|
27 |
+
) -> Dict[str, Any]:
|
28 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
29 |
+
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_FILE)
|
30 |
+
else:
|
31 |
+
config_file = pretrained_model_name_or_path
|
32 |
+
config_dict = cls._dict_from_yaml_file(config_file)
|
33 |
+
return config_dict
|
34 |
+
|
35 |
+
@classmethod
|
36 |
+
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
|
37 |
+
for k, v in kwargs.items():
|
38 |
+
if k in config_dict.keys():
|
39 |
+
config_dict[k] = v
|
40 |
+
config = cls(**config_dict)
|
41 |
+
return config
|
42 |
+
|
43 |
+
@classmethod
|
44 |
+
def from_pretrained(
|
45 |
+
cls,
|
46 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
47 |
+
**kwargs,
|
48 |
+
):
|
49 |
+
config_dict = cls.get_config_dict(pretrained_model_name_or_path)
|
50 |
+
return cls.from_dict(config_dict, **kwargs)
|
51 |
+
|
52 |
+
def to_dict(self):
|
53 |
+
output = copy.deepcopy(self.__dict__)
|
54 |
+
return output
|
55 |
+
|
56 |
+
def to_yaml_file(self, yaml_file_path: Union[str, os.PathLike]):
|
57 |
+
config_dict = self.to_dict()
|
58 |
+
|
59 |
+
with open(yaml_file_path, "w", encoding="utf-8") as writer:
|
60 |
+
yaml.safe_dump(config_dict, writer)
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == '__main__':
|
64 |
+
pass
|
toolbox/torchaudio/losses/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/losses/vad_loss/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/losses/vad_loss/base_vad_loss.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
class BaseVadLoss(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super(BaseVadLoss, self).__init__()
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def get_targets(inputs: torch.Tensor, batch_vad_segments: List[List[Tuple[float, float]]], duration: float):
|
15 |
+
"""
|
16 |
+
:param inputs: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation.
|
17 |
+
:param batch_vad_segments: VAD segment for each audio
|
18 |
+
:param duration: float. The total duration of each audio in the batch.
|
19 |
+
:return: targets, shape as `inputs`.
|
20 |
+
"""
|
21 |
+
b, t, _ = inputs.shape
|
22 |
+
|
23 |
+
batch_vad_segments_ = list()
|
24 |
+
for vad_segments in batch_vad_segments:
|
25 |
+
vad_segments_ = list()
|
26 |
+
for start, end in vad_segments:
|
27 |
+
start_ = start / duration * t
|
28 |
+
end_ = end / duration * t
|
29 |
+
start_ = round(start_)
|
30 |
+
end_ = round(end_)
|
31 |
+
vad_segments_.append([start_, end_])
|
32 |
+
batch_vad_segments_.append(vad_segments_)
|
33 |
+
|
34 |
+
targets = torch.zeros_like(inputs)
|
35 |
+
for idx, vad_segments_ in enumerate(batch_vad_segments_):
|
36 |
+
for start_, end_ in vad_segments_:
|
37 |
+
targets[idx, start_:end_, :] = 1
|
38 |
+
|
39 |
+
return targets
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
pass
|
toolbox/torchaudio/losses/vad_loss/bce_vad_loss.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss
|
9 |
+
|
10 |
+
|
11 |
+
class BCEVadLoss(BaseVadLoss):
|
12 |
+
"""
|
13 |
+
Binary Cross-Entropy Loss, BCE Loss
|
14 |
+
"""
|
15 |
+
def __init__(self,
|
16 |
+
reduction: str = "mean",
|
17 |
+
):
|
18 |
+
super(BCEVadLoss, self).__init__()
|
19 |
+
self.reduction = reduction
|
20 |
+
|
21 |
+
self.bce_loss_fn = nn.BCELoss(reduction=reduction)
|
22 |
+
|
23 |
+
def forward(self, inputs: torch.Tensor, batch_vad_segments: List[List[Tuple[float, float]]], duration: float):
|
24 |
+
"""
|
25 |
+
:param inputs: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation.
|
26 |
+
:param batch_vad_segments: VAD segment for each audio
|
27 |
+
:param duration: float. The total duration of each audio in the batch.
|
28 |
+
:return:
|
29 |
+
"""
|
30 |
+
|
31 |
+
targets = self.get_targets(inputs, batch_vad_segments, duration)
|
32 |
+
|
33 |
+
loss = self.bce_loss_fn.forward(inputs, targets)
|
34 |
+
return loss
|
35 |
+
|
36 |
+
|
37 |
+
def main():
|
38 |
+
inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
|
39 |
+
|
40 |
+
vad_segments = [
|
41 |
+
[[0.24, 1.15], [2.21, 3.2]],
|
42 |
+
]
|
43 |
+
|
44 |
+
loss_fn = BCEVadLoss()
|
45 |
+
|
46 |
+
loss = loss_fn.forward(inputs, vad_segments, duration=4)
|
47 |
+
print(loss)
|
48 |
+
return
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
main()
|
toolbox/torchaudio/losses/vad_loss/dice_vad_loss.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss
|
9 |
+
|
10 |
+
|
11 |
+
class DiceVadLoss(BaseVadLoss):
|
12 |
+
def __init__(self,
|
13 |
+
reduction: str = "mean",
|
14 |
+
eps: float = 1e-6,
|
15 |
+
):
|
16 |
+
super(DiceVadLoss, self).__init__()
|
17 |
+
self.reduction = reduction
|
18 |
+
self.eps = eps
|
19 |
+
|
20 |
+
if reduction not in ("sum", "mean"):
|
21 |
+
raise AssertionError(f"param reduction must be sum or mean.")
|
22 |
+
|
23 |
+
def forward(self, inputs: torch.Tensor, batch_vad_segments: List[List[Tuple[float, float]]], duration: float):
|
24 |
+
"""
|
25 |
+
:param inputs: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation.
|
26 |
+
:param batch_vad_segments: VAD segment for each audio
|
27 |
+
:param duration: float. The total duration of each audio in the batch.
|
28 |
+
:return:
|
29 |
+
"""
|
30 |
+
targets = self.get_targets(inputs, batch_vad_segments, duration)
|
31 |
+
|
32 |
+
inputs_ = torch.squeeze(inputs, dim=-1)
|
33 |
+
targets_ = torch.squeeze(targets, dim=-1)
|
34 |
+
# shape: [b, t]
|
35 |
+
|
36 |
+
intersection = (inputs_ * targets_).sum(dim=-1)
|
37 |
+
union = (inputs_ + targets_).sum(dim=-1)
|
38 |
+
# shape: [b,]
|
39 |
+
|
40 |
+
dice = (2. * intersection + self.eps) / (union + self.eps)
|
41 |
+
# shape: [b,]
|
42 |
+
|
43 |
+
loss = 1. - dice
|
44 |
+
# shape: [b,]
|
45 |
+
|
46 |
+
if self.reduction == "mean":
|
47 |
+
loss = torch.mean(loss)
|
48 |
+
elif self.reduction == "sum":
|
49 |
+
loss = torch.sum(loss)
|
50 |
+
else:
|
51 |
+
raise AssertionError
|
52 |
+
return loss
|
53 |
+
|
54 |
+
|
55 |
+
def main():
|
56 |
+
inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
|
57 |
+
|
58 |
+
vad_segments = [
|
59 |
+
[[0.24, 1.15], [2.21, 3.2]],
|
60 |
+
]
|
61 |
+
|
62 |
+
loss_fn = DiceVadLoss()
|
63 |
+
|
64 |
+
loss = loss_fn.forward(inputs, vad_segments, duration=4)
|
65 |
+
print(loss)
|
66 |
+
return
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
main()
|
toolbox/torchaudio/metrics/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/metrics/vad_metrics/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/metrics/vad_metrics/vad_accuracy.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class VadAccuracy(object):
|
7 |
+
def __init__(self, threshold: float = 0.5) -> None:
|
8 |
+
self.threshold = threshold
|
9 |
+
|
10 |
+
self.correct_count = 0.
|
11 |
+
self.total_count = 0.
|
12 |
+
|
13 |
+
def __call__(self,
|
14 |
+
predictions: torch.Tensor,
|
15 |
+
gold_labels: torch.Tensor,
|
16 |
+
):
|
17 |
+
"""
|
18 |
+
:param predictions: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation.
|
19 |
+
:param gold_labels: torch.Tensor, shape: [b, t, 1].
|
20 |
+
:return:
|
21 |
+
"""
|
22 |
+
predictions = (predictions > self.threshold).float()
|
23 |
+
correct = predictions.eq(gold_labels).float()
|
24 |
+
self.correct_count += correct.sum()
|
25 |
+
self.total_count += gold_labels.numel()
|
26 |
+
|
27 |
+
def get_metric(self, reset: bool = False):
|
28 |
+
"""
|
29 |
+
Returns
|
30 |
+
-------
|
31 |
+
The accumulated accuracy.
|
32 |
+
"""
|
33 |
+
if self.total_count > 1e-12:
|
34 |
+
accuracy = float(self.correct_count) / float(self.total_count)
|
35 |
+
else:
|
36 |
+
accuracy = 0.0
|
37 |
+
if reset:
|
38 |
+
self.reset()
|
39 |
+
return {'accuracy': accuracy}
|
40 |
+
|
41 |
+
def reset(self):
|
42 |
+
self.correct_count = 0.0
|
43 |
+
self.total_count = 0.0
|
44 |
+
|
45 |
+
|
46 |
+
def main():
|
47 |
+
inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
|
48 |
+
targets = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
|
49 |
+
|
50 |
+
metric_fn = VadAccuracy()
|
51 |
+
|
52 |
+
metric_fn.__call__(inputs, targets)
|
53 |
+
|
54 |
+
metrics = metric_fn.get_metric()
|
55 |
+
print(metrics)
|
56 |
+
return
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
main()
|
toolbox/torchaudio/models/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/models/vad/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/models/vad/silero_vad/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/models/vad/silero_vad/configuration_silero_vad.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
from toolbox.torchaudio.configuration_utils import PretrainedConfig
|
6 |
+
|
7 |
+
|
8 |
+
class SileroVadConfig(PretrainedConfig):
|
9 |
+
def __init__(self,
|
10 |
+
sample_rate: int = 8000,
|
11 |
+
nfft: int = 512,
|
12 |
+
win_size: int = 240,
|
13 |
+
hop_size: int = 80,
|
14 |
+
win_type: str = "hann",
|
15 |
+
|
16 |
+
in_channels: int = 64,
|
17 |
+
hidden_size: int = 128,
|
18 |
+
|
19 |
+
lr: float = 0.001,
|
20 |
+
lr_scheduler: str = "CosineAnnealingLR",
|
21 |
+
lr_scheduler_kwargs: dict = None,
|
22 |
+
|
23 |
+
max_epochs: int = 100,
|
24 |
+
clip_grad_norm: float = 10.,
|
25 |
+
seed: int = 1234,
|
26 |
+
|
27 |
+
num_workers: int = 4,
|
28 |
+
batch_size: int = 4,
|
29 |
+
eval_steps: int = 25000,
|
30 |
+
|
31 |
+
**kwargs
|
32 |
+
):
|
33 |
+
super(SileroVadConfig, self).__init__(**kwargs)
|
34 |
+
# transform
|
35 |
+
self.sample_rate = sample_rate
|
36 |
+
self.nfft = nfft
|
37 |
+
self.win_size = win_size
|
38 |
+
self.hop_size = hop_size
|
39 |
+
self.win_type = win_type
|
40 |
+
|
41 |
+
# encoder
|
42 |
+
self.in_channels = in_channels
|
43 |
+
self.hidden_size = hidden_size
|
44 |
+
|
45 |
+
# train
|
46 |
+
self.lr = lr
|
47 |
+
self.lr_scheduler = lr_scheduler
|
48 |
+
self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
|
49 |
+
|
50 |
+
self.max_epochs = max_epochs
|
51 |
+
self.clip_grad_norm = clip_grad_norm
|
52 |
+
self.seed = seed
|
53 |
+
|
54 |
+
self.num_workers = num_workers
|
55 |
+
self.batch_size = batch_size
|
56 |
+
self.eval_steps = eval_steps
|
57 |
+
|
58 |
+
|
59 |
+
def main():
|
60 |
+
config = SileroVadConfig()
|
61 |
+
config.to_yaml_file("config.yaml")
|
62 |
+
return
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
main()
|
toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/snakers4/silero-vad/wiki/Quality-Metrics
|
5 |
+
|
6 |
+
https://pytorch.org/hub/snakers4_silero-vad_vad/
|
7 |
+
https://github.com/snakers4/silero-vad
|
8 |
+
|
9 |
+
https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/data/silero_vad.jit
|
10 |
+
"""
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
|
15 |
+
from toolbox.torchaudio.modules.conv_stft import ConvSTFT
|
16 |
+
|
17 |
+
|
18 |
+
MODEL_FILE = "model.pt"
|
19 |
+
|
20 |
+
|
21 |
+
class EncoderBlock(nn.Module):
|
22 |
+
def __init__(self,
|
23 |
+
in_channels: int = 64,
|
24 |
+
out_channels: int = 128,
|
25 |
+
):
|
26 |
+
super(EncoderBlock, self).__init__()
|
27 |
+
self.conv1d = nn.Conv1d(
|
28 |
+
in_channels=in_channels,
|
29 |
+
out_channels=out_channels,
|
30 |
+
kernel_size=3,
|
31 |
+
padding="same",
|
32 |
+
)
|
33 |
+
self.activation = nn.ReLU()
|
34 |
+
self.norm = nn.BatchNorm1d(out_channels)
|
35 |
+
|
36 |
+
def forward(self, x: torch.Tensor):
|
37 |
+
# x shape: [b, t, f]
|
38 |
+
x = torch.transpose(x, dim0=1, dim1=2)
|
39 |
+
# x shape: [b, f, t]
|
40 |
+
|
41 |
+
x = self.conv1d.forward(x)
|
42 |
+
x = self.activation(x)
|
43 |
+
x = self.norm(x)
|
44 |
+
|
45 |
+
x = torch.transpose(x, dim0=1, dim1=2)
|
46 |
+
# x shape: [b, t, f]
|
47 |
+
|
48 |
+
return x
|
49 |
+
|
50 |
+
|
51 |
+
class Encoder(nn.Module):
|
52 |
+
def __init__(self,
|
53 |
+
in_channels: int = 64,
|
54 |
+
out_channels: int = 128,
|
55 |
+
num_layers: int = 3,
|
56 |
+
):
|
57 |
+
super(Encoder, self).__init__()
|
58 |
+
|
59 |
+
self.layers = nn.ModuleList(modules=[
|
60 |
+
EncoderBlock(
|
61 |
+
in_channels=in_channels,
|
62 |
+
out_channels=out_channels,
|
63 |
+
)
|
64 |
+
if i == 0 else
|
65 |
+
EncoderBlock(
|
66 |
+
in_channels=out_channels,
|
67 |
+
out_channels=out_channels,
|
68 |
+
)
|
69 |
+
for i in range(num_layers)
|
70 |
+
])
|
71 |
+
|
72 |
+
def forward(self, x: torch.Tensor):
|
73 |
+
for layer in self.layers:
|
74 |
+
x = layer.forward(x)
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
class SileroVadModel(nn.Module):
|
79 |
+
def __init__(self, config: SileroVadConfig):
|
80 |
+
super(SileroVadModel, self).__init__()
|
81 |
+
self.config = config
|
82 |
+
self.eps = 1e-12
|
83 |
+
|
84 |
+
self.stft = ConvSTFT(
|
85 |
+
nfft=config.nfft,
|
86 |
+
win_size=config.win_size,
|
87 |
+
hop_size=config.hop_size,
|
88 |
+
win_type=config.win_type,
|
89 |
+
power=1,
|
90 |
+
requires_grad=False
|
91 |
+
)
|
92 |
+
|
93 |
+
self.linear = nn.Linear(
|
94 |
+
in_features=(config.nfft // 2 + 1),
|
95 |
+
out_features=config.in_channels,
|
96 |
+
)
|
97 |
+
|
98 |
+
self.encoder = Encoder(
|
99 |
+
in_channels=config.in_channels,
|
100 |
+
out_channels=config.hidden_size,
|
101 |
+
)
|
102 |
+
|
103 |
+
self.lstm = nn.LSTM(
|
104 |
+
input_size=config.hidden_size,
|
105 |
+
hidden_size=config.hidden_size,
|
106 |
+
bidirectional=False,
|
107 |
+
batch_first=True
|
108 |
+
)
|
109 |
+
|
110 |
+
self.classifier = nn.Sequential(
|
111 |
+
nn.Linear(config.hidden_size, 32),
|
112 |
+
nn.ReLU(),
|
113 |
+
nn.Linear(32, 1),
|
114 |
+
nn.Sigmoid()
|
115 |
+
)
|
116 |
+
|
117 |
+
def forward(self, signal: torch.Tensor):
|
118 |
+
mags = self.stft.forward(signal)
|
119 |
+
# mags shape: [b, f, t]
|
120 |
+
|
121 |
+
x = torch.transpose(mags, dim0=1, dim1=2)
|
122 |
+
# x shape: [b, t, f]
|
123 |
+
|
124 |
+
x = self.linear.forward(x)
|
125 |
+
# x shape: [b, t, f']
|
126 |
+
|
127 |
+
x = self.encoder.forward(x)
|
128 |
+
# x shape: [b, t, f]
|
129 |
+
|
130 |
+
x, _ = self.lstm.forward(x)
|
131 |
+
x = self.classifier.forward(x)
|
132 |
+
|
133 |
+
# x shape: [b, t, 1]
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
def main():
|
138 |
+
config = SileroVadConfig()
|
139 |
+
model = SileroVadModel(config=config)
|
140 |
+
|
141 |
+
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
|
142 |
+
|
143 |
+
probs = model.forward(noisy)
|
144 |
+
print(f"probs: {probs}")
|
145 |
+
print(f"probs.shape: {probs.shape}")
|
146 |
+
|
147 |
+
return
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
main()
|
toolbox/torchaudio/models/vad/silero_vad/yaml/config.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "silero_vad"
|
2 |
+
|
3 |
+
sample_rate: 8000
|
4 |
+
nfft: 512
|
5 |
+
win_size: 240
|
6 |
+
hop_size: 80
|
7 |
+
win_type: hann
|
8 |
+
|
9 |
+
in_channels: 64
|
10 |
+
hidden_size: 128
|
11 |
+
|
12 |
+
lr: 0.001
|
13 |
+
lr_scheduler: CosineAnnealingLR
|
14 |
+
lr_scheduler_kwargs: {}
|
15 |
+
|
16 |
+
max_epochs: 100
|
17 |
+
clip_grad_norm: 10.0
|
18 |
+
seed: 1234
|
19 |
+
|
20 |
+
num_workers: 4
|
21 |
+
batch_size: 4
|
22 |
+
eval_steps: 25000
|
toolbox/torchaudio/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/modules/conv_stft.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py
|
5 |
+
"""
|
6 |
+
from collections import defaultdict
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from scipy.signal import get_window
|
12 |
+
|
13 |
+
|
14 |
+
def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
|
15 |
+
if win_type == "None" or win_type is None:
|
16 |
+
window = np.ones(win_size)
|
17 |
+
else:
|
18 |
+
window = get_window(win_type, win_size, fftbins=True)**0.5
|
19 |
+
|
20 |
+
fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size]
|
21 |
+
real_kernel = np.real(fourier_basis)
|
22 |
+
image_kernel = np.imag(fourier_basis)
|
23 |
+
kernel = np.concatenate([real_kernel, image_kernel], 1).T
|
24 |
+
|
25 |
+
if inverse:
|
26 |
+
kernel = np.linalg.pinv(kernel).T
|
27 |
+
|
28 |
+
kernel = kernel * window
|
29 |
+
kernel = kernel[:, None, :]
|
30 |
+
result = (
|
31 |
+
torch.from_numpy(kernel.astype(np.float32)),
|
32 |
+
torch.from_numpy(window[None, :, None].astype(np.float32))
|
33 |
+
)
|
34 |
+
return result
|
35 |
+
|
36 |
+
|
37 |
+
class ConvSTFT(nn.Module):
|
38 |
+
|
39 |
+
def __init__(self,
|
40 |
+
nfft: int,
|
41 |
+
win_size: int,
|
42 |
+
hop_size: int,
|
43 |
+
win_type: str = "hamming",
|
44 |
+
power: int = None,
|
45 |
+
requires_grad: bool = False):
|
46 |
+
super(ConvSTFT, self).__init__()
|
47 |
+
|
48 |
+
if nfft is None:
|
49 |
+
self.nfft = int(2**np.ceil(np.log2(win_size)))
|
50 |
+
else:
|
51 |
+
self.nfft = nfft
|
52 |
+
|
53 |
+
kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type)
|
54 |
+
self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
|
55 |
+
|
56 |
+
self.win_size = win_size
|
57 |
+
self.hop_size = hop_size
|
58 |
+
|
59 |
+
self.stride = hop_size
|
60 |
+
self.dim = self.nfft
|
61 |
+
self.power = power
|
62 |
+
|
63 |
+
def forward(self, waveform: torch.Tensor):
|
64 |
+
if waveform.dim() == 2:
|
65 |
+
waveform = torch.unsqueeze(waveform, 1)
|
66 |
+
|
67 |
+
matrix = F.conv1d(waveform, self.weight, stride=self.stride)
|
68 |
+
dim = self.dim // 2 + 1
|
69 |
+
real = matrix[:, :dim, :]
|
70 |
+
imag = matrix[:, dim:, :]
|
71 |
+
spec = torch.complex(real, imag)
|
72 |
+
# spec shape: [b, f, t], torch.complex64
|
73 |
+
|
74 |
+
if self.power is None:
|
75 |
+
return spec
|
76 |
+
elif self.power == 1:
|
77 |
+
mags = torch.sqrt(real**2 + imag**2)
|
78 |
+
# phase = torch.atan2(imag, real)
|
79 |
+
return mags
|
80 |
+
elif self.power == 2:
|
81 |
+
power = real**2 + imag**2
|
82 |
+
return power
|
83 |
+
else:
|
84 |
+
raise AssertionError
|
85 |
+
|
86 |
+
|
87 |
+
class ConviSTFT(nn.Module):
|
88 |
+
|
89 |
+
def __init__(self,
|
90 |
+
win_size: int,
|
91 |
+
hop_size: int,
|
92 |
+
nfft: int = None,
|
93 |
+
win_type: str = "hamming",
|
94 |
+
requires_grad: bool = False):
|
95 |
+
super(ConviSTFT, self).__init__()
|
96 |
+
if nfft is None:
|
97 |
+
self.nfft = int(2**np.ceil(np.log2(win_size)))
|
98 |
+
else:
|
99 |
+
self.nfft = nfft
|
100 |
+
|
101 |
+
kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True)
|
102 |
+
self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
|
103 |
+
# weight shape: [f*2, 1, nfft]
|
104 |
+
# f = nfft // 2 + 1
|
105 |
+
|
106 |
+
self.win_size = win_size
|
107 |
+
self.hop_size = hop_size
|
108 |
+
self.win_type = win_type
|
109 |
+
|
110 |
+
self.stride = hop_size
|
111 |
+
self.dim = self.nfft
|
112 |
+
|
113 |
+
self.register_buffer("window", window)
|
114 |
+
self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
|
115 |
+
# window shape: [1, nfft, 1]
|
116 |
+
# enframe shape: [nfft, 1, nfft]
|
117 |
+
|
118 |
+
def forward(self,
|
119 |
+
spec: torch.Tensor):
|
120 |
+
"""
|
121 |
+
self.weight shape: [f*2, 1, win_size]
|
122 |
+
self.window shape: [1, win_size, 1]
|
123 |
+
self.enframe shape: [win_size, 1, win_size]
|
124 |
+
|
125 |
+
:param spec: torch.Tensor, shape: [b, f, t, 2]
|
126 |
+
:return:
|
127 |
+
"""
|
128 |
+
spec = torch.view_as_real(spec)
|
129 |
+
# spec shape: [b, f, t, 2]
|
130 |
+
matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1)
|
131 |
+
# matrix shape: [b, f*2, t]
|
132 |
+
|
133 |
+
waveform = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
|
134 |
+
# waveform shape: [b, 1, num_samples]
|
135 |
+
|
136 |
+
# this is from torch-stft: https://github.com/pseeth/torch-stft
|
137 |
+
t = self.window.repeat(1, 1, matrix.size(-1))**2
|
138 |
+
# t shape: [1, win_size, t]
|
139 |
+
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
|
140 |
+
# coff shape: [1, 1, num_samples]
|
141 |
+
waveform = waveform / (coff + 1e-8)
|
142 |
+
# waveform = waveform / coff
|
143 |
+
return waveform
|
144 |
+
|
145 |
+
@torch.no_grad()
|
146 |
+
def forward_chunk(self,
|
147 |
+
spec: torch.Tensor,
|
148 |
+
cache_dict: dict = None
|
149 |
+
):
|
150 |
+
"""
|
151 |
+
:param spec: shape: [b, f, t]
|
152 |
+
:param cache_dict: dict,
|
153 |
+
waveform_cache shape: [b, 1, win_size - hop_size]
|
154 |
+
coff_cache shape: [b, 1, win_size - hop_size]
|
155 |
+
:return:
|
156 |
+
"""
|
157 |
+
if cache_dict is None:
|
158 |
+
cache_dict = defaultdict(lambda: None)
|
159 |
+
waveform_cache = cache_dict["waveform_cache"]
|
160 |
+
coff_cache = cache_dict["coff_cache"]
|
161 |
+
|
162 |
+
spec = torch.view_as_real(spec)
|
163 |
+
matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1)
|
164 |
+
|
165 |
+
waveform_current = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
|
166 |
+
|
167 |
+
t = self.window.repeat(1, 1, matrix.size(-1))**2
|
168 |
+
coff_current = F.conv_transpose1d(t, self.enframe, stride=self.stride)
|
169 |
+
|
170 |
+
overlap_size = self.win_size - self.hop_size
|
171 |
+
|
172 |
+
if waveform_cache is not None:
|
173 |
+
waveform_current[:, :, :overlap_size] += waveform_cache
|
174 |
+
waveform_output = waveform_current[:, :, :self.hop_size]
|
175 |
+
new_waveform_cache = waveform_current[:, :, self.hop_size:]
|
176 |
+
|
177 |
+
if coff_cache is not None:
|
178 |
+
coff_current[:, :, :overlap_size] += coff_cache
|
179 |
+
coff_output = coff_current[:, :, :self.hop_size]
|
180 |
+
new_coff_cache = coff_current[:, :, self.hop_size:]
|
181 |
+
|
182 |
+
waveform_output = waveform_output / (coff_output + 1e-8)
|
183 |
+
|
184 |
+
new_cache_dict = {
|
185 |
+
"waveform_cache": new_waveform_cache,
|
186 |
+
"coff_cache": new_coff_cache,
|
187 |
+
}
|
188 |
+
return waveform_output, new_cache_dict
|
189 |
+
|
190 |
+
|
191 |
+
def main():
|
192 |
+
nfft = 512
|
193 |
+
win_size = 512
|
194 |
+
hop_size = 256
|
195 |
+
|
196 |
+
stft = ConvSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size, power=None)
|
197 |
+
istft = ConviSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size)
|
198 |
+
|
199 |
+
mixture = torch.rand(size=(1, 16000), dtype=torch.float32)
|
200 |
+
b, num_samples = mixture.shape
|
201 |
+
t = (num_samples - win_size) / hop_size + 1
|
202 |
+
|
203 |
+
spec = stft.forward(mixture)
|
204 |
+
b, f, t = spec.shape
|
205 |
+
|
206 |
+
# 如果 spec 是由 stft 变换得来的,以下两种 waveform 还原方法就是一致的,否则还原出的 waveform 会有差异。
|
207 |
+
# spec = spec + 0.01 * torch.randn(size=(1, nfft//2+1, t), dtype=torch.float32)
|
208 |
+
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
|
209 |
+
|
210 |
+
waveform = istft.forward(spec)
|
211 |
+
# shape: [batch_size, channels, num_samples]
|
212 |
+
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
213 |
+
print(waveform[:, :, 300: 302])
|
214 |
+
|
215 |
+
waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
|
216 |
+
for i in range(int(t)):
|
217 |
+
begin = i * hop_size
|
218 |
+
end = begin + win_size
|
219 |
+
sub_spec = spec[:, :, i:i+1]
|
220 |
+
sub_waveform = istft.forward(sub_spec)
|
221 |
+
# (b, 1, win_size)
|
222 |
+
waveform[:, :, begin:end] = sub_waveform
|
223 |
+
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
224 |
+
print(waveform[:, :, 300: 302])
|
225 |
+
|
226 |
+
return
|
227 |
+
|
228 |
+
|
229 |
+
def main2():
|
230 |
+
nfft = 512
|
231 |
+
win_size = 512
|
232 |
+
hop_size = 256
|
233 |
+
|
234 |
+
stft = ConvSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size, power=None)
|
235 |
+
istft = ConviSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size)
|
236 |
+
|
237 |
+
mixture = torch.rand(size=(1, 16128), dtype=torch.float32)
|
238 |
+
b, num_samples = mixture.shape
|
239 |
+
|
240 |
+
spec = stft.forward(mixture)
|
241 |
+
b, f, t = spec.shape
|
242 |
+
|
243 |
+
# 如果 spec 是由 stft 变换得来的,以下两种 waveform 还原方法就是一致的,否则还原出的 waveform 会有差异。
|
244 |
+
spec = spec + 0.01 * torch.randn(size=(1, nfft//2+1, t), dtype=torch.float32)
|
245 |
+
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
|
246 |
+
|
247 |
+
waveform = istft.forward(spec)
|
248 |
+
# shape: [batch_size, channels, num_samples]
|
249 |
+
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
250 |
+
print(waveform[:, :, 300: 302])
|
251 |
+
|
252 |
+
cache_dict = None
|
253 |
+
waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
|
254 |
+
for i in range(int(t)):
|
255 |
+
sub_spec = spec[:, :, i:i+1]
|
256 |
+
begin = i * hop_size
|
257 |
+
|
258 |
+
end = begin + win_size - hop_size
|
259 |
+
sub_waveform, cache_dict = istft.forward_chunk(sub_spec, cache_dict=cache_dict)
|
260 |
+
# end = begin + win_size
|
261 |
+
# sub_waveform = istft.forward(sub_spec)
|
262 |
+
|
263 |
+
waveform[:, :, begin:end] = sub_waveform
|
264 |
+
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
265 |
+
print(waveform[:, :, 300: 302])
|
266 |
+
|
267 |
+
return
|
268 |
+
|
269 |
+
|
270 |
+
if __name__ == "__main__":
|
271 |
+
main2()
|
toolbox/webrtcvad/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
if __name__ == '__main__':
|
5 |
+
pass
|
toolbox/webrtcvad/vad.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import collections
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from scipy.io import wavfile
|
9 |
+
import webrtcvad
|
10 |
+
|
11 |
+
from project_settings import project_path
|
12 |
+
|
13 |
+
|
14 |
+
class Frame(object):
|
15 |
+
def __init__(self, signal: np.ndarray, timestamp, duration):
|
16 |
+
self.signal = signal
|
17 |
+
self.timestamp = timestamp
|
18 |
+
self.duration = duration
|
19 |
+
|
20 |
+
|
21 |
+
class WebRTCVad(object):
|
22 |
+
def __init__(self,
|
23 |
+
agg: int = 3,
|
24 |
+
frame_duration_ms: int = 30,
|
25 |
+
padding_duration_ms: int = 300,
|
26 |
+
silence_duration_threshold: float = 0.3,
|
27 |
+
sample_rate: int = 8000
|
28 |
+
):
|
29 |
+
self.agg = agg
|
30 |
+
self.frame_duration_ms = frame_duration_ms
|
31 |
+
self.padding_duration_ms = padding_duration_ms
|
32 |
+
self.silence_duration_threshold = silence_duration_threshold
|
33 |
+
self.sample_rate = sample_rate
|
34 |
+
|
35 |
+
self._vad = webrtcvad.Vad(mode=agg)
|
36 |
+
|
37 |
+
# frames
|
38 |
+
self.frame_length = int(sample_rate * (frame_duration_ms / 1000.0))
|
39 |
+
self.frame_timestamp = 0.0
|
40 |
+
self.signal_cache = None
|
41 |
+
|
42 |
+
# segments
|
43 |
+
self.num_padding_frames = int(padding_duration_ms / frame_duration_ms)
|
44 |
+
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
|
45 |
+
self.triggered = False
|
46 |
+
self.voiced_frames: List[Frame] = list()
|
47 |
+
self.segments = list()
|
48 |
+
|
49 |
+
# vad segments
|
50 |
+
self.is_first_segment = True
|
51 |
+
self.timestamp_start = 0.0
|
52 |
+
self.timestamp_end = 0.0
|
53 |
+
|
54 |
+
def signal_to_frames(self, signal: np.ndarray):
|
55 |
+
frames = list()
|
56 |
+
|
57 |
+
l = len(signal)
|
58 |
+
|
59 |
+
duration = (float(self.frame_length) / self.sample_rate)
|
60 |
+
|
61 |
+
for offset in range(0, l, self.frame_length):
|
62 |
+
sub_signal = signal[offset:offset+self.frame_length]
|
63 |
+
|
64 |
+
frame = Frame(sub_signal, self.frame_timestamp, duration)
|
65 |
+
self.frame_timestamp += duration
|
66 |
+
|
67 |
+
frames.append(frame)
|
68 |
+
return frames
|
69 |
+
|
70 |
+
def segments_generator(self, signal: np.ndarray):
|
71 |
+
# signal rounding
|
72 |
+
if self.signal_cache is not None:
|
73 |
+
signal = np.concatenate([self.signal_cache, signal])
|
74 |
+
|
75 |
+
rest = len(signal) % self.frame_length
|
76 |
+
|
77 |
+
if rest == 0:
|
78 |
+
self.signal_cache = None
|
79 |
+
signal_ = signal
|
80 |
+
else:
|
81 |
+
self.signal_cache = signal[-rest:]
|
82 |
+
signal_ = signal[:-rest]
|
83 |
+
|
84 |
+
# frames
|
85 |
+
frames = self.signal_to_frames(signal_)
|
86 |
+
|
87 |
+
for frame in frames:
|
88 |
+
audio_bytes = bytes(frame.signal)
|
89 |
+
is_speech = self._vad.is_speech(audio_bytes, self.sample_rate)
|
90 |
+
|
91 |
+
if not self.triggered:
|
92 |
+
self.ring_buffer.append((frame, is_speech))
|
93 |
+
num_voiced = len([f for f, speech in self.ring_buffer if speech])
|
94 |
+
|
95 |
+
if num_voiced > 0.9 * self.ring_buffer.maxlen:
|
96 |
+
self.triggered = True
|
97 |
+
|
98 |
+
for f, _ in self.ring_buffer:
|
99 |
+
self.voiced_frames.append(f)
|
100 |
+
self.ring_buffer.clear()
|
101 |
+
else:
|
102 |
+
self.voiced_frames.append(frame)
|
103 |
+
self.ring_buffer.append((frame, is_speech))
|
104 |
+
num_unvoiced = len([f for f, speech in self.ring_buffer if not speech])
|
105 |
+
if num_unvoiced > 0.9 * self.ring_buffer.maxlen:
|
106 |
+
self.triggered = False
|
107 |
+
segment = [
|
108 |
+
np.concatenate([f.signal for f in self.voiced_frames]),
|
109 |
+
self.voiced_frames[0].timestamp,
|
110 |
+
self.voiced_frames[-1].timestamp
|
111 |
+
]
|
112 |
+
yield segment
|
113 |
+
self.ring_buffer.clear()
|
114 |
+
self.voiced_frames: List[Frame] = list()
|
115 |
+
|
116 |
+
def vad_segments_generator(self, segments_generator):
|
117 |
+
segments = list(segments_generator)
|
118 |
+
|
119 |
+
for i, segment in enumerate(segments):
|
120 |
+
start = round(segment[1], 4)
|
121 |
+
end = round(segment[2], 4)
|
122 |
+
|
123 |
+
if self.is_first_segment:
|
124 |
+
self.timestamp_start = start
|
125 |
+
self.timestamp_end = end
|
126 |
+
self.is_first_segment = False
|
127 |
+
continue
|
128 |
+
|
129 |
+
if self.timestamp_start:
|
130 |
+
sil_duration = start - self.timestamp_end
|
131 |
+
if sil_duration > self.silence_duration_threshold:
|
132 |
+
vad_segment = [self.timestamp_start, self.timestamp_end]
|
133 |
+
yield vad_segment
|
134 |
+
|
135 |
+
self.timestamp_start = start
|
136 |
+
self.timestamp_end = end
|
137 |
+
else:
|
138 |
+
self.timestamp_end = end
|
139 |
+
|
140 |
+
def vad(self, signal: np.ndarray) -> List[list]:
|
141 |
+
segments = self.segments_generator(signal)
|
142 |
+
vad_segments = self.vad_segments_generator(segments)
|
143 |
+
vad_segments = list(vad_segments)
|
144 |
+
return vad_segments
|
145 |
+
|
146 |
+
def last_vad_segments(self) -> List[list]:
|
147 |
+
# last segments
|
148 |
+
if len(self.voiced_frames) == 0:
|
149 |
+
segments = []
|
150 |
+
else:
|
151 |
+
segment = [
|
152 |
+
np.concatenate([f.signal for f in self.voiced_frames]),
|
153 |
+
self.voiced_frames[0].timestamp,
|
154 |
+
self.voiced_frames[-1].timestamp
|
155 |
+
]
|
156 |
+
segments = [segment]
|
157 |
+
|
158 |
+
# last vad segments
|
159 |
+
vad_segments = self.vad_segments_generator(segments)
|
160 |
+
vad_segments = list(vad_segments)
|
161 |
+
|
162 |
+
vad_segments = vad_segments + [[self.timestamp_start, self.timestamp_end]]
|
163 |
+
return vad_segments
|
164 |
+
|
165 |
+
|
166 |
+
def get_args():
|
167 |
+
parser = argparse.ArgumentParser()
|
168 |
+
parser.add_argument(
|
169 |
+
"--wav_file",
|
170 |
+
# default=(project_path / "data/0eeaef67-ea59-4f2d-a5b8-b70c813fd45c.wav").as_posix(),
|
171 |
+
default=(project_path / "data/1c998b62-c3aa-4541-b59a-d4a40b79eff3.wav").as_posix(),
|
172 |
+
# default=(project_path / "data/8cbad66f-2c4e-43c2-ad11-ad95bab8bc15.wav").as_posix(),
|
173 |
+
type=str,
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--agg",
|
177 |
+
default=3,
|
178 |
+
type=int,
|
179 |
+
help="The level of aggressiveness of the VAD: [0-3]'"
|
180 |
+
)
|
181 |
+
parser.add_argument(
|
182 |
+
"--frame_duration_ms",
|
183 |
+
default=30,
|
184 |
+
type=int,
|
185 |
+
)
|
186 |
+
parser.add_argument(
|
187 |
+
"--padding_duration_ms",
|
188 |
+
default=300,
|
189 |
+
type=int,
|
190 |
+
)
|
191 |
+
parser.add_argument(
|
192 |
+
"--silence_duration_threshold",
|
193 |
+
default=0.3,
|
194 |
+
type=float,
|
195 |
+
help="minimum silence duration, in seconds."
|
196 |
+
)
|
197 |
+
args = parser.parse_args()
|
198 |
+
return args
|
199 |
+
|
200 |
+
|
201 |
+
def main():
|
202 |
+
import matplotlib.pyplot as plt
|
203 |
+
|
204 |
+
args = get_args()
|
205 |
+
|
206 |
+
SAMPLE_RATE = 8000
|
207 |
+
|
208 |
+
w_vad = WebRTCVad(
|
209 |
+
agg=args.agg,
|
210 |
+
frame_duration_ms=args.frame_duration_ms,
|
211 |
+
padding_duration_ms=args.padding_duration_ms,
|
212 |
+
silence_duration_threshold=args.silence_duration_threshold,
|
213 |
+
sample_rate=SAMPLE_RATE,
|
214 |
+
)
|
215 |
+
|
216 |
+
sample_rate, signal = wavfile.read(args.wav_file)
|
217 |
+
if SAMPLE_RATE != sample_rate:
|
218 |
+
raise AssertionError
|
219 |
+
|
220 |
+
vad_segments = list()
|
221 |
+
|
222 |
+
segments = w_vad.vad(signal)
|
223 |
+
vad_segments += segments
|
224 |
+
for segment in segments:
|
225 |
+
print(segment)
|
226 |
+
|
227 |
+
# last vad segment
|
228 |
+
segments = w_vad.last_vad_segments()
|
229 |
+
vad_segments += segments
|
230 |
+
for segment in segments:
|
231 |
+
print(segment)
|
232 |
+
|
233 |
+
# plot
|
234 |
+
time = np.arange(0, len(signal)) / sample_rate
|
235 |
+
plt.figure(figsize=(12, 5))
|
236 |
+
plt.plot(time, signal / 32768, color='b')
|
237 |
+
for start, end in vad_segments:
|
238 |
+
# start -= (w_vad.padding_duration_ms - 2*w_vad.frame_duration_ms) / 1000
|
239 |
+
end -= (w_vad.padding_duration_ms - 0*w_vad.frame_duration_ms) / 1000
|
240 |
+
|
241 |
+
plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') # 标记开始端点
|
242 |
+
plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') # 标记结束端点
|
243 |
+
|
244 |
+
plt.show()
|
245 |
+
return
|
246 |
+
|
247 |
+
|
248 |
+
if __name__ == '__main__':
|
249 |
+
main()
|