Spaces:
Sleeping
Sleeping
File size: 9,944 Bytes
f027c05 950486e f027c05 950486e f027c05 950486e f027c05 950486e f027c05 950486e f027c05 950486e f027c05 950486e f027c05 950486e f027c05 950486e f027c05 950486e f027c05 950486e f027c05 950486e f027c05 950486e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import streamlit as st
from menu import menu_with_redirect
# Standard imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
# Path manipulation
from pathlib import Path
from huggingface_hub import hf_hub_download
# Plotting
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'Arial'
# Custom and other imports
import project_config
from utils import capitalize_after_slash, load_kg
# Redirect to app.py if not logged in, otherwise show the navigation menu
menu_with_redirect()
# Header
st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)
st.markdown(
'''
Use CIPHER to predict how closely genes of interest are associated with Parkinson's disease. Search for specific genes to determine their ranking of PD association.
''')
# source_node_type = st.session_state.query['source_node_type']
# source_node = st.session_state.query['source_node']
# relation = st.session_state.query['relation']
# target_node_type = st.session_state.query['target_node_type']
source_node_type = "disease"
source_node = "Parkinson disease"
relation = "disease_protein"
target_node_type = "gene/protein"
# target_node_type = st.selectbox("I am interested in searching for...", ['gene/protein', 'effect/phenotype', 'drug'],
# format_func = lambda x: x.replace("_", " "), index = 1)
# relation = {
# 'gene/protein': 'disease_protein',
# 'effect/phenotype': 'disease_phenotype_positive',
# 'drug': 'indication'
# }
# Get list of allowed nodes
allowed_nodes = {
'gene/protein': ['RHOA', 'XRN1', 'SNCA', 'LRRK2', 'GBA1'],
'effect/phenotype': ['Parkinsonism', 'Parkinsonism with favorable response to dopaminergic medication'],
'drug': ['Levodopa']
}
# Use multiselect to search for specific nodes
selected_nodes = st.multiselect("Select genes to search for...",
allowed_nodes[target_node_type], placeholder = "Type to search...",
label_visibility = 'collapsed',)
# Add line break
st.markdown("---")
# Header
st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)
# st.subheader("Gene Search", divider = "blue")
@st.cache_data(show_spinner = 'Downloading AI model...')
def get_embeddings():
# Get checkpoint name
# best_ckpt = "2024_05_22_11_59_43_epoch=18-step=22912"
best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
# best_ckpt = "2024_03_29_04_12_52_epoch=3-step=54291"
# Get paths to embeddings, relation weights, and edge types
# with st.spinner('Downloading AI model...'):
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
filename=(best_ckpt + "-thresh=4000_embeddings.pt"),
token=st.secrets["HF_TOKEN"])
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
filename=(best_ckpt + "_relation_weights.pt"),
token=st.secrets["HF_TOKEN"])
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
filename=(best_ckpt + "_edge_types.pt"),
token=st.secrets["HF_TOKEN"])
return embed_path, relation_weights_path, edge_types_path
@st.cache_data(show_spinner = 'Loading AI model...')
def load_embeddings(embed_path, relation_weights_path, edge_types_path):
# Load embeddings, relation weights, and edge types
# with st.spinner('Loading AI model...'):
embeddings = torch.load(embed_path)
relation_weights = torch.load(relation_weights_path)
edge_types = torch.load(edge_types_path)
return embeddings, relation_weights, edge_types
# Load knowledge graph and embeddings
kg_nodes = load_kg()
embed_path, relation_weights_path, edge_types_path = get_embeddings()
embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
# Compute predictions
with st.spinner('Computing predictions...'):
# Get source node index
src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
# Get relation index
edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
# Get target nodes indices
target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
dst_indices = target_nodes.node_index.values
src_indices = np.repeat(src_index, len(dst_indices))
# Retrieve cached embeddings and apply activation function
src_embeddings = embeddings[src_indices]
dst_embeddings = embeddings[dst_indices]
src_embeddings = F.leaky_relu(src_embeddings)
dst_embeddings = F.leaky_relu(dst_embeddings)
# Get relation weights
rel_weights = relation_weights[edge_type_index]
# Compute weighted dot product
scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
scores = torch.sigmoid(scores)
# Add scores to dataframe
target_nodes['score'] = scores.detach().numpy()
target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
# Rename columns
display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Gene', 'score': 'CIPHER Score', 'node_source': 'Database'})
# Define dictionary mapping node types to database URLs
map_dbs = {
'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}",
'drug': lambda x: f"https://go.drugbank.com/drugs/{x}",
'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits
'disease': lambda x: x, # MONDO
# pad with 0s to 7 digits
'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}",
'pathway': lambda x: f"https://reactome.org/content/detail/{x}",
'anatomy': lambda x: x,
}
# Get name of database
display_database = display_data['Database'].values[0]
# Add URLs to database column
display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
# NODE SEARCH
# Filter nodes
if len(selected_nodes) > 0:
selected_display_data = display_data[display_data['Gene'].isin(selected_nodes)].copy().reset_index(drop = True)
# Plot rank vs. score using matplotlib
fig, ax = plt.subplots(figsize = (10, 6))
ax.plot(display_data['Rank'], display_data['CIPHER Score'], color = 'black')
ax.set_xlabel('Rank', fontsize = 12)
ax.set_ylabel('CIPHER Score', fontsize = 12)
ax.set_xlim(1, display_data['Rank'].max())
# Get color palette
# palette = plt.cm.get_cmap('tab10', len(selected_display_data))
palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"]
# Add vertical line for selected nodes
for i, node in selected_display_data.iterrows():
ax.axvline(node['Rank'], color = palette[i], linestyle = '--', label = node['Gene'], linewidth = 1.5)
# ax.text(node['Rank'] + 100, node['CIPHER Score'], node['Gene'], fontsize = 10, color = palette(i))
# Add legend
ax.legend(loc = 'upper right', fontsize = 10)
ax.grid(alpha = 0.2)
st.markdown(f"Out of 35,189 genes, the selected genes rank as follows:")
selected_display_data['Rank'] = selected_display_data['Rank'].apply(lambda x: f"{x} (top {(100*x/target_nodes.shape[0]):.2f}% of predictions)")
# Show filtered nodes
if target_node_type not in ['disease', 'anatomy']:
st.dataframe(selected_display_data, use_container_width = True, hide_index = True,
column_config={"Database": st.column_config.LinkColumn(width = "small",
help = "Click to visit external database.",
display_text = display_database)})
else:
st.dataframe(selected_display_data, use_container_width = True)
# Show plot
st.markdown(f"In the plot below, the dashed lines represent the rank of the selected genes across all CIPHER predictions for Parkinson's disease.")
st.pyplot(fig)
# # FULL RESULTS
# # Show top ranked nodes
# st.subheader("Model Predictions", divider = "blue")
# top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
# if target_node_type not in ['disease', 'anatomy']:
# st.dataframe(display_data.iloc[:top_k], use_container_width = True,
# column_config={"Database": st.column_config.LinkColumn(width = "small",
# help = "Click to visit external database.",
# display_text = display_database)})
# else:
# st.dataframe(display_data.iloc[:top_k], use_container_width = True)
# # Save to session state
# st.session_state.predictions = display_data
# st.session_state.display_database = display_database
|