# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa from typing import Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.nn.modules.utils import _pair def bilinear_grid_sample(im: Tensor, grid: Tensor, align_corners: bool = False) -> Tensor: """Given an input and a flow-field grid, computes the output using input values and pixel locations from grid. Supported only bilinear interpolation method to sample the input pixels. Args: im (torch.Tensor): Input feature map, shape (N, C, H, W) grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2) align_corners (bool): If set to True, the extrema (-1 and 1) are considered as referring to the center points of the input’s corner pixels. If set to False, they are instead considered as referring to the corner points of the input’s corner pixels, making the sampling more resolution agnostic. Returns: torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg) """ n, c, h, w = im.shape gn, gh, gw, _ = grid.shape assert n == gn x = grid[:, :, :, 0] y = grid[:, :, :, 1] if align_corners: x = ((x + 1) / 2) * (w - 1) y = ((y + 1) / 2) * (h - 1) else: x = ((x + 1) * w - 1) / 2 y = ((y + 1) * h - 1) / 2 x = x.view(n, -1) y = y.view(n, -1) x0 = torch.floor(x).long() y0 = torch.floor(y).long() x1 = x0 + 1 y1 = y0 + 1 wa = ((x1 - x) * (y1 - y)).unsqueeze(1) wb = ((x1 - x) * (y - y0)).unsqueeze(1) wc = ((x - x0) * (y1 - y)).unsqueeze(1) wd = ((x - x0) * (y - y0)).unsqueeze(1) # Apply default for grid_sample function zero padding im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0) padded_h = h + 2 padded_w = w + 2 # save points positions after padding x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1 # Clip coordinates to padded image size x0 = torch.where(x0 < 0, torch.tensor(0), x0) x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0) x1 = torch.where(x1 < 0, torch.tensor(0), x1) x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1) y0 = torch.where(y0 < 0, torch.tensor(0), y0) y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0) y1 = torch.where(y1 < 0, torch.tensor(0), y1) y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1) im_padded = im_padded.view(n, c, -1) x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) Ia = torch.gather(im_padded, 2, x0_y0) Ib = torch.gather(im_padded, 2, x0_y1) Ic = torch.gather(im_padded, 2, x1_y0) Id = torch.gather(im_padded, 2, x1_y1) return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw) def normalize(grid: Tensor) -> Tensor: """Normalize input grid from [-1, 1] to [0, 1] Args: grid (torch.Tensor): The grid to be normalize, range [-1, 1]. Returns: torch.Tensor: Normalized grid, range [0, 1]. """ return (grid + 1.0) / 2.0 def denormalize(grid: Tensor) -> Tensor: """Denormalize input grid from range [0, 1] to [-1, 1] Args: grid (torch.Tensor): The grid to be denormalize, range [0, 1]. Returns: torch.Tensor: Denormalized grid, range [-1, 1]. """ return grid * 2.0 - 1.0 def generate_grid(num_grid: int, size: Tuple[int, int], device: torch.device) -> Tensor: """Generate regular square grid of points in [0, 1] x [0, 1] coordinate space. Args: num_grid (int): The number of grids to sample, one for each region. size (tuple[int, int]): The side size of the regular grid. device (torch.device): Desired device of returned tensor. Returns: torch.Tensor: A tensor of shape (num_grid, size[0]*size[1], 2) that contains coordinates for the regular grids. """ affine_trans = torch.tensor([[[1., 0., 0.], [0., 1., 0.]]], device=device) grid = F.affine_grid( affine_trans, torch.Size((1, 1, *size)), align_corners=False) grid = normalize(grid) return grid.view(1, -1, 2).expand(num_grid, -1, -1) def rel_roi_point_to_abs_img_point(rois: Tensor, rel_roi_points: Tensor) -> Tensor: """Convert roi based relative point coordinates to image based absolute point coordinates. Args: rois (torch.Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5) rel_roi_points (torch.Tensor): Point coordinates inside RoI, relative to RoI, location, range (0, 1), shape (N, P, 2) Returns: torch.Tensor: Image based absolute point coordinates, shape (N, P, 2) """ with torch.no_grad(): assert rel_roi_points.size(0) == rois.size(0) assert rois.dim() == 2 assert rel_roi_points.dim() == 3 assert rel_roi_points.size(2) == 2 # remove batch idx if rois.size(1) == 5: rois = rois[:, 1:] abs_img_points = rel_roi_points.clone() # To avoid an error during exporting to onnx use independent # variables instead inplace computation xs = abs_img_points[:, :, 0] * (rois[:, None, 2] - rois[:, None, 0]) ys = abs_img_points[:, :, 1] * (rois[:, None, 3] - rois[:, None, 1]) xs += rois[:, None, 0] ys += rois[:, None, 1] abs_img_points = torch.stack([xs, ys], dim=2) return abs_img_points def get_shape_from_feature_map(x: Tensor) -> Tensor: """Get spatial resolution of input feature map considering exporting to onnx mode. Args: x (torch.Tensor): Input tensor, shape (N, C, H, W) Returns: torch.Tensor: Spatial resolution (width, height), shape (1, 1, 2) """ img_shape = torch.tensor(x.shape[2:]).flip(0).view(1, 1, 2).to(x.device).float() return img_shape def abs_img_point_to_rel_img_point(abs_img_points: Tensor, img: Union[tuple, Tensor], spatial_scale: float = 1.) -> Tensor: """Convert image based absolute point coordinates to image based relative coordinates for sampling. Args: abs_img_points (torch.Tensor): Image based absolute point coordinates, shape (N, P, 2) img (tuple or torch.Tensor): (height, width) of image or feature map. spatial_scale (float, optional): Scale points by this factor. Default: 1. Returns: Tensor: Image based relative point coordinates for sampling, shape (N, P, 2). """ assert (isinstance(img, tuple) and len(img) == 2) or \ (isinstance(img, torch.Tensor) and len(img.shape) == 4) if isinstance(img, tuple): h, w = img scale = torch.tensor([w, h], dtype=torch.float, device=abs_img_points.device) scale = scale.view(1, 1, 2) else: scale = get_shape_from_feature_map(img) return abs_img_points / scale * spatial_scale def rel_roi_point_to_rel_img_point(rois: Tensor, rel_roi_points: Tensor, img: Union[tuple, Tensor], spatial_scale: float = 1.) -> Tensor: """Convert roi based relative point coordinates to image based absolute point coordinates. Args: rois (torch.Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5) rel_roi_points (torch.Tensor): Point coordinates inside RoI, relative to RoI, location, range (0, 1), shape (N, P, 2) img (tuple or torch.Tensor): (height, width) of image or feature map. spatial_scale (float, optional): Scale points by this factor. Default: 1. Returns: torch.Tensor: Image based relative point coordinates for sampling, shape (N, P, 2). """ abs_img_point = rel_roi_point_to_abs_img_point(rois, rel_roi_points) rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img, spatial_scale) return rel_img_point def point_sample(input: Tensor, points: Tensor, align_corners: bool = False, **kwargs) -> Tensor: """A wrapper around :func:`grid_sample` to support 3D point_coords tensors Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to lie inside ``[0, 1] x [0, 1]`` square. Args: input (torch.Tensor): Feature map, shape (N, C, H, W). points (torch.Tensor): Image based absolute point coordinates (normalized), range [0, 1] x [0, 1], shape (N, P, 2) or (N, Hgrid, Wgrid, 2). align_corners (bool, optional): Whether align_corners. Default: False Returns: torch.Tensor: Features of `point` on `input`, shape (N, C, P) or (N, C, Hgrid, Wgrid). """ add_dim = False if points.dim() == 3: add_dim = True points = points.unsqueeze(2) output = F.grid_sample( input, denormalize(points), align_corners=align_corners, **kwargs) if add_dim: output = output.squeeze(3) return output class SimpleRoIAlign(nn.Module): def __init__(self, output_size: Tuple[int], spatial_scale: float, aligned: bool = True) -> None: """Simple RoI align in PointRend, faster than standard RoIAlign. Args: output_size (tuple[int]): h, w spatial_scale (float): scale the input boxes by this number aligned (bool): if False, use the legacy implementation in MMDetection, align_corners=True will be used in F.grid_sample. If True, align the results more perfectly. """ super().__init__() self.output_size = _pair(output_size) self.spatial_scale = float(spatial_scale) # to be consistent with other RoI ops self.use_torchvision = False self.aligned = aligned def forward(self, features: Tensor, rois: Tensor) -> Tensor: num_imgs = features.size(0) num_rois = rois.size(0) rel_roi_points = generate_grid( num_rois, self.output_size, device=rois.device) point_feats = [] for batch_ind in range(num_imgs): # unravel batch dim feat = features[batch_ind].unsqueeze(0) inds = (rois[:, 0].long() == batch_ind) if inds.any(): rel_img_points = rel_roi_point_to_rel_img_point( rois[inds], rel_roi_points[inds], feat, self.spatial_scale).unsqueeze(0) point_feat = point_sample( feat, rel_img_points, align_corners=not self.aligned) point_feat = point_feat.squeeze(0).transpose(0, 1) point_feats.append(point_feat) point_feats_t = torch.cat(point_feats, dim=0) channels = features.size(1) roi_feats = point_feats_t.reshape(num_rois, channels, *self.output_size) return roi_feats def __repr__(self) -> str: format_str = self.__class__.__name__ format_str += '(output_size={}, spatial_scale={}'.format( self.output_size, self.spatial_scale) return format_str