|
import streamlit as st |
|
import networkx as nx |
|
from pyvis.network import Network |
|
import textwrap |
|
|
|
from api.apis import get_question_by_id, get_question_ids_with_correctness |
|
|
|
|
|
def get_question_ids(): |
|
|
|
return get_question_ids_with_correctness() |
|
|
|
def get_question_context(question_id): |
|
return get_question_by_id(question_id[:-2]) |
|
|
|
def create_interactive_graph(reasoning_chain): |
|
G = nx.DiGraph() |
|
net = Network(notebook=True, width="100%", height="600px", directed=True) |
|
|
|
for i, step in enumerate(reasoning_chain): |
|
wrapped_text = textwrap.fill(step, width=30) |
|
label = f"Step {i+1}\n\n{wrapped_text}" |
|
G.add_node(i, title=step, label=label) |
|
if i > 0: |
|
G.add_edge(i-1, i) |
|
|
|
net.from_nx(G) |
|
|
|
for node in net.nodes: |
|
node['shape'] = 'box' |
|
node['color'] = '#97C2FC' |
|
node['font'] = {'size': 12, 'face': 'arial', 'multi': 'html', 'align': 'center'} |
|
node['widthConstraint'] = {'minimum': 200, 'maximum': 300} |
|
|
|
net.set_options(''' |
|
var options = { |
|
"nodes": { |
|
"shape": "box", |
|
"physics": false, |
|
"margin": 10 |
|
}, |
|
"edges": { |
|
"smooth": { |
|
"type": "curvedCW", |
|
"roundness": 0.2 |
|
}, |
|
"arrows": { |
|
"to": { |
|
"enabled": true, |
|
"scaleFactor": 0.5 |
|
} |
|
} |
|
}, |
|
"layout": { |
|
"hierarchical": { |
|
"enabled": true, |
|
"direction": "UD", |
|
"sortMethod": "directed" |
|
} |
|
}, |
|
"interaction": { |
|
"hover": true, |
|
"tooltipDelay": 100 |
|
} |
|
} |
|
''') |
|
|
|
return net |
|
|
|
def main(): |
|
st.title("Interactive Q&A App with Reasoning Chain Graph") |
|
|
|
|
|
question_ids = get_question_ids() |
|
|
|
|
|
if 'current_index' not in st.session_state: |
|
st.session_state.current_index = 0 |
|
|
|
|
|
col1, col2, col3 = st.columns([1,3,1]) |
|
with col1: |
|
if st.button("Previous"): |
|
st.session_state.current_index = (st.session_state.current_index - 1) % len(question_ids) |
|
with col3: |
|
if st.button("Next"): |
|
st.session_state.current_index = (st.session_state.current_index + 1) % len(question_ids) |
|
|
|
|
|
selected_question_id = st.selectbox("Select a question ID", question_ids, index=st.session_state.current_index) |
|
|
|
|
|
st.session_state.current_index = question_ids.index(selected_question_id) |
|
|
|
|
|
data = get_question_context(selected_question_id) |
|
original_question = data['original_question'] |
|
generated_result = data['generated_result'] |
|
|
|
|
|
st.subheader("Question") |
|
st.write(original_question['question']) |
|
|
|
|
|
st.subheader("Choices") |
|
for label, choice in original_question['label_choices'].items(): |
|
st.write(f"{label}: {choice}") |
|
|
|
|
|
st.subheader("Correct Answer") |
|
correct_answer_label = original_question['answer'] |
|
correct_answer = original_question['label_choices'][correct_answer_label] |
|
st.write(f"{correct_answer_label}: {correct_answer}") |
|
|
|
st.subheader("Generated Answer") |
|
generated_answer_label = generated_result['answer_key_vale'] |
|
generated_answer = original_question['label_choices'][generated_answer_label] |
|
st.write(f"{generated_answer_label}: {generated_answer}") |
|
|
|
st.subheader("is it correct ?") |
|
generated_answer_label = generated_result['answer_key_vale'] |
|
answer_label = original_question['answer'] |
|
is_same=answer_label.lower()==generated_answer_label.lower() |
|
st.write(f"β
answers Are the same. answer is {answer_label}"if is_same else f"π answers do differ.\n But you can check reasonings.\noriginal answer label : {answer_label}\ngenerated answer label : {generated_answer_label}") |
|
|
|
|
|
net = create_interactive_graph(generated_result['reasoning_chain']) |
|
net.save_graph("graph.html") |
|
|
|
|
|
with open("graph.html", 'r', encoding='utf-8') as f: |
|
html = f.read() |
|
st.components.v1.html(html, height=600) |
|
|
|
if __name__ == "__main__": |
|
main() |