Spaces:
Runtime error
Runtime error
:)
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- dp2/__init__.py +0 -0
- dp2/anonymizer/__init__.py +1 -0
- dp2/anonymizer/anonymizer.py +159 -0
- dp2/data/__init__.py +0 -0
- dp2/data/build.py +148 -0
- dp2/data/datasets/__init__.py +0 -0
- dp2/data/datasets/coco_cse.py +148 -0
- dp2/data/datasets/fdf.py +129 -0
- dp2/data/datasets/fdh.py +104 -0
- dp2/data/transforms/__init__.py +2 -0
- dp2/data/transforms/functional.py +61 -0
- dp2/data/transforms/stylegan2_transform.py +394 -0
- dp2/data/transforms/transforms.py +247 -0
- dp2/data/utils.py +102 -0
- dp2/detection/__init__.py +3 -0
- dp2/detection/base.py +45 -0
- dp2/detection/box_utils.py +104 -0
- dp2/detection/box_utils_fdf.py +203 -0
- dp2/detection/cse_mask_face_detector.py +116 -0
- dp2/detection/face_detector.py +62 -0
- dp2/detection/models/__init__.py +0 -0
- dp2/detection/models/cse.py +135 -0
- dp2/detection/models/keypoint_maskrcnn.py +111 -0
- dp2/detection/models/mask_rcnn.py +78 -0
- dp2/detection/person_detector.py +135 -0
- dp2/detection/structures.py +463 -0
- dp2/detection/utils.py +174 -0
- dp2/discriminator/__init__.py +1 -0
- dp2/discriminator/sg2_discriminator.py +76 -0
- dp2/gan_trainer.py +324 -0
- dp2/generator/__init__.py +0 -0
- dp2/generator/base.py +144 -0
- dp2/generator/dummy_generators.py +47 -0
- dp2/generator/imagen3_old.py +1210 -0
- dp2/generator/stylegan_unet.py +208 -0
- dp2/generator/utils.py +48 -0
- dp2/infer.py +72 -0
- dp2/layers/__init__.py +20 -0
- dp2/layers/sg2_layers.py +227 -0
- dp2/loss/__init__.py +1 -0
- dp2/loss/pl_regularization.py +48 -0
- dp2/loss/r1_regularization.py +31 -0
- dp2/loss/sg2_loss.py +94 -0
- dp2/loss/utils.py +25 -0
- dp2/metrics/__init__.py +3 -0
- dp2/metrics/fid.py +72 -0
- dp2/metrics/fid_clip.py +84 -0
- dp2/metrics/lpips.py +76 -0
- dp2/metrics/ppl.py +110 -0
- dp2/metrics/torch_metrics.py +176 -0
dp2/__init__.py
ADDED
File without changes
|
dp2/anonymizer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .anonymizer import Anonymizer
|
dp2/anonymizer/anonymizer.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Union, Optional
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import tops
|
6 |
+
import torchvision.transforms.functional as F
|
7 |
+
from motpy import Detection, MultiObjectTracker
|
8 |
+
from dp2.utils import load_config
|
9 |
+
from dp2.infer import build_trained_generator
|
10 |
+
from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection
|
11 |
+
|
12 |
+
|
13 |
+
def load_generator_from_cfg_path(cfg_path: Union[str, Path]):
|
14 |
+
cfg = load_config(cfg_path)
|
15 |
+
G = build_trained_generator(cfg)
|
16 |
+
tops.logger.log(f"Loaded generator from: {cfg_path}")
|
17 |
+
return G
|
18 |
+
|
19 |
+
|
20 |
+
def resize_batch(img, mask, maskrcnn_mask, condition, imsize, **kwargs):
|
21 |
+
img = F.resize(img, imsize, antialias=True)
|
22 |
+
mask = (F.resize(mask, imsize, antialias=True) > 0.99).float()
|
23 |
+
maskrcnn_mask = (F.resize(maskrcnn_mask, imsize, antialias=True) > 0.5).float()
|
24 |
+
|
25 |
+
condition = img * mask
|
26 |
+
return dict(img=img, mask=mask, maskrcnn_mask=maskrcnn_mask, condition=condition)
|
27 |
+
|
28 |
+
|
29 |
+
class Anonymizer:
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
detector,
|
34 |
+
load_cache: bool,
|
35 |
+
person_G_cfg: Optional[Union[str, Path]] = None,
|
36 |
+
cse_person_G_cfg: Optional[Union[str, Path]] = None,
|
37 |
+
face_G_cfg: Optional[Union[str, Path]] = None,
|
38 |
+
car_G_cfg: Optional[Union[str, Path]] = None,
|
39 |
+
) -> None:
|
40 |
+
self.detector = detector
|
41 |
+
self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]}
|
42 |
+
self.load_cache = load_cache
|
43 |
+
if cse_person_G_cfg is not None:
|
44 |
+
self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg)
|
45 |
+
if person_G_cfg is not None:
|
46 |
+
self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg)
|
47 |
+
if face_G_cfg is not None:
|
48 |
+
self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg)
|
49 |
+
if car_G_cfg is not None:
|
50 |
+
self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg)
|
51 |
+
|
52 |
+
def initialize_tracker(self, fps: float):
|
53 |
+
self.tracker = MultiObjectTracker(dt=1/fps)
|
54 |
+
self.track_to_z_idx = dict()
|
55 |
+
self.cur_z_idx = 0
|
56 |
+
|
57 |
+
@torch.no_grad()
|
58 |
+
def anonymize_detections(self,
|
59 |
+
im, detection, truncation_value: float,
|
60 |
+
multi_modal_truncation: bool, amp: bool, z_idx,
|
61 |
+
all_styles=None,
|
62 |
+
update_identity=None,
|
63 |
+
):
|
64 |
+
G = self.generators[type(detection)]
|
65 |
+
if G is None:
|
66 |
+
return im
|
67 |
+
C, H, W = im.shape
|
68 |
+
orig_im = im.clone()
|
69 |
+
if update_identity is None:
|
70 |
+
update_identity = [True for i in range(len(detection))]
|
71 |
+
for idx in range(len(detection)):
|
72 |
+
if not update_identity[idx]:
|
73 |
+
continue
|
74 |
+
batch = detection.get_crop(idx, im)
|
75 |
+
x0, y0, x1, y1 = batch.pop("boxes")[0]
|
76 |
+
batch = {k: tops.to_cuda(v) for k, v in batch.items()}
|
77 |
+
batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
|
78 |
+
batch["img"] = batch["img"].float()
|
79 |
+
batch["condition"] = batch["mask"] * batch["img"]
|
80 |
+
orig_shape = None
|
81 |
+
if G.imsize and batch["img"].shape[-1] != G.imsize[-1] and batch["img"].shape[-2] != G.imsize[-2]:
|
82 |
+
orig_shape = batch["img"].shape[-2:]
|
83 |
+
batch = resize_batch(**batch, imsize=G.imsize)
|
84 |
+
with torch.cuda.amp.autocast(amp):
|
85 |
+
if all_styles is not None:
|
86 |
+
anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"]
|
87 |
+
elif multi_modal_truncation and hasattr(G, "multi_modal_truncate") and hasattr(G.style_net, "w_centers"):
|
88 |
+
w_indices = None
|
89 |
+
if z_idx is not None:
|
90 |
+
w_indices = [z_idx[idx] % len(G.style_net.w_centers)]
|
91 |
+
anonymized_im = G.multi_modal_truncate(
|
92 |
+
**batch, truncation_value=truncation_value,
|
93 |
+
w_indices=w_indices)["img"]
|
94 |
+
else:
|
95 |
+
z = None
|
96 |
+
if z_idx is not None:
|
97 |
+
state = np.random.RandomState(seed=z_idx[idx])
|
98 |
+
z = state.normal(size=(1, G.z_channels))
|
99 |
+
z = tops.to_cuda(torch.from_numpy(z))
|
100 |
+
anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"]
|
101 |
+
if orig_shape is not None:
|
102 |
+
anonymized_im = F.resize(anonymized_im, orig_shape, antialias=True)
|
103 |
+
anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255).round().byte()
|
104 |
+
|
105 |
+
# Resize and denormalize image
|
106 |
+
gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), antialias=True)
|
107 |
+
mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0)
|
108 |
+
# Remove padding
|
109 |
+
pad = [max(-x0,0), max(-y0,0)]
|
110 |
+
pad = [*pad, max(x1-W,0), max(y1-H,0)]
|
111 |
+
remove_pad = lambda x: x[...,pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]]
|
112 |
+
gim = remove_pad(gim)
|
113 |
+
mask = remove_pad(mask)
|
114 |
+
x0, y0 = max(x0, 0), max(y0, 0)
|
115 |
+
x1, y1 = min(x1, W), min(y1, H)
|
116 |
+
mask = mask.logical_not()[None].repeat(3, 1, 1)
|
117 |
+
im[:, y0:y1, x0:x1][mask] = gim[mask]
|
118 |
+
|
119 |
+
return im
|
120 |
+
|
121 |
+
def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor:
|
122 |
+
all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
|
123 |
+
for det in all_detections:
|
124 |
+
im = det.visualize(im)
|
125 |
+
return im
|
126 |
+
|
127 |
+
@torch.no_grad()
|
128 |
+
def forward(self, im: torch.Tensor, cache_id: str = None, track=True, **synthesis_kwargs) -> torch.Tensor:
|
129 |
+
assert im.dtype == torch.uint8
|
130 |
+
im = tops.to_cuda(im)
|
131 |
+
all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
|
132 |
+
if hasattr(self, "tracker") and track:
|
133 |
+
[_.pre_process() for _ in all_detections]
|
134 |
+
import numpy as np
|
135 |
+
boxes = np.concatenate([_.boxes for _ in all_detections])
|
136 |
+
boxes = [Detection(box) for box in boxes]
|
137 |
+
self.tracker.step(boxes)
|
138 |
+
track_ids = self.tracker.detections_matched_ids
|
139 |
+
z_idx = []
|
140 |
+
for track_id in track_ids:
|
141 |
+
if track_id not in self.track_to_z_idx:
|
142 |
+
self.track_to_z_idx[track_id] = self.cur_z_idx
|
143 |
+
self.cur_z_idx += 1
|
144 |
+
z_idx.append(self.track_to_z_idx[track_id])
|
145 |
+
z_idx = np.array(z_idx)
|
146 |
+
idx_offset = 0
|
147 |
+
|
148 |
+
for detection in all_detections:
|
149 |
+
zs = None
|
150 |
+
if hasattr(self, "tracker") and track:
|
151 |
+
zs = z_idx[idx_offset:idx_offset+len(detection)]
|
152 |
+
idx_offset += len(detection)
|
153 |
+
im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs)
|
154 |
+
|
155 |
+
return im.cpu()
|
156 |
+
|
157 |
+
def __call__(self, *args, **kwargs):
|
158 |
+
return self.forward(*args, **kwargs)
|
159 |
+
|
dp2/data/__init__.py
ADDED
File without changes
|
dp2/data/build.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import torch
|
3 |
+
import tops
|
4 |
+
from .utils import collate_fn, jpg_decoder, get_num_workers, png_decoder
|
5 |
+
|
6 |
+
def get_dataloader(
|
7 |
+
dataset, gpu_transform: torch.nn.Module,
|
8 |
+
num_workers,
|
9 |
+
batch_size,
|
10 |
+
infinite: bool,
|
11 |
+
drop_last: bool,
|
12 |
+
prefetch_factor: int,
|
13 |
+
shuffle,
|
14 |
+
channels_last=False
|
15 |
+
):
|
16 |
+
sampler = None
|
17 |
+
dl_kwargs = dict(
|
18 |
+
pin_memory=True,
|
19 |
+
)
|
20 |
+
if infinite:
|
21 |
+
sampler = tops.InfiniteSampler(
|
22 |
+
dataset, rank=tops.rank(),
|
23 |
+
num_replicas=tops.world_size(),
|
24 |
+
shuffle=shuffle
|
25 |
+
)
|
26 |
+
elif tops.world_size() > 1:
|
27 |
+
sampler = torch.utils.data.DistributedSampler(
|
28 |
+
dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank())
|
29 |
+
dl_kwargs["drop_last"] = drop_last
|
30 |
+
else:
|
31 |
+
dl_kwargs["shuffle"] = shuffle
|
32 |
+
dl_kwargs["drop_last"] = drop_last
|
33 |
+
dataloader = torch.utils.data.DataLoader(
|
34 |
+
dataset, sampler=sampler, collate_fn=collate_fn,
|
35 |
+
batch_size=batch_size,
|
36 |
+
num_workers=num_workers, prefetch_factor=prefetch_factor,
|
37 |
+
**dl_kwargs
|
38 |
+
)
|
39 |
+
dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last)
|
40 |
+
return dataloader
|
41 |
+
|
42 |
+
|
43 |
+
def get_dataloader_places2_wds(
|
44 |
+
path,
|
45 |
+
batch_size: int,
|
46 |
+
num_workers: int,
|
47 |
+
transform: torch.nn.Module,
|
48 |
+
gpu_transform: torch.nn.Module,
|
49 |
+
infinite: bool,
|
50 |
+
shuffle: bool,
|
51 |
+
partial_batches: bool,
|
52 |
+
sample_shuffle=10_000,
|
53 |
+
tar_shuffle=100,
|
54 |
+
channels_last=False,
|
55 |
+
):
|
56 |
+
import webdataset as wds
|
57 |
+
import os
|
58 |
+
os.environ["RANK"] = str(tops.rank())
|
59 |
+
os.environ["WORLD_SIZE"] = str(tops.world_size())
|
60 |
+
|
61 |
+
if infinite:
|
62 |
+
pipeline = [wds.ResampledShards(str(path))]
|
63 |
+
else:
|
64 |
+
pipeline = [wds.SimpleShardList(str(path))]
|
65 |
+
if shuffle:
|
66 |
+
pipeline.append(wds.shuffle(tar_shuffle))
|
67 |
+
pipeline.extend([
|
68 |
+
wds.split_by_node,
|
69 |
+
wds.split_by_worker,
|
70 |
+
])
|
71 |
+
if shuffle:
|
72 |
+
pipeline.append(wds.shuffle(sample_shuffle))
|
73 |
+
|
74 |
+
pipeline.extend([
|
75 |
+
wds.tarfile_to_samples(),
|
76 |
+
wds.decode("torchrgb8"),
|
77 |
+
wds.rename_keys(["img", "jpg"], ["__key__", "__key__"]),
|
78 |
+
])
|
79 |
+
if transform is not None:
|
80 |
+
pipeline.append(wds.map(transform))
|
81 |
+
pipeline.extend([
|
82 |
+
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
|
83 |
+
])
|
84 |
+
pipeline = wds.DataPipeline(*pipeline)
|
85 |
+
if infinite:
|
86 |
+
pipeline = pipeline.repeat(nepochs=1000000)
|
87 |
+
loader = wds.WebLoader(
|
88 |
+
pipeline, batch_size=None, shuffle=False,
|
89 |
+
num_workers=get_num_workers(num_workers),
|
90 |
+
persistent_workers=True,
|
91 |
+
)
|
92 |
+
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
|
93 |
+
return loader
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
def get_dataloader_celebAHQ_wds(
|
99 |
+
path,
|
100 |
+
batch_size: int,
|
101 |
+
num_workers: int,
|
102 |
+
transform: torch.nn.Module,
|
103 |
+
gpu_transform: torch.nn.Module,
|
104 |
+
infinite: bool,
|
105 |
+
shuffle: bool,
|
106 |
+
partial_batches: bool,
|
107 |
+
sample_shuffle=10_000,
|
108 |
+
tar_shuffle=100,
|
109 |
+
channels_last=False,
|
110 |
+
):
|
111 |
+
import webdataset as wds
|
112 |
+
import os
|
113 |
+
os.environ["RANK"] = str(tops.rank())
|
114 |
+
os.environ["WORLD_SIZE"] = str(tops.world_size())
|
115 |
+
|
116 |
+
if infinite:
|
117 |
+
pipeline = [wds.ResampledShards(str(path))]
|
118 |
+
else:
|
119 |
+
pipeline = [wds.SimpleShardList(str(path))]
|
120 |
+
if shuffle:
|
121 |
+
pipeline.append(wds.shuffle(tar_shuffle))
|
122 |
+
pipeline.extend([
|
123 |
+
wds.split_by_node,
|
124 |
+
wds.split_by_worker,
|
125 |
+
])
|
126 |
+
if shuffle:
|
127 |
+
pipeline.append(wds.shuffle(sample_shuffle))
|
128 |
+
|
129 |
+
pipeline.extend([
|
130 |
+
wds.tarfile_to_samples(),
|
131 |
+
wds.decode(wds.handle_extension(".png", png_decoder)),
|
132 |
+
wds.rename_keys(["img", "png"], ["__key__", "__key__"]),
|
133 |
+
])
|
134 |
+
if transform is not None:
|
135 |
+
pipeline.append(wds.map(transform))
|
136 |
+
pipeline.extend([
|
137 |
+
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
|
138 |
+
])
|
139 |
+
pipeline = wds.DataPipeline(*pipeline)
|
140 |
+
if infinite:
|
141 |
+
pipeline = pipeline.repeat(nepochs=1000000)
|
142 |
+
loader = wds.WebLoader(
|
143 |
+
pipeline, batch_size=None, shuffle=False,
|
144 |
+
num_workers=get_num_workers(num_workers),
|
145 |
+
persistent_workers=True,
|
146 |
+
)
|
147 |
+
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last)
|
148 |
+
return loader
|
dp2/data/datasets/__init__.py
ADDED
File without changes
|
dp2/data/datasets/coco_cse.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import torchvision
|
3 |
+
import torch
|
4 |
+
import pathlib
|
5 |
+
import numpy as np
|
6 |
+
from typing import Callable, Optional, Union
|
7 |
+
from torch.hub import get_dir as get_hub_dir
|
8 |
+
|
9 |
+
|
10 |
+
def cache_embed_stats(embed_map: torch.Tensor):
|
11 |
+
mean = embed_map.mean(dim=0, keepdim=True)
|
12 |
+
rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
|
13 |
+
|
14 |
+
cache = dict(mean=mean, rstd=rstd, embed_map=embed_map)
|
15 |
+
path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch")
|
16 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
17 |
+
torch.save(cache, path)
|
18 |
+
|
19 |
+
|
20 |
+
class CocoCSE(torch.utils.data.Dataset):
|
21 |
+
|
22 |
+
def __init__(self,
|
23 |
+
dirpath: Union[str, pathlib.Path],
|
24 |
+
transform: Optional[Callable],
|
25 |
+
normalize_E: bool,):
|
26 |
+
dirpath = pathlib.Path(dirpath)
|
27 |
+
self.dirpath = dirpath
|
28 |
+
|
29 |
+
self.transform = transform
|
30 |
+
assert self.dirpath.is_dir(),\
|
31 |
+
f"Did not find dataset at: {dirpath}"
|
32 |
+
self.image_paths, self.embedding_paths = self._load_impaths()
|
33 |
+
self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
|
34 |
+
mean = self.embed_map.mean(dim=0, keepdim=True)
|
35 |
+
rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
|
36 |
+
self.embed_map = (self.embed_map - mean) * rstd
|
37 |
+
cache_embed_stats(self.embed_map)
|
38 |
+
|
39 |
+
def _load_impaths(self):
|
40 |
+
image_dir = self.dirpath.joinpath("images")
|
41 |
+
image_paths = list(image_dir.glob("*.png"))
|
42 |
+
image_paths.sort()
|
43 |
+
embedding_paths = [
|
44 |
+
self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
|
45 |
+
]
|
46 |
+
return image_paths, embedding_paths
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.image_paths)
|
50 |
+
|
51 |
+
def __getitem__(self, idx):
|
52 |
+
im = torchvision.io.read_image(str(self.image_paths[idx]))
|
53 |
+
vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
|
54 |
+
vertices = torch.from_numpy(vertices.squeeze()).long()
|
55 |
+
mask = torch.from_numpy(mask.squeeze()).float()
|
56 |
+
border = torch.from_numpy(border.squeeze()).float()
|
57 |
+
E_mask = 1 - mask - border
|
58 |
+
batch = {
|
59 |
+
"img": im,
|
60 |
+
"vertices": vertices[None],
|
61 |
+
"mask": mask[None],
|
62 |
+
"embed_map": self.embed_map,
|
63 |
+
"border": border[None],
|
64 |
+
"E_mask": E_mask[None]
|
65 |
+
}
|
66 |
+
if self.transform is None:
|
67 |
+
return batch
|
68 |
+
return self.transform(batch)
|
69 |
+
|
70 |
+
|
71 |
+
class CocoCSEWithFace(CocoCSE):
|
72 |
+
|
73 |
+
def __init__(self,
|
74 |
+
dirpath: Union[str, pathlib.Path],
|
75 |
+
transform: Optional[Callable],
|
76 |
+
**kwargs):
|
77 |
+
super().__init__(dirpath, transform, **kwargs)
|
78 |
+
with open(self.dirpath.joinpath("face_boxes_XYXY.pickle"), "rb") as fp:
|
79 |
+
self.face_boxes = pickle.load(fp)
|
80 |
+
|
81 |
+
def __getitem__(self, idx):
|
82 |
+
item = super().__getitem__(idx)
|
83 |
+
item["boxes_XYXY"] = self.face_boxes[self.image_paths[idx].name]
|
84 |
+
return item
|
85 |
+
|
86 |
+
|
87 |
+
class CocoCSESemantic(torch.utils.data.Dataset):
|
88 |
+
|
89 |
+
def __init__(self,
|
90 |
+
dirpath: Union[str, pathlib.Path],
|
91 |
+
transform: Optional[Callable],
|
92 |
+
**kwargs):
|
93 |
+
dirpath = pathlib.Path(dirpath)
|
94 |
+
self.dirpath = dirpath
|
95 |
+
|
96 |
+
self.transform = transform
|
97 |
+
assert self.dirpath.is_dir(),\
|
98 |
+
f"Did not find dataset at: {dirpath}"
|
99 |
+
self.image_paths, self.embedding_paths = self._load_impaths()
|
100 |
+
self.vertx2cat = torch.from_numpy(np.load(self.dirpath.parent.joinpath("vertx2cat.npy")))
|
101 |
+
self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
|
102 |
+
|
103 |
+
def _load_impaths(self):
|
104 |
+
image_dir = self.dirpath.joinpath("images")
|
105 |
+
image_paths = list(image_dir.glob("*.png"))
|
106 |
+
image_paths.sort()
|
107 |
+
embedding_paths = [
|
108 |
+
self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
|
109 |
+
]
|
110 |
+
return image_paths, embedding_paths
|
111 |
+
|
112 |
+
def __len__(self):
|
113 |
+
return len(self.image_paths)
|
114 |
+
|
115 |
+
def __getitem__(self, idx):
|
116 |
+
im = torchvision.io.read_image(str(self.image_paths[idx]))
|
117 |
+
vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
|
118 |
+
vertices = torch.from_numpy(vertices.squeeze()).long()
|
119 |
+
mask = torch.from_numpy(mask.squeeze()).float()
|
120 |
+
border = torch.from_numpy(border.squeeze()).float()
|
121 |
+
E_mask = 1 - mask - border
|
122 |
+
batch = {
|
123 |
+
"img": im,
|
124 |
+
"vertices": vertices[None],
|
125 |
+
"mask": mask[None],
|
126 |
+
"border": border[None],
|
127 |
+
"vertx2cat": self.vertx2cat,
|
128 |
+
"embed_map": self.embed_map,
|
129 |
+
}
|
130 |
+
if self.transform is None:
|
131 |
+
return batch
|
132 |
+
return self.transform(batch)
|
133 |
+
|
134 |
+
|
135 |
+
class CocoCSESemanticWithFace(CocoCSESemantic):
|
136 |
+
|
137 |
+
def __init__(self,
|
138 |
+
dirpath: Union[str, pathlib.Path],
|
139 |
+
transform: Optional[Callable],
|
140 |
+
**kwargs):
|
141 |
+
super().__init__(dirpath, transform, **kwargs)
|
142 |
+
with open(self.dirpath.joinpath("face_boxes_XYXY.pickle"), "rb") as fp:
|
143 |
+
self.face_boxes = pickle.load(fp)
|
144 |
+
|
145 |
+
def __getitem__(self, idx):
|
146 |
+
item = super().__getitem__(idx)
|
147 |
+
item["boxes_XYXY"] = self.face_boxes[self.image_paths[idx].name]
|
148 |
+
return item
|
dp2/data/datasets/fdf.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
from typing import Tuple
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import pathlib
|
6 |
+
try:
|
7 |
+
import pyspng
|
8 |
+
PYSPNG_IMPORTED = True
|
9 |
+
except ImportError:
|
10 |
+
PYSPNG_IMPORTED = False
|
11 |
+
print("Could not load pyspng. Defaulting to pillow image backend.")
|
12 |
+
from PIL import Image
|
13 |
+
from tops import logger
|
14 |
+
|
15 |
+
|
16 |
+
class FDFDataset:
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
dirpath,
|
20 |
+
imsize: Tuple[int],
|
21 |
+
load_keypoints: bool,
|
22 |
+
transform):
|
23 |
+
dirpath = pathlib.Path(dirpath)
|
24 |
+
self.dirpath = dirpath
|
25 |
+
self.transform = transform
|
26 |
+
self.imsize = imsize[0]
|
27 |
+
self.load_keypoints = load_keypoints
|
28 |
+
assert self.dirpath.is_dir(),\
|
29 |
+
f"Did not find dataset at: {dirpath}"
|
30 |
+
image_dir = self.dirpath.joinpath("images", str(self.imsize))
|
31 |
+
self.image_paths = list(image_dir.glob("*.png"))
|
32 |
+
assert len(self.image_paths) > 0,\
|
33 |
+
f"Did not find images in: {image_dir}"
|
34 |
+
self.image_paths.sort(key=lambda x: int(x.stem))
|
35 |
+
self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
|
36 |
+
|
37 |
+
self.bounding_boxes = torch.load(self.dirpath.joinpath("bounding_box", f"{self.imsize}.torch"))
|
38 |
+
assert len(self.image_paths) == len(self.bounding_boxes)
|
39 |
+
assert len(self.image_paths) == len(self.landmarks)
|
40 |
+
logger.log(
|
41 |
+
f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}, imsize={imsize}")
|
42 |
+
|
43 |
+
def get_mask(self, idx):
|
44 |
+
mask = torch.ones((1, self.imsize, self.imsize), dtype=torch.bool)
|
45 |
+
bounding_box = self.bounding_boxes[idx]
|
46 |
+
x0, y0, x1, y1 = bounding_box
|
47 |
+
mask[:, y0:y1, x0:x1] = 0
|
48 |
+
return mask
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.image_paths)
|
52 |
+
|
53 |
+
def __getitem__(self, index):
|
54 |
+
impath = self.image_paths[index]
|
55 |
+
if PYSPNG_IMPORTED:
|
56 |
+
with open(impath, "rb") as fp:
|
57 |
+
im = pyspng.load(fp.read())
|
58 |
+
else:
|
59 |
+
with Image.open(impath) as fp:
|
60 |
+
im = np.array(fp)
|
61 |
+
im = torch.from_numpy(np.rollaxis(im, -1, 0))
|
62 |
+
masks = self.get_mask(index)
|
63 |
+
landmark = self.landmarks[index]
|
64 |
+
batch = {
|
65 |
+
"img": im,
|
66 |
+
"mask": masks,
|
67 |
+
}
|
68 |
+
if self.load_keypoints:
|
69 |
+
batch["keypoints"] = landmark
|
70 |
+
if self.transform is None:
|
71 |
+
return batch
|
72 |
+
return self.transform(batch)
|
73 |
+
|
74 |
+
|
75 |
+
class FDF256Dataset:
|
76 |
+
|
77 |
+
def __init__(self,
|
78 |
+
dirpath,
|
79 |
+
load_keypoints: bool,
|
80 |
+
transform):
|
81 |
+
dirpath = pathlib.Path(dirpath)
|
82 |
+
self.dirpath = dirpath
|
83 |
+
self.transform = transform
|
84 |
+
self.load_keypoints = load_keypoints
|
85 |
+
assert self.dirpath.is_dir(),\
|
86 |
+
f"Did not find dataset at: {dirpath}"
|
87 |
+
image_dir = self.dirpath.joinpath("images")
|
88 |
+
self.image_paths = list(image_dir.glob("*.png"))
|
89 |
+
assert len(self.image_paths) > 0,\
|
90 |
+
f"Did not find images in: {image_dir}"
|
91 |
+
self.image_paths.sort(key=lambda x: int(x.stem))
|
92 |
+
self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
|
93 |
+
self.bounding_boxes = torch.from_numpy(np.load(self.dirpath.joinpath("bounding_box.npy")))
|
94 |
+
assert len(self.image_paths) == len(self.bounding_boxes)
|
95 |
+
assert len(self.image_paths) == len(self.landmarks)
|
96 |
+
logger.log(
|
97 |
+
f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}")
|
98 |
+
|
99 |
+
def get_mask(self, idx):
|
100 |
+
mask = torch.ones((1, 256, 256), dtype=torch.bool)
|
101 |
+
bounding_box = self.bounding_boxes[idx]
|
102 |
+
x0, y0, x1, y1 = bounding_box
|
103 |
+
mask[:, y0:y1, x0:x1] = 0
|
104 |
+
return mask
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
return len(self.image_paths)
|
108 |
+
|
109 |
+
def __getitem__(self, index):
|
110 |
+
impath = self.image_paths[index]
|
111 |
+
if PYSPNG_IMPORTED:
|
112 |
+
with open(impath, "rb") as fp:
|
113 |
+
im = pyspng.load(fp.read())
|
114 |
+
else:
|
115 |
+
with Image.open(impath) as fp:
|
116 |
+
im = np.array(fp)
|
117 |
+
im = torch.from_numpy(np.rollaxis(im, -1, 0))
|
118 |
+
masks = self.get_mask(index)
|
119 |
+
landmark = self.landmarks[index]
|
120 |
+
batch = {
|
121 |
+
"img": im,
|
122 |
+
"mask": masks,
|
123 |
+
}
|
124 |
+
if self.load_keypoints:
|
125 |
+
batch["keypoints"] = landmark
|
126 |
+
if self.transform is None:
|
127 |
+
return batch
|
128 |
+
return self.transform(batch)
|
129 |
+
|
dp2/data/datasets/fdh.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tops
|
3 |
+
import numpy as np
|
4 |
+
import io
|
5 |
+
import webdataset as wds
|
6 |
+
import os
|
7 |
+
from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn
|
8 |
+
|
9 |
+
|
10 |
+
def kp_decoder(x):
|
11 |
+
# Keypoints are between [0, 1] for webdataset
|
12 |
+
keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float()
|
13 |
+
keypoints[:, 0] /= 160
|
14 |
+
keypoints[:, 1] /= 288
|
15 |
+
check_outside = lambda x: (x < 0).logical_or(x > 1)
|
16 |
+
is_outside = check_outside(keypoints[:, 0]).logical_or(
|
17 |
+
check_outside(keypoints[:, 1])
|
18 |
+
)
|
19 |
+
keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
|
20 |
+
return keypoints
|
21 |
+
|
22 |
+
|
23 |
+
def vertices_decoder(x):
|
24 |
+
vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32))
|
25 |
+
return vertices.squeeze()[None]
|
26 |
+
|
27 |
+
|
28 |
+
def get_dataloader_fdh_wds(
|
29 |
+
path,
|
30 |
+
batch_size: int,
|
31 |
+
num_workers: int,
|
32 |
+
transform: torch.nn.Module,
|
33 |
+
gpu_transform: torch.nn.Module,
|
34 |
+
infinite: bool,
|
35 |
+
shuffle: bool,
|
36 |
+
partial_batches: bool,
|
37 |
+
load_embedding: bool,
|
38 |
+
sample_shuffle=10_000,
|
39 |
+
tar_shuffle=100,
|
40 |
+
read_condition=False,
|
41 |
+
channels_last=False,
|
42 |
+
):
|
43 |
+
# Need to set this for split_by_node to work.
|
44 |
+
os.environ["RANK"] = str(tops.rank())
|
45 |
+
os.environ["WORLD_SIZE"] = str(tops.world_size())
|
46 |
+
if infinite:
|
47 |
+
pipeline = [wds.ResampledShards(str(path))]
|
48 |
+
else:
|
49 |
+
pipeline = [wds.SimpleShardList(str(path))]
|
50 |
+
if shuffle:
|
51 |
+
pipeline.append(wds.shuffle(tar_shuffle))
|
52 |
+
pipeline.extend([
|
53 |
+
wds.split_by_node,
|
54 |
+
wds.split_by_worker,
|
55 |
+
])
|
56 |
+
if shuffle:
|
57 |
+
pipeline.append(wds.shuffle(sample_shuffle))
|
58 |
+
|
59 |
+
decoder = [
|
60 |
+
wds.handle_extension("image.png", png_decoder),
|
61 |
+
wds.handle_extension("mask.png", mask_decoder),
|
62 |
+
wds.handle_extension("maskrcnn_mask.png", mask_decoder),
|
63 |
+
wds.handle_extension("keypoints.npy", kp_decoder),
|
64 |
+
]
|
65 |
+
|
66 |
+
rename_keys = [
|
67 |
+
["img", "image.png"], ["mask", "mask.png"],
|
68 |
+
["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"]
|
69 |
+
]
|
70 |
+
if load_embedding:
|
71 |
+
decoder.extend([
|
72 |
+
wds.handle_extension("vertices.npy", vertices_decoder),
|
73 |
+
wds.handle_extension("E_mask.png", mask_decoder)
|
74 |
+
])
|
75 |
+
rename_keys.extend([
|
76 |
+
["vertices", "vertices.npy"],
|
77 |
+
["E_mask", "e_mask.png"]
|
78 |
+
])
|
79 |
+
|
80 |
+
if read_condition:
|
81 |
+
decoder.append(
|
82 |
+
wds.handle_extension("condition.png", png_decoder)
|
83 |
+
)
|
84 |
+
rename_keys.append(["condition", "condition.png"])
|
85 |
+
|
86 |
+
pipeline.extend([
|
87 |
+
wds.tarfile_to_samples(),
|
88 |
+
wds.decode(*decoder),
|
89 |
+
wds.rename_keys(*rename_keys),
|
90 |
+
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
|
91 |
+
])
|
92 |
+
if transform is not None:
|
93 |
+
pipeline.append(wds.map(transform))
|
94 |
+
pipeline = wds.DataPipeline(*pipeline)
|
95 |
+
if infinite:
|
96 |
+
pipeline = pipeline.repeat(nepochs=1000000)
|
97 |
+
|
98 |
+
loader = wds.WebLoader(
|
99 |
+
pipeline, batch_size=None, shuffle=False,
|
100 |
+
num_workers=get_num_workers(num_workers),
|
101 |
+
persistent_workers=True,
|
102 |
+
)
|
103 |
+
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
|
104 |
+
return loader
|
dp2/data/transforms/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .transforms import RandomCrop, CreateCondition, CreateEmbedding, Resize, ToFloat, Normalize
|
2 |
+
from .stylegan2_transform import StyleGANAugmentPipe
|
dp2/data/transforms/functional.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision.transforms.functional as F
|
2 |
+
import torch
|
3 |
+
import pickle
|
4 |
+
from tops import download_file, assert_shape
|
5 |
+
from typing import Dict
|
6 |
+
from functools import lru_cache
|
7 |
+
|
8 |
+
global symmetry_transform
|
9 |
+
|
10 |
+
@lru_cache(maxsize=1)
|
11 |
+
def get_symmetry_transform(symmetry_url):
|
12 |
+
file_name = download_file(symmetry_url)
|
13 |
+
with open(file_name, "rb") as fp:
|
14 |
+
symmetry = pickle.load(fp)
|
15 |
+
return torch.from_numpy(symmetry["vertex_transforms"]).long()
|
16 |
+
|
17 |
+
|
18 |
+
hflip_handled_cases = set([
|
19 |
+
"keypoints", "img", "mask", "border", "semantic_mask", "vertices", "E_mask", "embed_map", "condition",
|
20 |
+
"embedding", "vertx2cat", "maskrcnn_mask", "__key__",
|
21 |
+
"img_hr", "condition_hr", "mask_hr"])
|
22 |
+
|
23 |
+
def hflip(container: Dict[str, torch.Tensor], flip_map=None) -> Dict[str, torch.Tensor]:
|
24 |
+
container["img"] = F.hflip(container["img"])
|
25 |
+
if "condition" in container:
|
26 |
+
container["condition"] = F.hflip(container["condition"])
|
27 |
+
if "embedding" in container:
|
28 |
+
container["embedding"] = F.hflip(container["embedding"])
|
29 |
+
assert all([key in hflip_handled_cases for key in container]), container.keys()
|
30 |
+
if "keypoints" in container:
|
31 |
+
assert flip_map is not None
|
32 |
+
if container["keypoints"].ndim == 3:
|
33 |
+
keypoints = container["keypoints"][:, flip_map, :]
|
34 |
+
keypoints[:, :, 0] = 1 - keypoints[:, :, 0]
|
35 |
+
else:
|
36 |
+
assert_shape(container["keypoints"], (None, 3))
|
37 |
+
keypoints = container["keypoints"][flip_map, :]
|
38 |
+
keypoints[:, 0] = 1 - keypoints[:, 0]
|
39 |
+
container["keypoints"] = keypoints
|
40 |
+
if "mask" in container:
|
41 |
+
container["mask"] = F.hflip(container["mask"])
|
42 |
+
if "border" in container:
|
43 |
+
container["border"] = F.hflip(container["border"])
|
44 |
+
if "semantic_mask" in container:
|
45 |
+
container["semantic_mask"] = F.hflip(container["semantic_mask"])
|
46 |
+
if "vertices" in container:
|
47 |
+
symmetry_transform = get_symmetry_transform("https://dl.fbaipublicfiles.com/densepose/meshes/symmetry/symmetry_smpl_27554.pkl")
|
48 |
+
container["vertices"] = F.hflip(container["vertices"])
|
49 |
+
symmetry_transform_ = symmetry_transform.to(container["vertices"].device)
|
50 |
+
container["vertices"] = symmetry_transform_[container["vertices"].long()]
|
51 |
+
if "E_mask" in container:
|
52 |
+
container["E_mask"] = F.hflip(container["E_mask"])
|
53 |
+
if "maskrcnn_mask" in container:
|
54 |
+
container["maskrcnn_mask"] = F.hflip(container["maskrcnn_mask"])
|
55 |
+
if "img_hr" in container:
|
56 |
+
container["img_hr"] = F.hflip(container["img_hr"])
|
57 |
+
if "condition_hr" in container:
|
58 |
+
container["condition_hr"] = F.hflip(container["condition_hr"])
|
59 |
+
if "mask_hr" in container:
|
60 |
+
container["mask_hr"] = F.hflip(container["mask_hr"])
|
61 |
+
return container
|
dp2/data/transforms/stylegan2_transform.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy.signal
|
3 |
+
import torch
|
4 |
+
try:
|
5 |
+
from sg3_torch_utils import misc
|
6 |
+
from sg3_torch_utils.ops import upfirdn2d
|
7 |
+
from sg3_torch_utils.ops import grid_sample_gradfix
|
8 |
+
from sg3_torch_utils.ops import conv2d_gradfix
|
9 |
+
except:
|
10 |
+
pass
|
11 |
+
#----------------------------------------------------------------------------
|
12 |
+
# Coefficients of various wavelet decomposition low-pass filters.
|
13 |
+
|
14 |
+
wavelets = {
|
15 |
+
'haar': [0.7071067811865476, 0.7071067811865476],
|
16 |
+
'db1': [0.7071067811865476, 0.7071067811865476],
|
17 |
+
'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
18 |
+
'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
19 |
+
'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
|
20 |
+
'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
|
21 |
+
'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
|
22 |
+
'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
|
23 |
+
'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
|
24 |
+
'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
25 |
+
'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
26 |
+
'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
|
27 |
+
'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
|
28 |
+
'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
|
29 |
+
'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
|
30 |
+
'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
|
31 |
+
}
|
32 |
+
|
33 |
+
#----------------------------------------------------------------------------
|
34 |
+
# Helpers for constructing transformation matrices.
|
35 |
+
|
36 |
+
|
37 |
+
def matrix(*rows, device=None):
|
38 |
+
assert all(len(row) == len(rows[0]) for row in rows)
|
39 |
+
elems = [x for row in rows for x in row]
|
40 |
+
ref = [x for x in elems if isinstance(x, torch.Tensor)]
|
41 |
+
if len(ref) == 0:
|
42 |
+
return misc.constant(np.asarray(rows), device=device)
|
43 |
+
assert device is None or device == ref[0].device
|
44 |
+
elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
|
45 |
+
return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
|
46 |
+
|
47 |
+
|
48 |
+
def translate2d(tx, ty, **kwargs):
|
49 |
+
return matrix(
|
50 |
+
[1, 0, tx],
|
51 |
+
[0, 1, ty],
|
52 |
+
[0, 0, 1],
|
53 |
+
**kwargs)
|
54 |
+
|
55 |
+
|
56 |
+
def translate3d(tx, ty, tz, **kwargs):
|
57 |
+
return matrix(
|
58 |
+
[1, 0, 0, tx],
|
59 |
+
[0, 1, 0, ty],
|
60 |
+
[0, 0, 1, tz],
|
61 |
+
[0, 0, 0, 1],
|
62 |
+
**kwargs)
|
63 |
+
|
64 |
+
|
65 |
+
def scale2d(sx, sy, **kwargs):
|
66 |
+
return matrix(
|
67 |
+
[sx, 0, 0],
|
68 |
+
[0, sy, 0],
|
69 |
+
[0, 0, 1],
|
70 |
+
**kwargs)
|
71 |
+
|
72 |
+
|
73 |
+
def scale3d(sx, sy, sz, **kwargs):
|
74 |
+
return matrix(
|
75 |
+
[sx, 0, 0, 0],
|
76 |
+
[0, sy, 0, 0],
|
77 |
+
[0, 0, sz, 0],
|
78 |
+
[0, 0, 0, 1],
|
79 |
+
**kwargs)
|
80 |
+
|
81 |
+
|
82 |
+
def rotate2d(theta, **kwargs):
|
83 |
+
return matrix(
|
84 |
+
[torch.cos(theta), torch.sin(-theta), 0],
|
85 |
+
[torch.sin(theta), torch.cos(theta), 0],
|
86 |
+
[0, 0, 1],
|
87 |
+
**kwargs)
|
88 |
+
|
89 |
+
|
90 |
+
def rotate3d(v, theta, **kwargs):
|
91 |
+
vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
|
92 |
+
s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
|
93 |
+
return matrix(
|
94 |
+
[vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
|
95 |
+
[vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
|
96 |
+
[vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
|
97 |
+
[0, 0, 0, 1],
|
98 |
+
**kwargs)
|
99 |
+
|
100 |
+
|
101 |
+
def translate2d_inv(tx, ty, **kwargs):
|
102 |
+
return translate2d(-tx, -ty, **kwargs)
|
103 |
+
|
104 |
+
|
105 |
+
def scale2d_inv(sx, sy, **kwargs):
|
106 |
+
return scale2d(1 / sx, 1 / sy, **kwargs)
|
107 |
+
|
108 |
+
|
109 |
+
def rotate2d_inv(theta, **kwargs):
|
110 |
+
return rotate2d(-theta, **kwargs)
|
111 |
+
|
112 |
+
|
113 |
+
class StyleGANAugmentPipe(torch.nn.Module):
|
114 |
+
def __init__(self,
|
115 |
+
rotate90=0, xint=0, xint_max=0.125,
|
116 |
+
scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
|
117 |
+
brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5,
|
118 |
+
hue_max=1, saturation_std=1,
|
119 |
+
imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
|
120 |
+
):
|
121 |
+
super().__init__()
|
122 |
+
self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
|
123 |
+
|
124 |
+
# Pixel blitting.
|
125 |
+
self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
|
126 |
+
self.xint = float(xint) # Probability multiplier for integer translation.
|
127 |
+
self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
|
128 |
+
|
129 |
+
# General geometric transformations.
|
130 |
+
self.scale = float(scale) # Probability multiplier for isotropic scaling.
|
131 |
+
self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
|
132 |
+
self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
|
133 |
+
self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
|
134 |
+
self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
|
135 |
+
self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
|
136 |
+
self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
|
137 |
+
self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
|
138 |
+
|
139 |
+
# Color transformations.
|
140 |
+
self.brightness = float(brightness) # Probability multiplier for brightness.
|
141 |
+
self.contrast = float(contrast) # Probability multiplier for contrast.
|
142 |
+
self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
|
143 |
+
self.hue = float(hue) # Probability multiplier for hue rotation.
|
144 |
+
self.saturation = float(saturation) # Probability multiplier for saturation.
|
145 |
+
self.brightness_std = float(brightness_std) # Standard deviation of brightness.
|
146 |
+
self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
|
147 |
+
self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
|
148 |
+
self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
|
149 |
+
|
150 |
+
# Image-space filtering.
|
151 |
+
self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
|
152 |
+
self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
|
153 |
+
self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
|
154 |
+
|
155 |
+
# Setup orthogonal lowpass filter for geometric augmentations.
|
156 |
+
self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
|
157 |
+
|
158 |
+
# Construct filter bank for image-space filtering.
|
159 |
+
Hz_lo = np.asarray(wavelets['sym2']) # H(z)
|
160 |
+
Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
|
161 |
+
Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
|
162 |
+
Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
|
163 |
+
Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
|
164 |
+
for i in range(1, Hz_fbank.shape[0]):
|
165 |
+
Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
|
166 |
+
Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
|
167 |
+
Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
|
168 |
+
self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
|
169 |
+
|
170 |
+
def forward(self, batch, debug_percentile=None):
|
171 |
+
images = batch["img"]
|
172 |
+
batch["vertices"] = batch["vertices"].float()
|
173 |
+
assert isinstance(images, torch.Tensor) and images.ndim == 4
|
174 |
+
batch_size, num_channels, height, width = images.shape
|
175 |
+
device = images.device
|
176 |
+
self.Hz_fbank = self.Hz_fbank.to(device)
|
177 |
+
self.Hz_geom = self.Hz_geom.to(device)
|
178 |
+
if debug_percentile is not None:
|
179 |
+
debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
|
180 |
+
|
181 |
+
# -------------------------------------
|
182 |
+
# Select parameters for pixel blitting.
|
183 |
+
# -------------------------------------
|
184 |
+
|
185 |
+
# Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
|
186 |
+
I_3 = torch.eye(3, device=device)
|
187 |
+
G_inv = I_3
|
188 |
+
|
189 |
+
# Apply integer translation with probability (xint * strength).
|
190 |
+
if self.xint > 0:
|
191 |
+
t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
|
192 |
+
t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
|
193 |
+
if debug_percentile is not None:
|
194 |
+
t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
|
195 |
+
G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
|
196 |
+
|
197 |
+
# --------------------------------------------------------
|
198 |
+
# Select parameters for general geometric transformations.
|
199 |
+
# --------------------------------------------------------
|
200 |
+
|
201 |
+
# Apply isotropic scaling with probability (scale * strength).
|
202 |
+
if self.scale > 0:
|
203 |
+
s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
|
204 |
+
s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
|
205 |
+
if debug_percentile is not None:
|
206 |
+
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
|
207 |
+
G_inv = G_inv @ scale2d_inv(s, s)
|
208 |
+
|
209 |
+
# Apply pre-rotation with probability p_rot.
|
210 |
+
p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
|
211 |
+
if self.rotate > 0:
|
212 |
+
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
|
213 |
+
theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
|
214 |
+
if debug_percentile is not None:
|
215 |
+
theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
|
216 |
+
G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
|
217 |
+
|
218 |
+
# Apply anisotropic scaling with probability (aniso * strength).
|
219 |
+
if self.aniso > 0:
|
220 |
+
s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
|
221 |
+
s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
|
222 |
+
if debug_percentile is not None:
|
223 |
+
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
|
224 |
+
G_inv = G_inv @ scale2d_inv(s, 1 / s)
|
225 |
+
|
226 |
+
# Apply post-rotation with probability p_rot.
|
227 |
+
if self.rotate > 0:
|
228 |
+
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
|
229 |
+
theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
|
230 |
+
if debug_percentile is not None:
|
231 |
+
theta = torch.zeros_like(theta)
|
232 |
+
G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
|
233 |
+
|
234 |
+
# Apply fractional translation with probability (xfrac * strength).
|
235 |
+
if self.xfrac > 0:
|
236 |
+
t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
|
237 |
+
t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
|
238 |
+
if debug_percentile is not None:
|
239 |
+
t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
|
240 |
+
G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
|
241 |
+
|
242 |
+
# ----------------------------------
|
243 |
+
# Execute geometric transformations.
|
244 |
+
# ----------------------------------
|
245 |
+
|
246 |
+
# Execute if the transform is not identity.
|
247 |
+
if G_inv is not I_3:
|
248 |
+
# Calculate padding.
|
249 |
+
cx = (width - 1) / 2
|
250 |
+
cy = (height - 1) / 2
|
251 |
+
cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
|
252 |
+
cp = G_inv @ cp.t() # [batch, xyz, idx]
|
253 |
+
Hz_pad = self.Hz_geom.shape[0] // 4
|
254 |
+
margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
|
255 |
+
margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
|
256 |
+
margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
|
257 |
+
margin = margin.max(misc.constant([0, 0] * 2, device=device))
|
258 |
+
margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
|
259 |
+
mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
|
260 |
+
|
261 |
+
# Pad image and adjust origin.
|
262 |
+
images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
|
263 |
+
batch["mask"] = torch.nn.functional.pad(input=batch["mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=1.0)
|
264 |
+
batch["E_mask"] = torch.nn.functional.pad(input=batch["E_mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
|
265 |
+
batch["vertices"] = torch.nn.functional.pad(input=batch["vertices"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
|
266 |
+
G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
|
267 |
+
|
268 |
+
# Upsample.
|
269 |
+
images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
|
270 |
+
batch["mask"] = torch.nn.functional.interpolate(batch["mask"], scale_factor=2, mode="nearest")
|
271 |
+
batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"], scale_factor=2, mode="nearest")
|
272 |
+
batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"], scale_factor=2, mode="nearest")
|
273 |
+
G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
|
274 |
+
G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
|
275 |
+
|
276 |
+
# Execute transformation.
|
277 |
+
shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
|
278 |
+
G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
|
279 |
+
grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
|
280 |
+
images = grid_sample_gradfix.grid_sample(images, grid)
|
281 |
+
|
282 |
+
batch["mask"] = torch.nn.functional.grid_sample(
|
283 |
+
input=batch["mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
|
284 |
+
batch["E_mask"] = torch.nn.functional.grid_sample(
|
285 |
+
input=batch["E_mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
|
286 |
+
batch["vertices"] = torch.nn.functional.grid_sample(
|
287 |
+
input=batch["vertices"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
|
288 |
+
|
289 |
+
|
290 |
+
# Downsample and crop.
|
291 |
+
images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
|
292 |
+
batch["mask"] = torch.nn.functional.interpolate(batch["mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
|
293 |
+
batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
|
294 |
+
batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
|
295 |
+
# --------------------------------------------
|
296 |
+
# Select parameters for color transformations.
|
297 |
+
# --------------------------------------------
|
298 |
+
|
299 |
+
# Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
|
300 |
+
I_4 = torch.eye(4, device=device)
|
301 |
+
C = I_4
|
302 |
+
|
303 |
+
# Apply brightness with probability (brightness * strength).
|
304 |
+
if self.brightness > 0:
|
305 |
+
b = torch.randn([batch_size], device=device) * self.brightness_std
|
306 |
+
b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
|
307 |
+
if debug_percentile is not None:
|
308 |
+
b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
|
309 |
+
C = translate3d(b, b, b) @ C
|
310 |
+
|
311 |
+
# Apply contrast with probability (contrast * strength).
|
312 |
+
if self.contrast > 0:
|
313 |
+
c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
|
314 |
+
c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
|
315 |
+
if debug_percentile is not None:
|
316 |
+
c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
|
317 |
+
C = scale3d(c, c, c) @ C
|
318 |
+
|
319 |
+
# Apply luma flip with probability (lumaflip * strength).
|
320 |
+
v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
|
321 |
+
|
322 |
+
# Apply hue rotation with probability (hue * strength).
|
323 |
+
if self.hue > 0 and num_channels > 1:
|
324 |
+
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
|
325 |
+
theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
|
326 |
+
if debug_percentile is not None:
|
327 |
+
theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
|
328 |
+
C = rotate3d(v, theta) @ C # Rotate around v.
|
329 |
+
|
330 |
+
# Apply saturation with probability (saturation * strength).
|
331 |
+
if self.saturation > 0 and num_channels > 1:
|
332 |
+
s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
|
333 |
+
s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
|
334 |
+
if debug_percentile is not None:
|
335 |
+
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
|
336 |
+
C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
|
337 |
+
|
338 |
+
# ------------------------------
|
339 |
+
# Execute color transformations.
|
340 |
+
# ------------------------------
|
341 |
+
|
342 |
+
# Execute if the transform is not identity.
|
343 |
+
if C is not I_4:
|
344 |
+
images = images.reshape([batch_size, num_channels, height * width])
|
345 |
+
if num_channels == 3:
|
346 |
+
images = C[:, :3, :3] @ images + C[:, :3, 3:]
|
347 |
+
elif num_channels == 1:
|
348 |
+
C = C[:, :3, :].mean(dim=1, keepdims=True)
|
349 |
+
images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
|
350 |
+
else:
|
351 |
+
raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
|
352 |
+
images = images.reshape([batch_size, num_channels, height, width])
|
353 |
+
|
354 |
+
# ----------------------
|
355 |
+
# Image-space filtering.
|
356 |
+
# ----------------------
|
357 |
+
|
358 |
+
if self.imgfilter > 0:
|
359 |
+
num_bands = self.Hz_fbank.shape[0]
|
360 |
+
assert len(self.imgfilter_bands) == num_bands
|
361 |
+
expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
|
362 |
+
|
363 |
+
# Apply amplification for each band with probability (imgfilter * strength * band_strength).
|
364 |
+
g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
|
365 |
+
for i, band_strength in enumerate(self.imgfilter_bands):
|
366 |
+
t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
|
367 |
+
t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
|
368 |
+
if debug_percentile is not None:
|
369 |
+
t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
|
370 |
+
t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
|
371 |
+
t[:, i] = t_i # Replace i'th element.
|
372 |
+
t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
|
373 |
+
g = g * t # Accumulate into global gain.
|
374 |
+
|
375 |
+
# Construct combined amplification filter.
|
376 |
+
Hz_prime = g @ self.Hz_fbank # [batch, tap]
|
377 |
+
Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
|
378 |
+
Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
|
379 |
+
|
380 |
+
# Apply filter.
|
381 |
+
p = self.Hz_fbank.shape[1] // 2
|
382 |
+
images = images.reshape([1, batch_size * num_channels, height, width])
|
383 |
+
images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
|
384 |
+
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
|
385 |
+
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
|
386 |
+
images = images.reshape([batch_size, num_channels, height, width])
|
387 |
+
|
388 |
+
# ------------------------
|
389 |
+
# Image-space corruptions.
|
390 |
+
# ------------------------
|
391 |
+
batch["img"] = images
|
392 |
+
batch["vertices"] = batch["vertices"].long()
|
393 |
+
batch["border"] = 1 - batch["E_mask"] - batch["mask"]
|
394 |
+
return batch
|
dp2/data/transforms/transforms.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Dict, List
|
3 |
+
import torchvision
|
4 |
+
import torch
|
5 |
+
import tops
|
6 |
+
import torchvision.transforms.functional as F
|
7 |
+
from .functional import hflip
|
8 |
+
|
9 |
+
|
10 |
+
class RandomHorizontalFlip(torch.nn.Module):
|
11 |
+
|
12 |
+
def __init__(self, p: float, flip_map=None,**kwargs):
|
13 |
+
super().__init__()
|
14 |
+
self.flip_ratio = p
|
15 |
+
self.flip_map = flip_map
|
16 |
+
if self.flip_ratio is None:
|
17 |
+
self.flip_ratio = 0.5
|
18 |
+
assert 0 <= self.flip_ratio <= 1
|
19 |
+
|
20 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
21 |
+
if torch.rand(1) > self.flip_ratio:
|
22 |
+
return container
|
23 |
+
return hflip(container, self.flip_map)
|
24 |
+
|
25 |
+
|
26 |
+
class CenterCrop(torch.nn.Module):
|
27 |
+
"""
|
28 |
+
Performs the transform on the image.
|
29 |
+
NOTE: Does not transform the mask to improve runtime.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, size: List[int]):
|
33 |
+
super().__init__()
|
34 |
+
self.size = tuple(size)
|
35 |
+
|
36 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
37 |
+
min_size = min(container["img"].shape[1], container["img"].shape[2])
|
38 |
+
if min_size < self.size[0]:
|
39 |
+
container["img"] = F.center_crop(container["img"], min_size)
|
40 |
+
container["img"] = F.resize(container["img"], self.size)
|
41 |
+
return container
|
42 |
+
container["img"] = F.center_crop(container["img"], self.size)
|
43 |
+
return container
|
44 |
+
|
45 |
+
|
46 |
+
class Resize(torch.nn.Module):
|
47 |
+
"""
|
48 |
+
Performs the transform on the image.
|
49 |
+
NOTE: Does not transform the mask to improve runtime.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
|
53 |
+
super().__init__()
|
54 |
+
self.size = tuple(size)
|
55 |
+
self.interpolation = interpolation
|
56 |
+
|
57 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
58 |
+
container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
|
59 |
+
if "semantic_mask" in container:
|
60 |
+
container["semantic_mask"] = F.resize(
|
61 |
+
container["semantic_mask"], self.size, F.InterpolationMode.NEAREST)
|
62 |
+
if "embedding" in container:
|
63 |
+
container["embedding"] = F.resize(
|
64 |
+
container["embedding"], self.size, self.interpolation)
|
65 |
+
if "mask" in container:
|
66 |
+
container["mask"] = F.resize(
|
67 |
+
container["mask"], self.size, F.InterpolationMode.NEAREST)
|
68 |
+
if "E_mask" in container:
|
69 |
+
container["E_mask"] = F.resize(
|
70 |
+
container["E_mask"], self.size, F.InterpolationMode.NEAREST)
|
71 |
+
if "maskrcnn_mask" in container:
|
72 |
+
container["maskrcnn_mask"] = F.resize(
|
73 |
+
container["maskrcnn_mask"], self.size, F.InterpolationMode.NEAREST)
|
74 |
+
if "vertices" in container:
|
75 |
+
container["vertices"] = F.resize(
|
76 |
+
container["vertices"], self.size, F.InterpolationMode.NEAREST)
|
77 |
+
return container
|
78 |
+
|
79 |
+
def __repr__(self):
|
80 |
+
repr = super().__repr__()
|
81 |
+
vars_ = dict(size=self.size, interpolation=self.interpolation)
|
82 |
+
return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
|
83 |
+
|
84 |
+
|
85 |
+
class InsertHRImage(torch.nn.Module):
|
86 |
+
"""
|
87 |
+
Resizes mask by maxpool and assumes condition is already created
|
88 |
+
"""
|
89 |
+
def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
|
90 |
+
super().__init__()
|
91 |
+
self.size = tuple(size)
|
92 |
+
self.interpolation = interpolation
|
93 |
+
|
94 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
95 |
+
assert container["img"].dtype == torch.float32
|
96 |
+
container["img_hr"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
|
97 |
+
container["condition_hr"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True)
|
98 |
+
mask = container["mask"] > 0
|
99 |
+
container["mask_hr"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float()
|
100 |
+
container["condition_hr"] = container["condition_hr"] * (1 - container["mask_hr"]) + container["img_hr"] * container["mask_hr"]
|
101 |
+
return container
|
102 |
+
|
103 |
+
def __repr__(self):
|
104 |
+
repr = super().__repr__()
|
105 |
+
vars_ = dict(size=self.size, interpolation=self.interpolation)
|
106 |
+
return repr + " "
|
107 |
+
|
108 |
+
|
109 |
+
class CopyHRImage(torch.nn.Module):
|
110 |
+
def __init__(self) -> None:
|
111 |
+
super().__init__()
|
112 |
+
|
113 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
114 |
+
container["img_hr"] = container["img"]
|
115 |
+
container["condition_hr"] = container["condition"]
|
116 |
+
container["mask_hr"] = container["mask"]
|
117 |
+
return container
|
118 |
+
|
119 |
+
|
120 |
+
class Resize2(torch.nn.Module):
|
121 |
+
"""
|
122 |
+
Resizes mask by maxpool and assumes condition is already created
|
123 |
+
"""
|
124 |
+
def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR, downsample_condition: bool = True, mask_condition= True):
|
125 |
+
super().__init__()
|
126 |
+
self.size = tuple(size)
|
127 |
+
self.interpolation = interpolation
|
128 |
+
self.downsample_condition = downsample_condition
|
129 |
+
self.mask_condition = mask_condition
|
130 |
+
|
131 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
132 |
+
# assert container["img"].dtype == torch.float32
|
133 |
+
container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
|
134 |
+
mask = container["mask"] > 0
|
135 |
+
container["mask"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float()
|
136 |
+
|
137 |
+
if self.downsample_condition:
|
138 |
+
container["condition"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True)
|
139 |
+
if self.mask_condition:
|
140 |
+
container["condition"] = container["condition"] * (1 - container["mask"]) + container["img"] * container["mask"]
|
141 |
+
return container
|
142 |
+
|
143 |
+
def __repr__(self):
|
144 |
+
repr = super().__repr__()
|
145 |
+
vars_ = dict(size=self.size, interpolation=self.interpolation)
|
146 |
+
return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
class Normalize(torch.nn.Module):
|
151 |
+
"""
|
152 |
+
Performs the transform on the image.
|
153 |
+
NOTE: Does not transform the mask to improve runtime.
|
154 |
+
"""
|
155 |
+
|
156 |
+
def __init__(self, mean, std, inplace, keys=["img"]):
|
157 |
+
super().__init__()
|
158 |
+
self.mean = mean
|
159 |
+
self.std = std
|
160 |
+
self.inplace = inplace
|
161 |
+
self.keys = keys
|
162 |
+
|
163 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
164 |
+
for key in self.keys:
|
165 |
+
container[key] = F.normalize(container[key], self.mean, self.std, self.inplace)
|
166 |
+
return container
|
167 |
+
|
168 |
+
def __repr__(self):
|
169 |
+
repr = super().__repr__()
|
170 |
+
vars_ = dict(mean=self.mean, std=self.std, inplace=self.inplace)
|
171 |
+
return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
|
172 |
+
|
173 |
+
|
174 |
+
class ToFloat(torch.nn.Module):
|
175 |
+
|
176 |
+
def __init__(self, keys=["img"], norm=True) -> None:
|
177 |
+
super().__init__()
|
178 |
+
self.keys = keys
|
179 |
+
self.gain = 255 if norm else 1
|
180 |
+
|
181 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
182 |
+
for key in self.keys:
|
183 |
+
container[key] = container[key].float() / self.gain
|
184 |
+
return container
|
185 |
+
|
186 |
+
|
187 |
+
class RandomCrop(torchvision.transforms.RandomCrop):
|
188 |
+
"""
|
189 |
+
Performs the transform on the image.
|
190 |
+
NOTE: Does not transform the mask to improve runtime.
|
191 |
+
"""
|
192 |
+
|
193 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
194 |
+
container["img"] = super().forward(container["img"])
|
195 |
+
return container
|
196 |
+
|
197 |
+
|
198 |
+
class CreateCondition(torch.nn.Module):
|
199 |
+
|
200 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
201 |
+
if container["img"].dtype == torch.uint8:
|
202 |
+
container["condition"] = container["img"] * container["mask"].byte() + (1-container["mask"].byte()) * 127
|
203 |
+
return container
|
204 |
+
container["condition"] = container["img"] * container["mask"]
|
205 |
+
return container
|
206 |
+
|
207 |
+
|
208 |
+
class CreateEmbedding(torch.nn.Module):
|
209 |
+
|
210 |
+
def __init__(self, embed_path: Path, cuda=True) -> None:
|
211 |
+
super().__init__()
|
212 |
+
self.embed_map = torch.load(embed_path, map_location=torch.device("cpu"))
|
213 |
+
if cuda:
|
214 |
+
self.embed_map = tops.to_cuda(self.embed_map)
|
215 |
+
|
216 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
217 |
+
vertices = container["vertices"]
|
218 |
+
if vertices.ndim == 3:
|
219 |
+
embedding = self.embed_map[vertices.long()].squeeze(dim=0)
|
220 |
+
embedding = embedding.permute(2, 0, 1) * container["E_mask"]
|
221 |
+
pass
|
222 |
+
else:
|
223 |
+
assert vertices.ndim == 4
|
224 |
+
embedding = self.embed_map[vertices.long()].squeeze(dim=1)
|
225 |
+
embedding = embedding.permute(0, 3, 1, 2) * container["E_mask"]
|
226 |
+
container["embedding"] = embedding
|
227 |
+
container["embed_map"] = self.embed_map.clone()
|
228 |
+
return container
|
229 |
+
|
230 |
+
|
231 |
+
class UpdateMask(torch.nn.Module):
|
232 |
+
|
233 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
234 |
+
container["mask"] = (container["img"] == container["condition"]).any(dim=1, keepdims=True).float()
|
235 |
+
return container
|
236 |
+
|
237 |
+
|
238 |
+
class LoadClassEmbedding(torch.nn.Module):
|
239 |
+
|
240 |
+
def __init__(self, embedding_path: Path) -> None:
|
241 |
+
super().__init__()
|
242 |
+
self.embedding = torch.load(embedding_path, map_location="cpu")
|
243 |
+
|
244 |
+
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
245 |
+
key = "_".join(container["__key__"].split("train/")[-1].split("/")[:-1])
|
246 |
+
container["class_embedding"] = self.embedding[key].view(-1)
|
247 |
+
return container
|
dp2/data/utils.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import multiprocessing
|
5 |
+
import io
|
6 |
+
from tops import logger
|
7 |
+
from torch.utils.data._utils.collate import default_collate
|
8 |
+
|
9 |
+
try:
|
10 |
+
import pyspng
|
11 |
+
|
12 |
+
PYSPNG_IMPORTED = True
|
13 |
+
except ImportError:
|
14 |
+
PYSPNG_IMPORTED = False
|
15 |
+
print("Could not load pyspng. Defaulting to pillow image backend.")
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
|
19 |
+
def get_coco_keypoints():
|
20 |
+
return [
|
21 |
+
"nose",
|
22 |
+
"left_eye",
|
23 |
+
"right_eye",
|
24 |
+
"left_ear",
|
25 |
+
"right_ear",
|
26 |
+
"left_shoulder",
|
27 |
+
"right_shoulder",
|
28 |
+
"left_elbow",
|
29 |
+
"right_elbow",
|
30 |
+
"left_wrist",
|
31 |
+
"right_wrist",
|
32 |
+
"left_hip",
|
33 |
+
"right_hip",
|
34 |
+
"left_knee",
|
35 |
+
"right_knee",
|
36 |
+
"left_ankle",
|
37 |
+
"right_ankle",
|
38 |
+
]
|
39 |
+
|
40 |
+
|
41 |
+
def get_coco_flipmap():
|
42 |
+
keypoints = get_coco_keypoints()
|
43 |
+
keypoint_flip_map = {
|
44 |
+
"left_eye": "right_eye",
|
45 |
+
"left_ear": "right_ear",
|
46 |
+
"left_shoulder": "right_shoulder",
|
47 |
+
"left_elbow": "right_elbow",
|
48 |
+
"left_wrist": "right_wrist",
|
49 |
+
"left_hip": "right_hip",
|
50 |
+
"left_knee": "right_knee",
|
51 |
+
"left_ankle": "right_ankle",
|
52 |
+
}
|
53 |
+
for key, value in list(keypoint_flip_map.items()):
|
54 |
+
keypoint_flip_map[value] = key
|
55 |
+
keypoint_flip_map["nose"] = "nose"
|
56 |
+
keypoint_flip_map_idx = []
|
57 |
+
for source in keypoints:
|
58 |
+
keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
|
59 |
+
return keypoint_flip_map_idx
|
60 |
+
|
61 |
+
|
62 |
+
def mask_decoder(x):
|
63 |
+
mask = torch.from_numpy(np.array(Image.open(io.BytesIO(x)))).squeeze()[None]
|
64 |
+
mask = mask > 0 # This fixes bug causing maskf.loat().max() == 255.
|
65 |
+
return mask
|
66 |
+
|
67 |
+
|
68 |
+
def png_decoder(x):
|
69 |
+
if PYSPNG_IMPORTED:
|
70 |
+
return torch.from_numpy(np.rollaxis(pyspng.load(x), 2))
|
71 |
+
with Image.open(io.BytesIO(x)) as im:
|
72 |
+
im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
|
73 |
+
return im
|
74 |
+
|
75 |
+
|
76 |
+
def jpg_decoder(x):
|
77 |
+
with Image.open(io.BytesIO(x)) as im:
|
78 |
+
im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
|
79 |
+
return im
|
80 |
+
|
81 |
+
|
82 |
+
def get_num_workers(num_workers: int):
|
83 |
+
n_cpus = multiprocessing.cpu_count()
|
84 |
+
if num_workers > n_cpus:
|
85 |
+
logger.warn(f"Setting the number of workers to match cpu count: {n_cpus}")
|
86 |
+
return n_cpus
|
87 |
+
return num_workers
|
88 |
+
|
89 |
+
|
90 |
+
def collate_fn(batch):
|
91 |
+
elem = batch[0]
|
92 |
+
ignore_keys = set(["embed_map", "vertx2cat"])
|
93 |
+
batch_ = {
|
94 |
+
key: default_collate([d[key] for d in batch])
|
95 |
+
for key in elem
|
96 |
+
if key not in ignore_keys
|
97 |
+
}
|
98 |
+
if "embed_map" in elem:
|
99 |
+
batch_["embed_map"] = elem["embed_map"]
|
100 |
+
if "vertx2cat" in elem:
|
101 |
+
batch_["vertx2cat"] = elem["vertx2cat"]
|
102 |
+
return batch_
|
dp2/detection/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .cse_mask_face_detector import CSeMaskFaceDetector
|
2 |
+
from .person_detector import CSEPersonDetector
|
3 |
+
from .structures import PersonDetection, VehicleDetection, FaceDetection
|
dp2/detection/base.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import torch
|
3 |
+
import lzma
|
4 |
+
from pathlib import Path
|
5 |
+
from tops import logger
|
6 |
+
|
7 |
+
|
8 |
+
class BaseDetector:
|
9 |
+
|
10 |
+
|
11 |
+
def __init__(self, cache_directory: str) -> None:
|
12 |
+
if cache_directory is not None:
|
13 |
+
self.cache_directory = Path(cache_directory, str(self.__class__.__name__))
|
14 |
+
self.cache_directory.mkdir(exist_ok=True, parents=True)
|
15 |
+
|
16 |
+
def save_to_cache(self, detection, cache_path: Path, after_preprocess=True):
|
17 |
+
logger.log(f"Caching detection to: {cache_path}")
|
18 |
+
with lzma.open(cache_path, "wb") as fp:
|
19 |
+
torch.save(
|
20 |
+
[det.state_dict(after_preprocess=after_preprocess) for det in detection], fp,
|
21 |
+
pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
22 |
+
|
23 |
+
def load_from_cache(self, cache_path: Path):
|
24 |
+
logger.log(f"Loading detection from cache path: {cache_path}")
|
25 |
+
with lzma.open(cache_path, "rb") as fp:
|
26 |
+
state_dict = torch.load(fp)
|
27 |
+
return [
|
28 |
+
state["cls"].from_state_dict(state_dict=state) for state in state_dict
|
29 |
+
]
|
30 |
+
|
31 |
+
def forward_and_cache(self, im: torch.Tensor, cache_id: str, load_cache: bool):
|
32 |
+
if cache_id is None:
|
33 |
+
return self.forward(im)
|
34 |
+
cache_path = self.cache_directory.joinpath(cache_id + ".torch")
|
35 |
+
if cache_path.is_file() and load_cache:
|
36 |
+
try:
|
37 |
+
return self.load_from_cache(cache_path)
|
38 |
+
except Exception as e:
|
39 |
+
logger.warn(f"The cache file was corrupted: {cache_path}")
|
40 |
+
exit()
|
41 |
+
detections = self.forward(im)
|
42 |
+
self.save_to_cache(detections, cache_path)
|
43 |
+
return detections
|
44 |
+
|
45 |
+
|
dp2/detection/box_utils.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def expand_bbox_to_ratio(bbox, imshape, target_aspect_ratio):
|
5 |
+
x0, y0, x1, y1 = [int(_) for _ in bbox]
|
6 |
+
h, w = y1 - y0, x1 - x0
|
7 |
+
cur_ratio = h / w
|
8 |
+
|
9 |
+
if cur_ratio == target_aspect_ratio:
|
10 |
+
return [x0, y0, x1, y1]
|
11 |
+
if cur_ratio < target_aspect_ratio:
|
12 |
+
target_height = int(w*target_aspect_ratio)
|
13 |
+
y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
|
14 |
+
else:
|
15 |
+
target_width = int(h/target_aspect_ratio)
|
16 |
+
x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
|
17 |
+
return x0, y0, x1, y1
|
18 |
+
|
19 |
+
|
20 |
+
def expand_axis(start, end, target_width, limit):
|
21 |
+
# Can return a bbox outside of limit
|
22 |
+
cur_width = end - start
|
23 |
+
start = start - (target_width-cur_width)//2
|
24 |
+
end = end + (target_width-cur_width)//2
|
25 |
+
if end - start != target_width:
|
26 |
+
end += 1
|
27 |
+
assert end - start == target_width
|
28 |
+
if start < 0 and end > limit:
|
29 |
+
return start, end
|
30 |
+
if start < 0 and end < limit:
|
31 |
+
to_shift = min(0 - start, limit - end)
|
32 |
+
start += to_shift
|
33 |
+
end += to_shift
|
34 |
+
if end > limit and start > 0:
|
35 |
+
to_shift = min(end - limit, start)
|
36 |
+
end -= to_shift
|
37 |
+
start -= to_shift
|
38 |
+
assert end - start == target_width
|
39 |
+
return start, end
|
40 |
+
|
41 |
+
|
42 |
+
def expand_box(bbox, imshape, mask, percentage_background: float):
|
43 |
+
assert isinstance(bbox[0], int)
|
44 |
+
assert 0 < percentage_background < 1
|
45 |
+
# Percentage in S
|
46 |
+
mask_pixels = mask.long().sum().cpu()
|
47 |
+
total_pixels = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
48 |
+
percentage_mask = mask_pixels / total_pixels
|
49 |
+
if (1 - percentage_mask) > percentage_background:
|
50 |
+
return bbox
|
51 |
+
target_pixels = mask_pixels / (1 - percentage_background)
|
52 |
+
x0, y0, x1, y1 = bbox
|
53 |
+
H = y1 - y0
|
54 |
+
W = x1 - x0
|
55 |
+
p = np.sqrt(target_pixels/(H*W))
|
56 |
+
target_width = int(np.ceil(p * W))
|
57 |
+
target_height = int(np.ceil(p * H))
|
58 |
+
x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
|
59 |
+
y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
|
60 |
+
return [x0, y0, x1, y1]
|
61 |
+
|
62 |
+
|
63 |
+
def expand_axises_by_percentage(bbox_XYXY, imshape, percentage):
|
64 |
+
x0, y0, x1, y1 = bbox_XYXY
|
65 |
+
H = y1 - y0
|
66 |
+
W = x1 - x0
|
67 |
+
expansion = int(((H*W)**0.5) * percentage)
|
68 |
+
new_width = W + expansion
|
69 |
+
new_height = H + expansion
|
70 |
+
x0, x1 = expand_axis(x0, x1, min(new_width, imshape[1]), imshape[1])
|
71 |
+
y0, y1 = expand_axis(y0, y1, min(new_height, imshape[0]), imshape[0])
|
72 |
+
return [x0, y0, x1, y1]
|
73 |
+
|
74 |
+
|
75 |
+
def get_expanded_bbox(
|
76 |
+
bbox_XYXY,
|
77 |
+
imshape,
|
78 |
+
mask,
|
79 |
+
percentage_background: float,
|
80 |
+
axis_minimum_expansion: float,
|
81 |
+
target_aspect_ratio: float):
|
82 |
+
bbox_XYXY = bbox_XYXY.long().cpu().numpy().tolist()
|
83 |
+
# Expand each axis of the bounding box by a minimum percentage
|
84 |
+
bbox_XYXY = expand_axises_by_percentage(bbox_XYXY, imshape, axis_minimum_expansion)
|
85 |
+
# Find the minimum bbox with the aspect ratio. Can be outside of imshape
|
86 |
+
bbox_XYXY = expand_bbox_to_ratio(bbox_XYXY, imshape, target_aspect_ratio)
|
87 |
+
# Expands square box such that X% of the bbox is background
|
88 |
+
bbox_XYXY = expand_box(bbox_XYXY, imshape, mask, percentage_background)
|
89 |
+
assert isinstance(bbox_XYXY[0], (int, np.int64))
|
90 |
+
return bbox_XYXY
|
91 |
+
|
92 |
+
|
93 |
+
def include_box(bbox, minimum_area, aspect_ratio_range, min_bbox_ratio_inside, imshape):
|
94 |
+
def area_inside_ratio(bbox, imshape):
|
95 |
+
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
96 |
+
area_inside = (min(bbox[2], imshape[1]) - max(0,bbox[0])) * (min(imshape[0],bbox[3]) - max(0,bbox[1]))
|
97 |
+
return area_inside / area
|
98 |
+
ratio = (bbox[3] - bbox[1]) / (bbox[2] - bbox[0])
|
99 |
+
area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
100 |
+
if area_inside_ratio(bbox, imshape) < min_bbox_ratio_inside:
|
101 |
+
return False
|
102 |
+
if ratio <= aspect_ratio_range[0] or ratio >= aspect_ratio_range[1] or area < minimum_area:
|
103 |
+
return False
|
104 |
+
return True
|
dp2/detection/box_utils_fdf.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The FDF dataset expands bound boxes differently from what is used for CSE.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def quadratic_bounding_box(x0, y0, width, height, imshape):
|
9 |
+
# We assume that we can create a image that is quadratic without
|
10 |
+
# minimizing any of the sides
|
11 |
+
assert width <= min(imshape[:2])
|
12 |
+
assert height <= min(imshape[:2])
|
13 |
+
min_side = min(height, width)
|
14 |
+
if height != width:
|
15 |
+
side_diff = abs(height - width)
|
16 |
+
# Want to extend the shortest side
|
17 |
+
if min_side == height:
|
18 |
+
# Vertical side
|
19 |
+
height += side_diff
|
20 |
+
if height > imshape[0]:
|
21 |
+
# Take full frame, and shrink width
|
22 |
+
y0 = 0
|
23 |
+
height = imshape[0]
|
24 |
+
|
25 |
+
side_diff = abs(height - width)
|
26 |
+
width -= side_diff
|
27 |
+
x0 += side_diff // 2
|
28 |
+
else:
|
29 |
+
y0 -= side_diff // 2
|
30 |
+
y0 = max(0, y0)
|
31 |
+
else:
|
32 |
+
# Horizontal side
|
33 |
+
width += side_diff
|
34 |
+
if width > imshape[1]:
|
35 |
+
# Take full frame width, and shrink height
|
36 |
+
x0 = 0
|
37 |
+
width = imshape[1]
|
38 |
+
|
39 |
+
side_diff = abs(height - width)
|
40 |
+
height -= side_diff
|
41 |
+
y0 += side_diff // 2
|
42 |
+
else:
|
43 |
+
x0 -= side_diff // 2
|
44 |
+
x0 = max(0, x0)
|
45 |
+
# Check that bbox goes outside image
|
46 |
+
x1 = x0 + width
|
47 |
+
y1 = y0 + height
|
48 |
+
if imshape[1] < x1:
|
49 |
+
diff = x1 - imshape[1]
|
50 |
+
x0 -= diff
|
51 |
+
if imshape[0] < y1:
|
52 |
+
diff = y1 - imshape[0]
|
53 |
+
y0 -= diff
|
54 |
+
assert x0 >= 0, "Bounding box outside image."
|
55 |
+
assert y0 >= 0, "Bounding box outside image."
|
56 |
+
assert x0 + width <= imshape[1], "Bounding box outside image."
|
57 |
+
assert y0 + height <= imshape[0], "Bounding box outside image."
|
58 |
+
return x0, y0, width, height
|
59 |
+
|
60 |
+
|
61 |
+
def expand_bounding_box(bbox, percentage, imshape):
|
62 |
+
orig_bbox = bbox.copy()
|
63 |
+
x0, y0, x1, y1 = bbox
|
64 |
+
width = x1 - x0
|
65 |
+
height = y1 - y0
|
66 |
+
x0, y0, width, height = quadratic_bounding_box(
|
67 |
+
x0, y0, width, height, imshape)
|
68 |
+
expanding_factor = int(max(height, width) * percentage)
|
69 |
+
|
70 |
+
possible_max_expansion = [(imshape[0] - width) // 2,
|
71 |
+
(imshape[1] - height) // 2,
|
72 |
+
expanding_factor]
|
73 |
+
|
74 |
+
expanding_factor = min(possible_max_expansion)
|
75 |
+
# Expand height
|
76 |
+
|
77 |
+
if expanding_factor > 0:
|
78 |
+
|
79 |
+
y0 = y0 - expanding_factor
|
80 |
+
y0 = max(0, y0)
|
81 |
+
|
82 |
+
height += expanding_factor * 2
|
83 |
+
if height > imshape[0]:
|
84 |
+
y0 -= (imshape[0] - height)
|
85 |
+
height = imshape[0]
|
86 |
+
|
87 |
+
if height + y0 > imshape[0]:
|
88 |
+
y0 -= (height + y0 - imshape[0])
|
89 |
+
|
90 |
+
# Expand width
|
91 |
+
x0 = x0 - expanding_factor
|
92 |
+
x0 = max(0, x0)
|
93 |
+
|
94 |
+
width += expanding_factor * 2
|
95 |
+
if width > imshape[1]:
|
96 |
+
x0 -= (imshape[1] - width)
|
97 |
+
width = imshape[1]
|
98 |
+
|
99 |
+
if width + x0 > imshape[1]:
|
100 |
+
x0 -= (width + x0 - imshape[1])
|
101 |
+
y1 = y0 + height
|
102 |
+
x1 = x0 + width
|
103 |
+
assert y0 >= 0, "Y0 is minus"
|
104 |
+
assert height <= imshape[0], "Height is larger than image."
|
105 |
+
assert x0 + width <= imshape[1]
|
106 |
+
assert y0 + height <= imshape[0]
|
107 |
+
assert width == height, "HEIGHT IS NOT EQUAL WIDTH!!"
|
108 |
+
assert x0 >= 0, "Y0 is minus"
|
109 |
+
assert width <= imshape[1], "Height is larger than image."
|
110 |
+
# Check that original bbox is within new
|
111 |
+
x0_o, y0_o, x1_o, y1_o = orig_bbox
|
112 |
+
assert x0 <= x0_o, f"New bbox is outisde of original. O:{x0_o}, N: {x0}"
|
113 |
+
assert x1 >= x1_o, f"New bbox is outisde of original. O:{x1_o}, N: {x1}"
|
114 |
+
assert y0 <= y0_o, f"New bbox is outisde of original. O:{y0_o}, N: {y0}"
|
115 |
+
assert y1 >= y1_o, f"New bbox is outisde of original. O:{y1_o}, N: {y1}"
|
116 |
+
|
117 |
+
x0, y0, width, height = [int(_) for _ in [x0, y0, width, height]]
|
118 |
+
x1 = x0 + width
|
119 |
+
y1 = y0 + height
|
120 |
+
return np.array([x0, y0, x1, y1])
|
121 |
+
|
122 |
+
|
123 |
+
def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint):
|
124 |
+
keypoint = keypoint[:, :3] # only nose + eyes are relevant
|
125 |
+
kp_X = keypoint[0, :]
|
126 |
+
kp_Y = keypoint[1, :]
|
127 |
+
within_X = np.all(kp_X >= x0) and np.all(kp_X <= x1)
|
128 |
+
within_Y = np.all(kp_Y >= y0) and np.all(kp_Y <= y1)
|
129 |
+
return within_X and within_Y
|
130 |
+
|
131 |
+
|
132 |
+
def expand_bbox_simple(bbox, percentage):
|
133 |
+
x0, y0, x1, y1 = bbox.astype(float)
|
134 |
+
width = x1 - x0
|
135 |
+
height = y1 - y0
|
136 |
+
x_c = int(x0) + width // 2
|
137 |
+
y_c = int(y0) + height // 2
|
138 |
+
avg_size = max(width, height)
|
139 |
+
new_width = avg_size * (1 + percentage)
|
140 |
+
x0 = x_c - new_width // 2
|
141 |
+
y0 = y_c - new_width // 2
|
142 |
+
x1 = x_c + new_width // 2
|
143 |
+
y1 = y_c + new_width // 2
|
144 |
+
return np.array([x0, y0, x1, y1]).astype(int)
|
145 |
+
|
146 |
+
|
147 |
+
def pad_image(im, bbox, pad_value):
|
148 |
+
x0, y0, x1, y1 = bbox
|
149 |
+
if x0 < 0:
|
150 |
+
pad_im = np.zeros((im.shape[0], abs(x0), im.shape[2]),
|
151 |
+
dtype=np.uint8) + pad_value
|
152 |
+
im = np.concatenate((pad_im, im), axis=1)
|
153 |
+
x1 += abs(x0)
|
154 |
+
x0 = 0
|
155 |
+
if y0 < 0:
|
156 |
+
pad_im = np.zeros((abs(y0), im.shape[1], im.shape[2]),
|
157 |
+
dtype=np.uint8) + pad_value
|
158 |
+
im = np.concatenate((pad_im, im), axis=0)
|
159 |
+
y1 += abs(y0)
|
160 |
+
y0 = 0
|
161 |
+
if x1 >= im.shape[1]:
|
162 |
+
pad_im = np.zeros(
|
163 |
+
(im.shape[0], x1 - im.shape[1] + 1, im.shape[2]),
|
164 |
+
dtype=np.uint8) + pad_value
|
165 |
+
im = np.concatenate((im, pad_im), axis=1)
|
166 |
+
if y1 >= im.shape[0]:
|
167 |
+
pad_im = np.zeros(
|
168 |
+
(y1 - im.shape[0] + 1, im.shape[1], im.shape[2]),
|
169 |
+
dtype=np.uint8) + pad_value
|
170 |
+
im = np.concatenate((im, pad_im), axis=0)
|
171 |
+
return im[y0:y1, x0:x1]
|
172 |
+
|
173 |
+
|
174 |
+
def clip_box(bbox, im):
|
175 |
+
bbox[0] = max(0, bbox[0])
|
176 |
+
bbox[1] = max(0, bbox[1])
|
177 |
+
bbox[2] = min(im.shape[1] - 1, bbox[2])
|
178 |
+
bbox[3] = min(im.shape[0] - 1, bbox[3])
|
179 |
+
return bbox
|
180 |
+
|
181 |
+
|
182 |
+
def cut_face(im, bbox, simple_expand=False, pad_value=0, pad_im=True):
|
183 |
+
outside_im = (bbox < 0).any() or bbox[2] > im.shape[1] or bbox[3] > im.shape[0]
|
184 |
+
if simple_expand or (outside_im and pad_im):
|
185 |
+
return pad_image(im, bbox, pad_value)
|
186 |
+
bbox = clip_box(bbox, im)
|
187 |
+
x0, y0, x1, y1 = bbox
|
188 |
+
return im[y0:y1, x0:x1]
|
189 |
+
|
190 |
+
|
191 |
+
def expand_bbox(
|
192 |
+
bbox_ltrb, imshape, simple_expand, default_to_simple=False,
|
193 |
+
expansion_factor=0.35):
|
194 |
+
assert bbox_ltrb.shape == (4,), f"BBox shape was: {bbox.shape}"
|
195 |
+
bbox = bbox_ltrb.astype(float)
|
196 |
+
# FDF256 uses simple expand with ratio 0.4
|
197 |
+
if simple_expand:
|
198 |
+
return expand_bbox_simple(bbox, 0.4)
|
199 |
+
try:
|
200 |
+
return expand_bounding_box(bbox, expansion_factor, imshape)
|
201 |
+
except AssertionError:
|
202 |
+
return expand_bbox_simple(bbox, expansion_factor * 2)
|
203 |
+
|
dp2/detection/cse_mask_face_detector.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import lzma
|
3 |
+
import tops
|
4 |
+
from pathlib import Path
|
5 |
+
from dp2.detection.base import BaseDetector
|
6 |
+
from .utils import combine_cse_maskrcnn_dets
|
7 |
+
from face_detection import build_detector as build_face_detector
|
8 |
+
from .models.cse import CSEDetector
|
9 |
+
from .models.mask_rcnn import MaskRCNNDetector
|
10 |
+
from .structures import CSEPersonDetection, VehicleDetection, FaceDetection, PersonDetection
|
11 |
+
from tops import logger
|
12 |
+
|
13 |
+
|
14 |
+
def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
|
15 |
+
assert len(box1.shape) == 2
|
16 |
+
assert len(box2.shape) == 2
|
17 |
+
box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
|
18 |
+
# This can be batched
|
19 |
+
for i, box in enumerate(box1):
|
20 |
+
is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
|
21 |
+
is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
|
22 |
+
is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
|
23 |
+
box1_inside[i] = is_outside.logical_not().any()
|
24 |
+
return box1_inside
|
25 |
+
|
26 |
+
|
27 |
+
class CSeMaskFaceDetector(BaseDetector):
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
mask_rcnn_cfg,
|
32 |
+
face_detector_cfg: dict,
|
33 |
+
cse_cfg: dict,
|
34 |
+
face_post_process_cfg: dict,
|
35 |
+
cse_post_process_cfg,
|
36 |
+
score_threshold: float,
|
37 |
+
**kwargs
|
38 |
+
) -> None:
|
39 |
+
super().__init__(**kwargs)
|
40 |
+
self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
|
41 |
+
if "confidence_threshold" not in face_detector_cfg:
|
42 |
+
face_detector_cfg["confidence_threshold"] = score_threshold
|
43 |
+
if "score_thres" not in cse_cfg:
|
44 |
+
cse_cfg["score_thres"] = score_threshold
|
45 |
+
self.cse_detector = CSEDetector(**cse_cfg)
|
46 |
+
self.face_detector = build_face_detector(**face_detector_cfg, clip_boxes=True)
|
47 |
+
self.cse_post_process_cfg = cse_post_process_cfg
|
48 |
+
self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
|
49 |
+
self.mask_cse_iou_combine_threshold = self.cse_post_process_cfg.pop("iou_combine_threshold")
|
50 |
+
self.face_post_process_cfg = face_post_process_cfg
|
51 |
+
|
52 |
+
def __call__(self, *args, **kwargs):
|
53 |
+
return self.forward(*args, **kwargs)
|
54 |
+
|
55 |
+
def _detect_faces(self, im: torch.Tensor):
|
56 |
+
H, W = im.shape[1:]
|
57 |
+
im = im.float() - self.face_mean
|
58 |
+
im = self.face_detector.resize(im[None], 1.0)
|
59 |
+
boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
|
60 |
+
boxes_XYXY[:, [0, 2]] *= W
|
61 |
+
boxes_XYXY[:, [1, 3]] *= H
|
62 |
+
return boxes_XYXY.round().long()
|
63 |
+
|
64 |
+
def load_from_cache(self, cache_path: Path):
|
65 |
+
logger.log(f"Loading detection from cache path: {cache_path}",)
|
66 |
+
with lzma.open(cache_path, "rb") as fp:
|
67 |
+
state_dict = torch.load(fp, map_location="cpu")
|
68 |
+
kwargs = dict(
|
69 |
+
post_process_cfg=self.cse_post_process_cfg,
|
70 |
+
embed_map=self.cse_detector.embed_map,
|
71 |
+
**self.face_post_process_cfg
|
72 |
+
)
|
73 |
+
return [
|
74 |
+
state["cls"].from_state_dict(**kwargs, state_dict=state)
|
75 |
+
for state in state_dict
|
76 |
+
]
|
77 |
+
|
78 |
+
@torch.no_grad()
|
79 |
+
def forward(self, im: torch.Tensor):
|
80 |
+
maskrcnn_dets = self.mask_rcnn(im)
|
81 |
+
cse_dets = self.cse_detector(im)
|
82 |
+
embed_map = self.cse_detector.embed_map
|
83 |
+
print("Calling face detector.")
|
84 |
+
face_boxes = self._detect_faces(im).cpu()
|
85 |
+
maskrcnn_person = {
|
86 |
+
k: v[maskrcnn_dets["is_person"]] for k, v in maskrcnn_dets.items()
|
87 |
+
}
|
88 |
+
maskrcnn_other = {
|
89 |
+
k: v[maskrcnn_dets["is_person"].logical_not()] for k, v in maskrcnn_dets.items()
|
90 |
+
}
|
91 |
+
maskrcnn_other = VehicleDetection(maskrcnn_other["segmentation"])
|
92 |
+
combined_segmentation, cse_dets, matches = combine_cse_maskrcnn_dets(
|
93 |
+
maskrcnn_person["segmentation"], cse_dets, self.mask_cse_iou_combine_threshold)
|
94 |
+
|
95 |
+
persons_with_cse = CSEPersonDetection(
|
96 |
+
combined_segmentation, cse_dets, **self.cse_post_process_cfg,
|
97 |
+
embed_map=embed_map,orig_imshape_CHW=im.shape
|
98 |
+
)
|
99 |
+
persons_with_cse.pre_process()
|
100 |
+
not_matched = [i for i in range(maskrcnn_person["segmentation"].shape[0]) if i not in matches[:, 0]]
|
101 |
+
persons_without_cse = PersonDetection(
|
102 |
+
maskrcnn_person["segmentation"][not_matched], **self.cse_post_process_cfg,
|
103 |
+
orig_imshape_CHW=im.shape
|
104 |
+
)
|
105 |
+
persons_without_cse.pre_process()
|
106 |
+
|
107 |
+
face_boxes_covered = box1_inside_box2(face_boxes, persons_with_cse.dilated_boxes).logical_or(
|
108 |
+
box1_inside_box2(face_boxes, persons_without_cse.dilated_boxes)
|
109 |
+
)
|
110 |
+
face_boxes = face_boxes[face_boxes_covered.logical_not()]
|
111 |
+
face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
|
112 |
+
|
113 |
+
# Order matters. The anonymizer will anonymize FIFO.
|
114 |
+
# Later detections will overwrite.
|
115 |
+
all_detections = [face_boxes, maskrcnn_other, persons_without_cse, persons_with_cse]
|
116 |
+
return all_detections
|
dp2/detection/face_detector.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import lzma
|
3 |
+
import tops
|
4 |
+
from pathlib import Path
|
5 |
+
from dp2.detection.base import BaseDetector
|
6 |
+
from face_detection import build_detector as build_face_detector
|
7 |
+
from .structures import FaceDetection
|
8 |
+
from tops import logger
|
9 |
+
|
10 |
+
|
11 |
+
def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
|
12 |
+
assert len(box1.shape) == 2
|
13 |
+
assert len(box2.shape) == 2
|
14 |
+
box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
|
15 |
+
# This can be batched
|
16 |
+
for i, box in enumerate(box1):
|
17 |
+
is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
|
18 |
+
is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
|
19 |
+
is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
|
20 |
+
box1_inside[i] = is_outside.logical_not().any()
|
21 |
+
return box1_inside
|
22 |
+
|
23 |
+
|
24 |
+
class FaceDetector(BaseDetector):
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
face_detector_cfg: dict,
|
29 |
+
score_threshold: float,
|
30 |
+
face_post_process_cfg: dict,
|
31 |
+
**kwargs
|
32 |
+
) -> None:
|
33 |
+
super().__init__(**kwargs)
|
34 |
+
self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold)
|
35 |
+
self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
|
36 |
+
self.face_post_process_cfg = face_post_process_cfg
|
37 |
+
|
38 |
+
def __call__(self, *args, **kwargs):
|
39 |
+
return self.forward(*args, **kwargs)
|
40 |
+
|
41 |
+
def _detect_faces(self, im: torch.Tensor):
|
42 |
+
H, W = im.shape[1:]
|
43 |
+
im = im.float() - self.face_mean
|
44 |
+
im = self.face_detector.resize(im[None], 1.0)
|
45 |
+
boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
|
46 |
+
boxes_XYXY[:, [0, 2]] *= W
|
47 |
+
boxes_XYXY[:, [1, 3]] *= H
|
48 |
+
return boxes_XYXY.round().long().cpu()
|
49 |
+
|
50 |
+
@torch.no_grad()
|
51 |
+
def forward(self, im: torch.Tensor):
|
52 |
+
face_boxes = self._detect_faces(im)
|
53 |
+
face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
|
54 |
+
return [face_boxes]
|
55 |
+
|
56 |
+
def load_from_cache(self, cache_path: Path):
|
57 |
+
logger.log(f"Loading detection from cache path: {cache_path}")
|
58 |
+
with lzma.open(cache_path, "rb") as fp:
|
59 |
+
state_dict = torch.load(fp)
|
60 |
+
return [
|
61 |
+
state["cls"].from_state_dict(state_dict=state, **self.face_post_process_cfg) for state in state_dict
|
62 |
+
]
|
dp2/detection/models/__init__.py
ADDED
File without changes
|
dp2/detection/models/cse.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import List
|
3 |
+
import tops
|
4 |
+
from torchvision.transforms.functional import InterpolationMode, resize
|
5 |
+
from densepose.data.utils import get_class_to_mesh_name_mapping
|
6 |
+
from densepose import add_densepose_config
|
7 |
+
from densepose.structures import DensePoseEmbeddingPredictorOutput
|
8 |
+
from densepose.vis.extractor import DensePoseOutputsExtractor
|
9 |
+
from densepose.modeling import build_densepose_embedder
|
10 |
+
from detectron2.config import get_cfg
|
11 |
+
from detectron2.data.transforms import ResizeShortestEdge
|
12 |
+
from detectron2.checkpoint.detection_checkpoint import DetectionCheckpointer
|
13 |
+
from detectron2.modeling import build_model
|
14 |
+
|
15 |
+
|
16 |
+
model_urls = {
|
17 |
+
"https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x/250713061/model_final_1d3314.pkl",
|
18 |
+
"https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl",
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def cse_det_to_global(boxes_XYXY, S: torch.Tensor, imshape):
|
23 |
+
assert len(S.shape) == 3
|
24 |
+
H, W = imshape
|
25 |
+
N = len(boxes_XYXY)
|
26 |
+
segmentation = torch.zeros((N, H, W), dtype=torch.bool, device=S.device)
|
27 |
+
boxes_XYXY = boxes_XYXY.long()
|
28 |
+
for i in range(N):
|
29 |
+
x0, y0, x1, y1 = boxes_XYXY[i]
|
30 |
+
assert x0 >= 0 and y0 >= 0
|
31 |
+
assert x1 <= imshape[1]
|
32 |
+
assert y1 <= imshape[0]
|
33 |
+
h = y1 - y0
|
34 |
+
w = x1 - x0
|
35 |
+
segmentation[i:i+1, y0:y1, x0:x1] = resize(S[i:i+1], (h, w), interpolation=InterpolationMode.NEAREST) > 0
|
36 |
+
return segmentation
|
37 |
+
|
38 |
+
|
39 |
+
class CSEDetector:
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
cfg_url: str = "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
|
44 |
+
cfg_2_download: List[str] = [
|
45 |
+
"https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
|
46 |
+
"https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml",
|
47 |
+
"https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml"],
|
48 |
+
score_thres: float = 0.9,
|
49 |
+
nms_thresh: float = None,
|
50 |
+
) -> None:
|
51 |
+
with tops.logger.capture_log_stdout():
|
52 |
+
cfg = get_cfg()
|
53 |
+
self.device = tops.get_device()
|
54 |
+
add_densepose_config(cfg)
|
55 |
+
cfg_path = tops.download_file(cfg_url)
|
56 |
+
for p in cfg_2_download:
|
57 |
+
tops.download_file(p)
|
58 |
+
with tops.logger.capture_log_stdout():
|
59 |
+
cfg.merge_from_file(cfg_path)
|
60 |
+
assert cfg_url in model_urls, cfg_url
|
61 |
+
model_path = tops.download_file(model_urls[cfg_url])
|
62 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres
|
63 |
+
if nms_thresh is not None:
|
64 |
+
cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = nms_thresh
|
65 |
+
cfg.MODEL.WEIGHTS = str(model_path)
|
66 |
+
cfg.MODEL.DEVICE = str(self.device)
|
67 |
+
cfg.freeze()
|
68 |
+
with tops.logger.capture_log_stdout():
|
69 |
+
self.model = build_model(cfg)
|
70 |
+
self.model.eval()
|
71 |
+
DetectionCheckpointer(self.model).load(str(model_path))
|
72 |
+
self.input_format = cfg.INPUT.FORMAT
|
73 |
+
self.densepose_extractor = DensePoseOutputsExtractor()
|
74 |
+
self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
|
75 |
+
|
76 |
+
self.embedder = build_densepose_embedder(cfg)
|
77 |
+
self.mesh_vertex_embeddings = {
|
78 |
+
mesh_name: self.embedder(mesh_name).to(self.device)
|
79 |
+
for mesh_name in self.class_to_mesh_name.values()
|
80 |
+
if self.embedder.has_embeddings(mesh_name)
|
81 |
+
}
|
82 |
+
self.cfg = cfg
|
83 |
+
self.embed_map = self.mesh_vertex_embeddings["smpl_27554"]
|
84 |
+
tops.logger.log("CSEDetector built.")
|
85 |
+
|
86 |
+
def __call__(self, *args, **kwargs):
|
87 |
+
return self.forward(*args, **kwargs)
|
88 |
+
|
89 |
+
def resize_im(self, im):
|
90 |
+
H, W = im.shape[1:]
|
91 |
+
newH, newW = ResizeShortestEdge.get_output_shape(
|
92 |
+
H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
|
93 |
+
return resize(
|
94 |
+
im, (newH, newW), InterpolationMode.BILINEAR, antialias=True)
|
95 |
+
|
96 |
+
@torch.no_grad()
|
97 |
+
def forward(self, im):
|
98 |
+
assert im.dtype == torch.uint8
|
99 |
+
if self.input_format == "BGR":
|
100 |
+
im = im.flip(0)
|
101 |
+
H, W = im.shape[1:]
|
102 |
+
im = self.resize_im(im)
|
103 |
+
output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"]
|
104 |
+
scores = output.get("scores")
|
105 |
+
if len(scores) == 0:
|
106 |
+
return dict(
|
107 |
+
instance_segmentation=torch.empty((0, 0, 112, 112), dtype=torch.bool, device=im.device),
|
108 |
+
instance_embedding=torch.empty((0, 16, 112, 112), dtype=torch.float32, device=im.device),
|
109 |
+
embed_map=self.mesh_vertex_embeddings["smpl_27554"],
|
110 |
+
bbox_XYXY=torch.empty((0, 4), dtype=torch.long, device=im.device),
|
111 |
+
im_segmentation=torch.empty((0, H, W), dtype=torch.bool, device=im.device),
|
112 |
+
scores=torch.empty((0), dtype=torch.float, device=im.device)
|
113 |
+
)
|
114 |
+
pred_densepose, boxes_xywh, classes = self.densepose_extractor(output)
|
115 |
+
assert isinstance(pred_densepose, DensePoseEmbeddingPredictorOutput), pred_densepose
|
116 |
+
S = pred_densepose.coarse_segm.argmax(dim=1) # Segmentation channel Nx2xHxW (2 because only 2 classes)
|
117 |
+
E = pred_densepose.embedding
|
118 |
+
mesh_name = self.class_to_mesh_name[classes[0]]
|
119 |
+
assert mesh_name == "smpl_27554"
|
120 |
+
x0, y0, w, h = [boxes_xywh[:, i] for i in range(4)]
|
121 |
+
boxes_XYXY = torch.stack((x0, y0, x0+w, y0+h), dim=-1)
|
122 |
+
boxes_XYXY = boxes_XYXY.round_().long()
|
123 |
+
|
124 |
+
non_empty_boxes = (boxes_XYXY[:, :2] == boxes_XYXY[:, 2:]).any(dim=1).logical_not()
|
125 |
+
S = S[non_empty_boxes]
|
126 |
+
E = E[non_empty_boxes]
|
127 |
+
boxes_XYXY = boxes_XYXY[non_empty_boxes]
|
128 |
+
scores = scores[non_empty_boxes]
|
129 |
+
im_segmentation = cse_det_to_global(boxes_XYXY, S, [H, W])
|
130 |
+
return dict(
|
131 |
+
instance_segmentation=S, instance_embedding=E,
|
132 |
+
bbox_XYXY=boxes_XYXY,
|
133 |
+
im_segmentation=im_segmentation,
|
134 |
+
scores=scores.view(-1))
|
135 |
+
|
dp2/detection/models/keypoint_maskrcnn.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
4 |
+
from detectron2.modeling.roi_heads import CascadeROIHeads, StandardROIHeads
|
5 |
+
from detectron2.data.transforms import ResizeShortestEdge
|
6 |
+
from detectron2.structures import Instances
|
7 |
+
from detectron2 import model_zoo
|
8 |
+
from detectron2.config import instantiate
|
9 |
+
from detectron2.config import LazyCall as L
|
10 |
+
from PIL import Image
|
11 |
+
import tops
|
12 |
+
import functools
|
13 |
+
from torchvision.transforms.functional import resize
|
14 |
+
|
15 |
+
|
16 |
+
def get_rn50_fpn_keypoint_rcnn(weight_path: str):
|
17 |
+
from detectron2.modeling.poolers import ROIPooler
|
18 |
+
from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead
|
19 |
+
from detectron2.layers import ShapeSpec
|
20 |
+
model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model
|
21 |
+
model.roi_heads.update(
|
22 |
+
num_classes=1,
|
23 |
+
keypoint_in_features=["p2", "p3", "p4", "p5"],
|
24 |
+
keypoint_pooler=L(ROIPooler)(
|
25 |
+
output_size=14,
|
26 |
+
scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32),
|
27 |
+
sampling_ratio=0,
|
28 |
+
pooler_type="ROIAlignV2",
|
29 |
+
),
|
30 |
+
keypoint_head=L(KRCNNConvDeconvUpsampleHead)(
|
31 |
+
input_shape=ShapeSpec(channels=256, width=14, height=14),
|
32 |
+
num_keypoints=17,
|
33 |
+
conv_dims=[512] * 8,
|
34 |
+
loss_normalizer="visible",
|
35 |
+
),
|
36 |
+
)
|
37 |
+
|
38 |
+
# Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2.
|
39 |
+
# 1000 proposals per-image is found to hurt box AP.
|
40 |
+
# Therefore we increase it to 1500 per-image.
|
41 |
+
model.proposal_generator.post_nms_topk = (1500, 1000)
|
42 |
+
|
43 |
+
# Keypoint AP degrades (though box AP improves) when using plain L1 loss
|
44 |
+
model.roi_heads.box_predictor.smooth_l1_beta = 0.5
|
45 |
+
model = instantiate(model)
|
46 |
+
|
47 |
+
dataloader = model_zoo.get_config("common/data/coco_keypoint.py").dataloader
|
48 |
+
test_transform = instantiate(dataloader.test.mapper.augmentations)
|
49 |
+
DetectionCheckpointer(model).load(weight_path)
|
50 |
+
return model, test_transform
|
51 |
+
|
52 |
+
|
53 |
+
models = {
|
54 |
+
"rn50_fpn_maskrcnn": functools.partial(get_rn50_fpn_keypoint_rcnn, weight_path="https://folk.ntnu.no/haakohu/checkpoints/maskrcnn_keypoint/keypoint_maskrcnn_R_50_FPN_1x.pth")
|
55 |
+
}
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
class KeypointMaskRCNN:
|
61 |
+
|
62 |
+
def __init__(self, model_name: str, score_threshold: float) -> None:
|
63 |
+
assert model_name in models, f"Did not find {model_name} in models"
|
64 |
+
model, test_transform = models[model_name]()
|
65 |
+
self.model = model.eval().to(tops.get_device())
|
66 |
+
if isinstance(self.model.roi_heads, CascadeROIHeads):
|
67 |
+
for head in self.model.roi_heads.box_predictors:
|
68 |
+
assert hasattr(head, "test_score_thresh")
|
69 |
+
head.test_score_thresh = score_threshold
|
70 |
+
else:
|
71 |
+
assert isinstance(self.model.roi_heads, StandardROIHeads)
|
72 |
+
assert hasattr(self.model.roi_heads.box_predictor, "test_score_thresh")
|
73 |
+
self.model.roi_heads.box_predictor.test_score_thresh = score_threshold
|
74 |
+
|
75 |
+
self.test_transform = test_transform
|
76 |
+
assert len(self.test_transform) == 1
|
77 |
+
self.test_transform = self.test_transform[0]
|
78 |
+
assert isinstance(self.test_transform, ResizeShortestEdge)
|
79 |
+
assert self.test_transform.interp == Image.BILINEAR
|
80 |
+
self.image_format = self.model.input_format
|
81 |
+
|
82 |
+
def resize_im(self, im):
|
83 |
+
H, W = im.shape[-2:]
|
84 |
+
if self.test_transform.is_range:
|
85 |
+
size = np.random.randint(self.test_transform.short_edge_length[0], self.test_transform.short_edge_length[1] + 1)
|
86 |
+
else:
|
87 |
+
size = np.random.choice(self.test_transform.short_edge_length)
|
88 |
+
newH, newW = ResizeShortestEdge.get_output_shape(H, W, size, self.test_transform.max_size)
|
89 |
+
return resize(
|
90 |
+
im, (newH, newW), antialias=True)
|
91 |
+
|
92 |
+
def __call__(self, *args, **kwargs):
|
93 |
+
return self.forward(*args, **kwargs)
|
94 |
+
|
95 |
+
@torch.no_grad()
|
96 |
+
def forward(self, im: torch.Tensor) -> Instances:
|
97 |
+
assert im.ndim == 3
|
98 |
+
if self.image_format == "BGR":
|
99 |
+
im = im.flip(0)
|
100 |
+
H, W = im.shape[-2:]
|
101 |
+
im = self.resize_im(im)
|
102 |
+
im = im.float()
|
103 |
+
inputs = dict(image=im, height=H, width=W)
|
104 |
+
# instances contains
|
105 |
+
# dict_keys(['pred_boxes', 'scores', 'pred_classes', 'pred_masks', 'pred_keypoints', 'pred_keypoint_heatmaps'])
|
106 |
+
instances = self.model([inputs])[0]["instances"]
|
107 |
+
return dict(
|
108 |
+
scores=instances.get("scores").cpu(),
|
109 |
+
segmentation=instances.get("pred_masks").cpu(),
|
110 |
+
keypoints=instances.get("pred_keypoints").cpu()
|
111 |
+
)
|
dp2/detection/models/mask_rcnn.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tops
|
3 |
+
from detectron2.modeling import build_model
|
4 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
5 |
+
from detectron2.structures import Boxes
|
6 |
+
from detectron2.data import MetadataCatalog
|
7 |
+
from detectron2 import model_zoo
|
8 |
+
from typing import Dict
|
9 |
+
from detectron2.data.transforms import ResizeShortestEdge
|
10 |
+
from torchvision.transforms.functional import resize
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
model_urls = {
|
15 |
+
"COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml": "https://dl.fbaipublicfiles.com/detectron2/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x/139653917/model_final_2d9806.pkl",
|
16 |
+
|
17 |
+
}
|
18 |
+
class MaskRCNNDetector:
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
cfg_name: str = "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml",
|
23 |
+
score_thres: float = 0.9,
|
24 |
+
class_filter=["person"], #["car", "bicycle","truck", "bus", "backpack"]
|
25 |
+
fp16_inference: bool = False
|
26 |
+
) -> None:
|
27 |
+
cfg = model_zoo.get_config(cfg_name)
|
28 |
+
cfg.MODEL.DEVICE = str(tops.get_device())
|
29 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres
|
30 |
+
cfg.freeze()
|
31 |
+
self.cfg = cfg
|
32 |
+
with tops.logger.capture_log_stdout():
|
33 |
+
self.model = build_model(cfg)
|
34 |
+
DetectionCheckpointer(self.model).load(model_urls[cfg_name])
|
35 |
+
self.model.eval()
|
36 |
+
self.input_format = cfg.INPUT.FORMAT
|
37 |
+
self.class_names = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes
|
38 |
+
self.class_to_keep = set([self.class_names.index(cls_) for cls_ in class_filter])
|
39 |
+
self.person_class = self.class_names.index("person")
|
40 |
+
self.fp16_inference = fp16_inference
|
41 |
+
tops.logger.log("Mask R-CNN built.")
|
42 |
+
|
43 |
+
def __call__(self, *args, **kwargs):
|
44 |
+
return self.forward(*args, **kwargs)
|
45 |
+
|
46 |
+
def resize_im(self, im):
|
47 |
+
H, W = im.shape[1:]
|
48 |
+
newH, newW = ResizeShortestEdge.get_output_shape(
|
49 |
+
H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
|
50 |
+
return resize(
|
51 |
+
im, (newH, newW), antialias=True)
|
52 |
+
|
53 |
+
@torch.no_grad()
|
54 |
+
def forward(self, im: torch.Tensor):
|
55 |
+
if self.input_format == "BGR":
|
56 |
+
im = im.flip(0)
|
57 |
+
else:
|
58 |
+
assert self.input_format == "RGB"
|
59 |
+
H, W = im.shape[-2:]
|
60 |
+
im = self.resize_im(im)
|
61 |
+
with torch.cuda.amp.autocast(enabled=self.fp16_inference):
|
62 |
+
output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"]
|
63 |
+
scores = output.get("scores")
|
64 |
+
N = len(scores)
|
65 |
+
classes = output.get("pred_classes")
|
66 |
+
idx2keep = [i for i in range(N) if classes[i].tolist() in self.class_to_keep]
|
67 |
+
classes = classes[idx2keep]
|
68 |
+
assert isinstance(output.get("pred_boxes"), Boxes)
|
69 |
+
segmentation = output.get("pred_masks")[idx2keep]
|
70 |
+
assert segmentation.dtype == torch.bool
|
71 |
+
is_person = classes == self.person_class
|
72 |
+
return {
|
73 |
+
"scores": output.get("scores")[idx2keep],
|
74 |
+
"segmentation": segmentation,
|
75 |
+
"classes": output.get("pred_classes")[idx2keep],
|
76 |
+
"is_person": is_person
|
77 |
+
}
|
78 |
+
|
dp2/detection/person_detector.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import lzma
|
3 |
+
from dp2.detection.base import BaseDetector
|
4 |
+
from .utils import combine_cse_maskrcnn_dets
|
5 |
+
from .models.cse import CSEDetector
|
6 |
+
from .models.mask_rcnn import MaskRCNNDetector
|
7 |
+
from .models.keypoint_maskrcnn import KeypointMaskRCNN
|
8 |
+
from .structures import CSEPersonDetection, PersonDetection
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
|
12 |
+
class CSEPersonDetector(BaseDetector):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
score_threshold: float,
|
16 |
+
mask_rcnn_cfg: dict,
|
17 |
+
cse_cfg: dict,
|
18 |
+
cse_post_process_cfg: dict,
|
19 |
+
**kwargs
|
20 |
+
) -> None:
|
21 |
+
super().__init__(**kwargs)
|
22 |
+
self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
|
23 |
+
self.cse_detector = CSEDetector(**cse_cfg, score_thres=score_threshold)
|
24 |
+
self.post_process_cfg = cse_post_process_cfg
|
25 |
+
self.iou_combine_threshold = self.post_process_cfg.pop("iou_combine_threshold")
|
26 |
+
|
27 |
+
def __call__(self, *args, **kwargs):
|
28 |
+
return self.forward(*args, **kwargs)
|
29 |
+
|
30 |
+
def load_from_cache(self, cache_path: Path):
|
31 |
+
with lzma.open(cache_path, "rb") as fp:
|
32 |
+
state_dict = torch.load(fp)
|
33 |
+
kwargs = dict(
|
34 |
+
post_process_cfg=self.post_process_cfg,
|
35 |
+
embed_map=self.cse_detector.embed_map,
|
36 |
+
)
|
37 |
+
return [
|
38 |
+
state["cls"].from_state_dict(**kwargs, state_dict=state)
|
39 |
+
for state in state_dict
|
40 |
+
]
|
41 |
+
|
42 |
+
@torch.no_grad()
|
43 |
+
def forward(self, im: torch.Tensor, cse_dets=None):
|
44 |
+
mask_dets = self.mask_rcnn(im)
|
45 |
+
if cse_dets is None:
|
46 |
+
cse_dets = self.cse_detector(im)
|
47 |
+
segmentation = mask_dets["segmentation"]
|
48 |
+
segmentation, cse_dets, _ = combine_cse_maskrcnn_dets(
|
49 |
+
segmentation, cse_dets, self.iou_combine_threshold
|
50 |
+
)
|
51 |
+
det = CSEPersonDetection(
|
52 |
+
segmentation=segmentation,
|
53 |
+
cse_dets=cse_dets,
|
54 |
+
embed_map=self.cse_detector.embed_map,
|
55 |
+
orig_imshape_CHW=im.shape,
|
56 |
+
**self.post_process_cfg
|
57 |
+
)
|
58 |
+
return [det]
|
59 |
+
|
60 |
+
|
61 |
+
class MaskRCNNPersonDetector(BaseDetector):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
score_threshold: float,
|
65 |
+
mask_rcnn_cfg: dict,
|
66 |
+
cse_post_process_cfg: dict,
|
67 |
+
**kwargs
|
68 |
+
) -> None:
|
69 |
+
super().__init__(**kwargs)
|
70 |
+
self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
|
71 |
+
self.post_process_cfg = cse_post_process_cfg
|
72 |
+
|
73 |
+
def __call__(self, *args, **kwargs):
|
74 |
+
return self.forward(*args, **kwargs)
|
75 |
+
|
76 |
+
def load_from_cache(self, cache_path: Path):
|
77 |
+
with lzma.open(cache_path, "rb") as fp:
|
78 |
+
state_dict = torch.load(fp)
|
79 |
+
kwargs = dict(
|
80 |
+
post_process_cfg=self.post_process_cfg,
|
81 |
+
)
|
82 |
+
return [
|
83 |
+
state["cls"].from_state_dict(**kwargs, state_dict=state)
|
84 |
+
for state in state_dict
|
85 |
+
]
|
86 |
+
|
87 |
+
@torch.no_grad()
|
88 |
+
def forward(self, im: torch.Tensor):
|
89 |
+
mask_dets = self.mask_rcnn(im)
|
90 |
+
segmentation = mask_dets["segmentation"]
|
91 |
+
det = PersonDetection(
|
92 |
+
segmentation, **self.post_process_cfg, orig_imshape_CHW=im.shape
|
93 |
+
)
|
94 |
+
return [det]
|
95 |
+
|
96 |
+
|
97 |
+
class KeypointMaskRCNNPersonDetector(BaseDetector):
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
score_threshold: float,
|
101 |
+
mask_rcnn_cfg: dict,
|
102 |
+
cse_post_process_cfg: dict,
|
103 |
+
**kwargs
|
104 |
+
) -> None:
|
105 |
+
super().__init__(**kwargs)
|
106 |
+
self.mask_rcnn = KeypointMaskRCNN(
|
107 |
+
**mask_rcnn_cfg, score_threshold=score_threshold
|
108 |
+
)
|
109 |
+
self.post_process_cfg = cse_post_process_cfg
|
110 |
+
|
111 |
+
def __call__(self, *args, **kwargs):
|
112 |
+
return self.forward(*args, **kwargs)
|
113 |
+
|
114 |
+
def load_from_cache(self, cache_path: Path):
|
115 |
+
with lzma.open(cache_path, "rb") as fp:
|
116 |
+
state_dict = torch.load(fp)
|
117 |
+
kwargs = dict(
|
118 |
+
post_process_cfg=self.post_process_cfg,
|
119 |
+
)
|
120 |
+
return [
|
121 |
+
state["cls"].from_state_dict(**kwargs, state_dict=state)
|
122 |
+
for state in state_dict
|
123 |
+
]
|
124 |
+
|
125 |
+
@torch.no_grad()
|
126 |
+
def forward(self, im: torch.Tensor):
|
127 |
+
mask_dets = self.mask_rcnn(im)
|
128 |
+
segmentation = mask_dets["segmentation"]
|
129 |
+
det = PersonDetection(
|
130 |
+
segmentation,
|
131 |
+
**self.post_process_cfg,
|
132 |
+
orig_imshape_CHW=im.shape,
|
133 |
+
keypoints=mask_dets["keypoints"]
|
134 |
+
)
|
135 |
+
return [det]
|
dp2/detection/structures.py
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from dp2 import utils
|
4 |
+
from dp2.utils import vis_utils, crop_box
|
5 |
+
from .utils import (
|
6 |
+
cut_pad_resize, masks_to_boxes,
|
7 |
+
get_kernel, transform_embedding, initialize_cse_boxes
|
8 |
+
)
|
9 |
+
from .box_utils import get_expanded_bbox, include_box
|
10 |
+
import torchvision
|
11 |
+
import tops
|
12 |
+
from .box_utils_fdf import expand_bbox as expand_bbox_fdf
|
13 |
+
|
14 |
+
|
15 |
+
class VehicleDetection:
|
16 |
+
|
17 |
+
def __init__(self, segmentation: torch.BoolTensor) -> None:
|
18 |
+
self.segmentation = segmentation
|
19 |
+
self.boxes = masks_to_boxes(segmentation)
|
20 |
+
assert self.boxes.shape[1] == 4, self.boxes.shape
|
21 |
+
self.n_detections = self.segmentation.shape[0]
|
22 |
+
area = (self.boxes[:, 3] - self.boxes[:, 1]) * (self.boxes[:, 2] - self.boxes[:, 0])
|
23 |
+
|
24 |
+
sorted_idx = torch.argsort(area, descending=True)
|
25 |
+
self.segmentation = self.segmentation[sorted_idx]
|
26 |
+
self.boxes = self.boxes[sorted_idx].cpu()
|
27 |
+
|
28 |
+
def pre_process(self):
|
29 |
+
pass
|
30 |
+
|
31 |
+
def get_crop(self, idx: int, im):
|
32 |
+
assert idx < len(self)
|
33 |
+
box = self.boxes[idx]
|
34 |
+
im = crop_box(self.im, box)
|
35 |
+
mask = crop_box(self.segmentation[idx])
|
36 |
+
mask = mask == 0
|
37 |
+
return dict(img=im, mask=mask.float(), boxes=box)
|
38 |
+
|
39 |
+
def visualize(self, im):
|
40 |
+
if len(self) == 0:
|
41 |
+
return im
|
42 |
+
im = vis_utils.draw_mask(im.clone(), self.segmentation.logical_not())
|
43 |
+
return im
|
44 |
+
|
45 |
+
def __len__(self):
|
46 |
+
return self.n_detections
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def from_state_dict(state_dict, **kwargs):
|
50 |
+
numel = np.prod(state_dict["shape"])
|
51 |
+
arr = np.unpackbits(state_dict["segmentation"].numpy(), count=numel)
|
52 |
+
segmentation = tops.to_cuda(torch.from_numpy(arr)).view(state_dict["shape"])
|
53 |
+
return VehicleDetection(segmentation)
|
54 |
+
|
55 |
+
def state_dict(self, **kwargs):
|
56 |
+
segmentation = torch.from_numpy(np.packbits(self.segmentation.bool().cpu().numpy()))
|
57 |
+
return dict(segmentation=segmentation, cls=self.__class__, shape=self.segmentation.shape)
|
58 |
+
|
59 |
+
|
60 |
+
class FaceDetection:
|
61 |
+
|
62 |
+
def __init__(self, boxes_ltrb: torch.LongTensor, target_imsize, fdf128_expand: bool, **kwargs) -> None:
|
63 |
+
self.boxes = boxes_ltrb.cpu()
|
64 |
+
assert self.boxes.shape[1] == 4, self.boxes.shape
|
65 |
+
self.target_imsize = tuple(target_imsize)
|
66 |
+
# Sory by area to paste in largest faces last
|
67 |
+
area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1)
|
68 |
+
idx = area.argsort(descending=False)
|
69 |
+
self.boxes = self.boxes[idx]
|
70 |
+
self.fdf128_expand = fdf128_expand
|
71 |
+
|
72 |
+
def visualize(self, im):
|
73 |
+
if len(self) == 0:
|
74 |
+
return im
|
75 |
+
orig_device = im.device
|
76 |
+
for box in self.boxes:
|
77 |
+
simple_expand = False if self.fdf128_expand else True
|
78 |
+
e_box = torch.from_numpy(expand_bbox_fdf(box.numpy(), im.shape[-2:], simple_expand))
|
79 |
+
im = torchvision.utils.draw_bounding_boxes(im.cpu(), e_box[None], colors=(0, 0, 255), width=2)
|
80 |
+
im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2)
|
81 |
+
|
82 |
+
return im.to(device=orig_device)
|
83 |
+
|
84 |
+
def get_crop(self, idx: int, im):
|
85 |
+
assert idx < len(self)
|
86 |
+
box = self.boxes[idx].numpy()
|
87 |
+
expanded_boxes = expand_bbox_fdf(box, im.shape[-2:], True)
|
88 |
+
im = cut_pad_resize(im, expanded_boxes, self.target_imsize, fdf_resize=True)
|
89 |
+
area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1)
|
90 |
+
|
91 |
+
# Find the square mask corresponding to box.
|
92 |
+
box_mask = box.copy().astype(float)
|
93 |
+
box_mask[[0, 2]] -= expanded_boxes[0]
|
94 |
+
box_mask[[1, 3]] -= expanded_boxes[1]
|
95 |
+
|
96 |
+
width = expanded_boxes[2] - expanded_boxes[0]
|
97 |
+
resize_factor = self.target_imsize[0] / width
|
98 |
+
box_mask = (box_mask * resize_factor).astype(int)
|
99 |
+
mask = torch.ones((1, *self.target_imsize), device=im.device, dtype=torch.float32)
|
100 |
+
crop_box(mask, box_mask).fill_(0)
|
101 |
+
return dict(
|
102 |
+
img=im[None], mask=mask[None],
|
103 |
+
boxes=torch.from_numpy(expanded_boxes).view(1, -1))
|
104 |
+
|
105 |
+
def __len__(self):
|
106 |
+
return len(self.boxes)
|
107 |
+
|
108 |
+
@staticmethod
|
109 |
+
def from_state_dict(state_dict, **kwargs):
|
110 |
+
return FaceDetection(state_dict["boxes"].cpu(), **kwargs)
|
111 |
+
|
112 |
+
def state_dict(self, **kwargs):
|
113 |
+
return dict(boxes=self.boxes, cls=self.__class__)
|
114 |
+
|
115 |
+
def pre_process(self):
|
116 |
+
pass
|
117 |
+
|
118 |
+
|
119 |
+
def remove_dilate_in_pad(mask: torch.Tensor, exp_box, orig_imshape):
|
120 |
+
"""
|
121 |
+
Dilation happens after padding, which could place dilation in the padded area.
|
122 |
+
Remove this.
|
123 |
+
"""
|
124 |
+
x0, y0, x1, y1 = exp_box
|
125 |
+
H, W = orig_imshape
|
126 |
+
# Padding in original image space
|
127 |
+
p_y0 = max(0, -y0)
|
128 |
+
p_y1 = max(y1 - H, 0)
|
129 |
+
p_x0 = max(0, -x0)
|
130 |
+
p_x1 = max(x1 - W, 0)
|
131 |
+
resize_ratio = mask.shape[-2] / (y1-y0)
|
132 |
+
p_x0, p_y0, p_x1, p_y1 = [(_*resize_ratio).floor().long() for _ in [p_x0, p_y0, p_x1, p_y1]]
|
133 |
+
mask[..., :p_y0, :] = 0
|
134 |
+
mask[..., :p_x0] = 0
|
135 |
+
mask[..., mask.shape[-2] - p_y1:, :] = 0
|
136 |
+
mask[..., mask.shape[-1] - p_x1:] = 0
|
137 |
+
|
138 |
+
|
139 |
+
class CSEPersonDetection:
|
140 |
+
|
141 |
+
def __init__(self,
|
142 |
+
segmentation, cse_dets,
|
143 |
+
target_imsize,
|
144 |
+
exp_bbox_cfg, exp_bbox_filter,
|
145 |
+
dilation_percentage: float,
|
146 |
+
embed_map: torch.Tensor,
|
147 |
+
orig_imshape_CHW,
|
148 |
+
normalize_embedding: bool) -> None:
|
149 |
+
self.segmentation = segmentation
|
150 |
+
self.cse_dets = cse_dets
|
151 |
+
self.target_imsize = list(target_imsize)
|
152 |
+
self.pre_processed = False
|
153 |
+
self.exp_bbox_cfg = exp_bbox_cfg
|
154 |
+
self.exp_bbox_filter = exp_bbox_filter
|
155 |
+
self.dilation_percentage = dilation_percentage
|
156 |
+
self.embed_map = embed_map
|
157 |
+
self.normalize_embedding = normalize_embedding
|
158 |
+
if self.normalize_embedding:
|
159 |
+
embed_map_mean = self.embed_map.mean(dim=0, keepdim=True)
|
160 |
+
embed_map_rstd = ((self.embed_map - embed_map_mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
|
161 |
+
self.embed_map_normalized = (self.embed_map - embed_map_mean) * embed_map_rstd
|
162 |
+
self.orig_imshape_CHW = orig_imshape_CHW
|
163 |
+
|
164 |
+
@torch.no_grad()
|
165 |
+
def pre_process(self):
|
166 |
+
if self.pre_processed:
|
167 |
+
return
|
168 |
+
boxes = initialize_cse_boxes(self.segmentation, self.cse_dets["bbox_XYXY"]).cpu()
|
169 |
+
expanded_boxes = []
|
170 |
+
included_boxes = []
|
171 |
+
for i in range(len(boxes)):
|
172 |
+
exp_box = get_expanded_bbox(
|
173 |
+
boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg,
|
174 |
+
target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1])
|
175 |
+
if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter):
|
176 |
+
continue
|
177 |
+
included_boxes.append(i)
|
178 |
+
expanded_boxes.append(exp_box)
|
179 |
+
expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4)
|
180 |
+
self.segmentation = self.segmentation[included_boxes]
|
181 |
+
self.cse_dets = {k: v[included_boxes] for k, v in self.cse_dets.items()}
|
182 |
+
|
183 |
+
self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool)
|
184 |
+
area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes))
|
185 |
+
for i, box in enumerate(expanded_boxes):
|
186 |
+
self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0]
|
187 |
+
|
188 |
+
dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage))
|
189 |
+
self.maskrcnn_mask = self.mask.clone().logical_not()[:, None]
|
190 |
+
self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel)
|
191 |
+
[remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) for i in range(len(expanded_boxes))]
|
192 |
+
self.boxes = expanded_boxes.cpu()
|
193 |
+
self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask)
|
194 |
+
|
195 |
+
self.pre_processed = True
|
196 |
+
self.n_detections = len(self.boxes)
|
197 |
+
self.mask = self.mask.logical_not()
|
198 |
+
|
199 |
+
E_mask = torch.zeros((self.n_detections, 1, *self.target_imsize), device=self.mask.device, dtype=torch.bool)
|
200 |
+
self.vertices = torch.zeros_like(E_mask, dtype=torch.long)
|
201 |
+
for i in range(self.n_detections):
|
202 |
+
E_, E_mask[i] = transform_embedding(
|
203 |
+
self.cse_dets["instance_embedding"][i],
|
204 |
+
self.cse_dets["instance_segmentation"][i],
|
205 |
+
self.boxes[i],
|
206 |
+
self.cse_dets["bbox_XYXY"][i].cpu(),
|
207 |
+
self.target_imsize
|
208 |
+
)
|
209 |
+
self.vertices[i] = utils.from_E_to_vertex(E_[None], E_mask[i:i+1].logical_not(), self.embed_map).squeeze()[None]
|
210 |
+
self.E_mask = E_mask
|
211 |
+
|
212 |
+
sorted_idx = torch.argsort(area, descending=False)
|
213 |
+
self.mask = self.mask[sorted_idx]
|
214 |
+
self.boxes = self.boxes[sorted_idx.cpu()]
|
215 |
+
self.vertices = self.vertices[sorted_idx]
|
216 |
+
self.E_mask = self.E_mask[sorted_idx]
|
217 |
+
self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx]
|
218 |
+
|
219 |
+
def get_crop(self, idx: int, im):
|
220 |
+
self.pre_process()
|
221 |
+
assert idx < len(self)
|
222 |
+
box = self.boxes[idx]
|
223 |
+
mask = self.mask[idx]
|
224 |
+
im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0)
|
225 |
+
|
226 |
+
vertices_ = self.vertices[idx]
|
227 |
+
E_mask_ = self.E_mask[idx].float()
|
228 |
+
if self.normalize_embedding:
|
229 |
+
embedding = self.embed_map_normalized[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_
|
230 |
+
else:
|
231 |
+
embedding = self.embed_map[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_
|
232 |
+
|
233 |
+
return dict(
|
234 |
+
img=im,
|
235 |
+
mask=mask.float()[None],
|
236 |
+
boxes=box.reshape(1, -1),
|
237 |
+
E_mask=E_mask_[None],
|
238 |
+
vertices=vertices_[None],
|
239 |
+
embed_map=self.embed_map,
|
240 |
+
embedding=embedding[None],
|
241 |
+
maskrcnn_mask=self.maskrcnn_mask[idx].float()[None]
|
242 |
+
)
|
243 |
+
|
244 |
+
def __len__(self):
|
245 |
+
self.pre_process()
|
246 |
+
return self.n_detections
|
247 |
+
|
248 |
+
def state_dict(self, after_preprocess=False):
|
249 |
+
"""
|
250 |
+
The processed annotations occupy more space than the original detections.
|
251 |
+
"""
|
252 |
+
if not after_preprocess:
|
253 |
+
return {
|
254 |
+
"combined_segmentation": self.segmentation.bool(),
|
255 |
+
"cse_instance_segmentation": self.cse_dets["instance_segmentation"].bool(),
|
256 |
+
"cse_instance_embedding": self.cse_dets["instance_embedding"],
|
257 |
+
"cse_bbox_XYXY": self.cse_dets["bbox_XYXY"].long(),
|
258 |
+
"cls": self.__class__,
|
259 |
+
"orig_imshape_CHW": self.orig_imshape_CHW
|
260 |
+
}
|
261 |
+
self.pre_process()
|
262 |
+
return dict(
|
263 |
+
E_mask=torch.from_numpy(np.packbits(self.E_mask.bool().cpu().numpy())),
|
264 |
+
mask=torch.from_numpy(np.packbits(self.mask.bool().cpu().numpy())),
|
265 |
+
maskrcnn_mask=torch.from_numpy(np.packbits(self.maskrcnn_mask.bool().cpu().numpy())),
|
266 |
+
vertices=self.vertices.to(torch.int16).cpu(),
|
267 |
+
cls=self.__class__,
|
268 |
+
boxes=self.boxes,
|
269 |
+
orig_imshape_CHW=self.orig_imshape_CHW,
|
270 |
+
)
|
271 |
+
|
272 |
+
@staticmethod
|
273 |
+
def from_state_dict(
|
274 |
+
state_dict, embed_map,
|
275 |
+
post_process_cfg, **kwargs):
|
276 |
+
after_preprocess = "segmentation" not in state_dict and "combined_segmentation" not in state_dict
|
277 |
+
if after_preprocess:
|
278 |
+
detection = CSEPersonDetection(
|
279 |
+
segmentation=None, cse_dets=None, embed_map=embed_map,
|
280 |
+
orig_imshape_CHW=state_dict["orig_imshape_CHW"],
|
281 |
+
**post_process_cfg)
|
282 |
+
detection.vertices = tops.to_cuda(state_dict["vertices"].long())
|
283 |
+
numel = np.prod(detection.vertices.shape)
|
284 |
+
detection.E_mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["E_mask"].numpy(), count=numel))).view(*detection.vertices.shape)
|
285 |
+
detection.mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["mask"].numpy(), count=numel))).view(*detection.vertices.shape)
|
286 |
+
detection.maskrcnn_mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["maskrcnn_mask"].numpy(), count=numel))).view(*detection.vertices.shape)
|
287 |
+
detection.n_detections = len(detection.mask)
|
288 |
+
detection.pre_processed = True
|
289 |
+
|
290 |
+
if isinstance(state_dict["boxes"], np.ndarray):
|
291 |
+
state_dict["boxes"] = torch.from_numpy(state_dict["boxes"])
|
292 |
+
detection.boxes = state_dict["boxes"]
|
293 |
+
return detection
|
294 |
+
|
295 |
+
cse_dets = dict(
|
296 |
+
instance_segmentation=state_dict["cse_instance_segmentation"],
|
297 |
+
instance_embedding=state_dict["cse_instance_embedding"],
|
298 |
+
embed_map=embed_map,
|
299 |
+
bbox_XYXY=state_dict["cse_bbox_XYXY"])
|
300 |
+
cse_dets = {k: tops.to_cuda(v) for k, v in cse_dets.items()}
|
301 |
+
|
302 |
+
segmentation = state_dict["combined_segmentation"]
|
303 |
+
return CSEPersonDetection(
|
304 |
+
segmentation, cse_dets, embed_map=embed_map,
|
305 |
+
orig_imshape_CHW=state_dict["orig_imshape_CHW"],
|
306 |
+
**post_process_cfg)
|
307 |
+
|
308 |
+
def visualize(self, im):
|
309 |
+
self.pre_process()
|
310 |
+
if len(self) == 0:
|
311 |
+
return im
|
312 |
+
im = vis_utils.draw_cropped_masks(
|
313 |
+
im.clone(), self.mask, self.boxes, visualize_instances=False)
|
314 |
+
E = self.embed_map[self.vertices.long()].squeeze(1).permute(0,3, 1, 2)
|
315 |
+
im = im.to(E.device)
|
316 |
+
im = vis_utils.draw_cse_all(
|
317 |
+
E, self.E_mask.squeeze(1).bool(), im,
|
318 |
+
self.boxes, self.embed_map)
|
319 |
+
im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2)
|
320 |
+
return im
|
321 |
+
|
322 |
+
|
323 |
+
def shift_and_preprocess_keypoints(keypoints: torch.Tensor, boxes):
|
324 |
+
keypoints = keypoints.clone()
|
325 |
+
N = boxes.shape[0]
|
326 |
+
tops.assert_shape(keypoints, (N, None, 3))
|
327 |
+
tops.assert_shape(boxes, (N, 4))
|
328 |
+
x0, y0, x1, y1 = [_.view(-1, 1) for _ in boxes.T]
|
329 |
+
|
330 |
+
w = x1 - x0
|
331 |
+
h = y1 - y0
|
332 |
+
keypoints[:, :, 0] = (keypoints[:, :, 0] - x0) / w
|
333 |
+
keypoints[:, :, 1] = (keypoints[:, :, 1] - y0) / h
|
334 |
+
check_outside = lambda x: (x < 0).logical_or(x > 1)
|
335 |
+
is_outside = check_outside(keypoints[:, :, 0]).logical_or(check_outside(keypoints[:, :, 1]))
|
336 |
+
keypoints[:, :, 2] = keypoints[:, :, 2] >= 0
|
337 |
+
keypoints[:, :, 2] = (keypoints[:, :, 2] > 0).logical_and(is_outside.logical_not())
|
338 |
+
return keypoints
|
339 |
+
|
340 |
+
|
341 |
+
class PersonDetection:
|
342 |
+
|
343 |
+
def __init__(
|
344 |
+
self,
|
345 |
+
segmentation,
|
346 |
+
target_imsize,
|
347 |
+
exp_bbox_cfg, exp_bbox_filter,
|
348 |
+
dilation_percentage: float,
|
349 |
+
orig_imshape_CHW,
|
350 |
+
keypoints=None,
|
351 |
+
**kwargs) -> None:
|
352 |
+
self.segmentation = segmentation
|
353 |
+
self.target_imsize = list(target_imsize)
|
354 |
+
self.pre_processed = False
|
355 |
+
self.exp_bbox_cfg = exp_bbox_cfg
|
356 |
+
self.exp_bbox_filter = exp_bbox_filter
|
357 |
+
self.dilation_percentage = dilation_percentage
|
358 |
+
self.orig_imshape_CHW = orig_imshape_CHW
|
359 |
+
self.keypoints = keypoints
|
360 |
+
|
361 |
+
@torch.no_grad()
|
362 |
+
def pre_process(self):
|
363 |
+
if self.pre_processed:
|
364 |
+
return
|
365 |
+
boxes = masks_to_boxes(self.segmentation).cpu()
|
366 |
+
expanded_boxes = []
|
367 |
+
included_boxes = []
|
368 |
+
for i in range(len(boxes)):
|
369 |
+
exp_box = get_expanded_bbox(
|
370 |
+
boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg,
|
371 |
+
target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1])
|
372 |
+
if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter):
|
373 |
+
continue
|
374 |
+
included_boxes.append(i)
|
375 |
+
expanded_boxes.append(exp_box)
|
376 |
+
expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4)
|
377 |
+
self.segmentation = self.segmentation[included_boxes]
|
378 |
+
if self.keypoints is not None:
|
379 |
+
self.keypoints = self.keypoints[included_boxes]
|
380 |
+
area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes))
|
381 |
+
self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool)
|
382 |
+
for i, box in enumerate(expanded_boxes):
|
383 |
+
self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0]
|
384 |
+
if self.keypoints is not None:
|
385 |
+
self.keypoints = shift_and_preprocess_keypoints(self.keypoints, expanded_boxes)
|
386 |
+
dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage))
|
387 |
+
self.maskrcnn_mask = self.mask.clone().logical_not()[:, None]
|
388 |
+
self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel)
|
389 |
+
|
390 |
+
[remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) for i in range(len(expanded_boxes))]
|
391 |
+
self.boxes = expanded_boxes
|
392 |
+
self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask)
|
393 |
+
|
394 |
+
self.pre_processed = True
|
395 |
+
self.n_detections = len(self.boxes)
|
396 |
+
self.mask = self.mask.logical_not()
|
397 |
+
|
398 |
+
sorted_idx = torch.argsort(area, descending=False)
|
399 |
+
self.mask = self.mask[sorted_idx]
|
400 |
+
self.boxes = self.boxes[sorted_idx.cpu()]
|
401 |
+
self.segmentation = self.segmentation[sorted_idx]
|
402 |
+
self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx]
|
403 |
+
if self.keypoints is not None:
|
404 |
+
self.keypoints = self.keypoints[sorted_idx]
|
405 |
+
|
406 |
+
def get_crop(self, idx: int, im: torch.Tensor):
|
407 |
+
assert idx < len(self)
|
408 |
+
self.pre_process()
|
409 |
+
box = self.boxes[idx]
|
410 |
+
mask = self.mask[idx][None].float()
|
411 |
+
im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0)
|
412 |
+
batch = dict(
|
413 |
+
img=im, mask=mask, boxes=box.reshape(1, -1),
|
414 |
+
maskrcnn_mask=self.maskrcnn_mask[idx][None].float())
|
415 |
+
if self.keypoints is not None:
|
416 |
+
batch["keypoints"] = self.keypoints[idx:idx+1]
|
417 |
+
return batch
|
418 |
+
|
419 |
+
def __len__(self):
|
420 |
+
self.pre_process()
|
421 |
+
return self.n_detections
|
422 |
+
|
423 |
+
def state_dict(self, **kwargs):
|
424 |
+
return dict(
|
425 |
+
segmentation=self.segmentation.bool(),
|
426 |
+
cls=self.__class__,
|
427 |
+
orig_imshape_CHW=self.orig_imshape_CHW,
|
428 |
+
keypoints=self.keypoints
|
429 |
+
)
|
430 |
+
|
431 |
+
@staticmethod
|
432 |
+
def from_state_dict(
|
433 |
+
state_dict,
|
434 |
+
post_process_cfg, **kwargs):
|
435 |
+
return PersonDetection(
|
436 |
+
state_dict["segmentation"],
|
437 |
+
orig_imshape_CHW=state_dict["orig_imshape_CHW"],
|
438 |
+
**post_process_cfg,
|
439 |
+
keypoints=state_dict["keypoints"])
|
440 |
+
|
441 |
+
def visualize(self, im):
|
442 |
+
self.pre_process()
|
443 |
+
im = im.cpu()
|
444 |
+
if len(self) == 0:
|
445 |
+
return im
|
446 |
+
im = vis_utils.draw_cropped_masks(im.clone(), self.mask, self.boxes, visualize_instances=False)
|
447 |
+
im = vis_utils.draw_cropped_keypoints(im, self.keypoints, self.boxes)
|
448 |
+
return im
|
449 |
+
|
450 |
+
|
451 |
+
def get_dilated_boxes(exp_bbox: torch.LongTensor, mask):
|
452 |
+
"""
|
453 |
+
mask: resized mask
|
454 |
+
"""
|
455 |
+
assert exp_bbox.shape[0] == mask.shape[0]
|
456 |
+
boxes = masks_to_boxes(mask.squeeze(1)).cpu()
|
457 |
+
H, W = exp_bbox[:, 3] - exp_bbox[:, 1], exp_bbox[:, 2] - exp_bbox[:, 0]
|
458 |
+
boxes[:, [0, 2]] = (boxes[:, [0, 2]] * W[:, None] / mask.shape[-1]).long()
|
459 |
+
boxes[:, [1, 3]] = (boxes[:, [1, 3]] * H[:, None] / mask.shape[-2]).long()
|
460 |
+
boxes[:, [0, 2]] += exp_bbox[:, 0:1]
|
461 |
+
boxes[:, [1, 3]] += exp_bbox[:, 1:2]
|
462 |
+
return boxes
|
463 |
+
|
dp2/detection/utils.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import tops
|
5 |
+
from skimage.morphology import disk
|
6 |
+
from torchvision.transforms.functional import resize, InterpolationMode
|
7 |
+
from functools import lru_cache
|
8 |
+
|
9 |
+
|
10 |
+
@lru_cache(maxsize=200)
|
11 |
+
def get_kernel(n: int):
|
12 |
+
kernel = disk(n, dtype=bool)
|
13 |
+
return tops.to_cuda(torch.from_numpy(kernel).bool())
|
14 |
+
|
15 |
+
|
16 |
+
def transform_embedding(E: torch.Tensor, S: torch.Tensor, exp_bbox, E_bbox, target_imshape):
|
17 |
+
"""
|
18 |
+
Transforms the detected embedding/mask directly to the target image shape
|
19 |
+
"""
|
20 |
+
|
21 |
+
C, HE, WE = E.shape
|
22 |
+
assert E_bbox[0] >= exp_bbox[0], (E_bbox, exp_bbox)
|
23 |
+
assert E_bbox[2] >= exp_bbox[0]
|
24 |
+
assert E_bbox[1] >= exp_bbox[1]
|
25 |
+
assert E_bbox[3] >= exp_bbox[1]
|
26 |
+
assert E_bbox[2] <= exp_bbox[2]
|
27 |
+
assert E_bbox[3] <= exp_bbox[3]
|
28 |
+
|
29 |
+
x0 = int(np.round((E_bbox[0] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1]))
|
30 |
+
x1 = int(np.round((E_bbox[2] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1]))
|
31 |
+
y0 = int(np.round((E_bbox[1] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0]))
|
32 |
+
y1 = int(np.round((E_bbox[3] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0]))
|
33 |
+
new_E = torch.zeros((C, *target_imshape), device=E.device, dtype=torch.float32)
|
34 |
+
new_S = torch.zeros((target_imshape), device=S.device, dtype=torch.bool)
|
35 |
+
|
36 |
+
E = resize(E, (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)
|
37 |
+
new_E[:, y0:y1, x0:x1] = E
|
38 |
+
S = resize(S[None].float(), (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)[0] > 0
|
39 |
+
new_S[y0:y1, x0:x1] = S
|
40 |
+
return new_E, new_S
|
41 |
+
|
42 |
+
|
43 |
+
def pairwise_mask_iou(mask1: torch.Tensor, mask2: torch.Tensor):
|
44 |
+
"""
|
45 |
+
mask: shape [N, H, W]
|
46 |
+
"""
|
47 |
+
assert len(mask1.shape) == 3
|
48 |
+
assert len(mask2.shape) == 3
|
49 |
+
assert mask1.device == mask2.device, (mask1.device, mask2.device)
|
50 |
+
assert mask2.dtype == mask2.dtype
|
51 |
+
assert mask1.dtype == torch.bool
|
52 |
+
assert mask1.shape[1:] == mask2.shape[1:]
|
53 |
+
N1, H1, W1 = mask1.shape
|
54 |
+
N2, H2, W2 = mask2.shape
|
55 |
+
iou = torch.zeros((N1, N2), dtype=torch.float32)
|
56 |
+
for i in range(N1):
|
57 |
+
cur = mask1[i:i+1]
|
58 |
+
inter = torch.logical_and(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu()
|
59 |
+
union = torch.logical_or(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu()
|
60 |
+
iou[i] = inter / union
|
61 |
+
return iou
|
62 |
+
|
63 |
+
|
64 |
+
def find_best_matches(mask1: torch.Tensor, mask2: torch.Tensor, iou_threshold: float):
|
65 |
+
N1 = mask1.shape[0]
|
66 |
+
N2 = mask2.shape[0]
|
67 |
+
ious = pairwise_mask_iou(mask1, mask2).cpu().numpy()
|
68 |
+
indices = np.array([idx for idx, iou in np.ndenumerate(ious)])
|
69 |
+
ious = ious.flatten()
|
70 |
+
mask = ious >= iou_threshold
|
71 |
+
ious = ious[mask]
|
72 |
+
indices = indices[mask]
|
73 |
+
|
74 |
+
# do not sort by iou to keep ordering of mask rcnn / cse sorting.
|
75 |
+
taken1 = np.zeros((N1), dtype=bool)
|
76 |
+
taken2 = np.zeros((N2), dtype=bool)
|
77 |
+
matches = []
|
78 |
+
for i, j in indices:
|
79 |
+
if taken1[i].any() or taken2[j].any():
|
80 |
+
continue
|
81 |
+
matches.append((i, j))
|
82 |
+
taken1[i] = True
|
83 |
+
taken2[j] = True
|
84 |
+
return matches
|
85 |
+
|
86 |
+
|
87 |
+
def combine_cse_maskrcnn_dets(segmentation: torch.Tensor, cse_dets: dict, iou_threshold: float):
|
88 |
+
assert 0 < iou_threshold <= 1
|
89 |
+
matches = find_best_matches(segmentation, cse_dets["im_segmentation"], iou_threshold)
|
90 |
+
H, W = segmentation.shape[1:]
|
91 |
+
new_seg = torch.zeros((len(matches), H, W), dtype=torch.bool, device=segmentation.device)
|
92 |
+
cse_im_seg = cse_dets["im_segmentation"]
|
93 |
+
for idx, (i, j) in enumerate(matches):
|
94 |
+
new_seg[idx] = torch.logical_or(segmentation[i], cse_im_seg[j])
|
95 |
+
cse_dets = dict(
|
96 |
+
instance_segmentation=cse_dets["instance_segmentation"][[j for (i, j) in matches]],
|
97 |
+
instance_embedding=cse_dets["instance_embedding"][[j for (i, j) in matches]],
|
98 |
+
bbox_XYXY=cse_dets["bbox_XYXY"][[j for (i, j) in matches]],
|
99 |
+
scores=cse_dets["scores"][[j for (i, j) in matches]],
|
100 |
+
)
|
101 |
+
return new_seg, cse_dets, np.array(matches).reshape(-1, 2)
|
102 |
+
|
103 |
+
|
104 |
+
def initialize_cse_boxes(segmentation: torch.Tensor, cse_boxes: torch.Tensor):
|
105 |
+
"""
|
106 |
+
cse_boxes can be outside of segmentation.
|
107 |
+
"""
|
108 |
+
boxes = masks_to_boxes(segmentation)
|
109 |
+
|
110 |
+
assert boxes.shape == cse_boxes.shape, (boxes.shape, cse_boxes.shape)
|
111 |
+
combined = torch.stack((boxes, cse_boxes), dim=-1)
|
112 |
+
boxes = torch.cat((
|
113 |
+
combined[:, :2].min(dim=2).values,
|
114 |
+
combined[:, 2:].max(dim=2).values,
|
115 |
+
), dim=1)
|
116 |
+
return boxes
|
117 |
+
|
118 |
+
|
119 |
+
def cut_pad_resize(x: torch.Tensor, bbox, target_shape, fdf_resize=False):
|
120 |
+
"""
|
121 |
+
Crops or pads x to fit in the bbox and resize to target shape.
|
122 |
+
"""
|
123 |
+
C, H, W = x.shape
|
124 |
+
x0, y0, x1, y1 = bbox
|
125 |
+
|
126 |
+
if y0 > 0 and x0 > 0 and x1 <= W and y1 <= H:
|
127 |
+
new_x = x[:, y0:y1, x0:x1]
|
128 |
+
else:
|
129 |
+
new_x = torch.zeros(((C, y1-y0, x1-x0)), dtype=x.dtype, device=x.device)
|
130 |
+
y0_t = max(0, -y0)
|
131 |
+
y1_t = min(y1-y0, (y1-y0)-(y1-H))
|
132 |
+
x0_t = max(0, -x0)
|
133 |
+
x1_t = min(x1-x0, (x1-x0)-(x1-W))
|
134 |
+
x0 = max(0, x0)
|
135 |
+
y0 = max(0, y0)
|
136 |
+
x1 = min(x1, W)
|
137 |
+
y1 = min(y1, H)
|
138 |
+
new_x[:, y0_t:y1_t, x0_t:x1_t] = x[:, y0:y1, x0:x1]
|
139 |
+
if x1 - x0 == target_shape[1] and y1 - y0 == target_shape[0]:
|
140 |
+
return new_x
|
141 |
+
if x.dtype == torch.bool:
|
142 |
+
new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.NEAREST) > 0.5
|
143 |
+
elif x.dtype == torch.float32:
|
144 |
+
new_x = resize(new_x, target_shape, interpolation=InterpolationMode.BILINEAR, antialias=True)
|
145 |
+
elif x.dtype == torch.uint8:
|
146 |
+
if fdf_resize: # FDF dataset is created with cv2 INTER_AREA.
|
147 |
+
# Incorrect resizing generates noticeable poorer inpaintings.
|
148 |
+
upsampling = ((y1-y0) *(x1-x0)) < (target_shape[0] * target_shape[1])
|
149 |
+
if upsampling:
|
150 |
+
new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BICUBIC, antialias=True).round().clamp(0, 255).byte()
|
151 |
+
else:
|
152 |
+
device = new_x.device
|
153 |
+
new_x = new_x.permute(1, 2, 0).cpu().numpy()
|
154 |
+
new_x = cv2.resize(new_x, target_shape[::-1], interpolation=cv2.INTER_AREA)
|
155 |
+
new_x = torch.from_numpy(np.rollaxis(new_x, 2)).to(device)
|
156 |
+
else:
|
157 |
+
new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BILINEAR, antialias=True).round().clamp(0, 255).byte()
|
158 |
+
else:
|
159 |
+
raise ValueError(f"Not supported dtype: {x.dtype}")
|
160 |
+
return new_x
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
def masks_to_boxes(segmentation: torch.Tensor):
|
165 |
+
assert len(segmentation.shape) == 3
|
166 |
+
x = segmentation.any(dim=1).byte() # Compress rows
|
167 |
+
x0 = x.argmax(dim=1)
|
168 |
+
|
169 |
+
x1 = segmentation.shape[2] - x.flip(dims=(1,)).argmax(dim=1)
|
170 |
+
y = segmentation.any(dim=2).byte()
|
171 |
+
y0 = y.argmax(dim=1)
|
172 |
+
y1 = segmentation.shape[1] - y.flip(dims=(1,)).argmax(dim=1)
|
173 |
+
return torch.stack([x0, y0, x1, y1], dim=1)
|
174 |
+
|
dp2/discriminator/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .sg2_discriminator import SG2Discriminator
|
dp2/discriminator/sg2_discriminator.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sg3_torch_utils.ops import upfirdn2d
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn as nn
|
5 |
+
from .. import layers
|
6 |
+
from ..layers.sg2_layers import DiscriminatorEpilogue, ResidualBlock, Block
|
7 |
+
|
8 |
+
|
9 |
+
class SG2Discriminator(layers.Module):
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
cnum: int,
|
14 |
+
max_cnum_mul: int,
|
15 |
+
imsize,
|
16 |
+
min_fmap_resolution: int,
|
17 |
+
im_channels: int,
|
18 |
+
input_condition: bool,
|
19 |
+
conv_clamp: int,
|
20 |
+
input_cse: bool,
|
21 |
+
cse_nc: int):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
cse_nc = 0 if cse_nc is None else cse_nc
|
25 |
+
self._max_imsize = max(imsize)
|
26 |
+
self._cnum = cnum
|
27 |
+
self._max_cnum_mul = max_cnum_mul
|
28 |
+
self._min_fmap_resolution = min_fmap_resolution
|
29 |
+
self._input_condition = input_condition
|
30 |
+
self.input_cse = input_cse
|
31 |
+
self.layers = nn.ModuleList()
|
32 |
+
|
33 |
+
out_ch = self.get_chsize(self._max_imsize)
|
34 |
+
self.from_rgb = Block(
|
35 |
+
im_channels + input_condition*(im_channels+1) + input_cse*(cse_nc+1),
|
36 |
+
out_ch, conv_clamp=conv_clamp
|
37 |
+
)
|
38 |
+
n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1
|
39 |
+
|
40 |
+
for i in range(n_levels):
|
41 |
+
resolution = [x//2**i for x in imsize]
|
42 |
+
in_ch = self.get_chsize(max(resolution))
|
43 |
+
out_ch = self.get_chsize(max(max(resolution)//2, min_fmap_resolution))
|
44 |
+
|
45 |
+
down = 2
|
46 |
+
if i == 0:
|
47 |
+
down = 1
|
48 |
+
block = ResidualBlock(
|
49 |
+
in_ch, out_ch, down=down, conv_clamp=conv_clamp
|
50 |
+
)
|
51 |
+
self.layers.append(block)
|
52 |
+
self.output_layer = DiscriminatorEpilogue(
|
53 |
+
out_ch, resolution, conv_clamp=conv_clamp)
|
54 |
+
|
55 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1]))
|
56 |
+
|
57 |
+
def forward(self, img, condition, mask, embedding=None, E_mask=None,**kwargs):
|
58 |
+
to_cat = [img]
|
59 |
+
if self._input_condition:
|
60 |
+
to_cat.extend([condition, mask,])
|
61 |
+
if self.input_cse:
|
62 |
+
to_cat.extend([embedding, E_mask])
|
63 |
+
x = torch.cat(to_cat, dim=1)
|
64 |
+
x = self.from_rgb(x)
|
65 |
+
|
66 |
+
for i, layer in enumerate(self.layers):
|
67 |
+
x = layer(x)
|
68 |
+
|
69 |
+
x = self.output_layer(x)
|
70 |
+
return dict(score=x)
|
71 |
+
|
72 |
+
def get_chsize(self, imsize):
|
73 |
+
n = int(np.log2(self._max_imsize) - np.log2(imsize))
|
74 |
+
mul = min(2 ** n, self._max_cnum_mul)
|
75 |
+
ch = self._cnum * mul
|
76 |
+
return int(ch)
|
dp2/gan_trainer.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import atexit
|
2 |
+
from collections import defaultdict
|
3 |
+
import logging
|
4 |
+
import typing
|
5 |
+
import torch
|
6 |
+
import time
|
7 |
+
from dp2.utils import vis_utils
|
8 |
+
from dp2 import utils
|
9 |
+
from tops import logger, checkpointer
|
10 |
+
import tops
|
11 |
+
from easydict import EasyDict
|
12 |
+
|
13 |
+
|
14 |
+
def accumulate_gradients(params, fp16_ddp_accumulate):
|
15 |
+
if len(params) == 0:
|
16 |
+
return
|
17 |
+
params = [param for param in params if param.grad is not None]
|
18 |
+
flat = torch.cat([param.grad.flatten() for param in params])
|
19 |
+
orig_dtype = flat.dtype
|
20 |
+
if tops.world_size() > 1:
|
21 |
+
if fp16_ddp_accumulate:
|
22 |
+
flat = flat.half() / tops.world_size()
|
23 |
+
else:
|
24 |
+
flat /= tops.world_size()
|
25 |
+
torch.distributed.all_reduce(flat)
|
26 |
+
flat = flat.to(orig_dtype)
|
27 |
+
grads = flat.split([param.numel() for param in params])
|
28 |
+
for param, grad in zip(params, grads):
|
29 |
+
param.grad = grad.reshape(param.shape)
|
30 |
+
|
31 |
+
|
32 |
+
def accumulate_buffers(module: torch.nn.Module):
|
33 |
+
buffers = [buf for buf in module.buffers()]
|
34 |
+
if len(buffers) == 0:
|
35 |
+
return
|
36 |
+
flat = torch.cat([buf.flatten() for buf in buffers])
|
37 |
+
if tops.world_size() > 1:
|
38 |
+
torch.distributed.all_reduce(flat)
|
39 |
+
flat /= tops.world_size()
|
40 |
+
bufs = flat.split([buf.numel() for buf in buffers])
|
41 |
+
for old, new in zip(buffers, bufs):
|
42 |
+
old.copy_(new.reshape(old.shape), non_blocking=True)
|
43 |
+
|
44 |
+
|
45 |
+
def check_ddp_consistency(module):
|
46 |
+
if tops.world_size() == 1:
|
47 |
+
return
|
48 |
+
assert isinstance(module, torch.nn.Module)
|
49 |
+
assert isinstance(module, torch.nn.Module)
|
50 |
+
params_buffs = list(module.named_parameters()) + list(module.named_buffers())
|
51 |
+
for name, tensor in params_buffs:
|
52 |
+
fullname = type(module).__name__ + '.' + name
|
53 |
+
tensor = tensor.detach()
|
54 |
+
if tensor.is_floating_point():
|
55 |
+
tensor = torch.nan_to_num(tensor)
|
56 |
+
other = tensor.clone()
|
57 |
+
torch.distributed.broadcast(tensor=other, src=0)
|
58 |
+
assert (tensor == other).all(), fullname
|
59 |
+
|
60 |
+
class AverageMeter():
|
61 |
+
def __init__(self) -> None:
|
62 |
+
self.to_log = dict()
|
63 |
+
self.n = defaultdict(int)
|
64 |
+
pass
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def update(self, values: dict):
|
68 |
+
for key, value in values.items():
|
69 |
+
self.n[key] += 1
|
70 |
+
if key in self.to_log:
|
71 |
+
self.to_log[key] += value.mean().detach()
|
72 |
+
else:
|
73 |
+
self.to_log[key] = value.mean().detach()
|
74 |
+
|
75 |
+
def get_average(self):
|
76 |
+
return {key: value / self.n[key] for key, value in self.to_log.items()}
|
77 |
+
|
78 |
+
|
79 |
+
class GANTrainer:
|
80 |
+
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
G: torch.nn.Module,
|
84 |
+
D: torch.nn.Module,
|
85 |
+
G_EMA: torch.nn.Module,
|
86 |
+
D_optim: torch.optim.Optimizer,
|
87 |
+
G_optim: torch.optim.Optimizer,
|
88 |
+
dl_train: typing.Iterator,
|
89 |
+
dl_val: typing.Iterable,
|
90 |
+
scaler_D: torch.cuda.amp.GradScaler,
|
91 |
+
scaler_G: torch.cuda.amp.GradScaler,
|
92 |
+
ims_per_log: int,
|
93 |
+
max_images_to_train: int,
|
94 |
+
loss_handler,
|
95 |
+
ims_per_val: int,
|
96 |
+
evaluate_fn,
|
97 |
+
batch_size: int,
|
98 |
+
broadcast_buffers: bool,
|
99 |
+
fp16_ddp_accumulate: bool,
|
100 |
+
save_state: bool,
|
101 |
+
*args, **kwargs):
|
102 |
+
super().__init__(*args, **kwargs)
|
103 |
+
|
104 |
+
self.G = G
|
105 |
+
self.D = D
|
106 |
+
self.G_EMA = G_EMA
|
107 |
+
self.D_optim = D_optim
|
108 |
+
self.G_optim = G_optim
|
109 |
+
self.dl_train = dl_train
|
110 |
+
self.dl_val = dl_val
|
111 |
+
self.scaler_D = scaler_D
|
112 |
+
self.scaler_G = scaler_G
|
113 |
+
self.loss_handler = loss_handler
|
114 |
+
self.max_images_to_train = max_images_to_train
|
115 |
+
self.images_per_val = ims_per_val
|
116 |
+
self.images_per_log = ims_per_log
|
117 |
+
self.evaluate_fn = evaluate_fn
|
118 |
+
self.batch_size = batch_size
|
119 |
+
self.broadcast_buffers = broadcast_buffers
|
120 |
+
self.fp16_ddp_accumulate = fp16_ddp_accumulate
|
121 |
+
|
122 |
+
self.train_state = EasyDict(
|
123 |
+
next_log_step=0,
|
124 |
+
next_val_step=ims_per_val,
|
125 |
+
total_time=0
|
126 |
+
)
|
127 |
+
|
128 |
+
checkpointer.register_models(dict(
|
129 |
+
generator=G, discriminator=D, EMA_generator=G_EMA,
|
130 |
+
D_optimizer=D_optim,
|
131 |
+
G_optimizer=G_optim,
|
132 |
+
train_state=self.train_state,
|
133 |
+
scaler_D=self.scaler_D,
|
134 |
+
scaler_G=self.scaler_G
|
135 |
+
))
|
136 |
+
if checkpointer.has_checkpoint():
|
137 |
+
checkpointer.load_registered_models()
|
138 |
+
logger.log(f"Resuming training from: global step: {logger.global_step()}")
|
139 |
+
else:
|
140 |
+
logger.add_dict({
|
141 |
+
"stats/discriminator_parameters": tops.num_parameters(self.D),
|
142 |
+
"stats/generator_parameters": tops.num_parameters(self.G),
|
143 |
+
}, commit=False)
|
144 |
+
if save_state:
|
145 |
+
# If the job is unexpectedly killed, there could be a mismatch between previously saved checkpoint and the current checkpoint.
|
146 |
+
atexit.register(checkpointer.save_registered_models)
|
147 |
+
|
148 |
+
self._ims_per_log = ims_per_log
|
149 |
+
|
150 |
+
self.to_log = AverageMeter()
|
151 |
+
self.trainable_params_D = [param for param in self.D.parameters() if param.requires_grad]
|
152 |
+
self.trainable_params_G = [param for param in self.G.parameters() if param.requires_grad]
|
153 |
+
logger.add_dict({
|
154 |
+
"stats/discriminator_trainable_parameters": sum(p.numel() for p in self.trainable_params_D),
|
155 |
+
"stats/generator_trainable_parameters": sum(p.numel() for p in self.trainable_params_G),
|
156 |
+
}, commit=False, level=logging.INFO)
|
157 |
+
check_ddp_consistency(self.D)
|
158 |
+
check_ddp_consistency(self.G)
|
159 |
+
check_ddp_consistency(self.G_EMA.generator)
|
160 |
+
|
161 |
+
def train_loop(self):
|
162 |
+
self.log_time()
|
163 |
+
while logger.global_step() <= self.max_images_to_train:
|
164 |
+
batch = next(self.dl_train)
|
165 |
+
self.G_EMA.update_beta()
|
166 |
+
self.to_log.update(self.step_D(batch))
|
167 |
+
self.to_log.update(self.step_G(batch))
|
168 |
+
self.G_EMA.update(self.G)
|
169 |
+
|
170 |
+
if logger.global_step() >= self.train_state.next_log_step:
|
171 |
+
to_log = {f"loss/{key}": item.item() for key, item in self.to_log.get_average().items()}
|
172 |
+
to_log.update({"amp/grad_scale_G": self.scaler_G.get_scale()})
|
173 |
+
to_log.update({"amp/grad_scale_D": self.scaler_D.get_scale()})
|
174 |
+
self.to_log = AverageMeter()
|
175 |
+
logger.add_dict(to_log, commit=True)
|
176 |
+
self.train_state.next_log_step += self.images_per_log
|
177 |
+
if self.scaler_D.get_scale() < 1e-8 or self.scaler_G.get_scale() < 1e-8:
|
178 |
+
print("Stopping training as gradient scale < 1e-8")
|
179 |
+
logger.log("Stopping training as gradient scale < 1e-8")
|
180 |
+
break
|
181 |
+
|
182 |
+
if logger.global_step() >= self.train_state.next_val_step:
|
183 |
+
self.evaluate()
|
184 |
+
self.log_time()
|
185 |
+
self.save_images()
|
186 |
+
self.train_state.next_val_step += self.images_per_val
|
187 |
+
logger.step(self.batch_size*tops.world_size())
|
188 |
+
logger.log(f"Reached end of training at step {logger.global_step()}.")
|
189 |
+
checkpointer.save_registered_models()
|
190 |
+
|
191 |
+
def estimate_ims_per_hour(self):
|
192 |
+
batch = next(self.dl_train)
|
193 |
+
n_ims = int(100e3)
|
194 |
+
n_steps = int(n_ims / (self.batch_size * tops.world_size()))
|
195 |
+
n_ims = n_steps * self.batch_size * tops.world_size()
|
196 |
+
for i in range(10): # Warmup
|
197 |
+
self.G_EMA.update_beta()
|
198 |
+
self.step_D(batch)
|
199 |
+
self.step_G(batch)
|
200 |
+
self.G_EMA.update(self.G)
|
201 |
+
start_time = time.time()
|
202 |
+
for i in utils.tqdm_(list(range(n_steps))):
|
203 |
+
self.G_EMA.update_beta()
|
204 |
+
self.step_D(batch)
|
205 |
+
self.step_G(batch)
|
206 |
+
self.G_EMA.update(self.G)
|
207 |
+
total_time = time.time() - start_time
|
208 |
+
ims_per_sec = n_ims / total_time
|
209 |
+
ims_per_hour = ims_per_sec * 60*60
|
210 |
+
ims_per_day = ims_per_hour * 24
|
211 |
+
logger.log(f"Images per hour: {ims_per_hour/1e6:.3f}M")
|
212 |
+
logger.log(f"Images per day: {ims_per_day/1e6:.3f}M")
|
213 |
+
import math
|
214 |
+
ims_per_4_day = int(math.ceil(ims_per_day / tops.world_size() * 4))
|
215 |
+
logger.log(f"Images per 4 days: {ims_per_4_day}")
|
216 |
+
logger.add_dict({
|
217 |
+
"stats/ims_per_day": ims_per_day,
|
218 |
+
"stats/ims_per_4_day": ims_per_4_day
|
219 |
+
})
|
220 |
+
|
221 |
+
def log_time(self):
|
222 |
+
if not hasattr(self, "start_time"):
|
223 |
+
self.start_time = time.time()
|
224 |
+
self.last_time_step = logger.global_step()
|
225 |
+
return
|
226 |
+
n_images = logger.global_step() - self.last_time_step
|
227 |
+
if n_images == 0:
|
228 |
+
return
|
229 |
+
n_secs = time.time() - self.start_time
|
230 |
+
n_ims_per_sec = n_images / n_secs
|
231 |
+
training_time_hours = n_secs / 60/ 60
|
232 |
+
self.train_state.total_time += training_time_hours
|
233 |
+
remaining_images = self.max_images_to_train - logger.global_step()
|
234 |
+
remaining_time = remaining_images / n_ims_per_sec / 60 / 60
|
235 |
+
logger.add_dict({
|
236 |
+
"stats/n_ims_per_sec": n_ims_per_sec,
|
237 |
+
"stats/total_traing_time_hours": self.train_state.total_time,
|
238 |
+
"stats/remaining_time_hours": remaining_time
|
239 |
+
})
|
240 |
+
self.last_time_step = logger.global_step()
|
241 |
+
self.start_time = time.time()
|
242 |
+
|
243 |
+
def save_images(self):
|
244 |
+
dl_val = iter(self.dl_val)
|
245 |
+
batch = next(dl_val)
|
246 |
+
# TRUNCATED visualization
|
247 |
+
ims_to_log = 8
|
248 |
+
self.G_EMA.eval()
|
249 |
+
z = self.G.get_z(batch["img"])
|
250 |
+
fakes_truncated = self.G_EMA.sample(**batch, truncation_value=0)["img"]
|
251 |
+
fakes_truncated = utils.denormalize_img(fakes_truncated).mul(255).byte()[:ims_to_log].cpu()
|
252 |
+
if "__key__" in batch:
|
253 |
+
batch.pop("__key__")
|
254 |
+
real = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log]
|
255 |
+
to_vis = torch.cat((real, fakes_truncated))
|
256 |
+
logger.add_images("images/truncated", to_vis, nrow=2)
|
257 |
+
|
258 |
+
# Diverse images
|
259 |
+
ims_diverse = 3
|
260 |
+
batch = next(dl_val)
|
261 |
+
to_vis = []
|
262 |
+
|
263 |
+
for i in range(ims_diverse):
|
264 |
+
z = self.G.get_z(batch["img"])[:1].repeat(batch["img"].shape[0], 1)
|
265 |
+
fakes = utils.denormalize_img(self.G_EMA(**batch, z=z)["img"]).mul(255).byte()[:ims_to_log].cpu()
|
266 |
+
to_vis.append(fakes)
|
267 |
+
if "__key__" in batch:
|
268 |
+
batch.pop("__key__")
|
269 |
+
reals = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log]
|
270 |
+
to_vis.insert(0, reals)
|
271 |
+
to_vis = torch.cat(to_vis)
|
272 |
+
logger.add_images("images/diverse", to_vis, nrow=ims_diverse+1)
|
273 |
+
|
274 |
+
self.G_EMA.train()
|
275 |
+
pass
|
276 |
+
|
277 |
+
def evaluate(self):
|
278 |
+
logger.log("Stating evaluation.")
|
279 |
+
self.G_EMA.eval()
|
280 |
+
try:
|
281 |
+
checkpointer.save_registered_models(max_keep=3)
|
282 |
+
except Exception:
|
283 |
+
logger.log("Could not save checkpoint.")
|
284 |
+
if self.broadcast_buffers:
|
285 |
+
check_ddp_consistency(self.G)
|
286 |
+
check_ddp_consistency(self.D)
|
287 |
+
metrics = self.evaluate_fn(generator=self.G_EMA, dataloader=self.dl_val)
|
288 |
+
metrics = {f"metrics/{k}": v for k,v in metrics.items()}
|
289 |
+
logger.add_dict(metrics, level=logger.logger.INFO)
|
290 |
+
|
291 |
+
def step_D(self, batch):
|
292 |
+
utils.set_requires_grad(self.trainable_params_D, True)
|
293 |
+
utils.set_requires_grad(self.trainable_params_G, False)
|
294 |
+
tops.zero_grad(self.D)
|
295 |
+
loss, to_log = self.loss_handler.D_loss(batch, grad_scaler=self.scaler_D)
|
296 |
+
with torch.autograd.profiler.record_function("D_step"):
|
297 |
+
self.scaler_D.scale(loss).backward()
|
298 |
+
accumulate_gradients(self.trainable_params_D, fp16_ddp_accumulate=self.fp16_ddp_accumulate)
|
299 |
+
if self.broadcast_buffers:
|
300 |
+
accumulate_buffers(self.D)
|
301 |
+
accumulate_buffers(self.G)
|
302 |
+
# Step will not unscale if unscale is called previously.
|
303 |
+
self.scaler_D.step(self.D_optim)
|
304 |
+
self.scaler_D.update()
|
305 |
+
utils.set_requires_grad(self.trainable_params_D, False)
|
306 |
+
utils.set_requires_grad(self.trainable_params_G, False)
|
307 |
+
return to_log
|
308 |
+
|
309 |
+
def step_G(self, batch):
|
310 |
+
utils.set_requires_grad(self.trainable_params_D, False)
|
311 |
+
utils.set_requires_grad(self.trainable_params_G, True)
|
312 |
+
tops.zero_grad(self.G)
|
313 |
+
loss, to_log = self.loss_handler.G_loss(batch, grad_scaler=self.scaler_G)
|
314 |
+
with torch.autograd.profiler.record_function("G_step"):
|
315 |
+
self.scaler_G.scale(loss).backward()
|
316 |
+
accumulate_gradients(self.trainable_params_G, fp16_ddp_accumulate=self.fp16_ddp_accumulate)
|
317 |
+
if self.broadcast_buffers:
|
318 |
+
accumulate_buffers(self.G)
|
319 |
+
accumulate_buffers(self.D)
|
320 |
+
self.scaler_G.step(self.G_optim)
|
321 |
+
self.scaler_G.update()
|
322 |
+
utils.set_requires_grad(self.trainable_params_D, False)
|
323 |
+
utils.set_requires_grad(self.trainable_params_G, False)
|
324 |
+
return to_log
|
dp2/generator/__init__.py
ADDED
File without changes
|
dp2/generator/base.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import tqdm
|
4 |
+
import tops
|
5 |
+
from ..layers import Module
|
6 |
+
from ..layers.sg2_layers import FullyConnectedLayer
|
7 |
+
from dp2 import utils
|
8 |
+
|
9 |
+
|
10 |
+
class BaseGenerator(Module):
|
11 |
+
|
12 |
+
def __init__(self, z_channels: int):
|
13 |
+
super().__init__()
|
14 |
+
self.z_channels = z_channels
|
15 |
+
self.latent_space = "Z"
|
16 |
+
|
17 |
+
@torch.no_grad()
|
18 |
+
def get_z(
|
19 |
+
self,
|
20 |
+
x: torch.Tensor = None,
|
21 |
+
z: torch.Tensor = None,
|
22 |
+
truncation_value: float = None,
|
23 |
+
batch_size: int = None,
|
24 |
+
dtype=None, device=None) -> torch.Tensor:
|
25 |
+
"""Generates a latent variable for generator.
|
26 |
+
"""
|
27 |
+
if z is not None:
|
28 |
+
return z
|
29 |
+
if x is not None:
|
30 |
+
batch_size = x.shape[0]
|
31 |
+
dtype = x.dtype
|
32 |
+
device = x.device
|
33 |
+
if device is None:
|
34 |
+
device = utils.get_device()
|
35 |
+
if truncation_value == 0:
|
36 |
+
return torch.zeros((batch_size, self.z_channels), device=device, dtype=dtype)
|
37 |
+
z = torch.randn((batch_size, self.z_channels), device=device, dtype=dtype)
|
38 |
+
if truncation_value is None:
|
39 |
+
return z
|
40 |
+
while z.abs().max() > truncation_value:
|
41 |
+
m = z.abs() > truncation_value
|
42 |
+
z[m] = torch.rand_like(z)[m]
|
43 |
+
return z
|
44 |
+
|
45 |
+
def sample(self, truncation_value, z=None, **kwargs):
|
46 |
+
"""
|
47 |
+
Samples via interpolating to the mean (0).
|
48 |
+
"""
|
49 |
+
if truncation_value is None:
|
50 |
+
return self.forward(**kwargs)
|
51 |
+
truncation_value = max(0, truncation_value)
|
52 |
+
truncation_value = min(truncation_value, 1)
|
53 |
+
if z is None:
|
54 |
+
z = self.get_z(kwargs["condition"])
|
55 |
+
z = z * truncation_value
|
56 |
+
return self.forward(**kwargs, z=z)
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
class SG2StyleNet(torch.nn.Module):
|
61 |
+
def __init__(self,
|
62 |
+
z_dim, # Input latent (Z) dimensionality.
|
63 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
64 |
+
num_layers = 2, # Number of mapping layers.
|
65 |
+
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
|
66 |
+
w_avg_beta = 0.998, # Decay for tracking the moving average of W during training.
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
self.z_dim = z_dim
|
70 |
+
self.w_dim = w_dim
|
71 |
+
self.num_layers = num_layers
|
72 |
+
self.w_avg_beta = w_avg_beta
|
73 |
+
# Construct layers.
|
74 |
+
features = [self.z_dim] + [self.w_dim] * self.num_layers
|
75 |
+
for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]):
|
76 |
+
layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier)
|
77 |
+
setattr(self, f'fc{idx}', layer)
|
78 |
+
self.register_buffer('w_avg', torch.zeros([w_dim]))
|
79 |
+
|
80 |
+
def forward(self, z, update_emas=False, y=None):
|
81 |
+
tops.assert_shape(z, [None, self.z_dim])
|
82 |
+
|
83 |
+
# Embed, normalize, and concatenate inputs.
|
84 |
+
x = z.to(torch.float32)
|
85 |
+
x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
|
86 |
+
# Execute layers.
|
87 |
+
for idx in range(self.num_layers):
|
88 |
+
x = getattr(self, f'fc{idx}')(x)
|
89 |
+
# Update moving average of W.
|
90 |
+
if update_emas:
|
91 |
+
self.w_avg.copy_(x.float().detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
92 |
+
|
93 |
+
return x
|
94 |
+
|
95 |
+
def extra_repr(self):
|
96 |
+
return f'z_dim={self.z_dim:d}, w_dim={self.w_dim:d}'
|
97 |
+
|
98 |
+
def update_w(self, n=int(10e3), batch_size=32):
|
99 |
+
"""
|
100 |
+
Calculate w_ema over n iterations.
|
101 |
+
Useful in cases where w_ema is calculated incorrectly during training.
|
102 |
+
"""
|
103 |
+
n = n // batch_size
|
104 |
+
for i in tqdm.trange(n, desc="Updating w"):
|
105 |
+
z = torch.randn((batch_size, self.z_dim), device=tops.get_device())
|
106 |
+
self(z, update_emas=True)
|
107 |
+
|
108 |
+
|
109 |
+
class BaseStyleGAN(BaseGenerator):
|
110 |
+
|
111 |
+
def __init__(self, z_channels: int, w_dim: int):
|
112 |
+
super().__init__(z_channels)
|
113 |
+
self.style_net = SG2StyleNet(z_channels, w_dim)
|
114 |
+
self.latent_space = "W"
|
115 |
+
|
116 |
+
def get_w(self, z, update_emas):
|
117 |
+
return self.style_net(z, update_emas=update_emas)
|
118 |
+
|
119 |
+
@torch.no_grad()
|
120 |
+
def sample(self, truncation_value, **kwargs):
|
121 |
+
if truncation_value is None:
|
122 |
+
return self.forward(**kwargs)
|
123 |
+
truncation_value = max(0, truncation_value)
|
124 |
+
truncation_value = min(truncation_value, 1)
|
125 |
+
w = self.get_w(self.get_z(kwargs["condition"]), False)
|
126 |
+
w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value)
|
127 |
+
return self.forward(**kwargs, w=w)
|
128 |
+
|
129 |
+
def update_w(self, *args, **kwargs):
|
130 |
+
self.style_net.update_w(*args, **kwargs)
|
131 |
+
|
132 |
+
|
133 |
+
@torch.no_grad()
|
134 |
+
def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
|
135 |
+
if truncation_value is None:
|
136 |
+
return self.forward(**kwargs)
|
137 |
+
truncation_value = max(0, truncation_value)
|
138 |
+
truncation_value = min(truncation_value, 1)
|
139 |
+
w = self.get_w(self.get_z(kwargs["condition"]), False)
|
140 |
+
if w_indices is None:
|
141 |
+
w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w)))
|
142 |
+
w_centers = self.style_net.w_centers[w_indices].to(w.device)
|
143 |
+
w = w_centers.to(w.dtype).lerp(w, truncation_value)
|
144 |
+
return self.forward(**kwargs, w=w)
|
dp2/generator/dummy_generators.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from .base import BaseGenerator
|
4 |
+
|
5 |
+
|
6 |
+
class PixelationGenerator(BaseGenerator):
|
7 |
+
|
8 |
+
def __init__(self, pixelation_size, **kwargs):
|
9 |
+
super().__init__(z_channels=0)
|
10 |
+
self.pixelation_size = pixelation_size
|
11 |
+
self.z_channels = 0
|
12 |
+
self.latent_space=None
|
13 |
+
|
14 |
+
def forward(self, img, condition, mask, **kwargs):
|
15 |
+
old_shape = img.shape[-2:]
|
16 |
+
img = nn.functional.interpolate(img, size=(self.pixelation_size, self.pixelation_size), mode="bilinear", align_corners=True)
|
17 |
+
img = nn.functional.interpolate(img, size=old_shape, mode="bilinear", align_corners=True)
|
18 |
+
out = img*(1-mask) + condition*mask
|
19 |
+
return {"img": out}
|
20 |
+
|
21 |
+
|
22 |
+
class MaskOutGenerator(BaseGenerator):
|
23 |
+
|
24 |
+
def __init__(self, noise: str, **kwargs):
|
25 |
+
super().__init__(z_channels=0)
|
26 |
+
self.noise = noise
|
27 |
+
self.z_channels = 0
|
28 |
+
assert self.noise in ["rand", "constant"]
|
29 |
+
self.latent_space = None
|
30 |
+
|
31 |
+
def forward(self, img, condition, mask, **kwargs):
|
32 |
+
|
33 |
+
if self.noise == "constant":
|
34 |
+
img = torch.zeros_like(img)
|
35 |
+
elif self.noise == "rand":
|
36 |
+
img = torch.rand_like(img)
|
37 |
+
out = img*(1-mask) + condition*mask
|
38 |
+
return {"img": out}
|
39 |
+
|
40 |
+
|
41 |
+
class IdentityGenerator(BaseGenerator):
|
42 |
+
|
43 |
+
def __init__(self):
|
44 |
+
super().__init__(z_channels=0)
|
45 |
+
|
46 |
+
def forward(self, img, condition, mask, **kwargs):
|
47 |
+
return dict(img=img)
|
dp2/generator/imagen3_old.py
ADDED
@@ -0,0 +1,1210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# What is missing from this implementation
|
2 |
+
# 1. Global context in res block
|
3 |
+
# 2. Cross attention of conditional information in resnet block
|
4 |
+
#
|
5 |
+
from functools import partial
|
6 |
+
import tops
|
7 |
+
from tops.config import instantiate
|
8 |
+
import warnings
|
9 |
+
from typing import Iterable, List, Tuple
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch import einsum
|
14 |
+
from einops import rearrange
|
15 |
+
from dp2 import infer, utils
|
16 |
+
from .base import BaseGenerator
|
17 |
+
from sg3_torch_utils.ops import bias_act
|
18 |
+
from dp2.layers import Sequential
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from torchvision.transforms.functional import resize, InterpolationMode
|
21 |
+
from sg3_torch_utils.ops import conv2d_resample, fma, upfirdn2d
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
class Upfirdn2d(torch.nn.Module):
|
27 |
+
|
28 |
+
|
29 |
+
def __init__(self, down=1, up=1, fix_gain=True):
|
30 |
+
super().__init__()
|
31 |
+
self.register_buffer("resample_filter", upfirdn2d.setup_filter([1, 3, 3, 1]))
|
32 |
+
fw, fh = upfirdn2d._get_filter_size(self.resample_filter)
|
33 |
+
px0, px1, py0, py1 = upfirdn2d._parse_padding(0)
|
34 |
+
self.down = down
|
35 |
+
self.up = up
|
36 |
+
if up > 1:
|
37 |
+
px0 += (fw + up - 1) // 2
|
38 |
+
px1 += (fw - up) // 2
|
39 |
+
py0 += (fh + up - 1) // 2
|
40 |
+
py1 += (fh - up) // 2
|
41 |
+
if down > 1:
|
42 |
+
px0 += (fw - down + 1) // 2
|
43 |
+
px1 += (fw - down) // 2
|
44 |
+
py0 += (fh - down + 1) // 2
|
45 |
+
py1 += (fh - down) // 2
|
46 |
+
self.padding = [px0,px1,py0,py1]
|
47 |
+
self.gain = up**2 if fix_gain else 1
|
48 |
+
|
49 |
+
def forward(self, x, *args):
|
50 |
+
if isinstance(x, dict):
|
51 |
+
x = {k: v for k, v in x.items()}
|
52 |
+
x["x"] = upfirdn2d.upfirdn2d(x["x"], self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain)
|
53 |
+
return x
|
54 |
+
x = upfirdn2d.upfirdn2d(x, self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain)
|
55 |
+
if len(args) == 0:
|
56 |
+
return x
|
57 |
+
return (x, *args)
|
58 |
+
@torch.no_grad()
|
59 |
+
def spatial_embed_keypoints(keypoints: torch.Tensor, x):
|
60 |
+
tops.assert_shape(keypoints, (None, None, 3))
|
61 |
+
B, N_K, _ = keypoints.shape
|
62 |
+
H, W = x.shape[-2:]
|
63 |
+
keypoint_spatial = torch.zeros(keypoints.shape[0], N_K, H, W, device=keypoints.device, dtype=torch.float32)
|
64 |
+
x, y, visible = keypoints.chunk(3, dim=2)
|
65 |
+
x = (x * W).round().long().clamp(0, W-1)
|
66 |
+
y = (y * H).round().long().clamp(0, H-1)
|
67 |
+
kp_idx = torch.arange(0, N_K, 1, device=keypoints.device, dtype=torch.long).view(1, -1, 1).repeat(B, 1, 1)
|
68 |
+
pos = (kp_idx*(H*W) + y*W + x + 1)
|
69 |
+
# Offset all by 1 to index invisible keypoints as 0
|
70 |
+
pos = (pos * visible.round().long()).squeeze(dim=-1)
|
71 |
+
keypoint_spatial = torch.zeros(keypoints.shape[0], N_K*H*W+1, device=keypoints.device, dtype=torch.float32)
|
72 |
+
keypoint_spatial.scatter_(1, pos, 1)
|
73 |
+
keypoint_spatial = keypoint_spatial[:, 1:].view(-1, N_K, H, W)
|
74 |
+
return keypoint_spatial
|
75 |
+
|
76 |
+
|
77 |
+
def modulated_conv2d(
|
78 |
+
x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
|
79 |
+
weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
|
80 |
+
styles, # Modulation coefficients of shape [batch_size, in_channels].
|
81 |
+
noise = None, # Optional noise tensor to add to the output activations.
|
82 |
+
up = 1, # Integer upsampling factor.
|
83 |
+
down = 1, # Integer downsampling factor.
|
84 |
+
padding = 0, # Padding with respect to the upsampled image.
|
85 |
+
resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
|
86 |
+
demodulate = True, # Apply weight demodulation?
|
87 |
+
flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
|
88 |
+
fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
|
89 |
+
):
|
90 |
+
batch_size = x.shape[0]
|
91 |
+
out_channels, in_channels, kh, kw = weight.shape
|
92 |
+
tops.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
|
93 |
+
tops.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
|
94 |
+
tops.assert_shape(styles, [batch_size, in_channels]) # [NI]
|
95 |
+
|
96 |
+
# Pre-normalize inputs to avoid FP16 overflow.
|
97 |
+
if x.dtype == torch.float16 and demodulate:
|
98 |
+
weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
|
99 |
+
styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
|
100 |
+
|
101 |
+
# Calculate per-sample weights and demodulation coefficients.
|
102 |
+
w = None
|
103 |
+
dcoefs = None
|
104 |
+
if demodulate or fused_modconv:
|
105 |
+
w = weight.unsqueeze(0) # [NOIkk]
|
106 |
+
w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
|
107 |
+
if demodulate:
|
108 |
+
dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
|
109 |
+
if demodulate and fused_modconv:
|
110 |
+
w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
|
111 |
+
|
112 |
+
# Execute by scaling the activations before and after the convolution.
|
113 |
+
if not fused_modconv:
|
114 |
+
x = x * styles.reshape(batch_size, -1, 1, 1)
|
115 |
+
x = conv2d_resample.conv2d_resample(x=x, w=weight, f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
|
116 |
+
if demodulate and noise is not None:
|
117 |
+
x = fma.fma(x, dcoefs.reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
|
118 |
+
elif demodulate:
|
119 |
+
x = x * dcoefs.reshape(batch_size, -1, 1, 1)
|
120 |
+
elif noise is not None:
|
121 |
+
x = x.add_(noise.to(x.dtype))
|
122 |
+
return x
|
123 |
+
|
124 |
+
with tops.suppress_tracer_warnings(): # this value will be treated as a constant
|
125 |
+
batch_size = int(batch_size)
|
126 |
+
# Execute as one fused op using grouped convolution.
|
127 |
+
tops.assert_shape(x, [batch_size, in_channels, None, None])
|
128 |
+
x = x.reshape(1, -1, *x.shape[2:])
|
129 |
+
w = w.reshape(-1, in_channels, kh, kw)
|
130 |
+
x = conv2d_resample.conv2d_resample(x=x, w=w, f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
|
131 |
+
x = x.reshape(batch_size, -1, *x.shape[2:])
|
132 |
+
if noise is not None:
|
133 |
+
x = x.add_(noise)
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
class Identity(nn.Module):
|
138 |
+
|
139 |
+
def __init__(self) -> None:
|
140 |
+
super().__init__()
|
141 |
+
|
142 |
+
def forward(self, x, *args, **kwargs):
|
143 |
+
return x
|
144 |
+
|
145 |
+
|
146 |
+
class LayerNorm(nn.Module):
|
147 |
+
def __init__(self, dim, stable=False):
|
148 |
+
super().__init__()
|
149 |
+
self.stable = stable
|
150 |
+
self.g = nn.Parameter(torch.ones(dim))
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
if self.stable:
|
154 |
+
x = x / x.amax(dim=-1, keepdim=True).detach()
|
155 |
+
|
156 |
+
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
|
157 |
+
var = torch.var(x, dim=-1, unbiased=False, keepdim=True)
|
158 |
+
mean = torch.mean(x, dim=-1, keepdim=True)
|
159 |
+
return (x - mean) * (var + eps).rsqrt() * self.g
|
160 |
+
|
161 |
+
|
162 |
+
class FullyConnectedLayer(torch.nn.Module):
|
163 |
+
def __init__(self,
|
164 |
+
in_features, # Number of input features.
|
165 |
+
out_features, # Number of output features.
|
166 |
+
bias = True, # Apply additive bias before the activation function?
|
167 |
+
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
|
168 |
+
lr_multiplier = 1, # Learning rate multiplier.
|
169 |
+
bias_init = 0, # Initial value for the additive bias.
|
170 |
+
):
|
171 |
+
super().__init__()
|
172 |
+
self.repr = dict(
|
173 |
+
in_features=in_features, out_features=out_features, bias=bias,
|
174 |
+
activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init)
|
175 |
+
self.activation = activation
|
176 |
+
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
|
177 |
+
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
|
178 |
+
self.weight_gain = lr_multiplier / np.sqrt(in_features)
|
179 |
+
self.bias_gain = lr_multiplier
|
180 |
+
self.in_features = in_features
|
181 |
+
self.out_features = out_features
|
182 |
+
|
183 |
+
def forward(self, x):
|
184 |
+
w = self.weight * self.weight_gain
|
185 |
+
b = self.bias
|
186 |
+
if b is not None:
|
187 |
+
if self.bias_gain != 1:
|
188 |
+
b = b * self.bias_gain
|
189 |
+
x = F.linear(x, w)
|
190 |
+
x = bias_act.bias_act(x, b, act=self.activation)
|
191 |
+
return x
|
192 |
+
|
193 |
+
def extra_repr(self) -> str:
|
194 |
+
return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
def checkpoint_fn(fn, *args, **kwargs):
|
199 |
+
warnings.simplefilter("ignore")
|
200 |
+
return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
|
201 |
+
|
202 |
+
class Conv2d(torch.nn.Module):
|
203 |
+
def __init__(
|
204 |
+
self,
|
205 |
+
in_channels,
|
206 |
+
out_channels,
|
207 |
+
kernel_size=3,
|
208 |
+
activation='lrelu',
|
209 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
210 |
+
bias=True,
|
211 |
+
norm=None,
|
212 |
+
lr_multiplier=1,
|
213 |
+
bias_init=0,
|
214 |
+
w_dim=None,
|
215 |
+
gradient_checkpoint_norm=False,
|
216 |
+
gain=1,
|
217 |
+
):
|
218 |
+
super().__init__()
|
219 |
+
self.fused_modconv = False
|
220 |
+
if norm == torch.nn.InstanceNorm2d:
|
221 |
+
self.norm = torch.nn.InstanceNorm2d(None)
|
222 |
+
elif isinstance(norm, torch.nn.Module):
|
223 |
+
self.norm = norm
|
224 |
+
elif norm == "fused_modconv":
|
225 |
+
self.fused_modconv = True
|
226 |
+
elif norm:
|
227 |
+
self.norm = torch.nn.InstanceNorm2d(None)
|
228 |
+
elif norm is not None:
|
229 |
+
raise ValueError(f"norm not supported: {norm}")
|
230 |
+
self.activation = activation
|
231 |
+
self.conv_clamp = conv_clamp
|
232 |
+
self.out_channels = out_channels
|
233 |
+
self.in_channels = in_channels
|
234 |
+
self.padding = kernel_size // 2
|
235 |
+
self.repr = dict(
|
236 |
+
in_channels=in_channels, out_channels=out_channels,
|
237 |
+
kernel_size=kernel_size,
|
238 |
+
activation=activation, conv_clamp=conv_clamp, bias=bias,
|
239 |
+
fused_modconv=self.fused_modconv
|
240 |
+
)
|
241 |
+
self.act_gain = bias_act.activation_funcs[activation].def_gain * gain
|
242 |
+
self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2))
|
243 |
+
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]))
|
244 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels])+bias_init) if bias else None
|
245 |
+
self.bias_gain = lr_multiplier
|
246 |
+
if w_dim is not None:
|
247 |
+
if self.fused_modconv:
|
248 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
249 |
+
else:
|
250 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
251 |
+
self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0)
|
252 |
+
self.gradient_checkpoint_norm = gradient_checkpoint_norm
|
253 |
+
|
254 |
+
def forward(self, x, w=None, gain=1, **kwargs):
|
255 |
+
if self.fused_modconv:
|
256 |
+
styles = self.affine(w)
|
257 |
+
with torch.cuda.amp.autocast(enabled=False):
|
258 |
+
x = modulated_conv2d(x=x.half(), weight=self.weight.half(), styles=styles.half(), noise=None,
|
259 |
+
padding=self.padding, flip_weight=True, fused_modconv=False).to(x.dtype)
|
260 |
+
else:
|
261 |
+
if hasattr(self, "affine"):
|
262 |
+
gamma = self.affine(w).view(-1, self.in_channels, 1, 1)
|
263 |
+
beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1)
|
264 |
+
x = fma.fma(x, gamma ,beta)
|
265 |
+
w = self.weight * self.weight_gain
|
266 |
+
x = F.conv2d(input=x, weight=w, padding=self.padding,)
|
267 |
+
|
268 |
+
if hasattr(self, "norm"):
|
269 |
+
if self.gradient_checkpoint_norm:
|
270 |
+
x = checkpoint_fn(self.norm, x)
|
271 |
+
else:
|
272 |
+
x = self.norm(x)
|
273 |
+
act_gain = self.act_gain * gain
|
274 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
275 |
+
b = self.bias * self.bias_gain if self.bias is not None else None
|
276 |
+
x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
|
277 |
+
return x
|
278 |
+
|
279 |
+
def extra_repr(self) -> str:
|
280 |
+
return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
|
281 |
+
|
282 |
+
|
283 |
+
class CrossAttention(nn.Module):
|
284 |
+
def __init__(
|
285 |
+
self,
|
286 |
+
dim,
|
287 |
+
context_dim,
|
288 |
+
dim_head=64,
|
289 |
+
heads=8,
|
290 |
+
norm_context=False,
|
291 |
+
):
|
292 |
+
super().__init__()
|
293 |
+
self.scale = dim_head ** -0.5
|
294 |
+
|
295 |
+
self.heads = heads
|
296 |
+
inner_dim = dim_head * heads
|
297 |
+
|
298 |
+
self.norm = nn.InstanceNorm1d(dim)
|
299 |
+
self.norm_context = nn.InstanceNorm1d(None) if norm_context else Identity()
|
300 |
+
|
301 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
302 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
|
303 |
+
|
304 |
+
self.to_out = nn.Sequential(
|
305 |
+
nn.Linear(inner_dim, dim, bias=False),
|
306 |
+
nn.InstanceNorm1d(None)
|
307 |
+
)
|
308 |
+
|
309 |
+
def forward(self, x, w):
|
310 |
+
x = self.norm(x)
|
311 |
+
w = self.norm_context(w)
|
312 |
+
|
313 |
+
q, k, v = (self.to_q(x), *self.to_kv(w).chunk(2, dim = -1))
|
314 |
+
q = rearrange(q, "b n (h d) -> b h n d", h = self.heads)
|
315 |
+
k = rearrange(k, "b n (h d) -> b h n d", h = self.heads)
|
316 |
+
v = rearrange(v, "b n (h d) -> b h n d", h = self.heads)
|
317 |
+
q = q * self.scale
|
318 |
+
# similarities
|
319 |
+
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
320 |
+
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
321 |
+
|
322 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
323 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
324 |
+
return self.to_out(out)
|
325 |
+
|
326 |
+
|
327 |
+
class SG2ResidualBlock(torch.nn.Module):
|
328 |
+
def __init__(
|
329 |
+
self,
|
330 |
+
in_channels, # Number of input channels, 0 = first block.
|
331 |
+
out_channels, # Number of output channels.
|
332 |
+
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
333 |
+
skip_gain=np.sqrt(.5),
|
334 |
+
cross_attention: bool = False,
|
335 |
+
cross_attention_len: int = None,
|
336 |
+
use_adain: bool = True,
|
337 |
+
**layer_kwargs, # Arguments for conv layer.
|
338 |
+
):
|
339 |
+
super().__init__()
|
340 |
+
self.in_channels = in_channels
|
341 |
+
self.out_channels = out_channels
|
342 |
+
w_dim = layer_kwargs.pop("w_dim") if "w_dim" in layer_kwargs else None
|
343 |
+
if use_adain:
|
344 |
+
layer_kwargs["w_dim"] = w_dim
|
345 |
+
|
346 |
+
self.conv0 = Conv2d(in_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs)
|
347 |
+
self.conv1 = Conv2d(out_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs, gain=skip_gain)
|
348 |
+
|
349 |
+
self.skip = Conv2d(in_channels, out_channels, kernel_size=1, bias=False, gain=skip_gain)
|
350 |
+
if cross_attention and w_dim is not None:
|
351 |
+
self.cross_attention_len = cross_attention_len
|
352 |
+
self.cross_attn = CrossAttention(
|
353 |
+
dim=out_channels, context_dim=w_dim//self.cross_attention_len,
|
354 |
+
gain=skip_gain)
|
355 |
+
|
356 |
+
def forward(self, x, w=None, **layer_kwargs):
|
357 |
+
y = self.skip(x)
|
358 |
+
x = self.conv0(x, w, **layer_kwargs)
|
359 |
+
x = self.conv1(x, w, **layer_kwargs)
|
360 |
+
if hasattr(self, "cross_attn"):
|
361 |
+
h = x.shape[-2]
|
362 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
363 |
+
w = rearrange(w, "b (n c) -> b n c", n=self.cross_attention_len)
|
364 |
+
x = self.cross_attn(x, w=w) + x
|
365 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h)
|
366 |
+
return y + x
|
367 |
+
|
368 |
+
|
369 |
+
def default(val, d):
|
370 |
+
if val is not None:
|
371 |
+
return val
|
372 |
+
return d() if callable(d) else d
|
373 |
+
|
374 |
+
|
375 |
+
def cast_tuple(val, length=None):
|
376 |
+
if isinstance(val, Iterable) and not isinstance(val, str):
|
377 |
+
val = tuple(val)
|
378 |
+
output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
|
379 |
+
if length is not None:
|
380 |
+
assert len(output) == length, (output, length)
|
381 |
+
return output
|
382 |
+
|
383 |
+
|
384 |
+
class Attention(nn.Module):
|
385 |
+
# This is a version of Multi-Query Attention ()
|
386 |
+
# Fast Transformer Decoding: One Write-Head is All You Need
|
387 |
+
# Ablated in: https://arxiv.org/pdf/2203.07814.pdf
|
388 |
+
# and https://arxiv.org/pdf/2204.02311.pdf
|
389 |
+
def __init__(self, dim, norm, attn_fix_gain, gradient_checkpoint, dim_head=64, heads=8, cosine_sim_attn=False, fix_attention_again=False, gain=None):
|
390 |
+
super().__init__()
|
391 |
+
self.scale = dim_head**-0.5 if not cosine_sim_attn else 1.0
|
392 |
+
self.cosine_sim_attn = cosine_sim_attn
|
393 |
+
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
|
394 |
+
self.gradient_checkpoint = gradient_checkpoint
|
395 |
+
self.heads = heads
|
396 |
+
self.dim = dim
|
397 |
+
self.fix_attention_again = fix_attention_again
|
398 |
+
inner_dim = dim_head * heads
|
399 |
+
if norm == "LN":
|
400 |
+
self.norm = LayerNorm(dim)
|
401 |
+
elif norm == "IN":
|
402 |
+
self.norm = nn.InstanceNorm1d(dim)
|
403 |
+
elif norm is None:
|
404 |
+
self.norm = nn.Identity()
|
405 |
+
else:
|
406 |
+
raise ValueError(f"Norm not supported: {norm}")
|
407 |
+
|
408 |
+
self.to_q = FullyConnectedLayer(dim, inner_dim, bias=False)
|
409 |
+
self.to_kv = FullyConnectedLayer(dim, dim_head*2, bias=False)
|
410 |
+
|
411 |
+
self.to_out = nn.Sequential(
|
412 |
+
FullyConnectedLayer(inner_dim, dim, bias=False),
|
413 |
+
LayerNorm(dim) if norm == "LN" else nn.InstanceNorm1d(dim)
|
414 |
+
)
|
415 |
+
if fix_attention_again:
|
416 |
+
assert gain is not None
|
417 |
+
self.gain = gain
|
418 |
+
else:
|
419 |
+
self.gain = np.sqrt(.5) if attn_fix_gain else 1
|
420 |
+
|
421 |
+
def run_function(self, x, attn_bias):
|
422 |
+
b, c, h, w = x.shape
|
423 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
424 |
+
in_ = x
|
425 |
+
b, n, device = *x.shape[:2], x.device
|
426 |
+
x = self.norm(x)
|
427 |
+
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
|
428 |
+
|
429 |
+
q = rearrange(q, "b n (h d) -> b h n d", h=self.heads)
|
430 |
+
q = q * self.scale
|
431 |
+
|
432 |
+
# calculate query / key similarities
|
433 |
+
sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale
|
434 |
+
|
435 |
+
if attn_bias is not None:
|
436 |
+
attn_bias = attn_bias
|
437 |
+
attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)")
|
438 |
+
sim = sim + attn_bias
|
439 |
+
|
440 |
+
attn = sim.softmax(dim=-1)
|
441 |
+
|
442 |
+
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
443 |
+
|
444 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
445 |
+
if self.fix_attention_again:
|
446 |
+
out = self.to_out(out)*self.gain + in_
|
447 |
+
else:
|
448 |
+
out = (self.to_out(out) + in_) * self.gain
|
449 |
+
out = rearrange(out, "b (h w) c -> b c h w", h=h)
|
450 |
+
return out
|
451 |
+
|
452 |
+
def forward(self, x, *args, attn_bias=None, **kwargs):
|
453 |
+
if self.gradient_checkpoint:
|
454 |
+
return checkpoint_fn(self.run_function, x, attn_bias)
|
455 |
+
return self.run_function(x, attn_bias)
|
456 |
+
|
457 |
+
def get_attention(self, x, attn_bias=None):
|
458 |
+
b, c, h, w = x.shape
|
459 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
460 |
+
in_ = x
|
461 |
+
b, n, device = *x.shape[:2], x.device
|
462 |
+
x = self.norm(x)
|
463 |
+
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
|
464 |
+
|
465 |
+
q = rearrange(q, "b n (h d) -> b h n d", h=self.heads)
|
466 |
+
q = q * self.scale
|
467 |
+
|
468 |
+
# calculate query / key similarities
|
469 |
+
sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale
|
470 |
+
|
471 |
+
if attn_bias is not None:
|
472 |
+
attn_bias = attn_bias
|
473 |
+
attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)")
|
474 |
+
sim = sim + attn_bias
|
475 |
+
|
476 |
+
attn = sim.softmax(dim=-1)
|
477 |
+
return attn, None
|
478 |
+
|
479 |
+
|
480 |
+
class BiasedAttention(Attention):
|
481 |
+
|
482 |
+
def __init__(self, *args, head_wise: bool=True, **kwargs):
|
483 |
+
super().__init__(*args, **kwargs)
|
484 |
+
out_ch = self.heads if head_wise else 1
|
485 |
+
self.conv = Conv2d(self.dim+2, out_ch, activation="linear", kernel_size=3, bias_init=0)
|
486 |
+
nn.init.zeros_(self.conv.weight.data)
|
487 |
+
|
488 |
+
def forward(self, x, mask):
|
489 |
+
mask = resize(mask, size=x.shape[-2:])
|
490 |
+
bias = self.conv(torch.cat((x, mask, 1-mask), dim=1))
|
491 |
+
return super().forward(x=x, attn_bias=bias)
|
492 |
+
|
493 |
+
def get_attention(self, x, mask):
|
494 |
+
mask = resize(mask, size=x.shape[-2:])
|
495 |
+
bias = self.conv(torch.cat((x, mask, 1-mask), dim=1))
|
496 |
+
return super().get_attention(x, bias)[0], bias
|
497 |
+
|
498 |
+
class UNet(BaseGenerator):
|
499 |
+
|
500 |
+
def __init__(
|
501 |
+
self,
|
502 |
+
im_channels: int,
|
503 |
+
dim: int,
|
504 |
+
dim_mults: tuple,
|
505 |
+
num_resnet_blocks, # Number of resnet blocks per resolution
|
506 |
+
n_middle_blocks: int,
|
507 |
+
z_channels: int,
|
508 |
+
conv_clamp: int,
|
509 |
+
layer_attn,
|
510 |
+
w_dim: int,
|
511 |
+
norm_enc: bool,
|
512 |
+
norm_dec: str,
|
513 |
+
stylenet: nn.Module,
|
514 |
+
enc_style: bool, # Toggle style injection in encoder
|
515 |
+
use_maskrcnn_mask: bool,
|
516 |
+
skip_all_unets: bool,
|
517 |
+
fix_resize:bool,
|
518 |
+
comodulate: bool,
|
519 |
+
comod_net: nn.Module,
|
520 |
+
lr_comod: float,
|
521 |
+
dec_style: bool,
|
522 |
+
input_keypoints: bool,
|
523 |
+
n_keypoints: int,
|
524 |
+
input_keypoint_indices: Tuple[int],
|
525 |
+
use_adain: bool,
|
526 |
+
cross_attention: bool,
|
527 |
+
cross_attention_len: int,
|
528 |
+
gradient_checkpoint_norm: bool,
|
529 |
+
attn_cls: partial,
|
530 |
+
mask_out_train: bool,
|
531 |
+
fix_gain_again: bool,
|
532 |
+
) -> None:
|
533 |
+
super().__init__(z_channels)
|
534 |
+
self.enc_style = enc_style
|
535 |
+
self.n_keypoints = n_keypoints
|
536 |
+
self.input_keypoint_indices = list(input_keypoint_indices)
|
537 |
+
self.input_keypoints = input_keypoints
|
538 |
+
self.mask_out_train = mask_out_train
|
539 |
+
n_layers = len(dim_mults)
|
540 |
+
self.n_layers = n_layers
|
541 |
+
layer_attn = cast_tuple(layer_attn, n_layers)
|
542 |
+
num_resnet_blocks = cast_tuple(num_resnet_blocks, n_layers)
|
543 |
+
self._cnum = dim
|
544 |
+
self._image_channels = im_channels
|
545 |
+
self._z_channels = z_channels
|
546 |
+
encoder_layers = []
|
547 |
+
condition_ch = im_channels
|
548 |
+
self.from_rgb = Conv2d(
|
549 |
+
condition_ch + 2 + 2*int(use_maskrcnn_mask) + self.input_keypoints*len(input_keypoint_indices)
|
550 |
+
, dim, 7)
|
551 |
+
|
552 |
+
self.use_maskrcnn_mask = use_maskrcnn_mask
|
553 |
+
self.skip_all_unets = skip_all_unets
|
554 |
+
dims = [dim*m for m in dim_mults]
|
555 |
+
enc_blk = partial(
|
556 |
+
SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_enc,
|
557 |
+
use_adain=use_adain and self.enc_style,
|
558 |
+
w_dim=w_dim,
|
559 |
+
cross_attention=cross_attention,
|
560 |
+
cross_attention_len=cross_attention_len,
|
561 |
+
gradient_checkpoint_norm=gradient_checkpoint_norm
|
562 |
+
)
|
563 |
+
dec_blk = partial(
|
564 |
+
SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_dec,
|
565 |
+
use_adain=use_adain and dec_style,
|
566 |
+
w_dim=w_dim,
|
567 |
+
cross_attention=cross_attention,
|
568 |
+
cross_attention_len=cross_attention_len,
|
569 |
+
gradient_checkpoint_norm=gradient_checkpoint_norm
|
570 |
+
)
|
571 |
+
# Currently up/down sampling is done by bilinear upsampling.
|
572 |
+
# This can be simplified by replacing it with a strided upsampling layer...
|
573 |
+
self.encoder_attns = nn.ModuleList()
|
574 |
+
for lidx in range(n_layers):
|
575 |
+
gain = np.sqrt(1/3) if layer_attn[lidx] and fix_gain_again else np.sqrt(.5)
|
576 |
+
dim_in = dims[lidx]
|
577 |
+
dim_out = dims[min(lidx+1, n_layers-1)]
|
578 |
+
res_blocks = nn.ModuleList()
|
579 |
+
for i in range(num_resnet_blocks[lidx]):
|
580 |
+
is_last = num_resnet_blocks[lidx] - 1 == i
|
581 |
+
cur_dim = dim_out if is_last else dim_in
|
582 |
+
block = enc_blk(dim_in, cur_dim, skip_gain=gain)
|
583 |
+
res_blocks.append(block)
|
584 |
+
if layer_attn[lidx]:
|
585 |
+
self.encoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain))
|
586 |
+
else:
|
587 |
+
self.encoder_attns.append(Identity())
|
588 |
+
encoder_layers.append(res_blocks)
|
589 |
+
self.encoder = torch.nn.ModuleList(encoder_layers)
|
590 |
+
|
591 |
+
# initialize decoder
|
592 |
+
decoder_layers = []
|
593 |
+
self.unet_layers = torch.nn.ModuleList()
|
594 |
+
self.decoder_attns = torch.nn.ModuleList()
|
595 |
+
for lidx in range(n_layers):
|
596 |
+
dim_in = dims[min(-lidx, -1)]
|
597 |
+
dim_out = dims[-1-lidx]
|
598 |
+
res_blocks = nn.ModuleList()
|
599 |
+
unet_skips = nn.ModuleList()
|
600 |
+
for i in range(num_resnet_blocks[-lidx-1]):
|
601 |
+
is_first = i == 0
|
602 |
+
has_unet = is_first or skip_all_unets
|
603 |
+
is_last = i == num_resnet_blocks[-lidx-1] - 1
|
604 |
+
cur_dim = dim_in if is_first else dim_out
|
605 |
+
if has_unet and is_last and layer_attn[-lidx-1] and fix_gain_again: # x + residual + unet + layer attn
|
606 |
+
gain = np.sqrt(1/4)
|
607 |
+
elif has_unet: # x + residual + unet
|
608 |
+
gain = np.sqrt(1/3)
|
609 |
+
elif layer_attn[-lidx-1] and fix_gain_again: # x + residual + attention
|
610 |
+
gain = np.sqrt(1/3)
|
611 |
+
else: # x + residual
|
612 |
+
gain = np.sqrt(1/2) # Only residual block
|
613 |
+
block = dec_blk(cur_dim, dim_out, skip_gain=gain)
|
614 |
+
res_blocks.append(block)
|
615 |
+
if has_unet:
|
616 |
+
unet_block = Conv2d(
|
617 |
+
cur_dim, cur_dim, kernel_size=1, conv_clamp=conv_clamp,
|
618 |
+
norm=nn.InstanceNorm2d(None),
|
619 |
+
gradient_checkpoint_norm=gradient_checkpoint_norm,
|
620 |
+
gain=gain)
|
621 |
+
unet_skips.append(unet_block)
|
622 |
+
else:
|
623 |
+
unet_skips.append(torch.nn.Identity())
|
624 |
+
if layer_attn[-lidx-1]:
|
625 |
+
self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain))
|
626 |
+
else:
|
627 |
+
self.decoder_attns.append(Identity())
|
628 |
+
|
629 |
+
decoder_layers.append(res_blocks)
|
630 |
+
self.unet_layers.append(unet_skips)
|
631 |
+
|
632 |
+
middle_blocks = []
|
633 |
+
for i in range(n_middle_blocks):
|
634 |
+
block = dec_blk(dims[-1], dims[-1])
|
635 |
+
middle_blocks.append(block)
|
636 |
+
if n_middle_blocks != 0:
|
637 |
+
self.middle_blocks = Sequential(*middle_blocks)
|
638 |
+
self.decoder = torch.nn.ModuleList(decoder_layers)
|
639 |
+
self.to_rgb = Conv2d(dim, im_channels, 1, activation="linear", conv_clamp=conv_clamp)
|
640 |
+
self.stylenet = stylenet
|
641 |
+
self.downsample = Upfirdn2d(down=2, fix_gain=fix_resize)
|
642 |
+
self.upsample = Upfirdn2d(up=2, fix_gain=fix_resize)
|
643 |
+
self.comodulate = comodulate
|
644 |
+
if comodulate:
|
645 |
+
assert not self.enc_style
|
646 |
+
self.to_y = nn.Sequential(
|
647 |
+
Conv2d(dims[-1], dims[-1], lr_multiplier=lr_comod, gradient_checkpoint_norm=gradient_checkpoint_norm),
|
648 |
+
nn.AdaptiveAvgPool2d(1),
|
649 |
+
nn.Flatten(),
|
650 |
+
FullyConnectedLayer(dims[-1], 512, activation="lrelu", lr_multiplier=lr_comod)
|
651 |
+
)
|
652 |
+
self.comod_net = comod_net
|
653 |
+
|
654 |
+
|
655 |
+
def forward(self, condition, mask, maskrcnn_mask=None, z=None, w=None, update_emas=False, keypoints=None, return_decoder_features=False, **kwargs):
|
656 |
+
if z is None:
|
657 |
+
z = self.get_z(condition)
|
658 |
+
if w is None:
|
659 |
+
w = self.stylenet(z, update_emas=update_emas)
|
660 |
+
if self.use_maskrcnn_mask:
|
661 |
+
x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1)
|
662 |
+
else:
|
663 |
+
x = torch.cat((condition, mask, 1-mask), dim=1)
|
664 |
+
|
665 |
+
if self.input_keypoints:
|
666 |
+
keypoints = keypoints[:, self.input_keypoint_indices]
|
667 |
+
one_hot_pose = spatial_embed_keypoints(keypoints, x)
|
668 |
+
x = torch.cat((x, one_hot_pose), dim=1)
|
669 |
+
x = self.from_rgb(x)
|
670 |
+
x, unet_features = self.forward_enc(x, mask, w)
|
671 |
+
x, decoder_features = self.forward_dec(x, mask, w, unet_features)
|
672 |
+
x = self.to_rgb(x)
|
673 |
+
unmasked = x
|
674 |
+
if self.mask_out_train:
|
675 |
+
x = mask * condition + (1-mask) * x
|
676 |
+
out = dict(img=x, unmasked=unmasked)
|
677 |
+
if return_decoder_features:
|
678 |
+
out["decoder_features"] = decoder_features
|
679 |
+
return out
|
680 |
+
|
681 |
+
def forward_enc(self, x, mask, w):
|
682 |
+
unet_features = []
|
683 |
+
for i, res_blocks in enumerate(self.encoder):
|
684 |
+
is_last = i == len(self.encoder) - 1
|
685 |
+
for block in res_blocks:
|
686 |
+
x = block(x, w=w)
|
687 |
+
unet_features.append(x)
|
688 |
+
x = self.encoder_attns[i](x, mask=mask)
|
689 |
+
if not is_last:
|
690 |
+
x = self.downsample(x)
|
691 |
+
if self.comodulate:
|
692 |
+
y = self.to_y(x)
|
693 |
+
y = torch.cat((w, y), dim=-1)
|
694 |
+
w = self.comod_net(y)
|
695 |
+
return x, unet_features
|
696 |
+
|
697 |
+
def forward_dec(self, x, mask, w, unet_features):
|
698 |
+
if hasattr(self, "middle_blocks"):
|
699 |
+
x = self.middle_blocks(x, w=w)
|
700 |
+
features = []
|
701 |
+
unet_features = iter(reversed(unet_features))
|
702 |
+
for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)):
|
703 |
+
is_last = i == len(self.decoder) - 1
|
704 |
+
for skip, block in zip(unet_skip, res_blocks):
|
705 |
+
skip_x = next(unet_features)
|
706 |
+
if not isinstance(skip, torch.nn.Identity):
|
707 |
+
skip_x = skip(skip_x)
|
708 |
+
x = x + skip_x
|
709 |
+
x = block(x, w=w)
|
710 |
+
x = self.decoder_attns[i](x, mask=mask)
|
711 |
+
features.append(x)
|
712 |
+
if not is_last:
|
713 |
+
x = self.upsample(x)
|
714 |
+
return x, features
|
715 |
+
|
716 |
+
def get_w(self, z, update_emas):
|
717 |
+
return self.stylenet(z, update_emas=update_emas)
|
718 |
+
|
719 |
+
@torch.no_grad()
|
720 |
+
def sample(self, truncation_value, **kwargs):
|
721 |
+
if truncation_value is None:
|
722 |
+
return self.forward(**kwargs)
|
723 |
+
truncation_value = max(0, truncation_value)
|
724 |
+
truncation_value = min(truncation_value, 1)
|
725 |
+
w = self.get_w(self.get_z(kwargs["condition"]), False)
|
726 |
+
w = self.stylenet.w_avg.to(w.dtype).lerp(w, truncation_value)
|
727 |
+
return self.forward(**kwargs, w=w)
|
728 |
+
|
729 |
+
def update_w(self, *args, **kwargs):
|
730 |
+
self.style_net.update_w(*args, **kwargs)
|
731 |
+
|
732 |
+
@property
|
733 |
+
def style_net(self):
|
734 |
+
return self.stylenet
|
735 |
+
|
736 |
+
@torch.no_grad()
|
737 |
+
def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
|
738 |
+
if truncation_value is None:
|
739 |
+
return self.forward(**kwargs)
|
740 |
+
truncation_value = max(0, truncation_value)
|
741 |
+
truncation_value = min(truncation_value, 1)
|
742 |
+
w = self.get_w(self.get_z(kwargs["condition"]), False)
|
743 |
+
if w_indices is None:
|
744 |
+
w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w)))
|
745 |
+
w_centers = self.style_net.w_centers[w_indices].to(w.device)
|
746 |
+
w = w_centers.to(w.dtype).lerp(w, truncation_value)
|
747 |
+
return self.forward(**kwargs, w=w)
|
748 |
+
|
749 |
+
|
750 |
+
def get_stem_unet_kwargs(cfg):
|
751 |
+
if "stem_cfg" in cfg.generator: # If the stem has another stem, recursively apply get_stem_unet_kwargs
|
752 |
+
return get_stem_unet_kwargs(cfg.generator.stem_cfg)
|
753 |
+
return dict(cfg.generator)
|
754 |
+
|
755 |
+
|
756 |
+
class GrowingUnet(BaseGenerator):
|
757 |
+
|
758 |
+
def __init__(
|
759 |
+
self,
|
760 |
+
coarse_stem_cfg: str, # This can be a coarse generator or None
|
761 |
+
sr_cfg: str, # Can be a previous progressive u-net, Unet or None
|
762 |
+
residual: bool,
|
763 |
+
new_dataset: bool, # The "new dataset" creates condition first -> resizes
|
764 |
+
**unet_kwargs):
|
765 |
+
kwargs = dict()
|
766 |
+
if coarse_stem_cfg is not None:
|
767 |
+
coarse_stem_cfg = utils.load_config(coarse_stem_cfg)
|
768 |
+
kwargs = get_stem_unet_kwargs(coarse_stem_cfg)
|
769 |
+
if sr_cfg is not None:
|
770 |
+
sr_cfg = utils.load_config(sr_cfg)
|
771 |
+
sr_stem_unet_kwargs = get_stem_unet_kwargs(sr_cfg)
|
772 |
+
kwargs.update(sr_stem_unet_kwargs)
|
773 |
+
kwargs.update(unet_kwargs)
|
774 |
+
kwargs["stylenet"] = None
|
775 |
+
kwargs.pop("_target_")
|
776 |
+
if "sr_cfg" in kwargs: # Unet kwargs are inherited, do not pass this to the new u-net
|
777 |
+
del kwargs["sr_cfg"]
|
778 |
+
if "coarse_stem_cfg" in kwargs:
|
779 |
+
del kwargs["coarse_stem_cfg"]
|
780 |
+
super().__init__(z_channels=kwargs["z_channels"])
|
781 |
+
if coarse_stem_cfg is not None:
|
782 |
+
z_channels = coarse_stem_cfg.generator.z_channels
|
783 |
+
super().__init__(z_channels)
|
784 |
+
self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval()
|
785 |
+
self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize)
|
786 |
+
utils.set_requires_grad(self.coarse_stem, False)
|
787 |
+
else:
|
788 |
+
assert not residual
|
789 |
+
|
790 |
+
if sr_cfg is not None:
|
791 |
+
self.sr_stem = infer.build_trained_generator(sr_cfg, map_location="cpu").eval()
|
792 |
+
del self.sr_stem.from_rgb
|
793 |
+
del self.sr_stem.to_rgb
|
794 |
+
if hasattr(self.sr_stem, "coarse_stem"):
|
795 |
+
del self.sr_stem.coarse_stem
|
796 |
+
if isinstance(self.sr_stem, UNet):
|
797 |
+
del self.sr_stem.encoder[0][0] # Delete first residual block
|
798 |
+
del self.sr_stem.decoder[-1][-1] # Delete last residual block
|
799 |
+
else:
|
800 |
+
assert isinstance(self.sr_stem, GrowingUnet)
|
801 |
+
del self.sr_stem.unet.encoder[0][0] # Delete first residual block
|
802 |
+
del self.sr_stem.unet.decoder[-1][-1] # Delete last residual block
|
803 |
+
utils.set_requires_grad(self.sr_stem, False)
|
804 |
+
|
805 |
+
|
806 |
+
args = kwargs.pop("_args_")
|
807 |
+
if hasattr(self, "sr_stem"): # Growing the SR stem - Add a new layer to match sr
|
808 |
+
n_layers = len(kwargs["dim_mults"])
|
809 |
+
dim_mult = sr_stem_unet_kwargs["dim"] / (kwargs["dim"] * max(kwargs["dim_mults"]))
|
810 |
+
kwargs["dim_mults"] = [*kwargs["dim_mults"], int(dim_mult)]
|
811 |
+
kwargs["layer_attn"] = [*cast_tuple(kwargs["layer_attn"], n_layers), False]
|
812 |
+
kwargs["num_resnet_blocks"] = [*cast_tuple(kwargs["num_resnet_blocks"], n_layers), 1]
|
813 |
+
self.unet = UNet(
|
814 |
+
*args,
|
815 |
+
**kwargs
|
816 |
+
)
|
817 |
+
self.from_rgb = self.unet.from_rgb
|
818 |
+
self.to_rgb = self.unet.to_rgb
|
819 |
+
self.residual = residual
|
820 |
+
self.new_dataset = new_dataset
|
821 |
+
if residual:
|
822 |
+
nn.init.zeros_(self.to_rgb.weight.data)
|
823 |
+
del self.unet.from_rgb, self.unet.to_rgb
|
824 |
+
|
825 |
+
def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, **kwargs):
|
826 |
+
# Downsample for stem
|
827 |
+
if z is None:
|
828 |
+
z = self.get_z(img)
|
829 |
+
if w is None:
|
830 |
+
w = self.style_net(z)
|
831 |
+
if hasattr(self, "coarse_stem"):
|
832 |
+
with torch.no_grad():
|
833 |
+
if self.new_dataset:
|
834 |
+
img_stem = utils.denormalize_img(img)*255
|
835 |
+
condition_stem = img_stem * mask + (1-mask)*127
|
836 |
+
condition_stem = condition_stem.round()
|
837 |
+
condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True)
|
838 |
+
condition_stem = condition_stem / 255 *2 - 1
|
839 |
+
mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float()
|
840 |
+
maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float()
|
841 |
+
else:
|
842 |
+
mask_stem = (resize(mask, self.coarse_stem.imsize, antialias=True) > .99).float()
|
843 |
+
maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, antialias=True) > .5).float()
|
844 |
+
img_stem = utils.denormalize_img(img)*255
|
845 |
+
img_stem = resize(img_stem, self.coarse_stem.imsize, antialias=True).round()
|
846 |
+
img_stem = img_stem / 255 * 2 - 1
|
847 |
+
condition_stem = img_stem * mask_stem
|
848 |
+
stem_out = self.coarse_stem(
|
849 |
+
condition=condition_stem, mask=mask_stem,
|
850 |
+
maskrcnn_mask=maskrcnn_stem, w=w,
|
851 |
+
keypoints=keypoints)
|
852 |
+
x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True)
|
853 |
+
condition = condition*mask + (1-mask) * x_lr
|
854 |
+
if self.unet.use_maskrcnn_mask:
|
855 |
+
x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1)
|
856 |
+
else:
|
857 |
+
x = torch.cat((condition, mask, 1-mask), dim=1)
|
858 |
+
if self.unet.input_keypoints:
|
859 |
+
keypoints = keypoints[:, self.unet.input_keypoint_indices]
|
860 |
+
one_hot_pose = spatial_embed_keypoints(keypoints, x)
|
861 |
+
x = torch.cat((x, one_hot_pose), dim=1)
|
862 |
+
x = self.from_rgb(x)
|
863 |
+
x, unet_features = self.forward_enc(x, mask, w)
|
864 |
+
x = self.forward_dec(x, mask, w, unet_features)
|
865 |
+
if self.residual:
|
866 |
+
x = self.to_rgb(x) + condition
|
867 |
+
else:
|
868 |
+
x = self.to_rgb(x)
|
869 |
+
return dict(
|
870 |
+
img=condition * mask + (1-mask) * x,
|
871 |
+
unmasked=x,
|
872 |
+
x_lowres=[condition]
|
873 |
+
)
|
874 |
+
|
875 |
+
def forward_enc(self, x, mask, w):
|
876 |
+
x, unet_features = self.unet.forward_enc(x, mask, w)
|
877 |
+
if hasattr(self, "sr_stem"):
|
878 |
+
x, unet_features_stem = self.sr_stem.forward_enc(x, mask, w)
|
879 |
+
else:
|
880 |
+
unet_features_stem = None
|
881 |
+
return x, [unet_features, unet_features_stem]
|
882 |
+
|
883 |
+
def forward_dec(self, x, mask, w, unet_features):
|
884 |
+
unet_features, unet_features_stem = unet_features
|
885 |
+
if hasattr(self, "sr_stem"):
|
886 |
+
x = self.sr_stem.forward_dec(x, mask, w, unet_features_stem)
|
887 |
+
x, unet_features = self.unet.forward_dec(x, mask, w, unet_features)
|
888 |
+
return x
|
889 |
+
|
890 |
+
def get_z(self, *args, **kwargs):
|
891 |
+
if hasattr(self, "coarse_stem"):
|
892 |
+
return self.coarse_stem.get_z(*args, **kwargs)
|
893 |
+
if hasattr(self, "sr_stem"):
|
894 |
+
return self.sr_stem.get_z(*args, **kwargs)
|
895 |
+
raise AttributeError()
|
896 |
+
|
897 |
+
@property
|
898 |
+
def style_net(self):
|
899 |
+
if hasattr(self, "coarse_stem"):
|
900 |
+
return self.coarse_stem.style_net
|
901 |
+
if hasattr(self, "sr_stem"):
|
902 |
+
return self.sr_stem.style_net
|
903 |
+
raise AttributeError()
|
904 |
+
|
905 |
+
def update_w(self, *args, **kwargs):
|
906 |
+
self.style_net.update_w(*args, **kwargs)
|
907 |
+
|
908 |
+
def get_w(self, z, update_emas):
|
909 |
+
return self.style_net(z, update_emas=update_emas)
|
910 |
+
|
911 |
+
@torch.no_grad()
|
912 |
+
def sample(self, truncation_value, **kwargs):
|
913 |
+
if truncation_value is None:
|
914 |
+
return self.forward(**kwargs)
|
915 |
+
truncation_value = max(0, truncation_value)
|
916 |
+
truncation_value = min(truncation_value, 1)
|
917 |
+
w = self.get_w(self.get_z(kwargs["condition"]), False)
|
918 |
+
w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value)
|
919 |
+
return self.forward(**kwargs, w=w)
|
920 |
+
|
921 |
+
@torch.no_grad()
|
922 |
+
def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
|
923 |
+
if truncation_value is None:
|
924 |
+
return self.forward(**kwargs)
|
925 |
+
truncation_value = max(0, truncation_value)
|
926 |
+
truncation_value = min(truncation_value, 1)
|
927 |
+
w = self.get_w(self.get_z(kwargs["condition"]), False)
|
928 |
+
if w_indices is None:
|
929 |
+
w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w)))
|
930 |
+
w_centers = self.style_net.w_centers[w_indices].to(w.device)
|
931 |
+
w = w_centers.to(w.dtype).lerp(w, truncation_value)
|
932 |
+
return self.forward(**kwargs, w=w)
|
933 |
+
|
934 |
+
|
935 |
+
class CascadedUnet(BaseGenerator):
|
936 |
+
|
937 |
+
def __init__(
|
938 |
+
self,
|
939 |
+
coarse_stem_cfg: str, # This can be a coarse generator or None
|
940 |
+
residual: bool,
|
941 |
+
new_dataset: bool, # The "new dataset" creates condition first -> resizes
|
942 |
+
imsize: tuple,
|
943 |
+
cascade:bool,
|
944 |
+
**unet_kwargs):
|
945 |
+
kwargs = dict()
|
946 |
+
coarse_stem_cfg = utils.load_config(coarse_stem_cfg)
|
947 |
+
kwargs = get_stem_unet_kwargs(coarse_stem_cfg)
|
948 |
+
kwargs.update(unet_kwargs)
|
949 |
+
super().__init__(z_channels=kwargs["z_channels"])
|
950 |
+
|
951 |
+
self.input_keypoints = kwargs["input_keypoints"]
|
952 |
+
self.input_keypoint_indices = kwargs["input_keypoint_indices"]
|
953 |
+
self.use_maskrcnn_mask = kwargs["use_maskrcnn_mask"]
|
954 |
+
self.imsize = imsize
|
955 |
+
self.residual = residual
|
956 |
+
self.new_dataset = new_dataset
|
957 |
+
|
958 |
+
|
959 |
+
# Setup coarse stem
|
960 |
+
stem_dims = [m*coarse_stem_cfg.generator.dim for m in coarse_stem_cfg.generator.dim_mults]
|
961 |
+
self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval()
|
962 |
+
self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize)
|
963 |
+
utils.set_requires_grad(self.coarse_stem, False)
|
964 |
+
|
965 |
+
self.stem_res_to_layer_idx = {
|
966 |
+
self.coarse_stem.imsize[0] // 2^i: stem_dims[i]
|
967 |
+
for i in range(len(stem_dims))
|
968 |
+
}
|
969 |
+
|
970 |
+
dim = kwargs["dim"]
|
971 |
+
dim_mults = kwargs["dim_mults"]
|
972 |
+
n_layers = len(dim_mults)
|
973 |
+
dims = [dim*s for s in dim_mults]
|
974 |
+
layer_attn = cast_tuple(kwargs["layer_attn"], n_layers)
|
975 |
+
num_resnet_blocks = cast_tuple(kwargs["num_resnet_blocks"], n_layers)
|
976 |
+
attn_cls = kwargs["attn_cls"]
|
977 |
+
if not isinstance(attn_cls, partial):
|
978 |
+
attn_cls = instantiate(attn_cls)
|
979 |
+
|
980 |
+
dec_blk = partial(
|
981 |
+
SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_dec"],
|
982 |
+
use_adain=kwargs["use_adain"] and kwargs["dec_style"],
|
983 |
+
w_dim=kwargs["w_dim"],
|
984 |
+
cross_attention=kwargs["cross_attention"],
|
985 |
+
cross_attention_len=kwargs["cross_attention_len"],
|
986 |
+
gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"]
|
987 |
+
)
|
988 |
+
enc_blk = partial(
|
989 |
+
SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_enc"],
|
990 |
+
use_adain=kwargs["use_adain"] and kwargs["enc_style"],
|
991 |
+
w_dim=kwargs["w_dim"],
|
992 |
+
cross_attention=kwargs["cross_attention"],
|
993 |
+
cross_attention_len=kwargs["cross_attention_len"],
|
994 |
+
gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"]
|
995 |
+
)
|
996 |
+
|
997 |
+
# Currently up/down sampling is done by bilinear upsampling.
|
998 |
+
# This can be simplified by replacing it with a strided upsampling layer...
|
999 |
+
self.encoder_attns = nn.ModuleList()
|
1000 |
+
self.encoder_unet_skips = nn.ModuleDict()
|
1001 |
+
self.encoder = nn.ModuleList()
|
1002 |
+
for lidx in range(n_layers):
|
1003 |
+
has_stem_feature = imsize[0]//2^lidx in self.stem_res_to_layer_idx and cascade
|
1004 |
+
next_layer_has_stem_features = lidx+1 < n_layers and imsize[0]//2^(lidx+1) in self.stem_res_to_layer_idx and cascade
|
1005 |
+
|
1006 |
+
dim_in = dims[lidx]
|
1007 |
+
dim_out = dims[min(lidx+1, n_layers-1)]
|
1008 |
+
res_blocks = nn.ModuleList()
|
1009 |
+
if has_stem_feature:
|
1010 |
+
prev_layer_has_attention = lidx != 0 and layer_attn[lidx-1]
|
1011 |
+
stem_lidx = self.stem_res_to_layer_idx[imsize[0]//2^lidx]
|
1012 |
+
self.encoder_unet_skips.add_module(
|
1013 |
+
str(imsize[0]//2^lidx),
|
1014 |
+
Conv2d(
|
1015 |
+
stem_dims[stem_lidx], dim_in, kernel_size=1,
|
1016 |
+
conv_clamp=kwargs["conv_clamp"],
|
1017 |
+
norm=nn.InstanceNorm2d(None),
|
1018 |
+
gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"],
|
1019 |
+
gain=np.sqrt(1/4) if prev_layer_has_attention else np.sqrt(1/3) # This + previous residual + attention
|
1020 |
+
)
|
1021 |
+
)
|
1022 |
+
for i in range(num_resnet_blocks[lidx]):
|
1023 |
+
is_last = num_resnet_blocks[lidx] - 1 == i
|
1024 |
+
cur_dim = dim_out if is_last else dim_in
|
1025 |
+
if not is_last:
|
1026 |
+
gain = np.sqrt(.5)
|
1027 |
+
elif next_layer_has_stem_features and layer_attn[lidx]:
|
1028 |
+
gain = np.sqrt(1/4)
|
1029 |
+
elif layer_attn[lidx] or next_layer_has_stem_features:
|
1030 |
+
gain = np.sqrt(1/3)
|
1031 |
+
else:
|
1032 |
+
gain = np.sqrt(.5)
|
1033 |
+
block = enc_blk(dim_in, cur_dim, skip_gain=gain)
|
1034 |
+
res_blocks.append(block)
|
1035 |
+
if layer_attn[lidx]:
|
1036 |
+
self.encoder_attns.append(attn_cls(dim=dim_out, gain=gain, fix_attention_again=True))
|
1037 |
+
else:
|
1038 |
+
self.encoder_attns.append(Identity())
|
1039 |
+
self.encoder.append(res_blocks)
|
1040 |
+
|
1041 |
+
# initialize decoder
|
1042 |
+
self.decoder = torch.nn.ModuleList()
|
1043 |
+
self.unet_layers = torch.nn.ModuleList()
|
1044 |
+
self.decoder_attns = torch.nn.ModuleList()
|
1045 |
+
for lidx in range(n_layers):
|
1046 |
+
dim_in = dims[min(-lidx, -1)]
|
1047 |
+
dim_out = dims[-1-lidx]
|
1048 |
+
res_blocks = nn.ModuleList()
|
1049 |
+
unet_skips = nn.ModuleList()
|
1050 |
+
for i in range(num_resnet_blocks[-lidx-1]):
|
1051 |
+
is_first = i == 0
|
1052 |
+
has_unet = is_first or kwargs["skip_all_unets"]
|
1053 |
+
is_last = i == num_resnet_blocks[-lidx-1] - 1
|
1054 |
+
cur_dim = dim_in if is_first else dim_out
|
1055 |
+
if has_unet and is_last and layer_attn[-lidx-1]: # x + residual + unet + layer attn
|
1056 |
+
gain = np.sqrt(1/4)
|
1057 |
+
elif has_unet: # x + residual + unet
|
1058 |
+
gain = np.sqrt(1/3)
|
1059 |
+
elif layer_attn[-lidx-1]: # x + residual + attention
|
1060 |
+
gain = np.sqrt(1/3)
|
1061 |
+
else: # x + residual
|
1062 |
+
gain = np.sqrt(1/2) # Only residual block
|
1063 |
+
block = dec_blk(cur_dim, dim_out, skip_gain=gain)
|
1064 |
+
res_blocks.append(block)
|
1065 |
+
if kwargs["skip_all_unets"] or is_first:
|
1066 |
+
unet_block = Conv2d(
|
1067 |
+
cur_dim, cur_dim, kernel_size=1, conv_clamp=kwargs["conv_clamp"],
|
1068 |
+
norm=nn.InstanceNorm2d(None),
|
1069 |
+
gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"],
|
1070 |
+
gain=gain)
|
1071 |
+
unet_skips.append(unet_block)
|
1072 |
+
else:
|
1073 |
+
unet_skips.append(torch.nn.Identity())
|
1074 |
+
if layer_attn[-lidx-1]:
|
1075 |
+
self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=True, gain=gain))
|
1076 |
+
else:
|
1077 |
+
self.decoder_attns.append(Identity())
|
1078 |
+
|
1079 |
+
self.decoder.append(res_blocks)
|
1080 |
+
self.unet_layers.append(unet_skips)
|
1081 |
+
|
1082 |
+
self.from_rgb = Conv2d(
|
1083 |
+
3 + 2 + 2*int(kwargs["use_maskrcnn_mask"]) + self.input_keypoints*len(kwargs["input_keypoint_indices"])
|
1084 |
+
, dim, 7)
|
1085 |
+
self.to_rgb = Conv2d(dim, 3, 1, activation="linear", conv_clamp=kwargs["conv_clamp"])
|
1086 |
+
|
1087 |
+
self.downsample = Upfirdn2d(down=2, fix_gain=True)
|
1088 |
+
self.upsample = Upfirdn2d(up=2, fix_gain=True)
|
1089 |
+
self.cascade = cascade
|
1090 |
+
if residual:
|
1091 |
+
nn.init.zeros_(self.to_rgb.weight.data)
|
1092 |
+
|
1093 |
+
def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, return_decoder_features=False, **kwargs):
|
1094 |
+
# Downsample for stem
|
1095 |
+
if z is None:
|
1096 |
+
z = self.get_z(img)
|
1097 |
+
|
1098 |
+
with torch.no_grad(): # Forward pass stem
|
1099 |
+
if w is None:
|
1100 |
+
w = self.style_net(z)
|
1101 |
+
img_stem = utils.denormalize_img(img)*255
|
1102 |
+
condition_stem = img_stem * mask + (1-mask)*127
|
1103 |
+
condition_stem = condition_stem.round()
|
1104 |
+
condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True)
|
1105 |
+
condition_stem = condition_stem / 255 *2 - 1
|
1106 |
+
mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float()
|
1107 |
+
maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float()
|
1108 |
+
stem_out = self.coarse_stem(
|
1109 |
+
condition=condition_stem, mask=mask_stem,
|
1110 |
+
maskrcnn_mask=maskrcnn_stem, w=w,
|
1111 |
+
keypoints=keypoints,
|
1112 |
+
return_decoder_features=True)
|
1113 |
+
stem_features = stem_out["decoder_features"]
|
1114 |
+
x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True)
|
1115 |
+
condition = condition*mask + (1-mask) * x_lr
|
1116 |
+
|
1117 |
+
if self.use_maskrcnn_mask:
|
1118 |
+
x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1)
|
1119 |
+
else:
|
1120 |
+
x = torch.cat((condition, mask, 1-mask), dim=1)
|
1121 |
+
if self.input_keypoints:
|
1122 |
+
keypoints = keypoints[:, self.input_keypoint_indices]
|
1123 |
+
one_hot_pose = spatial_embed_keypoints(keypoints, x)
|
1124 |
+
x = torch.cat((x, one_hot_pose), dim=1)
|
1125 |
+
x = self.from_rgb(x)
|
1126 |
+
x, unet_features = self.forward_enc(x, mask, w, stem_features)
|
1127 |
+
x, decoder_features = self.forward_dec(x, mask, w, unet_features)
|
1128 |
+
if self.residual:
|
1129 |
+
x = self.to_rgb(x) + condition
|
1130 |
+
else:
|
1131 |
+
x = self.to_rgb(x)
|
1132 |
+
out= dict(
|
1133 |
+
img=condition * mask + (1-mask) * x, # TODO: Probably do not want masked here... or ??
|
1134 |
+
unmasked=x,
|
1135 |
+
x_lowres=[condition]
|
1136 |
+
)
|
1137 |
+
if return_decoder_features:
|
1138 |
+
out["decoder_features"] = decoder_features
|
1139 |
+
return out
|
1140 |
+
|
1141 |
+
def forward_enc(self, x, mask, w, stem_features: List[torch.Tensor]):
|
1142 |
+
unet_features = []
|
1143 |
+
stem_features.reverse()
|
1144 |
+
for i, res_blocks in enumerate(self.encoder):
|
1145 |
+
is_last = i == len(self.encoder) - 1
|
1146 |
+
res = self.imsize[0]//2^i
|
1147 |
+
if str(res) in self.encoder_unet_skips.keys() and self.cascade:
|
1148 |
+
y = stem_features[self.stem_res_to_layer_idx[res]]
|
1149 |
+
y = self.encoder_unet_skips[i](y)
|
1150 |
+
x = y + x
|
1151 |
+
for block in res_blocks:
|
1152 |
+
x = block(x, w=w)
|
1153 |
+
unet_features.append(x)
|
1154 |
+
x = self.encoder_attns[i](x, mask)
|
1155 |
+
if not is_last:
|
1156 |
+
x = self.downsample(x)
|
1157 |
+
return x, unet_features
|
1158 |
+
|
1159 |
+
def forward_dec(self, x, mask, w, unet_features):
|
1160 |
+
features = []
|
1161 |
+
unet_features = iter(reversed(unet_features))
|
1162 |
+
for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)):
|
1163 |
+
is_last = i == len(self.decoder) - 1
|
1164 |
+
for skip, block in zip(unet_skip, res_blocks):
|
1165 |
+
skip_x = next(unet_features)
|
1166 |
+
if not isinstance(skip, torch.nn.Identity):
|
1167 |
+
skip_x = skip(skip_x)
|
1168 |
+
x = x + skip_x
|
1169 |
+
x = block(x, w=w)
|
1170 |
+
x = self.decoder_attns[i](x, mask)
|
1171 |
+
features.append(x)
|
1172 |
+
if not is_last:
|
1173 |
+
x = self.upsample(x)
|
1174 |
+
return x, features
|
1175 |
+
|
1176 |
+
def get_z(self, *args, **kwargs):
|
1177 |
+
return self.coarse_stem.get_z(*args, **kwargs)
|
1178 |
+
|
1179 |
+
@property
|
1180 |
+
def style_net(self):
|
1181 |
+
return self.coarse_stem.style_net
|
1182 |
+
|
1183 |
+
def update_w(self, *args, **kwargs):
|
1184 |
+
self.style_net.update_w(*args, **kwargs)
|
1185 |
+
|
1186 |
+
def get_w(self, z, update_emas):
|
1187 |
+
return self.style_net(z, update_emas=update_emas)
|
1188 |
+
|
1189 |
+
@torch.no_grad()
|
1190 |
+
def sample(self, truncation_value, **kwargs):
|
1191 |
+
if truncation_value is None:
|
1192 |
+
return self.forward(**kwargs)
|
1193 |
+
truncation_value = max(0, truncation_value)
|
1194 |
+
truncation_value = min(truncation_value, 1)
|
1195 |
+
w = self.get_w(self.get_z(kwargs["condition"]), False)
|
1196 |
+
w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value)
|
1197 |
+
return self.forward(**kwargs, w=w)
|
1198 |
+
|
1199 |
+
@torch.no_grad()
|
1200 |
+
def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
|
1201 |
+
if truncation_value is None:
|
1202 |
+
return self.forward(**kwargs)
|
1203 |
+
truncation_value = max(0, truncation_value)
|
1204 |
+
truncation_value = min(truncation_value, 1)
|
1205 |
+
w = self.get_w(self.get_z(kwargs["condition"]), False)
|
1206 |
+
if w_indices is None:
|
1207 |
+
w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w)))
|
1208 |
+
w_centers = self.style_net.w_centers[w_indices].to(w.device)
|
1209 |
+
w = w_centers.to(w.dtype).lerp(w, truncation_value)
|
1210 |
+
return self.forward(**kwargs, w=w)
|
dp2/generator/stylegan_unet.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from dp2.layers import Sequential
|
4 |
+
from dp2.layers.sg2_layers import Conv2d, FullyConnectedLayer, ResidualBlock
|
5 |
+
from .base import BaseStyleGAN
|
6 |
+
from typing import List, Tuple
|
7 |
+
from .utils import spatial_embed_keypoints, mask_output
|
8 |
+
|
9 |
+
|
10 |
+
def get_chsize(imsize, cnum, max_imsize, max_cnum_mul):
|
11 |
+
n = int(np.log2(max_imsize) - np.log2(imsize))
|
12 |
+
mul = min(2**n, max_cnum_mul)
|
13 |
+
ch = cnum * mul
|
14 |
+
return int(ch)
|
15 |
+
|
16 |
+
class StyleGANUnet(BaseStyleGAN):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
scale_grad: bool,
|
20 |
+
im_channels: int,
|
21 |
+
min_fmap_resolution: int,
|
22 |
+
imsize: List[int],
|
23 |
+
cnum: int,
|
24 |
+
max_cnum_mul: int,
|
25 |
+
mask_output: bool,
|
26 |
+
conv_clamp: int,
|
27 |
+
input_cse: bool,
|
28 |
+
cse_nc: int,
|
29 |
+
n_middle_blocks: int,
|
30 |
+
input_keypoints: bool,
|
31 |
+
n_keypoints: int,
|
32 |
+
input_keypoint_indices: Tuple[int],
|
33 |
+
fix_errors: bool,
|
34 |
+
**kwargs
|
35 |
+
) -> None:
|
36 |
+
super().__init__(**kwargs)
|
37 |
+
self.n_keypoints = n_keypoints
|
38 |
+
self.input_keypoint_indices = list(input_keypoint_indices)
|
39 |
+
self.input_keypoints = input_keypoints
|
40 |
+
assert not (input_cse and input_keypoints)
|
41 |
+
cse_nc = 0 if cse_nc is None else cse_nc
|
42 |
+
self.imsize = imsize
|
43 |
+
self._cnum = cnum
|
44 |
+
self._max_cnum_mul = max_cnum_mul
|
45 |
+
self._min_fmap_resolution = min_fmap_resolution
|
46 |
+
self._image_channels = im_channels
|
47 |
+
self._max_imsize = max(imsize)
|
48 |
+
self.input_cse = input_cse
|
49 |
+
self.gain_unet = np.sqrt(1/3)
|
50 |
+
n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1
|
51 |
+
encoder_layers = []
|
52 |
+
self.from_rgb = Conv2d(
|
53 |
+
im_channels + 1 + input_cse*(cse_nc+1) + input_keypoints*len(self.input_keypoint_indices),
|
54 |
+
cnum, 1
|
55 |
+
)
|
56 |
+
for i in range(n_levels): # Encoder layers
|
57 |
+
resolution = [x//2**i for x in imsize]
|
58 |
+
in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul)
|
59 |
+
second_ch = in_ch
|
60 |
+
out_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul)
|
61 |
+
down = 2
|
62 |
+
|
63 |
+
if i == 0: # first (lowest) block. Downsampling is performed at the start of the block
|
64 |
+
down = 1
|
65 |
+
if i == n_levels - 1:
|
66 |
+
out_ch = second_ch
|
67 |
+
block = ResidualBlock(in_ch, out_ch, down=down, conv_clamp=conv_clamp, fix_residual=fix_errors)
|
68 |
+
encoder_layers.append(block)
|
69 |
+
self._encoder_out_shape = [
|
70 |
+
get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul),
|
71 |
+
*resolution]
|
72 |
+
|
73 |
+
self.encoder = torch.nn.ModuleList(encoder_layers)
|
74 |
+
|
75 |
+
# initialize decoder
|
76 |
+
decoder_layers = []
|
77 |
+
for i in range(n_levels):
|
78 |
+
resolution = [x//2**(n_levels-1-i) for x in imsize]
|
79 |
+
in_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul)
|
80 |
+
out_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul)
|
81 |
+
if i == 0: # first (lowest) block
|
82 |
+
in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul)
|
83 |
+
|
84 |
+
up = 1
|
85 |
+
if i != n_levels - 1:
|
86 |
+
up = 2
|
87 |
+
block = ResidualBlock(
|
88 |
+
in_ch, out_ch, conv_clamp=conv_clamp, gain_out=np.sqrt(1/3),
|
89 |
+
w_dim=self.style_net.w_dim, norm=True, up=up,
|
90 |
+
fix_residual=fix_errors
|
91 |
+
)
|
92 |
+
decoder_layers.append(block)
|
93 |
+
if i != 0:
|
94 |
+
unet_block = Conv2d(
|
95 |
+
in_ch, in_ch, kernel_size=1, conv_clamp=conv_clamp, norm=True,
|
96 |
+
gain=np.sqrt(1/3) if fix_errors else np.sqrt(.5))
|
97 |
+
setattr(self, f"unet_block{i}", unet_block)
|
98 |
+
|
99 |
+
# Initialize "middle blocks" that do not have down/up sample
|
100 |
+
middle_blocks = []
|
101 |
+
for i in range(n_middle_blocks):
|
102 |
+
ch = get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul)
|
103 |
+
block = ResidualBlock(
|
104 |
+
ch, ch, conv_clamp=conv_clamp, gain_out=np.sqrt(.5) if fix_errors else np.sqrt(1/3),
|
105 |
+
w_dim=self.style_net.w_dim, norm=True,
|
106 |
+
)
|
107 |
+
middle_blocks.append(block)
|
108 |
+
if n_middle_blocks != 0:
|
109 |
+
self.middle_blocks = Sequential(*middle_blocks)
|
110 |
+
self.decoder = torch.nn.ModuleList(decoder_layers)
|
111 |
+
self.to_rgb = Conv2d(cnum, im_channels, 1, activation="linear", conv_clamp=conv_clamp)
|
112 |
+
# Initialize "middle blocks" that do not have down/up sample
|
113 |
+
self.decoder = torch.nn.ModuleList(decoder_layers)
|
114 |
+
self.scale_grad = scale_grad
|
115 |
+
self.mask_output = mask_output
|
116 |
+
|
117 |
+
def forward_dec(self, x, w, unet_features, condition, mask, s, **kwargs):
|
118 |
+
for i, layer in enumerate(self.decoder):
|
119 |
+
if i != 0:
|
120 |
+
unet_layer = getattr(self, f"unet_block{i}")
|
121 |
+
x = x + unet_layer(unet_features[-i])
|
122 |
+
x = layer(x, w=w, s=s)
|
123 |
+
x = self.to_rgb(x)
|
124 |
+
if self.mask_output:
|
125 |
+
x = mask_output(True, condition, x, mask)
|
126 |
+
return dict(img=x)
|
127 |
+
|
128 |
+
def forward_enc(self, condition, mask, embedding, keypoints, E_mask, **kwargs):
|
129 |
+
if self.input_cse:
|
130 |
+
x = torch.cat((condition, mask, embedding, E_mask), dim=1)
|
131 |
+
else:
|
132 |
+
x = torch.cat((condition, mask), dim=1)
|
133 |
+
if self.input_keypoints:
|
134 |
+
keypoints = keypoints[:, self.input_keypoint_indices]
|
135 |
+
one_hot_pose = spatial_embed_keypoints(keypoints, x)
|
136 |
+
x = torch.cat((x, one_hot_pose), dim=1)
|
137 |
+
x = self.from_rgb(x)
|
138 |
+
|
139 |
+
unet_features = []
|
140 |
+
for i, layer in enumerate(self.encoder):
|
141 |
+
x = layer(x)
|
142 |
+
if i != len(self.encoder)-1:
|
143 |
+
unet_features.append(x)
|
144 |
+
if hasattr(self, "middle_blocks"):
|
145 |
+
for layer in self.middle_blocks:
|
146 |
+
x = layer(x)
|
147 |
+
return x, unet_features
|
148 |
+
|
149 |
+
def forward(
|
150 |
+
self, condition, mask,
|
151 |
+
z=None, embedding=None, w=None, update_emas=False, x=None,
|
152 |
+
s=None,
|
153 |
+
keypoints=None,
|
154 |
+
unet_features=None,
|
155 |
+
E_mask=None,
|
156 |
+
**kwargs):
|
157 |
+
# Used to skip sampling from encoder in inference. E.g. for w projection.
|
158 |
+
if x is not None and unet_features is not None:
|
159 |
+
assert not self.training
|
160 |
+
else:
|
161 |
+
x, unet_features = self.forward_enc(condition, mask, embedding, keypoints, E_mask, **kwargs)
|
162 |
+
if w is None:
|
163 |
+
if z is None:
|
164 |
+
z = self.get_z(condition)
|
165 |
+
w = self.get_w(z, update_emas=update_emas)
|
166 |
+
return self.forward_dec(x, w, unet_features, condition, mask, s, **kwargs)
|
167 |
+
|
168 |
+
class ComodStyleUNet(StyleGANUnet):
|
169 |
+
|
170 |
+
def __init__(self, min_comod_res=4, lr_multiplier_comod=1, **kwargs) -> None:
|
171 |
+
super().__init__(**kwargs)
|
172 |
+
min_fmap = min(self._encoder_out_shape[1:])
|
173 |
+
enc_out_ch = self._encoder_out_shape[0]
|
174 |
+
n_down = int(np.ceil(np.log2(min_fmap) - np.log2(min_comod_res)))
|
175 |
+
comod_layers = []
|
176 |
+
in_ch = enc_out_ch
|
177 |
+
for i in range(n_down):
|
178 |
+
comod_layers.append(Conv2d(enc_out_ch, 256, kernel_size=3, down=2, lr_multiplier=lr_multiplier_comod))
|
179 |
+
in_ch = 256
|
180 |
+
if n_down == 0:
|
181 |
+
comod_layers = [Conv2d(in_ch, 256, kernel_size=3)]
|
182 |
+
comod_layers.append(torch.nn.Flatten())
|
183 |
+
out_res = [x//2**n_down for x in self._encoder_out_shape[1:]]
|
184 |
+
in_ch_fc = np.prod(out_res) * 256
|
185 |
+
comod_layers.append(FullyConnectedLayer(in_ch_fc, 512, lr_multiplier=lr_multiplier_comod))
|
186 |
+
self.comod_block = Sequential(*comod_layers)
|
187 |
+
self.comod_fc = FullyConnectedLayer(512+self.style_net.w_dim, self.style_net.w_dim, lr_multiplier=lr_multiplier_comod)
|
188 |
+
|
189 |
+
def forward_dec(self, x, w, unet_features, condition, mask, **kwargs):
|
190 |
+
y = self.comod_block(x)
|
191 |
+
y = torch.cat((y, w), dim=1)
|
192 |
+
y = self.comod_fc(y)
|
193 |
+
for i, layer in enumerate(self.decoder):
|
194 |
+
if i != 0:
|
195 |
+
unet_layer = getattr(self, f"unet_block{i}")
|
196 |
+
x = x + unet_layer(unet_features[-i], gain=np.sqrt(.5))
|
197 |
+
x = layer(x, w=y)
|
198 |
+
x = self.to_rgb(x)
|
199 |
+
if self.mask_output:
|
200 |
+
x = mask_output(True, condition, x, mask)
|
201 |
+
return dict(img=x)
|
202 |
+
|
203 |
+
def get_comod_y(self, batch, w):
|
204 |
+
x, unet_features = self.forward_enc(**batch)
|
205 |
+
y = self.comod_block(x)
|
206 |
+
y = torch.cat((y, w), dim=1)
|
207 |
+
y = self.comod_fc(y)
|
208 |
+
return y
|
dp2/generator/utils.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tops
|
3 |
+
import torch
|
4 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
5 |
+
|
6 |
+
|
7 |
+
@torch.no_grad()
|
8 |
+
def spatial_embed_keypoints(keypoints: torch.Tensor, x):
|
9 |
+
tops.assert_shape(keypoints, (None, None, 3))
|
10 |
+
B, N_K, _ = keypoints.shape
|
11 |
+
H, W = x.shape[-2:]
|
12 |
+
keypoint_spatial = torch.zeros(keypoints.shape[0], N_K, H, W, device=keypoints.device, dtype=torch.float32)
|
13 |
+
x, y, visible = keypoints.chunk(3, dim=2)
|
14 |
+
x = (x * W).round().long().clamp(0, W-1)
|
15 |
+
y = (y * H).round().long().clamp(0, H-1)
|
16 |
+
kp_idx = torch.arange(0, N_K, 1, device=keypoints.device, dtype=torch.long).view(1, -1, 1).repeat(B, 1, 1)
|
17 |
+
pos = (kp_idx*(H*W) + y*W + x + 1)
|
18 |
+
# Offset all by 1 to index invisible keypoints as 0
|
19 |
+
pos = (pos * visible.round().long()).squeeze(dim=-1)
|
20 |
+
keypoint_spatial = torch.zeros(keypoints.shape[0], N_K*H*W+1, device=keypoints.device, dtype=torch.float32)
|
21 |
+
keypoint_spatial.scatter_(1, pos, 1)
|
22 |
+
keypoint_spatial = keypoint_spatial[:, 1:].view(-1, N_K, H, W)
|
23 |
+
return keypoint_spatial
|
24 |
+
|
25 |
+
class MaskOutput(torch.autograd.Function):
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
@custom_fwd
|
29 |
+
def forward(ctx, x_real, x_fake, mask):
|
30 |
+
ctx.save_for_backward(mask)
|
31 |
+
out = x_real * mask + (1-mask) * x_fake
|
32 |
+
return out
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
@custom_bwd
|
36 |
+
def backward(ctx, grad_output):
|
37 |
+
fake_grad = grad_output
|
38 |
+
mask, = ctx.saved_tensors
|
39 |
+
fake_grad = fake_grad * (1 - mask)
|
40 |
+
known_percentage = mask.view(mask.shape[0], -1).mean(dim=1)
|
41 |
+
fake_grad = fake_grad / (1-known_percentage).view(-1, 1, 1, 1)
|
42 |
+
return None, fake_grad, None
|
43 |
+
|
44 |
+
|
45 |
+
def mask_output(scale_grad, x_real, x_fake, mask):
|
46 |
+
if scale_grad:
|
47 |
+
return MaskOutput.apply(x_real, x_fake, mask)
|
48 |
+
return x_real * mask + (1-mask) * x_fake
|
dp2/infer.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tops
|
2 |
+
import torch
|
3 |
+
from tops import checkpointer
|
4 |
+
from tops.config import instantiate
|
5 |
+
from tops.logger import warn
|
6 |
+
|
7 |
+
def load_generator_state(ckpt, G: torch.nn.Module, ckpt_mapper=None):
|
8 |
+
state = ckpt["EMA_generator"] if "EMA_generator" in ckpt else ckpt["running_average_generator"]
|
9 |
+
if ckpt_mapper is not None:
|
10 |
+
state = ckpt_mapper(state)
|
11 |
+
load_state_dict(G, state)
|
12 |
+
tops.logger.log(f"Generator loaded, num parameters: {tops.num_parameters(G)/1e6}M")
|
13 |
+
print(ckpt.keys())
|
14 |
+
if "w_centers" in ckpt:
|
15 |
+
print("Has w_centers!")
|
16 |
+
G.style_net.w_centers = ckpt["w_centers"]
|
17 |
+
tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}")
|
18 |
+
|
19 |
+
|
20 |
+
def build_trained_generator(cfg, map_location=None):
|
21 |
+
map_location = map_location if map_location is not None else tops.get_device()
|
22 |
+
G = instantiate(cfg.generator).to(map_location)
|
23 |
+
G.eval()
|
24 |
+
G.imsize = tuple(cfg.data.imsize) if hasattr(cfg, "data") else None
|
25 |
+
if hasattr(cfg, "ckpt_mapper"):
|
26 |
+
ckpt_mapper = instantiate(cfg.ckpt_mapper)
|
27 |
+
else:
|
28 |
+
ckpt_mapper = None
|
29 |
+
if "model_url" in cfg.common:
|
30 |
+
ckpt = tops.load_file_or_url(cfg.common.model_url, md5sum=cfg.common.model_md5sum)
|
31 |
+
load_generator_state(ckpt, G, ckpt_mapper)
|
32 |
+
return G
|
33 |
+
try:
|
34 |
+
ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
|
35 |
+
load_generator_state(ckpt, G, ckpt_mapper)
|
36 |
+
except FileNotFoundError as e:
|
37 |
+
tops.logger.warn(f"Did not find generator checkpoint in: {cfg.checkpoint_dir}")
|
38 |
+
return G
|
39 |
+
|
40 |
+
|
41 |
+
def build_trained_discriminator(cfg, map_location=None):
|
42 |
+
map_location = map_location if map_location is not None else tops.get_device()
|
43 |
+
D = instantiate(cfg.discriminator).to(map_location)
|
44 |
+
D.eval()
|
45 |
+
try:
|
46 |
+
ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
|
47 |
+
if hasattr(cfg, "ckpt_mapper_D"):
|
48 |
+
ckpt["discriminator"] = instantiate(cfg.ckpt_mapper_D)(ckpt["discriminator"])
|
49 |
+
D.load_state_dict(ckpt["discriminator"])
|
50 |
+
except FileNotFoundError as e:
|
51 |
+
tops.logger.warn(f"Did not find discriminator checkpoint in: {cfg.checkpoint_dir}")
|
52 |
+
return D
|
53 |
+
|
54 |
+
|
55 |
+
def load_state_dict(module: torch.nn.Module, state_dict: dict):
|
56 |
+
module_sd = module.state_dict()
|
57 |
+
to_remove = []
|
58 |
+
for key, item in state_dict.items():
|
59 |
+
if key not in module_sd:
|
60 |
+
continue
|
61 |
+
if item.shape != module_sd[key].shape:
|
62 |
+
to_remove.append(key)
|
63 |
+
warn(f"Incorrect shape. Current model: {module_sd[key].shape}, in state dict: {item.shape} for key: {key}")
|
64 |
+
for key in to_remove:
|
65 |
+
state_dict.pop(key)
|
66 |
+
for key, item in state_dict.items():
|
67 |
+
if key not in module_sd:
|
68 |
+
warn(f"Did not fin key in model state dict: {key}")
|
69 |
+
for key, item in module_sd.items():
|
70 |
+
if key not in state_dict:
|
71 |
+
warn(f"Did not find key in state dict: {key}")
|
72 |
+
module.load_state_dict(state_dict, strict=False)
|
dp2/layers/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
import torch
|
3 |
+
import tops
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
class Sequential(nn.Sequential):
|
7 |
+
|
8 |
+
def forward(self, x: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]:
|
9 |
+
for module in self:
|
10 |
+
x = module(x, **kwargs)
|
11 |
+
return x
|
12 |
+
|
13 |
+
class Module(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self, *args, **kwargs):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
def extra_repr(self):
|
19 |
+
num_params = tops.num_parameters(self) / 10**6
|
20 |
+
return f"Num params: {num_params:.3f}M"
|
dp2/layers/sg2_layers.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import tops
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from sg3_torch_utils.ops import conv2d_resample
|
7 |
+
from sg3_torch_utils.ops import upfirdn2d
|
8 |
+
from sg3_torch_utils.ops import bias_act
|
9 |
+
from sg3_torch_utils.ops.fma import fma
|
10 |
+
|
11 |
+
|
12 |
+
class FullyConnectedLayer(torch.nn.Module):
|
13 |
+
def __init__(self,
|
14 |
+
in_features, # Number of input features.
|
15 |
+
out_features, # Number of output features.
|
16 |
+
bias = True, # Apply additive bias before the activation function?
|
17 |
+
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
|
18 |
+
lr_multiplier = 1, # Learning rate multiplier.
|
19 |
+
bias_init = 0, # Initial value for the additive bias.
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
self.repr = dict(
|
23 |
+
in_features=in_features, out_features=out_features, bias=bias,
|
24 |
+
activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init)
|
25 |
+
self.activation = activation
|
26 |
+
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
|
27 |
+
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
|
28 |
+
self.weight_gain = lr_multiplier / np.sqrt(in_features)
|
29 |
+
self.bias_gain = lr_multiplier
|
30 |
+
self.in_features = in_features
|
31 |
+
self.out_features = out_features
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
w = self.weight * self.weight_gain
|
35 |
+
b = self.bias
|
36 |
+
if b is not None and self.bias_gain != 1:
|
37 |
+
b = b * self.bias_gain
|
38 |
+
x = F.linear(x, w)
|
39 |
+
x = bias_act.bias_act(x, b, act=self.activation)
|
40 |
+
return x
|
41 |
+
|
42 |
+
def extra_repr(self) -> str:
|
43 |
+
return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
|
44 |
+
|
45 |
+
|
46 |
+
class Conv2d(torch.nn.Module):
|
47 |
+
def __init__(self,
|
48 |
+
in_channels, # Number of input channels.
|
49 |
+
out_channels, # Number of output channels.
|
50 |
+
kernel_size = 3, # Convolution kernel size.
|
51 |
+
up = 1, # Integer upsampling factor.
|
52 |
+
down = 1, # Integer downsampling factor
|
53 |
+
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
54 |
+
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
55 |
+
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
56 |
+
bias = True,
|
57 |
+
norm = False,
|
58 |
+
lr_multiplier=1,
|
59 |
+
bias_init=0,
|
60 |
+
w_dim=None,
|
61 |
+
gain=1,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
if norm:
|
65 |
+
self.norm = torch.nn.InstanceNorm2d(None)
|
66 |
+
assert norm in [True, False]
|
67 |
+
self.up = up
|
68 |
+
self.down = down
|
69 |
+
self.activation = activation
|
70 |
+
self.conv_clamp = conv_clamp if conv_clamp is None else conv_clamp * gain
|
71 |
+
self.out_channels = out_channels
|
72 |
+
self.in_channels = in_channels
|
73 |
+
self.padding = kernel_size // 2
|
74 |
+
|
75 |
+
self.repr = dict(
|
76 |
+
in_channels=in_channels, out_channels=out_channels,
|
77 |
+
kernel_size=kernel_size, up=up, down=down,
|
78 |
+
activation=activation, resample_filter=resample_filter, conv_clamp=conv_clamp, bias=bias,
|
79 |
+
)
|
80 |
+
|
81 |
+
if self.up == 1 and self.down == 1:
|
82 |
+
self.resample_filter = None
|
83 |
+
else:
|
84 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
85 |
+
|
86 |
+
self.act_gain = bias_act.activation_funcs[activation].def_gain * gain
|
87 |
+
self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2))
|
88 |
+
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]))
|
89 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels])+bias_init) if bias else None
|
90 |
+
self.bias_gain = lr_multiplier
|
91 |
+
if w_dim is not None:
|
92 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
93 |
+
self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0)
|
94 |
+
|
95 |
+
def forward(self, x, w=None, s=None):
|
96 |
+
tops.assert_shape(x, [None, self.weight.shape[1], None, None])
|
97 |
+
if s is not None:
|
98 |
+
s = s[..., :self.in_channels*2]
|
99 |
+
gamma, beta = s.view(-1, self.in_channels*2, 1, 1).chunk(2, dim=1)
|
100 |
+
x = fma(x, gamma, beta)
|
101 |
+
elif hasattr(self, "affine"):
|
102 |
+
gamma = self.affine(w).view(-1, self.in_channels, 1, 1)
|
103 |
+
beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1)
|
104 |
+
x = fma(x, gamma, beta)
|
105 |
+
w = self.weight * self.weight_gain
|
106 |
+
# Removing flip weight is not safe.
|
107 |
+
x = conv2d_resample.conv2d_resample(x, w, self.resample_filter, self.up, self.down, self.padding, flip_weight=self.up==1)
|
108 |
+
if hasattr(self, "norm"):
|
109 |
+
x = self.norm(x)
|
110 |
+
b = self.bias * self.bias_gain if self.bias is not None else None
|
111 |
+
x = bias_act.bias_act(x, b, act=self.activation, gain=self.act_gain, clamp=self.conv_clamp)
|
112 |
+
return x
|
113 |
+
|
114 |
+
def extra_repr(self) -> str:
|
115 |
+
return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
|
116 |
+
|
117 |
+
|
118 |
+
class Block(torch.nn.Module):
|
119 |
+
def __init__(self,
|
120 |
+
in_channels, # Number of input channels, 0 = first block.
|
121 |
+
out_channels, # Number of output channels.
|
122 |
+
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
123 |
+
up = 1,
|
124 |
+
down = 1,
|
125 |
+
**layer_kwargs, # Arguments for SynthesisLayer.
|
126 |
+
):
|
127 |
+
super().__init__()
|
128 |
+
self.in_channels = in_channels
|
129 |
+
self.down = down
|
130 |
+
self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs)
|
131 |
+
self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, **layer_kwargs)
|
132 |
+
|
133 |
+
def forward(self, x, **layer_kwargs):
|
134 |
+
x = self.conv0(x, **layer_kwargs)
|
135 |
+
x = self.conv1(x, **layer_kwargs)
|
136 |
+
return x
|
137 |
+
|
138 |
+
|
139 |
+
class ResidualBlock(torch.nn.Module):
|
140 |
+
def __init__(self,
|
141 |
+
in_channels, # Number of input channels, 0 = first block.
|
142 |
+
out_channels, # Number of output channels.
|
143 |
+
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
144 |
+
up = 1,
|
145 |
+
down = 1,
|
146 |
+
gain_out=np.sqrt(0.5),
|
147 |
+
fix_residual: bool = False,
|
148 |
+
**layer_kwargs, # Arguments for conv layer.
|
149 |
+
):
|
150 |
+
super().__init__()
|
151 |
+
self.in_channels = in_channels
|
152 |
+
self.out_channels = out_channels
|
153 |
+
self.down = down
|
154 |
+
self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs)
|
155 |
+
|
156 |
+
self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, gain=gain_out,**layer_kwargs)
|
157 |
+
|
158 |
+
self.skip = Conv2d(
|
159 |
+
in_channels, out_channels, kernel_size=1, bias=False, up=up, down=down,
|
160 |
+
activation="linear" if fix_residual else "lrelu",
|
161 |
+
gain=gain_out
|
162 |
+
)
|
163 |
+
self.gain_out = gain_out
|
164 |
+
|
165 |
+
def forward(self, x, w=None, s=None, **layer_kwargs):
|
166 |
+
y = self.skip(x)
|
167 |
+
s_ = next(s) if s is not None else None
|
168 |
+
x = self.conv0(x, w, s=s_, **layer_kwargs)
|
169 |
+
s_ = next(s) if s is not None else None
|
170 |
+
x = self.conv1(x, w, s=s_, **layer_kwargs)
|
171 |
+
x = y + x
|
172 |
+
return x
|
173 |
+
|
174 |
+
|
175 |
+
class MinibatchStdLayer(torch.nn.Module):
|
176 |
+
def __init__(self, group_size, num_channels=1):
|
177 |
+
super().__init__()
|
178 |
+
self.group_size = group_size
|
179 |
+
self.num_channels = num_channels
|
180 |
+
|
181 |
+
def forward(self, x):
|
182 |
+
N, C, H, W = x.shape
|
183 |
+
with tops.suppress_tracer_warnings(): # as_tensor results are registered as constants
|
184 |
+
G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
|
185 |
+
F = self.num_channels
|
186 |
+
c = C // F
|
187 |
+
|
188 |
+
y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
|
189 |
+
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
|
190 |
+
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
|
191 |
+
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
|
192 |
+
y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels.
|
193 |
+
y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
|
194 |
+
y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
|
195 |
+
x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
|
196 |
+
return x
|
197 |
+
|
198 |
+
#----------------------------------------------------------------------------
|
199 |
+
|
200 |
+
class DiscriminatorEpilogue(torch.nn.Module):
|
201 |
+
def __init__(self,
|
202 |
+
in_channels, # Number of input channels.
|
203 |
+
resolution: List[int], # Resolution of this block.
|
204 |
+
mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
205 |
+
mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
206 |
+
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
207 |
+
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
208 |
+
):
|
209 |
+
super().__init__()
|
210 |
+
self.in_channels = in_channels
|
211 |
+
self.resolution = resolution
|
212 |
+
self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
|
213 |
+
self.conv = Conv2d(
|
214 |
+
in_channels + mbstd_num_channels, in_channels,
|
215 |
+
kernel_size=3, activation=activation, conv_clamp=conv_clamp)
|
216 |
+
self.fc = FullyConnectedLayer(in_channels * resolution[0]*resolution[1], in_channels, activation=activation)
|
217 |
+
self.out = FullyConnectedLayer(in_channels, 1)
|
218 |
+
|
219 |
+
def forward(self, x):
|
220 |
+
tops.assert_shape(x, [None, self.in_channels, *self.resolution]) # [NCHW]
|
221 |
+
# Main layers.
|
222 |
+
if self.mbstd is not None:
|
223 |
+
x = self.mbstd(x)
|
224 |
+
x = self.conv(x)
|
225 |
+
x = self.fc(x.flatten(1))
|
226 |
+
x = self.out(x)
|
227 |
+
return x
|
dp2/loss/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .sg2_loss import StyleGAN2Loss
|
dp2/loss/pl_regularization.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tops
|
3 |
+
import numpy as np
|
4 |
+
from sg3_torch_utils.ops import conv2d_gradfix
|
5 |
+
|
6 |
+
pl_mean_total = torch.zeros([])
|
7 |
+
|
8 |
+
class PLRegularization:
|
9 |
+
|
10 |
+
def __init__(self, weight: float, batch_shrink: int, pl_decay:float, scale_by_mask: bool,**kwargs):
|
11 |
+
self.pl_mean = torch.zeros([], device=tops.get_device())
|
12 |
+
self.pl_weight = weight
|
13 |
+
self.batch_shrink = batch_shrink
|
14 |
+
self.pl_decay = pl_decay
|
15 |
+
self.scale_by_mask = scale_by_mask
|
16 |
+
|
17 |
+
def __call__(self, G, batch, grad_scaler):
|
18 |
+
batch_size = batch["img"].shape[0] // self.batch_shrink
|
19 |
+
batch = {k: v[:batch_size] for k, v in batch.items() if k != "embed_map"}
|
20 |
+
if "embed_map" in batch:
|
21 |
+
batch["embed_map"] = batch["embed_map"]
|
22 |
+
z = G.get_z(batch["img"])
|
23 |
+
|
24 |
+
with torch.cuda.amp.autocast(tops.AMP()):
|
25 |
+
gen_ws = G.style_net(z)
|
26 |
+
gen_img = G(**batch, w=gen_ws)["img"].float()
|
27 |
+
pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
|
28 |
+
with conv2d_gradfix.no_weight_gradients():
|
29 |
+
# Sums over HWC
|
30 |
+
pl_grads = torch.autograd.grad(
|
31 |
+
outputs=[grad_scaler.scale(gen_img * pl_noise)],
|
32 |
+
inputs=[gen_ws],
|
33 |
+
create_graph=True,
|
34 |
+
grad_outputs=torch.ones_like(gen_img),
|
35 |
+
only_inputs=True)[0]
|
36 |
+
|
37 |
+
pl_grads = pl_grads.float() / grad_scaler.get_scale()
|
38 |
+
if self.scale_by_mask:
|
39 |
+
# Percentage of pixels known
|
40 |
+
scaling = batch["mask"].flatten(start_dim=1).mean(dim=1).view(-1, 1)
|
41 |
+
pl_grads = pl_grads / scaling
|
42 |
+
pl_lengths = pl_grads.square().sum(1).sqrt()
|
43 |
+
pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
|
44 |
+
if not torch.isnan(pl_mean).any():
|
45 |
+
self.pl_mean.copy_(pl_mean.detach())
|
46 |
+
pl_penalty = (pl_lengths - pl_mean).square()
|
47 |
+
to_log = dict(pl_penalty=pl_penalty.mean().detach())
|
48 |
+
return pl_penalty.view(-1) * self.pl_weight, to_log
|
dp2/loss/r1_regularization.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tops
|
3 |
+
|
4 |
+
def r1_regularization(
|
5 |
+
real_img, real_score, mask, lambd: float, lazy_reg_interval: int,
|
6 |
+
lazy_regularization: bool,
|
7 |
+
scaler: torch.cuda.amp.GradScaler, mask_out: bool,
|
8 |
+
mask_out_scale: bool,
|
9 |
+
**kwargs
|
10 |
+
):
|
11 |
+
grad = torch.autograd.grad(
|
12 |
+
outputs=scaler.scale(real_score),
|
13 |
+
inputs=real_img,
|
14 |
+
grad_outputs=torch.ones_like(real_score),
|
15 |
+
create_graph=True,
|
16 |
+
only_inputs=True,
|
17 |
+
)[0]
|
18 |
+
inv_scale = 1.0 / scaler.get_scale()
|
19 |
+
grad = grad * inv_scale
|
20 |
+
with torch.cuda.amp.autocast(tops.AMP()):
|
21 |
+
if mask_out:
|
22 |
+
grad = grad * (1 - mask)
|
23 |
+
grad = grad.square().sum(dim=[1, 2, 3])
|
24 |
+
if mask_out and mask_out_scale:
|
25 |
+
total_pixels = real_img.shape[1] * real_img.shape[2] * real_img.shape[3]
|
26 |
+
n_fake = (1-mask).sum(dim=[1, 2, 3])
|
27 |
+
scaling = total_pixels / n_fake
|
28 |
+
grad = grad * scaling
|
29 |
+
if lazy_regularization:
|
30 |
+
lambd_ = lambd * lazy_reg_interval / 2 # From stylegan2, lazy regularization
|
31 |
+
return grad * lambd_, grad.detach()
|
dp2/loss/sg2_loss.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import torch
|
3 |
+
import tops
|
4 |
+
from tops import logger
|
5 |
+
from dp2.utils import forward_D_fake
|
6 |
+
from .utils import nsgan_d_loss, nsgan_g_loss
|
7 |
+
from .r1_regularization import r1_regularization
|
8 |
+
from .pl_regularization import PLRegularization
|
9 |
+
|
10 |
+
class StyleGAN2Loss:
|
11 |
+
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
D,
|
15 |
+
G,
|
16 |
+
r1_opts: dict,
|
17 |
+
EP_lambd: float,
|
18 |
+
lazy_reg_interval: int,
|
19 |
+
lazy_regularization: bool,
|
20 |
+
pl_reg_opts: dict,
|
21 |
+
) -> None:
|
22 |
+
self.gradient_step_D = 0
|
23 |
+
self._lazy_reg_interval = lazy_reg_interval
|
24 |
+
self.D = D
|
25 |
+
self.G = G
|
26 |
+
self.EP_lambd = EP_lambd
|
27 |
+
self.lazy_regularization = lazy_regularization
|
28 |
+
self.r1_reg = functools.partial(
|
29 |
+
r1_regularization, **r1_opts, lazy_reg_interval=lazy_reg_interval,
|
30 |
+
lazy_regularization=lazy_regularization)
|
31 |
+
self.do_PL_Reg = False
|
32 |
+
if pl_reg_opts.weight > 0:
|
33 |
+
self.pl_reg = PLRegularization(**pl_reg_opts)
|
34 |
+
self.do_PL_Reg = True
|
35 |
+
self.pl_start_nimg = pl_reg_opts.start_nimg
|
36 |
+
|
37 |
+
def D_loss(self, batch: dict, grad_scaler):
|
38 |
+
to_log = {}
|
39 |
+
# Forward through G and D
|
40 |
+
do_GP = self.lazy_regularization and self.gradient_step_D % self._lazy_reg_interval == 0
|
41 |
+
if do_GP:
|
42 |
+
batch["img"] = batch["img"].detach().requires_grad_(True)
|
43 |
+
with torch.cuda.amp.autocast(enabled=tops.AMP()):
|
44 |
+
with torch.no_grad():
|
45 |
+
G_fake = self.G(**batch, update_emas=True)
|
46 |
+
D_out_real = self.D(**batch)
|
47 |
+
|
48 |
+
D_out_fake = forward_D_fake(batch, G_fake["img"], self.D)
|
49 |
+
|
50 |
+
# Non saturating loss
|
51 |
+
nsgan_loss = nsgan_d_loss(D_out_real["score"], D_out_fake["score"])
|
52 |
+
tops.assert_shape(nsgan_loss, (batch["img"].shape[0], ))
|
53 |
+
to_log["d_loss"] = nsgan_loss.mean()
|
54 |
+
total_loss = nsgan_loss
|
55 |
+
epsilon_penalty = D_out_real["score"].pow(2).view(-1)
|
56 |
+
to_log["epsilon_penalty"] = epsilon_penalty.mean()
|
57 |
+
tops.assert_shape(epsilon_penalty, total_loss.shape)
|
58 |
+
total_loss = total_loss + epsilon_penalty * self.EP_lambd
|
59 |
+
|
60 |
+
# Improved gradient penalty with lazy regularization
|
61 |
+
# Gradient penalty applies specialized autocast.
|
62 |
+
if do_GP:
|
63 |
+
gradient_pen, grad_unscaled = self.r1_reg(batch["img"], D_out_real["score"], batch["mask"], scaler=grad_scaler)
|
64 |
+
to_log["r1_gradient_penalty"] = grad_unscaled.mean()
|
65 |
+
tops.assert_shape(gradient_pen, total_loss.shape)
|
66 |
+
total_loss = total_loss + gradient_pen
|
67 |
+
|
68 |
+
batch["img"] = batch["img"].detach().requires_grad_(False)
|
69 |
+
if "score" in D_out_real:
|
70 |
+
to_log["real_scores"] = D_out_real["score"]
|
71 |
+
to_log["real_logits_sign"] = D_out_real["score"].sign()
|
72 |
+
to_log["fake_logits_sign"] = D_out_fake["score"].sign()
|
73 |
+
to_log["fake_scores"] = D_out_fake["score"]
|
74 |
+
to_log = {key: item.mean().detach() for key, item in to_log.items()}
|
75 |
+
self.gradient_step_D += 1
|
76 |
+
return total_loss.mean(), to_log
|
77 |
+
|
78 |
+
def G_loss(self, batch: dict, grad_scaler):
|
79 |
+
with torch.cuda.amp.autocast(enabled=tops.AMP()):
|
80 |
+
to_log = {}
|
81 |
+
# Forward through G and D
|
82 |
+
G_fake = self.G(**batch)
|
83 |
+
D_out_fake = forward_D_fake(batch, G_fake["img"], self.D)
|
84 |
+
# Adversarial Loss
|
85 |
+
total_loss = nsgan_g_loss(D_out_fake["score"]).view(-1)
|
86 |
+
to_log["g_loss"] = total_loss.mean()
|
87 |
+
tops.assert_shape(total_loss, (batch["img"].shape[0], ))
|
88 |
+
|
89 |
+
if self.do_PL_Reg and logger.global_step() >= self.pl_start_nimg:
|
90 |
+
pl_reg, to_log_ = self.pl_reg(self.G, batch, grad_scaler=grad_scaler)
|
91 |
+
total_loss = total_loss + pl_reg.mean()
|
92 |
+
to_log.update(to_log_)
|
93 |
+
to_log = {key: item.mean().detach() for key, item in to_log.items()}
|
94 |
+
return total_loss.mean(), to_log
|
dp2/loss/utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
def nsgan_g_loss(fake_score):
|
5 |
+
"""
|
6 |
+
Non-saturating criterion from Goodfellow et al. 2014
|
7 |
+
"""
|
8 |
+
return torch.nn.functional.softplus(-fake_score)
|
9 |
+
|
10 |
+
|
11 |
+
def nsgan_d_loss(real_score, fake_score):
|
12 |
+
"""
|
13 |
+
Non-saturating criterion from Goodfellow et al. 2014
|
14 |
+
"""
|
15 |
+
d_loss = F.softplus(-real_score) + F.softplus(fake_score)
|
16 |
+
return d_loss.view(-1)
|
17 |
+
|
18 |
+
|
19 |
+
def smooth_masked_l1_loss(x, target, mask):
|
20 |
+
"""
|
21 |
+
Pixel-wise l1 loss for the area indicated by mask
|
22 |
+
"""
|
23 |
+
# Beta=.1 <-> square loss if pixel difference <= 12.8
|
24 |
+
l1 = F.smooth_l1_loss(x*mask, target*mask, beta=.1, reduction="none").sum(dim=[1,2,3]) / mask.sum(dim=[1, 2, 3])
|
25 |
+
return l1
|
dp2/metrics/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .torch_metrics import compute_metrics_iteratively
|
2 |
+
from .fid import compute_fid
|
3 |
+
from .ppl import calculate_ppl
|
dp2/metrics/fid.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tops
|
2 |
+
from dp2 import utils
|
3 |
+
from pathlib import Path
|
4 |
+
from torch_fidelity.generative_model_modulewrapper import GenerativeModelModuleWrapper
|
5 |
+
import torch
|
6 |
+
import torch_fidelity
|
7 |
+
|
8 |
+
|
9 |
+
class GeneratorIteratorWrapper(GenerativeModelModuleWrapper):
|
10 |
+
|
11 |
+
def __init__(self, generator, dataloader, zero_z: bool, n_diverse: int):
|
12 |
+
if isinstance(generator, utils.EMA):
|
13 |
+
generator = generator.generator
|
14 |
+
z_size = generator.z_channels
|
15 |
+
super().__init__(generator, z_size, "normal", 0)
|
16 |
+
self.zero_z = zero_z
|
17 |
+
self.dataloader = iter(dataloader)
|
18 |
+
self.n_diverse = n_diverse
|
19 |
+
self.cur_div_idx = 0
|
20 |
+
|
21 |
+
@torch.no_grad()
|
22 |
+
def forward(self, z, **kwargs):
|
23 |
+
if self.cur_div_idx == 0:
|
24 |
+
self.batch = next(self.dataloader)
|
25 |
+
if self.zero_z:
|
26 |
+
z = z.zero_()
|
27 |
+
self.cur_div_idx += 1
|
28 |
+
self.cur_div_idx = 0 if self.cur_div_idx == self.n_diverse else self.cur_div_idx
|
29 |
+
with torch.cuda.amp.autocast(enabled=tops.AMP()):
|
30 |
+
img = self.module(**self.batch)["img"]
|
31 |
+
img = (utils.denormalize_img(img)*255).byte()
|
32 |
+
return img
|
33 |
+
|
34 |
+
|
35 |
+
def compute_fid(generator, dataloader, real_directory, n_source, zero_z, n_diverse):
|
36 |
+
generator = GeneratorIteratorWrapper(generator, dataloader, zero_z, n_diverse)
|
37 |
+
batch_size = dataloader.batch_size
|
38 |
+
num_samples = (n_source * n_diverse) // batch_size * batch_size
|
39 |
+
assert n_diverse >= 1
|
40 |
+
assert (not zero_z) or n_diverse == 1
|
41 |
+
assert num_samples % batch_size == 0
|
42 |
+
assert n_source <= batch_size * len(dataloader), (batch_size*len(dataloader), n_source, n_diverse)
|
43 |
+
metrics = torch_fidelity.calculate_metrics(
|
44 |
+
input1=generator,
|
45 |
+
input2=real_directory,
|
46 |
+
cuda=torch.cuda.is_available(),
|
47 |
+
fid=True,
|
48 |
+
input2_cache_name="_".join(Path(real_directory).parts) + "_cached",
|
49 |
+
input1_model_num_samples=int(num_samples),
|
50 |
+
batch_size=dataloader.batch_size
|
51 |
+
)
|
52 |
+
return metrics["frechet_inception_distance"]
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
import click
|
57 |
+
from dp2.config import Config
|
58 |
+
from dp2.data import build_dataloader_val
|
59 |
+
from dp2.infer import build_trained_generator
|
60 |
+
@click.command()
|
61 |
+
@click.argument("config_path")
|
62 |
+
@click.option("--n_source", default=200, type=int)
|
63 |
+
@click.option("--n_diverse", default=5, type=int)
|
64 |
+
@click.option("--zero_z", default=False, is_flag=True)
|
65 |
+
def run(config_path, n_source: int, n_diverse: int, zero_z: bool):
|
66 |
+
cfg = Config.fromfile(config_path)
|
67 |
+
dataloader = build_dataloader_val(cfg)
|
68 |
+
generator, _ = build_trained_generator(cfg)
|
69 |
+
print(compute_fid(
|
70 |
+
generator, dataloader, cfg.fid_real_directory, n_source, zero_z, n_diverse))
|
71 |
+
|
72 |
+
run()
|
dp2/metrics/fid_clip.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import torch
|
3 |
+
import torchvision
|
4 |
+
from pathlib import Path
|
5 |
+
from dp2 import utils
|
6 |
+
import tops
|
7 |
+
try:
|
8 |
+
import clip
|
9 |
+
except ImportError:
|
10 |
+
print("Could not import clip.")
|
11 |
+
from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric
|
12 |
+
clip_model = None
|
13 |
+
clip_preprocess = None
|
14 |
+
|
15 |
+
|
16 |
+
@torch.no_grad()
|
17 |
+
def compute_fid_clip(
|
18 |
+
dataloader, generator,
|
19 |
+
cache_directory,
|
20 |
+
data_len=None,
|
21 |
+
**kwargs
|
22 |
+
) -> dict:
|
23 |
+
"""
|
24 |
+
FID CLIP following the description in The Role of ImageNet Classes in Frechet Inception Distance, Thomas Kynkaamniemi et al.
|
25 |
+
Args:
|
26 |
+
n_samples (int): Creates N samples from same image to calculate stats
|
27 |
+
"""
|
28 |
+
global clip_model, clip_preprocess
|
29 |
+
if clip_model is None:
|
30 |
+
clip_model, preprocess = clip.load("ViT-B/32", device="cpu")
|
31 |
+
normalize_fn = preprocess.transforms[-1]
|
32 |
+
img_mean = normalize_fn.mean
|
33 |
+
img_std = normalize_fn.std
|
34 |
+
clip_model = tops.to_cuda(clip_model.visual)
|
35 |
+
clip_preprocess = tops.to_cuda(torch.nn.Sequential(
|
36 |
+
torchvision.transforms.Resize((224, 224), interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
|
37 |
+
torchvision.transforms.Normalize(img_mean, img_std)
|
38 |
+
))
|
39 |
+
cache_directory = Path(cache_directory)
|
40 |
+
if data_len is None:
|
41 |
+
data_len = len(dataloader)*dataloader.batch_size
|
42 |
+
fid_cache_path = cache_directory.joinpath("fid_stats_clip.pkl")
|
43 |
+
has_fid_cache = fid_cache_path.is_file()
|
44 |
+
if not has_fid_cache:
|
45 |
+
fid_features_real = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device())
|
46 |
+
fid_features_fake = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device())
|
47 |
+
eidx = 0
|
48 |
+
n_samples_seen = 0
|
49 |
+
for batch in utils.tqdm_(iter(dataloader), desc="Computing FID CLIP."):
|
50 |
+
sidx = eidx
|
51 |
+
eidx = sidx + batch["img"].shape[0]
|
52 |
+
n_samples_seen += batch["img"].shape[0]
|
53 |
+
with torch.cuda.amp.autocast(tops.AMP()):
|
54 |
+
fakes = generator(**batch)["img"]
|
55 |
+
real_data = batch["img"]
|
56 |
+
fakes = utils.denormalize_img(fakes)
|
57 |
+
real_data = utils.denormalize_img(real_data)
|
58 |
+
if not has_fid_cache:
|
59 |
+
real_data = clip_preprocess(real_data)
|
60 |
+
fid_features_real[sidx:eidx] = clip_model(real_data)
|
61 |
+
fakes = clip_preprocess(fakes)
|
62 |
+
fid_features_fake[sidx:eidx] = clip_model(fakes)
|
63 |
+
fid_features_fake = fid_features_fake[:n_samples_seen]
|
64 |
+
fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu()
|
65 |
+
if has_fid_cache:
|
66 |
+
if tops.rank() == 0:
|
67 |
+
with open(fid_cache_path, "rb") as fp:
|
68 |
+
fid_stat_real = pickle.load(fp)
|
69 |
+
else:
|
70 |
+
fid_features_real = fid_features_real[:n_samples_seen]
|
71 |
+
fid_features_real = tops.all_gather_uneven(fid_features_real).cpu()
|
72 |
+
assert fid_features_real.shape == fid_features_fake.shape
|
73 |
+
if tops.rank() == 0:
|
74 |
+
fid_stat_real = fid_features_to_statistics(fid_features_real)
|
75 |
+
cache_directory.mkdir(exist_ok=True, parents=True)
|
76 |
+
with open(fid_cache_path, "wb") as fp:
|
77 |
+
pickle.dump(fid_stat_real, fp)
|
78 |
+
|
79 |
+
if tops.rank() == 0:
|
80 |
+
print("Starting calculation of fid from features of shape:", fid_features_fake.shape)
|
81 |
+
fid_stat_fake = fid_features_to_statistics(fid_features_fake)
|
82 |
+
fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"]
|
83 |
+
return dict(fid_clip=fid_)
|
84 |
+
return dict(fid_clip=-1)
|
dp2/metrics/lpips.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tops
|
3 |
+
import sys
|
4 |
+
from contextlib import redirect_stdout
|
5 |
+
from torch_fidelity.sample_similarity_lpips import NetLinLayer, URL_VGG16_LPIPS, VGG16features, normalize_tensor, spatial_average
|
6 |
+
|
7 |
+
class SampleSimilarityLPIPS(torch.nn.Module):
|
8 |
+
SUPPORTED_DTYPES = {
|
9 |
+
'uint8': torch.uint8,
|
10 |
+
'float32': torch.float32,
|
11 |
+
}
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
|
15 |
+
super().__init__()
|
16 |
+
self.chns = [64, 128, 256, 512, 512]
|
17 |
+
self.L = len(self.chns)
|
18 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=True)
|
19 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=True)
|
20 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=True)
|
21 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=True)
|
22 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=True)
|
23 |
+
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
24 |
+
with redirect_stdout(sys.stderr):
|
25 |
+
fp = tops.download_file(URL_VGG16_LPIPS)
|
26 |
+
state_dict = torch.load(fp, map_location="cpu")
|
27 |
+
self.load_state_dict(state_dict)
|
28 |
+
self.net = VGG16features()
|
29 |
+
self.eval()
|
30 |
+
for param in self.parameters():
|
31 |
+
param.requires_grad = False
|
32 |
+
mean_rescaled = (1 + torch.tensor([-.030, -.088, -.188]).view(1, 3, 1, 1)) * 255 / 2
|
33 |
+
inv_std_rescaled = 2 / (torch.tensor([.458, .448, .450]).view(1, 3, 1, 1) * 255)
|
34 |
+
self.register_buffer("mean", mean_rescaled)
|
35 |
+
self.register_buffer("std", inv_std_rescaled)
|
36 |
+
|
37 |
+
def normalize(self, x):
|
38 |
+
# torchvision values in range [0,1] mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]
|
39 |
+
x = (x.float() - self.mean) * self.std
|
40 |
+
return x
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def resize(x, size):
|
44 |
+
if x.shape[-1] > size and x.shape[-2] > size:
|
45 |
+
x = torch.nn.functional.interpolate(x, (size, size), mode='area')
|
46 |
+
else:
|
47 |
+
x = torch.nn.functional.interpolate(x, (size, size), mode='bilinear', align_corners=False)
|
48 |
+
return x
|
49 |
+
|
50 |
+
def lpips_from_feats(self, feats0, feats1):
|
51 |
+
diffs = {}
|
52 |
+
for kk in range(self.L):
|
53 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
54 |
+
|
55 |
+
res = [spatial_average(self.lins[kk].model(diffs[kk])) for kk in range(self.L)]
|
56 |
+
val = sum(res)
|
57 |
+
return val
|
58 |
+
|
59 |
+
def get_feats(self, x):
|
60 |
+
assert x.dim() == 4 and x.shape[1] == 3, 'Input 0 is not Bx3xHxW'
|
61 |
+
if x.shape[-2] < 16: # Resize images < 16x16
|
62 |
+
f = 16 / x.shape[-2]
|
63 |
+
size = tuple([int(f*_) for _ in x.shape[-2:]])
|
64 |
+
x = torch.nn.functional.interpolate(x, size=size, mode="bilinear", align_corners=False)
|
65 |
+
in0_input = self.normalize(x)
|
66 |
+
outs0 = self.net.forward(in0_input)
|
67 |
+
|
68 |
+
feats = {}
|
69 |
+
for kk in range(self.L):
|
70 |
+
feats[kk] = normalize_tensor(outs0[kk])
|
71 |
+
return feats
|
72 |
+
|
73 |
+
def forward(self, in0, in1):
|
74 |
+
feats0 = self.get_feats(in0)
|
75 |
+
feats1 = self.get_feats(in1)
|
76 |
+
return self.lpips_from_feats(feats0, feats1), feats0, feats1
|
dp2/metrics/ppl.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import tops
|
4 |
+
from dp2 import utils
|
5 |
+
from torch_fidelity.helpers import get_kwarg, vassert
|
6 |
+
from torch_fidelity.defaults import DEFAULTS as PPL_DEFAULTS
|
7 |
+
from torch_fidelity.utils import sample_random, batch_interp, create_sample_similarity
|
8 |
+
|
9 |
+
|
10 |
+
def slerp(a, b, t):
|
11 |
+
a = a / a.norm(dim=-1, keepdim=True)
|
12 |
+
b = b / b.norm(dim=-1, keepdim=True)
|
13 |
+
d = (a * b).sum(dim=-1, keepdim=True)
|
14 |
+
p = t * torch.acos(d)
|
15 |
+
c = b - d * a
|
16 |
+
c = c / c.norm(dim=-1, keepdim=True)
|
17 |
+
d = a * torch.cos(p) + c * torch.sin(p)
|
18 |
+
d = d / d.norm(dim=-1, keepdim=True)
|
19 |
+
return d
|
20 |
+
|
21 |
+
|
22 |
+
@torch.no_grad()
|
23 |
+
def calculate_ppl(
|
24 |
+
dataloader,
|
25 |
+
generator,
|
26 |
+
latent_space=None,
|
27 |
+
data_len=None,
|
28 |
+
**kwargs) -> dict:
|
29 |
+
"""
|
30 |
+
Inspired by https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py
|
31 |
+
"""
|
32 |
+
if latent_space is None:
|
33 |
+
latent_space = generator.latent_space
|
34 |
+
assert latent_space in ["Z", "W"], f"Not supported latent space: {latent_space}"
|
35 |
+
epsilon = PPL_DEFAULTS["ppl_epsilon"]
|
36 |
+
interp = PPL_DEFAULTS['ppl_z_interp_mode']
|
37 |
+
similarity_name = PPL_DEFAULTS['ppl_sample_similarity']
|
38 |
+
sample_similarity_resize = PPL_DEFAULTS['ppl_sample_similarity_resize']
|
39 |
+
sample_similarity_dtype = PPL_DEFAULTS['ppl_sample_similarity_dtype']
|
40 |
+
discard_percentile_lower = PPL_DEFAULTS['ppl_discard_percentile_lower']
|
41 |
+
discard_percentile_higher = PPL_DEFAULTS['ppl_discard_percentile_higher']
|
42 |
+
|
43 |
+
vassert(type(epsilon) is float and epsilon > 0, 'Epsilon must be a small positive floating point number')
|
44 |
+
vassert(discard_percentile_lower is None or 0 < discard_percentile_lower < 100, 'Invalid percentile')
|
45 |
+
vassert(discard_percentile_higher is None or 0 < discard_percentile_higher < 100, 'Invalid percentile')
|
46 |
+
if discard_percentile_lower is not None and discard_percentile_higher is not None:
|
47 |
+
vassert(0 < discard_percentile_lower < discard_percentile_higher < 100, 'Invalid percentiles')
|
48 |
+
|
49 |
+
sample_similarity = create_sample_similarity(
|
50 |
+
similarity_name,
|
51 |
+
sample_similarity_resize=sample_similarity_resize,
|
52 |
+
sample_similarity_dtype=sample_similarity_dtype,
|
53 |
+
cuda=False,
|
54 |
+
**kwargs
|
55 |
+
)
|
56 |
+
sample_similarity = tops.to_cuda(sample_similarity)
|
57 |
+
rng = np.random.RandomState(get_kwarg('rng_seed', kwargs))
|
58 |
+
distances = []
|
59 |
+
if data_len is None:
|
60 |
+
data_len = len(dataloader) * dataloader.batch_size
|
61 |
+
z0 = sample_random(rng, (data_len, generator.z_channels), "normal")
|
62 |
+
z1 = sample_random(rng, (data_len, generator.z_channels), "normal")
|
63 |
+
if latent_space == "Z":
|
64 |
+
z1 = batch_interp(z0, z1, epsilon, interp)
|
65 |
+
|
66 |
+
distances = torch.zeros(data_len, dtype=torch.float32, device=tops.get_device())
|
67 |
+
print(distances.shape)
|
68 |
+
end = 0
|
69 |
+
n_samples = 0
|
70 |
+
for it, batch in enumerate(utils.tqdm_(dataloader, desc="Perceptual Path Length")):
|
71 |
+
start = end
|
72 |
+
end = start + batch["img"].shape[0]
|
73 |
+
n_samples += batch["img"].shape[0]
|
74 |
+
batch_lat_e0 = tops.to_cuda(z0[start:end])
|
75 |
+
batch_lat_e1 = tops.to_cuda(z1[start:end])
|
76 |
+
if latent_space == "W":
|
77 |
+
w0 = generator.get_w(batch_lat_e0, update_emas=False)
|
78 |
+
w1 = generator.get_w(batch_lat_e1, update_emas=False)
|
79 |
+
w1 = w0.lerp(w1, epsilon) # PPL end
|
80 |
+
rgb1 = generator(**batch, w=w0)["img"]
|
81 |
+
rgb2 = generator(**batch, w=w1)["img"]
|
82 |
+
else:
|
83 |
+
rgb1 = generator(**batch, z=batch_lat_e0)["img"]
|
84 |
+
rgb2 = generator(**batch, z=batch_lat_e1)["img"]
|
85 |
+
rgb1 = utils.denormalize_img(rgb1).mul(255).byte()
|
86 |
+
rgb2 = utils.denormalize_img(rgb2).mul(255).byte()
|
87 |
+
|
88 |
+
sim = sample_similarity(rgb1, rgb2)
|
89 |
+
dist_lat_e01 = sim / (epsilon ** 2)
|
90 |
+
distances[start:end] = dist_lat_e01.view(-1)
|
91 |
+
distances = distances[:n_samples]
|
92 |
+
distances = tops.all_gather_uneven(distances).cpu().numpy()
|
93 |
+
if tops.rank() != 0:
|
94 |
+
return {"ppl/mean": -1, "ppl/std": -1}
|
95 |
+
if tops.rank() == 0:
|
96 |
+
cond, lo, hi = None, None, None
|
97 |
+
if discard_percentile_lower is not None:
|
98 |
+
lo = np.percentile(distances, discard_percentile_lower, interpolation='lower')
|
99 |
+
cond = lo <= distances
|
100 |
+
if discard_percentile_higher is not None:
|
101 |
+
hi = np.percentile(distances, discard_percentile_higher, interpolation='higher')
|
102 |
+
cond = np.logical_and(cond, distances <= hi)
|
103 |
+
if cond is not None:
|
104 |
+
distances = np.extract(cond, distances)
|
105 |
+
return {
|
106 |
+
"ppl/mean": float(np.mean(distances)),
|
107 |
+
"ppl/std": float(np.std(distances)),
|
108 |
+
}
|
109 |
+
else:
|
110 |
+
return {"ppl/mean"}
|
dp2/metrics/torch_metrics.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import time
|
5 |
+
from pathlib import Path
|
6 |
+
from dp2 import utils
|
7 |
+
import tops
|
8 |
+
from .lpips import SampleSimilarityLPIPS
|
9 |
+
from torch_fidelity.defaults import DEFAULTS as trf_defaults
|
10 |
+
from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric
|
11 |
+
from torch_fidelity.utils import create_feature_extractor
|
12 |
+
lpips_model = None
|
13 |
+
fid_model = None
|
14 |
+
|
15 |
+
@torch.no_grad()
|
16 |
+
def mse(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
|
17 |
+
se = (images1 - images2) ** 2
|
18 |
+
se = se.view(images1.shape[0], -1).mean(dim=1)
|
19 |
+
return se
|
20 |
+
|
21 |
+
@torch.no_grad()
|
22 |
+
def psnr(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
|
23 |
+
mse_ = mse(images1, images2)
|
24 |
+
psnr = 10 * torch.log10(1 / mse_)
|
25 |
+
return psnr
|
26 |
+
|
27 |
+
@torch.no_grad()
|
28 |
+
def lpips(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
|
29 |
+
return _lpips_w_grad(images1, images2)
|
30 |
+
|
31 |
+
|
32 |
+
def _lpips_w_grad(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
|
33 |
+
global lpips_model
|
34 |
+
if lpips_model is None:
|
35 |
+
lpips_model = tops.to_cuda(SampleSimilarityLPIPS())
|
36 |
+
|
37 |
+
images1 = images1.mul(255)
|
38 |
+
images2 = images2.mul(255)
|
39 |
+
with torch.cuda.amp.autocast(tops.AMP()):
|
40 |
+
dists = lpips_model(images1, images2)[0].view(-1)
|
41 |
+
return dists
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
@torch.no_grad()
|
47 |
+
def compute_metrics_iteratively(
|
48 |
+
dataloader, generator,
|
49 |
+
cache_directory,
|
50 |
+
data_len=None,
|
51 |
+
truncation_value: float=None,
|
52 |
+
) -> dict:
|
53 |
+
"""
|
54 |
+
Args:
|
55 |
+
n_samples (int): Creates N samples from same image to calculate stats
|
56 |
+
dataset_percentage (float): The percentage of the dataset to compute metrics on.
|
57 |
+
"""
|
58 |
+
|
59 |
+
global lpips_model, fid_model
|
60 |
+
if lpips_model is None:
|
61 |
+
lpips_model = tops.to_cuda(SampleSimilarityLPIPS())
|
62 |
+
if fid_model is None:
|
63 |
+
fid_model = create_feature_extractor(
|
64 |
+
trf_defaults["feature_extractor"], [trf_defaults["feature_layer_fid"]], cuda=False)
|
65 |
+
fid_model = tops.to_cuda(fid_model)
|
66 |
+
cache_directory = Path(cache_directory)
|
67 |
+
start_time = time.time()
|
68 |
+
lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device())
|
69 |
+
diversity_total = torch.zeros_like(lpips_total)
|
70 |
+
fid_cache_path = cache_directory.joinpath("fid_stats.pkl")
|
71 |
+
has_fid_cache = fid_cache_path.is_file()
|
72 |
+
if data_len is None:
|
73 |
+
data_len = len(dataloader)*dataloader.batch_size
|
74 |
+
if not has_fid_cache:
|
75 |
+
fid_features_real = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device())
|
76 |
+
fid_features_fake = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device())
|
77 |
+
n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device())
|
78 |
+
eidx = 0
|
79 |
+
for batch in utils.tqdm_(iter(dataloader), desc="Computing FID, LPIPS and LPIPS Diversity"):
|
80 |
+
sidx = eidx
|
81 |
+
eidx = sidx + batch["img"].shape[0]
|
82 |
+
n_samples_seen += batch["img"].shape[0]
|
83 |
+
with torch.cuda.amp.autocast(tops.AMP()):
|
84 |
+
fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"]
|
85 |
+
fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"]
|
86 |
+
fakes1 = utils.denormalize_img(fakes1).mul(255)
|
87 |
+
fakes2 = utils.denormalize_img(fakes2).mul(255)
|
88 |
+
real_data = utils.denormalize_img(batch["img"]).mul(255)
|
89 |
+
lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1)
|
90 |
+
fake2_lpips_feats = lpips_model.get_feats(fakes2)
|
91 |
+
lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats)
|
92 |
+
|
93 |
+
lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2)
|
94 |
+
diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum()
|
95 |
+
if not has_fid_cache:
|
96 |
+
fid_features_real[sidx:eidx] = fid_model(real_data.byte())[0]
|
97 |
+
fid_features_fake[sidx:eidx] = fid_model(fakes1.byte())[0]
|
98 |
+
fid_features_fake = fid_features_fake[:n_samples_seen]
|
99 |
+
if has_fid_cache:
|
100 |
+
if tops.rank() == 0:
|
101 |
+
with open(fid_cache_path, "rb") as fp:
|
102 |
+
fid_stat_real = pickle.load(fp)
|
103 |
+
else:
|
104 |
+
fid_features_real = fid_features_real[:n_samples_seen]
|
105 |
+
fid_features_real = tops.all_gather_uneven(fid_features_real).cpu()
|
106 |
+
if tops.rank() == 0:
|
107 |
+
fid_stat_real = fid_features_to_statistics(fid_features_real)
|
108 |
+
cache_directory.mkdir(exist_ok=True, parents=True)
|
109 |
+
with open(fid_cache_path, "wb") as fp:
|
110 |
+
pickle.dump(fid_stat_real, fp)
|
111 |
+
fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu()
|
112 |
+
if tops.rank() == 0:
|
113 |
+
print("Starting calculation of fid from features of shape:", fid_features_fake.shape)
|
114 |
+
fid_stat_fake = fid_features_to_statistics(fid_features_fake)
|
115 |
+
fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"]
|
116 |
+
tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM)
|
117 |
+
tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM)
|
118 |
+
tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM)
|
119 |
+
lpips_total = lpips_total / n_samples_seen
|
120 |
+
diversity_total = diversity_total / n_samples_seen
|
121 |
+
to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total)
|
122 |
+
if tops.rank() == 0:
|
123 |
+
to_return["fid"] = fid_
|
124 |
+
else:
|
125 |
+
to_return["fid"] = -1
|
126 |
+
to_return["validation_time_s"] = time.time() - start_time
|
127 |
+
return to_return
|
128 |
+
|
129 |
+
|
130 |
+
@torch.no_grad()
|
131 |
+
def compute_lpips(
|
132 |
+
dataloader, generator,
|
133 |
+
truncation_value: float=None,
|
134 |
+
data_len=None,
|
135 |
+
) -> dict:
|
136 |
+
"""
|
137 |
+
Args:
|
138 |
+
n_samples (int): Creates N samples from same image to calculate stats
|
139 |
+
dataset_percentage (float): The percentage of the dataset to compute metrics on.
|
140 |
+
"""
|
141 |
+
global lpips_model, fid_model
|
142 |
+
if lpips_model is None:
|
143 |
+
lpips_model = tops.to_cuda(SampleSimilarityLPIPS())
|
144 |
+
start_time = time.time()
|
145 |
+
lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device())
|
146 |
+
diversity_total = torch.zeros_like(lpips_total)
|
147 |
+
if data_len is None:
|
148 |
+
data_len = len(dataloader) * dataloader.batch_size
|
149 |
+
eidx = 0
|
150 |
+
n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device())
|
151 |
+
for batch in utils.tqdm_(dataloader, desc="Validating on dataset."):
|
152 |
+
sidx = eidx
|
153 |
+
eidx = sidx + batch["img"].shape[0]
|
154 |
+
n_samples_seen += batch["img"].shape[0]
|
155 |
+
with torch.cuda.amp.autocast(tops.AMP()):
|
156 |
+
fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"]
|
157 |
+
fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"]
|
158 |
+
real_data = batch["img"]
|
159 |
+
fakes1 = utils.denormalize_img(fakes1).mul(255)
|
160 |
+
fakes2 = utils.denormalize_img(fakes2).mul(255)
|
161 |
+
real_data = utils.denormalize_img(real_data).mul(255)
|
162 |
+
lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1)
|
163 |
+
fake2_lpips_feats = lpips_model.get_feats(fakes2)
|
164 |
+
lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats)
|
165 |
+
|
166 |
+
lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2)
|
167 |
+
diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum()
|
168 |
+
tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM)
|
169 |
+
tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM)
|
170 |
+
tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM)
|
171 |
+
lpips_total = lpips_total / n_samples_seen
|
172 |
+
diversity_total = diversity_total / n_samples_seen
|
173 |
+
to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total)
|
174 |
+
to_return = {k: v.cpu().item() for k, v in to_return.items()}
|
175 |
+
to_return["validation_time_s"] = time.time() - start_time
|
176 |
+
return to_return
|