Spaces:
Sleeping
Sleeping
Commit
·
f027c05
1
Parent(s):
e6b01d4
Demo for ASAP grant renewal
Browse files- .gitattributes +2 -0
- .gitignore +14 -0
- .streamlit/config.toml +5 -0
- README.md +5 -5
- app.py +157 -0
- data/kg_edge_types.csv +3 -0
- data/kg_edges.csv +3 -0
- data/kg_node_types.csv +3 -0
- data/kg_nodes.csv +3 -0
- media/about_header.svg +1 -0
- media/cipher_logo.svg +1 -0
- media/explore_header.svg +1 -0
- media/input_header.svg +1 -0
- media/pfp/anoori.png +3 -0
- media/pfp/cipher.png +3 -0
- media/pfp/gcroft.png +3 -0
- media/pfp/jpowell.png +3 -0
- media/pfp/lstuder.png +3 -0
- media/pfp/mzitnik.png +3 -0
- media/pfp/vkhurana.png +3 -0
- media/predict_header.svg +1 -0
- media/validate_header.svg +1 -0
- menu.py +51 -0
- pages/about.py +30 -0
- pages/admin.py +13 -0
- pages/input.py +107 -0
- pages/predict.py +199 -0
- pages/validate.py +114 -0
- project_config.py +17 -0
- requirements.txt +11 -0
- utils.py +75 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
media/pfp/*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/*.csv filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Ignore Mac temporary files
|
2 |
+
*.DS_Store
|
3 |
+
.DS_Store
|
4 |
+
|
5 |
+
# Ignore python cache files
|
6 |
+
__pycache__/
|
7 |
+
|
8 |
+
# Ignore model files
|
9 |
+
data/*.pt
|
10 |
+
|
11 |
+
# Ignore secrets
|
12 |
+
.streamlit/secrets.toml
|
13 |
+
.streamlit/cipher_asap-user-db.json
|
14 |
+
test.ipynb
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[client]
|
2 |
+
showSidebarNavigation = false
|
3 |
+
|
4 |
+
[theme]
|
5 |
+
base="light"
|
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
1 |
---
|
2 |
+
title: CIPHER
|
3 |
+
emoji: 🧠
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: purple
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: 1.34.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
app.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
# Do not load st-gsheets-connection
|
3 |
+
# from streamlit_gsheets import GSheetsConnection
|
4 |
+
import gspread
|
5 |
+
from oauth2client.service_account import ServiceAccountCredentials
|
6 |
+
import hmac
|
7 |
+
|
8 |
+
# Standard imports
|
9 |
+
import pandas as pd
|
10 |
+
|
11 |
+
# Custom and other imports
|
12 |
+
import project_config
|
13 |
+
# from utils import add_logo
|
14 |
+
from menu import menu
|
15 |
+
|
16 |
+
# Initialize st.session_state.role to None
|
17 |
+
if "role" not in st.session_state:
|
18 |
+
st.session_state.role = None
|
19 |
+
|
20 |
+
# # Retrieve the role from Session State to initialize the widget
|
21 |
+
# st.session_state._role = st.session_state.role
|
22 |
+
|
23 |
+
# def set_role():
|
24 |
+
# # Callback function to save the role selection to Session State
|
25 |
+
# st.session_state.role = st.session_state._role
|
26 |
+
|
27 |
+
|
28 |
+
# From https://stackoverflow.com/questions/55961295/serviceaccountcredentials-from-json-keyfile-name-equivalent-for-remote-json
|
29 |
+
# See also https://www.slingacademy.com/article/pandas-how-to-read-and-update-google-sheet-files/
|
30 |
+
# See also https://docs.streamlit.io/develop/tutorials/databases/private-gsheet
|
31 |
+
# Note that the secrets cannot be passed in a group in HuggingFace Spaces,
|
32 |
+
# which is required for the native Streamlit implementation
|
33 |
+
def create_keyfile_dict():
|
34 |
+
variables_keys = {
|
35 |
+
# "spreadsheet": st.secrets['spreadsheet'], # spreadsheet
|
36 |
+
"type": st.secrets['type'], # type
|
37 |
+
"project_id": st.secrets['project_id'], # project_id
|
38 |
+
"private_key_id": st.secrets['private_key_id'], # private_key_id
|
39 |
+
# Have to replace \n with new lines (^l in Word) by hand
|
40 |
+
"private_key": st.secrets['private_key'], # private_key
|
41 |
+
"client_email": st.secrets['client_email'], # client_email
|
42 |
+
"client_id": st.secrets['client_id'], # client_id
|
43 |
+
"auth_uri": st.secrets['auth_uri'], # auth_uri
|
44 |
+
"token_uri": st.secrets['token_uri'], # token_uri
|
45 |
+
"auth_provider_x509_cert_url": st.secrets['auth_provider_x509_cert_url'], # auth_provider_x509_cert_url
|
46 |
+
"client_x509_cert_url": st.secrets['client_x509_cert_url'], # client_x509_cert_url
|
47 |
+
"universe_domain": st.secrets['universe_domain'] # universe_domain
|
48 |
+
}
|
49 |
+
return variables_keys
|
50 |
+
|
51 |
+
|
52 |
+
def check_password():
|
53 |
+
"""Returns `True` if the user had a correct password."""
|
54 |
+
|
55 |
+
def login_form():
|
56 |
+
"""Form with widgets to collect user information"""
|
57 |
+
# Header
|
58 |
+
col1, col2, col3 = st.columns(3)
|
59 |
+
with col2:
|
60 |
+
st.image(str(project_config.MEDIA_DIR / 'cipher_logo.svg'), width=300)
|
61 |
+
|
62 |
+
# col1, col2, col3 = st.columns(3)
|
63 |
+
# with col1:
|
64 |
+
# st.header("Log In")
|
65 |
+
|
66 |
+
with st.form("Credentials"):
|
67 |
+
st.text_input("Username", key="username")
|
68 |
+
st.text_input("Password", type="password", key="password")
|
69 |
+
st.form_submit_button("Log In", on_click=password_entered)
|
70 |
+
|
71 |
+
def password_entered():
|
72 |
+
"""Checks whether a password entered by the user is correct."""
|
73 |
+
|
74 |
+
# Define the scope
|
75 |
+
scope = [
|
76 |
+
'https://spreadsheets.google.com/feeds',
|
77 |
+
'https://www.googleapis.com/auth/drive'
|
78 |
+
]
|
79 |
+
|
80 |
+
# Add credentials to the account
|
81 |
+
creds = ServiceAccountCredentials.from_json_keyfile_dict(create_keyfile_dict(), scope)
|
82 |
+
|
83 |
+
# Authenticate and create the client
|
84 |
+
client = gspread.authorize(creds)
|
85 |
+
|
86 |
+
# Open the spreadsheet
|
87 |
+
sheet = client.open_by_url(st.secrets['spreadsheet']).worksheet("user_db")
|
88 |
+
data = sheet.get_all_records()
|
89 |
+
user_db = pd.DataFrame(data)
|
90 |
+
|
91 |
+
# # Create a connection object to Google Sheets
|
92 |
+
# conn = st.connection("gsheets", type=GSheetsConnection)
|
93 |
+
|
94 |
+
# # Read the user database
|
95 |
+
# user_db = conn.read()
|
96 |
+
# user_db.dropna(axis=0, how="all", inplace=True)
|
97 |
+
# user_db.dropna(axis=1, how="all", inplace=True)
|
98 |
+
|
99 |
+
# Check if the username is in the database
|
100 |
+
if st.session_state["username"] in user_db.username.values:
|
101 |
+
|
102 |
+
st.session_state["username_correct"] = True
|
103 |
+
|
104 |
+
# Check if the password is correct
|
105 |
+
if hmac.compare_digest(
|
106 |
+
st.session_state["password"],
|
107 |
+
user_db.loc[user_db.username == st.session_state["username"], "password"].values[0],
|
108 |
+
):
|
109 |
+
|
110 |
+
st.session_state["password_correct"] = True
|
111 |
+
|
112 |
+
# Check if the username is an admin
|
113 |
+
if st.session_state["username"] in user_db[user_db.role == "admin"].username.values:
|
114 |
+
st.session_state["role"] = "admin"
|
115 |
+
else:
|
116 |
+
st.session_state["role"] = "user"
|
117 |
+
|
118 |
+
# Retrieve and store user name and team
|
119 |
+
st.session_state["name"] = user_db.loc[user_db.username == st.session_state["username"], "name"].values[0]
|
120 |
+
st.session_state["team"] = user_db.loc[user_db.username == st.session_state["username"], "team"].values[0]
|
121 |
+
# st.session_state["profile_pic"] = user_db.loc[user_db.username == st.session_state["username"], "profile_pic"].values[0]
|
122 |
+
st.session_state["profile_pic"] = st.session_state["username"]
|
123 |
+
|
124 |
+
# Don't store the username or password
|
125 |
+
del st.session_state["password"]
|
126 |
+
# del st.session_state["username"]
|
127 |
+
|
128 |
+
else:
|
129 |
+
st.session_state["password_correct"] = False
|
130 |
+
|
131 |
+
else:
|
132 |
+
st.session_state["username_correct"] = False
|
133 |
+
st.session_state["password_correct"] = False
|
134 |
+
|
135 |
+
# Return True if the username + password is validated
|
136 |
+
if st.session_state.get("password_correct", False):
|
137 |
+
return True
|
138 |
+
|
139 |
+
# Show inputs for username + password
|
140 |
+
login_form()
|
141 |
+
if "password_correct" in st.session_state:
|
142 |
+
|
143 |
+
if not st.session_state["username_correct"]:
|
144 |
+
st.error("User not found.")
|
145 |
+
elif not st.session_state["password_correct"]:
|
146 |
+
st.error("The password you entered is incorrect.")
|
147 |
+
else:
|
148 |
+
st.error("An unexpected error occurred.")
|
149 |
+
|
150 |
+
return False
|
151 |
+
|
152 |
+
menu() # Render the dynamic menu!
|
153 |
+
|
154 |
+
if not check_password():
|
155 |
+
st.stop()
|
156 |
+
|
157 |
+
st.switch_page("pages/about.py")
|
data/kg_edge_types.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2ee79b2f5021304a4dd82581568e8a8c940f94b29cd1206f7730bdff6b82cab4
|
3 |
+
size 5288
|
data/kg_edges.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab7d0e23c56381abf8e214cc5d4fae4e6a8b98957c8f2e5272b4f800953b1461
|
3 |
+
size 2765378133
|
data/kg_node_types.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1a0afff52deec5f48689a22a479d14cd49333759e054624366687ec4ef306c8
|
3 |
+
size 192
|
data/kg_nodes.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a21c42a1ee345195038d438854ee5d4befa7b1984e5efa4865ed74825a75b6d9
|
3 |
+
size 8529743
|
media/about_header.svg
ADDED
|
media/cipher_logo.svg
ADDED
|
media/explore_header.svg
ADDED
|
media/input_header.svg
ADDED
|
media/pfp/anoori.png
ADDED
![]() |
Git LFS Details
|
media/pfp/cipher.png
ADDED
![]() |
Git LFS Details
|
media/pfp/gcroft.png
ADDED
![]() |
Git LFS Details
|
media/pfp/jpowell.png
ADDED
![]() |
Git LFS Details
|
media/pfp/lstuder.png
ADDED
![]() |
Git LFS Details
|
media/pfp/mzitnik.png
ADDED
![]() |
Git LFS Details
|
media/pfp/vkhurana.png
ADDED
![]() |
Git LFS Details
|
media/predict_header.svg
ADDED
|
media/validate_header.svg
ADDED
|
menu.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From https://docs.streamlit.io/develop/tutorials/multipage/st.page_link-nav
|
2 |
+
import streamlit as st
|
3 |
+
import os
|
4 |
+
import project_config
|
5 |
+
|
6 |
+
def authenticated_menu():
|
7 |
+
|
8 |
+
# Insert profile picture
|
9 |
+
pfp_path = str(project_config.MEDIA_DIR / 'pfp' / f"{st.session_state.profile_pic}.png")
|
10 |
+
if not os.path.exists(pfp_path):
|
11 |
+
pfp_path = str(project_config.MEDIA_DIR / 'pfp' / "cipher.png")
|
12 |
+
st.sidebar.image(pfp_path, use_column_width=True)
|
13 |
+
st.sidebar.markdown("---")
|
14 |
+
|
15 |
+
# Show a navigation menu for authenticated users
|
16 |
+
# st.sidebar.page_link("app.py", label="Switch Accounts", icon="🔒")
|
17 |
+
st.sidebar.page_link("pages/about.py", label="About", icon="📖")
|
18 |
+
st.sidebar.page_link("pages/input.py", label="Input", icon="💡")
|
19 |
+
st.sidebar.page_link("pages/predict.py", label="Predict", icon="🔍",
|
20 |
+
disabled=("query" not in st.session_state))
|
21 |
+
st.sidebar.page_link("pages/validate.py", label="Validate", icon="✅",
|
22 |
+
disabled=("query" not in st.session_state))
|
23 |
+
# st.sidebar.page_link("pages/explore.py", label="Explore", icon="🔍")
|
24 |
+
if st.session_state.role in ["admin"]:
|
25 |
+
st.sidebar.page_link("pages/admin.py", label="Manage Users", icon="🔧")
|
26 |
+
|
27 |
+
# Show the logout button
|
28 |
+
st.sidebar.markdown("---")
|
29 |
+
st.sidebar.button("Log Out", on_click=lambda: st.session_state.clear())
|
30 |
+
|
31 |
+
|
32 |
+
def unauthenticated_menu():
|
33 |
+
|
34 |
+
# Show a navigation menu for unauthenticated users
|
35 |
+
st.sidebar.page_link("app.py", label="Log In", icon="🔒")
|
36 |
+
|
37 |
+
|
38 |
+
def menu():
|
39 |
+
# Determine if a user is logged in or not, then show the correct navigation menu
|
40 |
+
if "role" not in st.session_state or st.session_state.role is None:
|
41 |
+
unauthenticated_menu()
|
42 |
+
return
|
43 |
+
authenticated_menu()
|
44 |
+
|
45 |
+
|
46 |
+
def menu_with_redirect():
|
47 |
+
# Redirect users to the main page if not logged in, otherwise continue to
|
48 |
+
# render the navigation menu
|
49 |
+
if "role" not in st.session_state or st.session_state.role is None:
|
50 |
+
st.switch_page("app.py")
|
51 |
+
menu()
|
pages/about.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from menu import menu_with_redirect
|
3 |
+
|
4 |
+
# Path manipulation
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
# Custom and other imports
|
8 |
+
import project_config
|
9 |
+
|
10 |
+
# Redirect to app.py if not logged in, otherwise show the navigation menu
|
11 |
+
menu_with_redirect()
|
12 |
+
|
13 |
+
# Header
|
14 |
+
st.image(str(project_config.MEDIA_DIR / 'about_header.svg'), use_column_width=True)
|
15 |
+
|
16 |
+
# Main content
|
17 |
+
st.markdown("Welcome to CIPHER, a knowledge-grounded artificial intelligence (AI) system for **C**ontextually **I**nformed **P**recision **HE**althca**R**e in Parkinson's disease (PD).")
|
18 |
+
|
19 |
+
# Subheader
|
20 |
+
st.subheader("About CIPHER", divider = "grey")
|
21 |
+
|
22 |
+
st.markdown("""
|
23 |
+
CIPHER is a knowledge graph-based AI algorithm for diagnostic and therapeutic discovery in PD.
|
24 |
+
|
25 |
+
*Knowledge graph construction.* To create CIPHER, we integrated diverse public information about basic biomedical interactions into a harmonized data platform amenable for training large-scale AI models. Specifically, we constructed a multiscale heterogeneous knowledge graph (KG) with *n = 143,093* nodes and *n = 7,048,795* edges by curating 36 high-quality primary data sources, ontologies, and knowledge bases.
|
26 |
+
|
27 |
+
*Model training.* Next, to convert this trove of knowledge into an AI model with diagnostic and therapeutic capabilities, we employed graph representation learning, a deep learning to model biomedical networks by embedding graphs into informative low-dimensional vector spaces. We trained a state-of-the-art heterogeneous graph transformer to learn graph embeddings that encode the relationships in the KG.
|
28 |
+
|
29 |
+
Through CIPHER, we seek to enable molecular subtyping and patient stratification of PD by integrating genetic and clinical progression data (*e.g.*, PPMI and HBS2.0 cohorts) and nominate genes, proteins, and pathways for in-depth mechanistic studies in stem cell and other PD models.
|
30 |
+
""")
|
pages/admin.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from menu import menu_with_redirect
|
3 |
+
|
4 |
+
# Redirect to app.py if not logged in, otherwise show the navigation menu
|
5 |
+
menu_with_redirect()
|
6 |
+
|
7 |
+
# Verify the user's role
|
8 |
+
if st.session_state.role not in ["admin"]:
|
9 |
+
st.warning("You do not have permission to view this page.")
|
10 |
+
st.stop()
|
11 |
+
|
12 |
+
st.title("User Management")
|
13 |
+
st.markdown(f"You are currently logged with the role of {st.session_state.role}.")
|
pages/input.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from menu import menu_with_redirect
|
3 |
+
|
4 |
+
# Standard imports
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
# Path manipulation
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
# Custom and other imports
|
12 |
+
import project_config
|
13 |
+
from utils import load_kg
|
14 |
+
|
15 |
+
# Redirect to app.py if not logged in, otherwise show the navigation menu
|
16 |
+
menu_with_redirect()
|
17 |
+
|
18 |
+
# Header
|
19 |
+
st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)
|
20 |
+
|
21 |
+
# Main content
|
22 |
+
# st.markdown(f"Hello, {st.session_state.name}!")
|
23 |
+
|
24 |
+
st.subheader("Construct Query", divider = "red")
|
25 |
+
|
26 |
+
# # Checkbox to allow reverse edges
|
27 |
+
# allow_reverse_edges = st.checkbox("Reverse Edges", value = False)
|
28 |
+
|
29 |
+
# Load knowledge graph
|
30 |
+
kg_nodes = load_kg()
|
31 |
+
|
32 |
+
with st.spinner('Loading knowledge graph...'):
|
33 |
+
# kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
|
34 |
+
node_types = pd.read_csv(project_config.DATA_DIR / 'kg_node_types.csv')
|
35 |
+
edge_types = pd.read_csv(project_config.DATA_DIR / 'kg_edge_types.csv')
|
36 |
+
|
37 |
+
# if not allow_reverse_edges:
|
38 |
+
# edge_types = edge_types[edge_types.direction == 'forward']
|
39 |
+
|
40 |
+
# If query is not in session state, initialize it
|
41 |
+
if "query" not in st.session_state:
|
42 |
+
source_node_type_index = 0
|
43 |
+
source_node_index = 0
|
44 |
+
target_node_type_index = 0
|
45 |
+
relation_index = 0
|
46 |
+
else:
|
47 |
+
source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
|
48 |
+
source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
|
49 |
+
target_node_type_index = st.session_state.query_options['target_node_type'].index(st.session_state.query['target_node_type'])
|
50 |
+
relation_index = st.session_state.query_options['relation'].index(st.session_state.query['relation'])
|
51 |
+
|
52 |
+
# Select source node type
|
53 |
+
# source_node_type_options = node_types['node_type']
|
54 |
+
# source_node_type = st.selectbox("Source Node Type", source_node_type_options,
|
55 |
+
# format_func = lambda x: x.replace("_", " "), index = source_node_type_index)
|
56 |
+
source_node_type = "disease"
|
57 |
+
|
58 |
+
# Select source node
|
59 |
+
# source_node_options = kg_nodes[kg_nodes['node_type'] == source_node_type]['node_name']
|
60 |
+
# source_node = st.selectbox("Source Node", source_node_options,
|
61 |
+
# index = source_node_index)
|
62 |
+
source_node = "Parkinson disease"
|
63 |
+
|
64 |
+
# Select target node type
|
65 |
+
target_node_type_options = edge_types[edge_types.x_type == source_node_type].y_type.unique()
|
66 |
+
target_node_type = st.selectbox("Target Node Type", target_node_type_options,
|
67 |
+
format_func = lambda x: x.replace("_", " "), index = target_node_type_index)
|
68 |
+
|
69 |
+
# Select relation
|
70 |
+
relation_options = edge_types[(edge_types.x_type == source_node_type) & (edge_types.y_type == target_node_type)].relation.unique()
|
71 |
+
relation = st.selectbox("Edge Type", relation_options,
|
72 |
+
format_func = lambda x: x.replace("_", "-"), index = relation_index)
|
73 |
+
|
74 |
+
# Button to submit query
|
75 |
+
if st.button("Submit Query"):
|
76 |
+
|
77 |
+
# Save query to session state
|
78 |
+
st.session_state.query = {
|
79 |
+
"source_node_type": source_node_type,
|
80 |
+
"source_node": source_node,
|
81 |
+
"target_node_type": target_node_type,
|
82 |
+
"relation": relation
|
83 |
+
}
|
84 |
+
|
85 |
+
# Save query options to session state
|
86 |
+
st.session_state.query_options = {
|
87 |
+
# "source_node_type": list(source_node_type_options),
|
88 |
+
# "source_node": list(source_node_options),
|
89 |
+
"source_node_type": ["disease"],
|
90 |
+
"source_node": ["Parkinson disease"],
|
91 |
+
"target_node_type": list(target_node_type_options),
|
92 |
+
"relation": list(relation_options)
|
93 |
+
}
|
94 |
+
|
95 |
+
# # Write query to console
|
96 |
+
# st.write("Current Query:")
|
97 |
+
# st.write(st.session_state.query)
|
98 |
+
st.write("Query submitted.")
|
99 |
+
|
100 |
+
# Switch to the Predict page
|
101 |
+
st.switch_page("pages/predict.py")
|
102 |
+
|
103 |
+
|
104 |
+
# st.subheader("Knowledge Graph", divider = "red")
|
105 |
+
# display_data = kg_nodes[['node_id', 'node_type', 'node_name', 'node_source']].copy()
|
106 |
+
# display_data = display_data.rename(columns = {'node_id': 'ID', 'node_type': 'Type', 'node_name': 'Name', 'node_source': 'Database'})
|
107 |
+
# st.dataframe(display_data, use_container_width = True)
|
pages/predict.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from menu import menu_with_redirect
|
3 |
+
|
4 |
+
# Standard imports
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
# Path manipulation
|
12 |
+
from pathlib import Path
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
+
|
15 |
+
# Plotting
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
plt.rcParams['font.sans-serif'] = 'Arial'
|
18 |
+
|
19 |
+
# Custom and other imports
|
20 |
+
import project_config
|
21 |
+
from utils import capitalize_after_slash, load_kg
|
22 |
+
|
23 |
+
# Redirect to app.py if not logged in, otherwise show the navigation menu
|
24 |
+
menu_with_redirect()
|
25 |
+
|
26 |
+
# Header
|
27 |
+
st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)
|
28 |
+
|
29 |
+
# Main content
|
30 |
+
# st.markdown(f"Hello, {st.session_state.name}!")
|
31 |
+
|
32 |
+
st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")
|
33 |
+
|
34 |
+
# Print current query
|
35 |
+
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
|
36 |
+
|
37 |
+
@st.cache_data(show_spinner = 'Downloading AI model...')
|
38 |
+
def get_embeddings():
|
39 |
+
# Get checkpoint name
|
40 |
+
# best_ckpt = "2024_05_22_11_59_43_epoch=18-step=22912"
|
41 |
+
best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
|
42 |
+
# best_ckpt = "2024_03_29_04_12_52_epoch=3-step=54291"
|
43 |
+
|
44 |
+
# Get paths to embeddings, relation weights, and edge types
|
45 |
+
# with st.spinner('Downloading AI model...'):
|
46 |
+
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
47 |
+
filename=(best_ckpt + "-thresh=4000_embeddings.pt"),
|
48 |
+
token=st.secrets["HF_TOKEN"])
|
49 |
+
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
50 |
+
filename=(best_ckpt + "_relation_weights.pt"),
|
51 |
+
token=st.secrets["HF_TOKEN"])
|
52 |
+
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
53 |
+
filename=(best_ckpt + "_edge_types.pt"),
|
54 |
+
token=st.secrets["HF_TOKEN"])
|
55 |
+
return embed_path, relation_weights_path, edge_types_path
|
56 |
+
|
57 |
+
@st.cache_data(show_spinner = 'Loading AI model...')
|
58 |
+
def load_embeddings(embed_path, relation_weights_path, edge_types_path):
|
59 |
+
# Load embeddings, relation weights, and edge types
|
60 |
+
# with st.spinner('Loading AI model...'):
|
61 |
+
embeddings = torch.load(embed_path)
|
62 |
+
relation_weights = torch.load(relation_weights_path)
|
63 |
+
edge_types = torch.load(edge_types_path)
|
64 |
+
|
65 |
+
return embeddings, relation_weights, edge_types
|
66 |
+
|
67 |
+
# Load knowledge graph and embeddings
|
68 |
+
kg_nodes = load_kg()
|
69 |
+
embed_path, relation_weights_path, edge_types_path = get_embeddings()
|
70 |
+
embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
|
71 |
+
|
72 |
+
# # Print source node type
|
73 |
+
# st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
|
74 |
+
|
75 |
+
# # Print source node
|
76 |
+
# st.write(f"Source Node: {st.session_state.query['source_node']}")
|
77 |
+
|
78 |
+
# # Print relation
|
79 |
+
# st.write(f"Edge Type: {st.session_state.query['relation']}")
|
80 |
+
|
81 |
+
# # Print target node type
|
82 |
+
# st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
|
83 |
+
|
84 |
+
# Compute predictions
|
85 |
+
with st.spinner('Computing predictions...'):
|
86 |
+
|
87 |
+
source_node_type = st.session_state.query['source_node_type']
|
88 |
+
source_node = st.session_state.query['source_node']
|
89 |
+
relation = st.session_state.query['relation']
|
90 |
+
target_node_type = st.session_state.query['target_node_type']
|
91 |
+
|
92 |
+
# Get source node index
|
93 |
+
src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
|
94 |
+
|
95 |
+
# Get relation index
|
96 |
+
edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
|
97 |
+
|
98 |
+
# Get target nodes indices
|
99 |
+
target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
|
100 |
+
dst_indices = target_nodes.node_index.values
|
101 |
+
src_indices = np.repeat(src_index, len(dst_indices))
|
102 |
+
|
103 |
+
# Retrieve cached embeddings and apply activation function
|
104 |
+
src_embeddings = embeddings[src_indices]
|
105 |
+
dst_embeddings = embeddings[dst_indices]
|
106 |
+
src_embeddings = F.leaky_relu(src_embeddings)
|
107 |
+
dst_embeddings = F.leaky_relu(dst_embeddings)
|
108 |
+
|
109 |
+
# Get relation weights
|
110 |
+
rel_weights = relation_weights[edge_type_index]
|
111 |
+
|
112 |
+
# Compute weighted dot product
|
113 |
+
scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
|
114 |
+
scores = torch.sigmoid(scores)
|
115 |
+
|
116 |
+
# Add scores to dataframe
|
117 |
+
target_nodes['score'] = scores.detach().numpy()
|
118 |
+
target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
|
119 |
+
target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
|
120 |
+
|
121 |
+
# Rename columns
|
122 |
+
display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
|
123 |
+
display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'score': 'Score', 'node_source': 'Database'})
|
124 |
+
|
125 |
+
# Define dictionary mapping node types to database URLs
|
126 |
+
map_dbs = {
|
127 |
+
'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}",
|
128 |
+
'drug': lambda x: f"https://go.drugbank.com/drugs/{x}",
|
129 |
+
'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits
|
130 |
+
'disease': lambda x: x, # MONDO
|
131 |
+
# pad with 0s to 7 digits
|
132 |
+
'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
|
133 |
+
'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
|
134 |
+
'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
|
135 |
+
'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}",
|
136 |
+
'pathway': lambda x: f"https://reactome.org/content/detail/{x}",
|
137 |
+
'anatomy': lambda x: x,
|
138 |
+
}
|
139 |
+
|
140 |
+
# Get name of database
|
141 |
+
display_database = display_data['Database'].values[0]
|
142 |
+
|
143 |
+
# Add URLs to database column
|
144 |
+
display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
|
145 |
+
|
146 |
+
|
147 |
+
# NODE SEARCH
|
148 |
+
|
149 |
+
# Use multiselect to search for specific nodes
|
150 |
+
selected_nodes = st.multiselect(f"Search for specific {st.session_state.query['source_node_type'].replace('_', ' ')} nodes to determine their ranking.",
|
151 |
+
display_data.Name, placeholder = "Type to search...")
|
152 |
+
|
153 |
+
# Filter nodes
|
154 |
+
if len(selected_nodes) > 0:
|
155 |
+
selected_display_data = display_data[display_data.Name.isin(selected_nodes)]
|
156 |
+
|
157 |
+
# Show filtered nodes
|
158 |
+
if target_node_type not in ['disease', 'anatomy']:
|
159 |
+
st.dataframe(selected_display_data, use_container_width = True,
|
160 |
+
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
161 |
+
help = "Click to visit external database.",
|
162 |
+
display_text = display_database)})
|
163 |
+
else:
|
164 |
+
st.dataframe(selected_display_data, use_container_width = True)
|
165 |
+
|
166 |
+
# Plot rank vs. score using matplotlib
|
167 |
+
st.markdown("**Rank vs. Score**")
|
168 |
+
fig, ax = plt.subplots(figsize = (10, 6))
|
169 |
+
ax.plot(display_data['Rank'], display_data['Score'])
|
170 |
+
ax.set_xlabel('Rank', fontsize = 12)
|
171 |
+
ax.set_ylabel('Score', fontsize = 12)
|
172 |
+
ax.set_xlim(1, display_data['Rank'].max())
|
173 |
+
|
174 |
+
# Add vertical line for selected nodes
|
175 |
+
for i, node in selected_display_data.iterrows():
|
176 |
+
ax.axvline(node['Rank'], color = 'red', linestyle = '--', label = node['Name'])
|
177 |
+
ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize = 10, color = 'red')
|
178 |
+
|
179 |
+
# Show plot
|
180 |
+
st.pyplot(fig)
|
181 |
+
|
182 |
+
|
183 |
+
# FULL RESULTS
|
184 |
+
|
185 |
+
# Show top ranked nodes
|
186 |
+
st.subheader("Model Predictions", divider = "blue")
|
187 |
+
top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
|
188 |
+
|
189 |
+
if target_node_type not in ['disease', 'anatomy']:
|
190 |
+
st.dataframe(display_data.iloc[:top_k], use_container_width = True,
|
191 |
+
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
192 |
+
help = "Click to visit external database.",
|
193 |
+
display_text = display_database)})
|
194 |
+
else:
|
195 |
+
st.dataframe(display_data.iloc[:top_k], use_container_width = True)
|
196 |
+
|
197 |
+
# Save to session state
|
198 |
+
st.session_state.predictions = display_data
|
199 |
+
st.session_state.display_database = display_database
|
pages/validate.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from menu import menu_with_redirect
|
3 |
+
|
4 |
+
# Standard imports
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
# Path manipulation
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
# Plotting
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
plt.rcParams['font.sans-serif'] = 'Arial'
|
14 |
+
import matplotlib.colors as mcolors
|
15 |
+
|
16 |
+
# Custom and other imports
|
17 |
+
import project_config
|
18 |
+
from utils import load_kg, load_kg_edges
|
19 |
+
|
20 |
+
# Redirect to app.py if not logged in, otherwise show the navigation menu
|
21 |
+
menu_with_redirect()
|
22 |
+
|
23 |
+
# Header
|
24 |
+
st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)
|
25 |
+
|
26 |
+
# Main content
|
27 |
+
# st.markdown(f"Hello, {st.session_state.name}!")
|
28 |
+
|
29 |
+
st.subheader("Validate Predictions", divider = "green")
|
30 |
+
|
31 |
+
# Print current query
|
32 |
+
st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
|
33 |
+
|
34 |
+
# Coming soon
|
35 |
+
# st.write("Coming soon...")
|
36 |
+
|
37 |
+
source_node_type = st.session_state.query['source_node_type']
|
38 |
+
source_node = st.session_state.query['source_node']
|
39 |
+
relation = st.session_state.query['relation']
|
40 |
+
target_node_type = st.session_state.query['target_node_type']
|
41 |
+
predictions = st.session_state.predictions
|
42 |
+
|
43 |
+
kg_nodes = load_kg()
|
44 |
+
kg_edges = load_kg_edges()
|
45 |
+
|
46 |
+
# Convert tuple to hex
|
47 |
+
def rgba_to_hex(rgba):
|
48 |
+
return mcolors.to_hex(rgba[:3])
|
49 |
+
|
50 |
+
with st.spinner('Searching known relationships...'):
|
51 |
+
|
52 |
+
# Subset existing edges
|
53 |
+
edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)]
|
54 |
+
edge_subset = edge_subset[edge_subset.y_type == target_node_type]
|
55 |
+
|
56 |
+
# Merge edge subset with predictions
|
57 |
+
edges_in_kg = pd.merge(predictions, edge_subset[['relation', 'y_id']], left_on = 'ID', right_on = 'y_id', how = 'right')
|
58 |
+
edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False)
|
59 |
+
edges_in_kg = edges_in_kg.drop(columns = 'y_id')
|
60 |
+
|
61 |
+
# Rename relation to ground-truth
|
62 |
+
edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
|
63 |
+
edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})
|
64 |
+
|
65 |
+
# If there exist edges in KG
|
66 |
+
if len(edges_in_kg) > 0:
|
67 |
+
|
68 |
+
with st.spinner('Plotting known relationships...'):
|
69 |
+
|
70 |
+
# Define a color map for different relations
|
71 |
+
color_map = plt.get_cmap('tab10')
|
72 |
+
|
73 |
+
# Group by relation and create separate plots
|
74 |
+
relations = edges_in_kg['Known Relation'].unique()
|
75 |
+
for idx, relation in enumerate(relations):
|
76 |
+
|
77 |
+
relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]
|
78 |
+
|
79 |
+
# Get a color from the color map
|
80 |
+
color = color_map(idx % color_map.N)
|
81 |
+
|
82 |
+
fig, ax = plt.subplots(figsize=(10, 3))
|
83 |
+
ax.plot(predictions['Rank'], predictions['Score'])
|
84 |
+
ax.set_xlabel('Rank', fontsize=12)
|
85 |
+
ax.set_ylabel('Score', fontsize=12)
|
86 |
+
ax.set_xlim(1, predictions['Rank'].max())
|
87 |
+
|
88 |
+
for i, node in relation_data.iterrows():
|
89 |
+
ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'])
|
90 |
+
# ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
|
91 |
+
|
92 |
+
# ax.set_title(f'{relation.replace("_", "-")}')
|
93 |
+
# ax.legend()
|
94 |
+
color_hex = rgba_to_hex(color)
|
95 |
+
|
96 |
+
# Write header in color of relation
|
97 |
+
st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h2>", unsafe_allow_html=True)
|
98 |
+
|
99 |
+
# Show plot
|
100 |
+
st.pyplot(fig)
|
101 |
+
|
102 |
+
# Drop known relation column
|
103 |
+
relation_data = relation_data.drop(columns = 'Known Relation')
|
104 |
+
if target_node_type not in ['disease', 'anatomy']:
|
105 |
+
st.dataframe(relation_data, use_container_width=True,
|
106 |
+
column_config={"Database": st.column_config.LinkColumn(width = "small",
|
107 |
+
help = "Click to visit external database.",
|
108 |
+
display_text = st.session_state.display_database)})
|
109 |
+
else:
|
110 |
+
st.dataframe(relation_data, use_container_width=True)
|
111 |
+
|
112 |
+
else:
|
113 |
+
|
114 |
+
st.error('No ground truth relationships found for the given query in the knowledge graph.', icon="✖️")
|
project_config.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
PROJECT CONFIGURATION FILE
|
3 |
+
This file contains the configuration variables for the project. The variables are used
|
4 |
+
in the other scripts to define the paths to the data and results directories. The variables
|
5 |
+
are also used to set the random seed for reproducibility.
|
6 |
+
'''
|
7 |
+
|
8 |
+
# import libraries
|
9 |
+
import os
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
# define project configuration variables
|
13 |
+
PROJECT_DIR = Path(os.getcwd())
|
14 |
+
DATA_DIR = PROJECT_DIR / 'data'
|
15 |
+
MODEL_DIR = PROJECT_DIR / 'models'
|
16 |
+
MEDIA_DIR = PROJECT_DIR / 'media'
|
17 |
+
SEED = 42
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
pandas
|
3 |
+
scikit-learn
|
4 |
+
matplotlib
|
5 |
+
seaborn
|
6 |
+
pathlib
|
7 |
+
torch
|
8 |
+
altair<5
|
9 |
+
gspread
|
10 |
+
oauth2client
|
11 |
+
huggingface_hub
|
utils.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import project_config
|
4 |
+
import base64
|
5 |
+
|
6 |
+
|
7 |
+
@st.cache_data(show_spinner = 'Loading knowledge graph nodes...')
|
8 |
+
def load_kg():
|
9 |
+
# with st.spinner('Loading knowledge graph...'):
|
10 |
+
kg_nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
|
11 |
+
return kg_nodes
|
12 |
+
|
13 |
+
|
14 |
+
@st.cache_data(show_spinner = 'Loading knowledge graph edges...')
|
15 |
+
def load_kg_edges():
|
16 |
+
# with st.spinner('Loading knowledge graph...'):
|
17 |
+
kg_edges = pd.read_csv(project_config.DATA_DIR / 'kg_edges.csv', dtype = {'edge_index': int, 'x_index': int, 'y_index': int}, low_memory = False)
|
18 |
+
return kg_edges
|
19 |
+
|
20 |
+
|
21 |
+
def capitalize_after_slash(s):
|
22 |
+
# Split the string by slashes first
|
23 |
+
parts = s.split('/')
|
24 |
+
# Capitalize each part separately
|
25 |
+
capitalized_parts = [part.title() for part in parts]
|
26 |
+
# Rejoin the parts with slashes
|
27 |
+
capitalized_string = '/'.join(capitalized_parts).replace('_', ' ')
|
28 |
+
return capitalized_string
|
29 |
+
|
30 |
+
|
31 |
+
# From https://stackoverflow.com/questions/73251012/put-logo-and-title-above-on-top-of-page-navigation-in-sidebar-of-streamlit-multi
|
32 |
+
# See also https://arnaudmiribel.github.io/streamlit-extras/extras/app_logo/
|
33 |
+
@st.cache_data()
|
34 |
+
def get_base64_of_bin_file(png_file):
|
35 |
+
with open(png_file, "rb") as f:
|
36 |
+
data = f.read()
|
37 |
+
return base64.b64encode(data).decode()
|
38 |
+
|
39 |
+
|
40 |
+
def build_markup_for_logo(
|
41 |
+
png_file,
|
42 |
+
background_position="50% 10%",
|
43 |
+
margin_top="10%",
|
44 |
+
padding="20px",
|
45 |
+
image_width="80%",
|
46 |
+
image_height="",
|
47 |
+
):
|
48 |
+
binary_string = get_base64_of_bin_file(png_file)
|
49 |
+
return """
|
50 |
+
<style>
|
51 |
+
[data-testid="stSidebarNav"] {
|
52 |
+
background-image: url("data:image/png;base64,%s");
|
53 |
+
background-repeat: no-repeat;
|
54 |
+
background-position: %s;
|
55 |
+
margin-top: %s;
|
56 |
+
padding: %s;
|
57 |
+
background-size: %s %s;
|
58 |
+
}
|
59 |
+
</style>
|
60 |
+
""" % (
|
61 |
+
binary_string,
|
62 |
+
background_position,
|
63 |
+
margin_top,
|
64 |
+
padding,
|
65 |
+
image_width,
|
66 |
+
image_height,
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
def add_logo(png_file):
|
71 |
+
logo_markup = build_markup_for_logo(png_file)
|
72 |
+
st.markdown(
|
73 |
+
logo_markup,
|
74 |
+
unsafe_allow_html=True,
|
75 |
+
)
|