timroelofs123 commited on
Commit
e2a98a8
·
1 Parent(s): 7efa9d0
app.py CHANGED
@@ -1,12 +1,25 @@
1
  import gradio as gr
2
  import torch
3
- import argparse
4
  import git
 
5
 
6
  git.Repo.clone_from("https://huggingface.co/timroelofs123/face_re-aging", "./hf")
7
 
 
 
 
 
 
 
 
8
  from model.models import UNet
9
- from scripts.test_functions import process_image
 
 
 
 
 
 
10
 
11
 
12
  model_path = "hf/best_unet_model.pth"
@@ -15,12 +28,20 @@ unet_model = UNet().to(device)
15
  unet_model.load_state_dict(torch.load(model_path, map_location=device))
16
  unet_model.eval()
17
 
18
- def block(image, source_age, target_age):
19
  return process_image(unet_model, image, video=False, source_age=source_age,
20
- target_age=target_age, window_size=512, stride=256)
 
 
 
 
 
 
 
 
21
 
22
- demo = gr.Interface(
23
- fn=block,
24
  inputs=[
25
  gr.Image(type="pil"),
26
  gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
@@ -29,12 +50,50 @@ demo = gr.Interface(
29
  outputs="image",
30
  examples=[
31
  ['assets/gradio_example_images/1.png', 20, 80],
32
- # ['assets/gradio_example_images/2.png', 75, 40],
33
- # ['assets/gradio_example_images/3.png', 30, 70],
34
- # ['assets/gradio_example_images/4.png', 22, 60],
35
- # ['assets/gradio_example_images/5.png', 28, 75],
36
- # ['assets/gradio_example_images/6.png', 35, 15]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  ],
 
 
 
 
 
38
  )
39
 
 
 
 
 
 
40
  demo.launch()
 
1
  import gradio as gr
2
  import torch
 
3
  import git
4
+ import os, shutil
5
 
6
  git.Repo.clone_from("https://huggingface.co/timroelofs123/face_re-aging", "./hf")
7
 
8
+ git.Repo.clone_from("https://github.com/timroelofs123/face_reaging", "./fr")
9
+
10
+ shutil.move('./fr/assets', '.')
11
+ shutil.move('./fr/models', '.')
12
+ shutil.move('./fr/scripts', '.')
13
+ shutil.move('./fr/utils', '.')
14
+
15
  from model.models import UNet
16
+ from scripts.test_functions import process_image, process_video
17
+
18
+ # default settings
19
+ window_size = 512
20
+ stride = 256
21
+ steps = 18
22
+ frame_count = 100
23
 
24
 
25
  model_path = "hf/best_unet_model.pth"
 
28
  unet_model.load_state_dict(torch.load(model_path, map_location=device))
29
  unet_model.eval()
30
 
31
+ def block_img(image, source_age, target_age):
32
  return process_image(unet_model, image, video=False, source_age=source_age,
33
+ target_age=target_age, window_size=window_size, stride=stride)
34
+
35
+ def block_img_vid(image, source_age):
36
+ return process_image(unet_model, image, video=True, source_age=source_age,
37
+ target_age=0, window_size=window_size, stride=stride, steps=steps)
38
+
39
+ def block_vid(video_path, source_age, target_age):
40
+ return process_video(unet_model, video_path, source_age, target_age,
41
+ window_size=window_size, stride=stride, frame_count=frame_count)
42
 
43
+ demo_img = gr.Interface(
44
+ fn=block_img,
45
  inputs=[
46
  gr.Image(type="pil"),
47
  gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
 
50
  outputs="image",
51
  examples=[
52
  ['assets/gradio_example_images/1.png', 20, 80],
53
+ ['assets/gradio_example_images/2.png', 75, 40],
54
+ ['assets/gradio_example_images/3.png', 30, 70],
55
+ ['assets/gradio_example_images/4.png', 22, 60],
56
+ ['assets/gradio_example_images/5.png', 28, 75],
57
+ ['assets/gradio_example_images/6.png', 35, 15]
58
+ ],
59
+ description="Input an image of a person and age them from the source age to the target age."
60
+ )
61
+
62
+ demo_img_vid = gr.Interface(
63
+ fn=block_img_vid,
64
+ inputs=[
65
+ gr.Image(type="pil"),
66
+ gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
67
+ ],
68
+ outputs=gr.Video(),
69
+ examples=[
70
+ ['assets/gradio_example_images/1.png', 20],
71
+ ['assets/gradio_example_images/2.png', 75],
72
+ ['assets/gradio_example_images/3.png', 30],
73
+ ['assets/gradio_example_images/4.png', 22],
74
+ ['assets/gradio_example_images/5.png', 28],
75
+ ['assets/gradio_example_images/6.png', 35]
76
+ ],
77
+ description="Input an image of a person and a video will be returned of the person at different ages."
78
+ )
79
+
80
+ demo_vid = gr.Interface(
81
+ fn=block_vid,
82
+ inputs=[
83
+ gr.Video(),
84
+ gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
85
+ gr.Slider(10, 90, value=80, step=1, label="Target age", info="Choose the age you want to become")
86
  ],
87
+ outputs=gr.Video(),
88
+ # examples=[
89
+ # ['assets/gradio_example_images/orig.mp4', 35, 60],
90
+ # ],
91
+ description="Input a video of a person, and it will be aged frame-by-frame."
92
  )
93
 
94
+ demo = gr.TabbedInterface([demo_img, demo_img_vid, demo_vid],
95
+ tab_names=['Image inference demo', 'Image animation demo', 'Video inference demo'],
96
+ title="Face Re-Aging Demo",
97
+ )
98
+
99
  demo.launch()
assets/gradio_example_images/1.png DELETED
Binary file (987 kB)
 
assets/mask1024.jpg DELETED
Binary file (207 kB)
 
assets/mask512.jpg DELETED
Binary file (10.5 kB)
 
model/__init__.py DELETED
File without changes
model/losses.py DELETED
@@ -1,70 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import lpips # LPIPS library for perceptual loss
4
-
5
- class GeneratorLoss(nn.Module):
6
- def __init__(self, discriminator_model, l1_weight=1.0, perceptual_weight=1.0, adversarial_weight=0.05,
7
- device="cpu"):
8
- super(GeneratorLoss, self).__init__()
9
- self.discriminator_model = discriminator_model
10
- self.l1_weight = l1_weight
11
- self.perceptual_weight = perceptual_weight
12
- self.adversarial_weight = adversarial_weight
13
- self.criterion_l1 = nn.L1Loss()
14
- self.criterion_adversarial = nn.BCEWithLogitsLoss()
15
- self.criterion_perceptual = lpips.LPIPS(net='vgg').to(device)
16
-
17
- def forward(self, output, target, source):
18
- # L1 loss
19
-
20
- l1_loss = self.criterion_l1(output, target)
21
-
22
- # Perceptual loss
23
- perceptual_loss = torch.mean(self.criterion_perceptual(output, target))
24
-
25
- # Adversarial loss
26
- fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1)
27
- fake_prediction = self.discriminator_model(fake_input)
28
-
29
- adversarial_loss = self.criterion_adversarial(fake_prediction, torch.ones_like(fake_prediction))
30
-
31
- # Combine losses
32
- generator_loss = self.l1_weight * l1_loss + self.perceptual_weight * perceptual_loss + \
33
- self.adversarial_weight * adversarial_loss
34
-
35
- return generator_loss, l1_loss, perceptual_loss, adversarial_loss
36
-
37
- class DiscriminatorLoss(nn.Module):
38
- def __init__(self, discriminator_model, fake_weight=1.0, real_weight=2.0, mock_weight=.5):
39
- super(DiscriminatorLoss, self).__init__()
40
- self.discriminator_model = discriminator_model
41
- self.criterion_adversarial = nn.BCEWithLogitsLoss()
42
- self.fake_weight = fake_weight
43
- self.real_weight = real_weight
44
- self.mock_weight = mock_weight
45
-
46
- def forward(self, output, target, source):
47
- # Adversarial loss
48
- fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1) # prediction img with target age
49
- real_input = torch.cat([target, source[:, 4:5, :, :]], dim=1) # target img with target age
50
-
51
- mock_input1 = torch.cat([source[:, :3, :, :], source[:, 4:5, :, :]], dim=1) # source img with target age
52
- mock_input2 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with source age
53
- mock_input3 = torch.cat([output, source[:, 3:4, :, :]], dim=1) # prediction img with source age
54
- mock_input4 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with target age
55
-
56
- fake_pred, real_pred = self.discriminator_model(fake_input), self.discriminator_model(real_input)
57
- mock_pred1, mock_pred2, mock_pred3, mock_pred4 = (self.discriminator_model(mock_input1),
58
- self.discriminator_model(mock_input2),
59
- self.discriminator_model(mock_input3),
60
- self.discriminator_model(mock_input4))
61
-
62
- discriminator_loss = (self.fake_weight * self.criterion_adversarial(fake_pred, torch.zeros_like(fake_pred)) +
63
- self.real_weight * self.criterion_adversarial(real_pred, torch.ones_like(real_pred)) +
64
- self.mock_weight * self.criterion_adversarial(mock_pred1, torch.zeros_like(mock_pred1)) +
65
- self.mock_weight * self.criterion_adversarial(mock_pred2, torch.zeros_like(mock_pred2)) +
66
- self.mock_weight * self.criterion_adversarial(mock_pred3, torch.zeros_like(mock_pred3)) +
67
- self.mock_weight * self.criterion_adversarial(mock_pred4, torch.zeros_like(mock_pred4))
68
- )
69
-
70
- return discriminator_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/models.py DELETED
@@ -1,99 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import antialiased_cnns
4
-
5
-
6
- class DownLayer(nn.Module):
7
- def __init__(self, in_channels, out_channels):
8
- super(DownLayer, self).__init__()
9
- self.layer = nn.Sequential(
10
- nn.MaxPool2d(kernel_size=2, stride=1),
11
- antialiased_cnns.BlurPool(in_channels, stride=2),
12
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
13
- nn.LeakyReLU(inplace=True),
14
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
15
- nn.LeakyReLU(inplace=True)
16
- )
17
-
18
- def forward(self, x):
19
- return self.layer(x)
20
-
21
-
22
- class UpLayer(nn.Module):
23
- def __init__(self, in_channels, out_channels):
24
- super(UpLayer, self).__init__()
25
- # Conv transpose upsampling
26
-
27
- self.blur_upsample = nn.Sequential(
28
- nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0),
29
- antialiased_cnns.BlurPool(out_channels, stride=1)
30
- )
31
-
32
- self.layer = nn.Sequential(
33
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
34
- nn.LeakyReLU(inplace=True),
35
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
36
- nn.LeakyReLU(inplace=True)
37
- )
38
-
39
- def forward(self, x, skip):
40
- x = self.blur_upsample(x)
41
- x = torch.cat([x, skip], dim=1) # Concatenate with skip connection
42
- return self.layer(x)
43
-
44
-
45
- class UNet(nn.Module):
46
- def __init__(self):
47
- super(UNet, self).__init__()
48
- self.init_conv = nn.Sequential(
49
- nn.Conv2d(5, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
50
- nn.LeakyReLU(inplace=True),
51
- nn.Conv2d(64, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
52
- nn.LeakyReLU(inplace=True)
53
- )
54
-
55
- self.down1 = DownLayer(64, 128) # output: 256 x 256 x 128
56
- self.down2 = DownLayer(128, 256) # output: 128 x 128 x 256
57
- self.down3 = DownLayer(256, 512) # output: 64 x 64 x 512
58
- self.down4 = DownLayer(512, 1024) # output: 32 x 32 x 1024
59
- self.up1 = UpLayer(1024, 512) # output: 64 x 64 x 512
60
- self.up2 = UpLayer(512, 256) # output: 128 x 128 x 256
61
- self.up3 = UpLayer(256, 128) # output: 256 x 256 x 128
62
- self.up4 = UpLayer(128, 64) # output: 512 x 512 x 64
63
- self.final_conv = nn.Conv2d(64, 3, kernel_size=1) # output: 512 x 512 x 3
64
-
65
- def forward(self, x):
66
- x0 = self.init_conv(x)
67
- x1 = self.down1(x0)
68
- x2 = self.down2(x1)
69
- x3 = self.down3(x2)
70
- x4 = self.down4(x3)
71
- x = self.up1(x4, x3)
72
- x = self.up2(x, x2)
73
- x = self.up3(x, x1)
74
- x = self.up4(x, x0)
75
- x = self.final_conv(x)
76
- return x
77
-
78
-
79
- class PatchGANDiscriminator(nn.Module):
80
- def __init__(self, input_channels=3):
81
- super(PatchGANDiscriminator, self).__init__()
82
- self.model = nn.Sequential(
83
- nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
84
- nn.LeakyReLU(0.2, inplace=True),
85
-
86
- nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
87
- nn.BatchNorm2d(128),
88
- nn.LeakyReLU(0.2, inplace=True),
89
-
90
- nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
91
- nn.BatchNorm2d(256),
92
- nn.LeakyReLU(0.2, inplace=True),
93
-
94
- nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1)
95
- # Output layer with 1 channel for binary classification
96
- )
97
-
98
- def forward(self, x):
99
- return self.model(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -2,4 +2,6 @@ torch
2
  torchvision
3
  antialiased_cnns
4
  face_recognition
 
 
5
  gitpython
 
2
  torchvision
3
  antialiased_cnns
4
  face_recognition
5
+ ffmpy
6
+ av
7
  gitpython
scripts/__init__.py DELETED
File without changes
scripts/test_functions.py DELETED
@@ -1,141 +0,0 @@
1
- import face_recognition
2
- import numpy as np
3
- from PIL import Image
4
- import torch
5
- from torch.autograd import Variable
6
- from torchvision import transforms
7
- from torchvision.io import write_video
8
- import tempfile
9
-
10
- mask_file = torch.from_numpy(np.array(Image.open('assets/mask1024.jpg').convert('L'))) / 255
11
- small_mask_file = torch.from_numpy(np.array(Image.open('assets/mask512.jpg').convert('L'))) / 255
12
-
13
- def sliding_window_tensor(input_tensor, window_size, stride, your_model, mask=mask_file, small_mask=small_mask_file):
14
- """
15
- Apply aging operation on input tensor using a sliding-window method. This operation is done on the GPU, if available.
16
- """
17
-
18
- input_tensor = input_tensor.to(next(your_model.parameters()).device)
19
- mask = mask.to(next(your_model.parameters()).device)
20
- small_mask = small_mask.to(next(your_model.parameters()).device)
21
-
22
- n, c, h, w = input_tensor.size()
23
- output_tensor = torch.zeros((n, 3, h, w), dtype=input_tensor.dtype, device=input_tensor.device)
24
-
25
- count_tensor = torch.zeros((n, 3, h, w), dtype=torch.float32, device=input_tensor.device)
26
-
27
- add = 2 if window_size % stride != 0 else 1
28
-
29
- for y in range(0, h - window_size + add, stride):
30
- for x in range(0, w - window_size + add, stride):
31
- window = input_tensor[:, :, y:y + window_size, x:x + window_size]
32
-
33
- # Apply the same preprocessing as during training
34
- input_variable = Variable(window, requires_grad=False) # Assuming GPU is available
35
-
36
- # Forward pass
37
- with torch.no_grad():
38
- output = your_model(input_variable)
39
-
40
- output_tensor[:, :, y:y + window_size, x:x + window_size] += output * small_mask
41
- count_tensor[:, :, y:y + window_size, x:x + window_size] += small_mask
42
-
43
- count_tensor = torch.clamp(count_tensor, min=1.0)
44
-
45
- # Average the overlapping regions
46
- output_tensor /= count_tensor
47
-
48
- # Apply mask
49
- output_tensor *= mask
50
-
51
- return output_tensor.cpu()
52
-
53
-
54
- def process_image(your_model, image, video, source_age, target_age=0,
55
- window_size=512, stride=256, steps=18):
56
- """
57
- Aging the person in the image.
58
- If video=False, we age as from source_age to target_age, and return an image.
59
- If video=True, we age from source_age to a range of target ages, and return this as the path to a video.
60
- """
61
- if video:
62
- target_age = 0
63
- input_size = (1024, 1024)
64
-
65
- # image = face_recognition.load_image_file(filename)
66
- image = np.array(image)
67
- if video: # h264 codec requires frame size to be divisible by 2.
68
- width, height, depth = image.shape
69
- new_width = width if width % 2 == 0 else width - 1
70
- new_height = height if height % 2 == 0 else height - 1
71
- image.resize((new_width, new_height, depth))
72
-
73
- fl = face_recognition.face_locations(image)[0]
74
-
75
- # calculate margins
76
- margin_y_t = int((fl[2] - fl[0]) * .63 * .85) # larger as the forehead is often cut off
77
- margin_y_b = int((fl[2] - fl[0]) * .37 * .85)
78
- margin_x = int((fl[1] - fl[3]) // (2 / .85))
79
- margin_y_t += 2 * margin_x - margin_y_t - margin_y_b # make sure square is preserved
80
-
81
- l_y = max([fl[0] - margin_y_t, 0])
82
- r_y = min([fl[2] + margin_y_b, image.shape[0]])
83
- l_x = max([fl[3] - margin_x, 0])
84
- r_x = min([fl[1] + margin_x, image.shape[1]])
85
-
86
- # crop image
87
- cropped_image = image[l_y:r_y, l_x:r_x, :]
88
-
89
- # Resizing
90
- orig_size = cropped_image.shape[:2]
91
-
92
- cropped_image = transforms.ToTensor()(cropped_image)
93
-
94
- cropped_image_resized = transforms.Resize(input_size, interpolation=Image.BILINEAR, antialias=True)(cropped_image)
95
-
96
- source_age_channel = torch.full_like(cropped_image_resized[:1, :, :], source_age / 100)
97
- target_age_channel = torch.full_like(cropped_image_resized[:1, :, :], target_age / 100)
98
- input_tensor = torch.cat([cropped_image_resized, source_age_channel, target_age_channel], dim=0).unsqueeze(0)
99
-
100
- image = transforms.ToTensor()(image)
101
-
102
- if video:
103
- # aging in steps
104
- interval = .8 / steps
105
- aged_cropped_images = torch.zeros((steps, 3, input_size[1], input_size[0]))
106
- for i in range(0, steps):
107
- input_tensor[:, -1, :, :] += interval
108
-
109
- # performing actions on image
110
- aged_cropped_images[i, ...] = sliding_window_tensor(input_tensor, window_size, stride, your_model)
111
-
112
- # resize back to original size
113
- aged_cropped_images_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
114
- aged_cropped_images)
115
-
116
- # re-apply
117
- image = image.repeat(steps, 1, 1, 1)
118
-
119
- image[:, :, l_y:r_y, l_x:r_x] += aged_cropped_images_resized
120
- image = torch.clamp(image, 0, 1)
121
- image = (image * 255).to(torch.uint8)
122
-
123
- output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
124
-
125
- write_video(output_file.name, image.permute(0, 2, 3, 1), 2)
126
-
127
- return output_file.name
128
-
129
- else:
130
- # performing actions on image
131
- aged_cropped_image = sliding_window_tensor(input_tensor, window_size, stride, your_model)
132
-
133
- # resize back to original size
134
- aged_cropped_image_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
135
- aged_cropped_image)
136
-
137
- # re-apply
138
- image[:, l_y:r_y, l_x:r_x] += aged_cropped_image_resized.squeeze(0)
139
- image = torch.clamp(image, 0, 1)
140
-
141
- return transforms.functional.to_pil_image(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/__init__.py DELETED
File without changes
utils/dataloader.py DELETED
@@ -1,63 +0,0 @@
1
- import torch
2
- from torch.utils.data import Dataset, DataLoader
3
- from torchvision import transforms
4
- from PIL import Image
5
- import os
6
- import random
7
- from pathlib import Path
8
-
9
-
10
- # Define the transformations
11
- transform = transforms.Compose([
12
- transforms.RandomRotation(degrees=10),
13
- transforms.RandomCrop(512),
14
- transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
15
- transforms.ToTensor(),
16
- ])
17
-
18
- class CustomDataset(Dataset):
19
- def __init__(self, root_dir, transform=None):
20
- self.root_dir = root_dir
21
- self.transform = transform
22
- self.image_folders = [folder for folder in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, folder))]
23
-
24
- def __len__(self):
25
- return len(self.image_folders)
26
-
27
- def __getitem__(self, idx):
28
- folder_name = self.image_folders[idx]
29
- folder_path = os.path.join(self.root_dir, folder_name)
30
-
31
- # # Get the list of image filenames in the folder
32
- # image_filenames = [f"{i}.jpg" for i in range(0, 101, 10)]
33
- image_filenames = os.listdir(folder_path)
34
-
35
- # Pick two random assets from the folder
36
- source_image_name, target_image_name = random.sample(image_filenames, 2)
37
- # source_image_name, target_image_name = '20.jpg', '80.jpg'
38
-
39
- source_age = int(Path(source_image_name).stem) / 100
40
- target_age = int(Path(target_image_name).stem) / 100
41
-
42
- # Randomly select two assets from the folder
43
- source_image_path = os.path.join(folder_path, source_image_name)
44
- target_image_path = os.path.join(folder_path, target_image_name)
45
-
46
- source_image = Image.open(source_image_path).convert('RGB')
47
- target_image = Image.open(target_image_path).convert('RGB')
48
-
49
- # Apply the same random crop and augmentations to both assets
50
- if self.transform:
51
- seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()
52
- torch.manual_seed(seed)
53
- source_image = self.transform(source_image)
54
- torch.manual_seed(seed)
55
- target_image = self.transform(target_image)
56
-
57
- source_age_channel = torch.full_like(source_image[:1, :, :], source_age)
58
- target_age_channel = torch.full_like(source_image[:1, :, :], target_age)
59
-
60
- # Concatenate the age channels with the source_image
61
- source_image = torch.cat([source_image, source_age_channel, target_age_channel], dim=0)
62
-
63
- return source_image, target_image