Spaces:
Sleeping
Sleeping
Commit
·
950486e
1
Parent(s):
f027c05
Demo for ASAP reviewers
Browse files- {pages → deprecated}/input.py +0 -0
- deprecated/predict.py +199 -0
- {pages → deprecated}/validate.py +0 -0
- media/predict_header.svg +1 -1
- menu.py +6 -5
- pages/about.py +2 -2
- pages/predict.py +87 -60
{pages → deprecated}/input.py
RENAMED
File without changes
|
deprecated/predict.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from menu import menu_with_redirect
|
3 |
+
|
4 |
+
# Standard imports
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
# Path manipulation
|
12 |
+
from pathlib import Path
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
+
|
15 |
+
# Plotting
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
plt.rcParams['font.sans-serif'] = 'Arial'
|
18 |
+
|
19 |
+
# Custom and other imports
|
20 |
+
import project_config
|
21 |
+
from utils import capitalize_after_slash, load_kg
|
22 |
+
|
23 |
+
# Redirect to app.py if not logged in, otherwise show the navigation menu
|
24 |
+
menu_with_redirect()
|
25 |
+
|
26 |
+
# Header
|
27 |
+
st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)
|
28 |
+
|
29 |
+
# Main content
|
30 |
+
# st.markdown(f"Hello, {st.session_state.name}!")
|
31 |
+
|
32 |
+
st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")
|
33 |
+
|
34 |
+
# Print current query
|
35 |
+
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
|
36 |
+
|
37 |
+
@st.cache_data(show_spinner = 'Downloading AI model...')
|
38 |
+
def get_embeddings():
|
39 |
+
# Get checkpoint name
|
40 |
+
# best_ckpt = "2024_05_22_11_59_43_epoch=18-step=22912"
|
41 |
+
best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
|
42 |
+
# best_ckpt = "2024_03_29_04_12_52_epoch=3-step=54291"
|
43 |
+
|
44 |
+
# Get paths to embeddings, relation weights, and edge types
|
45 |
+
# with st.spinner('Downloading AI model...'):
|
46 |
+
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
47 |
+
filename=(best_ckpt + "-thresh=4000_embeddings.pt"),
|
48 |
+
token=st.secrets["HF_TOKEN"])
|
49 |
+
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
50 |
+
filename=(best_ckpt + "_relation_weights.pt"),
|
51 |
+
token=st.secrets["HF_TOKEN"])
|
52 |
+
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
53 |
+
filename=(best_ckpt + "_edge_types.pt"),
|
54 |
+
token=st.secrets["HF_TOKEN"])
|
55 |
+
return embed_path, relation_weights_path, edge_types_path
|
56 |
+
|
57 |
+
@st.cache_data(show_spinner = 'Loading AI model...')
|
58 |
+
def load_embeddings(embed_path, relation_weights_path, edge_types_path):
|
59 |
+
# Load embeddings, relation weights, and edge types
|
60 |
+
# with st.spinner('Loading AI model...'):
|
61 |
+
embeddings = torch.load(embed_path)
|
62 |
+
relation_weights = torch.load(relation_weights_path)
|
63 |
+
edge_types = torch.load(edge_types_path)
|
64 |
+
|
65 |
+
return embeddings, relation_weights, edge_types
|
66 |
+
|
67 |
+
# Load knowledge graph and embeddings
|
68 |
+
kg_nodes = load_kg()
|
69 |
+
embed_path, relation_weights_path, edge_types_path = get_embeddings()
|
70 |
+
embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
|
71 |
+
|
72 |
+
# # Print source node type
|
73 |
+
# st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
|
74 |
+
|
75 |
+
# # Print source node
|
76 |
+
# st.write(f"Source Node: {st.session_state.query['source_node']}")
|
77 |
+
|
78 |
+
# # Print relation
|
79 |
+
# st.write(f"Edge Type: {st.session_state.query['relation']}")
|
80 |
+
|
81 |
+
# # Print target node type
|
82 |
+
# st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
|
83 |
+
|
84 |
+
# Compute predictions
|
85 |
+
with st.spinner('Computing predictions...'):
|
86 |
+
|
87 |
+
source_node_type = st.session_state.query['source_node_type']
|
88 |
+
source_node = st.session_state.query['source_node']
|
89 |
+
relation = st.session_state.query['relation']
|
90 |
+
target_node_type = st.session_state.query['target_node_type']
|
91 |
+
|
92 |
+
# Get source node index
|
93 |
+
src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
|
94 |
+
|
95 |
+
# Get relation index
|
96 |
+
edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
|
97 |
+
|
98 |
+
# Get target nodes indices
|
99 |
+
target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
|
100 |
+
dst_indices = target_nodes.node_index.values
|
101 |
+
src_indices = np.repeat(src_index, len(dst_indices))
|
102 |
+
|
103 |
+
# Retrieve cached embeddings and apply activation function
|
104 |
+
src_embeddings = embeddings[src_indices]
|
105 |
+
dst_embeddings = embeddings[dst_indices]
|
106 |
+
src_embeddings = F.leaky_relu(src_embeddings)
|
107 |
+
dst_embeddings = F.leaky_relu(dst_embeddings)
|
108 |
+
|
109 |
+
# Get relation weights
|
110 |
+
rel_weights = relation_weights[edge_type_index]
|
111 |
+
|
112 |
+
# Compute weighted dot product
|
113 |
+
scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
|
114 |
+
scores = torch.sigmoid(scores)
|
115 |
+
|
116 |
+
# Add scores to dataframe
|
117 |
+
target_nodes['score'] = scores.detach().numpy()
|
118 |
+
target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
|
119 |
+
target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
|
120 |
+
|
121 |
+
# Rename columns
|
122 |
+
display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
|
123 |
+
display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'score': 'Score', 'node_source': 'Database'})
|
124 |
+
|
125 |
+
# Define dictionary mapping node types to database URLs
|
126 |
+
map_dbs = {
|
127 |
+
'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}",
|
128 |
+
'drug': lambda x: f"https://go.drugbank.com/drugs/{x}",
|
129 |
+
'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits
|
130 |
+
'disease': lambda x: x, # MONDO
|
131 |
+
# pad with 0s to 7 digits
|
132 |
+
'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
|
133 |
+
'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
|
134 |
+
'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
|
135 |
+
'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}",
|
136 |
+
'pathway': lambda x: f"https://reactome.org/content/detail/{x}",
|
137 |
+
'anatomy': lambda x: x,
|
138 |
+
}
|
139 |
+
|
140 |
+
# Get name of database
|
141 |
+
display_database = display_data['Database'].values[0]
|
142 |
+
|
143 |
+
# Add URLs to database column
|
144 |
+
display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
|
145 |
+
|
146 |
+
|
147 |
+
# NODE SEARCH
|
148 |
+
|
149 |
+
# Use multiselect to search for specific nodes
|
150 |
+
selected_nodes = st.multiselect(f"Search for specific {st.session_state.query['source_node_type'].replace('_', ' ')} nodes to determine their ranking.",
|
151 |
+
display_data.Name, placeholder = "Type to search...")
|
152 |
+
|
153 |
+
# Filter nodes
|
154 |
+
if len(selected_nodes) > 0:
|
155 |
+
selected_display_data = display_data[display_data.Name.isin(selected_nodes)]
|
156 |
+
|
157 |
+
# Show filtered nodes
|
158 |
+
if target_node_type not in ['disease', 'anatomy']:
|
159 |
+
st.dataframe(selected_display_data, use_container_width = True,
|
160 |
+
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
161 |
+
help = "Click to visit external database.",
|
162 |
+
display_text = display_database)})
|
163 |
+
else:
|
164 |
+
st.dataframe(selected_display_data, use_container_width = True)
|
165 |
+
|
166 |
+
# Plot rank vs. score using matplotlib
|
167 |
+
st.markdown("**Rank vs. Score**")
|
168 |
+
fig, ax = plt.subplots(figsize = (10, 6))
|
169 |
+
ax.plot(display_data['Rank'], display_data['Score'])
|
170 |
+
ax.set_xlabel('Rank', fontsize = 12)
|
171 |
+
ax.set_ylabel('Score', fontsize = 12)
|
172 |
+
ax.set_xlim(1, display_data['Rank'].max())
|
173 |
+
|
174 |
+
# Add vertical line for selected nodes
|
175 |
+
for i, node in selected_display_data.iterrows():
|
176 |
+
ax.axvline(node['Rank'], color = 'red', linestyle = '--', label = node['Name'])
|
177 |
+
ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize = 10, color = 'red')
|
178 |
+
|
179 |
+
# Show plot
|
180 |
+
st.pyplot(fig)
|
181 |
+
|
182 |
+
|
183 |
+
# FULL RESULTS
|
184 |
+
|
185 |
+
# Show top ranked nodes
|
186 |
+
st.subheader("Model Predictions", divider = "blue")
|
187 |
+
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
|
188 |
+
|
189 |
+
if target_node_type not in ['disease', 'anatomy']:
|
190 |
+
st.dataframe(display_data.iloc[:top_k], use_container_width = True,
|
191 |
+
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
192 |
+
help = "Click to visit external database.",
|
193 |
+
display_text = display_database)})
|
194 |
+
else:
|
195 |
+
st.dataframe(display_data.iloc[:top_k], use_container_width = True)
|
196 |
+
|
197 |
+
# Save to session state
|
198 |
+
st.session_state.predictions = display_data
|
199 |
+
st.session_state.display_database = display_database
|
{pages → deprecated}/validate.py
RENAMED
File without changes
|
media/predict_header.svg
CHANGED
|
|
menu.py
CHANGED
@@ -15,11 +15,12 @@ def authenticated_menu():
|
|
15 |
# Show a navigation menu for authenticated users
|
16 |
# st.sidebar.page_link("app.py", label="Switch Accounts", icon="🔒")
|
17 |
st.sidebar.page_link("pages/about.py", label="About", icon="📖")
|
18 |
-
st.sidebar.page_link("pages/
|
19 |
-
st.sidebar.page_link("pages/
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
23 |
# st.sidebar.page_link("pages/explore.py", label="Explore", icon="🔍")
|
24 |
if st.session_state.role in ["admin"]:
|
25 |
st.sidebar.page_link("pages/admin.py", label="Manage Users", icon="🔧")
|
|
|
15 |
# Show a navigation menu for authenticated users
|
16 |
# st.sidebar.page_link("app.py", label="Switch Accounts", icon="🔒")
|
17 |
st.sidebar.page_link("pages/about.py", label="About", icon="📖")
|
18 |
+
st.sidebar.page_link("pages/predict.py", label="Predict", icon="🔍")
|
19 |
+
# st.sidebar.page_link("pages/input.py", label="Input", icon="💡")
|
20 |
+
# st.sidebar.page_link("pages/predict.py", label="Predict", icon="🔍",
|
21 |
+
# disabled=("query" not in st.session_state))
|
22 |
+
# st.sidebar.page_link("pages/validate.py", label="Validate", icon="✅",
|
23 |
+
# disabled=("query" not in st.session_state))
|
24 |
# st.sidebar.page_link("pages/explore.py", label="Explore", icon="🔍")
|
25 |
if st.session_state.role in ["admin"]:
|
26 |
st.sidebar.page_link("pages/admin.py", label="Manage Users", icon="🔧")
|
pages/about.py
CHANGED
@@ -22,9 +22,9 @@ st.subheader("About CIPHER", divider = "grey")
|
|
22 |
st.markdown("""
|
23 |
CIPHER is a knowledge graph-based AI algorithm for diagnostic and therapeutic discovery in PD.
|
24 |
|
25 |
-
*Knowledge graph construction.* To create CIPHER, we integrated diverse public information about basic biomedical interactions into a harmonized data platform amenable for training large-scale AI models. Specifically, we constructed a multiscale heterogeneous knowledge graph (KG) with *n = 143,093
|
26 |
|
27 |
-
*Model training.* Next, to convert this trove of knowledge into an AI model with diagnostic and therapeutic capabilities, we employed graph representation learning, a deep learning to model biomedical networks by embedding graphs into informative low-dimensional vector spaces. We trained a state-of-the-art heterogeneous graph
|
28 |
|
29 |
Through CIPHER, we seek to enable molecular subtyping and patient stratification of PD by integrating genetic and clinical progression data (*e.g.*, PPMI and HBS2.0 cohorts) and nominate genes, proteins, and pathways for in-depth mechanistic studies in stem cell and other PD models.
|
30 |
""")
|
|
|
22 |
st.markdown("""
|
23 |
CIPHER is a knowledge graph-based AI algorithm for diagnostic and therapeutic discovery in PD.
|
24 |
|
25 |
+
*Knowledge graph construction.* To create CIPHER, we integrated diverse public information about basic biomedical interactions into a harmonized data platform amenable for training large-scale AI models. Specifically, we constructed a multiscale heterogeneous knowledge graph (KG) with *n* = 143,093 nodes and *n* = 7,048,795 edges by curating 36 high-quality primary data sources, ontologies, and knowledge bases.
|
26 |
|
27 |
+
*Model training.* Next, to convert this trove of knowledge into an AI model with diagnostic and therapeutic capabilities, we employed graph representation learning, a deep learning method to model biomedical networks by embedding graphs into informative low-dimensional vector spaces. We trained a state-of-the-art heterogeneous graph Transformer to learn graph embeddings that encode the relationships in the KG.
|
28 |
|
29 |
Through CIPHER, we seek to enable molecular subtyping and patient stratification of PD by integrating genetic and clinical progression data (*e.g.*, PPMI and HBS2.0 cohorts) and nominate genes, proteins, and pathways for in-depth mechanistic studies in stem cell and other PD models.
|
30 |
""")
|
pages/predict.py
CHANGED
@@ -24,15 +24,51 @@ from utils import capitalize_after_slash, load_kg
|
|
24 |
menu_with_redirect()
|
25 |
|
26 |
# Header
|
27 |
-
st.image(str(project_config.MEDIA_DIR / '
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
|
|
|
33 |
|
34 |
-
#
|
35 |
-
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
|
36 |
|
37 |
@st.cache_data(show_spinner = 'Downloading AI model...')
|
38 |
def get_embeddings():
|
@@ -69,26 +105,9 @@ kg_nodes = load_kg()
|
|
69 |
embed_path, relation_weights_path, edge_types_path = get_embeddings()
|
70 |
embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
|
71 |
|
72 |
-
# # Print source node type
|
73 |
-
# st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
|
74 |
-
|
75 |
-
# # Print source node
|
76 |
-
# st.write(f"Source Node: {st.session_state.query['source_node']}")
|
77 |
-
|
78 |
-
# # Print relation
|
79 |
-
# st.write(f"Edge Type: {st.session_state.query['relation']}")
|
80 |
-
|
81 |
-
# # Print target node type
|
82 |
-
# st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
|
83 |
-
|
84 |
# Compute predictions
|
85 |
with st.spinner('Computing predictions...'):
|
86 |
|
87 |
-
source_node_type = st.session_state.query['source_node_type']
|
88 |
-
source_node = st.session_state.query['source_node']
|
89 |
-
relation = st.session_state.query['relation']
|
90 |
-
target_node_type = st.session_state.query['target_node_type']
|
91 |
-
|
92 |
# Get source node index
|
93 |
src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
|
94 |
|
@@ -120,7 +139,7 @@ with st.spinner('Computing predictions...'):
|
|
120 |
|
121 |
# Rename columns
|
122 |
display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
|
123 |
-
display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': '
|
124 |
|
125 |
# Define dictionary mapping node types to database URLs
|
126 |
map_dbs = {
|
@@ -143,57 +162,65 @@ with st.spinner('Computing predictions...'):
|
|
143 |
# Add URLs to database column
|
144 |
display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
|
145 |
|
146 |
-
|
147 |
# NODE SEARCH
|
148 |
|
149 |
-
# Use multiselect to search for specific nodes
|
150 |
-
selected_nodes = st.multiselect(f"Search for specific {st.session_state.query['source_node_type'].replace('_', ' ')} nodes to determine their ranking.",
|
151 |
-
display_data.Name, placeholder = "Type to search...")
|
152 |
-
|
153 |
# Filter nodes
|
154 |
if len(selected_nodes) > 0:
|
155 |
-
selected_display_data = display_data[display_data.
|
156 |
-
|
157 |
-
# Show filtered nodes
|
158 |
-
if target_node_type not in ['disease', 'anatomy']:
|
159 |
-
st.dataframe(selected_display_data, use_container_width = True,
|
160 |
-
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
161 |
-
help = "Click to visit external database.",
|
162 |
-
display_text = display_database)})
|
163 |
-
else:
|
164 |
-
st.dataframe(selected_display_data, use_container_width = True)
|
165 |
|
166 |
# Plot rank vs. score using matplotlib
|
167 |
-
st.markdown("**Rank vs. Score**")
|
168 |
fig, ax = plt.subplots(figsize = (10, 6))
|
169 |
-
ax.plot(display_data['Rank'], display_data['Score'])
|
170 |
ax.set_xlabel('Rank', fontsize = 12)
|
171 |
-
ax.set_ylabel('Score', fontsize = 12)
|
172 |
ax.set_xlim(1, display_data['Rank'].max())
|
173 |
|
|
|
|
|
|
|
|
|
|
|
174 |
# Add vertical line for selected nodes
|
175 |
for i, node in selected_display_data.iterrows():
|
176 |
-
ax.axvline(node['Rank'], color =
|
177 |
-
ax.text(node['Rank'] + 100, node['Score'], node['
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
# Show plot
|
|
|
180 |
st.pyplot(fig)
|
181 |
|
182 |
|
183 |
-
# FULL RESULTS
|
184 |
|
185 |
-
# Show top ranked nodes
|
186 |
-
st.subheader("Model Predictions", divider = "blue")
|
187 |
-
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
|
188 |
|
189 |
-
if target_node_type not in ['disease', 'anatomy']:
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
else:
|
195 |
-
|
196 |
-
|
197 |
-
# Save to session state
|
198 |
-
st.session_state.predictions = display_data
|
199 |
-
st.session_state.display_database = display_database
|
|
|
24 |
menu_with_redirect()
|
25 |
|
26 |
# Header
|
27 |
+
st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)
|
28 |
+
|
29 |
+
st.markdown(
|
30 |
+
'''
|
31 |
+
Use CIPHER to predict how closely genes of interest are associated with Parkinson's disease. Search for specific genes to determine their ranking of PD association.
|
32 |
+
''')
|
33 |
+
|
34 |
+
# source_node_type = st.session_state.query['source_node_type']
|
35 |
+
# source_node = st.session_state.query['source_node']
|
36 |
+
# relation = st.session_state.query['relation']
|
37 |
+
# target_node_type = st.session_state.query['target_node_type']
|
38 |
+
source_node_type = "disease"
|
39 |
+
source_node = "Parkinson disease"
|
40 |
+
relation = "disease_protein"
|
41 |
+
target_node_type = "gene/protein"
|
42 |
+
|
43 |
+
# target_node_type = st.selectbox("I am interested in searching for...", ['gene/protein', 'effect/phenotype', 'drug'],
|
44 |
+
# format_func = lambda x: x.replace("_", " "), index = 1)
|
45 |
+
|
46 |
+
# relation = {
|
47 |
+
# 'gene/protein': 'disease_protein',
|
48 |
+
# 'effect/phenotype': 'disease_phenotype_positive',
|
49 |
+
# 'drug': 'indication'
|
50 |
+
# }
|
51 |
+
|
52 |
+
# Get list of allowed nodes
|
53 |
+
allowed_nodes = {
|
54 |
+
'gene/protein': ['RHOA', 'XRN1', 'SNCA', 'LRRK2', 'GBA1'],
|
55 |
+
'effect/phenotype': ['Parkinsonism', 'Parkinsonism with favorable response to dopaminergic medication'],
|
56 |
+
'drug': ['Levodopa']
|
57 |
+
}
|
58 |
+
|
59 |
+
# Use multiselect to search for specific nodes
|
60 |
+
selected_nodes = st.multiselect("Select genes to search for...",
|
61 |
+
allowed_nodes[target_node_type], placeholder = "Type to search...",
|
62 |
+
label_visibility = 'collapsed',)
|
63 |
+
|
64 |
+
|
65 |
+
# Add line break
|
66 |
+
st.markdown("---")
|
67 |
|
68 |
+
# Header
|
69 |
+
st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)
|
70 |
|
71 |
+
# st.subheader("Gene Search", divider = "blue")
|
|
|
72 |
|
73 |
@st.cache_data(show_spinner = 'Downloading AI model...')
|
74 |
def get_embeddings():
|
|
|
105 |
embed_path, relation_weights_path, edge_types_path = get_embeddings()
|
106 |
embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
# Compute predictions
|
109 |
with st.spinner('Computing predictions...'):
|
110 |
|
|
|
|
|
|
|
|
|
|
|
111 |
# Get source node index
|
112 |
src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
|
113 |
|
|
|
139 |
|
140 |
# Rename columns
|
141 |
display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
|
142 |
+
display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Gene', 'score': 'CIPHER Score', 'node_source': 'Database'})
|
143 |
|
144 |
# Define dictionary mapping node types to database URLs
|
145 |
map_dbs = {
|
|
|
162 |
# Add URLs to database column
|
163 |
display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
|
164 |
|
|
|
165 |
# NODE SEARCH
|
166 |
|
|
|
|
|
|
|
|
|
167 |
# Filter nodes
|
168 |
if len(selected_nodes) > 0:
|
169 |
+
selected_display_data = display_data[display_data['Gene'].isin(selected_nodes)].copy().reset_index(drop = True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
# Plot rank vs. score using matplotlib
|
|
|
172 |
fig, ax = plt.subplots(figsize = (10, 6))
|
173 |
+
ax.plot(display_data['Rank'], display_data['CIPHER Score'], color = 'black')
|
174 |
ax.set_xlabel('Rank', fontsize = 12)
|
175 |
+
ax.set_ylabel('CIPHER Score', fontsize = 12)
|
176 |
ax.set_xlim(1, display_data['Rank'].max())
|
177 |
|
178 |
+
# Get color palette
|
179 |
+
# palette = plt.cm.get_cmap('tab10', len(selected_display_data))
|
180 |
+
palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"]
|
181 |
+
|
182 |
+
|
183 |
# Add vertical line for selected nodes
|
184 |
for i, node in selected_display_data.iterrows():
|
185 |
+
ax.axvline(node['Rank'], color = palette[i], linestyle = '--', label = node['Gene'], linewidth = 1.5)
|
186 |
+
# ax.text(node['Rank'] + 100, node['CIPHER Score'], node['Gene'], fontsize = 10, color = palette(i))
|
187 |
+
|
188 |
+
# Add legend
|
189 |
+
ax.legend(loc = 'upper right', fontsize = 10)
|
190 |
+
ax.grid(alpha = 0.2)
|
191 |
+
|
192 |
+
|
193 |
+
st.markdown(f"Out of 35,189 genes, the selected genes rank as follows:")
|
194 |
+
selected_display_data['Rank'] = selected_display_data['Rank'].apply(lambda x: f"{x} (top {(100*x/target_nodes.shape[0]):.2f}% of predictions)")
|
195 |
+
|
196 |
+
# Show filtered nodes
|
197 |
+
if target_node_type not in ['disease', 'anatomy']:
|
198 |
+
st.dataframe(selected_display_data, use_container_width = True, hide_index = True,
|
199 |
+
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
200 |
+
help = "Click to visit external database.",
|
201 |
+
display_text = display_database)})
|
202 |
+
else:
|
203 |
+
st.dataframe(selected_display_data, use_container_width = True)
|
204 |
|
205 |
# Show plot
|
206 |
+
st.markdown(f"In the plot below, the dashed lines represent the rank of the selected genes across all CIPHER predictions for Parkinson's disease.")
|
207 |
st.pyplot(fig)
|
208 |
|
209 |
|
210 |
+
# # FULL RESULTS
|
211 |
|
212 |
+
# # Show top ranked nodes
|
213 |
+
# st.subheader("Model Predictions", divider = "blue")
|
214 |
+
# top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
|
215 |
|
216 |
+
# if target_node_type not in ['disease', 'anatomy']:
|
217 |
+
# st.dataframe(display_data.iloc[:top_k], use_container_width = True,
|
218 |
+
# column_config={"Database": st.column_config.LinkColumn(width = "small",
|
219 |
+
# help = "Click to visit external database.",
|
220 |
+
# display_text = display_database)})
|
221 |
+
# else:
|
222 |
+
# st.dataframe(display_data.iloc[:top_k], use_container_width = True)
|
223 |
+
|
224 |
+
# # Save to session state
|
225 |
+
# st.session_state.predictions = display_data
|
226 |
+
# st.session_state.display_database = display_database
|