Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import torch | |
import torch.nn as nn | |
from lavis.common.registry import registry | |
from lavis.models.base_model import BaseModel | |
from torch.nn import CrossEntropyLoss, MSELoss | |
from transformers import GPT2LMHeadModel | |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions | |
class GPTDialogue(BaseModel, GPT2LMHeadModel): | |
PRETRAINED_MODEL_CONFIG_DICT = {"base": "configs/models/gpt_dialogue_base.yaml"} | |
def __init__(self, config, len_video_ft=4224): | |
super().__init__(config) | |
self.video_ff = nn.Linear(len_video_ft, config.n_embd) | |
self.video_ff_out = nn.Linear(config.n_embd, len_video_ft) | |
# Model parallel | |
self.model_parallel = False | |
self.device_map = None | |
# Initialize weights and apply final processing | |
self.post_init() | |
def forward( | |
self, | |
samples, | |
past_key_values=None, | |
position_ids=None, | |
head_mask=None, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
use_cache=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
input_embs = self.transformer.wte(samples["input_ids"]) | |
video_embs = self.video_ff(samples["video_fts"]) | |
input_embs = torch.cat([video_embs, input_embs], dim=1) | |
transformer_outputs = self.transformer( | |
attention_mask=samples["attn_mask"], | |
token_type_ids=samples["token_type_ids"], | |
inputs_embeds=input_embs, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = transformer_outputs[0] | |
lm_logits = self.lm_head(hidden_states) | |
loss = None | |
if samples["labels"] is not None: | |
# Shift so that tokens < n predict n | |
shift_logits = lm_logits[..., :-1, :].contiguous() | |
shift_labels = samples["labels"][..., 1:].contiguous() | |
# Flatten the tokens | |
loss_fct = CrossEntropyLoss(ignore_index=-1) | |
loss = loss_fct( | |
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) | |
) | |
if samples["video_fts"] is not None: | |
len_video_fts = samples["video_fts"].shape[1] | |
video_logits = self.video_ff_out(hidden_states[:, :len_video_fts, :]) | |
# Shift so that tokens < n predict n | |
shift_logits = video_logits[..., :-1, :].contiguous() | |
shift_labels = samples["video_fts"][..., 1:, :].contiguous() | |
# Flatten the tokens | |
loss_fct = MSELoss(reduction="mean") | |
video_loss = loss_fct(shift_logits, shift_labels) | |
if loss is not None: | |
loss = loss + video_loss | |
else: | |
loss = video_loss | |
return CausalLMOutputWithCrossAttentions( | |
loss=loss, | |
logits=lm_logits, | |
past_key_values=transformer_outputs.past_key_values, | |
hidden_states=transformer_outputs.hidden_states, | |
attentions=transformer_outputs.attentions, | |
cross_attentions=transformer_outputs.cross_attentions, | |
) | |
def from_config(cls, cfg): | |
model = cls.__bases__[1].from_pretrained("gpt2") | |
model.resize_token_embeddings(cfg["len_tokenizer"]) | |
return model | |