File size: 4,173 Bytes
56f497c 70ebea3 fe93b05 56f497c fc650b8 d1d386f 4acf86a e13fba3 70ebea3 56f497c fc650b8 60b55c4 56f497c e402119 e13fba3 60b55c4 452443d fc650b8 b8a0208 fc650b8 e3f013c 60b55c4 56f497c bf6322b 56f497c 72e8644 9b57b81 72e8644 60b55c4 4fd4915 b98b60d 56f497c 2169b22 56f497c e402119 56f497c 2169b22 56f497c a4d39dd c6657b6 2169b22 28c6232 56f497c ba3b9f6 56f497c 4acf86a 56f497c b98b60d 56f497c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import gradio as gr
import spaces
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
langs = {"German": "de", "Romanian": "ro", "English": "en", "French": "fr", "Spanish": "es", "Italian": "it"}
options = list(langs.keys())
models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large", "facebook/nllb-200-distilled-600M", "facebook/nllb-200-distilled-1.3B"]
@spaces.GPU
def translate_text(input_text, sselected_language, tselected_language, model_name):
sl = langs[sselected_language]
tl = langs[tselected_language]
message_text = f'Translated from {sselected_language} to {tselected_language} with {model_name}'
if model_name == "Helsinki-NLP":
try:
model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
tokenizer = AutoTokenizer.from_pretrained(model_name_full)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
except EnvironmentError:
try :
model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
tokenizer = AutoTokenizer.from_pretrained(model_name_full)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
except EnvironmentError as error:
return f"Error finding model: {model_name_full}! Try other available language combination.", error
elif model_name.startswith('facebook/nllb'):
from languagecode import nllb_language_codes
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=nllb_language_codes[sselected_language], tgt_lang=nllb_language_codes[tselected_language])
translated_text = translator(input_text, max_length=512)
return translated_text[0]['translation_text'], message_text
else:
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
if model_name.startswith("Helsinki-NLP"):
prompt = input_text
else:
prompt = f"translate {sselected_language} to {tselected_language}: {input_text}"
input_ids = tokenizer.encode(prompt, return_tensors="pt")
output_ids = model.generate(input_ids, max_length=512)
translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f'Translating from {sselected_language} to {tselected_language} with {model_name}:', f'{input_text} = {translated_text}', sep='\n')
return translated_text, message_text
# Define a function to swap dropdown values
def swap_languages(src_lang, tgt_lang):
return tgt_lang, src_lang
def create_interface():
with gr.Blocks() as interface:
gr.Markdown("## Machine Text Translation")
with gr.Row():
input_text = gr.Textbox(label="Enter text to translate:", placeholder="Type your text here...")
with gr.Row():
sselected_language = gr.Dropdown(choices=options, value = options[0], label="Source language", interactive=True)
tselected_language = gr.Dropdown(choices=options, value = options[1], label="Target language", interactive=True)
swap_button = gr.Button("Swap Languages")
swap_button.click(fn=swap_languages, inputs=[sselected_language, tselected_language], outputs=[sselected_language, tselected_language])
model_name = gr.Dropdown(choices=models, label="Select a model", value = models[4], interactive=True)
translate_button = gr.Button("Translate")
translated_text = gr.Textbox(label="Translated text:", interactive=False)
message_text = gr.Textbox(label="Messages:", value = 'Display status and error messages', interactive=False)
translate_button.click(
translate_text,
inputs=[input_text, sselected_language, tselected_language, model_name],
outputs=[translated_text, message_text]
)
return interface
# Launch the Gradio interface
interface = create_interface()
interface.launch() |