File size: 5,392 Bytes
dcb5590
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f27535
 
dcb5590
 
 
 
 
 
 
 
 
 
0f27535
 
 
 
 
 
 
 
 
 
dcb5590
 
 
 
0f27535
dcb5590
 
 
 
 
 
 
 
 
 
0f27535
dcb5590
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f27535
dcb5590
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from pathlib import Path
from typing import Union, Tuple, Optional, List
import os
import lightning as L
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_and_extract_archive
from loguru import logger


class CatDogImageDataModule(L.LightningDataModule):
    """DataModule for Cat and Dog Image Classification using ImageFolder."""

    def __init__(
        self,
        data_dir: Union[str, Path] = "data",
        batch_size: int = 32,
        num_workers: int = 4,
        train_val_split: List[float] = [0.8, 0.2],
        pin_memory: bool = False,
        image_size: int = 224,
        url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
    ):
        super().__init__()
        self.data_dir = Path(data_dir)
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_val_split = train_val_split
        self.pin_memory = pin_memory
        self.image_size = image_size
        self.url = url

        # Initialize variables for datasets
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def prepare_data(self):
        """Download the dataset if it doesn't exist."""
        dataset_path = self.data_dir / "cats_and_dogs_filtered"
        if not dataset_path.exists():
            logger.info("Downloading and extracting dataset.")
            download_and_extract_archive(
                url=self.url, download_root=self.data_dir, remove_finished=True
            )
            logger.info("Download completed.")

    def setup(self, stage: Optional[str] = None):
        """Set up the train, validation, and test datasets."""

        train_transform = transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

        test_transform = transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

        train_path = self.data_dir / "cats_and_dogs_filtered" / "train"
        test_path = self.data_dir / "cats_and_dogs_filtered" / "validation"

        if stage == "fit" or stage is None:
            full_train_dataset = ImageFolder(root=train_path, transform=train_transform)
            train_size = int(self.train_val_split[0] * len(full_train_dataset))
            val_size = len(full_train_dataset) - train_size
            self.train_dataset, self.val_dataset = random_split(
                full_train_dataset, [train_size, val_size]
            )
            logger.info(
                f"Train/Validation split: {len(self.train_dataset)} train, {len(self.val_dataset)} validation images."
            )

        if stage == "test" or stage is None:
            self.test_dataset = ImageFolder(root=test_path, transform=test_transform)
            logger.info(f"Test dataset size: {len(self.test_dataset)} images.")

    def _create_dataloader(self, dataset, shuffle: bool = False) -> DataLoader:
        """Helper function to create a DataLoader."""
        return DataLoader(
            dataset=dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            shuffle=shuffle,
        )

    def train_dataloader(self) -> DataLoader:
        return self._create_dataloader(self.train_dataset, shuffle=True)

    def val_dataloader(self) -> DataLoader:
        return self._create_dataloader(self.val_dataset)

    def test_dataloader(self) -> DataLoader:
        return self._create_dataloader(self.test_dataset)


if __name__ == "__main__":
    from omegaconf import DictConfig, OmegaConf
    import hydra
    import rootutils

    # Setup root directory
    root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
    logger.info(f"Root directory: {root}")

    @hydra.main(
        version_base="1.3",
        config_path=str(root / "configs"),
        config_name="train",
    )
    def main(cfg: DictConfig):
        # Log configuration
        logger.info("Config:\n" + OmegaConf.to_yaml(cfg))

        # Initialize DataModule
        datamodule = CatDogImageDataModule(
            data_dir=cfg.data.data_dir,
            batch_size=cfg.data.batch_size,
            num_workers=cfg.data.num_workers,
            train_val_split=cfg.data.train_val_split,
            pin_memory=cfg.data.pin_memory,
            image_size=cfg.data.image_size,
            url=cfg.data.url,
        )
        datamodule.prepare_data()
        datamodule.setup()

        # Log DataLoader sizes
        logger.info(f"Train DataLoader: {len(datamodule.train_dataloader())} batches")
        logger.info(
            f"Validation DataLoader: {len(datamodule.val_dataloader())} batches"
        )
        logger.info(f"Test DataLoader: {len(datamodule.test_dataloader())} batches")

    main()