MSNP1381
commited on
Commit
Β·
f4f1e28
1
Parent(s):
937b552
ratings added
Browse files- .gitignore +3 -1
- all_data.json +0 -0
- api/apis.py +6 -30
- api/local_api.py +36 -0
- app.py +66 -25
.gitignore
CHANGED
@@ -160,4 +160,6 @@ cython_debug/
|
|
160 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
161 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
162 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
163 |
-
#.idea/
|
|
|
|
|
|
160 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
161 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
162 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
163 |
+
#.idea/
|
164 |
+
tmp.py
|
165 |
+
answers/
|
all_data.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
api/apis.py
CHANGED
@@ -1,17 +1,7 @@
|
|
1 |
-
import
|
2 |
-
import firebase_admin
|
3 |
-
from firebase_admin import credentials
|
4 |
from firebase_admin import firestore
|
5 |
-
import os
|
6 |
-
# Initialize Firebase
|
7 |
-
cert_dict=json.loads(os.getenv('google_json'))
|
8 |
-
cred = credentials.Certificate(cert_dict)
|
9 |
-
firebase_admin.initialize_app(cred)
|
10 |
|
11 |
-
|
12 |
-
db = firestore.client()
|
13 |
-
|
14 |
-
def get_question_by_id(question_id):
|
15 |
doc_ref = db.collection('questions').document(question_id)
|
16 |
doc = doc_ref.get()
|
17 |
if doc.exists:
|
@@ -19,20 +9,10 @@ def get_question_by_id(question_id):
|
|
19 |
else:
|
20 |
return None
|
21 |
|
22 |
-
|
23 |
-
question_id = "Mercury_7090615"
|
24 |
-
question_data = get_question_by_id(question_id)
|
25 |
-
if question_data:
|
26 |
-
print(f"Question data for {question_id}: {question_data}")
|
27 |
-
else:
|
28 |
-
print(f"No data found for question ID: {question_id}")
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
def get_question_ids_with_correctness():
|
33 |
questions_ref = db.collection('questions')
|
34 |
docs = questions_ref.stream()
|
35 |
-
|
36 |
results = []
|
37 |
for doc in docs:
|
38 |
data = doc.to_dict()
|
@@ -42,10 +22,6 @@ def get_question_ids_with_correctness():
|
|
42 |
correctness = "β
" if correct_answer == generated_answer else 'π'
|
43 |
|
44 |
results.append(f"{question_id} {correctness}")
|
|
|
45 |
|
46 |
-
return results
|
47 |
-
|
48 |
-
# Example usage
|
49 |
-
correctness_list = get_question_ids_with_correctness()
|
50 |
-
for result in correctness_list:
|
51 |
-
print(result)
|
|
|
1 |
+
from typing import List
|
|
|
|
|
2 |
from firebase_admin import firestore
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
+
def get_question_by_id(db,question_id):
|
|
|
|
|
|
|
5 |
doc_ref = db.collection('questions').document(question_id)
|
6 |
doc = doc_ref.get()
|
7 |
if doc.exists:
|
|
|
9 |
else:
|
10 |
return None
|
11 |
|
12 |
+
def get_question_ids_with_correctness(db: firestore.Client) -> List[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
questions_ref = db.collection('questions')
|
14 |
docs = questions_ref.stream()
|
15 |
+
print("started")
|
16 |
results = []
|
17 |
for doc in docs:
|
18 |
data = doc.to_dict()
|
|
|
22 |
correctness = "β
" if correct_answer == generated_answer else 'π'
|
23 |
|
24 |
results.append(f"{question_id} {correctness}")
|
25 |
+
print(results)
|
26 |
|
27 |
+
return results
|
|
|
|
|
|
|
|
|
|
api/local_api.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import json
|
3 |
+
|
4 |
+
def get_question_by_id(l,question_id):
|
5 |
+
doc=None
|
6 |
+
for i in l:
|
7 |
+
if i['original_question']['question_id']==question_id:
|
8 |
+
doc=i
|
9 |
+
break
|
10 |
+
return doc
|
11 |
+
def get_question_ids_with_correctness(l) -> List[str]:
|
12 |
+
|
13 |
+
print("started")
|
14 |
+
results = []
|
15 |
+
for doc in l:
|
16 |
+
data = doc
|
17 |
+
question_id = data['original_question']['question_id']
|
18 |
+
correct_answer = data['original_question']['answer']
|
19 |
+
generated_answer = data['generated_result']['answer_key_vale']
|
20 |
+
correctness = "β
" if correct_answer == generated_answer else 'π'
|
21 |
+
|
22 |
+
results.append(f"{question_id} {correctness}")
|
23 |
+
|
24 |
+
return results
|
25 |
+
|
26 |
+
def init_json():
|
27 |
+
with open("./all_data.json") as f:
|
28 |
+
return json.load(f)
|
29 |
+
def update_depth(question_id,depth):
|
30 |
+
with open('all_data_json' ,'w') as f:
|
31 |
+
l=json.load(f)
|
32 |
+
for i in l:
|
33 |
+
if i['original_question']['question_id']==question_id:
|
34 |
+
i["max_depth"]=depth
|
35 |
+
break
|
36 |
+
f.write(l)
|
app.py
CHANGED
@@ -1,26 +1,35 @@
|
|
|
|
1 |
import streamlit as st
|
2 |
import networkx as nx
|
3 |
from pyvis.network import Network
|
4 |
import textwrap
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# This function should be implemented to return a list of all question IDs
|
9 |
def get_question_ids():
|
10 |
# Placeholder implementation
|
11 |
-
return get_question_ids_with_correctness()
|
12 |
# This function should be implemented to return the context for a given question ID
|
13 |
def get_question_context(question_id):
|
14 |
-
return get_question_by_id(question_id[:-2])
|
15 |
|
16 |
-
def create_interactive_graph(reasoning_chain):
|
17 |
G = nx.DiGraph()
|
18 |
net = Network(notebook=True, width="100%", height="600px", directed=True)
|
19 |
-
|
20 |
for i, step in enumerate(reasoning_chain):
|
|
|
21 |
wrapped_text = textwrap.fill(step, width=30)
|
22 |
label = f"Step {i+1}\n\n{wrapped_text}"
|
23 |
-
|
|
|
|
|
24 |
if i > 0:
|
25 |
G.add_edge(i-1, i)
|
26 |
|
@@ -28,7 +37,6 @@ def create_interactive_graph(reasoning_chain):
|
|
28 |
|
29 |
for node in net.nodes:
|
30 |
node['shape'] = 'box'
|
31 |
-
node['color'] = '#97C2FC'
|
32 |
node['font'] = {'size': 12, 'face': 'arial', 'multi': 'html', 'align': 'center'}
|
33 |
node['widthConstraint'] = {'minimum': 200, 'maximum': 300}
|
34 |
|
@@ -66,18 +74,20 @@ def create_interactive_graph(reasoning_chain):
|
|
66 |
''')
|
67 |
|
68 |
return net
|
69 |
-
|
|
|
|
|
|
|
|
|
70 |
def main():
|
71 |
-
st.title("Interactive Q&A App with Reasoning Chain Graph")
|
72 |
|
73 |
-
# Get all question IDs
|
74 |
question_ids = get_question_ids()
|
75 |
-
|
76 |
-
# Initialize session state for current index
|
77 |
if 'current_index' not in st.session_state:
|
78 |
st.session_state.current_index = 0
|
|
|
|
|
79 |
|
80 |
-
# Navigation buttons
|
81 |
col1, col2, col3 = st.columns([1,3,1])
|
82 |
with col1:
|
83 |
if st.button("Previous"):
|
@@ -86,27 +96,21 @@ def main():
|
|
86 |
if st.button("Next"):
|
87 |
st.session_state.current_index = (st.session_state.current_index + 1) % len(question_ids)
|
88 |
|
89 |
-
# Select a question ID
|
90 |
selected_question_id = st.selectbox("Select a question ID", question_ids, index=st.session_state.current_index)
|
91 |
-
|
92 |
-
# Update current_index when selection changes
|
93 |
st.session_state.current_index = question_ids.index(selected_question_id)
|
94 |
|
95 |
-
# Get the context for the selected question
|
96 |
data = get_question_context(selected_question_id)
|
97 |
original_question = data['original_question']
|
98 |
generated_result = data['generated_result']
|
99 |
|
100 |
-
# Display question
|
101 |
st.subheader("Question")
|
102 |
st.write(original_question['question'])
|
103 |
|
104 |
-
# Display choices
|
105 |
st.subheader("Choices")
|
106 |
for label, choice in original_question['label_choices'].items():
|
107 |
st.write(f"{label}: {choice}")
|
108 |
|
109 |
-
|
110 |
st.subheader("Correct Answer")
|
111 |
correct_answer_label = original_question['answer']
|
112 |
correct_answer = original_question['label_choices'][correct_answer_label]
|
@@ -122,15 +126,52 @@ def main():
|
|
122 |
answer_label = original_question['answer']
|
123 |
is_same=answer_label.lower()==generated_answer_label.lower()
|
124 |
st.write(f"β
answers Are the same. answer is {answer_label}"if is_same else f"π answers do differ.\n But you can check reasonings.\noriginal answer label : {answer_label}\ngenerated answer label : {generated_answer_label}")
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
net.save_graph("graph.html")
|
129 |
|
130 |
-
# Display the interactive graph
|
131 |
with open("graph.html", 'r', encoding='utf-8') as f:
|
132 |
html = f.read()
|
133 |
st.components.v1.html(html, height=600)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
136 |
main()
|
|
|
1 |
+
import json
|
2 |
import streamlit as st
|
3 |
import networkx as nx
|
4 |
from pyvis.network import Network
|
5 |
import textwrap
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
load_dotenv()
|
8 |
+
import firebase_admin
|
9 |
+
from firebase_admin import credentials
|
10 |
+
from firebase_admin import firestore
|
11 |
+
import os
|
12 |
+
from api.local_api import get_question_by_id, get_question_ids_with_correctness,init_json
|
13 |
|
14 |
# This function should be implemented to return a list of all question IDs
|
15 |
def get_question_ids():
|
16 |
# Placeholder implementation
|
17 |
+
return get_question_ids_with_correctness(st.session_state.db)
|
18 |
# This function should be implemented to return the context for a given question ID
|
19 |
def get_question_context(question_id):
|
20 |
+
return get_question_by_id(st.session_state.db,question_id[:-2])
|
21 |
|
22 |
+
def create_interactive_graph(reasoning_chain, ratings):
|
23 |
G = nx.DiGraph()
|
24 |
net = Network(notebook=True, width="100%", height="600px", directed=True)
|
25 |
+
print(ratings)
|
26 |
for i, step in enumerate(reasoning_chain):
|
27 |
+
|
28 |
wrapped_text = textwrap.fill(step, width=30)
|
29 |
label = f"Step {i+1}\n\n{wrapped_text}"
|
30 |
+
color = "#97C2FC" if i < ratings else "#FF9999"
|
31 |
+
border_color = "#00FF00" if i >= ratings else "#FF0000"
|
32 |
+
G.add_node(i, title=step, label=label, color=color, borderWidth=3, borderColor=border_color)
|
33 |
if i > 0:
|
34 |
G.add_edge(i-1, i)
|
35 |
|
|
|
37 |
|
38 |
for node in net.nodes:
|
39 |
node['shape'] = 'box'
|
|
|
40 |
node['font'] = {'size': 12, 'face': 'arial', 'multi': 'html', 'align': 'center'}
|
41 |
node['widthConstraint'] = {'minimum': 200, 'maximum': 300}
|
42 |
|
|
|
74 |
''')
|
75 |
|
76 |
return net
|
77 |
+
def update_rate(selected_question_id,i):
|
78 |
+
st.session_state.ratings[selected_question_id] = i
|
79 |
+
update_rate(selected_question_id,i)
|
80 |
+
|
81 |
+
|
82 |
def main():
|
83 |
+
st.title("Interactive Q&A App with Reasoning Chain Graph and Rating")
|
84 |
|
|
|
85 |
question_ids = get_question_ids()
|
|
|
|
|
86 |
if 'current_index' not in st.session_state:
|
87 |
st.session_state.current_index = 0
|
88 |
+
if 'ratings' not in st.session_state:
|
89 |
+
st.session_state.ratings = {}
|
90 |
|
|
|
91 |
col1, col2, col3 = st.columns([1,3,1])
|
92 |
with col1:
|
93 |
if st.button("Previous"):
|
|
|
96 |
if st.button("Next"):
|
97 |
st.session_state.current_index = (st.session_state.current_index + 1) % len(question_ids)
|
98 |
|
|
|
99 |
selected_question_id = st.selectbox("Select a question ID", question_ids, index=st.session_state.current_index)
|
|
|
|
|
100 |
st.session_state.current_index = question_ids.index(selected_question_id)
|
101 |
|
|
|
102 |
data = get_question_context(selected_question_id)
|
103 |
original_question = data['original_question']
|
104 |
generated_result = data['generated_result']
|
105 |
|
|
|
106 |
st.subheader("Question")
|
107 |
st.write(original_question['question'])
|
108 |
|
|
|
109 |
st.subheader("Choices")
|
110 |
for label, choice in original_question['label_choices'].items():
|
111 |
st.write(f"{label}: {choice}")
|
112 |
|
113 |
+
# Display correct answer
|
114 |
st.subheader("Correct Answer")
|
115 |
correct_answer_label = original_question['answer']
|
116 |
correct_answer = original_question['label_choices'][correct_answer_label]
|
|
|
126 |
answer_label = original_question['answer']
|
127 |
is_same=answer_label.lower()==generated_answer_label.lower()
|
128 |
st.write(f"β
answers Are the same. answer is {answer_label}"if is_same else f"π answers do differ.\n But you can check reasonings.\noriginal answer label : {answer_label}\ngenerated answer label : {generated_answer_label}")
|
129 |
+
|
130 |
+
|
131 |
+
st.subheader("Rate the Reasoning Steps")
|
132 |
+
if selected_question_id not in st.session_state.ratings:
|
133 |
+
st.session_state.ratings[selected_question_id] = data['max_depth']
|
134 |
+
|
135 |
+
rating = st.session_state.ratings[selected_question_id]
|
136 |
+
cols = st.columns(len(generated_result['reasoning_chain'])+1)
|
137 |
+
for i, col in enumerate(cols):
|
138 |
+
if i==0:
|
139 |
+
col.button(f"None", key=f"rate_{i}",on_click=update_rate,args=[selected_question_id,i])
|
140 |
+
continue
|
141 |
+
|
142 |
+
col.button(f"Step {i}", key=f"rate_{i}",on_click=update_rate,args=[selected_question_id,i])
|
143 |
+
|
144 |
+
|
145 |
+
net = create_interactive_graph(generated_result['reasoning_chain'], rating)
|
146 |
net.save_graph("graph.html")
|
147 |
|
|
|
148 |
with open("graph.html", 'r', encoding='utf-8') as f:
|
149 |
html = f.read()
|
150 |
st.components.v1.html(html, height=600)
|
151 |
+
def initialize_firebase():
|
152 |
+
"""
|
153 |
+
Initialize Firebase app and return Firestore client.
|
154 |
+
If the app is already initialized, it returns the existing Firestore client.
|
155 |
+
"""
|
156 |
+
try:
|
157 |
+
cert=json.loads(os.getenv('google_json'))
|
158 |
+
except:
|
159 |
+
cert=os.getenv('google_json')
|
160 |
+
cred = credentials.Certificate(cert)
|
161 |
+
try:
|
162 |
+
firebase_admin.get_app()
|
163 |
+
print("Default app already exists")
|
164 |
+
except ValueError:
|
165 |
+
# Initialize the app with a service account, granting admin privileges
|
166 |
+
firebase_admin.initialize_app(cred)
|
167 |
+
|
168 |
+
return firestore.client()
|
169 |
+
|
170 |
|
171 |
if __name__ == "__main__":
|
172 |
+
if os.getenv("local")=='true':
|
173 |
+
st.session_state.db=init_json()
|
174 |
+
else:
|
175 |
+
st.session_state.db=initialize_firebase()
|
176 |
+
|
177 |
main()
|