File size: 4,457 Bytes
e84842d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

from warnings import warn

import torch
import torch.nn.functional as F
from lavis.common.config import node_to_dict
from lavis.common.registry import registry
from lavis.models.alpro_models import AlproBase
from lavis.models.alpro_models.alpro_outputs import (
    AlproIntermediateOutput,
    AlproOutputWithLogits,
)
from lavis.models.med import XBertEncoder
from lavis.models.timesformer.vit import TimeSformer
from torch import nn


@registry.register_model("alpro_qa")
class AlproQA(AlproBase):
    PRETRAINED_MODEL_CONFIG_DICT = {
        "msrvtt": "configs/models/alpro_qa_msrvtt.yaml",
        "msvd": "configs/models/alpro_qa_msvd.yaml",
    }

    def __init__(
        self, visual_encoder, text_encoder, hidden_size, num_classes, max_txt_len=40
    ):
        super().__init__()

        self.tokenizer = self.init_tokenizer()

        self.visual_encoder = visual_encoder

        self.text_encoder = text_encoder

        if num_classes > 0:
            self.classifier = nn.Sequential(
                nn.Linear(hidden_size, hidden_size * 2),
                nn.ReLU(True),
                nn.Linear(hidden_size * 2, num_classes),
            )
        else:
            warn(f"num_classes is 0. Initialized {type(self)} without classifier.")

        self.max_txt_len = max_txt_len

    def forward(self, samples, is_train=True):
        visual_inputs = samples["video"]
        question = samples["text_input"]
        targets = samples["answers"]

        # forward text
        text = self.tokenizer(
            question,
            padding="max_length",
            truncation=True,
            max_length=self.max_txt_len,
            return_tensors="pt",
        ).to(self.device)

        text_output = self.text_encoder.forward_text(
            text,
            token_type_ids=torch.zeros(
                text.input_ids.shape, dtype=torch.long, device=self.device
            ),
        )
        text_embeds = text_output.last_hidden_state

        # forward visual
        # timeSformer asks for (b, c, t, h, w) as input.
        video_embeds = self.visual_encoder.forward_features(visual_inputs)
        video_atts = torch.ones(video_embeds.size()[:-1], dtype=torch.long).to(
            self.device
        )

        # forward cross-encoder
        attention_mask = torch.cat([text.attention_mask, video_atts], dim=1)
        embedding_output = torch.cat([text_embeds, video_embeds], dim=1)

        encoder_output = self.text_encoder(
            encoder_embeds=embedding_output,
            attention_mask=attention_mask,
            return_dict=True,
            mode="fusion",
        )

        prediction = self.classifier(encoder_output.last_hidden_state[:, 0, :])
        if is_train:
            loss = F.cross_entropy(prediction, targets)
            # return {"loss": loss}
            return AlproOutputWithLogits(
                loss=loss,
                intermediate_output=AlproIntermediateOutput(
                    video_embeds=video_embeds,
                    text_embeds=text_embeds,
                    encoder_output=encoder_output,
                ),
                logits=prediction,
            )
        else:
            return {"predictions": prediction, "targets": targets}

    def predict(self, samples):
        output = self.forward(samples, is_train=False)
        return output

    @classmethod
    def from_config(cls, cfg):
        # vision encoder
        visual_encoder_config = node_to_dict(cfg.timesformer)
        visual_encoder = TimeSformer(**visual_encoder_config)

        # text encoder
        text_encoder = XBertEncoder.from_config(cfg)

        num_classes = cfg.get("num_classes", -1)
        hidden_size = cfg.get("hidden_size", 768)

        model = cls(
            visual_encoder=visual_encoder,
            text_encoder=text_encoder,
            hidden_size=hidden_size,
            num_classes=num_classes,
        )

        num_patches = (
            visual_encoder_config["image_size"] // visual_encoder_config["patch_size"]
        ) ** 2
        num_frames = visual_encoder_config["n_frms"]

        model.load_checkpoint_from_config(
            cfg, num_frames=num_frames, num_patches=num_patches
        )

        return model