diff --git a/dp2/__init__.py b/dp2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dp2/anonymizer/__init__.py b/dp2/anonymizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb33d7e6ad3b247938dc20ab2311728f286eb14 --- /dev/null +++ b/dp2/anonymizer/__init__.py @@ -0,0 +1 @@ +from .anonymizer import Anonymizer \ No newline at end of file diff --git a/dp2/anonymizer/anonymizer.py b/dp2/anonymizer/anonymizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d850384b3fa33b08b5d5b6770b584c5b64adce44 --- /dev/null +++ b/dp2/anonymizer/anonymizer.py @@ -0,0 +1,159 @@ +from pathlib import Path +from typing import Union, Optional +import numpy as np +import torch +import tops +import torchvision.transforms.functional as F +from motpy import Detection, MultiObjectTracker +from dp2.utils import load_config +from dp2.infer import build_trained_generator +from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection + + +def load_generator_from_cfg_path(cfg_path: Union[str, Path]): + cfg = load_config(cfg_path) + G = build_trained_generator(cfg) + tops.logger.log(f"Loaded generator from: {cfg_path}") + return G + + +def resize_batch(img, mask, maskrcnn_mask, condition, imsize, **kwargs): + img = F.resize(img, imsize, antialias=True) + mask = (F.resize(mask, imsize, antialias=True) > 0.99).float() + maskrcnn_mask = (F.resize(maskrcnn_mask, imsize, antialias=True) > 0.5).float() + + condition = img * mask + return dict(img=img, mask=mask, maskrcnn_mask=maskrcnn_mask, condition=condition) + + +class Anonymizer: + + def __init__( + self, + detector, + load_cache: bool, + person_G_cfg: Optional[Union[str, Path]] = None, + cse_person_G_cfg: Optional[Union[str, Path]] = None, + face_G_cfg: Optional[Union[str, Path]] = None, + car_G_cfg: Optional[Union[str, Path]] = None, + ) -> None: + self.detector = detector + self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]} + self.load_cache = load_cache + if cse_person_G_cfg is not None: + self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg) + if person_G_cfg is not None: + self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg) + if face_G_cfg is not None: + self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg) + if car_G_cfg is not None: + self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg) + + def initialize_tracker(self, fps: float): + self.tracker = MultiObjectTracker(dt=1/fps) + self.track_to_z_idx = dict() + self.cur_z_idx = 0 + + @torch.no_grad() + def anonymize_detections(self, + im, detection, truncation_value: float, + multi_modal_truncation: bool, amp: bool, z_idx, + all_styles=None, + update_identity=None, + ): + G = self.generators[type(detection)] + if G is None: + return im + C, H, W = im.shape + orig_im = im.clone() + if update_identity is None: + update_identity = [True for i in range(len(detection))] + for idx in range(len(detection)): + if not update_identity[idx]: + continue + batch = detection.get_crop(idx, im) + x0, y0, x1, y1 = batch.pop("boxes")[0] + batch = {k: tops.to_cuda(v) for k, v in batch.items()} + batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255]) + batch["img"] = batch["img"].float() + batch["condition"] = batch["mask"] * batch["img"] + orig_shape = None + if G.imsize and batch["img"].shape[-1] != G.imsize[-1] and batch["img"].shape[-2] != G.imsize[-2]: + orig_shape = batch["img"].shape[-2:] + batch = resize_batch(**batch, imsize=G.imsize) + with torch.cuda.amp.autocast(amp): + if all_styles is not None: + anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"] + elif multi_modal_truncation and hasattr(G, "multi_modal_truncate") and hasattr(G.style_net, "w_centers"): + w_indices = None + if z_idx is not None: + w_indices = [z_idx[idx] % len(G.style_net.w_centers)] + anonymized_im = G.multi_modal_truncate( + **batch, truncation_value=truncation_value, + w_indices=w_indices)["img"] + else: + z = None + if z_idx is not None: + state = np.random.RandomState(seed=z_idx[idx]) + z = state.normal(size=(1, G.z_channels)) + z = tops.to_cuda(torch.from_numpy(z)) + anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"] + if orig_shape is not None: + anonymized_im = F.resize(anonymized_im, orig_shape, antialias=True) + anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255).round().byte() + + # Resize and denormalize image + gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), antialias=True) + mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0) + # Remove padding + pad = [max(-x0,0), max(-y0,0)] + pad = [*pad, max(x1-W,0), max(y1-H,0)] + remove_pad = lambda x: x[...,pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]] + gim = remove_pad(gim) + mask = remove_pad(mask) + x0, y0 = max(x0, 0), max(y0, 0) + x1, y1 = min(x1, W), min(y1, H) + mask = mask.logical_not()[None].repeat(3, 1, 1) + im[:, y0:y1, x0:x1][mask] = gim[mask] + + return im + + def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor: + all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache) + for det in all_detections: + im = det.visualize(im) + return im + + @torch.no_grad() + def forward(self, im: torch.Tensor, cache_id: str = None, track=True, **synthesis_kwargs) -> torch.Tensor: + assert im.dtype == torch.uint8 + im = tops.to_cuda(im) + all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache) + if hasattr(self, "tracker") and track: + [_.pre_process() for _ in all_detections] + import numpy as np + boxes = np.concatenate([_.boxes for _ in all_detections]) + boxes = [Detection(box) for box in boxes] + self.tracker.step(boxes) + track_ids = self.tracker.detections_matched_ids + z_idx = [] + for track_id in track_ids: + if track_id not in self.track_to_z_idx: + self.track_to_z_idx[track_id] = self.cur_z_idx + self.cur_z_idx += 1 + z_idx.append(self.track_to_z_idx[track_id]) + z_idx = np.array(z_idx) + idx_offset = 0 + + for detection in all_detections: + zs = None + if hasattr(self, "tracker") and track: + zs = z_idx[idx_offset:idx_offset+len(detection)] + idx_offset += len(detection) + im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs) + + return im.cpu() + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + diff --git a/dp2/data/__init__.py b/dp2/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dp2/data/build.py b/dp2/data/build.py new file mode 100644 index 0000000000000000000000000000000000000000..07f2a4e3630b405bf5a84f6926a733044345d741 --- /dev/null +++ b/dp2/data/build.py @@ -0,0 +1,148 @@ +import io +import torch +import tops +from .utils import collate_fn, jpg_decoder, get_num_workers, png_decoder + +def get_dataloader( + dataset, gpu_transform: torch.nn.Module, + num_workers, + batch_size, + infinite: bool, + drop_last: bool, + prefetch_factor: int, + shuffle, + channels_last=False + ): + sampler = None + dl_kwargs = dict( + pin_memory=True, + ) + if infinite: + sampler = tops.InfiniteSampler( + dataset, rank=tops.rank(), + num_replicas=tops.world_size(), + shuffle=shuffle + ) + elif tops.world_size() > 1: + sampler = torch.utils.data.DistributedSampler( + dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank()) + dl_kwargs["drop_last"] = drop_last + else: + dl_kwargs["shuffle"] = shuffle + dl_kwargs["drop_last"] = drop_last + dataloader = torch.utils.data.DataLoader( + dataset, sampler=sampler, collate_fn=collate_fn, + batch_size=batch_size, + num_workers=num_workers, prefetch_factor=prefetch_factor, + **dl_kwargs + ) + dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last) + return dataloader + + +def get_dataloader_places2_wds( + path, + batch_size: int, + num_workers: int, + transform: torch.nn.Module, + gpu_transform: torch.nn.Module, + infinite: bool, + shuffle: bool, + partial_batches: bool, + sample_shuffle=10_000, + tar_shuffle=100, + channels_last=False, + ): + import webdataset as wds + import os + os.environ["RANK"] = str(tops.rank()) + os.environ["WORLD_SIZE"] = str(tops.world_size()) + + if infinite: + pipeline = [wds.ResampledShards(str(path))] + else: + pipeline = [wds.SimpleShardList(str(path))] + if shuffle: + pipeline.append(wds.shuffle(tar_shuffle)) + pipeline.extend([ + wds.split_by_node, + wds.split_by_worker, + ]) + if shuffle: + pipeline.append(wds.shuffle(sample_shuffle)) + + pipeline.extend([ + wds.tarfile_to_samples(), + wds.decode("torchrgb8"), + wds.rename_keys(["img", "jpg"], ["__key__", "__key__"]), + ]) + if transform is not None: + pipeline.append(wds.map(transform)) + pipeline.extend([ + wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches), + ]) + pipeline = wds.DataPipeline(*pipeline) + if infinite: + pipeline = pipeline.repeat(nepochs=1000000) + loader = wds.WebLoader( + pipeline, batch_size=None, shuffle=False, + num_workers=get_num_workers(num_workers), + persistent_workers=True, + ) + loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False) + return loader + + + + +def get_dataloader_celebAHQ_wds( + path, + batch_size: int, + num_workers: int, + transform: torch.nn.Module, + gpu_transform: torch.nn.Module, + infinite: bool, + shuffle: bool, + partial_batches: bool, + sample_shuffle=10_000, + tar_shuffle=100, + channels_last=False, + ): + import webdataset as wds + import os + os.environ["RANK"] = str(tops.rank()) + os.environ["WORLD_SIZE"] = str(tops.world_size()) + + if infinite: + pipeline = [wds.ResampledShards(str(path))] + else: + pipeline = [wds.SimpleShardList(str(path))] + if shuffle: + pipeline.append(wds.shuffle(tar_shuffle)) + pipeline.extend([ + wds.split_by_node, + wds.split_by_worker, + ]) + if shuffle: + pipeline.append(wds.shuffle(sample_shuffle)) + + pipeline.extend([ + wds.tarfile_to_samples(), + wds.decode(wds.handle_extension(".png", png_decoder)), + wds.rename_keys(["img", "png"], ["__key__", "__key__"]), + ]) + if transform is not None: + pipeline.append(wds.map(transform)) + pipeline.extend([ + wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches), + ]) + pipeline = wds.DataPipeline(*pipeline) + if infinite: + pipeline = pipeline.repeat(nepochs=1000000) + loader = wds.WebLoader( + pipeline, batch_size=None, shuffle=False, + num_workers=get_num_workers(num_workers), + persistent_workers=True, + ) + loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last) + return loader diff --git a/dp2/data/datasets/__init__.py b/dp2/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dp2/data/datasets/coco_cse.py b/dp2/data/datasets/coco_cse.py new file mode 100644 index 0000000000000000000000000000000000000000..27fa6dfb94f118b939788e0134e6ac42c613b297 --- /dev/null +++ b/dp2/data/datasets/coco_cse.py @@ -0,0 +1,148 @@ +import pickle +import torchvision +import torch +import pathlib +import numpy as np +from typing import Callable, Optional, Union +from torch.hub import get_dir as get_hub_dir + + +def cache_embed_stats(embed_map: torch.Tensor): + mean = embed_map.mean(dim=0, keepdim=True) + rstd = ((embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt() + + cache = dict(mean=mean, rstd=rstd, embed_map=embed_map) + path = pathlib.Path(get_hub_dir(), f"embed_map_stats.torch") + path.parent.mkdir(exist_ok=True, parents=True) + torch.save(cache, path) + + +class CocoCSE(torch.utils.data.Dataset): + + def __init__(self, + dirpath: Union[str, pathlib.Path], + transform: Optional[Callable], + normalize_E: bool,): + dirpath = pathlib.Path(dirpath) + self.dirpath = dirpath + + self.transform = transform + assert self.dirpath.is_dir(),\ + f"Did not find dataset at: {dirpath}" + self.image_paths, self.embedding_paths = self._load_impaths() + self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy"))) + mean = self.embed_map.mean(dim=0, keepdim=True) + rstd = ((self.embed_map - mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt() + self.embed_map = (self.embed_map - mean) * rstd + cache_embed_stats(self.embed_map) + + def _load_impaths(self): + image_dir = self.dirpath.joinpath("images") + image_paths = list(image_dir.glob("*.png")) + image_paths.sort() + embedding_paths = [ + self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths + ] + return image_paths, embedding_paths + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + im = torchvision.io.read_image(str(self.image_paths[idx])) + vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1) + vertices = torch.from_numpy(vertices.squeeze()).long() + mask = torch.from_numpy(mask.squeeze()).float() + border = torch.from_numpy(border.squeeze()).float() + E_mask = 1 - mask - border + batch = { + "img": im, + "vertices": vertices[None], + "mask": mask[None], + "embed_map": self.embed_map, + "border": border[None], + "E_mask": E_mask[None] + } + if self.transform is None: + return batch + return self.transform(batch) + + +class CocoCSEWithFace(CocoCSE): + + def __init__(self, + dirpath: Union[str, pathlib.Path], + transform: Optional[Callable], + **kwargs): + super().__init__(dirpath, transform, **kwargs) + with open(self.dirpath.joinpath("face_boxes_XYXY.pickle"), "rb") as fp: + self.face_boxes = pickle.load(fp) + + def __getitem__(self, idx): + item = super().__getitem__(idx) + item["boxes_XYXY"] = self.face_boxes[self.image_paths[idx].name] + return item + + +class CocoCSESemantic(torch.utils.data.Dataset): + + def __init__(self, + dirpath: Union[str, pathlib.Path], + transform: Optional[Callable], + **kwargs): + dirpath = pathlib.Path(dirpath) + self.dirpath = dirpath + + self.transform = transform + assert self.dirpath.is_dir(),\ + f"Did not find dataset at: {dirpath}" + self.image_paths, self.embedding_paths = self._load_impaths() + self.vertx2cat = torch.from_numpy(np.load(self.dirpath.parent.joinpath("vertx2cat.npy"))) + self.embed_map = torch.from_numpy(np.load(self.dirpath.joinpath("embed_map.npy"))) + + def _load_impaths(self): + image_dir = self.dirpath.joinpath("images") + image_paths = list(image_dir.glob("*.png")) + image_paths.sort() + embedding_paths = [ + self.dirpath.joinpath("embedding", x.stem + ".npy") for x in image_paths + ] + return image_paths, embedding_paths + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + im = torchvision.io.read_image(str(self.image_paths[idx])) + vertices, mask, border = np.split(np.load(self.embedding_paths[idx]), 3, axis=-1) + vertices = torch.from_numpy(vertices.squeeze()).long() + mask = torch.from_numpy(mask.squeeze()).float() + border = torch.from_numpy(border.squeeze()).float() + E_mask = 1 - mask - border + batch = { + "img": im, + "vertices": vertices[None], + "mask": mask[None], + "border": border[None], + "vertx2cat": self.vertx2cat, + "embed_map": self.embed_map, + } + if self.transform is None: + return batch + return self.transform(batch) + + +class CocoCSESemanticWithFace(CocoCSESemantic): + + def __init__(self, + dirpath: Union[str, pathlib.Path], + transform: Optional[Callable], + **kwargs): + super().__init__(dirpath, transform, **kwargs) + with open(self.dirpath.joinpath("face_boxes_XYXY.pickle"), "rb") as fp: + self.face_boxes = pickle.load(fp) + + def __getitem__(self, idx): + item = super().__getitem__(idx) + item["boxes_XYXY"] = self.face_boxes[self.image_paths[idx].name] + return item diff --git a/dp2/data/datasets/fdf.py b/dp2/data/datasets/fdf.py new file mode 100644 index 0000000000000000000000000000000000000000..b05c75692515c5e294af72417371af9fcefbfad8 --- /dev/null +++ b/dp2/data/datasets/fdf.py @@ -0,0 +1,129 @@ +import pathlib +from typing import Tuple +import numpy as np +import torch +import pathlib +try: + import pyspng + PYSPNG_IMPORTED = True +except ImportError: + PYSPNG_IMPORTED = False + print("Could not load pyspng. Defaulting to pillow image backend.") + from PIL import Image +from tops import logger + + +class FDFDataset: + + def __init__(self, + dirpath, + imsize: Tuple[int], + load_keypoints: bool, + transform): + dirpath = pathlib.Path(dirpath) + self.dirpath = dirpath + self.transform = transform + self.imsize = imsize[0] + self.load_keypoints = load_keypoints + assert self.dirpath.is_dir(),\ + f"Did not find dataset at: {dirpath}" + image_dir = self.dirpath.joinpath("images", str(self.imsize)) + self.image_paths = list(image_dir.glob("*.png")) + assert len(self.image_paths) > 0,\ + f"Did not find images in: {image_dir}" + self.image_paths.sort(key=lambda x: int(x.stem)) + self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32) + + self.bounding_boxes = torch.load(self.dirpath.joinpath("bounding_box", f"{self.imsize}.torch")) + assert len(self.image_paths) == len(self.bounding_boxes) + assert len(self.image_paths) == len(self.landmarks) + logger.log( + f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}, imsize={imsize}") + + def get_mask(self, idx): + mask = torch.ones((1, self.imsize, self.imsize), dtype=torch.bool) + bounding_box = self.bounding_boxes[idx] + x0, y0, x1, y1 = bounding_box + mask[:, y0:y1, x0:x1] = 0 + return mask + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, index): + impath = self.image_paths[index] + if PYSPNG_IMPORTED: + with open(impath, "rb") as fp: + im = pyspng.load(fp.read()) + else: + with Image.open(impath) as fp: + im = np.array(fp) + im = torch.from_numpy(np.rollaxis(im, -1, 0)) + masks = self.get_mask(index) + landmark = self.landmarks[index] + batch = { + "img": im, + "mask": masks, + } + if self.load_keypoints: + batch["keypoints"] = landmark + if self.transform is None: + return batch + return self.transform(batch) + + +class FDF256Dataset: + + def __init__(self, + dirpath, + load_keypoints: bool, + transform): + dirpath = pathlib.Path(dirpath) + self.dirpath = dirpath + self.transform = transform + self.load_keypoints = load_keypoints + assert self.dirpath.is_dir(),\ + f"Did not find dataset at: {dirpath}" + image_dir = self.dirpath.joinpath("images") + self.image_paths = list(image_dir.glob("*.png")) + assert len(self.image_paths) > 0,\ + f"Did not find images in: {image_dir}" + self.image_paths.sort(key=lambda x: int(x.stem)) + self.landmarks = np.load(self.dirpath.joinpath("landmarks.npy")).reshape(-1, 7, 2).astype(np.float32) + self.bounding_boxes = torch.from_numpy(np.load(self.dirpath.joinpath("bounding_box.npy"))) + assert len(self.image_paths) == len(self.bounding_boxes) + assert len(self.image_paths) == len(self.landmarks) + logger.log( + f"Dataset loaded from: {dirpath}. Number of samples:{len(self)}") + + def get_mask(self, idx): + mask = torch.ones((1, 256, 256), dtype=torch.bool) + bounding_box = self.bounding_boxes[idx] + x0, y0, x1, y1 = bounding_box + mask[:, y0:y1, x0:x1] = 0 + return mask + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, index): + impath = self.image_paths[index] + if PYSPNG_IMPORTED: + with open(impath, "rb") as fp: + im = pyspng.load(fp.read()) + else: + with Image.open(impath) as fp: + im = np.array(fp) + im = torch.from_numpy(np.rollaxis(im, -1, 0)) + masks = self.get_mask(index) + landmark = self.landmarks[index] + batch = { + "img": im, + "mask": masks, + } + if self.load_keypoints: + batch["keypoints"] = landmark + if self.transform is None: + return batch + return self.transform(batch) + diff --git a/dp2/data/datasets/fdh.py b/dp2/data/datasets/fdh.py new file mode 100644 index 0000000000000000000000000000000000000000..5eb654ba71604f1b088df586e6683418033b4b16 --- /dev/null +++ b/dp2/data/datasets/fdh.py @@ -0,0 +1,104 @@ +import torch +import tops +import numpy as np +import io +import webdataset as wds +import os +from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn + + +def kp_decoder(x): + # Keypoints are between [0, 1] for webdataset + keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float() + keypoints[:, 0] /= 160 + keypoints[:, 1] /= 288 + check_outside = lambda x: (x < 0).logical_or(x > 1) + is_outside = check_outside(keypoints[:, 0]).logical_or( + check_outside(keypoints[:, 1]) + ) + keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not()) + return keypoints + + +def vertices_decoder(x): + vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32)) + return vertices.squeeze()[None] + + +def get_dataloader_fdh_wds( + path, + batch_size: int, + num_workers: int, + transform: torch.nn.Module, + gpu_transform: torch.nn.Module, + infinite: bool, + shuffle: bool, + partial_batches: bool, + load_embedding: bool, + sample_shuffle=10_000, + tar_shuffle=100, + read_condition=False, + channels_last=False, + ): + # Need to set this for split_by_node to work. + os.environ["RANK"] = str(tops.rank()) + os.environ["WORLD_SIZE"] = str(tops.world_size()) + if infinite: + pipeline = [wds.ResampledShards(str(path))] + else: + pipeline = [wds.SimpleShardList(str(path))] + if shuffle: + pipeline.append(wds.shuffle(tar_shuffle)) + pipeline.extend([ + wds.split_by_node, + wds.split_by_worker, + ]) + if shuffle: + pipeline.append(wds.shuffle(sample_shuffle)) + + decoder = [ + wds.handle_extension("image.png", png_decoder), + wds.handle_extension("mask.png", mask_decoder), + wds.handle_extension("maskrcnn_mask.png", mask_decoder), + wds.handle_extension("keypoints.npy", kp_decoder), + ] + + rename_keys = [ + ["img", "image.png"], ["mask", "mask.png"], + ["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"] + ] + if load_embedding: + decoder.extend([ + wds.handle_extension("vertices.npy", vertices_decoder), + wds.handle_extension("E_mask.png", mask_decoder) + ]) + rename_keys.extend([ + ["vertices", "vertices.npy"], + ["E_mask", "e_mask.png"] + ]) + + if read_condition: + decoder.append( + wds.handle_extension("condition.png", png_decoder) + ) + rename_keys.append(["condition", "condition.png"]) + + pipeline.extend([ + wds.tarfile_to_samples(), + wds.decode(*decoder), + wds.rename_keys(*rename_keys), + wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches), + ]) + if transform is not None: + pipeline.append(wds.map(transform)) + pipeline = wds.DataPipeline(*pipeline) + if infinite: + pipeline = pipeline.repeat(nepochs=1000000) + + loader = wds.WebLoader( + pipeline, batch_size=None, shuffle=False, + num_workers=get_num_workers(num_workers), + persistent_workers=True, + ) + loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False) + return loader diff --git a/dp2/data/transforms/__init__.py b/dp2/data/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..66a9e160b6513cc81a00b0a62606721f37b70c61 --- /dev/null +++ b/dp2/data/transforms/__init__.py @@ -0,0 +1,2 @@ +from .transforms import RandomCrop, CreateCondition, CreateEmbedding, Resize, ToFloat, Normalize +from .stylegan2_transform import StyleGANAugmentPipe \ No newline at end of file diff --git a/dp2/data/transforms/functional.py b/dp2/data/transforms/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5c695944a47d67cac7e03a8a7f5b400c94417b --- /dev/null +++ b/dp2/data/transforms/functional.py @@ -0,0 +1,61 @@ +import torchvision.transforms.functional as F +import torch +import pickle +from tops import download_file, assert_shape +from typing import Dict +from functools import lru_cache + +global symmetry_transform + +@lru_cache(maxsize=1) +def get_symmetry_transform(symmetry_url): + file_name = download_file(symmetry_url) + with open(file_name, "rb") as fp: + symmetry = pickle.load(fp) + return torch.from_numpy(symmetry["vertex_transforms"]).long() + + +hflip_handled_cases = set([ + "keypoints", "img", "mask", "border", "semantic_mask", "vertices", "E_mask", "embed_map", "condition", + "embedding", "vertx2cat", "maskrcnn_mask", "__key__", + "img_hr", "condition_hr", "mask_hr"]) + +def hflip(container: Dict[str, torch.Tensor], flip_map=None) -> Dict[str, torch.Tensor]: + container["img"] = F.hflip(container["img"]) + if "condition" in container: + container["condition"] = F.hflip(container["condition"]) + if "embedding" in container: + container["embedding"] = F.hflip(container["embedding"]) + assert all([key in hflip_handled_cases for key in container]), container.keys() + if "keypoints" in container: + assert flip_map is not None + if container["keypoints"].ndim == 3: + keypoints = container["keypoints"][:, flip_map, :] + keypoints[:, :, 0] = 1 - keypoints[:, :, 0] + else: + assert_shape(container["keypoints"], (None, 3)) + keypoints = container["keypoints"][flip_map, :] + keypoints[:, 0] = 1 - keypoints[:, 0] + container["keypoints"] = keypoints + if "mask" in container: + container["mask"] = F.hflip(container["mask"]) + if "border" in container: + container["border"] = F.hflip(container["border"]) + if "semantic_mask" in container: + container["semantic_mask"] = F.hflip(container["semantic_mask"]) + if "vertices" in container: + symmetry_transform = get_symmetry_transform("https://dl.fbaipublicfiles.com/densepose/meshes/symmetry/symmetry_smpl_27554.pkl") + container["vertices"] = F.hflip(container["vertices"]) + symmetry_transform_ = symmetry_transform.to(container["vertices"].device) + container["vertices"] = symmetry_transform_[container["vertices"].long()] + if "E_mask" in container: + container["E_mask"] = F.hflip(container["E_mask"]) + if "maskrcnn_mask" in container: + container["maskrcnn_mask"] = F.hflip(container["maskrcnn_mask"]) + if "img_hr" in container: + container["img_hr"] = F.hflip(container["img_hr"]) + if "condition_hr" in container: + container["condition_hr"] = F.hflip(container["condition_hr"]) + if "mask_hr" in container: + container["mask_hr"] = F.hflip(container["mask_hr"]) + return container diff --git a/dp2/data/transforms/stylegan2_transform.py b/dp2/data/transforms/stylegan2_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..49a143cddf9673d079b87ac7d725c433713e54c5 --- /dev/null +++ b/dp2/data/transforms/stylegan2_transform.py @@ -0,0 +1,394 @@ +import numpy as np +import scipy.signal +import torch +try: + from sg3_torch_utils import misc + from sg3_torch_utils.ops import upfirdn2d + from sg3_torch_utils.ops import grid_sample_gradfix + from sg3_torch_utils.ops import conv2d_gradfix +except: + pass +#---------------------------------------------------------------------------- +# Coefficients of various wavelet decomposition low-pass filters. + +wavelets = { + 'haar': [0.7071067811865476, 0.7071067811865476], + 'db1': [0.7071067811865476, 0.7071067811865476], + 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], + 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], + 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523], + 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125], + 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017], + 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236], + 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161], + 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], + 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], + 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427], + 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728], + 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148], + 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255], + 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609], +} + +#---------------------------------------------------------------------------- +# Helpers for constructing transformation matrices. + + +def matrix(*rows, device=None): + assert all(len(row) == len(rows[0]) for row in rows) + elems = [x for row in rows for x in row] + ref = [x for x in elems if isinstance(x, torch.Tensor)] + if len(ref) == 0: + return misc.constant(np.asarray(rows), device=device) + assert device is None or device == ref[0].device + elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems] + return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) + + +def translate2d(tx, ty, **kwargs): + return matrix( + [1, 0, tx], + [0, 1, ty], + [0, 0, 1], + **kwargs) + + +def translate3d(tx, ty, tz, **kwargs): + return matrix( + [1, 0, 0, tx], + [0, 1, 0, ty], + [0, 0, 1, tz], + [0, 0, 0, 1], + **kwargs) + + +def scale2d(sx, sy, **kwargs): + return matrix( + [sx, 0, 0], + [0, sy, 0], + [0, 0, 1], + **kwargs) + + +def scale3d(sx, sy, sz, **kwargs): + return matrix( + [sx, 0, 0, 0], + [0, sy, 0, 0], + [0, 0, sz, 0], + [0, 0, 0, 1], + **kwargs) + + +def rotate2d(theta, **kwargs): + return matrix( + [torch.cos(theta), torch.sin(-theta), 0], + [torch.sin(theta), torch.cos(theta), 0], + [0, 0, 1], + **kwargs) + + +def rotate3d(v, theta, **kwargs): + vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2] + s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c + return matrix( + [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0], + [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0], + [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0], + [0, 0, 0, 1], + **kwargs) + + +def translate2d_inv(tx, ty, **kwargs): + return translate2d(-tx, -ty, **kwargs) + + +def scale2d_inv(sx, sy, **kwargs): + return scale2d(1 / sx, 1 / sy, **kwargs) + + +def rotate2d_inv(theta, **kwargs): + return rotate2d(-theta, **kwargs) + + +class StyleGANAugmentPipe(torch.nn.Module): + def __init__(self, + rotate90=0, xint=0, xint_max=0.125, + scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125, + brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, + hue_max=1, saturation_std=1, + imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1, + ): + super().__init__() + self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability. + + # Pixel blitting. + self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations. + self.xint = float(xint) # Probability multiplier for integer translation. + self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions. + + # General geometric transformations. + self.scale = float(scale) # Probability multiplier for isotropic scaling. + self.rotate = float(rotate) # Probability multiplier for arbitrary rotation. + self.aniso = float(aniso) # Probability multiplier for anisotropic scaling. + self.xfrac = float(xfrac) # Probability multiplier for fractional translation. + self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling. + self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle. + self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling. + self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions. + + # Color transformations. + self.brightness = float(brightness) # Probability multiplier for brightness. + self.contrast = float(contrast) # Probability multiplier for contrast. + self.lumaflip = float(lumaflip) # Probability multiplier for luma flip. + self.hue = float(hue) # Probability multiplier for hue rotation. + self.saturation = float(saturation) # Probability multiplier for saturation. + self.brightness_std = float(brightness_std) # Standard deviation of brightness. + self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast. + self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle. + self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation. + + # Image-space filtering. + self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering. + self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands. + self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification. + + # Setup orthogonal lowpass filter for geometric augmentations. + self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6'])) + + # Construct filter bank for image-space filtering. + Hz_lo = np.asarray(wavelets['sym2']) # H(z) + Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z) + Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2 + Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2 + Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i) + for i in range(1, Hz_fbank.shape[0]): + Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1] + Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2]) + Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2 + self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32)) + + def forward(self, batch, debug_percentile=None): + images = batch["img"] + batch["vertices"] = batch["vertices"].float() + assert isinstance(images, torch.Tensor) and images.ndim == 4 + batch_size, num_channels, height, width = images.shape + device = images.device + self.Hz_fbank = self.Hz_fbank.to(device) + self.Hz_geom = self.Hz_geom.to(device) + if debug_percentile is not None: + debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device) + + # ------------------------------------- + # Select parameters for pixel blitting. + # ------------------------------------- + + # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in + I_3 = torch.eye(3, device=device) + G_inv = I_3 + + # Apply integer translation with probability (xint * strength). + if self.xint > 0: + t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max + t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t)) + if debug_percentile is not None: + t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max) + G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height)) + + # -------------------------------------------------------- + # Select parameters for general geometric transformations. + # -------------------------------------------------------- + + # Apply isotropic scaling with probability (scale * strength). + if self.scale > 0: + s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std) + s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std)) + G_inv = G_inv @ scale2d_inv(s, s) + + # Apply pre-rotation with probability p_rot. + p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p + if self.rotate > 0: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max + theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max) + G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling. + + # Apply anisotropic scaling with probability (aniso * strength). + if self.aniso > 0: + s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std) + s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std)) + G_inv = G_inv @ scale2d_inv(s, 1 / s) + + # Apply post-rotation with probability p_rot. + if self.rotate > 0: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max + theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.zeros_like(theta) + G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling. + + # Apply fractional translation with probability (xfrac * strength). + if self.xfrac > 0: + t = torch.randn([batch_size, 2], device=device) * self.xfrac_std + t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t)) + if debug_percentile is not None: + t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std) + G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height) + + # ---------------------------------- + # Execute geometric transformations. + # ---------------------------------- + + # Execute if the transform is not identity. + if G_inv is not I_3: + # Calculate padding. + cx = (width - 1) / 2 + cy = (height - 1) / 2 + cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz] + cp = G_inv @ cp.t() # [batch, xyz, idx] + Hz_pad = self.Hz_geom.shape[0] // 4 + margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx] + margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1] + margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device) + margin = margin.max(misc.constant([0, 0] * 2, device=device)) + margin = margin.min(misc.constant([width-1, height-1] * 2, device=device)) + mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) + + # Pad image and adjust origin. + images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect') + batch["mask"] = torch.nn.functional.pad(input=batch["mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=1.0) + batch["E_mask"] = torch.nn.functional.pad(input=batch["E_mask"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0) + batch["vertices"] = torch.nn.functional.pad(input=batch["vertices"], pad=[mx0,mx1,my0,my1], mode='constant', value=0.0) + G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv + + # Upsample. + images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2) + batch["mask"] = torch.nn.functional.interpolate(batch["mask"], scale_factor=2, mode="nearest") + batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"], scale_factor=2, mode="nearest") + batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"], scale_factor=2, mode="nearest") + G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device) + G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device) + + # Execute transformation. + shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2] + G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device) + grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False) + images = grid_sample_gradfix.grid_sample(images, grid) + + batch["mask"] = torch.nn.functional.grid_sample( + input=batch["mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False) + batch["E_mask"] = torch.nn.functional.grid_sample( + input=batch["E_mask"], grid=grid, mode='nearest', padding_mode="border", align_corners=False) + batch["vertices"] = torch.nn.functional.grid_sample( + input=batch["vertices"], grid=grid, mode='nearest', padding_mode="border", align_corners=False) + + + # Downsample and crop. + images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True) + batch["mask"] = torch.nn.functional.interpolate(batch["mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False) + batch["E_mask"] = torch.nn.functional.interpolate(batch["E_mask"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False) + batch["vertices"] = torch.nn.functional.interpolate(batch["vertices"][:, :, Hz_pad*2:-Hz_pad*2, Hz_pad*2:-Hz_pad*2], scale_factor=.5, mode="nearest", recompute_scale_factor=False) + # -------------------------------------------- + # Select parameters for color transformations. + # -------------------------------------------- + + # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out + I_4 = torch.eye(4, device=device) + C = I_4 + + # Apply brightness with probability (brightness * strength). + if self.brightness > 0: + b = torch.randn([batch_size], device=device) * self.brightness_std + b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b)) + if debug_percentile is not None: + b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std) + C = translate3d(b, b, b) @ C + + # Apply contrast with probability (contrast * strength). + if self.contrast > 0: + c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std) + c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c)) + if debug_percentile is not None: + c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std)) + C = scale3d(c, c, c) @ C + + # Apply luma flip with probability (lumaflip * strength). + v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis. + + # Apply hue rotation with probability (hue * strength). + if self.hue > 0 and num_channels > 1: + theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max + theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max) + C = rotate3d(v, theta) @ C # Rotate around v. + + # Apply saturation with probability (saturation * strength). + if self.saturation > 0 and num_channels > 1: + s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std) + s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std)) + C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C + + # ------------------------------ + # Execute color transformations. + # ------------------------------ + + # Execute if the transform is not identity. + if C is not I_4: + images = images.reshape([batch_size, num_channels, height * width]) + if num_channels == 3: + images = C[:, :3, :3] @ images + C[:, :3, 3:] + elif num_channels == 1: + C = C[:, :3, :].mean(dim=1, keepdims=True) + images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:] + else: + raise ValueError('Image must be RGB (3 channels) or L (1 channel)') + images = images.reshape([batch_size, num_channels, height, width]) + + # ---------------------- + # Image-space filtering. + # ---------------------- + + if self.imgfilter > 0: + num_bands = self.Hz_fbank.shape[0] + assert len(self.imgfilter_bands) == num_bands + expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f). + + # Apply amplification for each band with probability (imgfilter * strength * band_strength). + g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity). + for i, band_strength in enumerate(self.imgfilter_bands): + t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std) + t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i)) + if debug_percentile is not None: + t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i) + t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector. + t[:, i] = t_i # Replace i'th element. + t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power. + g = g * t # Accumulate into global gain. + + # Construct combined amplification filter. + Hz_prime = g @ self.Hz_fbank # [batch, tap] + Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap] + Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap] + + # Apply filter. + p = self.Hz_fbank.shape[1] // 2 + images = images.reshape([1, batch_size * num_channels, height, width]) + images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect') + images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels) + images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels) + images = images.reshape([batch_size, num_channels, height, width]) + + # ------------------------ + # Image-space corruptions. + # ------------------------ + batch["img"] = images + batch["vertices"] = batch["vertices"].long() + batch["border"] = 1 - batch["E_mask"] - batch["mask"] + return batch diff --git a/dp2/data/transforms/transforms.py b/dp2/data/transforms/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..1221a9121d7f59b1ac33c28e4189339e8df6dadf --- /dev/null +++ b/dp2/data/transforms/transforms.py @@ -0,0 +1,247 @@ +from pathlib import Path +from typing import Dict, List +import torchvision +import torch +import tops +import torchvision.transforms.functional as F +from .functional import hflip + + +class RandomHorizontalFlip(torch.nn.Module): + + def __init__(self, p: float, flip_map=None,**kwargs): + super().__init__() + self.flip_ratio = p + self.flip_map = flip_map + if self.flip_ratio is None: + self.flip_ratio = 0.5 + assert 0 <= self.flip_ratio <= 1 + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + if torch.rand(1) > self.flip_ratio: + return container + return hflip(container, self.flip_map) + + +class CenterCrop(torch.nn.Module): + """ + Performs the transform on the image. + NOTE: Does not transform the mask to improve runtime. + """ + + def __init__(self, size: List[int]): + super().__init__() + self.size = tuple(size) + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + min_size = min(container["img"].shape[1], container["img"].shape[2]) + if min_size < self.size[0]: + container["img"] = F.center_crop(container["img"], min_size) + container["img"] = F.resize(container["img"], self.size) + return container + container["img"] = F.center_crop(container["img"], self.size) + return container + + +class Resize(torch.nn.Module): + """ + Performs the transform on the image. + NOTE: Does not transform the mask to improve runtime. + """ + + def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR): + super().__init__() + self.size = tuple(size) + self.interpolation = interpolation + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True) + if "semantic_mask" in container: + container["semantic_mask"] = F.resize( + container["semantic_mask"], self.size, F.InterpolationMode.NEAREST) + if "embedding" in container: + container["embedding"] = F.resize( + container["embedding"], self.size, self.interpolation) + if "mask" in container: + container["mask"] = F.resize( + container["mask"], self.size, F.InterpolationMode.NEAREST) + if "E_mask" in container: + container["E_mask"] = F.resize( + container["E_mask"], self.size, F.InterpolationMode.NEAREST) + if "maskrcnn_mask" in container: + container["maskrcnn_mask"] = F.resize( + container["maskrcnn_mask"], self.size, F.InterpolationMode.NEAREST) + if "vertices" in container: + container["vertices"] = F.resize( + container["vertices"], self.size, F.InterpolationMode.NEAREST) + return container + + def __repr__(self): + repr = super().__repr__() + vars_ = dict(size=self.size, interpolation=self.interpolation) + return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()]) + + +class InsertHRImage(torch.nn.Module): + """ + Resizes mask by maxpool and assumes condition is already created + """ + def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR): + super().__init__() + self.size = tuple(size) + self.interpolation = interpolation + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + assert container["img"].dtype == torch.float32 + container["img_hr"] = F.resize(container["img"], self.size, self.interpolation, antialias=True) + container["condition_hr"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True) + mask = container["mask"] > 0 + container["mask_hr"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float() + container["condition_hr"] = container["condition_hr"] * (1 - container["mask_hr"]) + container["img_hr"] * container["mask_hr"] + return container + + def __repr__(self): + repr = super().__repr__() + vars_ = dict(size=self.size, interpolation=self.interpolation) + return repr + " " + + +class CopyHRImage(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + container["img_hr"] = container["img"] + container["condition_hr"] = container["condition"] + container["mask_hr"] = container["mask"] + return container + + +class Resize2(torch.nn.Module): + """ + Resizes mask by maxpool and assumes condition is already created + """ + def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR, downsample_condition: bool = True, mask_condition= True): + super().__init__() + self.size = tuple(size) + self.interpolation = interpolation + self.downsample_condition = downsample_condition + self.mask_condition = mask_condition + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: +# assert container["img"].dtype == torch.float32 + container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True) + mask = container["mask"] > 0 + container["mask"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float() + + if self.downsample_condition: + container["condition"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True) + if self.mask_condition: + container["condition"] = container["condition"] * (1 - container["mask"]) + container["img"] * container["mask"] + return container + + def __repr__(self): + repr = super().__repr__() + vars_ = dict(size=self.size, interpolation=self.interpolation) + return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()]) + + + +class Normalize(torch.nn.Module): + """ + Performs the transform on the image. + NOTE: Does not transform the mask to improve runtime. + """ + + def __init__(self, mean, std, inplace, keys=["img"]): + super().__init__() + self.mean = mean + self.std = std + self.inplace = inplace + self.keys = keys + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + for key in self.keys: + container[key] = F.normalize(container[key], self.mean, self.std, self.inplace) + return container + + def __repr__(self): + repr = super().__repr__() + vars_ = dict(mean=self.mean, std=self.std, inplace=self.inplace) + return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()]) + + +class ToFloat(torch.nn.Module): + + def __init__(self, keys=["img"], norm=True) -> None: + super().__init__() + self.keys = keys + self.gain = 255 if norm else 1 + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + for key in self.keys: + container[key] = container[key].float() / self.gain + return container + + +class RandomCrop(torchvision.transforms.RandomCrop): + """ + Performs the transform on the image. + NOTE: Does not transform the mask to improve runtime. + """ + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + container["img"] = super().forward(container["img"]) + return container + + +class CreateCondition(torch.nn.Module): + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + if container["img"].dtype == torch.uint8: + container["condition"] = container["img"] * container["mask"].byte() + (1-container["mask"].byte()) * 127 + return container + container["condition"] = container["img"] * container["mask"] + return container + + +class CreateEmbedding(torch.nn.Module): + + def __init__(self, embed_path: Path, cuda=True) -> None: + super().__init__() + self.embed_map = torch.load(embed_path, map_location=torch.device("cpu")) + if cuda: + self.embed_map = tops.to_cuda(self.embed_map) + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + vertices = container["vertices"] + if vertices.ndim == 3: + embedding = self.embed_map[vertices.long()].squeeze(dim=0) + embedding = embedding.permute(2, 0, 1) * container["E_mask"] + pass + else: + assert vertices.ndim == 4 + embedding = self.embed_map[vertices.long()].squeeze(dim=1) + embedding = embedding.permute(0, 3, 1, 2) * container["E_mask"] + container["embedding"] = embedding + container["embed_map"] = self.embed_map.clone() + return container + + +class UpdateMask(torch.nn.Module): + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + container["mask"] = (container["img"] == container["condition"]).any(dim=1, keepdims=True).float() + return container + + +class LoadClassEmbedding(torch.nn.Module): + + def __init__(self, embedding_path: Path) -> None: + super().__init__() + self.embedding = torch.load(embedding_path, map_location="cpu") + + def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + key = "_".join(container["__key__"].split("train/")[-1].split("/")[:-1]) + container["class_embedding"] = self.embedding[key].view(-1) + return container diff --git a/dp2/data/utils.py b/dp2/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1ec03f0ec0091e3263f5aa9b2962ad97a1c36a0f --- /dev/null +++ b/dp2/data/utils.py @@ -0,0 +1,102 @@ +import torch +from PIL import Image +import numpy as np +import multiprocessing +import io +from tops import logger +from torch.utils.data._utils.collate import default_collate + +try: + import pyspng + + PYSPNG_IMPORTED = True +except ImportError: + PYSPNG_IMPORTED = False + print("Could not load pyspng. Defaulting to pillow image backend.") + from PIL import Image + + +def get_coco_keypoints(): + return [ + "nose", + "left_eye", + "right_eye", + "left_ear", + "right_ear", + "left_shoulder", + "right_shoulder", + "left_elbow", + "right_elbow", + "left_wrist", + "right_wrist", + "left_hip", + "right_hip", + "left_knee", + "right_knee", + "left_ankle", + "right_ankle", + ] + + +def get_coco_flipmap(): + keypoints = get_coco_keypoints() + keypoint_flip_map = { + "left_eye": "right_eye", + "left_ear": "right_ear", + "left_shoulder": "right_shoulder", + "left_elbow": "right_elbow", + "left_wrist": "right_wrist", + "left_hip": "right_hip", + "left_knee": "right_knee", + "left_ankle": "right_ankle", + } + for key, value in list(keypoint_flip_map.items()): + keypoint_flip_map[value] = key + keypoint_flip_map["nose"] = "nose" + keypoint_flip_map_idx = [] + for source in keypoints: + keypoint_flip_map_idx.append(keypoints.index(keypoint_flip_map[source])) + return keypoint_flip_map_idx + + +def mask_decoder(x): + mask = torch.from_numpy(np.array(Image.open(io.BytesIO(x)))).squeeze()[None] + mask = mask > 0 # This fixes bug causing maskf.loat().max() == 255. + return mask + + +def png_decoder(x): + if PYSPNG_IMPORTED: + return torch.from_numpy(np.rollaxis(pyspng.load(x), 2)) + with Image.open(io.BytesIO(x)) as im: + im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2)) + return im + + +def jpg_decoder(x): + with Image.open(io.BytesIO(x)) as im: + im = torch.from_numpy(np.rollaxis(np.array(im.convert("RGB")), 2)) + return im + + +def get_num_workers(num_workers: int): + n_cpus = multiprocessing.cpu_count() + if num_workers > n_cpus: + logger.warn(f"Setting the number of workers to match cpu count: {n_cpus}") + return n_cpus + return num_workers + + +def collate_fn(batch): + elem = batch[0] + ignore_keys = set(["embed_map", "vertx2cat"]) + batch_ = { + key: default_collate([d[key] for d in batch]) + for key in elem + if key not in ignore_keys + } + if "embed_map" in elem: + batch_["embed_map"] = elem["embed_map"] + if "vertx2cat" in elem: + batch_["vertx2cat"] = elem["vertx2cat"] + return batch_ diff --git a/dp2/detection/__init__.py b/dp2/detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..613969b28384cd1c64fc8db685e7622f4cc02615 --- /dev/null +++ b/dp2/detection/__init__.py @@ -0,0 +1,3 @@ +from .cse_mask_face_detector import CSeMaskFaceDetector +from .person_detector import CSEPersonDetector +from .structures import PersonDetection, VehicleDetection, FaceDetection diff --git a/dp2/detection/base.py b/dp2/detection/base.py new file mode 100644 index 0000000000000000000000000000000000000000..32ebba893878e95ddc1da451a9f2027799ffa044 --- /dev/null +++ b/dp2/detection/base.py @@ -0,0 +1,45 @@ +import pickle +import torch +import lzma +from pathlib import Path +from tops import logger + + +class BaseDetector: + + + def __init__(self, cache_directory: str) -> None: + if cache_directory is not None: + self.cache_directory = Path(cache_directory, str(self.__class__.__name__)) + self.cache_directory.mkdir(exist_ok=True, parents=True) + + def save_to_cache(self, detection, cache_path: Path, after_preprocess=True): + logger.log(f"Caching detection to: {cache_path}") + with lzma.open(cache_path, "wb") as fp: + torch.save( + [det.state_dict(after_preprocess=after_preprocess) for det in detection], fp, + pickle_protocol=pickle.HIGHEST_PROTOCOL) + + def load_from_cache(self, cache_path: Path): + logger.log(f"Loading detection from cache path: {cache_path}") + with lzma.open(cache_path, "rb") as fp: + state_dict = torch.load(fp) + return [ + state["cls"].from_state_dict(state_dict=state) for state in state_dict + ] + + def forward_and_cache(self, im: torch.Tensor, cache_id: str, load_cache: bool): + if cache_id is None: + return self.forward(im) + cache_path = self.cache_directory.joinpath(cache_id + ".torch") + if cache_path.is_file() and load_cache: + try: + return self.load_from_cache(cache_path) + except Exception as e: + logger.warn(f"The cache file was corrupted: {cache_path}") + exit() + detections = self.forward(im) + self.save_to_cache(detections, cache_path) + return detections + + \ No newline at end of file diff --git a/dp2/detection/box_utils.py b/dp2/detection/box_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6091b122a1e72d05b9cad2a25f4111b425eabb93 --- /dev/null +++ b/dp2/detection/box_utils.py @@ -0,0 +1,104 @@ +import numpy as np + + +def expand_bbox_to_ratio(bbox, imshape, target_aspect_ratio): + x0, y0, x1, y1 = [int(_) for _ in bbox] + h, w = y1 - y0, x1 - x0 + cur_ratio = h / w + + if cur_ratio == target_aspect_ratio: + return [x0, y0, x1, y1] + if cur_ratio < target_aspect_ratio: + target_height = int(w*target_aspect_ratio) + y0, y1 = expand_axis(y0, y1, target_height, imshape[0]) + else: + target_width = int(h/target_aspect_ratio) + x0, x1 = expand_axis(x0, x1, target_width, imshape[1]) + return x0, y0, x1, y1 + + +def expand_axis(start, end, target_width, limit): + # Can return a bbox outside of limit + cur_width = end - start + start = start - (target_width-cur_width)//2 + end = end + (target_width-cur_width)//2 + if end - start != target_width: + end += 1 + assert end - start == target_width + if start < 0 and end > limit: + return start, end + if start < 0 and end < limit: + to_shift = min(0 - start, limit - end) + start += to_shift + end += to_shift + if end > limit and start > 0: + to_shift = min(end - limit, start) + end -= to_shift + start -= to_shift + assert end - start == target_width + return start, end + + +def expand_box(bbox, imshape, mask, percentage_background: float): + assert isinstance(bbox[0], int) + assert 0 < percentage_background < 1 + # Percentage in S + mask_pixels = mask.long().sum().cpu() + total_pixels = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + percentage_mask = mask_pixels / total_pixels + if (1 - percentage_mask) > percentage_background: + return bbox + target_pixels = mask_pixels / (1 - percentage_background) + x0, y0, x1, y1 = bbox + H = y1 - y0 + W = x1 - x0 + p = np.sqrt(target_pixels/(H*W)) + target_width = int(np.ceil(p * W)) + target_height = int(np.ceil(p * H)) + x0, x1 = expand_axis(x0, x1, target_width, imshape[1]) + y0, y1 = expand_axis(y0, y1, target_height, imshape[0]) + return [x0, y0, x1, y1] + + +def expand_axises_by_percentage(bbox_XYXY, imshape, percentage): + x0, y0, x1, y1 = bbox_XYXY + H = y1 - y0 + W = x1 - x0 + expansion = int(((H*W)**0.5) * percentage) + new_width = W + expansion + new_height = H + expansion + x0, x1 = expand_axis(x0, x1, min(new_width, imshape[1]), imshape[1]) + y0, y1 = expand_axis(y0, y1, min(new_height, imshape[0]), imshape[0]) + return [x0, y0, x1, y1] + + +def get_expanded_bbox( + bbox_XYXY, + imshape, + mask, + percentage_background: float, + axis_minimum_expansion: float, + target_aspect_ratio: float): + bbox_XYXY = bbox_XYXY.long().cpu().numpy().tolist() + # Expand each axis of the bounding box by a minimum percentage + bbox_XYXY = expand_axises_by_percentage(bbox_XYXY, imshape, axis_minimum_expansion) + # Find the minimum bbox with the aspect ratio. Can be outside of imshape + bbox_XYXY = expand_bbox_to_ratio(bbox_XYXY, imshape, target_aspect_ratio) + # Expands square box such that X% of the bbox is background + bbox_XYXY = expand_box(bbox_XYXY, imshape, mask, percentage_background) + assert isinstance(bbox_XYXY[0], (int, np.int64)) + return bbox_XYXY + + +def include_box(bbox, minimum_area, aspect_ratio_range, min_bbox_ratio_inside, imshape): + def area_inside_ratio(bbox, imshape): + area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + area_inside = (min(bbox[2], imshape[1]) - max(0,bbox[0])) * (min(imshape[0],bbox[3]) - max(0,bbox[1])) + return area_inside / area + ratio = (bbox[3] - bbox[1]) / (bbox[2] - bbox[0]) + area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) + if area_inside_ratio(bbox, imshape) < min_bbox_ratio_inside: + return False + if ratio <= aspect_ratio_range[0] or ratio >= aspect_ratio_range[1] or area < minimum_area: + return False + return True diff --git a/dp2/detection/box_utils_fdf.py b/dp2/detection/box_utils_fdf.py new file mode 100644 index 0000000000000000000000000000000000000000..48e4e8c6ef067eb495ff8a021d2d236606e2e906 --- /dev/null +++ b/dp2/detection/box_utils_fdf.py @@ -0,0 +1,203 @@ +""" +The FDF dataset expands bound boxes differently from what is used for CSE. +""" + +import numpy as np + + +def quadratic_bounding_box(x0, y0, width, height, imshape): + # We assume that we can create a image that is quadratic without + # minimizing any of the sides + assert width <= min(imshape[:2]) + assert height <= min(imshape[:2]) + min_side = min(height, width) + if height != width: + side_diff = abs(height - width) + # Want to extend the shortest side + if min_side == height: + # Vertical side + height += side_diff + if height > imshape[0]: + # Take full frame, and shrink width + y0 = 0 + height = imshape[0] + + side_diff = abs(height - width) + width -= side_diff + x0 += side_diff // 2 + else: + y0 -= side_diff // 2 + y0 = max(0, y0) + else: + # Horizontal side + width += side_diff + if width > imshape[1]: + # Take full frame width, and shrink height + x0 = 0 + width = imshape[1] + + side_diff = abs(height - width) + height -= side_diff + y0 += side_diff // 2 + else: + x0 -= side_diff // 2 + x0 = max(0, x0) + # Check that bbox goes outside image + x1 = x0 + width + y1 = y0 + height + if imshape[1] < x1: + diff = x1 - imshape[1] + x0 -= diff + if imshape[0] < y1: + diff = y1 - imshape[0] + y0 -= diff + assert x0 >= 0, "Bounding box outside image." + assert y0 >= 0, "Bounding box outside image." + assert x0 + width <= imshape[1], "Bounding box outside image." + assert y0 + height <= imshape[0], "Bounding box outside image." + return x0, y0, width, height + + +def expand_bounding_box(bbox, percentage, imshape): + orig_bbox = bbox.copy() + x0, y0, x1, y1 = bbox + width = x1 - x0 + height = y1 - y0 + x0, y0, width, height = quadratic_bounding_box( + x0, y0, width, height, imshape) + expanding_factor = int(max(height, width) * percentage) + + possible_max_expansion = [(imshape[0] - width) // 2, + (imshape[1] - height) // 2, + expanding_factor] + + expanding_factor = min(possible_max_expansion) + # Expand height + + if expanding_factor > 0: + + y0 = y0 - expanding_factor + y0 = max(0, y0) + + height += expanding_factor * 2 + if height > imshape[0]: + y0 -= (imshape[0] - height) + height = imshape[0] + + if height + y0 > imshape[0]: + y0 -= (height + y0 - imshape[0]) + + # Expand width + x0 = x0 - expanding_factor + x0 = max(0, x0) + + width += expanding_factor * 2 + if width > imshape[1]: + x0 -= (imshape[1] - width) + width = imshape[1] + + if width + x0 > imshape[1]: + x0 -= (width + x0 - imshape[1]) + y1 = y0 + height + x1 = x0 + width + assert y0 >= 0, "Y0 is minus" + assert height <= imshape[0], "Height is larger than image." + assert x0 + width <= imshape[1] + assert y0 + height <= imshape[0] + assert width == height, "HEIGHT IS NOT EQUAL WIDTH!!" + assert x0 >= 0, "Y0 is minus" + assert width <= imshape[1], "Height is larger than image." + # Check that original bbox is within new + x0_o, y0_o, x1_o, y1_o = orig_bbox + assert x0 <= x0_o, f"New bbox is outisde of original. O:{x0_o}, N: {x0}" + assert x1 >= x1_o, f"New bbox is outisde of original. O:{x1_o}, N: {x1}" + assert y0 <= y0_o, f"New bbox is outisde of original. O:{y0_o}, N: {y0}" + assert y1 >= y1_o, f"New bbox is outisde of original. O:{y1_o}, N: {y1}" + + x0, y0, width, height = [int(_) for _ in [x0, y0, width, height]] + x1 = x0 + width + y1 = y0 + height + return np.array([x0, y0, x1, y1]) + + +def is_keypoint_within_bbox(x0, y0, x1, y1, keypoint): + keypoint = keypoint[:, :3] # only nose + eyes are relevant + kp_X = keypoint[0, :] + kp_Y = keypoint[1, :] + within_X = np.all(kp_X >= x0) and np.all(kp_X <= x1) + within_Y = np.all(kp_Y >= y0) and np.all(kp_Y <= y1) + return within_X and within_Y + + +def expand_bbox_simple(bbox, percentage): + x0, y0, x1, y1 = bbox.astype(float) + width = x1 - x0 + height = y1 - y0 + x_c = int(x0) + width // 2 + y_c = int(y0) + height // 2 + avg_size = max(width, height) + new_width = avg_size * (1 + percentage) + x0 = x_c - new_width // 2 + y0 = y_c - new_width // 2 + x1 = x_c + new_width // 2 + y1 = y_c + new_width // 2 + return np.array([x0, y0, x1, y1]).astype(int) + + +def pad_image(im, bbox, pad_value): + x0, y0, x1, y1 = bbox + if x0 < 0: + pad_im = np.zeros((im.shape[0], abs(x0), im.shape[2]), + dtype=np.uint8) + pad_value + im = np.concatenate((pad_im, im), axis=1) + x1 += abs(x0) + x0 = 0 + if y0 < 0: + pad_im = np.zeros((abs(y0), im.shape[1], im.shape[2]), + dtype=np.uint8) + pad_value + im = np.concatenate((pad_im, im), axis=0) + y1 += abs(y0) + y0 = 0 + if x1 >= im.shape[1]: + pad_im = np.zeros( + (im.shape[0], x1 - im.shape[1] + 1, im.shape[2]), + dtype=np.uint8) + pad_value + im = np.concatenate((im, pad_im), axis=1) + if y1 >= im.shape[0]: + pad_im = np.zeros( + (y1 - im.shape[0] + 1, im.shape[1], im.shape[2]), + dtype=np.uint8) + pad_value + im = np.concatenate((im, pad_im), axis=0) + return im[y0:y1, x0:x1] + + +def clip_box(bbox, im): + bbox[0] = max(0, bbox[0]) + bbox[1] = max(0, bbox[1]) + bbox[2] = min(im.shape[1] - 1, bbox[2]) + bbox[3] = min(im.shape[0] - 1, bbox[3]) + return bbox + + +def cut_face(im, bbox, simple_expand=False, pad_value=0, pad_im=True): + outside_im = (bbox < 0).any() or bbox[2] > im.shape[1] or bbox[3] > im.shape[0] + if simple_expand or (outside_im and pad_im): + return pad_image(im, bbox, pad_value) + bbox = clip_box(bbox, im) + x0, y0, x1, y1 = bbox + return im[y0:y1, x0:x1] + + +def expand_bbox( + bbox_ltrb, imshape, simple_expand, default_to_simple=False, + expansion_factor=0.35): + assert bbox_ltrb.shape == (4,), f"BBox shape was: {bbox.shape}" + bbox = bbox_ltrb.astype(float) + # FDF256 uses simple expand with ratio 0.4 + if simple_expand: + return expand_bbox_simple(bbox, 0.4) + try: + return expand_bounding_box(bbox, expansion_factor, imshape) + except AssertionError: + return expand_bbox_simple(bbox, expansion_factor * 2) + diff --git a/dp2/detection/cse_mask_face_detector.py b/dp2/detection/cse_mask_face_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..74a8cf43eb35516e5e2c2c3354e15fc50ea88016 --- /dev/null +++ b/dp2/detection/cse_mask_face_detector.py @@ -0,0 +1,116 @@ +import torch +import lzma +import tops +from pathlib import Path +from dp2.detection.base import BaseDetector +from .utils import combine_cse_maskrcnn_dets +from face_detection import build_detector as build_face_detector +from .models.cse import CSEDetector +from .models.mask_rcnn import MaskRCNNDetector +from .structures import CSEPersonDetection, VehicleDetection, FaceDetection, PersonDetection +from tops import logger + + +def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor): + assert len(box1.shape) == 2 + assert len(box2.shape) == 2 + box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool) + # This can be batched + for i, box in enumerate(box1): + is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1) + is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1) + is_outside = is_outside_lefttop.logical_or(is_outside_rightbot) + box1_inside[i] = is_outside.logical_not().any() + return box1_inside + + +class CSeMaskFaceDetector(BaseDetector): + + def __init__( + self, + mask_rcnn_cfg, + face_detector_cfg: dict, + cse_cfg: dict, + face_post_process_cfg: dict, + cse_post_process_cfg, + score_threshold: float, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold) + if "confidence_threshold" not in face_detector_cfg: + face_detector_cfg["confidence_threshold"] = score_threshold + if "score_thres" not in cse_cfg: + cse_cfg["score_thres"] = score_threshold + self.cse_detector = CSEDetector(**cse_cfg) + self.face_detector = build_face_detector(**face_detector_cfg, clip_boxes=True) + self.cse_post_process_cfg = cse_post_process_cfg + self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1)) + self.mask_cse_iou_combine_threshold = self.cse_post_process_cfg.pop("iou_combine_threshold") + self.face_post_process_cfg = face_post_process_cfg + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def _detect_faces(self, im: torch.Tensor): + H, W = im.shape[1:] + im = im.float() - self.face_mean + im = self.face_detector.resize(im[None], 1.0) + boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score + boxes_XYXY[:, [0, 2]] *= W + boxes_XYXY[:, [1, 3]] *= H + return boxes_XYXY.round().long() + + def load_from_cache(self, cache_path: Path): + logger.log(f"Loading detection from cache path: {cache_path}",) + with lzma.open(cache_path, "rb") as fp: + state_dict = torch.load(fp, map_location="cpu") + kwargs = dict( + post_process_cfg=self.cse_post_process_cfg, + embed_map=self.cse_detector.embed_map, + **self.face_post_process_cfg + ) + return [ + state["cls"].from_state_dict(**kwargs, state_dict=state) + for state in state_dict + ] + + @torch.no_grad() + def forward(self, im: torch.Tensor): + maskrcnn_dets = self.mask_rcnn(im) + cse_dets = self.cse_detector(im) + embed_map = self.cse_detector.embed_map + print("Calling face detector.") + face_boxes = self._detect_faces(im).cpu() + maskrcnn_person = { + k: v[maskrcnn_dets["is_person"]] for k, v in maskrcnn_dets.items() + } + maskrcnn_other = { + k: v[maskrcnn_dets["is_person"].logical_not()] for k, v in maskrcnn_dets.items() + } + maskrcnn_other = VehicleDetection(maskrcnn_other["segmentation"]) + combined_segmentation, cse_dets, matches = combine_cse_maskrcnn_dets( + maskrcnn_person["segmentation"], cse_dets, self.mask_cse_iou_combine_threshold) + + persons_with_cse = CSEPersonDetection( + combined_segmentation, cse_dets, **self.cse_post_process_cfg, + embed_map=embed_map,orig_imshape_CHW=im.shape + ) + persons_with_cse.pre_process() + not_matched = [i for i in range(maskrcnn_person["segmentation"].shape[0]) if i not in matches[:, 0]] + persons_without_cse = PersonDetection( + maskrcnn_person["segmentation"][not_matched], **self.cse_post_process_cfg, + orig_imshape_CHW=im.shape + ) + persons_without_cse.pre_process() + + face_boxes_covered = box1_inside_box2(face_boxes, persons_with_cse.dilated_boxes).logical_or( + box1_inside_box2(face_boxes, persons_without_cse.dilated_boxes) + ) + face_boxes = face_boxes[face_boxes_covered.logical_not()] + face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg) + + # Order matters. The anonymizer will anonymize FIFO. + # Later detections will overwrite. + all_detections = [face_boxes, maskrcnn_other, persons_without_cse, persons_with_cse] + return all_detections diff --git a/dp2/detection/face_detector.py b/dp2/detection/face_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..b05565bc3bc095edf1760c24c4238fe20b9962dc --- /dev/null +++ b/dp2/detection/face_detector.py @@ -0,0 +1,62 @@ +import torch +import lzma +import tops +from pathlib import Path +from dp2.detection.base import BaseDetector +from face_detection import build_detector as build_face_detector +from .structures import FaceDetection +from tops import logger + + +def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor): + assert len(box1.shape) == 2 + assert len(box2.shape) == 2 + box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool) + # This can be batched + for i, box in enumerate(box1): + is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1) + is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1) + is_outside = is_outside_lefttop.logical_or(is_outside_rightbot) + box1_inside[i] = is_outside.logical_not().any() + return box1_inside + + +class FaceDetector(BaseDetector): + + def __init__( + self, + face_detector_cfg: dict, + score_threshold: float, + face_post_process_cfg: dict, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.face_detector = build_face_detector(**face_detector_cfg, confidence_threshold=score_threshold) + self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1)) + self.face_post_process_cfg = face_post_process_cfg + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def _detect_faces(self, im: torch.Tensor): + H, W = im.shape[1:] + im = im.float() - self.face_mean + im = self.face_detector.resize(im[None], 1.0) + boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score + boxes_XYXY[:, [0, 2]] *= W + boxes_XYXY[:, [1, 3]] *= H + return boxes_XYXY.round().long().cpu() + + @torch.no_grad() + def forward(self, im: torch.Tensor): + face_boxes = self._detect_faces(im) + face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg) + return [face_boxes] + + def load_from_cache(self, cache_path: Path): + logger.log(f"Loading detection from cache path: {cache_path}") + with lzma.open(cache_path, "rb") as fp: + state_dict = torch.load(fp) + return [ + state["cls"].from_state_dict(state_dict=state, **self.face_post_process_cfg) for state in state_dict + ] diff --git a/dp2/detection/models/__init__.py b/dp2/detection/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dp2/detection/models/cse.py b/dp2/detection/models/cse.py new file mode 100644 index 0000000000000000000000000000000000000000..fe6b0f6c876ed86d542604ea0f4d7274afe31584 --- /dev/null +++ b/dp2/detection/models/cse.py @@ -0,0 +1,135 @@ +import torch +from typing import List +import tops +from torchvision.transforms.functional import InterpolationMode, resize +from densepose.data.utils import get_class_to_mesh_name_mapping +from densepose import add_densepose_config +from densepose.structures import DensePoseEmbeddingPredictorOutput +from densepose.vis.extractor import DensePoseOutputsExtractor +from densepose.modeling import build_densepose_embedder +from detectron2.config import get_cfg +from detectron2.data.transforms import ResizeShortestEdge +from detectron2.checkpoint.detection_checkpoint import DetectionCheckpointer +from detectron2.modeling import build_model + + +model_urls = { + "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x/250713061/model_final_1d3314.pkl", + "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_50_FPN_s1x.yaml": "https://dl.fbaipublicfiles.com/densepose/cse/densepose_rcnn_R_50_FPN_s1x/251155172/model_final_c4ea5f.pkl", +} + + +def cse_det_to_global(boxes_XYXY, S: torch.Tensor, imshape): + assert len(S.shape) == 3 + H, W = imshape + N = len(boxes_XYXY) + segmentation = torch.zeros((N, H, W), dtype=torch.bool, device=S.device) + boxes_XYXY = boxes_XYXY.long() + for i in range(N): + x0, y0, x1, y1 = boxes_XYXY[i] + assert x0 >= 0 and y0 >= 0 + assert x1 <= imshape[1] + assert y1 <= imshape[0] + h = y1 - y0 + w = x1 - x0 + segmentation[i:i+1, y0:y1, x0:x1] = resize(S[i:i+1], (h, w), interpolation=InterpolationMode.NEAREST) > 0 + return segmentation + + +class CSEDetector: + + def __init__( + self, + cfg_url: str = "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml", + cfg_2_download: List[str] = [ + "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/densepose_rcnn_R_101_FPN_DL_soft_s1x.yaml", + "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN.yaml", + "https://raw.githubusercontent.com/facebookresearch/detectron2/main/projects/DensePose/configs/cse/Base-DensePose-RCNN-FPN-Human.yaml"], + score_thres: float = 0.9, + nms_thresh: float = None, + ) -> None: + with tops.logger.capture_log_stdout(): + cfg = get_cfg() + self.device = tops.get_device() + add_densepose_config(cfg) + cfg_path = tops.download_file(cfg_url) + for p in cfg_2_download: + tops.download_file(p) + with tops.logger.capture_log_stdout(): + cfg.merge_from_file(cfg_path) + assert cfg_url in model_urls, cfg_url + model_path = tops.download_file(model_urls[cfg_url]) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres + if nms_thresh is not None: + cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = nms_thresh + cfg.MODEL.WEIGHTS = str(model_path) + cfg.MODEL.DEVICE = str(self.device) + cfg.freeze() + with tops.logger.capture_log_stdout(): + self.model = build_model(cfg) + self.model.eval() + DetectionCheckpointer(self.model).load(str(model_path)) + self.input_format = cfg.INPUT.FORMAT + self.densepose_extractor = DensePoseOutputsExtractor() + self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg) + + self.embedder = build_densepose_embedder(cfg) + self.mesh_vertex_embeddings = { + mesh_name: self.embedder(mesh_name).to(self.device) + for mesh_name in self.class_to_mesh_name.values() + if self.embedder.has_embeddings(mesh_name) + } + self.cfg = cfg + self.embed_map = self.mesh_vertex_embeddings["smpl_27554"] + tops.logger.log("CSEDetector built.") + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def resize_im(self, im): + H, W = im.shape[1:] + newH, newW = ResizeShortestEdge.get_output_shape( + H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST) + return resize( + im, (newH, newW), InterpolationMode.BILINEAR, antialias=True) + + @torch.no_grad() + def forward(self, im): + assert im.dtype == torch.uint8 + if self.input_format == "BGR": + im = im.flip(0) + H, W = im.shape[1:] + im = self.resize_im(im) + output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"] + scores = output.get("scores") + if len(scores) == 0: + return dict( + instance_segmentation=torch.empty((0, 0, 112, 112), dtype=torch.bool, device=im.device), + instance_embedding=torch.empty((0, 16, 112, 112), dtype=torch.float32, device=im.device), + embed_map=self.mesh_vertex_embeddings["smpl_27554"], + bbox_XYXY=torch.empty((0, 4), dtype=torch.long, device=im.device), + im_segmentation=torch.empty((0, H, W), dtype=torch.bool, device=im.device), + scores=torch.empty((0), dtype=torch.float, device=im.device) + ) + pred_densepose, boxes_xywh, classes = self.densepose_extractor(output) + assert isinstance(pred_densepose, DensePoseEmbeddingPredictorOutput), pred_densepose + S = pred_densepose.coarse_segm.argmax(dim=1) # Segmentation channel Nx2xHxW (2 because only 2 classes) + E = pred_densepose.embedding + mesh_name = self.class_to_mesh_name[classes[0]] + assert mesh_name == "smpl_27554" + x0, y0, w, h = [boxes_xywh[:, i] for i in range(4)] + boxes_XYXY = torch.stack((x0, y0, x0+w, y0+h), dim=-1) + boxes_XYXY = boxes_XYXY.round_().long() + + non_empty_boxes = (boxes_XYXY[:, :2] == boxes_XYXY[:, 2:]).any(dim=1).logical_not() + S = S[non_empty_boxes] + E = E[non_empty_boxes] + boxes_XYXY = boxes_XYXY[non_empty_boxes] + scores = scores[non_empty_boxes] + im_segmentation = cse_det_to_global(boxes_XYXY, S, [H, W]) + return dict( + instance_segmentation=S, instance_embedding=E, + bbox_XYXY=boxes_XYXY, + im_segmentation=im_segmentation, + scores=scores.view(-1)) + diff --git a/dp2/detection/models/keypoint_maskrcnn.py b/dp2/detection/models/keypoint_maskrcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..4fc3fd9e19aa8a023ad8135f6e6997135049c3db --- /dev/null +++ b/dp2/detection/models/keypoint_maskrcnn.py @@ -0,0 +1,111 @@ +import numpy as np +import torch +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.modeling.roi_heads import CascadeROIHeads, StandardROIHeads +from detectron2.data.transforms import ResizeShortestEdge +from detectron2.structures import Instances +from detectron2 import model_zoo +from detectron2.config import instantiate +from detectron2.config import LazyCall as L +from PIL import Image +import tops +import functools +from torchvision.transforms.functional import resize + + +def get_rn50_fpn_keypoint_rcnn(weight_path: str): + from detectron2.modeling.poolers import ROIPooler + from detectron2.modeling.roi_heads import KRCNNConvDeconvUpsampleHead + from detectron2.layers import ShapeSpec + model = model_zoo.get_config("common/models/mask_rcnn_fpn.py").model + model.roi_heads.update( + num_classes=1, + keypoint_in_features=["p2", "p3", "p4", "p5"], + keypoint_pooler=L(ROIPooler)( + output_size=14, + scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), + sampling_ratio=0, + pooler_type="ROIAlignV2", + ), + keypoint_head=L(KRCNNConvDeconvUpsampleHead)( + input_shape=ShapeSpec(channels=256, width=14, height=14), + num_keypoints=17, + conv_dims=[512] * 8, + loss_normalizer="visible", + ), + ) + + # Detectron1 uses 2000 proposals per-batch, but this option is per-image in detectron2. + # 1000 proposals per-image is found to hurt box AP. + # Therefore we increase it to 1500 per-image. + model.proposal_generator.post_nms_topk = (1500, 1000) + + # Keypoint AP degrades (though box AP improves) when using plain L1 loss + model.roi_heads.box_predictor.smooth_l1_beta = 0.5 + model = instantiate(model) + + dataloader = model_zoo.get_config("common/data/coco_keypoint.py").dataloader + test_transform = instantiate(dataloader.test.mapper.augmentations) + DetectionCheckpointer(model).load(weight_path) + return model, test_transform + + +models = { + "rn50_fpn_maskrcnn": functools.partial(get_rn50_fpn_keypoint_rcnn, weight_path="https://folk.ntnu.no/haakohu/checkpoints/maskrcnn_keypoint/keypoint_maskrcnn_R_50_FPN_1x.pth") +} + + + + +class KeypointMaskRCNN: + + def __init__(self, model_name: str, score_threshold: float) -> None: + assert model_name in models, f"Did not find {model_name} in models" + model, test_transform = models[model_name]() + self.model = model.eval().to(tops.get_device()) + if isinstance(self.model.roi_heads, CascadeROIHeads): + for head in self.model.roi_heads.box_predictors: + assert hasattr(head, "test_score_thresh") + head.test_score_thresh = score_threshold + else: + assert isinstance(self.model.roi_heads, StandardROIHeads) + assert hasattr(self.model.roi_heads.box_predictor, "test_score_thresh") + self.model.roi_heads.box_predictor.test_score_thresh = score_threshold + + self.test_transform = test_transform + assert len(self.test_transform) == 1 + self.test_transform = self.test_transform[0] + assert isinstance(self.test_transform, ResizeShortestEdge) + assert self.test_transform.interp == Image.BILINEAR + self.image_format = self.model.input_format + + def resize_im(self, im): + H, W = im.shape[-2:] + if self.test_transform.is_range: + size = np.random.randint(self.test_transform.short_edge_length[0], self.test_transform.short_edge_length[1] + 1) + else: + size = np.random.choice(self.test_transform.short_edge_length) + newH, newW = ResizeShortestEdge.get_output_shape(H, W, size, self.test_transform.max_size) + return resize( + im, (newH, newW), antialias=True) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @torch.no_grad() + def forward(self, im: torch.Tensor) -> Instances: + assert im.ndim == 3 + if self.image_format == "BGR": + im = im.flip(0) + H, W = im.shape[-2:] + im = self.resize_im(im) + im = im.float() + inputs = dict(image=im, height=H, width=W) + # instances contains + # dict_keys(['pred_boxes', 'scores', 'pred_classes', 'pred_masks', 'pred_keypoints', 'pred_keypoint_heatmaps']) + instances = self.model([inputs])[0]["instances"] + return dict( + scores=instances.get("scores").cpu(), + segmentation=instances.get("pred_masks").cpu(), + keypoints=instances.get("pred_keypoints").cpu() + ) diff --git a/dp2/detection/models/mask_rcnn.py b/dp2/detection/models/mask_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..1f87d151709c8adede45aa00de0c4bce0287114e --- /dev/null +++ b/dp2/detection/models/mask_rcnn.py @@ -0,0 +1,78 @@ +import torch +import tops +from detectron2.modeling import build_model +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.structures import Boxes +from detectron2.data import MetadataCatalog +from detectron2 import model_zoo +from typing import Dict +from detectron2.data.transforms import ResizeShortestEdge +from torchvision.transforms.functional import resize + + + +model_urls = { + "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml": "https://dl.fbaipublicfiles.com/detectron2/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x/139653917/model_final_2d9806.pkl", + +} +class MaskRCNNDetector: + + def __init__( + self, + cfg_name: str = "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml", + score_thres: float = 0.9, + class_filter=["person"], #["car", "bicycle","truck", "bus", "backpack"] + fp16_inference: bool = False + ) -> None: + cfg = model_zoo.get_config(cfg_name) + cfg.MODEL.DEVICE = str(tops.get_device()) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = score_thres + cfg.freeze() + self.cfg = cfg + with tops.logger.capture_log_stdout(): + self.model = build_model(cfg) + DetectionCheckpointer(self.model).load(model_urls[cfg_name]) + self.model.eval() + self.input_format = cfg.INPUT.FORMAT + self.class_names = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes + self.class_to_keep = set([self.class_names.index(cls_) for cls_ in class_filter]) + self.person_class = self.class_names.index("person") + self.fp16_inference = fp16_inference + tops.logger.log("Mask R-CNN built.") + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def resize_im(self, im): + H, W = im.shape[1:] + newH, newW = ResizeShortestEdge.get_output_shape( + H, W, self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST) + return resize( + im, (newH, newW), antialias=True) + + @torch.no_grad() + def forward(self, im: torch.Tensor): + if self.input_format == "BGR": + im = im.flip(0) + else: + assert self.input_format == "RGB" + H, W = im.shape[-2:] + im = self.resize_im(im) + with torch.cuda.amp.autocast(enabled=self.fp16_inference): + output = self.model([{"image": im, "height": H, "width": W}])[0]["instances"] + scores = output.get("scores") + N = len(scores) + classes = output.get("pred_classes") + idx2keep = [i for i in range(N) if classes[i].tolist() in self.class_to_keep] + classes = classes[idx2keep] + assert isinstance(output.get("pred_boxes"), Boxes) + segmentation = output.get("pred_masks")[idx2keep] + assert segmentation.dtype == torch.bool + is_person = classes == self.person_class + return { + "scores": output.get("scores")[idx2keep], + "segmentation": segmentation, + "classes": output.get("pred_classes")[idx2keep], + "is_person": is_person + } + diff --git a/dp2/detection/person_detector.py b/dp2/detection/person_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..1bbd0df8c2aa44839a5de8bd9a6aeede054ff2ee --- /dev/null +++ b/dp2/detection/person_detector.py @@ -0,0 +1,135 @@ +import torch +import lzma +from dp2.detection.base import BaseDetector +from .utils import combine_cse_maskrcnn_dets +from .models.cse import CSEDetector +from .models.mask_rcnn import MaskRCNNDetector +from .models.keypoint_maskrcnn import KeypointMaskRCNN +from .structures import CSEPersonDetection, PersonDetection +from pathlib import Path + + +class CSEPersonDetector(BaseDetector): + def __init__( + self, + score_threshold: float, + mask_rcnn_cfg: dict, + cse_cfg: dict, + cse_post_process_cfg: dict, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold) + self.cse_detector = CSEDetector(**cse_cfg, score_thres=score_threshold) + self.post_process_cfg = cse_post_process_cfg + self.iou_combine_threshold = self.post_process_cfg.pop("iou_combine_threshold") + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def load_from_cache(self, cache_path: Path): + with lzma.open(cache_path, "rb") as fp: + state_dict = torch.load(fp) + kwargs = dict( + post_process_cfg=self.post_process_cfg, + embed_map=self.cse_detector.embed_map, + ) + return [ + state["cls"].from_state_dict(**kwargs, state_dict=state) + for state in state_dict + ] + + @torch.no_grad() + def forward(self, im: torch.Tensor, cse_dets=None): + mask_dets = self.mask_rcnn(im) + if cse_dets is None: + cse_dets = self.cse_detector(im) + segmentation = mask_dets["segmentation"] + segmentation, cse_dets, _ = combine_cse_maskrcnn_dets( + segmentation, cse_dets, self.iou_combine_threshold + ) + det = CSEPersonDetection( + segmentation=segmentation, + cse_dets=cse_dets, + embed_map=self.cse_detector.embed_map, + orig_imshape_CHW=im.shape, + **self.post_process_cfg + ) + return [det] + + +class MaskRCNNPersonDetector(BaseDetector): + def __init__( + self, + score_threshold: float, + mask_rcnn_cfg: dict, + cse_post_process_cfg: dict, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold) + self.post_process_cfg = cse_post_process_cfg + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def load_from_cache(self, cache_path: Path): + with lzma.open(cache_path, "rb") as fp: + state_dict = torch.load(fp) + kwargs = dict( + post_process_cfg=self.post_process_cfg, + ) + return [ + state["cls"].from_state_dict(**kwargs, state_dict=state) + for state in state_dict + ] + + @torch.no_grad() + def forward(self, im: torch.Tensor): + mask_dets = self.mask_rcnn(im) + segmentation = mask_dets["segmentation"] + det = PersonDetection( + segmentation, **self.post_process_cfg, orig_imshape_CHW=im.shape + ) + return [det] + + +class KeypointMaskRCNNPersonDetector(BaseDetector): + def __init__( + self, + score_threshold: float, + mask_rcnn_cfg: dict, + cse_post_process_cfg: dict, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.mask_rcnn = KeypointMaskRCNN( + **mask_rcnn_cfg, score_threshold=score_threshold + ) + self.post_process_cfg = cse_post_process_cfg + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def load_from_cache(self, cache_path: Path): + with lzma.open(cache_path, "rb") as fp: + state_dict = torch.load(fp) + kwargs = dict( + post_process_cfg=self.post_process_cfg, + ) + return [ + state["cls"].from_state_dict(**kwargs, state_dict=state) + for state in state_dict + ] + + @torch.no_grad() + def forward(self, im: torch.Tensor): + mask_dets = self.mask_rcnn(im) + segmentation = mask_dets["segmentation"] + det = PersonDetection( + segmentation, + **self.post_process_cfg, + orig_imshape_CHW=im.shape, + keypoints=mask_dets["keypoints"] + ) + return [det] diff --git a/dp2/detection/structures.py b/dp2/detection/structures.py new file mode 100644 index 0000000000000000000000000000000000000000..a7e92083091506db637532154ee4d4131c5aa842 --- /dev/null +++ b/dp2/detection/structures.py @@ -0,0 +1,463 @@ +import torch +import numpy as np +from dp2 import utils +from dp2.utils import vis_utils, crop_box +from .utils import ( + cut_pad_resize, masks_to_boxes, + get_kernel, transform_embedding, initialize_cse_boxes + ) +from .box_utils import get_expanded_bbox, include_box +import torchvision +import tops +from .box_utils_fdf import expand_bbox as expand_bbox_fdf + + +class VehicleDetection: + + def __init__(self, segmentation: torch.BoolTensor) -> None: + self.segmentation = segmentation + self.boxes = masks_to_boxes(segmentation) + assert self.boxes.shape[1] == 4, self.boxes.shape + self.n_detections = self.segmentation.shape[0] + area = (self.boxes[:, 3] - self.boxes[:, 1]) * (self.boxes[:, 2] - self.boxes[:, 0]) + + sorted_idx = torch.argsort(area, descending=True) + self.segmentation = self.segmentation[sorted_idx] + self.boxes = self.boxes[sorted_idx].cpu() + + def pre_process(self): + pass + + def get_crop(self, idx: int, im): + assert idx < len(self) + box = self.boxes[idx] + im = crop_box(self.im, box) + mask = crop_box(self.segmentation[idx]) + mask = mask == 0 + return dict(img=im, mask=mask.float(), boxes=box) + + def visualize(self, im): + if len(self) == 0: + return im + im = vis_utils.draw_mask(im.clone(), self.segmentation.logical_not()) + return im + + def __len__(self): + return self.n_detections + + @staticmethod + def from_state_dict(state_dict, **kwargs): + numel = np.prod(state_dict["shape"]) + arr = np.unpackbits(state_dict["segmentation"].numpy(), count=numel) + segmentation = tops.to_cuda(torch.from_numpy(arr)).view(state_dict["shape"]) + return VehicleDetection(segmentation) + + def state_dict(self, **kwargs): + segmentation = torch.from_numpy(np.packbits(self.segmentation.bool().cpu().numpy())) + return dict(segmentation=segmentation, cls=self.__class__, shape=self.segmentation.shape) + + +class FaceDetection: + + def __init__(self, boxes_ltrb: torch.LongTensor, target_imsize, fdf128_expand: bool, **kwargs) -> None: + self.boxes = boxes_ltrb.cpu() + assert self.boxes.shape[1] == 4, self.boxes.shape + self.target_imsize = tuple(target_imsize) + # Sory by area to paste in largest faces last + area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1) + idx = area.argsort(descending=False) + self.boxes = self.boxes[idx] + self.fdf128_expand = fdf128_expand + + def visualize(self, im): + if len(self) == 0: + return im + orig_device = im.device + for box in self.boxes: + simple_expand = False if self.fdf128_expand else True + e_box = torch.from_numpy(expand_bbox_fdf(box.numpy(), im.shape[-2:], simple_expand)) + im = torchvision.utils.draw_bounding_boxes(im.cpu(), e_box[None], colors=(0, 0, 255), width=2) + im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2) + + return im.to(device=orig_device) + + def get_crop(self, idx: int, im): + assert idx < len(self) + box = self.boxes[idx].numpy() + expanded_boxes = expand_bbox_fdf(box, im.shape[-2:], True) + im = cut_pad_resize(im, expanded_boxes, self.target_imsize, fdf_resize=True) + area = (self.boxes[:, 2] - self.boxes[:, 0]) * (self.boxes[:, 3] - self.boxes[:, 1]).view(-1) + + # Find the square mask corresponding to box. + box_mask = box.copy().astype(float) + box_mask[[0, 2]] -= expanded_boxes[0] + box_mask[[1, 3]] -= expanded_boxes[1] + + width = expanded_boxes[2] - expanded_boxes[0] + resize_factor = self.target_imsize[0] / width + box_mask = (box_mask * resize_factor).astype(int) + mask = torch.ones((1, *self.target_imsize), device=im.device, dtype=torch.float32) + crop_box(mask, box_mask).fill_(0) + return dict( + img=im[None], mask=mask[None], + boxes=torch.from_numpy(expanded_boxes).view(1, -1)) + + def __len__(self): + return len(self.boxes) + + @staticmethod + def from_state_dict(state_dict, **kwargs): + return FaceDetection(state_dict["boxes"].cpu(), **kwargs) + + def state_dict(self, **kwargs): + return dict(boxes=self.boxes, cls=self.__class__) + + def pre_process(self): + pass + + +def remove_dilate_in_pad(mask: torch.Tensor, exp_box, orig_imshape): + """ + Dilation happens after padding, which could place dilation in the padded area. + Remove this. + """ + x0, y0, x1, y1 = exp_box + H, W = orig_imshape + # Padding in original image space + p_y0 = max(0, -y0) + p_y1 = max(y1 - H, 0) + p_x0 = max(0, -x0) + p_x1 = max(x1 - W, 0) + resize_ratio = mask.shape[-2] / (y1-y0) + p_x0, p_y0, p_x1, p_y1 = [(_*resize_ratio).floor().long() for _ in [p_x0, p_y0, p_x1, p_y1]] + mask[..., :p_y0, :] = 0 + mask[..., :p_x0] = 0 + mask[..., mask.shape[-2] - p_y1:, :] = 0 + mask[..., mask.shape[-1] - p_x1:] = 0 + + +class CSEPersonDetection: + + def __init__(self, + segmentation, cse_dets, + target_imsize, + exp_bbox_cfg, exp_bbox_filter, + dilation_percentage: float, + embed_map: torch.Tensor, + orig_imshape_CHW, + normalize_embedding: bool) -> None: + self.segmentation = segmentation + self.cse_dets = cse_dets + self.target_imsize = list(target_imsize) + self.pre_processed = False + self.exp_bbox_cfg = exp_bbox_cfg + self.exp_bbox_filter = exp_bbox_filter + self.dilation_percentage = dilation_percentage + self.embed_map = embed_map + self.normalize_embedding = normalize_embedding + if self.normalize_embedding: + embed_map_mean = self.embed_map.mean(dim=0, keepdim=True) + embed_map_rstd = ((self.embed_map - embed_map_mean).square().mean(dim=0, keepdim=True)+1e-8).rsqrt() + self.embed_map_normalized = (self.embed_map - embed_map_mean) * embed_map_rstd + self.orig_imshape_CHW = orig_imshape_CHW + + @torch.no_grad() + def pre_process(self): + if self.pre_processed: + return + boxes = initialize_cse_boxes(self.segmentation, self.cse_dets["bbox_XYXY"]).cpu() + expanded_boxes = [] + included_boxes = [] + for i in range(len(boxes)): + exp_box = get_expanded_bbox( + boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg, + target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1]) + if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter): + continue + included_boxes.append(i) + expanded_boxes.append(exp_box) + expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4) + self.segmentation = self.segmentation[included_boxes] + self.cse_dets = {k: v[included_boxes] for k, v in self.cse_dets.items()} + + self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool) + area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes)) + for i, box in enumerate(expanded_boxes): + self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0] + + dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage)) + self.maskrcnn_mask = self.mask.clone().logical_not()[:, None] + self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel) + [remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) for i in range(len(expanded_boxes))] + self.boxes = expanded_boxes.cpu() + self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask) + + self.pre_processed = True + self.n_detections = len(self.boxes) + self.mask = self.mask.logical_not() + + E_mask = torch.zeros((self.n_detections, 1, *self.target_imsize), device=self.mask.device, dtype=torch.bool) + self.vertices = torch.zeros_like(E_mask, dtype=torch.long) + for i in range(self.n_detections): + E_, E_mask[i] = transform_embedding( + self.cse_dets["instance_embedding"][i], + self.cse_dets["instance_segmentation"][i], + self.boxes[i], + self.cse_dets["bbox_XYXY"][i].cpu(), + self.target_imsize + ) + self.vertices[i] = utils.from_E_to_vertex(E_[None], E_mask[i:i+1].logical_not(), self.embed_map).squeeze()[None] + self.E_mask = E_mask + + sorted_idx = torch.argsort(area, descending=False) + self.mask = self.mask[sorted_idx] + self.boxes = self.boxes[sorted_idx.cpu()] + self.vertices = self.vertices[sorted_idx] + self.E_mask = self.E_mask[sorted_idx] + self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx] + + def get_crop(self, idx: int, im): + self.pre_process() + assert idx < len(self) + box = self.boxes[idx] + mask = self.mask[idx] + im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0) + + vertices_ = self.vertices[idx] + E_mask_ = self.E_mask[idx].float() + if self.normalize_embedding: + embedding = self.embed_map_normalized[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_ + else: + embedding = self.embed_map[vertices_.squeeze(dim=0)].permute(2, 0, 1) * E_mask_ + + return dict( + img=im, + mask=mask.float()[None], + boxes=box.reshape(1, -1), + E_mask=E_mask_[None], + vertices=vertices_[None], + embed_map=self.embed_map, + embedding=embedding[None], + maskrcnn_mask=self.maskrcnn_mask[idx].float()[None] + ) + + def __len__(self): + self.pre_process() + return self.n_detections + + def state_dict(self, after_preprocess=False): + """ + The processed annotations occupy more space than the original detections. + """ + if not after_preprocess: + return { + "combined_segmentation": self.segmentation.bool(), + "cse_instance_segmentation": self.cse_dets["instance_segmentation"].bool(), + "cse_instance_embedding": self.cse_dets["instance_embedding"], + "cse_bbox_XYXY": self.cse_dets["bbox_XYXY"].long(), + "cls": self.__class__, + "orig_imshape_CHW": self.orig_imshape_CHW + } + self.pre_process() + return dict( + E_mask=torch.from_numpy(np.packbits(self.E_mask.bool().cpu().numpy())), + mask=torch.from_numpy(np.packbits(self.mask.bool().cpu().numpy())), + maskrcnn_mask=torch.from_numpy(np.packbits(self.maskrcnn_mask.bool().cpu().numpy())), + vertices=self.vertices.to(torch.int16).cpu(), + cls=self.__class__, + boxes=self.boxes, + orig_imshape_CHW=self.orig_imshape_CHW, + ) + + @staticmethod + def from_state_dict( + state_dict, embed_map, + post_process_cfg, **kwargs): + after_preprocess = "segmentation" not in state_dict and "combined_segmentation" not in state_dict + if after_preprocess: + detection = CSEPersonDetection( + segmentation=None, cse_dets=None, embed_map=embed_map, + orig_imshape_CHW=state_dict["orig_imshape_CHW"], + **post_process_cfg) + detection.vertices = tops.to_cuda(state_dict["vertices"].long()) + numel = np.prod(detection.vertices.shape) + detection.E_mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["E_mask"].numpy(), count=numel))).view(*detection.vertices.shape) + detection.mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["mask"].numpy(), count=numel))).view(*detection.vertices.shape) + detection.maskrcnn_mask = tops.to_cuda(torch.from_numpy(np.unpackbits(state_dict["maskrcnn_mask"].numpy(), count=numel))).view(*detection.vertices.shape) + detection.n_detections = len(detection.mask) + detection.pre_processed = True + + if isinstance(state_dict["boxes"], np.ndarray): + state_dict["boxes"] = torch.from_numpy(state_dict["boxes"]) + detection.boxes = state_dict["boxes"] + return detection + + cse_dets = dict( + instance_segmentation=state_dict["cse_instance_segmentation"], + instance_embedding=state_dict["cse_instance_embedding"], + embed_map=embed_map, + bbox_XYXY=state_dict["cse_bbox_XYXY"]) + cse_dets = {k: tops.to_cuda(v) for k, v in cse_dets.items()} + + segmentation = state_dict["combined_segmentation"] + return CSEPersonDetection( + segmentation, cse_dets, embed_map=embed_map, + orig_imshape_CHW=state_dict["orig_imshape_CHW"], + **post_process_cfg) + + def visualize(self, im): + self.pre_process() + if len(self) == 0: + return im + im = vis_utils.draw_cropped_masks( + im.clone(), self.mask, self.boxes, visualize_instances=False) + E = self.embed_map[self.vertices.long()].squeeze(1).permute(0,3, 1, 2) + im = im.to(E.device) + im = vis_utils.draw_cse_all( + E, self.E_mask.squeeze(1).bool(), im, + self.boxes, self.embed_map) + im = torchvision.utils.draw_bounding_boxes(im.cpu(), self.boxes, colors=(255, 0, 0), width=2) + return im + + +def shift_and_preprocess_keypoints(keypoints: torch.Tensor, boxes): + keypoints = keypoints.clone() + N = boxes.shape[0] + tops.assert_shape(keypoints, (N, None, 3)) + tops.assert_shape(boxes, (N, 4)) + x0, y0, x1, y1 = [_.view(-1, 1) for _ in boxes.T] + + w = x1 - x0 + h = y1 - y0 + keypoints[:, :, 0] = (keypoints[:, :, 0] - x0) / w + keypoints[:, :, 1] = (keypoints[:, :, 1] - y0) / h + check_outside = lambda x: (x < 0).logical_or(x > 1) + is_outside = check_outside(keypoints[:, :, 0]).logical_or(check_outside(keypoints[:, :, 1])) + keypoints[:, :, 2] = keypoints[:, :, 2] >= 0 + keypoints[:, :, 2] = (keypoints[:, :, 2] > 0).logical_and(is_outside.logical_not()) + return keypoints + + +class PersonDetection: + + def __init__( + self, + segmentation, + target_imsize, + exp_bbox_cfg, exp_bbox_filter, + dilation_percentage: float, + orig_imshape_CHW, + keypoints=None, + **kwargs) -> None: + self.segmentation = segmentation + self.target_imsize = list(target_imsize) + self.pre_processed = False + self.exp_bbox_cfg = exp_bbox_cfg + self.exp_bbox_filter = exp_bbox_filter + self.dilation_percentage = dilation_percentage + self.orig_imshape_CHW = orig_imshape_CHW + self.keypoints = keypoints + + @torch.no_grad() + def pre_process(self): + if self.pre_processed: + return + boxes = masks_to_boxes(self.segmentation).cpu() + expanded_boxes = [] + included_boxes = [] + for i in range(len(boxes)): + exp_box = get_expanded_bbox( + boxes[i], self.orig_imshape_CHW[1:], self.segmentation[i], **self.exp_bbox_cfg, + target_aspect_ratio=self.target_imsize[0]/self.target_imsize[1]) + if not include_box(exp_box, imshape=self.orig_imshape_CHW[1:], **self.exp_bbox_filter): + continue + included_boxes.append(i) + expanded_boxes.append(exp_box) + expanded_boxes = torch.LongTensor(expanded_boxes).view(-1, 4) + self.segmentation = self.segmentation[included_boxes] + if self.keypoints is not None: + self.keypoints = self.keypoints[included_boxes] + area = self.segmentation.sum(dim=[1, 2]).view(len(expanded_boxes)) + self.mask = torch.empty((len(expanded_boxes), *self.target_imsize), device=tops.get_device(), dtype=torch.bool) + for i, box in enumerate(expanded_boxes): + self.mask[i] = cut_pad_resize(self.segmentation[i:i+1], box, self.target_imsize)[0] + if self.keypoints is not None: + self.keypoints = shift_and_preprocess_keypoints(self.keypoints, expanded_boxes) + dilation_kernel = get_kernel(int((self.target_imsize[0]*self.target_imsize[1])**0.5*self.dilation_percentage)) + self.maskrcnn_mask = self.mask.clone().logical_not()[:, None] + self.mask = utils.binary_dilation(self.mask[:, None], dilation_kernel) + + [remove_dilate_in_pad(self.mask[i], expanded_boxes[i], self.orig_imshape_CHW[1:]) for i in range(len(expanded_boxes))] + self.boxes = expanded_boxes + self.dilated_boxes = get_dilated_boxes(self.boxes, self.mask) + + self.pre_processed = True + self.n_detections = len(self.boxes) + self.mask = self.mask.logical_not() + + sorted_idx = torch.argsort(area, descending=False) + self.mask = self.mask[sorted_idx] + self.boxes = self.boxes[sorted_idx.cpu()] + self.segmentation = self.segmentation[sorted_idx] + self.maskrcnn_mask = self.maskrcnn_mask[sorted_idx] + if self.keypoints is not None: + self.keypoints = self.keypoints[sorted_idx] + + def get_crop(self, idx: int, im: torch.Tensor): + assert idx < len(self) + self.pre_process() + box = self.boxes[idx] + mask = self.mask[idx][None].float() + im = cut_pad_resize(im, box, self.target_imsize).unsqueeze(0) + batch = dict( + img=im, mask=mask, boxes=box.reshape(1, -1), + maskrcnn_mask=self.maskrcnn_mask[idx][None].float()) + if self.keypoints is not None: + batch["keypoints"] = self.keypoints[idx:idx+1] + return batch + + def __len__(self): + self.pre_process() + return self.n_detections + + def state_dict(self, **kwargs): + return dict( + segmentation=self.segmentation.bool(), + cls=self.__class__, + orig_imshape_CHW=self.orig_imshape_CHW, + keypoints=self.keypoints + ) + + @staticmethod + def from_state_dict( + state_dict, + post_process_cfg, **kwargs): + return PersonDetection( + state_dict["segmentation"], + orig_imshape_CHW=state_dict["orig_imshape_CHW"], + **post_process_cfg, + keypoints=state_dict["keypoints"]) + + def visualize(self, im): + self.pre_process() + im = im.cpu() + if len(self) == 0: + return im + im = vis_utils.draw_cropped_masks(im.clone(), self.mask, self.boxes, visualize_instances=False) + im = vis_utils.draw_cropped_keypoints(im, self.keypoints, self.boxes) + return im + + +def get_dilated_boxes(exp_bbox: torch.LongTensor, mask): + """ + mask: resized mask + """ + assert exp_bbox.shape[0] == mask.shape[0] + boxes = masks_to_boxes(mask.squeeze(1)).cpu() + H, W = exp_bbox[:, 3] - exp_bbox[:, 1], exp_bbox[:, 2] - exp_bbox[:, 0] + boxes[:, [0, 2]] = (boxes[:, [0, 2]] * W[:, None] / mask.shape[-1]).long() + boxes[:, [1, 3]] = (boxes[:, [1, 3]] * H[:, None] / mask.shape[-2]).long() + boxes[:, [0, 2]] += exp_bbox[:, 0:1] + boxes[:, [1, 3]] += exp_bbox[:, 1:2] + return boxes + diff --git a/dp2/detection/utils.py b/dp2/detection/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..85dbd29c1832d2f48b20ed14c3d0357c958732f6 --- /dev/null +++ b/dp2/detection/utils.py @@ -0,0 +1,174 @@ +import cv2 +import numpy as np +import torch +import tops +from skimage.morphology import disk +from torchvision.transforms.functional import resize, InterpolationMode +from functools import lru_cache + + +@lru_cache(maxsize=200) +def get_kernel(n: int): + kernel = disk(n, dtype=bool) + return tops.to_cuda(torch.from_numpy(kernel).bool()) + + +def transform_embedding(E: torch.Tensor, S: torch.Tensor, exp_bbox, E_bbox, target_imshape): + """ + Transforms the detected embedding/mask directly to the target image shape + """ + + C, HE, WE = E.shape + assert E_bbox[0] >= exp_bbox[0], (E_bbox, exp_bbox) + assert E_bbox[2] >= exp_bbox[0] + assert E_bbox[1] >= exp_bbox[1] + assert E_bbox[3] >= exp_bbox[1] + assert E_bbox[2] <= exp_bbox[2] + assert E_bbox[3] <= exp_bbox[3] + + x0 = int(np.round((E_bbox[0] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1])) + x1 = int(np.round((E_bbox[2] - exp_bbox[0]) / (exp_bbox[2] - exp_bbox[0]) * target_imshape[1])) + y0 = int(np.round((E_bbox[1] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0])) + y1 = int(np.round((E_bbox[3] - exp_bbox[1]) / (exp_bbox[3] - exp_bbox[1]) * target_imshape[0])) + new_E = torch.zeros((C, *target_imshape), device=E.device, dtype=torch.float32) + new_S = torch.zeros((target_imshape), device=S.device, dtype=torch.bool) + + E = resize(E, (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR) + new_E[:, y0:y1, x0:x1] = E + S = resize(S[None].float(), (y1-y0, x1-x0), antialias=True, interpolation=InterpolationMode.BILINEAR)[0] > 0 + new_S[y0:y1, x0:x1] = S + return new_E, new_S + + +def pairwise_mask_iou(mask1: torch.Tensor, mask2: torch.Tensor): + """ + mask: shape [N, H, W] + """ + assert len(mask1.shape) == 3 + assert len(mask2.shape) == 3 + assert mask1.device == mask2.device, (mask1.device, mask2.device) + assert mask2.dtype == mask2.dtype + assert mask1.dtype == torch.bool + assert mask1.shape[1:] == mask2.shape[1:] + N1, H1, W1 = mask1.shape + N2, H2, W2 = mask2.shape + iou = torch.zeros((N1, N2), dtype=torch.float32) + for i in range(N1): + cur = mask1[i:i+1] + inter = torch.logical_and(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu() + union = torch.logical_or(cur, mask2).flatten(start_dim=1).float().sum(dim=1).cpu() + iou[i] = inter / union + return iou + + +def find_best_matches(mask1: torch.Tensor, mask2: torch.Tensor, iou_threshold: float): + N1 = mask1.shape[0] + N2 = mask2.shape[0] + ious = pairwise_mask_iou(mask1, mask2).cpu().numpy() + indices = np.array([idx for idx, iou in np.ndenumerate(ious)]) + ious = ious.flatten() + mask = ious >= iou_threshold + ious = ious[mask] + indices = indices[mask] + + # do not sort by iou to keep ordering of mask rcnn / cse sorting. + taken1 = np.zeros((N1), dtype=bool) + taken2 = np.zeros((N2), dtype=bool) + matches = [] + for i, j in indices: + if taken1[i].any() or taken2[j].any(): + continue + matches.append((i, j)) + taken1[i] = True + taken2[j] = True + return matches + + +def combine_cse_maskrcnn_dets(segmentation: torch.Tensor, cse_dets: dict, iou_threshold: float): + assert 0 < iou_threshold <= 1 + matches = find_best_matches(segmentation, cse_dets["im_segmentation"], iou_threshold) + H, W = segmentation.shape[1:] + new_seg = torch.zeros((len(matches), H, W), dtype=torch.bool, device=segmentation.device) + cse_im_seg = cse_dets["im_segmentation"] + for idx, (i, j) in enumerate(matches): + new_seg[idx] = torch.logical_or(segmentation[i], cse_im_seg[j]) + cse_dets = dict( + instance_segmentation=cse_dets["instance_segmentation"][[j for (i, j) in matches]], + instance_embedding=cse_dets["instance_embedding"][[j for (i, j) in matches]], + bbox_XYXY=cse_dets["bbox_XYXY"][[j for (i, j) in matches]], + scores=cse_dets["scores"][[j for (i, j) in matches]], + ) + return new_seg, cse_dets, np.array(matches).reshape(-1, 2) + + +def initialize_cse_boxes(segmentation: torch.Tensor, cse_boxes: torch.Tensor): + """ + cse_boxes can be outside of segmentation. + """ + boxes = masks_to_boxes(segmentation) + + assert boxes.shape == cse_boxes.shape, (boxes.shape, cse_boxes.shape) + combined = torch.stack((boxes, cse_boxes), dim=-1) + boxes = torch.cat(( + combined[:, :2].min(dim=2).values, + combined[:, 2:].max(dim=2).values, + ), dim=1) + return boxes + + +def cut_pad_resize(x: torch.Tensor, bbox, target_shape, fdf_resize=False): + """ + Crops or pads x to fit in the bbox and resize to target shape. + """ + C, H, W = x.shape + x0, y0, x1, y1 = bbox + + if y0 > 0 and x0 > 0 and x1 <= W and y1 <= H: + new_x = x[:, y0:y1, x0:x1] + else: + new_x = torch.zeros(((C, y1-y0, x1-x0)), dtype=x.dtype, device=x.device) + y0_t = max(0, -y0) + y1_t = min(y1-y0, (y1-y0)-(y1-H)) + x0_t = max(0, -x0) + x1_t = min(x1-x0, (x1-x0)-(x1-W)) + x0 = max(0, x0) + y0 = max(0, y0) + x1 = min(x1, W) + y1 = min(y1, H) + new_x[:, y0_t:y1_t, x0_t:x1_t] = x[:, y0:y1, x0:x1] + if x1 - x0 == target_shape[1] and y1 - y0 == target_shape[0]: + return new_x + if x.dtype == torch.bool: + new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.NEAREST) > 0.5 + elif x.dtype == torch.float32: + new_x = resize(new_x, target_shape, interpolation=InterpolationMode.BILINEAR, antialias=True) + elif x.dtype == torch.uint8: + if fdf_resize: # FDF dataset is created with cv2 INTER_AREA. + # Incorrect resizing generates noticeable poorer inpaintings. + upsampling = ((y1-y0) *(x1-x0)) < (target_shape[0] * target_shape[1]) + if upsampling: + new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BICUBIC, antialias=True).round().clamp(0, 255).byte() + else: + device = new_x.device + new_x = new_x.permute(1, 2, 0).cpu().numpy() + new_x = cv2.resize(new_x, target_shape[::-1], interpolation=cv2.INTER_AREA) + new_x = torch.from_numpy(np.rollaxis(new_x, 2)).to(device) + else: + new_x = resize(new_x.float(), target_shape, interpolation=InterpolationMode.BILINEAR, antialias=True).round().clamp(0, 255).byte() + else: + raise ValueError(f"Not supported dtype: {x.dtype}") + return new_x + + + +def masks_to_boxes(segmentation: torch.Tensor): + assert len(segmentation.shape) == 3 + x = segmentation.any(dim=1).byte() # Compress rows + x0 = x.argmax(dim=1) + + x1 = segmentation.shape[2] - x.flip(dims=(1,)).argmax(dim=1) + y = segmentation.any(dim=2).byte() + y0 = y.argmax(dim=1) + y1 = segmentation.shape[1] - y.flip(dims=(1,)).argmax(dim=1) + return torch.stack([x0, y0, x1, y1], dim=1) + diff --git a/dp2/discriminator/__init__.py b/dp2/discriminator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77a4a773eafecde739caf086660063a19cf2160f --- /dev/null +++ b/dp2/discriminator/__init__.py @@ -0,0 +1 @@ +from .sg2_discriminator import SG2Discriminator \ No newline at end of file diff --git a/dp2/discriminator/sg2_discriminator.py b/dp2/discriminator/sg2_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..5f44e8f9d6cfb46c0763cc8ab5d96037fcd2c4e2 --- /dev/null +++ b/dp2/discriminator/sg2_discriminator.py @@ -0,0 +1,76 @@ +from sg3_torch_utils.ops import upfirdn2d +import torch +import numpy as np +import torch.nn as nn +from .. import layers +from ..layers.sg2_layers import DiscriminatorEpilogue, ResidualBlock, Block + + +class SG2Discriminator(layers.Module): + + def __init__( + self, + cnum: int, + max_cnum_mul: int, + imsize, + min_fmap_resolution: int, + im_channels: int, + input_condition: bool, + conv_clamp: int, + input_cse: bool, + cse_nc: int): + super().__init__() + + cse_nc = 0 if cse_nc is None else cse_nc + self._max_imsize = max(imsize) + self._cnum = cnum + self._max_cnum_mul = max_cnum_mul + self._min_fmap_resolution = min_fmap_resolution + self._input_condition = input_condition + self.input_cse = input_cse + self.layers = nn.ModuleList() + + out_ch = self.get_chsize(self._max_imsize) + self.from_rgb = Block( + im_channels + input_condition*(im_channels+1) + input_cse*(cse_nc+1), + out_ch, conv_clamp=conv_clamp + ) + n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1 + + for i in range(n_levels): + resolution = [x//2**i for x in imsize] + in_ch = self.get_chsize(max(resolution)) + out_ch = self.get_chsize(max(max(resolution)//2, min_fmap_resolution)) + + down = 2 + if i == 0: + down = 1 + block = ResidualBlock( + in_ch, out_ch, down=down, conv_clamp=conv_clamp + ) + self.layers.append(block) + self.output_layer = DiscriminatorEpilogue( + out_ch, resolution, conv_clamp=conv_clamp) + + self.register_buffer('resample_filter', upfirdn2d.setup_filter([1, 3, 3, 1])) + + def forward(self, img, condition, mask, embedding=None, E_mask=None,**kwargs): + to_cat = [img] + if self._input_condition: + to_cat.extend([condition, mask,]) + if self.input_cse: + to_cat.extend([embedding, E_mask]) + x = torch.cat(to_cat, dim=1) + x = self.from_rgb(x) + + for i, layer in enumerate(self.layers): + x = layer(x) + + x = self.output_layer(x) + return dict(score=x) + + def get_chsize(self, imsize): + n = int(np.log2(self._max_imsize) - np.log2(imsize)) + mul = min(2 ** n, self._max_cnum_mul) + ch = self._cnum * mul + return int(ch) diff --git a/dp2/gan_trainer.py b/dp2/gan_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e9cdbcd86f596c7be984b5463e486aafe6a0890e --- /dev/null +++ b/dp2/gan_trainer.py @@ -0,0 +1,324 @@ +import atexit +from collections import defaultdict +import logging +import typing +import torch +import time +from dp2.utils import vis_utils +from dp2 import utils +from tops import logger, checkpointer +import tops +from easydict import EasyDict + + +def accumulate_gradients(params, fp16_ddp_accumulate): + if len(params) == 0: + return + params = [param for param in params if param.grad is not None] + flat = torch.cat([param.grad.flatten() for param in params]) + orig_dtype = flat.dtype + if tops.world_size() > 1: + if fp16_ddp_accumulate: + flat = flat.half() / tops.world_size() + else: + flat /= tops.world_size() + torch.distributed.all_reduce(flat) + flat = flat.to(orig_dtype) + grads = flat.split([param.numel() for param in params]) + for param, grad in zip(params, grads): + param.grad = grad.reshape(param.shape) + + +def accumulate_buffers(module: torch.nn.Module): + buffers = [buf for buf in module.buffers()] + if len(buffers) == 0: + return + flat = torch.cat([buf.flatten() for buf in buffers]) + if tops.world_size() > 1: + torch.distributed.all_reduce(flat) + flat /= tops.world_size() + bufs = flat.split([buf.numel() for buf in buffers]) + for old, new in zip(buffers, bufs): + old.copy_(new.reshape(old.shape), non_blocking=True) + + +def check_ddp_consistency(module): + if tops.world_size() == 1: + return + assert isinstance(module, torch.nn.Module) + assert isinstance(module, torch.nn.Module) + params_buffs = list(module.named_parameters()) + list(module.named_buffers()) + for name, tensor in params_buffs: + fullname = type(module).__name__ + '.' + name + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = torch.nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (tensor == other).all(), fullname + +class AverageMeter(): + def __init__(self) -> None: + self.to_log = dict() + self.n = defaultdict(int) + pass + + @torch.no_grad() + def update(self, values: dict): + for key, value in values.items(): + self.n[key] += 1 + if key in self.to_log: + self.to_log[key] += value.mean().detach() + else: + self.to_log[key] = value.mean().detach() + + def get_average(self): + return {key: value / self.n[key] for key, value in self.to_log.items()} + + +class GANTrainer: + + def __init__( + self, + G: torch.nn.Module, + D: torch.nn.Module, + G_EMA: torch.nn.Module, + D_optim: torch.optim.Optimizer, + G_optim: torch.optim.Optimizer, + dl_train: typing.Iterator, + dl_val: typing.Iterable, + scaler_D: torch.cuda.amp.GradScaler, + scaler_G: torch.cuda.amp.GradScaler, + ims_per_log: int, + max_images_to_train: int, + loss_handler, + ims_per_val: int, + evaluate_fn, + batch_size: int, + broadcast_buffers: bool, + fp16_ddp_accumulate: bool, + save_state: bool, + *args, **kwargs): + super().__init__(*args, **kwargs) + + self.G = G + self.D = D + self.G_EMA = G_EMA + self.D_optim = D_optim + self.G_optim = G_optim + self.dl_train = dl_train + self.dl_val = dl_val + self.scaler_D = scaler_D + self.scaler_G = scaler_G + self.loss_handler = loss_handler + self.max_images_to_train = max_images_to_train + self.images_per_val = ims_per_val + self.images_per_log = ims_per_log + self.evaluate_fn = evaluate_fn + self.batch_size = batch_size + self.broadcast_buffers = broadcast_buffers + self.fp16_ddp_accumulate = fp16_ddp_accumulate + + self.train_state = EasyDict( + next_log_step=0, + next_val_step=ims_per_val, + total_time=0 + ) + + checkpointer.register_models(dict( + generator=G, discriminator=D, EMA_generator=G_EMA, + D_optimizer=D_optim, + G_optimizer=G_optim, + train_state=self.train_state, + scaler_D=self.scaler_D, + scaler_G=self.scaler_G + )) + if checkpointer.has_checkpoint(): + checkpointer.load_registered_models() + logger.log(f"Resuming training from: global step: {logger.global_step()}") + else: + logger.add_dict({ + "stats/discriminator_parameters": tops.num_parameters(self.D), + "stats/generator_parameters": tops.num_parameters(self.G), + }, commit=False) + if save_state: + # If the job is unexpectedly killed, there could be a mismatch between previously saved checkpoint and the current checkpoint. + atexit.register(checkpointer.save_registered_models) + + self._ims_per_log = ims_per_log + + self.to_log = AverageMeter() + self.trainable_params_D = [param for param in self.D.parameters() if param.requires_grad] + self.trainable_params_G = [param for param in self.G.parameters() if param.requires_grad] + logger.add_dict({ + "stats/discriminator_trainable_parameters": sum(p.numel() for p in self.trainable_params_D), + "stats/generator_trainable_parameters": sum(p.numel() for p in self.trainable_params_G), + }, commit=False, level=logging.INFO) + check_ddp_consistency(self.D) + check_ddp_consistency(self.G) + check_ddp_consistency(self.G_EMA.generator) + + def train_loop(self): + self.log_time() + while logger.global_step() <= self.max_images_to_train: + batch = next(self.dl_train) + self.G_EMA.update_beta() + self.to_log.update(self.step_D(batch)) + self.to_log.update(self.step_G(batch)) + self.G_EMA.update(self.G) + + if logger.global_step() >= self.train_state.next_log_step: + to_log = {f"loss/{key}": item.item() for key, item in self.to_log.get_average().items()} + to_log.update({"amp/grad_scale_G": self.scaler_G.get_scale()}) + to_log.update({"amp/grad_scale_D": self.scaler_D.get_scale()}) + self.to_log = AverageMeter() + logger.add_dict(to_log, commit=True) + self.train_state.next_log_step += self.images_per_log + if self.scaler_D.get_scale() < 1e-8 or self.scaler_G.get_scale() < 1e-8: + print("Stopping training as gradient scale < 1e-8") + logger.log("Stopping training as gradient scale < 1e-8") + break + + if logger.global_step() >= self.train_state.next_val_step: + self.evaluate() + self.log_time() + self.save_images() + self.train_state.next_val_step += self.images_per_val + logger.step(self.batch_size*tops.world_size()) + logger.log(f"Reached end of training at step {logger.global_step()}.") + checkpointer.save_registered_models() + + def estimate_ims_per_hour(self): + batch = next(self.dl_train) + n_ims = int(100e3) + n_steps = int(n_ims / (self.batch_size * tops.world_size())) + n_ims = n_steps * self.batch_size * tops.world_size() + for i in range(10): # Warmup + self.G_EMA.update_beta() + self.step_D(batch) + self.step_G(batch) + self.G_EMA.update(self.G) + start_time = time.time() + for i in utils.tqdm_(list(range(n_steps))): + self.G_EMA.update_beta() + self.step_D(batch) + self.step_G(batch) + self.G_EMA.update(self.G) + total_time = time.time() - start_time + ims_per_sec = n_ims / total_time + ims_per_hour = ims_per_sec * 60*60 + ims_per_day = ims_per_hour * 24 + logger.log(f"Images per hour: {ims_per_hour/1e6:.3f}M") + logger.log(f"Images per day: {ims_per_day/1e6:.3f}M") + import math + ims_per_4_day = int(math.ceil(ims_per_day / tops.world_size() * 4)) + logger.log(f"Images per 4 days: {ims_per_4_day}") + logger.add_dict({ + "stats/ims_per_day": ims_per_day, + "stats/ims_per_4_day": ims_per_4_day + }) + + def log_time(self): + if not hasattr(self, "start_time"): + self.start_time = time.time() + self.last_time_step = logger.global_step() + return + n_images = logger.global_step() - self.last_time_step + if n_images == 0: + return + n_secs = time.time() - self.start_time + n_ims_per_sec = n_images / n_secs + training_time_hours = n_secs / 60/ 60 + self.train_state.total_time += training_time_hours + remaining_images = self.max_images_to_train - logger.global_step() + remaining_time = remaining_images / n_ims_per_sec / 60 / 60 + logger.add_dict({ + "stats/n_ims_per_sec": n_ims_per_sec, + "stats/total_traing_time_hours": self.train_state.total_time, + "stats/remaining_time_hours": remaining_time + }) + self.last_time_step = logger.global_step() + self.start_time = time.time() + + def save_images(self): + dl_val = iter(self.dl_val) + batch = next(dl_val) + # TRUNCATED visualization + ims_to_log = 8 + self.G_EMA.eval() + z = self.G.get_z(batch["img"]) + fakes_truncated = self.G_EMA.sample(**batch, truncation_value=0)["img"] + fakes_truncated = utils.denormalize_img(fakes_truncated).mul(255).byte()[:ims_to_log].cpu() + if "__key__" in batch: + batch.pop("__key__") + real = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log] + to_vis = torch.cat((real, fakes_truncated)) + logger.add_images("images/truncated", to_vis, nrow=2) + + # Diverse images + ims_diverse = 3 + batch = next(dl_val) + to_vis = [] + + for i in range(ims_diverse): + z = self.G.get_z(batch["img"])[:1].repeat(batch["img"].shape[0], 1) + fakes = utils.denormalize_img(self.G_EMA(**batch, z=z)["img"]).mul(255).byte()[:ims_to_log].cpu() + to_vis.append(fakes) + if "__key__" in batch: + batch.pop("__key__") + reals = vis_utils.visualize_batch(**tops.to_cpu(batch))[:ims_to_log] + to_vis.insert(0, reals) + to_vis = torch.cat(to_vis) + logger.add_images("images/diverse", to_vis, nrow=ims_diverse+1) + + self.G_EMA.train() + pass + + def evaluate(self): + logger.log("Stating evaluation.") + self.G_EMA.eval() + try: + checkpointer.save_registered_models(max_keep=3) + except Exception: + logger.log("Could not save checkpoint.") + if self.broadcast_buffers: + check_ddp_consistency(self.G) + check_ddp_consistency(self.D) + metrics = self.evaluate_fn(generator=self.G_EMA, dataloader=self.dl_val) + metrics = {f"metrics/{k}": v for k,v in metrics.items()} + logger.add_dict(metrics, level=logger.logger.INFO) + + def step_D(self, batch): + utils.set_requires_grad(self.trainable_params_D, True) + utils.set_requires_grad(self.trainable_params_G, False) + tops.zero_grad(self.D) + loss, to_log = self.loss_handler.D_loss(batch, grad_scaler=self.scaler_D) + with torch.autograd.profiler.record_function("D_step"): + self.scaler_D.scale(loss).backward() + accumulate_gradients(self.trainable_params_D, fp16_ddp_accumulate=self.fp16_ddp_accumulate) + if self.broadcast_buffers: + accumulate_buffers(self.D) + accumulate_buffers(self.G) + # Step will not unscale if unscale is called previously. + self.scaler_D.step(self.D_optim) + self.scaler_D.update() + utils.set_requires_grad(self.trainable_params_D, False) + utils.set_requires_grad(self.trainable_params_G, False) + return to_log + + def step_G(self, batch): + utils.set_requires_grad(self.trainable_params_D, False) + utils.set_requires_grad(self.trainable_params_G, True) + tops.zero_grad(self.G) + loss, to_log = self.loss_handler.G_loss(batch, grad_scaler=self.scaler_G) + with torch.autograd.profiler.record_function("G_step"): + self.scaler_G.scale(loss).backward() + accumulate_gradients(self.trainable_params_G, fp16_ddp_accumulate=self.fp16_ddp_accumulate) + if self.broadcast_buffers: + accumulate_buffers(self.G) + accumulate_buffers(self.D) + self.scaler_G.step(self.G_optim) + self.scaler_G.update() + utils.set_requires_grad(self.trainable_params_D, False) + utils.set_requires_grad(self.trainable_params_G, False) + return to_log \ No newline at end of file diff --git a/dp2/generator/__init__.py b/dp2/generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dp2/generator/base.py b/dp2/generator/base.py new file mode 100644 index 0000000000000000000000000000000000000000..851f00ee2fed4f8e0601405ffd635338280ec993 --- /dev/null +++ b/dp2/generator/base.py @@ -0,0 +1,144 @@ +import torch +import numpy as np +import tqdm +import tops +from ..layers import Module +from ..layers.sg2_layers import FullyConnectedLayer +from dp2 import utils + + +class BaseGenerator(Module): + + def __init__(self, z_channels: int): + super().__init__() + self.z_channels = z_channels + self.latent_space = "Z" + + @torch.no_grad() + def get_z( + self, + x: torch.Tensor = None, + z: torch.Tensor = None, + truncation_value: float = None, + batch_size: int = None, + dtype=None, device=None) -> torch.Tensor: + """Generates a latent variable for generator. + """ + if z is not None: + return z + if x is not None: + batch_size = x.shape[0] + dtype = x.dtype + device = x.device + if device is None: + device = utils.get_device() + if truncation_value == 0: + return torch.zeros((batch_size, self.z_channels), device=device, dtype=dtype) + z = torch.randn((batch_size, self.z_channels), device=device, dtype=dtype) + if truncation_value is None: + return z + while z.abs().max() > truncation_value: + m = z.abs() > truncation_value + z[m] = torch.rand_like(z)[m] + return z + + def sample(self, truncation_value, z=None, **kwargs): + """ + Samples via interpolating to the mean (0). + """ + if truncation_value is None: + return self.forward(**kwargs) + truncation_value = max(0, truncation_value) + truncation_value = min(truncation_value, 1) + if z is None: + z = self.get_z(kwargs["condition"]) + z = z * truncation_value + return self.forward(**kwargs, z=z) + + + +class SG2StyleNet(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality. + w_dim, # Intermediate latent (W) dimensionality. + num_layers = 2, # Number of mapping layers. + lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta = 0.998, # Decay for tracking the moving average of W during training. + ): + super().__init__() + self.z_dim = z_dim + self.w_dim = w_dim + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + # Construct layers. + features = [self.z_dim] + [self.w_dim] * self.num_layers + for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]): + layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier) + setattr(self, f'fc{idx}', layer) + self.register_buffer('w_avg', torch.zeros([w_dim])) + + def forward(self, z, update_emas=False, y=None): + tops.assert_shape(z, [None, self.z_dim]) + + # Embed, normalize, and concatenate inputs. + x = z.to(torch.float32) + x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt() + # Execute layers. + for idx in range(self.num_layers): + x = getattr(self, f'fc{idx}')(x) + # Update moving average of W. + if update_emas: + self.w_avg.copy_(x.float().detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + return x + + def extra_repr(self): + return f'z_dim={self.z_dim:d}, w_dim={self.w_dim:d}' + + def update_w(self, n=int(10e3), batch_size=32): + """ + Calculate w_ema over n iterations. + Useful in cases where w_ema is calculated incorrectly during training. + """ + n = n // batch_size + for i in tqdm.trange(n, desc="Updating w"): + z = torch.randn((batch_size, self.z_dim), device=tops.get_device()) + self(z, update_emas=True) + + +class BaseStyleGAN(BaseGenerator): + + def __init__(self, z_channels: int, w_dim: int): + super().__init__(z_channels) + self.style_net = SG2StyleNet(z_channels, w_dim) + self.latent_space = "W" + + def get_w(self, z, update_emas): + return self.style_net(z, update_emas=update_emas) + + @torch.no_grad() + def sample(self, truncation_value, **kwargs): + if truncation_value is None: + return self.forward(**kwargs) + truncation_value = max(0, truncation_value) + truncation_value = min(truncation_value, 1) + w = self.get_w(self.get_z(kwargs["condition"]), False) + w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value) + return self.forward(**kwargs, w=w) + + def update_w(self, *args, **kwargs): + self.style_net.update_w(*args, **kwargs) + + + @torch.no_grad() + def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): + if truncation_value is None: + return self.forward(**kwargs) + truncation_value = max(0, truncation_value) + truncation_value = min(truncation_value, 1) + w = self.get_w(self.get_z(kwargs["condition"]), False) + if w_indices is None: + w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w))) + w_centers = self.style_net.w_centers[w_indices].to(w.device) + w = w_centers.to(w.dtype).lerp(w, truncation_value) + return self.forward(**kwargs, w=w) diff --git a/dp2/generator/dummy_generators.py b/dp2/generator/dummy_generators.py new file mode 100644 index 0000000000000000000000000000000000000000..c319e65c8191bf3f77d9d348698bff10c44e0b21 --- /dev/null +++ b/dp2/generator/dummy_generators.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn +from .base import BaseGenerator + + +class PixelationGenerator(BaseGenerator): + + def __init__(self, pixelation_size, **kwargs): + super().__init__(z_channels=0) + self.pixelation_size = pixelation_size + self.z_channels = 0 + self.latent_space=None + + def forward(self, img, condition, mask, **kwargs): + old_shape = img.shape[-2:] + img = nn.functional.interpolate(img, size=(self.pixelation_size, self.pixelation_size), mode="bilinear", align_corners=True) + img = nn.functional.interpolate(img, size=old_shape, mode="bilinear", align_corners=True) + out = img*(1-mask) + condition*mask + return {"img": out} + + +class MaskOutGenerator(BaseGenerator): + + def __init__(self, noise: str, **kwargs): + super().__init__(z_channels=0) + self.noise = noise + self.z_channels = 0 + assert self.noise in ["rand", "constant"] + self.latent_space = None + + def forward(self, img, condition, mask, **kwargs): + + if self.noise == "constant": + img = torch.zeros_like(img) + elif self.noise == "rand": + img = torch.rand_like(img) + out = img*(1-mask) + condition*mask + return {"img": out} + + +class IdentityGenerator(BaseGenerator): + + def __init__(self): + super().__init__(z_channels=0) + + def forward(self, img, condition, mask, **kwargs): + return dict(img=img) \ No newline at end of file diff --git a/dp2/generator/imagen3_old.py b/dp2/generator/imagen3_old.py new file mode 100644 index 0000000000000000000000000000000000000000..87f35d7595dd018f1d0db2baddcfeb3789596ed0 --- /dev/null +++ b/dp2/generator/imagen3_old.py @@ -0,0 +1,1210 @@ +# What is missing from this implementation +# 1. Global context in res block +# 2. Cross attention of conditional information in resnet block +# +from functools import partial +import tops +from tops.config import instantiate +import warnings +from typing import Iterable, List, Tuple +import numpy as np +import torch +import torch.nn as nn +from torch import einsum +from einops import rearrange +from dp2 import infer, utils +from .base import BaseGenerator +from sg3_torch_utils.ops import bias_act +from dp2.layers import Sequential +import torch.nn.functional as F +from torchvision.transforms.functional import resize, InterpolationMode +from sg3_torch_utils.ops import conv2d_resample, fma, upfirdn2d + + + + +class Upfirdn2d(torch.nn.Module): + + + def __init__(self, down=1, up=1, fix_gain=True): + super().__init__() + self.register_buffer("resample_filter", upfirdn2d.setup_filter([1, 3, 3, 1])) + fw, fh = upfirdn2d._get_filter_size(self.resample_filter) + px0, px1, py0, py1 = upfirdn2d._parse_padding(0) + self.down = down + self.up = up + if up > 1: + px0 += (fw + up - 1) // 2 + px1 += (fw - up) // 2 + py0 += (fh + up - 1) // 2 + py1 += (fh - up) // 2 + if down > 1: + px0 += (fw - down + 1) // 2 + px1 += (fw - down) // 2 + py0 += (fh - down + 1) // 2 + py1 += (fh - down) // 2 + self.padding = [px0,px1,py0,py1] + self.gain = up**2 if fix_gain else 1 + + def forward(self, x, *args): + if isinstance(x, dict): + x = {k: v for k, v in x.items()} + x["x"] = upfirdn2d.upfirdn2d(x["x"], self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain) + return x + x = upfirdn2d.upfirdn2d(x, self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain) + if len(args) == 0: + return x + return (x, *args) +@torch.no_grad() +def spatial_embed_keypoints(keypoints: torch.Tensor, x): + tops.assert_shape(keypoints, (None, None, 3)) + B, N_K, _ = keypoints.shape + H, W = x.shape[-2:] + keypoint_spatial = torch.zeros(keypoints.shape[0], N_K, H, W, device=keypoints.device, dtype=torch.float32) + x, y, visible = keypoints.chunk(3, dim=2) + x = (x * W).round().long().clamp(0, W-1) + y = (y * H).round().long().clamp(0, H-1) + kp_idx = torch.arange(0, N_K, 1, device=keypoints.device, dtype=torch.long).view(1, -1, 1).repeat(B, 1, 1) + pos = (kp_idx*(H*W) + y*W + x + 1) + # Offset all by 1 to index invisible keypoints as 0 + pos = (pos * visible.round().long()).squeeze(dim=-1) + keypoint_spatial = torch.zeros(keypoints.shape[0], N_K*H*W+1, device=keypoints.device, dtype=torch.float32) + keypoint_spatial.scatter_(1, pos, 1) + keypoint_spatial = keypoint_spatial[:, 1:].view(-1, N_K, H, W) + return keypoint_spatial + + +def modulated_conv2d( + x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. + weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. + styles, # Modulation coefficients of shape [batch_size, in_channels]. + noise = None, # Optional noise tensor to add to the output activations. + up = 1, # Integer upsampling factor. + down = 1, # Integer downsampling factor. + padding = 0, # Padding with respect to the upsampled image. + resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). + demodulate = True, # Apply weight demodulation? + flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). + fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? +): + batch_size = x.shape[0] + out_channels, in_channels, kh, kw = weight.shape + tops.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] + tops.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] + tops.assert_shape(styles, [batch_size, in_channels]) # [NI] + + # Pre-normalize inputs to avoid FP16 overflow. + if x.dtype == torch.float16 and demodulate: + weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk + styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I + + # Calculate per-sample weights and demodulation coefficients. + w = None + dcoefs = None + if demodulate or fused_modconv: + w = weight.unsqueeze(0) # [NOIkk] + w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] + if demodulate: + dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] + if demodulate and fused_modconv: + w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] + + # Execute by scaling the activations before and after the convolution. + if not fused_modconv: + x = x * styles.reshape(batch_size, -1, 1, 1) + x = conv2d_resample.conv2d_resample(x=x, w=weight, f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) + if demodulate and noise is not None: + x = fma.fma(x, dcoefs.reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) + elif demodulate: + x = x * dcoefs.reshape(batch_size, -1, 1, 1) + elif noise is not None: + x = x.add_(noise.to(x.dtype)) + return x + + with tops.suppress_tracer_warnings(): # this value will be treated as a constant + batch_size = int(batch_size) + # Execute as one fused op using grouped convolution. + tops.assert_shape(x, [batch_size, in_channels, None, None]) + x = x.reshape(1, -1, *x.shape[2:]) + w = w.reshape(-1, in_channels, kh, kw) + x = conv2d_resample.conv2d_resample(x=x, w=w, f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) + x = x.reshape(batch_size, -1, *x.shape[2:]) + if noise is not None: + x = x.add_(noise) + return x + + +class Identity(nn.Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + +class LayerNorm(nn.Module): + def __init__(self, dim, stable=False): + super().__init__() + self.stable = stable + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + if self.stable: + x = x / x.amax(dim=-1, keepdim=True).detach() + + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + var = torch.var(x, dim=-1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=-1, keepdim=True) + return (x - mean) * (var + eps).rsqrt() * self.g + + +class FullyConnectedLayer(torch.nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + bias = True, # Apply additive bias before the activation function? + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier = 1, # Learning rate multiplier. + bias_init = 0, # Initial value for the additive bias. + ): + super().__init__() + self.repr = dict( + in_features=in_features, out_features=out_features, bias=bias, + activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init) + self.activation = activation + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) + self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + self.in_features = in_features + self.out_features = out_features + + def forward(self, x): + w = self.weight * self.weight_gain + b = self.bias + if b is not None: + if self.bias_gain != 1: + b = b * self.bias_gain + x = F.linear(x, w) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self) -> str: + return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) + + + +def checkpoint_fn(fn, *args, **kwargs): + warnings.simplefilter("ignore") + return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs) + +class Conv2d(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + activation='lrelu', + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + bias=True, + norm=None, + lr_multiplier=1, + bias_init=0, + w_dim=None, + gradient_checkpoint_norm=False, + gain=1, + ): + super().__init__() + self.fused_modconv = False + if norm == torch.nn.InstanceNorm2d: + self.norm = torch.nn.InstanceNorm2d(None) + elif isinstance(norm, torch.nn.Module): + self.norm = norm + elif norm == "fused_modconv": + self.fused_modconv = True + elif norm: + self.norm = torch.nn.InstanceNorm2d(None) + elif norm is not None: + raise ValueError(f"norm not supported: {norm}") + self.activation = activation + self.conv_clamp = conv_clamp + self.out_channels = out_channels + self.in_channels = in_channels + self.padding = kernel_size // 2 + self.repr = dict( + in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, + activation=activation, conv_clamp=conv_clamp, bias=bias, + fused_modconv=self.fused_modconv + ) + self.act_gain = bias_act.activation_funcs[activation].def_gain * gain + self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2)) + self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size])) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])+bias_init) if bias else None + self.bias_gain = lr_multiplier + if w_dim is not None: + if self.fused_modconv: + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + else: + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0) + self.gradient_checkpoint_norm = gradient_checkpoint_norm + + def forward(self, x, w=None, gain=1, **kwargs): + if self.fused_modconv: + styles = self.affine(w) + with torch.cuda.amp.autocast(enabled=False): + x = modulated_conv2d(x=x.half(), weight=self.weight.half(), styles=styles.half(), noise=None, + padding=self.padding, flip_weight=True, fused_modconv=False).to(x.dtype) + else: + if hasattr(self, "affine"): + gamma = self.affine(w).view(-1, self.in_channels, 1, 1) + beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1) + x = fma.fma(x, gamma ,beta) + w = self.weight * self.weight_gain + x = F.conv2d(input=x, weight=w, padding=self.padding,) + + if hasattr(self, "norm"): + if self.gradient_checkpoint_norm: + x = checkpoint_fn(self.norm, x) + else: + x = self.norm(x) + act_gain = self.act_gain * gain + act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None + b = self.bias * self.bias_gain if self.bias is not None else None + x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) + return x + + def extra_repr(self) -> str: + return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) + + +class CrossAttention(nn.Module): + def __init__( + self, + dim, + context_dim, + dim_head=64, + heads=8, + norm_context=False, + ): + super().__init__() + self.scale = dim_head ** -0.5 + + self.heads = heads + inner_dim = dim_head * heads + + self.norm = nn.InstanceNorm1d(dim) + self.norm_context = nn.InstanceNorm1d(None) if norm_context else Identity() + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim, bias=False), + nn.InstanceNorm1d(None) + ) + + def forward(self, x, w): + x = self.norm(x) + w = self.norm_context(w) + + q, k, v = (self.to_q(x), *self.to_kv(w).chunk(2, dim = -1)) + q = rearrange(q, "b n (h d) -> b h n d", h = self.heads) + k = rearrange(k, "b n (h d) -> b h n d", h = self.heads) + v = rearrange(v, "b n (h d) -> b h n d", h = self.heads) + q = q * self.scale + # similarities + sim = einsum('b h i d, b h j d -> b h i j', q, k) + attn = sim.softmax(dim = -1, dtype = torch.float32) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class SG2ResidualBlock(torch.nn.Module): + def __init__( + self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. + skip_gain=np.sqrt(.5), + cross_attention: bool = False, + cross_attention_len: int = None, + use_adain: bool = True, + **layer_kwargs, # Arguments for conv layer. + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + w_dim = layer_kwargs.pop("w_dim") if "w_dim" in layer_kwargs else None + if use_adain: + layer_kwargs["w_dim"] = w_dim + + self.conv0 = Conv2d(in_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs) + self.conv1 = Conv2d(out_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs, gain=skip_gain) + + self.skip = Conv2d(in_channels, out_channels, kernel_size=1, bias=False, gain=skip_gain) + if cross_attention and w_dim is not None: + self.cross_attention_len = cross_attention_len + self.cross_attn = CrossAttention( + dim=out_channels, context_dim=w_dim//self.cross_attention_len, + gain=skip_gain) + + def forward(self, x, w=None, **layer_kwargs): + y = self.skip(x) + x = self.conv0(x, w, **layer_kwargs) + x = self.conv1(x, w, **layer_kwargs) + if hasattr(self, "cross_attn"): + h = x.shape[-2] + x = rearrange(x, "b c h w -> b (h w) c") + w = rearrange(w, "b (n c) -> b n c", n=self.cross_attention_len) + x = self.cross_attn(x, w=w) + x + x = rearrange(x, "b (h w) c -> b c h w", h=h) + return y + x + + +def default(val, d): + if val is not None: + return val + return d() if callable(d) else d + + +def cast_tuple(val, length=None): + if isinstance(val, Iterable) and not isinstance(val, str): + val = tuple(val) + output = val if isinstance(val, tuple) else ((val,) * default(length, 1)) + if length is not None: + assert len(output) == length, (output, length) + return output + + +class Attention(nn.Module): + # This is a version of Multi-Query Attention () + # Fast Transformer Decoding: One Write-Head is All You Need + # Ablated in: https://arxiv.org/pdf/2203.07814.pdf + # and https://arxiv.org/pdf/2204.02311.pdf + def __init__(self, dim, norm, attn_fix_gain, gradient_checkpoint, dim_head=64, heads=8, cosine_sim_attn=False, fix_attention_again=False, gain=None): + super().__init__() + self.scale = dim_head**-0.5 if not cosine_sim_attn else 1.0 + self.cosine_sim_attn = cosine_sim_attn + self.cosine_sim_scale = 16 if cosine_sim_attn else 1 + self.gradient_checkpoint = gradient_checkpoint + self.heads = heads + self.dim = dim + self.fix_attention_again = fix_attention_again + inner_dim = dim_head * heads + if norm == "LN": + self.norm = LayerNorm(dim) + elif norm == "IN": + self.norm = nn.InstanceNorm1d(dim) + elif norm is None: + self.norm = nn.Identity() + else: + raise ValueError(f"Norm not supported: {norm}") + + self.to_q = FullyConnectedLayer(dim, inner_dim, bias=False) + self.to_kv = FullyConnectedLayer(dim, dim_head*2, bias=False) + + self.to_out = nn.Sequential( + FullyConnectedLayer(inner_dim, dim, bias=False), + LayerNorm(dim) if norm == "LN" else nn.InstanceNorm1d(dim) + ) + if fix_attention_again: + assert gain is not None + self.gain = gain + else: + self.gain = np.sqrt(.5) if attn_fix_gain else 1 + + def run_function(self, x, attn_bias): + b, c, h, w = x.shape + x = rearrange(x, "b c h w -> b (h w) c") + in_ = x + b, n, device = *x.shape[:2], x.device + x = self.norm(x) + q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) + + q = rearrange(q, "b n (h d) -> b h n d", h=self.heads) + q = q * self.scale + + # calculate query / key similarities + sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale + + if attn_bias is not None: + attn_bias = attn_bias + attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)") + sim = sim + attn_bias + + attn = sim.softmax(dim=-1) + + out = einsum("b h i j, b j d -> b h i d", attn, v) + + out = rearrange(out, "b h n d -> b n (h d)") + if self.fix_attention_again: + out = self.to_out(out)*self.gain + in_ + else: + out = (self.to_out(out) + in_) * self.gain + out = rearrange(out, "b (h w) c -> b c h w", h=h) + return out + + def forward(self, x, *args, attn_bias=None, **kwargs): + if self.gradient_checkpoint: + return checkpoint_fn(self.run_function, x, attn_bias) + return self.run_function(x, attn_bias) + + def get_attention(self, x, attn_bias=None): + b, c, h, w = x.shape + x = rearrange(x, "b c h w -> b (h w) c") + in_ = x + b, n, device = *x.shape[:2], x.device + x = self.norm(x) + q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) + + q = rearrange(q, "b n (h d) -> b h n d", h=self.heads) + q = q * self.scale + + # calculate query / key similarities + sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale + + if attn_bias is not None: + attn_bias = attn_bias + attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)") + sim = sim + attn_bias + + attn = sim.softmax(dim=-1) + return attn, None + + +class BiasedAttention(Attention): + + def __init__(self, *args, head_wise: bool=True, **kwargs): + super().__init__(*args, **kwargs) + out_ch = self.heads if head_wise else 1 + self.conv = Conv2d(self.dim+2, out_ch, activation="linear", kernel_size=3, bias_init=0) + nn.init.zeros_(self.conv.weight.data) + + def forward(self, x, mask): + mask = resize(mask, size=x.shape[-2:]) + bias = self.conv(torch.cat((x, mask, 1-mask), dim=1)) + return super().forward(x=x, attn_bias=bias) + + def get_attention(self, x, mask): + mask = resize(mask, size=x.shape[-2:]) + bias = self.conv(torch.cat((x, mask, 1-mask), dim=1)) + return super().get_attention(x, bias)[0], bias + +class UNet(BaseGenerator): + + def __init__( + self, + im_channels: int, + dim: int, + dim_mults: tuple, + num_resnet_blocks, # Number of resnet blocks per resolution + n_middle_blocks: int, + z_channels: int, + conv_clamp: int, + layer_attn, + w_dim: int, + norm_enc: bool, + norm_dec: str, + stylenet: nn.Module, + enc_style: bool, # Toggle style injection in encoder + use_maskrcnn_mask: bool, + skip_all_unets: bool, + fix_resize:bool, + comodulate: bool, + comod_net: nn.Module, + lr_comod: float, + dec_style: bool, + input_keypoints: bool, + n_keypoints: int, + input_keypoint_indices: Tuple[int], + use_adain: bool, + cross_attention: bool, + cross_attention_len: int, + gradient_checkpoint_norm: bool, + attn_cls: partial, + mask_out_train: bool, + fix_gain_again: bool, + ) -> None: + super().__init__(z_channels) + self.enc_style = enc_style + self.n_keypoints = n_keypoints + self.input_keypoint_indices = list(input_keypoint_indices) + self.input_keypoints = input_keypoints + self.mask_out_train = mask_out_train + n_layers = len(dim_mults) + self.n_layers = n_layers + layer_attn = cast_tuple(layer_attn, n_layers) + num_resnet_blocks = cast_tuple(num_resnet_blocks, n_layers) + self._cnum = dim + self._image_channels = im_channels + self._z_channels = z_channels + encoder_layers = [] + condition_ch = im_channels + self.from_rgb = Conv2d( + condition_ch + 2 + 2*int(use_maskrcnn_mask) + self.input_keypoints*len(input_keypoint_indices) + , dim, 7) + + self.use_maskrcnn_mask = use_maskrcnn_mask + self.skip_all_unets = skip_all_unets + dims = [dim*m for m in dim_mults] + enc_blk = partial( + SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_enc, + use_adain=use_adain and self.enc_style, + w_dim=w_dim, + cross_attention=cross_attention, + cross_attention_len=cross_attention_len, + gradient_checkpoint_norm=gradient_checkpoint_norm + ) + dec_blk = partial( + SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_dec, + use_adain=use_adain and dec_style, + w_dim=w_dim, + cross_attention=cross_attention, + cross_attention_len=cross_attention_len, + gradient_checkpoint_norm=gradient_checkpoint_norm + ) + # Currently up/down sampling is done by bilinear upsampling. + # This can be simplified by replacing it with a strided upsampling layer... + self.encoder_attns = nn.ModuleList() + for lidx in range(n_layers): + gain = np.sqrt(1/3) if layer_attn[lidx] and fix_gain_again else np.sqrt(.5) + dim_in = dims[lidx] + dim_out = dims[min(lidx+1, n_layers-1)] + res_blocks = nn.ModuleList() + for i in range(num_resnet_blocks[lidx]): + is_last = num_resnet_blocks[lidx] - 1 == i + cur_dim = dim_out if is_last else dim_in + block = enc_blk(dim_in, cur_dim, skip_gain=gain) + res_blocks.append(block) + if layer_attn[lidx]: + self.encoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain)) + else: + self.encoder_attns.append(Identity()) + encoder_layers.append(res_blocks) + self.encoder = torch.nn.ModuleList(encoder_layers) + + # initialize decoder + decoder_layers = [] + self.unet_layers = torch.nn.ModuleList() + self.decoder_attns = torch.nn.ModuleList() + for lidx in range(n_layers): + dim_in = dims[min(-lidx, -1)] + dim_out = dims[-1-lidx] + res_blocks = nn.ModuleList() + unet_skips = nn.ModuleList() + for i in range(num_resnet_blocks[-lidx-1]): + is_first = i == 0 + has_unet = is_first or skip_all_unets + is_last = i == num_resnet_blocks[-lidx-1] - 1 + cur_dim = dim_in if is_first else dim_out + if has_unet and is_last and layer_attn[-lidx-1] and fix_gain_again: # x + residual + unet + layer attn + gain = np.sqrt(1/4) + elif has_unet: # x + residual + unet + gain = np.sqrt(1/3) + elif layer_attn[-lidx-1] and fix_gain_again: # x + residual + attention + gain = np.sqrt(1/3) + else: # x + residual + gain = np.sqrt(1/2) # Only residual block + block = dec_blk(cur_dim, dim_out, skip_gain=gain) + res_blocks.append(block) + if has_unet: + unet_block = Conv2d( + cur_dim, cur_dim, kernel_size=1, conv_clamp=conv_clamp, + norm=nn.InstanceNorm2d(None), + gradient_checkpoint_norm=gradient_checkpoint_norm, + gain=gain) + unet_skips.append(unet_block) + else: + unet_skips.append(torch.nn.Identity()) + if layer_attn[-lidx-1]: + self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain)) + else: + self.decoder_attns.append(Identity()) + + decoder_layers.append(res_blocks) + self.unet_layers.append(unet_skips) + + middle_blocks = [] + for i in range(n_middle_blocks): + block = dec_blk(dims[-1], dims[-1]) + middle_blocks.append(block) + if n_middle_blocks != 0: + self.middle_blocks = Sequential(*middle_blocks) + self.decoder = torch.nn.ModuleList(decoder_layers) + self.to_rgb = Conv2d(dim, im_channels, 1, activation="linear", conv_clamp=conv_clamp) + self.stylenet = stylenet + self.downsample = Upfirdn2d(down=2, fix_gain=fix_resize) + self.upsample = Upfirdn2d(up=2, fix_gain=fix_resize) + self.comodulate = comodulate + if comodulate: + assert not self.enc_style + self.to_y = nn.Sequential( + Conv2d(dims[-1], dims[-1], lr_multiplier=lr_comod, gradient_checkpoint_norm=gradient_checkpoint_norm), + nn.AdaptiveAvgPool2d(1), + nn.Flatten(), + FullyConnectedLayer(dims[-1], 512, activation="lrelu", lr_multiplier=lr_comod) + ) + self.comod_net = comod_net + + + def forward(self, condition, mask, maskrcnn_mask=None, z=None, w=None, update_emas=False, keypoints=None, return_decoder_features=False, **kwargs): + if z is None: + z = self.get_z(condition) + if w is None: + w = self.stylenet(z, update_emas=update_emas) + if self.use_maskrcnn_mask: + x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1) + else: + x = torch.cat((condition, mask, 1-mask), dim=1) + + if self.input_keypoints: + keypoints = keypoints[:, self.input_keypoint_indices] + one_hot_pose = spatial_embed_keypoints(keypoints, x) + x = torch.cat((x, one_hot_pose), dim=1) + x = self.from_rgb(x) + x, unet_features = self.forward_enc(x, mask, w) + x, decoder_features = self.forward_dec(x, mask, w, unet_features) + x = self.to_rgb(x) + unmasked = x + if self.mask_out_train: + x = mask * condition + (1-mask) * x + out = dict(img=x, unmasked=unmasked) + if return_decoder_features: + out["decoder_features"] = decoder_features + return out + + def forward_enc(self, x, mask, w): + unet_features = [] + for i, res_blocks in enumerate(self.encoder): + is_last = i == len(self.encoder) - 1 + for block in res_blocks: + x = block(x, w=w) + unet_features.append(x) + x = self.encoder_attns[i](x, mask=mask) + if not is_last: + x = self.downsample(x) + if self.comodulate: + y = self.to_y(x) + y = torch.cat((w, y), dim=-1) + w = self.comod_net(y) + return x, unet_features + + def forward_dec(self, x, mask, w, unet_features): + if hasattr(self, "middle_blocks"): + x = self.middle_blocks(x, w=w) + features = [] + unet_features = iter(reversed(unet_features)) + for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)): + is_last = i == len(self.decoder) - 1 + for skip, block in zip(unet_skip, res_blocks): + skip_x = next(unet_features) + if not isinstance(skip, torch.nn.Identity): + skip_x = skip(skip_x) + x = x + skip_x + x = block(x, w=w) + x = self.decoder_attns[i](x, mask=mask) + features.append(x) + if not is_last: + x = self.upsample(x) + return x, features + + def get_w(self, z, update_emas): + return self.stylenet(z, update_emas=update_emas) + + @torch.no_grad() + def sample(self, truncation_value, **kwargs): + if truncation_value is None: + return self.forward(**kwargs) + truncation_value = max(0, truncation_value) + truncation_value = min(truncation_value, 1) + w = self.get_w(self.get_z(kwargs["condition"]), False) + w = self.stylenet.w_avg.to(w.dtype).lerp(w, truncation_value) + return self.forward(**kwargs, w=w) + + def update_w(self, *args, **kwargs): + self.style_net.update_w(*args, **kwargs) + + @property + def style_net(self): + return self.stylenet + + @torch.no_grad() + def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): + if truncation_value is None: + return self.forward(**kwargs) + truncation_value = max(0, truncation_value) + truncation_value = min(truncation_value, 1) + w = self.get_w(self.get_z(kwargs["condition"]), False) + if w_indices is None: + w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w))) + w_centers = self.style_net.w_centers[w_indices].to(w.device) + w = w_centers.to(w.dtype).lerp(w, truncation_value) + return self.forward(**kwargs, w=w) + + +def get_stem_unet_kwargs(cfg): + if "stem_cfg" in cfg.generator: # If the stem has another stem, recursively apply get_stem_unet_kwargs + return get_stem_unet_kwargs(cfg.generator.stem_cfg) + return dict(cfg.generator) + + +class GrowingUnet(BaseGenerator): + + def __init__( + self, + coarse_stem_cfg: str, # This can be a coarse generator or None + sr_cfg: str, # Can be a previous progressive u-net, Unet or None + residual: bool, + new_dataset: bool, # The "new dataset" creates condition first -> resizes + **unet_kwargs): + kwargs = dict() + if coarse_stem_cfg is not None: + coarse_stem_cfg = utils.load_config(coarse_stem_cfg) + kwargs = get_stem_unet_kwargs(coarse_stem_cfg) + if sr_cfg is not None: + sr_cfg = utils.load_config(sr_cfg) + sr_stem_unet_kwargs = get_stem_unet_kwargs(sr_cfg) + kwargs.update(sr_stem_unet_kwargs) + kwargs.update(unet_kwargs) + kwargs["stylenet"] = None + kwargs.pop("_target_") + if "sr_cfg" in kwargs: # Unet kwargs are inherited, do not pass this to the new u-net + del kwargs["sr_cfg"] + if "coarse_stem_cfg" in kwargs: + del kwargs["coarse_stem_cfg"] + super().__init__(z_channels=kwargs["z_channels"]) + if coarse_stem_cfg is not None: + z_channels = coarse_stem_cfg.generator.z_channels + super().__init__(z_channels) + self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval() + self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize) + utils.set_requires_grad(self.coarse_stem, False) + else: + assert not residual + + if sr_cfg is not None: + self.sr_stem = infer.build_trained_generator(sr_cfg, map_location="cpu").eval() + del self.sr_stem.from_rgb + del self.sr_stem.to_rgb + if hasattr(self.sr_stem, "coarse_stem"): + del self.sr_stem.coarse_stem + if isinstance(self.sr_stem, UNet): + del self.sr_stem.encoder[0][0] # Delete first residual block + del self.sr_stem.decoder[-1][-1] # Delete last residual block + else: + assert isinstance(self.sr_stem, GrowingUnet) + del self.sr_stem.unet.encoder[0][0] # Delete first residual block + del self.sr_stem.unet.decoder[-1][-1] # Delete last residual block + utils.set_requires_grad(self.sr_stem, False) + + + args = kwargs.pop("_args_") + if hasattr(self, "sr_stem"): # Growing the SR stem - Add a new layer to match sr + n_layers = len(kwargs["dim_mults"]) + dim_mult = sr_stem_unet_kwargs["dim"] / (kwargs["dim"] * max(kwargs["dim_mults"])) + kwargs["dim_mults"] = [*kwargs["dim_mults"], int(dim_mult)] + kwargs["layer_attn"] = [*cast_tuple(kwargs["layer_attn"], n_layers), False] + kwargs["num_resnet_blocks"] = [*cast_tuple(kwargs["num_resnet_blocks"], n_layers), 1] + self.unet = UNet( + *args, + **kwargs + ) + self.from_rgb = self.unet.from_rgb + self.to_rgb = self.unet.to_rgb + self.residual = residual + self.new_dataset = new_dataset + if residual: + nn.init.zeros_(self.to_rgb.weight.data) + del self.unet.from_rgb, self.unet.to_rgb + + def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, **kwargs): + # Downsample for stem + if z is None: + z = self.get_z(img) + if w is None: + w = self.style_net(z) + if hasattr(self, "coarse_stem"): + with torch.no_grad(): + if self.new_dataset: + img_stem = utils.denormalize_img(img)*255 + condition_stem = img_stem * mask + (1-mask)*127 + condition_stem = condition_stem.round() + condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True) + condition_stem = condition_stem / 255 *2 - 1 + mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float() + maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float() + else: + mask_stem = (resize(mask, self.coarse_stem.imsize, antialias=True) > .99).float() + maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, antialias=True) > .5).float() + img_stem = utils.denormalize_img(img)*255 + img_stem = resize(img_stem, self.coarse_stem.imsize, antialias=True).round() + img_stem = img_stem / 255 * 2 - 1 + condition_stem = img_stem * mask_stem + stem_out = self.coarse_stem( + condition=condition_stem, mask=mask_stem, + maskrcnn_mask=maskrcnn_stem, w=w, + keypoints=keypoints) + x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True) + condition = condition*mask + (1-mask) * x_lr + if self.unet.use_maskrcnn_mask: + x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1) + else: + x = torch.cat((condition, mask, 1-mask), dim=1) + if self.unet.input_keypoints: + keypoints = keypoints[:, self.unet.input_keypoint_indices] + one_hot_pose = spatial_embed_keypoints(keypoints, x) + x = torch.cat((x, one_hot_pose), dim=1) + x = self.from_rgb(x) + x, unet_features = self.forward_enc(x, mask, w) + x = self.forward_dec(x, mask, w, unet_features) + if self.residual: + x = self.to_rgb(x) + condition + else: + x = self.to_rgb(x) + return dict( + img=condition * mask + (1-mask) * x, + unmasked=x, + x_lowres=[condition] + ) + + def forward_enc(self, x, mask, w): + x, unet_features = self.unet.forward_enc(x, mask, w) + if hasattr(self, "sr_stem"): + x, unet_features_stem = self.sr_stem.forward_enc(x, mask, w) + else: + unet_features_stem = None + return x, [unet_features, unet_features_stem] + + def forward_dec(self, x, mask, w, unet_features): + unet_features, unet_features_stem = unet_features + if hasattr(self, "sr_stem"): + x = self.sr_stem.forward_dec(x, mask, w, unet_features_stem) + x, unet_features = self.unet.forward_dec(x, mask, w, unet_features) + return x + + def get_z(self, *args, **kwargs): + if hasattr(self, "coarse_stem"): + return self.coarse_stem.get_z(*args, **kwargs) + if hasattr(self, "sr_stem"): + return self.sr_stem.get_z(*args, **kwargs) + raise AttributeError() + + @property + def style_net(self): + if hasattr(self, "coarse_stem"): + return self.coarse_stem.style_net + if hasattr(self, "sr_stem"): + return self.sr_stem.style_net + raise AttributeError() + + def update_w(self, *args, **kwargs): + self.style_net.update_w(*args, **kwargs) + + def get_w(self, z, update_emas): + return self.style_net(z, update_emas=update_emas) + + @torch.no_grad() + def sample(self, truncation_value, **kwargs): + if truncation_value is None: + return self.forward(**kwargs) + truncation_value = max(0, truncation_value) + truncation_value = min(truncation_value, 1) + w = self.get_w(self.get_z(kwargs["condition"]), False) + w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value) + return self.forward(**kwargs, w=w) + + @torch.no_grad() + def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): + if truncation_value is None: + return self.forward(**kwargs) + truncation_value = max(0, truncation_value) + truncation_value = min(truncation_value, 1) + w = self.get_w(self.get_z(kwargs["condition"]), False) + if w_indices is None: + w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w))) + w_centers = self.style_net.w_centers[w_indices].to(w.device) + w = w_centers.to(w.dtype).lerp(w, truncation_value) + return self.forward(**kwargs, w=w) + + +class CascadedUnet(BaseGenerator): + + def __init__( + self, + coarse_stem_cfg: str, # This can be a coarse generator or None + residual: bool, + new_dataset: bool, # The "new dataset" creates condition first -> resizes + imsize: tuple, + cascade:bool, + **unet_kwargs): + kwargs = dict() + coarse_stem_cfg = utils.load_config(coarse_stem_cfg) + kwargs = get_stem_unet_kwargs(coarse_stem_cfg) + kwargs.update(unet_kwargs) + super().__init__(z_channels=kwargs["z_channels"]) + + self.input_keypoints = kwargs["input_keypoints"] + self.input_keypoint_indices = kwargs["input_keypoint_indices"] + self.use_maskrcnn_mask = kwargs["use_maskrcnn_mask"] + self.imsize = imsize + self.residual = residual + self.new_dataset = new_dataset + + + # Setup coarse stem + stem_dims = [m*coarse_stem_cfg.generator.dim for m in coarse_stem_cfg.generator.dim_mults] + self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval() + self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize) + utils.set_requires_grad(self.coarse_stem, False) + + self.stem_res_to_layer_idx = { + self.coarse_stem.imsize[0] // 2^i: stem_dims[i] + for i in range(len(stem_dims)) + } + + dim = kwargs["dim"] + dim_mults = kwargs["dim_mults"] + n_layers = len(dim_mults) + dims = [dim*s for s in dim_mults] + layer_attn = cast_tuple(kwargs["layer_attn"], n_layers) + num_resnet_blocks = cast_tuple(kwargs["num_resnet_blocks"], n_layers) + attn_cls = kwargs["attn_cls"] + if not isinstance(attn_cls, partial): + attn_cls = instantiate(attn_cls) + + dec_blk = partial( + SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_dec"], + use_adain=kwargs["use_adain"] and kwargs["dec_style"], + w_dim=kwargs["w_dim"], + cross_attention=kwargs["cross_attention"], + cross_attention_len=kwargs["cross_attention_len"], + gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"] + ) + enc_blk = partial( + SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_enc"], + use_adain=kwargs["use_adain"] and kwargs["enc_style"], + w_dim=kwargs["w_dim"], + cross_attention=kwargs["cross_attention"], + cross_attention_len=kwargs["cross_attention_len"], + gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"] + ) + + # Currently up/down sampling is done by bilinear upsampling. + # This can be simplified by replacing it with a strided upsampling layer... + self.encoder_attns = nn.ModuleList() + self.encoder_unet_skips = nn.ModuleDict() + self.encoder = nn.ModuleList() + for lidx in range(n_layers): + has_stem_feature = imsize[0]//2^lidx in self.stem_res_to_layer_idx and cascade + next_layer_has_stem_features = lidx+1 < n_layers and imsize[0]//2^(lidx+1) in self.stem_res_to_layer_idx and cascade + + dim_in = dims[lidx] + dim_out = dims[min(lidx+1, n_layers-1)] + res_blocks = nn.ModuleList() + if has_stem_feature: + prev_layer_has_attention = lidx != 0 and layer_attn[lidx-1] + stem_lidx = self.stem_res_to_layer_idx[imsize[0]//2^lidx] + self.encoder_unet_skips.add_module( + str(imsize[0]//2^lidx), + Conv2d( + stem_dims[stem_lidx], dim_in, kernel_size=1, + conv_clamp=kwargs["conv_clamp"], + norm=nn.InstanceNorm2d(None), + gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"], + gain=np.sqrt(1/4) if prev_layer_has_attention else np.sqrt(1/3) # This + previous residual + attention + ) + ) + for i in range(num_resnet_blocks[lidx]): + is_last = num_resnet_blocks[lidx] - 1 == i + cur_dim = dim_out if is_last else dim_in + if not is_last: + gain = np.sqrt(.5) + elif next_layer_has_stem_features and layer_attn[lidx]: + gain = np.sqrt(1/4) + elif layer_attn[lidx] or next_layer_has_stem_features: + gain = np.sqrt(1/3) + else: + gain = np.sqrt(.5) + block = enc_blk(dim_in, cur_dim, skip_gain=gain) + res_blocks.append(block) + if layer_attn[lidx]: + self.encoder_attns.append(attn_cls(dim=dim_out, gain=gain, fix_attention_again=True)) + else: + self.encoder_attns.append(Identity()) + self.encoder.append(res_blocks) + + # initialize decoder + self.decoder = torch.nn.ModuleList() + self.unet_layers = torch.nn.ModuleList() + self.decoder_attns = torch.nn.ModuleList() + for lidx in range(n_layers): + dim_in = dims[min(-lidx, -1)] + dim_out = dims[-1-lidx] + res_blocks = nn.ModuleList() + unet_skips = nn.ModuleList() + for i in range(num_resnet_blocks[-lidx-1]): + is_first = i == 0 + has_unet = is_first or kwargs["skip_all_unets"] + is_last = i == num_resnet_blocks[-lidx-1] - 1 + cur_dim = dim_in if is_first else dim_out + if has_unet and is_last and layer_attn[-lidx-1]: # x + residual + unet + layer attn + gain = np.sqrt(1/4) + elif has_unet: # x + residual + unet + gain = np.sqrt(1/3) + elif layer_attn[-lidx-1]: # x + residual + attention + gain = np.sqrt(1/3) + else: # x + residual + gain = np.sqrt(1/2) # Only residual block + block = dec_blk(cur_dim, dim_out, skip_gain=gain) + res_blocks.append(block) + if kwargs["skip_all_unets"] or is_first: + unet_block = Conv2d( + cur_dim, cur_dim, kernel_size=1, conv_clamp=kwargs["conv_clamp"], + norm=nn.InstanceNorm2d(None), + gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"], + gain=gain) + unet_skips.append(unet_block) + else: + unet_skips.append(torch.nn.Identity()) + if layer_attn[-lidx-1]: + self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=True, gain=gain)) + else: + self.decoder_attns.append(Identity()) + + self.decoder.append(res_blocks) + self.unet_layers.append(unet_skips) + + self.from_rgb = Conv2d( + 3 + 2 + 2*int(kwargs["use_maskrcnn_mask"]) + self.input_keypoints*len(kwargs["input_keypoint_indices"]) + , dim, 7) + self.to_rgb = Conv2d(dim, 3, 1, activation="linear", conv_clamp=kwargs["conv_clamp"]) + + self.downsample = Upfirdn2d(down=2, fix_gain=True) + self.upsample = Upfirdn2d(up=2, fix_gain=True) + self.cascade = cascade + if residual: + nn.init.zeros_(self.to_rgb.weight.data) + + def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, return_decoder_features=False, **kwargs): + # Downsample for stem + if z is None: + z = self.get_z(img) + + with torch.no_grad(): # Forward pass stem + if w is None: + w = self.style_net(z) + img_stem = utils.denormalize_img(img)*255 + condition_stem = img_stem * mask + (1-mask)*127 + condition_stem = condition_stem.round() + condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True) + condition_stem = condition_stem / 255 *2 - 1 + mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float() + maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float() + stem_out = self.coarse_stem( + condition=condition_stem, mask=mask_stem, + maskrcnn_mask=maskrcnn_stem, w=w, + keypoints=keypoints, + return_decoder_features=True) + stem_features = stem_out["decoder_features"] + x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True) + condition = condition*mask + (1-mask) * x_lr + + if self.use_maskrcnn_mask: + x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1) + else: + x = torch.cat((condition, mask, 1-mask), dim=1) + if self.input_keypoints: + keypoints = keypoints[:, self.input_keypoint_indices] + one_hot_pose = spatial_embed_keypoints(keypoints, x) + x = torch.cat((x, one_hot_pose), dim=1) + x = self.from_rgb(x) + x, unet_features = self.forward_enc(x, mask, w, stem_features) + x, decoder_features = self.forward_dec(x, mask, w, unet_features) + if self.residual: + x = self.to_rgb(x) + condition + else: + x = self.to_rgb(x) + out= dict( + img=condition * mask + (1-mask) * x, # TODO: Probably do not want masked here... or ?? + unmasked=x, + x_lowres=[condition] + ) + if return_decoder_features: + out["decoder_features"] = decoder_features + return out + + def forward_enc(self, x, mask, w, stem_features: List[torch.Tensor]): + unet_features = [] + stem_features.reverse() + for i, res_blocks in enumerate(self.encoder): + is_last = i == len(self.encoder) - 1 + res = self.imsize[0]//2^i + if str(res) in self.encoder_unet_skips.keys() and self.cascade: + y = stem_features[self.stem_res_to_layer_idx[res]] + y = self.encoder_unet_skips[i](y) + x = y + x + for block in res_blocks: + x = block(x, w=w) + unet_features.append(x) + x = self.encoder_attns[i](x, mask) + if not is_last: + x = self.downsample(x) + return x, unet_features + + def forward_dec(self, x, mask, w, unet_features): + features = [] + unet_features = iter(reversed(unet_features)) + for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)): + is_last = i == len(self.decoder) - 1 + for skip, block in zip(unet_skip, res_blocks): + skip_x = next(unet_features) + if not isinstance(skip, torch.nn.Identity): + skip_x = skip(skip_x) + x = x + skip_x + x = block(x, w=w) + x = self.decoder_attns[i](x, mask) + features.append(x) + if not is_last: + x = self.upsample(x) + return x, features + + def get_z(self, *args, **kwargs): + return self.coarse_stem.get_z(*args, **kwargs) + + @property + def style_net(self): + return self.coarse_stem.style_net + + def update_w(self, *args, **kwargs): + self.style_net.update_w(*args, **kwargs) + + def get_w(self, z, update_emas): + return self.style_net(z, update_emas=update_emas) + + @torch.no_grad() + def sample(self, truncation_value, **kwargs): + if truncation_value is None: + return self.forward(**kwargs) + truncation_value = max(0, truncation_value) + truncation_value = min(truncation_value, 1) + w = self.get_w(self.get_z(kwargs["condition"]), False) + w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value) + return self.forward(**kwargs, w=w) + + @torch.no_grad() + def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): + if truncation_value is None: + return self.forward(**kwargs) + truncation_value = max(0, truncation_value) + truncation_value = min(truncation_value, 1) + w = self.get_w(self.get_z(kwargs["condition"]), False) + if w_indices is None: + w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w))) + w_centers = self.style_net.w_centers[w_indices].to(w.device) + w = w_centers.to(w.dtype).lerp(w, truncation_value) + return self.forward(**kwargs, w=w) diff --git a/dp2/generator/stylegan_unet.py b/dp2/generator/stylegan_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..68c7f7706d601e4ae2eb80a4b3fd03bc127b2164 --- /dev/null +++ b/dp2/generator/stylegan_unet.py @@ -0,0 +1,208 @@ +import torch +import numpy as np +from dp2.layers import Sequential +from dp2.layers.sg2_layers import Conv2d, FullyConnectedLayer, ResidualBlock +from .base import BaseStyleGAN +from typing import List, Tuple +from .utils import spatial_embed_keypoints, mask_output + + +def get_chsize(imsize, cnum, max_imsize, max_cnum_mul): + n = int(np.log2(max_imsize) - np.log2(imsize)) + mul = min(2**n, max_cnum_mul) + ch = cnum * mul + return int(ch) + +class StyleGANUnet(BaseStyleGAN): + def __init__( + self, + scale_grad: bool, + im_channels: int, + min_fmap_resolution: int, + imsize: List[int], + cnum: int, + max_cnum_mul: int, + mask_output: bool, + conv_clamp: int, + input_cse: bool, + cse_nc: int, + n_middle_blocks: int, + input_keypoints: bool, + n_keypoints: int, + input_keypoint_indices: Tuple[int], + fix_errors: bool, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.n_keypoints = n_keypoints + self.input_keypoint_indices = list(input_keypoint_indices) + self.input_keypoints = input_keypoints + assert not (input_cse and input_keypoints) + cse_nc = 0 if cse_nc is None else cse_nc + self.imsize = imsize + self._cnum = cnum + self._max_cnum_mul = max_cnum_mul + self._min_fmap_resolution = min_fmap_resolution + self._image_channels = im_channels + self._max_imsize = max(imsize) + self.input_cse = input_cse + self.gain_unet = np.sqrt(1/3) + n_levels = int(np.log2(self._max_imsize) - np.log2(min_fmap_resolution))+1 + encoder_layers = [] + self.from_rgb = Conv2d( + im_channels + 1 + input_cse*(cse_nc+1) + input_keypoints*len(self.input_keypoint_indices), + cnum, 1 + ) + for i in range(n_levels): # Encoder layers + resolution = [x//2**i for x in imsize] + in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul) + second_ch = in_ch + out_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul) + down = 2 + + if i == 0: # first (lowest) block. Downsampling is performed at the start of the block + down = 1 + if i == n_levels - 1: + out_ch = second_ch + block = ResidualBlock(in_ch, out_ch, down=down, conv_clamp=conv_clamp, fix_residual=fix_errors) + encoder_layers.append(block) + self._encoder_out_shape = [ + get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul), + *resolution] + + self.encoder = torch.nn.ModuleList(encoder_layers) + + # initialize decoder + decoder_layers = [] + for i in range(n_levels): + resolution = [x//2**(n_levels-1-i) for x in imsize] + in_ch = get_chsize(max(resolution)//2, cnum, self._max_imsize, max_cnum_mul) + out_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul) + if i == 0: # first (lowest) block + in_ch = get_chsize(max(resolution), cnum, self._max_imsize, max_cnum_mul) + + up = 1 + if i != n_levels - 1: + up = 2 + block = ResidualBlock( + in_ch, out_ch, conv_clamp=conv_clamp, gain_out=np.sqrt(1/3), + w_dim=self.style_net.w_dim, norm=True, up=up, + fix_residual=fix_errors + ) + decoder_layers.append(block) + if i != 0: + unet_block = Conv2d( + in_ch, in_ch, kernel_size=1, conv_clamp=conv_clamp, norm=True, + gain=np.sqrt(1/3) if fix_errors else np.sqrt(.5)) + setattr(self, f"unet_block{i}", unet_block) + + # Initialize "middle blocks" that do not have down/up sample + middle_blocks = [] + for i in range(n_middle_blocks): + ch = get_chsize(min_fmap_resolution, cnum, self._max_imsize, max_cnum_mul) + block = ResidualBlock( + ch, ch, conv_clamp=conv_clamp, gain_out=np.sqrt(.5) if fix_errors else np.sqrt(1/3), + w_dim=self.style_net.w_dim, norm=True, + ) + middle_blocks.append(block) + if n_middle_blocks != 0: + self.middle_blocks = Sequential(*middle_blocks) + self.decoder = torch.nn.ModuleList(decoder_layers) + self.to_rgb = Conv2d(cnum, im_channels, 1, activation="linear", conv_clamp=conv_clamp) + # Initialize "middle blocks" that do not have down/up sample + self.decoder = torch.nn.ModuleList(decoder_layers) + self.scale_grad = scale_grad + self.mask_output = mask_output + + def forward_dec(self, x, w, unet_features, condition, mask, s, **kwargs): + for i, layer in enumerate(self.decoder): + if i != 0: + unet_layer = getattr(self, f"unet_block{i}") + x = x + unet_layer(unet_features[-i]) + x = layer(x, w=w, s=s) + x = self.to_rgb(x) + if self.mask_output: + x = mask_output(True, condition, x, mask) + return dict(img=x) + + def forward_enc(self, condition, mask, embedding, keypoints, E_mask, **kwargs): + if self.input_cse: + x = torch.cat((condition, mask, embedding, E_mask), dim=1) + else: + x = torch.cat((condition, mask), dim=1) + if self.input_keypoints: + keypoints = keypoints[:, self.input_keypoint_indices] + one_hot_pose = spatial_embed_keypoints(keypoints, x) + x = torch.cat((x, one_hot_pose), dim=1) + x = self.from_rgb(x) + + unet_features = [] + for i, layer in enumerate(self.encoder): + x = layer(x) + if i != len(self.encoder)-1: + unet_features.append(x) + if hasattr(self, "middle_blocks"): + for layer in self.middle_blocks: + x = layer(x) + return x, unet_features + + def forward( + self, condition, mask, + z=None, embedding=None, w=None, update_emas=False, x=None, + s=None, + keypoints=None, + unet_features=None, + E_mask=None, + **kwargs): + # Used to skip sampling from encoder in inference. E.g. for w projection. + if x is not None and unet_features is not None: + assert not self.training + else: + x, unet_features = self.forward_enc(condition, mask, embedding, keypoints, E_mask, **kwargs) + if w is None: + if z is None: + z = self.get_z(condition) + w = self.get_w(z, update_emas=update_emas) + return self.forward_dec(x, w, unet_features, condition, mask, s, **kwargs) + +class ComodStyleUNet(StyleGANUnet): + + def __init__(self, min_comod_res=4, lr_multiplier_comod=1, **kwargs) -> None: + super().__init__(**kwargs) + min_fmap = min(self._encoder_out_shape[1:]) + enc_out_ch = self._encoder_out_shape[0] + n_down = int(np.ceil(np.log2(min_fmap) - np.log2(min_comod_res))) + comod_layers = [] + in_ch = enc_out_ch + for i in range(n_down): + comod_layers.append(Conv2d(enc_out_ch, 256, kernel_size=3, down=2, lr_multiplier=lr_multiplier_comod)) + in_ch = 256 + if n_down == 0: + comod_layers = [Conv2d(in_ch, 256, kernel_size=3)] + comod_layers.append(torch.nn.Flatten()) + out_res = [x//2**n_down for x in self._encoder_out_shape[1:]] + in_ch_fc = np.prod(out_res) * 256 + comod_layers.append(FullyConnectedLayer(in_ch_fc, 512, lr_multiplier=lr_multiplier_comod)) + self.comod_block = Sequential(*comod_layers) + self.comod_fc = FullyConnectedLayer(512+self.style_net.w_dim, self.style_net.w_dim, lr_multiplier=lr_multiplier_comod) + + def forward_dec(self, x, w, unet_features, condition, mask, **kwargs): + y = self.comod_block(x) + y = torch.cat((y, w), dim=1) + y = self.comod_fc(y) + for i, layer in enumerate(self.decoder): + if i != 0: + unet_layer = getattr(self, f"unet_block{i}") + x = x + unet_layer(unet_features[-i], gain=np.sqrt(.5)) + x = layer(x, w=y) + x = self.to_rgb(x) + if self.mask_output: + x = mask_output(True, condition, x, mask) + return dict(img=x) + + def get_comod_y(self, batch, w): + x, unet_features = self.forward_enc(**batch) + y = self.comod_block(x) + y = torch.cat((y, w), dim=1) + y = self.comod_fc(y) + return y diff --git a/dp2/generator/utils.py b/dp2/generator/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5732b2c511a42f4bffd4b512244cf790bec96ef0 --- /dev/null +++ b/dp2/generator/utils.py @@ -0,0 +1,48 @@ +import torch +import tops +import torch +from torch.cuda.amp import custom_bwd, custom_fwd + + +@torch.no_grad() +def spatial_embed_keypoints(keypoints: torch.Tensor, x): + tops.assert_shape(keypoints, (None, None, 3)) + B, N_K, _ = keypoints.shape + H, W = x.shape[-2:] + keypoint_spatial = torch.zeros(keypoints.shape[0], N_K, H, W, device=keypoints.device, dtype=torch.float32) + x, y, visible = keypoints.chunk(3, dim=2) + x = (x * W).round().long().clamp(0, W-1) + y = (y * H).round().long().clamp(0, H-1) + kp_idx = torch.arange(0, N_K, 1, device=keypoints.device, dtype=torch.long).view(1, -1, 1).repeat(B, 1, 1) + pos = (kp_idx*(H*W) + y*W + x + 1) + # Offset all by 1 to index invisible keypoints as 0 + pos = (pos * visible.round().long()).squeeze(dim=-1) + keypoint_spatial = torch.zeros(keypoints.shape[0], N_K*H*W+1, device=keypoints.device, dtype=torch.float32) + keypoint_spatial.scatter_(1, pos, 1) + keypoint_spatial = keypoint_spatial[:, 1:].view(-1, N_K, H, W) + return keypoint_spatial + +class MaskOutput(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, x_real, x_fake, mask): + ctx.save_for_backward(mask) + out = x_real * mask + (1-mask) * x_fake + return out + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + fake_grad = grad_output + mask, = ctx.saved_tensors + fake_grad = fake_grad * (1 - mask) + known_percentage = mask.view(mask.shape[0], -1).mean(dim=1) + fake_grad = fake_grad / (1-known_percentage).view(-1, 1, 1, 1) + return None, fake_grad, None + + +def mask_output(scale_grad, x_real, x_fake, mask): + if scale_grad: + return MaskOutput.apply(x_real, x_fake, mask) + return x_real * mask + (1-mask) * x_fake diff --git a/dp2/infer.py b/dp2/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..6ac05047c0c7ced2b6de2fd816ed9f15808b781f --- /dev/null +++ b/dp2/infer.py @@ -0,0 +1,72 @@ +import tops +import torch +from tops import checkpointer +from tops.config import instantiate +from tops.logger import warn + +def load_generator_state(ckpt, G: torch.nn.Module, ckpt_mapper=None): + state = ckpt["EMA_generator"] if "EMA_generator" in ckpt else ckpt["running_average_generator"] + if ckpt_mapper is not None: + state = ckpt_mapper(state) + load_state_dict(G, state) + tops.logger.log(f"Generator loaded, num parameters: {tops.num_parameters(G)/1e6}M") + print(ckpt.keys()) + if "w_centers" in ckpt: + print("Has w_centers!") + G.style_net.w_centers = ckpt["w_centers"] + tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}") + + +def build_trained_generator(cfg, map_location=None): + map_location = map_location if map_location is not None else tops.get_device() + G = instantiate(cfg.generator).to(map_location) + G.eval() + G.imsize = tuple(cfg.data.imsize) if hasattr(cfg, "data") else None + if hasattr(cfg, "ckpt_mapper"): + ckpt_mapper = instantiate(cfg.ckpt_mapper) + else: + ckpt_mapper = None + if "model_url" in cfg.common: + ckpt = tops.load_file_or_url(cfg.common.model_url, md5sum=cfg.common.model_md5sum) + load_generator_state(ckpt, G, ckpt_mapper) + return G + try: + ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu") + load_generator_state(ckpt, G, ckpt_mapper) + except FileNotFoundError as e: + tops.logger.warn(f"Did not find generator checkpoint in: {cfg.checkpoint_dir}") + return G + + +def build_trained_discriminator(cfg, map_location=None): + map_location = map_location if map_location is not None else tops.get_device() + D = instantiate(cfg.discriminator).to(map_location) + D.eval() + try: + ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu") + if hasattr(cfg, "ckpt_mapper_D"): + ckpt["discriminator"] = instantiate(cfg.ckpt_mapper_D)(ckpt["discriminator"]) + D.load_state_dict(ckpt["discriminator"]) + except FileNotFoundError as e: + tops.logger.warn(f"Did not find discriminator checkpoint in: {cfg.checkpoint_dir}") + return D + + +def load_state_dict(module: torch.nn.Module, state_dict: dict): + module_sd = module.state_dict() + to_remove = [] + for key, item in state_dict.items(): + if key not in module_sd: + continue + if item.shape != module_sd[key].shape: + to_remove.append(key) + warn(f"Incorrect shape. Current model: {module_sd[key].shape}, in state dict: {item.shape} for key: {key}") + for key in to_remove: + state_dict.pop(key) + for key, item in state_dict.items(): + if key not in module_sd: + warn(f"Did not fin key in model state dict: {key}") + for key, item in module_sd.items(): + if key not in state_dict: + warn(f"Did not find key in state dict: {key}") + module.load_state_dict(state_dict, strict=False) \ No newline at end of file diff --git a/dp2/layers/__init__.py b/dp2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f54bed21f71c7facaa8129e1b689f16c5f9d13d --- /dev/null +++ b/dp2/layers/__init__.py @@ -0,0 +1,20 @@ +from typing import Dict +import torch +import tops +import torch.nn as nn + +class Sequential(nn.Sequential): + + def forward(self, x: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]: + for module in self: + x = module(x, **kwargs) + return x + +class Module(nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + + def extra_repr(self): + num_params = tops.num_parameters(self) / 10**6 + return f"Num params: {num_params:.3f}M" diff --git a/dp2/layers/sg2_layers.py b/dp2/layers/sg2_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3aac03935d0cb6ec172f2cd0aabefbdc74d0141e --- /dev/null +++ b/dp2/layers/sg2_layers.py @@ -0,0 +1,227 @@ +from typing import List +import numpy as np +import torch +import tops +import torch.nn.functional as F +from sg3_torch_utils.ops import conv2d_resample +from sg3_torch_utils.ops import upfirdn2d +from sg3_torch_utils.ops import bias_act +from sg3_torch_utils.ops.fma import fma + + +class FullyConnectedLayer(torch.nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + bias = True, # Apply additive bias before the activation function? + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier = 1, # Learning rate multiplier. + bias_init = 0, # Initial value for the additive bias. + ): + super().__init__() + self.repr = dict( + in_features=in_features, out_features=out_features, bias=bias, + activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init) + self.activation = activation + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) + self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + self.in_features = in_features + self.out_features = out_features + + def forward(self, x): + w = self.weight * self.weight_gain + b = self.bias + if b is not None and self.bias_gain != 1: + b = b * self.bias_gain + x = F.linear(x, w) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self) -> str: + return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) + + +class Conv2d(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size = 3, # Convolution kernel size. + up = 1, # Integer upsampling factor. + down = 1, # Integer downsampling factor + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + bias = True, + norm = False, + lr_multiplier=1, + bias_init=0, + w_dim=None, + gain=1, + ): + super().__init__() + if norm: + self.norm = torch.nn.InstanceNorm2d(None) + assert norm in [True, False] + self.up = up + self.down = down + self.activation = activation + self.conv_clamp = conv_clamp if conv_clamp is None else conv_clamp * gain + self.out_channels = out_channels + self.in_channels = in_channels + self.padding = kernel_size // 2 + + self.repr = dict( + in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, up=up, down=down, + activation=activation, resample_filter=resample_filter, conv_clamp=conv_clamp, bias=bias, + ) + + if self.up == 1 and self.down == 1: + self.resample_filter = None + else: + self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter)) + + self.act_gain = bias_act.activation_funcs[activation].def_gain * gain + self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2)) + self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size])) + self.bias = torch.nn.Parameter(torch.zeros([out_channels])+bias_init) if bias else None + self.bias_gain = lr_multiplier + if w_dim is not None: + self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) + self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0) + + def forward(self, x, w=None, s=None): + tops.assert_shape(x, [None, self.weight.shape[1], None, None]) + if s is not None: + s = s[..., :self.in_channels*2] + gamma, beta = s.view(-1, self.in_channels*2, 1, 1).chunk(2, dim=1) + x = fma(x, gamma, beta) + elif hasattr(self, "affine"): + gamma = self.affine(w).view(-1, self.in_channels, 1, 1) + beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1) + x = fma(x, gamma, beta) + w = self.weight * self.weight_gain + # Removing flip weight is not safe. + x = conv2d_resample.conv2d_resample(x, w, self.resample_filter, self.up, self.down, self.padding, flip_weight=self.up==1) + if hasattr(self, "norm"): + x = self.norm(x) + b = self.bias * self.bias_gain if self.bias is not None else None + x = bias_act.bias_act(x, b, act=self.activation, gain=self.act_gain, clamp=self.conv_clamp) + return x + + def extra_repr(self) -> str: + return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) + + +class Block(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + up = 1, + down = 1, + **layer_kwargs, # Arguments for SynthesisLayer. + ): + super().__init__() + self.in_channels = in_channels + self.down = down + self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs) + self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, **layer_kwargs) + + def forward(self, x, **layer_kwargs): + x = self.conv0(x, **layer_kwargs) + x = self.conv1(x, **layer_kwargs) + return x + + +class ResidualBlock(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels, 0 = first block. + out_channels, # Number of output channels. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + up = 1, + down = 1, + gain_out=np.sqrt(0.5), + fix_residual: bool = False, + **layer_kwargs, # Arguments for conv layer. + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.down = down + self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs) + + self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, gain=gain_out,**layer_kwargs) + + self.skip = Conv2d( + in_channels, out_channels, kernel_size=1, bias=False, up=up, down=down, + activation="linear" if fix_residual else "lrelu", + gain=gain_out + ) + self.gain_out = gain_out + + def forward(self, x, w=None, s=None, **layer_kwargs): + y = self.skip(x) + s_ = next(s) if s is not None else None + x = self.conv0(x, w, s=s_, **layer_kwargs) + s_ = next(s) if s is not None else None + x = self.conv1(x, w, s=s_, **layer_kwargs) + x = y + x + return x + + +class MinibatchStdLayer(torch.nn.Module): + def __init__(self, group_size, num_channels=1): + super().__init__() + self.group_size = group_size + self.num_channels = num_channels + + def forward(self, x): + N, C, H, W = x.shape + with tops.suppress_tracer_warnings(): # as_tensor results are registered as constants + G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N + F = self.num_channels + c = C // F + + y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. + y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. + y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. + y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. + y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels. + y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. + y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. + x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. + return x + +#---------------------------------------------------------------------------- + +class DiscriminatorEpilogue(torch.nn.Module): + def __init__(self, + in_channels, # Number of input channels. + resolution: List[int], # Resolution of this block. + mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch. + mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable. + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping. + ): + super().__init__() + self.in_channels = in_channels + self.resolution = resolution + self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None + self.conv = Conv2d( + in_channels + mbstd_num_channels, in_channels, + kernel_size=3, activation=activation, conv_clamp=conv_clamp) + self.fc = FullyConnectedLayer(in_channels * resolution[0]*resolution[1], in_channels, activation=activation) + self.out = FullyConnectedLayer(in_channels, 1) + + def forward(self, x): + tops.assert_shape(x, [None, self.in_channels, *self.resolution]) # [NCHW] + # Main layers. + if self.mbstd is not None: + x = self.mbstd(x) + x = self.conv(x) + x = self.fc(x.flatten(1)) + x = self.out(x) + return x diff --git a/dp2/loss/__init__.py b/dp2/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79165e22915aeeebabbcddebcccb281deedb40c2 --- /dev/null +++ b/dp2/loss/__init__.py @@ -0,0 +1 @@ +from .sg2_loss import StyleGAN2Loss \ No newline at end of file diff --git a/dp2/loss/pl_regularization.py b/dp2/loss/pl_regularization.py new file mode 100644 index 0000000000000000000000000000000000000000..3a557a8b92cbcb2572e8f25633e7c89e996ef906 --- /dev/null +++ b/dp2/loss/pl_regularization.py @@ -0,0 +1,48 @@ +import torch +import tops +import numpy as np +from sg3_torch_utils.ops import conv2d_gradfix + +pl_mean_total = torch.zeros([]) + +class PLRegularization: + + def __init__(self, weight: float, batch_shrink: int, pl_decay:float, scale_by_mask: bool,**kwargs): + self.pl_mean = torch.zeros([], device=tops.get_device()) + self.pl_weight = weight + self.batch_shrink = batch_shrink + self.pl_decay = pl_decay + self.scale_by_mask = scale_by_mask + + def __call__(self, G, batch, grad_scaler): + batch_size = batch["img"].shape[0] // self.batch_shrink + batch = {k: v[:batch_size] for k, v in batch.items() if k != "embed_map"} + if "embed_map" in batch: + batch["embed_map"] = batch["embed_map"] + z = G.get_z(batch["img"]) + + with torch.cuda.amp.autocast(tops.AMP()): + gen_ws = G.style_net(z) + gen_img = G(**batch, w=gen_ws)["img"].float() + pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) + with conv2d_gradfix.no_weight_gradients(): + # Sums over HWC + pl_grads = torch.autograd.grad( + outputs=[grad_scaler.scale(gen_img * pl_noise)], + inputs=[gen_ws], + create_graph=True, + grad_outputs=torch.ones_like(gen_img), + only_inputs=True)[0] + + pl_grads = pl_grads.float() / grad_scaler.get_scale() + if self.scale_by_mask: + # Percentage of pixels known + scaling = batch["mask"].flatten(start_dim=1).mean(dim=1).view(-1, 1) + pl_grads = pl_grads / scaling + pl_lengths = pl_grads.square().sum(1).sqrt() + pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) + if not torch.isnan(pl_mean).any(): + self.pl_mean.copy_(pl_mean.detach()) + pl_penalty = (pl_lengths - pl_mean).square() + to_log = dict(pl_penalty=pl_penalty.mean().detach()) + return pl_penalty.view(-1) * self.pl_weight, to_log \ No newline at end of file diff --git a/dp2/loss/r1_regularization.py b/dp2/loss/r1_regularization.py new file mode 100644 index 0000000000000000000000000000000000000000..2098d792a591e4652259703085abe9d0bc55489b --- /dev/null +++ b/dp2/loss/r1_regularization.py @@ -0,0 +1,31 @@ +import torch +import tops + +def r1_regularization( + real_img, real_score, mask, lambd: float, lazy_reg_interval: int, + lazy_regularization: bool, + scaler: torch.cuda.amp.GradScaler, mask_out: bool, + mask_out_scale: bool, + **kwargs + ): + grad = torch.autograd.grad( + outputs=scaler.scale(real_score), + inputs=real_img, + grad_outputs=torch.ones_like(real_score), + create_graph=True, + only_inputs=True, + )[0] + inv_scale = 1.0 / scaler.get_scale() + grad = grad * inv_scale + with torch.cuda.amp.autocast(tops.AMP()): + if mask_out: + grad = grad * (1 - mask) + grad = grad.square().sum(dim=[1, 2, 3]) + if mask_out and mask_out_scale: + total_pixels = real_img.shape[1] * real_img.shape[2] * real_img.shape[3] + n_fake = (1-mask).sum(dim=[1, 2, 3]) + scaling = total_pixels / n_fake + grad = grad * scaling + if lazy_regularization: + lambd_ = lambd * lazy_reg_interval / 2 # From stylegan2, lazy regularization + return grad * lambd_, grad.detach() \ No newline at end of file diff --git a/dp2/loss/sg2_loss.py b/dp2/loss/sg2_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4a0d1150a66bbff4d7f408d9b68612adac179cdb --- /dev/null +++ b/dp2/loss/sg2_loss.py @@ -0,0 +1,94 @@ +import functools +import torch +import tops +from tops import logger +from dp2.utils import forward_D_fake +from .utils import nsgan_d_loss, nsgan_g_loss +from .r1_regularization import r1_regularization +from .pl_regularization import PLRegularization + +class StyleGAN2Loss: + + def __init__( + self, + D, + G, + r1_opts: dict, + EP_lambd: float, + lazy_reg_interval: int, + lazy_regularization: bool, + pl_reg_opts: dict, + ) -> None: + self.gradient_step_D = 0 + self._lazy_reg_interval = lazy_reg_interval + self.D = D + self.G = G + self.EP_lambd = EP_lambd + self.lazy_regularization = lazy_regularization + self.r1_reg = functools.partial( + r1_regularization, **r1_opts, lazy_reg_interval=lazy_reg_interval, + lazy_regularization=lazy_regularization) + self.do_PL_Reg = False + if pl_reg_opts.weight > 0: + self.pl_reg = PLRegularization(**pl_reg_opts) + self.do_PL_Reg = True + self.pl_start_nimg = pl_reg_opts.start_nimg + + def D_loss(self, batch: dict, grad_scaler): + to_log = {} + # Forward through G and D + do_GP = self.lazy_regularization and self.gradient_step_D % self._lazy_reg_interval == 0 + if do_GP: + batch["img"] = batch["img"].detach().requires_grad_(True) + with torch.cuda.amp.autocast(enabled=tops.AMP()): + with torch.no_grad(): + G_fake = self.G(**batch, update_emas=True) + D_out_real = self.D(**batch) + + D_out_fake = forward_D_fake(batch, G_fake["img"], self.D) + + # Non saturating loss + nsgan_loss = nsgan_d_loss(D_out_real["score"], D_out_fake["score"]) + tops.assert_shape(nsgan_loss, (batch["img"].shape[0], )) + to_log["d_loss"] = nsgan_loss.mean() + total_loss = nsgan_loss + epsilon_penalty = D_out_real["score"].pow(2).view(-1) + to_log["epsilon_penalty"] = epsilon_penalty.mean() + tops.assert_shape(epsilon_penalty, total_loss.shape) + total_loss = total_loss + epsilon_penalty * self.EP_lambd + + # Improved gradient penalty with lazy regularization + # Gradient penalty applies specialized autocast. + if do_GP: + gradient_pen, grad_unscaled = self.r1_reg(batch["img"], D_out_real["score"], batch["mask"], scaler=grad_scaler) + to_log["r1_gradient_penalty"] = grad_unscaled.mean() + tops.assert_shape(gradient_pen, total_loss.shape) + total_loss = total_loss + gradient_pen + + batch["img"] = batch["img"].detach().requires_grad_(False) + if "score" in D_out_real: + to_log["real_scores"] = D_out_real["score"] + to_log["real_logits_sign"] = D_out_real["score"].sign() + to_log["fake_logits_sign"] = D_out_fake["score"].sign() + to_log["fake_scores"] = D_out_fake["score"] + to_log = {key: item.mean().detach() for key, item in to_log.items()} + self.gradient_step_D += 1 + return total_loss.mean(), to_log + + def G_loss(self, batch: dict, grad_scaler): + with torch.cuda.amp.autocast(enabled=tops.AMP()): + to_log = {} + # Forward through G and D + G_fake = self.G(**batch) + D_out_fake = forward_D_fake(batch, G_fake["img"], self.D) + # Adversarial Loss + total_loss = nsgan_g_loss(D_out_fake["score"]).view(-1) + to_log["g_loss"] = total_loss.mean() + tops.assert_shape(total_loss, (batch["img"].shape[0], )) + + if self.do_PL_Reg and logger.global_step() >= self.pl_start_nimg: + pl_reg, to_log_ = self.pl_reg(self.G, batch, grad_scaler=grad_scaler) + total_loss = total_loss + pl_reg.mean() + to_log.update(to_log_) + to_log = {key: item.mean().detach() for key, item in to_log.items()} + return total_loss.mean(), to_log diff --git a/dp2/loss/utils.py b/dp2/loss/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c9fe96156f40ffcc08ad84652d50327263db2fd --- /dev/null +++ b/dp2/loss/utils.py @@ -0,0 +1,25 @@ +import torch +import torch.nn.functional as F + +def nsgan_g_loss(fake_score): + """ + Non-saturating criterion from Goodfellow et al. 2014 + """ + return torch.nn.functional.softplus(-fake_score) + + +def nsgan_d_loss(real_score, fake_score): + """ + Non-saturating criterion from Goodfellow et al. 2014 + """ + d_loss = F.softplus(-real_score) + F.softplus(fake_score) + return d_loss.view(-1) + + +def smooth_masked_l1_loss(x, target, mask): + """ + Pixel-wise l1 loss for the area indicated by mask + """ + # Beta=.1 <-> square loss if pixel difference <= 12.8 + l1 = F.smooth_l1_loss(x*mask, target*mask, beta=.1, reduction="none").sum(dim=[1,2,3]) / mask.sum(dim=[1, 2, 3]) + return l1 diff --git a/dp2/metrics/__init__.py b/dp2/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc224b42cc5ceeaf2e68d6371ab9da1dade557ed --- /dev/null +++ b/dp2/metrics/__init__.py @@ -0,0 +1,3 @@ +from .torch_metrics import compute_metrics_iteratively +from .fid import compute_fid +from .ppl import calculate_ppl \ No newline at end of file diff --git a/dp2/metrics/fid.py b/dp2/metrics/fid.py new file mode 100644 index 0000000000000000000000000000000000000000..66eb5e0060d60294c4cdf80254e583cf8fad8bc2 --- /dev/null +++ b/dp2/metrics/fid.py @@ -0,0 +1,72 @@ +import tops +from dp2 import utils +from pathlib import Path +from torch_fidelity.generative_model_modulewrapper import GenerativeModelModuleWrapper +import torch +import torch_fidelity + + +class GeneratorIteratorWrapper(GenerativeModelModuleWrapper): + + def __init__(self, generator, dataloader, zero_z: bool, n_diverse: int): + if isinstance(generator, utils.EMA): + generator = generator.generator + z_size = generator.z_channels + super().__init__(generator, z_size, "normal", 0) + self.zero_z = zero_z + self.dataloader = iter(dataloader) + self.n_diverse = n_diverse + self.cur_div_idx = 0 + + @torch.no_grad() + def forward(self, z, **kwargs): + if self.cur_div_idx == 0: + self.batch = next(self.dataloader) + if self.zero_z: + z = z.zero_() + self.cur_div_idx += 1 + self.cur_div_idx = 0 if self.cur_div_idx == self.n_diverse else self.cur_div_idx + with torch.cuda.amp.autocast(enabled=tops.AMP()): + img = self.module(**self.batch)["img"] + img = (utils.denormalize_img(img)*255).byte() + return img + + +def compute_fid(generator, dataloader, real_directory, n_source, zero_z, n_diverse): + generator = GeneratorIteratorWrapper(generator, dataloader, zero_z, n_diverse) + batch_size = dataloader.batch_size + num_samples = (n_source * n_diverse) // batch_size * batch_size + assert n_diverse >= 1 + assert (not zero_z) or n_diverse == 1 + assert num_samples % batch_size == 0 + assert n_source <= batch_size * len(dataloader), (batch_size*len(dataloader), n_source, n_diverse) + metrics = torch_fidelity.calculate_metrics( + input1=generator, + input2=real_directory, + cuda=torch.cuda.is_available(), + fid=True, + input2_cache_name="_".join(Path(real_directory).parts) + "_cached", + input1_model_num_samples=int(num_samples), + batch_size=dataloader.batch_size + ) + return metrics["frechet_inception_distance"] + + +if __name__ == "__main__": + import click + from dp2.config import Config + from dp2.data import build_dataloader_val + from dp2.infer import build_trained_generator + @click.command() + @click.argument("config_path") + @click.option("--n_source", default=200, type=int) + @click.option("--n_diverse", default=5, type=int) + @click.option("--zero_z", default=False, is_flag=True) + def run(config_path, n_source: int, n_diverse: int, zero_z: bool): + cfg = Config.fromfile(config_path) + dataloader = build_dataloader_val(cfg) + generator, _ = build_trained_generator(cfg) + print(compute_fid( + generator, dataloader, cfg.fid_real_directory, n_source, zero_z, n_diverse)) + + run() \ No newline at end of file diff --git a/dp2/metrics/fid_clip.py b/dp2/metrics/fid_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..6712fd48503b787bc8e2197a6a737bcd73546b35 --- /dev/null +++ b/dp2/metrics/fid_clip.py @@ -0,0 +1,84 @@ +import pickle +import torch +import torchvision +from pathlib import Path +from dp2 import utils +import tops +try: + import clip +except ImportError: + print("Could not import clip.") +from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric +clip_model = None +clip_preprocess = None + + +@torch.no_grad() +def compute_fid_clip( + dataloader, generator, + cache_directory, + data_len=None, + **kwargs + ) -> dict: + """ + FID CLIP following the description in The Role of ImageNet Classes in Frechet Inception Distance, Thomas Kynkaamniemi et al. + Args: + n_samples (int): Creates N samples from same image to calculate stats + """ + global clip_model, clip_preprocess + if clip_model is None: + clip_model, preprocess = clip.load("ViT-B/32", device="cpu") + normalize_fn = preprocess.transforms[-1] + img_mean = normalize_fn.mean + img_std = normalize_fn.std + clip_model = tops.to_cuda(clip_model.visual) + clip_preprocess = tops.to_cuda(torch.nn.Sequential( + torchvision.transforms.Resize((224, 224), interpolation=torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.Normalize(img_mean, img_std) + )) + cache_directory = Path(cache_directory) + if data_len is None: + data_len = len(dataloader)*dataloader.batch_size + fid_cache_path = cache_directory.joinpath("fid_stats_clip.pkl") + has_fid_cache = fid_cache_path.is_file() + if not has_fid_cache: + fid_features_real = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device()) + fid_features_fake = torch.zeros(data_len, 512, dtype=torch.float32, device=tops.get_device()) + eidx = 0 + n_samples_seen = 0 + for batch in utils.tqdm_(iter(dataloader), desc="Computing FID CLIP."): + sidx = eidx + eidx = sidx + batch["img"].shape[0] + n_samples_seen += batch["img"].shape[0] + with torch.cuda.amp.autocast(tops.AMP()): + fakes = generator(**batch)["img"] + real_data = batch["img"] + fakes = utils.denormalize_img(fakes) + real_data = utils.denormalize_img(real_data) + if not has_fid_cache: + real_data = clip_preprocess(real_data) + fid_features_real[sidx:eidx] = clip_model(real_data) + fakes = clip_preprocess(fakes) + fid_features_fake[sidx:eidx] = clip_model(fakes) + fid_features_fake = fid_features_fake[:n_samples_seen] + fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu() + if has_fid_cache: + if tops.rank() == 0: + with open(fid_cache_path, "rb") as fp: + fid_stat_real = pickle.load(fp) + else: + fid_features_real = fid_features_real[:n_samples_seen] + fid_features_real = tops.all_gather_uneven(fid_features_real).cpu() + assert fid_features_real.shape == fid_features_fake.shape + if tops.rank() == 0: + fid_stat_real = fid_features_to_statistics(fid_features_real) + cache_directory.mkdir(exist_ok=True, parents=True) + with open(fid_cache_path, "wb") as fp: + pickle.dump(fid_stat_real, fp) + + if tops.rank() == 0: + print("Starting calculation of fid from features of shape:", fid_features_fake.shape) + fid_stat_fake = fid_features_to_statistics(fid_features_fake) + fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"] + return dict(fid_clip=fid_) + return dict(fid_clip=-1) diff --git a/dp2/metrics/lpips.py b/dp2/metrics/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..cfda315da83fab7bb1bf5e8d1b09f3a721ded064 --- /dev/null +++ b/dp2/metrics/lpips.py @@ -0,0 +1,76 @@ +import torch +import tops +import sys +from contextlib import redirect_stdout +from torch_fidelity.sample_similarity_lpips import NetLinLayer, URL_VGG16_LPIPS, VGG16features, normalize_tensor, spatial_average + +class SampleSimilarityLPIPS(torch.nn.Module): + SUPPORTED_DTYPES = { + 'uint8': torch.uint8, + 'float32': torch.float32, + } + + def __init__(self): + + super().__init__() + self.chns = [64, 128, 256, 512, 512] + self.L = len(self.chns) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=True) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=True) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=True) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=True) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=True) + self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + with redirect_stdout(sys.stderr): + fp = tops.download_file(URL_VGG16_LPIPS) + state_dict = torch.load(fp, map_location="cpu") + self.load_state_dict(state_dict) + self.net = VGG16features() + self.eval() + for param in self.parameters(): + param.requires_grad = False + mean_rescaled = (1 + torch.tensor([-.030, -.088, -.188]).view(1, 3, 1, 1)) * 255 / 2 + inv_std_rescaled = 2 / (torch.tensor([.458, .448, .450]).view(1, 3, 1, 1) * 255) + self.register_buffer("mean", mean_rescaled) + self.register_buffer("std", inv_std_rescaled) + + def normalize(self, x): + # torchvision values in range [0,1] mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225] + x = (x.float() - self.mean) * self.std + return x + + @staticmethod + def resize(x, size): + if x.shape[-1] > size and x.shape[-2] > size: + x = torch.nn.functional.interpolate(x, (size, size), mode='area') + else: + x = torch.nn.functional.interpolate(x, (size, size), mode='bilinear', align_corners=False) + return x + + def lpips_from_feats(self, feats0, feats1): + diffs = {} + for kk in range(self.L): + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(self.lins[kk].model(diffs[kk])) for kk in range(self.L)] + val = sum(res) + return val + + def get_feats(self, x): + assert x.dim() == 4 and x.shape[1] == 3, 'Input 0 is not Bx3xHxW' + if x.shape[-2] < 16: # Resize images < 16x16 + f = 16 / x.shape[-2] + size = tuple([int(f*_) for _ in x.shape[-2:]]) + x = torch.nn.functional.interpolate(x, size=size, mode="bilinear", align_corners=False) + in0_input = self.normalize(x) + outs0 = self.net.forward(in0_input) + + feats = {} + for kk in range(self.L): + feats[kk] = normalize_tensor(outs0[kk]) + return feats + + def forward(self, in0, in1): + feats0 = self.get_feats(in0) + feats1 = self.get_feats(in1) + return self.lpips_from_feats(feats0, feats1), feats0, feats1 diff --git a/dp2/metrics/ppl.py b/dp2/metrics/ppl.py new file mode 100644 index 0000000000000000000000000000000000000000..3d30b220546bcd8e44c36eede56361868e26879a --- /dev/null +++ b/dp2/metrics/ppl.py @@ -0,0 +1,110 @@ +import numpy as np +import torch +import tops +from dp2 import utils +from torch_fidelity.helpers import get_kwarg, vassert +from torch_fidelity.defaults import DEFAULTS as PPL_DEFAULTS +from torch_fidelity.utils import sample_random, batch_interp, create_sample_similarity + + +def slerp(a, b, t): + a = a / a.norm(dim=-1, keepdim=True) + b = b / b.norm(dim=-1, keepdim=True) + d = (a * b).sum(dim=-1, keepdim=True) + p = t * torch.acos(d) + c = b - d * a + c = c / c.norm(dim=-1, keepdim=True) + d = a * torch.cos(p) + c * torch.sin(p) + d = d / d.norm(dim=-1, keepdim=True) + return d + + +@torch.no_grad() +def calculate_ppl( + dataloader, + generator, + latent_space=None, + data_len=None, + **kwargs) -> dict: + """ + Inspired by https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py + """ + if latent_space is None: + latent_space = generator.latent_space + assert latent_space in ["Z", "W"], f"Not supported latent space: {latent_space}" + epsilon = PPL_DEFAULTS["ppl_epsilon"] + interp = PPL_DEFAULTS['ppl_z_interp_mode'] + similarity_name = PPL_DEFAULTS['ppl_sample_similarity'] + sample_similarity_resize = PPL_DEFAULTS['ppl_sample_similarity_resize'] + sample_similarity_dtype = PPL_DEFAULTS['ppl_sample_similarity_dtype'] + discard_percentile_lower = PPL_DEFAULTS['ppl_discard_percentile_lower'] + discard_percentile_higher = PPL_DEFAULTS['ppl_discard_percentile_higher'] + + vassert(type(epsilon) is float and epsilon > 0, 'Epsilon must be a small positive floating point number') + vassert(discard_percentile_lower is None or 0 < discard_percentile_lower < 100, 'Invalid percentile') + vassert(discard_percentile_higher is None or 0 < discard_percentile_higher < 100, 'Invalid percentile') + if discard_percentile_lower is not None and discard_percentile_higher is not None: + vassert(0 < discard_percentile_lower < discard_percentile_higher < 100, 'Invalid percentiles') + + sample_similarity = create_sample_similarity( + similarity_name, + sample_similarity_resize=sample_similarity_resize, + sample_similarity_dtype=sample_similarity_dtype, + cuda=False, + **kwargs + ) + sample_similarity = tops.to_cuda(sample_similarity) + rng = np.random.RandomState(get_kwarg('rng_seed', kwargs)) + distances = [] + if data_len is None: + data_len = len(dataloader) * dataloader.batch_size + z0 = sample_random(rng, (data_len, generator.z_channels), "normal") + z1 = sample_random(rng, (data_len, generator.z_channels), "normal") + if latent_space == "Z": + z1 = batch_interp(z0, z1, epsilon, interp) + + distances = torch.zeros(data_len, dtype=torch.float32, device=tops.get_device()) + print(distances.shape) + end = 0 + n_samples = 0 + for it, batch in enumerate(utils.tqdm_(dataloader, desc="Perceptual Path Length")): + start = end + end = start + batch["img"].shape[0] + n_samples += batch["img"].shape[0] + batch_lat_e0 = tops.to_cuda(z0[start:end]) + batch_lat_e1 = tops.to_cuda(z1[start:end]) + if latent_space == "W": + w0 = generator.get_w(batch_lat_e0, update_emas=False) + w1 = generator.get_w(batch_lat_e1, update_emas=False) + w1 = w0.lerp(w1, epsilon) # PPL end + rgb1 = generator(**batch, w=w0)["img"] + rgb2 = generator(**batch, w=w1)["img"] + else: + rgb1 = generator(**batch, z=batch_lat_e0)["img"] + rgb2 = generator(**batch, z=batch_lat_e1)["img"] + rgb1 = utils.denormalize_img(rgb1).mul(255).byte() + rgb2 = utils.denormalize_img(rgb2).mul(255).byte() + + sim = sample_similarity(rgb1, rgb2) + dist_lat_e01 = sim / (epsilon ** 2) + distances[start:end] = dist_lat_e01.view(-1) + distances = distances[:n_samples] + distances = tops.all_gather_uneven(distances).cpu().numpy() + if tops.rank() != 0: + return {"ppl/mean": -1, "ppl/std": -1} + if tops.rank() == 0: + cond, lo, hi = None, None, None + if discard_percentile_lower is not None: + lo = np.percentile(distances, discard_percentile_lower, interpolation='lower') + cond = lo <= distances + if discard_percentile_higher is not None: + hi = np.percentile(distances, discard_percentile_higher, interpolation='higher') + cond = np.logical_and(cond, distances <= hi) + if cond is not None: + distances = np.extract(cond, distances) + return { + "ppl/mean": float(np.mean(distances)), + "ppl/std": float(np.std(distances)), + } + else: + return {"ppl/mean"} diff --git a/dp2/metrics/torch_metrics.py b/dp2/metrics/torch_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..a6682afbbbe9e5205d48743ac39db8c647607666 --- /dev/null +++ b/dp2/metrics/torch_metrics.py @@ -0,0 +1,176 @@ +import pickle +import numpy as np +import torch +import time +from pathlib import Path +from dp2 import utils +import tops +from .lpips import SampleSimilarityLPIPS +from torch_fidelity.defaults import DEFAULTS as trf_defaults +from torch_fidelity.metric_fid import fid_features_to_statistics, fid_statistics_to_metric +from torch_fidelity.utils import create_feature_extractor +lpips_model = None +fid_model = None + +@torch.no_grad() +def mse(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: + se = (images1 - images2) ** 2 + se = se.view(images1.shape[0], -1).mean(dim=1) + return se + +@torch.no_grad() +def psnr(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: + mse_ = mse(images1, images2) + psnr = 10 * torch.log10(1 / mse_) + return psnr + +@torch.no_grad() +def lpips(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: + return _lpips_w_grad(images1, images2) + + +def _lpips_w_grad(images1: torch.Tensor, images2: torch.Tensor) -> torch.Tensor: + global lpips_model + if lpips_model is None: + lpips_model = tops.to_cuda(SampleSimilarityLPIPS()) + + images1 = images1.mul(255) + images2 = images2.mul(255) + with torch.cuda.amp.autocast(tops.AMP()): + dists = lpips_model(images1, images2)[0].view(-1) + return dists + + + + +@torch.no_grad() +def compute_metrics_iteratively( + dataloader, generator, + cache_directory, + data_len=None, + truncation_value: float=None, + ) -> dict: + """ + Args: + n_samples (int): Creates N samples from same image to calculate stats + dataset_percentage (float): The percentage of the dataset to compute metrics on. + """ + + global lpips_model, fid_model + if lpips_model is None: + lpips_model = tops.to_cuda(SampleSimilarityLPIPS()) + if fid_model is None: + fid_model = create_feature_extractor( + trf_defaults["feature_extractor"], [trf_defaults["feature_layer_fid"]], cuda=False) + fid_model = tops.to_cuda(fid_model) + cache_directory = Path(cache_directory) + start_time = time.time() + lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device()) + diversity_total = torch.zeros_like(lpips_total) + fid_cache_path = cache_directory.joinpath("fid_stats.pkl") + has_fid_cache = fid_cache_path.is_file() + if data_len is None: + data_len = len(dataloader)*dataloader.batch_size + if not has_fid_cache: + fid_features_real = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device()) + fid_features_fake = torch.zeros(data_len, 2048, dtype=torch.float32, device=tops.get_device()) + n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device()) + eidx = 0 + for batch in utils.tqdm_(iter(dataloader), desc="Computing FID, LPIPS and LPIPS Diversity"): + sidx = eidx + eidx = sidx + batch["img"].shape[0] + n_samples_seen += batch["img"].shape[0] + with torch.cuda.amp.autocast(tops.AMP()): + fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"] + fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"] + fakes1 = utils.denormalize_img(fakes1).mul(255) + fakes2 = utils.denormalize_img(fakes2).mul(255) + real_data = utils.denormalize_img(batch["img"]).mul(255) + lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1) + fake2_lpips_feats = lpips_model.get_feats(fakes2) + lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats) + + lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2) + diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum() + if not has_fid_cache: + fid_features_real[sidx:eidx] = fid_model(real_data.byte())[0] + fid_features_fake[sidx:eidx] = fid_model(fakes1.byte())[0] + fid_features_fake = fid_features_fake[:n_samples_seen] + if has_fid_cache: + if tops.rank() == 0: + with open(fid_cache_path, "rb") as fp: + fid_stat_real = pickle.load(fp) + else: + fid_features_real = fid_features_real[:n_samples_seen] + fid_features_real = tops.all_gather_uneven(fid_features_real).cpu() + if tops.rank() == 0: + fid_stat_real = fid_features_to_statistics(fid_features_real) + cache_directory.mkdir(exist_ok=True, parents=True) + with open(fid_cache_path, "wb") as fp: + pickle.dump(fid_stat_real, fp) + fid_features_fake = tops.all_gather_uneven(fid_features_fake).cpu() + if tops.rank() == 0: + print("Starting calculation of fid from features of shape:", fid_features_fake.shape) + fid_stat_fake = fid_features_to_statistics(fid_features_fake) + fid_ = fid_statistics_to_metric(fid_stat_real, fid_stat_fake, verbose=False)["frechet_inception_distance"] + tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM) + tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM) + tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM) + lpips_total = lpips_total / n_samples_seen + diversity_total = diversity_total / n_samples_seen + to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total) + if tops.rank() == 0: + to_return["fid"] = fid_ + else: + to_return["fid"] = -1 + to_return["validation_time_s"] = time.time() - start_time + return to_return + + +@torch.no_grad() +def compute_lpips( + dataloader, generator, + truncation_value: float=None, + data_len=None, + ) -> dict: + """ + Args: + n_samples (int): Creates N samples from same image to calculate stats + dataset_percentage (float): The percentage of the dataset to compute metrics on. + """ + global lpips_model, fid_model + if lpips_model is None: + lpips_model = tops.to_cuda(SampleSimilarityLPIPS()) + start_time = time.time() + lpips_total = torch.tensor(0, dtype=torch.float32, device=tops.get_device()) + diversity_total = torch.zeros_like(lpips_total) + if data_len is None: + data_len = len(dataloader) * dataloader.batch_size + eidx = 0 + n_samples_seen = torch.tensor([0], dtype=torch.int32, device=tops.get_device()) + for batch in utils.tqdm_(dataloader, desc="Validating on dataset."): + sidx = eidx + eidx = sidx + batch["img"].shape[0] + n_samples_seen += batch["img"].shape[0] + with torch.cuda.amp.autocast(tops.AMP()): + fakes1 = generator.sample(**batch, truncation_value=truncation_value)["img"] + fakes2 = generator.sample(**batch, truncation_value=truncation_value)["img"] + real_data = batch["img"] + fakes1 = utils.denormalize_img(fakes1).mul(255) + fakes2 = utils.denormalize_img(fakes2).mul(255) + real_data = utils.denormalize_img(real_data).mul(255) + lpips_1, real_lpips_feats, fake1_lpips_feats = lpips_model(real_data, fakes1) + fake2_lpips_feats = lpips_model.get_feats(fakes2) + lpips_2 = lpips_model.lpips_from_feats(real_lpips_feats, fake2_lpips_feats) + + lpips_total += lpips_1.sum().add(lpips_2.sum()).div(2) + diversity_total += lpips_model.lpips_from_feats(fake1_lpips_feats, fake2_lpips_feats).sum() + tops.all_reduce(n_samples_seen, torch.distributed.ReduceOp.SUM) + tops.all_reduce(lpips_total, torch.distributed.ReduceOp.SUM) + tops.all_reduce(diversity_total, torch.distributed.ReduceOp.SUM) + lpips_total = lpips_total / n_samples_seen + diversity_total = diversity_total / n_samples_seen + to_return = dict(lpips=lpips_total, lpips_diversity=diversity_total) + to_return = {k: v.cpu().item() for k, v in to_return.items()} + to_return["validation_time_s"] = time.time() - start_time + return to_return diff --git a/dp2/utils/__init__.py b/dp2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a28c97be96580185208e5c6e56b1307a0fd9b9be --- /dev/null +++ b/dp2/utils/__init__.py @@ -0,0 +1,23 @@ +import pathlib +from tops.config import LazyConfig +from .torch_utils import ( + im2torch, im2numpy, denormalize_img, set_requires_grad, forward_D_fake, + binary_dilation, crop_box, remove_pad +) +from .ema import EMA +from .utils import init_tops, tqdm_, print_config, config_to_str, trange_ +from .cse import from_E_to_vertex + + +def load_config(config_path): + config_path = pathlib.Path(config_path) + assert config_path.is_file(), config_path + cfg = LazyConfig.load(str(config_path)) + cfg.output_dir = pathlib.Path(str(config_path).replace("configs", str(cfg.common.output_dir)).replace(".py", "")) + if cfg.common.experiment_name is None: + cfg.experiment_name = str(config_path) + else: + cfg.experiment_name = cfg.common.experiment_name + cfg.checkpoint_dir = cfg.output_dir.joinpath("checkpoints") + print("Saving outputs to:", cfg.output_dir) + return cfg diff --git a/dp2/utils/bufferless_video_capture.py b/dp2/utils/bufferless_video_capture.py new file mode 100644 index 0000000000000000000000000000000000000000..b071c1c4316ad48127c86c4f52ca40f66530edf7 --- /dev/null +++ b/dp2/utils/bufferless_video_capture.py @@ -0,0 +1,32 @@ +import queue +import threading +import cv2 + + +class BufferlessVideoCapture: + + def __init__(self, name, width=None, height=None): + self.cap = cv2.VideoCapture(name) + if width is not None and height is not None: + self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, width) + self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height) + self.q = queue.Queue() + t = threading.Thread(target=self._reader) + t.daemon = True + t.start() + + # read frames as soon as they are available, keeping only most recent one + def _reader(self): + while True: + ret, frame = self.cap.read() + if not ret: + break + if not self.q.empty(): + try: + self.q.get_nowait() # discard previous (unprocessed) frame + except queue.Empty: + pass + self.q.put((ret, frame)) + + def read(self): + return self.q.get() \ No newline at end of file diff --git a/dp2/utils/cse.py b/dp2/utils/cse.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b5a2b55cb912df0a50961eafb260eb297b0d4f --- /dev/null +++ b/dp2/utils/cse.py @@ -0,0 +1,21 @@ +import warnings +import torch +from densepose.modeling.cse.utils import get_closest_vertices_mask_from_ES + + +def from_E_to_vertex(E, M, embed_map): + """ + M is 1 for unkown regions + """ + assert len(E.shape) == 4 + assert len(E.shape) == len(M.shape), (E.shape, M.shape) + assert E.shape[0] == 1 + M = M.float() + M = torch.cat([M, 1-M], dim=1) + with warnings.catch_warnings(): # Ignore userError for pytorch interpolate from detectron2 + warnings.filterwarnings("ignore") + vertices, _ = get_closest_vertices_mask_from_ES( + E, M, E.shape[2], E.shape[3], + embed_map, device=E.device) + + return vertices.long() diff --git a/dp2/utils/ema.py b/dp2/utils/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..0c508213d4c445e2417607a7d3957d3ec953eb1f --- /dev/null +++ b/dp2/utils/ema.py @@ -0,0 +1,79 @@ +import torch +import copy +import tops +from tops import logger +from .torch_utils import set_requires_grad + +class EMA: + """ + Expoenential moving average. + See: + Yazici, Y. et al.The unusual effectiveness of averaging in GAN training. ICLR 2019 + + """ + + def __init__( + self, + generator: torch.nn.Module, + batch_size: int, + rampup: float, + ): + self.rampup = rampup + self._nimg_half_time = batch_size * 10 / 32 * 1000 + self._batch_size = batch_size + with torch.no_grad(): + self.generator = copy.deepcopy(generator.cpu()).eval() + self.generator = tops.to_cuda(self.generator) + self.old_ra_beta = 0 + set_requires_grad(self.generator, False) + + def update_beta(self): + y = self._nimg_half_time + global_step = logger.global_step() + if self.rampup != None: + y = min(y, global_step*self.rampup) + self.ra_beta = 0.5 ** (self._batch_size/max(y, 1e-8)) + if self.ra_beta != self.old_ra_beta: + logger.add_scalar("stats/EMA_beta", self.ra_beta) + self.old_ra_beta = self.ra_beta + + @torch.no_grad() + def update(self, normal_G): + with torch.autograd.profiler.record_function("EMA_update"): + for ema_p, p in zip(self.generator.parameters(), + normal_G.parameters()): + ema_p.copy_(p.lerp(ema_p, self.ra_beta)) + for ema_buf, buff in zip(self.generator.buffers(), + normal_G.buffers()): + ema_buf.copy_(buff) + + def __call__(self, *args, **kwargs): + return self.generator(*args, **kwargs) + + def __getattr__(self, name: str): + if hasattr(self.generator, name): + return getattr(self.generator, name) + raise AttributeError(f"Generator object has no attribute {name}") + + def cuda(self, *args, **kwargs): + self.generator = self.generator.cuda() + return self + + def state_dict(self, *args, **kwargs): + return self.generator.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + return self.generator.load_state_dict(*args, **kwargs) + + def eval(self): + self.generator.eval() + + def train(self): + self.generator.train() + + @property + def module(self): + return self.generator.module + + def sample(self, *args, **kwargs): + return self.generator.sample(*args, **kwargs) diff --git a/dp2/utils/torch_utils.py b/dp2/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..46defb854b11a615aad3b918dce4a650b0b80889 --- /dev/null +++ b/dp2/utils/torch_utils.py @@ -0,0 +1,111 @@ +import torch +import tops + + +def denormalize_img(image, mean=0.5, std=0.5): + image = image * std + mean + image = torch.clamp(image.float(), 0, 1) + image = (image * 255) + image = torch.round(image) + return image / 255 + + +@torch.no_grad() +def im2numpy(images, to_uint8=False, denormalize=False): + if denormalize: + images = denormalize_img(images) + if images.dtype != torch.uint8: + images = images.clamp(0, 1) + return tops.im2numpy(images, to_uint8=to_uint8) + + +@torch.no_grad() +def im2torch(im, cuda=False, normalize=True, to_float=True): + im = tops.im2torch(im, cuda=cuda, to_float=to_float) + if normalize: + assert im.min() >= 0.0 and im.max() <= 1.0 + if normalize: + im = im * 2 - 1 + return im + + +@torch.no_grad() +def binary_dilation(im: torch.Tensor, kernel: torch.Tensor): + assert len(im.shape) == 4 + assert len(kernel.shape) == 2 + kernel = kernel.unsqueeze(0).unsqueeze(0) + padding = kernel.shape[-1]//2 + assert kernel.shape[-1] % 2 != 0 + if isinstance(im, torch.cuda.FloatTensor): + im, kernel = im.half(), kernel.half() + else: + im, kernel = im.float(), kernel.float() + im = torch.nn.functional.conv2d( + im, kernel, groups=im.shape[1], padding=padding) + im = im > 0.5 + return im + + +@torch.no_grad() +def binary_erosion(im: torch.Tensor, kernel: torch.Tensor): + assert len(im.shape) == 4 + assert len(kernel.shape) == 2 + kernel = kernel.unsqueeze(0).unsqueeze(0) + padding = kernel.shape[-1]//2 + assert kernel.shape[-1] % 2 != 0 + if isinstance(im, torch.cuda.FloatTensor): + im, kernel = im.half(), kernel.half() + else: + im, kernel = im.float(), kernel.float() + ksum = kernel.sum() + padding = (padding, padding, padding, padding) + im = torch.nn.functional.pad(im, padding, mode="reflect") + im = torch.nn.functional.conv2d( + im, kernel, groups=im.shape[1]) + return im.round() == ksum + + +def set_requires_grad(value: torch.nn.Module, requires_grad: bool): + if isinstance(value, (list, tuple)): + for param in value: + param.requires_grad = requires_grad + return + for p in value.parameters(): + p.requires_grad = requires_grad + + +def forward_D_fake(batch, fake_img, discriminator, **kwargs): + fake_batch = {k: v for k, v in batch.items() if k != "img"} + fake_batch["img"] = fake_img + return discriminator(**fake_batch, **kwargs) + + + +def remove_pad(x: torch.Tensor, bbox_XYXY, imshape): + """ + Remove padding that is shown as negative + """ + H, W = imshape + x0, y0, x1, y1 = bbox_XYXY + padding = [ + max(0, -x0), + max(0, -y0), + max(x1 - W, 0), + max(y1 - H, 0) + ] + x0, y0 = padding[:2] + x1 = x.shape[2] - padding[2] + y1 = x.shape[1] - padding[3] + return x[:, y0:y1, x0:x1] + + +def crop_box(x: torch.Tensor, bbox_XYXY) -> torch.Tensor: + """ + Crops x by bbox_XYXY. + """ + x0, y0, x1, y1 = bbox_XYXY + x0 = max(x0, 0) + y0 = max(y0, 0) + x1 = min(x1, x.shape[-1]) + y1 = min(y1, x.shape[-2]) + return x[..., y0:y1, x0:x1] \ No newline at end of file diff --git a/dp2/utils/utils.py b/dp2/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..965eb6ae987bd6b5644e50116bfa6317e4e36769 --- /dev/null +++ b/dp2/utils/utils.py @@ -0,0 +1,30 @@ +import tops +import tqdm +from tops import logger, highlight_py_str +from tops.config import LazyConfig + + +def print_config(cfg): + logger.log("\n" + highlight_py_str(LazyConfig.to_py(cfg, prefix=""))) + + +def config_to_str(cfg): + return LazyConfig.to_py(cfg, prefix=".") + + +def init_tops(cfg, reinit=False): + tops.init( + cfg.output_dir, cfg.common.logger_backend, cfg.experiment_name, + cfg.common.wandb_project, dict(cfg), reinit) + + +def tqdm_(iterator, *args, **kwargs): + if tops.rank() == 0: + return tqdm.tqdm(iterator, *args, **kwargs) + return iterator + + +def trange_(*args, **kwargs): + if tops.rank() == 0: + return tqdm.trange(*args, **kwargs) + return range(*args) \ No newline at end of file diff --git a/dp2/utils/vis_utils.py b/dp2/utils/vis_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ee227f359b4ab3af3a1123e2900373fed01f6c38 --- /dev/null +++ b/dp2/utils/vis_utils.py @@ -0,0 +1,407 @@ +import torch +import tops +import cv2 +import torchvision.transforms.functional as F +from typing import Optional, List, Union, Tuple +from .cse import from_E_to_vertex +import numpy as np +from tops import download_file +from .torch_utils import ( + denormalize_img, binary_dilation, binary_erosion, + remove_pad, crop_box) +from torchvision.utils import _generate_color_palette +from PIL import Image, ImageColor, ImageDraw + + +def get_coco_keypoints(): + # From: https://github.com/facebookresearch/Detectron/blob/main/detectron/utils/keypoints.py + keypoints = [ + 'nose', + 'left_eye', + 'right_eye', + 'left_ear', + 'right_ear', + 'left_shoulder', + 'right_shoulder', + 'left_elbow', + 'right_elbow', + 'left_wrist', + 'right_wrist', + 'left_hip', + 'right_hip', + 'left_knee', + 'right_knee', + 'left_ankle', + 'right_ankle' + ] + keypoint_flip_map = { + 'left_eye': 'right_eye', + 'left_ear': 'right_ear', + 'left_shoulder': 'right_shoulder', + 'left_elbow': 'right_elbow', + 'left_wrist': 'right_wrist', + 'left_hip': 'right_hip', + 'left_knee': 'right_knee', + 'left_ankle': 'right_ankle' + } + connectivity = { + "nose": "left_eye", + "left_eye": "right_eye", + "right_eye": "nose", + "left_ear": "left_eye", + "right_ear": "right_eye", + "left_shoulder": "nose", + "right_shoulder": "nose", + "left_elbow": "left_shoulder", + "right_elbow": "right_shoulder", + "left_wrist": "left_elbow", + "right_wrist": "right_elbow", + "left_hip": "left_shoulder", + "right_hip": "right_shoulder", + "left_knee": "left_hip", + "right_knee": "right_hip", + "left_ankle": "left_knee", + "right_ankle": "right_knee" + } + connectivity_indices = [ + (sidx, keypoints.index(connectivity[kp])) + for sidx, kp in enumerate(keypoints) + ] + return keypoints, keypoint_flip_map, connectivity_indices + + +@torch.no_grad() +def draw_keypoints( + image: torch.Tensor, + keypoints: torch.Tensor, + connectivity: Optional[List[Tuple[int, int]]] = None, + colors: Optional[Union[str, Tuple[int, int, int]]] = None, + radius: int = 1, + width: int = 1, +) -> torch.Tensor: + + + """ + Function taken from torchvision source code. Added in torchvision 0.12 + + Draws Keypoints on given RGB image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, + in the format [x, y]. + connectivity (List[Tuple[int, int]]]): A List of tuple where, + each tuple contains pair of keypoints to be connected. + colors (str, Tuple): The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + radius (int): Integer denoting radius of keypoint. + width (int): Integer denoting width of line connecting keypoints. + + Returns: + img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. + """ + + if not isinstance(image, torch.Tensor): + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") + + if keypoints.ndim != 3: + raise ValueError("keypoints must be of shape (num_instances, K, 2)") + + ndarr = image.permute(1, 2, 0).cpu().numpy() + img_to_draw = Image.fromarray(ndarr) + draw = ImageDraw.Draw(img_to_draw) + img_kpts = keypoints.to(torch.int64).tolist() + + for kpt_id, kpt_inst in enumerate(img_kpts): + for inst_id, kpt in enumerate(kpt_inst): + x1 = kpt[0] - radius + x2 = kpt[0] + radius + y1 = kpt[1] - radius + y2 = kpt[1] + radius + draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) + + if connectivity: + for connection in connectivity: + if connection[1] >= len(kpt_inst) or connection[0] >= len(kpt_inst): + continue + start_pt_x = kpt_inst[connection[0]][0] + start_pt_y = kpt_inst[connection[0]][1] + + end_pt_x = kpt_inst[connection[1]][0] + end_pt_y = kpt_inst[connection[1]][1] + + draw.line( + ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), + width=width, + ) + + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + +def visualize_batch( + img: torch.Tensor, mask: torch.Tensor, + vertices: torch.Tensor=None, + E_mask: torch.Tensor=None, + embed_map: torch.Tensor=None, + semantic_mask: torch.Tensor=None, + embedding: torch.Tensor=None, + keypoints: torch.Tensor=None, + maskrcnn_mask: torch.Tensor=None, + **kwargs) -> torch.ByteTensor: + img = denormalize_img(img).mul(255).byte() + img = draw_mask(img, mask) + if maskrcnn_mask is not None: + img = draw_mask(img, maskrcnn_mask) + if vertices is not None or embedding is not None: + assert E_mask is not None + assert embed_map is not None + img = draw_cse(img, E_mask, embedding, embed_map, vertices) + elif semantic_mask is not None: + img = draw_segmentation_masks(img, semantic_mask) + if keypoints is not None: + keypoints = keypoints.clone() + keypoints[:, :, 0] *= img.shape[-1] + keypoints[:, :, 1] *= img.shape[-2] + _, _, connectivity = get_coco_keypoints() + connectivity = np.array(connectivity) + for idx in range(img.shape[0]): + if keypoints.shape[-1] == 3: + visible = (keypoints[idx:idx+1, :, 2] > 0 ).view(-1) + else: + visible = torch.ones(keypoints.shape[1], device=keypoints.device, dtype=torch.bool) + + if keypoints.shape[1] == 17: # COCO Connectivity + c = connectivity[visible.cpu().numpy()].tolist() + else: + c = None + + kp = keypoints[idx:idx+1, visible].long() + img[idx] = draw_keypoints(img[idx], kp, colors="red", connectivity=c) + return img + + +@torch.no_grad() +def draw_cse( + img: torch.Tensor, E_seg: torch.Tensor, + embedding: torch.Tensor = None, + embed_map: torch.Tensor = None, + vertices: torch.Tensor = None, t=0.7 + ): + """ + E_seg: 1 for areas with embedding + """ + assert img.dtype == torch.uint8 + img = img.view(-1, *img.shape[-3:]) + E_seg = E_seg.view(-1, 1, *E_seg.shape[-2:]) + if vertices is None: + assert embedding is not None + assert embed_map is not None + embedding = embedding.view(-1, *embedding.shape[-3:]) + vertices = torch.stack( + [from_E_to_vertex(e[None], e_seg[None].logical_not().float(), embed_map) + for e, e_seg in zip(embedding, E_seg)]) + + i = np.arange(0, 256, dtype=np.uint8).reshape(1, -1) + colormap_JET = torch.from_numpy(cv2.applyColorMap(i, cv2.COLORMAP_JET)[0]) + color_embed_map, _ = np.load(download_file("https://dl.fbaipublicfiles.com/densepose/data/cse/mds_d=256.npy"), allow_pickle=True) + color_embed_map = torch.from_numpy(color_embed_map).float()[:, 0] + color_embed_map -= color_embed_map.min() + color_embed_map /= color_embed_map.max() + vertx2idx = (color_embed_map*255).long() + vertx2colormap = colormap_JET[vertx2idx] + + vertices = vertices.view(-1, *vertices.shape[-2:]) + E_seg = E_seg.view(-1, 1, *E_seg.shape[-2:]) + # This operation might be good to do on cpu... + E_color = vertx2colormap[vertices.long()] + E_color = E_color.to(E_seg.device) + E_color = E_color.permute(0, 3, 1, 2) + E_color = E_color*E_seg.byte() + + m = E_seg.bool().repeat(1, 3, 1, 1) + img[m] = (img[m] * (1-t) + t * E_color[m]).byte() + return img + + +def draw_cse_all( + embedding: List[torch.Tensor], E_mask: List[torch.Tensor], + im: torch.Tensor, boxes_XYXY: list, embed_map: torch.Tensor, t=0.7): + """ + E_seg: 1 for areas with embedding + """ + assert len(im.shape) == 3, im.shape + assert im.dtype == torch.uint8 + + N = len(E_mask) + im = im.clone() + for i in range(N): + assert len(E_mask[i].shape) == 2 + assert len(embedding[i].shape) == 3 + assert embed_map.shape[1] == embedding[i].shape[0] + assert len(boxes_XYXY[i]) == 4 + E = embedding[i] + x0, y0, x1, y1 = boxes_XYXY[i] + E = F.resize(E, (y1-y0, x1-x0), antialias=True) + s = E_mask[i].float() + s = (F.resize(s.squeeze()[None], (y1-y0, x1-x0), antialias=True) > 0).float() + box = boxes_XYXY[i] + + im_ = crop_box(im, box) + s = remove_pad(s, box, im.shape[1:]) + E = remove_pad(E, box, im.shape[1:]) + E_color = draw_cse(img=im_, E_seg=s[None], embedding=E[None],embed_map=embed_map)[0] + E_color = E_color.to(im.device) + s = s.bool().repeat(3, 1, 1) + crop_box(im, box)[s] = (im_[s] * (1-t) + t * E_color[s]).byte() + return im + + + +@torch.no_grad() +def draw_segmentation_masks( + image: torch.Tensor, + masks: torch.Tensor, + alpha: float = 0.8, + colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, +) -> torch.Tensor: + + """ + Draws segmentation masks on given RGB image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. + alpha (float): Float number between 0 and 1 denoting the transparency of the masks. + 0 means full transparency, 1 means no transparency. + colors (list or None): List containing the colors of the masks. The colors can + be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list + with one element. By default, random colors are generated for each mask. + + Returns: + img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. + """ + + if not isinstance(image, torch.Tensor): + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") + if masks.ndim == 2: + masks = masks[None, :, :] + if masks.ndim != 3: + raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") + if masks.dtype != torch.bool: + raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") + if masks.shape[-2:] != image.shape[-2:]: + raise ValueError("The image and the masks must have the same height and width") + num_masks = masks.size()[0] + if num_masks == 0: + return image + if colors is None: + colors = _generate_color_palette(num_masks) + if not isinstance(colors[0], (Tuple, List)): + colors = [colors for i in range(num_masks)] + if colors is not None and num_masks > len(colors): + raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") + + + if not isinstance(colors, list): + colors = [colors] + if not isinstance(colors[0], (tuple, str)): + raise ValueError("colors must be a tuple or a string, or a list thereof") + if isinstance(colors[0], tuple) and len(colors[0]) != 3: + raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") + + out_dtype = torch.uint8 + + colors_ = [] + for color in colors: + if isinstance(color, str): + color = ImageColor.getrgb(color) + color = torch.tensor(color, dtype=out_dtype, device=masks.device) + colors_.append(color) + img_to_draw = image.detach().clone() + # TODO: There might be a way to vectorize this + for mask, color in zip(masks, colors_): + img_to_draw[:, mask] = color[:, None] + + out = image * (1 - alpha) + img_to_draw * alpha + return out.to(out_dtype) + + +def draw_mask(im: torch.Tensor, mask: torch.Tensor, t=0.2, color=(255, 255, 255), visualize_instances=True): + """ + Visualize mask where mask = 0. + Supports multiple instances. + mask shape: [N, C, H, W], where C is different instances in same image. + """ + orig_imshape = im.shape + if mask.numel() == 0: return im + assert len(mask.shape) in (3, 4), mask.shape + mask = mask.view(-1, *mask.shape[-3:]) + im = im.view(-1, *im.shape[-3:]) + assert im.dtype == torch.uint8, im.dtype + assert 0 <= t <= 1 + if not visualize_instances: + mask = mask.any(dim=1, keepdim=True) + mask = mask.bool() + kernel = torch.ones((3, 3), dtype=mask.dtype, device=mask.device) + outer_border = binary_dilation(mask, kernel).logical_xor(mask) + outer_border = outer_border.any(dim=1, keepdim=True).repeat(1, 3, 1, 1) > 0 + inner_border = binary_erosion(mask, kernel).logical_xor(mask) + inner_border = inner_border.any(dim=1, keepdim=True).repeat(1, 3, 1, 1) > 0 + mask = (mask == 0).any(dim=1, keepdim=True).repeat(1, 3, 1, 1) + color = torch.tensor(color).to(im.device).byte().view(1, 3, 1, 1)#.repeat(1, *im.shape[1:]) + color = color.repeat(im.shape[0], 1, *im.shape[-2:]) + im[mask] = (im[mask] * (1-t) + t * color[mask]).byte() + im[outer_border] = 255 + im[inner_border] = 0 + return im.view(*orig_imshape) + + +def draw_cropped_masks(im: torch.Tensor, mask: torch.Tensor, boxes: torch.Tensor, **kwargs): + for i, box in enumerate(boxes): + x0, y0, x1, y1 = boxes[i] + orig_shape = (y1-y0, x1-x0) + m = F.resize(mask[i], orig_shape, F.InterpolationMode.NEAREST).squeeze()[None] + m = remove_pad(m, boxes[i], im.shape[-2:]) + crop_box(im, boxes[i]).set_(draw_mask(crop_box(im, boxes[i]), m)) + return im + + +def draw_cropped_keypoints(im: torch.Tensor, all_keypoints: torch.Tensor, boxes: torch.Tensor, **kwargs): + n_boxes = boxes.shape[0] + tops.assert_shape(all_keypoints, (n_boxes, 17, 3)) + im = im.clone() + for i, box in enumerate(boxes): + + x0, y0, x1, y1 = boxes[i] + orig_shape = (y1-y0, x1-x0) + keypoints = all_keypoints[i].clone() + keypoints[:, 0] *= orig_shape[1] + keypoints[:, 1] *= orig_shape[0] + keypoints = keypoints.long() + _, _, connectivity = get_coco_keypoints() + connectivity = np.array(connectivity) + visible = (keypoints[:, 2] > 0) + if keypoints.shape[0] == 17: # COCO Connectivity + c = connectivity[visible.cpu().numpy()].tolist() + else: + c = None + # Remove padding from keypoints before visualization + keypoints[:, 0] += min(x0, 0) + keypoints[:, 1] += min(y0, 0) + im_with_kp = draw_keypoints(crop_box(im, box), keypoints[None, visible, :2], colors="red", connectivity=c) + crop_box(im, box).copy_(im_with_kp) + return im