Spaces:
Sleeping
Sleeping
from __future__ import division | |
import collections | |
import numbers | |
import random | |
import torch | |
from PIL import Image | |
from skimage import color | |
import src.data.functional as F | |
__all__ = [ | |
"Compose", | |
"Concatenate", | |
"ToTensor", | |
"Normalize", | |
"Resize", | |
"Scale", | |
"CenterCrop", | |
"Pad", | |
"RandomCrop", | |
"RandomHorizontalFlip", | |
"RandomVerticalFlip", | |
"RandomResizedCrop", | |
"RandomSizedCrop", | |
"FiveCrop", | |
"TenCrop", | |
"RGB2Lab", | |
] | |
def CustomFunc(inputs, func, *args, **kwargs): | |
im_l = func(inputs[0], *args, **kwargs) | |
im_ab = func(inputs[1], *args, **kwargs) | |
warp_ba = func(inputs[2], *args, **kwargs) | |
warp_aba = func(inputs[3], *args, **kwargs) | |
im_gbl_ab = func(inputs[4], *args, **kwargs) | |
bgr_mc_im = func(inputs[5], *args, **kwargs) | |
layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im] | |
for l in range(5): | |
layer = inputs[6 + l] | |
err_ba = func(layer[0], *args, **kwargs) | |
err_ab = func(layer[1], *args, **kwargs) | |
layer_data.append([err_ba, err_ab]) | |
return layer_data | |
class Compose(object): | |
"""Composes several transforms together. | |
Args: | |
transforms (list of ``Transform`` objects): list of transforms to compose. | |
Example: | |
>>> transforms.Compose([ | |
>>> transforms.CenterCrop(10), | |
>>> transforms.ToTensor(), | |
>>> ]) | |
""" | |
def __init__(self, transforms): | |
self.transforms = transforms | |
def __call__(self, inputs): | |
for t in self.transforms: | |
inputs = t(inputs) | |
return inputs | |
class Concatenate(object): | |
""" | |
Input: [im_l, im_ab, inputs] | |
inputs = [warp_ba_l, warp_ba_ab, warp_aba, err_pm, err_aba] | |
Output:[im_l, err_pm, warp_ba, warp_aba, im_ab, err_aba] | |
""" | |
def __call__(self, inputs): | |
im_l = inputs[0] | |
im_ab = inputs[1] | |
warp_ba = inputs[2] | |
warp_aba = inputs[3] | |
im_glb_ab = inputs[4] | |
bgr_mc_im = inputs[5] | |
bgr_mc_im = bgr_mc_im[[2, 1, 0], ...] | |
err_ba = [] | |
err_ab = [] | |
for l in range(5): | |
layer = inputs[6 + l] | |
err_ba.append(layer[0]) | |
err_ab.append(layer[1]) | |
cerr_ba = torch.cat(err_ba, 0) | |
cerr_ab = torch.cat(err_ab, 0) | |
return (im_l, cerr_ba, warp_ba, warp_aba, im_glb_ab, bgr_mc_im, im_ab, cerr_ab) | |
class ToTensor(object): | |
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. | |
Converts a PIL Image or numpy.ndarray (H x W x C) in the range | |
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. | |
""" | |
def __call__(self, inputs): | |
""" | |
Args: | |
pic (PIL Image or numpy.ndarray): Image to be converted to tensor. | |
Returns: | |
Tensor: Converted image. | |
""" | |
return CustomFunc(inputs, F.to_mytensor) | |
class Normalize(object): | |
"""Normalize an tensor image with mean and standard deviation. | |
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform | |
will normalize each channel of the input ``torch.*Tensor`` i.e. | |
``input[channel] = (input[channel] - mean[channel]) / std[channel]`` | |
Args: | |
mean (sequence): Sequence of means for each channel. | |
std (sequence): Sequence of standard deviations for each channel. | |
""" | |
def __call__(self, inputs): | |
""" | |
Args: | |
tensor (Tensor): Tensor image of size (C, H, W) to be normalized. | |
Returns: | |
Tensor: Normalized Tensor image. | |
""" | |
im_l = F.normalize(inputs[0], 50, 1) # [0, 100] | |
im_ab = F.normalize(inputs[1], (0, 0), (1, 1)) # [-100, 100] | |
inputs[2][0:1, :, :] = F.normalize(inputs[2][0:1, :, :], 50, 1) | |
inputs[2][1:3, :, :] = F.normalize(inputs[2][1:3, :, :], (0, 0), (1, 1)) | |
warp_ba = inputs[2] | |
inputs[3][0:1, :, :] = F.normalize(inputs[3][0:1, :, :], 50, 1) | |
inputs[3][1:3, :, :] = F.normalize(inputs[3][1:3, :, :], (0, 0), (1, 1)) | |
warp_aba = inputs[3] | |
im_gbl_ab = F.normalize(inputs[4], (0, 0), (1, 1)) # [-100, 100] | |
bgr_mc_im = F.normalize(inputs[5], (123.68, 116.78, 103.938), (1, 1, 1)) | |
layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im] | |
for l in range(5): | |
layer = inputs[6 + l] | |
err_ba = F.normalize(layer[0], 127, 2) # [0, 255] | |
err_ab = F.normalize(layer[1], 127, 2) # [0, 255] | |
layer_data.append([err_ba, err_ab]) | |
return layer_data | |
class Resize(object): | |
"""Resize the input PIL Image to the given size. | |
Args: | |
size (sequence or int): Desired output size. If size is a sequence like | |
(h, w), output size will be matched to this. If size is an int, | |
smaller edge of the image will be matched to this number. | |
i.e, if height > width, then image will be rescaled to | |
(size * height / width, size) | |
interpolation (int, optional): Desired interpolation. Default is | |
``PIL.Image.BILINEAR`` | |
""" | |
def __init__(self, size, interpolation=Image.BILINEAR): | |
assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) | |
self.size = size | |
self.interpolation = interpolation | |
def __call__(self, inputs): | |
""" | |
Args: | |
img (PIL Image): Image to be scaled. | |
Returns: | |
PIL Image: Rescaled image. | |
""" | |
return CustomFunc(inputs, F.resize, self.size, self.interpolation) | |
class RandomCrop(object): | |
"""Crop the given PIL Image at a random location. | |
Args: | |
size (sequence or int): Desired output size of the crop. If size is an | |
int instead of sequence like (h, w), a square crop (size, size) is | |
made. | |
padding (int or sequence, optional): Optional padding on each border | |
of the image. Default is 0, i.e no padding. If a sequence of length | |
4 is provided, it is used to pad left, top, right, bottom borders | |
respectively. | |
""" | |
def __init__(self, size, padding=0): | |
if isinstance(size, numbers.Number): | |
self.size = (int(size), int(size)) | |
else: | |
self.size = size | |
self.padding = padding | |
def get_params(img, output_size): | |
"""Get parameters for ``crop`` for a random crop. | |
Args: | |
img (PIL Image): Image to be cropped. | |
output_size (tuple): Expected output size of the crop. | |
Returns: | |
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. | |
""" | |
w, h = img.size | |
th, tw = output_size | |
if w == tw and h == th: | |
return 0, 0, h, w | |
i = random.randint(0, h - th) | |
j = random.randint(0, w - tw) | |
return i, j, th, tw | |
def __call__(self, inputs): | |
""" | |
Args: | |
img (PIL Image): Image to be cropped. | |
Returns: | |
PIL Image: Cropped image. | |
""" | |
if self.padding > 0: | |
inputs = CustomFunc(inputs, F.pad, self.padding) | |
i, j, h, w = self.get_params(inputs[0], self.size) | |
return CustomFunc(inputs, F.crop, i, j, h, w) | |
class CenterCrop(object): | |
"""Crop the given PIL Image at a random location. | |
Args: | |
size (sequence or int): Desired output size of the crop. If size is an | |
int instead of sequence like (h, w), a square crop (size, size) is | |
made. | |
padding (int or sequence, optional): Optional padding on each border | |
of the image. Default is 0, i.e no padding. If a sequence of length | |
4 is provided, it is used to pad left, top, right, bottom borders | |
respectively. | |
""" | |
def __init__(self, size, padding=0): | |
if isinstance(size, numbers.Number): | |
self.size = (int(size), int(size)) | |
else: | |
self.size = size | |
self.padding = padding | |
def get_params(img, output_size): | |
"""Get parameters for ``crop`` for a random crop. | |
Args: | |
img (PIL Image): Image to be cropped. | |
output_size (tuple): Expected output size of the crop. | |
Returns: | |
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. | |
""" | |
w, h = img.size | |
th, tw = output_size | |
if w == tw and h == th: | |
return 0, 0, h, w | |
i = (h - th) // 2 | |
j = (w - tw) // 2 | |
return i, j, th, tw | |
def __call__(self, inputs): | |
""" | |
Args: | |
img (PIL Image): Image to be cropped. | |
Returns: | |
PIL Image: Cropped image. | |
""" | |
if self.padding > 0: | |
inputs = CustomFunc(inputs, F.pad, self.padding) | |
i, j, h, w = self.get_params(inputs[0], self.size) | |
return CustomFunc(inputs, F.crop, i, j, h, w) | |
class RandomHorizontalFlip(object): | |
"""Horizontally flip the given PIL Image randomly with a probability of 0.5.""" | |
def __call__(self, inputs): | |
""" | |
Args: | |
img (PIL Image): Image to be flipped. | |
Returns: | |
PIL Image: Randomly flipped image. | |
""" | |
if random.random() < 0.5: | |
return CustomFunc(inputs, F.hflip) | |
return inputs | |
class RGB2Lab(object): | |
def __call__(self, inputs): | |
""" | |
Args: | |
img (PIL Image): Image to be flipped. | |
Returns: | |
PIL Image: Randomly flipped image. | |
""" | |
def __call__(self, inputs): | |
image_lab = color.rgb2lab(inputs[0]) | |
warp_ba_lab = color.rgb2lab(inputs[2]) | |
warp_aba_lab = color.rgb2lab(inputs[3]) | |
im_gbl_lab = color.rgb2lab(inputs[4]) | |
inputs[0] = image_lab[:, :, :1] # l channel | |
inputs[1] = image_lab[:, :, 1:] # ab channel | |
inputs[2] = warp_ba_lab # lab channel | |
inputs[3] = warp_aba_lab # lab channel | |
inputs[4] = im_gbl_lab[:, :, 1:] # ab channel | |
return inputs | |