|
import os |
|
import random |
|
import time |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
import cv2 |
|
|
|
from torchvision import transforms, datasets |
|
from torchvision.transforms import functional as F |
|
|
|
def _getvocpallete(num_colors): |
|
return [0, 0, 0] * num_colors |
|
|
|
|
|
|
|
|
|
|
|
class Rotate: |
|
def __init__(self, angle): |
|
self.angle = random.randint(-angle, angle) |
|
|
|
def __call__(self, img): |
|
return F.rotate(img, angle=self.angle) |
|
|
|
class Shear: |
|
def __init__(self, shear=10, scale=(1.0, 1.0)): |
|
self.shear = random.uniform(-shear, shear) |
|
self.scale = random.uniform(scale[0], scale[1]) |
|
|
|
def __call__(self, img): |
|
return F.affine(img, angle=0, translate=(0, 0), scale=self.scale, shear=[self.shear, self.shear]) |
|
|
|
class Skew: |
|
def __init__(self, magnitude=0.2): |
|
self.xshift = random.uniform(-magnitude, magnitude) |
|
self.yshift = random.uniform(-magnitude, magnitude) |
|
|
|
def __call__(self, img): |
|
width, height = img.size |
|
x_shift = int(self.xshift * width) |
|
y_shift = int(self.yshift * height) |
|
return img.transform(img.size, Image.AFFINE, (1, 0, x_shift, 0, 1, y_shift)) |
|
|
|
class Crop: |
|
def __init__(self, min_crop=0.8, max_crop=0.9): |
|
self.crop_scale = random.uniform(min_crop, max_crop) |
|
self.seed = time.time() |
|
|
|
def __call__(self, img): |
|
width, height = img.size |
|
crop_width = int(self.crop_scale * width) |
|
crop_height = int(self.crop_scale * height) |
|
|
|
random.seed(self.seed) |
|
left = random.randint(0, width - crop_width) |
|
top = random.randint(0, height - crop_height) |
|
return F.crop(img, top, left, crop_height, crop_width).resize((width, height)) |
|
|
|
|
|
class GaussianNoise: |
|
def __init__(self, mean=0, std=(10,20)): |
|
self.mean = mean |
|
self.std = random.uniform(std[0], std[1]) |
|
|
|
def __call__(self, img): |
|
img = np.array(img) |
|
|
|
noise = np.random.normal(self.mean, self.std, img.shape) |
|
img = img + noise |
|
img = np.clip(img, 0, 255).astype(np.uint8) |
|
return Image.fromarray(img) |
|
|
|
|
|
class SaltAndPepperNoise: |
|
def __init__(self, min_prob=0.01, max_prob=0.05): |
|
self.salt_prob = random.uniform(min_prob, max_prob) |
|
self.pepper_prob = random.uniform(min_prob, max_prob) |
|
|
|
def __call__(self, img): |
|
img_array = np.array(img) |
|
|
|
salt_mask = np.random.rand(*img_array.shape[:2]) < self.salt_prob |
|
pepper_mask = np.random.rand(*img_array.shape[:2]) < self.pepper_prob |
|
img_array[salt_mask] = 255 |
|
img_array[pepper_mask] = 0 |
|
return Image.fromarray(img_array.astype(np.uint8)) |
|
|
|
class MotionBlur: |
|
def __init__(self, min_size=3, max_size=21): |
|
self.kernel_size = random.randint(min_size, max_size) |
|
|
|
def __call__(self, img): |
|
img_array = np.array(img) |
|
|
|
kernel = np.zeros((self.kernel_size, self.kernel_size)) |
|
kernel[int((self.kernel_size - 1) / 2), :] = np.ones(self.kernel_size) |
|
kernel = kernel / self.kernel_size |
|
blurred = cv2.filter2D(img_array, -1, kernel) |
|
return Image.fromarray(blurred.astype(np.uint8)) |
|
|
|
class HideAndSeekNoise: |
|
def __init__(self, min_size=90, max_size=190): |
|
self.patch_size = random.randint(min_size, max_size) |
|
self.seed = time.time() |
|
|
|
def __call__(self, img): |
|
img_array = np.array(img) |
|
height, width, _ = img_array.shape |
|
|
|
random.seed(self.seed) |
|
top = random.randint(0, height - self.patch_size) |
|
left = random.randint(0, width - self.patch_size) |
|
img_array[top:top + self.patch_size, left:left + self.patch_size] = [0, 0, 0] |
|
return Image.fromarray(img_array) |
|
|
|
|
|
|
|
class BaseDataset(torch.utils.data.Dataset): |
|
def __init__(self, path_list, transform = None, data_set = 'val', seed=None, |
|
img_size=768, interpolation=Image.BILINEAR, color_pallete = 'city'): |
|
""" |
|
:param path_list: Path to file listing image paths. |
|
:param transform: Additional torchvision transforms. |
|
:param data_set: 'train' or other mode. |
|
:param seed: Seed for shuffling. |
|
:param img_size: Resize dimensions. |
|
:param interpolation: Interpolation method for resizing. |
|
""" |
|
self.transform = transform |
|
self.data_set = data_set |
|
self.color_pallete = color_pallete |
|
|
|
with open(path_list, "r") as file: |
|
self.imgs = file.readlines() |
|
|
|
if seed: |
|
random.seed(seed) |
|
random.shuffle(self.imgs) |
|
|
|
self.masks = [img_path for img_path in self.imgs] |
|
self.learning_map = None |
|
|
|
self.aug_weights = [0.4, 0.3, 0.3, 0.2, 0.2, 0.05, 0.05, 0.02, 0.02] |
|
if img_size: |
|
self.transform_resize = transforms.Resize((img_size, img_size), interpolation=Image.BILINEAR) |
|
|
|
def convert_label(self, label, inverse=False): |
|
temp = label.copy() |
|
converted_label = np.zeros_like(label) |
|
for k, v in self.learning_map.items(): |
|
converted_label[temp == k] = v |
|
return converted_label |
|
|
|
def get_color_pallete(self, npimg, dataset='city'): |
|
out_img = Image.fromarray(npimg.astype('uint8')).convert('P') |
|
if dataset == 'city': |
|
cityspallete = [ |
|
0, 0, 0, |
|
128, 64, 128, |
|
244, 35, 232, |
|
70, 70, 70, |
|
102, 102, 156, |
|
190, 153, 153, |
|
153, 153, 153, |
|
250, 170, 30, |
|
220, 220, 0, |
|
107, 142, 35, |
|
152, 251, 152, |
|
0, 130, 180, |
|
220, 20, 60, |
|
255, 0, 0, |
|
0, 0, 142, |
|
0, 0, 70, |
|
0, 60, 100, |
|
0, 80, 100, |
|
0, 0, 230, |
|
119, 11, 32, |
|
] |
|
out_img.putpalette(cityspallete) |
|
else: |
|
vocpallete = _getvocpallete(256) |
|
out_img.putpalette(vocpallete) |
|
return out_img.convert("RGB") |
|
|
|
def __getitem__(self, index): |
|
img_path, mask_path = self.imgs[index].rstrip(), self.masks[index].rstrip() |
|
|
|
|
|
img = Image.open(img_path).convert('RGB') |
|
img = self.transform_resize(img) |
|
|
|
|
|
mask = Image.open(mask_path) |
|
mask = np.array(mask) |
|
mask = self.convert_label(mask) |
|
mask = mask.astype(np.uint8) |
|
mask = self.get_color_pallete(mask, self.color_pallete) |
|
mask = self.transform_resize(mask) |
|
|
|
|
|
augmentation_num = random.choices(range(9), weights=self.aug_weights, k=1)[0] if self.data_set == 'train' else 0 |
|
if augmentation_num > 0: |
|
augmentation_set = [ |
|
transforms.RandomHorizontalFlip(p=1), |
|
transforms.RandomVerticalFlip(p=1), |
|
Crop(min_crop=0.6, max_crop=0.9), |
|
Rotate(angle=90), |
|
Shear(shear=10, scale=(0.8, 1.2)), |
|
Skew(magnitude=0.2), |
|
HideAndSeekNoise(min_size=90, max_size=210), |
|
GaussianNoise(mean=0, std=(5,20)), |
|
SaltAndPepperNoise(min_prob=0.01, max_prob=0.03), |
|
transforms.GaussianBlur(kernel_size=3, sigma=(0.2, 1)), |
|
MotionBlur(min_size=3, max_size=15), |
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), |
|
] |
|
random.shuffle(augmentation_set) |
|
augmentation_set = augmentation_set[:augmentation_num] |
|
for aug in augmentation_set: |
|
if isinstance(aug, (transforms.GaussianBlur, transforms.ColorJitter, GaussianNoise, SaltAndPepperNoise, MotionBlur)): |
|
img = aug(img) |
|
else: |
|
img = aug(img) |
|
mask = aug(mask) |
|
|
|
if self.transform: |
|
img = self.transform(img) |
|
mask = self.transform(mask) |
|
return img, mask, img_path |
|
|
|
def __len__(self): |
|
return len(self.imgs) |