gina9726's picture
Upload 29 files
b3660df verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import functools
import torch
import torch.nn.functional as F
def inflate_positional_embeds(
current_model_state_dict, new_state_dict,
num_frames=4,
load_temporal_fix='bilinear',
):
# allow loading of timesformer with fewer num_frames
curr_keys = list(current_model_state_dict.keys())
temporal_embed = ['visual.temporal_embed', 'visual.prompt_embed']
for x in temporal_embed:
if x in new_state_dict and x in curr_keys:
load_temporal_embed = new_state_dict[x]
load_num_frames = load_temporal_embed.shape[1]
curr_num_frames = num_frames
embed_dim = load_temporal_embed.shape[2]
if load_num_frames != curr_num_frames:
if load_num_frames > curr_num_frames:
print(f'### loaded SpaceTimeTransformer model has MORE frames than current...'
f'### loading {x} weights, filling in the extras via {load_temporal_fix}')
new_temporal_embed = load_temporal_embed[:, :curr_num_frames, :]
else:
print(f'### loaded SpaceTimeTransformer model has FEWER frames than current...'
f'### loading {x} weights, filling in the extras via {load_temporal_fix}')
if load_temporal_fix == 'zeros':
new_temporal_embed = torch.zeros([load_temporal_embed.shape[0], curr_num_frames, embed_dim])
new_temporal_embed[:, :load_num_frames] = load_temporal_embed
elif load_temporal_fix in ['interp', 'bilinear']:
# interpolate
# unsqueeze so pytorch thinks its an image
mode = 'nearest'
if load_temporal_fix == 'bilinear':
mode = 'bilinear'
load_temporal_embed = load_temporal_embed.unsqueeze(0)
new_temporal_embed = F.interpolate(load_temporal_embed,
(curr_num_frames, embed_dim), mode=mode).squeeze(0)
else:
raise NotImplementedError
new_state_dict[x] = new_temporal_embed
# allow loading with smaller spatial patches. assumes custom border crop, to append the
# border patches to the input sequence
if 'visual.pos_embed' in new_state_dict and 'visual.pos_embed' in curr_keys:
load_pos_embed = new_state_dict['visual.pos_embed']
load_num_patches = load_pos_embed.shape[1]
curr_pos_embed = current_model_state_dict['visual.pos_embed']
if load_num_patches != curr_pos_embed.shape[1]:
raise NotImplementedError(
'Loading models with different spatial resolution / patch number not yet implemented, sorry.')
return new_state_dict
def rsetattr(obj, attr, val):
pre, _, post = attr.rpartition('.')
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
# util functions to convert CLIP-style model keys to TimeSformer-style
def remap_keys(clip_state_dict, transformer_layers=12):
remapped_state_dict = OrderedDict()
key_mapping = {
"class_embedding": "cls_token",
"positional_embedding": "pos_embed",
"conv1.weight": "patch_embed.proj.weight",
"ln_pre.weight": "ln_pre.weight",
"ln_pre.bias": "ln_pre.bias",
"ln_post.weight": "norm.weight",
"ln_post.bias": "norm.bias",
}
for layer in range(transformer_layers):
key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_weight"] = f"blocks.{layer}.attn.qkv.weight"
key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_bias"] = f"blocks.{layer}.attn.qkv.bias"
key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.weight"] = f"blocks.{layer}.attn.proj.weight"
key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.bias"] = f"blocks.{layer}.attn.proj.bias"
key_mapping[f"transformer.resblocks.{layer}.ln_1.weight"] = f"blocks.{layer}.norm1.weight"
key_mapping[f"transformer.resblocks.{layer}.ln_1.bias"] = f"blocks.{layer}.norm1.bias"
key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.weight"] = f"blocks.{layer}.mlp.fc1.weight"
key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.bias"] = f"blocks.{layer}.mlp.fc1.bias"
key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.weight"] = f"blocks.{layer}.mlp.fc2.weight"
key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.bias"] = f"blocks.{layer}.mlp.fc2.bias"
key_mapping[f"transformer.resblocks.{layer}.ln_2.weight"] = f"blocks.{layer}.norm2.weight"
key_mapping[f"transformer.resblocks.{layer}.ln_2.bias"] = f"blocks.{layer}.norm2.bias"
for key in clip_state_dict:
if key == 'proj':
continue # due to possible dim mismatch, we load this later
if key == "class_embedding":
clip_state_dict[key] = clip_state_dict[key].unsqueeze(0).unsqueeze(0)
if key == "positional_embedding":
clip_state_dict[key] = clip_state_dict[key].unsqueeze(0)
remapped_state_dict[key_mapping[key]] = clip_state_dict[key]
return remapped_state_dict