File size: 5,472 Bytes
56f497c 70ebea3 a103fac fe93b05 56f497c fc650b8 d1d386f 4746bc2 e13fba3 310e819 70ebea3 56f497c fc650b8 60b55c4 56f497c 310e819 56f497c e402119 310e819 e13fba3 60b55c4 310e819 d2894fa 988d5ac aab3949 a103fac fc650b8 2045817 60b55c4 310e819 d2894fa 44aa6cb d2894fa 310e819 d2894fa 56f497c bf6322b 56f497c 72e8644 2045817 72e8644 60b55c4 4fd4915 b98b60d 56f497c 2169b22 56f497c e402119 56f497c 2169b22 56f497c a4d39dd c6657b6 2169b22 28c6232 56f497c ba3b9f6 56f497c 07594e7 56f497c b98b60d 56f497c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import gradio as gr
import spaces
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
langs = {"German": "de", "Romanian": "ro", "English": "en", "French": "fr", "Spanish": "es", "Italian": "it"}
options = list(langs.keys())
models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large", "facebook/nllb-200-distilled-600M", "facebook/nllb-200-distilled-1.3B", "facebook/mbart-large-50-many-to-many-mmt"]
def model_to_cuda(model):
# Move the model to GPU if available
if torch.cuda.is_available():
model = model.to('cuda')
print("CUDA is available! Using GPU.")
else:
print("CUDA not available! Using CPU.")
return model
@spaces.GPU
def translate_text(input_text, sselected_language, tselected_language, model_name):
sl = langs[sselected_language]
tl = langs[tselected_language]
message_text = f'Translated from {sselected_language} to {tselected_language} with {model_name}'
if model_name == "Helsinki-NLP":
try:
model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
tokenizer = AutoTokenizer.from_pretrained(model_name_full)
model = model_to_cuda(AutoModelForSeq2SeqLM.from_pretrained(model_name_full))
except EnvironmentError:
try :
model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
tokenizer = AutoTokenizer.from_pretrained(model_name_full)
model = model_to_cuda(AutoModelForSeq2SeqLM.from_pretrained(model_name_full))
except EnvironmentError as error:
return f"Error finding model: {model_name_full}! Try other available language combination.", error
if model_name.startswith('facebook/nllb'):
from languagecodes import nllb_language_codes
tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=nllb_language_codes[sselected_language])
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")
translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=nllb_language_codes[sselected_language], tgt_lang=nllb_language_codes[tselected_language])
translated_text = translator(input_text, max_length=512)
return translated_text[0]['translation_text'], message_text
if model_name.startswith('facebook/mbart-large'):
from languagecodes import mbart_large_languages
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
# translate source to target
tokenizer.src_lang = mbart_large_languages[sselected_language]
encoded = tokenizer(input_text, return_tensors="pt")
generated_tokens = model.generate(
**encoded,
forced_bos_token_id=tokenizer.lang_code_to_id[mbart_large_languages[tselected_language]]
)
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0], message_text
if model_name.startswith('t5'):
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
if model_name.startswith("Helsinki-NLP"):
prompt = input_text
else:
prompt = f"translate {sselected_language} to {tselected_language}: {input_text}"
input_ids = tokenizer.encode(prompt, return_tensors="pt")
output_ids = model.generate(input_ids, max_length=512)
translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f'Translating from {sselected_language} to {tselected_language} with {model_name}:', f'{input_text} = {translated_text}', sep='\n')
return translated_text, message_text
# Define a function to swap dropdown values
def swap_languages(src_lang, tgt_lang):
return tgt_lang, src_lang
def create_interface():
with gr.Blocks() as interface:
gr.Markdown("## Machine Text Translation")
with gr.Row():
input_text = gr.Textbox(label="Enter text to translate:", placeholder="Type your text here...")
with gr.Row():
sselected_language = gr.Dropdown(choices=options, value = options[0], label="Source language", interactive=True)
tselected_language = gr.Dropdown(choices=options, value = options[1], label="Target language", interactive=True)
swap_button = gr.Button("Swap Languages")
swap_button.click(fn=swap_languages, inputs=[sselected_language, tselected_language], outputs=[sselected_language, tselected_language])
model_name = gr.Dropdown(choices=models, label="Select a model", value = models[4], interactive=True)
translate_button = gr.Button("Translate")
translated_text = gr.Textbox(label="Translated text:", interactive=False)
message_text = gr.Textbox(label="Messages:", placeholder="Display status and error messages", interactive=False)
translate_button.click(
translate_text,
inputs=[input_text, sselected_language, tselected_language, model_name],
outputs=[translated_text, message_text]
)
return interface
# Launch the Gradio interface
interface = create_interface()
interface.launch() |