LAP-DEV commited on
Commit
6e87bac
·
verified ·
1 Parent(s): 6bdfbd4

Update modules/translation/translation_base.py

Browse files
modules/translation/translation_base.py CHANGED
@@ -9,7 +9,7 @@ from modules.whisper.whisper_parameter import *
9
  from modules.utils.subtitle_manager import *
10
  from modules.utils.files_manager import load_yaml, save_yaml
11
  from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
12
-
13
 
14
  class TranslationBase(ABC):
15
  def __init__(self,
@@ -171,13 +171,16 @@ class TranslationBase(ABC):
171
  List[dict] with translation
172
  """
173
  try:
174
- TranslationBase.cache_parameters(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,max_length=max_length,add_timestamp=add_timestamp)
175
- TranslationBase.update_model(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,progress=progress)
176
 
 
 
 
177
  total_progress = len(input_list_dict)
178
  for index, dic in enumerate(input_list_dict):
179
  progress(index / total_progress, desc="Translating..")
180
- translated_text = TranslationBase.translate(dic["text"], max_length=max_length)
181
  dic["text"] = translated_text
182
 
183
  return input_list_dict
@@ -185,7 +188,7 @@ class TranslationBase(ABC):
185
  except Exception as e:
186
  print(f"Error: {str(e)}")
187
  finally:
188
- TranslationBase.release_cuda_memory()
189
 
190
  @staticmethod
191
  def get_device():
 
9
  from modules.utils.subtitle_manager import *
10
  from modules.utils.files_manager import load_yaml, save_yaml
11
  from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
12
+ from modules.translation.inference import NLLBInference
13
 
14
  class TranslationBase(ABC):
15
  def __init__(self,
 
171
  List[dict] with translation
172
  """
173
  try:
174
+ self.cache_parameters(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,max_length=max_length,add_timestamp=add_timestamp)
175
+ #self.update_model(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,progress=progress)
176
 
177
+ NLLBInferenceInstance = NLLBInference()
178
+ NLLBInferenceInstance.update_model(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,progress=progress)
179
+
180
  total_progress = len(input_list_dict)
181
  for index, dic in enumerate(input_list_dict):
182
  progress(index / total_progress, desc="Translating..")
183
+ translated_text = self.translate(dic["text"], max_length=max_length)
184
  dic["text"] = translated_text
185
 
186
  return input_list_dict
 
188
  except Exception as e:
189
  print(f"Error: {str(e)}")
190
  finally:
191
+ self.release_cuda_memory()
192
 
193
  @staticmethod
194
  def get_device():