File size: 6,465 Bytes
931b71f
215212f
a23a2a4
7bc13dc
215212f
 
 
978158a
3ddb276
0a9420e
 
fa1dbbc
0a9420e
 
 
 
 
 
 
 
 
 
011b5f0
931b71f
 
011b5f0
 
30f984e
011b5f0
fa1dbbc
 
011b5f0
 
 
 
 
 
d04a69f
fa1dbbc
d04a69f
 
 
 
011b5f0
 
 
59764d5
931b71f
011b5f0
 
 
 
d04a69f
011b5f0
 
d04a69f
931b71f
011b5f0
 
 
 
 
 
931b71f
d04a69f
fa1dbbc
 
931b71f
 
 
 
 
 
 
 
 
 
 
1f648dc
931b71f
978158a
 
931b71f
 
 
 
 
 
 
cd2dcf6
1f648dc
 
 
 
 
 
215212f
86f6a5a
84325a1
e057a26
0f908c5
1f648dc
931b71f
1f648dc
978158a
1f648dc
 
 
 
e165141
a1af82c
1f648dc
e165141
 
059e62b
1f648dc
 
e165141
 
 
0cacd2e
 
0a9420e
0cacd2e
0a9420e
0ec1817
0a9420e
 
 
0ec1817
3a7e27a
931b71f
 
059e62b
931b71f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import streamlit as st
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, logging
import torch
import os
import httpx

logging.set_verbosity_error()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def download_argos_model(from_code, to_code):
    import argostranslate.package
    print('Downloading model', from_code, to_code) 
    # Download and install Argos Translate package
    argostranslate.package.update_package_index()
    available_packages = argostranslate.package.get_available_packages()
    package_to_install = next(
        filter(
            lambda x: x.from_code == from_code and x.to_code == to_code, available_packages
        )
    )
    argostranslate.package.install_from_path(package_to_install.download())

# App layout
st.header("Text Machine Translation")
input_text = st.text_input("Enter text to translate:")

# Language options and mappings
options = ["German", "Romanian", "English", "French", "Spanish", "Italian"]
langs = {"English": "en", "Romanian": "ro", "German": "de", "French": "fr", "Spanish": "es", "Italian": "it"}
models = ["Helsinki-NLP", "Argos", "t5-base", "t5-small", "t5-large", "Unbabel/Tower-Plus-2B", 
          "Unbabel/TowerInstruct-Mistral-7B-v0.2", "Google"]

# Initialize session state if not already set
if "sselected_language" not in st.session_state:
    st.session_state["sselected_language"] = options[0]
if "tselected_language" not in st.session_state:
    st.session_state["tselected_language"] = options[1]
if "model_name" not in st.session_state:
    st.session_state["model_name"] = models[1]

# Model selection FIRST
model_name = st.selectbox("Select a model:", models, 
                          index=models.index(st.session_state["model_name"]))

# Create columns for language selection
scol, swapcol, tcol = st.columns([3, 1, 3])

with scol:
    sselected_language = st.selectbox("Source language:", options, 
                                      index=options.index(st.session_state["sselected_language"]))
with swapcol:
    if st.button("🔄 Swap"):
        st.session_state["model_name"] = model_name  # Preserve model
        st.session_state["sselected_language"], st.session_state["tselected_language"] = \
            st.session_state["tselected_language"], st.session_state["sselected_language"]
        st.rerun()
with tcol:
    tselected_language = st.selectbox("Target language:", options, 
                                      index=options.index(st.session_state["tselected_language"]))

# Language codes
sl = langs[st.session_state["sselected_language"]]
tl = langs[st.session_state["tselected_language"]]

# Store selections
st.session_state["sselected_language"] = sselected_language
st.session_state["tselected_language"] = tselected_language
st.session_state["model_name"] = model_name

if model_name == 'Helsinki-NLP':
    try:
        model_name = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    except EnvironmentError:
        model_name = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
if model_name.startswith('t5'):
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)

st.write("Selected language combination:", sselected_language, " - ", tselected_language, "Selected model:", model_name)
submit_button = st.button("Translate")
translated_textarea = st.text("")

# Handle the submit button click
if submit_button:
    if model_name.startswith('Helsinki-NLP'):
        prompt = input_text
        print(prompt)
        input_ids = tokenizer.encode(prompt, return_tensors='pt')
        # Perform translation
        output_ids = model.generate(input_ids)
        # Decode the translated text
        translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    elif model_name.startswith('Google'): 
        url = os.environ['GCLIENT'] + f'sl={sl}&tl={tl}&q={input_text}'
        response = httpx.get(url)
        translated_text = response.json()[0][0][0]
        print(response.json()[0][0])
    elif model_name.startswith('t5'):
        prompt = f'translate {sselected_language} to {tselected_language}: {input_text}'
        print(prompt)
        input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
        # Perform translation
        output_ids = model.generate(input_ids)
        # Decode the translated text
        translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    if 'Unbabel' in model_name:   
        pipe = pipeline("text-generation", model=model_name, torch_dtype=torch.bfloat16, device_map="auto")
        # We use the tokenizer’s chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
        messages = [{"role": "user",
                     "content": f"Translate the following text from {sselected_language} into {tselected_language}.\n{sselected_language}: {input_text}.\n{tselected_language}:"}]
        prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
        outputs = pipe(prompt, max_new_tokens=256, do_sample=False)
        translated_text = outputs[0]["generated_text"]
        start_marker = "<end_of_turn>"
        if start_marker in translated_text:
            translated_text = translated_text.split(start_marker)[1].strip()
        translated_text = translated_text.replace('Answer:', '').strip() if translated_text.startswith('Answer:') else translated_text
    if 'Argos' in model_name:   
        import argostranslate.translate       
        # Translate
        try:
            download_argos_model(sl, tl)
            translated_text = argostranslate.translate.translate(input_text, sl, tl)
        except StopIteration:
            translated_text = f"No Argos model for {sselected_language} to {tselected_language}. Try other model or languages combination!"
        except Exception as error:
            translated_text = error
    # Display the translated text
    print(translated_text)
    st.write(f"Translated text from {sselected_language} to {tselected_language} using {model_name}:")
    translated_textarea = st.text(translated_text)