File size: 4,173 Bytes
56f497c
70ebea3
fe93b05
56f497c
fc650b8
d1d386f
4acf86a
e13fba3
70ebea3
56f497c
fc650b8
 
60b55c4
56f497c
 
 
 
 
 
e402119
 
 
 
e13fba3
60b55c4
452443d
fc650b8
b8a0208
 
fc650b8
e3f013c
60b55c4
56f497c
 
bf6322b
56f497c
 
 
 
 
 
72e8644
9b57b81
72e8644
60b55c4
4fd4915
b98b60d
56f497c
2169b22
 
 
 
56f497c
 
e402119
56f497c
 
 
2169b22
56f497c
a4d39dd
c6657b6
2169b22
28c6232
56f497c
ba3b9f6
56f497c
 
 
4acf86a
56f497c
 
 
 
b98b60d
56f497c
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import spaces
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

langs = {"German": "de", "Romanian": "ro", "English": "en", "French": "fr", "Spanish": "es", "Italian": "it"}
options = list(langs.keys())
models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large", "facebook/nllb-200-distilled-600M", "facebook/nllb-200-distilled-1.3B"]

@spaces.GPU
def translate_text(input_text, sselected_language, tselected_language, model_name):
    sl = langs[sselected_language]
    tl = langs[tselected_language]
    message_text = f'Translated from {sselected_language} to {tselected_language} with {model_name}'
    if model_name == "Helsinki-NLP":
        try:
            model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
            tokenizer = AutoTokenizer.from_pretrained(model_name_full)
            model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
        except EnvironmentError:
            try :   
                model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
                tokenizer = AutoTokenizer.from_pretrained(model_name_full)
                model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
            except EnvironmentError as error:
                return f"Error finding model: {model_name_full}! Try other available language combination.", error
    elif model_name.startswith('facebook/nllb'):
        from languagecode import nllb_language_codes
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=nllb_language_codes[sselected_language], tgt_lang=nllb_language_codes[tselected_language])
        translated_text = translator(input_text, max_length=512)
        return translated_text[0]['translation_text'], message_text
    else:
        tokenizer = T5Tokenizer.from_pretrained(model_name)
        model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")

    if model_name.startswith("Helsinki-NLP"):
        prompt = input_text
    else:
        prompt = f"translate {sselected_language} to {tselected_language}: {input_text}"

    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    output_ids = model.generate(input_ids, max_length=512)
    translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    print(f'Translating from {sselected_language} to {tselected_language} with {model_name}:', f'{input_text} =  {translated_text}', sep='\n')
    return translated_text, message_text

# Define a function to swap dropdown values
def swap_languages(src_lang, tgt_lang):
    return tgt_lang, src_lang 

def create_interface():
    with gr.Blocks() as interface:
        gr.Markdown("## Machine Text Translation")

        with gr.Row():
            input_text = gr.Textbox(label="Enter text to translate:", placeholder="Type your text here...")
        
        with gr.Row():
            sselected_language = gr.Dropdown(choices=options, value = options[0], label="Source language", interactive=True)
            tselected_language = gr.Dropdown(choices=options, value = options[1], label="Target language", interactive=True)
            swap_button = gr.Button("Swap Languages")
            swap_button.click(fn=swap_languages, inputs=[sselected_language, tselected_language], outputs=[sselected_language, tselected_language])

        model_name = gr.Dropdown(choices=models, label="Select a model", value = models[4], interactive=True)
        translate_button = gr.Button("Translate")

        translated_text = gr.Textbox(label="Translated text:", interactive=False)
        message_text = gr.Textbox(label="Messages:", value = 'Display status and error messages', interactive=False)

        translate_button.click(
            translate_text, 
            inputs=[input_text, sselected_language, tselected_language, model_name], 
            outputs=[translated_text, message_text]
        )

    return interface

# Launch the Gradio interface
interface = create_interface()
interface.launch()