File size: 4,548 Bytes
5e5e2eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import streamlit as st
import networkx as nx
from pyvis.network import Network
import textwrap
from api.apis import get_question_by_id, get_question_ids_with_correctness
# This function should be implemented to return a list of all question IDs
def get_question_ids():
# Placeholder implementation
return get_question_ids_with_correctness()
# This function should be implemented to return the context for a given question ID
def get_question_context(question_id):
return get_question_by_id(question_id)
def create_interactive_graph(reasoning_chain):
G = nx.DiGraph()
net = Network(notebook=True, width="100%", height="600px", directed=True)
for i, step in enumerate(reasoning_chain):
wrapped_text = textwrap.fill(step, width=30)
label = f"Step {i+1}\n\n{wrapped_text}"
G.add_node(i, title=step, label=label)
if i > 0:
G.add_edge(i-1, i)
net.from_nx(G)
for node in net.nodes:
node['shape'] = 'box'
node['color'] = '#97C2FC'
node['font'] = {'size': 12, 'face': 'arial', 'multi': 'html', 'align': 'center'}
node['widthConstraint'] = {'minimum': 200, 'maximum': 300}
net.set_options('''
var options = {
"nodes": {
"shape": "box",
"physics": false,
"margin": 10
},
"edges": {
"smooth": {
"type": "curvedCW",
"roundness": 0.2
},
"arrows": {
"to": {
"enabled": true,
"scaleFactor": 0.5
}
}
},
"layout": {
"hierarchical": {
"enabled": true,
"direction": "UD",
"sortMethod": "directed"
}
},
"interaction": {
"hover": true,
"tooltipDelay": 100
}
}
''')
return net
def main():
st.title("Interactive Q&A App with Reasoning Chain Graph")
# Get all question IDs
question_ids = get_question_ids()
# Initialize session state for current index
if 'current_index' not in st.session_state:
st.session_state.current_index = 0
# Navigation buttons
col1, col2, col3 = st.columns([1,3,1])
with col1:
if st.button("Previous"):
st.session_state.current_index = (st.session_state.current_index - 1) % len(question_ids)
with col3:
if st.button("Next"):
st.session_state.current_index = (st.session_state.current_index + 1) % len(question_ids)
# Select a question ID
selected_question_id = st.selectbox("Select a question ID", question_ids, index=st.session_state.current_index)
# Update current_index when selection changes
st.session_state.current_index = question_ids.index(selected_question_id)
# Get the context for the selected question
data = get_question_context(selected_question_id)
original_question = data['original_question']
generated_result = data['generated_result']
# Display question
st.subheader("Question")
st.write(original_question['question'])
# Display choices
st.subheader("Choices")
for label, choice in original_question['label_choices'].items():
st.write(f"{label}: {choice}")
# Display correct answer
st.subheader("Correct Answer")
correct_answer_label = original_question['answer']
correct_answer = original_question['label_choices'][correct_answer_label]
st.write(f"{correct_answer_label}: {correct_answer}")
st.subheader("Generated Answer")
generated_answer_label = generated_result['answer_key_vale']
generated_answer = original_question['label_choices'][generated_answer_label]
st.write(f"{generated_answer_label}: {generated_answer}")
st.subheader("is it correct ?")
generated_answer_label = generated_result['answer_key_vale']
answer_label = original_question['answer']
is_same=answer_label.lower()==generated_answer_label.lower()
st.write(f"✅ answers Are the same. answer is {answer_label}"if is_same else f"📛 answers do differ.\n But you can check reasonings.\noriginal answer label : {answer_label}\ngenerated answer label : {generated_answer_label}")
# Create and display interactive reasoning chain graph
net = create_interactive_graph(generated_result['reasoning_chain'])
net.save_graph("graph.html")
# Display the interactive graph
with open("graph.html", 'r', encoding='utf-8') as f:
html = f.read()
st.components.v1.html(html, height=600)
if __name__ == "__main__":
main() |