ayushnoori commited on
Commit
079a08e
·
1 Parent(s): 63ce440

Introducing team-specific defaults, immediate about page loading, and validation display on predict page

Browse files
Files changed (4) hide show
  1. app.py +3 -1
  2. pages/input.py +15 -0
  3. pages/predict.py +44 -4
  4. pages/validate.py +10 -0
app.py CHANGED
@@ -155,4 +155,6 @@ def check_password():
155
  menu() # Render the dynamic menu!
156
 
157
  if not check_password():
158
- st.stop()
 
 
 
155
  menu() # Render the dynamic menu!
156
 
157
  if not check_password():
158
+ st.stop()
159
+
160
+ st.switch_page("pages/about.py")
pages/input.py CHANGED
@@ -43,6 +43,17 @@ if "query" not in st.session_state:
43
  source_node_index = 0
44
  target_node_type_index = 0
45
  relation_index = 0
 
 
 
 
 
 
 
 
 
 
 
46
  else:
47
  source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
48
  source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
@@ -88,6 +99,10 @@ if st.button("Submit Query"):
88
  "relation": list(relation_options)
89
  }
90
 
 
 
 
 
91
  # # Write query to console
92
  # st.write("Current Query:")
93
  # st.write(st.session_state.query)
 
43
  source_node_index = 0
44
  target_node_type_index = 0
45
  relation_index = 0
46
+
47
+ if st.session_state.team == "Clalit":
48
+ source_node_type_index = 2
49
+ source_node_index = 0
50
+ target_node_type_index = 3
51
+ relation_index = 2
52
+
53
+ if st.session_state.team == "ASAP":
54
+ source_node_type_index = 2
55
+ source_node_index = 10255
56
+
57
  else:
58
  source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
59
  source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
 
99
  "relation": list(relation_options)
100
  }
101
 
102
+ # Delete validation from session state
103
+ if "validation" in st.session_state:
104
+ del st.session_state.validation
105
+
106
  # # Write query to console
107
  # st.write("Current Query:")
108
  # st.write(st.session_state.query)
pages/predict.py CHANGED
@@ -143,6 +143,37 @@ with st.spinner('Computing predictions...'):
143
  # Add URLs to database column
144
  display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  # NODE SEARCH
148
 
@@ -152,11 +183,16 @@ with st.spinner('Computing predictions...'):
152
 
153
  # Filter nodes
154
  if len(selected_nodes) > 0:
155
- selected_display_data = display_data[display_data.Name.isin(selected_nodes)]
 
 
 
 
 
156
 
157
  # Show filtered nodes
158
  if target_node_type not in ['disease', 'anatomy']:
159
- st.dataframe(selected_display_data, use_container_width = True,
160
  column_config={"Database": st.column_config.LinkColumn(width = "small",
161
  help = "Click to visit external database.",
162
  display_text = display_database)})
@@ -185,14 +221,18 @@ with st.spinner('Computing predictions...'):
185
  # Show top ranked nodes
186
  st.subheader("Model Predictions", divider = "blue")
187
  top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
 
 
 
 
188
 
189
  if target_node_type not in ['disease', 'anatomy']:
190
- st.dataframe(display_data.iloc[:top_k], use_container_width = True,
191
  column_config={"Database": st.column_config.LinkColumn(width = "small",
192
  help = "Click to visit external database.",
193
  display_text = display_database)})
194
  else:
195
- st.dataframe(display_data.iloc[:top_k], use_container_width = True)
196
 
197
  # Save to session state
198
  st.session_state.predictions = display_data
 
143
  # Add URLs to database column
144
  display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
145
 
146
+ # Check if validation data exists
147
+ if 'validation' in st.session_state:
148
+
149
+ # Checkbox to allow reverse edges
150
+ show_val = st.checkbox("Show Ground Truth Validation?", value = False)
151
+
152
+ if show_val:
153
+
154
+ # Get validation data
155
+ val_results = st.session_state.validation.copy()
156
+
157
+ # Merge with predictions
158
+ val_display_data = pd.merge(display_data, val_results, left_on = 'ID', right_on = 'y_id', how='left')
159
+ val_display_data = val_display_data.fillna(0).drop(columns='y_id')
160
+
161
+ # Get new columns
162
+ val_relations = val_display_data.columns.difference(display_data.columns).tolist()
163
+
164
+ # Replace 0 with blank and 1 with check emoji in new columns
165
+ for col in val_relations:
166
+ val_display_data[col] = val_display_data[col].replace({0: '', 1: '✅'})
167
+
168
+ # Define a function to apply styles
169
+ def style_val(val):
170
+ if val == '✅':
171
+ return 'background-color: #C2EABD;' # text-align: center;
172
+ return 'background-color: #F5F5F5;' # text-align: center;
173
+
174
+ else:
175
+ show_val = False
176
+
177
 
178
  # NODE SEARCH
179
 
 
183
 
184
  # Filter nodes
185
  if len(selected_nodes) > 0:
186
+
187
+ if show_val:
188
+ # selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)]
189
+ selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)].style.map(style_val, subset=val_relations)
190
+ else:
191
+ selected_display_data = display_data[display_data.Name.isin(selected_nodes)]
192
 
193
  # Show filtered nodes
194
  if target_node_type not in ['disease', 'anatomy']:
195
+ st.dataframe(selected_display_data, use_container_width = True, hide_index = True,
196
  column_config={"Database": st.column_config.LinkColumn(width = "small",
197
  help = "Click to visit external database.",
198
  display_text = display_database)})
 
221
  # Show top ranked nodes
222
  st.subheader("Model Predictions", divider = "blue")
223
  top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
224
+
225
+ # Show full results
226
+ # full_results = val_display_data.iloc[:top_k] if show_val else display_data.iloc[:top_k]
227
+ full_results = val_display_data.iloc[:top_k].style.map(style_val, subset=val_relations) if show_val else display_data.iloc[:top_k]
228
 
229
  if target_node_type not in ['disease', 'anatomy']:
230
+ st.dataframe(full_results, use_container_width = True, hide_index = True,
231
  column_config={"Database": st.column_config.LinkColumn(width = "small",
232
  help = "Click to visit external database.",
233
  display_text = display_database)})
234
  else:
235
+ st.dataframe(full_results, use_container_width = True, hide_index = True,)
236
 
237
  # Save to session state
238
  st.session_state.predictions = display_data
pages/validate.py CHANGED
@@ -65,6 +65,16 @@ with st.spinner('Searching known relationships...'):
65
  # If there exist edges in KG
66
  if len(edges_in_kg) > 0:
67
 
 
 
 
 
 
 
 
 
 
 
68
  with st.spinner('Plotting known relationships...'):
69
 
70
  # Define a color map for different relations
 
65
  # If there exist edges in KG
66
  if len(edges_in_kg) > 0:
67
 
68
+ with st.spinner('Saving validation results...'):
69
+
70
+ # Cast long to wide
71
+ val_results = edge_subset[['relation', 'y_id']].pivot_table(index='y_id', columns='relation', aggfunc='size', fill_value=0)
72
+ val_results = (val_results > 0).astype(int).reset_index()
73
+ val_results.columns = [val_results.columns[0]] + [x.replace('_', ' ').title() for x in val_results.columns[1:]]
74
+
75
+ # Save validation results to session state
76
+ st.session_state.validation = val_results
77
+
78
  with st.spinner('Plotting known relationships...'):
79
 
80
  # Define a color map for different relations