cipher-asap / pages /validate.py
ayushnoori's picture
Demo for ASAP grant renewal
f027c05
raw
history blame
4.41 kB
import streamlit as st
from menu import menu_with_redirect
# Standard imports
import numpy as np
import pandas as pd
# Path manipulation
from pathlib import Path
# Plotting
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'Arial'
import matplotlib.colors as mcolors
# Custom and other imports
import project_config
from utils import load_kg, load_kg_edges
# Redirect to app.py if not logged in, otherwise show the navigation menu
menu_with_redirect()
# Header
st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)
# Main content
# st.markdown(f"Hello, {st.session_state.name}!")
st.subheader("Validate Predictions", divider = "green")
# Print current query
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
# Coming soon
# st.write("Coming soon...")
source_node_type = st.session_state.query['source_node_type']
source_node = st.session_state.query['source_node']
relation = st.session_state.query['relation']
target_node_type = st.session_state.query['target_node_type']
predictions = st.session_state.predictions
kg_nodes = load_kg()
kg_edges = load_kg_edges()
# Convert tuple to hex
def rgba_to_hex(rgba):
return mcolors.to_hex(rgba[:3])
with st.spinner('Searching known relationships...'):
# Subset existing edges
edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)]
edge_subset = edge_subset[edge_subset.y_type == target_node_type]
# Merge edge subset with predictions
edges_in_kg = pd.merge(predictions, edge_subset[['relation', 'y_id']], left_on = 'ID', right_on = 'y_id', how = 'right')
edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False)
edges_in_kg = edges_in_kg.drop(columns = 'y_id')
# Rename relation to ground-truth
edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})
# If there exist edges in KG
if len(edges_in_kg) > 0:
with st.spinner('Plotting known relationships...'):
# Define a color map for different relations
color_map = plt.get_cmap('tab10')
# Group by relation and create separate plots
relations = edges_in_kg['Known Relation'].unique()
for idx, relation in enumerate(relations):
relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]
# Get a color from the color map
color = color_map(idx % color_map.N)
fig, ax = plt.subplots(figsize=(10, 3))
ax.plot(predictions['Rank'], predictions['Score'])
ax.set_xlabel('Rank', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_xlim(1, predictions['Rank'].max())
for i, node in relation_data.iterrows():
ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'])
# ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
# ax.set_title(f'{relation.replace("_", "-")}')
# ax.legend()
color_hex = rgba_to_hex(color)
# Write header in color of relation
st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h2>", unsafe_allow_html=True)
# Show plot
st.pyplot(fig)
# Drop known relation column
relation_data = relation_data.drop(columns = 'Known Relation')
if target_node_type not in ['disease', 'anatomy']:
st.dataframe(relation_data, use_container_width=True,
column_config={"Database": st.column_config.LinkColumn(width = "small",
help = "Click to visit external database.",
display_text = st.session_state.display_database)})
else:
st.dataframe(relation_data, use_container_width=True)
else:
st.error('No ground truth relationships found for the given query in the knowledge graph.', icon="✖️")