Commit
·
8aec0bb
1
Parent(s):
1ee75b1
Started working on cell line one-hot encoding experiments
Browse files
protac_degradation_predictor/protac_dataset.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
from typing import Literal, List, Tuple, Optional, Dict
|
2 |
from collections import defaultdict
|
|
|
|
|
3 |
|
4 |
from .data_utils import (
|
5 |
get_fingerprint,
|
@@ -24,15 +26,16 @@ class PROTAC_Dataset(Dataset):
|
|
24 |
def __init__(
|
25 |
self,
|
26 |
protac_df: pd.DataFrame,
|
27 |
-
protein2embedding: Dict,
|
28 |
-
cell2embedding: Dict,
|
29 |
-
smiles2fp: Dict,
|
30 |
use_smote: bool = False,
|
31 |
oversampler: Optional[SMOTE | ADASYN] = None,
|
32 |
active_label: str = 'Active',
|
33 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
34 |
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
|
35 |
use_single_scaler: Optional[bool] = None,
|
|
|
36 |
):
|
37 |
""" Initialize the PROTAC dataset
|
38 |
|
@@ -47,6 +50,7 @@ class PROTAC_Dataset(Dataset):
|
|
47 |
disabled_embeddings (list): The list of embeddings to disable, i.e., return a zero vector
|
48 |
scaler (StandardScaler | dict): The scaler to use for the embeddings
|
49 |
use_single_scaler (bool): Whether to use a single scaler for all features
|
|
|
50 |
"""
|
51 |
# Filter out examples with NaN in active_label column
|
52 |
self.data = protac_df # [~protac_df[active_label].isna()]
|
@@ -84,6 +88,22 @@ class PROTAC_Dataset(Dataset):
|
|
84 |
self.oversampler = oversampler
|
85 |
if self.use_smote:
|
86 |
self.apply_smote()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
def apply_smote(self):
|
89 |
# Prepare the dataset for SMOTE
|
@@ -269,6 +289,17 @@ class PROTAC_Dataset(Dataset):
|
|
269 |
else:
|
270 |
cell_emb = self.data['Cell Line Identifier'].iloc[idx]
|
271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
elem = {
|
273 |
'smiles_emb': smiles_emb,
|
274 |
'poi_emb': poi_emb,
|
@@ -293,6 +324,7 @@ def get_datasets(
|
|
293 |
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
|
294 |
use_single_scaler: Optional[bool] = None,
|
295 |
apply_scaling: bool = False,
|
|
|
296 |
) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]:
|
297 |
""" Get the datasets for training the PROTAC model.
|
298 |
|
@@ -323,6 +355,7 @@ def get_datasets(
|
|
323 |
disabled_embeddings=disabled_embeddings,
|
324 |
scaler=scaler,
|
325 |
use_single_scaler=use_single_scaler,
|
|
|
326 |
)
|
327 |
val_ds = PROTAC_Dataset(
|
328 |
val_df,
|
|
|
1 |
from typing import Literal, List, Tuple, Optional, Dict
|
2 |
from collections import defaultdict
|
3 |
+
import random
|
4 |
+
import logging
|
5 |
|
6 |
from .data_utils import (
|
7 |
get_fingerprint,
|
|
|
26 |
def __init__(
|
27 |
self,
|
28 |
protac_df: pd.DataFrame,
|
29 |
+
protein2embedding: Dict[str, np.ndarray],
|
30 |
+
cell2embedding: Dict[str, np.ndarray],
|
31 |
+
smiles2fp: Dict[str, np.ndarray],
|
32 |
use_smote: bool = False,
|
33 |
oversampler: Optional[SMOTE | ADASYN] = None,
|
34 |
active_label: str = 'Active',
|
35 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
36 |
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
|
37 |
use_single_scaler: Optional[bool] = None,
|
38 |
+
shuffle_embedding_prob: float = 0.0,
|
39 |
):
|
40 |
""" Initialize the PROTAC dataset
|
41 |
|
|
|
50 |
disabled_embeddings (list): The list of embeddings to disable, i.e., return a zero vector
|
51 |
scaler (StandardScaler | dict): The scaler to use for the embeddings
|
52 |
use_single_scaler (bool): Whether to use a single scaler for all features
|
53 |
+
shuffle_embedding_prob (float): The probability of shuffling the embeddings. Used for testing whether embeddings act as "barcodes". Defaults to 0.0, i.e., no shuffling.
|
54 |
"""
|
55 |
# Filter out examples with NaN in active_label column
|
56 |
self.data = protac_df # [~protac_df[active_label].isna()]
|
|
|
88 |
self.oversampler = oversampler
|
89 |
if self.use_smote:
|
90 |
self.apply_smote()
|
91 |
+
|
92 |
+
if shuffle_embedding_prob > 0.0:
|
93 |
+
self.shuffle_embedding_prob = shuffle_embedding_prob
|
94 |
+
# Set random seed
|
95 |
+
random.seed(42)
|
96 |
+
if self.protein_emb_dim != self.cell_emb_dim:
|
97 |
+
logging.warning('Protein and cell embeddings have different dimensions. Shuffling will be on POI and E3 embeddings only.')
|
98 |
+
|
99 |
+
def get_smiles_emb_dim(self):
|
100 |
+
return self.smiles_emb_dim
|
101 |
+
|
102 |
+
def get_protein_emb_dim(self):
|
103 |
+
return self.protein_emb_dim
|
104 |
+
|
105 |
+
def get_cell_emb_dim(self):
|
106 |
+
return self.cell_emb_dim
|
107 |
|
108 |
def apply_smote(self):
|
109 |
# Prepare the dataset for SMOTE
|
|
|
289 |
else:
|
290 |
cell_emb = self.data['Cell Line Identifier'].iloc[idx]
|
291 |
|
292 |
+
# Shuffle the embeddings if the probability is met
|
293 |
+
if random.random() < self.shuffle_embedding_prob:
|
294 |
+
if self.protein_emb_dim == self.cell_emb_dim:
|
295 |
+
# Randomly shuffle the embeddings for POI, cell, and E3
|
296 |
+
embeddings = np.vstack([poi_emb, e3_emb, cell_emb])
|
297 |
+
np.random.shuffle(embeddings)
|
298 |
+
poi_emb, e3_emb, cell_emb = embeddings
|
299 |
+
else:
|
300 |
+
# Swap POI and E3 embeddings only, because of different dimensions
|
301 |
+
poi_emb, e3_emb = e3_emb, poi_emb
|
302 |
+
|
303 |
elem = {
|
304 |
'smiles_emb': smiles_emb,
|
305 |
'poi_emb': poi_emb,
|
|
|
324 |
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
|
325 |
use_single_scaler: Optional[bool] = None,
|
326 |
apply_scaling: bool = False,
|
327 |
+
shuffle_embedding_prob: float = 0.0,
|
328 |
) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]:
|
329 |
""" Get the datasets for training the PROTAC model.
|
330 |
|
|
|
355 |
disabled_embeddings=disabled_embeddings,
|
356 |
scaler=scaler,
|
357 |
use_single_scaler=use_single_scaler,
|
358 |
+
shuffle_embedding_prob=shuffle_embedding_prob,
|
359 |
)
|
360 |
val_ds = PROTAC_Dataset(
|
361 |
val_df,
|
protac_degradation_predictor/pytorch_models.py
CHANGED
@@ -23,7 +23,8 @@ from torchmetrics import (
|
|
23 |
MetricCollection,
|
24 |
)
|
25 |
from imblearn.over_sampling import SMOTE
|
26 |
-
from sklearn.preprocessing import StandardScaler
|
|
|
27 |
|
28 |
|
29 |
class PROTAC_Predictor(nn.Module):
|
@@ -402,9 +403,9 @@ class PROTAC_Model(pl.LightningModule):
|
|
402 |
|
403 |
# TODO: Use some sort of **kwargs to pass all the parameters to the model...
|
404 |
def train_model(
|
405 |
-
protein2embedding: Dict,
|
406 |
-
cell2embedding: Dict,
|
407 |
-
smiles2fp: Dict,
|
408 |
train_df: pd.DataFrame,
|
409 |
val_df: pd.DataFrame,
|
410 |
test_df: Optional[pd.DataFrame] = None,
|
@@ -414,10 +415,6 @@ def train_model(
|
|
414 |
dropout: float = 0.2,
|
415 |
max_epochs: int = 50,
|
416 |
use_batch_norm: bool = False,
|
417 |
-
smiles_emb_dim: int = config.fingerprint_size,
|
418 |
-
poi_emb_dim: int = config.protein_embedding_size,
|
419 |
-
e3_emb_dim: int = config.protein_embedding_size,
|
420 |
-
cell_emb_dim: int = config.cell_embedding_size,
|
421 |
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
|
422 |
smote_k_neighbors:int = 5,
|
423 |
use_smote: bool = True,
|
@@ -431,29 +428,61 @@ def train_model(
|
|
431 |
checkpoint_model_name: str = 'protac',
|
432 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
433 |
return_predictions: bool = False,
|
|
|
|
|
|
|
434 |
) -> tuple:
|
435 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
436 |
|
437 |
Args:
|
438 |
-
protein2embedding (dict):
|
439 |
-
cell2embedding (dict):
|
440 |
-
smiles2fp (dict):
|
441 |
-
train_df (pd.DataFrame): The training
|
442 |
-
val_df (pd.DataFrame): The validation
|
443 |
-
test_df (pd.DataFrame): The test
|
444 |
-
hidden_dim (int): The hidden dimension of the model
|
445 |
-
batch_size (int): The batch size
|
446 |
-
learning_rate (float): The learning rate
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
|
454 |
Returns:
|
455 |
tuple: The trained model, the trainer, and the metrics over the validation and test sets.
|
456 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
train_ds, val_ds, test_ds = get_datasets(
|
458 |
train_df,
|
459 |
val_df,
|
@@ -465,7 +494,14 @@ def train_model(
|
|
465 |
smote_k_neighbors=smote_k_neighbors,
|
466 |
active_label=active_label,
|
467 |
disabled_embeddings=disabled_embeddings,
|
|
|
468 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
loggers = [
|
470 |
pl.loggers.TensorBoardLogger(
|
471 |
save_dir=logger_save_dir,
|
|
|
23 |
MetricCollection,
|
24 |
)
|
25 |
from imblearn.over_sampling import SMOTE
|
26 |
+
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
27 |
+
from sklearn.feature_extraction.text import CountVectorizer
|
28 |
|
29 |
|
30 |
class PROTAC_Predictor(nn.Module):
|
|
|
403 |
|
404 |
# TODO: Use some sort of **kwargs to pass all the parameters to the model...
|
405 |
def train_model(
|
406 |
+
protein2embedding: Dict[str, np.ndarray],
|
407 |
+
cell2embedding: Dict[str, np.ndarray],
|
408 |
+
smiles2fp: Dict[str, np.ndarray],
|
409 |
train_df: pd.DataFrame,
|
410 |
val_df: pd.DataFrame,
|
411 |
test_df: Optional[pd.DataFrame] = None,
|
|
|
415 |
dropout: float = 0.2,
|
416 |
max_epochs: int = 50,
|
417 |
use_batch_norm: bool = False,
|
|
|
|
|
|
|
|
|
418 |
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
|
419 |
smote_k_neighbors:int = 5,
|
420 |
use_smote: bool = True,
|
|
|
428 |
checkpoint_model_name: str = 'protac',
|
429 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
430 |
return_predictions: bool = False,
|
431 |
+
shuffle_embedding_prob: float = 0.0,
|
432 |
+
use_cells_one_hot: bool = False,
|
433 |
+
use_amino_acid_count: bool = False,
|
434 |
) -> tuple:
|
435 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
436 |
|
437 |
Args:
|
438 |
+
protein2embedding (dict): A dictionary mapping protein identifiers to embeddings.
|
439 |
+
cell2embedding (dict): A dictionary mapping cell line identifiers to embeddings.
|
440 |
+
smiles2fp (dict): A dictionary mapping SMILES strings to fingerprints.
|
441 |
+
train_df (pd.DataFrame): The training dataframe.
|
442 |
+
val_df (pd.DataFrame): The validation dataframe.
|
443 |
+
test_df (Optional[pd.DataFrame]): The test dataframe.
|
444 |
+
hidden_dim (int): The hidden dimension of the model
|
445 |
+
batch_size (int): The batch size
|
446 |
+
learning_rate (float): The learning rate
|
447 |
+
dropout (float): The dropout rate
|
448 |
+
max_epochs (int): The maximum number of epochs
|
449 |
+
use_batch_norm (bool): Whether to use batch normalization
|
450 |
+
join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings
|
451 |
+
smote_k_neighbors (int): The number of neighbors to use in SMOTE
|
452 |
+
use_smote (bool): Whether to use SMOTE
|
453 |
+
apply_scaling (bool): Whether to apply scaling to the embeddings
|
454 |
+
active_label (str): The name of the active label. Default: 'Active'
|
455 |
+
fast_dev_run (bool): Whether to run a fast development run (see PyTorch Lightning documentation)
|
456 |
+
use_logger (bool): Whether to use a logger
|
457 |
+
logger_save_dir (str): The directory to save the logs
|
458 |
+
logger_name (str): The name of the logger
|
459 |
+
enable_checkpointing (bool): Whether to enable checkpointing
|
460 |
+
checkpoint_model_name (str): The name of the model for checkpointing
|
461 |
+
disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles'
|
462 |
+
return_predictions (bool): Whether to return predictions on the validation and test sets
|
463 |
|
464 |
Returns:
|
465 |
tuple: The trained model, the trainer, and the metrics over the validation and test sets.
|
466 |
"""
|
467 |
+
if use_cells_one_hot:
|
468 |
+
# Get one-hot encoded embeddings for cell lines
|
469 |
+
onehotenc = OneHotEncoder(sparse_output=False)
|
470 |
+
cell_embeddings = onehotenc.fit_transform(
|
471 |
+
np.array(list(cell2embedding.keys()))
|
472 |
+
)
|
473 |
+
cell2embedding = {k: v for k, v in zip(cell2embedding.keys(), cell_embeddings)}
|
474 |
+
|
475 |
+
if use_amino_acid_count:
|
476 |
+
# Get count vectorized embeddings for proteins
|
477 |
+
# NOTE: Check that the protein2embedding is a dictionary of strings
|
478 |
+
if not all(isinstance(k, str) for k in protein2embedding.keys()):
|
479 |
+
raise ValueError("All keys in `protein2embedding` must be strings.")
|
480 |
+
countvec = CountVectorizer(ngram_range=(1,1), analyzer='char')
|
481 |
+
protein_embeddings = countvec.fit_transform(
|
482 |
+
list(protein2embedding.keys())
|
483 |
+
)
|
484 |
+
protein2embedding = {k: v for k, v in zip(protein2embedding.keys(), protein_embeddings)}
|
485 |
+
|
486 |
train_ds, val_ds, test_ds = get_datasets(
|
487 |
train_df,
|
488 |
val_df,
|
|
|
494 |
smote_k_neighbors=smote_k_neighbors,
|
495 |
active_label=active_label,
|
496 |
disabled_embeddings=disabled_embeddings,
|
497 |
+
shuffle_embedding_prob=shuffle_embedding_prob,
|
498 |
)
|
499 |
+
# NOTE: The embeddings dimensions should already match in all sets
|
500 |
+
smiles_emb_dim = train_ds.get_smiles_emb_dim()
|
501 |
+
poi_emb_dim = train_ds.get_protein_emb_dim()
|
502 |
+
e3_emb_dim = train_ds.get_protein_emb_dim()
|
503 |
+
cell_emb_dim = train_ds.get_cell_emb_dim()
|
504 |
+
|
505 |
loggers = [
|
506 |
pl.loggers.TensorBoardLogger(
|
507 |
save_dir=logger_save_dir,
|
src/run_experiments.py
CHANGED
@@ -238,10 +238,15 @@ def main(
|
|
238 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
239 |
|
240 |
Args:
|
241 |
-
|
242 |
-
n_trials (int): The number of hyperparameter
|
243 |
-
n_splits (int): The number of cross-validation splits.
|
244 |
fast_dev_run (bool): Whether to run a fast development run.
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
"""
|
246 |
pl.seed_everything(42)
|
247 |
|
|
|
238 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
239 |
|
240 |
Args:
|
241 |
+
active_col (str): The column containing the active/inactive information. Must be in the format 'Active (Dmax N, pDC50 M)'.
|
242 |
+
n_trials (int): The number of hyperparameter tuning trials to run.
|
|
|
243 |
fast_dev_run (bool): Whether to run a fast development run.
|
244 |
+
test_split (float): The percentage of the active PROTACs to use as the test set.
|
245 |
+
cv_n_splits (int): The number of cross-validation splits to use.
|
246 |
+
max_epochs (int): The maximum number of epochs to train the model.
|
247 |
+
run_sklearn (bool): Whether to run sklearn models.
|
248 |
+
force_study (bool): Whether to force the creation of a new Optuna study.
|
249 |
+
experiments (str): The type of experiments to run.
|
250 |
"""
|
251 |
pl.seed_everything(42)
|
252 |
|
src/run_experiments_cells_onehot.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from collections import defaultdict
|
4 |
+
import warnings
|
5 |
+
import logging
|
6 |
+
from typing import Literal
|
7 |
+
|
8 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
9 |
+
|
10 |
+
import protac_degradation_predictor as pdp
|
11 |
+
|
12 |
+
import pytorch_lightning as pl
|
13 |
+
from rdkit import Chem
|
14 |
+
from rdkit.Chem import AllChem
|
15 |
+
from rdkit import DataStructs
|
16 |
+
from jsonargparse import CLI
|
17 |
+
import pandas as pd
|
18 |
+
from tqdm import tqdm
|
19 |
+
import numpy as np
|
20 |
+
from sklearn.preprocessing import OrdinalEncoder
|
21 |
+
from sklearn.model_selection import (
|
22 |
+
StratifiedKFold,
|
23 |
+
StratifiedGroupKFold,
|
24 |
+
)
|
25 |
+
|
26 |
+
# Ignore UserWarning from Matplotlib
|
27 |
+
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
28 |
+
# Ignore UserWarning from PyTorch Lightning
|
29 |
+
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
30 |
+
|
31 |
+
|
32 |
+
root = logging.getLogger()
|
33 |
+
root.setLevel(logging.DEBUG)
|
34 |
+
|
35 |
+
handler = logging.StreamHandler(sys.stdout)
|
36 |
+
handler.setLevel(logging.DEBUG)
|
37 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
38 |
+
handler.setFormatter(formatter)
|
39 |
+
root.addHandler(handler)
|
40 |
+
|
41 |
+
def main(
|
42 |
+
active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
|
43 |
+
n_trials: int = 100,
|
44 |
+
fast_dev_run: bool = False,
|
45 |
+
test_split: float = 0.1,
|
46 |
+
cv_n_splits: int = 5,
|
47 |
+
max_epochs: int = 100,
|
48 |
+
force_study: bool = False,
|
49 |
+
experiments: str | Literal['all', 'standard', 'e3_ligase', 'similarity', 'target'] = 'all',
|
50 |
+
):
|
51 |
+
""" Run experiments with the cells one-hot encoding model.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
active_col (str): Name of the column containing the active values.
|
55 |
+
n_trials (int): Number of hyperparameter optimization trials.
|
56 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
57 |
+
test_split (float): Percentage of data to use for testing.
|
58 |
+
cv_n_splits (int): Number of cross-validation splits.
|
59 |
+
max_epochs (int): Maximum number of epochs to train the model.
|
60 |
+
force_study (bool): Whether to force the creation of a new study.
|
61 |
+
experiments (str): Type of experiments to run. Options are 'all', 'standard', 'e3_ligase', 'similarity', 'target'.
|
62 |
+
"""
|
63 |
+
|
64 |
+
# Make directory ../reports if it does not exist
|
65 |
+
if not os.path.exists('../reports'):
|
66 |
+
os.makedirs('../reports')
|
67 |
+
|
68 |
+
# Load embedding dictionaries
|
69 |
+
protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
|
70 |
+
cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
|
71 |
+
|
72 |
+
studies_dir = '../data/studies'
|
73 |
+
train_val_perc = f'{int((1 - test_split) * 100)}'
|
74 |
+
test_perc = f'{int(test_split * 100)}'
|
75 |
+
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
|
76 |
+
|
77 |
+
if experiments == 'all':
|
78 |
+
experiments = ['standard', 'similarity', 'target']
|
79 |
+
|
80 |
+
# Cross-Validation Training
|
81 |
+
reports = defaultdict(list)
|
82 |
+
for split_type in experiments:
|
83 |
+
|
84 |
+
train_val_filename = f'{split_type}_train_val_{train_val_perc}split_{active_name}.csv'
|
85 |
+
test_filename = f'{split_type}_test_{test_perc}split_{active_name}.csv'
|
86 |
+
|
87 |
+
train_val_df = pd.read_csv(os.path.join(studies_dir, train_val_filename))
|
88 |
+
test_df = pd.read_csv(os.path.join(studies_dir, test_filename))
|
89 |
+
|
90 |
+
# Get SMILES and precompute fingerprints dictionary
|
91 |
+
unique_smiles = pd.concat([train_val_df, test_df])['Smiles'].unique().tolist()
|
92 |
+
smiles2fp = {s: np.array(pdp.get_fingerprint(s)) for s in unique_smiles}
|
93 |
+
|
94 |
+
# Get the CV object
|
95 |
+
if split_type == 'standard':
|
96 |
+
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
97 |
+
group = None
|
98 |
+
elif split_type == 'e3_ligase':
|
99 |
+
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
100 |
+
group = train_val_df['E3 Group'].to_numpy()
|
101 |
+
elif split_type == 'similarity':
|
102 |
+
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
103 |
+
group = train_val_df['Tanimoto Group'].to_numpy()
|
104 |
+
elif split_type == 'target':
|
105 |
+
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
106 |
+
group = train_val_df['Uniprot Group'].to_numpy()
|
107 |
+
|
108 |
+
# Start the experiment
|
109 |
+
experiment_name = f'{active_name}_test_split_{test_split}_{split_type}'
|
110 |
+
optuna_reports = pdp.hyperparameter_tuning_and_training(
|
111 |
+
protein2embedding=protein2embedding,
|
112 |
+
cell2embedding=cell2embedding,
|
113 |
+
smiles2fp=smiles2fp,
|
114 |
+
train_val_df=train_val_df,
|
115 |
+
test_df=test_df,
|
116 |
+
kf=kf,
|
117 |
+
groups=group,
|
118 |
+
split_type=split_type,
|
119 |
+
n_models_for_test=3,
|
120 |
+
fast_dev_run=fast_dev_run,
|
121 |
+
n_trials=n_trials,
|
122 |
+
max_epochs=max_epochs,
|
123 |
+
logger_save_dir='../logs',
|
124 |
+
logger_name=f'logs_{experiment_name}',
|
125 |
+
active_label=active_col,
|
126 |
+
study_filename=f'../reports/study_cellsonehot_{experiment_name}.pkl',
|
127 |
+
force_study=force_study,
|
128 |
+
)
|
129 |
+
|
130 |
+
# Save the reports to file
|
131 |
+
for report_name, report in optuna_reports.items():
|
132 |
+
report.to_csv(f'../reports/cellsonehot_{report_name}_{experiment_name}.csv', index=False)
|
133 |
+
reports[report_name].append(report.copy())
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == '__main__':
|
137 |
+
cli = CLI(main)
|
src/{run_xgboost_experiments.py → run_experiments_xgboost.py}
RENAMED
File without changes
|