Spaces:
Running
Running
jhj0517
commited on
Commit
·
cee12df
1
Parent(s):
250b9b4
Handle gradio None values
Browse files
modules/whisper/data_classes.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
from typing import Optional, Dict, List
|
4 |
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
5 |
from gradio_i18n import Translate, gettext as _
|
6 |
from enum import Enum
|
@@ -241,7 +241,7 @@ class WhisperParams(BaseParams):
|
|
241 |
default=True,
|
242 |
description="Suppress blank outputs at start of sampling"
|
243 |
)
|
244 |
-
suppress_tokens: Optional[str] = Field(default=
|
245 |
max_initial_timestamp: float = Field(
|
246 |
default=0.0,
|
247 |
ge=0.0,
|
@@ -279,6 +279,20 @@ class WhisperParams(BaseParams):
|
|
279 |
from modules.utils.constants import AUTOMATIC_DETECTION
|
280 |
return None if v == AUTOMATIC_DETECTION.unwrap() else v
|
281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
@classmethod
|
283 |
def to_gradio_inputs(cls,
|
284 |
defaults: Optional[Dict] = None,
|
@@ -301,7 +315,7 @@ class WhisperParams(BaseParams):
|
|
301 |
gr.Dropdown(
|
302 |
label=_("Language"),
|
303 |
choices=available_langs,
|
304 |
-
value=defaults.get("lang",
|
305 |
),
|
306 |
gr.Checkbox(
|
307 |
label=_("Translate to English?"),
|
@@ -407,7 +421,7 @@ class WhisperParams(BaseParams):
|
|
407 |
),
|
408 |
gr.Textbox(
|
409 |
label="Suppress Tokens",
|
410 |
-
value=defaults.get("suppress_tokens",
|
411 |
info="Token IDs to suppress"
|
412 |
),
|
413 |
gr.Number(
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from typing import Optional, Dict, List, Union
|
4 |
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
5 |
from gradio_i18n import Translate, gettext as _
|
6 |
from enum import Enum
|
|
|
241 |
default=True,
|
242 |
description="Suppress blank outputs at start of sampling"
|
243 |
)
|
244 |
+
suppress_tokens: Optional[Union[List, str]] = Field(default=[-1], description="Token IDs to suppress")
|
245 |
max_initial_timestamp: float = Field(
|
246 |
default=0.0,
|
247 |
ge=0.0,
|
|
|
279 |
from modules.utils.constants import AUTOMATIC_DETECTION
|
280 |
return None if v == AUTOMATIC_DETECTION.unwrap() else v
|
281 |
|
282 |
+
@field_validator('suppress_tokens')
|
283 |
+
def validate_supress_tokens(cls, v):
|
284 |
+
import ast
|
285 |
+
try:
|
286 |
+
if isinstance(v, str):
|
287 |
+
suppress_tokens = ast.literal_eval(v)
|
288 |
+
if not isinstance(suppress_tokens, list):
|
289 |
+
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
|
290 |
+
return suppress_tokens
|
291 |
+
if isinstance(v, list):
|
292 |
+
return v
|
293 |
+
except Exception as e:
|
294 |
+
raise ValueError(f"Invalid Suppress Tokens. The value must be type of List[int]: {e}")
|
295 |
+
|
296 |
@classmethod
|
297 |
def to_gradio_inputs(cls,
|
298 |
defaults: Optional[Dict] = None,
|
|
|
315 |
gr.Dropdown(
|
316 |
label=_("Language"),
|
317 |
choices=available_langs,
|
318 |
+
value=defaults.get("lang", AUTOMATIC_DETECTION),
|
319 |
),
|
320 |
gr.Checkbox(
|
321 |
label=_("Translate to English?"),
|
|
|
421 |
),
|
422 |
gr.Textbox(
|
423 |
label="Suppress Tokens",
|
424 |
+
value=defaults.get("suppress_tokens", "[-1]"),
|
425 |
info="Token IDs to suppress"
|
426 |
),
|
427 |
gr.Number(
|