File size: 16,453 Bytes
d4e7f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
# -*- coding: utf-8 -*-
# Author: Gaojian Wang@ZJUICSR
# --------------------------------------------------------
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
# You can find the license in the LICENSE file in the root directory of this source tree.
# --------------------------------------------------------

import os
import json
import shutil

from torchvision import datasets, transforms

from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

import numpy as np
from PIL import Image
import random
import torch
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torchvision import transforms
from torch.nn import functional as F


class collate_fn_crfrp:
    def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75):
        self.img_size = input_size
        self.patch_size = patch_size
        self.num_patches_axis = input_size // patch_size
        self.num_patches = (input_size // patch_size) ** 2
        self.mask_ratio = mask_ratio

        # --------------------------------------------------------------------------
        # self.facial_region_group = [
        #     [2],  # right eyebrow
        #     [3],  # left eyebrow
        #     [4],  # right eye
        #     [5],  # left eye
        #     [6],  # nose
        #     [7, 8],  # upper mouth
        #     [8, 9],  # lower mouth
        #     [10, 1, 0],  # facial boundaries
        #     [10],  # hair
        #     [1],  # facial skin
        #     [0]  # background
        # ]
        self.facial_region_group = [
            [2, 3],  # eyebrows
            [4, 5],  # eyes
            [6],  # nose
            [7, 8, 9],  # mouth
            [10, 1, 0],  # face boundaries
            [10],  # hair
            [1],  # facial skin
            [0]  # background
        ]  # ['background', 'face', 'rb', 'lb', 're', 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']

    def __call__(self, samples):
        image, img_mask, facial_region_mask, random_specific_facial_region \
            = self.CRFR_P_masking(samples, specified_facial_region=None)

        return {'image': image, 'img_mask': img_mask, 'specific_facial_region_mask': facial_region_mask}

        # # using following code if using different data augmentation for target view
        # image, img_mask, facial_region_mask, random_specific_facial_region \
        #     = self.CRFR_P_masking(samples, specified_facial_region=None)
        # image_cl, img_mask_cl, facial_region_mask_cl, random_specific_facial_region_cl \
        #     = self.CRFR_P_masking(samples, specified_facial_region=random_specific_facial_region)
        #
        # return {'image': image, 'img_mask': img_mask, 'specific_facial_region_mask': facial_region_mask,
        #         'image_cl': image_cl, 'img_mask_cl': img_mask_cl, 'specific_facial_region_mask_cl': facial_region_mask_cl}

    def CRFR_P_masking(self, samples, specified_facial_region=None):
        image = torch.stack([sample['image'] for sample in samples])  # torch.Size([bs, 3, 224, 224])
        parsing_map = torch.stack([sample['parsing_map'] for sample in samples])  # torch.Size([bs, 1, 224, 224])
        parsing_map = parsing_map.squeeze(1)  # torch.Size([BS, 1, 224, 224]) → torch.Size([BS, 224, 224])

        # covering a randomly select facial_region_group and get fr_mask(masking all patches include this region)
        facial_region_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis,
                                         dtype=torch.float32)  # torch.Size([BS, H/P, W/P])
        facial_region_mask, random_specific_facial_region \
            = self.masking_all_patches_in_random_specific_facial_region(parsing_map, facial_region_mask)
        # torch.Size([num_patches,]), list

        img_mask, facial_region_mask \
            = self.variable_proportional_masking(parsing_map, facial_region_mask, random_specific_facial_region)
        # torch.Size([num_patches,]), torch.Size([num_patches,])

        del parsing_map
        return image, img_mask, facial_region_mask, random_specific_facial_region

    def masking_all_patches_in_random_specific_facial_region(self, parsing_map, facial_region_mask,
                                                             # specified_facial_region=None
                                                             ):
        # while True:
        #     random_specific_facial_region = random.choice(self.facial_region_group[:-2])
        #     if random_specific_facial_region != specified_facial_region:
        #         break
        random_specific_facial_region = random.choice(self.facial_region_group[:-2])
        if random_specific_facial_region == [10, 1, 0]:  # facial boundaries, 10-hair 1-skin 0-background
            # True for hair(10) or bg(0) patches:
            patch_hair_bg = F.max_pool2d(((parsing_map == 10) + (parsing_map == 0)).float(),
                                         kernel_size=self.patch_size)
            # True for skin(1) patches:
            patch_skin = F.max_pool2d((parsing_map == 1).float(), kernel_size=self.patch_size)
            # skin&hair or skin&bg is defined as facial boundaries:
            facial_region_mask = (patch_hair_bg.bool() & patch_skin.bool()).float()
        else:
            for facial_region_index in random_specific_facial_region:
                facial_region_mask = torch.maximum(facial_region_mask,
                                                   F.max_pool2d((parsing_map == facial_region_index).float(),
                                                                kernel_size=self.patch_size))

        return facial_region_mask.view(parsing_map.size(0), -1), random_specific_facial_region

    def variable_proportional_masking(self, parsing_map, facial_region_mask, random_specific_facial_region):
        img_mask = facial_region_mask.clone()

        # proportional masking patches in other regions
        other_facial_region_group = [region for region in self.facial_region_group if
                                     region != random_specific_facial_region]
        # print(other_facial_region_group)
        for i in range(facial_region_mask.size(0)):  # iterate each map in BS
            num_mask_to_change = (self.mask_ratio * self.num_patches - facial_region_mask[i].sum(dim=-1)).int()
            # mask_change_to = 1 if num_mask_to_change >= 0 else 0
            mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item()

            if mask_change_to == 1:
                # proportional masking patches in other facial regions according to the corresponding ratio
                mask_ratio_other_fr = (
                        num_mask_to_change / (self.num_patches - facial_region_mask[i].sum(dim=-1)))

                masked_patches = facial_region_mask[i].clone()
                for other_fr in other_facial_region_group:
                    to_mask_patches = torch.zeros(1, self.num_patches_axis, self.num_patches_axis,
                                                  dtype=torch.float32)
                    if other_fr == [10, 1, 0]:
                        patch_hair_bg = F.max_pool2d(
                            ((parsing_map[i].unsqueeze(0) == 10) + (parsing_map[i].unsqueeze(0) == 0)).float(),
                            kernel_size=self.patch_size)
                        patch_skin = F.max_pool2d((parsing_map[i].unsqueeze(0) == 1).float(),
                                                  kernel_size=self.patch_size)
                        # skin&hair or skin&bg defined as facial boundaries:
                        to_mask_patches = (patch_hair_bg.bool() & patch_skin.bool()).float()
                    else:
                        for facial_region_index in other_fr:
                            to_mask_patches = torch.maximum(to_mask_patches,
                                                            F.max_pool2d((parsing_map[i].unsqueeze(
                                                                0) == facial_region_index).float(),
                                                                         kernel_size=self.patch_size))

                    # ignore already masked patches:
                    to_mask_patches = (to_mask_patches.view(-1) - masked_patches) > 0
                    select_indices = to_mask_patches.nonzero(as_tuple=False).view(-1)
                    change_indices = torch.randperm(len(select_indices))[
                                     :torch.round(to_mask_patches.sum() * mask_ratio_other_fr).int()]
                    img_mask[i, select_indices[change_indices]] = mask_change_to
                    # prevent overlap
                    masked_patches = masked_patches + to_mask_patches.float()

                # mask/unmask patch from other facial regions to get img_mask with fixed size
                num_mask_to_change = (self.mask_ratio * self.num_patches - img_mask[i].sum(dim=-1)).int()
                # mask_change_to = 1 if num_mask_to_change >= 0 else 0
                mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item()
                # prevent unmasking facial_region_mask
                select_indices = ((img_mask[i] + facial_region_mask[i]) == (1 - mask_change_to)).nonzero(
                    as_tuple=False).view(-1)
                change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)]
                img_mask[i, select_indices[change_indices]] = mask_change_to

            else:
                # Extreme situations:
                # if fr_mask is already over(>=) num_patches*mask_ratio, unmask it to get img_mask with fixed ratio
                select_indices = (facial_region_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1)
                change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)]
                img_mask[i, select_indices[change_indices]] = mask_change_to
                facial_region_mask[i] = img_mask[i]

        return img_mask, facial_region_mask


class FaceParsingDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root_dir = root
        self.transform = transform
        self.image_folder = os.path.join(root, 'images')
        self.parsing_map_folder = os.path.join(root, 'parsing_maps')
        self.image_names = os.listdir(self.image_folder)

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_folder, self.image_names[idx])
        parsing_map_name = os.path.join(self.parsing_map_folder, self.image_names[idx].replace('.png', '.npy'))

        image = Image.open(img_name).convert("RGB")
        parsing_map_np = np.load(parsing_map_name)

        if self.transform:
            image = self.transform(image)

        # Convert mask to tensor
        parsing_map = torch.from_numpy(parsing_map_np)
        del parsing_map_np  # may save mem

        return {'image': image, 'parsing_map': parsing_map}


class TestImageFolder(datasets.ImageFolder):
    def __init__(self, root, transform=None, target_transform=None):
        super(TestImageFolder, self).__init__(root, transform, target_transform)

    def __getitem__(self, index):
        # Call the parent class method to load image and label
        original_tuple = super(TestImageFolder, self).__getitem__(index)

        # Get the video name
        video_name = self.imgs[index][0].split('/')[-1].split('_frame_')[0]  # the separator of video name

        # Extend the tuple to include video name
        extended_tuple = (original_tuple + (video_name,))

        return extended_tuple


def get_mean_std(args):
    print('dataset_paths:', args.data_path)
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Resize((args.input_size, args.input_size),
                                                      interpolation=transforms.InterpolationMode.BICUBIC)])

    if len(args.data_path) > 1:
        pretrain_datasets = [FaceParsingDataset(root=path, transform=transform) for path in args.data_path]
        dataset_pretrain = ConcatDataset(pretrain_datasets)
    else:
        pretrain_datasets = args.data_path[0]
        dataset_pretrain = FaceParsingDataset(root=pretrain_datasets, transform=transform)

    print('Compute mean and variance for pretraining data.')
    print('len(dataset_train): ', len(dataset_pretrain))

    loader = DataLoader(
        dataset_pretrain,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    for sample in loader:
        data = sample['image']
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_squared_sum += torch.mean(data ** 2, dim=[0, 2, 3])
        num_batches += 1

    mean = channels_sum / num_batches
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    print(f'train dataset mean%: {mean.numpy()} std: %{std.numpy()} ')
    del pretrain_datasets, dataset_pretrain, loader
    return mean.numpy(), std.numpy()


def build_dataset(is_train, args):
    transform = build_transform(is_train, args)
    if args.eval:
        # no loading training set
        root = os.path.join(args.data_path, 'test' if is_train else 'test')
        dataset = TestImageFolder(root, transform=transform)
    else:
        root = os.path.join(args.data_path, 'train' if is_train else 'val')
        dataset = datasets.ImageFolder(root, transform=transform)
        print(dataset)

    return dataset


def build_transform(is_train, args):
    if args.normalize_from_IMN:
        mean = IMAGENET_DEFAULT_MEAN
        std = IMAGENET_DEFAULT_STD
        # print(f'mean:{mean}, std:{std}')
    else:
        if not os.path.exists(os.path.join(args.output_dir, "/pretrain_ds_mean_std.txt")) and not args.eval:
            shutil.copyfile(os.path.dirname(args.finetune) + '/pretrain_ds_mean_std.txt',
                            os.path.join(args.output_dir) + '/pretrain_ds_mean_std.txt')
        with open(os.path.join(os.path.dirname(args.resume)) + '/pretrain_ds_mean_std.txt' if args.eval
                  else os.path.join(args.output_dir) + '/pretrain_ds_mean_std.txt', 'r') as file:
            ds_stat = json.loads(file.readline())
            mean = ds_stat['mean']
            std = ds_stat['std']
            # print(f'mean:{mean}, std:{std}')

    if args.apply_simple_augment:
        if is_train:
            # this should always dispatch to transforms_imagenet_train
            transform = create_transform(
                input_size=args.input_size,
                is_training=True,
                color_jitter=args.color_jitter,
                auto_augment=args.aa,
                interpolation=transforms.InterpolationMode.BICUBIC,
                re_prob=args.reprob,
                re_mode=args.remode,
                re_count=args.recount,
                mean=mean,
                std=std,
            )
            return transform

        # no augment / eval transform
        t = []
        if args.input_size <= 224:
            crop_pct = 224 / 256
        else:
            crop_pct = 1.0
        size = int(args.input_size / crop_pct)  # 256
        t.append(
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
            # to maintain same ratio w.r.t. 224 images
        )
        t.append(transforms.CenterCrop(args.input_size))  # 224

        t.append(transforms.ToTensor())
        t.append(transforms.Normalize(mean, std))
        return transforms.Compose(t)

    else:
        t = []
        if args.input_size < 224:
            crop_pct = input_size / 224
        else:
            crop_pct = 1.0
        size = int(args.input_size / crop_pct)  # size = 224
        t.append(
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
            # to maintain same ratio w.r.t. 224 images
        )
        # t.append(
        #     transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
        #     # to maintain same ratio w.r.t. 224 images
        # )

        t.append(transforms.ToTensor())
        t.append(transforms.Normalize(mean, std))
        return transforms.Compose(t)