Spaces:
Runtime error
Runtime error
import streamlit as st | |
import transformers | |
from transformers import ( | |
GPT2Config, | |
GPT2Tokenizer) | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>') | |
from model_file import FlaxGPT2ForMultipleChoice | |
import jax.numpy as jnp | |
st.title('GPT2 for common sense reasoning') | |
st.write('Multiple Choice Question Answering using CosmosQA Dataset') | |
context=st.text_area('Context',height=25) | |
st.write(context) | |
#context = st.text_input('Context :') | |
question=st.text_input('Question') | |
buff, col, buff2 = st.beta_columns([5,1,2]) | |
choice_a=buff.text_input('choice 0:') | |
choice_b=buff.text_input('choice 1:') | |
choice_c=buff.text_input('choice 2:') | |
choice_d=buff.text_input('choice 3:') | |
a={} | |
def preprocess(context,question,choice_a,choice_b,choice_c,choice_d): | |
a['context&question']=context+question | |
a['first_sentence']=[a['context&question'],a['context&question'],a['context&question'],a['context&question']] | |
a['second_sentence']=choice_a,choice_b,choice_c,choice_d | |
return a | |
preprocessed_data=preprocess(context,question,choice_a,choice_b,choice_c,choice_d) | |
def tokenize(examples): | |
b=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax') | |
return b | |
tokenized_data=tokenize(preprocessed_data) | |
model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1)) | |
input_id=jnp.array(tokenized_data['input_ids']) | |
att_mask=jnp.array(tokenized_data['attention_mask']) | |
if st.button("Run"): | |
with st.spinner(text="Getting results..."): | |
outputs=model(input_id,att_mask) | |
final_output=jnp.argmax(outputs,axis=-1) | |
if final_output==0: | |
result='0' | |
elif final_output==1: | |
result='1' | |
elif final_output==2: | |
result='2' | |
elif final_output==3: | |
result='3' | |
st.success(f"The answer is choice {result1}") | |