LAP-DEV commited on
Commit
73f3097
·
verified ·
1 Parent(s): c14b8d7

Upload silero_vad.py

Browse files
Files changed (1) hide show
  1. modules/vad/silero_vad.py +277 -0
modules/vad/silero_vad.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py
2
+
3
+ from faster_whisper.vad import VadOptions, get_vad_model
4
+ import numpy as np
5
+ from typing import BinaryIO, Union, List, Optional, Tuple
6
+ import warnings
7
+ import bisect
8
+ import faster_whisper
9
+ from faster_whisper.transcribe import SpeechTimestampsMap
10
+ import gradio as gr
11
+
12
+ from modules.whisper.data_classes import *
13
+
14
+
15
+ class SileroVAD:
16
+ def __init__(self):
17
+ self.sampling_rate = 16000
18
+ self.window_size_samples = 512
19
+ self.model = None
20
+
21
+ def run(self,
22
+ audio: Union[str, BinaryIO, np.ndarray],
23
+ vad_parameters: VadOptions,
24
+ progress: gr.Progress = gr.Progress()
25
+ ) -> Tuple[np.ndarray, List[dict]]:
26
+ """
27
+ Run VAD
28
+
29
+ Parameters
30
+ ----------
31
+ audio: Union[str, BinaryIO, np.ndarray]
32
+ Audio path or file binary or Audio numpy array
33
+ vad_parameters:
34
+ Options for VAD processing.
35
+ progress: gr.Progress
36
+ Indicator to show progress directly in gradio.
37
+
38
+ Returns
39
+ ----------
40
+ np.ndarray
41
+ Pre-processed audio with VAD
42
+ List[dict]
43
+ Chunks of speeches to be used to restore the timestamps later
44
+ """
45
+
46
+ sampling_rate = self.sampling_rate
47
+
48
+ if not isinstance(audio, np.ndarray):
49
+ audio = faster_whisper.decode_audio(audio, sampling_rate=sampling_rate)
50
+
51
+ duration = audio.shape[0] / sampling_rate
52
+ duration_after_vad = duration
53
+
54
+ if vad_parameters is None:
55
+ vad_parameters = VadOptions()
56
+ elif isinstance(vad_parameters, dict):
57
+ vad_parameters = VadOptions(**vad_parameters)
58
+ speech_chunks = self.get_speech_timestamps(
59
+ audio=audio,
60
+ vad_options=vad_parameters,
61
+ progress=progress
62
+ )
63
+
64
+ audio = self.collect_chunks(audio, speech_chunks)
65
+ duration_after_vad = audio.shape[0] / sampling_rate
66
+
67
+ return audio, speech_chunks
68
+
69
+ def get_speech_timestamps(
70
+ self,
71
+ audio: np.ndarray,
72
+ vad_options: Optional[VadOptions] = None,
73
+ progress: gr.Progress = gr.Progress(),
74
+ **kwargs,
75
+ ) -> List[dict]:
76
+ """This method is used for splitting long audios into speech chunks using silero VAD.
77
+
78
+ Args:
79
+ audio: One dimensional float array.
80
+ vad_options: Options for VAD processing.
81
+ kwargs: VAD options passed as keyword arguments for backward compatibility.
82
+ progress: Gradio progress to indicate progress.
83
+
84
+ Returns:
85
+ List of dicts containing begin and end samples of each speech chunk.
86
+ """
87
+
88
+ if self.model is None:
89
+ self.update_model()
90
+
91
+ if vad_options is None:
92
+ vad_options = VadOptions(**kwargs)
93
+
94
+ threshold = vad_options.threshold
95
+ neg_threshold = vad_options.neg_threshold
96
+ min_speech_duration_ms = vad_options.min_speech_duration_ms
97
+ max_speech_duration_s = vad_options.max_speech_duration_s
98
+ min_silence_duration_ms = vad_options.min_silence_duration_ms
99
+ window_size_samples = self.window_size_samples
100
+ speech_pad_ms = vad_options.speech_pad_ms
101
+ min_speech_samples = self.sampling_rate * min_speech_duration_ms / 1000
102
+ speech_pad_samples = self.sampling_rate * speech_pad_ms / 1000
103
+ max_speech_samples = (
104
+ self.sampling_rate * max_speech_duration_s
105
+ - window_size_samples
106
+ - 2 * speech_pad_samples
107
+ )
108
+ min_silence_samples = self.sampling_rate * min_silence_duration_ms / 1000
109
+ min_silence_samples_at_max_speech = self.sampling_rate * 98 / 1000
110
+
111
+ audio_length_samples = len(audio)
112
+
113
+ padded_audio = np.pad(
114
+ audio, (0, window_size_samples - audio.shape[0] % window_size_samples)
115
+ )
116
+ speech_probs = self.model(padded_audio.reshape(1, -1)).squeeze(0)
117
+
118
+ triggered = False
119
+ speeches = []
120
+ current_speech = {}
121
+ if neg_threshold is None:
122
+ neg_threshold = max(threshold - 0.15, 0.01)
123
+
124
+ # to save potential segment end (and tolerate some silence)
125
+ temp_end = 0
126
+ # to save potential segment limits in case of maximum segment size reached
127
+ prev_end = next_start = 0
128
+
129
+ for i, speech_prob in enumerate(speech_probs):
130
+ if (speech_prob >= threshold) and temp_end:
131
+ temp_end = 0
132
+ if next_start < prev_end:
133
+ next_start = window_size_samples * i
134
+
135
+ if (speech_prob >= threshold) and not triggered:
136
+ triggered = True
137
+ current_speech["start"] = window_size_samples * i
138
+ continue
139
+
140
+ if (
141
+ triggered
142
+ and (window_size_samples * i) - current_speech["start"] > max_speech_samples
143
+ ):
144
+ if prev_end:
145
+ current_speech["end"] = prev_end
146
+ speeches.append(current_speech)
147
+ current_speech = {}
148
+ # previously reached silence (< neg_thres) and is still not speech (< thres)
149
+ if next_start < prev_end:
150
+ triggered = False
151
+ else:
152
+ current_speech["start"] = next_start
153
+ prev_end = next_start = temp_end = 0
154
+ else:
155
+ current_speech["end"] = window_size_samples * i
156
+ speeches.append(current_speech)
157
+ current_speech = {}
158
+ prev_end = next_start = temp_end = 0
159
+ triggered = False
160
+ continue
161
+
162
+ if (speech_prob < neg_threshold) and triggered:
163
+ if not temp_end:
164
+ temp_end = window_size_samples * i
165
+ # condition to avoid cutting in very short silence
166
+ if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
167
+ prev_end = temp_end
168
+ if (window_size_samples * i) - temp_end < min_silence_samples:
169
+ continue
170
+ else:
171
+ current_speech["end"] = temp_end
172
+ if (
173
+ current_speech["end"] - current_speech["start"]
174
+ ) > min_speech_samples:
175
+ speeches.append(current_speech)
176
+ current_speech = {}
177
+ prev_end = next_start = temp_end = 0
178
+ triggered = False
179
+ continue
180
+
181
+ if (
182
+ current_speech
183
+ and (audio_length_samples - current_speech["start"]) > min_speech_samples
184
+ ):
185
+ current_speech["end"] = audio_length_samples
186
+ speeches.append(current_speech)
187
+
188
+ for i, speech in enumerate(speeches):
189
+ if i == 0:
190
+ speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
191
+ if i != len(speeches) - 1:
192
+ silence_duration = speeches[i + 1]["start"] - speech["end"]
193
+ if silence_duration < 2 * speech_pad_samples:
194
+ speech["end"] += int(silence_duration // 2)
195
+ speeches[i + 1]["start"] = int(
196
+ max(0, speeches[i + 1]["start"] - silence_duration // 2)
197
+ )
198
+ else:
199
+ speech["end"] = int(
200
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
201
+ )
202
+ speeches[i + 1]["start"] = int(
203
+ max(0, speeches[i + 1]["start"] - speech_pad_samples)
204
+ )
205
+ else:
206
+ speech["end"] = int(
207
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
208
+ )
209
+
210
+ return speeches
211
+
212
+ def update_model(self):
213
+ self.model = get_vad_model()
214
+
215
+ @staticmethod
216
+ def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
217
+ """Collects and concatenates audio chunks."""
218
+ if not chunks:
219
+ return np.array([], dtype=np.float32)
220
+
221
+ return np.concatenate([audio[chunk["start"]: chunk["end"]] for chunk in chunks])
222
+
223
+ @staticmethod
224
+ def format_timestamp(
225
+ seconds: float,
226
+ always_include_hours: bool = False,
227
+ decimal_marker: str = ".",
228
+ ) -> str:
229
+ assert seconds >= 0, "non-negative timestamp expected"
230
+ milliseconds = round(seconds * 1000.0)
231
+
232
+ hours = milliseconds // 3_600_000
233
+ milliseconds -= hours * 3_600_000
234
+
235
+ minutes = milliseconds // 60_000
236
+ milliseconds -= minutes * 60_000
237
+
238
+ seconds = milliseconds // 1_000
239
+ milliseconds -= seconds * 1_000
240
+
241
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
242
+ return (
243
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
244
+ )
245
+
246
+ def restore_speech_timestamps(
247
+ self,
248
+ segments: List[Segment],
249
+ speech_chunks: List[dict],
250
+ sampling_rate: Optional[int] = None,
251
+ ) -> List[Segment]:
252
+ if sampling_rate is None:
253
+ sampling_rate = self.sampling_rate
254
+
255
+ ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
256
+
257
+ for segment in segments:
258
+ if segment.words:
259
+ words = []
260
+ for word in segment.words:
261
+ # Ensure the word start and end times are resolved to the same chunk.
262
+ middle = (word.start + word.end) / 2
263
+ chunk_index = ts_map.get_chunk_index(middle)
264
+ word.start = ts_map.get_original_time(word.start, chunk_index)
265
+ word.end = ts_map.get_original_time(word.end, chunk_index)
266
+ words.append(word)
267
+
268
+ segment.start = words[0].start
269
+ segment.end = words[-1].end
270
+ segment.words = words
271
+
272
+ else:
273
+ segment.start = ts_map.get_original_time(segment.start)
274
+ segment.end = ts_map.get_original_time(segment.end)
275
+
276
+ return segments
277
+