Spaces:
Runtime error
Runtime error
Soutrik
commited on
Commit
·
dcb5590
1
Parent(s):
cf754af
added: optim datamodule
Browse files
src/datamodules/catdog_datamodule.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Union, Tuple, Optional, List
|
3 |
+
import os
|
4 |
+
import lightning as L
|
5 |
+
from torch.utils.data import DataLoader, random_split
|
6 |
+
from torchvision import transforms
|
7 |
+
from torchvision.datasets import ImageFolder
|
8 |
+
from torchvision.datasets.utils import download_and_extract_archive
|
9 |
+
from loguru import logger
|
10 |
+
|
11 |
+
|
12 |
+
class CatDogImageDataModule(L.LightningDataModule):
|
13 |
+
"""DataModule for Cat and Dog Image Classification using ImageFolder."""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
data_dir: Union[str, Path] = "data",
|
18 |
+
batch_size: int = 32,
|
19 |
+
num_workers: int = 4,
|
20 |
+
train_val_split: List[float] = [0.8, 0.2],
|
21 |
+
pin_memory: bool = False,
|
22 |
+
image_size: int = 224,
|
23 |
+
url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
self.data_dir = Path(data_dir)
|
27 |
+
self.batch_size = batch_size
|
28 |
+
self.num_workers = num_workers
|
29 |
+
self.train_val_split = train_val_split
|
30 |
+
self.pin_memory = pin_memory
|
31 |
+
self.image_size = image_size
|
32 |
+
self.url = url
|
33 |
+
|
34 |
+
# Initialize variables for datasets
|
35 |
+
self.train_dataset = None
|
36 |
+
self.val_dataset = None
|
37 |
+
self.test_dataset = None
|
38 |
+
|
39 |
+
def prepare_data(self):
|
40 |
+
"""Download the dataset if it doesn't exist."""
|
41 |
+
dataset_path = self.data_dir / "cats_and_dogs_filtered"
|
42 |
+
if not dataset_path.exists():
|
43 |
+
logger.info("Downloading and extracting dataset.")
|
44 |
+
download_and_extract_archive(
|
45 |
+
url=self.url, download_root=self.data_dir, remove_finished=True
|
46 |
+
)
|
47 |
+
logger.info("Download completed.")
|
48 |
+
|
49 |
+
def setup(self, stage: Optional[str] = None):
|
50 |
+
"""Set up the train, validation, and test datasets."""
|
51 |
+
transform = transforms.Compose(
|
52 |
+
[
|
53 |
+
transforms.Resize((self.image_size, self.image_size)),
|
54 |
+
transforms.RandomHorizontalFlip(),
|
55 |
+
transforms.ToTensor(),
|
56 |
+
transforms.Normalize(
|
57 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
58 |
+
),
|
59 |
+
]
|
60 |
+
)
|
61 |
+
|
62 |
+
train_path = self.data_dir / "cats_and_dogs_filtered" / "train"
|
63 |
+
test_path = self.data_dir / "cats_and_dogs_filtered" / "validation"
|
64 |
+
|
65 |
+
if stage == "fit" or stage is None:
|
66 |
+
full_train_dataset = ImageFolder(root=train_path, transform=transform)
|
67 |
+
train_size = int(self.train_val_split[0] * len(full_train_dataset))
|
68 |
+
val_size = len(full_train_dataset) - train_size
|
69 |
+
self.train_dataset, self.val_dataset = random_split(
|
70 |
+
full_train_dataset, [train_size, val_size]
|
71 |
+
)
|
72 |
+
logger.info(
|
73 |
+
f"Train/Validation split: {len(self.train_dataset)} train, {len(self.val_dataset)} validation images."
|
74 |
+
)
|
75 |
+
|
76 |
+
if stage == "test" or stage is None:
|
77 |
+
self.test_dataset = ImageFolder(root=test_path, transform=transform)
|
78 |
+
logger.info(f"Test dataset size: {len(self.test_dataset)} images.")
|
79 |
+
|
80 |
+
def _create_dataloader(self, dataset, shuffle: bool = False) -> DataLoader:
|
81 |
+
"""Helper function to create a DataLoader."""
|
82 |
+
return DataLoader(
|
83 |
+
dataset=dataset,
|
84 |
+
batch_size=self.batch_size,
|
85 |
+
num_workers=self.num_workers,
|
86 |
+
pin_memory=self.pin_memory,
|
87 |
+
shuffle=shuffle,
|
88 |
+
)
|
89 |
+
|
90 |
+
def train_dataloader(self) -> DataLoader:
|
91 |
+
return self._create_dataloader(self.train_dataset, shuffle=True)
|
92 |
+
|
93 |
+
def val_dataloader(self) -> DataLoader:
|
94 |
+
return self._create_dataloader(self.val_dataset)
|
95 |
+
|
96 |
+
def test_dataloader(self) -> DataLoader:
|
97 |
+
return self._create_dataloader(self.test_dataset)
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
from omegaconf import DictConfig, OmegaConf
|
102 |
+
import hydra
|
103 |
+
import rootutils
|
104 |
+
|
105 |
+
# Setup root directory
|
106 |
+
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
107 |
+
logger.info(f"Root directory: {root}")
|
108 |
+
|
109 |
+
@hydra.main(
|
110 |
+
version_base="1.3",
|
111 |
+
config_path=str(root / "configs"),
|
112 |
+
config_name="train",
|
113 |
+
)
|
114 |
+
def main(cfg: DictConfig):
|
115 |
+
# Log configuration
|
116 |
+
logger.info("Config:\n" + OmegaConf.to_yaml(cfg))
|
117 |
+
|
118 |
+
# Initialize DataModule
|
119 |
+
datamodule = CatDogImageDataModule(
|
120 |
+
data_dir=cfg.data.data_dir,
|
121 |
+
batch_size=cfg.data.batch_size,
|
122 |
+
num_workers=cfg.data.num_workers,
|
123 |
+
train_val_split=cfg.data.train_val_split,
|
124 |
+
pin_memory=cfg.data.pin_memory,
|
125 |
+
image_size=cfg.data.image_size,
|
126 |
+
url=cfg.data.dataset_url,
|
127 |
+
)
|
128 |
+
datamodule.prepare_data()
|
129 |
+
datamodule.setup()
|
130 |
+
|
131 |
+
# Log DataLoader sizes
|
132 |
+
logger.info(f"Train DataLoader: {len(datamodule.train_dataloader())} batches")
|
133 |
+
logger.info(
|
134 |
+
f"Validation DataLoader: {len(datamodule.val_dataloader())} batches"
|
135 |
+
)
|
136 |
+
logger.info(f"Test DataLoader: {len(datamodule.test_dataloader())} batches")
|
137 |
+
|
138 |
+
main()
|