LAP-DEV commited on
Commit
b58c6d8
·
verified ·
1 Parent(s): 1e31f50

Upload whisper_base_old.py

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base_old.py +648 -0
modules/whisper/whisper_base_old.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import whisper
4
+ import gradio as gr
5
+ import torchaudio
6
+ from abc import ABC, abstractmethod
7
+ from typing import BinaryIO, Union, Tuple, List
8
+ import numpy as np
9
+ from datetime import datetime
10
+ from faster_whisper.vad import VadOptions
11
+ from dataclasses import astuple
12
+
13
+ from modules.uvr.music_separator import MusicSeparator
14
+ from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
15
+ UVR_MODELS_DIR)
16
+ from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, get_csv, write_file, safe_filename
17
+ from modules.utils.youtube_manager import get_ytdata, get_ytaudio
18
+ from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
19
+ from modules.whisper.whisper_parameter import *
20
+ from modules.diarize.diarizer import Diarizer
21
+ from modules.vad.silero_vad import SileroVAD
22
+ from modules.translation.nllb_inference import NLLBInference
23
+ from modules.translation.nllb_inference import NLLB_AVAILABLE_LANGS
24
+
25
+ class WhisperBase(ABC):
26
+ def __init__(self,
27
+ model_dir: str = WHISPER_MODELS_DIR,
28
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
29
+ uvr_model_dir: str = UVR_MODELS_DIR,
30
+ output_dir: str = OUTPUT_DIR,
31
+ ):
32
+ self.model_dir = model_dir
33
+ self.output_dir = output_dir
34
+ os.makedirs(self.output_dir, exist_ok=True)
35
+ os.makedirs(self.model_dir, exist_ok=True)
36
+ self.diarizer = Diarizer(
37
+ model_dir=diarization_model_dir
38
+ )
39
+ self.vad = SileroVAD()
40
+ self.music_separator = MusicSeparator(
41
+ model_dir=uvr_model_dir,
42
+ output_dir=os.path.join(output_dir, "UVR")
43
+ )
44
+
45
+ self.model = None
46
+ self.current_model_size = None
47
+ self.available_models = whisper.available_models()
48
+ self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
49
+ #self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
50
+ self.translatable_models = whisper.available_models()
51
+ self.device = self.get_device()
52
+ self.available_compute_types = ["float16", "float32"]
53
+ self.current_compute_type = "float16" if self.device == "cuda" else "float32"
54
+
55
+ @abstractmethod
56
+ def transcribe(self,
57
+ audio: Union[str, BinaryIO, np.ndarray],
58
+ progress: gr.Progress = gr.Progress(),
59
+ *whisper_params,
60
+ ):
61
+ """Inference whisper model to transcribe"""
62
+ pass
63
+
64
+ @abstractmethod
65
+ def update_model(self,
66
+ model_size: str,
67
+ compute_type: str,
68
+ progress: gr.Progress = gr.Progress()
69
+ ):
70
+ """Initialize whisper model"""
71
+ pass
72
+
73
+ def run(self,
74
+ audio: Union[str, BinaryIO, np.ndarray],
75
+ progress: gr.Progress = gr.Progress(),
76
+ add_timestamp: bool = True,
77
+ *whisper_params,
78
+ ) -> Tuple[List[dict], float]:
79
+ """
80
+ Run transcription with conditional pre-processing and post-processing.
81
+ The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
82
+ The diarization will be performed in post-processing, if enabled.
83
+
84
+ Parameters
85
+ ----------
86
+ audio: Union[str, BinaryIO, np.ndarray]
87
+ Audio input. This can be file path or binary type.
88
+ progress: gr.Progress
89
+ Indicator to show progress directly in gradio.
90
+ add_timestamp: bool
91
+ Whether to add a timestamp at the end of the filename.
92
+ *whisper_params: tuple
93
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
94
+
95
+ Returns
96
+ ----------
97
+ segments_result: List[dict]
98
+ list of dicts that includes start, end timestamps and transcribed text
99
+ elapsed_time: float
100
+ elapsed time for running
101
+ """
102
+ params = WhisperParameters.as_value(*whisper_params)
103
+
104
+ self.cache_parameters(
105
+ whisper_params=params,
106
+ add_timestamp=add_timestamp
107
+ )
108
+
109
+ if params.lang is None:
110
+ pass
111
+ elif params.lang == "Automatic Detection":
112
+ params.lang = None
113
+ else:
114
+ language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
115
+ params.lang = language_code_dict[params.lang]
116
+
117
+ if params.is_bgm_separate:
118
+ music, audio, _ = self.music_separator.separate(
119
+ audio=audio,
120
+ model_name=params.uvr_model_size,
121
+ device=params.uvr_device,
122
+ segment_size=params.uvr_segment_size,
123
+ save_file=params.uvr_save_file,
124
+ progress=progress
125
+ )
126
+
127
+ if audio.ndim >= 2:
128
+ audio = audio.mean(axis=1)
129
+ if self.music_separator.audio_info is None:
130
+ origin_sample_rate = 16000
131
+ else:
132
+ origin_sample_rate = self.music_separator.audio_info.sample_rate
133
+ audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
134
+
135
+ if params.uvr_enable_offload:
136
+ self.music_separator.offload()
137
+
138
+ if params.vad_filter:
139
+ # Explicit value set for float('inf') from gr.Number()
140
+ if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
141
+ params.max_speech_duration_s = float('inf')
142
+
143
+ vad_options = VadOptions(
144
+ threshold=params.threshold,
145
+ min_speech_duration_ms=params.min_speech_duration_ms,
146
+ max_speech_duration_s=params.max_speech_duration_s,
147
+ min_silence_duration_ms=params.min_silence_duration_ms,
148
+ speech_pad_ms=params.speech_pad_ms
149
+ )
150
+
151
+ audio, speech_chunks = self.vad.run(
152
+ audio=audio,
153
+ vad_parameters=vad_options,
154
+ progress=progress
155
+ )
156
+
157
+ result, elapsed_time = self.transcribe(
158
+ audio,
159
+ progress,
160
+ *astuple(params)
161
+ )
162
+
163
+ if params.vad_filter:
164
+ result = self.vad.restore_speech_timestamps(
165
+ segments=result,
166
+ speech_chunks=speech_chunks,
167
+ )
168
+
169
+ if params.is_diarize:
170
+ result, elapsed_time_diarization = self.diarizer.run(
171
+ audio=audio,
172
+ use_auth_token=params.hf_token,
173
+ transcribed_result=result,
174
+ )
175
+ elapsed_time += elapsed_time_diarization
176
+ return result, elapsed_time
177
+
178
+ def transcribe_file(self,
179
+ files: Optional[List] = None,
180
+ input_folder_path: Optional[str] = None,
181
+ file_format: str = "SRT",
182
+ add_timestamp: bool = True,
183
+ translate_output: bool = False,
184
+ translate_model: str = "",
185
+ target_lang: str = "",
186
+ progress=gr.Progress(),
187
+ *whisper_params,
188
+ ) -> list:
189
+ """
190
+ Write subtitle file from Files
191
+
192
+ Parameters
193
+ ----------
194
+ files: list
195
+ List of files to transcribe from gr.Files()
196
+ input_folder_path: str
197
+ Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
198
+ this will be used instead.
199
+ file_format: str
200
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
201
+ add_timestamp: bool
202
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
203
+ translate_output: bool
204
+ Translate output
205
+ translate_model: str
206
+ Translation model to use
207
+ target_lang: str
208
+ Target language to use
209
+ progress: gr.Progress
210
+ Indicator to show progress directly in gradio.
211
+ *whisper_params: tuple
212
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
213
+
214
+ Returns
215
+ ----------
216
+ result_str:
217
+ Result of transcription to return to gr.Textbox()
218
+ result_file_path:
219
+ Output file path to return to gr.Files()
220
+ """
221
+ try:
222
+ if input_folder_path:
223
+ files = get_media_files(input_folder_path)
224
+ if isinstance(files, str):
225
+ files = [files]
226
+ if files and isinstance(files[0], gr.utils.NamedString):
227
+ files = [file.name for file in files]
228
+
229
+ ## Initialization variables & start time
230
+ files_info = {}
231
+ files_to_download = {}
232
+ time_start = datetime.now()
233
+
234
+ ## Load parameters related with whisper
235
+ params = WhisperParameters.as_value(*whisper_params)
236
+
237
+ ## Load model to detect language
238
+ model = whisper.load_model("base")
239
+
240
+ for file in files:
241
+
242
+ ## Detect language
243
+ mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(whisper.load_audio(file))).to(model.device)
244
+ _, probs = model.detect_language(mel)
245
+ file_language = ""
246
+ file_lang_probs = ""
247
+ for key,value in whisper.tokenizer.LANGUAGES.items():
248
+ if key == str(max(probs, key=probs.get)):
249
+ file_language = value.capitalize()
250
+ for key_prob,value_prob in probs.items():
251
+ if key == key_prob:
252
+ file_lang_probs = str((round(value_prob*100)))
253
+ break
254
+ break
255
+
256
+ transcribed_segments, time_for_task = self.run(
257
+ file,
258
+ progress,
259
+ add_timestamp,
260
+ *whisper_params,
261
+ )
262
+
263
+ # Define source language
264
+ source_lang = file_language
265
+
266
+ # Translate to English using Whisper built-in functionality
267
+ transcription_note = ""
268
+ if params.is_translate:
269
+ if source_lang != "English":
270
+ transcription_note = "To English"
271
+ source_lang = "English"
272
+ else:
273
+ transcription_note = "Already in English"
274
+
275
+ # Translate the transcribed segments
276
+ translation_note = ""
277
+ if translate_output:
278
+ if source_lang != target_lang:
279
+ self.nllb_inf = NLLBInference()
280
+ if source_lang in NLLB_AVAILABLE_LANGS.keys():
281
+ transcribed_segments = self.nllb_inf.translate_text(
282
+ input_list_dict=transcribed_segments,
283
+ model_size=translate_model,
284
+ src_lang=source_lang,
285
+ tgt_lang=target_lang,
286
+ speaker_diarization=params.is_diarize
287
+ )
288
+ translation_note = "To " + target_lang
289
+ else:
290
+ translation_note = source_lang + " not supported"
291
+ else:
292
+ translation_note = "Already in " + target_lang
293
+
294
+ ## Get preview as txt
295
+ file_name, file_ext = os.path.splitext(os.path.basename(file))
296
+ subtitle = get_txt(transcribed_segments)
297
+ files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "lang": file_language, "lang_prob": file_lang_probs, "input_source_file": (file_name+file_ext), "translation": translation_note, "transcription": transcription_note}
298
+
299
+ ## Add output file as txt
300
+ file_name, file_ext = os.path.splitext(os.path.basename(file))
301
+ subtitle, file_path = self.generate_and_write_file(
302
+ file_name=file_name,
303
+ transcribed_segments=transcribed_segments,
304
+ add_timestamp=add_timestamp,
305
+ file_format="txt",
306
+ output_dir=self.output_dir
307
+ )
308
+ files_to_download[file_name+"_txt"] = {"path": file_path}
309
+
310
+ ## Add output file as srt
311
+ file_name, file_ext = os.path.splitext(os.path.basename(file))
312
+ subtitle, file_path = self.generate_and_write_file(
313
+ file_name=file_name,
314
+ transcribed_segments=transcribed_segments,
315
+ add_timestamp=add_timestamp,
316
+ file_format="srt",
317
+ output_dir=self.output_dir
318
+ )
319
+ files_to_download[file_name+"_srt"] = {"path": file_path}
320
+
321
+ ## Add output file as csv
322
+ file_name, file_ext = os.path.splitext(os.path.basename(file))
323
+ subtitle, file_path = self.generate_and_write_file(
324
+ file_name=file_name,
325
+ transcribed_segments=transcribed_segments,
326
+ add_timestamp=add_timestamp,
327
+ file_format="csv",
328
+ output_dir=self.output_dir
329
+ )
330
+ files_to_download[file_name+"_csv"] = {"path": file_path}
331
+
332
+ total_result = ''
333
+ total_info = ''
334
+ total_time = 0
335
+ for file_name, info in files_info.items():
336
+ total_result += f'{info["subtitle"]}'
337
+ total_time += info["time_for_task"]
338
+ total_info += f'Input file:\t\t{info["input_source_file"]}\nLanguage:\t{info["lang"]} (probability {info["lang_prob"]}%)\n'
339
+
340
+ if params.is_translate:
341
+ total_info += f'Translation:\t{info["transcription"]}\n\t⤷ Handled by OpenAI Whisper\n'
342
+
343
+ if translate_output:
344
+ total_info += f'Translation:\t{info["translation"]}\n\t⤷ Handled by Facebook NLLB\n'
345
+
346
+ time_end = datetime.now()
347
+ total_info += f"\nTotal processing time: {self.format_time((time_end-time_start).total_seconds())}"
348
+
349
+ result_str = total_result.rstrip("\n")
350
+ result_file_path = [info['path'] for info in files_to_download.values()]
351
+
352
+ return [result_str,result_file_path,total_info]
353
+
354
+ except Exception as e:
355
+ print(f"Error transcribing file: {e}")
356
+ finally:
357
+ self.release_cuda_memory()
358
+
359
+ def transcribe_mic(self,
360
+ mic_audio: str,
361
+ file_format: str = "SRT",
362
+ add_timestamp: bool = True,
363
+ progress=gr.Progress(),
364
+ *whisper_params,
365
+ ) -> list:
366
+ """
367
+ Write subtitle file from microphone
368
+
369
+ Parameters
370
+ ----------
371
+ mic_audio: str
372
+ Audio file path from gr.Microphone()
373
+ file_format: str
374
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
375
+ add_timestamp: bool
376
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
377
+ progress: gr.Progress
378
+ Indicator to show progress directly in gradio.
379
+ *whisper_params: tuple
380
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
381
+
382
+ Returns
383
+ ----------
384
+ result_str:
385
+ Result of transcription to return to gr.Textbox()
386
+ result_file_path:
387
+ Output file path to return to gr.Files()
388
+ """
389
+ try:
390
+ progress(0, desc="Loading Audio...")
391
+ transcribed_segments, time_for_task = self.run(
392
+ mic_audio,
393
+ progress,
394
+ add_timestamp,
395
+ *whisper_params,
396
+ )
397
+ progress(1, desc="Completed!")
398
+
399
+ subtitle, result_file_path = self.generate_and_write_file(
400
+ file_name="Mic",
401
+ transcribed_segments=transcribed_segments,
402
+ add_timestamp=add_timestamp,
403
+ file_format=file_format,
404
+ output_dir=self.output_dir
405
+ )
406
+
407
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
408
+ return [result_str, result_file_path]
409
+ except Exception as e:
410
+ print(f"Error transcribing file: {e}")
411
+ finally:
412
+ self.release_cuda_memory()
413
+
414
+ def transcribe_youtube(self,
415
+ youtube_link: str,
416
+ file_format: str = "SRT",
417
+ add_timestamp: bool = True,
418
+ progress=gr.Progress(),
419
+ *whisper_params,
420
+ ) -> list:
421
+ """
422
+ Write subtitle file from Youtube
423
+
424
+ Parameters
425
+ ----------
426
+ youtube_link: str
427
+ URL of the Youtube video to transcribe from gr.Textbox()
428
+ file_format: str
429
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
430
+ add_timestamp: bool
431
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
432
+ progress: gr.Progress
433
+ Indicator to show progress directly in gradio.
434
+ *whisper_params: tuple
435
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
436
+
437
+ Returns
438
+ ----------
439
+ result_str:
440
+ Result of transcription to return to gr.Textbox()
441
+ result_file_path:
442
+ Output file path to return to gr.Files()
443
+ """
444
+ try:
445
+ progress(0, desc="Loading Audio from Youtube...")
446
+ yt = get_ytdata(youtube_link)
447
+ audio = get_ytaudio(yt)
448
+
449
+ transcribed_segments, time_for_task = self.run(
450
+ audio,
451
+ progress,
452
+ add_timestamp,
453
+ *whisper_params,
454
+ )
455
+
456
+ progress(1, desc="Completed!")
457
+
458
+ file_name = safe_filename(yt.title)
459
+ subtitle, result_file_path = self.generate_and_write_file(
460
+ file_name=file_name,
461
+ transcribed_segments=transcribed_segments,
462
+ add_timestamp=add_timestamp,
463
+ file_format=file_format,
464
+ output_dir=self.output_dir
465
+ )
466
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
467
+
468
+ if os.path.exists(audio):
469
+ os.remove(audio)
470
+
471
+ return [result_str, result_file_path]
472
+
473
+ except Exception as e:
474
+ print(f"Error transcribing file: {e}")
475
+ finally:
476
+ self.release_cuda_memory()
477
+
478
+ @staticmethod
479
+ def generate_and_write_file(file_name: str,
480
+ transcribed_segments: list,
481
+ add_timestamp: bool,
482
+ file_format: str,
483
+ output_dir: str
484
+ ) -> str:
485
+ """
486
+ Writes subtitle file
487
+
488
+ Parameters
489
+ ----------
490
+ file_name: str
491
+ Output file name
492
+ transcribed_segments: list
493
+ Text segments transcribed from audio
494
+ add_timestamp: bool
495
+ Determines whether to add a timestamp to the end of the filename.
496
+ file_format: str
497
+ File format to write. Supported formats: [SRT, WebVTT, txt, csv]
498
+ output_dir: str
499
+ Directory path of the output
500
+
501
+ Returns
502
+ ----------
503
+ content: str
504
+ Result of the transcription
505
+ output_path: str
506
+ output file path
507
+ """
508
+ if add_timestamp:
509
+ #timestamp = datetime.now().strftime("%m%d%H%M%S")
510
+ timestamp = datetime.now().strftime("%Y%m%d %H%M%S")
511
+ output_path = os.path.join(output_dir, f"{file_name} - {timestamp}")
512
+ else:
513
+ output_path = os.path.join(output_dir, f"{file_name}")
514
+
515
+ file_format = file_format.strip().lower()
516
+ if file_format == "srt":
517
+ content = get_srt(transcribed_segments)
518
+ output_path += '.srt'
519
+
520
+ elif file_format == "webvtt":
521
+ content = get_vtt(transcribed_segments)
522
+ output_path += '.vtt'
523
+
524
+ elif file_format == "txt":
525
+ content = get_txt(transcribed_segments)
526
+ output_path += '.txt'
527
+
528
+ elif file_format == "csv":
529
+ content = get_csv(transcribed_segments)
530
+ output_path += '.csv'
531
+
532
+ write_file(content, output_path)
533
+ return content, output_path
534
+
535
+ @staticmethod
536
+ def format_time(elapsed_time: float) -> str:
537
+ """
538
+ Get {hours} {minutes} {seconds} time format string
539
+
540
+ Parameters
541
+ ----------
542
+ elapsed_time: str
543
+ Elapsed time for transcription
544
+
545
+ Returns
546
+ ----------
547
+ Time format string
548
+ """
549
+ hours, rem = divmod(elapsed_time, 3600)
550
+ minutes, seconds = divmod(rem, 60)
551
+
552
+ time_str = ""
553
+
554
+ hours = round(hours)
555
+ if hours:
556
+ if hours == 1:
557
+ time_str += f"{hours} hour "
558
+ else:
559
+ time_str += f"{hours} hours "
560
+
561
+ minutes = round(minutes)
562
+ if minutes:
563
+ if minutes == 1:
564
+ time_str += f"{minutes} minute "
565
+ else:
566
+ time_str += f"{minutes} minutes "
567
+
568
+ seconds = round(seconds)
569
+ if seconds == 1:
570
+ time_str += f"{seconds} second"
571
+ else:
572
+ time_str += f"{seconds} seconds"
573
+
574
+ return time_str.strip()
575
+
576
+ @staticmethod
577
+ def get_device():
578
+ if torch.cuda.is_available():
579
+ return "cuda"
580
+ elif torch.backends.mps.is_available():
581
+ if not WhisperBase.is_sparse_api_supported():
582
+ # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
583
+ return "cpu"
584
+ return "mps"
585
+ else:
586
+ return "cpu"
587
+
588
+ @staticmethod
589
+ def is_sparse_api_supported():
590
+ if not torch.backends.mps.is_available():
591
+ return False
592
+
593
+ try:
594
+ device = torch.device("mps")
595
+ sparse_tensor = torch.sparse_coo_tensor(
596
+ indices=torch.tensor([[0, 1], [2, 3]]),
597
+ values=torch.tensor([1, 2]),
598
+ size=(4, 4),
599
+ device=device
600
+ )
601
+ return True
602
+ except RuntimeError:
603
+ return False
604
+
605
+ @staticmethod
606
+ def release_cuda_memory():
607
+ """Release memory"""
608
+ if torch.cuda.is_available():
609
+ torch.cuda.empty_cache()
610
+ torch.cuda.reset_max_memory_allocated()
611
+
612
+ @staticmethod
613
+ def remove_input_files(file_paths: List[str]):
614
+ """Remove gradio cached files"""
615
+ if not file_paths:
616
+ return
617
+
618
+ for file_path in file_paths:
619
+ if file_path and os.path.exists(file_path):
620
+ os.remove(file_path)
621
+
622
+ @staticmethod
623
+ def cache_parameters(
624
+ whisper_params: WhisperValues,
625
+ add_timestamp: bool
626
+ ):
627
+ """cache parameters to the yaml file"""
628
+ cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
629
+ cached_whisper_param = whisper_params.to_yaml()
630
+ cached_yaml = {**cached_params, **cached_whisper_param}
631
+ cached_yaml["whisper"]["add_timestamp"] = add_timestamp
632
+
633
+ save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
634
+
635
+ @staticmethod
636
+ def resample_audio(audio: Union[str, np.ndarray],
637
+ new_sample_rate: int = 16000,
638
+ original_sample_rate: Optional[int] = None,) -> np.ndarray:
639
+ """Resamples audio to 16k sample rate, standard on Whisper model"""
640
+ if isinstance(audio, str):
641
+ audio, original_sample_rate = torchaudio.load(audio)
642
+ else:
643
+ if original_sample_rate is None:
644
+ raise ValueError("original_sample_rate must be provided when audio is numpy array.")
645
+ audio = torch.from_numpy(audio)
646
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate)
647
+ resampled_audio = resampler(audio).numpy()
648
+ return resampled_audio