T5-small / app.py
wang2246478872's picture
init
df3368c
raw
history blame
944 Bytes
import streamlit as st
from transformers import T5Tokenizer, T5ForConditionalGeneration
@st.cache_resource
def init_model():
model = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
return model, tokenizer
max_source_length = 512
max_target_length = 128
model, tokenizer = init_model()
st.title('T5-Small')
with st.form('my_form'):
text = st.text_area('Enter text:', '')
cols = st.columns(3)
submitted = cols[0].form_submit_button('translate')
task_prefix = cols[1].text_input("input language", "translate Chinese to English: ")
placeholder = st.markdown("", unsafe_allow_html=True)
if submitted:
with st.spinner("Translating..."):
input_ids = tokenizer(f"{task_prefix}{text}", return_tensors="pt").input_ids
outputs = model.generate(input_ids)
placeholder.markdown(tokenizer.decode(outputs[0], skip_special_tokens=True))