File size: 6,137 Bytes
f4f1e28 5e5e2eb f4f1e28 5e5e2eb f4f1e28 5e5e2eb f4f1e28 5e5e2eb f4f1e28 5e5e2eb f4f1e28 5e5e2eb f4f1e28 5e5e2eb f4f1e28 5e5e2eb f4f1e28 5e5e2eb f4f1e28 5e5e2eb f4f1e28 5e5e2eb f4f1e28 5e5e2eb f4f1e28 5e5e2eb f4f1e28 5e5e2eb 40a2be0 f4f1e28 5e5e2eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import json
import streamlit as st
import networkx as nx
from pyvis.network import Network
import textwrap
from dotenv import load_dotenv
load_dotenv()
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
import os
from api.local_api import get_question_by_id, get_question_ids_with_correctness,init_json
# This function should be implemented to return a list of all question IDs
def get_question_ids():
# Placeholder implementation
return get_question_ids_with_correctness(st.session_state.db)
# This function should be implemented to return the context for a given question ID
def get_question_context(question_id):
return get_question_by_id(st.session_state.db,question_id[:-2])
def create_interactive_graph(reasoning_chain, ratings):
G = nx.DiGraph()
net = Network(notebook=True, width="100%", height="600px", directed=True)
print(ratings)
for i, step in enumerate(reasoning_chain):
wrapped_text = textwrap.fill(step, width=30)
label = f"Step {i+1}\n\n{wrapped_text}"
color = "#97C2FC" if i < ratings else "#FF9999"
border_color = "#00FF00" if i >= ratings else "#FF0000"
G.add_node(i, title=step, label=label, color=color, borderWidth=3, borderColor=border_color)
if i > 0:
G.add_edge(i-1, i)
net.from_nx(G)
for node in net.nodes:
node['shape'] = 'box'
node['font'] = {'size': 12, 'face': 'arial', 'multi': 'html', 'align': 'center'}
node['widthConstraint'] = {'minimum': 200, 'maximum': 300}
net.set_options('''
var options = {
"nodes": {
"shape": "box",
"physics": false,
"margin": 10
},
"edges": {
"smooth": {
"type": "curvedCW",
"roundness": 0.2
},
"arrows": {
"to": {
"enabled": true,
"scaleFactor": 0.5
}
}
},
"layout": {
"hierarchical": {
"enabled": true,
"direction": "UD",
"sortMethod": "directed"
}
},
"interaction": {
"hover": true,
"tooltipDelay": 100
}
}
''')
return net
def update_rate(selected_question_id,i):
st.session_state.ratings[selected_question_id] = i
update_rate(selected_question_id,i)
def main():
st.title("Interactive Q&A App with Reasoning Chain Graph and Rating")
question_ids = get_question_ids()
if 'current_index' not in st.session_state:
st.session_state.current_index = 0
if 'ratings' not in st.session_state:
st.session_state.ratings = {}
col1, col2, col3 = st.columns([1,3,1])
with col1:
if st.button("Previous"):
st.session_state.current_index = (st.session_state.current_index - 1) % len(question_ids)
with col3:
if st.button("Next"):
st.session_state.current_index = (st.session_state.current_index + 1) % len(question_ids)
selected_question_id = st.selectbox("Select a question ID", question_ids, index=st.session_state.current_index)
st.session_state.current_index = question_ids.index(selected_question_id)
data = get_question_context(selected_question_id)
original_question = data['original_question']
generated_result = data['generated_result']
st.subheader("Question")
st.write(original_question['question'])
st.subheader("Choices")
for label, choice in original_question['label_choices'].items():
st.write(f"{label}: {choice}")
# Display correct answer
st.subheader("Correct Answer")
correct_answer_label = original_question['answer']
correct_answer = original_question['label_choices'][correct_answer_label]
st.write(f"{correct_answer_label}: {correct_answer}")
st.subheader("Generated Answer")
generated_answer_label = generated_result['answer_key_vale']
generated_answer = original_question['label_choices'][generated_answer_label]
st.write(f"{generated_answer_label}: {generated_answer}")
st.subheader("is it correct ?")
generated_answer_label = generated_result['answer_key_vale']
answer_label = original_question['answer']
is_same=answer_label.lower()==generated_answer_label.lower()
st.write(f"✅ answers Are the same. answer is {answer_label}"if is_same else f"📛 answers do differ.\n But you can check reasonings.\noriginal answer label : {answer_label}\ngenerated answer label : {generated_answer_label}")
st.subheader("Rate the Reasoning Steps")
if selected_question_id not in st.session_state.ratings:
st.session_state.ratings[selected_question_id] = data['max_depth']
rating = st.session_state.ratings[selected_question_id]
cols = st.columns(len(generated_result['reasoning_chain'])+1)
for i, col in enumerate(cols):
if i==0:
col.button(f"None", key=f"rate_{i}",on_click=update_rate,args=[selected_question_id,i])
continue
col.button(f"Step {i}", key=f"rate_{i}",on_click=update_rate,args=[selected_question_id,i])
net = create_interactive_graph(generated_result['reasoning_chain'], rating)
net.save_graph("graph.html")
with open("graph.html", 'r', encoding='utf-8') as f:
html = f.read()
st.components.v1.html(html, height=600)
def initialize_firebase():
"""
Initialize Firebase app and return Firestore client.
If the app is already initialized, it returns the existing Firestore client.
"""
try:
cert=json.loads(os.getenv('google_json'))
except:
cert=os.getenv('google_json')
cred = credentials.Certificate(cert)
try:
firebase_admin.get_app()
print("Default app already exists")
except ValueError:
# Initialize the app with a service account, granting admin privileges
firebase_admin.initialize_app(cred)
return firestore.client()
if __name__ == "__main__":
if os.getenv("local")=='true' or True:
st.session_state.db=init_json()
else:
st.session_state.db=initialize_firebase()
main() |