keshan's picture
updating model loading msg
8f192a0
raw
history blame
1.04 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
st.title('Sinhala Text generation with GPT2')
st.markdown('A simple demo using Sinhala-gpt2 model trained during hf-flax week')
seed = st.text_input('Starting text', 'ආයුබෝවන්')
seq_num = st.number_input('Number of sentences to generate ', 1, 20, 5)
max_len = st.number_input('Length of the sentence ', 5, 300, 100)
go = st.button('Generate')
with st.spinner('Waiting for the model to load.....'):
model = AutoModelForCausalLM.from_pretrained('flax-community/Sinhala-gpt2')
tokenizer = AutoTokenizer.from_pretrained('flax-community/Sinhala-gpt2')
st.success('Model loaded!!')
if go:
try:
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
st.write(seqs)
except Exception as e:
st.exception(f'Exception: {e}')
st.markdown('____________')
st.markdown('by Keshan with Flax Community')