File size: 5,370 Bytes
23fa981 356b6f2 23fa981 159c02d 356b6f2 23fa981 159c02d 23fa981 356b6f2 23fa981 356b6f2 23fa981 356b6f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import lightning as L
import numpy as np
import pandas as pd
import torch
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision.io import read_image
from torchvision.transforms import v2 as T
class DRDataset(Dataset):
def __init__(self, csv_path: str, transform=None):
self.csv_path = csv_path
self.transform = transform
self.image_paths, self.labels = self.load_csv_data()
def load_csv_data(self):
# Check if CSV file exists
if not os.path.isfile(self.csv_path):
raise FileNotFoundError(f"CSV file '{self.csv_path}' not found.")
# Load data from CSV file
data = pd.read_csv(self.csv_path)
# Check if 'image_path' and 'label' columns exist
if "image_path" not in data.columns or "label" not in data.columns:
raise ValueError("CSV file must contain 'image_path' and 'label' columns.")
# Extract image paths and labels
image_paths = data["image_path"].tolist()
labels = data["label"].tolist()
# Check if any image paths are invalid
invalid_image_paths = [
img_path for img_path in image_paths if not os.path.isfile(img_path)
]
if invalid_image_paths:
raise FileNotFoundError(f"Invalid image paths found: {invalid_image_paths}")
# Convert labels to LongTensor
labels = torch.LongTensor(labels)
return image_paths, labels
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
label = self.labels[idx]
# Load image
try:
image = read_image(image_path)
except Exception as e:
raise IOError(f"Error loading image at path '{image_path}': {e}")
# Apply transformations if provided
if self.transform:
try:
image = self.transform(image)
except Exception as e:
raise RuntimeError(
f"Error applying transformations to image at path '{image_path}': {e}"
)
return image, label
class DRDataModule(L.LightningDataModule):
def __init__(self, batch_size: int = 8, num_workers: int = 4):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
# Define the transformations
self.train_transform = T.Compose(
[
T.Resize((224, 224), antialias=True),
T.RandomAffine(degrees=10, translate=(0.01, 0.01), scale=(0.99, 1.01)),
T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.01),
T.RandomHorizontalFlip(p=0.5),
T.ToDtype(torch.float32, scale=True),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
self.val_transform = T.Compose(
[
T.Resize((224, 224), antialias=True),
T.ToDtype(torch.float32, scale=True),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
self.num_classes = 5
def setup(self, stage=None):
self.train_dataset = DRDataset("data/train.csv", transform=self.train_transform)
self.val_dataset = DRDataset("data/val.csv", transform=self.val_transform)
# compute class weights
labels = self.train_dataset.labels.numpy()
self.class_weights = None # self.compute_class_weights(labels)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
sampler=self._get_weighted_sampler(self.train_dataset.labels.numpy()),
# shuffle=True,
num_workers=self.num_workers,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers
)
def compute_class_weights(self, labels):
class_weights = compute_class_weight(
class_weight="balanced", classes=np.unique(labels), y=labels
)
return torch.tensor(class_weights, dtype=torch.float32)
def _get_weighted_sampler(self, labels: np.ndarray) -> WeightedRandomSampler:
"""Returns a WeightedRandomSampler based on class weights.
The weights tensor should contain a weight for each sample, not the class weights.
Have a look at this post for an example: https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2
https://www.maskaravivek.com/post/pytorch-weighted-random-sampler/
"""
class_sample_count = np.array([len(np.where(labels == label)[0]) for label in np.unique(labels)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[label] for label in labels])
samples_weight = torch.from_numpy(samples_weight)
# class_weights = compute_class_weight("balanced", classes=np.unique(labels), y=labels)
# class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32)
return WeightedRandomSampler(weights=samples_weight, num_samples=len(labels), replacement=True)
|