Spaces:
Sleeping
Sleeping
File size: 8,841 Bytes
950486e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import streamlit as st
from menu import menu_with_redirect
# Standard imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
# Path manipulation
from pathlib import Path
from huggingface_hub import hf_hub_download
# Plotting
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'Arial'
# Custom and other imports
import project_config
from utils import capitalize_after_slash, load_kg
# Redirect to app.py if not logged in, otherwise show the navigation menu
menu_with_redirect()
# Header
st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)
# Main content
# st.markdown(f"Hello, {st.session_state.name}!")
st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")
# Print current query
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
@st.cache_data(show_spinner = 'Downloading AI model...')
def get_embeddings():
# Get checkpoint name
# best_ckpt = "2024_05_22_11_59_43_epoch=18-step=22912"
best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
# best_ckpt = "2024_03_29_04_12_52_epoch=3-step=54291"
# Get paths to embeddings, relation weights, and edge types
# with st.spinner('Downloading AI model...'):
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
filename=(best_ckpt + "-thresh=4000_embeddings.pt"),
token=st.secrets["HF_TOKEN"])
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
filename=(best_ckpt + "_relation_weights.pt"),
token=st.secrets["HF_TOKEN"])
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
filename=(best_ckpt + "_edge_types.pt"),
token=st.secrets["HF_TOKEN"])
return embed_path, relation_weights_path, edge_types_path
@st.cache_data(show_spinner = 'Loading AI model...')
def load_embeddings(embed_path, relation_weights_path, edge_types_path):
# Load embeddings, relation weights, and edge types
# with st.spinner('Loading AI model...'):
embeddings = torch.load(embed_path)
relation_weights = torch.load(relation_weights_path)
edge_types = torch.load(edge_types_path)
return embeddings, relation_weights, edge_types
# Load knowledge graph and embeddings
kg_nodes = load_kg()
embed_path, relation_weights_path, edge_types_path = get_embeddings()
embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
# # Print source node type
# st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
# # Print source node
# st.write(f"Source Node: {st.session_state.query['source_node']}")
# # Print relation
# st.write(f"Edge Type: {st.session_state.query['relation']}")
# # Print target node type
# st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
# Compute predictions
with st.spinner('Computing predictions...'):
source_node_type = st.session_state.query['source_node_type']
source_node = st.session_state.query['source_node']
relation = st.session_state.query['relation']
target_node_type = st.session_state.query['target_node_type']
# Get source node index
src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
# Get relation index
edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
# Get target nodes indices
target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
dst_indices = target_nodes.node_index.values
src_indices = np.repeat(src_index, len(dst_indices))
# Retrieve cached embeddings and apply activation function
src_embeddings = embeddings[src_indices]
dst_embeddings = embeddings[dst_indices]
src_embeddings = F.leaky_relu(src_embeddings)
dst_embeddings = F.leaky_relu(dst_embeddings)
# Get relation weights
rel_weights = relation_weights[edge_type_index]
# Compute weighted dot product
scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
scores = torch.sigmoid(scores)
# Add scores to dataframe
target_nodes['score'] = scores.detach().numpy()
target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
# Rename columns
display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'score': 'Score', 'node_source': 'Database'})
# Define dictionary mapping node types to database URLs
map_dbs = {
'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}",
'drug': lambda x: f"https://go.drugbank.com/drugs/{x}",
'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits
'disease': lambda x: x, # MONDO
# pad with 0s to 7 digits
'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}",
'pathway': lambda x: f"https://reactome.org/content/detail/{x}",
'anatomy': lambda x: x,
}
# Get name of database
display_database = display_data['Database'].values[0]
# Add URLs to database column
display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
# NODE SEARCH
# Use multiselect to search for specific nodes
selected_nodes = st.multiselect(f"Search for specific {st.session_state.query['source_node_type'].replace('_', ' ')} nodes to determine their ranking.",
display_data.Name, placeholder = "Type to search...")
# Filter nodes
if len(selected_nodes) > 0:
selected_display_data = display_data[display_data.Name.isin(selected_nodes)]
# Show filtered nodes
if target_node_type not in ['disease', 'anatomy']:
st.dataframe(selected_display_data, use_container_width = True,
column_config={"Database": st.column_config.LinkColumn(width = "small",
help = "Click to visit external database.",
display_text = display_database)})
else:
st.dataframe(selected_display_data, use_container_width = True)
# Plot rank vs. score using matplotlib
st.markdown("**Rank vs. Score**")
fig, ax = plt.subplots(figsize = (10, 6))
ax.plot(display_data['Rank'], display_data['Score'])
ax.set_xlabel('Rank', fontsize = 12)
ax.set_ylabel('Score', fontsize = 12)
ax.set_xlim(1, display_data['Rank'].max())
# Add vertical line for selected nodes
for i, node in selected_display_data.iterrows():
ax.axvline(node['Rank'], color = 'red', linestyle = '--', label = node['Name'])
ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize = 10, color = 'red')
# Show plot
st.pyplot(fig)
# FULL RESULTS
# Show top ranked nodes
st.subheader("Model Predictions", divider = "blue")
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
if target_node_type not in ['disease', 'anatomy']:
st.dataframe(display_data.iloc[:top_k], use_container_width = True,
column_config={"Database": st.column_config.LinkColumn(width = "small",
help = "Click to visit external database.",
display_text = display_database)})
else:
st.dataframe(display_data.iloc[:top_k], use_container_width = True)
# Save to session state
st.session_state.predictions = display_data
st.session_state.display_database = display_database
|