|
from typing import List, Tuple, Dict, Optional |
|
|
|
import torch |
|
import torchvision |
|
from torch import nn, Tensor |
|
from torchvision.transforms import functional as F |
|
from torchvision.transforms import transforms as T |
|
|
|
|
|
def _flip_coco_person_keypoints(kps, width): |
|
flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] |
|
flipped_data = kps[:, flip_inds] |
|
flipped_data[..., 0] = width - flipped_data[..., 0] |
|
|
|
inds = flipped_data[..., 2] == 0 |
|
flipped_data[inds] = 0 |
|
return flipped_data |
|
|
|
|
|
class Compose: |
|
def __init__(self, transforms): |
|
self.transforms = transforms |
|
|
|
def __call__(self, image, target): |
|
for t in self.transforms: |
|
image, target = t(image, target) |
|
return image, target |
|
|
|
|
|
class RandomHorizontalFlip(T.RandomHorizontalFlip): |
|
def forward( |
|
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None |
|
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: |
|
if torch.rand(1) < self.p: |
|
image = F.hflip(image) |
|
if target is not None: |
|
width, _ = F.get_image_size(image) |
|
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]] |
|
if "masks" in target: |
|
target["masks"] = target["masks"].flip(-1) |
|
if "keypoints" in target: |
|
keypoints = target["keypoints"] |
|
keypoints = _flip_coco_person_keypoints(keypoints, width) |
|
target["keypoints"] = keypoints |
|
return image, target |
|
|
|
|
|
class ToTensor(nn.Module): |
|
def forward( |
|
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None |
|
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: |
|
image = F.pil_to_tensor(image) |
|
image = F.convert_image_dtype(image) |
|
return image, target |
|
|
|
|
|
class PILToTensor(nn.Module): |
|
def forward( |
|
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None |
|
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: |
|
image = F.pil_to_tensor(image) |
|
return image, target |
|
|
|
|
|
class ConvertImageDtype(nn.Module): |
|
def __init__(self, dtype: torch.dtype) -> None: |
|
super().__init__() |
|
self.dtype = dtype |
|
|
|
def forward( |
|
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None |
|
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: |
|
image = F.convert_image_dtype(image, self.dtype) |
|
return image, target |
|
|
|
|
|
class RandomIoUCrop(nn.Module): |
|
def __init__( |
|
self, |
|
min_scale: float = 0.3, |
|
max_scale: float = 1.0, |
|
min_aspect_ratio: float = 0.5, |
|
max_aspect_ratio: float = 2.0, |
|
sampler_options: Optional[List[float]] = None, |
|
trials: int = 40, |
|
): |
|
super().__init__() |
|
|
|
self.min_scale = min_scale |
|
self.max_scale = max_scale |
|
self.min_aspect_ratio = min_aspect_ratio |
|
self.max_aspect_ratio = max_aspect_ratio |
|
if sampler_options is None: |
|
sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0] |
|
self.options = sampler_options |
|
self.trials = trials |
|
|
|
def forward( |
|
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None |
|
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: |
|
if target is None: |
|
raise ValueError("The targets can't be None for this transform.") |
|
|
|
if isinstance(image, torch.Tensor): |
|
if image.ndimension() not in {2, 3}: |
|
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") |
|
elif image.ndimension() == 2: |
|
image = image.unsqueeze(0) |
|
|
|
orig_w, orig_h = F.get_image_size(image) |
|
|
|
while True: |
|
|
|
idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) |
|
min_jaccard_overlap = self.options[idx] |
|
if min_jaccard_overlap >= 1.0: |
|
return image, target |
|
|
|
for _ in range(self.trials): |
|
|
|
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) |
|
new_w = int(orig_w * r[0]) |
|
new_h = int(orig_h * r[1]) |
|
aspect_ratio = new_w / new_h |
|
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio): |
|
continue |
|
|
|
|
|
r = torch.rand(2) |
|
left = int((orig_w - new_w) * r[0]) |
|
top = int((orig_h - new_h) * r[1]) |
|
right = left + new_w |
|
bottom = top + new_h |
|
if left == right or top == bottom: |
|
continue |
|
|
|
|
|
cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2]) |
|
cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3]) |
|
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) |
|
if not is_within_crop_area.any(): |
|
continue |
|
|
|
|
|
boxes = target["boxes"][is_within_crop_area] |
|
ious = torchvision.ops.boxes.box_iou( |
|
boxes, torch.tensor([[left, top, right, bottom]], dtype=boxes.dtype, device=boxes.device) |
|
) |
|
if ious.max() < min_jaccard_overlap: |
|
continue |
|
|
|
|
|
target["boxes"] = boxes |
|
target["labels"] = target["labels"][is_within_crop_area] |
|
target["boxes"][:, 0::2] -= left |
|
target["boxes"][:, 1::2] -= top |
|
target["boxes"][:, 0::2].clamp_(min=0, max=new_w) |
|
target["boxes"][:, 1::2].clamp_(min=0, max=new_h) |
|
image = F.crop(image, top, left, new_h, new_w) |
|
|
|
return image, target |
|
|
|
|
|
class RandomZoomOut(nn.Module): |
|
def __init__( |
|
self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5 |
|
): |
|
super().__init__() |
|
if fill is None: |
|
fill = [0.0, 0.0, 0.0] |
|
self.fill = fill |
|
self.side_range = side_range |
|
if side_range[0] < 1.0 or side_range[0] > side_range[1]: |
|
raise ValueError(f"Invalid canvas side range provided {side_range}.") |
|
self.p = p |
|
|
|
@torch.jit.unused |
|
def _get_fill_value(self, is_pil): |
|
|
|
|
|
return tuple(int(x) for x in self.fill) if is_pil else 0 |
|
|
|
def forward( |
|
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None |
|
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: |
|
if isinstance(image, torch.Tensor): |
|
if image.ndimension() not in {2, 3}: |
|
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") |
|
elif image.ndimension() == 2: |
|
image = image.unsqueeze(0) |
|
|
|
if torch.rand(1) < self.p: |
|
return image, target |
|
|
|
orig_w, orig_h = F.get_image_size(image) |
|
|
|
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) |
|
canvas_width = int(orig_w * r) |
|
canvas_height = int(orig_h * r) |
|
|
|
r = torch.rand(2) |
|
left = int((canvas_width - orig_w) * r[0]) |
|
top = int((canvas_height - orig_h) * r[1]) |
|
right = canvas_width - (left + orig_w) |
|
bottom = canvas_height - (top + orig_h) |
|
|
|
if torch.jit.is_scripting(): |
|
fill = 0 |
|
else: |
|
fill = self._get_fill_value(F._is_pil_image(image)) |
|
|
|
image = F.pad(image, [left, top, right, bottom], fill=fill) |
|
if isinstance(image, torch.Tensor): |
|
v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1) |
|
image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[ |
|
..., :, (left + orig_w) : |
|
] = v |
|
|
|
if target is not None: |
|
target["boxes"][:, 0::2] += left |
|
target["boxes"][:, 1::2] += top |
|
|
|
return image, target |
|
|
|
|
|
class RandomPhotometricDistort(nn.Module): |
|
def __init__( |
|
self, |
|
contrast: Tuple[float] = (0.5, 1.5), |
|
saturation: Tuple[float] = (0.5, 1.5), |
|
hue: Tuple[float] = (-0.05, 0.05), |
|
brightness: Tuple[float] = (0.875, 1.125), |
|
p: float = 0.5, |
|
): |
|
super().__init__() |
|
self._brightness = T.ColorJitter(brightness=brightness) |
|
self._contrast = T.ColorJitter(contrast=contrast) |
|
self._hue = T.ColorJitter(hue=hue) |
|
self._saturation = T.ColorJitter(saturation=saturation) |
|
self.p = p |
|
|
|
def forward( |
|
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None |
|
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: |
|
if isinstance(image, torch.Tensor): |
|
if image.ndimension() not in {2, 3}: |
|
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") |
|
elif image.ndimension() == 2: |
|
image = image.unsqueeze(0) |
|
|
|
r = torch.rand(7) |
|
|
|
if r[0] < self.p: |
|
image = self._brightness(image) |
|
|
|
contrast_before = r[1] < 0.5 |
|
if contrast_before: |
|
if r[2] < self.p: |
|
image = self._contrast(image) |
|
|
|
if r[3] < self.p: |
|
image = self._saturation(image) |
|
|
|
if r[4] < self.p: |
|
image = self._hue(image) |
|
|
|
if not contrast_before: |
|
if r[5] < self.p: |
|
image = self._contrast(image) |
|
|
|
if r[6] < self.p: |
|
channels = F.get_image_num_channels(image) |
|
permutation = torch.randperm(channels) |
|
|
|
is_pil = F._is_pil_image(image) |
|
if is_pil: |
|
image = F.pil_to_tensor(image) |
|
image = F.convert_image_dtype(image) |
|
image = image[..., permutation, :, :] |
|
if is_pil: |
|
image = F.to_pil_image(image) |
|
|
|
return image, target |
|
|