|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
import functools |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
def inflate_positional_embeds( |
|
current_model_state_dict, new_state_dict, |
|
num_frames=4, |
|
load_temporal_fix='bilinear', |
|
): |
|
|
|
curr_keys = list(current_model_state_dict.keys()) |
|
temporal_embed = ['visual.temporal_embed', 'visual.prompt_embed'] |
|
for x in temporal_embed: |
|
if x in new_state_dict and x in curr_keys: |
|
load_temporal_embed = new_state_dict[x] |
|
load_num_frames = load_temporal_embed.shape[1] |
|
curr_num_frames = num_frames |
|
embed_dim = load_temporal_embed.shape[2] |
|
|
|
if load_num_frames != curr_num_frames: |
|
if load_num_frames > curr_num_frames: |
|
print(f'### loaded SpaceTimeTransformer model has MORE frames than current...' |
|
f'### loading {x} weights, filling in the extras via {load_temporal_fix}') |
|
new_temporal_embed = load_temporal_embed[:, :curr_num_frames, :] |
|
else: |
|
print(f'### loaded SpaceTimeTransformer model has FEWER frames than current...' |
|
f'### loading {x} weights, filling in the extras via {load_temporal_fix}') |
|
if load_temporal_fix == 'zeros': |
|
new_temporal_embed = torch.zeros([load_temporal_embed.shape[0], curr_num_frames, embed_dim]) |
|
new_temporal_embed[:, :load_num_frames] = load_temporal_embed |
|
elif load_temporal_fix in ['interp', 'bilinear']: |
|
|
|
|
|
mode = 'nearest' |
|
if load_temporal_fix == 'bilinear': |
|
mode = 'bilinear' |
|
load_temporal_embed = load_temporal_embed.unsqueeze(0) |
|
new_temporal_embed = F.interpolate(load_temporal_embed, |
|
(curr_num_frames, embed_dim), mode=mode).squeeze(0) |
|
else: |
|
raise NotImplementedError |
|
new_state_dict[x] = new_temporal_embed |
|
|
|
|
|
if 'visual.pos_embed' in new_state_dict and 'visual.pos_embed' in curr_keys: |
|
load_pos_embed = new_state_dict['visual.pos_embed'] |
|
load_num_patches = load_pos_embed.shape[1] |
|
curr_pos_embed = current_model_state_dict['visual.pos_embed'] |
|
if load_num_patches != curr_pos_embed.shape[1]: |
|
raise NotImplementedError( |
|
'Loading models with different spatial resolution / patch number not yet implemented, sorry.') |
|
|
|
return new_state_dict |
|
|
|
|
|
def rsetattr(obj, attr, val): |
|
pre, _, post = attr.rpartition('.') |
|
return setattr(rgetattr(obj, pre) if pre else obj, post, val) |
|
|
|
|
|
def rgetattr(obj, attr, *args): |
|
def _getattr(obj, attr): |
|
return getattr(obj, attr, *args) |
|
return functools.reduce(_getattr, [obj] + attr.split('.')) |
|
|
|
|
|
|
|
def remap_keys(clip_state_dict, transformer_layers=12): |
|
remapped_state_dict = OrderedDict() |
|
key_mapping = { |
|
"class_embedding": "cls_token", |
|
"positional_embedding": "pos_embed", |
|
"conv1.weight": "patch_embed.proj.weight", |
|
"ln_pre.weight": "ln_pre.weight", |
|
"ln_pre.bias": "ln_pre.bias", |
|
"ln_post.weight": "norm.weight", |
|
"ln_post.bias": "norm.bias", |
|
} |
|
for layer in range(transformer_layers): |
|
key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_weight"] = f"blocks.{layer}.attn.qkv.weight" |
|
key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_bias"] = f"blocks.{layer}.attn.qkv.bias" |
|
key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.weight"] = f"blocks.{layer}.attn.proj.weight" |
|
key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.bias"] = f"blocks.{layer}.attn.proj.bias" |
|
key_mapping[f"transformer.resblocks.{layer}.ln_1.weight"] = f"blocks.{layer}.norm1.weight" |
|
key_mapping[f"transformer.resblocks.{layer}.ln_1.bias"] = f"blocks.{layer}.norm1.bias" |
|
key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.weight"] = f"blocks.{layer}.mlp.fc1.weight" |
|
key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.bias"] = f"blocks.{layer}.mlp.fc1.bias" |
|
key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.weight"] = f"blocks.{layer}.mlp.fc2.weight" |
|
key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.bias"] = f"blocks.{layer}.mlp.fc2.bias" |
|
key_mapping[f"transformer.resblocks.{layer}.ln_2.weight"] = f"blocks.{layer}.norm2.weight" |
|
key_mapping[f"transformer.resblocks.{layer}.ln_2.bias"] = f"blocks.{layer}.norm2.bias" |
|
|
|
for key in clip_state_dict: |
|
if key == 'proj': |
|
continue |
|
if key == "class_embedding": |
|
clip_state_dict[key] = clip_state_dict[key].unsqueeze(0).unsqueeze(0) |
|
if key == "positional_embedding": |
|
clip_state_dict[key] = clip_state_dict[key].unsqueeze(0) |
|
remapped_state_dict[key_mapping[key]] = clip_state_dict[key] |
|
|
|
return remapped_state_dict |
|
|