from typing import List, Dict, Callable, Tuple, Optional import torch import torch.nn.functional as F import functools import numpy as np def get_crop_and_resize_matrix( box: torch.Tensor, target_shape: Tuple[int, int], target_face_scale: float = 1.0, make_square_crop: bool = True, offset_xy: Optional[Tuple[float, float]] = None, align_corners: bool = True, offset_box_coords: bool = False) -> torch.Tensor: """ Args: box: b x 4(x1, y1, x2, y2) align_corners (bool): Set this to `True` only if the box you give has coordinates ranging from `0` to `h-1` or `w-1`. offset_box_coords (bool): Set this to `True` if the box you give has coordinates ranging from `0` to `h` or `w`. Set this to `False` if the box coordinates range from `-0.5` to `h-0.5` or `w-0.5`. If the box coordinates range from `0` to `h-1` or `w-1`, set `align_corners=True`. Returns: torch.Tensor: b x 3 x 3. """ if offset_xy is None: offset_xy = (0.0, 0.0) x1, y1, x2, y2 = box.split(1, dim=1) # b x 1 cx = (x1 + x2) / 2 + offset_xy[0] cy = (y1 + y2) / 2 + offset_xy[1] rx = (x2 - x1) / 2 / target_face_scale ry = (y2 - y1) / 2 / target_face_scale if make_square_crop: rx = ry = torch.maximum(rx, ry) x1, y1, x2, y2 = cx - rx, cy - ry, cx + rx, cy + ry h, w, *_ = target_shape zeros_pl = torch.zeros_like(x1) ones_pl = torch.ones_like(x1) if align_corners: # x -> (x - x1) / (x2 - x1) * (w - 1) # y -> (y - y1) / (y2 - y1) * (h - 1) ax = 1.0 / (x2 - x1) * (w - 1) ay = 1.0 / (y2 - y1) * (h - 1) matrix = torch.cat([ ax, zeros_pl, -x1 * ax, zeros_pl, ay, -y1 * ay, zeros_pl, zeros_pl, ones_pl ], dim=1).reshape(-1, 3, 3) # b x 3 x 3 else: if offset_box_coords: # x1, x2 \in [0, w], y1, y2 \in [0, h] # first we should offset x1, x2, y1, y2 to be ranging in # [-0.5, w-0.5] and [-0.5, h-0.5] # so to convert these pixel coordinates into boundary coordinates. x1, x2, y1, y2 = x1-0.5, x2-0.5, y1-0.5, y2-0.5 # x -> (x - x1) / (x2 - x1) * w - 0.5 # y -> (y - y1) / (y2 - y1) * h - 0.5 ax = 1.0 / (x2 - x1) * w ay = 1.0 / (y2 - y1) * h matrix = torch.cat([ ax, zeros_pl, -x1 * ax - 0.5*ones_pl, zeros_pl, ay, -y1 * ay - 0.5*ones_pl, zeros_pl, zeros_pl, ones_pl ], dim=1).reshape(-1, 3, 3) # b x 3 x 3 return matrix def get_similarity_transform_matrix( from_pts: torch.Tensor, to_pts: torch.Tensor) -> torch.Tensor: """ Args: from_pts, to_pts: b x n x 2 Returns: torch.Tensor: b x 3 x 3 """ mfrom = from_pts.mean(dim=1, keepdim=True) # b x 1 x 2 mto = to_pts.mean(dim=1, keepdim=True) # b x 1 x 2 a1 = (from_pts - mfrom).square().sum([1, 2], keepdim=False) # b c1 = ((to_pts - mto) * (from_pts - mfrom)).sum([1, 2], keepdim=False) # b to_delta = to_pts - mto from_delta = from_pts - mfrom c2 = (to_delta[:, :, 0] * from_delta[:, :, 1] - to_delta[:, :, 1] * from_delta[:, :, 0]).sum([1], keepdim=False) # b a = c1 / a1 b = c2 / a1 dx = mto[:, 0, 0] - a * mfrom[:, 0, 0] - b * mfrom[:, 0, 1] # b dy = mto[:, 0, 1] + b * mfrom[:, 0, 0] - a * mfrom[:, 0, 1] # b ones_pl = torch.ones_like(a1) zeros_pl = torch.zeros_like(a1) return torch.stack([ a, b, dx, -b, a, dy, zeros_pl, zeros_pl, ones_pl, ], dim=-1).reshape(-1, 3, 3) @functools.lru_cache() def _standard_face_pts(): pts = torch.tensor([ 196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], dtype=torch.float32) / 256.0 - 1.0 return torch.reshape(pts, (5, 2)) def get_face_align_matrix( face_pts: torch.Tensor, target_shape: Tuple[int, int], target_face_scale: float = 1.0, offset_xy: Optional[Tuple[float, float]] = None, target_pts: Optional[torch.Tensor] = None): if target_pts is None: with torch.no_grad(): std_pts = _standard_face_pts().to(face_pts) # [-1 1] h, w, *_ = target_shape target_pts = (std_pts * target_face_scale + 1) * \ torch.tensor([w-1, h-1]).to(face_pts) / 2.0 if offset_xy is not None: target_pts[:, 0] += offset_xy[0] target_pts[:, 1] += offset_xy[1] else: target_pts = target_pts.to(face_pts) if target_pts.dim() == 2: target_pts = target_pts.unsqueeze(0) if target_pts.size(0) == 1: target_pts = target_pts.broadcast_to(face_pts.shape) assert target_pts.shape == face_pts.shape return get_similarity_transform_matrix(face_pts, target_pts) def rot90(v): return np.array([-v[1], v[0]]) def get_quad(lm: torch.Tensor): # N,2 lm = lm.detach().cpu().numpy() # Choose oriented crop rectangle. eye_avg = (lm[0] + lm[1]) * 0.5 + 0.5 mouth_avg = (lm[3] + lm[4]) * 0.5 + 0.5 eye_to_eye = lm[1] - lm[0] eye_to_mouth = mouth_avg - eye_avg x = eye_to_eye - rot90(eye_to_mouth) x /= np.hypot(*x) x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) y = rot90(x) c = eye_avg + eye_to_mouth * 0.1 quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) quad_for_coeffs = quad[[0,3, 2,1]] # 顺序改一下 return torch.from_numpy(quad_for_coeffs).float() def get_face_align_matrix_celebm( face_pts: torch.Tensor, target_shape: Tuple[int, int]): face_pts = torch.stack([get_quad(pts) for pts in face_pts], dim=0).to(face_pts) assert target_shape[0] == target_shape[1] target_size = target_shape[0] target_pts = torch.as_tensor([[0, 0], [target_size,0], [target_size, target_size], [0, target_size]]).to(face_pts) if target_pts.dim() == 2: target_pts = target_pts.unsqueeze(0) if target_pts.size(0) == 1: target_pts = target_pts.broadcast_to(face_pts.shape) assert target_pts.shape == face_pts.shape return get_similarity_transform_matrix(face_pts, target_pts) @functools.lru_cache(maxsize=128) def _meshgrid(h, w) -> Tuple[torch.Tensor, torch.Tensor]: yy, xx = torch.meshgrid(torch.arange(h).float(), torch.arange(w).float(), indexing='ij') return yy, xx def _forge_grid(batch_size: int, device: torch.device, output_shape: Tuple[int, int], fn: Callable[[torch.Tensor], torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forge transform maps with a given function `fn`. Args: output_shape (tuple): (b, h, w, ...). fn (Callable[[torch.Tensor], torch.Tensor]): The function that accepts a bxnx2 array and outputs the transformed bxnx2 array. Both input and output store (x, y) coordinates. Note: both input and output arrays of `fn` should store (y, x) coordinates. Returns: Tuple[torch.Tensor, torch.Tensor]: Two maps `X` and `Y`, where for each pixel (y, x) or coordinate (x, y), `(X[y, x], Y[y, x]) = fn([x, y])` """ h, w, *_ = output_shape yy, xx = _meshgrid(h, w) # h x w yy = yy.unsqueeze(0).broadcast_to(batch_size, h, w).to(device) xx = xx.unsqueeze(0).broadcast_to(batch_size, h, w).to(device) in_xxyy = torch.stack( [xx, yy], dim=-1).reshape([batch_size, h*w, 2]) # (h x w) x 2 out_xxyy: torch.Tensor = fn(in_xxyy) # (h x w) x 2 return out_xxyy.reshape(batch_size, h, w, 2) def _safe_arctanh(x: torch.Tensor, eps: float = 0.001) -> torch.Tensor: return torch.clamp(x, -1+eps, 1-eps).arctanh() def inverted_tanh_warp_transform(coords: torch.Tensor, matrix: torch.Tensor, warp_factor: float, warped_shape: Tuple[int, int]): """ Inverted tanh-warp function. Args: coords (torch.Tensor): b x n x 2 (x, y). The transformed coordinates. matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates from the original image to the aligned yet not-warped image. warp_factor (float): The warp factor. 0 means linear transform, 1 means full tanh warp. warped_shape (tuple): [height, width]. Returns: torch.Tensor: b x n x 2 (x, y). The original coordinates. """ h, w, *_ = warped_shape # h -= 1 # w -= 1 w_h = torch.tensor([[w, h]]).to(coords) if warp_factor > 0: # normalize coordinates to [-1, +1] coords = coords / w_h * 2 - 1 nl_part1 = coords > 1.0 - warp_factor nl_part2 = coords < -1.0 + warp_factor ret_nl_part1 = _safe_arctanh( (coords - 1.0 + warp_factor) / warp_factor) * warp_factor + \ 1.0 - warp_factor ret_nl_part2 = _safe_arctanh( (coords + 1.0 - warp_factor) / warp_factor) * warp_factor - \ 1.0 + warp_factor coords = torch.where(nl_part1, ret_nl_part1, torch.where(nl_part2, ret_nl_part2, coords)) # denormalize coords = (coords + 1) / 2 * w_h coords_homo = torch.cat( [coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3 # inv_matrix = torch.linalg.inv(matrix) # b x 3 x 3 device = matrix.device inv_matrix_np = np.linalg.inv(matrix.cpu().numpy()) inv_matrix = torch.from_numpy(inv_matrix_np).to(device) coords_homo = torch.bmm( coords_homo, inv_matrix.permute(0, 2, 1)) # b x n x 3 return coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]] def tanh_warp_transform( coords: torch.Tensor, matrix: torch.Tensor, warp_factor: float, warped_shape: Tuple[int, int]): """ Tanh-warp function. Args: coords (torch.Tensor): b x n x 2 (x, y). The original coordinates. matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates from the original image to the aligned yet not-warped image. warp_factor (float): The warp factor. 0 means linear transform, 1 means full tanh warp. warped_shape (tuple): [height, width]. Returns: torch.Tensor: b x n x 2 (x, y). The transformed coordinates. """ h, w, *_ = warped_shape # h -= 1 # w -= 1 w_h = torch.tensor([[w, h]]).to(coords) coords_homo = torch.cat( [coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3 coords_homo = torch.bmm(coords_homo, matrix.transpose(2, 1)) # b x n x 3 coords = (coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]) # b x n x 2 if warp_factor > 0: # normalize coordinates to [-1, +1] coords = coords / w_h * 2 - 1 nl_part1 = coords > 1.0 - warp_factor nl_part2 = coords < -1.0 + warp_factor ret_nl_part1 = torch.tanh( (coords - 1.0 + warp_factor) / warp_factor) * warp_factor + \ 1.0 - warp_factor ret_nl_part2 = torch.tanh( (coords + 1.0 - warp_factor) / warp_factor) * warp_factor - \ 1.0 + warp_factor coords = torch.where(nl_part1, ret_nl_part1, torch.where(nl_part2, ret_nl_part2, coords)) # denormalize coords = (coords + 1) / 2 * w_h return coords def make_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float, warped_shape: Tuple[int, int], orig_shape: Tuple[int, int]): """ Args: matrix: bx3x3 matrix. warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping, `warp_factor=0.0` represents a cropping. warped_shape: The target image shape to transform to. Returns: torch.Tensor: b x h x w x 2 (x, y). """ orig_h, orig_w, *_ = orig_shape w_h = torch.tensor([orig_w, orig_h]).to(matrix).reshape(1, 1, 1, 2) return _forge_grid( matrix.size(0), matrix.device, warped_shape, functools.partial(inverted_tanh_warp_transform, matrix=matrix, warp_factor=warp_factor, warped_shape=warped_shape)) / w_h*2-1 def make_inverted_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float, warped_shape: Tuple[int, int], orig_shape: Tuple[int, int]): """ Args: matrix: bx3x3 matrix. warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping, `warp_factor=0.0` represents a cropping. warped_shape: The target image shape to transform to. orig_shape: The original image shape that is transformed from. Returns: torch.Tensor: b x h x w x 2 (x, y). """ h, w, *_ = warped_shape w_h = torch.tensor([w, h]).to(matrix).reshape(1, 1, 1, 2) return _forge_grid( matrix.size(0), matrix.device, orig_shape, functools.partial(tanh_warp_transform, matrix=matrix, warp_factor=warp_factor, warped_shape=warped_shape)) / w_h * 2-1