Spaces:
Build error
Build error
Create utils.py
Browse files- utils/utils.py +51 -0
utils/utils.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
from kornia.geometry.transform.crop2d import warp_affine
|
| 5 |
+
|
| 6 |
+
from utils.matlab_cp2tform import get_similarity_transform_for_cv2
|
| 7 |
+
from torchvision.transforms import Pad
|
| 8 |
+
|
| 9 |
+
REFERNCE_FACIAL_POINTS_RELATIVE = np.array([[38.29459953, 51.69630051],
|
| 10 |
+
[72.53179932, 51.50139999],
|
| 11 |
+
[56.02519989, 71.73660278],
|
| 12 |
+
[41.54930115, 92.3655014],
|
| 13 |
+
[70.72990036, 92.20410156]
|
| 14 |
+
]) / 112 # Original points are 112 * 96 added 8 to the x axis to make it 112 * 112
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@torch.no_grad()
|
| 18 |
+
def detect_face(images: torch.Tensor, mtcnn: torch.nn.Module) -> torch.Tensor:
|
| 19 |
+
"""
|
| 20 |
+
Detect faces in the images using MTCNN. If no face is detected, use the whole image.
|
| 21 |
+
"""
|
| 22 |
+
images = rearrange(images, "b c h w -> b h w c")
|
| 23 |
+
if images.dtype != torch.uint8:
|
| 24 |
+
images = ((images * 0.5 + 0.5) * 255).type(torch.uint8) # Unnormalize
|
| 25 |
+
|
| 26 |
+
_, _, landmarks = mtcnn(images, landmarks=True)
|
| 27 |
+
|
| 28 |
+
return landmarks
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def extract_faces_and_landmarks(images: torch.Tensor, output_size=112, mtcnn: torch.nn.Module = None, refernce_points=REFERNCE_FACIAL_POINTS_RELATIVE):
|
| 32 |
+
"""
|
| 33 |
+
detect faces in the images and crop them (in a differentiable way) to 112x112 using MTCNN.
|
| 34 |
+
"""
|
| 35 |
+
images = Pad(200)(images)
|
| 36 |
+
landmarks_batched = detect_face(images, mtcnn=mtcnn)
|
| 37 |
+
affine_transformations = []
|
| 38 |
+
invalid_indices = []
|
| 39 |
+
for i, landmarks in enumerate(landmarks_batched):
|
| 40 |
+
if landmarks is None:
|
| 41 |
+
invalid_indices.append(i)
|
| 42 |
+
affine_transformations.append(np.eye(2, 3).astype(np.float32))
|
| 43 |
+
else:
|
| 44 |
+
affine_transformations.append(get_similarity_transform_for_cv2(landmarks[0].astype(np.float32),
|
| 45 |
+
refernce_points.astype(np.float32) * output_size))
|
| 46 |
+
affine_transformations = torch.from_numpy(np.stack(affine_transformations).astype(np.float32)).to(device=images.device, dtype=torch.float32)
|
| 47 |
+
|
| 48 |
+
invalid_indices = torch.tensor(invalid_indices).to(device=images.device)
|
| 49 |
+
|
| 50 |
+
fp_images = images.to(torch.float32)
|
| 51 |
+
return warp_affine(fp_images, affine_transformations, dsize=(output_size, output_size)).to(dtype=images.dtype), invalid_indices
|