File size: 8,518 Bytes
e8f2edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import os
import random
import time
import numpy as np
import torch
from PIL import Image
import cv2

from torchvision import transforms, datasets
from torchvision.transforms import functional as F

def _getvocpallete(num_colors):
    return [0, 0, 0] * num_colors


# ----------------------
# Augmentation Classes
# ----------------------
class Rotate:
    def __init__(self, angle):
        self.angle = random.randint(-angle, angle)

    def __call__(self, img):
        return F.rotate(img, angle=self.angle)
    
class Shear:
    def __init__(self, shear=10, scale=(1.0, 1.0)):
        self.shear = random.uniform(-shear, shear)
        self.scale = random.uniform(scale[0], scale[1])

    def __call__(self, img):
        return F.affine(img, angle=0, translate=(0, 0), scale=self.scale, shear=[self.shear, self.shear])
    
class Skew:
    def __init__(self, magnitude=0.2):
        self.xshift = random.uniform(-magnitude, magnitude)
        self.yshift = random.uniform(-magnitude, magnitude)

    def __call__(self, img):
        width, height = img.size
        x_shift = int(self.xshift * width)
        y_shift = int(self.yshift * height)
        return img.transform(img.size, Image.AFFINE, (1, 0, x_shift, 0, 1, y_shift))
    
class Crop:
    def __init__(self, min_crop=0.8, max_crop=0.9):
        self.crop_scale = random.uniform(min_crop, max_crop)
        self.seed = time.time()

    def __call__(self, img):
        width, height = img.size        
        crop_width = int(self.crop_scale * width)
        crop_height = int(self.crop_scale * height)
        
        random.seed(self.seed)
        left = random.randint(0, width - crop_width)
        top = random.randint(0, height - crop_height)
        return F.crop(img, top, left, crop_height, crop_width).resize((width, height))
    

class GaussianNoise:
    def __init__(self, mean=0, std=(10,20)):
        self.mean = mean
        self.std = random.uniform(std[0], std[1])

    def __call__(self, img):
        img = np.array(img)

        noise = np.random.normal(self.mean, self.std, img.shape)
        img = img + noise
        img = np.clip(img, 0, 255).astype(np.uint8)
        return Image.fromarray(img)
    

class SaltAndPepperNoise:
    def __init__(self, min_prob=0.01, max_prob=0.05):
        self.salt_prob = random.uniform(min_prob, max_prob)
        self.pepper_prob = random.uniform(min_prob, max_prob)

    def __call__(self, img):
        img_array = np.array(img)

        salt_mask = np.random.rand(*img_array.shape[:2]) < self.salt_prob
        pepper_mask = np.random.rand(*img_array.shape[:2]) < self.pepper_prob
        img_array[salt_mask] = 255
        img_array[pepper_mask] = 0
        return Image.fromarray(img_array.astype(np.uint8))
    
class MotionBlur:
    def __init__(self, min_size=3, max_size=21):
        self.kernel_size = random.randint(min_size, max_size)

    def __call__(self, img):
        img_array = np.array(img)

        kernel = np.zeros((self.kernel_size, self.kernel_size))
        kernel[int((self.kernel_size - 1) / 2), :] = np.ones(self.kernel_size)
        kernel = kernel / self.kernel_size
        blurred = cv2.filter2D(img_array, -1, kernel)
        return Image.fromarray(blurred.astype(np.uint8))
    
class HideAndSeekNoise:
    def __init__(self, min_size=90, max_size=190):
        self.patch_size = random.randint(min_size, max_size)
        self.seed = time.time()

    def __call__(self, img):
        img_array = np.array(img)
        height, width, _ = img_array.shape
        
        random.seed(self.seed)
        top = random.randint(0, height - self.patch_size)
        left = random.randint(0, width - self.patch_size)
        img_array[top:top + self.patch_size, left:left + self.patch_size] = [0, 0, 0]
        return Image.fromarray(img_array)
    


class BaseDataset(torch.utils.data.Dataset):
    def __init__(self, path_list, transform = None, data_set = 'val', seed=None,
                img_size=768, interpolation=Image.BILINEAR, color_pallete = 'city'):
        """
        :param path_list: Path to file listing image paths.
        :param transform: Additional torchvision transforms.
        :param data_set: 'train' or other mode.
        :param seed: Seed for shuffling.
        :param img_size: Resize dimensions.
        :param interpolation: Interpolation method for resizing.
        """
        self.transform = transform
        self.data_set = data_set
        self.color_pallete = color_pallete

        with open(path_list, "r") as file:
            self.imgs = file.readlines()

        if seed:
            random.seed(seed)
            random.shuffle(self.imgs)

        self.masks = [img_path for img_path in self.imgs]
        self.learning_map = None
        
        self.aug_weights = [0.4, 0.3, 0.3, 0.2, 0.2, 0.05, 0.05, 0.02, 0.02]
        if img_size:
            self.transform_resize = transforms.Resize((img_size, img_size), interpolation=Image.BILINEAR)

    def convert_label(self, label, inverse=False):
        temp = label.copy()
        converted_label = np.zeros_like(label)
        for k, v in self.learning_map.items():
            converted_label[temp == k] = v
        return converted_label

    def get_color_pallete(self, npimg, dataset='city'):
        out_img = Image.fromarray(npimg.astype('uint8')).convert('P')
        if dataset == 'city':
            cityspallete = [
                0, 0, 0,
                128, 64, 128,
                244, 35, 232,
                70, 70, 70,
                102, 102, 156,
                190, 153, 153,
                153, 153, 153,
                250, 170, 30,
                220, 220, 0,
                107, 142, 35,
                152, 251, 152,
                0, 130, 180,
                220, 20, 60,
                255, 0, 0,
                0, 0, 142,
                0, 0, 70,
                0, 60, 100,
                0, 80, 100,
                0, 0, 230,
                119, 11, 32,
            ]
            out_img.putpalette(cityspallete)
        else:
            vocpallete = _getvocpallete(256)
            out_img.putpalette(vocpallete)
        return out_img.convert("RGB")

    def __getitem__(self, index):
        img_path, mask_path = self.imgs[index].rstrip(), self.masks[index].rstrip()

        # Load and resize the image
        img = Image.open(img_path).convert('RGB')
        img = self.transform_resize(img)
        
        # Load, convert, and resize the mask
        mask = Image.open(mask_path)
        mask = np.array(mask)
        mask = self.convert_label(mask)
        mask = mask.astype(np.uint8)
        mask = self.get_color_pallete(mask, self.color_pallete)
        mask = self.transform_resize(mask)
        
        # Augmentation stage
        augmentation_num = random.choices(range(9), weights=self.aug_weights, k=1)[0] if self.data_set == 'train' else 0
        if augmentation_num > 0:
            augmentation_set = [
                transforms.RandomHorizontalFlip(p=1),  # Flip horizontally
                transforms.RandomVerticalFlip(p=1),  # Flip vertically
                Crop(min_crop=0.6, max_crop=0.9),  # Random crop
                Rotate(angle=90),  # Rotate
                Shear(shear=10, scale=(0.8, 1.2)),  # Shear
                Skew(magnitude=0.2),  # Skew
                HideAndSeekNoise(min_size=90, max_size=210), #Hide and seek noise
                GaussianNoise(mean=0, std=(5,20)), # Gaussian noise (only for image) / std 10-20
                SaltAndPepperNoise(min_prob=0.01, max_prob=0.03), #Salt and pepper noise (only for image) 
                transforms.GaussianBlur(kernel_size=3, sigma=(0.2, 1)),  # Gaussian blur (only for image)
                MotionBlur(min_size=3, max_size=15), # Motion blur (only for image)
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # Color jitter (only for image)
            ]
            random.shuffle(augmentation_set)
            augmentation_set = augmentation_set[:augmentation_num]
            for aug in augmentation_set:
                if isinstance(aug, (transforms.GaussianBlur, transforms.ColorJitter, GaussianNoise, SaltAndPepperNoise, MotionBlur)):
                    img = aug(img)
                else:
                    img = aug(img)
                    mask = aug(mask)

        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)
        return img, mask, img_path

    def __len__(self):
        return len(self.imgs)