GPT2-CosmosQA / app.py
Vivek's picture
added final files
e98fee7
raw
history blame
1.85 kB
import streamlit as st
import transformers
from transformers import (
GPT2Config,
GPT2Tokenizer)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>')
from model_file import FlaxGPT2ForMultipleChoice
import jax
import jax.numpy as jnp
st.title('GPT2 for common sense reasoning')
st.write('Multiple Choice Question Answering using CosmosQA Dataset')
context=st.text_area('Context',height=25)
#context = st.text_input('Context :')
question=st.text_input('Question')
buff, col, buff2 = st.beta_columns([5,1,2])
choice_a=buff.text_input('choice 0')
choice_b=buff.text_input('choice 1')
choice_c=buff.text_input('choice 2')
choice_d=buff.text_input('choice 3')
a={}
def preprocess(context,question,choice_a,choice_b,choice_c,choice_d):
a['context&question']=context+question
a['first_sentence']=[a['context&question'],a['context&question'],a['context&question'],a['context&question']]
a['second_sentence']=choice_a,choice_b,choice_c,choice_d
return a
preprocessed_data=preprocess(context,question,choice_a,choice_b,choice_c,choice_d)
def tokenize(examples):
b=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,return_tensors='jax')
return b
tokenized_data=tokenize(preprocessed_data)
model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1))
input_id=jnp.array(tokenized_data['input_ids'])
att_mask=jnp.array(tokenized_data['attention_mask'])
input_id=input_id.reshape(1,4,-1)
att_mask=att_mask.reshape(1,4,-1)
if st.button("Run"):
with st.spinner(text="Getting results..."):
outputs=model(input_id,att_mask)
final_output=jnp.argmax(outputs,axis=-1)
#output=jax.device_get(final_output).item()
st.success(f"The answer is choice {final_output}")