jhj0517 commited on
Commit
df6fc6f
·
1 Parent(s): a88b526

Refactor dataclasses

Browse files
Files changed (1) hide show
  1. modules/whisper/data_classes.py +462 -339
modules/whisper/data_classes.py CHANGED
@@ -1,371 +1,494 @@
1
- from dataclasses import dataclass, fields
2
  import gradio as gr
3
- from typing import Optional, Dict
 
 
 
 
4
  import yaml
5
 
6
  from modules.utils.constants import AUTOMATIC_DETECTION
7
 
8
 
9
- @dataclass
10
- class WhisperParameters:
11
- model_size: gr.Dropdown
12
- lang: gr.Dropdown
13
- is_translate: gr.Checkbox
14
- beam_size: gr.Number
15
- log_prob_threshold: gr.Number
16
- no_speech_threshold: gr.Number
17
- compute_type: gr.Dropdown
18
- best_of: gr.Number
19
- patience: gr.Number
20
- condition_on_previous_text: gr.Checkbox
21
- prompt_reset_on_temperature: gr.Slider
22
- initial_prompt: gr.Textbox
23
- temperature: gr.Slider
24
- compression_ratio_threshold: gr.Number
25
- vad_filter: gr.Checkbox
26
- threshold: gr.Slider
27
- min_speech_duration_ms: gr.Number
28
- max_speech_duration_s: gr.Number
29
- min_silence_duration_ms: gr.Number
30
- speech_pad_ms: gr.Number
31
- batch_size: gr.Number
32
- is_diarize: gr.Checkbox
33
- hf_token: gr.Textbox
34
- diarization_device: gr.Dropdown
35
- length_penalty: gr.Number
36
- repetition_penalty: gr.Number
37
- no_repeat_ngram_size: gr.Number
38
- prefix: gr.Textbox
39
- suppress_blank: gr.Checkbox
40
- suppress_tokens: gr.Textbox
41
- max_initial_timestamp: gr.Number
42
- word_timestamps: gr.Checkbox
43
- prepend_punctuations: gr.Textbox
44
- append_punctuations: gr.Textbox
45
- max_new_tokens: gr.Number
46
- chunk_length: gr.Number
47
- hallucination_silence_threshold: gr.Number
48
- hotwords: gr.Textbox
49
- language_detection_threshold: gr.Number
50
- language_detection_segments: gr.Number
51
- is_bgm_separate: gr.Checkbox
52
- uvr_model_size: gr.Dropdown
53
- uvr_device: gr.Dropdown
54
- uvr_segment_size: gr.Number
55
- uvr_save_file: gr.Checkbox
56
- uvr_enable_offload: gr.Checkbox
57
- """
58
- A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
59
- This data class is used to mitigate the key-value problem between Gradio components and function parameters.
60
- Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
61
- See more about Gradio pre-processing: https://www.gradio.app/docs/components
62
 
63
- Attributes
64
- ----------
65
- model_size: gr.Dropdown
66
- Whisper model size.
67
-
68
- lang: gr.Dropdown
69
- Source language of the file to transcribe.
70
-
71
- is_translate: gr.Checkbox
72
- Boolean value that determines whether to translate to English.
73
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
74
-
75
- beam_size: gr.Number
76
- Int value that is used for decoding option.
77
-
78
- log_prob_threshold: gr.Number
79
- If the average log probability over sampled tokens is below this value, treat as failed.
80
-
81
- no_speech_threshold: gr.Number
82
- If the no_speech probability is higher than this value AND
83
- the average log probability over sampled tokens is below `log_prob_threshold`,
84
- consider the segment as silent.
85
-
86
- compute_type: gr.Dropdown
87
- compute type for transcription.
88
- see more info : https://opennmt.net/CTranslate2/quantization.html
89
-
90
- best_of: gr.Number
91
- Number of candidates when sampling with non-zero temperature.
92
-
93
- patience: gr.Number
94
- Beam search patience factor.
95
-
96
- condition_on_previous_text: gr.Checkbox
97
- if True, the previous output of the model is provided as a prompt for the next window;
98
- disabling may make the text inconsistent across windows, but the model becomes less prone to
99
- getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
100
-
101
- initial_prompt: gr.Textbox
102
- Optional text to provide as a prompt for the first window. This can be used to provide, or
103
- "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
104
- to make it more likely to predict those word correctly.
105
-
106
- temperature: gr.Slider
107
- Temperature for sampling. It can be a tuple of temperatures,
108
- which will be successively used upon failures according to either
109
- `compression_ratio_threshold` or `log_prob_threshold`.
110
-
111
- compression_ratio_threshold: gr.Number
112
- If the gzip compression ratio is above this value, treat as failed
113
-
114
- vad_filter: gr.Checkbox
115
- Enable the voice activity detection (VAD) to filter out parts of the audio
116
- without speech. This step is using the Silero VAD model
117
- https://github.com/snakers4/silero-vad.
118
-
119
- threshold: gr.Slider
120
- This parameter is related with Silero VAD. Speech threshold.
121
- Silero VAD outputs speech probabilities for each audio chunk,
122
- probabilities ABOVE this value are considered as SPEECH. It is better to tune this
123
- parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
124
-
125
- min_speech_duration_ms: gr.Number
126
- This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out.
127
-
128
- max_speech_duration_s: gr.Number
129
- This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer
130
- than max_speech_duration_s will be split at the timestamp of the last silence that
131
- lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
132
- split aggressively just before max_speech_duration_s.
133
-
134
- min_silence_duration_ms: gr.Number
135
- This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
136
- before separating it
137
-
138
- speech_pad_ms: gr.Number
139
- This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
140
-
141
- batch_size: gr.Number
142
- This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
143
-
144
- is_diarize: gr.Checkbox
145
- This parameter is related with whisperx. Boolean value that determines whether to diarize or not.
146
-
147
- hf_token: gr.Textbox
148
- This parameter is related with whisperx. Huggingface token is needed to download diarization models.
149
- Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
150
-
151
- diarization_device: gr.Dropdown
152
- This parameter is related with whisperx. Device to run diarization model
153
-
154
- length_penalty: gr.Number
155
- This parameter is related to faster-whisper. Exponential length penalty constant.
156
-
157
- repetition_penalty: gr.Number
158
- This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens
159
- (set > 1 to penalize).
160
 
161
- no_repeat_ngram_size: gr.Number
162
- This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- prefix: gr.Textbox
165
- This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window.
166
 
167
- suppress_blank: gr.Checkbox
168
- This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- suppress_tokens: gr.Textbox
171
- This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set
172
- of symbols as defined in the model config.json file.
173
 
174
- max_initial_timestamp: gr.Number
175
- This parameter is related to faster-whisper. The initial timestamp cannot be later than this.
176
 
177
- word_timestamps: gr.Checkbox
178
- This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern
179
- and dynamic time warping, and include the timestamps for each word in each segment.
 
 
 
 
180
 
181
- prepend_punctuations: gr.Textbox
182
- This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
183
- with the next word.
184
 
185
- append_punctuations: gr.Textbox
186
- This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
187
- with the previous word.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
- max_new_tokens: gr.Number
190
- This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set,
191
- the maximum will be set by the default max_length.
192
 
193
- chunk_length: gr.Number
194
- This parameter is related to faster-whisper and insanely-fast-whisper. The length of audio segments in seconds.
195
- If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- hallucination_silence_threshold: gr.Number
198
- This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
199
- (in seconds) when a possible hallucination is detected.
200
 
201
- hotwords: gr.Textbox
202
- This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
- language_detection_threshold: gr.Number
205
- This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected.
206
 
207
- language_detection_segments: gr.Number
208
- This parameter is related to faster-whisper. Number of segments to consider for the language detection.
209
-
210
- is_separate_bgm: gr.Checkbox
211
- This parameter is related to UVR. Boolean value that determines whether to separate bgm or not.
212
-
213
- uvr_model_size: gr.Dropdown
214
- This parameter is related to UVR. UVR model size.
215
-
216
- uvr_device: gr.Dropdown
217
- This parameter is related to UVR. Device to run UVR model.
218
-
219
- uvr_segment_size: gr.Number
220
- This parameter is related to UVR. Segment size for UVR model.
221
-
222
- uvr_save_file: gr.Checkbox
223
- This parameter is related to UVR. Boolean value that determines whether to save the file or not.
224
-
225
- uvr_enable_offload: gr.Checkbox
226
- This parameter is related to UVR. Boolean value that determines whether to offload the UVR model or not
227
- after each transcription.
228
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- def as_list(self) -> list:
231
- """
232
- Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
233
- See more about Gradio pre-processing: : https://www.gradio.app/docs/components
234
 
235
- Returns
236
- ----------
237
- A list of Gradio components
238
- """
239
- return [getattr(self, f.name) for f in fields(self)]
240
 
241
- @staticmethod
242
- def as_value(*args) -> 'WhisperValues':
243
- """
244
- To use Whisper parameters in function after Gradio post-processing.
245
- See more about Gradio post-processing: : https://www.gradio.app/docs/components
 
 
246
 
247
- Returns
248
- ----------
249
- WhisperValues
250
- Data class that has values of parameters
251
- """
252
- return WhisperValues(*args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
- @dataclass
256
- class WhisperValues:
257
- model_size: str = "large-v2"
258
- lang: Optional[str] = None
259
- is_translate: bool = False
260
- beam_size: int = 5
261
- log_prob_threshold: float = -1.0
262
- no_speech_threshold: float = 0.6
263
- compute_type: str = "float16"
264
- best_of: int = 5
265
- patience: float = 1.0
266
- condition_on_previous_text: bool = True
267
- prompt_reset_on_temperature: float = 0.5
268
- initial_prompt: Optional[str] = None
269
- temperature: float = 0.0
270
- compression_ratio_threshold: float = 2.4
271
- vad_filter: bool = False
272
- threshold: float = 0.5
273
- min_speech_duration_ms: int = 250
274
- max_speech_duration_s: float = float("inf")
275
- min_silence_duration_ms: int = 2000
276
- speech_pad_ms: int = 400
277
- batch_size: int = 24
278
- is_diarize: bool = False
279
- hf_token: str = ""
280
- diarization_device: str = "cuda"
281
- length_penalty: float = 1.0
282
- repetition_penalty: float = 1.0
283
- no_repeat_ngram_size: int = 0
284
- prefix: Optional[str] = None
285
- suppress_blank: bool = True
286
- suppress_tokens: Optional[str] = "[-1]"
287
- max_initial_timestamp: float = 0.0
288
- word_timestamps: bool = False
289
- prepend_punctuations: Optional[str] = "\"'“¿([{-"
290
- append_punctuations: Optional[str] = "\"'.。,,!!??::”)]}、"
291
- max_new_tokens: Optional[int] = None
292
- chunk_length: Optional[int] = 30
293
- hallucination_silence_threshold: Optional[float] = None
294
- hotwords: Optional[str] = None
295
- language_detection_threshold: Optional[float] = None
296
- language_detection_segments: int = 1
297
- is_bgm_separate: bool = False
298
- uvr_model_size: str = "UVR-MDX-NET-Inst_HQ_4"
299
- uvr_device: str = "cuda"
300
- uvr_segment_size: int = 256
301
- uvr_save_file: bool = False
302
- uvr_enable_offload: bool = True
303
- """
304
- A data class to use Whisper parameters.
305
- """
306
 
307
- def to_yaml(self) -> Dict:
 
 
 
 
 
 
 
 
308
  data = {
309
- "whisper": {
310
- "model_size": self.model_size,
311
- "lang": AUTOMATIC_DETECTION.unwrap() if self.lang is None else self.lang,
312
- "is_translate": self.is_translate,
313
- "beam_size": self.beam_size,
314
- "log_prob_threshold": self.log_prob_threshold,
315
- "no_speech_threshold": self.no_speech_threshold,
316
- "best_of": self.best_of,
317
- "patience": self.patience,
318
- "condition_on_previous_text": self.condition_on_previous_text,
319
- "prompt_reset_on_temperature": self.prompt_reset_on_temperature,
320
- "initial_prompt": None if not self.initial_prompt else self.initial_prompt,
321
- "temperature": self.temperature,
322
- "compression_ratio_threshold": self.compression_ratio_threshold,
323
- "batch_size": self.batch_size,
324
- "length_penalty": self.length_penalty,
325
- "repetition_penalty": self.repetition_penalty,
326
- "no_repeat_ngram_size": self.no_repeat_ngram_size,
327
- "prefix": None if not self.prefix else self.prefix,
328
- "suppress_blank": self.suppress_blank,
329
- "suppress_tokens": self.suppress_tokens,
330
- "max_initial_timestamp": self.max_initial_timestamp,
331
- "word_timestamps": self.word_timestamps,
332
- "prepend_punctuations": self.prepend_punctuations,
333
- "append_punctuations": self.append_punctuations,
334
- "max_new_tokens": self.max_new_tokens,
335
- "chunk_length": self.chunk_length,
336
- "hallucination_silence_threshold": self.hallucination_silence_threshold,
337
- "hotwords": None if not self.hotwords else self.hotwords,
338
- "language_detection_threshold": self.language_detection_threshold,
339
- "language_detection_segments": self.language_detection_segments,
340
- },
341
- "vad": {
342
- "vad_filter": self.vad_filter,
343
- "threshold": self.threshold,
344
- "min_speech_duration_ms": self.min_speech_duration_ms,
345
- "max_speech_duration_s": self.max_speech_duration_s,
346
- "min_silence_duration_ms": self.min_silence_duration_ms,
347
- "speech_pad_ms": self.speech_pad_ms,
348
- },
349
- "diarization": {
350
- "is_diarize": self.is_diarize,
351
- "hf_token": self.hf_token
352
- },
353
- "bgm_separation": {
354
- "is_separate_bgm": self.is_bgm_separate,
355
- "model_size": self.uvr_model_size,
356
- "segment_size": self.uvr_segment_size,
357
- "save_file": self.uvr_save_file,
358
- "enable_offload": self.uvr_enable_offload
359
- },
360
  }
361
  return data
362
 
363
- def as_list(self) -> list:
364
- """
365
- Converts the data class attributes into a list
366
-
367
- Returns
368
- ----------
369
- A list of Whisper parameters
370
- """
371
- return [getattr(self, f.name) for f in fields(self)]
 
 
1
  import gradio as gr
2
+ import torch
3
+ from typing import Optional, Dict, List
4
+ from pydantic import BaseModel, Field, field_validator
5
+ from gradio_i18n import Translate, gettext as _
6
+ from enum import Enum
7
  import yaml
8
 
9
  from modules.utils.constants import AUTOMATIC_DETECTION
10
 
11
 
12
+ class WhisperImpl(Enum):
13
+ WHISPER = "whisper"
14
+ FASTER_WHISPER = "faster-whisper"
15
+ INSANELY_FAST_WHISPER = "insanely_fast_whisper"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ class VadParams(BaseModel):
19
+ """Voice Activity Detection parameters"""
20
+ vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts")
21
+ threshold: float = Field(
22
+ default=0.5,
23
+ ge=0.0,
24
+ le=1.0,
25
+ description="Speech threshold for Silero VAD. Probabilities above this value are considered speech"
26
+ )
27
+ min_speech_duration_ms: int = Field(
28
+ default=250,
29
+ ge=0,
30
+ description="Final speech chunks shorter than this are discarded"
31
+ )
32
+ max_speech_duration_s: float = Field(
33
+ default=float("inf"),
34
+ gt=0,
35
+ description="Maximum duration of speech chunks in seconds"
36
+ )
37
+ min_silence_duration_ms: int = Field(
38
+ default=2000,
39
+ ge=0,
40
+ description="Minimum silence duration between speech chunks"
41
+ )
42
+ speech_pad_ms: int = Field(
43
+ default=400,
44
+ ge=0,
45
+ description="Padding added to each side of speech chunks"
46
+ )
47
 
48
+ def to_dict(self) -> Dict:
49
+ return self.model_dump()
50
 
51
+ @classmethod
52
+ def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]:
53
+ defaults = defaults or {}
54
+ return [
55
+ gr.Checkbox(label=_("Enable Silero VAD Filter"), value=defaults.get("vad_filter", cls.vad_filter),
56
+ interactive=True,
57
+ info=_("Enable this to transcribe only detected voice")),
58
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
59
+ value=defaults.get("threshold", cls.threshold),
60
+ info="Lower it to be more sensitive to small sounds."),
61
+ gr.Number(label="Minimum Speech Duration (ms)", precision=0,
62
+ value=defaults.get("min_speech_duration_ms", cls.min_speech_duration_ms),
63
+ info="Final speech chunks shorter than this time are thrown out"),
64
+ gr.Number(label="Maximum Speech Duration (s)",
65
+ value=defaults.get("max_speech_duration_s", cls.max_speech_duration_s),
66
+ info="Maximum duration of speech chunks in \"seconds\"."),
67
+ gr.Number(label="Minimum Silence Duration (ms)", precision=0,
68
+ value=defaults.get("min_silence_duration_ms", cls.min_silence_duration_ms),
69
+ info="In the end of each speech chunk wait for this time"
70
+ " before separating it"),
71
+ gr.Number(label="Speech Padding (ms)", precision=0,
72
+ value=defaults.get("speech_pad_ms", cls.speech_pad_ms),
73
+ info="Final speech chunks are padded by this time each side")
74
+ ]
75
 
 
 
 
76
 
 
 
77
 
78
+ class DiarizationParams(BaseModel):
79
+ """Speaker diarization parameters"""
80
+ is_diarize: bool = Field(default=False, description="Enable speaker diarization")
81
+ hf_token: str = Field(
82
+ default="",
83
+ description="Hugging Face token for downloading diarization models"
84
+ )
85
 
86
+ def to_dict(self) -> Dict:
87
+ return self.model_dump()
 
88
 
89
+ @classmethod
90
+ def to_gradio_inputs(cls,
91
+ defaults: Optional[Dict] = None,
92
+ available_devices: Optional[List] = None,
93
+ device: Optional[str] = None) -> List[gr.components.base.FormComponent]:
94
+ defaults = defaults or {}
95
+ return [
96
+ gr.Checkbox(
97
+ label=_("Enable Diarization"),
98
+ value=defaults.get("is_diarize", cls.is_diarize),
99
+ info=_("Enable speaker diarization")
100
+ ),
101
+ gr.Textbox(
102
+ label=_("HuggingFace Token"),
103
+ value=defaults.get("hf_token", cls.hf_token),
104
+ info=_("This is only needed the first time you download the model")
105
+ ),
106
+ gr.Dropdown(
107
+ label=_("Device"),
108
+ choices=["cpu", "cuda"] if available_devices is None else available_devices,
109
+ value="cuda" if device is None else device,
110
+ info=_("Device to run diarization model")
111
+ )
112
+ ]
113
 
 
 
 
114
 
115
+ class BGMSeparationParams(BaseModel):
116
+ """Background music separation parameters"""
117
+ is_separate_bgm: bool = Field(default=False, description="Enable background music separation")
118
+ model_size: str = Field(
119
+ default="UVR-MDX-NET-Inst_HQ_4",
120
+ description="UVR model size"
121
+ )
122
+ segment_size: int = Field(
123
+ default=256,
124
+ gt=0,
125
+ description="Segment size for UVR model"
126
+ )
127
+ save_file: bool = Field(
128
+ default=False,
129
+ description="Whether to save separated audio files"
130
+ )
131
+ enable_offload: bool = Field(
132
+ default=True,
133
+ description="Offload UVR model after transcription"
134
+ )
135
 
136
+ def to_dict(self) -> Dict:
137
+ return self.model_dump()
 
138
 
139
+ @classmethod
140
+ def to_gradio_input(cls,
141
+ defaults: Optional[Dict] = None,
142
+ available_devices: Optional[List] = None,
143
+ device: Optional[str] = None,
144
+ available_models: Optional[List] = None) -> List[gr.components.base.FormComponent]:
145
+ defaults = defaults or {}
146
+ return [
147
+ gr.Checkbox(
148
+ label=_("Enable Background Music Remover Filter"),
149
+ value=defaults.get("is_separate_bgm", cls.is_separate_bgm),
150
+ interactive=True,
151
+ info=_("Enabling this will remove background music")
152
+ ),
153
+ gr.Dropdown(
154
+ label=_("Device"),
155
+ choices=["cpu", "cuda"] if available_devices is None else available_devices,
156
+ value="cuda" if device is None else device,
157
+ info=_("Device to run UVR model")
158
+ ),
159
+ gr.Dropdown(
160
+ label=_("Model"),
161
+ choices=["UVR-MDX-NET-Inst_HQ_4", "UVR-MDX-NET-Inst_3"] if available_models is None else available_models,
162
+ value=defaults.get("model_size", cls.model_size),
163
+ info=_("UVR model size")
164
+ ),
165
+ gr.Number(
166
+ label="Segment Size",
167
+ value=defaults.get("segment_size", cls.segment_size),
168
+ precision=0,
169
+ info="Segment size for UVR model"
170
+ ),
171
+ gr.Checkbox(
172
+ label=_("Save separated files to output"),
173
+ value=defaults.get("save_file", cls.save_file),
174
+ info=_("Whether to save separated audio files")
175
+ ),
176
+ gr.Checkbox(
177
+ label=_("Offload sub model after removing background music"),
178
+ value=defaults.get("enable_offload", cls.enable_offload),
179
+ info=_("Offload UVR model after transcription")
180
+ )
181
+ ]
182
 
 
 
183
 
184
+ class WhisperParams(BaseModel):
185
+ """Whisper parameters"""
186
+ model_size: str = Field(default="large-v2", description="Whisper model size")
187
+ lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe")
188
+ is_translate: bool = Field(default=False, description="Translate speech to English end-to-end")
189
+ beam_size: int = Field(default=5, ge=1, description="Beam size for decoding")
190
+ log_prob_threshold: float = Field(
191
+ default=-1.0,
192
+ description="Threshold for average log probability of sampled tokens"
193
+ )
194
+ no_speech_threshold: float = Field(
195
+ default=0.6,
196
+ ge=0.0,
197
+ le=1.0,
198
+ description="Threshold for detecting silence"
199
+ )
200
+ compute_type: str = Field(default="float16", description="Computation type for transcription")
201
+ best_of: int = Field(default=5, ge=1, description="Number of candidates when sampling")
202
+ patience: float = Field(default=1.0, gt=0, description="Beam search patience factor")
203
+ condition_on_previous_text: bool = Field(
204
+ default=True,
205
+ description="Use previous output as prompt for next window"
206
+ )
207
+ prompt_reset_on_temperature: float = Field(
208
+ default=0.5,
209
+ ge=0.0,
210
+ le=1.0,
211
+ description="Temperature threshold for resetting prompt"
212
+ )
213
+ initial_prompt: Optional[str] = Field(default=None, description="Initial prompt for first window")
214
+ temperature: float = Field(
215
+ default=0.0,
216
+ ge=0.0,
217
+ description="Temperature for sampling"
218
+ )
219
+ compression_ratio_threshold: float = Field(
220
+ default=2.4,
221
+ gt=0,
222
+ description="Threshold for gzip compression ratio"
223
+ )
224
+ batch_size: int = Field(default=24, gt=0, description="Batch size for processing")
225
+ length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty")
226
+ repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens")
227
+ no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition")
228
+ prefix: Optional[str] = Field(default=None, description="Prefix text for first window")
229
+ suppress_blank: bool = Field(
230
+ default=True,
231
+ description="Suppress blank outputs at start of sampling"
232
+ )
233
+ suppress_tokens: Optional[str] = Field(default="[-1]", description="Token IDs to suppress")
234
+ max_initial_timestamp: float = Field(
235
+ default=0.0,
236
+ ge=0.0,
237
+ description="Maximum initial timestamp"
238
+ )
239
+ word_timestamps: bool = Field(default=False, description="Extract word-level timestamps")
240
+ prepend_punctuations: Optional[str] = Field(
241
+ default="\"'“¿([{-",
242
+ description="Punctuations to merge with next word"
243
+ )
244
+ append_punctuations: Optional[str] = Field(
245
+ default="\"'.。,,!!??::”)]}、",
246
+ description="Punctuations to merge with previous word"
247
+ )
248
+ max_new_tokens: Optional[int] = Field(default=None, description="Maximum number of new tokens per chunk")
249
+ chunk_length: Optional[int] = Field(default=30, description="Length of audio segments in seconds")
250
+ hallucination_silence_threshold: Optional[float] = Field(
251
+ default=None,
252
+ description="Threshold for skipping silent periods in hallucination detection"
253
+ )
254
+ hotwords: Optional[str] = Field(default=None, description="Hotwords/hint phrases for the model")
255
+ language_detection_threshold: Optional[float] = Field(
256
+ default=None,
257
+ description="Threshold for language detection probability"
258
+ )
259
+ language_detection_segments: int = Field(
260
+ default=1,
261
+ gt=0,
262
+ description="Number of segments for language detection"
263
+ )
264
 
265
+ def to_dict(self):
266
+ return self.model_dump()
 
 
267
 
268
+ @field_validator('lang')
269
+ def validate_lang(cls, v):
270
+ from modules.utils.constants import AUTOMATIC_DETECTION
271
+ return None if v == AUTOMATIC_DETECTION.unwrap() else v
 
272
 
273
+ @classmethod
274
+ def to_gradio_inputs(cls,
275
+ defaults: Optional[Dict] = None,
276
+ only_advanced: Optional[bool] = True,
277
+ whisper_type: Optional[WhisperImpl] = None):
278
+ defaults = {} if defaults is None else defaults
279
+ whisper_type = WhisperImpl.FASTER_WHISPER if whisper_type is None else whisper_type
280
 
281
+ inputs = []
282
+ if not only_advanced:
283
+ inputs += [
284
+ gr.Dropdown(
285
+ label="Model Size",
286
+ choices=["small", "medium", "large-v2"],
287
+ value=defaults.get("model_size", cls.model_size),
288
+ info="Whisper model size"
289
+ ),
290
+ gr.Textbox(
291
+ label="Language",
292
+ value=defaults.get("lang", cls.lang),
293
+ info="Source language of the file to transcribe"
294
+ ),
295
+ gr.Checkbox(
296
+ label="Translate to English",
297
+ value=defaults.get("is_translate", cls.is_translate),
298
+ info="Translate speech to English end-to-end"
299
+ ),
300
+ ]
301
 
302
+ inputs += [
303
+ gr.Number(
304
+ label="Beam Size",
305
+ value=defaults.get("beam_size", cls.beam_size),
306
+ precision=0,
307
+ info="Beam size for decoding"
308
+ ),
309
+ gr.Number(
310
+ label="Log Probability Threshold",
311
+ value=defaults.get("log_prob_threshold", cls.log_prob_threshold),
312
+ info="Threshold for average log probability of sampled tokens"
313
+ ),
314
+ gr.Number(
315
+ label="No Speech Threshold",
316
+ value=defaults.get("no_speech_threshold", cls.no_speech_threshold),
317
+ info="Threshold for detecting silence"
318
+ ),
319
+ gr.Dropdown(
320
+ label="Compute Type",
321
+ choices=["float16", "int8", "int16"],
322
+ value=defaults.get("compute_type", cls.compute_type),
323
+ info="Computation type for transcription"
324
+ ),
325
+ gr.Number(
326
+ label="Best Of",
327
+ value=defaults.get("best_of", cls.best_of),
328
+ precision=0,
329
+ info="Number of candidates when sampling"
330
+ ),
331
+ gr.Number(
332
+ label="Patience",
333
+ value=defaults.get("patience", cls.patience),
334
+ info="Beam search patience factor"
335
+ ),
336
+ gr.Checkbox(
337
+ label="Condition On Previous Text",
338
+ value=defaults.get("condition_on_previous_text", cls.condition_on_previous_text),
339
+ info="Use previous output as prompt for next window"
340
+ ),
341
+ gr.Slider(
342
+ label="Prompt Reset On Temperature",
343
+ value=defaults.get("prompt_reset_on_temperature", cls.prompt_reset_on_temperature),
344
+ minimum=0,
345
+ maximum=1,
346
+ step=0.01,
347
+ info="Temperature threshold for resetting prompt"
348
+ ),
349
+ gr.Textbox(
350
+ label="Initial Prompt",
351
+ value=defaults.get("initial_prompt", cls.initial_prompt),
352
+ info="Initial prompt for first window"
353
+ ),
354
+ gr.Slider(
355
+ label="Temperature",
356
+ value=defaults.get("temperature", cls.temperature),
357
+ minimum=0.0,
358
+ step=0.01,
359
+ maximum=1.0,
360
+ info="Temperature for sampling"
361
+ ),
362
+ gr.Number(
363
+ label="Compression Ratio Threshold",
364
+ value=defaults.get("compression_ratio_threshold", cls.compression_ratio_threshold),
365
+ info="Threshold for gzip compression ratio"
366
+ )
367
+ ]
368
+ if whisper_type == WhisperImpl.FASTER_WHISPER:
369
+ inputs += [
370
+ gr.Number(
371
+ label="Length Penalty",
372
+ value=defaults.get("length_penalty", cls.length_penalty),
373
+ info="Exponential length penalty",
374
+ visible=whisper_type=="faster_whisper"
375
+ ),
376
+ gr.Number(
377
+ label="Repetition Penalty",
378
+ value=defaults.get("repetition_penalty", cls.repetition_penalty),
379
+ info="Penalty for repeated tokens"
380
+ ),
381
+ gr.Number(
382
+ label="No Repeat N-gram Size",
383
+ value=defaults.get("no_repeat_ngram_size", cls.no_repeat_ngram_size),
384
+ precision=0,
385
+ info="Size of n-grams to prevent repetition"
386
+ ),
387
+ gr.Textbox(
388
+ label="Prefix",
389
+ value=defaults.get("prefix", cls.prefix),
390
+ info="Prefix text for first window"
391
+ ),
392
+ gr.Checkbox(
393
+ label="Suppress Blank",
394
+ value=defaults.get("suppress_blank", cls.suppress_blank),
395
+ info="Suppress blank outputs at start of sampling"
396
+ ),
397
+ gr.Textbox(
398
+ label="Suppress Tokens",
399
+ value=defaults.get("suppress_tokens", cls.suppress_tokens),
400
+ info="Token IDs to suppress"
401
+ ),
402
+ gr.Number(
403
+ label="Max Initial Timestamp",
404
+ value=defaults.get("max_initial_timestamp", cls.max_initial_timestamp),
405
+ info="Maximum initial timestamp"
406
+ ),
407
+ gr.Checkbox(
408
+ label="Word Timestamps",
409
+ value=defaults.get("word_timestamps", cls.word_timestamps),
410
+ info="Extract word-level timestamps"
411
+ ),
412
+ gr.Textbox(
413
+ label="Prepend Punctuations",
414
+ value=defaults.get("prepend_punctuations", cls.prepend_punctuations),
415
+ info="Punctuations to merge with next word"
416
+ ),
417
+ gr.Textbox(
418
+ label="Append Punctuations",
419
+ value=defaults.get("append_punctuations", cls.append_punctuations),
420
+ info="Punctuations to merge with previous word"
421
+ ),
422
+ gr.Number(
423
+ label="Max New Tokens",
424
+ value=defaults.get("max_new_tokens", cls.max_new_tokens),
425
+ precision=0,
426
+ info="Maximum number of new tokens per chunk"
427
+ ),
428
+ gr.Number(
429
+ label="Chunk Length (s)",
430
+ value=defaults.get("chunk_length", cls.chunk_length),
431
+ precision=0,
432
+ info="Length of audio segments in seconds"
433
+ ),
434
+ gr.Number(
435
+ label="Hallucination Silence Threshold (sec)",
436
+ value=defaults.get("hallucination_silence_threshold", cls.hallucination_silence_threshold),
437
+ info="Threshold for skipping silent periods in hallucination detection"
438
+ ),
439
+ gr.Textbox(
440
+ label="Hotwords",
441
+ value=defaults.get("hotwords", cls.hotwords),
442
+ info="Hotwords/hint phrases for the model"
443
+ ),
444
+ gr.Number(
445
+ label="Language Detection Threshold",
446
+ value=defaults.get("language_detection_threshold", cls.language_detection_threshold),
447
+ info="Threshold for language detection probability"
448
+ ),
449
+ gr.Number(
450
+ label="Language Detection Segments",
451
+ value=defaults.get("language_detection_segments", cls.language_detection_segments),
452
+ precision=0,
453
+ info="Number of segments for language detection"
454
+ )
455
+ ]
456
 
457
+ if whisper_type == WhisperImpl.INSANELY_FAST_WHISPER:
458
+ inputs += [
459
+ gr.Number(
460
+ label="Batch Size",
461
+ value=defaults.get("batch_size", cls.batch_size),
462
+ precision=0,
463
+ info="Batch size for processing",
464
+ visible=whisper_type == "insanely_fast_whisper"
465
+ )
466
+ ]
467
+ return inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
+
470
+ class TranscriptionPipelineParams(BaseModel):
471
+ """Transcription pipeline parameters"""
472
+ whisper: WhisperParams = Field(default_factory=WhisperParams)
473
+ vad: VadParams = Field(default_factory=VadParams)
474
+ diarization: DiarizationParams = Field(default_factory=DiarizationParams)
475
+ bgm_separation: BGMSeparationParams = Field(default_factory=BGMSeparationParams)
476
+
477
+ def to_dict(self) -> Dict:
478
  data = {
479
+ "whisper": self.whisper.to_dict(),
480
+ "vad": self.vad.to_dict(),
481
+ "diarization": self.diarization.to_dict(),
482
+ "bgm_separation": self.bgm_separation.to_dict()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  }
484
  return data
485
 
486
+ # def as_list(self) -> list:
487
+ # """
488
+ # Converts the data class attributes into a list
489
+ #
490
+ # Returns
491
+ # ----------
492
+ # A list of Whisper parameters
493
+ # """
494
+ # return [getattr(self, f.name) for f in fields(self)]