ribesstefano commited on
Commit
ed339ed
·
1 Parent(s): 394ed39

Started renaming splits. Added datasets in paper as separate CSV.

Browse files
README.md CHANGED
@@ -1,3 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ![Maturity level-0](https://img.shields.io/badge/Maturity%20Level-ML--0-red)
2
  <a href="https://colab.research.google.com/github/ribesstefano/PROTAC-Degradation-Predictor/blob/main/notebooks/protac_degradation_predictor_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
3
  [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/ailab-bio/PROTAC-Degradation-Predictor)
 
1
+ ---
2
+
3
+ title: PROTAC-Degradation-Predictor
4
+ emoji: 🧬
5
+ colorFrom: pink
6
+ colorTo: green
7
+ sdk: gradio
8
+ sdk_version: 4.37.2
9
+ app_file: app.py
10
+ pinned: false
11
+ license: mit
12
+
13
+ ---
14
+
15
  ![Maturity level-0](https://img.shields.io/badge/Maturity%20Level-ML--0-red)
16
  <a href="https://colab.research.google.com/github/ribesstefano/PROTAC-Degradation-Predictor/blob/main/notebooks/protac_degradation_predictor_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
17
  [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/ailab-bio/PROTAC-Degradation-Predictor)
data/studies/e3_ligase_test_10split_Active_Dmax_0.6_pDC50_6.0.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/studies/e3_ligase_train_val_90split_Active_Dmax_0.6_pDC50_6.0.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/studies/similarity_test_10split_Active_Dmax_0.6_pDC50_6.0.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/studies/similarity_train_val_90split_Active_Dmax_0.6_pDC50_6.0.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/studies/standard_test_10split_Active_Dmax_0.6_pDC50_6.0.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/studies/standard_train_val_90split_Active_Dmax_0.6_pDC50_6.0.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/studies/target_test_10split_Active_Dmax_0.6_pDC50_6.0.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/studies/target_train_val_90split_Active_Dmax_0.6_pDC50_6.0.csv ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/predict_unknown_protacs.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
protac_degradation_predictor/optuna_utils.py CHANGED
@@ -9,25 +9,11 @@ from .pytorch_models import (
9
  )
10
  from .protac_dataset import get_datasets
11
 
12
- from .sklearn_models import (
13
- train_sklearn_model,
14
- suggest_random_forest,
15
- suggest_logistic_regression,
16
- suggest_svc,
17
- suggest_gradient_boosting,
18
- )
19
-
20
  import torch
21
  import optuna
22
  from optuna.samplers import TPESampler
23
  import joblib
24
  import pandas as pd
25
- from sklearn.ensemble import (
26
- RandomForestClassifier,
27
- GradientBoostingClassifier,
28
- )
29
- from sklearn.linear_model import LogisticRegression
30
- from sklearn.svm import SVC
31
  from sklearn.model_selection import (
32
  StratifiedKFold,
33
  StratifiedGroupKFold,
@@ -270,14 +256,23 @@ def hyperparameter_tuning_and_training(
270
  """ Hyperparameter tuning and training of a PROTAC model.
271
 
272
  Args:
273
- train_df (pd.DataFrame): The training set.
274
- val_df (pd.DataFrame): The validation set.
 
 
275
  test_df (pd.DataFrame): The test set.
 
 
 
 
276
  fast_dev_run (bool): Whether to run a fast development run.
277
- n_trials (int): The number of hyperparameter optimization trials.
278
- logger_name (str): The name of the logger.
 
279
  active_label (str): The active label column.
280
- disabled_embeddings (List[str]): The list of disabled embeddings.
 
 
281
 
282
  Returns:
283
  tuple: The trained model, the trainer, and the best metrics.
@@ -507,148 +502,4 @@ def hyperparameter_tuning_and_training(
507
  }
508
  if not fast_dev_run:
509
  ret['majority_vote_report'] = majority_vote_report
510
- return ret
511
-
512
-
513
- def sklearn_model_objective(
514
- trial: optuna.Trial,
515
- protein2embedding: Dict,
516
- cell2embedding: Dict,
517
- smiles2fp: Dict,
518
- train_df: pd.DataFrame,
519
- val_df: pd.DataFrame,
520
- model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
521
- active_label: str = 'Active',
522
- ) -> float:
523
- """ Objective function for hyperparameter optimization.
524
-
525
- Args:
526
- trial (optuna.Trial): The Optuna trial object.
527
- train_df (pd.DataFrame): The training set.
528
- val_df (pd.DataFrame): The validation set.
529
- model_type (str): The model type.
530
- hyperparameters (Dict): The hyperparameters for the model.
531
- fast_dev_run (bool): Whether to run a fast development run.
532
- active_label (str): The active label column.
533
- """
534
-
535
- # Generate the hyperparameters
536
- use_single_scaler = trial.suggest_categorical('use_single_scaler', [True, False])
537
- if model_type == 'RandomForest':
538
- clf = suggest_random_forest(trial)
539
- elif model_type == 'SVC':
540
- clf = suggest_svc(trial)
541
- elif model_type == 'LogisticRegression':
542
- clf = suggest_logistic_regression(trial)
543
- elif model_type == 'GradientBoosting':
544
- clf = suggest_gradient_boosting(trial)
545
- else:
546
- raise ValueError(f'Invalid model type: {model_type}. Available: RandomForest, SVC, LogisticRegression, GradientBoosting.')
547
-
548
- # Train the model with the current set of hyperparameters
549
- _, metrics = train_sklearn_model(
550
- clf=clf,
551
- protein2embedding=protein2embedding,
552
- cell2embedding=cell2embedding,
553
- smiles2fp=smiles2fp,
554
- train_df=train_df,
555
- val_df=val_df,
556
- active_label=active_label,
557
- use_single_scaler=use_single_scaler,
558
- )
559
-
560
- # Metrics is a dictionary containing at least the validation loss
561
- val_acc = metrics['val_acc']
562
- val_roc_auc = metrics['val_roc_auc']
563
-
564
- # Optuna aims to minimize the sklearn_model_objective
565
- return - val_acc - val_roc_auc
566
-
567
-
568
- def hyperparameter_tuning_and_training_sklearn(
569
- protein2embedding: Dict,
570
- cell2embedding: Dict,
571
- smiles2fp: Dict,
572
- train_df: pd.DataFrame,
573
- val_df: pd.DataFrame,
574
- test_df: Optional[pd.DataFrame] = None,
575
- model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
576
- active_label: str = 'Active',
577
- n_trials: int = 50,
578
- logger_name: str = 'protac_hparam_search_sklearn',
579
- study_filename: Optional[str] = None,
580
- ) -> Tuple:
581
- """ Hyperparameter tuning and training of a PROTAC model.
582
-
583
- Args:
584
- train_df (pd.DataFrame): The training set.
585
- val_df (pd.DataFrame): The validation set.
586
- test_df (pd.DataFrame): The test set.
587
- model_type (str): The model type.
588
- n_trials (int): The number of hyperparameter optimization trials.
589
- logger_name (str): The name of the logger. Unused, for compatibility with hyperparameter_tuning_and_training.
590
- active_label (str): The active label column.
591
-
592
- Returns:
593
- tuple: The trained model and the best metrics.
594
- """
595
- # Set the verbosity of Optuna
596
- optuna.logging.set_verbosity(optuna.logging.WARNING)
597
- # Create an Optuna study object
598
- sampler = TPESampler(seed=42, multivariate=True)
599
- study = optuna.create_study(direction='minimize', sampler=sampler)
600
-
601
- study_loaded = False
602
- if study_filename:
603
- if os.path.exists(study_filename):
604
- study = joblib.load(study_filename)
605
- study_loaded = True
606
- logging.info(f'Loaded study from {study_filename}')
607
-
608
- if not study_loaded:
609
- study.optimize(
610
- lambda trial: sklearn_model_objective(
611
- trial=trial,
612
- protein2embedding=protein2embedding,
613
- cell2embedding=cell2embedding,
614
- smiles2fp=smiles2fp,
615
- train_df=train_df,
616
- val_df=val_df,
617
- model_type=model_type,
618
- active_label=active_label,
619
- ),
620
- n_trials=n_trials,
621
- )
622
- if study_filename:
623
- joblib.dump(study, study_filename)
624
-
625
- # Retrain the model with the best hyperparameters
626
- best_hyperparameters = {k.replace('model_', ''): v for k, v in study.best_params.items() if k.startswith('model_')}
627
- if model_type == 'RandomForest':
628
- clf = RandomForestClassifier(random_state=42, **best_hyperparameters)
629
- elif model_type == 'SVC':
630
- clf = SVC(random_state=42, probability=True, **best_hyperparameters)
631
- elif model_type == 'LogisticRegression':
632
- clf = LogisticRegression(random_state=42, max_iter=1000, **best_hyperparameters)
633
- elif model_type == 'GradientBoosting':
634
- clf = GradientBoostingClassifier(random_state=42, **best_hyperparameters)
635
- else:
636
- raise ValueError(f'Invalid model type: {model_type}. Available: RandomForest, SVC, LogisticRegression, GradientBoosting.')
637
-
638
- model, metrics = train_sklearn_model(
639
- clf=clf,
640
- protein2embedding=protein2embedding,
641
- cell2embedding=cell2embedding,
642
- smiles2fp=smiles2fp,
643
- train_df=train_df,
644
- val_df=val_df,
645
- test_df=test_df,
646
- active_label=active_label,
647
- use_single_scaler=study.best_params['use_single_scaler'],
648
- )
649
-
650
- # Report the best hyperparameters found
651
- metrics.update({f'hparam_{k}': v for k, v in study.best_params.items()})
652
-
653
- # Return the best metrics
654
- return model, metrics
 
9
  )
10
  from .protac_dataset import get_datasets
11
 
 
 
 
 
 
 
 
 
12
  import torch
13
  import optuna
14
  from optuna.samplers import TPESampler
15
  import joblib
16
  import pandas as pd
 
 
 
 
 
 
17
  from sklearn.model_selection import (
18
  StratifiedKFold,
19
  StratifiedGroupKFold,
 
256
  """ Hyperparameter tuning and training of a PROTAC model.
257
 
258
  Args:
259
+ protein2embedding (Dict): The protein to embedding dictionary.
260
+ cell2embedding (Dict): The cell to embedding dictionary.
261
+ smiles2fp (Dict): The SMILES to fingerprint dictionary.
262
+ train_val_df (pd.DataFrame): The training and validation set.
263
  test_df (pd.DataFrame): The test set.
264
+ kf (StratifiedKFold | StratifiedGroupKFold): The KFold object.
265
+ groups (np.array): The groups for the StratifiedGroupKFold.
266
+ split_type (str): The split type.
267
+ n_models_for_test (int): The number of models to train for the test set.
268
  fast_dev_run (bool): Whether to run a fast development run.
269
+ n_trials (int): The number of trials for the hyperparameter search.
270
+ logger_save_dir (str): The logger save directory.
271
+ logger_name (str): The logger name.
272
  active_label (str): The active label column.
273
+ max_epochs (int): The maximum number of epochs.
274
+ study_filename (str): The study filename.
275
+ force_study (bool): Whether to force the study.
276
 
277
  Returns:
278
  tuple: The trained model, the trainer, and the best metrics.
 
502
  }
503
  if not fast_dev_run:
504
  ret['majority_vote_report'] = majority_vote_report
505
+ return ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/README.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Models
2
+
3
+ ## Dataset Specification
4
+
5
+ From the repository top level directory, run the following command to get the datasets reported in the paper:
6
+
7
+ ```bash
8
+ cd src
9
+ python get_studies_datasets.py
10
+ ```
11
+
12
+ For training on custom datasets, please refer to the class `PROTAC_Dataset` in the file [`protac_dataset.py`](../protac_degradation_predictor/protac_dataset.py). The class expects a Pandas dataframe, so plase assemble a file to be parsed into a Pandas DataFrame with the following columns:
13
+
14
+ | Column Name | Type | Description |
15
+ | --- | --- | --- |
16
+ | Smiles | str | The SMILES representation of the PROTAC molecule. |
17
+ | Uniprot | str | The Uniprot ID of the target protein. |
18
+ | E3 Ligase Uniprot | str | The Uniprot ID of the E3 ligase. |
19
+ | Cell Line Identifier | str | The cell line identifier as one reported in Cellosaurus. |
20
+ | `<active_label>` | bool | The activity label of the PROTAC molecule to be predicted by the model. |
21
+
22
+ The column `<active_label>` is set _"Active"_ as default in the `PROTAC_Dataset` class and in the `hyperparameter_tuning_and_training` function (see below for how to use it).
23
+
24
+ ## Training on Custom Data
25
+
26
+ For training on custom datasets, please refer to the function `hyperparameter_tuning_and_training` in [`optuna_utils.py`](../protac_degradation_predictor/optuna_utils.py) and the file [`run_experiments.py`](../src/run_experiments.py) for inspiration on how to use the function.
27
+
28
+ An example of skeleton implementation is as follows:
29
+
30
+ ```python
31
+ import protac_degradation_predictor as pdp
32
+ import pandas as pd
33
+ import numpy as np
34
+ from sklearn.model_selection import StratifiedKFold
35
+
36
+ # Load train/val and test dataframes
37
+ train_val_df = pd.read_csv('path/to/custom_dataset.csv')
38
+ test_df = pd.read_csv('path/to/test_dataset.csv') # Load one of our test datasets
39
+
40
+ # NOTE: Make sure to avoid data leakage by removing leaking data in the train/val
41
+ # dataframe. Do NOT do remove/alter the test set, as it would impair comparison
42
+ # with our work. Data leakage can occur if the test set contains any combination
43
+ # of SMILES, Uniprot, E3 Ligase Uniprot, or Cell Line Identifier that is present
44
+ # in the train/val set too.
45
+
46
+ # Precompute Morgan fingerprints
47
+ unique_smiles = pd.concat([train_val_df, test_df])['Smiles'].unique().tolist()
48
+ smiles2fp = {s: np.array(pdp.get_fingerprint(s)) for s in unique_smiles}
49
+
50
+ # Load embedding dictionaries
51
+ protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
52
+ cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
53
+
54
+ # Setup Cross-Validation object
55
+ kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
56
+ pdp.hyperparameter_tuning_and_training(
57
+ protein2embedding=protein2embedding,
58
+ cell2embedding=cell2embedding,
59
+ smiles2fp=smiles2fp,
60
+ train_val_df=train_val_df,
61
+ test_df=test_df,
62
+ kf=kf,
63
+ n_models_for_test=3,
64
+ n_trials=100,
65
+ max_epochs=20,
66
+ logger_save_dir='../logs',
67
+ logger_name=f'logs_{experiment_name}',
68
+ study_filename=f'../reports/study_{experiment_name}.pkl',
69
+ )
70
+ ```
src/get_studies_datasets.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5
+ import protac_degradation_predictor as pdp
6
+
7
+ from collections import defaultdict
8
+ import warnings
9
+ import logging
10
+ from typing import Literal
11
+
12
+ from sklearn.preprocessing import OrdinalEncoder
13
+ from tqdm import tqdm
14
+ import pandas as pd
15
+ import numpy as np
16
+ import pytorch_lightning as pl
17
+ from rdkit import DataStructs
18
+
19
+
20
+ root = logging.getLogger()
21
+ root.setLevel(logging.DEBUG)
22
+
23
+ handler = logging.StreamHandler(sys.stdout)
24
+ handler.setLevel(logging.DEBUG)
25
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
26
+ handler.setFormatter(formatter)
27
+ root.addHandler(handler)
28
+
29
+
30
+ def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index:
31
+ """ Get the indices of the test set using a random split.
32
+
33
+ Args:
34
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
35
+ test_split (float): The percentage of the active PROTACs to use as the test set.
36
+
37
+ Returns:
38
+ pd.Index: The indices of the test set.
39
+ """
40
+ test_df = active_df.sample(frac=test_split, random_state=42)
41
+ return test_df.index
42
+
43
+
44
+ def get_e3_ligase_split_indices(active_df: pd.DataFrame) -> pd.Index:
45
+ """ Get the indices of the test set using the E3 ligase split.
46
+
47
+ Args:
48
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
49
+
50
+ Returns:
51
+ pd.Index: The indices of the test set.
52
+ """
53
+ encoder = OrdinalEncoder()
54
+ active_df['E3 Group'] = encoder.fit_transform(active_df[['E3 Ligase']]).astype(int)
55
+ test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')]
56
+ return test_df.index
57
+
58
+
59
+ def get_smiles2fp_and_avg_tanimoto(protac_df: pd.DataFrame) -> tuple:
60
+ """ Get the SMILES to fingerprint dictionary and the average Tanimoto similarity.
61
+
62
+ Args:
63
+ protac_df (pd.DataFrame): The DataFrame containing the PROTACs.
64
+
65
+ Returns:
66
+ tuple: The SMILES to fingerprint dictionary and the average Tanimoto similarity.
67
+ """
68
+ unique_smiles = protac_df['Smiles'].unique().tolist()
69
+
70
+ smiles2fp = {}
71
+ for smiles in tqdm(unique_smiles, desc='Precomputing fingerprints'):
72
+ smiles2fp[smiles] = pdp.get_fingerprint(smiles)
73
+
74
+ # # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
75
+ # tanimoto_matrix = defaultdict(list)
76
+ # for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
77
+ # fp1 = smiles2fp[smiles1]
78
+ # # TODO: Use BulkTanimotoSimilarity for better performance
79
+ # for j, smiles2 in enumerate(protac_df['Smiles'].unique()[i:]):
80
+ # fp2 = smiles2fp[smiles2]
81
+ # tanimoto_dist = 1 - DataStructs.TanimotoSimilarity(fp1, fp2)
82
+ # tanimoto_matrix[smiles1].append(tanimoto_dist)
83
+ # avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
84
+ # protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
85
+
86
+
87
+ tanimoto_matrix = defaultdict(list)
88
+ fps = list(smiles2fp.values())
89
+
90
+ # Compute all-against-all Tanimoto similarity using BulkTanimotoSimilarity
91
+ for i, (smiles1, fp1) in enumerate(tqdm(zip(unique_smiles, fps), desc='Computing Tanimoto similarity', total=len(fps))):
92
+ similarities = DataStructs.BulkTanimotoSimilarity(fp1, fps[i:]) # Only compute for i to end, avoiding duplicates
93
+ for j, similarity in enumerate(similarities):
94
+ distance = 1 - similarity
95
+ tanimoto_matrix[smiles1].append(distance) # Store as distance
96
+ if i != i + j:
97
+ tanimoto_matrix[unique_smiles[i + j]].append(distance) # Symmetric filling
98
+
99
+ # Calculate average Tanimoto distance for each unique SMILES
100
+ avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
101
+ protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
102
+
103
+ smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
104
+
105
+ return smiles2fp, protac_df
106
+
107
+
108
+ def get_tanimoto_split_indices(
109
+ active_df: pd.DataFrame,
110
+ active_col: str,
111
+ test_split: float,
112
+ n_bins_tanimoto: int = 200,
113
+ ) -> pd.Index:
114
+ """ Get the indices of the test set using the Tanimoto-based split.
115
+
116
+ Args:
117
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
118
+ n_bins_tanimoto (int): The number of bins to use for the Tanimoto similarity.
119
+
120
+ Returns:
121
+ pd.Index: The indices of the test set.
122
+ """
123
+ tanimoto_groups = pd.cut(active_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
124
+ encoder = OrdinalEncoder()
125
+ active_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
126
+ # Sort the groups so that samples with the highest tanimoto similarity,
127
+ # i.e., the "less similar" ones, are placed in the test set first
128
+ tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
129
+
130
+ test_df = []
131
+ # For each group, get the number of active and inactive entries. Then, add those
132
+ # entries to the test_df if: 1) the test_df lenght + the group entries is less
133
+ # 20% of the active_df lenght, and 2) the percentage of True and False entries
134
+ # in the active_col in test_df is roughly 50%.
135
+ for group in tanimoto_groups:
136
+ group_df = active_df[active_df['Tanimoto Group'] == group]
137
+ if test_df == []:
138
+ test_df.append(group_df)
139
+ continue
140
+
141
+ num_entries = len(group_df)
142
+ num_active_group = group_df[active_col].sum()
143
+ num_inactive_group = num_entries - num_active_group
144
+
145
+ tmp_test_df = pd.concat(test_df)
146
+ num_entries_test = len(tmp_test_df)
147
+ num_active_test = tmp_test_df[active_col].sum()
148
+ num_inactive_test = num_entries_test - num_active_test
149
+
150
+ # Check if the group entries can be added to the test_df
151
+ if num_entries_test + num_entries < test_split * len(active_df):
152
+ # Add anything at the beggining
153
+ if num_entries_test + num_entries < test_split / 2 * len(active_df):
154
+ test_df.append(group_df)
155
+ continue
156
+ # Be more selective and make sure that the percentage of active and
157
+ # inactive is balanced
158
+ if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
159
+ if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
160
+ test_df.append(group_df)
161
+ test_df = pd.concat(test_df)
162
+ return test_df.index
163
+
164
+
165
+ def get_target_split_indices(active_df: pd.DataFrame, active_col: str, test_split: float) -> pd.Index:
166
+ """ Get the indices of the test set using the target-based split.
167
+
168
+ Args:
169
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
170
+ active_col (str): The column containing the active/inactive information.
171
+ test_split (float): The percentage of the active PROTACs to use as the test set.
172
+
173
+ Returns:
174
+ pd.Index: The indices of the test set.
175
+ """
176
+ encoder = OrdinalEncoder()
177
+ active_df['Uniprot Group'] = encoder.fit_transform(active_df[['Uniprot']]).astype(int)
178
+
179
+ test_df = []
180
+ # For each group, get the number of active and inactive entries. Then, add those
181
+ # entries to the test_df if: 1) the test_df lenght + the group entries is less
182
+ # 20% of the active_df lenght, and 2) the percentage of True and False entries
183
+ # in the active_col in test_df is roughly 50%.
184
+ # Start the loop from the groups containing the smallest number of entries.
185
+ for group in reversed(active_df['Uniprot'].value_counts().index):
186
+ group_df = active_df[active_df['Uniprot'] == group]
187
+ if test_df == []:
188
+ test_df.append(group_df)
189
+ continue
190
+
191
+ num_entries = len(group_df)
192
+ num_active_group = group_df[active_col].sum()
193
+ num_inactive_group = num_entries - num_active_group
194
+
195
+ tmp_test_df = pd.concat(test_df)
196
+ num_entries_test = len(tmp_test_df)
197
+ num_active_test = tmp_test_df[active_col].sum()
198
+ num_inactive_test = num_entries_test - num_active_test
199
+
200
+ # Check if the group entries can be added to the test_df
201
+ if num_entries_test + num_entries < test_split * len(active_df):
202
+ # Add anything at the beggining
203
+ if num_entries_test + num_entries < test_split / 2 * len(active_df):
204
+ test_df.append(group_df)
205
+ continue
206
+ # Be more selective and make sure that the percentage of active and
207
+ # inactive is balanced
208
+ if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
209
+ if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
210
+ test_df.append(group_df)
211
+ test_df = pd.concat(test_df)
212
+ return test_df.index
213
+
214
+
215
+ def main(
216
+ active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
217
+ test_split: float = 0.1,
218
+ studies: str | Literal['all', 'standard', 'e3_ligase', 'similarity', 'target'] = 'all',
219
+ ):
220
+ """ Get and save the datasets for the different studies.
221
+
222
+ Args:
223
+ active_col (str): The column containing the active/inactive information. It should be in the format 'Active (Dmax N, pDC50 M)', where N and M are the thresholds float values for Dmax and pDC50, respectively.
224
+ test_split (float): The percentage of the active PROTACs to use as the test set.
225
+ studies (str): The type of studies to save dataset for. Options: 'all', 'standard', 'e3_ligase', 'similarity', 'target'.
226
+ """
227
+ pl.seed_everything(42)
228
+
229
+ # Set the Column to Predict
230
+ active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
231
+
232
+ # Get Dmax_threshold from the active_col
233
+ Dmax_threshold = float(active_col.split('Dmax')[1].split(',')[0].strip('(').strip(')').strip())
234
+ pDC50_threshold = float(active_col.split('pDC50')[1].strip('(').strip(')').strip())
235
+
236
+ # Load the PROTAC dataset
237
+ protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
238
+ # Map E3 Ligase Iap to IAP
239
+ protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
240
+ protac_df[active_col] = protac_df.apply(
241
+ lambda x: pdp.is_active(x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold), axis=1
242
+ )
243
+ _, protac_df = get_smiles2fp_and_avg_tanimoto(protac_df)
244
+
245
+ ## Get the test sets
246
+ test_indeces = {}
247
+ active_df = protac_df[protac_df[active_col].notna()].copy()
248
+
249
+ # Remove legacy column 'Active - OR' if it exists
250
+ if 'Active - OR' in active_df.columns:
251
+ active_df.drop(columns='Active - OR', inplace=True)
252
+
253
+ if studies == 'standard' or studies == 'all':
254
+ test_indeces['standard'] = get_random_split_indices(active_df, test_split)
255
+ if studies == 'target' or studies == 'all':
256
+ test_indeces['target'] = get_target_split_indices(active_df, active_col, test_split)
257
+ if studies == 'e3_ligase' or studies == 'all':
258
+ test_indeces['e3_ligase'] = get_e3_ligase_split_indices(active_df)
259
+ if studies == 'similarity' or studies == 'all':
260
+ test_indeces['similarity'] = get_tanimoto_split_indices(active_df, active_col, test_split)
261
+
262
+ # Make directory for studies datasets if it does not exist
263
+ data_dir = '../data/studies'
264
+ if not os.path.exists(data_dir):
265
+ os.makedirs(data_dir)
266
+
267
+ # Cross-Validation Training
268
+ for split_type, indeces in test_indeces.items():
269
+ test_df = active_df.loc[indeces].copy()
270
+ train_val_df = active_df[~active_df.index.isin(test_df.index)].copy()
271
+
272
+ # Save the datasets
273
+
274
+ train_val_perc = f'{int((1 - test_split) * 100)}'
275
+ test_perc = f'{int(test_split * 100)}'
276
+
277
+ train_val_filename = f'{data_dir}/{split_type}_train_val_{train_val_perc}split_{active_name}.csv'
278
+ test_filename = f'{data_dir}/{split_type}_test_{test_perc}split_{active_name}.csv'
279
+
280
+ print('')
281
+ print(f'Saving train_val datasets as: {train_val_filename}')
282
+ print(f'Saving test datasets as: {test_filename}')
283
+
284
+ train_val_df.to_csv(train_val_filename, index=False)
285
+ test_df.to_csv(test_filename, index=False)
286
+
287
+
288
+ if __name__ == '__main__':
289
+ main()