LAP-DEV commited on
Commit
ac35c43
·
verified ·
1 Parent(s): c16edd0

Upload whisper_base.py

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +738 -0
modules/whisper/whisper_base.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import whisper
4
+ import gradio as gr
5
+ import torchaudio
6
+ from abc import ABC, abstractmethod
7
+ from typing import BinaryIO, Union, Tuple, List
8
+ import numpy as np
9
+ from datetime import datetime
10
+ from faster_whisper.vad import VadOptions
11
+ from dataclasses import astuple
12
+ import gc
13
+ from copy import deepcopy
14
+ from modules.vad.silero_vad import merge_chunks, Segment
15
+ from modules.uvr.music_separator import MusicSeparator
16
+ from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
17
+ UVR_MODELS_DIR)
18
+ from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, get_plaintext, get_csv, write_file, safe_filename
19
+ from modules.utils.youtube_manager import get_ytdata, get_ytaudio
20
+ from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
21
+ from modules.whisper.whisper_parameter import *
22
+ from modules.diarize.diarizer import Diarizer
23
+ from modules.vad.silero_vad import SileroVAD
24
+ from modules.translation.nllb_inference import NLLBInference
25
+ from modules.translation.nllb_inference import NLLB_AVAILABLE_LANGS
26
+ import faster_whisper
27
+
28
+ class WhisperBase(ABC):
29
+ def __init__(self,
30
+ model_dir: str = WHISPER_MODELS_DIR,
31
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
32
+ uvr_model_dir: str = UVR_MODELS_DIR,
33
+ output_dir: str = OUTPUT_DIR,
34
+ ):
35
+ self.model_dir = model_dir
36
+ self.output_dir = output_dir
37
+ os.makedirs(self.output_dir, exist_ok=True)
38
+ os.makedirs(self.model_dir, exist_ok=True)
39
+ self.diarizer = Diarizer(
40
+ model_dir=diarization_model_dir
41
+ )
42
+ self.vad = SileroVAD()
43
+ self.music_separator = MusicSeparator(
44
+ model_dir=uvr_model_dir,
45
+ output_dir=os.path.join(output_dir, "UVR")
46
+ )
47
+
48
+ self.model = None
49
+ self.current_model_size = None
50
+ self.available_models = whisper.available_models()
51
+ self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
52
+ #self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
53
+ self.translatable_models = whisper.available_models()
54
+ self.device = self.get_device()
55
+ self.available_compute_types = ["float16", "float32"]
56
+ self.current_compute_type = "float16" if self.device == "cuda" else "float32"
57
+
58
+ @abstractmethod
59
+ def transcribe(self,
60
+ audio: Union[str, BinaryIO, np.ndarray],
61
+ progress: gr.Progress = gr.Progress(),
62
+ *whisper_params,
63
+ ):
64
+ """Inference whisper model to transcribe"""
65
+ pass
66
+
67
+ @abstractmethod
68
+ def update_model(self,
69
+ model_size: str,
70
+ compute_type: str,
71
+ progress: gr.Progress = gr.Progress()
72
+ ):
73
+ """Initialize whisper model"""
74
+ pass
75
+
76
+ def run(self,
77
+ audio: Union[str, BinaryIO, np.ndarray],
78
+ progress: gr.Progress = gr.Progress(),
79
+ add_timestamp: bool = True,
80
+ *whisper_params,
81
+ ) -> Tuple[List[dict], float]:
82
+ """
83
+ Run transcription with conditional pre-processing and post-processing.
84
+ The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
85
+ The diarization will be performed in post-processing, if enabled.
86
+
87
+ Parameters
88
+ ----------
89
+ audio: Union[str, BinaryIO, np.ndarray]
90
+ Audio input. This can be file path or binary type.
91
+ progress: gr.Progress
92
+ Indicator to show progress directly in gradio.
93
+ add_timestamp: bool
94
+ Whether to add a timestamp at the end of the filename.
95
+ *whisper_params: tuple
96
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
97
+
98
+ Returns
99
+ ----------
100
+ segments_result: List[dict]
101
+ list of dicts that includes start, end timestamps and transcribed text
102
+ elapsed_time: float
103
+ elapsed time for running
104
+ """
105
+
106
+ start_time = datetime.now()
107
+ params = WhisperParameters.as_value(*whisper_params)
108
+
109
+ # Get the offload params
110
+ default_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
111
+ whisper_params = default_params["whisper"]
112
+ diarization_params = default_params["diarization"]
113
+ bool_whisper_enable_offload = whisper_params["enable_offload"]
114
+ bool_diarization_enable_offload = diarization_params["enable_offload"]
115
+
116
+ if params.lang is None:
117
+ pass
118
+ elif params.lang == "Automatic Detection":
119
+ params.lang = None
120
+ else:
121
+ language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
122
+ params.lang = language_code_dict[params.lang]
123
+
124
+ if params.is_bgm_separate:
125
+ music, audio, _ = self.music_separator.separate(
126
+ audio=audio,
127
+ model_name=params.uvr_model_size,
128
+ device=params.uvr_device,
129
+ segment_size=params.uvr_segment_size,
130
+ save_file=params.uvr_save_file,
131
+ progress=progress
132
+ )
133
+
134
+ if audio.ndim >= 2:
135
+ audio = audio.mean(axis=1)
136
+ if self.music_separator.audio_info is None:
137
+ origin_sample_rate = 16000
138
+ else:
139
+ origin_sample_rate = self.music_separator.audio_info.sample_rate
140
+ audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
141
+
142
+ if params.uvr_enable_offload:
143
+ self.music_separator.offload()
144
+ elapsed_time_bgm_sep = datetime.now() - start_time
145
+
146
+ origin_audio = deepcopy(audio)
147
+
148
+ if params.vad_filter:
149
+ # Explicit value set for float('inf') from gr.Number()
150
+ if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
151
+ params.max_speech_duration_s = float('inf')
152
+
153
+ progress(0, desc="Filtering silent parts from audio...")
154
+ vad_options = VadOptions(
155
+ threshold=params.threshold,
156
+ min_speech_duration_ms=params.min_speech_duration_ms,
157
+ max_speech_duration_s=params.max_speech_duration_s,
158
+ min_silence_duration_ms=params.min_silence_duration_ms,
159
+ speech_pad_ms=params.speech_pad_ms
160
+ )
161
+
162
+ vad_processed, speech_chunks = self.vad.run(
163
+ audio=audio,
164
+ vad_parameters=vad_options,
165
+ progress=progress
166
+ )
167
+
168
+ try:
169
+ if vad_processed.size > 0 and speech_chunks:
170
+ if not isinstance(audio, np.ndarray):
171
+ loaded_audio = faster_whisper.decode_audio(audio, sampling_rate=self.vad.sampling_rate)
172
+ else:
173
+ loaded_audio = audio
174
+ # Convert speech_chunks to Segment objects and convert samples to seconds
175
+ segments = [Segment(start=chunk['start']/self.vad.sampling_rate, end=chunk['end']/self.vad.sampling_rate) for chunk in speech_chunks]
176
+ # merged_chunks only works on segments expressed in seconds!!
177
+ merged_chunks = merge_chunks(segments, chunk_size=300, onset=0.0, offset=None)
178
+ all_segments = []
179
+ total_elapsed_time = 0.0
180
+ for merged in merged_chunks:
181
+ chunk_start = merged['start']
182
+ chunk_end = merged['end']
183
+
184
+ # To slice audio, convert chunk_start and chunk_end from seconds to samples by mulitplying by sampling rate.
185
+ start_sample = int(chunk_start*self.vad.sampling_rate)
186
+ end_sample = int(chunk_end*self.vad.sampling_rate)
187
+
188
+ chunk_audio = loaded_audio[start_sample:end_sample]
189
+
190
+ chunk_result, chunk_time = self.transcribe(
191
+ chunk_audio,
192
+ progress,
193
+ *astuple(params)
194
+ )
195
+ # Offset timestamps
196
+ for seg in chunk_result:
197
+ seg['start'] += chunk_start
198
+ seg['end'] += chunk_start
199
+ all_segments.extend(chunk_result)
200
+ total_elapsed_time += chunk_time
201
+ result = all_segments
202
+ elapsed_time = total_elapsed_time
203
+ else:
204
+ params.vad_filter = False
205
+ except Exception as e:
206
+ print(f"Error transcribing file: {e}")
207
+
208
+ if not params.vad_filter:
209
+ result, elapsed_time = self.transcribe(
210
+ audio,
211
+ progress,
212
+ *astuple(params)
213
+ )
214
+ if bool_whisper_enable_offload:
215
+ self.offload()
216
+
217
+ if params.is_diarize:
218
+ progress(0.99, desc="Diarizing speakers...")
219
+ result, elapsed_time_diarization = self.diarizer.run(
220
+ audio=origin_audio,
221
+ use_auth_token=params.hf_token,
222
+ transcribed_result=result,
223
+ device=params.diarization_device
224
+ )
225
+ if bool_diarization_enable_offload:
226
+ self.diarizer.offload()
227
+
228
+ if not result:
229
+ print(f"Whisper did not detected any speech segments in the audio.")
230
+ result = list()
231
+
232
+ progress(1.0, desc="Processing done!")
233
+ total_elapsed_time = datetime.now() - start_time
234
+ return result, elapsed_time
235
+
236
+ def transcribe_file(self,
237
+ files: Optional[List] = None,
238
+ input_folder_path: Optional[str] = None,
239
+ file_format: str = "SRT",
240
+ add_timestamp: bool = True,
241
+ translate_output: bool = False,
242
+ translate_model: str = "",
243
+ target_lang: str = "",
244
+ add_timestamp_preview: bool = False,
245
+ progress=gr.Progress(),
246
+ *whisper_params,
247
+ ) -> list:
248
+ """
249
+ Write subtitle file from Files
250
+
251
+ Parameters
252
+ ----------
253
+ files: list
254
+ List of files to transcribe from gr.Files()
255
+ input_folder_path: str
256
+ Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
257
+ this will be used instead.
258
+ file_format: str
259
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
260
+ add_timestamp: bool
261
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
262
+ translate_output: bool
263
+ Translate output
264
+ translate_model: str
265
+ Translation model to use
266
+ target_lang: str
267
+ Target language to use
268
+ add_timestamp_preview: bool
269
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp to output preview
270
+ progress: gr.Progress
271
+ Indicator to show progress directly in gradio.
272
+ *whisper_params: tuple
273
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
274
+
275
+ Returns
276
+ ----------
277
+ result_str:
278
+ Result of transcription to return to gr.Textbox()
279
+ result_file_path:
280
+ Output file path to return to gr.Files()
281
+ """
282
+ try:
283
+ if input_folder_path:
284
+ files = get_media_files(input_folder_path)
285
+ if isinstance(files, str):
286
+ files = [files]
287
+ if files and isinstance(files[0], gr.utils.NamedString):
288
+ files = [file.name for file in files]
289
+
290
+ ## Initialization variables & start time
291
+ files_info = {}
292
+ files_to_download = {}
293
+ time_start = datetime.now()
294
+
295
+ ## Load parameters related with whisper
296
+ params = WhisperParameters.as_value(*whisper_params)
297
+
298
+ ## Load model to detect language
299
+ model = whisper.load_model("base")
300
+
301
+ for file in files:
302
+ print(file)
303
+ ## Detect language
304
+ mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(whisper.load_audio(file))).to(model.device)
305
+ _, probs = model.detect_language(mel)
306
+ file_language = ""
307
+ file_lang_probs = ""
308
+ for key,value in whisper.tokenizer.LANGUAGES.items():
309
+ if key == str(max(probs, key=probs.get)):
310
+ file_language = value.capitalize()
311
+ for key_prob,value_prob in probs.items():
312
+ if key == key_prob:
313
+ file_lang_probs = str((round(value_prob*100)))
314
+ break
315
+ break
316
+ transcribed_segments, time_for_task = self.run(
317
+ file,
318
+ progress,
319
+ add_timestamp,
320
+ *whisper_params,
321
+ )
322
+ # Define source language
323
+ source_lang = file_language
324
+
325
+ # Translate to English using Whisper built-in functionality
326
+ transcription_note = ""
327
+ if params.is_translate:
328
+ if source_lang != "English":
329
+ transcription_note = "To English"
330
+ source_lang = "English"
331
+ else:
332
+ transcription_note = "Already in English"
333
+
334
+ # Translate the transcribed segments
335
+ translation_note = ""
336
+ if translate_output:
337
+ if source_lang != target_lang:
338
+ self.nllb_inf = NLLBInference()
339
+ if source_lang in NLLB_AVAILABLE_LANGS.keys():
340
+ transcribed_segments = self.nllb_inf.translate_text(
341
+ input_list_dict=transcribed_segments,
342
+ model_size=translate_model,
343
+ src_lang=source_lang,
344
+ tgt_lang=target_lang,
345
+ speaker_diarization=params.is_diarize
346
+ )
347
+ translation_note = "To " + target_lang
348
+ else:
349
+ translation_note = source_lang + " not supported"
350
+ else:
351
+ translation_note = "Already in " + target_lang
352
+
353
+ ## Get preview
354
+ file_name, file_ext = os.path.splitext(os.path.basename(file))
355
+ ## With or without timestamps
356
+ if add_timestamp_preview:
357
+ subtitle = get_txt(transcribed_segments)
358
+ else:
359
+ subtitle = get_plaintext(transcribed_segments)
360
+ files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "lang": file_language, "lang_prob": file_lang_probs, "input_source_file": (file_name+file_ext), "translation": translation_note, "transcription": transcription_note}
361
+
362
+ ## Add output file as txt
363
+ file_name, file_ext = os.path.splitext(os.path.basename(file))
364
+ subtitle, file_path = self.generate_and_write_file(
365
+ file_name=file_name,
366
+ transcribed_segments=transcribed_segments,
367
+ add_timestamp=add_timestamp,
368
+ file_format="txt",
369
+ output_dir=self.output_dir
370
+ )
371
+ files_to_download[file_name+"_txt"] = {"path": file_path}
372
+
373
+ ## Add output file as srt
374
+ file_name, file_ext = os.path.splitext(os.path.basename(file))
375
+ subtitle, file_path = self.generate_and_write_file(
376
+ file_name=file_name,
377
+ transcribed_segments=transcribed_segments,
378
+ add_timestamp=add_timestamp,
379
+ file_format="srt",
380
+ output_dir=self.output_dir
381
+ )
382
+ files_to_download[file_name+"_srt"] = {"path": file_path}
383
+
384
+ ## Add output file as csv
385
+ file_name, file_ext = os.path.splitext(os.path.basename(file))
386
+ subtitle, file_path = self.generate_and_write_file(
387
+ file_name=file_name,
388
+ transcribed_segments=transcribed_segments,
389
+ add_timestamp=add_timestamp,
390
+ file_format="csv",
391
+ output_dir=self.output_dir
392
+ )
393
+ files_to_download[file_name+"_csv"] = {"path": file_path}
394
+
395
+ total_result = ""
396
+ total_info = ""
397
+ total_time = 0
398
+ for file_name, info in files_info.items():
399
+ total_result += f'{info["subtitle"]}'
400
+
401
+ total_time += info["time_for_task"]
402
+ total_info += f'Media file:\t{info["input_source_file"]}\nLanguage:\t{info["lang"]} (probability {info["lang_prob"]}%)\n'
403
+
404
+ if params.is_translate:
405
+ total_info += f'Translation:\t{info["transcription"]}\n\t⤷ Handled by OpenAI Whisper\n'
406
+
407
+ if translate_output:
408
+ total_info += f'Translation:\t{info["translation"]}\n\t⤷ Handled by Facebook NLLB\n'
409
+
410
+ time_end = datetime.now()
411
+ total_info += f"\nTotal processing time: {self.format_time((time_end-time_start).total_seconds())}"
412
+
413
+ result_str = total_result.rstrip("\n")
414
+ result_file_path = [info['path'] for info in files_to_download.values()]
415
+
416
+ return [result_str,result_file_path,total_info]
417
+
418
+ except Exception as e:
419
+ print(f"Error transcribing file: {e}")
420
+ finally:
421
+ self.release_cuda_memory()
422
+
423
+ def transcribe_mic(self,
424
+ mic_audio: str,
425
+ file_format: str = "SRT",
426
+ add_timestamp: bool = True,
427
+ progress=gr.Progress(),
428
+ *whisper_params,
429
+ ) -> list:
430
+ """
431
+ Write subtitle file from microphone
432
+
433
+ Parameters
434
+ ----------
435
+ mic_audio: str
436
+ Audio file path from gr.Microphone()
437
+ file_format: str
438
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
439
+ add_timestamp: bool
440
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
441
+ progress: gr.Progress
442
+ Indicator to show progress directly in gradio.
443
+ *whisper_params: tuple
444
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
445
+
446
+ Returns
447
+ ----------
448
+ result_str:
449
+ Result of transcription to return to gr.Textbox()
450
+ result_file_path:
451
+ Output file path to return to gr.Files()
452
+ """
453
+ try:
454
+ progress(0, desc="Loading Audio...")
455
+ transcribed_segments, time_for_task = self.run(
456
+ mic_audio,
457
+ progress,
458
+ add_timestamp,
459
+ *whisper_params,
460
+ )
461
+ progress(1, desc="Completed!")
462
+
463
+ subtitle, result_file_path = self.generate_and_write_file(
464
+ file_name="Mic",
465
+ transcribed_segments=transcribed_segments,
466
+ add_timestamp=add_timestamp,
467
+ file_format=file_format,
468
+ output_dir=self.output_dir
469
+ )
470
+
471
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
472
+ return [result_str, result_file_path]
473
+ except Exception as e:
474
+ print(f"Error transcribing file: {e}")
475
+ finally:
476
+ self.release_cuda_memory()
477
+
478
+ def transcribe_youtube(self,
479
+ youtube_link: str,
480
+ file_format: str = "SRT",
481
+ add_timestamp: bool = True,
482
+ progress=gr.Progress(),
483
+ *whisper_params,
484
+ ) -> list:
485
+ """
486
+ Write subtitle file from Youtube
487
+
488
+ Parameters
489
+ ----------
490
+ youtube_link: str
491
+ URL of the Youtube video to transcribe from gr.Textbox()
492
+ file_format: str
493
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
494
+ add_timestamp: bool
495
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
496
+ progress: gr.Progress
497
+ Indicator to show progress directly in gradio.
498
+ *whisper_params: tuple
499
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
500
+
501
+ Returns
502
+ ----------
503
+ result_str:
504
+ Result of transcription to return to gr.Textbox()
505
+ result_file_path:
506
+ Output file path to return to gr.Files()
507
+ """
508
+ try:
509
+ progress(0, desc="Loading Audio from Youtube...")
510
+ yt = get_ytdata(youtube_link)
511
+ audio = get_ytaudio(yt)
512
+
513
+ transcribed_segments, time_for_task = self.run(
514
+ audio,
515
+ progress,
516
+ add_timestamp,
517
+ *whisper_params,
518
+ )
519
+
520
+ progress(1, desc="Completed!")
521
+
522
+ file_name = safe_filename(yt.title)
523
+ subtitle, result_file_path = self.generate_and_write_file(
524
+ file_name=file_name,
525
+ transcribed_segments=transcribed_segments,
526
+ add_timestamp=add_timestamp,
527
+ file_format=file_format,
528
+ output_dir=self.output_dir
529
+ )
530
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
531
+
532
+ if os.path.exists(audio):
533
+ os.remove(audio)
534
+
535
+ return [result_str, result_file_path]
536
+
537
+ except Exception as e:
538
+ print(f"Error transcribing file: {e}")
539
+ finally:
540
+ self.release_cuda_memory()
541
+
542
+ @staticmethod
543
+ def generate_and_write_file(file_name: str,
544
+ transcribed_segments: list,
545
+ add_timestamp: bool,
546
+ file_format: str,
547
+ output_dir: str
548
+ ) -> str:
549
+ """
550
+ Writes subtitle file
551
+
552
+ Parameters
553
+ ----------
554
+ file_name: str
555
+ Output file name
556
+ transcribed_segments: list
557
+ Text segments transcribed from audio
558
+ add_timestamp: bool
559
+ Determines whether to add a timestamp to the end of the filename.
560
+ file_format: str
561
+ File format to write. Supported formats: [SRT, WebVTT, txt, csv]
562
+ output_dir: str
563
+ Directory path of the output
564
+
565
+ Returns
566
+ ----------
567
+ content: str
568
+ Result of the transcription
569
+ output_path: str
570
+ output file path
571
+ """
572
+ if add_timestamp:
573
+ #timestamp = datetime.now().strftime("%m%d%H%M%S")
574
+ timestamp = datetime.now().strftime("%Y%m%d %H%M%S")
575
+ output_path = os.path.join(output_dir, f"{file_name} - {timestamp}")
576
+ else:
577
+ output_path = os.path.join(output_dir, f"{file_name}")
578
+
579
+ file_format = file_format.strip().lower()
580
+ if file_format == "srt":
581
+ content = get_srt(transcribed_segments)
582
+ output_path += '.srt'
583
+
584
+ elif file_format == "webvtt":
585
+ content = get_vtt(transcribed_segments)
586
+ output_path += '.vtt'
587
+
588
+ elif file_format == "txt":
589
+ content = get_txt(transcribed_segments)
590
+ output_path += '.txt'
591
+
592
+ elif file_format == "csv":
593
+ content = get_csv(transcribed_segments)
594
+ output_path += '.csv'
595
+
596
+ write_file(content, output_path)
597
+ return content, output_path
598
+
599
+ def offload(self):
600
+ """Offload the model and free up the memory"""
601
+ if self.model is not None:
602
+ del self.model
603
+ self.model = None
604
+ if self.device == "cuda":
605
+ self.release_cuda_memory()
606
+ gc.collect()
607
+
608
+ @staticmethod
609
+ def format_time(elapsed_time: float) -> str:
610
+ """
611
+ Get {hours} {minutes} {seconds} time format string
612
+
613
+ Parameters
614
+ ----------
615
+ elapsed_time: str
616
+ Elapsed time for transcription
617
+
618
+ Returns
619
+ ----------
620
+ Time format string
621
+ """
622
+ hours, rem = divmod(elapsed_time, 3600)
623
+ minutes, seconds = divmod(rem, 60)
624
+
625
+ time_str = ""
626
+
627
+ hours = round(hours)
628
+ if hours:
629
+ if hours == 1:
630
+ time_str += f"{hours} hour "
631
+ else:
632
+ time_str += f"{hours} hours "
633
+
634
+ minutes = round(minutes)
635
+ if minutes:
636
+ if minutes == 1:
637
+ time_str += f"{minutes} minute "
638
+ else:
639
+ time_str += f"{minutes} minutes "
640
+
641
+ seconds = round(seconds)
642
+ if seconds == 1:
643
+ time_str += f"{seconds} second"
644
+ else:
645
+ time_str += f"{seconds} seconds"
646
+
647
+ return time_str.strip()
648
+
649
+ @staticmethod
650
+ def get_device():
651
+ if torch.cuda.is_available():
652
+ return "cuda"
653
+ elif torch.backends.mps.is_available():
654
+ if not WhisperBase.is_sparse_api_supported():
655
+ # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
656
+ return "cpu"
657
+ return "mps"
658
+ else:
659
+ return "cpu"
660
+
661
+ @staticmethod
662
+ def is_sparse_api_supported():
663
+ if not torch.backends.mps.is_available():
664
+ return False
665
+
666
+ try:
667
+ device = torch.device("mps")
668
+ sparse_tensor = torch.sparse_coo_tensor(
669
+ indices=torch.tensor([[0, 1], [2, 3]]),
670
+ values=torch.tensor([1, 2]),
671
+ size=(4, 4),
672
+ device=device
673
+ )
674
+ return True
675
+ except RuntimeError:
676
+ return False
677
+
678
+ @staticmethod
679
+ def release_cuda_memory():
680
+ """Release memory"""
681
+ if torch.cuda.is_available():
682
+ torch.cuda.empty_cache()
683
+ torch.cuda.reset_max_memory_allocated()
684
+
685
+ @staticmethod
686
+ def remove_input_files(file_paths: List[str]):
687
+ """Remove gradio cached files"""
688
+ if not file_paths:
689
+ return
690
+
691
+ for file_path in file_paths:
692
+ if file_path and os.path.exists(file_path):
693
+ os.remove(file_path)
694
+
695
+ @staticmethod
696
+ def cache_parameters(
697
+ params: WhisperValues,
698
+ file_format: str = "SRT",
699
+ add_timestamp: bool = True
700
+ ):
701
+ """Cache parameters to the yaml file"""
702
+ cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
703
+ param_to_cache = params.to_dict()
704
+
705
+ cached_yaml = {**cached_params, **param_to_cache}
706
+ cached_yaml["whisper"]["add_timestamp"] = add_timestamp
707
+ cached_yaml["whisper"]["file_format"] = file_format
708
+
709
+ suppress_token = cached_yaml["whisper"].get("suppress_tokens", None)
710
+ if suppress_token and isinstance(suppress_token, list):
711
+ cached_yaml["whisper"]["suppress_tokens"] = str(suppress_token)
712
+
713
+ if cached_yaml["whisper"].get("lang", None) is None:
714
+ cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
715
+ else:
716
+ language_dict = whisper.tokenizer.LANGUAGES
717
+ cached_yaml["whisper"]["lang"] = language_dict[cached_yaml["whisper"]["lang"]]
718
+
719
+ if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'):
720
+ cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX
721
+
722
+ if cached_yaml is not None and cached_yaml:
723
+ save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
724
+
725
+ @staticmethod
726
+ def resample_audio(audio: Union[str, np.ndarray],
727
+ new_sample_rate: int = 16000,
728
+ original_sample_rate: Optional[int] = None,) -> np.ndarray:
729
+ """Resamples audio to 16k sample rate, standard on Whisper model"""
730
+ if isinstance(audio, str):
731
+ audio, original_sample_rate = torchaudio.load(audio)
732
+ else:
733
+ if original_sample_rate is None:
734
+ raise ValueError("original_sample_rate must be provided when audio is numpy array.")
735
+ audio = torch.from_numpy(audio)
736
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate)
737
+ resampled_audio = resampler(audio).numpy()
738
+ return resampled_audio