# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from operator import attrgetter from typing import List, Union import torch import torch.nn as nn def efficient_conv_bn_eval_forward(bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor): """Code borrowed from mmcv 2.0.1, so that this feature can be used for old mmcv versions. Implementation based on https://arxiv.org/abs/2305.11624 "Tune-Mode ConvBN Blocks For Efficient Transfer Learning" It leverages the associative law between convolution and affine transform, i.e., normalize (weight conv feature) = (normalize weight) conv feature. It works for Eval mode of ConvBN blocks during validation, and can be used for training as well. It reduces memory and computation cost. Args: bn (_BatchNorm): a BatchNorm module. conv (nn._ConvNd): a conv module x (torch.Tensor): Input feature map. """ # These lines of code are designed to deal with various cases # like bn without affine transform, and conv without bias weight_on_the_fly = conv.weight if conv.bias is not None: bias_on_the_fly = conv.bias else: bias_on_the_fly = torch.zeros_like(bn.running_var) if bn.weight is not None: bn_weight = bn.weight else: bn_weight = torch.ones_like(bn.running_var) if bn.bias is not None: bn_bias = bn.bias else: bn_bias = torch.zeros_like(bn.running_var) # shape of [C_out, 1, 1, 1] in Conv2d weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape([-1] + [1] * (len(conv.weight.shape) - 1)) # shape of [C_out, 1, 1, 1] in Conv2d coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff # shape of [C_out, C_in, k, k] in Conv2d weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly # shape of [C_out] in Conv2d bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\ (bias_on_the_fly - bn.running_mean) return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly) def efficient_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor): """This function controls whether to use `efficient_conv_bn_eval_forward`. If the following `bn` is in `eval` mode, then we turn on the special `efficient_conv_bn_eval_forward`. """ if not bn.training: # bn in eval mode output = efficient_conv_bn_eval_forward(bn, conv, x) return output else: conv_out = conv._conv_forward(x, conv.weight, conv.bias) return bn(conv_out) def efficient_conv_bn_eval_graph_transform(fx_model): """Find consecutive conv+bn calls in the graph, inplace modify the graph with the fused operation.""" modules = dict(fx_model.named_modules()) patterns = [(torch.nn.modules.conv._ConvNd, torch.nn.modules.batchnorm._BatchNorm)] pairs = [] # Iterate through nodes in the graph to find ConvBN blocks for node in fx_model.graph.nodes: # If our current node isn't calling a Module then we can ignore it. if node.op != 'call_module': continue target_module = modules[node.target] found_pair = False for conv_class, bn_class in patterns: if isinstance(target_module, bn_class): source_module = modules[node.args[0].target] if isinstance(source_module, conv_class): found_pair = True # Not a conv-BN pattern or output of conv is used by other nodes if not found_pair or len(node.args[0].users) > 1: continue # Find a pair of conv and bn computation nodes to optimize conv_node = node.args[0] bn_node = node pairs.append([conv_node, bn_node]) for conv_node, bn_node in pairs: # set insertion point fx_model.graph.inserting_before(conv_node) # create `get_attr` node to access modules # note that we directly call `create_node` to fill the `name` # argument. `fx_model.graph.get_attr` and # `fx_model.graph.call_function` does not allow the `name` argument. conv_get_node = fx_model.graph.create_node( op='get_attr', target=conv_node.target, name='get_conv') bn_get_node = fx_model.graph.create_node( op='get_attr', target=bn_node.target, name='get_bn') # prepare args for the fused function args = (bn_get_node, conv_get_node, conv_node.args[0]) # create a new node new_node = fx_model.graph.create_node( op='call_function', target=efficient_conv_bn_eval_control, args=args, name='efficient_conv_bn_eval') # this node replaces the original conv + bn, and therefore # should replace the uses of bn_node bn_node.replace_all_uses_with(new_node) # take care of the deletion order: # delete bn_node first, and then conv_node fx_model.graph.erase_node(bn_node) fx_model.graph.erase_node(conv_node) # regenerate the code fx_model.graph.lint() fx_model.recompile() def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module): import torch.fx as fx # currently we use `fx.symbolic_trace` to trace models. # in the future, we might turn to pytorch 2.0 compile infrastructure to # get the `fx.GraphModule` IR. Nonetheless, the graph transform function # can remain unchanged. We just need to change the way # we get `fx.GraphModule`. fx_model: fx.GraphModule = fx.symbolic_trace(model) efficient_conv_bn_eval_graph_transform(fx_model) model.forward = fx_model.forward def turn_on_efficient_conv_bn_eval(model: torch.nn.Module, modules: Union[List[str], str]): if isinstance(modules, str): modules = [modules] for module_name in modules: module = attrgetter(module_name)(model) turn_on_efficient_conv_bn_eval_for_single_model(module)