TiberiuCristianLeon commited on
Commit
c1b6fa4
·
verified ·
1 Parent(s): 6450772

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import streamlit as st
2
- from transformers import T5ForTranslation, T5Tokenizer
 
 
 
3
 
4
- # Load the pre-trained model and tokenizer
5
- model = T5ForTranslation.from_pretrained("Helsinki-NLP/opus-mt-en-ro")
6
- tokenizer = T5Tokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ro")
7
 
8
  # Create the app layout
9
  st.title("Text Translation App")
@@ -14,13 +16,15 @@ translated_text = st.text("")
14
  # Handle the submit button click
15
  if submit_button:
16
  # Encode the input text
17
- encoded = tokenizer(text_input, return_tensors="pt")
 
18
 
19
  # Perform translation
20
- translated = model.generate(**encoded)
 
21
 
22
  # Decode the translated text
23
- translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
24
 
25
  # Display the translated text
26
  st.write("Translated Text:", translated_text[0])
 
1
  import streamlit as st
2
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
3
+ model_name = 't5-base'
4
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
5
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
6
 
7
+ # model = T5ForTranslation.from_pretrained("Helsinki-NLP/opus-mt-en-ro")
8
+ # tokenizer = T5Tokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ro")
 
9
 
10
  # Create the app layout
11
  st.title("Text Translation App")
 
16
  # Handle the submit button click
17
  if submit_button:
18
  # Encode the input text
19
+ # encoded = tokenizer(text_input, return_tensors="pt")
20
+ input_ids = tokenizer.encode(input_text, return_tensors='pt')
21
 
22
  # Perform translation
23
+ output_ids = model.generate(input_ids)
24
+ translated = tokenizer.decode(output_ids[0], skip_special_tokens=True)
25
 
26
  # Decode the translated text
27
+ translated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
28
 
29
  # Display the translated text
30
  st.write("Translated Text:", translated_text[0])