|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import json |
|
import shutil |
|
|
|
from torchvision import datasets, transforms |
|
|
|
from timm.data import create_transform |
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
|
|
import numpy as np |
|
from PIL import Image |
|
import random |
|
import torch |
|
from torch.utils.data import DataLoader, Dataset, ConcatDataset |
|
from torchvision import transforms |
|
from torch.nn import functional as F |
|
|
|
|
|
class collate_fn_crfrp: |
|
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75): |
|
self.img_size = input_size |
|
self.patch_size = patch_size |
|
self.num_patches_axis = input_size // patch_size |
|
self.num_patches = (input_size // patch_size) ** 2 |
|
self.mask_ratio = mask_ratio |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.facial_region_group = [ |
|
[2, 3], |
|
[4, 5], |
|
[6], |
|
[7, 8, 9], |
|
[10, 1, 0], |
|
[10], |
|
[1], |
|
[0] |
|
] |
|
|
|
def __call__(self, samples): |
|
image, img_mask, facial_region_mask, random_specific_facial_region \ |
|
= self.CRFR_P_masking(samples, specified_facial_region=None) |
|
|
|
return {'image': image, 'img_mask': img_mask, 'specific_facial_region_mask': facial_region_mask} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def CRFR_P_masking(self, samples, specified_facial_region=None): |
|
image = torch.stack([sample['image'] for sample in samples]) |
|
parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) |
|
parsing_map = parsing_map.squeeze(1) |
|
|
|
|
|
facial_region_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, |
|
dtype=torch.float32) |
|
facial_region_mask, random_specific_facial_region \ |
|
= self.masking_all_patches_in_random_specific_facial_region(parsing_map, facial_region_mask) |
|
|
|
|
|
img_mask, facial_region_mask \ |
|
= self.variable_proportional_masking(parsing_map, facial_region_mask, random_specific_facial_region) |
|
|
|
|
|
del parsing_map |
|
return image, img_mask, facial_region_mask, random_specific_facial_region |
|
|
|
def masking_all_patches_in_random_specific_facial_region(self, parsing_map, facial_region_mask, |
|
|
|
): |
|
|
|
|
|
|
|
|
|
random_specific_facial_region = random.choice(self.facial_region_group[:-2]) |
|
if random_specific_facial_region == [10, 1, 0]: |
|
|
|
patch_hair_bg = F.max_pool2d(((parsing_map == 10) + (parsing_map == 0)).float(), |
|
kernel_size=self.patch_size) |
|
|
|
patch_skin = F.max_pool2d((parsing_map == 1).float(), kernel_size=self.patch_size) |
|
|
|
facial_region_mask = (patch_hair_bg.bool() & patch_skin.bool()).float() |
|
else: |
|
for facial_region_index in random_specific_facial_region: |
|
facial_region_mask = torch.maximum(facial_region_mask, |
|
F.max_pool2d((parsing_map == facial_region_index).float(), |
|
kernel_size=self.patch_size)) |
|
|
|
return facial_region_mask.view(parsing_map.size(0), -1), random_specific_facial_region |
|
|
|
def variable_proportional_masking(self, parsing_map, facial_region_mask, random_specific_facial_region): |
|
img_mask = facial_region_mask.clone() |
|
|
|
|
|
other_facial_region_group = [region for region in self.facial_region_group if |
|
region != random_specific_facial_region] |
|
|
|
for i in range(facial_region_mask.size(0)): |
|
num_mask_to_change = (self.mask_ratio * self.num_patches - facial_region_mask[i].sum(dim=-1)).int() |
|
|
|
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item() |
|
|
|
if mask_change_to == 1: |
|
|
|
mask_ratio_other_fr = ( |
|
num_mask_to_change / (self.num_patches - facial_region_mask[i].sum(dim=-1))) |
|
|
|
masked_patches = facial_region_mask[i].clone() |
|
for other_fr in other_facial_region_group: |
|
to_mask_patches = torch.zeros(1, self.num_patches_axis, self.num_patches_axis, |
|
dtype=torch.float32) |
|
if other_fr == [10, 1, 0]: |
|
patch_hair_bg = F.max_pool2d( |
|
((parsing_map[i].unsqueeze(0) == 10) + (parsing_map[i].unsqueeze(0) == 0)).float(), |
|
kernel_size=self.patch_size) |
|
patch_skin = F.max_pool2d((parsing_map[i].unsqueeze(0) == 1).float(), |
|
kernel_size=self.patch_size) |
|
|
|
to_mask_patches = (patch_hair_bg.bool() & patch_skin.bool()).float() |
|
else: |
|
for facial_region_index in other_fr: |
|
to_mask_patches = torch.maximum(to_mask_patches, |
|
F.max_pool2d((parsing_map[i].unsqueeze( |
|
0) == facial_region_index).float(), |
|
kernel_size=self.patch_size)) |
|
|
|
|
|
to_mask_patches = (to_mask_patches.view(-1) - masked_patches) > 0 |
|
select_indices = to_mask_patches.nonzero(as_tuple=False).view(-1) |
|
change_indices = torch.randperm(len(select_indices))[ |
|
:torch.round(to_mask_patches.sum() * mask_ratio_other_fr).int()] |
|
img_mask[i, select_indices[change_indices]] = mask_change_to |
|
|
|
masked_patches = masked_patches + to_mask_patches.float() |
|
|
|
|
|
num_mask_to_change = (self.mask_ratio * self.num_patches - img_mask[i].sum(dim=-1)).int() |
|
|
|
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item() |
|
|
|
select_indices = ((img_mask[i] + facial_region_mask[i]) == (1 - mask_change_to)).nonzero( |
|
as_tuple=False).view(-1) |
|
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] |
|
img_mask[i, select_indices[change_indices]] = mask_change_to |
|
|
|
else: |
|
|
|
|
|
select_indices = (facial_region_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1) |
|
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] |
|
img_mask[i, select_indices[change_indices]] = mask_change_to |
|
facial_region_mask[i] = img_mask[i] |
|
|
|
return img_mask, facial_region_mask |
|
|
|
|
|
class FaceParsingDataset(Dataset): |
|
def __init__(self, root, transform=None): |
|
self.root_dir = root |
|
self.transform = transform |
|
self.image_folder = os.path.join(root, 'images') |
|
self.parsing_map_folder = os.path.join(root, 'parsing_maps') |
|
self.image_names = os.listdir(self.image_folder) |
|
|
|
def __len__(self): |
|
return len(self.image_names) |
|
|
|
def __getitem__(self, idx): |
|
img_name = os.path.join(self.image_folder, self.image_names[idx]) |
|
parsing_map_name = os.path.join(self.parsing_map_folder, self.image_names[idx].replace('.png', '.npy')) |
|
|
|
image = Image.open(img_name).convert("RGB") |
|
parsing_map_np = np.load(parsing_map_name) |
|
|
|
if self.transform: |
|
image = self.transform(image) |
|
|
|
|
|
parsing_map = torch.from_numpy(parsing_map_np) |
|
del parsing_map_np |
|
|
|
return {'image': image, 'parsing_map': parsing_map} |
|
|
|
|
|
class TestImageFolder(datasets.ImageFolder): |
|
def __init__(self, root, transform=None, target_transform=None): |
|
super(TestImageFolder, self).__init__(root, transform, target_transform) |
|
|
|
def __getitem__(self, index): |
|
|
|
original_tuple = super(TestImageFolder, self).__getitem__(index) |
|
|
|
|
|
video_name = self.imgs[index][0].split('/')[-1].split('_frame_')[0] |
|
|
|
|
|
extended_tuple = (original_tuple + (video_name,)) |
|
|
|
return extended_tuple |
|
|
|
|
|
def get_mean_std(args): |
|
print('dataset_paths:', args.data_path) |
|
transform = transforms.Compose([transforms.ToTensor(), |
|
transforms.Resize((args.input_size, args.input_size), |
|
interpolation=transforms.InterpolationMode.BICUBIC)]) |
|
|
|
if len(args.data_path) > 1: |
|
pretrain_datasets = [FaceParsingDataset(root=path, transform=transform) for path in args.data_path] |
|
dataset_pretrain = ConcatDataset(pretrain_datasets) |
|
else: |
|
pretrain_datasets = args.data_path[0] |
|
dataset_pretrain = FaceParsingDataset(root=pretrain_datasets, transform=transform) |
|
|
|
print('Compute mean and variance for pretraining data.') |
|
print('len(dataset_train): ', len(dataset_pretrain)) |
|
|
|
loader = DataLoader( |
|
dataset_pretrain, |
|
batch_size=args.batch_size, |
|
num_workers=args.num_workers, |
|
pin_memory=args.pin_mem, |
|
drop_last=True, |
|
) |
|
|
|
channels_sum, channels_squared_sum, num_batches = 0, 0, 0 |
|
for sample in loader: |
|
data = sample['image'] |
|
channels_sum += torch.mean(data, dim=[0, 2, 3]) |
|
channels_squared_sum += torch.mean(data ** 2, dim=[0, 2, 3]) |
|
num_batches += 1 |
|
|
|
mean = channels_sum / num_batches |
|
std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5 |
|
|
|
print(f'train dataset mean%: {mean.numpy()} std: %{std.numpy()} ') |
|
del pretrain_datasets, dataset_pretrain, loader |
|
return mean.numpy(), std.numpy() |
|
|
|
|
|
def build_dataset(is_train, args): |
|
transform = build_transform(is_train, args) |
|
if args.eval: |
|
|
|
root = os.path.join(args.data_path, 'test' if is_train else 'test') |
|
dataset = TestImageFolder(root, transform=transform) |
|
else: |
|
root = os.path.join(args.data_path, 'train' if is_train else 'val') |
|
dataset = datasets.ImageFolder(root, transform=transform) |
|
print(dataset) |
|
|
|
return dataset |
|
|
|
|
|
def build_transform(is_train, args): |
|
if args.normalize_from_IMN: |
|
mean = IMAGENET_DEFAULT_MEAN |
|
std = IMAGENET_DEFAULT_STD |
|
|
|
else: |
|
if not os.path.exists(os.path.join(args.output_dir, "/pretrain_ds_mean_std.txt")) and not args.eval: |
|
shutil.copyfile(os.path.dirname(args.finetune) + '/pretrain_ds_mean_std.txt', |
|
os.path.join(args.output_dir) + '/pretrain_ds_mean_std.txt') |
|
with open(os.path.join(os.path.dirname(args.resume)) + '/pretrain_ds_mean_std.txt' if args.eval |
|
else os.path.join(args.output_dir) + '/pretrain_ds_mean_std.txt', 'r') as file: |
|
ds_stat = json.loads(file.readline()) |
|
mean = ds_stat['mean'] |
|
std = ds_stat['std'] |
|
|
|
|
|
if args.apply_simple_augment: |
|
if is_train: |
|
|
|
transform = create_transform( |
|
input_size=args.input_size, |
|
is_training=True, |
|
color_jitter=args.color_jitter, |
|
auto_augment=args.aa, |
|
interpolation=transforms.InterpolationMode.BICUBIC, |
|
re_prob=args.reprob, |
|
re_mode=args.remode, |
|
re_count=args.recount, |
|
mean=mean, |
|
std=std, |
|
) |
|
return transform |
|
|
|
|
|
t = [] |
|
if args.input_size <= 224: |
|
crop_pct = 224 / 256 |
|
else: |
|
crop_pct = 1.0 |
|
size = int(args.input_size / crop_pct) |
|
t.append( |
|
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), |
|
|
|
) |
|
t.append(transforms.CenterCrop(args.input_size)) |
|
|
|
t.append(transforms.ToTensor()) |
|
t.append(transforms.Normalize(mean, std)) |
|
return transforms.Compose(t) |
|
|
|
else: |
|
t = [] |
|
if args.input_size < 224: |
|
crop_pct = input_size / 224 |
|
else: |
|
crop_pct = 1.0 |
|
size = int(args.input_size / crop_pct) |
|
t.append( |
|
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
t.append(transforms.ToTensor()) |
|
t.append(transforms.Normalize(mean, std)) |
|
return transforms.Compose(t) |
|
|