Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import tempfile | |
from functools import partial | |
from pathlib import Path | |
import numpy as np | |
import torch | |
from mmengine.config import Config, DictAction | |
from mmengine.logging import MMLogger | |
from mmengine.model import revert_sync_batchnorm | |
from mmengine.registry import init_default_scope | |
from mmengine.runner import Runner | |
from mmengine.utils import digit_version | |
from mmdet.registry import MODELS | |
try: | |
from mmengine.analysis import get_model_complexity_info | |
from mmengine.analysis.print_helper import _format_size | |
except ImportError: | |
raise ImportError('Please upgrade mmengine >= 0.6.0') | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Get a detector flops') | |
parser.add_argument('config', help='train config file path') | |
parser.add_argument( | |
'--num-images', | |
type=int, | |
default=100, | |
help='num images of calculate model flops') | |
parser.add_argument( | |
'--cfg-options', | |
nargs='+', | |
action=DictAction, | |
help='override some settings in the used config, the key-value pair ' | |
'in xxx=yyy format will be merged into config file. If the value to ' | |
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' | |
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' | |
'Note that the quotation marks are necessary and that no white space ' | |
'is allowed.') | |
args = parser.parse_args() | |
return args | |
def inference(args, logger): | |
if digit_version(torch.__version__) < digit_version('1.12'): | |
logger.warning( | |
'Some config files, such as configs/yolact and configs/detectors,' | |
'may have compatibility issues with torch.jit when torch<1.12. ' | |
'If you want to calculate flops for these models, ' | |
'please make sure your pytorch version is >=1.12.') | |
config_name = Path(args.config) | |
if not config_name.exists(): | |
logger.error(f'{config_name} not found.') | |
cfg = Config.fromfile(args.config) | |
cfg.val_dataloader.batch_size = 1 | |
cfg.work_dir = tempfile.TemporaryDirectory().name | |
if args.cfg_options is not None: | |
cfg.merge_from_dict(args.cfg_options) | |
init_default_scope(cfg.get('default_scope', 'mmdet')) | |
# TODO: The following usage is temporary and not safe | |
# use hard code to convert mmSyncBN to SyncBN. This is a known | |
# bug in mmengine, mmSyncBN requires a distributed environment, | |
# this question involves models like configs/strong_baselines | |
if hasattr(cfg, 'head_norm_cfg'): | |
cfg['head_norm_cfg'] = dict(type='SyncBN', requires_grad=True) | |
cfg['model']['roi_head']['bbox_head']['norm_cfg'] = dict( | |
type='SyncBN', requires_grad=True) | |
cfg['model']['roi_head']['mask_head']['norm_cfg'] = dict( | |
type='SyncBN', requires_grad=True) | |
result = {} | |
avg_flops = [] | |
data_loader = Runner.build_dataloader(cfg.val_dataloader) | |
model = MODELS.build(cfg.model) | |
if torch.cuda.is_available(): | |
model = model.cuda() | |
model = revert_sync_batchnorm(model) | |
model.eval() | |
_forward = model.forward | |
for idx, data_batch in enumerate(data_loader): | |
if idx == args.num_images: | |
break | |
data = model.data_preprocessor(data_batch) | |
result['ori_shape'] = data['data_samples'][0].ori_shape | |
result['pad_shape'] = data['data_samples'][0].pad_shape | |
if hasattr(data['data_samples'][0], 'batch_input_shape'): | |
result['pad_shape'] = data['data_samples'][0].batch_input_shape | |
model.forward = partial(_forward, data_samples=data['data_samples']) | |
outputs = get_model_complexity_info( | |
model, | |
None, | |
inputs=data['inputs'], | |
show_table=False, | |
show_arch=False) | |
avg_flops.append(outputs['flops']) | |
params = outputs['params'] | |
result['compute_type'] = 'dataloader: load a picture from the dataset' | |
del data_loader | |
mean_flops = _format_size(int(np.average(avg_flops))) | |
params = _format_size(params) | |
result['flops'] = mean_flops | |
result['params'] = params | |
return result | |
def main(): | |
args = parse_args() | |
logger = MMLogger.get_instance(name='MMLogger') | |
result = inference(args, logger) | |
split_line = '=' * 30 | |
ori_shape = result['ori_shape'] | |
pad_shape = result['pad_shape'] | |
flops = result['flops'] | |
params = result['params'] | |
compute_type = result['compute_type'] | |
if pad_shape != ori_shape: | |
print(f'{split_line}\nUse size divisor set input shape ' | |
f'from {ori_shape} to {pad_shape}') | |
print(f'{split_line}\nCompute type: {compute_type}\n' | |
f'Input shape: {pad_shape}\nFlops: {flops}\n' | |
f'Params: {params}\n{split_line}') | |
print('!!!Please be cautious if you use the results in papers. ' | |
'You may need to check if all ops are supported and verify ' | |
'that the flops computation is correct.') | |
if __name__ == '__main__': | |
main() | |