jhj0517 commited on
Commit
cee12df
·
1 Parent(s): 250b9b4

Handle gradio None values

Browse files
Files changed (1) hide show
  1. modules/whisper/data_classes.py +18 -4
modules/whisper/data_classes.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from typing import Optional, Dict, List
4
  from pydantic import BaseModel, Field, field_validator, ConfigDict
5
  from gradio_i18n import Translate, gettext as _
6
  from enum import Enum
@@ -241,7 +241,7 @@ class WhisperParams(BaseParams):
241
  default=True,
242
  description="Suppress blank outputs at start of sampling"
243
  )
244
- suppress_tokens: Optional[str] = Field(default="[-1]", description="Token IDs to suppress")
245
  max_initial_timestamp: float = Field(
246
  default=0.0,
247
  ge=0.0,
@@ -279,6 +279,20 @@ class WhisperParams(BaseParams):
279
  from modules.utils.constants import AUTOMATIC_DETECTION
280
  return None if v == AUTOMATIC_DETECTION.unwrap() else v
281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  @classmethod
283
  def to_gradio_inputs(cls,
284
  defaults: Optional[Dict] = None,
@@ -301,7 +315,7 @@ class WhisperParams(BaseParams):
301
  gr.Dropdown(
302
  label=_("Language"),
303
  choices=available_langs,
304
- value=defaults.get("lang", cls.__fields__["lang"].default),
305
  ),
306
  gr.Checkbox(
307
  label=_("Translate to English?"),
@@ -407,7 +421,7 @@ class WhisperParams(BaseParams):
407
  ),
408
  gr.Textbox(
409
  label="Suppress Tokens",
410
- value=defaults.get("suppress_tokens", cls.__fields__["suppress_tokens"].default),
411
  info="Token IDs to suppress"
412
  ),
413
  gr.Number(
 
1
  import gradio as gr
2
  import torch
3
+ from typing import Optional, Dict, List, Union
4
  from pydantic import BaseModel, Field, field_validator, ConfigDict
5
  from gradio_i18n import Translate, gettext as _
6
  from enum import Enum
 
241
  default=True,
242
  description="Suppress blank outputs at start of sampling"
243
  )
244
+ suppress_tokens: Optional[Union[List, str]] = Field(default=[-1], description="Token IDs to suppress")
245
  max_initial_timestamp: float = Field(
246
  default=0.0,
247
  ge=0.0,
 
279
  from modules.utils.constants import AUTOMATIC_DETECTION
280
  return None if v == AUTOMATIC_DETECTION.unwrap() else v
281
 
282
+ @field_validator('suppress_tokens')
283
+ def validate_supress_tokens(cls, v):
284
+ import ast
285
+ try:
286
+ if isinstance(v, str):
287
+ suppress_tokens = ast.literal_eval(v)
288
+ if not isinstance(suppress_tokens, list):
289
+ raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
290
+ return suppress_tokens
291
+ if isinstance(v, list):
292
+ return v
293
+ except Exception as e:
294
+ raise ValueError(f"Invalid Suppress Tokens. The value must be type of List[int]: {e}")
295
+
296
  @classmethod
297
  def to_gradio_inputs(cls,
298
  defaults: Optional[Dict] = None,
 
315
  gr.Dropdown(
316
  label=_("Language"),
317
  choices=available_langs,
318
+ value=defaults.get("lang", AUTOMATIC_DETECTION),
319
  ),
320
  gr.Checkbox(
321
  label=_("Translate to English?"),
 
421
  ),
422
  gr.Textbox(
423
  label="Suppress Tokens",
424
+ value=defaults.get("suppress_tokens", "[-1]"),
425
  info="Token IDs to suppress"
426
  ),
427
  gr.Number(