{ "cells": [ { "cell_type": "code", "metadata": { "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", "execution": { "iopub.execute_input": "2025-01-09T16:36:25.227597Z", "iopub.status.busy": "2025-01-09T16:36:25.227303Z", "iopub.status.idle": "2025-01-09T16:36:35.081281Z", "shell.execute_reply": "2025-01-09T16:36:35.080659Z", "shell.execute_reply.started": "2025-01-09T16:36:25.227573Z" } }, "source": [ "import segmentation_models_pytorch as smp\n", "import os\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "import numpy as np\n", "import torch\n", "from torch.fx.experimental.meta_tracer import torch_abs_override\n", "from torch.utils.data import Dataset, DataLoader\n", "from torchvision import transforms, utils\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.optim import lr_scheduler\n", "import time\n", "import albumentations as Album\n", "import torch.nn.functional as Functional\n", "import pandas as pd\n", "import nibabel as nib\n", "from tqdm import tqdm" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "! pip show albumentations" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:36:48.479196Z", "iopub.status.busy": "2025-01-09T16:36:48.478879Z", "iopub.status.idle": "2025-01-09T16:36:48.500028Z", "shell.execute_reply": "2025-01-09T16:36:48.499404Z", "shell.execute_reply.started": "2025-01-09T16:36:48.479170Z" } }, "source": [ "training_df = pd.read_csv('data/archive/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/name_mapping.csv')\n", "root_df = 'data/archive/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData'" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:36:51.384165Z", "iopub.status.busy": "2025-01-09T16:36:51.383835Z", "iopub.status.idle": "2025-01-09T16:36:51.401352Z", "shell.execute_reply": "2025-01-09T16:36:51.400713Z", "shell.execute_reply.started": "2025-01-09T16:36:51.384140Z" } }, "source": [ "training_df.head(10)" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "Exporting CSV Files to be used as reference for MRI Imaging files (.nii) to their respective file paths" ] }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:36:57.780114Z", "iopub.status.busy": "2025-01-09T16:36:57.779827Z", "iopub.status.idle": "2025-01-09T16:36:59.207480Z", "shell.execute_reply": "2025-01-09T16:36:59.206793Z", "shell.execute_reply.started": "2025-01-09T16:36:57.780094Z" } }, "source": [ "root_list = []\n", "tot_list = []\n", "\n", "for filename_root in tqdm(np.sort(os.listdir(root_df))[:-2]):\n", " subpath = os.path.join(root_df, filename_root)\n", " file_list = []\n", "\n", " for filename in np.sort(os.listdir(subpath)):\n", " file_list.append(os.path.join(subpath, filename))\n", "\n", " root_list.append(filename_root)\n", " tot_list.append(file_list)\n", " \n", "maps = pd.concat(\n", " [pd.DataFrame(root_list, columns=['DIR']),\n", " pd.DataFrame(tot_list, columns=['flair', 'seg', 't1', 't1ce', 't2']) \n", "], axis=1)\n", "\n", "maps.to_csv('links.csv', index=False)" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:37:07.946953Z", "iopub.status.busy": "2025-01-09T16:37:07.946665Z", "iopub.status.idle": "2025-01-09T16:37:07.955468Z", "shell.execute_reply": "2025-01-09T16:37:07.954634Z", "shell.execute_reply.started": "2025-01-09T16:37:07.946934Z" } }, "source": [ "image_path = {\n", " 'seg': [],\n", " 't1': [],\n", " 't1ce': [],\n", " 't2': [],\n", " 'flair': []\n", "}\n", "\n", "for path in training_df['BraTS_2020_subject_ID']:\n", " patient = os.path.join(root_df, path)\n", "\n", " for name in image_path:\n", " image_path[name].append(os.path.join(patient, path + f'_{name}.nii'))\n", "\n", "image_path['seg'][:5]" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:37:15.635134Z", "iopub.status.busy": "2025-01-09T16:37:15.634853Z", "iopub.status.idle": "2025-01-09T16:37:15.640048Z", "shell.execute_reply": "2025-01-09T16:37:15.639143Z", "shell.execute_reply.started": "2025-01-09T16:37:15.635113Z" } }, "source": [ "def load_image(image_path):\n", " return nib.load(image_path).get_fdata()\n", "\n", "\n", "def ccentre(image_slice, crop_x, crop_y):\n", " y, x = image_slice.shape\n", "\n", " start_x = x // 2 - (crop_x // 2)\n", " start_y = y // 2 - (crop_y // 2)\n", "\n", " return image_slice[start_y : start_y + crop_y, start_x : start_x + crop_x]\n", "\n", "\n", "def normalize(image_slice):\n", " return (image_slice - image_slice.mean()) / image_slice.std()" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:37:23.487997Z", "iopub.status.busy": "2025-01-09T16:37:23.487694Z", "iopub.status.idle": "2025-01-09T16:37:24.301565Z", "shell.execute_reply": "2025-01-09T16:37:24.300420Z", "shell.execute_reply.started": "2025-01-09T16:37:23.487971Z" } }, "source": [ "def create_dataset_directories(base_dir=\"dataset\"):\n", " os.makedirs(os.path.join(base_dir, \"t1\"), exist_ok=True)\n", " os.makedirs(os.path.join(base_dir, \"t1ce\"), exist_ok=True)\n", " os.makedirs(os.path.join(base_dir, \"t2\"), exist_ok=True)\n", " os.makedirs(os.path.join(base_dir, \"flair\"), exist_ok=True)\n", " os.makedirs(os.path.join(base_dir, \"seg\"), exist_ok=True)" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "create_dataset_directories('dataset')\n", "# Save the stress because the directory already exists" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:37:51.309665Z", "iopub.status.busy": "2025-01-09T16:37:51.309191Z", "iopub.status.idle": "2025-01-09T16:39:04.326289Z", "shell.execute_reply": "2025-01-09T16:39:04.325310Z", "shell.execute_reply.started": "2025-01-09T16:37:51.309625Z" } }, "source": [ "images_saved = 0\n", "images = {}\n", "image_slice = {}\n", "\n", "save_limit = 5000\n", "\n", "for i in (range(len(image_path['seg']))):\n", " \n", " for name in image_path:\n", " images[name] = load_image(image_path[name][i])\n", "\n", " for j in range(155):\n", " for name in images:\n", " image_slice[name] = images[name][:, :, j]\n", " image_slice[name] = ccentre(image_slice[name], 128, 128)\n", "\n", " if image_slice['seg'].max() > 0:\n", " for name in ['t1', 't2', 't1ce', 'flair']:\n", " image_slice[name] = normalize(image_slice[name])\n", "\n", " for name in image_slice:\n", " np.save(f'dataset/{name}/image_{images_saved}.npy', image_slice[name])\n", "\n", " images_saved += 1\n", "\n", " if images_saved == save_limit:\n", " break\n", "\n", " if images_saved == save_limit:\n", " break" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:40:00.898802Z", "iopub.status.busy": "2025-01-09T16:40:00.898500Z", "iopub.status.idle": "2025-01-09T16:40:00.902420Z", "shell.execute_reply": "2025-01-09T16:40:00.901607Z", "shell.execute_reply.started": "2025-01-09T16:40:00.898781Z" } }, "source": [ "# SOME BASIC IMAGE VISUALIZATIONS" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:40:06.557314Z", "iopub.status.busy": "2025-01-09T16:40:06.556901Z", "iopub.status.idle": "2025-01-09T16:40:07.667075Z", "shell.execute_reply": "2025-01-09T16:40:07.666168Z", "shell.execute_reply.started": "2025-01-09T16:40:06.557279Z" } }, "source": [ "fig = plt.figure(figsize = (24, 15))\n", "\n", "plt.subplot(1, 5, 1)\n", "plt.imshow(np.load('dataset/flair/image_25.npy'), cmap='bone')\n", "plt.title('Original')\n", "\n", "plt.subplot(1, 5, 2)\n", "plt.imshow(np.load('dataset/seg/image_25.npy'), cmap='bone')\n", "plt.title('Segment')\n", "\n", "plt.subplot(1, 5, 3)\n", "plt.imshow(np.load('dataset/t1/image_25.npy'), cmap='bone')\n", "plt.title('T1')\n", "\n", "plt.subplot(1, 5, 4)\n", "plt.imshow(np.load('dataset/t1ce/image_25.npy'), cmap='bone')\n", "plt.title('T1CE')\n", "\n", "plt.subplot(1, 5, 5)\n", "plt.imshow(np.load('dataset/t2/image_25.npy'), cmap='bone')\n", "plt.title('T2')" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:40:14.432473Z", "iopub.status.busy": "2025-01-09T16:40:14.432179Z", "iopub.status.idle": "2025-01-09T16:40:14.436037Z", "shell.execute_reply": "2025-01-09T16:40:14.435102Z", "shell.execute_reply.started": "2025-01-09T16:40:14.432449Z" } }, "source": [ "# WITH SOME COLOUR..." ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:40:15.200141Z", "iopub.status.busy": "2025-01-09T16:40:15.199879Z", "iopub.status.idle": "2025-01-09T16:40:16.473796Z", "shell.execute_reply": "2025-01-09T16:40:16.472822Z", "shell.execute_reply.started": "2025-01-09T16:40:15.200120Z" } }, "source": [ "fig = plt.figure(figsize = (24, 15))\n", "\n", "plt.subplot(1, 5, 1)\n", "plt.imshow(np.load('dataset/flair/image_25.npy'), cmap = 'bone')\n", "plt.title('Original')\n", "\n", "plt.subplot(1, 5, 2)\n", "plt.imshow(np.load('dataset/flair/image_25.npy'), cmap = 'bone')\n", "plt.imshow(np.load('dataset/seg/image_25.npy'), alpha=0.5, cmap='nipy_spectral')\n", "plt.title('Segment')\n", "\n", "plt.subplot(1, 5, 3)\n", "plt.imshow(np.load('dataset/flair/image_25.npy'), cmap = 'bone')\n", "plt.imshow(np.load('dataset/t1/image_25.npy'), alpha=0.5, cmap='nipy_spectral')\n", "plt.title('T1')\n", "\n", "plt.subplot(1, 5, 4)\n", "plt.imshow(np.load('dataset/flair/image_25.npy'), cmap = 'bone')\n", "plt.imshow(np.load('dataset/t1ce/image_25.npy'), alpha=0.5, cmap='nipy_spectral')\n", "plt.title('T1CE')\n", "\n", "plt.subplot(1, 5, 5)\n", "plt.imshow(np.load('dataset/flair/image_25.npy'), cmap = 'bone')\n", "plt.imshow(np.load('dataset/t2/image_25.npy'), alpha=0.5, cmap='nipy_spectral')\n", "plt.title('T2')" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:40:47.809814Z", "iopub.status.busy": "2025-01-09T16:40:47.809476Z", "iopub.status.idle": "2025-01-09T16:40:47.817498Z", "shell.execute_reply": "2025-01-09T16:40:47.816414Z", "shell.execute_reply.started": "2025-01-09T16:40:47.809789Z" } }, "source": [ "class DatasetGenerator(Dataset):\n", " def __init__(self, datapath='dataset/', augmentation=None):\n", " self.augmentation = augmentation\n", "\n", " self.folderpaths = {\n", " 'mask': os.path.join(datapath, 'seg/'),\n", " 't1': os.path.join(datapath, 't1/'),\n", " 't1ce': os.path.join(datapath, 't1ce/'),\n", " 't2': os.path.join(datapath, 't2/'),\n", " 'flair': os.path.join(datapath, 'flair/'),\n", " }\n", "\n", " def __getitem__(self, index):\n", " images = {}\n", "\n", " for name in self.folderpaths:\n", " images[name] = np.load(os.path.join(self.folderpaths[name], f'image_{index}.npy')).astype(np.float32)\n", "\n", " # print(f\"Loaded images for index {index}: {images.keys()}\")\n", " \n", " if self.augmentation:\n", " augmented = self.augmentation(\n", " image=images['flair'],\n", " mask=images['mask'],\n", " t1=images['t1'],\n", " t1ce=images['t1ce'],\n", " t2=images['t2']\n", " )\n", " # print(f\"Augmented images for index {index}: {augmented.keys()}\")\n", " images['flair'] = augmented['image']\n", " images['mask'] = augmented['mask']\n", " images['t1'] = augmented['t1']\n", " images['t1ce'] = augmented['t1ce']\n", " images['t2'] = augmented['t2']\n", "\n", " for name in images:\n", " images[name] = torch.from_numpy(images[name])\n", "\n", " # STACKING UP MULTI INPUTS\n", " input = torch.stack([\n", " images['t1'],\n", " images['t1ce'],\n", " images['t2'],\n", " images['flair']\n", " ], dim=0)\n", "\n", " images['mask'][images['mask'] == 4] = 3\n", "\n", " # ONE-HOT TRUTH LABEL ENCODING\n", " images['mask'] = Functional.one_hot(\n", " images['mask'].long().unsqueeze(0),\n", " num_classes=4\n", " ).permute(0, 3, 1, 2).contiguous().squeeze(0)\n", "\n", " return input.float(), images['mask'].long()\n", "\n", " def __len__(self):\n", " return len(os.listdir(self.folderpaths['mask'])) - 1" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:40:52.376269Z", "iopub.status.busy": "2025-01-09T16:40:52.375862Z", "iopub.status.idle": "2025-01-09T16:40:52.404458Z", "shell.execute_reply": "2025-01-09T16:40:52.403612Z", "shell.execute_reply.started": "2025-01-09T16:40:52.376234Z" } }, "source": [ "augmentation = Album.Compose([\n", " Album.OneOf([\n", " Album.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.5),\n", " Album.GridDistortion(p=0.5),\n", " Album.OpticalDistortion(distort_limit=2, shift_limit=0.5, p=0.5)\n", "\n", " ], p=0.8),\n", " Album.RandomBrightnessContrast(p=0.8),\n", "\n", " # Added classes for enhanced data augmentations\n", " #Album.Rotate(limit=45, p=0.8),\n", " #Album.HorizontalFlip(p=0.8),\n", " #Album.VerticalFlip(p=0.8),\n", " #Album.GaussNoise(p=0.5)\n", "\n", "], additional_targets={\n", " 't1': 'image',\n", " 't1ce': 'image',\n", " 't2': 'image'\n", "})\n", "\n", "\n", "valid_test_dataset = DatasetGenerator(datapath='dataset/', augmentation=None)\n", "train_dataset = DatasetGenerator(datapath='dataset/', augmentation=augmentation)\n", "\n", "# USING A 4:1:1 train-validation-test\n", "train_length = int(0.6 * len(valid_test_dataset))\n", "valid_length = int(0.2 * len(valid_test_dataset))\n", "test_length = len(valid_test_dataset) - train_length - valid_length\n", "\n", "_, valid_dataset, test_dataset = torch.utils.data.random_split(\n", " valid_test_dataset,\n", " (train_length, valid_length, test_length), generator=torch.Generator().manual_seed(42)\n", ")\n", "\n", "train_dataset, _, _ = torch.utils.data.random_split(\n", " train_dataset,\n", " (train_length, valid_length, test_length), generator=torch.Generator().manual_seed(42)\n", ")" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:41:01.714186Z", "iopub.status.busy": "2025-01-09T16:41:01.713852Z", "iopub.status.idle": "2025-01-09T16:41:01.719951Z", "shell.execute_reply": "2025-01-09T16:41:01.719031Z", "shell.execute_reply.started": "2025-01-09T16:41:01.714157Z" } }, "source": [ "train_loader = DataLoader(\n", " train_dataset, batch_size=16,\n", " num_workers=0, shuffle=True\n", ")\n", "\n", "valid_loader = DataLoader(\n", " valid_dataset, batch_size=1,\n", " num_workers=0, shuffle=True\n", ")\n", "\n", "test_loader = DataLoader(\n", " test_dataset, batch_size=1,\n", " num_workers=2, shuffle=True\n", ")" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "print(len(train_loader))\n", "print(len(test_loader))\n", "print(len(valid_loader))" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:41:04.716492Z", "iopub.status.busy": "2025-01-09T16:41:04.716204Z", "iopub.status.idle": "2025-01-09T16:41:05.078974Z", "shell.execute_reply": "2025-01-09T16:41:05.077171Z", "shell.execute_reply.started": "2025-01-09T16:41:04.716472Z" } }, "source": [ "a, b = next(iter(train_loader))" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.execute_input": "2025-01-09T16:17:05.731880Z", "iopub.status.busy": "2025-01-09T16:17:05.731375Z", "iopub.status.idle": "2025-01-09T16:17:05.752440Z", "shell.execute_reply": "2025-01-09T16:17:05.750970Z", "shell.execute_reply.started": "2025-01-09T16:17:05.731822Z" } }, "source": [ "plt.imshow(a[0, 0], cmap='gray')" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.status.busy": "2025-01-09T15:44:19.497446Z", "iopub.status.idle": "2025-01-09T15:44:19.497913Z", "shell.execute_reply": "2025-01-09T15:44:19.497700Z" } }, "source": [ "temp = torch.argmax(b, 0)\n", "plt.imshow(temp[0], cmap='gray')" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "! nvidia-smi" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.status.busy": "2025-01-09T15:44:19.498903Z", "iopub.status.idle": "2025-01-09T15:44:19.499326Z", "shell.execute_reply": "2025-01-09T15:44:19.499132Z" } }, "source": [ "# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", "print(torch.cuda.is_available())\n", "print(f'* CUDA Device: {torch.cuda.get_device_name(\"cuda:0\")}\\n* Device Properties: {torch.cuda.get_device_properties(\"cuda:0\")}')\n", "\n", "# device = torch.cuda.device(0)\n", "device = torch.device('cuda:0')" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.status.busy": "2025-01-09T15:44:19.500315Z", "iopub.status.idle": "2025-01-09T15:44:19.500623Z", "shell.execute_reply": "2025-01-09T15:44:19.500501Z" } }, "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "@torch.jit.script\n", "def autocrop(encoder_layer: torch.Tensor, decoder_layer: torch.Tensor):\n", " if encoder_layer.shape[2:] != decoder_layer.shape[2:]:\n", " ds = encoder_layer.shape[2:]\n", " es = decoder_layer.shape[2:]\n", "\n", " assert ds[0] >= es[0]\n", " assert ds[1] >= es[1]\n", "\n", " # IN CASES OF 2D FORMAT\n", " if encoder_layer.dim() == 4:\n", " encoder_layer = encoder_layer[\n", " :, :, \n", " ((ds[0] - es[0]) // 2) : ((ds[0] + es[0]) // 2),\n", " ((ds[1] - es[1]) // 2) : ((ds[1] + es[1]) // 2)\n", " ]\n", "\n", " # IN CASES OF 3D FORMATS\n", " elif encoder_layer.dim() == 5:\n", " assert ds[2] >= es[2]\n", "\n", " encoder_layer = encoder_layer[\n", " :, :, \n", " ((ds[0] - es[0]) // 2) : ((ds[0] + es[0]) // 2),\n", " ((ds[1] - es[1]) // 2) : ((ds[1] + es[1]) // 2),\n", " ((ds[2] - es[2]) // 2) : ((ds[2] + es[2]) // 2)\n", " ]\n", "\n", " return encoder_layer, decoder_layer\n", " \n", " else: \n", " return encoder_layer, decoder_layer\n", "\n", "\n", "def convolution_layer(dim: int):\n", " if dim == 3: \n", " return nn.Conv3d\n", " elif dim == 2:\n", " return nn.Conv2d\n", "\n", "\n", "def get_convolution_layer(\n", " in_channels: int, out_channels: int,\n", " kernel_size: int = 3, stride: int = 1,\n", " padding: int = 1, bias: bool = True, dim: int = 2):\n", "\n", " return convolution_layer(dim)(in_channels, out_channels, kernel_size=kernel_size,\n", " stride=stride, padding=padding, bias=bias)\n", "\n", "\n", "def convolution_transpose_layer(dim: int):\n", " if dim == 3:\n", " return nn.ConvTranspose3d\n", " elif dim == 2:\n", " return nn.ConvTranspose2d\n", "\n", "\n", "def get_up_layer(\n", " in_channels: int, out_channels: int,\n", " kernel_size: int = 2, stride: int = 2,\n", " dim: int = 3, up_mode: str = 'transposed'):\n", "\n", " if up_mode == 'transposed':\n", " return convolution_transpose_layer(dim)(in_channels, out_channels, \n", " kernel_size=kernel_size, stride=stride)\n", " else:\n", " return nn.Upsample(scale_factor=2.0, mode=up_mode)\n", "\n", "\n", "def maxpool_layer(dim: int):\n", " if dim == 3:\n", " return nn.MaxPool3d\n", " elif dim == 2:\n", " return nn.MaxPool2d\n", "\n", "\n", "def get_maxpool_layer(kernel_size: int = 2, stride: int = 2, padding: int = 0, dim: int = 2):\n", " return maxpool_layer(dim=dim)(kernel_size=kernel_size, stride=stride, padding=padding)\n", "\n", "# LeakyReLU Problem\n", "def get_activation(activation: str):\n", " if activation == 'relu':\n", " return nn.ReLU()\n", " elif activation == 'leaky':\n", " return nn.LeakyReLU(negative_slope=0.1)\n", " elif activation == 'elu':\n", " return nn.ELU()\n", "\n", "\n", "def get_normalization(normalization: str, num_channels: int, dim: int):\n", " if normalization == 'batch':\n", " if dim == 3:\n", " return nn.BatchNorm3d(num_channels)\n", " elif dim == 2:\n", " return nn.BatchNorm2d(num_channels)\n", "\n", " elif normalization == 'instance':\n", " if dim == 3:\n", " return nn.InstanceNorm3d(num_channels)\n", " elif dim == 2:\n", " return nn.InstanceNorm2d(num_channels)\n", "\n", " elif 'group' in normalization:\n", " num_groups = int(normalization.partition('group')[-1])\n", " return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)\n", "\n", "\n", "class ConcatenateLayer(nn.Module):\n", " def __init__(self):\n", " super(ConcatenateLayer, self).__init__()\n", "\n", " def forward(self, layer_1, layer_2):\n", " x = torch.cat((layer_1, layer_2), 1)\n", "\n", " return x\n", "\n", "\n", "class DownBlock(nn.Module):\n", " def __init__(\n", " self, \n", " in_channels: int,\n", " out_channels: int, \n", " pooling: bool = True,\n", " activation: str = 'relu',\n", " normalization: str = None,\n", " dim: int = 2,\n", " convolution_mode: str = 'same'):\n", "\n", " super().__init__()\n", "\n", " self.in_channels = in_channels\n", " self.out_channels = out_channels\n", " self.pooling = pooling\n", " self.normalization = normalization\n", "\n", " if convolution_mode == 'same':\n", " self.padding = 1\n", " elif convolution_mode == 'valid':\n", " self.padding = 0\n", "\n", " self.dim = dim\n", " self.activation = activation\n", "\n", " # CONVOLUTION LAYERS\n", " self.convolution1 = get_convolution_layer(\n", " self.in_channels, self.out_channels, kernel_size=3,\n", " stride=1, padding=self.padding, bias=True, dim=self.dim\n", " )\n", " self.convolution2 = get_convolution_layer(\n", " self.out_channels, self.out_channels, kernel_size=3,\n", " stride=1, padding=self.padding, bias=True, dim=self.dim\n", " )\n", "\n", " # POOLING LAYER\n", " if self.pooling:\n", " self.pool = get_maxpool_layer(kernel_size=2, stride=2, padding=0, dim=self.dim)\n", "\n", " # ACTIVATION LAYER\n", " self.activation1 = get_activation(self.activation)\n", " self.activation2 = get_activation(self.activation)\n", "\n", " # NORMALIZATION LAYERS\n", " if self.normalization:\n", " self.normalization1 = get_normalization(\n", " normalization=self.normalization, num_channels=self.out_channels,\n", " dim=self.dim\n", " )\n", " self.normalization2 = get_normalization(\n", " normalization=self.normalization, num_channels=self.out_channels,\n", " dim=self.dim\n", " )\n", "\n", " def forward(self, x):\n", " y = self.convolution1(x)\n", " y = self.activation1(y)\n", "\n", " if self.normalization:\n", " y = self.normalization1(y)\n", "\n", " y = self.convolution2(y)\n", " y = self.activation2(y)\n", "\n", " if self.normalization:\n", " y = self.normalization2(y)\n", "\n", " before_pooling = y\n", "\n", " if self.pooling:\n", " y = self.pool(y)\n", "\n", " return y, before_pooling\n", "\n", "\n", "import torch\n", "import torch.nn as nn\n", "\n", "class UpBlock(nn.Module):\n", " def __init__(self,\n", " in_channels: int,\n", " out_channels: int,\n", " activation: str = 'relu',\n", " normalization: str = None,\n", " dim: int = 3,\n", " convolution_mode: str = 'same',\n", " up_mode: str = 'transposed'):\n", "\n", " super().__init__()\n", "\n", " self.in_channels = in_channels\n", " self.out_channels = out_channels\n", " self.normalization = normalization\n", "\n", " if convolution_mode == 'same':\n", " self.padding = 1\n", " elif convolution_mode == 'valid':\n", " self.padding = 0\n", "\n", " self.dim = dim\n", " self.activation = activation\n", " self.up_mode = up_mode\n", "\n", " # UP-CONVOLUTION/UP-SAMPLING LAYER\n", " self.up = get_up_layer(\n", " self.in_channels, self.out_channels, kernel_size=2,\n", " stride=2, dim=self.dim, up_mode=self.up_mode\n", " )\n", "\n", " self.convolution0 = get_convolution_layer(\n", " self.out_channels, self.out_channels, kernel_size=1,\n", " stride=1, padding=0, bias=True, dim=self.dim\n", " )\n", " self.convolution1 = get_convolution_layer(\n", " 2 * self.out_channels, self.out_channels, kernel_size=3,\n", " stride=1, padding=self.padding, bias=True, dim=self.dim\n", " )\n", " self.convolution2 = get_convolution_layer(\n", " self.out_channels, self.out_channels, kernel_size=3,\n", " stride=1, padding=self.padding, bias=True, dim=self.dim\n", " )\n", "\n", " # ACTIVATION LAYERS\n", " self.activation0 = get_activation(self.activation)\n", " self.activation1 = get_activation(self.activation)\n", " self.activation2 = get_activation(self.activation)\n", "\n", " # NORMALIZATION LAYERS\n", " if self.normalization:\n", " self.normalization0 = get_normalization(\n", " normalization=self.normalization, num_channels=self.out_channels,\n", " dim=self.dim\n", " )\n", " self.normalization1 = get_normalization(\n", " normalization=self.normalization, num_channels=self.out_channels,\n", " dim=self.dim\n", " )\n", " self.normalization2 = get_normalization(\n", " normalization=self.normalization, num_channels=self.out_channels,\n", " dim=self.dim\n", " )\n", "\n", " self.concat = ConcatenateLayer()\n", "\n", " def forward(self, encoder_layer, decoder_layer):\n", " up_layer = self.up(decoder_layer)\n", " cropped_encoder_layer, dec_layer = autocrop(encoder_layer, up_layer)\n", "\n", " if self.up_mode != 'transposed':\n", " up_layer = self.convolution0(up_layer)\n", "\n", " up_layer = self.convolution0(up_layer)\n", "\n", " if self.normalization:\n", " up_layer = self.normalization0(up_layer)\n", "\n", " merged_layer = self.concat(up_layer, cropped_encoder_layer)\n", "\n", " y = self.convolution1(merged_layer)\n", " y = self.activation1(y)\n", "\n", " if self.normalization:\n", " y = self.normalization1(y)\n", "\n", " y = self.convolution2(y)\n", " y = self.activation2(y)\n", "\n", " if self.normalization:\n", " y = self.normalization2(y)\n", "\n", " return y\n", "\n", "\n", "class UNet(nn.Module):\n", " def __init__(\n", " self,\n", " in_channels: int = 1,\n", " out_channels: int = 2,\n", " n_blocks: int = 4,\n", " start_filters: int = 32,\n", " activation: str = 'relu',\n", " normalization: str = 'batch',\n", " convolution_mode: str = 'same',\n", " dim: int = 2,\n", " up_mode: str = 'transposed'):\n", "\n", " super().__init__()\n", "\n", " self.in_channels = in_channels\n", " self.out_channels = out_channels\n", " self.n_blocks = n_blocks\n", " self.start_filters = start_filters\n", " self.activation = activation\n", " self.normalization = normalization\n", " self.convolution_mode = convolution_mode\n", " self.dim = dim\n", " self.up_mode = up_mode\n", "\n", " self.down_blocks = []\n", " self.up_blocks = []\n", "\n", " # ENCODER PATH CREATION\n", " for i in range(self.n_blocks):\n", " num_filters_in = self.in_channels if i == 0 else num_filters_out\n", " num_filters_out = self.start_filters * (2 ** i)\n", " pooling = True if i < self.n_blocks - 1 else False\n", "\n", " down_block = DownBlock(\n", " in_channels=num_filters_in, out_channels=num_filters_out,\n", " pooling=pooling, activation=self.activation,\n", " normalization=self.normalization, convolution_mode=self.convolution_mode,\n", " dim=self.dim\n", " )\n", "\n", " self.down_blocks.append(down_block)\n", "\n", " # DECODER PATH CREATION (NEEDS ONLY N_BLOCKS-1)\n", " for i in range(n_blocks - 1):\n", " num_filters_in = num_filters_out\n", " num_filters_out = num_filters_in // 2\n", "\n", " up_block = UpBlock(\n", " in_channels=num_filters_in, out_channels=num_filters_out,\n", " activation=self.activation, normalization=self.normalization,\n", " convolution_mode=self.convolution_mode,\n", " dim=self.dim, up_mode=self.up_mode\n", " )\n", "\n", " self.up_blocks.append(up_block)\n", "\n", " # FINAL CONVOLUTION\n", " self.convolution_final = get_convolution_layer(\n", " num_filters_out, self.out_channels,\n", " kernel_size=1, stride=1,\n", " padding=0, bias=True, dim=self.dim\n", " )\n", "\n", " # ADDING LIST OF MODULES TO CURRENT MODULE\n", " self.down_blocks = nn.ModuleList(self.down_blocks)\n", " self.up_blocks = nn.ModuleList(self.up_blocks)\n", "\n", " # WEIGHT INITIALIZATION\n", " self.initialize_parameters()\n", "\n", " @staticmethod\n", " def weight_init(module, method, **kwargs):\n", " if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)):\n", " method(module.weight, **kwargs)\n", "\n", " @staticmethod\n", " def bias_init(module, method, **kwargs):\n", " if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)):\n", " method(module.bias, **kwargs)\n", "\n", " def initialize_parameters(self,\n", " method_weights=nn.init.xavier_uniform_,\n", " method_bias=nn.init.zeros_,\n", " kwargs_weights={},\n", " kwargs_bias={}):\n", "\n", " for module in self.modules():\n", " self.weight_init(module, method_weights, **kwargs_weights) # initialize weights\n", " self.bias_init(module, method_bias, **kwargs_bias) # initialize bias\n", "\n", " def forward(self, x: torch.tensor):\n", " encoder_output = []\n", "\n", " # ENCODER PATHWAY\n", " for module in self.down_blocks:\n", " x, before_pooling = module(x)\n", " encoder_output.append(before_pooling)\n", "\n", " # DECODER PATHWAY\n", " for i, module in enumerate(self.up_blocks):\n", " before_pool = encoder_output[-(i + 2)]\n", " x = module(before_pool, x)\n", "\n", " x = self.convolution_final(x)\n", "\n", " return x\n", "\n", " def __repr__(self):\n", " attributes = {attr_key: self.__dict__[attr_key] for attr_key in self.__dict__.keys() if '_' not in attr_key[0] and 'training' not in attr_key}\n", " d = {self.__class__.__name__: attributes}\n", "\n", " return f'{d}'" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.status.busy": "2025-01-09T15:44:19.501954Z", "iopub.status.idle": "2025-01-09T15:44:19.502387Z", "shell.execute_reply": "2025-01-09T15:44:19.502197Z" } }, "source": [ "MODEL = UNet(\n", " in_channels=4, out_channels=4,\n", " n_blocks=4, start_filters=32,\n", " activation='relu', normalization='batch',\n", " convolution_mode='same', dim=2\n", ")" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.status.busy": "2025-01-09T15:44:19.503541Z", "iopub.status.idle": "2025-01-09T15:44:19.503975Z", "shell.execute_reply": "2025-01-09T15:44:19.503784Z" } }, "source": [ "background_channel = [0]\n", "\n", "dice_loss = smp.utils.losses.DiceLoss(activation='softmax2d')\n", "\n", "optimizer = torch.optim.Adam([\n", " dict(params=MODEL.parameters(), lr=0.0001)\n", "])\n", "\n", "metrics = [\n", " smp.utils.metrics.IoU(threshold=0.5, ignore_channels=background_channel, activation='softmax2d'),\n", " smp.utils.metrics.Fscore(ignore_channels=background_channel, activation='softmax2d'),\n", "]" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "execution": { "iopub.status.busy": "2025-01-09T15:44:19.505175Z", "iopub.status.idle": "2025-01-09T15:44:19.505582Z", "shell.execute_reply": "2025-01-09T15:44:19.505396Z" } }, "source": [ "train_epoch = smp.utils.train.TrainEpoch(\n", " model=MODEL, loss=dice_loss,\n", " metrics=[], optimizer=optimizer,\n", " device=device, verbose=True\n", ")\n", "\n", "valid_epoch = smp.utils.train.ValidEpoch(\n", " model=MODEL, loss=dice_loss,\n", " metrics=metrics, device=device,\n", " verbose=True\n", ")\n", "\n", "max_dice_score = 0\n", "\n", "stats = {\n", " 'train_loss' : [],\n", " 'valid_loss' : [],\n", " 'fscore' : [],\n", " 'iou_score' : []\n", "}\n", "\n", "for i in range(50):\n", " print(f'\\n |--- EPOCH-{i} ---| ')\n", " train_logs = train_epoch.run(train_loader)\n", " valid_logs = valid_epoch.run(valid_loader)\n", " \n", " if max_dice_score < valid_logs['fscore']:\n", " max_dice_score = valid_logs['fscore']\n", " torch.save(MODEL.state_dict(), f'model/model.pth')\n", " \n", " print('model saved!')\n", " \n", " # loss statistics\n", " stats['train_loss'].append(train_logs['dice_loss'])\n", " stats['valid_loss'].append(valid_logs['dice_loss'])\n", " \n", " # metric statistics\n", " stats['fscore'].append(valid_logs['fscore'])\n", " stats['iou_score'].append(valid_logs['iou_score'])\n", " \n", " np.save(f'model/model.npy', stats)\n", " " ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "STATS = np.load(f'model/model.npy', allow_pickle=True).item()\n", "plt.plot(STATS['train_loss'], label='train_loss')\n", "plt.plot(STATS['valid_loss'], label='valid_loss')\n", "\n", "plt.legend(loc='upper right')\n", "\n", "plt.xlabel('EPOCH')\n", "plt.ylabel('LOSS')\n", "\n", "plt.title('TRAIN & VALIDATION LOSS')" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "STATS = np.load(f'model/model.npy', allow_pickle=True).item()\n", "plt.plot(STATS['fscore'], label ='fscore')\n", "plt.legend(loc = \"lower right\")\n", "plt.ylabel('SCORE')\n", "plt.xlabel('EPOCH')\n", "plt.title('F_SCORE')\n", "\n", "plt.plot(STATS['iou_score'], label ='iou_score')\n", "plt.legend(loc = \"lower right\")\n", "plt.ylabel('SCORE')\n", "plt.xlabel('EPOCH')\n", "plt.title('IOU_SCORE')" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "MODEL.load_state_dict(torch.load('model/model.pth', weights_only=True))" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "with torch.no_grad():\n", " out = MODEL(a.cuda())" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "plt.figure(figsize = (18, 10))\n", "plt.subplot(1, 3, 1)\n", "plt.imshow(a[2, 0],cmap='bone')\n", "plt.title('Input Image')\n", "\n", "plt.subplot(1, 3, 2)\n", "plt.imshow(a[2, 0],cmap='bone')\n", "plt.imshow(out.cpu()[2, 0], alpha = 0.5, cmap = 'nipy_spectral')\n", "plt.title('Predicted Segmentation')\n", "\n", "plt.subplot(1, 3, 3)\n", "plt.imshow(out.cpu()[2, 0], cmap = 'bone')\n", "plt.title('Prediction')" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "\n", "# Enhanced Data Augmentation\n", "from albumentations import Compose, RandomCrop, ElasticTransform, GridDistortion, OpticalDistortion, RandomBrightnessContrast, GaussNoise, Flip\n", "\n", "def get_augmentation_pipeline():\n", " return Compose([\n", " Flip(p=0.5),\n", " RandomCrop(height=128, width=128, p=0.5),\n", " ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.5),\n", " GridDistortion(p=0.5),\n", " OpticalDistortion(p=0.5),\n", " GaussNoise(p=0.5),\n", " RandomBrightnessContrast(p=0.5)\n", " ])\n", "\n", "augmentation_pipeline = get_augmentation_pipeline()\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "\n", "# Switching to Attention U-Net / UNet++ with Pre-trained Encoders\n", "import segmentation_models_pytorch as smp\n", "\n", "# Define a UNet++ with a ResNet34 encoder pre-trained on ImageNet\n", "model = smp.UnetPlusPlus(\n", " encoder_name=\"resnet34\",\n", " encoder_weights=\"imagenet\",\n", " in_channels=4,\n", " classes=4\n", ")\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "\n", "# Improved Loss Function\n", "import torch.nn as nn\n", "from segmentation_models_pytorch.losses import TverskyLoss\n", "\n", "# Combine Dice Loss and Tversky Loss\n", "class CombinedLoss(nn.Module):\n", " def __init__(self, alpha=0.5):\n", " super(CombinedLoss, self).__init__()\n", " self.dice_loss = smp.losses.DiceLoss(\"softmax\")\n", " self.tversky_loss = TverskyLoss(\"softmax\", alpha=0.7, beta=0.3)\n", " self.alpha = alpha\n", "\n", " def forward(self, y_pred, y_true):\n", " return self.alpha * self.dice_loss(y_pred, y_true) + (1 - self.alpha) * self.tversky_loss(y_pred, y_true)\n", "\n", "loss_fn = CombinedLoss()\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "from sklearn.svm._liblinear import train_wrap\n", "\n", "num_epochs = 50\n", "\n", "# Learning Rate Scheduling\n", "from torch.optim.lr_scheduler import CosineAnnealingLR\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", "scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-5) # Cosine Annealing\n", "\n", "# Update the scheduler in each epoch\n", "for epoch in range(num_epochs):\n", " train_wrap(...) # Train your model for one epoch\n", " scheduler.step()\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "\n", "# Post-Processing with CRF\n", "import pydensecrf.densecrf as dcrf\n", "\n", "def apply_crf(prob_map, img):\n", " d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], 4) # 4 is the number of classes\n", " U = -np.log(prob_map)\n", " d.setUnaryEnergy(U)\n", "\n", " # Add pairwise terms\n", " d.addPairwiseGaussian(sxy=3, compat=3)\n", " d.addPairwiseBilateral(sxy=30, srgb=13, rgbim=img, compat=10)\n", "\n", " Q = d.inference(5) # Number of iterations\n", " \n", " return np.argmax(Q, axis=0).reshape((img.shape[0], img.shape[1]))\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "\n", "# Cross-Validation\n", "from sklearn.model_selection import KFold\n", "\n", "kf = KFold(n_splits=5)\n", "for train_idx, valid_idx in kf.split(dataset):\n", " train_data = Subset(dataset, train_idx)\n", " valid_data = Subset(dataset, valid_idx)\n", "\n", " train_loader = DataLoader(train_data, batch_size=16, shuffle=True)\n", " valid_loader = DataLoader(valid_data, batch_size=16, shuffle=False)\n", "\n", " train_model(train_loader, valid_loader)\n" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": {}, "source": [ "\n", "# Ensemble Learning\n", "class EnsembleModel(nn.Module):\n", " def __init__(self, models):\n", " super(EnsembleModel, self).__init__()\n", " self.models = nn.ModuleList(models)\n", "\n", " def forward(self, x):\n", " outputs = [model(x) for model in self.models]\n", " return torch.mean(torch.stack(outputs), dim=0)\n", "\n", "# Combine multiple trained models\n", "models = [model1, model2, model3] # Pre-trained models\n", "ensemble_model = EnsembleModel(models)\n" ], "outputs": [], "execution_count": null } ], "metadata": { "kaggle": { "accelerator": "nvidiaTeslaT4", "dataSources": [ { "datasetId": 723383, "sourceId": 1267593, "sourceType": "datasetVersion" }, { "datasetId": 751906, "sourceId": 1299795, "sourceType": "datasetVersion" } ], "dockerImageVersionId": 30823, "isGpuEnabled": true, "isInternetEnabled": true, "language": "python", "sourceType": "notebook" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }