jhj0517 commited on
Commit
e862b08
·
1 Parent(s): cee12df

Handle gradio none values

Browse files
modules/whisper/faster_whisper_inference.py CHANGED
@@ -67,16 +67,6 @@ class FasterWhisperInference(WhisperBase):
67
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
68
  self.update_model(params.model_size, params.compute_type, progress)
69
 
70
- # None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723
71
- if not params.initial_prompt:
72
- params.initial_prompt = None
73
- if not params.prefix:
74
- params.prefix = None
75
- if not params.hotwords:
76
- params.hotwords = None
77
-
78
- params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
79
-
80
  segments, info = self.model.transcribe(
81
  audio=audio,
82
  language=params.lang,
 
67
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
68
  self.update_model(params.model_size, params.compute_type, progress)
69
 
 
 
 
 
 
 
 
 
 
 
70
  segments, info = self.model.transcribe(
71
  audio=audio,
72
  language=params.lang,
modules/whisper/whisper_base.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import torch
 
3
  import whisper
4
  import ctranslate2
5
  import gradio as gr
@@ -14,7 +15,7 @@ from dataclasses import astuple
14
  from modules.uvr.music_separator import MusicSeparator
15
  from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
16
  UVR_MODELS_DIR)
17
- from modules.utils.constants import AUTOMATIC_DETECTION
18
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
19
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
20
  from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
@@ -101,16 +102,9 @@ class WhisperBase(ABC):
101
  elapsed time for running
102
  """
103
  params = TranscriptionPipelineParams.from_list(list(pipeline_params))
 
104
  bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization
105
 
106
- if whisper_params.lang is None:
107
- pass
108
- elif whisper_params.lang == AUTOMATIC_DETECTION:
109
- whisper_params.lang = None
110
- else:
111
- language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
112
- whisper_params.lang = language_code_dict[params.lang]
113
-
114
  if bgm_params.is_separate_bgm:
115
  music, audio, _ = self.music_separator.separate(
116
  audio=audio,
@@ -515,25 +509,57 @@ class WhisperBase(ABC):
515
  if file_path and os.path.exists(file_path):
516
  os.remove(file_path)
517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  @staticmethod
519
  def cache_parameters(
520
  params: TranscriptionPipelineParams,
521
  add_timestamp: bool
522
  ):
523
- """cache parameters to the yaml file"""
524
-
525
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
526
  param_to_cache = params.to_dict()
527
 
528
- print(param_to_cache)
529
-
530
  cached_yaml = {**cached_params, **param_to_cache}
531
  cached_yaml["whisper"]["add_timestamp"] = add_timestamp
532
 
 
 
 
 
533
  if cached_yaml["whisper"].get("lang", None) is None:
534
- cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION
535
 
536
- save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
 
537
 
538
  @staticmethod
539
  def resample_audio(audio: Union[str, np.ndarray],
 
1
  import os
2
  import torch
3
+ import ast
4
  import whisper
5
  import ctranslate2
6
  import gradio as gr
 
15
  from modules.uvr.music_separator import MusicSeparator
16
  from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
17
  UVR_MODELS_DIR)
18
+ from modules.utils.constants import AUTOMATIC_DETECTION, GRADIO_NONE_VALUES
19
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
20
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
21
  from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
 
102
  elapsed time for running
103
  """
104
  params = TranscriptionPipelineParams.from_list(list(pipeline_params))
105
+ params = self.handle_gradio_values(params)
106
  bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization
107
 
 
 
 
 
 
 
 
 
108
  if bgm_params.is_separate_bgm:
109
  music, audio, _ = self.music_separator.separate(
110
  audio=audio,
 
509
  if file_path and os.path.exists(file_path):
510
  os.remove(file_path)
511
 
512
+ @staticmethod
513
+ def handle_gradio_values(params: TranscriptionPipelineParams):
514
+ """
515
+ Handle gradio specific values that can't be displayed as None in the UI.
516
+ Related issue : https://github.com/gradio-app/gradio/issues/8723
517
+ """
518
+ if params.whisper.lang is None:
519
+ pass
520
+ elif params.whisper.lang == AUTOMATIC_DETECTION:
521
+ params.whisper.lang = None
522
+ else:
523
+ language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
524
+ params.whisper.lang = language_code_dict[params.lang]
525
+
526
+ if not params.whisper.initial_prompt:
527
+ params.whisper.initial_prompt = None
528
+ if not params.whisper.prefix:
529
+ params.whisper.prefix = None
530
+ if not params.whisper.hotwords:
531
+ params.whisper.hotwords = None
532
+ if params.whisper.max_new_tokens == 0:
533
+ params.whisper.max_new_tokens = None
534
+ if params.whisper.hallucination_silence_threshold == 0:
535
+ params.whisper.hallucination_silence_threshold = None
536
+ if params.whisper.language_detection_threshold == 0:
537
+ params.whisper.language_detection_threshold = None
538
+ if params.whisper.max_speech_duration_s >= 9999:
539
+ params.whisper.max_speech_duration_s = float('inf')
540
+ return params
541
+
542
  @staticmethod
543
  def cache_parameters(
544
  params: TranscriptionPipelineParams,
545
  add_timestamp: bool
546
  ):
547
+ """Cache parameters to the yaml file"""
 
548
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
549
  param_to_cache = params.to_dict()
550
 
 
 
551
  cached_yaml = {**cached_params, **param_to_cache}
552
  cached_yaml["whisper"]["add_timestamp"] = add_timestamp
553
 
554
+ supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
555
+ if supress_token and isinstance(supress_token, list):
556
+ cached_yaml["whisper"]["suppress_tokens"] = str(supress_token)
557
+
558
  if cached_yaml["whisper"].get("lang", None) is None:
559
+ cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
560
 
561
+ if cached_yaml is not None and cached_yaml:
562
+ save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
563
 
564
  @staticmethod
565
  def resample_audio(audio: Union[str, np.ndarray],