Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
from warnings import warn | |
import torch | |
import torch.nn.functional as F | |
from lavis.common.config import node_to_dict | |
from lavis.common.registry import registry | |
from lavis.models.alpro_models import AlproBase | |
from lavis.models.alpro_models.alpro_outputs import ( | |
AlproIntermediateOutput, | |
AlproOutputWithLogits, | |
) | |
from lavis.models.med import XBertEncoder | |
from lavis.models.timesformer.vit import TimeSformer | |
from torch import nn | |
class AlproQA(AlproBase): | |
PRETRAINED_MODEL_CONFIG_DICT = { | |
"msrvtt": "configs/models/alpro_qa_msrvtt.yaml", | |
"msvd": "configs/models/alpro_qa_msvd.yaml", | |
} | |
def __init__( | |
self, visual_encoder, text_encoder, hidden_size, num_classes, max_txt_len=40 | |
): | |
super().__init__() | |
self.tokenizer = self.init_tokenizer() | |
self.visual_encoder = visual_encoder | |
self.text_encoder = text_encoder | |
if num_classes > 0: | |
self.classifier = nn.Sequential( | |
nn.Linear(hidden_size, hidden_size * 2), | |
nn.ReLU(True), | |
nn.Linear(hidden_size * 2, num_classes), | |
) | |
else: | |
warn(f"num_classes is 0. Initialized {type(self)} without classifier.") | |
self.max_txt_len = max_txt_len | |
def forward(self, samples, is_train=True): | |
visual_inputs = samples["video"] | |
question = samples["text_input"] | |
targets = samples["answers"] | |
# forward text | |
text = self.tokenizer( | |
question, | |
padding="max_length", | |
truncation=True, | |
max_length=self.max_txt_len, | |
return_tensors="pt", | |
).to(self.device) | |
text_output = self.text_encoder.forward_text( | |
text, | |
token_type_ids=torch.zeros( | |
text.input_ids.shape, dtype=torch.long, device=self.device | |
), | |
) | |
text_embeds = text_output.last_hidden_state | |
# forward visual | |
# timeSformer asks for (b, c, t, h, w) as input. | |
video_embeds = self.visual_encoder.forward_features(visual_inputs) | |
video_atts = torch.ones(video_embeds.size()[:-1], dtype=torch.long).to( | |
self.device | |
) | |
# forward cross-encoder | |
attention_mask = torch.cat([text.attention_mask, video_atts], dim=1) | |
embedding_output = torch.cat([text_embeds, video_embeds], dim=1) | |
encoder_output = self.text_encoder( | |
encoder_embeds=embedding_output, | |
attention_mask=attention_mask, | |
return_dict=True, | |
mode="fusion", | |
) | |
prediction = self.classifier(encoder_output.last_hidden_state[:, 0, :]) | |
if is_train: | |
loss = F.cross_entropy(prediction, targets) | |
# return {"loss": loss} | |
return AlproOutputWithLogits( | |
loss=loss, | |
intermediate_output=AlproIntermediateOutput( | |
video_embeds=video_embeds, | |
text_embeds=text_embeds, | |
encoder_output=encoder_output, | |
), | |
logits=prediction, | |
) | |
else: | |
return {"predictions": prediction, "targets": targets} | |
def predict(self, samples): | |
output = self.forward(samples, is_train=False) | |
return output | |
def from_config(cls, cfg): | |
# vision encoder | |
visual_encoder_config = node_to_dict(cfg.timesformer) | |
visual_encoder = TimeSformer(**visual_encoder_config) | |
# text encoder | |
text_encoder = XBertEncoder.from_config(cfg) | |
num_classes = cfg.get("num_classes", -1) | |
hidden_size = cfg.get("hidden_size", 768) | |
model = cls( | |
visual_encoder=visual_encoder, | |
text_encoder=text_encoder, | |
hidden_size=hidden_size, | |
num_classes=num_classes, | |
) | |
num_patches = ( | |
visual_encoder_config["image_size"] // visual_encoder_config["patch_size"] | |
) ** 2 | |
num_frames = visual_encoder_config["n_frms"] | |
model.load_checkpoint_from_config( | |
cfg, num_frames=num_frames, num_patches=num_patches | |
) | |
return model | |