haakohu's picture
:)
548d634
raw
history blame
9.68 kB
from pathlib import Path
from typing import Dict, List
import torchvision
import torch
import tops
import torchvision.transforms.functional as F
from .functional import hflip
class RandomHorizontalFlip(torch.nn.Module):
def __init__(self, p: float, flip_map=None,**kwargs):
super().__init__()
self.flip_ratio = p
self.flip_map = flip_map
if self.flip_ratio is None:
self.flip_ratio = 0.5
assert 0 <= self.flip_ratio <= 1
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
if torch.rand(1) > self.flip_ratio:
return container
return hflip(container, self.flip_map)
class CenterCrop(torch.nn.Module):
"""
Performs the transform on the image.
NOTE: Does not transform the mask to improve runtime.
"""
def __init__(self, size: List[int]):
super().__init__()
self.size = tuple(size)
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
min_size = min(container["img"].shape[1], container["img"].shape[2])
if min_size < self.size[0]:
container["img"] = F.center_crop(container["img"], min_size)
container["img"] = F.resize(container["img"], self.size)
return container
container["img"] = F.center_crop(container["img"], self.size)
return container
class Resize(torch.nn.Module):
"""
Performs the transform on the image.
NOTE: Does not transform the mask to improve runtime.
"""
def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
super().__init__()
self.size = tuple(size)
self.interpolation = interpolation
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
if "semantic_mask" in container:
container["semantic_mask"] = F.resize(
container["semantic_mask"], self.size, F.InterpolationMode.NEAREST)
if "embedding" in container:
container["embedding"] = F.resize(
container["embedding"], self.size, self.interpolation)
if "mask" in container:
container["mask"] = F.resize(
container["mask"], self.size, F.InterpolationMode.NEAREST)
if "E_mask" in container:
container["E_mask"] = F.resize(
container["E_mask"], self.size, F.InterpolationMode.NEAREST)
if "maskrcnn_mask" in container:
container["maskrcnn_mask"] = F.resize(
container["maskrcnn_mask"], self.size, F.InterpolationMode.NEAREST)
if "vertices" in container:
container["vertices"] = F.resize(
container["vertices"], self.size, F.InterpolationMode.NEAREST)
return container
def __repr__(self):
repr = super().__repr__()
vars_ = dict(size=self.size, interpolation=self.interpolation)
return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
class InsertHRImage(torch.nn.Module):
"""
Resizes mask by maxpool and assumes condition is already created
"""
def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR):
super().__init__()
self.size = tuple(size)
self.interpolation = interpolation
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
assert container["img"].dtype == torch.float32
container["img_hr"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
container["condition_hr"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True)
mask = container["mask"] > 0
container["mask_hr"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float()
container["condition_hr"] = container["condition_hr"] * (1 - container["mask_hr"]) + container["img_hr"] * container["mask_hr"]
return container
def __repr__(self):
repr = super().__repr__()
vars_ = dict(size=self.size, interpolation=self.interpolation)
return repr + " "
class CopyHRImage(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
container["img_hr"] = container["img"]
container["condition_hr"] = container["condition"]
container["mask_hr"] = container["mask"]
return container
class Resize2(torch.nn.Module):
"""
Resizes mask by maxpool and assumes condition is already created
"""
def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR, downsample_condition: bool = True, mask_condition= True):
super().__init__()
self.size = tuple(size)
self.interpolation = interpolation
self.downsample_condition = downsample_condition
self.mask_condition = mask_condition
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# assert container["img"].dtype == torch.float32
container["img"] = F.resize(container["img"], self.size, self.interpolation, antialias=True)
mask = container["mask"] > 0
container["mask"] = (torch.nn.functional.adaptive_max_pool2d(mask.logical_not().float(), output_size=self.size) > 0).logical_not().float()
if self.downsample_condition:
container["condition"] = F.resize(container["condition"], self.size, self.interpolation, antialias=True)
if self.mask_condition:
container["condition"] = container["condition"] * (1 - container["mask"]) + container["img"] * container["mask"]
return container
def __repr__(self):
repr = super().__repr__()
vars_ = dict(size=self.size, interpolation=self.interpolation)
return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
class Normalize(torch.nn.Module):
"""
Performs the transform on the image.
NOTE: Does not transform the mask to improve runtime.
"""
def __init__(self, mean, std, inplace, keys=["img"]):
super().__init__()
self.mean = mean
self.std = std
self.inplace = inplace
self.keys = keys
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
for key in self.keys:
container[key] = F.normalize(container[key], self.mean, self.std, self.inplace)
return container
def __repr__(self):
repr = super().__repr__()
vars_ = dict(mean=self.mean, std=self.std, inplace=self.inplace)
return repr + " " + " ".join([f"{k}: {v}" for k, v in vars_.items()])
class ToFloat(torch.nn.Module):
def __init__(self, keys=["img"], norm=True) -> None:
super().__init__()
self.keys = keys
self.gain = 255 if norm else 1
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
for key in self.keys:
container[key] = container[key].float() / self.gain
return container
class RandomCrop(torchvision.transforms.RandomCrop):
"""
Performs the transform on the image.
NOTE: Does not transform the mask to improve runtime.
"""
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
container["img"] = super().forward(container["img"])
return container
class CreateCondition(torch.nn.Module):
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
if container["img"].dtype == torch.uint8:
container["condition"] = container["img"] * container["mask"].byte() + (1-container["mask"].byte()) * 127
return container
container["condition"] = container["img"] * container["mask"]
return container
class CreateEmbedding(torch.nn.Module):
def __init__(self, embed_path: Path, cuda=True) -> None:
super().__init__()
self.embed_map = torch.load(embed_path, map_location=torch.device("cpu"))
if cuda:
self.embed_map = tops.to_cuda(self.embed_map)
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
vertices = container["vertices"]
if vertices.ndim == 3:
embedding = self.embed_map[vertices.long()].squeeze(dim=0)
embedding = embedding.permute(2, 0, 1) * container["E_mask"]
pass
else:
assert vertices.ndim == 4
embedding = self.embed_map[vertices.long()].squeeze(dim=1)
embedding = embedding.permute(0, 3, 1, 2) * container["E_mask"]
container["embedding"] = embedding
container["embed_map"] = self.embed_map.clone()
return container
class UpdateMask(torch.nn.Module):
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
container["mask"] = (container["img"] == container["condition"]).any(dim=1, keepdims=True).float()
return container
class LoadClassEmbedding(torch.nn.Module):
def __init__(self, embedding_path: Path) -> None:
super().__init__()
self.embedding = torch.load(embedding_path, map_location="cpu")
def forward(self, container: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
key = "_".join(container["__key__"].split("train/")[-1].split("/")[:-1])
container["class_embedding"] = self.embedding[key].view(-1)
return container