Spaces:
Paused
Paused
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from collections import OrderedDict | |
import functools | |
import torch | |
import torch.nn.functional as F | |
def inflate_positional_embeds( | |
current_model_state_dict, new_state_dict, | |
num_frames=4, | |
load_temporal_fix='bilinear', | |
): | |
# allow loading of timesformer with fewer num_frames | |
curr_keys = list(current_model_state_dict.keys()) | |
temporal_embed = ['visual.temporal_embed', 'visual.prompt_embed'] | |
for x in temporal_embed: | |
if x in new_state_dict and x in curr_keys: | |
load_temporal_embed = new_state_dict[x] | |
load_num_frames = load_temporal_embed.shape[1] | |
curr_num_frames = num_frames | |
embed_dim = load_temporal_embed.shape[2] | |
if load_num_frames != curr_num_frames: | |
if load_num_frames > curr_num_frames: | |
print(f'### loaded SpaceTimeTransformer model has MORE frames than current...' | |
f'### loading {x} weights, filling in the extras via {load_temporal_fix}') | |
new_temporal_embed = load_temporal_embed[:, :curr_num_frames, :] | |
else: | |
print(f'### loaded SpaceTimeTransformer model has FEWER frames than current...' | |
f'### loading {x} weights, filling in the extras via {load_temporal_fix}') | |
if load_temporal_fix == 'zeros': | |
new_temporal_embed = torch.zeros([load_temporal_embed.shape[0], curr_num_frames, embed_dim]) | |
new_temporal_embed[:, :load_num_frames] = load_temporal_embed | |
elif load_temporal_fix in ['interp', 'bilinear']: | |
# interpolate | |
# unsqueeze so pytorch thinks its an image | |
mode = 'nearest' | |
if load_temporal_fix == 'bilinear': | |
mode = 'bilinear' | |
load_temporal_embed = load_temporal_embed.unsqueeze(0) | |
new_temporal_embed = F.interpolate(load_temporal_embed, | |
(curr_num_frames, embed_dim), mode=mode).squeeze(0) | |
else: | |
raise NotImplementedError | |
new_state_dict[x] = new_temporal_embed | |
# allow loading with smaller spatial patches. assumes custom border crop, to append the | |
# border patches to the input sequence | |
if 'visual.pos_embed' in new_state_dict and 'visual.pos_embed' in curr_keys: | |
load_pos_embed = new_state_dict['visual.pos_embed'] | |
load_num_patches = load_pos_embed.shape[1] | |
curr_pos_embed = current_model_state_dict['visual.pos_embed'] | |
if load_num_patches != curr_pos_embed.shape[1]: | |
raise NotImplementedError( | |
'Loading models with different spatial resolution / patch number not yet implemented, sorry.') | |
return new_state_dict | |
def rsetattr(obj, attr, val): | |
pre, _, post = attr.rpartition('.') | |
return setattr(rgetattr(obj, pre) if pre else obj, post, val) | |
def rgetattr(obj, attr, *args): | |
def _getattr(obj, attr): | |
return getattr(obj, attr, *args) | |
return functools.reduce(_getattr, [obj] + attr.split('.')) | |
# util functions to convert CLIP-style model keys to TimeSformer-style | |
def remap_keys(clip_state_dict, transformer_layers=12): | |
remapped_state_dict = OrderedDict() | |
key_mapping = { | |
"class_embedding": "cls_token", | |
"positional_embedding": "pos_embed", | |
"conv1.weight": "patch_embed.proj.weight", | |
"ln_pre.weight": "ln_pre.weight", | |
"ln_pre.bias": "ln_pre.bias", | |
"ln_post.weight": "norm.weight", | |
"ln_post.bias": "norm.bias", | |
} | |
for layer in range(transformer_layers): | |
key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_weight"] = f"blocks.{layer}.attn.qkv.weight" | |
key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_bias"] = f"blocks.{layer}.attn.qkv.bias" | |
key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.weight"] = f"blocks.{layer}.attn.proj.weight" | |
key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.bias"] = f"blocks.{layer}.attn.proj.bias" | |
key_mapping[f"transformer.resblocks.{layer}.ln_1.weight"] = f"blocks.{layer}.norm1.weight" | |
key_mapping[f"transformer.resblocks.{layer}.ln_1.bias"] = f"blocks.{layer}.norm1.bias" | |
key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.weight"] = f"blocks.{layer}.mlp.fc1.weight" | |
key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.bias"] = f"blocks.{layer}.mlp.fc1.bias" | |
key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.weight"] = f"blocks.{layer}.mlp.fc2.weight" | |
key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.bias"] = f"blocks.{layer}.mlp.fc2.bias" | |
key_mapping[f"transformer.resblocks.{layer}.ln_2.weight"] = f"blocks.{layer}.norm2.weight" | |
key_mapping[f"transformer.resblocks.{layer}.ln_2.bias"] = f"blocks.{layer}.norm2.bias" | |
for key in clip_state_dict: | |
if key == 'proj': | |
continue # due to possible dim mismatch, we load this later | |
if key == "class_embedding": | |
clip_state_dict[key] = clip_state_dict[key].unsqueeze(0).unsqueeze(0) | |
if key == "positional_embedding": | |
clip_state_dict[key] = clip_state_dict[key].unsqueeze(0) | |
remapped_state_dict[key_mapping[key]] = clip_state_dict[key] | |
return remapped_state_dict | |