from typing import Optional, Dict, Any import functools import torch import torch.nn.functional as F from ..transform import get_face_align_matrix, make_tanh_warp_grid from .base import FaceAttribute from ..farl import farl_classification from ..util import download_jit import numpy as np def get_std_points_xray(out_size=256, mid_size=500): std_points_256 = np.array( [ [85.82991, 85.7792], [169.0532, 84.3381], [127.574, 137.0006], [90.6964, 174.7014], [167.3069, 173.3733], ] ) std_points_256[:, 1] += 30 old_size = 256 mid = mid_size / 2 new_std_points = std_points_256 - old_size / 2 + mid target_pts = new_std_points * out_size / mid_size target_pts = torch.from_numpy(target_pts).float() return target_pts pretrain_settings = { "celeba/224": { # acc 92.06617474555969 "num_classes": 40, "layers": [11], "url": "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_attribute.farl.celeba.pt", "matrix_src_tag": "points", "get_matrix_fn": functools.partial( get_face_align_matrix, target_shape=(224, 224), target_pts=get_std_points_xray(out_size=224, mid_size=500), ), "get_grid_fn": functools.partial( make_tanh_warp_grid, warp_factor=0.0, warped_shape=(224, 224) ), "classes": [ "5_o_Clock_Shadow", "Arched_Eyebrows", "Attractive", "Bags_Under_Eyes", "Bald", "Bangs", "Big_Lips", "Big_Nose", "Black_Hair", "Blond_Hair", "Blurry", "Brown_Hair", "Bushy_Eyebrows", "Chubby", "Double_Chin", "Eyeglasses", "Goatee", "Gray_Hair", "Heavy_Makeup", "High_Cheekbones", "Male", "Mouth_Slightly_Open", "Mustache", "Narrow_Eyes", "No_Beard", "Oval_Face", "Pale_Skin", "Pointy_Nose", "Receding_Hairline", "Rosy_Cheeks", "Sideburns", "Smiling", "Straight_Hair", "Wavy_Hair", "Wearing_Earrings", "Wearing_Hat", "Wearing_Lipstick", "Wearing_Necklace", "Wearing_Necktie", "Young", ], } } def load_face_attr(model_path, num_classes=40, layers=[11]): model = farl_classification(num_classes=num_classes, layers=layers) state_dict = download_jit(model_path, jit=False) model.load_state_dict(state_dict) return model class FaRLFaceAttribute(FaceAttribute): """The face attribute recognition models from [FaRL](https://github.com/FacePerceiver/FaRL). Please consider citing ```bibtex @article{zheng2021farl, title={General Facial Representation Learning in a Visual-Linguistic Manner}, author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen, Dongdong and Huang, Yangyu and Yuan, Lu and Chen, Dong and Zeng, Ming and Wen, Fang}, journal={arXiv preprint arXiv:2112.03109}, year={2021} } ``` """ def __init__( self, conf_name: Optional[str] = None, model_path: Optional[str] = None, device=None, ) -> None: super().__init__() if conf_name is None: conf_name = "celeba/224" if model_path is None: model_path = pretrain_settings[conf_name]["url"] self.conf_name = conf_name setting = pretrain_settings[self.conf_name] self.labels = setting["classes"] self.net = load_face_attr(model_path, num_classes=setting["num_classes"], layers = setting["layers"]) if device is not None: self.net = self.net.to(device) self.eval() def forward(self, images: torch.Tensor, data: Dict[str, Any]): setting = pretrain_settings[self.conf_name] images = images.float() / 255.0 # backbone 自带 normalize _, _, h, w = images.shape simages = images[data["image_ids"]] matrix = setting["get_matrix_fn"](data[setting["matrix_src_tag"]]) grid = setting["get_grid_fn"](matrix=matrix, orig_shape=(h, w)) w_images = F.grid_sample(simages, grid, mode="bilinear", align_corners=False) outputs = self.net(w_images) probs = torch.sigmoid(outputs) data["attrs"] = probs return data if __name__ == "__main__": model = FaRLFaceAttribute()