File size: 7,693 Bytes
c310e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torch.nn.functional as F
from torch import nn

from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist
from maskrcnn_benchmark.modeling.box_coder import BoxCoder


class PostProcessor(nn.Module):
    """
    From a set of classification scores, box regression and proposals,
    computes the post-processed boxes, and applies NMS to obtain the
    final results
    """

    def __init__(
        self, score_thresh=0.05, nms=0.5, detections_per_img=100, box_coder=None, cfg=None
    ):
        """
        Arguments:
            score_thresh (float)
            nms (float)
            detections_per_img (int)
            box_coder (BoxCoder)
        """
        super(PostProcessor, self).__init__()
        self.cfg = cfg
        self.score_thresh = score_thresh
        self.nms = nms
        self.detections_per_img = detections_per_img
        if cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION:
            if box_coder is None:
                box_coder = BoxCoder(weights=(10., 10., 5., 5.))
            self.box_coder = box_coder
 
    def forward(self, x, boxes):
        """
        Arguments:
            x (tuple[tensor, tensor]): x contains the class logits
                and the box_regression from the model.
            boxes (list[BoxList]): bounding boxes that are used as
                reference, one for ech image

        Returns:
            results (list[BoxList]): one BoxList for each image, containing
                the extra fields labels and scores
        """
        class_logits, box_regression = x
        class_prob = F.softmax(class_logits, -1)

        # TODO think about a representation of batch of boxes 
        image_shapes = [box.size for box in boxes]
        boxes_per_image = [len(box) for box in boxes]
        if self.cfg.MODEL.SEG.USE_SEG_POLY or self.cfg.MODEL.ROI_BOX_HEAD.USE_MASKED_FEATURE or self.cfg.MODEL.ROI_MASK_HEAD.USE_MASKED_FEATURE:
            masks = [box.get_field('masks') for box in boxes]
        if self.cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION:
            concat_boxes = torch.cat([a.bbox for a in boxes], dim=0)
            proposals = self.box_coder.decode(
                box_regression.view(sum(boxes_per_image), -1), concat_boxes
            )
            proposals = proposals.split(boxes_per_image, dim=0)
        else:
            proposals = boxes
        num_classes = class_prob.shape[1]
        class_prob = class_prob.split(boxes_per_image, dim=0)

        results = []
        if self.cfg.MODEL.SEG.USE_SEG_POLY or self.cfg.MODEL.ROI_BOX_HEAD.USE_MASKED_FEATURE or self.cfg.MODEL.ROI_MASK_HEAD.USE_MASKED_FEATURE:
            for prob, boxes_per_img, image_shape, mask in zip(
                class_prob, proposals, image_shapes, masks
            ):
                boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape, mask)
                if self.cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION:
                    boxlist = boxlist.clip_to_image(remove_empty=False)
                    boxlist = self.filter_results(boxlist, num_classes)
                results.append(boxlist)
        else:
            for prob, boxes_per_img, image_shape in zip(
                class_prob, proposals, image_shapes
            ):
                boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape)
                if self.cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION:
                    boxlist = boxlist.clip_to_image(remove_empty=False)
                    boxlist = self.filter_results(boxlist, num_classes)
                results.append(boxlist)
        return results

    def prepare_boxlist(self, boxes, scores, image_shape, mask=None):
        """
        Returns BoxList from `boxes` and adds probability scores information
        as an extra field
        `boxes` has shape (#detections, 4 * #classes), where each row represents
        a list of predicted bounding boxes for each of the object classes in the
        dataset (including the background class). The detections in each row
        originate from the same object proposal.
        `scores` has shape (#detection, #classes), where each row represents a list
        of object detection confidence scores for each of the object classes in the
        dataset (including the background class). `scores[i, j]`` corresponds to the
        box at `boxes[i, j * 4:(j + 1) * 4]`.
        """
        if not self.cfg.MODEL.ROI_BOX_HEAD.USE_REGRESSION:
            scores = scores.reshape(-1)
            boxes.add_field("scores", scores)
            return boxes
        boxes = boxes.reshape(-1, 4)
        scores = scores.reshape(-1)
        boxlist = BoxList(boxes, image_shape, mode="xyxy")
        boxlist.add_field("scores", scores)
        if mask is not None:
            boxlist.add_field('masks', mask)
        return boxlist

    def filter_results(self, boxlist, num_classes):
        """Returns bounding-box detection results by thresholding on scores and
        applying non-maximum suppression (NMS).
        """
        # unwrap the boxlist to avoid additional overhead.
        # if we had multi-class NMS, we could perform this directly on the boxlist
        boxes = boxlist.bbox.reshape(-1, num_classes * 4)
        scores = boxlist.get_field("scores").reshape(-1, num_classes)

        device = scores.device
        result = []
        # Apply threshold on detection probabilities and apply NMS
        # Skip j = 0, because it's the background class
        inds_all = scores > self.score_thresh
        for j in range(1, num_classes):
            inds = inds_all[:, j].nonzero().squeeze(1)
            scores_j = scores[inds, j]
            boxes_j = boxes[inds, j * 4 : (j + 1) * 4]
            boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
            boxlist_for_class.add_field("scores", scores_j)
            boxlist_for_class = boxlist_nms(
                boxlist_for_class, self.nms, score_field="scores"
            )
            num_labels = len(boxlist_for_class)
            boxlist_for_class.add_field(
                "labels", torch.full((num_labels,), j, dtype=torch.int64, device=device)
            )
            if self.cfg.MODEL.SEG.USE_SEG_POLY or self.cfg.MODEL.ROI_BOX_HEAD.USE_MASKED_FEATURE or self.cfg.MODEL.ROI_MASK_HEAD.USE_MASKED_FEATURE:
                boxlist_for_class.add_field('masks', boxlist.get_field('masks'))
            result.append(boxlist_for_class)

        result = cat_boxlist(result)
        number_of_detections = len(result)

        # Limit to max_per_image detections **over all classes**
        if number_of_detections > self.detections_per_img > 0:
            cls_scores = result.get_field("scores")
            image_thresh, _ = torch.kthvalue(
                cls_scores.cpu(), number_of_detections - self.detections_per_img + 1
            )
            keep = cls_scores >= image_thresh.item()
            keep = torch.nonzero(keep).squeeze(1)
            result = result[keep]
        return result


def make_roi_box_post_processor(cfg):
    # use_fpn = cfg.MODEL.ROI_HEADS.USE_FPN

    bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS
    box_coder = BoxCoder(weights=bbox_reg_weights)

    score_thresh = cfg.MODEL.ROI_HEADS.SCORE_THRESH
    nms_thresh = cfg.MODEL.ROI_HEADS.NMS
    detections_per_img = cfg.MODEL.ROI_HEADS.DETECTIONS_PER_IMG

    postprocessor = PostProcessor(
        score_thresh, nms_thresh, detections_per_img, box_coder, cfg
    )
    return postprocessor