import streamlit as st import transformers from transformers import ( GPT2Config, GPT2Tokenizer) tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>') from model_file import FlaxGPT2ForMultipleChoice import jax.numpy as jnp st.title('GPT2 for common sense reasoning') st.write('Multiple Choice Question Answering using CosmosQA Dataset') context=st.text_area('Context',height=25) st.write(context) #context = st.text_input('Context :') question=st.text_input('Question') buff, col, buff2 = st.beta_columns([5,1,2]) choice_a=buff.text_input('choice 0:') choice_b=buff.text_input('choice 1:') choice_c=buff.text_input('choice 2:') choice_d=buff.text_input('choice 3:') a={} def preprocess(context,question,choice_a,choice_b,choice_c,choice_d): a['context&question']=context+question a['first_sentence']=[a['context&question'],a['context&question'],a['context&question'],a['context&question']] a['second_sentence']=choice_a,choice_b,choice_c,choice_d return a preprocessed_data=preprocess(context,question,choice_a,choice_b,choice_c,choice_d) def tokenize(examples): b=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax') return b tokenized_data=tokenize(preprocessed_data) model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1)) input_id=jnp.array(tokenized_data['input_ids']) att_mask=jnp.array(tokenized_data['attention_mask']) if st.button("Run"): with st.spinner(text="Getting results..."): outputs=model(input_id,att_mask) final_output=jnp.argmax(outputs,axis=-1) if final_output==0: result='0' elif final_output==1: result='1' elif final_output==2: result='2' elif final_output==3: result='3' st.success(f"The answer is choice {result1}")