File size: 4,407 Bytes
40fa5b9 f128fe5 40fa5b9 7e225aa 40fa5b9 7e225aa 40fa5b9 7e225aa 40fa5b9 f128fe5 40fa5b9 ba1c7a0 ca764d6 f18a5e1 ca764d6 6efe11e f18a5e1 7e225aa e9c640b 7e225aa e9c640b 7e225aa e9c640b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import streamlit as st
from menu import menu_with_redirect
# Standard imports
import numpy as np
import pandas as pd
# Path manipulation
from pathlib import Path
# Plotting
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'Arial'
import matplotlib.colors as mcolors
# Custom and other imports
import project_config
from utils import load_kg, load_kg_edges
# Redirect to if not logged in, otherwise show the navigation menu
# Header
st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)
# Main content
# st.markdown(f"Hello, {}!")
st.subheader("Validate Predictions", divider = "green")
# Print current query
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
# Coming soon
# st.write("Coming soon...")
source_node_type = st.session_state.query['source_node_type']
source_node = st.session_state.query['source_node']
relation = st.session_state.query['relation']
target_node_type = st.session_state.query['target_node_type']
predictions = st.session_state.predictions
kg_nodes = load_kg()
kg_edges = load_kg_edges()
# Convert tuple to hex
def rgba_to_hex(rgba):
return mcolors.to_hex(rgba[:3])
with st.spinner('Searching known relationships...'):
# Subset existing edges
edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)]
edge_subset = edge_subset[edge_subset.y_type == target_node_type]
# Merge edge subset with predictions
edges_in_kg = pd.merge(predictions, edge_subset[['relation', 'y_id']], left_on = 'ID', right_on = 'y_id', how = 'right')
edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False)
edges_in_kg = edges_in_kg.drop(columns = 'y_id')
# Rename relation to ground-truth
edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})
# If there exist edges in KG
if len(edges_in_kg) > 0:
with st.spinner('Plotting known relationships...'):
# Define a color map for different relations
color_map = plt.get_cmap('tab10')
# Group by relation and create separate plots
relations = edges_in_kg['Known Relation'].unique()
for idx, relation in enumerate(relations):
relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]
# Get a color from the color map
color = color_map(idx % color_map.N)
fig, ax = plt.subplots(figsize=(10, 3))
ax.plot(predictions['Rank'], predictions['Score'])
ax.set_xlabel('Rank', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_xlim(1, predictions['Rank'].max())
for i, node in relation_data.iterrows():
ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'])
# ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
# ax.set_title(f'{relation.replace("_", "-")}')
# ax.legend()
color_hex = rgba_to_hex(color)
# Write header in color of relation
st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h2>", unsafe_allow_html=True)
# Show plot
# Drop known relation column
relation_data = relation_data.drop(columns = 'Known Relation')
if target_node_type not in ['disease', 'anatomy']:
st.dataframe(relation_data, use_container_width=True,
column_config={"Database": st.column_config.LinkColumn(width = "small",
help = "Click to visit external database.",
display_text = st.session_state.display_database)})
st.dataframe(relation_data, use_container_width=True)
st.error('No ground truth relationships found for the given query in the knowledge graph.', icon="✖️") |