File size: 4,915 Bytes
128757a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""

Implements the Generalized R-CNN framework

"""

import torch
from torch import nn

from maskrcnn_benchmark.structures.image_list import to_image_list

from ..backbone import build_backbone
from ..rpn import build_rpn
from ..roi_heads import build_roi_heads

import timeit

class GeneralizedRCNN(nn.Module):
    """

    Main class for Generalized R-CNN. Currently supports boxes and masks.

    It consists of three main parts:

    - backbone

    - rpn

    - heads: takes the features + the proposals from the RPN and computes

        detections / masks from it.

    """

    def __init__(self, cfg):
        super(GeneralizedRCNN, self).__init__()

        self.backbone = build_backbone(cfg)
        self.rpn = build_rpn(cfg)
        self.roi_heads = build_roi_heads(cfg)
        self.DEBUG = cfg.MODEL.DEBUG
        self.ONNX = cfg.MODEL.ONNX
        self.freeze_backbone = cfg.MODEL.BACKBONE.FREEZE
        self.freeze_fpn = cfg.MODEL.FPN.FREEZE
        self.freeze_rpn = cfg.MODEL.RPN.FREEZE

        if cfg.MODEL.LINEAR_PROB:
            assert cfg.MODEL.BACKBONE.FREEZE, "For linear probing, backbone should be frozen!"
            if hasattr(self.backbone, 'fpn'):
                assert cfg.MODEL.FPN.FREEZE, "For linear probing, FPN should be frozen!"
        self.linear_prob = cfg.MODEL.LINEAR_PROB

    def train(self, mode=True):
        """Convert the model into training mode while keep layers freezed."""
        super(GeneralizedRCNN, self).train(mode)
        if self.freeze_backbone:
            self.backbone.body.eval()
            for p in self.backbone.body.parameters():
                p.requires_grad = False
        if self.freeze_fpn:
            self.backbone.fpn.eval()
            for p in self.backbone.fpn.parameters():
                p.requires_grad = False
        if self.freeze_rpn:
            self.rpn.eval()
            for p in self.rpn.parameters():
                p.requires_grad = False
        if self.linear_prob:
            if self.rpn is not None:
                for key, value in self.rpn.named_parameters():
                    if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key):
                        value.requires_grad = False
            if self.roi_heads is not None:
                for key, value in self.roi_heads.named_parameters():
                    if not ('bbox_pred' in key or 'cls_logits' in key or 'centerness' in key or 'cosine_scale' in key):
                        value.requires_grad = False

    def forward(self, images, targets=None):
        """

        Arguments:

            images (list[Tensor] or ImageList): images to be processed

            targets (list[BoxList]): ground-truth boxes present in the image (optional)



        Returns:

            result (list[BoxList] or dict[Tensor]): the output from the model.

                During training, it returns a dict[Tensor] which contains the losses.

                During testing, it returns list[BoxList] contains additional fields

                like `scores`, `labels` and `mask` (for Mask R-CNN models).



        """
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")

        if self.DEBUG: debug_info = {}
        if self.DEBUG: debug_info['input_size'] = images[0].size()
        if self.DEBUG: tic = timeit.time.perf_counter()

        if self.ONNX:
            features = self.backbone(images)
        else:
            images = to_image_list(images)
            features = self.backbone(images.tensors)

        if self.DEBUG: debug_info['feat_time'] = timeit.time.perf_counter() - tic
        if self.DEBUG: debug_info['feat_size'] = [feat.size() for feat in features]
        if self.DEBUG: tic = timeit.time.perf_counter()

        proposals, proposal_losses = self.rpn(images, features, targets)

        if self.DEBUG: debug_info['rpn_time'] = timeit.time.perf_counter() - tic
        if self.DEBUG: debug_info['#rpn'] = [prop for prop in proposals]
        if self.DEBUG: tic = timeit.time.perf_counter()

        if self.roi_heads:
            x, result, detector_losses = self.roi_heads(features, proposals, targets)
        else:
            # RPN-only models don't have roi_heads
            x = features
            result = proposals
            detector_losses = {}

        if self.DEBUG: debug_info['rcnn_time'] = timeit.time.perf_counter() - tic
        if self.DEBUG: debug_info['#rcnn'] = result
        if self.DEBUG: return result, debug_info

        if self.training:
            losses = {}
            losses.update(detector_losses)
            losses.update(proposal_losses)
            return losses

        return result