Spaces:
Sleeping
Sleeping
Commit
·
e9c640b
1
Parent(s):
b2ee56b
Persist query
Browse files- pages/input.py +32 -7
- pages/validate.py +52 -44
pages/input.py
CHANGED
@@ -37,20 +37,37 @@ with st.spinner('Loading knowledge graph...'):
|
|
37 |
if not allow_reverse_edges:
|
38 |
edge_types = edge_types[edge_types.direction == 'forward']
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# Select source node type
|
41 |
-
|
42 |
-
|
|
|
43 |
|
44 |
# Select source node
|
45 |
-
|
|
|
|
|
46 |
|
47 |
# Select target node type
|
48 |
-
|
49 |
-
|
|
|
50 |
|
51 |
# Select relation
|
52 |
-
|
53 |
-
|
|
|
54 |
|
55 |
# Button to submit query
|
56 |
if st.button("Submit Query"):
|
@@ -63,6 +80,14 @@ if st.button("Submit Query"):
|
|
63 |
"relation": relation
|
64 |
}
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
# # Write query to console
|
67 |
# st.write("Current Query:")
|
68 |
# st.write(st.session_state.query)
|
|
|
37 |
if not allow_reverse_edges:
|
38 |
edge_types = edge_types[edge_types.direction == 'forward']
|
39 |
|
40 |
+
# If query is not in session state, initialize it
|
41 |
+
if "query" not in st.session_state:
|
42 |
+
source_node_type_index = 0
|
43 |
+
source_node_index = 0
|
44 |
+
target_node_type_index = 0
|
45 |
+
relation_index = 0
|
46 |
+
else:
|
47 |
+
source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
|
48 |
+
source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
|
49 |
+
target_node_type_index = st.session_state.query_options['target_node_type'].index(st.session_state.query['target_node_type'])
|
50 |
+
relation_index = st.session_state.query_options['relation'].index(st.session_state.query['relation'])
|
51 |
+
|
52 |
# Select source node type
|
53 |
+
source_node_type_options = node_types['node_type']
|
54 |
+
source_node_type = st.selectbox("Source Node Type", source_node_type_options,
|
55 |
+
format_func = lambda x: x.replace("_", " "), index = source_node_type_index)
|
56 |
|
57 |
# Select source node
|
58 |
+
source_node_options = kg_nodes[kg_nodes['node_type'] == source_node_type]['node_name']
|
59 |
+
source_node = st.selectbox("Source Node", source_node_options,
|
60 |
+
index = source_node_index)
|
61 |
|
62 |
# Select target node type
|
63 |
+
target_node_type_options = edge_types[edge_types.x_type == source_node_type].y_type.unique()
|
64 |
+
target_node_type = st.selectbox("Target Node Type", target_node_type_options,
|
65 |
+
format_func = lambda x: x.replace("_", " "), index = target_node_type_index)
|
66 |
|
67 |
# Select relation
|
68 |
+
relation_options = edge_types[(edge_types.x_type == source_node_type) & (edge_types.y_type == target_node_type)].relation.unique()
|
69 |
+
relation = st.selectbox("Edge Type", relation_options,
|
70 |
+
format_func = lambda x: x.replace("_", "-"), index = relation_index)
|
71 |
|
72 |
# Button to submit query
|
73 |
if st.button("Submit Query"):
|
|
|
80 |
"relation": relation
|
81 |
}
|
82 |
|
83 |
+
# Save query options to session state
|
84 |
+
st.session_state.query_options = {
|
85 |
+
"source_node_type": list(source_node_type_options),
|
86 |
+
"source_node": list(source_node_options),
|
87 |
+
"target_node_type": list(target_node_type_options),
|
88 |
+
"relation": list(relation_options)
|
89 |
+
}
|
90 |
+
|
91 |
# # Write query to console
|
92 |
# st.write("Current Query:")
|
93 |
# st.write(st.session_state.query)
|
pages/validate.py
CHANGED
@@ -43,6 +43,9 @@ predictions = st.session_state.predictions
|
|
43 |
kg_nodes = load_kg()
|
44 |
kg_edges = load_kg_edges()
|
45 |
|
|
|
|
|
|
|
46 |
|
47 |
with st.spinner('Searching known relationships...'):
|
48 |
|
@@ -59,48 +62,53 @@ with st.spinner('Searching known relationships...'):
|
|
59 |
edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
|
60 |
edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})
|
61 |
|
|
|
|
|
62 |
|
63 |
-
with st.spinner('Plotting known relationships...'):
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
ax.
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
43 |
kg_nodes = load_kg()
|
44 |
kg_edges = load_kg_edges()
|
45 |
|
46 |
+
# Convert tuple to hex
|
47 |
+
def rgba_to_hex(rgba):
|
48 |
+
return mcolors.to_hex(rgba[:3])
|
49 |
|
50 |
with st.spinner('Searching known relationships...'):
|
51 |
|
|
|
62 |
edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
|
63 |
edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})
|
64 |
|
65 |
+
# If there exist edges in KG
|
66 |
+
if len(edges_in_kg) > 0:
|
67 |
|
68 |
+
with st.spinner('Plotting known relationships...'):
|
69 |
+
|
70 |
+
# Define a color map for different relations
|
71 |
+
color_map = plt.get_cmap('tab10')
|
72 |
+
|
73 |
+
# Group by relation and create separate plots
|
74 |
+
relations = edges_in_kg['Known Relation'].unique()
|
75 |
+
for idx, relation in enumerate(relations):
|
76 |
+
|
77 |
+
relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]
|
78 |
+
|
79 |
+
# Get a color from the color map
|
80 |
+
color = color_map(idx % color_map.N)
|
81 |
+
|
82 |
+
fig, ax = plt.subplots(figsize=(10, 3))
|
83 |
+
ax.plot(predictions['Rank'], predictions['Score'])
|
84 |
+
ax.set_xlabel('Rank', fontsize=12)
|
85 |
+
ax.set_ylabel('Score', fontsize=12)
|
86 |
+
ax.set_xlim(1, predictions['Rank'].max())
|
87 |
+
|
88 |
+
for i, node in relation_data.iterrows():
|
89 |
+
ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'])
|
90 |
+
# ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
|
91 |
+
|
92 |
+
# ax.set_title(f'{relation.replace("_", "-")}')
|
93 |
+
# ax.legend()
|
94 |
+
color_hex = rgba_to_hex(color)
|
95 |
+
|
96 |
+
# Write header in color of relation
|
97 |
+
st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h2>", unsafe_allow_html=True)
|
98 |
+
|
99 |
+
# Show plot
|
100 |
+
st.pyplot(fig)
|
101 |
+
|
102 |
+
# Drop known relation column
|
103 |
+
relation_data = relation_data.drop(columns = 'Known Relation')
|
104 |
+
if target_node_type not in ['disease', 'anatomy']:
|
105 |
+
st.dataframe(relation_data, use_container_width=True,
|
106 |
+
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
107 |
+
help = "Click to visit external database.",
|
108 |
+
display_text = st.session_state.display_database)})
|
109 |
+
else:
|
110 |
+
st.dataframe(relation_data, use_container_width=True)
|
111 |
+
|
112 |
+
else:
|
113 |
+
|
114 |
+
st.error('No ground truth relationships found for the given query in the knowledge graph.', icon="✖️")
|