Spaces:
Runtime error
Runtime error
File size: 4,879 Bytes
f8987ba ca404cd 06452a1 ca404cd d6f4621 f8987ba 77b63e6 06452a1 f8987ba d6f4621 06452a1 77b63e6 8f192a0 06452a1 ca404cd dd28c43 ca404cd 06452a1 77b63e6 549927d f8987ba 77b63e6 d6f4621 77b63e6 d6f4621 06452a1 ca404cd 06452a1 ca404cd d6f4621 77b63e6 549927d 77b63e6 d6f4621 77b63e6 d6f4621 06452a1 ca404cd 06452a1 ca404cd d6f4621 f8987ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import streamlit as st
import streamlit.components.v1 as component
from googletrans import Translator
from model import load_model
# from huggingface_hub import snapshot_download
page = st.sidebar.selectbox("Model ", ["Finetuned on News data", "Pretrained GPT2"])
translator = Translator()
seed = st.sidebar.text_input('Starting text', 'ආයුබෝවන්')
seq_num = st.sidebar.number_input('Number of sequences to generate ', 1, 20, 5)
max_len = st.sidebar.number_input('Length of a sequence ', 5, 300, 100)
gen_bt = st.sidebar.button('Generate')
def generate(model, tokenizer, seed, seq_num, max_len):
sentences = []
input_ids = tokenizer.encode(seed, return_tensors='pt')
beam_outputs = model.generate(
input_ids,
do_sample=True,
max_length=max_len,
top_k=50,
top_p=0.95,
temperature=0.7,
num_return_sequences=seq_num,
no_repeat_ngram_size=2,
early_stopping=True
)
for beam_out in beam_outputs:
sentences.append(tokenizer.decode(beam_out, skip_special_tokens=True))
return sentences
def html(body):
st.markdown(body, unsafe_allow_html=True)
def card_begin_str(Sinhala_sentence):
return (
"<style>div.card{background-color:#023b1d;border-radius: 5px;box-shadow: 0 4px 8px 0 rgba(0,0,0,0.2);transition: 0.3s;} small{ margin: 5px;}</style>"
'<div class="card">'
'<div class="container">'
f'<small class="text-white-400">{Sinhala_sentence}</small>'
)
def card_end_str():
return "</div></div>"
def card(sinhala_sentence, english_sentence):
lines = [card_begin_str(sinhala_sentence), f"<p>{english_sentence}</p>", card_end_str()]
html("".join(lines))
def br(n):
html(n * "<br>")
def card_html(sinhala_sentence, english_sentence):
with open('./app.css') as f:
css_file = f.read()
return component.html(
f"""
<style>{css_file}</style>
<article class="class_1 bg-white rounded-lg p-4 relative">
<p class="font-bold items-center text-sm text-primary relative mb-1">{sinhala_sentence}</p>
<div class="flex items-center text-white-400 mb-4">
<i class="fab fa-google mx-2"></i>
<small class="text-white-400">English Translations are by Google Translate</small>
</div>
<p class="not-italic items-center text-sm text-primary relative mb-4">
{english_sentence}
</p>
</article>
"""
)
if page == 'Pretrained GPT2':
st.title('Sinhala Text generation with GPT2')
st.markdown('A simple demo using [Sinhala-gpt2 model](https://huggingface.co/flax-community/Sinhala-gpt2) trained during hf-flax week')
model, tokenizer = load_model('flax-community/Sinhala-gpt2')
if gen_bt:
try:
with st.spinner('Generating...'):
# generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
# seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
seqs = generate(model, tokenizer, seed, seq_num, max_len)
st.warning("English sentences were translated by Google Translate.")
for i, seq in enumerate(seqs):
english_sentence = translator.translate(seq, src='si', dest='en').text
# card(seq, english_sentence)
html(card_begin_str(seq))
st.info(english_sentence)
html(card_end_str())
except Exception as e:
st.exception(f'Exception: {e}')
else:
st.title('Sinhala Text generation with Finetuned GPT2')
st.markdown('This model has been [finetuned Sinhala-gpt2 model](https://huggingface.co/keshan/sinhala-gpt2-newswire) with 6000 news articles(~12MB)')
model, tokenizer = load_model('keshan/sinhala-gpt2-newswire')
if gen_bt:
try:
with st.spinner('Generating...'):
# generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
# seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
seqs = generate(model, tokenizer, seed, seq_num, max_len)
st.warning("English sentences were translated by Google Translate.")
for i, seq in enumerate(seqs):
# st.info(f'Generated sequence {i+1}:')
# st.write(seq)
# st.info(f'English translation (by Google Translation):')
# st.write(translator.translate(seq, src='si', dest='en').text)
english_sentence = translator.translate(seq, src='si', dest='en').text
# card(seq, english_sentence)
html(card_begin_str(seq))
st.info(english_sentence)
html(card_end_str())
except Exception as e:
st.exception(f'Exception: {e}')
st.markdown('____________')
|