LAP-DEV commited on
Commit
4adf2ca
·
verified ·
1 Parent(s): 64ceef0

Update modules/translation/translation_base.py

Browse files
modules/translation/translation_base.py CHANGED
@@ -136,6 +136,56 @@ class TranslationBase(ABC):
136
  finally:
137
  self.release_cuda_memory()
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  @staticmethod
140
  def get_device():
141
  if torch.cuda.is_available():
 
136
  finally:
137
  self.release_cuda_memory()
138
 
139
+ def translate_text(self,
140
+ input_list_dict: list,
141
+ model_size: str,
142
+ src_lang: str,
143
+ tgt_lang: str,
144
+ max_length: int = 200,
145
+ add_timestamp: bool = True,
146
+ progress=gr.Progress()) -> list:
147
+ """
148
+ Translate text from source language to target language
149
+ Parameters
150
+ ----------
151
+ str_text: str
152
+ List[dict] to translate
153
+ model_size: str
154
+ Whisper model size from gr.Dropdown()
155
+ src_lang: str
156
+ Source language of the file to translate from gr.Dropdown()
157
+ tgt_lang: str
158
+ Target language of the file to translate from gr.Dropdown()
159
+ max_length: int
160
+ Max length per line to translate
161
+ add_timestamp: bool
162
+ Boolean value that determines whether to add a timestamp
163
+ progress: gr.Progress
164
+ Indicator to show progress directly in gradio.
165
+ I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
166
+ Returns
167
+ ----------
168
+ A List of
169
+ List[dict] with translation
170
+ """
171
+ try:
172
+ self.cache_parameters(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,max_length=max_length,add_timestamp=add_timestamp)
173
+ self.update_model(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,progress=progress)
174
+
175
+ total_progress = len(input_list_dict)
176
+ for index, dic in enumerate(input_list_dict):
177
+ progress(index / total_progress, desc="Translating..")
178
+ translated_text = self.translate(dic["text"], max_length=max_length)
179
+ dic["text"] = translated_text
180
+
181
+ return input_list_dict
182
+
183
+ except Exception as e:
184
+ print(f"Error translating text: {e}")
185
+ raise
186
+ finally:
187
+ self.release_cuda_memory()
188
+
189
  @staticmethod
190
  def get_device():
191
  if torch.cuda.is_available():