TiberiuCristianLeon's picture
Update app.py
a68ab45 verified
raw
history blame
8.16 kB
import streamlit as st
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, logging, AutoModelForCausalLM
import torch
import os
import httpx
logging.set_verbosity_error()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def download_argos_model(from_code, to_code):
import argostranslate.package
print('Downloading model', from_code, to_code)
# Download and install Argos Translate package
argostranslate.package.update_package_index()
available_packages = argostranslate.package.get_available_packages()
package_to_install = next(
filter(
lambda x: x.from_code == from_code and x.to_code == to_code, available_packages
)
)
argostranslate.package.install_from_path(package_to_install.download())
def wingpt(model_name, sl, tl, input_text):
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_json = '{"input_text": input_text}'
messages = [
{"role": "system", "content": f"Translate this to {tl} language"},
{"role": "user", "content": input_text}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512,
temperature=0.1
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))
rawresult = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
result = rawresult.split('\n')[-1].strip() if '\n' in rawresult else rawresult.strip()
return result
# App layout
st.header("Text Machine Translation")
input_text = st.text_input("Enter text to translate:")
# Language options and mappings
options = ["German", "Romanian", "English", "French", "Spanish", "Italian"]
langs = {"English": "en", "Romanian": "ro", "German": "de", "French": "fr", "Spanish": "es", "Italian": "it"}
models = ["Helsinki-NLP", "Argos", "t5-base", "t5-small", "t5-large", "Unbabel/Tower-Plus-2B",
"Unbabel/TowerInstruct-Mistral-7B-v0.2", "winninghealth/WiNGPT-Babel-2", "Google"]
# Initialize session state if not already set
if "sselected_language" not in st.session_state:
st.session_state["sselected_language"] = options[0]
if "tselected_language" not in st.session_state:
st.session_state["tselected_language"] = options[1]
if "model_name" not in st.session_state:
st.session_state["model_name"] = models[1]
# Model selection FIRST
model_name = st.selectbox("Select a model:", models,
index=models.index(st.session_state["model_name"]))
# Create columns for language selection
scol, swapcol, tcol = st.columns([3, 1, 3])
with scol:
sselected_language = st.selectbox("Source language:", options,
index=options.index(st.session_state["sselected_language"]))
with swapcol:
if st.button("🔄 Swap"):
st.session_state["model_name"] = model_name # Preserve model
st.session_state["sselected_language"], st.session_state["tselected_language"] = \
st.session_state["tselected_language"], st.session_state["sselected_language"]
st.rerun()
with tcol:
tselected_language = st.selectbox("Target language:", options,
index=options.index(st.session_state["tselected_language"]))
# Language codes
sl = langs[st.session_state["sselected_language"]]
tl = langs[st.session_state["tselected_language"]]
# Store selections
st.session_state["sselected_language"] = sselected_language
st.session_state["tselected_language"] = tselected_language
st.session_state["model_name"] = model_name
if model_name == 'Helsinki-NLP':
try:
model_name = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
pipe = pipeline("translation", model=model, tokenizer=tokenizer)
except (EnvironmentError, OSError):
model_name = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
pipe = pipeline("translation", model=model, tokenizer=tokenizer)
if model_name.startswith('t5'):
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
st.write("Selected language combination:", sselected_language, " - ", tselected_language, "Selected model:", model_name)
submit_button = st.button("Translate")
translated_textarea = st.text("")
# Handle the submit button click
if submit_button:
if model_name.startswith('Helsinki-NLP'):
# prompt = input_text
# print(prompt)
# input_ids = tokenizer.encode(prompt, return_tensors='pt')
# # Perform translation
# output_ids = model.generate(input_ids)
# # Decode the translated text
# translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Use a pipeline as a high-level helper
translation = pipe(input_text)
translated_text = translation[0]['translation_text']
elif model_name.startswith('Google'):
url = os.environ['GCLIENT'] + f'sl={sl}&tl={tl}&q={input_text}'
response = httpx.get(url)
translated_text = response.json()[0][0][0]
print(response.json()[0][0])
elif model_name.startswith('t5'):
prompt = f'translate {sselected_language} to {tselected_language}: {input_text}'
print(prompt)
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
# Perform translation
output_ids = model.generate(input_ids)
# Decode the translated text
translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
elif 'Unbabel' in model_name:
pipe = pipeline("text-generation", model=model_name, torch_dtype=torch.bfloat16, device_map="auto")
# We use the tokenizer’s chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
messages = [{"role": "user",
"content": f"Translate the following text from {sselected_language} into {tselected_language}.\n{sselected_language}: {input_text}.\n{tselected_language}:"}]
prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
outputs = pipe(prompt, max_new_tokens=256, do_sample=False)
translated_text = outputs[0]["generated_text"]
start_marker = "<end_of_turn>"
if start_marker in translated_text:
translated_text = translated_text.split(start_marker)[1].strip()
translated_text = translated_text.replace('Answer:', '').strip() if translated_text.startswith('Answer:') else translated_text
elif 'Argos' in model_name:
import argostranslate.translate
# Translate
try:
download_argos_model(sl, tl)
translated_text = argostranslate.translate.translate(input_text, sl, tl)
except StopIteration:
translated_text = f"No Argos model for {sselected_language} to {tselected_language}. Try other model or languages combination!"
except Exception as error:
translated_text = error
elif model_name == "winninghealth/WiNGPT-Babel-2":
translated_text = wingpt(model_name, sselected_language, tselected_language, input_text)
# Display the translated text
print(translated_text)
st.write(f"Translated text from {sselected_language} to {tselected_language} using {model_name}:")
translated_textarea = st.text(translated_text)