Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
"""Caffe2 to PyTorch checkpoint name converting utility.""" | |
import re | |
def get_name_convert_func(): | |
""" | |
Get the function to convert Caffe2 layer names to PyTorch layer names. | |
Returns: | |
(func): function to convert parameter name from Caffe2 format to PyTorch | |
format. | |
""" | |
pairs = [ | |
# ------------------------------------------------------------ | |
# 'nonlocal_conv3_1_theta_w' -> 's3.pathway0_nonlocal3.conv_g.weight' | |
[ | |
r"^nonlocal_conv([0-9]+)_([0-9]+)_(.*)", | |
r"s\1.pathway0_nonlocal\2_\3", | |
], | |
# 'theta' -> 'conv_theta' | |
[r"^(.*)_nonlocal([0-9]+)_(theta)(.*)", r"\1_nonlocal\2.conv_\3\4"], | |
# 'g' -> 'conv_g' | |
[r"^(.*)_nonlocal([0-9]+)_(g)(.*)", r"\1_nonlocal\2.conv_\3\4"], | |
# 'phi' -> 'conv_phi' | |
[r"^(.*)_nonlocal([0-9]+)_(phi)(.*)", r"\1_nonlocal\2.conv_\3\4"], | |
# 'out' -> 'conv_out' | |
[r"^(.*)_nonlocal([0-9]+)_(out)(.*)", r"\1_nonlocal\2.conv_\3\4"], | |
# 'nonlocal_conv4_5_bn_s' -> 's4.pathway0_nonlocal3.bn.weight' | |
[r"^(.*)_nonlocal([0-9]+)_(bn)_(.*)", r"\1_nonlocal\2.\3.\4"], | |
# ------------------------------------------------------------ | |
# 't_pool1_subsample_bn' -> 's1_fuse.conv_f2s.bn.running_mean' | |
[r"^t_pool1_subsample_bn_(.*)", r"s1_fuse.bn.\1"], | |
# 't_pool1_subsample' -> 's1_fuse.conv_f2s' | |
[r"^t_pool1_subsample_(.*)", r"s1_fuse.conv_f2s.\1"], | |
# 't_res4_5_branch2c_bn_subsample_bn_rm' -> 's4_fuse.conv_f2s.bias' | |
[ | |
r"^t_res([0-9]+)_([0-9]+)_branch2c_bn_subsample_bn_(.*)", | |
r"s\1_fuse.bn.\3", | |
], | |
# 't_pool1_subsample' -> 's1_fuse.conv_f2s' | |
[ | |
r"^t_res([0-9]+)_([0-9]+)_branch2c_bn_subsample_(.*)", | |
r"s\1_fuse.conv_f2s.\3", | |
], | |
# ------------------------------------------------------------ | |
# 'res4_4_branch_2c_bn_b' -> 's4.pathway0_res4.branch2.c_bn_b' | |
[ | |
r"^res([0-9]+)_([0-9]+)_branch([0-9]+)([a-z])_(.*)", | |
r"s\1.pathway0_res\2.branch\3.\4_\5", | |
], | |
# 'res_conv1_bn_' -> 's1.pathway0_stem.bn.' | |
[r"^res_conv1_bn_(.*)", r"s1.pathway0_stem.bn.\1"], | |
# 'conv1_xy_w_momentum' -> 's1.pathway0_stem.conv_xy.' | |
[r"^conv1_xy(.*)", r"s1.pathway0_stem.conv_xy\1"], | |
# 'conv1_w_momentum' -> 's1.pathway0_stem.conv.' | |
[r"^conv1_(.*)", r"s1.pathway0_stem.conv.\1"], | |
# 'res4_0_branch1_w' -> 'S4.pathway0_res0.branch1.weight' | |
[ | |
r"^res([0-9]+)_([0-9]+)_branch([0-9]+)_(.*)", | |
r"s\1.pathway0_res\2.branch\3_\4", | |
], | |
# 'res_conv1_' -> 's1.pathway0_stem.conv.' | |
[r"^res_conv1_(.*)", r"s1.pathway0_stem.conv.\1"], | |
# ------------------------------------------------------------ | |
# 'res4_4_branch_2c_bn_b' -> 's4.pathway0_res4.branch2.c_bn_b' | |
[ | |
r"^t_res([0-9]+)_([0-9]+)_branch([0-9]+)([a-z])_(.*)", | |
r"s\1.pathway1_res\2.branch\3.\4_\5", | |
], | |
# 'res_conv1_bn_' -> 's1.pathway0_stem.bn.' | |
[r"^t_res_conv1_bn_(.*)", r"s1.pathway1_stem.bn.\1"], | |
# 'conv1_w_momentum' -> 's1.pathway0_stem.conv.' | |
[r"^t_conv1_(.*)", r"s1.pathway1_stem.conv.\1"], | |
# 'res4_0_branch1_w' -> 'S4.pathway0_res0.branch1.weight' | |
[ | |
r"^t_res([0-9]+)_([0-9]+)_branch([0-9]+)_(.*)", | |
r"s\1.pathway1_res\2.branch\3_\4", | |
], | |
# 'res_conv1_' -> 's1.pathway0_stem.conv.' | |
[r"^t_res_conv1_(.*)", r"s1.pathway1_stem.conv.\1"], | |
# ------------------------------------------------------------ | |
# pred_ -> head.projection. | |
[r"pred_(.*)", r"head.projection.\1"], | |
# '.b_bn_fc' -> '.se.fc' | |
[r"(.*)b_bn_fc(.*)", r"\1se.fc\2"], | |
# conv_5 -> head.conv_5. | |
[r"conv_5(.*)", r"head.conv_5\1"], | |
# conv_5 -> head.conv_5. | |
[r"lin_5(.*)", r"head.lin_5\1"], | |
# '.bn_b' -> '.weight' | |
[r"(.*)bn.b\Z", r"\1bn.bias"], | |
# '.bn_s' -> '.weight' | |
[r"(.*)bn.s\Z", r"\1bn.weight"], | |
# '_bn_rm' -> '.running_mean' | |
[r"(.*)bn.rm\Z", r"\1bn.running_mean"], | |
# '_bn_riv' -> '.running_var' | |
[r"(.*)bn.riv\Z", r"\1bn.running_var"], | |
# '_b' -> '.bias' | |
[r"(.*)[\._]b\Z", r"\1.bias"], | |
# '_w' -> '.weight' | |
[r"(.*)[\._]w\Z", r"\1.weight"], | |
] | |
def convert_caffe2_name_to_pytorch(caffe2_layer_name): | |
""" | |
Convert the caffe2_layer_name to pytorch format by apply the list of | |
regular expressions. | |
Args: | |
caffe2_layer_name (str): caffe2 layer name. | |
Returns: | |
(str): pytorch layer name. | |
""" | |
for source, dest in pairs: | |
caffe2_layer_name = re.sub(source, dest, caffe2_layer_name) | |
return caffe2_layer_name | |
return convert_caffe2_name_to_pytorch | |