|
import json |
|
import streamlit as st |
|
import networkx as nx |
|
from pyvis.network import Network |
|
import textwrap |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
import firebase_admin |
|
from firebase_admin import credentials |
|
from firebase_admin import firestore |
|
import os |
|
from api.local_api import get_question_by_id, get_question_ids_with_correctness,init_json |
|
|
|
|
|
def get_question_ids(): |
|
|
|
return get_question_ids_with_correctness(st.session_state.db) |
|
|
|
def get_question_context(question_id): |
|
return get_question_by_id(st.session_state.db,question_id[:-2]) |
|
|
|
def create_interactive_graph(reasoning_chain, ratings): |
|
G = nx.DiGraph() |
|
net = Network(notebook=True, width="100%", height="600px", directed=True) |
|
print(ratings) |
|
for i, step in enumerate(reasoning_chain): |
|
|
|
wrapped_text = textwrap.fill(step, width=30) |
|
label = f"Step {i+1}\n\n{wrapped_text}" |
|
color = "#97C2FC" if i < ratings else "#FF9999" |
|
border_color = "#00FF00" if i >= ratings else "#FF0000" |
|
G.add_node(i, title=step, label=label, color=color, borderWidth=3, borderColor=border_color) |
|
if i > 0: |
|
G.add_edge(i-1, i) |
|
|
|
net.from_nx(G) |
|
|
|
for node in net.nodes: |
|
node['shape'] = 'box' |
|
node['font'] = {'size': 12, 'face': 'arial', 'multi': 'html', 'align': 'center'} |
|
node['widthConstraint'] = {'minimum': 200, 'maximum': 300} |
|
|
|
net.set_options(''' |
|
var options = { |
|
"nodes": { |
|
"shape": "box", |
|
"physics": false, |
|
"margin": 10 |
|
}, |
|
"edges": { |
|
"smooth": { |
|
"type": "curvedCW", |
|
"roundness": 0.2 |
|
}, |
|
"arrows": { |
|
"to": { |
|
"enabled": true, |
|
"scaleFactor": 0.5 |
|
} |
|
} |
|
}, |
|
"layout": { |
|
"hierarchical": { |
|
"enabled": true, |
|
"direction": "UD", |
|
"sortMethod": "directed" |
|
} |
|
}, |
|
"interaction": { |
|
"hover": true, |
|
"tooltipDelay": 100 |
|
} |
|
} |
|
''') |
|
|
|
return net |
|
def update_rate(selected_question_id,i): |
|
st.session_state.ratings[selected_question_id] = i |
|
update_rate(selected_question_id,i) |
|
|
|
|
|
def main(): |
|
st.title("Interactive Q&A App with Reasoning Chain Graph and Rating") |
|
|
|
question_ids = get_question_ids() |
|
if 'current_index' not in st.session_state: |
|
st.session_state.current_index = 0 |
|
if 'ratings' not in st.session_state: |
|
st.session_state.ratings = {} |
|
|
|
col1, col2, col3 = st.columns([1,3,1]) |
|
with col1: |
|
if st.button("Previous"): |
|
st.session_state.current_index = (st.session_state.current_index - 1) % len(question_ids) |
|
with col3: |
|
if st.button("Next"): |
|
st.session_state.current_index = (st.session_state.current_index + 1) % len(question_ids) |
|
|
|
selected_question_id = st.selectbox("Select a question ID", question_ids, index=st.session_state.current_index) |
|
st.session_state.current_index = question_ids.index(selected_question_id) |
|
|
|
data = get_question_context(selected_question_id) |
|
original_question = data['original_question'] |
|
generated_result = data['generated_result'] |
|
|
|
st.subheader("Question") |
|
st.write(original_question['question']) |
|
|
|
st.subheader("Choices") |
|
for label, choice in original_question['label_choices'].items(): |
|
st.write(f"{label}: {choice}") |
|
|
|
|
|
st.subheader("Correct Answer") |
|
correct_answer_label = original_question['answer'] |
|
correct_answer = original_question['label_choices'][correct_answer_label] |
|
st.write(f"{correct_answer_label}: {correct_answer}") |
|
|
|
st.subheader("Generated Answer") |
|
generated_answer_label = generated_result['answer_key_vale'] |
|
generated_answer = original_question['label_choices'][generated_answer_label] |
|
st.write(f"{generated_answer_label}: {generated_answer}") |
|
|
|
st.subheader("is it correct ?") |
|
generated_answer_label = generated_result['answer_key_vale'] |
|
answer_label = original_question['answer'] |
|
is_same=answer_label.lower()==generated_answer_label.lower() |
|
st.write(f"✅ answers Are the same. answer is {answer_label}"if is_same else f"📛 answers do differ.\n But you can check reasonings.\noriginal answer label : {answer_label}\ngenerated answer label : {generated_answer_label}") |
|
|
|
|
|
st.subheader("Rate the Reasoning Steps") |
|
if selected_question_id not in st.session_state.ratings: |
|
st.session_state.ratings[selected_question_id] = data['max_depth'] |
|
|
|
rating = st.session_state.ratings[selected_question_id] |
|
cols = st.columns(len(generated_result['reasoning_chain'])+1) |
|
for i, col in enumerate(cols): |
|
if i==0: |
|
col.button(f"None", key=f"rate_{i}",on_click=update_rate,args=[selected_question_id,i]) |
|
continue |
|
|
|
col.button(f"Step {i}", key=f"rate_{i}",on_click=update_rate,args=[selected_question_id,i]) |
|
|
|
|
|
net = create_interactive_graph(generated_result['reasoning_chain'], rating) |
|
net.save_graph("graph.html") |
|
|
|
with open("graph.html", 'r', encoding='utf-8') as f: |
|
html = f.read() |
|
st.components.v1.html(html, height=600) |
|
def initialize_firebase(): |
|
""" |
|
Initialize Firebase app and return Firestore client. |
|
If the app is already initialized, it returns the existing Firestore client. |
|
""" |
|
try: |
|
cert=json.loads(os.getenv('google_json')) |
|
except: |
|
cert=os.getenv('google_json') |
|
cred = credentials.Certificate(cert) |
|
try: |
|
firebase_admin.get_app() |
|
print("Default app already exists") |
|
except ValueError: |
|
|
|
firebase_admin.initialize_app(cred) |
|
|
|
return firestore.client() |
|
|
|
|
|
if __name__ == "__main__": |
|
if os.getenv("local")=='true' or True: |
|
st.session_state.db=init_json() |
|
else: |
|
st.session_state.db=initialize_firebase() |
|
|
|
main() |