Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
Based on https://github.com/mlfoundations/open_clip | |
""" | |
from torch import nn as nn | |
from torchvision.ops.misc import FrozenBatchNorm2d | |
def freeze_batch_norm_2d(module, module_match={}, name=""): | |
""" | |
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is | |
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and | |
returned. Otherwise, the module is walked recursively and submodules are converted in place. | |
Args: | |
module (torch.nn.Module): Any PyTorch module. | |
module_match (dict): Dictionary of full module names to freeze (all if empty) | |
name (str): Full module name (prefix) | |
Returns: | |
torch.nn.Module: Resulting module | |
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 | |
""" | |
res = module | |
is_match = True | |
if module_match: | |
is_match = name in module_match | |
if is_match and isinstance( | |
module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) | |
): | |
res = FrozenBatchNorm2d(module.num_features) | |
res.num_features = module.num_features | |
res.affine = module.affine | |
if module.affine: | |
res.weight.data = module.weight.data.clone().detach() | |
res.bias.data = module.bias.data.clone().detach() | |
res.running_mean.data = module.running_mean.data | |
res.running_var.data = module.running_var.data | |
res.eps = module.eps | |
else: | |
for child_name, child in module.named_children(): | |
full_child_name = ".".join([name, child_name]) if name else child_name | |
new_child = freeze_batch_norm_2d(child, module_match, full_child_name) | |
if new_child is not child: | |
res.add_module(child_name, new_child) | |
return res | |