File size: 4,871 Bytes
01b734e
 
 
ca764d6
 
 
 
01b734e
 
 
 
 
6efe11e
01b734e
 
 
 
ba1c7a0
 
 
 
ca764d6
 
 
 
 
 
 
6efe11e
 
 
ca764d6
6efe11e
ca764d6
 
 
 
 
 
e9c640b
 
 
 
 
 
079a08e
 
 
 
 
 
 
 
 
 
 
e9c640b
 
 
 
 
 
31e2a19
 
 
 
 
 
 
ca764d6
e9c640b
 
31e2a19
 
ca764d6
 
e9c640b
 
31e2a19
ca764d6
 
e9c640b
 
31e2a19
 
ca764d6
 
e9c640b
 
31e2a19
 
ca764d6
 
 
 
 
 
 
 
 
 
 
 
e9c640b
 
 
 
 
 
 
 
079a08e
 
 
 
ca764d6
 
 
 
 
6efe11e
 
 
 
ca764d6
 
 
31e2a19
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import streamlit as st
from menu import menu_with_redirect

# Standard imports
import numpy as np
import pandas as pd

# Path manipulation
from pathlib import Path

# Custom and other imports
import project_config
from utils import load_kg

# Redirect to app.py if not logged in, otherwise show the navigation menu
menu_with_redirect()

# Header
st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)

# Main content
# st.markdown(f"Hello, {st.session_state.name}!")

st.subheader("Construct Query", divider = "red")

# Checkbox to allow reverse edges
allow_reverse_edges = st.checkbox("Reverse Edges", value = False)

# Load knowledge graph
kg_nodes = load_kg()

with st.spinner('Loading knowledge graph...'):
    # kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
    node_types = pd.read_csv(project_config.DATA_DIR / 'kg_node_types.csv')
    edge_types = pd.read_csv(project_config.DATA_DIR / 'kg_edge_types.csv')

    if not allow_reverse_edges:
       edge_types = edge_types[edge_types.direction == 'forward']

# If query is not in session state, initialize it
if "query" not in st.session_state:
    source_node_type_index = 0
    source_node_index = 0
    target_node_type_index = 0
    relation_index = 0

    if st.session_state.team == "Clalit":
        source_node_type_index = 2
        source_node_index = 0
        target_node_type_index = 3
        relation_index = 2

    if st.session_state.team == "ASAP":
        source_node_type_index = 2
        source_node_index = 10255

else:
    source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
    source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
    target_node_type_index = st.session_state.query_options['target_node_type'].index(st.session_state.query['target_node_type'])
    relation_index = st.session_state.query_options['relation'].index(st.session_state.query['relation'])

# Define error catching function
def catch_index_error(index, index_options):
    if index >= len(index_options):
        return 0
    else:
        return index

# Select source node type
source_node_type_options = node_types['node_type']
source_node_type = st.selectbox("Source Node Type", source_node_type_options,
                                format_func = lambda x: x.replace("_", " "),
                                index = catch_index_error(source_node_type_index, source_node_type_options))

# Select source node
source_node_options = kg_nodes[kg_nodes['node_type'] == source_node_type]['node_name']
source_node = st.selectbox("Source Node", source_node_options,
                            index = catch_index_error(source_node_index, source_node_options))

# Select target node type
target_node_type_options = edge_types[edge_types.x_type == source_node_type].y_type.unique()
target_node_type = st.selectbox("Target Node Type", target_node_type_options,
                                format_func = lambda x: x.replace("_", " "),
                                index = catch_index_error(target_node_type_index, target_node_type_options))

# Select relation
relation_options = edge_types[(edge_types.x_type == source_node_type) & (edge_types.y_type == target_node_type)].relation.unique()
relation = st.selectbox("Edge Type", relation_options,
                        format_func = lambda x: x.replace("_", "-"),
                        index = catch_index_error(relation_index, relation_options))

# Button to submit query
if st.button("Submit Query"):
    
    # Save query to session state
    st.session_state.query = {
        "source_node_type": source_node_type,
        "source_node": source_node,
        "target_node_type": target_node_type,
        "relation": relation
    }

    # Save query options to session state
    st.session_state.query_options = {
        "source_node_type": list(source_node_type_options),
        "source_node": list(source_node_options),
        "target_node_type": list(target_node_type_options),
        "relation": list(relation_options)
    }

    # Delete validation from session state
    if "validation" in st.session_state:
        del st.session_state.validation

    # # Write query to console
    # st.write("Current Query:")
    # st.write(st.session_state.query)
    st.write("Query submitted.")

    # Switch to the Predict page
    st.switch_page("pages/predict.py")


st.subheader("Knowledge Graph", divider = "red")    
display_data = kg_nodes[['node_id', 'node_type', 'node_name', 'node_source']].copy()
display_data = display_data.rename(columns = {'node_id': 'ID', 'node_type': 'Type', 'node_name': 'Name', 'node_source': 'Database'})
st.dataframe(display_data, use_container_width = True, hide_index = True)