Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import logging | |
import os | |
import torch | |
import torch.nn.functional as F | |
from lavis.common.dist_utils import download_cached_file | |
from lavis.common.utils import is_url | |
from lavis.models.base_model import BaseModel | |
from transformers import BertTokenizer | |
class AlproBase(BaseModel): | |
def init_tokenizer(cls): | |
return BertTokenizer.from_pretrained("bert-base-uncased") | |
def load_from_pretrained(self, url_or_filename, num_frames, num_patches): | |
if is_url(url_or_filename): | |
cached_file = download_cached_file( | |
url_or_filename, check_hash=False, progress=True | |
) | |
checkpoint = torch.load(cached_file, map_location="cpu") | |
elif os.path.isfile(url_or_filename): | |
checkpoint = torch.load(url_or_filename, map_location="cpu") | |
else: | |
raise RuntimeError("checkpoint url or path is invalid") | |
if "model" in checkpoint: | |
state_dict = checkpoint["model"] | |
else: | |
state_dict = checkpoint | |
for key in list(state_dict.keys()): | |
if "bert" in key: | |
new_key = key.replace("bert.", "") | |
state_dict[new_key] = state_dict[key] | |
del state_dict[key] | |
spatial_embed_key = "visual_encoder.model.pos_embed" | |
temporal_embed_key = "visual_encoder.model.time_embed" | |
## Resizing spatial embeddings in case they don't match | |
if num_patches + 1 != state_dict[spatial_embed_key].size(1): | |
state_dict[spatial_embed_key] = resize_spatial_embedding( | |
state_dict, spatial_embed_key, num_patches | |
) | |
else: | |
logging.info( | |
"The length of spatial position embedding matches. No need to resize." | |
) | |
## Resizing time embeddings in case they don't match | |
if temporal_embed_key in state_dict and num_frames != state_dict[ | |
temporal_embed_key | |
].size(1): | |
state_dict[temporal_embed_key] = resize_temporal_embedding( | |
state_dict, temporal_embed_key, num_frames | |
) | |
else: | |
logging.info( | |
"No temporal encoding found. Or the length of temporal position embedding matches. No need to resize." | |
) | |
msg = self.load_state_dict(state_dict, strict=False) | |
logging.info("Missing keys {}".format(msg.missing_keys)) | |
logging.info("load checkpoint from %s" % url_or_filename) | |
return msg | |
def resize_spatial_embedding(state_dict, key, num_patches): | |
logging.info( | |
f"Resizing spatial position embedding from {state_dict[key].size(1)} to {num_patches + 1}" | |
) | |
pos_embed = state_dict[key] | |
cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) | |
other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) | |
new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode="nearest") | |
new_pos_embed = new_pos_embed.transpose(1, 2) | |
new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) | |
return new_pos_embed | |
def resize_temporal_embedding(state_dict, key, num_frames): | |
logging.info( | |
f"Resizing temporal position embedding from {state_dict[key].size(1)} to {num_frames}" | |
) | |
time_embed = state_dict[key].transpose(1, 2) | |
new_time_embed = F.interpolate(time_embed, size=(num_frames), mode="nearest") | |
return new_time_embed.transpose(1, 2) | |