LAP-DEV commited on
Commit
904a73a
·
verified ·
1 Parent(s): c3f75cb

Delete modules/whisper/whisper_base.py

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +0 -542
modules/whisper/whisper_base.py DELETED
@@ -1,542 +0,0 @@
1
- import os
2
- import torch
3
- import whisper
4
- import gradio as gr
5
- import torchaudio
6
- from abc import ABC, abstractmethod
7
- from typing import BinaryIO, Union, Tuple, List
8
- import numpy as np
9
- from datetime import datetime
10
- from faster_whisper.vad import VadOptions
11
- from dataclasses import astuple
12
-
13
- from modules.uvr.music_separator import MusicSeparator
14
- from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
15
- UVR_MODELS_DIR)
16
- from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
17
- from modules.utils.youtube_manager import get_ytdata, get_ytaudio
18
- from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
19
- from modules.whisper.whisper_parameter import *
20
- from modules.diarize.diarizer import Diarizer
21
- from modules.vad.silero_vad import SileroVAD
22
-
23
-
24
- class WhisperBase(ABC):
25
- def __init__(self,
26
- model_dir: str = WHISPER_MODELS_DIR,
27
- diarization_model_dir: str = DIARIZATION_MODELS_DIR,
28
- uvr_model_dir: str = UVR_MODELS_DIR,
29
- output_dir: str = OUTPUT_DIR,
30
- ):
31
- self.model_dir = model_dir
32
- self.output_dir = output_dir
33
- os.makedirs(self.output_dir, exist_ok=True)
34
- os.makedirs(self.model_dir, exist_ok=True)
35
- self.diarizer = Diarizer(
36
- model_dir=diarization_model_dir
37
- )
38
- self.vad = SileroVAD()
39
- self.music_separator = MusicSeparator(
40
- model_dir=uvr_model_dir,
41
- output_dir=os.path.join(output_dir, "UVR")
42
- )
43
-
44
- self.model = None
45
- self.current_model_size = None
46
- self.available_models = whisper.available_models()
47
- self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
48
- self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
49
- self.device = self.get_device()
50
- self.available_compute_types = ["float16", "float32"]
51
- self.current_compute_type = "float16" if self.device == "cuda" else "float32"
52
-
53
- @abstractmethod
54
- def transcribe(self,
55
- audio: Union[str, BinaryIO, np.ndarray],
56
- progress: gr.Progress = gr.Progress(),
57
- *whisper_params,
58
- ):
59
- """Inference whisper model to transcribe"""
60
- pass
61
-
62
- @abstractmethod
63
- def update_model(self,
64
- model_size: str,
65
- compute_type: str,
66
- progress: gr.Progress = gr.Progress()
67
- ):
68
- """Initialize whisper model"""
69
- pass
70
-
71
- def run(self,
72
- audio: Union[str, BinaryIO, np.ndarray],
73
- progress: gr.Progress = gr.Progress(),
74
- add_timestamp: bool = True,
75
- *whisper_params,
76
- ) -> Tuple[List[dict], float]:
77
- """
78
- Run transcription with conditional pre-processing and post-processing.
79
- The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
80
- The diarization will be performed in post-processing, if enabled.
81
-
82
- Parameters
83
- ----------
84
- audio: Union[str, BinaryIO, np.ndarray]
85
- Audio input. This can be file path or binary type.
86
- progress: gr.Progress
87
- Indicator to show progress directly in gradio.
88
- add_timestamp: bool
89
- Whether to add a timestamp at the end of the filename.
90
- *whisper_params: tuple
91
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
92
-
93
- Returns
94
- ----------
95
- segments_result: List[dict]
96
- list of dicts that includes start, end timestamps and transcribed text
97
- elapsed_time: float
98
- elapsed time for running
99
- """
100
- params = WhisperParameters.as_value(*whisper_params)
101
-
102
- self.cache_parameters(
103
- whisper_params=params,
104
- add_timestamp=add_timestamp
105
- )
106
-
107
- if params.lang is None:
108
- pass
109
- elif params.lang == "Automatic Detection":
110
- params.lang = None
111
- else:
112
- language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
113
- params.lang = language_code_dict[params.lang]
114
-
115
- if params.is_bgm_separate:
116
- music, audio, _ = self.music_separator.separate(
117
- audio=audio,
118
- model_name=params.uvr_model_size,
119
- device=params.uvr_device,
120
- segment_size=params.uvr_segment_size,
121
- save_file=params.uvr_save_file,
122
- progress=progress
123
- )
124
-
125
- if audio.ndim >= 2:
126
- audio = audio.mean(axis=1)
127
- if self.music_separator.audio_info is None:
128
- origin_sample_rate = 16000
129
- else:
130
- origin_sample_rate = self.music_separator.audio_info.sample_rate
131
- audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
132
-
133
- if params.uvr_enable_offload:
134
- self.music_separator.offload()
135
-
136
- if params.vad_filter:
137
- # Explicit value set for float('inf') from gr.Number()
138
- if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
139
- params.max_speech_duration_s = float('inf')
140
-
141
- vad_options = VadOptions(
142
- threshold=params.threshold,
143
- min_speech_duration_ms=params.min_speech_duration_ms,
144
- max_speech_duration_s=params.max_speech_duration_s,
145
- min_silence_duration_ms=params.min_silence_duration_ms,
146
- speech_pad_ms=params.speech_pad_ms
147
- )
148
-
149
- audio, speech_chunks = self.vad.run(
150
- audio=audio,
151
- vad_parameters=vad_options,
152
- progress=progress
153
- )
154
-
155
- result, elapsed_time = self.transcribe(
156
- audio,
157
- progress,
158
- *astuple(params)
159
- )
160
-
161
- if params.vad_filter:
162
- result = self.vad.restore_speech_timestamps(
163
- segments=result,
164
- speech_chunks=speech_chunks,
165
- )
166
-
167
- if params.is_diarize:
168
- result, elapsed_time_diarization = self.diarizer.run(
169
- audio=audio,
170
- use_auth_token=params.hf_token,
171
- transcribed_result=result,
172
- )
173
- elapsed_time += elapsed_time_diarization
174
- return result, elapsed_time
175
-
176
- def transcribe_file(self,
177
- files: Optional[List] = None,
178
- input_folder_path: Optional[str] = None,
179
- file_format: str = "SRT",
180
- add_timestamp: bool = True,
181
- progress=gr.Progress(),
182
- *whisper_params,
183
- ) -> list:
184
- """
185
- Write subtitle file from Files
186
-
187
- Parameters
188
- ----------
189
- files: list
190
- List of files to transcribe from gr.Files()
191
- input_folder_path: str
192
- Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
193
- this will be used instead.
194
- file_format: str
195
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
196
- add_timestamp: bool
197
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
198
- progress: gr.Progress
199
- Indicator to show progress directly in gradio.
200
- *whisper_params: tuple
201
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
202
-
203
- Returns
204
- ----------
205
- result_str:
206
- Result of transcription to return to gr.Textbox()
207
- result_file_path:
208
- Output file path to return to gr.Files()
209
- """
210
- try:
211
- if input_folder_path:
212
- files = get_media_files(input_folder_path)
213
- if isinstance(files, str):
214
- files = [files]
215
- if files and isinstance(files[0], gr.utils.NamedString):
216
- files = [file.name for file in files]
217
-
218
- files_info = {}
219
- for file in files:
220
-
221
- ## Detect language
222
- #model = whisper.load_model("base")
223
- params = WhisperParameters.as_value(*whisper_params)
224
- model = whisper.load_model(params.model_size)
225
- mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(whisper.load_audio(file))).to(model.device)
226
- _, probs = model.detect_language(mel)
227
- file_language = "not"
228
- for key,value in whisper.tokenizer.LANGUAGES.items():
229
- if key == str(max(probs, key=probs.get)):
230
- file_language = value.capitalize()
231
- break
232
-
233
- transcribed_segments, time_for_task = self.run(
234
- file,
235
- progress,
236
- add_timestamp,
237
- *whisper_params,
238
- )
239
-
240
- file_name, file_ext = os.path.splitext(os.path.basename(file))
241
- subtitle, file_path = self.generate_and_write_file(
242
- file_name=file_name,
243
- transcribed_segments=transcribed_segments,
244
- add_timestamp=add_timestamp,
245
- file_format=file_format,
246
- output_dir=self.output_dir
247
- )
248
-
249
- files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path, "lang": file_language}
250
-
251
- total_result = ''
252
- total_info = ''
253
- total_time = 0
254
- for file_name, info in files_info.items():
255
- total_result += f'{info["subtitle"]}'
256
- total_time += info["time_for_task"]
257
- #total_info += f'{info["lang"]}'
258
- total_info += f"Language {info['lang']} detected for file '{file_name}{file_ext}'"
259
-
260
- #result_str = f"Processing of file '{file_name}{file_ext}' done in {self.format_time(total_time)}:\n\n{total_result}"
261
- total_info += f"\nTranscription process done in {self.format_time(total_time)}"
262
- result_str = total_result
263
- result_file_path = [info['path'] for info in files_info.values()]
264
-
265
- return [result_str, result_file_path, total_info]
266
-
267
- except Exception as e:
268
- print(f"Error transcribing file: {e}")
269
- finally:
270
- self.release_cuda_memory()
271
-
272
- def transcribe_mic(self,
273
- mic_audio: str,
274
- file_format: str = "SRT",
275
- add_timestamp: bool = True,
276
- progress=gr.Progress(),
277
- *whisper_params,
278
- ) -> list:
279
- """
280
- Write subtitle file from microphone
281
-
282
- Parameters
283
- ----------
284
- mic_audio: str
285
- Audio file path from gr.Microphone()
286
- file_format: str
287
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
288
- add_timestamp: bool
289
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
290
- progress: gr.Progress
291
- Indicator to show progress directly in gradio.
292
- *whisper_params: tuple
293
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
294
-
295
- Returns
296
- ----------
297
- result_str:
298
- Result of transcription to return to gr.Textbox()
299
- result_file_path:
300
- Output file path to return to gr.Files()
301
- """
302
- try:
303
- progress(0, desc="Loading Audio..")
304
- transcribed_segments, time_for_task = self.run(
305
- mic_audio,
306
- progress,
307
- add_timestamp,
308
- *whisper_params,
309
- )
310
- progress(1, desc="Completed!")
311
-
312
- subtitle, result_file_path = self.generate_and_write_file(
313
- file_name="Mic",
314
- transcribed_segments=transcribed_segments,
315
- add_timestamp=add_timestamp,
316
- file_format=file_format,
317
- output_dir=self.output_dir
318
- )
319
-
320
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
321
- return [result_str, result_file_path]
322
- except Exception as e:
323
- print(f"Error transcribing file: {e}")
324
- finally:
325
- self.release_cuda_memory()
326
-
327
- def transcribe_youtube(self,
328
- youtube_link: str,
329
- file_format: str = "SRT",
330
- add_timestamp: bool = True,
331
- progress=gr.Progress(),
332
- *whisper_params,
333
- ) -> list:
334
- """
335
- Write subtitle file from Youtube
336
-
337
- Parameters
338
- ----------
339
- youtube_link: str
340
- URL of the Youtube video to transcribe from gr.Textbox()
341
- file_format: str
342
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
343
- add_timestamp: bool
344
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
345
- progress: gr.Progress
346
- Indicator to show progress directly in gradio.
347
- *whisper_params: tuple
348
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
349
-
350
- Returns
351
- ----------
352
- result_str:
353
- Result of transcription to return to gr.Textbox()
354
- result_file_path:
355
- Output file path to return to gr.Files()
356
- """
357
- try:
358
- progress(0, desc="Loading Audio from Youtube..")
359
- yt = get_ytdata(youtube_link)
360
- audio = get_ytaudio(yt)
361
-
362
- transcribed_segments, time_for_task = self.run(
363
- audio,
364
- progress,
365
- add_timestamp,
366
- *whisper_params,
367
- )
368
-
369
- progress(1, desc="Completed!")
370
-
371
- file_name = safe_filename(yt.title)
372
- subtitle, result_file_path = self.generate_and_write_file(
373
- file_name=file_name,
374
- transcribed_segments=transcribed_segments,
375
- add_timestamp=add_timestamp,
376
- file_format=file_format,
377
- output_dir=self.output_dir
378
- )
379
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
380
-
381
- if os.path.exists(audio):
382
- os.remove(audio)
383
-
384
- return [result_str, result_file_path]
385
-
386
- except Exception as e:
387
- print(f"Error transcribing file: {e}")
388
- finally:
389
- self.release_cuda_memory()
390
-
391
- @staticmethod
392
- def generate_and_write_file(file_name: str,
393
- transcribed_segments: list,
394
- add_timestamp: bool,
395
- file_format: str,
396
- output_dir: str
397
- ) -> str:
398
- """
399
- Writes subtitle file
400
-
401
- Parameters
402
- ----------
403
- file_name: str
404
- Output file name
405
- transcribed_segments: list
406
- Text segments transcribed from audio
407
- add_timestamp: bool
408
- Determines whether to add a timestamp to the end of the filename.
409
- file_format: str
410
- File format to write. Supported formats: [SRT, WebVTT, txt]
411
- output_dir: str
412
- Directory path of the output
413
-
414
- Returns
415
- ----------
416
- content: str
417
- Result of the transcription
418
- output_path: str
419
- output file path
420
- """
421
- if add_timestamp:
422
- timestamp = datetime.now().strftime("%m%d%H%M%S")
423
- output_path = os.path.join(output_dir, f"{file_name} - {timestamp}")
424
- else:
425
- output_path = os.path.join(output_dir, f"{file_name}")
426
-
427
- file_format = file_format.strip().lower()
428
- if file_format == "srt":
429
- content = get_srt(transcribed_segments)
430
- output_path += '.srt'
431
-
432
- elif file_format == "webvtt":
433
- content = get_vtt(transcribed_segments)
434
- output_path += '.vtt'
435
-
436
- elif file_format == "txt":
437
- content = get_txt(transcribed_segments)
438
- output_path += '.txt'
439
-
440
- write_file(content, output_path)
441
- return content, output_path
442
-
443
- @staticmethod
444
- def format_time(elapsed_time: float) -> str:
445
- """
446
- Get {hours} {minutes} {seconds} time format string
447
-
448
- Parameters
449
- ----------
450
- elapsed_time: str
451
- Elapsed time for transcription
452
-
453
- Returns
454
- ----------
455
- Time format string
456
- """
457
- hours, rem = divmod(elapsed_time, 3600)
458
- minutes, seconds = divmod(rem, 60)
459
-
460
- time_str = ""
461
- if hours:
462
- time_str += f"{hours} hours "
463
- if minutes:
464
- time_str += f"{minutes} minutes "
465
- seconds = round(seconds)
466
- time_str += f"{seconds} seconds"
467
-
468
- return time_str.strip()
469
-
470
- @staticmethod
471
- def get_device():
472
- if torch.cuda.is_available():
473
- return "cuda"
474
- elif torch.backends.mps.is_available():
475
- if not WhisperBase.is_sparse_api_supported():
476
- # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
477
- return "cpu"
478
- return "mps"
479
- else:
480
- return "cpu"
481
-
482
- @staticmethod
483
- def is_sparse_api_supported():
484
- if not torch.backends.mps.is_available():
485
- return False
486
-
487
- try:
488
- device = torch.device("mps")
489
- sparse_tensor = torch.sparse_coo_tensor(
490
- indices=torch.tensor([[0, 1], [2, 3]]),
491
- values=torch.tensor([1, 2]),
492
- size=(4, 4),
493
- device=device
494
- )
495
- return True
496
- except RuntimeError:
497
- return False
498
-
499
- @staticmethod
500
- def release_cuda_memory():
501
- """Release memory"""
502
- if torch.cuda.is_available():
503
- torch.cuda.empty_cache()
504
- torch.cuda.reset_max_memory_allocated()
505
-
506
- @staticmethod
507
- def remove_input_files(file_paths: List[str]):
508
- """Remove gradio cached files"""
509
- if not file_paths:
510
- return
511
-
512
- for file_path in file_paths:
513
- if file_path and os.path.exists(file_path):
514
- os.remove(file_path)
515
-
516
- @staticmethod
517
- def cache_parameters(
518
- whisper_params: WhisperValues,
519
- add_timestamp: bool
520
- ):
521
- """cache parameters to the yaml file"""
522
- cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
523
- cached_whisper_param = whisper_params.to_yaml()
524
- cached_yaml = {**cached_params, **cached_whisper_param}
525
- cached_yaml["whisper"]["add_timestamp"] = add_timestamp
526
-
527
- save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
528
-
529
- @staticmethod
530
- def resample_audio(audio: Union[str, np.ndarray],
531
- new_sample_rate: int = 16000,
532
- original_sample_rate: Optional[int] = None,) -> np.ndarray:
533
- """Resamples audio to 16k sample rate, standard on Whisper model"""
534
- if isinstance(audio, str):
535
- audio, original_sample_rate = torchaudio.load(audio)
536
- else:
537
- if original_sample_rate is None:
538
- raise ValueError("original_sample_rate must be provided when audio is numpy array.")
539
- audio = torch.from_numpy(audio)
540
- resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate)
541
- resampled_audio = resampler(audio).numpy()
542
- return resampled_audio