File size: 4,209 Bytes
f027c05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import streamlit as st
from menu import menu_with_redirect

# Standard imports
import numpy as np
import pandas as pd

# Path manipulation
from pathlib import Path

# Custom and other imports
import project_config
from utils import load_kg

# Redirect to app.py if not logged in, otherwise show the navigation menu
menu_with_redirect()

# Header
st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)

# Main content
# st.markdown(f"Hello, {st.session_state.name}!")

st.subheader("Construct Query", divider = "red")

# # Checkbox to allow reverse edges
# allow_reverse_edges = st.checkbox("Reverse Edges", value = False)

# Load knowledge graph
kg_nodes = load_kg()

with st.spinner('Loading knowledge graph...'):
    # kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
    node_types = pd.read_csv(project_config.DATA_DIR / 'kg_node_types.csv')
    edge_types = pd.read_csv(project_config.DATA_DIR / 'kg_edge_types.csv')

    # if not allow_reverse_edges:
    #    edge_types = edge_types[edge_types.direction == 'forward']

# If query is not in session state, initialize it
if "query" not in st.session_state:
    source_node_type_index = 0
    source_node_index = 0
    target_node_type_index = 0
    relation_index = 0
else:
    source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
    source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
    target_node_type_index = st.session_state.query_options['target_node_type'].index(st.session_state.query['target_node_type'])
    relation_index = st.session_state.query_options['relation'].index(st.session_state.query['relation'])

# Select source node type
# source_node_type_options = node_types['node_type']
# source_node_type = st.selectbox("Source Node Type", source_node_type_options,
#                                 format_func = lambda x: x.replace("_", " "), index = source_node_type_index)
source_node_type = "disease"

# Select source node
# source_node_options = kg_nodes[kg_nodes['node_type'] == source_node_type]['node_name']
# source_node = st.selectbox("Source Node", source_node_options,
#                            index = source_node_index)
source_node = "Parkinson disease"

# Select target node type
target_node_type_options = edge_types[edge_types.x_type == source_node_type].y_type.unique()
target_node_type = st.selectbox("Target Node Type", target_node_type_options,
                                format_func = lambda x: x.replace("_", " "), index = target_node_type_index)

# Select relation
relation_options = edge_types[(edge_types.x_type == source_node_type) & (edge_types.y_type == target_node_type)].relation.unique()
relation = st.selectbox("Edge Type", relation_options,
                        format_func = lambda x: x.replace("_", "-"), index = relation_index)

# Button to submit query
if st.button("Submit Query"):
    
    # Save query to session state
    st.session_state.query = {
        "source_node_type": source_node_type,
        "source_node": source_node,
        "target_node_type": target_node_type,
        "relation": relation
    }

    # Save query options to session state
    st.session_state.query_options = {
        # "source_node_type": list(source_node_type_options),
        # "source_node": list(source_node_options),
        "source_node_type": ["disease"],
        "source_node": ["Parkinson disease"],
        "target_node_type": list(target_node_type_options),
        "relation": list(relation_options)
    }

    # # Write query to console
    # st.write("Current Query:")
    # st.write(st.session_state.query)
    st.write("Query submitted.")

    # Switch to the Predict page
    st.switch_page("pages/predict.py")


# st.subheader("Knowledge Graph", divider = "red")    
# display_data = kg_nodes[['node_id', 'node_type', 'node_name', 'node_source']].copy()
# display_data = display_data.rename(columns = {'node_id': 'ID', 'node_type': 'Type', 'node_name': 'Name', 'node_source': 'Database'})
# st.dataframe(display_data, use_container_width = True)