|
|
|
|
|
|
|
from typing import List, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule |
|
from mmcv.ops import point_sample, rel_roi_point_to_rel_img_point |
|
from mmengine.model import BaseModule |
|
from mmengine.structures import InstanceData |
|
from torch import Tensor |
|
|
|
from mmdet.models.task_modules.samplers import SamplingResult |
|
from mmdet.models.utils import (get_uncertain_point_coords_with_randomness, |
|
get_uncertainty) |
|
from mmdet.registry import MODELS |
|
from mmdet.structures.bbox import bbox2roi |
|
from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptConfigType |
|
|
|
|
|
@MODELS.register_module() |
|
class MaskPointHead(BaseModule): |
|
"""A mask point head use in PointRend. |
|
|
|
``MaskPointHead`` use shared multi-layer perceptron (equivalent to |
|
nn.Conv1d) to predict the logit of input points. The fine-grained feature |
|
and coarse feature will be concatenate together for predication. |
|
|
|
Args: |
|
num_fcs (int): Number of fc layers in the head. Defaults to 3. |
|
in_channels (int): Number of input channels. Defaults to 256. |
|
fc_channels (int): Number of fc channels. Defaults to 256. |
|
num_classes (int): Number of classes for logits. Defaults to 80. |
|
class_agnostic (bool): Whether use class agnostic classification. |
|
If so, the output channels of logits will be 1. Defaults to False. |
|
coarse_pred_each_layer (bool): Whether concatenate coarse feature with |
|
the output of each fc layer. Defaults to True. |
|
conv_cfg (:obj:`ConfigDict` or dict): Dictionary to construct |
|
and config conv layer. Defaults to dict(type='Conv1d')). |
|
norm_cfg (:obj:`ConfigDict` or dict, optional): Dictionary to construct |
|
and config norm layer. Defaults to None. |
|
loss_point (:obj:`ConfigDict` or dict): Dictionary to construct and |
|
config loss layer of point head. Defaults to |
|
dict(type='CrossEntropyLoss', use_mask=True, loss_weight=1.0). |
|
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ |
|
dict], optional): Initialization config dict. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_classes: int, |
|
num_fcs: int = 3, |
|
in_channels: int = 256, |
|
fc_channels: int = 256, |
|
class_agnostic: bool = False, |
|
coarse_pred_each_layer: bool = True, |
|
conv_cfg: ConfigType = dict(type='Conv1d'), |
|
norm_cfg: OptConfigType = None, |
|
act_cfg: ConfigType = dict(type='ReLU'), |
|
loss_point: ConfigType = dict( |
|
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0), |
|
init_cfg: MultiConfig = dict( |
|
type='Normal', std=0.001, override=dict(name='fc_logits')) |
|
) -> None: |
|
super().__init__(init_cfg=init_cfg) |
|
self.num_fcs = num_fcs |
|
self.in_channels = in_channels |
|
self.fc_channels = fc_channels |
|
self.num_classes = num_classes |
|
self.class_agnostic = class_agnostic |
|
self.coarse_pred_each_layer = coarse_pred_each_layer |
|
self.conv_cfg = conv_cfg |
|
self.norm_cfg = norm_cfg |
|
self.loss_point = MODELS.build(loss_point) |
|
|
|
fc_in_channels = in_channels + num_classes |
|
self.fcs = nn.ModuleList() |
|
for _ in range(num_fcs): |
|
fc = ConvModule( |
|
fc_in_channels, |
|
fc_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
self.fcs.append(fc) |
|
fc_in_channels = fc_channels |
|
fc_in_channels += num_classes if self.coarse_pred_each_layer else 0 |
|
|
|
out_channels = 1 if self.class_agnostic else self.num_classes |
|
self.fc_logits = nn.Conv1d( |
|
fc_in_channels, out_channels, kernel_size=1, stride=1, padding=0) |
|
|
|
def forward(self, fine_grained_feats: Tensor, |
|
coarse_feats: Tensor) -> Tensor: |
|
"""Classify each point base on fine grained and coarse feats. |
|
|
|
Args: |
|
fine_grained_feats (Tensor): Fine grained feature sampled from FPN, |
|
shape (num_rois, in_channels, num_points). |
|
coarse_feats (Tensor): Coarse feature sampled from CoarseMaskHead, |
|
shape (num_rois, num_classes, num_points). |
|
|
|
Returns: |
|
Tensor: Point classification results, |
|
shape (num_rois, num_class, num_points). |
|
""" |
|
|
|
x = torch.cat([fine_grained_feats, coarse_feats], dim=1) |
|
for fc in self.fcs: |
|
x = fc(x) |
|
if self.coarse_pred_each_layer: |
|
x = torch.cat((x, coarse_feats), dim=1) |
|
return self.fc_logits(x) |
|
|
|
def get_targets(self, rois: Tensor, rel_roi_points: Tensor, |
|
sampling_results: List[SamplingResult], |
|
batch_gt_instances: InstanceList, |
|
cfg: ConfigType) -> Tensor: |
|
"""Get training targets of MaskPointHead for all images. |
|
|
|
Args: |
|
rois (Tensor): Region of Interest, shape (num_rois, 5). |
|
rel_roi_points (Tensor): Points coordinates relative to RoI, shape |
|
(num_rois, num_points, 2). |
|
sampling_results (:obj:`SamplingResult`): Sampling result after |
|
sampling and assignment. |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes``, ``labels``, and |
|
``masks`` attributes. |
|
cfg (obj:`ConfigDict` or dict): Training cfg. |
|
|
|
Returns: |
|
Tensor: Point target, shape (num_rois, num_points). |
|
""" |
|
|
|
num_imgs = len(sampling_results) |
|
rois_list = [] |
|
rel_roi_points_list = [] |
|
for batch_ind in range(num_imgs): |
|
inds = (rois[:, 0] == batch_ind) |
|
rois_list.append(rois[inds]) |
|
rel_roi_points_list.append(rel_roi_points[inds]) |
|
pos_assigned_gt_inds_list = [ |
|
res.pos_assigned_gt_inds for res in sampling_results |
|
] |
|
cfg_list = [cfg for _ in range(num_imgs)] |
|
|
|
point_targets = map(self._get_targets_single, rois_list, |
|
rel_roi_points_list, pos_assigned_gt_inds_list, |
|
batch_gt_instances, cfg_list) |
|
point_targets = list(point_targets) |
|
|
|
if len(point_targets) > 0: |
|
point_targets = torch.cat(point_targets) |
|
|
|
return point_targets |
|
|
|
def _get_targets_single(self, rois: Tensor, rel_roi_points: Tensor, |
|
pos_assigned_gt_inds: Tensor, |
|
gt_instances: InstanceData, |
|
cfg: ConfigType) -> Tensor: |
|
"""Get training target of MaskPointHead for each image.""" |
|
num_pos = rois.size(0) |
|
num_points = cfg.num_points |
|
if num_pos > 0: |
|
gt_masks_th = ( |
|
gt_instances.masks.to_tensor(rois.dtype, |
|
rois.device).index_select( |
|
0, pos_assigned_gt_inds)) |
|
gt_masks_th = gt_masks_th.unsqueeze(1) |
|
rel_img_points = rel_roi_point_to_rel_img_point( |
|
rois, rel_roi_points, gt_masks_th) |
|
point_targets = point_sample(gt_masks_th, |
|
rel_img_points).squeeze(1) |
|
else: |
|
point_targets = rois.new_zeros((0, num_points)) |
|
return point_targets |
|
|
|
def loss_and_target(self, point_pred: Tensor, rel_roi_points: Tensor, |
|
sampling_results: List[SamplingResult], |
|
batch_gt_instances: InstanceList, |
|
cfg: ConfigType) -> dict: |
|
"""Calculate loss for MaskPointHead. |
|
|
|
Args: |
|
point_pred (Tensor): Point predication result, shape |
|
(num_rois, num_classes, num_points). |
|
rel_roi_points (Tensor): Points coordinates relative to RoI, shape |
|
(num_rois, num_points, 2). |
|
sampling_results (:obj:`SamplingResult`): Sampling result after |
|
sampling and assignment. |
|
batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes``, ``labels``, and |
|
``masks`` attributes. |
|
cfg (obj:`ConfigDict` or dict): Training cfg. |
|
|
|
Returns: |
|
dict: a dictionary of point loss and point target. |
|
""" |
|
rois = bbox2roi([res.pos_bboxes for res in sampling_results]) |
|
pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) |
|
|
|
point_target = self.get_targets(rois, rel_roi_points, sampling_results, |
|
batch_gt_instances, cfg) |
|
if self.class_agnostic: |
|
loss_point = self.loss_point(point_pred, point_target, |
|
torch.zeros_like(pos_labels)) |
|
else: |
|
loss_point = self.loss_point(point_pred, point_target, pos_labels) |
|
|
|
return dict(loss_point=loss_point, point_target=point_target) |
|
|
|
def get_roi_rel_points_train(self, mask_preds: Tensor, labels: Tensor, |
|
cfg: ConfigType) -> Tensor: |
|
"""Get ``num_points`` most uncertain points with random points during |
|
train. |
|
|
|
Sample points in [0, 1] x [0, 1] coordinate space based on their |
|
uncertainty. The uncertainties are calculated for each point using |
|
'_get_uncertainty()' function that takes point's logit prediction as |
|
input. |
|
|
|
Args: |
|
mask_preds (Tensor): A tensor of shape (num_rois, num_classes, |
|
mask_height, mask_width) for class-specific or class-agnostic |
|
prediction. |
|
labels (Tensor): The ground truth class for each instance. |
|
cfg (:obj:`ConfigDict` or dict): Training config of point head. |
|
|
|
Returns: |
|
point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) |
|
that contains the coordinates sampled points. |
|
""" |
|
point_coords = get_uncertain_point_coords_with_randomness( |
|
mask_preds, labels, cfg.num_points, cfg.oversample_ratio, |
|
cfg.importance_sample_ratio) |
|
return point_coords |
|
|
|
def get_roi_rel_points_test(self, mask_preds: Tensor, label_preds: Tensor, |
|
cfg: ConfigType) -> Tuple[Tensor, Tensor]: |
|
"""Get ``num_points`` most uncertain points during test. |
|
|
|
Args: |
|
mask_preds (Tensor): A tensor of shape (num_rois, num_classes, |
|
mask_height, mask_width) for class-specific or class-agnostic |
|
prediction. |
|
label_preds (Tensor): The predication class for each instance. |
|
cfg (:obj:`ConfigDict` or dict): Testing config of point head. |
|
|
|
Returns: |
|
tuple: |
|
|
|
- point_indices (Tensor): A tensor of shape (num_rois, num_points) |
|
that contains indices from [0, mask_height x mask_width) of the |
|
most uncertain points. |
|
- point_coords (Tensor): A tensor of shape (num_rois, num_points, |
|
2) that contains [0, 1] x [0, 1] normalized coordinates of the |
|
most uncertain points from the [mask_height, mask_width] grid. |
|
""" |
|
num_points = cfg.subdivision_num_points |
|
uncertainty_map = get_uncertainty(mask_preds, label_preds) |
|
num_rois, _, mask_height, mask_width = uncertainty_map.shape |
|
|
|
|
|
|
|
if isinstance(mask_height, torch.Tensor): |
|
h_step = 1.0 / mask_height.float() |
|
w_step = 1.0 / mask_width.float() |
|
else: |
|
h_step = 1.0 / mask_height |
|
w_step = 1.0 / mask_width |
|
|
|
mask_size = int(mask_height * mask_width) |
|
uncertainty_map = uncertainty_map.view(num_rois, mask_size) |
|
num_points = min(mask_size, num_points) |
|
point_indices = uncertainty_map.topk(num_points, dim=1)[1] |
|
xs = w_step / 2.0 + (point_indices % mask_width).float() * w_step |
|
ys = h_step / 2.0 + (point_indices // mask_width).float() * h_step |
|
point_coords = torch.stack([xs, ys], dim=2) |
|
return point_indices, point_coords |
|
|