import json import streamlit as st import networkx as nx from pyvis.network import Network import textwrap from dotenv import load_dotenv load_dotenv() import firebase_admin from firebase_admin import credentials from firebase_admin import firestore import os from api.local_api import get_question_by_id, get_question_ids_with_correctness,init_json # This function should be implemented to return a list of all question IDs def get_question_ids(): # Placeholder implementation return get_question_ids_with_correctness(st.session_state.db) # This function should be implemented to return the context for a given question ID def get_question_context(question_id): return get_question_by_id(st.session_state.db,question_id[:-2]) def create_interactive_graph(reasoning_chain, ratings): G = nx.DiGraph() net = Network(notebook=True, width="100%", height="600px", directed=True) print(ratings) for i, step in enumerate(reasoning_chain): wrapped_text = textwrap.fill(step, width=30) label = f"Step {i+1}\n\n{wrapped_text}" color = "#97C2FC" if i < ratings else "#FF9999" border_color = "#00FF00" if i >= ratings else "#FF0000" G.add_node(i, title=step, label=label, color=color, borderWidth=3, borderColor=border_color) if i > 0: G.add_edge(i-1, i) net.from_nx(G) for node in net.nodes: node['shape'] = 'box' node['font'] = {'size': 12, 'face': 'arial', 'multi': 'html', 'align': 'center'} node['widthConstraint'] = {'minimum': 200, 'maximum': 300} net.set_options(''' var options = { "nodes": { "shape": "box", "physics": false, "margin": 10 }, "edges": { "smooth": { "type": "curvedCW", "roundness": 0.2 }, "arrows": { "to": { "enabled": true, "scaleFactor": 0.5 } } }, "layout": { "hierarchical": { "enabled": true, "direction": "UD", "sortMethod": "directed" } }, "interaction": { "hover": true, "tooltipDelay": 100 } } ''') return net def update_rate(selected_question_id,i): st.session_state.ratings[selected_question_id] = i update_rate(selected_question_id,i) def main(): st.title("Interactive Q&A App with Reasoning Chain Graph and Rating") question_ids = get_question_ids() if 'current_index' not in st.session_state: st.session_state.current_index = 0 if 'ratings' not in st.session_state: st.session_state.ratings = {} col1, col2, col3 = st.columns([1,3,1]) with col1: if st.button("Previous"): st.session_state.current_index = (st.session_state.current_index - 1) % len(question_ids) with col3: if st.button("Next"): st.session_state.current_index = (st.session_state.current_index + 1) % len(question_ids) selected_question_id = st.selectbox("Select a question ID", question_ids, index=st.session_state.current_index) st.session_state.current_index = question_ids.index(selected_question_id) data = get_question_context(selected_question_id) original_question = data['original_question'] generated_result = data['generated_result'] st.subheader("Question") st.write(original_question['question']) st.subheader("Choices") for label, choice in original_question['label_choices'].items(): st.write(f"{label}: {choice}") # Display correct answer st.subheader("Correct Answer") correct_answer_label = original_question['answer'] correct_answer = original_question['label_choices'][correct_answer_label] st.write(f"{correct_answer_label}: {correct_answer}") st.subheader("Generated Answer") generated_answer_label = generated_result['answer_key_vale'] generated_answer = original_question['label_choices'][generated_answer_label] st.write(f"{generated_answer_label}: {generated_answer}") st.subheader("is it correct ?") generated_answer_label = generated_result['answer_key_vale'] answer_label = original_question['answer'] is_same=answer_label.lower()==generated_answer_label.lower() st.write(f"✅ answers Are the same. answer is {answer_label}"if is_same else f"📛 answers do differ.\n But you can check reasonings.\noriginal answer label : {answer_label}\ngenerated answer label : {generated_answer_label}") st.subheader("Rate the Reasoning Steps") if selected_question_id not in st.session_state.ratings: st.session_state.ratings[selected_question_id] = data['max_depth'] rating = st.session_state.ratings[selected_question_id] cols = st.columns(len(generated_result['reasoning_chain'])+1) for i, col in enumerate(cols): if i==0: col.button(f"None", key=f"rate_{i}",on_click=update_rate,args=[selected_question_id,i]) continue col.button(f"Step {i}", key=f"rate_{i}",on_click=update_rate,args=[selected_question_id,i]) net = create_interactive_graph(generated_result['reasoning_chain'], rating) net.save_graph("graph.html") with open("graph.html", 'r', encoding='utf-8') as f: html = f.read() st.components.v1.html(html, height=600) def initialize_firebase(): """ Initialize Firebase app and return Firestore client. If the app is already initialized, it returns the existing Firestore client. """ try: cert=json.loads(os.getenv('google_json')) except: cert=os.getenv('google_json') cred = credentials.Certificate(cert) try: firebase_admin.get_app() print("Default app already exists") except ValueError: # Initialize the app with a service account, granting admin privileges firebase_admin.initialize_app(cred) return firestore.client() if __name__ == "__main__": if os.getenv("local")=='true' or True: st.session_state.db=init_json() else: st.session_state.db=initialize_firebase() main()