root
commited on
Commit
·
ffaff91
1
Parent(s):
1e6a1f0
adding utility files used throughout FusOn-pLM training and benchmarking
Browse files- fuson_plm/utils/README.md +3 -0
- fuson_plm/utils/__init__.py +0 -0
- fuson_plm/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- fuson_plm/utils/__pycache__/clustering.cpython-310.pyc +0 -0
- fuson_plm/utils/__pycache__/constants.cpython-310.pyc +0 -0
- fuson_plm/utils/__pycache__/data_cleaning.cpython-310.pyc +0 -0
- fuson_plm/utils/__pycache__/embedding.cpython-310.pyc +0 -0
- fuson_plm/utils/__pycache__/logging.cpython-310.pyc +0 -0
- fuson_plm/utils/__pycache__/splitting.cpython-310.pyc +0 -0
- fuson_plm/utils/__pycache__/visualizing.cpython-310.pyc +0 -0
- fuson_plm/utils/clustering.py +139 -0
- fuson_plm/utils/constants.py +108 -0
- fuson_plm/utils/data_cleaning.py +126 -0
- fuson_plm/utils/embedding.py +193 -0
- fuson_plm/utils/logging.py +116 -0
- fuson_plm/utils/splitting.py +206 -0
- fuson_plm/utils/visualizing.py +545 -0
fuson_plm/utils/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
This folder contains common functions for data cleaning, clustering, train-test splitting, visualization, embedding, and logging.
|
| 2 |
+
|
| 3 |
+
The functions in these scripts are used throughout the pository for training the main model, FusOn-pLM, as well as benchmarks.
|
fuson_plm/utils/__init__.py
ADDED
|
File without changes
|
fuson_plm/utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (149 Bytes). View file
|
|
|
fuson_plm/utils/__pycache__/clustering.cpython-310.pyc
ADDED
|
Binary file (4.87 kB). View file
|
|
|
fuson_plm/utils/__pycache__/constants.cpython-310.pyc
ADDED
|
Binary file (2.48 kB). View file
|
|
|
fuson_plm/utils/__pycache__/data_cleaning.cpython-310.pyc
ADDED
|
Binary file (4.45 kB). View file
|
|
|
fuson_plm/utils/__pycache__/embedding.cpython-310.pyc
ADDED
|
Binary file (5.13 kB). View file
|
|
|
fuson_plm/utils/__pycache__/logging.cpython-310.pyc
ADDED
|
Binary file (3.31 kB). View file
|
|
|
fuson_plm/utils/__pycache__/splitting.cpython-310.pyc
ADDED
|
Binary file (6.95 kB). View file
|
|
|
fuson_plm/utils/__pycache__/visualizing.cpython-310.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
fuson_plm/utils/clustering.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import os
|
| 3 |
+
import subprocess
|
| 4 |
+
import sys
|
| 5 |
+
from Bio import SeqIO
|
| 6 |
+
import shutil
|
| 7 |
+
from fuson_plm.utils.logging import open_logfile, log_update
|
| 8 |
+
|
| 9 |
+
def ensure_mmseqs_in_path(mmseqs_dir):
|
| 10 |
+
"""
|
| 11 |
+
Checks if MMseqs2 is in the PATH. If it's not, add it. MMseqs2 will not run if this is not done correctly.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
mmseqs_dir (str): Directory containing MMseqs2 binaries
|
| 15 |
+
"""
|
| 16 |
+
mmseqs_bin = os.path.join(mmseqs_dir, 'mmseqs')
|
| 17 |
+
|
| 18 |
+
# Check if mmseqs is already in PATH
|
| 19 |
+
if shutil.which('mmseqs') is None:
|
| 20 |
+
# Export the MMseqs2 directory to PATH
|
| 21 |
+
os.environ['PATH'] = f"{mmseqs_dir}:{os.environ['PATH']}"
|
| 22 |
+
log_update(f"\tAdded {mmseqs_dir} to PATH")
|
| 23 |
+
|
| 24 |
+
def process_fasta(fasta_path):
|
| 25 |
+
fasta_sequences = SeqIO.parse(open(fasta_path),'fasta')
|
| 26 |
+
d = {}
|
| 27 |
+
for fasta in fasta_sequences:
|
| 28 |
+
id, sequence = fasta.id, str(fasta.seq)
|
| 29 |
+
|
| 30 |
+
d[id] = sequence
|
| 31 |
+
|
| 32 |
+
return d
|
| 33 |
+
|
| 34 |
+
def analyze_clustering_result(input_fasta: str, tsv_path: str):
|
| 35 |
+
"""
|
| 36 |
+
Args:
|
| 37 |
+
input_fasta (str): path to input fasta file
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
# Process input fasta
|
| 41 |
+
input_d = process_fasta(input_fasta)
|
| 42 |
+
|
| 43 |
+
# Process clusters.tsv
|
| 44 |
+
clusters = pd.read_csv(f'{tsv_path}',sep='\t',header=None)
|
| 45 |
+
clusters = clusters.rename(columns={
|
| 46 |
+
0: 'representative seq_id',
|
| 47 |
+
1: 'member seq_id'
|
| 48 |
+
})
|
| 49 |
+
|
| 50 |
+
clusters['representative seq'] = clusters['representative seq_id'].apply(lambda seq_id: input_d[seq_id])
|
| 51 |
+
clusters['member seq'] = clusters['member seq_id'].apply(lambda seq_id: input_d[seq_id])
|
| 52 |
+
|
| 53 |
+
# Sort them so that splitting results are reproducible
|
| 54 |
+
clusters = clusters.sort_values(by=['representative seq_id','member seq_id'],ascending=True).reset_index(drop=True)
|
| 55 |
+
|
| 56 |
+
return clusters
|
| 57 |
+
|
| 58 |
+
def make_fasta(sequences: dict, fasta_path: str):
|
| 59 |
+
"""
|
| 60 |
+
Makes a fasta file from sequences, where the key is the header and the value is the sequence.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
sequences (dict): A dictionary where the key is the header and the value is the sequence.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
str: The path to the fasta file.
|
| 67 |
+
"""
|
| 68 |
+
with open(fasta_path, 'w') as f:
|
| 69 |
+
for header, sequence in sequences.items():
|
| 70 |
+
f.write(f'>{header}\n{sequence}\n')
|
| 71 |
+
|
| 72 |
+
return fasta_path
|
| 73 |
+
|
| 74 |
+
def run_mmseqs_clustering(input_fasta, output_dir, min_seq_id=0.3, c=0.8, cov_mode=0, cluster_mode=0, path_to_mmseqs='fuson_plm/mmseqs'):
|
| 75 |
+
"""
|
| 76 |
+
Runs MMSeqs2 clustering using easycluster module
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
input_fasta (str): path to input fasta file, formatted >header\nsequence\n>header\nsequence....
|
| 80 |
+
output_dir (str): path to output dir for clustering results
|
| 81 |
+
min_seq_id (float): number [0,1] representing --min-seq-id in cluster command
|
| 82 |
+
c (float): nunber [0,1] representing -c in cluster command
|
| 83 |
+
cov_mode (int): number 0, 1, 2, or 3 representing --cov-mode in cluster command
|
| 84 |
+
cluster_mode (int): number 0, 1, or 2 representing --cluster-mode in cluster command
|
| 85 |
+
|
| 86 |
+
"""
|
| 87 |
+
# Get mmseqs dir
|
| 88 |
+
log_update("\nRunning MMSeqs clustering...")
|
| 89 |
+
mmseqs_dir = os.path.join(path_to_mmseqs[0:path_to_mmseqs.index('/mmseqs')], 'mmseqs/bin')
|
| 90 |
+
|
| 91 |
+
# Ensure MMseqs2 is in the PATH
|
| 92 |
+
ensure_mmseqs_in_path(mmseqs_dir)
|
| 93 |
+
|
| 94 |
+
# Define paths for MMseqs2
|
| 95 |
+
mmseqs_bin = "mmseqs" # Ensure this is in your PATH or provide the full path to mmseqs binary
|
| 96 |
+
|
| 97 |
+
# Create the output directory
|
| 98 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 99 |
+
|
| 100 |
+
# Run MMseqs2 easy-cluster
|
| 101 |
+
cmd_easy_cluster = [
|
| 102 |
+
mmseqs_bin, "easy-cluster", input_fasta, os.path.join(output_dir, "mmseqs"), output_dir,
|
| 103 |
+
"--min-seq-id", str(min_seq_id),
|
| 104 |
+
"-c", str(c),
|
| 105 |
+
"--cov-mode", str(cov_mode),
|
| 106 |
+
"--cluster-mode", str(cluster_mode),
|
| 107 |
+
"--dbtype", "1"
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
# Write the command to a log file
|
| 111 |
+
log_update("\n\tCommand entered to MMSeqs2:")
|
| 112 |
+
log_update("\t" + " ".join(cmd_easy_cluster) + "\n")
|
| 113 |
+
|
| 114 |
+
subprocess.run(cmd_easy_cluster, check=True)
|
| 115 |
+
|
| 116 |
+
log_update(f"Clustering completed. Results are in {output_dir}")
|
| 117 |
+
|
| 118 |
+
def cluster_summary(clusters: pd.DataFrame):
|
| 119 |
+
"""
|
| 120 |
+
Summarizes how many clusters were formed, how big they are, etc ...
|
| 121 |
+
"""
|
| 122 |
+
grouped_clusters = clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
|
| 123 |
+
assert len(grouped_clusters) == len(clusters['representative seq_id'].unique()) # make sure number of cluster reps = # grouped clusters
|
| 124 |
+
|
| 125 |
+
total_seqs = sum(grouped_clusters['member count'])
|
| 126 |
+
log_update(f"Created {len(grouped_clusters)} clusters of {total_seqs} sequences")
|
| 127 |
+
log_update(f"\t{len(grouped_clusters.loc[grouped_clusters['member count']==1])} clusters of size 1")
|
| 128 |
+
csize1_seqs = sum(grouped_clusters[grouped_clusters['member count']==1]['member count'])
|
| 129 |
+
log_update(f"\t\tsequences: {csize1_seqs} ({round(100*csize1_seqs/total_seqs, 2)}%)")
|
| 130 |
+
|
| 131 |
+
log_update(f"\t{len(grouped_clusters.loc[grouped_clusters['member count']>1])} clusters of size > 1")
|
| 132 |
+
csizeg1_seqs = sum(grouped_clusters[grouped_clusters['member count']>1]['member count'])
|
| 133 |
+
log_update(f"\t\tsequences: {csizeg1_seqs} ({round(100*csizeg1_seqs/total_seqs, 2)}%)")
|
| 134 |
+
log_update(f"\tlargest cluster: {max(grouped_clusters['member count'])}")
|
| 135 |
+
|
| 136 |
+
log_update("\nCluster size breakdown below...")
|
| 137 |
+
|
| 138 |
+
value_counts = grouped_clusters['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
|
| 139 |
+
log_update(value_counts.sort_values(by='cluster size (n_members)',ascending=True).to_string(index=False))
|
fuson_plm/utils/constants.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data Cleaning Parameters
|
| 2 |
+
# TCGA abbreviations for cancer. From https://gdc.cancer.gov/resources-tcga-users/tcga-code-tables/tcga-study-abbreviations
|
| 3 |
+
TCGA_CODES = {
|
| 4 |
+
'LAML': 'Acute Myeloid Leukemia',
|
| 5 |
+
'ACC': 'Adrenocortical carcinoma',
|
| 6 |
+
'BLCA': 'Bladder Urothelial Carcinoma',
|
| 7 |
+
'LGG': 'Brain Lower Grade Glioma',
|
| 8 |
+
'BRCA': 'Breast invasive carcinoma',
|
| 9 |
+
'CESC': 'Cervical squamous cell carcinoma and endocervical adenocarcinoma',
|
| 10 |
+
'CHOL': 'Cholangiocarcinoma',
|
| 11 |
+
'LCML': 'Chronic Myelogenous Leukemia',
|
| 12 |
+
'COAD': 'Colon adenocarcinoma',
|
| 13 |
+
'CNTL': 'Controls',
|
| 14 |
+
'ESCA': 'Esophageal carcinoma',
|
| 15 |
+
'FPPP': 'FFPE Pilot Phase II',
|
| 16 |
+
'GBM': 'Glioblastoma multiforme',
|
| 17 |
+
'HNSC': 'Head and Neck squamous cell carcinoma',
|
| 18 |
+
'KICH': 'Kidney Chromophobe',
|
| 19 |
+
'KIRC': 'Kidney renal clear cell carcinoma',
|
| 20 |
+
'KIRP': 'Kidney renal papillary cell carcinoma',
|
| 21 |
+
'LIHC': 'Liver hepatocellular carcinoma',
|
| 22 |
+
'LUAD': 'Lung adenocarcinoma',
|
| 23 |
+
'LUSC': 'Lung squamous cell carcinoma',
|
| 24 |
+
'DLBC': 'Lymphoid Neoplasm Diffuse Large B-cell Lymphoma',
|
| 25 |
+
'MESO': 'Mesothelioma',
|
| 26 |
+
'MISC': 'Miscellaneous',
|
| 27 |
+
'OV': 'Ovarian serous cystadenocarcinoma',
|
| 28 |
+
'PAAD': 'Pancreatic adenocarcinoma',
|
| 29 |
+
'PCPG': 'Pheochromocytoma and Paraganglioma',
|
| 30 |
+
'PRAD': 'Prostate adenocarcinoma',
|
| 31 |
+
'READ': 'Rectum adenocarcinoma',
|
| 32 |
+
'SARC': 'Sarcoma',
|
| 33 |
+
'SKCM': 'Skin Cutaneous Melanoma',
|
| 34 |
+
'STAD': 'Stomach adenocarcinoma',
|
| 35 |
+
'TGCT': 'Testicular Germ Cell Tumors',
|
| 36 |
+
'THYM': 'Thymoma',
|
| 37 |
+
'THCA': 'Thyroid carcinoma',
|
| 38 |
+
'UCS': 'Uterine Carcinosarcoma',
|
| 39 |
+
'UCEC': 'Uterine Corpus Endometrial Carcinoma',
|
| 40 |
+
'UVM': 'Uveal Melanoma'
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
FODB_CODES = {
|
| 44 |
+
'ACC': 'Adenoid cystic carcinoma',
|
| 45 |
+
'ALL': 'Acute Lymphoid Leukemia',
|
| 46 |
+
'AML': 'Acute Myeloid Leukemia',
|
| 47 |
+
'BALL': 'B-cell acute lymphoblastic leukemia',
|
| 48 |
+
'BLCA': 'Bladder Urothelial Carcinoma',
|
| 49 |
+
'BRCA': 'Breast invasive carcinoma',
|
| 50 |
+
'CESC': 'Cervical squamous cell carcinoma and endocervical adenocarcinoma',
|
| 51 |
+
'CHOL': 'Cholangiocarcinoma',
|
| 52 |
+
'EPD': 'Ependymoma',
|
| 53 |
+
'HGG': 'High-grade glioma',
|
| 54 |
+
'HNSC': 'Head and Neck squamous cell carcinoma',
|
| 55 |
+
'KIRC': 'Kidney renal clear cell carcinoma',
|
| 56 |
+
'LGG': 'Low-grade glioma',
|
| 57 |
+
'LUAD': 'Lung adenocarcinoma',
|
| 58 |
+
'LUSC': 'Lung squamous cell carcinoma',
|
| 59 |
+
'MEL': 'Melanoma',
|
| 60 |
+
'MESO': 'Mesothelioma',
|
| 61 |
+
'NBL': 'Neuroblastoma',
|
| 62 |
+
'OS': 'Osteosarcoma',
|
| 63 |
+
'OV': 'Ovarian serous cystadenocarcinoma',
|
| 64 |
+
'PCPG': 'Pheochromocytoma and Paraganglioma',
|
| 65 |
+
'PRAD': 'Prostate adenocarcinoma',
|
| 66 |
+
'READ': 'Rectum adenocarcinoma',
|
| 67 |
+
'RHB': 'Rhabdomyosarcoma',
|
| 68 |
+
'SARC': 'Sarcoma',
|
| 69 |
+
'STAD': 'Stomach adenocarcinoma',
|
| 70 |
+
'TALL': 'T-cell acute lymphoblastic leukemia',
|
| 71 |
+
'THYM': 'Thymoma',
|
| 72 |
+
'UCEC': 'Uterine Corpus Endometrial Carcinoma',
|
| 73 |
+
'UCS': 'Uterine Carcinosarcoma',
|
| 74 |
+
'UVM': 'Uveal Melanoma',
|
| 75 |
+
'WLM': 'Wilms tumor'
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
VALID_AAS = {'A',
|
| 79 |
+
'R',
|
| 80 |
+
'N',
|
| 81 |
+
'D',
|
| 82 |
+
'C',
|
| 83 |
+
'E',
|
| 84 |
+
'Q',
|
| 85 |
+
'G',
|
| 86 |
+
'H',
|
| 87 |
+
'I',
|
| 88 |
+
'L',
|
| 89 |
+
'K',
|
| 90 |
+
'M',
|
| 91 |
+
'F',
|
| 92 |
+
'P',
|
| 93 |
+
'S',
|
| 94 |
+
'T',
|
| 95 |
+
'W',
|
| 96 |
+
'Y',
|
| 97 |
+
'V'}
|
| 98 |
+
|
| 99 |
+
DELIMITERS = {',',
|
| 100 |
+
';',
|
| 101 |
+
'|',
|
| 102 |
+
'\t',
|
| 103 |
+
' ',
|
| 104 |
+
':',
|
| 105 |
+
'-',
|
| 106 |
+
'/',
|
| 107 |
+
'\\',
|
| 108 |
+
'\n'}
|
fuson_plm/utils/data_cleaning.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from fuson_plm.utils.logging import log_update
|
| 4 |
+
|
| 5 |
+
def clean_rows_and_cols(df: pd.Series) -> pd.Series:
|
| 6 |
+
"""
|
| 7 |
+
Deletes empty rows and columns
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
df (pd.Series): input DatFrame to be cleaned
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
pd.Series: cleaned DataFrame
|
| 14 |
+
"""
|
| 15 |
+
# Delete rows with no data
|
| 16 |
+
log_update(f"\trow cleaning...\n\t\toriginal # rows: {len(df)}")
|
| 17 |
+
log_update("\t\tdropping rows where all entries are np.nan...")
|
| 18 |
+
df = df.dropna(how='all')
|
| 19 |
+
log_update(f"\t\tnew # rows: {len(df)}")
|
| 20 |
+
|
| 21 |
+
# Delete columns with no data
|
| 22 |
+
log_update(f"\tcolumn cleaning...\n\t\toriginal # columns: {len(df.columns)}")
|
| 23 |
+
log_update("\t\tdropping columns where all entries are np.nan...")
|
| 24 |
+
df = df.dropna(axis=1,how='all')
|
| 25 |
+
log_update(f"\t\tnew # columns: {len(df.columns)}")
|
| 26 |
+
log_update(f"\t\tcolumn names: {','.join(list(df.columns))}")
|
| 27 |
+
|
| 28 |
+
return df
|
| 29 |
+
|
| 30 |
+
def check_columns_for_listlike(df: pd.DataFrame, cols_of_interest: list, delimiters: set):
|
| 31 |
+
"""
|
| 32 |
+
Checks if a column contains any listlike items
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
df (pd.DataFrame): DataFrame to be investigated
|
| 36 |
+
cols_of_interest (list): columns in df to be investigated for list-containing potential
|
| 37 |
+
delimiters (set): set of potential delimiting strings to search for. A column with any of these strings is considered listlike.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
dict: dictionary containing a set {} of all delimiters found in each column
|
| 41 |
+
e.g., { 'col1': {',',';'},
|
| 42 |
+
'col2': {'|'} }
|
| 43 |
+
"""
|
| 44 |
+
# return the delimiters/listlike things found for each column
|
| 45 |
+
return_dict = {}
|
| 46 |
+
|
| 47 |
+
log_update("\tchecking if any of our columns of interest look listlike (contain list objects or delimiters)...")
|
| 48 |
+
for col in cols_of_interest:
|
| 49 |
+
unique_col = list(df[col].value_counts().index)
|
| 50 |
+
listlike = any([check_item_for_listlike(x, delimiters) for x in unique_col])
|
| 51 |
+
|
| 52 |
+
if listlike:
|
| 53 |
+
found_delims = df[col].apply(lambda x: check_item_for_listlike(x, delimiters)).value_counts().reset_index()['index'].to_list()
|
| 54 |
+
unique_found_delims = set()
|
| 55 |
+
for x in found_delims:
|
| 56 |
+
unique_found_delims = unique_found_delims.union(x)
|
| 57 |
+
|
| 58 |
+
return_dict[col] = unique_found_delims
|
| 59 |
+
else:
|
| 60 |
+
return_dict[col] = False
|
| 61 |
+
|
| 62 |
+
# display the return dict
|
| 63 |
+
log_update(f"\t\tcolumn name: {col}\tlistlike: {return_dict[col]}")
|
| 64 |
+
|
| 65 |
+
return return_dict
|
| 66 |
+
|
| 67 |
+
def check_item_for_listlike(x, delimiters: set):
|
| 68 |
+
"""
|
| 69 |
+
Checks if a column looks like it contains a list of items, rather than an inidvidual item, based on string delimiters.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
x: the item to check. Any dtype.
|
| 73 |
+
delimiters: a set of delimiters to check for. e.g., {',', ';', '|', '\t', ' ', ':', '-', '/', '\\', '\n'}
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
If x is a string: the set (may be empty) of delimiters contained in the string
|
| 77 |
+
If x is not a string: the dtype of x
|
| 78 |
+
"""
|
| 79 |
+
if isinstance(x, str):
|
| 80 |
+
return find_delimiters(x, delimiters)
|
| 81 |
+
else:
|
| 82 |
+
if x is None:
|
| 83 |
+
# if it's None, it's not listlike, it's just empty. return {} because it has no delimiters.
|
| 84 |
+
return {}
|
| 85 |
+
if type(x)==float:
|
| 86 |
+
# if it's nan, it's not listlike, it's just empty. return {} because it has no delimiters.
|
| 87 |
+
if np.isnan(x):
|
| 88 |
+
return {}
|
| 89 |
+
return type(x)
|
| 90 |
+
|
| 91 |
+
def find_delimiters(seq: str, delimiters: set) -> set:
|
| 92 |
+
"""
|
| 93 |
+
Find and return a set of delimiters in a sequence. Helper mtehod for check_item_for_listlike.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
seq (str): The sequence you wish to search for invalid characters.
|
| 97 |
+
delimiters (set): a set of delimiters to check for. e.g., {',', ';', '|', '\t', ' ', ':', '-', '/', '\\', '\n'}
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
set: A set of characters in the sequence that are not in the set of valid characters.
|
| 101 |
+
"""
|
| 102 |
+
unique_chars = set(seq) # set of all characters in the sequence; unique_chars = {A, C} for protein="AAACCC"
|
| 103 |
+
overlap = delimiters.intersection(unique_chars)
|
| 104 |
+
|
| 105 |
+
if len(overlap)==0:
|
| 106 |
+
return {}
|
| 107 |
+
else:
|
| 108 |
+
return overlap
|
| 109 |
+
|
| 110 |
+
def find_invalid_chars(seq: str, valid_chars: set) -> set:
|
| 111 |
+
"""
|
| 112 |
+
Find and return a set of invalid characters in a sequence.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
seq (str): The sequence you wish to search for invalid characters.
|
| 116 |
+
valid_chars (set): A set of valid characters.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
set: A set of characters in the sequence that are not in the set of valid characters.
|
| 120 |
+
"""
|
| 121 |
+
unique_chars = set(seq) # set of all characters in the sequence; unique_chars = {A, C} for protein="AAACCC"
|
| 122 |
+
|
| 123 |
+
if unique_chars.issubset(valid_chars): # e.g. unique_chars = {A,C}, and {A,C} is a subset of valid_chars
|
| 124 |
+
return ''
|
| 125 |
+
else: # e.g. unique_chars = {A,X}. {A,X} is not a subset of valid_chars because X is not in valid_chars
|
| 126 |
+
return unique_chars.difference(valid_chars) # e.g. {A,X} - valid_chars = {X}
|
fuson_plm/utils/embedding.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import EsmModel, AutoTokenizer
|
| 4 |
+
from transformers import T5Tokenizer, T5EncoderModel
|
| 5 |
+
import pickle
|
| 6 |
+
import logging
|
| 7 |
+
from fuson_plm.utils.logging import log_update
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def redump_pickle_dictionary(pickle_path):
|
| 11 |
+
"""
|
| 12 |
+
Loads a pickle dictionary and redumps it in its location. This allows a clean reset for a pickle built with 'ab+'
|
| 13 |
+
"""
|
| 14 |
+
entries = {}
|
| 15 |
+
# Load one by one
|
| 16 |
+
with open(pickle_path, 'rb') as f:
|
| 17 |
+
while True:
|
| 18 |
+
try:
|
| 19 |
+
entry = pickle.load(f)
|
| 20 |
+
entries.update(entry)
|
| 21 |
+
except EOFError:
|
| 22 |
+
break # End of file reached
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print(f"An error occurred: {e}")
|
| 25 |
+
break
|
| 26 |
+
# Redump
|
| 27 |
+
with open(pickle_path, 'wb') as f:
|
| 28 |
+
pickle.dump(entries, f)
|
| 29 |
+
|
| 30 |
+
def load_esm2_type(esm_type, device=None):
|
| 31 |
+
"""
|
| 32 |
+
Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D)
|
| 33 |
+
"""
|
| 34 |
+
# Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
|
| 35 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
| 36 |
+
|
| 37 |
+
if device is None:
|
| 38 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
+
print(f"Using device: {device}")
|
| 40 |
+
|
| 41 |
+
model = EsmModel.from_pretrained(f"facebook/{esm_type}")
|
| 42 |
+
tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}")
|
| 43 |
+
|
| 44 |
+
model.to(device)
|
| 45 |
+
model.eval() # disables dropout for deterministic results
|
| 46 |
+
|
| 47 |
+
return model, tokenizer, device
|
| 48 |
+
|
| 49 |
+
def load_prott5():
|
| 50 |
+
# Initialize tokenizer and model
|
| 51 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 52 |
+
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
|
| 53 |
+
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
|
| 54 |
+
if device == torch.device('cpu'):
|
| 55 |
+
model.to(torch.float32)
|
| 56 |
+
model.to(device)
|
| 57 |
+
return model, tokenizer, device
|
| 58 |
+
|
| 59 |
+
def get_esm_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None):
|
| 60 |
+
"""
|
| 61 |
+
Compute ESM embeddings.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
model
|
| 65 |
+
tokenizer
|
| 66 |
+
sequences
|
| 67 |
+
device
|
| 68 |
+
average: if True, the average embeddings will be taken and returned
|
| 69 |
+
savepath: if savepath is not None, the embeddings will be saved somewhere. It must be a pickle
|
| 70 |
+
"""
|
| 71 |
+
# Correct save path to pickle if necessary
|
| 72 |
+
if savepath is not None:
|
| 73 |
+
if savepath[-4::] != '.pkl': savepath += '.pkl'
|
| 74 |
+
|
| 75 |
+
# If no max length was passed, just set it to the maximum in the dataset
|
| 76 |
+
max_seq_len = max([len(s) for s in sequences])
|
| 77 |
+
if max_length is None: max_length=max_seq_len+2 #+2 for BOS, EOS
|
| 78 |
+
|
| 79 |
+
# Initialize an empty dict to store the ESM embeddings
|
| 80 |
+
embedding_dict = {}
|
| 81 |
+
# Iterate through the seqs
|
| 82 |
+
for i in range(len(sequences)):
|
| 83 |
+
sequence = sequences[i]
|
| 84 |
+
# Get the embeddings
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
inputs = tokenizer(sequence, return_tensors="pt",padding=True, truncation=True,max_length=max_length)
|
| 87 |
+
inputs = {k:v.to(device) for k, v in inputs.items()}
|
| 88 |
+
|
| 89 |
+
outputs = model(**inputs)
|
| 90 |
+
embedding = outputs.last_hidden_state
|
| 91 |
+
|
| 92 |
+
# remove extra dimension
|
| 93 |
+
embedding = embedding.squeeze(0)
|
| 94 |
+
# remove BOS and EOS tokens
|
| 95 |
+
embedding = embedding[1:-1, :]
|
| 96 |
+
|
| 97 |
+
# Convert embeddings to numpy array (if needed)
|
| 98 |
+
embedding = embedding.cpu().numpy()
|
| 99 |
+
|
| 100 |
+
# Average (if necessary)
|
| 101 |
+
if average:
|
| 102 |
+
embedding = embedding.mean(0)
|
| 103 |
+
|
| 104 |
+
# Add to dictionary
|
| 105 |
+
embedding_dict[sequence] = embedding
|
| 106 |
+
|
| 107 |
+
# Save individual embedding (if necessary)
|
| 108 |
+
if not(savepath is None) and not(save_at_end):
|
| 109 |
+
with open(savepath, 'ab+') as f:
|
| 110 |
+
d = {sequence: embedding}
|
| 111 |
+
pickle.dump(d, f)
|
| 112 |
+
|
| 113 |
+
# Print update (if necessary)
|
| 114 |
+
if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...")
|
| 115 |
+
|
| 116 |
+
# Dump all at once at the end (if necessary)
|
| 117 |
+
if not(savepath is None):
|
| 118 |
+
# If saving for the first time, just dump it
|
| 119 |
+
if save_at_end:
|
| 120 |
+
with open(savepath, 'wb') as f:
|
| 121 |
+
pickle.dump(embedding_dict, f)
|
| 122 |
+
# If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
|
| 123 |
+
else:
|
| 124 |
+
redump_pickle_dictionary(savepath)
|
| 125 |
+
|
| 126 |
+
# Return the dictionary
|
| 127 |
+
return embedding_dict
|
| 128 |
+
|
| 129 |
+
def get_prott5_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None):
|
| 130 |
+
# Correct save path to pickle if necessary
|
| 131 |
+
if savepath is not None:
|
| 132 |
+
if savepath[-4::] != '.pkl': savepath += '.pkl'
|
| 133 |
+
|
| 134 |
+
# If no max length was passed, just set it to the maximum in the dataset
|
| 135 |
+
max_seq_len = max([len(s) for s in sequences])
|
| 136 |
+
if max_length is None: max_length=max_seq_len+2 #+2 for BOS, EOS
|
| 137 |
+
|
| 138 |
+
# the ProtT5 tokenizer requires that there are spaces between residues
|
| 139 |
+
spaced_sequences = [' '.join(list(seq)) for seq in sequences] # Spaces between residues for Prot-T5 tokenizer
|
| 140 |
+
|
| 141 |
+
# Store embeddings here
|
| 142 |
+
embedding_dict = {} # store embeddings here
|
| 143 |
+
|
| 144 |
+
for i in range(0, len(spaced_sequences)):
|
| 145 |
+
spaced_sequence = spaced_sequences[i] # get current sequence
|
| 146 |
+
seq = spaced_sequence.replace(" ", "")
|
| 147 |
+
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
inputs = tokenizer(spaced_sequence, return_tensors="pt", add_special_tokens=True, truncation=True,max_length=max_length) # shouldn't have to pad because batch size is 1
|
| 150 |
+
inputs = {k:v.to(device) for k, v in inputs.items()}
|
| 151 |
+
|
| 152 |
+
# Pass through the model with no gradient to get embeddings
|
| 153 |
+
embedding_repr = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
| 154 |
+
|
| 155 |
+
# Process the embedding
|
| 156 |
+
seq_length = len(seq) # length of the sequence is after you remove spaces
|
| 157 |
+
embedding = embedding_repr.last_hidden_state.squeeze(0) # remove batch dimension
|
| 158 |
+
embedding = embedding[0:-1] # remove EOS token (there is no BOS token)
|
| 159 |
+
embedding = embedding.cpu().numpy() # put on CPU and numpy
|
| 160 |
+
embedding_log = f"\tembedding shape: {embedding.shape}"
|
| 161 |
+
# MAKE SURE the embedding lengths are right with an assert. We expect embedding dimension 1024, and sequence length to match real sequence length
|
| 162 |
+
assert embedding.shape[1] == 1024
|
| 163 |
+
assert embedding.shape[0] == seq_length
|
| 164 |
+
|
| 165 |
+
# Average (if necessary)
|
| 166 |
+
if average:
|
| 167 |
+
dim_before = embedding.shape
|
| 168 |
+
embedding = embedding.mean(0)
|
| 169 |
+
embedding_log = f"\tembedding shape before avg: {dim_before}\tafter avg: {embedding.shape}"
|
| 170 |
+
|
| 171 |
+
# Add the embedding to the dictionary
|
| 172 |
+
embedding_dict[seq] = embedding
|
| 173 |
+
|
| 174 |
+
# Save individual embedding (if necessary)
|
| 175 |
+
if not(savepath is None) and not(save_at_end):
|
| 176 |
+
with open(savepath, 'ab+') as f:
|
| 177 |
+
d = {seq: embedding}
|
| 178 |
+
pickle.dump(d, f)
|
| 179 |
+
|
| 180 |
+
if print_updates: log_update(f"sequence {i+1}: {seq[0:10]}...{embedding_log}\t seq len: {seq_length}")
|
| 181 |
+
|
| 182 |
+
# Dump all at once at the end (if necessary)
|
| 183 |
+
if not(savepath is None):
|
| 184 |
+
# If saving for the first time, just dump it
|
| 185 |
+
if save_at_end:
|
| 186 |
+
with open(savepath, 'wb') as f:
|
| 187 |
+
pickle.dump(embedding_dict, f)
|
| 188 |
+
# If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
|
| 189 |
+
else:
|
| 190 |
+
redump_pickle_dictionary(savepath)
|
| 191 |
+
|
| 192 |
+
# Return the dictionary
|
| 193 |
+
return embedding_dict
|
fuson_plm/utils/logging.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
import sys
|
| 4 |
+
import pytz
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
class CustomParams:
|
| 8 |
+
"""
|
| 9 |
+
Class for custom parameters where dictionary elements can be accessed as attributes
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, **kwargs):
|
| 12 |
+
self.__dict__.update(kwargs)
|
| 13 |
+
|
| 14 |
+
def print_config(self,indent=''):
|
| 15 |
+
for attr, value in self.__dict__.items():
|
| 16 |
+
print(f"{indent}{attr}: {value}")
|
| 17 |
+
|
| 18 |
+
def log_update(text: str):
|
| 19 |
+
"""
|
| 20 |
+
Logs input text to an output file
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
text (str): the text to be logged
|
| 24 |
+
"""
|
| 25 |
+
print(text) # log_update the text
|
| 26 |
+
sys.stdout.flush() # flush to automatically update the output file
|
| 27 |
+
|
| 28 |
+
@contextmanager
|
| 29 |
+
def open_logfile(log_path,mode='w'):
|
| 30 |
+
"""
|
| 31 |
+
Open log-file for real-time logging of the most important updates
|
| 32 |
+
"""
|
| 33 |
+
log_file = open(log_path, mode) # open
|
| 34 |
+
original_stdout = sys.stdout # save original stdout
|
| 35 |
+
sys.stdout = log_file # redirect stdout to log_file
|
| 36 |
+
try:
|
| 37 |
+
yield log_file
|
| 38 |
+
finally:
|
| 39 |
+
sys.stdout = original_stdout
|
| 40 |
+
log_file.close()
|
| 41 |
+
|
| 42 |
+
@contextmanager
|
| 43 |
+
def open_errfile(log_path,mode='w'):
|
| 44 |
+
"""
|
| 45 |
+
Redirects stderr (error messages) to a separate log file.
|
| 46 |
+
"""
|
| 47 |
+
log_file = open(log_path, mode) # open the error log file for writing
|
| 48 |
+
original_stderr = sys.stderr # save original stderr
|
| 49 |
+
sys.stderr = log_file # redirect stderr to log_file
|
| 50 |
+
try:
|
| 51 |
+
yield log_file
|
| 52 |
+
finally:
|
| 53 |
+
sys.stderr = original_stderr # restore original stderr
|
| 54 |
+
log_file.close() # close the error log file
|
| 55 |
+
|
| 56 |
+
def print_configpy(module):
|
| 57 |
+
"""
|
| 58 |
+
Prints all the configurations in a config.py file
|
| 59 |
+
"""
|
| 60 |
+
log_update("All configurations:")
|
| 61 |
+
# Iterate over attributes
|
| 62 |
+
for attribute in dir(module):
|
| 63 |
+
# Filter out built-in attributes and methods
|
| 64 |
+
if not attribute.startswith("__"):
|
| 65 |
+
value = getattr(module, attribute)
|
| 66 |
+
log_update(f"\t{attribute}: {value}")
|
| 67 |
+
|
| 68 |
+
def get_local_time(timezone_str='US/Eastern'):
|
| 69 |
+
"""
|
| 70 |
+
Get current time in the specified timezone.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
timezone_str (str): The timezone to retrieve time for. Defaults to 'US/Eastern'.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
str: The formatted current time in the specified timezone.
|
| 77 |
+
"""
|
| 78 |
+
try:
|
| 79 |
+
timezone = pytz.timezone(timezone_str)
|
| 80 |
+
except pytz.UnknownTimeZoneError:
|
| 81 |
+
return f"Unknown timezone: {timezone_str}"
|
| 82 |
+
|
| 83 |
+
current_datetime = datetime.now(pytz.utc).astimezone(timezone)
|
| 84 |
+
return current_datetime.strftime('%m-%d-%Y-%H:%M:%S')
|
| 85 |
+
|
| 86 |
+
def get_local_date_yr(timezone_str='US/Eastern'):
|
| 87 |
+
"""
|
| 88 |
+
Get current time in the specified timezone.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
timezone_str (str): The timezone to retrieve time for. Defaults to 'US/Eastern'.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
str: The formatted current time in the specified timezone.
|
| 95 |
+
"""
|
| 96 |
+
try:
|
| 97 |
+
timezone = pytz.timezone(timezone_str)
|
| 98 |
+
except pytz.UnknownTimeZoneError:
|
| 99 |
+
return f"Unknown timezone: {timezone_str}"
|
| 100 |
+
|
| 101 |
+
current_datetime = datetime.now(pytz.utc).astimezone(timezone)
|
| 102 |
+
return current_datetime.strftime('%m_%d_%Y')
|
| 103 |
+
|
| 104 |
+
def find_fuson_plm_directory():
|
| 105 |
+
"""
|
| 106 |
+
Constructs a path backwards to fuson_plm directory so we don't have to use absolute paths (helps for docker containers)
|
| 107 |
+
"""
|
| 108 |
+
current_dir = os.path.abspath(os.getcwd())
|
| 109 |
+
|
| 110 |
+
while True:
|
| 111 |
+
if 'fuson_plm' in os.listdir(current_dir):
|
| 112 |
+
return os.path.join(current_dir, 'fuson_plm')
|
| 113 |
+
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
|
| 114 |
+
if parent_dir == current_dir: # If we've reached the root directory
|
| 115 |
+
raise FileNotFoundError("fuson_plm directory not found.")
|
| 116 |
+
current_dir = parent_dir
|
fuson_plm/utils/splitting.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from sklearn.model_selection import train_test_split
|
| 3 |
+
from fuson_plm.utils.logging import log_update
|
| 4 |
+
|
| 5 |
+
def split_clusters_train_test(X, y, benchmark_cluster_reps=[], random_state = 1, test_size = 0.20):
|
| 6 |
+
# cluster with random state fixed for reproducible results
|
| 7 |
+
log_update(f"\tPerforming split: all clusters -> train clusters ({round(1-test_size,3)}) and test clusters ({test_size})")
|
| 8 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
|
| 9 |
+
|
| 10 |
+
# add benchmark representatives back to X_test
|
| 11 |
+
log_update(f"\tManually adding {len(benchmark_cluster_reps)} clusters containing benchmark seqs into X_test")
|
| 12 |
+
X_test += benchmark_cluster_reps
|
| 13 |
+
|
| 14 |
+
# assert no duplicates within the train, test, or val sets (there shouldn't be, if the input data was clean)
|
| 15 |
+
assert len(X_train)==len(set(X_train))
|
| 16 |
+
assert len(X_test)==len(set(X_test))
|
| 17 |
+
|
| 18 |
+
return {
|
| 19 |
+
'X_train': X_train,
|
| 20 |
+
'X_test': X_test
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
def split_clusters_train_val_test(X, y, benchmark_cluster_reps=[], random_state_1 = 1, random_state_2 = 1, test_size_1 = 0.20, test_size_2 = 0.50):
|
| 24 |
+
# cluster with random state fixed for reproducible results
|
| 25 |
+
log_update(f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})")
|
| 26 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size_1, random_state=random_state_1)
|
| 27 |
+
log_update(f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})")
|
| 28 |
+
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=test_size_2, random_state=random_state_2)
|
| 29 |
+
|
| 30 |
+
# add benchmark representatives back to X_test
|
| 31 |
+
log_update(f"\tManually adding {len(benchmark_cluster_reps)} clusters containing benchmark seqs into X_test")
|
| 32 |
+
X_test += benchmark_cluster_reps
|
| 33 |
+
|
| 34 |
+
# assert no duplicates within the train, test, or val sets (there shouldn't be, if the input data was clean)
|
| 35 |
+
assert len(X_train)==len(set(X_train))
|
| 36 |
+
assert len(X_val)==len(set(X_val))
|
| 37 |
+
assert len(X_test)==len(set(X_test))
|
| 38 |
+
|
| 39 |
+
return {
|
| 40 |
+
'X_train': X_train,
|
| 41 |
+
'X_val': X_val,
|
| 42 |
+
'X_test': X_test
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
def split_clusters(cluster_representatives: list, val_set = True, benchmark_cluster_reps=[], random_state_1 = 1, random_state_2 = 1, test_size_1 = 0.20, test_size_2 = 0.50):
|
| 46 |
+
""""
|
| 47 |
+
Cluster-splitting method amenable to either train-test or train-val-test.
|
| 48 |
+
For train-val-test, there are two splits.
|
| 49 |
+
"""
|
| 50 |
+
log_update("\nPerforming splits...")
|
| 51 |
+
# Approx. 80/10/10 split
|
| 52 |
+
X = [x for x in cluster_representatives if not(x in benchmark_cluster_reps)] # X, for splitting, does NOT include benchmark reps. We'll add these clusters to test.
|
| 53 |
+
y = [0]*len(X) # y is a dummy array here; there are no values.
|
| 54 |
+
|
| 55 |
+
split_dict = None
|
| 56 |
+
if val_set:
|
| 57 |
+
split_dict = split_clusters_train_val_test(X, y, benchmark_cluster_reps=benchmark_cluster_reps,
|
| 58 |
+
random_state_1 = random_state_1, random_state_2 = random_state_2,
|
| 59 |
+
test_size_1 = test_size_1, test_size_2 = test_size_2)
|
| 60 |
+
else:
|
| 61 |
+
split_dict = split_clusters_train_test(X, y, benchmark_cluster_reps=benchmark_cluster_reps,
|
| 62 |
+
random_state = random_state_1,
|
| 63 |
+
test_size = test_size_1)
|
| 64 |
+
|
| 65 |
+
return split_dict
|
| 66 |
+
|
| 67 |
+
def check_split_validity(train_clusters, val_clusters, test_clusters, benchmark_sequences=None):
|
| 68 |
+
"""
|
| 69 |
+
Args:
|
| 70 |
+
train_clusters (pd.DataFrame):
|
| 71 |
+
val_clusters (pd.DataFrame): (optional - can pass None if there is no validation set)
|
| 72 |
+
test_clusters (pd.DataFrame):
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
# Make grouped versions of these DataFrames for size analysis
|
| 76 |
+
train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
|
| 77 |
+
if val_clusters is not None:
|
| 78 |
+
val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
|
| 79 |
+
if test_clusters is not None:
|
| 80 |
+
test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
|
| 81 |
+
|
| 82 |
+
# Calculate stats - clusters
|
| 83 |
+
n_train_clusters = len(train_clustersgb)
|
| 84 |
+
n_val_clusters, n_test_clusters = 0, 0
|
| 85 |
+
if val_clusters is not None:
|
| 86 |
+
n_val_clusters = len(val_clustersgb)
|
| 87 |
+
if test_clusters is not None:
|
| 88 |
+
n_test_clusters = len(test_clustersgb)
|
| 89 |
+
n_clusters = n_train_clusters + n_val_clusters + n_test_clusters
|
| 90 |
+
|
| 91 |
+
assert len(train_clusters['representative seq_id'].unique()) == len(train_clustersgb)
|
| 92 |
+
if val_clusters is not None:
|
| 93 |
+
assert len(val_clusters['representative seq_id'].unique()) == len(val_clustersgb)
|
| 94 |
+
if test_clusters is not None:
|
| 95 |
+
assert len(test_clusters['representative seq_id'].unique()) == len(test_clustersgb)
|
| 96 |
+
|
| 97 |
+
train_cluster_pcnt = round(100*n_train_clusters/n_clusters,2)
|
| 98 |
+
if val_clusters is not None:
|
| 99 |
+
val_cluster_pcnt = round(100*n_val_clusters/n_clusters,2)
|
| 100 |
+
if test_clusters is not None:
|
| 101 |
+
test_cluster_pcnt = round(100*n_test_clusters/n_clusters,2)
|
| 102 |
+
|
| 103 |
+
# Calculate stats - proteins
|
| 104 |
+
n_train_proteins = len(train_clusters)
|
| 105 |
+
n_val_proteins, n_test_proteins = 0, 0
|
| 106 |
+
if val_clusters is not None:
|
| 107 |
+
n_val_proteins = len(val_clusters)
|
| 108 |
+
if test_clusters is not None:
|
| 109 |
+
n_test_proteins = len(test_clusters)
|
| 110 |
+
n_proteins = n_train_proteins + n_val_proteins + n_test_proteins
|
| 111 |
+
|
| 112 |
+
assert len(train_clusters) == sum(train_clustersgb['member count'])
|
| 113 |
+
if val_clusters is not None:
|
| 114 |
+
assert len(val_clusters) == sum(val_clustersgb['member count'])
|
| 115 |
+
if test_clusters is not None:
|
| 116 |
+
assert len(test_clusters) == sum(test_clustersgb['member count'])
|
| 117 |
+
|
| 118 |
+
train_protein_pcnt = round(100*n_train_proteins/n_proteins,2)
|
| 119 |
+
if val_clusters is not None:
|
| 120 |
+
val_protein_pcnt = round(100*n_val_proteins/n_proteins,2)
|
| 121 |
+
if test_clusters is not None:
|
| 122 |
+
test_protein_pcnt = round(100*n_test_proteins/n_proteins,2)
|
| 123 |
+
|
| 124 |
+
# Print results
|
| 125 |
+
log_update("\nCluster breakdown...")
|
| 126 |
+
log_update(f"Total clusters = {n_clusters}, total proteins = {n_proteins}")
|
| 127 |
+
log_update(f"\tTrain set:\n\t\tTotal Clusters = {len(train_clustersgb)} ({train_cluster_pcnt}%)\n\t\tTotal Proteins = {len(train_clusters)} ({train_protein_pcnt}%)")
|
| 128 |
+
if val_clusters is not None:
|
| 129 |
+
log_update(f"\tVal set:\n\t\tTotal Clusters = {len(val_clustersgb)} ({val_cluster_pcnt}%)\n\t\tTotal Proteins = {len(val_clusters)} ({val_protein_pcnt}%)")
|
| 130 |
+
if test_clusters is not None:
|
| 131 |
+
log_update(f"\tTest set:\n\t\tTotal Clusters = {len(test_clustersgb)} ({test_cluster_pcnt}%)\n\t\tTotal Proteins = {len(test_clusters)} ({test_protein_pcnt}%)")
|
| 132 |
+
|
| 133 |
+
# Check for overlap in both sequence ID and sequence actual
|
| 134 |
+
train_protein_ids = set(train_clusters['member seq_id'])
|
| 135 |
+
train_protein_seqs = set(train_clusters['member seq'])
|
| 136 |
+
if val_clusters is not None:
|
| 137 |
+
val_protein_ids = set(val_clusters['member seq_id'])
|
| 138 |
+
val_protein_seqs = set(val_clusters['member seq'])
|
| 139 |
+
if test_clusters is not None:
|
| 140 |
+
test_protein_ids = set(test_clusters['member seq_id'])
|
| 141 |
+
test_protein_seqs = set(test_clusters['member seq'])
|
| 142 |
+
|
| 143 |
+
# Print results
|
| 144 |
+
log_update("\nChecking for overlap...")
|
| 145 |
+
if (val_clusters is not None) and (test_clusters is not None):
|
| 146 |
+
log_update(f"\tSequence IDs...\n\t\tTrain-Val Overlap: {len(train_protein_ids.intersection(val_protein_ids))}\n\t\tTrain-Test Overlap: {len(train_protein_ids.intersection(test_protein_ids))}\n\t\tVal-Test Overlap: {len(val_protein_ids.intersection(test_protein_ids))}")
|
| 147 |
+
log_update(f"\tSequences...\n\t\tTrain-Val Overlap: {len(train_protein_seqs.intersection(val_protein_seqs))}\n\t\tTrain-Test Overlap: {len(train_protein_seqs.intersection(test_protein_seqs))}\n\t\tVal-Test Overlap: {len(val_protein_seqs.intersection(test_protein_seqs))}")
|
| 148 |
+
if (val_clusters is not None) and (test_clusters is None):
|
| 149 |
+
log_update(f"\tSequence IDs...\n\t\tTrain-Val Overlap: {len(train_protein_ids.intersection(val_protein_ids))}")
|
| 150 |
+
log_update(f"\tSequences...\n\t\tTrain-Val Overlap: {len(train_protein_seqs.intersection(val_protein_seqs))}")
|
| 151 |
+
if (val_clusters is None) and (test_clusters is not None):
|
| 152 |
+
log_update(f"\tSequence IDs...\n\t\tTrain-Test Overlap: {len(train_protein_ids.intersection(test_protein_ids))}")
|
| 153 |
+
log_update(f"\tSequences...\n\t\tTrain-Test Overlap: {len(train_protein_seqs.intersection(test_protein_seqs))}")
|
| 154 |
+
|
| 155 |
+
# Assert no sequence overlap
|
| 156 |
+
if val_clusters is not None:
|
| 157 |
+
assert len(train_protein_seqs.intersection(val_protein_seqs))==0
|
| 158 |
+
if test_clusters is not None:
|
| 159 |
+
assert len(train_protein_seqs.intersection(test_protein_seqs))==0
|
| 160 |
+
if (val_clusters is not None) and (test_clusters is not None):
|
| 161 |
+
assert len(val_protein_seqs.intersection(test_protein_seqs))==0
|
| 162 |
+
|
| 163 |
+
# Finally, check that there are only benchmark sequences in test - if there are benchmark sequences
|
| 164 |
+
if not(benchmark_sequences is None):
|
| 165 |
+
bench_in_train = len(train_clusters.loc[train_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique())
|
| 166 |
+
bench_in_val, bench_in_test = 0, 0
|
| 167 |
+
if val_clusters is not None:
|
| 168 |
+
bench_in_val = len(val_clusters.loc[val_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique())
|
| 169 |
+
if test_clusters is not None:
|
| 170 |
+
bench_in_test = len(test_clusters.loc[test_clusters['member seq'].isin(benchmark_sequences)]['member seq'].unique())
|
| 171 |
+
|
| 172 |
+
# Assert this
|
| 173 |
+
log_update("\nChecking for benchmark sequence presence in test, and absence from train and val...")
|
| 174 |
+
log_update(f"\tTotal benchmark sequences: {len(benchmark_sequences)}")
|
| 175 |
+
log_update(f"\tBenchmark sequences in train: {bench_in_train}")
|
| 176 |
+
if val_clusters is not None:
|
| 177 |
+
log_update(f"\tBenchmark sequences in val: {bench_in_val}")
|
| 178 |
+
if test_clusters is not None:
|
| 179 |
+
log_update(f"\tBenchmark sequences in test: {bench_in_test}")
|
| 180 |
+
assert bench_in_train == bench_in_val == 0
|
| 181 |
+
assert bench_in_test == len(benchmark_sequences)
|
| 182 |
+
|
| 183 |
+
def check_class_distributions(train_df, val_df, test_df, class_col='class'):
|
| 184 |
+
"""
|
| 185 |
+
Checks class distributions within train, val, and test sets.
|
| 186 |
+
Expects input dataframes to have 'sequence' column and 'class' column
|
| 187 |
+
"""
|
| 188 |
+
train_vc = pd.DataFrame(train_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'train_count'})
|
| 189 |
+
train_vc['train_pct'] = (train_vc['train_count'] / train_vc['train_count'].sum()).round(3)*100
|
| 190 |
+
if val_df is not None:
|
| 191 |
+
val_vc = pd.DataFrame(val_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'val_count'})
|
| 192 |
+
val_vc['val_pct'] = (val_vc['val_count'] / val_vc['val_count'].sum()).round(3)*100
|
| 193 |
+
test_vc = pd.DataFrame(test_df[class_col].value_counts()).reset_index().rename(columns={'index':class_col, class_col:'test_count'})
|
| 194 |
+
test_vc['test_pct'] = (test_vc['test_count'] / test_vc['test_count'].sum()).round(3)*100
|
| 195 |
+
# concatenate so I can see them next to each other
|
| 196 |
+
if val_df is not None:
|
| 197 |
+
compare = pd.concat([train_vc, val_vc, test_vc], axis=1)
|
| 198 |
+
compare['train-val diff'] = (compare['train_pct'] - compare['val_pct']).apply(lambda x: abs(x))
|
| 199 |
+
compare['val-test diff'] = (compare['val_pct'] - compare['test_pct']).apply(lambda x: abs(x))
|
| 200 |
+
else:
|
| 201 |
+
compare = pd.concat([train_vc, test_vc], axis=1)
|
| 202 |
+
compare['train-test diff'] = (compare['train_pct'] - compare['test_pct']).apply(lambda x: abs(x))
|
| 203 |
+
|
| 204 |
+
compare_str = compare.to_string(index=False)
|
| 205 |
+
compare_str = "\t" + compare_str.replace("\n","\n\t")
|
| 206 |
+
log_update(f"\nClass distribution:\n{compare_str}")
|
fuson_plm/utils/visualizing.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import matplotlib.font_manager as fm
|
| 3 |
+
from matplotlib.font_manager import FontProperties
|
| 4 |
+
from scipy.stats import entropy
|
| 5 |
+
from sklearn.manifold import TSNE
|
| 6 |
+
import pickle
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import os
|
| 9 |
+
import numpy as np
|
| 10 |
+
from fuson_plm.utils.logging import log_update, find_fuson_plm_directory
|
| 11 |
+
|
| 12 |
+
def set_font():
|
| 13 |
+
# Load and set the font
|
| 14 |
+
fuson_plm_dir = find_fuson_plm_directory()
|
| 15 |
+
|
| 16 |
+
# Paths for regular, bold, italic fonts
|
| 17 |
+
regular_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Regular.ttf')
|
| 18 |
+
bold_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Bold.ttf')
|
| 19 |
+
italic_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-Italic.ttf')
|
| 20 |
+
bold_italic_font_path = os.path.join(fuson_plm_dir, 'ubuntu_font', 'Ubuntu-BoldItalic.ttf')
|
| 21 |
+
|
| 22 |
+
# Load the font properties
|
| 23 |
+
regular_font = FontProperties(fname=regular_font_path)
|
| 24 |
+
bold_font = FontProperties(fname=bold_font_path)
|
| 25 |
+
italic_font = FontProperties(fname=italic_font_path)
|
| 26 |
+
bold_italic_font = FontProperties(fname=bold_italic_font_path)
|
| 27 |
+
|
| 28 |
+
# Add the fonts to the font manager
|
| 29 |
+
fm.fontManager.addfont(regular_font_path)
|
| 30 |
+
fm.fontManager.addfont(bold_font_path)
|
| 31 |
+
fm.fontManager.addfont(italic_font_path)
|
| 32 |
+
fm.fontManager.addfont(bold_italic_font_path)
|
| 33 |
+
|
| 34 |
+
# Set the font family globally to Ubuntu
|
| 35 |
+
plt.rcParams['font.family'] = regular_font.get_name()
|
| 36 |
+
|
| 37 |
+
# Set the fonts for math text (like for labels) to use the loaded Ubuntu fonts
|
| 38 |
+
plt.rcParams['mathtext.fontset'] = 'custom'
|
| 39 |
+
plt.rcParams['mathtext.rm'] = regular_font.get_name()
|
| 40 |
+
plt.rcParams['mathtext.it'] = f'{italic_font.get_name()}'
|
| 41 |
+
plt.rcParams['mathtext.bf'] = f'{bold_font.get_name()}'
|
| 42 |
+
|
| 43 |
+
global default_color_map
|
| 44 |
+
default_color_map = {
|
| 45 |
+
'train': '#0072B2',
|
| 46 |
+
'val': '#009E73',
|
| 47 |
+
'test': '#E69F00'
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
def get_avg_embeddings_for_tsne(train_sequences=None, val_sequences=None, test_sequences=None, embedding_path='fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl'):
|
| 51 |
+
if train_sequences is None: train_sequences = []
|
| 52 |
+
if val_sequences is None: val_sequences = []
|
| 53 |
+
if test_sequences is None: test_sequences = []
|
| 54 |
+
|
| 55 |
+
embeddings = {}
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
with open(embedding_path, 'rb') as f:
|
| 59 |
+
embeddings = pickle.load(f)
|
| 60 |
+
|
| 61 |
+
train_embeddings = [v for k, v in embeddings.items() if k in train_sequences]
|
| 62 |
+
val_embeddings = [v for k, v in embeddings.items() if k in val_sequences]
|
| 63 |
+
test_embeddings = [v for k, v in embeddings.items() if k in test_sequences]
|
| 64 |
+
|
| 65 |
+
return train_embeddings, val_embeddings, test_embeddings
|
| 66 |
+
except:
|
| 67 |
+
print("could not open embeddings")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def calculate_aa_composition(sequences):
|
| 71 |
+
composition = {}
|
| 72 |
+
total_length = sum([len(seq) for seq in sequences])
|
| 73 |
+
|
| 74 |
+
for seq in sequences:
|
| 75 |
+
for aa in seq:
|
| 76 |
+
if aa in composition:
|
| 77 |
+
composition[aa] += 1
|
| 78 |
+
else:
|
| 79 |
+
composition[aa] = 1
|
| 80 |
+
|
| 81 |
+
# Convert counts to relative frequency
|
| 82 |
+
for aa in composition:
|
| 83 |
+
composition[aa] /= total_length
|
| 84 |
+
|
| 85 |
+
return composition
|
| 86 |
+
|
| 87 |
+
def calculate_shannon_entropy(sequence):
|
| 88 |
+
"""
|
| 89 |
+
Calculate the Shannon entropy for a given sequence.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
sequence (str): A sequence of characters (e.g., amino acids or nucleotides).
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
float: Shannon entropy value.
|
| 96 |
+
"""
|
| 97 |
+
bases = set(sequence)
|
| 98 |
+
counts = [sequence.count(base) for base in bases]
|
| 99 |
+
return entropy(counts, base=2)
|
| 100 |
+
|
| 101 |
+
def visualize_splits_hist(train_lengths=None, val_lengths=None, test_lengths=None, colormap=None, savepath=f'splits/length_distributions.png', axes=None):
|
| 102 |
+
"""
|
| 103 |
+
Works to plot train, val, test; train, val; or train, test
|
| 104 |
+
"""
|
| 105 |
+
set_font()
|
| 106 |
+
if colormap is None: colormap=default_color_map
|
| 107 |
+
|
| 108 |
+
log_update('\nMaking histogram of length distributions')
|
| 109 |
+
|
| 110 |
+
# Get index for test plot
|
| 111 |
+
val_plot_index, test_plot_index, total_plots = 1, 2, 3
|
| 112 |
+
if val_lengths is None:
|
| 113 |
+
val_plot_index = None
|
| 114 |
+
test_plot_index-= 1
|
| 115 |
+
total_plots-=1
|
| 116 |
+
if test_lengths is None:
|
| 117 |
+
test_plot_index = None
|
| 118 |
+
total_plots-=1
|
| 119 |
+
|
| 120 |
+
# Create a figure and axes with 1 row and 3 columns
|
| 121 |
+
fig_individual, axes_individual = plt.subplots(1, total_plots, figsize=(6*total_plots, 6))
|
| 122 |
+
|
| 123 |
+
# Set axes list
|
| 124 |
+
axes_list = [axes_individual] if axes is None else [axes_individual, axes]
|
| 125 |
+
|
| 126 |
+
# Unpack the labels and titles
|
| 127 |
+
xlabel, ylabel = ['Sequence Length (AA)', 'Frequency']
|
| 128 |
+
|
| 129 |
+
for cur_axes in axes_list:
|
| 130 |
+
# Plot the first histogram
|
| 131 |
+
cur_axes[0].hist(train_lengths, bins=20, edgecolor='k',color=colormap['train'])
|
| 132 |
+
cur_axes[0].set_xlabel(xlabel)
|
| 133 |
+
cur_axes[0].set_ylabel(ylabel)
|
| 134 |
+
cur_axes[0].set_title(f'Train Set Length Distribution (n={len(train_lengths)})')
|
| 135 |
+
cur_axes[0].grid(True)
|
| 136 |
+
cur_axes[0].set_axisbelow(True)
|
| 137 |
+
|
| 138 |
+
# Plot the second histogram
|
| 139 |
+
if not(val_plot_index is None):
|
| 140 |
+
cur_axes[val_plot_index].hist(val_lengths, bins=20, edgecolor='k',color=colormap['val'])
|
| 141 |
+
cur_axes[val_plot_index].set_xlabel(xlabel)
|
| 142 |
+
cur_axes[val_plot_index].set_ylabel(ylabel)
|
| 143 |
+
cur_axes[val_plot_index].set_title(f'Validation Set Length Distribution (n={len(val_lengths)})')
|
| 144 |
+
cur_axes[val_plot_index].grid(True)
|
| 145 |
+
cur_axes[val_plot_index].set_axisbelow(True)
|
| 146 |
+
|
| 147 |
+
# Plot the third histogram
|
| 148 |
+
if not(test_plot_index is None):
|
| 149 |
+
cur_axes[test_plot_index].hist(test_lengths, bins=20, edgecolor='k',color=colormap['test'])
|
| 150 |
+
cur_axes[test_plot_index].set_xlabel(xlabel)
|
| 151 |
+
cur_axes[test_plot_index].set_ylabel(ylabel)
|
| 152 |
+
cur_axes[test_plot_index].set_title(f'Test Set Length Distribution (n={len(test_lengths)})')
|
| 153 |
+
cur_axes[test_plot_index].grid(True)
|
| 154 |
+
cur_axes[test_plot_index].set_axisbelow(True)
|
| 155 |
+
|
| 156 |
+
# Adjust layout
|
| 157 |
+
fig_individual.set_tight_layout(True)
|
| 158 |
+
|
| 159 |
+
# Save the figure
|
| 160 |
+
fig_individual.savefig(savepath)
|
| 161 |
+
log_update(f"\tSaved figure to {savepath}")
|
| 162 |
+
|
| 163 |
+
def visualize_splits_scatter(train_clusters=None, val_clusters=None, test_clusters=None, benchmark_cluster_reps=None, colormap=None, savepath='splits/scatterplot.png', axes=None):
|
| 164 |
+
set_font()
|
| 165 |
+
if colormap is None: colormap=default_color_map
|
| 166 |
+
|
| 167 |
+
# Create a figure and axes with 1 row and 3 columns
|
| 168 |
+
fig_individual, axes_individual = plt.subplots(figsize=(18, 6))
|
| 169 |
+
|
| 170 |
+
# Set axes list
|
| 171 |
+
axes_list = [axes_individual] if axes is None else [axes_individual, axes]
|
| 172 |
+
|
| 173 |
+
log_update("\nMaking scatterplot with distribution of cluster sizes across train, test, and val")
|
| 174 |
+
# Make grouped versions of these DataFrames for size analysis
|
| 175 |
+
train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
|
| 176 |
+
if not(val_clusters is None):
|
| 177 |
+
val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
|
| 178 |
+
if not(test_clusters is None):
|
| 179 |
+
test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
|
| 180 |
+
# Isolate benchmark-containing clusters so their contribution can be plotted separately
|
| 181 |
+
total_test_proteins = sum(test_clustersgb['member count'])
|
| 182 |
+
if not(benchmark_cluster_reps is None):
|
| 183 |
+
test_clustersgb['benchmark cluster'] = test_clustersgb['representative seq_id'].isin(benchmark_cluster_reps)
|
| 184 |
+
benchmark_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']].reset_index(drop=True)
|
| 185 |
+
test_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']==False].reset_index(drop=True)
|
| 186 |
+
|
| 187 |
+
# Convert them to value counts
|
| 188 |
+
train_clustersgb = train_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
|
| 189 |
+
if not(val_clusters is None):
|
| 190 |
+
val_clustersgb = val_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
|
| 191 |
+
if not(test_clusters is None):
|
| 192 |
+
test_clustersgb = test_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
|
| 193 |
+
if not(benchmark_cluster_reps is None):
|
| 194 |
+
benchmark_clustersgb = benchmark_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
|
| 195 |
+
|
| 196 |
+
# Get the percentage of each dataset that's made of each cluster size
|
| 197 |
+
train_clustersgb['n_proteins'] = train_clustersgb['cluster size (n_members)']*train_clustersgb['n_clusters'] # proteins per cluster * n clusters = # proteins
|
| 198 |
+
train_clustersgb['percent_proteins'] = train_clustersgb['n_proteins']/sum(train_clustersgb['n_proteins'])
|
| 199 |
+
if not(val_clusters is None):
|
| 200 |
+
val_clustersgb['n_proteins'] = val_clustersgb['cluster size (n_members)']*val_clustersgb['n_clusters']
|
| 201 |
+
val_clustersgb['percent_proteins'] = val_clustersgb['n_proteins']/sum(val_clustersgb['n_proteins'])
|
| 202 |
+
if not(test_clusters is None):
|
| 203 |
+
test_clustersgb['n_proteins'] = test_clustersgb['cluster size (n_members)']*test_clustersgb['n_clusters']
|
| 204 |
+
test_clustersgb['percent_proteins'] = test_clustersgb['n_proteins']/total_test_proteins
|
| 205 |
+
if not(benchmark_cluster_reps is None):
|
| 206 |
+
benchmark_clustersgb['n_proteins'] = benchmark_clustersgb['cluster size (n_members)']*benchmark_clustersgb['n_clusters']
|
| 207 |
+
benchmark_clustersgb['percent_proteins'] = benchmark_clustersgb['n_proteins']/total_test_proteins
|
| 208 |
+
|
| 209 |
+
# Specially mark the benchmark clusters because these can't be reallocated
|
| 210 |
+
for ax in axes_list:
|
| 211 |
+
ax.plot(train_clustersgb['cluster size (n_members)'],train_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['train'],label='train')
|
| 212 |
+
if not(val_clusters is None):
|
| 213 |
+
ax.plot(val_clustersgb['cluster size (n_members)'],val_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['val'],label='val')
|
| 214 |
+
if not(test_clusters is None):
|
| 215 |
+
ax.plot(test_clustersgb['cluster size (n_members)'],test_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['test'],label='test')
|
| 216 |
+
if not(benchmark_cluster_reps is None):
|
| 217 |
+
ax.plot(benchmark_clustersgb['cluster size (n_members)'],benchmark_clustersgb['percent_proteins'],
|
| 218 |
+
marker='o',
|
| 219 |
+
linestyle='None',
|
| 220 |
+
markerfacecolor=colormap['test'], # fill same as test
|
| 221 |
+
markeredgecolor='black', # outline black
|
| 222 |
+
markeredgewidth=1.5,
|
| 223 |
+
label='benchmark'
|
| 224 |
+
)
|
| 225 |
+
ax.set(ylabel='Percentage of Proteins in Dataset',xlabel='cluster_size')
|
| 226 |
+
ax.legend()
|
| 227 |
+
|
| 228 |
+
# save the figure
|
| 229 |
+
fig_individual.set_tight_layout(True)
|
| 230 |
+
fig_individual.savefig(savepath)
|
| 231 |
+
log_update(f"\tSaved figure to {savepath}")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def visualize_splits_tsne(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, esm_type="esm2_t33_650M_UR50D", embedding_path="fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl", savepath='splits/tsne_plot.png',axes=None):
|
| 235 |
+
set_font()
|
| 236 |
+
|
| 237 |
+
if colormap is None: colormap=default_color_map
|
| 238 |
+
|
| 239 |
+
"""
|
| 240 |
+
Generate a t-SNE plot of embeddings for train, test, and validation.
|
| 241 |
+
"""
|
| 242 |
+
log_update('\nMaking t-SNE plot of train, val, and test embeddings')
|
| 243 |
+
# Create a figure and axes with 1 row and 3 columns
|
| 244 |
+
fig_individual, axes_individual = plt.subplots(figsize=(18, 6))
|
| 245 |
+
|
| 246 |
+
# Set axes list
|
| 247 |
+
axes_list = [axes_individual] if axes is None else [axes_individual, axes]
|
| 248 |
+
|
| 249 |
+
# Combine the embeddings into one array
|
| 250 |
+
train_embeddings, val_embeddings, test_embeddings = get_avg_embeddings_for_tsne(train_sequences=train_sequences,
|
| 251 |
+
val_sequences=val_sequences,
|
| 252 |
+
test_sequences=test_sequences, embedding_path=embedding_path)
|
| 253 |
+
if not(val_embeddings is None) and not(test_embeddings is None):
|
| 254 |
+
embeddings = np.concatenate([train_embeddings, val_embeddings, test_embeddings])
|
| 255 |
+
labels = ['train'] * len(train_embeddings) + ['val'] * len(val_embeddings) + ['test'] * len(test_embeddings)
|
| 256 |
+
if not(val_embeddings is None) and (test_embeddings is None):
|
| 257 |
+
embeddings = np.concatenate([train_embeddings, val_embeddings])
|
| 258 |
+
labels = ['train'] * len(train_embeddings) + ['val'] * len(val_embeddings)
|
| 259 |
+
if (val_embeddings is None) and not(test_embeddings is None):
|
| 260 |
+
embeddings = np.concatenate([train_embeddings, test_embeddings])
|
| 261 |
+
labels = ['train'] * len(train_embeddings) + ['test'] * len(test_embeddings)
|
| 262 |
+
|
| 263 |
+
# Perform t-SNE
|
| 264 |
+
tsne = TSNE(n_components=2, random_state=42)
|
| 265 |
+
tsne_results = tsne.fit_transform(embeddings)
|
| 266 |
+
|
| 267 |
+
# Convert t-SNE results into a DataFrame
|
| 268 |
+
tsne_df = pd.DataFrame(data=tsne_results, columns=['TSNE_1', 'TSNE_2'])
|
| 269 |
+
tsne_df['label'] = labels
|
| 270 |
+
|
| 271 |
+
for ax in axes_list:
|
| 272 |
+
# Scatter plot for each set
|
| 273 |
+
for label, color in colormap.items():
|
| 274 |
+
subset = tsne_df[tsne_df['label'] == label].reset_index(drop=True)
|
| 275 |
+
ax.scatter(subset['TSNE_1'], subset['TSNE_2'], c=color, label=label.capitalize(), alpha=0.6)
|
| 276 |
+
|
| 277 |
+
ax.set_title(f't-SNE of {esm_type} Embeddings')
|
| 278 |
+
ax.set_xlabel('t-SNE Dimension 1')
|
| 279 |
+
ax.set_ylabel('t-SNE Dimension 2')
|
| 280 |
+
ax.legend()
|
| 281 |
+
ax.grid(True)
|
| 282 |
+
|
| 283 |
+
# Save the figure if savepath is provided
|
| 284 |
+
fig_individual.set_tight_layout(True)
|
| 285 |
+
fig_individual.savefig(savepath)
|
| 286 |
+
log_update(f"\tSaved figure to {savepath}")
|
| 287 |
+
|
| 288 |
+
def visualize_splits_shannon_entropy(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, savepath='splits/shannon_entropy_plot.png',axes=None):
|
| 289 |
+
set_font()
|
| 290 |
+
"""
|
| 291 |
+
Generate Shannon entropy plots for train, validation, and test sets.
|
| 292 |
+
"""
|
| 293 |
+
# Get index for test plot
|
| 294 |
+
val_plot_index, test_plot_index, total_plots = 1, 2, 3
|
| 295 |
+
if val_sequences is None:
|
| 296 |
+
val_plot_index = None
|
| 297 |
+
test_plot_index-= 1
|
| 298 |
+
total_plots-=1
|
| 299 |
+
if test_sequences is None:
|
| 300 |
+
test_plot_index = None
|
| 301 |
+
total_plots-=1
|
| 302 |
+
|
| 303 |
+
if colormap is None: colormap=default_color_map
|
| 304 |
+
# Create a figure and axes with 1 row and 3 columns
|
| 305 |
+
fig_individual, axes_individual = plt.subplots(1, total_plots, figsize=(6*total_plots, 6))
|
| 306 |
+
|
| 307 |
+
# Set axes list
|
| 308 |
+
axes_list = [axes_individual] if axes is None else [axes_individual, axes]
|
| 309 |
+
|
| 310 |
+
log_update('\nMaking histogram of Shannon Entropy distributions')
|
| 311 |
+
train_entropy = [calculate_shannon_entropy(seq) for seq in train_sequences]
|
| 312 |
+
if not(val_plot_index is None):
|
| 313 |
+
val_entropy = [calculate_shannon_entropy(seq) for seq in val_sequences]
|
| 314 |
+
if not(test_plot_index is None):
|
| 315 |
+
test_entropy = [calculate_shannon_entropy(seq) for seq in test_sequences]
|
| 316 |
+
|
| 317 |
+
for ax in axes_list:
|
| 318 |
+
ax[0].hist(train_entropy, bins=20, edgecolor='k', color=colormap['train'])
|
| 319 |
+
ax[0].set_title(f'Train Set (n={len(train_entropy)})')
|
| 320 |
+
ax[0].set_xlabel('Shannon Entropy')
|
| 321 |
+
ax[0].set_ylabel('Frequency')
|
| 322 |
+
ax[0].grid(True)
|
| 323 |
+
ax[0].set_axisbelow(True)
|
| 324 |
+
|
| 325 |
+
if not(val_plot_index is None):
|
| 326 |
+
ax[val_plot_index].hist(val_entropy, bins=20, edgecolor='k', color=colormap['val'])
|
| 327 |
+
ax[val_plot_index].set_title(f'Validation Set (n={len(val_entropy)})')
|
| 328 |
+
ax[val_plot_index].set_xlabel('Shannon Entropy')
|
| 329 |
+
ax[val_plot_index].grid(True)
|
| 330 |
+
ax[val_plot_index].set_axisbelow(True)
|
| 331 |
+
|
| 332 |
+
if not(test_plot_index is None):
|
| 333 |
+
ax[test_plot_index].hist(test_entropy, bins=20, edgecolor='k', color=colormap['test'])
|
| 334 |
+
ax[test_plot_index].set_title(f'Test Set (n={len(test_entropy)})')
|
| 335 |
+
ax[test_plot_index].set_xlabel('Shannon Entropy')
|
| 336 |
+
ax[test_plot_index].grid(True)
|
| 337 |
+
ax[test_plot_index].set_axisbelow(True)
|
| 338 |
+
|
| 339 |
+
fig_individual.set_tight_layout(True)
|
| 340 |
+
fig_individual.savefig(savepath)
|
| 341 |
+
log_update(f"\tSaved figure to {savepath}")
|
| 342 |
+
|
| 343 |
+
def visualize_splits_aa_composition(train_sequences=None, val_sequences=None, test_sequences=None, colormap=None, savepath='splits/aa_comp.png',axes=None):
|
| 344 |
+
set_font()
|
| 345 |
+
if colormap is None: colormap=default_color_map
|
| 346 |
+
|
| 347 |
+
# Create a figure and axes with 1 row and 3 columns
|
| 348 |
+
fig_individual, axes_individual = plt.subplots(figsize=(18, 6))
|
| 349 |
+
|
| 350 |
+
# Set axes list
|
| 351 |
+
axes_list = [axes_individual] if axes is None else [axes_individual, axes]
|
| 352 |
+
|
| 353 |
+
log_update('\nMaking bar plot of AA composition across each set')
|
| 354 |
+
train_comp = calculate_aa_composition(train_sequences)
|
| 355 |
+
if not(val_sequences is None):
|
| 356 |
+
val_comp = calculate_aa_composition(val_sequences)
|
| 357 |
+
if not(test_sequences is None):
|
| 358 |
+
test_comp = calculate_aa_composition(test_sequences)
|
| 359 |
+
|
| 360 |
+
# Create DataFrame
|
| 361 |
+
if not(val_sequences is None) and not(test_sequences is None):
|
| 362 |
+
comp_df = pd.DataFrame([train_comp, val_comp, test_comp], index=['train', 'val', 'test']).T
|
| 363 |
+
if not(val_sequences is None) and (test_sequences is None):
|
| 364 |
+
comp_df = pd.DataFrame([train_comp, val_comp], index=['train', 'val']).T
|
| 365 |
+
if (val_sequences is None) and not(test_sequences is None):
|
| 366 |
+
comp_df = pd.DataFrame([train_comp, test_comp], index=['train', 'test']).T
|
| 367 |
+
colors = [colormap[col] for col in comp_df.columns]
|
| 368 |
+
|
| 369 |
+
# Plotting
|
| 370 |
+
for ax in axes_list:
|
| 371 |
+
comp_df.plot(kind='bar', color=colors, ax=ax)
|
| 372 |
+
ax.set_title('Amino Acid Composition Across Datasets')
|
| 373 |
+
ax.set_xlabel('Amino Acid')
|
| 374 |
+
ax.set_ylabel('Relative Frequency')
|
| 375 |
+
|
| 376 |
+
fig_individual.set_tight_layout(True)
|
| 377 |
+
fig_individual.savefig(savepath)
|
| 378 |
+
log_update(f"\tSaved figure to {savepath}")
|
| 379 |
+
|
| 380 |
+
### Outer methods for visualizing splits
|
| 381 |
+
def visualize_splits(train_clusters=None, val_clusters=None, test_clusters=None, benchmark_cluster_reps=None, train_color='#0072B2',val_color='#009E73',test_color='#E69F00',esm_embeddings_path=None, onehot_embeddings_path=None):
|
| 382 |
+
colormap = {
|
| 383 |
+
'train': train_color,
|
| 384 |
+
'val': val_color,
|
| 385 |
+
'test': test_color
|
| 386 |
+
}
|
| 387 |
+
valid_entry = False
|
| 388 |
+
# Add columns for plotting
|
| 389 |
+
if not(train_clusters is None) and not(val_clusters is None) and not(test_clusters is None):
|
| 390 |
+
visualize_train_val_test_splits(train_clusters, val_clusters, test_clusters,benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap)
|
| 391 |
+
valid_entry=True
|
| 392 |
+
if not(train_clusters is None) and (val_clusters is None) and not(test_clusters is None):
|
| 393 |
+
visualize_train_test_splits(train_clusters, test_clusters, benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap)
|
| 394 |
+
valid_entry=True
|
| 395 |
+
if not(train_clusters is None) and not(val_clusters is None) and (test_clusters is None):
|
| 396 |
+
visualize_train_val_splits(train_clusters, val_clusters, benchmark_cluster_reps=benchmark_cluster_reps,colormap=colormap)
|
| 397 |
+
valid_entry=True
|
| 398 |
+
|
| 399 |
+
if not(valid_entry): raise Exception("Must pass train and at least one of val or test")
|
| 400 |
+
|
| 401 |
+
def visualize_train_val_test_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None):
|
| 402 |
+
if colormap is None: colormap=default_color_map
|
| 403 |
+
# Add length column
|
| 404 |
+
train_clusters['member length'] = train_clusters['member seq'].str.len()
|
| 405 |
+
val_clusters['member length'] = val_clusters['member seq'].str.len()
|
| 406 |
+
test_clusters['member length'] = test_clusters['member seq'].str.len()
|
| 407 |
+
|
| 408 |
+
# Prepare lengths and seqs for plotting
|
| 409 |
+
train_lengths = train_clusters['member length'].tolist()
|
| 410 |
+
val_lengths = val_clusters['member length'].tolist()
|
| 411 |
+
test_lengths = test_clusters['member length'].tolist()
|
| 412 |
+
train_sequences = train_clusters['member seq'].tolist()
|
| 413 |
+
val_sequences = val_clusters['member seq'].tolist()
|
| 414 |
+
test_sequences = test_clusters['member seq'].tolist()
|
| 415 |
+
|
| 416 |
+
# Create a combined figure with 3 rows and 3 columns
|
| 417 |
+
set_font()
|
| 418 |
+
fig_combined, axs = plt.subplots(3, 3, figsize=(24, 18))
|
| 419 |
+
|
| 420 |
+
# Make the three visualization plots for saving TOGETHER
|
| 421 |
+
visualize_splits_hist(train_lengths=train_lengths,
|
| 422 |
+
val_lengths=val_lengths,
|
| 423 |
+
test_lengths=test_lengths,
|
| 424 |
+
colormap=colormap, axes=axs[0])
|
| 425 |
+
visualize_splits_shannon_entropy(train_sequences=train_sequences,
|
| 426 |
+
val_sequences=val_sequences,
|
| 427 |
+
test_sequences=test_sequences,
|
| 428 |
+
colormap=colormap,axes=axs[1])
|
| 429 |
+
visualize_splits_scatter(train_clusters=train_clusters,
|
| 430 |
+
val_clusters=val_clusters,
|
| 431 |
+
test_clusters=test_clusters,
|
| 432 |
+
benchmark_cluster_reps=benchmark_cluster_reps,
|
| 433 |
+
colormap=colormap, axes=axs[2, 0])
|
| 434 |
+
visualize_splits_aa_composition(train_sequences=train_sequences,
|
| 435 |
+
val_sequences=val_sequences,
|
| 436 |
+
test_sequences=test_sequences,
|
| 437 |
+
colormap=colormap, axes=axs[2, 1])
|
| 438 |
+
if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path):
|
| 439 |
+
visualize_splits_tsne(train_sequences=train_sequences,
|
| 440 |
+
val_sequences=val_sequences,
|
| 441 |
+
test_sequences=test_sequences,
|
| 442 |
+
colormap=colormap, axes=axs[2, 2])
|
| 443 |
+
else:
|
| 444 |
+
# Leave the last subplot blank
|
| 445 |
+
axs[2, 2].axis('off')
|
| 446 |
+
|
| 447 |
+
plt.tight_layout()
|
| 448 |
+
fig_combined.savefig('splits/combined_plot.png')
|
| 449 |
+
log_update(f"\nSaved combined figure to splits/combined_plot.png")
|
| 450 |
+
|
| 451 |
+
def visualize_train_test_splits(train_clusters, test_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None):
|
| 452 |
+
if colormap is None: colormap=default_color_map
|
| 453 |
+
# Add length column
|
| 454 |
+
train_clusters['member length'] = train_clusters['member seq'].str.len()
|
| 455 |
+
test_clusters['member length'] = test_clusters['member seq'].str.len()
|
| 456 |
+
|
| 457 |
+
# Prepare lengths and seqs for plotting
|
| 458 |
+
train_lengths = train_clusters['member length'].tolist()
|
| 459 |
+
test_lengths = test_clusters['member length'].tolist()
|
| 460 |
+
train_sequences = train_clusters['member seq'].tolist()
|
| 461 |
+
test_sequences = test_clusters['member seq'].tolist()
|
| 462 |
+
|
| 463 |
+
# Create a combined figure with 4 rows and 2 columns if TSNE plot, 3 x 2 otherwise
|
| 464 |
+
if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path):
|
| 465 |
+
set_font()
|
| 466 |
+
fig_combined, axs = plt.subplots(4, 2, figsize=(18, 36))
|
| 467 |
+
visualize_splits_tsne(train_sequences=train_sequences,
|
| 468 |
+
val_sequences=None,
|
| 469 |
+
test_sequences=test_sequences,
|
| 470 |
+
colormap=colormap, axes=axs[3, 0])
|
| 471 |
+
axs[-1,1].axis('off')
|
| 472 |
+
else:
|
| 473 |
+
set_font()
|
| 474 |
+
fig_combined, axs = plt.subplots(3, 2, figsize=(18, 18))
|
| 475 |
+
|
| 476 |
+
# Make the three visualization plots for saving TOGETHER
|
| 477 |
+
visualize_splits_hist(train_lengths=train_lengths,
|
| 478 |
+
val_lengths=None,
|
| 479 |
+
test_lengths=test_lengths,
|
| 480 |
+
colormap=colormap, axes=axs[0])
|
| 481 |
+
visualize_splits_shannon_entropy(train_sequences=train_sequences,
|
| 482 |
+
val_sequences=None,
|
| 483 |
+
test_sequences=test_sequences,
|
| 484 |
+
colormap=colormap,axes=axs[1])
|
| 485 |
+
visualize_splits_scatter(train_clusters=train_clusters,
|
| 486 |
+
val_clusters=None,
|
| 487 |
+
test_clusters=test_clusters,
|
| 488 |
+
benchmark_cluster_reps=benchmark_cluster_reps,
|
| 489 |
+
colormap=colormap, axes=axs[2, 0])
|
| 490 |
+
visualize_splits_aa_composition(train_sequences=train_sequences,
|
| 491 |
+
val_sequences=None,
|
| 492 |
+
test_sequences=test_sequences,
|
| 493 |
+
colormap=colormap, axes=axs[2, 1])
|
| 494 |
+
|
| 495 |
+
plt.tight_layout()
|
| 496 |
+
fig_combined.savefig('splits/combined_plot.png')
|
| 497 |
+
log_update(f"\nSaved combined figure to splits/combined_plot.png")
|
| 498 |
+
|
| 499 |
+
def visualize_train_val_splits(train_clusters, val_clusters, benchmark_cluster_reps=None, colormap=None, esm_embeddings_path=None, onehot_embeddings_path=None):
|
| 500 |
+
if colormap is None: colormap=default_color_map
|
| 501 |
+
# Add length column
|
| 502 |
+
train_clusters['member length'] = train_clusters['member seq'].str.len()
|
| 503 |
+
val_clusters['member length'] = val_clusters['member seq'].str.len()
|
| 504 |
+
|
| 505 |
+
# Prepare lengths and seqs for plotting
|
| 506 |
+
train_lengths = train_clusters['member length'].tolist()
|
| 507 |
+
val_lengths = val_clusters['member length'].tolist()
|
| 508 |
+
train_sequences = train_clusters['member seq'].tolist()
|
| 509 |
+
val_sequences = val_clusters['member seq'].tolist()
|
| 510 |
+
|
| 511 |
+
# Create a combined figure with 4 rows and 2 columns if TSNE plot, 3 x 2 otherwise
|
| 512 |
+
if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path):
|
| 513 |
+
set_font()
|
| 514 |
+
fig_combined, axs = plt.subplots(4, 2, figsize=(18, 36))
|
| 515 |
+
visualize_splits_tsne(train_sequences=train_sequences,
|
| 516 |
+
val_sequences=val_sequences,
|
| 517 |
+
test_sequences=None,
|
| 518 |
+
colormap=colormap, axes=axs[3, 0])
|
| 519 |
+
axs[-1,1].axis('off')
|
| 520 |
+
else:
|
| 521 |
+
set_font()
|
| 522 |
+
fig_combined, axs = plt.subplots(3, 2, figsize=(18, 18))
|
| 523 |
+
|
| 524 |
+
# Make the three visualization plots for saving TOGETHER
|
| 525 |
+
visualize_splits_hist(train_lengths=train_lengths,
|
| 526 |
+
val_lengths=val_lengths,
|
| 527 |
+
test_lengths=None,
|
| 528 |
+
colormap=colormap, axes=axs[0])
|
| 529 |
+
visualize_splits_shannon_entropy(train_sequences=train_sequences,
|
| 530 |
+
val_sequences=val_sequences,
|
| 531 |
+
test_sequences=None,
|
| 532 |
+
colormap=colormap,axes=axs[1])
|
| 533 |
+
visualize_splits_scatter(train_clusters=train_clusters,
|
| 534 |
+
val_clusters=val_clusters,
|
| 535 |
+
test_clusters=None,
|
| 536 |
+
benchmark_cluster_reps=benchmark_cluster_reps,
|
| 537 |
+
colormap=colormap, axes=axs[2, 0])
|
| 538 |
+
visualize_splits_aa_composition(train_sequences=train_sequences,
|
| 539 |
+
val_sequences=val_sequences,
|
| 540 |
+
test_sequences=None,
|
| 541 |
+
colormap=colormap, axes=axs[2, 1])
|
| 542 |
+
|
| 543 |
+
plt.tight_layout()
|
| 544 |
+
fig_combined.savefig('splits/combined_plot.png')
|
| 545 |
+
log_update(f"\nSaved combined figure to splits/combined_plot.png")
|