File size: 2,239 Bytes
8a469fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import time
from scripts.physton_prompt.get_lang import get_lang

model = None
tokenizer = None
model_name = "facebook/mbart-large-50-many-to-many-mmt"
cache_dir = os.path.normpath(os.path.dirname(os.path.abspath(__file__)) + '/../../models')
loading = False

def initialize(reload=False):
    global model, tokenizer, model_name, cache_dir, loading
    if loading:
        while not loading:
            time.sleep(0.1)
            pass
        if model is None or tokenizer is None:
            raise Exception('error')
        # raise Exception(get_lang('model_is_loading'))
        return
    if not reload and model is not None:
        return
    loading = True
    model = None
    tokenizer = None

    model_path = os.path.join(cache_dir, "mbart-large-50-many-to-many-mmt")
    model_file = os.path.join(model_path, "pytorch_model.bin")
    if os.path.exists(model_path) and os.path.exists(model_file):
        model_name = model_path

    try:
        from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
        print(f'[sd-webui-prompt-all-in-one] Loading model {model_name} from {cache_dir}...')
        model = MBartForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
        tokenizer = MBart50TokenizerFast.from_pretrained(model_name, cache_dir=cache_dir)
        print(f'[sd-webui-prompt-all-in-one] Model {model_name} loaded.')
        loading = False
    except Exception as e:
        loading = False
        raise e

def translate(text, src_lang, target_lang):
    global model, tokenizer

    if not text:
        if isinstance(text, list):
            return []
        else:
            return ''

    if model is None:
        raise Exception(get_lang('model_not_initialized'))

    if tokenizer is None:
        raise Exception(get_lang('model_not_initialized'))

    if src_lang == target_lang:
        return text

    tokenizer.src_lang = src_lang
    encoded_input = tokenizer(text, return_tensors="pt", padding=True)
    generated_tokens = model.generate(
        **encoded_input, forced_bos_token_id=tokenizer.lang_code_to_id[target_lang],
        max_new_tokens=500
    )
    return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)