File size: 5,472 Bytes
56f497c
70ebea3
a103fac
fe93b05
56f497c
fc650b8
d1d386f
4746bc2
e13fba3
310e819
 
 
 
 
 
 
 
70ebea3
56f497c
fc650b8
 
60b55c4
56f497c
 
 
 
310e819
56f497c
e402119
 
 
310e819
e13fba3
60b55c4
310e819
d2894fa
988d5ac
aab3949
a103fac
fc650b8
2045817
60b55c4
310e819
d2894fa
 
 
 
 
 
 
44aa6cb
d2894fa
 
 
 
310e819
d2894fa
 
56f497c
bf6322b
56f497c
 
 
 
 
 
72e8644
2045817
72e8644
60b55c4
4fd4915
b98b60d
56f497c
2169b22
 
 
 
56f497c
 
e402119
56f497c
 
 
2169b22
56f497c
a4d39dd
c6657b6
2169b22
28c6232
56f497c
ba3b9f6
56f497c
 
 
07594e7
56f497c
 
 
 
b98b60d
56f497c
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import spaces
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

langs = {"German": "de", "Romanian": "ro", "English": "en", "French": "fr", "Spanish": "es", "Italian": "it"}
options = list(langs.keys())
models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large", "facebook/nllb-200-distilled-600M", "facebook/nllb-200-distilled-1.3B", "facebook/mbart-large-50-many-to-many-mmt"]

def model_to_cuda(model):
    # Move the model to GPU if available
    if torch.cuda.is_available():
        model = model.to('cuda')
        print("CUDA is available! Using GPU.")
    else:
        print("CUDA not available! Using CPU.")
    return model
@spaces.GPU
def translate_text(input_text, sselected_language, tselected_language, model_name):
    sl = langs[sselected_language]
    tl = langs[tselected_language]
    message_text = f'Translated from {sselected_language} to {tselected_language} with {model_name}'
    if model_name == "Helsinki-NLP":
        try:
            model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
            tokenizer = AutoTokenizer.from_pretrained(model_name_full)
            model = model_to_cuda(AutoModelForSeq2SeqLM.from_pretrained(model_name_full))
        except EnvironmentError:
            try :   
                model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
                tokenizer = AutoTokenizer.from_pretrained(model_name_full)
                model = model_to_cuda(AutoModelForSeq2SeqLM.from_pretrained(model_name_full))
            except EnvironmentError as error:
                return f"Error finding model: {model_name_full}! Try other available language combination.", error
    
    if model_name.startswith('facebook/nllb'):
        from languagecodes import nllb_language_codes
        tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=nllb_language_codes[sselected_language])
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")
        translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=nllb_language_codes[sselected_language], tgt_lang=nllb_language_codes[tselected_language])
        translated_text = translator(input_text, max_length=512)
        return translated_text[0]['translation_text'], message_text
    
    if model_name.startswith('facebook/mbart-large'):
        from languagecodes import mbart_large_languages
        from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
        model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
        tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
        # translate source to target
        tokenizer.src_lang = mbart_large_languages[sselected_language]
        encoded = tokenizer(input_text, return_tensors="pt")
        generated_tokens = model.generate(
            **encoded,
            forced_bos_token_id=tokenizer.lang_code_to_id[mbart_large_languages[tselected_language]]
        )
        return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0], message_text

    if model_name.startswith('t5'):
        tokenizer = T5Tokenizer.from_pretrained(model_name)
        model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")

    if model_name.startswith("Helsinki-NLP"):
        prompt = input_text
    else:
        prompt = f"translate {sselected_language} to {tselected_language}: {input_text}"

    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    output_ids = model.generate(input_ids, max_length=512)
    translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    print(f'Translating from {sselected_language} to {tselected_language} with {model_name}:', f'{input_text} =  {translated_text}', sep='\n')
    return translated_text, message_text

# Define a function to swap dropdown values
def swap_languages(src_lang, tgt_lang):
    return tgt_lang, src_lang 

def create_interface():
    with gr.Blocks() as interface:
        gr.Markdown("## Machine Text Translation")

        with gr.Row():
            input_text = gr.Textbox(label="Enter text to translate:", placeholder="Type your text here...")
        
        with gr.Row():
            sselected_language = gr.Dropdown(choices=options, value = options[0], label="Source language", interactive=True)
            tselected_language = gr.Dropdown(choices=options, value = options[1], label="Target language", interactive=True)
            swap_button = gr.Button("Swap Languages")
            swap_button.click(fn=swap_languages, inputs=[sselected_language, tselected_language], outputs=[sselected_language, tselected_language])

        model_name = gr.Dropdown(choices=models, label="Select a model", value = models[4], interactive=True)
        translate_button = gr.Button("Translate")

        translated_text = gr.Textbox(label="Translated text:", interactive=False)
        message_text = gr.Textbox(label="Messages:", placeholder="Display status and error messages", interactive=False)

        translate_button.click(
            translate_text, 
            inputs=[input_text, sselected_language, tselected_language, model_name], 
            outputs=[translated_text, message_text]
        )

    return interface

# Launch the Gradio interface
interface = create_interface()
interface.launch()