|
from typing import Optional, Dict, Any |
|
import functools |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from ..util import download_jit |
|
from ..transform import (get_crop_and_resize_matrix, get_face_align_matrix, get_face_align_matrix_celebm, |
|
make_inverted_tanh_warp_grid, make_tanh_warp_grid) |
|
from .base import FaceParser |
|
|
|
pretrain_settings = { |
|
'lapa/448': { |
|
'url': [ |
|
'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.lapa.main_ema_136500_jit191.pt', |
|
], |
|
'matrix_src_tag': 'points', |
|
'get_matrix_fn': functools.partial(get_face_align_matrix, |
|
target_shape=(448, 448), target_face_scale=1.0), |
|
'get_grid_fn': functools.partial(make_tanh_warp_grid, |
|
warp_factor=0.8, warped_shape=(448, 448)), |
|
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid, |
|
warp_factor=0.8, warped_shape=(448, 448)), |
|
'label_names': ['background', 'face', 'rb', 'lb', 're', |
|
'le', 'nose', 'ulip', 'imouth', 'llip', 'hair'] |
|
}, |
|
'celebm/448': { |
|
'url': [ |
|
'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.celebm.main_ema_181500_jit.pt', |
|
], |
|
'matrix_src_tag': 'points', |
|
'get_matrix_fn': functools.partial(get_face_align_matrix_celebm, |
|
target_shape=(448, 448)), |
|
'get_grid_fn': functools.partial(make_tanh_warp_grid, |
|
warp_factor=0, warped_shape=(448, 448)), |
|
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid, |
|
warp_factor=0, warped_shape=(448, 448)), |
|
'label_names': [ |
|
'background', 'neck', 'face', 'cloth', 'rr', 'lr', 'rb', 'lb', 're', |
|
'le', 'nose', 'imouth', 'llip', 'ulip', 'hair', |
|
'eyeg', 'hat', 'earr', 'neck_l'] |
|
} |
|
} |
|
|
|
|
|
class FaRLFaceParser(FaceParser): |
|
""" The face parsing models from [FaRL](https://github.com/FacePerceiver/FaRL). |
|
|
|
Please consider citing |
|
```bibtex |
|
@article{zheng2021farl, |
|
title={General Facial Representation Learning in a Visual-Linguistic Manner}, |
|
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen, |
|
Dongdong and Huang, Yangyu and Yuan, Lu and Chen, |
|
Dong and Zeng, Ming and Wen, Fang}, |
|
journal={arXiv preprint arXiv:2112.03109}, |
|
year={2021} |
|
} |
|
``` |
|
""" |
|
|
|
def __init__(self, conf_name: Optional[str] = None, |
|
model_path: Optional[str] = None, device=None) -> None: |
|
super().__init__() |
|
if conf_name is None: |
|
conf_name = 'lapa/448' |
|
if model_path is None: |
|
model_path = pretrain_settings[conf_name]['url'] |
|
self.conf_name = conf_name |
|
self.net = download_jit(model_path, map_location=device) |
|
self.eval() |
|
|
|
def forward(self, images: torch.Tensor, data: Dict[str, Any]): |
|
setting = pretrain_settings[self.conf_name] |
|
images = images.float() / 255.0 |
|
_, _, h, w = images.shape |
|
|
|
simages = images[data['image_ids']] |
|
matrix = setting['get_matrix_fn'](data[setting['matrix_src_tag']]) |
|
grid = setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w)) |
|
inv_grid = setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w)) |
|
|
|
w_images = F.grid_sample( |
|
simages, grid, mode='bilinear', align_corners=False) |
|
|
|
w_seg_logits, _ = self.net(w_images) |
|
|
|
seg_logits = F.grid_sample( |
|
w_seg_logits, inv_grid, mode='bilinear', align_corners=False) |
|
|
|
data['seg'] = {'logits': seg_logits, |
|
'label_names': setting['label_names']} |
|
return data |
|
|