import streamlit as st from menu import menu_with_redirect # Standard imports import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F # Path manipulation from pathlib import Path from huggingface_hub import hf_hub_download # Plotting import matplotlib.pyplot as plt plt.rcParams['font.sans-serif'] = 'Arial' # Custom and other imports import project_config from utils import capitalize_after_slash, load_kg # Redirect to app.py if not logged in, otherwise show the navigation menu menu_with_redirect() # Header st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True) # Main content # st.markdown(f"Hello, {st.session_state.name}!") st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue") # Print current query st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}") @st.cache_data(show_spinner = 'Downloading AI model...') def get_embeddings(): # Get checkpoint name # best_ckpt = "2024_05_22_11_59_43_epoch=18-step=22912" best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383" # best_ckpt = "2024_03_29_04_12_52_epoch=3-step=54291" # Get paths to embeddings, relation weights, and edge types # with st.spinner('Downloading AI model...'): embed_path = hf_hub_download(repo_id="ayushnoori/galaxy", filename=(best_ckpt + "-thresh=4000_embeddings.pt"), token=st.secrets["HF_TOKEN"]) relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy", filename=(best_ckpt + "_relation_weights.pt"), token=st.secrets["HF_TOKEN"]) edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy", filename=(best_ckpt + "_edge_types.pt"), token=st.secrets["HF_TOKEN"]) return embed_path, relation_weights_path, edge_types_path @st.cache_data(show_spinner = 'Loading AI model...') def load_embeddings(embed_path, relation_weights_path, edge_types_path): # Load embeddings, relation weights, and edge types # with st.spinner('Loading AI model...'): embeddings = torch.load(embed_path) relation_weights = torch.load(relation_weights_path) edge_types = torch.load(edge_types_path) return embeddings, relation_weights, edge_types # Load knowledge graph and embeddings kg_nodes = load_kg() embed_path, relation_weights_path, edge_types_path = get_embeddings() embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path) # # Print source node type # st.write(f"Source Node Type: {st.session_state.query['source_node_type']}") # # Print source node # st.write(f"Source Node: {st.session_state.query['source_node']}") # # Print relation # st.write(f"Edge Type: {st.session_state.query['relation']}") # # Print target node type # st.write(f"Target Node Type: {st.session_state.query['target_node_type']}") # Compute predictions with st.spinner('Computing predictions...'): source_node_type = st.session_state.query['source_node_type'] source_node = st.session_state.query['source_node'] relation = st.session_state.query['relation'] target_node_type = st.session_state.query['target_node_type'] # Get source node index src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0] # Get relation index edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0] # Get target nodes indices target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy() dst_indices = target_nodes.node_index.values src_indices = np.repeat(src_index, len(dst_indices)) # Retrieve cached embeddings and apply activation function src_embeddings = embeddings[src_indices] dst_embeddings = embeddings[dst_indices] src_embeddings = F.leaky_relu(src_embeddings) dst_embeddings = F.leaky_relu(dst_embeddings) # Get relation weights rel_weights = relation_weights[edge_type_index] # Compute weighted dot product scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1) scores = torch.sigmoid(scores) # Add scores to dataframe target_nodes['score'] = scores.detach().numpy() target_nodes = target_nodes.sort_values(by = 'score', ascending = False) target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1) # Rename columns display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy() display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'score': 'Score', 'node_source': 'Database'}) # Define dictionary mapping node types to database URLs map_dbs = { 'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}", 'drug': lambda x: f"https://go.drugbank.com/drugs/{x}", 'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits 'disease': lambda x: x, # MONDO # pad with 0s to 7 digits 'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}", 'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}", 'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}", 'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}", 'pathway': lambda x: f"https://reactome.org/content/detail/{x}", 'anatomy': lambda x: x, } # Get name of database display_database = display_data['Database'].values[0] # Add URLs to database column display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1) # NODE SEARCH # Use multiselect to search for specific nodes selected_nodes = st.multiselect(f"Search for specific {st.session_state.query['source_node_type'].replace('_', ' ')} nodes to determine their ranking.", display_data.Name, placeholder = "Type to search...") # Filter nodes if len(selected_nodes) > 0: selected_display_data = display_data[display_data.Name.isin(selected_nodes)] # Show filtered nodes if target_node_type not in ['disease', 'anatomy']: st.dataframe(selected_display_data, use_container_width = True, column_config={"Database": st.column_config.LinkColumn(width = "small", help = "Click to visit external database.", display_text = display_database)}) else: st.dataframe(selected_display_data, use_container_width = True) # Plot rank vs. score using matplotlib st.markdown("**Rank vs. Score**") fig, ax = plt.subplots(figsize = (10, 6)) ax.plot(display_data['Rank'], display_data['Score']) ax.set_xlabel('Rank', fontsize = 12) ax.set_ylabel('Score', fontsize = 12) ax.set_xlim(1, display_data['Rank'].max()) # Add vertical line for selected nodes for i, node in selected_display_data.iterrows(): ax.axvline(node['Rank'], color = 'red', linestyle = '--', label = node['Name']) ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize = 10, color = 'red') # Show plot st.pyplot(fig) # FULL RESULTS # Show top ranked nodes st.subheader("Model Predictions", divider = "blue") top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0])) if target_node_type not in ['disease', 'anatomy']: st.dataframe(display_data.iloc[:top_k], use_container_width = True, column_config={"Database": st.column_config.LinkColumn(width = "small", help = "Click to visit external database.", display_text = display_database)}) else: st.dataframe(display_data.iloc[:top_k], use_container_width = True) # Save to session state st.session_state.predictions = display_data st.session_state.display_database = display_database