|
from typing import Optional, Dict, Any |
|
import functools |
|
import torch |
|
import torch.nn.functional as F |
|
from ..transform import get_face_align_matrix, make_tanh_warp_grid |
|
from .base import FaceAttribute |
|
from ..farl import farl_classification |
|
from ..util import download_jit |
|
import numpy as np |
|
|
|
|
|
def get_std_points_xray(out_size=256, mid_size=500): |
|
std_points_256 = np.array( |
|
[ |
|
[85.82991, 85.7792], |
|
[169.0532, 84.3381], |
|
[127.574, 137.0006], |
|
[90.6964, 174.7014], |
|
[167.3069, 173.3733], |
|
] |
|
) |
|
std_points_256[:, 1] += 30 |
|
old_size = 256 |
|
mid = mid_size / 2 |
|
new_std_points = std_points_256 - old_size / 2 + mid |
|
target_pts = new_std_points * out_size / mid_size |
|
target_pts = torch.from_numpy(target_pts).float() |
|
return target_pts |
|
|
|
|
|
pretrain_settings = { |
|
"celeba/224": { |
|
|
|
"num_classes": 40, |
|
"layers": [11], |
|
"url": "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_attribute.farl.celeba.pt", |
|
"matrix_src_tag": "points", |
|
"get_matrix_fn": functools.partial( |
|
get_face_align_matrix, |
|
target_shape=(224, 224), |
|
target_pts=get_std_points_xray(out_size=224, mid_size=500), |
|
), |
|
"get_grid_fn": functools.partial( |
|
make_tanh_warp_grid, warp_factor=0.0, warped_shape=(224, 224) |
|
), |
|
"classes": [ |
|
"5_o_Clock_Shadow", |
|
"Arched_Eyebrows", |
|
"Attractive", |
|
"Bags_Under_Eyes", |
|
"Bald", |
|
"Bangs", |
|
"Big_Lips", |
|
"Big_Nose", |
|
"Black_Hair", |
|
"Blond_Hair", |
|
"Blurry", |
|
"Brown_Hair", |
|
"Bushy_Eyebrows", |
|
"Chubby", |
|
"Double_Chin", |
|
"Eyeglasses", |
|
"Goatee", |
|
"Gray_Hair", |
|
"Heavy_Makeup", |
|
"High_Cheekbones", |
|
"Male", |
|
"Mouth_Slightly_Open", |
|
"Mustache", |
|
"Narrow_Eyes", |
|
"No_Beard", |
|
"Oval_Face", |
|
"Pale_Skin", |
|
"Pointy_Nose", |
|
"Receding_Hairline", |
|
"Rosy_Cheeks", |
|
"Sideburns", |
|
"Smiling", |
|
"Straight_Hair", |
|
"Wavy_Hair", |
|
"Wearing_Earrings", |
|
"Wearing_Hat", |
|
"Wearing_Lipstick", |
|
"Wearing_Necklace", |
|
"Wearing_Necktie", |
|
"Young", |
|
], |
|
} |
|
} |
|
|
|
|
|
def load_face_attr(model_path, num_classes=40, layers=[11]): |
|
model = farl_classification(num_classes=num_classes, layers=layers) |
|
state_dict = download_jit(model_path, jit=False) |
|
model.load_state_dict(state_dict) |
|
return model |
|
|
|
|
|
class FaRLFaceAttribute(FaceAttribute): |
|
"""The face attribute recognition models from [FaRL](https://github.com/FacePerceiver/FaRL). |
|
|
|
Please consider citing |
|
```bibtex |
|
@article{zheng2021farl, |
|
title={General Facial Representation Learning in a Visual-Linguistic Manner}, |
|
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen, |
|
Dongdong and Huang, Yangyu and Yuan, Lu and Chen, |
|
Dong and Zeng, Ming and Wen, Fang}, |
|
journal={arXiv preprint arXiv:2112.03109}, |
|
year={2021} |
|
} |
|
``` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
conf_name: Optional[str] = None, |
|
model_path: Optional[str] = None, |
|
device=None, |
|
) -> None: |
|
super().__init__() |
|
if conf_name is None: |
|
conf_name = "celeba/224" |
|
if model_path is None: |
|
model_path = pretrain_settings[conf_name]["url"] |
|
self.conf_name = conf_name |
|
|
|
setting = pretrain_settings[self.conf_name] |
|
self.labels = setting["classes"] |
|
self.net = load_face_attr(model_path, num_classes=setting["num_classes"], layers = setting["layers"]) |
|
if device is not None: |
|
self.net = self.net.to(device) |
|
|
|
self.eval() |
|
|
|
def forward(self, images: torch.Tensor, data: Dict[str, Any]): |
|
setting = pretrain_settings[self.conf_name] |
|
images = images.float() / 255.0 |
|
_, _, h, w = images.shape |
|
|
|
simages = images[data["image_ids"]] |
|
matrix = setting["get_matrix_fn"](data[setting["matrix_src_tag"]]) |
|
grid = setting["get_grid_fn"](matrix=matrix, orig_shape=(h, w)) |
|
|
|
w_images = F.grid_sample(simages, grid, mode="bilinear", align_corners=False) |
|
|
|
outputs = self.net(w_images) |
|
probs = torch.sigmoid(outputs) |
|
|
|
data["attrs"] = probs |
|
|
|
return data |
|
|
|
|
|
if __name__ == "__main__": |
|
model = FaRLFaceAttribute() |
|
|