FSFM-3C
commited on
Commit
·
d4e7f2f
1
Parent(s):
52bc0c5
Add V1.0
Browse files- app.py +797 -0
- facer/__init__.py +55 -0
- facer/draw.py +186 -0
- facer/face_alignment/__init__.py +2 -0
- facer/face_alignment/base.py +24 -0
- facer/face_alignment/farl.py +180 -0
- facer/face_alignment/network/__init__.py +42 -0
- facer/face_alignment/network/common.py +91 -0
- facer/face_alignment/network/geometry.py +45 -0
- facer/face_alignment/network/mmseg.py +29 -0
- facer/face_alignment/network/transformers.py +173 -0
- facer/face_attribute/__init__.py +2 -0
- facer/face_attribute/base.py +24 -0
- facer/face_attribute/farl.py +156 -0
- facer/face_detection/__init__.py +2 -0
- facer/face_detection/base.py +19 -0
- facer/face_detection/retinaface.py +672 -0
- facer/face_parsing/__init__.py +2 -0
- facer/face_parsing/base.py +27 -0
- facer/face_parsing/farl.py +92 -0
- facer/farl/__init__.py +5 -0
- facer/farl/classification.py +149 -0
- facer/farl/model.py +419 -0
- facer/io.py +28 -0
- facer/show.py +37 -0
- facer/transform.py +386 -0
- facer/util.py +169 -0
- facer/version.py +1 -0
- models_mae.py +251 -0
- requirements.txt +80 -0
- util/crop.py +43 -0
- util/datasets.py +349 -0
- util/lars.py +44 -0
- util/loss_contrastive.py +360 -0
- util/lr_decay.py +72 -0
- util/lr_sched.py +44 -0
- util/metrics.py +88 -0
- util/misc.py +390 -0
- util/pos_embed.py +118 -0
app.py
ADDED
@@ -0,0 +1,797 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR
|
3 |
+
# --------------------------------------------------------
|
4 |
+
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
+
# You can find the license in the LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# pip uninstall nvidia_cublas_cu11
|
8 |
+
|
9 |
+
|
10 |
+
import sys
|
11 |
+
sys.path.append('..')
|
12 |
+
import os
|
13 |
+
os.system(f'pip install dlib')
|
14 |
+
import torch
|
15 |
+
import numpy as np
|
16 |
+
from PIL import Image
|
17 |
+
import models_mae
|
18 |
+
from torch.nn import functional as F
|
19 |
+
import dlib
|
20 |
+
|
21 |
+
import gradio as gr
|
22 |
+
|
23 |
+
|
24 |
+
# loading model
|
25 |
+
model = getattr(models_mae, 'mae_vit_base_patch16')()
|
26 |
+
|
27 |
+
|
28 |
+
class ITEM:
|
29 |
+
def __init__(self, img, parsing_map):
|
30 |
+
self.image = img
|
31 |
+
self.parsing_map = parsing_map
|
32 |
+
|
33 |
+
face_to_show = ITEM(None, None)
|
34 |
+
|
35 |
+
|
36 |
+
check_region = {'Eyebrows': [2, 3],
|
37 |
+
'Eyes': [4, 5],
|
38 |
+
'Nose': [6],
|
39 |
+
'Mouth': [7, 8, 9],
|
40 |
+
'Face Boundaries': [10, 1, 0],
|
41 |
+
'Hair': [10],
|
42 |
+
'Skin': [1],
|
43 |
+
'Background': [0]}
|
44 |
+
|
45 |
+
|
46 |
+
def get_boundingbox(face, width, height, minsize=None):
|
47 |
+
"""
|
48 |
+
Expects a dlib face to generate a quadratic bounding box.
|
49 |
+
:param face: dlib face class
|
50 |
+
:param width: frame width
|
51 |
+
:param height: frame height
|
52 |
+
:param cfg.face_scale: bounding box size multiplier to get a bigger face region
|
53 |
+
:param minsize: set minimum bounding box size
|
54 |
+
:return: x, y, bounding_box_size in opencv form
|
55 |
+
"""
|
56 |
+
x1 = face.left()
|
57 |
+
y1 = face.top()
|
58 |
+
x2 = face.right()
|
59 |
+
y2 = face.bottom()
|
60 |
+
size_bb = int(max(x2 - x1, y2 - y1) * 1.3)
|
61 |
+
if minsize:
|
62 |
+
if size_bb < minsize:
|
63 |
+
size_bb = minsize
|
64 |
+
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
|
65 |
+
|
66 |
+
# Check for out of bounds, x-y top left corner
|
67 |
+
x1 = max(int(center_x - size_bb // 2), 0)
|
68 |
+
y1 = max(int(center_y - size_bb // 2), 0)
|
69 |
+
# Check for too big bb size for given x, y
|
70 |
+
size_bb = min(width - x1, size_bb)
|
71 |
+
size_bb = min(height - y1, size_bb)
|
72 |
+
|
73 |
+
return x1, y1, size_bb
|
74 |
+
|
75 |
+
|
76 |
+
def extract_face(frame):
|
77 |
+
face_detector = dlib.get_frontal_face_detector()
|
78 |
+
image = np.array(frame.convert('RGB'))
|
79 |
+
faces = face_detector(image, 1)
|
80 |
+
if len(faces) > 0:
|
81 |
+
# For now only take the biggest face
|
82 |
+
face = faces[0]
|
83 |
+
# Face crop and rescale(follow FF++)
|
84 |
+
x, y, size = get_boundingbox(face, image.shape[1], image.shape[0])
|
85 |
+
# Get the landmarks/parts for the face in box d only with the five key points
|
86 |
+
cropped_face = image[y:y + size, x:x + size]
|
87 |
+
# cropped_face = cv2.resize(cropped_face, (224, 224), interpolation=cv2.INTER_CUBIC)
|
88 |
+
return Image.fromarray(cropped_face)
|
89 |
+
else:
|
90 |
+
return None
|
91 |
+
|
92 |
+
|
93 |
+
from torchvision.transforms import transforms
|
94 |
+
def show_one_img_patchify(img, model):
|
95 |
+
x = torch.tensor(img)
|
96 |
+
|
97 |
+
# make it a batch-like
|
98 |
+
x = x.unsqueeze(dim=0)
|
99 |
+
x = torch.einsum('nhwc->nchw', x)
|
100 |
+
x_patches = model.patchify(x)
|
101 |
+
|
102 |
+
# visualize the img_patchify
|
103 |
+
n = int(np.sqrt(x_patches.shape[1]))
|
104 |
+
image_size = int(224/n)
|
105 |
+
padding = 3
|
106 |
+
new_img = Image.new('RGB', (n * image_size + padding*(n-1), n * image_size + padding*(n-1)), 'white')
|
107 |
+
for i, patch in enumerate(x_patches[0]):
|
108 |
+
ax = i % n
|
109 |
+
ay = int(i / n)
|
110 |
+
patch_img_tensor = torch.reshape(patch, (model.patch_embed.patch_size[0], model.patch_embed.patch_size[1], 3))
|
111 |
+
patch_img_tensor = torch.einsum('hwc->chw', patch_img_tensor)
|
112 |
+
patch_img = transforms.ToPILImage()(patch_img_tensor)
|
113 |
+
new_img.paste(patch_img, (ax * image_size + padding * ax, ay * image_size + padding * ay))
|
114 |
+
|
115 |
+
new_img = new_img.resize((224, 224), Image.BICUBIC)
|
116 |
+
return new_img
|
117 |
+
|
118 |
+
|
119 |
+
def show_one_img_parchify_mask(img, parsing_map, mask, model):
|
120 |
+
mask = mask.detach()
|
121 |
+
mask_patches = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3)
|
122 |
+
mask = model.unpatchify(mask_patches) # 1 is removing, 0 is keeping
|
123 |
+
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
|
124 |
+
|
125 |
+
# visualize mask
|
126 |
+
vis_mask = mask[0].clone()
|
127 |
+
vis_mask[vis_mask == 1] = 1 # gray for masked
|
128 |
+
vis_mask[vis_mask == 2] = -1 # black for highlight masked facial region
|
129 |
+
vis_mask[vis_mask == 0] = 2 # white for visible
|
130 |
+
vis_mask = torch.clip(vis_mask * 127, 0, 255).int()
|
131 |
+
fasking_mask = vis_mask.numpy().astype(np.uint8)
|
132 |
+
fasking_mask = Image.fromarray(fasking_mask)
|
133 |
+
|
134 |
+
# visualize the masked image
|
135 |
+
im_masked = img
|
136 |
+
im_masked[mask[0] == 1] = 127
|
137 |
+
im_masked[mask[0] == 2] = 0
|
138 |
+
im_masked = Image.fromarray(im_masked)
|
139 |
+
|
140 |
+
# visualize the masked image_patchify
|
141 |
+
parsing_map_masked = parsing_map
|
142 |
+
parsing_map_masked[mask[0] == 1] = 127
|
143 |
+
parsing_map_masked[mask[0] == 2] = 0
|
144 |
+
|
145 |
+
return [show_one_img_patchify(parsing_map_masked, model), fasking_mask, im_masked]
|
146 |
+
|
147 |
+
|
148 |
+
# Random
|
149 |
+
class CollateFn_Random:
|
150 |
+
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75):
|
151 |
+
self.img_size = input_size
|
152 |
+
self.patch_size = patch_size
|
153 |
+
self.num_patches_axis = input_size // patch_size
|
154 |
+
self.num_patches = (input_size // patch_size) ** 2
|
155 |
+
self.mask_ratio = mask_ratio
|
156 |
+
|
157 |
+
def __call__(self, image, parsing_map):
|
158 |
+
random_mask = torch.zeros(parsing_map.size(0), self.num_patches, dtype=torch.float32) # torch.Size([BS, 14, 14])
|
159 |
+
random_mask = self.masking(parsing_map, random_mask)
|
160 |
+
|
161 |
+
return {'image': image, 'random_mask': random_mask}
|
162 |
+
|
163 |
+
def masking(self, parsing_map, random_mask):
|
164 |
+
"""
|
165 |
+
:return:
|
166 |
+
"""
|
167 |
+
for i in range(random_mask.size(0)):
|
168 |
+
# normalize the masking to strictly target percentage for batch computation.
|
169 |
+
num_mask_to_change = int(self.mask_ratio * self.num_patches)
|
170 |
+
mask_change_to = 1 if num_mask_to_change >= 0 else 0
|
171 |
+
change_indices = torch.randperm(self.num_patches)
|
172 |
+
for idx in range(num_mask_to_change):
|
173 |
+
random_mask[i, change_indices[idx]] = mask_change_to
|
174 |
+
|
175 |
+
return random_mask
|
176 |
+
|
177 |
+
|
178 |
+
def do_random_masking(image, parsing_map_vis, ratio):
|
179 |
+
img = torch.from_numpy(image)
|
180 |
+
img = img.unsqueeze(0).permute(0, 3, 1, 2)
|
181 |
+
parsing_map = face_to_show.parsing_map
|
182 |
+
parsing_map = torch.tensor(parsing_map)
|
183 |
+
|
184 |
+
mask_method = CollateFn_Random(input_size=224, patch_size=16, mask_ratio=ratio)
|
185 |
+
mask = mask_method(img, parsing_map)['random_mask']
|
186 |
+
|
187 |
+
random_patch_on_parsing, random_mask, random_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask, model)
|
188 |
+
|
189 |
+
return random_patch_on_parsing, random_mask, random_mask_on_image
|
190 |
+
|
191 |
+
|
192 |
+
# Fasking
|
193 |
+
class CollateFn_Fasking:
|
194 |
+
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75):
|
195 |
+
self.img_size = input_size
|
196 |
+
self.patch_size = patch_size
|
197 |
+
self.num_patches_axis = input_size // patch_size
|
198 |
+
self.num_patches = (input_size // patch_size) ** 2
|
199 |
+
self.mask_ratio = mask_ratio
|
200 |
+
# --------------------------------------------------------------------------
|
201 |
+
self.facial_region_group = [
|
202 |
+
[2, 4], # right eye
|
203 |
+
[3, 5], # left eye
|
204 |
+
[6], # nose
|
205 |
+
[7, 8, 9], # mouth
|
206 |
+
[10], # hair
|
207 |
+
[1], # skin
|
208 |
+
[0] # background
|
209 |
+
] # ['background', 'face', 'rb', 'lb', 're', 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
|
210 |
+
|
211 |
+
def __call__(self, image, parsing_map):
|
212 |
+
# image = torch.stack([sample['image'] for sample in samples]) # torch.Size([bs, 3, 224, 224])
|
213 |
+
# parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) # torch.Size([bs, 1, 224, 224])
|
214 |
+
# parsing_map = parsing_map.squeeze(1) # torch.Size([BS, 1, 224, 224]) → torch.Size([BS, 224, 224])
|
215 |
+
|
216 |
+
# random select a facial semantic region and get corresponding mask(masking all patches include this region)
|
217 |
+
fasking_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) # torch.Size([BS, 14, 14])
|
218 |
+
fasking_mask = self.fasking(parsing_map, fasking_mask)
|
219 |
+
|
220 |
+
return {'image': image, 'fasking_mask': fasking_mask}
|
221 |
+
|
222 |
+
def fasking(self, parsing_map, fasking_mask):
|
223 |
+
"""
|
224 |
+
:return:
|
225 |
+
"""
|
226 |
+
for i in range(parsing_map.size(0)):
|
227 |
+
terminate = False
|
228 |
+
for seg_group in self.facial_region_group[:-2]:
|
229 |
+
if terminate:
|
230 |
+
break
|
231 |
+
for comp_value in seg_group:
|
232 |
+
fasking_mask[i] = torch.maximum(
|
233 |
+
fasking_mask[i], F.max_pool2d((parsing_map[i].unsqueeze(0) == comp_value).float(), kernel_size=self.patch_size))
|
234 |
+
if fasking_mask[i].mean() >= ((self.mask_ratio * self.num_patches) / self.num_patches):
|
235 |
+
terminate = True
|
236 |
+
break
|
237 |
+
|
238 |
+
fasking_mask = fasking_mask.view(parsing_map.size(0), -1)
|
239 |
+
for i in range(fasking_mask.size(0)):
|
240 |
+
# normalize the masking to strictly target percentage for batch computation.
|
241 |
+
num_mask_to_change = (self.mask_ratio * self.num_patches - fasking_mask[i].sum(dim=-1)).int()
|
242 |
+
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item()
|
243 |
+
select_indices = (fasking_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1)
|
244 |
+
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)]
|
245 |
+
fasking_mask[i, select_indices[change_indices]] = mask_change_to
|
246 |
+
|
247 |
+
return fasking_mask
|
248 |
+
|
249 |
+
|
250 |
+
def do_fasking_masking(image, parsing_map_vis, ratio):
|
251 |
+
img = torch.from_numpy(image)
|
252 |
+
img = img.unsqueeze(0).permute(0, 3, 1, 2)
|
253 |
+
parsing_map = face_to_show.parsing_map
|
254 |
+
parsing_map = torch.tensor(parsing_map)
|
255 |
+
|
256 |
+
mask_method = CollateFn_Fasking(input_size=224, patch_size=16, mask_ratio=ratio)
|
257 |
+
mask = mask_method(img, parsing_map)['fasking_mask']
|
258 |
+
|
259 |
+
fasking_patch_on_parsing, fasking_mask, fasking_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask, model)
|
260 |
+
|
261 |
+
return fasking_patch_on_parsing, fasking_mask, fasking_mask_on_image
|
262 |
+
|
263 |
+
|
264 |
+
# FRP
|
265 |
+
class CollateFn_FR_P_Masking:
|
266 |
+
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75):
|
267 |
+
self.img_size = input_size
|
268 |
+
self.patch_size = patch_size
|
269 |
+
self.num_patches_axis = input_size // patch_size
|
270 |
+
self.num_patches = (input_size // patch_size) ** 2
|
271 |
+
self.mask_ratio = mask_ratio
|
272 |
+
self.facial_region_group = [
|
273 |
+
[2, 3], # eyebrows
|
274 |
+
[4, 5], # eyes
|
275 |
+
[6], # nose
|
276 |
+
[7, 8, 9], # mouth
|
277 |
+
[10, 1, 0], # face boundaries
|
278 |
+
[10], # hair
|
279 |
+
[1], # facial skin
|
280 |
+
[0] # background
|
281 |
+
] # ['background', 'face', 'rb', 'lb', 're', 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
|
282 |
+
|
283 |
+
def __call__(self, image, parsing_map):
|
284 |
+
# image = torch.stack([sample['image'] for sample in samples]) # torch.Size([bs, 3, 224, 224])
|
285 |
+
# parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) # torch.Size([bs, 1, 224, 224])
|
286 |
+
# parsing_map = parsing_map.squeeze(1) # torch.Size([BS, 1, 224, 224]) → torch.Size([BS, 224, 224])
|
287 |
+
|
288 |
+
# random select a facial semantic region and get corresponding mask(masking all patches include this region)
|
289 |
+
P_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) # torch.Size([BS, 14, 14])
|
290 |
+
P_mask = self.random_variable_facial_semantics_masking(parsing_map, P_mask)
|
291 |
+
|
292 |
+
return {'image': image, 'P_mask': P_mask}
|
293 |
+
|
294 |
+
def random_variable_facial_semantics_masking(self, parsing_map, P_mask):
|
295 |
+
"""
|
296 |
+
:return:
|
297 |
+
"""
|
298 |
+
P_mask = P_mask.view(P_mask.size(0), -1)
|
299 |
+
for i in range(parsing_map.size(0)):
|
300 |
+
|
301 |
+
for seg_group in self.facial_region_group[:-2]:
|
302 |
+
mask_in_seg_group = torch.zeros(1, self.num_patches_axis, self.num_patches_axis, dtype=torch.float32)
|
303 |
+
if seg_group == [10, 1, 0]:
|
304 |
+
patch_hair_bg = F.max_pool2d(
|
305 |
+
((parsing_map[i].unsqueeze(0) == 10) + (parsing_map[i].unsqueeze(0) == 0)).float(),
|
306 |
+
kernel_size=self.patch_size)
|
307 |
+
patch_skin = F.max_pool2d((parsing_map[i].unsqueeze(0) == 1).float(), kernel_size=self.patch_size)
|
308 |
+
# skin&hair or skin&bg defined as facial boundaries:
|
309 |
+
mask_in_seg_group = torch.maximum(mask_in_seg_group,
|
310 |
+
(patch_hair_bg.bool() & patch_skin.bool()).float())
|
311 |
+
else:
|
312 |
+
for comp_value in seg_group:
|
313 |
+
mask_in_seg_group = torch.maximum(mask_in_seg_group,
|
314 |
+
F.max_pool2d(
|
315 |
+
(parsing_map[i].unsqueeze(0) == comp_value).float(),
|
316 |
+
kernel_size=self.patch_size))
|
317 |
+
|
318 |
+
mask_in_seg_group = mask_in_seg_group.view(-1)
|
319 |
+
# to_mask_patches_in_seg_group = mask_in_seg_group - (mask_in_seg_group & P_mask[i])
|
320 |
+
to_mask_patches_in_seg_group = (mask_in_seg_group - P_mask[i]) > 0
|
321 |
+
mask_num = (mask_in_seg_group.sum(dim=-1) * self.mask_ratio -
|
322 |
+
(mask_in_seg_group.sum(dim=-1)-to_mask_patches_in_seg_group.sum(dim=-1))).int()
|
323 |
+
if mask_num > 0:
|
324 |
+
select_indices = (to_mask_patches_in_seg_group == 1).nonzero(as_tuple=False).view(-1)
|
325 |
+
change_indices = torch.randperm(len(select_indices))[:mask_num]
|
326 |
+
P_mask[i, select_indices[change_indices]] = 1
|
327 |
+
|
328 |
+
num_mask_to_change = (self.mask_ratio * self.num_patches - P_mask[i].sum(dim=-1)).int()
|
329 |
+
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item()
|
330 |
+
select_indices = (P_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1)
|
331 |
+
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)]
|
332 |
+
P_mask[i, select_indices[change_indices]] = mask_change_to
|
333 |
+
|
334 |
+
return P_mask
|
335 |
+
|
336 |
+
|
337 |
+
def do_FRP_masking(image, parsing_map_vis, ratio):
|
338 |
+
img = torch.from_numpy(image)
|
339 |
+
img = img.unsqueeze(0).permute(0, 3, 1, 2)
|
340 |
+
parsing_map = face_to_show.parsing_map
|
341 |
+
parsing_map = torch.tensor(parsing_map)
|
342 |
+
|
343 |
+
mask_method = CollateFn_FR_P_Masking(input_size=224, patch_size=16, mask_ratio=ratio)
|
344 |
+
masks = mask_method(img, parsing_map)
|
345 |
+
mask = masks['P_mask']
|
346 |
+
|
347 |
+
FRP_patch_on_parsing, FRP_mask, FRP_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask, model)
|
348 |
+
|
349 |
+
return FRP_patch_on_parsing, FRP_mask, FRP_mask_on_image
|
350 |
+
|
351 |
+
|
352 |
+
# CRFR_R
|
353 |
+
class CollateFn_CRFR_R_Masking:
|
354 |
+
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75, region='Nose'):
|
355 |
+
self.img_size = input_size
|
356 |
+
self.patch_size = patch_size
|
357 |
+
self.num_patches_axis = input_size // patch_size
|
358 |
+
self.num_patches = (input_size // patch_size) ** 2
|
359 |
+
self.mask_ratio = mask_ratio
|
360 |
+
self.facial_region_group = [
|
361 |
+
[2, 3], # eyebrows
|
362 |
+
[4, 5], # eyes
|
363 |
+
[6], # nose
|
364 |
+
[7, 8, 9], # mouth
|
365 |
+
[10, 1, 0], # face boundaries
|
366 |
+
[10], # hair
|
367 |
+
[1], # facial skin
|
368 |
+
[0] # background
|
369 |
+
] # ['background', 'face', 'rb', 'lb', 're', 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
|
370 |
+
self.random_specific_facial_region = check_region[region]
|
371 |
+
|
372 |
+
def __call__(self, image, parsing_map):
|
373 |
+
# mage = torch.stack([sample['image'] for sample in samples]) # torch.Size([bs, 3, 224, 224])
|
374 |
+
# parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) # torch.Size([bs, 1, 224, 224])
|
375 |
+
# parsing_map = parsing_map.squeeze(1) # torch.Size([BS, 1, 224, 224]) → torch.Size([BS, 224, 224])
|
376 |
+
|
377 |
+
# random select a facial semantic region and get corresponding mask(masking all patches include this region)
|
378 |
+
facial_region_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) # torch.Size([1, H/P, W/P])
|
379 |
+
facial_region_mask, random_specific_facial_region = self.masking_all_patches_in_random_specific_facial_region(parsing_map, facial_region_mask)
|
380 |
+
# torch.Size([num_patches,]), list
|
381 |
+
|
382 |
+
CRFR_R_mask, facial_region_mask = self.random_variable_masking(facial_region_mask)
|
383 |
+
# torch.Size([num_patches,]), torch.Size([num_patches,])
|
384 |
+
|
385 |
+
return {'image': image, 'CRFR_R_mask': CRFR_R_mask, 'fr_mask': facial_region_mask}
|
386 |
+
|
387 |
+
def masking_all_patches_in_random_specific_facial_region(self, parsing_map, facial_region_mask):
|
388 |
+
"""
|
389 |
+
:param parsing_map: [1, img_size, img_size])
|
390 |
+
:param facial_region_mask: [1, num_patches ** .5, num_patches ** .5]
|
391 |
+
:return: facial_region_mask, random_specific_facial_region
|
392 |
+
"""
|
393 |
+
# random_specific_facial_region = random.choice(self.facial_region_group[:-2])
|
394 |
+
# random_specific_facial_region = [6] # for test: nose
|
395 |
+
if self.random_specific_facial_region == [10, 1, 0]: # facial boundaries, 10-hair 1-skin 0-background
|
396 |
+
# True for hair(10) or bg(0) patches:
|
397 |
+
patch_hair_bg = F.max_pool2d(((parsing_map == 10) + (parsing_map == 0)).float(),
|
398 |
+
kernel_size=self.patch_size)
|
399 |
+
# True for skin(1) patches:
|
400 |
+
patch_skin = F.max_pool2d((parsing_map == 1).float(), kernel_size=self.patch_size)
|
401 |
+
# skin&hair or skin&bg is defined as facial boundaries:
|
402 |
+
facial_region_mask = (patch_hair_bg.bool() & patch_skin.bool()).float()
|
403 |
+
else:
|
404 |
+
for facial_region_index in self.random_specific_facial_region:
|
405 |
+
facial_region_mask = torch.maximum(facial_region_mask,
|
406 |
+
F.max_pool2d((parsing_map == facial_region_index).float(),
|
407 |
+
kernel_size=self.patch_size))
|
408 |
+
|
409 |
+
return facial_region_mask.view(parsing_map.size(0), -1), self.random_specific_facial_region
|
410 |
+
|
411 |
+
def random_variable_masking(self, facial_region_mask):
|
412 |
+
CRFR_R_mask = facial_region_mask.clone()
|
413 |
+
|
414 |
+
for i in range(facial_region_mask.size(0)):
|
415 |
+
num_mask_to_change = (self.mask_ratio * self.num_patches - facial_region_mask[i].sum(dim=-1)).int()
|
416 |
+
mask_change_to = 1 if num_mask_to_change >= 0 else 0
|
417 |
+
|
418 |
+
select_indices = (facial_region_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1)
|
419 |
+
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)]
|
420 |
+
CRFR_R_mask[i, select_indices[change_indices]] = mask_change_to
|
421 |
+
|
422 |
+
facial_region_mask[i] = CRFR_R_mask[i] if num_mask_to_change < 0 else facial_region_mask[i]
|
423 |
+
|
424 |
+
return CRFR_R_mask, facial_region_mask
|
425 |
+
|
426 |
+
|
427 |
+
def do_CRFR_R_masking(image, parsing_map_vis, ratio, region):
|
428 |
+
img = torch.from_numpy(image)
|
429 |
+
img = img.unsqueeze(0).permute(0, 3, 1, 2)
|
430 |
+
parsing_map = face_to_show.parsing_map
|
431 |
+
parsing_map = torch.tensor(parsing_map)
|
432 |
+
|
433 |
+
mask_method = CollateFn_CRFR_R_Masking(input_size=224, patch_size=16, mask_ratio=ratio, region=region)
|
434 |
+
masks = mask_method(img, parsing_map)
|
435 |
+
mask = masks['CRFR_R_mask']
|
436 |
+
fr_mask = masks['fr_mask']
|
437 |
+
|
438 |
+
CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask+fr_mask, model)
|
439 |
+
|
440 |
+
return CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image
|
441 |
+
|
442 |
+
|
443 |
+
# CRFR_P
|
444 |
+
class CollateFn_CRFR_P_Masking:
|
445 |
+
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75, region='Nose'):
|
446 |
+
self.img_size = input_size
|
447 |
+
self.patch_size = patch_size
|
448 |
+
self.num_patches_axis = input_size // patch_size
|
449 |
+
self.num_patches = (input_size // patch_size) ** 2
|
450 |
+
self.mask_ratio = mask_ratio
|
451 |
+
|
452 |
+
self.facial_region_group = [
|
453 |
+
[2, 3], # eyebrows
|
454 |
+
[4, 5], # eyes
|
455 |
+
[6], # nose
|
456 |
+
[7, 8, 9], # mouth
|
457 |
+
[10, 1, 0], # face boundaries
|
458 |
+
[10], # hair
|
459 |
+
[1], # facial skin
|
460 |
+
[0] # background
|
461 |
+
] # ['background', 'face', 'rb', 'lb', 're', 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
|
462 |
+
self.random_specific_facial_region = check_region[region]
|
463 |
+
|
464 |
+
def __call__(self, image, parsing_map):
|
465 |
+
# image = torch.stack([sample['image'] for sample in samples]) # torch.Size([bs, 3, 224, 224])
|
466 |
+
# parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) # torch.Size([bs, 1, 224, 224])
|
467 |
+
# parsing_map = parsing_map.squeeze(1) # torch.Size([BS, 1, 224, 224]) → torch.Size([BS, 224, 224])
|
468 |
+
|
469 |
+
# random select a facial semantic region and get corresponding mask(masking all patches include this region)
|
470 |
+
facial_region_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis,
|
471 |
+
dtype=torch.float32) # torch.Size([1, H/P, W/P])
|
472 |
+
facial_region_mask, random_specific_facial_region = self.masking_all_patches_in_random_specific_facial_region(parsing_map, facial_region_mask)
|
473 |
+
# torch.Size([num_patches,]), list
|
474 |
+
|
475 |
+
CRFR_P_mask, facial_region_mask = self.random_variable_masking(parsing_map, facial_region_mask, random_specific_facial_region)
|
476 |
+
# torch.Size([num_patches,]), torch.Size([num_patches,])
|
477 |
+
|
478 |
+
return {'image': image, 'CRFR_P_mask': CRFR_P_mask, 'fr_mask': facial_region_mask}
|
479 |
+
|
480 |
+
def masking_all_patches_in_random_specific_facial_region(self, parsing_map, facial_region_mask):
|
481 |
+
"""
|
482 |
+
:param parsing_map: [1, img_size, img_size])
|
483 |
+
:param facial_region_mask: [1, num_patches ** .5, num_patches ** .5]
|
484 |
+
:return: facial_region_mask, random_specific_facial_region
|
485 |
+
"""
|
486 |
+
# random_specific_facial_region = random.choice(self.facial_region_group[:-2])
|
487 |
+
# random_specific_facial_region = [4, 5] # for test: eyes
|
488 |
+
if self.random_specific_facial_region == [10, 1, 0]: # facial boundaries, 10-hair 1-skin 0-background
|
489 |
+
# True for hair(10) or bg(0) patches:
|
490 |
+
patch_hair_bg = F.max_pool2d(((parsing_map == 10) + (parsing_map == 0)).float(), kernel_size=self.patch_size)
|
491 |
+
# True for skin(1) patches:
|
492 |
+
patch_skin = F.max_pool2d((parsing_map == 1).float(), kernel_size=self.patch_size)
|
493 |
+
# skin&hair or skin&bg is defined as facial boundaries:
|
494 |
+
facial_region_mask = (patch_hair_bg.bool() & patch_skin.bool()).float()
|
495 |
+
|
496 |
+
# # True for hair(10) or skin(1) patches:
|
497 |
+
# patch_hair_face = F.max_pool2d(((parsing_map == 10) + (parsing_map == 1)).float(),
|
498 |
+
# kernel_size=self.patch_size)
|
499 |
+
# # True for bg(0) patches:
|
500 |
+
# patch_bg = F.max_pool2d((parsing_map == 0).float(), kernel_size=self.patch_size)
|
501 |
+
# # skin&bg or hair&bg defined as facial boundaries:
|
502 |
+
# facial_region_mask = (patch_hair_face.bool() & patch_bg.bool()).float()
|
503 |
+
|
504 |
+
else:
|
505 |
+
for facial_region_index in self.random_specific_facial_region:
|
506 |
+
facial_region_mask = torch.maximum(facial_region_mask,
|
507 |
+
F.max_pool2d((parsing_map == facial_region_index).float(),
|
508 |
+
kernel_size=self.patch_size))
|
509 |
+
|
510 |
+
return facial_region_mask.view(parsing_map.size(0), -1), self.random_specific_facial_region
|
511 |
+
|
512 |
+
def random_variable_masking(self, parsing_map, facial_region_mask, random_specific_facial_region):
|
513 |
+
CRFR_P_mask = facial_region_mask.clone()
|
514 |
+
other_facial_region_group = [region for region in self.facial_region_group if
|
515 |
+
region != random_specific_facial_region]
|
516 |
+
# print(other_facial_region_group)
|
517 |
+
for i in range(facial_region_mask.size(0)): # iterate each map in BS
|
518 |
+
num_mask_to_change = (self.mask_ratio * self.num_patches - facial_region_mask[i].sum(dim=-1)).int()
|
519 |
+
# mask_change_to = 1 if num_mask_to_change >= 0 else 0
|
520 |
+
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item()
|
521 |
+
|
522 |
+
# masking patches in other facial regions according to the corresponding ratio
|
523 |
+
if mask_change_to == 1:
|
524 |
+
# mask_ratio_other_fr = remain(unmasked) patches should be masked / remain(unmasked) patches
|
525 |
+
mask_ratio_other_fr = (
|
526 |
+
num_mask_to_change / (self.num_patches - facial_region_mask[i].sum(dim=-1)))
|
527 |
+
|
528 |
+
masked_patches = facial_region_mask[i].clone()
|
529 |
+
for other_fr in other_facial_region_group:
|
530 |
+
to_mask_patches = torch.zeros(1, self.num_patches_axis, self.num_patches_axis,
|
531 |
+
dtype=torch.float32)
|
532 |
+
if other_fr == [10, 1, 0]:
|
533 |
+
patch_hair_bg = F.max_pool2d(
|
534 |
+
((parsing_map[i].unsqueeze(0) == 10) + (parsing_map[i].unsqueeze(0) == 0)).float(),
|
535 |
+
kernel_size=self.patch_size)
|
536 |
+
patch_skin = F.max_pool2d((parsing_map[i].unsqueeze(0) == 1).float(), kernel_size=self.patch_size)
|
537 |
+
# skin&hair or skin&bg defined as facial boundaries:
|
538 |
+
to_mask_patches = (patch_hair_bg.bool() & patch_skin.bool()).float()
|
539 |
+
else:
|
540 |
+
for facial_region_index in other_fr:
|
541 |
+
to_mask_patches = torch.maximum(to_mask_patches,
|
542 |
+
F.max_pool2d((parsing_map[i].unsqueeze(0) == facial_region_index).float(),
|
543 |
+
kernel_size=self.patch_size))
|
544 |
+
|
545 |
+
# ignore already masked patches:
|
546 |
+
to_mask_patches = (to_mask_patches.view(-1) - masked_patches) > 0
|
547 |
+
# to_mask_patches = to_mask_patches.view(-1) - (to_mask_patches.view(-1) & masked_patches)
|
548 |
+
select_indices = to_mask_patches.nonzero(as_tuple=False).view(-1)
|
549 |
+
change_indices = torch.randperm(len(select_indices))[
|
550 |
+
:torch.round(to_mask_patches.sum() * mask_ratio_other_fr).int()]
|
551 |
+
CRFR_P_mask[i, select_indices[change_indices]] = mask_change_to
|
552 |
+
# prevent overlap
|
553 |
+
masked_patches = masked_patches + to_mask_patches.float()
|
554 |
+
|
555 |
+
# mask/unmask patch from other facial regions to get CRFR_P_mask with fixed size
|
556 |
+
num_mask_to_change = (self.mask_ratio * self.num_patches - CRFR_P_mask[i].sum(dim=-1)).int()
|
557 |
+
# mask_change_to = 1 if num_mask_to_change >= 0 else 0
|
558 |
+
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item()
|
559 |
+
# prevent unmasking facial_region_mask
|
560 |
+
select_indices = ((CRFR_P_mask[i] + facial_region_mask[i]) == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1)
|
561 |
+
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)]
|
562 |
+
CRFR_P_mask[i, select_indices[change_indices]] = mask_change_to
|
563 |
+
|
564 |
+
else:
|
565 |
+
# if the num of facial_region_mask is over (num_patches*mask_ratio),
|
566 |
+
# unmask it to get CRFR_P_mask with fixed size
|
567 |
+
select_indices = (facial_region_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1)
|
568 |
+
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)]
|
569 |
+
CRFR_P_mask[i, select_indices[change_indices]] = mask_change_to
|
570 |
+
facial_region_mask[i] = CRFR_P_mask[i]
|
571 |
+
|
572 |
+
return CRFR_P_mask, facial_region_mask
|
573 |
+
|
574 |
+
|
575 |
+
def do_CRFR_P_masking(image, parsing_map_vis, ratio, region):
|
576 |
+
img = torch.from_numpy(image)
|
577 |
+
img = img.unsqueeze(0).permute(0, 3, 1, 2)
|
578 |
+
parsing_map = face_to_show.parsing_map
|
579 |
+
parsing_map = torch.tensor(parsing_map)
|
580 |
+
|
581 |
+
mask_method = CollateFn_CRFR_P_Masking(input_size=224, patch_size=16, mask_ratio=ratio, region=region)
|
582 |
+
masks = mask_method(img, parsing_map)
|
583 |
+
mask = masks['CRFR_P_mask']
|
584 |
+
fr_mask = masks['fr_mask']
|
585 |
+
|
586 |
+
CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask+fr_mask, model)
|
587 |
+
|
588 |
+
return CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image
|
589 |
+
|
590 |
+
|
591 |
+
def vis_parsing_maps(parsing_anno):
|
592 |
+
part_colors = [[255, 255, 255],
|
593 |
+
[0, 0, 255], [255, 128, 0], [255, 255, 0],
|
594 |
+
[0, 255, 0], [0, 255, 128],
|
595 |
+
[0, 255, 255], [255, 0, 255], [255, 0, 128],
|
596 |
+
[128, 0, 255], [255, 0, 0]]
|
597 |
+
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
|
598 |
+
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
|
599 |
+
|
600 |
+
num_of_class = np.max(vis_parsing_anno)
|
601 |
+
|
602 |
+
for pi in range(1, num_of_class + 1):
|
603 |
+
index = np.where(vis_parsing_anno == pi)
|
604 |
+
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
|
605 |
+
|
606 |
+
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
|
607 |
+
return vis_parsing_anno_color
|
608 |
+
|
609 |
+
|
610 |
+
#from facer import facer
|
611 |
+
import facer
|
612 |
+
def do_face_parsing(img):
|
613 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
614 |
+
|
615 |
+
face_detector = facer.face_detector('retinaface/mobilenet', device=device, threshold=0.3) # 0.3 for FF++
|
616 |
+
face_parser = facer.face_parser('farl/lapa/448', device=device) # celebm parser
|
617 |
+
|
618 |
+
img = extract_face(img)
|
619 |
+
with torch.inference_mode():
|
620 |
+
img = img.resize((224, 224), Image.BICUBIC)
|
621 |
+
image = torch.from_numpy(np.array(img.convert('RGB')))
|
622 |
+
image = image.unsqueeze(0).permute(0, 3, 1, 2).to(device=device)
|
623 |
+
try:
|
624 |
+
faces = face_detector(image)
|
625 |
+
faces = face_parser(image, faces)
|
626 |
+
|
627 |
+
seg_logits = faces['seg']['logits']
|
628 |
+
seg_probs = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w
|
629 |
+
seg_probs = seg_probs.data # torch.Size([1, 11, 224, 224])
|
630 |
+
parsing = seg_probs.argmax(1) # [1, 224, 224]
|
631 |
+
|
632 |
+
parsing_map = parsing.data.cpu().numpy() # [1, 224, 224] int64
|
633 |
+
parsing_map = parsing_map.astype(np.int8) # smaller space
|
634 |
+
parsing_map_vis = vis_parsing_maps(parsing_map.squeeze(0))
|
635 |
+
|
636 |
+
except KeyError:
|
637 |
+
return gr.update()
|
638 |
+
|
639 |
+
face_to_show.image = img
|
640 |
+
face_to_show.parsing_map = parsing_map
|
641 |
+
return img, parsing_map_vis, show_one_img_patchify(parsing_map_vis, model)
|
642 |
+
|
643 |
+
|
644 |
+
# WebUI
|
645 |
+
with gr.Blocks() as demo:
|
646 |
+
gr.Markdown("# Demo of Facial Masking Strategies")
|
647 |
+
gr.Markdown(
|
648 |
+
"This is a demo of different facial masking strategies for MIM that are introduced in [FSFM-3C](https://fsfm-3c.github.io/)"
|
649 |
+
)
|
650 |
+
gr.Markdown(
|
651 |
+
"- <b>Random Masking</b>: Random masking all patches."
|
652 |
+
)
|
653 |
+
gr.Markdown(
|
654 |
+
"- <b>Fasking-I</b>: Use a face parser to divide facial regions and priority masking non-skin and non-background regions."
|
655 |
+
)
|
656 |
+
gr.Markdown(
|
657 |
+
"- <b>FRP</b>: Facial Region Proportional masking, which masks an equal portion of patches in each facial region to the overall masking ratio."
|
658 |
+
)
|
659 |
+
gr.Markdown(
|
660 |
+
"- <b>CRFR-R</b>: (1) Covering a Random Facial Region followed by (2) Random masking other patche."
|
661 |
+
)
|
662 |
+
gr.Markdown(
|
663 |
+
"- <b>CRFR-P (_suggested in FSFM-3C_)</b>: (1) Covering a Random Facial Region followed by (2) Proportional masking masking other regions."
|
664 |
+
)
|
665 |
+
|
666 |
+
with gr.Column():
|
667 |
+
image = gr.Image(label="Upload/Capture/Paste a facial image", type="pil")
|
668 |
+
|
669 |
+
image_submit_btn = gr.Button("🖱️ Face Parsing")
|
670 |
+
with gr.Row():
|
671 |
+
ori_image = gr.Image(interactive=False, label="Detected Face")
|
672 |
+
parsing_map_vis = gr.Image(interactive=False, label="Face Parsing")
|
673 |
+
patch_parsing_map = gr.Image(interactive=False, label="Patchify")
|
674 |
+
gr.HTML('<div class="spacer-20"></div>')
|
675 |
+
|
676 |
+
with gr.Column(): # Random
|
677 |
+
random_submit_btn = gr.Button("🖱️ Random Masking")
|
678 |
+
ratio_random = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Masking Ratio for Random Masking")
|
679 |
+
with gr.Row():
|
680 |
+
random_patch_on_parsing = gr.Image(interactive=False, label="Mask/Parsing")
|
681 |
+
random_mask = gr.Image(interactive=False, label="Mask")
|
682 |
+
random_mask_on_image = gr.Image(interactive=False, label="Masked Face")
|
683 |
+
gr.HTML('<div class="spacer-20"></div>')
|
684 |
+
|
685 |
+
with gr.Column(): # Fasking-I
|
686 |
+
fasking_submit_btn = gr.Button("🖱️ Fasking-I")
|
687 |
+
ratio_fasking = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Masking Ratio for Fasking")
|
688 |
+
with gr.Row():
|
689 |
+
fasking_patch_on_parsing = gr.Image(interactive=False, label="Mask/Parsing")
|
690 |
+
fasking_mask = gr.Image(interactive=False, label="Mask")
|
691 |
+
fasking_mask_on_image = gr.Image(interactive=False, label="Masked Face")
|
692 |
+
gr.HTML('<div class="spacer-20"></div>')
|
693 |
+
|
694 |
+
with gr.Column(): # FRP
|
695 |
+
FRP_submit_btn = gr.Button("🖱️ FRP")
|
696 |
+
ratio_FRP = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Masking Ratio for FRP")
|
697 |
+
with gr.Row():
|
698 |
+
FRP_patch_on_parsing = gr.Image(interactive=False, label="Mask/Parsing")
|
699 |
+
FRP_mask = gr.Image(interactive=False, label="Mask")
|
700 |
+
FRP_mask_on_image = gr.Image(interactive=False, label="Masked Face")
|
701 |
+
gr.HTML('<div class="spacer-20"></div>')
|
702 |
+
|
703 |
+
with gr.Column(): # CRFR-R
|
704 |
+
CRFR_R_submit_btn = gr.Button("🖱️ CRFR-R")
|
705 |
+
ratio_CRFR_R = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Masking Ratio for CRFR-R")
|
706 |
+
mask_region_CRFR_R = gr.Radio(choices=['Eyebrows', 'Eyes', 'Nose', 'Mouth', 'Face Boundaries', 'Hair','Skin','Background'],
|
707 |
+
value='Eyes',
|
708 |
+
label="Facial Region (for CRFR, highlighted by black)")
|
709 |
+
with gr.Row():
|
710 |
+
CRFR_R_patch_on_parsing = gr.Image(interactive=False, label="Mask/Parsing")
|
711 |
+
CRFR_R_mask = gr.Image(interactive=False, label="Mask")
|
712 |
+
CRFR_R_mask_on_image = gr.Image(interactive=False, label="Masked Face")
|
713 |
+
gr.HTML('<div class="spacer-20"></div>')
|
714 |
+
|
715 |
+
with gr.Column(): # CRFR-P
|
716 |
+
CRFR_P_submit_btn = gr.Button("🖱️ CRFR-P (suggested in FSFM-3C)")
|
717 |
+
ratio_CRFR_P = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Masking Ratio for CRFR-P")
|
718 |
+
mask_region_CRFR_P = gr.Radio(choices=['Eyebrows', 'Eyes', 'Nose', 'Mouth', 'Face Boundaries', 'Hair', 'Skin', 'Background'],
|
719 |
+
value='Eyes',
|
720 |
+
label="Facial Region (for CRFR, highlighted by black)")
|
721 |
+
with gr.Row():
|
722 |
+
CRFR_P_patch_on_parsing = gr.Image(interactive=False, label="Mask/Parsing")
|
723 |
+
CRFR_P_mask = gr.Image(interactive=False, label="Mask")
|
724 |
+
CRFR_P_mask_on_image = gr.Image(interactive=False, label="Masked Face")
|
725 |
+
|
726 |
+
|
727 |
+
parseing_map = []
|
728 |
+
image_submit_btn.click(
|
729 |
+
fn = do_face_parsing,
|
730 |
+
inputs=image,
|
731 |
+
outputs=[ori_image, parsing_map_vis, patch_parsing_map]
|
732 |
+
)
|
733 |
+
random_submit_btn.click(
|
734 |
+
fn = do_random_masking,
|
735 |
+
inputs=[ori_image, parsing_map_vis, ratio_random],
|
736 |
+
outputs=[random_patch_on_parsing, random_mask, random_mask_on_image],
|
737 |
+
)
|
738 |
+
ratio_random.change(
|
739 |
+
fn = do_random_masking,
|
740 |
+
inputs=[ori_image, parsing_map_vis, ratio_random],
|
741 |
+
outputs=[random_patch_on_parsing, random_mask, random_mask_on_image],
|
742 |
+
)
|
743 |
+
fasking_submit_btn.click(
|
744 |
+
fn = do_fasking_masking,
|
745 |
+
inputs=[ori_image, parsing_map_vis, ratio_fasking],
|
746 |
+
outputs=[fasking_patch_on_parsing, fasking_mask, fasking_mask_on_image],
|
747 |
+
)
|
748 |
+
ratio_fasking.change(
|
749 |
+
fn = do_fasking_masking,
|
750 |
+
inputs=[ori_image, parsing_map_vis, ratio_fasking],
|
751 |
+
outputs=[fasking_patch_on_parsing, fasking_mask, fasking_mask_on_image],
|
752 |
+
)
|
753 |
+
FRP_submit_btn.click(
|
754 |
+
fn = do_FRP_masking,
|
755 |
+
inputs=[ori_image, parsing_map_vis, ratio_FRP],
|
756 |
+
outputs=[FRP_patch_on_parsing, FRP_mask, FRP_mask_on_image],
|
757 |
+
)
|
758 |
+
ratio_FRP.change(
|
759 |
+
fn = do_FRP_masking,
|
760 |
+
inputs=[ori_image, parsing_map_vis, ratio_FRP],
|
761 |
+
outputs=[FRP_patch_on_parsing, FRP_mask, FRP_mask_on_image],
|
762 |
+
)
|
763 |
+
CRFR_R_submit_btn.click(
|
764 |
+
fn = do_CRFR_R_masking,
|
765 |
+
inputs=[ori_image, parsing_map_vis, ratio_CRFR_R, mask_region_CRFR_R],
|
766 |
+
outputs=[CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image],
|
767 |
+
)
|
768 |
+
ratio_CRFR_R.change(
|
769 |
+
fn = do_CRFR_R_masking,
|
770 |
+
inputs=[ori_image, parsing_map_vis, ratio_CRFR_R, mask_region_CRFR_R],
|
771 |
+
outputs=[CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image],
|
772 |
+
)
|
773 |
+
mask_region_CRFR_R.change(
|
774 |
+
fn = do_CRFR_R_masking,
|
775 |
+
inputs=[ori_image, parsing_map_vis, ratio_CRFR_R, mask_region_CRFR_R],
|
776 |
+
outputs=[CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image],
|
777 |
+
)
|
778 |
+
CRFR_P_submit_btn.click(
|
779 |
+
fn = do_CRFR_P_masking,
|
780 |
+
inputs=[ori_image, parsing_map_vis, ratio_CRFR_P, mask_region_CRFR_P],
|
781 |
+
outputs=[CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image],
|
782 |
+
)
|
783 |
+
ratio_CRFR_P.change(
|
784 |
+
fn=do_CRFR_P_masking,
|
785 |
+
inputs=[ori_image, parsing_map_vis, ratio_CRFR_P, mask_region_CRFR_P],
|
786 |
+
outputs=[CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image],
|
787 |
+
)
|
788 |
+
mask_region_CRFR_P.change(
|
789 |
+
fn=do_CRFR_P_masking,
|
790 |
+
inputs=[ori_image, parsing_map_vis, ratio_CRFR_P, mask_region_CRFR_P],
|
791 |
+
outputs=[CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image],
|
792 |
+
)
|
793 |
+
|
794 |
+
if __name__ == "__main__":
|
795 |
+
gr.close_all()
|
796 |
+
demo.queue()
|
797 |
+
demo.launch()
|
facer/__init__.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from .io import read_hwc, write_hwc
|
5 |
+
from .util import hwc2bchw, bchw2hwc, bchw2bhwc, bhwc2bchw, bhwc2hwc
|
6 |
+
from .draw import draw_bchw, draw_landmarks
|
7 |
+
from .show import show_bchw, show_bhw
|
8 |
+
|
9 |
+
from .face_detection import FaceDetector
|
10 |
+
from .face_parsing import FaceParser
|
11 |
+
from .face_alignment import FaceAlignment
|
12 |
+
from .face_attribute import FaceAttribute
|
13 |
+
|
14 |
+
|
15 |
+
def _split_name(name: str) -> Tuple[str, Optional[str]]:
|
16 |
+
if '/' in name:
|
17 |
+
detector_type, conf_name = name.split('/', 1)
|
18 |
+
else:
|
19 |
+
detector_type, conf_name = name, None
|
20 |
+
return detector_type, conf_name
|
21 |
+
|
22 |
+
|
23 |
+
def face_detector(name: str, device: torch.device, **kwargs) -> FaceDetector:
|
24 |
+
detector_type, conf_name = _split_name(name)
|
25 |
+
if detector_type == 'retinaface':
|
26 |
+
from .face_detection import RetinaFaceDetector
|
27 |
+
return RetinaFaceDetector(conf_name, **kwargs).to(device)
|
28 |
+
else:
|
29 |
+
raise RuntimeError(f'Unknown detector type: {detector_type}')
|
30 |
+
|
31 |
+
|
32 |
+
def face_parser(name: str, device: torch.device, **kwargs) -> FaceParser:
|
33 |
+
parser_type, conf_name = _split_name(name)
|
34 |
+
if parser_type == 'farl':
|
35 |
+
from .face_parsing import FaRLFaceParser
|
36 |
+
return FaRLFaceParser(conf_name, device=device, **kwargs).to(device)
|
37 |
+
else:
|
38 |
+
raise RuntimeError(f'Unknown parser type: {parser_type}')
|
39 |
+
|
40 |
+
|
41 |
+
def face_aligner(name: str, device: torch.device, **kwargs) -> FaceAlignment:
|
42 |
+
aligner_type, conf_name = _split_name(name)
|
43 |
+
if aligner_type == 'farl':
|
44 |
+
from .face_alignment import FaRLFaceAlignment
|
45 |
+
return FaRLFaceAlignment(conf_name, device=device, **kwargs).to(device)
|
46 |
+
else:
|
47 |
+
raise RuntimeError(f'Unknown aligner type: {aligner_type}')
|
48 |
+
|
49 |
+
def face_attr(name: str, device: torch.device, **kwargs) -> FaceAttribute:
|
50 |
+
attr_type, conf_name = _split_name(name)
|
51 |
+
if attr_type == 'farl':
|
52 |
+
from .face_attribute import FaRLFaceAttribute
|
53 |
+
return FaRLFaceAttribute(conf_name, device=device, **kwargs).to(device)
|
54 |
+
else:
|
55 |
+
raise RuntimeError(f'Unknown attribute type: {attr_type}')
|
facer/draw.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
import torch
|
3 |
+
import colorsys
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
from skimage.draw import line_aa, circle_perimeter_aa
|
7 |
+
import cv2
|
8 |
+
from .util import select_data
|
9 |
+
|
10 |
+
|
11 |
+
def _gen_random_colors(N, bright=True):
|
12 |
+
brightness = 1.0 if bright else 0.7
|
13 |
+
hsv = [(i / N, 1, brightness) for i in range(N)]
|
14 |
+
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
|
15 |
+
random.shuffle(colors)
|
16 |
+
return colors
|
17 |
+
|
18 |
+
|
19 |
+
_static_label_colors = [
|
20 |
+
np.array((1.0, 1.0, 1.0), np.float32),
|
21 |
+
np.array((255, 250, 79), np.float32) / 255.0, # face
|
22 |
+
np.array([255, 125, 138], np.float32) / 255.0, # lb
|
23 |
+
np.array([213, 32, 29], np.float32) / 255.0, # rb
|
24 |
+
np.array([0, 144, 187], np.float32) / 255.0, # le
|
25 |
+
np.array([0, 196, 253], np.float32) / 255.0, # re
|
26 |
+
np.array([255, 129, 54], np.float32) / 255.0, # nose
|
27 |
+
np.array([88, 233, 135], np.float32) / 255.0, # ulip
|
28 |
+
np.array([0, 117, 27], np.float32) / 255.0, # llip
|
29 |
+
np.array([255, 76, 249], np.float32) / 255.0, # imouth
|
30 |
+
np.array((1.0, 0.0, 0.0), np.float32), # hair
|
31 |
+
np.array((255, 250, 100), np.float32) / 255.0, # lr
|
32 |
+
np.array((255, 250, 100), np.float32) / 255.0, # rr
|
33 |
+
np.array((250, 245, 50), np.float32) / 255.0, # neck
|
34 |
+
np.array((0.0, 1.0, 0.5), np.float32), # cloth
|
35 |
+
np.array((1.0, 0.0, 0.5), np.float32),
|
36 |
+
] + _gen_random_colors(256)
|
37 |
+
|
38 |
+
_names_in_static_label_colors = [
|
39 |
+
'background', 'face', 'lb', 'rb', 'le', 're', 'nose',
|
40 |
+
'ulip', 'llip', 'imouth', 'hair', 'lr', 'rr', 'neck',
|
41 |
+
'cloth', 'eyeg', 'hat', 'earr'
|
42 |
+
]
|
43 |
+
|
44 |
+
|
45 |
+
def _blend_labels(image, labels, label_names_dict=None,
|
46 |
+
default_alpha=0.6, color_offset=None):
|
47 |
+
assert labels.ndim == 2
|
48 |
+
bg_mask = labels == 0
|
49 |
+
if label_names_dict is None:
|
50 |
+
colors = _static_label_colors
|
51 |
+
else:
|
52 |
+
colors = [np.array((1.0, 1.0, 1.0), np.float32)]
|
53 |
+
for i in range(1, labels.max() + 1):
|
54 |
+
if isinstance(label_names_dict, dict) and i not in label_names_dict:
|
55 |
+
bg_mask = np.logical_or(bg_mask, labels == i)
|
56 |
+
colors.append(np.zeros((3)))
|
57 |
+
continue
|
58 |
+
label_name = label_names_dict[i]
|
59 |
+
if label_name in _names_in_static_label_colors:
|
60 |
+
color = _static_label_colors[
|
61 |
+
_names_in_static_label_colors.index(
|
62 |
+
label_name)]
|
63 |
+
else:
|
64 |
+
color = np.array((1.0, 1.0, 1.0), np.float32)
|
65 |
+
colors.append(color)
|
66 |
+
|
67 |
+
if color_offset is not None:
|
68 |
+
ncolors = []
|
69 |
+
for c in colors:
|
70 |
+
nc = np.array(c)
|
71 |
+
if (nc != np.zeros(3)).any():
|
72 |
+
nc += color_offset
|
73 |
+
ncolors.append(nc)
|
74 |
+
colors = ncolors
|
75 |
+
|
76 |
+
if image is None:
|
77 |
+
image = orig_image = np.zeros(
|
78 |
+
[labels.shape[0], labels.shape[1], 3], np.float32)
|
79 |
+
alpha = 1.0
|
80 |
+
else:
|
81 |
+
orig_image = image / np.max(image)
|
82 |
+
image = orig_image * (1.0 - default_alpha)
|
83 |
+
alpha = default_alpha
|
84 |
+
for i in range(1, np.max(labels) + 1):
|
85 |
+
image += alpha * \
|
86 |
+
np.tile(
|
87 |
+
np.expand_dims(
|
88 |
+
(labels == i).astype(np.float32), -1),
|
89 |
+
[1, 1, 3]) * colors[(i) % len(colors)]
|
90 |
+
image[np.where(image > 1.0)] = 1.0
|
91 |
+
image[np.where(image < 0)] = 0.0
|
92 |
+
image[np.where(bg_mask)] = orig_image[np.where(bg_mask)]
|
93 |
+
return image
|
94 |
+
|
95 |
+
|
96 |
+
def _draw_hwc(image: torch.Tensor, data: Dict[str, torch.Tensor]):
|
97 |
+
device = image.device
|
98 |
+
image = np.array(image.cpu().numpy(), copy=True)
|
99 |
+
dtype = image.dtype
|
100 |
+
h, w, _ = image.shape
|
101 |
+
|
102 |
+
draw_score_error = False
|
103 |
+
for tag, batch_content in data.items():
|
104 |
+
if tag == 'rects':
|
105 |
+
for cid, content in enumerate(batch_content):
|
106 |
+
x1, y1, x2, y2 = [int(v) for v in content]
|
107 |
+
y1, y2 = [max(min(v, h-1), 0) for v in [y1, y2]]
|
108 |
+
x1, x2 = [max(min(v, w-1), 0) for v in [x1, x2]]
|
109 |
+
for xx1, yy1, xx2, yy2 in [
|
110 |
+
[x1, y1, x2, y1],
|
111 |
+
[x1, y2, x2, y2],
|
112 |
+
[x1, y1, x1, y2],
|
113 |
+
[x2, y1, x2, y2]
|
114 |
+
]:
|
115 |
+
rr, cc, val = line_aa(yy1, xx1, yy2, xx2)
|
116 |
+
val = val[:, None][:, [0, 0, 0]]
|
117 |
+
image[rr, cc] = image[rr, cc] * (1.0-val) + val * 255
|
118 |
+
|
119 |
+
if 'scores' in data:
|
120 |
+
try:
|
121 |
+
import cv2
|
122 |
+
score = data['scores'][cid].item()
|
123 |
+
score_str = f'{score:0.3f}'
|
124 |
+
image_c = np.array(image).copy()
|
125 |
+
cv2.putText(image_c, score_str, org=(int(x1), int(y2)),
|
126 |
+
fontFace=cv2.FONT_HERSHEY_TRIPLEX,
|
127 |
+
fontScale=0.6, color=(255, 255, 255), thickness=1)
|
128 |
+
image[:, :, :] = image_c
|
129 |
+
except Exception as e:
|
130 |
+
if not draw_score_error:
|
131 |
+
print(f'Failed to draw scores on image.')
|
132 |
+
print(e)
|
133 |
+
draw_score_error = True
|
134 |
+
|
135 |
+
if tag == 'points':
|
136 |
+
for content in batch_content:
|
137 |
+
# content: npoints x 2
|
138 |
+
for x, y in content:
|
139 |
+
x = max(min(int(x), w-1), 0)
|
140 |
+
y = max(min(int(y), h-1), 0)
|
141 |
+
rr, cc, val = circle_perimeter_aa(y, x, 1)
|
142 |
+
valid = np.all([rr >= 0, rr < h, cc >= 0, cc < w], axis=0)
|
143 |
+
rr = rr[valid]
|
144 |
+
cc = cc[valid]
|
145 |
+
val = val[valid]
|
146 |
+
val = val[:, None][:, [0, 0, 0]]
|
147 |
+
image[rr, cc] = image[rr, cc] * (1.0-val) + val * 255
|
148 |
+
|
149 |
+
if tag == 'seg':
|
150 |
+
label_names = batch_content['label_names']
|
151 |
+
for seg_logits in batch_content['logits']:
|
152 |
+
# content: nclasses x h x w
|
153 |
+
seg_probs = seg_logits.softmax(dim=0)
|
154 |
+
seg_labels = seg_probs.argmax(dim=0).cpu().numpy()
|
155 |
+
image = (_blend_labels(image.astype(np.float32) /
|
156 |
+
255, seg_labels,
|
157 |
+
label_names_dict=label_names) * 255).astype(dtype)
|
158 |
+
|
159 |
+
return torch.from_numpy(image).to(device=device)
|
160 |
+
|
161 |
+
|
162 |
+
def draw_bchw(images: torch.Tensor, data: Dict[str, torch.Tensor]) -> torch.Tensor:
|
163 |
+
images2 = []
|
164 |
+
for image_id, image_chw in enumerate(images):
|
165 |
+
selected_data = select_data(image_id == data['image_ids'], data)
|
166 |
+
images2.append(
|
167 |
+
_draw_hwc(image_chw.permute(1, 2, 0), selected_data).permute(2, 0, 1))
|
168 |
+
return torch.stack(images2, dim=0)
|
169 |
+
|
170 |
+
def draw_landmarks(img, bbox=None, landmark=None, color=(0, 255, 0)):
|
171 |
+
"""
|
172 |
+
Input:
|
173 |
+
- img: gray or RGB
|
174 |
+
- bbox: type of BBox
|
175 |
+
- landmark: reproject landmark of (5L, 2L)
|
176 |
+
Output:
|
177 |
+
- img marked with landmark and bbox
|
178 |
+
"""
|
179 |
+
img = cv2.UMat(img).get()
|
180 |
+
if bbox is not None:
|
181 |
+
x1, y1, x2, y2 = np.array(bbox)[:4].astype(np.int32)
|
182 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
|
183 |
+
if landmark is not None:
|
184 |
+
for x, y in np.array(landmark).astype(np.int32):
|
185 |
+
cv2.circle(img, (int(x), int(y)), 2, color, -1)
|
186 |
+
return img
|
facer/face_alignment/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .base import FaceAlignment
|
2 |
+
from .farl import FaRLFaceAlignment
|
facer/face_alignment/base.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class FaceAlignment(nn.Module):
|
5 |
+
""" face alignment
|
6 |
+
|
7 |
+
Args:
|
8 |
+
images (torch.Tensor): b x c x h x w
|
9 |
+
|
10 |
+
data (Dict[str, Any]):
|
11 |
+
|
12 |
+
* image_ids (torch.Tensor): nfaces
|
13 |
+
* rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
|
14 |
+
* points (torch.Tensor): nfaces x 5 x 2 (x, y)
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
data (Dict[str, Any]):
|
18 |
+
|
19 |
+
* image_ids (torch.Tensor): nfaces
|
20 |
+
* rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
|
21 |
+
* points (torch.Tensor): nfaces x 5 x 2 (x, y)
|
22 |
+
* alignment
|
23 |
+
"""
|
24 |
+
pass
|
facer/face_alignment/farl.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Dict, Any
|
2 |
+
import functools
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from .network import FaRLVisualFeatures, MMSEG_UPerHead, FaceAlignmentTransformer, denormalize_points, heatmap2points
|
6 |
+
from ..transform import (get_face_align_matrix,
|
7 |
+
make_inverted_tanh_warp_grid, make_tanh_warp_grid)
|
8 |
+
from .base import FaceAlignment
|
9 |
+
from ..util import download_jit
|
10 |
+
import io
|
11 |
+
|
12 |
+
pretrain_settings = {
|
13 |
+
'ibug300w/448': {
|
14 |
+
# inter_ocular 0.028835 epoch 60
|
15 |
+
'num_classes': 68,
|
16 |
+
'url': "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_alignment.farl.ibug300w.main_ema_jit.pt",
|
17 |
+
'matrix_src_tag': 'points',
|
18 |
+
'get_matrix_fn': functools.partial(get_face_align_matrix,
|
19 |
+
target_shape=(448, 448), target_face_scale=0.8),
|
20 |
+
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
21 |
+
warp_factor=0.0, warped_shape=(448, 448)),
|
22 |
+
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
23 |
+
warp_factor=0.0, warped_shape=(448, 448)),
|
24 |
+
|
25 |
+
},
|
26 |
+
'aflw19/448': {
|
27 |
+
# diag 0.009329 epoch 15
|
28 |
+
'num_classes': 19,
|
29 |
+
'url': "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_alignment.farl.aflw19.main_ema_jit.pt",
|
30 |
+
'matrix_src_tag': 'points',
|
31 |
+
'get_matrix_fn': functools.partial(get_face_align_matrix,
|
32 |
+
target_shape=(448, 448), target_face_scale=0.8),
|
33 |
+
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
34 |
+
warp_factor=0.0, warped_shape=(448, 448)),
|
35 |
+
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
36 |
+
warp_factor=0.0, warped_shape=(448, 448)),
|
37 |
+
},
|
38 |
+
'wflw/448': {
|
39 |
+
# inter_ocular 0.038933 epoch 20
|
40 |
+
'num_classes': 98,
|
41 |
+
'url': "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_alignment.farl.wflw.main_ema_jit.pt",
|
42 |
+
'matrix_src_tag': 'points',
|
43 |
+
'get_matrix_fn': functools.partial(get_face_align_matrix,
|
44 |
+
target_shape=(448, 448), target_face_scale=0.8),
|
45 |
+
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
46 |
+
warp_factor=0.0, warped_shape=(448, 448)),
|
47 |
+
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
48 |
+
warp_factor=0.0, warped_shape=(448, 448)),
|
49 |
+
},
|
50 |
+
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
def load_face_alignment_model(model_path: str, num_classes=68):
|
55 |
+
backbone = FaRLVisualFeatures("base", None, forced_input_resolution=448, output_indices=None).cpu()
|
56 |
+
if "jit" in model_path:
|
57 |
+
extra_files = {"backbone": None}
|
58 |
+
heatmap_head = download_jit(model_path, map_location="cpu", _extra_files=extra_files)
|
59 |
+
backbone_weight_io = io.BytesIO(extra_files["backbone"])
|
60 |
+
backbone.load_state_dict(torch.load(backbone_weight_io))
|
61 |
+
# print("load from jit")
|
62 |
+
else:
|
63 |
+
channels = backbone.get_output_channel("base")
|
64 |
+
in_channels = [channels] * 4
|
65 |
+
num_classes = num_classes
|
66 |
+
heatmap_head = MMSEG_UPerHead(in_channels=in_channels, channels=channels, num_classes=num_classes) # this requires mmseg as a dependency
|
67 |
+
state = torch.load(model_path,map_location="cpu")["networks"]["main_ema"]
|
68 |
+
# print("load from checkpoint")
|
69 |
+
|
70 |
+
main_network = FaceAlignmentTransformer(backbone, heatmap_head, heatmap_act="sigmoid").cpu()
|
71 |
+
|
72 |
+
if "jit" not in model_path:
|
73 |
+
main_network.load_state_dict(state, strict=True)
|
74 |
+
|
75 |
+
return main_network
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
class FaRLFaceAlignment(FaceAlignment):
|
80 |
+
""" The face alignment models from [FaRL](https://github.com/FacePerceiver/FaRL).
|
81 |
+
|
82 |
+
Please consider citing
|
83 |
+
```bibtex
|
84 |
+
@article{zheng2021farl,
|
85 |
+
title={General Facial Representation Learning in a Visual-Linguistic Manner},
|
86 |
+
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
|
87 |
+
Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
|
88 |
+
Dong and Zeng, Ming and Wen, Fang},
|
89 |
+
journal={arXiv preprint arXiv:2112.03109},
|
90 |
+
year={2021}
|
91 |
+
}
|
92 |
+
```
|
93 |
+
"""
|
94 |
+
|
95 |
+
def __init__(self, conf_name: Optional[str] = None,
|
96 |
+
model_path: Optional[str] = None, device=None) -> None:
|
97 |
+
super().__init__()
|
98 |
+
if conf_name is None:
|
99 |
+
conf_name = 'ibug300w/448'
|
100 |
+
if model_path is None:
|
101 |
+
model_path = pretrain_settings[conf_name]['url']
|
102 |
+
self.conf_name = conf_name
|
103 |
+
|
104 |
+
setting = pretrain_settings[self.conf_name]
|
105 |
+
self.net = load_face_alignment_model(model_path, num_classes = setting["num_classes"])
|
106 |
+
if device is not None:
|
107 |
+
self.net = self.net.to(device)
|
108 |
+
|
109 |
+
self.heatmap_interpolate_mode = 'bilinear'
|
110 |
+
self.eval()
|
111 |
+
|
112 |
+
def forward(self, images: torch.Tensor, data: Dict[str, Any]):
|
113 |
+
setting = pretrain_settings[self.conf_name]
|
114 |
+
images = images.float() / 255.0 # backbone 自带 normalize
|
115 |
+
_, _, h, w = images.shape
|
116 |
+
|
117 |
+
simages = images[data['image_ids']]
|
118 |
+
matrix = setting['get_matrix_fn'](data[setting['matrix_src_tag']])
|
119 |
+
grid = setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
|
120 |
+
inv_grid = setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
|
121 |
+
|
122 |
+
w_images = F.grid_sample(
|
123 |
+
simages, grid, mode='bilinear', align_corners=False)
|
124 |
+
|
125 |
+
_, _, warp_h, warp_w = w_images.shape
|
126 |
+
|
127 |
+
heatmap_acted = self.net(w_images)
|
128 |
+
|
129 |
+
warpped_heatmap = F.interpolate(
|
130 |
+
heatmap_acted, size=(warp_h, warp_w),
|
131 |
+
mode=self.heatmap_interpolate_mode, align_corners=False)
|
132 |
+
|
133 |
+
pred_heatmap = F.grid_sample(
|
134 |
+
warpped_heatmap, inv_grid, mode='bilinear', align_corners=False)
|
135 |
+
|
136 |
+
landmark = heatmap2points(pred_heatmap)
|
137 |
+
|
138 |
+
landmark = denormalize_points(landmark, h, w)
|
139 |
+
|
140 |
+
data['alignment'] = landmark
|
141 |
+
|
142 |
+
return data
|
143 |
+
|
144 |
+
|
145 |
+
if __name__=="__main__":
|
146 |
+
image = torch.randn(1, 3, 448, 448)
|
147 |
+
|
148 |
+
aligner1 = FaRLFaceAlignment("wflw/448")
|
149 |
+
|
150 |
+
x1 = aligner1.net(image)
|
151 |
+
|
152 |
+
import argparse
|
153 |
+
|
154 |
+
parser = argparse.ArgumentParser()
|
155 |
+
parser.add_argument("--jit_path", type=str, default=None)
|
156 |
+
args = parser.parse_args()
|
157 |
+
|
158 |
+
if args.jit_path is None:
|
159 |
+
exit(0)
|
160 |
+
|
161 |
+
net = aligner1.net.cpu()
|
162 |
+
|
163 |
+
features, _ = net.backbone(image)
|
164 |
+
|
165 |
+
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
|
166 |
+
traced_script_module = torch.jit.trace(net.heatmap_head, example_inputs=[features])
|
167 |
+
|
168 |
+
buffer = io.BytesIO()
|
169 |
+
|
170 |
+
torch.save(net.backbone.state_dict(), buffer)
|
171 |
+
|
172 |
+
# Save to file
|
173 |
+
torch.jit.save(traced_script_module, args.jit_path,
|
174 |
+
_extra_files={"backbone": buffer.getvalue()})
|
175 |
+
|
176 |
+
aligner2 = FaRLFaceAlignment(model_path=args.jit_path)
|
177 |
+
|
178 |
+
# compare the output
|
179 |
+
x2 = aligner2.net(image)
|
180 |
+
print(torch.allclose(x1, x2))
|
facer/face_alignment/network/__init__.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
from .common import (load_checkpoint, Activation, MLP, Residual)
|
5 |
+
from .geometry import (normalize_points, denormalize_points,
|
6 |
+
heatmap2points)
|
7 |
+
from .mmseg import MMSEG_UPerHead
|
8 |
+
from .transformers import FaRLVisualFeatures
|
9 |
+
from torch import nn
|
10 |
+
from typing import Optional, List, Tuple
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
class FaceAlignmentTransformer(nn.Module):
|
15 |
+
"""Face alignment transformer.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
image (torch.Tensor): Float32 tensor with shape [b, 3, h, w], normalized to [0, 1].
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
landmark (torch.Tensor): Float32 tensor with shape [b, npoints, 2], coordinates normalized to [0, 1].
|
22 |
+
aux_outputs:
|
23 |
+
heatmap (torch.Tensor): Float32 tensor with shape [b, npoints, S, S]
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, backbone: nn.Module, heatmap_head: nn.Module,
|
27 |
+
heatmap_act: Optional[str] = 'relu'):
|
28 |
+
super().__init__()
|
29 |
+
self.backbone = backbone
|
30 |
+
self.heatmap_head = heatmap_head
|
31 |
+
self.heatmap_act = Activation(heatmap_act)
|
32 |
+
self.float()
|
33 |
+
|
34 |
+
def forward(self, image):
|
35 |
+
features, _ = self.backbone(image)
|
36 |
+
heatmap = self.heatmap_head(features) # b x npoints x s x s
|
37 |
+
heatmap_acted = self.heatmap_act(heatmap)
|
38 |
+
# landmark = heatmap2points(heatmap_acted) # b x npoints x 2
|
39 |
+
# return landmark, {'heatmap': heatmap, 'heatmap_acted': heatmap_acted}
|
40 |
+
return heatmap_acted
|
41 |
+
|
42 |
+
|
facer/face_alignment/network/common.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
from typing import List, Optional, Tuple, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
def load_checkpoint(net: nn.Module, checkpoint_path: str, network_name: str):
|
12 |
+
states = torch.load(open(checkpoint_path, 'rb'), map_location={
|
13 |
+
'cuda:0': f'cuda:{torch.cuda.current_device()}'})
|
14 |
+
network_states = states['networks']
|
15 |
+
net.load_state_dict(network_states[network_name])
|
16 |
+
return net
|
17 |
+
|
18 |
+
|
19 |
+
class Activation(nn.Module):
|
20 |
+
def __init__(self, name: Optional[str], **kwargs):
|
21 |
+
super().__init__()
|
22 |
+
if name == 'relu':
|
23 |
+
self.fn = F.relu
|
24 |
+
elif name == 'softplus':
|
25 |
+
self.fn = F.softplus
|
26 |
+
elif name == 'gelu':
|
27 |
+
self.fn = F.gelu
|
28 |
+
elif name == 'sigmoid':
|
29 |
+
self.fn = torch.sigmoid
|
30 |
+
elif name == 'sigmoid_x':
|
31 |
+
self.epsilon = kwargs.get('epsilon', 1e-3)
|
32 |
+
self.fn = lambda x: torch.clamp(
|
33 |
+
x.sigmoid() * (1.0 + self.epsilon*2.0) - self.epsilon,
|
34 |
+
min=0.0, max=1.0)
|
35 |
+
elif name == None:
|
36 |
+
self.fn = lambda x: x
|
37 |
+
else:
|
38 |
+
raise RuntimeError(f'Unknown activation name: {name}')
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return self.fn(x)
|
42 |
+
|
43 |
+
|
44 |
+
class MLP(nn.Module):
|
45 |
+
def __init__(self, channels: List[int], act: Optional[str]):
|
46 |
+
super().__init__()
|
47 |
+
assert len(channels) > 1
|
48 |
+
layers = []
|
49 |
+
for i in range(len(channels)-1):
|
50 |
+
layers.append(nn.Linear(channels[i], channels[i+1]))
|
51 |
+
if i+1 < len(channels):
|
52 |
+
layers.append(Activation(act))
|
53 |
+
self.layers = nn.Sequential(*layers)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
return self.layers(x)
|
57 |
+
|
58 |
+
|
59 |
+
class Residual(nn.Module):
|
60 |
+
def __init__(self, net: nn.Module, res_weight_init: Optional[float] = 0.0):
|
61 |
+
super().__init__()
|
62 |
+
self.net = net
|
63 |
+
if res_weight_init is not None:
|
64 |
+
self.res_weight = nn.Parameter(torch.tensor(res_weight_init))
|
65 |
+
else:
|
66 |
+
self.res_weight = None
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
if self.res_weight is not None:
|
70 |
+
return self.res_weight * self.net(x) + x
|
71 |
+
else:
|
72 |
+
return self.net(x) + x
|
73 |
+
|
74 |
+
|
75 |
+
class SE(nn.Module):
|
76 |
+
def __init__(self, channel: int, r: int = 1):
|
77 |
+
super().__init__()
|
78 |
+
self.branch = nn.Sequential(
|
79 |
+
nn.Conv2d(channel, channel//r, (1, 1)),
|
80 |
+
nn.ReLU(),
|
81 |
+
nn.Conv2d(channel//r, channel, (1, 1)),
|
82 |
+
nn.Sigmoid()
|
83 |
+
)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
# x: b x channel x h x w
|
87 |
+
v = x.mean([2, 3], keepdim=True) # b x channel x 1 x 1
|
88 |
+
v = self.branch(v) # b x channel x 1 x 1
|
89 |
+
return x * v
|
90 |
+
|
91 |
+
|
facer/face_alignment/network/geometry.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
from typing import Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def normalize_points(points: torch.Tensor, h: int, w: int) -> torch.Tensor:
|
10 |
+
""" Normalize coordinates to [0, 1].
|
11 |
+
"""
|
12 |
+
return (points + 0.5) / torch.tensor([[[w, h]]]).to(points)
|
13 |
+
|
14 |
+
|
15 |
+
def denormalize_points(normalized_points: torch.Tensor, h: int, w: int) -> torch.Tensor:
|
16 |
+
""" Reverse normalize_points.
|
17 |
+
"""
|
18 |
+
return normalized_points * torch.tensor([[[w, h]]]).to(normalized_points) - 0.5
|
19 |
+
|
20 |
+
|
21 |
+
def heatmap2points(heatmap, t_scale: Union[None, float, torch.Tensor] = None):
|
22 |
+
""" Heatmaps -> normalized points [b x npoints x 2(XY)].
|
23 |
+
"""
|
24 |
+
dtype = heatmap.dtype
|
25 |
+
_, _, h, w = heatmap.shape
|
26 |
+
|
27 |
+
# 0 ~ h-1, 0 ~ w-1
|
28 |
+
yy, xx = torch.meshgrid(
|
29 |
+
torch.arange(h).float(),
|
30 |
+
torch.arange(w).float())
|
31 |
+
|
32 |
+
yy = yy.view(1, 1, h, w).to(heatmap)
|
33 |
+
xx = xx.view(1, 1, h, w).to(heatmap)
|
34 |
+
|
35 |
+
if t_scale is not None:
|
36 |
+
heatmap = (heatmap * t_scale).exp()
|
37 |
+
heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6)
|
38 |
+
|
39 |
+
yy_coord = (yy * heatmap).sum([2, 3]) / heatmap_sum # b x npoints
|
40 |
+
xx_coord = (xx * heatmap).sum([2, 3]) / heatmap_sum # b x npoints
|
41 |
+
|
42 |
+
points = torch.stack([xx_coord, yy_coord], dim=-1) # b x npoints x 2
|
43 |
+
|
44 |
+
normalized_points = normalize_points(points, h, w)
|
45 |
+
return normalized_points
|
facer/face_alignment/network/mmseg.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class MMSEG_UPerHead(nn.Module):
|
8 |
+
"""Wraps the UPerHead from mmseg for segmentation.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, num_classes: int,
|
12 |
+
in_channels: list = [384, 384, 384, 384], channels: int = 512):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
from mmseg.models.decode_heads import UPerHead
|
16 |
+
self.head = UPerHead(
|
17 |
+
in_channels=in_channels,
|
18 |
+
in_index=[0, 1, 2, 3],
|
19 |
+
pool_scales=(1, 2, 3, 6),
|
20 |
+
channels=channels,
|
21 |
+
dropout_ratio=0.1,
|
22 |
+
num_classes=num_classes,
|
23 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
24 |
+
align_corners=False,
|
25 |
+
loss_decode=dict(
|
26 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
27 |
+
|
28 |
+
def forward(self, inputs):
|
29 |
+
return self.head(inputs)
|
facer/face_alignment/network/transformers.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
from typing import Optional, List, Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
from ... import farl
|
14 |
+
|
15 |
+
|
16 |
+
def _make_fpns(vision_patch_size: int, output_channels: int):
|
17 |
+
if vision_patch_size in {16, 14}:
|
18 |
+
fpn1 = nn.Sequential(
|
19 |
+
nn.ConvTranspose2d(output_channels, output_channels,
|
20 |
+
kernel_size=2, stride=2),
|
21 |
+
nn.SyncBatchNorm(output_channels),
|
22 |
+
nn.GELU(),
|
23 |
+
nn.ConvTranspose2d(output_channels, output_channels, kernel_size=2, stride=2))
|
24 |
+
|
25 |
+
fpn2 = nn.ConvTranspose2d(
|
26 |
+
output_channels, output_channels, kernel_size=2, stride=2)
|
27 |
+
fpn3 = nn.Identity()
|
28 |
+
fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
29 |
+
return nn.ModuleList([fpn1, fpn2, fpn3, fpn4])
|
30 |
+
elif vision_patch_size == 8:
|
31 |
+
fpn1 = nn.Sequential(nn.ConvTranspose2d(
|
32 |
+
output_channels, output_channels, kernel_size=2, stride=2))
|
33 |
+
fpn2 = nn.Identity()
|
34 |
+
fpn3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
35 |
+
fpn4 = nn.MaxPool2d(kernel_size=4, stride=4)
|
36 |
+
return nn.ModuleList([fpn1, fpn2, fpn3, fpn4])
|
37 |
+
else:
|
38 |
+
raise NotImplementedError()
|
39 |
+
|
40 |
+
|
41 |
+
def _resize_pe(pe: torch.Tensor, new_size: int, mode: str = 'bicubic', num_tokens: int = 1) -> torch.Tensor:
|
42 |
+
"""Resize positional embeddings.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
pe (torch.Tensor): A tensor with shape (num_tokens + old_size ** 2, width). pe[0, :] is the CLS token.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
torch.Tensor: A tensor with shape (num_tokens + new_size **2, width).
|
49 |
+
"""
|
50 |
+
l, w = pe.shape
|
51 |
+
old_size = int(math.sqrt(l-num_tokens))
|
52 |
+
assert old_size ** 2 + num_tokens == l
|
53 |
+
return torch.cat([
|
54 |
+
pe[:num_tokens, :],
|
55 |
+
F.interpolate(pe[num_tokens:, :].reshape(1, old_size, old_size, w).permute(0, 3, 1, 2),
|
56 |
+
(new_size, new_size), mode=mode, align_corners=False).view(w, -1).t()], dim=0)
|
57 |
+
|
58 |
+
|
59 |
+
class FaRLVisualFeatures(nn.Module):
|
60 |
+
"""Extract features from FaRL visual encoder.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
image (torch.Tensor): Float32 tensor with shape [b, 3, h, w],
|
64 |
+
normalized to [0, 1].
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
List[torch.Tensor]: A list of features.
|
68 |
+
"""
|
69 |
+
image_mean: torch.Tensor
|
70 |
+
image_std: torch.Tensor
|
71 |
+
output_channels: int
|
72 |
+
num_outputs: int
|
73 |
+
|
74 |
+
def __init__(self, model_type: str,
|
75 |
+
model_path: Optional[str] = None, output_indices: Optional[List[int]] = None,
|
76 |
+
forced_input_resolution: Optional[int] = None,
|
77 |
+
apply_fpn: bool = True):
|
78 |
+
super().__init__()
|
79 |
+
self.visual = farl.load_farl(model_type, model_path)
|
80 |
+
|
81 |
+
vision_patch_size = self.visual.conv1.weight.shape[-1]
|
82 |
+
|
83 |
+
self.input_resolution = self.visual.input_resolution
|
84 |
+
if forced_input_resolution is not None and \
|
85 |
+
self.input_resolution != forced_input_resolution:
|
86 |
+
# resizing the positonal embeddings
|
87 |
+
self.visual.positional_embedding = nn.Parameter(
|
88 |
+
_resize_pe(self.visual.positional_embedding,
|
89 |
+
forced_input_resolution//vision_patch_size))
|
90 |
+
self.input_resolution = forced_input_resolution
|
91 |
+
|
92 |
+
self.output_channels = self.visual.transformer.width
|
93 |
+
|
94 |
+
if output_indices is None:
|
95 |
+
output_indices = self.__class__.get_default_output_indices(
|
96 |
+
model_type)
|
97 |
+
self.output_indices = output_indices
|
98 |
+
self.num_outputs = len(output_indices)
|
99 |
+
|
100 |
+
self.register_buffer('image_mean', torch.tensor(
|
101 |
+
[0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1))
|
102 |
+
self.register_buffer('image_std', torch.tensor(
|
103 |
+
[0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1))
|
104 |
+
|
105 |
+
if apply_fpn:
|
106 |
+
self.fpns = _make_fpns(vision_patch_size, self.output_channels)
|
107 |
+
else:
|
108 |
+
self.fpns = None
|
109 |
+
|
110 |
+
@staticmethod
|
111 |
+
def get_output_channel(model_type):
|
112 |
+
if model_type == 'base':
|
113 |
+
return 768
|
114 |
+
if model_type == 'large':
|
115 |
+
return 1024
|
116 |
+
if model_type == 'huge':
|
117 |
+
return 1280
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def get_default_output_indices(model_type):
|
121 |
+
if model_type == 'base':
|
122 |
+
return [3, 5, 7, 11]
|
123 |
+
if model_type == 'large':
|
124 |
+
return [7, 11, 15, 23]
|
125 |
+
if model_type == 'huge':
|
126 |
+
return [8, 14, 20, 31]
|
127 |
+
|
128 |
+
def forward(self, image: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
129 |
+
# b x 3 x res x res
|
130 |
+
_, _, input_h, input_w = image.shape
|
131 |
+
if input_h != self.input_resolution or input_w != self.input_resolution:
|
132 |
+
image = F.interpolate(image, self.input_resolution,
|
133 |
+
mode='bilinear', align_corners=False)
|
134 |
+
|
135 |
+
image = (image - self.image_mean.to(image.device)) / self.image_std.to(image.device)
|
136 |
+
|
137 |
+
x = image.to(self.visual.conv1.weight.data)
|
138 |
+
|
139 |
+
x = self.visual.conv1(x) # shape = [*, width, grid, grid]
|
140 |
+
N, _, S, S = x.shape
|
141 |
+
|
142 |
+
# shape = [*, width, grid ** 2]
|
143 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
144 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
145 |
+
x = torch.cat([self.visual.class_embedding.to(x.dtype) +
|
146 |
+
torch.zeros(x.shape[0], 1, x.shape[-1],
|
147 |
+
dtype=x.dtype, device=x.device),
|
148 |
+
x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
149 |
+
|
150 |
+
x = x + self.visual.positional_embedding.to(x.dtype)
|
151 |
+
|
152 |
+
x = self.visual.ln_pre(x)
|
153 |
+
|
154 |
+
x = x.permute(1, 0, 2).contiguous() # NLD -> LND
|
155 |
+
|
156 |
+
features = []
|
157 |
+
cls_tokens = []
|
158 |
+
for blk in self.visual.transformer.resblocks:
|
159 |
+
x = blk(x) # [S ** 2 + 1, N, D]
|
160 |
+
# if idx in self.output_indices:
|
161 |
+
feature = x[1:, :, :].permute(
|
162 |
+
1, 2, 0).view(N, -1, S, S).contiguous().float()
|
163 |
+
features.append(feature)
|
164 |
+
cls_tokens.append(x[0, :, :])
|
165 |
+
|
166 |
+
features = [features[ind] for ind in self.output_indices]
|
167 |
+
cls_tokens = [cls_tokens[ind] for ind in self.output_indices]
|
168 |
+
|
169 |
+
if self.fpns is not None:
|
170 |
+
for i, fpn in enumerate(self.fpns):
|
171 |
+
features[i] = fpn(features[i])
|
172 |
+
|
173 |
+
return features, cls_tokens
|
facer/face_attribute/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .base import FaceAttribute
|
2 |
+
from .farl import FaRLFaceAttribute
|
facer/face_attribute/base.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
class FaceAttribute(nn.Module):
|
4 |
+
""" face attribute base class
|
5 |
+
|
6 |
+
Args:
|
7 |
+
images (torch.Tensor): b x c x h x w
|
8 |
+
|
9 |
+
data (Dict[str, Any]):
|
10 |
+
|
11 |
+
* image_ids (torch.Tensor): nfaces
|
12 |
+
* rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
|
13 |
+
* points (torch.Tensor): nfaces x 5 x 2 (x, y)
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
data (Dict[str, Any]):
|
17 |
+
|
18 |
+
* image_ids (torch.Tensor): nfaces
|
19 |
+
* rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
|
20 |
+
* points (torch.Tensor): nfaces x 5 x 2 (x, y)
|
21 |
+
* attrs (Dict[str, Any]):
|
22 |
+
* logits (torch.Tensor): nfaces x nclasses
|
23 |
+
"""
|
24 |
+
pass
|
facer/face_attribute/farl.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Dict, Any
|
2 |
+
import functools
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from ..transform import get_face_align_matrix, make_tanh_warp_grid
|
6 |
+
from .base import FaceAttribute
|
7 |
+
from ..farl import farl_classification
|
8 |
+
from ..util import download_jit
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
def get_std_points_xray(out_size=256, mid_size=500):
|
13 |
+
std_points_256 = np.array(
|
14 |
+
[
|
15 |
+
[85.82991, 85.7792],
|
16 |
+
[169.0532, 84.3381],
|
17 |
+
[127.574, 137.0006],
|
18 |
+
[90.6964, 174.7014],
|
19 |
+
[167.3069, 173.3733],
|
20 |
+
]
|
21 |
+
)
|
22 |
+
std_points_256[:, 1] += 30
|
23 |
+
old_size = 256
|
24 |
+
mid = mid_size / 2
|
25 |
+
new_std_points = std_points_256 - old_size / 2 + mid
|
26 |
+
target_pts = new_std_points * out_size / mid_size
|
27 |
+
target_pts = torch.from_numpy(target_pts).float()
|
28 |
+
return target_pts
|
29 |
+
|
30 |
+
|
31 |
+
pretrain_settings = {
|
32 |
+
"celeba/224": {
|
33 |
+
# acc 92.06617474555969
|
34 |
+
"num_classes": 40,
|
35 |
+
"layers": [11],
|
36 |
+
"url": "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_attribute.farl.celeba.pt",
|
37 |
+
"matrix_src_tag": "points",
|
38 |
+
"get_matrix_fn": functools.partial(
|
39 |
+
get_face_align_matrix,
|
40 |
+
target_shape=(224, 224),
|
41 |
+
target_pts=get_std_points_xray(out_size=224, mid_size=500),
|
42 |
+
),
|
43 |
+
"get_grid_fn": functools.partial(
|
44 |
+
make_tanh_warp_grid, warp_factor=0.0, warped_shape=(224, 224)
|
45 |
+
),
|
46 |
+
"classes": [
|
47 |
+
"5_o_Clock_Shadow",
|
48 |
+
"Arched_Eyebrows",
|
49 |
+
"Attractive",
|
50 |
+
"Bags_Under_Eyes",
|
51 |
+
"Bald",
|
52 |
+
"Bangs",
|
53 |
+
"Big_Lips",
|
54 |
+
"Big_Nose",
|
55 |
+
"Black_Hair",
|
56 |
+
"Blond_Hair",
|
57 |
+
"Blurry",
|
58 |
+
"Brown_Hair",
|
59 |
+
"Bushy_Eyebrows",
|
60 |
+
"Chubby",
|
61 |
+
"Double_Chin",
|
62 |
+
"Eyeglasses",
|
63 |
+
"Goatee",
|
64 |
+
"Gray_Hair",
|
65 |
+
"Heavy_Makeup",
|
66 |
+
"High_Cheekbones",
|
67 |
+
"Male",
|
68 |
+
"Mouth_Slightly_Open",
|
69 |
+
"Mustache",
|
70 |
+
"Narrow_Eyes",
|
71 |
+
"No_Beard",
|
72 |
+
"Oval_Face",
|
73 |
+
"Pale_Skin",
|
74 |
+
"Pointy_Nose",
|
75 |
+
"Receding_Hairline",
|
76 |
+
"Rosy_Cheeks",
|
77 |
+
"Sideburns",
|
78 |
+
"Smiling",
|
79 |
+
"Straight_Hair",
|
80 |
+
"Wavy_Hair",
|
81 |
+
"Wearing_Earrings",
|
82 |
+
"Wearing_Hat",
|
83 |
+
"Wearing_Lipstick",
|
84 |
+
"Wearing_Necklace",
|
85 |
+
"Wearing_Necktie",
|
86 |
+
"Young",
|
87 |
+
],
|
88 |
+
}
|
89 |
+
}
|
90 |
+
|
91 |
+
|
92 |
+
def load_face_attr(model_path, num_classes=40, layers=[11]):
|
93 |
+
model = farl_classification(num_classes=num_classes, layers=layers)
|
94 |
+
state_dict = download_jit(model_path, jit=False)
|
95 |
+
model.load_state_dict(state_dict)
|
96 |
+
return model
|
97 |
+
|
98 |
+
|
99 |
+
class FaRLFaceAttribute(FaceAttribute):
|
100 |
+
"""The face attribute recognition models from [FaRL](https://github.com/FacePerceiver/FaRL).
|
101 |
+
|
102 |
+
Please consider citing
|
103 |
+
```bibtex
|
104 |
+
@article{zheng2021farl,
|
105 |
+
title={General Facial Representation Learning in a Visual-Linguistic Manner},
|
106 |
+
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
|
107 |
+
Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
|
108 |
+
Dong and Zeng, Ming and Wen, Fang},
|
109 |
+
journal={arXiv preprint arXiv:2112.03109},
|
110 |
+
year={2021}
|
111 |
+
}
|
112 |
+
```
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
conf_name: Optional[str] = None,
|
118 |
+
model_path: Optional[str] = None,
|
119 |
+
device=None,
|
120 |
+
) -> None:
|
121 |
+
super().__init__()
|
122 |
+
if conf_name is None:
|
123 |
+
conf_name = "celeba/224"
|
124 |
+
if model_path is None:
|
125 |
+
model_path = pretrain_settings[conf_name]["url"]
|
126 |
+
self.conf_name = conf_name
|
127 |
+
|
128 |
+
setting = pretrain_settings[self.conf_name]
|
129 |
+
self.labels = setting["classes"]
|
130 |
+
self.net = load_face_attr(model_path, num_classes=setting["num_classes"], layers = setting["layers"])
|
131 |
+
if device is not None:
|
132 |
+
self.net = self.net.to(device)
|
133 |
+
|
134 |
+
self.eval()
|
135 |
+
|
136 |
+
def forward(self, images: torch.Tensor, data: Dict[str, Any]):
|
137 |
+
setting = pretrain_settings[self.conf_name]
|
138 |
+
images = images.float() / 255.0 # backbone 自带 normalize
|
139 |
+
_, _, h, w = images.shape
|
140 |
+
|
141 |
+
simages = images[data["image_ids"]]
|
142 |
+
matrix = setting["get_matrix_fn"](data[setting["matrix_src_tag"]])
|
143 |
+
grid = setting["get_grid_fn"](matrix=matrix, orig_shape=(h, w))
|
144 |
+
|
145 |
+
w_images = F.grid_sample(simages, grid, mode="bilinear", align_corners=False)
|
146 |
+
|
147 |
+
outputs = self.net(w_images)
|
148 |
+
probs = torch.sigmoid(outputs)
|
149 |
+
|
150 |
+
data["attrs"] = probs
|
151 |
+
|
152 |
+
return data
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
model = FaRLFaceAttribute()
|
facer/face_detection/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .base import FaceDetector
|
2 |
+
from .retinaface import RetinaFaceDetector
|
facer/face_detection/base.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class FaceDetector(nn.Module):
|
6 |
+
""" face detector
|
7 |
+
|
8 |
+
Args:
|
9 |
+
images (torch.Tensor): b x c x h x w
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
data (Dict[str, torch.Tensor]):
|
13 |
+
|
14 |
+
* rects: nfaces x 4 (x1, y1, x2, y2)
|
15 |
+
* points: nfaces x 5 x 2 (x, y)
|
16 |
+
* scores: nfaces
|
17 |
+
* image_ids: nfaces
|
18 |
+
"""
|
19 |
+
pass
|
facer/face_detection/retinaface.py
ADDED
@@ -0,0 +1,672 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# largely borrowed from https://github.dev/elliottzheng/batch-face/face_detection/alignment.py
|
2 |
+
|
3 |
+
from typing import Dict, List, Optional, Tuple
|
4 |
+
import torch
|
5 |
+
import torch.backends.cudnn as cudnn
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchvision.models._utils as _utils
|
9 |
+
from .base import FaceDetector
|
10 |
+
|
11 |
+
|
12 |
+
from itertools import product as product
|
13 |
+
from math import ceil
|
14 |
+
|
15 |
+
|
16 |
+
pretrained_urls = {
|
17 |
+
"mobilenet": "https://github.com/elliottzheng/face-detection/releases/download/0.0.1/mobilenet0.25_Final.pth",
|
18 |
+
"resnet50": "https://github.com/elliottzheng/face-detection/releases/download/0.0.1/Resnet50_Final.pth"
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def conv_bn(inp, oup, stride=1, leaky=0):
|
23 |
+
return nn.Sequential(
|
24 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
25 |
+
nn.BatchNorm2d(oup),
|
26 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def conv_bn_no_relu(inp, oup, stride):
|
31 |
+
return nn.Sequential(
|
32 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
33 |
+
nn.BatchNorm2d(oup),
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def conv_bn1X1(inp, oup, stride, leaky=0):
|
38 |
+
return nn.Sequential(
|
39 |
+
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
|
40 |
+
nn.BatchNorm2d(oup),
|
41 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def conv_dw(inp, oup, stride, leaky=0.1):
|
46 |
+
return nn.Sequential(
|
47 |
+
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
|
48 |
+
nn.BatchNorm2d(inp),
|
49 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
50 |
+
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
51 |
+
nn.BatchNorm2d(oup),
|
52 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
class SSH(nn.Module):
|
57 |
+
def __init__(self, in_channel, out_channel):
|
58 |
+
super(SSH, self).__init__()
|
59 |
+
assert out_channel % 4 == 0
|
60 |
+
leaky = 0
|
61 |
+
if out_channel <= 64:
|
62 |
+
leaky = 0.1
|
63 |
+
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
|
64 |
+
|
65 |
+
self.conv5X5_1 = conv_bn(
|
66 |
+
in_channel, out_channel // 4, stride=1, leaky=leaky)
|
67 |
+
self.conv5X5_2 = conv_bn_no_relu(
|
68 |
+
out_channel // 4, out_channel // 4, stride=1)
|
69 |
+
|
70 |
+
self.conv7X7_2 = conv_bn(
|
71 |
+
out_channel // 4, out_channel // 4, stride=1, leaky=leaky
|
72 |
+
)
|
73 |
+
self.conv7x7_3 = conv_bn_no_relu(
|
74 |
+
out_channel // 4, out_channel // 4, stride=1)
|
75 |
+
|
76 |
+
def forward(self, input):
|
77 |
+
conv3X3 = self.conv3X3(input)
|
78 |
+
|
79 |
+
conv5X5_1 = self.conv5X5_1(input)
|
80 |
+
conv5X5 = self.conv5X5_2(conv5X5_1)
|
81 |
+
|
82 |
+
conv7X7_2 = self.conv7X7_2(conv5X5_1)
|
83 |
+
conv7X7 = self.conv7x7_3(conv7X7_2)
|
84 |
+
|
85 |
+
out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
|
86 |
+
out = F.relu(out)
|
87 |
+
return out
|
88 |
+
|
89 |
+
|
90 |
+
class FPN(nn.Module):
|
91 |
+
def __init__(self, in_channels_list, out_channels):
|
92 |
+
super(FPN, self).__init__()
|
93 |
+
leaky = 0
|
94 |
+
if out_channels <= 64:
|
95 |
+
leaky = 0.1
|
96 |
+
self.output1 = conv_bn1X1(
|
97 |
+
in_channels_list[0], out_channels, stride=1, leaky=leaky
|
98 |
+
)
|
99 |
+
self.output2 = conv_bn1X1(
|
100 |
+
in_channels_list[1], out_channels, stride=1, leaky=leaky
|
101 |
+
)
|
102 |
+
self.output3 = conv_bn1X1(
|
103 |
+
in_channels_list[2], out_channels, stride=1, leaky=leaky
|
104 |
+
)
|
105 |
+
|
106 |
+
self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
|
107 |
+
self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
|
108 |
+
|
109 |
+
def forward(self, input):
|
110 |
+
# names = list(input.keys())
|
111 |
+
input = list(input.values())
|
112 |
+
|
113 |
+
output1 = self.output1(input[0])
|
114 |
+
output2 = self.output2(input[1])
|
115 |
+
output3 = self.output3(input[2])
|
116 |
+
|
117 |
+
up3 = F.interpolate(
|
118 |
+
output3, size=[output2.size(2), output2.size(3)], mode="nearest"
|
119 |
+
)
|
120 |
+
output2 = output2 + up3
|
121 |
+
output2 = self.merge2(output2)
|
122 |
+
|
123 |
+
up2 = F.interpolate(
|
124 |
+
output2, size=[output1.size(2), output1.size(3)], mode="nearest"
|
125 |
+
)
|
126 |
+
output1 = output1 + up2
|
127 |
+
output1 = self.merge1(output1)
|
128 |
+
|
129 |
+
out = [output1, output2, output3]
|
130 |
+
return out
|
131 |
+
|
132 |
+
|
133 |
+
class MobileNetV1(nn.Module):
|
134 |
+
def __init__(self):
|
135 |
+
super(MobileNetV1, self).__init__()
|
136 |
+
self.stage1 = nn.Sequential(
|
137 |
+
conv_bn(3, 8, 2, leaky=0.1), # 3
|
138 |
+
conv_dw(8, 16, 1), # 7
|
139 |
+
conv_dw(16, 32, 2), # 11
|
140 |
+
conv_dw(32, 32, 1), # 19
|
141 |
+
conv_dw(32, 64, 2), # 27
|
142 |
+
conv_dw(64, 64, 1), # 43
|
143 |
+
)
|
144 |
+
self.stage2 = nn.Sequential(
|
145 |
+
conv_dw(64, 128, 2), # 43 + 16 = 59
|
146 |
+
conv_dw(128, 128, 1), # 59 + 32 = 91
|
147 |
+
conv_dw(128, 128, 1), # 91 + 32 = 123
|
148 |
+
conv_dw(128, 128, 1), # 123 + 32 = 155
|
149 |
+
conv_dw(128, 128, 1), # 155 + 32 = 187
|
150 |
+
conv_dw(128, 128, 1), # 187 + 32 = 219
|
151 |
+
)
|
152 |
+
self.stage3 = nn.Sequential(
|
153 |
+
conv_dw(128, 256, 2), # 219 +3 2 = 241
|
154 |
+
conv_dw(256, 256, 1), # 241 + 64 = 301
|
155 |
+
)
|
156 |
+
self.avg = nn.AdaptiveAvgPool2d((1, 1))
|
157 |
+
self.fc = nn.Linear(256, 1000)
|
158 |
+
|
159 |
+
def forward(self, x):
|
160 |
+
x = self.stage1(x)
|
161 |
+
x = self.stage2(x)
|
162 |
+
x = self.stage3(x)
|
163 |
+
x = self.avg(x)
|
164 |
+
# x = self.model(x)
|
165 |
+
x = x.view(-1, 256)
|
166 |
+
x = self.fc(x)
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
class ClassHead(nn.Module):
|
171 |
+
def __init__(self, inchannels=512, num_anchors=3):
|
172 |
+
super(ClassHead, self).__init__()
|
173 |
+
self.num_anchors = num_anchors
|
174 |
+
self.conv1x1 = nn.Conv2d(
|
175 |
+
inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0
|
176 |
+
)
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
out = self.conv1x1(x)
|
180 |
+
out = out.permute(0, 2, 3, 1).contiguous()
|
181 |
+
return out.view(out.shape[0], -1, 2)
|
182 |
+
|
183 |
+
|
184 |
+
class BboxHead(nn.Module):
|
185 |
+
def __init__(self, inchannels=512, num_anchors=3):
|
186 |
+
super(BboxHead, self).__init__()
|
187 |
+
self.conv1x1 = nn.Conv2d(
|
188 |
+
inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0
|
189 |
+
)
|
190 |
+
|
191 |
+
def forward(self, x):
|
192 |
+
out = self.conv1x1(x)
|
193 |
+
out = out.permute(0, 2, 3, 1).contiguous()
|
194 |
+
return out.view(out.shape[0], -1, 4)
|
195 |
+
|
196 |
+
|
197 |
+
class LandmarkHead(nn.Module):
|
198 |
+
def __init__(self, inchannels=512, num_anchors=3):
|
199 |
+
super(LandmarkHead, self).__init__()
|
200 |
+
self.conv1x1 = nn.Conv2d(
|
201 |
+
inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0
|
202 |
+
)
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
out = self.conv1x1(x)
|
206 |
+
out = out.permute(0, 2, 3, 1).contiguous()
|
207 |
+
return out.view(out.shape[0], -1, 10)
|
208 |
+
|
209 |
+
|
210 |
+
class RetinaFace(nn.Module):
|
211 |
+
def __init__(self, cfg=None, phase="train"):
|
212 |
+
"""
|
213 |
+
:param cfg: Network related settings.
|
214 |
+
:param phase: train or test.
|
215 |
+
"""
|
216 |
+
super(RetinaFace, self).__init__()
|
217 |
+
self.phase = phase
|
218 |
+
backbone = None
|
219 |
+
if cfg["name"] == "mobilenet0.25":
|
220 |
+
backbone = MobileNetV1()
|
221 |
+
elif cfg["name"] == "Resnet50":
|
222 |
+
import torchvision.models as models
|
223 |
+
backbone = models.resnet50(pretrained=cfg["pretrain"])
|
224 |
+
|
225 |
+
self.body = _utils.IntermediateLayerGetter(
|
226 |
+
backbone, cfg["return_layers"])
|
227 |
+
in_channels_stage2 = cfg["in_channel"]
|
228 |
+
in_channels_list = [
|
229 |
+
in_channels_stage2 * 2,
|
230 |
+
in_channels_stage2 * 4,
|
231 |
+
in_channels_stage2 * 8,
|
232 |
+
]
|
233 |
+
out_channels = cfg["out_channel"]
|
234 |
+
self.fpn = FPN(in_channels_list, out_channels)
|
235 |
+
self.ssh1 = SSH(out_channels, out_channels)
|
236 |
+
self.ssh2 = SSH(out_channels, out_channels)
|
237 |
+
self.ssh3 = SSH(out_channels, out_channels)
|
238 |
+
|
239 |
+
self.ClassHead = self._make_class_head(
|
240 |
+
fpn_num=3, inchannels=cfg["out_channel"])
|
241 |
+
self.BboxHead = self._make_bbox_head(
|
242 |
+
fpn_num=3, inchannels=cfg["out_channel"])
|
243 |
+
self.LandmarkHead = self._make_landmark_head(
|
244 |
+
fpn_num=3, inchannels=cfg["out_channel"]
|
245 |
+
)
|
246 |
+
|
247 |
+
def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2):
|
248 |
+
classhead = nn.ModuleList()
|
249 |
+
for i in range(fpn_num):
|
250 |
+
classhead.append(ClassHead(inchannels, anchor_num))
|
251 |
+
return classhead
|
252 |
+
|
253 |
+
def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2):
|
254 |
+
bboxhead = nn.ModuleList()
|
255 |
+
for i in range(fpn_num):
|
256 |
+
bboxhead.append(BboxHead(inchannels, anchor_num))
|
257 |
+
return bboxhead
|
258 |
+
|
259 |
+
def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2):
|
260 |
+
landmarkhead = nn.ModuleList()
|
261 |
+
for i in range(fpn_num):
|
262 |
+
landmarkhead.append(LandmarkHead(inchannels, anchor_num))
|
263 |
+
return landmarkhead
|
264 |
+
|
265 |
+
def forward(self, inputs):
|
266 |
+
out = self.body(inputs)
|
267 |
+
|
268 |
+
# FPN
|
269 |
+
fpn = self.fpn(out)
|
270 |
+
|
271 |
+
# SSH
|
272 |
+
feature1 = self.ssh1(fpn[0])
|
273 |
+
feature2 = self.ssh2(fpn[1])
|
274 |
+
feature3 = self.ssh3(fpn[2])
|
275 |
+
features = [feature1, feature2, feature3]
|
276 |
+
|
277 |
+
bbox_regressions = torch.cat(
|
278 |
+
[self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1
|
279 |
+
)
|
280 |
+
classifications = torch.cat(
|
281 |
+
[self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1
|
282 |
+
)
|
283 |
+
ldm_regressions = torch.cat(
|
284 |
+
[self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1
|
285 |
+
)
|
286 |
+
|
287 |
+
if self.phase == "train":
|
288 |
+
output = (bbox_regressions, classifications, ldm_regressions)
|
289 |
+
else:
|
290 |
+
output = (
|
291 |
+
bbox_regressions,
|
292 |
+
F.softmax(classifications, dim=-1),
|
293 |
+
ldm_regressions,
|
294 |
+
)
|
295 |
+
return output
|
296 |
+
|
297 |
+
|
298 |
+
# Adapted from https://github.com/Hakuyume/chainer-ssd
|
299 |
+
def decode(loc: torch.Tensor, priors: torch.Tensor, variances: Tuple[float, float]) -> torch.Tensor:
|
300 |
+
boxes = torch.cat(
|
301 |
+
(
|
302 |
+
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
303 |
+
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1]),
|
304 |
+
),
|
305 |
+
1,
|
306 |
+
)
|
307 |
+
boxes[:, :2] -= boxes[:, 2:] / 2
|
308 |
+
boxes[:, 2:] += boxes[:, :2]
|
309 |
+
return boxes
|
310 |
+
|
311 |
+
|
312 |
+
def decode_landm(pre: torch.Tensor, priors: torch.Tensor, variances: Tuple[float, float]) -> torch.Tensor:
|
313 |
+
landms = torch.cat(
|
314 |
+
(
|
315 |
+
priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
|
316 |
+
priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
|
317 |
+
priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
|
318 |
+
priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
|
319 |
+
priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
|
320 |
+
),
|
321 |
+
dim=1,
|
322 |
+
)
|
323 |
+
return landms
|
324 |
+
|
325 |
+
|
326 |
+
def nms(dets: torch.Tensor, thresh: float) -> List[int]:
|
327 |
+
"""Pure Python NMS baseline."""
|
328 |
+
x1 = dets[:, 0]
|
329 |
+
y1 = dets[:, 1]
|
330 |
+
x2 = dets[:, 2]
|
331 |
+
y2 = dets[:, 3]
|
332 |
+
scores = dets[:, 4]
|
333 |
+
|
334 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
335 |
+
order = torch.flip(scores.argsort(), [0])
|
336 |
+
|
337 |
+
keep = []
|
338 |
+
while order.numel() > 0:
|
339 |
+
i = order[0].item()
|
340 |
+
keep.append(i)
|
341 |
+
xx1 = torch.maximum(x1[i], x1[order[1:]])
|
342 |
+
yy1 = torch.maximum(y1[i], y1[order[1:]])
|
343 |
+
xx2 = torch.minimum(x2[i], x2[order[1:]])
|
344 |
+
yy2 = torch.minimum(y2[i], y2[order[1:]])
|
345 |
+
|
346 |
+
w = torch.maximum(torch.tensor(0.0).to(dets), xx2 - xx1 + 1)
|
347 |
+
h = torch.maximum(torch.tensor(0.0).to(dets), yy2 - yy1 + 1)
|
348 |
+
inter = w * h
|
349 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
350 |
+
|
351 |
+
inds = torch.where(ovr <= thresh)[0]
|
352 |
+
order = order[inds + 1]
|
353 |
+
|
354 |
+
return keep
|
355 |
+
|
356 |
+
|
357 |
+
class PriorBox:
|
358 |
+
def __init__(self, cfg: dict, image_size: Tuple[int, int]):
|
359 |
+
self.min_sizes = cfg["min_sizes"]
|
360 |
+
self.steps = cfg["steps"]
|
361 |
+
self.clip = cfg["clip"]
|
362 |
+
self.image_size = image_size
|
363 |
+
self.feature_maps = [
|
364 |
+
[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)]
|
365 |
+
for step in self.steps
|
366 |
+
]
|
367 |
+
|
368 |
+
def generate_anchors(self, device) -> torch.Tensor:
|
369 |
+
anchors = []
|
370 |
+
for k, f in enumerate(self.feature_maps):
|
371 |
+
min_sizes = self.min_sizes[k]
|
372 |
+
for i, j in product(range(f[0]), range(f[1])):
|
373 |
+
for min_size in min_sizes:
|
374 |
+
s_kx = min_size / self.image_size[1]
|
375 |
+
s_ky = min_size / self.image_size[0]
|
376 |
+
dense_cx = [
|
377 |
+
x * self.steps[k] / self.image_size[1] for x in [j + 0.5]
|
378 |
+
]
|
379 |
+
dense_cy = [
|
380 |
+
y * self.steps[k] / self.image_size[0] for y in [i + 0.5]
|
381 |
+
]
|
382 |
+
for cy, cx in product(dense_cy, dense_cx):
|
383 |
+
anchors += [cx, cy, s_kx, s_ky]
|
384 |
+
|
385 |
+
# back to torch land
|
386 |
+
output = torch.tensor(anchors).view(-1, 4)
|
387 |
+
if self.clip:
|
388 |
+
output.clamp_(max=1, min=0)
|
389 |
+
return output.to(device=device)
|
390 |
+
|
391 |
+
|
392 |
+
cfg_mnet = {
|
393 |
+
"name": "mobilenet0.25",
|
394 |
+
"min_sizes": [[16, 32], [64, 128], [256, 512]],
|
395 |
+
"steps": [8, 16, 32],
|
396 |
+
"variance": [0.1, 0.2],
|
397 |
+
"clip": False,
|
398 |
+
"loc_weight": 2.0,
|
399 |
+
"gpu_train": True,
|
400 |
+
"batch_size": 32,
|
401 |
+
"ngpu": 1,
|
402 |
+
"epoch": 250,
|
403 |
+
"decay1": 190,
|
404 |
+
"decay2": 220,
|
405 |
+
"image_size": 640,
|
406 |
+
"pretrain": True,
|
407 |
+
"return_layers": {"stage1": 1, "stage2": 2, "stage3": 3},
|
408 |
+
"in_channel": 32,
|
409 |
+
"out_channel": 64,
|
410 |
+
}
|
411 |
+
|
412 |
+
cfg_re50 = {
|
413 |
+
"name": "Resnet50",
|
414 |
+
"min_sizes": [[16, 32], [64, 128], [256, 512]],
|
415 |
+
"steps": [8, 16, 32],
|
416 |
+
"variance": [0.1, 0.2],
|
417 |
+
"clip": False,
|
418 |
+
"loc_weight": 2.0,
|
419 |
+
"gpu_train": True,
|
420 |
+
"batch_size": 24,
|
421 |
+
"ngpu": 4,
|
422 |
+
"epoch": 100,
|
423 |
+
"decay1": 70,
|
424 |
+
"decay2": 90,
|
425 |
+
"image_size": 840,
|
426 |
+
"pretrain": False,
|
427 |
+
"return_layers": {"layer2": 1, "layer3": 2, "layer4": 3},
|
428 |
+
"in_channel": 256,
|
429 |
+
"out_channel": 256,
|
430 |
+
}
|
431 |
+
|
432 |
+
|
433 |
+
def check_keys(model, pretrained_state_dict):
|
434 |
+
ckpt_keys = set(pretrained_state_dict.keys())
|
435 |
+
model_keys = set(model.state_dict().keys())
|
436 |
+
used_pretrained_keys = model_keys & ckpt_keys
|
437 |
+
assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint"
|
438 |
+
return True
|
439 |
+
|
440 |
+
|
441 |
+
def remove_prefix(state_dict, prefix):
|
442 |
+
""" Old style model is stored with all names of parameters sharing common prefix 'module.' """
|
443 |
+
def f(x): return x.split(prefix, 1)[-1] if x.startswith(prefix) else x
|
444 |
+
return {f(key): value for key, value in state_dict.items()}
|
445 |
+
|
446 |
+
|
447 |
+
def load_model(model, pretrained_path, load_to_cpu, network: str):
|
448 |
+
if pretrained_path is None:
|
449 |
+
url = pretrained_urls[network]
|
450 |
+
if load_to_cpu:
|
451 |
+
pretrained_dict = torch.utils.model_zoo.load_url(
|
452 |
+
url, map_location=lambda storage, loc: storage
|
453 |
+
)
|
454 |
+
else:
|
455 |
+
pretrained_dict = torch.utils.model_zoo.load_url(
|
456 |
+
url, map_location=lambda storage, loc: storage.cuda(device)
|
457 |
+
)
|
458 |
+
else:
|
459 |
+
if load_to_cpu:
|
460 |
+
pretrained_dict = torch.load(
|
461 |
+
pretrained_path, map_location=lambda storage, loc: storage
|
462 |
+
)
|
463 |
+
else:
|
464 |
+
device = torch.cuda.current_device()
|
465 |
+
pretrained_dict = torch.load(
|
466 |
+
pretrained_path, map_location=lambda storage, loc: storage.cuda(
|
467 |
+
device)
|
468 |
+
)
|
469 |
+
if "state_dict" in pretrained_dict.keys():
|
470 |
+
pretrained_dict = remove_prefix(
|
471 |
+
pretrained_dict["state_dict"], "module.")
|
472 |
+
else:
|
473 |
+
pretrained_dict = remove_prefix(pretrained_dict, "module.")
|
474 |
+
check_keys(model, pretrained_dict)
|
475 |
+
model.load_state_dict(pretrained_dict, strict=False)
|
476 |
+
return model
|
477 |
+
|
478 |
+
|
479 |
+
def load_net(model_path, network="mobilenet"):
|
480 |
+
if network == "mobilenet":
|
481 |
+
cfg = cfg_mnet
|
482 |
+
elif network == "resnet50":
|
483 |
+
cfg = cfg_re50
|
484 |
+
else:
|
485 |
+
raise NotImplementedError(network)
|
486 |
+
# net and model
|
487 |
+
net = RetinaFace(cfg=cfg, phase="test")
|
488 |
+
net = load_model(net, model_path, True, network=network)
|
489 |
+
net.eval()
|
490 |
+
cudnn.benchmark = True
|
491 |
+
# net = net.to(device)
|
492 |
+
return net
|
493 |
+
|
494 |
+
|
495 |
+
def parse_det(det: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, float]:
|
496 |
+
landmarks = det[5:].reshape(5, 2)
|
497 |
+
box = det[:4]
|
498 |
+
score = det[4]
|
499 |
+
return box, landmarks, score.item()
|
500 |
+
|
501 |
+
|
502 |
+
def post_process(
|
503 |
+
loc: torch.Tensor,
|
504 |
+
conf: torch.Tensor,
|
505 |
+
landms: torch.Tensor,
|
506 |
+
prior_data: torch.Tensor,
|
507 |
+
cfg: dict,
|
508 |
+
scale: float,
|
509 |
+
scale1: float,
|
510 |
+
resize,
|
511 |
+
confidence_threshold,
|
512 |
+
top_k,
|
513 |
+
nms_threshold,
|
514 |
+
keep_top_k,
|
515 |
+
):
|
516 |
+
boxes = decode(loc, prior_data, cfg["variance"])
|
517 |
+
boxes = boxes * scale / resize
|
518 |
+
# boxes = boxes.cpu().numpy()
|
519 |
+
# scores = conf.cpu().numpy()[:, 1]
|
520 |
+
scores = conf[:, 1]
|
521 |
+
landms_copy = decode_landm(landms, prior_data, cfg["variance"])
|
522 |
+
|
523 |
+
landms_copy = landms_copy * scale1 / resize
|
524 |
+
# landms_copy = landms_copy.cpu().numpy()
|
525 |
+
|
526 |
+
# ignore low scores
|
527 |
+
inds = torch.where(scores > confidence_threshold)[0]
|
528 |
+
boxes = boxes[inds]
|
529 |
+
landms_copy = landms_copy[inds]
|
530 |
+
scores = scores[inds]
|
531 |
+
|
532 |
+
# keep top-K before NMS
|
533 |
+
order = torch.flip(scores.argsort(), [0])[:top_k]
|
534 |
+
boxes = boxes[order]
|
535 |
+
landms_copy = landms_copy[order]
|
536 |
+
scores = scores[order]
|
537 |
+
|
538 |
+
# do NMS
|
539 |
+
dets = torch.hstack((boxes, scores.unsqueeze(-1))).to(
|
540 |
+
dtype=torch.float32, copy=False)
|
541 |
+
keep = nms(dets, nms_threshold)
|
542 |
+
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
|
543 |
+
dets = dets[keep, :]
|
544 |
+
landms_copy = landms_copy[keep]
|
545 |
+
|
546 |
+
# keep top-K faster NMS
|
547 |
+
dets = dets[:keep_top_k, :]
|
548 |
+
landms_copy = landms_copy[:keep_top_k, :]
|
549 |
+
|
550 |
+
dets = torch.cat((dets, landms_copy), dim=1)
|
551 |
+
# show image
|
552 |
+
dets = sorted(dets, key=lambda x: x[4], reverse=True)
|
553 |
+
dets = [parse_det(x) for x in dets]
|
554 |
+
|
555 |
+
return dets
|
556 |
+
|
557 |
+
|
558 |
+
# @torch.no_grad()
|
559 |
+
def batch_detect(net: nn.Module, images: torch.Tensor, threshold: float = 0.5):
|
560 |
+
confidence_threshold = threshold
|
561 |
+
cfg = cfg_mnet
|
562 |
+
top_k = 5000
|
563 |
+
nms_threshold = 0.4
|
564 |
+
keep_top_k = 1
|
565 |
+
resize = 1
|
566 |
+
|
567 |
+
img = images.float()
|
568 |
+
mean = torch.as_tensor([104, 117, 123], dtype=img.dtype, device=img.device).view(
|
569 |
+
1, 3, 1, 1
|
570 |
+
)
|
571 |
+
img -= mean
|
572 |
+
(
|
573 |
+
_,
|
574 |
+
_,
|
575 |
+
im_height,
|
576 |
+
im_width,
|
577 |
+
) = img.shape
|
578 |
+
scale = torch.as_tensor(
|
579 |
+
[im_width, im_height, im_width, im_height],
|
580 |
+
dtype=img.dtype,
|
581 |
+
device=img.device,
|
582 |
+
)
|
583 |
+
scale = scale.to(img.device)
|
584 |
+
|
585 |
+
loc, conf, landms = net(img) # forward pass
|
586 |
+
|
587 |
+
priorbox = PriorBox(cfg, image_size=(im_height, im_width))
|
588 |
+
prior_data = priorbox.generate_anchors(device=img.device)
|
589 |
+
scale1 = torch.as_tensor(
|
590 |
+
[
|
591 |
+
img.shape[3],
|
592 |
+
img.shape[2],
|
593 |
+
img.shape[3],
|
594 |
+
img.shape[2],
|
595 |
+
img.shape[3],
|
596 |
+
img.shape[2],
|
597 |
+
img.shape[3],
|
598 |
+
img.shape[2],
|
599 |
+
img.shape[3],
|
600 |
+
img.shape[2],
|
601 |
+
],
|
602 |
+
dtype=img.dtype,
|
603 |
+
device=img.device,
|
604 |
+
)
|
605 |
+
scale1 = scale1.to(img.device)
|
606 |
+
|
607 |
+
all_dets = [
|
608 |
+
post_process(
|
609 |
+
loc_i,
|
610 |
+
conf_i,
|
611 |
+
landms_i,
|
612 |
+
prior_data,
|
613 |
+
cfg,
|
614 |
+
scale,
|
615 |
+
scale1,
|
616 |
+
resize,
|
617 |
+
confidence_threshold,
|
618 |
+
top_k,
|
619 |
+
nms_threshold,
|
620 |
+
keep_top_k,
|
621 |
+
)
|
622 |
+
for loc_i, conf_i, landms_i in zip(loc, conf, landms)
|
623 |
+
]
|
624 |
+
|
625 |
+
rects = []
|
626 |
+
points = []
|
627 |
+
scores = []
|
628 |
+
image_ids = []
|
629 |
+
for image_id, faces_in_one_image in enumerate(all_dets):
|
630 |
+
for rect, landmarks, score in faces_in_one_image:
|
631 |
+
rects.append(rect)
|
632 |
+
points.append(landmarks)
|
633 |
+
scores.append(score)
|
634 |
+
image_ids.append(image_id)
|
635 |
+
|
636 |
+
if len(rects) == 0:
|
637 |
+
return dict()
|
638 |
+
|
639 |
+
return {
|
640 |
+
'rects': torch.stack(rects, dim=0).to(img.device),
|
641 |
+
'points': torch.stack(points, dim=0).to(img.device),
|
642 |
+
'scores': torch.tensor(scores).to(img.device),
|
643 |
+
'image_ids': torch.tensor(image_ids).to(img.device),
|
644 |
+
}
|
645 |
+
|
646 |
+
|
647 |
+
class RetinaFaceDetector(FaceDetector):
|
648 |
+
"""RetinaFaceDetector
|
649 |
+
|
650 |
+
Args:
|
651 |
+
images (torch.Tensor): b x c x h x w, uint8, 0~255.
|
652 |
+
|
653 |
+
Returns:
|
654 |
+
faces (Dict[str, torch.Tensor]):
|
655 |
+
|
656 |
+
* image_ids: n, int
|
657 |
+
* rects: n x 4 (x1, y1, x2, y2)
|
658 |
+
* points: n x 5 x 2 (x, y)
|
659 |
+
* scores: n
|
660 |
+
"""
|
661 |
+
|
662 |
+
def __init__(self, conf_name: Optional[str] = None,
|
663 |
+
model_path: Optional[str] = None, threshold=0.8) -> None:
|
664 |
+
super().__init__()
|
665 |
+
if conf_name is None:
|
666 |
+
conf_name = 'mobilenet'
|
667 |
+
self.net = load_net(model_path, conf_name)
|
668 |
+
self.threshold = threshold
|
669 |
+
self.eval()
|
670 |
+
|
671 |
+
def forward(self, images: torch.Tensor) -> Dict[str, torch.Tensor]:
|
672 |
+
return batch_detect(self.net, images.clone(), threshold=self.threshold)
|
facer/face_parsing/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .base import FaceParser
|
2 |
+
from .farl import FaRLFaceParser
|
facer/face_parsing/base.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class FaceParser(nn.Module):
|
5 |
+
""" face parser
|
6 |
+
|
7 |
+
Args:
|
8 |
+
images (torch.Tensor): b x c x h x w
|
9 |
+
|
10 |
+
data (Dict[str, Any]):
|
11 |
+
|
12 |
+
* image_ids (torch.Tensor): nfaces
|
13 |
+
* rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
|
14 |
+
* points (torch.Tensor): nfaces x 5 x 2 (x, y)
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
data (Dict[str, Any]):
|
18 |
+
|
19 |
+
* image_ids (torch.Tensor): nfaces
|
20 |
+
* rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
|
21 |
+
* points (torch.Tensor): nfaces x 5 x 2 (x, y)
|
22 |
+
* seg (Dict[str, Any]):
|
23 |
+
|
24 |
+
* logits (torch.Tensor): nfaces x nclasses x h x w
|
25 |
+
* label_names (List[str]): nclasses
|
26 |
+
"""
|
27 |
+
pass
|
facer/face_parsing/farl.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Dict, Any
|
2 |
+
import functools
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from ..util import download_jit
|
7 |
+
from ..transform import (get_crop_and_resize_matrix, get_face_align_matrix, get_face_align_matrix_celebm,
|
8 |
+
make_inverted_tanh_warp_grid, make_tanh_warp_grid)
|
9 |
+
from .base import FaceParser
|
10 |
+
|
11 |
+
pretrain_settings = {
|
12 |
+
'lapa/448': {
|
13 |
+
'url': [
|
14 |
+
'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.lapa.main_ema_136500_jit191.pt',
|
15 |
+
],
|
16 |
+
'matrix_src_tag': 'points',
|
17 |
+
'get_matrix_fn': functools.partial(get_face_align_matrix,
|
18 |
+
target_shape=(448, 448), target_face_scale=1.0),
|
19 |
+
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
20 |
+
warp_factor=0.8, warped_shape=(448, 448)),
|
21 |
+
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
22 |
+
warp_factor=0.8, warped_shape=(448, 448)),
|
23 |
+
'label_names': ['background', 'face', 'rb', 'lb', 're',
|
24 |
+
'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
|
25 |
+
},
|
26 |
+
'celebm/448': {
|
27 |
+
'url': [
|
28 |
+
'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.celebm.main_ema_181500_jit.pt',
|
29 |
+
],
|
30 |
+
'matrix_src_tag': 'points',
|
31 |
+
'get_matrix_fn': functools.partial(get_face_align_matrix_celebm,
|
32 |
+
target_shape=(448, 448)),
|
33 |
+
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
34 |
+
warp_factor=0, warped_shape=(448, 448)),
|
35 |
+
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
36 |
+
warp_factor=0, warped_shape=(448, 448)),
|
37 |
+
'label_names': [
|
38 |
+
'background', 'neck', 'face', 'cloth', 'rr', 'lr', 'rb', 'lb', 're',
|
39 |
+
'le', 'nose', 'imouth', 'llip', 'ulip', 'hair',
|
40 |
+
'eyeg', 'hat', 'earr', 'neck_l']
|
41 |
+
}
|
42 |
+
}
|
43 |
+
|
44 |
+
|
45 |
+
class FaRLFaceParser(FaceParser):
|
46 |
+
""" The face parsing models from [FaRL](https://github.com/FacePerceiver/FaRL).
|
47 |
+
|
48 |
+
Please consider citing
|
49 |
+
```bibtex
|
50 |
+
@article{zheng2021farl,
|
51 |
+
title={General Facial Representation Learning in a Visual-Linguistic Manner},
|
52 |
+
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
|
53 |
+
Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
|
54 |
+
Dong and Zeng, Ming and Wen, Fang},
|
55 |
+
journal={arXiv preprint arXiv:2112.03109},
|
56 |
+
year={2021}
|
57 |
+
}
|
58 |
+
```
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, conf_name: Optional[str] = None,
|
62 |
+
model_path: Optional[str] = None, device=None) -> None:
|
63 |
+
super().__init__()
|
64 |
+
if conf_name is None:
|
65 |
+
conf_name = 'lapa/448'
|
66 |
+
if model_path is None:
|
67 |
+
model_path = pretrain_settings[conf_name]['url']
|
68 |
+
self.conf_name = conf_name
|
69 |
+
self.net = download_jit(model_path, map_location=device)
|
70 |
+
self.eval()
|
71 |
+
|
72 |
+
def forward(self, images: torch.Tensor, data: Dict[str, Any]):
|
73 |
+
setting = pretrain_settings[self.conf_name]
|
74 |
+
images = images.float() / 255.0
|
75 |
+
_, _, h, w = images.shape
|
76 |
+
|
77 |
+
simages = images[data['image_ids']]
|
78 |
+
matrix = setting['get_matrix_fn'](data[setting['matrix_src_tag']])
|
79 |
+
grid = setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
|
80 |
+
inv_grid = setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
|
81 |
+
|
82 |
+
w_images = F.grid_sample(
|
83 |
+
simages, grid, mode='bilinear', align_corners=False)
|
84 |
+
|
85 |
+
w_seg_logits, _ = self.net(w_images) # (b*n) x c x h x w
|
86 |
+
|
87 |
+
seg_logits = F.grid_sample(
|
88 |
+
w_seg_logits, inv_grid, mode='bilinear', align_corners=False)
|
89 |
+
|
90 |
+
data['seg'] = {'logits': seg_logits,
|
91 |
+
'label_names': setting['label_names']}
|
92 |
+
return data
|
facer/farl/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
from .model import load_farl, VisualTransformer
|
5 |
+
from .classification import farl_classification
|
facer/farl/classification.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
import torch.utils.checkpoint as checkpoint
|
4 |
+
from .model import VisualTransformer
|
5 |
+
|
6 |
+
|
7 |
+
class VITClassificationHeadV0(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
num_features: int,
|
11 |
+
channel: int,
|
12 |
+
num_labels: int,
|
13 |
+
norm=False,
|
14 |
+
dropout=0.0,
|
15 |
+
ret_feat=False,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.weights = nn.Parameter(
|
19 |
+
torch.ones(1, num_features * 3, 1, dtype=torch.float32)
|
20 |
+
)
|
21 |
+
self.final_fc = nn.Linear(channel, num_labels)
|
22 |
+
self.norm = norm
|
23 |
+
if self.norm:
|
24 |
+
for i in range(num_features * 3):
|
25 |
+
setattr(self, f"norm_{i}", nn.LayerNorm(channel))
|
26 |
+
self.dropout = nn.Dropout(p=dropout)
|
27 |
+
self.ret_feat = ret_feat
|
28 |
+
|
29 |
+
def forward(self, features, cls_tokens):
|
30 |
+
xs = []
|
31 |
+
for feature, cls_token in zip(features, cls_tokens):
|
32 |
+
# feature: b x c x s x s
|
33 |
+
# cls_token: b x c
|
34 |
+
xs.append(feature.mean([2, 3]))
|
35 |
+
xs.append(feature.max(-1).values.max(-1).values)
|
36 |
+
xs.append(cls_token)
|
37 |
+
if self.norm:
|
38 |
+
xs = [getattr(self, f"norm_{i}")(x) for i, x in enumerate(xs)]
|
39 |
+
xs = torch.stack(xs, dim=1) # b x 3N x c
|
40 |
+
feat = (xs * self.weights.softmax(dim=1)).sum(1) # b x c
|
41 |
+
x = self.dropout(feat)
|
42 |
+
x = self.final_fc(x) # b x num_labels
|
43 |
+
if self.ret_feat:
|
44 |
+
return x, feat
|
45 |
+
else:
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class FACTransformer(nn.Module):
|
50 |
+
"""A face attribute classification transformer leveraging multiple cls_tokens.
|
51 |
+
Args:
|
52 |
+
image (torch.Tensor): Float32 tensor with shape [b, 3, h, w], normalized to [0, 1].
|
53 |
+
Returns:
|
54 |
+
logits (torch.Tensor): Float32 tensor with shape [b, n_classes].
|
55 |
+
aux_outputs:
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, backbone: nn.Module, head: nn.Module):
|
59 |
+
super().__init__()
|
60 |
+
self.backbone = backbone
|
61 |
+
self.head = head
|
62 |
+
self.cuda().float()
|
63 |
+
|
64 |
+
def forward(self, image):
|
65 |
+
logits = self.head(*self.backbone(image))
|
66 |
+
return logits
|
67 |
+
|
68 |
+
|
69 |
+
def add_method(obj, name, method):
|
70 |
+
import types
|
71 |
+
|
72 |
+
setattr(obj, name, types.MethodType(method, obj))
|
73 |
+
|
74 |
+
|
75 |
+
def get_clip_encode_func(layers):
|
76 |
+
def func(self, x):
|
77 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
78 |
+
# shape = [*, width, grid ** 2]
|
79 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
80 |
+
extra_tokens = getattr(self, "extra_tokens", [])
|
81 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
82 |
+
class_token = self.class_embedding.to(x.dtype) + torch.zeros(
|
83 |
+
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
|
84 |
+
)
|
85 |
+
special_tokens = [
|
86 |
+
getattr(self, name).to(x.dtype)
|
87 |
+
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
|
88 |
+
for name in extra_tokens
|
89 |
+
]
|
90 |
+
x = torch.cat(
|
91 |
+
[class_token, *special_tokens, x], dim=1
|
92 |
+
) # shape = [*, grid ** 2 + 1, width]
|
93 |
+
x = x + self.positional_embedding.to(x.dtype)
|
94 |
+
x = self.ln_pre(x)
|
95 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
96 |
+
outs = []
|
97 |
+
max_layer = max(layers)
|
98 |
+
use_checkpoint = self.transformer.use_checkpoint
|
99 |
+
for layer_i, blk in enumerate(self.transformer.resblocks):
|
100 |
+
if layer_i > max_layer:
|
101 |
+
break
|
102 |
+
if self.training and use_checkpoint:
|
103 |
+
x = checkpoint.checkpoint(blk, x)
|
104 |
+
else:
|
105 |
+
x = blk(x)
|
106 |
+
outs.append(x)
|
107 |
+
|
108 |
+
outs = torch.stack(outs).permute(0, 2, 1, 3)
|
109 |
+
cls_tokens = outs[layers, :, 0, :]
|
110 |
+
|
111 |
+
extra_token_feats = {}
|
112 |
+
for i, name in enumerate(extra_tokens):
|
113 |
+
extra_token_feats[name] = outs[layers, :, i + 1, :]
|
114 |
+
L, B, N, C = outs.shape
|
115 |
+
import math
|
116 |
+
|
117 |
+
W = int(math.sqrt(N - 1 - len(extra_tokens)))
|
118 |
+
features = (
|
119 |
+
outs[layers, :, 1 + len(extra_tokens) :, :]
|
120 |
+
.reshape(len(layers), B, W, W, C)
|
121 |
+
.permute(0, 1, 4, 2, 3)
|
122 |
+
)
|
123 |
+
if getattr(self, "ret_special", False):
|
124 |
+
return features, cls_tokens, extra_token_feats
|
125 |
+
else:
|
126 |
+
return features, cls_tokens
|
127 |
+
|
128 |
+
return func
|
129 |
+
|
130 |
+
|
131 |
+
def farl_classification(num_classes=2, layers=list(range(12))):
|
132 |
+
model = VisualTransformer(
|
133 |
+
input_resolution=224,
|
134 |
+
patch_size=16,
|
135 |
+
width=768,
|
136 |
+
layers=12,
|
137 |
+
heads=12,
|
138 |
+
output_dim=512,
|
139 |
+
)
|
140 |
+
channel = 768
|
141 |
+
model = model.cuda()
|
142 |
+
del model.proj
|
143 |
+
del model.ln_post
|
144 |
+
add_method(model, "forward", get_clip_encode_func(layers))
|
145 |
+
head = VITClassificationHeadV0(
|
146 |
+
num_features=len(layers), channel=channel, num_labels=num_classes, norm=True
|
147 |
+
)
|
148 |
+
model = FACTransformer(model, head)
|
149 |
+
return model
|
facer/farl/model.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
from collections import OrderedDict
|
5 |
+
import logging
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
import torch.utils.checkpoint as checkpoint
|
10 |
+
import numpy as np
|
11 |
+
from timm.models.layers import trunc_normal_, DropPath
|
12 |
+
|
13 |
+
|
14 |
+
class Bottleneck(nn.Module):
|
15 |
+
expansion = 4
|
16 |
+
|
17 |
+
def __init__(self, inplanes, planes, stride=1):
|
18 |
+
super().__init__()
|
19 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
20 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
21 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
22 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
23 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
24 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
25 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
26 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.downsample = None
|
29 |
+
self.stride = stride
|
30 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
31 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
32 |
+
self.downsample = nn.Sequential(OrderedDict([
|
33 |
+
("-1", nn.AvgPool2d(stride)),
|
34 |
+
("0", nn.Conv2d(inplanes, planes *
|
35 |
+
self.expansion, 1, stride=1, bias=False)),
|
36 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
37 |
+
]))
|
38 |
+
|
39 |
+
def forward(self, x: torch.Tensor):
|
40 |
+
identity = x
|
41 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
42 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
43 |
+
out = self.avgpool(out)
|
44 |
+
out = self.bn3(self.conv3(out))
|
45 |
+
if self.downsample is not None:
|
46 |
+
identity = self.downsample(x)
|
47 |
+
out += identity
|
48 |
+
out = self.relu(out)
|
49 |
+
return out
|
50 |
+
|
51 |
+
|
52 |
+
class AttentionPool2d(nn.Module):
|
53 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
54 |
+
super().__init__()
|
55 |
+
self.positional_embedding = nn.Parameter(
|
56 |
+
torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5
|
57 |
+
)
|
58 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
59 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
60 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
61 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
62 |
+
self.num_heads = num_heads
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2]
|
66 |
+
* x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
67 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
68 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
69 |
+
x, _ = F.multi_head_attention_forward(
|
70 |
+
query=x, key=x, value=x,
|
71 |
+
embed_dim_to_check=x.shape[-1],
|
72 |
+
num_heads=self.num_heads,
|
73 |
+
q_proj_weight=self.q_proj.weight,
|
74 |
+
k_proj_weight=self.k_proj.weight,
|
75 |
+
v_proj_weight=self.v_proj.weight,
|
76 |
+
in_proj_weight=None,
|
77 |
+
in_proj_bias=torch.cat(
|
78 |
+
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
|
79 |
+
),
|
80 |
+
bias_k=None,
|
81 |
+
bias_v=None,
|
82 |
+
add_zero_attn=False,
|
83 |
+
dropout_p=0,
|
84 |
+
out_proj_weight=self.c_proj.weight,
|
85 |
+
out_proj_bias=self.c_proj.bias,
|
86 |
+
use_separate_proj_weight=True,
|
87 |
+
training=self.training,
|
88 |
+
need_weights=False
|
89 |
+
)
|
90 |
+
return x[0]
|
91 |
+
|
92 |
+
|
93 |
+
class ModifiedResNet(nn.Module):
|
94 |
+
"""
|
95 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
96 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
97 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
98 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
99 |
+
"""
|
100 |
+
|
101 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
102 |
+
super().__init__()
|
103 |
+
self.output_dim = output_dim
|
104 |
+
self.input_resolution = input_resolution
|
105 |
+
# the 3-layer stem
|
106 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3,
|
107 |
+
stride=2, padding=1, bias=False)
|
108 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
109 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2,
|
110 |
+
kernel_size=3, padding=1, bias=False)
|
111 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
112 |
+
self.conv3 = nn.Conv2d(
|
113 |
+
width // 2, width, kernel_size=3, padding=1, bias=False)
|
114 |
+
self.bn3 = nn.BatchNorm2d(width)
|
115 |
+
self.avgpool = nn.AvgPool2d(2)
|
116 |
+
self.relu = nn.ReLU(inplace=True)
|
117 |
+
# residual layers
|
118 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
119 |
+
self.layer1 = self._make_layer(width, layers[0])
|
120 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
121 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
122 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
123 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
124 |
+
self.attnpool = AttentionPool2d(
|
125 |
+
input_resolution // 32, embed_dim, heads, output_dim
|
126 |
+
)
|
127 |
+
self.apply(self._init_weights)
|
128 |
+
|
129 |
+
def _init_weights(self, m):
|
130 |
+
if isinstance(m, (nn.BatchNorm2d, LayerNorm)):
|
131 |
+
nn.init.constant_(m.weight, 1)
|
132 |
+
nn.init.constant_(m.bias, 0)
|
133 |
+
elif isinstance(m, (nn.Linear, nn.Conv2d)):
|
134 |
+
trunc_normal_(m.weight, std=0.02)
|
135 |
+
if m.bias is not None:
|
136 |
+
nn.init.constant_(m.bias, 0)
|
137 |
+
|
138 |
+
def _make_layer(self, planes, blocks, stride=1):
|
139 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
140 |
+
self._inplanes = planes * Bottleneck.expansion
|
141 |
+
for _ in range(1, blocks):
|
142 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
143 |
+
return nn.Sequential(*layers)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
def stem(x):
|
147 |
+
for conv, bn in [
|
148 |
+
(self.conv1, self.bn1),
|
149 |
+
(self.conv2, self.bn2),
|
150 |
+
(self.conv3, self.bn3)
|
151 |
+
]:
|
152 |
+
x = self.relu(bn(conv(x)))
|
153 |
+
x = self.avgpool(x)
|
154 |
+
return x
|
155 |
+
x = x.type(self.conv1.weight.dtype)
|
156 |
+
x = stem(x)
|
157 |
+
x = self.layer1(x)
|
158 |
+
x = self.layer2(x)
|
159 |
+
x = self.layer3(x)
|
160 |
+
x = self.layer4(x)
|
161 |
+
x = self.attnpool(x)
|
162 |
+
return x
|
163 |
+
|
164 |
+
|
165 |
+
class LayerNorm(nn.Module):
|
166 |
+
def __init__(self, hidden_size, eps=1e-5):
|
167 |
+
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
168 |
+
"""
|
169 |
+
super(LayerNorm, self).__init__()
|
170 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
171 |
+
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
172 |
+
self.variance_epsilon = eps
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
pdtype = x.dtype
|
176 |
+
x = x.float()
|
177 |
+
u = x.mean(-1, keepdim=True)
|
178 |
+
s = (x - u).pow(2).mean(-1, keepdim=True)
|
179 |
+
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
180 |
+
return self.weight * x.to(pdtype) + self.bias
|
181 |
+
|
182 |
+
|
183 |
+
class QuickGELU(nn.Module):
|
184 |
+
def forward(self, x: torch.Tensor):
|
185 |
+
return x * torch.sigmoid(1.702 * x)
|
186 |
+
|
187 |
+
|
188 |
+
class ResidualAttentionBlock(nn.Module):
|
189 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, drop_path=0.):
|
190 |
+
super().__init__()
|
191 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
192 |
+
self.ln_1 = LayerNorm(d_model)
|
193 |
+
self.mlp = nn.Sequential(OrderedDict([
|
194 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
195 |
+
("gelu", QuickGELU()),
|
196 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
197 |
+
]))
|
198 |
+
self.ln_2 = LayerNorm(d_model)
|
199 |
+
self.attn_mask = attn_mask
|
200 |
+
self.drop_path = DropPath(
|
201 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
202 |
+
|
203 |
+
def add_drop_path(self, drop_path):
|
204 |
+
self.drop_path = DropPath(
|
205 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
206 |
+
|
207 |
+
def attention(self, x: torch.Tensor):
|
208 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
|
209 |
+
if self.attn_mask is not None else None
|
210 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
211 |
+
|
212 |
+
def forward(self, x: torch.Tensor):
|
213 |
+
x = x + self.drop_path(self.attention(self.ln_1(x)))
|
214 |
+
x = x + self.drop_path(self.mlp(self.ln_2(x)))
|
215 |
+
return x
|
216 |
+
|
217 |
+
|
218 |
+
class Transformer(nn.Module):
|
219 |
+
def __init__(self,
|
220 |
+
width: int,
|
221 |
+
layers: int,
|
222 |
+
heads: int,
|
223 |
+
attn_mask: torch.Tensor = None,
|
224 |
+
use_checkpoint=True,
|
225 |
+
drop_rate=0.,
|
226 |
+
attn_drop_rate=0.,
|
227 |
+
drop_path_rate=0.,
|
228 |
+
):
|
229 |
+
super().__init__()
|
230 |
+
self.width = width
|
231 |
+
self.layers = layers
|
232 |
+
self.use_checkpoint = use_checkpoint
|
233 |
+
# stochastic depth decay rule
|
234 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, layers)]
|
235 |
+
self.resblocks = nn.ModuleList([
|
236 |
+
ResidualAttentionBlock(width, heads, attn_mask, drop_path=dpr[i])
|
237 |
+
for i in range(layers)
|
238 |
+
])
|
239 |
+
self.apply(self._init_weights)
|
240 |
+
|
241 |
+
def _init_weights(self, m):
|
242 |
+
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
243 |
+
trunc_normal_(m.weight, std=0.02)
|
244 |
+
if m.bias is not None:
|
245 |
+
nn.init.constant_(m.bias, 0)
|
246 |
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
247 |
+
nn.init.constant_(m.bias, 0)
|
248 |
+
nn.init.constant_(m.weight, 1.0)
|
249 |
+
|
250 |
+
def forward(self, x: torch.Tensor):
|
251 |
+
for i, blk in enumerate(self.resblocks):
|
252 |
+
x = blk(x)
|
253 |
+
return x
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
class VisualTransformer(nn.Module):
|
258 |
+
positional_embedding: nn.Parameter
|
259 |
+
|
260 |
+
def __init__(self,
|
261 |
+
input_resolution: int,
|
262 |
+
patch_size: int,
|
263 |
+
width: int,
|
264 |
+
layers: int,
|
265 |
+
heads: int,
|
266 |
+
output_dim: int,
|
267 |
+
pool_type: str = 'default',
|
268 |
+
skip_cls: bool = False,
|
269 |
+
drop_path_rate=0.,
|
270 |
+
**kwargs):
|
271 |
+
super().__init__()
|
272 |
+
self.pool_type = pool_type
|
273 |
+
self.skip_cls = skip_cls
|
274 |
+
self.input_resolution = input_resolution
|
275 |
+
self.output_dim = output_dim
|
276 |
+
self.conv1 = nn.Conv2d(
|
277 |
+
in_channels=3,
|
278 |
+
out_channels=width,
|
279 |
+
kernel_size=patch_size,
|
280 |
+
stride=patch_size,
|
281 |
+
bias=False
|
282 |
+
)
|
283 |
+
self.config = kwargs.get("config", None)
|
284 |
+
self.sequence_length = (input_resolution // patch_size) ** 2 + 1
|
285 |
+
self.conv_pool = nn.Identity()
|
286 |
+
if (self.pool_type == 'linear'):
|
287 |
+
if (not self.skip_cls):
|
288 |
+
self.conv_pool = nn.Conv1d(
|
289 |
+
width, width, self.sequence_length, stride=self.sequence_length, groups=width)
|
290 |
+
else:
|
291 |
+
self.conv_pool = nn.Conv1d(
|
292 |
+
width, width, self.sequence_length-1, stride=self.sequence_length, groups=width)
|
293 |
+
scale = width ** -0.5
|
294 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
295 |
+
self.positional_embedding = nn.Parameter(
|
296 |
+
scale * torch.randn(
|
297 |
+
self.sequence_length, width
|
298 |
+
)
|
299 |
+
)
|
300 |
+
self.ln_pre = LayerNorm(width)
|
301 |
+
self.transformer = Transformer(
|
302 |
+
width, layers, heads, drop_path_rate=drop_path_rate)
|
303 |
+
self.ln_post = LayerNorm(width)
|
304 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
305 |
+
if self.config is not None and self.config.MIM.ENABLE:
|
306 |
+
logging.info("MIM ENABLED")
|
307 |
+
self.mim = True
|
308 |
+
self.lm_transformer = Transformer(
|
309 |
+
width, self.config.MIM.LAYERS, heads)
|
310 |
+
self.ln_lm = LayerNorm(width)
|
311 |
+
self.lm_head = nn.Linear(width, self.config.MIM.VOCAB_SIZE)
|
312 |
+
self.mask_token = nn.Parameter(scale * torch.randn(width))
|
313 |
+
else:
|
314 |
+
self.mim = False
|
315 |
+
self.apply(self._init_weights)
|
316 |
+
|
317 |
+
def _init_weights(self, m):
|
318 |
+
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv1d)):
|
319 |
+
trunc_normal_(m.weight, std=0.02)
|
320 |
+
if m.bias is not None:
|
321 |
+
nn.init.constant_(m.bias, 0)
|
322 |
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
323 |
+
nn.init.constant_(m.bias, 0)
|
324 |
+
nn.init.constant_(m.weight, 1.0)
|
325 |
+
|
326 |
+
def forward(self, x: torch.Tensor):
|
327 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
328 |
+
# shape = [*, width, grid ** 2]
|
329 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
330 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
331 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1],
|
332 |
+
dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
333 |
+
x = x + self.positional_embedding.to(x.dtype)
|
334 |
+
x = self.ln_pre(x)
|
335 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
336 |
+
x = self.transformer(x)
|
337 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
338 |
+
if (self.pool_type == 'average'):
|
339 |
+
if self.skip_cls:
|
340 |
+
x = x[:, 1:, :]
|
341 |
+
x = torch.mean(x, dim=1)
|
342 |
+
elif (self.pool_type == 'linear'):
|
343 |
+
if self.skip_cls:
|
344 |
+
x = x[:, 1:, :]
|
345 |
+
x = x.permute(0, 2, 1)
|
346 |
+
x = self.conv_pool(x)
|
347 |
+
x = x.permute(0, 2, 1).squeeze()
|
348 |
+
else:
|
349 |
+
x = x[:, 0, :]
|
350 |
+
x = self.ln_post(x)
|
351 |
+
if self.proj is not None:
|
352 |
+
x = x @ self.proj
|
353 |
+
return x
|
354 |
+
|
355 |
+
def forward_mim(self, x: torch.Tensor, bool_masked_pos, return_all_tokens=False, disable_vlc=False):
|
356 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
357 |
+
# shape = [*, width, grid ** 2]
|
358 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
359 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
360 |
+
batch_size, seq_len, _ = x.size()
|
361 |
+
mask_token = self.mask_token.unsqueeze(
|
362 |
+
0).unsqueeze(0).expand(batch_size, seq_len, -1)
|
363 |
+
w = bool_masked_pos.unsqueeze(-1).type_as(mask_token)
|
364 |
+
masked_x = x * (1 - w) + mask_token * w
|
365 |
+
if disable_vlc:
|
366 |
+
x = masked_x
|
367 |
+
masked_start = 0
|
368 |
+
else:
|
369 |
+
x = torch.cat([x, masked_x], 0)
|
370 |
+
masked_start = batch_size
|
371 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(
|
372 |
+
x.shape[0], 1, x.shape[-1],
|
373 |
+
dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
374 |
+
x = x + self.positional_embedding.to(x.dtype)
|
375 |
+
x = self.ln_pre(x)
|
376 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
377 |
+
x = self.transformer(x)
|
378 |
+
masked_x = x[:, masked_start:]
|
379 |
+
masked_x = self.lm_transformer(masked_x)
|
380 |
+
masked_x = masked_x.permute(1, 0, 2)
|
381 |
+
masked_x = masked_x[:, 1:]
|
382 |
+
masked_x = self.ln_lm(masked_x)
|
383 |
+
if not return_all_tokens:
|
384 |
+
masked_x = masked_x[bool_masked_pos]
|
385 |
+
logits = self.lm_head(masked_x)
|
386 |
+
assert self.pool_type == "default"
|
387 |
+
result = {"logits": logits}
|
388 |
+
if not disable_vlc:
|
389 |
+
x = x[0, :batch_size]
|
390 |
+
x = self.ln_post(x)
|
391 |
+
if self.proj is not None:
|
392 |
+
x = x @ self.proj
|
393 |
+
result["feature"] = x
|
394 |
+
return result
|
395 |
+
|
396 |
+
|
397 |
+
def load_farl(model_type, model_file=None) -> VisualTransformer:
|
398 |
+
if model_type == "base":
|
399 |
+
model = VisualTransformer(
|
400 |
+
input_resolution=224, patch_size=16, width=768, layers=12, heads=12, output_dim=512)
|
401 |
+
elif model_type == "large":
|
402 |
+
model = VisualTransformer(
|
403 |
+
input_resolution=224, patch_size=16, width=1024, layers=24, heads=16, output_dim=512)
|
404 |
+
elif model_type == "huge":
|
405 |
+
model = VisualTransformer(
|
406 |
+
input_resolution=224, patch_size=14, width=1280, layers=32, heads=16, output_dim=512)
|
407 |
+
else:
|
408 |
+
raise
|
409 |
+
model.transformer.use_checkpoint = False
|
410 |
+
if model_file is not None:
|
411 |
+
checkpoint = torch.load(model_file, map_location='cpu')
|
412 |
+
state_dict = {}
|
413 |
+
for name, weight in checkpoint["state_dict"].items():
|
414 |
+
if name.startswith("visual"):
|
415 |
+
state_dict[name[7:]] = weight
|
416 |
+
inco = model.load_state_dict(state_dict, strict=False)
|
417 |
+
# print(inco.missing_keys)
|
418 |
+
assert len(inco.missing_keys) == 0
|
419 |
+
return model
|
facer/io.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def read_hwc(path: str) -> torch.Tensor:
|
7 |
+
"""Read an image from a given path.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
path (str): The given path.
|
11 |
+
"""
|
12 |
+
image = Image.open(path)
|
13 |
+
np_image = np.array(image.convert('RGB'))
|
14 |
+
return torch.from_numpy(np_image)
|
15 |
+
|
16 |
+
|
17 |
+
def write_hwc(image: torch.Tensor, path: str):
|
18 |
+
"""Write an image to a given path.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
image (torch.Tensor): The image.
|
22 |
+
path (str): The given path.
|
23 |
+
"""
|
24 |
+
|
25 |
+
Image.fromarray(image.cpu().numpy()).save(path)
|
26 |
+
|
27 |
+
|
28 |
+
|
facer/show.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
from .util import bchw2hwc
|
7 |
+
|
8 |
+
|
9 |
+
def set_figsize(*args):
|
10 |
+
if len(args) == 0:
|
11 |
+
plt.rcParams["figure.figsize"] = plt.rcParamsDefault["figure.figsize"]
|
12 |
+
elif len(args) == 1:
|
13 |
+
plt.rcParams["figure.figsize"] = (args[0], args[0])
|
14 |
+
elif len(args) == 2:
|
15 |
+
plt.rcParams["figure.figsize"] = tuple(args)
|
16 |
+
else:
|
17 |
+
raise RuntimeError(
|
18 |
+
f'Supported argument types: set_figsize() or set_figsize(int) or set_figsize(int, int)')
|
19 |
+
|
20 |
+
|
21 |
+
def show_hwc(image: torch.Tensor):
|
22 |
+
if image.dtype != torch.uint8:
|
23 |
+
image = image.to(torch.uint8)
|
24 |
+
if image.size(2) == 1:
|
25 |
+
image = image.repeat(1, 1, 3)
|
26 |
+
pimage = Image.fromarray(image.cpu().numpy())
|
27 |
+
plt.imshow(pimage)
|
28 |
+
plt.imsave('12345.jpg', pimage)
|
29 |
+
plt.show()
|
30 |
+
|
31 |
+
|
32 |
+
def show_bchw(image: torch.Tensor):
|
33 |
+
show_hwc(bchw2hwc(image))
|
34 |
+
|
35 |
+
|
36 |
+
def show_bhw(image: torch.Tensor):
|
37 |
+
show_bchw(image.unsqueeze(1))
|
facer/transform.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Callable, Tuple, Optional
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import functools
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def get_crop_and_resize_matrix(
|
9 |
+
box: torch.Tensor, target_shape: Tuple[int, int],
|
10 |
+
target_face_scale: float = 1.0, make_square_crop: bool = True,
|
11 |
+
offset_xy: Optional[Tuple[float, float]] = None, align_corners: bool = True,
|
12 |
+
offset_box_coords: bool = False) -> torch.Tensor:
|
13 |
+
"""
|
14 |
+
Args:
|
15 |
+
box: b x 4(x1, y1, x2, y2)
|
16 |
+
align_corners (bool): Set this to `True` only if the box you give has coordinates
|
17 |
+
ranging from `0` to `h-1` or `w-1`.
|
18 |
+
|
19 |
+
offset_box_coords (bool): Set this to `True` if the box you give has coordinates
|
20 |
+
ranging from `0` to `h` or `w`.
|
21 |
+
|
22 |
+
Set this to `False` if the box coordinates range from `-0.5` to `h-0.5` or `w-0.5`.
|
23 |
+
|
24 |
+
If the box coordinates range from `0` to `h-1` or `w-1`, set `align_corners=True`.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
torch.Tensor: b x 3 x 3.
|
28 |
+
"""
|
29 |
+
if offset_xy is None:
|
30 |
+
offset_xy = (0.0, 0.0)
|
31 |
+
|
32 |
+
x1, y1, x2, y2 = box.split(1, dim=1) # b x 1
|
33 |
+
cx = (x1 + x2) / 2 + offset_xy[0]
|
34 |
+
cy = (y1 + y2) / 2 + offset_xy[1]
|
35 |
+
rx = (x2 - x1) / 2 / target_face_scale
|
36 |
+
ry = (y2 - y1) / 2 / target_face_scale
|
37 |
+
if make_square_crop:
|
38 |
+
rx = ry = torch.maximum(rx, ry)
|
39 |
+
|
40 |
+
x1, y1, x2, y2 = cx - rx, cy - ry, cx + rx, cy + ry
|
41 |
+
|
42 |
+
h, w, *_ = target_shape
|
43 |
+
|
44 |
+
zeros_pl = torch.zeros_like(x1)
|
45 |
+
ones_pl = torch.ones_like(x1)
|
46 |
+
|
47 |
+
if align_corners:
|
48 |
+
# x -> (x - x1) / (x2 - x1) * (w - 1)
|
49 |
+
# y -> (y - y1) / (y2 - y1) * (h - 1)
|
50 |
+
ax = 1.0 / (x2 - x1) * (w - 1)
|
51 |
+
ay = 1.0 / (y2 - y1) * (h - 1)
|
52 |
+
matrix = torch.cat([
|
53 |
+
ax, zeros_pl, -x1 * ax,
|
54 |
+
zeros_pl, ay, -y1 * ay,
|
55 |
+
zeros_pl, zeros_pl, ones_pl
|
56 |
+
], dim=1).reshape(-1, 3, 3) # b x 3 x 3
|
57 |
+
else:
|
58 |
+
if offset_box_coords:
|
59 |
+
# x1, x2 \in [0, w], y1, y2 \in [0, h]
|
60 |
+
# first we should offset x1, x2, y1, y2 to be ranging in
|
61 |
+
# [-0.5, w-0.5] and [-0.5, h-0.5]
|
62 |
+
# so to convert these pixel coordinates into boundary coordinates.
|
63 |
+
x1, x2, y1, y2 = x1-0.5, x2-0.5, y1-0.5, y2-0.5
|
64 |
+
|
65 |
+
# x -> (x - x1) / (x2 - x1) * w - 0.5
|
66 |
+
# y -> (y - y1) / (y2 - y1) * h - 0.5
|
67 |
+
ax = 1.0 / (x2 - x1) * w
|
68 |
+
ay = 1.0 / (y2 - y1) * h
|
69 |
+
matrix = torch.cat([
|
70 |
+
ax, zeros_pl, -x1 * ax - 0.5*ones_pl,
|
71 |
+
zeros_pl, ay, -y1 * ay - 0.5*ones_pl,
|
72 |
+
zeros_pl, zeros_pl, ones_pl
|
73 |
+
], dim=1).reshape(-1, 3, 3) # b x 3 x 3
|
74 |
+
return matrix
|
75 |
+
|
76 |
+
|
77 |
+
def get_similarity_transform_matrix(
|
78 |
+
from_pts: torch.Tensor, to_pts: torch.Tensor) -> torch.Tensor:
|
79 |
+
"""
|
80 |
+
Args:
|
81 |
+
from_pts, to_pts: b x n x 2
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
torch.Tensor: b x 3 x 3
|
85 |
+
"""
|
86 |
+
mfrom = from_pts.mean(dim=1, keepdim=True) # b x 1 x 2
|
87 |
+
mto = to_pts.mean(dim=1, keepdim=True) # b x 1 x 2
|
88 |
+
|
89 |
+
a1 = (from_pts - mfrom).square().sum([1, 2], keepdim=False) # b
|
90 |
+
c1 = ((to_pts - mto) * (from_pts - mfrom)).sum([1, 2], keepdim=False) # b
|
91 |
+
|
92 |
+
to_delta = to_pts - mto
|
93 |
+
from_delta = from_pts - mfrom
|
94 |
+
c2 = (to_delta[:, :, 0] * from_delta[:, :, 1] - to_delta[:,
|
95 |
+
:, 1] * from_delta[:, :, 0]).sum([1], keepdim=False) # b
|
96 |
+
|
97 |
+
a = c1 / a1
|
98 |
+
b = c2 / a1
|
99 |
+
dx = mto[:, 0, 0] - a * mfrom[:, 0, 0] - b * mfrom[:, 0, 1] # b
|
100 |
+
dy = mto[:, 0, 1] + b * mfrom[:, 0, 0] - a * mfrom[:, 0, 1] # b
|
101 |
+
|
102 |
+
ones_pl = torch.ones_like(a1)
|
103 |
+
zeros_pl = torch.zeros_like(a1)
|
104 |
+
|
105 |
+
return torch.stack([
|
106 |
+
a, b, dx,
|
107 |
+
-b, a, dy,
|
108 |
+
zeros_pl, zeros_pl, ones_pl,
|
109 |
+
], dim=-1).reshape(-1, 3, 3)
|
110 |
+
|
111 |
+
|
112 |
+
@functools.lru_cache()
|
113 |
+
def _standard_face_pts():
|
114 |
+
pts = torch.tensor([
|
115 |
+
196.0, 226.0,
|
116 |
+
316.0, 226.0,
|
117 |
+
256.0, 286.0,
|
118 |
+
220.0, 360.4,
|
119 |
+
292.0, 360.4], dtype=torch.float32) / 256.0 - 1.0
|
120 |
+
return torch.reshape(pts, (5, 2))
|
121 |
+
|
122 |
+
|
123 |
+
def get_face_align_matrix(
|
124 |
+
face_pts: torch.Tensor, target_shape: Tuple[int, int],
|
125 |
+
target_face_scale: float = 1.0, offset_xy: Optional[Tuple[float, float]] = None,
|
126 |
+
target_pts: Optional[torch.Tensor] = None):
|
127 |
+
|
128 |
+
if target_pts is None:
|
129 |
+
with torch.no_grad():
|
130 |
+
std_pts = _standard_face_pts().to(face_pts) # [-1 1]
|
131 |
+
h, w, *_ = target_shape
|
132 |
+
target_pts = (std_pts * target_face_scale + 1) * \
|
133 |
+
torch.tensor([w-1, h-1]).to(face_pts) / 2.0
|
134 |
+
if offset_xy is not None:
|
135 |
+
target_pts[:, 0] += offset_xy[0]
|
136 |
+
target_pts[:, 1] += offset_xy[1]
|
137 |
+
else:
|
138 |
+
target_pts = target_pts.to(face_pts)
|
139 |
+
|
140 |
+
if target_pts.dim() == 2:
|
141 |
+
target_pts = target_pts.unsqueeze(0)
|
142 |
+
if target_pts.size(0) == 1:
|
143 |
+
target_pts = target_pts.broadcast_to(face_pts.shape)
|
144 |
+
|
145 |
+
assert target_pts.shape == face_pts.shape
|
146 |
+
|
147 |
+
return get_similarity_transform_matrix(face_pts, target_pts)
|
148 |
+
|
149 |
+
|
150 |
+
def rot90(v):
|
151 |
+
return np.array([-v[1], v[0]])
|
152 |
+
|
153 |
+
|
154 |
+
def get_quad(lm: torch.Tensor):
|
155 |
+
# N,2
|
156 |
+
lm = lm.detach().cpu().numpy()
|
157 |
+
# Choose oriented crop rectangle.
|
158 |
+
eye_avg = (lm[0] + lm[1]) * 0.5 + 0.5
|
159 |
+
mouth_avg = (lm[3] + lm[4]) * 0.5 + 0.5
|
160 |
+
eye_to_eye = lm[1] - lm[0]
|
161 |
+
eye_to_mouth = mouth_avg - eye_avg
|
162 |
+
x = eye_to_eye - rot90(eye_to_mouth)
|
163 |
+
x /= np.hypot(*x)
|
164 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
|
165 |
+
y = rot90(x)
|
166 |
+
c = eye_avg + eye_to_mouth * 0.1
|
167 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
168 |
+
quad_for_coeffs = quad[[0,3, 2,1]] # 顺序改一下
|
169 |
+
return torch.from_numpy(quad_for_coeffs).float()
|
170 |
+
|
171 |
+
|
172 |
+
def get_face_align_matrix_celebm(
|
173 |
+
face_pts: torch.Tensor, target_shape: Tuple[int, int]):
|
174 |
+
|
175 |
+
face_pts = torch.stack([get_quad(pts) for pts in face_pts], dim=0).to(face_pts)
|
176 |
+
|
177 |
+
assert target_shape[0] == target_shape[1]
|
178 |
+
target_size = target_shape[0]
|
179 |
+
target_pts = torch.as_tensor([[0, 0], [target_size,0], [target_size, target_size], [0, target_size]]).to(face_pts)
|
180 |
+
|
181 |
+
if target_pts.dim() == 2:
|
182 |
+
target_pts = target_pts.unsqueeze(0)
|
183 |
+
if target_pts.size(0) == 1:
|
184 |
+
target_pts = target_pts.broadcast_to(face_pts.shape)
|
185 |
+
|
186 |
+
assert target_pts.shape == face_pts.shape
|
187 |
+
|
188 |
+
return get_similarity_transform_matrix(face_pts, target_pts)
|
189 |
+
|
190 |
+
@functools.lru_cache(maxsize=128)
|
191 |
+
def _meshgrid(h, w) -> Tuple[torch.Tensor, torch.Tensor]:
|
192 |
+
yy, xx = torch.meshgrid(torch.arange(h).float(),
|
193 |
+
torch.arange(w).float(),
|
194 |
+
indexing='ij')
|
195 |
+
return yy, xx
|
196 |
+
|
197 |
+
|
198 |
+
def _forge_grid(batch_size: int, device: torch.device,
|
199 |
+
output_shape: Tuple[int, int],
|
200 |
+
fn: Callable[[torch.Tensor], torch.Tensor]
|
201 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
202 |
+
""" Forge transform maps with a given function `fn`.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
output_shape (tuple): (b, h, w, ...).
|
206 |
+
fn (Callable[[torch.Tensor], torch.Tensor]): The function that accepts
|
207 |
+
a bxnx2 array and outputs the transformed bxnx2 array. Both input
|
208 |
+
and output store (x, y) coordinates.
|
209 |
+
|
210 |
+
Note:
|
211 |
+
both input and output arrays of `fn` should store (y, x) coordinates.
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
Tuple[torch.Tensor, torch.Tensor]: Two maps `X` and `Y`, where for each
|
215 |
+
pixel (y, x) or coordinate (x, y),
|
216 |
+
`(X[y, x], Y[y, x]) = fn([x, y])`
|
217 |
+
"""
|
218 |
+
h, w, *_ = output_shape
|
219 |
+
yy, xx = _meshgrid(h, w) # h x w
|
220 |
+
yy = yy.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)
|
221 |
+
xx = xx.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)
|
222 |
+
|
223 |
+
in_xxyy = torch.stack(
|
224 |
+
[xx, yy], dim=-1).reshape([batch_size, h*w, 2]) # (h x w) x 2
|
225 |
+
out_xxyy: torch.Tensor = fn(in_xxyy) # (h x w) x 2
|
226 |
+
return out_xxyy.reshape(batch_size, h, w, 2)
|
227 |
+
|
228 |
+
|
229 |
+
def _safe_arctanh(x: torch.Tensor, eps: float = 0.001) -> torch.Tensor:
|
230 |
+
return torch.clamp(x, -1+eps, 1-eps).arctanh()
|
231 |
+
|
232 |
+
|
233 |
+
def inverted_tanh_warp_transform(coords: torch.Tensor, matrix: torch.Tensor,
|
234 |
+
warp_factor: float, warped_shape: Tuple[int, int]):
|
235 |
+
""" Inverted tanh-warp function.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
coords (torch.Tensor): b x n x 2 (x, y). The transformed coordinates.
|
239 |
+
matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates
|
240 |
+
from the original image to the aligned yet not-warped image.
|
241 |
+
warp_factor (float): The warp factor.
|
242 |
+
0 means linear transform, 1 means full tanh warp.
|
243 |
+
warped_shape (tuple): [height, width].
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
torch.Tensor: b x n x 2 (x, y). The original coordinates.
|
247 |
+
"""
|
248 |
+
h, w, *_ = warped_shape
|
249 |
+
# h -= 1
|
250 |
+
# w -= 1
|
251 |
+
|
252 |
+
w_h = torch.tensor([[w, h]]).to(coords)
|
253 |
+
|
254 |
+
if warp_factor > 0:
|
255 |
+
# normalize coordinates to [-1, +1]
|
256 |
+
coords = coords / w_h * 2 - 1
|
257 |
+
|
258 |
+
nl_part1 = coords > 1.0 - warp_factor
|
259 |
+
nl_part2 = coords < -1.0 + warp_factor
|
260 |
+
|
261 |
+
ret_nl_part1 = _safe_arctanh(
|
262 |
+
(coords - 1.0 + warp_factor) /
|
263 |
+
warp_factor) * warp_factor + \
|
264 |
+
1.0 - warp_factor
|
265 |
+
ret_nl_part2 = _safe_arctanh(
|
266 |
+
(coords + 1.0 - warp_factor) /
|
267 |
+
warp_factor) * warp_factor - \
|
268 |
+
1.0 + warp_factor
|
269 |
+
|
270 |
+
coords = torch.where(nl_part1, ret_nl_part1,
|
271 |
+
torch.where(nl_part2, ret_nl_part2, coords))
|
272 |
+
|
273 |
+
# denormalize
|
274 |
+
coords = (coords + 1) / 2 * w_h
|
275 |
+
|
276 |
+
coords_homo = torch.cat(
|
277 |
+
[coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3
|
278 |
+
|
279 |
+
# inv_matrix = torch.linalg.inv(matrix) # b x 3 x 3
|
280 |
+
device = matrix.device
|
281 |
+
inv_matrix_np = np.linalg.inv(matrix.cpu().numpy())
|
282 |
+
inv_matrix = torch.from_numpy(inv_matrix_np).to(device)
|
283 |
+
coords_homo = torch.bmm(
|
284 |
+
coords_homo, inv_matrix.permute(0, 2, 1)) # b x n x 3
|
285 |
+
return coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]
|
286 |
+
|
287 |
+
|
288 |
+
def tanh_warp_transform(
|
289 |
+
coords: torch.Tensor, matrix: torch.Tensor,
|
290 |
+
warp_factor: float, warped_shape: Tuple[int, int]):
|
291 |
+
""" Tanh-warp function.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
coords (torch.Tensor): b x n x 2 (x, y). The original coordinates.
|
295 |
+
matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates
|
296 |
+
from the original image to the aligned yet not-warped image.
|
297 |
+
warp_factor (float): The warp factor.
|
298 |
+
0 means linear transform, 1 means full tanh warp.
|
299 |
+
warped_shape (tuple): [height, width].
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
torch.Tensor: b x n x 2 (x, y). The transformed coordinates.
|
303 |
+
"""
|
304 |
+
h, w, *_ = warped_shape
|
305 |
+
# h -= 1
|
306 |
+
# w -= 1
|
307 |
+
w_h = torch.tensor([[w, h]]).to(coords)
|
308 |
+
|
309 |
+
coords_homo = torch.cat(
|
310 |
+
[coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3
|
311 |
+
|
312 |
+
coords_homo = torch.bmm(coords_homo, matrix.transpose(2, 1)) # b x n x 3
|
313 |
+
coords = (coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]) # b x n x 2
|
314 |
+
|
315 |
+
if warp_factor > 0:
|
316 |
+
# normalize coordinates to [-1, +1]
|
317 |
+
coords = coords / w_h * 2 - 1
|
318 |
+
|
319 |
+
nl_part1 = coords > 1.0 - warp_factor
|
320 |
+
nl_part2 = coords < -1.0 + warp_factor
|
321 |
+
|
322 |
+
ret_nl_part1 = torch.tanh(
|
323 |
+
(coords - 1.0 + warp_factor) /
|
324 |
+
warp_factor) * warp_factor + \
|
325 |
+
1.0 - warp_factor
|
326 |
+
ret_nl_part2 = torch.tanh(
|
327 |
+
(coords + 1.0 - warp_factor) /
|
328 |
+
warp_factor) * warp_factor - \
|
329 |
+
1.0 + warp_factor
|
330 |
+
|
331 |
+
coords = torch.where(nl_part1, ret_nl_part1,
|
332 |
+
torch.where(nl_part2, ret_nl_part2, coords))
|
333 |
+
|
334 |
+
# denormalize
|
335 |
+
coords = (coords + 1) / 2 * w_h
|
336 |
+
|
337 |
+
return coords
|
338 |
+
|
339 |
+
|
340 |
+
def make_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float,
|
341 |
+
warped_shape: Tuple[int, int],
|
342 |
+
orig_shape: Tuple[int, int]):
|
343 |
+
"""
|
344 |
+
Args:
|
345 |
+
matrix: bx3x3 matrix.
|
346 |
+
warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping,
|
347 |
+
`warp_factor=0.0` represents a cropping.
|
348 |
+
warped_shape: The target image shape to transform to.
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
torch.Tensor: b x h x w x 2 (x, y).
|
352 |
+
"""
|
353 |
+
orig_h, orig_w, *_ = orig_shape
|
354 |
+
w_h = torch.tensor([orig_w, orig_h]).to(matrix).reshape(1, 1, 1, 2)
|
355 |
+
return _forge_grid(
|
356 |
+
matrix.size(0), matrix.device,
|
357 |
+
warped_shape,
|
358 |
+
functools.partial(inverted_tanh_warp_transform,
|
359 |
+
matrix=matrix,
|
360 |
+
warp_factor=warp_factor,
|
361 |
+
warped_shape=warped_shape)) / w_h*2-1
|
362 |
+
|
363 |
+
|
364 |
+
def make_inverted_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float,
|
365 |
+
warped_shape: Tuple[int, int],
|
366 |
+
orig_shape: Tuple[int, int]):
|
367 |
+
"""
|
368 |
+
Args:
|
369 |
+
matrix: bx3x3 matrix.
|
370 |
+
warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping,
|
371 |
+
`warp_factor=0.0` represents a cropping.
|
372 |
+
warped_shape: The target image shape to transform to.
|
373 |
+
orig_shape: The original image shape that is transformed from.
|
374 |
+
|
375 |
+
Returns:
|
376 |
+
torch.Tensor: b x h x w x 2 (x, y).
|
377 |
+
"""
|
378 |
+
h, w, *_ = warped_shape
|
379 |
+
w_h = torch.tensor([w, h]).to(matrix).reshape(1, 1, 1, 2)
|
380 |
+
return _forge_grid(
|
381 |
+
matrix.size(0), matrix.device,
|
382 |
+
orig_shape,
|
383 |
+
functools.partial(tanh_warp_transform,
|
384 |
+
matrix=matrix,
|
385 |
+
warp_factor=warp_factor,
|
386 |
+
warped_shape=warped_shape)) / w_h * 2-1
|
facer/util.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Any, Optional, Union, List, Dict
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
from urllib.parse import urlparse
|
6 |
+
import errno
|
7 |
+
import sys
|
8 |
+
import validators
|
9 |
+
import requests
|
10 |
+
import json
|
11 |
+
|
12 |
+
|
13 |
+
def hwc2bchw(images: torch.Tensor) -> torch.Tensor:
|
14 |
+
return images.unsqueeze(0).permute(0, 3, 1, 2)
|
15 |
+
|
16 |
+
|
17 |
+
def bchw2hwc(images: torch.Tensor, nrows: Optional[int] = None, border: int = 2,
|
18 |
+
background_value: float = 0) -> torch.Tensor:
|
19 |
+
""" make a grid image from an image batch.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
images (torch.Tensor): input image batch.
|
23 |
+
nrows: rows of grid.
|
24 |
+
border: border size in pixel.
|
25 |
+
background_value: color value of background.
|
26 |
+
"""
|
27 |
+
assert images.ndim == 4 # n x c x h x w
|
28 |
+
images = images.permute(0, 2, 3, 1) # n x h x w x c
|
29 |
+
n, h, w, c = images.shape
|
30 |
+
if nrows is None:
|
31 |
+
nrows = max(int(math.sqrt(n)), 1)
|
32 |
+
ncols = (n + nrows - 1) // nrows
|
33 |
+
result = torch.full([(h + border) * nrows - border,
|
34 |
+
(w + border) * ncols - border, c], background_value,
|
35 |
+
device=images.device,
|
36 |
+
dtype=images.dtype)
|
37 |
+
|
38 |
+
for i, single_image in enumerate(images):
|
39 |
+
row = i // ncols
|
40 |
+
col = i % ncols
|
41 |
+
yy = (h + border) * row
|
42 |
+
xx = (w + border) * col
|
43 |
+
result[yy:(yy + h), xx:(xx + w), :] = single_image
|
44 |
+
return result
|
45 |
+
|
46 |
+
|
47 |
+
def bchw2bhwc(images: torch.Tensor) -> torch.Tensor:
|
48 |
+
return images.permute(0, 2, 3, 1)
|
49 |
+
|
50 |
+
|
51 |
+
def bhwc2bchw(images: torch.Tensor) -> torch.Tensor:
|
52 |
+
return images.permute(0, 3, 1, 2)
|
53 |
+
|
54 |
+
|
55 |
+
def bhwc2hwc(images: torch.Tensor, *kargs, **kwargs) -> torch.Tensor:
|
56 |
+
return bchw2hwc(bhwc2bchw(images), *kargs, **kwargs)
|
57 |
+
|
58 |
+
|
59 |
+
def select_data(selection, data):
|
60 |
+
if isinstance(data, dict):
|
61 |
+
return {name: select_data(selection, val) for name, val in data.items()}
|
62 |
+
elif isinstance(data, (list, tuple)):
|
63 |
+
return [select_data(selection, val) for val in data]
|
64 |
+
elif isinstance(data, torch.Tensor):
|
65 |
+
return data[selection]
|
66 |
+
return data
|
67 |
+
|
68 |
+
|
69 |
+
def download_from_github(to_path, organisation, repository, file_path, branch='main', username=None, access_token=None):
|
70 |
+
""" download files (including LFS files) from github.
|
71 |
+
|
72 |
+
For example, in order to downlod https://github.com/FacePerceiver/facer/blob/main/README.md, call with
|
73 |
+
```
|
74 |
+
download_from_github(
|
75 |
+
to_path='README.md', organisation='FacePerceiver',
|
76 |
+
repository='facer', file_path='README.md', branch='main')
|
77 |
+
```
|
78 |
+
"""
|
79 |
+
if username is not None:
|
80 |
+
assert access_token is not None
|
81 |
+
auth = (username, access_token)
|
82 |
+
else:
|
83 |
+
auth = None
|
84 |
+
r = requests.get(f'https://api.github.com/repos/{organisation}/{repository}/contents/{file_path}?ref={branch}',
|
85 |
+
auth=auth)
|
86 |
+
data = json.loads(r.content)
|
87 |
+
torch.hub.download_url_to_file(data['download_url'], to_path)
|
88 |
+
|
89 |
+
|
90 |
+
def is_github_url(url: str):
|
91 |
+
"""
|
92 |
+
A typical github url should be like
|
93 |
+
https://github.com/FacePerceiver/facer/blob/main/facer/util.py or
|
94 |
+
https://github.com/FacePerceiver/facer/raw/main/facer/util.py.
|
95 |
+
"""
|
96 |
+
return ('blob' in url or 'raw' in url) and url.startswith('https://github.com/')
|
97 |
+
|
98 |
+
|
99 |
+
def get_github_components(url: str):
|
100 |
+
assert is_github_url(url)
|
101 |
+
organisation, repository, blob_or_raw, branch, * \
|
102 |
+
path = url[len('https://github.com/'):].split('/')
|
103 |
+
assert blob_or_raw in {'blob', 'raw'}
|
104 |
+
return organisation, repository, branch, '/'.join(path)
|
105 |
+
|
106 |
+
|
107 |
+
def download_url_to_file(url, dst, **kwargs):
|
108 |
+
if is_github_url(url):
|
109 |
+
org, rep, branch, path = get_github_components(url)
|
110 |
+
download_from_github(dst, org, rep, path, branch, kwargs.get(
|
111 |
+
'username', None), kwargs.get('access_token', None))
|
112 |
+
else:
|
113 |
+
torch.hub.download_url_to_file(url, dst)
|
114 |
+
|
115 |
+
|
116 |
+
def select_data(selection, data):
|
117 |
+
if isinstance(data, dict):
|
118 |
+
return {name: select_data(selection, val) for name, val in data.items()}
|
119 |
+
elif isinstance(data, (list, tuple)):
|
120 |
+
return [select_data(selection, val) for val in data]
|
121 |
+
elif isinstance(data, torch.Tensor):
|
122 |
+
return data[selection]
|
123 |
+
return data
|
124 |
+
|
125 |
+
|
126 |
+
def download_jit(url_or_paths: Union[str, List[str]], model_dir=None, map_location=None, jit=True, **kwargs):
|
127 |
+
if isinstance(url_or_paths, str):
|
128 |
+
url_or_paths = [url_or_paths]
|
129 |
+
|
130 |
+
for url_or_path in url_or_paths:
|
131 |
+
try:
|
132 |
+
if validators.url(url_or_path):
|
133 |
+
url = url_or_path
|
134 |
+
if model_dir is None:
|
135 |
+
if hasattr(torch.hub, 'get_dir'):
|
136 |
+
hub_dir = torch.hub.get_dir()
|
137 |
+
else:
|
138 |
+
hub_dir = os.path.join(os.path.expanduser(
|
139 |
+
'~'), '.cache', 'torch', 'hub')
|
140 |
+
model_dir = os.path.join(hub_dir, 'checkpoints')
|
141 |
+
|
142 |
+
try:
|
143 |
+
os.makedirs(model_dir)
|
144 |
+
except OSError as e:
|
145 |
+
if e.errno == errno.EEXIST:
|
146 |
+
# Directory already exists, ignore.
|
147 |
+
pass
|
148 |
+
else:
|
149 |
+
# Unexpected OSError, re-raise.
|
150 |
+
raise
|
151 |
+
|
152 |
+
parts = urlparse(url)
|
153 |
+
filename = os.path.basename(parts.path)
|
154 |
+
cached_file = os.path.join(model_dir, filename)
|
155 |
+
if not os.path.exists(cached_file):
|
156 |
+
sys.stderr.write(
|
157 |
+
'Downloading: "{}" to {}\n'.format(url, cached_file))
|
158 |
+
download_url_to_file(url, cached_file)
|
159 |
+
else:
|
160 |
+
cached_file = url_or_path
|
161 |
+
if jit:
|
162 |
+
return torch.jit.load(cached_file, map_location=map_location, **kwargs)
|
163 |
+
else:
|
164 |
+
return torch.load(cached_file, map_location=map_location, **kwargs)
|
165 |
+
except:
|
166 |
+
sys.stderr.write(f'failed downloading from {url_or_path}\n')
|
167 |
+
raise
|
168 |
+
|
169 |
+
raise RuntimeError('failed to download jit models from all given urls')
|
facer/version.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__="0.0.4"
|
models_mae.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR
|
3 |
+
# --------------------------------------------------------
|
4 |
+
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
+
# You can find the license in the LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
from timm.models.vision_transformer import PatchEmbed, Block
|
14 |
+
|
15 |
+
from util.pos_embed import get_2d_sincos_pos_embed
|
16 |
+
|
17 |
+
|
18 |
+
class MaskedAutoencoderViT(nn.Module):
|
19 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3,
|
23 |
+
embed_dim=1024, depth=24, num_heads=16,
|
24 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
25 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
# --------------------------------------------------------------------------
|
29 |
+
# MAE encoder specifics
|
30 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
|
31 |
+
num_patches = self.patch_embed.num_patches
|
32 |
+
|
33 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
34 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim),
|
35 |
+
requires_grad=False) # fixed sin-cos embedding
|
36 |
+
|
37 |
+
self.blocks = nn.ModuleList([
|
38 |
+
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) # qk_scale=None
|
39 |
+
for i in range(depth)])
|
40 |
+
self.norm = norm_layer(embed_dim)
|
41 |
+
# --------------------------------------------------------------------------
|
42 |
+
|
43 |
+
# --------------------------------------------------------------------------
|
44 |
+
# MAE decoder specifics
|
45 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
46 |
+
|
47 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
48 |
+
|
49 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
|
50 |
+
requires_grad=False) # fixed sin-cos embedding
|
51 |
+
|
52 |
+
self.decoder_blocks = nn.ModuleList([
|
53 |
+
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) # qk_scale=None
|
54 |
+
for i in range(decoder_depth)])
|
55 |
+
|
56 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
57 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch
|
58 |
+
# --------------------------------------------------------------------------
|
59 |
+
|
60 |
+
self.norm_pix_loss = norm_pix_loss
|
61 |
+
|
62 |
+
self.initialize_weights()
|
63 |
+
|
64 |
+
def initialize_weights(self):
|
65 |
+
# initialization
|
66 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
67 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5),
|
68 |
+
cls_token=True)
|
69 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
70 |
+
|
71 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
|
72 |
+
int(self.patch_embed.num_patches ** .5), cls_token=True)
|
73 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
74 |
+
|
75 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
76 |
+
w = self.patch_embed.proj.weight.data
|
77 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
78 |
+
|
79 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
80 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
81 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
82 |
+
|
83 |
+
# initialize nn.Linear and nn.LayerNorm
|
84 |
+
self.apply(self._init_weights)
|
85 |
+
|
86 |
+
def _init_weights(self, m):
|
87 |
+
if isinstance(m, nn.Linear):
|
88 |
+
# we use xavier_uniform following official JAX ViT:
|
89 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
90 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
91 |
+
nn.init.constant_(m.bias, 0)
|
92 |
+
elif isinstance(m, nn.LayerNorm):
|
93 |
+
nn.init.constant_(m.bias, 0)
|
94 |
+
nn.init.constant_(m.weight, 1.0)
|
95 |
+
|
96 |
+
def patchify(self, imgs):
|
97 |
+
"""
|
98 |
+
imgs: (N, 3, H, W)
|
99 |
+
x: (N, L, patch_size**2 *3)
|
100 |
+
"""
|
101 |
+
p = self.patch_embed.patch_size[0]
|
102 |
+
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
103 |
+
|
104 |
+
h = w = imgs.shape[2] // p
|
105 |
+
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
106 |
+
x = torch.einsum('nchpwq->nhwpqc', x)
|
107 |
+
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
|
108 |
+
return x
|
109 |
+
|
110 |
+
def unpatchify(self, x):
|
111 |
+
"""
|
112 |
+
x: (N, L, patch_size**2 *3)
|
113 |
+
imgs: (N, 3, H, W)
|
114 |
+
"""
|
115 |
+
p = self.patch_embed.patch_size[0]
|
116 |
+
h = w = int(x.shape[1] ** .5)
|
117 |
+
assert h * w == x.shape[1]
|
118 |
+
|
119 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
120 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
121 |
+
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
122 |
+
return imgs
|
123 |
+
|
124 |
+
def random_masking(self, x, mask_ratio):
|
125 |
+
"""
|
126 |
+
Perform per-sample random masking by per-sample shuffling.
|
127 |
+
Per-sample shuffling is done by argsort random noise.
|
128 |
+
x: [N, L, D], sequence
|
129 |
+
"""
|
130 |
+
N, L, D = x.shape # batch, length, dim
|
131 |
+
len_keep = int(L * (1 - mask_ratio))
|
132 |
+
|
133 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
134 |
+
|
135 |
+
# sort noise for each sample
|
136 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
137 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
138 |
+
|
139 |
+
# keep the first subset
|
140 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
141 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
142 |
+
|
143 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
144 |
+
mask = torch.ones([N, L], device=x.device)
|
145 |
+
mask[:, :len_keep] = 0
|
146 |
+
# unshuffle to get the binary mask
|
147 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
148 |
+
|
149 |
+
return x_masked, mask, ids_restore
|
150 |
+
|
151 |
+
def forward_encoder(self, x, mask_ratio):
|
152 |
+
# embed patches
|
153 |
+
x = self.patch_embed(x)
|
154 |
+
|
155 |
+
# add pos embed w/o cls token
|
156 |
+
x = x + self.pos_embed[:, 1:, :]
|
157 |
+
|
158 |
+
# masking: length -> length * mask_ratio
|
159 |
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
160 |
+
|
161 |
+
# append cls token
|
162 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
163 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
164 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
165 |
+
|
166 |
+
# apply Transformer blocks
|
167 |
+
for blk in self.blocks:
|
168 |
+
x = blk(x)
|
169 |
+
x = self.norm(x)
|
170 |
+
|
171 |
+
return x, mask, ids_restore
|
172 |
+
|
173 |
+
def forward_decoder(self, x, ids_restore):
|
174 |
+
# embed tokens
|
175 |
+
x = self.decoder_embed(x)
|
176 |
+
|
177 |
+
# append mask tokens to sequence
|
178 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
179 |
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
180 |
+
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
181 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
182 |
+
|
183 |
+
# add pos embed
|
184 |
+
x = x + self.decoder_pos_embed
|
185 |
+
|
186 |
+
# apply Transformer blocks
|
187 |
+
for blk in self.decoder_blocks:
|
188 |
+
x = blk(x)
|
189 |
+
x = self.decoder_norm(x)
|
190 |
+
|
191 |
+
# predictor projection
|
192 |
+
x = self.decoder_pred(x)
|
193 |
+
|
194 |
+
# remove cls token
|
195 |
+
x = x[:, 1:, :]
|
196 |
+
|
197 |
+
return x
|
198 |
+
|
199 |
+
def forward_loss(self, imgs, pred, mask):
|
200 |
+
"""
|
201 |
+
imgs: [N, 3, H, W]
|
202 |
+
pred: [N, L, p*p*3]
|
203 |
+
mask: [N, L], 0 is keep, 1 is remove,
|
204 |
+
"""
|
205 |
+
target = self.patchify(imgs)
|
206 |
+
if self.norm_pix_loss:
|
207 |
+
mean = target.mean(dim=-1, keepdim=True)
|
208 |
+
var = target.var(dim=-1, keepdim=True)
|
209 |
+
target = (target - mean) / (var + 1.e-6) ** .5
|
210 |
+
|
211 |
+
loss = (pred - target) ** 2
|
212 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
213 |
+
|
214 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
215 |
+
return loss
|
216 |
+
|
217 |
+
def forward(self, imgs, mask_ratio=0.75):
|
218 |
+
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
|
219 |
+
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
|
220 |
+
loss = self.forward_loss(imgs, pred, mask)
|
221 |
+
return loss, pred, mask
|
222 |
+
|
223 |
+
|
224 |
+
def mae_vit_base_patch16_dec512d8b(**kwargs):
|
225 |
+
model = MaskedAutoencoderViT(
|
226 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
227 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
228 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
229 |
+
return model
|
230 |
+
|
231 |
+
|
232 |
+
def mae_vit_large_patch16_dec512d8b(**kwargs):
|
233 |
+
model = MaskedAutoencoderViT(
|
234 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
235 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
236 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
237 |
+
return model
|
238 |
+
|
239 |
+
|
240 |
+
def mae_vit_huge_patch14_dec512d8b(**kwargs):
|
241 |
+
model = MaskedAutoencoderViT(
|
242 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16,
|
243 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
244 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
245 |
+
return model
|
246 |
+
|
247 |
+
|
248 |
+
# set recommended archs
|
249 |
+
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
|
250 |
+
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
|
251 |
+
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
|
requirements.txt
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
annotated-types==0.7.0
|
3 |
+
anyio==4.6.2.post1
|
4 |
+
certifi==2024.8.30
|
5 |
+
charset-normalizer==3.4.0
|
6 |
+
click==8.1.7
|
7 |
+
cmake==3.31.0.1
|
8 |
+
contourpy==1.3.0
|
9 |
+
cycler==0.12.1
|
10 |
+
# dlib==19.24.6
|
11 |
+
exceptiongroup==1.2.2
|
12 |
+
fastapi==0.115.4
|
13 |
+
ffmpy==0.4.0
|
14 |
+
filelock==3.16.1
|
15 |
+
fonttools==4.54.1
|
16 |
+
fsspec==2024.10.0
|
17 |
+
gradio==4.44.1
|
18 |
+
gradio_client==1.3.0
|
19 |
+
h11==0.14.0
|
20 |
+
httpcore==1.0.6
|
21 |
+
httpx==0.27.2
|
22 |
+
huggingface-hub==0.26.2
|
23 |
+
idna==3.10
|
24 |
+
imageio==2.36.0
|
25 |
+
importlib_resources==6.4.5
|
26 |
+
Jinja2==3.1.4
|
27 |
+
kiwisolver==1.4.7
|
28 |
+
lazy_loader==0.4
|
29 |
+
lit==18.1.8
|
30 |
+
markdown-it-py==3.0.0
|
31 |
+
MarkupSafe==2.1.5
|
32 |
+
matplotlib==3.9.2
|
33 |
+
mdurl==0.1.2
|
34 |
+
mpmath==1.3.0
|
35 |
+
networkx==3.2.1
|
36 |
+
numpy==1.24.4
|
37 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
38 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
39 |
+
nvidia-cudnn-cu11==8.5.0.96
|
40 |
+
opencv-python==4.10.0.84
|
41 |
+
orjson==3.10.11
|
42 |
+
packaging==24.2
|
43 |
+
pandas==2.2.3
|
44 |
+
pillow==10.4.0
|
45 |
+
pydantic==2.9.2
|
46 |
+
pydantic_core==2.23.4
|
47 |
+
pydub==0.25.1
|
48 |
+
Pygments==2.18.0
|
49 |
+
pyparsing==3.2.0
|
50 |
+
python-dateutil==2.9.0.post0
|
51 |
+
python-multipart==0.0.17
|
52 |
+
pytz==2024.2
|
53 |
+
PyYAML==6.0.2
|
54 |
+
requests==2.32.3
|
55 |
+
rich==13.9.4
|
56 |
+
ruff==0.7.3
|
57 |
+
safetensors==0.4.5
|
58 |
+
scikit-image==0.24.0
|
59 |
+
scipy==1.13.1
|
60 |
+
semantic-version==2.10.0
|
61 |
+
shellingham==1.5.4
|
62 |
+
six==1.16.0
|
63 |
+
sniffio==1.3.1
|
64 |
+
starlette==0.41.2
|
65 |
+
sympy==1.13.1
|
66 |
+
tifffile==2024.8.30
|
67 |
+
timm==1.0.11
|
68 |
+
tomlkit==0.12.0
|
69 |
+
torch==2.0.0
|
70 |
+
torchvision==0.15.1
|
71 |
+
tqdm==4.67.0
|
72 |
+
triton==2.0.0
|
73 |
+
typer==0.13.0
|
74 |
+
typing_extensions==4.12.2
|
75 |
+
tzdata==2024.2
|
76 |
+
urllib3==2.2.3
|
77 |
+
uvicorn==0.32.0
|
78 |
+
validators==0.34.0
|
79 |
+
websockets==12.0
|
80 |
+
zipp==3.21.0
|
util/crop.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR
|
3 |
+
# --------------------------------------------------------
|
4 |
+
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
+
# You can find the license in the LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import math
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from torchvision import transforms
|
13 |
+
from torchvision.transforms import functional as F
|
14 |
+
|
15 |
+
|
16 |
+
class RandomResizedCrop(transforms.RandomResizedCrop):
|
17 |
+
"""
|
18 |
+
RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
|
19 |
+
This may lead to results different with torchvision's version.
|
20 |
+
Following BYOL's TF code:
|
21 |
+
https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
|
22 |
+
"""
|
23 |
+
@staticmethod
|
24 |
+
def get_params(img, scale, ratio):
|
25 |
+
width, height = F._get_image_size(img)
|
26 |
+
area = height * width
|
27 |
+
|
28 |
+
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
|
29 |
+
log_ratio = torch.log(torch.tensor(ratio))
|
30 |
+
aspect_ratio = torch.exp(
|
31 |
+
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
|
32 |
+
).item()
|
33 |
+
|
34 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
35 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
36 |
+
|
37 |
+
w = min(w, width)
|
38 |
+
h = min(h, height)
|
39 |
+
|
40 |
+
i = torch.randint(0, height - h + 1, size=(1,)).item()
|
41 |
+
j = torch.randint(0, width - w + 1, size=(1,)).item()
|
42 |
+
|
43 |
+
return i, j, h, w
|
util/datasets.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR
|
3 |
+
# --------------------------------------------------------
|
4 |
+
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
+
# You can find the license in the LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
import shutil
|
11 |
+
|
12 |
+
from torchvision import datasets, transforms
|
13 |
+
|
14 |
+
from timm.data import create_transform
|
15 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
from PIL import Image
|
19 |
+
import random
|
20 |
+
import torch
|
21 |
+
from torch.utils.data import DataLoader, Dataset, ConcatDataset
|
22 |
+
from torchvision import transforms
|
23 |
+
from torch.nn import functional as F
|
24 |
+
|
25 |
+
|
26 |
+
class collate_fn_crfrp:
|
27 |
+
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75):
|
28 |
+
self.img_size = input_size
|
29 |
+
self.patch_size = patch_size
|
30 |
+
self.num_patches_axis = input_size // patch_size
|
31 |
+
self.num_patches = (input_size // patch_size) ** 2
|
32 |
+
self.mask_ratio = mask_ratio
|
33 |
+
|
34 |
+
# --------------------------------------------------------------------------
|
35 |
+
# self.facial_region_group = [
|
36 |
+
# [2], # right eyebrow
|
37 |
+
# [3], # left eyebrow
|
38 |
+
# [4], # right eye
|
39 |
+
# [5], # left eye
|
40 |
+
# [6], # nose
|
41 |
+
# [7, 8], # upper mouth
|
42 |
+
# [8, 9], # lower mouth
|
43 |
+
# [10, 1, 0], # facial boundaries
|
44 |
+
# [10], # hair
|
45 |
+
# [1], # facial skin
|
46 |
+
# [0] # background
|
47 |
+
# ]
|
48 |
+
self.facial_region_group = [
|
49 |
+
[2, 3], # eyebrows
|
50 |
+
[4, 5], # eyes
|
51 |
+
[6], # nose
|
52 |
+
[7, 8, 9], # mouth
|
53 |
+
[10, 1, 0], # face boundaries
|
54 |
+
[10], # hair
|
55 |
+
[1], # facial skin
|
56 |
+
[0] # background
|
57 |
+
] # ['background', 'face', 'rb', 'lb', 're', 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
|
58 |
+
|
59 |
+
def __call__(self, samples):
|
60 |
+
image, img_mask, facial_region_mask, random_specific_facial_region \
|
61 |
+
= self.CRFR_P_masking(samples, specified_facial_region=None)
|
62 |
+
|
63 |
+
return {'image': image, 'img_mask': img_mask, 'specific_facial_region_mask': facial_region_mask}
|
64 |
+
|
65 |
+
# # using following code if using different data augmentation for target view
|
66 |
+
# image, img_mask, facial_region_mask, random_specific_facial_region \
|
67 |
+
# = self.CRFR_P_masking(samples, specified_facial_region=None)
|
68 |
+
# image_cl, img_mask_cl, facial_region_mask_cl, random_specific_facial_region_cl \
|
69 |
+
# = self.CRFR_P_masking(samples, specified_facial_region=random_specific_facial_region)
|
70 |
+
#
|
71 |
+
# return {'image': image, 'img_mask': img_mask, 'specific_facial_region_mask': facial_region_mask,
|
72 |
+
# 'image_cl': image_cl, 'img_mask_cl': img_mask_cl, 'specific_facial_region_mask_cl': facial_region_mask_cl}
|
73 |
+
|
74 |
+
def CRFR_P_masking(self, samples, specified_facial_region=None):
|
75 |
+
image = torch.stack([sample['image'] for sample in samples]) # torch.Size([bs, 3, 224, 224])
|
76 |
+
parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) # torch.Size([bs, 1, 224, 224])
|
77 |
+
parsing_map = parsing_map.squeeze(1) # torch.Size([BS, 1, 224, 224]) → torch.Size([BS, 224, 224])
|
78 |
+
|
79 |
+
# covering a randomly select facial_region_group and get fr_mask(masking all patches include this region)
|
80 |
+
facial_region_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis,
|
81 |
+
dtype=torch.float32) # torch.Size([BS, H/P, W/P])
|
82 |
+
facial_region_mask, random_specific_facial_region \
|
83 |
+
= self.masking_all_patches_in_random_specific_facial_region(parsing_map, facial_region_mask)
|
84 |
+
# torch.Size([num_patches,]), list
|
85 |
+
|
86 |
+
img_mask, facial_region_mask \
|
87 |
+
= self.variable_proportional_masking(parsing_map, facial_region_mask, random_specific_facial_region)
|
88 |
+
# torch.Size([num_patches,]), torch.Size([num_patches,])
|
89 |
+
|
90 |
+
del parsing_map
|
91 |
+
return image, img_mask, facial_region_mask, random_specific_facial_region
|
92 |
+
|
93 |
+
def masking_all_patches_in_random_specific_facial_region(self, parsing_map, facial_region_mask,
|
94 |
+
# specified_facial_region=None
|
95 |
+
):
|
96 |
+
# while True:
|
97 |
+
# random_specific_facial_region = random.choice(self.facial_region_group[:-2])
|
98 |
+
# if random_specific_facial_region != specified_facial_region:
|
99 |
+
# break
|
100 |
+
random_specific_facial_region = random.choice(self.facial_region_group[:-2])
|
101 |
+
if random_specific_facial_region == [10, 1, 0]: # facial boundaries, 10-hair 1-skin 0-background
|
102 |
+
# True for hair(10) or bg(0) patches:
|
103 |
+
patch_hair_bg = F.max_pool2d(((parsing_map == 10) + (parsing_map == 0)).float(),
|
104 |
+
kernel_size=self.patch_size)
|
105 |
+
# True for skin(1) patches:
|
106 |
+
patch_skin = F.max_pool2d((parsing_map == 1).float(), kernel_size=self.patch_size)
|
107 |
+
# skin&hair or skin&bg is defined as facial boundaries:
|
108 |
+
facial_region_mask = (patch_hair_bg.bool() & patch_skin.bool()).float()
|
109 |
+
else:
|
110 |
+
for facial_region_index in random_specific_facial_region:
|
111 |
+
facial_region_mask = torch.maximum(facial_region_mask,
|
112 |
+
F.max_pool2d((parsing_map == facial_region_index).float(),
|
113 |
+
kernel_size=self.patch_size))
|
114 |
+
|
115 |
+
return facial_region_mask.view(parsing_map.size(0), -1), random_specific_facial_region
|
116 |
+
|
117 |
+
def variable_proportional_masking(self, parsing_map, facial_region_mask, random_specific_facial_region):
|
118 |
+
img_mask = facial_region_mask.clone()
|
119 |
+
|
120 |
+
# proportional masking patches in other regions
|
121 |
+
other_facial_region_group = [region for region in self.facial_region_group if
|
122 |
+
region != random_specific_facial_region]
|
123 |
+
# print(other_facial_region_group)
|
124 |
+
for i in range(facial_region_mask.size(0)): # iterate each map in BS
|
125 |
+
num_mask_to_change = (self.mask_ratio * self.num_patches - facial_region_mask[i].sum(dim=-1)).int()
|
126 |
+
# mask_change_to = 1 if num_mask_to_change >= 0 else 0
|
127 |
+
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item()
|
128 |
+
|
129 |
+
if mask_change_to == 1:
|
130 |
+
# proportional masking patches in other facial regions according to the corresponding ratio
|
131 |
+
mask_ratio_other_fr = (
|
132 |
+
num_mask_to_change / (self.num_patches - facial_region_mask[i].sum(dim=-1)))
|
133 |
+
|
134 |
+
masked_patches = facial_region_mask[i].clone()
|
135 |
+
for other_fr in other_facial_region_group:
|
136 |
+
to_mask_patches = torch.zeros(1, self.num_patches_axis, self.num_patches_axis,
|
137 |
+
dtype=torch.float32)
|
138 |
+
if other_fr == [10, 1, 0]:
|
139 |
+
patch_hair_bg = F.max_pool2d(
|
140 |
+
((parsing_map[i].unsqueeze(0) == 10) + (parsing_map[i].unsqueeze(0) == 0)).float(),
|
141 |
+
kernel_size=self.patch_size)
|
142 |
+
patch_skin = F.max_pool2d((parsing_map[i].unsqueeze(0) == 1).float(),
|
143 |
+
kernel_size=self.patch_size)
|
144 |
+
# skin&hair or skin&bg defined as facial boundaries:
|
145 |
+
to_mask_patches = (patch_hair_bg.bool() & patch_skin.bool()).float()
|
146 |
+
else:
|
147 |
+
for facial_region_index in other_fr:
|
148 |
+
to_mask_patches = torch.maximum(to_mask_patches,
|
149 |
+
F.max_pool2d((parsing_map[i].unsqueeze(
|
150 |
+
0) == facial_region_index).float(),
|
151 |
+
kernel_size=self.patch_size))
|
152 |
+
|
153 |
+
# ignore already masked patches:
|
154 |
+
to_mask_patches = (to_mask_patches.view(-1) - masked_patches) > 0
|
155 |
+
select_indices = to_mask_patches.nonzero(as_tuple=False).view(-1)
|
156 |
+
change_indices = torch.randperm(len(select_indices))[
|
157 |
+
:torch.round(to_mask_patches.sum() * mask_ratio_other_fr).int()]
|
158 |
+
img_mask[i, select_indices[change_indices]] = mask_change_to
|
159 |
+
# prevent overlap
|
160 |
+
masked_patches = masked_patches + to_mask_patches.float()
|
161 |
+
|
162 |
+
# mask/unmask patch from other facial regions to get img_mask with fixed size
|
163 |
+
num_mask_to_change = (self.mask_ratio * self.num_patches - img_mask[i].sum(dim=-1)).int()
|
164 |
+
# mask_change_to = 1 if num_mask_to_change >= 0 else 0
|
165 |
+
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item()
|
166 |
+
# prevent unmasking facial_region_mask
|
167 |
+
select_indices = ((img_mask[i] + facial_region_mask[i]) == (1 - mask_change_to)).nonzero(
|
168 |
+
as_tuple=False).view(-1)
|
169 |
+
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)]
|
170 |
+
img_mask[i, select_indices[change_indices]] = mask_change_to
|
171 |
+
|
172 |
+
else:
|
173 |
+
# Extreme situations:
|
174 |
+
# if fr_mask is already over(>=) num_patches*mask_ratio, unmask it to get img_mask with fixed ratio
|
175 |
+
select_indices = (facial_region_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1)
|
176 |
+
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)]
|
177 |
+
img_mask[i, select_indices[change_indices]] = mask_change_to
|
178 |
+
facial_region_mask[i] = img_mask[i]
|
179 |
+
|
180 |
+
return img_mask, facial_region_mask
|
181 |
+
|
182 |
+
|
183 |
+
class FaceParsingDataset(Dataset):
|
184 |
+
def __init__(self, root, transform=None):
|
185 |
+
self.root_dir = root
|
186 |
+
self.transform = transform
|
187 |
+
self.image_folder = os.path.join(root, 'images')
|
188 |
+
self.parsing_map_folder = os.path.join(root, 'parsing_maps')
|
189 |
+
self.image_names = os.listdir(self.image_folder)
|
190 |
+
|
191 |
+
def __len__(self):
|
192 |
+
return len(self.image_names)
|
193 |
+
|
194 |
+
def __getitem__(self, idx):
|
195 |
+
img_name = os.path.join(self.image_folder, self.image_names[idx])
|
196 |
+
parsing_map_name = os.path.join(self.parsing_map_folder, self.image_names[idx].replace('.png', '.npy'))
|
197 |
+
|
198 |
+
image = Image.open(img_name).convert("RGB")
|
199 |
+
parsing_map_np = np.load(parsing_map_name)
|
200 |
+
|
201 |
+
if self.transform:
|
202 |
+
image = self.transform(image)
|
203 |
+
|
204 |
+
# Convert mask to tensor
|
205 |
+
parsing_map = torch.from_numpy(parsing_map_np)
|
206 |
+
del parsing_map_np # may save mem
|
207 |
+
|
208 |
+
return {'image': image, 'parsing_map': parsing_map}
|
209 |
+
|
210 |
+
|
211 |
+
class TestImageFolder(datasets.ImageFolder):
|
212 |
+
def __init__(self, root, transform=None, target_transform=None):
|
213 |
+
super(TestImageFolder, self).__init__(root, transform, target_transform)
|
214 |
+
|
215 |
+
def __getitem__(self, index):
|
216 |
+
# Call the parent class method to load image and label
|
217 |
+
original_tuple = super(TestImageFolder, self).__getitem__(index)
|
218 |
+
|
219 |
+
# Get the video name
|
220 |
+
video_name = self.imgs[index][0].split('/')[-1].split('_frame_')[0] # the separator of video name
|
221 |
+
|
222 |
+
# Extend the tuple to include video name
|
223 |
+
extended_tuple = (original_tuple + (video_name,))
|
224 |
+
|
225 |
+
return extended_tuple
|
226 |
+
|
227 |
+
|
228 |
+
def get_mean_std(args):
|
229 |
+
print('dataset_paths:', args.data_path)
|
230 |
+
transform = transforms.Compose([transforms.ToTensor(),
|
231 |
+
transforms.Resize((args.input_size, args.input_size),
|
232 |
+
interpolation=transforms.InterpolationMode.BICUBIC)])
|
233 |
+
|
234 |
+
if len(args.data_path) > 1:
|
235 |
+
pretrain_datasets = [FaceParsingDataset(root=path, transform=transform) for path in args.data_path]
|
236 |
+
dataset_pretrain = ConcatDataset(pretrain_datasets)
|
237 |
+
else:
|
238 |
+
pretrain_datasets = args.data_path[0]
|
239 |
+
dataset_pretrain = FaceParsingDataset(root=pretrain_datasets, transform=transform)
|
240 |
+
|
241 |
+
print('Compute mean and variance for pretraining data.')
|
242 |
+
print('len(dataset_train): ', len(dataset_pretrain))
|
243 |
+
|
244 |
+
loader = DataLoader(
|
245 |
+
dataset_pretrain,
|
246 |
+
batch_size=args.batch_size,
|
247 |
+
num_workers=args.num_workers,
|
248 |
+
pin_memory=args.pin_mem,
|
249 |
+
drop_last=True,
|
250 |
+
)
|
251 |
+
|
252 |
+
channels_sum, channels_squared_sum, num_batches = 0, 0, 0
|
253 |
+
for sample in loader:
|
254 |
+
data = sample['image']
|
255 |
+
channels_sum += torch.mean(data, dim=[0, 2, 3])
|
256 |
+
channels_squared_sum += torch.mean(data ** 2, dim=[0, 2, 3])
|
257 |
+
num_batches += 1
|
258 |
+
|
259 |
+
mean = channels_sum / num_batches
|
260 |
+
std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5
|
261 |
+
|
262 |
+
print(f'train dataset mean%: {mean.numpy()} std: %{std.numpy()} ')
|
263 |
+
del pretrain_datasets, dataset_pretrain, loader
|
264 |
+
return mean.numpy(), std.numpy()
|
265 |
+
|
266 |
+
|
267 |
+
def build_dataset(is_train, args):
|
268 |
+
transform = build_transform(is_train, args)
|
269 |
+
if args.eval:
|
270 |
+
# no loading training set
|
271 |
+
root = os.path.join(args.data_path, 'test' if is_train else 'test')
|
272 |
+
dataset = TestImageFolder(root, transform=transform)
|
273 |
+
else:
|
274 |
+
root = os.path.join(args.data_path, 'train' if is_train else 'val')
|
275 |
+
dataset = datasets.ImageFolder(root, transform=transform)
|
276 |
+
print(dataset)
|
277 |
+
|
278 |
+
return dataset
|
279 |
+
|
280 |
+
|
281 |
+
def build_transform(is_train, args):
|
282 |
+
if args.normalize_from_IMN:
|
283 |
+
mean = IMAGENET_DEFAULT_MEAN
|
284 |
+
std = IMAGENET_DEFAULT_STD
|
285 |
+
# print(f'mean:{mean}, std:{std}')
|
286 |
+
else:
|
287 |
+
if not os.path.exists(os.path.join(args.output_dir, "/pretrain_ds_mean_std.txt")) and not args.eval:
|
288 |
+
shutil.copyfile(os.path.dirname(args.finetune) + '/pretrain_ds_mean_std.txt',
|
289 |
+
os.path.join(args.output_dir) + '/pretrain_ds_mean_std.txt')
|
290 |
+
with open(os.path.join(os.path.dirname(args.resume)) + '/pretrain_ds_mean_std.txt' if args.eval
|
291 |
+
else os.path.join(args.output_dir) + '/pretrain_ds_mean_std.txt', 'r') as file:
|
292 |
+
ds_stat = json.loads(file.readline())
|
293 |
+
mean = ds_stat['mean']
|
294 |
+
std = ds_stat['std']
|
295 |
+
# print(f'mean:{mean}, std:{std}')
|
296 |
+
|
297 |
+
if args.apply_simple_augment:
|
298 |
+
if is_train:
|
299 |
+
# this should always dispatch to transforms_imagenet_train
|
300 |
+
transform = create_transform(
|
301 |
+
input_size=args.input_size,
|
302 |
+
is_training=True,
|
303 |
+
color_jitter=args.color_jitter,
|
304 |
+
auto_augment=args.aa,
|
305 |
+
interpolation=transforms.InterpolationMode.BICUBIC,
|
306 |
+
re_prob=args.reprob,
|
307 |
+
re_mode=args.remode,
|
308 |
+
re_count=args.recount,
|
309 |
+
mean=mean,
|
310 |
+
std=std,
|
311 |
+
)
|
312 |
+
return transform
|
313 |
+
|
314 |
+
# no augment / eval transform
|
315 |
+
t = []
|
316 |
+
if args.input_size <= 224:
|
317 |
+
crop_pct = 224 / 256
|
318 |
+
else:
|
319 |
+
crop_pct = 1.0
|
320 |
+
size = int(args.input_size / crop_pct) # 256
|
321 |
+
t.append(
|
322 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
|
323 |
+
# to maintain same ratio w.r.t. 224 images
|
324 |
+
)
|
325 |
+
t.append(transforms.CenterCrop(args.input_size)) # 224
|
326 |
+
|
327 |
+
t.append(transforms.ToTensor())
|
328 |
+
t.append(transforms.Normalize(mean, std))
|
329 |
+
return transforms.Compose(t)
|
330 |
+
|
331 |
+
else:
|
332 |
+
t = []
|
333 |
+
if args.input_size < 224:
|
334 |
+
crop_pct = input_size / 224
|
335 |
+
else:
|
336 |
+
crop_pct = 1.0
|
337 |
+
size = int(args.input_size / crop_pct) # size = 224
|
338 |
+
t.append(
|
339 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
|
340 |
+
# to maintain same ratio w.r.t. 224 images
|
341 |
+
)
|
342 |
+
# t.append(
|
343 |
+
# transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
|
344 |
+
# # to maintain same ratio w.r.t. 224 images
|
345 |
+
# )
|
346 |
+
|
347 |
+
t.append(transforms.ToTensor())
|
348 |
+
t.append(transforms.Normalize(mean, std))
|
349 |
+
return transforms.Compose(t)
|
util/lars.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR
|
3 |
+
# --------------------------------------------------------
|
4 |
+
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
+
# You can find the license in the LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class LARS(torch.optim.Optimizer):
|
12 |
+
"""
|
13 |
+
LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
|
14 |
+
"""
|
15 |
+
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
|
16 |
+
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
|
17 |
+
super().__init__(params, defaults)
|
18 |
+
|
19 |
+
@torch.no_grad()
|
20 |
+
def step(self):
|
21 |
+
for g in self.param_groups:
|
22 |
+
for p in g['params']:
|
23 |
+
dp = p.grad
|
24 |
+
|
25 |
+
if dp is None:
|
26 |
+
continue
|
27 |
+
|
28 |
+
if p.ndim > 1: # if not normalization gamma/beta or bias
|
29 |
+
dp = dp.add(p, alpha=g['weight_decay'])
|
30 |
+
param_norm = torch.norm(p)
|
31 |
+
update_norm = torch.norm(dp)
|
32 |
+
one = torch.ones_like(param_norm)
|
33 |
+
q = torch.where(param_norm > 0.,
|
34 |
+
torch.where(update_norm > 0,
|
35 |
+
(g['trust_coefficient'] * param_norm / update_norm), one),
|
36 |
+
one)
|
37 |
+
dp = dp.mul(q)
|
38 |
+
|
39 |
+
param_state = self.state[p]
|
40 |
+
if 'mu' not in param_state:
|
41 |
+
param_state['mu'] = torch.zeros_like(p)
|
42 |
+
mu = param_state['mu']
|
43 |
+
mu.mul_(g['momentum']).add_(dp)
|
44 |
+
p.add_(mu, alpha=-g['lr'])
|
util/loss_contrastive.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR
|
3 |
+
# --------------------------------------------------------
|
4 |
+
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
+
# You can find the license in the LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
from __future__ import print_function
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import math
|
13 |
+
|
14 |
+
|
15 |
+
class SimSiamLoss(nn.Module):
|
16 |
+
def __init__(self):
|
17 |
+
super(SimSiamLoss, self).__init__()
|
18 |
+
self.criterion = nn.CosineSimilarity(dim=1)
|
19 |
+
|
20 |
+
def forward(self, cl_features):
|
21 |
+
|
22 |
+
if len(cl_features.shape) < 3:
|
23 |
+
raise ValueError('`features` needs to be [bsz, n_views, ...],'
|
24 |
+
'at least 3 dimensions are required')
|
25 |
+
if len(cl_features.shape) > 3:
|
26 |
+
cl_features = cl_features.view(cl_features.shape[0], cl_features.shape[1], -1) # [BS, 2, feat_cl_dim]
|
27 |
+
|
28 |
+
cl_features_1 = cl_features[:, 0] # [BS, feat_cl_dim]
|
29 |
+
cl_features_2 = cl_features[:, 1] # [BS, feat_cl_dim]
|
30 |
+
loss = -(self.criterion(cl_features_1, cl_features_2).mean()) * 0.5
|
31 |
+
|
32 |
+
# if not math.isfinite(loss):
|
33 |
+
# print(cl_features_1, '\n', cl_features_2)
|
34 |
+
# print(self.criterion(cl_features_1, cl_features_2))
|
35 |
+
|
36 |
+
return loss
|
37 |
+
|
38 |
+
|
39 |
+
class BYOLLoss(nn.Module):
|
40 |
+
def __init__(self):
|
41 |
+
super(BYOLLoss, self).__init__()
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def forward(cl_features):
|
45 |
+
|
46 |
+
if len(cl_features.shape) < 3:
|
47 |
+
raise ValueError('`features` needs to be [bsz, n_views, ...],'
|
48 |
+
'at least 3 dimensions are required')
|
49 |
+
if len(cl_features.shape) > 3:
|
50 |
+
cl_features = cl_features.view(cl_features.shape[0], cl_features.shape[1], -1) # [BS, 2, feat_cl_dim]
|
51 |
+
|
52 |
+
cl_features_1 = cl_features[:, 0] # [BS, feat_cl_dim]
|
53 |
+
cl_features_2 = cl_features[:, 1] # [BS, feat_cl_dim]
|
54 |
+
loss = 2 - 2 * (cl_features_1 * cl_features_2).sum(dim=-1)
|
55 |
+
# loss = 1 - (cl_features_1 * cl_features_2).sum(dim=-1)
|
56 |
+
loss = loss.mean()
|
57 |
+
|
58 |
+
if not math.isfinite(loss):
|
59 |
+
print(cl_features_1, '\n', cl_features_2)
|
60 |
+
print(2 - 2 * (cl_features_1 * cl_features_2).sum(dim=-1))
|
61 |
+
|
62 |
+
return loss
|
63 |
+
|
64 |
+
|
65 |
+
# different implementation of InfoNCELoss, including MOCOV3Loss; SupConLoss
|
66 |
+
class InfoNCELoss(nn.Module):
|
67 |
+
def __init__(self, temperature=0.1, contrast_sample='all'):
|
68 |
+
"""
|
69 |
+
from CMAE: https://github.com/ZhichengHuang/CMAE/issues/5
|
70 |
+
:param temperature: 0.1 0.5 1.0, 1.5 2.0
|
71 |
+
"""
|
72 |
+
super(InfoNCELoss, self).__init__()
|
73 |
+
self.temperature = temperature
|
74 |
+
self.criterion = nn.CrossEntropyLoss()
|
75 |
+
self.contrast_sample = contrast_sample
|
76 |
+
|
77 |
+
def forward(self, cl_features):
|
78 |
+
"""
|
79 |
+
Args:
|
80 |
+
:param cl_features: : hidden vector of shape [bsz, n_views, ...]
|
81 |
+
Returns:
|
82 |
+
A loss scalar.
|
83 |
+
"""
|
84 |
+
device = (torch.device('cuda')
|
85 |
+
if cl_features.is_cuda
|
86 |
+
else torch.device('cpu'))
|
87 |
+
|
88 |
+
if len(cl_features.shape) < 3:
|
89 |
+
raise ValueError('`features` needs to be [bsz, n_views, ...],'
|
90 |
+
'at least 3 dimensions are required')
|
91 |
+
if len(cl_features.shape) > 3:
|
92 |
+
cl_features = cl_features.view(cl_features.shape[0], cl_features.shape[1], -1) # [BS, 2, feat_cl_dim]
|
93 |
+
|
94 |
+
cl_features_1 = cl_features[:, 0] # [BS, feat_cl_dim]
|
95 |
+
cl_features_2 = cl_features[:, 1] # [BS, feat_cl_dim]
|
96 |
+
score_all = torch.matmul(cl_features_1, cl_features_2.transpose(1, 0)) # [BS, BS]
|
97 |
+
score_all = score_all / self.temperature
|
98 |
+
bs = score_all.size(0)
|
99 |
+
|
100 |
+
if self.contrast_sample == 'all':
|
101 |
+
score = score_all
|
102 |
+
elif self.contrast_sample == 'positive':
|
103 |
+
mask = torch.eye(bs, dtype=torch.float).to(device) # torch.Size([BS, BS])
|
104 |
+
score = score_all * mask
|
105 |
+
else:
|
106 |
+
raise ValueError('Contrastive sample: all{pos&neg} or positive(positive)')
|
107 |
+
|
108 |
+
# label = (torch.arange(bs, dtype=torch.long) +
|
109 |
+
# bs * torch.distributed.get_rank()).to(device)
|
110 |
+
label = torch.arange(bs, dtype=torch.long).to(device)
|
111 |
+
|
112 |
+
loss = 2 * self.temperature * self.criterion(score, label)
|
113 |
+
|
114 |
+
if not math.isfinite(loss):
|
115 |
+
print(cl_features_1, '\n', cl_features_2)
|
116 |
+
print(score_all, '\n', score, '\n', mask)
|
117 |
+
|
118 |
+
return loss
|
119 |
+
|
120 |
+
|
121 |
+
class MOCOV3Loss(nn.Module):
|
122 |
+
def __init__(self, temperature=0.1):
|
123 |
+
super(MOCOV3Loss, self).__init__()
|
124 |
+
self.temperature = temperature
|
125 |
+
|
126 |
+
def forward(self, cl_features):
|
127 |
+
|
128 |
+
if len(cl_features.shape) < 3:
|
129 |
+
raise ValueError('`features` needs to be [bsz, n_views, ...],'
|
130 |
+
'at least 3 dimensions are required')
|
131 |
+
if len(cl_features.shape) > 3:
|
132 |
+
cl_features = cl_features.view(cl_features.shape[0], cl_features.shape[1], -1) # [BS, 2, feat_cl_dim]
|
133 |
+
|
134 |
+
cl_features_1 = cl_features[:, 0] # [BS, feat_cl_dim]
|
135 |
+
cl_features_2 = cl_features[:, 1] # [BS, feat_cl_dim]
|
136 |
+
|
137 |
+
# normalize
|
138 |
+
cl_features_1 = nn.functional.normalize(cl_features_1, dim=1)
|
139 |
+
cl_features_2 = nn.functional.normalize(cl_features_2, dim=1)
|
140 |
+
# Einstein sum is more intuitive
|
141 |
+
logits = torch.einsum('nc,mc->nm', [cl_features_1, cl_features_2]) / self.temperature
|
142 |
+
N = logits.shape[0]
|
143 |
+
labels = (torch.arange(N, dtype=torch.long)).cuda()
|
144 |
+
return nn.CrossEntropyLoss()(logits, labels) * (2 * self.temperature)
|
145 |
+
|
146 |
+
|
147 |
+
class SupConLoss(nn.Module):
|
148 |
+
"""
|
149 |
+
from: https://github.com/HobbitLong/SupContrast
|
150 |
+
Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
|
151 |
+
It also supports the unsupervised contrastive loss in SimCLR"""
|
152 |
+
def __init__(self, temperature=0.1, contrast_mode='all', contrast_sample='all',
|
153 |
+
base_temperature=0.1):
|
154 |
+
super(SupConLoss, self).__init__()
|
155 |
+
self.temperature = temperature
|
156 |
+
self.contrast_mode = contrast_mode
|
157 |
+
self.contrast_sample = contrast_sample
|
158 |
+
self.base_temperature = base_temperature
|
159 |
+
|
160 |
+
def forward(self, features, labels=None, mask=None):
|
161 |
+
"""Compute loss for model. If both `labels` and `mask` are None,
|
162 |
+
it degenerates to SimCLR unsupervised loss:
|
163 |
+
https://arxiv.org/pdf/2002.05709.pdf
|
164 |
+
|
165 |
+
Args:
|
166 |
+
features: hidden vector of shape [bsz, n_views, ...].
|
167 |
+
labels: ground truth of shape [bsz].
|
168 |
+
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
|
169 |
+
has the same class as sample i. Can be asymmetric.
|
170 |
+
Returns:
|
171 |
+
A loss scalar.
|
172 |
+
"""
|
173 |
+
device = (torch.device('cuda')
|
174 |
+
if features.is_cuda
|
175 |
+
else torch.device('cpu'))
|
176 |
+
|
177 |
+
if len(features.shape) < 3:
|
178 |
+
raise ValueError('`features` needs to be [bsz, n_views, ...],'
|
179 |
+
'at least 3 dimensions are required')
|
180 |
+
if len(features.shape) > 3:
|
181 |
+
features = features.view(features.shape[0], features.shape[1], -1) # [BS, 2, feat_cl_dim]
|
182 |
+
|
183 |
+
batch_size = features.shape[0]
|
184 |
+
if labels is not None and mask is not None:
|
185 |
+
raise ValueError('Cannot define both `labels` and `mask`')
|
186 |
+
elif labels is None and mask is None:
|
187 |
+
mask = torch.eye(batch_size, dtype=torch.float32).to(device) # torch.Size([BS, BS])
|
188 |
+
elif labels is not None:
|
189 |
+
labels = labels.contiguous().view(-1, 1)
|
190 |
+
if labels.shape[0] != batch_size:
|
191 |
+
raise ValueError('Num of labels does not match num of features')
|
192 |
+
mask = torch.eq(labels, labels.T).float().to(device)
|
193 |
+
else:
|
194 |
+
mask = mask.float().to(device)
|
195 |
+
|
196 |
+
contrast_count = features.shape[1] # contrast_count(2)
|
197 |
+
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) # [BS*contrast_count, D]
|
198 |
+
if self.contrast_mode == 'one':
|
199 |
+
anchor_feature = features[:, 0] # [BS, D]
|
200 |
+
anchor_count = 1
|
201 |
+
elif self.contrast_mode == 'all':
|
202 |
+
anchor_feature = contrast_feature # [BS*contrast_count, D]
|
203 |
+
anchor_count = contrast_count
|
204 |
+
else:
|
205 |
+
raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
|
206 |
+
|
207 |
+
# compute logits
|
208 |
+
anchor_dot_contrast = torch.div(
|
209 |
+
torch.matmul(anchor_feature, contrast_feature.T),
|
210 |
+
self.temperature) # [BS*contrast_count, BS*contrast_count]
|
211 |
+
# for numerical stability
|
212 |
+
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) # [BS*contrast_count, 1]
|
213 |
+
logits = anchor_dot_contrast - logits_max.detach() # [BS*contrast_count, BS*contrast_count]
|
214 |
+
|
215 |
+
# tile mask
|
216 |
+
mask = mask.repeat(anchor_count, contrast_count) # [BS*anchor_count, BS*contrast_count]
|
217 |
+
# mask-out self-contrast cases
|
218 |
+
logits_mask = torch.scatter(
|
219 |
+
torch.ones_like(mask),
|
220 |
+
1,
|
221 |
+
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
|
222 |
+
0
|
223 |
+
) # [BS*anchor_count, BS*contrast_count]
|
224 |
+
mask = mask * logits_mask # [BS*anchor_count, BS*contrast_count]
|
225 |
+
|
226 |
+
"""
|
227 |
+
logits_mask is used to get the denominator(positives and negatives).
|
228 |
+
mask is used to get the numerator(positives). mask is applied to log_prob.
|
229 |
+
"""
|
230 |
+
|
231 |
+
# compute log_prob,logits_mask is contrast anchor with both positives and negatives
|
232 |
+
exp_logits = torch.exp(logits) * logits_mask # [BS*anchor_count, BS*contrast_count]
|
233 |
+
# compute log_prob,logits_mask is contrast anchor with negatives, i.e., denominator only negatives contrast:
|
234 |
+
# exp_logits = torch.exp(logits) * (logits_mask-mask)
|
235 |
+
|
236 |
+
if self.contrast_sample == 'all':
|
237 |
+
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) # [BS*anchor_count, BS*anchor_count]
|
238 |
+
# compute mean of log-likelihood over positive
|
239 |
+
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) # [BS*anchor_count]
|
240 |
+
elif self.contrast_sample == 'positive':
|
241 |
+
mean_log_prob_pos = (mask * logits).sum(1) / mask.sum(1)
|
242 |
+
else:
|
243 |
+
raise ValueError('Contrastive sample: all{pos&neg} or positive(positive)')
|
244 |
+
|
245 |
+
# loss
|
246 |
+
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
|
247 |
+
loss = loss.view(anchor_count, batch_size).mean()
|
248 |
+
|
249 |
+
return loss
|
250 |
+
|
251 |
+
|
252 |
+
class InfoNCELossPatchLevel(nn.Module):
|
253 |
+
"""
|
254 |
+
test: ref ConMIM: https://github.com/TencentARC/ConMIM.
|
255 |
+
"""
|
256 |
+
def __init__(self, temperature=0.1, contrast_sample='all'):
|
257 |
+
"""
|
258 |
+
:param temperature: 0.1 0.5 1.0, 1.5 2.0
|
259 |
+
"""
|
260 |
+
super(InfoNCELossPatchLevel, self).__init__()
|
261 |
+
self.temperature = temperature
|
262 |
+
self.criterion = nn.CrossEntropyLoss()
|
263 |
+
self.contrast_sample = contrast_sample
|
264 |
+
|
265 |
+
self.facial_region_group = [
|
266 |
+
[2, 3], # eyebrows
|
267 |
+
[4, 5], # eyes
|
268 |
+
[6], # nose
|
269 |
+
[7, 8, 9], # mouth
|
270 |
+
[10, 1, 0], # face boundaries
|
271 |
+
[10], # hair
|
272 |
+
[1], # facial skin
|
273 |
+
[0] # background
|
274 |
+
]
|
275 |
+
|
276 |
+
def forward(self, cl_features, parsing_map=None):
|
277 |
+
"""
|
278 |
+
Args:
|
279 |
+
:param parsing_map:
|
280 |
+
:param cl_features: : hidden vector of shape [bsz, n_views, ...]
|
281 |
+
Returns:
|
282 |
+
A loss scalar.
|
283 |
+
"""
|
284 |
+
device = (torch.device('cuda')
|
285 |
+
if cl_features.is_cuda
|
286 |
+
else torch.device('cpu'))
|
287 |
+
|
288 |
+
if len(cl_features.shape) < 4:
|
289 |
+
raise ValueError('`features` needs to be [bsz, n_views, n_cl_patches, ...],'
|
290 |
+
'at least 4 dimensions are required')
|
291 |
+
if len(cl_features.shape) > 4:
|
292 |
+
cl_features = cl_features.view(cl_features.shape[0], cl_features.shape[1], cl_features.shape[2], -1)
|
293 |
+
# [BS, 2, num_cl_patches, feat_cl_dim]
|
294 |
+
|
295 |
+
cl_features_1 = cl_features[:, 0]
|
296 |
+
cl_features_2 = cl_features[:, 1]
|
297 |
+
score = torch.matmul(cl_features_1, cl_features_2.permute(0, 2, 1)) # [BS, num_cl_patches, num_cl_patches]
|
298 |
+
score = score / self.temperature
|
299 |
+
bs = score.size(0)
|
300 |
+
num_cl_patches = score.size(1)
|
301 |
+
|
302 |
+
if self.contrast_sample == 'all':
|
303 |
+
score = score
|
304 |
+
elif self.contrast_sample == 'positive':
|
305 |
+
mask = torch.eye(num_cl_patches, dtype=torch.float32) # torch.Size([num_cl_patches, num_cl_patches])
|
306 |
+
mask_batch = mask.unsqueeze(0).expand(bs, -1).to(device) # [bs, num_cl_patches, num_cl_patches]
|
307 |
+
score = score*mask_batch
|
308 |
+
elif self.contrast_sample == 'region':
|
309 |
+
cl_features_1_fr = []
|
310 |
+
cl_features_2_fr = []
|
311 |
+
for facial_region_index in self.facial_region_group:
|
312 |
+
fr_mask = (parsing_map == facial_region_index).unsqueeze(2).expand(-1, -1, cl_features_1.size(-1))
|
313 |
+
cl_features_1_fr.append((cl_features_1 * fr_mask).mean(dim=1, keepdim=False))
|
314 |
+
cl_features_2_fr.append((cl_features_1 * fr_mask).mean(dim=1, keepdim=False))
|
315 |
+
cl_features_1_fr = torch.stack(cl_features_1_fr, dim=1)
|
316 |
+
cl_features_2_fr = torch.stack(cl_features_2_fr, dim=1)
|
317 |
+
score = torch.matmul(cl_features_1_fr, cl_features_2_fr.permute(0, 2, 1)) # [BS, 8, 8]
|
318 |
+
score = score / self.temperature
|
319 |
+
# mask = torch.eye(cl_features_1_fr.size(1), dtype=torch.bool)
|
320 |
+
# torch.Size([cl_features_1_fr.size(1), cl_features_1_fr.size(1)])
|
321 |
+
# mask_batch = mask.unsqueeze(0).expand(bs, -1).to(device)
|
322 |
+
# [bs, cl_features_1_fr.size(1), cl_features_1_fr.size(1)]
|
323 |
+
# score = score*mask_batch
|
324 |
+
label = torch.arange(cl_features_1_fr.size(1), dtype=torch.long).to(device)
|
325 |
+
labels_batch = label.unsqueeze(0).expand(bs, -1)
|
326 |
+
loss = 2 * self.temperature * self.criterion(score, labels_batch)
|
327 |
+
return loss
|
328 |
+
else:
|
329 |
+
raise ValueError('Contrastive sample: all{pos&neg} or positive(positive)')
|
330 |
+
|
331 |
+
# label = (torch.arange(bs, dtype=torch.long) +
|
332 |
+
# bs * torch.distributed.get_rank()).to(device)
|
333 |
+
label = torch.arange(num_cl_patches, dtype=torch.long).to(device)
|
334 |
+
labels_batch = label.unsqueeze(0).expand(bs, -1)
|
335 |
+
|
336 |
+
loss = 2 * self.temperature * self.criterion(score, labels_batch)
|
337 |
+
|
338 |
+
return loss
|
339 |
+
|
340 |
+
|
341 |
+
class MSELoss(nn.Module):
|
342 |
+
"""
|
343 |
+
test: unused
|
344 |
+
"""
|
345 |
+
def __init__(self):
|
346 |
+
super(MSELoss, self).__init__()
|
347 |
+
|
348 |
+
@staticmethod
|
349 |
+
def forward(cl_features):
|
350 |
+
|
351 |
+
if len(cl_features.shape) < 3:
|
352 |
+
raise ValueError('`features` needs to be [bsz, n_views, n_patches, ...],'
|
353 |
+
'at least 3 dimensions are required')
|
354 |
+
if len(cl_features.shape) > 3:
|
355 |
+
cl_features = cl_features.view(cl_features.shape[0], cl_features.shape[1], -1) # [BS, 2, feat_cl_dim]
|
356 |
+
|
357 |
+
cl_features_1 = cl_features[:, 0].float() # [BS, feat_cl_dim]
|
358 |
+
cl_features_2 = cl_features[:, 1].float() # [BS, feat_cl_dim]
|
359 |
+
|
360 |
+
return torch.nn.functional.mse_loss(cl_features_1, cl_features_2, reduction='mean')
|
util/lr_decay.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR
|
3 |
+
# --------------------------------------------------------
|
4 |
+
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
+
# You can find the license in the LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import json
|
9 |
+
|
10 |
+
|
11 |
+
def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
|
12 |
+
"""
|
13 |
+
Parameter groups for layer-wise lr decay
|
14 |
+
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
|
15 |
+
"""
|
16 |
+
param_group_names = {}
|
17 |
+
param_groups = {}
|
18 |
+
|
19 |
+
num_layers = len(model.blocks) + 1
|
20 |
+
|
21 |
+
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
|
22 |
+
|
23 |
+
for n, p in model.named_parameters():
|
24 |
+
if not p.requires_grad:
|
25 |
+
continue
|
26 |
+
|
27 |
+
# no decay: all 1D parameters and model specific ones
|
28 |
+
if p.ndim == 1 or n in no_weight_decay_list:
|
29 |
+
g_decay = "no_decay"
|
30 |
+
this_decay = 0.
|
31 |
+
else:
|
32 |
+
g_decay = "decay"
|
33 |
+
this_decay = weight_decay
|
34 |
+
|
35 |
+
layer_id = get_layer_id_for_vit(n, num_layers)
|
36 |
+
group_name = "layer_%d_%s" % (layer_id, g_decay)
|
37 |
+
|
38 |
+
if group_name not in param_group_names:
|
39 |
+
this_scale = layer_scales[layer_id]
|
40 |
+
|
41 |
+
param_group_names[group_name] = {
|
42 |
+
"lr_scale": this_scale,
|
43 |
+
"weight_decay": this_decay,
|
44 |
+
"params": [],
|
45 |
+
}
|
46 |
+
param_groups[group_name] = {
|
47 |
+
"lr_scale": this_scale,
|
48 |
+
"weight_decay": this_decay,
|
49 |
+
"params": [],
|
50 |
+
}
|
51 |
+
|
52 |
+
param_group_names[group_name]["params"].append(n)
|
53 |
+
param_groups[group_name]["params"].append(p)
|
54 |
+
|
55 |
+
# print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
56 |
+
|
57 |
+
return list(param_groups.values())
|
58 |
+
|
59 |
+
|
60 |
+
def get_layer_id_for_vit(name, num_layers):
|
61 |
+
"""
|
62 |
+
Assign a parameter with its layer id
|
63 |
+
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
64 |
+
"""
|
65 |
+
if name in ['cls_token', 'pos_embed']:
|
66 |
+
return 0
|
67 |
+
elif name.startswith('patch_embed'):
|
68 |
+
return 0
|
69 |
+
elif name.startswith('blocks'):
|
70 |
+
return int(name.split('.')[1]) + 1
|
71 |
+
else:
|
72 |
+
return num_layers
|
util/lr_sched.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR
|
3 |
+
# --------------------------------------------------------
|
4 |
+
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
+
# You can find the license in the LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
13 |
+
"""Decay the learning rate with half-cycle cosine after warmup"""
|
14 |
+
if epoch < args.warmup_epochs:
|
15 |
+
lr = args.lr * epoch / args.warmup_epochs
|
16 |
+
else:
|
17 |
+
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
|
18 |
+
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
|
19 |
+
for param_group in optimizer.param_groups:
|
20 |
+
if "lr_scale" in param_group:
|
21 |
+
param_group["lr"] = lr * param_group["lr_scale"]
|
22 |
+
else:
|
23 |
+
param_group["lr"] = lr
|
24 |
+
return lr
|
25 |
+
|
26 |
+
|
27 |
+
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
|
28 |
+
start_warmup_value=0, warmup_steps=-1):
|
29 |
+
warmup_schedule = np.array([])
|
30 |
+
warmup_iters = warmup_epochs * niter_per_ep
|
31 |
+
if warmup_steps > 0:
|
32 |
+
warmup_iters = warmup_steps
|
33 |
+
print("Set warmup steps = %d" % warmup_iters)
|
34 |
+
if warmup_epochs > 0:
|
35 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
36 |
+
|
37 |
+
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
38 |
+
schedule = np.array(
|
39 |
+
[final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
|
40 |
+
|
41 |
+
schedule = np.concatenate((warmup_schedule, schedule))
|
42 |
+
|
43 |
+
assert len(schedule) == epochs * niter_per_ep
|
44 |
+
return schedule
|
util/metrics.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR
|
3 |
+
# --------------------------------------------------------
|
4 |
+
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
+
# You can find the license in the LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
from sklearn.metrics import roc_auc_score
|
9 |
+
from sklearn.metrics import roc_curve
|
10 |
+
from sklearn.metrics import auc, accuracy_score, balanced_accuracy_score
|
11 |
+
from scipy.optimize import brentq
|
12 |
+
from scipy.interpolate import interp1d
|
13 |
+
|
14 |
+
|
15 |
+
def frame_level_acc(labels, y_preds):
|
16 |
+
return accuracy_score(labels, y_preds) * 100.
|
17 |
+
|
18 |
+
|
19 |
+
def frame_level_balanced_acc(labels, y_preds):
|
20 |
+
return balanced_accuracy_score(labels, y_preds) * 100.
|
21 |
+
|
22 |
+
|
23 |
+
def frame_level_auc(labels, preds):
|
24 |
+
return roc_auc_score(labels, preds) * 100.
|
25 |
+
|
26 |
+
|
27 |
+
def frame_level_eer(labels, preds):
|
28 |
+
# 推荐;更正确的,MaskRelation(TIFS23也是)
|
29 |
+
fpr, tpr, thresholds = roc_curve(labels, preds, pos_label=1) # 如果标签不是二进制的,则应显式地给出pos_标签
|
30 |
+
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
31 |
+
# eer_thresh = interp1d(fpr, thresholds)(eer)
|
32 |
+
return eer
|
33 |
+
|
34 |
+
|
35 |
+
# def frame_level_eer(labels, preds):
|
36 |
+
# fpr, tpr, thresholds = roc_curve(labels, preds, pos_label=1)
|
37 |
+
# eer_threshold = thresholds[(fpr + (1 - tpr)).argmin()]
|
38 |
+
# fpr_eer = fpr[thresholds == eer_threshold][0]
|
39 |
+
# fnr_eer = 1 - tpr[thresholds == eer_threshold][0]
|
40 |
+
# eer = (fpr_eer + fnr_eer) / 2
|
41 |
+
# metric_logger.meters['eer'].update(eer)
|
42 |
+
# return eer, eer_thresh
|
43 |
+
|
44 |
+
|
45 |
+
def get_video_level_label_pred(f_label_list, v_name_list, f_pred_list):
|
46 |
+
"""
|
47 |
+
References:
|
48 |
+
CADDM: https://github.com/megvii-research/CADDM
|
49 |
+
"""
|
50 |
+
video_res_dict = dict()
|
51 |
+
video_pred_list = list()
|
52 |
+
video_y_pred_list = list()
|
53 |
+
video_label_list = list()
|
54 |
+
# summarize all the results for each video
|
55 |
+
for label, video, score in zip(f_label_list, v_name_list, f_pred_list):
|
56 |
+
if video not in video_res_dict.keys():
|
57 |
+
video_res_dict[video] = {"scores": [score], "label": label}
|
58 |
+
else:
|
59 |
+
video_res_dict[video]["scores"].append(score)
|
60 |
+
# get the score and label for each video
|
61 |
+
for video, res in video_res_dict.items():
|
62 |
+
score = sum(res['scores']) / len(res['scores'])
|
63 |
+
label = res['label']
|
64 |
+
video_pred_list.append(score)
|
65 |
+
video_label_list.append(label)
|
66 |
+
video_y_pred_list.append(score >= 0.5)
|
67 |
+
|
68 |
+
return video_label_list, video_pred_list, video_y_pred_list
|
69 |
+
|
70 |
+
|
71 |
+
def video_level_acc(video_label_list, video_y_pred_list):
|
72 |
+
return accuracy_score(video_label_list, video_y_pred_list) * 100.
|
73 |
+
|
74 |
+
|
75 |
+
def video_level_balanced_acc(video_label_list, video_y_pred_list):
|
76 |
+
return balanced_accuracy_score(video_label_list, video_y_pred_list) * 100.
|
77 |
+
|
78 |
+
|
79 |
+
def video_level_auc(video_label_list, video_pred_list):
|
80 |
+
return roc_auc_score(video_label_list, video_pred_list) * 100.
|
81 |
+
|
82 |
+
|
83 |
+
def video_level_eer(video_label_list, video_pred_list):
|
84 |
+
# 推荐;更正确的,MaskRelation(TIFS23也是)
|
85 |
+
fpr, tpr, thresholds = roc_curve(video_label_list, video_pred_list, pos_label=1) # 如果标签不是二进制的,则应显式地给出pos_标签
|
86 |
+
v_eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
87 |
+
# eer_thresh = interp1d(fpr, thresholds)(eer)
|
88 |
+
return v_eer
|
util/misc.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR
|
3 |
+
# --------------------------------------------------------
|
4 |
+
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
+
# You can find the license in the LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import builtins
|
9 |
+
import datetime
|
10 |
+
import os
|
11 |
+
import time
|
12 |
+
from collections import defaultdict, deque
|
13 |
+
from pathlib import Path
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.distributed as dist
|
17 |
+
from torch._six import inf
|
18 |
+
|
19 |
+
|
20 |
+
class SmoothedValue(object):
|
21 |
+
"""Track a series of values and provide access to smoothed values over a
|
22 |
+
window or the global series average.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, window_size=20, fmt=None):
|
26 |
+
if fmt is None:
|
27 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
28 |
+
self.deque = deque(maxlen=window_size)
|
29 |
+
self.total = 0.0
|
30 |
+
self.count = 0
|
31 |
+
self.fmt = fmt
|
32 |
+
|
33 |
+
def update(self, value, n=1):
|
34 |
+
self.deque.append(value)
|
35 |
+
self.count += n
|
36 |
+
self.total += value * n
|
37 |
+
|
38 |
+
def synchronize_between_processes(self):
|
39 |
+
"""
|
40 |
+
Warning: does not synchronize the deque!
|
41 |
+
"""
|
42 |
+
if not is_dist_avail_and_initialized():
|
43 |
+
return
|
44 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
45 |
+
dist.barrier()
|
46 |
+
dist.all_reduce(t)
|
47 |
+
t = t.tolist()
|
48 |
+
self.count = int(t[0])
|
49 |
+
self.total = t[1]
|
50 |
+
|
51 |
+
@property
|
52 |
+
def median(self):
|
53 |
+
d = torch.tensor(list(self.deque))
|
54 |
+
return d.median().item()
|
55 |
+
|
56 |
+
@property
|
57 |
+
def avg(self):
|
58 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
59 |
+
return d.mean().item()
|
60 |
+
|
61 |
+
@property
|
62 |
+
def global_avg(self):
|
63 |
+
return self.total / self.count
|
64 |
+
|
65 |
+
@property
|
66 |
+
def max(self):
|
67 |
+
return max(self.deque)
|
68 |
+
|
69 |
+
@property
|
70 |
+
def value(self):
|
71 |
+
return self.deque[-1]
|
72 |
+
|
73 |
+
def __str__(self):
|
74 |
+
return self.fmt.format(
|
75 |
+
median=self.median,
|
76 |
+
avg=self.avg,
|
77 |
+
global_avg=self.global_avg,
|
78 |
+
max=self.max,
|
79 |
+
value=self.value)
|
80 |
+
|
81 |
+
|
82 |
+
class MetricLogger(object):
|
83 |
+
def __init__(self, delimiter="\t"):
|
84 |
+
self.meters = defaultdict(SmoothedValue)
|
85 |
+
self.delimiter = delimiter
|
86 |
+
|
87 |
+
def update(self, **kwargs):
|
88 |
+
for k, v in kwargs.items():
|
89 |
+
if v is None:
|
90 |
+
continue
|
91 |
+
if isinstance(v, torch.Tensor):
|
92 |
+
v = v.item()
|
93 |
+
assert isinstance(v, (float, int))
|
94 |
+
self.meters[k].update(v)
|
95 |
+
|
96 |
+
def __getattr__(self, attr):
|
97 |
+
if attr in self.meters:
|
98 |
+
return self.meters[attr]
|
99 |
+
if attr in self.__dict__:
|
100 |
+
return self.__dict__[attr]
|
101 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
102 |
+
type(self).__name__, attr))
|
103 |
+
|
104 |
+
def __str__(self):
|
105 |
+
loss_str = []
|
106 |
+
for name, meter in self.meters.items():
|
107 |
+
loss_str.append(
|
108 |
+
"{}: {}".format(name, str(meter))
|
109 |
+
)
|
110 |
+
return self.delimiter.join(loss_str)
|
111 |
+
|
112 |
+
def synchronize_between_processes(self):
|
113 |
+
for meter in self.meters.values():
|
114 |
+
meter.synchronize_between_processes()
|
115 |
+
|
116 |
+
def add_meter(self, name, meter):
|
117 |
+
self.meters[name] = meter
|
118 |
+
|
119 |
+
def log_every(self, iterable, print_freq, header=None):
|
120 |
+
i = 0
|
121 |
+
if not header:
|
122 |
+
header = ''
|
123 |
+
start_time = time.time()
|
124 |
+
end = time.time()
|
125 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
126 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
127 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
128 |
+
log_msg = [
|
129 |
+
header,
|
130 |
+
'[{0' + space_fmt + '}/{1}]',
|
131 |
+
'eta: {eta}',
|
132 |
+
'{meters}',
|
133 |
+
'time: {time}',
|
134 |
+
'data: {data}'
|
135 |
+
]
|
136 |
+
if torch.cuda.is_available():
|
137 |
+
log_msg.append('max mem: {memory:.0f}')
|
138 |
+
log_msg = self.delimiter.join(log_msg)
|
139 |
+
MB = 1024.0 * 1024.0
|
140 |
+
for obj in iterable:
|
141 |
+
data_time.update(time.time() - end)
|
142 |
+
yield obj
|
143 |
+
iter_time.update(time.time() - end)
|
144 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
145 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
146 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
147 |
+
if torch.cuda.is_available():
|
148 |
+
print(log_msg.format(
|
149 |
+
i, len(iterable), eta=eta_string,
|
150 |
+
meters=str(self),
|
151 |
+
time=str(iter_time), data=str(data_time),
|
152 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
153 |
+
else:
|
154 |
+
print(log_msg.format(
|
155 |
+
i, len(iterable), eta=eta_string,
|
156 |
+
meters=str(self),
|
157 |
+
time=str(iter_time), data=str(data_time)))
|
158 |
+
i += 1
|
159 |
+
end = time.time()
|
160 |
+
total_time = time.time() - start_time
|
161 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
162 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
163 |
+
header, total_time_str, total_time / len(iterable)))
|
164 |
+
|
165 |
+
|
166 |
+
def setup_for_distributed(is_master):
|
167 |
+
"""
|
168 |
+
This function disables printing when not in master process
|
169 |
+
"""
|
170 |
+
builtin_print = builtins.print
|
171 |
+
|
172 |
+
def print(*args, **kwargs):
|
173 |
+
force = kwargs.pop('force', False)
|
174 |
+
force = force or (get_world_size() > 8)
|
175 |
+
if is_master or force:
|
176 |
+
now = datetime.datetime.now().time()
|
177 |
+
builtin_print('[{}] '.format(now), end='') # print with time stamp
|
178 |
+
builtin_print(*args, **kwargs)
|
179 |
+
|
180 |
+
builtins.print = print
|
181 |
+
|
182 |
+
|
183 |
+
def is_dist_avail_and_initialized():
|
184 |
+
if not dist.is_available():
|
185 |
+
return False
|
186 |
+
if not dist.is_initialized():
|
187 |
+
return False
|
188 |
+
return True
|
189 |
+
|
190 |
+
|
191 |
+
def get_world_size():
|
192 |
+
if not is_dist_avail_and_initialized():
|
193 |
+
return 1
|
194 |
+
return dist.get_world_size()
|
195 |
+
|
196 |
+
|
197 |
+
def get_rank():
|
198 |
+
if not is_dist_avail_and_initialized():
|
199 |
+
return 0
|
200 |
+
return dist.get_rank()
|
201 |
+
|
202 |
+
|
203 |
+
def is_main_process():
|
204 |
+
return get_rank() == 0
|
205 |
+
|
206 |
+
|
207 |
+
def save_on_master(*args, **kwargs):
|
208 |
+
if is_main_process():
|
209 |
+
torch.save(*args, **kwargs)
|
210 |
+
|
211 |
+
|
212 |
+
def init_distributed_mode(args):
|
213 |
+
if args.dist_on_itp:
|
214 |
+
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
215 |
+
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
216 |
+
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
217 |
+
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
218 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
219 |
+
os.environ['RANK'] = str(args.rank)
|
220 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
221 |
+
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
222 |
+
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
223 |
+
args.rank = int(os.environ["RANK"])
|
224 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
225 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
226 |
+
elif 'SLURM_PROCID' in os.environ:
|
227 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
228 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
229 |
+
else:
|
230 |
+
print('Not using distributed mode')
|
231 |
+
setup_for_distributed(is_master=True) # hack
|
232 |
+
args.distributed = False
|
233 |
+
return
|
234 |
+
|
235 |
+
args.distributed = True
|
236 |
+
|
237 |
+
torch.cuda.set_device(args.gpu)
|
238 |
+
args.dist_backend = 'nccl'
|
239 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
240 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
241 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
242 |
+
world_size=args.world_size, rank=args.rank)
|
243 |
+
torch.distributed.barrier()
|
244 |
+
setup_for_distributed(args.rank == 0)
|
245 |
+
|
246 |
+
|
247 |
+
class NativeScalerWithGradNormCount:
|
248 |
+
state_dict_key = "amp_scaler"
|
249 |
+
|
250 |
+
def __init__(self):
|
251 |
+
self._scaler = torch.cuda.amp.GradScaler()
|
252 |
+
|
253 |
+
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
|
254 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
255 |
+
if update_grad:
|
256 |
+
if clip_grad is not None:
|
257 |
+
assert parameters is not None
|
258 |
+
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
259 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
260 |
+
else:
|
261 |
+
self._scaler.unscale_(optimizer)
|
262 |
+
norm = get_grad_norm_(parameters)
|
263 |
+
self._scaler.step(optimizer)
|
264 |
+
self._scaler.update()
|
265 |
+
else:
|
266 |
+
norm = None
|
267 |
+
return norm
|
268 |
+
|
269 |
+
def state_dict(self):
|
270 |
+
return self._scaler.state_dict()
|
271 |
+
|
272 |
+
def load_state_dict(self, state_dict):
|
273 |
+
self._scaler.load_state_dict(state_dict)
|
274 |
+
|
275 |
+
|
276 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
277 |
+
if isinstance(parameters, torch.Tensor):
|
278 |
+
parameters = [parameters]
|
279 |
+
parameters = [p for p in parameters if p.grad is not None]
|
280 |
+
norm_type = float(norm_type)
|
281 |
+
if len(parameters) == 0:
|
282 |
+
return torch.tensor(0.)
|
283 |
+
device = parameters[0].grad.device
|
284 |
+
if norm_type == inf:
|
285 |
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
286 |
+
else:
|
287 |
+
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
|
288 |
+
norm_type)
|
289 |
+
return total_norm
|
290 |
+
|
291 |
+
|
292 |
+
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, tag=None):
|
293 |
+
output_dir = Path(args.output_dir)
|
294 |
+
epoch_name = str(epoch)
|
295 |
+
if loss_scaler is not None:
|
296 |
+
if tag is None:
|
297 |
+
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
|
298 |
+
else:
|
299 |
+
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % tag)]
|
300 |
+
for checkpoint_path in checkpoint_paths:
|
301 |
+
to_save = {
|
302 |
+
'model': model_without_ddp.state_dict(),
|
303 |
+
'optimizer': optimizer.state_dict(),
|
304 |
+
'epoch': epoch,
|
305 |
+
'scaler': loss_scaler.state_dict(),
|
306 |
+
'args': args,
|
307 |
+
}
|
308 |
+
|
309 |
+
save_on_master(to_save, checkpoint_path)
|
310 |
+
else:
|
311 |
+
client_state = {'epoch': epoch}
|
312 |
+
if tag is None:
|
313 |
+
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name,
|
314 |
+
client_state=client_state)
|
315 |
+
else:
|
316 |
+
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % tag,
|
317 |
+
client_state=client_state)
|
318 |
+
|
319 |
+
|
320 |
+
def save_model_target_encoder(args, epoch, model, model_target_encoder_without_ddp, optimizer, loss_scaler, tag=None):
|
321 |
+
output_dir = Path(args.output_dir)
|
322 |
+
epoch_name = str(epoch)
|
323 |
+
if loss_scaler is not None:
|
324 |
+
if tag is None:
|
325 |
+
checkpoint_paths = [output_dir / ('checkpoint-te-%s.pth' % epoch_name)]
|
326 |
+
else:
|
327 |
+
checkpoint_paths = [output_dir / ('checkpoint-te-%s.pth' % tag)]
|
328 |
+
for checkpoint_path in checkpoint_paths:
|
329 |
+
to_save = {
|
330 |
+
'model': model_target_encoder_without_ddp.state_dict(),
|
331 |
+
'optimizer': optimizer.state_dict(),
|
332 |
+
'epoch': epoch,
|
333 |
+
'scaler': loss_scaler.state_dict(),
|
334 |
+
'args': args,
|
335 |
+
}
|
336 |
+
|
337 |
+
save_on_master(to_save, checkpoint_path)
|
338 |
+
else:
|
339 |
+
client_state = {'epoch': epoch}
|
340 |
+
if tag is None:
|
341 |
+
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-te-%s" % epoch_name,
|
342 |
+
client_state=client_state)
|
343 |
+
else:
|
344 |
+
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-te-%s" % tag,
|
345 |
+
client_state=client_state)
|
346 |
+
|
347 |
+
|
348 |
+
def load_model(args, model_without_ddp, optimizer, loss_scaler):
|
349 |
+
if args.resume:
|
350 |
+
if args.resume.startswith('https'):
|
351 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
352 |
+
args.resume, map_location='cpu', check_hash=True)
|
353 |
+
else:
|
354 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
355 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
356 |
+
print("Resume checkpoint %s" % args.resume)
|
357 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
|
358 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
359 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
360 |
+
if 'scaler' in checkpoint:
|
361 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
362 |
+
print("With optim & sched!")
|
363 |
+
|
364 |
+
|
365 |
+
def load_model_target_encoder(args, model_target_encoder_without_ddp, optimizer, loss_scaler):
|
366 |
+
if args.resume:
|
367 |
+
if args.resume.startswith('https'):
|
368 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
369 |
+
args.resume, map_location='cpu', check_hash=True)
|
370 |
+
else:
|
371 |
+
checkpoint = torch.load(args.resume_target_encoder, map_location='cpu')
|
372 |
+
model_target_encoder_without_ddp.load_state_dict(checkpoint['model'])
|
373 |
+
print("Resume checkpoint %s" % args.resume_target_encoder)
|
374 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
|
375 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
376 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
377 |
+
if 'scaler' in checkpoint:
|
378 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
379 |
+
print("With optim & sched!")
|
380 |
+
|
381 |
+
|
382 |
+
def all_reduce_mean(x):
|
383 |
+
world_size = get_world_size()
|
384 |
+
if world_size > 1:
|
385 |
+
x_reduce = torch.tensor(x).cuda()
|
386 |
+
dist.all_reduce(x_reduce)
|
387 |
+
x_reduce /= world_size
|
388 |
+
return x_reduce.item()
|
389 |
+
else:
|
390 |
+
return x
|
util/pos_embed.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Author: Gaojian Wang@ZJUICSR
|
3 |
+
# --------------------------------------------------------
|
4 |
+
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
|
5 |
+
# You can find the license in the LICENSE file in the root directory of this source tree.
|
6 |
+
# -------------------------------------------------------------
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
# --------------------------------------------------------
|
13 |
+
# 2D sine-cosine position embedding
|
14 |
+
# References:
|
15 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
16 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
17 |
+
# --------------------------------------------------------
|
18 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
19 |
+
"""
|
20 |
+
grid_size: int of the grid height and width
|
21 |
+
return:
|
22 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
23 |
+
"""
|
24 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
25 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
26 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
27 |
+
grid = np.stack(grid, axis=0)
|
28 |
+
|
29 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
30 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
31 |
+
if cls_token:
|
32 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
33 |
+
return pos_embed
|
34 |
+
|
35 |
+
|
36 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
37 |
+
assert embed_dim % 2 == 0
|
38 |
+
|
39 |
+
# use half of dimensions to encode grid_h
|
40 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
41 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
42 |
+
|
43 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
44 |
+
return emb
|
45 |
+
|
46 |
+
|
47 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
48 |
+
"""
|
49 |
+
embed_dim: output dimension for each position
|
50 |
+
pos: a list of positions to be encoded: size (M,)
|
51 |
+
out: (M, D)
|
52 |
+
"""
|
53 |
+
assert embed_dim % 2 == 0
|
54 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
55 |
+
omega /= embed_dim / 2.
|
56 |
+
omega = 1. / 10000**omega # (D/2,)
|
57 |
+
|
58 |
+
pos = pos.reshape(-1) # (M,)
|
59 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
60 |
+
|
61 |
+
emb_sin = np.sin(out) # (M, D/2)
|
62 |
+
emb_cos = np.cos(out) # (M, D/2)
|
63 |
+
|
64 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
65 |
+
return emb
|
66 |
+
|
67 |
+
|
68 |
+
# --------------------------------------------------------
|
69 |
+
# Interpolate position embeddings for high-resolution
|
70 |
+
# References:
|
71 |
+
# DeiT: https://github.com/facebookresearch/deit
|
72 |
+
# --------------------------------------------------------
|
73 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
74 |
+
if 'pos_embed' in checkpoint_model:
|
75 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
76 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
77 |
+
num_patches = model.patch_embed.num_patches
|
78 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
79 |
+
# height (== width) for the checkpoint position embedding
|
80 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
81 |
+
# height (== width) for the new position embedding
|
82 |
+
new_size = int(num_patches ** 0.5)
|
83 |
+
# class_token and dist_token are kept unchanged
|
84 |
+
if orig_size != new_size:
|
85 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
86 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
87 |
+
# only the position tokens are interpolated
|
88 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
89 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
90 |
+
pos_tokens = torch.nn.functional.interpolate(
|
91 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
92 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
93 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
94 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
95 |
+
|
96 |
+
|
97 |
+
def interpolate_pos_embed_ema(model, checkpoint_model):
|
98 |
+
if checkpoint_model.pos_embed is not None:
|
99 |
+
pos_embed_checkpoint = checkpoint_model.pos_embed
|
100 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
101 |
+
num_patches = model.patch_embed.num_patches
|
102 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
103 |
+
# height (== width) for the checkpoint position embedding
|
104 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
105 |
+
# height (== width) for the new position embedding
|
106 |
+
new_size = int(num_patches ** 0.5)
|
107 |
+
# class_token and dist_token are kept unchanged
|
108 |
+
if orig_size != new_size:
|
109 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
110 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
111 |
+
# only the position tokens are interpolated
|
112 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
113 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
114 |
+
pos_tokens = torch.nn.functional.interpolate(
|
115 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
116 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
117 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
118 |
+
checkpoint_model.pos_embed = new_pos_embed
|