jchwenger's picture
Upload 351 files (#2)
d9272c6 verified
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from copy import deepcopy
import torch
import torch.nn.functional as F
from lavis.common.registry import registry
from lavis.models.base_model import MomentumDistilationMixin
from lavis.models.blip_models.blip import BlipBase
from lavis.models.blip_models.blip_outputs import (
BlipIntermediateOutput,
BlipOutputWithLogits,
)
from lavis.models.med import XBertEncoder
from lavis.models.vit import VisionTransformerEncoder
from torch import nn
@registry.register_model("blip_classification")
class BlipClassification(BlipBase, MomentumDistilationMixin):
PRETRAINED_MODEL_CONFIG_DICT = {
"base": "configs/models/blip_classification_base.yaml",
}
def __init__(
self,
image_encoder,
text_encoder,
num_classes,
momentum=0.995,
alpha=0.4,
max_txt_len=40,
use_distill=True,
):
super().__init__()
self.tokenizer = self.init_tokenizer()
self.use_distill = use_distill
self.visual_encoder = image_encoder
self.text_encoder = text_encoder
hidden_size = text_encoder.config.hidden_size
self.cls_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, num_classes),
)
if self.use_distill:
self.visual_encoder_m = deepcopy(self.visual_encoder)
self.text_encoder_m = deepcopy(self.text_encoder)
self.cls_head_m = deepcopy(self.cls_head)
self.momentum = momentum
self.alpha = alpha
self.model_pairs = [
[self.visual_encoder, self.visual_encoder_m],
[self.text_encoder, self.text_encoder_m],
[self.cls_head, self.cls_head_m],
]
self.copy_params()
self.max_txt_len = max_txt_len
def _rampup_factor(self, epoch, iters, num_iters_per_epoch):
return min(1, (epoch * num_iters_per_epoch + iters) / num_iters_per_epoch)
def forward(self, samples, is_train=True):
sentences = samples["text_input"]
sentences = self.tokenizer(
sentences,
padding="longest",
truncation=True,
max_length=self.max_txt_len,
return_tensors="pt",
).to(self.device)
samples.update({"tokenized_text": sentences})
targets = samples["label"]
image_embeds = self.visual_encoder.forward_features(samples["image"])
encoder_output = self.text_encoder.forward_automask(
samples["tokenized_text"], image_embeds
)
prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :])
if is_train:
if self.use_distill:
with torch.no_grad():
self._momentum_update()
image_embeds_m = self.visual_encoder_m(samples["image"])
encoder_output_m = self.text_encoder_m.forward_automask(
samples["tokenized_text"], image_embeds_m
)
prediction_m = self.cls_head_m(
encoder_output_m.last_hidden_state[:, 0, :]
)
alpha = self.alpha * self._rampup_factor(
epoch=samples["epoch"],
iters=samples["iters"],
num_iters_per_epoch=samples["num_iters_per_epoch"],
)
loss = (1 - alpha) * F.cross_entropy(
prediction, targets
) - alpha * torch.sum(
F.log_softmax(prediction, dim=1) * F.softmax(prediction_m, dim=1),
dim=1,
).mean()
else:
loss = F.cross_entropy(prediction, targets)
# return {"loss": loss}
return BlipOutputWithLogits(
loss=loss,
intermediate_output=BlipIntermediateOutput(
image_embeds=image_embeds,
image_embeds_m=image_embeds_m,
encoder_output=encoder_output,
encoder_output_m=encoder_output_m,
),
logits=prediction,
logits_m=prediction_m,
)
else:
return {"predictions": prediction, "targets": targets}
def predict(self, samples):
output = self.forward(samples, is_train=False)
return output
@classmethod
def from_config(cls, cfg=None):
image_encoder = VisionTransformerEncoder.from_config(cfg)
# text encoder + multimodal encoder
text_encoder = XBertEncoder.from_config(cfg)
use_distill = cfg.get("use_distill", True)
momentum = cfg.get("momentum", 0.995)
num_classes = cfg.get("num_classes", -1)
alpha = cfg.get("alpha", 0.4)
max_txt_len = cfg.get("max_txt_len", 40)
assert num_classes > 1, "Invalid number of classes provided, found {}".format(
num_classes
)
model = cls(
image_encoder=image_encoder,
text_encoder=text_encoder,
use_distill=use_distill,
alpha=alpha,
num_classes=num_classes,
momentum=momentum,
max_txt_len=max_txt_len,
)
# load pre-trained weights
pretrain_path = cfg.get("pretrained", None)
if pretrain_path is not None:
msg = model.load_from_pretrained(url_or_filename=pretrain_path)
return model