Spaces:
Sleeping
Sleeping
File size: 7,181 Bytes
d9272c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import datetime
import logging
import os
import time
import lavis.common.dist_utils as dist_utils
import torch
import torch.distributed as dist
import torch.nn.functional as F
from lavis.common.dist_utils import download_cached_file
from lavis.common.logger import MetricLogger
from lavis.common.utils import is_url
from lavis.models.base_model import BaseModel
from lavis.models.vit import interpolate_pos_embed
from transformers import BertTokenizer
class AlbefBase(BaseModel):
@classmethod
def init_tokenizer(cls):
return BertTokenizer.from_pretrained("bert-base-uncased")
def load_from_pretrained(self, url_or_filename, rename_text_keys=True):
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
if "model" in checkpoint:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed(
state_dict["visual_encoder.pos_embed"], self.visual_encoder
)
if (
"visual_encoder_m.pos_embed" in self.state_dict().keys()
and "visual_encoder_m.pos_embed" in state_dict
):
state_dict["visual_encoder_m.pos_embed"] = interpolate_pos_embed(
state_dict["visual_encoder_m.pos_embed"], self.visual_encoder_m
)
if rename_text_keys:
for key in list(state_dict.keys()):
if "bert" in key:
new_key = key.replace("bert.", "")
state_dict[new_key] = state_dict[key]
del state_dict[key]
for key in self.state_dict().keys():
if key in state_dict.keys():
if state_dict[key].shape != self.state_dict()[key].shape:
del state_dict[key]
msg = self.load_state_dict(state_dict, strict=False)
logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)
return msg
def compute_sim_matrix(model, data_loader, **kwargs):
k_test = kwargs.pop("k_test")
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation:"
logging.info("Computing features for evaluation...")
start_time = time.time()
texts = data_loader.dataset.text
num_text = len(texts)
text_bs = 256
text_ids = []
text_embeds = []
text_atts = []
for i in range(0, num_text, text_bs):
text = texts[i : min(num_text, i + text_bs)]
text_input = model.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=35,
return_tensors="pt",
).to(model.device)
text_output = model.text_encoder.forward_text(text_input)
text_embed = F.normalize(
model.text_proj(text_output.last_hidden_state[:, 0, :])
)
text_embeds.append(text_embed)
text_ids.append(text_input.input_ids)
text_atts.append(text_input.attention_mask)
text_embeds = torch.cat(text_embeds, dim=0)
text_ids = torch.cat(text_ids, dim=0)
text_atts = torch.cat(text_atts, dim=0)
if hasattr(model.tokenizer, "enc_token_id"):
text_ids[:, 0] = model.tokenizer.enc_token_id
image_feats = []
image_embeds = []
for samples in data_loader:
image = samples["image"]
image = image.to(model.device)
image_feat = model.visual_encoder.forward_features(image)
image_embed = model.vision_proj(image_feat[:, 0, :])
image_embed = F.normalize(image_embed, dim=-1)
image_feats.append(image_feat.cpu())
image_embeds.append(image_embed)
image_feats = torch.cat(image_feats, dim=0)
image_embeds = torch.cat(image_embeds, dim=0)
sims_matrix = image_embeds @ text_embeds.t()
score_matrix_i2t = torch.full(
(len(data_loader.dataset.image), len(texts)), -100.0
).to(model.device)
num_tasks = dist_utils.get_world_size()
rank = dist_utils.get_rank()
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
for i, sims in enumerate(
metric_logger.log_every(sims_matrix[start:end], 50, header)
):
# topk_sim, topk_idx = sims.topk(k=config["k_test"], dim=0)
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
encoder_output = image_feats[start + i].repeat(k_test, 1, 1).to(model.device)
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
model.device
)
output = model.text_encoder(
text_ids[topk_idx],
attention_mask=text_atts[topk_idx],
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=True,
)
score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
score_matrix_i2t[start + i, topk_idx] = score + topk_sim
sims_matrix = sims_matrix.t()
score_matrix_t2i = torch.full(
(len(texts), len(data_loader.dataset.image)), -100.0
).to(model.device)
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
for i, sims in enumerate(
metric_logger.log_every(sims_matrix[start:end], 50, header)
):
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
encoder_output = image_feats[topk_idx.cpu()].to(model.device)
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
model.device
)
output = model.text_encoder(
text_ids[start + i].repeat(k_test, 1),
attention_mask=text_atts[start + i].repeat(k_test, 1),
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=True,
)
score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
score_matrix_t2i[start + i, topk_idx] = score + topk_sim
if dist_utils.is_dist_avail_and_initialized():
dist.barrier()
torch.distributed.all_reduce(
score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
)
torch.distributed.all_reduce(
score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logging.info("Evaluation time {}".format(total_time_str))
return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
|