timroelofs123 commited on
Commit
c690b8f
·
1 Parent(s): 1931503

add necessary files

Browse files
app.py CHANGED
@@ -3,12 +3,8 @@ import torch
3
  import argparse
4
  import git
5
 
6
- git.Repo.clone_from("https://github.com/timroelofs123/face_reaging.git", "./face_reaging")
7
  git.Repo.clone_from("https://huggingface.co/timroelofs123/face_re-aging", "./hf")
8
 
9
- import sys
10
- sys.path.append("./face_reaging")
11
-
12
  from model.models import UNet
13
  from scripts.test_functions import process_image
14
 
 
3
  import argparse
4
  import git
5
 
 
6
  git.Repo.clone_from("https://huggingface.co/timroelofs123/face_re-aging", "./hf")
7
 
 
 
 
8
  from model.models import UNet
9
  from scripts.test_functions import process_image
10
 
assets/mask1024.jpg ADDED
assets/mask512.jpg ADDED
model/__init__.py ADDED
File without changes
model/losses.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import lpips # LPIPS library for perceptual loss
4
+
5
+ class GeneratorLoss(nn.Module):
6
+ def __init__(self, discriminator_model, l1_weight=1.0, perceptual_weight=1.0, adversarial_weight=0.05,
7
+ device="cpu"):
8
+ super(GeneratorLoss, self).__init__()
9
+ self.discriminator_model = discriminator_model
10
+ self.l1_weight = l1_weight
11
+ self.perceptual_weight = perceptual_weight
12
+ self.adversarial_weight = adversarial_weight
13
+ self.criterion_l1 = nn.L1Loss()
14
+ self.criterion_adversarial = nn.BCEWithLogitsLoss()
15
+ self.criterion_perceptual = lpips.LPIPS(net='vgg').to(device)
16
+
17
+ def forward(self, output, target, source):
18
+ # L1 loss
19
+
20
+ l1_loss = self.criterion_l1(output, target)
21
+
22
+ # Perceptual loss
23
+ perceptual_loss = torch.mean(self.criterion_perceptual(output, target))
24
+
25
+ # Adversarial loss
26
+ fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1)
27
+ fake_prediction = self.discriminator_model(fake_input)
28
+
29
+ adversarial_loss = self.criterion_adversarial(fake_prediction, torch.ones_like(fake_prediction))
30
+
31
+ # Combine losses
32
+ generator_loss = self.l1_weight * l1_loss + self.perceptual_weight * perceptual_loss + \
33
+ self.adversarial_weight * adversarial_loss
34
+
35
+ return generator_loss, l1_loss, perceptual_loss, adversarial_loss
36
+
37
+ class DiscriminatorLoss(nn.Module):
38
+ def __init__(self, discriminator_model, fake_weight=1.0, real_weight=2.0, mock_weight=.5):
39
+ super(DiscriminatorLoss, self).__init__()
40
+ self.discriminator_model = discriminator_model
41
+ self.criterion_adversarial = nn.BCEWithLogitsLoss()
42
+ self.fake_weight = fake_weight
43
+ self.real_weight = real_weight
44
+ self.mock_weight = mock_weight
45
+
46
+ def forward(self, output, target, source):
47
+ # Adversarial loss
48
+ fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1) # prediction img with target age
49
+ real_input = torch.cat([target, source[:, 4:5, :, :]], dim=1) # target img with target age
50
+
51
+ mock_input1 = torch.cat([source[:, :3, :, :], source[:, 4:5, :, :]], dim=1) # source img with target age
52
+ mock_input2 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with source age
53
+ mock_input3 = torch.cat([output, source[:, 3:4, :, :]], dim=1) # prediction img with source age
54
+ mock_input4 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with target age
55
+
56
+ fake_pred, real_pred = self.discriminator_model(fake_input), self.discriminator_model(real_input)
57
+ mock_pred1, mock_pred2, mock_pred3, mock_pred4 = (self.discriminator_model(mock_input1),
58
+ self.discriminator_model(mock_input2),
59
+ self.discriminator_model(mock_input3),
60
+ self.discriminator_model(mock_input4))
61
+
62
+ discriminator_loss = (self.fake_weight * self.criterion_adversarial(fake_pred, torch.zeros_like(fake_pred)) +
63
+ self.real_weight * self.criterion_adversarial(real_pred, torch.ones_like(real_pred)) +
64
+ self.mock_weight * self.criterion_adversarial(mock_pred1, torch.zeros_like(mock_pred1)) +
65
+ self.mock_weight * self.criterion_adversarial(mock_pred2, torch.zeros_like(mock_pred2)) +
66
+ self.mock_weight * self.criterion_adversarial(mock_pred3, torch.zeros_like(mock_pred3)) +
67
+ self.mock_weight * self.criterion_adversarial(mock_pred4, torch.zeros_like(mock_pred4))
68
+ )
69
+
70
+ return discriminator_loss
model/models.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import antialiased_cnns
4
+
5
+
6
+ class DownLayer(nn.Module):
7
+ def __init__(self, in_channels, out_channels):
8
+ super(DownLayer, self).__init__()
9
+ self.layer = nn.Sequential(
10
+ nn.MaxPool2d(kernel_size=2, stride=1),
11
+ antialiased_cnns.BlurPool(in_channels, stride=2),
12
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
13
+ nn.LeakyReLU(inplace=True),
14
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
15
+ nn.LeakyReLU(inplace=True)
16
+ )
17
+
18
+ def forward(self, x):
19
+ return self.layer(x)
20
+
21
+
22
+ class UpLayer(nn.Module):
23
+ def __init__(self, in_channels, out_channels):
24
+ super(UpLayer, self).__init__()
25
+ # Conv transpose upsampling
26
+
27
+ self.blur_upsample = nn.Sequential(
28
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0),
29
+ antialiased_cnns.BlurPool(out_channels, stride=1)
30
+ )
31
+
32
+ self.layer = nn.Sequential(
33
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
34
+ nn.LeakyReLU(inplace=True),
35
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
36
+ nn.LeakyReLU(inplace=True)
37
+ )
38
+
39
+ def forward(self, x, skip):
40
+ x = self.blur_upsample(x)
41
+ x = torch.cat([x, skip], dim=1) # Concatenate with skip connection
42
+ return self.layer(x)
43
+
44
+
45
+ class UNet(nn.Module):
46
+ def __init__(self):
47
+ super(UNet, self).__init__()
48
+ self.init_conv = nn.Sequential(
49
+ nn.Conv2d(5, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
50
+ nn.LeakyReLU(inplace=True),
51
+ nn.Conv2d(64, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
52
+ nn.LeakyReLU(inplace=True)
53
+ )
54
+
55
+ self.down1 = DownLayer(64, 128) # output: 256 x 256 x 128
56
+ self.down2 = DownLayer(128, 256) # output: 128 x 128 x 256
57
+ self.down3 = DownLayer(256, 512) # output: 64 x 64 x 512
58
+ self.down4 = DownLayer(512, 1024) # output: 32 x 32 x 1024
59
+ self.up1 = UpLayer(1024, 512) # output: 64 x 64 x 512
60
+ self.up2 = UpLayer(512, 256) # output: 128 x 128 x 256
61
+ self.up3 = UpLayer(256, 128) # output: 256 x 256 x 128
62
+ self.up4 = UpLayer(128, 64) # output: 512 x 512 x 64
63
+ self.final_conv = nn.Conv2d(64, 3, kernel_size=1) # output: 512 x 512 x 3
64
+
65
+ def forward(self, x):
66
+ x0 = self.init_conv(x)
67
+ x1 = self.down1(x0)
68
+ x2 = self.down2(x1)
69
+ x3 = self.down3(x2)
70
+ x4 = self.down4(x3)
71
+ x = self.up1(x4, x3)
72
+ x = self.up2(x, x2)
73
+ x = self.up3(x, x1)
74
+ x = self.up4(x, x0)
75
+ x = self.final_conv(x)
76
+ return x
77
+
78
+
79
+ class PatchGANDiscriminator(nn.Module):
80
+ def __init__(self, input_channels=3):
81
+ super(PatchGANDiscriminator, self).__init__()
82
+ self.model = nn.Sequential(
83
+ nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
84
+ nn.LeakyReLU(0.2, inplace=True),
85
+
86
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
87
+ nn.BatchNorm2d(128),
88
+ nn.LeakyReLU(0.2, inplace=True),
89
+
90
+ nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
91
+ nn.BatchNorm2d(256),
92
+ nn.LeakyReLU(0.2, inplace=True),
93
+
94
+ nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1)
95
+ # Output layer with 1 channel for binary classification
96
+ )
97
+
98
+ def forward(self, x):
99
+ return self.model(x)
utils/__init__.py ADDED
File without changes
utils/dataloader.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import os
6
+ import random
7
+ from pathlib import Path
8
+
9
+
10
+ # Define the transformations
11
+ transform = transforms.Compose([
12
+ transforms.RandomRotation(degrees=10),
13
+ transforms.RandomCrop(512),
14
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
15
+ transforms.ToTensor(),
16
+ ])
17
+
18
+ class CustomDataset(Dataset):
19
+ def __init__(self, root_dir, transform=None):
20
+ self.root_dir = root_dir
21
+ self.transform = transform
22
+ self.image_folders = [folder for folder in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, folder))]
23
+
24
+ def __len__(self):
25
+ return len(self.image_folders)
26
+
27
+ def __getitem__(self, idx):
28
+ folder_name = self.image_folders[idx]
29
+ folder_path = os.path.join(self.root_dir, folder_name)
30
+
31
+ # # Get the list of image filenames in the folder
32
+ # image_filenames = [f"{i}.jpg" for i in range(0, 101, 10)]
33
+ image_filenames = os.listdir(folder_path)
34
+
35
+ # Pick two random assets from the folder
36
+ source_image_name, target_image_name = random.sample(image_filenames, 2)
37
+ # source_image_name, target_image_name = '20.jpg', '80.jpg'
38
+
39
+ source_age = int(Path(source_image_name).stem) / 100
40
+ target_age = int(Path(target_image_name).stem) / 100
41
+
42
+ # Randomly select two assets from the folder
43
+ source_image_path = os.path.join(folder_path, source_image_name)
44
+ target_image_path = os.path.join(folder_path, target_image_name)
45
+
46
+ source_image = Image.open(source_image_path).convert('RGB')
47
+ target_image = Image.open(target_image_path).convert('RGB')
48
+
49
+ # Apply the same random crop and augmentations to both assets
50
+ if self.transform:
51
+ seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()
52
+ torch.manual_seed(seed)
53
+ source_image = self.transform(source_image)
54
+ torch.manual_seed(seed)
55
+ target_image = self.transform(target_image)
56
+
57
+ source_age_channel = torch.full_like(source_image[:1, :, :], source_age)
58
+ target_age_channel = torch.full_like(source_image[:1, :, :], target_age)
59
+
60
+ # Concatenate the age channels with the source_image
61
+ source_image = torch.cat([source_image, source_age_channel, target_age_channel], dim=0)
62
+
63
+ return source_image, target_image