Spaces:
Sleeping
Sleeping

Introducing team-specific defaults, immediate about page loading, and validation display on predict page
079a08e
import streamlit as st | |
from menu import menu_with_redirect | |
# Standard imports | |
import numpy as np | |
import pandas as pd | |
# Path manipulation | |
from pathlib import Path | |
# Plotting | |
import matplotlib.pyplot as plt | |
plt.rcParams['font.sans-serif'] = 'Arial' | |
import matplotlib.colors as mcolors | |
# Custom and other imports | |
import project_config | |
from utils import load_kg, load_kg_edges | |
# Redirect to app.py if not logged in, otherwise show the navigation menu | |
menu_with_redirect() | |
# Header | |
st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True) | |
# Main content | |
# st.markdown(f"Hello, {st.session_state.name}!") | |
st.subheader("Validate Predictions", divider = "green") | |
# Print current query | |
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}") | |
# Coming soon | |
# st.write("Coming soon...") | |
source_node_type = st.session_state.query['source_node_type'] | |
source_node = st.session_state.query['source_node'] | |
relation = st.session_state.query['relation'] | |
target_node_type = st.session_state.query['target_node_type'] | |
predictions = st.session_state.predictions | |
kg_nodes = load_kg() | |
kg_edges = load_kg_edges() | |
# Convert tuple to hex | |
def rgba_to_hex(rgba): | |
return mcolors.to_hex(rgba[:3]) | |
with st.spinner('Searching known relationships...'): | |
# Subset existing edges | |
edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)] | |
edge_subset = edge_subset[edge_subset.y_type == target_node_type] | |
# Merge edge subset with predictions | |
edges_in_kg = pd.merge(predictions, edge_subset[['relation', 'y_id']], left_on = 'ID', right_on = 'y_id', how = 'right') | |
edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False) | |
edges_in_kg = edges_in_kg.drop(columns = 'y_id') | |
# Rename relation to ground-truth | |
edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']] | |
edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'}) | |
# If there exist edges in KG | |
if len(edges_in_kg) > 0: | |
with st.spinner('Saving validation results...'): | |
# Cast long to wide | |
val_results = edge_subset[['relation', 'y_id']].pivot_table(index='y_id', columns='relation', aggfunc='size', fill_value=0) | |
val_results = (val_results > 0).astype(int).reset_index() | |
val_results.columns = [val_results.columns[0]] + [x.replace('_', ' ').title() for x in val_results.columns[1:]] | |
# Save validation results to session state | |
st.session_state.validation = val_results | |
with st.spinner('Plotting known relationships...'): | |
# Define a color map for different relations | |
color_map = plt.get_cmap('tab10') | |
# Group by relation and create separate plots | |
relations = edges_in_kg['Known Relation'].unique() | |
for idx, relation in enumerate(relations): | |
relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation] | |
# Get a color from the color map | |
color = color_map(idx % color_map.N) | |
fig, ax = plt.subplots(figsize=(10, 3)) | |
ax.plot(predictions['Rank'], predictions['Score']) | |
ax.set_xlabel('Rank', fontsize=12) | |
ax.set_ylabel('Score', fontsize=12) | |
ax.set_xlim(1, predictions['Rank'].max()) | |
for i, node in relation_data.iterrows(): | |
ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name']) | |
# ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color) | |
# ax.set_title(f'{relation.replace("_", "-")}') | |
# ax.legend() | |
color_hex = rgba_to_hex(color) | |
# Write header in color of relation | |
st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h2>", unsafe_allow_html=True) | |
# Show plot | |
st.pyplot(fig) | |
# Drop known relation column | |
relation_data = relation_data.drop(columns = 'Known Relation') | |
if target_node_type not in ['disease', 'anatomy']: | |
st.dataframe(relation_data, use_container_width=True, | |
column_config={"Database": st.column_config.LinkColumn(width = "small", | |
help = "Click to visit external database.", | |
display_text = st.session_state.display_database)}) | |
else: | |
st.dataframe(relation_data, use_container_width=True) | |
else: | |
st.error('No ground truth relationships found for the given query in the knowledge graph.', icon="✖️") |