import streamlit as st from menu import menu_with_redirect # Standard imports import numpy as np import pandas as pd # Path manipulation from pathlib import Path # Plotting import matplotlib.pyplot as plt plt.rcParams['font.sans-serif'] = 'Arial' import matplotlib.colors as mcolors # Custom and other imports import project_config from utils import load_kg, load_kg_edges # Redirect to app.py if not logged in, otherwise show the navigation menu menu_with_redirect() # Header st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True) # Main content # st.markdown(f"Hello, {st.session_state.name}!") st.subheader("Validate Predictions", divider = "green") # Print current query st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}") # Coming soon # st.write("Coming soon...") source_node_type = st.session_state.query['source_node_type'] source_node = st.session_state.query['source_node'] relation = st.session_state.query['relation'] target_node_type = st.session_state.query['target_node_type'] predictions = st.session_state.predictions kg_nodes = load_kg() kg_edges = load_kg_edges() # Convert tuple to hex def rgba_to_hex(rgba): return mcolors.to_hex(rgba[:3]) with st.spinner('Searching known relationships...'): # Subset existing edges edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)] edge_subset = edge_subset[edge_subset.y_type == target_node_type] # Merge edge subset with predictions edges_in_kg = pd.merge(predictions, edge_subset[['relation', 'y_id']], left_on = 'ID', right_on = 'y_id', how = 'right') edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False) edges_in_kg = edges_in_kg.drop(columns = 'y_id') # Rename relation to ground-truth edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']] edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'}) # If there exist edges in KG if len(edges_in_kg) > 0: with st.spinner('Plotting known relationships...'): # Define a color map for different relations color_map = plt.get_cmap('tab10') # Group by relation and create separate plots relations = edges_in_kg['Known Relation'].unique() for idx, relation in enumerate(relations): relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation] # Get a color from the color map color = color_map(idx % color_map.N) fig, ax = plt.subplots(figsize=(10, 3)) ax.plot(predictions['Rank'], predictions['Score']) ax.set_xlabel('Rank', fontsize=12) ax.set_ylabel('Score', fontsize=12) ax.set_xlim(1, predictions['Rank'].max()) for i, node in relation_data.iterrows(): ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name']) # ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color) # ax.set_title(f'{relation.replace("_", "-")}') # ax.legend() color_hex = rgba_to_hex(color) # Write header in color of relation st.markdown(f"