Spaces:
Runtime error
Runtime error
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from PIL import Image | |
import os | |
import random | |
from pathlib import Path | |
# Define the transformations | |
transform = transforms.Compose([ | |
transforms.RandomRotation(degrees=10), | |
transforms.RandomCrop(512), | |
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), | |
transforms.ToTensor(), | |
]) | |
class CustomDataset(Dataset): | |
def __init__(self, root_dir, transform=None): | |
self.root_dir = root_dir | |
self.transform = transform | |
self.image_folders = [folder for folder in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, folder))] | |
def __len__(self): | |
return len(self.image_folders) | |
def __getitem__(self, idx): | |
folder_name = self.image_folders[idx] | |
folder_path = os.path.join(self.root_dir, folder_name) | |
# # Get the list of image filenames in the folder | |
# image_filenames = [f"{i}.jpg" for i in range(0, 101, 10)] | |
image_filenames = os.listdir(folder_path) | |
# Pick two random assets from the folder | |
source_image_name, target_image_name = random.sample(image_filenames, 2) | |
# source_image_name, target_image_name = '20.jpg', '80.jpg' | |
source_age = int(Path(source_image_name).stem) / 100 | |
target_age = int(Path(target_image_name).stem) / 100 | |
# Randomly select two assets from the folder | |
source_image_path = os.path.join(folder_path, source_image_name) | |
target_image_path = os.path.join(folder_path, target_image_name) | |
source_image = Image.open(source_image_path).convert('RGB') | |
target_image = Image.open(target_image_path).convert('RGB') | |
# Apply the same random crop and augmentations to both assets | |
if self.transform: | |
seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() | |
torch.manual_seed(seed) | |
source_image = self.transform(source_image) | |
torch.manual_seed(seed) | |
target_image = self.transform(target_image) | |
source_age_channel = torch.full_like(source_image[:1, :, :], source_age) | |
target_age_channel = torch.full_like(source_image[:1, :, :], target_age) | |
# Concatenate the age channels with the source_image | |
source_image = torch.cat([source_image, source_age_channel, target_age_channel], dim=0) | |
return source_image, target_image | |