import copy
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv import ConfigDict
from mmcv.cnn import normal_init
from mmcv.ops import batched_nms
from ..builder import HEADS
from .anchor_head import AnchorHead
from .rpn_test_mixin import RPNTestMixin
class RPNHead(RPNTestMixin, AnchorHead):
"""RPN head.
in_channels (int): Number of channels in the input feature map.
""" # noqa: W605
def __init__(self, in_channels, **kwargs):
super(RPNHead, self).__init__(1, in_channels, **kwargs)
def _init_layers(self):
"""Initialize layers of the head."""
self.rpn_conv = nn.Conv2d(
self.in_channels, self.feat_channels, 3, padding=1)
self.rpn_cls = nn.Conv2d(self.feat_channels,
self.num_anchors * self.cls_out_channels, 1)
self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
def init_weights(self):
"""Initialize weights of the head."""
normal_init(self.rpn_conv, std=0.01)
normal_init(self.rpn_cls, std=0.01)
normal_init(self.rpn_reg, std=0.01)
def forward_single(self, x):
"""Forward feature map of a single scale level."""
x = self.rpn_conv(x)
x = F.relu(x, inplace=True)
rpn_cls_score = self.rpn_cls(x)
rpn_bbox_pred = self.rpn_reg(x)
return rpn_cls_score, rpn_bbox_pred
def loss(self,
"""Compute losses of the head.
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
dict[str, Tensor]: A dictionary of loss components.
losses = super(RPNHead, self).loss(
return dict(
loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'])
def _get_bboxes(self,
"""Transform outputs for a single batch item into bbox predictions.
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W).
mlvl_anchors (list[Tensor]): Box reference for each scale level
with shape (num_total_anchors, 4).
img_shapes (list[tuple[int]]): Shape of the input image,
(height, width, 3).
scale_factors (list[ndarray]): Scale factor of the image arange as
(w_scale, h_scale, w_scale, h_scale).
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
The first item is an (n, 5) tensor, where the first 4 columns
are bounding box positions (tl_x, tl_y, br_x, br_y) and the
5-th column is a score between 0 and 1. The second item is a
(n,) tensor where each item is the predicted class labelof the
corresponding box.
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
# bboxes from different level should be independent during NMS,
# level_ids are used as labels for batched NMS to separate them
level_ids = []
mlvl_scores = []
mlvl_bbox_preds = []
mlvl_valid_anchors = []
batch_size = cls_scores[0].shape[0]
nms_pre_tensor = torch.tensor(
cfg.nms_pre, device=cls_scores[0].device, dtype=torch.long)
for idx in range(len(cls_scores)):
rpn_cls_score = cls_scores[idx]
rpn_bbox_pred = bbox_preds[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
rpn_cls_score = rpn_cls_score.permute(0, 2, 3, 1)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.reshape(batch_size, -1)
scores = rpn_cls_score.sigmoid()
rpn_cls_score = rpn_cls_score.reshape(batch_size, -1, 2)
# We set FG labels to [0, num_class-1] and BG label to
# num_class in RPN head since mmdet v2.5, which is unified to
# be consistent with other head since mmdet v2.0. In mmdet v2.0
# to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
scores = rpn_cls_score.softmax(-1)[..., 0]
rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).reshape(
batch_size, -1, 4)
anchors = mlvl_anchors[idx]
anchors = anchors.expand_as(rpn_bbox_pred)
if nms_pre_tensor > 0:
# sort is faster than topk
# _, topk_inds = scores.topk(cfg.nms_pre)
# keep topk op for dynamic k in onnx model
if torch.onnx.is_in_onnx_export():
# sort op will be converted to TopK in onnx
# and k<=3480 in TensorRT
scores_shape = torch._shape_as_tensor(scores)
nms_pre = torch.where(scores_shape[1] < nms_pre_tensor,
scores_shape[1], nms_pre_tensor)
_, topk_inds = scores.topk(nms_pre)
batch_inds = torch.arange(batch_size).view(
-1, 1).expand_as(topk_inds)
scores = scores[batch_inds, topk_inds]
rpn_bbox_pred = rpn_bbox_pred[batch_inds, topk_inds, :]
anchors = anchors[batch_inds, topk_inds, :]
elif scores.shape[-1] > cfg.nms_pre:
ranked_scores, rank_inds = scores.sort(descending=True)
topk_inds = rank_inds[:, :cfg.nms_pre]
scores = ranked_scores[:, :cfg.nms_pre]
batch_inds = torch.arange(batch_size).view(
-1, 1).expand_as(topk_inds)
rpn_bbox_pred = rpn_bbox_pred[batch_inds, topk_inds, :]
anchors = anchors[batch_inds, topk_inds, :]
batch_mlvl_scores =, dim=1)
batch_mlvl_anchors =, dim=1)
batch_mlvl_rpn_bbox_pred =, dim=1)
batch_mlvl_proposals = self.bbox_coder.decode(
batch_mlvl_anchors, batch_mlvl_rpn_bbox_pred, max_shape=img_shapes)
batch_mlvl_ids =, dim=1)
# deprecate arguments warning
if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
'In rpn_proposal or test_cfg, '
'nms_thr has been moved to a dict named nms as '
'iou_threshold, max_num has been renamed as max_per_img, '
'name of original arguments and the way to specify '
'iou_threshold of NMS will be deprecated.')
if 'nms' not in cfg:
cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
if 'max_num' in cfg:
if 'max_per_img' in cfg:
assert cfg.max_num == cfg.max_per_img, f'You ' \
f'set max_num and ' \
f'max_per_img at the same time, but get {cfg.max_num} ' \
f'and {cfg.max_per_img} respectively' \
'Please delete max_num which will be deprecated.'
cfg.max_per_img = cfg.max_num
if 'nms_thr' in cfg:
assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set' \
f' iou_threshold in nms and ' \
f'nms_thr at the same time, but get' \
f' {cfg.nms.iou_threshold} and {cfg.nms_thr}' \
f' respectively. Please delete the nms_thr ' \
f'which will be deprecated.'
result_list = []
for (mlvl_proposals, mlvl_scores,
mlvl_ids) in zip(batch_mlvl_proposals, batch_mlvl_scores,
# Skip nonzero op while exporting to ONNX
if cfg.min_bbox_size > 0 and (not torch.onnx.is_in_onnx_export()):
w = mlvl_proposals[:, 2] - mlvl_proposals[:, 0]
h = mlvl_proposals[:, 3] - mlvl_proposals[:, 1]
valid_ind = torch.nonzero(
(w >= cfg.min_bbox_size)
& (h >= cfg.min_bbox_size),
if valid_ind.sum().item() != len(mlvl_proposals):
mlvl_proposals = mlvl_proposals[valid_ind, :]
mlvl_scores = mlvl_scores[valid_ind]
mlvl_ids = mlvl_ids[valid_ind]
dets, keep = batched_nms(mlvl_proposals, mlvl_scores, mlvl_ids,
return result_list