ayushnoori commited on
Commit
7e225aa
·
1 Parent(s): 6efe11e

Add validate

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  media/pfp/*.png filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  media/pfp/*.png filter=lfs diff=lfs merge=lfs -text
37
+ data/*.csv filter=lfs diff=lfs merge=lfs -text
data/kg_edge_types.csv CHANGED
@@ -1,75 +1,3 @@
1
- x_type,relation,display_relation,y_type,direction,N
2
- anatomy,anatomy_protein_present,expression present,gene/protein,forward,3831782
3
- gene/protein,rev_anatomy_protein_present,expression present,anatomy,reverse,3831782
4
- drug,drug_drug,synergistic interaction,drug,forward,1433261
5
- drug,rev_drug_drug,synergistic interaction,drug,reverse,1433261
6
- anatomy,anatomy_protein_absent,expression absent,gene/protein,forward,324186
7
- gene/protein,rev_anatomy_protein_absent,expression absent,anatomy,reverse,324186
8
- gene/protein,protein_protein,ppi,gene/protein,forward,321090
9
- gene/protein,rev_protein_protein,ppi,gene/protein,reverse,321090
10
- disease,disease_phenotype_positive,phenotype present,effect/phenotype,forward,200354
11
- effect/phenotype,rev_disease_phenotype_positive,phenotype present,disease,reverse,200354
12
- disease,disease_protein,associated with,gene/protein,forward,147984
13
- gene/protein,rev_disease_protein,associated with,disease,reverse,147984
14
- biological_process,bioprocess_protein,interacts with,gene/protein,forward,138297
15
- gene/protein,rev_bioprocess_protein,interacts with,biological_process,reverse,138297
16
- cellular_component,cellcomp_protein,interacts with,gene/protein,forward,83089
17
- gene/protein,rev_cellcomp_protein,interacts with,cellular_component,reverse,83089
18
- disease,disease_protein_negative,expression downregulated,gene/protein,forward,71135
19
- gene/protein,rev_disease_protein_negative,expression downregulated,disease,reverse,71135
20
- gene/protein,molfunc_protein,interacts with,molecular_function,forward,70291
21
- molecular_function,rev_molfunc_protein,interacts with,gene/protein,reverse,70291
22
- disease,disease_protein_positive,expression upregulated,gene/protein,forward,69488
23
- gene/protein,rev_disease_protein_positive,expression upregulated,disease,reverse,69488
24
- drug,drug_effect,side effect,effect/phenotype,forward,64249
25
- effect/phenotype,rev_drug_effect,side effect,drug,reverse,64249
26
- biological_process,bioprocess_bioprocess,parent-child,biological_process,forward,50232
27
- biological_process,rev_bioprocess_bioprocess,parent-child,biological_process,reverse,50232
28
- gene/protein,pathway_protein,interacts with,pathway,forward,44116
29
- pathway,rev_pathway_protein,interacts with,gene/protein,reverse,44116
30
- disease,disease_disease,parent-child,disease,forward,37808
31
- disease,rev_disease_disease,parent-child,disease,reverse,37808
32
- disease,contraindication,contraindication,drug,forward,26899
33
- drug,rev_contraindication,contraindication,disease,reverse,26899
34
- effect/phenotype,phenotype_phenotype,parent-child,effect/phenotype,forward,20183
35
- effect/phenotype,rev_phenotype_phenotype,parent-child,effect/phenotype,reverse,20183
36
- drug,drug_protein,target,gene/protein,forward,18513
37
- gene/protein,rev_drug_protein,target,drug,reverse,18513
38
- disease,weak_clinical_evidence,clinical candidate,drug,forward,16111
39
- drug,rev_weak_clinical_evidence,clinical candidate,disease,reverse,16111
40
- anatomy,anatomy_anatomy,parent-child,anatomy,forward,14383
41
- anatomy,rev_anatomy_anatomy,parent-child,anatomy,reverse,14383
42
- molecular_function,molfunc_molfunc,parent-child,molecular_function,forward,13735
43
- molecular_function,rev_molfunc_molfunc,parent-child,molecular_function,reverse,13735
44
- disease,indication,indication,drug,forward,12608
45
- drug,rev_indication,indication,disease,reverse,12608
46
- drug,drug_protein,enzyme,gene/protein,forward,5919
47
- gene/protein,rev_drug_protein,enzyme,drug,reverse,5919
48
- disease,strong_clinical_evidence,clinical candidate,drug,forward,5352
49
- drug,rev_strong_clinical_evidence,clinical candidate,disease,reverse,5352
50
- cellular_component,cellcomp_cellcomp,parent-child,cellular_component,forward,4683
51
- cellular_component,rev_cellcomp_cellcomp,parent-child,cellular_component,reverse,4683
52
- effect/phenotype,phenotype_protein,associated with,gene/protein,forward,4437
53
- gene/protein,rev_phenotype_protein,associated with,effect/phenotype,reverse,4437
54
- drug,drug_protein,transporter,gene/protein,forward,3349
55
- gene/protein,rev_drug_protein,transporter,drug,reverse,3349
56
- pathway,pathway_pathway,parent-child,pathway,forward,2647
57
- pathway,rev_pathway_pathway,parent-child,pathway,reverse,2647
58
- disease,exposure_disease,linked to,exposure,forward,2421
59
- exposure,rev_exposure_disease,linked to,disease,reverse,2421
60
- disease,off_label_use,off-label use,drug,forward,2370
61
- drug,rev_off_label_use,off-label use,disease,reverse,2370
62
- exposure,exposure_exposure,parent-child,exposure,forward,2263
63
- exposure,rev_exposure_exposure,parent-child,exposure,reverse,2263
64
- exposure,exposure_protein,interacts with,gene/protein,forward,2012
65
- gene/protein,rev_exposure_protein,interacts with,exposure,reverse,2012
66
- biological_process,exposure_bioprocess,interacts with,exposure,forward,1990
67
- exposure,rev_exposure_bioprocess,interacts with,biological_process,reverse,1990
68
- drug,drug_protein,carrier,gene/protein,forward,993
69
- gene/protein,rev_drug_protein,carrier,drug,reverse,993
70
- disease,disease_phenotype_negative,phenotype absent,effect/phenotype,forward,508
71
- effect/phenotype,rev_disease_phenotype_negative,phenotype absent,disease,reverse,508
72
- exposure,exposure_molfunc,interacts with,molecular_function,forward,45
73
- molecular_function,rev_exposure_molfunc,interacts with,exposure,reverse,45
74
- cellular_component,exposure_cellcomp,interacts with,exposure,forward,12
75
- exposure,rev_exposure_cellcomp,interacts with,cellular_component,reverse,12
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ee79b2f5021304a4dd82581568e8a8c940f94b29cd1206f7730bdff6b82cab4
3
+ size 5288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/kg_edges.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab7d0e23c56381abf8e214cc5d4fae4e6a8b98957c8f2e5272b4f800953b1461
3
+ size 2765378133
data/kg_node_types.csv CHANGED
@@ -1,11 +1,3 @@
1
- node_type,N
2
- gene/protein,35198
3
- biological_process,27668
4
- disease,22201
5
- effect/phenotype,16711
6
- anatomy,14384
7
- molecular_function,11228
8
- drug,8160
9
- cellular_component,4054
10
- pathway,2629
11
- exposure,860
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1a0afff52deec5f48689a22a479d14cd49333759e054624366687ec4ef306c8
3
+ size 192
 
 
 
 
 
 
 
 
data/kg_nodes.csv CHANGED
The diff for this file is too large to render. See raw diff
 
pages/predict.py CHANGED
@@ -12,6 +12,10 @@ import torch.nn.functional as F
12
  from pathlib import Path
13
  from huggingface_hub import hf_hub_download
14
 
 
 
 
 
15
  # Custom and other imports
16
  import project_config
17
  from utils import capitalize_after_slash, load_kg
@@ -134,6 +138,9 @@ with st.spinner('Computing predictions...'):
134
  # Add URLs to database column
135
  display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
136
 
 
 
 
137
  # Use multiselect to search for specific nodes
138
  selected_nodes = st.multiselect(f"Search for specific {st.session_state.query['source_node_type'].replace('_', ' ')} nodes to determine their ranking.",
139
  display_data.Name, placeholder = "Type to search...")
@@ -151,6 +158,25 @@ with st.spinner('Computing predictions...'):
151
  else:
152
  st.dataframe(selected_display_data, use_container_width = True)
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  # Show top ranked nodes
155
  st.subheader("Model Predictions", divider = "blue")
156
  top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
@@ -165,3 +191,4 @@ with st.spinner('Computing predictions...'):
165
 
166
  # Save to session state
167
  st.session_state.predictions = display_data
 
 
12
  from pathlib import Path
13
  from huggingface_hub import hf_hub_download
14
 
15
+ # Plotting
16
+ import matplotlib.pyplot as plt
17
+ plt.rcParams['font.sans-serif'] = 'Arial'
18
+
19
  # Custom and other imports
20
  import project_config
21
  from utils import capitalize_after_slash, load_kg
 
138
  # Add URLs to database column
139
  display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
140
 
141
+
142
+ # NODE SEARCH
143
+
144
  # Use multiselect to search for specific nodes
145
  selected_nodes = st.multiselect(f"Search for specific {st.session_state.query['source_node_type'].replace('_', ' ')} nodes to determine their ranking.",
146
  display_data.Name, placeholder = "Type to search...")
 
158
  else:
159
  st.dataframe(selected_display_data, use_container_width = True)
160
 
161
+ # Plot rank vs. score using matplotlib
162
+ st.markdown("**Rank vs. Score**")
163
+ fig, ax = plt.subplots(figsize = (10, 6))
164
+ ax.plot(display_data['Rank'], display_data['Score'])
165
+ ax.set_xlabel('Rank', fontsize = 12)
166
+ ax.set_ylabel('Score', fontsize = 12)
167
+ ax.set_xlim(1, display_data['Rank'].max())
168
+
169
+ # Add vertical line for selected nodes
170
+ for i, node in selected_display_data.iterrows():
171
+ ax.axvline(node['Rank'], color = 'red', linestyle = '--', label = node['Name'])
172
+ ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize = 10, color = 'red')
173
+
174
+ # Show plot
175
+ st.pyplot(fig)
176
+
177
+
178
+ # FULL RESULTS
179
+
180
  # Show top ranked nodes
181
  st.subheader("Model Predictions", divider = "blue")
182
  top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
 
191
 
192
  # Save to session state
193
  st.session_state.predictions = display_data
194
+ st.session_state.display_database = display_database
pages/validate.py CHANGED
@@ -1,11 +1,21 @@
1
  import streamlit as st
2
  from menu import menu_with_redirect
3
 
 
 
 
 
4
  # Path manipulation
5
  from pathlib import Path
6
 
 
 
 
 
 
7
  # Custom and other imports
8
  import project_config
 
9
 
10
  # Redirect to app.py if not logged in, otherwise show the navigation menu
11
  menu_with_redirect()
@@ -21,6 +31,76 @@ st.subheader("Validate Predictions", divider = "green")
21
  # Print current query
22
  st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
23
 
24
-
25
  # Coming soon
26
- st.write("Coming soon...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from menu import menu_with_redirect
3
 
4
+ # Standard imports
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
  # Path manipulation
9
  from pathlib import Path
10
 
11
+ # Plotting
12
+ import matplotlib.pyplot as plt
13
+ plt.rcParams['font.sans-serif'] = 'Arial'
14
+ import matplotlib.colors as mcolors
15
+
16
  # Custom and other imports
17
  import project_config
18
+ from utils import load_kg, load_kg_edges
19
 
20
  # Redirect to app.py if not logged in, otherwise show the navigation menu
21
  menu_with_redirect()
 
31
  # Print current query
32
  st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
33
 
 
34
  # Coming soon
35
+ # st.write("Coming soon...")
36
+
37
+ source_node_type = st.session_state.query['source_node_type']
38
+ source_node = st.session_state.query['source_node']
39
+ relation = st.session_state.query['relation']
40
+ target_node_type = st.session_state.query['target_node_type']
41
+ predictions = st.session_state.predictions
42
+
43
+ kg_nodes = load_kg()
44
+ kg_edges = load_kg_edges()
45
+
46
+
47
+ with st.spinner('Searching known relationships...'):
48
+
49
+ # Subset existing edges
50
+ edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)]
51
+ edge_subset = edge_subset[edge_subset.y_type == target_node_type]
52
+
53
+ # Merge edge subset with predictions
54
+ edges_in_kg = pd.merge(predictions, edge_subset[['relation', 'y_id']], left_on = 'ID', right_on = 'y_id', how = 'right')
55
+ edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False)
56
+ edges_in_kg = edges_in_kg.drop(columns = 'y_id')
57
+
58
+ # Rename relation to ground-truth
59
+ edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
60
+ edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})
61
+
62
+
63
+ with st.spinner('Plotting known relationships...'):
64
+
65
+ # Define a color map for different relations
66
+ color_map = plt.cm.get_cmap('tab10')
67
+
68
+ # Group by relation and create separate plots
69
+ relations = edges_in_kg['Known Relation'].unique()
70
+ for idx, relation in enumerate(relations):
71
+ fig, ax = plt.subplots(figsize=(10, 3))
72
+ ax.plot(predictions['Rank'], predictions['Score'])
73
+ ax.set_xlabel('Rank', fontsize=12)
74
+ ax.set_ylabel('Score', fontsize=12)
75
+ ax.set_xlim(1, predictions['Rank'].max())
76
+
77
+ relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]
78
+ color = color_map(idx % color_map.N) # Get a color from the color map
79
+
80
+ for i, node in relation_data.iterrows():
81
+ ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'])
82
+ # ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
83
+
84
+ # ax.set_title(f'{relation.replace("_", "-")}')
85
+ # ax.legend()
86
+
87
+ # Convert tuple to hex
88
+ def rgba_to_hex(rgba):
89
+ return mcolors.to_hex(rgba[:3])
90
+ color_hex = rgba_to_hex(color)
91
+
92
+ # Write header in color of relation
93
+ st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h2>", unsafe_allow_html=True)
94
+
95
+ # Show plot
96
+ st.pyplot(fig)
97
+
98
+ # Drop known relation column
99
+ relation_data = relation_data.drop(columns = 'Known Relation')
100
+ if target_node_type not in ['disease', 'anatomy']:
101
+ st.dataframe(relation_data, use_container_width=True,
102
+ column_config={"Database": st.column_config.LinkColumn(width = "small",
103
+ help = "Click to visit external database.",
104
+ display_text = st.session_state.display_database)})
105
+ else:
106
+ st.dataframe(relation_data, use_container_width=True)
utils.py CHANGED
@@ -3,12 +3,21 @@ import pandas as pd
3
  import project_config
4
  import base64
5
 
6
- @st.cache_data(show_spinner = 'Loading knowledge graph...')
 
7
  def load_kg():
8
  # with st.spinner('Loading knowledge graph...'):
9
  kg_nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
10
  return kg_nodes
11
 
 
 
 
 
 
 
 
 
12
  def capitalize_after_slash(s):
13
  # Split the string by slashes first
14
  parts = s.split('/')
@@ -18,6 +27,7 @@ def capitalize_after_slash(s):
18
  capitalized_string = '/'.join(capitalized_parts).replace('_', ' ')
19
  return capitalized_string
20
 
 
21
  # From https://stackoverflow.com/questions/73251012/put-logo-and-title-above-on-top-of-page-navigation-in-sidebar-of-streamlit-multi
22
  # See also https://arnaudmiribel.github.io/streamlit-extras/extras/app_logo/
23
  @st.cache_data()
@@ -26,6 +36,7 @@ def get_base64_of_bin_file(png_file):
26
  data = f.read()
27
  return base64.b64encode(data).decode()
28
 
 
29
  def build_markup_for_logo(
30
  png_file,
31
  background_position="50% 10%",
@@ -55,6 +66,7 @@ def build_markup_for_logo(
55
  image_height,
56
  )
57
 
 
58
  def add_logo(png_file):
59
  logo_markup = build_markup_for_logo(png_file)
60
  st.markdown(
 
3
  import project_config
4
  import base64
5
 
6
+
7
+ @st.cache_data(show_spinner = 'Loading knowledge graph nodes...')
8
  def load_kg():
9
  # with st.spinner('Loading knowledge graph...'):
10
  kg_nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
11
  return kg_nodes
12
 
13
+
14
+ @st.cache_data(show_spinner = 'Loading knowledge graph edges...')
15
+ def load_kg_edges():
16
+ # with st.spinner('Loading knowledge graph...'):
17
+ kg_edges = pd.read_csv(project_config.DATA_DIR / 'kg_edges.csv', dtype = {'edge_index': int, 'x_index': int, 'y_index': int}, low_memory = False)
18
+ return kg_edges
19
+
20
+
21
  def capitalize_after_slash(s):
22
  # Split the string by slashes first
23
  parts = s.split('/')
 
27
  capitalized_string = '/'.join(capitalized_parts).replace('_', ' ')
28
  return capitalized_string
29
 
30
+
31
  # From https://stackoverflow.com/questions/73251012/put-logo-and-title-above-on-top-of-page-navigation-in-sidebar-of-streamlit-multi
32
  # See also https://arnaudmiribel.github.io/streamlit-extras/extras/app_logo/
33
  @st.cache_data()
 
36
  data = f.read()
37
  return base64.b64encode(data).decode()
38
 
39
+
40
  def build_markup_for_logo(
41
  png_file,
42
  background_position="50% 10%",
 
66
  image_height,
67
  )
68
 
69
+
70
  def add_logo(png_file):
71
  logo_markup = build_markup_for_logo(png_file)
72
  st.markdown(