File size: 5,704 Bytes
8c212a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# python3.7
"""Defines loss functions for encoder training."""

import torch
import torch.nn.functional as F

from models import build_perceptual

__all__ = ['EncoderLoss']


class EncoderLoss(object):
    """Contains the class to compute logistic GAN loss."""

    def __init__(self,
                 runner,
                 d_loss_kwargs=None,
                 e_loss_kwargs=None,
                 perceptual_kwargs=None):
        """Initializes with models and arguments for computing losses."""
        self.d_loss_kwargs = d_loss_kwargs or dict()
        self.e_loss_kwargs = e_loss_kwargs or dict()
        self.r1_gamma = self.d_loss_kwargs.get('r1_gamma', 10.0)
        self.r2_gamma = self.d_loss_kwargs.get('r2_gamma', 0.0)

        self.perceptual_lw = self.e_loss_kwargs.get('perceptual_lw', 5e-5)
        self.adv_lw = self.e_loss_kwargs.get('adv_lw', 0.1)

        self.perceptual_model = build_perceptual(**perceptual_kwargs).cuda()
        self.perceptual_model.eval()
        for param in self.perceptual_model.parameters():
            param.requires_grad = False

        runner.space_of_latent = runner.config.space_of_latent

        runner.running_stats.add(
            f'recon_loss', log_format='.3f', log_strategy='AVERAGE')
        runner.running_stats.add(
            f'adv_loss', log_format='.3f', log_strategy='AVERAGE')
        runner.running_stats.add(
            f'loss_fake', log_format='.3f', log_strategy='AVERAGE')
        runner.running_stats.add(
            f'loss_real', log_format='.3f', log_strategy='AVERAGE')
        if self.r1_gamma != 0:
            runner.running_stats.add(
                f'real_grad_penalty', log_format='.3f', log_strategy='AVERAGE')
        if self.r2_gamma != 0:
            runner.running_stats.add(
                f'fake_grad_penalty', log_format='.3f', log_strategy='AVERAGE')

    @staticmethod
    def compute_grad_penalty(images, scores):
        """Computes gradient penalty."""
        image_grad = torch.autograd.grad(
            outputs=scores.sum(),
            inputs=images,
            create_graph=True,
            retain_graph=True)[0].view(images.shape[0], -1)
        penalty = image_grad.pow(2).sum(dim=1).mean()
        return penalty

    def d_loss(self, runner, data):
        """Computes loss for discriminator."""
        if 'generator_smooth' in runner.models:
            G = runner.get_module(runner.models['generator_smooth'])
        else:
            G = runner.get_module(runner.models['generator'])
        G.eval()
        D = runner.models['discriminator']
        E = runner.models['encoder']

        reals = data['image']
        reals.requires_grad = True

        with torch.no_grad():
            latents = E(reals)
            if runner.space_of_latent == 'z':
                reals_rec = G(latents, **runner.G_kwargs_val)['image']
            elif runner.space_of_latent == 'wp':
                reals_rec = G.synthesis(latents,
                    **runner.G_kwargs_val)['image']
            elif runner.space_of_latent == 'y':
                G.set_space_of_latent('y')
                reals_rec = G.synthesis(latents,
                    **runner.G_kwargs_val)['image']
        real_scores = D(reals, **runner.D_kwargs_train)
        fake_scores = D(reals_rec, **runner.D_kwargs_train)
        loss_fake = F.softplus(fake_scores).mean()
        loss_real = F.softplus(-real_scores).mean()
        d_loss = loss_fake + loss_real

        runner.running_stats.update({'loss_fake': loss_fake.item()})
        runner.running_stats.update({'loss_real': loss_real.item()})

        real_grad_penalty = torch.zeros_like(d_loss)
        fake_grad_penalty = torch.zeros_like(d_loss)
        if self.r1_gamma:
            real_grad_penalty = self.compute_grad_penalty(reals, real_scores)
            runner.running_stats.update(
                {'real_grad_penalty': real_grad_penalty.item()})
        if self.r2_gamma:
            fake_grad_penalty = self.compute_grad_penalty(
                reals_rec, fake_scores)
            runner.running_stats.update(
                {'fake_grad_penalty': fake_grad_penalty.item()})

        return (d_loss +
                real_grad_penalty * (self.r1_gamma * 0.5) +
                fake_grad_penalty * (self.r2_gamma * 0.5))

    def e_loss(self, runner, data):
        """Computes loss for generator."""
        if 'generator_smooth' in runner.models:
            G = runner.get_module(runner.models['generator_smooth'])
        else:
            G = runner.get_module(runner.models['generator'])
        G.eval()
        D = runner.models['discriminator']
        E = runner.models['encoder']
        P = self.perceptual_model

        # Fetch data
        reals = data['image']

        latents = E(reals)
        if runner.space_of_latent == 'z':
            reals_rec = G(latents, **runner.G_kwargs_val)['image']
        elif runner.space_of_latent == 'wp':
            reals_rec = G.synthesis(latents, **runner.G_kwargs_val)['image']
        elif runner.space_of_latent == 'y':
            G.set_space_of_latent('y')
            reals_rec = G.synthesis(latents, **runner.G_kwargs_val)['image']
        loss_pix = F.mse_loss(reals_rec, reals, reduction='mean')
        loss_feat = self.perceptual_lw * F.mse_loss(
            P(reals_rec), P(reals), reduction='mean')
        loss_rec = loss_pix + loss_feat
        fake_scores = D(reals_rec, **runner.D_kwargs_train)
        adv_loss = self.adv_lw * F.softplus(-fake_scores).mean()
        e_loss = loss_pix + loss_feat + adv_loss

        runner.running_stats.update({'recon_loss': loss_rec.item()})
        runner.running_stats.update({'adv_loss': adv_loss.item()})

        return e_loss