Spaces:
No application file
No application file
File size: 5,328 Bytes
430de99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# Copyright (c) Facebook, Inc. and its affiliates.
# -*- coding: utf-8 -*-
import logging
import typing
import torch
from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table
from torch import nn
from detectron2.structures import BitMasks, Boxes, ImageList, Instances
from .logger import log_first_n
__all__ = [
"activation_count_operators",
"flop_count_operators",
"parameter_count_table",
"parameter_count",
]
FLOPS_MODE = "flops"
ACTIVATIONS_MODE = "activations"
# some extra ops to ignore from counting.
_IGNORED_OPS = {
"aten::add",
"aten::add_",
"aten::batch_norm",
"aten::constant_pad_nd",
"aten::div",
"aten::div_",
"aten::exp",
"aten::log2",
"aten::max_pool2d",
"aten::meshgrid",
"aten::mul",
"aten::mul_",
"aten::nonzero_numpy",
"aten::rsub",
"aten::sigmoid",
"aten::sigmoid_",
"aten::softmax",
"aten::sort",
"aten::sqrt",
"aten::sub",
"aten::upsample_nearest2d",
"prim::PythonOp",
"torchvision::nms", # TODO estimate flop for nms
}
def flop_count_operators(
model: nn.Module, inputs: list, **kwargs
) -> typing.DefaultDict[str, float]:
"""
Implement operator-level flops counting using jit.
This is a wrapper of fvcore.nn.flop_count, that supports standard detection models
in detectron2.
Note:
The function runs the input through the model to compute flops.
The flops of a detection model is often input-dependent, for example,
the flops of box & mask head depends on the number of proposals &
the number of detected objects.
Therefore, the flops counting using a single input may not accurately
reflect the computation cost of a model.
Args:
model: a detectron2 model that takes `list[dict]` as input.
inputs (list[dict]): inputs to model, in detectron2's standard format.
"""
return _wrapper_count_operators(model=model, inputs=inputs, mode=FLOPS_MODE, **kwargs)
def activation_count_operators(
model: nn.Module, inputs: list, **kwargs
) -> typing.DefaultDict[str, float]:
"""
Implement operator-level activations counting using jit.
This is a wrapper of fvcore.nn.activation_count, that supports standard detection models
in detectron2.
Note:
The function runs the input through the model to compute activations.
The activations of a detection model is often input-dependent, for example,
the activations of box & mask head depends on the number of proposals &
the number of detected objects.
Args:
model: a detectron2 model that takes `list[dict]` as input.
inputs (list[dict]): inputs to model, in detectron2's standard format.
"""
return _wrapper_count_operators(model=model, inputs=inputs, mode=ACTIVATIONS_MODE, **kwargs)
def _flatten_to_tuple(outputs):
result = []
if isinstance(outputs, torch.Tensor):
result.append(outputs)
elif isinstance(outputs, (list, tuple)):
for v in outputs:
result.extend(_flatten_to_tuple(v))
elif isinstance(outputs, dict):
for _, v in outputs.items():
result.extend(_flatten_to_tuple(v))
elif isinstance(outputs, Instances):
result.extend(_flatten_to_tuple(outputs.get_fields()))
elif isinstance(outputs, (Boxes, BitMasks, ImageList)):
result.append(outputs.tensor)
else:
log_first_n(
logging.WARN,
f"Output of type {type(outputs)} not included in flops/activations count.",
n=10,
)
return tuple(result)
def _wrapper_count_operators(
model: nn.Module, inputs: list, mode: str, **kwargs
) -> typing.DefaultDict[str, float]:
# ignore some ops
supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS}
supported_ops.update(kwargs.pop("supported_ops", {}))
kwargs["supported_ops"] = supported_ops
assert len(inputs) == 1, "Please use batch size=1"
tensor_input = inputs[0]["image"]
class WrapModel(nn.Module):
def __init__(self, model):
super().__init__()
if isinstance(
model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)
):
self.model = model.module
else:
self.model = model
def forward(self, image):
# jit requires the input/output to be Tensors
inputs = [{"image": image}]
outputs = self.model.forward(inputs)
# Only the subgraph that computes the returned tuple of tensor will be
# counted. So we flatten everything we found to tuple of tensors.
return _flatten_to_tuple(outputs)
old_train = model.training
with torch.no_grad():
if mode == FLOPS_MODE:
ret = flop_count(WrapModel(model).train(False), (tensor_input,), **kwargs)
elif mode == ACTIVATIONS_MODE:
ret = activation_count(WrapModel(model).train(False), (tensor_input,), **kwargs)
else:
raise NotImplementedError("Count for mode {} is not supported yet.".format(mode))
# compatible with change in fvcore
if isinstance(ret, tuple):
ret = ret[0]
model.train(old_train)
return ret
|