Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from collections import OrderedDict | |
| import functools | |
| import torch | |
| import torch.nn.functional as F | |
| def inflate_positional_embeds( | |
| current_model_state_dict, new_state_dict, | |
| num_frames=4, | |
| load_temporal_fix='bilinear', | |
| ): | |
| # allow loading of timesformer with fewer num_frames | |
| curr_keys = list(current_model_state_dict.keys()) | |
| temporal_embed = ['visual.temporal_embed', 'visual.prompt_embed'] | |
| for x in temporal_embed: | |
| if x in new_state_dict and x in curr_keys: | |
| load_temporal_embed = new_state_dict[x] | |
| load_num_frames = load_temporal_embed.shape[1] | |
| curr_num_frames = num_frames | |
| embed_dim = load_temporal_embed.shape[2] | |
| if load_num_frames != curr_num_frames: | |
| if load_num_frames > curr_num_frames: | |
| print(f'### loaded SpaceTimeTransformer model has MORE frames than current...' | |
| f'### loading {x} weights, filling in the extras via {load_temporal_fix}') | |
| new_temporal_embed = load_temporal_embed[:, :curr_num_frames, :] | |
| else: | |
| print(f'### loaded SpaceTimeTransformer model has FEWER frames than current...' | |
| f'### loading {x} weights, filling in the extras via {load_temporal_fix}') | |
| if load_temporal_fix == 'zeros': | |
| new_temporal_embed = torch.zeros([load_temporal_embed.shape[0], curr_num_frames, embed_dim]) | |
| new_temporal_embed[:, :load_num_frames] = load_temporal_embed | |
| elif load_temporal_fix in ['interp', 'bilinear']: | |
| # interpolate | |
| # unsqueeze so pytorch thinks its an image | |
| mode = 'nearest' | |
| if load_temporal_fix == 'bilinear': | |
| mode = 'bilinear' | |
| load_temporal_embed = load_temporal_embed.unsqueeze(0) | |
| new_temporal_embed = F.interpolate(load_temporal_embed, | |
| (curr_num_frames, embed_dim), mode=mode).squeeze(0) | |
| else: | |
| raise NotImplementedError | |
| new_state_dict[x] = new_temporal_embed | |
| # allow loading with smaller spatial patches. assumes custom border crop, to append the | |
| # border patches to the input sequence | |
| if 'visual.pos_embed' in new_state_dict and 'visual.pos_embed' in curr_keys: | |
| load_pos_embed = new_state_dict['visual.pos_embed'] | |
| load_num_patches = load_pos_embed.shape[1] | |
| curr_pos_embed = current_model_state_dict['visual.pos_embed'] | |
| if load_num_patches != curr_pos_embed.shape[1]: | |
| raise NotImplementedError( | |
| 'Loading models with different spatial resolution / patch number not yet implemented, sorry.') | |
| return new_state_dict | |
| def rsetattr(obj, attr, val): | |
| pre, _, post = attr.rpartition('.') | |
| return setattr(rgetattr(obj, pre) if pre else obj, post, val) | |
| def rgetattr(obj, attr, *args): | |
| def _getattr(obj, attr): | |
| return getattr(obj, attr, *args) | |
| return functools.reduce(_getattr, [obj] + attr.split('.')) | |
| # util functions to convert CLIP-style model keys to TimeSformer-style | |
| def remap_keys(clip_state_dict, transformer_layers=12): | |
| remapped_state_dict = OrderedDict() | |
| key_mapping = { | |
| "class_embedding": "cls_token", | |
| "positional_embedding": "pos_embed", | |
| "conv1.weight": "patch_embed.proj.weight", | |
| "ln_pre.weight": "ln_pre.weight", | |
| "ln_pre.bias": "ln_pre.bias", | |
| "ln_post.weight": "norm.weight", | |
| "ln_post.bias": "norm.bias", | |
| } | |
| for layer in range(transformer_layers): | |
| key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_weight"] = f"blocks.{layer}.attn.qkv.weight" | |
| key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_bias"] = f"blocks.{layer}.attn.qkv.bias" | |
| key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.weight"] = f"blocks.{layer}.attn.proj.weight" | |
| key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.bias"] = f"blocks.{layer}.attn.proj.bias" | |
| key_mapping[f"transformer.resblocks.{layer}.ln_1.weight"] = f"blocks.{layer}.norm1.weight" | |
| key_mapping[f"transformer.resblocks.{layer}.ln_1.bias"] = f"blocks.{layer}.norm1.bias" | |
| key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.weight"] = f"blocks.{layer}.mlp.fc1.weight" | |
| key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.bias"] = f"blocks.{layer}.mlp.fc1.bias" | |
| key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.weight"] = f"blocks.{layer}.mlp.fc2.weight" | |
| key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.bias"] = f"blocks.{layer}.mlp.fc2.bias" | |
| key_mapping[f"transformer.resblocks.{layer}.ln_2.weight"] = f"blocks.{layer}.norm2.weight" | |
| key_mapping[f"transformer.resblocks.{layer}.ln_2.bias"] = f"blocks.{layer}.norm2.bias" | |
| for key in clip_state_dict: | |
| if key == 'proj': | |
| continue # due to possible dim mismatch, we load this later | |
| if key == "class_embedding": | |
| clip_state_dict[key] = clip_state_dict[key].unsqueeze(0).unsqueeze(0) | |
| if key == "positional_embedding": | |
| clip_state_dict[key] = clip_state_dict[key].unsqueeze(0) | |
| remapped_state_dict[key_mapping[key]] = clip_state_dict[key] | |
| return remapped_state_dict | |