Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Transforms and data augmentation for both image + bbox. | |
""" | |
import logging | |
import random | |
from typing import Iterable | |
import torch | |
import torchvision.transforms as T | |
import torchvision.transforms.functional as F | |
import torchvision.transforms.v2.functional as Fv2 | |
from PIL import Image as PILImage | |
from torchvision.transforms import InterpolationMode | |
from training.utils.data_utils import VideoDatapoint | |
def hflip(datapoint, index): | |
datapoint.frames[index].data = F.hflip(datapoint.frames[index].data) | |
for obj in datapoint.frames[index].objects: | |
if obj.segment is not None: | |
obj.segment = F.hflip(obj.segment) | |
return datapoint | |
def get_size_with_aspect_ratio(image_size, size, max_size=None): | |
w, h = image_size | |
if max_size is not None: | |
min_original_size = float(min((w, h))) | |
max_original_size = float(max((w, h))) | |
if max_original_size / min_original_size * size > max_size: | |
size = max_size * min_original_size / max_original_size | |
if (w <= h and w == size) or (h <= w and h == size): | |
return (h, w) | |
if w < h: | |
ow = int(round(size)) | |
oh = int(round(size * h / w)) | |
else: | |
oh = int(round(size)) | |
ow = int(round(size * w / h)) | |
return (oh, ow) | |
def resize(datapoint, index, size, max_size=None, square=False, v2=False): | |
# size can be min_size (scalar) or (w, h) tuple | |
def get_size(image_size, size, max_size=None): | |
if isinstance(size, (list, tuple)): | |
return size[::-1] | |
else: | |
return get_size_with_aspect_ratio(image_size, size, max_size) | |
if square: | |
size = size, size | |
else: | |
cur_size = ( | |
datapoint.frames[index].data.size()[-2:][::-1] | |
if v2 | |
else datapoint.frames[index].data.size | |
) | |
size = get_size(cur_size, size, max_size) | |
old_size = ( | |
datapoint.frames[index].data.size()[-2:][::-1] | |
if v2 | |
else datapoint.frames[index].data.size | |
) | |
if v2: | |
datapoint.frames[index].data = Fv2.resize( | |
datapoint.frames[index].data, size, antialias=True | |
) | |
else: | |
datapoint.frames[index].data = F.resize(datapoint.frames[index].data, size) | |
new_size = ( | |
datapoint.frames[index].data.size()[-2:][::-1] | |
if v2 | |
else datapoint.frames[index].data.size | |
) | |
for obj in datapoint.frames[index].objects: | |
if obj.segment is not None: | |
obj.segment = F.resize(obj.segment[None, None], size).squeeze() | |
h, w = size | |
datapoint.frames[index].size = (h, w) | |
return datapoint | |
def pad(datapoint, index, padding, v2=False): | |
old_h, old_w = datapoint.frames[index].size | |
h, w = old_h, old_w | |
if len(padding) == 2: | |
# assumes that we only pad on the bottom right corners | |
datapoint.frames[index].data = F.pad( | |
datapoint.frames[index].data, (0, 0, padding[0], padding[1]) | |
) | |
h += padding[1] | |
w += padding[0] | |
else: | |
# left, top, right, bottom | |
datapoint.frames[index].data = F.pad( | |
datapoint.frames[index].data, | |
(padding[0], padding[1], padding[2], padding[3]), | |
) | |
h += padding[1] + padding[3] | |
w += padding[0] + padding[2] | |
datapoint.frames[index].size = (h, w) | |
for obj in datapoint.frames[index].objects: | |
if obj.segment is not None: | |
if v2: | |
if len(padding) == 2: | |
obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1])) | |
else: | |
obj.segment = Fv2.pad(obj.segment, tuple(padding)) | |
else: | |
if len(padding) == 2: | |
obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1])) | |
else: | |
obj.segment = F.pad(obj.segment, tuple(padding)) | |
return datapoint | |
class RandomHorizontalFlip: | |
def __init__(self, consistent_transform, p=0.5): | |
self.p = p | |
self.consistent_transform = consistent_transform | |
def __call__(self, datapoint, **kwargs): | |
if self.consistent_transform: | |
if random.random() < self.p: | |
for i in range(len(datapoint.frames)): | |
datapoint = hflip(datapoint, i) | |
return datapoint | |
for i in range(len(datapoint.frames)): | |
if random.random() < self.p: | |
datapoint = hflip(datapoint, i) | |
return datapoint | |
class RandomResizeAPI: | |
def __init__( | |
self, sizes, consistent_transform, max_size=None, square=False, v2=False | |
): | |
if isinstance(sizes, int): | |
sizes = (sizes,) | |
assert isinstance(sizes, Iterable) | |
self.sizes = list(sizes) | |
self.max_size = max_size | |
self.square = square | |
self.consistent_transform = consistent_transform | |
self.v2 = v2 | |
def __call__(self, datapoint, **kwargs): | |
if self.consistent_transform: | |
size = random.choice(self.sizes) | |
for i in range(len(datapoint.frames)): | |
datapoint = resize( | |
datapoint, i, size, self.max_size, square=self.square, v2=self.v2 | |
) | |
return datapoint | |
for i in range(len(datapoint.frames)): | |
size = random.choice(self.sizes) | |
datapoint = resize( | |
datapoint, i, size, self.max_size, square=self.square, v2=self.v2 | |
) | |
return datapoint | |
class ToTensorAPI: | |
def __init__(self, v2=False): | |
self.v2 = v2 | |
def __call__(self, datapoint: VideoDatapoint, **kwargs): | |
for img in datapoint.frames: | |
if self.v2: | |
img.data = Fv2.to_image_tensor(img.data) | |
else: | |
img.data = F.to_tensor(img.data) | |
return datapoint | |
class NormalizeAPI: | |
def __init__(self, mean, std, v2=False): | |
self.mean = mean | |
self.std = std | |
self.v2 = v2 | |
def __call__(self, datapoint: VideoDatapoint, **kwargs): | |
for img in datapoint.frames: | |
if self.v2: | |
img.data = Fv2.convert_image_dtype(img.data, torch.float32) | |
img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std) | |
else: | |
img.data = F.normalize(img.data, mean=self.mean, std=self.std) | |
return datapoint | |
class ComposeAPI: | |
def __init__(self, transforms): | |
self.transforms = transforms | |
def __call__(self, datapoint, **kwargs): | |
for t in self.transforms: | |
datapoint = t(datapoint, **kwargs) | |
return datapoint | |
def __repr__(self): | |
format_string = self.__class__.__name__ + "(" | |
for t in self.transforms: | |
format_string += "\n" | |
format_string += " {0}".format(t) | |
format_string += "\n)" | |
return format_string | |
class RandomGrayscale: | |
def __init__(self, consistent_transform, p=0.5): | |
self.p = p | |
self.consistent_transform = consistent_transform | |
self.Grayscale = T.Grayscale(num_output_channels=3) | |
def __call__(self, datapoint: VideoDatapoint, **kwargs): | |
if self.consistent_transform: | |
if random.random() < self.p: | |
for img in datapoint.frames: | |
img.data = self.Grayscale(img.data) | |
return datapoint | |
for img in datapoint.frames: | |
if random.random() < self.p: | |
img.data = self.Grayscale(img.data) | |
return datapoint | |
class ColorJitter: | |
def __init__(self, consistent_transform, brightness, contrast, saturation, hue): | |
self.consistent_transform = consistent_transform | |
self.brightness = ( | |
brightness | |
if isinstance(brightness, list) | |
else [max(0, 1 - brightness), 1 + brightness] | |
) | |
self.contrast = ( | |
contrast | |
if isinstance(contrast, list) | |
else [max(0, 1 - contrast), 1 + contrast] | |
) | |
self.saturation = ( | |
saturation | |
if isinstance(saturation, list) | |
else [max(0, 1 - saturation), 1 + saturation] | |
) | |
self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue]) | |
def __call__(self, datapoint: VideoDatapoint, **kwargs): | |
if self.consistent_transform: | |
# Create a color jitter transformation params | |
( | |
fn_idx, | |
brightness_factor, | |
contrast_factor, | |
saturation_factor, | |
hue_factor, | |
) = T.ColorJitter.get_params( | |
self.brightness, self.contrast, self.saturation, self.hue | |
) | |
for img in datapoint.frames: | |
if not self.consistent_transform: | |
( | |
fn_idx, | |
brightness_factor, | |
contrast_factor, | |
saturation_factor, | |
hue_factor, | |
) = T.ColorJitter.get_params( | |
self.brightness, self.contrast, self.saturation, self.hue | |
) | |
for fn_id in fn_idx: | |
if fn_id == 0 and brightness_factor is not None: | |
img.data = F.adjust_brightness(img.data, brightness_factor) | |
elif fn_id == 1 and contrast_factor is not None: | |
img.data = F.adjust_contrast(img.data, contrast_factor) | |
elif fn_id == 2 and saturation_factor is not None: | |
img.data = F.adjust_saturation(img.data, saturation_factor) | |
elif fn_id == 3 and hue_factor is not None: | |
img.data = F.adjust_hue(img.data, hue_factor) | |
return datapoint | |
class RandomAffine: | |
def __init__( | |
self, | |
degrees, | |
consistent_transform, | |
scale=None, | |
translate=None, | |
shear=None, | |
image_mean=(123, 116, 103), | |
log_warning=True, | |
num_tentatives=1, | |
image_interpolation="bicubic", | |
): | |
""" | |
The mask is required for this transform. | |
if consistent_transform if True, then the same random affine is applied to all frames and masks. | |
""" | |
self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees]) | |
self.scale = scale | |
self.shear = ( | |
shear if isinstance(shear, list) else ([-shear, shear] if shear else None) | |
) | |
self.translate = translate | |
self.fill_img = image_mean | |
self.consistent_transform = consistent_transform | |
self.log_warning = log_warning | |
self.num_tentatives = num_tentatives | |
if image_interpolation == "bicubic": | |
self.image_interpolation = InterpolationMode.BICUBIC | |
elif image_interpolation == "bilinear": | |
self.image_interpolation = InterpolationMode.BILINEAR | |
else: | |
raise NotImplementedError | |
def __call__(self, datapoint: VideoDatapoint, **kwargs): | |
for _tentative in range(self.num_tentatives): | |
res = self.transform_datapoint(datapoint) | |
if res is not None: | |
return res | |
if self.log_warning: | |
logging.warning( | |
f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives" | |
) | |
return datapoint | |
def transform_datapoint(self, datapoint: VideoDatapoint): | |
_, height, width = F.get_dimensions(datapoint.frames[0].data) | |
img_size = [width, height] | |
if self.consistent_transform: | |
# Create a random affine transformation | |
affine_params = T.RandomAffine.get_params( | |
degrees=self.degrees, | |
translate=self.translate, | |
scale_ranges=self.scale, | |
shears=self.shear, | |
img_size=img_size, | |
) | |
for img_idx, img in enumerate(datapoint.frames): | |
this_masks = [ | |
obj.segment.unsqueeze(0) if obj.segment is not None else None | |
for obj in img.objects | |
] | |
if not self.consistent_transform: | |
# if not consistent we create a new affine params for every frame&mask pair Create a random affine transformation | |
affine_params = T.RandomAffine.get_params( | |
degrees=self.degrees, | |
translate=self.translate, | |
scale_ranges=self.scale, | |
shears=self.shear, | |
img_size=img_size, | |
) | |
transformed_bboxes, transformed_masks = [], [] | |
for i in range(len(img.objects)): | |
if this_masks[i] is None: | |
transformed_masks.append(None) | |
# Dummy bbox for a dummy target | |
transformed_bboxes.append(torch.tensor([[0, 0, 1, 1]])) | |
else: | |
transformed_mask = F.affine( | |
this_masks[i], | |
*affine_params, | |
interpolation=InterpolationMode.NEAREST, | |
fill=0.0, | |
) | |
if img_idx == 0 and transformed_mask.max() == 0: | |
# We are dealing with a video and the object is not visible in the first frame | |
# Return the datapoint without transformation | |
return None | |
transformed_masks.append(transformed_mask.squeeze()) | |
for i in range(len(img.objects)): | |
img.objects[i].segment = transformed_masks[i] | |
img.data = F.affine( | |
img.data, | |
*affine_params, | |
interpolation=self.image_interpolation, | |
fill=self.fill_img, | |
) | |
return datapoint | |
def random_mosaic_frame( | |
datapoint, | |
index, | |
grid_h, | |
grid_w, | |
target_grid_y, | |
target_grid_x, | |
should_hflip, | |
): | |
# Step 1: downsize the images and paste them into a mosaic | |
image_data = datapoint.frames[index].data | |
is_pil = isinstance(image_data, PILImage.Image) | |
if is_pil: | |
H_im = image_data.height | |
W_im = image_data.width | |
image_data_output = PILImage.new("RGB", (W_im, H_im)) | |
else: | |
H_im = image_data.size(-2) | |
W_im = image_data.size(-1) | |
image_data_output = torch.zeros_like(image_data) | |
downsize_cache = {} | |
for grid_y in range(grid_h): | |
for grid_x in range(grid_w): | |
y_offset_b = grid_y * H_im // grid_h | |
x_offset_b = grid_x * W_im // grid_w | |
y_offset_e = (grid_y + 1) * H_im // grid_h | |
x_offset_e = (grid_x + 1) * W_im // grid_w | |
H_im_downsize = y_offset_e - y_offset_b | |
W_im_downsize = x_offset_e - x_offset_b | |
if (H_im_downsize, W_im_downsize) in downsize_cache: | |
image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)] | |
else: | |
image_data_downsize = F.resize( | |
image_data, | |
size=(H_im_downsize, W_im_downsize), | |
interpolation=InterpolationMode.BILINEAR, | |
antialias=True, # antialiasing for downsizing | |
) | |
downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize | |
if should_hflip[grid_y, grid_x].item(): | |
image_data_downsize = F.hflip(image_data_downsize) | |
if is_pil: | |
image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b)) | |
else: | |
image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = ( | |
image_data_downsize | |
) | |
datapoint.frames[index].data = image_data_output | |
# Step 2: downsize the masks and paste them into the target grid of the mosaic | |
for obj in datapoint.frames[index].objects: | |
if obj.segment is None: | |
continue | |
assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8 | |
segment_output = torch.zeros_like(obj.segment) | |
target_y_offset_b = target_grid_y * H_im // grid_h | |
target_x_offset_b = target_grid_x * W_im // grid_w | |
target_y_offset_e = (target_grid_y + 1) * H_im // grid_h | |
target_x_offset_e = (target_grid_x + 1) * W_im // grid_w | |
target_H_im_downsize = target_y_offset_e - target_y_offset_b | |
target_W_im_downsize = target_x_offset_e - target_x_offset_b | |
segment_downsize = F.resize( | |
obj.segment[None, None], | |
size=(target_H_im_downsize, target_W_im_downsize), | |
interpolation=InterpolationMode.BILINEAR, | |
antialias=True, # antialiasing for downsizing | |
)[0, 0] | |
if should_hflip[target_grid_y, target_grid_x].item(): | |
segment_downsize = F.hflip(segment_downsize[None, None])[0, 0] | |
segment_output[ | |
target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e | |
] = segment_downsize | |
obj.segment = segment_output | |
return datapoint | |
class RandomMosaicVideoAPI: | |
def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False): | |
self.prob = prob | |
self.grid_h = grid_h | |
self.grid_w = grid_w | |
self.use_random_hflip = use_random_hflip | |
def __call__(self, datapoint, **kwargs): | |
if random.random() > self.prob: | |
return datapoint | |
# select a random location to place the target mask in the mosaic | |
target_grid_y = random.randint(0, self.grid_h - 1) | |
target_grid_x = random.randint(0, self.grid_w - 1) | |
# whether to flip each grid in the mosaic horizontally | |
if self.use_random_hflip: | |
should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5 | |
else: | |
should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool) | |
for i in range(len(datapoint.frames)): | |
datapoint = random_mosaic_frame( | |
datapoint, | |
i, | |
grid_h=self.grid_h, | |
grid_w=self.grid_w, | |
target_grid_y=target_grid_y, | |
target_grid_x=target_grid_x, | |
should_hflip=should_hflip, | |
) | |
return datapoint | |