Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from menu import menu_with_redirect | |
| # Standard imports | |
| import numpy as np | |
| import pandas as pd | |
| # Path manipulation | |
| from pathlib import Path | |
| # Plotting | |
| import matplotlib.pyplot as plt | |
| plt.rcParams['font.sans-serif'] = 'Arial' | |
| import matplotlib.colors as mcolors | |
| # Custom and other imports | |
| import project_config | |
| from utils import load_kg, load_kg_edges | |
| # Redirect to app.py if not logged in, otherwise show the navigation menu | |
| menu_with_redirect() | |
| # Header | |
| st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True) | |
| # Main content | |
| # st.markdown(f"Hello, {st.session_state.name}!") | |
| st.subheader("Validate Predictions", divider = "green") | |
| # Print current query | |
| st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}") | |
| # Coming soon | |
| # st.write("Coming soon...") | |
| source_node_type = st.session_state.query['source_node_type'] | |
| source_node = st.session_state.query['source_node'] | |
| relation = st.session_state.query['relation'] | |
| target_node_type = st.session_state.query['target_node_type'] | |
| predictions = st.session_state.predictions | |
| kg_nodes = load_kg() | |
| kg_edges = load_kg_edges() | |
| # Convert tuple to hex | |
| def rgba_to_hex(rgba): | |
| return mcolors.to_hex(rgba[:3]) | |
| with st.spinner('Searching known relationships...'): | |
| # Subset existing edges | |
| edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)] | |
| edge_subset = edge_subset[edge_subset.y_type == target_node_type] | |
| # Merge edge subset with predictions | |
| edges_in_kg = pd.merge(predictions, edge_subset[['relation', 'y_id']], left_on = 'ID', right_on = 'y_id', how = 'right') | |
| edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False) | |
| edges_in_kg = edges_in_kg.drop(columns = 'y_id') | |
| # Rename relation to ground-truth | |
| edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']] | |
| edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'}) | |
| # If there exist edges in KG | |
| if len(edges_in_kg) > 0: | |
| with st.spinner('Plotting known relationships...'): | |
| # Define a color map for different relations | |
| color_map = plt.get_cmap('tab10') | |
| # Group by relation and create separate plots | |
| relations = edges_in_kg['Known Relation'].unique() | |
| for idx, relation in enumerate(relations): | |
| relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation] | |
| # Get a color from the color map | |
| color = color_map(idx % color_map.N) | |
| fig, ax = plt.subplots(figsize=(10, 3)) | |
| ax.plot(predictions['Rank'], predictions['Score']) | |
| ax.set_xlabel('Rank', fontsize=12) | |
| ax.set_ylabel('Score', fontsize=12) | |
| ax.set_xlim(1, predictions['Rank'].max()) | |
| for i, node in relation_data.iterrows(): | |
| ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name']) | |
| # ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color) | |
| # ax.set_title(f'{relation.replace("_", "-")}') | |
| # ax.legend() | |
| color_hex = rgba_to_hex(color) | |
| # Write header in color of relation | |
| st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h2>", unsafe_allow_html=True) | |
| # Show plot | |
| st.pyplot(fig) | |
| # Drop known relation column | |
| relation_data = relation_data.drop(columns = 'Known Relation') | |
| if target_node_type not in ['disease', 'anatomy']: | |
| st.dataframe(relation_data, use_container_width=True, | |
| column_config={"Database": st.column_config.LinkColumn(width = "small", | |
| help = "Click to visit external database.", | |
| display_text = st.session_state.display_database)}) | |
| else: | |
| st.dataframe(relation_data, use_container_width=True) | |
| else: | |
| st.error('No ground truth relationships found for the given query in the knowledge graph.', icon="✖️") |