refinamento / modules /whisper /data_classes.py
jhj0517
Use constant for gradio none validation values
2a2f7c6
raw
history blame
23.5 kB
import gradio as gr
import torch
from typing import Optional, Dict, List, Union
from pydantic import BaseModel, Field, field_validator, ConfigDict
from gradio_i18n import Translate, gettext as _
from enum import Enum
from copy import deepcopy
import yaml
from modules.utils.constants import *
class WhisperImpl(Enum):
WHISPER = "whisper"
FASTER_WHISPER = "faster-whisper"
INSANELY_FAST_WHISPER = "insanely_fast_whisper"
class BaseParams(BaseModel):
model_config = ConfigDict(protected_namespaces=())
def to_dict(self) -> Dict:
return self.model_dump()
def to_list(self) -> List:
return list(self.model_dump().values())
@classmethod
def from_list(cls, data_list: List) -> 'BaseParams':
field_names = list(cls.model_fields.keys())
return cls(**dict(zip(field_names, data_list)))
class VadParams(BaseParams):
"""Voice Activity Detection parameters"""
vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts")
threshold: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Speech threshold for Silero VAD. Probabilities above this value are considered speech"
)
min_speech_duration_ms: int = Field(
default=250,
ge=0,
description="Final speech chunks shorter than this are discarded"
)
max_speech_duration_s: float = Field(
default=float("inf"),
gt=0,
description="Maximum duration of speech chunks in seconds"
)
min_silence_duration_ms: int = Field(
default=2000,
ge=0,
description="Minimum silence duration between speech chunks"
)
speech_pad_ms: int = Field(
default=400,
ge=0,
description="Padding added to each side of speech chunks"
)
@classmethod
def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]:
return [
gr.Checkbox(
label=_("Enable Silero VAD Filter"),
value=defaults.get("vad_filter", cls.__fields__["vad_filter"].default),
interactive=True,
info=_("Enable this to transcribe only detected voice")
),
gr.Slider(
minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
value=defaults.get("threshold", cls.__fields__["threshold"].default),
info="Lower it to be more sensitive to small sounds."
),
gr.Number(
label="Minimum Speech Duration (ms)", precision=0,
value=defaults.get("min_speech_duration_ms", cls.__fields__["min_speech_duration_ms"].default),
info="Final speech chunks shorter than this time are thrown out"
),
gr.Number(
label="Maximum Speech Duration (s)",
value=defaults.get("max_speech_duration_s", GRADIO_NONE_NUMBER_MAX),
info="Maximum duration of speech chunks in \"seconds\"."
),
gr.Number(
label="Minimum Silence Duration (ms)", precision=0,
value=defaults.get("min_silence_duration_ms", cls.__fields__["min_silence_duration_ms"].default),
info="In the end of each speech chunk wait for this time before separating it"
),
gr.Number(
label="Speech Padding (ms)", precision=0,
value=defaults.get("speech_pad_ms", cls.__fields__["speech_pad_ms"].default),
info="Final speech chunks are padded by this time each side"
)
]
class DiarizationParams(BaseParams):
"""Speaker diarization parameters"""
is_diarize: bool = Field(default=False, description="Enable speaker diarization")
device: str = Field(default="cuda", description="Device to run Diarization model.")
hf_token: str = Field(
default="",
description="Hugging Face token for downloading diarization models"
)
@classmethod
def to_gradio_inputs(cls,
defaults: Optional[Dict] = None,
available_devices: Optional[List] = None,
device: Optional[str] = None) -> List[gr.components.base.FormComponent]:
return [
gr.Checkbox(
label=_("Enable Diarization"),
value=defaults.get("is_diarize", cls.__fields__["is_diarize"].default),
),
gr.Dropdown(
label=_("Device"),
choices=["cpu", "cuda"] if available_devices is None else available_devices,
value=defaults.get("device", device),
),
gr.Textbox(
label=_("HuggingFace Token"),
value=defaults.get("hf_token", cls.__fields__["hf_token"].default),
info=_("This is only needed the first time you download the model")
),
]
class BGMSeparationParams(BaseParams):
"""Background music separation parameters"""
is_separate_bgm: bool = Field(default=False, description="Enable background music separation")
model_size: str = Field(
default="UVR-MDX-NET-Inst_HQ_4",
description="UVR model size"
)
device: str = Field(default="cuda", description="Device to run UVR model.")
segment_size: int = Field(
default=256,
gt=0,
description="Segment size for UVR model"
)
save_file: bool = Field(
default=False,
description="Whether to save separated audio files"
)
enable_offload: bool = Field(
default=True,
description="Offload UVR model after transcription"
)
@classmethod
def to_gradio_input(cls,
defaults: Optional[Dict] = None,
available_devices: Optional[List] = None,
device: Optional[str] = None,
available_models: Optional[List] = None) -> List[gr.components.base.FormComponent]:
return [
gr.Checkbox(
label=_("Enable Background Music Remover Filter"),
value=defaults.get("is_separate_bgm", cls.__fields__["is_separate_bgm"].default),
interactive=True,
info=_("Enabling this will remove background music")
),
gr.Dropdown(
label=_("Model"),
choices=["UVR-MDX-NET-Inst_HQ_4",
"UVR-MDX-NET-Inst_3"] if available_models is None else available_models,
value=defaults.get("model_size", cls.__fields__["model_size"].default),
),
gr.Dropdown(
label=_("Device"),
choices=["cpu", "cuda"] if available_devices is None else available_devices,
value=defaults.get("device", device),
),
gr.Number(
label="Segment Size",
value=defaults.get("segment_size", cls.__fields__["segment_size"].default),
precision=0,
info="Segment size for UVR model"
),
gr.Checkbox(
label=_("Save separated files to output"),
value=defaults.get("save_file", cls.__fields__["save_file"].default),
),
gr.Checkbox(
label=_("Offload sub model after removing background music"),
value=defaults.get("enable_offload", cls.__fields__["enable_offload"].default),
)
]
class WhisperParams(BaseParams):
"""Whisper parameters"""
model_size: str = Field(default="large-v2", description="Whisper model size")
lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe")
is_translate: bool = Field(default=False, description="Translate speech to English end-to-end")
beam_size: int = Field(default=5, ge=1, description="Beam size for decoding")
log_prob_threshold: float = Field(
default=-1.0,
description="Threshold for average log probability of sampled tokens"
)
no_speech_threshold: float = Field(
default=0.6,
ge=0.0,
le=1.0,
description="Threshold for detecting silence"
)
compute_type: str = Field(default="float16", description="Computation type for transcription")
best_of: int = Field(default=5, ge=1, description="Number of candidates when sampling")
patience: float = Field(default=1.0, gt=0, description="Beam search patience factor")
condition_on_previous_text: bool = Field(
default=True,
description="Use previous output as prompt for next window"
)
prompt_reset_on_temperature: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Temperature threshold for resetting prompt"
)
initial_prompt: Optional[str] = Field(default=None, description="Initial prompt for first window")
temperature: float = Field(
default=0.0,
ge=0.0,
description="Temperature for sampling"
)
compression_ratio_threshold: float = Field(
default=2.4,
gt=0,
description="Threshold for gzip compression ratio"
)
length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty")
repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens")
no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition")
prefix: Optional[str] = Field(default=None, description="Prefix text for first window")
suppress_blank: bool = Field(
default=True,
description="Suppress blank outputs at start of sampling"
)
suppress_tokens: Optional[Union[List, str]] = Field(default=[-1], description="Token IDs to suppress")
max_initial_timestamp: float = Field(
default=0.0,
ge=0.0,
description="Maximum initial timestamp"
)
word_timestamps: bool = Field(default=False, description="Extract word-level timestamps")
prepend_punctuations: Optional[str] = Field(
default="\"'“¿([{-",
description="Punctuations to merge with next word"
)
append_punctuations: Optional[str] = Field(
default="\"'.。,,!!??::”)]}、",
description="Punctuations to merge with previous word"
)
max_new_tokens: Optional[int] = Field(default=None, description="Maximum number of new tokens per chunk")
chunk_length: Optional[int] = Field(default=30, description="Length of audio segments in seconds")
hallucination_silence_threshold: Optional[float] = Field(
default=None,
description="Threshold for skipping silent periods in hallucination detection"
)
hotwords: Optional[str] = Field(default=None, description="Hotwords/hint phrases for the model")
language_detection_threshold: Optional[float] = Field(
default=None,
description="Threshold for language detection probability"
)
language_detection_segments: int = Field(
default=1,
gt=0,
description="Number of segments for language detection"
)
batch_size: int = Field(default=24, gt=0, description="Batch size for processing")
@field_validator('lang')
def validate_lang(cls, v):
from modules.utils.constants import AUTOMATIC_DETECTION
return None if v == AUTOMATIC_DETECTION.unwrap() else v
@field_validator('suppress_tokens')
def validate_supress_tokens(cls, v):
import ast
try:
if isinstance(v, str):
suppress_tokens = ast.literal_eval(v)
if not isinstance(suppress_tokens, list):
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
return suppress_tokens
if isinstance(v, list):
return v
except Exception as e:
raise ValueError(f"Invalid Suppress Tokens. The value must be type of List[int]: {e}")
@classmethod
def to_gradio_inputs(cls,
defaults: Optional[Dict] = None,
only_advanced: Optional[bool] = True,
whisper_type: Optional[WhisperImpl] = None,
available_models: Optional[List] = None,
available_langs: Optional[List] = None,
available_compute_types: Optional[List] = None,
compute_type: Optional[str] = None):
whisper_type = WhisperImpl.FASTER_WHISPER if whisper_type is None else whisper_type
inputs = []
if not only_advanced:
inputs += [
gr.Dropdown(
label=_("Model"),
choices=available_models,
value=defaults.get("model_size", cls.__fields__["model_size"].default),
),
gr.Dropdown(
label=_("Language"),
choices=available_langs,
value=defaults.get("lang", AUTOMATIC_DETECTION),
),
gr.Checkbox(
label=_("Translate to English?"),
value=defaults.get("is_translate", cls.__fields__["is_translate"].default),
),
]
inputs += [
gr.Number(
label="Beam Size",
value=defaults.get("beam_size", cls.__fields__["beam_size"].default),
precision=0,
info="Beam size for decoding"
),
gr.Number(
label="Log Probability Threshold",
value=defaults.get("log_prob_threshold", cls.__fields__["log_prob_threshold"].default),
info="Threshold for average log probability of sampled tokens"
),
gr.Number(
label="No Speech Threshold",
value=defaults.get("no_speech_threshold", cls.__fields__["no_speech_threshold"].default),
info="Threshold for detecting silence"
),
gr.Dropdown(
label="Compute Type",
choices=["float16", "int8", "int16"] if available_compute_types is None else available_compute_types,
value=defaults.get("compute_type", compute_type),
info="Computation type for transcription"
),
gr.Number(
label="Best Of",
value=defaults.get("best_of", cls.__fields__["best_of"].default),
precision=0,
info="Number of candidates when sampling"
),
gr.Number(
label="Patience",
value=defaults.get("patience", cls.__fields__["patience"].default),
info="Beam search patience factor"
),
gr.Checkbox(
label="Condition On Previous Text",
value=defaults.get("condition_on_previous_text", cls.__fields__["condition_on_previous_text"].default),
info="Use previous output as prompt for next window"
),
gr.Slider(
label="Prompt Reset On Temperature",
value=defaults.get("prompt_reset_on_temperature",
cls.__fields__["prompt_reset_on_temperature"].default),
minimum=0,
maximum=1,
step=0.01,
info="Temperature threshold for resetting prompt"
),
gr.Textbox(
label="Initial Prompt",
value=defaults.get("initial_prompt", GRADIO_NONE_STR),
info="Initial prompt for first window"
),
gr.Slider(
label="Temperature",
value=defaults.get("temperature", cls.__fields__["temperature"].default),
minimum=0.0,
step=0.01,
maximum=1.0,
info="Temperature for sampling"
),
gr.Number(
label="Compression Ratio Threshold",
value=defaults.get("compression_ratio_threshold",
cls.__fields__["compression_ratio_threshold"].default),
info="Threshold for gzip compression ratio"
)
]
faster_whisper_inputs = [
gr.Number(
label="Length Penalty",
value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default),
info="Exponential length penalty",
),
gr.Number(
label="Repetition Penalty",
value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default),
info="Penalty for repeated tokens"
),
gr.Number(
label="No Repeat N-gram Size",
value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default),
precision=0,
info="Size of n-grams to prevent repetition"
),
gr.Textbox(
label="Prefix",
value=defaults.get("prefix", GRADIO_NONE_STR),
info="Prefix text for first window"
),
gr.Checkbox(
label="Suppress Blank",
value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default),
info="Suppress blank outputs at start of sampling"
),
gr.Textbox(
label="Suppress Tokens",
value=defaults.get("suppress_tokens", "[-1]"),
info="Token IDs to suppress"
),
gr.Number(
label="Max Initial Timestamp",
value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default),
info="Maximum initial timestamp"
),
gr.Checkbox(
label="Word Timestamps",
value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default),
info="Extract word-level timestamps"
),
gr.Textbox(
label="Prepend Punctuations",
value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default),
info="Punctuations to merge with next word"
),
gr.Textbox(
label="Append Punctuations",
value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default),
info="Punctuations to merge with previous word"
),
gr.Number(
label="Max New Tokens",
value=defaults.get("max_new_tokens", GRADIO_NONE_NUMBER_MIN),
precision=0,
info="Maximum number of new tokens per chunk"
),
gr.Number(
label="Chunk Length (s)",
value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default),
precision=0,
info="Length of audio segments in seconds"
),
gr.Number(
label="Hallucination Silence Threshold (sec)",
value=defaults.get("hallucination_silence_threshold",
GRADIO_NONE_NUMBER_MIN),
info="Threshold for skipping silent periods in hallucination detection"
),
gr.Textbox(
label="Hotwords",
value=defaults.get("hotwords", cls.__fields__["hotwords"].default),
info="Hotwords/hint phrases for the model"
),
gr.Number(
label="Language Detection Threshold",
value=defaults.get("language_detection_threshold",
GRADIO_NONE_NUMBER_MIN),
info="Threshold for language detection probability"
),
gr.Number(
label="Language Detection Segments",
value=defaults.get("language_detection_segments",
cls.__fields__["language_detection_segments"].default),
precision=0,
info="Number of segments for language detection"
)
]
insanely_fast_whisper_inputs = [
gr.Number(
label="Batch Size",
value=defaults.get("batch_size", cls.__fields__["batch_size"].default),
precision=0,
info="Batch size for processing"
)
]
if whisper_type == WhisperImpl.FASTER_WHISPER:
for input_component in faster_whisper_inputs:
input_component.visible = True
if whisper_type == WhisperImpl.INSANELY_FAST_WHISPER:
for input_component in insanely_fast_whisper_inputs:
input_component.visible = True
inputs += faster_whisper_inputs + insanely_fast_whisper_inputs
return inputs
class TranscriptionPipelineParams(BaseModel):
"""Transcription pipeline parameters"""
whisper: WhisperParams = Field(default_factory=WhisperParams)
vad: VadParams = Field(default_factory=VadParams)
diarization: DiarizationParams = Field(default_factory=DiarizationParams)
bgm_separation: BGMSeparationParams = Field(default_factory=BGMSeparationParams)
def to_dict(self) -> Dict:
data = {
"whisper": self.whisper.to_dict(),
"vad": self.vad.to_dict(),
"diarization": self.diarization.to_dict(),
"bgm_separation": self.bgm_separation.to_dict()
}
return data
def to_list(self) -> List:
"""
Convert data class to the list because I have to pass the parameters as a list in the gradio.
Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
See more about Gradio pre-processing: https://www.gradio.app/docs/components
"""
whisper_list = self.whisper.to_list()
vad_list = self.vad.to_list()
diarization_list = self.diarization.to_list()
bgm_sep_list = self.bgm_separation.to_list()
return whisper_list + vad_list + diarization_list + bgm_sep_list
@staticmethod
def from_list(pipeline_list: List) -> 'TranscriptionPipelineParams':
"""Convert list to the data class again to use it in a function."""
data_list = deepcopy(pipeline_list)
whisper_list = data_list[0:len(WhisperParams.__annotations__)]
data_list = data_list[len(WhisperParams.__annotations__):]
vad_list = data_list[0:len(VadParams.__annotations__)]
data_list = data_list[len(VadParams.__annotations__):]
diarization_list = data_list[0:len(DiarizationParams.__annotations__)]
data_list = data_list[len(DiarizationParams.__annotations__):]
bgm_sep_list = data_list[0:len(BGMSeparationParams.__annotations__)]
return TranscriptionPipelineParams(
whisper=WhisperParams.from_list(whisper_list),
vad=VadParams.from_list(vad_list),
diarization=DiarizationParams.from_list(diarization_list),
bgm_separation=BGMSeparationParams.from_list(bgm_sep_list)
)