Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
"""Utility function for weight initialization""" | |
import torch.nn as nn | |
from fvcore.nn.weight_init import c2_msra_fill | |
def init_weights(model, fc_init_std=0.01, zero_init_final_bn=True): | |
""" | |
Performs ResNet style weight initialization. | |
Args: | |
fc_init_std (float): the expected standard deviation for fc layer. | |
zero_init_final_bn (bool): if True, zero initialize the final bn for | |
every bottleneck. | |
""" | |
for m in model.modules(): | |
if isinstance(m, nn.Conv3d): | |
""" | |
Follow the initialization method proposed in: | |
{He, Kaiming, et al. | |
"Delving deep into rectifiers: Surpassing human-level | |
performance on imagenet classification." | |
arXiv preprint arXiv:1502.01852 (2015)} | |
""" | |
c2_msra_fill(m) | |
elif isinstance(m, nn.BatchNorm3d): | |
if ( | |
hasattr(m, "transform_final_bn") | |
and m.transform_final_bn | |
and zero_init_final_bn | |
): | |
batchnorm_weight = 0.0 | |
else: | |
batchnorm_weight = 1.0 | |
if m.weight is not None: | |
m.weight.data.fill_(batchnorm_weight) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
if isinstance(m, nn.Linear): | |
m.weight.data.normal_(mean=0.0, std=fc_init_std) | |
if m.bias is not None: | |
m.bias.data.zero_() | |