cipher-asap / pages /predict.py
ayushnoori's picture
Add predict button
f083f50
import streamlit as st
from menu import menu_with_redirect
# Standard imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
# Path manipulation
from pathlib import Path
from huggingface_hub import hf_hub_download
# Plotting
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'Arial'
# Custom and other imports
import project_config
from utils import capitalize_after_slash, load_kg
# Redirect to app.py if not logged in, otherwise show the navigation menu
menu_with_redirect()
# Header
st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)
st.markdown(
'''
Use CIPHER to predict how closely genes of interest are associated with Parkinson's disease. Search for specific genes to determine their ranking of PD association.
''')
# source_node_type = st.session_state.query['source_node_type']
# source_node = st.session_state.query['source_node']
# relation = st.session_state.query['relation']
# target_node_type = st.session_state.query['target_node_type']
source_node_type = "disease"
source_node = "Parkinson disease"
relation = "disease_protein"
target_node_type = "gene/protein"
# target_node_type = st.selectbox("I am interested in searching for...", ['gene/protein', 'effect/phenotype', 'drug'],
# format_func = lambda x: x.replace("_", " "), index = 1)
# relation = {
# 'gene/protein': 'disease_protein',
# 'effect/phenotype': 'disease_phenotype_positive',
# 'drug': 'indication'
# }
# Get list of allowed nodes
allowed_nodes = {
'gene/protein': ['RHOA', 'XRN1', 'SNCA', 'LRRK2', 'GBA1'],
'effect/phenotype': ['Parkinsonism', 'Parkinsonism with favorable response to dopaminergic medication'],
'drug': ['Levodopa']
}
# Use multiselect to search for specific nodes
selected_nodes = st.multiselect("Select genes to search for...",
allowed_nodes[target_node_type], placeholder = "Type to search...",
label_visibility = 'collapsed',)
# Add line break
st.markdown("---")
# Header
st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)
# st.subheader("Gene Search", divider = "blue")
@st.cache_data(show_spinner = 'Downloading AI model...')
def get_embeddings():
# Get checkpoint name
# best_ckpt = "2024_05_22_11_59_43_epoch=18-step=22912"
best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
# best_ckpt = "2024_03_29_04_12_52_epoch=3-step=54291"
# Get paths to embeddings, relation weights, and edge types
# with st.spinner('Downloading AI model...'):
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
filename=(best_ckpt + "-thresh=4000_embeddings.pt"),
token=st.secrets["HF_TOKEN"])
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
filename=(best_ckpt + "_relation_weights.pt"),
token=st.secrets["HF_TOKEN"])
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
filename=(best_ckpt + "_edge_types.pt"),
token=st.secrets["HF_TOKEN"])
return embed_path, relation_weights_path, edge_types_path
@st.cache_data(show_spinner = 'Loading AI model...')
def load_embeddings(embed_path, relation_weights_path, edge_types_path):
# Load embeddings, relation weights, and edge types
# with st.spinner('Loading AI model...'):
embeddings = torch.load(embed_path)
relation_weights = torch.load(relation_weights_path)
edge_types = torch.load(edge_types_path)
return embeddings, relation_weights, edge_types
# Load knowledge graph and embeddings
kg_nodes = load_kg()
embed_path, relation_weights_path, edge_types_path = get_embeddings()
embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
# Compute predictions
with st.spinner('Computing predictions...'):
# Get source node index
src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
# Get relation index
edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
# Get target nodes indices
target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
dst_indices = target_nodes.node_index.values
src_indices = np.repeat(src_index, len(dst_indices))
# Retrieve cached embeddings and apply activation function
src_embeddings = embeddings[src_indices]
dst_embeddings = embeddings[dst_indices]
src_embeddings = F.leaky_relu(src_embeddings)
dst_embeddings = F.leaky_relu(dst_embeddings)
# Get relation weights
rel_weights = relation_weights[edge_type_index]
# Compute weighted dot product
scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
scores = torch.sigmoid(scores)
# Add scores to dataframe
target_nodes['score'] = scores.detach().numpy()
target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
# Rename columns
display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Gene', 'score': 'CIPHER Score', 'node_source': 'Database'})
# Define dictionary mapping node types to database URLs
map_dbs = {
'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}",
'drug': lambda x: f"https://go.drugbank.com/drugs/{x}",
'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits
'disease': lambda x: x, # MONDO
# pad with 0s to 7 digits
'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}",
'pathway': lambda x: f"https://reactome.org/content/detail/{x}",
'anatomy': lambda x: x,
}
# Get name of database
display_database = display_data['Database'].values[0]
# Add URLs to database column
display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
# NODE SEARCH
# Filter nodes
if len(selected_nodes) > 0:
selected_display_data = display_data[display_data['Gene'].isin(selected_nodes)].copy().reset_index(drop = True)
# Plot rank vs. score using matplotlib
fig, ax = plt.subplots(figsize = (10, 6))
ax.plot(display_data['Rank'], display_data['CIPHER Score'], color = 'black')
ax.set_xlabel('Rank', fontsize = 12)
ax.set_ylabel('CIPHER Score', fontsize = 12)
ax.set_xlim(1, display_data['Rank'].max())
# Get color palette
# palette = plt.cm.get_cmap('tab10', len(selected_display_data))
palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"]
# Add vertical line for selected nodes
for i, node in selected_display_data.iterrows():
ax.axvline(node['Rank'], color = palette[i], linestyle = '--', label = node['Gene'], linewidth = 1.5)
# ax.text(node['Rank'] + 100, node['CIPHER Score'], node['Gene'], fontsize = 10, color = palette(i))
# Add legend
ax.legend(loc = 'upper right', fontsize = 10)
ax.grid(alpha = 0.2)
st.markdown(f"Out of 35,189 genes, the selected genes rank as follows:")
selected_display_data['Rank'] = selected_display_data['Rank'].apply(lambda x: f"{x} (top {(100*x/target_nodes.shape[0]):.2f}% of predictions)")
# Show filtered nodes
if target_node_type not in ['disease', 'anatomy']:
st.dataframe(selected_display_data, use_container_width = True, hide_index = True,
column_config={"Database": st.column_config.LinkColumn(width = "small",
help = "Click to visit external database.",
display_text = display_database)})
else:
st.dataframe(selected_display_data, use_container_width = True)
# Show plot
st.markdown(f"In the plot below, the dashed lines represent the rank of the selected genes across all CIPHER predictions for Parkinson's disease.")
st.pyplot(fig)
# # FULL RESULTS
# # Show top ranked nodes
# st.subheader("Model Predictions", divider = "blue")
# top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
# if target_node_type not in ['disease', 'anatomy']:
# st.dataframe(display_data.iloc[:top_k], use_container_width = True,
# column_config={"Database": st.column_config.LinkColumn(width = "small",
# help = "Click to visit external database.",
# display_text = display_database)})
# else:
# st.dataframe(display_data.iloc[:top_k], use_container_width = True)
# # Save to session state
# st.session_state.predictions = display_data
# st.session_state.display_database = display_database