File size: 4,121 Bytes
56f497c 70ebea3 fe93b05 56f497c c6657b6 d1d386f 452443d e13fba3 70ebea3 56f497c 60e7bda 56f497c e402119 e13fba3 452443d b8a0208 60e7bda e3f013c 71f8c73 56f497c bf6322b 56f497c 72e8644 9b57b81 72e8644 ba3b9f6 4fd4915 b98b60d 56f497c 2169b22 56f497c e402119 56f497c 2169b22 56f497c a4d39dd c6657b6 2169b22 28c6232 56f497c ba3b9f6 56f497c ba3b9f6 56f497c b98b60d 56f497c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import gradio as gr
import spaces
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
langs = {"German": ("de", "deu_Latn"), "Romanian": ("ro", "ron_Latn"), "English": ("en", "eng_Latn"), "French": ("fr", "fra_Latn"), "Spanish": ("es", "spa_Latn"), "Italian": ("it", "ita_Latn")}
options = list(langs.keys())
models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large", "facebook/nllb-200-distilled-1.3B", "facebook/nllb-200-distilled-600M"]
@spaces.GPU
def translate_text(input_text, sselected_language, tselected_language, model_name):
sl = langs[sselected_language][0]
tl = langs[tselected_language][0]
if model_name == "Helsinki-NLP":
try:
model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
tokenizer = AutoTokenizer.from_pretrained(model_name_full)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
except EnvironmentError:
try :
model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
tokenizer = AutoTokenizer.from_pretrained(model_name_full)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
except EnvironmentError as error:
return f"Error finding required model! Try other available language combination. Error: {error}"
elif model_name.startswith('facebook/nllb'):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=langs[sselected_language][1], tgt_lang=langs[tselected_language][1])
translated_text = translator(input_text, max_length=512)
return translated_text[0]['translation_text']
else:
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
if model_name.startswith("Helsinki-NLP"):
prompt = input_text
else:
prompt = f"translate {sselected_language} to {tselected_language}: {input_text}"
input_ids = tokenizer.encode(prompt, return_tensors="pt")
output_ids = model.generate(input_ids, max_length=512)
translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
message_text = f'Translated from {sselected_language} to {tselected_language} with {model_name}'
print(f'Translating from {sselected_language} to {tselected_language} with {model_name}:', f'{input_text} = {translated_text}', sep='\n')
return translated_text, message_text
# Define a function to swap dropdown values
def swap_languages(src_lang, tgt_lang):
return tgt_lang, src_lang
def create_interface():
with gr.Blocks() as interface:
gr.Markdown("## Machine Text Translation")
with gr.Row():
input_text = gr.Textbox(label="Enter text to translate:", placeholder="Type your text here...")
with gr.Row():
sselected_language = gr.Dropdown(choices=options, value = options[0], label="Source language", interactive=True)
tselected_language = gr.Dropdown(choices=options, value = options[1], label="Target language", interactive=True)
swap_button = gr.Button("Swap Languages")
swap_button.click(fn=swap_languages, inputs=[sselected_language, tselected_language], outputs=[sselected_language, tselected_language])
model_name = gr.Dropdown(choices=models, label="Select a model", value = models[4], interactive=True)
translate_button = gr.Button("Translate")
translated_text = gr.Textbox(label="Translated text:", interactive=False)
message_text = gr.Textbox(label="Message:", interactive=False)
translate_button.click(
translate_text,
inputs=[input_text, sselected_language, tselected_language, model_name],
outputs=[translated_text, message_text]
)
return interface
# Launch the Gradio interface
interface = create_interface()
interface.launch() |