Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
from typing import Dict, Optional | |
import mmengine | |
import torch # noqa | |
import torch.nn as nn | |
from mmengine.hooks import Hook | |
from mmengine.logging import print_log | |
from mmengine.registry import HOOKS | |
from .operator import BaseConvRFSearchOp, Conv2dRFSearchOp # noqa | |
from .utils import get_single_padding, write_to_json | |
class RFSearchHook(Hook): | |
"""Rcecptive field search via dilation rates. | |
Please refer to `RF-Next: Efficient Receptive Field | |
Search for Convolutional Neural Networks | |
<https://arxiv.org/abs/2206.06637>`_ for more details. | |
Args: | |
mode (str, optional): It can be set to the following types: | |
'search', 'fixed_single_branch', or 'fixed_multi_branch'. | |
Defaults to 'search'. | |
config (Dict, optional): config dict of search. | |
By default this config contains "search", | |
and config["search"] must include: | |
- "step": recording the current searching step. | |
- "max_step": The maximum number of searching steps | |
to update the structures. | |
- "search_interval": The interval (epoch/iteration) | |
between two updates. | |
- "exp_rate": The controller of the sparsity of search space. | |
- "init_alphas": The value for initializing weights of each branch. | |
- "mmin": The minimum dilation rate. | |
- "mmax": The maximum dilation rate. | |
- "num_branches": The controller of the size of | |
search space (the number of branches). | |
- "skip_layer": The modules in skip_layer will be ignored | |
during the receptive field search. | |
rfstructure_file (str, optional): Path to load searched receptive | |
fields of the model. Defaults to None. | |
by_epoch (bool, optional): Determine to perform step by epoch or | |
by iteration. If set to True, it will step by epoch. Otherwise, by | |
iteration. Defaults to True. | |
verbose (bool): Determines whether to print rf-next related logging | |
messages. Defaults to True. | |
""" | |
def __init__(self, | |
mode: str = 'search', | |
config: Dict = {}, | |
rfstructure_file: Optional[str] = None, | |
by_epoch: bool = True, | |
verbose: bool = True): | |
assert mode in ['search', 'fixed_single_branch', 'fixed_multi_branch'] | |
assert config is not None | |
self.config = config | |
self.config['structure'] = {} | |
self.verbose = verbose | |
if rfstructure_file is not None: | |
rfstructure = mmengine.load(rfstructure_file)['structure'] | |
self.config['structure'] = rfstructure | |
self.mode = mode | |
self.num_branches = self.config['search']['num_branches'] | |
self.by_epoch = by_epoch | |
def init_model(self, model: nn.Module): | |
"""init model with search ability. | |
Args: | |
model (nn.Module): pytorch model | |
Raises: | |
NotImplementedError: only support three modes: | |
search/fixed_single_branch/fixed_multi_branch | |
""" | |
if self.verbose: | |
print_log('RFSearch init begin.', 'current') | |
if self.mode == 'search': | |
if self.config['structure']: | |
self.set_model(model, search_op='Conv2d') | |
self.wrap_model(model, search_op='Conv2d') | |
elif self.mode == 'fixed_single_branch': | |
self.set_model(model, search_op='Conv2d') | |
elif self.mode == 'fixed_multi_branch': | |
self.set_model(model, search_op='Conv2d') | |
self.wrap_model(model, search_op='Conv2d') | |
else: | |
raise NotImplementedError | |
if self.verbose: | |
print_log('RFSearch init end.', 'current') | |
def after_train_epoch(self, runner): | |
"""Performs a dilation searching step after one training epoch.""" | |
if self.by_epoch and self.mode == 'search': | |
self.step(runner.model, runner.work_dir) | |
def after_train_iter(self, runner, batch_idx, data_batch, outputs): | |
"""Performs a dilation searching step after one training iteration.""" | |
if not self.by_epoch and self.mode == 'search': | |
self.step(runner.model, runner.work_dir) | |
def step(self, model: nn.Module, work_dir: str) -> None: | |
"""Performs a dilation searching step. | |
Args: | |
model (nn.Module): pytorch model | |
work_dir (str): Directory to save the searching results. | |
""" | |
self.config['search']['step'] += 1 | |
if (self.config['search']['step'] | |
) % self.config['search']['search_interval'] == 0 and (self.config[ | |
'search']['step']) < self.config['search']['max_step']: | |
self.estimate_and_expand(model) | |
for name, module in model.named_modules(): | |
if isinstance(module, BaseConvRFSearchOp): | |
self.config['structure'][name] = module.op_layer.dilation | |
write_to_json( | |
self.config, | |
os.path.join( | |
work_dir, | |
'local_search_config_step%d.json' % | |
self.config['search']['step'], | |
), | |
) | |
def estimate_and_expand(self, model: nn.Module) -> None: | |
"""estimate and search for RFConvOp. | |
Args: | |
model (nn.Module): pytorch model | |
""" | |
for module in model.modules(): | |
if isinstance(module, BaseConvRFSearchOp): | |
module.estimate_rates() | |
module.expand_rates() | |
def wrap_model(self, | |
model: nn.Module, | |
search_op: str = 'Conv2d', | |
prefix: str = '') -> None: | |
"""wrap model to support searchable conv op. | |
Args: | |
model (nn.Module): pytorch model | |
search_op (str): The module that uses RF search. | |
Defaults to 'Conv2d'. | |
init_rates (int, optional): Set to other initial dilation rates. | |
Defaults to None. | |
prefix (str): Prefix for function recursion. Defaults to ''. | |
""" | |
op = 'torch.nn.' + search_op | |
for name, module in model.named_children(): | |
if prefix == '': | |
fullname = 'module.' + name | |
else: | |
fullname = prefix + '.' + name | |
if self.config['search']['skip_layer'] is not None: | |
if any(layer in fullname | |
for layer in self.config['search']['skip_layer']): | |
continue | |
if isinstance(module, eval(op)): | |
if 1 < module.kernel_size[0] and \ | |
0 != module.kernel_size[0] % 2 or \ | |
1 < module.kernel_size[1] and \ | |
0 != module.kernel_size[1] % 2: | |
moduleWrap = eval(search_op + 'RFSearchOp')( | |
module, self.config['search'], self.verbose) | |
moduleWrap = moduleWrap.to(module.weight.device) | |
if self.verbose: | |
print_log( | |
'Wrap model %s to %s.' % | |
(str(module), str(moduleWrap)), 'current') | |
setattr(model, name, moduleWrap) | |
elif not isinstance(module, BaseConvRFSearchOp): | |
self.wrap_model(module, search_op, fullname) | |
def set_model(self, | |
model: nn.Module, | |
search_op: str = 'Conv2d', | |
init_rates: Optional[int] = None, | |
prefix: str = '') -> None: | |
"""set model based on config. | |
Args: | |
model (nn.Module): pytorch model | |
config (Dict): config file | |
search_op (str): The module that uses RF search. | |
Defaults to 'Conv2d'. | |
init_rates (int, optional): Set to other initial dilation rates. | |
Defaults to None. | |
prefix (str): Prefix for function recursion. Defaults to ''. | |
""" | |
op = 'torch.nn.' + search_op | |
for name, module in model.named_children(): | |
if prefix == '': | |
fullname = 'module.' + name | |
else: | |
fullname = prefix + '.' + name | |
if self.config['search']['skip_layer'] is not None: | |
if any(layer in fullname | |
for layer in self.config['search']['skip_layer']): | |
continue | |
if isinstance(module, eval(op)): | |
if 1 < module.kernel_size[0] and \ | |
0 != module.kernel_size[0] % 2 or \ | |
1 < module.kernel_size[1] and \ | |
0 != module.kernel_size[1] % 2: | |
if isinstance(self.config['structure'][fullname], int): | |
self.config['structure'][fullname] = [ | |
self.config['structure'][fullname], | |
self.config['structure'][fullname] | |
] | |
module.dilation = ( | |
self.config['structure'][fullname][0], | |
self.config['structure'][fullname][1], | |
) | |
module.padding = ( | |
get_single_padding( | |
module.kernel_size[0], module.stride[0], | |
self.config['structure'][fullname][0]), | |
get_single_padding( | |
module.kernel_size[1], module.stride[1], | |
self.config['structure'][fullname][1])) | |
setattr(model, name, module) | |
if self.verbose: | |
print_log( | |
'Set module %s dilation as: [%d %d]' % | |
(fullname, module.dilation[0], module.dilation[1]), | |
'current') | |
elif not isinstance(module, BaseConvRFSearchOp): | |
self.set_model(module, search_op, init_rates, fullname) | |