""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import logging import os import torch from lavis.common.dist_utils import download_cached_file from lavis.common.utils import is_url from lavis.models.base_model import BaseModel from lavis.models.vit import interpolate_pos_embed from transformers import BertTokenizer class BlipBase(BaseModel): @classmethod def init_tokenizer(cls): tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") tokenizer.add_special_tokens({"bos_token": "[DEC]"}) tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]}) tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] return tokenizer def load_from_pretrained(self, url_or_filename): if is_url(url_or_filename): cached_file = download_cached_file( url_or_filename, check_hash=False, progress=True ) checkpoint = torch.load(cached_file, map_location="cpu") elif os.path.isfile(url_or_filename): checkpoint = torch.load(url_or_filename, map_location="cpu") else: raise RuntimeError("checkpoint url or path is invalid") state_dict = checkpoint["model"] state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed( state_dict["visual_encoder.pos_embed"], self.visual_encoder ) if "visual_encoder_m.pos_embed" in self.state_dict().keys(): state_dict["visual_encoder_m.pos_embed"] = interpolate_pos_embed( state_dict["visual_encoder_m.pos_embed"], self.visual_encoder_m ) for key in self.state_dict().keys(): if key in state_dict.keys(): if state_dict[key].shape != self.state_dict()[key].shape: del state_dict[key] msg = self.load_state_dict(state_dict, strict=False) logging.info("Missing keys {}".format(msg.missing_keys)) logging.info("load checkpoint from %s" % url_or_filename) return msg