File size: 3,817 Bytes
c690b8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
import lpips  # LPIPS library for perceptual loss

class GeneratorLoss(nn.Module):
    def __init__(self, discriminator_model, l1_weight=1.0, perceptual_weight=1.0, adversarial_weight=0.05,
                 device="cpu"):
        super(GeneratorLoss, self).__init__()
        self.discriminator_model = discriminator_model
        self.l1_weight = l1_weight
        self.perceptual_weight = perceptual_weight
        self.adversarial_weight = adversarial_weight
        self.criterion_l1 = nn.L1Loss()
        self.criterion_adversarial = nn.BCEWithLogitsLoss()
        self.criterion_perceptual = lpips.LPIPS(net='vgg').to(device)

    def forward(self, output, target, source):
        # L1 loss

        l1_loss = self.criterion_l1(output, target)

        # Perceptual loss
        perceptual_loss = torch.mean(self.criterion_perceptual(output, target))

        # Adversarial loss
        fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1)
        fake_prediction = self.discriminator_model(fake_input)

        adversarial_loss = self.criterion_adversarial(fake_prediction, torch.ones_like(fake_prediction))

        # Combine losses
        generator_loss = self.l1_weight * l1_loss + self.perceptual_weight * perceptual_loss + \
                         self.adversarial_weight * adversarial_loss

        return generator_loss, l1_loss, perceptual_loss, adversarial_loss

class DiscriminatorLoss(nn.Module):
    def __init__(self, discriminator_model, fake_weight=1.0, real_weight=2.0, mock_weight=.5):
        super(DiscriminatorLoss, self).__init__()
        self.discriminator_model = discriminator_model
        self.criterion_adversarial = nn.BCEWithLogitsLoss()
        self.fake_weight = fake_weight
        self.real_weight = real_weight
        self.mock_weight = mock_weight

    def forward(self, output, target, source):
        # Adversarial loss
        fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1)  # prediction img with target age
        real_input = torch.cat([target, source[:, 4:5, :, :]], dim=1)  # target img with target age

        mock_input1 = torch.cat([source[:, :3, :, :], source[:, 4:5, :, :]], dim=1)  # source img with target age
        mock_input2 = torch.cat([target, source[:, 3:4, :, :]], dim=1)  # target img with source age
        mock_input3 = torch.cat([output, source[:, 3:4, :, :]], dim=1)  # prediction img with source age
        mock_input4 = torch.cat([target, source[:, 3:4, :, :]], dim=1)  # target img with target age

        fake_pred, real_pred = self.discriminator_model(fake_input), self.discriminator_model(real_input)
        mock_pred1, mock_pred2, mock_pred3, mock_pred4 = (self.discriminator_model(mock_input1),
                                                          self.discriminator_model(mock_input2),
                                                          self.discriminator_model(mock_input3),
                                                          self.discriminator_model(mock_input4))

        discriminator_loss = (self.fake_weight * self.criterion_adversarial(fake_pred, torch.zeros_like(fake_pred)) +
                              self.real_weight * self.criterion_adversarial(real_pred, torch.ones_like(real_pred)) +
                              self.mock_weight * self.criterion_adversarial(mock_pred1, torch.zeros_like(mock_pred1)) +
                              self.mock_weight * self.criterion_adversarial(mock_pred2, torch.zeros_like(mock_pred2)) +
                              self.mock_weight * self.criterion_adversarial(mock_pred3, torch.zeros_like(mock_pred3)) +
                              self.mock_weight * self.criterion_adversarial(mock_pred4, torch.zeros_like(mock_pred4))
                              )

        return discriminator_loss