rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy
import math
from typing import Callable, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Scale
from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d
from mmengine.config import ConfigDict
from mmengine.model import BaseModel
from mmengine.structures import InstanceData
from torch import Tensor
try:
from transformers import BertConfig
except ImportError:
BertConfig = None
from mmdet.registry import MODELS
from mmdet.structures.bbox import cat_boxes
from mmdet.utils import InstanceList, OptInstanceList, reduce_mean
from ..utils import (BertEncoderLayer, VLFuse, filter_scores_and_topk,
permute_and_flatten, select_single_mlvl,
unpack_gt_instances)
from ..utils.vlfuse_helper import MAX_CLAMP_VALUE
from .atss_head import ATSSHead
def convert_grounding_to_cls_scores(logits: Tensor,
positive_maps: List[dict]) -> Tensor:
"""Convert logits to class scores."""
assert len(positive_maps) == logits.shape[0] # batch size
scores = torch.zeros(logits.shape[0], logits.shape[1],
len(positive_maps[0])).to(logits.device)
if positive_maps is not None:
if all(x == positive_maps[0] for x in positive_maps):
# only need to compute once
positive_map = positive_maps[0]
for label_j in positive_map:
scores[:, :, label_j -
1] = logits[:, :,
torch.LongTensor(positive_map[label_j]
)].mean(-1)
else:
for i, positive_map in enumerate(positive_maps):
for label_j in positive_map:
scores[i, :, label_j - 1] = logits[
i, :, torch.LongTensor(positive_map[label_j])].mean(-1)
return scores
class Conv3x3Norm(nn.Module):
"""Conv3x3 and norm."""
def __init__(self,
in_channels: int,
out_channels: int,
stride: int,
groups: int = 1,
use_dcn: bool = False,
norm_type: Optional[Union[Sequence, str]] = None):
super().__init__()
if use_dcn:
self.conv = ModulatedDeformConv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
groups=groups)
else:
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
groups=groups)
if isinstance(norm_type, Sequence):
assert len(norm_type) == 2
assert norm_type[0] == 'gn'
gn_group = norm_type[1]
norm_type = norm_type[0]
if norm_type == 'bn':
bn_op = nn.BatchNorm2d(out_channels)
elif norm_type == 'gn':
bn_op = nn.GroupNorm(
num_groups=gn_group, num_channels=out_channels)
if norm_type is not None:
self.bn = bn_op
else:
self.bn = None
def forward(self, x, **kwargs):
x = self.conv(x, **kwargs)
if self.bn:
x = self.bn(x)
return x
class DyReLU(nn.Module):
"""Dynamic ReLU."""
def __init__(self,
in_channels: int,
out_channels: int,
expand_ratio: int = 4):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.expand_ratio = expand_ratio
self.out_channels = out_channels
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // expand_ratio),
nn.ReLU(inplace=True),
nn.Linear(in_channels // expand_ratio,
out_channels * self.expand_ratio),
nn.Hardsigmoid(inplace=True))
def forward(self, x) -> Tensor:
x_out = x
b, c, h, w = x.size()
x = self.avg_pool(x).view(b, c)
x = self.fc(x).view(b, -1, 1, 1)
a1, b1, a2, b2 = torch.split(x, self.out_channels, dim=1)
a1 = (a1 - 0.5) * 2 + 1.0
a2 = (a2 - 0.5) * 2
b1 = b1 - 0.5
b2 = b2 - 0.5
out = torch.max(x_out * a1 + b1, x_out * a2 + b2)
return out
class DyConv(nn.Module):
"""Dynamic Convolution."""
def __init__(self,
conv_func: Callable,
in_channels: int,
out_channels: int,
use_dyfuse: bool = True,
use_dyrelu: bool = False,
use_dcn: bool = False):
super().__init__()
self.dyconvs = nn.ModuleList()
self.dyconvs.append(conv_func(in_channels, out_channels, 1))
self.dyconvs.append(conv_func(in_channels, out_channels, 1))
self.dyconvs.append(conv_func(in_channels, out_channels, 2))
if use_dyfuse:
self.attnconv = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, 1, kernel_size=1),
nn.ReLU(inplace=True))
self.h_sigmoid = nn.Hardsigmoid(inplace=True)
else:
self.attnconv = None
if use_dyrelu:
self.relu = DyReLU(in_channels, out_channels)
else:
self.relu = nn.ReLU()
if use_dcn:
self.offset = nn.Conv2d(
in_channels, 27, kernel_size=3, stride=1, padding=1)
else:
self.offset = None
self.init_weights()
def init_weights(self):
for m in self.dyconvs.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, 0, 0.01)
if m.bias is not None:
m.bias.data.zero_()
if self.attnconv is not None:
for m in self.attnconv.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, 0, 0.01)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, inputs: dict) -> dict:
visual_feats = inputs['visual']
out_vis_feats = []
for level, feature in enumerate(visual_feats):
offset_conv_args = {}
if self.offset is not None:
offset_mask = self.offset(feature)
offset = offset_mask[:, :18, :, :]
mask = offset_mask[:, 18:, :, :].sigmoid()
offset_conv_args = dict(offset=offset, mask=mask)
temp_feats = [self.dyconvs[1](feature, **offset_conv_args)]
if level > 0:
temp_feats.append(self.dyconvs[2](visual_feats[level - 1],
**offset_conv_args))
if level < len(visual_feats) - 1:
temp_feats.append(
F.upsample_bilinear(
self.dyconvs[0](visual_feats[level + 1],
**offset_conv_args),
size=[feature.size(2),
feature.size(3)]))
mean_feats = torch.mean(
torch.stack(temp_feats), dim=0, keepdim=False)
if self.attnconv is not None:
attn_feat = []
res_feat = []
for feat in temp_feats:
res_feat.append(feat)
attn_feat.append(self.attnconv(feat))
res_feat = torch.stack(res_feat)
spa_pyr_attn = self.h_sigmoid(torch.stack(attn_feat))
mean_feats = torch.mean(
res_feat * spa_pyr_attn, dim=0, keepdim=False)
out_vis_feats.append(mean_feats)
out_vis_feats = [self.relu(item) for item in out_vis_feats]
features_dict = {'visual': out_vis_feats, 'lang': inputs['lang']}
return features_dict
class VLFusionModule(BaseModel):
"""Visual-lang Fusion Module."""
def __init__(self,
in_channels: int,
feat_channels: int,
num_base_priors: int,
early_fuse: bool = False,
num_dyhead_blocks: int = 6,
lang_model_name: str = 'bert-base-uncased',
use_dyrelu: bool = True,
use_dyfuse: bool = True,
use_dcn: bool = True,
use_checkpoint: bool = False,
**kwargs) -> None:
super().__init__(**kwargs)
if BertConfig is None:
raise RuntimeError(
'transformers is not installed, please install it by: '
'pip install transformers.')
self.in_channels = in_channels
self.feat_channels = feat_channels
self.num_base_priors = num_base_priors
self.early_fuse = early_fuse
self.num_dyhead_blocks = num_dyhead_blocks
self.use_dyrelu = use_dyrelu
self.use_dyfuse = use_dyfuse
self.use_dcn = use_dcn
self.use_checkpoint = use_checkpoint
self.lang_cfg = BertConfig.from_pretrained(lang_model_name)
self.lang_dim = self.lang_cfg.hidden_size
self._init_layers()
def _init_layers(self) -> None:
"""Initialize layers of the model."""
bias_value = -math.log((1 - 0.01) / 0.01)
dyhead_tower = []
for i in range(self.num_dyhead_blocks):
if self.early_fuse:
# cross-modality fusion
dyhead_tower.append(VLFuse(use_checkpoint=self.use_checkpoint))
# lang branch
dyhead_tower.append(
BertEncoderLayer(
self.lang_cfg,
clamp_min_for_underflow=True,
clamp_max_for_overflow=True))
# vision branch
dyhead_tower.append(
DyConv(
lambda i, o, s: Conv3x3Norm(
i, o, s, use_dcn=self.use_dcn, norm_type=['gn', 16]),
self.in_channels if i == 0 else self.feat_channels,
self.feat_channels,
use_dyrelu=(self.use_dyrelu
and self.in_channels == self.feat_channels)
if i == 0 else self.use_dyrelu,
use_dyfuse=(self.use_dyfuse
and self.in_channels == self.feat_channels)
if i == 0 else self.use_dyfuse,
use_dcn=(self.use_dcn
and self.in_channels == self.feat_channels)
if i == 0 else self.use_dcn,
))
self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower))
self.bbox_pred = nn.Conv2d(
self.feat_channels, self.num_base_priors * 4, kernel_size=1)
self.centerness = nn.Conv2d(
self.feat_channels, self.num_base_priors * 1, kernel_size=1)
self.dot_product_projection_text = nn.Linear(
self.lang_dim,
self.num_base_priors * self.feat_channels,
bias=True)
self.log_scale = nn.Parameter(torch.Tensor([0.0]), requires_grad=True)
self.bias_lang = nn.Parameter(
torch.zeros(self.lang_dim), requires_grad=True)
self.bias0 = nn.Parameter(
torch.Tensor([bias_value]), requires_grad=True)
self.scales = nn.ModuleList([Scale(1.0) for _ in range(5)])
def forward(self, visual_feats: Tuple[Tensor],
language_feats: dict) -> Tuple:
feat_inputs = {'visual': visual_feats, 'lang': language_feats}
dyhead_tower = self.dyhead_tower(feat_inputs)
if self.early_fuse:
embedding = dyhead_tower['lang']['hidden']
else:
embedding = language_feats['embedded']
embedding = F.normalize(embedding, p=2, dim=-1)
dot_product_proj_tokens = self.dot_product_projection_text(embedding /
2.0)
dot_product_proj_tokens_bias = torch.matmul(
embedding, self.bias_lang) + self.bias0
bbox_preds = []
centerness = []
cls_logits = []
for i, feature in enumerate(visual_feats):
visual = dyhead_tower['visual'][i]
B, C, H, W = visual.shape
bbox_pred = self.scales[i](self.bbox_pred(visual))
bbox_preds.append(bbox_pred)
centerness.append(self.centerness(visual))
dot_product_proj_queries = permute_and_flatten(
visual, B, self.num_base_priors, C, H, W)
bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat(
1, self.num_base_priors, 1)
dot_product_logit = (
torch.matmul(dot_product_proj_queries,
dot_product_proj_tokens.transpose(-1, -2)) /
self.log_scale.exp()) + bias
dot_product_logit = torch.clamp(
dot_product_logit, max=MAX_CLAMP_VALUE)
dot_product_logit = torch.clamp(
dot_product_logit, min=-MAX_CLAMP_VALUE)
cls_logits.append(dot_product_logit)
return bbox_preds, centerness, cls_logits
@MODELS.register_module()
class ATSSVLFusionHead(ATSSHead):
"""ATSS head with visual-language fusion module.
Args:
early_fuse (bool): Whether to fuse visual and language features
Defaults to False.
use_checkpoint (bool): Whether to use checkpoint. Defaults to False.
num_dyhead_blocks (int): Number of dynamic head blocks. Defaults to 6.
lang_model_name (str): Name of the language model.
Defaults to 'bert-base-uncased'.
"""
def __init__(self,
*args,
early_fuse: bool = False,
use_checkpoint: bool = False,
num_dyhead_blocks: int = 6,
lang_model_name: str = 'bert-base-uncased',
init_cfg=None,
**kwargs):
super().__init__(*args, **kwargs, init_cfg=init_cfg)
self.head = VLFusionModule(
in_channels=self.in_channels,
feat_channels=self.feat_channels,
num_base_priors=self.num_base_priors,
early_fuse=early_fuse,
use_checkpoint=use_checkpoint,
num_dyhead_blocks=num_dyhead_blocks,
lang_model_name=lang_model_name)
self.text_masks = None
def _init_layers(self) -> None:
"""No need to initialize the ATSS head layer."""
pass
def forward(self, visual_feats: Tuple[Tensor],
language_feats: dict) -> Tuple[Tensor]:
"""Forward function."""
bbox_preds, centerness, cls_logits = self.head(visual_feats,
language_feats)
return cls_logits, bbox_preds, centerness
def loss(self, visual_feats: Tuple[Tensor], language_feats: dict,
batch_data_samples):
outputs = unpack_gt_instances(batch_data_samples)
(batch_gt_instances, batch_gt_instances_ignore,
batch_img_metas) = outputs
outs = self(visual_feats, language_feats)
self.text_masks = language_feats['masks']
loss_inputs = outs + (batch_gt_instances, batch_img_metas,
batch_gt_instances_ignore)
losses = self.loss_by_feat(*loss_inputs)
return losses
def loss_by_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
centernesses: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
centernesses (list[Tensor]): Centerness for each scale
level with shape (N, num_anchors * 1, H, W)
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds]
assert len(featmap_sizes) == self.prior_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, batch_img_metas, device=device)
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
batch_gt_instances,
batch_img_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore)
(anchor_list, labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, avg_factor) = cls_reg_targets
avg_factor = reduce_mean(
torch.tensor(avg_factor, dtype=torch.float, device=device)).item()
anchors = torch.cat(anchor_list, dim=1)
labels = torch.cat(labels_list, dim=1)
label_weights = torch.cat(label_weights_list, dim=1)
bbox_targets = torch.cat(bbox_targets_list, dim=1)
cls_scores = torch.cat(cls_scores, dim=1)
centernesses_ = []
bbox_preds_ = []
for bbox_pred, centerness in zip(bbox_preds, centernesses):
centernesses_.append(
centerness.permute(0, 2, 3,
1).reshape(cls_scores.size(0), -1, 1))
bbox_preds_.append(
bbox_pred.permute(0, 2, 3,
1).reshape(cls_scores.size(0), -1, 4))
bbox_preds = torch.cat(bbox_preds_, dim=1)
centernesses = torch.cat(centernesses_, dim=1)
losses_cls, losses_bbox, loss_centerness, bbox_avg_factor = \
self._loss_by_feat(
anchors,
cls_scores,
bbox_preds,
centernesses,
labels,
label_weights,
bbox_targets,
avg_factor=avg_factor)
bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item()
losses_bbox = losses_bbox / bbox_avg_factor
return dict(
loss_cls=losses_cls,
loss_bbox=losses_bbox,
loss_centerness=loss_centerness)
def _loss_by_feat(self, anchors: Tensor, cls_score: Tensor,
bbox_pred: Tensor, centerness: Tensor, labels: Tensor,
label_weights: Tensor, bbox_targets: Tensor,
avg_factor: float) -> dict:
"""Calculate the loss of all scale level based on the features
extracted by the detection head.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
anchors = anchors.reshape(-1, 4)
# ===== this change =====
pos_inds = (labels.sum(-1) > 0).reshape(-1)
# Loss is not computed for the padded regions of the text.
assert (self.text_masks.dim() == 2)
text_mask = (self.text_masks > 0).unsqueeze(1)
text_mask = text_mask.repeat(1, cls_score.size(1), 1)
cls_score = torch.masked_select(cls_score, text_mask).contiguous()
labels = torch.masked_select(labels, text_mask)
label_weights = label_weights[...,
None].repeat(1, 1, text_mask.size(-1))
label_weights = torch.masked_select(label_weights, text_mask)
bbox_pred = bbox_pred.reshape(-1, 4)
centerness = centerness.reshape(-1)
bbox_targets = bbox_targets.reshape(-1, 4)
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
# classification loss
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=avg_factor)
if pos_inds.sum() > 0:
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_pred = bbox_pred[pos_inds]
pos_anchors = anchors[pos_inds]
pos_centerness = centerness[pos_inds]
centerness_targets = self.centerness_target(
pos_anchors, pos_bbox_targets)
if torch.isnan(centerness_targets).any():
print('=====Centerness includes NaN=====')
mask = ~torch.isnan(centerness_targets)
centerness_targets = centerness_targets[mask]
pos_centerness = pos_centerness[mask]
pos_anchors = pos_anchors[mask]
pos_bbox_targets = pos_bbox_targets[mask]
pos_bbox_pred = pos_bbox_pred[mask]
if pos_bbox_targets.shape[0] == 0:
loss_bbox = bbox_pred.sum() * 0
loss_centerness = centerness.sum() * 0
centerness_targets = bbox_targets.new_tensor(0.)
return loss_cls, loss_bbox, loss_centerness, \
centerness_targets.sum()
# The decoding process takes the offset into consideration.
pos_anchors[:, 2:] += 1
pos_decode_bbox_pred = self.bbox_coder.decode(
pos_anchors, pos_bbox_pred)
# regression loss
loss_bbox = self.loss_bbox(
pos_decode_bbox_pred,
pos_bbox_targets,
weight=centerness_targets,
avg_factor=1.0)
# centerness loss
loss_centerness = self.loss_centerness(
pos_centerness, centerness_targets, avg_factor=avg_factor)
else:
loss_bbox = bbox_pred.sum() * 0
loss_centerness = centerness.sum() * 0
centerness_targets = bbox_targets.new_tensor(0.)
return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum()
def _get_targets_single(self,
flat_anchors: Tensor,
valid_flags: Tensor,
num_level_anchors: List[int],
gt_instances: InstanceData,
img_meta: dict,
gt_instances_ignore: Optional[InstanceData] = None,
unmap_outputs: bool = True) -> tuple:
"""Compute regression, classification targets for anchors in a single
image.
Args:
flat_anchors (Tensor): Multi-level anchors of the image, which are
concatenated into a single tensor of shape (num_anchors ,4)
valid_flags (Tensor): Multi level valid flags of the image,
which are concatenated into a single tensor of
shape (num_anchors,).
num_level_anchors (List[int]): Number of anchors of each scale
level.
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It usually includes ``bboxes`` and ``labels``
attributes.
img_meta (dict): Meta information for current image.
gt_instances_ignore (:obj:`InstanceData`, optional): Instances
to be ignored during training. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple: N is the number of total anchors in the image.
labels (Tensor): Labels of all anchors in the image with shape
(N,).
label_weights (Tensor): Label weights of all anchor in the
image with shape (N,).
bbox_targets (Tensor): BBox targets of all anchors in the
image with shape (N, 4).
bbox_weights (Tensor): BBox weights of all anchors in the
image with shape (N, 4)
pos_inds (Tensor): Indices of positive anchor with shape
(num_pos,).
neg_inds (Tensor): Indices of negative anchor with shape
(num_neg,).
sampling_result (:obj:`SamplingResult`): Sampling results.
"""
anchors = flat_anchors
# Align the official implementation
anchors[:, 2:] -= 1
num_level_anchors_inside = num_level_anchors
pred_instances = InstanceData(priors=anchors)
assign_result = self.assigner.assign(pred_instances,
num_level_anchors_inside,
gt_instances, gt_instances_ignore)
sampling_result = self.sampler.sample(assign_result, pred_instances,
gt_instances)
num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
# ===== this change =====
labels = anchors.new_full((num_valid_anchors, self.feat_channels),
0,
dtype=torch.float32)
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
if self.reg_decoded_bbox:
pos_bbox_targets = sampling_result.pos_gt_bboxes
else:
pos_bbox_targets = self.bbox_coder.encode(
sampling_result.pos_priors, sampling_result.pos_gt_bboxes)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
# ===== this change =====
labels[pos_inds] = gt_instances.positive_maps[
sampling_result.pos_assigned_gt_inds]
if self.train_cfg['pos_weight'] <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = self.train_cfg['pos_weight']
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
return (anchors, labels, label_weights, bbox_targets, bbox_weights,
pos_inds, neg_inds, sampling_result)
def centerness_target(self, anchors: Tensor, gts: Tensor) -> Tensor:
"""Calculate the centerness between anchors and gts.
Only calculate pos centerness targets, otherwise there may be nan.
Args:
anchors (Tensor): Anchors with shape (N, 4), "xyxy" format.
gts (Tensor): Ground truth bboxes with shape (N, 4), "xyxy" format.
Returns:
Tensor: Centerness between anchors and gts.
"""
anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
l_ = anchors_cx - gts[:, 0]
t_ = anchors_cy - gts[:, 1]
r_ = gts[:, 2] - anchors_cx
b_ = gts[:, 3] - anchors_cy
left_right = torch.stack([l_, r_], dim=1)
top_bottom = torch.stack([t_, b_], dim=1)
centerness = torch.sqrt(
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) *
(top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]))
# assert not torch.isnan(centerness).any()
return centerness
def predict(self,
visual_feats: Tuple[Tensor],
language_feats: dict,
batch_data_samples,
rescale: bool = True):
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network.
Args:
visual_feats (tuple[Tensor]): Multi-level visual features from the
upstream network, each is a 4D-tensor.
language_feats (dict): Language features from the upstream network.
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[obj:`InstanceData`]: Detection results of each image
after the post process.
"""
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
batch_token_positive_maps = [
data_samples.token_positive_map
for data_samples in batch_data_samples
]
outs = self(visual_feats, language_feats)
predictions = self.predict_by_feat(
*outs,
batch_img_metas=batch_img_metas,
batch_token_positive_maps=batch_token_positive_maps,
rescale=rescale)
return predictions
def predict_by_feat(self,
cls_logits: List[Tensor],
bbox_preds: List[Tensor],
score_factors: List[Tensor],
batch_img_metas: Optional[List[dict]] = None,
batch_token_positive_maps: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
with_nms: bool = True) -> InstanceList:
"""Transform a batch of output features extracted from the head into
bbox results.
Note: When score_factors is not None, the cls_scores are
usually multiplied by it then obtain the real score used in NMS,
such as CenterNess in FCOS, IoU branch in ATSS.
Args:
cls_logits (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
score_factors (list[Tensor], optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Defaults to None.
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
batch_token_positive_maps (list[dict], Optional): Batch token
positive map. Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
after the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
assert len(bbox_preds) == len(score_factors)
num_levels = len(bbox_preds)
featmap_sizes = [bbox_preds[i].shape[-2:] for i in range(num_levels)]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=bbox_preds[0].dtype,
device=bbox_preds[0].device)
result_list = []
for img_id in range(len(batch_img_metas)):
img_meta = batch_img_metas[img_id]
token_positive_maps = batch_token_positive_maps[img_id]
bbox_pred_list = select_single_mlvl(
bbox_preds, img_id, detach=True)
score_factor_list = select_single_mlvl(
score_factors, img_id, detach=True)
cls_logit_list = select_single_mlvl(
cls_logits, img_id, detach=True)
results = self._predict_by_feat_single(
bbox_pred_list=bbox_pred_list,
score_factor_list=score_factor_list,
cls_logit_list=cls_logit_list,
mlvl_priors=mlvl_priors,
token_positive_maps=token_positive_maps,
img_meta=img_meta,
cfg=cfg,
rescale=rescale,
with_nms=with_nms)
result_list.append(results)
return result_list
def _predict_by_feat_single(self,
bbox_pred_list: List[Tensor],
score_factor_list: List[Tensor],
cls_logit_list: List[Tensor],
mlvl_priors: List[Tensor],
token_positive_maps: dict,
img_meta: dict,
cfg: ConfigDict,
rescale: bool = True,
with_nms: bool = True) -> InstanceData:
"""Transform a single image's features extracted from the head into
bbox results.
Args:
bbox_pred_list (list[Tensor]): Box energies / deltas from
all scale levels of a single image, each item has shape
(num_priors * 4, H, W).
score_factor_list (list[Tensor]): Score factor from all scale
levels of a single image, each item has shape
(num_priors * 1, H, W).
cls_logit_list (list[Tensor]): Box scores from all scale
levels of a single image, each item has shape
(num_priors * num_classes, H, W).
mlvl_priors (list[Tensor]): Each element in the list is
the priors of a single level in feature pyramid. In all
anchor-based methods, it has shape (num_priors, 4). In
all anchor-free methods, it has shape (num_priors, 2)
when `with_stride=True`, otherwise it still has shape
(num_priors, 4).
token_positive_maps (dict): Token positive map.
img_meta (dict): Image meta info.
cfg (mmengine.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
:obj:`InstanceData`: Detection results of each image
after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
img_shape = img_meta['img_shape']
nms_pre = cfg.get('nms_pre', -1)
score_thr = cfg.get('score_thr', 0)
mlvl_bbox_preds = []
mlvl_valid_priors = []
mlvl_scores = []
mlvl_labels = []
for level_idx, (bbox_pred, score_factor, cls_logit, priors) in \
enumerate(zip(bbox_pred_list,
score_factor_list, cls_logit_list, mlvl_priors)):
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(
-1, self.bbox_coder.encode_size)
score_factor = score_factor.permute(1, 2, 0).reshape(-1).sigmoid()
scores = convert_grounding_to_cls_scores(
logits=cls_logit.sigmoid()[None],
positive_maps=[token_positive_maps])[0]
results = filter_scores_and_topk(
scores, score_thr, nms_pre,
dict(bbox_pred=bbox_pred, priors=priors))
scores, labels, keep_idxs, filtered_results = results
bbox_pred = filtered_results['bbox_pred']
priors = filtered_results['priors']
score_factor = score_factor[keep_idxs]
scores = torch.sqrt(scores * score_factor)
mlvl_bbox_preds.append(bbox_pred)
mlvl_valid_priors.append(priors)
mlvl_scores.append(scores)
mlvl_labels.append(labels)
bbox_pred = torch.cat(mlvl_bbox_preds)
priors = cat_boxes(mlvl_valid_priors)
bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape)
results = InstanceData()
results.bboxes = bboxes
results.scores = torch.cat(mlvl_scores)
results.labels = torch.cat(mlvl_labels)
predictions = self._bbox_post_process(
results=results,
cfg=cfg,
rescale=rescale,
with_nms=with_nms,
img_meta=img_meta)
if len(predictions) > 0:
# Note: GLIP adopts a very strange bbox decoder logic,
# and if 1 is not added here, it will not align with
# the official mAP.
predictions.bboxes[:, 2:] = predictions.bboxes[:, 2:] + 1
return predictions