Spaces:
Sleeping
Sleeping
File size: 3,673 Bytes
e84842d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import os
import torch
import torch.nn.functional as F
from lavis.common.dist_utils import download_cached_file
from lavis.common.utils import is_url
from lavis.models.base_model import BaseModel
from transformers import BertTokenizer
class AlproBase(BaseModel):
@classmethod
def init_tokenizer(cls):
return BertTokenizer.from_pretrained("bert-base-uncased")
def load_from_pretrained(self, url_or_filename, num_frames, num_patches):
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
if "model" in checkpoint:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
for key in list(state_dict.keys()):
if "bert" in key:
new_key = key.replace("bert.", "")
state_dict[new_key] = state_dict[key]
del state_dict[key]
spatial_embed_key = "visual_encoder.model.pos_embed"
temporal_embed_key = "visual_encoder.model.time_embed"
## Resizing spatial embeddings in case they don't match
if num_patches + 1 != state_dict[spatial_embed_key].size(1):
state_dict[spatial_embed_key] = resize_spatial_embedding(
state_dict, spatial_embed_key, num_patches
)
else:
logging.info(
"The length of spatial position embedding matches. No need to resize."
)
## Resizing time embeddings in case they don't match
if temporal_embed_key in state_dict and num_frames != state_dict[
temporal_embed_key
].size(1):
state_dict[temporal_embed_key] = resize_temporal_embedding(
state_dict, temporal_embed_key, num_frames
)
else:
logging.info(
"No temporal encoding found. Or the length of temporal position embedding matches. No need to resize."
)
msg = self.load_state_dict(state_dict, strict=False)
logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)
return msg
def resize_spatial_embedding(state_dict, key, num_patches):
logging.info(
f"Resizing spatial position embedding from {state_dict[key].size(1)} to {num_patches + 1}"
)
pos_embed = state_dict[key]
cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1)
other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2)
new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode="nearest")
new_pos_embed = new_pos_embed.transpose(1, 2)
new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
return new_pos_embed
def resize_temporal_embedding(state_dict, key, num_frames):
logging.info(
f"Resizing temporal position embedding from {state_dict[key].size(1)} to {num_frames}"
)
time_embed = state_dict[key].transpose(1, 2)
new_time_embed = F.interpolate(time_embed, size=(num_frames), mode="nearest")
return new_time_embed.transpose(1, 2)
|