haakohu commited on
Commit
548d634
·
1 Parent(s): e01e2cd
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. dp2/__init__.py +0 -0
  2. dp2/anonymizer/__init__.py +1 -0
  3. dp2/anonymizer/anonymizer.py +159 -0
  4. dp2/data/__init__.py +0 -0
  5. dp2/data/build.py +148 -0
  6. dp2/data/datasets/__init__.py +0 -0
  7. dp2/data/datasets/coco_cse.py +148 -0
  8. dp2/data/datasets/fdf.py +129 -0
  9. dp2/data/datasets/fdh.py +104 -0
  10. dp2/data/transforms/__init__.py +2 -0
  11. dp2/data/transforms/functional.py +61 -0
  12. dp2/data/transforms/stylegan2_transform.py +394 -0
  13. dp2/data/transforms/transforms.py +247 -0
  14. dp2/data/utils.py +102 -0
  15. dp2/detection/__init__.py +3 -0
  16. dp2/detection/base.py +45 -0
  17. dp2/detection/box_utils.py +104 -0
  18. dp2/detection/box_utils_fdf.py +203 -0
  19. dp2/detection/cse_mask_face_detector.py +116 -0
  20. dp2/detection/face_detector.py +62 -0
  21. dp2/detection/models/__init__.py +0 -0
  22. dp2/detection/models/cse.py +135 -0
  23. dp2/detection/models/keypoint_maskrcnn.py +111 -0
  24. dp2/detection/models/mask_rcnn.py +78 -0
  25. dp2/detection/person_detector.py +135 -0
  26. dp2/detection/structures.py +463 -0
  27. dp2/detection/utils.py +174 -0
  28. dp2/discriminator/__init__.py +1 -0
  29. dp2/discriminator/sg2_discriminator.py +76 -0
  30. dp2/gan_trainer.py +324 -0
  31. dp2/generator/__init__.py +0 -0
  32. dp2/generator/base.py +144 -0
  33. dp2/generator/dummy_generators.py +47 -0
  34. dp2/generator/imagen3_old.py +1210 -0
  35. dp2/generator/stylegan_unet.py +208 -0
  36. dp2/generator/utils.py +48 -0
  37. dp2/infer.py +72 -0
  38. dp2/layers/__init__.py +20 -0
  39. dp2/layers/sg2_layers.py +227 -0
  40. dp2/loss/__init__.py +1 -0
  41. dp2/loss/pl_regularization.py +48 -0
  42. dp2/loss/r1_regularization.py +31 -0
  43. dp2/loss/sg2_loss.py +94 -0
  44. dp2/loss/utils.py +25 -0
  45. dp2/metrics/__init__.py +3 -0
  46. dp2/metrics/fid.py +72 -0
  47. dp2/metrics/fid_clip.py +84 -0
  48. dp2/metrics/lpips.py +76 -0
  49. dp2/metrics/ppl.py +110 -0
  50. dp2/metrics/torch_metrics.py +176 -0
dp2/__init__.py ADDED
File without changes
dp2/anonymizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .anonymizer import Anonymizer
dp2/anonymizer/anonymizer.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Union, Optional
3
+ import numpy as np
4
+ import torch
5
+ import tops
6
+ import torchvision.transforms.functional as F
7
+ from motpy import Detection, MultiObjectTracker
8
+ from dp2.utils import load_config
9
+ from dp2.infer import build_trained_generator
10
+ from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection
11
+
12
+
13
+ def load_generator_from_cfg_path(cfg_path: Union[str, Path]):
14
+ cfg = load_config(cfg_path)
15
+ G = build_trained_generator(cfg)
16
+ tops.logger.log(f"Loaded generator from: {cfg_path}")
17
+ return G
18
+
19
+
20
+ def resize_batch(img, mask, maskrcnn_mask, condition, imsize, **kwargs):
21
+ img = F.resize(img, imsize, antialias=True)
22
+ mask = (F.resize(mask, imsize, antialias=True) > 0.99).float()
23
+ maskrcnn_mask = (F.resize(maskrcnn_mask, imsize, antialias=True) > 0.5).float()
24
+
25
+ condition = img * mask
26
+ return dict(img=img, mask=mask, maskrcnn_mask=maskrcnn_mask, condition=condition)
27
+
28
+
29
+ class Anonymizer:
30
+
31
+ def __init__(
32
+ self,
33
+ detector,
34
+ load_cache: bool,
35
+ person_G_cfg: Optional[Union[str, Path]] = None,
36
+ cse_person_G_cfg: Optional[Union[str, Path]] = None,
37
+ face_G_cfg: Optional[Union[str, Path]] = None,
38
+ car_G_cfg: Optional[Union[str, Path]] = None,
39
+ ) -> None:
40
+ self.detector = detector
41
+ self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]}
42
+ self.load_cache = load_cache
43
+ if cse_person_G_cfg is not None:
44
+ self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg)
45
+ if person_G_cfg is not None:
46
+ self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg)
47
+ if face_G_cfg is not None:
48
+ self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg)
49
+ if car_G_cfg is not None:
50
+ self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg)
51
+
52
+ def initialize_tracker(self, fps: float):
53
+ self.tracker = MultiObjectTracker(dt=1/fps)
54
+ self.track_to_z_idx = dict()
55
+ self.cur_z_idx = 0
56
+
57
+ @torch.no_grad()
58
+ def anonymize_detections(self,
59
+ im, detection, truncation_value: float,
60
+ multi_modal_truncation: bool, amp: bool, z_idx,
61
+ all_styles=None,
62
+ update_identity=None,
63
+ ):
64
+ G = self.generators[type(detection)]
65
+ if G is None:
66
+ return im
67
+ C, H, W = im.shape
68
+ orig_im = im.clone()
69
+ if update_identity is None:
70
+ update_identity = [True for i in range(len(detection))]
71
+ for idx in range(len(detection)):
72
+ if not update_identity[idx]:
73
+ continue
74
+ batch = detection.get_crop(idx, im)
75
+ x0, y0, x1, y1 = batch.pop("boxes")[0]
76
+ batch = {k: tops.to_cuda(v) for k, v in batch.items()}
77
+ batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255])
78
+ batch["img"] = batch["img"].float()
79
+ batch["condition"] = batch["mask"] * batch["img"]
80
+ orig_shape = None
81
+ if G.imsize and batch["img"].shape[-1] != G.imsize[-1] and batch["img"].shape[-2] != G.imsize[-2]:
82
+ orig_shape = batch["img"].shape[-2:]
83
+ batch = resize_batch(**batch, imsize=G.imsize)
84
+ with torch.cuda.amp.autocast(amp):
85
+ if all_styles is not None:
86
+ anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"]
87
+ elif multi_modal_truncation and hasattr(G, "multi_modal_truncate") and hasattr(G.style_net, "w_centers"):
88
+ w_indices = None
89
+ if z_idx is not None:
90
+ w_indices = [z_idx[idx] % len(G.style_net.w_centers)]
91
+ anonymized_im = G.multi_modal_truncate(
92
+ **batch, truncation_value=truncation_value,
93
+ w_indices=w_indices)["img"]
94
+ else:
95
+ z = None
96
+ if z_idx is not None:
97
+ state = np.random.RandomState(seed=z_idx[idx])
98
+ z = state.normal(size=(1, G.z_channels))
99
+ z = tops.to_cuda(torch.from_numpy(z))
100
+ anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"]
101
+ if orig_shape is not None:
102
+ anonymized_im = F.resize(anonymized_im, orig_shape, antialias=True)
103
+ anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255).round().byte()
104
+
105
+ # Resize and denormalize image
106
+ gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), antialias=True)
107
+ mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0)
108
+ # Remove padding
109
+ pad = [max(-x0,0), max(-y0,0)]
110
+ pad = [*pad, max(x1-W,0), max(y1-H,0)]
111
+ remove_pad = lambda x: x[...,pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]]
112
+ gim = remove_pad(gim)
113
+ mask = remove_pad(mask)
114
+ x0, y0 = max(x0, 0), max(y0, 0)
115
+ x1, y1 = min(x1, W), min(y1, H)
116
+ mask = mask.logical_not()[None].repeat(3, 1, 1)
117
+ im[:, y0:y1, x0:x1][mask] = gim[mask]
118
+
119
+ return im
120
+
121
+ def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor:
122
+ all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
123
+ for det in all_detections:
124
+ im = det.visualize(im)
125
+ return im
126
+
127
+ @torch.no_grad()
128
+ def forward(self, im: torch.Tensor, cache_id: str = None, track=True, **synthesis_kwargs) -> torch.Tensor:
129
+ assert im.dtype == torch.uint8
130
+ im = tops.to_cuda(im)
131
+ all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache)
132
+ if hasattr(self, "tracker") and track:
133
+ [_.pre_process() for _ in all_detections]
134
+ import numpy as np
135
+ boxes = np.concatenate([_.boxes for _ in all_detections])
136
+ boxes = [Detection(box) for box in boxes]
137
+ self.tracker.step(boxes)
138
+ track_ids = self.tracker.detections_matched_ids
139
+ z_idx = []
140
+ for track_id in track_ids:
141
+ if track_id not in self.track_to_z_idx:
142
+ self.track_to_z_idx[track_id] = self.cur_z_idx
143
+ self.cur_z_idx += 1
144
+ z_idx.append(self.track_to_z_idx[track_id])
145
+ z_idx = np.array(z_idx)
146
+ idx_offset = 0
147
+
148
+ for detection in all_detections:
149
+ zs = None
150
+ if hasattr(self, "tracker") and track:
151
+ zs = z_idx[idx_offset:idx_offset+len(detection)]
152
+ idx_offset += len(detection)
153
+ im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs)
154
+
155
+ return im.cpu()
156
+
157
+ def __call__(self, *args, **kwargs):
158
+ return self.forward(*args, **kwargs)
159
+
dp2/data/__init__.py ADDED
File without changes
dp2/data/build.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import torch
3
+ import tops
4
+ from .utils import collate_fn, jpg_decoder, get_num_workers, png_decoder
5
+
6
+ def get_dataloader(
7
+ dataset, gpu_transform: torch.nn.Module,
8
+ num_workers,
9
+ batch_size,
10
+ infinite: bool,
11
+ drop_last: bool,
12
+ prefetch_factor: int,
13
+ shuffle,
14
+ channels_last=False
15
+ ):
16
+ sampler = None
17
+ dl_kwargs = dict(
18
+ pin_memory=True,
19
+ )
20
+ if infinite:
21
+ sampler = tops.InfiniteSampler(
22
+ dataset, rank=tops.rank(),
23
+ num_replicas=tops.world_size(),
24
+ shuffle=shuffle
25
+ )
26
+ elif tops.world_size() > 1:
27
+ sampler = torch.utils.data.DistributedSampler(
28
+ dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank())
29
+ dl_kwargs["drop_last"] = drop_last
30
+ else:
31
+ dl_kwargs["shuffle"] = shuffle
32
+ dl_kwargs["drop_last"] = drop_last
33
+ dataloader = torch.utils.data.DataLoader(
34
+ dataset, sampler=sampler, collate_fn=collate_fn,
35
+ batch_size=batch_size,
36
+ num_workers=num_workers, prefetch_factor=prefetch_factor,
37
+ **dl_kwargs
38
+ )
39
+ dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last)
40
+ return dataloader
41
+
42
+
43
+ def get_dataloader_places2_wds(
44
+ path,
45
+ batch_size: int,
46
+ num_workers: int,
47
+ transform: torch.nn.Module,
48
+ gpu_transform: torch.nn.Module,
49
+ infinite: bool,
50
+ shuffle: bool,
51
+ partial_batches: bool,
52
+ sample_shuffle=10_000,
53
+ tar_shuffle=100,
54
+ channels_last=False,
55
+ ):
56
+ import webdataset as wds
57
+ import os
58
+ os.environ["RANK"] = str(tops.rank())
59
+ os.environ["WORLD_SIZE"] = str(tops.world_size())
60
+
61
+ if infinite:
62
+ pipeline = [wds.ResampledShards(str(path))]
63
+ else:
64
+ pipeline = [wds.SimpleShardList(str(path))]
65
+ if shuffle:
66
+ pipeline.append(wds.shuffle(tar_shuffle))
67
+ pipeline.extend([
68
+ wds.split_by_node,
69
+ wds.split_by_worker,
70
+ ])
71
+ if shuffle:
72
+ pipeline.append(wds.shuffle(sample_shuffle))
73
+
74
+ pipeline.extend([
75
+ wds.tarfile_to_samples(),
76
+ wds.decode("torchrgb8"),
77
+ wds.rename_keys(["img", "jpg"], ["__key__", "__key__"]),
78
+ ])
79
+ if transform is not None:
80
+ pipeline.append(wds.map(transform))
81
+ pipeline.extend([
82
+ wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
83
+ ])
84
+ pipeline = wds.DataPipeline(*pipeline)
85
+ if infinite:
86
+ pipeline = pipeline.repeat(nepochs=1000000)
87
+ loader = wds.WebLoader(
88
+ pipeline, batch_size=None, shuffle=False,
89
+ num_workers=get_num_workers(num_workers),
90
+ persistent_workers=True,
91
+ )
92
+ loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
93
+ return loader
94
+
95
+
96
+
97
+
98
+ def get_dataloader_celebAHQ_wds(
99
+ path,
100
+ batch_size: int,
101
+ num_workers: int,
102
+ transform: torch.nn.Module,
103
+ gpu_transform: torch.nn.Module,
104
+ infinite: bool,
105
+ shuffle: bool,
106
+ partial_batches: bool,
107
+ sample_shuffle=10_000,
108
+ tar_shuffle=100,
109
+ channels_last=False,
110
+ ):
111
+ import webdataset as wds
112
+ import os
113
+ os.environ["RANK"] = str(tops.rank())
114
+ os.environ["WORLD_SIZE"] = str(tops.world_size())
115
+
116
+ if infinite:
117
+ pipeline = [wds.ResampledShards(str(path))]
118
+ else:
119
+ pipeline = [wds.SimpleShardList(str(path))]
120
+ if shuffle:
121
+ pipeline.append(wds.shuffle(tar_shuffle))
122
+ pipeline.extend([
123
+ wds.split_by_node,
124
+ wds.split_by_worker,
125
+ ])
126
+ if shuffle:
127
+ pipeline.append(wds.shuffle(sample_shuffle))
128
+
129
+ pipeline.extend([
130
+ wds.tarfile_to_samples(),
131
+ wds.decode(wds.handle_extension(".png", png_decoder)),
132
+ wds.rename_keys(["img", "png"], ["__key__", "__key__"]),
133
+ ])
134
+ if transform is not None:
135
+ pipeline.append(wds.map(transform))
136
+ pipeline.extend([
137
+ wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
138
+ ])
139
+ pipeline = wds.DataPipeline(*pipeline)
140
+ if infinite:
141
+ pipeline = pipeline.repeat(nepochs=1000000)
142
+ loader = wds.WebLoader(
143
+ pipeline, batch_size=None, shuffle=False,
144
+ num_workers=get_num_workers(num_workers),
145
+ persistent_workers=True,
146
+ )
147
+ loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last)
148
+ return loader
dp2/data/datasets/__init__.py ADDED
File without changes
dp2/data/datasets/coco_cse.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torchvision
3
+ import torch
4
+ import pathlib
5
+ import numpy as np
6
+ from typing import Callable, Optional, Union
7
+ from torch.hub import get_dir as get_hub_dir
8
+
9
+
10
+ def cache_embed_stats(embed_map: torch.Tensor):
11
+ mean = embed_map.mean(dim=0, keepdim=True)
12
+ rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
13
+
14
+ cache = dict(mean=mean, rstd=rstd, embed_map=embed_map)
15
+ path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch")
16
+ path.parent.mkdir(exist_ok=True, parents=True)
17
+ torch.save(cache, path)
18
+
19
+
20
+ class CocoCSE(torch.utils.data.Dataset):
21
+
22
+ def __init__(self,
23
+ dirpath: Union[str, pathlib.Path],
24
+ transform: Optional[Callable],
25
+ normalize_E: bool,):
26
+ dirpath = pathlib.Path(dirpath)
27
+ self.dirpath = dirpath
28
+
29
+ self.transform = transform
30
+ assert self.dirpath.is_dir(),\
31
+ f"Did not find dataset at: {dirpath}"
32
+ self.image_paths, self.embedding_paths = self._load_impaths()
33
+ self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
34
+ mean = self.embed_map.mean(dim=0, keepdim=True)
35
+ rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
36
+ self.embed_map = (self.embed_map - mean) * rstd
37
+ cache_embed_stats(self.embed_map)
38
+
39
+ def _load_impaths(self):
40
+ image_dir = self.dirpath.joinpath("images")
41
+ image_paths = list(image_dir.glob("*.png"))
42
+ image_paths.sort()
43
+ embedding_paths = [
44
+ self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
45
+ ]
46
+ return image_paths, embedding_paths
47
+
48
+ def __len__(self):
49
+ return len(self.image_paths)
50
+
51
+ def __getitem__(self, idx):
52
+ im = torchvision.io.read_image(str(self.image_paths[idx]))
53
+ vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
54
+ vertices = torch.from_numpy(vertices.squeeze()).long()
55
+ mask = torch.from_numpy(mask.squeeze()).float()
56
+ border = torch.from_numpy(border.squeeze()).float()
57
+ E_mask = 1 - mask - border
58
+ batch = {
59
+ "img": im,
60
+ "vertices": vertices[None],
61
+ "mask": mask[None],
62
+ "embed_map": self.embed_map,
63
+ "border": border[None],
64
+ "E_mask": E_mask[None]
65
+ }
66
+ if self.transform is None:
67
+ return batch
68
+ return self.transform(batch)
69
+
70
+
71
+ class CocoCSEWithFace(CocoCSE):
72
+
73
+ def __init__(self,
74
+ dirpath: Union[str, pathlib.Path],
75
+ transform: Optional[Callable],
76
+ **kwargs):
77
+ super().__init__(dirpath, transform, **kwargs)
78
+ with open(self.dirpath.joinpath("face_boxes_XYXY.pickle"), "rb") as fp:
79
+ self.face_boxes = pickle.load(fp)
80
+
81
+ def __getitem__(self, idx):
82
+ item = super().__getitem__(idx)
83
+ item["boxes_XYXY"] = self.face_boxes[self.image_paths[idx].name]
84
+ return item
85
+
86
+
87
+ class CocoCSESemantic(torch.utils.data.Dataset):
88
+
89
+ def __init__(self,
90
+ dirpath: Union[str, pathlib.Path],
91
+ transform: Optional[Callable],
92
+ **kwargs):
93
+ dirpath = pathlib.Path(dirpath)
94
+ self.dirpath = dirpath
95
+
96
+ self.transform = transform
97
+ assert self.dirpath.is_dir(),\
98
+ f"Did not find dataset at: {dirpath}"
99
+ self.image_paths, self.embedding_paths = self._load_impaths()
100
+ self.vertx2cat = torch.from_numpy(np.load(self.dirpath.parent.joinpath("vertx2cat.npy")))
101
+ self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy")))
102
+
103
+ def _load_impaths(self):
104
+ image_dir = self.dirpath.joinpath("images")
105
+ image_paths = list(image_dir.glob("*.png"))
106
+ image_paths.sort()
107
+ embedding_paths = [
108
+ self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths
109
+ ]
110
+ return image_paths, embedding_paths
111
+
112
+ def __len__(self):
113
+ return len(self.image_paths)
114
+
115
+ def __getitem__(self, idx):
116
+ im = torchvision.io.read_image(str(self.image_paths[idx]))
117
+ vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1)
118
+ vertices = torch.from_numpy(vertices.squeeze()).long()
119
+ mask = torch.from_numpy(mask.squeeze()).float()
120
+ border = torch.from_numpy(border.squeeze()).float()
121
+ E_mask = 1 - mask - border
122
+ batch = {
123
+ "img": im,
124
+ "vertices": vertices[None],
125
+ "mask": mask[None],
126
+ "border": border[None],
127
+ "vertx2cat": self.vertx2cat,
128
+ "embed_map": self.embed_map,
129
+ }
130
+ if self.transform is None:
131
+ return batch
132
+ return self.transform(batch)
133
+
134
+
135
+ class CocoCSESemanticWithFace(CocoCSESemantic):
136
+
137
+ def __init__(self,
138
+ dirpath: Union[str, pathlib.Path],
139
+ transform: Optional[Callable],
140
+ **kwargs):
141
+ super().__init__(dirpath, transform, **kwargs)
142
+ with open(self.dirpath.joinpath("face_boxes_XYXY.pickle"), "rb") as fp:
143
+ self.face_boxes = pickle.load(fp)
144
+
145
+ def __getitem__(self, idx):
146
+ item = super().__getitem__(idx)
147
+ item["boxes_XYXY"] = self.face_boxes[self.image_paths[idx].name]
148
+ return item
dp2/data/datasets/fdf.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from typing import Tuple
3
+ import numpy as np
4
+ import torch
5
+ import pathlib
6
+ try:
7
+ import pyspng
8
+ PYSPNG_IMPORTED = True
9
+ except ImportError:
10
+ PYSPNG_IMPORTED = False
11
+ print("Could not load pyspng. Defaulting to pillow image backend.")
12
+ from PIL import Image
13
+ from tops import logger
14
+
15
+
16
+ class FDFDataset:
17
+
18
+ def __init__(self,
19
+ dirpath,
20
+ imsize: Tuple[int],
21
+ load_keypoints: bool,
22
+ transform):
23
+ dirpath = pathlib.Path(dirpath)
24
+ self.dirpath = dirpath
25
+ self.transform = transform
26
+ self.imsize = imsize[0]
27
+ self.load_keypoints = load_keypoints
28
+ assert self.dirpath.is_dir(),\
29
+ f"Did not find dataset at: {dirpath}"
30
+ image_dir = self.dirpath.joinpath("images", str(self.imsize))
31
+ self.image_paths = list(image_dir.glob("*.png"))
32
+ assert len(self.image_paths) > 0,\
33
+ f"Did not find images in: {image_dir}"
34
+ self.image_paths.sort(key=lambda x: int(x.stem))
35
+ self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
36
+
37
+ self.bounding_boxes = torch.load(self.dirpath.joinpath("bounding_box", f"{self.imsize}.torch"))
38
+ assert len(self.image_paths) == len(self.bounding_boxes)
39
+ assert len(self.image_paths) == len(self.landmarks)
40
+ logger.log(
41
+ f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}, imsize={imsize}")
42
+
43
+ def get_mask(self, idx):
44
+ mask = torch.ones((1, self.imsize, self.imsize), dtype=torch.bool)
45
+ bounding_box = self.bounding_boxes[idx]
46
+ x0, y0, x1, y1 = bounding_box
47
+ mask[:, y0:y1, x0:x1] = 0
48
+ return mask
49
+
50
+ def __len__(self):
51
+ return len(self.image_paths)
52
+
53
+ def __getitem__(self, index):
54
+ impath = self.image_paths[index]
55
+ if PYSPNG_IMPORTED:
56
+ with open(impath, "rb") as fp:
57
+ im = pyspng.load(fp.read())
58
+ else:
59
+ with Image.open(impath) as fp:
60
+ im = np.array(fp)
61
+ im = torch.from_numpy(np.rollaxis(im, -1, 0))
62
+ masks = self.get_mask(index)
63
+ landmark = self.landmarks[index]
64
+ batch = {
65
+ "img": im,
66
+ "mask": masks,
67
+ }
68
+ if self.load_keypoints:
69
+ batch["keypoints"] = landmark
70
+ if self.transform is None:
71
+ return batch
72
+ return self.transform(batch)
73
+
74
+
75
+ class FDF256Dataset:
76
+
77
+ def __init__(self,
78
+ dirpath,
79
+ load_keypoints: bool,
80
+ transform):
81
+ dirpath = pathlib.Path(dirpath)
82
+ self.dirpath = dirpath
83
+ self.transform = transform
84
+ self.load_keypoints = load_keypoints
85
+ assert self.dirpath.is_dir(),\
86
+ f"Did not find dataset at: {dirpath}"
87
+ image_dir = self.dirpath.joinpath("images")
88
+ self.image_paths = list(image_dir.glob("*.png"))
89
+ assert len(self.image_paths) > 0,\
90
+ f"Did not find images in: {image_dir}"
91
+ self.image_paths.sort(key=lambda x: int(x.stem))
92
+ self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32)
93
+ self.bounding_boxes = torch.from_numpy(np.load(self.dirpath.joinpath("bounding_box.npy")))
94
+ assert len(self.image_paths) == len(self.bounding_boxes)
95
+ assert len(self.image_paths) == len(self.landmarks)
96
+ logger.log(
97
+ f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}")
98
+
99
+ def get_mask(self, idx):
100
+ mask = torch.ones((1, 256, 256), dtype=torch.bool)
101
+ bounding_box = self.bounding_boxes[idx]
102
+ x0, y0, x1, y1 = bounding_box
103
+ mask[:, y0:y1, x0:x1] = 0
104
+ return mask
105
+
106
+ def __len__(self):
107
+ return len(self.image_paths)
108
+
109
+ def __getitem__(self, index):
110
+ impath = self.image_paths[index]
111
+ if PYSPNG_IMPORTED:
112
+ with open(impath, "rb") as fp:
113
+ im = pyspng.load(fp.read())
114
+ else:
115
+ with Image.open(impath) as fp:
116
+ im = np.array(fp)
117
+ im = torch.from_numpy(np.rollaxis(im, -1, 0))
118
+ masks = self.get_mask(index)
119
+ landmark = self.landmarks[index]
120
+ batch = {
121
+ "img": im,
122
+ "mask": masks,
123
+ }
124
+ if self.load_keypoints:
125
+ batch["keypoints"] = landmark
126
+ if self.transform is None:
127
+ return batch
128
+ return self.transform(batch)
129
+
dp2/data/datasets/fdh.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tops
3
+ import numpy as np
4
+ import io
5
+ import webdataset as wds
6
+ import os
7
+ from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn
8
+
9
+
10
+ def kp_decoder(x):
11
+ # Keypoints are between [0, 1] for webdataset
12
+ keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float()
13
+ keypoints[:, 0] /= 160
14
+ keypoints[:, 1] /= 288
15
+ check_outside = lambda x: (x < 0).logical_or(x > 1)
16
+ is_outside = check_outside(keypoints[:, 0]).logical_or(
17
+ check_outside(keypoints[:, 1])
18
+ )
19
+ keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
20
+ return keypoints
21
+
22
+
23
+ def vertices_decoder(x):
24
+ vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32))
25
+ return vertices.squeeze()[None]
26
+
27
+
28
+ def get_dataloader_fdh_wds(
29
+ path,
30
+ batch_size: int,
31
+ num_workers: int,
32
+ transform: torch.nn.Module,
33
+ gpu_transform: torch.nn.Module,
34
+ infinite: bool,
35
+ shuffle: bool,
36
+ partial_batches: bool,
37
+ load_embedding: bool,
38
+ sample_shuffle=10_000,
39
+ tar_shuffle=100,
40
+ read_condition=False,
41
+ channels_last=False,
42
+ ):
43
+ # Need to set this for split_by_node to work.
44
+ os.environ["RANK"] = str(tops.rank())
45
+ os.environ["WORLD_SIZE"] = str(tops.world_size())
46
+ if infinite:
47
+ pipeline = [wds.ResampledShards(str(path))]
48
+ else:
49
+ pipeline = [wds.SimpleShardList(str(path))]
50
+ if shuffle:
51
+ pipeline.append(wds.shuffle(tar_shuffle))
52
+ pipeline.extend([
53
+ wds.split_by_node,
54
+ wds.split_by_worker,
55
+ ])
56
+ if shuffle:
57
+ pipeline.append(wds.shuffle(sample_shuffle))
58
+
59
+ decoder = [
60
+ wds.handle_extension("image.png", png_decoder),
61
+ wds.handle_extension("mask.png", mask_decoder),
62
+ wds.handle_extension("maskrcnn_mask.png", mask_decoder),
63
+ wds.handle_extension("keypoints.npy", kp_decoder),
64
+ ]
65
+
66
+ rename_keys = [
67
+ ["img", "image.png"], ["mask", "mask.png"],
68
+ ["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"]
69
+ ]
70
+ if load_embedding:
71
+ decoder.extend([
72
+ wds.handle_extension("vertices.npy", vertices_decoder),
73
+ wds.handle_extension("E_mask.png", mask_decoder)
74
+ ])
75
+ rename_keys.extend([
76
+ ["vertices", "vertices.npy"],
77
+ ["E_mask", "e_mask.png"]
78
+ ])
79
+
80
+ if read_condition:
81
+ decoder.append(
82
+ wds.handle_extension("condition.png", png_decoder)
83
+ )
84
+ rename_keys.append(["condition", "condition.png"])
85
+
86
+ pipeline.extend([
87
+ wds.tarfile_to_samples(),
88
+ wds.decode(*decoder),
89
+ wds.rename_keys(*rename_keys),
90
+ wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
91
+ ])
92
+ if transform is not None:
93
+ pipeline.append(wds.map(transform))
94
+ pipeline = wds.DataPipeline(*pipeline)
95
+ if infinite:
96
+ pipeline = pipeline.repeat(nepochs=1000000)
97
+
98
+ loader = wds.WebLoader(
99
+ pipeline, batch_size=None, shuffle=False,
100
+ num_workers=get_num_workers(num_workers),
101
+ persistent_workers=True,
102
+ )
103
+ loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
104
+ return loader
dp2/data/transforms/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .transforms import RandomCrop, CreateCondition, CreateEmbedding, Resize, ToFloat, Normalize
2
+ from .stylegan2_transform import StyleGANAugmentPipe
dp2/data/transforms/functional.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms.functional as F
2
+ import torch
3
+ import pickle
4
+ from tops import download_file, assert_shape
5
+ from typing import Dict
6
+ from functools import lru_cache
7
+
8
+ global symmetry_transform
9
+
10
+ @lru_cache(maxsize=1)
11
+ def get_symmetry_transform(symmetry_url):
12
+ file_name = download_file(symmetry_url)
13
+ with open(file_name, "rb") as fp:
14
+ symmetry = pickle.load(fp)
15
+ return torch.from_numpy(symmetry["vertex_transforms"]).long()
16
+
17
+
18
+ hflip_handled_cases = set([
19
+ "keypoints", "img", "mask", "border", "semantic_mask", "vertices", "E_mask", "embed_map", "condition",
20
+ "embedding", "vertx2cat", "maskrcnn_mask", "__key__",
21
+ "img_hr", "condition_hr", "mask_hr"])
22
+
23
+ def hflip(container: Dict[str, torch.Tensor], flip_map=None) -> Dict[str, torch.Tensor]:
24
+ container["img"] = F.hflip(container["img"])
25
+ if "condition" in container:
26
+ container["condition"] = F.hflip(container["condition"])
27
+ if "embedding" in container:
28
+ container["embedding"] = F.hflip(container["embedding"])
29
+ assert all([key in hflip_handled_cases for key in container]), container.keys()
30
+ if "keypoints" in container:
31
+ assert flip_map is not None
32
+ if container["keypoints"].ndim == 3:
33
+ keypoints = container["keypoints"][:, flip_map, :]
34
+ keypoints[:, :, 0] = 1 - keypoints[:, :, 0]
35
+ else:
36
+ assert_shape(container["keypoints"], (None, 3))
37
+ keypoints = container["keypoints"][flip_map, :]
38
+ keypoints[:, 0] = 1 - keypoints[:, 0]
39
+ container["keypoints"] = keypoints
40
+ if "mask" in container:
41
+ container["mask"] = F.hflip(container["mask"])
42
+ if "border" in container:
43
+ container["border"] = F.hflip(container["border"])
44
+ if "semantic_mask" in container:
45
+ container["semantic_mask"] = F.hflip(container["semantic_mask"])
46
+ if "vertices" in container:
47
+ symmetry_transform = get_symmetry_transform("https://dl.fbaipublicfiles.com/densepose/meshes/symmetry/symmetry_smpl_27554.pkl")
48
+ container["vertices"] = F.hflip(container["vertices"])
49
+ symmetry_transform_ = symmetry_transform.to(container["vertices"].device)
50
+ container["vertices"] = symmetry_transform_[container["vertices"].long()]
51
+ if "E_mask" in container:
52
+ container["E_mask"] = F.hflip(container["E_mask"])
53
+ if "maskrcnn_mask" in container:
54
+ container["maskrcnn_mask"] = F.hflip(container["maskrcnn_mask"])
55
+ if "img_hr" in container:
56
+ container["img_hr"] = F.hflip(container["img_hr"])
57
+ if "condition_hr" in container:
58
+ container["condition_hr"] = F.hflip(container["condition_hr"])
59
+ if "mask_hr" in container:
60
+ container["mask_hr"] = F.hflip(container["mask_hr"])
61
+ return container
dp2/data/transforms/stylegan2_transform.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy.signal
3
+ import torch
4
+ try:
5
+ from sg3_torch_utils import misc
6
+ from sg3_torch_utils.ops import upfirdn2d
7
+ from sg3_torch_utils.ops import grid_sample_gradfix
8
+ from sg3_torch_utils.ops import conv2d_gradfix
9
+ except:
10
+ pass
11
+ #----------------------------------------------------------------------------
12
+ # Coefficients of various wavelet decomposition low-pass filters.
13
+
14
+ wavelets = {
15
+ 'haar': [0.7071067811865476, 0.7071067811865476],
16
+ 'db1': [0.7071067811865476, 0.7071067811865476],
17
+ 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
18
+ 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
19
+ 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
20
+ 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
21
+ 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
22
+ 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
23
+ 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
24
+ 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
25
+ 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
26
+ 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
27
+ 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
28
+ 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
29
+ 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
30
+ 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
31
+ }
32
+
33
+ #----------------------------------------------------------------------------
34
+ # Helpers for constructing transformation matrices.
35
+
36
+
37
+ def matrix(*rows, device=None):
38
+ assert all(len(row) == len(rows[0]) for row in rows)
39
+ elems = [x for row in rows for x in row]
40
+ ref = [x for x in elems if isinstance(x, torch.Tensor)]
41
+ if len(ref) == 0:
42
+ return misc.constant(np.asarray(rows), device=device)
43
+ assert device is None or device == ref[0].device
44
+ elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
45
+ return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
46
+
47
+
48
+ def translate2d(tx, ty, **kwargs):
49
+ return matrix(
50
+ [1, 0, tx],
51
+ [0, 1, ty],
52
+ [0, 0, 1],
53
+ **kwargs)
54
+
55
+
56
+ def translate3d(tx, ty, tz, **kwargs):
57
+ return matrix(
58
+ [1, 0, 0, tx],
59
+ [0, 1, 0, ty],
60
+ [0, 0, 1, tz],
61
+ [0, 0, 0, 1],
62
+ **kwargs)
63
+
64
+
65
+ def scale2d(sx, sy, **kwargs):
66
+ return matrix(
67
+ [sx, 0, 0],
68
+ [0, sy, 0],
69
+ [0, 0, 1],
70
+ **kwargs)
71
+
72
+
73
+ def scale3d(sx, sy, sz, **kwargs):
74
+ return matrix(
75
+ [sx, 0, 0, 0],
76
+ [0, sy, 0, 0],
77
+ [0, 0, sz, 0],
78
+ [0, 0, 0, 1],
79
+ **kwargs)
80
+
81
+
82
+ def rotate2d(theta, **kwargs):
83
+ return matrix(
84
+ [torch.cos(theta), torch.sin(-theta), 0],
85
+ [torch.sin(theta), torch.cos(theta), 0],
86
+ [0, 0, 1],
87
+ **kwargs)
88
+
89
+
90
+ def rotate3d(v, theta, **kwargs):
91
+ vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
92
+ s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
93
+ return matrix(
94
+ [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
95
+ [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
96
+ [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
97
+ [0, 0, 0, 1],
98
+ **kwargs)
99
+
100
+
101
+ def translate2d_inv(tx, ty, **kwargs):
102
+ return translate2d(-tx, -ty, **kwargs)
103
+
104
+
105
+ def scale2d_inv(sx, sy, **kwargs):
106
+ return scale2d(1 / sx, 1 / sy, **kwargs)
107
+
108
+
109
+ def rotate2d_inv(theta, **kwargs):
110
+ return rotate2d(-theta, **kwargs)
111
+
112
+
113
+ class StyleGANAugmentPipe(torch.nn.Module):
114
+ def __init__(self,
115
+ rotate90=0, xint=0, xint_max=0.125,
116
+ scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
117
+ brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5,
118
+ hue_max=1, saturation_std=1,
119
+ imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
120
+ ):
121
+ super().__init__()
122
+ self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
123
+
124
+ # Pixel blitting.
125
+ self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
126
+ self.xint = float(xint) # Probability multiplier for integer translation.
127
+ self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
128
+
129
+ # General geometric transformations.
130
+ self.scale = float(scale) # Probability multiplier for isotropic scaling.
131
+ self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
132
+ self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
133
+ self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
134
+ self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
135
+ self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
136
+ self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
137
+ self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
138
+
139
+ # Color transformations.
140
+ self.brightness = float(brightness) # Probability multiplier for brightness.
141
+ self.contrast = float(contrast) # Probability multiplier for contrast.
142
+ self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
143
+ self.hue = float(hue) # Probability multiplier for hue rotation.
144
+ self.saturation = float(saturation) # Probability multiplier for saturation.
145
+ self.brightness_std = float(brightness_std) # Standard deviation of brightness.
146
+ self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
147
+ self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
148
+ self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
149
+
150
+ # Image-space filtering.
151
+ self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
152
+ self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
153
+ self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
154
+
155
+ # Setup orthogonal lowpass filter for geometric augmentations.
156
+ self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
157
+
158
+ # Construct filter bank for image-space filtering.
159
+ Hz_lo = np.asarray(wavelets['sym2']) # H(z)
160
+ Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
161
+ Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
162
+ Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
163
+ Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
164
+ for i in range(1, Hz_fbank.shape[0]):
165
+ Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
166
+ Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
167
+ Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
168
+ self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
169
+
170
+ def forward(self, batch, debug_percentile=None):
171
+ images = batch["img"]
172
+ batch["vertices"] = batch["vertices"].float()
173
+ assert isinstance(images, torch.Tensor) and images.ndim == 4
174
+ batch_size, num_channels, height, width = images.shape
175
+ device = images.device
176
+ self.Hz_fbank = self.Hz_fbank.to(device)
177
+ self.Hz_geom = self.Hz_geom.to(device)
178
+ if debug_percentile is not None:
179
+ debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
180
+
181
+ # -------------------------------------
182
+ # Select parameters for pixel blitting.
183
+ # -------------------------------------
184
+
185
+ # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
186
+ I_3 = torch.eye(3, device=device)
187
+ G_inv = I_3
188
+
189
+ # Apply integer translation with probability (xint * strength).
190
+ if self.xint > 0:
191
+ t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
192
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
193
+ if debug_percentile is not None:
194
+ t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
195
+ G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
196
+
197
+ # --------------------------------------------------------
198
+ # Select parameters for general geometric transformations.
199
+ # --------------------------------------------------------
200
+
201
+ # Apply isotropic scaling with probability (scale * strength).
202
+ if self.scale > 0:
203
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
204
+ s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
205
+ if debug_percentile is not None:
206
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
207
+ G_inv = G_inv @ scale2d_inv(s, s)
208
+
209
+ # Apply pre-rotation with probability p_rot.
210
+ p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
211
+ if self.rotate > 0:
212
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
213
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
214
+ if debug_percentile is not None:
215
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
216
+ G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
217
+
218
+ # Apply anisotropic scaling with probability (aniso * strength).
219
+ if self.aniso > 0:
220
+ s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
221
+ s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
222
+ if debug_percentile is not None:
223
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
224
+ G_inv = G_inv @ scale2d_inv(s, 1 / s)
225
+
226
+ # Apply post-rotation with probability p_rot.
227
+ if self.rotate > 0:
228
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
229
+ theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
230
+ if debug_percentile is not None:
231
+ theta = torch.zeros_like(theta)
232
+ G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
233
+
234
+ # Apply fractional translation with probability (xfrac * strength).
235
+ if self.xfrac > 0:
236
+ t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
237
+ t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
238
+ if debug_percentile is not None:
239
+ t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
240
+ G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
241
+
242
+ # ----------------------------------
243
+ # Execute geometric transformations.
244
+ # ----------------------------------
245
+
246
+ # Execute if the transform is not identity.
247
+ if G_inv is not I_3:
248
+ # Calculate padding.
249
+ cx = (width - 1) / 2
250
+ cy = (height - 1) / 2
251
+ cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
252
+ cp = G_inv @ cp.t() # [batch, xyz, idx]
253
+ Hz_pad = self.Hz_geom.shape[0] // 4
254
+ margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
255
+ margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
256
+ margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
257
+ margin = margin.max(misc.constant([0, 0] * 2, device=device))
258
+ margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
259
+ mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
260
+
261
+ # Pad image and adjust origin.
262
+ images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
263
+ batch["mask"] = torch.nn.functional.pad(input=batch["mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=1.0)
264
+ batch["E_mask"] = torch.nn.functional.pad(input=batch["E_mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
265
+ batch["vertices"] = torch.nn.functional.pad(input=batch["vertices"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0)
266
+ G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
267
+
268
+ # Upsample.
269
+ images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
270
+ batch["mask"] = torch.nn.functional.interpolate(batch["mask"], scale_factor=2, mode="nearest")
271
+ batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"], scale_factor=2, mode="nearest")
272
+ batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"], scale_factor=2, mode="nearest")
273
+ G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
274
+ G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
275
+
276
+ # Execute transformation.
277
+ shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
278
+ G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
279
+ grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
280
+ images = grid_sample_gradfix.grid_sample(images, grid)
281
+
282
+ batch["mask"] = torch.nn.functional.grid_sample(
283
+ input=batch["mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
284
+ batch["E_mask"] = torch.nn.functional.grid_sample(
285
+ input=batch["E_mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
286
+ batch["vertices"] = torch.nn.functional.grid_sample(
287
+ input=batch["vertices"], grid=grid, mode='nearest', padding_mode="border", align_corners=False)
288
+
289
+
290
+ # Downsample and crop.
291
+ images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
292
+ batch["mask"] = torch.nn.functional.interpolate(batch["mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
293
+ batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
294
+ batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False)
295
+ # --------------------------------------------
296
+ # Select parameters for color transformations.
297
+ # --------------------------------------------
298
+
299
+ # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
300
+ I_4 = torch.eye(4, device=device)
301
+ C = I_4
302
+
303
+ # Apply brightness with probability (brightness * strength).
304
+ if self.brightness > 0:
305
+ b = torch.randn([batch_size], device=device) * self.brightness_std
306
+ b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
307
+ if debug_percentile is not None:
308
+ b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
309
+ C = translate3d(b, b, b) @ C
310
+
311
+ # Apply contrast with probability (contrast * strength).
312
+ if self.contrast > 0:
313
+ c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
314
+ c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
315
+ if debug_percentile is not None:
316
+ c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
317
+ C = scale3d(c, c, c) @ C
318
+
319
+ # Apply luma flip with probability (lumaflip * strength).
320
+ v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
321
+
322
+ # Apply hue rotation with probability (hue * strength).
323
+ if self.hue > 0 and num_channels > 1:
324
+ theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
325
+ theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
326
+ if debug_percentile is not None:
327
+ theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
328
+ C = rotate3d(v, theta) @ C # Rotate around v.
329
+
330
+ # Apply saturation with probability (saturation * strength).
331
+ if self.saturation > 0 and num_channels > 1:
332
+ s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
333
+ s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
334
+ if debug_percentile is not None:
335
+ s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
336
+ C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
337
+
338
+ # ------------------------------
339
+ # Execute color transformations.
340
+ # ------------------------------
341
+
342
+ # Execute if the transform is not identity.
343
+ if C is not I_4:
344
+ images = images.reshape([batch_size, num_channels, height * width])
345
+ if num_channels == 3:
346
+ images = C[:, :3, :3] @ images + C[:, :3, 3:]
347
+ elif num_channels == 1:
348
+ C = C[:, :3, :].mean(dim=1, keepdims=True)
349
+ images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
350
+ else:
351
+ raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
352
+ images = images.reshape([batch_size, num_channels, height, width])
353
+
354
+ # ----------------------
355
+ # Image-space filtering.
356
+ # ----------------------
357
+
358
+ if self.imgfilter > 0:
359
+ num_bands = self.Hz_fbank.shape[0]
360
+ assert len(self.imgfilter_bands) == num_bands
361
+ expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
362
+
363
+ # Apply amplification for each band with probability (imgfilter * strength * band_strength).
364
+ g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
365
+ for i, band_strength in enumerate(self.imgfilter_bands):
366
+ t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
367
+ t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
368
+ if debug_percentile is not None:
369
+ t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
370
+ t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
371
+ t[:, i] = t_i # Replace i'th element.
372
+ t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
373
+ g = g * t # Accumulate into global gain.
374
+
375
+ # Construct combined amplification filter.
376
+ Hz_prime = g @ self.Hz_fbank # [batch, tap]
377
+ Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
378
+ Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
379
+
380
+ # Apply filter.
381
+ p = self.Hz_fbank.shape[1] // 2
382
+ images = images.reshape([1, batch_size * num_channels, height, width])
383
+ images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
384
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
385
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
386
+ images = images.reshape([batch_size, num_channels, height, width])
387
+
388
+ # ------------------------
389
+ # Image-space corruptions.
390
+ # ------------------------
391
+ batch["img"] = images
392
+ batch["vertices"] = batch["vertices"].long()
393
+ batch["border"] = 1 - batch["E_mask"] - batch["mask"]
394
+ return batch
dp2/data/transforms/transforms.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Dict, List
3
+ import torchvision
4
+ import torch
5
+ import tops
6
+ import torchvision.transforms.functional as F
7
+ from .functional import hflip
8
+
9
+
10
+ class RandomHorizontalFlip(torch.nn.Module):
11
+
12
+ def __init__(self, p: float, flip_map=None,**kwargs):
13
+ super().__init__()
14
+ self.flip_ratio = p
15
+ self.flip_map = flip_map
16
+ if self.flip_ratio is None:
17
+ self.flip_ratio = 0.5
18
+ assert 0 <= self.flip_ratio <= 1
19
+
20
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
21
+ if torch.rand(1) > self.flip_ratio:
22
+ return container
23
+ return hflip(container, self.flip_map)
24
+
25
+
26
+ class CenterCrop(torch.nn.Module):
27
+ """
28
+ Performs the transform on the image.
29
+ NOTE: Does not transform the mask to improve runtime.
30
+ """
31
+
32
+ def __init__(self, size: List[int]):
33
+ super().__init__()
34
+ self.size = tuple(size)
35
+
36
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
37
+ min_size = min(container["img"].shape[1], container["img"].shape[2])
38
+ if min_size < self.size[0]:
39
+ container["img"] = F.center_crop(container["img"], min_size)
40
+ container["img"] = F.resize(container["img"], self.size)
41
+ return container
42
+ container["img"] = F.center_crop(container["img"], self.size)
43
+ return container
44
+
45
+
46
+ class Resize(torch.nn.Module):
47
+ """
48
+ Performs the transform on the image.
49
+ NOTE: Does not transform the mask to improve runtime.
50
+ """
51
+
52
+ def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
53
+ super().__init__()
54
+ self.size = tuple(size)
55
+ self.interpolation = interpolation
56
+
57
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
58
+ container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
59
+ if "semantic_mask" in container:
60
+ container["semantic_mask"] = F.resize(
61
+ container["semantic_mask"], self.size, F.InterpolationMode.NEAREST)
62
+ if "embedding" in container:
63
+ container["embedding"] = F.resize(
64
+ container["embedding"], self.size, self.interpolation)
65
+ if "mask" in container:
66
+ container["mask"] = F.resize(
67
+ container["mask"], self.size, F.InterpolationMode.NEAREST)
68
+ if "E_mask" in container:
69
+ container["E_mask"] = F.resize(
70
+ container["E_mask"], self.size, F.InterpolationMode.NEAREST)
71
+ if "maskrcnn_mask" in container:
72
+ container["maskrcnn_mask"] = F.resize(
73
+ container["maskrcnn_mask"], self.size, F.InterpolationMode.NEAREST)
74
+ if "vertices" in container:
75
+ container["vertices"] = F.resize(
76
+ container["vertices"], self.size, F.InterpolationMode.NEAREST)
77
+ return container
78
+
79
+ def __repr__(self):
80
+ repr = super().__repr__()
81
+ vars_ = dict(size=self.size, interpolation=self.interpolation)
82
+ return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
83
+
84
+
85
+ class InsertHRImage(torch.nn.Module):
86
+ """
87
+ Resizes mask by maxpool and assumes condition is already created
88
+ """
89
+ def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
90
+ super().__init__()
91
+ self.size = tuple(size)
92
+ self.interpolation = interpolation
93
+
94
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
95
+ assert container["img"].dtype == torch.float32
96
+ container["img_hr"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
97
+ container["condition_hr"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True)
98
+ mask = container["mask"] > 0
99
+ container["mask_hr"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float()
100
+ container["condition_hr"] = container["condition_hr"] * (1 - container["mask_hr"]) + container["img_hr"] * container["mask_hr"]
101
+ return container
102
+
103
+ def __repr__(self):
104
+ repr = super().__repr__()
105
+ vars_ = dict(size=self.size, interpolation=self.interpolation)
106
+ return repr + " "
107
+
108
+
109
+ class CopyHRImage(torch.nn.Module):
110
+ def __init__(self) -> None:
111
+ super().__init__()
112
+
113
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
114
+ container["img_hr"] = container["img"]
115
+ container["condition_hr"] = container["condition"]
116
+ container["mask_hr"] = container["mask"]
117
+ return container
118
+
119
+
120
+ class Resize2(torch.nn.Module):
121
+ """
122
+ Resizes mask by maxpool and assumes condition is already created
123
+ """
124
+ def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR, downsample_condition: bool = True, mask_condition= True):
125
+ super().__init__()
126
+ self.size = tuple(size)
127
+ self.interpolation = interpolation
128
+ self.downsample_condition = downsample_condition
129
+ self.mask_condition = mask_condition
130
+
131
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
132
+ # assert container["img"].dtype == torch.float32
133
+ container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
134
+ mask = container["mask"] > 0
135
+ container["mask"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float()
136
+
137
+ if self.downsample_condition:
138
+ container["condition"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True)
139
+ if self.mask_condition:
140
+ container["condition"] = container["condition"] * (1 - container["mask"]) + container["img"] * container["mask"]
141
+ return container
142
+
143
+ def __repr__(self):
144
+ repr = super().__repr__()
145
+ vars_ = dict(size=self.size, interpolation=self.interpolation)
146
+ return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
147
+
148
+
149
+
150
+ class Normalize(torch.nn.Module):
151
+ """
152
+ Performs the transform on the image.
153
+ NOTE: Does not transform the mask to improve runtime.
154
+ """
155
+
156
+ def __init__(self, mean, std, inplace, keys=["img"]):
157
+ super().__init__()
158
+ self.mean = mean
159
+ self.std = std
160
+ self.inplace = inplace
161
+ self.keys = keys
162
+
163
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
164
+ for key in self.keys:
165
+ container[key] = F.normalize(container[key], self.mean, self.std, self.inplace)
166
+ return container
167
+
168
+ def __repr__(self):
169
+ repr = super().__repr__()
170
+ vars_ = dict(mean=self.mean, std=self.std, inplace=self.inplace)
171
+ return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
172
+
173
+
174
+ class ToFloat(torch.nn.Module):
175
+
176
+ def __init__(self, keys=["img"], norm=True) -> None:
177
+ super().__init__()
178
+ self.keys = keys
179
+ self.gain = 255 if norm else 1
180
+
181
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
182
+ for key in self.keys:
183
+ container[key] = container[key].float() / self.gain
184
+ return container
185
+
186
+
187
+ class RandomCrop(torchvision.transforms.RandomCrop):
188
+ """
189
+ Performs the transform on the image.
190
+ NOTE: Does not transform the mask to improve runtime.
191
+ """
192
+
193
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
194
+ container["img"] = super().forward(container["img"])
195
+ return container
196
+
197
+
198
+ class CreateCondition(torch.nn.Module):
199
+
200
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
201
+ if container["img"].dtype == torch.uint8:
202
+ container["condition"] = container["img"] * container["mask"].byte() + (1-container["mask"].byte()) * 127
203
+ return container
204
+ container["condition"] = container["img"] * container["mask"]
205
+ return container
206
+
207
+
208
+ class CreateEmbedding(torch.nn.Module):
209
+
210
+ def __init__(self, embed_path: Path, cuda=True) -> None:
211
+ super().__init__()
212
+ self.embed_map = torch.load(embed_path, map_location=torch.device("cpu"))
213
+ if cuda:
214
+ self.embed_map = tops.to_cuda(self.embed_map)
215
+
216
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
217
+ vertices = container["vertices"]
218
+ if vertices.ndim == 3:
219
+ embedding = self.embed_map[vertices.long()].squeeze(dim=0)
220
+ embedding = embedding.permute(2, 0, 1) * container["E_mask"]
221
+ pass
222
+ else:
223
+ assert vertices.ndim == 4
224
+ embedding = self.embed_map[vertices.long()].squeeze(dim=1)
225
+ embedding = embedding.permute(0, 3, 1, 2) * container["E_mask"]
226
+ container["embedding"] = embedding
227
+ container["embed_map"] = self.embed_map.clone()
228
+ return container
229
+
230
+
231
+ class UpdateMask(torch.nn.Module):
232
+
233
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
234
+ container["mask"] = (container["img"] == container["condition"]).any(dim=1, keepdims=True).float()
235
+ return container
236
+
237
+
238
+ class LoadClassEmbedding(torch.nn.Module):
239
+
240
+ def __init__(self, embedding_path: Path) -> None:
241
+ super().__init__()
242
+ self.embedding = torch.load(embedding_path, map_location="cpu")
243
+
244
+ def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
245
+ key = "_".join(container["__key__"].split("train/")[-1].split("/")[:-1])
246
+ container["class_embedding"] = self.embedding[key].view(-1)
247
+ return container
dp2/data/utils.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+ import multiprocessing
5
+ import io
6
+ from tops import logger
7
+ from torch.utils.data._utils.collate import default_collate
8
+
9
+ try:
10
+ import pyspng
11
+
12
+ PYSPNG_IMPORTED = True
13
+ except ImportError:
14
+ PYSPNG_IMPORTED = False
15
+ print("Could not load pyspng. Defaulting to pillow image backend.")
16
+ from PIL import Image
17
+
18
+
19
+ def get_coco_keypoints():
20
+ return [
21
+ "nose",
22
+ "left_eye",
23
+ "right_eye",
24
+ "left_ear",
25
+ "right_ear",
26
+ "left_shoulder",
27
+ "right_shoulder",
28
+ "left_elbow",
29
+ "right_elbow",
30
+ "left_wrist",
31
+ "right_wrist",
32
+ "left_hip",
33
+ "right_hip",
34
+ "left_knee",
35
+ "right_knee",
36
+ "left_ankle",
37
+ "right_ankle",
38
+ ]
39
+
40
+
41
+ def get_coco_flipmap():
42
+ keypoints = get_coco_keypoints()
43
+ keypoint_flip_map = {
44
+ "left_eye": "right_eye",
45
+ "left_ear": "right_ear",
46
+ "left_shoulder": "right_shoulder",
47
+ "left_elbow": "right_elbow",
48
+ "left_wrist": "right_wrist",
49
+ "left_hip": "right_hip",
50
+ "left_knee": "right_knee",
51
+ "left_ankle": "right_ankle",
52
+ }
53
+ for key, value in list(keypoint_flip_map.items()):
54
+ keypoint_flip_map[value] = key
55
+ keypoint_flip_map["nose"] = "nose"
56
+ keypoint_flip_map_idx = []
57
+ for source in keypoints:
58
+ keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source]))
59
+ return keypoint_flip_map_idx
60
+
61
+
62
+ def mask_decoder(x):
63
+ mask = torch.from_numpy(np.array(Image.open(io.BytesIO(x)))).squeeze()[None]
64
+ mask = mask > 0 # This fixes bug causing maskf.loat().max() == 255.
65
+ return mask
66
+
67
+
68
+ def png_decoder(x):
69
+ if PYSPNG_IMPORTED:
70
+ return torch.from_numpy(np.rollaxis(pyspng.load(x), 2))
71
+ with Image.open(io.BytesIO(x)) as im:
72
+ im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
73
+ return im
74
+
75
+
76
+ def jpg_decoder(x):
77
+ with Image.open(io.BytesIO(x)) as im:
78
+ im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2))
79
+ return im
80
+
81
+
82
+ def get_num_workers(num_workers: int):
83
+ n_cpus = multiprocessing.cpu_count()
84
+ if num_workers > n_cpus:
85
+ logger.warn(f"Setting the number of workers to match cpu count: {n_cpus}")
86
+ return n_cpus
87
+ return num_workers
88
+
89
+
90
+ def collate_fn(batch):
91
+ elem = batch[0]
92
+ ignore_keys = set(["embed_map", "vertx2cat"])
93
+ batch_ = {
94
+ key: default_collate([d[key] for d in batch])
95
+ for key in elem
96
+ if key not in ignore_keys
97
+ }
98
+ if "embed_map" in elem:
99
+ batch_["embed_map"] = elem["embed_map"]
100
+ if "vertx2cat" in elem:
101
+ batch_["vertx2cat"] = elem["vertx2cat"]
102
+ return batch_
dp2/detection/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .cse_mask_face_detector import CSeMaskFaceDetector
2
+ from .person_detector import CSEPersonDetector
3
+ from .structures import PersonDetection, VehicleDetection, FaceDetection
dp2/detection/base.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torch
3
+ import lzma
4
+ from pathlib import Path
5
+ from tops import logger
6
+
7
+
8
+ class BaseDetector:
9
+
10
+
11
+ def __init__(self, cache_directory: str) -> None:
12
+ if cache_directory is not None:
13
+ self.cache_directory = Path(cache_directory, str(self.__class__.__name__))
14
+ self.cache_directory.mkdir(exist_ok=True, parents=True)
15
+
16
+ def save_to_cache(self, detection, cache_path: Path, after_preprocess=True):
17
+ logger.log(f"Caching detection to: {cache_path}")
18
+ with lzma.open(cache_path, "wb") as fp:
19
+ torch.save(
20
+ [det.state_dict(after_preprocess=after_preprocess) for det in detection], fp,
21
+ pickle_protocol=pickle.HIGHEST_PROTOCOL)
22
+
23
+ def load_from_cache(self, cache_path: Path):
24
+ logger.log(f"Loading detection from cache path: {cache_path}")
25
+ with lzma.open(cache_path, "rb") as fp:
26
+ state_dict = torch.load(fp)
27
+ return [
28
+ state["cls"].from_state_dict(state_dict=state) for state in state_dict
29
+ ]
30
+
31
+ def forward_and_cache(self, im: torch.Tensor, cache_id: str, load_cache: bool):
32
+ if cache_id is None:
33
+ return self.forward(im)
34
+ cache_path = self.cache_directory.joinpath(cache_id + ".torch")
35
+ if cache_path.is_file() and load_cache:
36
+ try:
37
+ return self.load_from_cache(cache_path)
38
+ except Exception as e:
39
+ logger.warn(f"The cache file was corrupted: {cache_path}")
40
+ exit()
41
+ detections = self.forward(im)
42
+ self.save_to_cache(detections, cache_path)
43
+ return detections
44
+
45
+
dp2/detection/box_utils.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def expand_bbox_to_ratio(bbox, imshape, target_aspect_ratio):
5
+ x0, y0, x1, y1 = [int(_) for _ in bbox]
6
+ h, w = y1 - y0, x1 - x0
7
+ cur_ratio = h / w
8
+
9
+ if cur_ratio == target_aspect_ratio:
10
+ return [x0, y0, x1, y1]
11
+ if cur_ratio < target_aspect_ratio:
12
+ target_height = int(w*target_aspect_ratio)
13
+ y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
14
+ else:
15
+ target_width = int(h/target_aspect_ratio)
16
+ x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
17
+ return x0, y0, x1, y1
18
+
19
+
20
+ def expand_axis(start, end, target_width, limit):
21
+ # Can return a bbox outside of limit
22
+ cur_width = end - start
23
+ start = start - (target_width-cur_width)//2
24
+ end = end + (target_width-cur_width)//2
25
+ if end - start != target_width:
26
+ end += 1
27
+ assert end - start == target_width
28
+ if start < 0 and end > limit:
29
+ return start, end
30
+ if start < 0 and end < limit:
31
+ to_shift = min(0 - start, limit - end)
32
+ start += to_shift
33
+ end += to_shift
34
+ if end > limit and start > 0:
35
+ to_shift = min(end - limit, start)
36
+ end -= to_shift
37
+ start -= to_shift
38
+ assert end - start == target_width
39
+ return start, end
40
+
41
+
42
+ def expand_box(bbox, imshape, mask, percentage_background: float):
43
+ assert isinstance(bbox[0], int)
44
+ assert 0 < percentage_background < 1
45
+ # Percentage in S
46
+ mask_pixels = mask.long().sum().cpu()
47
+ total_pixels = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
48
+ percentage_mask = mask_pixels / total_pixels
49
+ if (1 - percentage_mask) > percentage_background:
50
+ return bbox
51
+ target_pixels = mask_pixels / (1 - percentage_background)
52
+ x0, y0, x1, y1 = bbox
53
+ H = y1 - y0
54
+ W = x1 - x0
55
+ p = np.sqrt(target_pixels/(H*W))
56
+ target_width = int(np.ceil(p * W))
57
+ target_height = int(np.ceil(p * H))
58
+ x0, x1 = expand_axis(x0, x1, target_width, imshape[1])
59
+ y0, y1 = expand_axis(y0, y1, target_height, imshape[0])
60
+ return [x0, y0, x1, y1]
61
+
62
+
63
+ def expand_axises_by_percentage(bbox_XYXY, imshape, percentage):
64
+ x0, y0, x1, y1 = bbox_XYXY
65
+ H = y1 - y0
66
+ W = x1 - x0
67
+ expansion = int(((H*W)**0.5) * percentage)
68
+ new_width = W + expansion
69
+ new_height = H + expansion
70
+ x0, x1 = expand_axis(x0, x1, min(new_width, imshape[1]), imshape[1])
71
+ y0, y1 = expand_axis(y0, y1, min(new_height, imshape[0]), imshape[0])
72
+ return [x0, y0, x1, y1]
73
+
74
+
75
+ def get_expanded_bbox(
76
+ bbox_XYXY,
77
+ imshape,
78
+ mask,
79
+ percentage_background: float,
80
+ axis_minimum_expansion: float,
81
+ target_aspect_ratio: float):
82
+ bbox_XYXY = bbox_XYXY.long().cpu().numpy().tolist()
83
+ # Expand each axis of the bounding box by a minimum percentage
84
+ bbox_XYXY = expand_axises_by_percentage(bbox_XYXY, imshape, axis_minimum_expansion)
85
+ # Find the minimum bbox with the aspect ratio. Can be outside of imshape
86
+ bbox_XYXY = expand_bbox_to_ratio(bbox_XYXY, imshape, target_aspect_ratio)
87
+ # Expands square box such that X% of the bbox is background
88
+ bbox_XYXY = expand_box(bbox_XYXY, imshape, mask, percentage_background)
89
+ assert isinstance(bbox_XYXY[0], (int, np.int64))
90
+ return bbox_XYXY
91
+
92
+
93
+ def include_box(bbox, minimum_area, aspect_ratio_range, min_bbox_ratio_inside, imshape):
94
+ def area_inside_ratio(bbox, imshape):
95
+ area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
96
+ area_inside = (min(bbox[2], imshape[1]) - max(0,bbox[0])) * (min(imshape[0],bbox[3]) - max(0,bbox[1]))
97
+ return area_inside / area
98
+ ratio = (bbox[3] - bbox[1]) / (bbox[2] - bbox[0])
99
+ area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
100
+ if area_inside_ratio(bbox, imshape) < min_bbox_ratio_inside:
101
+ return False
102
+ if ratio <= aspect_ratio_range[0] or ratio >= aspect_ratio_range[1] or area < minimum_area:
103
+ return False
104
+ return True
dp2/detection/box_utils_fdf.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The FDF dataset expands bound boxes differently from what is used for CSE.
3
+ """
4
+
5
+ import numpy as np
6
+
7
+
8
+ def quadratic_bounding_box(x0, y0, width, height, imshape):
9
+ # We assume that we can create a image that is quadratic without
10
+ # minimizing any of the sides
11
+ assert width <= min(imshape[:2])
12
+ assert height <= min(imshape[:2])
13
+ min_side = min(height, width)
14
+ if height != width:
15
+ side_diff = abs(height - width)
16
+ # Want to extend the shortest side
17
+ if min_side == height:
18
+ # Vertical side
19
+ height += side_diff
20
+ if height > imshape[0]:
21
+ # Take full frame, and shrink width
22
+ y0 = 0
23
+ height = imshape[0]
24
+
25
+ side_diff = abs(height - width)
26
+ width -= side_diff
27
+ x0 += side_diff // 2
28
+ else:
29
+ y0 -= side_diff // 2
30
+ y0 = max(0, y0)
31
+ else:
32
+ # Horizontal side
33
+ width += side_diff
34
+ if width > imshape[1]:
35
+ # Take full frame width, and shrink height
36
+ x0 = 0
37
+ width = imshape[1]
38
+
39
+ side_diff = abs(height - width)
40
+ height -= side_diff
41
+ y0 += side_diff // 2
42
+ else:
43
+ x0 -= side_diff // 2
44
+ x0 = max(0, x0)
45
+ # Check that bbox goes outside image
46
+ x1 = x0 + width
47
+ y1 = y0 + height
48
+ if imshape[1] < x1:
49
+ diff = x1 - imshape[1]
50
+ x0 -= diff
51
+ if imshape[0] < y1:
52
+ diff = y1 - imshape[0]
53
+ y0 -= diff
54
+ assert x0 >= 0, "Bounding box outside image."
55
+ assert y0 >= 0, "Bounding box outside image."
56
+ assert x0 + width <= imshape[1], "Bounding box outside image."
57
+ assert y0 + height <= imshape[0], "Bounding box outside image."
58
+ return x0, y0, width, height
59
+
60
+
61
+ def expand_bounding_box(bbox, percentage, imshape):
62
+ orig_bbox = bbox.copy()
63
+ x0, y0, x1, y1 = bbox
64
+ width = x1 - x0
65
+ height = y1 - y0
66
+ x0, y0, width, height = quadratic_bounding_box(
67
+ x0, y0, width, height, imshape)
68
+ expanding_factor = int(max(height, width) * percentage)
69
+
70
+ possible_max_expansion = [(imshape[0] - width) // 2,
71
+ (imshape[1] - height) // 2,
72
+ expanding_factor]
73
+
74
+ expanding_factor = min(possible_max_expansion)
75
+ # Expand height
76
+
77
+ if expanding_factor > 0:
78
+
79
+ y0 = y0 - expanding_factor
80
+ y0 = max(0, y0)
81
+
82
+ height += expanding_factor * 2
83
+ if height > imshape[0]:
84
+ y0 -= (imshape[0] - height)
85
+ height = imshape[0]
86
+
87
+ if height + y0 > imshape[0]:
88
+ y0 -= (height + y0 - imshape[0])
89
+
90
+ # Expand width
91
+ x0 = x0 - expanding_factor
92
+ x0 = max(0, x0)
93
+
94
+ width += expanding_factor * 2
95
+ if width > imshape[1]:
96
+ x0 -= (imshape[1] - width)
97
+ width = imshape[1]
98
+
99
+ if width + x0 > imshape[1]:
100
+ x0 -= (width + x0 - imshape[1])
101
+ y1 = y0 + height
102
+ x1 = x0 + width
103
+ assert y0 >= 0, "Y0 is minus"
104
+ assert height <= imshape[0], "Height is larger than image."
105
+ assert x0 + width <= imshape[1]
106
+ assert y0 + height <= imshape[0]
107
+ assert width == height, "HEIGHT IS NOT EQUAL WIDTH!!"
108
+ assert x0 >= 0, "Y0 is minus"
109
+ assert width <= imshape[1], "Height is larger than image."
110
+ # Check that original bbox is within new
111
+ x0_o, y0_o, x1_o, y1_o = orig_bbox
112
+ assert x0 <= x0_o, f"New bbox is outisde of original. O:{x0_o}, N: {x0}"
113
+ assert x1 >= x1_o, f"New bbox is outisde of original. O:{x1_o}, N: {x1}"
114
+ assert y0 <= y0_o, f"New bbox is outisde of original. O:{y0_o}, N: {y0}"
115
+ assert y1 >= y1_o, f"New bbox is outisde of original. O:{y1_o}, N: {y1}"
116
+
117
+ x0, y0, width, height = [int(_) for _ in [x0, y0, width, height]]
118
+ x1 = x0 + width
119
+ y1 = y0 + height
120
+ return np.array([x0, y0, x1, y1])
121
+
122
+
123
+ def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint):
124
+ keypoint = keypoint[:, :3] # only nose + eyes are relevant
125
+ kp_X = keypoint[0, :]
126
+ kp_Y = keypoint[1, :]
127
+ within_X = np.all(kp_X >= x0) and np.all(kp_X <= x1)
128
+ within_Y = np.all(kp_Y >= y0) and np.all(kp_Y <= y1)
129
+ return within_X and within_Y
130
+
131
+
132
+ def expand_bbox_simple(bbox, percentage):
133
+ x0, y0, x1, y1 = bbox.astype(float)
134
+ width = x1 - x0
135
+ height = y1 - y0
136
+ x_c = int(x0) + width // 2
137
+ y_c = int(y0) + height // 2
138
+ avg_size = max(width, height)
139
+ new_width = avg_size * (1 + percentage)
140
+ x0 = x_c - new_width // 2
141
+ y0 = y_c - new_width // 2
142
+ x1 = x_c + new_width // 2
143
+ y1 = y_c + new_width // 2
144
+ return np.array([x0, y0, x1, y1]).astype(int)
145
+
146
+
147
+ def pad_image(im, bbox, pad_value):
148
+ x0, y0, x1, y1 = bbox
149
+ if x0 < 0:
150
+ pad_im = np.zeros((im.shape[0], abs(x0), im.shape[2]),
151
+ dtype=np.uint8) + pad_value
152
+ im = np.concatenate((pad_im, im), axis=1)
153
+ x1 += abs(x0)
154
+ x0 = 0
155
+ if y0 < 0:
156
+ pad_im = np.zeros((abs(y0), im.shape[1], im.shape[2]),
157
+ dtype=np.uint8) + pad_value
158
+ im = np.concatenate((pad_im, im), axis=0)
159
+ y1 += abs(y0)
160
+ y0 = 0
161
+ if x1 >= im.shape[1]:
162
+ pad_im = np.zeros(
163
+ (im.shape[0], x1 - im.shape[1] + 1, im.shape[2]),
164
+ dtype=np.uint8) + pad_value
165
+ im = np.concatenate((im, pad_im), axis=1)
166
+ if y1 >= im.shape[0]:
167
+ pad_im = np.zeros(
168
+ (y1 - im.shape[0] + 1, im.shape[1], im.shape[2]),
169
+ dtype=np.uint8) + pad_value
170
+ im = np.concatenate((im, pad_im), axis=0)
171
+ return im[y0:y1, x0:x1]
172
+
173
+
174
+ def clip_box(bbox, im):
175
+ bbox[0] = max(0, bbox[0])
176
+ bbox[1] = max(0, bbox[1])
177
+ bbox[2] = min(im.shape[1] - 1, bbox[2])
178
+ bbox[3] = min(im.shape[0] - 1, bbox[3])
179
+ return bbox
180
+
181
+
182
+ def cut_face(im, bbox, simple_expand=False, pad_value=0, pad_im=True):
183
+ outside_im = (bbox < 0).any() or bbox[2] > im.shape[1] or bbox[3] > im.shape[0]
184
+ if simple_expand or (outside_im and pad_im):
185
+ return pad_image(im, bbox, pad_value)
186
+ bbox = clip_box(bbox, im)
187
+ x0, y0, x1, y1 = bbox
188
+ return im[y0:y1, x0:x1]
189
+
190
+
191
+ def expand_bbox(
192
+ bbox_ltrb, imshape, simple_expand, default_to_simple=False,
193
+ expansion_factor=0.35):
194
+ assert bbox_ltrb.shape == (4,), f"BBox shape was: {bbox.shape}"
195
+ bbox = bbox_ltrb.astype(float)
196
+ # FDF256 uses simple expand with ratio 0.4
197
+ if simple_expand:
198
+ return expand_bbox_simple(bbox, 0.4)
199
+ try:
200
+ return expand_bounding_box(bbox, expansion_factor, imshape)
201
+ except AssertionError:
202
+ return expand_bbox_simple(bbox, expansion_factor * 2)
203
+
dp2/detection/cse_mask_face_detector.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import lzma
3
+ import tops
4
+ from pathlib import Path
5
+ from dp2.detection.base import BaseDetector
6
+ from .utils import combine_cse_maskrcnn_dets
7
+ from face_detection import build_detector as build_face_detector
8
+ from .models.cse import CSEDetector
9
+ from .models.mask_rcnn import MaskRCNNDetector
10
+ from .structures import CSEPersonDetection, VehicleDetection, FaceDetection, PersonDetection
11
+ from tops import logger
12
+
13
+
14
+ def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
15
+ assert len(box1.shape) == 2
16
+ assert len(box2.shape) == 2
17
+ box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
18
+ # This can be batched
19
+ for i, box in enumerate(box1):
20
+ is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
21
+ is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
22
+ is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
23
+ box1_inside[i] = is_outside.logical_not().any()
24
+ return box1_inside
25
+
26
+
27
+ class CSeMaskFaceDetector(BaseDetector):
28
+
29
+ def __init__(
30
+ self,
31
+ mask_rcnn_cfg,
32
+ face_detector_cfg: dict,
33
+ cse_cfg: dict,
34
+ face_post_process_cfg: dict,
35
+ cse_post_process_cfg,
36
+ score_threshold: float,
37
+ **kwargs
38
+ ) -> None:
39
+ super().__init__(**kwargs)
40
+ self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
41
+ if "confidence_threshold" not in face_detector_cfg:
42
+ face_detector_cfg["confidence_threshold"] = score_threshold
43
+ if "score_thres" not in cse_cfg:
44
+ cse_cfg["score_thres"] = score_threshold
45
+ self.cse_detector = CSEDetector(**cse_cfg)
46
+ self.face_detector = build_face_detector(**face_detector_cfg, clip_boxes=True)
47
+ self.cse_post_process_cfg = cse_post_process_cfg
48
+ self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
49
+ self.mask_cse_iou_combine_threshold = self.cse_post_process_cfg.pop("iou_combine_threshold")
50
+ self.face_post_process_cfg = face_post_process_cfg
51
+
52
+ def __call__(self, *args, **kwargs):
53
+ return self.forward(*args, **kwargs)
54
+
55
+ def _detect_faces(self, im: torch.Tensor):
56
+ H, W = im.shape[1:]
57
+ im = im.float() - self.face_mean
58
+ im = self.face_detector.resize(im[None], 1.0)
59
+ boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
60
+ boxes_XYXY[:, [0, 2]] *= W
61
+ boxes_XYXY[:, [1, 3]] *= H
62
+ return boxes_XYXY.round().long()
63
+
64
+ def load_from_cache(self, cache_path: Path):
65
+ logger.log(f"Loading detection from cache path: {cache_path}",)
66
+ with lzma.open(cache_path, "rb") as fp:
67
+ state_dict = torch.load(fp, map_location="cpu")
68
+ kwargs = dict(
69
+ post_process_cfg=self.cse_post_process_cfg,
70
+ embed_map=self.cse_detector.embed_map,
71
+ **self.face_post_process_cfg
72
+ )
73
+ return [
74
+ state["cls"].from_state_dict(**kwargs, state_dict=state)
75
+ for state in state_dict
76
+ ]
77
+
78
+ @torch.no_grad()
79
+ def forward(self, im: torch.Tensor):
80
+ maskrcnn_dets = self.mask_rcnn(im)
81
+ cse_dets = self.cse_detector(im)
82
+ embed_map = self.cse_detector.embed_map
83
+ print("Calling face detector.")
84
+ face_boxes = self._detect_faces(im).cpu()
85
+ maskrcnn_person = {
86
+ k: v[maskrcnn_dets["is_person"]] for k, v in maskrcnn_dets.items()
87
+ }
88
+ maskrcnn_other = {
89
+ k: v[maskrcnn_dets["is_person"].logical_not()] for k, v in maskrcnn_dets.items()
90
+ }
91
+ maskrcnn_other = VehicleDetection(maskrcnn_other["segmentation"])
92
+ combined_segmentation, cse_dets, matches = combine_cse_maskrcnn_dets(
93
+ maskrcnn_person["segmentation"], cse_dets, self.mask_cse_iou_combine_threshold)
94
+
95
+ persons_with_cse = CSEPersonDetection(
96
+ combined_segmentation, cse_dets, **self.cse_post_process_cfg,
97
+ embed_map=embed_map,orig_imshape_CHW=im.shape
98
+ )
99
+ persons_with_cse.pre_process()
100
+ not_matched = [i for i in range(maskrcnn_person["segmentation"].shape[0]) if i not in matches[:, 0]]
101
+ persons_without_cse = PersonDetection(
102
+ maskrcnn_person["segmentation"][not_matched], **self.cse_post_process_cfg,
103
+ orig_imshape_CHW=im.shape
104
+ )
105
+ persons_without_cse.pre_process()
106
+
107
+ face_boxes_covered = box1_inside_box2(face_boxes, persons_with_cse.dilated_boxes).logical_or(
108
+ box1_inside_box2(face_boxes, persons_without_cse.dilated_boxes)
109
+ )
110
+ face_boxes = face_boxes[face_boxes_covered.logical_not()]
111
+ face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
112
+
113
+ # Order matters. The anonymizer will anonymize FIFO.
114
+ # Later detections will overwrite.
115
+ all_detections = [face_boxes, maskrcnn_other, persons_without_cse, persons_with_cse]
116
+ return all_detections
dp2/detection/face_detector.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import lzma
3
+ import tops
4
+ from pathlib import Path
5
+ from dp2.detection.base import BaseDetector
6
+ from face_detection import build_detector as build_face_detector
7
+ from .structures import FaceDetection
8
+ from tops import logger
9
+
10
+
11
+ def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
12
+ assert len(box1.shape) == 2
13
+ assert len(box2.shape) == 2
14
+ box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
15
+ # This can be batched
16
+ for i, box in enumerate(box1):
17
+ is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
18
+ is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
19
+ is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
20
+ box1_inside[i] = is_outside.logical_not().any()
21
+ return box1_inside
22
+
23
+
24
+ class FaceDetector(BaseDetector):
25
+
26
+ def __init__(
27
+ self,
28
+ face_detector_cfg: dict,
29
+ score_threshold: float,
30
+ face_post_process_cfg: dict,
31
+ **kwargs
32
+ ) -> None:
33
+ super().__init__(**kwargs)
34
+ self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold)
35
+ self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
36
+ self.face_post_process_cfg = face_post_process_cfg
37
+
38
+ def __call__(self, *args, **kwargs):
39
+ return self.forward(*args, **kwargs)
40
+
41
+ def _detect_faces(self, im: torch.Tensor):
42
+ H, W = im.shape[1:]
43
+ im = im.float() - self.face_mean
44
+ im = self.face_detector.resize(im[None], 1.0)
45
+ boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
46
+ boxes_XYXY[:, [0, 2]] *= W
47
+ boxes_XYXY[:, [1, 3]] *= H
48
+ return boxes_XYXY.round().long().cpu()
49
+
50
+ @torch.no_grad()
51
+ def forward(self, im: torch.Tensor):
52
+ face_boxes = self._detect_faces(im)
53
+ face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
54
+ return [face_boxes]
55
+
56
+ def load_from_cache(self, cache_path: Path):
57
+ logger.log(f"Loading detection from cache path: {cache_path}")
58
+ with lzma.open(cache_path, "rb") as fp:
59
+ state_dict = torch.load(fp)
60
+ return [
61
+ state["cls"].from_state_dict(state_dict=state, **self.face_post_process_cfg) for state in state_dict
62
+ ]
dp2/detection/models/__init__.py ADDED
File without changes
dp2/detection/models/cse.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List
3
+ import tops
4
+ from torchvision.transforms.functional import InterpolationMode, resize
5
+ from densepose.data.utils import get_class_to_mesh_name_mapping
6
+ from densepose import add_densepose_config
7
+ from densepose.structures import DensePoseEmbeddingPredictorOutput
8
+ from densepose.vis.extractor import DensePoseOutputsExtractor
9
+ from densepose.modeling import build_densepose_embedder
10
+ from detectron2.config import get_cfg
11
+ from detectron2.data.transforms import ResizeShortestEdge
12
+ from detectron2.checkpoint.detection_checkpoint import DetectionCheckpointer
13
+ from detectron2.modeling import build_model
14
+
15
+
16
+ model_urls = {
17
+ "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x/250713061/model_final_1d3314.pkl",
18
+ "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl",
19
+ }
20
+
21
+
22
+ def cse_det_to_global(boxes_XYXY, S: torch.Tensor, imshape):
23
+ assert len(S.shape) == 3
24
+ H, W = imshape
25
+ N = len(boxes_XYXY)
26
+ segmentation = torch.zeros((N, H, W), dtype=torch.bool, device=S.device)
27
+ boxes_XYXY = boxes_XYXY.long()
28
+ for i in range(N):
29
+ x0, y0, x1, y1 = boxes_XYXY[i]
30
+ assert x0 >= 0 and y0 >= 0
31
+ assert x1 <= imshape[1]
32
+ assert y1 <= imshape[0]
33
+ h = y1 - y0
34
+ w = x1 - x0
35
+ segmentation[i:i+1, y0:y1, x0:x1] = resize(S[i:i+1], (h, w), interpolation=InterpolationMode.NEAREST) > 0
36
+ return segmentation
37
+
38
+
39
+ class CSEDetector:
40
+
41
+ def __init__(
42
+ self,
43
+ cfg_url: str = "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
44
+ cfg_2_download: List[str] = [
45
+ "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml",
46
+ "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml",
47
+ "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml"],
48
+ score_thres: float = 0.9,
49
+ nms_thresh: float = None,
50
+ ) -> None:
51
+ with tops.logger.capture_log_stdout():
52
+ cfg = get_cfg()
53
+ self.device = tops.get_device()
54
+ add_densepose_config(cfg)
55
+ cfg_path = tops.download_file(cfg_url)
56
+ for p in cfg_2_download:
57
+ tops.download_file(p)
58
+ with tops.logger.capture_log_stdout():
59
+ cfg.merge_from_file(cfg_path)
60
+ assert cfg_url in model_urls, cfg_url
61
+ model_path = tops.download_file(model_urls[cfg_url])
62
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres
63
+ if nms_thresh is not None:
64
+ cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = nms_thresh
65
+ cfg.MODEL.WEIGHTS = str(model_path)
66
+ cfg.MODEL.DEVICE = str(self.device)
67
+ cfg.freeze()
68
+ with tops.logger.capture_log_stdout():
69
+ self.model = build_model(cfg)
70
+ self.model.eval()
71
+ DetectionCheckpointer(self.model).load(str(model_path))
72
+ self.input_format = cfg.INPUT.FORMAT
73
+ self.densepose_extractor = DensePoseOutputsExtractor()
74
+ self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
75
+
76
+ self.embedder = build_densepose_embedder(cfg)
77
+ self.mesh_vertex_embeddings = {
78
+ mesh_name: self.embedder(mesh_name).to(self.device)
79
+ for mesh_name in self.class_to_mesh_name.values()
80
+ if self.embedder.has_embeddings(mesh_name)
81
+ }
82
+ self.cfg = cfg
83
+ self.embed_map = self.mesh_vertex_embeddings["smpl_27554"]
84
+ tops.logger.log("CSEDetector built.")
85
+
86
+ def __call__(self, *args, **kwargs):
87
+ return self.forward(*args, **kwargs)
88
+
89
+ def resize_im(self, im):
90
+ H, W = im.shape[1:]
91
+ newH, newW = ResizeShortestEdge.get_output_shape(
92
+ H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
93
+ return resize(
94
+ im, (newH, newW), InterpolationMode.BILINEAR, antialias=True)
95
+
96
+ @torch.no_grad()
97
+ def forward(self, im):
98
+ assert im.dtype == torch.uint8
99
+ if self.input_format == "BGR":
100
+ im = im.flip(0)
101
+ H, W = im.shape[1:]
102
+ im = self.resize_im(im)
103
+ output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"]
104
+ scores = output.get("scores")
105
+ if len(scores) == 0:
106
+ return dict(
107
+ instance_segmentation=torch.empty((0, 0, 112, 112), dtype=torch.bool, device=im.device),
108
+ instance_embedding=torch.empty((0, 16, 112, 112), dtype=torch.float32, device=im.device),
109
+ embed_map=self.mesh_vertex_embeddings["smpl_27554"],
110
+ bbox_XYXY=torch.empty((0, 4), dtype=torch.long, device=im.device),
111
+ im_segmentation=torch.empty((0, H, W), dtype=torch.bool, device=im.device),
112
+ scores=torch.empty((0), dtype=torch.float, device=im.device)
113
+ )
114
+ pred_densepose, boxes_xywh, classes = self.densepose_extractor(output)
115
+ assert isinstance(pred_densepose, DensePoseEmbeddingPredictorOutput), pred_densepose
116
+ S = pred_densepose.coarse_segm.argmax(dim=1) # Segmentation channel Nx2xHxW (2 because only 2 classes)
117
+ E = pred_densepose.embedding
118
+ mesh_name = self.class_to_mesh_name[classes[0]]
119
+ assert mesh_name == "smpl_27554"
120
+ x0, y0, w, h = [boxes_xywh[:, i] for i in range(4)]
121
+ boxes_XYXY = torch.stack((x0, y0, x0+w, y0+h), dim=-1)
122
+ boxes_XYXY = boxes_XYXY.round_().long()
123
+
124
+ non_empty_boxes = (boxes_XYXY[:, :2] == boxes_XYXY[:, 2:]).any(dim=1).logical_not()
125
+ S = S[non_empty_boxes]
126
+ E = E[non_empty_boxes]
127
+ boxes_XYXY = boxes_XYXY[non_empty_boxes]
128
+ scores = scores[non_empty_boxes]
129
+ im_segmentation = cse_det_to_global(boxes_XYXY, S, [H, W])
130
+ return dict(
131
+ instance_segmentation=S, instance_embedding=E,
132
+ bbox_XYXY=boxes_XYXY,
133
+ im_segmentation=im_segmentation,
134
+ scores=scores.view(-1))
135
+
dp2/detection/models/keypoint_maskrcnn.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from detectron2.checkpoint import DetectionCheckpointer
4
+ from detectron2.modeling.roi_heads import CascadeROIHeads, StandardROIHeads
5
+ from detectron2.data.transforms import ResizeShortestEdge
6
+ from detectron2.structures import Instances
7
+ from detectron2 import model_zoo
8
+ from detectron2.config import instantiate
9
+ from detectron2.config import LazyCall as L
10
+ from PIL import Image
11
+ import tops
12
+ import functools
13
+ from torchvision.transforms.functional import resize
14
+
15
+
16
+ def get_rn50_fpn_keypoint_rcnn(weight_path: str):
17
+ from detectron2.modeling.poolers import ROIPooler
18
+ from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead
19
+ from detectron2.layers import ShapeSpec
20
+ model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model
21
+ model.roi_heads.update(
22
+ num_classes=1,
23
+ keypoint_in_features=["p2", "p3", "p4", "p5"],
24
+ keypoint_pooler=L(ROIPooler)(
25
+ output_size=14,
26
+ scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32),
27
+ sampling_ratio=0,
28
+ pooler_type="ROIAlignV2",
29
+ ),
30
+ keypoint_head=L(KRCNNConvDeconvUpsampleHead)(
31
+ input_shape=ShapeSpec(channels=256, width=14, height=14),
32
+ num_keypoints=17,
33
+ conv_dims=[512] * 8,
34
+ loss_normalizer="visible",
35
+ ),
36
+ )
37
+
38
+ # Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2.
39
+ # 1000 proposals per-image is found to hurt box AP.
40
+ # Therefore we increase it to 1500 per-image.
41
+ model.proposal_generator.post_nms_topk = (1500, 1000)
42
+
43
+ # Keypoint AP degrades (though box AP improves) when using plain L1 loss
44
+ model.roi_heads.box_predictor.smooth_l1_beta = 0.5
45
+ model = instantiate(model)
46
+
47
+ dataloader = model_zoo.get_config("common/data/coco_keypoint.py").dataloader
48
+ test_transform = instantiate(dataloader.test.mapper.augmentations)
49
+ DetectionCheckpointer(model).load(weight_path)
50
+ return model, test_transform
51
+
52
+
53
+ models = {
54
+ "rn50_fpn_maskrcnn": functools.partial(get_rn50_fpn_keypoint_rcnn, weight_path="https://folk.ntnu.no/haakohu/checkpoints/maskrcnn_keypoint/keypoint_maskrcnn_R_50_FPN_1x.pth")
55
+ }
56
+
57
+
58
+
59
+
60
+ class KeypointMaskRCNN:
61
+
62
+ def __init__(self, model_name: str, score_threshold: float) -> None:
63
+ assert model_name in models, f"Did not find {model_name} in models"
64
+ model, test_transform = models[model_name]()
65
+ self.model = model.eval().to(tops.get_device())
66
+ if isinstance(self.model.roi_heads, CascadeROIHeads):
67
+ for head in self.model.roi_heads.box_predictors:
68
+ assert hasattr(head, "test_score_thresh")
69
+ head.test_score_thresh = score_threshold
70
+ else:
71
+ assert isinstance(self.model.roi_heads, StandardROIHeads)
72
+ assert hasattr(self.model.roi_heads.box_predictor, "test_score_thresh")
73
+ self.model.roi_heads.box_predictor.test_score_thresh = score_threshold
74
+
75
+ self.test_transform = test_transform
76
+ assert len(self.test_transform) == 1
77
+ self.test_transform = self.test_transform[0]
78
+ assert isinstance(self.test_transform, ResizeShortestEdge)
79
+ assert self.test_transform.interp == Image.BILINEAR
80
+ self.image_format = self.model.input_format
81
+
82
+ def resize_im(self, im):
83
+ H, W = im.shape[-2:]
84
+ if self.test_transform.is_range:
85
+ size = np.random.randint(self.test_transform.short_edge_length[0], self.test_transform.short_edge_length[1] + 1)
86
+ else:
87
+ size = np.random.choice(self.test_transform.short_edge_length)
88
+ newH, newW = ResizeShortestEdge.get_output_shape(H, W, size, self.test_transform.max_size)
89
+ return resize(
90
+ im, (newH, newW), antialias=True)
91
+
92
+ def __call__(self, *args, **kwargs):
93
+ return self.forward(*args, **kwargs)
94
+
95
+ @torch.no_grad()
96
+ def forward(self, im: torch.Tensor) -> Instances:
97
+ assert im.ndim == 3
98
+ if self.image_format == "BGR":
99
+ im = im.flip(0)
100
+ H, W = im.shape[-2:]
101
+ im = self.resize_im(im)
102
+ im = im.float()
103
+ inputs = dict(image=im, height=H, width=W)
104
+ # instances contains
105
+ # dict_keys(['pred_boxes', 'scores', 'pred_classes', 'pred_masks', 'pred_keypoints', 'pred_keypoint_heatmaps'])
106
+ instances = self.model([inputs])[0]["instances"]
107
+ return dict(
108
+ scores=instances.get("scores").cpu(),
109
+ segmentation=instances.get("pred_masks").cpu(),
110
+ keypoints=instances.get("pred_keypoints").cpu()
111
+ )
dp2/detection/models/mask_rcnn.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tops
3
+ from detectron2.modeling import build_model
4
+ from detectron2.checkpoint import DetectionCheckpointer
5
+ from detectron2.structures import Boxes
6
+ from detectron2.data import MetadataCatalog
7
+ from detectron2 import model_zoo
8
+ from typing import Dict
9
+ from detectron2.data.transforms import ResizeShortestEdge
10
+ from torchvision.transforms.functional import resize
11
+
12
+
13
+
14
+ model_urls = {
15
+ "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml": "https://dl.fbaipublicfiles.com/detectron2/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x/139653917/model_final_2d9806.pkl",
16
+
17
+ }
18
+ class MaskRCNNDetector:
19
+
20
+ def __init__(
21
+ self,
22
+ cfg_name: str = "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml",
23
+ score_thres: float = 0.9,
24
+ class_filter=["person"], #["car", "bicycle","truck", "bus", "backpack"]
25
+ fp16_inference: bool = False
26
+ ) -> None:
27
+ cfg = model_zoo.get_config(cfg_name)
28
+ cfg.MODEL.DEVICE = str(tops.get_device())
29
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres
30
+ cfg.freeze()
31
+ self.cfg = cfg
32
+ with tops.logger.capture_log_stdout():
33
+ self.model = build_model(cfg)
34
+ DetectionCheckpointer(self.model).load(model_urls[cfg_name])
35
+ self.model.eval()
36
+ self.input_format = cfg.INPUT.FORMAT
37
+ self.class_names = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes
38
+ self.class_to_keep = set([self.class_names.index(cls_) for cls_ in class_filter])
39
+ self.person_class = self.class_names.index("person")
40
+ self.fp16_inference = fp16_inference
41
+ tops.logger.log("Mask R-CNN built.")
42
+
43
+ def __call__(self, *args, **kwargs):
44
+ return self.forward(*args, **kwargs)
45
+
46
+ def resize_im(self, im):
47
+ H, W = im.shape[1:]
48
+ newH, newW = ResizeShortestEdge.get_output_shape(
49
+ H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
50
+ return resize(
51
+ im, (newH, newW), antialias=True)
52
+
53
+ @torch.no_grad()
54
+ def forward(self, im: torch.Tensor):
55
+ if self.input_format == "BGR":
56
+ im = im.flip(0)
57
+ else:
58
+ assert self.input_format == "RGB"
59
+ H, W = im.shape[-2:]
60
+ im = self.resize_im(im)
61
+ with torch.cuda.amp.autocast(enabled=self.fp16_inference):
62
+ output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"]
63
+ scores = output.get("scores")
64
+ N = len(scores)
65
+ classes = output.get("pred_classes")
66
+ idx2keep = [i for i in range(N) if classes[i].tolist() in self.class_to_keep]
67
+ classes = classes[idx2keep]
68
+ assert isinstance(output.get("pred_boxes"), Boxes)
69
+ segmentation = output.get("pred_masks")[idx2keep]
70
+ assert segmentation.dtype == torch.bool
71
+ is_person = classes == self.person_class
72
+ return {
73
+ "scores": output.get("scores")[idx2keep],
74
+ "segmentation": segmentation,
75
+ "classes": output.get("pred_classes")[idx2keep],
76
+ "is_person": is_person
77
+ }
78
+
dp2/detection/person_detector.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import lzma
3
+ from dp2.detection.base import BaseDetector
4
+ from .utils import combine_cse_maskrcnn_dets
5
+ from .models.cse import CSEDetector
6
+ from .models.mask_rcnn import MaskRCNNDetector
7
+ from .models.keypoint_maskrcnn import KeypointMaskRCNN
8
+ from .structures import CSEPersonDetection, PersonDetection
9
+ from pathlib import Path
10
+
11
+
12
+ class CSEPersonDetector(BaseDetector):
13
+ def __init__(
14
+ self,
15
+ score_threshold: float,
16
+ mask_rcnn_cfg: dict,
17
+ cse_cfg: dict,
18
+ cse_post_process_cfg: dict,
19
+ **kwargs
20
+ ) -> None:
21
+ super().__init__(**kwargs)
22
+ self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
23
+ self.cse_detector = CSEDetector(**cse_cfg, score_thres=score_threshold)
24
+ self.post_process_cfg = cse_post_process_cfg
25
+ self.iou_combine_threshold = self.post_process_cfg.pop("iou_combine_threshold")
26
+
27
+ def __call__(self, *args, **kwargs):
28
+ return self.forward(*args, **kwargs)
29
+
30
+ def load_from_cache(self, cache_path: Path):
31
+ with lzma.open(cache_path, "rb") as fp:
32
+ state_dict = torch.load(fp)
33
+ kwargs = dict(
34
+ post_process_cfg=self.post_process_cfg,
35
+ embed_map=self.cse_detector.embed_map,
36
+ )
37
+ return [
38
+ state["cls"].from_state_dict(**kwargs, state_dict=state)
39
+ for state in state_dict
40
+ ]
41
+
42
+ @torch.no_grad()
43
+ def forward(self, im: torch.Tensor, cse_dets=None):
44
+ mask_dets = self.mask_rcnn(im)
45
+ if cse_dets is None:
46
+ cse_dets = self.cse_detector(im)
47
+ segmentation = mask_dets["segmentation"]
48
+ segmentation, cse_dets, _ = combine_cse_maskrcnn_dets(
49
+ segmentation, cse_dets, self.iou_combine_threshold
50
+ )
51
+ det = CSEPersonDetection(
52
+ segmentation=segmentation,
53
+ cse_dets=cse_dets,
54
+ embed_map=self.cse_detector.embed_map,
55
+ orig_imshape_CHW=im.shape,
56
+ **self.post_process_cfg
57
+ )
58
+ return [det]
59
+
60
+
61
+ class MaskRCNNPersonDetector(BaseDetector):
62
+ def __init__(
63
+ self,
64
+ score_threshold: float,
65
+ mask_rcnn_cfg: dict,
66
+ cse_post_process_cfg: dict,
67
+ **kwargs
68
+ ) -> None:
69
+ super().__init__(**kwargs)
70
+ self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
71
+ self.post_process_cfg = cse_post_process_cfg
72
+
73
+ def __call__(self, *args, **kwargs):
74
+ return self.forward(*args, **kwargs)
75
+
76
+ def load_from_cache(self, cache_path: Path):
77
+ with lzma.open(cache_path, "rb") as fp:
78
+ state_dict = torch.load(fp)
79
+ kwargs = dict(
80
+ post_process_cfg=self.post_process_cfg,
81
+ )
82
+ return [
83
+ state["cls"].from_state_dict(**kwargs, state_dict=state)
84
+ for state in state_dict
85
+ ]
86
+
87
+ @torch.no_grad()
88
+ def forward(self, im: torch.Tensor):
89
+ mask_dets = self.mask_rcnn(im)
90
+ segmentation = mask_dets["segmentation"]
91
+ det = PersonDetection(
92
+ segmentation, **self.post_process_cfg, orig_imshape_CHW=im.shape
93
+ )
94
+ return [det]
95
+
96
+
97
+ class KeypointMaskRCNNPersonDetector(BaseDetector):
98
+ def __init__(
99
+ self,
100
+ score_threshold: float,
101
+ mask_rcnn_cfg: dict,
102
+ cse_post_process_cfg: dict,
103
+ **kwargs
104
+ ) -> None:
105
+ super().__init__(**kwargs)
106
+ self.mask_rcnn = KeypointMaskRCNN(
107
+ **mask_rcnn_cfg, score_threshold=score_threshold
108
+ )
109
+ self.post_process_cfg = cse_post_process_cfg
110
+
111
+ def __call__(self, *args, **kwargs):
112
+ return self.forward(*args, **kwargs)
113
+
114
+ def load_from_cache(self, cache_path: Path):
115
+ with lzma.open(cache_path, "rb") as fp:
116
+ state_dict = torch.load(fp)
117
+ kwargs = dict(
118
+ post_process_cfg=self.post_process_cfg,
119
+ )
120
+ return [
121
+ state["cls"].from_state_dict(**kwargs, state_dict=state)
122
+ for state in state_dict
123
+ ]
124
+
125
+ @torch.no_grad()
126
+ def forward(self, im: torch.Tensor):
127
+ mask_dets = self.mask_rcnn(im)
128
+ segmentation = mask_dets["segmentation"]
129
+ det = PersonDetection(
130
+ segmentation,
131
+ **self.post_process_cfg,
132
+ orig_imshape_CHW=im.shape,
133
+ keypoints=mask_dets["keypoints"]
134
+ )
135
+ return [det]
dp2/detection/structures.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from dp2 import utils
4
+ from dp2.utils import vis_utils, crop_box
5
+ from .utils import (
6
+ cut_pad_resize, masks_to_boxes,
7
+ get_kernel, transform_embedding, initialize_cse_boxes
8
+ )
9
+ from .box_utils import get_expanded_bbox, include_box
10
+ import torchvision
11
+ import tops
12
+ from .box_utils_fdf import expand_bbox as expand_bbox_fdf
13
+
14
+
15
+ class VehicleDetection:
16
+
17
+ def __init__(self, segmentation: torch.BoolTensor) -> None:
18
+ self.segmentation = segmentation
19
+ self.boxes = masks_to_boxes(segmentation)
20
+ assert self.boxes.shape[1] == 4, self.boxes.shape
21
+ self.n_detections = self.segmentation.shape[0]
22
+ area = (self.boxes[:, 3] - self.boxes[:, 1]) * (self.boxes[:, 2] - self.boxes[:, 0])
23
+
24
+ sorted_idx = torch.argsort(area, descending=True)
25
+ self.segmentation = self.segmentation[sorted_idx]
26
+ self.boxes = self.boxes[sorted_idx].cpu()
27
+
28
+ def pre_process(self):
29
+ pass
30
+
31
+ def get_crop(self, idx: int, im):
32
+ assert idx < len(self)
33
+ box = self.boxes[idx]
34
+ im = crop_box(self.im, box)
35
+ mask = crop_box(self.segmentation[idx])
36
+ mask = mask == 0
37
+ return dict(img=im, mask=mask.float(), boxes=box)
38
+
39
+ def visualize(self, im):
40
+ if len(self) == 0:
41
+ return im
42
+ im = vis_utils.draw_mask(im.clone(), self.segmentation.logical_not())
43
+ return im
44
+
45
+ def __len__(self):
46
+ return self.n_detections
47
+
48
+ @staticmethod
49
+ def from_state_dict(state_dict, **kwargs):
50
+ numel = np.prod(state_dict["shape"])
51
+ arr = np.unpackbits(state_dict["segmentation"].numpy(), count=numel)
52
+ segmentation = tops.to_cuda(torch.from_numpy(arr)).view(state_dict["shape"])
53
+ return VehicleDetection(segmentation)
54
+
55
+ def state_dict(self, **kwargs):
56
+ segmentation = torch.from_numpy(np.packbits(self.segmentation.bool().cpu().numpy()))
57
+ return dict(segmentation=segmentation, cls=self.__class__, shape=self.segmentation.shape)
58
+
59
+
60
+ class FaceDetection:
61
+
62
+ def __init__(self, boxes_ltrb: torch.LongTensor, target_imsize, fdf128_expand: bool, **kwargs) -> None:
63
+ self.boxes = boxes_ltrb.cpu()
64
+ assert self.boxes.shape[1] == 4, self.boxes.shape
65
+ self.target_imsize = tuple(target_imsize)
66
+ # Sory by area to paste in largest faces last
67
+ area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1)
68
+ idx = area.argsort(descending=False)
69
+ self.boxes = self.boxes[idx]
70
+ self.fdf128_expand = fdf128_expand
71
+
72
+ def visualize(self, im):
73
+ if len(self) == 0:
74
+ return im
75
+ orig_device = im.device
76
+ for box in self.boxes:
77
+ simple_expand = False if self.fdf128_expand else True
78
+ e_box = torch.from_numpy(expand_bbox_fdf(box.numpy(), im.shape[-2:], simple_expand))
79
+ im = torchvision.utils.draw_bounding_boxes(im.cpu(), e_box[None], colors=(0, 0, 255), width=2)
80
+ im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2)
81
+
82
+ return im.to(device=orig_device)
83
+
84
+ def get_crop(self, idx: int, im):
85
+ assert idx < len(self)
86
+ box = self.boxes[idx].numpy()
87
+ expanded_boxes = expand_bbox_fdf(box, im.shape[-2:], True)
88
+ im = cut_pad_resize(im, expanded_boxes, self.target_imsize, fdf_resize=True)
89
+ area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1)
90
+
91
+ # Find the square mask corresponding to box.
92
+ box_mask = box.copy().astype(float)
93
+ box_mask[[0, 2]] -= expanded_boxes[0]
94
+ box_mask[[1, 3]] -= expanded_boxes[1]
95
+
96
+ width = expanded_boxes[2] - expanded_boxes[0]
97
+ resize_factor = self.target_imsize[0] / width
98
+ box_mask = (box_mask * resize_factor).astype(int)
99
+ mask = torch.ones((1, *self.target_imsize), device=im.device, dtype=torch.float32)
100
+ crop_box(mask, box_mask).fill_(0)
101
+ return dict(
102
+ img=im[None], mask=mask[None],
103
+ boxes=torch.from_numpy(expanded_boxes).view(1, -1))
104
+
105
+ def __len__(self):
106
+ return len(self.boxes)
107
+
108
+ @staticmethod
109
+ def from_state_dict(state_dict, **kwargs):
110
+ return FaceDetection(state_dict["boxes"].cpu(), **kwargs)
111
+
112
+ def state_dict(self, **kwargs):
113
+ return dict(boxes=self.boxes, cls=self.__class__)
114
+
115
+ def pre_process(self):
116
+ pass
117
+
118
+
119
+ def remove_dilate_in_pad(mask: torch.Tensor, exp_box, orig_imshape):
120
+ """
121
+ Dilation happens after padding, which could place dilation in the padded area.
122
+ Remove this.
123
+ """
124
+ x0, y0, x1, y1 = exp_box
125
+ H, W = orig_imshape
126
+ # Padding in original image space
127
+ p_y0 = max(0, -y0)
128
+ p_y1 = max(y1 - H, 0)
129
+ p_x0 = max(0, -x0)
130
+ p_x1 = max(x1 - W, 0)
131
+ resize_ratio = mask.shape[-2] / (y1-y0)
132
+ p_x0, p_y0, p_x1, p_y1 = [(_*resize_ratio).floor().long() for _ in [p_x0, p_y0, p_x1, p_y1]]
133
+ mask[..., :p_y0, :] = 0
134
+ mask[..., :p_x0] = 0
135
+ mask[..., mask.shape[-2] - p_y1:, :] = 0
136
+ mask[..., mask.shape[-1] - p_x1:] = 0
137
+
138
+
139
+ class CSEPersonDetection:
140
+
141
+ def __init__(self,
142
+ segmentation, cse_dets,
143
+ target_imsize,
144
+ exp_bbox_cfg, exp_bbox_filter,
145
+ dilation_percentage: float,
146
+ embed_map: torch.Tensor,
147
+ orig_imshape_CHW,
148
+ normalize_embedding: bool) -> None:
149
+ self.segmentation = segmentation
150
+ self.cse_dets = cse_dets
151
+ self.target_imsize = list(target_imsize)
152
+ self.pre_processed = False
153
+ self.exp_bbox_cfg = exp_bbox_cfg
154
+ self.exp_bbox_filter = exp_bbox_filter
155
+ self.dilation_percentage = dilation_percentage
156
+ self.embed_map = embed_map
157
+ self.normalize_embedding = normalize_embedding
158
+ if self.normalize_embedding:
159
+ embed_map_mean = self.embed_map.mean(dim=0, keepdim=True)
160
+ embed_map_rstd = ((self.embed_map - embed_map_mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt()
161
+ self.embed_map_normalized = (self.embed_map - embed_map_mean) * embed_map_rstd
162
+ self.orig_imshape_CHW = orig_imshape_CHW
163
+
164
+ @torch.no_grad()
165
+ def pre_process(self):
166
+ if self.pre_processed:
167
+ return
168
+ boxes = initialize_cse_boxes(self.segmentation, self.cse_dets["bbox_XYXY"]).cpu()
169
+ expanded_boxes = []
170
+ included_boxes = []
171
+ for i in range(len(boxes)):
172
+ exp_box = get_expanded_bbox(
173
+ boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg,
174
+ target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1])
175
+ if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter):
176
+ continue
177
+ included_boxes.append(i)
178
+ expanded_boxes.append(exp_box)
179
+ expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4)
180
+ self.segmentation = self.segmentation[included_boxes]
181
+ self.cse_dets = {k: v[included_boxes] for k, v in self.cse_dets.items()}
182
+
183
+ self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool)
184
+ area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes))
185
+ for i, box in enumerate(expanded_boxes):
186
+ self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0]
187
+
188
+ dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage))
189
+ self.maskrcnn_mask = self.mask.clone().logical_not()[:, None]
190
+ self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel)
191
+ [remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) for i in range(len(expanded_boxes))]
192
+ self.boxes = expanded_boxes.cpu()
193
+ self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask)
194
+
195
+ self.pre_processed = True
196
+ self.n_detections = len(self.boxes)
197
+ self.mask = self.mask.logical_not()
198
+
199
+ E_mask = torch.zeros((self.n_detections, 1, *self.target_imsize), device=self.mask.device, dtype=torch.bool)
200
+ self.vertices = torch.zeros_like(E_mask, dtype=torch.long)
201
+ for i in range(self.n_detections):
202
+ E_, E_mask[i] = transform_embedding(
203
+ self.cse_dets["instance_embedding"][i],
204
+ self.cse_dets["instance_segmentation"][i],
205
+ self.boxes[i],
206
+ self.cse_dets["bbox_XYXY"][i].cpu(),
207
+ self.target_imsize
208
+ )
209
+ self.vertices[i] = utils.from_E_to_vertex(E_[None], E_mask[i:i+1].logical_not(), self.embed_map).squeeze()[None]
210
+ self.E_mask = E_mask
211
+
212
+ sorted_idx = torch.argsort(area, descending=False)
213
+ self.mask = self.mask[sorted_idx]
214
+ self.boxes = self.boxes[sorted_idx.cpu()]
215
+ self.vertices = self.vertices[sorted_idx]
216
+ self.E_mask = self.E_mask[sorted_idx]
217
+ self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx]
218
+
219
+ def get_crop(self, idx: int, im):
220
+ self.pre_process()
221
+ assert idx < len(self)
222
+ box = self.boxes[idx]
223
+ mask = self.mask[idx]
224
+ im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0)
225
+
226
+ vertices_ = self.vertices[idx]
227
+ E_mask_ = self.E_mask[idx].float()
228
+ if self.normalize_embedding:
229
+ embedding = self.embed_map_normalized[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_
230
+ else:
231
+ embedding = self.embed_map[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_
232
+
233
+ return dict(
234
+ img=im,
235
+ mask=mask.float()[None],
236
+ boxes=box.reshape(1, -1),
237
+ E_mask=E_mask_[None],
238
+ vertices=vertices_[None],
239
+ embed_map=self.embed_map,
240
+ embedding=embedding[None],
241
+ maskrcnn_mask=self.maskrcnn_mask[idx].float()[None]
242
+ )
243
+
244
+ def __len__(self):
245
+ self.pre_process()
246
+ return self.n_detections
247
+
248
+ def state_dict(self, after_preprocess=False):
249
+ """
250
+ The processed annotations occupy more space than the original detections.
251
+ """
252
+ if not after_preprocess:
253
+ return {
254
+ "combined_segmentation": self.segmentation.bool(),
255
+ "cse_instance_segmentation": self.cse_dets["instance_segmentation"].bool(),
256
+ "cse_instance_embedding": self.cse_dets["instance_embedding"],
257
+ "cse_bbox_XYXY": self.cse_dets["bbox_XYXY"].long(),
258
+ "cls": self.__class__,
259
+ "orig_imshape_CHW": self.orig_imshape_CHW
260
+ }
261
+ self.pre_process()
262
+ return dict(
263
+ E_mask=torch.from_numpy(np.packbits(self.E_mask.bool().cpu().numpy())),
264
+ mask=torch.from_numpy(np.packbits(self.mask.bool().cpu().numpy())),
265
+ maskrcnn_mask=torch.from_numpy(np.packbits(self.maskrcnn_mask.bool().cpu().numpy())),
266
+ vertices=self.vertices.to(torch.int16).cpu(),
267
+ cls=self.__class__,
268
+ boxes=self.boxes,
269
+ orig_imshape_CHW=self.orig_imshape_CHW,
270
+ )
271
+
272
+ @staticmethod
273
+ def from_state_dict(
274
+ state_dict, embed_map,
275
+ post_process_cfg, **kwargs):
276
+ after_preprocess = "segmentation" not in state_dict and "combined_segmentation" not in state_dict
277
+ if after_preprocess:
278
+ detection = CSEPersonDetection(
279
+ segmentation=None, cse_dets=None, embed_map=embed_map,
280
+ orig_imshape_CHW=state_dict["orig_imshape_CHW"],
281
+ **post_process_cfg)
282
+ detection.vertices = tops.to_cuda(state_dict["vertices"].long())
283
+ numel = np.prod(detection.vertices.shape)
284
+ detection.E_mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["E_mask"].numpy(), count=numel))).view(*detection.vertices.shape)
285
+ detection.mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["mask"].numpy(), count=numel))).view(*detection.vertices.shape)
286
+ detection.maskrcnn_mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["maskrcnn_mask"].numpy(), count=numel))).view(*detection.vertices.shape)
287
+ detection.n_detections = len(detection.mask)
288
+ detection.pre_processed = True
289
+
290
+ if isinstance(state_dict["boxes"], np.ndarray):
291
+ state_dict["boxes"] = torch.from_numpy(state_dict["boxes"])
292
+ detection.boxes = state_dict["boxes"]
293
+ return detection
294
+
295
+ cse_dets = dict(
296
+ instance_segmentation=state_dict["cse_instance_segmentation"],
297
+ instance_embedding=state_dict["cse_instance_embedding"],
298
+ embed_map=embed_map,
299
+ bbox_XYXY=state_dict["cse_bbox_XYXY"])
300
+ cse_dets = {k: tops.to_cuda(v) for k, v in cse_dets.items()}
301
+
302
+ segmentation = state_dict["combined_segmentation"]
303
+ return CSEPersonDetection(
304
+ segmentation, cse_dets, embed_map=embed_map,
305
+ orig_imshape_CHW=state_dict["orig_imshape_CHW"],
306
+ **post_process_cfg)
307
+
308
+ def visualize(self, im):
309
+ self.pre_process()
310
+ if len(self) == 0:
311
+ return im
312
+ im = vis_utils.draw_cropped_masks(
313
+ im.clone(), self.mask, self.boxes, visualize_instances=False)
314
+ E = self.embed_map[self.vertices.long()].squeeze(1).permute(0,3, 1, 2)
315
+ im = im.to(E.device)
316
+ im = vis_utils.draw_cse_all(
317
+ E, self.E_mask.squeeze(1).bool(), im,
318
+ self.boxes, self.embed_map)
319
+ im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2)
320
+ return im
321
+
322
+
323
+ def shift_and_preprocess_keypoints(keypoints: torch.Tensor, boxes):
324
+ keypoints = keypoints.clone()
325
+ N = boxes.shape[0]
326
+ tops.assert_shape(keypoints, (N, None, 3))
327
+ tops.assert_shape(boxes, (N, 4))
328
+ x0, y0, x1, y1 = [_.view(-1, 1) for _ in boxes.T]
329
+
330
+ w = x1 - x0
331
+ h = y1 - y0
332
+ keypoints[:, :, 0] = (keypoints[:, :, 0] - x0) / w
333
+ keypoints[:, :, 1] = (keypoints[:, :, 1] - y0) / h
334
+ check_outside = lambda x: (x < 0).logical_or(x > 1)
335
+ is_outside = check_outside(keypoints[:, :, 0]).logical_or(check_outside(keypoints[:, :, 1]))
336
+ keypoints[:, :, 2] = keypoints[:, :, 2] >= 0
337
+ keypoints[:, :, 2] = (keypoints[:, :, 2] > 0).logical_and(is_outside.logical_not())
338
+ return keypoints
339
+
340
+
341
+ class PersonDetection:
342
+
343
+ def __init__(
344
+ self,
345
+ segmentation,
346
+ target_imsize,
347
+ exp_bbox_cfg, exp_bbox_filter,
348
+ dilation_percentage: float,
349
+ orig_imshape_CHW,
350
+ keypoints=None,
351
+ **kwargs) -> None:
352
+ self.segmentation = segmentation
353
+ self.target_imsize = list(target_imsize)
354
+ self.pre_processed = False
355
+ self.exp_bbox_cfg = exp_bbox_cfg
356
+ self.exp_bbox_filter = exp_bbox_filter
357
+ self.dilation_percentage = dilation_percentage
358
+ self.orig_imshape_CHW = orig_imshape_CHW
359
+ self.keypoints = keypoints
360
+
361
+ @torch.no_grad()
362
+ def pre_process(self):
363
+ if self.pre_processed:
364
+ return
365
+ boxes = masks_to_boxes(self.segmentation).cpu()
366
+ expanded_boxes = []
367
+ included_boxes = []
368
+ for i in range(len(boxes)):
369
+ exp_box = get_expanded_bbox(
370
+ boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg,
371
+ target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1])
372
+ if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter):
373
+ continue
374
+ included_boxes.append(i)
375
+ expanded_boxes.append(exp_box)
376
+ expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4)
377
+ self.segmentation = self.segmentation[included_boxes]
378
+ if self.keypoints is not None:
379
+ self.keypoints = self.keypoints[included_boxes]
380
+ area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes))
381
+ self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool)
382
+ for i, box in enumerate(expanded_boxes):
383
+ self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0]
384
+ if self.keypoints is not None:
385
+ self.keypoints = shift_and_preprocess_keypoints(self.keypoints, expanded_boxes)
386
+ dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage))
387
+ self.maskrcnn_mask = self.mask.clone().logical_not()[:, None]
388
+ self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel)
389
+
390
+ [remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) for i in range(len(expanded_boxes))]
391
+ self.boxes = expanded_boxes
392
+ self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask)
393
+
394
+ self.pre_processed = True
395
+ self.n_detections = len(self.boxes)
396
+ self.mask = self.mask.logical_not()
397
+
398
+ sorted_idx = torch.argsort(area, descending=False)
399
+ self.mask = self.mask[sorted_idx]
400
+ self.boxes = self.boxes[sorted_idx.cpu()]
401
+ self.segmentation = self.segmentation[sorted_idx]
402
+ self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx]
403
+ if self.keypoints is not None:
404
+ self.keypoints = self.keypoints[sorted_idx]
405
+
406
+ def get_crop(self, idx: int, im: torch.Tensor):
407
+ assert idx < len(self)
408
+ self.pre_process()
409
+ box = self.boxes[idx]
410
+ mask = self.mask[idx][None].float()
411
+ im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0)
412
+ batch = dict(
413
+ img=im, mask=mask, boxes=box.reshape(1, -1),
414
+ maskrcnn_mask=self.maskrcnn_mask[idx][None].float())
415
+ if self.keypoints is not None:
416
+ batch["keypoints"] = self.keypoints[idx:idx+1]
417
+ return batch
418
+
419
+ def __len__(self):
420
+ self.pre_process()
421
+ return self.n_detections
422
+
423
+ def state_dict(self, **kwargs):
424
+ return dict(
425
+ segmentation=self.segmentation.bool(),
426
+ cls=self.__class__,
427
+ orig_imshape_CHW=self.orig_imshape_CHW,
428
+ keypoints=self.keypoints
429
+ )
430
+
431
+ @staticmethod
432
+ def from_state_dict(
433
+ state_dict,
434
+ post_process_cfg, **kwargs):
435
+ return PersonDetection(
436
+ state_dict["segmentation"],
437
+ orig_imshape_CHW=state_dict["orig_imshape_CHW"],
438
+ **post_process_cfg,
439
+ keypoints=state_dict["keypoints"])
440
+
441
+ def visualize(self, im):
442
+ self.pre_process()
443
+ im = im.cpu()
444
+ if len(self) == 0:
445
+ return im
446
+ im = vis_utils.draw_cropped_masks(im.clone(), self.mask, self.boxes, visualize_instances=False)
447
+ im = vis_utils.draw_cropped_keypoints(im, self.keypoints, self.boxes)
448
+ return im
449
+
450
+
451
+ def get_dilated_boxes(exp_bbox: torch.LongTensor, mask):
452
+ """
453
+ mask: resized mask
454
+ """
455
+ assert exp_bbox.shape[0] == mask.shape[0]
456
+ boxes = masks_to_boxes(mask.squeeze(1)).cpu()
457
+ H, W = exp_bbox[:, 3] - exp_bbox[:, 1], exp_bbox[:, 2] - exp_bbox[:, 0]
458
+ boxes[:, [0, 2]] = (boxes[:, [0, 2]] * W[:, None] / mask.shape[-1]).long()
459
+ boxes[:, [1, 3]] = (boxes[:, [1, 3]] * H[:, None] / mask.shape[-2]).long()
460
+ boxes[:, [0, 2]] += exp_bbox[:, 0:1]
461
+ boxes[:, [1, 3]] += exp_bbox[:, 1:2]
462
+ return boxes
463
+
dp2/detection/utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import tops
5
+ from skimage.morphology import disk
6
+ from torchvision.transforms.functional import resize, InterpolationMode
7
+ from functools import lru_cache
8
+
9
+
10
+ @lru_cache(maxsize=200)
11
+ def get_kernel(n: int):
12
+ kernel = disk(n, dtype=bool)
13
+ return tops.to_cuda(torch.from_numpy(kernel).bool())
14
+
15
+
16
+ def transform_embedding(E: torch.Tensor, S: torch.Tensor, exp_bbox, E_bbox, target_imshape):
17
+ """
18
+ Transforms the detected embedding/mask directly to the target image shape
19
+ """
20
+
21
+ C, HE, WE = E.shape
22
+ assert E_bbox[0] >= exp_bbox[0], (E_bbox, exp_bbox)
23
+ assert E_bbox[2] >= exp_bbox[0]
24
+ assert E_bbox[1] >= exp_bbox[1]
25
+ assert E_bbox[3] >= exp_bbox[1]
26
+ assert E_bbox[2] <= exp_bbox[2]
27
+ assert E_bbox[3] <= exp_bbox[3]
28
+
29
+ x0 = int(np.round((E_bbox[0] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1]))
30
+ x1 = int(np.round((E_bbox[2] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1]))
31
+ y0 = int(np.round((E_bbox[1] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0]))
32
+ y1 = int(np.round((E_bbox[3] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0]))
33
+ new_E = torch.zeros((C, *target_imshape), device=E.device, dtype=torch.float32)
34
+ new_S = torch.zeros((target_imshape), device=S.device, dtype=torch.bool)
35
+
36
+ E = resize(E, (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)
37
+ new_E[:, y0:y1, x0:x1] = E
38
+ S = resize(S[None].float(), (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)[0] > 0
39
+ new_S[y0:y1, x0:x1] = S
40
+ return new_E, new_S
41
+
42
+
43
+ def pairwise_mask_iou(mask1: torch.Tensor, mask2: torch.Tensor):
44
+ """
45
+ mask: shape [N, H, W]
46
+ """
47
+ assert len(mask1.shape) == 3
48
+ assert len(mask2.shape) == 3
49
+ assert mask1.device == mask2.device, (mask1.device, mask2.device)
50
+ assert mask2.dtype == mask2.dtype
51
+ assert mask1.dtype == torch.bool
52
+ assert mask1.shape[1:] == mask2.shape[1:]
53
+ N1, H1, W1 = mask1.shape
54
+ N2, H2, W2 = mask2.shape
55
+ iou = torch.zeros((N1, N2), dtype=torch.float32)
56
+ for i in range(N1):
57
+ cur = mask1[i:i+1]
58
+ inter = torch.logical_and(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu()
59
+ union = torch.logical_or(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu()
60
+ iou[i] = inter / union
61
+ return iou
62
+
63
+
64
+ def find_best_matches(mask1: torch.Tensor, mask2: torch.Tensor, iou_threshold: float):
65
+ N1 = mask1.shape[0]
66
+ N2 = mask2.shape[0]
67
+ ious = pairwise_mask_iou(mask1, mask2).cpu().numpy()
68
+ indices = np.array([idx for idx, iou in np.ndenumerate(ious)])
69
+ ious = ious.flatten()
70
+ mask = ious >= iou_threshold
71
+ ious = ious[mask]
72
+ indices = indices[mask]
73
+
74
+ # do not sort by iou to keep ordering of mask rcnn / cse sorting.
75
+ taken1 = np.zeros((N1), dtype=bool)
76
+ taken2 = np.zeros((N2), dtype=bool)
77
+ matches = []
78
+ for i, j in indices:
79
+ if taken1[i].any() or taken2[j].any():
80
+ continue
81
+ matches.append((i, j))
82
+ taken1[i] = True
83
+ taken2[j] = True
84
+ return matches
85
+
86
+
87
+ def combine_cse_maskrcnn_dets(segmentation: torch.Tensor, cse_dets: dict, iou_threshold: float):
88
+ assert 0 < iou_threshold <= 1
89
+ matches = find_best_matches(segmentation, cse_dets["im_segmentation"], iou_threshold)
90
+ H, W = segmentation.shape[1:]
91
+ new_seg = torch.zeros((len(matches), H, W), dtype=torch.bool, device=segmentation.device)
92
+ cse_im_seg = cse_dets["im_segmentation"]
93
+ for idx, (i, j) in enumerate(matches):
94
+ new_seg[idx] = torch.logical_or(segmentation[i], cse_im_seg[j])
95
+ cse_dets = dict(
96
+ instance_segmentation=cse_dets["instance_segmentation"][[j for (i, j) in matches]],
97
+ instance_embedding=cse_dets["instance_embedding"][[j for (i, j) in matches]],
98
+ bbox_XYXY=cse_dets["bbox_XYXY"][[j for (i, j) in matches]],
99
+ scores=cse_dets["scores"][[j for (i, j) in matches]],
100
+ )
101
+ return new_seg, cse_dets, np.array(matches).reshape(-1, 2)
102
+
103
+
104
+ def initialize_cse_boxes(segmentation: torch.Tensor, cse_boxes: torch.Tensor):
105
+ """
106
+ cse_boxes can be outside of segmentation.
107
+ """
108
+ boxes = masks_to_boxes(segmentation)
109
+
110
+ assert boxes.shape == cse_boxes.shape, (boxes.shape, cse_boxes.shape)
111
+ combined = torch.stack((boxes, cse_boxes), dim=-1)
112
+ boxes = torch.cat((
113
+ combined[:, :2].min(dim=2).values,
114
+ combined[:, 2:].max(dim=2).values,
115
+ ), dim=1)
116
+ return boxes
117
+
118
+
119
+ def cut_pad_resize(x: torch.Tensor, bbox, target_shape, fdf_resize=False):
120
+ """
121
+ Crops or pads x to fit in the bbox and resize to target shape.
122
+ """
123
+ C, H, W = x.shape
124
+ x0, y0, x1, y1 = bbox
125
+
126
+ if y0 > 0 and x0 > 0 and x1 <= W and y1 <= H:
127
+ new_x = x[:, y0:y1, x0:x1]
128
+ else:
129
+ new_x = torch.zeros(((C, y1-y0, x1-x0)), dtype=x.dtype, device=x.device)
130
+ y0_t = max(0, -y0)
131
+ y1_t = min(y1-y0, (y1-y0)-(y1-H))
132
+ x0_t = max(0, -x0)
133
+ x1_t = min(x1-x0, (x1-x0)-(x1-W))
134
+ x0 = max(0, x0)
135
+ y0 = max(0, y0)
136
+ x1 = min(x1, W)
137
+ y1 = min(y1, H)
138
+ new_x[:, y0_t:y1_t, x0_t:x1_t] = x[:, y0:y1, x0:x1]
139
+ if x1 - x0 == target_shape[1] and y1 - y0 == target_shape[0]:
140
+ return new_x
141
+ if x.dtype == torch.bool:
142
+ new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.NEAREST) > 0.5
143
+ elif x.dtype == torch.float32:
144
+ new_x = resize(new_x, target_shape, interpolation=InterpolationMode.BILINEAR, antialias=True)
145
+ elif x.dtype == torch.uint8:
146
+ if fdf_resize: # FDF dataset is created with cv2 INTER_AREA.
147
+ # Incorrect resizing generates noticeable poorer inpaintings.
148
+ upsampling = ((y1-y0) *(x1-x0)) < (target_shape[0] * target_shape[1])
149
+ if upsampling:
150
+ new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BICUBIC, antialias=True).round().clamp(0, 255).byte()
151
+ else:
152
+ device = new_x.device
153
+ new_x = new_x.permute(1, 2, 0).cpu().numpy()
154
+ new_x = cv2.resize(new_x, target_shape[::-1], interpolation=cv2.INTER_AREA)
155
+ new_x = torch.from_numpy(np.rollaxis(new_x, 2)).to(device)
156
+ else:
157
+ new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BILINEAR, antialias=True).round().clamp(0, 255).byte()
158
+ else:
159
+ raise ValueError(f"Not supported dtype: {x.dtype}")
160
+ return new_x
161
+
162
+
163
+
164
+ def masks_to_boxes(segmentation: torch.Tensor):
165
+ assert len(segmentation.shape) == 3
166
+ x = segmentation.any(dim=1).byte() # Compress rows
167
+ x0 = x.argmax(dim=1)
168
+
169
+ x1 = segmentation.shape[2] - x.flip(dims=(1,)).argmax(dim=1)
170
+ y = segmentation.any(dim=2).byte()
171
+ y0 = y.argmax(dim=1)
172
+ y1 = segmentation.shape[1] - y.flip(dims=(1,)).argmax(dim=1)
173
+ return torch.stack([x0, y0, x1, y1], dim=1)
174
+
dp2/discriminator/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sg2_discriminator import SG2Discriminator
dp2/discriminator/sg2_discriminator.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sg3_torch_utils.ops import upfirdn2d
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ from .. import layers
6
+ from ..layers.sg2_layers import DiscriminatorEpilogue, ResidualBlock, Block
7
+
8
+
9
+ class SG2Discriminator(layers.Module):
10
+
11
+ def __init__(
12
+ self,
13
+ cnum: int,
14
+ max_cnum_mul: int,
15
+ imsize,
16
+ min_fmap_resolution: int,
17
+ im_channels: int,
18
+ input_condition: bool,
19
+ conv_clamp: int,
20
+ input_cse: bool,
21
+ cse_nc: int):
22
+ super().__init__()
23
+
24
+ cse_nc = 0 if cse_nc is None else cse_nc
25
+ self._max_imsize = max(imsize)
26
+ self._cnum = cnum
27
+ self._max_cnum_mul = max_cnum_mul
28
+ self._min_fmap_resolution = min_fmap_resolution
29
+ self._input_condition = input_condition
30
+ self.input_cse = input_cse
31
+ self.layers = nn.ModuleList()
32
+
33
+ out_ch = self.get_chsize(self._max_imsize)
34
+ self.from_rgb = Block(
35
+ im_channels + input_condition*(im_channels+1) + input_cse*(cse_nc+1),
36
+ out_ch, conv_clamp=conv_clamp
37
+ )
38
+ n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1
39
+
40
+ for i in range(n_levels):
41
+ resolution = [x//2**i for x in imsize]
42
+ in_ch = self.get_chsize(max(resolution))
43
+ out_ch = self.get_chsize(max(max(resolution)//2, min_fmap_resolution))
44
+
45
+ down = 2
46
+ if i == 0:
47
+ down = 1
48
+ block = ResidualBlock(
49
+ in_ch, out_ch, down=down, conv_clamp=conv_clamp
50
+ )
51
+ self.layers.append(block)
52
+ self.output_layer = DiscriminatorEpilogue(
53
+ out_ch, resolution, conv_clamp=conv_clamp)
54
+
55
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1]))
56
+
57
+ def forward(self, img, condition, mask, embedding=None, E_mask=None,**kwargs):
58
+ to_cat = [img]
59
+ if self._input_condition:
60
+ to_cat.extend([condition, mask,])
61
+ if self.input_cse:
62
+ to_cat.extend([embedding, E_mask])
63
+ x = torch.cat(to_cat, dim=1)
64
+ x = self.from_rgb(x)
65
+
66
+ for i, layer in enumerate(self.layers):
67
+ x = layer(x)
68
+
69
+ x = self.output_layer(x)
70
+ return dict(score=x)
71
+
72
+ def get_chsize(self, imsize):
73
+ n = int(np.log2(self._max_imsize) - np.log2(imsize))
74
+ mul = min(2 ** n, self._max_cnum_mul)
75
+ ch = self._cnum * mul
76
+ return int(ch)
dp2/gan_trainer.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ from collections import defaultdict
3
+ import logging
4
+ import typing
5
+ import torch
6
+ import time
7
+ from dp2.utils import vis_utils
8
+ from dp2 import utils
9
+ from tops import logger, checkpointer
10
+ import tops
11
+ from easydict import EasyDict
12
+
13
+
14
+ def accumulate_gradients(params, fp16_ddp_accumulate):
15
+ if len(params) == 0:
16
+ return
17
+ params = [param for param in params if param.grad is not None]
18
+ flat = torch.cat([param.grad.flatten() for param in params])
19
+ orig_dtype = flat.dtype
20
+ if tops.world_size() > 1:
21
+ if fp16_ddp_accumulate:
22
+ flat = flat.half() / tops.world_size()
23
+ else:
24
+ flat /= tops.world_size()
25
+ torch.distributed.all_reduce(flat)
26
+ flat = flat.to(orig_dtype)
27
+ grads = flat.split([param.numel() for param in params])
28
+ for param, grad in zip(params, grads):
29
+ param.grad = grad.reshape(param.shape)
30
+
31
+
32
+ def accumulate_buffers(module: torch.nn.Module):
33
+ buffers = [buf for buf in module.buffers()]
34
+ if len(buffers) == 0:
35
+ return
36
+ flat = torch.cat([buf.flatten() for buf in buffers])
37
+ if tops.world_size() > 1:
38
+ torch.distributed.all_reduce(flat)
39
+ flat /= tops.world_size()
40
+ bufs = flat.split([buf.numel() for buf in buffers])
41
+ for old, new in zip(buffers, bufs):
42
+ old.copy_(new.reshape(old.shape), non_blocking=True)
43
+
44
+
45
+ def check_ddp_consistency(module):
46
+ if tops.world_size() == 1:
47
+ return
48
+ assert isinstance(module, torch.nn.Module)
49
+ assert isinstance(module, torch.nn.Module)
50
+ params_buffs = list(module.named_parameters()) + list(module.named_buffers())
51
+ for name, tensor in params_buffs:
52
+ fullname = type(module).__name__ + '.' + name
53
+ tensor = tensor.detach()
54
+ if tensor.is_floating_point():
55
+ tensor = torch.nan_to_num(tensor)
56
+ other = tensor.clone()
57
+ torch.distributed.broadcast(tensor=other, src=0)
58
+ assert (tensor == other).all(), fullname
59
+
60
+ class AverageMeter():
61
+ def __init__(self) -> None:
62
+ self.to_log = dict()
63
+ self.n = defaultdict(int)
64
+ pass
65
+
66
+ @torch.no_grad()
67
+ def update(self, values: dict):
68
+ for key, value in values.items():
69
+ self.n[key] += 1
70
+ if key in self.to_log:
71
+ self.to_log[key] += value.mean().detach()
72
+ else:
73
+ self.to_log[key] = value.mean().detach()
74
+
75
+ def get_average(self):
76
+ return {key: value / self.n[key] for key, value in self.to_log.items()}
77
+
78
+
79
+ class GANTrainer:
80
+
81
+ def __init__(
82
+ self,
83
+ G: torch.nn.Module,
84
+ D: torch.nn.Module,
85
+ G_EMA: torch.nn.Module,
86
+ D_optim: torch.optim.Optimizer,
87
+ G_optim: torch.optim.Optimizer,
88
+ dl_train: typing.Iterator,
89
+ dl_val: typing.Iterable,
90
+ scaler_D: torch.cuda.amp.GradScaler,
91
+ scaler_G: torch.cuda.amp.GradScaler,
92
+ ims_per_log: int,
93
+ max_images_to_train: int,
94
+ loss_handler,
95
+ ims_per_val: int,
96
+ evaluate_fn,
97
+ batch_size: int,
98
+ broadcast_buffers: bool,
99
+ fp16_ddp_accumulate: bool,
100
+ save_state: bool,
101
+ *args, **kwargs):
102
+ super().__init__(*args, **kwargs)
103
+
104
+ self.G = G
105
+ self.D = D
106
+ self.G_EMA = G_EMA
107
+ self.D_optim = D_optim
108
+ self.G_optim = G_optim
109
+ self.dl_train = dl_train
110
+ self.dl_val = dl_val
111
+ self.scaler_D = scaler_D
112
+ self.scaler_G = scaler_G
113
+ self.loss_handler = loss_handler
114
+ self.max_images_to_train = max_images_to_train
115
+ self.images_per_val = ims_per_val
116
+ self.images_per_log = ims_per_log
117
+ self.evaluate_fn = evaluate_fn
118
+ self.batch_size = batch_size
119
+ self.broadcast_buffers = broadcast_buffers
120
+ self.fp16_ddp_accumulate = fp16_ddp_accumulate
121
+
122
+ self.train_state = EasyDict(
123
+ next_log_step=0,
124
+ next_val_step=ims_per_val,
125
+ total_time=0
126
+ )
127
+
128
+ checkpointer.register_models(dict(
129
+ generator=G, discriminator=D, EMA_generator=G_EMA,
130
+ D_optimizer=D_optim,
131
+ G_optimizer=G_optim,
132
+ train_state=self.train_state,
133
+ scaler_D=self.scaler_D,
134
+ scaler_G=self.scaler_G
135
+ ))
136
+ if checkpointer.has_checkpoint():
137
+ checkpointer.load_registered_models()
138
+ logger.log(f"Resuming training from: global step: {logger.global_step()}")
139
+ else:
140
+ logger.add_dict({
141
+ "stats/discriminator_parameters": tops.num_parameters(self.D),
142
+ "stats/generator_parameters": tops.num_parameters(self.G),
143
+ }, commit=False)
144
+ if save_state:
145
+ # If the job is unexpectedly killed, there could be a mismatch between previously saved checkpoint and the current checkpoint.
146
+ atexit.register(checkpointer.save_registered_models)
147
+
148
+ self._ims_per_log = ims_per_log
149
+
150
+ self.to_log = AverageMeter()
151
+ self.trainable_params_D = [param for param in self.D.parameters() if param.requires_grad]
152
+ self.trainable_params_G = [param for param in self.G.parameters() if param.requires_grad]
153
+ logger.add_dict({
154
+ "stats/discriminator_trainable_parameters": sum(p.numel() for p in self.trainable_params_D),
155
+ "stats/generator_trainable_parameters": sum(p.numel() for p in self.trainable_params_G),
156
+ }, commit=False, level=logging.INFO)
157
+ check_ddp_consistency(self.D)
158
+ check_ddp_consistency(self.G)
159
+ check_ddp_consistency(self.G_EMA.generator)
160
+
161
+ def train_loop(self):
162
+ self.log_time()
163
+ while logger.global_step() <= self.max_images_to_train:
164
+ batch = next(self.dl_train)
165
+ self.G_EMA.update_beta()
166
+ self.to_log.update(self.step_D(batch))
167
+ self.to_log.update(self.step_G(batch))
168
+ self.G_EMA.update(self.G)
169
+
170
+ if logger.global_step() >= self.train_state.next_log_step:
171
+ to_log = {f"loss/{key}": item.item() for key, item in self.to_log.get_average().items()}
172
+ to_log.update({"amp/grad_scale_G": self.scaler_G.get_scale()})
173
+ to_log.update({"amp/grad_scale_D": self.scaler_D.get_scale()})
174
+ self.to_log = AverageMeter()
175
+ logger.add_dict(to_log, commit=True)
176
+ self.train_state.next_log_step += self.images_per_log
177
+ if self.scaler_D.get_scale() < 1e-8 or self.scaler_G.get_scale() < 1e-8:
178
+ print("Stopping training as gradient scale < 1e-8")
179
+ logger.log("Stopping training as gradient scale < 1e-8")
180
+ break
181
+
182
+ if logger.global_step() >= self.train_state.next_val_step:
183
+ self.evaluate()
184
+ self.log_time()
185
+ self.save_images()
186
+ self.train_state.next_val_step += self.images_per_val
187
+ logger.step(self.batch_size*tops.world_size())
188
+ logger.log(f"Reached end of training at step {logger.global_step()}.")
189
+ checkpointer.save_registered_models()
190
+
191
+ def estimate_ims_per_hour(self):
192
+ batch = next(self.dl_train)
193
+ n_ims = int(100e3)
194
+ n_steps = int(n_ims / (self.batch_size * tops.world_size()))
195
+ n_ims = n_steps * self.batch_size * tops.world_size()
196
+ for i in range(10): # Warmup
197
+ self.G_EMA.update_beta()
198
+ self.step_D(batch)
199
+ self.step_G(batch)
200
+ self.G_EMA.update(self.G)
201
+ start_time = time.time()
202
+ for i in utils.tqdm_(list(range(n_steps))):
203
+ self.G_EMA.update_beta()
204
+ self.step_D(batch)
205
+ self.step_G(batch)
206
+ self.G_EMA.update(self.G)
207
+ total_time = time.time() - start_time
208
+ ims_per_sec = n_ims / total_time
209
+ ims_per_hour = ims_per_sec * 60*60
210
+ ims_per_day = ims_per_hour * 24
211
+ logger.log(f"Images per hour: {ims_per_hour/1e6:.3f}M")
212
+ logger.log(f"Images per day: {ims_per_day/1e6:.3f}M")
213
+ import math
214
+ ims_per_4_day = int(math.ceil(ims_per_day / tops.world_size() * 4))
215
+ logger.log(f"Images per 4 days: {ims_per_4_day}")
216
+ logger.add_dict({
217
+ "stats/ims_per_day": ims_per_day,
218
+ "stats/ims_per_4_day": ims_per_4_day
219
+ })
220
+
221
+ def log_time(self):
222
+ if not hasattr(self, "start_time"):
223
+ self.start_time = time.time()
224
+ self.last_time_step = logger.global_step()
225
+ return
226
+ n_images = logger.global_step() - self.last_time_step
227
+ if n_images == 0:
228
+ return
229
+ n_secs = time.time() - self.start_time
230
+ n_ims_per_sec = n_images / n_secs
231
+ training_time_hours = n_secs / 60/ 60
232
+ self.train_state.total_time += training_time_hours
233
+ remaining_images = self.max_images_to_train - logger.global_step()
234
+ remaining_time = remaining_images / n_ims_per_sec / 60 / 60
235
+ logger.add_dict({
236
+ "stats/n_ims_per_sec": n_ims_per_sec,
237
+ "stats/total_traing_time_hours": self.train_state.total_time,
238
+ "stats/remaining_time_hours": remaining_time
239
+ })
240
+ self.last_time_step = logger.global_step()
241
+ self.start_time = time.time()
242
+
243
+ def save_images(self):
244
+ dl_val = iter(self.dl_val)
245
+ batch = next(dl_val)
246
+ # TRUNCATED visualization
247
+ ims_to_log = 8
248
+ self.G_EMA.eval()
249
+ z = self.G.get_z(batch["img"])
250
+ fakes_truncated = self.G_EMA.sample(**batch, truncation_value=0)["img"]
251
+ fakes_truncated = utils.denormalize_img(fakes_truncated).mul(255).byte()[:ims_to_log].cpu()
252
+ if "__key__" in batch:
253
+ batch.pop("__key__")
254
+ real = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log]
255
+ to_vis = torch.cat((real, fakes_truncated))
256
+ logger.add_images("images/truncated", to_vis, nrow=2)
257
+
258
+ # Diverse images
259
+ ims_diverse = 3
260
+ batch = next(dl_val)
261
+ to_vis = []
262
+
263
+ for i in range(ims_diverse):
264
+ z = self.G.get_z(batch["img"])[:1].repeat(batch["img"].shape[0], 1)
265
+ fakes = utils.denormalize_img(self.G_EMA(**batch, z=z)["img"]).mul(255).byte()[:ims_to_log].cpu()
266
+ to_vis.append(fakes)
267
+ if "__key__" in batch:
268
+ batch.pop("__key__")
269
+ reals = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log]
270
+ to_vis.insert(0, reals)
271
+ to_vis = torch.cat(to_vis)
272
+ logger.add_images("images/diverse", to_vis, nrow=ims_diverse+1)
273
+
274
+ self.G_EMA.train()
275
+ pass
276
+
277
+ def evaluate(self):
278
+ logger.log("Stating evaluation.")
279
+ self.G_EMA.eval()
280
+ try:
281
+ checkpointer.save_registered_models(max_keep=3)
282
+ except Exception:
283
+ logger.log("Could not save checkpoint.")
284
+ if self.broadcast_buffers:
285
+ check_ddp_consistency(self.G)
286
+ check_ddp_consistency(self.D)
287
+ metrics = self.evaluate_fn(generator=self.G_EMA, dataloader=self.dl_val)
288
+ metrics = {f"metrics/{k}": v for k,v in metrics.items()}
289
+ logger.add_dict(metrics, level=logger.logger.INFO)
290
+
291
+ def step_D(self, batch):
292
+ utils.set_requires_grad(self.trainable_params_D, True)
293
+ utils.set_requires_grad(self.trainable_params_G, False)
294
+ tops.zero_grad(self.D)
295
+ loss, to_log = self.loss_handler.D_loss(batch, grad_scaler=self.scaler_D)
296
+ with torch.autograd.profiler.record_function("D_step"):
297
+ self.scaler_D.scale(loss).backward()
298
+ accumulate_gradients(self.trainable_params_D, fp16_ddp_accumulate=self.fp16_ddp_accumulate)
299
+ if self.broadcast_buffers:
300
+ accumulate_buffers(self.D)
301
+ accumulate_buffers(self.G)
302
+ # Step will not unscale if unscale is called previously.
303
+ self.scaler_D.step(self.D_optim)
304
+ self.scaler_D.update()
305
+ utils.set_requires_grad(self.trainable_params_D, False)
306
+ utils.set_requires_grad(self.trainable_params_G, False)
307
+ return to_log
308
+
309
+ def step_G(self, batch):
310
+ utils.set_requires_grad(self.trainable_params_D, False)
311
+ utils.set_requires_grad(self.trainable_params_G, True)
312
+ tops.zero_grad(self.G)
313
+ loss, to_log = self.loss_handler.G_loss(batch, grad_scaler=self.scaler_G)
314
+ with torch.autograd.profiler.record_function("G_step"):
315
+ self.scaler_G.scale(loss).backward()
316
+ accumulate_gradients(self.trainable_params_G, fp16_ddp_accumulate=self.fp16_ddp_accumulate)
317
+ if self.broadcast_buffers:
318
+ accumulate_buffers(self.G)
319
+ accumulate_buffers(self.D)
320
+ self.scaler_G.step(self.G_optim)
321
+ self.scaler_G.update()
322
+ utils.set_requires_grad(self.trainable_params_D, False)
323
+ utils.set_requires_grad(self.trainable_params_G, False)
324
+ return to_log
dp2/generator/__init__.py ADDED
File without changes
dp2/generator/base.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import tqdm
4
+ import tops
5
+ from ..layers import Module
6
+ from ..layers.sg2_layers import FullyConnectedLayer
7
+ from dp2 import utils
8
+
9
+
10
+ class BaseGenerator(Module):
11
+
12
+ def __init__(self, z_channels: int):
13
+ super().__init__()
14
+ self.z_channels = z_channels
15
+ self.latent_space = "Z"
16
+
17
+ @torch.no_grad()
18
+ def get_z(
19
+ self,
20
+ x: torch.Tensor = None,
21
+ z: torch.Tensor = None,
22
+ truncation_value: float = None,
23
+ batch_size: int = None,
24
+ dtype=None, device=None) -> torch.Tensor:
25
+ """Generates a latent variable for generator.
26
+ """
27
+ if z is not None:
28
+ return z
29
+ if x is not None:
30
+ batch_size = x.shape[0]
31
+ dtype = x.dtype
32
+ device = x.device
33
+ if device is None:
34
+ device = utils.get_device()
35
+ if truncation_value == 0:
36
+ return torch.zeros((batch_size, self.z_channels), device=device, dtype=dtype)
37
+ z = torch.randn((batch_size, self.z_channels), device=device, dtype=dtype)
38
+ if truncation_value is None:
39
+ return z
40
+ while z.abs().max() > truncation_value:
41
+ m = z.abs() > truncation_value
42
+ z[m] = torch.rand_like(z)[m]
43
+ return z
44
+
45
+ def sample(self, truncation_value, z=None, **kwargs):
46
+ """
47
+ Samples via interpolating to the mean (0).
48
+ """
49
+ if truncation_value is None:
50
+ return self.forward(**kwargs)
51
+ truncation_value = max(0, truncation_value)
52
+ truncation_value = min(truncation_value, 1)
53
+ if z is None:
54
+ z = self.get_z(kwargs["condition"])
55
+ z = z * truncation_value
56
+ return self.forward(**kwargs, z=z)
57
+
58
+
59
+
60
+ class SG2StyleNet(torch.nn.Module):
61
+ def __init__(self,
62
+ z_dim, # Input latent (Z) dimensionality.
63
+ w_dim, # Intermediate latent (W) dimensionality.
64
+ num_layers = 2, # Number of mapping layers.
65
+ lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
66
+ w_avg_beta = 0.998, # Decay for tracking the moving average of W during training.
67
+ ):
68
+ super().__init__()
69
+ self.z_dim = z_dim
70
+ self.w_dim = w_dim
71
+ self.num_layers = num_layers
72
+ self.w_avg_beta = w_avg_beta
73
+ # Construct layers.
74
+ features = [self.z_dim] + [self.w_dim] * self.num_layers
75
+ for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]):
76
+ layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier)
77
+ setattr(self, f'fc{idx}', layer)
78
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
79
+
80
+ def forward(self, z, update_emas=False, y=None):
81
+ tops.assert_shape(z, [None, self.z_dim])
82
+
83
+ # Embed, normalize, and concatenate inputs.
84
+ x = z.to(torch.float32)
85
+ x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
86
+ # Execute layers.
87
+ for idx in range(self.num_layers):
88
+ x = getattr(self, f'fc{idx}')(x)
89
+ # Update moving average of W.
90
+ if update_emas:
91
+ self.w_avg.copy_(x.float().detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
92
+
93
+ return x
94
+
95
+ def extra_repr(self):
96
+ return f'z_dim={self.z_dim:d}, w_dim={self.w_dim:d}'
97
+
98
+ def update_w(self, n=int(10e3), batch_size=32):
99
+ """
100
+ Calculate w_ema over n iterations.
101
+ Useful in cases where w_ema is calculated incorrectly during training.
102
+ """
103
+ n = n // batch_size
104
+ for i in tqdm.trange(n, desc="Updating w"):
105
+ z = torch.randn((batch_size, self.z_dim), device=tops.get_device())
106
+ self(z, update_emas=True)
107
+
108
+
109
+ class BaseStyleGAN(BaseGenerator):
110
+
111
+ def __init__(self, z_channels: int, w_dim: int):
112
+ super().__init__(z_channels)
113
+ self.style_net = SG2StyleNet(z_channels, w_dim)
114
+ self.latent_space = "W"
115
+
116
+ def get_w(self, z, update_emas):
117
+ return self.style_net(z, update_emas=update_emas)
118
+
119
+ @torch.no_grad()
120
+ def sample(self, truncation_value, **kwargs):
121
+ if truncation_value is None:
122
+ return self.forward(**kwargs)
123
+ truncation_value = max(0, truncation_value)
124
+ truncation_value = min(truncation_value, 1)
125
+ w = self.get_w(self.get_z(kwargs["condition"]), False)
126
+ w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value)
127
+ return self.forward(**kwargs, w=w)
128
+
129
+ def update_w(self, *args, **kwargs):
130
+ self.style_net.update_w(*args, **kwargs)
131
+
132
+
133
+ @torch.no_grad()
134
+ def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
135
+ if truncation_value is None:
136
+ return self.forward(**kwargs)
137
+ truncation_value = max(0, truncation_value)
138
+ truncation_value = min(truncation_value, 1)
139
+ w = self.get_w(self.get_z(kwargs["condition"]), False)
140
+ if w_indices is None:
141
+ w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w)))
142
+ w_centers = self.style_net.w_centers[w_indices].to(w.device)
143
+ w = w_centers.to(w.dtype).lerp(w, truncation_value)
144
+ return self.forward(**kwargs, w=w)
dp2/generator/dummy_generators.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .base import BaseGenerator
4
+
5
+
6
+ class PixelationGenerator(BaseGenerator):
7
+
8
+ def __init__(self, pixelation_size, **kwargs):
9
+ super().__init__(z_channels=0)
10
+ self.pixelation_size = pixelation_size
11
+ self.z_channels = 0
12
+ self.latent_space=None
13
+
14
+ def forward(self, img, condition, mask, **kwargs):
15
+ old_shape = img.shape[-2:]
16
+ img = nn.functional.interpolate(img, size=(self.pixelation_size, self.pixelation_size), mode="bilinear", align_corners=True)
17
+ img = nn.functional.interpolate(img, size=old_shape, mode="bilinear", align_corners=True)
18
+ out = img*(1-mask) + condition*mask
19
+ return {"img": out}
20
+
21
+
22
+ class MaskOutGenerator(BaseGenerator):
23
+
24
+ def __init__(self, noise: str, **kwargs):
25
+ super().__init__(z_channels=0)
26
+ self.noise = noise
27
+ self.z_channels = 0
28
+ assert self.noise in ["rand", "constant"]
29
+ self.latent_space = None
30
+
31
+ def forward(self, img, condition, mask, **kwargs):
32
+
33
+ if self.noise == "constant":
34
+ img = torch.zeros_like(img)
35
+ elif self.noise == "rand":
36
+ img = torch.rand_like(img)
37
+ out = img*(1-mask) + condition*mask
38
+ return {"img": out}
39
+
40
+
41
+ class IdentityGenerator(BaseGenerator):
42
+
43
+ def __init__(self):
44
+ super().__init__(z_channels=0)
45
+
46
+ def forward(self, img, condition, mask, **kwargs):
47
+ return dict(img=img)
dp2/generator/imagen3_old.py ADDED
@@ -0,0 +1,1210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # What is missing from this implementation
2
+ # 1. Global context in res block
3
+ # 2. Cross attention of conditional information in resnet block
4
+ #
5
+ from functools import partial
6
+ import tops
7
+ from tops.config import instantiate
8
+ import warnings
9
+ from typing import Iterable, List, Tuple
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch import einsum
14
+ from einops import rearrange
15
+ from dp2 import infer, utils
16
+ from .base import BaseGenerator
17
+ from sg3_torch_utils.ops import bias_act
18
+ from dp2.layers import Sequential
19
+ import torch.nn.functional as F
20
+ from torchvision.transforms.functional import resize, InterpolationMode
21
+ from sg3_torch_utils.ops import conv2d_resample, fma, upfirdn2d
22
+
23
+
24
+
25
+
26
+ class Upfirdn2d(torch.nn.Module):
27
+
28
+
29
+ def __init__(self, down=1, up=1, fix_gain=True):
30
+ super().__init__()
31
+ self.register_buffer("resample_filter", upfirdn2d.setup_filter([1, 3, 3, 1]))
32
+ fw, fh = upfirdn2d._get_filter_size(self.resample_filter)
33
+ px0, px1, py0, py1 = upfirdn2d._parse_padding(0)
34
+ self.down = down
35
+ self.up = up
36
+ if up > 1:
37
+ px0 += (fw + up - 1) // 2
38
+ px1 += (fw - up) // 2
39
+ py0 += (fh + up - 1) // 2
40
+ py1 += (fh - up) // 2
41
+ if down > 1:
42
+ px0 += (fw - down + 1) // 2
43
+ px1 += (fw - down) // 2
44
+ py0 += (fh - down + 1) // 2
45
+ py1 += (fh - down) // 2
46
+ self.padding = [px0,px1,py0,py1]
47
+ self.gain = up**2 if fix_gain else 1
48
+
49
+ def forward(self, x, *args):
50
+ if isinstance(x, dict):
51
+ x = {k: v for k, v in x.items()}
52
+ x["x"] = upfirdn2d.upfirdn2d(x["x"], self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain)
53
+ return x
54
+ x = upfirdn2d.upfirdn2d(x, self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain)
55
+ if len(args) == 0:
56
+ return x
57
+ return (x, *args)
58
+ @torch.no_grad()
59
+ def spatial_embed_keypoints(keypoints: torch.Tensor, x):
60
+ tops.assert_shape(keypoints, (None, None, 3))
61
+ B, N_K, _ = keypoints.shape
62
+ H, W = x.shape[-2:]
63
+ keypoint_spatial = torch.zeros(keypoints.shape[0], N_K, H, W, device=keypoints.device, dtype=torch.float32)
64
+ x, y, visible = keypoints.chunk(3, dim=2)
65
+ x = (x * W).round().long().clamp(0, W-1)
66
+ y = (y * H).round().long().clamp(0, H-1)
67
+ kp_idx = torch.arange(0, N_K, 1, device=keypoints.device, dtype=torch.long).view(1, -1, 1).repeat(B, 1, 1)
68
+ pos = (kp_idx*(H*W) + y*W + x + 1)
69
+ # Offset all by 1 to index invisible keypoints as 0
70
+ pos = (pos * visible.round().long()).squeeze(dim=-1)
71
+ keypoint_spatial = torch.zeros(keypoints.shape[0], N_K*H*W+1, device=keypoints.device, dtype=torch.float32)
72
+ keypoint_spatial.scatter_(1, pos, 1)
73
+ keypoint_spatial = keypoint_spatial[:, 1:].view(-1, N_K, H, W)
74
+ return keypoint_spatial
75
+
76
+
77
+ def modulated_conv2d(
78
+ x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
79
+ weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
80
+ styles, # Modulation coefficients of shape [batch_size, in_channels].
81
+ noise = None, # Optional noise tensor to add to the output activations.
82
+ up = 1, # Integer upsampling factor.
83
+ down = 1, # Integer downsampling factor.
84
+ padding = 0, # Padding with respect to the upsampled image.
85
+ resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
86
+ demodulate = True, # Apply weight demodulation?
87
+ flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
88
+ fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
89
+ ):
90
+ batch_size = x.shape[0]
91
+ out_channels, in_channels, kh, kw = weight.shape
92
+ tops.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
93
+ tops.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
94
+ tops.assert_shape(styles, [batch_size, in_channels]) # [NI]
95
+
96
+ # Pre-normalize inputs to avoid FP16 overflow.
97
+ if x.dtype == torch.float16 and demodulate:
98
+ weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
99
+ styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
100
+
101
+ # Calculate per-sample weights and demodulation coefficients.
102
+ w = None
103
+ dcoefs = None
104
+ if demodulate or fused_modconv:
105
+ w = weight.unsqueeze(0) # [NOIkk]
106
+ w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
107
+ if demodulate:
108
+ dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
109
+ if demodulate and fused_modconv:
110
+ w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
111
+
112
+ # Execute by scaling the activations before and after the convolution.
113
+ if not fused_modconv:
114
+ x = x * styles.reshape(batch_size, -1, 1, 1)
115
+ x = conv2d_resample.conv2d_resample(x=x, w=weight, f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
116
+ if demodulate and noise is not None:
117
+ x = fma.fma(x, dcoefs.reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
118
+ elif demodulate:
119
+ x = x * dcoefs.reshape(batch_size, -1, 1, 1)
120
+ elif noise is not None:
121
+ x = x.add_(noise.to(x.dtype))
122
+ return x
123
+
124
+ with tops.suppress_tracer_warnings(): # this value will be treated as a constant
125
+ batch_size = int(batch_size)
126
+ # Execute as one fused op using grouped convolution.
127
+ tops.assert_shape(x, [batch_size, in_channels, None, None])
128
+ x = x.reshape(1, -1, *x.shape[2:])
129
+ w = w.reshape(-1, in_channels, kh, kw)
130
+ x = conv2d_resample.conv2d_resample(x=x, w=w, f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
131
+ x = x.reshape(batch_size, -1, *x.shape[2:])
132
+ if noise is not None:
133
+ x = x.add_(noise)
134
+ return x
135
+
136
+
137
+ class Identity(nn.Module):
138
+
139
+ def __init__(self) -> None:
140
+ super().__init__()
141
+
142
+ def forward(self, x, *args, **kwargs):
143
+ return x
144
+
145
+
146
+ class LayerNorm(nn.Module):
147
+ def __init__(self, dim, stable=False):
148
+ super().__init__()
149
+ self.stable = stable
150
+ self.g = nn.Parameter(torch.ones(dim))
151
+
152
+ def forward(self, x):
153
+ if self.stable:
154
+ x = x / x.amax(dim=-1, keepdim=True).detach()
155
+
156
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
157
+ var = torch.var(x, dim=-1, unbiased=False, keepdim=True)
158
+ mean = torch.mean(x, dim=-1, keepdim=True)
159
+ return (x - mean) * (var + eps).rsqrt() * self.g
160
+
161
+
162
+ class FullyConnectedLayer(torch.nn.Module):
163
+ def __init__(self,
164
+ in_features, # Number of input features.
165
+ out_features, # Number of output features.
166
+ bias = True, # Apply additive bias before the activation function?
167
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
168
+ lr_multiplier = 1, # Learning rate multiplier.
169
+ bias_init = 0, # Initial value for the additive bias.
170
+ ):
171
+ super().__init__()
172
+ self.repr = dict(
173
+ in_features=in_features, out_features=out_features, bias=bias,
174
+ activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init)
175
+ self.activation = activation
176
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
177
+ self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
178
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
179
+ self.bias_gain = lr_multiplier
180
+ self.in_features = in_features
181
+ self.out_features = out_features
182
+
183
+ def forward(self, x):
184
+ w = self.weight * self.weight_gain
185
+ b = self.bias
186
+ if b is not None:
187
+ if self.bias_gain != 1:
188
+ b = b * self.bias_gain
189
+ x = F.linear(x, w)
190
+ x = bias_act.bias_act(x, b, act=self.activation)
191
+ return x
192
+
193
+ def extra_repr(self) -> str:
194
+ return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
195
+
196
+
197
+
198
+ def checkpoint_fn(fn, *args, **kwargs):
199
+ warnings.simplefilter("ignore")
200
+ return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
201
+
202
+ class Conv2d(torch.nn.Module):
203
+ def __init__(
204
+ self,
205
+ in_channels,
206
+ out_channels,
207
+ kernel_size=3,
208
+ activation='lrelu',
209
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
210
+ bias=True,
211
+ norm=None,
212
+ lr_multiplier=1,
213
+ bias_init=0,
214
+ w_dim=None,
215
+ gradient_checkpoint_norm=False,
216
+ gain=1,
217
+ ):
218
+ super().__init__()
219
+ self.fused_modconv = False
220
+ if norm == torch.nn.InstanceNorm2d:
221
+ self.norm = torch.nn.InstanceNorm2d(None)
222
+ elif isinstance(norm, torch.nn.Module):
223
+ self.norm = norm
224
+ elif norm == "fused_modconv":
225
+ self.fused_modconv = True
226
+ elif norm:
227
+ self.norm = torch.nn.InstanceNorm2d(None)
228
+ elif norm is not None:
229
+ raise ValueError(f"norm not supported: {norm}")
230
+ self.activation = activation
231
+ self.conv_clamp = conv_clamp
232
+ self.out_channels = out_channels
233
+ self.in_channels = in_channels
234
+ self.padding = kernel_size // 2
235
+ self.repr = dict(
236
+ in_channels=in_channels, out_channels=out_channels,
237
+ kernel_size=kernel_size,
238
+ activation=activation, conv_clamp=conv_clamp, bias=bias,
239
+ fused_modconv=self.fused_modconv
240
+ )
241
+ self.act_gain = bias_act.activation_funcs[activation].def_gain * gain
242
+ self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2))
243
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]))
244
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels])+bias_init) if bias else None
245
+ self.bias_gain = lr_multiplier
246
+ if w_dim is not None:
247
+ if self.fused_modconv:
248
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
249
+ else:
250
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
251
+ self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0)
252
+ self.gradient_checkpoint_norm = gradient_checkpoint_norm
253
+
254
+ def forward(self, x, w=None, gain=1, **kwargs):
255
+ if self.fused_modconv:
256
+ styles = self.affine(w)
257
+ with torch.cuda.amp.autocast(enabled=False):
258
+ x = modulated_conv2d(x=x.half(), weight=self.weight.half(), styles=styles.half(), noise=None,
259
+ padding=self.padding, flip_weight=True, fused_modconv=False).to(x.dtype)
260
+ else:
261
+ if hasattr(self, "affine"):
262
+ gamma = self.affine(w).view(-1, self.in_channels, 1, 1)
263
+ beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1)
264
+ x = fma.fma(x, gamma ,beta)
265
+ w = self.weight * self.weight_gain
266
+ x = F.conv2d(input=x, weight=w, padding=self.padding,)
267
+
268
+ if hasattr(self, "norm"):
269
+ if self.gradient_checkpoint_norm:
270
+ x = checkpoint_fn(self.norm, x)
271
+ else:
272
+ x = self.norm(x)
273
+ act_gain = self.act_gain * gain
274
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
275
+ b = self.bias * self.bias_gain if self.bias is not None else None
276
+ x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
277
+ return x
278
+
279
+ def extra_repr(self) -> str:
280
+ return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
281
+
282
+
283
+ class CrossAttention(nn.Module):
284
+ def __init__(
285
+ self,
286
+ dim,
287
+ context_dim,
288
+ dim_head=64,
289
+ heads=8,
290
+ norm_context=False,
291
+ ):
292
+ super().__init__()
293
+ self.scale = dim_head ** -0.5
294
+
295
+ self.heads = heads
296
+ inner_dim = dim_head * heads
297
+
298
+ self.norm = nn.InstanceNorm1d(dim)
299
+ self.norm_context = nn.InstanceNorm1d(None) if norm_context else Identity()
300
+
301
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
302
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
303
+
304
+ self.to_out = nn.Sequential(
305
+ nn.Linear(inner_dim, dim, bias=False),
306
+ nn.InstanceNorm1d(None)
307
+ )
308
+
309
+ def forward(self, x, w):
310
+ x = self.norm(x)
311
+ w = self.norm_context(w)
312
+
313
+ q, k, v = (self.to_q(x), *self.to_kv(w).chunk(2, dim = -1))
314
+ q = rearrange(q, "b n (h d) -> b h n d", h = self.heads)
315
+ k = rearrange(k, "b n (h d) -> b h n d", h = self.heads)
316
+ v = rearrange(v, "b n (h d) -> b h n d", h = self.heads)
317
+ q = q * self.scale
318
+ # similarities
319
+ sim = einsum('b h i d, b h j d -> b h i j', q, k)
320
+ attn = sim.softmax(dim = -1, dtype = torch.float32)
321
+
322
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
323
+ out = rearrange(out, 'b h n d -> b n (h d)')
324
+ return self.to_out(out)
325
+
326
+
327
+ class SG2ResidualBlock(torch.nn.Module):
328
+ def __init__(
329
+ self,
330
+ in_channels, # Number of input channels, 0 = first block.
331
+ out_channels, # Number of output channels.
332
+ conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
333
+ skip_gain=np.sqrt(.5),
334
+ cross_attention: bool = False,
335
+ cross_attention_len: int = None,
336
+ use_adain: bool = True,
337
+ **layer_kwargs, # Arguments for conv layer.
338
+ ):
339
+ super().__init__()
340
+ self.in_channels = in_channels
341
+ self.out_channels = out_channels
342
+ w_dim = layer_kwargs.pop("w_dim") if "w_dim" in layer_kwargs else None
343
+ if use_adain:
344
+ layer_kwargs["w_dim"] = w_dim
345
+
346
+ self.conv0 = Conv2d(in_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs)
347
+ self.conv1 = Conv2d(out_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs, gain=skip_gain)
348
+
349
+ self.skip = Conv2d(in_channels, out_channels, kernel_size=1, bias=False, gain=skip_gain)
350
+ if cross_attention and w_dim is not None:
351
+ self.cross_attention_len = cross_attention_len
352
+ self.cross_attn = CrossAttention(
353
+ dim=out_channels, context_dim=w_dim//self.cross_attention_len,
354
+ gain=skip_gain)
355
+
356
+ def forward(self, x, w=None, **layer_kwargs):
357
+ y = self.skip(x)
358
+ x = self.conv0(x, w, **layer_kwargs)
359
+ x = self.conv1(x, w, **layer_kwargs)
360
+ if hasattr(self, "cross_attn"):
361
+ h = x.shape[-2]
362
+ x = rearrange(x, "b c h w -> b (h w) c")
363
+ w = rearrange(w, "b (n c) -> b n c", n=self.cross_attention_len)
364
+ x = self.cross_attn(x, w=w) + x
365
+ x = rearrange(x, "b (h w) c -> b c h w", h=h)
366
+ return y + x
367
+
368
+
369
+ def default(val, d):
370
+ if val is not None:
371
+ return val
372
+ return d() if callable(d) else d
373
+
374
+
375
+ def cast_tuple(val, length=None):
376
+ if isinstance(val, Iterable) and not isinstance(val, str):
377
+ val = tuple(val)
378
+ output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
379
+ if length is not None:
380
+ assert len(output) == length, (output, length)
381
+ return output
382
+
383
+
384
+ class Attention(nn.Module):
385
+ # This is a version of Multi-Query Attention ()
386
+ # Fast Transformer Decoding: One Write-Head is All You Need
387
+ # Ablated in: https://arxiv.org/pdf/2203.07814.pdf
388
+ # and https://arxiv.org/pdf/2204.02311.pdf
389
+ def __init__(self, dim, norm, attn_fix_gain, gradient_checkpoint, dim_head=64, heads=8, cosine_sim_attn=False, fix_attention_again=False, gain=None):
390
+ super().__init__()
391
+ self.scale = dim_head**-0.5 if not cosine_sim_attn else 1.0
392
+ self.cosine_sim_attn = cosine_sim_attn
393
+ self.cosine_sim_scale = 16 if cosine_sim_attn else 1
394
+ self.gradient_checkpoint = gradient_checkpoint
395
+ self.heads = heads
396
+ self.dim = dim
397
+ self.fix_attention_again = fix_attention_again
398
+ inner_dim = dim_head * heads
399
+ if norm == "LN":
400
+ self.norm = LayerNorm(dim)
401
+ elif norm == "IN":
402
+ self.norm = nn.InstanceNorm1d(dim)
403
+ elif norm is None:
404
+ self.norm = nn.Identity()
405
+ else:
406
+ raise ValueError(f"Norm not supported: {norm}")
407
+
408
+ self.to_q = FullyConnectedLayer(dim, inner_dim, bias=False)
409
+ self.to_kv = FullyConnectedLayer(dim, dim_head*2, bias=False)
410
+
411
+ self.to_out = nn.Sequential(
412
+ FullyConnectedLayer(inner_dim, dim, bias=False),
413
+ LayerNorm(dim) if norm == "LN" else nn.InstanceNorm1d(dim)
414
+ )
415
+ if fix_attention_again:
416
+ assert gain is not None
417
+ self.gain = gain
418
+ else:
419
+ self.gain = np.sqrt(.5) if attn_fix_gain else 1
420
+
421
+ def run_function(self, x, attn_bias):
422
+ b, c, h, w = x.shape
423
+ x = rearrange(x, "b c h w -> b (h w) c")
424
+ in_ = x
425
+ b, n, device = *x.shape[:2], x.device
426
+ x = self.norm(x)
427
+ q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
428
+
429
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.heads)
430
+ q = q * self.scale
431
+
432
+ # calculate query / key similarities
433
+ sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale
434
+
435
+ if attn_bias is not None:
436
+ attn_bias = attn_bias
437
+ attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)")
438
+ sim = sim + attn_bias
439
+
440
+ attn = sim.softmax(dim=-1)
441
+
442
+ out = einsum("b h i j, b j d -> b h i d", attn, v)
443
+
444
+ out = rearrange(out, "b h n d -> b n (h d)")
445
+ if self.fix_attention_again:
446
+ out = self.to_out(out)*self.gain + in_
447
+ else:
448
+ out = (self.to_out(out) + in_) * self.gain
449
+ out = rearrange(out, "b (h w) c -> b c h w", h=h)
450
+ return out
451
+
452
+ def forward(self, x, *args, attn_bias=None, **kwargs):
453
+ if self.gradient_checkpoint:
454
+ return checkpoint_fn(self.run_function, x, attn_bias)
455
+ return self.run_function(x, attn_bias)
456
+
457
+ def get_attention(self, x, attn_bias=None):
458
+ b, c, h, w = x.shape
459
+ x = rearrange(x, "b c h w -> b (h w) c")
460
+ in_ = x
461
+ b, n, device = *x.shape[:2], x.device
462
+ x = self.norm(x)
463
+ q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
464
+
465
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.heads)
466
+ q = q * self.scale
467
+
468
+ # calculate query / key similarities
469
+ sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale
470
+
471
+ if attn_bias is not None:
472
+ attn_bias = attn_bias
473
+ attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)")
474
+ sim = sim + attn_bias
475
+
476
+ attn = sim.softmax(dim=-1)
477
+ return attn, None
478
+
479
+
480
+ class BiasedAttention(Attention):
481
+
482
+ def __init__(self, *args, head_wise: bool=True, **kwargs):
483
+ super().__init__(*args, **kwargs)
484
+ out_ch = self.heads if head_wise else 1
485
+ self.conv = Conv2d(self.dim+2, out_ch, activation="linear", kernel_size=3, bias_init=0)
486
+ nn.init.zeros_(self.conv.weight.data)
487
+
488
+ def forward(self, x, mask):
489
+ mask = resize(mask, size=x.shape[-2:])
490
+ bias = self.conv(torch.cat((x, mask, 1-mask), dim=1))
491
+ return super().forward(x=x, attn_bias=bias)
492
+
493
+ def get_attention(self, x, mask):
494
+ mask = resize(mask, size=x.shape[-2:])
495
+ bias = self.conv(torch.cat((x, mask, 1-mask), dim=1))
496
+ return super().get_attention(x, bias)[0], bias
497
+
498
+ class UNet(BaseGenerator):
499
+
500
+ def __init__(
501
+ self,
502
+ im_channels: int,
503
+ dim: int,
504
+ dim_mults: tuple,
505
+ num_resnet_blocks, # Number of resnet blocks per resolution
506
+ n_middle_blocks: int,
507
+ z_channels: int,
508
+ conv_clamp: int,
509
+ layer_attn,
510
+ w_dim: int,
511
+ norm_enc: bool,
512
+ norm_dec: str,
513
+ stylenet: nn.Module,
514
+ enc_style: bool, # Toggle style injection in encoder
515
+ use_maskrcnn_mask: bool,
516
+ skip_all_unets: bool,
517
+ fix_resize:bool,
518
+ comodulate: bool,
519
+ comod_net: nn.Module,
520
+ lr_comod: float,
521
+ dec_style: bool,
522
+ input_keypoints: bool,
523
+ n_keypoints: int,
524
+ input_keypoint_indices: Tuple[int],
525
+ use_adain: bool,
526
+ cross_attention: bool,
527
+ cross_attention_len: int,
528
+ gradient_checkpoint_norm: bool,
529
+ attn_cls: partial,
530
+ mask_out_train: bool,
531
+ fix_gain_again: bool,
532
+ ) -> None:
533
+ super().__init__(z_channels)
534
+ self.enc_style = enc_style
535
+ self.n_keypoints = n_keypoints
536
+ self.input_keypoint_indices = list(input_keypoint_indices)
537
+ self.input_keypoints = input_keypoints
538
+ self.mask_out_train = mask_out_train
539
+ n_layers = len(dim_mults)
540
+ self.n_layers = n_layers
541
+ layer_attn = cast_tuple(layer_attn, n_layers)
542
+ num_resnet_blocks = cast_tuple(num_resnet_blocks, n_layers)
543
+ self._cnum = dim
544
+ self._image_channels = im_channels
545
+ self._z_channels = z_channels
546
+ encoder_layers = []
547
+ condition_ch = im_channels
548
+ self.from_rgb = Conv2d(
549
+ condition_ch + 2 + 2*int(use_maskrcnn_mask) + self.input_keypoints*len(input_keypoint_indices)
550
+ , dim, 7)
551
+
552
+ self.use_maskrcnn_mask = use_maskrcnn_mask
553
+ self.skip_all_unets = skip_all_unets
554
+ dims = [dim*m for m in dim_mults]
555
+ enc_blk = partial(
556
+ SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_enc,
557
+ use_adain=use_adain and self.enc_style,
558
+ w_dim=w_dim,
559
+ cross_attention=cross_attention,
560
+ cross_attention_len=cross_attention_len,
561
+ gradient_checkpoint_norm=gradient_checkpoint_norm
562
+ )
563
+ dec_blk = partial(
564
+ SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_dec,
565
+ use_adain=use_adain and dec_style,
566
+ w_dim=w_dim,
567
+ cross_attention=cross_attention,
568
+ cross_attention_len=cross_attention_len,
569
+ gradient_checkpoint_norm=gradient_checkpoint_norm
570
+ )
571
+ # Currently up/down sampling is done by bilinear upsampling.
572
+ # This can be simplified by replacing it with a strided upsampling layer...
573
+ self.encoder_attns = nn.ModuleList()
574
+ for lidx in range(n_layers):
575
+ gain = np.sqrt(1/3) if layer_attn[lidx] and fix_gain_again else np.sqrt(.5)
576
+ dim_in = dims[lidx]
577
+ dim_out = dims[min(lidx+1, n_layers-1)]
578
+ res_blocks = nn.ModuleList()
579
+ for i in range(num_resnet_blocks[lidx]):
580
+ is_last = num_resnet_blocks[lidx] - 1 == i
581
+ cur_dim = dim_out if is_last else dim_in
582
+ block = enc_blk(dim_in, cur_dim, skip_gain=gain)
583
+ res_blocks.append(block)
584
+ if layer_attn[lidx]:
585
+ self.encoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain))
586
+ else:
587
+ self.encoder_attns.append(Identity())
588
+ encoder_layers.append(res_blocks)
589
+ self.encoder = torch.nn.ModuleList(encoder_layers)
590
+
591
+ # initialize decoder
592
+ decoder_layers = []
593
+ self.unet_layers = torch.nn.ModuleList()
594
+ self.decoder_attns = torch.nn.ModuleList()
595
+ for lidx in range(n_layers):
596
+ dim_in = dims[min(-lidx, -1)]
597
+ dim_out = dims[-1-lidx]
598
+ res_blocks = nn.ModuleList()
599
+ unet_skips = nn.ModuleList()
600
+ for i in range(num_resnet_blocks[-lidx-1]):
601
+ is_first = i == 0
602
+ has_unet = is_first or skip_all_unets
603
+ is_last = i == num_resnet_blocks[-lidx-1] - 1
604
+ cur_dim = dim_in if is_first else dim_out
605
+ if has_unet and is_last and layer_attn[-lidx-1] and fix_gain_again: # x + residual + unet + layer attn
606
+ gain = np.sqrt(1/4)
607
+ elif has_unet: # x + residual + unet
608
+ gain = np.sqrt(1/3)
609
+ elif layer_attn[-lidx-1] and fix_gain_again: # x + residual + attention
610
+ gain = np.sqrt(1/3)
611
+ else: # x + residual
612
+ gain = np.sqrt(1/2) # Only residual block
613
+ block = dec_blk(cur_dim, dim_out, skip_gain=gain)
614
+ res_blocks.append(block)
615
+ if has_unet:
616
+ unet_block = Conv2d(
617
+ cur_dim, cur_dim, kernel_size=1, conv_clamp=conv_clamp,
618
+ norm=nn.InstanceNorm2d(None),
619
+ gradient_checkpoint_norm=gradient_checkpoint_norm,
620
+ gain=gain)
621
+ unet_skips.append(unet_block)
622
+ else:
623
+ unet_skips.append(torch.nn.Identity())
624
+ if layer_attn[-lidx-1]:
625
+ self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain))
626
+ else:
627
+ self.decoder_attns.append(Identity())
628
+
629
+ decoder_layers.append(res_blocks)
630
+ self.unet_layers.append(unet_skips)
631
+
632
+ middle_blocks = []
633
+ for i in range(n_middle_blocks):
634
+ block = dec_blk(dims[-1], dims[-1])
635
+ middle_blocks.append(block)
636
+ if n_middle_blocks != 0:
637
+ self.middle_blocks = Sequential(*middle_blocks)
638
+ self.decoder = torch.nn.ModuleList(decoder_layers)
639
+ self.to_rgb = Conv2d(dim, im_channels, 1, activation="linear", conv_clamp=conv_clamp)
640
+ self.stylenet = stylenet
641
+ self.downsample = Upfirdn2d(down=2, fix_gain=fix_resize)
642
+ self.upsample = Upfirdn2d(up=2, fix_gain=fix_resize)
643
+ self.comodulate = comodulate
644
+ if comodulate:
645
+ assert not self.enc_style
646
+ self.to_y = nn.Sequential(
647
+ Conv2d(dims[-1], dims[-1], lr_multiplier=lr_comod, gradient_checkpoint_norm=gradient_checkpoint_norm),
648
+ nn.AdaptiveAvgPool2d(1),
649
+ nn.Flatten(),
650
+ FullyConnectedLayer(dims[-1], 512, activation="lrelu", lr_multiplier=lr_comod)
651
+ )
652
+ self.comod_net = comod_net
653
+
654
+
655
+ def forward(self, condition, mask, maskrcnn_mask=None, z=None, w=None, update_emas=False, keypoints=None, return_decoder_features=False, **kwargs):
656
+ if z is None:
657
+ z = self.get_z(condition)
658
+ if w is None:
659
+ w = self.stylenet(z, update_emas=update_emas)
660
+ if self.use_maskrcnn_mask:
661
+ x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1)
662
+ else:
663
+ x = torch.cat((condition, mask, 1-mask), dim=1)
664
+
665
+ if self.input_keypoints:
666
+ keypoints = keypoints[:, self.input_keypoint_indices]
667
+ one_hot_pose = spatial_embed_keypoints(keypoints, x)
668
+ x = torch.cat((x, one_hot_pose), dim=1)
669
+ x = self.from_rgb(x)
670
+ x, unet_features = self.forward_enc(x, mask, w)
671
+ x, decoder_features = self.forward_dec(x, mask, w, unet_features)
672
+ x = self.to_rgb(x)
673
+ unmasked = x
674
+ if self.mask_out_train:
675
+ x = mask * condition + (1-mask) * x
676
+ out = dict(img=x, unmasked=unmasked)
677
+ if return_decoder_features:
678
+ out["decoder_features"] = decoder_features
679
+ return out
680
+
681
+ def forward_enc(self, x, mask, w):
682
+ unet_features = []
683
+ for i, res_blocks in enumerate(self.encoder):
684
+ is_last = i == len(self.encoder) - 1
685
+ for block in res_blocks:
686
+ x = block(x, w=w)
687
+ unet_features.append(x)
688
+ x = self.encoder_attns[i](x, mask=mask)
689
+ if not is_last:
690
+ x = self.downsample(x)
691
+ if self.comodulate:
692
+ y = self.to_y(x)
693
+ y = torch.cat((w, y), dim=-1)
694
+ w = self.comod_net(y)
695
+ return x, unet_features
696
+
697
+ def forward_dec(self, x, mask, w, unet_features):
698
+ if hasattr(self, "middle_blocks"):
699
+ x = self.middle_blocks(x, w=w)
700
+ features = []
701
+ unet_features = iter(reversed(unet_features))
702
+ for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)):
703
+ is_last = i == len(self.decoder) - 1
704
+ for skip, block in zip(unet_skip, res_blocks):
705
+ skip_x = next(unet_features)
706
+ if not isinstance(skip, torch.nn.Identity):
707
+ skip_x = skip(skip_x)
708
+ x = x + skip_x
709
+ x = block(x, w=w)
710
+ x = self.decoder_attns[i](x, mask=mask)
711
+ features.append(x)
712
+ if not is_last:
713
+ x = self.upsample(x)
714
+ return x, features
715
+
716
+ def get_w(self, z, update_emas):
717
+ return self.stylenet(z, update_emas=update_emas)
718
+
719
+ @torch.no_grad()
720
+ def sample(self, truncation_value, **kwargs):
721
+ if truncation_value is None:
722
+ return self.forward(**kwargs)
723
+ truncation_value = max(0, truncation_value)
724
+ truncation_value = min(truncation_value, 1)
725
+ w = self.get_w(self.get_z(kwargs["condition"]), False)
726
+ w = self.stylenet.w_avg.to(w.dtype).lerp(w, truncation_value)
727
+ return self.forward(**kwargs, w=w)
728
+
729
+ def update_w(self, *args, **kwargs):
730
+ self.style_net.update_w(*args, **kwargs)
731
+
732
+ @property
733
+ def style_net(self):
734
+ return self.stylenet
735
+
736
+ @torch.no_grad()
737
+ def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
738
+ if truncation_value is None:
739
+ return self.forward(**kwargs)
740
+ truncation_value = max(0, truncation_value)
741
+ truncation_value = min(truncation_value, 1)
742
+ w = self.get_w(self.get_z(kwargs["condition"]), False)
743
+ if w_indices is None:
744
+ w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w)))
745
+ w_centers = self.style_net.w_centers[w_indices].to(w.device)
746
+ w = w_centers.to(w.dtype).lerp(w, truncation_value)
747
+ return self.forward(**kwargs, w=w)
748
+
749
+
750
+ def get_stem_unet_kwargs(cfg):
751
+ if "stem_cfg" in cfg.generator: # If the stem has another stem, recursively apply get_stem_unet_kwargs
752
+ return get_stem_unet_kwargs(cfg.generator.stem_cfg)
753
+ return dict(cfg.generator)
754
+
755
+
756
+ class GrowingUnet(BaseGenerator):
757
+
758
+ def __init__(
759
+ self,
760
+ coarse_stem_cfg: str, # This can be a coarse generator or None
761
+ sr_cfg: str, # Can be a previous progressive u-net, Unet or None
762
+ residual: bool,
763
+ new_dataset: bool, # The "new dataset" creates condition first -> resizes
764
+ **unet_kwargs):
765
+ kwargs = dict()
766
+ if coarse_stem_cfg is not None:
767
+ coarse_stem_cfg = utils.load_config(coarse_stem_cfg)
768
+ kwargs = get_stem_unet_kwargs(coarse_stem_cfg)
769
+ if sr_cfg is not None:
770
+ sr_cfg = utils.load_config(sr_cfg)
771
+ sr_stem_unet_kwargs = get_stem_unet_kwargs(sr_cfg)
772
+ kwargs.update(sr_stem_unet_kwargs)
773
+ kwargs.update(unet_kwargs)
774
+ kwargs["stylenet"] = None
775
+ kwargs.pop("_target_")
776
+ if "sr_cfg" in kwargs: # Unet kwargs are inherited, do not pass this to the new u-net
777
+ del kwargs["sr_cfg"]
778
+ if "coarse_stem_cfg" in kwargs:
779
+ del kwargs["coarse_stem_cfg"]
780
+ super().__init__(z_channels=kwargs["z_channels"])
781
+ if coarse_stem_cfg is not None:
782
+ z_channels = coarse_stem_cfg.generator.z_channels
783
+ super().__init__(z_channels)
784
+ self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval()
785
+ self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize)
786
+ utils.set_requires_grad(self.coarse_stem, False)
787
+ else:
788
+ assert not residual
789
+
790
+ if sr_cfg is not None:
791
+ self.sr_stem = infer.build_trained_generator(sr_cfg, map_location="cpu").eval()
792
+ del self.sr_stem.from_rgb
793
+ del self.sr_stem.to_rgb
794
+ if hasattr(self.sr_stem, "coarse_stem"):
795
+ del self.sr_stem.coarse_stem
796
+ if isinstance(self.sr_stem, UNet):
797
+ del self.sr_stem.encoder[0][0] # Delete first residual block
798
+ del self.sr_stem.decoder[-1][-1] # Delete last residual block
799
+ else:
800
+ assert isinstance(self.sr_stem, GrowingUnet)
801
+ del self.sr_stem.unet.encoder[0][0] # Delete first residual block
802
+ del self.sr_stem.unet.decoder[-1][-1] # Delete last residual block
803
+ utils.set_requires_grad(self.sr_stem, False)
804
+
805
+
806
+ args = kwargs.pop("_args_")
807
+ if hasattr(self, "sr_stem"): # Growing the SR stem - Add a new layer to match sr
808
+ n_layers = len(kwargs["dim_mults"])
809
+ dim_mult = sr_stem_unet_kwargs["dim"] / (kwargs["dim"] * max(kwargs["dim_mults"]))
810
+ kwargs["dim_mults"] = [*kwargs["dim_mults"], int(dim_mult)]
811
+ kwargs["layer_attn"] = [*cast_tuple(kwargs["layer_attn"], n_layers), False]
812
+ kwargs["num_resnet_blocks"] = [*cast_tuple(kwargs["num_resnet_blocks"], n_layers), 1]
813
+ self.unet = UNet(
814
+ *args,
815
+ **kwargs
816
+ )
817
+ self.from_rgb = self.unet.from_rgb
818
+ self.to_rgb = self.unet.to_rgb
819
+ self.residual = residual
820
+ self.new_dataset = new_dataset
821
+ if residual:
822
+ nn.init.zeros_(self.to_rgb.weight.data)
823
+ del self.unet.from_rgb, self.unet.to_rgb
824
+
825
+ def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, **kwargs):
826
+ # Downsample for stem
827
+ if z is None:
828
+ z = self.get_z(img)
829
+ if w is None:
830
+ w = self.style_net(z)
831
+ if hasattr(self, "coarse_stem"):
832
+ with torch.no_grad():
833
+ if self.new_dataset:
834
+ img_stem = utils.denormalize_img(img)*255
835
+ condition_stem = img_stem * mask + (1-mask)*127
836
+ condition_stem = condition_stem.round()
837
+ condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True)
838
+ condition_stem = condition_stem / 255 *2 - 1
839
+ mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float()
840
+ maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float()
841
+ else:
842
+ mask_stem = (resize(mask, self.coarse_stem.imsize, antialias=True) > .99).float()
843
+ maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, antialias=True) > .5).float()
844
+ img_stem = utils.denormalize_img(img)*255
845
+ img_stem = resize(img_stem, self.coarse_stem.imsize, antialias=True).round()
846
+ img_stem = img_stem / 255 * 2 - 1
847
+ condition_stem = img_stem * mask_stem
848
+ stem_out = self.coarse_stem(
849
+ condition=condition_stem, mask=mask_stem,
850
+ maskrcnn_mask=maskrcnn_stem, w=w,
851
+ keypoints=keypoints)
852
+ x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True)
853
+ condition = condition*mask + (1-mask) * x_lr
854
+ if self.unet.use_maskrcnn_mask:
855
+ x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1)
856
+ else:
857
+ x = torch.cat((condition, mask, 1-mask), dim=1)
858
+ if self.unet.input_keypoints:
859
+ keypoints = keypoints[:, self.unet.input_keypoint_indices]
860
+ one_hot_pose = spatial_embed_keypoints(keypoints, x)
861
+ x = torch.cat((x, one_hot_pose), dim=1)
862
+ x = self.from_rgb(x)
863
+ x, unet_features = self.forward_enc(x, mask, w)
864
+ x = self.forward_dec(x, mask, w, unet_features)
865
+ if self.residual:
866
+ x = self.to_rgb(x) + condition
867
+ else:
868
+ x = self.to_rgb(x)
869
+ return dict(
870
+ img=condition * mask + (1-mask) * x,
871
+ unmasked=x,
872
+ x_lowres=[condition]
873
+ )
874
+
875
+ def forward_enc(self, x, mask, w):
876
+ x, unet_features = self.unet.forward_enc(x, mask, w)
877
+ if hasattr(self, "sr_stem"):
878
+ x, unet_features_stem = self.sr_stem.forward_enc(x, mask, w)
879
+ else:
880
+ unet_features_stem = None
881
+ return x, [unet_features, unet_features_stem]
882
+
883
+ def forward_dec(self, x, mask, w, unet_features):
884
+ unet_features, unet_features_stem = unet_features
885
+ if hasattr(self, "sr_stem"):
886
+ x = self.sr_stem.forward_dec(x, mask, w, unet_features_stem)
887
+ x, unet_features = self.unet.forward_dec(x, mask, w, unet_features)
888
+ return x
889
+
890
+ def get_z(self, *args, **kwargs):
891
+ if hasattr(self, "coarse_stem"):
892
+ return self.coarse_stem.get_z(*args, **kwargs)
893
+ if hasattr(self, "sr_stem"):
894
+ return self.sr_stem.get_z(*args, **kwargs)
895
+ raise AttributeError()
896
+
897
+ @property
898
+ def style_net(self):
899
+ if hasattr(self, "coarse_stem"):
900
+ return self.coarse_stem.style_net
901
+ if hasattr(self, "sr_stem"):
902
+ return self.sr_stem.style_net
903
+ raise AttributeError()
904
+
905
+ def update_w(self, *args, **kwargs):
906
+ self.style_net.update_w(*args, **kwargs)
907
+
908
+ def get_w(self, z, update_emas):
909
+ return self.style_net(z, update_emas=update_emas)
910
+
911
+ @torch.no_grad()
912
+ def sample(self, truncation_value, **kwargs):
913
+ if truncation_value is None:
914
+ return self.forward(**kwargs)
915
+ truncation_value = max(0, truncation_value)
916
+ truncation_value = min(truncation_value, 1)
917
+ w = self.get_w(self.get_z(kwargs["condition"]), False)
918
+ w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value)
919
+ return self.forward(**kwargs, w=w)
920
+
921
+ @torch.no_grad()
922
+ def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
923
+ if truncation_value is None:
924
+ return self.forward(**kwargs)
925
+ truncation_value = max(0, truncation_value)
926
+ truncation_value = min(truncation_value, 1)
927
+ w = self.get_w(self.get_z(kwargs["condition"]), False)
928
+ if w_indices is None:
929
+ w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w)))
930
+ w_centers = self.style_net.w_centers[w_indices].to(w.device)
931
+ w = w_centers.to(w.dtype).lerp(w, truncation_value)
932
+ return self.forward(**kwargs, w=w)
933
+
934
+
935
+ class CascadedUnet(BaseGenerator):
936
+
937
+ def __init__(
938
+ self,
939
+ coarse_stem_cfg: str, # This can be a coarse generator or None
940
+ residual: bool,
941
+ new_dataset: bool, # The "new dataset" creates condition first -> resizes
942
+ imsize: tuple,
943
+ cascade:bool,
944
+ **unet_kwargs):
945
+ kwargs = dict()
946
+ coarse_stem_cfg = utils.load_config(coarse_stem_cfg)
947
+ kwargs = get_stem_unet_kwargs(coarse_stem_cfg)
948
+ kwargs.update(unet_kwargs)
949
+ super().__init__(z_channels=kwargs["z_channels"])
950
+
951
+ self.input_keypoints = kwargs["input_keypoints"]
952
+ self.input_keypoint_indices = kwargs["input_keypoint_indices"]
953
+ self.use_maskrcnn_mask = kwargs["use_maskrcnn_mask"]
954
+ self.imsize = imsize
955
+ self.residual = residual
956
+ self.new_dataset = new_dataset
957
+
958
+
959
+ # Setup coarse stem
960
+ stem_dims = [m*coarse_stem_cfg.generator.dim for m in coarse_stem_cfg.generator.dim_mults]
961
+ self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval()
962
+ self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize)
963
+ utils.set_requires_grad(self.coarse_stem, False)
964
+
965
+ self.stem_res_to_layer_idx = {
966
+ self.coarse_stem.imsize[0] // 2^i: stem_dims[i]
967
+ for i in range(len(stem_dims))
968
+ }
969
+
970
+ dim = kwargs["dim"]
971
+ dim_mults = kwargs["dim_mults"]
972
+ n_layers = len(dim_mults)
973
+ dims = [dim*s for s in dim_mults]
974
+ layer_attn = cast_tuple(kwargs["layer_attn"], n_layers)
975
+ num_resnet_blocks = cast_tuple(kwargs["num_resnet_blocks"], n_layers)
976
+ attn_cls = kwargs["attn_cls"]
977
+ if not isinstance(attn_cls, partial):
978
+ attn_cls = instantiate(attn_cls)
979
+
980
+ dec_blk = partial(
981
+ SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_dec"],
982
+ use_adain=kwargs["use_adain"] and kwargs["dec_style"],
983
+ w_dim=kwargs["w_dim"],
984
+ cross_attention=kwargs["cross_attention"],
985
+ cross_attention_len=kwargs["cross_attention_len"],
986
+ gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"]
987
+ )
988
+ enc_blk = partial(
989
+ SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_enc"],
990
+ use_adain=kwargs["use_adain"] and kwargs["enc_style"],
991
+ w_dim=kwargs["w_dim"],
992
+ cross_attention=kwargs["cross_attention"],
993
+ cross_attention_len=kwargs["cross_attention_len"],
994
+ gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"]
995
+ )
996
+
997
+ # Currently up/down sampling is done by bilinear upsampling.
998
+ # This can be simplified by replacing it with a strided upsampling layer...
999
+ self.encoder_attns = nn.ModuleList()
1000
+ self.encoder_unet_skips = nn.ModuleDict()
1001
+ self.encoder = nn.ModuleList()
1002
+ for lidx in range(n_layers):
1003
+ has_stem_feature = imsize[0]//2^lidx in self.stem_res_to_layer_idx and cascade
1004
+ next_layer_has_stem_features = lidx+1 < n_layers and imsize[0]//2^(lidx+1) in self.stem_res_to_layer_idx and cascade
1005
+
1006
+ dim_in = dims[lidx]
1007
+ dim_out = dims[min(lidx+1, n_layers-1)]
1008
+ res_blocks = nn.ModuleList()
1009
+ if has_stem_feature:
1010
+ prev_layer_has_attention = lidx != 0 and layer_attn[lidx-1]
1011
+ stem_lidx = self.stem_res_to_layer_idx[imsize[0]//2^lidx]
1012
+ self.encoder_unet_skips.add_module(
1013
+ str(imsize[0]//2^lidx),
1014
+ Conv2d(
1015
+ stem_dims[stem_lidx], dim_in, kernel_size=1,
1016
+ conv_clamp=kwargs["conv_clamp"],
1017
+ norm=nn.InstanceNorm2d(None),
1018
+ gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"],
1019
+ gain=np.sqrt(1/4) if prev_layer_has_attention else np.sqrt(1/3) # This + previous residual + attention
1020
+ )
1021
+ )
1022
+ for i in range(num_resnet_blocks[lidx]):
1023
+ is_last = num_resnet_blocks[lidx] - 1 == i
1024
+ cur_dim = dim_out if is_last else dim_in
1025
+ if not is_last:
1026
+ gain = np.sqrt(.5)
1027
+ elif next_layer_has_stem_features and layer_attn[lidx]:
1028
+ gain = np.sqrt(1/4)
1029
+ elif layer_attn[lidx] or next_layer_has_stem_features:
1030
+ gain = np.sqrt(1/3)
1031
+ else:
1032
+ gain = np.sqrt(.5)
1033
+ block = enc_blk(dim_in, cur_dim, skip_gain=gain)
1034
+ res_blocks.append(block)
1035
+ if layer_attn[lidx]:
1036
+ self.encoder_attns.append(attn_cls(dim=dim_out, gain=gain, fix_attention_again=True))
1037
+ else:
1038
+ self.encoder_attns.append(Identity())
1039
+ self.encoder.append(res_blocks)
1040
+
1041
+ # initialize decoder
1042
+ self.decoder = torch.nn.ModuleList()
1043
+ self.unet_layers = torch.nn.ModuleList()
1044
+ self.decoder_attns = torch.nn.ModuleList()
1045
+ for lidx in range(n_layers):
1046
+ dim_in = dims[min(-lidx, -1)]
1047
+ dim_out = dims[-1-lidx]
1048
+ res_blocks = nn.ModuleList()
1049
+ unet_skips = nn.ModuleList()
1050
+ for i in range(num_resnet_blocks[-lidx-1]):
1051
+ is_first = i == 0
1052
+ has_unet = is_first or kwargs["skip_all_unets"]
1053
+ is_last = i == num_resnet_blocks[-lidx-1] - 1
1054
+ cur_dim = dim_in if is_first else dim_out
1055
+ if has_unet and is_last and layer_attn[-lidx-1]: # x + residual + unet + layer attn
1056
+ gain = np.sqrt(1/4)
1057
+ elif has_unet: # x + residual + unet
1058
+ gain = np.sqrt(1/3)
1059
+ elif layer_attn[-lidx-1]: # x + residual + attention
1060
+ gain = np.sqrt(1/3)
1061
+ else: # x + residual
1062
+ gain = np.sqrt(1/2) # Only residual block
1063
+ block = dec_blk(cur_dim, dim_out, skip_gain=gain)
1064
+ res_blocks.append(block)
1065
+ if kwargs["skip_all_unets"] or is_first:
1066
+ unet_block = Conv2d(
1067
+ cur_dim, cur_dim, kernel_size=1, conv_clamp=kwargs["conv_clamp"],
1068
+ norm=nn.InstanceNorm2d(None),
1069
+ gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"],
1070
+ gain=gain)
1071
+ unet_skips.append(unet_block)
1072
+ else:
1073
+ unet_skips.append(torch.nn.Identity())
1074
+ if layer_attn[-lidx-1]:
1075
+ self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=True, gain=gain))
1076
+ else:
1077
+ self.decoder_attns.append(Identity())
1078
+
1079
+ self.decoder.append(res_blocks)
1080
+ self.unet_layers.append(unet_skips)
1081
+
1082
+ self.from_rgb = Conv2d(
1083
+ 3 + 2 + 2*int(kwargs["use_maskrcnn_mask"]) + self.input_keypoints*len(kwargs["input_keypoint_indices"])
1084
+ , dim, 7)
1085
+ self.to_rgb = Conv2d(dim, 3, 1, activation="linear", conv_clamp=kwargs["conv_clamp"])
1086
+
1087
+ self.downsample = Upfirdn2d(down=2, fix_gain=True)
1088
+ self.upsample = Upfirdn2d(up=2, fix_gain=True)
1089
+ self.cascade = cascade
1090
+ if residual:
1091
+ nn.init.zeros_(self.to_rgb.weight.data)
1092
+
1093
+ def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, return_decoder_features=False, **kwargs):
1094
+ # Downsample for stem
1095
+ if z is None:
1096
+ z = self.get_z(img)
1097
+
1098
+ with torch.no_grad(): # Forward pass stem
1099
+ if w is None:
1100
+ w = self.style_net(z)
1101
+ img_stem = utils.denormalize_img(img)*255
1102
+ condition_stem = img_stem * mask + (1-mask)*127
1103
+ condition_stem = condition_stem.round()
1104
+ condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True)
1105
+ condition_stem = condition_stem / 255 *2 - 1
1106
+ mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float()
1107
+ maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float()
1108
+ stem_out = self.coarse_stem(
1109
+ condition=condition_stem, mask=mask_stem,
1110
+ maskrcnn_mask=maskrcnn_stem, w=w,
1111
+ keypoints=keypoints,
1112
+ return_decoder_features=True)
1113
+ stem_features = stem_out["decoder_features"]
1114
+ x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True)
1115
+ condition = condition*mask + (1-mask) * x_lr
1116
+
1117
+ if self.use_maskrcnn_mask:
1118
+ x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1)
1119
+ else:
1120
+ x = torch.cat((condition, mask, 1-mask), dim=1)
1121
+ if self.input_keypoints:
1122
+ keypoints = keypoints[:, self.input_keypoint_indices]
1123
+ one_hot_pose = spatial_embed_keypoints(keypoints, x)
1124
+ x = torch.cat((x, one_hot_pose), dim=1)
1125
+ x = self.from_rgb(x)
1126
+ x, unet_features = self.forward_enc(x, mask, w, stem_features)
1127
+ x, decoder_features = self.forward_dec(x, mask, w, unet_features)
1128
+ if self.residual:
1129
+ x = self.to_rgb(x) + condition
1130
+ else:
1131
+ x = self.to_rgb(x)
1132
+ out= dict(
1133
+ img=condition * mask + (1-mask) * x, # TODO: Probably do not want masked here... or ??
1134
+ unmasked=x,
1135
+ x_lowres=[condition]
1136
+ )
1137
+ if return_decoder_features:
1138
+ out["decoder_features"] = decoder_features
1139
+ return out
1140
+
1141
+ def forward_enc(self, x, mask, w, stem_features: List[torch.Tensor]):
1142
+ unet_features = []
1143
+ stem_features.reverse()
1144
+ for i, res_blocks in enumerate(self.encoder):
1145
+ is_last = i == len(self.encoder) - 1
1146
+ res = self.imsize[0]//2^i
1147
+ if str(res) in self.encoder_unet_skips.keys() and self.cascade:
1148
+ y = stem_features[self.stem_res_to_layer_idx[res]]
1149
+ y = self.encoder_unet_skips[i](y)
1150
+ x = y + x
1151
+ for block in res_blocks:
1152
+ x = block(x, w=w)
1153
+ unet_features.append(x)
1154
+ x = self.encoder_attns[i](x, mask)
1155
+ if not is_last:
1156
+ x = self.downsample(x)
1157
+ return x, unet_features
1158
+
1159
+ def forward_dec(self, x, mask, w, unet_features):
1160
+ features = []
1161
+ unet_features = iter(reversed(unet_features))
1162
+ for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)):
1163
+ is_last = i == len(self.decoder) - 1
1164
+ for skip, block in zip(unet_skip, res_blocks):
1165
+ skip_x = next(unet_features)
1166
+ if not isinstance(skip, torch.nn.Identity):
1167
+ skip_x = skip(skip_x)
1168
+ x = x + skip_x
1169
+ x = block(x, w=w)
1170
+ x = self.decoder_attns[i](x, mask)
1171
+ features.append(x)
1172
+ if not is_last:
1173
+ x = self.upsample(x)
1174
+ return x, features
1175
+
1176
+ def get_z(self, *args, **kwargs):
1177
+ return self.coarse_stem.get_z(*args, **kwargs)
1178
+
1179
+ @property
1180
+ def style_net(self):
1181
+ return self.coarse_stem.style_net
1182
+
1183
+ def update_w(self, *args, **kwargs):
1184
+ self.style_net.update_w(*args, **kwargs)
1185
+
1186
+ def get_w(self, z, update_emas):
1187
+ return self.style_net(z, update_emas=update_emas)
1188
+
1189
+ @torch.no_grad()
1190
+ def sample(self, truncation_value, **kwargs):
1191
+ if truncation_value is None:
1192
+ return self.forward(**kwargs)
1193
+ truncation_value = max(0, truncation_value)
1194
+ truncation_value = min(truncation_value, 1)
1195
+ w = self.get_w(self.get_z(kwargs["condition"]), False)
1196
+ w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value)
1197
+ return self.forward(**kwargs, w=w)
1198
+
1199
+ @torch.no_grad()
1200
+ def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
1201
+ if truncation_value is None:
1202
+ return self.forward(**kwargs)
1203
+ truncation_value = max(0, truncation_value)
1204
+ truncation_value = min(truncation_value, 1)
1205
+ w = self.get_w(self.get_z(kwargs["condition"]), False)
1206
+ if w_indices is None:
1207
+ w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w)))
1208
+ w_centers = self.style_net.w_centers[w_indices].to(w.device)
1209
+ w = w_centers.to(w.dtype).lerp(w, truncation_value)
1210
+ return self.forward(**kwargs, w=w)
dp2/generator/stylegan_unet.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from dp2.layers import Sequential
4
+ from dp2.layers.sg2_layers import Conv2d, FullyConnectedLayer, ResidualBlock
5
+ from .base import BaseStyleGAN
6
+ from typing import List, Tuple
7
+ from .utils import spatial_embed_keypoints, mask_output
8
+
9
+
10
+ def get_chsize(imsize, cnum, max_imsize, max_cnum_mul):
11
+ n = int(np.log2(max_imsize) - np.log2(imsize))
12
+ mul = min(2**n, max_cnum_mul)
13
+ ch = cnum * mul
14
+ return int(ch)
15
+
16
+ class StyleGANUnet(BaseStyleGAN):
17
+ def __init__(
18
+ self,
19
+ scale_grad: bool,
20
+ im_channels: int,
21
+ min_fmap_resolution: int,
22
+ imsize: List[int],
23
+ cnum: int,
24
+ max_cnum_mul: int,
25
+ mask_output: bool,
26
+ conv_clamp: int,
27
+ input_cse: bool,
28
+ cse_nc: int,
29
+ n_middle_blocks: int,
30
+ input_keypoints: bool,
31
+ n_keypoints: int,
32
+ input_keypoint_indices: Tuple[int],
33
+ fix_errors: bool,
34
+ **kwargs
35
+ ) -> None:
36
+ super().__init__(**kwargs)
37
+ self.n_keypoints = n_keypoints
38
+ self.input_keypoint_indices = list(input_keypoint_indices)
39
+ self.input_keypoints = input_keypoints
40
+ assert not (input_cse and input_keypoints)
41
+ cse_nc = 0 if cse_nc is None else cse_nc
42
+ self.imsize = imsize
43
+ self._cnum = cnum
44
+ self._max_cnum_mul = max_cnum_mul
45
+ self._min_fmap_resolution = min_fmap_resolution
46
+ self._image_channels = im_channels
47
+ self._max_imsize = max(imsize)
48
+ self.input_cse = input_cse
49
+ self.gain_unet = np.sqrt(1/3)
50
+ n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1
51
+ encoder_layers = []
52
+ self.from_rgb = Conv2d(
53
+ im_channels + 1 + input_cse*(cse_nc+1) + input_keypoints*len(self.input_keypoint_indices),
54
+ cnum, 1
55
+ )
56
+ for i in range(n_levels): # Encoder layers
57
+ resolution = [x//2**i for x in imsize]
58
+ in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul)
59
+ second_ch = in_ch
60
+ out_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul)
61
+ down = 2
62
+
63
+ if i == 0: # first (lowest) block. Downsampling is performed at the start of the block
64
+ down = 1
65
+ if i == n_levels - 1:
66
+ out_ch = second_ch
67
+ block = ResidualBlock(in_ch, out_ch, down=down, conv_clamp=conv_clamp, fix_residual=fix_errors)
68
+ encoder_layers.append(block)
69
+ self._encoder_out_shape = [
70
+ get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul),
71
+ *resolution]
72
+
73
+ self.encoder = torch.nn.ModuleList(encoder_layers)
74
+
75
+ # initialize decoder
76
+ decoder_layers = []
77
+ for i in range(n_levels):
78
+ resolution = [x//2**(n_levels-1-i) for x in imsize]
79
+ in_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul)
80
+ out_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul)
81
+ if i == 0: # first (lowest) block
82
+ in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul)
83
+
84
+ up = 1
85
+ if i != n_levels - 1:
86
+ up = 2
87
+ block = ResidualBlock(
88
+ in_ch, out_ch, conv_clamp=conv_clamp, gain_out=np.sqrt(1/3),
89
+ w_dim=self.style_net.w_dim, norm=True, up=up,
90
+ fix_residual=fix_errors
91
+ )
92
+ decoder_layers.append(block)
93
+ if i != 0:
94
+ unet_block = Conv2d(
95
+ in_ch, in_ch, kernel_size=1, conv_clamp=conv_clamp, norm=True,
96
+ gain=np.sqrt(1/3) if fix_errors else np.sqrt(.5))
97
+ setattr(self, f"unet_block{i}", unet_block)
98
+
99
+ # Initialize "middle blocks" that do not have down/up sample
100
+ middle_blocks = []
101
+ for i in range(n_middle_blocks):
102
+ ch = get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul)
103
+ block = ResidualBlock(
104
+ ch, ch, conv_clamp=conv_clamp, gain_out=np.sqrt(.5) if fix_errors else np.sqrt(1/3),
105
+ w_dim=self.style_net.w_dim, norm=True,
106
+ )
107
+ middle_blocks.append(block)
108
+ if n_middle_blocks != 0:
109
+ self.middle_blocks = Sequential(*middle_blocks)
110
+ self.decoder = torch.nn.ModuleList(decoder_layers)
111
+ self.to_rgb = Conv2d(cnum, im_channels, 1, activation="linear", conv_clamp=conv_clamp)
112
+ # Initialize "middle blocks" that do not have down/up sample
113
+ self.decoder = torch.nn.ModuleList(decoder_layers)
114
+ self.scale_grad = scale_grad
115
+ self.mask_output = mask_output
116
+
117
+ def forward_dec(self, x, w, unet_features, condition, mask, s, **kwargs):
118
+ for i, layer in enumerate(self.decoder):
119
+ if i != 0:
120
+ unet_layer = getattr(self, f"unet_block{i}")
121
+ x = x + unet_layer(unet_features[-i])
122
+ x = layer(x, w=w, s=s)
123
+ x = self.to_rgb(x)
124
+ if self.mask_output:
125
+ x = mask_output(True, condition, x, mask)
126
+ return dict(img=x)
127
+
128
+ def forward_enc(self, condition, mask, embedding, keypoints, E_mask, **kwargs):
129
+ if self.input_cse:
130
+ x = torch.cat((condition, mask, embedding, E_mask), dim=1)
131
+ else:
132
+ x = torch.cat((condition, mask), dim=1)
133
+ if self.input_keypoints:
134
+ keypoints = keypoints[:, self.input_keypoint_indices]
135
+ one_hot_pose = spatial_embed_keypoints(keypoints, x)
136
+ x = torch.cat((x, one_hot_pose), dim=1)
137
+ x = self.from_rgb(x)
138
+
139
+ unet_features = []
140
+ for i, layer in enumerate(self.encoder):
141
+ x = layer(x)
142
+ if i != len(self.encoder)-1:
143
+ unet_features.append(x)
144
+ if hasattr(self, "middle_blocks"):
145
+ for layer in self.middle_blocks:
146
+ x = layer(x)
147
+ return x, unet_features
148
+
149
+ def forward(
150
+ self, condition, mask,
151
+ z=None, embedding=None, w=None, update_emas=False, x=None,
152
+ s=None,
153
+ keypoints=None,
154
+ unet_features=None,
155
+ E_mask=None,
156
+ **kwargs):
157
+ # Used to skip sampling from encoder in inference. E.g. for w projection.
158
+ if x is not None and unet_features is not None:
159
+ assert not self.training
160
+ else:
161
+ x, unet_features = self.forward_enc(condition, mask, embedding, keypoints, E_mask, **kwargs)
162
+ if w is None:
163
+ if z is None:
164
+ z = self.get_z(condition)
165
+ w = self.get_w(z, update_emas=update_emas)
166
+ return self.forward_dec(x, w, unet_features, condition, mask, s, **kwargs)
167
+
168
+ class ComodStyleUNet(StyleGANUnet):
169
+
170
+ def __init__(self, min_comod_res=4, lr_multiplier_comod=1, **kwargs) -> None:
171
+ super().__init__(**kwargs)
172
+ min_fmap = min(self._encoder_out_shape[1:])
173
+ enc_out_ch = self._encoder_out_shape[0]
174
+ n_down = int(np.ceil(np.log2(min_fmap) - np.log2(min_comod_res)))
175
+ comod_layers = []
176
+ in_ch = enc_out_ch
177
+ for i in range(n_down):
178
+ comod_layers.append(Conv2d(enc_out_ch, 256, kernel_size=3, down=2, lr_multiplier=lr_multiplier_comod))
179
+ in_ch = 256
180
+ if n_down == 0:
181
+ comod_layers = [Conv2d(in_ch, 256, kernel_size=3)]
182
+ comod_layers.append(torch.nn.Flatten())
183
+ out_res = [x//2**n_down for x in self._encoder_out_shape[1:]]
184
+ in_ch_fc = np.prod(out_res) * 256
185
+ comod_layers.append(FullyConnectedLayer(in_ch_fc, 512, lr_multiplier=lr_multiplier_comod))
186
+ self.comod_block = Sequential(*comod_layers)
187
+ self.comod_fc = FullyConnectedLayer(512+self.style_net.w_dim, self.style_net.w_dim, lr_multiplier=lr_multiplier_comod)
188
+
189
+ def forward_dec(self, x, w, unet_features, condition, mask, **kwargs):
190
+ y = self.comod_block(x)
191
+ y = torch.cat((y, w), dim=1)
192
+ y = self.comod_fc(y)
193
+ for i, layer in enumerate(self.decoder):
194
+ if i != 0:
195
+ unet_layer = getattr(self, f"unet_block{i}")
196
+ x = x + unet_layer(unet_features[-i], gain=np.sqrt(.5))
197
+ x = layer(x, w=y)
198
+ x = self.to_rgb(x)
199
+ if self.mask_output:
200
+ x = mask_output(True, condition, x, mask)
201
+ return dict(img=x)
202
+
203
+ def get_comod_y(self, batch, w):
204
+ x, unet_features = self.forward_enc(**batch)
205
+ y = self.comod_block(x)
206
+ y = torch.cat((y, w), dim=1)
207
+ y = self.comod_fc(y)
208
+ return y
dp2/generator/utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tops
3
+ import torch
4
+ from torch.cuda.amp import custom_bwd, custom_fwd
5
+
6
+
7
+ @torch.no_grad()
8
+ def spatial_embed_keypoints(keypoints: torch.Tensor, x):
9
+ tops.assert_shape(keypoints, (None, None, 3))
10
+ B, N_K, _ = keypoints.shape
11
+ H, W = x.shape[-2:]
12
+ keypoint_spatial = torch.zeros(keypoints.shape[0], N_K, H, W, device=keypoints.device, dtype=torch.float32)
13
+ x, y, visible = keypoints.chunk(3, dim=2)
14
+ x = (x * W).round().long().clamp(0, W-1)
15
+ y = (y * H).round().long().clamp(0, H-1)
16
+ kp_idx = torch.arange(0, N_K, 1, device=keypoints.device, dtype=torch.long).view(1, -1, 1).repeat(B, 1, 1)
17
+ pos = (kp_idx*(H*W) + y*W + x + 1)
18
+ # Offset all by 1 to index invisible keypoints as 0
19
+ pos = (pos * visible.round().long()).squeeze(dim=-1)
20
+ keypoint_spatial = torch.zeros(keypoints.shape[0], N_K*H*W+1, device=keypoints.device, dtype=torch.float32)
21
+ keypoint_spatial.scatter_(1, pos, 1)
22
+ keypoint_spatial = keypoint_spatial[:, 1:].view(-1, N_K, H, W)
23
+ return keypoint_spatial
24
+
25
+ class MaskOutput(torch.autograd.Function):
26
+
27
+ @staticmethod
28
+ @custom_fwd
29
+ def forward(ctx, x_real, x_fake, mask):
30
+ ctx.save_for_backward(mask)
31
+ out = x_real * mask + (1-mask) * x_fake
32
+ return out
33
+
34
+ @staticmethod
35
+ @custom_bwd
36
+ def backward(ctx, grad_output):
37
+ fake_grad = grad_output
38
+ mask, = ctx.saved_tensors
39
+ fake_grad = fake_grad * (1 - mask)
40
+ known_percentage = mask.view(mask.shape[0], -1).mean(dim=1)
41
+ fake_grad = fake_grad / (1-known_percentage).view(-1, 1, 1, 1)
42
+ return None, fake_grad, None
43
+
44
+
45
+ def mask_output(scale_grad, x_real, x_fake, mask):
46
+ if scale_grad:
47
+ return MaskOutput.apply(x_real, x_fake, mask)
48
+ return x_real * mask + (1-mask) * x_fake
dp2/infer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tops
2
+ import torch
3
+ from tops import checkpointer
4
+ from tops.config import instantiate
5
+ from tops.logger import warn
6
+
7
+ def load_generator_state(ckpt, G: torch.nn.Module, ckpt_mapper=None):
8
+ state = ckpt["EMA_generator"] if "EMA_generator" in ckpt else ckpt["running_average_generator"]
9
+ if ckpt_mapper is not None:
10
+ state = ckpt_mapper(state)
11
+ load_state_dict(G, state)
12
+ tops.logger.log(f"Generator loaded, num parameters: {tops.num_parameters(G)/1e6}M")
13
+ print(ckpt.keys())
14
+ if "w_centers" in ckpt:
15
+ print("Has w_centers!")
16
+ G.style_net.w_centers = ckpt["w_centers"]
17
+ tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}")
18
+
19
+
20
+ def build_trained_generator(cfg, map_location=None):
21
+ map_location = map_location if map_location is not None else tops.get_device()
22
+ G = instantiate(cfg.generator).to(map_location)
23
+ G.eval()
24
+ G.imsize = tuple(cfg.data.imsize) if hasattr(cfg, "data") else None
25
+ if hasattr(cfg, "ckpt_mapper"):
26
+ ckpt_mapper = instantiate(cfg.ckpt_mapper)
27
+ else:
28
+ ckpt_mapper = None
29
+ if "model_url" in cfg.common:
30
+ ckpt = tops.load_file_or_url(cfg.common.model_url, md5sum=cfg.common.model_md5sum)
31
+ load_generator_state(ckpt, G, ckpt_mapper)
32
+ return G
33
+ try:
34
+ ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
35
+ load_generator_state(ckpt, G, ckpt_mapper)
36
+ except FileNotFoundError as e:
37
+ tops.logger.warn(f"Did not find generator checkpoint in: {cfg.checkpoint_dir}")
38
+ return G
39
+
40
+
41
+ def build_trained_discriminator(cfg, map_location=None):
42
+ map_location = map_location if map_location is not None else tops.get_device()
43
+ D = instantiate(cfg.discriminator).to(map_location)
44
+ D.eval()
45
+ try:
46
+ ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
47
+ if hasattr(cfg, "ckpt_mapper_D"):
48
+ ckpt["discriminator"] = instantiate(cfg.ckpt_mapper_D)(ckpt["discriminator"])
49
+ D.load_state_dict(ckpt["discriminator"])
50
+ except FileNotFoundError as e:
51
+ tops.logger.warn(f"Did not find discriminator checkpoint in: {cfg.checkpoint_dir}")
52
+ return D
53
+
54
+
55
+ def load_state_dict(module: torch.nn.Module, state_dict: dict):
56
+ module_sd = module.state_dict()
57
+ to_remove = []
58
+ for key, item in state_dict.items():
59
+ if key not in module_sd:
60
+ continue
61
+ if item.shape != module_sd[key].shape:
62
+ to_remove.append(key)
63
+ warn(f"Incorrect shape. Current model: {module_sd[key].shape}, in state dict: {item.shape} for key: {key}")
64
+ for key in to_remove:
65
+ state_dict.pop(key)
66
+ for key, item in state_dict.items():
67
+ if key not in module_sd:
68
+ warn(f"Did not fin key in model state dict: {key}")
69
+ for key, item in module_sd.items():
70
+ if key not in state_dict:
71
+ warn(f"Did not find key in state dict: {key}")
72
+ module.load_state_dict(state_dict, strict=False)
dp2/layers/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import torch
3
+ import tops
4
+ import torch.nn as nn
5
+
6
+ class Sequential(nn.Sequential):
7
+
8
+ def forward(self, x: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]:
9
+ for module in self:
10
+ x = module(x, **kwargs)
11
+ return x
12
+
13
+ class Module(nn.Module):
14
+
15
+ def __init__(self, *args, **kwargs):
16
+ super().__init__()
17
+
18
+ def extra_repr(self):
19
+ num_params = tops.num_parameters(self) / 10**6
20
+ return f"Num params: {num_params:.3f}M"
dp2/layers/sg2_layers.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import numpy as np
3
+ import torch
4
+ import tops
5
+ import torch.nn.functional as F
6
+ from sg3_torch_utils.ops import conv2d_resample
7
+ from sg3_torch_utils.ops import upfirdn2d
8
+ from sg3_torch_utils.ops import bias_act
9
+ from sg3_torch_utils.ops.fma import fma
10
+
11
+
12
+ class FullyConnectedLayer(torch.nn.Module):
13
+ def __init__(self,
14
+ in_features, # Number of input features.
15
+ out_features, # Number of output features.
16
+ bias = True, # Apply additive bias before the activation function?
17
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
18
+ lr_multiplier = 1, # Learning rate multiplier.
19
+ bias_init = 0, # Initial value for the additive bias.
20
+ ):
21
+ super().__init__()
22
+ self.repr = dict(
23
+ in_features=in_features, out_features=out_features, bias=bias,
24
+ activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init)
25
+ self.activation = activation
26
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
27
+ self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
28
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
29
+ self.bias_gain = lr_multiplier
30
+ self.in_features = in_features
31
+ self.out_features = out_features
32
+
33
+ def forward(self, x):
34
+ w = self.weight * self.weight_gain
35
+ b = self.bias
36
+ if b is not None and self.bias_gain != 1:
37
+ b = b * self.bias_gain
38
+ x = F.linear(x, w)
39
+ x = bias_act.bias_act(x, b, act=self.activation)
40
+ return x
41
+
42
+ def extra_repr(self) -> str:
43
+ return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
44
+
45
+
46
+ class Conv2d(torch.nn.Module):
47
+ def __init__(self,
48
+ in_channels, # Number of input channels.
49
+ out_channels, # Number of output channels.
50
+ kernel_size = 3, # Convolution kernel size.
51
+ up = 1, # Integer upsampling factor.
52
+ down = 1, # Integer downsampling factor
53
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
54
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
55
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
56
+ bias = True,
57
+ norm = False,
58
+ lr_multiplier=1,
59
+ bias_init=0,
60
+ w_dim=None,
61
+ gain=1,
62
+ ):
63
+ super().__init__()
64
+ if norm:
65
+ self.norm = torch.nn.InstanceNorm2d(None)
66
+ assert norm in [True, False]
67
+ self.up = up
68
+ self.down = down
69
+ self.activation = activation
70
+ self.conv_clamp = conv_clamp if conv_clamp is None else conv_clamp * gain
71
+ self.out_channels = out_channels
72
+ self.in_channels = in_channels
73
+ self.padding = kernel_size // 2
74
+
75
+ self.repr = dict(
76
+ in_channels=in_channels, out_channels=out_channels,
77
+ kernel_size=kernel_size, up=up, down=down,
78
+ activation=activation, resample_filter=resample_filter, conv_clamp=conv_clamp, bias=bias,
79
+ )
80
+
81
+ if self.up == 1 and self.down == 1:
82
+ self.resample_filter = None
83
+ else:
84
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
85
+
86
+ self.act_gain = bias_act.activation_funcs[activation].def_gain * gain
87
+ self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2))
88
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]))
89
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels])+bias_init) if bias else None
90
+ self.bias_gain = lr_multiplier
91
+ if w_dim is not None:
92
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
93
+ self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0)
94
+
95
+ def forward(self, x, w=None, s=None):
96
+ tops.assert_shape(x, [None, self.weight.shape[1], None, None])
97
+ if s is not None:
98
+ s = s[..., :self.in_channels*2]
99
+ gamma, beta = s.view(-1, self.in_channels*2, 1, 1).chunk(2, dim=1)
100
+ x = fma(x, gamma, beta)
101
+ elif hasattr(self, "affine"):
102
+ gamma = self.affine(w).view(-1, self.in_channels, 1, 1)
103
+ beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1)
104
+ x = fma(x, gamma, beta)
105
+ w = self.weight * self.weight_gain
106
+ # Removing flip weight is not safe.
107
+ x = conv2d_resample.conv2d_resample(x, w, self.resample_filter, self.up, self.down, self.padding, flip_weight=self.up==1)
108
+ if hasattr(self, "norm"):
109
+ x = self.norm(x)
110
+ b = self.bias * self.bias_gain if self.bias is not None else None
111
+ x = bias_act.bias_act(x, b, act=self.activation, gain=self.act_gain, clamp=self.conv_clamp)
112
+ return x
113
+
114
+ def extra_repr(self) -> str:
115
+ return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
116
+
117
+
118
+ class Block(torch.nn.Module):
119
+ def __init__(self,
120
+ in_channels, # Number of input channels, 0 = first block.
121
+ out_channels, # Number of output channels.
122
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
123
+ up = 1,
124
+ down = 1,
125
+ **layer_kwargs, # Arguments for SynthesisLayer.
126
+ ):
127
+ super().__init__()
128
+ self.in_channels = in_channels
129
+ self.down = down
130
+ self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs)
131
+ self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, **layer_kwargs)
132
+
133
+ def forward(self, x, **layer_kwargs):
134
+ x = self.conv0(x, **layer_kwargs)
135
+ x = self.conv1(x, **layer_kwargs)
136
+ return x
137
+
138
+
139
+ class ResidualBlock(torch.nn.Module):
140
+ def __init__(self,
141
+ in_channels, # Number of input channels, 0 = first block.
142
+ out_channels, # Number of output channels.
143
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
144
+ up = 1,
145
+ down = 1,
146
+ gain_out=np.sqrt(0.5),
147
+ fix_residual: bool = False,
148
+ **layer_kwargs, # Arguments for conv layer.
149
+ ):
150
+ super().__init__()
151
+ self.in_channels = in_channels
152
+ self.out_channels = out_channels
153
+ self.down = down
154
+ self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs)
155
+
156
+ self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, gain=gain_out,**layer_kwargs)
157
+
158
+ self.skip = Conv2d(
159
+ in_channels, out_channels, kernel_size=1, bias=False, up=up, down=down,
160
+ activation="linear" if fix_residual else "lrelu",
161
+ gain=gain_out
162
+ )
163
+ self.gain_out = gain_out
164
+
165
+ def forward(self, x, w=None, s=None, **layer_kwargs):
166
+ y = self.skip(x)
167
+ s_ = next(s) if s is not None else None
168
+ x = self.conv0(x, w, s=s_, **layer_kwargs)
169
+ s_ = next(s) if s is not None else None
170
+ x = self.conv1(x, w, s=s_, **layer_kwargs)
171
+ x = y + x
172
+ return x
173
+
174
+
175
+ class MinibatchStdLayer(torch.nn.Module):
176
+ def __init__(self, group_size, num_channels=1):
177
+ super().__init__()
178
+ self.group_size = group_size
179
+ self.num_channels = num_channels
180
+
181
+ def forward(self, x):
182
+ N, C, H, W = x.shape
183
+ with tops.suppress_tracer_warnings(): # as_tensor results are registered as constants
184
+ G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
185
+ F = self.num_channels
186
+ c = C // F
187
+
188
+ y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
189
+ y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
190
+ y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
191
+ y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
192
+ y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels.
193
+ y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
194
+ y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
195
+ x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
196
+ return x
197
+
198
+ #----------------------------------------------------------------------------
199
+
200
+ class DiscriminatorEpilogue(torch.nn.Module):
201
+ def __init__(self,
202
+ in_channels, # Number of input channels.
203
+ resolution: List[int], # Resolution of this block.
204
+ mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
205
+ mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
206
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
207
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
208
+ ):
209
+ super().__init__()
210
+ self.in_channels = in_channels
211
+ self.resolution = resolution
212
+ self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
213
+ self.conv = Conv2d(
214
+ in_channels + mbstd_num_channels, in_channels,
215
+ kernel_size=3, activation=activation, conv_clamp=conv_clamp)
216
+ self.fc = FullyConnectedLayer(in_channels * resolution[0]*resolution[1], in_channels, activation=activation)
217
+ self.out = FullyConnectedLayer(in_channels, 1)
218
+
219
+ def forward(self, x):
220
+ tops.assert_shape(x, [None, self.in_channels, *self.resolution]) # [NCHW]
221
+ # Main layers.
222
+ if self.mbstd is not None:
223
+ x = self.mbstd(x)
224
+ x = self.conv(x)
225
+ x = self.fc(x.flatten(1))
226
+ x = self.out(x)
227
+ return x
dp2/loss/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sg2_loss import StyleGAN2Loss
dp2/loss/pl_regularization.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tops
3
+ import numpy as np
4
+ from sg3_torch_utils.ops import conv2d_gradfix
5
+
6
+ pl_mean_total = torch.zeros([])
7
+
8
+ class PLRegularization:
9
+
10
+ def __init__(self, weight: float, batch_shrink: int, pl_decay:float, scale_by_mask: bool,**kwargs):
11
+ self.pl_mean = torch.zeros([], device=tops.get_device())
12
+ self.pl_weight = weight
13
+ self.batch_shrink = batch_shrink
14
+ self.pl_decay = pl_decay
15
+ self.scale_by_mask = scale_by_mask
16
+
17
+ def __call__(self, G, batch, grad_scaler):
18
+ batch_size = batch["img"].shape[0] // self.batch_shrink
19
+ batch = {k: v[:batch_size] for k, v in batch.items() if k != "embed_map"}
20
+ if "embed_map" in batch:
21
+ batch["embed_map"] = batch["embed_map"]
22
+ z = G.get_z(batch["img"])
23
+
24
+ with torch.cuda.amp.autocast(tops.AMP()):
25
+ gen_ws = G.style_net(z)
26
+ gen_img = G(**batch, w=gen_ws)["img"].float()
27
+ pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
28
+ with conv2d_gradfix.no_weight_gradients():
29
+ # Sums over HWC
30
+ pl_grads = torch.autograd.grad(
31
+ outputs=[grad_scaler.scale(gen_img * pl_noise)],
32
+ inputs=[gen_ws],
33
+ create_graph=True,
34
+ grad_outputs=torch.ones_like(gen_img),
35
+ only_inputs=True)[0]
36
+
37
+ pl_grads = pl_grads.float() / grad_scaler.get_scale()
38
+ if self.scale_by_mask:
39
+ # Percentage of pixels known
40
+ scaling = batch["mask"].flatten(start_dim=1).mean(dim=1).view(-1, 1)
41
+ pl_grads = pl_grads / scaling
42
+ pl_lengths = pl_grads.square().sum(1).sqrt()
43
+ pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
44
+ if not torch.isnan(pl_mean).any():
45
+ self.pl_mean.copy_(pl_mean.detach())
46
+ pl_penalty = (pl_lengths - pl_mean).square()
47
+ to_log = dict(pl_penalty=pl_penalty.mean().detach())
48
+ return pl_penalty.view(-1) * self.pl_weight, to_log
dp2/loss/r1_regularization.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tops
3
+
4
+ def r1_regularization(
5
+ real_img, real_score, mask, lambd: float, lazy_reg_interval: int,
6
+ lazy_regularization: bool,
7
+ scaler: torch.cuda.amp.GradScaler, mask_out: bool,
8
+ mask_out_scale: bool,
9
+ **kwargs
10
+ ):
11
+ grad = torch.autograd.grad(
12
+ outputs=scaler.scale(real_score),
13
+ inputs=real_img,
14
+ grad_outputs=torch.ones_like(real_score),
15
+ create_graph=True,
16
+ only_inputs=True,
17
+ )[0]
18
+ inv_scale = 1.0 / scaler.get_scale()
19
+ grad = grad * inv_scale
20
+ with torch.cuda.amp.autocast(tops.AMP()):
21
+ if mask_out:
22
+ grad = grad * (1 - mask)
23
+ grad = grad.square().sum(dim=[1, 2, 3])
24
+ if mask_out and mask_out_scale:
25
+ total_pixels = real_img.shape[1] * real_img.shape[2] * real_img.shape[3]
26
+ n_fake = (1-mask).sum(dim=[1, 2, 3])
27
+ scaling = total_pixels / n_fake
28
+ grad = grad * scaling
29
+ if lazy_regularization:
30
+ lambd_ = lambd * lazy_reg_interval / 2 # From stylegan2, lazy regularization
31
+ return grad * lambd_, grad.detach()
dp2/loss/sg2_loss.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+ import tops
4
+ from tops import logger
5
+ from dp2.utils import forward_D_fake
6
+ from .utils import nsgan_d_loss, nsgan_g_loss
7
+ from .r1_regularization import r1_regularization
8
+ from .pl_regularization import PLRegularization
9
+
10
+ class StyleGAN2Loss:
11
+
12
+ def __init__(
13
+ self,
14
+ D,
15
+ G,
16
+ r1_opts: dict,
17
+ EP_lambd: float,
18
+ lazy_reg_interval: int,
19
+ lazy_regularization: bool,
20
+ pl_reg_opts: dict,
21
+ ) -> None:
22
+ self.gradient_step_D = 0
23
+ self._lazy_reg_interval = lazy_reg_interval
24
+ self.D = D
25
+ self.G = G
26
+ self.EP_lambd = EP_lambd
27
+ self.lazy_regularization = lazy_regularization
28
+ self.r1_reg = functools.partial(
29
+ r1_regularization, **r1_opts, lazy_reg_interval=lazy_reg_interval,
30
+ lazy_regularization=lazy_regularization)
31
+ self.do_PL_Reg = False
32
+ if pl_reg_opts.weight > 0:
33
+ self.pl_reg = PLRegularization(**pl_reg_opts)
34
+ self.do_PL_Reg = True
35
+ self.pl_start_nimg = pl_reg_opts.start_nimg
36
+
37
+ def D_loss(self, batch: dict, grad_scaler):
38
+ to_log = {}
39
+ # Forward through G and D
40
+ do_GP = self.lazy_regularization and self.gradient_step_D % self._lazy_reg_interval == 0
41
+ if do_GP:
42
+ batch["img"] = batch["img"].detach().requires_grad_(True)
43
+ with torch.cuda.amp.autocast(enabled=tops.AMP()):
44
+ with torch.no_grad():
45
+ G_fake = self.G(**batch, update_emas=True)
46
+ D_out_real = self.D(**batch)
47
+
48
+ D_out_fake = forward_D_fake(batch, G_fake["img"], self.D)
49
+
50
+ # Non saturating loss
51
+ nsgan_loss = nsgan_d_loss(D_out_real["score"], D_out_fake["score"])
52
+ tops.assert_shape(nsgan_loss, (batch["img"].shape[0], ))
53
+ to_log["d_loss"] = nsgan_loss.mean()
54
+ total_loss = nsgan_loss
55
+ epsilon_penalty = D_out_real["score"].pow(2).view(-1)
56
+ to_log["epsilon_penalty"] = epsilon_penalty.mean()
57
+ tops.assert_shape(epsilon_penalty, total_loss.shape)
58
+ total_loss = total_loss + epsilon_penalty * self.EP_lambd
59
+
60
+ # Improved gradient penalty with lazy regularization
61
+ # Gradient penalty applies specialized autocast.
62
+ if do_GP:
63
+ gradient_pen, grad_unscaled = self.r1_reg(batch["img"], D_out_real["score"], batch["mask"], scaler=grad_scaler)
64
+ to_log["r1_gradient_penalty"] = grad_unscaled.mean()
65
+ tops.assert_shape(gradient_pen, total_loss.shape)
66
+ total_loss = total_loss + gradient_pen
67
+
68
+ batch["img"] = batch["img"].detach().requires_grad_(False)
69
+ if "score" in D_out_real:
70
+ to_log["real_scores"] = D_out_real["score"]
71
+ to_log["real_logits_sign"] = D_out_real["score"].sign()
72
+ to_log["fake_logits_sign"] = D_out_fake["score"].sign()
73
+ to_log["fake_scores"] = D_out_fake["score"]
74
+ to_log = {key: item.mean().detach() for key, item in to_log.items()}
75
+ self.gradient_step_D += 1
76
+ return total_loss.mean(), to_log
77
+
78
+ def G_loss(self, batch: dict, grad_scaler):
79
+ with torch.cuda.amp.autocast(enabled=tops.AMP()):
80
+ to_log = {}
81
+ # Forward through G and D
82
+ G_fake = self.G(**batch)
83
+ D_out_fake = forward_D_fake(batch, G_fake["img"], self.D)
84
+ # Adversarial Loss
85
+ total_loss = nsgan_g_loss(D_out_fake["score"]).view(-1)
86
+ to_log["g_loss"] = total_loss.mean()
87
+ tops.assert_shape(total_loss, (batch["img"].shape[0], ))
88
+
89
+ if self.do_PL_Reg and logger.global_step() >= self.pl_start_nimg:
90
+ pl_reg, to_log_ = self.pl_reg(self.G, batch, grad_scaler=grad_scaler)
91
+ total_loss = total_loss + pl_reg.mean()
92
+ to_log.update(to_log_)
93
+ to_log = {key: item.mean().detach() for key, item in to_log.items()}
94
+ return total_loss.mean(), to_log
dp2/loss/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ def nsgan_g_loss(fake_score):
5
+ """
6
+ Non-saturating criterion from Goodfellow et al. 2014
7
+ """
8
+ return torch.nn.functional.softplus(-fake_score)
9
+
10
+
11
+ def nsgan_d_loss(real_score, fake_score):
12
+ """
13
+ Non-saturating criterion from Goodfellow et al. 2014
14
+ """
15
+ d_loss = F.softplus(-real_score) + F.softplus(fake_score)
16
+ return d_loss.view(-1)
17
+
18
+
19
+ def smooth_masked_l1_loss(x, target, mask):
20
+ """
21
+ Pixel-wise l1 loss for the area indicated by mask
22
+ """
23
+ # Beta=.1 <-> square loss if pixel difference <= 12.8
24
+ l1 = F.smooth_l1_loss(x*mask, target*mask, beta=.1, reduction="none").sum(dim=[1,2,3]) / mask.sum(dim=[1, 2, 3])
25
+ return l1
dp2/metrics/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .torch_metrics import compute_metrics_iteratively
2
+ from .fid import compute_fid
3
+ from .ppl import calculate_ppl
dp2/metrics/fid.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tops
2
+ from dp2 import utils
3
+ from pathlib import Path
4
+ from torch_fidelity.generative_model_modulewrapper import GenerativeModelModuleWrapper
5
+ import torch
6
+ import torch_fidelity
7
+
8
+
9
+ class GeneratorIteratorWrapper(GenerativeModelModuleWrapper):
10
+
11
+ def __init__(self, generator, dataloader, zero_z: bool, n_diverse: int):
12
+ if isinstance(generator, utils.EMA):
13
+ generator = generator.generator
14
+ z_size = generator.z_channels
15
+ super().__init__(generator, z_size, "normal", 0)
16
+ self.zero_z = zero_z
17
+ self.dataloader = iter(dataloader)
18
+ self.n_diverse = n_diverse
19
+ self.cur_div_idx = 0
20
+
21
+ @torch.no_grad()
22
+ def forward(self, z, **kwargs):
23
+ if self.cur_div_idx == 0:
24
+ self.batch = next(self.dataloader)
25
+ if self.zero_z:
26
+ z = z.zero_()
27
+ self.cur_div_idx += 1
28
+ self.cur_div_idx = 0 if self.cur_div_idx == self.n_diverse else self.cur_div_idx
29
+ with torch.cuda.amp.autocast(enabled=tops.AMP()):
30
+ img = self.module(**self.batch)["img"]
31
+ img = (utils.denormalize_img(img)*255).byte()
32
+ return img
33
+
34
+
35
+ def compute_fid(generator, dataloader, real_directory, n_source, zero_z, n_diverse):
36
+ generator = GeneratorIteratorWrapper(generator, dataloader, zero_z, n_diverse)
37
+ batch_size = dataloader.batch_size
38
+ num_samples = (n_source * n_diverse) // batch_size * batch_size
39
+ assert n_diverse >= 1
40
+ assert (not zero_z) or n_diverse == 1
41
+ assert num_samples % batch_size == 0
42
+ assert n_source <= batch_size * len(dataloader), (batch_size*len(dataloader), n_source, n_diverse)
43
+ metrics = torch_fidelity.calculate_metrics(
44
+ input1=generator,
45
+ input2=real_directory,
46
+ cuda=torch.cuda.is_available(),
47
+ fid=True,
48
+ input2_cache_name="_".join(Path(real_directory).parts) + "_cached",
49
+ input1_model_num_samples=int(num_samples),
50
+ batch_size=dataloader.batch_size
51
+ )
52
+ return metrics["frechet_inception_distance"]
53
+
54
+
55
+ if __name__ == "__main__":
56
+ import click
57
+ from dp2.config import Config
58
+ from dp2.data import build_dataloader_val
59
+ from dp2.infer import build_trained_generator
60
+ @click.command()
61
+ @click.argument("config_path")
62
+ @click.option("--n_source", default=200, type=int)
63
+ @click.option("--n_diverse", default=5, type=int)
64
+ @click.option("--zero_z", default=False, is_flag=True)
65
+ def run(config_path, n_source: int, n_diverse: int, zero_z: bool):
66
+ cfg = Config.fromfile(config_path)
67
+ dataloader = build_dataloader_val(cfg)
68
+ generator, _ = build_trained_generator(cfg)
69
+ print(compute_fid(
70
+ generator, dataloader, cfg.fid_real_directory, n_source, zero_z, n_diverse))
71
+
72
+ run()
dp2/metrics/fid_clip.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torch
3
+ import torchvision
4
+ from pathlib import Path
5
+ from dp2 import utils
6
+ import tops
7
+ try:
8
+ import clip
9
+ except ImportError:
10
+ print("Could not import clip.")
11
+ from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric
12
+ clip_model = None
13
+ clip_preprocess = None
14
+
15
+
16
+ @torch.no_grad()
17
+ def compute_fid_clip(
18
+ dataloader, generator,
19
+ cache_directory,
20
+ data_len=None,
21
+ **kwargs
22
+ ) -> dict:
23
+ """
24
+ FID CLIP following the description in The Role of ImageNet Classes in Frechet Inception Distance, Thomas Kynkaamniemi et al.
25
+ Args:
26
+ n_samples (int): Creates N samples from same image to calculate stats
27
+ """
28
+ global clip_model, clip_preprocess
29
+ if clip_model is None:
30
+ clip_model, preprocess = clip.load("ViT-B/32", device="cpu")
31
+ normalize_fn = preprocess.transforms[-1]
32
+ img_mean = normalize_fn.mean
33
+ img_std = normalize_fn.std
34
+ clip_model = tops.to_cuda(clip_model.visual)
35
+ clip_preprocess = tops.to_cuda(torch.nn.Sequential(
36
+ torchvision.transforms.Resize((224, 224), interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
37
+ torchvision.transforms.Normalize(img_mean, img_std)
38
+ ))
39
+ cache_directory = Path(cache_directory)
40
+ if data_len is None:
41
+ data_len = len(dataloader)*dataloader.batch_size
42
+ fid_cache_path = cache_directory.joinpath("fid_stats_clip.pkl")
43
+ has_fid_cache = fid_cache_path.is_file()
44
+ if not has_fid_cache:
45
+ fid_features_real = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device())
46
+ fid_features_fake = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device())
47
+ eidx = 0
48
+ n_samples_seen = 0
49
+ for batch in utils.tqdm_(iter(dataloader), desc="Computing FID CLIP."):
50
+ sidx = eidx
51
+ eidx = sidx + batch["img"].shape[0]
52
+ n_samples_seen += batch["img"].shape[0]
53
+ with torch.cuda.amp.autocast(tops.AMP()):
54
+ fakes = generator(**batch)["img"]
55
+ real_data = batch["img"]
56
+ fakes = utils.denormalize_img(fakes)
57
+ real_data = utils.denormalize_img(real_data)
58
+ if not has_fid_cache:
59
+ real_data = clip_preprocess(real_data)
60
+ fid_features_real[sidx:eidx] = clip_model(real_data)
61
+ fakes = clip_preprocess(fakes)
62
+ fid_features_fake[sidx:eidx] = clip_model(fakes)
63
+ fid_features_fake = fid_features_fake[:n_samples_seen]
64
+ fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu()
65
+ if has_fid_cache:
66
+ if tops.rank() == 0:
67
+ with open(fid_cache_path, "rb") as fp:
68
+ fid_stat_real = pickle.load(fp)
69
+ else:
70
+ fid_features_real = fid_features_real[:n_samples_seen]
71
+ fid_features_real = tops.all_gather_uneven(fid_features_real).cpu()
72
+ assert fid_features_real.shape == fid_features_fake.shape
73
+ if tops.rank() == 0:
74
+ fid_stat_real = fid_features_to_statistics(fid_features_real)
75
+ cache_directory.mkdir(exist_ok=True, parents=True)
76
+ with open(fid_cache_path, "wb") as fp:
77
+ pickle.dump(fid_stat_real, fp)
78
+
79
+ if tops.rank() == 0:
80
+ print("Starting calculation of fid from features of shape:", fid_features_fake.shape)
81
+ fid_stat_fake = fid_features_to_statistics(fid_features_fake)
82
+ fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"]
83
+ return dict(fid_clip=fid_)
84
+ return dict(fid_clip=-1)
dp2/metrics/lpips.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tops
3
+ import sys
4
+ from contextlib import redirect_stdout
5
+ from torch_fidelity.sample_similarity_lpips import NetLinLayer, URL_VGG16_LPIPS, VGG16features, normalize_tensor, spatial_average
6
+
7
+ class SampleSimilarityLPIPS(torch.nn.Module):
8
+ SUPPORTED_DTYPES = {
9
+ 'uint8': torch.uint8,
10
+ 'float32': torch.float32,
11
+ }
12
+
13
+ def __init__(self):
14
+
15
+ super().__init__()
16
+ self.chns = [64, 128, 256, 512, 512]
17
+ self.L = len(self.chns)
18
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=True)
19
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=True)
20
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=True)
21
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=True)
22
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=True)
23
+ self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
24
+ with redirect_stdout(sys.stderr):
25
+ fp = tops.download_file(URL_VGG16_LPIPS)
26
+ state_dict = torch.load(fp, map_location="cpu")
27
+ self.load_state_dict(state_dict)
28
+ self.net = VGG16features()
29
+ self.eval()
30
+ for param in self.parameters():
31
+ param.requires_grad = False
32
+ mean_rescaled = (1 + torch.tensor([-.030, -.088, -.188]).view(1, 3, 1, 1)) * 255 / 2
33
+ inv_std_rescaled = 2 / (torch.tensor([.458, .448, .450]).view(1, 3, 1, 1) * 255)
34
+ self.register_buffer("mean", mean_rescaled)
35
+ self.register_buffer("std", inv_std_rescaled)
36
+
37
+ def normalize(self, x):
38
+ # torchvision values in range [0,1] mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]
39
+ x = (x.float() - self.mean) * self.std
40
+ return x
41
+
42
+ @staticmethod
43
+ def resize(x, size):
44
+ if x.shape[-1] > size and x.shape[-2] > size:
45
+ x = torch.nn.functional.interpolate(x, (size, size), mode='area')
46
+ else:
47
+ x = torch.nn.functional.interpolate(x, (size, size), mode='bilinear', align_corners=False)
48
+ return x
49
+
50
+ def lpips_from_feats(self, feats0, feats1):
51
+ diffs = {}
52
+ for kk in range(self.L):
53
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
54
+
55
+ res = [spatial_average(self.lins[kk].model(diffs[kk])) for kk in range(self.L)]
56
+ val = sum(res)
57
+ return val
58
+
59
+ def get_feats(self, x):
60
+ assert x.dim() == 4 and x.shape[1] == 3, 'Input 0 is not Bx3xHxW'
61
+ if x.shape[-2] < 16: # Resize images < 16x16
62
+ f = 16 / x.shape[-2]
63
+ size = tuple([int(f*_) for _ in x.shape[-2:]])
64
+ x = torch.nn.functional.interpolate(x, size=size, mode="bilinear", align_corners=False)
65
+ in0_input = self.normalize(x)
66
+ outs0 = self.net.forward(in0_input)
67
+
68
+ feats = {}
69
+ for kk in range(self.L):
70
+ feats[kk] = normalize_tensor(outs0[kk])
71
+ return feats
72
+
73
+ def forward(self, in0, in1):
74
+ feats0 = self.get_feats(in0)
75
+ feats1 = self.get_feats(in1)
76
+ return self.lpips_from_feats(feats0, feats1), feats0, feats1
dp2/metrics/ppl.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import tops
4
+ from dp2 import utils
5
+ from torch_fidelity.helpers import get_kwarg, vassert
6
+ from torch_fidelity.defaults import DEFAULTS as PPL_DEFAULTS
7
+ from torch_fidelity.utils import sample_random, batch_interp, create_sample_similarity
8
+
9
+
10
+ def slerp(a, b, t):
11
+ a = a / a.norm(dim=-1, keepdim=True)
12
+ b = b / b.norm(dim=-1, keepdim=True)
13
+ d = (a * b).sum(dim=-1, keepdim=True)
14
+ p = t * torch.acos(d)
15
+ c = b - d * a
16
+ c = c / c.norm(dim=-1, keepdim=True)
17
+ d = a * torch.cos(p) + c * torch.sin(p)
18
+ d = d / d.norm(dim=-1, keepdim=True)
19
+ return d
20
+
21
+
22
+ @torch.no_grad()
23
+ def calculate_ppl(
24
+ dataloader,
25
+ generator,
26
+ latent_space=None,
27
+ data_len=None,
28
+ **kwargs) -> dict:
29
+ """
30
+ Inspired by https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py
31
+ """
32
+ if latent_space is None:
33
+ latent_space = generator.latent_space
34
+ assert latent_space in ["Z", "W"], f"Not supported latent space: {latent_space}"
35
+ epsilon = PPL_DEFAULTS["ppl_epsilon"]
36
+ interp = PPL_DEFAULTS['ppl_z_interp_mode']
37
+ similarity_name = PPL_DEFAULTS['ppl_sample_similarity']
38
+ sample_similarity_resize = PPL_DEFAULTS['ppl_sample_similarity_resize']
39
+ sample_similarity_dtype = PPL_DEFAULTS['ppl_sample_similarity_dtype']
40
+ discard_percentile_lower = PPL_DEFAULTS['ppl_discard_percentile_lower']
41
+ discard_percentile_higher = PPL_DEFAULTS['ppl_discard_percentile_higher']
42
+
43
+ vassert(type(epsilon) is float and epsilon > 0, 'Epsilon must be a small positive floating point number')
44
+ vassert(discard_percentile_lower is None or 0 < discard_percentile_lower < 100, 'Invalid percentile')
45
+ vassert(discard_percentile_higher is None or 0 < discard_percentile_higher < 100, 'Invalid percentile')
46
+ if discard_percentile_lower is not None and discard_percentile_higher is not None:
47
+ vassert(0 < discard_percentile_lower < discard_percentile_higher < 100, 'Invalid percentiles')
48
+
49
+ sample_similarity = create_sample_similarity(
50
+ similarity_name,
51
+ sample_similarity_resize=sample_similarity_resize,
52
+ sample_similarity_dtype=sample_similarity_dtype,
53
+ cuda=False,
54
+ **kwargs
55
+ )
56
+ sample_similarity = tops.to_cuda(sample_similarity)
57
+ rng = np.random.RandomState(get_kwarg('rng_seed', kwargs))
58
+ distances = []
59
+ if data_len is None:
60
+ data_len = len(dataloader) * dataloader.batch_size
61
+ z0 = sample_random(rng, (data_len, generator.z_channels), "normal")
62
+ z1 = sample_random(rng, (data_len, generator.z_channels), "normal")
63
+ if latent_space == "Z":
64
+ z1 = batch_interp(z0, z1, epsilon, interp)
65
+
66
+ distances = torch.zeros(data_len, dtype=torch.float32, device=tops.get_device())
67
+ print(distances.shape)
68
+ end = 0
69
+ n_samples = 0
70
+ for it, batch in enumerate(utils.tqdm_(dataloader, desc="Perceptual Path Length")):
71
+ start = end
72
+ end = start + batch["img"].shape[0]
73
+ n_samples += batch["img"].shape[0]
74
+ batch_lat_e0 = tops.to_cuda(z0[start:end])
75
+ batch_lat_e1 = tops.to_cuda(z1[start:end])
76
+ if latent_space == "W":
77
+ w0 = generator.get_w(batch_lat_e0, update_emas=False)
78
+ w1 = generator.get_w(batch_lat_e1, update_emas=False)
79
+ w1 = w0.lerp(w1, epsilon) # PPL end
80
+ rgb1 = generator(**batch, w=w0)["img"]
81
+ rgb2 = generator(**batch, w=w1)["img"]
82
+ else:
83
+ rgb1 = generator(**batch, z=batch_lat_e0)["img"]
84
+ rgb2 = generator(**batch, z=batch_lat_e1)["img"]
85
+ rgb1 = utils.denormalize_img(rgb1).mul(255).byte()
86
+ rgb2 = utils.denormalize_img(rgb2).mul(255).byte()
87
+
88
+ sim = sample_similarity(rgb1, rgb2)
89
+ dist_lat_e01 = sim / (epsilon ** 2)
90
+ distances[start:end] = dist_lat_e01.view(-1)
91
+ distances = distances[:n_samples]
92
+ distances = tops.all_gather_uneven(distances).cpu().numpy()
93
+ if tops.rank() != 0:
94
+ return {"ppl/mean": -1, "ppl/std": -1}
95
+ if tops.rank() == 0:
96
+ cond, lo, hi = None, None, None
97
+ if discard_percentile_lower is not None:
98
+ lo = np.percentile(distances, discard_percentile_lower, interpolation='lower')
99
+ cond = lo <= distances
100
+ if discard_percentile_higher is not None:
101
+ hi = np.percentile(distances, discard_percentile_higher, interpolation='higher')
102
+ cond = np.logical_and(cond, distances <= hi)
103
+ if cond is not None:
104
+ distances = np.extract(cond, distances)
105
+ return {
106
+ "ppl/mean": float(np.mean(distances)),
107
+ "ppl/std": float(np.std(distances)),
108
+ }
109
+ else:
110
+ return {"ppl/mean"}
dp2/metrics/torch_metrics.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import numpy as np
3
+ import torch
4
+ import time
5
+ from pathlib import Path
6
+ from dp2 import utils
7
+ import tops
8
+ from .lpips import SampleSimilarityLPIPS
9
+ from torch_fidelity.defaults import DEFAULTS as trf_defaults
10
+ from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric
11
+ from torch_fidelity.utils import create_feature_extractor
12
+ lpips_model = None
13
+ fid_model = None
14
+
15
+ @torch.no_grad()
16
+ def mse(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
17
+ se = (images1 - images2) ** 2
18
+ se = se.view(images1.shape[0], -1).mean(dim=1)
19
+ return se
20
+
21
+ @torch.no_grad()
22
+ def psnr(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
23
+ mse_ = mse(images1, images2)
24
+ psnr = 10 * torch.log10(1 / mse_)
25
+ return psnr
26
+
27
+ @torch.no_grad()
28
+ def lpips(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
29
+ return _lpips_w_grad(images1, images2)
30
+
31
+
32
+ def _lpips_w_grad(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor:
33
+ global lpips_model
34
+ if lpips_model is None:
35
+ lpips_model = tops.to_cuda(SampleSimilarityLPIPS())
36
+
37
+ images1 = images1.mul(255)
38
+ images2 = images2.mul(255)
39
+ with torch.cuda.amp.autocast(tops.AMP()):
40
+ dists = lpips_model(images1, images2)[0].view(-1)
41
+ return dists
42
+
43
+
44
+
45
+
46
+ @torch.no_grad()
47
+ def compute_metrics_iteratively(
48
+ dataloader, generator,
49
+ cache_directory,
50
+ data_len=None,
51
+ truncation_value: float=None,
52
+ ) -> dict:
53
+ """
54
+ Args:
55
+ n_samples (int): Creates N samples from same image to calculate stats
56
+ dataset_percentage (float): The percentage of the dataset to compute metrics on.
57
+ """
58
+
59
+ global lpips_model, fid_model
60
+ if lpips_model is None:
61
+ lpips_model = tops.to_cuda(SampleSimilarityLPIPS())
62
+ if fid_model is None:
63
+ fid_model = create_feature_extractor(
64
+ trf_defaults["feature_extractor"], [trf_defaults["feature_layer_fid"]], cuda=False)
65
+ fid_model = tops.to_cuda(fid_model)
66
+ cache_directory = Path(cache_directory)
67
+ start_time = time.time()
68
+ lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device())
69
+ diversity_total = torch.zeros_like(lpips_total)
70
+ fid_cache_path = cache_directory.joinpath("fid_stats.pkl")
71
+ has_fid_cache = fid_cache_path.is_file()
72
+ if data_len is None:
73
+ data_len = len(dataloader)*dataloader.batch_size
74
+ if not has_fid_cache:
75
+ fid_features_real = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device())
76
+ fid_features_fake = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device())
77
+ n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device())
78
+ eidx = 0
79
+ for batch in utils.tqdm_(iter(dataloader), desc="Computing FID, LPIPS and LPIPS Diversity"):
80
+ sidx = eidx
81
+ eidx = sidx + batch["img"].shape[0]
82
+ n_samples_seen += batch["img"].shape[0]
83
+ with torch.cuda.amp.autocast(tops.AMP()):
84
+ fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"]
85
+ fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"]
86
+ fakes1 = utils.denormalize_img(fakes1).mul(255)
87
+ fakes2 = utils.denormalize_img(fakes2).mul(255)
88
+ real_data = utils.denormalize_img(batch["img"]).mul(255)
89
+ lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1)
90
+ fake2_lpips_feats = lpips_model.get_feats(fakes2)
91
+ lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats)
92
+
93
+ lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2)
94
+ diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum()
95
+ if not has_fid_cache:
96
+ fid_features_real[sidx:eidx] = fid_model(real_data.byte())[0]
97
+ fid_features_fake[sidx:eidx] = fid_model(fakes1.byte())[0]
98
+ fid_features_fake = fid_features_fake[:n_samples_seen]
99
+ if has_fid_cache:
100
+ if tops.rank() == 0:
101
+ with open(fid_cache_path, "rb") as fp:
102
+ fid_stat_real = pickle.load(fp)
103
+ else:
104
+ fid_features_real = fid_features_real[:n_samples_seen]
105
+ fid_features_real = tops.all_gather_uneven(fid_features_real).cpu()
106
+ if tops.rank() == 0:
107
+ fid_stat_real = fid_features_to_statistics(fid_features_real)
108
+ cache_directory.mkdir(exist_ok=True, parents=True)
109
+ with open(fid_cache_path, "wb") as fp:
110
+ pickle.dump(fid_stat_real, fp)
111
+ fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu()
112
+ if tops.rank() == 0:
113
+ print("Starting calculation of fid from features of shape:", fid_features_fake.shape)
114
+ fid_stat_fake = fid_features_to_statistics(fid_features_fake)
115
+ fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"]
116
+ tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM)
117
+ tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM)
118
+ tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM)
119
+ lpips_total = lpips_total / n_samples_seen
120
+ diversity_total = diversity_total / n_samples_seen
121
+ to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total)
122
+ if tops.rank() == 0:
123
+ to_return["fid"] = fid_
124
+ else:
125
+ to_return["fid"] = -1
126
+ to_return["validation_time_s"] = time.time() - start_time
127
+ return to_return
128
+
129
+
130
+ @torch.no_grad()
131
+ def compute_lpips(
132
+ dataloader, generator,
133
+ truncation_value: float=None,
134
+ data_len=None,
135
+ ) -> dict:
136
+ """
137
+ Args:
138
+ n_samples (int): Creates N samples from same image to calculate stats
139
+ dataset_percentage (float): The percentage of the dataset to compute metrics on.
140
+ """
141
+ global lpips_model, fid_model
142
+ if lpips_model is None:
143
+ lpips_model = tops.to_cuda(SampleSimilarityLPIPS())
144
+ start_time = time.time()
145
+ lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device())
146
+ diversity_total = torch.zeros_like(lpips_total)
147
+ if data_len is None:
148
+ data_len = len(dataloader) * dataloader.batch_size
149
+ eidx = 0
150
+ n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device())
151
+ for batch in utils.tqdm_(dataloader, desc="Validating on dataset."):
152
+ sidx = eidx
153
+ eidx = sidx + batch["img"].shape[0]
154
+ n_samples_seen += batch["img"].shape[0]
155
+ with torch.cuda.amp.autocast(tops.AMP()):
156
+ fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"]
157
+ fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"]
158
+ real_data = batch["img"]
159
+ fakes1 = utils.denormalize_img(fakes1).mul(255)
160
+ fakes2 = utils.denormalize_img(fakes2).mul(255)
161
+ real_data = utils.denormalize_img(real_data).mul(255)
162
+ lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1)
163
+ fake2_lpips_feats = lpips_model.get_feats(fakes2)
164
+ lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats)
165
+
166
+ lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2)
167
+ diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum()
168
+ tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM)
169
+ tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM)
170
+ tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM)
171
+ lpips_total = lpips_total / n_samples_seen
172
+ diversity_total = diversity_total / n_samples_seen
173
+ to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total)
174
+ to_return = {k: v.cpu().item() for k, v in to_return.items()}
175
+ to_return["validation_time_s"] = time.time() - start_time
176
+ return to_return