|
""" |
|
CelebFaces Attributes (CelebA) Dataset |
|
https://www.kaggle.com/datasets/jessicali9530/celeba-dataset |
|
""" |
|
|
|
|
|
import os |
|
|
|
import torch |
|
from PIL import Image |
|
from torch.utils.data import DataLoader, Dataset |
|
from torchvision import transforms |
|
|
|
|
|
class CelebADataset(Dataset): |
|
|
|
def __init__(self, root, img_shape=(64, 64)) -> None: |
|
super().__init__() |
|
self.root = root |
|
self.img_shape = img_shape |
|
self.filenames = sorted(os.listdir(root)) |
|
|
|
def __len__(self) -> int: |
|
return len(self.filenames) |
|
|
|
def __getitem__(self, index: int): |
|
path = os.path.join(self.root, self.filenames[index]) |
|
img = Image.open(path).convert('RGB') |
|
pipeline = transforms.Compose([ |
|
transforms.CenterCrop(168), |
|
transforms.Resize(self.img_shape), |
|
transforms.ToTensor() |
|
]) |
|
return pipeline(img) |
|
|
|
|
|
def get_dataloader(root='data/celebA/img_align_celeba', **kwargs): |
|
dataset = CelebADataset(root, **kwargs) |
|
return DataLoader(dataset, 16, shuffle=True) |
|
|