ReasMe / app.py
MSNP1381
change in data read
40a2be0
raw
history blame
6.14 kB
import json
import streamlit as st
import networkx as nx
from pyvis.network import Network
import textwrap
from dotenv import load_dotenv
load_dotenv()
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
import os
from api.local_api import get_question_by_id, get_question_ids_with_correctness,init_json
# This function should be implemented to return a list of all question IDs
def get_question_ids():
# Placeholder implementation
return get_question_ids_with_correctness(st.session_state.db)
# This function should be implemented to return the context for a given question ID
def get_question_context(question_id):
return get_question_by_id(st.session_state.db,question_id[:-2])
def create_interactive_graph(reasoning_chain, ratings):
G = nx.DiGraph()
net = Network(notebook=True, width="100%", height="600px", directed=True)
print(ratings)
for i, step in enumerate(reasoning_chain):
wrapped_text = textwrap.fill(step, width=30)
label = f"Step {i+1}\n\n{wrapped_text}"
color = "#97C2FC" if i < ratings else "#FF9999"
border_color = "#00FF00" if i >= ratings else "#FF0000"
G.add_node(i, title=step, label=label, color=color, borderWidth=3, borderColor=border_color)
if i > 0:
G.add_edge(i-1, i)
net.from_nx(G)
for node in net.nodes:
node['shape'] = 'box'
node['font'] = {'size': 12, 'face': 'arial', 'multi': 'html', 'align': 'center'}
node['widthConstraint'] = {'minimum': 200, 'maximum': 300}
net.set_options('''
var options = {
"nodes": {
"shape": "box",
"physics": false,
"margin": 10
},
"edges": {
"smooth": {
"type": "curvedCW",
"roundness": 0.2
},
"arrows": {
"to": {
"enabled": true,
"scaleFactor": 0.5
}
}
},
"layout": {
"hierarchical": {
"enabled": true,
"direction": "UD",
"sortMethod": "directed"
}
},
"interaction": {
"hover": true,
"tooltipDelay": 100
}
}
''')
return net
def update_rate(selected_question_id,i):
st.session_state.ratings[selected_question_id] = i
update_rate(selected_question_id,i)
def main():
st.title("Interactive Q&A App with Reasoning Chain Graph and Rating")
question_ids = get_question_ids()
if 'current_index' not in st.session_state:
st.session_state.current_index = 0
if 'ratings' not in st.session_state:
st.session_state.ratings = {}
col1, col2, col3 = st.columns([1,3,1])
with col1:
if st.button("Previous"):
st.session_state.current_index = (st.session_state.current_index - 1) % len(question_ids)
with col3:
if st.button("Next"):
st.session_state.current_index = (st.session_state.current_index + 1) % len(question_ids)
selected_question_id = st.selectbox("Select a question ID", question_ids, index=st.session_state.current_index)
st.session_state.current_index = question_ids.index(selected_question_id)
data = get_question_context(selected_question_id)
original_question = data['original_question']
generated_result = data['generated_result']
st.subheader("Question")
st.write(original_question['question'])
st.subheader("Choices")
for label, choice in original_question['label_choices'].items():
st.write(f"{label}: {choice}")
# Display correct answer
st.subheader("Correct Answer")
correct_answer_label = original_question['answer']
correct_answer = original_question['label_choices'][correct_answer_label]
st.write(f"{correct_answer_label}: {correct_answer}")
st.subheader("Generated Answer")
generated_answer_label = generated_result['answer_key_vale']
generated_answer = original_question['label_choices'][generated_answer_label]
st.write(f"{generated_answer_label}: {generated_answer}")
st.subheader("is it correct ?")
generated_answer_label = generated_result['answer_key_vale']
answer_label = original_question['answer']
is_same=answer_label.lower()==generated_answer_label.lower()
st.write(f"✅ answers Are the same. answer is {answer_label}"if is_same else f"📛 answers do differ.\n But you can check reasonings.\noriginal answer label : {answer_label}\ngenerated answer label : {generated_answer_label}")
st.subheader("Rate the Reasoning Steps")
if selected_question_id not in st.session_state.ratings:
st.session_state.ratings[selected_question_id] = data['max_depth']
rating = st.session_state.ratings[selected_question_id]
cols = st.columns(len(generated_result['reasoning_chain'])+1)
for i, col in enumerate(cols):
if i==0:
col.button(f"None", key=f"rate_{i}",on_click=update_rate,args=[selected_question_id,i])
continue
col.button(f"Step {i}", key=f"rate_{i}",on_click=update_rate,args=[selected_question_id,i])
net = create_interactive_graph(generated_result['reasoning_chain'], rating)
net.save_graph("graph.html")
with open("graph.html", 'r', encoding='utf-8') as f:
html = f.read()
st.components.v1.html(html, height=600)
def initialize_firebase():
"""
Initialize Firebase app and return Firestore client.
If the app is already initialized, it returns the existing Firestore client.
"""
try:
cert=json.loads(os.getenv('google_json'))
except:
cert=os.getenv('google_json')
cred = credentials.Certificate(cert)
try:
firebase_admin.get_app()
print("Default app already exists")
except ValueError:
# Initialize the app with a service account, granting admin privileges
firebase_admin.initialize_app(cred)
return firestore.client()
if __name__ == "__main__":
if os.getenv("local")=='true' or True:
st.session_state.db=init_json()
else:
st.session_state.db=initialize_firebase()
main()