Spaces:
Sleeping
Sleeping
File size: 4,871 Bytes
01b734e ca764d6 01b734e 6efe11e 01b734e ba1c7a0 ca764d6 6efe11e ca764d6 6efe11e ca764d6 e9c640b 079a08e e9c640b 31e2a19 ca764d6 e9c640b 31e2a19 ca764d6 e9c640b 31e2a19 ca764d6 e9c640b 31e2a19 ca764d6 e9c640b 31e2a19 ca764d6 e9c640b 079a08e ca764d6 6efe11e ca764d6 31e2a19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import streamlit as st
from menu import menu_with_redirect
# Standard imports
import numpy as np
import pandas as pd
# Path manipulation
from pathlib import Path
# Custom and other imports
import project_config
from utils import load_kg
# Redirect to app.py if not logged in, otherwise show the navigation menu
menu_with_redirect()
# Header
st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)
# Main content
# st.markdown(f"Hello, {st.session_state.name}!")
st.subheader("Construct Query", divider = "red")
# Checkbox to allow reverse edges
allow_reverse_edges = st.checkbox("Reverse Edges", value = False)
# Load knowledge graph
kg_nodes = load_kg()
with st.spinner('Loading knowledge graph...'):
# kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
node_types = pd.read_csv(project_config.DATA_DIR / 'kg_node_types.csv')
edge_types = pd.read_csv(project_config.DATA_DIR / 'kg_edge_types.csv')
if not allow_reverse_edges:
edge_types = edge_types[edge_types.direction == 'forward']
# If query is not in session state, initialize it
if "query" not in st.session_state:
source_node_type_index = 0
source_node_index = 0
target_node_type_index = 0
relation_index = 0
if st.session_state.team == "Clalit":
source_node_type_index = 2
source_node_index = 0
target_node_type_index = 3
relation_index = 2
if st.session_state.team == "ASAP":
source_node_type_index = 2
source_node_index = 10255
else:
source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
target_node_type_index = st.session_state.query_options['target_node_type'].index(st.session_state.query['target_node_type'])
relation_index = st.session_state.query_options['relation'].index(st.session_state.query['relation'])
# Define error catching function
def catch_index_error(index, index_options):
if index >= len(index_options):
return 0
else:
return index
# Select source node type
source_node_type_options = node_types['node_type']
source_node_type = st.selectbox("Source Node Type", source_node_type_options,
format_func = lambda x: x.replace("_", " "),
index = catch_index_error(source_node_type_index, source_node_type_options))
# Select source node
source_node_options = kg_nodes[kg_nodes['node_type'] == source_node_type]['node_name']
source_node = st.selectbox("Source Node", source_node_options,
index = catch_index_error(source_node_index, source_node_options))
# Select target node type
target_node_type_options = edge_types[edge_types.x_type == source_node_type].y_type.unique()
target_node_type = st.selectbox("Target Node Type", target_node_type_options,
format_func = lambda x: x.replace("_", " "),
index = catch_index_error(target_node_type_index, target_node_type_options))
# Select relation
relation_options = edge_types[(edge_types.x_type == source_node_type) & (edge_types.y_type == target_node_type)].relation.unique()
relation = st.selectbox("Edge Type", relation_options,
format_func = lambda x: x.replace("_", "-"),
index = catch_index_error(relation_index, relation_options))
# Button to submit query
if st.button("Submit Query"):
# Save query to session state
st.session_state.query = {
"source_node_type": source_node_type,
"source_node": source_node,
"target_node_type": target_node_type,
"relation": relation
}
# Save query options to session state
st.session_state.query_options = {
"source_node_type": list(source_node_type_options),
"source_node": list(source_node_options),
"target_node_type": list(target_node_type_options),
"relation": list(relation_options)
}
# Delete validation from session state
if "validation" in st.session_state:
del st.session_state.validation
# # Write query to console
# st.write("Current Query:")
# st.write(st.session_state.query)
st.write("Query submitted.")
# Switch to the Predict page
st.switch_page("pages/predict.py")
st.subheader("Knowledge Graph", divider = "red")
display_data = kg_nodes[['node_id', 'node_type', 'node_name', 'node_source']].copy()
display_data = display_data.rename(columns = {'node_id': 'ID', 'node_type': 'Type', 'node_name': 'Name', 'node_source': 'Database'})
st.dataframe(display_data, use_container_width = True, hide_index = True) |