LAP-DEV commited on
Commit
8aeae3c
·
verified ·
1 Parent(s): c571c60

Upload silero_vad.py

Browse files
Files changed (1) hide show
  1. modules/vad/silero_vad.py +302 -0
modules/vad/silero_vad.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py
2
+
3
+ from faster_whisper.vad import VadOptions, get_vad_model
4
+ import numpy as np
5
+ from typing import BinaryIO, Union, List, Optional, Tuple
6
+ import warnings
7
+ import bisect
8
+ import faster_whisper
9
+ from faster_whisper.transcribe import SpeechTimestampsMap
10
+ import gradio as gr
11
+
12
+ class SileroVAD:
13
+ def __init__(self):
14
+ self.sampling_rate = 16000
15
+ self.window_size_samples = 512
16
+ self.model = None
17
+
18
+ def run(self,
19
+ audio: Union[str, BinaryIO, np.ndarray],
20
+ vad_parameters: VadOptions,
21
+ progress: gr.Progress = gr.Progress()
22
+ ) -> Tuple[np.ndarray, List[dict]]:
23
+ """
24
+ Run VAD
25
+
26
+ Parameters
27
+ ----------
28
+ audio: Union[str, BinaryIO, np.ndarray]
29
+ Audio path or file binary or Audio numpy array
30
+ vad_parameters:
31
+ Options for VAD processing.
32
+ progress: gr.Progress
33
+ Indicator to show progress directly in gradio.
34
+
35
+ Returns
36
+ ----------
37
+ np.ndarray
38
+ Pre-processed audio with VAD
39
+ List[dict]
40
+ Chunks of speeches to be used to restore the timestamps later
41
+ """
42
+
43
+ sampling_rate = self.sampling_rate
44
+
45
+ if not isinstance(audio, np.ndarray):
46
+ audio = faster_whisper.decode_audio(audio, sampling_rate=sampling_rate)
47
+
48
+ duration = audio.shape[0] / sampling_rate
49
+ duration_after_vad = duration
50
+
51
+ if vad_parameters is None:
52
+ vad_parameters = VadOptions()
53
+ elif isinstance(vad_parameters, dict):
54
+ vad_parameters = VadOptions(**vad_parameters)
55
+ speech_chunks = self.get_speech_timestamps(
56
+ audio=audio,
57
+ vad_options=vad_parameters,
58
+ progress=progress
59
+ )
60
+
61
+ audio = self.collect_chunks(audio, speech_chunks)
62
+ duration_after_vad = audio.shape[0] / sampling_rate
63
+
64
+ return audio, speech_chunks
65
+
66
+ def get_speech_timestamps(
67
+ self,
68
+ audio: np.ndarray,
69
+ vad_options: Optional[VadOptions] = None,
70
+ progress: gr.Progress = gr.Progress(),
71
+ **kwargs,
72
+ ) -> List[dict]:
73
+ """This method is used for splitting long audios into speech chunks using silero VAD.
74
+
75
+ Args:
76
+ audio: One dimensional float array.
77
+ vad_options: Options for VAD processing.
78
+ kwargs: VAD options passed as keyword arguments for backward compatibility.
79
+ progress: Gradio progress to indicate progress.
80
+
81
+ Returns:
82
+ List of dicts containing begin and end samples of each speech chunk.
83
+ """
84
+
85
+ if self.model is None:
86
+ self.update_model()
87
+
88
+ if vad_options is None:
89
+ vad_options = VadOptions(**kwargs)
90
+
91
+ threshold = vad_options.threshold
92
+ neg_threshold = vad_options.neg_threshold
93
+ min_speech_duration_ms = vad_options.min_speech_duration_ms
94
+ max_speech_duration_s = vad_options.max_speech_duration_s
95
+ min_silence_duration_ms = vad_options.min_silence_duration_ms
96
+ window_size_samples = self.window_size_samples
97
+ speech_pad_ms = vad_options.speech_pad_ms
98
+ min_speech_samples = self.sampling_rate * min_speech_duration_ms / 1000
99
+ speech_pad_samples = self.sampling_rate * speech_pad_ms / 1000
100
+ max_speech_samples = (
101
+ self.sampling_rate * max_speech_duration_s
102
+ - window_size_samples
103
+ - 2 * speech_pad_samples
104
+ )
105
+ min_silence_samples = self.sampling_rate * min_silence_duration_ms / 1000
106
+ min_silence_samples_at_max_speech = self.sampling_rate * 98 / 1000
107
+
108
+ audio_length_samples = len(audio)
109
+
110
+ padded_audio = np.pad(
111
+ audio, (0, window_size_samples - audio.shape[0] % window_size_samples)
112
+ )
113
+ speech_probs = self.model(padded_audio.reshape(1, -1)).squeeze(0)
114
+
115
+ triggered = False
116
+ speeches = []
117
+ current_speech = {}
118
+ if neg_threshold is None:
119
+ neg_threshold = max(threshold - 0.15, 0.01)
120
+
121
+ # to save potential segment end (and tolerate some silence)
122
+ temp_end = 0
123
+ # to save potential segment limits in case of maximum segment size reached
124
+ prev_end = next_start = 0
125
+
126
+ for i, speech_prob in enumerate(speech_probs):
127
+ if (speech_prob >= threshold) and temp_end:
128
+ temp_end = 0
129
+ if next_start < prev_end:
130
+ next_start = window_size_samples * i
131
+
132
+ if (speech_prob >= threshold) and not triggered:
133
+ triggered = True
134
+ current_speech["start"] = window_size_samples * i
135
+ continue
136
+
137
+ if (
138
+ triggered
139
+ and (window_size_samples * i) - current_speech["start"] > max_speech_samples
140
+ ):
141
+ if prev_end:
142
+ current_speech["end"] = prev_end
143
+ speeches.append(current_speech)
144
+ current_speech = {}
145
+ # previously reached silence (< neg_thres) and is still not speech (< thres)
146
+ if next_start < prev_end:
147
+ triggered = False
148
+ else:
149
+ current_speech["start"] = next_start
150
+ prev_end = next_start = temp_end = 0
151
+ else:
152
+ current_speech["end"] = window_size_samples * i
153
+ speeches.append(current_speech)
154
+ current_speech = {}
155
+ prev_end = next_start = temp_end = 0
156
+ triggered = False
157
+ continue
158
+
159
+ if (speech_prob < neg_threshold) and triggered:
160
+ if not temp_end:
161
+ temp_end = window_size_samples * i
162
+ # condition to avoid cutting in very short silence
163
+ if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
164
+ prev_end = temp_end
165
+ if (window_size_samples * i) - temp_end < min_silence_samples:
166
+ continue
167
+ else:
168
+ current_speech["end"] = temp_end
169
+ if (
170
+ current_speech["end"] - current_speech["start"]
171
+ ) > min_speech_samples:
172
+ speeches.append(current_speech)
173
+ current_speech = {}
174
+ prev_end = next_start = temp_end = 0
175
+ triggered = False
176
+ continue
177
+
178
+ if (
179
+ current_speech
180
+ and (audio_length_samples - current_speech["start"]) > min_speech_samples
181
+ ):
182
+ current_speech["end"] = audio_length_samples
183
+ speeches.append(current_speech)
184
+
185
+ for i, speech in enumerate(speeches):
186
+ if i == 0:
187
+ speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
188
+ if i != len(speeches) - 1:
189
+ silence_duration = speeches[i + 1]["start"] - speech["end"]
190
+ if silence_duration < 2 * speech_pad_samples:
191
+ speech["end"] += int(silence_duration // 2)
192
+ speeches[i + 1]["start"] = int(
193
+ max(0, speeches[i + 1]["start"] - silence_duration // 2)
194
+ )
195
+ else:
196
+ speech["end"] = int(
197
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
198
+ )
199
+ speeches[i + 1]["start"] = int(
200
+ max(0, speeches[i + 1]["start"] - speech_pad_samples)
201
+ )
202
+ else:
203
+ speech["end"] = int(
204
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
205
+ )
206
+
207
+ return speeches
208
+
209
+ def update_model(self):
210
+ self.model = get_vad_model()
211
+
212
+ @staticmethod
213
+ def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
214
+ """Collects and concatenates audio chunks."""
215
+ if not chunks:
216
+ return np.array([], dtype=np.float32)
217
+
218
+ return np.concatenate([audio[chunk["start"]: chunk["end"]] for chunk in chunks])
219
+
220
+ @staticmethod
221
+ def format_timestamp(
222
+ seconds: float,
223
+ always_include_hours: bool = False,
224
+ decimal_marker: str = ".",
225
+ ) -> str:
226
+ assert seconds >= 0, "non-negative timestamp expected"
227
+ milliseconds = round(seconds * 1000.0)
228
+
229
+ hours = milliseconds // 3_600_000
230
+ milliseconds -= hours * 3_600_000
231
+
232
+ minutes = milliseconds // 60_000
233
+ milliseconds -= minutes * 60_000
234
+
235
+ seconds = milliseconds // 1_000
236
+ milliseconds -= seconds * 1_000
237
+
238
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
239
+ return (
240
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
241
+ )
242
+
243
+ def restore_speech_timestamps(
244
+ self,
245
+ segments: List[dict],
246
+ speech_chunks: List[dict],
247
+ sampling_rate: Optional[int] = None,
248
+ ) -> List[dict]:
249
+ if sampling_rate is None:
250
+ sampling_rate = self.sampling_rate
251
+
252
+ ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
253
+
254
+ for segment in segments:
255
+ segment["start"] = ts_map.get_original_time(segment["start"])
256
+ segment["end"] = ts_map.get_original_time(segment["end"])
257
+
258
+ return segments
259
+
260
+
261
+ # Copied from: https://github.com/m-bain/whisperX/blob/main/whisperx/vads/vad.py
262
+ def merge_chunks(segments,
263
+ chunk_size,
264
+ onset: float,
265
+ offset: Optional[float]):
266
+ """
267
+ Merge operation described in paper
268
+ """
269
+ curr_end = 0
270
+ merged_segments = []
271
+ seg_idxs: list[tuple]= []
272
+ speaker_idxs: list[Optional[str]] = []
273
+
274
+ curr_start = segments[0].start
275
+ for seg in segments:
276
+ if seg.end - curr_start > chunk_size and curr_end - curr_start > 0:
277
+ merged_segments.append({
278
+ "start": curr_start,
279
+ "end": curr_end,
280
+ "segments": seg_idxs,
281
+ })
282
+ curr_start = seg.start
283
+ seg_idxs = []
284
+ speaker_idxs = []
285
+ curr_end = seg.end
286
+ seg_idxs.append((seg.start, seg.end))
287
+ speaker_idxs.append(seg.speaker)
288
+ # add final
289
+ merged_segments.append({
290
+ "start": curr_start,
291
+ "end": curr_end,
292
+ "segments": seg_idxs,
293
+ })
294
+
295
+ return merged_segments
296
+
297
+ # Copied from: https://github.com/m-bain/whisperX/blob/main/whisperx/types.py
298
+ class Segment:
299
+ def __init__(self, start:int, end:int, speaker:Optional[str]=None):
300
+ self.start = start
301
+ self.end = end
302
+ self.speaker = speaker