Spaces:
Sleeping
Sleeping
File size: 14,591 Bytes
d9272c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 |
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from copy import deepcopy
import torch
import torch.nn.functional as F
from lavis.common.registry import registry
from lavis.models.albef_models import compute_sim_matrix
from lavis.models.base_model import (
MomentumDistilationMixin,
SharedQueueMixin,
all_gather_with_grad,
concat_all_gather,
)
from lavis.models.blip_models.blip import BlipBase
from lavis.models.blip_models.blip_outputs import (
BlipOutput,
BlipSimilarity,
BlipIntermediateOutput,
)
from lavis.models.med import XBertEncoder
from lavis.models.vit import VisionTransformerEncoder
from torch import nn
@registry.register_model("blip_retrieval")
class BlipRetrieval(BlipBase, MomentumDistilationMixin, SharedQueueMixin):
"""
BLIP retrieval model.
Supported model types:
- coco: fine-tuned BLIP base model on COCO dataset (Karpathy split).
- flickr: fine-tuned BLIP base model on Flickr30k dataset.
Usage:
>>> from lavis.models import load_model
>>> model = load_model("blip_retrieval", "coco")
>>> model = load_model("blip_retrieval", "flickr")
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"coco": "configs/models/blip_retrieval_coco.yaml",
"flickr": "configs/models/blip_retrieval_flickr.yaml",
}
def __init__(
self,
image_encoder,
text_encoder,
queue_size,
alpha=0.4,
embed_dim=256,
momentum=0.995,
negative_all_rank=False,
max_txt_len=35,
):
""" """
super().__init__()
self.tokenizer = self.init_tokenizer()
self.visual_encoder = image_encoder
self.text_encoder = text_encoder
# creating projection layers for ITC
text_width = text_encoder.config.hidden_size
vision_width = image_encoder.vision_width
self.vision_proj = nn.Linear(vision_width, embed_dim)
self.text_proj = nn.Linear(text_width, embed_dim)
self.itm_head = nn.Linear(text_width, 2)
# create the momentum encoder
self.visual_encoder_m = deepcopy(self.visual_encoder)
self.text_encoder_m = deepcopy(self.text_encoder)
self.vision_proj_m = deepcopy(self.vision_proj)
self.text_proj_m = deepcopy(self.text_proj)
self.model_pairs = [
[self.visual_encoder, self.visual_encoder_m],
[self.text_encoder, self.text_encoder_m],
[self.vision_proj, self.vision_proj_m],
[self.text_proj, self.text_proj_m],
]
self.copy_params()
# create the queue
self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
self.register_buffer("idx_queue", torch.full((1, queue_size), -100))
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
self.queue_size = queue_size
self.momentum = momentum
self.temp = nn.Parameter(0.07 * torch.ones([]))
self.alpha = alpha
self.max_txt_len = max_txt_len
self.negative_all_rank = negative_all_rank
def _rampup_factor(self, epoch, iters, num_iters_per_epoch):
return min(1, (epoch * num_iters_per_epoch + iters) / (2 * num_iters_per_epoch))
def forward(self, samples):
"""
Args:
samples (dict): A dictionary containing the following keys:
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). The input images.
- text_input (list): A list of length batch_size, each element is a string of text/caption.
- image_id (torch.Tensor): A tensor of shape (batch_size, ). The image ids, used to identify same images in batch.
- epoch (int): The current epoch.
- iters (int): The current iteration.
- num_iters_per_epoch (int): The number of iterations per epoch.
Returns:
BlipOutput: A BlipOutput object. See ``lavis.models.blip_models.blip_outputs.BlipOutput`` for more details.
Examples:
>>> import torch
>>> from lavis.models import load_model
>>> model = load_model("blip_retrieval", "coco")
>>> images = torch.randn(4, 3, 384, 384)
>>> text_input = ["caption of image 1", "another caption of image 1", "caption of image 2", "caption of image 3"]
>>> image_id = torch.tensor([1, 1, 2, 3])
>>> samples = {"image": images, "text_input": text_input, "image_id": image_id, "epoch": 0, "iters": 0, "num_iters_per_epoch": 100}
>>> output = model(samples)
>>> output.keys()
odict_keys(['sims', 'intermediate_output', 'loss', 'loss_itc', 'loss_itm'])
"""
image = samples["image"]
caption = samples["text_input"]
idx = samples["image_id"]
alpha = self.alpha * self._rampup_factor(
epoch=samples["epoch"],
iters=samples["iters"],
num_iters_per_epoch=samples["num_iters_per_epoch"],
)
with torch.no_grad():
self.temp.clamp_(0.001, 0.5)
image_embeds = self.visual_encoder.forward_features(image)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
)
image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
text = self.tokenizer(
caption,
padding="max_length",
truncation=True,
max_length=self.max_txt_len,
return_tensors="pt",
).to(image.device)
text_output = self.text_encoder.forward_text(text)
text_embeds = text_output.last_hidden_state
text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1)
# Image-text Contrastive Learning
idx = idx.view(-1, 1)
idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()], dim=1)
pos_idx = torch.eq(idx, idx_all).float()
sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)
# get momentum features
with torch.no_grad():
self._momentum_update()
image_embeds_m = self.visual_encoder_m(image)
image_feat_m = F.normalize(
self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1
)
image_feat_m_all = torch.cat(
[image_feat_m.t(), self.image_queue.clone().detach()], dim=1
)
text_output_m = self.text_encoder_m.forward_text(text)
text_embeds_m = text_output_m.last_hidden_state
text_feat_m = F.normalize(self.text_proj_m(text_embeds_m[:, 0, :]), dim=-1)
text_feat_m_all = torch.cat(
[text_feat_m.t(), self.text_queue.clone().detach()], dim=1
)
sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp
sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp
sim_i2t_targets = (
alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
)
sim_t2i_targets = (
alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
)
sim_i2t = image_feat @ text_feat_m_all / self.temp
sim_t2i = text_feat @ image_feat_m_all / self.temp
loss_i2t = -torch.sum(
F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1
).mean()
loss_t2i = -torch.sum(
F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1
).mean()
loss_itc = (loss_i2t + loss_t2i) / 2
self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx)
# Image-text Matching
encoder_input_ids = text.input_ids.clone()
encoder_input_ids[:, 0] = self.tokenizer.enc_token_id
# forward the positve image-text pair
bs = image.size(0)
output_pos = self.text_encoder(
encoder_input_ids,
attention_mask=text.attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
idxs = concat_all_gather(idx)
if self.negative_all_rank:
# compute sample similarity
with torch.no_grad():
mask = torch.eq(idx, idxs.t())
image_feat_world = concat_all_gather(image_feat)
text_feat_world = concat_all_gather(text_feat)
sim_i2t = image_feat @ text_feat_world.t() / self.temp
sim_t2i = text_feat @ image_feat_world.t() / self.temp
weights_i2t = F.softmax(sim_i2t, dim=1)
weights_i2t.masked_fill_(mask, 0)
weights_t2i = F.softmax(sim_t2i, dim=1)
weights_t2i.masked_fill_(mask, 0)
image_embeds_world = all_gather_with_grad(image_embeds)
# select a negative image (from all ranks) for each text
image_embeds_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
image_embeds_neg.append(image_embeds_world[neg_idx])
image_embeds_neg = torch.stack(image_embeds_neg, dim=0)
# select a negative text (from all ranks) for each image
input_ids_world = concat_all_gather(encoder_input_ids)
att_mask_world = concat_all_gather(text.attention_mask)
text_ids_neg = []
text_atts_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
text_ids_neg.append(input_ids_world[neg_idx])
text_atts_neg.append(att_mask_world[neg_idx])
else:
with torch.no_grad():
mask = torch.eq(idx, idx.t())
sim_i2t = image_feat @ text_feat.t() / self.temp
sim_t2i = text_feat @ image_feat.t() / self.temp
weights_i2t = F.softmax(sim_i2t, dim=1)
weights_i2t.masked_fill_(mask, 0)
weights_t2i = F.softmax(sim_t2i, dim=1)
weights_t2i.masked_fill_(mask, 0)
# select a negative image (from same rank) for each text
image_embeds_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
image_embeds_neg.append(image_embeds[neg_idx])
image_embeds_neg = torch.stack(image_embeds_neg, dim=0)
# select a negative text (from same rank) for each image
text_ids_neg = []
text_atts_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
text_ids_neg.append(encoder_input_ids[neg_idx])
text_atts_neg.append(text.attention_mask[neg_idx])
text_ids_neg = torch.stack(text_ids_neg, dim=0)
text_atts_neg = torch.stack(text_atts_neg, dim=0)
text_ids_all = torch.cat([encoder_input_ids, text_ids_neg], dim=0)
text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0)
image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0)
image_atts_all = torch.cat([image_atts, image_atts], dim=0)
output_neg = self.text_encoder(
text_ids_all,
attention_mask=text_atts_all,
encoder_hidden_states=image_embeds_all,
encoder_attention_mask=image_atts_all,
return_dict=True,
)
vl_embeddings = torch.cat(
[
output_pos.last_hidden_state[:, 0, :],
output_neg.last_hidden_state[:, 0, :],
],
dim=0,
)
itm_logits = self.itm_head(vl_embeddings)
itm_labels = torch.cat(
[torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
dim=0,
).to(self.device)
loss_itm = F.cross_entropy(itm_logits, itm_labels)
return BlipOutput(
loss=loss_itc + loss_itm,
loss_itc=loss_itc,
loss_itm=loss_itm,
sims=BlipSimilarity(
sim_i2t=sim_i2t,
sim_t2i=sim_t2i,
sim_i2t_m=sim_i2t_m,
sim_t2i_m=sim_t2i_m,
sim_i2t_targets=sim_i2t_targets,
sim_t2i_targets=sim_t2i_targets,
),
intermediate_output=BlipIntermediateOutput(
image_embeds=image_embeds,
image_embeds_m=image_embeds_m,
text_embeds=text_embeds,
text_embeds_m=text_embeds_m,
encoder_output=output_pos,
encoder_output_neg=output_neg,
itm_logits=itm_logits,
itm_labels=itm_labels,
),
)
def reset_queue_ptr(self):
self.queue_ptr = torch.zeros(1, dtype=torch.long)
@classmethod
def from_config(cls, cfg=None):
# set from_pretrained=True to load weights for 'bert-base-uncased'
image_encoder = VisionTransformerEncoder.from_config(cfg)
text_encoder = XBertEncoder.from_config(cfg)
embed_dim = cfg.get("embed_dim", 256)
momentum = cfg.get("momentum", 0.995)
alpha = cfg.get("alpha", 0.4)
negative_all_rank = cfg.get("negative_all_rank", False)
queue_size = cfg.get("queue_size", 0)
max_txt_len = cfg.get("max_txt_len", 35)
model = cls(
image_encoder=image_encoder,
text_encoder=text_encoder,
queue_size=queue_size,
alpha=alpha,
embed_dim=embed_dim,
momentum=momentum,
negative_all_rank=negative_all_rank,
max_txt_len=max_txt_len,
)
model.load_checkpoint_from_config(cfg)
model.reset_queue_ptr()
return model
def compute_sim_matrix(self, data_loader, task_cfg):
"""
Compute similarity i2t, t2i matrix for the given data loader.
"""
k_test = task_cfg.k_test
return compute_sim_matrix(model=self, data_loader=data_loader, k_test=k_test)
|