Spaces:
Build error
Build error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import logging | |
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| from lavis.common.dist_utils import download_cached_file | |
| from lavis.common.utils import is_url | |
| from lavis.models.base_model import BaseModel | |
| from transformers import BertTokenizer | |
| class AlproBase(BaseModel): | |
| def init_tokenizer(cls): | |
| return BertTokenizer.from_pretrained("bert-base-uncased") | |
| def load_from_pretrained(self, url_or_filename, num_frames, num_patches): | |
| if is_url(url_or_filename): | |
| cached_file = download_cached_file( | |
| url_or_filename, check_hash=False, progress=True | |
| ) | |
| checkpoint = torch.load(cached_file, map_location="cpu") | |
| elif os.path.isfile(url_or_filename): | |
| checkpoint = torch.load(url_or_filename, map_location="cpu") | |
| else: | |
| raise RuntimeError("checkpoint url or path is invalid") | |
| if "model" in checkpoint: | |
| state_dict = checkpoint["model"] | |
| else: | |
| state_dict = checkpoint | |
| for key in list(state_dict.keys()): | |
| if "bert" in key: | |
| new_key = key.replace("bert.", "") | |
| state_dict[new_key] = state_dict[key] | |
| del state_dict[key] | |
| spatial_embed_key = "visual_encoder.model.pos_embed" | |
| temporal_embed_key = "visual_encoder.model.time_embed" | |
| ## Resizing spatial embeddings in case they don't match | |
| if num_patches + 1 != state_dict[spatial_embed_key].size(1): | |
| state_dict[spatial_embed_key] = resize_spatial_embedding( | |
| state_dict, spatial_embed_key, num_patches | |
| ) | |
| else: | |
| logging.info( | |
| "The length of spatial position embedding matches. No need to resize." | |
| ) | |
| ## Resizing time embeddings in case they don't match | |
| if temporal_embed_key in state_dict and num_frames != state_dict[ | |
| temporal_embed_key | |
| ].size(1): | |
| state_dict[temporal_embed_key] = resize_temporal_embedding( | |
| state_dict, temporal_embed_key, num_frames | |
| ) | |
| else: | |
| logging.info( | |
| "No temporal encoding found. Or the length of temporal position embedding matches. No need to resize." | |
| ) | |
| msg = self.load_state_dict(state_dict, strict=False) | |
| logging.info("Missing keys {}".format(msg.missing_keys)) | |
| logging.info("load checkpoint from %s" % url_or_filename) | |
| return msg | |
| def resize_spatial_embedding(state_dict, key, num_patches): | |
| logging.info( | |
| f"Resizing spatial position embedding from {state_dict[key].size(1)} to {num_patches + 1}" | |
| ) | |
| pos_embed = state_dict[key] | |
| cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) | |
| other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) | |
| new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode="nearest") | |
| new_pos_embed = new_pos_embed.transpose(1, 2) | |
| new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) | |
| return new_pos_embed | |
| def resize_temporal_embedding(state_dict, key, num_frames): | |
| logging.info( | |
| f"Resizing temporal position embedding from {state_dict[key].size(1)} to {num_frames}" | |
| ) | |
| time_embed = state_dict[key].transpose(1, 2) | |
| new_time_embed = F.interpolate(time_embed, size=(num_frames), mode="nearest") | |
| return new_time_embed.transpose(1, 2) | |