update
Browse files- examples/evaluation/step_1_run_evaluation.py +14 -13
- examples/evaluation/step_2_show_metrics.py +70 -0
- examples/evaluation/step_3_show_vad.py +105 -0
- examples/fsmn_vad_by_webrtcvad/run.sh +1 -1
- examples/fsmn_vad_by_webrtcvad/step_4_train_model.py +1 -0
- examples/silero_vad_by_webrtcvad/step_4_train_model.py +1 -0
- requirements.txt +1 -0
- toolbox/pydub/__init__.py +6 -0
- toolbox/pydub/volume.py +106 -0
- toolbox/torch/utils/data/dataset/vad_padding_jsonl_dataset.py +8 -0
examples/evaluation/step_1_run_evaluation.py
CHANGED
@@ -26,7 +26,7 @@ def get_args():
|
|
26 |
)
|
27 |
parser.add_argument(
|
28 |
"--output_file",
|
29 |
-
default=r"
|
30 |
type=str
|
31 |
)
|
32 |
parser.add_argument("--expected_sample_rate", default=8000, type=int)
|
@@ -105,12 +105,13 @@ def main():
|
|
105 |
},
|
106 |
audio_microphone_t=None,
|
107 |
start_ring_rate=0.5,
|
108 |
-
end_ring_rate=0.
|
109 |
-
ring_max_length=
|
110 |
min_silence_length=6,
|
111 |
max_speech_length=100000,
|
112 |
min_speech_length=15,
|
113 |
-
engine="fsmn-vad-by-webrtcvad-nx2-dns3",
|
|
|
114 |
api_name="/when_click_vad_button"
|
115 |
)
|
116 |
js = json.loads(message)
|
@@ -138,16 +139,16 @@ def main():
|
|
138 |
f.write(f"{row_}\n")
|
139 |
|
140 |
total += 1
|
141 |
-
total_accuracy += accuracy
|
142 |
-
total_precision += precision
|
143 |
-
total_recall += recall
|
144 |
-
total_f1 += f1
|
145 |
total_duration += duration
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
151 |
|
152 |
progress_bar.update(1)
|
153 |
progress_bar.set_postfix({
|
|
|
26 |
)
|
27 |
parser.add_argument(
|
28 |
"--output_file",
|
29 |
+
default=r"evaluation.jsonl",
|
30 |
type=str
|
31 |
)
|
32 |
parser.add_argument("--expected_sample_rate", default=8000, type=int)
|
|
|
105 |
},
|
106 |
audio_microphone_t=None,
|
107 |
start_ring_rate=0.5,
|
108 |
+
end_ring_rate=0.3,
|
109 |
+
ring_max_length=10,
|
110 |
min_silence_length=6,
|
111 |
max_speech_length=100000,
|
112 |
min_speech_length=15,
|
113 |
+
# engine="fsmn-vad-by-webrtcvad-nx2-dns3",
|
114 |
+
engine="silero-vad-by-webrtcvad-nx2-dns3",
|
115 |
api_name="/when_click_vad_button"
|
116 |
)
|
117 |
js = json.loads(message)
|
|
|
139 |
f.write(f"{row_}\n")
|
140 |
|
141 |
total += 1
|
|
|
|
|
|
|
|
|
142 |
total_duration += duration
|
143 |
+
total_accuracy += accuracy * duration
|
144 |
+
total_precision += precision * duration
|
145 |
+
total_recall += recall * duration
|
146 |
+
total_f1 += f1 * duration
|
147 |
+
|
148 |
+
average_accuracy = total_accuracy / total_duration
|
149 |
+
average_precision = total_precision / total_duration
|
150 |
+
average_recall = total_recall / total_duration
|
151 |
+
average_f1 = total_f1 / total_duration
|
152 |
|
153 |
progress_bar.update(1)
|
154 |
progress_bar.set_postfix({
|
examples/evaluation/step_2_show_metrics.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
|
8 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
9 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
10 |
+
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
def get_args():
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
|
17 |
+
parser.add_argument(
|
18 |
+
"--eval_file",
|
19 |
+
default=r"evaluation.jsonl",
|
20 |
+
type=str
|
21 |
+
)
|
22 |
+
args = parser.parse_args()
|
23 |
+
return args
|
24 |
+
|
25 |
+
|
26 |
+
def main():
|
27 |
+
args = get_args()
|
28 |
+
|
29 |
+
total = 0
|
30 |
+
total_duration = 0
|
31 |
+
total_accuracy = 0
|
32 |
+
total_precision = 0
|
33 |
+
total_recall = 0
|
34 |
+
total_f1 = 0
|
35 |
+
progress_bar = tqdm(desc="evaluation")
|
36 |
+
with open(args.eval_file, "r", encoding="utf-8") as f:
|
37 |
+
for row in f:
|
38 |
+
row = json.loads(row)
|
39 |
+
duration = row["duration"]
|
40 |
+
accuracy = row["accuracy"]
|
41 |
+
precision = row["precision"]
|
42 |
+
recall = row["recall"]
|
43 |
+
f1 = row["f1"]
|
44 |
+
|
45 |
+
total += 1
|
46 |
+
total_duration += duration
|
47 |
+
total_accuracy += accuracy * duration
|
48 |
+
total_precision += precision * duration
|
49 |
+
total_recall += recall * duration
|
50 |
+
total_f1 += f1 * duration
|
51 |
+
|
52 |
+
average_accuracy = total_accuracy / total_duration
|
53 |
+
average_precision = total_precision / total_duration
|
54 |
+
average_recall = total_recall / total_duration
|
55 |
+
average_f1 = total_f1 / total_duration
|
56 |
+
|
57 |
+
progress_bar.update(1)
|
58 |
+
progress_bar.set_postfix({
|
59 |
+
"total": total,
|
60 |
+
"accuracy": average_accuracy,
|
61 |
+
"precision": average_precision,
|
62 |
+
"recall": average_recall,
|
63 |
+
"f1": average_f1,
|
64 |
+
"total_duration": f"{round(total_duration / 60, 4)}min",
|
65 |
+
})
|
66 |
+
return
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
main()
|
examples/evaluation/step_3_show_vad.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import tempfile
|
8 |
+
|
9 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
10 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
11 |
+
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import numpy as np
|
14 |
+
from scipy.io import wavfile
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
|
18 |
+
def get_args():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
|
21 |
+
parser.add_argument(
|
22 |
+
"--eval_file",
|
23 |
+
default=r"evaluation.jsonl",
|
24 |
+
type=str
|
25 |
+
)
|
26 |
+
args = parser.parse_args()
|
27 |
+
return args
|
28 |
+
|
29 |
+
|
30 |
+
def show_image(signal: np.ndarray,
|
31 |
+
ground_truth_probs: np.ndarray,
|
32 |
+
prediction_probs: np.ndarray,
|
33 |
+
sample_rate: int = 8000,
|
34 |
+
):
|
35 |
+
duration = np.arange(0, len(signal)) / sample_rate
|
36 |
+
plt.figure(figsize=(12, 5))
|
37 |
+
|
38 |
+
plt.subplot(2, 1, 1) # 2行1列,第1个位置
|
39 |
+
plt.plot(duration, signal, color="b")
|
40 |
+
plt.plot(duration, ground_truth_probs, color="gray")
|
41 |
+
plt.title("ground_truth")
|
42 |
+
|
43 |
+
plt.subplot(2, 1, 2) # 2行1列,第2个位置
|
44 |
+
plt.plot(duration, signal, color="b")
|
45 |
+
plt.plot(duration, prediction_probs, color="gray")
|
46 |
+
plt.title("prediction")
|
47 |
+
|
48 |
+
# plt.tight_layout()
|
49 |
+
plt.subplots_adjust(hspace=0.5) # 调整上下间距
|
50 |
+
|
51 |
+
plt.show()
|
52 |
+
|
53 |
+
|
54 |
+
def main():
|
55 |
+
args = get_args()
|
56 |
+
|
57 |
+
with open(args.eval_file, "r", encoding="utf-8") as f:
|
58 |
+
for row in f:
|
59 |
+
row = json.loads(row)
|
60 |
+
filename = row["filename"]
|
61 |
+
duration = row["duration"]
|
62 |
+
ground_truth = row["ground_truth"]
|
63 |
+
prediction = row["prediction"]
|
64 |
+
|
65 |
+
accuracy = row["accuracy"]
|
66 |
+
precision = row["precision"]
|
67 |
+
recall = row["recall"]
|
68 |
+
f1 = row["f1"]
|
69 |
+
|
70 |
+
sample_rate, signal = wavfile.read(
|
71 |
+
filename=filename,
|
72 |
+
)
|
73 |
+
signal = np.array(signal / (1 << 15), dtype=np.float32)
|
74 |
+
signal_length = len(signal)
|
75 |
+
ground_truth_probs = np.zeros(shape=(signal_length,), dtype=np.float32)
|
76 |
+
for begin, end in ground_truth:
|
77 |
+
begin = int(begin * sample_rate)
|
78 |
+
end = int(end * sample_rate)
|
79 |
+
ground_truth_probs[begin:end] = 1
|
80 |
+
prediction_probs = np.zeros(shape=(signal_length,), dtype=np.float32)
|
81 |
+
for begin, end in prediction:
|
82 |
+
begin = int(begin * sample_rate)
|
83 |
+
end = int(end * sample_rate)
|
84 |
+
prediction_probs[begin:end] = 1
|
85 |
+
|
86 |
+
# p = encoder_num_layers * (encoder_kernel_size - 1) // 2 * hop_size * sample_rate
|
87 |
+
p = 3 * (3 - 1) // 2 * 80
|
88 |
+
p = int(p)
|
89 |
+
print(f"p: {p}")
|
90 |
+
prediction_probs = np.concat(
|
91 |
+
[
|
92 |
+
prediction_probs[p:], prediction_probs[-p:]
|
93 |
+
],
|
94 |
+
axis=-1
|
95 |
+
)
|
96 |
+
|
97 |
+
show_image(signal,
|
98 |
+
ground_truth_probs, prediction_probs,
|
99 |
+
sample_rate=sample_rate,
|
100 |
+
)
|
101 |
+
return
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
main()
|
examples/fsmn_vad_by_webrtcvad/run.sh
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
|
3 |
: <<'END'
|
4 |
|
5 |
-
bash run.sh --stage 3 --stop_stage
|
6 |
--file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
7 |
--final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
8 |
--noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
|
|
2 |
|
3 |
: <<'END'
|
4 |
|
5 |
+
bash run.sh --stage 3 --stop_stage 5 --system_version centos \
|
6 |
--file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
7 |
--final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
8 |
--noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
examples/fsmn_vad_by_webrtcvad/step_4_train_model.py
CHANGED
@@ -127,6 +127,7 @@ def main():
|
|
127 |
max_wave_value=32768.0,
|
128 |
min_snr_db=config.min_snr_db,
|
129 |
max_snr_db=config.max_snr_db,
|
|
|
130 |
# skip=225000,
|
131 |
)
|
132 |
valid_dataset = VadPaddingJsonlDataset(
|
|
|
127 |
max_wave_value=32768.0,
|
128 |
min_snr_db=config.min_snr_db,
|
129 |
max_snr_db=config.max_snr_db,
|
130 |
+
do_volume_enhancement=True,
|
131 |
# skip=225000,
|
132 |
)
|
133 |
valid_dataset = VadPaddingJsonlDataset(
|
examples/silero_vad_by_webrtcvad/step_4_train_model.py
CHANGED
@@ -127,6 +127,7 @@ def main():
|
|
127 |
max_wave_value=32768.0,
|
128 |
min_snr_db=config.min_snr_db,
|
129 |
max_snr_db=config.max_snr_db,
|
|
|
130 |
# skip=225000,
|
131 |
)
|
132 |
valid_dataset = VadPaddingJsonlDataset(
|
|
|
127 |
max_wave_value=32768.0,
|
128 |
min_snr_db=config.min_snr_db,
|
129 |
max_snr_db=config.max_snr_db,
|
130 |
+
do_volume_enhancement=True,
|
131 |
# skip=225000,
|
132 |
)
|
133 |
valid_dataset = VadPaddingJsonlDataset(
|
requirements.txt
CHANGED
@@ -4,6 +4,7 @@ datasets==3.2.0
|
|
4 |
python-dotenv==1.0.1
|
5 |
scipy==1.15.1
|
6 |
librosa==0.10.2.post1
|
|
|
7 |
pandas==2.2.3
|
8 |
openpyxl==3.1.5
|
9 |
torch==2.5.1
|
|
|
4 |
python-dotenv==1.0.1
|
5 |
scipy==1.15.1
|
6 |
librosa==0.10.2.post1
|
7 |
+
pydub==0.25.1
|
8 |
pandas==2.2.3
|
9 |
openpyxl==3.1.5
|
10 |
torch==2.5.1
|
toolbox/pydub/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/pydub/volume.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import numpy as np
|
8 |
+
from pydub import AudioSegment
|
9 |
+
from scipy.io import wavfile
|
10 |
+
|
11 |
+
from project_settings import project_path
|
12 |
+
|
13 |
+
|
14 |
+
def score_transform(x: float, stages: List[float], scores: List[float], ndigits: int = 4):
|
15 |
+
last_stage = stages[0]
|
16 |
+
last_score = scores[0]
|
17 |
+
stages = stages[1:]
|
18 |
+
scores = scores[1:]
|
19 |
+
for stage, score in zip(stages, scores):
|
20 |
+
if x >= stage:
|
21 |
+
result = score + (x - stage) / (last_stage - stage + 1e-7) * (last_score - score)
|
22 |
+
return round(result, ndigits)
|
23 |
+
last_stage = stage
|
24 |
+
last_score = score
|
25 |
+
raise ValueError(f"values of x, stages and scores should between 0 and 1, "
|
26 |
+
f"stages and scores should be same length and decreased. "
|
27 |
+
f"x: {x}, stages: {stages}, scores: {scores}")
|
28 |
+
|
29 |
+
|
30 |
+
def set_volume(waveform: np.ndarray, sample_rate: int = 8000, volume: int = 0):
|
31 |
+
if np.min(waveform) < -1 or np.max(waveform) > 1:
|
32 |
+
raise AssertionError(f"waveform type: {type(waveform)}, dtype: {waveform.dtype}")
|
33 |
+
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
34 |
+
raw_data = waveform.tobytes()
|
35 |
+
|
36 |
+
audio_segment = AudioSegment(
|
37 |
+
data=raw_data,
|
38 |
+
sample_width=2,
|
39 |
+
frame_rate=sample_rate,
|
40 |
+
channels=1
|
41 |
+
)
|
42 |
+
|
43 |
+
map_list = [
|
44 |
+
[0, -150],
|
45 |
+
[10, -60],
|
46 |
+
[50, -35],
|
47 |
+
[100, -20],
|
48 |
+
]
|
49 |
+
stages = [a for a, b in map_list]
|
50 |
+
scores = [b for a, b in map_list]
|
51 |
+
|
52 |
+
# 计算目标 dBFS
|
53 |
+
target_db = score_transform(
|
54 |
+
x=volume,
|
55 |
+
stages=list(reversed(stages)),
|
56 |
+
scores=list(reversed(scores)),
|
57 |
+
)
|
58 |
+
|
59 |
+
audio_segment = audio_segment.apply_gain(target_db - audio_segment.dBFS)
|
60 |
+
|
61 |
+
samples = np.array(audio_segment.get_array_of_samples())
|
62 |
+
|
63 |
+
if audio_segment.sample_width == 2:
|
64 |
+
samples = samples.astype(np.float32) / (1 << (2*8-1))
|
65 |
+
elif audio_segment.sample_width == 3:
|
66 |
+
samples = samples.astype(np.float32) / (1 << (3*8-1))
|
67 |
+
elif audio_segment.sample_width == 4:
|
68 |
+
samples = samples.astype(np.float32) / (1 << (4*8-1))
|
69 |
+
else:
|
70 |
+
raise AssertionError
|
71 |
+
return samples
|
72 |
+
|
73 |
+
|
74 |
+
def get_args():
|
75 |
+
parser = argparse.ArgumentParser()
|
76 |
+
parser.add_argument(
|
77 |
+
"--filename",
|
78 |
+
default=(project_path / "data/examples/ai_agent/chinese-1.wav").as_posix(),
|
79 |
+
type=str
|
80 |
+
)
|
81 |
+
args = parser.parse_args()
|
82 |
+
return args
|
83 |
+
|
84 |
+
|
85 |
+
def main():
|
86 |
+
args = get_args()
|
87 |
+
|
88 |
+
waveform, sample_rate = librosa.load(args.filename, sr=8000)
|
89 |
+
|
90 |
+
waveform = set_volume(
|
91 |
+
waveform=waveform,
|
92 |
+
sample_rate=sample_rate,
|
93 |
+
volume=10
|
94 |
+
)
|
95 |
+
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
96 |
+
|
97 |
+
wavfile.write(
|
98 |
+
"temp.wav",
|
99 |
+
rate=8000,
|
100 |
+
data=waveform,
|
101 |
+
)
|
102 |
+
return
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
main()
|
toolbox/torch/utils/data/dataset/vad_padding_jsonl_dataset.py
CHANGED
@@ -9,6 +9,8 @@ import numpy as np
|
|
9 |
import torch
|
10 |
from torch.utils.data import Dataset, IterableDataset
|
11 |
|
|
|
|
|
12 |
|
13 |
class VadPaddingJsonlDataset(IterableDataset):
|
14 |
def __init__(self,
|
@@ -19,6 +21,7 @@ class VadPaddingJsonlDataset(IterableDataset):
|
|
19 |
buffer_size: int = 1000,
|
20 |
min_snr_db: float = None,
|
21 |
max_snr_db: float = None,
|
|
|
22 |
speech_target_duration: float = 8.0,
|
23 |
eps: float = 1e-8,
|
24 |
skip: int = 0,
|
@@ -29,6 +32,7 @@ class VadPaddingJsonlDataset(IterableDataset):
|
|
29 |
self.max_wave_value = max_wave_value
|
30 |
self.min_snr_db = min_snr_db
|
31 |
self.max_snr_db = max_snr_db
|
|
|
32 |
self.speech_target_duration = speech_target_duration
|
33 |
self.eps = eps
|
34 |
self.skip = skip
|
@@ -134,6 +138,10 @@ class VadPaddingJsonlDataset(IterableDataset):
|
|
134 |
speech_wave_np, left_pad_duration, _ = self.pad_waveform(speech_wave_np, self.expected_sample_rate, self.speech_target_duration)
|
135 |
speech_wave_np = self.make_sure_duration(speech_wave_np, self.expected_sample_rate, self.speech_target_duration)
|
136 |
|
|
|
|
|
|
|
|
|
137 |
noise_wave_list = list()
|
138 |
for noise in noise_list:
|
139 |
filename = noise["filename"]
|
|
|
9 |
import torch
|
10 |
from torch.utils.data import Dataset, IterableDataset
|
11 |
|
12 |
+
from toolbox.pydub.volume import set_volume
|
13 |
+
|
14 |
|
15 |
class VadPaddingJsonlDataset(IterableDataset):
|
16 |
def __init__(self,
|
|
|
21 |
buffer_size: int = 1000,
|
22 |
min_snr_db: float = None,
|
23 |
max_snr_db: float = None,
|
24 |
+
do_volume_enhancement: bool = False,
|
25 |
speech_target_duration: float = 8.0,
|
26 |
eps: float = 1e-8,
|
27 |
skip: int = 0,
|
|
|
32 |
self.max_wave_value = max_wave_value
|
33 |
self.min_snr_db = min_snr_db
|
34 |
self.max_snr_db = max_snr_db
|
35 |
+
self.do_volume_enhancement = do_volume_enhancement
|
36 |
self.speech_target_duration = speech_target_duration
|
37 |
self.eps = eps
|
38 |
self.skip = skip
|
|
|
138 |
speech_wave_np, left_pad_duration, _ = self.pad_waveform(speech_wave_np, self.expected_sample_rate, self.speech_target_duration)
|
139 |
speech_wave_np = self.make_sure_duration(speech_wave_np, self.expected_sample_rate, self.speech_target_duration)
|
140 |
|
141 |
+
# volume enhancement
|
142 |
+
volume = random.randint(0, 100)
|
143 |
+
speech_wave_np = set_volume(speech_wave_np, sample_rate=self.expected_sample_rate, volume=volume)
|
144 |
+
|
145 |
noise_wave_list = list()
|
146 |
for noise in noise_list:
|
147 |
filename = noise["filename"]
|