Spaces:
Runtime error
Runtime error
File size: 3,989 Bytes
40e9898 36338f2 40e9898 fb12737 40e9898 36338f2 40e9898 36338f2 1a15874 36338f2 40e9898 36338f2 40e9898 36338f2 2b02259 40e9898 36338f2 e4461ed 6673aaa 1c9ee74 cab7f25 1c9ee74 6673aaa cab7f25 1c9ee74 cab7f25 6673aaa 36338f2 2b02259 40e9898 e4461ed 2b02259 40e9898 36338f2 e4461ed 36338f2 2b02259 e4461ed 36338f2 2b02259 e4461ed 36338f2 5a923ef 36338f2 5a923ef e4461ed 5a923ef e4461ed 5a923ef 36338f2 e4461ed 5a923ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
""" Script for streamlit demo
@author: AbinayaM02
"""
# Install necessary libraries
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline
import streamlit as st
import json
# Read the config
with open("config.json") as f:
config = json.loads(f.read())
# Set page layout
st.set_page_config(
page_title="Tamil Language Models",
page_icon="U+270D",
layout="wide",
initial_sidebar_state="expanded"
)
# Load the model
@st.cache(allow_output_mutation=True)
def load_model(model_name):
with st.spinner('Waiting for the model to load.....'):
model = AutoModelWithLMHead.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
# Side bar
img = st.sidebar.image("images/tamil_logo.jpg", width=300)
# Choose the model based on selection
st.sidebar.title("கதை சொல்லி!")
page = st.sidebar.selectbox(label="Select model",
options=config["models"],
help="Select the model to generate the text")
data = st.sidebar.selectbox(label="Select data",
options=config[page],
help="Select the data on which the model is trained")
if page == "Text Generation" and data == "Oscar + IndicNLP":
st.sidebar.markdown(
"[Model tracking on wandb](https://wandb.ai/wandb/hf-flax-gpt2-tamil/runs/watdq7ib/overview?workspace=user-abinayam)",
unsafe_allow_html=True
)
st.sidebar.markdown(
"[Model card](https://huggingface.co/abinayam/gpt-2-tamil)",
unsafe_allow_html=True
)
elif page == "Text Generation" and data == "Oscar":
st.sidebar.markdown(
"[Model tracking on wandb](https://wandb.ai/abinayam/hf-flax-gpt-2-tamil/runs/1ddv4131/overview?workspace=user-abinayam)",
unsafe_allow_html=True
)
st.sidebar.markdown(
"[Model card](https://huggingface.co/flax-community/gpt-2-tamil)",
unsafe_allow_html=True
)
# Main page
st.title("Tamil Language Demos")
st.markdown(
"Built as part of the Flax/Jax Community week, this demo uses [GPT2 trained on Oscar dataset](https://huggingface.co/flax-community/gpt-2-tamil) "
"and [GPT2 trained on Oscar & IndicNLP dataset] (https://huggingface.co/abinayam/gpt-2-tamil) "
"to show language generation!"
)
# Set default options for examples
prompts = config["examples"] + ["Custom"]
if page == 'Text Generation' and data == 'Oscar':
st.header('Tamil text generation with GPT2')
st.markdown('A simple demo using gpt-2-tamil model trained on Oscar dataset!')
model, tokenizer = load_model(config[data])
elif page == 'Text Generation' and data == "Oscar + Indic Corpus":
st.header('Tamil text generation with GPT2')
st.markdown('A simple demo using gpt-2-tamil model trained on Oscar + IndicNLP dataset')
model, tokenizer = load_model(config[data])
else:
st.title('Tamil News classification with Finetuned GPT2')
st.markdown('In progress')
if page == "Text Generation":
# Set default options
prompt = st.selectbox('Examples', prompts, index=0)
if prompt == "Custom":
prompt_box = "",
text = st.text_input(
'Add your custom text in Tamil',
"",
max_chars=1000)
else:
prompt_box = prompt
text = st.text_input(
'Selected example in Tamil',
prompt,
max_chars=1000)
max_len = st.slider('Select length of the sentence to generate', 25, 300, 100)
gen_bt = st.button('Generate')
# Generate text
if gen_bt:
try:
with st.spinner('Generating...'):
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
seqs = generator(prompt_box, max_length=max_len)[0]['generated_text']
st.write(seqs)
except Exception as e:
st.exception(f'Exception: {e}')
|