File size: 8,157 Bytes
931b71f 5e9286a a23a2a4 7bc13dc 215212f 978158a 3ddb276 0a9420e fa1dbbc 0a9420e 5f86d7f 8ba2e89 a68ab45 8ba2e89 a68ab45 8ba2e89 a68ab45 8ba2e89 a68ab45 8ba2e89 563ff16 5f86d7f 011b5f0 931b71f 011b5f0 30f984e 011b5f0 fa1dbbc 7f87809 011b5f0 d04a69f fa1dbbc d04a69f 011b5f0 59764d5 931b71f 011b5f0 d04a69f 011b5f0 d04a69f 931b71f 011b5f0 931b71f d04a69f fa1dbbc 931b71f 181890d bd94fe4 931b71f 181890d 1f648dc 931b71f 978158a 931b71f 4d3b257 215212f 86f6a5a 84325a1 e057a26 0f908c5 1f648dc 931b71f 1f648dc 978158a 1f648dc a68ab45 a1af82c 1f648dc e165141 059e62b 1f648dc e165141 0cacd2e a68ab45 0a9420e 0cacd2e 0a9420e 0ec1817 0a9420e 0ec1817 3a7e27a a68ab45 5f86d7f 11103fd 931b71f 059e62b 931b71f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import streamlit as st
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, logging, AutoModelForCausalLM
import torch
import os
import httpx
logging.set_verbosity_error()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def download_argos_model(from_code, to_code):
import argostranslate.package
print('Downloading model', from_code, to_code)
# Download and install Argos Translate package
argostranslate.package.update_package_index()
available_packages = argostranslate.package.get_available_packages()
package_to_install = next(
filter(
lambda x: x.from_code == from_code and x.to_code == to_code, available_packages
)
)
argostranslate.package.install_from_path(package_to_install.download())
def wingpt(model_name, sl, tl, input_text):
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_json = '{"input_text": input_text}'
messages = [
{"role": "system", "content": f"Translate this to {tl} language"},
{"role": "user", "content": input_text}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512,
temperature=0.1
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))
rawresult = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
result = rawresult.split('\n')[-1].strip() if '\n' in rawresult else rawresult.strip()
return result
# App layout
st.header("Text Machine Translation")
input_text = st.text_input("Enter text to translate:")
# Language options and mappings
options = ["German", "Romanian", "English", "French", "Spanish", "Italian"]
langs = {"English": "en", "Romanian": "ro", "German": "de", "French": "fr", "Spanish": "es", "Italian": "it"}
models = ["Helsinki-NLP", "Argos", "t5-base", "t5-small", "t5-large", "Unbabel/Tower-Plus-2B",
"Unbabel/TowerInstruct-Mistral-7B-v0.2", "winninghealth/WiNGPT-Babel-2", "Google"]
# Initialize session state if not already set
if "sselected_language" not in st.session_state:
st.session_state["sselected_language"] = options[0]
if "tselected_language" not in st.session_state:
st.session_state["tselected_language"] = options[1]
if "model_name" not in st.session_state:
st.session_state["model_name"] = models[1]
# Model selection FIRST
model_name = st.selectbox("Select a model:", models,
index=models.index(st.session_state["model_name"]))
# Create columns for language selection
scol, swapcol, tcol = st.columns([3, 1, 3])
with scol:
sselected_language = st.selectbox("Source language:", options,
index=options.index(st.session_state["sselected_language"]))
with swapcol:
if st.button("🔄 Swap"):
st.session_state["model_name"] = model_name # Preserve model
st.session_state["sselected_language"], st.session_state["tselected_language"] = \
st.session_state["tselected_language"], st.session_state["sselected_language"]
st.rerun()
with tcol:
tselected_language = st.selectbox("Target language:", options,
index=options.index(st.session_state["tselected_language"]))
# Language codes
sl = langs[st.session_state["sselected_language"]]
tl = langs[st.session_state["tselected_language"]]
# Store selections
st.session_state["sselected_language"] = sselected_language
st.session_state["tselected_language"] = tselected_language
st.session_state["model_name"] = model_name
if model_name == 'Helsinki-NLP':
try:
model_name = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
pipe = pipeline("translation", model=model, tokenizer=tokenizer)
except (EnvironmentError, OSError):
model_name = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
pipe = pipeline("translation", model=model, tokenizer=tokenizer)
if model_name.startswith('t5'):
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
st.write("Selected language combination:", sselected_language, " - ", tselected_language, "Selected model:", model_name)
submit_button = st.button("Translate")
translated_textarea = st.text("")
# Handle the submit button click
if submit_button:
if model_name.startswith('Helsinki-NLP'):
# prompt = input_text
# print(prompt)
# input_ids = tokenizer.encode(prompt, return_tensors='pt')
# # Perform translation
# output_ids = model.generate(input_ids)
# # Decode the translated text
# translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Use a pipeline as a high-level helper
translation = pipe(input_text)
translated_text = translation[0]['translation_text']
elif model_name.startswith('Google'):
url = os.environ['GCLIENT'] + f'sl={sl}&tl={tl}&q={input_text}'
response = httpx.get(url)
translated_text = response.json()[0][0][0]
print(response.json()[0][0])
elif model_name.startswith('t5'):
prompt = f'translate {sselected_language} to {tselected_language}: {input_text}'
print(prompt)
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
# Perform translation
output_ids = model.generate(input_ids)
# Decode the translated text
translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
elif 'Unbabel' in model_name:
pipe = pipeline("text-generation", model=model_name, torch_dtype=torch.bfloat16, device_map="auto")
# We use the tokenizer’s chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
messages = [{"role": "user",
"content": f"Translate the following text from {sselected_language} into {tselected_language}.\n{sselected_language}: {input_text}.\n{tselected_language}:"}]
prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
outputs = pipe(prompt, max_new_tokens=256, do_sample=False)
translated_text = outputs[0]["generated_text"]
start_marker = "<end_of_turn>"
if start_marker in translated_text:
translated_text = translated_text.split(start_marker)[1].strip()
translated_text = translated_text.replace('Answer:', '').strip() if translated_text.startswith('Answer:') else translated_text
elif 'Argos' in model_name:
import argostranslate.translate
# Translate
try:
download_argos_model(sl, tl)
translated_text = argostranslate.translate.translate(input_text, sl, tl)
except StopIteration:
translated_text = f"No Argos model for {sselected_language} to {tselected_language}. Try other model or languages combination!"
except Exception as error:
translated_text = error
elif model_name == "winninghealth/WiNGPT-Babel-2":
translated_text = wingpt(model_name, sselected_language, tselected_language, input_text)
# Display the translated text
print(translated_text)
st.write(f"Translated text from {sselected_language} to {tselected_language} using {model_name}:")
translated_textarea = st.text(translated_text) |