# -*- coding: utf-8 -*- # Author: Gaojian Wang@ZJUICSR # -------------------------------------------------------- # This source code is licensed under the Attribution-NonCommercial 4.0 International License. # You can find the license in the LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # pip uninstall nvidia_cublas_cu11 import sys sys.path.append('..') import os os.system(f'pip install dlib') import torch import numpy as np from PIL import Image import models_mae from torch.nn import functional as F import dlib import gradio as gr # loading model model = getattr(models_mae, 'mae_vit_base_patch16')() class ITEM: def __init__(self, img, parsing_map): self.image = img self.parsing_map = parsing_map face_to_show = ITEM(None, None) check_region = {'Eyebrows': [2, 3], 'Eyes': [4, 5], 'Nose': [6], 'Mouth': [7, 8, 9], 'Face Boundaries': [10, 1, 0], 'Hair': [10], 'Skin': [1], 'Background': [0]} def get_boundingbox(face, width, height, minsize=None): """ Expects a dlib face to generate a quadratic bounding box. :param face: dlib face class :param width: frame width :param height: frame height :param cfg.face_scale: bounding box size multiplier to get a bigger face region :param minsize: set minimum bounding box size :return: x, y, bounding_box_size in opencv form """ x1 = face.left() y1 = face.top() x2 = face.right() y2 = face.bottom() size_bb = int(max(x2 - x1, y2 - y1) * 1.3) if minsize: if size_bb < minsize: size_bb = minsize center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2 # Check for out of bounds, x-y top left corner x1 = max(int(center_x - size_bb // 2), 0) y1 = max(int(center_y - size_bb // 2), 0) # Check for too big bb size for given x, y size_bb = min(width - x1, size_bb) size_bb = min(height - y1, size_bb) return x1, y1, size_bb def extract_face(frame): face_detector = dlib.get_frontal_face_detector() image = np.array(frame.convert('RGB')) faces = face_detector(image, 1) if len(faces) > 0: # For now only take the biggest face face = faces[0] # Face crop and rescale(follow FF++) x, y, size = get_boundingbox(face, image.shape[1], image.shape[0]) # Get the landmarks/parts for the face in box d only with the five key points cropped_face = image[y:y + size, x:x + size] # cropped_face = cv2.resize(cropped_face, (224, 224), interpolation=cv2.INTER_CUBIC) return Image.fromarray(cropped_face) else: return None from torchvision.transforms import transforms def show_one_img_patchify(img, model): x = torch.tensor(img) # make it a batch-like x = x.unsqueeze(dim=0) x = torch.einsum('nhwc->nchw', x) x_patches = model.patchify(x) # visualize the img_patchify n = int(np.sqrt(x_patches.shape[1])) image_size = int(224/n) padding = 3 new_img = Image.new('RGB', (n * image_size + padding*(n-1), n * image_size + padding*(n-1)), 'white') for i, patch in enumerate(x_patches[0]): ax = i % n ay = int(i / n) patch_img_tensor = torch.reshape(patch, (model.patch_embed.patch_size[0], model.patch_embed.patch_size[1], 3)) patch_img_tensor = torch.einsum('hwc->chw', patch_img_tensor) patch_img = transforms.ToPILImage()(patch_img_tensor) new_img.paste(patch_img, (ax * image_size + padding * ax, ay * image_size + padding * ay)) new_img = new_img.resize((224, 224), Image.BICUBIC) return new_img def show_one_img_parchify_mask(img, parsing_map, mask, model): mask = mask.detach() mask_patches = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3) mask = model.unpatchify(mask_patches) # 1 is removing, 0 is keeping mask = torch.einsum('nchw->nhwc', mask).detach().cpu() # visualize mask vis_mask = mask[0].clone() vis_mask[vis_mask == 1] = 1 # gray for masked vis_mask[vis_mask == 2] = -1 # black for highlight masked facial region vis_mask[vis_mask == 0] = 2 # white for visible vis_mask = torch.clip(vis_mask * 127, 0, 255).int() fasking_mask = vis_mask.numpy().astype(np.uint8) fasking_mask = Image.fromarray(fasking_mask) # visualize the masked image im_masked = img im_masked[mask[0] == 1] = 127 im_masked[mask[0] == 2] = 0 im_masked = Image.fromarray(im_masked) # visualize the masked image_patchify parsing_map_masked = parsing_map parsing_map_masked[mask[0] == 1] = 127 parsing_map_masked[mask[0] == 2] = 0 return [show_one_img_patchify(parsing_map_masked, model), fasking_mask, im_masked] # Random class CollateFn_Random: def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75): self.img_size = input_size self.patch_size = patch_size self.num_patches_axis = input_size // patch_size self.num_patches = (input_size // patch_size) ** 2 self.mask_ratio = mask_ratio def __call__(self, image, parsing_map): random_mask = torch.zeros(parsing_map.size(0), self.num_patches, dtype=torch.float32) # torch.Size([BS, 14, 14]) random_mask = self.masking(parsing_map, random_mask) return {'image': image, 'random_mask': random_mask} def masking(self, parsing_map, random_mask): """ :return: """ for i in range(random_mask.size(0)): # normalize the masking to strictly target percentage for batch computation. num_mask_to_change = int(self.mask_ratio * self.num_patches) mask_change_to = 1 if num_mask_to_change >= 0 else 0 change_indices = torch.randperm(self.num_patches) for idx in range(num_mask_to_change): random_mask[i, change_indices[idx]] = mask_change_to return random_mask def do_random_masking(image, parsing_map_vis, ratio): img = torch.from_numpy(image) img = img.unsqueeze(0).permute(0, 3, 1, 2) parsing_map = face_to_show.parsing_map parsing_map = torch.tensor(parsing_map) mask_method = CollateFn_Random(input_size=224, patch_size=16, mask_ratio=ratio) mask = mask_method(img, parsing_map)['random_mask'] random_patch_on_parsing, random_mask, random_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask, model) return random_patch_on_parsing, random_mask, random_mask_on_image # Fasking class CollateFn_Fasking: def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75): self.img_size = input_size self.patch_size = patch_size self.num_patches_axis = input_size // patch_size self.num_patches = (input_size // patch_size) ** 2 self.mask_ratio = mask_ratio # -------------------------------------------------------------------------- self.facial_region_group = [ [2, 4], # right eye [3, 5], # left eye [6], # nose [7, 8, 9], # mouth [10], # hair [1], # skin [0] # background ] # ['background', 'face', 'rb', 'lb', 're', 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair'] def __call__(self, image, parsing_map): # image = torch.stack([sample['image'] for sample in samples]) # torch.Size([bs, 3, 224, 224]) # parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) # torch.Size([bs, 1, 224, 224]) # parsing_map = parsing_map.squeeze(1) # torch.Size([BS, 1, 224, 224]) → torch.Size([BS, 224, 224]) # random select a facial semantic region and get corresponding mask(masking all patches include this region) fasking_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) # torch.Size([BS, 14, 14]) fasking_mask = self.fasking(parsing_map, fasking_mask) return {'image': image, 'fasking_mask': fasking_mask} def fasking(self, parsing_map, fasking_mask): """ :return: """ for i in range(parsing_map.size(0)): terminate = False for seg_group in self.facial_region_group[:-2]: if terminate: break for comp_value in seg_group: fasking_mask[i] = torch.maximum( fasking_mask[i], F.max_pool2d((parsing_map[i].unsqueeze(0) == comp_value).float(), kernel_size=self.patch_size)) if fasking_mask[i].mean() >= ((self.mask_ratio * self.num_patches) / self.num_patches): terminate = True break fasking_mask = fasking_mask.view(parsing_map.size(0), -1) for i in range(fasking_mask.size(0)): # normalize the masking to strictly target percentage for batch computation. num_mask_to_change = (self.mask_ratio * self.num_patches - fasking_mask[i].sum(dim=-1)).int() mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item() select_indices = (fasking_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1) change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] fasking_mask[i, select_indices[change_indices]] = mask_change_to return fasking_mask def do_fasking_masking(image, parsing_map_vis, ratio): img = torch.from_numpy(image) img = img.unsqueeze(0).permute(0, 3, 1, 2) parsing_map = face_to_show.parsing_map parsing_map = torch.tensor(parsing_map) mask_method = CollateFn_Fasking(input_size=224, patch_size=16, mask_ratio=ratio) mask = mask_method(img, parsing_map)['fasking_mask'] fasking_patch_on_parsing, fasking_mask, fasking_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask, model) return fasking_patch_on_parsing, fasking_mask, fasking_mask_on_image # FRP class CollateFn_FR_P_Masking: def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75): self.img_size = input_size self.patch_size = patch_size self.num_patches_axis = input_size // patch_size self.num_patches = (input_size // patch_size) ** 2 self.mask_ratio = mask_ratio self.facial_region_group = [ [2, 3], # eyebrows [4, 5], # eyes [6], # nose [7, 8, 9], # mouth [10, 1, 0], # face boundaries [10], # hair [1], # facial skin [0] # background ] # ['background', 'face', 'rb', 'lb', 're', 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair'] def __call__(self, image, parsing_map): # image = torch.stack([sample['image'] for sample in samples]) # torch.Size([bs, 3, 224, 224]) # parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) # torch.Size([bs, 1, 224, 224]) # parsing_map = parsing_map.squeeze(1) # torch.Size([BS, 1, 224, 224]) → torch.Size([BS, 224, 224]) # random select a facial semantic region and get corresponding mask(masking all patches include this region) P_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) # torch.Size([BS, 14, 14]) P_mask = self.random_variable_facial_semantics_masking(parsing_map, P_mask) return {'image': image, 'P_mask': P_mask} def random_variable_facial_semantics_masking(self, parsing_map, P_mask): """ :return: """ P_mask = P_mask.view(P_mask.size(0), -1) for i in range(parsing_map.size(0)): for seg_group in self.facial_region_group[:-2]: mask_in_seg_group = torch.zeros(1, self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) if seg_group == [10, 1, 0]: patch_hair_bg = F.max_pool2d( ((parsing_map[i].unsqueeze(0) == 10) + (parsing_map[i].unsqueeze(0) == 0)).float(), kernel_size=self.patch_size) patch_skin = F.max_pool2d((parsing_map[i].unsqueeze(0) == 1).float(), kernel_size=self.patch_size) # skin&hair or skin&bg defined as facial boundaries: mask_in_seg_group = torch.maximum(mask_in_seg_group, (patch_hair_bg.bool() & patch_skin.bool()).float()) else: for comp_value in seg_group: mask_in_seg_group = torch.maximum(mask_in_seg_group, F.max_pool2d( (parsing_map[i].unsqueeze(0) == comp_value).float(), kernel_size=self.patch_size)) mask_in_seg_group = mask_in_seg_group.view(-1) # to_mask_patches_in_seg_group = mask_in_seg_group - (mask_in_seg_group & P_mask[i]) to_mask_patches_in_seg_group = (mask_in_seg_group - P_mask[i]) > 0 mask_num = (mask_in_seg_group.sum(dim=-1) * self.mask_ratio - (mask_in_seg_group.sum(dim=-1)-to_mask_patches_in_seg_group.sum(dim=-1))).int() if mask_num > 0: select_indices = (to_mask_patches_in_seg_group == 1).nonzero(as_tuple=False).view(-1) change_indices = torch.randperm(len(select_indices))[:mask_num] P_mask[i, select_indices[change_indices]] = 1 num_mask_to_change = (self.mask_ratio * self.num_patches - P_mask[i].sum(dim=-1)).int() mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item() select_indices = (P_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1) change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] P_mask[i, select_indices[change_indices]] = mask_change_to return P_mask def do_FRP_masking(image, parsing_map_vis, ratio): img = torch.from_numpy(image) img = img.unsqueeze(0).permute(0, 3, 1, 2) parsing_map = face_to_show.parsing_map parsing_map = torch.tensor(parsing_map) mask_method = CollateFn_FR_P_Masking(input_size=224, patch_size=16, mask_ratio=ratio) masks = mask_method(img, parsing_map) mask = masks['P_mask'] FRP_patch_on_parsing, FRP_mask, FRP_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask, model) return FRP_patch_on_parsing, FRP_mask, FRP_mask_on_image # CRFR_R class CollateFn_CRFR_R_Masking: def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75, region='Nose'): self.img_size = input_size self.patch_size = patch_size self.num_patches_axis = input_size // patch_size self.num_patches = (input_size // patch_size) ** 2 self.mask_ratio = mask_ratio self.facial_region_group = [ [2, 3], # eyebrows [4, 5], # eyes [6], # nose [7, 8, 9], # mouth [10, 1, 0], # face boundaries [10], # hair [1], # facial skin [0] # background ] # ['background', 'face', 'rb', 'lb', 're', 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair'] self.random_specific_facial_region = check_region[region] def __call__(self, image, parsing_map): # mage = torch.stack([sample['image'] for sample in samples]) # torch.Size([bs, 3, 224, 224]) # parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) # torch.Size([bs, 1, 224, 224]) # parsing_map = parsing_map.squeeze(1) # torch.Size([BS, 1, 224, 224]) → torch.Size([BS, 224, 224]) # random select a facial semantic region and get corresponding mask(masking all patches include this region) facial_region_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) # torch.Size([1, H/P, W/P]) facial_region_mask, random_specific_facial_region = self.masking_all_patches_in_random_specific_facial_region(parsing_map, facial_region_mask) # torch.Size([num_patches,]), list CRFR_R_mask, facial_region_mask = self.random_variable_masking(facial_region_mask) # torch.Size([num_patches,]), torch.Size([num_patches,]) return {'image': image, 'CRFR_R_mask': CRFR_R_mask, 'fr_mask': facial_region_mask} def masking_all_patches_in_random_specific_facial_region(self, parsing_map, facial_region_mask): """ :param parsing_map: [1, img_size, img_size]) :param facial_region_mask: [1, num_patches ** .5, num_patches ** .5] :return: facial_region_mask, random_specific_facial_region """ # random_specific_facial_region = random.choice(self.facial_region_group[:-2]) # random_specific_facial_region = [6] # for test: nose if self.random_specific_facial_region == [10, 1, 0]: # facial boundaries, 10-hair 1-skin 0-background # True for hair(10) or bg(0) patches: patch_hair_bg = F.max_pool2d(((parsing_map == 10) + (parsing_map == 0)).float(), kernel_size=self.patch_size) # True for skin(1) patches: patch_skin = F.max_pool2d((parsing_map == 1).float(), kernel_size=self.patch_size) # skin&hair or skin&bg is defined as facial boundaries: facial_region_mask = (patch_hair_bg.bool() & patch_skin.bool()).float() else: for facial_region_index in self.random_specific_facial_region: facial_region_mask = torch.maximum(facial_region_mask, F.max_pool2d((parsing_map == facial_region_index).float(), kernel_size=self.patch_size)) return facial_region_mask.view(parsing_map.size(0), -1), self.random_specific_facial_region def random_variable_masking(self, facial_region_mask): CRFR_R_mask = facial_region_mask.clone() for i in range(facial_region_mask.size(0)): num_mask_to_change = (self.mask_ratio * self.num_patches - facial_region_mask[i].sum(dim=-1)).int() mask_change_to = 1 if num_mask_to_change >= 0 else 0 select_indices = (facial_region_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1) change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] CRFR_R_mask[i, select_indices[change_indices]] = mask_change_to facial_region_mask[i] = CRFR_R_mask[i] if num_mask_to_change < 0 else facial_region_mask[i] return CRFR_R_mask, facial_region_mask def do_CRFR_R_masking(image, parsing_map_vis, ratio, region): img = torch.from_numpy(image) img = img.unsqueeze(0).permute(0, 3, 1, 2) parsing_map = face_to_show.parsing_map parsing_map = torch.tensor(parsing_map) mask_method = CollateFn_CRFR_R_Masking(input_size=224, patch_size=16, mask_ratio=ratio, region=region) masks = mask_method(img, parsing_map) mask = masks['CRFR_R_mask'] fr_mask = masks['fr_mask'] CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask+fr_mask, model) return CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image # CRFR_P class CollateFn_CRFR_P_Masking: def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75, region='Nose'): self.img_size = input_size self.patch_size = patch_size self.num_patches_axis = input_size // patch_size self.num_patches = (input_size // patch_size) ** 2 self.mask_ratio = mask_ratio self.facial_region_group = [ [2, 3], # eyebrows [4, 5], # eyes [6], # nose [7, 8, 9], # mouth [10, 1, 0], # face boundaries [10], # hair [1], # facial skin [0] # background ] # ['background', 'face', 'rb', 'lb', 're', 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair'] self.random_specific_facial_region = check_region[region] def __call__(self, image, parsing_map): # image = torch.stack([sample['image'] for sample in samples]) # torch.Size([bs, 3, 224, 224]) # parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) # torch.Size([bs, 1, 224, 224]) # parsing_map = parsing_map.squeeze(1) # torch.Size([BS, 1, 224, 224]) → torch.Size([BS, 224, 224]) # random select a facial semantic region and get corresponding mask(masking all patches include this region) facial_region_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) # torch.Size([1, H/P, W/P]) facial_region_mask, random_specific_facial_region = self.masking_all_patches_in_random_specific_facial_region(parsing_map, facial_region_mask) # torch.Size([num_patches,]), list CRFR_P_mask, facial_region_mask = self.random_variable_masking(parsing_map, facial_region_mask, random_specific_facial_region) # torch.Size([num_patches,]), torch.Size([num_patches,]) return {'image': image, 'CRFR_P_mask': CRFR_P_mask, 'fr_mask': facial_region_mask} def masking_all_patches_in_random_specific_facial_region(self, parsing_map, facial_region_mask): """ :param parsing_map: [1, img_size, img_size]) :param facial_region_mask: [1, num_patches ** .5, num_patches ** .5] :return: facial_region_mask, random_specific_facial_region """ # random_specific_facial_region = random.choice(self.facial_region_group[:-2]) # random_specific_facial_region = [4, 5] # for test: eyes if self.random_specific_facial_region == [10, 1, 0]: # facial boundaries, 10-hair 1-skin 0-background # True for hair(10) or bg(0) patches: patch_hair_bg = F.max_pool2d(((parsing_map == 10) + (parsing_map == 0)).float(), kernel_size=self.patch_size) # True for skin(1) patches: patch_skin = F.max_pool2d((parsing_map == 1).float(), kernel_size=self.patch_size) # skin&hair or skin&bg is defined as facial boundaries: facial_region_mask = (patch_hair_bg.bool() & patch_skin.bool()).float() # # True for hair(10) or skin(1) patches: # patch_hair_face = F.max_pool2d(((parsing_map == 10) + (parsing_map == 1)).float(), # kernel_size=self.patch_size) # # True for bg(0) patches: # patch_bg = F.max_pool2d((parsing_map == 0).float(), kernel_size=self.patch_size) # # skin&bg or hair&bg defined as facial boundaries: # facial_region_mask = (patch_hair_face.bool() & patch_bg.bool()).float() else: for facial_region_index in self.random_specific_facial_region: facial_region_mask = torch.maximum(facial_region_mask, F.max_pool2d((parsing_map == facial_region_index).float(), kernel_size=self.patch_size)) return facial_region_mask.view(parsing_map.size(0), -1), self.random_specific_facial_region def random_variable_masking(self, parsing_map, facial_region_mask, random_specific_facial_region): CRFR_P_mask = facial_region_mask.clone() other_facial_region_group = [region for region in self.facial_region_group if region != random_specific_facial_region] # print(other_facial_region_group) for i in range(facial_region_mask.size(0)): # iterate each map in BS num_mask_to_change = (self.mask_ratio * self.num_patches - facial_region_mask[i].sum(dim=-1)).int() # mask_change_to = 1 if num_mask_to_change >= 0 else 0 mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item() # masking patches in other facial regions according to the corresponding ratio if mask_change_to == 1: # mask_ratio_other_fr = remain(unmasked) patches should be masked / remain(unmasked) patches mask_ratio_other_fr = ( num_mask_to_change / (self.num_patches - facial_region_mask[i].sum(dim=-1))) masked_patches = facial_region_mask[i].clone() for other_fr in other_facial_region_group: to_mask_patches = torch.zeros(1, self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) if other_fr == [10, 1, 0]: patch_hair_bg = F.max_pool2d( ((parsing_map[i].unsqueeze(0) == 10) + (parsing_map[i].unsqueeze(0) == 0)).float(), kernel_size=self.patch_size) patch_skin = F.max_pool2d((parsing_map[i].unsqueeze(0) == 1).float(), kernel_size=self.patch_size) # skin&hair or skin&bg defined as facial boundaries: to_mask_patches = (patch_hair_bg.bool() & patch_skin.bool()).float() else: for facial_region_index in other_fr: to_mask_patches = torch.maximum(to_mask_patches, F.max_pool2d((parsing_map[i].unsqueeze(0) == facial_region_index).float(), kernel_size=self.patch_size)) # ignore already masked patches: to_mask_patches = (to_mask_patches.view(-1) - masked_patches) > 0 # to_mask_patches = to_mask_patches.view(-1) - (to_mask_patches.view(-1) & masked_patches) select_indices = to_mask_patches.nonzero(as_tuple=False).view(-1) change_indices = torch.randperm(len(select_indices))[ :torch.round(to_mask_patches.sum() * mask_ratio_other_fr).int()] CRFR_P_mask[i, select_indices[change_indices]] = mask_change_to # prevent overlap masked_patches = masked_patches + to_mask_patches.float() # mask/unmask patch from other facial regions to get CRFR_P_mask with fixed size num_mask_to_change = (self.mask_ratio * self.num_patches - CRFR_P_mask[i].sum(dim=-1)).int() # mask_change_to = 1 if num_mask_to_change >= 0 else 0 mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item() # prevent unmasking facial_region_mask select_indices = ((CRFR_P_mask[i] + facial_region_mask[i]) == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1) change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] CRFR_P_mask[i, select_indices[change_indices]] = mask_change_to else: # if the num of facial_region_mask is over (num_patches*mask_ratio), # unmask it to get CRFR_P_mask with fixed size select_indices = (facial_region_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1) change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] CRFR_P_mask[i, select_indices[change_indices]] = mask_change_to facial_region_mask[i] = CRFR_P_mask[i] return CRFR_P_mask, facial_region_mask def do_CRFR_P_masking(image, parsing_map_vis, ratio, region): img = torch.from_numpy(image) img = img.unsqueeze(0).permute(0, 3, 1, 2) parsing_map = face_to_show.parsing_map parsing_map = torch.tensor(parsing_map) mask_method = CollateFn_CRFR_P_Masking(input_size=224, patch_size=16, mask_ratio=ratio, region=region) masks = mask_method(img, parsing_map) mask = masks['CRFR_P_mask'] fr_mask = masks['fr_mask'] CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask+fr_mask, model) return CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image def vis_parsing_maps(parsing_anno): part_colors = [[255, 255, 255], [0, 0, 255], [255, 128, 0], [255, 255, 0], [0, 255, 0], [0, 255, 128], [0, 255, 255], [255, 0, 255], [255, 0, 128], [128, 0, 255], [255, 0, 0]] vis_parsing_anno = parsing_anno.copy().astype(np.uint8) vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 num_of_class = np.max(vis_parsing_anno) for pi in range(1, num_of_class + 1): index = np.where(vis_parsing_anno == pi) vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) return vis_parsing_anno_color #from facer import facer import facer def do_face_parsing(img): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") face_detector = facer.face_detector('retinaface/mobilenet', device=device, threshold=0.3) # 0.3 for FF++ face_parser = facer.face_parser('farl/lapa/448', device=device) # celebm parser img = extract_face(img) with torch.inference_mode(): img = img.resize((224, 224), Image.BICUBIC) image = torch.from_numpy(np.array(img.convert('RGB'))) image = image.unsqueeze(0).permute(0, 3, 1, 2).to(device=device) try: faces = face_detector(image) faces = face_parser(image, faces) seg_logits = faces['seg']['logits'] seg_probs = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w seg_probs = seg_probs.data # torch.Size([1, 11, 224, 224]) parsing = seg_probs.argmax(1) # [1, 224, 224] parsing_map = parsing.data.cpu().numpy() # [1, 224, 224] int64 parsing_map = parsing_map.astype(np.int8) # smaller space parsing_map_vis = vis_parsing_maps(parsing_map.squeeze(0)) except KeyError: return gr.update() face_to_show.image = img face_to_show.parsing_map = parsing_map return img, parsing_map_vis, show_one_img_patchify(parsing_map_vis, model) # WebUI with gr.Blocks() as demo: # gr.Markdown("

🧑‍ Visualization Demo of Facial Masking Strategies

") gr.HTML("

🧑‍ Visualization Demo of Facial Masking Strategies

") gr.Markdown( "This is a demo of visualizing different facial masking strategies that are introduced in [FSFM-3C](https://fsfm-3c.github.io/) for facial masked image modeling (MIM)." ) gr.Markdown( "- Random Masking: Random masking all patches." ) gr.Markdown( "- Fasking-I: Use a face parser to divide facial regions and priority masking non-skin and non-background regions." ) gr.Markdown( "- FRP: Facial Region Proportional masking, which masks an equal portion of patches in each facial region to the overall masking ratio." ) gr.Markdown( "- CRFR-R: (1) Covering a Random Facial Region followed by (2) Random masking other patche." ) gr.Markdown( "- CRFR-P _(suggested in FSFM-3C)_: (1) Covering a Random Facial Region followed by (2) Proportional masking masking other regions." ) with gr.Column(): image = gr.Image(label="Upload/Capture/Paste a facial image", type="pil") image_submit_btn = gr.Button("🖱️ Face Parsing") with gr.Row(): ori_image = gr.Image(interactive=False, label="Detected Face") parsing_map_vis = gr.Image(interactive=False, label="Face Parsing") patch_parsing_map = gr.Image(interactive=False, label="Patchify") gr.HTML('
') with gr.Column(): # Random random_submit_btn = gr.Button("🖱️ Random Masking") ratio_random = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Masking Ratio for Random Masking") with gr.Row(): random_patch_on_parsing = gr.Image(interactive=False, label="Mask/Parsing") random_mask = gr.Image(interactive=False, label="Mask") random_mask_on_image = gr.Image(interactive=False, label="Masked Face") gr.HTML('
') with gr.Column(): # Fasking-I fasking_submit_btn = gr.Button("🖱️ Fasking-I") ratio_fasking = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Masking Ratio for Fasking") with gr.Row(): fasking_patch_on_parsing = gr.Image(interactive=False, label="Mask/Parsing") fasking_mask = gr.Image(interactive=False, label="Mask") fasking_mask_on_image = gr.Image(interactive=False, label="Masked Face") gr.HTML('
') with gr.Column(): # FRP FRP_submit_btn = gr.Button("🖱️ FRP") ratio_FRP = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Masking Ratio for FRP") with gr.Row(): FRP_patch_on_parsing = gr.Image(interactive=False, label="Mask/Parsing") FRP_mask = gr.Image(interactive=False, label="Mask") FRP_mask_on_image = gr.Image(interactive=False, label="Masked Face") gr.HTML('
') with gr.Column(): # CRFR-R CRFR_R_submit_btn = gr.Button("🖱️ CRFR-R") ratio_CRFR_R = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Masking Ratio for CRFR-R") mask_region_CRFR_R = gr.Radio(choices=['Eyebrows', 'Eyes', 'Nose', 'Mouth', 'Face Boundaries', 'Hair','Skin','Background'], value='Eyes', label="Facial Region (for CRFR, highlighted by black)") with gr.Row(): CRFR_R_patch_on_parsing = gr.Image(interactive=False, label="Mask/Parsing") CRFR_R_mask = gr.Image(interactive=False, label="Mask") CRFR_R_mask_on_image = gr.Image(interactive=False, label="Masked Face") gr.HTML('
') with gr.Column(): # CRFR-P CRFR_P_submit_btn = gr.Button("🖱️ CRFR-P (suggested in FSFM-3C)") ratio_CRFR_P = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Masking Ratio for CRFR-P") mask_region_CRFR_P = gr.Radio(choices=['Eyebrows', 'Eyes', 'Nose', 'Mouth', 'Face Boundaries', 'Hair', 'Skin', 'Background'], value='Eyes', label="Facial Region (for CRFR, highlighted by black)") with gr.Row(): CRFR_P_patch_on_parsing = gr.Image(interactive=False, label="Mask/Parsing") CRFR_P_mask = gr.Image(interactive=False, label="Mask") CRFR_P_mask_on_image = gr.Image(interactive=False, label="Masked Face") gr.HTML( '
' '' '' '' '
' ) parseing_map = [] image_submit_btn.click( fn = do_face_parsing, inputs=image, outputs=[ori_image, parsing_map_vis, patch_parsing_map] ) random_submit_btn.click( fn = do_random_masking, inputs=[ori_image, parsing_map_vis, ratio_random], outputs=[random_patch_on_parsing, random_mask, random_mask_on_image], ) ratio_random.change( fn = do_random_masking, inputs=[ori_image, parsing_map_vis, ratio_random], outputs=[random_patch_on_parsing, random_mask, random_mask_on_image], ) fasking_submit_btn.click( fn = do_fasking_masking, inputs=[ori_image, parsing_map_vis, ratio_fasking], outputs=[fasking_patch_on_parsing, fasking_mask, fasking_mask_on_image], ) ratio_fasking.change( fn = do_fasking_masking, inputs=[ori_image, parsing_map_vis, ratio_fasking], outputs=[fasking_patch_on_parsing, fasking_mask, fasking_mask_on_image], ) FRP_submit_btn.click( fn = do_FRP_masking, inputs=[ori_image, parsing_map_vis, ratio_FRP], outputs=[FRP_patch_on_parsing, FRP_mask, FRP_mask_on_image], ) ratio_FRP.change( fn = do_FRP_masking, inputs=[ori_image, parsing_map_vis, ratio_FRP], outputs=[FRP_patch_on_parsing, FRP_mask, FRP_mask_on_image], ) CRFR_R_submit_btn.click( fn = do_CRFR_R_masking, inputs=[ori_image, parsing_map_vis, ratio_CRFR_R, mask_region_CRFR_R], outputs=[CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image], ) ratio_CRFR_R.change( fn = do_CRFR_R_masking, inputs=[ori_image, parsing_map_vis, ratio_CRFR_R, mask_region_CRFR_R], outputs=[CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image], ) mask_region_CRFR_R.change( fn = do_CRFR_R_masking, inputs=[ori_image, parsing_map_vis, ratio_CRFR_R, mask_region_CRFR_R], outputs=[CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image], ) CRFR_P_submit_btn.click( fn = do_CRFR_P_masking, inputs=[ori_image, parsing_map_vis, ratio_CRFR_P, mask_region_CRFR_P], outputs=[CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image], ) ratio_CRFR_P.change( fn=do_CRFR_P_masking, inputs=[ori_image, parsing_map_vis, ratio_CRFR_P, mask_region_CRFR_P], outputs=[CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image], ) mask_region_CRFR_P.change( fn=do_CRFR_P_masking, inputs=[ori_image, parsing_map_vis, ratio_CRFR_P, mask_region_CRFR_P], outputs=[CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image], ) if __name__ == "__main__": gr.close_all() demo.queue() demo.launch()