Vivek commited on
Commit
72fe35a
·
1 Parent(s): e98fee7

delte the files

Browse files
Files changed (1) hide show
  1. app.py +0 -63
app.py DELETED
@@ -1,63 +0,0 @@
1
- import streamlit as st
2
- import transformers
3
- from transformers import (
4
- GPT2Config,
5
- GPT2Tokenizer)
6
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>')
7
- from model_file import FlaxGPT2ForMultipleChoice
8
- import jax
9
- import jax.numpy as jnp
10
-
11
- st.title('GPT2 for common sense reasoning')
12
- st.write('Multiple Choice Question Answering using CosmosQA Dataset')
13
-
14
- context=st.text_area('Context',height=25)
15
-
16
- #context = st.text_input('Context :')
17
-
18
-
19
-
20
-
21
-
22
- question=st.text_input('Question')
23
-
24
-
25
- buff, col, buff2 = st.beta_columns([5,1,2])
26
- choice_a=buff.text_input('choice 0')
27
- choice_b=buff.text_input('choice 1')
28
- choice_c=buff.text_input('choice 2')
29
- choice_d=buff.text_input('choice 3')
30
-
31
- a={}
32
- def preprocess(context,question,choice_a,choice_b,choice_c,choice_d):
33
- a['context&question']=context+question
34
- a['first_sentence']=[a['context&question'],a['context&question'],a['context&question'],a['context&question']]
35
- a['second_sentence']=choice_a,choice_b,choice_c,choice_d
36
- return a
37
-
38
- preprocessed_data=preprocess(context,question,choice_a,choice_b,choice_c,choice_d)
39
-
40
- def tokenize(examples):
41
- b=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,return_tensors='jax')
42
- return b
43
-
44
- tokenized_data=tokenize(preprocessed_data)
45
-
46
-
47
- model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1))
48
-
49
- input_id=jnp.array(tokenized_data['input_ids'])
50
- att_mask=jnp.array(tokenized_data['attention_mask'])
51
-
52
- input_id=input_id.reshape(1,4,-1)
53
- att_mask=att_mask.reshape(1,4,-1)
54
-
55
- if st.button("Run"):
56
- with st.spinner(text="Getting results..."):
57
- outputs=model(input_id,att_mask)
58
- final_output=jnp.argmax(outputs,axis=-1)
59
- #output=jax.device_get(final_output).item()
60
- st.success(f"The answer is choice {final_output}")
61
-
62
-
63
-