rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
import torch
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.utils import InstanceList, OptInstanceList
from ..losses import carl_loss, isr_p
from ..utils import images_to_levels
from .retina_head import RetinaHead
@MODELS.register_module()
class PISARetinaHead(RetinaHead):
"""PISA Retinanet Head.
The head owns the same structure with Retinanet Head, but differs in two
aspects:
1. Importance-based Sample Reweighting Positive (ISR-P) is applied to
change the positive loss weights.
2. Classification-aware regression loss is adopted as a third loss.
"""
def loss_by_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict: Loss dict, comprise classification loss, regression loss and
carl loss.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.prior_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, batch_img_metas, device=device)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
batch_gt_instances,
batch_img_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore,
return_sampling_results=True)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
avg_factor, sampling_results_list) = cls_reg_targets
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
# concat all level anchors and flags to a single tensor
concat_anchor_list = []
for i in range(len(anchor_list)):
concat_anchor_list.append(torch.cat(anchor_list[i]))
all_anchor_list = images_to_levels(concat_anchor_list,
num_level_anchors)
num_imgs = len(batch_img_metas)
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, label_channels)
for cls_score in cls_scores
]
flatten_cls_scores = torch.cat(
flatten_cls_scores, dim=1).reshape(-1,
flatten_cls_scores[0].size(-1))
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
flatten_bbox_preds = torch.cat(
flatten_bbox_preds, dim=1).view(-1, flatten_bbox_preds[0].size(-1))
flatten_labels = torch.cat(labels_list, dim=1).reshape(-1)
flatten_label_weights = torch.cat(
label_weights_list, dim=1).reshape(-1)
flatten_anchors = torch.cat(all_anchor_list, dim=1).reshape(-1, 4)
flatten_bbox_targets = torch.cat(
bbox_targets_list, dim=1).reshape(-1, 4)
flatten_bbox_weights = torch.cat(
bbox_weights_list, dim=1).reshape(-1, 4)
# Apply ISR-P
isr_cfg = self.train_cfg.get('isr', None)
if isr_cfg is not None:
all_targets = (flatten_labels, flatten_label_weights,
flatten_bbox_targets, flatten_bbox_weights)
with torch.no_grad():
all_targets = isr_p(
flatten_cls_scores,
flatten_bbox_preds,
all_targets,
flatten_anchors,
sampling_results_list,
bbox_coder=self.bbox_coder,
loss_cls=self.loss_cls,
num_class=self.num_classes,
**self.train_cfg['isr'])
(flatten_labels, flatten_label_weights, flatten_bbox_targets,
flatten_bbox_weights) = all_targets
# For convenience we compute loss once instead separating by fpn level,
# so that we don't need to separate the weights by level again.
# The result should be the same
losses_cls = self.loss_cls(
flatten_cls_scores,
flatten_labels,
flatten_label_weights,
avg_factor=avg_factor)
losses_bbox = self.loss_bbox(
flatten_bbox_preds,
flatten_bbox_targets,
flatten_bbox_weights,
avg_factor=avg_factor)
loss_dict = dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
# CARL Loss
carl_cfg = self.train_cfg.get('carl', None)
if carl_cfg is not None:
loss_carl = carl_loss(
flatten_cls_scores,
flatten_labels,
flatten_bbox_preds,
flatten_bbox_targets,
self.loss_bbox,
**self.train_cfg['carl'],
avg_factor=avg_factor,
sigmoid=True,
num_class=self.num_classes)
loss_dict.update(loss_carl)
return loss_dict