File size: 5,328 Bytes
430de99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Copyright (c) Facebook, Inc. and its affiliates.
# -*- coding: utf-8 -*-

import logging
import typing
import torch
from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table
from torch import nn

from detectron2.structures import BitMasks, Boxes, ImageList, Instances

from .logger import log_first_n

__all__ = [
    "activation_count_operators",
    "flop_count_operators",
    "parameter_count_table",
    "parameter_count",
]

FLOPS_MODE = "flops"
ACTIVATIONS_MODE = "activations"


# some extra ops to ignore from counting.
_IGNORED_OPS = {
    "aten::add",
    "aten::add_",
    "aten::batch_norm",
    "aten::constant_pad_nd",
    "aten::div",
    "aten::div_",
    "aten::exp",
    "aten::log2",
    "aten::max_pool2d",
    "aten::meshgrid",
    "aten::mul",
    "aten::mul_",
    "aten::nonzero_numpy",
    "aten::rsub",
    "aten::sigmoid",
    "aten::sigmoid_",
    "aten::softmax",
    "aten::sort",
    "aten::sqrt",
    "aten::sub",
    "aten::upsample_nearest2d",
    "prim::PythonOp",
    "torchvision::nms",  # TODO estimate flop for nms
}


def flop_count_operators(
    model: nn.Module, inputs: list, **kwargs
) -> typing.DefaultDict[str, float]:
    """
    Implement operator-level flops counting using jit.
    This is a wrapper of fvcore.nn.flop_count, that supports standard detection models
    in detectron2.

    Note:
        The function runs the input through the model to compute flops.
        The flops of a detection model is often input-dependent, for example,
        the flops of box & mask head depends on the number of proposals &
        the number of detected objects.
        Therefore, the flops counting using a single input may not accurately
        reflect the computation cost of a model.

    Args:
        model: a detectron2 model that takes `list[dict]` as input.
        inputs (list[dict]): inputs to model, in detectron2's standard format.
    """
    return _wrapper_count_operators(model=model, inputs=inputs, mode=FLOPS_MODE, **kwargs)


def activation_count_operators(
    model: nn.Module, inputs: list, **kwargs
) -> typing.DefaultDict[str, float]:
    """
    Implement operator-level activations counting using jit.
    This is a wrapper of fvcore.nn.activation_count, that supports standard detection models
    in detectron2.

    Note:
        The function runs the input through the model to compute activations.
        The activations of a detection model is often input-dependent, for example,
        the activations of box & mask head depends on the number of proposals &
        the number of detected objects.

    Args:
        model: a detectron2 model that takes `list[dict]` as input.
        inputs (list[dict]): inputs to model, in detectron2's standard format.
    """
    return _wrapper_count_operators(model=model, inputs=inputs, mode=ACTIVATIONS_MODE, **kwargs)


def _flatten_to_tuple(outputs):
    result = []
    if isinstance(outputs, torch.Tensor):
        result.append(outputs)
    elif isinstance(outputs, (list, tuple)):
        for v in outputs:
            result.extend(_flatten_to_tuple(v))
    elif isinstance(outputs, dict):
        for _, v in outputs.items():
            result.extend(_flatten_to_tuple(v))
    elif isinstance(outputs, Instances):
        result.extend(_flatten_to_tuple(outputs.get_fields()))
    elif isinstance(outputs, (Boxes, BitMasks, ImageList)):
        result.append(outputs.tensor)
    else:
        log_first_n(
            logging.WARN,
            f"Output of type {type(outputs)} not included in flops/activations count.",
            n=10,
        )
    return tuple(result)


def _wrapper_count_operators(
    model: nn.Module, inputs: list, mode: str, **kwargs
) -> typing.DefaultDict[str, float]:

    # ignore some ops
    supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS}
    supported_ops.update(kwargs.pop("supported_ops", {}))
    kwargs["supported_ops"] = supported_ops

    assert len(inputs) == 1, "Please use batch size=1"
    tensor_input = inputs[0]["image"]

    class WrapModel(nn.Module):
        def __init__(self, model):
            super().__init__()
            if isinstance(
                model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)
            ):
                self.model = model.module
            else:
                self.model = model

        def forward(self, image):
            # jit requires the input/output to be Tensors
            inputs = [{"image": image}]
            outputs = self.model.forward(inputs)
            # Only the subgraph that computes the returned tuple of tensor will be
            # counted. So we flatten everything we found to tuple of tensors.
            return _flatten_to_tuple(outputs)

    old_train = model.training
    with torch.no_grad():
        if mode == FLOPS_MODE:
            ret = flop_count(WrapModel(model).train(False), (tensor_input,), **kwargs)
        elif mode == ACTIVATIONS_MODE:
            ret = activation_count(WrapModel(model).train(False), (tensor_input,), **kwargs)
        else:
            raise NotImplementedError("Count for mode {} is not supported yet.".format(mode))
    # compatible with change in fvcore
    if isinstance(ret, tuple):
        ret = ret[0]
    model.train(old_train)
    return ret