Spaces:
Sleeping
Sleeping
Commit
·
079a08e
1
Parent(s):
63ce440
Introducing team-specific defaults, immediate about page loading, and validation display on predict page
Browse files- app.py +3 -1
- pages/input.py +15 -0
- pages/predict.py +44 -4
- pages/validate.py +10 -0
app.py
CHANGED
@@ -155,4 +155,6 @@ def check_password():
|
|
155 |
menu() # Render the dynamic menu!
|
156 |
|
157 |
if not check_password():
|
158 |
-
st.stop()
|
|
|
|
|
|
155 |
menu() # Render the dynamic menu!
|
156 |
|
157 |
if not check_password():
|
158 |
+
st.stop()
|
159 |
+
|
160 |
+
st.switch_page("pages/about.py")
|
pages/input.py
CHANGED
@@ -43,6 +43,17 @@ if "query" not in st.session_state:
|
|
43 |
source_node_index = 0
|
44 |
target_node_type_index = 0
|
45 |
relation_index = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
else:
|
47 |
source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
|
48 |
source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
|
@@ -88,6 +99,10 @@ if st.button("Submit Query"):
|
|
88 |
"relation": list(relation_options)
|
89 |
}
|
90 |
|
|
|
|
|
|
|
|
|
91 |
# # Write query to console
|
92 |
# st.write("Current Query:")
|
93 |
# st.write(st.session_state.query)
|
|
|
43 |
source_node_index = 0
|
44 |
target_node_type_index = 0
|
45 |
relation_index = 0
|
46 |
+
|
47 |
+
if st.session_state.team == "Clalit":
|
48 |
+
source_node_type_index = 2
|
49 |
+
source_node_index = 0
|
50 |
+
target_node_type_index = 3
|
51 |
+
relation_index = 2
|
52 |
+
|
53 |
+
if st.session_state.team == "ASAP":
|
54 |
+
source_node_type_index = 2
|
55 |
+
source_node_index = 10255
|
56 |
+
|
57 |
else:
|
58 |
source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
|
59 |
source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
|
|
|
99 |
"relation": list(relation_options)
|
100 |
}
|
101 |
|
102 |
+
# Delete validation from session state
|
103 |
+
if "validation" in st.session_state:
|
104 |
+
del st.session_state.validation
|
105 |
+
|
106 |
# # Write query to console
|
107 |
# st.write("Current Query:")
|
108 |
# st.write(st.session_state.query)
|
pages/predict.py
CHANGED
@@ -143,6 +143,37 @@ with st.spinner('Computing predictions...'):
|
|
143 |
# Add URLs to database column
|
144 |
display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
# NODE SEARCH
|
148 |
|
@@ -152,11 +183,16 @@ with st.spinner('Computing predictions...'):
|
|
152 |
|
153 |
# Filter nodes
|
154 |
if len(selected_nodes) > 0:
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
# Show filtered nodes
|
158 |
if target_node_type not in ['disease', 'anatomy']:
|
159 |
-
st.dataframe(selected_display_data, use_container_width = True,
|
160 |
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
161 |
help = "Click to visit external database.",
|
162 |
display_text = display_database)})
|
@@ -185,14 +221,18 @@ with st.spinner('Computing predictions...'):
|
|
185 |
# Show top ranked nodes
|
186 |
st.subheader("Model Predictions", divider = "blue")
|
187 |
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
|
|
|
|
|
|
|
|
|
188 |
|
189 |
if target_node_type not in ['disease', 'anatomy']:
|
190 |
-
st.dataframe(
|
191 |
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
192 |
help = "Click to visit external database.",
|
193 |
display_text = display_database)})
|
194 |
else:
|
195 |
-
st.dataframe(
|
196 |
|
197 |
# Save to session state
|
198 |
st.session_state.predictions = display_data
|
|
|
143 |
# Add URLs to database column
|
144 |
display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
|
145 |
|
146 |
+
# Check if validation data exists
|
147 |
+
if 'validation' in st.session_state:
|
148 |
+
|
149 |
+
# Checkbox to allow reverse edges
|
150 |
+
show_val = st.checkbox("Show Ground Truth Validation?", value = False)
|
151 |
+
|
152 |
+
if show_val:
|
153 |
+
|
154 |
+
# Get validation data
|
155 |
+
val_results = st.session_state.validation.copy()
|
156 |
+
|
157 |
+
# Merge with predictions
|
158 |
+
val_display_data = pd.merge(display_data, val_results, left_on = 'ID', right_on = 'y_id', how='left')
|
159 |
+
val_display_data = val_display_data.fillna(0).drop(columns='y_id')
|
160 |
+
|
161 |
+
# Get new columns
|
162 |
+
val_relations = val_display_data.columns.difference(display_data.columns).tolist()
|
163 |
+
|
164 |
+
# Replace 0 with blank and 1 with check emoji in new columns
|
165 |
+
for col in val_relations:
|
166 |
+
val_display_data[col] = val_display_data[col].replace({0: '', 1: '✅'})
|
167 |
+
|
168 |
+
# Define a function to apply styles
|
169 |
+
def style_val(val):
|
170 |
+
if val == '✅':
|
171 |
+
return 'background-color: #C2EABD;' # text-align: center;
|
172 |
+
return 'background-color: #F5F5F5;' # text-align: center;
|
173 |
+
|
174 |
+
else:
|
175 |
+
show_val = False
|
176 |
+
|
177 |
|
178 |
# NODE SEARCH
|
179 |
|
|
|
183 |
|
184 |
# Filter nodes
|
185 |
if len(selected_nodes) > 0:
|
186 |
+
|
187 |
+
if show_val:
|
188 |
+
# selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)]
|
189 |
+
selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)].style.map(style_val, subset=val_relations)
|
190 |
+
else:
|
191 |
+
selected_display_data = display_data[display_data.Name.isin(selected_nodes)]
|
192 |
|
193 |
# Show filtered nodes
|
194 |
if target_node_type not in ['disease', 'anatomy']:
|
195 |
+
st.dataframe(selected_display_data, use_container_width = True, hide_index = True,
|
196 |
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
197 |
help = "Click to visit external database.",
|
198 |
display_text = display_database)})
|
|
|
221 |
# Show top ranked nodes
|
222 |
st.subheader("Model Predictions", divider = "blue")
|
223 |
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
|
224 |
+
|
225 |
+
# Show full results
|
226 |
+
# full_results = val_display_data.iloc[:top_k] if show_val else display_data.iloc[:top_k]
|
227 |
+
full_results = val_display_data.iloc[:top_k].style.map(style_val, subset=val_relations) if show_val else display_data.iloc[:top_k]
|
228 |
|
229 |
if target_node_type not in ['disease', 'anatomy']:
|
230 |
+
st.dataframe(full_results, use_container_width = True, hide_index = True,
|
231 |
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
232 |
help = "Click to visit external database.",
|
233 |
display_text = display_database)})
|
234 |
else:
|
235 |
+
st.dataframe(full_results, use_container_width = True, hide_index = True,)
|
236 |
|
237 |
# Save to session state
|
238 |
st.session_state.predictions = display_data
|
pages/validate.py
CHANGED
@@ -65,6 +65,16 @@ with st.spinner('Searching known relationships...'):
|
|
65 |
# If there exist edges in KG
|
66 |
if len(edges_in_kg) > 0:
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
with st.spinner('Plotting known relationships...'):
|
69 |
|
70 |
# Define a color map for different relations
|
|
|
65 |
# If there exist edges in KG
|
66 |
if len(edges_in_kg) > 0:
|
67 |
|
68 |
+
with st.spinner('Saving validation results...'):
|
69 |
+
|
70 |
+
# Cast long to wide
|
71 |
+
val_results = edge_subset[['relation', 'y_id']].pivot_table(index='y_id', columns='relation', aggfunc='size', fill_value=0)
|
72 |
+
val_results = (val_results > 0).astype(int).reset_index()
|
73 |
+
val_results.columns = [val_results.columns[0]] + [x.replace('_', ' ').title() for x in val_results.columns[1:]]
|
74 |
+
|
75 |
+
# Save validation results to session state
|
76 |
+
st.session_state.validation = val_results
|
77 |
+
|
78 |
with st.spinner('Plotting known relationships...'):
|
79 |
|
80 |
# Define a color map for different relations
|