Soutrik commited on
Commit
dcb5590
·
1 Parent(s): cf754af

added: optim datamodule

Browse files
Files changed (1) hide show
  1. src/datamodules/catdog_datamodule.py +138 -0
src/datamodules/catdog_datamodule.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Union, Tuple, Optional, List
3
+ import os
4
+ import lightning as L
5
+ from torch.utils.data import DataLoader, random_split
6
+ from torchvision import transforms
7
+ from torchvision.datasets import ImageFolder
8
+ from torchvision.datasets.utils import download_and_extract_archive
9
+ from loguru import logger
10
+
11
+
12
+ class CatDogImageDataModule(L.LightningDataModule):
13
+ """DataModule for Cat and Dog Image Classification using ImageFolder."""
14
+
15
+ def __init__(
16
+ self,
17
+ data_dir: Union[str, Path] = "data",
18
+ batch_size: int = 32,
19
+ num_workers: int = 4,
20
+ train_val_split: List[float] = [0.8, 0.2],
21
+ pin_memory: bool = False,
22
+ image_size: int = 224,
23
+ url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
24
+ ):
25
+ super().__init__()
26
+ self.data_dir = Path(data_dir)
27
+ self.batch_size = batch_size
28
+ self.num_workers = num_workers
29
+ self.train_val_split = train_val_split
30
+ self.pin_memory = pin_memory
31
+ self.image_size = image_size
32
+ self.url = url
33
+
34
+ # Initialize variables for datasets
35
+ self.train_dataset = None
36
+ self.val_dataset = None
37
+ self.test_dataset = None
38
+
39
+ def prepare_data(self):
40
+ """Download the dataset if it doesn't exist."""
41
+ dataset_path = self.data_dir / "cats_and_dogs_filtered"
42
+ if not dataset_path.exists():
43
+ logger.info("Downloading and extracting dataset.")
44
+ download_and_extract_archive(
45
+ url=self.url, download_root=self.data_dir, remove_finished=True
46
+ )
47
+ logger.info("Download completed.")
48
+
49
+ def setup(self, stage: Optional[str] = None):
50
+ """Set up the train, validation, and test datasets."""
51
+ transform = transforms.Compose(
52
+ [
53
+ transforms.Resize((self.image_size, self.image_size)),
54
+ transforms.RandomHorizontalFlip(),
55
+ transforms.ToTensor(),
56
+ transforms.Normalize(
57
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
58
+ ),
59
+ ]
60
+ )
61
+
62
+ train_path = self.data_dir / "cats_and_dogs_filtered" / "train"
63
+ test_path = self.data_dir / "cats_and_dogs_filtered" / "validation"
64
+
65
+ if stage == "fit" or stage is None:
66
+ full_train_dataset = ImageFolder(root=train_path, transform=transform)
67
+ train_size = int(self.train_val_split[0] * len(full_train_dataset))
68
+ val_size = len(full_train_dataset) - train_size
69
+ self.train_dataset, self.val_dataset = random_split(
70
+ full_train_dataset, [train_size, val_size]
71
+ )
72
+ logger.info(
73
+ f"Train/Validation split: {len(self.train_dataset)} train, {len(self.val_dataset)} validation images."
74
+ )
75
+
76
+ if stage == "test" or stage is None:
77
+ self.test_dataset = ImageFolder(root=test_path, transform=transform)
78
+ logger.info(f"Test dataset size: {len(self.test_dataset)} images.")
79
+
80
+ def _create_dataloader(self, dataset, shuffle: bool = False) -> DataLoader:
81
+ """Helper function to create a DataLoader."""
82
+ return DataLoader(
83
+ dataset=dataset,
84
+ batch_size=self.batch_size,
85
+ num_workers=self.num_workers,
86
+ pin_memory=self.pin_memory,
87
+ shuffle=shuffle,
88
+ )
89
+
90
+ def train_dataloader(self) -> DataLoader:
91
+ return self._create_dataloader(self.train_dataset, shuffle=True)
92
+
93
+ def val_dataloader(self) -> DataLoader:
94
+ return self._create_dataloader(self.val_dataset)
95
+
96
+ def test_dataloader(self) -> DataLoader:
97
+ return self._create_dataloader(self.test_dataset)
98
+
99
+
100
+ if __name__ == "__main__":
101
+ from omegaconf import DictConfig, OmegaConf
102
+ import hydra
103
+ import rootutils
104
+
105
+ # Setup root directory
106
+ root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
107
+ logger.info(f"Root directory: {root}")
108
+
109
+ @hydra.main(
110
+ version_base="1.3",
111
+ config_path=str(root / "configs"),
112
+ config_name="train",
113
+ )
114
+ def main(cfg: DictConfig):
115
+ # Log configuration
116
+ logger.info("Config:\n" + OmegaConf.to_yaml(cfg))
117
+
118
+ # Initialize DataModule
119
+ datamodule = CatDogImageDataModule(
120
+ data_dir=cfg.data.data_dir,
121
+ batch_size=cfg.data.batch_size,
122
+ num_workers=cfg.data.num_workers,
123
+ train_val_split=cfg.data.train_val_split,
124
+ pin_memory=cfg.data.pin_memory,
125
+ image_size=cfg.data.image_size,
126
+ url=cfg.data.dataset_url,
127
+ )
128
+ datamodule.prepare_data()
129
+ datamodule.setup()
130
+
131
+ # Log DataLoader sizes
132
+ logger.info(f"Train DataLoader: {len(datamodule.train_dataloader())} batches")
133
+ logger.info(
134
+ f"Validation DataLoader: {len(datamodule.val_dataloader())} batches"
135
+ )
136
+ logger.info(f"Test DataLoader: {len(datamodule.test_dataloader())} batches")
137
+
138
+ main()