jhj0517 commited on
Commit
b6032b5
·
1 Parent(s): 7d9eec3

Add `cache_parameters()`

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +10 -1
modules/whisper/whisper_base.py CHANGED
@@ -9,7 +9,7 @@ from datetime import datetime
9
  from faster_whisper.vad import VadOptions
10
  from dataclasses import astuple
11
 
12
- from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR)
13
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
14
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
15
  from modules.utils.files_manager import get_media_files, format_gradio_files
@@ -94,6 +94,8 @@ class WhisperBase(ABC):
94
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
95
  params.lang = language_code_dict[params.lang]
96
 
 
 
97
  speech_chunks = None
98
  if params.vad_filter:
99
  # Explicit value set for float('inf') from gr.Number()
@@ -435,3 +437,10 @@ class WhisperBase(ABC):
435
  for file_path in file_paths:
436
  if file_path and os.path.exists(file_path):
437
  os.remove(file_path)
 
 
 
 
 
 
 
 
9
  from faster_whisper.vad import VadOptions
10
  from dataclasses import astuple
11
 
12
+ from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH)
13
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
14
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
15
  from modules.utils.files_manager import get_media_files, format_gradio_files
 
94
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
95
  params.lang = language_code_dict[params.lang]
96
 
97
+ self.cache_parameters(params)
98
+
99
  speech_chunks = None
100
  if params.vad_filter:
101
  # Explicit value set for float('inf') from gr.Number()
 
437
  for file_path in file_paths:
438
  if file_path and os.path.exists(file_path):
439
  os.remove(file_path)
440
+
441
+ @staticmethod
442
+ def cache_parameters(whisper_params: WhisperValues):
443
+ cached_yaml = whisper_params.to_yaml()
444
+
445
+ with open(DEFAULT_PARAMETERS_CONFIG_PATH, 'w', encoding='utf-8') as file:
446
+ file.write(cached_yaml)