|
from typing import List, Dict, Callable, Tuple, Optional |
|
import torch |
|
import torch.nn.functional as F |
|
import functools |
|
import numpy as np |
|
|
|
|
|
def get_crop_and_resize_matrix( |
|
box: torch.Tensor, target_shape: Tuple[int, int], |
|
target_face_scale: float = 1.0, make_square_crop: bool = True, |
|
offset_xy: Optional[Tuple[float, float]] = None, align_corners: bool = True, |
|
offset_box_coords: bool = False) -> torch.Tensor: |
|
""" |
|
Args: |
|
box: b x 4(x1, y1, x2, y2) |
|
align_corners (bool): Set this to `True` only if the box you give has coordinates |
|
ranging from `0` to `h-1` or `w-1`. |
|
|
|
offset_box_coords (bool): Set this to `True` if the box you give has coordinates |
|
ranging from `0` to `h` or `w`. |
|
|
|
Set this to `False` if the box coordinates range from `-0.5` to `h-0.5` or `w-0.5`. |
|
|
|
If the box coordinates range from `0` to `h-1` or `w-1`, set `align_corners=True`. |
|
|
|
Returns: |
|
torch.Tensor: b x 3 x 3. |
|
""" |
|
if offset_xy is None: |
|
offset_xy = (0.0, 0.0) |
|
|
|
x1, y1, x2, y2 = box.split(1, dim=1) |
|
cx = (x1 + x2) / 2 + offset_xy[0] |
|
cy = (y1 + y2) / 2 + offset_xy[1] |
|
rx = (x2 - x1) / 2 / target_face_scale |
|
ry = (y2 - y1) / 2 / target_face_scale |
|
if make_square_crop: |
|
rx = ry = torch.maximum(rx, ry) |
|
|
|
x1, y1, x2, y2 = cx - rx, cy - ry, cx + rx, cy + ry |
|
|
|
h, w, *_ = target_shape |
|
|
|
zeros_pl = torch.zeros_like(x1) |
|
ones_pl = torch.ones_like(x1) |
|
|
|
if align_corners: |
|
|
|
|
|
ax = 1.0 / (x2 - x1) * (w - 1) |
|
ay = 1.0 / (y2 - y1) * (h - 1) |
|
matrix = torch.cat([ |
|
ax, zeros_pl, -x1 * ax, |
|
zeros_pl, ay, -y1 * ay, |
|
zeros_pl, zeros_pl, ones_pl |
|
], dim=1).reshape(-1, 3, 3) |
|
else: |
|
if offset_box_coords: |
|
|
|
|
|
|
|
|
|
x1, x2, y1, y2 = x1-0.5, x2-0.5, y1-0.5, y2-0.5 |
|
|
|
|
|
|
|
ax = 1.0 / (x2 - x1) * w |
|
ay = 1.0 / (y2 - y1) * h |
|
matrix = torch.cat([ |
|
ax, zeros_pl, -x1 * ax - 0.5*ones_pl, |
|
zeros_pl, ay, -y1 * ay - 0.5*ones_pl, |
|
zeros_pl, zeros_pl, ones_pl |
|
], dim=1).reshape(-1, 3, 3) |
|
return matrix |
|
|
|
|
|
def get_similarity_transform_matrix( |
|
from_pts: torch.Tensor, to_pts: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
from_pts, to_pts: b x n x 2 |
|
|
|
Returns: |
|
torch.Tensor: b x 3 x 3 |
|
""" |
|
mfrom = from_pts.mean(dim=1, keepdim=True) |
|
mto = to_pts.mean(dim=1, keepdim=True) |
|
|
|
a1 = (from_pts - mfrom).square().sum([1, 2], keepdim=False) |
|
c1 = ((to_pts - mto) * (from_pts - mfrom)).sum([1, 2], keepdim=False) |
|
|
|
to_delta = to_pts - mto |
|
from_delta = from_pts - mfrom |
|
c2 = (to_delta[:, :, 0] * from_delta[:, :, 1] - to_delta[:, |
|
:, 1] * from_delta[:, :, 0]).sum([1], keepdim=False) |
|
|
|
a = c1 / a1 |
|
b = c2 / a1 |
|
dx = mto[:, 0, 0] - a * mfrom[:, 0, 0] - b * mfrom[:, 0, 1] |
|
dy = mto[:, 0, 1] + b * mfrom[:, 0, 0] - a * mfrom[:, 0, 1] |
|
|
|
ones_pl = torch.ones_like(a1) |
|
zeros_pl = torch.zeros_like(a1) |
|
|
|
return torch.stack([ |
|
a, b, dx, |
|
-b, a, dy, |
|
zeros_pl, zeros_pl, ones_pl, |
|
], dim=-1).reshape(-1, 3, 3) |
|
|
|
|
|
@functools.lru_cache() |
|
def _standard_face_pts(): |
|
pts = torch.tensor([ |
|
196.0, 226.0, |
|
316.0, 226.0, |
|
256.0, 286.0, |
|
220.0, 360.4, |
|
292.0, 360.4], dtype=torch.float32) / 256.0 - 1.0 |
|
return torch.reshape(pts, (5, 2)) |
|
|
|
|
|
def get_face_align_matrix( |
|
face_pts: torch.Tensor, target_shape: Tuple[int, int], |
|
target_face_scale: float = 1.0, offset_xy: Optional[Tuple[float, float]] = None, |
|
target_pts: Optional[torch.Tensor] = None): |
|
|
|
if target_pts is None: |
|
with torch.no_grad(): |
|
std_pts = _standard_face_pts().to(face_pts) |
|
h, w, *_ = target_shape |
|
target_pts = (std_pts * target_face_scale + 1) * \ |
|
torch.tensor([w-1, h-1]).to(face_pts) / 2.0 |
|
if offset_xy is not None: |
|
target_pts[:, 0] += offset_xy[0] |
|
target_pts[:, 1] += offset_xy[1] |
|
else: |
|
target_pts = target_pts.to(face_pts) |
|
|
|
if target_pts.dim() == 2: |
|
target_pts = target_pts.unsqueeze(0) |
|
if target_pts.size(0) == 1: |
|
target_pts = target_pts.broadcast_to(face_pts.shape) |
|
|
|
assert target_pts.shape == face_pts.shape |
|
|
|
return get_similarity_transform_matrix(face_pts, target_pts) |
|
|
|
|
|
def rot90(v): |
|
return np.array([-v[1], v[0]]) |
|
|
|
|
|
def get_quad(lm: torch.Tensor): |
|
|
|
lm = lm.detach().cpu().numpy() |
|
|
|
eye_avg = (lm[0] + lm[1]) * 0.5 + 0.5 |
|
mouth_avg = (lm[3] + lm[4]) * 0.5 + 0.5 |
|
eye_to_eye = lm[1] - lm[0] |
|
eye_to_mouth = mouth_avg - eye_avg |
|
x = eye_to_eye - rot90(eye_to_mouth) |
|
x /= np.hypot(*x) |
|
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) |
|
y = rot90(x) |
|
c = eye_avg + eye_to_mouth * 0.1 |
|
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) |
|
quad_for_coeffs = quad[[0,3, 2,1]] |
|
return torch.from_numpy(quad_for_coeffs).float() |
|
|
|
|
|
def get_face_align_matrix_celebm( |
|
face_pts: torch.Tensor, target_shape: Tuple[int, int]): |
|
|
|
face_pts = torch.stack([get_quad(pts) for pts in face_pts], dim=0).to(face_pts) |
|
|
|
assert target_shape[0] == target_shape[1] |
|
target_size = target_shape[0] |
|
target_pts = torch.as_tensor([[0, 0], [target_size,0], [target_size, target_size], [0, target_size]]).to(face_pts) |
|
|
|
if target_pts.dim() == 2: |
|
target_pts = target_pts.unsqueeze(0) |
|
if target_pts.size(0) == 1: |
|
target_pts = target_pts.broadcast_to(face_pts.shape) |
|
|
|
assert target_pts.shape == face_pts.shape |
|
|
|
return get_similarity_transform_matrix(face_pts, target_pts) |
|
|
|
@functools.lru_cache(maxsize=128) |
|
def _meshgrid(h, w) -> Tuple[torch.Tensor, torch.Tensor]: |
|
yy, xx = torch.meshgrid(torch.arange(h).float(), |
|
torch.arange(w).float(), |
|
indexing='ij') |
|
return yy, xx |
|
|
|
|
|
def _forge_grid(batch_size: int, device: torch.device, |
|
output_shape: Tuple[int, int], |
|
fn: Callable[[torch.Tensor], torch.Tensor] |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" Forge transform maps with a given function `fn`. |
|
|
|
Args: |
|
output_shape (tuple): (b, h, w, ...). |
|
fn (Callable[[torch.Tensor], torch.Tensor]): The function that accepts |
|
a bxnx2 array and outputs the transformed bxnx2 array. Both input |
|
and output store (x, y) coordinates. |
|
|
|
Note: |
|
both input and output arrays of `fn` should store (y, x) coordinates. |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: Two maps `X` and `Y`, where for each |
|
pixel (y, x) or coordinate (x, y), |
|
`(X[y, x], Y[y, x]) = fn([x, y])` |
|
""" |
|
h, w, *_ = output_shape |
|
yy, xx = _meshgrid(h, w) |
|
yy = yy.unsqueeze(0).broadcast_to(batch_size, h, w).to(device) |
|
xx = xx.unsqueeze(0).broadcast_to(batch_size, h, w).to(device) |
|
|
|
in_xxyy = torch.stack( |
|
[xx, yy], dim=-1).reshape([batch_size, h*w, 2]) |
|
out_xxyy: torch.Tensor = fn(in_xxyy) |
|
return out_xxyy.reshape(batch_size, h, w, 2) |
|
|
|
|
|
def _safe_arctanh(x: torch.Tensor, eps: float = 0.001) -> torch.Tensor: |
|
return torch.clamp(x, -1+eps, 1-eps).arctanh() |
|
|
|
|
|
def inverted_tanh_warp_transform(coords: torch.Tensor, matrix: torch.Tensor, |
|
warp_factor: float, warped_shape: Tuple[int, int]): |
|
""" Inverted tanh-warp function. |
|
|
|
Args: |
|
coords (torch.Tensor): b x n x 2 (x, y). The transformed coordinates. |
|
matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates |
|
from the original image to the aligned yet not-warped image. |
|
warp_factor (float): The warp factor. |
|
0 means linear transform, 1 means full tanh warp. |
|
warped_shape (tuple): [height, width]. |
|
|
|
Returns: |
|
torch.Tensor: b x n x 2 (x, y). The original coordinates. |
|
""" |
|
h, w, *_ = warped_shape |
|
|
|
|
|
|
|
w_h = torch.tensor([[w, h]]).to(coords) |
|
|
|
if warp_factor > 0: |
|
|
|
coords = coords / w_h * 2 - 1 |
|
|
|
nl_part1 = coords > 1.0 - warp_factor |
|
nl_part2 = coords < -1.0 + warp_factor |
|
|
|
ret_nl_part1 = _safe_arctanh( |
|
(coords - 1.0 + warp_factor) / |
|
warp_factor) * warp_factor + \ |
|
1.0 - warp_factor |
|
ret_nl_part2 = _safe_arctanh( |
|
(coords + 1.0 - warp_factor) / |
|
warp_factor) * warp_factor - \ |
|
1.0 + warp_factor |
|
|
|
coords = torch.where(nl_part1, ret_nl_part1, |
|
torch.where(nl_part2, ret_nl_part2, coords)) |
|
|
|
|
|
coords = (coords + 1) / 2 * w_h |
|
|
|
coords_homo = torch.cat( |
|
[coords, torch.ones_like(coords[:, :, [0]])], dim=-1) |
|
|
|
|
|
device = matrix.device |
|
inv_matrix_np = np.linalg.inv(matrix.cpu().numpy()) |
|
inv_matrix = torch.from_numpy(inv_matrix_np).to(device) |
|
coords_homo = torch.bmm( |
|
coords_homo, inv_matrix.permute(0, 2, 1)) |
|
return coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]] |
|
|
|
|
|
def tanh_warp_transform( |
|
coords: torch.Tensor, matrix: torch.Tensor, |
|
warp_factor: float, warped_shape: Tuple[int, int]): |
|
""" Tanh-warp function. |
|
|
|
Args: |
|
coords (torch.Tensor): b x n x 2 (x, y). The original coordinates. |
|
matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates |
|
from the original image to the aligned yet not-warped image. |
|
warp_factor (float): The warp factor. |
|
0 means linear transform, 1 means full tanh warp. |
|
warped_shape (tuple): [height, width]. |
|
|
|
Returns: |
|
torch.Tensor: b x n x 2 (x, y). The transformed coordinates. |
|
""" |
|
h, w, *_ = warped_shape |
|
|
|
|
|
w_h = torch.tensor([[w, h]]).to(coords) |
|
|
|
coords_homo = torch.cat( |
|
[coords, torch.ones_like(coords[:, :, [0]])], dim=-1) |
|
|
|
coords_homo = torch.bmm(coords_homo, matrix.transpose(2, 1)) |
|
coords = (coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]) |
|
|
|
if warp_factor > 0: |
|
|
|
coords = coords / w_h * 2 - 1 |
|
|
|
nl_part1 = coords > 1.0 - warp_factor |
|
nl_part2 = coords < -1.0 + warp_factor |
|
|
|
ret_nl_part1 = torch.tanh( |
|
(coords - 1.0 + warp_factor) / |
|
warp_factor) * warp_factor + \ |
|
1.0 - warp_factor |
|
ret_nl_part2 = torch.tanh( |
|
(coords + 1.0 - warp_factor) / |
|
warp_factor) * warp_factor - \ |
|
1.0 + warp_factor |
|
|
|
coords = torch.where(nl_part1, ret_nl_part1, |
|
torch.where(nl_part2, ret_nl_part2, coords)) |
|
|
|
|
|
coords = (coords + 1) / 2 * w_h |
|
|
|
return coords |
|
|
|
|
|
def make_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float, |
|
warped_shape: Tuple[int, int], |
|
orig_shape: Tuple[int, int]): |
|
""" |
|
Args: |
|
matrix: bx3x3 matrix. |
|
warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping, |
|
`warp_factor=0.0` represents a cropping. |
|
warped_shape: The target image shape to transform to. |
|
|
|
Returns: |
|
torch.Tensor: b x h x w x 2 (x, y). |
|
""" |
|
orig_h, orig_w, *_ = orig_shape |
|
w_h = torch.tensor([orig_w, orig_h]).to(matrix).reshape(1, 1, 1, 2) |
|
return _forge_grid( |
|
matrix.size(0), matrix.device, |
|
warped_shape, |
|
functools.partial(inverted_tanh_warp_transform, |
|
matrix=matrix, |
|
warp_factor=warp_factor, |
|
warped_shape=warped_shape)) / w_h*2-1 |
|
|
|
|
|
def make_inverted_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float, |
|
warped_shape: Tuple[int, int], |
|
orig_shape: Tuple[int, int]): |
|
""" |
|
Args: |
|
matrix: bx3x3 matrix. |
|
warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping, |
|
`warp_factor=0.0` represents a cropping. |
|
warped_shape: The target image shape to transform to. |
|
orig_shape: The original image shape that is transformed from. |
|
|
|
Returns: |
|
torch.Tensor: b x h x w x 2 (x, y). |
|
""" |
|
h, w, *_ = warped_shape |
|
w_h = torch.tensor([w, h]).to(matrix).reshape(1, 1, 1, 2) |
|
return _forge_grid( |
|
matrix.size(0), matrix.device, |
|
orig_shape, |
|
functools.partial(tanh_warp_transform, |
|
matrix=matrix, |
|
warp_factor=warp_factor, |
|
warped_shape=warped_shape)) / w_h * 2-1 |
|
|