File size: 2,440 Bytes
c690b8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import random
from pathlib import Path


# Define the transformations
transform = transforms.Compose([
    transforms.RandomRotation(degrees=10),
    transforms.RandomCrop(512),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
])

class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_folders = [folder for folder in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, folder))]

    def __len__(self):
        return len(self.image_folders)

    def __getitem__(self, idx):
        folder_name = self.image_folders[idx]
        folder_path = os.path.join(self.root_dir, folder_name)

        # # Get the list of image filenames in the folder
        # image_filenames = [f"{i}.jpg" for i in range(0, 101, 10)]
        image_filenames = os.listdir(folder_path)

        # Pick two random assets from the folder
        source_image_name, target_image_name = random.sample(image_filenames, 2)
        # source_image_name, target_image_name = '20.jpg', '80.jpg'

        source_age = int(Path(source_image_name).stem) / 100
        target_age = int(Path(target_image_name).stem) / 100

        # Randomly select two assets from the folder
        source_image_path = os.path.join(folder_path, source_image_name)
        target_image_path = os.path.join(folder_path, target_image_name)

        source_image = Image.open(source_image_path).convert('RGB')
        target_image = Image.open(target_image_path).convert('RGB')

        # Apply the same random crop and augmentations to both assets
        if self.transform:
            seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()
            torch.manual_seed(seed)
            source_image = self.transform(source_image)
            torch.manual_seed(seed)
            target_image = self.transform(target_image)

        source_age_channel = torch.full_like(source_image[:1, :, :], source_age)
        target_age_channel = torch.full_like(source_image[:1, :, :], target_age)

        # Concatenate the age channels with the source_image
        source_image = torch.cat([source_image, source_age_channel, target_age_channel], dim=0)

        return source_image, target_image