File size: 5,370 Bytes
23fa981
 
 
 
 
 
 
356b6f2
23fa981
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159c02d
356b6f2
 
23fa981
 
 
 
 
 
 
 
159c02d
23fa981
 
 
 
 
 
 
 
 
 
 
 
 
356b6f2
23fa981
 
 
 
 
356b6f2
 
23fa981
 
 
 
 
 
 
 
 
 
 
 
 
356b6f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os

import lightning as L
import numpy as np
import pandas as pd
import torch
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision.io import read_image
from torchvision.transforms import v2 as T


class DRDataset(Dataset):
    def __init__(self, csv_path: str, transform=None):
        self.csv_path = csv_path
        self.transform = transform
        self.image_paths, self.labels = self.load_csv_data()

    def load_csv_data(self):
        # Check if CSV file exists
        if not os.path.isfile(self.csv_path):
            raise FileNotFoundError(f"CSV file '{self.csv_path}' not found.")

        # Load data from CSV file
        data = pd.read_csv(self.csv_path)

        # Check if 'image_path' and 'label' columns exist
        if "image_path" not in data.columns or "label" not in data.columns:
            raise ValueError("CSV file must contain 'image_path' and 'label' columns.")

        # Extract image paths and labels
        image_paths = data["image_path"].tolist()
        labels = data["label"].tolist()

        # Check if any image paths are invalid
        invalid_image_paths = [
            img_path for img_path in image_paths if not os.path.isfile(img_path)
        ]
        if invalid_image_paths:
            raise FileNotFoundError(f"Invalid image paths found: {invalid_image_paths}")

        # Convert labels to LongTensor
        labels = torch.LongTensor(labels)

        return image_paths, labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]

        # Load image
        try:
            image = read_image(image_path)
        except Exception as e:
            raise IOError(f"Error loading image at path '{image_path}': {e}")

        # Apply transformations if provided
        if self.transform:
            try:
                image = self.transform(image)
            except Exception as e:
                raise RuntimeError(
                    f"Error applying transformations to image at path '{image_path}': {e}"
                )

        return image, label


class DRDataModule(L.LightningDataModule):
    def __init__(self, batch_size: int = 8, num_workers: int = 4):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers

        # Define the transformations
        self.train_transform = T.Compose(
            [
                T.Resize((224, 224), antialias=True),
                T.RandomAffine(degrees=10, translate=(0.01, 0.01), scale=(0.99, 1.01)),
                T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.01),
                T.RandomHorizontalFlip(p=0.5),
                T.ToDtype(torch.float32, scale=True),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

        self.val_transform = T.Compose(
            [
                T.Resize((224, 224), antialias=True),
                T.ToDtype(torch.float32, scale=True),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

        self.num_classes = 5

    def setup(self, stage=None):
        self.train_dataset = DRDataset("data/train.csv", transform=self.train_transform)
        self.val_dataset = DRDataset("data/val.csv", transform=self.val_transform)

        # compute class weights
        labels = self.train_dataset.labels.numpy()
        self.class_weights = None # self.compute_class_weights(labels)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            sampler=self._get_weighted_sampler(self.train_dataset.labels.numpy()),
            # shuffle=True,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers
        )

    def compute_class_weights(self, labels):
        class_weights = compute_class_weight(
            class_weight="balanced", classes=np.unique(labels), y=labels
        )
        return torch.tensor(class_weights, dtype=torch.float32)
    
    def _get_weighted_sampler(self, labels: np.ndarray) -> WeightedRandomSampler:
        """Returns a WeightedRandomSampler based on class weights.
        
        The weights tensor should contain a weight for each sample, not the class weights.
        Have a look at this post for an example: https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2
        https://www.maskaravivek.com/post/pytorch-weighted-random-sampler/
        """
        

        class_sample_count = np.array([len(np.where(labels == label)[0]) for label in np.unique(labels)])
        weight = 1. / class_sample_count
        samples_weight = np.array([weight[label] for label in labels])
        samples_weight = torch.from_numpy(samples_weight)

        # class_weights = compute_class_weight("balanced", classes=np.unique(labels), y=labels)
        # class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32)
        return WeightedRandomSampler(weights=samples_weight, num_samples=len(labels), replacement=True)