rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import tempfile
from functools import partial
from pathlib import Path
import numpy as np
import torch
from mmengine.config import Config, DictAction
from mmengine.logging import MMLogger
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.runner import Runner
from mmengine.utils import digit_version
from mmdet.registry import MODELS
try:
from mmengine.analysis import get_model_complexity_info
from mmengine.analysis.print_helper import _format_size
except ImportError:
raise ImportError('Please upgrade mmengine >= 0.6.0')
def parse_args():
parser = argparse.ArgumentParser(description='Get a detector flops')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--num-images',
type=int,
default=100,
help='num images of calculate model flops')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args
def inference(args, logger):
if digit_version(torch.__version__) < digit_version('1.12'):
logger.warning(
'Some config files, such as configs/yolact and configs/detectors,'
'may have compatibility issues with torch.jit when torch<1.12. '
'If you want to calculate flops for these models, '
'please make sure your pytorch version is >=1.12.')
config_name = Path(args.config)
if not config_name.exists():
logger.error(f'{config_name} not found.')
cfg = Config.fromfile(args.config)
cfg.val_dataloader.batch_size = 1
cfg.work_dir = tempfile.TemporaryDirectory().name
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
init_default_scope(cfg.get('default_scope', 'mmdet'))
# TODO: The following usage is temporary and not safe
# use hard code to convert mmSyncBN to SyncBN. This is a known
# bug in mmengine, mmSyncBN requires a distributed environment,
# this question involves models like configs/strong_baselines
if hasattr(cfg, 'head_norm_cfg'):
cfg['head_norm_cfg'] = dict(type='SyncBN', requires_grad=True)
cfg['model']['roi_head']['bbox_head']['norm_cfg'] = dict(
type='SyncBN', requires_grad=True)
cfg['model']['roi_head']['mask_head']['norm_cfg'] = dict(
type='SyncBN', requires_grad=True)
result = {}
avg_flops = []
data_loader = Runner.build_dataloader(cfg.val_dataloader)
model = MODELS.build(cfg.model)
if torch.cuda.is_available():
model = model.cuda()
model = revert_sync_batchnorm(model)
model.eval()
_forward = model.forward
for idx, data_batch in enumerate(data_loader):
if idx == args.num_images:
break
data = model.data_preprocessor(data_batch)
result['ori_shape'] = data['data_samples'][0].ori_shape
result['pad_shape'] = data['data_samples'][0].pad_shape
if hasattr(data['data_samples'][0], 'batch_input_shape'):
result['pad_shape'] = data['data_samples'][0].batch_input_shape
model.forward = partial(_forward, data_samples=data['data_samples'])
outputs = get_model_complexity_info(
model,
None,
inputs=data['inputs'],
show_table=False,
show_arch=False)
avg_flops.append(outputs['flops'])
params = outputs['params']
result['compute_type'] = 'dataloader: load a picture from the dataset'
del data_loader
mean_flops = _format_size(int(np.average(avg_flops)))
params = _format_size(params)
result['flops'] = mean_flops
result['params'] = params
return result
def main():
args = parse_args()
logger = MMLogger.get_instance(name='MMLogger')
result = inference(args, logger)
split_line = '=' * 30
ori_shape = result['ori_shape']
pad_shape = result['pad_shape']
flops = result['flops']
params = result['params']
compute_type = result['compute_type']
if pad_shape != ori_shape:
print(f'{split_line}\nUse size divisor set input shape '
f'from {ori_shape} to {pad_shape}')
print(f'{split_line}\nCompute type: {compute_type}\n'
f'Input shape: {pad_shape}\nFlops: {flops}\n'
f'Params: {params}\n{split_line}')
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify '
'that the flops computation is correct.')
if __name__ == '__main__':
main()