FSFM-3C
Add V1.0
d4e7f2f
raw
history blame
7.11 kB
from typing import Optional, Dict, Any
import functools
import torch
import torch.nn.functional as F
from .network import FaRLVisualFeatures, MMSEG_UPerHead, FaceAlignmentTransformer, denormalize_points, heatmap2points
from ..transform import (get_face_align_matrix,
make_inverted_tanh_warp_grid, make_tanh_warp_grid)
from .base import FaceAlignment
from ..util import download_jit
import io
pretrain_settings = {
'ibug300w/448': {
# inter_ocular 0.028835 epoch 60
'num_classes': 68,
'url': "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_alignment.farl.ibug300w.main_ema_jit.pt",
'matrix_src_tag': 'points',
'get_matrix_fn': functools.partial(get_face_align_matrix,
target_shape=(448, 448), target_face_scale=0.8),
'get_grid_fn': functools.partial(make_tanh_warp_grid,
warp_factor=0.0, warped_shape=(448, 448)),
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
warp_factor=0.0, warped_shape=(448, 448)),
},
'aflw19/448': {
# diag 0.009329 epoch 15
'num_classes': 19,
'url': "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_alignment.farl.aflw19.main_ema_jit.pt",
'matrix_src_tag': 'points',
'get_matrix_fn': functools.partial(get_face_align_matrix,
target_shape=(448, 448), target_face_scale=0.8),
'get_grid_fn': functools.partial(make_tanh_warp_grid,
warp_factor=0.0, warped_shape=(448, 448)),
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
warp_factor=0.0, warped_shape=(448, 448)),
},
'wflw/448': {
# inter_ocular 0.038933 epoch 20
'num_classes': 98,
'url': "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_alignment.farl.wflw.main_ema_jit.pt",
'matrix_src_tag': 'points',
'get_matrix_fn': functools.partial(get_face_align_matrix,
target_shape=(448, 448), target_face_scale=0.8),
'get_grid_fn': functools.partial(make_tanh_warp_grid,
warp_factor=0.0, warped_shape=(448, 448)),
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
warp_factor=0.0, warped_shape=(448, 448)),
},
}
def load_face_alignment_model(model_path: str, num_classes=68):
backbone = FaRLVisualFeatures("base", None, forced_input_resolution=448, output_indices=None).cpu()
if "jit" in model_path:
extra_files = {"backbone": None}
heatmap_head = download_jit(model_path, map_location="cpu", _extra_files=extra_files)
backbone_weight_io = io.BytesIO(extra_files["backbone"])
backbone.load_state_dict(torch.load(backbone_weight_io))
# print("load from jit")
else:
channels = backbone.get_output_channel("base")
in_channels = [channels] * 4
num_classes = num_classes
heatmap_head = MMSEG_UPerHead(in_channels=in_channels, channels=channels, num_classes=num_classes) # this requires mmseg as a dependency
state = torch.load(model_path,map_location="cpu")["networks"]["main_ema"]
# print("load from checkpoint")
main_network = FaceAlignmentTransformer(backbone, heatmap_head, heatmap_act="sigmoid").cpu()
if "jit" not in model_path:
main_network.load_state_dict(state, strict=True)
return main_network
class FaRLFaceAlignment(FaceAlignment):
""" The face alignment models from [FaRL](https://github.com/FacePerceiver/FaRL).
Please consider citing
```bibtex
@article{zheng2021farl,
title={General Facial Representation Learning in a Visual-Linguistic Manner},
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
Dong and Zeng, Ming and Wen, Fang},
journal={arXiv preprint arXiv:2112.03109},
year={2021}
}
```
"""
def __init__(self, conf_name: Optional[str] = None,
model_path: Optional[str] = None, device=None) -> None:
super().__init__()
if conf_name is None:
conf_name = 'ibug300w/448'
if model_path is None:
model_path = pretrain_settings[conf_name]['url']
self.conf_name = conf_name
setting = pretrain_settings[self.conf_name]
self.net = load_face_alignment_model(model_path, num_classes = setting["num_classes"])
if device is not None:
self.net = self.net.to(device)
self.heatmap_interpolate_mode = 'bilinear'
self.eval()
def forward(self, images: torch.Tensor, data: Dict[str, Any]):
setting = pretrain_settings[self.conf_name]
images = images.float() / 255.0 # backbone 自带 normalize
_, _, h, w = images.shape
simages = images[data['image_ids']]
matrix = setting['get_matrix_fn'](data[setting['matrix_src_tag']])
grid = setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
inv_grid = setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
w_images = F.grid_sample(
simages, grid, mode='bilinear', align_corners=False)
_, _, warp_h, warp_w = w_images.shape
heatmap_acted = self.net(w_images)
warpped_heatmap = F.interpolate(
heatmap_acted, size=(warp_h, warp_w),
mode=self.heatmap_interpolate_mode, align_corners=False)
pred_heatmap = F.grid_sample(
warpped_heatmap, inv_grid, mode='bilinear', align_corners=False)
landmark = heatmap2points(pred_heatmap)
landmark = denormalize_points(landmark, h, w)
data['alignment'] = landmark
return data
if __name__=="__main__":
image = torch.randn(1, 3, 448, 448)
aligner1 = FaRLFaceAlignment("wflw/448")
x1 = aligner1.net(image)
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--jit_path", type=str, default=None)
args = parser.parse_args()
if args.jit_path is None:
exit(0)
net = aligner1.net.cpu()
features, _ = net.backbone(image)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(net.heatmap_head, example_inputs=[features])
buffer = io.BytesIO()
torch.save(net.backbone.state_dict(), buffer)
# Save to file
torch.jit.save(traced_script_module, args.jit_path,
_extra_files={"backbone": buffer.getvalue()})
aligner2 = FaRLFaceAlignment(model_path=args.jit_path)
# compare the output
x2 = aligner2.net(image)
print(torch.allclose(x1, x2))