Spaces:
Runtime error
Runtime error
Commit
·
e2a98a8
1
Parent(s):
7efa9d0
new demos
Browse files- app.py +70 -11
- assets/gradio_example_images/1.png +0 -0
- assets/mask1024.jpg +0 -0
- assets/mask512.jpg +0 -0
- model/__init__.py +0 -0
- model/losses.py +0 -70
- model/models.py +0 -99
- requirements.txt +2 -0
- scripts/__init__.py +0 -0
- scripts/test_functions.py +0 -141
- utils/__init__.py +0 -0
- utils/dataloader.py +0 -63
app.py
CHANGED
@@ -1,12 +1,25 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
import argparse
|
4 |
import git
|
|
|
5 |
|
6 |
git.Repo.clone_from("https://huggingface.co/timroelofs123/face_re-aging", "./hf")
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from model.models import UNet
|
9 |
-
from scripts.test_functions import process_image
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
model_path = "hf/best_unet_model.pth"
|
@@ -15,12 +28,20 @@ unet_model = UNet().to(device)
|
|
15 |
unet_model.load_state_dict(torch.load(model_path, map_location=device))
|
16 |
unet_model.eval()
|
17 |
|
18 |
-
def
|
19 |
return process_image(unet_model, image, video=False, source_age=source_age,
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
fn=
|
24 |
inputs=[
|
25 |
gr.Image(type="pil"),
|
26 |
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
|
@@ -29,12 +50,50 @@ demo = gr.Interface(
|
|
29 |
outputs="image",
|
30 |
examples=[
|
31 |
['assets/gradio_example_images/1.png', 20, 80],
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
],
|
|
|
|
|
|
|
|
|
|
|
38 |
)
|
39 |
|
|
|
|
|
|
|
|
|
|
|
40 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
|
|
3 |
import git
|
4 |
+
import os, shutil
|
5 |
|
6 |
git.Repo.clone_from("https://huggingface.co/timroelofs123/face_re-aging", "./hf")
|
7 |
|
8 |
+
git.Repo.clone_from("https://github.com/timroelofs123/face_reaging", "./fr")
|
9 |
+
|
10 |
+
shutil.move('./fr/assets', '.')
|
11 |
+
shutil.move('./fr/models', '.')
|
12 |
+
shutil.move('./fr/scripts', '.')
|
13 |
+
shutil.move('./fr/utils', '.')
|
14 |
+
|
15 |
from model.models import UNet
|
16 |
+
from scripts.test_functions import process_image, process_video
|
17 |
+
|
18 |
+
# default settings
|
19 |
+
window_size = 512
|
20 |
+
stride = 256
|
21 |
+
steps = 18
|
22 |
+
frame_count = 100
|
23 |
|
24 |
|
25 |
model_path = "hf/best_unet_model.pth"
|
|
|
28 |
unet_model.load_state_dict(torch.load(model_path, map_location=device))
|
29 |
unet_model.eval()
|
30 |
|
31 |
+
def block_img(image, source_age, target_age):
|
32 |
return process_image(unet_model, image, video=False, source_age=source_age,
|
33 |
+
target_age=target_age, window_size=window_size, stride=stride)
|
34 |
+
|
35 |
+
def block_img_vid(image, source_age):
|
36 |
+
return process_image(unet_model, image, video=True, source_age=source_age,
|
37 |
+
target_age=0, window_size=window_size, stride=stride, steps=steps)
|
38 |
+
|
39 |
+
def block_vid(video_path, source_age, target_age):
|
40 |
+
return process_video(unet_model, video_path, source_age, target_age,
|
41 |
+
window_size=window_size, stride=stride, frame_count=frame_count)
|
42 |
|
43 |
+
demo_img = gr.Interface(
|
44 |
+
fn=block_img,
|
45 |
inputs=[
|
46 |
gr.Image(type="pil"),
|
47 |
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
|
|
|
50 |
outputs="image",
|
51 |
examples=[
|
52 |
['assets/gradio_example_images/1.png', 20, 80],
|
53 |
+
['assets/gradio_example_images/2.png', 75, 40],
|
54 |
+
['assets/gradio_example_images/3.png', 30, 70],
|
55 |
+
['assets/gradio_example_images/4.png', 22, 60],
|
56 |
+
['assets/gradio_example_images/5.png', 28, 75],
|
57 |
+
['assets/gradio_example_images/6.png', 35, 15]
|
58 |
+
],
|
59 |
+
description="Input an image of a person and age them from the source age to the target age."
|
60 |
+
)
|
61 |
+
|
62 |
+
demo_img_vid = gr.Interface(
|
63 |
+
fn=block_img_vid,
|
64 |
+
inputs=[
|
65 |
+
gr.Image(type="pil"),
|
66 |
+
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
|
67 |
+
],
|
68 |
+
outputs=gr.Video(),
|
69 |
+
examples=[
|
70 |
+
['assets/gradio_example_images/1.png', 20],
|
71 |
+
['assets/gradio_example_images/2.png', 75],
|
72 |
+
['assets/gradio_example_images/3.png', 30],
|
73 |
+
['assets/gradio_example_images/4.png', 22],
|
74 |
+
['assets/gradio_example_images/5.png', 28],
|
75 |
+
['assets/gradio_example_images/6.png', 35]
|
76 |
+
],
|
77 |
+
description="Input an image of a person and a video will be returned of the person at different ages."
|
78 |
+
)
|
79 |
+
|
80 |
+
demo_vid = gr.Interface(
|
81 |
+
fn=block_vid,
|
82 |
+
inputs=[
|
83 |
+
gr.Video(),
|
84 |
+
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
|
85 |
+
gr.Slider(10, 90, value=80, step=1, label="Target age", info="Choose the age you want to become")
|
86 |
],
|
87 |
+
outputs=gr.Video(),
|
88 |
+
# examples=[
|
89 |
+
# ['assets/gradio_example_images/orig.mp4', 35, 60],
|
90 |
+
# ],
|
91 |
+
description="Input a video of a person, and it will be aged frame-by-frame."
|
92 |
)
|
93 |
|
94 |
+
demo = gr.TabbedInterface([demo_img, demo_img_vid, demo_vid],
|
95 |
+
tab_names=['Image inference demo', 'Image animation demo', 'Video inference demo'],
|
96 |
+
title="Face Re-Aging Demo",
|
97 |
+
)
|
98 |
+
|
99 |
demo.launch()
|
assets/gradio_example_images/1.png
DELETED
Binary file (987 kB)
|
|
assets/mask1024.jpg
DELETED
Binary file (207 kB)
|
|
assets/mask512.jpg
DELETED
Binary file (10.5 kB)
|
|
model/__init__.py
DELETED
File without changes
|
model/losses.py
DELETED
@@ -1,70 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import lpips # LPIPS library for perceptual loss
|
4 |
-
|
5 |
-
class GeneratorLoss(nn.Module):
|
6 |
-
def __init__(self, discriminator_model, l1_weight=1.0, perceptual_weight=1.0, adversarial_weight=0.05,
|
7 |
-
device="cpu"):
|
8 |
-
super(GeneratorLoss, self).__init__()
|
9 |
-
self.discriminator_model = discriminator_model
|
10 |
-
self.l1_weight = l1_weight
|
11 |
-
self.perceptual_weight = perceptual_weight
|
12 |
-
self.adversarial_weight = adversarial_weight
|
13 |
-
self.criterion_l1 = nn.L1Loss()
|
14 |
-
self.criterion_adversarial = nn.BCEWithLogitsLoss()
|
15 |
-
self.criterion_perceptual = lpips.LPIPS(net='vgg').to(device)
|
16 |
-
|
17 |
-
def forward(self, output, target, source):
|
18 |
-
# L1 loss
|
19 |
-
|
20 |
-
l1_loss = self.criterion_l1(output, target)
|
21 |
-
|
22 |
-
# Perceptual loss
|
23 |
-
perceptual_loss = torch.mean(self.criterion_perceptual(output, target))
|
24 |
-
|
25 |
-
# Adversarial loss
|
26 |
-
fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1)
|
27 |
-
fake_prediction = self.discriminator_model(fake_input)
|
28 |
-
|
29 |
-
adversarial_loss = self.criterion_adversarial(fake_prediction, torch.ones_like(fake_prediction))
|
30 |
-
|
31 |
-
# Combine losses
|
32 |
-
generator_loss = self.l1_weight * l1_loss + self.perceptual_weight * perceptual_loss + \
|
33 |
-
self.adversarial_weight * adversarial_loss
|
34 |
-
|
35 |
-
return generator_loss, l1_loss, perceptual_loss, adversarial_loss
|
36 |
-
|
37 |
-
class DiscriminatorLoss(nn.Module):
|
38 |
-
def __init__(self, discriminator_model, fake_weight=1.0, real_weight=2.0, mock_weight=.5):
|
39 |
-
super(DiscriminatorLoss, self).__init__()
|
40 |
-
self.discriminator_model = discriminator_model
|
41 |
-
self.criterion_adversarial = nn.BCEWithLogitsLoss()
|
42 |
-
self.fake_weight = fake_weight
|
43 |
-
self.real_weight = real_weight
|
44 |
-
self.mock_weight = mock_weight
|
45 |
-
|
46 |
-
def forward(self, output, target, source):
|
47 |
-
# Adversarial loss
|
48 |
-
fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1) # prediction img with target age
|
49 |
-
real_input = torch.cat([target, source[:, 4:5, :, :]], dim=1) # target img with target age
|
50 |
-
|
51 |
-
mock_input1 = torch.cat([source[:, :3, :, :], source[:, 4:5, :, :]], dim=1) # source img with target age
|
52 |
-
mock_input2 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with source age
|
53 |
-
mock_input3 = torch.cat([output, source[:, 3:4, :, :]], dim=1) # prediction img with source age
|
54 |
-
mock_input4 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with target age
|
55 |
-
|
56 |
-
fake_pred, real_pred = self.discriminator_model(fake_input), self.discriminator_model(real_input)
|
57 |
-
mock_pred1, mock_pred2, mock_pred3, mock_pred4 = (self.discriminator_model(mock_input1),
|
58 |
-
self.discriminator_model(mock_input2),
|
59 |
-
self.discriminator_model(mock_input3),
|
60 |
-
self.discriminator_model(mock_input4))
|
61 |
-
|
62 |
-
discriminator_loss = (self.fake_weight * self.criterion_adversarial(fake_pred, torch.zeros_like(fake_pred)) +
|
63 |
-
self.real_weight * self.criterion_adversarial(real_pred, torch.ones_like(real_pred)) +
|
64 |
-
self.mock_weight * self.criterion_adversarial(mock_pred1, torch.zeros_like(mock_pred1)) +
|
65 |
-
self.mock_weight * self.criterion_adversarial(mock_pred2, torch.zeros_like(mock_pred2)) +
|
66 |
-
self.mock_weight * self.criterion_adversarial(mock_pred3, torch.zeros_like(mock_pred3)) +
|
67 |
-
self.mock_weight * self.criterion_adversarial(mock_pred4, torch.zeros_like(mock_pred4))
|
68 |
-
)
|
69 |
-
|
70 |
-
return discriminator_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/models.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import antialiased_cnns
|
4 |
-
|
5 |
-
|
6 |
-
class DownLayer(nn.Module):
|
7 |
-
def __init__(self, in_channels, out_channels):
|
8 |
-
super(DownLayer, self).__init__()
|
9 |
-
self.layer = nn.Sequential(
|
10 |
-
nn.MaxPool2d(kernel_size=2, stride=1),
|
11 |
-
antialiased_cnns.BlurPool(in_channels, stride=2),
|
12 |
-
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
13 |
-
nn.LeakyReLU(inplace=True),
|
14 |
-
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
15 |
-
nn.LeakyReLU(inplace=True)
|
16 |
-
)
|
17 |
-
|
18 |
-
def forward(self, x):
|
19 |
-
return self.layer(x)
|
20 |
-
|
21 |
-
|
22 |
-
class UpLayer(nn.Module):
|
23 |
-
def __init__(self, in_channels, out_channels):
|
24 |
-
super(UpLayer, self).__init__()
|
25 |
-
# Conv transpose upsampling
|
26 |
-
|
27 |
-
self.blur_upsample = nn.Sequential(
|
28 |
-
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0),
|
29 |
-
antialiased_cnns.BlurPool(out_channels, stride=1)
|
30 |
-
)
|
31 |
-
|
32 |
-
self.layer = nn.Sequential(
|
33 |
-
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
34 |
-
nn.LeakyReLU(inplace=True),
|
35 |
-
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
36 |
-
nn.LeakyReLU(inplace=True)
|
37 |
-
)
|
38 |
-
|
39 |
-
def forward(self, x, skip):
|
40 |
-
x = self.blur_upsample(x)
|
41 |
-
x = torch.cat([x, skip], dim=1) # Concatenate with skip connection
|
42 |
-
return self.layer(x)
|
43 |
-
|
44 |
-
|
45 |
-
class UNet(nn.Module):
|
46 |
-
def __init__(self):
|
47 |
-
super(UNet, self).__init__()
|
48 |
-
self.init_conv = nn.Sequential(
|
49 |
-
nn.Conv2d(5, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
|
50 |
-
nn.LeakyReLU(inplace=True),
|
51 |
-
nn.Conv2d(64, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
|
52 |
-
nn.LeakyReLU(inplace=True)
|
53 |
-
)
|
54 |
-
|
55 |
-
self.down1 = DownLayer(64, 128) # output: 256 x 256 x 128
|
56 |
-
self.down2 = DownLayer(128, 256) # output: 128 x 128 x 256
|
57 |
-
self.down3 = DownLayer(256, 512) # output: 64 x 64 x 512
|
58 |
-
self.down4 = DownLayer(512, 1024) # output: 32 x 32 x 1024
|
59 |
-
self.up1 = UpLayer(1024, 512) # output: 64 x 64 x 512
|
60 |
-
self.up2 = UpLayer(512, 256) # output: 128 x 128 x 256
|
61 |
-
self.up3 = UpLayer(256, 128) # output: 256 x 256 x 128
|
62 |
-
self.up4 = UpLayer(128, 64) # output: 512 x 512 x 64
|
63 |
-
self.final_conv = nn.Conv2d(64, 3, kernel_size=1) # output: 512 x 512 x 3
|
64 |
-
|
65 |
-
def forward(self, x):
|
66 |
-
x0 = self.init_conv(x)
|
67 |
-
x1 = self.down1(x0)
|
68 |
-
x2 = self.down2(x1)
|
69 |
-
x3 = self.down3(x2)
|
70 |
-
x4 = self.down4(x3)
|
71 |
-
x = self.up1(x4, x3)
|
72 |
-
x = self.up2(x, x2)
|
73 |
-
x = self.up3(x, x1)
|
74 |
-
x = self.up4(x, x0)
|
75 |
-
x = self.final_conv(x)
|
76 |
-
return x
|
77 |
-
|
78 |
-
|
79 |
-
class PatchGANDiscriminator(nn.Module):
|
80 |
-
def __init__(self, input_channels=3):
|
81 |
-
super(PatchGANDiscriminator, self).__init__()
|
82 |
-
self.model = nn.Sequential(
|
83 |
-
nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
|
84 |
-
nn.LeakyReLU(0.2, inplace=True),
|
85 |
-
|
86 |
-
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
|
87 |
-
nn.BatchNorm2d(128),
|
88 |
-
nn.LeakyReLU(0.2, inplace=True),
|
89 |
-
|
90 |
-
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
|
91 |
-
nn.BatchNorm2d(256),
|
92 |
-
nn.LeakyReLU(0.2, inplace=True),
|
93 |
-
|
94 |
-
nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1)
|
95 |
-
# Output layer with 1 channel for binary classification
|
96 |
-
)
|
97 |
-
|
98 |
-
def forward(self, x):
|
99 |
-
return self.model(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -2,4 +2,6 @@ torch
|
|
2 |
torchvision
|
3 |
antialiased_cnns
|
4 |
face_recognition
|
|
|
|
|
5 |
gitpython
|
|
|
2 |
torchvision
|
3 |
antialiased_cnns
|
4 |
face_recognition
|
5 |
+
ffmpy
|
6 |
+
av
|
7 |
gitpython
|
scripts/__init__.py
DELETED
File without changes
|
scripts/test_functions.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
import face_recognition
|
2 |
-
import numpy as np
|
3 |
-
from PIL import Image
|
4 |
-
import torch
|
5 |
-
from torch.autograd import Variable
|
6 |
-
from torchvision import transforms
|
7 |
-
from torchvision.io import write_video
|
8 |
-
import tempfile
|
9 |
-
|
10 |
-
mask_file = torch.from_numpy(np.array(Image.open('assets/mask1024.jpg').convert('L'))) / 255
|
11 |
-
small_mask_file = torch.from_numpy(np.array(Image.open('assets/mask512.jpg').convert('L'))) / 255
|
12 |
-
|
13 |
-
def sliding_window_tensor(input_tensor, window_size, stride, your_model, mask=mask_file, small_mask=small_mask_file):
|
14 |
-
"""
|
15 |
-
Apply aging operation on input tensor using a sliding-window method. This operation is done on the GPU, if available.
|
16 |
-
"""
|
17 |
-
|
18 |
-
input_tensor = input_tensor.to(next(your_model.parameters()).device)
|
19 |
-
mask = mask.to(next(your_model.parameters()).device)
|
20 |
-
small_mask = small_mask.to(next(your_model.parameters()).device)
|
21 |
-
|
22 |
-
n, c, h, w = input_tensor.size()
|
23 |
-
output_tensor = torch.zeros((n, 3, h, w), dtype=input_tensor.dtype, device=input_tensor.device)
|
24 |
-
|
25 |
-
count_tensor = torch.zeros((n, 3, h, w), dtype=torch.float32, device=input_tensor.device)
|
26 |
-
|
27 |
-
add = 2 if window_size % stride != 0 else 1
|
28 |
-
|
29 |
-
for y in range(0, h - window_size + add, stride):
|
30 |
-
for x in range(0, w - window_size + add, stride):
|
31 |
-
window = input_tensor[:, :, y:y + window_size, x:x + window_size]
|
32 |
-
|
33 |
-
# Apply the same preprocessing as during training
|
34 |
-
input_variable = Variable(window, requires_grad=False) # Assuming GPU is available
|
35 |
-
|
36 |
-
# Forward pass
|
37 |
-
with torch.no_grad():
|
38 |
-
output = your_model(input_variable)
|
39 |
-
|
40 |
-
output_tensor[:, :, y:y + window_size, x:x + window_size] += output * small_mask
|
41 |
-
count_tensor[:, :, y:y + window_size, x:x + window_size] += small_mask
|
42 |
-
|
43 |
-
count_tensor = torch.clamp(count_tensor, min=1.0)
|
44 |
-
|
45 |
-
# Average the overlapping regions
|
46 |
-
output_tensor /= count_tensor
|
47 |
-
|
48 |
-
# Apply mask
|
49 |
-
output_tensor *= mask
|
50 |
-
|
51 |
-
return output_tensor.cpu()
|
52 |
-
|
53 |
-
|
54 |
-
def process_image(your_model, image, video, source_age, target_age=0,
|
55 |
-
window_size=512, stride=256, steps=18):
|
56 |
-
"""
|
57 |
-
Aging the person in the image.
|
58 |
-
If video=False, we age as from source_age to target_age, and return an image.
|
59 |
-
If video=True, we age from source_age to a range of target ages, and return this as the path to a video.
|
60 |
-
"""
|
61 |
-
if video:
|
62 |
-
target_age = 0
|
63 |
-
input_size = (1024, 1024)
|
64 |
-
|
65 |
-
# image = face_recognition.load_image_file(filename)
|
66 |
-
image = np.array(image)
|
67 |
-
if video: # h264 codec requires frame size to be divisible by 2.
|
68 |
-
width, height, depth = image.shape
|
69 |
-
new_width = width if width % 2 == 0 else width - 1
|
70 |
-
new_height = height if height % 2 == 0 else height - 1
|
71 |
-
image.resize((new_width, new_height, depth))
|
72 |
-
|
73 |
-
fl = face_recognition.face_locations(image)[0]
|
74 |
-
|
75 |
-
# calculate margins
|
76 |
-
margin_y_t = int((fl[2] - fl[0]) * .63 * .85) # larger as the forehead is often cut off
|
77 |
-
margin_y_b = int((fl[2] - fl[0]) * .37 * .85)
|
78 |
-
margin_x = int((fl[1] - fl[3]) // (2 / .85))
|
79 |
-
margin_y_t += 2 * margin_x - margin_y_t - margin_y_b # make sure square is preserved
|
80 |
-
|
81 |
-
l_y = max([fl[0] - margin_y_t, 0])
|
82 |
-
r_y = min([fl[2] + margin_y_b, image.shape[0]])
|
83 |
-
l_x = max([fl[3] - margin_x, 0])
|
84 |
-
r_x = min([fl[1] + margin_x, image.shape[1]])
|
85 |
-
|
86 |
-
# crop image
|
87 |
-
cropped_image = image[l_y:r_y, l_x:r_x, :]
|
88 |
-
|
89 |
-
# Resizing
|
90 |
-
orig_size = cropped_image.shape[:2]
|
91 |
-
|
92 |
-
cropped_image = transforms.ToTensor()(cropped_image)
|
93 |
-
|
94 |
-
cropped_image_resized = transforms.Resize(input_size, interpolation=Image.BILINEAR, antialias=True)(cropped_image)
|
95 |
-
|
96 |
-
source_age_channel = torch.full_like(cropped_image_resized[:1, :, :], source_age / 100)
|
97 |
-
target_age_channel = torch.full_like(cropped_image_resized[:1, :, :], target_age / 100)
|
98 |
-
input_tensor = torch.cat([cropped_image_resized, source_age_channel, target_age_channel], dim=0).unsqueeze(0)
|
99 |
-
|
100 |
-
image = transforms.ToTensor()(image)
|
101 |
-
|
102 |
-
if video:
|
103 |
-
# aging in steps
|
104 |
-
interval = .8 / steps
|
105 |
-
aged_cropped_images = torch.zeros((steps, 3, input_size[1], input_size[0]))
|
106 |
-
for i in range(0, steps):
|
107 |
-
input_tensor[:, -1, :, :] += interval
|
108 |
-
|
109 |
-
# performing actions on image
|
110 |
-
aged_cropped_images[i, ...] = sliding_window_tensor(input_tensor, window_size, stride, your_model)
|
111 |
-
|
112 |
-
# resize back to original size
|
113 |
-
aged_cropped_images_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
|
114 |
-
aged_cropped_images)
|
115 |
-
|
116 |
-
# re-apply
|
117 |
-
image = image.repeat(steps, 1, 1, 1)
|
118 |
-
|
119 |
-
image[:, :, l_y:r_y, l_x:r_x] += aged_cropped_images_resized
|
120 |
-
image = torch.clamp(image, 0, 1)
|
121 |
-
image = (image * 255).to(torch.uint8)
|
122 |
-
|
123 |
-
output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
124 |
-
|
125 |
-
write_video(output_file.name, image.permute(0, 2, 3, 1), 2)
|
126 |
-
|
127 |
-
return output_file.name
|
128 |
-
|
129 |
-
else:
|
130 |
-
# performing actions on image
|
131 |
-
aged_cropped_image = sliding_window_tensor(input_tensor, window_size, stride, your_model)
|
132 |
-
|
133 |
-
# resize back to original size
|
134 |
-
aged_cropped_image_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
|
135 |
-
aged_cropped_image)
|
136 |
-
|
137 |
-
# re-apply
|
138 |
-
image[:, l_y:r_y, l_x:r_x] += aged_cropped_image_resized.squeeze(0)
|
139 |
-
image = torch.clamp(image, 0, 1)
|
140 |
-
|
141 |
-
return transforms.functional.to_pil_image(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/__init__.py
DELETED
File without changes
|
utils/dataloader.py
DELETED
@@ -1,63 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch.utils.data import Dataset, DataLoader
|
3 |
-
from torchvision import transforms
|
4 |
-
from PIL import Image
|
5 |
-
import os
|
6 |
-
import random
|
7 |
-
from pathlib import Path
|
8 |
-
|
9 |
-
|
10 |
-
# Define the transformations
|
11 |
-
transform = transforms.Compose([
|
12 |
-
transforms.RandomRotation(degrees=10),
|
13 |
-
transforms.RandomCrop(512),
|
14 |
-
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
15 |
-
transforms.ToTensor(),
|
16 |
-
])
|
17 |
-
|
18 |
-
class CustomDataset(Dataset):
|
19 |
-
def __init__(self, root_dir, transform=None):
|
20 |
-
self.root_dir = root_dir
|
21 |
-
self.transform = transform
|
22 |
-
self.image_folders = [folder for folder in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, folder))]
|
23 |
-
|
24 |
-
def __len__(self):
|
25 |
-
return len(self.image_folders)
|
26 |
-
|
27 |
-
def __getitem__(self, idx):
|
28 |
-
folder_name = self.image_folders[idx]
|
29 |
-
folder_path = os.path.join(self.root_dir, folder_name)
|
30 |
-
|
31 |
-
# # Get the list of image filenames in the folder
|
32 |
-
# image_filenames = [f"{i}.jpg" for i in range(0, 101, 10)]
|
33 |
-
image_filenames = os.listdir(folder_path)
|
34 |
-
|
35 |
-
# Pick two random assets from the folder
|
36 |
-
source_image_name, target_image_name = random.sample(image_filenames, 2)
|
37 |
-
# source_image_name, target_image_name = '20.jpg', '80.jpg'
|
38 |
-
|
39 |
-
source_age = int(Path(source_image_name).stem) / 100
|
40 |
-
target_age = int(Path(target_image_name).stem) / 100
|
41 |
-
|
42 |
-
# Randomly select two assets from the folder
|
43 |
-
source_image_path = os.path.join(folder_path, source_image_name)
|
44 |
-
target_image_path = os.path.join(folder_path, target_image_name)
|
45 |
-
|
46 |
-
source_image = Image.open(source_image_path).convert('RGB')
|
47 |
-
target_image = Image.open(target_image_path).convert('RGB')
|
48 |
-
|
49 |
-
# Apply the same random crop and augmentations to both assets
|
50 |
-
if self.transform:
|
51 |
-
seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()
|
52 |
-
torch.manual_seed(seed)
|
53 |
-
source_image = self.transform(source_image)
|
54 |
-
torch.manual_seed(seed)
|
55 |
-
target_image = self.transform(target_image)
|
56 |
-
|
57 |
-
source_age_channel = torch.full_like(source_image[:1, :, :], source_age)
|
58 |
-
target_age_channel = torch.full_like(source_image[:1, :, :], target_age)
|
59 |
-
|
60 |
-
# Concatenate the age channels with the source_image
|
61 |
-
source_image = torch.cat([source_image, source_age_channel, target_age_channel], dim=0)
|
62 |
-
|
63 |
-
return source_image, target_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|