LAP-DEV commited on
Commit
3eb2f5a
·
verified ·
1 Parent(s): fa14ba9

Upload 4 files

Browse files
modules/translation/deepl_api.py CHANGED
@@ -5,6 +5,7 @@ from datetime import datetime
5
  import gradio as gr
6
 
7
  from modules.utils.paths import TRANSLATION_OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH
 
8
  from modules.utils.subtitle_manager import *
9
  from modules.utils.files_manager import load_yaml, save_yaml
10
 
@@ -50,7 +51,7 @@ DEEPL_AVAILABLE_TARGET_LANGS = {
50
  }
51
 
52
  DEEPL_AVAILABLE_SOURCE_LANGS = {
53
- 'Automatic Detection': None,
54
  'Bulgarian': 'BG',
55
  'Czech': 'CS',
56
  'Danish': 'DA',
@@ -138,37 +139,27 @@ class DeepLAPI:
138
  )
139
 
140
  files_info = {}
141
- for fileobj in fileobjs:
142
- file_path = fileobj
143
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
144
-
145
- if file_ext == ".srt":
146
- parsed_dicts = parse_srt(file_path=file_path)
147
-
148
- elif file_ext == ".vtt":
149
- parsed_dicts = parse_vtt(file_path=file_path)
150
 
151
  batch_size = self.max_text_batch_size
152
- for batch_start in range(0, len(parsed_dicts), batch_size):
153
- batch_end = min(batch_start + batch_size, len(parsed_dicts))
154
- sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
155
  translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
156
  target_lang, is_pro)
157
  for i, translated_text in enumerate(translated_texts):
158
- parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
159
- progress(batch_end / len(parsed_dicts), desc="Translating..")
160
-
161
- if file_ext == ".srt":
162
- subtitle = get_serialized_srt(parsed_dicts)
163
- elif file_ext == ".vtt":
164
- subtitle = get_serialized_vtt(parsed_dicts)
165
-
166
- if add_timestamp:
167
- timestamp = datetime.now().strftime("%m%d%H%M%S")
168
- file_name += f"-{timestamp}"
169
-
170
- output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
171
- write_file(subtitle, output_path)
172
 
173
  files_info[file_name] = {"subtitle": subtitle, "path": output_path}
174
 
 
5
  import gradio as gr
6
 
7
  from modules.utils.paths import TRANSLATION_OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH
8
+ from modules.utils.constants import AUTOMATIC_DETECTION
9
  from modules.utils.subtitle_manager import *
10
  from modules.utils.files_manager import load_yaml, save_yaml
11
 
 
51
  }
52
 
53
  DEEPL_AVAILABLE_SOURCE_LANGS = {
54
+ AUTOMATIC_DETECTION: None,
55
  'Bulgarian': 'BG',
56
  'Czech': 'CS',
57
  'Danish': 'DA',
 
139
  )
140
 
141
  files_info = {}
142
+ for file_path in fileobjs:
143
+ file_name, file_ext = os.path.splitext(os.path.basename(file_path))
144
+ writer = get_writer(file_ext, self.output_dir)
145
+ segments = writer.to_segments(file_path)
 
 
 
 
 
146
 
147
  batch_size = self.max_text_batch_size
148
+ for batch_start in range(0, len(segments), batch_size):
149
+ progress(batch_start / len(segments), desc="Translating..")
150
+ sentences_to_translate = [seg.text for seg in segments[batch_start:batch_start+batch_size]]
151
  translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
152
  target_lang, is_pro)
153
  for i, translated_text in enumerate(translated_texts):
154
+ segments[batch_start + i].text = translated_text["text"]
155
+
156
+ subtitle, output_path = generate_file(
157
+ output_dir=self.output_dir,
158
+ output_file_name=file_name,
159
+ output_format=file_ext,
160
+ result=segments,
161
+ add_timestamp=add_timestamp
162
+ )
 
 
 
 
 
163
 
164
  files_info[file_name] = {"subtitle": subtitle, "path": output_path}
165
 
modules/translation/nllb_inference.py CHANGED
@@ -3,10 +3,10 @@ import gradio as gr
3
  import os
4
 
5
  from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
6
- from modules.translation.translation_base import TranslationBase
7
 
8
 
9
- class NLLBInference(TranslationBase):
10
  def __init__(self,
11
  model_dir: str = NLLB_MODELS_DIR,
12
  output_dir: str = TRANSLATION_OUTPUT_DIR
@@ -29,7 +29,7 @@ class NLLBInference(TranslationBase):
29
  text,
30
  max_length=max_length
31
  )
32
- return result[0]['translation_text']
33
 
34
  def update_model(self,
35
  model_size: str,
@@ -41,8 +41,7 @@ class NLLBInference(TranslationBase):
41
  if lang in NLLB_AVAILABLE_LANGS:
42
  return NLLB_AVAILABLE_LANGS[lang]
43
  elif lang not in NLLB_AVAILABLE_LANGS.values():
44
- raise ValueError(
45
- f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
46
  return lang
47
 
48
  src_lang = validate_language(src_lang)
 
3
  import os
4
 
5
  from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
6
+ import modules.translation.translation_base as base
7
 
8
 
9
+ class NLLBInference(base.TranslationBase):
10
  def __init__(self,
11
  model_dir: str = NLLB_MODELS_DIR,
12
  output_dir: str = TRANSLATION_OUTPUT_DIR
 
29
  text,
30
  max_length=max_length
31
  )
32
+ return result[0]["translation_text"]
33
 
34
  def update_model(self,
35
  model_size: str,
 
41
  if lang in NLLB_AVAILABLE_LANGS:
42
  return NLLB_AVAILABLE_LANGS[lang]
43
  elif lang not in NLLB_AVAILABLE_LANGS.values():
44
+ raise ValueError(f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
 
45
  return lang
46
 
47
  src_lang = validate_language(src_lang)
modules/translation/translation_base.py CHANGED
@@ -2,14 +2,17 @@ import os
2
  import torch
3
  import gradio as gr
4
  from abc import ABC, abstractmethod
 
5
  from typing import List
6
  from datetime import datetime
7
 
8
- from modules.whisper.whisper_parameter import *
 
9
  from modules.utils.subtitle_manager import *
10
  from modules.utils.files_manager import load_yaml, save_yaml
11
  from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
12
 
 
13
  class TranslationBase(ABC):
14
  def __init__(self,
15
  model_dir: str = NLLB_MODELS_DIR,
@@ -93,32 +96,22 @@ class TranslationBase(ABC):
93
  files_info = {}
94
  for fileobj in fileobjs:
95
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
96
- if file_ext == ".srt":
97
- parsed_dicts = parse_srt(file_path=fileobj)
98
- total_progress = len(parsed_dicts)
99
- for index, dic in enumerate(parsed_dicts):
100
- progress(index / total_progress, desc="Translating..")
101
- translated_text = self.translate(dic["sentence"], max_length=max_length)
102
- dic["sentence"] = translated_text
103
- subtitle = get_serialized_srt(parsed_dicts)
104
-
105
- elif file_ext == ".vtt":
106
- parsed_dicts = parse_vtt(file_path=fileobj)
107
- total_progress = len(parsed_dicts)
108
- for index, dic in enumerate(parsed_dicts):
109
- progress(index / total_progress, desc="Translating..")
110
- translated_text = self.translate(dic["sentence"], max_length=max_length)
111
- dic["sentence"] = translated_text
112
- subtitle = get_serialized_vtt(parsed_dicts)
113
-
114
- if add_timestamp:
115
- timestamp = datetime.now().strftime("%m%d%H%M%S")
116
- file_name += f"-{timestamp}"
117
-
118
- output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
119
- write_file(subtitle, output_path)
120
-
121
- files_info[file_name] = {"subtitle": subtitle, "path": output_path}
122
 
123
  total_result = ''
124
  for file_name, info in files_info.items():
@@ -131,61 +124,20 @@ class TranslationBase(ABC):
131
  return [gr_str, output_file_paths]
132
 
133
  except Exception as e:
134
- print(f"Error: {str(e)}")
 
135
  finally:
136
  self.release_cuda_memory()
137
 
138
- def translate_text(self,
139
- input_list_dict: list,
140
- model_size: str,
141
- src_lang: str,
142
- tgt_lang: str,
143
- max_length: int = 200,
144
- add_timestamp: bool = True,
145
- progress=gr.Progress()) -> list:
146
- """
147
- Translate text from source language to target language
148
-
149
- Parameters
150
- ----------
151
- str_text: str
152
- List[dict] to translate
153
- model_size: str
154
- Whisper model size from gr.Dropdown()
155
- src_lang: str
156
- Source language of the file to translate from gr.Dropdown()
157
- tgt_lang: str
158
- Target language of the file to translate from gr.Dropdown()
159
- max_length: int
160
- Max length per line to translate
161
- add_timestamp: bool
162
- Boolean value that determines whether to add a timestamp
163
- progress: gr.Progress
164
- Indicator to show progress directly in gradio.
165
- I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
166
-
167
- Returns
168
- ----------
169
- A List of
170
- List[dict] with translation
171
- """
172
- try:
173
- self.cache_parameters(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,max_length=max_length,add_timestamp=add_timestamp)
174
- self.update_model(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,progress=progress)
175
-
176
- total_progress = len(input_list_dict)
177
- for index, dic in enumerate(input_list_dict):
178
- progress(index / total_progress, desc="Translating..")
179
- translated_text = self.translate(dic["text"], max_length=max_length)
180
- dic["text"] = translated_text
181
-
182
- return input_list_dict
183
-
184
- except Exception as e:
185
- print(f"Error: {str(e)}")
186
- finally:
187
  self.release_cuda_memory()
188
-
 
189
  @staticmethod
190
  def get_device():
191
  if torch.cuda.is_available():
@@ -216,11 +168,17 @@ class TranslationBase(ABC):
216
  tgt_lang: str,
217
  max_length: int,
218
  add_timestamp: bool):
 
 
 
 
 
 
219
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
220
  cached_params["translation"]["nllb"] = {
221
  "model_size": model_size,
222
- "source_lang": src_lang,
223
- "target_lang": tgt_lang,
224
  "max_length": max_length,
225
  }
226
  cached_params["translation"]["add_timestamp"] = add_timestamp
 
2
  import torch
3
  import gradio as gr
4
  from abc import ABC, abstractmethod
5
+ import gc
6
  from typing import List
7
  from datetime import datetime
8
 
9
+ import modules.translation.nllb_inference as nllb
10
+ from modules.whisper.data_classes import *
11
  from modules.utils.subtitle_manager import *
12
  from modules.utils.files_manager import load_yaml, save_yaml
13
  from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
14
 
15
+
16
  class TranslationBase(ABC):
17
  def __init__(self,
18
  model_dir: str = NLLB_MODELS_DIR,
 
96
  files_info = {}
97
  for fileobj in fileobjs:
98
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
99
+ writer = get_writer(file_ext, self.output_dir)
100
+ segments = writer.to_segments(fileobj)
101
+ for i, segment in enumerate(segments):
102
+ progress(i / len(segments), desc="Translating..")
103
+ translated_text = self.translate(segment.text, max_length=max_length)
104
+ segment.text = translated_text
105
+
106
+ subtitle, file_path = generate_file(
107
+ output_dir=self.output_dir,
108
+ output_file_name=file_name,
109
+ output_format=file_ext,
110
+ result=segments,
111
+ add_timestamp=add_timestamp
112
+ )
113
+
114
+ files_info[file_name] = {"subtitle": subtitle, "path": file_path}
 
 
 
 
 
 
 
 
 
 
115
 
116
  total_result = ''
117
  for file_name, info in files_info.items():
 
124
  return [gr_str, output_file_paths]
125
 
126
  except Exception as e:
127
+ print(f"Error translating file: {e}")
128
+ raise
129
  finally:
130
  self.release_cuda_memory()
131
 
132
+ def offload(self):
133
+ """Offload the model and free up the memory"""
134
+ if self.model is not None:
135
+ del self.model
136
+ self.model = None
137
+ if self.device == "cuda":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  self.release_cuda_memory()
139
+ gc.collect()
140
+
141
  @staticmethod
142
  def get_device():
143
  if torch.cuda.is_available():
 
168
  tgt_lang: str,
169
  max_length: int,
170
  add_timestamp: bool):
171
+ def validate_lang(lang: str):
172
+ if lang in list(nllb.NLLB_AVAILABLE_LANGS.values()):
173
+ flipped = {value: key for key, value in nllb.NLLB_AVAILABLE_LANGS.items()}
174
+ return flipped[lang]
175
+ return lang
176
+
177
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
178
  cached_params["translation"]["nllb"] = {
179
  "model_size": model_size,
180
+ "source_lang": validate_lang(src_lang),
181
+ "target_lang": validate_lang(tgt_lang),
182
  "max_length": max_length,
183
  }
184
  cached_params["translation"]["add_timestamp"] = add_timestamp