rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import copy
import os
import os.path as osp
from mmengine.config import Config, DictAction
from mmengine.dist import get_dist_info
from mmengine.evaluator import DumpResults
from mmengine.fileio import dump
from mmengine.runner import Runner
from mmdet.engine.hooks.utils import trigger_visualization_hook
from mmdet.registry import RUNNERS
from tools.analysis_tools.robustness_eval import get_results
def parse_args():
parser = argparse.ArgumentParser(description='MMDet test detector')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--out',
type=str,
help='dump predictions to a pickle file for offline evaluation')
parser.add_argument(
'--corruptions',
type=str,
nargs='+',
default='benchmark',
choices=[
'all', 'benchmark', 'noise', 'blur', 'weather', 'digital',
'holdout', 'None', 'gaussian_noise', 'shot_noise', 'impulse_noise',
'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', 'snow',
'frost', 'fog', 'brightness', 'contrast', 'elastic_transform',
'pixelate', 'jpeg_compression', 'speckle_noise', 'gaussian_blur',
'spatter', 'saturate'
],
help='corruptions')
parser.add_argument(
'--work-dir',
help='the directory to save the file containing evaluation metrics')
parser.add_argument(
'--severities',
type=int,
nargs='+',
default=[0, 1, 2, 3, 4, 5],
help='corruption severity levels')
parser.add_argument(
'--summaries',
type=bool,
default=False,
help='Print summaries for every corruption and severity')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument(
'--show-dir', help='directory where painted images will be saved')
parser.add_argument(
'--wait-time', type=float, default=2, help='the interval of show (s)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--final-prints',
type=str,
nargs='+',
choices=['P', 'mPC', 'rPC'],
default='mPC',
help='corruption benchmark metric to print at the end')
parser.add_argument(
'--final-prints-aggregate',
type=str,
choices=['all', 'benchmark'],
default='benchmark',
help='aggregate all results or only those for benchmark corruptions')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
assert args.out or args.show or args.show_dir, \
('Please specify at least one operation (save or show the results) '
'with the argument "--out", "--show" or "show-dir"')
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
cfg.model.backbone.init_cfg.type = None
cfg.test_dataloader.dataset.test_mode = True
cfg.load_from = args.checkpoint
if args.show or args.show_dir:
cfg = trigger_visualization_hook(cfg, args)
# build the runner from config
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
# add `DumpResults` dummy metric
if args.out is not None:
assert args.out.endswith(('.pkl', '.pickle')), \
'The dump file must be a pkl file.'
runner.test_evaluator.metrics.append(
DumpResults(out_file_path=args.out))
if 'all' in args.corruptions:
corruptions = [
'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
'brightness', 'contrast', 'elastic_transform', 'pixelate',
'jpeg_compression', 'speckle_noise', 'gaussian_blur', 'spatter',
'saturate'
]
elif 'benchmark' in args.corruptions:
corruptions = [
'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
'brightness', 'contrast', 'elastic_transform', 'pixelate',
'jpeg_compression'
]
elif 'noise' in args.corruptions:
corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise']
elif 'blur' in args.corruptions:
corruptions = [
'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur'
]
elif 'weather' in args.corruptions:
corruptions = ['snow', 'frost', 'fog', 'brightness']
elif 'digital' in args.corruptions:
corruptions = [
'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression'
]
elif 'holdout' in args.corruptions:
corruptions = ['speckle_noise', 'gaussian_blur', 'spatter', 'saturate']
elif 'None' in args.corruptions:
corruptions = ['None']
args.severities = [0]
else:
corruptions = args.corruptions
aggregated_results = {}
for corr_i, corruption in enumerate(corruptions):
aggregated_results[corruption] = {}
for sev_i, corruption_severity in enumerate(args.severities):
# evaluate severity 0 (= no corruption) only once
if corr_i > 0 and corruption_severity == 0:
aggregated_results[corruption][0] = \
aggregated_results[corruptions[0]][0]
continue
test_loader_cfg = copy.deepcopy(cfg.test_dataloader)
# assign corruption and severity
if corruption_severity > 0:
corruption_trans = dict(
type='Corrupt',
corruption=corruption,
severity=corruption_severity)
# TODO: hard coded "1", we assume that the first step is
# loading images, which needs to be fixed in the future
test_loader_cfg.dataset.pipeline.insert(1, corruption_trans)
test_loader = runner.build_dataloader(test_loader_cfg)
runner.test_loop.dataloader = test_loader
# set random seeds
if args.seed is not None:
runner.set_randomness(args.seed)
# print info
print(f'\nTesting {corruption} at severity {corruption_severity}')
eval_results = runner.test()
if args.out:
eval_results_filename = (
osp.splitext(args.out)[0] + '_results' +
osp.splitext(args.out)[1])
aggregated_results[corruption][
corruption_severity] = eval_results
dump(aggregated_results, eval_results_filename)
rank, _ = get_dist_info()
if rank == 0:
eval_results_filename = (
osp.splitext(args.out)[0] + '_results' + osp.splitext(args.out)[1])
# print final results
print('\nAggregated results:')
prints = args.final_prints
aggregate = args.final_prints_aggregate
if cfg.dataset_type == 'VOCDataset':
get_results(
eval_results_filename,
dataset='voc',
prints=prints,
aggregate=aggregate)
else:
get_results(
eval_results_filename,
dataset='coco',
prints=prints,
aggregate=aggregate)
if __name__ == '__main__':
main()