MSNP1381 commited on
Commit
f4f1e28
Β·
1 Parent(s): 937b552

ratings added

Browse files
Files changed (5) hide show
  1. .gitignore +3 -1
  2. all_data.json +0 -0
  3. api/apis.py +6 -30
  4. api/local_api.py +36 -0
  5. app.py +66 -25
.gitignore CHANGED
@@ -160,4 +160,6 @@ cython_debug/
160
  # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
  # and can be added to the global gitignore or merged into this file. For a more nuclear
162
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
- #.idea/
 
 
 
160
  # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
  # and can be added to the global gitignore or merged into this file. For a more nuclear
162
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
164
+ tmp.py
165
+ answers/
all_data.json ADDED
The diff for this file is too large to render. See raw diff
 
api/apis.py CHANGED
@@ -1,17 +1,7 @@
1
- import json
2
- import firebase_admin
3
- from firebase_admin import credentials
4
  from firebase_admin import firestore
5
- import os
6
- # Initialize Firebase
7
- cert_dict=json.loads(os.getenv('google_json'))
8
- cred = credentials.Certificate(cert_dict)
9
- firebase_admin.initialize_app(cred)
10
 
11
- # Initialize Firestore DB
12
- db = firestore.client()
13
-
14
- def get_question_by_id(question_id):
15
  doc_ref = db.collection('questions').document(question_id)
16
  doc = doc_ref.get()
17
  if doc.exists:
@@ -19,20 +9,10 @@ def get_question_by_id(question_id):
19
  else:
20
  return None
21
 
22
- # Example usage
23
- question_id = "Mercury_7090615"
24
- question_data = get_question_by_id(question_id)
25
- if question_data:
26
- print(f"Question data for {question_id}: {question_data}")
27
- else:
28
- print(f"No data found for question ID: {question_id}")
29
-
30
-
31
-
32
- def get_question_ids_with_correctness():
33
  questions_ref = db.collection('questions')
34
  docs = questions_ref.stream()
35
-
36
  results = []
37
  for doc in docs:
38
  data = doc.to_dict()
@@ -42,10 +22,6 @@ def get_question_ids_with_correctness():
42
  correctness = "βœ…" if correct_answer == generated_answer else 'πŸ“›'
43
 
44
  results.append(f"{question_id} {correctness}")
 
45
 
46
- return results
47
-
48
- # Example usage
49
- correctness_list = get_question_ids_with_correctness()
50
- for result in correctness_list:
51
- print(result)
 
1
+ from typing import List
 
 
2
  from firebase_admin import firestore
 
 
 
 
 
3
 
4
+ def get_question_by_id(db,question_id):
 
 
 
5
  doc_ref = db.collection('questions').document(question_id)
6
  doc = doc_ref.get()
7
  if doc.exists:
 
9
  else:
10
  return None
11
 
12
+ def get_question_ids_with_correctness(db: firestore.Client) -> List[str]:
 
 
 
 
 
 
 
 
 
 
13
  questions_ref = db.collection('questions')
14
  docs = questions_ref.stream()
15
+ print("started")
16
  results = []
17
  for doc in docs:
18
  data = doc.to_dict()
 
22
  correctness = "βœ…" if correct_answer == generated_answer else 'πŸ“›'
23
 
24
  results.append(f"{question_id} {correctness}")
25
+ print(results)
26
 
27
+ return results
 
 
 
 
 
api/local_api.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import json
3
+
4
+ def get_question_by_id(l,question_id):
5
+ doc=None
6
+ for i in l:
7
+ if i['original_question']['question_id']==question_id:
8
+ doc=i
9
+ break
10
+ return doc
11
+ def get_question_ids_with_correctness(l) -> List[str]:
12
+
13
+ print("started")
14
+ results = []
15
+ for doc in l:
16
+ data = doc
17
+ question_id = data['original_question']['question_id']
18
+ correct_answer = data['original_question']['answer']
19
+ generated_answer = data['generated_result']['answer_key_vale']
20
+ correctness = "βœ…" if correct_answer == generated_answer else 'πŸ“›'
21
+
22
+ results.append(f"{question_id} {correctness}")
23
+
24
+ return results
25
+
26
+ def init_json():
27
+ with open("./all_data.json") as f:
28
+ return json.load(f)
29
+ def update_depth(question_id,depth):
30
+ with open('all_data_json' ,'w') as f:
31
+ l=json.load(f)
32
+ for i in l:
33
+ if i['original_question']['question_id']==question_id:
34
+ i["max_depth"]=depth
35
+ break
36
+ f.write(l)
app.py CHANGED
@@ -1,26 +1,35 @@
 
1
  import streamlit as st
2
  import networkx as nx
3
  from pyvis.network import Network
4
  import textwrap
5
-
6
- from api.apis import get_question_by_id, get_question_ids_with_correctness
 
 
 
 
 
7
 
8
  # This function should be implemented to return a list of all question IDs
9
  def get_question_ids():
10
  # Placeholder implementation
11
- return get_question_ids_with_correctness()
12
  # This function should be implemented to return the context for a given question ID
13
  def get_question_context(question_id):
14
- return get_question_by_id(question_id[:-2])
15
 
16
- def create_interactive_graph(reasoning_chain):
17
  G = nx.DiGraph()
18
  net = Network(notebook=True, width="100%", height="600px", directed=True)
19
-
20
  for i, step in enumerate(reasoning_chain):
 
21
  wrapped_text = textwrap.fill(step, width=30)
22
  label = f"Step {i+1}\n\n{wrapped_text}"
23
- G.add_node(i, title=step, label=label)
 
 
24
  if i > 0:
25
  G.add_edge(i-1, i)
26
 
@@ -28,7 +37,6 @@ def create_interactive_graph(reasoning_chain):
28
 
29
  for node in net.nodes:
30
  node['shape'] = 'box'
31
- node['color'] = '#97C2FC'
32
  node['font'] = {'size': 12, 'face': 'arial', 'multi': 'html', 'align': 'center'}
33
  node['widthConstraint'] = {'minimum': 200, 'maximum': 300}
34
 
@@ -66,18 +74,20 @@ def create_interactive_graph(reasoning_chain):
66
  ''')
67
 
68
  return net
69
-
 
 
 
 
70
  def main():
71
- st.title("Interactive Q&A App with Reasoning Chain Graph")
72
 
73
- # Get all question IDs
74
  question_ids = get_question_ids()
75
-
76
- # Initialize session state for current index
77
  if 'current_index' not in st.session_state:
78
  st.session_state.current_index = 0
 
 
79
 
80
- # Navigation buttons
81
  col1, col2, col3 = st.columns([1,3,1])
82
  with col1:
83
  if st.button("Previous"):
@@ -86,27 +96,21 @@ def main():
86
  if st.button("Next"):
87
  st.session_state.current_index = (st.session_state.current_index + 1) % len(question_ids)
88
 
89
- # Select a question ID
90
  selected_question_id = st.selectbox("Select a question ID", question_ids, index=st.session_state.current_index)
91
-
92
- # Update current_index when selection changes
93
  st.session_state.current_index = question_ids.index(selected_question_id)
94
 
95
- # Get the context for the selected question
96
  data = get_question_context(selected_question_id)
97
  original_question = data['original_question']
98
  generated_result = data['generated_result']
99
 
100
- # Display question
101
  st.subheader("Question")
102
  st.write(original_question['question'])
103
 
104
- # Display choices
105
  st.subheader("Choices")
106
  for label, choice in original_question['label_choices'].items():
107
  st.write(f"{label}: {choice}")
108
 
109
- # Display correct answer
110
  st.subheader("Correct Answer")
111
  correct_answer_label = original_question['answer']
112
  correct_answer = original_question['label_choices'][correct_answer_label]
@@ -122,15 +126,52 @@ def main():
122
  answer_label = original_question['answer']
123
  is_same=answer_label.lower()==generated_answer_label.lower()
124
  st.write(f"βœ… answers Are the same. answer is {answer_label}"if is_same else f"πŸ“› answers do differ.\n But you can check reasonings.\noriginal answer label : {answer_label}\ngenerated answer label : {generated_answer_label}")
125
-
126
- # Create and display interactive reasoning chain graph
127
- net = create_interactive_graph(generated_result['reasoning_chain'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  net.save_graph("graph.html")
129
 
130
- # Display the interactive graph
131
  with open("graph.html", 'r', encoding='utf-8') as f:
132
  html = f.read()
133
  st.components.v1.html(html, height=600)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  if __name__ == "__main__":
 
 
 
 
 
136
  main()
 
1
+ import json
2
  import streamlit as st
3
  import networkx as nx
4
  from pyvis.network import Network
5
  import textwrap
6
+ from dotenv import load_dotenv
7
+ load_dotenv()
8
+ import firebase_admin
9
+ from firebase_admin import credentials
10
+ from firebase_admin import firestore
11
+ import os
12
+ from api.local_api import get_question_by_id, get_question_ids_with_correctness,init_json
13
 
14
  # This function should be implemented to return a list of all question IDs
15
  def get_question_ids():
16
  # Placeholder implementation
17
+ return get_question_ids_with_correctness(st.session_state.db)
18
  # This function should be implemented to return the context for a given question ID
19
  def get_question_context(question_id):
20
+ return get_question_by_id(st.session_state.db,question_id[:-2])
21
 
22
+ def create_interactive_graph(reasoning_chain, ratings):
23
  G = nx.DiGraph()
24
  net = Network(notebook=True, width="100%", height="600px", directed=True)
25
+ print(ratings)
26
  for i, step in enumerate(reasoning_chain):
27
+
28
  wrapped_text = textwrap.fill(step, width=30)
29
  label = f"Step {i+1}\n\n{wrapped_text}"
30
+ color = "#97C2FC" if i < ratings else "#FF9999"
31
+ border_color = "#00FF00" if i >= ratings else "#FF0000"
32
+ G.add_node(i, title=step, label=label, color=color, borderWidth=3, borderColor=border_color)
33
  if i > 0:
34
  G.add_edge(i-1, i)
35
 
 
37
 
38
  for node in net.nodes:
39
  node['shape'] = 'box'
 
40
  node['font'] = {'size': 12, 'face': 'arial', 'multi': 'html', 'align': 'center'}
41
  node['widthConstraint'] = {'minimum': 200, 'maximum': 300}
42
 
 
74
  ''')
75
 
76
  return net
77
+ def update_rate(selected_question_id,i):
78
+ st.session_state.ratings[selected_question_id] = i
79
+ update_rate(selected_question_id,i)
80
+
81
+
82
  def main():
83
+ st.title("Interactive Q&A App with Reasoning Chain Graph and Rating")
84
 
 
85
  question_ids = get_question_ids()
 
 
86
  if 'current_index' not in st.session_state:
87
  st.session_state.current_index = 0
88
+ if 'ratings' not in st.session_state:
89
+ st.session_state.ratings = {}
90
 
 
91
  col1, col2, col3 = st.columns([1,3,1])
92
  with col1:
93
  if st.button("Previous"):
 
96
  if st.button("Next"):
97
  st.session_state.current_index = (st.session_state.current_index + 1) % len(question_ids)
98
 
 
99
  selected_question_id = st.selectbox("Select a question ID", question_ids, index=st.session_state.current_index)
 
 
100
  st.session_state.current_index = question_ids.index(selected_question_id)
101
 
 
102
  data = get_question_context(selected_question_id)
103
  original_question = data['original_question']
104
  generated_result = data['generated_result']
105
 
 
106
  st.subheader("Question")
107
  st.write(original_question['question'])
108
 
 
109
  st.subheader("Choices")
110
  for label, choice in original_question['label_choices'].items():
111
  st.write(f"{label}: {choice}")
112
 
113
+ # Display correct answer
114
  st.subheader("Correct Answer")
115
  correct_answer_label = original_question['answer']
116
  correct_answer = original_question['label_choices'][correct_answer_label]
 
126
  answer_label = original_question['answer']
127
  is_same=answer_label.lower()==generated_answer_label.lower()
128
  st.write(f"βœ… answers Are the same. answer is {answer_label}"if is_same else f"πŸ“› answers do differ.\n But you can check reasonings.\noriginal answer label : {answer_label}\ngenerated answer label : {generated_answer_label}")
129
+
130
+
131
+ st.subheader("Rate the Reasoning Steps")
132
+ if selected_question_id not in st.session_state.ratings:
133
+ st.session_state.ratings[selected_question_id] = data['max_depth']
134
+
135
+ rating = st.session_state.ratings[selected_question_id]
136
+ cols = st.columns(len(generated_result['reasoning_chain'])+1)
137
+ for i, col in enumerate(cols):
138
+ if i==0:
139
+ col.button(f"None", key=f"rate_{i}",on_click=update_rate,args=[selected_question_id,i])
140
+ continue
141
+
142
+ col.button(f"Step {i}", key=f"rate_{i}",on_click=update_rate,args=[selected_question_id,i])
143
+
144
+
145
+ net = create_interactive_graph(generated_result['reasoning_chain'], rating)
146
  net.save_graph("graph.html")
147
 
 
148
  with open("graph.html", 'r', encoding='utf-8') as f:
149
  html = f.read()
150
  st.components.v1.html(html, height=600)
151
+ def initialize_firebase():
152
+ """
153
+ Initialize Firebase app and return Firestore client.
154
+ If the app is already initialized, it returns the existing Firestore client.
155
+ """
156
+ try:
157
+ cert=json.loads(os.getenv('google_json'))
158
+ except:
159
+ cert=os.getenv('google_json')
160
+ cred = credentials.Certificate(cert)
161
+ try:
162
+ firebase_admin.get_app()
163
+ print("Default app already exists")
164
+ except ValueError:
165
+ # Initialize the app with a service account, granting admin privileges
166
+ firebase_admin.initialize_app(cred)
167
+
168
+ return firestore.client()
169
+
170
 
171
  if __name__ == "__main__":
172
+ if os.getenv("local")=='true':
173
+ st.session_state.db=init_json()
174
+ else:
175
+ st.session_state.db=initialize_firebase()
176
+
177
  main()