jhj0517 commited on
Commit
296b5e1
·
1 Parent(s): aa178ad

rename data class and add `post_process()`

Browse files
modules/faster_whisper_inference.py CHANGED
@@ -52,7 +52,7 @@ class FasterWhisperInference(WhisperBase):
52
  """
53
  start_time = time.time()
54
 
55
- params = WhisperValues(*whisper_params)
56
 
57
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
58
  self.update_model(params.model_size, params.compute_type, progress)
 
52
  """
53
  start_time = time.time()
54
 
55
+ params = WhisperParameters.post_process(*whisper_params)
56
 
57
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
58
  self.update_model(params.model_size, params.compute_type, progress)
modules/insanely_fast_whisper_inference.py CHANGED
@@ -8,7 +8,6 @@ from transformers.utils import is_flash_attn_2_available
8
  import gradio as gr
9
  from huggingface_hub import hf_hub_download
10
  import whisper
11
-
12
  from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
13
 
14
  from modules.whisper_parameter import *
@@ -50,7 +49,7 @@ class InsanelyFastWhisperInference(WhisperBase):
50
  elapsed time for transcription
51
  """
52
  start_time = time.time()
53
- params = WhisperValues(*whisper_params)
54
 
55
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
56
  self.update_model(params.model_size, params.compute_type, progress)
 
8
  import gradio as gr
9
  from huggingface_hub import hf_hub_download
10
  import whisper
 
11
  from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
12
 
13
  from modules.whisper_parameter import *
 
49
  elapsed time for transcription
50
  """
51
  start_time = time.time()
52
+ params = WhisperParameters.post_process(*whisper_params)
53
 
54
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
55
  self.update_model(params.model_size, params.compute_type, progress)
modules/whisper_Inference.py CHANGED
@@ -41,7 +41,7 @@ class WhisperInference(WhisperBase):
41
  elapsed time for transcription
42
  """
43
  start_time = time.time()
44
- params = WhisperValues(*whisper_params)
45
 
46
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
47
  self.update_model(params.model_size, params.compute_type, progress)
 
41
  elapsed time for transcription
42
  """
43
  start_time = time.time()
44
+ params = WhisperParameters.post_process(*whisper_params)
45
 
46
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
47
  self.update_model(params.model_size, params.compute_type, progress)
modules/whisper_parameter.py CHANGED
@@ -4,7 +4,7 @@ from typing import Optional
4
 
5
 
6
  @dataclass
7
- class WhisperGradioComponents:
8
  model_size: gr.Dropdown
9
  lang: gr.Dropdown
10
  is_translate: gr.Checkbox
@@ -115,7 +115,7 @@ class WhisperGradioComponents:
115
 
116
  def to_list(self) -> list:
117
  """
118
- Converts the data class attributes into a list. Use "before" Gradio pre-processing.
119
  See more about Gradio pre-processing: : https://www.gradio.app/docs/components
120
 
121
  Returns
@@ -124,6 +124,40 @@ class WhisperGradioComponents:
124
  """
125
  return [getattr(self, f.name) for f in fields(self)]
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  @dataclass
129
  class WhisperValues:
@@ -148,6 +182,5 @@ class WhisperValues:
148
  window_size_samples: int
149
  speech_pad_ms: int
150
  """
151
- A data class to use Whisper parameters. Use "after" Gradio pre-processing.
152
- See more about Gradio pre-processing: : https://www.gradio.app/docs/components
153
  """
 
4
 
5
 
6
  @dataclass
7
+ class WhisperParameters:
8
  model_size: gr.Dropdown
9
  lang: gr.Dropdown
10
  is_translate: gr.Checkbox
 
115
 
116
  def to_list(self) -> list:
117
  """
118
+ Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
119
  See more about Gradio pre-processing: : https://www.gradio.app/docs/components
120
 
121
  Returns
 
124
  """
125
  return [getattr(self, f.name) for f in fields(self)]
126
 
127
+ @staticmethod
128
+ def post_process(*args) -> 'WhisperValues':
129
+ """
130
+ To use Whisper parameters in function after Gradio post-processing.
131
+ See more about Gradio post-processing: : https://www.gradio.app/docs/components
132
+
133
+ Returns
134
+ ----------
135
+ WhisperValues
136
+ Data class that has values of parameters
137
+ """
138
+ return WhisperValues(
139
+ model_size=args[0],
140
+ lang=args[1],
141
+ is_translate=args[2],
142
+ beam_size=args[3],
143
+ log_prob_threshold=args[4],
144
+ no_speech_threshold=args[5],
145
+ compute_type=args[6],
146
+ best_of=args[7],
147
+ patience=args[8],
148
+ condition_on_previous_text=args[9],
149
+ initial_prompt=args[10],
150
+ temperature=args[11],
151
+ compression_ratio_threshold=args[12],
152
+ vad_filter=args[13],
153
+ threshold=args[14],
154
+ min_speech_duration_ms=args[15],
155
+ max_speech_duration_s=args[16],
156
+ min_silence_duration_ms=args[17],
157
+ window_size_samples=args[18],
158
+ speech_pad_ms=args[19]
159
+ )
160
+
161
 
162
  @dataclass
163
  class WhisperValues:
 
182
  window_size_samples: int
183
  speech_pad_ms: int
184
  """
185
+ A data class to use Whisper parameters.
 
186
  """