Spaces:
Sleeping
Sleeping
File size: 4,457 Bytes
e84842d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from warnings import warn
import torch
import torch.nn.functional as F
from lavis.common.config import node_to_dict
from lavis.common.registry import registry
from lavis.models.alpro_models import AlproBase
from lavis.models.alpro_models.alpro_outputs import (
AlproIntermediateOutput,
AlproOutputWithLogits,
)
from lavis.models.med import XBertEncoder
from lavis.models.timesformer.vit import TimeSformer
from torch import nn
@registry.register_model("alpro_qa")
class AlproQA(AlproBase):
PRETRAINED_MODEL_CONFIG_DICT = {
"msrvtt": "configs/models/alpro_qa_msrvtt.yaml",
"msvd": "configs/models/alpro_qa_msvd.yaml",
}
def __init__(
self, visual_encoder, text_encoder, hidden_size, num_classes, max_txt_len=40
):
super().__init__()
self.tokenizer = self.init_tokenizer()
self.visual_encoder = visual_encoder
self.text_encoder = text_encoder
if num_classes > 0:
self.classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 2),
nn.ReLU(True),
nn.Linear(hidden_size * 2, num_classes),
)
else:
warn(f"num_classes is 0. Initialized {type(self)} without classifier.")
self.max_txt_len = max_txt_len
def forward(self, samples, is_train=True):
visual_inputs = samples["video"]
question = samples["text_input"]
targets = samples["answers"]
# forward text
text = self.tokenizer(
question,
padding="max_length",
truncation=True,
max_length=self.max_txt_len,
return_tensors="pt",
).to(self.device)
text_output = self.text_encoder.forward_text(
text,
token_type_ids=torch.zeros(
text.input_ids.shape, dtype=torch.long, device=self.device
),
)
text_embeds = text_output.last_hidden_state
# forward visual
# timeSformer asks for (b, c, t, h, w) as input.
video_embeds = self.visual_encoder.forward_features(visual_inputs)
video_atts = torch.ones(video_embeds.size()[:-1], dtype=torch.long).to(
self.device
)
# forward cross-encoder
attention_mask = torch.cat([text.attention_mask, video_atts], dim=1)
embedding_output = torch.cat([text_embeds, video_embeds], dim=1)
encoder_output = self.text_encoder(
encoder_embeds=embedding_output,
attention_mask=attention_mask,
return_dict=True,
mode="fusion",
)
prediction = self.classifier(encoder_output.last_hidden_state[:, 0, :])
if is_train:
loss = F.cross_entropy(prediction, targets)
# return {"loss": loss}
return AlproOutputWithLogits(
loss=loss,
intermediate_output=AlproIntermediateOutput(
video_embeds=video_embeds,
text_embeds=text_embeds,
encoder_output=encoder_output,
),
logits=prediction,
)
else:
return {"predictions": prediction, "targets": targets}
def predict(self, samples):
output = self.forward(samples, is_train=False)
return output
@classmethod
def from_config(cls, cfg):
# vision encoder
visual_encoder_config = node_to_dict(cfg.timesformer)
visual_encoder = TimeSformer(**visual_encoder_config)
# text encoder
text_encoder = XBertEncoder.from_config(cfg)
num_classes = cfg.get("num_classes", -1)
hidden_size = cfg.get("hidden_size", 768)
model = cls(
visual_encoder=visual_encoder,
text_encoder=text_encoder,
hidden_size=hidden_size,
num_classes=num_classes,
)
num_patches = (
visual_encoder_config["image_size"] // visual_encoder_config["patch_size"]
) ** 2
num_frames = visual_encoder_config["n_frms"]
model.load_checkpoint_from_config(
cfg, num_frames=num_frames, num_patches=num_patches
)
return model
|