ReasMe / app.py
MSNP1381
ratings added
f4f1e28
raw
history blame
6.13 kB
import json
import streamlit as st
import networkx as nx
from pyvis.network import Network
import textwrap
from dotenv import load_dotenv
load_dotenv()
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
import os
from api.local_api import get_question_by_id, get_question_ids_with_correctness,init_json
# This function should be implemented to return a list of all question IDs
def get_question_ids():
# Placeholder implementation
return get_question_ids_with_correctness(st.session_state.db)
# This function should be implemented to return the context for a given question ID
def get_question_context(question_id):
return get_question_by_id(st.session_state.db,question_id[:-2])
def create_interactive_graph(reasoning_chain, ratings):
G = nx.DiGraph()
net = Network(notebook=True, width="100%", height="600px", directed=True)
print(ratings)
for i, step in enumerate(reasoning_chain):
wrapped_text = textwrap.fill(step, width=30)
label = f"Step {i+1}\n\n{wrapped_text}"
color = "#97C2FC" if i < ratings else "#FF9999"
border_color = "#00FF00" if i >= ratings else "#FF0000"
G.add_node(i, title=step, label=label, color=color, borderWidth=3, borderColor=border_color)
if i > 0:
G.add_edge(i-1, i)
net.from_nx(G)
for node in net.nodes:
node['shape'] = 'box'
node['font'] = {'size': 12, 'face': 'arial', 'multi': 'html', 'align': 'center'}
node['widthConstraint'] = {'minimum': 200, 'maximum': 300}
net.set_options('''
var options = {
"nodes": {
"shape": "box",
"physics": false,
"margin": 10
},
"edges": {
"smooth": {
"type": "curvedCW",
"roundness": 0.2
},
"arrows": {
"to": {
"enabled": true,
"scaleFactor": 0.5
}
}
},
"layout": {
"hierarchical": {
"enabled": true,
"direction": "UD",
"sortMethod": "directed"
}
},
"interaction": {
"hover": true,
"tooltipDelay": 100
}
}
''')
return net
def update_rate(selected_question_id,i):
st.session_state.ratings[selected_question_id] = i
update_rate(selected_question_id,i)
def main():
st.title("Interactive Q&A App with Reasoning Chain Graph and Rating")
question_ids = get_question_ids()
if 'current_index' not in st.session_state:
st.session_state.current_index = 0
if 'ratings' not in st.session_state:
st.session_state.ratings = {}
col1, col2, col3 = st.columns([1,3,1])
with col1:
if st.button("Previous"):
st.session_state.current_index = (st.session_state.current_index - 1) % len(question_ids)
with col3:
if st.button("Next"):
st.session_state.current_index = (st.session_state.current_index + 1) % len(question_ids)
selected_question_id = st.selectbox("Select a question ID", question_ids, index=st.session_state.current_index)
st.session_state.current_index = question_ids.index(selected_question_id)
data = get_question_context(selected_question_id)
original_question = data['original_question']
generated_result = data['generated_result']
st.subheader("Question")
st.write(original_question['question'])
st.subheader("Choices")
for label, choice in original_question['label_choices'].items():
st.write(f"{label}: {choice}")
# Display correct answer
st.subheader("Correct Answer")
correct_answer_label = original_question['answer']
correct_answer = original_question['label_choices'][correct_answer_label]
st.write(f"{correct_answer_label}: {correct_answer}")
st.subheader("Generated Answer")
generated_answer_label = generated_result['answer_key_vale']
generated_answer = original_question['label_choices'][generated_answer_label]
st.write(f"{generated_answer_label}: {generated_answer}")
st.subheader("is it correct ?")
generated_answer_label = generated_result['answer_key_vale']
answer_label = original_question['answer']
is_same=answer_label.lower()==generated_answer_label.lower()
st.write(f"✅ answers Are the same. answer is {answer_label}"if is_same else f"📛 answers do differ.\n But you can check reasonings.\noriginal answer label : {answer_label}\ngenerated answer label : {generated_answer_label}")
st.subheader("Rate the Reasoning Steps")
if selected_question_id not in st.session_state.ratings:
st.session_state.ratings[selected_question_id] = data['max_depth']
rating = st.session_state.ratings[selected_question_id]
cols = st.columns(len(generated_result['reasoning_chain'])+1)
for i, col in enumerate(cols):
if i==0:
col.button(f"None", key=f"rate_{i}",on_click=update_rate,args=[selected_question_id,i])
continue
col.button(f"Step {i}", key=f"rate_{i}",on_click=update_rate,args=[selected_question_id,i])
net = create_interactive_graph(generated_result['reasoning_chain'], rating)
net.save_graph("graph.html")
with open("graph.html", 'r', encoding='utf-8') as f:
html = f.read()
st.components.v1.html(html, height=600)
def initialize_firebase():
"""
Initialize Firebase app and return Firestore client.
If the app is already initialized, it returns the existing Firestore client.
"""
try:
cert=json.loads(os.getenv('google_json'))
except:
cert=os.getenv('google_json')
cred = credentials.Certificate(cert)
try:
firebase_admin.get_app()
print("Default app already exists")
except ValueError:
# Initialize the app with a service account, granting admin privileges
firebase_admin.initialize_app(cred)
return firestore.client()
if __name__ == "__main__":
if os.getenv("local")=='true':
st.session_state.db=init_json()
else:
st.session_state.db=initialize_firebase()
main()