Update app.py
Browse files
app.py
CHANGED
@@ -70,12 +70,12 @@ st.session_state["model_name"] = model_name
|
|
70 |
if model_name == 'Helsinki-NLP':
|
71 |
try:
|
72 |
model_name = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
|
73 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
74 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
75 |
except EnvironmentError:
|
76 |
model_name = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
|
77 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
78 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
79 |
if model_name.startswith('t5'):
|
80 |
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
81 |
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
|
@@ -87,13 +87,18 @@ translated_textarea = st.text("")
|
|
87 |
# Handle the submit button click
|
88 |
if submit_button:
|
89 |
if model_name.startswith('Helsinki-NLP'):
|
90 |
-
prompt = input_text
|
91 |
-
print(prompt)
|
92 |
-
input_ids = tokenizer.encode(prompt, return_tensors='pt')
|
93 |
-
# Perform translation
|
94 |
-
output_ids = model.generate(input_ids)
|
95 |
-
# Decode the translated text
|
96 |
-
translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
97 |
elif model_name.startswith('Google'):
|
98 |
url = os.environ['GCLIENT'] + f'sl={sl}&tl={tl}&q={input_text}'
|
99 |
response = httpx.get(url)
|
|
|
70 |
if model_name == 'Helsinki-NLP':
|
71 |
try:
|
72 |
model_name = f"Helsinki-NLP/opus-mt-{sl}-{tl}"
|
73 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name)
|
74 |
+
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
75 |
except EnvironmentError:
|
76 |
model_name = f"Helsinki-NLP/opus-tatoeba-{sl}-{tl}"
|
77 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name)
|
78 |
+
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
79 |
if model_name.startswith('t5'):
|
80 |
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
81 |
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
|
|
|
87 |
# Handle the submit button click
|
88 |
if submit_button:
|
89 |
if model_name.startswith('Helsinki-NLP'):
|
90 |
+
# prompt = input_text
|
91 |
+
# print(prompt)
|
92 |
+
# input_ids = tokenizer.encode(prompt, return_tensors='pt')
|
93 |
+
# # Perform translation
|
94 |
+
# output_ids = model.generate(input_ids)
|
95 |
+
# # Decode the translated text
|
96 |
+
# translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
97 |
+
# Use a pipeline as a high-level helper
|
98 |
+
pipe = pipeline("translation", model=model_name)
|
99 |
+
translation = pipe(input_text)
|
100 |
+
translated_text = translation[0]['translation_text']
|
101 |
+
|
102 |
elif model_name.startswith('Google'):
|
103 |
url = os.environ['GCLIENT'] + f'sl={sl}&tl={tl}&q={input_text}'
|
104 |
response = httpx.get(url)
|