ReasMe / app.py
MSNP1381
app updated
937b552
raw
history blame
4.55 kB
import streamlit as st
import networkx as nx
from pyvis.network import Network
import textwrap
from api.apis import get_question_by_id, get_question_ids_with_correctness
# This function should be implemented to return a list of all question IDs
def get_question_ids():
# Placeholder implementation
return get_question_ids_with_correctness()
# This function should be implemented to return the context for a given question ID
def get_question_context(question_id):
return get_question_by_id(question_id[:-2])
def create_interactive_graph(reasoning_chain):
G = nx.DiGraph()
net = Network(notebook=True, width="100%", height="600px", directed=True)
for i, step in enumerate(reasoning_chain):
wrapped_text = textwrap.fill(step, width=30)
label = f"Step {i+1}\n\n{wrapped_text}"
G.add_node(i, title=step, label=label)
if i > 0:
G.add_edge(i-1, i)
net.from_nx(G)
for node in net.nodes:
node['shape'] = 'box'
node['color'] = '#97C2FC'
node['font'] = {'size': 12, 'face': 'arial', 'multi': 'html', 'align': 'center'}
node['widthConstraint'] = {'minimum': 200, 'maximum': 300}
net.set_options('''
var options = {
"nodes": {
"shape": "box",
"physics": false,
"margin": 10
},
"edges": {
"smooth": {
"type": "curvedCW",
"roundness": 0.2
},
"arrows": {
"to": {
"enabled": true,
"scaleFactor": 0.5
}
}
},
"layout": {
"hierarchical": {
"enabled": true,
"direction": "UD",
"sortMethod": "directed"
}
},
"interaction": {
"hover": true,
"tooltipDelay": 100
}
}
''')
return net
def main():
st.title("Interactive Q&A App with Reasoning Chain Graph")
# Get all question IDs
question_ids = get_question_ids()
# Initialize session state for current index
if 'current_index' not in st.session_state:
st.session_state.current_index = 0
# Navigation buttons
col1, col2, col3 = st.columns([1,3,1])
with col1:
if st.button("Previous"):
st.session_state.current_index = (st.session_state.current_index - 1) % len(question_ids)
with col3:
if st.button("Next"):
st.session_state.current_index = (st.session_state.current_index + 1) % len(question_ids)
# Select a question ID
selected_question_id = st.selectbox("Select a question ID", question_ids, index=st.session_state.current_index)
# Update current_index when selection changes
st.session_state.current_index = question_ids.index(selected_question_id)
# Get the context for the selected question
data = get_question_context(selected_question_id)
original_question = data['original_question']
generated_result = data['generated_result']
# Display question
st.subheader("Question")
st.write(original_question['question'])
# Display choices
st.subheader("Choices")
for label, choice in original_question['label_choices'].items():
st.write(f"{label}: {choice}")
# Display correct answer
st.subheader("Correct Answer")
correct_answer_label = original_question['answer']
correct_answer = original_question['label_choices'][correct_answer_label]
st.write(f"{correct_answer_label}: {correct_answer}")
st.subheader("Generated Answer")
generated_answer_label = generated_result['answer_key_vale']
generated_answer = original_question['label_choices'][generated_answer_label]
st.write(f"{generated_answer_label}: {generated_answer}")
st.subheader("is it correct ?")
generated_answer_label = generated_result['answer_key_vale']
answer_label = original_question['answer']
is_same=answer_label.lower()==generated_answer_label.lower()
st.write(f"βœ… answers Are the same. answer is {answer_label}"if is_same else f"πŸ“› answers do differ.\n But you can check reasonings.\noriginal answer label : {answer_label}\ngenerated answer label : {generated_answer_label}")
# Create and display interactive reasoning chain graph
net = create_interactive_graph(generated_result['reasoning_chain'])
net.save_graph("graph.html")
# Display the interactive graph
with open("graph.html", 'r', encoding='utf-8') as f:
html = f.read()
st.components.v1.html(html, height=600)
if __name__ == "__main__":
main()