Update modules/translation/translation_base.py
Browse files
modules/translation/translation_base.py
CHANGED
@@ -9,7 +9,7 @@ from modules.whisper.whisper_parameter import *
|
|
9 |
from modules.utils.subtitle_manager import *
|
10 |
from modules.utils.files_manager import load_yaml, save_yaml
|
11 |
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
|
12 |
-
|
13 |
|
14 |
class TranslationBase(ABC):
|
15 |
def __init__(self,
|
@@ -171,13 +171,16 @@ class TranslationBase(ABC):
|
|
171 |
List[dict] with translation
|
172 |
"""
|
173 |
try:
|
174 |
-
|
175 |
-
|
176 |
|
|
|
|
|
|
|
177 |
total_progress = len(input_list_dict)
|
178 |
for index, dic in enumerate(input_list_dict):
|
179 |
progress(index / total_progress, desc="Translating..")
|
180 |
-
translated_text =
|
181 |
dic["text"] = translated_text
|
182 |
|
183 |
return input_list_dict
|
@@ -185,7 +188,7 @@ class TranslationBase(ABC):
|
|
185 |
except Exception as e:
|
186 |
print(f"Error: {str(e)}")
|
187 |
finally:
|
188 |
-
|
189 |
|
190 |
@staticmethod
|
191 |
def get_device():
|
|
|
9 |
from modules.utils.subtitle_manager import *
|
10 |
from modules.utils.files_manager import load_yaml, save_yaml
|
11 |
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
|
12 |
+
from modules.translation.inference import NLLBInference
|
13 |
|
14 |
class TranslationBase(ABC):
|
15 |
def __init__(self,
|
|
|
171 |
List[dict] with translation
|
172 |
"""
|
173 |
try:
|
174 |
+
self.cache_parameters(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,max_length=max_length,add_timestamp=add_timestamp)
|
175 |
+
#self.update_model(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,progress=progress)
|
176 |
|
177 |
+
NLLBInferenceInstance = NLLBInference()
|
178 |
+
NLLBInferenceInstance.update_model(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,progress=progress)
|
179 |
+
|
180 |
total_progress = len(input_list_dict)
|
181 |
for index, dic in enumerate(input_list_dict):
|
182 |
progress(index / total_progress, desc="Translating..")
|
183 |
+
translated_text = self.translate(dic["text"], max_length=max_length)
|
184 |
dic["text"] = translated_text
|
185 |
|
186 |
return input_list_dict
|
|
|
188 |
except Exception as e:
|
189 |
print(f"Error: {str(e)}")
|
190 |
finally:
|
191 |
+
self.release_cuda_memory()
|
192 |
|
193 |
@staticmethod
|
194 |
def get_device():
|