Spaces:
Running
on
Zero
Running
on
Zero
Delete src/pixel3dmm/preprocessing/facer/facer/face_parsing/farl.py
Browse files
src/pixel3dmm/preprocessing/facer/facer/face_parsing/farl.py
DELETED
|
@@ -1,174 +0,0 @@
|
|
| 1 |
-
from typing import Optional, Dict, Any
|
| 2 |
-
import functools
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
|
| 6 |
-
from ..util import download_jit
|
| 7 |
-
from ..transform import (get_crop_and_resize_matrix, get_face_align_matrix, get_face_align_matrix_celebm,
|
| 8 |
-
make_inverted_tanh_warp_grid, make_tanh_warp_grid)
|
| 9 |
-
from .base import FaceParser
|
| 10 |
-
import numpy as np
|
| 11 |
-
|
| 12 |
-
pretrain_settings = {
|
| 13 |
-
'lapa/448': {
|
| 14 |
-
'url': [
|
| 15 |
-
'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.lapa.main_ema_136500_jit191.pt',
|
| 16 |
-
],
|
| 17 |
-
'matrix_src_tag': 'points',
|
| 18 |
-
'get_matrix_fn': functools.partial(get_face_align_matrix,
|
| 19 |
-
target_shape=(448, 448), target_face_scale=1.0),
|
| 20 |
-
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
| 21 |
-
warp_factor=0.8, warped_shape=(448, 448)),
|
| 22 |
-
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
| 23 |
-
warp_factor=0.8, warped_shape=(448, 448)),
|
| 24 |
-
'label_names': ['background', 'face', 'rb', 'lb', 're',
|
| 25 |
-
'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
|
| 26 |
-
},
|
| 27 |
-
'celebm/448': {
|
| 28 |
-
'url': [
|
| 29 |
-
'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.celebm.main_ema_181500_jit.pt',
|
| 30 |
-
],
|
| 31 |
-
'matrix_src_tag': 'points',
|
| 32 |
-
'get_matrix_fn': functools.partial(get_face_align_matrix_celebm,
|
| 33 |
-
target_shape=(448, 448)),
|
| 34 |
-
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
| 35 |
-
warp_factor=0, warped_shape=(448, 448)),
|
| 36 |
-
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
| 37 |
-
warp_factor=0, warped_shape=(448, 448)),
|
| 38 |
-
'label_names': [
|
| 39 |
-
'background', 'neck', 'face', 'cloth', 'rr', 'lr', 'rb', 'lb', 're',
|
| 40 |
-
'le', 'nose', 'imouth', 'llip', 'ulip', 'hair',
|
| 41 |
-
'eyeg', 'hat', 'earr', 'neck_l']
|
| 42 |
-
}
|
| 43 |
-
}
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class FaRLFaceParser(FaceParser):
|
| 47 |
-
""" The face parsing models from [FaRL](https://github.com/FacePerceiver/FaRL).
|
| 48 |
-
|
| 49 |
-
Please consider citing
|
| 50 |
-
```bibtex
|
| 51 |
-
@article{zheng2021farl,
|
| 52 |
-
title={General Facial Representation Learning in a Visual-Linguistic Manner},
|
| 53 |
-
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
|
| 54 |
-
Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
|
| 55 |
-
Dong and Zeng, Ming and Wen, Fang},
|
| 56 |
-
journal={arXiv preprint arXiv:2112.03109},
|
| 57 |
-
year={2021}
|
| 58 |
-
}
|
| 59 |
-
```
|
| 60 |
-
"""
|
| 61 |
-
|
| 62 |
-
def __init__(self, conf_name: Optional[str] = None, model_path: Optional[str] = None, device=None) -> None:
|
| 63 |
-
super().__init__()
|
| 64 |
-
if conf_name is None:
|
| 65 |
-
conf_name = 'lapa/448'
|
| 66 |
-
if model_path is None:
|
| 67 |
-
model_path = pretrain_settings[conf_name]['url']
|
| 68 |
-
self.conf_name = conf_name
|
| 69 |
-
self.net = download_jit(model_path, map_location=device)
|
| 70 |
-
self.eval()
|
| 71 |
-
self.device = device
|
| 72 |
-
self.setting = pretrain_settings[conf_name]
|
| 73 |
-
self.label_names = self.setting['label_names']
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def get_warp_grid(self, images: torch.Tensor, matrix_src):
|
| 77 |
-
_, _, h, w = images.shape
|
| 78 |
-
matrix = self.setting['get_matrix_fn'](matrix_src)
|
| 79 |
-
grid = self.setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
|
| 80 |
-
inv_grid = self.setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
|
| 81 |
-
return grid, inv_grid
|
| 82 |
-
|
| 83 |
-
def warp_images(self, images: torch.Tensor, data: Dict[str, Any]):
|
| 84 |
-
simages = self.unify_image_dtype(images)
|
| 85 |
-
simages = simages[data['image_ids']]
|
| 86 |
-
matrix_src = data[self.setting['matrix_src_tag']]
|
| 87 |
-
grid, inv_grid = self.get_warp_grid(simages, matrix_src)
|
| 88 |
-
|
| 89 |
-
w_images = F.grid_sample(
|
| 90 |
-
simages, grid, mode='bilinear', align_corners=False)
|
| 91 |
-
return w_images, grid, inv_grid
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def decode_image_to_cv2(self, images: torch.Tensor):
|
| 95 |
-
'''
|
| 96 |
-
output: b x 3 x h x w, torch.uint8, [0, 255]
|
| 97 |
-
'''
|
| 98 |
-
assert images.ndim == 4
|
| 99 |
-
assert images.shape[1] == 3
|
| 100 |
-
images = images.permute(0, 2, 3, 1).cpu().numpy() * 255
|
| 101 |
-
images = images.astype(np.uint8)
|
| 102 |
-
return images
|
| 103 |
-
|
| 104 |
-
def unify_image_dtype(self, images: torch.Tensor|np.ndarray|list):
|
| 105 |
-
'''
|
| 106 |
-
output: b x 3 x h x w, torch.float32, [0, 1]
|
| 107 |
-
'''
|
| 108 |
-
if isinstance(images, np.ndarray):
|
| 109 |
-
images = torch.from_numpy(images)
|
| 110 |
-
elif isinstance(images, torch.Tensor):
|
| 111 |
-
pass
|
| 112 |
-
elif isinstance(images, list):
|
| 113 |
-
assert len(images) > 0, "images is empty"
|
| 114 |
-
first_image = images[0]
|
| 115 |
-
if isinstance(first_image, np.ndarray):
|
| 116 |
-
images = [torch.from_numpy(image).permute(2, 0, 1) for image in images]
|
| 117 |
-
images = torch.stack(images)
|
| 118 |
-
elif isinstance(first_image, torch.Tensor):
|
| 119 |
-
images = torch.stack(images)
|
| 120 |
-
else:
|
| 121 |
-
raise ValueError(f"Unsupported image type: {type(first_image)}")
|
| 122 |
-
|
| 123 |
-
else:
|
| 124 |
-
raise ValueError(f"Unsupported image type: {type(images)}")
|
| 125 |
-
|
| 126 |
-
assert images.ndim == 4
|
| 127 |
-
assert images.shape[1] == 3
|
| 128 |
-
|
| 129 |
-
max_val = images.max()
|
| 130 |
-
if max_val <= 1:
|
| 131 |
-
assert images.dtype == torch.float32 or images.dtype == torch.float16
|
| 132 |
-
elif max_val <= 255:
|
| 133 |
-
assert images.dtype == torch.uint8
|
| 134 |
-
images = images.float() / 255.0
|
| 135 |
-
else:
|
| 136 |
-
raise ValueError(f"Unsupported image type: {images.dtype}")
|
| 137 |
-
if images.device != self.device:
|
| 138 |
-
images = images.to(device=self.device)
|
| 139 |
-
return images
|
| 140 |
-
|
| 141 |
-
@torch.no_grad()
|
| 142 |
-
@torch.inference_mode()
|
| 143 |
-
def forward(self, images: torch.Tensor, data: Dict[str, Any]):
|
| 144 |
-
'''
|
| 145 |
-
images: b x 3 x h x w , torch.uint8, [0, 255]
|
| 146 |
-
data: {'rects': rects, 'points': points, 'scores': scores, 'image_ids': image_ids}
|
| 147 |
-
'''
|
| 148 |
-
w_images, grid, inv_grid = self.warp_images(images, data)
|
| 149 |
-
w_seg_logits = self.forward_warped(w_images, return_preds=False)
|
| 150 |
-
|
| 151 |
-
seg_logits = F.grid_sample(
|
| 152 |
-
w_seg_logits, inv_grid, mode='bilinear', align_corners=False)
|
| 153 |
-
|
| 154 |
-
data['seg'] = {'logits': seg_logits, 'label_names': self.label_names}
|
| 155 |
-
return data
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
def logits2predictions(self, logits: torch.Tensor):
|
| 159 |
-
return logits.argmax(dim=1)
|
| 160 |
-
|
| 161 |
-
@torch.no_grad()
|
| 162 |
-
@torch.inference_mode()
|
| 163 |
-
def forward_warped(self, images: torch.Tensor, return_preds: bool = True):
|
| 164 |
-
'''
|
| 165 |
-
images: b x 3 x h x w , torch.uint8, [0, 255]
|
| 166 |
-
'''
|
| 167 |
-
images = self.unify_image_dtype(images)
|
| 168 |
-
seg_logits, _ = self.net(images) # nfaces x c x h x w
|
| 169 |
-
# seg_probs = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w
|
| 170 |
-
if return_preds:
|
| 171 |
-
seg_preds = self.logits2predictions(seg_logits)
|
| 172 |
-
return seg_logits, seg_preds, self.label_names
|
| 173 |
-
else:
|
| 174 |
-
return seg_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|