ayushnoori commited on
Commit
f027c05
·
1 Parent(s): e6b01d4

Demo for ASAP grant renewal

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ media/pfp/*.png filter=lfs diff=lfs merge=lfs -text
37
+ data/*.csv filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore Mac temporary files
2
+ *.DS_Store
3
+ .DS_Store
4
+
5
+ # Ignore python cache files
6
+ __pycache__/
7
+
8
+ # Ignore model files
9
+ data/*.pt
10
+
11
+ # Ignore secrets
12
+ .streamlit/secrets.toml
13
+ .streamlit/cipher_asap-user-db.json
14
+ test.ipynb
.streamlit/config.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [client]
2
+ showSidebarNavigation = false
3
+
4
+ [theme]
5
+ base="light"
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Cipher Asap
3
- emoji: 📈
4
- colorFrom: blue
5
- colorTo: pink
6
  sdk: streamlit
7
- sdk_version: 1.35.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: CIPHER
3
+ emoji: 🧠
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: streamlit
7
+ sdk_version: 1.34.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ # Do not load st-gsheets-connection
3
+ # from streamlit_gsheets import GSheetsConnection
4
+ import gspread
5
+ from oauth2client.service_account import ServiceAccountCredentials
6
+ import hmac
7
+
8
+ # Standard imports
9
+ import pandas as pd
10
+
11
+ # Custom and other imports
12
+ import project_config
13
+ # from utils import add_logo
14
+ from menu import menu
15
+
16
+ # Initialize st.session_state.role to None
17
+ if "role" not in st.session_state:
18
+ st.session_state.role = None
19
+
20
+ # # Retrieve the role from Session State to initialize the widget
21
+ # st.session_state._role = st.session_state.role
22
+
23
+ # def set_role():
24
+ # # Callback function to save the role selection to Session State
25
+ # st.session_state.role = st.session_state._role
26
+
27
+
28
+ # From https://stackoverflow.com/questions/55961295/serviceaccountcredentials-from-json-keyfile-name-equivalent-for-remote-json
29
+ # See also https://www.slingacademy.com/article/pandas-how-to-read-and-update-google-sheet-files/
30
+ # See also https://docs.streamlit.io/develop/tutorials/databases/private-gsheet
31
+ # Note that the secrets cannot be passed in a group in HuggingFace Spaces,
32
+ # which is required for the native Streamlit implementation
33
+ def create_keyfile_dict():
34
+ variables_keys = {
35
+ # "spreadsheet": st.secrets['spreadsheet'], # spreadsheet
36
+ "type": st.secrets['type'], # type
37
+ "project_id": st.secrets['project_id'], # project_id
38
+ "private_key_id": st.secrets['private_key_id'], # private_key_id
39
+ # Have to replace \n with new lines (^l in Word) by hand
40
+ "private_key": st.secrets['private_key'], # private_key
41
+ "client_email": st.secrets['client_email'], # client_email
42
+ "client_id": st.secrets['client_id'], # client_id
43
+ "auth_uri": st.secrets['auth_uri'], # auth_uri
44
+ "token_uri": st.secrets['token_uri'], # token_uri
45
+ "auth_provider_x509_cert_url": st.secrets['auth_provider_x509_cert_url'], # auth_provider_x509_cert_url
46
+ "client_x509_cert_url": st.secrets['client_x509_cert_url'], # client_x509_cert_url
47
+ "universe_domain": st.secrets['universe_domain'] # universe_domain
48
+ }
49
+ return variables_keys
50
+
51
+
52
+ def check_password():
53
+ """Returns `True` if the user had a correct password."""
54
+
55
+ def login_form():
56
+ """Form with widgets to collect user information"""
57
+ # Header
58
+ col1, col2, col3 = st.columns(3)
59
+ with col2:
60
+ st.image(str(project_config.MEDIA_DIR / 'cipher_logo.svg'), width=300)
61
+
62
+ # col1, col2, col3 = st.columns(3)
63
+ # with col1:
64
+ # st.header("Log In")
65
+
66
+ with st.form("Credentials"):
67
+ st.text_input("Username", key="username")
68
+ st.text_input("Password", type="password", key="password")
69
+ st.form_submit_button("Log In", on_click=password_entered)
70
+
71
+ def password_entered():
72
+ """Checks whether a password entered by the user is correct."""
73
+
74
+ # Define the scope
75
+ scope = [
76
+ 'https://spreadsheets.google.com/feeds',
77
+ 'https://www.googleapis.com/auth/drive'
78
+ ]
79
+
80
+ # Add credentials to the account
81
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(create_keyfile_dict(), scope)
82
+
83
+ # Authenticate and create the client
84
+ client = gspread.authorize(creds)
85
+
86
+ # Open the spreadsheet
87
+ sheet = client.open_by_url(st.secrets['spreadsheet']).worksheet("user_db")
88
+ data = sheet.get_all_records()
89
+ user_db = pd.DataFrame(data)
90
+
91
+ # # Create a connection object to Google Sheets
92
+ # conn = st.connection("gsheets", type=GSheetsConnection)
93
+
94
+ # # Read the user database
95
+ # user_db = conn.read()
96
+ # user_db.dropna(axis=0, how="all", inplace=True)
97
+ # user_db.dropna(axis=1, how="all", inplace=True)
98
+
99
+ # Check if the username is in the database
100
+ if st.session_state["username"] in user_db.username.values:
101
+
102
+ st.session_state["username_correct"] = True
103
+
104
+ # Check if the password is correct
105
+ if hmac.compare_digest(
106
+ st.session_state["password"],
107
+ user_db.loc[user_db.username == st.session_state["username"], "password"].values[0],
108
+ ):
109
+
110
+ st.session_state["password_correct"] = True
111
+
112
+ # Check if the username is an admin
113
+ if st.session_state["username"] in user_db[user_db.role == "admin"].username.values:
114
+ st.session_state["role"] = "admin"
115
+ else:
116
+ st.session_state["role"] = "user"
117
+
118
+ # Retrieve and store user name and team
119
+ st.session_state["name"] = user_db.loc[user_db.username == st.session_state["username"], "name"].values[0]
120
+ st.session_state["team"] = user_db.loc[user_db.username == st.session_state["username"], "team"].values[0]
121
+ # st.session_state["profile_pic"] = user_db.loc[user_db.username == st.session_state["username"], "profile_pic"].values[0]
122
+ st.session_state["profile_pic"] = st.session_state["username"]
123
+
124
+ # Don't store the username or password
125
+ del st.session_state["password"]
126
+ # del st.session_state["username"]
127
+
128
+ else:
129
+ st.session_state["password_correct"] = False
130
+
131
+ else:
132
+ st.session_state["username_correct"] = False
133
+ st.session_state["password_correct"] = False
134
+
135
+ # Return True if the username + password is validated
136
+ if st.session_state.get("password_correct", False):
137
+ return True
138
+
139
+ # Show inputs for username + password
140
+ login_form()
141
+ if "password_correct" in st.session_state:
142
+
143
+ if not st.session_state["username_correct"]:
144
+ st.error("User not found.")
145
+ elif not st.session_state["password_correct"]:
146
+ st.error("The password you entered is incorrect.")
147
+ else:
148
+ st.error("An unexpected error occurred.")
149
+
150
+ return False
151
+
152
+ menu() # Render the dynamic menu!
153
+
154
+ if not check_password():
155
+ st.stop()
156
+
157
+ st.switch_page("pages/about.py")
data/kg_edge_types.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ee79b2f5021304a4dd82581568e8a8c940f94b29cd1206f7730bdff6b82cab4
3
+ size 5288
data/kg_edges.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab7d0e23c56381abf8e214cc5d4fae4e6a8b98957c8f2e5272b4f800953b1461
3
+ size 2765378133
data/kg_node_types.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1a0afff52deec5f48689a22a479d14cd49333759e054624366687ec4ef306c8
3
+ size 192
data/kg_nodes.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a21c42a1ee345195038d438854ee5d4befa7b1984e5efa4865ed74825a75b6d9
3
+ size 8529743
media/about_header.svg ADDED
media/cipher_logo.svg ADDED
media/explore_header.svg ADDED
media/input_header.svg ADDED
media/pfp/anoori.png ADDED

Git LFS Details

  • SHA256: 56f2cd51f6496ff1e43f0ce3fb63145a442772b16e3d456bba06cf86d78671cf
  • Pointer size: 132 Bytes
  • Size of remote file: 1.53 MB
media/pfp/cipher.png ADDED

Git LFS Details

  • SHA256: ca98082f2798c07487a3af16cee2cbab54d068b9663429aac425d3188ec04304
  • Pointer size: 131 Bytes
  • Size of remote file: 146 kB
media/pfp/gcroft.png ADDED

Git LFS Details

  • SHA256: b024c11c391dce7afde45676f7eae4cd4abfbc5fd7ad2eda6d56e96872162ceb
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
media/pfp/jpowell.png ADDED

Git LFS Details

  • SHA256: 950a39928eb56fa99267c0858f8869f893a3579533a36393912d3514899c42dc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.01 MB
media/pfp/lstuder.png ADDED

Git LFS Details

  • SHA256: dee91beea710a873478b5eb6e9c09310b2cc8d916451392083807badb04edac5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
media/pfp/mzitnik.png ADDED

Git LFS Details

  • SHA256: b514858118909ce8004028a1f87f3e7a259415d730d34371830e884a1343da2f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.2 MB
media/pfp/vkhurana.png ADDED

Git LFS Details

  • SHA256: d063af30f22a6aeb948dad3afeaeba705f5cbed9a60f6d911b3ccb870adf5e31
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
media/predict_header.svg ADDED
media/validate_header.svg ADDED
menu.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://docs.streamlit.io/develop/tutorials/multipage/st.page_link-nav
2
+ import streamlit as st
3
+ import os
4
+ import project_config
5
+
6
+ def authenticated_menu():
7
+
8
+ # Insert profile picture
9
+ pfp_path = str(project_config.MEDIA_DIR / 'pfp' / f"{st.session_state.profile_pic}.png")
10
+ if not os.path.exists(pfp_path):
11
+ pfp_path = str(project_config.MEDIA_DIR / 'pfp' / "cipher.png")
12
+ st.sidebar.image(pfp_path, use_column_width=True)
13
+ st.sidebar.markdown("---")
14
+
15
+ # Show a navigation menu for authenticated users
16
+ # st.sidebar.page_link("app.py", label="Switch Accounts", icon="🔒")
17
+ st.sidebar.page_link("pages/about.py", label="About", icon="📖")
18
+ st.sidebar.page_link("pages/input.py", label="Input", icon="💡")
19
+ st.sidebar.page_link("pages/predict.py", label="Predict", icon="🔍",
20
+ disabled=("query" not in st.session_state))
21
+ st.sidebar.page_link("pages/validate.py", label="Validate", icon="✅",
22
+ disabled=("query" not in st.session_state))
23
+ # st.sidebar.page_link("pages/explore.py", label="Explore", icon="🔍")
24
+ if st.session_state.role in ["admin"]:
25
+ st.sidebar.page_link("pages/admin.py", label="Manage Users", icon="🔧")
26
+
27
+ # Show the logout button
28
+ st.sidebar.markdown("---")
29
+ st.sidebar.button("Log Out", on_click=lambda: st.session_state.clear())
30
+
31
+
32
+ def unauthenticated_menu():
33
+
34
+ # Show a navigation menu for unauthenticated users
35
+ st.sidebar.page_link("app.py", label="Log In", icon="🔒")
36
+
37
+
38
+ def menu():
39
+ # Determine if a user is logged in or not, then show the correct navigation menu
40
+ if "role" not in st.session_state or st.session_state.role is None:
41
+ unauthenticated_menu()
42
+ return
43
+ authenticated_menu()
44
+
45
+
46
+ def menu_with_redirect():
47
+ # Redirect users to the main page if not logged in, otherwise continue to
48
+ # render the navigation menu
49
+ if "role" not in st.session_state or st.session_state.role is None:
50
+ st.switch_page("app.py")
51
+ menu()
pages/about.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from menu import menu_with_redirect
3
+
4
+ # Path manipulation
5
+ from pathlib import Path
6
+
7
+ # Custom and other imports
8
+ import project_config
9
+
10
+ # Redirect to app.py if not logged in, otherwise show the navigation menu
11
+ menu_with_redirect()
12
+
13
+ # Header
14
+ st.image(str(project_config.MEDIA_DIR / 'about_header.svg'), use_column_width=True)
15
+
16
+ # Main content
17
+ st.markdown("Welcome to CIPHER, a knowledge-grounded artificial intelligence (AI) system for **C**ontextually **I**nformed **P**recision **HE**althca**R**e in Parkinson's disease (PD).")
18
+
19
+ # Subheader
20
+ st.subheader("About CIPHER", divider = "grey")
21
+
22
+ st.markdown("""
23
+ CIPHER is a knowledge graph-based AI algorithm for diagnostic and therapeutic discovery in PD.
24
+
25
+ *Knowledge graph construction.* To create CIPHER, we integrated diverse public information about basic biomedical interactions into a harmonized data platform amenable for training large-scale AI models. Specifically, we constructed a multiscale heterogeneous knowledge graph (KG) with *n = 143,093* nodes and *n = 7,048,795* edges by curating 36 high-quality primary data sources, ontologies, and knowledge bases.
26
+
27
+ *Model training.* Next, to convert this trove of knowledge into an AI model with diagnostic and therapeutic capabilities, we employed graph representation learning, a deep learning to model biomedical networks by embedding graphs into informative low-dimensional vector spaces. We trained a state-of-the-art heterogeneous graph transformer to learn graph embeddings that encode the relationships in the KG.
28
+
29
+ Through CIPHER, we seek to enable molecular subtyping and patient stratification of PD by integrating genetic and clinical progression data (*e.g.*, PPMI and HBS2.0 cohorts) and nominate genes, proteins, and pathways for in-depth mechanistic studies in stem cell and other PD models.
30
+ """)
pages/admin.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from menu import menu_with_redirect
3
+
4
+ # Redirect to app.py if not logged in, otherwise show the navigation menu
5
+ menu_with_redirect()
6
+
7
+ # Verify the user's role
8
+ if st.session_state.role not in ["admin"]:
9
+ st.warning("You do not have permission to view this page.")
10
+ st.stop()
11
+
12
+ st.title("User Management")
13
+ st.markdown(f"You are currently logged with the role of {st.session_state.role}.")
pages/input.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from menu import menu_with_redirect
3
+
4
+ # Standard imports
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ # Path manipulation
9
+ from pathlib import Path
10
+
11
+ # Custom and other imports
12
+ import project_config
13
+ from utils import load_kg
14
+
15
+ # Redirect to app.py if not logged in, otherwise show the navigation menu
16
+ menu_with_redirect()
17
+
18
+ # Header
19
+ st.image(str(project_config.MEDIA_DIR / 'input_header.svg'), use_column_width=True)
20
+
21
+ # Main content
22
+ # st.markdown(f"Hello, {st.session_state.name}!")
23
+
24
+ st.subheader("Construct Query", divider = "red")
25
+
26
+ # # Checkbox to allow reverse edges
27
+ # allow_reverse_edges = st.checkbox("Reverse Edges", value = False)
28
+
29
+ # Load knowledge graph
30
+ kg_nodes = load_kg()
31
+
32
+ with st.spinner('Loading knowledge graph...'):
33
+ # kg_nodes = nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
34
+ node_types = pd.read_csv(project_config.DATA_DIR / 'kg_node_types.csv')
35
+ edge_types = pd.read_csv(project_config.DATA_DIR / 'kg_edge_types.csv')
36
+
37
+ # if not allow_reverse_edges:
38
+ # edge_types = edge_types[edge_types.direction == 'forward']
39
+
40
+ # If query is not in session state, initialize it
41
+ if "query" not in st.session_state:
42
+ source_node_type_index = 0
43
+ source_node_index = 0
44
+ target_node_type_index = 0
45
+ relation_index = 0
46
+ else:
47
+ source_node_type_index = st.session_state.query_options['source_node_type'].index(st.session_state.query['source_node_type'])
48
+ source_node_index = st.session_state.query_options['source_node'].index(st.session_state.query['source_node'])
49
+ target_node_type_index = st.session_state.query_options['target_node_type'].index(st.session_state.query['target_node_type'])
50
+ relation_index = st.session_state.query_options['relation'].index(st.session_state.query['relation'])
51
+
52
+ # Select source node type
53
+ # source_node_type_options = node_types['node_type']
54
+ # source_node_type = st.selectbox("Source Node Type", source_node_type_options,
55
+ # format_func = lambda x: x.replace("_", " "), index = source_node_type_index)
56
+ source_node_type = "disease"
57
+
58
+ # Select source node
59
+ # source_node_options = kg_nodes[kg_nodes['node_type'] == source_node_type]['node_name']
60
+ # source_node = st.selectbox("Source Node", source_node_options,
61
+ # index = source_node_index)
62
+ source_node = "Parkinson disease"
63
+
64
+ # Select target node type
65
+ target_node_type_options = edge_types[edge_types.x_type == source_node_type].y_type.unique()
66
+ target_node_type = st.selectbox("Target Node Type", target_node_type_options,
67
+ format_func = lambda x: x.replace("_", " "), index = target_node_type_index)
68
+
69
+ # Select relation
70
+ relation_options = edge_types[(edge_types.x_type == source_node_type) & (edge_types.y_type == target_node_type)].relation.unique()
71
+ relation = st.selectbox("Edge Type", relation_options,
72
+ format_func = lambda x: x.replace("_", "-"), index = relation_index)
73
+
74
+ # Button to submit query
75
+ if st.button("Submit Query"):
76
+
77
+ # Save query to session state
78
+ st.session_state.query = {
79
+ "source_node_type": source_node_type,
80
+ "source_node": source_node,
81
+ "target_node_type": target_node_type,
82
+ "relation": relation
83
+ }
84
+
85
+ # Save query options to session state
86
+ st.session_state.query_options = {
87
+ # "source_node_type": list(source_node_type_options),
88
+ # "source_node": list(source_node_options),
89
+ "source_node_type": ["disease"],
90
+ "source_node": ["Parkinson disease"],
91
+ "target_node_type": list(target_node_type_options),
92
+ "relation": list(relation_options)
93
+ }
94
+
95
+ # # Write query to console
96
+ # st.write("Current Query:")
97
+ # st.write(st.session_state.query)
98
+ st.write("Query submitted.")
99
+
100
+ # Switch to the Predict page
101
+ st.switch_page("pages/predict.py")
102
+
103
+
104
+ # st.subheader("Knowledge Graph", divider = "red")
105
+ # display_data = kg_nodes[['node_id', 'node_type', 'node_name', 'node_source']].copy()
106
+ # display_data = display_data.rename(columns = {'node_id': 'ID', 'node_type': 'Type', 'node_name': 'Name', 'node_source': 'Database'})
107
+ # st.dataframe(display_data, use_container_width = True)
pages/predict.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from menu import menu_with_redirect
3
+
4
+ # Standard imports
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ # Path manipulation
12
+ from pathlib import Path
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ # Plotting
16
+ import matplotlib.pyplot as plt
17
+ plt.rcParams['font.sans-serif'] = 'Arial'
18
+
19
+ # Custom and other imports
20
+ import project_config
21
+ from utils import capitalize_after_slash, load_kg
22
+
23
+ # Redirect to app.py if not logged in, otherwise show the navigation menu
24
+ menu_with_redirect()
25
+
26
+ # Header
27
+ st.image(str(project_config.MEDIA_DIR / 'predict_header.svg'), use_column_width=True)
28
+
29
+ # Main content
30
+ # st.markdown(f"Hello, {st.session_state.name}!")
31
+
32
+ st.subheader(f"{capitalize_after_slash(st.session_state.query['target_node_type'])} Search", divider = "blue")
33
+
34
+ # Print current query
35
+ st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
36
+
37
+ @st.cache_data(show_spinner = 'Downloading AI model...')
38
+ def get_embeddings():
39
+ # Get checkpoint name
40
+ # best_ckpt = "2024_05_22_11_59_43_epoch=18-step=22912"
41
+ best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
42
+ # best_ckpt = "2024_03_29_04_12_52_epoch=3-step=54291"
43
+
44
+ # Get paths to embeddings, relation weights, and edge types
45
+ # with st.spinner('Downloading AI model...'):
46
+ embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
47
+ filename=(best_ckpt + "-thresh=4000_embeddings.pt"),
48
+ token=st.secrets["HF_TOKEN"])
49
+ relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
50
+ filename=(best_ckpt + "_relation_weights.pt"),
51
+ token=st.secrets["HF_TOKEN"])
52
+ edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
53
+ filename=(best_ckpt + "_edge_types.pt"),
54
+ token=st.secrets["HF_TOKEN"])
55
+ return embed_path, relation_weights_path, edge_types_path
56
+
57
+ @st.cache_data(show_spinner = 'Loading AI model...')
58
+ def load_embeddings(embed_path, relation_weights_path, edge_types_path):
59
+ # Load embeddings, relation weights, and edge types
60
+ # with st.spinner('Loading AI model...'):
61
+ embeddings = torch.load(embed_path)
62
+ relation_weights = torch.load(relation_weights_path)
63
+ edge_types = torch.load(edge_types_path)
64
+
65
+ return embeddings, relation_weights, edge_types
66
+
67
+ # Load knowledge graph and embeddings
68
+ kg_nodes = load_kg()
69
+ embed_path, relation_weights_path, edge_types_path = get_embeddings()
70
+ embeddings, relation_weights, edge_types = load_embeddings(embed_path, relation_weights_path, edge_types_path)
71
+
72
+ # # Print source node type
73
+ # st.write(f"Source Node Type: {st.session_state.query['source_node_type']}")
74
+
75
+ # # Print source node
76
+ # st.write(f"Source Node: {st.session_state.query['source_node']}")
77
+
78
+ # # Print relation
79
+ # st.write(f"Edge Type: {st.session_state.query['relation']}")
80
+
81
+ # # Print target node type
82
+ # st.write(f"Target Node Type: {st.session_state.query['target_node_type']}")
83
+
84
+ # Compute predictions
85
+ with st.spinner('Computing predictions...'):
86
+
87
+ source_node_type = st.session_state.query['source_node_type']
88
+ source_node = st.session_state.query['source_node']
89
+ relation = st.session_state.query['relation']
90
+ target_node_type = st.session_state.query['target_node_type']
91
+
92
+ # Get source node index
93
+ src_index = kg_nodes[(kg_nodes.node_type == source_node_type) & (kg_nodes.node_name == source_node)].node_index.values[0]
94
+
95
+ # Get relation index
96
+ edge_type_index = [i for i, etype in enumerate(edge_types) if etype == (source_node_type, relation, target_node_type)][0]
97
+
98
+ # Get target nodes indices
99
+ target_nodes = kg_nodes[kg_nodes.node_type == target_node_type].copy()
100
+ dst_indices = target_nodes.node_index.values
101
+ src_indices = np.repeat(src_index, len(dst_indices))
102
+
103
+ # Retrieve cached embeddings and apply activation function
104
+ src_embeddings = embeddings[src_indices]
105
+ dst_embeddings = embeddings[dst_indices]
106
+ src_embeddings = F.leaky_relu(src_embeddings)
107
+ dst_embeddings = F.leaky_relu(dst_embeddings)
108
+
109
+ # Get relation weights
110
+ rel_weights = relation_weights[edge_type_index]
111
+
112
+ # Compute weighted dot product
113
+ scores = torch.sum(src_embeddings * rel_weights * dst_embeddings, dim = 1)
114
+ scores = torch.sigmoid(scores)
115
+
116
+ # Add scores to dataframe
117
+ target_nodes['score'] = scores.detach().numpy()
118
+ target_nodes = target_nodes.sort_values(by = 'score', ascending = False)
119
+ target_nodes['rank'] = np.arange(1, target_nodes.shape[0] + 1)
120
+
121
+ # Rename columns
122
+ display_data = target_nodes[['rank', 'node_id', 'node_name', 'score', 'node_source']].copy()
123
+ display_data = display_data.rename(columns = {'rank': 'Rank', 'node_id': 'ID', 'node_name': 'Name', 'score': 'Score', 'node_source': 'Database'})
124
+
125
+ # Define dictionary mapping node types to database URLs
126
+ map_dbs = {
127
+ 'gene/protein': lambda x: f"https://ncbi.nlm.nih.gov/gene/?term={x}",
128
+ 'drug': lambda x: f"https://go.drugbank.com/drugs/{x}",
129
+ 'effect/phenotype': lambda x: f"https://hpo.jax.org/app/browse/term/HP:{x.zfill(7)}", # pad with 0s to 7 digits
130
+ 'disease': lambda x: x, # MONDO
131
+ # pad with 0s to 7 digits
132
+ 'biological_process': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
133
+ 'molecular_function': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
134
+ 'cellular_component': lambda x: f"https://amigo.geneontology.org/amigo/term/GO:{x.zfill(7)}",
135
+ 'exposure': lambda x: f"https://ctdbase.org/detail.go?type=chem&acc={x}",
136
+ 'pathway': lambda x: f"https://reactome.org/content/detail/{x}",
137
+ 'anatomy': lambda x: x,
138
+ }
139
+
140
+ # Get name of database
141
+ display_database = display_data['Database'].values[0]
142
+
143
+ # Add URLs to database column
144
+ display_data['Database'] = display_data.apply(lambda x: map_dbs[target_node_type](x['ID']), axis = 1)
145
+
146
+
147
+ # NODE SEARCH
148
+
149
+ # Use multiselect to search for specific nodes
150
+ selected_nodes = st.multiselect(f"Search for specific {st.session_state.query['source_node_type'].replace('_', ' ')} nodes to determine their ranking.",
151
+ display_data.Name, placeholder = "Type to search...")
152
+
153
+ # Filter nodes
154
+ if len(selected_nodes) > 0:
155
+ selected_display_data = display_data[display_data.Name.isin(selected_nodes)]
156
+
157
+ # Show filtered nodes
158
+ if target_node_type not in ['disease', 'anatomy']:
159
+ st.dataframe(selected_display_data, use_container_width = True,
160
+ column_config={"Database": st.column_config.LinkColumn(width = "small",
161
+ help = "Click to visit external database.",
162
+ display_text = display_database)})
163
+ else:
164
+ st.dataframe(selected_display_data, use_container_width = True)
165
+
166
+ # Plot rank vs. score using matplotlib
167
+ st.markdown("**Rank vs. Score**")
168
+ fig, ax = plt.subplots(figsize = (10, 6))
169
+ ax.plot(display_data['Rank'], display_data['Score'])
170
+ ax.set_xlabel('Rank', fontsize = 12)
171
+ ax.set_ylabel('Score', fontsize = 12)
172
+ ax.set_xlim(1, display_data['Rank'].max())
173
+
174
+ # Add vertical line for selected nodes
175
+ for i, node in selected_display_data.iterrows():
176
+ ax.axvline(node['Rank'], color = 'red', linestyle = '--', label = node['Name'])
177
+ ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize = 10, color = 'red')
178
+
179
+ # Show plot
180
+ st.pyplot(fig)
181
+
182
+
183
+ # FULL RESULTS
184
+
185
+ # Show top ranked nodes
186
+ st.subheader("Model Predictions", divider = "blue")
187
+ top_k = st.slider('Select number of top ranked nodes to show.', 1, target_nodes.shape[0], min(500, target_nodes.shape[0]))
188
+
189
+ if target_node_type not in ['disease', 'anatomy']:
190
+ st.dataframe(display_data.iloc[:top_k], use_container_width = True,
191
+ column_config={"Database": st.column_config.LinkColumn(width = "small",
192
+ help = "Click to visit external database.",
193
+ display_text = display_database)})
194
+ else:
195
+ st.dataframe(display_data.iloc[:top_k], use_container_width = True)
196
+
197
+ # Save to session state
198
+ st.session_state.predictions = display_data
199
+ st.session_state.display_database = display_database
pages/validate.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from menu import menu_with_redirect
3
+
4
+ # Standard imports
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ # Path manipulation
9
+ from pathlib import Path
10
+
11
+ # Plotting
12
+ import matplotlib.pyplot as plt
13
+ plt.rcParams['font.sans-serif'] = 'Arial'
14
+ import matplotlib.colors as mcolors
15
+
16
+ # Custom and other imports
17
+ import project_config
18
+ from utils import load_kg, load_kg_edges
19
+
20
+ # Redirect to app.py if not logged in, otherwise show the navigation menu
21
+ menu_with_redirect()
22
+
23
+ # Header
24
+ st.image(str(project_config.MEDIA_DIR / 'validate_header.svg'), use_column_width=True)
25
+
26
+ # Main content
27
+ # st.markdown(f"Hello, {st.session_state.name}!")
28
+
29
+ st.subheader("Validate Predictions", divider = "green")
30
+
31
+ # Print current query
32
+ st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' ')} ➡️ {st.session_state.query['relation'].replace('_', '-')} ➡️ {st.session_state.query['target_node_type'].replace('_', ' ')}")
33
+
34
+ # Coming soon
35
+ # st.write("Coming soon...")
36
+
37
+ source_node_type = st.session_state.query['source_node_type']
38
+ source_node = st.session_state.query['source_node']
39
+ relation = st.session_state.query['relation']
40
+ target_node_type = st.session_state.query['target_node_type']
41
+ predictions = st.session_state.predictions
42
+
43
+ kg_nodes = load_kg()
44
+ kg_edges = load_kg_edges()
45
+
46
+ # Convert tuple to hex
47
+ def rgba_to_hex(rgba):
48
+ return mcolors.to_hex(rgba[:3])
49
+
50
+ with st.spinner('Searching known relationships...'):
51
+
52
+ # Subset existing edges
53
+ edge_subset = kg_edges[(kg_edges.x_type == source_node_type) & (kg_edges.x_name == source_node)]
54
+ edge_subset = edge_subset[edge_subset.y_type == target_node_type]
55
+
56
+ # Merge edge subset with predictions
57
+ edges_in_kg = pd.merge(predictions, edge_subset[['relation', 'y_id']], left_on = 'ID', right_on = 'y_id', how = 'right')
58
+ edges_in_kg = edges_in_kg.sort_values(by = 'Score', ascending = False)
59
+ edges_in_kg = edges_in_kg.drop(columns = 'y_id')
60
+
61
+ # Rename relation to ground-truth
62
+ edges_in_kg = edges_in_kg[['relation'] + [col for col in edges_in_kg.columns if col != 'relation']]
63
+ edges_in_kg = edges_in_kg.rename(columns = {'relation': 'Known Relation'})
64
+
65
+ # If there exist edges in KG
66
+ if len(edges_in_kg) > 0:
67
+
68
+ with st.spinner('Plotting known relationships...'):
69
+
70
+ # Define a color map for different relations
71
+ color_map = plt.get_cmap('tab10')
72
+
73
+ # Group by relation and create separate plots
74
+ relations = edges_in_kg['Known Relation'].unique()
75
+ for idx, relation in enumerate(relations):
76
+
77
+ relation_data = edges_in_kg[edges_in_kg['Known Relation'] == relation]
78
+
79
+ # Get a color from the color map
80
+ color = color_map(idx % color_map.N)
81
+
82
+ fig, ax = plt.subplots(figsize=(10, 3))
83
+ ax.plot(predictions['Rank'], predictions['Score'])
84
+ ax.set_xlabel('Rank', fontsize=12)
85
+ ax.set_ylabel('Score', fontsize=12)
86
+ ax.set_xlim(1, predictions['Rank'].max())
87
+
88
+ for i, node in relation_data.iterrows():
89
+ ax.axvline(node['Rank'], color=color, linestyle='--', label=node['Name'])
90
+ # ax.text(node['Rank'] + 100, node['Score'], node['Name'], fontsize=10, color=color)
91
+
92
+ # ax.set_title(f'{relation.replace("_", "-")}')
93
+ # ax.legend()
94
+ color_hex = rgba_to_hex(color)
95
+
96
+ # Write header in color of relation
97
+ st.markdown(f"<h3 style='color:{color_hex}'>{relation.replace('_', ' ').title()}</h2>", unsafe_allow_html=True)
98
+
99
+ # Show plot
100
+ st.pyplot(fig)
101
+
102
+ # Drop known relation column
103
+ relation_data = relation_data.drop(columns = 'Known Relation')
104
+ if target_node_type not in ['disease', 'anatomy']:
105
+ st.dataframe(relation_data, use_container_width=True,
106
+ column_config={"Database": st.column_config.LinkColumn(width = "small",
107
+ help = "Click to visit external database.",
108
+ display_text = st.session_state.display_database)})
109
+ else:
110
+ st.dataframe(relation_data, use_container_width=True)
111
+
112
+ else:
113
+
114
+ st.error('No ground truth relationships found for the given query in the knowledge graph.', icon="✖️")
project_config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ PROJECT CONFIGURATION FILE
3
+ This file contains the configuration variables for the project. The variables are used
4
+ in the other scripts to define the paths to the data and results directories. The variables
5
+ are also used to set the random seed for reproducibility.
6
+ '''
7
+
8
+ # import libraries
9
+ import os
10
+ from pathlib import Path
11
+
12
+ # define project configuration variables
13
+ PROJECT_DIR = Path(os.getcwd())
14
+ DATA_DIR = PROJECT_DIR / 'data'
15
+ MODEL_DIR = PROJECT_DIR / 'models'
16
+ MEDIA_DIR = PROJECT_DIR / 'media'
17
+ SEED = 42
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ scikit-learn
4
+ matplotlib
5
+ seaborn
6
+ pathlib
7
+ torch
8
+ altair<5
9
+ gspread
10
+ oauth2client
11
+ huggingface_hub
utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import project_config
4
+ import base64
5
+
6
+
7
+ @st.cache_data(show_spinner = 'Loading knowledge graph nodes...')
8
+ def load_kg():
9
+ # with st.spinner('Loading knowledge graph...'):
10
+ kg_nodes = pd.read_csv(project_config.DATA_DIR / 'kg_nodes.csv', dtype = {'node_index': int}, low_memory = False)
11
+ return kg_nodes
12
+
13
+
14
+ @st.cache_data(show_spinner = 'Loading knowledge graph edges...')
15
+ def load_kg_edges():
16
+ # with st.spinner('Loading knowledge graph...'):
17
+ kg_edges = pd.read_csv(project_config.DATA_DIR / 'kg_edges.csv', dtype = {'edge_index': int, 'x_index': int, 'y_index': int}, low_memory = False)
18
+ return kg_edges
19
+
20
+
21
+ def capitalize_after_slash(s):
22
+ # Split the string by slashes first
23
+ parts = s.split('/')
24
+ # Capitalize each part separately
25
+ capitalized_parts = [part.title() for part in parts]
26
+ # Rejoin the parts with slashes
27
+ capitalized_string = '/'.join(capitalized_parts).replace('_', ' ')
28
+ return capitalized_string
29
+
30
+
31
+ # From https://stackoverflow.com/questions/73251012/put-logo-and-title-above-on-top-of-page-navigation-in-sidebar-of-streamlit-multi
32
+ # See also https://arnaudmiribel.github.io/streamlit-extras/extras/app_logo/
33
+ @st.cache_data()
34
+ def get_base64_of_bin_file(png_file):
35
+ with open(png_file, "rb") as f:
36
+ data = f.read()
37
+ return base64.b64encode(data).decode()
38
+
39
+
40
+ def build_markup_for_logo(
41
+ png_file,
42
+ background_position="50% 10%",
43
+ margin_top="10%",
44
+ padding="20px",
45
+ image_width="80%",
46
+ image_height="",
47
+ ):
48
+ binary_string = get_base64_of_bin_file(png_file)
49
+ return """
50
+ <style>
51
+ [data-testid="stSidebarNav"] {
52
+ background-image: url("data:image/png;base64,%s");
53
+ background-repeat: no-repeat;
54
+ background-position: %s;
55
+ margin-top: %s;
56
+ padding: %s;
57
+ background-size: %s %s;
58
+ }
59
+ </style>
60
+ """ % (
61
+ binary_string,
62
+ background_position,
63
+ margin_top,
64
+ padding,
65
+ image_width,
66
+ image_height,
67
+ )
68
+
69
+
70
+ def add_logo(png_file):
71
+ logo_markup = build_markup_for_logo(png_file)
72
+ st.markdown(
73
+ logo_markup,
74
+ unsafe_allow_html=True,
75
+ )