|
import gradio as gr |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
def translate_text(input_text, sselected_language, tselected_language, model_name): |
|
langs = {"English": "en", "Romanian": "ro", "German": "de", "French": "fr", "Spanish": "es"} |
|
sl = langs[sselected_language] |
|
tl = langs[tselected_language] |
|
|
|
if model_name == "Helsinki-NLP": |
|
try: |
|
model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name_full) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full) |
|
except EnvironmentError: |
|
model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name_full) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full) |
|
else: |
|
tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
model = T5ForConditionalGeneration.from_pretrained(model_name) |
|
|
|
if model_name.startswith("Helsinki-NLP"): |
|
prompt = input_text |
|
else: |
|
prompt = f"translate {sselected_language} to {tselected_language}: {input_text}" |
|
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt") |
|
output_ids = model.generate(input_ids) |
|
translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
return translated_text |
|
|
|
options = ["German", "Romanian", "English", "French", "Spanish"] |
|
models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large"] |
|
|
|
def create_interface(): |
|
with gr.Blocks() as interface: |
|
gr.Markdown("## Text Machine Translation") |
|
|
|
with gr.Row(): |
|
input_text = gr.Textbox(label="Enter text to translate:", placeholder="Type your text here...") |
|
|
|
with gr.Row(): |
|
sselected_language = gr.Dropdown(choices=options, value="German", label="Source language") |
|
tselected_language = gr.Dropdown(choices=options, value="Romanian", label="Target language") |
|
|
|
model_name = gr.Dropdown(choices=models, value="Helsinki-NLP", label="Select a model") |
|
translate_button = gr.Button("Translate") |
|
|
|
translated_text = gr.Textbox(label="Translated text:", interactive=False) |
|
|
|
translate_button.click( |
|
translate_text, |
|
inputs=[input_text, sselected_language, tselected_language, model_name], |
|
outputs=translated_text |
|
) |
|
|
|
return interface |
|
|
|
|
|
interface = create_interface() |
|
interface.launch() |