import os import numpy as np import torch from torch.utils.data import Dataset from torchvision.transforms import Compose, Resize, ToTensor import imageio from tqdm import tqdm class pix2pixDataset(Dataset): def __init__(self, dataset="maps", data_dir="/projects/ml4science/datasets_pix2pix/", split="train", normalize=True, transforms=None, preload=False, image_size=256, direction="BtoA"): self.datadir = os.path.join(data_dir, dataset) self.img_name_list_path = os.path.join(data_dir, dataset, split) if not os.path.exists(self.datadir): print(f'Dataset directory {self.datadir} does not exists') self.normalize=normalize self.image_name_list = os.listdir(self.img_name_list_path) self.preload = preload self.direction = direction if transforms is None: self.transforms = Compose([ ToTensor(), # Convert to torch tensor Resize((image_size, image_size), antialias=False), # Resize to 256x256 ]) else: self.transforms = transforms if self.preload: self.x_list, self.y_list= (), () for name in tqdm(self.image_name_list): x, y = self.load_every(name) self.x_list = self.x_list + (x,) self.y_list = self.y_list + (y,) self.x_list = torch.stack(self.x_list, 0) self.y_list = torch.stack(self.y_list, 0) print(f"{split} dataset preloaded!") def load_every(self, name): img_array = np.asarray(imageio.imread(os.path.join(self.img_name_list_path, name))) img_H, img_W = img_array.shape[0], img_array.shape[1] if self.normalize: img_array = self.normalize_fn(img_array) x_img, y_img = img_array[:,:img_W//2, :], img_array[:, img_W//2:, :] x_img, y_img = self.transforms(x_img), self.transforms(y_img) # Apply the resize transform return x_img.float(), y_img.float() def normalize_fn(self, x): return (x/255. -0.5)*2 def unnormalize_fn(self, x): return ((x/2 + 0.5) * 255).int().clamp(0, 255) #since these are images def __getitem__(self, index): # getitem should return x0, x1, y (where y is the class label for class conditional generation) class_cond = None if self.preload: x_img, y_img = self.x_list[index], self.y_list[index] else: name = self.image_name_list[index] x_img, y_img = self.load_every(name) # if self.direction == "BtoA": # return x_img, y_img, class_cond # elif self.direction == "AtoB": # return y_img, x_img, class_cond batch ={ "image1":x_img, "image2":y_img, } return batch def __len__(self): return len(self.image_name_list) class FishDataset(Dataset): def __init__(self, data_dir="/projects/ml4science/FishDiffusion/", split="train", normalize=True, transforms=None, preload=False, image_size=128): self.datadir = os.path.join(data_dir) self.img_name_list_path = os.path.join(data_dir, split) if not os.path.exists(self.datadir): print(f'Dataset directory {self.datadir} does not exists') self.normalize=normalize self.image_name_list = os.listdir(self.img_name_list_path) self.preload = preload if transforms is None: # self.transforms = Compose([ # ToTensor(), # Convert to torch tensor # Resize((image_size, image_size), antialias=False), # Resize to 256x256 # ]) self.transforms = Compose([ ToTensor(), # Convert to torch tensor ]) else: self.transforms = transforms if self.preload: self.x_list, self.y_list, self.class_id = (), (), [] for name in tqdm(self.image_name_list): x, y = self.load_every(name) cls_id = int(name.split("_")[-1][:-4]) self.x_list = self.x_list + (x,) self.y_list = self.y_list + (y,) self.class_id.append(cls_id) self.x_list = torch.stack(self.x_list, 0) self.y_list = torch.stack(self.y_list, 0) self.class_id = torch.tensor(self.class_id) print(f"{split} dataset preloaded!") def load_every(self, name): img_array = np.asarray(imageio.imread(os.path.join(self.img_name_list_path, name))) img_H, img_W = img_array.shape[0], img_array.shape[1] if self.normalize: img_array = self.normalize_fn(img_array) x_img, y_img = img_array[:,:img_W//2, :], img_array[:, img_W//2:, :] x_img, y_img = self.transforms(x_img), self.transforms(y_img) # Apply the resize transform return x_img.float(), y_img.float() def normalize_fn(self, x): return (x/255. -0.5)*2 def unnormalize_fn(self, x): return ((x/2 + 0.5) * 255).int().clamp(0, 255) #since these are images def __getitem__(self, index): # getitem should return x0, x1, y (where y is the class label for class conditional generation) if self.preload: x_img, y_img, class_id = self.x_list[index], self.y_list[index], self.class_id[index] else: name = self.image_name_list[index] class_id = torch.tensor(int(name.split("_")[-1][:-4])) x_img, y_img = self.load_every(name) return x_img, y_img, class_id def __len__(self): return len(self.image_name_list)