from abc import abstractmethod from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset from PIL import Image, ImageFile from pathlib import Path from functools import partial from torchvision import transforms as T, utils from torch import nn def exists(val): return val is not None def cycle(dl): while True: for data in dl: yield data def convert_image_to(img_type, image): if image.mode != img_type: return image.convert(img_type) return image class Txt2ImgIterableBaseDataset(IterableDataset): ''' Define an interface to make the IterableDatasets for text2img data chainable ''' def __init__(self, num_records=0, valid_ids=None, size=256): super().__init__() self.num_records = num_records self.valid_ids = valid_ids self.sample_ids = valid_ids self.size = size # print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') # def __len__(self): # return self.num_records @abstractmethod def __iter__(self): pass class BaseDataset(Dataset): def __init__( self, folder, image_size, exts = ['jpg', 'jpeg', 'png', 'tiff'], convert_image_to_type = None ): super().__init__() self.folder = folder self.image_size = image_size self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] convert_fn = partial(convert_image_to, convert_image_to_type) if exists(convert_image_to_type) else nn.Identity() self.transform = T.Compose([ T.Lambda(convert_fn), T.Resize(image_size), T.RandomHorizontalFlip(), T.CenterCrop(image_size), T.ToTensor() ]) def __len__(self): return len(self.paths) def __getitem__(self, index): path = self.paths[index] img = Image.open(path) return self.transform(img)