LAP-DEV commited on
Commit
59cd625
·
verified ·
1 Parent(s): 0058cc2

Delete modules/whisper/faster_whisper_inference.py

Browse files
modules/whisper/faster_whisper_inference.py DELETED
@@ -1,209 +0,0 @@
1
- import os
2
- import time
3
- import huggingface_hub
4
- import numpy as np
5
- import torch
6
- from typing import BinaryIO, Union, Tuple, List
7
- import faster_whisper
8
- from faster_whisper.vad import VadOptions
9
- import ast
10
- import ctranslate2
11
- import whisper
12
- import gradio as gr
13
- from argparse import Namespace
14
-
15
- from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
16
- from modules.whisper.whisper_parameter import *
17
- from modules.whisper.whisper_base import WhisperBase
18
-
19
- class FasterWhisperInference(WhisperBase):
20
- def __init__(self,
21
- model_dir: str = FASTER_WHISPER_MODELS_DIR,
22
- diarization_model_dir: str = DIARIZATION_MODELS_DIR,
23
- uvr_model_dir: str = UVR_MODELS_DIR,
24
- output_dir: str = OUTPUT_DIR,
25
- ):
26
- super().__init__(
27
- model_dir=model_dir,
28
- diarization_model_dir=diarization_model_dir,
29
- uvr_model_dir=uvr_model_dir,
30
- output_dir=output_dir
31
- )
32
- self.model_dir = model_dir
33
- os.makedirs(self.model_dir, exist_ok=True)
34
-
35
- self.model_paths = self.get_model_paths()
36
- self.device = self.get_device()
37
- self.available_models = self.model_paths.keys()
38
-
39
- def transcribe(self,
40
- audio: Union[str, BinaryIO, np.ndarray],
41
- progress: gr.Progress = gr.Progress(),
42
- *whisper_params,
43
- ) -> Tuple[List[dict], float]:
44
- """
45
- transcribe method for faster-whisper.
46
-
47
- Parameters
48
- ----------
49
- audio: Union[str, BinaryIO, np.ndarray]
50
- Audio path or file binary or Audio numpy array
51
- progress: gr.Progress
52
- Indicator to show progress directly in gradio.
53
- *whisper_params: tuple
54
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
55
-
56
- Returns
57
- ----------
58
- segments_result: List[dict]
59
- list of Segment that includes start, end timestamps and transcribed text
60
- elapsed_time: float
61
- elapsed time for transcription
62
- """
63
- start_time = time.time()
64
-
65
- params = WhisperParameters.as_value(*whisper_params)
66
- params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
67
-
68
- if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
69
- self.update_model(params.model_size, params.compute_type, progress)
70
-
71
- segments, info = self.model.transcribe(
72
- audio=audio,
73
- language=params.lang,
74
- task="translate" if params.is_translate else "transcribe",
75
- beam_size=params.beam_size,
76
- log_prob_threshold=params.log_prob_threshold,
77
- no_speech_threshold=params.no_speech_threshold,
78
- best_of=params.best_of,
79
- patience=params.patience,
80
- temperature=params.temperature,
81
- initial_prompt=params.initial_prompt,
82
- compression_ratio_threshold=params.compression_ratio_threshold,
83
- length_penalty=params.length_penalty,
84
- repetition_penalty=params.repetition_penalty,
85
- no_repeat_ngram_size=params.no_repeat_ngram_size,
86
- prefix=params.prefix,
87
- suppress_blank=params.suppress_blank,
88
- suppress_tokens=params.suppress_tokens,
89
- max_initial_timestamp=params.max_initial_timestamp,
90
- word_timestamps=params.word_timestamps,
91
- prepend_punctuations=params.prepend_punctuations,
92
- append_punctuations=params.append_punctuations,
93
- max_new_tokens=params.max_new_tokens,
94
- chunk_length=params.chunk_length,
95
- hallucination_silence_threshold=params.hallucination_silence_threshold,
96
- hotwords=params.hotwords,
97
- language_detection_threshold=params.language_detection_threshold,
98
- language_detection_segments=params.language_detection_segments,
99
- prompt_reset_on_temperature=params.prompt_reset_on_temperature,
100
- )
101
- progress(0, desc="Loading audio...")
102
-
103
- segments_result = []
104
-
105
- dummy_segments = segments
106
- segments_lenght = len(list(dummy_segments))
107
- segments_counter = 1
108
- for segment in segments:
109
- #progress(segment.start / info.duration, desc="Transcribing...")
110
- progress(segments_counter / segments_lenght , desc="Transcribing...")
111
- segments_counter = segments_counter + 1
112
- segments_result.append({
113
- "start": segment.start,
114
- "end": segment.end,
115
- "text": segment.text
116
- })
117
-
118
- elapsed_time = time.time() - start_time
119
- return segments_result, elapsed_time
120
-
121
- def update_model(self,
122
- model_size: str,
123
- compute_type: str,
124
- progress: gr.Progress = gr.Progress()
125
- ):
126
- """
127
- Update current model setting
128
-
129
- Parameters
130
- ----------
131
- model_size: str
132
- Size of whisper model. If you enter the huggingface repo id, it will try to download the model
133
- automatically from huggingface.
134
- compute_type: str
135
- Compute type for transcription.
136
- see more info : https://opennmt.net/CTranslate2/quantization.html
137
- progress: gr.Progress
138
- Indicator to show progress directly in gradio.
139
- """
140
- progress(0, desc="Initializing Model...")
141
-
142
- model_size_dirname = model_size.replace("/", "--") if "/" in model_size else model_size
143
- if model_size not in self.model_paths and model_size_dirname not in self.model_paths:
144
- print(f"Model is not detected. Trying to download \"{model_size}\" from huggingface to "
145
- f"\"{os.path.join(self.model_dir, model_size_dirname)} ...")
146
- huggingface_hub.snapshot_download(
147
- model_size,
148
- local_dir=os.path.join(self.model_dir, model_size_dirname),
149
- )
150
- self.model_paths = self.get_model_paths()
151
- gr.Info(f"Model is downloaded with the name \"{model_size_dirname}\"")
152
-
153
- self.current_model_size = self.model_paths[model_size_dirname]
154
-
155
- local_files_only = False
156
- hf_prefix = "models--Systran--faster-whisper-"
157
- official_model_path = os.path.join(self.model_dir, hf_prefix+model_size)
158
- if ((os.path.isdir(self.current_model_size) and os.path.exists(self.current_model_size)) or
159
- (model_size in faster_whisper.available_models() and os.path.exists(official_model_path))):
160
- local_files_only = True
161
-
162
- self.current_compute_type = compute_type
163
- self.model = faster_whisper.WhisperModel(
164
- device=self.device,
165
- model_size_or_path=self.current_model_size,
166
- download_root=self.model_dir,
167
- compute_type=self.current_compute_type,
168
- local_files_only=local_files_only
169
- )
170
-
171
- def get_model_paths(self):
172
- """
173
- Get available models from models path including fine-tuned model.
174
-
175
- Returns
176
- ----------
177
- Name list of models
178
- """
179
- model_paths = {model:model for model in faster_whisper.available_models()}
180
- faster_whisper_prefix = "models--Systran--faster-whisper-"
181
-
182
- existing_models = os.listdir(self.model_dir)
183
- wrong_dirs = [".locks", "faster_whisper_models_will_be_saved_here"]
184
- existing_models = list(set(existing_models) - set(wrong_dirs))
185
-
186
- for model_name in existing_models:
187
- if faster_whisper_prefix in model_name:
188
- model_name = model_name[len(faster_whisper_prefix):]
189
-
190
- if model_name not in whisper.available_models():
191
- model_paths[model_name] = os.path.join(self.model_dir, model_name)
192
- return model_paths
193
-
194
- @staticmethod
195
- def get_device():
196
- if torch.cuda.is_available():
197
- return "cuda"
198
- else:
199
- return "auto"
200
-
201
- @staticmethod
202
- def format_suppress_tokens_str(suppress_tokens_str: str) -> List[int]:
203
- try:
204
- suppress_tokens = ast.literal_eval(suppress_tokens_str)
205
- if not isinstance(suppress_tokens, list) or not all(isinstance(item, int) for item in suppress_tokens):
206
- raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
207
- return suppress_tokens
208
- except Exception as e:
209
- raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")