# -*- coding: utf-8 -*- # Author: Gaojian Wang@ZJUICSR # -------------------------------------------------------- # This source code is licensed under the Attribution-NonCommercial 4.0 International License. # You can find the license in the LICENSE file in the root directory of this source tree. # -------------------------------------------------------- import os import json import shutil from torchvision import datasets, transforms from timm.data import create_transform from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD import numpy as np from PIL import Image import random import torch from torch.utils.data import DataLoader, Dataset, ConcatDataset from torchvision import transforms from torch.nn import functional as F class collate_fn_crfrp: def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75): self.img_size = input_size self.patch_size = patch_size self.num_patches_axis = input_size // patch_size self.num_patches = (input_size // patch_size) ** 2 self.mask_ratio = mask_ratio # -------------------------------------------------------------------------- # self.facial_region_group = [ # [2], # right eyebrow # [3], # left eyebrow # [4], # right eye # [5], # left eye # [6], # nose # [7, 8], # upper mouth # [8, 9], # lower mouth # [10, 1, 0], # facial boundaries # [10], # hair # [1], # facial skin # [0] # background # ] self.facial_region_group = [ [2, 3], # eyebrows [4, 5], # eyes [6], # nose [7, 8, 9], # mouth [10, 1, 0], # face boundaries [10], # hair [1], # facial skin [0] # background ] # ['background', 'face', 'rb', 'lb', 're', 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair'] def __call__(self, samples): image, img_mask, facial_region_mask, random_specific_facial_region \ = self.CRFR_P_masking(samples, specified_facial_region=None) return {'image': image, 'img_mask': img_mask, 'specific_facial_region_mask': facial_region_mask} # # using following code if using different data augmentation for target view # image, img_mask, facial_region_mask, random_specific_facial_region \ # = self.CRFR_P_masking(samples, specified_facial_region=None) # image_cl, img_mask_cl, facial_region_mask_cl, random_specific_facial_region_cl \ # = self.CRFR_P_masking(samples, specified_facial_region=random_specific_facial_region) # # return {'image': image, 'img_mask': img_mask, 'specific_facial_region_mask': facial_region_mask, # 'image_cl': image_cl, 'img_mask_cl': img_mask_cl, 'specific_facial_region_mask_cl': facial_region_mask_cl} def CRFR_P_masking(self, samples, specified_facial_region=None): image = torch.stack([sample['image'] for sample in samples]) # torch.Size([bs, 3, 224, 224]) parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) # torch.Size([bs, 1, 224, 224]) parsing_map = parsing_map.squeeze(1) # torch.Size([BS, 1, 224, 224]) → torch.Size([BS, 224, 224]) # covering a randomly select facial_region_group and get fr_mask(masking all patches include this region) facial_region_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) # torch.Size([BS, H/P, W/P]) facial_region_mask, random_specific_facial_region \ = self.masking_all_patches_in_random_specific_facial_region(parsing_map, facial_region_mask) # torch.Size([num_patches,]), list img_mask, facial_region_mask \ = self.variable_proportional_masking(parsing_map, facial_region_mask, random_specific_facial_region) # torch.Size([num_patches,]), torch.Size([num_patches,]) del parsing_map return image, img_mask, facial_region_mask, random_specific_facial_region def masking_all_patches_in_random_specific_facial_region(self, parsing_map, facial_region_mask, # specified_facial_region=None ): # while True: # random_specific_facial_region = random.choice(self.facial_region_group[:-2]) # if random_specific_facial_region != specified_facial_region: # break random_specific_facial_region = random.choice(self.facial_region_group[:-2]) if random_specific_facial_region == [10, 1, 0]: # facial boundaries, 10-hair 1-skin 0-background # True for hair(10) or bg(0) patches: patch_hair_bg = F.max_pool2d(((parsing_map == 10) + (parsing_map == 0)).float(), kernel_size=self.patch_size) # True for skin(1) patches: patch_skin = F.max_pool2d((parsing_map == 1).float(), kernel_size=self.patch_size) # skin&hair or skin&bg is defined as facial boundaries: facial_region_mask = (patch_hair_bg.bool() & patch_skin.bool()).float() else: for facial_region_index in random_specific_facial_region: facial_region_mask = torch.maximum(facial_region_mask, F.max_pool2d((parsing_map == facial_region_index).float(), kernel_size=self.patch_size)) return facial_region_mask.view(parsing_map.size(0), -1), random_specific_facial_region def variable_proportional_masking(self, parsing_map, facial_region_mask, random_specific_facial_region): img_mask = facial_region_mask.clone() # proportional masking patches in other regions other_facial_region_group = [region for region in self.facial_region_group if region != random_specific_facial_region] # print(other_facial_region_group) for i in range(facial_region_mask.size(0)): # iterate each map in BS num_mask_to_change = (self.mask_ratio * self.num_patches - facial_region_mask[i].sum(dim=-1)).int() # mask_change_to = 1 if num_mask_to_change >= 0 else 0 mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item() if mask_change_to == 1: # proportional masking patches in other facial regions according to the corresponding ratio mask_ratio_other_fr = ( num_mask_to_change / (self.num_patches - facial_region_mask[i].sum(dim=-1))) masked_patches = facial_region_mask[i].clone() for other_fr in other_facial_region_group: to_mask_patches = torch.zeros(1, self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) if other_fr == [10, 1, 0]: patch_hair_bg = F.max_pool2d( ((parsing_map[i].unsqueeze(0) == 10) + (parsing_map[i].unsqueeze(0) == 0)).float(), kernel_size=self.patch_size) patch_skin = F.max_pool2d((parsing_map[i].unsqueeze(0) == 1).float(), kernel_size=self.patch_size) # skin&hair or skin&bg defined as facial boundaries: to_mask_patches = (patch_hair_bg.bool() & patch_skin.bool()).float() else: for facial_region_index in other_fr: to_mask_patches = torch.maximum(to_mask_patches, F.max_pool2d((parsing_map[i].unsqueeze( 0) == facial_region_index).float(), kernel_size=self.patch_size)) # ignore already masked patches: to_mask_patches = (to_mask_patches.view(-1) - masked_patches) > 0 select_indices = to_mask_patches.nonzero(as_tuple=False).view(-1) change_indices = torch.randperm(len(select_indices))[ :torch.round(to_mask_patches.sum() * mask_ratio_other_fr).int()] img_mask[i, select_indices[change_indices]] = mask_change_to # prevent overlap masked_patches = masked_patches + to_mask_patches.float() # mask/unmask patch from other facial regions to get img_mask with fixed size num_mask_to_change = (self.mask_ratio * self.num_patches - img_mask[i].sum(dim=-1)).int() # mask_change_to = 1 if num_mask_to_change >= 0 else 0 mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item() # prevent unmasking facial_region_mask select_indices = ((img_mask[i] + facial_region_mask[i]) == (1 - mask_change_to)).nonzero( as_tuple=False).view(-1) change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] img_mask[i, select_indices[change_indices]] = mask_change_to else: # Extreme situations: # if fr_mask is already over(>=) num_patches*mask_ratio, unmask it to get img_mask with fixed ratio select_indices = (facial_region_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1) change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] img_mask[i, select_indices[change_indices]] = mask_change_to facial_region_mask[i] = img_mask[i] return img_mask, facial_region_mask class FaceParsingDataset(Dataset): def __init__(self, root, transform=None): self.root_dir = root self.transform = transform self.image_folder = os.path.join(root, 'images') self.parsing_map_folder = os.path.join(root, 'parsing_maps') self.image_names = os.listdir(self.image_folder) def __len__(self): return len(self.image_names) def __getitem__(self, idx): img_name = os.path.join(self.image_folder, self.image_names[idx]) parsing_map_name = os.path.join(self.parsing_map_folder, self.image_names[idx].replace('.png', '.npy')) image = Image.open(img_name).convert("RGB") parsing_map_np = np.load(parsing_map_name) if self.transform: image = self.transform(image) # Convert mask to tensor parsing_map = torch.from_numpy(parsing_map_np) del parsing_map_np # may save mem return {'image': image, 'parsing_map': parsing_map} class TestImageFolder(datasets.ImageFolder): def __init__(self, root, transform=None, target_transform=None): super(TestImageFolder, self).__init__(root, transform, target_transform) def __getitem__(self, index): # Call the parent class method to load image and label original_tuple = super(TestImageFolder, self).__getitem__(index) # Get the video name video_name = self.imgs[index][0].split('/')[-1].split('_frame_')[0] # the separator of video name # Extend the tuple to include video name extended_tuple = (original_tuple + (video_name,)) return extended_tuple def get_mean_std(args): print('dataset_paths:', args.data_path) transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((args.input_size, args.input_size), interpolation=transforms.InterpolationMode.BICUBIC)]) if len(args.data_path) > 1: pretrain_datasets = [FaceParsingDataset(root=path, transform=transform) for path in args.data_path] dataset_pretrain = ConcatDataset(pretrain_datasets) else: pretrain_datasets = args.data_path[0] dataset_pretrain = FaceParsingDataset(root=pretrain_datasets, transform=transform) print('Compute mean and variance for pretraining data.') print('len(dataset_train): ', len(dataset_pretrain)) loader = DataLoader( dataset_pretrain, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) channels_sum, channels_squared_sum, num_batches = 0, 0, 0 for sample in loader: data = sample['image'] channels_sum += torch.mean(data, dim=[0, 2, 3]) channels_squared_sum += torch.mean(data ** 2, dim=[0, 2, 3]) num_batches += 1 mean = channels_sum / num_batches std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5 print(f'train dataset mean%: {mean.numpy()} std: %{std.numpy()} ') del pretrain_datasets, dataset_pretrain, loader return mean.numpy(), std.numpy() def build_dataset(is_train, args): transform = build_transform(is_train, args) if args.eval: # no loading training set root = os.path.join(args.data_path, 'test' if is_train else 'test') dataset = TestImageFolder(root, transform=transform) else: root = os.path.join(args.data_path, 'train' if is_train else 'val') dataset = datasets.ImageFolder(root, transform=transform) print(dataset) return dataset def build_transform(is_train, args): if args.normalize_from_IMN: mean = IMAGENET_DEFAULT_MEAN std = IMAGENET_DEFAULT_STD # print(f'mean:{mean}, std:{std}') else: if not os.path.exists(os.path.join(args.output_dir, "/pretrain_ds_mean_std.txt")) and not args.eval: shutil.copyfile(os.path.dirname(args.finetune) + '/pretrain_ds_mean_std.txt', os.path.join(args.output_dir) + '/pretrain_ds_mean_std.txt') with open(os.path.join(os.path.dirname(args.resume)) + '/pretrain_ds_mean_std.txt' if args.eval else os.path.join(args.output_dir) + '/pretrain_ds_mean_std.txt', 'r') as file: ds_stat = json.loads(file.readline()) mean = ds_stat['mean'] std = ds_stat['std'] # print(f'mean:{mean}, std:{std}') if args.apply_simple_augment: if is_train: # this should always dispatch to transforms_imagenet_train transform = create_transform( input_size=args.input_size, is_training=True, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation=transforms.InterpolationMode.BICUBIC, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, mean=mean, std=std, ) return transform # no augment / eval transform t = [] if args.input_size <= 224: crop_pct = 224 / 256 else: crop_pct = 1.0 size = int(args.input_size / crop_pct) # 256 t.append( transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), # to maintain same ratio w.r.t. 224 images ) t.append(transforms.CenterCrop(args.input_size)) # 224 t.append(transforms.ToTensor()) t.append(transforms.Normalize(mean, std)) return transforms.Compose(t) else: t = [] if args.input_size < 224: crop_pct = input_size / 224 else: crop_pct = 1.0 size = int(args.input_size / crop_pct) # size = 224 t.append( transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), # to maintain same ratio w.r.t. 224 images ) # t.append( # transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), # # to maintain same ratio w.r.t. 224 images # ) t.append(transforms.ToTensor()) t.append(transforms.Normalize(mean, std)) return transforms.Compose(t)