File size: 4,407 Bytes
f027c05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import streamlit as st
from menu import menu_with_redirect

# Standard imports
import numpy as np
import pandas as pd

# Path manipulation
from pathlib import Path

# Plotting
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'Arial'
import matplotlib.colors as mcolors

# Custom and other imports
import project_config
from utils import load_kg, load_kg_edges

# Redirect to app.py if not logged in, otherwise show the navigation menu
menu_with_redirect()

# Header
st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)

# Main content
# st.markdown(f"Hello, {st.session_state.name}!")

st.subheader("Validate Predictions", divider = "green")

# Print current query
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")

# Coming soon
# st.write("Coming soon...")

source_node_type = st.session_state.query['source_node_type']
source_node = st.session_state.query['source_node']
relation = st.session_state.query['relation']
target_node_type = st.session_state.query['target_node_type']
predictions = st.session_state.predictions

kg_nodes = load_kg()
kg_edges = load_kg_edges()

# Convert tuple to hex
def rgba_to_hex(rgba):
    return mcolors.to_hex(rgba[:3])

with st.spinner('Searching known relationships...'):

    # Subset existing edges
    edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)]
    edge_subset = edge_subset[edge_subset.y_type == target_node_type]

    # Merge edge subset with predictions
    edges_in_kg = pd.merge(predictions, edge_subset[['relation', 'y_id']], left_on = 'ID', right_on = 'y_id', how = 'right')
    edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False)
    edges_in_kg = edges_in_kg.drop(columns = 'y_id')

    # Rename relation to ground-truth
    edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
    edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})

# If there exist edges in KG
if len(edges_in_kg) > 0:

    with st.spinner('Plotting known relationships...'):

        # Define a color map for different relations
        color_map = plt.get_cmap('tab10')

        # Group by relation and create separate plots
        relations = edges_in_kg['Known Relation'].unique()
        for idx, relation in enumerate(relations):

            relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]

            # Get a color from the color map
            color = color_map(idx % color_map.N)

            fig, ax = plt.subplots(figsize=(10, 3))
            ax.plot(predictions['Rank'], predictions['Score'])
            ax.set_xlabel('Rank', fontsize=12)
            ax.set_ylabel('Score', fontsize=12)
            ax.set_xlim(1, predictions['Rank'].max())
            
            for i, node in relation_data.iterrows():
                ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'])
                # ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
            
            # ax.set_title(f'{relation.replace("_", "-")}')
            # ax.legend()
            color_hex = rgba_to_hex(color)

            # Write header in color of relation
            st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h2>", unsafe_allow_html=True)

            # Show plot
            st.pyplot(fig)

            # Drop known relation column
            relation_data = relation_data.drop(columns = 'Known Relation')
            if target_node_type not in ['disease', 'anatomy']:
                st.dataframe(relation_data, use_container_width=True,
                            column_config={"Database": st.column_config.LinkColumn(width = "small",
                                                                                    help = "Click to visit external database.",
                                                                                    display_text = st.session_state.display_database)})
            else:
                st.dataframe(relation_data, use_container_width=True)

else:

    st.error('No ground truth relationships found for the given query in the knowledge graph.', icon="✖️")