HoneyTian commited on
Commit
9829721
·
0 Parent(s):

first commit

Browse files
Files changed (46) hide show
  1. .dockerignore +5 -0
  2. .gitattributes +35 -0
  3. .gitignore +23 -0
  4. Dockerfile +24 -0
  5. README.md +129 -0
  6. examples/silero_vad_by_webrtcvad/run.sh +164 -0
  7. examples/silero_vad_by_webrtcvad/step_1_prepare_data.py +185 -0
  8. examples/silero_vad_by_webrtcvad/step_2_train_model.py +469 -0
  9. examples/silero_vad_by_webrtcvad/yaml/config.yaml +22 -0
  10. install.sh +64 -0
  11. log.py +220 -0
  12. main.py +69 -0
  13. project_settings.py +27 -0
  14. requirements.txt +13 -0
  15. toolbox/__init__.py +6 -0
  16. toolbox/json/__init__.py +6 -0
  17. toolbox/json/misc.py +63 -0
  18. toolbox/os/__init__.py +6 -0
  19. toolbox/os/command.py +59 -0
  20. toolbox/os/environment.py +114 -0
  21. toolbox/os/other.py +9 -0
  22. toolbox/torch/__init__.py +6 -0
  23. toolbox/torch/utils/__init__.py +6 -0
  24. toolbox/torch/utils/data/__init__.py +6 -0
  25. toolbox/torch/utils/data/dataset/__init__.py +6 -0
  26. toolbox/torch/utils/data/dataset/vad_jsonl_dataset.py +179 -0
  27. toolbox/torchaudio/__init__.py +6 -0
  28. toolbox/torchaudio/configuration_utils.py +64 -0
  29. toolbox/torchaudio/losses/__init__.py +6 -0
  30. toolbox/torchaudio/losses/vad_loss/__init__.py +6 -0
  31. toolbox/torchaudio/losses/vad_loss/base_vad_loss.py +43 -0
  32. toolbox/torchaudio/losses/vad_loss/bce_vad_loss.py +52 -0
  33. toolbox/torchaudio/losses/vad_loss/dice_vad_loss.py +70 -0
  34. toolbox/torchaudio/metrics/__init__.py +6 -0
  35. toolbox/torchaudio/metrics/vad_metrics/__init__.py +6 -0
  36. toolbox/torchaudio/metrics/vad_metrics/vad_accuracy.py +60 -0
  37. toolbox/torchaudio/models/__init__.py +6 -0
  38. toolbox/torchaudio/models/vad/__init__.py +6 -0
  39. toolbox/torchaudio/models/vad/silero_vad/__init__.py +6 -0
  40. toolbox/torchaudio/models/vad/silero_vad/configuration_silero_vad.py +66 -0
  41. toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py +151 -0
  42. toolbox/torchaudio/models/vad/silero_vad/yaml/config.yaml +22 -0
  43. toolbox/torchaudio/modules/__init__.py +6 -0
  44. toolbox/torchaudio/modules/conv_stft.py +271 -0
  45. toolbox/webrtcvad/__init__.py +5 -0
  46. toolbox/webrtcvad/vad.py +249 -0
.dockerignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ .git/
3
+ .idea/
4
+
5
+ /examples/
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .gradio/
3
+ .git/
4
+ .idea/
5
+
6
+ **/evaluation_audio/
7
+ **/file_dir/
8
+ **/flagged/
9
+ **/log/
10
+ **/logs/
11
+ **/__pycache__/
12
+
13
+ /data/
14
+ /docs/
15
+ /dotenv/
16
+ /hub_datasets/
17
+ /script/
18
+ /thirdparty/
19
+ /trained_models/
20
+ /temp/
21
+
22
+ **/*.wav
23
+ **/*.xlsx
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12
2
+
3
+ WORKDIR /code
4
+
5
+ COPY . /code
6
+
7
+ RUN apt-get update
8
+ RUN apt-get install -y ffmpeg build-essential
9
+
10
+ RUN pip install --upgrade pip
11
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
12
+
13
+ RUN useradd -m -u 1000 user
14
+
15
+ USER user
16
+
17
+ ENV HOME=/home/user \
18
+ PATH=/home/user/.local/bin:$PATH
19
+
20
+ WORKDIR $HOME/app
21
+
22
+ COPY --chown=user . $HOME/app
23
+
24
+ CMD ["python3", "main.py"]
README.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CC VAD
3
+ emoji: 🐢
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
+ ## CC VAD
13
+
14
+
15
+ ### datasets
16
+
17
+ ```text
18
+
19
+ AISHELL (15G)
20
+ https://openslr.trmal.net/resources/33/
21
+
22
+ AISHELL-3 (19G)
23
+ http://www.openslr.org/93/
24
+
25
+ DNS3
26
+ https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
27
+ 噪音数据来源于 DEMAND, FreeSound, AudioSet.
28
+
29
+ MS-SNSD
30
+ https://github.com/microsoft/MS-SNSD
31
+ 噪音数据来源于 DEMAND, FreeSound.
32
+
33
+ MUSAN
34
+ https://www.openslr.org/17/
35
+ 其中包含 music, noise, speech.
36
+ music 是一些纯音乐, noise 包含 free-sound, sound-bible, sound-bible部分也许可以做为补充部分.
37
+ 总的来说, 有用的不部不多, 可能噪音数据仍然需要自己收集为主, 更加可靠.
38
+
39
+ CHiME-4
40
+ https://www.chimechallenge.org/challenges/chime4/download.html
41
+
42
+ freesound
43
+ https://freesound.org/
44
+
45
+ AudioSet
46
+ https://research.google.com/audioset/index.html
47
+ ```
48
+
49
+
50
+ ### ### 创建训练容器
51
+
52
+ ```text
53
+ 在容器中训练模型,需要能够从容器中访问到 GPU,参考:
54
+ https://hub.docker.com/r/ollama/ollama
55
+
56
+ docker run -itd \
57
+ --name cc_vad \
58
+ --network host \
59
+ --gpus all \
60
+ --privileged \
61
+ --ipc=host \
62
+ -v /data/tianxing/HuggingDatasets/nx_noise/data:/data/tianxing/HuggingDatasets/nx_noise/data \
63
+ -v /data/tianxing/PycharmProjects/cc_vad:/data/tianxing/PycharmProjects/cc_vad \
64
+ python:3.12
65
+
66
+
67
+ 查看GPU
68
+ nvidia-smi
69
+ watch -n 1 -d nvidia-smi
70
+
71
+
72
+ ```
73
+
74
+ ```text
75
+ 在容器中访问 GPU
76
+
77
+ 参考:
78
+ https://blog.csdn.net/footless_bird/article/details/136291344
79
+ 步骤:
80
+ # 安装
81
+ yum install -y nvidia-container-toolkit
82
+
83
+ # 编辑文件 /etc/docker/daemon.json
84
+ cat /etc/docker/daemon.json
85
+ {
86
+ "data-root": "/data/lib/docker",
87
+ "default-runtime": "nvidia",
88
+ "runtimes": {
89
+ "nvidia": {
90
+ "path": "/usr/bin/nvidia-container-runtime",
91
+ "runtimeArgs": []
92
+ }
93
+ },
94
+ "registry-mirrors": [
95
+ "https://docker.m.daocloud.io",
96
+ "https://dockerproxy.com",
97
+ "https://docker.mirrors.ustc.edu.cn",
98
+ "https://docker.nju.edu.cn"
99
+ ]
100
+ }
101
+
102
+ # 重启 docker
103
+ systemctl restart docker
104
+ systemctl daemon-reload
105
+
106
+ # 测试容器内能否访问 GPU.
107
+ docker run --gpus all python:3.12-slim nvidia-smi
108
+
109
+ # 通过这种方式启动容器, 在容器中, 可以查看到 GPU. 但是容器中没有 GPU驱动 nvidia-smi 不工作.
110
+ docker run -it --privileged python:3.12-slim /bin/bash
111
+ apt update
112
+ apt install -y pciutils
113
+ lspci | grep -i nvidia
114
+ #00:08.0 3D controller: NVIDIA Corporation TU104GL [Tesla T4] (rev a1)
115
+
116
+ # 网上看的是这种启动容器的方式, 但是进去后仍然是 nvidia-smi 不工作.
117
+ docker run \
118
+ --device /dev/nvidia0:/dev/nvidia0 \
119
+ --device /dev/nvidiactl:/dev/nvidiactl \
120
+ --device /dev/nvidia-uvm:/dev/nvidia-uvm \
121
+ -v /usr/local/nvidia:/usr/local/nvidia \
122
+ -it --privileged python:3.12-slim /bin/bash
123
+
124
+
125
+ # 这种方式进入容器, nvidia-smi 可以工作. 应该关键是 --gpus all 参数.
126
+ docker run -itd --gpus all --name open_unsloth python:3.12-slim /bin/bash
127
+ docker run -itd --gpus all --name Qwen2-7B-Instruct python:3.12-slim /bin/bash
128
+
129
+ ```
examples/silero_vad_by_webrtcvad/run.sh ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name dfnet-nx-speech \
6
+ --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
7
+ --speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
8
+
9
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx-dns3 \
10
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
11
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
12
+
13
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx2 \
14
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
15
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
16
+
17
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dfnet2-nx2-dns3 --final_model_name dfnet2-nx2-dns3 \
18
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
19
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
20
+
21
+
22
+ END
23
+
24
+
25
+ # params
26
+ system_version="windows";
27
+ verbose=true;
28
+ stage=0 # start from 0 if you need to start from data preparation
29
+ stop_stage=9
30
+
31
+ work_dir="$(pwd)"
32
+ file_folder_name=file_folder_name
33
+ final_model_name=final_model_name
34
+ config_file="yaml/config.yaml"
35
+ limit=10
36
+
37
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
38
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
39
+
40
+ max_count=-1
41
+
42
+ nohup_name=nohup.out
43
+
44
+ # model params
45
+ batch_size=64
46
+ max_epochs=200
47
+ save_top_k=10
48
+ patience=5
49
+
50
+
51
+ # parse options
52
+ while true; do
53
+ [ -z "${1:-}" ] && break; # break if there are no arguments
54
+ case "$1" in
55
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
56
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
57
+ old_value="(eval echo \\$$name)";
58
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
59
+ was_bool=true;
60
+ else
61
+ was_bool=false;
62
+ fi
63
+
64
+ # Set the variable to the right value-- the escaped quotes make it work if
65
+ # the option had spaces, like --cmd "queue.pl -sync y"
66
+ eval "${name}=\"$2\"";
67
+
68
+ # Check that Boolean-valued arguments are really Boolean.
69
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
70
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
71
+ exit 1;
72
+ fi
73
+ shift 2;
74
+ ;;
75
+
76
+ *) break;
77
+ esac
78
+ done
79
+
80
+ file_dir="${work_dir}/${file_folder_name}"
81
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
82
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
83
+
84
+ train_dataset="${file_dir}/train.jsonl"
85
+ valid_dataset="${file_dir}/valid.jsonl"
86
+
87
+ $verbose && echo "system_version: ${system_version}"
88
+ $verbose && echo "file_folder_name: ${file_folder_name}"
89
+
90
+ if [ $system_version == "windows" ]; then
91
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
92
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
93
+ #source /data/local/bin/nx_denoise/bin/activate
94
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
95
+ fi
96
+
97
+
98
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
99
+ $verbose && echo "stage 1: prepare data"
100
+ cd "${work_dir}" || exit 1
101
+ python3 step_1_prepare_data.py \
102
+ --file_dir "${file_dir}" \
103
+ --noise_dir "${noise_dir}" \
104
+ --speech_dir "${speech_dir}" \
105
+ --train_dataset "${train_dataset}" \
106
+ --valid_dataset "${valid_dataset}" \
107
+ --max_count "${max_count}" \
108
+
109
+ fi
110
+
111
+
112
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
113
+ $verbose && echo "stage 2: train model"
114
+ cd "${work_dir}" || exit 1
115
+ python3 step_2_train_model.py \
116
+ --train_dataset "${train_dataset}" \
117
+ --valid_dataset "${valid_dataset}" \
118
+ --serialization_dir "${file_dir}" \
119
+ --config_file "${config_file}" \
120
+
121
+ fi
122
+
123
+
124
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
125
+ $verbose && echo "stage 3: test model"
126
+ cd "${work_dir}" || exit 1
127
+ python3 step_3_evaluation.py \
128
+ --valid_dataset "${valid_dataset}" \
129
+ --model_dir "${file_dir}/best" \
130
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
131
+ --limit "${limit}" \
132
+
133
+ fi
134
+
135
+
136
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
137
+ $verbose && echo "stage 4: collect files"
138
+ cd "${work_dir}" || exit 1
139
+
140
+ mkdir -p ${final_model_dir}
141
+
142
+ cp "${file_dir}/best"/* "${final_model_dir}"
143
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
144
+
145
+ cd "${final_model_dir}/.." || exit 1;
146
+
147
+ if [ -e "${final_model_name}.zip" ]; then
148
+ rm -rf "${final_model_name}_backup.zip"
149
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
150
+ fi
151
+
152
+ zip -r "${final_model_name}.zip" "${final_model_name}"
153
+ rm -rf "${final_model_name}"
154
+
155
+ fi
156
+
157
+
158
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
159
+ $verbose && echo "stage 5: clear file_dir"
160
+ cd "${work_dir}" || exit 1
161
+
162
+ rm -rf "${file_dir}";
163
+
164
+ fi
examples/silero_vad_by_webrtcvad/step_1_prepare_data.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ from scipy.io import wavfile
16
+ from tqdm import tqdm
17
+
18
+ from toolbox.webrtcvad.vad import WebRTCVad
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--file_dir", default="./", type=str)
24
+
25
+ parser.add_argument(
26
+ "--noise_dir",
27
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
28
+ type=str
29
+ )
30
+ parser.add_argument(
31
+ "--speech_dir",
32
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
33
+ type=str
34
+ )
35
+
36
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
38
+
39
+ parser.add_argument("--duration", default=4.0, type=float)
40
+ parser.add_argument("--min_snr_db", default=-10, type=float)
41
+ parser.add_argument("--max_snr_db", default=20, type=float)
42
+
43
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
44
+
45
+ parser.add_argument("--max_count", default=-1, type=int)
46
+
47
+ # vad
48
+ parser.add_argument("--agg", default=3, type=int)
49
+ parser.add_argument("--frame_duration_ms", default=30, type=int)
50
+ parser.add_argument("--padding_duration_ms", default=30, type=int)
51
+ parser.add_argument("--silence_duration_threshold", default=0.3, type=float)
52
+
53
+ args = parser.parse_args()
54
+ return args
55
+
56
+
57
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
58
+ data_dir = Path(data_dir)
59
+ for epoch_idx in range(max_epoch):
60
+ for filename in data_dir.glob("**/*.wav"):
61
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
62
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
63
+
64
+ if raw_duration < duration:
65
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
66
+ continue
67
+ if signal.ndim != 1:
68
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
69
+
70
+ signal_length = len(signal)
71
+ win_size = int(duration * sample_rate)
72
+ for begin in range(0, signal_length - win_size, win_size):
73
+ if np.sum(signal[begin: begin+win_size]) == 0:
74
+ continue
75
+ row = {
76
+ "epoch_idx": epoch_idx,
77
+ "filename": filename.as_posix(),
78
+ "raw_duration": round(raw_duration, 4),
79
+ "offset": round(begin / sample_rate, 4),
80
+ "duration": round(duration, 4),
81
+ }
82
+ yield row
83
+
84
+
85
+ def main():
86
+ args = get_args()
87
+
88
+ file_dir = Path(args.file_dir)
89
+ file_dir.mkdir(exist_ok=True)
90
+
91
+ noise_dir = Path(args.noise_dir)
92
+ speech_dir = Path(args.speech_dir)
93
+
94
+ noise_generator = target_second_signal_generator(
95
+ noise_dir.as_posix(),
96
+ duration=args.duration,
97
+ sample_rate=args.target_sample_rate,
98
+ max_epoch=100000,
99
+ )
100
+ speech_generator = target_second_signal_generator(
101
+ speech_dir.as_posix(),
102
+ duration=args.duration,
103
+ sample_rate=args.target_sample_rate,
104
+ max_epoch=1,
105
+ )
106
+
107
+ w_vad = WebRTCVad(
108
+ agg=args.agg,
109
+ frame_duration_ms=args.frame_duration_ms,
110
+ padding_duration_ms=args.padding_duration_ms,
111
+ silence_duration_threshold=args.silence_duration_threshold,
112
+ sample_rate=args.target_sample_rate,
113
+ )
114
+
115
+ count = 0
116
+ process_bar = tqdm(desc="build dataset jsonl")
117
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
118
+ for noise, speech in zip(noise_generator, speech_generator):
119
+ if count >= args.max_count > 0:
120
+ break
121
+
122
+ # row
123
+ noise_filename = noise["filename"]
124
+ noise_raw_duration = noise["raw_duration"]
125
+ noise_offset = noise["offset"]
126
+ noise_duration = noise["duration"]
127
+
128
+ speech_filename = speech["filename"]
129
+ speech_raw_duration = speech["raw_duration"]
130
+ speech_offset = speech["offset"]
131
+ speech_duration = speech["duration"]
132
+
133
+ # vad
134
+ _, signal = wavfile.read(speech_filename)
135
+ vad_segments = list()
136
+ segments = w_vad.vad(signal)
137
+ vad_segments += segments
138
+ segments = w_vad.last_vad_segments()
139
+ vad_segments += segments
140
+
141
+ # row
142
+ random1 = random.random()
143
+ random2 = random.random()
144
+
145
+ row = {
146
+ "count": count,
147
+
148
+ "noise_filename": noise_filename,
149
+ "noise_raw_duration": noise_raw_duration,
150
+ "noise_offset": noise_offset,
151
+ "noise_duration": noise_duration,
152
+
153
+ "speech_filename": speech_filename,
154
+ "speech_raw_duration": speech_raw_duration,
155
+ "speech_offset": speech_offset,
156
+ "speech_duration": speech_duration,
157
+
158
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
159
+
160
+ "vad_segments": vad_segments,
161
+
162
+ "random1": random1,
163
+ }
164
+ row = json.dumps(row, ensure_ascii=False)
165
+ if random2 < (1 / 300 / 1):
166
+ fvalid.write(f"{row}\n")
167
+ else:
168
+ ftrain.write(f"{row}\n")
169
+
170
+ count += 1
171
+ duration_seconds = count * args.duration
172
+ duration_hours = duration_seconds / 3600
173
+
174
+ process_bar.update(n=1)
175
+ process_bar.set_postfix({
176
+ # "duration_seconds": round(duration_seconds, 4),
177
+ "duration_hours": round(duration_hours, 4),
178
+
179
+ })
180
+
181
+ return
182
+
183
+
184
+ if __name__ == "__main__":
185
+ main()
examples/silero_vad_by_webrtcvad/step_2_train_model.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/Rikorose/DeepFilterNet
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ from fontTools.varLib.plot import stops
19
+
20
+ pwd = os.path.abspath(os.path.dirname(__file__))
21
+ sys.path.append(os.path.join(pwd, "../../"))
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch.nn import functional as F
27
+ from torch.utils.data.dataloader import DataLoader
28
+ from tqdm import tqdm
29
+
30
+ from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
31
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
32
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
33
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
34
+ from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config
35
+ from toolbox.torchaudio.models.dfnet2.modeling_dfnet2 import DfNet2, DfNet2PretrainedModel
36
+
37
+
38
+ def get_args():
39
+ parser = argparse.ArgumentParser()
40
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
41
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
42
+
43
+ parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
44
+ parser.add_argument("--patience", default=30, type=int)
45
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
46
+
47
+ parser.add_argument("--config_file", default="config.yaml", type=str)
48
+
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ def logging_config(file_dir: str):
54
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
55
+
56
+ logging.basicConfig(format=fmt,
57
+ datefmt="%m/%d/%Y %H:%M:%S",
58
+ level=logging.INFO)
59
+ file_handler = TimedRotatingFileHandler(
60
+ filename=os.path.join(file_dir, "main.log"),
61
+ encoding="utf-8",
62
+ when="D",
63
+ interval=1,
64
+ backupCount=7
65
+ )
66
+ file_handler.setLevel(logging.INFO)
67
+ file_handler.setFormatter(logging.Formatter(fmt))
68
+ logger = logging.getLogger(__name__)
69
+ logger.addHandler(file_handler)
70
+
71
+ return logger
72
+
73
+
74
+ class CollateFunction(object):
75
+ def __init__(self):
76
+ pass
77
+
78
+ def __call__(self, batch: List[dict]):
79
+ clean_audios = list()
80
+ noisy_audios = list()
81
+ snr_db_list = list()
82
+
83
+ for sample in batch:
84
+ # noise_wave: torch.Tensor = sample["noise_wave"]
85
+ clean_audio: torch.Tensor = sample["speech_wave"]
86
+ noisy_audio: torch.Tensor = sample["mix_wave"]
87
+ # snr_db: float = sample["snr_db"]
88
+
89
+ clean_audios.append(clean_audio)
90
+ noisy_audios.append(noisy_audio)
91
+
92
+ clean_audios = torch.stack(clean_audios)
93
+ noisy_audios = torch.stack(noisy_audios)
94
+
95
+ # assert
96
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
97
+ raise AssertionError("nan or inf in clean_audios")
98
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
99
+ raise AssertionError("nan or inf in noisy_audios")
100
+ return clean_audios, noisy_audios
101
+
102
+
103
+ collate_fn = CollateFunction()
104
+
105
+
106
+ def main():
107
+ args = get_args()
108
+
109
+ config = DfNet2Config.from_pretrained(
110
+ pretrained_model_name_or_path=args.config_file,
111
+ )
112
+
113
+ serialization_dir = Path(args.serialization_dir)
114
+ serialization_dir.mkdir(parents=True, exist_ok=True)
115
+
116
+ logger = logging_config(serialization_dir)
117
+
118
+ random.seed(config.seed)
119
+ np.random.seed(config.seed)
120
+ torch.manual_seed(config.seed)
121
+ logger.info(f"set seed: {config.seed}")
122
+
123
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
+ n_gpu = torch.cuda.device_count()
125
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
126
+
127
+ # datasets
128
+ train_dataset = DenoiseJsonlDataset(
129
+ jsonl_file=args.train_dataset,
130
+ expected_sample_rate=config.sample_rate,
131
+ max_wave_value=32768.0,
132
+ min_snr_db=config.min_snr_db,
133
+ max_snr_db=config.max_snr_db,
134
+ # skip=225000,
135
+ )
136
+ valid_dataset = DenoiseJsonlDataset(
137
+ jsonl_file=args.valid_dataset,
138
+ expected_sample_rate=config.sample_rate,
139
+ max_wave_value=32768.0,
140
+ min_snr_db=config.min_snr_db,
141
+ max_snr_db=config.max_snr_db,
142
+ )
143
+ train_data_loader = DataLoader(
144
+ dataset=train_dataset,
145
+ batch_size=config.batch_size,
146
+ # shuffle=True,
147
+ sampler=None,
148
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
149
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
150
+ collate_fn=collate_fn,
151
+ pin_memory=False,
152
+ prefetch_factor=None if platform.system() == "Windows" else 2,
153
+ )
154
+ valid_data_loader = DataLoader(
155
+ dataset=valid_dataset,
156
+ batch_size=config.batch_size,
157
+ # shuffle=True,
158
+ sampler=None,
159
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
160
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
161
+ collate_fn=collate_fn,
162
+ pin_memory=False,
163
+ prefetch_factor=None if platform.system() == "Windows" else 2,
164
+ )
165
+
166
+ # models
167
+ logger.info(f"prepare models. config_file: {args.config_file}")
168
+ model = DfNet2PretrainedModel(config).to(device)
169
+ model.to(device)
170
+ model.train()
171
+
172
+ # optimizer
173
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
174
+ optimizer = torch.optim.AdamW(model.parameters(), config.lr)
175
+
176
+ # resume training
177
+ last_step_idx = -1
178
+ last_epoch = -1
179
+ for step_idx_str in serialization_dir.glob("steps-*"):
180
+ step_idx_str = Path(step_idx_str)
181
+ step_idx = step_idx_str.stem.split("-")[1]
182
+ step_idx = int(step_idx)
183
+ if step_idx > last_step_idx:
184
+ last_step_idx = step_idx
185
+ # last_epoch = 1
186
+
187
+ if last_step_idx != -1:
188
+ logger.info(f"resume from steps-{last_step_idx}.")
189
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
190
+
191
+ logger.info(f"load state dict for model.")
192
+ with open(model_pt.as_posix(), "rb") as f:
193
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
194
+ model.load_state_dict(state_dict, strict=True)
195
+
196
+ if config.lr_scheduler == "CosineAnnealingLR":
197
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
198
+ optimizer,
199
+ last_epoch=last_epoch,
200
+ # T_max=10 * config.eval_steps,
201
+ # eta_min=0.01 * config.lr,
202
+ **config.lr_scheduler_kwargs,
203
+ )
204
+ elif config.lr_scheduler == "MultiStepLR":
205
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
206
+ optimizer,
207
+ last_epoch=last_epoch,
208
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
209
+ )
210
+ else:
211
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
212
+
213
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
214
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
215
+ fft_size_list=[256, 512, 1024],
216
+ win_size_list=[256, 512, 1024],
217
+ hop_size_list=[128, 256, 512],
218
+ factor_sc=1.5,
219
+ factor_mag=1.0,
220
+ reduction="mean"
221
+ ).to(device)
222
+
223
+ # training loop
224
+
225
+ # state
226
+ average_pesq_score = 1000000000
227
+ average_loss = 1000000000
228
+ average_mr_stft_loss = 1000000000
229
+ average_neg_si_snr_loss = 1000000000
230
+ average_mask_loss = 1000000000
231
+ average_lsnr_loss = 1000000000
232
+
233
+ model_list = list()
234
+ best_epoch_idx = None
235
+ best_step_idx = None
236
+ best_metric = None
237
+ patience_count = 0
238
+
239
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
240
+
241
+ logger.info("training")
242
+ early_stop_flag = False
243
+ for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
244
+ if early_stop_flag:
245
+ break
246
+
247
+ # train
248
+ model.train()
249
+
250
+ total_pesq_score = 0.
251
+ total_loss = 0.
252
+ total_mr_stft_loss = 0.
253
+ total_neg_si_snr_loss = 0.
254
+ total_mask_loss = 0.
255
+ total_lsnr_loss = 0.
256
+ total_batches = 0.
257
+
258
+ progress_bar_train = tqdm(
259
+ initial=step_idx,
260
+ desc="Training; epoch-{}".format(epoch_idx),
261
+ )
262
+ for train_batch in train_data_loader:
263
+ clean_audios, noisy_audios = train_batch
264
+ clean_audios: torch.Tensor = clean_audios.to(device)
265
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
266
+
267
+ est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
268
+ # est_wav shape: [b, 1, n_samples]
269
+ est_wav = torch.squeeze(est_wav, dim=1)
270
+ # est_wav shape: [b, n_samples]
271
+
272
+ mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
273
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
274
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
275
+ lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
276
+
277
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss
278
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
279
+ logger.info(f"find nan or inf in loss. continue.")
280
+ continue
281
+
282
+ denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
283
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
284
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
285
+
286
+ optimizer.zero_grad()
287
+ loss.backward()
288
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
289
+ optimizer.step()
290
+ lr_scheduler.step()
291
+
292
+ total_pesq_score += pesq_score
293
+ total_loss += loss.item()
294
+ total_mr_stft_loss += mr_stft_loss.item()
295
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
296
+ total_mask_loss += mask_loss.item()
297
+ total_lsnr_loss += lsnr_loss.item()
298
+ total_batches += 1
299
+
300
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
301
+ average_loss = round(total_loss / total_batches, 4)
302
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
303
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
304
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
305
+ average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
306
+
307
+ progress_bar_train.update(1)
308
+ progress_bar_train.set_postfix({
309
+ "lr": lr_scheduler.get_last_lr()[0],
310
+ "pesq_score": average_pesq_score,
311
+ "loss": average_loss,
312
+ "mr_stft_loss": average_mr_stft_loss,
313
+ "neg_si_snr_loss": average_neg_si_snr_loss,
314
+ "mask_loss": average_mask_loss,
315
+ "lsnr_loss": average_lsnr_loss,
316
+ })
317
+
318
+ # evaluation
319
+ step_idx += 1
320
+ if step_idx % config.eval_steps == 0:
321
+ with torch.no_grad():
322
+ torch.cuda.empty_cache()
323
+
324
+ model.eval()
325
+
326
+ total_pesq_score = 0.
327
+ total_loss = 0.
328
+ total_mr_stft_loss = 0.
329
+ total_neg_si_snr_loss = 0.
330
+ total_mask_loss = 0.
331
+ total_lsnr_loss = 0.
332
+ total_batches = 0.
333
+
334
+ progress_bar_train.close()
335
+ progress_bar_eval = tqdm(
336
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
337
+ )
338
+ for eval_batch in valid_data_loader:
339
+ clean_audios, noisy_audios = eval_batch
340
+ clean_audios: torch.Tensor = clean_audios.to(device)
341
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
342
+
343
+ est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
344
+ # est_wav shape: [b, 1, n_samples]
345
+ est_wav = torch.squeeze(est_wav, dim=1)
346
+ # est_wav shape: [b, n_samples]
347
+
348
+ mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
349
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
350
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
351
+ lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
352
+
353
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss
354
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
355
+ logger.info(f"find nan or inf in loss. continue.")
356
+ continue
357
+
358
+ denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
359
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
360
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
361
+
362
+ total_pesq_score += pesq_score
363
+ total_loss += loss.item()
364
+ total_mr_stft_loss += mr_stft_loss.item()
365
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
366
+ total_mask_loss += mask_loss.item()
367
+ total_lsnr_loss += lsnr_loss.item()
368
+ total_batches += 1
369
+
370
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
371
+ average_loss = round(total_loss / total_batches, 4)
372
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
373
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
374
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
375
+ average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
376
+
377
+ progress_bar_eval.update(1)
378
+ progress_bar_eval.set_postfix({
379
+ "lr": lr_scheduler.get_last_lr()[0],
380
+ "pesq_score": average_pesq_score,
381
+ "loss": average_loss,
382
+ "mr_stft_loss": average_mr_stft_loss,
383
+ "neg_si_snr_loss": average_neg_si_snr_loss,
384
+ "mask_loss": average_mask_loss,
385
+ "lsnr_loss": average_lsnr_loss,
386
+ })
387
+
388
+ model.train()
389
+
390
+ total_pesq_score = 0.
391
+ total_loss = 0.
392
+ total_mr_stft_loss = 0.
393
+ total_neg_si_snr_loss = 0.
394
+ total_mask_loss = 0.
395
+ total_lsnr_loss = 0.
396
+ total_batches = 0.
397
+
398
+ progress_bar_eval.close()
399
+ progress_bar_train = tqdm(
400
+ initial=progress_bar_train.n,
401
+ postfix=progress_bar_train.postfix,
402
+ desc=progress_bar_train.desc,
403
+ )
404
+
405
+ # save path
406
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
407
+ save_dir.mkdir(parents=True, exist_ok=False)
408
+
409
+ # save models
410
+ model.save_pretrained(save_dir.as_posix())
411
+
412
+ model_list.append(save_dir)
413
+ if len(model_list) >= args.num_serialized_models_to_keep:
414
+ model_to_delete: Path = model_list.pop(0)
415
+ shutil.rmtree(model_to_delete.as_posix())
416
+
417
+ # save metric
418
+ if best_metric is None:
419
+ best_epoch_idx = epoch_idx
420
+ best_step_idx = step_idx
421
+ best_metric = average_pesq_score
422
+ elif average_pesq_score >= best_metric:
423
+ # great is better.
424
+ best_epoch_idx = epoch_idx
425
+ best_step_idx = step_idx
426
+ best_metric = average_pesq_score
427
+ else:
428
+ pass
429
+
430
+ metrics = {
431
+ "epoch_idx": epoch_idx,
432
+ "best_epoch_idx": best_epoch_idx,
433
+ "best_step_idx": best_step_idx,
434
+ "pesq_score": average_pesq_score,
435
+ "loss": average_loss,
436
+ "mr_stft_loss": average_mr_stft_loss,
437
+ "neg_si_snr_loss": average_neg_si_snr_loss,
438
+ "mask_loss": average_mask_loss,
439
+ "lsnr_loss": average_lsnr_loss,
440
+ }
441
+ metrics_filename = save_dir / "metrics_epoch.json"
442
+ with open(metrics_filename, "w", encoding="utf-8") as f:
443
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
444
+
445
+ # save best
446
+ best_dir = serialization_dir / "best"
447
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
448
+ if best_dir.exists():
449
+ shutil.rmtree(best_dir)
450
+ shutil.copytree(save_dir, best_dir)
451
+
452
+ # early stop
453
+ early_stop_flag = False
454
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
455
+ patience_count = 0
456
+ else:
457
+ patience_count += 1
458
+ if patience_count >= args.patience:
459
+ early_stop_flag = True
460
+
461
+ # early stop
462
+ if early_stop_flag:
463
+ break
464
+
465
+ return
466
+
467
+
468
+ if __name__ == "__main__":
469
+ main()
examples/silero_vad_by_webrtcvad/yaml/config.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "silero_vad"
2
+
3
+ sample_rate: 8000
4
+ nfft: 512
5
+ win_size: 240
6
+ hop_size: 80
7
+ win_type: hann
8
+
9
+ in_channels: 64
10
+ hidden_size: 128
11
+
12
+ lr: 0.001
13
+ lr_scheduler: CosineAnnealingLR
14
+ lr_scheduler_kwargs: {}
15
+
16
+ max_epochs: 100
17
+ clip_grad_norm: 10.0
18
+ seed: 1234
19
+
20
+ num_workers: 4
21
+ batch_size: 4
22
+ eval_steps: 25000
install.sh ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # bash install.sh --stage 2 --stop_stage 2 --system_version centos
4
+
5
+
6
+ python_version=3.12.1
7
+ system_version="centos";
8
+
9
+ verbose=true;
10
+ stage=-1
11
+ stop_stage=0
12
+
13
+
14
+ # parse options
15
+ while true; do
16
+ [ -z "${1:-}" ] && break; # break if there are no arguments
17
+ case "$1" in
18
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
19
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
20
+ old_value="(eval echo \\$$name)";
21
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
22
+ was_bool=true;
23
+ else
24
+ was_bool=false;
25
+ fi
26
+
27
+ # Set the variable to the right value-- the escaped quotes make it work if
28
+ # the option had spaces, like --cmd "queue.pl -sync y"
29
+ eval "${name}=\"$2\"";
30
+
31
+ # Check that Boolean-valued arguments are really Boolean.
32
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
33
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
34
+ exit 1;
35
+ fi
36
+ shift 2;
37
+ ;;
38
+
39
+ *) break;
40
+ esac
41
+ done
42
+
43
+ work_dir="$(pwd)"
44
+
45
+
46
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
47
+ $verbose && echo "stage 1: install python"
48
+ cd "${work_dir}" || exit 1;
49
+
50
+ sh ./script/install_python.sh --python_version "${python_version}" --system_version "${system_version}"
51
+ fi
52
+
53
+
54
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
55
+ $verbose && echo "stage 2: create virtualenv"
56
+
57
+ # /usr/local/python-3.12.1/bin/virtualenv cc_vad
58
+ # source /data/local/bin/cc_vad/bin/activate
59
+ /usr/local/python-${python_version}/bin/pip3 install virtualenv
60
+ mkdir -p /data/local/bin
61
+ cd /data/local/bin || exit 1;
62
+ /usr/local/python-${python_version}/bin/virtualenv cc_vad
63
+
64
+ fi
log.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from datetime import datetime
4
+ import logging
5
+ from logging.handlers import RotatingFileHandler, TimedRotatingFileHandler
6
+ import os
7
+ from zoneinfo import ZoneInfo # Python 3.9+ 自带,无需安装
8
+
9
+
10
+ def get_converter(tz_info: str = "Asia/Shanghai"):
11
+ def converter(timestamp):
12
+ dt = datetime.fromtimestamp(timestamp, ZoneInfo(tz_info))
13
+ result = dt.timetuple()
14
+ return result
15
+ return converter
16
+
17
+
18
+ def setup_size_rotating(log_directory: str, tz_info: str = "Asia/Shanghai"):
19
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
20
+
21
+ formatter = logging.Formatter(
22
+ fmt=fmt,
23
+ datefmt="%Y-%m-%d %H:%M:%S %z"
24
+ )
25
+ formatter.converter = get_converter(tz_info)
26
+
27
+ stream_handler = logging.StreamHandler()
28
+ stream_handler.setLevel(logging.INFO)
29
+ stream_handler.setFormatter(formatter)
30
+
31
+ # main
32
+ main_logger = logging.getLogger("main")
33
+ main_logger.addHandler(stream_handler)
34
+ main_info_file_handler = RotatingFileHandler(
35
+ filename=os.path.join(log_directory, "main.log"),
36
+ maxBytes=100*1024*1024, # 100MB
37
+ encoding="utf-8",
38
+ backupCount=2,
39
+ )
40
+ main_info_file_handler.setLevel(logging.INFO)
41
+ main_info_file_handler.setFormatter(logging.Formatter(fmt))
42
+ main_logger.addHandler(main_info_file_handler)
43
+
44
+ # http
45
+ http_logger = logging.getLogger("http")
46
+ http_file_handler = RotatingFileHandler(
47
+ filename=os.path.join(log_directory, "http.log"),
48
+ maxBytes=100*1024*1024, # 100MB
49
+ encoding="utf-8",
50
+ backupCount=2,
51
+ )
52
+ http_file_handler.setLevel(logging.DEBUG)
53
+ http_file_handler.setFormatter(logging.Formatter(fmt))
54
+ http_logger.addHandler(http_file_handler)
55
+
56
+ # api
57
+ api_logger = logging.getLogger("api")
58
+ api_file_handler = RotatingFileHandler(
59
+ filename=os.path.join(log_directory, "api.log"),
60
+ maxBytes=10*1024*1024, # 10MB
61
+ encoding="utf-8",
62
+ backupCount=2,
63
+ )
64
+ api_file_handler.setLevel(logging.DEBUG)
65
+ api_file_handler.setFormatter(logging.Formatter(fmt))
66
+ api_logger.addHandler(api_file_handler)
67
+
68
+ # alarm
69
+ alarm_logger = logging.getLogger("alarm")
70
+ alarm_file_handler = RotatingFileHandler(
71
+ filename=os.path.join(log_directory, "alarm.log"),
72
+ maxBytes=1*1024*1024, # 1MB
73
+ encoding="utf-8",
74
+ backupCount=2,
75
+ )
76
+ alarm_file_handler.setLevel(logging.DEBUG)
77
+ alarm_file_handler.setFormatter(logging.Formatter(fmt))
78
+ alarm_logger.addHandler(alarm_file_handler)
79
+
80
+ debug_file_handler = RotatingFileHandler(
81
+ filename=os.path.join(log_directory, "debug.log"),
82
+ maxBytes=1*1024*1024, # 1MB
83
+ encoding="utf-8",
84
+ backupCount=2,
85
+ )
86
+ debug_file_handler.setLevel(logging.DEBUG)
87
+ debug_file_handler.setFormatter(logging.Formatter(fmt))
88
+
89
+ info_file_handler = RotatingFileHandler(
90
+ filename=os.path.join(log_directory, "info.log"),
91
+ maxBytes=1*1024*1024, # 1MB
92
+ encoding="utf-8",
93
+ backupCount=2,
94
+ )
95
+ info_file_handler.setLevel(logging.INFO)
96
+ info_file_handler.setFormatter(logging.Formatter(fmt))
97
+
98
+ error_file_handler = RotatingFileHandler(
99
+ filename=os.path.join(log_directory, "error.log"),
100
+ maxBytes=1*1024*1024, # 1MB
101
+ encoding="utf-8",
102
+ backupCount=2,
103
+ )
104
+ error_file_handler.setLevel(logging.ERROR)
105
+ error_file_handler.setFormatter(logging.Formatter(fmt))
106
+
107
+ logging.basicConfig(
108
+ level=logging.DEBUG,
109
+ datefmt="%a, %d %b %Y %H:%M:%S",
110
+ handlers=[
111
+ debug_file_handler,
112
+ info_file_handler,
113
+ error_file_handler,
114
+ ]
115
+ )
116
+
117
+
118
+ def setup_time_rotating(log_directory: str):
119
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
120
+
121
+ stream_handler = logging.StreamHandler()
122
+ stream_handler.setLevel(logging.INFO)
123
+ stream_handler.setFormatter(logging.Formatter(fmt))
124
+
125
+ # main
126
+ main_logger = logging.getLogger("main")
127
+ main_logger.addHandler(stream_handler)
128
+ main_info_file_handler = TimedRotatingFileHandler(
129
+ filename=os.path.join(log_directory, "main.log"),
130
+ encoding="utf-8",
131
+ when="midnight",
132
+ interval=1,
133
+ backupCount=7
134
+ )
135
+ main_info_file_handler.setLevel(logging.INFO)
136
+ main_info_file_handler.setFormatter(logging.Formatter(fmt))
137
+ main_logger.addHandler(main_info_file_handler)
138
+
139
+ # http
140
+ http_logger = logging.getLogger("http")
141
+ http_file_handler = TimedRotatingFileHandler(
142
+ filename=os.path.join(log_directory, "http.log"),
143
+ encoding='utf-8',
144
+ when="midnight",
145
+ interval=1,
146
+ backupCount=7
147
+ )
148
+ http_file_handler.setLevel(logging.DEBUG)
149
+ http_file_handler.setFormatter(logging.Formatter(fmt))
150
+ http_logger.addHandler(http_file_handler)
151
+
152
+ # api
153
+ api_logger = logging.getLogger("api")
154
+ api_file_handler = TimedRotatingFileHandler(
155
+ filename=os.path.join(log_directory, "api.log"),
156
+ encoding='utf-8',
157
+ when="midnight",
158
+ interval=1,
159
+ backupCount=7
160
+ )
161
+ api_file_handler.setLevel(logging.DEBUG)
162
+ api_file_handler.setFormatter(logging.Formatter(fmt))
163
+ api_logger.addHandler(api_file_handler)
164
+
165
+ # alarm
166
+ alarm_logger = logging.getLogger("alarm")
167
+ alarm_file_handler = TimedRotatingFileHandler(
168
+ filename=os.path.join(log_directory, "alarm.log"),
169
+ encoding="utf-8",
170
+ when="midnight",
171
+ interval=1,
172
+ backupCount=7
173
+ )
174
+ alarm_file_handler.setLevel(logging.DEBUG)
175
+ alarm_file_handler.setFormatter(logging.Formatter(fmt))
176
+ alarm_logger.addHandler(alarm_file_handler)
177
+
178
+ debug_file_handler = TimedRotatingFileHandler(
179
+ filename=os.path.join(log_directory, "debug.log"),
180
+ encoding="utf-8",
181
+ when="D",
182
+ interval=1,
183
+ backupCount=7
184
+ )
185
+ debug_file_handler.setLevel(logging.DEBUG)
186
+ debug_file_handler.setFormatter(logging.Formatter(fmt))
187
+
188
+ info_file_handler = TimedRotatingFileHandler(
189
+ filename=os.path.join(log_directory, "info.log"),
190
+ encoding="utf-8",
191
+ when="D",
192
+ interval=1,
193
+ backupCount=7
194
+ )
195
+ info_file_handler.setLevel(logging.INFO)
196
+ info_file_handler.setFormatter(logging.Formatter(fmt))
197
+
198
+ error_file_handler = TimedRotatingFileHandler(
199
+ filename=os.path.join(log_directory, "error.log"),
200
+ encoding="utf-8",
201
+ when="D",
202
+ interval=1,
203
+ backupCount=7
204
+ )
205
+ error_file_handler.setLevel(logging.ERROR)
206
+ error_file_handler.setFormatter(logging.Formatter(fmt))
207
+
208
+ logging.basicConfig(
209
+ level=logging.DEBUG,
210
+ datefmt="%a, %d %b %Y %H:%M:%S",
211
+ handlers=[
212
+ debug_file_handler,
213
+ info_file_handler,
214
+ error_file_handler,
215
+ ]
216
+ )
217
+
218
+
219
+ if __name__ == "__main__":
220
+ pass
main.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import logging
5
+ import platform
6
+
7
+ import gradio as gr
8
+
9
+ import log
10
+ from project_settings import environment, log_directory, time_zone_info
11
+ from toolbox.os.command import Command
12
+
13
+ log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
14
+
15
+ logger = logging.getLogger("main")
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument(
21
+ "--hf_token",
22
+ default=environment.get("hf_token"),
23
+ type=str,
24
+ )
25
+ parser.add_argument(
26
+ "--server_port",
27
+ default=environment.get("server_port", 7860),
28
+ type=int
29
+ )
30
+
31
+ args = parser.parse_args()
32
+ return args
33
+
34
+
35
+ def shell(cmd: str):
36
+ return Command.popen(cmd)
37
+
38
+
39
+ def main():
40
+ args = get_args()
41
+
42
+ # ui
43
+ with gr.Blocks() as blocks:
44
+ gr.Markdown(value="vad.")
45
+ with gr.Tabs():
46
+ with gr.TabItem("shell"):
47
+ shell_text = gr.Textbox(label="cmd")
48
+ shell_button = gr.Button("run")
49
+ shell_output = gr.Textbox(label="output")
50
+
51
+ shell_button.click(
52
+ shell,
53
+ inputs=[shell_text,],
54
+ outputs=[shell_output],
55
+ )
56
+
57
+ # http://127.0.0.1:7866/
58
+ # http://10.75.27.247:7866/
59
+ blocks.queue().launch(
60
+ # share=True,
61
+ share=False if platform.system() == "Windows" else False,
62
+ server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
63
+ server_port=args.server_port
64
+ )
65
+ return
66
+
67
+
68
+ if __name__ == "__main__":
69
+ main()
project_settings.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from toolbox.os.environment import EnvironmentManager
7
+
8
+
9
+ project_path = os.path.abspath(os.path.dirname(__file__))
10
+ project_path = Path(project_path)
11
+
12
+ time_zone_info = "Asia/Shanghai"
13
+
14
+ log_directory = project_path / "logs"
15
+ log_directory.mkdir(parents=True, exist_ok=True)
16
+
17
+ # temp_directory = project_path / "temp"
18
+ # temp_directory.mkdir(parents=True, exist_ok=True)
19
+
20
+ environment = EnvironmentManager(
21
+ path=os.path.join(project_path, "dotenv"),
22
+ env=os.environ.get("environment", "dev"),
23
+ )
24
+
25
+
26
+ if __name__ == '__main__':
27
+ pass
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.33.0
2
+ gradio_client==1.10.2
3
+ datasets==3.2.0
4
+ python-dotenv==1.0.1
5
+ scipy==1.15.1
6
+ librosa==0.10.2.post1
7
+ pandas==2.2.3
8
+ openpyxl==3.1.5
9
+ torch==2.5.1
10
+ torchaudio==2.5.1
11
+ overrides==7.7.0
12
+ webrtcvad==2.0.10
13
+ matplotlib==3.10.3
toolbox/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/json/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/json/misc.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Callable
4
+
5
+
6
+ def traverse(js, callback: Callable, *args, **kwargs):
7
+ if isinstance(js, list):
8
+ result = list()
9
+ for l in js:
10
+ l = traverse(l, callback, *args, **kwargs)
11
+ result.append(l)
12
+ return result
13
+ elif isinstance(js, tuple):
14
+ result = list()
15
+ for l in js:
16
+ l = traverse(l, callback, *args, **kwargs)
17
+ result.append(l)
18
+ return tuple(result)
19
+ elif isinstance(js, dict):
20
+ result = dict()
21
+ for k, v in js.items():
22
+ k = traverse(k, callback, *args, **kwargs)
23
+ v = traverse(v, callback, *args, **kwargs)
24
+ result[k] = v
25
+ return result
26
+ elif isinstance(js, int):
27
+ return callback(js, *args, **kwargs)
28
+ elif isinstance(js, str):
29
+ return callback(js, *args, **kwargs)
30
+ else:
31
+ return js
32
+
33
+
34
+ def demo1():
35
+ d = {
36
+ "env": "ppe",
37
+ "mysql_connect": {
38
+ "host": "$mysql_connect_host",
39
+ "port": 3306,
40
+ "user": "callbot",
41
+ "password": "NxcloudAI2021!",
42
+ "database": "callbot_ppe",
43
+ "charset": "utf8"
44
+ },
45
+ "es_connect": {
46
+ "hosts": ["10.20.251.8"],
47
+ "http_auth": ["elastic", "ElasticAI2021!"],
48
+ "port": 9200
49
+ }
50
+ }
51
+
52
+ def callback(s):
53
+ if isinstance(s, str) and s.startswith('$'):
54
+ return s[1:]
55
+ return s
56
+
57
+ result = traverse(d, callback=callback)
58
+ print(result)
59
+ return
60
+
61
+
62
+ if __name__ == '__main__':
63
+ demo1()
toolbox/os/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/os/command.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+
5
+
6
+ class Command(object):
7
+ custom_command = [
8
+ "cd"
9
+ ]
10
+
11
+ @staticmethod
12
+ def _get_cmd(command):
13
+ command = str(command).strip()
14
+ if command == "":
15
+ return None
16
+ cmd_and_args = command.split(sep=" ")
17
+ cmd = cmd_and_args[0]
18
+ args = " ".join(cmd_and_args[1:])
19
+ return cmd, args
20
+
21
+ @classmethod
22
+ def popen(cls, command):
23
+ cmd, args = cls._get_cmd(command)
24
+ if cmd in cls.custom_command:
25
+ method = getattr(cls, cmd)
26
+ return method(args)
27
+ else:
28
+ resp = os.popen(command)
29
+ result = resp.read()
30
+ resp.close()
31
+ return result
32
+
33
+ @classmethod
34
+ def cd(cls, args):
35
+ if args.startswith("/"):
36
+ os.chdir(args)
37
+ else:
38
+ pwd = os.getcwd()
39
+ path = os.path.join(pwd, args)
40
+ os.chdir(path)
41
+
42
+ @classmethod
43
+ def system(cls, command):
44
+ return os.system(command)
45
+
46
+ def __init__(self):
47
+ pass
48
+
49
+
50
+ def ps_ef_grep(keyword: str):
51
+ cmd = "ps -ef | grep {}".format(keyword)
52
+ rows = Command.popen(cmd)
53
+ rows = str(rows).split("\n")
54
+ rows = [row for row in rows if row.__contains__(keyword) and not row.__contains__("grep")]
55
+ return rows
56
+
57
+
58
+ if __name__ == "__main__":
59
+ pass
toolbox/os/environment.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import json
4
+ import os
5
+
6
+ from dotenv import load_dotenv
7
+ from dotenv.main import DotEnv
8
+
9
+ from toolbox.json.misc import traverse
10
+
11
+
12
+ class EnvironmentManager(object):
13
+ def __init__(self, path, env, override=False):
14
+ filename = os.path.join(path, '{}.env'.format(env))
15
+ self.filename = filename
16
+
17
+ load_dotenv(
18
+ dotenv_path=filename,
19
+ override=override
20
+ )
21
+
22
+ self._environ = dict()
23
+
24
+ def open_dotenv(self, filename: str = None):
25
+ filename = filename or self.filename
26
+ dotenv = DotEnv(
27
+ dotenv_path=filename,
28
+ stream=None,
29
+ verbose=False,
30
+ interpolate=False,
31
+ override=False,
32
+ encoding="utf-8",
33
+ )
34
+ result = dotenv.dict()
35
+ return result
36
+
37
+ def get(self, key, default=None, dtype=str):
38
+ result = os.environ.get(key)
39
+ if result is None:
40
+ if default is None:
41
+ result = None
42
+ else:
43
+ result = default
44
+ else:
45
+ result = dtype(result)
46
+ self._environ[key] = result
47
+ return result
48
+
49
+
50
+ _DEFAULT_DTYPE_MAP = {
51
+ 'int': int,
52
+ 'float': float,
53
+ 'str': str,
54
+ 'json.loads': json.loads
55
+ }
56
+
57
+
58
+ class JsonConfig(object):
59
+ """
60
+ 将 json 中, 形如 `$float:threshold` 的值, 处理为:
61
+ 从环境变量中查到 threshold, 再将其转换为 float 类型.
62
+ """
63
+ def __init__(self, dtype_map: dict = None, environment: EnvironmentManager = None):
64
+ self.dtype_map = dtype_map or _DEFAULT_DTYPE_MAP
65
+ self.environment = environment or os.environ
66
+
67
+ def sanitize_by_filename(self, filename: str):
68
+ with open(filename, 'r', encoding='utf-8') as f:
69
+ js = json.load(f)
70
+
71
+ return self.sanitize_by_json(js)
72
+
73
+ def sanitize_by_json(self, js):
74
+ js = traverse(
75
+ js,
76
+ callback=self.sanitize,
77
+ environment=self.environment
78
+ )
79
+ return js
80
+
81
+ def sanitize(self, string, environment):
82
+ """支持 $ 符开始的, 环境变量配置"""
83
+ if isinstance(string, str) and string.startswith('$'):
84
+ dtype, key = string[1:].split(':')
85
+ dtype = self.dtype_map[dtype]
86
+
87
+ value = environment.get(key)
88
+ if value is None:
89
+ raise AssertionError('environment not exist. key: {}'.format(key))
90
+
91
+ value = dtype(value)
92
+ result = value
93
+ else:
94
+ result = string
95
+ return result
96
+
97
+
98
+ def demo1():
99
+ import json
100
+
101
+ from project_settings import project_path
102
+
103
+ environment = EnvironmentManager(
104
+ path=os.path.join(project_path, 'server/callbot_server/dotenv'),
105
+ env='dev',
106
+ )
107
+ init_scenes = environment.get(key='init_scenes', dtype=json.loads)
108
+ print(init_scenes)
109
+ print(environment._environ)
110
+ return
111
+
112
+
113
+ if __name__ == '__main__':
114
+ demo1()
toolbox/os/other.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import inspect
3
+
4
+
5
+ def pwd():
6
+ """你在哪个文件调用此函数, 它就会返回那个文件所在的 dir 目标"""
7
+ frame = inspect.stack()[1]
8
+ module = inspect.getmodule(frame[0])
9
+ return os.path.dirname(os.path.abspath(module.__file__))
toolbox/torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torch/utils/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torch/utils/data/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torch/utils/data/dataset/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torch/utils/data/dataset/vad_jsonl_dataset.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import json
4
+ import random
5
+ from typing import List
6
+
7
+ import librosa
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import Dataset, IterableDataset
11
+
12
+
13
+ class VadJsonlDataset(IterableDataset):
14
+ def __init__(self,
15
+ jsonl_file: str,
16
+ expected_sample_rate: int,
17
+ resample: bool = False,
18
+ max_wave_value: float = 1.0,
19
+ buffer_size: int = 1000,
20
+ min_snr_db: float = None,
21
+ max_snr_db: float = None,
22
+ eps: float = 1e-8,
23
+ skip: int = 0,
24
+ ):
25
+ self.jsonl_file = jsonl_file
26
+ self.expected_sample_rate = expected_sample_rate
27
+ self.resample = resample
28
+ self.max_wave_value = max_wave_value
29
+ self.min_snr_db = min_snr_db
30
+ self.max_snr_db = max_snr_db
31
+ self.eps = eps
32
+ self.skip = skip
33
+
34
+ self.buffer_size = buffer_size
35
+ self.buffer_samples: List[dict] = list()
36
+
37
+ def __iter__(self):
38
+ self.buffer_samples = list()
39
+
40
+ iterable_source = self.iterable_source()
41
+
42
+ try:
43
+ for _ in range(self.skip):
44
+ next(iterable_source)
45
+ except StopIteration:
46
+ pass
47
+
48
+ # 初始填充缓冲区
49
+ try:
50
+ for _ in range(self.buffer_size):
51
+ self.buffer_samples.append(next(iterable_source))
52
+ except StopIteration:
53
+ pass
54
+
55
+ # 动态替换逻辑
56
+ while True:
57
+ try:
58
+ item = next(iterable_source)
59
+ # 随机替换缓冲区元素
60
+ replace_idx = random.randint(0, len(self.buffer_samples) - 1)
61
+ sample = self.buffer_samples[replace_idx]
62
+ self.buffer_samples[replace_idx] = item
63
+ yield self.convert_sample(sample)
64
+ except StopIteration:
65
+ break
66
+
67
+ # 清空剩余元素
68
+ random.shuffle(self.buffer_samples)
69
+ for sample in self.buffer_samples:
70
+ yield self.convert_sample(sample)
71
+
72
+ def iterable_source(self):
73
+ last_sample = None
74
+ with open(self.jsonl_file, "r", encoding="utf-8") as f:
75
+ for row in f:
76
+ row = json.loads(row)
77
+ noise_filename = row["noise_filename"]
78
+ noise_raw_duration = row["noise_raw_duration"]
79
+ noise_offset = row["noise_offset"]
80
+ noise_duration = row["noise_duration"]
81
+
82
+ speech_filename = row["speech_filename"]
83
+ speech_raw_duration = row["speech_raw_duration"]
84
+ speech_offset = row["speech_offset"]
85
+ speech_duration = row["speech_duration"]
86
+
87
+ if self.min_snr_db is None or self.max_snr_db is None:
88
+ snr_db = row["snr_db"]
89
+ else:
90
+ snr_db = random.uniform(self.min_snr_db, self.max_snr_db)
91
+
92
+ vad_segments = row["vad_segments"]
93
+
94
+ sample = {
95
+ "noise_filename": noise_filename,
96
+ "noise_raw_duration": noise_raw_duration,
97
+ "noise_offset": noise_offset,
98
+ "noise_duration": noise_duration,
99
+
100
+ "speech_filename": speech_filename,
101
+ "speech_raw_duration": speech_raw_duration,
102
+ "speech_offset": speech_offset,
103
+ "speech_duration": speech_duration,
104
+
105
+ "snr_db": snr_db,
106
+
107
+ "vad_segments": vad_segments,
108
+ }
109
+ if last_sample is None:
110
+ last_sample = sample
111
+ continue
112
+ yield sample
113
+ yield last_sample
114
+
115
+ def convert_sample(self, sample: dict):
116
+ noise_filename = sample["noise_filename"]
117
+ noise_offset = sample["noise_offset"]
118
+ noise_duration = sample["noise_duration"]
119
+
120
+ speech_filename = sample["speech_filename"]
121
+ speech_offset = sample["speech_offset"]
122
+ speech_duration = sample["speech_duration"]
123
+
124
+ snr_db = sample["snr_db"]
125
+
126
+ vad_segments = sample["vad_segments"]
127
+
128
+ noise_wave = self.filename_to_waveform(noise_filename, noise_offset, noise_duration)
129
+ speech_wave = self.filename_to_waveform(speech_filename, speech_offset, speech_duration)
130
+
131
+ noisy_wave, _ = self.mix_speech_and_noise(
132
+ speech=speech_wave.numpy(),
133
+ noise=noise_wave.numpy(),
134
+ snr_db=snr_db, eps=self.eps,
135
+ )
136
+ noisy_wave = torch.tensor(noisy_wave, dtype=torch.float32)
137
+
138
+ result = {
139
+ "noisy_wave": noisy_wave,
140
+ "vad_segments": vad_segments,
141
+ }
142
+ return result
143
+
144
+ def filename_to_waveform(self, filename: str, offset: float, duration: float):
145
+ try:
146
+ waveform, sample_rate = librosa.load(
147
+ filename,
148
+ sr=self.expected_sample_rate,
149
+ offset=offset,
150
+ duration=duration,
151
+ )
152
+ except ValueError as e:
153
+ print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}")
154
+ raise e
155
+ waveform = torch.tensor(waveform, dtype=torch.float32)
156
+ return waveform
157
+
158
+ @staticmethod
159
+ def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float, eps: float = 1e-8):
160
+ l1 = len(speech)
161
+ l2 = len(noise)
162
+ l = min(l1, l2)
163
+ speech = speech[:l]
164
+ noise = noise[:l]
165
+
166
+ # np.float32, value between (-1, 1).
167
+
168
+ speech_power = np.mean(np.square(speech))
169
+ noise_power = speech_power / (10 ** (snr_db / 10))
170
+
171
+ noise_adjusted = np.sqrt(noise_power) * noise / (np.sqrt(np.mean(noise ** 2)) + eps)
172
+
173
+ noisy_signal = speech + noise_adjusted
174
+
175
+ return noisy_signal, noise_adjusted
176
+
177
+
178
+ if __name__ == "__main__":
179
+ pass
toolbox/torchaudio/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/configuration_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import copy
4
+ import os
5
+ from typing import Any, Dict, Union
6
+
7
+ import yaml
8
+
9
+
10
+ CONFIG_FILE = "config.yaml"
11
+ DISCRIMINATOR_CONFIG_FILE = "discriminator_config.yaml"
12
+
13
+
14
+ class PretrainedConfig(object):
15
+ def __init__(self, **kwargs):
16
+ pass
17
+
18
+ @classmethod
19
+ def _dict_from_yaml_file(cls, yaml_file: Union[str, os.PathLike]):
20
+ with open(yaml_file, encoding="utf-8") as f:
21
+ config_dict = yaml.safe_load(f)
22
+ return config_dict
23
+
24
+ @classmethod
25
+ def get_config_dict(
26
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike]
27
+ ) -> Dict[str, Any]:
28
+ if os.path.isdir(pretrained_model_name_or_path):
29
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_FILE)
30
+ else:
31
+ config_file = pretrained_model_name_or_path
32
+ config_dict = cls._dict_from_yaml_file(config_file)
33
+ return config_dict
34
+
35
+ @classmethod
36
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
37
+ for k, v in kwargs.items():
38
+ if k in config_dict.keys():
39
+ config_dict[k] = v
40
+ config = cls(**config_dict)
41
+ return config
42
+
43
+ @classmethod
44
+ def from_pretrained(
45
+ cls,
46
+ pretrained_model_name_or_path: Union[str, os.PathLike],
47
+ **kwargs,
48
+ ):
49
+ config_dict = cls.get_config_dict(pretrained_model_name_or_path)
50
+ return cls.from_dict(config_dict, **kwargs)
51
+
52
+ def to_dict(self):
53
+ output = copy.deepcopy(self.__dict__)
54
+ return output
55
+
56
+ def to_yaml_file(self, yaml_file_path: Union[str, os.PathLike]):
57
+ config_dict = self.to_dict()
58
+
59
+ with open(yaml_file_path, "w", encoding="utf-8") as writer:
60
+ yaml.safe_dump(config_dict, writer)
61
+
62
+
63
+ if __name__ == '__main__':
64
+ pass
toolbox/torchaudio/losses/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/losses/vad_loss/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/losses/vad_loss/base_vad_loss.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class BaseVadLoss(nn.Module):
10
+ def __init__(self):
11
+ super(BaseVadLoss, self).__init__()
12
+
13
+ @staticmethod
14
+ def get_targets(inputs: torch.Tensor, batch_vad_segments: List[List[Tuple[float, float]]], duration: float):
15
+ """
16
+ :param inputs: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation.
17
+ :param batch_vad_segments: VAD segment for each audio
18
+ :param duration: float. The total duration of each audio in the batch.
19
+ :return: targets, shape as `inputs`.
20
+ """
21
+ b, t, _ = inputs.shape
22
+
23
+ batch_vad_segments_ = list()
24
+ for vad_segments in batch_vad_segments:
25
+ vad_segments_ = list()
26
+ for start, end in vad_segments:
27
+ start_ = start / duration * t
28
+ end_ = end / duration * t
29
+ start_ = round(start_)
30
+ end_ = round(end_)
31
+ vad_segments_.append([start_, end_])
32
+ batch_vad_segments_.append(vad_segments_)
33
+
34
+ targets = torch.zeros_like(inputs)
35
+ for idx, vad_segments_ in enumerate(batch_vad_segments_):
36
+ for start_, end_ in vad_segments_:
37
+ targets[idx, start_:end_, :] = 1
38
+
39
+ return targets
40
+
41
+
42
+ if __name__ == "__main__":
43
+ pass
toolbox/torchaudio/losses/vad_loss/bce_vad_loss.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss
9
+
10
+
11
+ class BCEVadLoss(BaseVadLoss):
12
+ """
13
+ Binary Cross-Entropy Loss, BCE Loss
14
+ """
15
+ def __init__(self,
16
+ reduction: str = "mean",
17
+ ):
18
+ super(BCEVadLoss, self).__init__()
19
+ self.reduction = reduction
20
+
21
+ self.bce_loss_fn = nn.BCELoss(reduction=reduction)
22
+
23
+ def forward(self, inputs: torch.Tensor, batch_vad_segments: List[List[Tuple[float, float]]], duration: float):
24
+ """
25
+ :param inputs: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation.
26
+ :param batch_vad_segments: VAD segment for each audio
27
+ :param duration: float. The total duration of each audio in the batch.
28
+ :return:
29
+ """
30
+
31
+ targets = self.get_targets(inputs, batch_vad_segments, duration)
32
+
33
+ loss = self.bce_loss_fn.forward(inputs, targets)
34
+ return loss
35
+
36
+
37
+ def main():
38
+ inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
39
+
40
+ vad_segments = [
41
+ [[0.24, 1.15], [2.21, 3.2]],
42
+ ]
43
+
44
+ loss_fn = BCEVadLoss()
45
+
46
+ loss = loss_fn.forward(inputs, vad_segments, duration=4)
47
+ print(loss)
48
+ return
49
+
50
+
51
+ if __name__ == "__main__":
52
+ main()
toolbox/torchaudio/losses/vad_loss/dice_vad_loss.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss
9
+
10
+
11
+ class DiceVadLoss(BaseVadLoss):
12
+ def __init__(self,
13
+ reduction: str = "mean",
14
+ eps: float = 1e-6,
15
+ ):
16
+ super(DiceVadLoss, self).__init__()
17
+ self.reduction = reduction
18
+ self.eps = eps
19
+
20
+ if reduction not in ("sum", "mean"):
21
+ raise AssertionError(f"param reduction must be sum or mean.")
22
+
23
+ def forward(self, inputs: torch.Tensor, batch_vad_segments: List[List[Tuple[float, float]]], duration: float):
24
+ """
25
+ :param inputs: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation.
26
+ :param batch_vad_segments: VAD segment for each audio
27
+ :param duration: float. The total duration of each audio in the batch.
28
+ :return:
29
+ """
30
+ targets = self.get_targets(inputs, batch_vad_segments, duration)
31
+
32
+ inputs_ = torch.squeeze(inputs, dim=-1)
33
+ targets_ = torch.squeeze(targets, dim=-1)
34
+ # shape: [b, t]
35
+
36
+ intersection = (inputs_ * targets_).sum(dim=-1)
37
+ union = (inputs_ + targets_).sum(dim=-1)
38
+ # shape: [b,]
39
+
40
+ dice = (2. * intersection + self.eps) / (union + self.eps)
41
+ # shape: [b,]
42
+
43
+ loss = 1. - dice
44
+ # shape: [b,]
45
+
46
+ if self.reduction == "mean":
47
+ loss = torch.mean(loss)
48
+ elif self.reduction == "sum":
49
+ loss = torch.sum(loss)
50
+ else:
51
+ raise AssertionError
52
+ return loss
53
+
54
+
55
+ def main():
56
+ inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
57
+
58
+ vad_segments = [
59
+ [[0.24, 1.15], [2.21, 3.2]],
60
+ ]
61
+
62
+ loss_fn = DiceVadLoss()
63
+
64
+ loss = loss_fn.forward(inputs, vad_segments, duration=4)
65
+ print(loss)
66
+ return
67
+
68
+
69
+ if __name__ == "__main__":
70
+ main()
toolbox/torchaudio/metrics/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/metrics/vad_metrics/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/metrics/vad_metrics/vad_accuracy.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+
5
+
6
+ class VadAccuracy(object):
7
+ def __init__(self, threshold: float = 0.5) -> None:
8
+ self.threshold = threshold
9
+
10
+ self.correct_count = 0.
11
+ self.total_count = 0.
12
+
13
+ def __call__(self,
14
+ predictions: torch.Tensor,
15
+ gold_labels: torch.Tensor,
16
+ ):
17
+ """
18
+ :param predictions: torch.Tensor, shape: [b, t, 1]. vad prob, after sigmoid activation.
19
+ :param gold_labels: torch.Tensor, shape: [b, t, 1].
20
+ :return:
21
+ """
22
+ predictions = (predictions > self.threshold).float()
23
+ correct = predictions.eq(gold_labels).float()
24
+ self.correct_count += correct.sum()
25
+ self.total_count += gold_labels.numel()
26
+
27
+ def get_metric(self, reset: bool = False):
28
+ """
29
+ Returns
30
+ -------
31
+ The accumulated accuracy.
32
+ """
33
+ if self.total_count > 1e-12:
34
+ accuracy = float(self.correct_count) / float(self.total_count)
35
+ else:
36
+ accuracy = 0.0
37
+ if reset:
38
+ self.reset()
39
+ return {'accuracy': accuracy}
40
+
41
+ def reset(self):
42
+ self.correct_count = 0.0
43
+ self.total_count = 0.0
44
+
45
+
46
+ def main():
47
+ inputs = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
48
+ targets = torch.zeros(size=(1, 198, 1), dtype=torch.float32)
49
+
50
+ metric_fn = VadAccuracy()
51
+
52
+ metric_fn.__call__(inputs, targets)
53
+
54
+ metrics = metric_fn.get_metric()
55
+ print(metrics)
56
+ return
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
toolbox/torchaudio/models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/models/vad/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/models/vad/silero_vad/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/models/vad/silero_vad/configuration_silero_vad.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Tuple
4
+
5
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class SileroVadConfig(PretrainedConfig):
9
+ def __init__(self,
10
+ sample_rate: int = 8000,
11
+ nfft: int = 512,
12
+ win_size: int = 240,
13
+ hop_size: int = 80,
14
+ win_type: str = "hann",
15
+
16
+ in_channels: int = 64,
17
+ hidden_size: int = 128,
18
+
19
+ lr: float = 0.001,
20
+ lr_scheduler: str = "CosineAnnealingLR",
21
+ lr_scheduler_kwargs: dict = None,
22
+
23
+ max_epochs: int = 100,
24
+ clip_grad_norm: float = 10.,
25
+ seed: int = 1234,
26
+
27
+ num_workers: int = 4,
28
+ batch_size: int = 4,
29
+ eval_steps: int = 25000,
30
+
31
+ **kwargs
32
+ ):
33
+ super(SileroVadConfig, self).__init__(**kwargs)
34
+ # transform
35
+ self.sample_rate = sample_rate
36
+ self.nfft = nfft
37
+ self.win_size = win_size
38
+ self.hop_size = hop_size
39
+ self.win_type = win_type
40
+
41
+ # encoder
42
+ self.in_channels = in_channels
43
+ self.hidden_size = hidden_size
44
+
45
+ # train
46
+ self.lr = lr
47
+ self.lr_scheduler = lr_scheduler
48
+ self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
49
+
50
+ self.max_epochs = max_epochs
51
+ self.clip_grad_norm = clip_grad_norm
52
+ self.seed = seed
53
+
54
+ self.num_workers = num_workers
55
+ self.batch_size = batch_size
56
+ self.eval_steps = eval_steps
57
+
58
+
59
+ def main():
60
+ config = SileroVadConfig()
61
+ config.to_yaml_file("config.yaml")
62
+ return
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/snakers4/silero-vad/wiki/Quality-Metrics
5
+
6
+ https://pytorch.org/hub/snakers4_silero-vad_vad/
7
+ https://github.com/snakers4/silero-vad
8
+
9
+ https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/data/silero_vad.jit
10
+ """
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
15
+ from toolbox.torchaudio.modules.conv_stft import ConvSTFT
16
+
17
+
18
+ MODEL_FILE = "model.pt"
19
+
20
+
21
+ class EncoderBlock(nn.Module):
22
+ def __init__(self,
23
+ in_channels: int = 64,
24
+ out_channels: int = 128,
25
+ ):
26
+ super(EncoderBlock, self).__init__()
27
+ self.conv1d = nn.Conv1d(
28
+ in_channels=in_channels,
29
+ out_channels=out_channels,
30
+ kernel_size=3,
31
+ padding="same",
32
+ )
33
+ self.activation = nn.ReLU()
34
+ self.norm = nn.BatchNorm1d(out_channels)
35
+
36
+ def forward(self, x: torch.Tensor):
37
+ # x shape: [b, t, f]
38
+ x = torch.transpose(x, dim0=1, dim1=2)
39
+ # x shape: [b, f, t]
40
+
41
+ x = self.conv1d.forward(x)
42
+ x = self.activation(x)
43
+ x = self.norm(x)
44
+
45
+ x = torch.transpose(x, dim0=1, dim1=2)
46
+ # x shape: [b, t, f]
47
+
48
+ return x
49
+
50
+
51
+ class Encoder(nn.Module):
52
+ def __init__(self,
53
+ in_channels: int = 64,
54
+ out_channels: int = 128,
55
+ num_layers: int = 3,
56
+ ):
57
+ super(Encoder, self).__init__()
58
+
59
+ self.layers = nn.ModuleList(modules=[
60
+ EncoderBlock(
61
+ in_channels=in_channels,
62
+ out_channels=out_channels,
63
+ )
64
+ if i == 0 else
65
+ EncoderBlock(
66
+ in_channels=out_channels,
67
+ out_channels=out_channels,
68
+ )
69
+ for i in range(num_layers)
70
+ ])
71
+
72
+ def forward(self, x: torch.Tensor):
73
+ for layer in self.layers:
74
+ x = layer.forward(x)
75
+ return x
76
+
77
+
78
+ class SileroVadModel(nn.Module):
79
+ def __init__(self, config: SileroVadConfig):
80
+ super(SileroVadModel, self).__init__()
81
+ self.config = config
82
+ self.eps = 1e-12
83
+
84
+ self.stft = ConvSTFT(
85
+ nfft=config.nfft,
86
+ win_size=config.win_size,
87
+ hop_size=config.hop_size,
88
+ win_type=config.win_type,
89
+ power=1,
90
+ requires_grad=False
91
+ )
92
+
93
+ self.linear = nn.Linear(
94
+ in_features=(config.nfft // 2 + 1),
95
+ out_features=config.in_channels,
96
+ )
97
+
98
+ self.encoder = Encoder(
99
+ in_channels=config.in_channels,
100
+ out_channels=config.hidden_size,
101
+ )
102
+
103
+ self.lstm = nn.LSTM(
104
+ input_size=config.hidden_size,
105
+ hidden_size=config.hidden_size,
106
+ bidirectional=False,
107
+ batch_first=True
108
+ )
109
+
110
+ self.classifier = nn.Sequential(
111
+ nn.Linear(config.hidden_size, 32),
112
+ nn.ReLU(),
113
+ nn.Linear(32, 1),
114
+ nn.Sigmoid()
115
+ )
116
+
117
+ def forward(self, signal: torch.Tensor):
118
+ mags = self.stft.forward(signal)
119
+ # mags shape: [b, f, t]
120
+
121
+ x = torch.transpose(mags, dim0=1, dim1=2)
122
+ # x shape: [b, t, f]
123
+
124
+ x = self.linear.forward(x)
125
+ # x shape: [b, t, f']
126
+
127
+ x = self.encoder.forward(x)
128
+ # x shape: [b, t, f]
129
+
130
+ x, _ = self.lstm.forward(x)
131
+ x = self.classifier.forward(x)
132
+
133
+ # x shape: [b, t, 1]
134
+ return x
135
+
136
+
137
+ def main():
138
+ config = SileroVadConfig()
139
+ model = SileroVadModel(config=config)
140
+
141
+ noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
142
+
143
+ probs = model.forward(noisy)
144
+ print(f"probs: {probs}")
145
+ print(f"probs.shape: {probs.shape}")
146
+
147
+ return
148
+
149
+
150
+ if __name__ == "__main__":
151
+ main()
toolbox/torchaudio/models/vad/silero_vad/yaml/config.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "silero_vad"
2
+
3
+ sample_rate: 8000
4
+ nfft: 512
5
+ win_size: 240
6
+ hop_size: 80
7
+ win_type: hann
8
+
9
+ in_channels: 64
10
+ hidden_size: 128
11
+
12
+ lr: 0.001
13
+ lr_scheduler: CosineAnnealingLR
14
+ lr_scheduler_kwargs: {}
15
+
16
+ max_epochs: 100
17
+ clip_grad_norm: 10.0
18
+ seed: 1234
19
+
20
+ num_workers: 4
21
+ batch_size: 4
22
+ eval_steps: 25000
toolbox/torchaudio/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/modules/conv_stft.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py
5
+ """
6
+ from collections import defaultdict
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from scipy.signal import get_window
12
+
13
+
14
+ def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
15
+ if win_type == "None" or win_type is None:
16
+ window = np.ones(win_size)
17
+ else:
18
+ window = get_window(win_type, win_size, fftbins=True)**0.5
19
+
20
+ fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size]
21
+ real_kernel = np.real(fourier_basis)
22
+ image_kernel = np.imag(fourier_basis)
23
+ kernel = np.concatenate([real_kernel, image_kernel], 1).T
24
+
25
+ if inverse:
26
+ kernel = np.linalg.pinv(kernel).T
27
+
28
+ kernel = kernel * window
29
+ kernel = kernel[:, None, :]
30
+ result = (
31
+ torch.from_numpy(kernel.astype(np.float32)),
32
+ torch.from_numpy(window[None, :, None].astype(np.float32))
33
+ )
34
+ return result
35
+
36
+
37
+ class ConvSTFT(nn.Module):
38
+
39
+ def __init__(self,
40
+ nfft: int,
41
+ win_size: int,
42
+ hop_size: int,
43
+ win_type: str = "hamming",
44
+ power: int = None,
45
+ requires_grad: bool = False):
46
+ super(ConvSTFT, self).__init__()
47
+
48
+ if nfft is None:
49
+ self.nfft = int(2**np.ceil(np.log2(win_size)))
50
+ else:
51
+ self.nfft = nfft
52
+
53
+ kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type)
54
+ self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
55
+
56
+ self.win_size = win_size
57
+ self.hop_size = hop_size
58
+
59
+ self.stride = hop_size
60
+ self.dim = self.nfft
61
+ self.power = power
62
+
63
+ def forward(self, waveform: torch.Tensor):
64
+ if waveform.dim() == 2:
65
+ waveform = torch.unsqueeze(waveform, 1)
66
+
67
+ matrix = F.conv1d(waveform, self.weight, stride=self.stride)
68
+ dim = self.dim // 2 + 1
69
+ real = matrix[:, :dim, :]
70
+ imag = matrix[:, dim:, :]
71
+ spec = torch.complex(real, imag)
72
+ # spec shape: [b, f, t], torch.complex64
73
+
74
+ if self.power is None:
75
+ return spec
76
+ elif self.power == 1:
77
+ mags = torch.sqrt(real**2 + imag**2)
78
+ # phase = torch.atan2(imag, real)
79
+ return mags
80
+ elif self.power == 2:
81
+ power = real**2 + imag**2
82
+ return power
83
+ else:
84
+ raise AssertionError
85
+
86
+
87
+ class ConviSTFT(nn.Module):
88
+
89
+ def __init__(self,
90
+ win_size: int,
91
+ hop_size: int,
92
+ nfft: int = None,
93
+ win_type: str = "hamming",
94
+ requires_grad: bool = False):
95
+ super(ConviSTFT, self).__init__()
96
+ if nfft is None:
97
+ self.nfft = int(2**np.ceil(np.log2(win_size)))
98
+ else:
99
+ self.nfft = nfft
100
+
101
+ kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True)
102
+ self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
103
+ # weight shape: [f*2, 1, nfft]
104
+ # f = nfft // 2 + 1
105
+
106
+ self.win_size = win_size
107
+ self.hop_size = hop_size
108
+ self.win_type = win_type
109
+
110
+ self.stride = hop_size
111
+ self.dim = self.nfft
112
+
113
+ self.register_buffer("window", window)
114
+ self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
115
+ # window shape: [1, nfft, 1]
116
+ # enframe shape: [nfft, 1, nfft]
117
+
118
+ def forward(self,
119
+ spec: torch.Tensor):
120
+ """
121
+ self.weight shape: [f*2, 1, win_size]
122
+ self.window shape: [1, win_size, 1]
123
+ self.enframe shape: [win_size, 1, win_size]
124
+
125
+ :param spec: torch.Tensor, shape: [b, f, t, 2]
126
+ :return:
127
+ """
128
+ spec = torch.view_as_real(spec)
129
+ # spec shape: [b, f, t, 2]
130
+ matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1)
131
+ # matrix shape: [b, f*2, t]
132
+
133
+ waveform = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
134
+ # waveform shape: [b, 1, num_samples]
135
+
136
+ # this is from torch-stft: https://github.com/pseeth/torch-stft
137
+ t = self.window.repeat(1, 1, matrix.size(-1))**2
138
+ # t shape: [1, win_size, t]
139
+ coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
140
+ # coff shape: [1, 1, num_samples]
141
+ waveform = waveform / (coff + 1e-8)
142
+ # waveform = waveform / coff
143
+ return waveform
144
+
145
+ @torch.no_grad()
146
+ def forward_chunk(self,
147
+ spec: torch.Tensor,
148
+ cache_dict: dict = None
149
+ ):
150
+ """
151
+ :param spec: shape: [b, f, t]
152
+ :param cache_dict: dict,
153
+ waveform_cache shape: [b, 1, win_size - hop_size]
154
+ coff_cache shape: [b, 1, win_size - hop_size]
155
+ :return:
156
+ """
157
+ if cache_dict is None:
158
+ cache_dict = defaultdict(lambda: None)
159
+ waveform_cache = cache_dict["waveform_cache"]
160
+ coff_cache = cache_dict["coff_cache"]
161
+
162
+ spec = torch.view_as_real(spec)
163
+ matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1)
164
+
165
+ waveform_current = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
166
+
167
+ t = self.window.repeat(1, 1, matrix.size(-1))**2
168
+ coff_current = F.conv_transpose1d(t, self.enframe, stride=self.stride)
169
+
170
+ overlap_size = self.win_size - self.hop_size
171
+
172
+ if waveform_cache is not None:
173
+ waveform_current[:, :, :overlap_size] += waveform_cache
174
+ waveform_output = waveform_current[:, :, :self.hop_size]
175
+ new_waveform_cache = waveform_current[:, :, self.hop_size:]
176
+
177
+ if coff_cache is not None:
178
+ coff_current[:, :, :overlap_size] += coff_cache
179
+ coff_output = coff_current[:, :, :self.hop_size]
180
+ new_coff_cache = coff_current[:, :, self.hop_size:]
181
+
182
+ waveform_output = waveform_output / (coff_output + 1e-8)
183
+
184
+ new_cache_dict = {
185
+ "waveform_cache": new_waveform_cache,
186
+ "coff_cache": new_coff_cache,
187
+ }
188
+ return waveform_output, new_cache_dict
189
+
190
+
191
+ def main():
192
+ nfft = 512
193
+ win_size = 512
194
+ hop_size = 256
195
+
196
+ stft = ConvSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size, power=None)
197
+ istft = ConviSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size)
198
+
199
+ mixture = torch.rand(size=(1, 16000), dtype=torch.float32)
200
+ b, num_samples = mixture.shape
201
+ t = (num_samples - win_size) / hop_size + 1
202
+
203
+ spec = stft.forward(mixture)
204
+ b, f, t = spec.shape
205
+
206
+ # 如果 spec 是由 stft 变换得来的,以下两种 waveform 还原方法就是一致的,否则还原出的 waveform 会有差异。
207
+ # spec = spec + 0.01 * torch.randn(size=(1, nfft//2+1, t), dtype=torch.float32)
208
+ print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
209
+
210
+ waveform = istft.forward(spec)
211
+ # shape: [batch_size, channels, num_samples]
212
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
213
+ print(waveform[:, :, 300: 302])
214
+
215
+ waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
216
+ for i in range(int(t)):
217
+ begin = i * hop_size
218
+ end = begin + win_size
219
+ sub_spec = spec[:, :, i:i+1]
220
+ sub_waveform = istft.forward(sub_spec)
221
+ # (b, 1, win_size)
222
+ waveform[:, :, begin:end] = sub_waveform
223
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
224
+ print(waveform[:, :, 300: 302])
225
+
226
+ return
227
+
228
+
229
+ def main2():
230
+ nfft = 512
231
+ win_size = 512
232
+ hop_size = 256
233
+
234
+ stft = ConvSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size, power=None)
235
+ istft = ConviSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size)
236
+
237
+ mixture = torch.rand(size=(1, 16128), dtype=torch.float32)
238
+ b, num_samples = mixture.shape
239
+
240
+ spec = stft.forward(mixture)
241
+ b, f, t = spec.shape
242
+
243
+ # 如果 spec 是由 stft 变换得来的,以下两种 waveform 还原方法就是一致的,否则还原出的 waveform 会有差异。
244
+ spec = spec + 0.01 * torch.randn(size=(1, nfft//2+1, t), dtype=torch.float32)
245
+ print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
246
+
247
+ waveform = istft.forward(spec)
248
+ # shape: [batch_size, channels, num_samples]
249
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
250
+ print(waveform[:, :, 300: 302])
251
+
252
+ cache_dict = None
253
+ waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
254
+ for i in range(int(t)):
255
+ sub_spec = spec[:, :, i:i+1]
256
+ begin = i * hop_size
257
+
258
+ end = begin + win_size - hop_size
259
+ sub_waveform, cache_dict = istft.forward_chunk(sub_spec, cache_dict=cache_dict)
260
+ # end = begin + win_size
261
+ # sub_waveform = istft.forward(sub_spec)
262
+
263
+ waveform[:, :, begin:end] = sub_waveform
264
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
265
+ print(waveform[:, :, 300: 302])
266
+
267
+ return
268
+
269
+
270
+ if __name__ == "__main__":
271
+ main2()
toolbox/webrtcvad/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/webrtcvad/vad.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import collections
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ from scipy.io import wavfile
9
+ import webrtcvad
10
+
11
+ from project_settings import project_path
12
+
13
+
14
+ class Frame(object):
15
+ def __init__(self, signal: np.ndarray, timestamp, duration):
16
+ self.signal = signal
17
+ self.timestamp = timestamp
18
+ self.duration = duration
19
+
20
+
21
+ class WebRTCVad(object):
22
+ def __init__(self,
23
+ agg: int = 3,
24
+ frame_duration_ms: int = 30,
25
+ padding_duration_ms: int = 300,
26
+ silence_duration_threshold: float = 0.3,
27
+ sample_rate: int = 8000
28
+ ):
29
+ self.agg = agg
30
+ self.frame_duration_ms = frame_duration_ms
31
+ self.padding_duration_ms = padding_duration_ms
32
+ self.silence_duration_threshold = silence_duration_threshold
33
+ self.sample_rate = sample_rate
34
+
35
+ self._vad = webrtcvad.Vad(mode=agg)
36
+
37
+ # frames
38
+ self.frame_length = int(sample_rate * (frame_duration_ms / 1000.0))
39
+ self.frame_timestamp = 0.0
40
+ self.signal_cache = None
41
+
42
+ # segments
43
+ self.num_padding_frames = int(padding_duration_ms / frame_duration_ms)
44
+ self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
45
+ self.triggered = False
46
+ self.voiced_frames: List[Frame] = list()
47
+ self.segments = list()
48
+
49
+ # vad segments
50
+ self.is_first_segment = True
51
+ self.timestamp_start = 0.0
52
+ self.timestamp_end = 0.0
53
+
54
+ def signal_to_frames(self, signal: np.ndarray):
55
+ frames = list()
56
+
57
+ l = len(signal)
58
+
59
+ duration = (float(self.frame_length) / self.sample_rate)
60
+
61
+ for offset in range(0, l, self.frame_length):
62
+ sub_signal = signal[offset:offset+self.frame_length]
63
+
64
+ frame = Frame(sub_signal, self.frame_timestamp, duration)
65
+ self.frame_timestamp += duration
66
+
67
+ frames.append(frame)
68
+ return frames
69
+
70
+ def segments_generator(self, signal: np.ndarray):
71
+ # signal rounding
72
+ if self.signal_cache is not None:
73
+ signal = np.concatenate([self.signal_cache, signal])
74
+
75
+ rest = len(signal) % self.frame_length
76
+
77
+ if rest == 0:
78
+ self.signal_cache = None
79
+ signal_ = signal
80
+ else:
81
+ self.signal_cache = signal[-rest:]
82
+ signal_ = signal[:-rest]
83
+
84
+ # frames
85
+ frames = self.signal_to_frames(signal_)
86
+
87
+ for frame in frames:
88
+ audio_bytes = bytes(frame.signal)
89
+ is_speech = self._vad.is_speech(audio_bytes, self.sample_rate)
90
+
91
+ if not self.triggered:
92
+ self.ring_buffer.append((frame, is_speech))
93
+ num_voiced = len([f for f, speech in self.ring_buffer if speech])
94
+
95
+ if num_voiced > 0.9 * self.ring_buffer.maxlen:
96
+ self.triggered = True
97
+
98
+ for f, _ in self.ring_buffer:
99
+ self.voiced_frames.append(f)
100
+ self.ring_buffer.clear()
101
+ else:
102
+ self.voiced_frames.append(frame)
103
+ self.ring_buffer.append((frame, is_speech))
104
+ num_unvoiced = len([f for f, speech in self.ring_buffer if not speech])
105
+ if num_unvoiced > 0.9 * self.ring_buffer.maxlen:
106
+ self.triggered = False
107
+ segment = [
108
+ np.concatenate([f.signal for f in self.voiced_frames]),
109
+ self.voiced_frames[0].timestamp,
110
+ self.voiced_frames[-1].timestamp
111
+ ]
112
+ yield segment
113
+ self.ring_buffer.clear()
114
+ self.voiced_frames: List[Frame] = list()
115
+
116
+ def vad_segments_generator(self, segments_generator):
117
+ segments = list(segments_generator)
118
+
119
+ for i, segment in enumerate(segments):
120
+ start = round(segment[1], 4)
121
+ end = round(segment[2], 4)
122
+
123
+ if self.is_first_segment:
124
+ self.timestamp_start = start
125
+ self.timestamp_end = end
126
+ self.is_first_segment = False
127
+ continue
128
+
129
+ if self.timestamp_start:
130
+ sil_duration = start - self.timestamp_end
131
+ if sil_duration > self.silence_duration_threshold:
132
+ vad_segment = [self.timestamp_start, self.timestamp_end]
133
+ yield vad_segment
134
+
135
+ self.timestamp_start = start
136
+ self.timestamp_end = end
137
+ else:
138
+ self.timestamp_end = end
139
+
140
+ def vad(self, signal: np.ndarray) -> List[list]:
141
+ segments = self.segments_generator(signal)
142
+ vad_segments = self.vad_segments_generator(segments)
143
+ vad_segments = list(vad_segments)
144
+ return vad_segments
145
+
146
+ def last_vad_segments(self) -> List[list]:
147
+ # last segments
148
+ if len(self.voiced_frames) == 0:
149
+ segments = []
150
+ else:
151
+ segment = [
152
+ np.concatenate([f.signal for f in self.voiced_frames]),
153
+ self.voiced_frames[0].timestamp,
154
+ self.voiced_frames[-1].timestamp
155
+ ]
156
+ segments = [segment]
157
+
158
+ # last vad segments
159
+ vad_segments = self.vad_segments_generator(segments)
160
+ vad_segments = list(vad_segments)
161
+
162
+ vad_segments = vad_segments + [[self.timestamp_start, self.timestamp_end]]
163
+ return vad_segments
164
+
165
+
166
+ def get_args():
167
+ parser = argparse.ArgumentParser()
168
+ parser.add_argument(
169
+ "--wav_file",
170
+ # default=(project_path / "data/0eeaef67-ea59-4f2d-a5b8-b70c813fd45c.wav").as_posix(),
171
+ default=(project_path / "data/1c998b62-c3aa-4541-b59a-d4a40b79eff3.wav").as_posix(),
172
+ # default=(project_path / "data/8cbad66f-2c4e-43c2-ad11-ad95bab8bc15.wav").as_posix(),
173
+ type=str,
174
+ )
175
+ parser.add_argument(
176
+ "--agg",
177
+ default=3,
178
+ type=int,
179
+ help="The level of aggressiveness of the VAD: [0-3]'"
180
+ )
181
+ parser.add_argument(
182
+ "--frame_duration_ms",
183
+ default=30,
184
+ type=int,
185
+ )
186
+ parser.add_argument(
187
+ "--padding_duration_ms",
188
+ default=300,
189
+ type=int,
190
+ )
191
+ parser.add_argument(
192
+ "--silence_duration_threshold",
193
+ default=0.3,
194
+ type=float,
195
+ help="minimum silence duration, in seconds."
196
+ )
197
+ args = parser.parse_args()
198
+ return args
199
+
200
+
201
+ def main():
202
+ import matplotlib.pyplot as plt
203
+
204
+ args = get_args()
205
+
206
+ SAMPLE_RATE = 8000
207
+
208
+ w_vad = WebRTCVad(
209
+ agg=args.agg,
210
+ frame_duration_ms=args.frame_duration_ms,
211
+ padding_duration_ms=args.padding_duration_ms,
212
+ silence_duration_threshold=args.silence_duration_threshold,
213
+ sample_rate=SAMPLE_RATE,
214
+ )
215
+
216
+ sample_rate, signal = wavfile.read(args.wav_file)
217
+ if SAMPLE_RATE != sample_rate:
218
+ raise AssertionError
219
+
220
+ vad_segments = list()
221
+
222
+ segments = w_vad.vad(signal)
223
+ vad_segments += segments
224
+ for segment in segments:
225
+ print(segment)
226
+
227
+ # last vad segment
228
+ segments = w_vad.last_vad_segments()
229
+ vad_segments += segments
230
+ for segment in segments:
231
+ print(segment)
232
+
233
+ # plot
234
+ time = np.arange(0, len(signal)) / sample_rate
235
+ plt.figure(figsize=(12, 5))
236
+ plt.plot(time, signal / 32768, color='b')
237
+ for start, end in vad_segments:
238
+ # start -= (w_vad.padding_duration_ms - 2*w_vad.frame_duration_ms) / 1000
239
+ end -= (w_vad.padding_duration_ms - 0*w_vad.frame_duration_ms) / 1000
240
+
241
+ plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') # 标记开始端点
242
+ plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') # 标记结束端点
243
+
244
+ plt.show()
245
+ return
246
+
247
+
248
+ if __name__ == '__main__':
249
+ main()