Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import logging | |
| import math | |
| import numpy as np | |
| from typing import Dict, List, Tuple | |
| import fvcore.nn.weight_init as weight_init | |
| import torch | |
| from torch import Tensor, nn | |
| from torch.nn import functional as F | |
| from detectron2.config import configurable | |
| from detectron2.layers import Conv2d, ShapeSpec, cat, interpolate | |
| from detectron2.modeling import ROI_MASK_HEAD_REGISTRY | |
| from detectron2.modeling.roi_heads.mask_head import mask_rcnn_inference, mask_rcnn_loss | |
| from detectron2.structures import Boxes | |
| from .point_features import ( | |
| generate_regular_grid_point_coords, | |
| get_point_coords_wrt_image, | |
| get_uncertain_point_coords_on_grid, | |
| get_uncertain_point_coords_with_randomness, | |
| point_sample, | |
| point_sample_fine_grained_features, | |
| sample_point_labels, | |
| ) | |
| from .point_head import build_point_head, roi_mask_point_loss | |
| def calculate_uncertainty(logits, classes): | |
| """ | |
| We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the | |
| foreground class in `classes`. | |
| Args: | |
| logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or | |
| class-agnostic, where R is the total number of predicted masks in all images and C is | |
| the number of foreground classes. The values are logits. | |
| classes (list): A list of length R that contains either predicted of ground truth class | |
| for eash predicted mask. | |
| Returns: | |
| scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with | |
| the most uncertain locations having the highest uncertainty score. | |
| """ | |
| if logits.shape[1] == 1: | |
| gt_class_logits = logits.clone() | |
| else: | |
| gt_class_logits = logits[ | |
| torch.arange(logits.shape[0], device=logits.device), classes | |
| ].unsqueeze(1) | |
| return -(torch.abs(gt_class_logits)) | |
| class ConvFCHead(nn.Module): | |
| """ | |
| A mask head with fully connected layers. Given pooled features it first reduces channels and | |
| spatial dimensions with conv layers and then uses FC layers to predict coarse masks analogously | |
| to the standard box head. | |
| """ | |
| _version = 2 | |
| def __init__( | |
| self, input_shape: ShapeSpec, *, conv_dim: int, fc_dims: List[int], output_shape: Tuple[int] | |
| ): | |
| """ | |
| Args: | |
| conv_dim: the output dimension of the conv layers | |
| fc_dims: a list of N>0 integers representing the output dimensions of N FC layers | |
| output_shape: shape of the output mask prediction | |
| """ | |
| super().__init__() | |
| # fmt: off | |
| input_channels = input_shape.channels | |
| input_h = input_shape.height | |
| input_w = input_shape.width | |
| self.output_shape = output_shape | |
| # fmt: on | |
| self.conv_layers = [] | |
| if input_channels > conv_dim: | |
| self.reduce_channel_dim_conv = Conv2d( | |
| input_channels, | |
| conv_dim, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=True, | |
| activation=F.relu, | |
| ) | |
| self.conv_layers.append(self.reduce_channel_dim_conv) | |
| self.reduce_spatial_dim_conv = Conv2d( | |
| conv_dim, conv_dim, kernel_size=2, stride=2, padding=0, bias=True, activation=F.relu | |
| ) | |
| self.conv_layers.append(self.reduce_spatial_dim_conv) | |
| input_dim = conv_dim * input_h * input_w | |
| input_dim //= 4 | |
| self.fcs = [] | |
| for k, fc_dim in enumerate(fc_dims): | |
| fc = nn.Linear(input_dim, fc_dim) | |
| self.add_module("fc{}".format(k + 1), fc) | |
| self.fcs.append(fc) | |
| input_dim = fc_dim | |
| output_dim = int(np.prod(self.output_shape)) | |
| self.prediction = nn.Linear(fc_dims[-1], output_dim) | |
| # use normal distribution initialization for mask prediction layer | |
| nn.init.normal_(self.prediction.weight, std=0.001) | |
| nn.init.constant_(self.prediction.bias, 0) | |
| for layer in self.conv_layers: | |
| weight_init.c2_msra_fill(layer) | |
| for layer in self.fcs: | |
| weight_init.c2_xavier_fill(layer) | |
| def from_config(cls, cfg, input_shape): | |
| output_shape = ( | |
| cfg.MODEL.ROI_HEADS.NUM_CLASSES, | |
| cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION, | |
| cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION, | |
| ) | |
| fc_dim = cfg.MODEL.ROI_MASK_HEAD.FC_DIM | |
| num_fc = cfg.MODEL.ROI_MASK_HEAD.NUM_FC | |
| ret = dict( | |
| input_shape=input_shape, | |
| conv_dim=cfg.MODEL.ROI_MASK_HEAD.CONV_DIM, | |
| fc_dims=[fc_dim] * num_fc, | |
| output_shape=output_shape, | |
| ) | |
| return ret | |
| def forward(self, x): | |
| N = x.shape[0] | |
| for layer in self.conv_layers: | |
| x = layer(x) | |
| x = torch.flatten(x, start_dim=1) | |
| for layer in self.fcs: | |
| x = F.relu(layer(x)) | |
| output_shape = [N] + list(self.output_shape) | |
| return self.prediction(x).view(*output_shape) | |
| def _load_from_state_dict( | |
| self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | |
| ): | |
| version = local_metadata.get("version", None) | |
| if version is None or version < 2: | |
| logger = logging.getLogger(__name__) | |
| logger.warning( | |
| "Weight format of PointRend models have changed! " | |
| "Applying automatic conversion now ..." | |
| ) | |
| for k in list(state_dict.keys()): | |
| newk = k | |
| if k.startswith(prefix + "coarse_mask_fc"): | |
| newk = k.replace(prefix + "coarse_mask_fc", prefix + "fc") | |
| if newk != k: | |
| state_dict[newk] = state_dict[k] | |
| del state_dict[k] | |
| class PointRendMaskHead(nn.Module): | |
| def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): | |
| super().__init__() | |
| self._feature_scales = {k: 1.0 / v.stride for k, v in input_shape.items()} | |
| # point head | |
| self._init_point_head(cfg, input_shape) | |
| # coarse mask head | |
| self.roi_pooler_in_features = cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES | |
| self.roi_pooler_size = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION | |
| self._feature_scales = {k: 1.0 / v.stride for k, v in input_shape.items()} | |
| in_channels = np.sum([input_shape[f].channels for f in self.roi_pooler_in_features]) | |
| self._init_roi_head( | |
| cfg, | |
| ShapeSpec( | |
| channels=in_channels, | |
| width=self.roi_pooler_size, | |
| height=self.roi_pooler_size, | |
| ), | |
| ) | |
| def _init_roi_head(self, cfg, input_shape): | |
| self.coarse_head = ConvFCHead(cfg, input_shape) | |
| def _init_point_head(self, cfg, input_shape): | |
| # fmt: off | |
| self.mask_point_on = cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON | |
| if not self.mask_point_on: | |
| return | |
| assert cfg.MODEL.ROI_HEADS.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES | |
| self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES | |
| self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS | |
| self.mask_point_oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO | |
| self.mask_point_importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO | |
| # next three parameters are use in the adaptive subdivions inference procedure | |
| self.mask_point_subdivision_init_resolution = cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION | |
| self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS | |
| self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS | |
| # fmt: on | |
| in_channels = int(np.sum([input_shape[f].channels for f in self.mask_point_in_features])) | |
| self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) | |
| # An optimization to skip unused subdivision steps: if after subdivision, all pixels on | |
| # the mask will be selected and recomputed anyway, we should just double our init_resolution | |
| while ( | |
| 4 * self.mask_point_subdivision_init_resolution**2 | |
| <= self.mask_point_subdivision_num_points | |
| ): | |
| self.mask_point_subdivision_init_resolution *= 2 | |
| self.mask_point_subdivision_steps -= 1 | |
| def forward(self, features, instances): | |
| """ | |
| Args: | |
| features (dict[str, Tensor]): a dict of image-level features | |
| instances (list[Instances]): proposals in training; detected | |
| instances in inference | |
| """ | |
| if self.training: | |
| proposal_boxes = [x.proposal_boxes for x in instances] | |
| coarse_mask = self.coarse_head(self._roi_pooler(features, proposal_boxes)) | |
| losses = {"loss_mask": mask_rcnn_loss(coarse_mask, instances)} | |
| if not self.mask_point_on: | |
| return losses | |
| point_coords, point_labels = self._sample_train_points(coarse_mask, instances) | |
| point_fine_grained_features = self._point_pooler(features, proposal_boxes, point_coords) | |
| point_logits = self._get_point_logits( | |
| point_fine_grained_features, point_coords, coarse_mask | |
| ) | |
| losses["loss_mask_point"] = roi_mask_point_loss(point_logits, instances, point_labels) | |
| return losses | |
| else: | |
| pred_boxes = [x.pred_boxes for x in instances] | |
| coarse_mask = self.coarse_head(self._roi_pooler(features, pred_boxes)) | |
| return self._subdivision_inference(features, coarse_mask, instances) | |
| def _roi_pooler(self, features: List[Tensor], boxes: List[Boxes]): | |
| """ | |
| Extract per-box feature. This is similar to RoIAlign(sampling_ratio=1) except: | |
| 1. It's implemented by point_sample | |
| 2. It pools features across all levels and concat them, while typically | |
| RoIAlign select one level for every box. However in the config we only use | |
| one level (p2) so there is no difference. | |
| Returns: | |
| Tensor of shape (R, C, pooler_size, pooler_size) where R is the total number of boxes | |
| """ | |
| features_list = [features[k] for k in self.roi_pooler_in_features] | |
| features_scales = [self._feature_scales[k] for k in self.roi_pooler_in_features] | |
| num_boxes = sum(x.tensor.size(0) for x in boxes) | |
| output_size = self.roi_pooler_size | |
| point_coords = generate_regular_grid_point_coords(num_boxes, output_size, boxes[0].device) | |
| # For regular grids of points, this function is equivalent to `len(features_list)' calls | |
| # of `ROIAlign` (with `SAMPLING_RATIO=1`), and concat the results. | |
| roi_features, _ = point_sample_fine_grained_features( | |
| features_list, features_scales, boxes, point_coords | |
| ) | |
| return roi_features.view(num_boxes, roi_features.shape[1], output_size, output_size) | |
| def _sample_train_points(self, coarse_mask, instances): | |
| assert self.training | |
| gt_classes = cat([x.gt_classes for x in instances]) | |
| with torch.no_grad(): | |
| # sample point_coords | |
| point_coords = get_uncertain_point_coords_with_randomness( | |
| coarse_mask, | |
| lambda logits: calculate_uncertainty(logits, gt_classes), | |
| self.mask_point_train_num_points, | |
| self.mask_point_oversample_ratio, | |
| self.mask_point_importance_sample_ratio, | |
| ) | |
| # sample point_labels | |
| proposal_boxes = [x.proposal_boxes for x in instances] | |
| cat_boxes = Boxes.cat(proposal_boxes) | |
| point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) | |
| point_labels = sample_point_labels(instances, point_coords_wrt_image) | |
| return point_coords, point_labels | |
| def _point_pooler(self, features, proposal_boxes, point_coords): | |
| point_features_list = [features[k] for k in self.mask_point_in_features] | |
| point_features_scales = [self._feature_scales[k] for k in self.mask_point_in_features] | |
| # sample image-level features | |
| point_fine_grained_features, _ = point_sample_fine_grained_features( | |
| point_features_list, point_features_scales, proposal_boxes, point_coords | |
| ) | |
| return point_fine_grained_features | |
| def _get_point_logits(self, point_fine_grained_features, point_coords, coarse_mask): | |
| coarse_features = point_sample(coarse_mask, point_coords, align_corners=False) | |
| point_logits = self.point_head(point_fine_grained_features, coarse_features) | |
| return point_logits | |
| def _subdivision_inference(self, features, mask_representations, instances): | |
| assert not self.training | |
| pred_boxes = [x.pred_boxes for x in instances] | |
| pred_classes = cat([x.pred_classes for x in instances]) | |
| mask_logits = None | |
| # +1 here to include an initial step to generate the coarsest mask | |
| # prediction with init_resolution, when mask_logits is None. | |
| # We compute initial mask by sampling on a regular grid. coarse_mask | |
| # can be used as initial mask as well, but it's typically very low-res | |
| # so it will be completely overwritten during subdivision anyway. | |
| for _ in range(self.mask_point_subdivision_steps + 1): | |
| if mask_logits is None: | |
| point_coords = generate_regular_grid_point_coords( | |
| pred_classes.size(0), | |
| self.mask_point_subdivision_init_resolution, | |
| pred_boxes[0].device, | |
| ) | |
| else: | |
| mask_logits = interpolate( | |
| mask_logits, scale_factor=2, mode="bilinear", align_corners=False | |
| ) | |
| uncertainty_map = calculate_uncertainty(mask_logits, pred_classes) | |
| point_indices, point_coords = get_uncertain_point_coords_on_grid( | |
| uncertainty_map, self.mask_point_subdivision_num_points | |
| ) | |
| # Run the point head for every point in point_coords | |
| fine_grained_features = self._point_pooler(features, pred_boxes, point_coords) | |
| point_logits = self._get_point_logits( | |
| fine_grained_features, point_coords, mask_representations | |
| ) | |
| if mask_logits is None: | |
| # Create initial mask_logits using point_logits on this regular grid | |
| R, C, _ = point_logits.shape | |
| mask_logits = point_logits.reshape( | |
| R, | |
| C, | |
| self.mask_point_subdivision_init_resolution, | |
| self.mask_point_subdivision_init_resolution, | |
| ) | |
| # The subdivision code will fail with the empty list of boxes | |
| if len(pred_classes) == 0: | |
| mask_rcnn_inference(mask_logits, instances) | |
| return instances | |
| else: | |
| # Put point predictions to the right places on the upsampled grid. | |
| R, C, H, W = mask_logits.shape | |
| point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) | |
| mask_logits = ( | |
| mask_logits.reshape(R, C, H * W) | |
| .scatter_(2, point_indices, point_logits) | |
| .view(R, C, H, W) | |
| ) | |
| mask_rcnn_inference(mask_logits, instances) | |
| return instances | |
| class ImplicitPointRendMaskHead(PointRendMaskHead): | |
| def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): | |
| super().__init__(cfg, input_shape) | |
| def _init_roi_head(self, cfg, input_shape): | |
| assert hasattr(self, "num_params"), "Please initialize point_head first!" | |
| self.parameter_head = ConvFCHead(cfg, input_shape, output_shape=(self.num_params,)) | |
| self.regularizer = cfg.MODEL.IMPLICIT_POINTREND.PARAMS_L2_REGULARIZER | |
| def _init_point_head(self, cfg, input_shape): | |
| # fmt: off | |
| self.mask_point_on = True # always on | |
| assert cfg.MODEL.ROI_HEADS.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES | |
| self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES | |
| self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS | |
| # next two parameters are use in the adaptive subdivions inference procedure | |
| self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS | |
| self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS | |
| # fmt: on | |
| in_channels = int(np.sum([input_shape[f].channels for f in self.mask_point_in_features])) | |
| self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) | |
| self.num_params = self.point_head.num_params | |
| # inference parameters | |
| self.mask_point_subdivision_init_resolution = int( | |
| math.sqrt(self.mask_point_subdivision_num_points) | |
| ) | |
| assert ( | |
| self.mask_point_subdivision_init_resolution | |
| * self.mask_point_subdivision_init_resolution | |
| == self.mask_point_subdivision_num_points | |
| ) | |
| def forward(self, features, instances): | |
| """ | |
| Args: | |
| features (dict[str, Tensor]): a dict of image-level features | |
| instances (list[Instances]): proposals in training; detected | |
| instances in inference | |
| """ | |
| if self.training: | |
| proposal_boxes = [x.proposal_boxes for x in instances] | |
| parameters = self.parameter_head(self._roi_pooler(features, proposal_boxes)) | |
| losses = {"loss_l2": self.regularizer * (parameters**2).mean()} | |
| point_coords, point_labels = self._uniform_sample_train_points(instances) | |
| point_fine_grained_features = self._point_pooler(features, proposal_boxes, point_coords) | |
| point_logits = self._get_point_logits( | |
| point_fine_grained_features, point_coords, parameters | |
| ) | |
| losses["loss_mask_point"] = roi_mask_point_loss(point_logits, instances, point_labels) | |
| return losses | |
| else: | |
| pred_boxes = [x.pred_boxes for x in instances] | |
| parameters = self.parameter_head(self._roi_pooler(features, pred_boxes)) | |
| return self._subdivision_inference(features, parameters, instances) | |
| def _uniform_sample_train_points(self, instances): | |
| assert self.training | |
| proposal_boxes = [x.proposal_boxes for x in instances] | |
| cat_boxes = Boxes.cat(proposal_boxes) | |
| # uniform sample | |
| point_coords = torch.rand( | |
| len(cat_boxes), self.mask_point_train_num_points, 2, device=cat_boxes.tensor.device | |
| ) | |
| # sample point_labels | |
| point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) | |
| point_labels = sample_point_labels(instances, point_coords_wrt_image) | |
| return point_coords, point_labels | |
| def _get_point_logits(self, fine_grained_features, point_coords, parameters): | |
| return self.point_head(fine_grained_features, point_coords, parameters) | |