{ "cells": [ { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "from collections import defaultdict\n", "import warnings\n", "import logging\n", "from typing import Literal\n", "\n", "sys.path.append('~/PROTAC-Degradation-Predictor/protac_degradation_predictor')\n", "import protac_degradation_predictor as pdp\n", "\n", "import pytorch_lightning as pl\n", "from rdkit import Chem\n", "from rdkit.Chem import AllChem\n", "from rdkit import DataStructs\n", "from jsonargparse import CLI\n", "import pandas as pd\n", "# Import tqdm for notebook\n", "from tqdm.notebook import tqdm\n", "import numpy as np\n", "from sklearn.preprocessing import OrdinalEncoder\n", "from sklearn.model_selection import (\n", " StratifiedKFold,\n", " StratifiedGroupKFold,\n", ")\n", "\n", "\n", "active_col = 'Active (Dmax 0.6, pDC50 6.0)'\n", "pDC50_threshold = 6.0\n", "Dmax_threshold = 0.6\n", "\n", "protac_df = pd.read_csv('~/PROTAC-Degradation-Predictor/data/PROTAC-Degradation-DB.csv')\n", "protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')\n", "protac_df[active_col] = protac_df.apply(\n", " lambda x: pdp.is_active(x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold), axis=1\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "771" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index:\n", " \"\"\" Get the indices of the test set using a random split.\n", " \n", " Args:\n", " active_df (pd.DataFrame): The DataFrame containing the active PROTACs.\n", " test_split (float): The percentage of the active PROTACs to use as the test set.\n", " \n", " Returns:\n", " pd.Index: The indices of the test set.\n", " \"\"\"\n", " test_df = active_df.sample(frac=test_split, random_state=42)\n", " return test_df.index\n", "\n", "active_df = protac_df[protac_df[active_col].notna()].copy()\n", "test_split = 0.1\n", "test_indices = get_random_split_indices(active_df, test_split)\n", "train_val_df = active_df[~active_df.index.isin(test_indices)].copy()\n", "len(train_val_df)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "import optuna\n", "\n", "def objective(trial: optuna.Trial, verbose: int = 0) -> float:\n", " \n", " radius = trial.suggest_int('radius', 1, 15)\n", " fpsize = trial.suggest_int('fpsize', 128, 2048, step=128)\n", "\n", " morgan_fpgen = AllChem.GetMorganGenerator(\n", " radius=radius,\n", " fpSize=fpsize,\n", " includeChirality=True,\n", " )\n", "\n", " smiles2fp = {}\n", " for smiles in train_val_df['Smiles'].unique().tolist():\n", " smiles2fp[smiles] = pdp.get_fingerprint(smiles, morgan_fpgen)\n", "\n", " # Count the number of unique SMILES and the number of unique Morgan fingerprints\n", " unique_fps = set([tuple(fp) for fp in smiles2fp.values()])\n", " # Get the list of SMILES with overlapping fingerprints\n", " overlapping_smiles = []\n", " unique_fps = set()\n", " for smiles, fp in smiles2fp.items():\n", " if tuple(fp) in unique_fps:\n", " overlapping_smiles.append(smiles)\n", " else:\n", " unique_fps.add(tuple(fp))\n", " num_overlaps = len(train_val_df[train_val_df[\"Smiles\"].isin(overlapping_smiles)])\n", " num_overlaps_tot = len(protac_df[protac_df[\"Smiles\"].isin(overlapping_smiles)])\n", "\n", " if verbose:\n", " print(f'Radius: {radius}')\n", " print(f'FP length: {fpsize}')\n", " print(f'Number of unique SMILES: {len(smiles2fp)}')\n", " print(f'Number of unique fingerprints: {len(unique_fps)}')\n", " print(f'Number of SMILES with overlapping fingerprints: {len(overlapping_smiles)}')\n", " print(f'Number of overlapping SMILES in train_val_df: {num_overlaps}')\n", " print(f'Number of overlapping SMILES in protac_df: {num_overlaps_tot}')\n", " return num_overlaps + radius + fpsize / 100" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[I 2024-04-29 11:28:05,626] A new study created in memory with name: no-name-4db5d822-6220-4ab8-bc3a-c776b0e5cac2\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "678150f59ec548bb89562e2230993989", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/50 [00:00