LAP-DEV commited on
Commit
fe1e730
·
verified ·
1 Parent(s): 9bd1fc9

Upload 4 files

Browse files
modules/translation/deepl_api.py CHANGED
@@ -5,7 +5,6 @@ from datetime import datetime
5
  import gradio as gr
6
 
7
  from modules.utils.paths import TRANSLATION_OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH
8
- from modules.utils.constants import AUTOMATIC_DETECTION
9
  from modules.utils.subtitle_manager import *
10
  from modules.utils.files_manager import load_yaml, save_yaml
11
 
@@ -51,7 +50,7 @@ DEEPL_AVAILABLE_TARGET_LANGS = {
51
  }
52
 
53
  DEEPL_AVAILABLE_SOURCE_LANGS = {
54
- AUTOMATIC_DETECTION: None,
55
  'Bulgarian': 'BG',
56
  'Czech': 'CS',
57
  'Danish': 'DA',
@@ -139,27 +138,37 @@ class DeepLAPI:
139
  )
140
 
141
  files_info = {}
142
- for file_path in fileobjs:
143
- file_name, file_ext = os.path.splitext(os.path.basename(file_path))
144
- writer = get_writer(file_ext, self.output_dir)
145
- segments = writer.to_segments(file_path)
 
 
 
 
 
146
 
147
  batch_size = self.max_text_batch_size
148
- for batch_start in range(0, len(segments), batch_size):
149
- progress(batch_start / len(segments), desc="Translating..")
150
- sentences_to_translate = [seg.text for seg in segments[batch_start:batch_start+batch_size]]
151
  translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
152
  target_lang, is_pro)
153
  for i, translated_text in enumerate(translated_texts):
154
- segments[batch_start + i].text = translated_text["text"]
155
-
156
- subtitle, output_path = generate_file(
157
- output_dir=self.output_dir,
158
- output_file_name=file_name,
159
- output_format=file_ext,
160
- result=segments,
161
- add_timestamp=add_timestamp
162
- )
 
 
 
 
 
163
 
164
  files_info[file_name] = {"subtitle": subtitle, "path": output_path}
165
 
 
5
  import gradio as gr
6
 
7
  from modules.utils.paths import TRANSLATION_OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH
 
8
  from modules.utils.subtitle_manager import *
9
  from modules.utils.files_manager import load_yaml, save_yaml
10
 
 
50
  }
51
 
52
  DEEPL_AVAILABLE_SOURCE_LANGS = {
53
+ 'Automatic Detection': None,
54
  'Bulgarian': 'BG',
55
  'Czech': 'CS',
56
  'Danish': 'DA',
 
138
  )
139
 
140
  files_info = {}
141
+ for fileobj in fileobjs:
142
+ file_path = fileobj
143
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
144
+
145
+ if file_ext == ".srt":
146
+ parsed_dicts = parse_srt(file_path=file_path)
147
+
148
+ elif file_ext == ".vtt":
149
+ parsed_dicts = parse_vtt(file_path=file_path)
150
 
151
  batch_size = self.max_text_batch_size
152
+ for batch_start in range(0, len(parsed_dicts), batch_size):
153
+ batch_end = min(batch_start + batch_size, len(parsed_dicts))
154
+ sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
155
  translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
156
  target_lang, is_pro)
157
  for i, translated_text in enumerate(translated_texts):
158
+ parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
159
+ progress(batch_end / len(parsed_dicts), desc="Translating..")
160
+
161
+ if file_ext == ".srt":
162
+ subtitle = get_serialized_srt(parsed_dicts)
163
+ elif file_ext == ".vtt":
164
+ subtitle = get_serialized_vtt(parsed_dicts)
165
+
166
+ if add_timestamp:
167
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
168
+ file_name += f"-{timestamp}"
169
+
170
+ output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
171
+ write_file(subtitle, output_path)
172
 
173
  files_info[file_name] = {"subtitle": subtitle, "path": output_path}
174
 
modules/translation/nllb_inference.py CHANGED
@@ -3,10 +3,10 @@ import gradio as gr
3
  import os
4
 
5
  from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
6
- import modules.translation.translation_base as base
7
 
8
 
9
- class NLLBInference(base.TranslationBase):
10
  def __init__(self,
11
  model_dir: str = NLLB_MODELS_DIR,
12
  output_dir: str = TRANSLATION_OUTPUT_DIR
@@ -29,7 +29,7 @@ class NLLBInference(base.TranslationBase):
29
  text,
30
  max_length=max_length
31
  )
32
- return result[0]["translation_text"]
33
 
34
  def update_model(self,
35
  model_size: str,
@@ -41,7 +41,8 @@ class NLLBInference(base.TranslationBase):
41
  if lang in NLLB_AVAILABLE_LANGS:
42
  return NLLB_AVAILABLE_LANGS[lang]
43
  elif lang not in NLLB_AVAILABLE_LANGS.values():
44
- raise ValueError(f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
 
45
  return lang
46
 
47
  src_lang = validate_language(src_lang)
 
3
  import os
4
 
5
  from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
6
+ from modules.translation.translation_base import TranslationBase
7
 
8
 
9
+ class NLLBInference(TranslationBase):
10
  def __init__(self,
11
  model_dir: str = NLLB_MODELS_DIR,
12
  output_dir: str = TRANSLATION_OUTPUT_DIR
 
29
  text,
30
  max_length=max_length
31
  )
32
+ return result[0]['translation_text']
33
 
34
  def update_model(self,
35
  model_size: str,
 
41
  if lang in NLLB_AVAILABLE_LANGS:
42
  return NLLB_AVAILABLE_LANGS[lang]
43
  elif lang not in NLLB_AVAILABLE_LANGS.values():
44
+ raise ValueError(
45
+ f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
46
  return lang
47
 
48
  src_lang = validate_language(src_lang)
modules/translation/translation_base.py CHANGED
@@ -2,11 +2,10 @@ import os
2
  import torch
3
  import gradio as gr
4
  from abc import ABC, abstractmethod
5
- import gc
6
  from typing import List
7
  from datetime import datetime
8
 
9
- import modules.translation.nllb_inference as nllb
10
  from modules.utils.subtitle_manager import *
11
  from modules.utils.files_manager import load_yaml, save_yaml
12
  from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
@@ -95,22 +94,32 @@ class TranslationBase(ABC):
95
  files_info = {}
96
  for fileobj in fileobjs:
97
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
98
- writer = get_writer(file_ext, self.output_dir)
99
- segments = writer.to_segments(fileobj)
100
- for i, segment in enumerate(segments):
101
- progress(i / len(segments), desc="Translating..")
102
- translated_text = self.translate(segment.text, max_length=max_length)
103
- segment.text = translated_text
104
-
105
- subtitle, file_path = generate_file(
106
- output_dir=self.output_dir,
107
- output_file_name=file_name,
108
- output_format=file_ext,
109
- result=segments,
110
- add_timestamp=add_timestamp
111
- )
112
-
113
- files_info[file_name] = {"subtitle": subtitle, "path": file_path}
 
 
 
 
 
 
 
 
 
 
114
 
115
  total_result = ''
116
  for file_name, info in files_info.items():
@@ -123,20 +132,10 @@ class TranslationBase(ABC):
123
  return [gr_str, output_file_paths]
124
 
125
  except Exception as e:
126
- print(f"Error translating file: {e}")
127
- raise
128
  finally:
129
  self.release_cuda_memory()
130
 
131
- def offload(self):
132
- """Offload the model and free up the memory"""
133
- if self.model is not None:
134
- del self.model
135
- self.model = None
136
- if self.device == "cuda":
137
- self.release_cuda_memory()
138
- gc.collect()
139
-
140
  @staticmethod
141
  def get_device():
142
  if torch.cuda.is_available():
@@ -167,17 +166,11 @@ class TranslationBase(ABC):
167
  tgt_lang: str,
168
  max_length: int,
169
  add_timestamp: bool):
170
- def validate_lang(lang: str):
171
- if lang in list(nllb.NLLB_AVAILABLE_LANGS.values()):
172
- flipped = {value: key for key, value in nllb.NLLB_AVAILABLE_LANGS.items()}
173
- return flipped[lang]
174
- return lang
175
-
176
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
177
  cached_params["translation"]["nllb"] = {
178
  "model_size": model_size,
179
- "source_lang": validate_lang(src_lang),
180
- "target_lang": validate_lang(tgt_lang),
181
  "max_length": max_length,
182
  }
183
  cached_params["translation"]["add_timestamp"] = add_timestamp
 
2
  import torch
3
  import gradio as gr
4
  from abc import ABC, abstractmethod
 
5
  from typing import List
6
  from datetime import datetime
7
 
8
+ from modules.whisper.whisper_parameter import *
9
  from modules.utils.subtitle_manager import *
10
  from modules.utils.files_manager import load_yaml, save_yaml
11
  from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
 
94
  files_info = {}
95
  for fileobj in fileobjs:
96
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
97
+ if file_ext == ".srt":
98
+ parsed_dicts = parse_srt(file_path=fileobj)
99
+ total_progress = len(parsed_dicts)
100
+ for index, dic in enumerate(parsed_dicts):
101
+ progress(index / total_progress, desc="Translating..")
102
+ translated_text = self.translate(dic["sentence"], max_length=max_length)
103
+ dic["sentence"] = translated_text
104
+ subtitle = get_serialized_srt(parsed_dicts)
105
+
106
+ elif file_ext == ".vtt":
107
+ parsed_dicts = parse_vtt(file_path=fileobj)
108
+ total_progress = len(parsed_dicts)
109
+ for index, dic in enumerate(parsed_dicts):
110
+ progress(index / total_progress, desc="Translating..")
111
+ translated_text = self.translate(dic["sentence"], max_length=max_length)
112
+ dic["sentence"] = translated_text
113
+ subtitle = get_serialized_vtt(parsed_dicts)
114
+
115
+ if add_timestamp:
116
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
117
+ file_name += f"-{timestamp}"
118
+
119
+ output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
120
+ write_file(subtitle, output_path)
121
+
122
+ files_info[file_name] = {"subtitle": subtitle, "path": output_path}
123
 
124
  total_result = ''
125
  for file_name, info in files_info.items():
 
132
  return [gr_str, output_file_paths]
133
 
134
  except Exception as e:
135
+ print(f"Error: {str(e)}")
 
136
  finally:
137
  self.release_cuda_memory()
138
 
 
 
 
 
 
 
 
 
 
139
  @staticmethod
140
  def get_device():
141
  if torch.cuda.is_available():
 
166
  tgt_lang: str,
167
  max_length: int,
168
  add_timestamp: bool):
 
 
 
 
 
 
169
  cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
170
  cached_params["translation"]["nllb"] = {
171
  "model_size": model_size,
172
+ "source_lang": src_lang,
173
+ "target_lang": tgt_lang,
174
  "max_length": max_length,
175
  }
176
  cached_params["translation"]["add_timestamp"] = add_timestamp