TiberiuCristianLeon commited on
Commit
59764d5
·
verified ·
1 Parent(s): b14e60e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -52
app.py CHANGED
@@ -1,59 +1,62 @@
1
- import streamlit as st
2
  from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
- # Create the app layout
5
- st.header("Text Machine Translation")
6
- input_text = st.text_input("Enter text to translate:")
7
- # Create a list of options for the select box
8
- options = ["German", "Romanian", "English", "French", "Spanish"]
9
- langs = {"English":"en", "Romanian":"ro", "German":"de", "French":"fr", "Spanish":"es"}
10
- models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large"]
11
 
12
- # Create two columns
13
- scol, tcol = st.columns(2)
14
- # Place select boxes in columns
15
- with scol:
16
- sselected_language = st.selectbox("Source language:", options, index=0, placeholder="Select source language")
17
- with tcol:
18
- tselected_language = st.selectbox("Target language:", options, index=1, placeholder="Select target language")
19
- model_name = st.selectbox("Select a model:", models, index=0, placeholder="Select language model")
20
-
21
- sl = langs[sselected_language]
22
- tl = langs[tselected_language]
23
-
24
- st.session_state["sselected_language"] = sselected_language
25
- st.session_state["tselected_language"] = tselected_language
26
- st.session_state["model_name"] = model_name
27
-
28
- if model_name == 'Helsinki-NLP':
29
- try:
30
- model_name = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
31
- tokenizer = AutoTokenizer.from_pretrained(model_name)
32
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
33
- except EnvironmentError:
34
- model_name = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
35
- tokenizer = AutoTokenizer.from_pretrained(model_name)
36
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
37
- else:
38
- tokenizer = T5Tokenizer.from_pretrained(model_name)
39
- model = T5ForConditionalGeneration.from_pretrained(model_name)
40
- st.write("Selected language combination:", sselected_language, " - ", tselected_language, "Selected model:", model_name)
41
- submit_button = st.button("Translate")
42
- translated_textarea = st.text("")
43
-
44
- # Handle the submit button click
45
- if submit_button:
46
- if model_name.startswith('Helsinki-NLP'):
47
  prompt = input_text
48
  else:
49
- prompt = f'translate {sselected_language} to {tselected_language}: {input_text}'
50
- print(prompt)
51
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
52
- # Perform translation
53
  output_ids = model.generate(input_ids)
54
- # Decode the translated text
55
  translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
56
- # Display the translated text
57
- print(translated_text)
58
- st.write(f"Translated text from {sselected_language} to {tselected_language} using {model_name}")
59
- translated_textarea = st.text(translated_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
+ def translate_text(input_text, sselected_language, tselected_language, model_name):
5
+ langs = {"English": "en", "Romanian": "ro", "German": "de", "French": "fr", "Spanish": "es"}
6
+ sl = langs[sselected_language]
7
+ tl = langs[tselected_language]
 
 
 
8
 
9
+ if model_name == "Helsinki-NLP":
10
+ try:
11
+ model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name_full)
13
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
14
+ except EnvironmentError:
15
+ model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name_full)
17
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
18
+ else:
19
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
20
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
21
+
22
+ if model_name.startswith("Helsinki-NLP"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  prompt = input_text
24
  else:
25
+ prompt = f"translate {sselected_language} to {tselected_language}: {input_text}"
26
+
27
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
 
28
  output_ids = model.generate(input_ids)
 
29
  translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
30
+
31
+ return translated_text
32
+
33
+ options = ["German", "Romanian", "English", "French", "Spanish"]
34
+ models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large"]
35
+
36
+ def create_interface():
37
+ with gr.Blocks() as interface:
38
+ gr.Markdown("## Text Machine Translation")
39
+
40
+ with gr.Row():
41
+ input_text = gr.Textbox(label="Enter text to translate:", placeholder="Type your text here...")
42
+
43
+ with gr.Row():
44
+ sselected_language = gr.Dropdown(options=options, value="English", label="Source language")
45
+ tselected_language = gr.Dropdown(options=options, value="German", label="Target language")
46
+
47
+ model_name = gr.Dropdown(options=models, value="Helsinki-NLP", label="Select a model")
48
+ translate_button = gr.Button("Translate")
49
+
50
+ translated_text = gr.Textbox(label="Translated text:", interactive=False)
51
+
52
+ translate_button.click(
53
+ translate_text,
54
+ inputs=[input_text, sselected_language, tselected_language, model_name],
55
+ outputs=translated_text
56
+ )
57
+
58
+ return interface
59
+
60
+ # Launch the Gradio interface
61
+ interface = create_interface()
62
+ interface.launch()