|
import os |
|
import time |
|
from scripts.physton_prompt.get_lang import get_lang |
|
|
|
model = None |
|
tokenizer = None |
|
model_name = "facebook/mbart-large-50-many-to-many-mmt" |
|
cache_dir = os.path.normpath(os.path.dirname(os.path.abspath(__file__)) + '/../../models') |
|
loading = False |
|
|
|
def initialize(reload=False): |
|
global model, tokenizer, model_name, cache_dir, loading |
|
if loading: |
|
while not loading: |
|
time.sleep(0.1) |
|
pass |
|
if model is None or tokenizer is None: |
|
raise Exception('error') |
|
|
|
return |
|
if not reload and model is not None: |
|
return |
|
loading = True |
|
model = None |
|
tokenizer = None |
|
|
|
model_path = os.path.join(cache_dir, "mbart-large-50-many-to-many-mmt") |
|
model_file = os.path.join(model_path, "pytorch_model.bin") |
|
if os.path.exists(model_path) and os.path.exists(model_file): |
|
model_name = model_path |
|
|
|
try: |
|
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration |
|
print(f'[sd-webui-prompt-all-in-one] Loading model {model_name} from {cache_dir}...') |
|
model = MBartForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir) |
|
tokenizer = MBart50TokenizerFast.from_pretrained(model_name, cache_dir=cache_dir) |
|
print(f'[sd-webui-prompt-all-in-one] Model {model_name} loaded.') |
|
loading = False |
|
except Exception as e: |
|
loading = False |
|
raise e |
|
|
|
def translate(text, src_lang, target_lang): |
|
global model, tokenizer |
|
|
|
if not text: |
|
if isinstance(text, list): |
|
return [] |
|
else: |
|
return '' |
|
|
|
if model is None: |
|
raise Exception(get_lang('model_not_initialized')) |
|
|
|
if tokenizer is None: |
|
raise Exception(get_lang('model_not_initialized')) |
|
|
|
if src_lang == target_lang: |
|
return text |
|
|
|
tokenizer.src_lang = src_lang |
|
encoded_input = tokenizer(text, return_tensors="pt", padding=True) |
|
generated_tokens = model.generate( |
|
**encoded_input, forced_bos_token_id=tokenizer.lang_code_to_id[target_lang], |
|
max_new_tokens=500 |
|
) |
|
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
|