John6666's picture
Upload 351 files
e84842d verified
raw
history blame
14.4 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import datetime
import logging
import time
import lavis.common.dist_utils as dist_utils
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from lavis.common.config import node_to_dict
from lavis.common.dist_utils import get_rank
from lavis.common.logger import MetricLogger
from lavis.common.registry import registry
from lavis.models.alpro_models import AlproBase
from lavis.models.alpro_models.alpro_outputs import AlproIntermediateOutput, AlproOutput
from lavis.models.base_model import all_gather_with_grad
from lavis.models.med import XBertEncoder
from lavis.models.timesformer.vit import TimeSformer
from torch import nn
@registry.register_model("alpro_retrieval")
class AlproRetrieval(AlproBase):
PRETRAINED_MODEL_CONFIG_DICT = {
"msrvtt": "configs/models/alpro_retrieval_msrvtt.yaml",
"didemo": "configs/models/alpro_retrieval_didemo.yaml",
}
def __init__(
self,
visual_encoder,
text_encoder,
vision_width=768,
text_width=768,
embed_dim=256,
max_txt_len=35,
temp=0.07,
):
super().__init__()
self.temp = nn.Parameter(torch.ones([]) * temp)
self.tokenizer = self.init_tokenizer()
self.visual_encoder = visual_encoder
self.text_encoder = text_encoder
vision_width = vision_width
text_width = text_width
self.vision_proj = nn.Linear(vision_width, embed_dim)
self.text_proj = nn.Linear(text_width, embed_dim)
self.itm_head = nn.Linear(text_width, 2)
self.max_txt_len = max_txt_len
def forward(self, samples):
with torch.no_grad():
self.temp.clamp_(0.001, 0.5)
visual_inputs = samples["video"]
caption = samples["text_input"]
b, t, c, h, w = visual_inputs.shape
# forward text
text = self.tokenizer(
caption,
padding="max_length",
truncation=True,
max_length=self.max_txt_len,
return_tensors="pt",
).to(self.device)
text_output = self.text_encoder.forward_text(
text,
token_type_ids=torch.zeros(
text.input_ids.shape, dtype=torch.long, device=self.device
),
)
text_embeds = text_output.last_hidden_state
text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1)
# forward visual
# timeSformer asks for (b, c, t, h, w) as input.
video_embeds = self.visual_encoder.forward_features(visual_inputs)
video_feat = F.normalize(self.vision_proj(video_embeds[:, 0, :]), dim=-1)
video_atts = torch.ones(video_embeds.size()[:-1], dtype=torch.long).to(
self.device
)
# ========== (in-batch) ITC loss ==========
gathered_video_feats = all_gather_with_grad(video_feat)
gathered_text_feats = all_gather_with_grad(text_feat)
sim_v2t = video_feat @ gathered_text_feats.t() / self.temp
sim_t2v = text_feat @ gathered_video_feats.t() / self.temp
sim_targets = torch.zeros_like(sim_v2t)
local_rank = get_rank()
b_start, b_end = b * local_rank, b * (local_rank + 1)
sim_targets[:, b_start:b_end] = torch.eye(b)
loss_v2t = -torch.sum(F.log_softmax(sim_v2t, dim=1) * sim_targets, dim=1).mean()
loss_t2v = -torch.sum(F.log_softmax(sim_t2v, dim=1) * sim_targets, dim=1).mean()
vtc_loss = (loss_v2t + loss_t2v) / 2
(
vtm_loss,
vtm_logits,
vtm_labels,
encoder_output,
encoder_output_neg,
) = self.compute_vtm(
text_embeds=text_embeds,
text_atts=text.attention_mask,
image_embeds=video_embeds,
image_atts=video_atts,
sim_i2t=sim_v2t.clone(), # for hard mining
sim_t2i=sim_t2v.clone(), # for hard mining
)
loss = vtc_loss + vtm_loss
# return {"loss": loss}
return AlproOutput(
loss=loss,
loss_vtc=vtc_loss,
loss_vtm=vtm_loss,
intermediate_output=AlproIntermediateOutput(
video_embeds=video_embeds,
text_embeds=text_embeds,
encoder_output=encoder_output,
encoder_output_neg=encoder_output_neg,
vtm_logits=vtm_logits,
vtm_labels=vtm_labels,
),
)
def compute_vtm(
self, text_embeds, text_atts, image_embeds, image_atts, sim_i2t, sim_t2i
):
device = self.device
# ====== positive pairs =======
attention_mask = torch.cat([text_atts, image_atts], dim=1)
embedding_output_pos = torch.cat([text_embeds, image_embeds], dim=1)
encoder_outputs_pos = self.text_encoder(
encoder_embeds=embedding_output_pos,
attention_mask=attention_mask,
return_dict=True,
mode="fusion",
)
# ====== negative pairs =======
bs = text_embeds.shape[0]
local_rank = get_rank()
b_start, b_end = bs * local_rank, bs * (local_rank + 1)
with torch.no_grad():
weights_v2t = sim_i2t[:, b_start:b_end]
weights_t2v = sim_t2i[:, b_start:b_end]
# never select self as negative
weights_v2t.fill_diagonal_(-np.Inf)
weights_t2v.fill_diagonal_(-np.Inf)
weights_v2t = F.softmax(weights_v2t, dim=1)
weights_t2v = F.softmax(weights_t2v, dim=1)
# select a negative image for each text
# FIXME to optimize using indexing operations
image_embeds_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_t2v[b], 1).item()
image_embeds_neg.append(image_embeds[neg_idx])
image_embeds_neg = torch.stack(image_embeds_neg, dim=0)
# select a negative text for each image
text_embeds_neg = []
text_atts_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_v2t[b], 1).item()
text_embeds_neg.append(text_embeds[neg_idx])
text_atts_neg.append(text_atts[neg_idx])
text_embeds_neg = torch.stack(text_embeds_neg, dim=0)
text_atts_neg = torch.stack(text_atts_neg, dim=0)
text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0)
text_atts_all = torch.cat([text_atts, text_atts_neg], dim=0)
video_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0)
video_atts_all = torch.cat([image_atts, image_atts], dim=0)
attention_mask_all = torch.cat([text_atts_all, video_atts_all], dim=1)
embedding_output_all = torch.cat([text_embeds_all, video_embeds_all], dim=1)
# forward negative pairs via cross encoder
encoder_outputs_neg = self.text_encoder(
encoder_embeds=embedding_output_all,
attention_mask=attention_mask_all,
return_dict=True,
mode="fusion",
)
vl_embeddings = torch.cat(
[
encoder_outputs_pos.last_hidden_state[:, 0, :],
encoder_outputs_neg.last_hidden_state[:, 0, :],
],
dim=0,
)
vtm_logits = self.itm_head(vl_embeddings)
vtm_labels = torch.cat(
[torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
dim=0,
).to(device)
vtm_loss = F.cross_entropy(vtm_logits, vtm_labels)
return (
vtm_loss,
vtm_logits,
vtm_labels,
encoder_outputs_pos,
encoder_outputs_neg,
)
def compute_sim_matrix(self, data_loader, task_cfg):
k_test = task_cfg.get("k_test")
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation:"
logging.info("Computing features for evaluation...")
start_time = time.time()
texts = data_loader.dataset.text
num_text = len(texts)
text_bs = 256
text_ids = []
text_embeds = []
text_feats = []
text_atts = []
for i in range(0, num_text, text_bs):
text = texts[i : min(num_text, i + text_bs)]
text_input = self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=self.max_txt_len,
return_tensors="pt",
).to(self.device)
text_output = self.text_encoder.forward_text(
text_input,
token_type_ids=torch.zeros(
text_input.input_ids.shape, dtype=torch.long, device=self.device
),
)
text_feats.append(text_output.last_hidden_state.cpu())
text_embed = F.normalize(
self.text_proj(text_output.last_hidden_state[:, 0, :])
)
text_embeds.append(text_embed)
text_ids.append(text_input.input_ids)
text_atts.append(text_input.attention_mask)
text_embeds = torch.cat(text_embeds, dim=0)
text_ids = torch.cat(text_ids, dim=0)
text_atts = torch.cat(text_atts, dim=0)
text_feats = torch.cat(text_feats, dim=0)
video_feats = []
video_embeds = []
for samples in data_loader:
video = samples["video"]
video = video.to(self.device)
video_feat = self.visual_encoder.forward_features(video)
video_embed = self.vision_proj(video_feat[:, 0, :])
video_embed = F.normalize(video_embed, dim=-1)
video_feats.append(video_feat.cpu())
video_embeds.append(video_embed)
video_feats = torch.cat(video_feats, dim=0)
video_embeds = torch.cat(video_embeds, dim=0)
sims_matrix = video_embeds @ text_embeds.t()
score_matrix_v2t = torch.full(
(len(data_loader.dataset.image), len(texts)), -100.0
).to(self.device)
num_tasks = dist_utils.get_world_size()
rank = dist_utils.get_rank()
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
# video-to-text
for i, sims in enumerate(
metric_logger.log_every(sims_matrix[start:end], 50, header)
):
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
video_feats_repeat = (
video_feats[start + i].repeat(k_test, 1, 1).to(self.device)
)
video_atts_repeat = torch.ones(
video_feats_repeat.size()[:-1], dtype=torch.long
).to(self.device)
attention_mask = torch.cat([text_atts[topk_idx], video_atts_repeat], dim=1)
embedding_output = torch.cat(
[text_feats[topk_idx].to(self.device), video_feats_repeat], dim=1
)
output = self.text_encoder(
encoder_embeds=embedding_output,
attention_mask=attention_mask,
return_dict=True,
mode="fusion",
)
score = self.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
score_matrix_v2t[start + i, topk_idx] = score + topk_sim
# text-to-video
sims_matrix = sims_matrix.t()
score_matrix_t2v = torch.full(
(len(texts), len(data_loader.dataset.image)), -100.0
).to(self.device)
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
for i, sims in enumerate(
metric_logger.log_every(sims_matrix[start:end], 50, header)
):
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
text_feats_repeat = (
text_feats[start + i].repeat(k_test, 1, 1).to(self.device)
)
text_atts_repeat = text_atts[start + i].repeat(k_test, 1).to(self.device)
video_atts = torch.ones(
video_feats[topk_idx].size()[:-1], dtype=torch.long
).to(self.device)
embedding_output = torch.cat(
[text_feats_repeat, video_feats[topk_idx].to(self.device)], dim=1
)
attention_mask = torch.cat([text_atts_repeat, video_atts], dim=1)
output = self.text_encoder(
encoder_embeds=embedding_output,
attention_mask=attention_mask,
return_dict=True,
mode="fusion",
)
score = self.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
score_matrix_t2v[start + i, topk_idx] = score + topk_sim
if dist_utils.is_dist_avail_and_initialized():
dist.barrier()
torch.distributed.all_reduce(
score_matrix_v2t, op=torch.distributed.ReduceOp.SUM
)
torch.distributed.all_reduce(
score_matrix_t2v, op=torch.distributed.ReduceOp.SUM
)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logging.info("Evaluation time {}".format(total_time_str))
return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy()
@classmethod
def from_config(cls, cfg):
# vision encoder
visual_encoder_config = node_to_dict(cfg.timesformer)
visual_encoder = TimeSformer(**visual_encoder_config)
# text encoder
text_encoder = XBertEncoder.from_config(cfg)
max_txt_len = cfg.get("max_txt_len", 35)
model = cls(
visual_encoder=visual_encoder,
text_encoder=text_encoder,
max_txt_len=max_txt_len,
)
num_patches = (
visual_encoder_config["image_size"] // visual_encoder_config["patch_size"]
) ** 2
num_frames = visual_encoder_config["n_frms"]
model.load_checkpoint_from_config(
cfg, num_frames=num_frames, num_patches=num_patches
)
return model