Update app.py
Browse files
app.py
CHANGED
@@ -6,8 +6,8 @@ st.header("Text Machine Translation")
|
|
6 |
input_text = st.text_input("Enter text to translate:")
|
7 |
# Create a list of options for the select box
|
8 |
options = ["German", "Romanian", "English", "French", "Spanish"]
|
9 |
-
langs = {"English":"en", "Romanian":"ro", "German":"de", "French":"fr", "Spanish":"es"}
|
10 |
-
models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large"]
|
11 |
|
12 |
# Create two columns
|
13 |
scol, tcol = st.columns(2)
|
@@ -34,9 +34,10 @@ if model_name == 'Helsinki-NLP':
|
|
34 |
model_name = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
|
35 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
36 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
37 |
-
|
38 |
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
39 |
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
|
|
40 |
st.write("Selected language combination:", sselected_language, " - ", tselected_language, "Selected model:", model_name)
|
41 |
submit_button = st.button("Translate")
|
42 |
translated_textarea = st.text("")
|
@@ -45,14 +46,30 @@ translated_textarea = st.text("")
|
|
45 |
if submit_button:
|
46 |
if model_name.startswith('Helsinki-NLP'):
|
47 |
prompt = input_text
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
prompt = f'translate {sselected_language} to {tselected_language}: {input_text}'
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
# Display the translated text
|
57 |
print(translated_text)
|
58 |
st.write(f"Translated text from {sselected_language} to {tselected_language} using {model_name}")
|
|
|
6 |
input_text = st.text_input("Enter text to translate:")
|
7 |
# Create a list of options for the select box
|
8 |
options = ["German", "Romanian", "English", "French", "Spanish"]
|
9 |
+
langs = {"English":"en", "Romanian":"ro", "German":"de", "French":"fr", "Spanish":"es", "Italian":"it"}
|
10 |
+
models = ["Helsinki-NLP", "t5-base", "t5-small", "t5-large", "Unbabel/TowerInstruct-7B-v0.2"]
|
11 |
|
12 |
# Create two columns
|
13 |
scol, tcol = st.columns(2)
|
|
|
34 |
model_name = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
|
35 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
36 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
37 |
+
if model_name.startswith('t5'):
|
38 |
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
39 |
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
40 |
+
|
41 |
st.write("Selected language combination:", sselected_language, " - ", tselected_language, "Selected model:", model_name)
|
42 |
submit_button = st.button("Translate")
|
43 |
translated_textarea = st.text("")
|
|
|
46 |
if submit_button:
|
47 |
if model_name.startswith('Helsinki-NLP'):
|
48 |
prompt = input_text
|
49 |
+
print(prompt)
|
50 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt')
|
51 |
+
# Perform translation
|
52 |
+
output_ids = model.generate(input_ids)
|
53 |
+
# Decode the translated text
|
54 |
+
translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
55 |
+
elif model_name.startswith('t5'):
|
56 |
prompt = f'translate {sselected_language} to {tselected_language}: {input_text}'
|
57 |
+
print(prompt)
|
58 |
+
input_ids = tokenizer.encode(prompt, return_tensors='pt')
|
59 |
+
# Perform translation
|
60 |
+
output_ids = model.generate(input_ids)
|
61 |
+
# Decode the translated text
|
62 |
+
translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
63 |
+
else:
|
64 |
+
pipe = pipeline("text-generation", model="Unbabel/TowerInstruct-7B-v0.2", torch_dtype=torch.bfloat16, device_map="auto")
|
65 |
+
# We use the tokenizer’s chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
|
66 |
+
messages = [
|
67 |
+
{"role": "user", "content": f"Translate the following text from {sselected_language} into {tselected_language}.\n{sselected_language}: {input_text}.\n{tselected_language}:"},
|
68 |
+
]
|
69 |
+
prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
70 |
+
outputs = pipe(prompt, max_new_tokens=256, do_sample=False)
|
71 |
+
translated_text = outputs[0]["generated_text"]
|
72 |
+
|
73 |
# Display the translated text
|
74 |
print(translated_text)
|
75 |
st.write(f"Translated text from {sselected_language} to {tselected_language} using {model_name}")
|