File size: 7,154 Bytes
548d634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from pathlib import Path
from typing import Union, Optional
import numpy as np
import torch
import tops
import torchvision.transforms.functional as F
from motpy import Detection, MultiObjectTracker
from dp2.utils import load_config
from dp2.infer import build_trained_generator
from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection


def load_generator_from_cfg_path(cfg_path: Union[str, Path]):
    cfg = load_config(cfg_path)
    G = build_trained_generator(cfg)
    tops.logger.log(f"Loaded generator from: {cfg_path}")
    return G


def resize_batch(img, mask, maskrcnn_mask, condition, imsize, **kwargs):
    img = F.resize(img, imsize, antialias=True)
    mask = (F.resize(mask, imsize, antialias=True) > 0.99).float()
    maskrcnn_mask = (F.resize(maskrcnn_mask, imsize, antialias=True) > 0.5).float()
    
    condition = img * mask
    return dict(img=img, mask=mask, maskrcnn_mask=maskrcnn_mask, condition=condition)
    

class Anonymizer:

    def __init__(
            self,
            detector,
            load_cache: bool,
            person_G_cfg: Optional[Union[str, Path]] = None,
            cse_person_G_cfg: Optional[Union[str, Path]] = None,
            face_G_cfg: Optional[Union[str, Path]] = None,
            car_G_cfg: Optional[Union[str, Path]] = None,
            ) -> None:
        self.detector = detector
        self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]}
        self.load_cache = load_cache
        if cse_person_G_cfg is not None:
            self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg)
        if person_G_cfg is not None:
            self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg)
        if face_G_cfg is not None:
            self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg)
        if car_G_cfg is not None:
            self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg)

    def initialize_tracker(self, fps: float):
        self.tracker = MultiObjectTracker(dt=1/fps)
        self.track_to_z_idx = dict()
        self.cur_z_idx = 0

    @torch.no_grad()
    def anonymize_detections(self,
            im, detection, truncation_value: float,
            multi_modal_truncation: bool, amp: bool, z_idx,
            all_styles=None,
            update_identity=None,
            ):
        G = self.generators[type(detection)]
        if G is None:
            return im
        C, H, W = im.shape
        orig_im = im.clone()
        if update_identity is None:
            update_identity = [True for i in range(len(detection))]
        for idx in range(len(detection)):
            if not update_identity[idx]:
                continue
            batch = detection.get_crop(idx, im)
            x0, y0, x1, y1 = batch.pop("boxes")[0]
            batch = {k: tops.to_cuda(v) for k, v in batch.items()}
            batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
            batch["img"] = batch["img"].float()
            batch["condition"] = batch["mask"] * batch["img"]
            orig_shape = None
            if G.imsize and batch["img"].shape[-1] != G.imsize[-1] and batch["img"].shape[-2] != G.imsize[-2]:
                orig_shape = batch["img"].shape[-2:]
                batch = resize_batch(**batch, imsize=G.imsize)
            with torch.cuda.amp.autocast(amp):
                if all_styles is not None:
                    anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"]
                elif multi_modal_truncation and hasattr(G, "multi_modal_truncate") and hasattr(G.style_net, "w_centers"):
                    w_indices = None
                    if z_idx is not None:
                        w_indices = [z_idx[idx] % len(G.style_net.w_centers)]
                    anonymized_im = G.multi_modal_truncate(
                        **batch, truncation_value=truncation_value,
                        w_indices=w_indices)["img"]
                else:
                    z = None
                    if z_idx is not None:
                        state = np.random.RandomState(seed=z_idx[idx])
                        z = state.normal(size=(1, G.z_channels))
                        z = tops.to_cuda(torch.from_numpy(z))
                    anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"]
            if orig_shape is not None:
                anonymized_im = F.resize(anonymized_im, orig_shape, antialias=True)
            anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255).round().byte()

            # Resize and denormalize image
            gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), antialias=True)
            mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0)
            # Remove padding
            pad = [max(-x0,0), max(-y0,0)]
            pad = [*pad, max(x1-W,0), max(y1-H,0)]
            remove_pad = lambda x: x[...,pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]]
            gim = remove_pad(gim)
            mask = remove_pad(mask)
            x0, y0 = max(x0, 0), max(y0, 0)
            x1, y1 = min(x1, W), min(y1, H)
            mask = mask.logical_not()[None].repeat(3, 1, 1)
            im[:, y0:y1, x0:x1][mask] = gim[mask]

        return im

    def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor:
        all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
        for det in all_detections:
            im = det.visualize(im)
        return im

    @torch.no_grad()
    def forward(self, im: torch.Tensor, cache_id: str = None, track=True, **synthesis_kwargs) -> torch.Tensor:
        assert im.dtype == torch.uint8
        im = tops.to_cuda(im)
        all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
        if hasattr(self, "tracker") and track:
            [_.pre_process() for _ in all_detections]
            import numpy as np 
            boxes = np.concatenate([_.boxes for _ in all_detections])
            boxes = [Detection(box) for box in boxes]
            self.tracker.step(boxes)
            track_ids = self.tracker.detections_matched_ids
            z_idx = []
            for track_id in track_ids:
                if track_id not in self.track_to_z_idx:
                    self.track_to_z_idx[track_id] = self.cur_z_idx
                    self.cur_z_idx += 1
                z_idx.append(self.track_to_z_idx[track_id])
            z_idx = np.array(z_idx)
            idx_offset = 0

        for detection in all_detections:
            zs = None
            if hasattr(self, "tracker") and track:
                zs = z_idx[idx_offset:idx_offset+len(detection)]
                idx_offset += len(detection)
            im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs)

        return im.cpu()

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)