|
import os |
|
import re |
|
import torch |
|
import gradio as gr |
|
from abc import ABC, abstractmethod |
|
from typing import List |
|
from datetime import datetime |
|
|
|
from modules.whisper.whisper_parameter import * |
|
from modules.utils.subtitle_manager import * |
|
from modules.utils.files_manager import load_yaml, save_yaml |
|
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR |
|
|
|
|
|
class TranslationBase(ABC): |
|
def __init__(self, |
|
model_dir: str = NLLB_MODELS_DIR, |
|
output_dir: str = TRANSLATION_OUTPUT_DIR |
|
): |
|
super().__init__() |
|
self.model = None |
|
self.model_dir = model_dir |
|
self.output_dir = output_dir |
|
os.makedirs(self.model_dir, exist_ok=True) |
|
os.makedirs(self.output_dir, exist_ok=True) |
|
self.current_model_size = None |
|
self.device = self.get_device() |
|
|
|
@abstractmethod |
|
def translate(self, |
|
text: str, |
|
max_length: int |
|
): |
|
pass |
|
|
|
@abstractmethod |
|
def update_model(self, |
|
model_size: str, |
|
src_lang: str, |
|
tgt_lang: str, |
|
progress: gr.Progress = gr.Progress() |
|
): |
|
pass |
|
|
|
def translate_file(self, |
|
fileobjs: list, |
|
model_size: str, |
|
src_lang: str, |
|
tgt_lang: str, |
|
max_length: int = 200, |
|
add_timestamp: bool = True, |
|
progress=gr.Progress()) -> list: |
|
""" |
|
Translate subtitle file from source language to target language |
|
|
|
Parameters |
|
---------- |
|
fileobjs: list |
|
List of files to transcribe from gr.Files() |
|
model_size: str |
|
Whisper model size from gr.Dropdown() |
|
src_lang: str |
|
Source language of the file to translate from gr.Dropdown() |
|
tgt_lang: str |
|
Target language of the file to translate from gr.Dropdown() |
|
max_length: int |
|
Max length per line to translate |
|
add_timestamp: bool |
|
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename. |
|
progress: gr.Progress |
|
Indicator to show progress directly in gradio. |
|
I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback |
|
|
|
Returns |
|
---------- |
|
A List of |
|
String to return to gr.Textbox() |
|
Files to return to gr.Files() |
|
""" |
|
try: |
|
if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString): |
|
fileobjs = [file.name for file in fileobjs] |
|
|
|
self.cache_parameters(model_size=model_size, |
|
src_lang=src_lang, |
|
tgt_lang=tgt_lang, |
|
max_length=max_length, |
|
add_timestamp=add_timestamp) |
|
|
|
self.update_model(model_size=model_size, |
|
src_lang=src_lang, |
|
tgt_lang=tgt_lang, |
|
progress=progress) |
|
|
|
files_info = {} |
|
for fileobj in fileobjs: |
|
file_name, file_ext = os.path.splitext(os.path.basename(fileobj)) |
|
if file_ext == ".srt": |
|
parsed_dicts = parse_srt(file_path=fileobj) |
|
total_progress = len(parsed_dicts) |
|
for index, dic in enumerate(parsed_dicts): |
|
progress(index / total_progress, desc="Translating...") |
|
translated_text = self.translate(dic["sentence"], max_length=max_length) |
|
dic["sentence"] = translated_text |
|
subtitle = get_serialized_srt(parsed_dicts) |
|
|
|
elif file_ext == ".vtt": |
|
parsed_dicts = parse_vtt(file_path=fileobj) |
|
total_progress = len(parsed_dicts) |
|
for index, dic in enumerate(parsed_dicts): |
|
progress(index / total_progress, desc="Translating...") |
|
translated_text = self.translate(dic["sentence"], max_length=max_length) |
|
dic["sentence"] = translated_text |
|
subtitle = get_serialized_vtt(parsed_dicts) |
|
|
|
if add_timestamp: |
|
timestamp = datetime.now().strftime("%m%d%H%M%S") |
|
file_name += f"-{timestamp}" |
|
|
|
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}") |
|
write_file(subtitle, output_path) |
|
|
|
files_info[file_name] = {"subtitle": subtitle, "path": output_path} |
|
|
|
total_result = '' |
|
for file_name, info in files_info.items(): |
|
total_result += '------------------------------------\n' |
|
total_result += f'{file_name}\n\n' |
|
total_result += f'{info["subtitle"]}' |
|
gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}" |
|
|
|
output_file_paths = [item["path"] for key, item in files_info.items()] |
|
return [gr_str, output_file_paths] |
|
|
|
except Exception as e: |
|
print(f"Error: {str(e)}") |
|
finally: |
|
self.release_cuda_memory() |
|
|
|
def translate_text(self, |
|
input_list_dict: list, |
|
model_size: str, |
|
src_lang: str, |
|
tgt_lang: str, |
|
speaker_diarization: bool = False, |
|
max_length: int = 200, |
|
add_timestamp: bool = True, |
|
progress=gr.Progress()) -> list: |
|
""" |
|
Translate text from source language to target language |
|
Parameters |
|
---------- |
|
str_text: str |
|
List[dict] to translate |
|
model_size: str |
|
Whisper model size from gr.Dropdown() |
|
src_lang: str |
|
Source language of the file to translate from gr.Dropdown() |
|
tgt_lang: str |
|
Target language of the file to translate from gr.Dropdown() |
|
speaker_diarization: bool |
|
Boolean value that determines whether diarization is enabled or not |
|
max_length: int |
|
Max length per line to translate |
|
add_timestamp: bool |
|
Boolean value that determines whether to add a timestamp |
|
progress: gr.Progress |
|
Indicator to show progress directly in gradio. |
|
I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback |
|
Returns |
|
---------- |
|
A List of |
|
List[dict] with translation |
|
""" |
|
|
|
try: |
|
if src_lang != tgt_lang: |
|
self.cache_parameters(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,max_length=max_length,add_timestamp=add_timestamp) |
|
self.update_model(model_size=model_size,src_lang=src_lang,tgt_lang=tgt_lang,progress=progress) |
|
|
|
total_progress = len(input_list_dict) |
|
for index, dic in enumerate(input_list_dict): |
|
progress(index / total_progress, desc="Translating...") |
|
|
|
|
|
if speaker_diarization: |
|
translated_text = ((dic['text']).split(":", 1)[0]).strip() + ": " + self.translate(((dic['text']).split(":", 1)[1]).strip(), max_length=max_length) |
|
else: |
|
translated_text = self.translate(dic["text"], max_length=max_length) |
|
|
|
dic["text"] = translated_text |
|
|
|
return input_list_dict |
|
|
|
except Exception as e: |
|
print(f"Error translating text: {e}") |
|
raise |
|
finally: |
|
self.release_cuda_memory() |
|
|
|
def offload(self): |
|
"""Offload the model and free up the memory""" |
|
if self.model is not None: |
|
del self.model |
|
self.model = None |
|
if self.device == "cuda": |
|
self.release_cuda_memory() |
|
gc.collect() |
|
|
|
@staticmethod |
|
def get_device(): |
|
if torch.cuda.is_available(): |
|
return "cuda" |
|
elif torch.backends.mps.is_available(): |
|
return "mps" |
|
else: |
|
return "cpu" |
|
|
|
@staticmethod |
|
def release_cuda_memory(): |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
torch.cuda.reset_max_memory_allocated() |
|
|
|
@staticmethod |
|
def remove_input_files(file_paths: List[str]): |
|
if not file_paths: |
|
return |
|
|
|
for file_path in file_paths: |
|
if file_path and os.path.exists(file_path): |
|
os.remove(file_path) |
|
|
|
@staticmethod |
|
def cache_parameters(model_size: str, |
|
src_lang: str, |
|
tgt_lang: str, |
|
max_length: int, |
|
add_timestamp: bool): |
|
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) |
|
cached_params["translation"]["nllb"] = { |
|
"model_size": model_size, |
|
"source_lang": src_lang, |
|
"target_lang": tgt_lang, |
|
"max_length": max_length, |
|
} |
|
cached_params["translation"]["add_timestamp"] = add_timestamp |
|
save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH) |
|
|