LAP-DEV commited on
Commit
c3f75cb
·
verified ·
1 Parent(s): 1cd78ce

Upload whisper_base.py

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +542 -0
modules/whisper/whisper_base.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import whisper
4
+ import gradio as gr
5
+ import torchaudio
6
+ from abc import ABC, abstractmethod
7
+ from typing import BinaryIO, Union, Tuple, List
8
+ import numpy as np
9
+ from datetime import datetime
10
+ from faster_whisper.vad import VadOptions
11
+ from dataclasses import astuple
12
+
13
+ from modules.uvr.music_separator import MusicSeparator
14
+ from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
15
+ UVR_MODELS_DIR)
16
+ from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
17
+ from modules.utils.youtube_manager import get_ytdata, get_ytaudio
18
+ from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
19
+ from modules.whisper.whisper_parameter import *
20
+ from modules.diarize.diarizer import Diarizer
21
+ from modules.vad.silero_vad import SileroVAD
22
+
23
+
24
+ class WhisperBase(ABC):
25
+ def __init__(self,
26
+ model_dir: str = WHISPER_MODELS_DIR,
27
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
28
+ uvr_model_dir: str = UVR_MODELS_DIR,
29
+ output_dir: str = OUTPUT_DIR,
30
+ ):
31
+ self.model_dir = model_dir
32
+ self.output_dir = output_dir
33
+ os.makedirs(self.output_dir, exist_ok=True)
34
+ os.makedirs(self.model_dir, exist_ok=True)
35
+ self.diarizer = Diarizer(
36
+ model_dir=diarization_model_dir
37
+ )
38
+ self.vad = SileroVAD()
39
+ self.music_separator = MusicSeparator(
40
+ model_dir=uvr_model_dir,
41
+ output_dir=os.path.join(output_dir, "UVR")
42
+ )
43
+
44
+ self.model = None
45
+ self.current_model_size = None
46
+ self.available_models = whisper.available_models()
47
+ self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
48
+ self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
49
+ self.device = self.get_device()
50
+ self.available_compute_types = ["float16", "float32"]
51
+ self.current_compute_type = "float16" if self.device == "cuda" else "float32"
52
+
53
+ @abstractmethod
54
+ def transcribe(self,
55
+ audio: Union[str, BinaryIO, np.ndarray],
56
+ progress: gr.Progress = gr.Progress(),
57
+ *whisper_params,
58
+ ):
59
+ """Inference whisper model to transcribe"""
60
+ pass
61
+
62
+ @abstractmethod
63
+ def update_model(self,
64
+ model_size: str,
65
+ compute_type: str,
66
+ progress: gr.Progress = gr.Progress()
67
+ ):
68
+ """Initialize whisper model"""
69
+ pass
70
+
71
+ def run(self,
72
+ audio: Union[str, BinaryIO, np.ndarray],
73
+ progress: gr.Progress = gr.Progress(),
74
+ add_timestamp: bool = True,
75
+ *whisper_params,
76
+ ) -> Tuple[List[dict], float]:
77
+ """
78
+ Run transcription with conditional pre-processing and post-processing.
79
+ The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
80
+ The diarization will be performed in post-processing, if enabled.
81
+
82
+ Parameters
83
+ ----------
84
+ audio: Union[str, BinaryIO, np.ndarray]
85
+ Audio input. This can be file path or binary type.
86
+ progress: gr.Progress
87
+ Indicator to show progress directly in gradio.
88
+ add_timestamp: bool
89
+ Whether to add a timestamp at the end of the filename.
90
+ *whisper_params: tuple
91
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
92
+
93
+ Returns
94
+ ----------
95
+ segments_result: List[dict]
96
+ list of dicts that includes start, end timestamps and transcribed text
97
+ elapsed_time: float
98
+ elapsed time for running
99
+ """
100
+ params = WhisperParameters.as_value(*whisper_params)
101
+
102
+ self.cache_parameters(
103
+ whisper_params=params,
104
+ add_timestamp=add_timestamp
105
+ )
106
+
107
+ if params.lang is None:
108
+ pass
109
+ elif params.lang == "Automatic Detection":
110
+ params.lang = None
111
+ else:
112
+ language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
113
+ params.lang = language_code_dict[params.lang]
114
+
115
+ if params.is_bgm_separate:
116
+ music, audio, _ = self.music_separator.separate(
117
+ audio=audio,
118
+ model_name=params.uvr_model_size,
119
+ device=params.uvr_device,
120
+ segment_size=params.uvr_segment_size,
121
+ save_file=params.uvr_save_file,
122
+ progress=progress
123
+ )
124
+
125
+ if audio.ndim >= 2:
126
+ audio = audio.mean(axis=1)
127
+ if self.music_separator.audio_info is None:
128
+ origin_sample_rate = 16000
129
+ else:
130
+ origin_sample_rate = self.music_separator.audio_info.sample_rate
131
+ audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
132
+
133
+ if params.uvr_enable_offload:
134
+ self.music_separator.offload()
135
+
136
+ if params.vad_filter:
137
+ # Explicit value set for float('inf') from gr.Number()
138
+ if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
139
+ params.max_speech_duration_s = float('inf')
140
+
141
+ vad_options = VadOptions(
142
+ threshold=params.threshold,
143
+ min_speech_duration_ms=params.min_speech_duration_ms,
144
+ max_speech_duration_s=params.max_speech_duration_s,
145
+ min_silence_duration_ms=params.min_silence_duration_ms,
146
+ speech_pad_ms=params.speech_pad_ms
147
+ )
148
+
149
+ audio, speech_chunks = self.vad.run(
150
+ audio=audio,
151
+ vad_parameters=vad_options,
152
+ progress=progress
153
+ )
154
+
155
+ result, elapsed_time = self.transcribe(
156
+ audio,
157
+ progress,
158
+ *astuple(params)
159
+ )
160
+
161
+ if params.vad_filter:
162
+ result = self.vad.restore_speech_timestamps(
163
+ segments=result,
164
+ speech_chunks=speech_chunks,
165
+ )
166
+
167
+ if params.is_diarize:
168
+ result, elapsed_time_diarization = self.diarizer.run(
169
+ audio=audio,
170
+ use_auth_token=params.hf_token,
171
+ transcribed_result=result,
172
+ )
173
+ elapsed_time += elapsed_time_diarization
174
+ return result, elapsed_time
175
+
176
+ def transcribe_file(self,
177
+ files: Optional[List] = None,
178
+ input_folder_path: Optional[str] = None,
179
+ file_format: str = "SRT",
180
+ add_timestamp: bool = True,
181
+ progress=gr.Progress(),
182
+ *whisper_params,
183
+ ) -> list:
184
+ """
185
+ Write subtitle file from Files
186
+
187
+ Parameters
188
+ ----------
189
+ files: list
190
+ List of files to transcribe from gr.Files()
191
+ input_folder_path: str
192
+ Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
193
+ this will be used instead.
194
+ file_format: str
195
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
196
+ add_timestamp: bool
197
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
198
+ progress: gr.Progress
199
+ Indicator to show progress directly in gradio.
200
+ *whisper_params: tuple
201
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
202
+
203
+ Returns
204
+ ----------
205
+ result_str:
206
+ Result of transcription to return to gr.Textbox()
207
+ result_file_path:
208
+ Output file path to return to gr.Files()
209
+ """
210
+ try:
211
+ if input_folder_path:
212
+ files = get_media_files(input_folder_path)
213
+ if isinstance(files, str):
214
+ files = [files]
215
+ if files and isinstance(files[0], gr.utils.NamedString):
216
+ files = [file.name for file in files]
217
+
218
+ files_info = {}
219
+ for file in files:
220
+
221
+ ## Detect language
222
+ #model = whisper.load_model("base")
223
+ params = WhisperParameters.as_value(*whisper_params)
224
+ model = whisper.load_model(params.model_size)
225
+ mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(whisper.load_audio(file))).to(model.device)
226
+ _, probs = model.detect_language(mel)
227
+ file_language = "not"
228
+ for key,value in whisper.tokenizer.LANGUAGES.items():
229
+ if key == str(max(probs, key=probs.get)):
230
+ file_language = value.capitalize()
231
+ break
232
+
233
+ transcribed_segments, time_for_task = self.run(
234
+ file,
235
+ progress,
236
+ add_timestamp,
237
+ *whisper_params,
238
+ )
239
+
240
+ file_name, file_ext = os.path.splitext(os.path.basename(file))
241
+ subtitle, file_path = self.generate_and_write_file(
242
+ file_name=file_name,
243
+ transcribed_segments=transcribed_segments,
244
+ add_timestamp=add_timestamp,
245
+ file_format=file_format,
246
+ output_dir=self.output_dir
247
+ )
248
+
249
+ files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path, "lang": file_language}
250
+
251
+ total_result = ''
252
+ total_info = ''
253
+ total_time = 0
254
+ for file_name, info in files_info.items():
255
+ total_result += f'{info["subtitle"]}'
256
+ total_time += info["time_for_task"]
257
+ #total_info += f'{info["lang"]}'
258
+ total_info += f"Language {info['lang']} detected for file '{file_name}{file_ext}'"
259
+
260
+ #result_str = f"Processing of file '{file_name}{file_ext}' done in {self.format_time(total_time)}:\n\n{total_result}"
261
+ total_info += f"\nTranscription process done in {self.format_time(total_time)}"
262
+ result_str = total_result
263
+ result_file_path = [info['path'] for info in files_info.values()]
264
+
265
+ return [result_str, result_file_path, total_info]
266
+
267
+ except Exception as e:
268
+ print(f"Error transcribing file: {e}")
269
+ finally:
270
+ self.release_cuda_memory()
271
+
272
+ def transcribe_mic(self,
273
+ mic_audio: str,
274
+ file_format: str = "SRT",
275
+ add_timestamp: bool = True,
276
+ progress=gr.Progress(),
277
+ *whisper_params,
278
+ ) -> list:
279
+ """
280
+ Write subtitle file from microphone
281
+
282
+ Parameters
283
+ ----------
284
+ mic_audio: str
285
+ Audio file path from gr.Microphone()
286
+ file_format: str
287
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
288
+ add_timestamp: bool
289
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
290
+ progress: gr.Progress
291
+ Indicator to show progress directly in gradio.
292
+ *whisper_params: tuple
293
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
294
+
295
+ Returns
296
+ ----------
297
+ result_str:
298
+ Result of transcription to return to gr.Textbox()
299
+ result_file_path:
300
+ Output file path to return to gr.Files()
301
+ """
302
+ try:
303
+ progress(0, desc="Loading Audio..")
304
+ transcribed_segments, time_for_task = self.run(
305
+ mic_audio,
306
+ progress,
307
+ add_timestamp,
308
+ *whisper_params,
309
+ )
310
+ progress(1, desc="Completed!")
311
+
312
+ subtitle, result_file_path = self.generate_and_write_file(
313
+ file_name="Mic",
314
+ transcribed_segments=transcribed_segments,
315
+ add_timestamp=add_timestamp,
316
+ file_format=file_format,
317
+ output_dir=self.output_dir
318
+ )
319
+
320
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
321
+ return [result_str, result_file_path]
322
+ except Exception as e:
323
+ print(f"Error transcribing file: {e}")
324
+ finally:
325
+ self.release_cuda_memory()
326
+
327
+ def transcribe_youtube(self,
328
+ youtube_link: str,
329
+ file_format: str = "SRT",
330
+ add_timestamp: bool = True,
331
+ progress=gr.Progress(),
332
+ *whisper_params,
333
+ ) -> list:
334
+ """
335
+ Write subtitle file from Youtube
336
+
337
+ Parameters
338
+ ----------
339
+ youtube_link: str
340
+ URL of the Youtube video to transcribe from gr.Textbox()
341
+ file_format: str
342
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
343
+ add_timestamp: bool
344
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
345
+ progress: gr.Progress
346
+ Indicator to show progress directly in gradio.
347
+ *whisper_params: tuple
348
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
349
+
350
+ Returns
351
+ ----------
352
+ result_str:
353
+ Result of transcription to return to gr.Textbox()
354
+ result_file_path:
355
+ Output file path to return to gr.Files()
356
+ """
357
+ try:
358
+ progress(0, desc="Loading Audio from Youtube..")
359
+ yt = get_ytdata(youtube_link)
360
+ audio = get_ytaudio(yt)
361
+
362
+ transcribed_segments, time_for_task = self.run(
363
+ audio,
364
+ progress,
365
+ add_timestamp,
366
+ *whisper_params,
367
+ )
368
+
369
+ progress(1, desc="Completed!")
370
+
371
+ file_name = safe_filename(yt.title)
372
+ subtitle, result_file_path = self.generate_and_write_file(
373
+ file_name=file_name,
374
+ transcribed_segments=transcribed_segments,
375
+ add_timestamp=add_timestamp,
376
+ file_format=file_format,
377
+ output_dir=self.output_dir
378
+ )
379
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
380
+
381
+ if os.path.exists(audio):
382
+ os.remove(audio)
383
+
384
+ return [result_str, result_file_path]
385
+
386
+ except Exception as e:
387
+ print(f"Error transcribing file: {e}")
388
+ finally:
389
+ self.release_cuda_memory()
390
+
391
+ @staticmethod
392
+ def generate_and_write_file(file_name: str,
393
+ transcribed_segments: list,
394
+ add_timestamp: bool,
395
+ file_format: str,
396
+ output_dir: str
397
+ ) -> str:
398
+ """
399
+ Writes subtitle file
400
+
401
+ Parameters
402
+ ----------
403
+ file_name: str
404
+ Output file name
405
+ transcribed_segments: list
406
+ Text segments transcribed from audio
407
+ add_timestamp: bool
408
+ Determines whether to add a timestamp to the end of the filename.
409
+ file_format: str
410
+ File format to write. Supported formats: [SRT, WebVTT, txt]
411
+ output_dir: str
412
+ Directory path of the output
413
+
414
+ Returns
415
+ ----------
416
+ content: str
417
+ Result of the transcription
418
+ output_path: str
419
+ output file path
420
+ """
421
+ if add_timestamp:
422
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
423
+ output_path = os.path.join(output_dir, f"{file_name} - {timestamp}")
424
+ else:
425
+ output_path = os.path.join(output_dir, f"{file_name}")
426
+
427
+ file_format = file_format.strip().lower()
428
+ if file_format == "srt":
429
+ content = get_srt(transcribed_segments)
430
+ output_path += '.srt'
431
+
432
+ elif file_format == "webvtt":
433
+ content = get_vtt(transcribed_segments)
434
+ output_path += '.vtt'
435
+
436
+ elif file_format == "txt":
437
+ content = get_txt(transcribed_segments)
438
+ output_path += '.txt'
439
+
440
+ write_file(content, output_path)
441
+ return content, output_path
442
+
443
+ @staticmethod
444
+ def format_time(elapsed_time: float) -> str:
445
+ """
446
+ Get {hours} {minutes} {seconds} time format string
447
+
448
+ Parameters
449
+ ----------
450
+ elapsed_time: str
451
+ Elapsed time for transcription
452
+
453
+ Returns
454
+ ----------
455
+ Time format string
456
+ """
457
+ hours, rem = divmod(elapsed_time, 3600)
458
+ minutes, seconds = divmod(rem, 60)
459
+
460
+ time_str = ""
461
+ if hours:
462
+ time_str += f"{hours} hours "
463
+ if minutes:
464
+ time_str += f"{minutes} minutes "
465
+ seconds = round(seconds)
466
+ time_str += f"{seconds} seconds"
467
+
468
+ return time_str.strip()
469
+
470
+ @staticmethod
471
+ def get_device():
472
+ if torch.cuda.is_available():
473
+ return "cuda"
474
+ elif torch.backends.mps.is_available():
475
+ if not WhisperBase.is_sparse_api_supported():
476
+ # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
477
+ return "cpu"
478
+ return "mps"
479
+ else:
480
+ return "cpu"
481
+
482
+ @staticmethod
483
+ def is_sparse_api_supported():
484
+ if not torch.backends.mps.is_available():
485
+ return False
486
+
487
+ try:
488
+ device = torch.device("mps")
489
+ sparse_tensor = torch.sparse_coo_tensor(
490
+ indices=torch.tensor([[0, 1], [2, 3]]),
491
+ values=torch.tensor([1, 2]),
492
+ size=(4, 4),
493
+ device=device
494
+ )
495
+ return True
496
+ except RuntimeError:
497
+ return False
498
+
499
+ @staticmethod
500
+ def release_cuda_memory():
501
+ """Release memory"""
502
+ if torch.cuda.is_available():
503
+ torch.cuda.empty_cache()
504
+ torch.cuda.reset_max_memory_allocated()
505
+
506
+ @staticmethod
507
+ def remove_input_files(file_paths: List[str]):
508
+ """Remove gradio cached files"""
509
+ if not file_paths:
510
+ return
511
+
512
+ for file_path in file_paths:
513
+ if file_path and os.path.exists(file_path):
514
+ os.remove(file_path)
515
+
516
+ @staticmethod
517
+ def cache_parameters(
518
+ whisper_params: WhisperValues,
519
+ add_timestamp: bool
520
+ ):
521
+ """cache parameters to the yaml file"""
522
+ cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
523
+ cached_whisper_param = whisper_params.to_yaml()
524
+ cached_yaml = {**cached_params, **cached_whisper_param}
525
+ cached_yaml["whisper"]["add_timestamp"] = add_timestamp
526
+
527
+ save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
528
+
529
+ @staticmethod
530
+ def resample_audio(audio: Union[str, np.ndarray],
531
+ new_sample_rate: int = 16000,
532
+ original_sample_rate: Optional[int] = None,) -> np.ndarray:
533
+ """Resamples audio to 16k sample rate, standard on Whisper model"""
534
+ if isinstance(audio, str):
535
+ audio, original_sample_rate = torchaudio.load(audio)
536
+ else:
537
+ if original_sample_rate is None:
538
+ raise ValueError("original_sample_rate must be provided when audio is numpy array.")
539
+ audio = torch.from_numpy(audio)
540
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate)
541
+ resampled_audio = resampler(audio).numpy()
542
+ return resampled_audio