File size: 2,016 Bytes
97a6728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import tops
import numpy as np
from sg3_torch_utils.ops import conv2d_gradfix

pl_mean_total = torch.zeros([])


class PLRegularization:

    def __init__(self, weight: float, batch_shrink: int, pl_decay: float, scale_by_mask: bool, **kwargs):
        self.pl_mean = torch.zeros([], device=tops.get_device())
        self.pl_weight = weight
        self.batch_shrink = batch_shrink
        self.pl_decay = pl_decay
        self.scale_by_mask = scale_by_mask

    def __call__(self, G, batch, grad_scaler):
        batch_size = batch["img"].shape[0] // self.batch_shrink
        batch = {k: v[:batch_size] for k, v in batch.items() if k != "embed_map"}
        if "embed_map" in batch:
            batch["embed_map"] = batch["embed_map"]
        z = G.get_z(batch["img"])

        with torch.cuda.amp.autocast(tops.AMP()):
            gen_ws = G.style_net(z)
            gen_img = G(**batch, w=gen_ws)["img"].float()
        pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
        with conv2d_gradfix.no_weight_gradients():
            # Sums over HWC
            pl_grads = torch.autograd.grad(
                outputs=[grad_scaler.scale(gen_img * pl_noise)],
                inputs=[gen_ws],
                create_graph=True,
                grad_outputs=torch.ones_like(gen_img),
                only_inputs=True)[0]

        pl_grads = pl_grads.float() / grad_scaler.get_scale()
        if self.scale_by_mask:
            # Percentage of pixels known
            scaling = batch["mask"].flatten(start_dim=1).mean(dim=1).view(-1, 1)
            pl_grads = pl_grads / scaling
        pl_lengths = pl_grads.square().sum(1).sqrt()
        pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
        if not torch.isnan(pl_mean).any():
            self.pl_mean.copy_(pl_mean.detach())
        pl_penalty = (pl_lengths - pl_mean).square()
        to_log = dict(pl_penalty=pl_penalty.mean().detach())
        return pl_penalty.view(-1) * self.pl_weight, to_log