TiberiuCristianLeon commited on
Commit
56f497c
·
verified ·
1 Parent(s): 247101d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
3
+
4
+ def translate_text(input_text, sselected_language, tselected_language, model_name):
5
+ langs = {"English": "en", "Romanian": "ro", "German": "de", "French": "fr", "Spanish": "es"}
6
+ sl = langs[sselected_language]
7
+ tl = langs[tselected_language]
8
+
9
+ if model_name == "Helsinki-NLP":
10
+ try:
11
+ model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name_full)
13
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
14
+ except EnvironmentError:
15
+ model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name_full)
17
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
18
+ else:
19
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
20
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
21
+
22
+ if model_name.startswith("Helsinki-NLP"):
23
+ prompt = input_text
24
+ else:
25
+ prompt = f"translate {sselected_language} to {tselected_language}: {input_text}"
26
+
27
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
28
+ output_ids = model.generate(input_ids)
29
+ translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
30
+
31
+ return translated_text
32
+
33
+ options = ["German", "Romanian", "English", "French", "Spanish"]
34
+ models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large"]
35
+
36
+ def create_interface():
37
+ with gr.Blocks() as interface:
38
+ gr.Markdown("## Text Machine Translation")
39
+
40
+ with gr.Row():
41
+ input_text = gr.Textbox(label="Enter text to translate:", placeholder="Type your text here...")
42
+
43
+ with gr.Row():
44
+ sselected_language = gr.Dropdown(choices=options, value="English", label="Source language")
45
+ tselected_language = gr.Dropdown(choices=options, value="German", label="Target language")
46
+
47
+ model_name = gr.Dropdown(choices=models, value="Helsinki-NLP", label="Select a model")
48
+ translate_button = gr.Button("Translate")
49
+
50
+ translated_text = gr.Textbox(label="Translated text:", interactive=False)
51
+
52
+ translate_button.click(
53
+ translate_text,
54
+ inputs=[input_text, sselected_language, tselected_language, model_name],
55
+ outputs=translated_text
56
+ )
57
+
58
+ return interface
59
+
60
+ # Launch the Gradio interface
61
+ interface = create_interface()
62
+ interface.launch()