File size: 7,113 Bytes
d4e7f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
from typing import Optional, Dict, Any
import functools
import torch
import torch.nn.functional as F
from .network import FaRLVisualFeatures, MMSEG_UPerHead, FaceAlignmentTransformer, denormalize_points, heatmap2points
from ..transform import (get_face_align_matrix,
make_inverted_tanh_warp_grid, make_tanh_warp_grid)
from .base import FaceAlignment
from ..util import download_jit
import io
pretrain_settings = {
'ibug300w/448': {
# inter_ocular 0.028835 epoch 60
'num_classes': 68,
'url': "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_alignment.farl.ibug300w.main_ema_jit.pt",
'matrix_src_tag': 'points',
'get_matrix_fn': functools.partial(get_face_align_matrix,
target_shape=(448, 448), target_face_scale=0.8),
'get_grid_fn': functools.partial(make_tanh_warp_grid,
warp_factor=0.0, warped_shape=(448, 448)),
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
warp_factor=0.0, warped_shape=(448, 448)),
},
'aflw19/448': {
# diag 0.009329 epoch 15
'num_classes': 19,
'url': "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_alignment.farl.aflw19.main_ema_jit.pt",
'matrix_src_tag': 'points',
'get_matrix_fn': functools.partial(get_face_align_matrix,
target_shape=(448, 448), target_face_scale=0.8),
'get_grid_fn': functools.partial(make_tanh_warp_grid,
warp_factor=0.0, warped_shape=(448, 448)),
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
warp_factor=0.0, warped_shape=(448, 448)),
},
'wflw/448': {
# inter_ocular 0.038933 epoch 20
'num_classes': 98,
'url': "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_alignment.farl.wflw.main_ema_jit.pt",
'matrix_src_tag': 'points',
'get_matrix_fn': functools.partial(get_face_align_matrix,
target_shape=(448, 448), target_face_scale=0.8),
'get_grid_fn': functools.partial(make_tanh_warp_grid,
warp_factor=0.0, warped_shape=(448, 448)),
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
warp_factor=0.0, warped_shape=(448, 448)),
},
}
def load_face_alignment_model(model_path: str, num_classes=68):
backbone = FaRLVisualFeatures("base", None, forced_input_resolution=448, output_indices=None).cpu()
if "jit" in model_path:
extra_files = {"backbone": None}
heatmap_head = download_jit(model_path, map_location="cpu", _extra_files=extra_files)
backbone_weight_io = io.BytesIO(extra_files["backbone"])
backbone.load_state_dict(torch.load(backbone_weight_io))
# print("load from jit")
else:
channels = backbone.get_output_channel("base")
in_channels = [channels] * 4
num_classes = num_classes
heatmap_head = MMSEG_UPerHead(in_channels=in_channels, channels=channels, num_classes=num_classes) # this requires mmseg as a dependency
state = torch.load(model_path,map_location="cpu")["networks"]["main_ema"]
# print("load from checkpoint")
main_network = FaceAlignmentTransformer(backbone, heatmap_head, heatmap_act="sigmoid").cpu()
if "jit" not in model_path:
main_network.load_state_dict(state, strict=True)
return main_network
class FaRLFaceAlignment(FaceAlignment):
""" The face alignment models from [FaRL](https://github.com/FacePerceiver/FaRL).
Please consider citing
```bibtex
@article{zheng2021farl,
title={General Facial Representation Learning in a Visual-Linguistic Manner},
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
Dong and Zeng, Ming and Wen, Fang},
journal={arXiv preprint arXiv:2112.03109},
year={2021}
}
```
"""
def __init__(self, conf_name: Optional[str] = None,
model_path: Optional[str] = None, device=None) -> None:
super().__init__()
if conf_name is None:
conf_name = 'ibug300w/448'
if model_path is None:
model_path = pretrain_settings[conf_name]['url']
self.conf_name = conf_name
setting = pretrain_settings[self.conf_name]
self.net = load_face_alignment_model(model_path, num_classes = setting["num_classes"])
if device is not None:
self.net = self.net.to(device)
self.heatmap_interpolate_mode = 'bilinear'
self.eval()
def forward(self, images: torch.Tensor, data: Dict[str, Any]):
setting = pretrain_settings[self.conf_name]
images = images.float() / 255.0 # backbone 自带 normalize
_, _, h, w = images.shape
simages = images[data['image_ids']]
matrix = setting['get_matrix_fn'](data[setting['matrix_src_tag']])
grid = setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
inv_grid = setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
w_images = F.grid_sample(
simages, grid, mode='bilinear', align_corners=False)
_, _, warp_h, warp_w = w_images.shape
heatmap_acted = self.net(w_images)
warpped_heatmap = F.interpolate(
heatmap_acted, size=(warp_h, warp_w),
mode=self.heatmap_interpolate_mode, align_corners=False)
pred_heatmap = F.grid_sample(
warpped_heatmap, inv_grid, mode='bilinear', align_corners=False)
landmark = heatmap2points(pred_heatmap)
landmark = denormalize_points(landmark, h, w)
data['alignment'] = landmark
return data
if __name__=="__main__":
image = torch.randn(1, 3, 448, 448)
aligner1 = FaRLFaceAlignment("wflw/448")
x1 = aligner1.net(image)
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--jit_path", type=str, default=None)
args = parser.parse_args()
if args.jit_path is None:
exit(0)
net = aligner1.net.cpu()
features, _ = net.backbone(image)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(net.heatmap_head, example_inputs=[features])
buffer = io.BytesIO()
torch.save(net.backbone.state_dict(), buffer)
# Save to file
torch.jit.save(traced_script_module, args.jit_path,
_extra_files={"backbone": buffer.getvalue()})
aligner2 = FaRLFaceAlignment(model_path=args.jit_path)
# compare the output
x2 = aligner2.net(image)
print(torch.allclose(x1, x2)) |