HoneyTian commited on
Commit
bebc2b8
·
1 Parent(s): 12ce37e
examples/evaluation/step_1_run_evaluation.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import sys
8
+
9
+ pwd = os.path.abspath(os.path.dirname(__file__))
10
+ sys.path.append(os.path.join(pwd, "../../"))
11
+
12
+ import librosa
13
+ from gradio_client import Client
14
+ import numpy as np
15
+ from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score
16
+ from tqdm import tqdm
17
+
18
+
19
+ def get_args():
20
+ parser = argparse.ArgumentParser()
21
+
22
+ parser.add_argument(
23
+ "--test_set",
24
+ default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\vad",
25
+ type=str
26
+ )
27
+ parser.add_argument(
28
+ "--output_file",
29
+ default=r"fsmn-vad.jsonl",
30
+ type=str
31
+ )
32
+ parser.add_argument("--expected_sample_rate", default=8000, type=int)
33
+
34
+ args = parser.parse_args()
35
+ return args
36
+
37
+
38
+ def get_metrics(ground_truth, predictions, total_duration, step=0.01):
39
+ """
40
+ 基于时间点离散化的评估方法
41
+ :param ground_truth: 真实区间列表,格式 [[start1, end1], [start2, end2], ...]
42
+ :param predictions: 预测区间列表,格式同上
43
+ :param total_duration: 音频总时长(秒)
44
+ :param step: 时间离散化步长(默认10ms)
45
+ :return: 评估指标字典
46
+ """
47
+ # 生成时间点数组
48
+ time_points = np.arange(0, total_duration, step)
49
+
50
+ # 生成标签数组
51
+ y_true = np.zeros_like(time_points, dtype=int)
52
+ y_pred = np.zeros_like(time_points, dtype=int)
53
+
54
+ # 标记真实语音区间
55
+ for start, end in ground_truth:
56
+ mask = (time_points >= start) & (time_points <= end)
57
+ y_true[mask] = 1
58
+
59
+ # 标记预测语音区间
60
+ for start, end in predictions:
61
+ mask = (time_points >= start) & (time_points <= end)
62
+ y_pred[mask] = 1
63
+
64
+ # 计算指标
65
+ result = {
66
+ "accuracy": accuracy_score(y_true, y_pred),
67
+ "precision": precision_score(y_true, y_pred, zero_division=0),
68
+ "recall": recall_score(y_true, y_pred, zero_division=0),
69
+ "f1": f1_score(y_true, y_pred, zero_division=0)
70
+ }
71
+ return result
72
+
73
+
74
+ def main():
75
+ args = get_args()
76
+
77
+ client = Client("http://127.0.0.1:7866/")
78
+
79
+ test_set = Path(args.test_set)
80
+ output_file = Path(args.output_file)
81
+
82
+ annotation_file = test_set / "vad.json"
83
+
84
+ with open(annotation_file.as_posix(), "r", encoding="utf-8") as f:
85
+ annotation = json.load(f)
86
+
87
+ total = 0
88
+ total_accuracy = 0
89
+ total_precision = 0
90
+ total_recall = 0
91
+ total_f1 = 0
92
+ total_duration = 0
93
+ progress_bar = tqdm(desc="evaluation")
94
+ with open(output_file.as_posix(), "w", encoding="utf-8") as f:
95
+ for row in annotation:
96
+ filename = row["filename"]
97
+ ground_truth_vad_segments = row["vad_segments"]
98
+
99
+ filename = test_set / filename
100
+
101
+ _, _, _, message = client.predict(
102
+ audio_file_t={
103
+ "path": filename.as_posix(),
104
+ "meta": {"_type": "gradio.FileData"}
105
+ },
106
+ audio_microphone_t=None,
107
+ start_ring_rate=0.5,
108
+ end_ring_rate=0.5,
109
+ ring_max_length=1,
110
+ min_silence_length=6,
111
+ max_speech_length=100000,
112
+ min_speech_length=15,
113
+ engine="fsmn-vad-by-webrtcvad-nx2-dns3",
114
+ api_name="/when_click_vad_button"
115
+ )
116
+ js = json.loads(message)
117
+ prediction_vad_segments = js["vad_segments"]
118
+ duration = js["duration"]
119
+
120
+ metrics = get_metrics(ground_truth_vad_segments, prediction_vad_segments, duration)
121
+ accuracy = metrics["accuracy"]
122
+ precision = metrics["precision"]
123
+ recall = metrics["recall"]
124
+ f1 = metrics["f1"]
125
+
126
+ row_ = {
127
+ "filename": filename.as_posix(),
128
+ "duration": duration,
129
+ "ground_truth": ground_truth_vad_segments,
130
+ "prediction": prediction_vad_segments,
131
+
132
+ "accuracy": accuracy,
133
+ "precision": precision,
134
+ "recall": recall,
135
+ "f1": f1,
136
+ }
137
+ row_ = json.dumps(row_, ensure_ascii=False)
138
+ f.write(f"{row_}\n")
139
+
140
+ total += 1
141
+ total_accuracy += accuracy
142
+ total_precision += precision
143
+ total_recall += recall
144
+ total_f1 += f1
145
+ total_duration += duration
146
+
147
+ average_accuracy = total_accuracy / total
148
+ average_precision = total_precision / total
149
+ average_recall = total_recall / total
150
+ average_f1 = total_f1 / total
151
+
152
+ progress_bar.update(1)
153
+ progress_bar.set_postfix({
154
+ "total": total,
155
+ "accuracy": average_accuracy,
156
+ "precision": average_precision,
157
+ "recall": average_recall,
158
+ "f1": average_f1,
159
+ "total_duration": f"{round(total_duration / 60, 4)}min",
160
+ })
161
+
162
+ return
163
+
164
+
165
+ if __name__ == "__main__":
166
+ main()
main.py CHANGED
@@ -101,6 +101,7 @@ def generate_image(signal: np.ndarray, speech_probs: np.ndarray, sample_rate: in
101
 
102
  def when_click_vad_button(audio_file_t = None, audio_microphone_t = None,
103
  start_ring_rate: float = 0.5, end_ring_rate: float = 0.3,
 
104
  min_silence_length: int = 2,
105
  max_speech_length: int = 10000, min_speech_length: int = 10,
106
  engine: str = None,
@@ -112,7 +113,7 @@ def when_click_vad_button(audio_file_t = None, audio_microphone_t = None,
112
  audio_t: Tuple = audio_file_t or audio_microphone_t
113
 
114
  sample_rate, signal = audio_t
115
- audio_duration = signal.shape[-1] // 8000
116
  audio = np.array(signal / (1 << 15), dtype=np.float32)
117
 
118
  infer_engine_param = vad_engines.get(engine)
@@ -128,38 +129,55 @@ def when_click_vad_button(audio_file_t = None, audio_microphone_t = None,
128
  vad_info = infer_engine.infer(audio)
129
  time_cost = time.time() - begin
130
 
131
- fpr = time_cost / audio_duration
132
- info = {
133
- "time_cost": round(time_cost, 4),
134
- "audio_duration": round(audio_duration, 4),
135
- "fpr": round(fpr, 4)
136
- }
137
- message = json.dumps(info, ensure_ascii=False, indent=4)
138
-
139
  probs = vad_info["probs"]
140
  lsnr = vad_info["lsnr"]
141
  # lsnr = lsnr / np.max(np.abs(lsnr))
142
  lsnr = lsnr / 30
143
 
144
  frame_step = infer_engine.config.hop_size
145
- probs_ = process_speech_probs(audio, probs, frame_step)
146
- probs_image = generate_image(audio, probs_)
147
-
148
- lsnr_ = process_speech_probs(audio, lsnr, frame_step)
149
- lsnr_image = generate_image(audio, lsnr_)
150
 
151
  # post process
152
  vad_post_process = PostProcess(
153
  start_ring_rate=start_ring_rate,
154
  end_ring_rate=end_ring_rate,
 
155
  min_silence_length=min_silence_length,
156
  max_speech_length=max_speech_length,
157
  min_speech_length=min_speech_length
158
  )
159
- vad = vad_post_process.post_process(probs)
160
- vad_ = process_speech_probs(audio, vad, frame_step)
 
 
 
161
  vad_image = generate_image(audio, vad_)
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  except Exception as e:
164
  raise gr.Error(f"vad failed, error type: {type(e)}, error text: {str(e)}.")
165
 
@@ -240,10 +258,12 @@ def main():
240
  with gr.Row():
241
  vad_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="start_ring_rate")
242
  vad_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="end_ring_rate")
243
- vad_min_silence_length = gr.Number(value=30, label="min_silence_length")
244
  with gr.Row():
245
- vad_max_speech_length = gr.Number(value=100000, label="max_speech_length")
246
- vad_min_speech_length = gr.Number(value=15, label="min_speech_length")
 
 
 
247
  vad_engine = gr.Dropdown(choices=vad_engine_choices, value=vad_engine_choices[0], label="engine")
248
  vad_button = gr.Button(variant="primary")
249
  with gr.Column(variant="panel", scale=5):
@@ -257,6 +277,7 @@ def main():
257
  inputs=[
258
  vad_audio_file, vad_audio_microphone,
259
  vad_start_ring_rate, vad_end_ring_rate,
 
260
  vad_min_silence_length,
261
  vad_max_speech_length, vad_min_speech_length,
262
  vad_engine,
@@ -288,7 +309,8 @@ def main():
288
  # share=True,
289
  share=False if platform.system() == "Windows" else False,
290
  server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
291
- server_port=args.server_port
 
292
  )
293
  return
294
 
 
101
 
102
  def when_click_vad_button(audio_file_t = None, audio_microphone_t = None,
103
  start_ring_rate: float = 0.5, end_ring_rate: float = 0.3,
104
+ ring_max_length: int = 10,
105
  min_silence_length: int = 2,
106
  max_speech_length: int = 10000, min_speech_length: int = 10,
107
  engine: str = None,
 
113
  audio_t: Tuple = audio_file_t or audio_microphone_t
114
 
115
  sample_rate, signal = audio_t
116
+ audio_duration = signal.shape[-1] // sample_rate
117
  audio = np.array(signal / (1 << 15), dtype=np.float32)
118
 
119
  infer_engine_param = vad_engines.get(engine)
 
129
  vad_info = infer_engine.infer(audio)
130
  time_cost = time.time() - begin
131
 
 
 
 
 
 
 
 
 
132
  probs = vad_info["probs"]
133
  lsnr = vad_info["lsnr"]
134
  # lsnr = lsnr / np.max(np.abs(lsnr))
135
  lsnr = lsnr / 30
136
 
137
  frame_step = infer_engine.config.hop_size
 
 
 
 
 
138
 
139
  # post process
140
  vad_post_process = PostProcess(
141
  start_ring_rate=start_ring_rate,
142
  end_ring_rate=end_ring_rate,
143
+ ring_max_length=ring_max_length,
144
  min_silence_length=min_silence_length,
145
  max_speech_length=max_speech_length,
146
  min_speech_length=min_speech_length
147
  )
148
+ vad_segments = vad_post_process.get_vad_segments(probs)
149
+ vad_flags = vad_post_process.get_vad_flags(probs, vad_segments)
150
+
151
+ # vad_image
152
+ vad_ = process_speech_probs(audio, vad_flags, frame_step)
153
  vad_image = generate_image(audio, vad_)
154
 
155
+ # probs_image
156
+ probs_ = process_speech_probs(audio, probs, frame_step)
157
+ probs_image = generate_image(audio, probs_)
158
+
159
+ # lsnr_image
160
+ lsnr_ = process_speech_probs(audio, lsnr, frame_step)
161
+ lsnr_image = generate_image(audio, lsnr_)
162
+
163
+ # vad segment
164
+ vad_segments = [
165
+ [
166
+ v[0] * frame_step / sample_rate,
167
+ v[1] * frame_step / sample_rate
168
+ ] for v in vad_segments
169
+ ]
170
+
171
+ # message
172
+ rtf = time_cost / audio_duration
173
+ info = {
174
+ "vad_segments": vad_segments,
175
+ "time_cost": round(time_cost, 4),
176
+ "duration": round(audio_duration, 4),
177
+ "rtf": round(rtf, 4)
178
+ }
179
+ message = json.dumps(info, ensure_ascii=False, indent=4)
180
+
181
  except Exception as e:
182
  raise gr.Error(f"vad failed, error type: {type(e)}, error text: {str(e)}.")
183
 
 
258
  with gr.Row():
259
  vad_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="start_ring_rate")
260
  vad_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="end_ring_rate")
 
261
  with gr.Row():
262
+ vad_ring_max_length = gr.Number(value=10, label="ring_max_length (*10ms)")
263
+ vad_min_silence_length = gr.Number(value=6, label="min_silence_length (*10ms)")
264
+ with gr.Row():
265
+ vad_max_speech_length = gr.Number(value=100000, label="max_speech_length (*10ms)")
266
+ vad_min_speech_length = gr.Number(value=15, label="min_speech_length (*10ms)")
267
  vad_engine = gr.Dropdown(choices=vad_engine_choices, value=vad_engine_choices[0], label="engine")
268
  vad_button = gr.Button(variant="primary")
269
  with gr.Column(variant="panel", scale=5):
 
277
  inputs=[
278
  vad_audio_file, vad_audio_microphone,
279
  vad_start_ring_rate, vad_end_ring_rate,
280
+ vad_ring_max_length,
281
  vad_min_silence_length,
282
  vad_max_speech_length, vad_min_speech_length,
283
  vad_engine,
 
309
  # share=True,
310
  share=False if platform.system() == "Windows" else False,
311
  server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
312
+ server_port=args.server_port,
313
+ show_error=True
314
  )
315
  return
316
 
toolbox/vad/utils.py CHANGED
@@ -9,18 +9,20 @@ class PostProcess(object):
9
  def __init__(self,
10
  start_ring_rate: float = 0.5,
11
  end_ring_rate: float = 0.5,
12
- min_silence_length: int = 1,
13
- max_speech_length: float = 10,
14
- min_speech_length: float = 2,
 
15
  ):
16
  self.start_ring_rate = start_ring_rate
17
  self.end_ring_rate = end_ring_rate
 
18
  self.max_speech_length = max_speech_length
19
  self.min_speech_length = min_speech_length
20
  self.min_silence_length = min_silence_length
21
 
22
  # segments
23
- self.ring_buffer = collections.deque(maxlen=10)
24
  self.triggered = False
25
 
26
  # vad segments
@@ -117,19 +119,27 @@ class PostProcess(object):
117
  vad_segments = vad_segments + [[self.start_idx, self.end_idx]]
118
  return vad_segments
119
 
120
- def post_process(self, probs: List[float]):
121
  vad_segments = list()
122
  segments = self.vad(probs)
123
  vad_segments += segments
124
  segments = self.last_vad_segments()
125
  vad_segments += segments
126
 
 
 
 
127
  result = [0] * len(probs)
128
  for begin, end in vad_segments:
129
  result[begin: end] = [1] * (end - begin)
130
 
131
  return result
132
 
 
 
 
 
 
133
 
134
  if __name__ == "__main__":
135
  pass
 
9
  def __init__(self,
10
  start_ring_rate: float = 0.5,
11
  end_ring_rate: float = 0.5,
12
+ ring_max_length: int = 10,
13
+ min_silence_length: int = 6,
14
+ max_speech_length: float = 100000,
15
+ min_speech_length: float = 15,
16
  ):
17
  self.start_ring_rate = start_ring_rate
18
  self.end_ring_rate = end_ring_rate
19
+ self.ring_max_length = ring_max_length
20
  self.max_speech_length = max_speech_length
21
  self.min_speech_length = min_speech_length
22
  self.min_silence_length = min_silence_length
23
 
24
  # segments
25
+ self.ring_buffer = collections.deque(maxlen=self.ring_max_length)
26
  self.triggered = False
27
 
28
  # vad segments
 
119
  vad_segments = vad_segments + [[self.start_idx, self.end_idx]]
120
  return vad_segments
121
 
122
+ def get_vad_segments(self, probs: List[float]):
123
  vad_segments = list()
124
  segments = self.vad(probs)
125
  vad_segments += segments
126
  segments = self.last_vad_segments()
127
  vad_segments += segments
128
 
129
+ return vad_segments
130
+
131
+ def get_vad_flags(self, probs: List[float], vad_segments: List[Tuple[int, int]]):
132
  result = [0] * len(probs)
133
  for begin, end in vad_segments:
134
  result[begin: end] = [1] * (end - begin)
135
 
136
  return result
137
 
138
+ def post_process(self, probs: List[float]):
139
+ vad_segments = self.get_vad_segments(probs)
140
+ vad_flags = self.get_vad_flags(probs, vad_segments)
141
+ return vad_flags
142
+
143
 
144
  if __name__ == "__main__":
145
  pass