LAP-DEV commited on
Commit
35d057a
·
verified ·
1 Parent(s): 48b1606

Update modules/translation/translation_base.py

Browse files
modules/translation/translation_base.py CHANGED
@@ -136,6 +136,54 @@ class TranslationBase(ABC):
136
  finally:
137
  self.release_cuda_memory()
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  @staticmethod
140
  def get_device():
141
  if torch.cuda.is_available():
 
136
  finally:
137
  self.release_cuda_memory()
138
 
139
+ def translate_text(self,
140
+ List[dict]: input_list_dict,
141
+ model_size: str,
142
+ src_lang: str,
143
+ tgt_lang: str,
144
+ max_length: int = 200,
145
+ progress=gr.Progress()) -> list:
146
+ """
147
+ Translate text from source language to target language
148
+
149
+ Parameters
150
+ ----------
151
+ str_text: str
152
+ List[dict] to translate
153
+ model_size: str
154
+ Whisper model size from gr.Dropdown()
155
+ src_lang: str
156
+ Source language of the file to translate from gr.Dropdown()
157
+ tgt_lang: str
158
+ Target language of the file to translate from gr.Dropdown()
159
+ max_length: int
160
+ Max length per line to translate
161
+ progress: gr.Progress
162
+ Indicator to show progress directly in gradio.
163
+ I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
164
+
165
+ Returns
166
+ ----------
167
+ A List of
168
+ List[dict] with translation
169
+ """
170
+ try:
171
+ self.cache_parameters(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,max_length=max_length,add_timestamp=add_timestamp)
172
+ self.update_model(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,progress=progress)
173
+
174
+ total_progress = len(input_list_dict)
175
+ for index, dic in enumerate(input_list_dict):
176
+ progress(index / total_progress, desc="Translating..")
177
+ translated_text = self.translate(dic["text"], max_length=max_length)
178
+ dic["text"] = translated_text
179
+
180
+ return input_list_dict
181
+
182
+ except Exception as e:
183
+ print(f"Error: {str(e)}")
184
+ finally:
185
+ self.release_cuda_memory()
186
+
187
  @staticmethod
188
  def get_device():
189
  if torch.cuda.is_available():