LAP-DEV commited on
Commit
179876d
·
verified ·
1 Parent(s): bc8e7de

Update modules/whisper/whisper_base.py

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +23 -6
modules/whisper/whisper_base.py CHANGED
@@ -671,16 +671,33 @@ class WhisperBase(ABC):
671
 
672
  @staticmethod
673
  def cache_parameters(
674
- whisper_params: WhisperValues,
675
- add_timestamp: bool
 
676
  ):
677
- """cache parameters to the yaml file"""
678
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
679
- cached_whisper_param = whisper_params.to_yaml()
680
- cached_yaml = {**cached_params, **cached_whisper_param}
 
681
  cached_yaml["whisper"]["add_timestamp"] = add_timestamp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
 
683
- save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
 
684
 
685
  @staticmethod
686
  def resample_audio(audio: Union[str, np.ndarray],
 
671
 
672
  @staticmethod
673
  def cache_parameters(
674
+ params: WhisperValues,
675
+ file_format: str = "SRT",
676
+ add_timestamp: bool = True
677
  ):
678
+ """Cache parameters to the yaml file"""
679
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
680
+ param_to_cache = params.to_dict()
681
+
682
+ cached_yaml = {**cached_params, **param_to_cache}
683
  cached_yaml["whisper"]["add_timestamp"] = add_timestamp
684
+ cached_yaml["whisper"]["file_format"] = file_format
685
+
686
+ suppress_token = cached_yaml["whisper"].get("suppress_tokens", None)
687
+ if suppress_token and isinstance(suppress_token, list):
688
+ cached_yaml["whisper"]["suppress_tokens"] = str(suppress_token)
689
+
690
+ if cached_yaml["whisper"].get("lang", None) is None:
691
+ cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
692
+ else:
693
+ language_dict = whisper.tokenizer.LANGUAGES
694
+ cached_yaml["whisper"]["lang"] = language_dict[cached_yaml["whisper"]["lang"]]
695
+
696
+ if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'):
697
+ cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX
698
 
699
+ if cached_yaml is not None and cached_yaml:
700
+ save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
701
 
702
  @staticmethod
703
  def resample_audio(audio: Union[str, np.ndarray],