rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Dict, Optional
import mmengine
import torch # noqa
import torch.nn as nn
from mmengine.hooks import Hook
from mmengine.logging import print_log
from mmengine.registry import HOOKS
from .operator import BaseConvRFSearchOp, Conv2dRFSearchOp # noqa
from .utils import get_single_padding, write_to_json
@HOOKS.register_module()
class RFSearchHook(Hook):
"""Rcecptive field search via dilation rates.
Please refer to `RF-Next: Efficient Receptive Field
Search for Convolutional Neural Networks
<https://arxiv.org/abs/2206.06637>`_ for more details.
Args:
mode (str, optional): It can be set to the following types:
'search', 'fixed_single_branch', or 'fixed_multi_branch'.
Defaults to 'search'.
config (Dict, optional): config dict of search.
By default this config contains "search",
and config["search"] must include:
- "step": recording the current searching step.
- "max_step": The maximum number of searching steps
to update the structures.
- "search_interval": The interval (epoch/iteration)
between two updates.
- "exp_rate": The controller of the sparsity of search space.
- "init_alphas": The value for initializing weights of each branch.
- "mmin": The minimum dilation rate.
- "mmax": The maximum dilation rate.
- "num_branches": The controller of the size of
search space (the number of branches).
- "skip_layer": The modules in skip_layer will be ignored
during the receptive field search.
rfstructure_file (str, optional): Path to load searched receptive
fields of the model. Defaults to None.
by_epoch (bool, optional): Determine to perform step by epoch or
by iteration. If set to True, it will step by epoch. Otherwise, by
iteration. Defaults to True.
verbose (bool): Determines whether to print rf-next related logging
messages. Defaults to True.
"""
def __init__(self,
mode: str = 'search',
config: Dict = {},
rfstructure_file: Optional[str] = None,
by_epoch: bool = True,
verbose: bool = True):
assert mode in ['search', 'fixed_single_branch', 'fixed_multi_branch']
assert config is not None
self.config = config
self.config['structure'] = {}
self.verbose = verbose
if rfstructure_file is not None:
rfstructure = mmengine.load(rfstructure_file)['structure']
self.config['structure'] = rfstructure
self.mode = mode
self.num_branches = self.config['search']['num_branches']
self.by_epoch = by_epoch
def init_model(self, model: nn.Module):
"""init model with search ability.
Args:
model (nn.Module): pytorch model
Raises:
NotImplementedError: only support three modes:
search/fixed_single_branch/fixed_multi_branch
"""
if self.verbose:
print_log('RFSearch init begin.', 'current')
if self.mode == 'search':
if self.config['structure']:
self.set_model(model, search_op='Conv2d')
self.wrap_model(model, search_op='Conv2d')
elif self.mode == 'fixed_single_branch':
self.set_model(model, search_op='Conv2d')
elif self.mode == 'fixed_multi_branch':
self.set_model(model, search_op='Conv2d')
self.wrap_model(model, search_op='Conv2d')
else:
raise NotImplementedError
if self.verbose:
print_log('RFSearch init end.', 'current')
def after_train_epoch(self, runner):
"""Performs a dilation searching step after one training epoch."""
if self.by_epoch and self.mode == 'search':
self.step(runner.model, runner.work_dir)
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
"""Performs a dilation searching step after one training iteration."""
if not self.by_epoch and self.mode == 'search':
self.step(runner.model, runner.work_dir)
def step(self, model: nn.Module, work_dir: str) -> None:
"""Performs a dilation searching step.
Args:
model (nn.Module): pytorch model
work_dir (str): Directory to save the searching results.
"""
self.config['search']['step'] += 1
if (self.config['search']['step']
) % self.config['search']['search_interval'] == 0 and (self.config[
'search']['step']) < self.config['search']['max_step']:
self.estimate_and_expand(model)
for name, module in model.named_modules():
if isinstance(module, BaseConvRFSearchOp):
self.config['structure'][name] = module.op_layer.dilation
write_to_json(
self.config,
os.path.join(
work_dir,
'local_search_config_step%d.json' %
self.config['search']['step'],
),
)
def estimate_and_expand(self, model: nn.Module) -> None:
"""estimate and search for RFConvOp.
Args:
model (nn.Module): pytorch model
"""
for module in model.modules():
if isinstance(module, BaseConvRFSearchOp):
module.estimate_rates()
module.expand_rates()
def wrap_model(self,
model: nn.Module,
search_op: str = 'Conv2d',
prefix: str = '') -> None:
"""wrap model to support searchable conv op.
Args:
model (nn.Module): pytorch model
search_op (str): The module that uses RF search.
Defaults to 'Conv2d'.
init_rates (int, optional): Set to other initial dilation rates.
Defaults to None.
prefix (str): Prefix for function recursion. Defaults to ''.
"""
op = 'torch.nn.' + search_op
for name, module in model.named_children():
if prefix == '':
fullname = 'module.' + name
else:
fullname = prefix + '.' + name
if self.config['search']['skip_layer'] is not None:
if any(layer in fullname
for layer in self.config['search']['skip_layer']):
continue
if isinstance(module, eval(op)):
if 1 < module.kernel_size[0] and \
0 != module.kernel_size[0] % 2 or \
1 < module.kernel_size[1] and \
0 != module.kernel_size[1] % 2:
moduleWrap = eval(search_op + 'RFSearchOp')(
module, self.config['search'], self.verbose)
moduleWrap = moduleWrap.to(module.weight.device)
if self.verbose:
print_log(
'Wrap model %s to %s.' %
(str(module), str(moduleWrap)), 'current')
setattr(model, name, moduleWrap)
elif not isinstance(module, BaseConvRFSearchOp):
self.wrap_model(module, search_op, fullname)
def set_model(self,
model: nn.Module,
search_op: str = 'Conv2d',
init_rates: Optional[int] = None,
prefix: str = '') -> None:
"""set model based on config.
Args:
model (nn.Module): pytorch model
config (Dict): config file
search_op (str): The module that uses RF search.
Defaults to 'Conv2d'.
init_rates (int, optional): Set to other initial dilation rates.
Defaults to None.
prefix (str): Prefix for function recursion. Defaults to ''.
"""
op = 'torch.nn.' + search_op
for name, module in model.named_children():
if prefix == '':
fullname = 'module.' + name
else:
fullname = prefix + '.' + name
if self.config['search']['skip_layer'] is not None:
if any(layer in fullname
for layer in self.config['search']['skip_layer']):
continue
if isinstance(module, eval(op)):
if 1 < module.kernel_size[0] and \
0 != module.kernel_size[0] % 2 or \
1 < module.kernel_size[1] and \
0 != module.kernel_size[1] % 2:
if isinstance(self.config['structure'][fullname], int):
self.config['structure'][fullname] = [
self.config['structure'][fullname],
self.config['structure'][fullname]
]
module.dilation = (
self.config['structure'][fullname][0],
self.config['structure'][fullname][1],
)
module.padding = (
get_single_padding(
module.kernel_size[0], module.stride[0],
self.config['structure'][fullname][0]),
get_single_padding(
module.kernel_size[1], module.stride[1],
self.config['structure'][fullname][1]))
setattr(model, name, module)
if self.verbose:
print_log(
'Set module %s dilation as: [%d %d]' %
(fullname, module.dilation[0], module.dilation[1]),
'current')
elif not isinstance(module, BaseConvRFSearchOp):
self.set_model(module, search_op, init_rates, fullname)