John6666's picture
Upload 351 files
e84842d verified
raw
history blame
3.9 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import torch
import torch.nn as nn
from lavis.common.registry import registry
from lavis.models.base_model import BaseModel
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import GPT2LMHeadModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
@registry.register_model("gpt_dialogue")
class GPTDialogue(BaseModel, GPT2LMHeadModel):
PRETRAINED_MODEL_CONFIG_DICT = {"base": "configs/models/gpt_dialogue_base.yaml"}
def __init__(self, config, len_video_ft=4224):
super().__init__(config)
self.video_ff = nn.Linear(len_video_ft, config.n_embd)
self.video_ff_out = nn.Linear(config.n_embd, len_video_ft)
# Model parallel
self.model_parallel = False
self.device_map = None
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
samples,
past_key_values=None,
position_ids=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
input_embs = self.transformer.wte(samples["input_ids"])
video_embs = self.video_ff(samples["video_fts"])
input_embs = torch.cat([video_embs, input_embs], dim=1)
transformer_outputs = self.transformer(
attention_mask=samples["attn_mask"],
token_type_ids=samples["token_type_ids"],
inputs_embeds=input_embs,
position_ids=position_ids,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if samples["labels"] is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = samples["labels"][..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
if samples["video_fts"] is not None:
len_video_fts = samples["video_fts"].shape[1]
video_logits = self.video_ff_out(hidden_states[:, :len_video_fts, :])
# Shift so that tokens < n predict n
shift_logits = video_logits[..., :-1, :].contiguous()
shift_labels = samples["video_fts"][..., 1:, :].contiguous()
# Flatten the tokens
loss_fct = MSELoss(reduction="mean")
video_loss = loss_fct(shift_logits, shift_labels)
if loss is not None:
loss = loss + video_loss
else:
loss = video_loss
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
@classmethod
def from_config(cls, cfg):
model = cls.__bases__[1].from_pretrained("gpt2")
model.resize_token_embeddings(cfg["len_tokenizer"])
return model