Vivek commited on
Commit
e98fee7
·
1 Parent(s): b1b7999

added final files

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import transformers
3
+ from transformers import (
4
+ GPT2Config,
5
+ GPT2Tokenizer)
6
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>')
7
+ from model_file import FlaxGPT2ForMultipleChoice
8
+ import jax
9
+ import jax.numpy as jnp
10
+
11
+ st.title('GPT2 for common sense reasoning')
12
+ st.write('Multiple Choice Question Answering using CosmosQA Dataset')
13
+
14
+ context=st.text_area('Context',height=25)
15
+
16
+ #context = st.text_input('Context :')
17
+
18
+
19
+
20
+
21
+
22
+ question=st.text_input('Question')
23
+
24
+
25
+ buff, col, buff2 = st.beta_columns([5,1,2])
26
+ choice_a=buff.text_input('choice 0')
27
+ choice_b=buff.text_input('choice 1')
28
+ choice_c=buff.text_input('choice 2')
29
+ choice_d=buff.text_input('choice 3')
30
+
31
+ a={}
32
+ def preprocess(context,question,choice_a,choice_b,choice_c,choice_d):
33
+ a['context&question']=context+question
34
+ a['first_sentence']=[a['context&question'],a['context&question'],a['context&question'],a['context&question']]
35
+ a['second_sentence']=choice_a,choice_b,choice_c,choice_d
36
+ return a
37
+
38
+ preprocessed_data=preprocess(context,question,choice_a,choice_b,choice_c,choice_d)
39
+
40
+ def tokenize(examples):
41
+ b=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,return_tensors='jax')
42
+ return b
43
+
44
+ tokenized_data=tokenize(preprocessed_data)
45
+
46
+
47
+ model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1))
48
+
49
+ input_id=jnp.array(tokenized_data['input_ids'])
50
+ att_mask=jnp.array(tokenized_data['attention_mask'])
51
+
52
+ input_id=input_id.reshape(1,4,-1)
53
+ att_mask=att_mask.reshape(1,4,-1)
54
+
55
+ if st.button("Run"):
56
+ with st.spinner(text="Getting results..."):
57
+ outputs=model(input_id,att_mask)
58
+ final_output=jnp.argmax(outputs,axis=-1)
59
+ #output=jax.device_get(final_output).item()
60
+ st.success(f"The answer is choice {final_output}")
61
+
62
+
63
+