|
import numpy as np |
|
import random |
|
import torch |
|
from image_dataset import ImageDataset |
|
from typing import Generator, Tuple |
|
|
|
|
|
class BatchSampler: |
|
""" |
|
Implements an iterable which given a torch dataset and a batch_size |
|
will produce batches of data of that given size. The batches are |
|
returned as tuples in the form (images, labels). |
|
Can produce balanced batches, where each batch will have an equal |
|
amount of samples from each class in the dataset. If your dataset is heavily |
|
|
|
imbalanced, this might mean throwing away a lot of samples from |
|
over-represented classes! |
|
""" |
|
|
|
def __init__(self, batch_size: int, dataset: ImageDataset, balanced: bool = False) -> None: |
|
self.batch_size = batch_size |
|
self.dataset = dataset |
|
self.balanced = balanced |
|
if self.balanced: |
|
|
|
unique, counts = np.unique(self.dataset.targets, return_counts=True) |
|
indexes = [] |
|
|
|
for i in range(len(unique)): |
|
print(i) |
|
indexes.append( |
|
np.random.choice( |
|
np.where(self.dataset.targets == i)[0], |
|
size=counts.min(), |
|
replace=False, |
|
) |
|
) |
|
|
|
self.indexes = np.concatenate(indexes) |
|
else: |
|
|
|
self.indexes = [i for i in range(len(dataset))] |
|
|
|
def __len__(self) -> int: |
|
return (len(self.indexes) // self.batch_size) + 1 |
|
|
|
def shuffle(self) -> None: |
|
random.shuffle(self.indexes) |
|
|
|
def __iter__(self) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: |
|
remaining = False |
|
self.shuffle() |
|
|
|
for i in range(0, len(self.indexes), self.batch_size): |
|
|
|
if i + self.batch_size > len(self.indexes): |
|
remaining = True |
|
break |
|
|
|
else: |
|
|
|
X_batch = [ |
|
self.dataset[self.indexes[k]][0] |
|
for k in range(i, i + self.batch_size) |
|
] |
|
Y_batch = [ |
|
self.dataset[self.indexes[k]][1] |
|
for k in range(i, i + self.batch_size) |
|
] |
|
|
|
yield torch.stack(X_batch).float(), torch.tensor(Y_batch).long() |
|
|
|
if remaining: |
|
|
|
X_batch = [ |
|
self.dataset[self.indexes[k]][0] for k in range(i, len(self.indexes)) |
|
] |
|
Y_batch = [ |
|
self.dataset[self.indexes[k]][1] for k in range(i, len(self.indexes)) |
|
] |
|
yield torch.stack(X_batch).float(), torch.tensor(Y_batch).long() |
|
|