File size: 4,163 Bytes
97a6728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

import torch
import tops
import numpy as np
from kornia.color import rgb_to_hsv
from dp2 import utils
from kornia.enhance import histogram
from .anonymizer import Anonymizer
import torchvision.transforms.functional as F
from skimage.exposure import match_histograms
from kornia.filters import gaussian_blur2d


class LatentHistogramMatchAnonymizer(Anonymizer):

    def forward_G(
        self,
        G,
        batch,
        multi_modal_truncation: bool,
        amp: bool,
        z_idx: int,
        truncation_value: float,
        idx: int,
        n_sampling_steps: int = 1,
        all_styles=None,
    ):
        batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
        batch["img"] = batch["img"].float()
        batch["condition"] = batch["mask"].float() * batch["img"]

        assert z_idx is None and all_styles is None, "Arguments not supported with n_sampling_steps > 1."
        real_hls = rgb_to_hsv(utils.denormalize_img(batch["img"]))
        real_hls[:, 0] /= 2 * torch.pi
        indices = [1, 2]
        hist_kwargs = dict(
            bins=torch.linspace(0, 1, 256, dtype=torch.float32, device=tops.get_device()),
            bandwidth=torch.tensor(1., device=tops.get_device()))
        real_hist = [histogram(real_hls[:, i].flatten(start_dim=1), **hist_kwargs) for i in indices]
        for j in range(n_sampling_steps):
            if j == 0:
                if multi_modal_truncation:
                    w = G.style_net.multi_modal_truncate(
                        truncation_value=truncation_value, **batch, w_indices=None).detach()
                else:
                    w = G.style_net.get_truncated(truncation_value, **batch).detach()
                assert z_idx is None and all_styles is None, "Arguments not supported with n_sampling_steps > 1."
                w.requires_grad = True
                optim = torch.optim.Adam([w])
            with torch.set_grad_enabled(True):
                with torch.cuda.amp.autocast(amp):
                    anonymized_im = G(**batch, truncation_value=None, w=w)["img"]
                fake_hls = rgb_to_hsv(anonymized_im*0.5 + 0.5)
                fake_hls[:, 0] /= 2 * torch.pi
                fake_hist = [histogram(fake_hls[:, i].flatten(start_dim=1), **hist_kwargs) for i in indices]
                dist = sum([utils.torch_wasserstein_loss(r, f) for r, f in zip(real_hist, fake_hist)])
                dist.backward()
                if w.grad.sum() == 0:
                    break
                assert w.grad.sum() != 0
                optim.step()
                optim.zero_grad()
                if dist < 0.02:
                    break
        anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255)
        return anonymized_im


class HistogramMatchAnonymizer(Anonymizer):

    def forward_G(self, batch, *args, **kwargs):
        rimg = batch["img"]
        batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
        batch["img"] = batch["img"].float()
        batch["condition"] = batch["mask"].float() * batch["img"]

        anonymized_im = super().forward_G(batch, *args, **kwargs)

        equalized_gim = match_histograms(tops.im2numpy(anonymized_im.round().clamp(0, 255).byte()), tops.im2numpy(rimg))
        if equalized_gim.dtype != np.uint8:
            equalized_gim = equalized_gim.astype(np.float32)
            assert equalized_gim.dtype == np.float32, equalized_gim.dtype
            equalized_gim = tops.im2torch(equalized_gim, to_float=False)[0]
        else:
            equalized_gim = tops.im2torch(equalized_gim, to_float=False).float()[0]
        equalized_gim = equalized_gim.to(device=rimg.device)
        assert equalized_gim.dtype == torch.float32
        gaussian_mask = 1 - (batch["maskrcnn_mask"][0].repeat(3, 1, 1) > 0.5).float()

        gaussian_mask = gaussian_blur2d(gaussian_mask[None], kernel_size=[19, 19], sigma=[10, 10])[0]
        gaussian_mask = gaussian_mask / gaussian_mask.max()
        anonymized_im = gaussian_mask * equalized_gim + (1-gaussian_mask) * anonymized_im
        return anonymized_im