darkPhantomX commited on
Commit
999c545
·
verified ·
1 Parent(s): 856acc8

Upload 12 files

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Samson
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
brats_pretrained.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
brats_scratch-temp-modified.ipynb ADDED
@@ -0,0 +1,1433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "metadata": {
6
+ "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
7
+ "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
8
+ "execution": {
9
+ "iopub.execute_input": "2025-01-09T16:36:25.227597Z",
10
+ "iopub.status.busy": "2025-01-09T16:36:25.227303Z",
11
+ "iopub.status.idle": "2025-01-09T16:36:35.081281Z",
12
+ "shell.execute_reply": "2025-01-09T16:36:35.080659Z",
13
+ "shell.execute_reply.started": "2025-01-09T16:36:25.227573Z"
14
+ }
15
+ },
16
+ "source": [
17
+ "import segmentation_models_pytorch as smp\n",
18
+ "import os\n",
19
+ "import matplotlib.pyplot as plt\n",
20
+ "from PIL import Image\n",
21
+ "import numpy as np\n",
22
+ "import torch\n",
23
+ "from torch.fx.experimental.meta_tracer import torch_abs_override\n",
24
+ "from torch.utils.data import Dataset, DataLoader\n",
25
+ "from torchvision import transforms, utils\n",
26
+ "import torch.nn as nn\n",
27
+ "import torch.optim as optim\n",
28
+ "from torch.optim import lr_scheduler\n",
29
+ "import time\n",
30
+ "import albumentations as Album\n",
31
+ "import torch.nn.functional as Functional\n",
32
+ "import pandas as pd\n",
33
+ "import nibabel as nib\n",
34
+ "from tqdm import tqdm"
35
+ ],
36
+ "outputs": [],
37
+ "execution_count": null
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "metadata": {},
42
+ "source": [
43
+ "! pip show albumentations"
44
+ ],
45
+ "outputs": [],
46
+ "execution_count": null
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "metadata": {
51
+ "execution": {
52
+ "iopub.execute_input": "2025-01-09T16:36:48.479196Z",
53
+ "iopub.status.busy": "2025-01-09T16:36:48.478879Z",
54
+ "iopub.status.idle": "2025-01-09T16:36:48.500028Z",
55
+ "shell.execute_reply": "2025-01-09T16:36:48.499404Z",
56
+ "shell.execute_reply.started": "2025-01-09T16:36:48.479170Z"
57
+ }
58
+ },
59
+ "source": [
60
+ "training_df = pd.read_csv('data/archive/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/name_mapping.csv')\n",
61
+ "root_df = 'data/archive/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData'"
62
+ ],
63
+ "outputs": [],
64
+ "execution_count": null
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "metadata": {
69
+ "execution": {
70
+ "iopub.execute_input": "2025-01-09T16:36:51.384165Z",
71
+ "iopub.status.busy": "2025-01-09T16:36:51.383835Z",
72
+ "iopub.status.idle": "2025-01-09T16:36:51.401352Z",
73
+ "shell.execute_reply": "2025-01-09T16:36:51.400713Z",
74
+ "shell.execute_reply.started": "2025-01-09T16:36:51.384140Z"
75
+ }
76
+ },
77
+ "source": [
78
+ "training_df.head(10)"
79
+ ],
80
+ "outputs": [],
81
+ "execution_count": null
82
+ },
83
+ {
84
+ "cell_type": "markdown",
85
+ "metadata": {},
86
+ "source": [
87
+ "Exporting CSV Files to be used as reference for MRI Imaging files (.nii) to their respective file paths"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "metadata": {
93
+ "execution": {
94
+ "iopub.execute_input": "2025-01-09T16:36:57.780114Z",
95
+ "iopub.status.busy": "2025-01-09T16:36:57.779827Z",
96
+ "iopub.status.idle": "2025-01-09T16:36:59.207480Z",
97
+ "shell.execute_reply": "2025-01-09T16:36:59.206793Z",
98
+ "shell.execute_reply.started": "2025-01-09T16:36:57.780094Z"
99
+ }
100
+ },
101
+ "source": [
102
+ "root_list = []\n",
103
+ "tot_list = []\n",
104
+ "\n",
105
+ "for filename_root in tqdm(np.sort(os.listdir(root_df))[:-2]):\n",
106
+ " subpath = os.path.join(root_df, filename_root)\n",
107
+ " file_list = []\n",
108
+ "\n",
109
+ " for filename in np.sort(os.listdir(subpath)):\n",
110
+ " file_list.append(os.path.join(subpath, filename))\n",
111
+ "\n",
112
+ " root_list.append(filename_root)\n",
113
+ " tot_list.append(file_list)\n",
114
+ " \n",
115
+ "maps = pd.concat(\n",
116
+ " [pd.DataFrame(root_list, columns=['DIR']),\n",
117
+ " pd.DataFrame(tot_list, columns=['flair', 'seg', 't1', 't1ce', 't2']) \n",
118
+ "], axis=1)\n",
119
+ "\n",
120
+ "maps.to_csv('links.csv', index=False)"
121
+ ],
122
+ "outputs": [],
123
+ "execution_count": null
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "metadata": {
128
+ "execution": {
129
+ "iopub.execute_input": "2025-01-09T16:37:07.946953Z",
130
+ "iopub.status.busy": "2025-01-09T16:37:07.946665Z",
131
+ "iopub.status.idle": "2025-01-09T16:37:07.955468Z",
132
+ "shell.execute_reply": "2025-01-09T16:37:07.954634Z",
133
+ "shell.execute_reply.started": "2025-01-09T16:37:07.946934Z"
134
+ }
135
+ },
136
+ "source": [
137
+ "image_path = {\n",
138
+ " 'seg': [],\n",
139
+ " 't1': [],\n",
140
+ " 't1ce': [],\n",
141
+ " 't2': [],\n",
142
+ " 'flair': []\n",
143
+ "}\n",
144
+ "\n",
145
+ "for path in training_df['BraTS_2020_subject_ID']:\n",
146
+ " patient = os.path.join(root_df, path)\n",
147
+ "\n",
148
+ " for name in image_path:\n",
149
+ " image_path[name].append(os.path.join(patient, path + f'_{name}.nii'))\n",
150
+ "\n",
151
+ "image_path['seg'][:5]"
152
+ ],
153
+ "outputs": [],
154
+ "execution_count": null
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "metadata": {
159
+ "execution": {
160
+ "iopub.execute_input": "2025-01-09T16:37:15.635134Z",
161
+ "iopub.status.busy": "2025-01-09T16:37:15.634853Z",
162
+ "iopub.status.idle": "2025-01-09T16:37:15.640048Z",
163
+ "shell.execute_reply": "2025-01-09T16:37:15.639143Z",
164
+ "shell.execute_reply.started": "2025-01-09T16:37:15.635113Z"
165
+ }
166
+ },
167
+ "source": [
168
+ "def load_image(image_path):\n",
169
+ " return nib.load(image_path).get_fdata()\n",
170
+ "\n",
171
+ "\n",
172
+ "def ccentre(image_slice, crop_x, crop_y):\n",
173
+ " y, x = image_slice.shape\n",
174
+ "\n",
175
+ " start_x = x // 2 - (crop_x // 2)\n",
176
+ " start_y = y // 2 - (crop_y // 2)\n",
177
+ "\n",
178
+ " return image_slice[start_y : start_y + crop_y, start_x : start_x + crop_x]\n",
179
+ "\n",
180
+ "\n",
181
+ "def normalize(image_slice):\n",
182
+ " return (image_slice - image_slice.mean()) / image_slice.std()"
183
+ ],
184
+ "outputs": [],
185
+ "execution_count": null
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "metadata": {
190
+ "execution": {
191
+ "iopub.execute_input": "2025-01-09T16:37:23.487997Z",
192
+ "iopub.status.busy": "2025-01-09T16:37:23.487694Z",
193
+ "iopub.status.idle": "2025-01-09T16:37:24.301565Z",
194
+ "shell.execute_reply": "2025-01-09T16:37:24.300420Z",
195
+ "shell.execute_reply.started": "2025-01-09T16:37:23.487971Z"
196
+ }
197
+ },
198
+ "source": [
199
+ "def create_dataset_directories(base_dir=\"dataset\"):\n",
200
+ " os.makedirs(os.path.join(base_dir, \"t1\"), exist_ok=True)\n",
201
+ " os.makedirs(os.path.join(base_dir, \"t1ce\"), exist_ok=True)\n",
202
+ " os.makedirs(os.path.join(base_dir, \"t2\"), exist_ok=True)\n",
203
+ " os.makedirs(os.path.join(base_dir, \"flair\"), exist_ok=True)\n",
204
+ " os.makedirs(os.path.join(base_dir, \"seg\"), exist_ok=True)"
205
+ ],
206
+ "outputs": [],
207
+ "execution_count": null
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "metadata": {},
212
+ "source": [
213
+ "create_dataset_directories('dataset')\n",
214
+ "# Save the stress because the directory already exists"
215
+ ],
216
+ "outputs": [],
217
+ "execution_count": null
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "metadata": {
222
+ "execution": {
223
+ "iopub.execute_input": "2025-01-09T16:37:51.309665Z",
224
+ "iopub.status.busy": "2025-01-09T16:37:51.309191Z",
225
+ "iopub.status.idle": "2025-01-09T16:39:04.326289Z",
226
+ "shell.execute_reply": "2025-01-09T16:39:04.325310Z",
227
+ "shell.execute_reply.started": "2025-01-09T16:37:51.309625Z"
228
+ }
229
+ },
230
+ "source": [
231
+ "images_saved = 0\n",
232
+ "images = {}\n",
233
+ "image_slice = {}\n",
234
+ "\n",
235
+ "save_limit = 5000\n",
236
+ "\n",
237
+ "for i in (range(len(image_path['seg']))):\n",
238
+ " \n",
239
+ " for name in image_path:\n",
240
+ " images[name] = load_image(image_path[name][i])\n",
241
+ "\n",
242
+ " for j in range(155):\n",
243
+ " for name in images:\n",
244
+ " image_slice[name] = images[name][:, :, j]\n",
245
+ " image_slice[name] = ccentre(image_slice[name], 128, 128)\n",
246
+ "\n",
247
+ " if image_slice['seg'].max() > 0:\n",
248
+ " for name in ['t1', 't2', 't1ce', 'flair']:\n",
249
+ " image_slice[name] = normalize(image_slice[name])\n",
250
+ "\n",
251
+ " for name in image_slice:\n",
252
+ " np.save(f'dataset/{name}/image_{images_saved}.npy', image_slice[name])\n",
253
+ "\n",
254
+ " images_saved += 1\n",
255
+ "\n",
256
+ " if images_saved == save_limit:\n",
257
+ " break\n",
258
+ "\n",
259
+ " if images_saved == save_limit:\n",
260
+ " break"
261
+ ],
262
+ "outputs": [],
263
+ "execution_count": null
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "metadata": {
268
+ "execution": {
269
+ "iopub.execute_input": "2025-01-09T16:40:00.898802Z",
270
+ "iopub.status.busy": "2025-01-09T16:40:00.898500Z",
271
+ "iopub.status.idle": "2025-01-09T16:40:00.902420Z",
272
+ "shell.execute_reply": "2025-01-09T16:40:00.901607Z",
273
+ "shell.execute_reply.started": "2025-01-09T16:40:00.898781Z"
274
+ }
275
+ },
276
+ "source": [
277
+ "# SOME BASIC IMAGE VISUALIZATIONS"
278
+ ],
279
+ "outputs": [],
280
+ "execution_count": null
281
+ },
282
+ {
283
+ "cell_type": "code",
284
+ "metadata": {
285
+ "execution": {
286
+ "iopub.execute_input": "2025-01-09T16:40:06.557314Z",
287
+ "iopub.status.busy": "2025-01-09T16:40:06.556901Z",
288
+ "iopub.status.idle": "2025-01-09T16:40:07.667075Z",
289
+ "shell.execute_reply": "2025-01-09T16:40:07.666168Z",
290
+ "shell.execute_reply.started": "2025-01-09T16:40:06.557279Z"
291
+ }
292
+ },
293
+ "source": [
294
+ "fig = plt.figure(figsize = (24, 15))\n",
295
+ "\n",
296
+ "plt.subplot(1, 5, 1)\n",
297
+ "plt.imshow(np.load('dataset/flair/image_25.npy'), cmap='bone')\n",
298
+ "plt.title('Original')\n",
299
+ "\n",
300
+ "plt.subplot(1, 5, 2)\n",
301
+ "plt.imshow(np.load('dataset/seg/image_25.npy'), cmap='bone')\n",
302
+ "plt.title('Segment')\n",
303
+ "\n",
304
+ "plt.subplot(1, 5, 3)\n",
305
+ "plt.imshow(np.load('dataset/t1/image_25.npy'), cmap='bone')\n",
306
+ "plt.title('T1')\n",
307
+ "\n",
308
+ "plt.subplot(1, 5, 4)\n",
309
+ "plt.imshow(np.load('dataset/t1ce/image_25.npy'), cmap='bone')\n",
310
+ "plt.title('T1CE')\n",
311
+ "\n",
312
+ "plt.subplot(1, 5, 5)\n",
313
+ "plt.imshow(np.load('dataset/t2/image_25.npy'), cmap='bone')\n",
314
+ "plt.title('T2')"
315
+ ],
316
+ "outputs": [],
317
+ "execution_count": null
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "metadata": {
322
+ "execution": {
323
+ "iopub.execute_input": "2025-01-09T16:40:14.432473Z",
324
+ "iopub.status.busy": "2025-01-09T16:40:14.432179Z",
325
+ "iopub.status.idle": "2025-01-09T16:40:14.436037Z",
326
+ "shell.execute_reply": "2025-01-09T16:40:14.435102Z",
327
+ "shell.execute_reply.started": "2025-01-09T16:40:14.432449Z"
328
+ }
329
+ },
330
+ "source": [
331
+ "# WITH SOME COLOUR..."
332
+ ],
333
+ "outputs": [],
334
+ "execution_count": null
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "metadata": {
339
+ "execution": {
340
+ "iopub.execute_input": "2025-01-09T16:40:15.200141Z",
341
+ "iopub.status.busy": "2025-01-09T16:40:15.199879Z",
342
+ "iopub.status.idle": "2025-01-09T16:40:16.473796Z",
343
+ "shell.execute_reply": "2025-01-09T16:40:16.472822Z",
344
+ "shell.execute_reply.started": "2025-01-09T16:40:15.200120Z"
345
+ }
346
+ },
347
+ "source": [
348
+ "fig = plt.figure(figsize = (24, 15))\n",
349
+ "\n",
350
+ "plt.subplot(1, 5, 1)\n",
351
+ "plt.imshow(np.load('dataset/flair/image_25.npy'), cmap = 'bone')\n",
352
+ "plt.title('Original')\n",
353
+ "\n",
354
+ "plt.subplot(1, 5, 2)\n",
355
+ "plt.imshow(np.load('dataset/flair/image_25.npy'), cmap = 'bone')\n",
356
+ "plt.imshow(np.load('dataset/seg/image_25.npy'), alpha=0.5, cmap='nipy_spectral')\n",
357
+ "plt.title('Segment')\n",
358
+ "\n",
359
+ "plt.subplot(1, 5, 3)\n",
360
+ "plt.imshow(np.load('dataset/flair/image_25.npy'), cmap = 'bone')\n",
361
+ "plt.imshow(np.load('dataset/t1/image_25.npy'), alpha=0.5, cmap='nipy_spectral')\n",
362
+ "plt.title('T1')\n",
363
+ "\n",
364
+ "plt.subplot(1, 5, 4)\n",
365
+ "plt.imshow(np.load('dataset/flair/image_25.npy'), cmap = 'bone')\n",
366
+ "plt.imshow(np.load('dataset/t1ce/image_25.npy'), alpha=0.5, cmap='nipy_spectral')\n",
367
+ "plt.title('T1CE')\n",
368
+ "\n",
369
+ "plt.subplot(1, 5, 5)\n",
370
+ "plt.imshow(np.load('dataset/flair/image_25.npy'), cmap = 'bone')\n",
371
+ "plt.imshow(np.load('dataset/t2/image_25.npy'), alpha=0.5, cmap='nipy_spectral')\n",
372
+ "plt.title('T2')"
373
+ ],
374
+ "outputs": [],
375
+ "execution_count": null
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "metadata": {
380
+ "execution": {
381
+ "iopub.execute_input": "2025-01-09T16:40:47.809814Z",
382
+ "iopub.status.busy": "2025-01-09T16:40:47.809476Z",
383
+ "iopub.status.idle": "2025-01-09T16:40:47.817498Z",
384
+ "shell.execute_reply": "2025-01-09T16:40:47.816414Z",
385
+ "shell.execute_reply.started": "2025-01-09T16:40:47.809789Z"
386
+ }
387
+ },
388
+ "source": [
389
+ "class DatasetGenerator(Dataset):\n",
390
+ " def __init__(self, datapath='dataset/', augmentation=None):\n",
391
+ " self.augmentation = augmentation\n",
392
+ "\n",
393
+ " self.folderpaths = {\n",
394
+ " 'mask': os.path.join(datapath, 'seg/'),\n",
395
+ " 't1': os.path.join(datapath, 't1/'),\n",
396
+ " 't1ce': os.path.join(datapath, 't1ce/'),\n",
397
+ " 't2': os.path.join(datapath, 't2/'),\n",
398
+ " 'flair': os.path.join(datapath, 'flair/'),\n",
399
+ " }\n",
400
+ "\n",
401
+ " def __getitem__(self, index):\n",
402
+ " images = {}\n",
403
+ "\n",
404
+ " for name in self.folderpaths:\n",
405
+ " images[name] = np.load(os.path.join(self.folderpaths[name], f'image_{index}.npy')).astype(np.float32)\n",
406
+ "\n",
407
+ " # print(f\"Loaded images for index {index}: {images.keys()}\")\n",
408
+ " \n",
409
+ " if self.augmentation:\n",
410
+ " augmented = self.augmentation(\n",
411
+ " image=images['flair'],\n",
412
+ " mask=images['mask'],\n",
413
+ " t1=images['t1'],\n",
414
+ " t1ce=images['t1ce'],\n",
415
+ " t2=images['t2']\n",
416
+ " )\n",
417
+ " # print(f\"Augmented images for index {index}: {augmented.keys()}\")\n",
418
+ " images['flair'] = augmented['image']\n",
419
+ " images['mask'] = augmented['mask']\n",
420
+ " images['t1'] = augmented['t1']\n",
421
+ " images['t1ce'] = augmented['t1ce']\n",
422
+ " images['t2'] = augmented['t2']\n",
423
+ "\n",
424
+ " for name in images:\n",
425
+ " images[name] = torch.from_numpy(images[name])\n",
426
+ "\n",
427
+ " # STACKING UP MULTI INPUTS\n",
428
+ " input = torch.stack([\n",
429
+ " images['t1'],\n",
430
+ " images['t1ce'],\n",
431
+ " images['t2'],\n",
432
+ " images['flair']\n",
433
+ " ], dim=0)\n",
434
+ "\n",
435
+ " images['mask'][images['mask'] == 4] = 3\n",
436
+ "\n",
437
+ " # ONE-HOT TRUTH LABEL ENCODING\n",
438
+ " images['mask'] = Functional.one_hot(\n",
439
+ " images['mask'].long().unsqueeze(0),\n",
440
+ " num_classes=4\n",
441
+ " ).permute(0, 3, 1, 2).contiguous().squeeze(0)\n",
442
+ "\n",
443
+ " return input.float(), images['mask'].long()\n",
444
+ "\n",
445
+ " def __len__(self):\n",
446
+ " return len(os.listdir(self.folderpaths['mask'])) - 1"
447
+ ],
448
+ "outputs": [],
449
+ "execution_count": null
450
+ },
451
+ {
452
+ "cell_type": "code",
453
+ "metadata": {
454
+ "execution": {
455
+ "iopub.execute_input": "2025-01-09T16:40:52.376269Z",
456
+ "iopub.status.busy": "2025-01-09T16:40:52.375862Z",
457
+ "iopub.status.idle": "2025-01-09T16:40:52.404458Z",
458
+ "shell.execute_reply": "2025-01-09T16:40:52.403612Z",
459
+ "shell.execute_reply.started": "2025-01-09T16:40:52.376234Z"
460
+ }
461
+ },
462
+ "source": [
463
+ "augmentation = Album.Compose([\n",
464
+ " Album.OneOf([\n",
465
+ " Album.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.5),\n",
466
+ " Album.GridDistortion(p=0.5),\n",
467
+ " Album.OpticalDistortion(distort_limit=2, shift_limit=0.5, p=0.5)\n",
468
+ "\n",
469
+ " ], p=0.8),\n",
470
+ " Album.RandomBrightnessContrast(p=0.8),\n",
471
+ "\n",
472
+ " # Added classes for enhanced data augmentations\n",
473
+ " #Album.Rotate(limit=45, p=0.8),\n",
474
+ " #Album.HorizontalFlip(p=0.8),\n",
475
+ " #Album.VerticalFlip(p=0.8),\n",
476
+ " #Album.GaussNoise(p=0.5)\n",
477
+ "\n",
478
+ "], additional_targets={\n",
479
+ " 't1': 'image',\n",
480
+ " 't1ce': 'image',\n",
481
+ " 't2': 'image'\n",
482
+ "})\n",
483
+ "\n",
484
+ "\n",
485
+ "valid_test_dataset = DatasetGenerator(datapath='dataset/', augmentation=None)\n",
486
+ "train_dataset = DatasetGenerator(datapath='dataset/', augmentation=augmentation)\n",
487
+ "\n",
488
+ "# USING A 4:1:1 train-validation-test\n",
489
+ "train_length = int(0.6 * len(valid_test_dataset))\n",
490
+ "valid_length = int(0.2 * len(valid_test_dataset))\n",
491
+ "test_length = len(valid_test_dataset) - train_length - valid_length\n",
492
+ "\n",
493
+ "_, valid_dataset, test_dataset = torch.utils.data.random_split(\n",
494
+ " valid_test_dataset,\n",
495
+ " (train_length, valid_length, test_length), generator=torch.Generator().manual_seed(42)\n",
496
+ ")\n",
497
+ "\n",
498
+ "train_dataset, _, _ = torch.utils.data.random_split(\n",
499
+ " train_dataset,\n",
500
+ " (train_length, valid_length, test_length), generator=torch.Generator().manual_seed(42)\n",
501
+ ")"
502
+ ],
503
+ "outputs": [],
504
+ "execution_count": null
505
+ },
506
+ {
507
+ "cell_type": "code",
508
+ "metadata": {
509
+ "execution": {
510
+ "iopub.execute_input": "2025-01-09T16:41:01.714186Z",
511
+ "iopub.status.busy": "2025-01-09T16:41:01.713852Z",
512
+ "iopub.status.idle": "2025-01-09T16:41:01.719951Z",
513
+ "shell.execute_reply": "2025-01-09T16:41:01.719031Z",
514
+ "shell.execute_reply.started": "2025-01-09T16:41:01.714157Z"
515
+ }
516
+ },
517
+ "source": [
518
+ "train_loader = DataLoader(\n",
519
+ " train_dataset, batch_size=16,\n",
520
+ " num_workers=0, shuffle=True\n",
521
+ ")\n",
522
+ "\n",
523
+ "valid_loader = DataLoader(\n",
524
+ " valid_dataset, batch_size=1,\n",
525
+ " num_workers=0, shuffle=True\n",
526
+ ")\n",
527
+ "\n",
528
+ "test_loader = DataLoader(\n",
529
+ " test_dataset, batch_size=1,\n",
530
+ " num_workers=2, shuffle=True\n",
531
+ ")"
532
+ ],
533
+ "outputs": [],
534
+ "execution_count": null
535
+ },
536
+ {
537
+ "cell_type": "code",
538
+ "metadata": {},
539
+ "source": [
540
+ "print(len(train_loader))\n",
541
+ "print(len(test_loader))\n",
542
+ "print(len(valid_loader))"
543
+ ],
544
+ "outputs": [],
545
+ "execution_count": null
546
+ },
547
+ {
548
+ "cell_type": "code",
549
+ "metadata": {
550
+ "execution": {
551
+ "iopub.execute_input": "2025-01-09T16:41:04.716492Z",
552
+ "iopub.status.busy": "2025-01-09T16:41:04.716204Z",
553
+ "iopub.status.idle": "2025-01-09T16:41:05.078974Z",
554
+ "shell.execute_reply": "2025-01-09T16:41:05.077171Z",
555
+ "shell.execute_reply.started": "2025-01-09T16:41:04.716472Z"
556
+ }
557
+ },
558
+ "source": [
559
+ "a, b = next(iter(train_loader))"
560
+ ],
561
+ "outputs": [],
562
+ "execution_count": null
563
+ },
564
+ {
565
+ "cell_type": "code",
566
+ "metadata": {
567
+ "execution": {
568
+ "iopub.execute_input": "2025-01-09T16:17:05.731880Z",
569
+ "iopub.status.busy": "2025-01-09T16:17:05.731375Z",
570
+ "iopub.status.idle": "2025-01-09T16:17:05.752440Z",
571
+ "shell.execute_reply": "2025-01-09T16:17:05.750970Z",
572
+ "shell.execute_reply.started": "2025-01-09T16:17:05.731822Z"
573
+ }
574
+ },
575
+ "source": [
576
+ "plt.imshow(a[0, 0], cmap='gray')"
577
+ ],
578
+ "outputs": [],
579
+ "execution_count": null
580
+ },
581
+ {
582
+ "cell_type": "code",
583
+ "metadata": {
584
+ "execution": {
585
+ "iopub.status.busy": "2025-01-09T15:44:19.497446Z",
586
+ "iopub.status.idle": "2025-01-09T15:44:19.497913Z",
587
+ "shell.execute_reply": "2025-01-09T15:44:19.497700Z"
588
+ }
589
+ },
590
+ "source": [
591
+ "temp = torch.argmax(b, 0)\n",
592
+ "plt.imshow(temp[0], cmap='gray')"
593
+ ],
594
+ "outputs": [],
595
+ "execution_count": null
596
+ },
597
+ {
598
+ "cell_type": "code",
599
+ "metadata": {},
600
+ "source": [
601
+ "! nvidia-smi"
602
+ ],
603
+ "outputs": [],
604
+ "execution_count": null
605
+ },
606
+ {
607
+ "cell_type": "code",
608
+ "metadata": {
609
+ "execution": {
610
+ "iopub.status.busy": "2025-01-09T15:44:19.498903Z",
611
+ "iopub.status.idle": "2025-01-09T15:44:19.499326Z",
612
+ "shell.execute_reply": "2025-01-09T15:44:19.499132Z"
613
+ }
614
+ },
615
+ "source": [
616
+ "# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
617
+ "print(torch.cuda.is_available())\n",
618
+ "print(f'* CUDA Device: {torch.cuda.get_device_name(\"cuda:0\")}\\n* Device Properties: {torch.cuda.get_device_properties(\"cuda:0\")}')\n",
619
+ "\n",
620
+ "# device = torch.cuda.device(0)\n",
621
+ "device = torch.device('cuda:0')"
622
+ ],
623
+ "outputs": [],
624
+ "execution_count": null
625
+ },
626
+ {
627
+ "cell_type": "code",
628
+ "metadata": {
629
+ "execution": {
630
+ "iopub.status.busy": "2025-01-09T15:44:19.500315Z",
631
+ "iopub.status.idle": "2025-01-09T15:44:19.500623Z",
632
+ "shell.execute_reply": "2025-01-09T15:44:19.500501Z"
633
+ }
634
+ },
635
+ "source": [
636
+ "import torch\n",
637
+ "import torch.nn as nn\n",
638
+ "\n",
639
+ "@torch.jit.script\n",
640
+ "def autocrop(encoder_layer: torch.Tensor, decoder_layer: torch.Tensor):\n",
641
+ " if encoder_layer.shape[2:] != decoder_layer.shape[2:]:\n",
642
+ " ds = encoder_layer.shape[2:]\n",
643
+ " es = decoder_layer.shape[2:]\n",
644
+ "\n",
645
+ " assert ds[0] >= es[0]\n",
646
+ " assert ds[1] >= es[1]\n",
647
+ "\n",
648
+ " # IN CASES OF 2D FORMAT\n",
649
+ " if encoder_layer.dim() == 4:\n",
650
+ " encoder_layer = encoder_layer[\n",
651
+ " :, :, \n",
652
+ " ((ds[0] - es[0]) // 2) : ((ds[0] + es[0]) // 2),\n",
653
+ " ((ds[1] - es[1]) // 2) : ((ds[1] + es[1]) // 2)\n",
654
+ " ]\n",
655
+ "\n",
656
+ " # IN CASES OF 3D FORMATS\n",
657
+ " elif encoder_layer.dim() == 5:\n",
658
+ " assert ds[2] >= es[2]\n",
659
+ "\n",
660
+ " encoder_layer = encoder_layer[\n",
661
+ " :, :, \n",
662
+ " ((ds[0] - es[0]) // 2) : ((ds[0] + es[0]) // 2),\n",
663
+ " ((ds[1] - es[1]) // 2) : ((ds[1] + es[1]) // 2),\n",
664
+ " ((ds[2] - es[2]) // 2) : ((ds[2] + es[2]) // 2)\n",
665
+ " ]\n",
666
+ "\n",
667
+ " return encoder_layer, decoder_layer\n",
668
+ " \n",
669
+ " else: \n",
670
+ " return encoder_layer, decoder_layer\n",
671
+ "\n",
672
+ "\n",
673
+ "def convolution_layer(dim: int):\n",
674
+ " if dim == 3: \n",
675
+ " return nn.Conv3d\n",
676
+ " elif dim == 2:\n",
677
+ " return nn.Conv2d\n",
678
+ "\n",
679
+ "\n",
680
+ "def get_convolution_layer(\n",
681
+ " in_channels: int, out_channels: int,\n",
682
+ " kernel_size: int = 3, stride: int = 1,\n",
683
+ " padding: int = 1, bias: bool = True, dim: int = 2):\n",
684
+ "\n",
685
+ " return convolution_layer(dim)(in_channels, out_channels, kernel_size=kernel_size,\n",
686
+ " stride=stride, padding=padding, bias=bias)\n",
687
+ "\n",
688
+ "\n",
689
+ "def convolution_transpose_layer(dim: int):\n",
690
+ " if dim == 3:\n",
691
+ " return nn.ConvTranspose3d\n",
692
+ " elif dim == 2:\n",
693
+ " return nn.ConvTranspose2d\n",
694
+ "\n",
695
+ "\n",
696
+ "def get_up_layer(\n",
697
+ " in_channels: int, out_channels: int,\n",
698
+ " kernel_size: int = 2, stride: int = 2,\n",
699
+ " dim: int = 3, up_mode: str = 'transposed'):\n",
700
+ "\n",
701
+ " if up_mode == 'transposed':\n",
702
+ " return convolution_transpose_layer(dim)(in_channels, out_channels, \n",
703
+ " kernel_size=kernel_size, stride=stride)\n",
704
+ " else:\n",
705
+ " return nn.Upsample(scale_factor=2.0, mode=up_mode)\n",
706
+ "\n",
707
+ "\n",
708
+ "def maxpool_layer(dim: int):\n",
709
+ " if dim == 3:\n",
710
+ " return nn.MaxPool3d\n",
711
+ " elif dim == 2:\n",
712
+ " return nn.MaxPool2d\n",
713
+ "\n",
714
+ "\n",
715
+ "def get_maxpool_layer(kernel_size: int = 2, stride: int = 2, padding: int = 0, dim: int = 2):\n",
716
+ " return maxpool_layer(dim=dim)(kernel_size=kernel_size, stride=stride, padding=padding)\n",
717
+ "\n",
718
+ "# LeakyReLU Problem\n",
719
+ "def get_activation(activation: str):\n",
720
+ " if activation == 'relu':\n",
721
+ " return nn.ReLU()\n",
722
+ " elif activation == 'leaky':\n",
723
+ " return nn.LeakyReLU(negative_slope=0.1)\n",
724
+ " elif activation == 'elu':\n",
725
+ " return nn.ELU()\n",
726
+ "\n",
727
+ "\n",
728
+ "def get_normalization(normalization: str, num_channels: int, dim: int):\n",
729
+ " if normalization == 'batch':\n",
730
+ " if dim == 3:\n",
731
+ " return nn.BatchNorm3d(num_channels)\n",
732
+ " elif dim == 2:\n",
733
+ " return nn.BatchNorm2d(num_channels)\n",
734
+ "\n",
735
+ " elif normalization == 'instance':\n",
736
+ " if dim == 3:\n",
737
+ " return nn.InstanceNorm3d(num_channels)\n",
738
+ " elif dim == 2:\n",
739
+ " return nn.InstanceNorm2d(num_channels)\n",
740
+ "\n",
741
+ " elif 'group' in normalization:\n",
742
+ " num_groups = int(normalization.partition('group')[-1])\n",
743
+ " return nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)\n",
744
+ "\n",
745
+ "\n",
746
+ "class ConcatenateLayer(nn.Module):\n",
747
+ " def __init__(self):\n",
748
+ " super(ConcatenateLayer, self).__init__()\n",
749
+ "\n",
750
+ " def forward(self, layer_1, layer_2):\n",
751
+ " x = torch.cat((layer_1, layer_2), 1)\n",
752
+ "\n",
753
+ " return x\n",
754
+ "\n",
755
+ "\n",
756
+ "class DownBlock(nn.Module):\n",
757
+ " def __init__(\n",
758
+ " self, \n",
759
+ " in_channels: int,\n",
760
+ " out_channels: int, \n",
761
+ " pooling: bool = True,\n",
762
+ " activation: str = 'relu',\n",
763
+ " normalization: str = None,\n",
764
+ " dim: int = 2,\n",
765
+ " convolution_mode: str = 'same'):\n",
766
+ "\n",
767
+ " super().__init__()\n",
768
+ "\n",
769
+ " self.in_channels = in_channels\n",
770
+ " self.out_channels = out_channels\n",
771
+ " self.pooling = pooling\n",
772
+ " self.normalization = normalization\n",
773
+ "\n",
774
+ " if convolution_mode == 'same':\n",
775
+ " self.padding = 1\n",
776
+ " elif convolution_mode == 'valid':\n",
777
+ " self.padding = 0\n",
778
+ "\n",
779
+ " self.dim = dim\n",
780
+ " self.activation = activation\n",
781
+ "\n",
782
+ " # CONVOLUTION LAYERS\n",
783
+ " self.convolution1 = get_convolution_layer(\n",
784
+ " self.in_channels, self.out_channels, kernel_size=3,\n",
785
+ " stride=1, padding=self.padding, bias=True, dim=self.dim\n",
786
+ " )\n",
787
+ " self.convolution2 = get_convolution_layer(\n",
788
+ " self.out_channels, self.out_channels, kernel_size=3,\n",
789
+ " stride=1, padding=self.padding, bias=True, dim=self.dim\n",
790
+ " )\n",
791
+ "\n",
792
+ " # POOLING LAYER\n",
793
+ " if self.pooling:\n",
794
+ " self.pool = get_maxpool_layer(kernel_size=2, stride=2, padding=0, dim=self.dim)\n",
795
+ "\n",
796
+ " # ACTIVATION LAYER\n",
797
+ " self.activation1 = get_activation(self.activation)\n",
798
+ " self.activation2 = get_activation(self.activation)\n",
799
+ "\n",
800
+ " # NORMALIZATION LAYERS\n",
801
+ " if self.normalization:\n",
802
+ " self.normalization1 = get_normalization(\n",
803
+ " normalization=self.normalization, num_channels=self.out_channels,\n",
804
+ " dim=self.dim\n",
805
+ " )\n",
806
+ " self.normalization2 = get_normalization(\n",
807
+ " normalization=self.normalization, num_channels=self.out_channels,\n",
808
+ " dim=self.dim\n",
809
+ " )\n",
810
+ "\n",
811
+ " def forward(self, x):\n",
812
+ " y = self.convolution1(x)\n",
813
+ " y = self.activation1(y)\n",
814
+ "\n",
815
+ " if self.normalization:\n",
816
+ " y = self.normalization1(y)\n",
817
+ "\n",
818
+ " y = self.convolution2(y)\n",
819
+ " y = self.activation2(y)\n",
820
+ "\n",
821
+ " if self.normalization:\n",
822
+ " y = self.normalization2(y)\n",
823
+ "\n",
824
+ " before_pooling = y\n",
825
+ "\n",
826
+ " if self.pooling:\n",
827
+ " y = self.pool(y)\n",
828
+ "\n",
829
+ " return y, before_pooling\n",
830
+ "\n",
831
+ "\n",
832
+ "import torch\n",
833
+ "import torch.nn as nn\n",
834
+ "\n",
835
+ "class UpBlock(nn.Module):\n",
836
+ " def __init__(self,\n",
837
+ " in_channels: int,\n",
838
+ " out_channels: int,\n",
839
+ " activation: str = 'relu',\n",
840
+ " normalization: str = None,\n",
841
+ " dim: int = 3,\n",
842
+ " convolution_mode: str = 'same',\n",
843
+ " up_mode: str = 'transposed'):\n",
844
+ "\n",
845
+ " super().__init__()\n",
846
+ "\n",
847
+ " self.in_channels = in_channels\n",
848
+ " self.out_channels = out_channels\n",
849
+ " self.normalization = normalization\n",
850
+ "\n",
851
+ " if convolution_mode == 'same':\n",
852
+ " self.padding = 1\n",
853
+ " elif convolution_mode == 'valid':\n",
854
+ " self.padding = 0\n",
855
+ "\n",
856
+ " self.dim = dim\n",
857
+ " self.activation = activation\n",
858
+ " self.up_mode = up_mode\n",
859
+ "\n",
860
+ " # UP-CONVOLUTION/UP-SAMPLING LAYER\n",
861
+ " self.up = get_up_layer(\n",
862
+ " self.in_channels, self.out_channels, kernel_size=2,\n",
863
+ " stride=2, dim=self.dim, up_mode=self.up_mode\n",
864
+ " )\n",
865
+ "\n",
866
+ " self.convolution0 = get_convolution_layer(\n",
867
+ " self.out_channels, self.out_channels, kernel_size=1,\n",
868
+ " stride=1, padding=0, bias=True, dim=self.dim\n",
869
+ " )\n",
870
+ " self.convolution1 = get_convolution_layer(\n",
871
+ " 2 * self.out_channels, self.out_channels, kernel_size=3,\n",
872
+ " stride=1, padding=self.padding, bias=True, dim=self.dim\n",
873
+ " )\n",
874
+ " self.convolution2 = get_convolution_layer(\n",
875
+ " self.out_channels, self.out_channels, kernel_size=3,\n",
876
+ " stride=1, padding=self.padding, bias=True, dim=self.dim\n",
877
+ " )\n",
878
+ "\n",
879
+ " # ACTIVATION LAYERS\n",
880
+ " self.activation0 = get_activation(self.activation)\n",
881
+ " self.activation1 = get_activation(self.activation)\n",
882
+ " self.activation2 = get_activation(self.activation)\n",
883
+ "\n",
884
+ " # NORMALIZATION LAYERS\n",
885
+ " if self.normalization:\n",
886
+ " self.normalization0 = get_normalization(\n",
887
+ " normalization=self.normalization, num_channels=self.out_channels,\n",
888
+ " dim=self.dim\n",
889
+ " )\n",
890
+ " self.normalization1 = get_normalization(\n",
891
+ " normalization=self.normalization, num_channels=self.out_channels,\n",
892
+ " dim=self.dim\n",
893
+ " )\n",
894
+ " self.normalization2 = get_normalization(\n",
895
+ " normalization=self.normalization, num_channels=self.out_channels,\n",
896
+ " dim=self.dim\n",
897
+ " )\n",
898
+ "\n",
899
+ " self.concat = ConcatenateLayer()\n",
900
+ "\n",
901
+ " def forward(self, encoder_layer, decoder_layer):\n",
902
+ " up_layer = self.up(decoder_layer)\n",
903
+ " cropped_encoder_layer, dec_layer = autocrop(encoder_layer, up_layer)\n",
904
+ "\n",
905
+ " if self.up_mode != 'transposed':\n",
906
+ " up_layer = self.convolution0(up_layer)\n",
907
+ "\n",
908
+ " up_layer = self.convolution0(up_layer)\n",
909
+ "\n",
910
+ " if self.normalization:\n",
911
+ " up_layer = self.normalization0(up_layer)\n",
912
+ "\n",
913
+ " merged_layer = self.concat(up_layer, cropped_encoder_layer)\n",
914
+ "\n",
915
+ " y = self.convolution1(merged_layer)\n",
916
+ " y = self.activation1(y)\n",
917
+ "\n",
918
+ " if self.normalization:\n",
919
+ " y = self.normalization1(y)\n",
920
+ "\n",
921
+ " y = self.convolution2(y)\n",
922
+ " y = self.activation2(y)\n",
923
+ "\n",
924
+ " if self.normalization:\n",
925
+ " y = self.normalization2(y)\n",
926
+ "\n",
927
+ " return y\n",
928
+ "\n",
929
+ "\n",
930
+ "class UNet(nn.Module):\n",
931
+ " def __init__(\n",
932
+ " self,\n",
933
+ " in_channels: int = 1,\n",
934
+ " out_channels: int = 2,\n",
935
+ " n_blocks: int = 4,\n",
936
+ " start_filters: int = 32,\n",
937
+ " activation: str = 'relu',\n",
938
+ " normalization: str = 'batch',\n",
939
+ " convolution_mode: str = 'same',\n",
940
+ " dim: int = 2,\n",
941
+ " up_mode: str = 'transposed'):\n",
942
+ "\n",
943
+ " super().__init__()\n",
944
+ "\n",
945
+ " self.in_channels = in_channels\n",
946
+ " self.out_channels = out_channels\n",
947
+ " self.n_blocks = n_blocks\n",
948
+ " self.start_filters = start_filters\n",
949
+ " self.activation = activation\n",
950
+ " self.normalization = normalization\n",
951
+ " self.convolution_mode = convolution_mode\n",
952
+ " self.dim = dim\n",
953
+ " self.up_mode = up_mode\n",
954
+ "\n",
955
+ " self.down_blocks = []\n",
956
+ " self.up_blocks = []\n",
957
+ "\n",
958
+ " # ENCODER PATH CREATION\n",
959
+ " for i in range(self.n_blocks):\n",
960
+ " num_filters_in = self.in_channels if i == 0 else num_filters_out\n",
961
+ " num_filters_out = self.start_filters * (2 ** i)\n",
962
+ " pooling = True if i < self.n_blocks - 1 else False\n",
963
+ "\n",
964
+ " down_block = DownBlock(\n",
965
+ " in_channels=num_filters_in, out_channels=num_filters_out,\n",
966
+ " pooling=pooling, activation=self.activation,\n",
967
+ " normalization=self.normalization, convolution_mode=self.convolution_mode,\n",
968
+ " dim=self.dim\n",
969
+ " )\n",
970
+ "\n",
971
+ " self.down_blocks.append(down_block)\n",
972
+ "\n",
973
+ " # DECODER PATH CREATION (NEEDS ONLY N_BLOCKS-1)\n",
974
+ " for i in range(n_blocks - 1):\n",
975
+ " num_filters_in = num_filters_out\n",
976
+ " num_filters_out = num_filters_in // 2\n",
977
+ "\n",
978
+ " up_block = UpBlock(\n",
979
+ " in_channels=num_filters_in, out_channels=num_filters_out,\n",
980
+ " activation=self.activation, normalization=self.normalization,\n",
981
+ " convolution_mode=self.convolution_mode,\n",
982
+ " dim=self.dim, up_mode=self.up_mode\n",
983
+ " )\n",
984
+ "\n",
985
+ " self.up_blocks.append(up_block)\n",
986
+ "\n",
987
+ " # FINAL CONVOLUTION\n",
988
+ " self.convolution_final = get_convolution_layer(\n",
989
+ " num_filters_out, self.out_channels,\n",
990
+ " kernel_size=1, stride=1,\n",
991
+ " padding=0, bias=True, dim=self.dim\n",
992
+ " )\n",
993
+ "\n",
994
+ " # ADDING LIST OF MODULES TO CURRENT MODULE\n",
995
+ " self.down_blocks = nn.ModuleList(self.down_blocks)\n",
996
+ " self.up_blocks = nn.ModuleList(self.up_blocks)\n",
997
+ "\n",
998
+ " # WEIGHT INITIALIZATION\n",
999
+ " self.initialize_parameters()\n",
1000
+ "\n",
1001
+ " @staticmethod\n",
1002
+ " def weight_init(module, method, **kwargs):\n",
1003
+ " if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)):\n",
1004
+ " method(module.weight, **kwargs)\n",
1005
+ "\n",
1006
+ " @staticmethod\n",
1007
+ " def bias_init(module, method, **kwargs):\n",
1008
+ " if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)):\n",
1009
+ " method(module.bias, **kwargs)\n",
1010
+ "\n",
1011
+ " def initialize_parameters(self,\n",
1012
+ " method_weights=nn.init.xavier_uniform_,\n",
1013
+ " method_bias=nn.init.zeros_,\n",
1014
+ " kwargs_weights={},\n",
1015
+ " kwargs_bias={}):\n",
1016
+ "\n",
1017
+ " for module in self.modules():\n",
1018
+ " self.weight_init(module, method_weights, **kwargs_weights) # initialize weights\n",
1019
+ " self.bias_init(module, method_bias, **kwargs_bias) # initialize bias\n",
1020
+ "\n",
1021
+ " def forward(self, x: torch.tensor):\n",
1022
+ " encoder_output = []\n",
1023
+ "\n",
1024
+ " # ENCODER PATHWAY\n",
1025
+ " for module in self.down_blocks:\n",
1026
+ " x, before_pooling = module(x)\n",
1027
+ " encoder_output.append(before_pooling)\n",
1028
+ "\n",
1029
+ " # DECODER PATHWAY\n",
1030
+ " for i, module in enumerate(self.up_blocks):\n",
1031
+ " before_pool = encoder_output[-(i + 2)]\n",
1032
+ " x = module(before_pool, x)\n",
1033
+ "\n",
1034
+ " x = self.convolution_final(x)\n",
1035
+ "\n",
1036
+ " return x\n",
1037
+ "\n",
1038
+ " def __repr__(self):\n",
1039
+ " attributes = {attr_key: self.__dict__[attr_key] for attr_key in self.__dict__.keys() if '_' not in attr_key[0] and 'training' not in attr_key}\n",
1040
+ " d = {self.__class__.__name__: attributes}\n",
1041
+ "\n",
1042
+ " return f'{d}'"
1043
+ ],
1044
+ "outputs": [],
1045
+ "execution_count": null
1046
+ },
1047
+ {
1048
+ "cell_type": "code",
1049
+ "metadata": {
1050
+ "execution": {
1051
+ "iopub.status.busy": "2025-01-09T15:44:19.501954Z",
1052
+ "iopub.status.idle": "2025-01-09T15:44:19.502387Z",
1053
+ "shell.execute_reply": "2025-01-09T15:44:19.502197Z"
1054
+ }
1055
+ },
1056
+ "source": [
1057
+ "MODEL = UNet(\n",
1058
+ " in_channels=4, out_channels=4,\n",
1059
+ " n_blocks=4, start_filters=32,\n",
1060
+ " activation='relu', normalization='batch',\n",
1061
+ " convolution_mode='same', dim=2\n",
1062
+ ")"
1063
+ ],
1064
+ "outputs": [],
1065
+ "execution_count": null
1066
+ },
1067
+ {
1068
+ "cell_type": "code",
1069
+ "metadata": {
1070
+ "execution": {
1071
+ "iopub.status.busy": "2025-01-09T15:44:19.503541Z",
1072
+ "iopub.status.idle": "2025-01-09T15:44:19.503975Z",
1073
+ "shell.execute_reply": "2025-01-09T15:44:19.503784Z"
1074
+ }
1075
+ },
1076
+ "source": [
1077
+ "background_channel = [0]\n",
1078
+ "\n",
1079
+ "dice_loss = smp.utils.losses.DiceLoss(activation='softmax2d')\n",
1080
+ "\n",
1081
+ "optimizer = torch.optim.Adam([\n",
1082
+ " dict(params=MODEL.parameters(), lr=0.0001)\n",
1083
+ "])\n",
1084
+ "\n",
1085
+ "metrics = [\n",
1086
+ " smp.utils.metrics.IoU(threshold=0.5, ignore_channels=background_channel, activation='softmax2d'),\n",
1087
+ " smp.utils.metrics.Fscore(ignore_channels=background_channel, activation='softmax2d'),\n",
1088
+ "]"
1089
+ ],
1090
+ "outputs": [],
1091
+ "execution_count": null
1092
+ },
1093
+ {
1094
+ "cell_type": "code",
1095
+ "metadata": {
1096
+ "execution": {
1097
+ "iopub.status.busy": "2025-01-09T15:44:19.505175Z",
1098
+ "iopub.status.idle": "2025-01-09T15:44:19.505582Z",
1099
+ "shell.execute_reply": "2025-01-09T15:44:19.505396Z"
1100
+ }
1101
+ },
1102
+ "source": [
1103
+ "train_epoch = smp.utils.train.TrainEpoch(\n",
1104
+ " model=MODEL, loss=dice_loss,\n",
1105
+ " metrics=[], optimizer=optimizer,\n",
1106
+ " device=device, verbose=True\n",
1107
+ ")\n",
1108
+ "\n",
1109
+ "valid_epoch = smp.utils.train.ValidEpoch(\n",
1110
+ " model=MODEL, loss=dice_loss,\n",
1111
+ " metrics=metrics, device=device,\n",
1112
+ " verbose=True\n",
1113
+ ")\n",
1114
+ "\n",
1115
+ "max_dice_score = 0\n",
1116
+ "\n",
1117
+ "stats = {\n",
1118
+ " 'train_loss' : [],\n",
1119
+ " 'valid_loss' : [],\n",
1120
+ " 'fscore' : [],\n",
1121
+ " 'iou_score' : []\n",
1122
+ "}\n",
1123
+ "\n",
1124
+ "for i in range(50):\n",
1125
+ " print(f'\\n |--- EPOCH-{i} ---| ')\n",
1126
+ " train_logs = train_epoch.run(train_loader)\n",
1127
+ " valid_logs = valid_epoch.run(valid_loader)\n",
1128
+ " \n",
1129
+ " if max_dice_score < valid_logs['fscore']:\n",
1130
+ " max_dice_score = valid_logs['fscore']\n",
1131
+ " torch.save(MODEL.state_dict(), f'model/model.pth')\n",
1132
+ " \n",
1133
+ " print('model saved!')\n",
1134
+ " \n",
1135
+ " # loss statistics\n",
1136
+ " stats['train_loss'].append(train_logs['dice_loss'])\n",
1137
+ " stats['valid_loss'].append(valid_logs['dice_loss'])\n",
1138
+ " \n",
1139
+ " # metric statistics\n",
1140
+ " stats['fscore'].append(valid_logs['fscore'])\n",
1141
+ " stats['iou_score'].append(valid_logs['iou_score'])\n",
1142
+ " \n",
1143
+ " np.save(f'model/model.npy', stats)\n",
1144
+ " "
1145
+ ],
1146
+ "outputs": [],
1147
+ "execution_count": null
1148
+ },
1149
+ {
1150
+ "cell_type": "code",
1151
+ "metadata": {},
1152
+ "source": [
1153
+ "STATS = np.load(f'model/model.npy', allow_pickle=True).item()\n",
1154
+ "plt.plot(STATS['train_loss'], label='train_loss')\n",
1155
+ "plt.plot(STATS['valid_loss'], label='valid_loss')\n",
1156
+ "\n",
1157
+ "plt.legend(loc='upper right')\n",
1158
+ "\n",
1159
+ "plt.xlabel('EPOCH')\n",
1160
+ "plt.ylabel('LOSS')\n",
1161
+ "\n",
1162
+ "plt.title('TRAIN & VALIDATION LOSS')"
1163
+ ],
1164
+ "outputs": [],
1165
+ "execution_count": null
1166
+ },
1167
+ {
1168
+ "cell_type": "code",
1169
+ "metadata": {},
1170
+ "source": [
1171
+ "STATS = np.load(f'model/model.npy', allow_pickle=True).item()\n",
1172
+ "plt.plot(STATS['fscore'], label ='fscore')\n",
1173
+ "plt.legend(loc = \"lower right\")\n",
1174
+ "plt.ylabel('SCORE')\n",
1175
+ "plt.xlabel('EPOCH')\n",
1176
+ "plt.title('F_SCORE')\n",
1177
+ "\n",
1178
+ "plt.plot(STATS['iou_score'], label ='iou_score')\n",
1179
+ "plt.legend(loc = \"lower right\")\n",
1180
+ "plt.ylabel('SCORE')\n",
1181
+ "plt.xlabel('EPOCH')\n",
1182
+ "plt.title('IOU_SCORE')"
1183
+ ],
1184
+ "outputs": [],
1185
+ "execution_count": null
1186
+ },
1187
+ {
1188
+ "cell_type": "code",
1189
+ "metadata": {},
1190
+ "source": [
1191
+ "MODEL.load_state_dict(torch.load('model/model.pth', weights_only=True))"
1192
+ ],
1193
+ "outputs": [],
1194
+ "execution_count": null
1195
+ },
1196
+ {
1197
+ "cell_type": "code",
1198
+ "metadata": {},
1199
+ "source": [
1200
+ "with torch.no_grad():\n",
1201
+ " out = MODEL(a.cuda())"
1202
+ ],
1203
+ "outputs": [],
1204
+ "execution_count": null
1205
+ },
1206
+ {
1207
+ "cell_type": "code",
1208
+ "metadata": {},
1209
+ "source": [
1210
+ "plt.figure(figsize = (18, 10))\n",
1211
+ "plt.subplot(1, 3, 1)\n",
1212
+ "plt.imshow(a[2, 0],cmap='bone')\n",
1213
+ "plt.title('Input Image')\n",
1214
+ "\n",
1215
+ "plt.subplot(1, 3, 2)\n",
1216
+ "plt.imshow(a[2, 0],cmap='bone')\n",
1217
+ "plt.imshow(out.cpu()[2, 0], alpha = 0.5, cmap = 'nipy_spectral')\n",
1218
+ "plt.title('Predicted Segmentation')\n",
1219
+ "\n",
1220
+ "plt.subplot(1, 3, 3)\n",
1221
+ "plt.imshow(out.cpu()[2, 0], cmap = 'bone')\n",
1222
+ "plt.title('Prediction')"
1223
+ ],
1224
+ "outputs": [],
1225
+ "execution_count": null
1226
+ },
1227
+ {
1228
+ "cell_type": "code",
1229
+ "metadata": {},
1230
+ "source": [],
1231
+ "outputs": [],
1232
+ "execution_count": null
1233
+ },
1234
+ {
1235
+ "cell_type": "code",
1236
+ "metadata": {},
1237
+ "source": [
1238
+ "\n",
1239
+ "# Enhanced Data Augmentation\n",
1240
+ "from albumentations import Compose, RandomCrop, ElasticTransform, GridDistortion, OpticalDistortion, RandomBrightnessContrast, GaussNoise, Flip\n",
1241
+ "\n",
1242
+ "def get_augmentation_pipeline():\n",
1243
+ " return Compose([\n",
1244
+ " Flip(p=0.5),\n",
1245
+ " RandomCrop(height=128, width=128, p=0.5),\n",
1246
+ " ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.5),\n",
1247
+ " GridDistortion(p=0.5),\n",
1248
+ " OpticalDistortion(p=0.5),\n",
1249
+ " GaussNoise(p=0.5),\n",
1250
+ " RandomBrightnessContrast(p=0.5)\n",
1251
+ " ])\n",
1252
+ "\n",
1253
+ "augmentation_pipeline = get_augmentation_pipeline()\n"
1254
+ ],
1255
+ "outputs": [],
1256
+ "execution_count": null
1257
+ },
1258
+ {
1259
+ "cell_type": "code",
1260
+ "metadata": {},
1261
+ "source": [
1262
+ "\n",
1263
+ "# Switching to Attention U-Net / UNet++ with Pre-trained Encoders\n",
1264
+ "import segmentation_models_pytorch as smp\n",
1265
+ "\n",
1266
+ "# Define a UNet++ with a ResNet34 encoder pre-trained on ImageNet\n",
1267
+ "model = smp.UnetPlusPlus(\n",
1268
+ " encoder_name=\"resnet34\",\n",
1269
+ " encoder_weights=\"imagenet\",\n",
1270
+ " in_channels=4,\n",
1271
+ " classes=4\n",
1272
+ ")\n"
1273
+ ],
1274
+ "outputs": [],
1275
+ "execution_count": null
1276
+ },
1277
+ {
1278
+ "cell_type": "code",
1279
+ "metadata": {},
1280
+ "source": [
1281
+ "\n",
1282
+ "# Improved Loss Function\n",
1283
+ "import torch.nn as nn\n",
1284
+ "from segmentation_models_pytorch.losses import TverskyLoss\n",
1285
+ "\n",
1286
+ "# Combine Dice Loss and Tversky Loss\n",
1287
+ "class CombinedLoss(nn.Module):\n",
1288
+ " def __init__(self, alpha=0.5):\n",
1289
+ " super(CombinedLoss, self).__init__()\n",
1290
+ " self.dice_loss = smp.losses.DiceLoss(\"softmax\")\n",
1291
+ " self.tversky_loss = TverskyLoss(\"softmax\", alpha=0.7, beta=0.3)\n",
1292
+ " self.alpha = alpha\n",
1293
+ "\n",
1294
+ " def forward(self, y_pred, y_true):\n",
1295
+ " return self.alpha * self.dice_loss(y_pred, y_true) + (1 - self.alpha) * self.tversky_loss(y_pred, y_true)\n",
1296
+ "\n",
1297
+ "loss_fn = CombinedLoss()\n"
1298
+ ],
1299
+ "outputs": [],
1300
+ "execution_count": null
1301
+ },
1302
+ {
1303
+ "cell_type": "code",
1304
+ "metadata": {},
1305
+ "source": [
1306
+ "from sklearn.svm._liblinear import train_wrap\n",
1307
+ "\n",
1308
+ "num_epochs = 50\n",
1309
+ "\n",
1310
+ "# Learning Rate Scheduling\n",
1311
+ "from torch.optim.lr_scheduler import CosineAnnealingLR\n",
1312
+ "\n",
1313
+ "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
1314
+ "scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-5) # Cosine Annealing\n",
1315
+ "\n",
1316
+ "# Update the scheduler in each epoch\n",
1317
+ "for epoch in range(num_epochs):\n",
1318
+ " train_wrap(...) # Train your model for one epoch\n",
1319
+ " scheduler.step()\n"
1320
+ ],
1321
+ "outputs": [],
1322
+ "execution_count": null
1323
+ },
1324
+ {
1325
+ "cell_type": "code",
1326
+ "metadata": {},
1327
+ "source": [
1328
+ "\n",
1329
+ "# Post-Processing with CRF\n",
1330
+ "import pydensecrf.densecrf as dcrf\n",
1331
+ "\n",
1332
+ "def apply_crf(prob_map, img):\n",
1333
+ " d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], 4) # 4 is the number of classes\n",
1334
+ " U = -np.log(prob_map)\n",
1335
+ " d.setUnaryEnergy(U)\n",
1336
+ "\n",
1337
+ " # Add pairwise terms\n",
1338
+ " d.addPairwiseGaussian(sxy=3, compat=3)\n",
1339
+ " d.addPairwiseBilateral(sxy=30, srgb=13, rgbim=img, compat=10)\n",
1340
+ "\n",
1341
+ " Q = d.inference(5) # Number of iterations\n",
1342
+ " \n",
1343
+ " return np.argmax(Q, axis=0).reshape((img.shape[0], img.shape[1]))\n"
1344
+ ],
1345
+ "outputs": [],
1346
+ "execution_count": null
1347
+ },
1348
+ {
1349
+ "cell_type": "code",
1350
+ "metadata": {},
1351
+ "source": [
1352
+ "\n",
1353
+ "# Cross-Validation\n",
1354
+ "from sklearn.model_selection import KFold\n",
1355
+ "\n",
1356
+ "kf = KFold(n_splits=5)\n",
1357
+ "for train_idx, valid_idx in kf.split(dataset):\n",
1358
+ " train_data = Subset(dataset, train_idx)\n",
1359
+ " valid_data = Subset(dataset, valid_idx)\n",
1360
+ "\n",
1361
+ " train_loader = DataLoader(train_data, batch_size=16, shuffle=True)\n",
1362
+ " valid_loader = DataLoader(valid_data, batch_size=16, shuffle=False)\n",
1363
+ "\n",
1364
+ " train_model(train_loader, valid_loader)\n"
1365
+ ],
1366
+ "outputs": [],
1367
+ "execution_count": null
1368
+ },
1369
+ {
1370
+ "cell_type": "code",
1371
+ "metadata": {},
1372
+ "source": [
1373
+ "\n",
1374
+ "# Ensemble Learning\n",
1375
+ "class EnsembleModel(nn.Module):\n",
1376
+ " def __init__(self, models):\n",
1377
+ " super(EnsembleModel, self).__init__()\n",
1378
+ " self.models = nn.ModuleList(models)\n",
1379
+ "\n",
1380
+ " def forward(self, x):\n",
1381
+ " outputs = [model(x) for model in self.models]\n",
1382
+ " return torch.mean(torch.stack(outputs), dim=0)\n",
1383
+ "\n",
1384
+ "# Combine multiple trained models\n",
1385
+ "models = [model1, model2, model3] # Pre-trained models\n",
1386
+ "ensemble_model = EnsembleModel(models)\n"
1387
+ ],
1388
+ "outputs": [],
1389
+ "execution_count": null
1390
+ }
1391
+ ],
1392
+ "metadata": {
1393
+ "kaggle": {
1394
+ "accelerator": "nvidiaTeslaT4",
1395
+ "dataSources": [
1396
+ {
1397
+ "datasetId": 723383,
1398
+ "sourceId": 1267593,
1399
+ "sourceType": "datasetVersion"
1400
+ },
1401
+ {
1402
+ "datasetId": 751906,
1403
+ "sourceId": 1299795,
1404
+ "sourceType": "datasetVersion"
1405
+ }
1406
+ ],
1407
+ "dockerImageVersionId": 30823,
1408
+ "isGpuEnabled": true,
1409
+ "isInternetEnabled": true,
1410
+ "language": "python",
1411
+ "sourceType": "notebook"
1412
+ },
1413
+ "kernelspec": {
1414
+ "display_name": "Python 3",
1415
+ "language": "python",
1416
+ "name": "python3"
1417
+ },
1418
+ "language_info": {
1419
+ "codemirror_mode": {
1420
+ "name": "ipython",
1421
+ "version": 3
1422
+ },
1423
+ "file_extension": ".py",
1424
+ "mimetype": "text/x-python",
1425
+ "name": "python",
1426
+ "nbconvert_exporter": "python",
1427
+ "pygments_lexer": "ipython3",
1428
+ "version": "3.10.12"
1429
+ }
1430
+ },
1431
+ "nbformat": 4,
1432
+ "nbformat_minor": 4
1433
+ }
brats_scratch-temp.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
brats_scratch.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
links.csv ADDED
The diff for this file is too large to render. See raw diff
 
model/model.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df6534f6cb7866692c00232c381826a053dd4bdd88127c3c4b26fcc24d41e387
3
+ size 5961
model/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2369e10b3b15fbc78f934c5173efe9b09f8e91d575b6abfaa326987d188b7ad0
3
+ size 7853934
pre_links.csv ADDED
The diff for this file is too large to render. See raw diff
 
pre_model/model.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a11df31cd0be24bf393c5c64a6add5004026176ae02ecaaca2ee8d8b735651b
3
+ size 24561
pre_model/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c029a6b882c7e96f3eac714d99dcc4066bb8cff4b1d5e584f2a8498fa38316ca
3
+ size 295128250
tests.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Suggestions to Improve BraTS U-Net Segmentation Pipeline
2
+
3
+ # 1. Enhanced Data Augmentation
4
+ from albumentations import Compose, RandomCrop, ElasticTransform, GridDistortion, OpticalDistortion, RandomBrightnessContrast, GaussianNoise, Flip
5
+ from sklearn.svm._liblinear import train
6
+
7
+
8
+ def get_augmentation_pipeline():
9
+ return Compose([
10
+ Flip(p=0.5),
11
+ RandomCrop(height=128, width=128, p=0.5),
12
+ ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.5),
13
+ GridDistortion(p=0.5),
14
+ OpticalDistortion(p=0.5),
15
+ GaussianNoise(p=0.5),
16
+ RandomBrightnessContrast(p=0.5)
17
+ ])
18
+
19
+ augmentation_pipeline = get_augmentation_pipeline()
20
+
21
+ # Apply this pipeline to your dataset loader as part of preprocessing.
22
+
23
+ # 2. Switching to Attention U-Net / UNet++ with Pre-trained Encoders
24
+ import segmentation_models_pytorch as smp
25
+
26
+ # Define a UNet++ with a ResNet34 encoder pre-trained on ImageNet
27
+ model = smp.UnetPlusPlus(
28
+ encoder_name="resnet34", # Encoder architecture
29
+ encoder_weights="imagenet", # Use ImageNet pre-trained weights
30
+ in_channels=4, # Number of input channels (BraTS has 4 modalities)
31
+ classes=4 # Number of output classes
32
+ )
33
+
34
+ # 3. Improved Loss Function
35
+ import torch
36
+ import torch.nn as nn
37
+ from segmentation_models_pytorch.losses import TverskyLoss
38
+
39
+ # Combine Dice Loss and Tversky Loss
40
+ class CombinedLoss(nn.Module):
41
+ def __init__(self, alpha=0.5):
42
+ super(CombinedLoss, self).__init__()
43
+ self.dice_loss = smp.losses.DiceLoss("softmax")
44
+ self.tversky_loss = TverskyLoss("softmax", alpha=0.7, beta=0.3)
45
+ self.alpha = alpha
46
+
47
+ def forward(self, y_pred, y_true):
48
+ return self.alpha * self.dice_loss(y_pred, y_true) + (1 - self.alpha) * self.tversky_loss(y_pred, y_true)
49
+
50
+ loss_fn = CombinedLoss()
51
+
52
+ # 4. Learning Rate Scheduling
53
+ from torch.optim.lr_scheduler import CosineAnnealingLR
54
+
55
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
56
+ scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-5) # Cosine Annealing
57
+
58
+ # Update the scheduler in each epoch
59
+ for epoch in range(num_epochs):
60
+ train(...) # Train your model for one epoch
61
+ scheduler.step()
62
+
63
+ # 5. Post-Processing with CRF
64
+ import pydensecrf.densecrf as dcrf
65
+
66
+ def apply_crf(prob_map, img):
67
+ d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], 4) # 4 is the number of classes
68
+ U = -np.log(prob_map)
69
+ d.setUnaryEnergy(U)
70
+
71
+ # Add pairwise terms
72
+ d.addPairwiseGaussian(sxy=3, compat=3)
73
+ d.addPairwiseBilateral(sxy=30, srgb=13, rgbim=img, compat=10)
74
+
75
+ Q = d.inference(5) # Number of iterations
76
+ return np.argmax(Q, axis=0).reshape((img.shape[0], img.shape[1]))
77
+
78
+ # Apply this on your predicted probabilities
79
+
80
+ # 6. Cross-Validation
81
+ from sklearn.model_selection import KFold
82
+
83
+ kf = KFold(n_splits=5)
84
+ for train_idx, valid_idx in kf.split(dataset):
85
+ train_data = Subset(dataset, train_idx)
86
+ valid_data = Subset(dataset, valid_idx)
87
+
88
+ train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
89
+ valid_loader = DataLoader(valid_data, batch_size=16, shuffle=False)
90
+
91
+ train_model(train_loader, valid_loader)
92
+
93
+ # 7. Ensemble Learning
94
+ class EnsembleModel(nn.Module):
95
+ def __init__(self, models):
96
+ super(EnsembleModel, self).__init__()
97
+ self.models = nn.ModuleList(models)
98
+
99
+ def forward(self, x):
100
+ outputs = [model(x) for model in self.models]
101
+ return torch.mean(torch.stack(outputs), dim=0)
102
+
103
+ # Combine multiple trained models
104
+ models = [model1, model2, model3] # Pre-trained models
105
+ ensemble_model = EnsembleModel(models)
106
+
107
+ # 8. Hyperparameter Tuning with Grid Search (Example)
108
+ from sklearn.model_selection import ParameterGrid
109
+
110
+ param_grid = {
111
+ 'learning_rate': [1e-3, 1e-4],
112
+ 'batch_size': [8, 16],
113
+ 'loss_alpha': [0.5, 0.7]
114
+ }
115
+
116
+ for params in ParameterGrid(param_grid):
117
+ optimizer = torch.optim.Adam(model.parameters(), lr=params['learning_rate'])
118
+ loss_fn = CombinedLoss(alpha=params['loss_alpha'])
119
+ train_loader = DataLoader(train_data, batch_size=params['batch_size'])
120
+
121
+ train_model(train_loader, valid_loader)