update
Browse files- .gitignore +2 -0
- Dockerfile +3 -1
- download_sound_models.py +53 -0
- examples/fsmn_vad/step_1_prepare_data.py +156 -0
- examples/silero_vad_by_webrtcvad/run.sh +7 -7
- examples/silero_vad_by_webrtcvad/step_1_prepare_data.py +98 -23
- examples/silero_vad_by_webrtcvad/step_2_make_vad_segments.py +91 -25
- examples/silero_vad_by_webrtcvad/step_3_check_vad.py +68 -0
- examples/silero_vad_by_webrtcvad/{step_3_train_model.py → step_4_train_model.py} +39 -13
- install.sh +28 -6
- toolbox/torch/utils/data/dataset/vad_padding_jsonl_dataset.py +240 -0
- toolbox/torch/utils/data/vocabulary.py +211 -0
- toolbox/torchaudio/metrics/vad_metrics/vad_f1_score.py +60 -0
- toolbox/torchaudio/models/vad/silero_vad/inference_silero_vad.py +169 -0
- toolbox/vad/__init__.py +6 -0
- toolbox/vad/vad.py +450 -0
- toolbox/webrtcvad/vad.py +5 -3
.gitignore
CHANGED
@@ -9,6 +9,7 @@
|
|
9 |
**/log/
|
10 |
**/logs/
|
11 |
**/__pycache__/
|
|
|
12 |
|
13 |
/data/
|
14 |
/docs/
|
@@ -21,3 +22,4 @@
|
|
21 |
|
22 |
**/*.wav
|
23 |
**/*.xlsx
|
|
|
|
9 |
**/log/
|
10 |
**/logs/
|
11 |
**/__pycache__/
|
12 |
+
**/serialization_dir/
|
13 |
|
14 |
/data/
|
15 |
/docs/
|
|
|
22 |
|
23 |
**/*.wav
|
24 |
**/*.xlsx
|
25 |
+
**/*.jsonl
|
Dockerfile
CHANGED
@@ -10,7 +10,9 @@ RUN apt-get install -y ffmpeg build-essential
|
|
10 |
RUN pip install --upgrade pip
|
11 |
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
12 |
|
13 |
-
RUN
|
|
|
|
|
14 |
|
15 |
USER user
|
16 |
|
|
|
10 |
RUN pip install --upgrade pip
|
11 |
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
12 |
|
13 |
+
RUN pip install --upgrade pip
|
14 |
+
|
15 |
+
RUN bash install.sh --stage 1 --stop_stage 2 --system_version centos
|
16 |
|
17 |
USER user
|
18 |
|
download_sound_models.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from huggingface_hub import snapshot_download
|
7 |
+
|
8 |
+
from project_settings import environment, project_path
|
9 |
+
|
10 |
+
|
11 |
+
def get_args():
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument(
|
14 |
+
"--trained_model_dir",
|
15 |
+
default=(project_path / "trained_models").as_posix(),
|
16 |
+
type=str,
|
17 |
+
)
|
18 |
+
parser.add_argument(
|
19 |
+
"--models_repo_id",
|
20 |
+
default="qgyd2021/vm_sound_classification",
|
21 |
+
type=str,
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
"--model_pattern",
|
25 |
+
default="sound-*-ch32.zip",
|
26 |
+
type=str,
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"--hf_token",
|
30 |
+
default=environment.get("hf_token"),
|
31 |
+
type=str,
|
32 |
+
)
|
33 |
+
args = parser.parse_args()
|
34 |
+
return args
|
35 |
+
|
36 |
+
|
37 |
+
def main():
|
38 |
+
args = get_args()
|
39 |
+
|
40 |
+
trained_model_dir = Path(args.trained_model_dir)
|
41 |
+
trained_model_dir.mkdir(parents=True, exist_ok=True)
|
42 |
+
|
43 |
+
_ = snapshot_download(
|
44 |
+
repo_id=args.models_repo_id,
|
45 |
+
allow_patterns=[args.model_pattern],
|
46 |
+
local_dir=trained_model_dir.as_posix(),
|
47 |
+
token=args.hf_token,
|
48 |
+
)
|
49 |
+
return
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
main()
|
examples/fsmn_vad/step_1_prepare_data.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
import random
|
8 |
+
import sys
|
9 |
+
|
10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
11 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
12 |
+
|
13 |
+
import librosa
|
14 |
+
import numpy as np
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
|
18 |
+
def get_args():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument(
|
21 |
+
"--noise_dir",
|
22 |
+
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
23 |
+
type=str
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--speech_dir",
|
27 |
+
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
|
28 |
+
type=str
|
29 |
+
)
|
30 |
+
|
31 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
32 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
33 |
+
|
34 |
+
parser.add_argument("--duration", default=6.0, type=float)
|
35 |
+
parser.add_argument("--min_snr_db", default=-10, type=float)
|
36 |
+
parser.add_argument("--max_snr_db", default=20, type=float)
|
37 |
+
|
38 |
+
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
39 |
+
|
40 |
+
parser.add_argument("--max_count", default=-1, type=int)
|
41 |
+
|
42 |
+
args = parser.parse_args()
|
43 |
+
return args
|
44 |
+
|
45 |
+
|
46 |
+
def target_second_signal_generator(data_dir: str, duration: int = 6, sample_rate: int = 8000, max_epoch: int = 20000):
|
47 |
+
data_dir = Path(data_dir)
|
48 |
+
for epoch_idx in range(max_epoch):
|
49 |
+
for filename in data_dir.glob("**/*.wav"):
|
50 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
51 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
52 |
+
|
53 |
+
if raw_duration < duration:
|
54 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
55 |
+
continue
|
56 |
+
if signal.ndim != 1:
|
57 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
58 |
+
|
59 |
+
signal_length = len(signal)
|
60 |
+
win_size = int(duration * sample_rate)
|
61 |
+
for begin in range(0, signal_length - win_size, win_size):
|
62 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
63 |
+
continue
|
64 |
+
row = {
|
65 |
+
"epoch_idx": epoch_idx,
|
66 |
+
"filename": filename.as_posix(),
|
67 |
+
"raw_duration": round(raw_duration, 4),
|
68 |
+
"offset": round(begin / sample_rate, 4),
|
69 |
+
"duration": round(duration, 4),
|
70 |
+
}
|
71 |
+
yield row
|
72 |
+
|
73 |
+
|
74 |
+
def main():
|
75 |
+
args = get_args()
|
76 |
+
|
77 |
+
noise_dir = Path(args.noise_dir)
|
78 |
+
speech_dir = Path(args.speech_dir)
|
79 |
+
|
80 |
+
train_dataset = Path(args.train_dataset)
|
81 |
+
valid_dataset = Path(args.valid_dataset)
|
82 |
+
train_dataset.parent.mkdir(parents=True, exist_ok=True)
|
83 |
+
valid_dataset.parent.mkdir(parents=True, exist_ok=True)
|
84 |
+
|
85 |
+
noise_generator = target_second_signal_generator(
|
86 |
+
noise_dir.as_posix(),
|
87 |
+
duration=args.duration,
|
88 |
+
sample_rate=args.target_sample_rate,
|
89 |
+
max_epoch=100000,
|
90 |
+
)
|
91 |
+
speech_generator = target_second_signal_generator(
|
92 |
+
speech_dir.as_posix(),
|
93 |
+
duration=args.duration,
|
94 |
+
sample_rate=args.target_sample_rate,
|
95 |
+
max_epoch=1,
|
96 |
+
)
|
97 |
+
|
98 |
+
count = 0
|
99 |
+
process_bar = tqdm(desc="build dataset jsonl")
|
100 |
+
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
101 |
+
for noise, speech in zip(noise_generator, speech_generator):
|
102 |
+
if count >= args.max_count > 0:
|
103 |
+
break
|
104 |
+
|
105 |
+
# row
|
106 |
+
noise_filename = noise["filename"]
|
107 |
+
noise_raw_duration = noise["raw_duration"]
|
108 |
+
noise_offset = noise["offset"]
|
109 |
+
noise_duration = noise["duration"]
|
110 |
+
|
111 |
+
speech_filename = speech["filename"]
|
112 |
+
speech_raw_duration = speech["raw_duration"]
|
113 |
+
speech_offset = speech["offset"]
|
114 |
+
speech_duration = speech["duration"]
|
115 |
+
|
116 |
+
# row
|
117 |
+
random1 = random.random()
|
118 |
+
random2 = random.random()
|
119 |
+
|
120 |
+
row = {
|
121 |
+
"count": count,
|
122 |
+
|
123 |
+
"noise_filename": noise_filename,
|
124 |
+
"noise_raw_duration": noise_raw_duration,
|
125 |
+
"noise_offset": noise_offset,
|
126 |
+
"noise_duration": noise_duration,
|
127 |
+
|
128 |
+
"speech_filename": speech_filename,
|
129 |
+
"speech_raw_duration": speech_raw_duration,
|
130 |
+
"speech_offset": speech_offset,
|
131 |
+
"speech_duration": speech_duration,
|
132 |
+
|
133 |
+
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
134 |
+
|
135 |
+
"random1": random1,
|
136 |
+
}
|
137 |
+
row = json.dumps(row, ensure_ascii=False)
|
138 |
+
if random2 < (1 / 300 / 1):
|
139 |
+
fvalid.write(f"{row}\n")
|
140 |
+
else:
|
141 |
+
ftrain.write(f"{row}\n")
|
142 |
+
|
143 |
+
count += 1
|
144 |
+
duration_seconds = count * args.duration
|
145 |
+
duration_hours = duration_seconds / 3600
|
146 |
+
|
147 |
+
process_bar.update(n=1)
|
148 |
+
process_bar.set_postfix({
|
149 |
+
"duration_hours": round(duration_hours, 4),
|
150 |
+
})
|
151 |
+
|
152 |
+
return
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
main()
|
examples/silero_vad_by_webrtcvad/run.sh
CHANGED
@@ -122,7 +122,7 @@ fi
|
|
122 |
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
123 |
$verbose && echo "stage 3: train model"
|
124 |
cd "${work_dir}" || exit 1
|
125 |
-
python3
|
126 |
--train_dataset "${train_vad_dataset}" \
|
127 |
--valid_dataset "${valid_vad_dataset}" \
|
128 |
--serialization_dir "${file_dir}" \
|
@@ -131,8 +131,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|
131 |
fi
|
132 |
|
133 |
|
134 |
-
if [ ${stage} -le
|
135 |
-
$verbose && echo "stage
|
136 |
cd "${work_dir}" || exit 1
|
137 |
python3 step_3_evaluation.py \
|
138 |
--valid_dataset "${valid_dataset}" \
|
@@ -143,8 +143,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|
143 |
fi
|
144 |
|
145 |
|
146 |
-
if [ ${stage} -le
|
147 |
-
$verbose && echo "stage
|
148 |
cd "${work_dir}" || exit 1
|
149 |
|
150 |
mkdir -p ${final_model_dir}
|
@@ -165,8 +165,8 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
|
165 |
fi
|
166 |
|
167 |
|
168 |
-
if [ ${stage} -le
|
169 |
-
$verbose && echo "stage
|
170 |
cd "${work_dir}" || exit 1
|
171 |
|
172 |
rm -rf "${file_dir}";
|
|
|
122 |
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
123 |
$verbose && echo "stage 3: train model"
|
124 |
cd "${work_dir}" || exit 1
|
125 |
+
python3 step_4_train_model.py \
|
126 |
--train_dataset "${train_vad_dataset}" \
|
127 |
--valid_dataset "${valid_vad_dataset}" \
|
128 |
--serialization_dir "${file_dir}" \
|
|
|
131 |
fi
|
132 |
|
133 |
|
134 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
135 |
+
$verbose && echo "stage 4: test model"
|
136 |
cd "${work_dir}" || exit 1
|
137 |
python3 step_3_evaluation.py \
|
138 |
--valid_dataset "${valid_dataset}" \
|
|
|
143 |
fi
|
144 |
|
145 |
|
146 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
147 |
+
$verbose && echo "stage 5: collect files"
|
148 |
cd "${work_dir}" || exit 1
|
149 |
|
150 |
mkdir -p ${final_model_dir}
|
|
|
165 |
fi
|
166 |
|
167 |
|
168 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
169 |
+
$verbose && echo "stage 6: clear file_dir"
|
170 |
cd "${work_dir}" || exit 1
|
171 |
|
172 |
rm -rf "${file_dir}";
|
examples/silero_vad_by_webrtcvad/step_1_prepare_data.py
CHANGED
@@ -6,6 +6,7 @@ import os
|
|
6 |
from pathlib import Path
|
7 |
import random
|
8 |
import sys
|
|
|
9 |
|
10 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
11 |
sys.path.append(os.path.join(pwd, "../../"))
|
@@ -19,19 +20,21 @@ def get_args():
|
|
19 |
parser = argparse.ArgumentParser()
|
20 |
parser.add_argument(
|
21 |
"--noise_dir",
|
22 |
-
default=r"
|
23 |
type=str
|
24 |
)
|
25 |
parser.add_argument(
|
26 |
"--speech_dir",
|
27 |
-
default=r"
|
28 |
type=str
|
29 |
)
|
30 |
|
31 |
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
32 |
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
33 |
|
34 |
-
parser.add_argument("--duration", default=
|
|
|
|
|
35 |
parser.add_argument("--min_snr_db", default=-10, type=float)
|
36 |
parser.add_argument("--max_snr_db", default=20, type=float)
|
37 |
|
@@ -43,21 +46,90 @@ def get_args():
|
|
43 |
return args
|
44 |
|
45 |
|
46 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
data_dir = Path(data_dir)
|
48 |
for epoch_idx in range(max_epoch):
|
49 |
for filename in data_dir.glob("**/*.wav"):
|
50 |
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
51 |
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
52 |
|
53 |
-
if raw_duration < duration:
|
54 |
-
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
55 |
-
continue
|
56 |
if signal.ndim != 1:
|
57 |
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
signal_length = len(signal)
|
60 |
-
win_size = int(
|
61 |
for begin in range(0, signal_length - win_size, win_size):
|
62 |
if np.sum(signal[begin: begin+win_size]) == 0:
|
63 |
continue
|
@@ -66,7 +138,7 @@ def target_second_signal_generator(data_dir: str, duration: int = 6, sample_rate
|
|
66 |
"filename": filename.as_posix(),
|
67 |
"raw_duration": round(raw_duration, 4),
|
68 |
"offset": round(begin / sample_rate, 4),
|
69 |
-
"duration": round(
|
70 |
}
|
71 |
yield row
|
72 |
|
@@ -82,15 +154,16 @@ def main():
|
|
82 |
train_dataset.parent.mkdir(parents=True, exist_ok=True)
|
83 |
valid_dataset.parent.mkdir(parents=True, exist_ok=True)
|
84 |
|
85 |
-
noise_generator =
|
86 |
noise_dir.as_posix(),
|
87 |
duration=args.duration,
|
88 |
sample_rate=args.target_sample_rate,
|
89 |
max_epoch=100000,
|
90 |
)
|
91 |
-
speech_generator =
|
92 |
speech_dir.as_posix(),
|
93 |
-
|
|
|
94 |
sample_rate=args.target_sample_rate,
|
95 |
max_epoch=1,
|
96 |
)
|
@@ -98,21 +171,26 @@ def main():
|
|
98 |
count = 0
|
99 |
process_bar = tqdm(desc="build dataset jsonl")
|
100 |
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
101 |
-
for
|
102 |
if count >= args.max_count > 0:
|
103 |
break
|
104 |
|
105 |
# row
|
106 |
-
noise_filename = noise["filename"]
|
107 |
-
noise_raw_duration = noise["raw_duration"]
|
108 |
-
noise_offset = noise["offset"]
|
109 |
-
noise_duration = noise["duration"]
|
110 |
-
|
111 |
speech_filename = speech["filename"]
|
112 |
speech_raw_duration = speech["raw_duration"]
|
113 |
speech_offset = speech["offset"]
|
114 |
speech_duration = speech["duration"]
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
# row
|
117 |
random1 = random.random()
|
118 |
random2 = random.random()
|
@@ -120,16 +198,13 @@ def main():
|
|
120 |
row = {
|
121 |
"count": count,
|
122 |
|
123 |
-
"noise_filename": noise_filename,
|
124 |
-
"noise_raw_duration": noise_raw_duration,
|
125 |
-
"noise_offset": noise_offset,
|
126 |
-
"noise_duration": noise_duration,
|
127 |
-
|
128 |
"speech_filename": speech_filename,
|
129 |
"speech_raw_duration": speech_raw_duration,
|
130 |
"speech_offset": speech_offset,
|
131 |
"speech_duration": speech_duration,
|
132 |
|
|
|
|
|
133 |
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
134 |
|
135 |
"random1": random1,
|
|
|
6 |
from pathlib import Path
|
7 |
import random
|
8 |
import sys
|
9 |
+
import time
|
10 |
|
11 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
12 |
sys.path.append(os.path.join(pwd, "../../"))
|
|
|
20 |
parser = argparse.ArgumentParser()
|
21 |
parser.add_argument(
|
22 |
"--noise_dir",
|
23 |
+
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
24 |
type=str
|
25 |
)
|
26 |
parser.add_argument(
|
27 |
"--speech_dir",
|
28 |
+
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech",
|
29 |
type=str
|
30 |
)
|
31 |
|
32 |
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
33 |
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
34 |
|
35 |
+
parser.add_argument("--duration", default=8.0, type=float)
|
36 |
+
parser.add_argument("--min_speech_duration", default=6.0, type=float)
|
37 |
+
parser.add_argument("--max_speech_duration", default=8.0, type=float)
|
38 |
parser.add_argument("--min_snr_db", default=-10, type=float)
|
39 |
parser.add_argument("--max_snr_db", default=20, type=float)
|
40 |
|
|
|
46 |
return args
|
47 |
|
48 |
|
49 |
+
def target_second_noise_signal_generator(data_dir: str,
|
50 |
+
duration: int = 4,
|
51 |
+
sample_rate: int = 8000, max_epoch: int = 20000):
|
52 |
+
noise_list = list()
|
53 |
+
wait_duration = duration
|
54 |
+
|
55 |
+
data_dir = Path(data_dir)
|
56 |
+
for epoch_idx in range(max_epoch):
|
57 |
+
for filename in data_dir.glob("**/*.wav"):
|
58 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
59 |
+
|
60 |
+
if signal.ndim != 1:
|
61 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
62 |
+
|
63 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
64 |
+
|
65 |
+
offset = 0.
|
66 |
+
rest_duration = raw_duration
|
67 |
+
|
68 |
+
for _ in range(1000):
|
69 |
+
if rest_duration <= 0:
|
70 |
+
break
|
71 |
+
if rest_duration <= wait_duration:
|
72 |
+
noise_list.append({
|
73 |
+
"epoch_idx": epoch_idx,
|
74 |
+
"filename": filename.as_posix(),
|
75 |
+
"raw_duration": round(raw_duration, 4),
|
76 |
+
"offset": round(offset, 4),
|
77 |
+
"duration": None,
|
78 |
+
"duration_": round(rest_duration, 4),
|
79 |
+
})
|
80 |
+
wait_duration -= rest_duration
|
81 |
+
offset = 0
|
82 |
+
rest_duration = 0
|
83 |
+
elif rest_duration > wait_duration:
|
84 |
+
noise_list.append({
|
85 |
+
"epoch_idx": epoch_idx,
|
86 |
+
"filename": filename.as_posix(),
|
87 |
+
"raw_duration": round(raw_duration, 4),
|
88 |
+
"offset": round(offset, 4),
|
89 |
+
"duration": round(wait_duration, 4),
|
90 |
+
"duration_": round(wait_duration, 4),
|
91 |
+
})
|
92 |
+
offset += wait_duration
|
93 |
+
rest_duration -= wait_duration
|
94 |
+
wait_duration = 0
|
95 |
+
else:
|
96 |
+
raise AssertionError
|
97 |
+
|
98 |
+
if wait_duration <= 0:
|
99 |
+
yield noise_list
|
100 |
+
noise_list = list()
|
101 |
+
wait_duration = duration
|
102 |
+
|
103 |
+
|
104 |
+
def target_second_speech_signal_generator(data_dir: str,
|
105 |
+
min_duration: int = 4,
|
106 |
+
max_duration: int = 6,
|
107 |
+
sample_rate: int = 8000, max_epoch: int = 1):
|
108 |
data_dir = Path(data_dir)
|
109 |
for epoch_idx in range(max_epoch):
|
110 |
for filename in data_dir.glob("**/*.wav"):
|
111 |
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
112 |
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
113 |
|
|
|
|
|
|
|
114 |
if signal.ndim != 1:
|
115 |
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
116 |
|
117 |
+
if raw_duration < min_duration:
|
118 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
119 |
+
continue
|
120 |
+
|
121 |
+
if raw_duration < max_duration:
|
122 |
+
row = {
|
123 |
+
"epoch_idx": epoch_idx,
|
124 |
+
"filename": filename.as_posix(),
|
125 |
+
"raw_duration": round(raw_duration, 4),
|
126 |
+
"offset": 0.,
|
127 |
+
"duration": round(raw_duration, 4),
|
128 |
+
}
|
129 |
+
yield row
|
130 |
+
|
131 |
signal_length = len(signal)
|
132 |
+
win_size = int(max_duration * sample_rate)
|
133 |
for begin in range(0, signal_length - win_size, win_size):
|
134 |
if np.sum(signal[begin: begin+win_size]) == 0:
|
135 |
continue
|
|
|
138 |
"filename": filename.as_posix(),
|
139 |
"raw_duration": round(raw_duration, 4),
|
140 |
"offset": round(begin / sample_rate, 4),
|
141 |
+
"duration": round(max_duration, 4),
|
142 |
}
|
143 |
yield row
|
144 |
|
|
|
154 |
train_dataset.parent.mkdir(parents=True, exist_ok=True)
|
155 |
valid_dataset.parent.mkdir(parents=True, exist_ok=True)
|
156 |
|
157 |
+
noise_generator = target_second_noise_signal_generator(
|
158 |
noise_dir.as_posix(),
|
159 |
duration=args.duration,
|
160 |
sample_rate=args.target_sample_rate,
|
161 |
max_epoch=100000,
|
162 |
)
|
163 |
+
speech_generator = target_second_speech_signal_generator(
|
164 |
speech_dir.as_posix(),
|
165 |
+
min_duration=args.min_speech_duration,
|
166 |
+
max_duration=args.max_speech_duration,
|
167 |
sample_rate=args.target_sample_rate,
|
168 |
max_epoch=1,
|
169 |
)
|
|
|
171 |
count = 0
|
172 |
process_bar = tqdm(desc="build dataset jsonl")
|
173 |
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
174 |
+
for speech, noise_list in zip(speech_generator, noise_generator):
|
175 |
if count >= args.max_count > 0:
|
176 |
break
|
177 |
|
178 |
# row
|
|
|
|
|
|
|
|
|
|
|
179 |
speech_filename = speech["filename"]
|
180 |
speech_raw_duration = speech["raw_duration"]
|
181 |
speech_offset = speech["offset"]
|
182 |
speech_duration = speech["duration"]
|
183 |
|
184 |
+
noise_list = [
|
185 |
+
{
|
186 |
+
"filename": noise["filename"],
|
187 |
+
"raw_duration": noise["raw_duration"],
|
188 |
+
"offset": noise["offset"],
|
189 |
+
"duration": noise["duration"],
|
190 |
+
}
|
191 |
+
for noise in noise_list
|
192 |
+
]
|
193 |
+
|
194 |
# row
|
195 |
random1 = random.random()
|
196 |
random2 = random.random()
|
|
|
198 |
row = {
|
199 |
"count": count,
|
200 |
|
|
|
|
|
|
|
|
|
|
|
201 |
"speech_filename": speech_filename,
|
202 |
"speech_raw_duration": speech_raw_duration,
|
203 |
"speech_offset": speech_offset,
|
204 |
"speech_duration": speech_duration,
|
205 |
|
206 |
+
"noise_list": noise_list,
|
207 |
+
|
208 |
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
209 |
|
210 |
"random1": random1,
|
examples/silero_vad_by_webrtcvad/step_2_make_vad_segments.py
CHANGED
@@ -12,7 +12,8 @@ import librosa
|
|
12 |
import numpy as np
|
13 |
from tqdm import tqdm
|
14 |
|
15 |
-
from
|
|
|
16 |
|
17 |
|
18 |
def get_args():
|
@@ -24,15 +25,19 @@ def get_args():
|
|
24 |
parser.add_argument("--train_vad_dataset", default="train-vad.jsonl", type=str)
|
25 |
parser.add_argument("--valid_vad_dataset", default="valid-vad.jsonl", type=str)
|
26 |
|
27 |
-
parser.add_argument("--duration", default=
|
28 |
parser.add_argument("--expected_sample_rate", default=8000, type=int)
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
36 |
args = parser.parse_args()
|
37 |
return args
|
38 |
|
@@ -40,17 +45,58 @@ def get_args():
|
|
40 |
def main():
|
41 |
args = get_args()
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
sample_rate=args.expected_sample_rate,
|
49 |
)
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
# valid
|
|
|
|
|
|
|
|
|
52 |
count = 0
|
53 |
-
|
54 |
with (open(args.valid_dataset, "r", encoding="utf-8") as fvalid,
|
55 |
open(args.valid_vad_dataset, "w", encoding="utf-8") as fvalid_vad):
|
56 |
for row in fvalid:
|
@@ -81,18 +127,30 @@ def main():
|
|
81 |
row = json.dumps(row, ensure_ascii=False)
|
82 |
fvalid_vad.write(f"{row}\n")
|
83 |
|
|
|
|
|
|
|
|
|
84 |
count += 1
|
85 |
-
duration_seconds = count * args.duration
|
86 |
-
duration_hours = duration_seconds / 3600
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
})
|
92 |
|
93 |
# train
|
|
|
|
|
|
|
|
|
94 |
count = 0
|
95 |
-
|
96 |
with (open(args.train_dataset, "r", encoding="utf-8") as ftrain,
|
97 |
open(args.train_vad_dataset, "w", encoding="utf-8") as ftrain_vad):
|
98 |
for row in ftrain:
|
@@ -123,13 +181,21 @@ def main():
|
|
123 |
row = json.dumps(row, ensure_ascii=False)
|
124 |
ftrain_vad.write(f"{row}\n")
|
125 |
|
|
|
|
|
|
|
|
|
126 |
count += 1
|
127 |
-
duration_seconds = count * args.duration
|
128 |
-
duration_hours = duration_seconds / 3600
|
129 |
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
})
|
134 |
|
135 |
return
|
|
|
12 |
import numpy as np
|
13 |
from tqdm import tqdm
|
14 |
|
15 |
+
from project_settings import project_path
|
16 |
+
from toolbox.vad.vad import WebRTCVoiceClassifier, SileroVoiceClassifier, CCSoundsClassifier, RingVad
|
17 |
|
18 |
|
19 |
def get_args():
|
|
|
25 |
parser.add_argument("--train_vad_dataset", default="train-vad.jsonl", type=str)
|
26 |
parser.add_argument("--valid_vad_dataset", default="valid-vad.jsonl", type=str)
|
27 |
|
28 |
+
parser.add_argument("--duration", default=8.0, type=float)
|
29 |
parser.add_argument("--expected_sample_rate", default=8000, type=int)
|
30 |
|
31 |
+
parser.add_argument(
|
32 |
+
"--silero_model_path",
|
33 |
+
default=(project_path / "trained_models/silero_vad.jit").as_posix(),
|
34 |
+
type=str,
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--cc_sounds_model_path",
|
38 |
+
default=(project_path / "trained_models/sound-2-ch32.zip").as_posix(),
|
39 |
+
type=str,
|
40 |
+
)
|
41 |
args = parser.parse_args()
|
42 |
return args
|
43 |
|
|
|
45 |
def main():
|
46 |
args = get_args()
|
47 |
|
48 |
+
# webrtcvad
|
49 |
+
# model = SileroVoiceClassifier(model_path=args.silero_model_path, sample_rate=args.expected_sample_rate)
|
50 |
+
# w_vad = RingVad(
|
51 |
+
# model=model,
|
52 |
+
# start_ring_rate=0.2,
|
53 |
+
# end_ring_rate=0.1,
|
54 |
+
# frame_size_ms=32,
|
55 |
+
# frame_step_ms=32,
|
56 |
+
# padding_length_ms=320,
|
57 |
+
# max_silence_length_ms=320,
|
58 |
+
# max_speech_length_s=100,
|
59 |
+
# min_speech_length_s=0.1,
|
60 |
+
# sample_rate=args.expected_sample_rate,
|
61 |
+
# )
|
62 |
+
|
63 |
+
# webrtcvad
|
64 |
+
model = WebRTCVoiceClassifier(agg=3, sample_rate=args.expected_sample_rate)
|
65 |
+
w_vad = RingVad(
|
66 |
+
model=model,
|
67 |
+
start_ring_rate=0.9,
|
68 |
+
end_ring_rate=0.1,
|
69 |
+
frame_size_ms=30,
|
70 |
+
frame_step_ms=30,
|
71 |
+
padding_length_ms=90,
|
72 |
+
max_silence_length_ms=100,
|
73 |
+
max_speech_length_s=100,
|
74 |
+
min_speech_length_s=0.1,
|
75 |
sample_rate=args.expected_sample_rate,
|
76 |
)
|
77 |
|
78 |
+
# cc sounds
|
79 |
+
# model = CCSoundsClassifier(model_path=args.cc_sounds_model_path, sample_rate=args.expected_sample_rate)
|
80 |
+
# w_vad = RingVad(
|
81 |
+
# model=model,
|
82 |
+
# start_ring_rate=0.5,
|
83 |
+
# end_ring_rate=0.3,
|
84 |
+
# frame_size_ms=300,
|
85 |
+
# frame_step_ms=300,
|
86 |
+
# padding_length_ms=300,
|
87 |
+
# max_silence_length_ms=100,
|
88 |
+
# max_speech_length_s=100,
|
89 |
+
# min_speech_length_s=0.1,
|
90 |
+
# sample_rate=args.expected_sample_rate,
|
91 |
+
# )
|
92 |
+
|
93 |
# valid
|
94 |
+
va_duration = 0
|
95 |
+
raw_duration = 0
|
96 |
+
use_duration = 0
|
97 |
+
|
98 |
count = 0
|
99 |
+
process_bar_valid = tqdm(desc="process valid dataset jsonl")
|
100 |
with (open(args.valid_dataset, "r", encoding="utf-8") as fvalid,
|
101 |
open(args.valid_vad_dataset, "w", encoding="utf-8") as fvalid_vad):
|
102 |
for row in fvalid:
|
|
|
127 |
row = json.dumps(row, ensure_ascii=False)
|
128 |
fvalid_vad.write(f"{row}\n")
|
129 |
|
130 |
+
va_duration += sum([vad_segment[1] - vad_segment[0] for vad_segment in vad_segments])
|
131 |
+
raw_duration += speech_duration
|
132 |
+
use_duration += args.duration
|
133 |
+
|
134 |
count += 1
|
|
|
|
|
135 |
|
136 |
+
va_rate = va_duration / use_duration
|
137 |
+
va_raw_rate = va_duration / raw_duration
|
138 |
+
use_duration_hours = use_duration / 3600
|
139 |
+
|
140 |
+
process_bar_valid.update(n=1)
|
141 |
+
process_bar_valid.set_postfix({
|
142 |
+
"va_rate": round(va_rate, 4),
|
143 |
+
"va_raw_rate": round(va_raw_rate, 4),
|
144 |
+
"duration_hours": round(use_duration_hours, 4),
|
145 |
})
|
146 |
|
147 |
# train
|
148 |
+
va_duration = 0
|
149 |
+
raw_duration = 0
|
150 |
+
use_duration = 0
|
151 |
+
|
152 |
count = 0
|
153 |
+
process_bar_train = tqdm(desc="process train dataset jsonl")
|
154 |
with (open(args.train_dataset, "r", encoding="utf-8") as ftrain,
|
155 |
open(args.train_vad_dataset, "w", encoding="utf-8") as ftrain_vad):
|
156 |
for row in ftrain:
|
|
|
181 |
row = json.dumps(row, ensure_ascii=False)
|
182 |
ftrain_vad.write(f"{row}\n")
|
183 |
|
184 |
+
va_duration += sum([vad_segment[1] - vad_segment[0] for vad_segment in vad_segments])
|
185 |
+
raw_duration += speech_duration
|
186 |
+
use_duration += args.duration
|
187 |
+
|
188 |
count += 1
|
|
|
|
|
189 |
|
190 |
+
va_rate = va_duration / use_duration
|
191 |
+
va_raw_rate = va_duration / raw_duration
|
192 |
+
use_duration_hours = use_duration / 3600
|
193 |
+
|
194 |
+
process_bar_train.update(n=1)
|
195 |
+
process_bar_train.set_postfix({
|
196 |
+
"va_rate": round(va_rate, 4),
|
197 |
+
"va_raw_rate": round(va_raw_rate, 4),
|
198 |
+
"duration_hours": round(use_duration_hours, 4),
|
199 |
})
|
200 |
|
201 |
return
|
examples/silero_vad_by_webrtcvad/step_3_check_vad.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
|
8 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
9 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
10 |
+
|
11 |
+
import librosa
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import numpy as np
|
14 |
+
from scipy.io import wavfile
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
|
18 |
+
def get_args():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
|
21 |
+
parser.add_argument("--train_vad_dataset", default="train-vad.jsonl", type=str)
|
22 |
+
parser.add_argument("--valid_vad_dataset", default="valid-vad.jsonl", type=str)
|
23 |
+
|
24 |
+
parser.add_argument("--duration", default=8.0, type=float)
|
25 |
+
parser.add_argument("--expected_sample_rate", default=8000, type=int)
|
26 |
+
|
27 |
+
args = parser.parse_args()
|
28 |
+
return args
|
29 |
+
|
30 |
+
|
31 |
+
def main():
|
32 |
+
args = get_args()
|
33 |
+
|
34 |
+
SAMPLE_RATE = 8000
|
35 |
+
|
36 |
+
with open(args.train_vad_dataset, "r", encoding="utf-8") as f:
|
37 |
+
for row in f:
|
38 |
+
row = json.loads(row)
|
39 |
+
|
40 |
+
speech_filename = row["speech_filename"]
|
41 |
+
speech_offset = row["speech_offset"]
|
42 |
+
speech_duration = row["speech_duration"]
|
43 |
+
|
44 |
+
vad_segments = row["vad_segments"]
|
45 |
+
|
46 |
+
print(f"speech_filename: {speech_filename}")
|
47 |
+
signal, sample_rate = librosa.load(
|
48 |
+
speech_filename,
|
49 |
+
sr=SAMPLE_RATE,
|
50 |
+
offset=speech_offset,
|
51 |
+
duration=speech_duration,
|
52 |
+
)
|
53 |
+
|
54 |
+
# plot
|
55 |
+
time = np.arange(0, len(signal)) / sample_rate
|
56 |
+
plt.figure(figsize=(12, 5))
|
57 |
+
plt.plot(time, signal, color='b')
|
58 |
+
for start, end in vad_segments:
|
59 |
+
plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') # 标记开始端点
|
60 |
+
plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') # 标记结束端点
|
61 |
+
|
62 |
+
plt.show()
|
63 |
+
|
64 |
+
return
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
main()
|
examples/silero_vad_by_webrtcvad/{step_3_train_model.py → step_4_train_model.py}
RENAMED
@@ -22,25 +22,26 @@ from torch.nn import functional as F
|
|
22 |
from torch.utils.data.dataloader import DataLoader
|
23 |
from tqdm import tqdm
|
24 |
|
25 |
-
from toolbox.torch.utils.data.dataset.
|
26 |
from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
|
27 |
from toolbox.torchaudio.models.vad.silero_vad.modeling_silero_vad import SileroVadModel, SileroVadPretrainedModel
|
28 |
from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss
|
29 |
from toolbox.torchaudio.losses.bce_loss import BCELoss
|
30 |
from toolbox.torchaudio.losses.dice_loss import DiceLoss
|
31 |
from toolbox.torchaudio.metrics.vad_metrics.vad_accuracy import VadAccuracy
|
|
|
32 |
|
33 |
|
34 |
def get_args():
|
35 |
parser = argparse.ArgumentParser()
|
36 |
-
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
37 |
-
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
38 |
|
39 |
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
40 |
parser.add_argument("--patience", default=30, type=int)
|
41 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
42 |
|
43 |
-
parser.add_argument("--config_file", default="config.yaml", type=str)
|
44 |
|
45 |
args = parser.parse_args()
|
46 |
return args
|
@@ -116,7 +117,7 @@ def main():
|
|
116 |
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
117 |
|
118 |
# datasets
|
119 |
-
train_dataset =
|
120 |
jsonl_file=args.train_dataset,
|
121 |
expected_sample_rate=config.sample_rate,
|
122 |
max_wave_value=32768.0,
|
@@ -124,7 +125,7 @@ def main():
|
|
124 |
max_snr_db=config.max_snr_db,
|
125 |
# skip=225000,
|
126 |
)
|
127 |
-
valid_dataset =
|
128 |
jsonl_file=args.valid_dataset,
|
129 |
expected_sample_rate=config.sample_rate,
|
130 |
max_wave_value=32768.0,
|
@@ -205,6 +206,7 @@ def main():
|
|
205 |
dice_loss_fn = DiceLoss(reduction="mean").to(device)
|
206 |
|
207 |
vad_accuracy_metrics_fn = VadAccuracy(threshold=0.5)
|
|
|
208 |
|
209 |
# training loop
|
210 |
|
@@ -213,6 +215,11 @@ def main():
|
|
213 |
average_bce_loss = 1000000000
|
214 |
average_dice_loss = 1000000000
|
215 |
|
|
|
|
|
|
|
|
|
|
|
216 |
model_list = list()
|
217 |
best_epoch_idx = None
|
218 |
best_step_idx = None
|
@@ -230,6 +237,7 @@ def main():
|
|
230 |
# train
|
231 |
model.train()
|
232 |
vad_accuracy_metrics_fn.reset()
|
|
|
233 |
|
234 |
total_loss = 0.
|
235 |
total_bce_loss = 0.
|
@@ -259,6 +267,7 @@ def main():
|
|
259 |
continue
|
260 |
|
261 |
vad_accuracy_metrics_fn.__call__(probs, targets)
|
|
|
262 |
|
263 |
optimizer.zero_grad()
|
264 |
loss.backward()
|
@@ -277,14 +286,21 @@ def main():
|
|
277 |
|
278 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
279 |
accuracy = metrics["accuracy"]
|
|
|
|
|
|
|
|
|
280 |
|
281 |
progress_bar_train.update(1)
|
282 |
progress_bar_train.set_postfix({
|
283 |
"lr": lr_scheduler.get_last_lr()[0],
|
284 |
"loss": average_loss,
|
285 |
-
"
|
286 |
-
"
|
287 |
"accuracy": accuracy,
|
|
|
|
|
|
|
288 |
})
|
289 |
|
290 |
# evaluation
|
@@ -295,6 +311,7 @@ def main():
|
|
295 |
|
296 |
model.eval()
|
297 |
vad_accuracy_metrics_fn.reset()
|
|
|
298 |
|
299 |
total_loss = 0.
|
300 |
total_bce_loss = 0.
|
@@ -324,6 +341,7 @@ def main():
|
|
324 |
continue
|
325 |
|
326 |
vad_accuracy_metrics_fn.__call__(probs, targets)
|
|
|
327 |
|
328 |
total_loss += loss.item()
|
329 |
total_bce_loss += bce_loss.item()
|
@@ -336,18 +354,26 @@ def main():
|
|
336 |
|
337 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
338 |
accuracy = metrics["accuracy"]
|
|
|
|
|
|
|
|
|
339 |
|
340 |
progress_bar_eval.update(1)
|
341 |
progress_bar_eval.set_postfix({
|
342 |
"lr": lr_scheduler.get_last_lr()[0],
|
343 |
"loss": average_loss,
|
344 |
-
"
|
345 |
-
"
|
346 |
"accuracy": accuracy,
|
|
|
|
|
|
|
347 |
})
|
348 |
|
349 |
model.train()
|
350 |
vad_accuracy_metrics_fn.reset()
|
|
|
351 |
|
352 |
total_loss = 0.
|
353 |
total_bce_loss = 0.
|
@@ -377,12 +403,12 @@ def main():
|
|
377 |
if best_metric is None:
|
378 |
best_epoch_idx = epoch_idx
|
379 |
best_step_idx = step_idx
|
380 |
-
best_metric =
|
381 |
-
elif
|
382 |
# great is better.
|
383 |
best_epoch_idx = epoch_idx
|
384 |
best_step_idx = step_idx
|
385 |
-
best_metric =
|
386 |
else:
|
387 |
pass
|
388 |
|
|
|
22 |
from torch.utils.data.dataloader import DataLoader
|
23 |
from tqdm import tqdm
|
24 |
|
25 |
+
from toolbox.torch.utils.data.dataset.vad_padding_jsonl_dataset import VadPaddingJsonlDataset
|
26 |
from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
|
27 |
from toolbox.torchaudio.models.vad.silero_vad.modeling_silero_vad import SileroVadModel, SileroVadPretrainedModel
|
28 |
from toolbox.torchaudio.losses.vad_loss.base_vad_loss import BaseVadLoss
|
29 |
from toolbox.torchaudio.losses.bce_loss import BCELoss
|
30 |
from toolbox.torchaudio.losses.dice_loss import DiceLoss
|
31 |
from toolbox.torchaudio.metrics.vad_metrics.vad_accuracy import VadAccuracy
|
32 |
+
from toolbox.torchaudio.metrics.vad_metrics.vad_f1_score import VadF1Score
|
33 |
|
34 |
|
35 |
def get_args():
|
36 |
parser = argparse.ArgumentParser()
|
37 |
+
parser.add_argument("--train_dataset", default="train-vad.jsonl", type=str)
|
38 |
+
parser.add_argument("--valid_dataset", default="valid-vad.jsonl", type=str)
|
39 |
|
40 |
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
41 |
parser.add_argument("--patience", default=30, type=int)
|
42 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
43 |
|
44 |
+
parser.add_argument("--config_file", default="yaml/config.yaml", type=str)
|
45 |
|
46 |
args = parser.parse_args()
|
47 |
return args
|
|
|
117 |
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
118 |
|
119 |
# datasets
|
120 |
+
train_dataset = VadPaddingJsonlDataset(
|
121 |
jsonl_file=args.train_dataset,
|
122 |
expected_sample_rate=config.sample_rate,
|
123 |
max_wave_value=32768.0,
|
|
|
125 |
max_snr_db=config.max_snr_db,
|
126 |
# skip=225000,
|
127 |
)
|
128 |
+
valid_dataset = VadPaddingJsonlDataset(
|
129 |
jsonl_file=args.valid_dataset,
|
130 |
expected_sample_rate=config.sample_rate,
|
131 |
max_wave_value=32768.0,
|
|
|
206 |
dice_loss_fn = DiceLoss(reduction="mean").to(device)
|
207 |
|
208 |
vad_accuracy_metrics_fn = VadAccuracy(threshold=0.5)
|
209 |
+
vad_f1_score_metrics_fn = VadF1Score(threshold=0.5)
|
210 |
|
211 |
# training loop
|
212 |
|
|
|
215 |
average_bce_loss = 1000000000
|
216 |
average_dice_loss = 1000000000
|
217 |
|
218 |
+
accuracy = -1
|
219 |
+
f1 = -1
|
220 |
+
precision = -1
|
221 |
+
recall = -1
|
222 |
+
|
223 |
model_list = list()
|
224 |
best_epoch_idx = None
|
225 |
best_step_idx = None
|
|
|
237 |
# train
|
238 |
model.train()
|
239 |
vad_accuracy_metrics_fn.reset()
|
240 |
+
vad_f1_score_metrics_fn.reset()
|
241 |
|
242 |
total_loss = 0.
|
243 |
total_bce_loss = 0.
|
|
|
267 |
continue
|
268 |
|
269 |
vad_accuracy_metrics_fn.__call__(probs, targets)
|
270 |
+
vad_f1_score_metrics_fn.__call__(probs, targets)
|
271 |
|
272 |
optimizer.zero_grad()
|
273 |
loss.backward()
|
|
|
286 |
|
287 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
288 |
accuracy = metrics["accuracy"]
|
289 |
+
metrics = vad_f1_score_metrics_fn.get_metric()
|
290 |
+
f1 = metrics["f1"]
|
291 |
+
precision = metrics["precision"]
|
292 |
+
recall = metrics["recall"]
|
293 |
|
294 |
progress_bar_train.update(1)
|
295 |
progress_bar_train.set_postfix({
|
296 |
"lr": lr_scheduler.get_last_lr()[0],
|
297 |
"loss": average_loss,
|
298 |
+
"bce_loss": average_bce_loss,
|
299 |
+
"dice_loss": average_dice_loss,
|
300 |
"accuracy": accuracy,
|
301 |
+
"f1": f1,
|
302 |
+
"precision": precision,
|
303 |
+
"recall": recall,
|
304 |
})
|
305 |
|
306 |
# evaluation
|
|
|
311 |
|
312 |
model.eval()
|
313 |
vad_accuracy_metrics_fn.reset()
|
314 |
+
vad_f1_score_metrics_fn.reset()
|
315 |
|
316 |
total_loss = 0.
|
317 |
total_bce_loss = 0.
|
|
|
341 |
continue
|
342 |
|
343 |
vad_accuracy_metrics_fn.__call__(probs, targets)
|
344 |
+
vad_f1_score_metrics_fn.__call__(probs, targets)
|
345 |
|
346 |
total_loss += loss.item()
|
347 |
total_bce_loss += bce_loss.item()
|
|
|
354 |
|
355 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
356 |
accuracy = metrics["accuracy"]
|
357 |
+
metrics = vad_f1_score_metrics_fn.get_metric()
|
358 |
+
f1 = metrics["f1"]
|
359 |
+
precision = metrics["precision"]
|
360 |
+
recall = metrics["recall"]
|
361 |
|
362 |
progress_bar_eval.update(1)
|
363 |
progress_bar_eval.set_postfix({
|
364 |
"lr": lr_scheduler.get_last_lr()[0],
|
365 |
"loss": average_loss,
|
366 |
+
"bce_loss": average_bce_loss,
|
367 |
+
"dice_loss": average_dice_loss,
|
368 |
"accuracy": accuracy,
|
369 |
+
"f1": f1,
|
370 |
+
"precision": precision,
|
371 |
+
"recall": recall,
|
372 |
})
|
373 |
|
374 |
model.train()
|
375 |
vad_accuracy_metrics_fn.reset()
|
376 |
+
vad_f1_score_metrics_fn.reset()
|
377 |
|
378 |
total_loss = 0.
|
379 |
total_bce_loss = 0.
|
|
|
403 |
if best_metric is None:
|
404 |
best_epoch_idx = epoch_idx
|
405 |
best_step_idx = step_idx
|
406 |
+
best_metric = f1
|
407 |
+
elif f1 >= best_metric:
|
408 |
# great is better.
|
409 |
best_epoch_idx = epoch_idx
|
410 |
best_step_idx = step_idx
|
411 |
+
best_metric = f1
|
412 |
else:
|
413 |
pass
|
414 |
|
install.sh
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
#!/usr/bin/env bash
|
2 |
|
3 |
-
# bash install.sh --stage
|
4 |
|
5 |
|
6 |
-
python_version=3.12.
|
7 |
system_version="centos";
|
8 |
|
9 |
verbose=true;
|
@@ -41,20 +41,42 @@ while true; do
|
|
41 |
done
|
42 |
|
43 |
work_dir="$(pwd)"
|
|
|
|
|
|
|
44 |
|
45 |
|
46 |
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
47 |
-
$verbose && echo "stage 1:
|
48 |
cd "${work_dir}" || exit 1;
|
49 |
|
50 |
-
|
|
|
51 |
fi
|
52 |
|
53 |
|
54 |
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
55 |
-
$verbose && echo "stage 2:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
# /usr/local/python-3.
|
|
|
58 |
# source /data/local/bin/cc_vad/bin/activate
|
59 |
/usr/local/python-${python_version}/bin/pip3 install virtualenv
|
60 |
mkdir -p /data/local/bin
|
|
|
1 |
#!/usr/bin/env bash
|
2 |
|
3 |
+
# bash install.sh --stage 1 --stop_stage 2 --system_version centos
|
4 |
|
5 |
|
6 |
+
python_version=3.12.8
|
7 |
system_version="centos";
|
8 |
|
9 |
verbose=true;
|
|
|
41 |
done
|
42 |
|
43 |
work_dir="$(pwd)"
|
44 |
+
trained_models_dir="$(pwd)/trained_models"
|
45 |
+
|
46 |
+
mkdir -p "${trained_models_dir}"
|
47 |
|
48 |
|
49 |
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
50 |
+
$verbose && echo "stage 1: download sound models"
|
51 |
cd "${work_dir}" || exit 1;
|
52 |
|
53 |
+
python download_sound_models.py
|
54 |
+
|
55 |
fi
|
56 |
|
57 |
|
58 |
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
59 |
+
$verbose && echo "stage 2: download silero vad model"
|
60 |
+
cd "${trained_models_dir}" || exit 1;
|
61 |
+
|
62 |
+
wget https://github.com/snakers4/silero-vad/blob/master/src/silero_vad/data/silero_vad.jit
|
63 |
+
|
64 |
+
fi
|
65 |
+
|
66 |
+
|
67 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
68 |
+
$verbose && echo "stage 3: install python"
|
69 |
+
cd "${work_dir}" || exit 1;
|
70 |
+
|
71 |
+
sh ./script/install_python.sh --python_version "${python_version}" --system_version "${system_version}"
|
72 |
+
fi
|
73 |
+
|
74 |
+
|
75 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
76 |
+
$verbose && echo "stage 4: create virtualenv"
|
77 |
|
78 |
+
# /usr/local/python-3.9.9/bin/pip3 install virtualenv
|
79 |
+
# /usr/local/python-3.9.9/bin/virtualenv cc_vad
|
80 |
# source /data/local/bin/cc_vad/bin/activate
|
81 |
/usr/local/python-${python_version}/bin/pip3 install virtualenv
|
82 |
mkdir -p /data/local/bin
|
toolbox/torch/utils/data/dataset/vad_padding_jsonl_dataset.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import json
|
4 |
+
import random
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from torch.utils.data import Dataset, IterableDataset
|
11 |
+
|
12 |
+
|
13 |
+
class VadPaddingJsonlDataset(IterableDataset):
|
14 |
+
def __init__(self,
|
15 |
+
jsonl_file: str,
|
16 |
+
expected_sample_rate: int,
|
17 |
+
resample: bool = False,
|
18 |
+
max_wave_value: float = 1.0,
|
19 |
+
buffer_size: int = 1000,
|
20 |
+
min_snr_db: float = None,
|
21 |
+
max_snr_db: float = None,
|
22 |
+
speech_target_duration: float = 8.0,
|
23 |
+
eps: float = 1e-8,
|
24 |
+
skip: int = 0,
|
25 |
+
):
|
26 |
+
self.jsonl_file = jsonl_file
|
27 |
+
self.expected_sample_rate = expected_sample_rate
|
28 |
+
self.resample = resample
|
29 |
+
self.max_wave_value = max_wave_value
|
30 |
+
self.min_snr_db = min_snr_db
|
31 |
+
self.max_snr_db = max_snr_db
|
32 |
+
self.speech_target_duration = speech_target_duration
|
33 |
+
self.eps = eps
|
34 |
+
self.skip = skip
|
35 |
+
|
36 |
+
self.buffer_size = buffer_size
|
37 |
+
self.buffer_samples: List[dict] = list()
|
38 |
+
|
39 |
+
def __iter__(self):
|
40 |
+
self.buffer_samples = list()
|
41 |
+
|
42 |
+
iterable_source = self.iterable_source()
|
43 |
+
|
44 |
+
try:
|
45 |
+
for _ in range(self.skip):
|
46 |
+
next(iterable_source)
|
47 |
+
except StopIteration:
|
48 |
+
pass
|
49 |
+
|
50 |
+
# 初始填充缓冲区
|
51 |
+
try:
|
52 |
+
for _ in range(self.buffer_size):
|
53 |
+
self.buffer_samples.append(next(iterable_source))
|
54 |
+
except StopIteration:
|
55 |
+
pass
|
56 |
+
|
57 |
+
# 动态替换逻辑
|
58 |
+
while True:
|
59 |
+
try:
|
60 |
+
item = next(iterable_source)
|
61 |
+
# 随机替换缓冲区元素
|
62 |
+
replace_idx = random.randint(0, len(self.buffer_samples) - 1)
|
63 |
+
sample = self.buffer_samples[replace_idx]
|
64 |
+
self.buffer_samples[replace_idx] = item
|
65 |
+
yield self.convert_sample(sample)
|
66 |
+
except StopIteration:
|
67 |
+
break
|
68 |
+
|
69 |
+
# 清空剩余元素
|
70 |
+
random.shuffle(self.buffer_samples)
|
71 |
+
for sample in self.buffer_samples:
|
72 |
+
yield self.convert_sample(sample)
|
73 |
+
|
74 |
+
def iterable_source(self):
|
75 |
+
last_sample = None
|
76 |
+
with open(self.jsonl_file, "r", encoding="utf-8") as f:
|
77 |
+
for row in f:
|
78 |
+
row = json.loads(row)
|
79 |
+
|
80 |
+
speech_filename = row["speech_filename"]
|
81 |
+
speech_raw_duration = row["speech_raw_duration"]
|
82 |
+
speech_offset = row["speech_offset"]
|
83 |
+
speech_duration = row["speech_duration"]
|
84 |
+
|
85 |
+
noise_list = row["noise_list"]
|
86 |
+
noise_list = [
|
87 |
+
{
|
88 |
+
"filename": noise["filename"],
|
89 |
+
"raw_duration": noise["raw_duration"],
|
90 |
+
"offset": noise["offset"],
|
91 |
+
"duration": noise["duration"],
|
92 |
+
}
|
93 |
+
for noise in noise_list
|
94 |
+
]
|
95 |
+
|
96 |
+
if self.min_snr_db is None or self.max_snr_db is None:
|
97 |
+
snr_db = row["snr_db"]
|
98 |
+
else:
|
99 |
+
snr_db = random.uniform(self.min_snr_db, self.max_snr_db)
|
100 |
+
|
101 |
+
vad_segments = row["vad_segments"]
|
102 |
+
|
103 |
+
sample = {
|
104 |
+
"speech_filename": speech_filename,
|
105 |
+
"speech_raw_duration": speech_raw_duration,
|
106 |
+
"speech_offset": speech_offset,
|
107 |
+
"speech_duration": speech_duration,
|
108 |
+
|
109 |
+
"noise_list": noise_list,
|
110 |
+
|
111 |
+
"snr_db": snr_db,
|
112 |
+
|
113 |
+
"vad_segments": vad_segments,
|
114 |
+
}
|
115 |
+
if last_sample is None:
|
116 |
+
last_sample = sample
|
117 |
+
continue
|
118 |
+
yield sample
|
119 |
+
yield last_sample
|
120 |
+
|
121 |
+
def convert_sample(self, sample: dict):
|
122 |
+
speech_filename = sample["speech_filename"]
|
123 |
+
speech_offset = sample["speech_offset"]
|
124 |
+
speech_duration = sample["speech_duration"]
|
125 |
+
|
126 |
+
noise_list = sample["noise_list"]
|
127 |
+
|
128 |
+
snr_db = sample["snr_db"]
|
129 |
+
|
130 |
+
vad_segments = sample["vad_segments"]
|
131 |
+
|
132 |
+
speech_wave = self.filename_to_waveform(speech_filename, speech_offset, speech_duration)
|
133 |
+
speech_wave_np = speech_wave.numpy()
|
134 |
+
speech_wave_np, left_pad_duration, _ = self.pad_waveform(speech_wave_np, self.expected_sample_rate, self.speech_target_duration)
|
135 |
+
speech_wave_np = self.make_sure_duration(speech_wave_np, self.expected_sample_rate, self.speech_target_duration)
|
136 |
+
|
137 |
+
noise_wave_list = list()
|
138 |
+
for noise in noise_list:
|
139 |
+
filename = noise["filename"]
|
140 |
+
offset = noise["offset"]
|
141 |
+
duration = noise["duration"]
|
142 |
+
noise_wave_: torch.Tensor = self.filename_to_waveform(filename, offset, duration)
|
143 |
+
noise_wave_list.append(noise_wave_)
|
144 |
+
noise_wave = torch.cat(noise_wave_list, dim=-1)
|
145 |
+
noise_wave_np = noise_wave.numpy()
|
146 |
+
noise_wave_np = self.make_sure_duration(noise_wave_np, self.expected_sample_rate, self.speech_target_duration)
|
147 |
+
|
148 |
+
noisy_wave_np, _ = self.mix_speech_and_noise(
|
149 |
+
speech=speech_wave_np,
|
150 |
+
noise=noise_wave_np,
|
151 |
+
snr_db=snr_db, eps=self.eps,
|
152 |
+
)
|
153 |
+
noisy_wave = torch.tensor(noisy_wave_np, dtype=torch.float32)
|
154 |
+
|
155 |
+
vad_segments = [
|
156 |
+
[
|
157 |
+
vad_segment[0] + left_pad_duration,
|
158 |
+
vad_segment[1] + left_pad_duration,
|
159 |
+
]
|
160 |
+
for vad_segment in vad_segments
|
161 |
+
]
|
162 |
+
|
163 |
+
result = {
|
164 |
+
"noisy_wave": noisy_wave,
|
165 |
+
"vad_segments": vad_segments,
|
166 |
+
}
|
167 |
+
return result
|
168 |
+
|
169 |
+
def filename_to_waveform(self, filename: str, offset: float, duration: float):
|
170 |
+
try:
|
171 |
+
waveform, sample_rate = librosa.load(
|
172 |
+
filename,
|
173 |
+
sr=self.expected_sample_rate,
|
174 |
+
offset=offset,
|
175 |
+
duration=duration,
|
176 |
+
)
|
177 |
+
except ValueError as e:
|
178 |
+
print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}")
|
179 |
+
raise e
|
180 |
+
waveform = torch.tensor(waveform, dtype=torch.float32)
|
181 |
+
return waveform
|
182 |
+
|
183 |
+
@staticmethod
|
184 |
+
def pad_waveform(waveform: np.ndarray, sample_rate: int = 8000, target_duration: float = 8.0):
|
185 |
+
num_samples = len(waveform)
|
186 |
+
target_num_samples = int(sample_rate * target_duration)
|
187 |
+
if target_num_samples < num_samples:
|
188 |
+
return waveform, 0, 0
|
189 |
+
|
190 |
+
left_pad_size = (target_num_samples - num_samples) // 2
|
191 |
+
right_pad_size = target_num_samples - left_pad_size
|
192 |
+
result = np.concat([
|
193 |
+
np.zeros(left_pad_size, dtype=waveform.dtype),
|
194 |
+
waveform,
|
195 |
+
np.zeros(right_pad_size, dtype=waveform.dtype),
|
196 |
+
])
|
197 |
+
|
198 |
+
left_pad_duration = left_pad_size / sample_rate
|
199 |
+
right_pad_duration = right_pad_size / sample_rate
|
200 |
+
return result, left_pad_duration, right_pad_duration
|
201 |
+
|
202 |
+
@staticmethod
|
203 |
+
def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float, eps: float = 1e-8):
|
204 |
+
l1 = len(speech)
|
205 |
+
l2 = len(noise)
|
206 |
+
l = min(l1, l2)
|
207 |
+
speech = speech[:l]
|
208 |
+
noise = noise[:l]
|
209 |
+
|
210 |
+
# np.float32, value between (-1, 1).
|
211 |
+
|
212 |
+
speech_power = np.mean(np.square(speech))
|
213 |
+
noise_power = speech_power / (10 ** (snr_db / 10))
|
214 |
+
|
215 |
+
noise_adjusted = np.sqrt(noise_power) * noise / (np.sqrt(np.mean(noise ** 2)) + eps)
|
216 |
+
|
217 |
+
noisy_signal = speech + noise_adjusted
|
218 |
+
|
219 |
+
return noisy_signal, noise_adjusted
|
220 |
+
|
221 |
+
@staticmethod
|
222 |
+
def make_sure_duration(waveform: np.ndarray, sample_rate: int = 8000, target_duration: float = 8.0):
|
223 |
+
num_samples = len(waveform)
|
224 |
+
target_num_samples = int(sample_rate * target_duration)
|
225 |
+
|
226 |
+
if target_num_samples < num_samples:
|
227 |
+
waveform = waveform[:target_num_samples]
|
228 |
+
elif target_num_samples > num_samples:
|
229 |
+
pad_size = target_num_samples - num_samples
|
230 |
+
waveform = np.concat([
|
231 |
+
waveform,
|
232 |
+
np.zeros(pad_size, dtype=waveform.dtype),
|
233 |
+
])
|
234 |
+
else:
|
235 |
+
pass
|
236 |
+
return waveform
|
237 |
+
|
238 |
+
|
239 |
+
if __name__ == "__main__":
|
240 |
+
pass
|
toolbox/torch/utils/data/vocabulary.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from collections import defaultdict, OrderedDict
|
4 |
+
import os
|
5 |
+
from typing import Any, Callable, Dict, Iterable, List, Set
|
6 |
+
|
7 |
+
|
8 |
+
def namespace_match(pattern: str, namespace: str):
|
9 |
+
"""
|
10 |
+
Matches a namespace pattern against a namespace string. For example, ``*tags`` matches
|
11 |
+
``passage_tags`` and ``question_tags`` and ``tokens`` matches ``tokens`` but not
|
12 |
+
``stemmed_tokens``.
|
13 |
+
"""
|
14 |
+
if pattern[0] == '*' and namespace.endswith(pattern[1:]):
|
15 |
+
return True
|
16 |
+
elif pattern == namespace:
|
17 |
+
return True
|
18 |
+
return False
|
19 |
+
|
20 |
+
|
21 |
+
class _NamespaceDependentDefaultDict(defaultdict):
|
22 |
+
def __init__(self,
|
23 |
+
non_padded_namespaces: Set[str],
|
24 |
+
padded_function: Callable[[], Any],
|
25 |
+
non_padded_function: Callable[[], Any]) -> None:
|
26 |
+
self._non_padded_namespaces = set(non_padded_namespaces)
|
27 |
+
self._padded_function = padded_function
|
28 |
+
self._non_padded_function = non_padded_function
|
29 |
+
super(_NamespaceDependentDefaultDict, self).__init__()
|
30 |
+
|
31 |
+
def __missing__(self, key: str):
|
32 |
+
if any(namespace_match(pattern, key) for pattern in self._non_padded_namespaces):
|
33 |
+
value = self._non_padded_function()
|
34 |
+
else:
|
35 |
+
value = self._padded_function()
|
36 |
+
dict.__setitem__(self, key, value)
|
37 |
+
return value
|
38 |
+
|
39 |
+
def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
|
40 |
+
# add non_padded_namespaces which weren't already present
|
41 |
+
self._non_padded_namespaces.update(non_padded_namespaces)
|
42 |
+
|
43 |
+
|
44 |
+
class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict):
|
45 |
+
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
|
46 |
+
super(_TokenToIndexDefaultDict, self).__init__(non_padded_namespaces,
|
47 |
+
lambda: {padding_token: 0, oov_token: 1},
|
48 |
+
lambda: {})
|
49 |
+
|
50 |
+
|
51 |
+
class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict):
|
52 |
+
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
|
53 |
+
super(_IndexToTokenDefaultDict, self).__init__(non_padded_namespaces,
|
54 |
+
lambda: {0: padding_token, 1: oov_token},
|
55 |
+
lambda: {})
|
56 |
+
|
57 |
+
|
58 |
+
DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels")
|
59 |
+
DEFAULT_PADDING_TOKEN = '[PAD]'
|
60 |
+
DEFAULT_OOV_TOKEN = '[UNK]'
|
61 |
+
NAMESPACE_PADDING_FILE = 'non_padded_namespaces.txt'
|
62 |
+
|
63 |
+
|
64 |
+
class Vocabulary(object):
|
65 |
+
def __init__(self, non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES):
|
66 |
+
self._non_padded_namespaces = set(non_padded_namespaces)
|
67 |
+
self._padding_token = DEFAULT_PADDING_TOKEN
|
68 |
+
self._oov_token = DEFAULT_OOV_TOKEN
|
69 |
+
self._token_to_index = _TokenToIndexDefaultDict(self._non_padded_namespaces,
|
70 |
+
self._padding_token,
|
71 |
+
self._oov_token)
|
72 |
+
self._index_to_token = _IndexToTokenDefaultDict(self._non_padded_namespaces,
|
73 |
+
self._padding_token,
|
74 |
+
self._oov_token)
|
75 |
+
|
76 |
+
def add_token_to_namespace(self, token: str, namespace: str = 'tokens') -> int:
|
77 |
+
if token not in self._token_to_index[namespace]:
|
78 |
+
index = len(self._token_to_index[namespace])
|
79 |
+
self._token_to_index[namespace][token] = index
|
80 |
+
self._index_to_token[namespace][index] = token
|
81 |
+
return index
|
82 |
+
else:
|
83 |
+
return self._token_to_index[namespace][token]
|
84 |
+
|
85 |
+
def get_index_to_token_vocabulary(self, namespace: str = 'tokens') -> Dict[int, str]:
|
86 |
+
return self._index_to_token[namespace]
|
87 |
+
|
88 |
+
def get_token_to_index_vocabulary(self, namespace: str = 'tokens') -> Dict[str, int]:
|
89 |
+
return self._token_to_index[namespace]
|
90 |
+
|
91 |
+
def get_token_index(self, token: str, namespace: str = 'tokens') -> int:
|
92 |
+
if token in self._token_to_index[namespace]:
|
93 |
+
return self._token_to_index[namespace][token]
|
94 |
+
else:
|
95 |
+
return self._token_to_index[namespace][self._oov_token]
|
96 |
+
|
97 |
+
def get_token_from_index(self, index: int, namespace: str = 'tokens'):
|
98 |
+
return self._index_to_token[namespace][index]
|
99 |
+
|
100 |
+
def get_vocab_size(self, namespace: str = 'tokens') -> int:
|
101 |
+
return len(self._token_to_index[namespace])
|
102 |
+
|
103 |
+
def save_to_files(self, directory: str):
|
104 |
+
os.makedirs(directory, exist_ok=True)
|
105 |
+
with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'w', encoding='utf-8') as f:
|
106 |
+
for namespace_str in self._non_padded_namespaces:
|
107 |
+
f.write('{}\n'.format(namespace_str))
|
108 |
+
|
109 |
+
for namespace, token_to_index in self._token_to_index.items():
|
110 |
+
filename = os.path.join(directory, '{}.txt'.format(namespace))
|
111 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
112 |
+
for token, _ in token_to_index.items():
|
113 |
+
f.write('{}\n'.format(token))
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def from_files(cls, directory: str) -> 'Vocabulary':
|
117 |
+
with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'r', encoding='utf-8') as f:
|
118 |
+
non_padded_namespaces = [namespace_str.strip() for namespace_str in f]
|
119 |
+
|
120 |
+
vocab = cls(non_padded_namespaces=non_padded_namespaces)
|
121 |
+
|
122 |
+
for namespace_filename in os.listdir(directory):
|
123 |
+
if namespace_filename == NAMESPACE_PADDING_FILE:
|
124 |
+
continue
|
125 |
+
if namespace_filename.startswith("."):
|
126 |
+
continue
|
127 |
+
namespace = namespace_filename.replace('.txt', '')
|
128 |
+
if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces):
|
129 |
+
is_padded = False
|
130 |
+
else:
|
131 |
+
is_padded = True
|
132 |
+
filename = os.path.join(directory, namespace_filename)
|
133 |
+
vocab.set_from_file(filename, is_padded, namespace=namespace)
|
134 |
+
|
135 |
+
return vocab
|
136 |
+
|
137 |
+
def set_from_file(self,
|
138 |
+
filename: str,
|
139 |
+
is_padded: bool = True,
|
140 |
+
oov_token: str = DEFAULT_OOV_TOKEN,
|
141 |
+
namespace: str = "tokens"
|
142 |
+
):
|
143 |
+
if is_padded:
|
144 |
+
self._token_to_index[namespace] = {self._padding_token: 0}
|
145 |
+
self._index_to_token[namespace] = {0: self._padding_token}
|
146 |
+
else:
|
147 |
+
self._token_to_index[namespace] = {}
|
148 |
+
self._index_to_token[namespace] = {}
|
149 |
+
|
150 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
151 |
+
index = 1 if is_padded else 0
|
152 |
+
for row in f:
|
153 |
+
token = str(row).strip()
|
154 |
+
if token == oov_token:
|
155 |
+
token = self._oov_token
|
156 |
+
self._token_to_index[namespace][token] = index
|
157 |
+
self._index_to_token[namespace][index] = token
|
158 |
+
index += 1
|
159 |
+
|
160 |
+
def convert_tokens_to_ids(self, tokens: List[str], namespace: str = "tokens"):
|
161 |
+
result = list()
|
162 |
+
for token in tokens:
|
163 |
+
idx = self._token_to_index[namespace].get(token)
|
164 |
+
if idx is None:
|
165 |
+
idx = self._token_to_index[namespace][self._oov_token]
|
166 |
+
result.append(idx)
|
167 |
+
return result
|
168 |
+
|
169 |
+
def convert_ids_to_tokens(self, ids: List[int], namespace: str = "tokens"):
|
170 |
+
result = list()
|
171 |
+
for idx in ids:
|
172 |
+
idx = self._index_to_token[namespace][idx]
|
173 |
+
result.append(idx)
|
174 |
+
return result
|
175 |
+
|
176 |
+
def pad_or_truncate_ids_by_max_length(self, ids: List[int], max_length: int, namespace: str = "tokens"):
|
177 |
+
pad_idx = self._token_to_index[namespace][self._padding_token]
|
178 |
+
|
179 |
+
length = len(ids)
|
180 |
+
if length > max_length:
|
181 |
+
result = ids[:max_length]
|
182 |
+
else:
|
183 |
+
result = ids + [pad_idx] * (max_length - length)
|
184 |
+
return result
|
185 |
+
|
186 |
+
|
187 |
+
def demo1():
|
188 |
+
import jieba
|
189 |
+
|
190 |
+
vocabulary = Vocabulary()
|
191 |
+
vocabulary.add_token_to_namespace('白天', 'tokens')
|
192 |
+
vocabulary.add_token_to_namespace('晚上', 'tokens')
|
193 |
+
|
194 |
+
text = '不是在白天, 就是在晚上'
|
195 |
+
tokens = jieba.lcut(text)
|
196 |
+
|
197 |
+
print(tokens)
|
198 |
+
|
199 |
+
ids = vocabulary.convert_tokens_to_ids(tokens)
|
200 |
+
print(ids)
|
201 |
+
|
202 |
+
padded_idx = vocabulary.pad_or_truncate_ids_by_max_length(ids, 10)
|
203 |
+
print(padded_idx)
|
204 |
+
|
205 |
+
tokens = vocabulary.convert_ids_to_tokens(padded_idx)
|
206 |
+
print(tokens)
|
207 |
+
return
|
208 |
+
|
209 |
+
|
210 |
+
if __name__ == '__main__':
|
211 |
+
demo1()
|
toolbox/torchaudio/metrics/vad_metrics/vad_f1_score.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class VadF1Score(object):
|
7 |
+
def __init__(self, threshold: float = 0.5, epsilon: float = 1e-12) -> None:
|
8 |
+
self.threshold = threshold
|
9 |
+
self.epsilon = epsilon # 防止除零错误
|
10 |
+
|
11 |
+
self.true_positives = 0.0
|
12 |
+
self.false_positives = 0.0
|
13 |
+
self.false_negatives = 0.0
|
14 |
+
|
15 |
+
def __call__(self,
|
16 |
+
predictions: torch.Tensor,
|
17 |
+
gold_labels: torch.Tensor,
|
18 |
+
):
|
19 |
+
"""
|
20 |
+
:param predictions: [b, t, 1], 经过sigmoid的概率输出
|
21 |
+
:param gold_labels: [b, t, 1], 二值标签 (0或1)
|
22 |
+
"""
|
23 |
+
# 将预测值转为二进制标签
|
24 |
+
pred_labels = (predictions > self.threshold).float()
|
25 |
+
|
26 |
+
# 计算TP/FP/FN
|
27 |
+
tp = (pred_labels * gold_labels).sum() # True Positives
|
28 |
+
fp = (pred_labels * (1 - gold_labels)).sum() # False Positives
|
29 |
+
fn = ((1 - pred_labels) * gold_labels).sum() # False Negatives
|
30 |
+
|
31 |
+
# 累加统计量
|
32 |
+
self.true_positives += tp.item()
|
33 |
+
self.false_positives += fp.item()
|
34 |
+
self.false_negatives += fn.item()
|
35 |
+
|
36 |
+
def get_metric(self, reset: bool = False):
|
37 |
+
# 计算Precision和Recall
|
38 |
+
precision = self.true_positives / (self.true_positives + self.false_positives + self.epsilon)
|
39 |
+
recall = self.true_positives / (self.true_positives + self.false_negatives + self.epsilon)
|
40 |
+
|
41 |
+
# 计算F1 Score
|
42 |
+
f1 = 2 * (precision * recall) / (precision + recall + self.epsilon)
|
43 |
+
|
44 |
+
if reset:
|
45 |
+
self.reset()
|
46 |
+
|
47 |
+
return {
|
48 |
+
'f1': f1,
|
49 |
+
'precision': precision,
|
50 |
+
'recall': recall
|
51 |
+
}
|
52 |
+
|
53 |
+
def reset(self):
|
54 |
+
self.true_positives = 0.0
|
55 |
+
self.false_positives = 0.0
|
56 |
+
self.false_negatives = 0.0
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
pass
|
toolbox/torchaudio/models/vad/silero_vad/inference_silero_vad.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
from pathlib import Path
|
6 |
+
import shutil
|
7 |
+
import tempfile, time
|
8 |
+
import zipfile
|
9 |
+
|
10 |
+
from scipy.io import wavfile
|
11 |
+
import librosa
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torchaudio
|
15 |
+
|
16 |
+
torch.set_num_threads(1)
|
17 |
+
|
18 |
+
from project_settings import project_path
|
19 |
+
from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
|
20 |
+
from toolbox.torchaudio.models.vad.silero_vad.modeling_silero_vad import SileroVadPretrainedModel, MODEL_FILE
|
21 |
+
from toolbox.vad.vad import FrameVoiceClassifier, RingVad, process_speech_probs, make_visualization
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger("toolbox")
|
25 |
+
|
26 |
+
|
27 |
+
class SileroVadVoiceClassifier(FrameVoiceClassifier):
|
28 |
+
def __init__(self,
|
29 |
+
pretrained_model_path_or_zip_file: str,
|
30 |
+
device: str = "cpu",
|
31 |
+
):
|
32 |
+
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
33 |
+
self.device = torch.device(device)
|
34 |
+
|
35 |
+
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
|
36 |
+
config, model = self.load_models(self.pretrained_model_path_or_zip_file)
|
37 |
+
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
|
38 |
+
|
39 |
+
self.config = config
|
40 |
+
self.model = model
|
41 |
+
self.model.to(device)
|
42 |
+
self.model.eval()
|
43 |
+
|
44 |
+
def load_models(self, model_path: str):
|
45 |
+
model_path = Path(model_path)
|
46 |
+
if model_path.name.endswith(".zip"):
|
47 |
+
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
|
48 |
+
out_root = Path(tempfile.gettempdir()) / "cc_vad"
|
49 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
50 |
+
f_zip.extractall(path=out_root)
|
51 |
+
model_path = out_root / model_path.stem
|
52 |
+
|
53 |
+
config = SileroVadConfig.from_pretrained(
|
54 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
55 |
+
)
|
56 |
+
model = SileroVadPretrainedModel.from_pretrained(
|
57 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
58 |
+
)
|
59 |
+
model.to(self.device)
|
60 |
+
model.eval()
|
61 |
+
|
62 |
+
shutil.rmtree(model_path)
|
63 |
+
return config, model
|
64 |
+
|
65 |
+
def predict(self, chunk: np.ndarray) -> float:
|
66 |
+
if chunk.dtype != np.int16:
|
67 |
+
raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
|
68 |
+
|
69 |
+
chunk = chunk / 32768
|
70 |
+
|
71 |
+
inputs = torch.tensor(chunk, dtype=torch.float32)
|
72 |
+
inputs = torch.unsqueeze(inputs, dim=0)
|
73 |
+
|
74 |
+
try:
|
75 |
+
logits, _ = self.model.forward(inputs)
|
76 |
+
except RuntimeError as e:
|
77 |
+
print(inputs.shape)
|
78 |
+
raise e
|
79 |
+
# logits shape: [b, t, 1]
|
80 |
+
logits_ = torch.mean(logits, dim=1)
|
81 |
+
# logits_ shape: [b, 1]
|
82 |
+
probs = torch.sigmoid(logits_)
|
83 |
+
|
84 |
+
voice_prob = probs[0][0]
|
85 |
+
return float(voice_prob)
|
86 |
+
|
87 |
+
|
88 |
+
class InferenceSileroVad(object):
|
89 |
+
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
90 |
+
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
91 |
+
self.device = torch.device(device)
|
92 |
+
|
93 |
+
self.voice_classifier = SileroVadVoiceClassifier(pretrained_model_path_or_zip_file, device=device)
|
94 |
+
|
95 |
+
self.ring_vad = RingVad(model=self.voice_classifier,
|
96 |
+
start_ring_rate=0.2,
|
97 |
+
end_ring_rate=0.1,
|
98 |
+
frame_size_ms=30,
|
99 |
+
frame_step_ms=30,
|
100 |
+
padding_length_ms=300,
|
101 |
+
max_silence_length_ms=300,
|
102 |
+
sample_rate=SAMPLE_RATE,
|
103 |
+
)
|
104 |
+
|
105 |
+
def vad(self, signal: np.ndarray) -> np.ndarray:
|
106 |
+
self.ring_vad.reset()
|
107 |
+
|
108 |
+
vad_segments = list()
|
109 |
+
|
110 |
+
segments = self.ring_vad.vad(signal)
|
111 |
+
vad_segments += segments
|
112 |
+
# last vad segment
|
113 |
+
segments = self.ring_vad.last_vad_segments()
|
114 |
+
vad_segments += segments
|
115 |
+
return vad_segments
|
116 |
+
|
117 |
+
def get_vad_speech_probs(self):
|
118 |
+
result = self.ring_vad.speech_probs
|
119 |
+
return result
|
120 |
+
|
121 |
+
def get_vad_frame_step(self):
|
122 |
+
result = self.ring_vad.frame_step
|
123 |
+
return result
|
124 |
+
|
125 |
+
|
126 |
+
def get_args():
|
127 |
+
parser = argparse.ArgumentParser()
|
128 |
+
parser.add_argument(
|
129 |
+
"--wav_file",
|
130 |
+
default=(project_path / "data/examples/hado/2f16ca0b-baec-4601-8a1e-7893eb875623.wav").as_posix(),
|
131 |
+
type=str,
|
132 |
+
)
|
133 |
+
args = parser.parse_args()
|
134 |
+
return args
|
135 |
+
|
136 |
+
|
137 |
+
SAMPLE_RATE = 8000
|
138 |
+
|
139 |
+
|
140 |
+
def main():
|
141 |
+
args = get_args()
|
142 |
+
|
143 |
+
sample_rate, signal = wavfile.read(args.wav_file)
|
144 |
+
if SAMPLE_RATE != sample_rate:
|
145 |
+
raise AssertionError
|
146 |
+
|
147 |
+
infer = InferenceSileroVad(
|
148 |
+
pretrained_model_path_or_zip_file=(project_path / "trained_models/silero-vad-webrtcvad-nx2-dns3.zip").as_posix()
|
149 |
+
)
|
150 |
+
|
151 |
+
vad_segments = infer.vad(signal)
|
152 |
+
|
153 |
+
speech_probs = infer.get_vad_speech_probs()
|
154 |
+
frame_step = infer.get_vad_frame_step()
|
155 |
+
|
156 |
+
# speech_probs
|
157 |
+
speech_probs = process_speech_probs(
|
158 |
+
signal=signal,
|
159 |
+
speech_probs=speech_probs,
|
160 |
+
frame_step=frame_step,
|
161 |
+
)
|
162 |
+
|
163 |
+
# plot
|
164 |
+
make_visualization(signal, speech_probs, SAMPLE_RATE, vad_segments)
|
165 |
+
return
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
main()
|
toolbox/vad/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/vad/vad.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import collections
|
5 |
+
from functools import lru_cache
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
import shutil
|
9 |
+
import tempfile
|
10 |
+
import zipfile
|
11 |
+
from typing import List
|
12 |
+
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import numpy as np
|
15 |
+
from scipy.io import wavfile
|
16 |
+
import torch
|
17 |
+
import webrtcvad
|
18 |
+
|
19 |
+
from project_settings import project_path
|
20 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
21 |
+
|
22 |
+
|
23 |
+
class FrameVoiceClassifier(object):
|
24 |
+
def predict(self, chunk: np.ndarray) -> float:
|
25 |
+
raise NotImplementedError
|
26 |
+
|
27 |
+
|
28 |
+
class WebRTCVoiceClassifier(FrameVoiceClassifier):
|
29 |
+
def __init__(self,
|
30 |
+
agg: int = 3,
|
31 |
+
sample_rate: int = 8000
|
32 |
+
):
|
33 |
+
self.agg = agg
|
34 |
+
self.sample_rate = sample_rate
|
35 |
+
|
36 |
+
self.model = webrtcvad.Vad(mode=agg)
|
37 |
+
|
38 |
+
def predict(self, chunk: np.ndarray) -> float:
|
39 |
+
if chunk.dtype != np.int16:
|
40 |
+
raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
|
41 |
+
|
42 |
+
audio_bytes = bytes(chunk)
|
43 |
+
is_speech = self.model.is_speech(audio_bytes, self.sample_rate)
|
44 |
+
return 1.0 if is_speech else 0.0
|
45 |
+
|
46 |
+
|
47 |
+
class SileroVoiceClassifier(FrameVoiceClassifier):
|
48 |
+
def __init__(self,
|
49 |
+
model_path: str,
|
50 |
+
sample_rate: int = 8000):
|
51 |
+
self.model_path = model_path
|
52 |
+
self.sample_rate = sample_rate
|
53 |
+
|
54 |
+
with open(self.model_path, "rb") as f:
|
55 |
+
model = torch.jit.load(f, map_location="cpu")
|
56 |
+
self.model = model
|
57 |
+
self.model.reset_states()
|
58 |
+
|
59 |
+
def predict(self, chunk: np.ndarray) -> float:
|
60 |
+
if self.sample_rate / len(chunk) > 31.25:
|
61 |
+
raise AssertionError("chunk samples number {} is less than {}".format(len(chunk), self.sample_rate / 31.25))
|
62 |
+
if chunk.dtype != np.int16:
|
63 |
+
raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
|
64 |
+
|
65 |
+
num_samples = len(chunk)
|
66 |
+
if self.sample_rate == 8000 and num_samples != 256:
|
67 |
+
raise AssertionError(f"win size must be 32 ms for silero vad. ")
|
68 |
+
if self.sample_rate == 16000 and num_samples != 512:
|
69 |
+
raise AssertionError(f"win size must be 32 ms for silero vad. ")
|
70 |
+
|
71 |
+
chunk = chunk / 32768
|
72 |
+
chunk = torch.tensor(chunk, dtype=torch.float32)
|
73 |
+
speech_prob = self.model(chunk, self.sample_rate).item()
|
74 |
+
return float(speech_prob)
|
75 |
+
|
76 |
+
|
77 |
+
class CCSoundsClassifier(FrameVoiceClassifier):
|
78 |
+
def __init__(self,
|
79 |
+
model_path: str,
|
80 |
+
sample_rate: int = 8000):
|
81 |
+
self.model_path = model_path
|
82 |
+
self.sample_rate = sample_rate
|
83 |
+
|
84 |
+
d = self.load_model(Path(model_path))
|
85 |
+
|
86 |
+
model = d["model"]
|
87 |
+
vocabulary = d["vocabulary"]
|
88 |
+
|
89 |
+
self.model = model
|
90 |
+
self.vocabulary = vocabulary
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
@lru_cache(maxsize=100)
|
94 |
+
def load_model(model_file: Path):
|
95 |
+
with zipfile.ZipFile(model_file, "r") as f_zip:
|
96 |
+
out_root = Path(tempfile.gettempdir()) / "cc_audio_8"
|
97 |
+
if out_root.exists():
|
98 |
+
shutil.rmtree(out_root.as_posix())
|
99 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
100 |
+
f_zip.extractall(path=out_root)
|
101 |
+
|
102 |
+
tgt_path = out_root / model_file.stem
|
103 |
+
jit_model_file = tgt_path / "trace_model.zip"
|
104 |
+
vocab_path = tgt_path / "vocabulary"
|
105 |
+
|
106 |
+
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
|
107 |
+
|
108 |
+
with open(jit_model_file.as_posix(), "rb") as f:
|
109 |
+
model = torch.jit.load(f)
|
110 |
+
model.eval()
|
111 |
+
|
112 |
+
shutil.rmtree(tgt_path)
|
113 |
+
|
114 |
+
d = {
|
115 |
+
"model": model,
|
116 |
+
"vocabulary": vocabulary
|
117 |
+
}
|
118 |
+
return d
|
119 |
+
|
120 |
+
def predict(self, chunk: np.ndarray) -> float:
|
121 |
+
if chunk.dtype != np.int16:
|
122 |
+
raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
|
123 |
+
|
124 |
+
chunk = chunk / (1 << 15)
|
125 |
+
inputs = torch.tensor(chunk, dtype=torch.float32)
|
126 |
+
inputs = torch.unsqueeze(inputs, dim=0)
|
127 |
+
|
128 |
+
with torch.no_grad():
|
129 |
+
logits = self.model(inputs)
|
130 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
131 |
+
|
132 |
+
voice_idx = self.vocabulary.get_token_index(token="voice", namespace="labels")
|
133 |
+
|
134 |
+
probs = probs.cpu()
|
135 |
+
|
136 |
+
voice_prob = probs[0][voice_idx]
|
137 |
+
return float(voice_prob)
|
138 |
+
|
139 |
+
|
140 |
+
class Frame(object):
|
141 |
+
def __init__(self, signal: np.ndarray, timestamp_s: float):
|
142 |
+
self.signal = signal
|
143 |
+
self.timestamp_s = timestamp_s
|
144 |
+
|
145 |
+
|
146 |
+
class RingVad(object):
|
147 |
+
def __init__(self,
|
148 |
+
model: FrameVoiceClassifier,
|
149 |
+
start_ring_rate: float = 0.5,
|
150 |
+
end_ring_rate: float = 0.5,
|
151 |
+
frame_size_ms: int = 30,
|
152 |
+
frame_step_ms: int = 30,
|
153 |
+
padding_length_ms: int = 300,
|
154 |
+
max_silence_length_ms: int = 300,
|
155 |
+
max_speech_length_s: float = 2.0,
|
156 |
+
min_speech_length_s: float = 0.3,
|
157 |
+
sample_rate: int = 8000
|
158 |
+
):
|
159 |
+
self.model = model
|
160 |
+
self.start_ring_rate = start_ring_rate
|
161 |
+
self.end_ring_rate = end_ring_rate
|
162 |
+
self.frame_size_ms = frame_size_ms
|
163 |
+
self.frame_step_ms = frame_step_ms
|
164 |
+
self.padding_length_ms = padding_length_ms
|
165 |
+
self.max_silence_length_ms = max_silence_length_ms
|
166 |
+
self.max_speech_length_s = max_speech_length_s
|
167 |
+
self.min_speech_length_s = min_speech_length_s
|
168 |
+
self.sample_rate = sample_rate
|
169 |
+
|
170 |
+
# frames
|
171 |
+
self.frame_size = int(sample_rate * (frame_size_ms / 1000.0))
|
172 |
+
self.frame_step = int(sample_rate * (frame_step_ms / 1000.0))
|
173 |
+
self.frame_timestamp_s = 0.0
|
174 |
+
self.signal_cache = np.zeros(shape=(self.frame_size,), dtype=np.int16)
|
175 |
+
|
176 |
+
# segments
|
177 |
+
self.num_padding_frames = int(padding_length_ms / frame_step_ms)
|
178 |
+
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
|
179 |
+
self.triggered = False
|
180 |
+
self.voiced_frames: List[Frame] = list()
|
181 |
+
self.segments = list()
|
182 |
+
|
183 |
+
# vad segments
|
184 |
+
self.is_first_segment = True
|
185 |
+
self.timestamp_start_s = 0.0
|
186 |
+
self.timestamp_end_s = 0.0
|
187 |
+
|
188 |
+
# speech probs
|
189 |
+
self.speech_probs: List[float] = list()
|
190 |
+
|
191 |
+
def reset(self):
|
192 |
+
# frames
|
193 |
+
self.frame_size = int(self.sample_rate * (self.frame_size_ms / 1000.0))
|
194 |
+
self.frame_step = int(self.sample_rate * (self.frame_step_ms / 1000.0))
|
195 |
+
self.frame_timestamp_s = 0.0
|
196 |
+
self.signal_cache = np.zeros(shape=(self.frame_size,), dtype=np.int16)
|
197 |
+
|
198 |
+
# segments
|
199 |
+
self.num_padding_frames = int(self.padding_length_ms / self.frame_step_ms)
|
200 |
+
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
|
201 |
+
self.triggered = False
|
202 |
+
self.voiced_frames: List[Frame] = list()
|
203 |
+
self.segments = list()
|
204 |
+
|
205 |
+
# vad segments
|
206 |
+
self.is_first_segment = True
|
207 |
+
self.timestamp_start_s = 0.0
|
208 |
+
self.timestamp_end_s = 0.0
|
209 |
+
|
210 |
+
# speech probs
|
211 |
+
self.speech_probs: List[float] = list()
|
212 |
+
|
213 |
+
def signal_to_frames(self, signal: np.ndarray):
|
214 |
+
frames = list()
|
215 |
+
|
216 |
+
l = len(signal)
|
217 |
+
|
218 |
+
duration_s = float(self.frame_step) / self.sample_rate
|
219 |
+
|
220 |
+
for offset in range(0, l - self.frame_size + 1, self.frame_step):
|
221 |
+
sub_signal = signal[offset:offset+self.frame_size]
|
222 |
+
frame = Frame(sub_signal, self.frame_timestamp_s)
|
223 |
+
self.frame_timestamp_s += duration_s
|
224 |
+
|
225 |
+
frames.append(frame)
|
226 |
+
return frames
|
227 |
+
|
228 |
+
def segments_generator(self, signal: np.ndarray):
|
229 |
+
# signal rounding
|
230 |
+
if self.signal_cache is not None:
|
231 |
+
signal = np.concatenate([self.signal_cache, signal])
|
232 |
+
|
233 |
+
# rest
|
234 |
+
rest = (len(signal) - self.frame_size) % self.frame_step
|
235 |
+
|
236 |
+
if rest == 0:
|
237 |
+
self.signal_cache = None
|
238 |
+
signal_ = signal
|
239 |
+
else:
|
240 |
+
self.signal_cache = signal[-rest:]
|
241 |
+
signal_ = signal[:-rest]
|
242 |
+
|
243 |
+
# frames
|
244 |
+
frames = self.signal_to_frames(signal_)
|
245 |
+
|
246 |
+
for frame in frames:
|
247 |
+
speech_prob = self.model.predict(frame.signal)
|
248 |
+
self.speech_probs.append(speech_prob)
|
249 |
+
|
250 |
+
if not self.triggered:
|
251 |
+
self.ring_buffer.append((frame, speech_prob))
|
252 |
+
num_voiced = sum([p for _, p in self.ring_buffer])
|
253 |
+
|
254 |
+
if num_voiced > self.start_ring_rate * self.ring_buffer.maxlen:
|
255 |
+
self.triggered = True
|
256 |
+
|
257 |
+
for f, _ in self.ring_buffer:
|
258 |
+
self.voiced_frames.append(f)
|
259 |
+
continue
|
260 |
+
|
261 |
+
self.voiced_frames.append(frame)
|
262 |
+
self.ring_buffer.append((frame, speech_prob))
|
263 |
+
num_voiced = sum([p for _, p in self.ring_buffer])
|
264 |
+
|
265 |
+
if num_voiced < self.end_ring_rate * self.ring_buffer.maxlen:
|
266 |
+
segment = [
|
267 |
+
np.concatenate([f.signal for f in self.voiced_frames]),
|
268 |
+
self.voiced_frames[0].timestamp_s,
|
269 |
+
self.voiced_frames[-1].timestamp_s,
|
270 |
+
]
|
271 |
+
yield segment
|
272 |
+
self.triggered = False
|
273 |
+
self.ring_buffer.clear()
|
274 |
+
self.voiced_frames = []
|
275 |
+
continue
|
276 |
+
|
277 |
+
def vad_segments_generator(self, segments_generator):
|
278 |
+
segments = list(segments_generator)
|
279 |
+
|
280 |
+
for i, segment in enumerate(segments):
|
281 |
+
start = round(segment[1], 4)
|
282 |
+
end = round(segment[2], 4)
|
283 |
+
|
284 |
+
if self.timestamp_start_s is None and self.timestamp_end_s is None:
|
285 |
+
self.timestamp_start_s = start
|
286 |
+
self.timestamp_end_s = end
|
287 |
+
continue
|
288 |
+
|
289 |
+
if self.timestamp_end_s - self.timestamp_start_s > self.max_speech_length_s:
|
290 |
+
end_ = self.timestamp_start_s + self.max_speech_length_s
|
291 |
+
vad_segment = [self.timestamp_start_s, end_]
|
292 |
+
yield vad_segment
|
293 |
+
self.timestamp_start_s = end_
|
294 |
+
|
295 |
+
silence_length_ms = (start - self.timestamp_end_s) * 1000
|
296 |
+
if silence_length_ms < self.max_silence_length_ms:
|
297 |
+
self.timestamp_end_s = end
|
298 |
+
continue
|
299 |
+
|
300 |
+
if self.timestamp_end_s - self.timestamp_start_s < self.min_speech_length_s:
|
301 |
+
self.timestamp_start_s = start
|
302 |
+
self.timestamp_end_s = end
|
303 |
+
continue
|
304 |
+
|
305 |
+
vad_segment = [self.timestamp_start_s, self.timestamp_end_s]
|
306 |
+
yield vad_segment
|
307 |
+
self.timestamp_start_s = start
|
308 |
+
self.timestamp_end_s = end
|
309 |
+
|
310 |
+
def vad(self, signal: np.ndarray) -> List[list]:
|
311 |
+
segments = self.segments_generator(signal)
|
312 |
+
vad_segments = self.vad_segments_generator(segments)
|
313 |
+
vad_segments = list(vad_segments)
|
314 |
+
return vad_segments
|
315 |
+
|
316 |
+
def last_vad_segments(self) -> List[list]:
|
317 |
+
# last segments
|
318 |
+
if len(self.voiced_frames) == 0:
|
319 |
+
segments = []
|
320 |
+
else:
|
321 |
+
segment = [
|
322 |
+
np.concatenate([f.signal for f in self.voiced_frames]),
|
323 |
+
self.voiced_frames[0].timestamp_s,
|
324 |
+
self.voiced_frames[-1].timestamp_s
|
325 |
+
]
|
326 |
+
segments = [segment]
|
327 |
+
|
328 |
+
# last vad segments
|
329 |
+
vad_segments = self.vad_segments_generator(segments)
|
330 |
+
vad_segments = list(vad_segments)
|
331 |
+
|
332 |
+
if self.timestamp_end_s > 1e-5 and self.timestamp_end_s > 1e-5:
|
333 |
+
vad_segments = vad_segments + [[self.timestamp_start_s, self.timestamp_end_s]]
|
334 |
+
return vad_segments
|
335 |
+
|
336 |
+
|
337 |
+
def process_speech_probs(signal: np.ndarray, speech_probs: List[float], frame_step: int) -> np.ndarray:
|
338 |
+
speech_probs_ = list()
|
339 |
+
for p in speech_probs[1:]:
|
340 |
+
speech_probs_.extend([p] * frame_step)
|
341 |
+
|
342 |
+
pad = (signal.shape[0] - len(speech_probs_))
|
343 |
+
speech_probs_ = speech_probs_ + [0.0] * pad
|
344 |
+
speech_probs_ = np.array(speech_probs_, dtype=np.float32)
|
345 |
+
|
346 |
+
if len(speech_probs_) != len(signal):
|
347 |
+
raise AssertionError
|
348 |
+
return speech_probs_
|
349 |
+
|
350 |
+
|
351 |
+
def make_visualization(signal: np.ndarray, speech_probs, sample_rate: int, vad_segments: list):
|
352 |
+
time = np.arange(0, len(signal)) / sample_rate
|
353 |
+
plt.figure(figsize=(12, 5))
|
354 |
+
plt.plot(time, signal / 32768, color='b')
|
355 |
+
plt.plot(time, speech_probs, color='gray')
|
356 |
+
for start, end in vad_segments:
|
357 |
+
plt.axvline(x=start, ymin=0.15, ymax=0.85, color="g", linestyle="--", label="开始端点")
|
358 |
+
plt.axvline(x=end, ymin=0.15, ymax=0.85, color="r", linestyle="--", label="结束端点")
|
359 |
+
|
360 |
+
plt.show()
|
361 |
+
return
|
362 |
+
|
363 |
+
|
364 |
+
def get_args():
|
365 |
+
parser = argparse.ArgumentParser()
|
366 |
+
parser.add_argument(
|
367 |
+
"--wav_file",
|
368 |
+
# default=(project_path / "data/early_media/62/3300999628999191096.wav").as_posix(),
|
369 |
+
default=r"D:/Users/tianx/HuggingDatasets/nx_noise/data/speech/nx-speech/en-PH/2025-05-28/active_media_w_1f650e5c-bd22-4803-bb88-d670b00fccda_30.wav",
|
370 |
+
type=str,
|
371 |
+
)
|
372 |
+
parser.add_argument(
|
373 |
+
"--model_path",
|
374 |
+
default=(project_path / "trained_models/silero_vad.jit").as_posix(),
|
375 |
+
type=str,
|
376 |
+
)
|
377 |
+
args = parser.parse_args()
|
378 |
+
return args
|
379 |
+
|
380 |
+
|
381 |
+
SAMPLE_RATE = 8000
|
382 |
+
|
383 |
+
|
384 |
+
def main():
|
385 |
+
args = get_args()
|
386 |
+
|
387 |
+
sample_rate, signal = wavfile.read(args.wav_file)
|
388 |
+
if SAMPLE_RATE != sample_rate:
|
389 |
+
raise AssertionError
|
390 |
+
|
391 |
+
# model = SileroVoiceClassifier(model_path=args.model_path, sample_rate=SAMPLE_RATE)
|
392 |
+
model = WebRTCVoiceClassifier(agg=3, sample_rate=SAMPLE_RATE)
|
393 |
+
# model = CallVoiceClassifier(model_path=(project_path / "trained_models/cnn_voicemail_common_20231130").as_posix())
|
394 |
+
|
395 |
+
# silero vad
|
396 |
+
ring_vad = RingVad(model=model,
|
397 |
+
start_ring_rate=0.2,
|
398 |
+
end_ring_rate=0.1,
|
399 |
+
frame_size_ms=32,
|
400 |
+
frame_step_ms=32,
|
401 |
+
padding_length_ms=320,
|
402 |
+
max_silence_length_ms=320,
|
403 |
+
max_speech_length_s=100,
|
404 |
+
min_speech_length_s=0.1,
|
405 |
+
sample_rate=SAMPLE_RATE,
|
406 |
+
)
|
407 |
+
# webrtcvad
|
408 |
+
ring_vad = RingVad(model=model,
|
409 |
+
start_ring_rate=0.9,
|
410 |
+
end_ring_rate=0.1,
|
411 |
+
frame_size_ms=30,
|
412 |
+
frame_step_ms=30,
|
413 |
+
padding_length_ms=300,
|
414 |
+
max_silence_length_ms=300,
|
415 |
+
max_speech_length_s=100,
|
416 |
+
min_speech_length_s=0.1,
|
417 |
+
sample_rate=SAMPLE_RATE,
|
418 |
+
)
|
419 |
+
print(ring_vad)
|
420 |
+
|
421 |
+
vad_segments = list()
|
422 |
+
|
423 |
+
segments = ring_vad.vad(signal)
|
424 |
+
vad_segments += segments
|
425 |
+
for segment in segments:
|
426 |
+
print(segment)
|
427 |
+
|
428 |
+
# last vad segment
|
429 |
+
segments = ring_vad.last_vad_segments()
|
430 |
+
vad_segments += segments
|
431 |
+
for segment in segments:
|
432 |
+
print(segment)
|
433 |
+
|
434 |
+
print(ring_vad.speech_probs)
|
435 |
+
print(len(ring_vad.speech_probs))
|
436 |
+
|
437 |
+
# speech_probs
|
438 |
+
speech_probs = process_speech_probs(
|
439 |
+
signal=signal,
|
440 |
+
speech_probs=ring_vad.speech_probs,
|
441 |
+
frame_step=ring_vad.frame_step,
|
442 |
+
)
|
443 |
+
|
444 |
+
# plot
|
445 |
+
make_visualization(signal, speech_probs, SAMPLE_RATE, vad_segments)
|
446 |
+
return
|
447 |
+
|
448 |
+
|
449 |
+
if __name__ == "__main__":
|
450 |
+
main()
|
toolbox/webrtcvad/vad.py
CHANGED
@@ -107,6 +107,7 @@ class WebRTCVad(object):
|
|
107 |
for frame in frames:
|
108 |
audio_bytes = bytes(frame.signal)
|
109 |
is_speech = self._vad.is_speech(audio_bytes, self.sample_rate)
|
|
|
110 |
|
111 |
if not self.triggered:
|
112 |
self.ring_buffer.append((frame, is_speech))
|
@@ -189,8 +190,9 @@ def get_args():
|
|
189 |
parser.add_argument(
|
190 |
"--wav_file",
|
191 |
# default=(project_path / "data/0eeaef67-ea59-4f2d-a5b8-b70c813fd45c.wav").as_posix(),
|
192 |
-
default=(project_path / "data/1c998b62-c3aa-4541-b59a-d4a40b79eff3.wav").as_posix(),
|
193 |
# default=(project_path / "data/8cbad66f-2c4e-43c2-ad11-ad95bab8bc15.wav").as_posix(),
|
|
|
194 |
type=str,
|
195 |
)
|
196 |
parser.add_argument(
|
@@ -206,12 +208,12 @@ def get_args():
|
|
206 |
)
|
207 |
parser.add_argument(
|
208 |
"--padding_duration_ms",
|
209 |
-
default=
|
210 |
type=int,
|
211 |
)
|
212 |
parser.add_argument(
|
213 |
"--silence_duration_threshold",
|
214 |
-
default=0.
|
215 |
type=float,
|
216 |
help="minimum silence duration, in seconds."
|
217 |
)
|
|
|
107 |
for frame in frames:
|
108 |
audio_bytes = bytes(frame.signal)
|
109 |
is_speech = self._vad.is_speech(audio_bytes, self.sample_rate)
|
110 |
+
print(f"is_speech: {is_speech}")
|
111 |
|
112 |
if not self.triggered:
|
113 |
self.ring_buffer.append((frame, is_speech))
|
|
|
190 |
parser.add_argument(
|
191 |
"--wav_file",
|
192 |
# default=(project_path / "data/0eeaef67-ea59-4f2d-a5b8-b70c813fd45c.wav").as_posix(),
|
193 |
+
# default=(project_path / "data/1c998b62-c3aa-4541-b59a-d4a40b79eff3.wav").as_posix(),
|
194 |
# default=(project_path / "data/8cbad66f-2c4e-43c2-ad11-ad95bab8bc15.wav").as_posix(),
|
195 |
+
default=r"D:/Users/tianx/HuggingDatasets/nx_noise/data/speech/nx-speech/en-PH/2025-05-28/active_media_w_1f650e5c-bd22-4803-bb88-d670b00fccda_30.wav",
|
196 |
type=str,
|
197 |
)
|
198 |
parser.add_argument(
|
|
|
208 |
)
|
209 |
parser.add_argument(
|
210 |
"--padding_duration_ms",
|
211 |
+
default=300,
|
212 |
type=int,
|
213 |
)
|
214 |
parser.add_argument(
|
215 |
"--silence_duration_threshold",
|
216 |
+
default=0.3,
|
217 |
type=float,
|
218 |
help="minimum silence duration, in seconds."
|
219 |
)
|