Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import logging | |
import math | |
import numpy as np | |
from typing import Dict, List, Tuple | |
import fvcore.nn.weight_init as weight_init | |
import torch | |
from torch import Tensor, nn | |
from torch.nn import functional as F | |
from detectron2.config import configurable | |
from detectron2.layers import Conv2d, ShapeSpec, cat, interpolate | |
from detectron2.modeling import ROI_MASK_HEAD_REGISTRY | |
from detectron2.modeling.roi_heads.mask_head import mask_rcnn_inference, mask_rcnn_loss | |
from detectron2.structures import Boxes | |
from .point_features import ( | |
generate_regular_grid_point_coords, | |
get_point_coords_wrt_image, | |
get_uncertain_point_coords_on_grid, | |
get_uncertain_point_coords_with_randomness, | |
point_sample, | |
point_sample_fine_grained_features, | |
sample_point_labels, | |
) | |
from .point_head import build_point_head, roi_mask_point_loss | |
def calculate_uncertainty(logits, classes): | |
""" | |
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the | |
foreground class in `classes`. | |
Args: | |
logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or | |
class-agnostic, where R is the total number of predicted masks in all images and C is | |
the number of foreground classes. The values are logits. | |
classes (list): A list of length R that contains either predicted of ground truth class | |
for eash predicted mask. | |
Returns: | |
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with | |
the most uncertain locations having the highest uncertainty score. | |
""" | |
if logits.shape[1] == 1: | |
gt_class_logits = logits.clone() | |
else: | |
gt_class_logits = logits[ | |
torch.arange(logits.shape[0], device=logits.device), classes | |
].unsqueeze(1) | |
return -(torch.abs(gt_class_logits)) | |
class ConvFCHead(nn.Module): | |
""" | |
A mask head with fully connected layers. Given pooled features it first reduces channels and | |
spatial dimensions with conv layers and then uses FC layers to predict coarse masks analogously | |
to the standard box head. | |
""" | |
_version = 2 | |
def __init__( | |
self, input_shape: ShapeSpec, *, conv_dim: int, fc_dims: List[int], output_shape: Tuple[int] | |
): | |
""" | |
Args: | |
conv_dim: the output dimension of the conv layers | |
fc_dims: a list of N>0 integers representing the output dimensions of N FC layers | |
output_shape: shape of the output mask prediction | |
""" | |
super().__init__() | |
# fmt: off | |
input_channels = input_shape.channels | |
input_h = input_shape.height | |
input_w = input_shape.width | |
self.output_shape = output_shape | |
# fmt: on | |
self.conv_layers = [] | |
if input_channels > conv_dim: | |
self.reduce_channel_dim_conv = Conv2d( | |
input_channels, | |
conv_dim, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=True, | |
activation=F.relu, | |
) | |
self.conv_layers.append(self.reduce_channel_dim_conv) | |
self.reduce_spatial_dim_conv = Conv2d( | |
conv_dim, conv_dim, kernel_size=2, stride=2, padding=0, bias=True, activation=F.relu | |
) | |
self.conv_layers.append(self.reduce_spatial_dim_conv) | |
input_dim = conv_dim * input_h * input_w | |
input_dim //= 4 | |
self.fcs = [] | |
for k, fc_dim in enumerate(fc_dims): | |
fc = nn.Linear(input_dim, fc_dim) | |
self.add_module("fc{}".format(k + 1), fc) | |
self.fcs.append(fc) | |
input_dim = fc_dim | |
output_dim = int(np.prod(self.output_shape)) | |
self.prediction = nn.Linear(fc_dims[-1], output_dim) | |
# use normal distribution initialization for mask prediction layer | |
nn.init.normal_(self.prediction.weight, std=0.001) | |
nn.init.constant_(self.prediction.bias, 0) | |
for layer in self.conv_layers: | |
weight_init.c2_msra_fill(layer) | |
for layer in self.fcs: | |
weight_init.c2_xavier_fill(layer) | |
def from_config(cls, cfg, input_shape): | |
output_shape = ( | |
cfg.MODEL.ROI_HEADS.NUM_CLASSES, | |
cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION, | |
cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION, | |
) | |
fc_dim = cfg.MODEL.ROI_MASK_HEAD.FC_DIM | |
num_fc = cfg.MODEL.ROI_MASK_HEAD.NUM_FC | |
ret = dict( | |
input_shape=input_shape, | |
conv_dim=cfg.MODEL.ROI_MASK_HEAD.CONV_DIM, | |
fc_dims=[fc_dim] * num_fc, | |
output_shape=output_shape, | |
) | |
return ret | |
def forward(self, x): | |
N = x.shape[0] | |
for layer in self.conv_layers: | |
x = layer(x) | |
x = torch.flatten(x, start_dim=1) | |
for layer in self.fcs: | |
x = F.relu(layer(x)) | |
output_shape = [N] + list(self.output_shape) | |
return self.prediction(x).view(*output_shape) | |
def _load_from_state_dict( | |
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | |
): | |
version = local_metadata.get("version", None) | |
if version is None or version < 2: | |
logger = logging.getLogger(__name__) | |
logger.warning( | |
"Weight format of PointRend models have changed! " | |
"Applying automatic conversion now ..." | |
) | |
for k in list(state_dict.keys()): | |
newk = k | |
if k.startswith(prefix + "coarse_mask_fc"): | |
newk = k.replace(prefix + "coarse_mask_fc", prefix + "fc") | |
if newk != k: | |
state_dict[newk] = state_dict[k] | |
del state_dict[k] | |
class PointRendMaskHead(nn.Module): | |
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): | |
super().__init__() | |
self._feature_scales = {k: 1.0 / v.stride for k, v in input_shape.items()} | |
# point head | |
self._init_point_head(cfg, input_shape) | |
# coarse mask head | |
self.roi_pooler_in_features = cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES | |
self.roi_pooler_size = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION | |
self._feature_scales = {k: 1.0 / v.stride for k, v in input_shape.items()} | |
in_channels = np.sum([input_shape[f].channels for f in self.roi_pooler_in_features]) | |
self._init_roi_head( | |
cfg, | |
ShapeSpec( | |
channels=in_channels, | |
width=self.roi_pooler_size, | |
height=self.roi_pooler_size, | |
), | |
) | |
def _init_roi_head(self, cfg, input_shape): | |
self.coarse_head = ConvFCHead(cfg, input_shape) | |
def _init_point_head(self, cfg, input_shape): | |
# fmt: off | |
self.mask_point_on = cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON | |
if not self.mask_point_on: | |
return | |
assert cfg.MODEL.ROI_HEADS.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES | |
self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES | |
self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS | |
self.mask_point_oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO | |
self.mask_point_importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO | |
# next three parameters are use in the adaptive subdivions inference procedure | |
self.mask_point_subdivision_init_resolution = cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION | |
self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS | |
self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS | |
# fmt: on | |
in_channels = int(np.sum([input_shape[f].channels for f in self.mask_point_in_features])) | |
self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) | |
# An optimization to skip unused subdivision steps: if after subdivision, all pixels on | |
# the mask will be selected and recomputed anyway, we should just double our init_resolution | |
while ( | |
4 * self.mask_point_subdivision_init_resolution**2 | |
<= self.mask_point_subdivision_num_points | |
): | |
self.mask_point_subdivision_init_resolution *= 2 | |
self.mask_point_subdivision_steps -= 1 | |
def forward(self, features, instances): | |
""" | |
Args: | |
features (dict[str, Tensor]): a dict of image-level features | |
instances (list[Instances]): proposals in training; detected | |
instances in inference | |
""" | |
if self.training: | |
proposal_boxes = [x.proposal_boxes for x in instances] | |
coarse_mask = self.coarse_head(self._roi_pooler(features, proposal_boxes)) | |
losses = {"loss_mask": mask_rcnn_loss(coarse_mask, instances)} | |
if not self.mask_point_on: | |
return losses | |
point_coords, point_labels = self._sample_train_points(coarse_mask, instances) | |
point_fine_grained_features = self._point_pooler(features, proposal_boxes, point_coords) | |
point_logits = self._get_point_logits( | |
point_fine_grained_features, point_coords, coarse_mask | |
) | |
losses["loss_mask_point"] = roi_mask_point_loss(point_logits, instances, point_labels) | |
return losses | |
else: | |
pred_boxes = [x.pred_boxes for x in instances] | |
coarse_mask = self.coarse_head(self._roi_pooler(features, pred_boxes)) | |
return self._subdivision_inference(features, coarse_mask, instances) | |
def _roi_pooler(self, features: List[Tensor], boxes: List[Boxes]): | |
""" | |
Extract per-box feature. This is similar to RoIAlign(sampling_ratio=1) except: | |
1. It's implemented by point_sample | |
2. It pools features across all levels and concat them, while typically | |
RoIAlign select one level for every box. However in the config we only use | |
one level (p2) so there is no difference. | |
Returns: | |
Tensor of shape (R, C, pooler_size, pooler_size) where R is the total number of boxes | |
""" | |
features_list = [features[k] for k in self.roi_pooler_in_features] | |
features_scales = [self._feature_scales[k] for k in self.roi_pooler_in_features] | |
num_boxes = sum(x.tensor.size(0) for x in boxes) | |
output_size = self.roi_pooler_size | |
point_coords = generate_regular_grid_point_coords(num_boxes, output_size, boxes[0].device) | |
# For regular grids of points, this function is equivalent to `len(features_list)' calls | |
# of `ROIAlign` (with `SAMPLING_RATIO=1`), and concat the results. | |
roi_features, _ = point_sample_fine_grained_features( | |
features_list, features_scales, boxes, point_coords | |
) | |
return roi_features.view(num_boxes, roi_features.shape[1], output_size, output_size) | |
def _sample_train_points(self, coarse_mask, instances): | |
assert self.training | |
gt_classes = cat([x.gt_classes for x in instances]) | |
with torch.no_grad(): | |
# sample point_coords | |
point_coords = get_uncertain_point_coords_with_randomness( | |
coarse_mask, | |
lambda logits: calculate_uncertainty(logits, gt_classes), | |
self.mask_point_train_num_points, | |
self.mask_point_oversample_ratio, | |
self.mask_point_importance_sample_ratio, | |
) | |
# sample point_labels | |
proposal_boxes = [x.proposal_boxes for x in instances] | |
cat_boxes = Boxes.cat(proposal_boxes) | |
point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) | |
point_labels = sample_point_labels(instances, point_coords_wrt_image) | |
return point_coords, point_labels | |
def _point_pooler(self, features, proposal_boxes, point_coords): | |
point_features_list = [features[k] for k in self.mask_point_in_features] | |
point_features_scales = [self._feature_scales[k] for k in self.mask_point_in_features] | |
# sample image-level features | |
point_fine_grained_features, _ = point_sample_fine_grained_features( | |
point_features_list, point_features_scales, proposal_boxes, point_coords | |
) | |
return point_fine_grained_features | |
def _get_point_logits(self, point_fine_grained_features, point_coords, coarse_mask): | |
coarse_features = point_sample(coarse_mask, point_coords, align_corners=False) | |
point_logits = self.point_head(point_fine_grained_features, coarse_features) | |
return point_logits | |
def _subdivision_inference(self, features, mask_representations, instances): | |
assert not self.training | |
pred_boxes = [x.pred_boxes for x in instances] | |
pred_classes = cat([x.pred_classes for x in instances]) | |
mask_logits = None | |
# +1 here to include an initial step to generate the coarsest mask | |
# prediction with init_resolution, when mask_logits is None. | |
# We compute initial mask by sampling on a regular grid. coarse_mask | |
# can be used as initial mask as well, but it's typically very low-res | |
# so it will be completely overwritten during subdivision anyway. | |
for _ in range(self.mask_point_subdivision_steps + 1): | |
if mask_logits is None: | |
point_coords = generate_regular_grid_point_coords( | |
pred_classes.size(0), | |
self.mask_point_subdivision_init_resolution, | |
pred_boxes[0].device, | |
) | |
else: | |
mask_logits = interpolate( | |
mask_logits, scale_factor=2, mode="bilinear", align_corners=False | |
) | |
uncertainty_map = calculate_uncertainty(mask_logits, pred_classes) | |
point_indices, point_coords = get_uncertain_point_coords_on_grid( | |
uncertainty_map, self.mask_point_subdivision_num_points | |
) | |
# Run the point head for every point in point_coords | |
fine_grained_features = self._point_pooler(features, pred_boxes, point_coords) | |
point_logits = self._get_point_logits( | |
fine_grained_features, point_coords, mask_representations | |
) | |
if mask_logits is None: | |
# Create initial mask_logits using point_logits on this regular grid | |
R, C, _ = point_logits.shape | |
mask_logits = point_logits.reshape( | |
R, | |
C, | |
self.mask_point_subdivision_init_resolution, | |
self.mask_point_subdivision_init_resolution, | |
) | |
# The subdivision code will fail with the empty list of boxes | |
if len(pred_classes) == 0: | |
mask_rcnn_inference(mask_logits, instances) | |
return instances | |
else: | |
# Put point predictions to the right places on the upsampled grid. | |
R, C, H, W = mask_logits.shape | |
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) | |
mask_logits = ( | |
mask_logits.reshape(R, C, H * W) | |
.scatter_(2, point_indices, point_logits) | |
.view(R, C, H, W) | |
) | |
mask_rcnn_inference(mask_logits, instances) | |
return instances | |
class ImplicitPointRendMaskHead(PointRendMaskHead): | |
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): | |
super().__init__(cfg, input_shape) | |
def _init_roi_head(self, cfg, input_shape): | |
assert hasattr(self, "num_params"), "Please initialize point_head first!" | |
self.parameter_head = ConvFCHead(cfg, input_shape, output_shape=(self.num_params,)) | |
self.regularizer = cfg.MODEL.IMPLICIT_POINTREND.PARAMS_L2_REGULARIZER | |
def _init_point_head(self, cfg, input_shape): | |
# fmt: off | |
self.mask_point_on = True # always on | |
assert cfg.MODEL.ROI_HEADS.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES | |
self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES | |
self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS | |
# next two parameters are use in the adaptive subdivions inference procedure | |
self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS | |
self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS | |
# fmt: on | |
in_channels = int(np.sum([input_shape[f].channels for f in self.mask_point_in_features])) | |
self.point_head = build_point_head(cfg, ShapeSpec(channels=in_channels, width=1, height=1)) | |
self.num_params = self.point_head.num_params | |
# inference parameters | |
self.mask_point_subdivision_init_resolution = int( | |
math.sqrt(self.mask_point_subdivision_num_points) | |
) | |
assert ( | |
self.mask_point_subdivision_init_resolution | |
* self.mask_point_subdivision_init_resolution | |
== self.mask_point_subdivision_num_points | |
) | |
def forward(self, features, instances): | |
""" | |
Args: | |
features (dict[str, Tensor]): a dict of image-level features | |
instances (list[Instances]): proposals in training; detected | |
instances in inference | |
""" | |
if self.training: | |
proposal_boxes = [x.proposal_boxes for x in instances] | |
parameters = self.parameter_head(self._roi_pooler(features, proposal_boxes)) | |
losses = {"loss_l2": self.regularizer * (parameters**2).mean()} | |
point_coords, point_labels = self._uniform_sample_train_points(instances) | |
point_fine_grained_features = self._point_pooler(features, proposal_boxes, point_coords) | |
point_logits = self._get_point_logits( | |
point_fine_grained_features, point_coords, parameters | |
) | |
losses["loss_mask_point"] = roi_mask_point_loss(point_logits, instances, point_labels) | |
return losses | |
else: | |
pred_boxes = [x.pred_boxes for x in instances] | |
parameters = self.parameter_head(self._roi_pooler(features, pred_boxes)) | |
return self._subdivision_inference(features, parameters, instances) | |
def _uniform_sample_train_points(self, instances): | |
assert self.training | |
proposal_boxes = [x.proposal_boxes for x in instances] | |
cat_boxes = Boxes.cat(proposal_boxes) | |
# uniform sample | |
point_coords = torch.rand( | |
len(cat_boxes), self.mask_point_train_num_points, 2, device=cat_boxes.tensor.device | |
) | |
# sample point_labels | |
point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) | |
point_labels = sample_point_labels(instances, point_coords_wrt_image) | |
return point_coords, point_labels | |
def _get_point_logits(self, fine_grained_features, point_coords, parameters): | |
return self.point_head(fine_grained_features, point_coords, parameters) | |