Update app.py
Browse files
app.py
CHANGED
@@ -13,7 +13,9 @@ options.extend(list(all_langs.keys()))
|
|
13 |
models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large",
|
14 |
"facebook/nllb-200-distilled-600M",
|
15 |
"facebook/nllb-200-distilled-1.3B",
|
16 |
-
"facebook/mbart-large-50-many-to-many-mmt"
|
|
|
|
|
17 |
|
18 |
def model_to_cuda(model):
|
19 |
# Move the model to GPU if available
|
@@ -62,6 +64,15 @@ def translate_text(input_text, sselected_language, tselected_language, model_nam
|
|
62 |
)
|
63 |
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0], message_text
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
if model_name.startswith('t5'):
|
66 |
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
67 |
model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
|
|
|
13 |
models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large",
|
14 |
"facebook/nllb-200-distilled-600M",
|
15 |
"facebook/nllb-200-distilled-1.3B",
|
16 |
+
"facebook/mbart-large-50-many-to-many-mmt",
|
17 |
+
"Unbabel/TowerInstruct-7B-v0.2",
|
18 |
+
"Unbabel/TowerInstruct-Mistral-7B-v0.2"]
|
19 |
|
20 |
def model_to_cuda(model):
|
21 |
# Move the model to GPU if available
|
|
|
64 |
)
|
65 |
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0], message_text
|
66 |
|
67 |
+
if 'Unbabel' in model_name:
|
68 |
+
pipe = pipeline("text-generation", model=model_name, torch_dtype=torch.bfloat16, device_map="auto")
|
69 |
+
messages = [{"role": "user",
|
70 |
+
"content": f"Translate the following text from {sselected_language} into {tselected_language}.\n{sselected_language}: {input_text}.\n{tselected_language}:"}]
|
71 |
+
prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
72 |
+
outputs = pipe(prompt, max_new_tokens=256, do_sample=False)
|
73 |
+
translated_text = outputs[0]["generated_text"]
|
74 |
+
return translated_text, message_text
|
75 |
+
|
76 |
if model_name.startswith('t5'):
|
77 |
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
78 |
model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
|