Spaces:
Sleeping
Sleeping
| from __future__ import division | |
| import collections | |
| import numbers | |
| import random | |
| import torch | |
| from PIL import Image | |
| from skimage import color | |
| import src.data.functional as F | |
| __all__ = [ | |
| "Compose", | |
| "Concatenate", | |
| "ToTensor", | |
| "Normalize", | |
| "Resize", | |
| "Scale", | |
| "CenterCrop", | |
| "Pad", | |
| "RandomCrop", | |
| "RandomHorizontalFlip", | |
| "RandomVerticalFlip", | |
| "RandomResizedCrop", | |
| "RandomSizedCrop", | |
| "FiveCrop", | |
| "TenCrop", | |
| "RGB2Lab", | |
| ] | |
| def CustomFunc(inputs, func, *args, **kwargs): | |
| im_l = func(inputs[0], *args, **kwargs) | |
| im_ab = func(inputs[1], *args, **kwargs) | |
| warp_ba = func(inputs[2], *args, **kwargs) | |
| warp_aba = func(inputs[3], *args, **kwargs) | |
| im_gbl_ab = func(inputs[4], *args, **kwargs) | |
| bgr_mc_im = func(inputs[5], *args, **kwargs) | |
| layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im] | |
| for l in range(5): | |
| layer = inputs[6 + l] | |
| err_ba = func(layer[0], *args, **kwargs) | |
| err_ab = func(layer[1], *args, **kwargs) | |
| layer_data.append([err_ba, err_ab]) | |
| return layer_data | |
| class Compose(object): | |
| """Composes several transforms together. | |
| Args: | |
| transforms (list of ``Transform`` objects): list of transforms to compose. | |
| Example: | |
| >>> transforms.Compose([ | |
| >>> transforms.CenterCrop(10), | |
| >>> transforms.ToTensor(), | |
| >>> ]) | |
| """ | |
| def __init__(self, transforms): | |
| self.transforms = transforms | |
| def __call__(self, inputs): | |
| for t in self.transforms: | |
| inputs = t(inputs) | |
| return inputs | |
| class Concatenate(object): | |
| """ | |
| Input: [im_l, im_ab, inputs] | |
| inputs = [warp_ba_l, warp_ba_ab, warp_aba, err_pm, err_aba] | |
| Output:[im_l, err_pm, warp_ba, warp_aba, im_ab, err_aba] | |
| """ | |
| def __call__(self, inputs): | |
| im_l = inputs[0] | |
| im_ab = inputs[1] | |
| warp_ba = inputs[2] | |
| warp_aba = inputs[3] | |
| im_glb_ab = inputs[4] | |
| bgr_mc_im = inputs[5] | |
| bgr_mc_im = bgr_mc_im[[2, 1, 0], ...] | |
| err_ba = [] | |
| err_ab = [] | |
| for l in range(5): | |
| layer = inputs[6 + l] | |
| err_ba.append(layer[0]) | |
| err_ab.append(layer[1]) | |
| cerr_ba = torch.cat(err_ba, 0) | |
| cerr_ab = torch.cat(err_ab, 0) | |
| return (im_l, cerr_ba, warp_ba, warp_aba, im_glb_ab, bgr_mc_im, im_ab, cerr_ab) | |
| class ToTensor(object): | |
| """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. | |
| Converts a PIL Image or numpy.ndarray (H x W x C) in the range | |
| [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. | |
| """ | |
| def __call__(self, inputs): | |
| """ | |
| Args: | |
| pic (PIL Image or numpy.ndarray): Image to be converted to tensor. | |
| Returns: | |
| Tensor: Converted image. | |
| """ | |
| return CustomFunc(inputs, F.to_mytensor) | |
| class Normalize(object): | |
| """Normalize an tensor image with mean and standard deviation. | |
| Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform | |
| will normalize each channel of the input ``torch.*Tensor`` i.e. | |
| ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` | |
| Args: | |
| mean (sequence): Sequence of means for each channel. | |
| std (sequence): Sequence of standard deviations for each channel. | |
| """ | |
| def __call__(self, inputs): | |
| """ | |
| Args: | |
| tensor (Tensor): Tensor image of size (C, H, W) to be normalized. | |
| Returns: | |
| Tensor: Normalized Tensor image. | |
| """ | |
| im_l = F.normalize(inputs[0], 50, 1) # [0, 100] | |
| im_ab = F.normalize(inputs[1], (0, 0), (1, 1)) # [-100, 100] | |
| inputs[2][0:1, :, :] = F.normalize(inputs[2][0:1, :, :], 50, 1) | |
| inputs[2][1:3, :, :] = F.normalize(inputs[2][1:3, :, :], (0, 0), (1, 1)) | |
| warp_ba = inputs[2] | |
| inputs[3][0:1, :, :] = F.normalize(inputs[3][0:1, :, :], 50, 1) | |
| inputs[3][1:3, :, :] = F.normalize(inputs[3][1:3, :, :], (0, 0), (1, 1)) | |
| warp_aba = inputs[3] | |
| im_gbl_ab = F.normalize(inputs[4], (0, 0), (1, 1)) # [-100, 100] | |
| bgr_mc_im = F.normalize(inputs[5], (123.68, 116.78, 103.938), (1, 1, 1)) | |
| layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im] | |
| for l in range(5): | |
| layer = inputs[6 + l] | |
| err_ba = F.normalize(layer[0], 127, 2) # [0, 255] | |
| err_ab = F.normalize(layer[1], 127, 2) # [0, 255] | |
| layer_data.append([err_ba, err_ab]) | |
| return layer_data | |
| class Resize(object): | |
| """Resize the input PIL Image to the given size. | |
| Args: | |
| size (sequence or int): Desired output size. If size is a sequence like | |
| (h, w), output size will be matched to this. If size is an int, | |
| smaller edge of the image will be matched to this number. | |
| i.e, if height > width, then image will be rescaled to | |
| (size * height / width, size) | |
| interpolation (int, optional): Desired interpolation. Default is | |
| ``PIL.Image.BILINEAR`` | |
| """ | |
| def __init__(self, size, interpolation=Image.BILINEAR): | |
| assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) | |
| self.size = size | |
| self.interpolation = interpolation | |
| def __call__(self, inputs): | |
| """ | |
| Args: | |
| img (PIL Image): Image to be scaled. | |
| Returns: | |
| PIL Image: Rescaled image. | |
| """ | |
| return CustomFunc(inputs, F.resize, self.size, self.interpolation) | |
| class RandomCrop(object): | |
| """Crop the given PIL Image at a random location. | |
| Args: | |
| size (sequence or int): Desired output size of the crop. If size is an | |
| int instead of sequence like (h, w), a square crop (size, size) is | |
| made. | |
| padding (int or sequence, optional): Optional padding on each border | |
| of the image. Default is 0, i.e no padding. If a sequence of length | |
| 4 is provided, it is used to pad left, top, right, bottom borders | |
| respectively. | |
| """ | |
| def __init__(self, size, padding=0): | |
| if isinstance(size, numbers.Number): | |
| self.size = (int(size), int(size)) | |
| else: | |
| self.size = size | |
| self.padding = padding | |
| def get_params(img, output_size): | |
| """Get parameters for ``crop`` for a random crop. | |
| Args: | |
| img (PIL Image): Image to be cropped. | |
| output_size (tuple): Expected output size of the crop. | |
| Returns: | |
| tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. | |
| """ | |
| w, h = img.size | |
| th, tw = output_size | |
| if w == tw and h == th: | |
| return 0, 0, h, w | |
| i = random.randint(0, h - th) | |
| j = random.randint(0, w - tw) | |
| return i, j, th, tw | |
| def __call__(self, inputs): | |
| """ | |
| Args: | |
| img (PIL Image): Image to be cropped. | |
| Returns: | |
| PIL Image: Cropped image. | |
| """ | |
| if self.padding > 0: | |
| inputs = CustomFunc(inputs, F.pad, self.padding) | |
| i, j, h, w = self.get_params(inputs[0], self.size) | |
| return CustomFunc(inputs, F.crop, i, j, h, w) | |
| class CenterCrop(object): | |
| """Crop the given PIL Image at a random location. | |
| Args: | |
| size (sequence or int): Desired output size of the crop. If size is an | |
| int instead of sequence like (h, w), a square crop (size, size) is | |
| made. | |
| padding (int or sequence, optional): Optional padding on each border | |
| of the image. Default is 0, i.e no padding. If a sequence of length | |
| 4 is provided, it is used to pad left, top, right, bottom borders | |
| respectively. | |
| """ | |
| def __init__(self, size, padding=0): | |
| if isinstance(size, numbers.Number): | |
| self.size = (int(size), int(size)) | |
| else: | |
| self.size = size | |
| self.padding = padding | |
| def get_params(img, output_size): | |
| """Get parameters for ``crop`` for a random crop. | |
| Args: | |
| img (PIL Image): Image to be cropped. | |
| output_size (tuple): Expected output size of the crop. | |
| Returns: | |
| tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. | |
| """ | |
| w, h = img.size | |
| th, tw = output_size | |
| if w == tw and h == th: | |
| return 0, 0, h, w | |
| i = (h - th) // 2 | |
| j = (w - tw) // 2 | |
| return i, j, th, tw | |
| def __call__(self, inputs): | |
| """ | |
| Args: | |
| img (PIL Image): Image to be cropped. | |
| Returns: | |
| PIL Image: Cropped image. | |
| """ | |
| if self.padding > 0: | |
| inputs = CustomFunc(inputs, F.pad, self.padding) | |
| i, j, h, w = self.get_params(inputs[0], self.size) | |
| return CustomFunc(inputs, F.crop, i, j, h, w) | |
| class RandomHorizontalFlip(object): | |
| """Horizontally flip the given PIL Image randomly with a probability of 0.5.""" | |
| def __call__(self, inputs): | |
| """ | |
| Args: | |
| img (PIL Image): Image to be flipped. | |
| Returns: | |
| PIL Image: Randomly flipped image. | |
| """ | |
| if random.random() < 0.5: | |
| return CustomFunc(inputs, F.hflip) | |
| return inputs | |
| class RGB2Lab(object): | |
| def __call__(self, inputs): | |
| """ | |
| Args: | |
| img (PIL Image): Image to be flipped. | |
| Returns: | |
| PIL Image: Randomly flipped image. | |
| """ | |
| def __call__(self, inputs): | |
| image_lab = color.rgb2lab(inputs[0]) | |
| warp_ba_lab = color.rgb2lab(inputs[2]) | |
| warp_aba_lab = color.rgb2lab(inputs[3]) | |
| im_gbl_lab = color.rgb2lab(inputs[4]) | |
| inputs[0] = image_lab[:, :, :1] # l channel | |
| inputs[1] = image_lab[:, :, 1:] # ab channel | |
| inputs[2] = warp_ba_lab # lab channel | |
| inputs[3] = warp_aba_lab # lab channel | |
| inputs[4] = im_gbl_lab[:, :, 1:] # ab channel | |
| return inputs | |