TiberiuCristianLeon commited on
Commit
931b71f
·
verified ·
1 Parent(s): ca58f82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -55
app.py CHANGED
@@ -1,62 +1,59 @@
1
- import gradio as gr
2
  from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
- def translate_text(input_text, sselected_language, tselected_language, model_name):
5
- langs = {"English": "en", "Romanian": "ro", "German": "de", "French": "fr", "Spanish": "es"}
6
- sl = langs[sselected_language]
7
- tl = langs[tselected_language]
8
-
9
- if model_name == "Helsinki-NLP":
10
- try:
11
- model_name_full = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
12
- tokenizer = AutoTokenizer.from_pretrained(model_name_full)
13
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
14
- except EnvironmentError:
15
- model_name_full = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
16
- tokenizer = AutoTokenizer.from_pretrained(model_name_full)
17
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name_full)
18
- else:
19
- tokenizer = T5Tokenizer.from_pretrained(model_name)
20
- model = T5ForConditionalGeneration.from_pretrained(model_name)
21
 
22
- if model_name.startswith("Helsinki-NLP"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  prompt = input_text
24
  else:
25
- prompt = f"translate {sselected_language} to {tselected_language}: {input_text}"
26
-
27
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
 
28
  output_ids = model.generate(input_ids)
 
29
  translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
30
-
31
- return translated_text
32
-
33
- options = ["German", "Romanian", "English", "French", "Spanish"]
34
- models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large"]
35
-
36
- def create_interface():
37
- with gr.Blocks() as interface:
38
- gr.Markdown("## Text Machine Translation")
39
-
40
- with gr.Row():
41
- input_text = gr.Textbox(label="Enter text to translate:", placeholder="Type your text here...")
42
-
43
- with gr.Row():
44
- sselected_language = gr.Dropdown(choices=options, value="English", label="Source language")
45
- tselected_language = gr.Dropdown(choices=options, value="German", label="Target language")
46
-
47
- model_name = gr.Dropdown(choices=models, value="Helsinki-NLP", label="Select a model")
48
- translate_button = gr.Button("Translate")
49
-
50
- translated_text = gr.Textbox(label="Translated text:", interactive=False)
51
-
52
- translate_button.click(
53
- translate_text,
54
- inputs=[input_text, sselected_language, tselected_language, model_name],
55
- outputs=translated_text
56
- )
57
-
58
- return interface
59
-
60
- # Launch the Gradio interface
61
- interface = create_interface()
62
- interface.launch()
 
1
+ import streamlit as st
2
  from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
+ # Create the app layout
5
+ st.header("Text Machine Translation")
6
+ input_text = st.text_input("Enter text to translate:")
7
+ # Create a list of options for the select box
8
+ options = ["German", "Romanian", "English", "French", "Spanish"]
9
+ langs = {"English":"en", "Romanian":"ro", "German":"de", "French":"fr", "Spanish":"es"}
10
+ models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large"]
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Create two columns
13
+ scol, tcol = st.columns(2)
14
+ # Place select boxes in columns
15
+ with scol:
16
+ sselected_language = st.selectbox("Source language:", options, index=0, placeholder="Select source language")
17
+ with tcol:
18
+ tselected_language = st.selectbox("Target language:", options, index=1, placeholder="Select target language")
19
+ model_name = st.selectbox("Select a model:", models, index=0, placeholder="Select language model")
20
+
21
+ sl = langs[sselected_language]
22
+ tl = langs[tselected_language]
23
+
24
+ st.session_state["sselected_language"] = sselected_language
25
+ st.session_state["tselected_language"] = tselected_language
26
+ st.session_state["model_name"] = model_name
27
+
28
+ if model_name == 'Helsinki-NLP':
29
+ try:
30
+ model_name = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
31
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
32
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
33
+ except EnvironmentError:
34
+ model_name = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
35
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
36
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
37
+ else:
38
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
39
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
40
+ st.write("Selected language combination:", sselected_language, " - ", tselected_language, "Selected model:", model_name)
41
+ submit_button = st.button("Translate")
42
+ translated_textarea = st.text("")
43
+
44
+ # Handle the submit button click
45
+ if submit_button:
46
+ if model_name.startswith('Helsinki-NLP'):
47
  prompt = input_text
48
  else:
49
+ prompt = f'translate {sselected_language} to {tselected_language}: {input_text}'
50
+ print(prompt)
51
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
52
+ # Perform translation
53
  output_ids = model.generate(input_ids)
54
+ # Decode the translated text
55
  translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
56
+ # Display the translated text
57
+ print(translated_text)
58
+ st.write(f"Translated text from {sselected_language} to {tselected_language} using {model_name}")
59
+ translated_textarea = st.text(translated_text)