import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import os import random from pathlib import Path # Define the transformations transform = transforms.Compose([ transforms.RandomRotation(degrees=10), transforms.RandomCrop(512), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.ToTensor(), ]) class CustomDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.image_folders = [folder for folder in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, folder))] def __len__(self): return len(self.image_folders) def __getitem__(self, idx): folder_name = self.image_folders[idx] folder_path = os.path.join(self.root_dir, folder_name) # # Get the list of image filenames in the folder # image_filenames = [f"{i}.jpg" for i in range(0, 101, 10)] image_filenames = os.listdir(folder_path) # Pick two random assets from the folder source_image_name, target_image_name = random.sample(image_filenames, 2) # source_image_name, target_image_name = '20.jpg', '80.jpg' source_age = int(Path(source_image_name).stem) / 100 target_age = int(Path(target_image_name).stem) / 100 # Randomly select two assets from the folder source_image_path = os.path.join(folder_path, source_image_name) target_image_path = os.path.join(folder_path, target_image_name) source_image = Image.open(source_image_path).convert('RGB') target_image = Image.open(target_image_path).convert('RGB') # Apply the same random crop and augmentations to both assets if self.transform: seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() torch.manual_seed(seed) source_image = self.transform(source_image) torch.manual_seed(seed) target_image = self.transform(target_image) source_age_channel = torch.full_like(source_image[:1, :, :], source_age) target_age_channel = torch.full_like(source_image[:1, :, :], target_age) # Concatenate the age channels with the source_image source_image = torch.cat([source_image, source_age_channel, target_age_channel], dim=0) return source_image, target_image