LAP-DEV commited on
Commit
c16edd0
·
verified ·
1 Parent(s): d33be88

Delete modules/whisper/whisper_base.py

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +0 -715
modules/whisper/whisper_base.py DELETED
@@ -1,715 +0,0 @@
1
- import os
2
- import torch
3
- import whisper
4
- import gradio as gr
5
- import torchaudio
6
- from abc import ABC, abstractmethod
7
- from typing import BinaryIO, Union, Tuple, List
8
- import numpy as np
9
- from datetime import datetime
10
- from faster_whisper.vad import VadOptions
11
- from dataclasses import astuple
12
- import gc
13
- from copy import deepcopy
14
-
15
- from modules.uvr.music_separator import MusicSeparator
16
- from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
17
- UVR_MODELS_DIR)
18
- from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, get_plaintext, get_csv, write_file, safe_filename
19
- from modules.utils.youtube_manager import get_ytdata, get_ytaudio
20
- from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
21
- from modules.whisper.whisper_parameter import *
22
- from modules.diarize.diarizer import Diarizer
23
- from modules.vad.silero_vad import SileroVAD
24
- from modules.translation.nllb_inference import NLLBInference
25
- from modules.translation.nllb_inference import NLLB_AVAILABLE_LANGS
26
-
27
- class WhisperBase(ABC):
28
- def __init__(self,
29
- model_dir: str = WHISPER_MODELS_DIR,
30
- diarization_model_dir: str = DIARIZATION_MODELS_DIR,
31
- uvr_model_dir: str = UVR_MODELS_DIR,
32
- output_dir: str = OUTPUT_DIR,
33
- ):
34
- self.model_dir = model_dir
35
- self.output_dir = output_dir
36
- os.makedirs(self.output_dir, exist_ok=True)
37
- os.makedirs(self.model_dir, exist_ok=True)
38
- self.diarizer = Diarizer(
39
- model_dir=diarization_model_dir
40
- )
41
- self.vad = SileroVAD()
42
- self.music_separator = MusicSeparator(
43
- model_dir=uvr_model_dir,
44
- output_dir=os.path.join(output_dir, "UVR")
45
- )
46
-
47
- self.model = None
48
- self.current_model_size = None
49
- self.available_models = whisper.available_models()
50
- self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
51
- #self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
52
- self.translatable_models = whisper.available_models()
53
- self.device = self.get_device()
54
- self.available_compute_types = ["float16", "float32"]
55
- self.current_compute_type = "float16" if self.device == "cuda" else "float32"
56
-
57
- @abstractmethod
58
- def transcribe(self,
59
- audio: Union[str, BinaryIO, np.ndarray],
60
- progress: gr.Progress = gr.Progress(),
61
- *whisper_params,
62
- ):
63
- """Inference whisper model to transcribe"""
64
- pass
65
-
66
- @abstractmethod
67
- def update_model(self,
68
- model_size: str,
69
- compute_type: str,
70
- progress: gr.Progress = gr.Progress()
71
- ):
72
- """Initialize whisper model"""
73
- pass
74
-
75
- def run(self,
76
- audio: Union[str, BinaryIO, np.ndarray],
77
- progress: gr.Progress = gr.Progress(),
78
- add_timestamp: bool = True,
79
- *whisper_params,
80
- ) -> Tuple[List[dict], float]:
81
- """
82
- Run transcription with conditional pre-processing and post-processing.
83
- The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
84
- The diarization will be performed in post-processing, if enabled.
85
-
86
- Parameters
87
- ----------
88
- audio: Union[str, BinaryIO, np.ndarray]
89
- Audio input. This can be file path or binary type.
90
- progress: gr.Progress
91
- Indicator to show progress directly in gradio.
92
- add_timestamp: bool
93
- Whether to add a timestamp at the end of the filename.
94
- *whisper_params: tuple
95
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
96
-
97
- Returns
98
- ----------
99
- segments_result: List[dict]
100
- list of dicts that includes start, end timestamps and transcribed text
101
- elapsed_time: float
102
- elapsed time for running
103
- """
104
-
105
- start_time = datetime.now()
106
- params = WhisperParameters.as_value(*whisper_params)
107
-
108
- # Get the offload params
109
- default_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
110
- whisper_params = default_params["whisper"]
111
- diarization_params = default_params["diarization"]
112
- bool_whisper_enable_offload = whisper_params["enable_offload"]
113
- bool_diarization_enable_offload = diarization_params["enable_offload"]
114
-
115
- if params.lang is None:
116
- pass
117
- elif params.lang == "Automatic Detection":
118
- params.lang = None
119
- else:
120
- language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
121
- params.lang = language_code_dict[params.lang]
122
-
123
- if params.is_bgm_separate:
124
- music, audio, _ = self.music_separator.separate(
125
- audio=audio,
126
- model_name=params.uvr_model_size,
127
- device=params.uvr_device,
128
- segment_size=params.uvr_segment_size,
129
- save_file=params.uvr_save_file,
130
- progress=progress
131
- )
132
-
133
- if audio.ndim >= 2:
134
- audio = audio.mean(axis=1)
135
- if self.music_separator.audio_info is None:
136
- origin_sample_rate = 16000
137
- else:
138
- origin_sample_rate = self.music_separator.audio_info.sample_rate
139
- audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
140
-
141
- if params.uvr_enable_offload:
142
- self.music_separator.offload()
143
- elapsed_time_bgm_sep = datetime.now() - start_time
144
-
145
- origin_audio = deepcopy(audio)
146
-
147
- if params.vad_filter:
148
- # Explicit value set for float('inf') from gr.Number()
149
- if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
150
- params.max_speech_duration_s = float('inf')
151
-
152
- progress(0, desc="Filtering silent parts from audio...")
153
- vad_options = VadOptions(
154
- threshold=params.threshold,
155
- min_speech_duration_ms=params.min_speech_duration_ms,
156
- max_speech_duration_s=params.max_speech_duration_s,
157
- min_silence_duration_ms=params.min_silence_duration_ms,
158
- speech_pad_ms=params.speech_pad_ms
159
- )
160
-
161
- vad_processed, speech_chunks = self.vad.run(
162
- audio=audio,
163
- vad_parameters=vad_options,
164
- progress=progress
165
- )
166
-
167
- if vad_processed.size > 0:
168
- audio = vad_processed
169
- else:
170
- params.vad_filter = False
171
-
172
- result, elapsed_time = self.transcribe(
173
- audio,
174
- progress,
175
- *astuple(params)
176
- )
177
- if bool_whisper_enable_offload:
178
- self.offload()
179
-
180
- if params.vad_filter:
181
- restored_result = self.vad.restore_speech_timestamps(
182
- segments=result,
183
- speech_chunks=speech_chunks,
184
- )
185
- if restored_result:
186
- result = restored_result
187
- else:
188
- print("VAD detected no speech segments in the audio.")
189
-
190
- if params.is_diarize:
191
- progress(0.99, desc="Diarizing speakers...")
192
- result, elapsed_time_diarization = self.diarizer.run(
193
- audio=origin_audio,
194
- use_auth_token=params.hf_token,
195
- transcribed_result=result,
196
- device=params.diarization_device
197
- )
198
- if bool_diarization_enable_offload:
199
- self.diarizer.offload()
200
-
201
- if not result:
202
- print(f"Whisper did not detected any speech segments in the audio.")
203
- result = list()
204
-
205
- progress(1.0, desc="Processing done!")
206
- total_elapsed_time = datetime.now() - start_time
207
- return result, elapsed_time
208
-
209
- def transcribe_file(self,
210
- files: Optional[List] = None,
211
- input_folder_path: Optional[str] = None,
212
- file_format: str = "SRT",
213
- add_timestamp: bool = True,
214
- translate_output: bool = False,
215
- translate_model: str = "",
216
- target_lang: str = "",
217
- add_timestamp_preview: bool = False,
218
- progress=gr.Progress(),
219
- *whisper_params,
220
- ) -> list:
221
- """
222
- Write subtitle file from Files
223
-
224
- Parameters
225
- ----------
226
- files: list
227
- List of files to transcribe from gr.Files()
228
- input_folder_path: str
229
- Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
230
- this will be used instead.
231
- file_format: str
232
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
233
- add_timestamp: bool
234
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
235
- translate_output: bool
236
- Translate output
237
- translate_model: str
238
- Translation model to use
239
- target_lang: str
240
- Target language to use
241
- add_timestamp_preview: bool
242
- Boolean value from gr.Checkbox() that determines whether to add a timestamp to output preview
243
- progress: gr.Progress
244
- Indicator to show progress directly in gradio.
245
- *whisper_params: tuple
246
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
247
-
248
- Returns
249
- ----------
250
- result_str:
251
- Result of transcription to return to gr.Textbox()
252
- result_file_path:
253
- Output file path to return to gr.Files()
254
- """
255
- try:
256
- if input_folder_path:
257
- files = get_media_files(input_folder_path)
258
- if isinstance(files, str):
259
- files = [files]
260
- if files and isinstance(files[0], gr.utils.NamedString):
261
- files = [file.name for file in files]
262
-
263
- ## Initialization variables & start time
264
- files_info = {}
265
- files_to_download = {}
266
- time_start = datetime.now()
267
-
268
- ## Load parameters related with whisper
269
- params = WhisperParameters.as_value(*whisper_params)
270
-
271
- ## Load model to detect language
272
- model = whisper.load_model("base")
273
-
274
- for file in files:
275
-
276
- ## Detect language
277
- mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(whisper.load_audio(file))).to(model.device)
278
- _, probs = model.detect_language(mel)
279
- file_language = ""
280
- file_lang_probs = ""
281
- for key,value in whisper.tokenizer.LANGUAGES.items():
282
- if key == str(max(probs, key=probs.get)):
283
- file_language = value.capitalize()
284
- for key_prob,value_prob in probs.items():
285
- if key == key_prob:
286
- file_lang_probs = str((round(value_prob*100)))
287
- break
288
- break
289
-
290
- transcribed_segments, time_for_task = self.run(
291
- file,
292
- progress,
293
- add_timestamp,
294
- *whisper_params,
295
- )
296
-
297
- # Define source language
298
- source_lang = file_language
299
-
300
- # Translate to English using Whisper built-in functionality
301
- transcription_note = ""
302
- if params.is_translate:
303
- if source_lang != "English":
304
- transcription_note = "To English"
305
- source_lang = "English"
306
- else:
307
- transcription_note = "Already in English"
308
-
309
- # Translate the transcribed segments
310
- translation_note = ""
311
- if translate_output:
312
- if source_lang != target_lang:
313
- self.nllb_inf = NLLBInference()
314
- if source_lang in NLLB_AVAILABLE_LANGS.keys():
315
- transcribed_segments = self.nllb_inf.translate_text(
316
- input_list_dict=transcribed_segments,
317
- model_size=translate_model,
318
- src_lang=source_lang,
319
- tgt_lang=target_lang,
320
- speaker_diarization=params.is_diarize
321
- )
322
- translation_note = "To " + target_lang
323
- else:
324
- translation_note = source_lang + " not supported"
325
- else:
326
- translation_note = "Already in " + target_lang
327
-
328
- ## Get preview
329
- file_name, file_ext = os.path.splitext(os.path.basename(file))
330
-
331
- ## With or without timestamps
332
- if add_timestamp_preview:
333
- subtitle = get_txt(transcribed_segments)
334
- else:
335
- subtitle = get_plaintext(transcribed_segments)
336
-
337
- files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "lang": file_language, "lang_prob": file_lang_probs, "input_source_file": (file_name+file_ext), "translation": translation_note, "transcription": transcription_note}
338
-
339
- ## Add output file as txt
340
- file_name, file_ext = os.path.splitext(os.path.basename(file))
341
- subtitle, file_path = self.generate_and_write_file(
342
- file_name=file_name,
343
- transcribed_segments=transcribed_segments,
344
- add_timestamp=add_timestamp,
345
- file_format="txt",
346
- output_dir=self.output_dir
347
- )
348
- files_to_download[file_name+"_txt"] = {"path": file_path}
349
-
350
- ## Add output file as srt
351
- file_name, file_ext = os.path.splitext(os.path.basename(file))
352
- subtitle, file_path = self.generate_and_write_file(
353
- file_name=file_name,
354
- transcribed_segments=transcribed_segments,
355
- add_timestamp=add_timestamp,
356
- file_format="srt",
357
- output_dir=self.output_dir
358
- )
359
- files_to_download[file_name+"_srt"] = {"path": file_path}
360
-
361
- ## Add output file as csv
362
- file_name, file_ext = os.path.splitext(os.path.basename(file))
363
- subtitle, file_path = self.generate_and_write_file(
364
- file_name=file_name,
365
- transcribed_segments=transcribed_segments,
366
- add_timestamp=add_timestamp,
367
- file_format="csv",
368
- output_dir=self.output_dir
369
- )
370
- files_to_download[file_name+"_csv"] = {"path": file_path}
371
-
372
- total_result = ""
373
- total_info = ""
374
- total_time = 0
375
- for file_name, info in files_info.items():
376
- total_result += f'{info["subtitle"]}'
377
-
378
- total_time += info["time_for_task"]
379
- total_info += f'Media file:\t{info["input_source_file"]}\nLanguage:\t{info["lang"]} (probability {info["lang_prob"]}%)\n'
380
-
381
- if params.is_translate:
382
- total_info += f'Translation:\t{info["transcription"]}\n\t⤷ Handled by OpenAI Whisper\n'
383
-
384
- if translate_output:
385
- total_info += f'Translation:\t{info["translation"]}\n\t⤷ Handled by Facebook NLLB\n'
386
-
387
- time_end = datetime.now()
388
- total_info += f"\nTotal processing time: {self.format_time((time_end-time_start).total_seconds())}"
389
-
390
- result_str = total_result.rstrip("\n")
391
- result_file_path = [info['path'] for info in files_to_download.values()]
392
-
393
- return [result_str,result_file_path,total_info]
394
-
395
- except Exception as e:
396
- print(f"Error transcribing file: {e}")
397
- finally:
398
- self.release_cuda_memory()
399
-
400
- def transcribe_mic(self,
401
- mic_audio: str,
402
- file_format: str = "SRT",
403
- add_timestamp: bool = True,
404
- progress=gr.Progress(),
405
- *whisper_params,
406
- ) -> list:
407
- """
408
- Write subtitle file from microphone
409
-
410
- Parameters
411
- ----------
412
- mic_audio: str
413
- Audio file path from gr.Microphone()
414
- file_format: str
415
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
416
- add_timestamp: bool
417
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
418
- progress: gr.Progress
419
- Indicator to show progress directly in gradio.
420
- *whisper_params: tuple
421
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
422
-
423
- Returns
424
- ----------
425
- result_str:
426
- Result of transcription to return to gr.Textbox()
427
- result_file_path:
428
- Output file path to return to gr.Files()
429
- """
430
- try:
431
- progress(0, desc="Loading Audio...")
432
- transcribed_segments, time_for_task = self.run(
433
- mic_audio,
434
- progress,
435
- add_timestamp,
436
- *whisper_params,
437
- )
438
- progress(1, desc="Completed!")
439
-
440
- subtitle, result_file_path = self.generate_and_write_file(
441
- file_name="Mic",
442
- transcribed_segments=transcribed_segments,
443
- add_timestamp=add_timestamp,
444
- file_format=file_format,
445
- output_dir=self.output_dir
446
- )
447
-
448
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
449
- return [result_str, result_file_path]
450
- except Exception as e:
451
- print(f"Error transcribing file: {e}")
452
- finally:
453
- self.release_cuda_memory()
454
-
455
- def transcribe_youtube(self,
456
- youtube_link: str,
457
- file_format: str = "SRT",
458
- add_timestamp: bool = True,
459
- progress=gr.Progress(),
460
- *whisper_params,
461
- ) -> list:
462
- """
463
- Write subtitle file from Youtube
464
-
465
- Parameters
466
- ----------
467
- youtube_link: str
468
- URL of the Youtube video to transcribe from gr.Textbox()
469
- file_format: str
470
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
471
- add_timestamp: bool
472
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
473
- progress: gr.Progress
474
- Indicator to show progress directly in gradio.
475
- *whisper_params: tuple
476
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
477
-
478
- Returns
479
- ----------
480
- result_str:
481
- Result of transcription to return to gr.Textbox()
482
- result_file_path:
483
- Output file path to return to gr.Files()
484
- """
485
- try:
486
- progress(0, desc="Loading Audio from Youtube...")
487
- yt = get_ytdata(youtube_link)
488
- audio = get_ytaudio(yt)
489
-
490
- transcribed_segments, time_for_task = self.run(
491
- audio,
492
- progress,
493
- add_timestamp,
494
- *whisper_params,
495
- )
496
-
497
- progress(1, desc="Completed!")
498
-
499
- file_name = safe_filename(yt.title)
500
- subtitle, result_file_path = self.generate_and_write_file(
501
- file_name=file_name,
502
- transcribed_segments=transcribed_segments,
503
- add_timestamp=add_timestamp,
504
- file_format=file_format,
505
- output_dir=self.output_dir
506
- )
507
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
508
-
509
- if os.path.exists(audio):
510
- os.remove(audio)
511
-
512
- return [result_str, result_file_path]
513
-
514
- except Exception as e:
515
- print(f"Error transcribing file: {e}")
516
- finally:
517
- self.release_cuda_memory()
518
-
519
- @staticmethod
520
- def generate_and_write_file(file_name: str,
521
- transcribed_segments: list,
522
- add_timestamp: bool,
523
- file_format: str,
524
- output_dir: str
525
- ) -> str:
526
- """
527
- Writes subtitle file
528
-
529
- Parameters
530
- ----------
531
- file_name: str
532
- Output file name
533
- transcribed_segments: list
534
- Text segments transcribed from audio
535
- add_timestamp: bool
536
- Determines whether to add a timestamp to the end of the filename.
537
- file_format: str
538
- File format to write. Supported formats: [SRT, WebVTT, txt, csv]
539
- output_dir: str
540
- Directory path of the output
541
-
542
- Returns
543
- ----------
544
- content: str
545
- Result of the transcription
546
- output_path: str
547
- output file path
548
- """
549
- if add_timestamp:
550
- #timestamp = datetime.now().strftime("%m%d%H%M%S")
551
- timestamp = datetime.now().strftime("%Y%m%d %H%M%S")
552
- output_path = os.path.join(output_dir, f"{file_name} - {timestamp}")
553
- else:
554
- output_path = os.path.join(output_dir, f"{file_name}")
555
-
556
- file_format = file_format.strip().lower()
557
- if file_format == "srt":
558
- content = get_srt(transcribed_segments)
559
- output_path += '.srt'
560
-
561
- elif file_format == "webvtt":
562
- content = get_vtt(transcribed_segments)
563
- output_path += '.vtt'
564
-
565
- elif file_format == "txt":
566
- content = get_txt(transcribed_segments)
567
- output_path += '.txt'
568
-
569
- elif file_format == "csv":
570
- content = get_csv(transcribed_segments)
571
- output_path += '.csv'
572
-
573
- write_file(content, output_path)
574
- return content, output_path
575
-
576
- def offload(self):
577
- """Offload the model and free up the memory"""
578
- if self.model is not None:
579
- del self.model
580
- self.model = None
581
- if self.device == "cuda":
582
- self.release_cuda_memory()
583
- gc.collect()
584
-
585
- @staticmethod
586
- def format_time(elapsed_time: float) -> str:
587
- """
588
- Get {hours} {minutes} {seconds} time format string
589
-
590
- Parameters
591
- ----------
592
- elapsed_time: str
593
- Elapsed time for transcription
594
-
595
- Returns
596
- ----------
597
- Time format string
598
- """
599
- hours, rem = divmod(elapsed_time, 3600)
600
- minutes, seconds = divmod(rem, 60)
601
-
602
- time_str = ""
603
-
604
- hours = round(hours)
605
- if hours:
606
- if hours == 1:
607
- time_str += f"{hours} hour "
608
- else:
609
- time_str += f"{hours} hours "
610
-
611
- minutes = round(minutes)
612
- if minutes:
613
- if minutes == 1:
614
- time_str += f"{minutes} minute "
615
- else:
616
- time_str += f"{minutes} minutes "
617
-
618
- seconds = round(seconds)
619
- if seconds == 1:
620
- time_str += f"{seconds} second"
621
- else:
622
- time_str += f"{seconds} seconds"
623
-
624
- return time_str.strip()
625
-
626
- @staticmethod
627
- def get_device():
628
- if torch.cuda.is_available():
629
- return "cuda"
630
- elif torch.backends.mps.is_available():
631
- if not WhisperBase.is_sparse_api_supported():
632
- # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
633
- return "cpu"
634
- return "mps"
635
- else:
636
- return "cpu"
637
-
638
- @staticmethod
639
- def is_sparse_api_supported():
640
- if not torch.backends.mps.is_available():
641
- return False
642
-
643
- try:
644
- device = torch.device("mps")
645
- sparse_tensor = torch.sparse_coo_tensor(
646
- indices=torch.tensor([[0, 1], [2, 3]]),
647
- values=torch.tensor([1, 2]),
648
- size=(4, 4),
649
- device=device
650
- )
651
- return True
652
- except RuntimeError:
653
- return False
654
-
655
- @staticmethod
656
- def release_cuda_memory():
657
- """Release memory"""
658
- if torch.cuda.is_available():
659
- torch.cuda.empty_cache()
660
- torch.cuda.reset_max_memory_allocated()
661
-
662
- @staticmethod
663
- def remove_input_files(file_paths: List[str]):
664
- """Remove gradio cached files"""
665
- if not file_paths:
666
- return
667
-
668
- for file_path in file_paths:
669
- if file_path and os.path.exists(file_path):
670
- os.remove(file_path)
671
-
672
- @staticmethod
673
- def cache_parameters(
674
- params: WhisperValues,
675
- file_format: str = "SRT",
676
- add_timestamp: bool = True
677
- ):
678
- """Cache parameters to the yaml file"""
679
- cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
680
- param_to_cache = params.to_dict()
681
-
682
- cached_yaml = {**cached_params, **param_to_cache}
683
- cached_yaml["whisper"]["add_timestamp"] = add_timestamp
684
- cached_yaml["whisper"]["file_format"] = file_format
685
-
686
- suppress_token = cached_yaml["whisper"].get("suppress_tokens", None)
687
- if suppress_token and isinstance(suppress_token, list):
688
- cached_yaml["whisper"]["suppress_tokens"] = str(suppress_token)
689
-
690
- if cached_yaml["whisper"].get("lang", None) is None:
691
- cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
692
- else:
693
- language_dict = whisper.tokenizer.LANGUAGES
694
- cached_yaml["whisper"]["lang"] = language_dict[cached_yaml["whisper"]["lang"]]
695
-
696
- if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'):
697
- cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX
698
-
699
- if cached_yaml is not None and cached_yaml:
700
- save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
701
-
702
- @staticmethod
703
- def resample_audio(audio: Union[str, np.ndarray],
704
- new_sample_rate: int = 16000,
705
- original_sample_rate: Optional[int] = None,) -> np.ndarray:
706
- """Resamples audio to 16k sample rate, standard on Whisper model"""
707
- if isinstance(audio, str):
708
- audio, original_sample_rate = torchaudio.load(audio)
709
- else:
710
- if original_sample_rate is None:
711
- raise ValueError("original_sample_rate must be provided when audio is numpy array.")
712
- audio = torch.from_numpy(audio)
713
- resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate)
714
- resampled_audio = resampler(audio).numpy()
715
- return resampled_audio