Spaces:
Runtime error
Runtime error
Commit
·
c690b8f
1
Parent(s):
1931503
add necessary files
Browse files- app.py +0 -4
- assets/mask1024.jpg +0 -0
- assets/mask512.jpg +0 -0
- model/__init__.py +0 -0
- model/losses.py +70 -0
- model/models.py +99 -0
- utils/__init__.py +0 -0
- utils/dataloader.py +63 -0
app.py
CHANGED
@@ -3,12 +3,8 @@ import torch
|
|
3 |
import argparse
|
4 |
import git
|
5 |
|
6 |
-
git.Repo.clone_from("https://github.com/timroelofs123/face_reaging.git", "./face_reaging")
|
7 |
git.Repo.clone_from("https://huggingface.co/timroelofs123/face_re-aging", "./hf")
|
8 |
|
9 |
-
import sys
|
10 |
-
sys.path.append("./face_reaging")
|
11 |
-
|
12 |
from model.models import UNet
|
13 |
from scripts.test_functions import process_image
|
14 |
|
|
|
3 |
import argparse
|
4 |
import git
|
5 |
|
|
|
6 |
git.Repo.clone_from("https://huggingface.co/timroelofs123/face_re-aging", "./hf")
|
7 |
|
|
|
|
|
|
|
8 |
from model.models import UNet
|
9 |
from scripts.test_functions import process_image
|
10 |
|
assets/mask1024.jpg
ADDED
![]() |
assets/mask512.jpg
ADDED
![]() |
model/__init__.py
ADDED
File without changes
|
model/losses.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import lpips # LPIPS library for perceptual loss
|
4 |
+
|
5 |
+
class GeneratorLoss(nn.Module):
|
6 |
+
def __init__(self, discriminator_model, l1_weight=1.0, perceptual_weight=1.0, adversarial_weight=0.05,
|
7 |
+
device="cpu"):
|
8 |
+
super(GeneratorLoss, self).__init__()
|
9 |
+
self.discriminator_model = discriminator_model
|
10 |
+
self.l1_weight = l1_weight
|
11 |
+
self.perceptual_weight = perceptual_weight
|
12 |
+
self.adversarial_weight = adversarial_weight
|
13 |
+
self.criterion_l1 = nn.L1Loss()
|
14 |
+
self.criterion_adversarial = nn.BCEWithLogitsLoss()
|
15 |
+
self.criterion_perceptual = lpips.LPIPS(net='vgg').to(device)
|
16 |
+
|
17 |
+
def forward(self, output, target, source):
|
18 |
+
# L1 loss
|
19 |
+
|
20 |
+
l1_loss = self.criterion_l1(output, target)
|
21 |
+
|
22 |
+
# Perceptual loss
|
23 |
+
perceptual_loss = torch.mean(self.criterion_perceptual(output, target))
|
24 |
+
|
25 |
+
# Adversarial loss
|
26 |
+
fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1)
|
27 |
+
fake_prediction = self.discriminator_model(fake_input)
|
28 |
+
|
29 |
+
adversarial_loss = self.criterion_adversarial(fake_prediction, torch.ones_like(fake_prediction))
|
30 |
+
|
31 |
+
# Combine losses
|
32 |
+
generator_loss = self.l1_weight * l1_loss + self.perceptual_weight * perceptual_loss + \
|
33 |
+
self.adversarial_weight * adversarial_loss
|
34 |
+
|
35 |
+
return generator_loss, l1_loss, perceptual_loss, adversarial_loss
|
36 |
+
|
37 |
+
class DiscriminatorLoss(nn.Module):
|
38 |
+
def __init__(self, discriminator_model, fake_weight=1.0, real_weight=2.0, mock_weight=.5):
|
39 |
+
super(DiscriminatorLoss, self).__init__()
|
40 |
+
self.discriminator_model = discriminator_model
|
41 |
+
self.criterion_adversarial = nn.BCEWithLogitsLoss()
|
42 |
+
self.fake_weight = fake_weight
|
43 |
+
self.real_weight = real_weight
|
44 |
+
self.mock_weight = mock_weight
|
45 |
+
|
46 |
+
def forward(self, output, target, source):
|
47 |
+
# Adversarial loss
|
48 |
+
fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1) # prediction img with target age
|
49 |
+
real_input = torch.cat([target, source[:, 4:5, :, :]], dim=1) # target img with target age
|
50 |
+
|
51 |
+
mock_input1 = torch.cat([source[:, :3, :, :], source[:, 4:5, :, :]], dim=1) # source img with target age
|
52 |
+
mock_input2 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with source age
|
53 |
+
mock_input3 = torch.cat([output, source[:, 3:4, :, :]], dim=1) # prediction img with source age
|
54 |
+
mock_input4 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with target age
|
55 |
+
|
56 |
+
fake_pred, real_pred = self.discriminator_model(fake_input), self.discriminator_model(real_input)
|
57 |
+
mock_pred1, mock_pred2, mock_pred3, mock_pred4 = (self.discriminator_model(mock_input1),
|
58 |
+
self.discriminator_model(mock_input2),
|
59 |
+
self.discriminator_model(mock_input3),
|
60 |
+
self.discriminator_model(mock_input4))
|
61 |
+
|
62 |
+
discriminator_loss = (self.fake_weight * self.criterion_adversarial(fake_pred, torch.zeros_like(fake_pred)) +
|
63 |
+
self.real_weight * self.criterion_adversarial(real_pred, torch.ones_like(real_pred)) +
|
64 |
+
self.mock_weight * self.criterion_adversarial(mock_pred1, torch.zeros_like(mock_pred1)) +
|
65 |
+
self.mock_weight * self.criterion_adversarial(mock_pred2, torch.zeros_like(mock_pred2)) +
|
66 |
+
self.mock_weight * self.criterion_adversarial(mock_pred3, torch.zeros_like(mock_pred3)) +
|
67 |
+
self.mock_weight * self.criterion_adversarial(mock_pred4, torch.zeros_like(mock_pred4))
|
68 |
+
)
|
69 |
+
|
70 |
+
return discriminator_loss
|
model/models.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import antialiased_cnns
|
4 |
+
|
5 |
+
|
6 |
+
class DownLayer(nn.Module):
|
7 |
+
def __init__(self, in_channels, out_channels):
|
8 |
+
super(DownLayer, self).__init__()
|
9 |
+
self.layer = nn.Sequential(
|
10 |
+
nn.MaxPool2d(kernel_size=2, stride=1),
|
11 |
+
antialiased_cnns.BlurPool(in_channels, stride=2),
|
12 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
13 |
+
nn.LeakyReLU(inplace=True),
|
14 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
15 |
+
nn.LeakyReLU(inplace=True)
|
16 |
+
)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
return self.layer(x)
|
20 |
+
|
21 |
+
|
22 |
+
class UpLayer(nn.Module):
|
23 |
+
def __init__(self, in_channels, out_channels):
|
24 |
+
super(UpLayer, self).__init__()
|
25 |
+
# Conv transpose upsampling
|
26 |
+
|
27 |
+
self.blur_upsample = nn.Sequential(
|
28 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0),
|
29 |
+
antialiased_cnns.BlurPool(out_channels, stride=1)
|
30 |
+
)
|
31 |
+
|
32 |
+
self.layer = nn.Sequential(
|
33 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
34 |
+
nn.LeakyReLU(inplace=True),
|
35 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
36 |
+
nn.LeakyReLU(inplace=True)
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, x, skip):
|
40 |
+
x = self.blur_upsample(x)
|
41 |
+
x = torch.cat([x, skip], dim=1) # Concatenate with skip connection
|
42 |
+
return self.layer(x)
|
43 |
+
|
44 |
+
|
45 |
+
class UNet(nn.Module):
|
46 |
+
def __init__(self):
|
47 |
+
super(UNet, self).__init__()
|
48 |
+
self.init_conv = nn.Sequential(
|
49 |
+
nn.Conv2d(5, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
|
50 |
+
nn.LeakyReLU(inplace=True),
|
51 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
|
52 |
+
nn.LeakyReLU(inplace=True)
|
53 |
+
)
|
54 |
+
|
55 |
+
self.down1 = DownLayer(64, 128) # output: 256 x 256 x 128
|
56 |
+
self.down2 = DownLayer(128, 256) # output: 128 x 128 x 256
|
57 |
+
self.down3 = DownLayer(256, 512) # output: 64 x 64 x 512
|
58 |
+
self.down4 = DownLayer(512, 1024) # output: 32 x 32 x 1024
|
59 |
+
self.up1 = UpLayer(1024, 512) # output: 64 x 64 x 512
|
60 |
+
self.up2 = UpLayer(512, 256) # output: 128 x 128 x 256
|
61 |
+
self.up3 = UpLayer(256, 128) # output: 256 x 256 x 128
|
62 |
+
self.up4 = UpLayer(128, 64) # output: 512 x 512 x 64
|
63 |
+
self.final_conv = nn.Conv2d(64, 3, kernel_size=1) # output: 512 x 512 x 3
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
x0 = self.init_conv(x)
|
67 |
+
x1 = self.down1(x0)
|
68 |
+
x2 = self.down2(x1)
|
69 |
+
x3 = self.down3(x2)
|
70 |
+
x4 = self.down4(x3)
|
71 |
+
x = self.up1(x4, x3)
|
72 |
+
x = self.up2(x, x2)
|
73 |
+
x = self.up3(x, x1)
|
74 |
+
x = self.up4(x, x0)
|
75 |
+
x = self.final_conv(x)
|
76 |
+
return x
|
77 |
+
|
78 |
+
|
79 |
+
class PatchGANDiscriminator(nn.Module):
|
80 |
+
def __init__(self, input_channels=3):
|
81 |
+
super(PatchGANDiscriminator, self).__init__()
|
82 |
+
self.model = nn.Sequential(
|
83 |
+
nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
|
84 |
+
nn.LeakyReLU(0.2, inplace=True),
|
85 |
+
|
86 |
+
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
|
87 |
+
nn.BatchNorm2d(128),
|
88 |
+
nn.LeakyReLU(0.2, inplace=True),
|
89 |
+
|
90 |
+
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
|
91 |
+
nn.BatchNorm2d(256),
|
92 |
+
nn.LeakyReLU(0.2, inplace=True),
|
93 |
+
|
94 |
+
nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1)
|
95 |
+
# Output layer with 1 channel for binary classification
|
96 |
+
)
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
return self.model(x)
|
utils/__init__.py
ADDED
File without changes
|
utils/dataloader.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
|
10 |
+
# Define the transformations
|
11 |
+
transform = transforms.Compose([
|
12 |
+
transforms.RandomRotation(degrees=10),
|
13 |
+
transforms.RandomCrop(512),
|
14 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
15 |
+
transforms.ToTensor(),
|
16 |
+
])
|
17 |
+
|
18 |
+
class CustomDataset(Dataset):
|
19 |
+
def __init__(self, root_dir, transform=None):
|
20 |
+
self.root_dir = root_dir
|
21 |
+
self.transform = transform
|
22 |
+
self.image_folders = [folder for folder in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, folder))]
|
23 |
+
|
24 |
+
def __len__(self):
|
25 |
+
return len(self.image_folders)
|
26 |
+
|
27 |
+
def __getitem__(self, idx):
|
28 |
+
folder_name = self.image_folders[idx]
|
29 |
+
folder_path = os.path.join(self.root_dir, folder_name)
|
30 |
+
|
31 |
+
# # Get the list of image filenames in the folder
|
32 |
+
# image_filenames = [f"{i}.jpg" for i in range(0, 101, 10)]
|
33 |
+
image_filenames = os.listdir(folder_path)
|
34 |
+
|
35 |
+
# Pick two random assets from the folder
|
36 |
+
source_image_name, target_image_name = random.sample(image_filenames, 2)
|
37 |
+
# source_image_name, target_image_name = '20.jpg', '80.jpg'
|
38 |
+
|
39 |
+
source_age = int(Path(source_image_name).stem) / 100
|
40 |
+
target_age = int(Path(target_image_name).stem) / 100
|
41 |
+
|
42 |
+
# Randomly select two assets from the folder
|
43 |
+
source_image_path = os.path.join(folder_path, source_image_name)
|
44 |
+
target_image_path = os.path.join(folder_path, target_image_name)
|
45 |
+
|
46 |
+
source_image = Image.open(source_image_path).convert('RGB')
|
47 |
+
target_image = Image.open(target_image_path).convert('RGB')
|
48 |
+
|
49 |
+
# Apply the same random crop and augmentations to both assets
|
50 |
+
if self.transform:
|
51 |
+
seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()
|
52 |
+
torch.manual_seed(seed)
|
53 |
+
source_image = self.transform(source_image)
|
54 |
+
torch.manual_seed(seed)
|
55 |
+
target_image = self.transform(target_image)
|
56 |
+
|
57 |
+
source_age_channel = torch.full_like(source_image[:1, :, :], source_age)
|
58 |
+
target_age_channel = torch.full_like(source_image[:1, :, :], target_age)
|
59 |
+
|
60 |
+
# Concatenate the age channels with the source_image
|
61 |
+
source_image = torch.cat([source_image, source_age_channel, target_age_channel], dim=0)
|
62 |
+
|
63 |
+
return source_image, target_image
|