Last commit not found
from typing import List, Dict, Callable, Tuple, Optional | |
import torch | |
import torch.nn.functional as F | |
import functools | |
import numpy as np | |
def get_crop_and_resize_matrix( | |
box: torch.Tensor, target_shape: Tuple[int, int], | |
target_face_scale: float = 1.0, make_square_crop: bool = True, | |
offset_xy: Optional[Tuple[float, float]] = None, align_corners: bool = True, | |
offset_box_coords: bool = False) -> torch.Tensor: | |
""" | |
Args: | |
box: b x 4(x1, y1, x2, y2) | |
align_corners (bool): Set this to `True` only if the box you give has coordinates | |
ranging from `0` to `h-1` or `w-1`. | |
offset_box_coords (bool): Set this to `True` if the box you give has coordinates | |
ranging from `0` to `h` or `w`. | |
Set this to `False` if the box coordinates range from `-0.5` to `h-0.5` or `w-0.5`. | |
If the box coordinates range from `0` to `h-1` or `w-1`, set `align_corners=True`. | |
Returns: | |
torch.Tensor: b x 3 x 3. | |
""" | |
if offset_xy is None: | |
offset_xy = (0.0, 0.0) | |
x1, y1, x2, y2 = box.split(1, dim=1) # b x 1 | |
cx = (x1 + x2) / 2 + offset_xy[0] | |
cy = (y1 + y2) / 2 + offset_xy[1] | |
rx = (x2 - x1) / 2 / target_face_scale | |
ry = (y2 - y1) / 2 / target_face_scale | |
if make_square_crop: | |
rx = ry = torch.maximum(rx, ry) | |
x1, y1, x2, y2 = cx - rx, cy - ry, cx + rx, cy + ry | |
h, w, *_ = target_shape | |
zeros_pl = torch.zeros_like(x1) | |
ones_pl = torch.ones_like(x1) | |
if align_corners: | |
# x -> (x - x1) / (x2 - x1) * (w - 1) | |
# y -> (y - y1) / (y2 - y1) * (h - 1) | |
ax = 1.0 / (x2 - x1) * (w - 1) | |
ay = 1.0 / (y2 - y1) * (h - 1) | |
matrix = torch.cat([ | |
ax, zeros_pl, -x1 * ax, | |
zeros_pl, ay, -y1 * ay, | |
zeros_pl, zeros_pl, ones_pl | |
], dim=1).reshape(-1, 3, 3) # b x 3 x 3 | |
else: | |
if offset_box_coords: | |
# x1, x2 \in [0, w], y1, y2 \in [0, h] | |
# first we should offset x1, x2, y1, y2 to be ranging in | |
# [-0.5, w-0.5] and [-0.5, h-0.5] | |
# so to convert these pixel coordinates into boundary coordinates. | |
x1, x2, y1, y2 = x1-0.5, x2-0.5, y1-0.5, y2-0.5 | |
# x -> (x - x1) / (x2 - x1) * w - 0.5 | |
# y -> (y - y1) / (y2 - y1) * h - 0.5 | |
ax = 1.0 / (x2 - x1) * w | |
ay = 1.0 / (y2 - y1) * h | |
matrix = torch.cat([ | |
ax, zeros_pl, -x1 * ax - 0.5*ones_pl, | |
zeros_pl, ay, -y1 * ay - 0.5*ones_pl, | |
zeros_pl, zeros_pl, ones_pl | |
], dim=1).reshape(-1, 3, 3) # b x 3 x 3 | |
return matrix | |
def get_similarity_transform_matrix( | |
from_pts: torch.Tensor, to_pts: torch.Tensor) -> torch.Tensor: | |
""" | |
Args: | |
from_pts, to_pts: b x n x 2 | |
Returns: | |
torch.Tensor: b x 3 x 3 | |
""" | |
mfrom = from_pts.mean(dim=1, keepdim=True) # b x 1 x 2 | |
mto = to_pts.mean(dim=1, keepdim=True) # b x 1 x 2 | |
a1 = (from_pts - mfrom).square().sum([1, 2], keepdim=False) # b | |
c1 = ((to_pts - mto) * (from_pts - mfrom)).sum([1, 2], keepdim=False) # b | |
to_delta = to_pts - mto | |
from_delta = from_pts - mfrom | |
c2 = (to_delta[:, :, 0] * from_delta[:, :, 1] - to_delta[:, | |
:, 1] * from_delta[:, :, 0]).sum([1], keepdim=False) # b | |
a = c1 / a1 | |
b = c2 / a1 | |
dx = mto[:, 0, 0] - a * mfrom[:, 0, 0] - b * mfrom[:, 0, 1] # b | |
dy = mto[:, 0, 1] + b * mfrom[:, 0, 0] - a * mfrom[:, 0, 1] # b | |
ones_pl = torch.ones_like(a1) | |
zeros_pl = torch.zeros_like(a1) | |
return torch.stack([ | |
a, b, dx, | |
-b, a, dy, | |
zeros_pl, zeros_pl, ones_pl, | |
], dim=-1).reshape(-1, 3, 3) | |
def _standard_face_pts(): | |
pts = torch.tensor([ | |
196.0, 226.0, | |
316.0, 226.0, | |
256.0, 286.0, | |
220.0, 360.4, | |
292.0, 360.4], dtype=torch.float32) / 256.0 - 1.0 | |
return torch.reshape(pts, (5, 2)) | |
def get_face_align_matrix( | |
face_pts: torch.Tensor, target_shape: Tuple[int, int], | |
target_face_scale: float = 1.0, offset_xy: Optional[Tuple[float, float]] = None, | |
target_pts: Optional[torch.Tensor] = None): | |
if target_pts is None: | |
with torch.no_grad(): | |
std_pts = _standard_face_pts().to(face_pts) # [-1 1] | |
h, w, *_ = target_shape | |
target_pts = (std_pts * target_face_scale + 1) * \ | |
torch.tensor([w-1, h-1]).to(face_pts) / 2.0 | |
if offset_xy is not None: | |
target_pts[:, 0] += offset_xy[0] | |
target_pts[:, 1] += offset_xy[1] | |
else: | |
target_pts = target_pts.to(face_pts) | |
if target_pts.dim() == 2: | |
target_pts = target_pts.unsqueeze(0) | |
if target_pts.size(0) == 1: | |
target_pts = target_pts.broadcast_to(face_pts.shape) | |
assert target_pts.shape == face_pts.shape | |
return get_similarity_transform_matrix(face_pts, target_pts) | |
def rot90(v): | |
return np.array([-v[1], v[0]]) | |
def get_quad(lm: torch.Tensor): | |
# N,2 | |
lm = lm.detach().cpu().numpy() | |
# Choose oriented crop rectangle. | |
eye_avg = (lm[0] + lm[1]) * 0.5 + 0.5 | |
mouth_avg = (lm[3] + lm[4]) * 0.5 + 0.5 | |
eye_to_eye = lm[1] - lm[0] | |
eye_to_mouth = mouth_avg - eye_avg | |
x = eye_to_eye - rot90(eye_to_mouth) | |
x /= np.hypot(*x) | |
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) | |
y = rot90(x) | |
c = eye_avg + eye_to_mouth * 0.1 | |
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) | |
quad_for_coeffs = quad[[0,3, 2,1]] # 顺序改一下 | |
return torch.from_numpy(quad_for_coeffs).float() | |
def get_face_align_matrix_celebm( | |
face_pts: torch.Tensor, target_shape: Tuple[int, int]): | |
face_pts = torch.stack([get_quad(pts) for pts in face_pts], dim=0).to(face_pts) | |
assert target_shape[0] == target_shape[1] | |
target_size = target_shape[0] | |
target_pts = torch.as_tensor([[0, 0], [target_size,0], [target_size, target_size], [0, target_size]]).to(face_pts) | |
if target_pts.dim() == 2: | |
target_pts = target_pts.unsqueeze(0) | |
if target_pts.size(0) == 1: | |
target_pts = target_pts.broadcast_to(face_pts.shape) | |
assert target_pts.shape == face_pts.shape | |
return get_similarity_transform_matrix(face_pts, target_pts) | |
def _meshgrid(h, w) -> Tuple[torch.Tensor, torch.Tensor]: | |
yy, xx = torch.meshgrid(torch.arange(h).float(), | |
torch.arange(w).float(), | |
indexing='ij') | |
return yy, xx | |
def _forge_grid(batch_size: int, device: torch.device, | |
output_shape: Tuple[int, int], | |
fn: Callable[[torch.Tensor], torch.Tensor] | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" Forge transform maps with a given function `fn`. | |
Args: | |
output_shape (tuple): (b, h, w, ...). | |
fn (Callable[[torch.Tensor], torch.Tensor]): The function that accepts | |
a bxnx2 array and outputs the transformed bxnx2 array. Both input | |
and output store (x, y) coordinates. | |
Note: | |
both input and output arrays of `fn` should store (y, x) coordinates. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: Two maps `X` and `Y`, where for each | |
pixel (y, x) or coordinate (x, y), | |
`(X[y, x], Y[y, x]) = fn([x, y])` | |
""" | |
h, w, *_ = output_shape | |
yy, xx = _meshgrid(h, w) # h x w | |
yy = yy.unsqueeze(0).broadcast_to(batch_size, h, w).to(device) | |
xx = xx.unsqueeze(0).broadcast_to(batch_size, h, w).to(device) | |
in_xxyy = torch.stack( | |
[xx, yy], dim=-1).reshape([batch_size, h*w, 2]) # (h x w) x 2 | |
out_xxyy: torch.Tensor = fn(in_xxyy) # (h x w) x 2 | |
return out_xxyy.reshape(batch_size, h, w, 2) | |
def _safe_arctanh(x: torch.Tensor, eps: float = 0.001) -> torch.Tensor: | |
return torch.clamp(x, -1+eps, 1-eps).arctanh() | |
def inverted_tanh_warp_transform(coords: torch.Tensor, matrix: torch.Tensor, | |
warp_factor: float, warped_shape: Tuple[int, int]): | |
""" Inverted tanh-warp function. | |
Args: | |
coords (torch.Tensor): b x n x 2 (x, y). The transformed coordinates. | |
matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates | |
from the original image to the aligned yet not-warped image. | |
warp_factor (float): The warp factor. | |
0 means linear transform, 1 means full tanh warp. | |
warped_shape (tuple): [height, width]. | |
Returns: | |
torch.Tensor: b x n x 2 (x, y). The original coordinates. | |
""" | |
h, w, *_ = warped_shape | |
# h -= 1 | |
# w -= 1 | |
w_h = torch.tensor([[w, h]]).to(coords) | |
if warp_factor > 0: | |
# normalize coordinates to [-1, +1] | |
coords = coords / w_h * 2 - 1 | |
nl_part1 = coords > 1.0 - warp_factor | |
nl_part2 = coords < -1.0 + warp_factor | |
ret_nl_part1 = _safe_arctanh( | |
(coords - 1.0 + warp_factor) / | |
warp_factor) * warp_factor + \ | |
1.0 - warp_factor | |
ret_nl_part2 = _safe_arctanh( | |
(coords + 1.0 - warp_factor) / | |
warp_factor) * warp_factor - \ | |
1.0 + warp_factor | |
coords = torch.where(nl_part1, ret_nl_part1, | |
torch.where(nl_part2, ret_nl_part2, coords)) | |
# denormalize | |
coords = (coords + 1) / 2 * w_h | |
coords_homo = torch.cat( | |
[coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3 | |
# inv_matrix = torch.linalg.inv(matrix) # b x 3 x 3 | |
device = matrix.device | |
inv_matrix_np = np.linalg.inv(matrix.cpu().numpy()) | |
inv_matrix = torch.from_numpy(inv_matrix_np).to(device) | |
coords_homo = torch.bmm( | |
coords_homo, inv_matrix.permute(0, 2, 1)) # b x n x 3 | |
return coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]] | |
def tanh_warp_transform( | |
coords: torch.Tensor, matrix: torch.Tensor, | |
warp_factor: float, warped_shape: Tuple[int, int]): | |
""" Tanh-warp function. | |
Args: | |
coords (torch.Tensor): b x n x 2 (x, y). The original coordinates. | |
matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates | |
from the original image to the aligned yet not-warped image. | |
warp_factor (float): The warp factor. | |
0 means linear transform, 1 means full tanh warp. | |
warped_shape (tuple): [height, width]. | |
Returns: | |
torch.Tensor: b x n x 2 (x, y). The transformed coordinates. | |
""" | |
h, w, *_ = warped_shape | |
# h -= 1 | |
# w -= 1 | |
w_h = torch.tensor([[w, h]]).to(coords) | |
coords_homo = torch.cat( | |
[coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3 | |
coords_homo = torch.bmm(coords_homo, matrix.transpose(2, 1)) # b x n x 3 | |
coords = (coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]) # b x n x 2 | |
if warp_factor > 0: | |
# normalize coordinates to [-1, +1] | |
coords = coords / w_h * 2 - 1 | |
nl_part1 = coords > 1.0 - warp_factor | |
nl_part2 = coords < -1.0 + warp_factor | |
ret_nl_part1 = torch.tanh( | |
(coords - 1.0 + warp_factor) / | |
warp_factor) * warp_factor + \ | |
1.0 - warp_factor | |
ret_nl_part2 = torch.tanh( | |
(coords + 1.0 - warp_factor) / | |
warp_factor) * warp_factor - \ | |
1.0 + warp_factor | |
coords = torch.where(nl_part1, ret_nl_part1, | |
torch.where(nl_part2, ret_nl_part2, coords)) | |
# denormalize | |
coords = (coords + 1) / 2 * w_h | |
return coords | |
def make_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float, | |
warped_shape: Tuple[int, int], | |
orig_shape: Tuple[int, int]): | |
""" | |
Args: | |
matrix: bx3x3 matrix. | |
warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping, | |
`warp_factor=0.0` represents a cropping. | |
warped_shape: The target image shape to transform to. | |
Returns: | |
torch.Tensor: b x h x w x 2 (x, y). | |
""" | |
orig_h, orig_w, *_ = orig_shape | |
w_h = torch.tensor([orig_w, orig_h]).to(matrix).reshape(1, 1, 1, 2) | |
return _forge_grid( | |
matrix.size(0), matrix.device, | |
warped_shape, | |
functools.partial(inverted_tanh_warp_transform, | |
matrix=matrix, | |
warp_factor=warp_factor, | |
warped_shape=warped_shape)) / w_h*2-1 | |
def make_inverted_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float, | |
warped_shape: Tuple[int, int], | |
orig_shape: Tuple[int, int]): | |
""" | |
Args: | |
matrix: bx3x3 matrix. | |
warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping, | |
`warp_factor=0.0` represents a cropping. | |
warped_shape: The target image shape to transform to. | |
orig_shape: The original image shape that is transformed from. | |
Returns: | |
torch.Tensor: b x h x w x 2 (x, y). | |
""" | |
h, w, *_ = warped_shape | |
w_h = torch.tensor([w, h]).to(matrix).reshape(1, 1, 1, 2) | |
return _forge_grid( | |
matrix.size(0), matrix.device, | |
orig_shape, | |
functools.partial(tanh_warp_transform, | |
matrix=matrix, | |
warp_factor=warp_factor, | |
warped_shape=warped_shape)) / w_h * 2-1 | |