File size: 4,121 Bytes
56f497c
70ebea3
fe93b05
56f497c
c6657b6
d1d386f
452443d
e13fba3
70ebea3
56f497c
60e7bda
 
56f497c
 
 
 
 
 
e402119
 
 
 
e13fba3
 
452443d
b8a0208
 
60e7bda
e3f013c
71f8c73
56f497c
 
bf6322b
56f497c
 
 
 
 
 
72e8644
9b57b81
72e8644
ba3b9f6
4fd4915
b98b60d
56f497c
2169b22
 
 
 
56f497c
 
e402119
56f497c
 
 
2169b22
56f497c
a4d39dd
c6657b6
2169b22
28c6232
56f497c
ba3b9f6
56f497c
 
 
ba3b9f6
56f497c
 
 
 
b98b60d
56f497c
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import spaces
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

langs = {"German": ("de", "deu_Latn"), "Romanian": ("ro", "ron_Latn"), "English": ("en", "eng_Latn"), "French": ("fr", "fra_Latn"), "Spanish": ("es", "spa_Latn"), "Italian": ("it", "ita_Latn")}
options = list(langs.keys())
models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large", "facebook/nllb-200-distilled-1.3B", "facebook/nllb-200-distilled-600M"]

@spaces.GPU
def translate_text(input_text, sselected_language, tselected_language, model_name):
    sl = langs[sselected_language][0]
    tl = langs[tselected_language][0]
    if model_name == "Helsinki-NLP":
        try:
            model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
            tokenizer = AutoTokenizer.from_pretrained(model_name_full)
            model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
        except EnvironmentError:
            try :   
                model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
                tokenizer = AutoTokenizer.from_pretrained(model_name_full)
                model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
            except EnvironmentError as error:
                return f"Error finding required model! Try other available language combination. Error: {error}"
    elif model_name.startswith('facebook/nllb'):
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=langs[sselected_language][1], tgt_lang=langs[tselected_language][1])
        translated_text = translator(input_text, max_length=512)
        return translated_text[0]['translation_text']
    else:
        tokenizer = T5Tokenizer.from_pretrained(model_name)
        model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")

    if model_name.startswith("Helsinki-NLP"):
        prompt = input_text
    else:
        prompt = f"translate {sselected_language} to {tselected_language}: {input_text}"

    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    output_ids = model.generate(input_ids, max_length=512)
    translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    message_text = f'Translated from {sselected_language} to {tselected_language} with {model_name}'
    print(f'Translating from {sselected_language} to {tselected_language} with {model_name}:', f'{input_text} =  {translated_text}', sep='\n')
    return translated_text, message_text

# Define a function to swap dropdown values
def swap_languages(src_lang, tgt_lang):
    return tgt_lang, src_lang 

def create_interface():
    with gr.Blocks() as interface:
        gr.Markdown("## Machine Text Translation")

        with gr.Row():
            input_text = gr.Textbox(label="Enter text to translate:", placeholder="Type your text here...")
        
        with gr.Row():
            sselected_language = gr.Dropdown(choices=options, value = options[0], label="Source language", interactive=True)
            tselected_language = gr.Dropdown(choices=options, value = options[1], label="Target language", interactive=True)
            swap_button = gr.Button("Swap Languages")
            swap_button.click(fn=swap_languages, inputs=[sselected_language, tselected_language], outputs=[sselected_language, tselected_language])

        model_name = gr.Dropdown(choices=models, label="Select a model", value = models[4], interactive=True)
        translate_button = gr.Button("Translate")

        translated_text = gr.Textbox(label="Translated text:", interactive=False)
        message_text = gr.Textbox(label="Message:", interactive=False)

        translate_button.click(
            translate_text, 
            inputs=[input_text, sselected_language, tselected_language, model_name], 
            outputs=[translated_text, message_text]
        )

    return interface

# Launch the Gradio interface
interface = create_interface()
interface.launch()