Spaces:
Runtime error
Runtime error
import streamlit as st | |
import streamlit.components.v1 as component | |
from googletrans import Translator | |
from model import load_model | |
# from huggingface_hub import snapshot_download | |
page = st.sidebar.selectbox("Model ", ["Finetuned on News data", "Pretrained GPT2"]) | |
translator = Translator() | |
seed = st.sidebar.text_input('Starting text', 'ආයුබෝවන්') | |
seq_num = st.sidebar.number_input('Number of sequences to generate ', 1, 20, 5) | |
max_len = st.sidebar.number_input('Length of a sequence ', 5, 300, 100) | |
gen_bt = st.sidebar.button('Generate') | |
def generate(model, tokenizer, seed, seq_num, max_len): | |
sentences = [] | |
input_ids = tokenizer.encode(seed, return_tensors='pt') | |
beam_outputs = model.generate( | |
input_ids, | |
do_sample=True, | |
max_length=max_len, | |
top_k=50, | |
top_p=0.95, | |
temperature=0.7, | |
num_return_sequences=seq_num, | |
no_repeat_ngram_size=2, | |
early_stopping=True | |
) | |
for beam_out in beam_outputs: | |
sentences.append(tokenizer.decode(beam_out, skip_special_tokens=True)) | |
return sentences | |
def html(body): | |
st.markdown(body, unsafe_allow_html=True) | |
def card_begin_str(Sinhala_sentence): | |
return ( | |
"<style>div.card{background-color:#023b1d;border-radius: 5px;box-shadow: 0 4px 8px 0 rgba(0,0,0,0.2);transition: 0.3s;} small{ margin: 5px;}</style>" | |
'<div class="card">' | |
'<div class="container">' | |
f'<small class="text-white-400">{Sinhala_sentence}</small>' | |
) | |
def card_end_str(): | |
return "</div></div>" | |
def card(sinhala_sentence, english_sentence): | |
lines = [card_begin_str(sinhala_sentence), f"<p>{english_sentence}</p>", card_end_str()] | |
html("".join(lines)) | |
def br(n): | |
html(n * "<br>") | |
def card_html(sinhala_sentence, english_sentence): | |
with open('./app.css') as f: | |
css_file = f.read() | |
return component.html( | |
f""" | |
<style>{css_file}</style> | |
<article class="class_1 bg-white rounded-lg p-4 relative"> | |
<p class="font-bold items-center text-sm text-primary relative mb-1">{sinhala_sentence}</p> | |
<div class="flex items-center text-white-400 mb-4"> | |
<i class="fab fa-google mx-2"></i> | |
<small class="text-white-400">English Translations are by Google Translate</small> | |
</div> | |
<p class="not-italic items-center text-sm text-primary relative mb-4"> | |
{english_sentence} | |
</p> | |
</article> | |
""" | |
) | |
if page == 'Pretrained GPT2': | |
st.title('Sinhala Text generation with GPT2') | |
st.markdown('A simple demo using [Sinhala-gpt2 model](https://huggingface.co/flax-community/Sinhala-gpt2) trained during hf-flax week') | |
model, tokenizer = load_model('flax-community/Sinhala-gpt2') | |
if gen_bt: | |
try: | |
with st.spinner('Generating...'): | |
# generator = pipeline('text-generation', model=model, tokenizer=tokenizer) | |
# seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num) | |
seqs = generate(model, tokenizer, seed, seq_num, max_len) | |
st.warning("English sentences were translated by Google Translate.") | |
for i, seq in enumerate(seqs): | |
english_sentence = translator.translate(seq, src='si', dest='en').text | |
# card(seq, english_sentence) | |
html(card_begin_str(seq)) | |
st.info(english_sentence) | |
html(card_end_str()) | |
except Exception as e: | |
st.exception(f'Exception: {e}') | |
else: | |
st.title('Sinhala Text generation with Finetuned GPT2') | |
st.markdown('This model has been [finetuned Sinhala-gpt2 model](https://huggingface.co/keshan/sinhala-gpt2-newswire) with 6000 news articles(~12MB)') | |
model, tokenizer = load_model('keshan/sinhala-gpt2-newswire') | |
if gen_bt: | |
try: | |
with st.spinner('Generating...'): | |
# generator = pipeline('text-generation', model=model, tokenizer=tokenizer) | |
# seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num) | |
seqs = generate(model, tokenizer, seed, seq_num, max_len) | |
st.warning("English sentences were translated by Google Translate.") | |
for i, seq in enumerate(seqs): | |
# st.info(f'Generated sequence {i+1}:') | |
# st.write(seq) | |
# st.info(f'English translation (by Google Translation):') | |
# st.write(translator.translate(seq, src='si', dest='en').text) | |
english_sentence = translator.translate(seq, src='si', dest='en').text | |
# card(seq, english_sentence) | |
html(card_begin_str(seq)) | |
st.info(english_sentence) | |
html(card_end_str()) | |
except Exception as e: | |
st.exception(f'Exception: {e}') | |
st.markdown('____________') | |