File size: 4,910 Bytes
40fa5b9
f128fe5
40fa5b9
7e225aa
 
 
 
40fa5b9
 
 
7e225aa
 
 
 
 
40fa5b9
 
7e225aa
40fa5b9
f128fe5
 
40fa5b9
ba1c7a0
 
 
 
ca764d6
 
f18a5e1
ca764d6
6efe11e
 
 
f18a5e1
7e225aa
 
 
 
 
 
 
 
 
 
 
e9c640b
 
 
7e225aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9c640b
 
7e225aa
079a08e
 
 
 
 
 
 
 
 
 
e9c640b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
from menu import menu_with_redirect

# Standard imports
import numpy as np
import pandas as pd

# Path manipulation
from pathlib import Path

# Plotting
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'Arial'
import matplotlib.colors as mcolors

# Custom and other imports
import project_config
from utils import load_kg, load_kg_edges

# Redirect to app.py if not logged in, otherwise show the navigation menu
menu_with_redirect()

# Header
st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)

# Main content
# st.markdown(f"Hello, {st.session_state.name}!")

st.subheader("Validate Predictions", divider = "green")

# Print current query
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")

# Coming soon
# st.write("Coming soon...")

source_node_type = st.session_state.query['source_node_type']
source_node = st.session_state.query['source_node']
relation = st.session_state.query['relation']
target_node_type = st.session_state.query['target_node_type']
predictions = st.session_state.predictions

kg_nodes = load_kg()
kg_edges = load_kg_edges()

# Convert tuple to hex
def rgba_to_hex(rgba):
    return mcolors.to_hex(rgba[:3])

with st.spinner('Searching known relationships...'):

    # Subset existing edges
    edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)]
    edge_subset = edge_subset[edge_subset.y_type == target_node_type]

    # Merge edge subset with predictions
    edges_in_kg = pd.merge(predictions, edge_subset[['relation', 'y_id']], left_on = 'ID', right_on = 'y_id', how = 'right')
    edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False)
    edges_in_kg = edges_in_kg.drop(columns = 'y_id')

    # Rename relation to ground-truth
    edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
    edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})

# If there exist edges in KG
if len(edges_in_kg) > 0:

    with st.spinner('Saving validation results...'):

        # Cast long to wide
        val_results = edge_subset[['relation', 'y_id']].pivot_table(index='y_id', columns='relation', aggfunc='size', fill_value=0)
        val_results = (val_results > 0).astype(int).reset_index()
        val_results.columns = [val_results.columns[0]] + [x.replace('_', ' ').title() for x in val_results.columns[1:]]

        # Save validation results to session state
        st.session_state.validation = val_results

    with st.spinner('Plotting known relationships...'):

        # Define a color map for different relations
        color_map = plt.get_cmap('tab10')

        # Group by relation and create separate plots
        relations = edges_in_kg['Known Relation'].unique()
        for idx, relation in enumerate(relations):

            relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]

            # Get a color from the color map
            color = color_map(idx % color_map.N)

            fig, ax = plt.subplots(figsize=(10, 3))
            ax.plot(predictions['Rank'], predictions['Score'])
            ax.set_xlabel('Rank', fontsize=12)
            ax.set_ylabel('Score', fontsize=12)
            ax.set_xlim(1, predictions['Rank'].max())
            
            for i, node in relation_data.iterrows():
                ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'])
                # ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
            
            # ax.set_title(f'{relation.replace("_", "-")}')
            # ax.legend()
            color_hex = rgba_to_hex(color)

            # Write header in color of relation
            st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h2>", unsafe_allow_html=True)

            # Show plot
            st.pyplot(fig)

            # Drop known relation column
            relation_data = relation_data.drop(columns = 'Known Relation')
            if target_node_type not in ['disease', 'anatomy']:
                st.dataframe(relation_data, use_container_width=True,
                            column_config={"Database": st.column_config.LinkColumn(width = "small",
                                                                                    help = "Click to visit external database.",
                                                                                    display_text = st.session_state.display_database)})
            else:
                st.dataframe(relation_data, use_container_width=True)

else:

    st.error('No ground truth relationships found for the given query in the knowledge graph.', icon="✖️")