Spaces:
Runtime error
Runtime error
File size: 3,668 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.core import bbox2result, bbox_mapping_back
from ..builder import DETECTORS
from .single_stage import SingleStageDetector
@DETECTORS.register_module()
class CornerNet(SingleStageDetector):
"""CornerNet.
This detector is the implementation of the paper `CornerNet: Detecting
Objects as Paired Keypoints <https://arxiv.org/abs/1808.01244>`_ .
"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(CornerNet, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained, init_cfg)
def merge_aug_results(self, aug_results, img_metas):
"""Merge augmented detection bboxes and score.
Args:
aug_results (list[list[Tensor]]): Det_bboxes and det_labels of each
image.
img_metas (list[list[dict]]): Meta information of each image, e.g.,
image size, scaling factor, etc.
Returns:
tuple: (bboxes, labels)
"""
recovered_bboxes, aug_labels = [], []
for bboxes_labels, img_info in zip(aug_results, img_metas):
img_shape = img_info[0]['img_shape'] # using shape before padding
scale_factor = img_info[0]['scale_factor']
flip = img_info[0]['flip']
bboxes, labels = bboxes_labels
bboxes, scores = bboxes[:, :4], bboxes[:, -1:]
bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip)
recovered_bboxes.append(torch.cat([bboxes, scores], dim=-1))
aug_labels.append(labels)
bboxes = torch.cat(recovered_bboxes, dim=0)
labels = torch.cat(aug_labels)
if bboxes.shape[0] > 0:
out_bboxes, out_labels = self.bbox_head._bboxes_nms(
bboxes, labels, self.bbox_head.test_cfg)
else:
out_bboxes, out_labels = bboxes, labels
return out_bboxes, out_labels
def aug_test(self, imgs, img_metas, rescale=False):
"""Augment testing of CornerNet.
Args:
imgs (list[Tensor]): Augmented images.
img_metas (list[list[dict]]): Meta information of each image, e.g.,
image size, scaling factor, etc.
rescale (bool): If True, return boxes in original image space.
Default: False.
Note:
``imgs`` must including flipped image pairs.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
img_inds = list(range(len(imgs)))
assert img_metas[0][0]['flip'] + img_metas[1][0]['flip'], (
'aug test must have flipped image pair')
aug_results = []
for ind, flip_ind in zip(img_inds[0::2], img_inds[1::2]):
img_pair = torch.cat([imgs[ind], imgs[flip_ind]])
x = self.extract_feat(img_pair)
outs = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(
*outs, [img_metas[ind], img_metas[flip_ind]], False, False)
aug_results.append(bbox_list[0])
aug_results.append(bbox_list[1])
bboxes, labels = self.merge_aug_results(aug_results, img_metas)
bbox_results = bbox2result(bboxes, labels, self.bbox_head.num_classes)
return [bbox_results]
|