LAP-DEV commited on
Commit
1cd78ce
·
verified ·
1 Parent(s): df70fdf

Delete modules/whisper/whisper_base.py

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +0 -567
modules/whisper/whisper_base.py DELETED
@@ -1,567 +0,0 @@
1
- import os
2
- import torch
3
- import whisper
4
- import gradio as gr
5
- import torchaudio
6
- from abc import ABC, abstractmethod
7
- from typing import BinaryIO, Union, Tuple, List
8
- import numpy as np
9
- from datetime import datetime
10
- from faster_whisper.vad import VadOptions
11
- from dataclasses import astuple
12
-
13
- from modules.uvr.music_separator import MusicSeparator
14
- from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
15
- UVR_MODELS_DIR)
16
- from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
17
- from modules.utils.youtube_manager import get_ytdata, get_ytaudio
18
- from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
19
- from modules.whisper.whisper_parameter import *
20
- from modules.diarize.diarizer import Diarizer
21
- from modules.vad.silero_vad import SileroVAD
22
-
23
-
24
- class WhisperBase(ABC):
25
- def __init__(self,
26
- model_dir: str = WHISPER_MODELS_DIR,
27
- diarization_model_dir: str = DIARIZATION_MODELS_DIR,
28
- uvr_model_dir: str = UVR_MODELS_DIR,
29
- output_dir: str = OUTPUT_DIR,
30
- ):
31
- self.model_dir = model_dir
32
- self.output_dir = output_dir
33
- os.makedirs(self.output_dir, exist_ok=True)
34
- os.makedirs(self.model_dir, exist_ok=True)
35
- self.diarizer = Diarizer(
36
- model_dir=diarization_model_dir
37
- )
38
- self.vad = SileroVAD()
39
- self.music_separator = MusicSeparator(
40
- model_dir=uvr_model_dir,
41
- output_dir=os.path.join(output_dir, "UVR")
42
- )
43
-
44
- self.model = None
45
- self.current_model_size = None
46
- self.available_models = whisper.available_models()
47
- self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
48
- #self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
49
- self.translatable_models = whisper.available_models()
50
- self.device = self.get_device()
51
- self.available_compute_types = ["float16", "float32"]
52
- self.current_compute_type = "float16" if self.device == "cuda" else "float32"
53
-
54
- @abstractmethod
55
- def transcribe(self,
56
- audio: Union[str, BinaryIO, np.ndarray],
57
- progress: gr.Progress = gr.Progress(),
58
- *whisper_params,
59
- ):
60
- """Inference whisper model to transcribe"""
61
- pass
62
-
63
- @abstractmethod
64
- def update_model(self,
65
- model_size: str,
66
- compute_type: str,
67
- progress: gr.Progress = gr.Progress()
68
- ):
69
- """Initialize whisper model"""
70
- pass
71
-
72
- def run(self,
73
- audio: Union[str, BinaryIO, np.ndarray],
74
- progress: gr.Progress = gr.Progress(),
75
- add_timestamp: bool = True,
76
- *whisper_params,
77
- ) -> Tuple[List[dict], float]:
78
- """
79
- Run transcription with conditional pre-processing and post-processing.
80
- The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
81
- The diarization will be performed in post-processing, if enabled.
82
-
83
- Parameters
84
- ----------
85
- audio: Union[str, BinaryIO, np.ndarray]
86
- Audio input. This can be file path or binary type.
87
- progress: gr.Progress
88
- Indicator to show progress directly in gradio.
89
- add_timestamp: bool
90
- Whether to add a timestamp at the end of the filename.
91
- *whisper_params: tuple
92
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
93
-
94
- Returns
95
- ----------
96
- segments_result: List[dict]
97
- list of dicts that includes start, end timestamps and transcribed text
98
- elapsed_time: float
99
- elapsed time for running
100
- """
101
- params = WhisperParameters.as_value(*whisper_params)
102
-
103
- self.cache_parameters(
104
- whisper_params=params,
105
- add_timestamp=add_timestamp
106
- )
107
-
108
- if params.lang is None:
109
- pass
110
- elif params.lang == "Automatic Detection":
111
- params.lang = None
112
- else:
113
- language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
114
- params.lang = language_code_dict[params.lang]
115
-
116
- if params.is_bgm_separate:
117
- music, audio, _ = self.music_separator.separate(
118
- audio=audio,
119
- model_name=params.uvr_model_size,
120
- device=params.uvr_device,
121
- segment_size=params.uvr_segment_size,
122
- save_file=params.uvr_save_file,
123
- progress=progress
124
- )
125
-
126
- if audio.ndim >= 2:
127
- audio = audio.mean(axis=1)
128
- if self.music_separator.audio_info is None:
129
- origin_sample_rate = 16000
130
- else:
131
- origin_sample_rate = self.music_separator.audio_info.sample_rate
132
- audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
133
-
134
- if params.uvr_enable_offload:
135
- self.music_separator.offload()
136
-
137
- if params.vad_filter:
138
- # Explicit value set for float('inf') from gr.Number()
139
- if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
140
- params.max_speech_duration_s = float('inf')
141
-
142
- vad_options = VadOptions(
143
- threshold=params.threshold,
144
- min_speech_duration_ms=params.min_speech_duration_ms,
145
- max_speech_duration_s=params.max_speech_duration_s,
146
- min_silence_duration_ms=params.min_silence_duration_ms,
147
- speech_pad_ms=params.speech_pad_ms
148
- )
149
-
150
- audio, speech_chunks = self.vad.run(
151
- audio=audio,
152
- vad_parameters=vad_options,
153
- progress=progress
154
- )
155
-
156
- result, elapsed_time = self.transcribe(
157
- audio,
158
- progress,
159
- *astuple(params)
160
- )
161
-
162
- if params.vad_filter:
163
- result = self.vad.restore_speech_timestamps(
164
- segments=result,
165
- speech_chunks=speech_chunks,
166
- )
167
-
168
- if params.is_diarize:
169
- result, elapsed_time_diarization = self.diarizer.run(
170
- audio=audio,
171
- use_auth_token=params.hf_token,
172
- transcribed_result=result,
173
- )
174
- elapsed_time += elapsed_time_diarization
175
- return result, elapsed_time
176
-
177
- def transcribe_file(self,
178
- files: Optional[List] = None,
179
- input_folder_path: Optional[str] = None,
180
- file_format: str = "SRT",
181
- add_timestamp: bool = True,
182
- progress=gr.Progress(),
183
- *whisper_params,
184
- ) -> list:
185
- """
186
- Write subtitle file from Files
187
-
188
- Parameters
189
- ----------
190
- files: list
191
- List of files to transcribe from gr.Files()
192
- input_folder_path: str
193
- Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
194
- this will be used instead.
195
- file_format: str
196
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
197
- add_timestamp: bool
198
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
199
- progress: gr.Progress
200
- Indicator to show progress directly in gradio.
201
- *whisper_params: tuple
202
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
203
-
204
- Returns
205
- ----------
206
- result_str:
207
- Result of transcription to return to gr.Textbox()
208
- result_file_path:
209
- Output file path to return to gr.Files()
210
- """
211
- try:
212
- if input_folder_path:
213
- files = get_media_files(input_folder_path)
214
- if isinstance(files, str):
215
- files = [files]
216
- if files and isinstance(files[0], gr.utils.NamedString):
217
- files = [file.name for file in files]
218
-
219
- ## Load model to detect language
220
- model = whisper.load_model("base")
221
-
222
- files_info = {}
223
- files_to_download = {}
224
- for file in files:
225
-
226
- ## Detect language
227
- #params = WhisperParameters.as_value(*whisper_params)
228
- #model = whisper.load_model(params.model_size)
229
- mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(whisper.load_audio(file))).to(model.device)
230
- _, probs = model.detect_language(mel)
231
- file_language = ""
232
- for key,value in whisper.tokenizer.LANGUAGES.items():
233
- if key == str(max(probs, key=probs.get)):
234
- file_language = value.capitalize()
235
- break
236
-
237
- transcribed_segments, time_for_task = self.run(
238
- file,
239
- progress,
240
- add_timestamp,
241
- *whisper_params,
242
- )
243
-
244
- file_name, file_ext = os.path.splitext(os.path.basename(file))
245
- subtitle, file_path = self.generate_and_write_file(
246
- file_name=file_name,
247
- transcribed_segments=transcribed_segments,
248
- add_timestamp=add_timestamp,
249
- file_format=file_format,
250
- output_dir=self.output_dir
251
- )
252
- files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path, "lang": file_language, "input": (file_name+file_ext)}
253
-
254
- ## Add output file as txt
255
- file_name, file_ext = os.path.splitext(os.path.basename(file))
256
- subtitle, file_path = self.generate_and_write_file(
257
- file_name=file_name,
258
- transcribed_segments=transcribed_segments,
259
- add_timestamp=add_timestamp,
260
- file_format="txt",
261
- output_dir=self.output_dir
262
- )
263
- files_to_download[file_name+"_txt"] = {"path": file_path}
264
-
265
- ## Add output file as srt
266
- file_name, file_ext = os.path.splitext(os.path.basename(file))
267
- subtitle, file_path = self.generate_and_write_file(
268
- file_name=file_name,
269
- transcribed_segments=transcribed_segments,
270
- add_timestamp=add_timestamp,
271
- file_format="srt",
272
- output_dir=self.output_dir
273
- )
274
- files_to_download[file_name+"_srt"] = {"path": file_path}
275
-
276
- total_result = ''
277
- total_info = ''
278
- total_time = 0
279
- for file_name, info in files_info.items():
280
- total_result += f'{info["subtitle"]}'
281
- total_time += info["time_for_task"]
282
- #total_info += f'{info["lang"]}'
283
- total_info += f"Input file: {info['input']}\nLanguage prediction: {info['lang']}\n"
284
-
285
- #result_str = f"Processing of file '{file_name}{file_ext}' done in {self.format_time(total_time)}:\n\n{total_result}"
286
- total_info += f"\nTranscription duration: {self.format_time(total_time)}"
287
- result_str = total_result
288
- result_file_path = [info['path'] for info in files_to_download.values()]
289
-
290
- return [result_str, result_file_path, total_info]
291
-
292
- except Exception as e:
293
- print(f"Error transcribing file: {e}")
294
- finally:
295
- self.release_cuda_memory()
296
-
297
- def transcribe_mic(self,
298
- mic_audio: str,
299
- file_format: str = "SRT",
300
- add_timestamp: bool = True,
301
- progress=gr.Progress(),
302
- *whisper_params,
303
- ) -> list:
304
- """
305
- Write subtitle file from microphone
306
-
307
- Parameters
308
- ----------
309
- mic_audio: str
310
- Audio file path from gr.Microphone()
311
- file_format: str
312
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
313
- add_timestamp: bool
314
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
315
- progress: gr.Progress
316
- Indicator to show progress directly in gradio.
317
- *whisper_params: tuple
318
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
319
-
320
- Returns
321
- ----------
322
- result_str:
323
- Result of transcription to return to gr.Textbox()
324
- result_file_path:
325
- Output file path to return to gr.Files()
326
- """
327
- try:
328
- progress(0, desc="Loading Audio..")
329
- transcribed_segments, time_for_task = self.run(
330
- mic_audio,
331
- progress,
332
- add_timestamp,
333
- *whisper_params,
334
- )
335
- progress(1, desc="Completed!")
336
-
337
- subtitle, result_file_path = self.generate_and_write_file(
338
- file_name="Mic",
339
- transcribed_segments=transcribed_segments,
340
- add_timestamp=add_timestamp,
341
- file_format=file_format,
342
- output_dir=self.output_dir
343
- )
344
-
345
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
346
- return [result_str, result_file_path]
347
- except Exception as e:
348
- print(f"Error transcribing file: {e}")
349
- finally:
350
- self.release_cuda_memory()
351
-
352
- def transcribe_youtube(self,
353
- youtube_link: str,
354
- file_format: str = "SRT",
355
- add_timestamp: bool = True,
356
- progress=gr.Progress(),
357
- *whisper_params,
358
- ) -> list:
359
- """
360
- Write subtitle file from Youtube
361
-
362
- Parameters
363
- ----------
364
- youtube_link: str
365
- URL of the Youtube video to transcribe from gr.Textbox()
366
- file_format: str
367
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
368
- add_timestamp: bool
369
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
370
- progress: gr.Progress
371
- Indicator to show progress directly in gradio.
372
- *whisper_params: tuple
373
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
374
-
375
- Returns
376
- ----------
377
- result_str:
378
- Result of transcription to return to gr.Textbox()
379
- result_file_path:
380
- Output file path to return to gr.Files()
381
- """
382
- try:
383
- progress(0, desc="Loading Audio from Youtube..")
384
- yt = get_ytdata(youtube_link)
385
- audio = get_ytaudio(yt)
386
-
387
- transcribed_segments, time_for_task = self.run(
388
- audio,
389
- progress,
390
- add_timestamp,
391
- *whisper_params,
392
- )
393
-
394
- progress(1, desc="Completed!")
395
-
396
- file_name = safe_filename(yt.title)
397
- subtitle, result_file_path = self.generate_and_write_file(
398
- file_name=file_name,
399
- transcribed_segments=transcribed_segments,
400
- add_timestamp=add_timestamp,
401
- file_format=file_format,
402
- output_dir=self.output_dir
403
- )
404
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
405
-
406
- if os.path.exists(audio):
407
- os.remove(audio)
408
-
409
- return [result_str, result_file_path]
410
-
411
- except Exception as e:
412
- print(f"Error transcribing file: {e}")
413
- finally:
414
- self.release_cuda_memory()
415
-
416
- @staticmethod
417
- def generate_and_write_file(file_name: str,
418
- transcribed_segments: list,
419
- add_timestamp: bool,
420
- file_format: str,
421
- output_dir: str
422
- ) -> str:
423
- """
424
- Writes subtitle file
425
-
426
- Parameters
427
- ----------
428
- file_name: str
429
- Output file name
430
- transcribed_segments: list
431
- Text segments transcribed from audio
432
- add_timestamp: bool
433
- Determines whether to add a timestamp to the end of the filename.
434
- file_format: str
435
- File format to write. Supported formats: [SRT, WebVTT, txt]
436
- output_dir: str
437
- Directory path of the output
438
-
439
- Returns
440
- ----------
441
- content: str
442
- Result of the transcription
443
- output_path: str
444
- output file path
445
- """
446
- if add_timestamp:
447
- timestamp = datetime.now().strftime("%m%d%H%M%S")
448
- output_path = os.path.join(output_dir, f"{file_name} - {timestamp}")
449
- else:
450
- output_path = os.path.join(output_dir, f"{file_name}")
451
-
452
- file_format = file_format.strip().lower()
453
- if file_format == "srt":
454
- content = get_srt(transcribed_segments)
455
- output_path += '.srt'
456
-
457
- elif file_format == "webvtt":
458
- content = get_vtt(transcribed_segments)
459
- output_path += '.vtt'
460
-
461
- elif file_format == "txt":
462
- content = get_txt(transcribed_segments)
463
- output_path += '.txt'
464
-
465
- write_file(content, output_path)
466
- return content, output_path
467
-
468
- @staticmethod
469
- def format_time(elapsed_time: float) -> str:
470
- """
471
- Get {hours} {minutes} {seconds} time format string
472
-
473
- Parameters
474
- ----------
475
- elapsed_time: str
476
- Elapsed time for transcription
477
-
478
- Returns
479
- ----------
480
- Time format string
481
- """
482
- hours, rem = divmod(elapsed_time, 3600)
483
- minutes, seconds = divmod(rem, 60)
484
-
485
- time_str = ""
486
- if hours:
487
- time_str += f"{hours} hours "
488
- if minutes:
489
- time_str += f"{minutes} minutes "
490
- seconds = round(seconds)
491
- time_str += f"{seconds} seconds"
492
-
493
- return time_str.strip()
494
-
495
- @staticmethod
496
- def get_device():
497
- if torch.cuda.is_available():
498
- return "cuda"
499
- elif torch.backends.mps.is_available():
500
- if not WhisperBase.is_sparse_api_supported():
501
- # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
502
- return "cpu"
503
- return "mps"
504
- else:
505
- return "cpu"
506
-
507
- @staticmethod
508
- def is_sparse_api_supported():
509
- if not torch.backends.mps.is_available():
510
- return False
511
-
512
- try:
513
- device = torch.device("mps")
514
- sparse_tensor = torch.sparse_coo_tensor(
515
- indices=torch.tensor([[0, 1], [2, 3]]),
516
- values=torch.tensor([1, 2]),
517
- size=(4, 4),
518
- device=device
519
- )
520
- return True
521
- except RuntimeError:
522
- return False
523
-
524
- @staticmethod
525
- def release_cuda_memory():
526
- """Release memory"""
527
- if torch.cuda.is_available():
528
- torch.cuda.empty_cache()
529
- torch.cuda.reset_max_memory_allocated()
530
-
531
- @staticmethod
532
- def remove_input_files(file_paths: List[str]):
533
- """Remove gradio cached files"""
534
- if not file_paths:
535
- return
536
-
537
- for file_path in file_paths:
538
- if file_path and os.path.exists(file_path):
539
- os.remove(file_path)
540
-
541
- @staticmethod
542
- def cache_parameters(
543
- whisper_params: WhisperValues,
544
- add_timestamp: bool
545
- ):
546
- """cache parameters to the yaml file"""
547
- cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
548
- cached_whisper_param = whisper_params.to_yaml()
549
- cached_yaml = {**cached_params, **cached_whisper_param}
550
- cached_yaml["whisper"]["add_timestamp"] = add_timestamp
551
-
552
- save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
553
-
554
- @staticmethod
555
- def resample_audio(audio: Union[str, np.ndarray],
556
- new_sample_rate: int = 16000,
557
- original_sample_rate: Optional[int] = None,) -> np.ndarray:
558
- """Resamples audio to 16k sample rate, standard on Whisper model"""
559
- if isinstance(audio, str):
560
- audio, original_sample_rate = torchaudio.load(audio)
561
- else:
562
- if original_sample_rate is None:
563
- raise ValueError("original_sample_rate must be provided when audio is numpy array.")
564
- audio = torch.from_numpy(audio)
565
- resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate)
566
- resampled_audio = resampler(audio).numpy()
567
- return resampled_audio