ayushnoori commited on
Commit
950486e
·
1 Parent(s): f027c05

Demo for ASAP reviewers

Browse files
{pages → deprecated}/input.py RENAMED
File without changes
deprecated/predict.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from menu import menu_with_redirect
3
+
4
+ # Standard imports
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ # Path manipulation
12
+ from pathlib import Path
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ # Plotting
16
+ import matplotlib.pyplot as plt
17
+ plt.rcParams['font.sans-serif'] = 'Arial'
18
+
19
+ # Custom and other imports
20
+ import project_config
21
+ from utils import capitalize_after_slash, load_kg
22
+
23
+ # Redirect to app.py if not logged in, otherwise show the navigation menu
24
+ menu_with_redirect()
25
+
26
+ # Header
27
+ st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)
28
+
29
+ # Main content
30
+ # st.markdown(f"Hello, {st.session_state.name}!")
31
+
32
+ st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")
33
+
34
+ # Print current query
35
+ st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
36
+
37
+ @st.cache_data(show_spinner = 'Downloading AI model...')
38
+ def get_embeddings():
39
+ # Get checkpoint name
40
+ # best_ckpt = "2024_05_22_11_59_43_epoch=18-step=22912"
41
+ best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
42
+ # best_ckpt = "2024_03_29_04_12_52_epoch=3-step=54291"
43
+
44
+ # Get paths to embeddings, relation weights, and edge types
45
+ # with st.spinner('Downloading AI model...'):
46
+ embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
47
+ filename=(best_ckpt + "-thresh=4000_embeddings.pt"),
48
+ token=st.secrets["HF_TOKEN"])
49
+ relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
50
+ filename=(best_ckpt + "_relation_weights.pt"),
51
+ token=st.secrets["HF_TOKEN"])
52
+ edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
53
+ filename=(best_ckpt + "_edge_types.pt"),
54
+ token=st.secrets["HF_TOKEN"])
55
+ return embed_path, relation_weights_path, edge_types_path
56
+
57
+ @st.cache_data(show_spinner = 'Loading AI model...')
58
+ def load_embeddings(embed_path, relation_weights_path, edge_types_path):
59
+ # Load embeddings, relation weights, and edge types
60
+ # with st.spinner('Loading AI model...'):
61
+ embeddings = torch.load(embed_path)
62
+ relation_weights = torch.load(relation_weights_path)
63
+ edge_types = torch.load(edge_types_path)
64
+
65
+ return embeddings, relation_weights, edge_types
66
+
67
+ # Load knowledge graph and embeddings
68
+ kg_nodes = load_kg()
69
+ embed_path, relation_weights_path, edge_types_path = get_embeddings()
70
+ embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
71
+
72
+ # # Print source node type
73
+ # st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
74
+
75
+ # # Print source node
76
+ # st.write(f"Source Node: {st.session_state.query['source_node']}")
77
+
78
+ # # Print relation
79
+ # st.write(f"Edge Type: {st.session_state.query['relation']}")
80
+
81
+ # # Print target node type
82
+ # st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
83
+
84
+ # Compute predictions
85
+ with st.spinner('Computing predictions...'):
86
+
87
+ source_node_type = st.session_state.query['source_node_type']
88
+ source_node = st.session_state.query['source_node']
89
+ relation = st.session_state.query['relation']
90
+ target_node_type = st.session_state.query['target_node_type']
91
+
92
+ # Get source node index
93
+ src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
94
+
95
+ # Get relation index
96
+ edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
97
+
98
+ # Get target nodes indices
99
+ target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
100
+ dst_indices = target_nodes.node_index.values
101
+ src_indices = np.repeat(src_index, len(dst_indices))
102
+
103
+ # Retrieve cached embeddings and apply activation function
104
+ src_embeddings = embeddings[src_indices]
105
+ dst_embeddings = embeddings[dst_indices]
106
+ src_embeddings = F.leaky_relu(src_embeddings)
107
+ dst_embeddings = F.leaky_relu(dst_embeddings)
108
+
109
+ # Get relation weights
110
+ rel_weights = relation_weights[edge_type_index]
111
+
112
+ # Compute weighted dot product
113
+ scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
114
+ scores = torch.sigmoid(scores)
115
+
116
+ # Add scores to dataframe
117
+ target_nodes['score'] = scores.detach().numpy()
118
+ target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
119
+ target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
120
+
121
+ # Rename columns
122
+ display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
123
+ display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'score': 'Score', 'node_source': 'Database'})
124
+
125
+ # Define dictionary mapping node types to database URLs
126
+ map_dbs = {
127
+ 'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}",
128
+ 'drug': lambda x: f"https://go.drugbank.com/drugs/{x}",
129
+ 'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits
130
+ 'disease': lambda x: x, # MONDO
131
+ # pad with 0s to 7 digits
132
+ 'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
133
+ 'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
134
+ 'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
135
+ 'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}",
136
+ 'pathway': lambda x: f"https://reactome.org/content/detail/{x}",
137
+ 'anatomy': lambda x: x,
138
+ }
139
+
140
+ # Get name of database
141
+ display_database = display_data['Database'].values[0]
142
+
143
+ # Add URLs to database column
144
+ display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
145
+
146
+
147
+ # NODE SEARCH
148
+
149
+ # Use multiselect to search for specific nodes
150
+ selected_nodes = st.multiselect(f"Search for specific {st.session_state.query['source_node_type'].replace('_', ' ')} nodes to determine their ranking.",
151
+ display_data.Name, placeholder = "Type to search...")
152
+
153
+ # Filter nodes
154
+ if len(selected_nodes) > 0:
155
+ selected_display_data = display_data[display_data.Name.isin(selected_nodes)]
156
+
157
+ # Show filtered nodes
158
+ if target_node_type not in ['disease', 'anatomy']:
159
+ st.dataframe(selected_display_data, use_container_width = True,
160
+ column_config={"Database": st.column_config.LinkColumn(width = "small",
161
+ help = "Click to visit external database.",
162
+ display_text = display_database)})
163
+ else:
164
+ st.dataframe(selected_display_data, use_container_width = True)
165
+
166
+ # Plot rank vs. score using matplotlib
167
+ st.markdown("**Rank vs. Score**")
168
+ fig, ax = plt.subplots(figsize = (10, 6))
169
+ ax.plot(display_data['Rank'], display_data['Score'])
170
+ ax.set_xlabel('Rank', fontsize = 12)
171
+ ax.set_ylabel('Score', fontsize = 12)
172
+ ax.set_xlim(1, display_data['Rank'].max())
173
+
174
+ # Add vertical line for selected nodes
175
+ for i, node in selected_display_data.iterrows():
176
+ ax.axvline(node['Rank'], color = 'red', linestyle = '--', label = node['Name'])
177
+ ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize = 10, color = 'red')
178
+
179
+ # Show plot
180
+ st.pyplot(fig)
181
+
182
+
183
+ # FULL RESULTS
184
+
185
+ # Show top ranked nodes
186
+ st.subheader("Model Predictions", divider = "blue")
187
+ top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
188
+
189
+ if target_node_type not in ['disease', 'anatomy']:
190
+ st.dataframe(display_data.iloc[:top_k], use_container_width = True,
191
+ column_config={"Database": st.column_config.LinkColumn(width = "small",
192
+ help = "Click to visit external database.",
193
+ display_text = display_database)})
194
+ else:
195
+ st.dataframe(display_data.iloc[:top_k], use_container_width = True)
196
+
197
+ # Save to session state
198
+ st.session_state.predictions = display_data
199
+ st.session_state.display_database = display_database
{pages → deprecated}/validate.py RENAMED
File without changes
media/predict_header.svg CHANGED
menu.py CHANGED
@@ -15,11 +15,12 @@ def authenticated_menu():
15
  # Show a navigation menu for authenticated users
16
  # st.sidebar.page_link("app.py", label="Switch Accounts", icon="🔒")
17
  st.sidebar.page_link("pages/about.py", label="About", icon="📖")
18
- st.sidebar.page_link("pages/input.py", label="Input", icon="💡")
19
- st.sidebar.page_link("pages/predict.py", label="Predict", icon="🔍",
20
- disabled=("query" not in st.session_state))
21
- st.sidebar.page_link("pages/validate.py", label="Validate", icon="✅",
22
- disabled=("query" not in st.session_state))
 
23
  # st.sidebar.page_link("pages/explore.py", label="Explore", icon="🔍")
24
  if st.session_state.role in ["admin"]:
25
  st.sidebar.page_link("pages/admin.py", label="Manage Users", icon="🔧")
 
15
  # Show a navigation menu for authenticated users
16
  # st.sidebar.page_link("app.py", label="Switch Accounts", icon="🔒")
17
  st.sidebar.page_link("pages/about.py", label="About", icon="📖")
18
+ st.sidebar.page_link("pages/predict.py", label="Predict", icon="🔍")
19
+ # st.sidebar.page_link("pages/input.py", label="Input", icon="💡")
20
+ # st.sidebar.page_link("pages/predict.py", label="Predict", icon="🔍",
21
+ # disabled=("query" not in st.session_state))
22
+ # st.sidebar.page_link("pages/validate.py", label="Validate", icon="✅",
23
+ # disabled=("query" not in st.session_state))
24
  # st.sidebar.page_link("pages/explore.py", label="Explore", icon="🔍")
25
  if st.session_state.role in ["admin"]:
26
  st.sidebar.page_link("pages/admin.py", label="Manage Users", icon="🔧")
pages/about.py CHANGED
@@ -22,9 +22,9 @@ st.subheader("About CIPHER", divider = "grey")
22
  st.markdown("""
23
  CIPHER is a knowledge graph-based AI algorithm for diagnostic and therapeutic discovery in PD.
24
 
25
- *Knowledge graph construction.* To create CIPHER, we integrated diverse public information about basic biomedical interactions into a harmonized data platform amenable for training large-scale AI models. Specifically, we constructed a multiscale heterogeneous knowledge graph (KG) with *n = 143,093* nodes and *n = 7,048,795* edges by curating 36 high-quality primary data sources, ontologies, and knowledge bases.
26
 
27
- *Model training.* Next, to convert this trove of knowledge into an AI model with diagnostic and therapeutic capabilities, we employed graph representation learning, a deep learning to model biomedical networks by embedding graphs into informative low-dimensional vector spaces. We trained a state-of-the-art heterogeneous graph transformer to learn graph embeddings that encode the relationships in the KG.
28
 
29
  Through CIPHER, we seek to enable molecular subtyping and patient stratification of PD by integrating genetic and clinical progression data (*e.g.*, PPMI and HBS2.0 cohorts) and nominate genes, proteins, and pathways for in-depth mechanistic studies in stem cell and other PD models.
30
  """)
 
22
  st.markdown("""
23
  CIPHER is a knowledge graph-based AI algorithm for diagnostic and therapeutic discovery in PD.
24
 
25
+ *Knowledge graph construction.* To create CIPHER, we integrated diverse public information about basic biomedical interactions into a harmonized data platform amenable for training large-scale AI models. Specifically, we constructed a multiscale heterogeneous knowledge graph (KG) with *n* = 143,093 nodes and *n* = 7,048,795 edges by curating 36 high-quality primary data sources, ontologies, and knowledge bases.
26
 
27
+ *Model training.* Next, to convert this trove of knowledge into an AI model with diagnostic and therapeutic capabilities, we employed graph representation learning, a deep learning method to model biomedical networks by embedding graphs into informative low-dimensional vector spaces. We trained a state-of-the-art heterogeneous graph Transformer to learn graph embeddings that encode the relationships in the KG.
28
 
29
  Through CIPHER, we seek to enable molecular subtyping and patient stratification of PD by integrating genetic and clinical progression data (*e.g.*, PPMI and HBS2.0 cohorts) and nominate genes, proteins, and pathways for in-depth mechanistic studies in stem cell and other PD models.
30
  """)
pages/predict.py CHANGED
@@ -24,15 +24,51 @@ from utils import capitalize_after_slash, load_kg
24
  menu_with_redirect()
25
 
26
  # Header
27
- st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)
28
-
29
- # Main content
30
- # st.markdown(f"Hello, {st.session_state.name}!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")
 
33
 
34
- # Print current query
35
- st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
36
 
37
  @st.cache_data(show_spinner = 'Downloading AI model...')
38
  def get_embeddings():
@@ -69,26 +105,9 @@ kg_nodes = load_kg()
69
  embed_path, relation_weights_path, edge_types_path = get_embeddings()
70
  embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
71
 
72
- # # Print source node type
73
- # st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
74
-
75
- # # Print source node
76
- # st.write(f"Source Node: {st.session_state.query['source_node']}")
77
-
78
- # # Print relation
79
- # st.write(f"Edge Type: {st.session_state.query['relation']}")
80
-
81
- # # Print target node type
82
- # st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
83
-
84
  # Compute predictions
85
  with st.spinner('Computing predictions...'):
86
 
87
- source_node_type = st.session_state.query['source_node_type']
88
- source_node = st.session_state.query['source_node']
89
- relation = st.session_state.query['relation']
90
- target_node_type = st.session_state.query['target_node_type']
91
-
92
  # Get source node index
93
  src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
94
 
@@ -120,7 +139,7 @@ with st.spinner('Computing predictions...'):
120
 
121
  # Rename columns
122
  display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
123
- display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'score': 'Score', 'node_source': 'Database'})
124
 
125
  # Define dictionary mapping node types to database URLs
126
  map_dbs = {
@@ -143,57 +162,65 @@ with st.spinner('Computing predictions...'):
143
  # Add URLs to database column
144
  display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
145
 
146
-
147
  # NODE SEARCH
148
 
149
- # Use multiselect to search for specific nodes
150
- selected_nodes = st.multiselect(f"Search for specific {st.session_state.query['source_node_type'].replace('_', ' ')} nodes to determine their ranking.",
151
- display_data.Name, placeholder = "Type to search...")
152
-
153
  # Filter nodes
154
  if len(selected_nodes) > 0:
155
- selected_display_data = display_data[display_data.Name.isin(selected_nodes)]
156
-
157
- # Show filtered nodes
158
- if target_node_type not in ['disease', 'anatomy']:
159
- st.dataframe(selected_display_data, use_container_width = True,
160
- column_config={"Database": st.column_config.LinkColumn(width = "small",
161
- help = "Click to visit external database.",
162
- display_text = display_database)})
163
- else:
164
- st.dataframe(selected_display_data, use_container_width = True)
165
 
166
  # Plot rank vs. score using matplotlib
167
- st.markdown("**Rank vs. Score**")
168
  fig, ax = plt.subplots(figsize = (10, 6))
169
- ax.plot(display_data['Rank'], display_data['Score'])
170
  ax.set_xlabel('Rank', fontsize = 12)
171
- ax.set_ylabel('Score', fontsize = 12)
172
  ax.set_xlim(1, display_data['Rank'].max())
173
 
 
 
 
 
 
174
  # Add vertical line for selected nodes
175
  for i, node in selected_display_data.iterrows():
176
- ax.axvline(node['Rank'], color = 'red', linestyle = '--', label = node['Name'])
177
- ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize = 10, color = 'red')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  # Show plot
 
180
  st.pyplot(fig)
181
 
182
 
183
- # FULL RESULTS
184
 
185
- # Show top ranked nodes
186
- st.subheader("Model Predictions", divider = "blue")
187
- top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
188
 
189
- if target_node_type not in ['disease', 'anatomy']:
190
- st.dataframe(display_data.iloc[:top_k], use_container_width = True,
191
- column_config={"Database": st.column_config.LinkColumn(width = "small",
192
- help = "Click to visit external database.",
193
- display_text = display_database)})
194
- else:
195
- st.dataframe(display_data.iloc[:top_k], use_container_width = True)
196
-
197
- # Save to session state
198
- st.session_state.predictions = display_data
199
- st.session_state.display_database = display_database
 
24
  menu_with_redirect()
25
 
26
  # Header
27
+ st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)
28
+
29
+ st.markdown(
30
+ '''
31
+ Use CIPHER to predict how closely genes of interest are associated with Parkinson's disease. Search for specific genes to determine their ranking of PD association.
32
+ ''')
33
+
34
+ # source_node_type = st.session_state.query['source_node_type']
35
+ # source_node = st.session_state.query['source_node']
36
+ # relation = st.session_state.query['relation']
37
+ # target_node_type = st.session_state.query['target_node_type']
38
+ source_node_type = "disease"
39
+ source_node = "Parkinson disease"
40
+ relation = "disease_protein"
41
+ target_node_type = "gene/protein"
42
+
43
+ # target_node_type = st.selectbox("I am interested in searching for...", ['gene/protein', 'effect/phenotype', 'drug'],
44
+ # format_func = lambda x: x.replace("_", " "), index = 1)
45
+
46
+ # relation = {
47
+ # 'gene/protein': 'disease_protein',
48
+ # 'effect/phenotype': 'disease_phenotype_positive',
49
+ # 'drug': 'indication'
50
+ # }
51
+
52
+ # Get list of allowed nodes
53
+ allowed_nodes = {
54
+ 'gene/protein': ['RHOA', 'XRN1', 'SNCA', 'LRRK2', 'GBA1'],
55
+ 'effect/phenotype': ['Parkinsonism', 'Parkinsonism with favorable response to dopaminergic medication'],
56
+ 'drug': ['Levodopa']
57
+ }
58
+
59
+ # Use multiselect to search for specific nodes
60
+ selected_nodes = st.multiselect("Select genes to search for...",
61
+ allowed_nodes[target_node_type], placeholder = "Type to search...",
62
+ label_visibility = 'collapsed',)
63
+
64
+
65
+ # Add line break
66
+ st.markdown("---")
67
 
68
+ # Header
69
+ st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)
70
 
71
+ # st.subheader("Gene Search", divider = "blue")
 
72
 
73
  @st.cache_data(show_spinner = 'Downloading AI model...')
74
  def get_embeddings():
 
105
  embed_path, relation_weights_path, edge_types_path = get_embeddings()
106
  embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
107
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # Compute predictions
109
  with st.spinner('Computing predictions...'):
110
 
 
 
 
 
 
111
  # Get source node index
112
  src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
113
 
 
139
 
140
  # Rename columns
141
  display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
142
+ display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Gene', 'score': 'CIPHER Score', 'node_source': 'Database'})
143
 
144
  # Define dictionary mapping node types to database URLs
145
  map_dbs = {
 
162
  # Add URLs to database column
163
  display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
164
 
 
165
  # NODE SEARCH
166
 
 
 
 
 
167
  # Filter nodes
168
  if len(selected_nodes) > 0:
169
+ selected_display_data = display_data[display_data['Gene'].isin(selected_nodes)].copy().reset_index(drop = True)
 
 
 
 
 
 
 
 
 
170
 
171
  # Plot rank vs. score using matplotlib
 
172
  fig, ax = plt.subplots(figsize = (10, 6))
173
+ ax.plot(display_data['Rank'], display_data['CIPHER Score'], color = 'black')
174
  ax.set_xlabel('Rank', fontsize = 12)
175
+ ax.set_ylabel('CIPHER Score', fontsize = 12)
176
  ax.set_xlim(1, display_data['Rank'].max())
177
 
178
+ # Get color palette
179
+ # palette = plt.cm.get_cmap('tab10', len(selected_display_data))
180
+ palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"]
181
+
182
+
183
  # Add vertical line for selected nodes
184
  for i, node in selected_display_data.iterrows():
185
+ ax.axvline(node['Rank'], color = palette[i], linestyle = '--', label = node['Gene'], linewidth = 1.5)
186
+ # ax.text(node['Rank'] + 100, node['CIPHER Score'], node['Gene'], fontsize = 10, color = palette(i))
187
+
188
+ # Add legend
189
+ ax.legend(loc = 'upper right', fontsize = 10)
190
+ ax.grid(alpha = 0.2)
191
+
192
+
193
+ st.markdown(f"Out of 35,189 genes, the selected genes rank as follows:")
194
+ selected_display_data['Rank'] = selected_display_data['Rank'].apply(lambda x: f"{x} (top {(100*x/target_nodes.shape[0]):.2f}% of predictions)")
195
+
196
+ # Show filtered nodes
197
+ if target_node_type not in ['disease', 'anatomy']:
198
+ st.dataframe(selected_display_data, use_container_width = True, hide_index = True,
199
+ column_config={"Database": st.column_config.LinkColumn(width = "small",
200
+ help = "Click to visit external database.",
201
+ display_text = display_database)})
202
+ else:
203
+ st.dataframe(selected_display_data, use_container_width = True)
204
 
205
  # Show plot
206
+ st.markdown(f"In the plot below, the dashed lines represent the rank of the selected genes across all CIPHER predictions for Parkinson's disease.")
207
  st.pyplot(fig)
208
 
209
 
210
+ # # FULL RESULTS
211
 
212
+ # # Show top ranked nodes
213
+ # st.subheader("Model Predictions", divider = "blue")
214
+ # top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
215
 
216
+ # if target_node_type not in ['disease', 'anatomy']:
217
+ # st.dataframe(display_data.iloc[:top_k], use_container_width = True,
218
+ # column_config={"Database": st.column_config.LinkColumn(width = "small",
219
+ # help = "Click to visit external database.",
220
+ # display_text = display_database)})
221
+ # else:
222
+ # st.dataframe(display_data.iloc[:top_k], use_container_width = True)
223
+
224
+ # # Save to session state
225
+ # st.session_state.predictions = display_data
226
+ # st.session_state.display_database = display_database