Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import T5Tokenizer, MT5ForConditionalGeneration | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
TOKENIZER = T5Tokenizer.from_pretrained('google/mt5-small') | |
MODEL = MT5ForConditionalGeneration.from_pretrained("werent4/mt5TranslatorLT") | |
MODEL.to(DEVICE) | |
def translate(text, mode, max_length, num_beams): | |
text = f"translate English to Lithuanian: {text}" if mode == "En2Lt" else f"translate Lithuanian to English: {text}" | |
encoded_input = TOKENIZER(text, return_tensors="pt", padding=True, truncation=True, max_length=max_length).to(DEVICE) | |
with torch.no_grad(): | |
output_tokens = MODEL.generate( | |
**encoded_input, | |
max_length=max_length, | |
num_beams=num_beams, | |
no_repeat_ngram_size=2, | |
early_stopping=True | |
) | |
return TOKENIZER.decode(output_tokens[0], skip_special_tokens=True) | |
with gr.Blocks() as interface: | |
gr.Markdown("<h1>Lt🔄En: Lithuanian to English and vice versa") | |
with gr.Row(): | |
max_length = gr.Slider(1, 512, value=128, label="Max length", interactive=True) | |
num_beams = gr.Slider(1, 16, value=5, step=False, label="Num beams", interactive=True) | |
with gr.Row(): | |
input_text = gr.Textbox(label="Text input", placeholder="Enter your text here") | |
with gr.Column(): | |
mode = gr.Dropdown(label="Mode", choices=["En2Lt", "Lt2En"]) | |
translate_button = gr.Button("Translate") | |
output_text = gr.Textbox(label="Translated text") | |
with gr.Accordion("How to run the model locally:", open=False): | |
gr.Code("""import torch | |
from transformers import T5Tokenizer, MT5ForConditionalGeneration | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tokenizer = T5Tokenizer.from_pretrained('google/mt5-small') | |
model = MT5ForConditionalGeneration.from_pretrained("werent4/mt5TranslatorLT") | |
model.to(device) | |
def translate(text, model, tokenizer, device): | |
input_text = f"translate English to Lithuanian: {text}" | |
encoded_input = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device) | |
with torch.no_grad(): | |
output_tokens = model.generate( | |
**encoded_input, | |
max_length=128, | |
num_beams=5, | |
no_repeat_ngram_size=2, | |
early_stopping=True | |
) | |
translated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True) | |
return translated_text | |
text = "I live in Kaunas" | |
translate(text, model, tokenizer, device) | |
""", language='python') | |
translate_button.click(fn=translate, inputs=[input_text, mode, max_length, num_beams], outputs=[output_text]) | |
interface.launch(share=True) |