John6666's picture
Upload 351 files
e84842d verified
raw
history blame
2.22 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import os
import torch
from lavis.common.dist_utils import download_cached_file
from lavis.common.utils import is_url
from lavis.models.base_model import BaseModel
from lavis.models.vit import interpolate_pos_embed
from transformers import BertTokenizer
class BlipBase(BaseModel):
@classmethod
def init_tokenizer(cls):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]})
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
return tokenizer
def load_from_pretrained(self, url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed(
state_dict["visual_encoder.pos_embed"], self.visual_encoder
)
if "visual_encoder_m.pos_embed" in self.state_dict().keys():
state_dict["visual_encoder_m.pos_embed"] = interpolate_pos_embed(
state_dict["visual_encoder_m.pos_embed"], self.visual_encoder_m
)
for key in self.state_dict().keys():
if key in state_dict.keys():
if state_dict[key].shape != self.state_dict()[key].shape:
del state_dict[key]
msg = self.load_state_dict(state_dict, strict=False)
logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)
return msg