GradioTranslate / app.py
TiberiuCristianLeon's picture
Update app.py
5a309fc verified
raw
history blame
2.54 kB
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
def translate_text(input_text, sselected_language, tselected_language, model_name):
langs = {"English": "en", "Romanian": "ro", "German": "de", "French": "fr", "Spanish": "es"}
sl = langs[sselected_language]
tl = langs[tselected_language]
if model_name == "Helsinki-NLP":
try:
model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
tokenizer = AutoTokenizer.from_pretrained(model_name_full)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
except EnvironmentError:
model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
tokenizer = AutoTokenizer.from_pretrained(model_name_full)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
else:
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
if model_name.startswith("Helsinki-NLP"):
prompt = input_text
else:
prompt = f"translate {sselected_language} to {tselected_language}: {input_text}"
input_ids = tokenizer.encode(prompt, return_tensors="pt")
output_ids = model.generate(input_ids)
translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return translated_text
options = ["German", "Romanian", "English", "French", "Spanish"]
models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large"]
def create_interface():
with gr.Blocks() as interface:
gr.Markdown("## Text Machine Translation")
with gr.Row():
input_text = gr.Textbox(label="Enter text to translate:", placeholder="Type your text here...")
with gr.Row():
sselected_language = gr.Dropdown(choices=options, value="German", label="Source language")
tselected_language = gr.Dropdown(choices=options, value="Romanian", label="Target language")
model_name = gr.Dropdown(choices=models, value="Helsinki-NLP", label="Select a model")
translate_button = gr.Button("Translate")
translated_text = gr.Textbox(label="Translated text:", interactive=False)
translate_button.click(
translate_text,
inputs=[input_text, sselected_language, tselected_language, model_name],
outputs=translated_text
)
return interface
# Launch the Gradio interface
interface = create_interface()
interface.launch()