File size: 14,281 Bytes
ce3dce6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
"""Script to download the pre-trained tensorflow weights and convert them to pytorch weights."""
import os
import argparse
import torch
import numpy as np
from tensorflow.python.training import py_checkpoint_reader
from repnet import utils
from repnet.model import RepNet
# Relevant paths
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
TF_CHECKPOINT_BASE_URL = 'https://storage.googleapis.com/repnet_ckpt'
TF_CHECKPOINT_FILES = ['checkpoint', 'ckpt-88.data-00000-of-00002', 'ckpt-88.data-00001-of-00002', 'ckpt-88.index']
OUT_CHECKPOINTS_DIR = os.path.join(PROJECT_ROOT, 'checkpoints')
# Mapping of ndim -> permutation to go from tf to pytorch
WEIGHTS_PERMUTATION = {
2: (1, 0),
4: (3, 2, 0, 1),
5: (4, 3, 0, 1, 2)
}
# Mapping of tf attributes -> pytorch attributes
ATTR_MAPPING = {
'kernel':'weight',
'bias': 'bias',
'beta': 'bias',
'gamma': 'weight',
'moving_mean': 'running_mean',
'moving_variance': 'running_var'
}
# Mapping of tf checkpoint -> tf model -> pytorch model
WEIGHTS_MAPPING = [
# Base frame encoder
('base_model.layer-2', 'conv1_conv', 'encoder.stem.conv'),
('base_model.layer-5', 'conv2_block1_preact_bn', 'encoder.stages.0.blocks.0.norm1'),
('base_model.layer-7', 'conv2_block1_1_conv', 'encoder.stages.0.blocks.0.conv1'),
('base_model.layer-8', 'conv2_block1_1_bn', 'encoder.stages.0.blocks.0.norm2'),
('base_model.layer_with_weights-4', 'conv2_block1_2_conv', 'encoder.stages.0.blocks.0.conv2'),
('base_model.layer_with_weights-5', 'conv2_block1_2_bn', 'encoder.stages.0.blocks.0.norm3'),
('base_model.layer_with_weights-6', 'conv2_block1_0_conv', 'encoder.stages.0.blocks.0.downsample.conv'),
('base_model.layer_with_weights-7', 'conv2_block1_3_conv', 'encoder.stages.0.blocks.0.conv3'),
('base_model.layer_with_weights-8', 'conv2_block2_preact_bn', 'encoder.stages.0.blocks.1.norm1'),
('base_model.layer_with_weights-9', 'conv2_block2_1_conv', 'encoder.stages.0.blocks.1.conv1'),
('base_model.layer_with_weights-10', 'conv2_block2_1_bn', 'encoder.stages.0.blocks.1.norm2'),
('base_model.layer_with_weights-11', 'conv2_block2_2_conv', 'encoder.stages.0.blocks.1.conv2'),
('base_model.layer_with_weights-12', 'conv2_block2_2_bn', 'encoder.stages.0.blocks.1.norm3'),
('base_model.layer_with_weights-13', 'conv2_block2_3_conv', 'encoder.stages.0.blocks.1.conv3'),
('base_model.layer_with_weights-14', 'conv2_block3_preact_bn', 'encoder.stages.0.blocks.2.norm1'),
('base_model.layer_with_weights-15', 'conv2_block3_1_conv', 'encoder.stages.0.blocks.2.conv1'),
('base_model.layer_with_weights-16', 'conv2_block3_1_bn', 'encoder.stages.0.blocks.2.norm2'),
('base_model.layer_with_weights-17', 'conv2_block3_2_conv', 'encoder.stages.0.blocks.2.conv2'),
('base_model.layer_with_weights-18', 'conv2_block3_2_bn', 'encoder.stages.0.blocks.2.norm3'),
('base_model.layer_with_weights-19', 'conv2_block3_3_conv', 'encoder.stages.0.blocks.2.conv3'),
('base_model.layer_with_weights-20', 'conv3_block1_preact_bn', 'encoder.stages.1.blocks.0.norm1'),
('base_model.layer_with_weights-21', 'conv3_block1_1_conv', 'encoder.stages.1.blocks.0.conv1'),
('base_model.layer_with_weights-22', 'conv3_block1_1_bn', 'encoder.stages.1.blocks.0.norm2'),
('base_model.layer_with_weights-23', 'conv3_block1_2_conv', 'encoder.stages.1.blocks.0.conv2'),
('base_model.layer-47', 'conv3_block1_2_bn', 'encoder.stages.1.blocks.0.norm3'),
('base_model.layer_with_weights-25', 'conv3_block1_0_conv', 'encoder.stages.1.blocks.0.downsample.conv'),
('base_model.layer_with_weights-26', 'conv3_block1_3_conv', 'encoder.stages.1.blocks.0.conv3'),
('base_model.layer_with_weights-27', 'conv3_block2_preact_bn', 'encoder.stages.1.blocks.1.norm1'),
('base_model.layer_with_weights-28', 'conv3_block2_1_conv', 'encoder.stages.1.blocks.1.conv1'),
('base_model.layer_with_weights-29', 'conv3_block2_1_bn', 'encoder.stages.1.blocks.1.norm2'),
('base_model.layer_with_weights-30', 'conv3_block2_2_conv', 'encoder.stages.1.blocks.1.conv2'),
('base_model.layer_with_weights-31', 'conv3_block2_2_bn', 'encoder.stages.1.blocks.1.norm3'),
('base_model.layer-61', 'conv3_block2_3_conv', 'encoder.stages.1.blocks.1.conv3'),
('base_model.layer-63', 'conv3_block3_preact_bn', 'encoder.stages.1.blocks.2.norm1'),
('base_model.layer-65', 'conv3_block3_1_conv', 'encoder.stages.1.blocks.2.conv1'),
('base_model.layer-66', 'conv3_block3_1_bn', 'encoder.stages.1.blocks.2.norm2'),
('base_model.layer-69', 'conv3_block3_2_conv', 'encoder.stages.1.blocks.2.conv2'),
('base_model.layer-70', 'conv3_block3_2_bn', 'encoder.stages.1.blocks.2.norm3'),
('base_model.layer_with_weights-38', 'conv3_block3_3_conv', 'encoder.stages.1.blocks.2.conv3'),
('base_model.layer-74', 'conv3_block4_preact_bn', 'encoder.stages.1.blocks.3.norm1'),
('base_model.layer_with_weights-40', 'conv3_block4_1_conv', 'encoder.stages.1.blocks.3.conv1'),
('base_model.layer_with_weights-41', 'conv3_block4_1_bn', 'encoder.stages.1.blocks.3.norm2'),
('base_model.layer_with_weights-42', 'conv3_block4_2_conv', 'encoder.stages.1.blocks.3.conv2'),
('base_model.layer_with_weights-43', 'conv3_block4_2_bn', 'encoder.stages.1.blocks.3.norm3'),
('base_model.layer_with_weights-44', 'conv3_block4_3_conv', 'encoder.stages.1.blocks.3.conv3'),
('base_model.layer_with_weights-45', 'conv4_block1_preact_bn', 'encoder.stages.2.blocks.0.norm1'),
('base_model.layer_with_weights-46', 'conv4_block1_1_conv', 'encoder.stages.2.blocks.0.conv1'),
('base_model.layer_with_weights-47', 'conv4_block1_1_bn', 'encoder.stages.2.blocks.0.norm2'),
('base_model.layer-92', 'conv4_block1_2_conv', 'encoder.stages.2.blocks.0.conv2'),
('base_model.layer-93', 'conv4_block1_2_bn', 'encoder.stages.2.blocks.0.norm3'),
('base_model.layer-95', 'conv4_block1_0_conv', 'encoder.stages.2.blocks.0.downsample.conv'),
('base_model.layer-96', 'conv4_block1_3_conv', 'encoder.stages.2.blocks.0.conv3'),
('base_model.layer-98', 'conv4_block2_preact_bn', 'encoder.stages.2.blocks.1.norm1'),
('base_model.layer-100', 'conv4_block2_1_conv', 'encoder.stages.2.blocks.1.conv1'),
('base_model.layer-101', 'conv4_block2_1_bn', 'encoder.stages.2.blocks.1.norm2'),
('base_model.layer-104', 'conv4_block2_2_conv', 'encoder.stages.2.blocks.1.conv2'),
('base_model.layer-105', 'conv4_block2_2_bn', 'encoder.stages.2.blocks.1.norm3'),
('base_model.layer-107', 'conv4_block2_3_conv', 'encoder.stages.2.blocks.1.conv3'),
('base_model.layer-109', 'conv4_block3_preact_bn', 'encoder.stages.2.blocks.2.norm1'),
('base_model.layer-111', 'conv4_block3_1_conv', 'encoder.stages.2.blocks.2.conv1'),
('base_model.layer-112', 'conv4_block3_1_bn', 'encoder.stages.2.blocks.2.norm2'),
('base_model.layer-115', 'conv4_block3_2_conv', 'encoder.stages.2.blocks.2.conv2'),
('base_model.layer-116', 'conv4_block3_2_bn', 'encoder.stages.2.blocks.2.norm3'),
('base_model.layer-118', 'conv4_block3_3_conv', 'encoder.stages.2.blocks.2.conv3'),
# Temporal convolution
('temporal_conv_layers.0', 'conv3d', 'temporal_conv.0'),
('temporal_bn_layers.0', 'batch_normalization', 'temporal_conv.1'),
('conv_3x3_layer', 'conv2d', 'tsm_conv.0'),
# Period length head
('input_projection', 'dense', 'period_length_head.0.input_projection'),
('pos_encoding', None, 'period_length_head.0.pos_encoding'),
('transformer_layers.0.ffn.layer-0', None, 'period_length_head.0.transformer_layer.linear1'),
('transformer_layers.0.ffn.layer-1', None, 'period_length_head.0.transformer_layer.linear2'),
('transformer_layers.0.layernorm1', None, 'period_length_head.0.transformer_layer.norm1'),
('transformer_layers.0.layernorm2', None, 'period_length_head.0.transformer_layer.norm2'),
('transformer_layers.0.mha.w_weight', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_weight'),
('transformer_layers.0.mha.w_bias', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_bias'),
('transformer_layers.0.mha.dense', None, 'period_length_head.0.transformer_layer.self_attn.out_proj'),
('fc_layers.0', 'dense_14', 'period_length_head.1'),
('fc_layers.1', 'dense_15', 'period_length_head.3'),
('fc_layers.2', 'dense_16', 'period_length_head.5'),
# Periodicity head
('input_projection2', 'dense_1', 'periodicity_head.0.input_projection'),
('pos_encoding2', None, 'periodicity_head.0.pos_encoding'),
('transformer_layers2.0.ffn.layer-0', None, 'periodicity_head.0.transformer_layer.linear1'),
('transformer_layers2.0.ffn.layer-1', None, 'periodicity_head.0.transformer_layer.linear2'),
('transformer_layers2.0.layernorm1', None, 'periodicity_head.0.transformer_layer.norm1'),
('transformer_layers2.0.layernorm2', None, 'periodicity_head.0.transformer_layer.norm2'),
('transformer_layers2.0.mha.w_weight',None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_weight'),
('transformer_layers2.0.mha.w_bias', None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_bias'),
('transformer_layers2.0.mha.dense', None, 'periodicity_head.0.transformer_layer.self_attn.out_proj'),
('within_period_fc_layers.0', 'dense_17', 'periodicity_head.1'),
('within_period_fc_layers.1', 'dense_18', 'periodicity_head.3'),
('within_period_fc_layers.2', 'dense_19', 'periodicity_head.5'),
]
# Script arguments
parser = argparse.ArgumentParser(description='Download and convert the pre-trained weights from tensorflow to pytorch.')
if __name__ == '__main__':
args = parser.parse_args()
# Download tensorflow checkpoints
print('Downloading checkpoints...')
tf_checkpoint_dir = os.path.join(OUT_CHECKPOINTS_DIR, 'tf_checkpoint')
os.makedirs(tf_checkpoint_dir, exist_ok=True)
for file in TF_CHECKPOINT_FILES:
dst = os.path.join(tf_checkpoint_dir, file)
if not os.path.exists(dst):
utils.download_file(f'{TF_CHECKPOINT_BASE_URL}/{file}', dst)
# Load tensorflow weights into a dictionary
print('Loading tensorflow checkpoint...')
checkpoint_path = os.path.join(tf_checkpoint_dir, 'ckpt-88')
checkpoint_reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path)
shape_map = checkpoint_reader.get_variable_to_shape_map()
tf_state_dict = {}
for var_name in sorted(shape_map.keys()):
var_tensor = checkpoint_reader.get_tensor(var_name)
if not var_name.startswith('model') or '.OPTIMIZER_SLOT' in var_name:
continue # Skip variables that are not part of the model, e.g. from the optimizer
# Split var_name into path
var_path = var_name.split('/')[1:] # Remove `model`` key from the path
var_path = [p for p in var_path if p not in ['.ATTRIBUTES', 'VARIABLE_VALUE']]
# Map weights into a nested dictionary
current_dict = tf_state_dict
for path in var_path[:-1]:
current_dict = current_dict.setdefault(path, {})
current_dict[var_path[-1]] = var_tensor
# Merge transformer self-attention weights into a single tensor
for k in ['transformer_layers', 'transformer_layers2']:
v = tf_state_dict[k]['0']['mha']
v['w_weight'] = np.concatenate([v['wq']['kernel'].T, v['wk']['kernel'].T, v['wv']['kernel'].T], axis=0)
v['w_bias'] = np.concatenate([v['wq']['bias'].T, v['wk']['bias'].T, v['wv']['bias'].T], axis=0)
del v['wk'], v['wq'], v['wv']
tf_state_dict = utils.flatten_dict(tf_state_dict, keep_last=True)
# Add missing final level for some weights
for k, v in tf_state_dict.items():
if not isinstance(v, dict):
tf_state_dict[k] = {None: v}
# Convert to a format compatible with PyTorch and save
print(f'Converting to PyTorch format...')
pt_checkpoint_path = os.path.join(OUT_CHECKPOINTS_DIR, 'pytorch_weights.pth')
pt_state_dict = {}
for k_tf, _, k_pt in WEIGHTS_MAPPING:
assert k_pt not in pt_state_dict
pt_state_dict[k_pt] = {}
for attr in tf_state_dict[k_tf]:
new_attr = ATTR_MAPPING.get(attr, attr)
pt_state_dict[k_pt][new_attr] = torch.from_numpy(tf_state_dict[k_tf][attr])
if attr == 'kernel':
weights_permutation = WEIGHTS_PERMUTATION[pt_state_dict[k_pt][new_attr].ndim] # Permute weights if needed
pt_state_dict[k_pt][new_attr] = pt_state_dict[k_pt][new_attr].permute(weights_permutation)
pt_state_dict = utils.flatten_dict(pt_state_dict, skip_none=True)
torch.save(pt_state_dict, pt_checkpoint_path)
# Initialize the model and try to load the weights
print('Check that the weights can be loaded into the model...')
model = RepNet()
pt_state_dict = torch.load(pt_checkpoint_path)
model.load_state_dict(pt_state_dict)
print(f'Done. PyTorch weights saved to {pt_checkpoint_path}.')
|