Update app.py
Browse files
app.py
CHANGED
@@ -2,14 +2,14 @@ import gradio as gr
|
|
2 |
import spaces
|
3 |
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
4 |
|
5 |
-
langs = {"German":
|
6 |
options = list(langs.keys())
|
7 |
models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large", "facebook/nllb-200-distilled-600M", "facebook/nllb-200-distilled-1.3B"]
|
8 |
|
9 |
@spaces.GPU
|
10 |
def translate_text(input_text, sselected_language, tselected_language, model_name):
|
11 |
-
sl = langs[sselected_language]
|
12 |
-
tl = langs[tselected_language]
|
13 |
message_text = f'Translated from {sselected_language} to {tselected_language} with {model_name}'
|
14 |
if model_name == "Helsinki-NLP":
|
15 |
try:
|
@@ -24,9 +24,10 @@ def translate_text(input_text, sselected_language, tselected_language, model_nam
|
|
24 |
except EnvironmentError as error:
|
25 |
return f"Error finding model: {model_name_full}! Try other available language combination.", error
|
26 |
elif model_name.startswith('facebook/nllb'):
|
|
|
27 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
28 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
29 |
-
translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=
|
30 |
translated_text = translator(input_text, max_length=512)
|
31 |
return translated_text[0]['translation_text'], message_text
|
32 |
else:
|
|
|
2 |
import spaces
|
3 |
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
4 |
|
5 |
+
langs = {"German": "de", "Romanian": "ro", "English": "en", "French": "fr", "Spanish": "es", "Italian": "it"}
|
6 |
options = list(langs.keys())
|
7 |
models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large", "facebook/nllb-200-distilled-600M", "facebook/nllb-200-distilled-1.3B"]
|
8 |
|
9 |
@spaces.GPU
|
10 |
def translate_text(input_text, sselected_language, tselected_language, model_name):
|
11 |
+
sl = langs[sselected_language]
|
12 |
+
tl = langs[tselected_language]
|
13 |
message_text = f'Translated from {sselected_language} to {tselected_language} with {model_name}'
|
14 |
if model_name == "Helsinki-NLP":
|
15 |
try:
|
|
|
24 |
except EnvironmentError as error:
|
25 |
return f"Error finding model: {model_name_full}! Try other available language combination.", error
|
26 |
elif model_name.startswith('facebook/nllb'):
|
27 |
+
from languagecode import nllb_language_codes
|
28 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
29 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
30 |
+
translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=nllb_language_codes[sselected_language], tgt_lang=nllb_language_codes[tselected_language])
|
31 |
translated_text = translator(input_text, max_length=512)
|
32 |
return translated_text[0]['translation_text'], message_text
|
33 |
else:
|