|
import os |
|
import sys |
|
import streamlit as st |
|
import pandas as pd |
|
import joblib |
|
import time |
|
import numpy as np |
|
from rdkit import Chem |
|
from rdkit.Chem import Draw |
|
from rdkit.Chem import AllChem |
|
from rdkit import RDLogger |
|
import uuid |
|
from datasets import load_dataset |
|
import requests |
|
from io import BytesIO |
|
import urllib.request |
|
import networkx as nx |
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
from model import OnTheFlyModel, HitSelectorByOverlap, CommunityDetector, task_evaluator |
|
from morgan_desc import * |
|
from physchem_desc import * |
|
from fragment_embedder import FragmentEmbedder |
|
|
|
SIMILARITY_PERCENTILES = [95, 90] |
|
|
|
def get_session_id(): |
|
if "session_id" not in st.session_state: |
|
st.session_state["session_id"] = str(uuid.uuid4()) |
|
return st.session_state["session_id"] |
|
|
|
def clear_old_cache(cache_folder, hours=24): |
|
|
|
folder_path = cache_folder |
|
|
|
|
|
current_time = time.time() |
|
|
|
|
|
for filename in os.listdir(folder_path): |
|
file_path = os.path.join(folder_path, filename) |
|
|
|
if os.path.isfile(file_path): |
|
|
|
file_modified_time = os.path.getmtime(file_path) |
|
if current_time - file_modified_time > hours*3600: |
|
print(f"Deleting {filename} as it is older than one day") |
|
os.remove(file_path) |
|
|
|
def load_protein_spearman_similarity_matrix(): |
|
|
|
url = ''.join(('https://huggingface.co/datasets/ligdis/data/resolve/main/protein_protein_spearman_correlations.joblib')) |
|
with urllib.request.urlopen(url) as response: |
|
uniprot_acs, M = joblib.load(BytesIO(response.read())) |
|
values = np.triu(M, k=1).ravel() |
|
cutoffs = [np.percentile(values, p) for p in SIMILARITY_PERCENTILES] |
|
return uniprot_acs, M, cutoffs |
|
|
|
def load_protein_hit_similarity_matrix(): |
|
|
|
url = ''.join(('https://huggingface.co/datasets/ligdis/data/resolve/main/protein_protein_hit_cosines.joblib')) |
|
with urllib.request.urlopen(url) as response: |
|
uniprot_acs, M = joblib.load(BytesIO(response.read())) |
|
values = np.triu(M, k=1).ravel() |
|
cutoffs = [np.percentile(values, p) for p in SIMILARITY_PERCENTILES] |
|
return uniprot_acs, M, cutoffs |
|
|
|
|
|
global_uniprot_acs_0, M0, cutoffs_0 = load_protein_spearman_similarity_matrix() |
|
global_uniprot_acs_1, M1, cutoffs_1 = load_protein_hit_similarity_matrix() |
|
|
|
|
|
def get_protein_graph(uniprot_acs): |
|
G = nx.Graph() |
|
G.add_nodes_from(uniprot_acs) |
|
pid2idx_0 = dict((k, i) for i, k in enumerate(global_uniprot_acs_0)) |
|
pid2idx_1 = dict((k, i) for i, k in enumerate(global_uniprot_acs_1)) |
|
for i, pid_0 in enumerate(uniprot_acs): |
|
for j, pid_1 in enumerate(uniprot_acs): |
|
if i >= j: |
|
continue |
|
v = M0[pid2idx_0[pid_0], pid2idx_0[pid_1]] |
|
for cutoff in cutoffs_0: |
|
if v >= cutoff: |
|
if not G.has_edge(pid_0, pid_1): |
|
G.add_edge(pid_0, pid_1, weight=1) |
|
else: |
|
current_weight = G[pid_0][pid_1].get("weight") |
|
G[pid_0][pid_1]["weight"] = current_weight + 1 |
|
v = M1[pid2idx_1[pid_0], pid2idx_1[pid_1]] |
|
for cutoff in cutoffs_1: |
|
if v >= cutoff: |
|
if not G.has_edge(pid_0, pid_1): |
|
G.add_edge(pid_0, pid_1, weight=1) |
|
else: |
|
current_weight = G[pid_0][pid_1].get("weight") |
|
G[pid_0][pid_1]["weight"] = current_weight + 1 |
|
return G |
|
|
|
def load_hits(): |
|
|
|
url = ''.join(('https://huggingface.co/datasets/ligdis/data/resolve/main/hits.joblib')) |
|
with urllib.request.urlopen(url) as response: |
|
hits, fid_prom, pid_prom = joblib.load(BytesIO(response.read())) |
|
return hits, fid_prom, pid_prom |
|
|
|
def load_fid2smi(): |
|
|
|
dataset = load_dataset('ligdis/data', data_files={"cemm_smiles.csv"}) |
|
d = dataset['train'].to_pandas() |
|
fid2smi = {} |
|
for r in d.values: |
|
fid2smi[r[0]] = r[1] |
|
return fid2smi |
|
|
|
def pid2name_mapper(): |
|
|
|
dataset = load_dataset('ligdis/data', data_files={"pid2name_primary.tsv"}) |
|
df = dataset['train'].to_pandas() |
|
|
|
df.columns = ["uniprot_ac", "gene_name"] |
|
name2pid = {} |
|
pid2name = {} |
|
any2pid = {} |
|
for r in df.values: |
|
name2pid[r[1]] = r[0] |
|
pid2name[r[0]] = r[1] |
|
any2pid[r[0]] = r[0] |
|
any2pid[r[1]] = r[0] |
|
return pid2name, name2pid, any2pid |
|
|
|
def pids_to_dataframe(pids, pid2name, pid_prom): |
|
R = [] |
|
for pid in pids: |
|
r = [pid, pid2name[pid], pid_prom[pid]] |
|
R += [r] |
|
df = ( |
|
pd.DataFrame(R, columns=["UniprotAC", "Gene Name", "Fragment Hits"]) |
|
.drop_duplicates() |
|
.reset_index(drop=True) |
|
) |
|
return df |
|
|
|
def is_valid_smiles(smiles): |
|
try: |
|
mol = Chem.MolFromSmiles(smiles) |
|
except: |
|
mol = None |
|
if mol is None: |
|
return False |
|
else: |
|
return True |
|
|
|
def has_crf(mol, CRF_PATTERN): |
|
pattern = CRF_PATTERN |
|
has_pattern = mol.HasSubstructMatch(Chem.MolFromSmarts(pattern)) |
|
if not has_pattern: |
|
if mol.HasSubstructMatch( |
|
Chem.MolFromSmarts(CRF_PATTERN_0) |
|
) and mol.HasSubstructMatch(Chem.MolFromSmarts(CRF_PATTERN_1)): |
|
return True |
|
else: |
|
return False |
|
return True |
|
|
|
def attach_crf(smiles): |
|
mol = Chem.MolFromSmiles(smiles) |
|
combined_mol_0 = CombineMols(mol, crf_0, "O") |
|
combined_mol_1 = [] |
|
combined_mol = combined_mol_0 + combined_mol_1 |
|
result = [] |
|
for cm in combined_mol: |
|
smi = Chem.MolToSmiles(cm) |
|
if "." in smi: |
|
continue |
|
mol = Chem.MolFromSmiles(smi) |
|
if mol is None: |
|
continue |
|
if not has_crf(mol): |
|
continue |
|
result += [Chem.MolToSmiles(mol)] |
|
if len(result) > 0: |
|
return result[0] |
|
else: |
|
return None |
|
|
|
def get_fragment_image(smiles): |
|
m = Chem.MolFromSmiles(smiles) |
|
AllChem.Compute2DCoords(m) |
|
im = Draw.MolToImage(m, size=(200, 200)) |
|
return im |
|
|