Upload 4 files
Browse files
modules/translation/deepl_api.py
CHANGED
@@ -5,6 +5,7 @@ from datetime import datetime
|
|
5 |
import gradio as gr
|
6 |
|
7 |
from modules.utils.paths import TRANSLATION_OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH
|
|
|
8 |
from modules.utils.subtitle_manager import *
|
9 |
from modules.utils.files_manager import load_yaml, save_yaml
|
10 |
|
@@ -50,7 +51,7 @@ DEEPL_AVAILABLE_TARGET_LANGS = {
|
|
50 |
}
|
51 |
|
52 |
DEEPL_AVAILABLE_SOURCE_LANGS = {
|
53 |
-
|
54 |
'Bulgarian': 'BG',
|
55 |
'Czech': 'CS',
|
56 |
'Danish': 'DA',
|
@@ -138,37 +139,27 @@ class DeepLAPI:
|
|
138 |
)
|
139 |
|
140 |
files_info = {}
|
141 |
-
for
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
if file_ext == ".srt":
|
146 |
-
parsed_dicts = parse_srt(file_path=file_path)
|
147 |
-
|
148 |
-
elif file_ext == ".vtt":
|
149 |
-
parsed_dicts = parse_vtt(file_path=file_path)
|
150 |
|
151 |
batch_size = self.max_text_batch_size
|
152 |
-
for batch_start in range(0, len(
|
153 |
-
|
154 |
-
sentences_to_translate = [
|
155 |
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
|
156 |
target_lang, is_pro)
|
157 |
for i, translated_text in enumerate(translated_texts):
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
168 |
-
file_name += f"-{timestamp}"
|
169 |
-
|
170 |
-
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
|
171 |
-
write_file(subtitle, output_path)
|
172 |
|
173 |
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
|
174 |
|
|
|
5 |
import gradio as gr
|
6 |
|
7 |
from modules.utils.paths import TRANSLATION_OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH
|
8 |
+
from modules.utils.constants import AUTOMATIC_DETECTION
|
9 |
from modules.utils.subtitle_manager import *
|
10 |
from modules.utils.files_manager import load_yaml, save_yaml
|
11 |
|
|
|
51 |
}
|
52 |
|
53 |
DEEPL_AVAILABLE_SOURCE_LANGS = {
|
54 |
+
AUTOMATIC_DETECTION: None,
|
55 |
'Bulgarian': 'BG',
|
56 |
'Czech': 'CS',
|
57 |
'Danish': 'DA',
|
|
|
139 |
)
|
140 |
|
141 |
files_info = {}
|
142 |
+
for file_path in fileobjs:
|
143 |
+
file_name, file_ext = os.path.splitext(os.path.basename(file_path))
|
144 |
+
writer = get_writer(file_ext, self.output_dir)
|
145 |
+
segments = writer.to_segments(file_path)
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
batch_size = self.max_text_batch_size
|
148 |
+
for batch_start in range(0, len(segments), batch_size):
|
149 |
+
progress(batch_start / len(segments), desc="Translating..")
|
150 |
+
sentences_to_translate = [seg.text for seg in segments[batch_start:batch_start+batch_size]]
|
151 |
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
|
152 |
target_lang, is_pro)
|
153 |
for i, translated_text in enumerate(translated_texts):
|
154 |
+
segments[batch_start + i].text = translated_text["text"]
|
155 |
+
|
156 |
+
subtitle, output_path = generate_file(
|
157 |
+
output_dir=self.output_dir,
|
158 |
+
output_file_name=file_name,
|
159 |
+
output_format=file_ext,
|
160 |
+
result=segments,
|
161 |
+
add_timestamp=add_timestamp
|
162 |
+
)
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
|
165 |
|
modules/translation/nllb_inference.py
CHANGED
@@ -3,10 +3,10 @@ import gradio as gr
|
|
3 |
import os
|
4 |
|
5 |
from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
|
6 |
-
|
7 |
|
8 |
|
9 |
-
class NLLBInference(TranslationBase):
|
10 |
def __init__(self,
|
11 |
model_dir: str = NLLB_MODELS_DIR,
|
12 |
output_dir: str = TRANSLATION_OUTPUT_DIR
|
@@ -29,7 +29,7 @@ class NLLBInference(TranslationBase):
|
|
29 |
text,
|
30 |
max_length=max_length
|
31 |
)
|
32 |
-
return result[0][
|
33 |
|
34 |
def update_model(self,
|
35 |
model_size: str,
|
@@ -41,8 +41,7 @@ class NLLBInference(TranslationBase):
|
|
41 |
if lang in NLLB_AVAILABLE_LANGS:
|
42 |
return NLLB_AVAILABLE_LANGS[lang]
|
43 |
elif lang not in NLLB_AVAILABLE_LANGS.values():
|
44 |
-
raise ValueError(
|
45 |
-
f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
|
46 |
return lang
|
47 |
|
48 |
src_lang = validate_language(src_lang)
|
|
|
3 |
import os
|
4 |
|
5 |
from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
|
6 |
+
import modules.translation.translation_base as base
|
7 |
|
8 |
|
9 |
+
class NLLBInference(base.TranslationBase):
|
10 |
def __init__(self,
|
11 |
model_dir: str = NLLB_MODELS_DIR,
|
12 |
output_dir: str = TRANSLATION_OUTPUT_DIR
|
|
|
29 |
text,
|
30 |
max_length=max_length
|
31 |
)
|
32 |
+
return result[0]["translation_text"]
|
33 |
|
34 |
def update_model(self,
|
35 |
model_size: str,
|
|
|
41 |
if lang in NLLB_AVAILABLE_LANGS:
|
42 |
return NLLB_AVAILABLE_LANGS[lang]
|
43 |
elif lang not in NLLB_AVAILABLE_LANGS.values():
|
44 |
+
raise ValueError(f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
|
|
|
45 |
return lang
|
46 |
|
47 |
src_lang = validate_language(src_lang)
|
modules/translation/translation_base.py
CHANGED
@@ -2,14 +2,17 @@ import os
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
from abc import ABC, abstractmethod
|
|
|
5 |
from typing import List
|
6 |
from datetime import datetime
|
7 |
|
8 |
-
|
|
|
9 |
from modules.utils.subtitle_manager import *
|
10 |
from modules.utils.files_manager import load_yaml, save_yaml
|
11 |
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
|
12 |
|
|
|
13 |
class TranslationBase(ABC):
|
14 |
def __init__(self,
|
15 |
model_dir: str = NLLB_MODELS_DIR,
|
@@ -93,32 +96,22 @@ class TranslationBase(ABC):
|
|
93 |
files_info = {}
|
94 |
for fileobj in fileobjs:
|
95 |
file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
subtitle = get_serialized_vtt(parsed_dicts)
|
113 |
-
|
114 |
-
if add_timestamp:
|
115 |
-
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
116 |
-
file_name += f"-{timestamp}"
|
117 |
-
|
118 |
-
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
|
119 |
-
write_file(subtitle, output_path)
|
120 |
-
|
121 |
-
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
|
122 |
|
123 |
total_result = ''
|
124 |
for file_name, info in files_info.items():
|
@@ -131,61 +124,20 @@ class TranslationBase(ABC):
|
|
131 |
return [gr_str, output_file_paths]
|
132 |
|
133 |
except Exception as e:
|
134 |
-
print(f"Error: {
|
|
|
135 |
finally:
|
136 |
self.release_cuda_memory()
|
137 |
|
138 |
-
def
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
add_timestamp: bool = True,
|
145 |
-
progress=gr.Progress()) -> list:
|
146 |
-
"""
|
147 |
-
Translate text from source language to target language
|
148 |
-
|
149 |
-
Parameters
|
150 |
-
----------
|
151 |
-
str_text: str
|
152 |
-
List[dict] to translate
|
153 |
-
model_size: str
|
154 |
-
Whisper model size from gr.Dropdown()
|
155 |
-
src_lang: str
|
156 |
-
Source language of the file to translate from gr.Dropdown()
|
157 |
-
tgt_lang: str
|
158 |
-
Target language of the file to translate from gr.Dropdown()
|
159 |
-
max_length: int
|
160 |
-
Max length per line to translate
|
161 |
-
add_timestamp: bool
|
162 |
-
Boolean value that determines whether to add a timestamp
|
163 |
-
progress: gr.Progress
|
164 |
-
Indicator to show progress directly in gradio.
|
165 |
-
I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
|
166 |
-
|
167 |
-
Returns
|
168 |
-
----------
|
169 |
-
A List of
|
170 |
-
List[dict] with translation
|
171 |
-
"""
|
172 |
-
try:
|
173 |
-
self.cache_parameters(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,max_length=max_length,add_timestamp=add_timestamp)
|
174 |
-
self.update_model(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,progress=progress)
|
175 |
-
|
176 |
-
total_progress = len(input_list_dict)
|
177 |
-
for index, dic in enumerate(input_list_dict):
|
178 |
-
progress(index / total_progress, desc="Translating..")
|
179 |
-
translated_text = self.translate(dic["text"], max_length=max_length)
|
180 |
-
dic["text"] = translated_text
|
181 |
-
|
182 |
-
return input_list_dict
|
183 |
-
|
184 |
-
except Exception as e:
|
185 |
-
print(f"Error: {str(e)}")
|
186 |
-
finally:
|
187 |
self.release_cuda_memory()
|
188 |
-
|
|
|
189 |
@staticmethod
|
190 |
def get_device():
|
191 |
if torch.cuda.is_available():
|
@@ -216,11 +168,17 @@ class TranslationBase(ABC):
|
|
216 |
tgt_lang: str,
|
217 |
max_length: int,
|
218 |
add_timestamp: bool):
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
220 |
cached_params["translation"]["nllb"] = {
|
221 |
"model_size": model_size,
|
222 |
-
"source_lang": src_lang,
|
223 |
-
"target_lang": tgt_lang,
|
224 |
"max_length": max_length,
|
225 |
}
|
226 |
cached_params["translation"]["add_timestamp"] = add_timestamp
|
|
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
from abc import ABC, abstractmethod
|
5 |
+
import gc
|
6 |
from typing import List
|
7 |
from datetime import datetime
|
8 |
|
9 |
+
import modules.translation.nllb_inference as nllb
|
10 |
+
from modules.whisper.data_classes import *
|
11 |
from modules.utils.subtitle_manager import *
|
12 |
from modules.utils.files_manager import load_yaml, save_yaml
|
13 |
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
|
14 |
|
15 |
+
|
16 |
class TranslationBase(ABC):
|
17 |
def __init__(self,
|
18 |
model_dir: str = NLLB_MODELS_DIR,
|
|
|
96 |
files_info = {}
|
97 |
for fileobj in fileobjs:
|
98 |
file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
|
99 |
+
writer = get_writer(file_ext, self.output_dir)
|
100 |
+
segments = writer.to_segments(fileobj)
|
101 |
+
for i, segment in enumerate(segments):
|
102 |
+
progress(i / len(segments), desc="Translating..")
|
103 |
+
translated_text = self.translate(segment.text, max_length=max_length)
|
104 |
+
segment.text = translated_text
|
105 |
+
|
106 |
+
subtitle, file_path = generate_file(
|
107 |
+
output_dir=self.output_dir,
|
108 |
+
output_file_name=file_name,
|
109 |
+
output_format=file_ext,
|
110 |
+
result=segments,
|
111 |
+
add_timestamp=add_timestamp
|
112 |
+
)
|
113 |
+
|
114 |
+
files_info[file_name] = {"subtitle": subtitle, "path": file_path}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
total_result = ''
|
117 |
for file_name, info in files_info.items():
|
|
|
124 |
return [gr_str, output_file_paths]
|
125 |
|
126 |
except Exception as e:
|
127 |
+
print(f"Error translating file: {e}")
|
128 |
+
raise
|
129 |
finally:
|
130 |
self.release_cuda_memory()
|
131 |
|
132 |
+
def offload(self):
|
133 |
+
"""Offload the model and free up the memory"""
|
134 |
+
if self.model is not None:
|
135 |
+
del self.model
|
136 |
+
self.model = None
|
137 |
+
if self.device == "cuda":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
self.release_cuda_memory()
|
139 |
+
gc.collect()
|
140 |
+
|
141 |
@staticmethod
|
142 |
def get_device():
|
143 |
if torch.cuda.is_available():
|
|
|
168 |
tgt_lang: str,
|
169 |
max_length: int,
|
170 |
add_timestamp: bool):
|
171 |
+
def validate_lang(lang: str):
|
172 |
+
if lang in list(nllb.NLLB_AVAILABLE_LANGS.values()):
|
173 |
+
flipped = {value: key for key, value in nllb.NLLB_AVAILABLE_LANGS.items()}
|
174 |
+
return flipped[lang]
|
175 |
+
return lang
|
176 |
+
|
177 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
178 |
cached_params["translation"]["nllb"] = {
|
179 |
"model_size": model_size,
|
180 |
+
"source_lang": validate_lang(src_lang),
|
181 |
+
"target_lang": validate_lang(tgt_lang),
|
182 |
"max_length": max_length,
|
183 |
}
|
184 |
cached_params["translation"]["add_timestamp"] = add_timestamp
|