Spaces:
Runtime error
Runtime error
Soutrik
commited on
Commit
·
b0bdbcf
1
Parent(s):
aeaa968
datamodule new tested
Browse files
configs/data/catdog.yaml
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
_target_: src.datamodules.catdog_datamodule.CatDogImageDataModule
|
| 2 |
-
|
|
|
|
| 3 |
url: ${paths.data_url}
|
| 4 |
num_workers: 4
|
| 5 |
batch_size: 32
|
| 6 |
train_val_split: [0.8, 0.2]
|
| 7 |
pin_memory: False
|
| 8 |
-
image_size:
|
|
|
|
| 1 |
_target_: src.datamodules.catdog_datamodule.CatDogImageDataModule
|
| 2 |
+
root_dir: ${paths.data_dir}
|
| 3 |
+
data_dir: "cats_and_dogs_filtered"
|
| 4 |
url: ${paths.data_url}
|
| 5 |
num_workers: 4
|
| 6 |
batch_size: 32
|
| 7 |
train_val_split: [0.8, 0.2]
|
| 8 |
pin_memory: False
|
| 9 |
+
image_size: 224
|
configs/experiment/catdog_experiment.yaml
CHANGED
|
@@ -18,10 +18,11 @@ seed: 42
|
|
| 18 |
name: "catdog_experiment"
|
| 19 |
|
| 20 |
data:
|
| 21 |
-
|
|
|
|
| 22 |
num_workers: 8
|
| 23 |
pin_memory: True
|
| 24 |
-
image_size:
|
| 25 |
|
| 26 |
model:
|
| 27 |
lr: 1e-3
|
|
|
|
| 18 |
name: "catdog_experiment"
|
| 19 |
|
| 20 |
data:
|
| 21 |
+
dataset: "cats_and_dogs_filtered"
|
| 22 |
+
batch_size: 32
|
| 23 |
num_workers: 8
|
| 24 |
pin_memory: True
|
| 25 |
+
image_size: 224
|
| 26 |
|
| 27 |
model:
|
| 28 |
lr: 1e-3
|
notebooks/datamodule_lightning.ipynb
CHANGED
|
@@ -53,13 +53,222 @@
|
|
| 53 |
}
|
| 54 |
],
|
| 55 |
"source": [
|
| 56 |
-
"\n",
|
| 57 |
"import os\n",
|
| 58 |
"\n",
|
| 59 |
"os.chdir(\"..\")\n",
|
| 60 |
"print(os.getcwd())"
|
| 61 |
]
|
| 62 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
{
|
| 64 |
"cell_type": "code",
|
| 65 |
"execution_count": null,
|
|
|
|
| 53 |
}
|
| 54 |
],
|
| 55 |
"source": [
|
|
|
|
| 56 |
"import os\n",
|
| 57 |
"\n",
|
| 58 |
"os.chdir(\"..\")\n",
|
| 59 |
"print(os.getcwd())"
|
| 60 |
]
|
| 61 |
},
|
| 62 |
+
{
|
| 63 |
+
"cell_type": "code",
|
| 64 |
+
"execution_count": 3,
|
| 65 |
+
"metadata": {},
|
| 66 |
+
"outputs": [
|
| 67 |
+
{
|
| 68 |
+
"name": "stderr",
|
| 69 |
+
"output_type": "stream",
|
| 70 |
+
"text": [
|
| 71 |
+
"/anaconda/envs/emlo_env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 72 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
| 73 |
+
]
|
| 74 |
+
}
|
| 75 |
+
],
|
| 76 |
+
"source": [
|
| 77 |
+
"from pathlib import Path\n",
|
| 78 |
+
"from typing import Union, Tuple, Optional, List\n",
|
| 79 |
+
"import os\n",
|
| 80 |
+
"import lightning as L\n",
|
| 81 |
+
"from torch.utils.data import DataLoader, random_split\n",
|
| 82 |
+
"from torchvision import transforms\n",
|
| 83 |
+
"from torchvision.datasets import ImageFolder\n",
|
| 84 |
+
"from torchvision.datasets.utils import download_and_extract_archive\n",
|
| 85 |
+
"from loguru import logger"
|
| 86 |
+
]
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"cell_type": "code",
|
| 90 |
+
"execution_count": 32,
|
| 91 |
+
"metadata": {},
|
| 92 |
+
"outputs": [],
|
| 93 |
+
"source": [
|
| 94 |
+
"class CatDogImageDataModule(L.LightningDataModule):\n",
|
| 95 |
+
" \"\"\"DataModule for Cat and Dog Image Classification using ImageFolder.\"\"\"\n",
|
| 96 |
+
"\n",
|
| 97 |
+
" def __init__(\n",
|
| 98 |
+
" self,\n",
|
| 99 |
+
" data_root: Union[str, Path] = \"data\",\n",
|
| 100 |
+
" data_dir: Union[str, Path] = \"cats_and_dogs_filtered\",\n",
|
| 101 |
+
" batch_size: int = 32,\n",
|
| 102 |
+
" num_workers: int = 4,\n",
|
| 103 |
+
" train_val_split: List[float] = [0.8, 0.2],\n",
|
| 104 |
+
" pin_memory: bool = False,\n",
|
| 105 |
+
" image_size: int = 224,\n",
|
| 106 |
+
" url: str = \"https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\",\n",
|
| 107 |
+
" ):\n",
|
| 108 |
+
" super().__init__()\n",
|
| 109 |
+
" self.data_root = Path(data_root)\n",
|
| 110 |
+
" self.data_dir = data_dir\n",
|
| 111 |
+
" self.batch_size = batch_size\n",
|
| 112 |
+
" self.num_workers = num_workers\n",
|
| 113 |
+
" self.train_val_split = train_val_split\n",
|
| 114 |
+
" self.pin_memory = pin_memory\n",
|
| 115 |
+
" self.image_size = image_size\n",
|
| 116 |
+
" self.url = url\n",
|
| 117 |
+
"\n",
|
| 118 |
+
" # Initialize variables for datasets\n",
|
| 119 |
+
" self.train_dataset = None\n",
|
| 120 |
+
" self.val_dataset = None\n",
|
| 121 |
+
" self.test_dataset = None\n",
|
| 122 |
+
"\n",
|
| 123 |
+
" def prepare_data(self):\n",
|
| 124 |
+
" \"\"\"Download the dataset if it doesn't exist.\"\"\"\n",
|
| 125 |
+
" self.dataset_path = self.data_root / self.data_dir\n",
|
| 126 |
+
" if not self.dataset_path.exists():\n",
|
| 127 |
+
" logger.info(\"Downloading and extracting dataset.\")\n",
|
| 128 |
+
" download_and_extract_archive(\n",
|
| 129 |
+
" url=self.url, download_root=self.data_root, remove_finished=True\n",
|
| 130 |
+
" )\n",
|
| 131 |
+
" logger.info(\"Download completed.\")\n",
|
| 132 |
+
"\n",
|
| 133 |
+
" def setup(self, stage: Optional[str] = None):\n",
|
| 134 |
+
" \"\"\"Set up the train, validation, and test datasets.\"\"\"\n",
|
| 135 |
+
"\n",
|
| 136 |
+
" train_transform = transforms.Compose(\n",
|
| 137 |
+
" [\n",
|
| 138 |
+
" transforms.Resize((self.image_size, self.image_size)),\n",
|
| 139 |
+
" transforms.RandomHorizontalFlip(0.1),\n",
|
| 140 |
+
" transforms.RandomRotation(10),\n",
|
| 141 |
+
" transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),\n",
|
| 142 |
+
" transforms.RandomAutocontrast(0.1),\n",
|
| 143 |
+
" transforms.RandomAdjustSharpness(2, 0.1),\n",
|
| 144 |
+
" transforms.ToTensor(),\n",
|
| 145 |
+
" transforms.Normalize(\n",
|
| 146 |
+
" mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
|
| 147 |
+
" ),\n",
|
| 148 |
+
" ]\n",
|
| 149 |
+
" )\n",
|
| 150 |
+
"\n",
|
| 151 |
+
" test_transform = transforms.Compose(\n",
|
| 152 |
+
" [\n",
|
| 153 |
+
" transforms.Resize((self.image_size, self.image_size)),\n",
|
| 154 |
+
" transforms.ToTensor(),\n",
|
| 155 |
+
" transforms.Normalize(\n",
|
| 156 |
+
" mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
|
| 157 |
+
" ),\n",
|
| 158 |
+
" ]\n",
|
| 159 |
+
" )\n",
|
| 160 |
+
"\n",
|
| 161 |
+
" train_path = self.dataset_path / \"train\"\n",
|
| 162 |
+
" test_path = self.dataset_path / \"test\"\n",
|
| 163 |
+
"\n",
|
| 164 |
+
" self.prepare_data()\n",
|
| 165 |
+
"\n",
|
| 166 |
+
" if stage == \"fit\" or stage is None:\n",
|
| 167 |
+
" full_train_dataset = ImageFolder(root=train_path, transform=train_transform)\n",
|
| 168 |
+
" self.class_names = full_train_dataset.classes\n",
|
| 169 |
+
" train_size = int(self.train_val_split[0] * len(full_train_dataset))\n",
|
| 170 |
+
" val_size = len(full_train_dataset) - train_size\n",
|
| 171 |
+
" self.train_dataset, self.val_dataset = random_split(\n",
|
| 172 |
+
" full_train_dataset, [train_size, val_size]\n",
|
| 173 |
+
" )\n",
|
| 174 |
+
" logger.info(\n",
|
| 175 |
+
" f\"Train/Validation split: {len(self.train_dataset)} train, {len(self.val_dataset)} validation images.\"\n",
|
| 176 |
+
" )\n",
|
| 177 |
+
"\n",
|
| 178 |
+
" if stage == \"test\" or stage is None:\n",
|
| 179 |
+
" self.test_dataset = ImageFolder(root=test_path, transform=test_transform)\n",
|
| 180 |
+
" logger.info(f\"Test dataset size: {len(self.test_dataset)} images.\")\n",
|
| 181 |
+
"\n",
|
| 182 |
+
" def _create_dataloader(self, dataset, shuffle: bool = False) -> DataLoader:\n",
|
| 183 |
+
" \"\"\"Helper function to create a DataLoader.\"\"\"\n",
|
| 184 |
+
" return DataLoader(\n",
|
| 185 |
+
" dataset=dataset,\n",
|
| 186 |
+
" batch_size=self.batch_size,\n",
|
| 187 |
+
" num_workers=self.num_workers,\n",
|
| 188 |
+
" pin_memory=self.pin_memory,\n",
|
| 189 |
+
" shuffle=shuffle,\n",
|
| 190 |
+
" )\n",
|
| 191 |
+
"\n",
|
| 192 |
+
" def train_dataloader(self) -> DataLoader:\n",
|
| 193 |
+
" return self._create_dataloader(self.train_dataset, shuffle=True)\n",
|
| 194 |
+
"\n",
|
| 195 |
+
" def val_dataloader(self) -> DataLoader:\n",
|
| 196 |
+
" return self._create_dataloader(self.val_dataset)\n",
|
| 197 |
+
"\n",
|
| 198 |
+
" def test_dataloader(self) -> DataLoader:\n",
|
| 199 |
+
" return self._create_dataloader(self.test_dataset)\n",
|
| 200 |
+
"\n",
|
| 201 |
+
" def get_class_names(self) -> List[str]:\n",
|
| 202 |
+
" return self.class_names"
|
| 203 |
+
]
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"cell_type": "code",
|
| 207 |
+
"execution_count": 33,
|
| 208 |
+
"metadata": {},
|
| 209 |
+
"outputs": [],
|
| 210 |
+
"source": [
|
| 211 |
+
"datamodule = CatDogImageDataModule(\n",
|
| 212 |
+
" data_root=\"data\",\n",
|
| 213 |
+
" data_dir=\"cats_and_dogs_filtered\",\n",
|
| 214 |
+
" batch_size=32,\n",
|
| 215 |
+
" num_workers=4,\n",
|
| 216 |
+
" train_val_split=[0.8, 0.2],\n",
|
| 217 |
+
" pin_memory=True,\n",
|
| 218 |
+
" image_size=224,\n",
|
| 219 |
+
" url=\"https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\",\n",
|
| 220 |
+
")"
|
| 221 |
+
]
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
"cell_type": "code",
|
| 225 |
+
"execution_count": 35,
|
| 226 |
+
"metadata": {},
|
| 227 |
+
"outputs": [
|
| 228 |
+
{
|
| 229 |
+
"name": "stderr",
|
| 230 |
+
"output_type": "stream",
|
| 231 |
+
"text": [
|
| 232 |
+
"\u001b[32m2024-11-10 05:37:17.840\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m81\u001b[0m - \u001b[1mTrain/Validation split: 2241 train, 561 validation images.\u001b[0m\n"
|
| 233 |
+
]
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"name": "stderr",
|
| 237 |
+
"output_type": "stream",
|
| 238 |
+
"text": [
|
| 239 |
+
"\u001b[32m2024-11-10 05:37:17.910\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m87\u001b[0m - \u001b[1mTest dataset size: 198 images.\u001b[0m\n"
|
| 240 |
+
]
|
| 241 |
+
}
|
| 242 |
+
],
|
| 243 |
+
"source": [
|
| 244 |
+
"datamodule.prepare_data()\n",
|
| 245 |
+
"datamodule.setup()\n",
|
| 246 |
+
"class_names = datamodule.get_class_names()\n",
|
| 247 |
+
"train_dataloader = datamodule.train_dataloader()\n",
|
| 248 |
+
"val_dataloader= datamodule.val_dataloader()\n",
|
| 249 |
+
"test_dataloader= datamodule.test_dataloader()"
|
| 250 |
+
]
|
| 251 |
+
},
|
| 252 |
+
{
|
| 253 |
+
"cell_type": "code",
|
| 254 |
+
"execution_count": 36,
|
| 255 |
+
"metadata": {},
|
| 256 |
+
"outputs": [
|
| 257 |
+
{
|
| 258 |
+
"data": {
|
| 259 |
+
"text/plain": [
|
| 260 |
+
"['cats', 'dogs']"
|
| 261 |
+
]
|
| 262 |
+
},
|
| 263 |
+
"execution_count": 36,
|
| 264 |
+
"metadata": {},
|
| 265 |
+
"output_type": "execute_result"
|
| 266 |
+
}
|
| 267 |
+
],
|
| 268 |
+
"source": [
|
| 269 |
+
"class_names"
|
| 270 |
+
]
|
| 271 |
+
},
|
| 272 |
{
|
| 273 |
"cell_type": "code",
|
| 274 |
"execution_count": null,
|
src/datamodules/catdog_datamodule.py
CHANGED
|
@@ -14,7 +14,8 @@ class CatDogImageDataModule(L.LightningDataModule):
|
|
| 14 |
|
| 15 |
def __init__(
|
| 16 |
self,
|
| 17 |
-
|
|
|
|
| 18 |
batch_size: int = 32,
|
| 19 |
num_workers: int = 4,
|
| 20 |
train_val_split: List[float] = [0.8, 0.2],
|
|
@@ -23,7 +24,8 @@ class CatDogImageDataModule(L.LightningDataModule):
|
|
| 23 |
url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
|
| 24 |
):
|
| 25 |
super().__init__()
|
| 26 |
-
self.
|
|
|
|
| 27 |
self.batch_size = batch_size
|
| 28 |
self.num_workers = num_workers
|
| 29 |
self.train_val_split = train_val_split
|
|
@@ -38,21 +40,27 @@ class CatDogImageDataModule(L.LightningDataModule):
|
|
| 38 |
|
| 39 |
def prepare_data(self):
|
| 40 |
"""Download the dataset if it doesn't exist."""
|
| 41 |
-
dataset_path = self.
|
| 42 |
-
if not dataset_path.exists():
|
| 43 |
logger.info("Downloading and extracting dataset.")
|
| 44 |
download_and_extract_archive(
|
| 45 |
-
url=self.url, download_root=self.
|
| 46 |
)
|
| 47 |
logger.info("Download completed.")
|
| 48 |
|
| 49 |
def setup(self, stage: Optional[str] = None):
|
| 50 |
"""Set up the train, validation, and test datasets."""
|
| 51 |
|
|
|
|
|
|
|
| 52 |
train_transform = transforms.Compose(
|
| 53 |
[
|
| 54 |
transforms.Resize((self.image_size, self.image_size)),
|
| 55 |
-
transforms.RandomHorizontalFlip(),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
transforms.ToTensor(),
|
| 57 |
transforms.Normalize(
|
| 58 |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
@@ -70,11 +78,12 @@ class CatDogImageDataModule(L.LightningDataModule):
|
|
| 70 |
]
|
| 71 |
)
|
| 72 |
|
| 73 |
-
train_path = self.
|
| 74 |
-
test_path = self.
|
| 75 |
|
| 76 |
if stage == "fit" or stage is None:
|
| 77 |
full_train_dataset = ImageFolder(root=train_path, transform=train_transform)
|
|
|
|
| 78 |
train_size = int(self.train_val_split[0] * len(full_train_dataset))
|
| 79 |
val_size = len(full_train_dataset) - train_size
|
| 80 |
self.train_dataset, self.val_dataset = random_split(
|
|
@@ -107,43 +116,42 @@ class CatDogImageDataModule(L.LightningDataModule):
|
|
| 107 |
def test_dataloader(self) -> DataLoader:
|
| 108 |
return self._create_dataloader(self.test_dataset)
|
| 109 |
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
if __name__ == "__main__":
|
| 112 |
-
|
| 113 |
import hydra
|
|
|
|
| 114 |
import rootutils
|
| 115 |
|
| 116 |
-
|
| 117 |
-
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 118 |
-
logger.info(f"Root directory: {root}")
|
| 119 |
|
| 120 |
@hydra.main(
|
| 121 |
-
version_base="1.3",
|
| 122 |
-
config_path=str(root / "configs"),
|
| 123 |
-
config_name="train",
|
| 124 |
)
|
| 125 |
-
def
|
| 126 |
-
|
| 127 |
-
logger.info("Config:\n" + OmegaConf.to_yaml(cfg))
|
| 128 |
-
|
| 129 |
-
# Initialize DataModule
|
| 130 |
datamodule = CatDogImageDataModule(
|
|
|
|
| 131 |
data_dir=cfg.data.data_dir,
|
| 132 |
batch_size=cfg.data.batch_size,
|
| 133 |
num_workers=cfg.data.num_workers,
|
| 134 |
train_val_split=cfg.data.train_val_split,
|
| 135 |
pin_memory=cfg.data.pin_memory,
|
| 136 |
image_size=cfg.data.image_size,
|
| 137 |
-
url=cfg.data.url,
|
| 138 |
-
)
|
| 139 |
-
datamodule.prepare_data()
|
| 140 |
-
datamodule.setup()
|
| 141 |
-
|
| 142 |
-
# Log DataLoader sizes
|
| 143 |
-
logger.info(f"Train DataLoader: {len(datamodule.train_dataloader())} batches")
|
| 144 |
-
logger.info(
|
| 145 |
-
f"Validation DataLoader: {len(datamodule.val_dataloader())} batches"
|
| 146 |
)
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
def __init__(
|
| 16 |
self,
|
| 17 |
+
data_root: Union[str, Path] = "data",
|
| 18 |
+
data_dir: Union[str, Path] = "cats_and_dogs_filtered",
|
| 19 |
batch_size: int = 32,
|
| 20 |
num_workers: int = 4,
|
| 21 |
train_val_split: List[float] = [0.8, 0.2],
|
|
|
|
| 24 |
url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
|
| 25 |
):
|
| 26 |
super().__init__()
|
| 27 |
+
self.data_root = Path(data_root)
|
| 28 |
+
self.data_dir = data_dir
|
| 29 |
self.batch_size = batch_size
|
| 30 |
self.num_workers = num_workers
|
| 31 |
self.train_val_split = train_val_split
|
|
|
|
| 40 |
|
| 41 |
def prepare_data(self):
|
| 42 |
"""Download the dataset if it doesn't exist."""
|
| 43 |
+
self.dataset_path = self.data_root / self.data_dir
|
| 44 |
+
if not self.dataset_path.exists():
|
| 45 |
logger.info("Downloading and extracting dataset.")
|
| 46 |
download_and_extract_archive(
|
| 47 |
+
url=self.url, download_root=self.data_root, remove_finished=True
|
| 48 |
)
|
| 49 |
logger.info("Download completed.")
|
| 50 |
|
| 51 |
def setup(self, stage: Optional[str] = None):
|
| 52 |
"""Set up the train, validation, and test datasets."""
|
| 53 |
|
| 54 |
+
self.prepare_data()
|
| 55 |
+
|
| 56 |
train_transform = transforms.Compose(
|
| 57 |
[
|
| 58 |
transforms.Resize((self.image_size, self.image_size)),
|
| 59 |
+
transforms.RandomHorizontalFlip(0.1),
|
| 60 |
+
transforms.RandomRotation(10),
|
| 61 |
+
transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
|
| 62 |
+
transforms.RandomAutocontrast(0.1),
|
| 63 |
+
transforms.RandomAdjustSharpness(2, 0.1),
|
| 64 |
transforms.ToTensor(),
|
| 65 |
transforms.Normalize(
|
| 66 |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
|
|
| 78 |
]
|
| 79 |
)
|
| 80 |
|
| 81 |
+
train_path = self.dataset_path / "train"
|
| 82 |
+
test_path = self.dataset_path / "test"
|
| 83 |
|
| 84 |
if stage == "fit" or stage is None:
|
| 85 |
full_train_dataset = ImageFolder(root=train_path, transform=train_transform)
|
| 86 |
+
self.class_names = full_train_dataset.classes
|
| 87 |
train_size = int(self.train_val_split[0] * len(full_train_dataset))
|
| 88 |
val_size = len(full_train_dataset) - train_size
|
| 89 |
self.train_dataset, self.val_dataset = random_split(
|
|
|
|
| 116 |
def test_dataloader(self) -> DataLoader:
|
| 117 |
return self._create_dataloader(self.test_dataset)
|
| 118 |
|
| 119 |
+
def get_class_names(self) -> List[str]:
|
| 120 |
+
return self.class_names
|
| 121 |
+
|
| 122 |
|
| 123 |
if __name__ == "__main__":
|
| 124 |
+
# Test the CatDogImageDataModule
|
| 125 |
import hydra
|
| 126 |
+
from omegaconf import DictConfig, OmegaConf
|
| 127 |
import rootutils
|
| 128 |
|
| 129 |
+
root = rootutils.setup_root(__file__, indicator=".project-root")
|
|
|
|
|
|
|
| 130 |
|
| 131 |
@hydra.main(
|
| 132 |
+
config_path=str(root / "configs"), version_base="1.3", config_name="train"
|
|
|
|
|
|
|
| 133 |
)
|
| 134 |
+
def test_datamodule(cfg: DictConfig):
|
| 135 |
+
logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
|
|
|
|
|
|
|
|
|
|
| 136 |
datamodule = CatDogImageDataModule(
|
| 137 |
+
data_root=cfg.paths.data_dir,
|
| 138 |
data_dir=cfg.data.data_dir,
|
| 139 |
batch_size=cfg.data.batch_size,
|
| 140 |
num_workers=cfg.data.num_workers,
|
| 141 |
train_val_split=cfg.data.train_val_split,
|
| 142 |
pin_memory=cfg.data.pin_memory,
|
| 143 |
image_size=cfg.data.image_size,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
)
|
| 145 |
+
datamodule.setup(stage="fit")
|
| 146 |
+
train_loader = datamodule.train_dataloader()
|
| 147 |
+
val_loader = datamodule.val_dataloader()
|
| 148 |
+
datamodule.setup(stage="test")
|
| 149 |
+
test_loader = datamodule.test_dataloader()
|
| 150 |
+
class_names = datamodule.get_class_names()
|
| 151 |
+
|
| 152 |
+
logger.info(f"Train loader: {len(train_loader)} batches")
|
| 153 |
+
logger.info(f"Validation loader: {len(val_loader)} batches")
|
| 154 |
+
logger.info(f"Test loader: {len(test_loader)} batches")
|
| 155 |
+
logger.info(f"Class names: {class_names}")
|
| 156 |
+
|
| 157 |
+
test_datamodule()
|
src/train_new.py
CHANGED
|
@@ -122,7 +122,7 @@ def run_test_module(
|
|
| 122 |
return test_metrics[0] if test_metrics else {}
|
| 123 |
|
| 124 |
|
| 125 |
-
@hydra.main(config_path="../configs", config_name="train", version_base="1.
|
| 126 |
def setup_run_trainer(cfg: DictConfig):
|
| 127 |
"""Set up and run the Trainer for training and testing."""
|
| 128 |
# Display configuration
|
|
|
|
| 122 |
return test_metrics[0] if test_metrics else {}
|
| 123 |
|
| 124 |
|
| 125 |
+
@hydra.main(config_path="../configs", config_name="train", version_base="1.3")
|
| 126 |
def setup_run_trainer(cfg: DictConfig):
|
| 127 |
"""Set up and run the Trainer for training and testing."""
|
| 128 |
# Display configuration
|