Spaces:
Runtime error
Runtime error
File size: 1,953 Bytes
d006373 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import streamlit as st
import transformers
from transformers import (
GPT2Config,
GPT2Tokenizer)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>')
from model_file import FlaxGPT2ForMultipleChoice
import jax.numpy as jnp
st.title('GPT2 for common sense reasoning')
st.write('Multiple Choice Question Answering using CosmosQA Dataset')
context=st.text_area('Context',height=25)
st.write(context)
#context = st.text_input('Context :')
question=st.text_input('Question')
buff, col, buff2 = st.beta_columns([5,1,2])
choice_a=buff.text_input('choice 0:')
choice_b=buff.text_input('choice 1:')
choice_c=buff.text_input('choice 2:')
choice_d=buff.text_input('choice 3:')
a={}
def preprocess(context,question,choice_a,choice_b,choice_c,choice_d):
a['context&question']=context+question
a['first_sentence']=[a['context&question'],a['context&question'],a['context&question'],a['context&question']]
a['second_sentence']=choice_a,choice_b,choice_c,choice_d
return a
preprocessed_data=preprocess(context,question,choice_a,choice_b,choice_c,choice_d)
def tokenize(examples):
b=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
return b
tokenized_data=tokenize(preprocessed_data)
model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1))
input_id=jnp.array(tokenized_data['input_ids'])
att_mask=jnp.array(tokenized_data['attention_mask'])
if st.button("Run"):
with st.spinner(text="Getting results..."):
outputs=model(input_id,att_mask)
final_output=jnp.argmax(outputs,axis=-1)
if final_output==0:
result='0'
elif final_output==1:
result='1'
elif final_output==2:
result='2'
elif final_output==3:
result='3'
st.success(f"The answer is choice {result1}")
|