ayushnoori commited on
Commit
e9c640b
·
1 Parent(s): b2ee56b

Persist query

Browse files
Files changed (2) hide show
  1. pages/input.py +32 -7
  2. pages/validate.py +52 -44
pages/input.py CHANGED
@@ -37,20 +37,37 @@ with st.spinner('Loading knowledge graph...'):
37
  if not allow_reverse_edges:
38
  edge_types = edge_types[edge_types.direction == 'forward']
39
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # Select source node type
41
- source_node_type = st.selectbox("Source Node Type", node_types['node_type'],
42
- format_func = lambda x: x.replace("_", " "))
 
43
 
44
  # Select source node
45
- source_node = st.selectbox("Source Node", kg_nodes[kg_nodes['node_type'] == source_node_type]['node_name'])
 
 
46
 
47
  # Select target node type
48
- target_node_type = st.selectbox("Target Node Type", edge_types[edge_types.x_type == source_node_type].y_type.unique(),
49
- format_func = lambda x: x.replace("_", " "))
 
50
 
51
  # Select relation
52
- relation = st.selectbox("Edge Type", edge_types[(edge_types.x_type == source_node_type) & (edge_types.y_type == target_node_type)].relation.unique(),
53
- format_func = lambda x: x.replace("_", "-"))
 
54
 
55
  # Button to submit query
56
  if st.button("Submit Query"):
@@ -63,6 +80,14 @@ if st.button("Submit Query"):
63
  "relation": relation
64
  }
65
 
 
 
 
 
 
 
 
 
66
  # # Write query to console
67
  # st.write("Current Query:")
68
  # st.write(st.session_state.query)
 
37
  if not allow_reverse_edges:
38
  edge_types = edge_types[edge_types.direction == 'forward']
39
 
40
+ # If query is not in session state, initialize it
41
+ if "query" not in st.session_state:
42
+ source_node_type_index = 0
43
+ source_node_index = 0
44
+ target_node_type_index = 0
45
+ relation_index = 0
46
+ else:
47
+ source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
48
+ source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
49
+ target_node_type_index = st.session_state.query_options['target_node_type'].index(st.session_state.query['target_node_type'])
50
+ relation_index = st.session_state.query_options['relation'].index(st.session_state.query['relation'])
51
+
52
  # Select source node type
53
+ source_node_type_options = node_types['node_type']
54
+ source_node_type = st.selectbox("Source Node Type", source_node_type_options,
55
+ format_func = lambda x: x.replace("_", " "), index = source_node_type_index)
56
 
57
  # Select source node
58
+ source_node_options = kg_nodes[kg_nodes['node_type'] == source_node_type]['node_name']
59
+ source_node = st.selectbox("Source Node", source_node_options,
60
+ index = source_node_index)
61
 
62
  # Select target node type
63
+ target_node_type_options = edge_types[edge_types.x_type == source_node_type].y_type.unique()
64
+ target_node_type = st.selectbox("Target Node Type", target_node_type_options,
65
+ format_func = lambda x: x.replace("_", " "), index = target_node_type_index)
66
 
67
  # Select relation
68
+ relation_options = edge_types[(edge_types.x_type == source_node_type) & (edge_types.y_type == target_node_type)].relation.unique()
69
+ relation = st.selectbox("Edge Type", relation_options,
70
+ format_func = lambda x: x.replace("_", "-"), index = relation_index)
71
 
72
  # Button to submit query
73
  if st.button("Submit Query"):
 
80
  "relation": relation
81
  }
82
 
83
+ # Save query options to session state
84
+ st.session_state.query_options = {
85
+ "source_node_type": list(source_node_type_options),
86
+ "source_node": list(source_node_options),
87
+ "target_node_type": list(target_node_type_options),
88
+ "relation": list(relation_options)
89
+ }
90
+
91
  # # Write query to console
92
  # st.write("Current Query:")
93
  # st.write(st.session_state.query)
pages/validate.py CHANGED
@@ -43,6 +43,9 @@ predictions = st.session_state.predictions
43
  kg_nodes = load_kg()
44
  kg_edges = load_kg_edges()
45
 
 
 
 
46
 
47
  with st.spinner('Searching known relationships...'):
48
 
@@ -59,48 +62,53 @@ with st.spinner('Searching known relationships...'):
59
  edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
60
  edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})
61
 
 
 
62
 
63
- with st.spinner('Plotting known relationships...'):
64
-
65
- # Define a color map for different relations
66
- color_map = plt.get_cmap('tab10')
67
-
68
- # Group by relation and create separate plots
69
- relations = edges_in_kg['Known Relation'].unique()
70
- for idx, relation in enumerate(relations):
71
- fig, ax = plt.subplots(figsize=(10, 3))
72
- ax.plot(predictions['Rank'], predictions['Score'])
73
- ax.set_xlabel('Rank', fontsize=12)
74
- ax.set_ylabel('Score', fontsize=12)
75
- ax.set_xlim(1, predictions['Rank'].max())
76
-
77
- relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]
78
- color = color_map(idx % color_map.N) # Get a color from the color map
79
-
80
- for i, node in relation_data.iterrows():
81
- ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'])
82
- # ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
83
-
84
- # ax.set_title(f'{relation.replace("_", "-")}')
85
- # ax.legend()
86
-
87
- # Convert tuple to hex
88
- def rgba_to_hex(rgba):
89
- return mcolors.to_hex(rgba[:3])
90
- color_hex = rgba_to_hex(color)
91
-
92
- # Write header in color of relation
93
- st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h2>", unsafe_allow_html=True)
94
-
95
- # Show plot
96
- st.pyplot(fig)
97
-
98
- # Drop known relation column
99
- relation_data = relation_data.drop(columns = 'Known Relation')
100
- if target_node_type not in ['disease', 'anatomy']:
101
- st.dataframe(relation_data, use_container_width=True,
102
- column_config={"Database": st.column_config.LinkColumn(width = "small",
103
- help = "Click to visit external database.",
104
- display_text = st.session_state.display_database)})
105
- else:
106
- st.dataframe(relation_data, use_container_width=True)
 
 
 
 
43
  kg_nodes = load_kg()
44
  kg_edges = load_kg_edges()
45
 
46
+ # Convert tuple to hex
47
+ def rgba_to_hex(rgba):
48
+ return mcolors.to_hex(rgba[:3])
49
 
50
  with st.spinner('Searching known relationships...'):
51
 
 
62
  edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
63
  edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})
64
 
65
+ # If there exist edges in KG
66
+ if len(edges_in_kg) > 0:
67
 
68
+ with st.spinner('Plotting known relationships...'):
69
+
70
+ # Define a color map for different relations
71
+ color_map = plt.get_cmap('tab10')
72
+
73
+ # Group by relation and create separate plots
74
+ relations = edges_in_kg['Known Relation'].unique()
75
+ for idx, relation in enumerate(relations):
76
+
77
+ relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]
78
+
79
+ # Get a color from the color map
80
+ color = color_map(idx % color_map.N)
81
+
82
+ fig, ax = plt.subplots(figsize=(10, 3))
83
+ ax.plot(predictions['Rank'], predictions['Score'])
84
+ ax.set_xlabel('Rank', fontsize=12)
85
+ ax.set_ylabel('Score', fontsize=12)
86
+ ax.set_xlim(1, predictions['Rank'].max())
87
+
88
+ for i, node in relation_data.iterrows():
89
+ ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'])
90
+ # ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
91
+
92
+ # ax.set_title(f'{relation.replace("_", "-")}')
93
+ # ax.legend()
94
+ color_hex = rgba_to_hex(color)
95
+
96
+ # Write header in color of relation
97
+ st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h2>", unsafe_allow_html=True)
98
+
99
+ # Show plot
100
+ st.pyplot(fig)
101
+
102
+ # Drop known relation column
103
+ relation_data = relation_data.drop(columns = 'Known Relation')
104
+ if target_node_type not in ['disease', 'anatomy']:
105
+ st.dataframe(relation_data, use_container_width=True,
106
+ column_config={"Database": st.column_config.LinkColumn(width = "small",
107
+ help = "Click to visit external database.",
108
+ display_text = st.session_state.display_database)})
109
+ else:
110
+ st.dataframe(relation_data, use_container_width=True)
111
+
112
+ else:
113
+
114
+ st.error('No ground truth relationships found for the given query in the knowledge graph.', icon="✖️")