Spaces:
Sleeping
Sleeping
File size: 11,902 Bytes
f18a5e1 7e225aa f18a5e1 6efe11e f18a5e1 6efe11e f18a5e1 6efe11e b2ee56b 6efe11e f18a5e1 b2ee56b f18a5e1 b2ee56b f18a5e1 b2ee56b f18a5e1 6efe11e f18a5e1 6efe11e f18a5e1 6efe11e f18a5e1 6efe11e f18a5e1 079a08e 7e225aa f18a5e1 b88647e 6efe11e f18a5e1 079a08e 48009b1 079a08e 48009b1 f18a5e1 48009b1 f18a5e1 48009b1 f18a5e1 7e225aa 48009b1 7e225aa 48009b1 7e225aa 48009b1 7e225aa f18a5e1 079a08e f18a5e1 079a08e f18a5e1 079a08e 6efe11e 7e225aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 |
import streamlit as st
from menu import menu_with_redirect
# Standard imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
# Path manipulation
from pathlib import Path
from huggingface_hub import hf_hub_download
# Plotting
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'Arial'
# Custom and other imports
import project_config
from utils import capitalize_after_slash, load_kg
# Redirect to app.py if not logged in, otherwise show the navigation menu
menu_with_redirect()
# Header
st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)
# Main content
# st.markdown(f"Hello, {st.session_state.name}!")
st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")
# Print current query
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
@st.cache_data(show_spinner = 'Downloading AI model...')
def get_embeddings():
# Get checkpoint name
# best_ckpt = "2024_05_22_11_59_43_epoch=18-step=22912"
best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
# best_ckpt = "2024_03_29_04_12_52_epoch=3-step=54291"
# Get paths to embeddings, relation weights, and edge types
# with st.spinner('Downloading AI model...'):
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
filename=(best_ckpt + "-thresh=4000_embeddings.pt"),
token=st.secrets["HF_TOKEN"])
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
filename=(best_ckpt + "_relation_weights.pt"),
token=st.secrets["HF_TOKEN"])
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
filename=(best_ckpt + "_edge_types.pt"),
token=st.secrets["HF_TOKEN"])
return embed_path, relation_weights_path, edge_types_path
@st.cache_data(show_spinner = 'Loading AI model...')
def load_embeddings(embed_path, relation_weights_path, edge_types_path):
# Load embeddings, relation weights, and edge types
# with st.spinner('Loading AI model...'):
embeddings = torch.load(embed_path)
relation_weights = torch.load(relation_weights_path)
edge_types = torch.load(edge_types_path)
return embeddings, relation_weights, edge_types
# Load knowledge graph and embeddings
kg_nodes = load_kg()
embed_path, relation_weights_path, edge_types_path = get_embeddings()
embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
# # Print source node type
# st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
# # Print source node
# st.write(f"Source Node: {st.session_state.query['source_node']}")
# # Print relation
# st.write(f"Edge Type: {st.session_state.query['relation']}")
# # Print target node type
# st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
# Compute predictions
with st.spinner('Computing predictions...'):
source_node_type = st.session_state.query['source_node_type']
source_node = st.session_state.query['source_node']
relation = st.session_state.query['relation']
target_node_type = st.session_state.query['target_node_type']
# Get source node index
src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
# Get relation index
edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
# Get target nodes indices
target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
dst_indices = target_nodes.node_index.values
src_indices = np.repeat(src_index, len(dst_indices))
# Retrieve cached embeddings and apply activation function
src_embeddings = embeddings[src_indices]
dst_embeddings = embeddings[dst_indices]
src_embeddings = F.leaky_relu(src_embeddings)
dst_embeddings = F.leaky_relu(dst_embeddings)
# Get relation weights
rel_weights = relation_weights[edge_type_index]
# Compute weighted dot product
scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
scores = torch.sigmoid(scores)
# Add scores to dataframe
target_nodes['score'] = scores.detach().numpy()
target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
# Rename columns
display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'score': 'Score', 'node_source': 'Database'})
# Define dictionary mapping node types to database URLs
map_dbs = {
'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}",
'drug': lambda x: f"https://go.drugbank.com/drugs/{x}",
'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits
'disease': lambda x: x, # MONDO
# pad with 0s to 7 digits
'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}",
'pathway': lambda x: f"https://reactome.org/content/detail/{x}",
'anatomy': lambda x: x,
}
# Get name of database
display_database = display_data['Database'].values[0]
# Add URLs to database column
display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
# Check if validation data exists
if 'validation' in st.session_state:
# Checkbox to allow reverse edges
show_val = st.checkbox("Show Ground Truth Validation?", value = False)
if show_val:
# Get validation data
val_results = st.session_state.validation.copy()
# Merge with predictions
val_display_data = pd.merge(display_data, val_results, left_on = 'ID', right_on = 'y_id', how='left')
val_display_data = val_display_data.fillna(0).drop(columns='y_id')
# Get new columns
val_relations = val_display_data.columns.difference(display_data.columns).tolist()
# Replace 0 with blank and 1 with check emoji in new columns
for col in val_relations:
val_display_data[col] = val_display_data[col].replace({0: '', 1: '✅'})
# Define a function to apply styles
def style_val(val):
if val == '✅':
return 'background-color: #C2EABD;' # text-align: center;
return 'background-color: #F5F5F5;' # text-align: center;
else:
show_val = False
# NODE SEARCH
# Use multiselect to search for specific nodes
selected_nodes = st.multiselect(f"Search for specific {target_node_type.replace('_', ' ')} nodes to determine their ranking.",
display_data.Name, placeholder = "Type to search...")
# Filter nodes
if len(selected_nodes) > 0:
if show_val:
# selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)]
selected_display_data = val_display_data[val_display_data.Name.isin(selected_nodes)].copy()
selected_display_data = selected_display_data.reset_index(drop=True).style.map(style_val, subset=val_relations)
else:
selected_display_data = display_data[display_data.Name.isin(selected_nodes)].copy()
selected_display_data = selected_display_data.reset_index(drop=True)
st.markdown(f"Out of {target_nodes.shape[0]} {target_node_type} nodes, the selected nodes rank as follows:")
selected_display_data_with_rank = selected_display_data.copy()
selected_display_data_with_rank['Rank'] = selected_display_data_with_rank['Rank'].apply(lambda x: f"{x} (top {(100*x/target_nodes.shape[0]):.2f}% of predictions)")
# Show filtered nodes
if target_node_type not in ['disease', 'anatomy']:
st.dataframe(selected_display_data_with_rank, use_container_width = True, hide_index = True,
column_config={"Database": st.column_config.LinkColumn(width = "small",
help = "Click to visit external database.",
display_text = display_database)})
else:
st.dataframe(selected_display_data_with_rank, use_container_width = True)
# Show plot
st.markdown(f"In the plot below, the dashed lines represent the rank of the selected {target_node_type} nodes across all predictions for {source_node}.")
# Checkbox to show text labels
show_labels = st.checkbox("Show Text Labels?", value = False)
# Plot rank vs. score using matplotlib
fig, ax = plt.subplots(figsize = (10, 6))
ax.plot(display_data['Rank'], display_data['Score'], color = 'black', linewidth = 1.5, zorder = 2)
ax.set_xlabel('Rank', fontsize = 12)
ax.set_ylabel('Score', fontsize = 12)
ax.set_xlim(1, display_data['Rank'].max())
# Get color palette
# palette = plt.cm.get_cmap('tab10', len(selected_display_data))
palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"]
# Add vertical line for selected nodes
for i, node in selected_display_data.iterrows():
ax.scatter(node['Rank'], node['Score'], color = palette[i], zorder=3)
ax.axvline(node['Rank'], color = palette[i], linestyle = '--', linewidth = 1.5, label = node['Name'], zorder=3)
if show_labels:
ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize = 10, color = palette[i], zorder=3)
# Add legend
ax.legend(loc = 'upper right', fontsize = 10)
ax.grid(alpha = 0.2, zorder=0)
st.pyplot(fig)
# FULL RESULTS
# Show top ranked nodes
st.subheader("Model Predictions", divider = "blue")
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
# Show full results
# full_results = val_display_data.iloc[:top_k] if show_val else display_data.iloc[:top_k]
full_results = val_display_data.iloc[:top_k].style.map(style_val, subset=val_relations) if show_val else display_data.iloc[:top_k]
if target_node_type not in ['disease', 'anatomy']:
st.dataframe(full_results, use_container_width = True, hide_index = True,
column_config={"Database": st.column_config.LinkColumn(width = "small",
help = "Click to visit external database.",
display_text = display_database)})
else:
st.dataframe(full_results, use_container_width = True, hide_index = True,)
# Save to session state
st.session_state.predictions = display_data
st.session_state.display_database = display_database
|