Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Copyright (c) Facebook, Inc. All Rights Reserved | |
import torch | |
from torch import nn | |
try: | |
from transformers.modeling_bert import ( | |
BertEmbeddings, | |
ACT2FN, | |
) | |
except ImportError: | |
pass | |
class VideoTokenMLP(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
input_dim = config.input_dim if hasattr(config, "input_dim") else 512 | |
self.linear1 = nn.Linear(input_dim, config.hidden_size) | |
self.LayerNorm = nn.LayerNorm(config.hidden_size) | |
self.activation = ACT2FN[config.hidden_act] | |
self.linear2 = nn.Linear(config.hidden_size, config.hidden_size) | |
def forward(self, hidden_states): | |
hidden_states = self.linear1(hidden_states) | |
hidden_states = self.activation(hidden_states) | |
hidden_states = self.LayerNorm(hidden_states) | |
hidden_states = self.linear2(hidden_states) | |
return hidden_states | |
class MMBertEmbeddings(BertEmbeddings): | |
def __init__(self, config): | |
super().__init__(config) | |
self.max_video_len = config.max_video_len | |
if hasattr(config, "use_seg_emb") and config.use_seg_emb: | |
"""the original VLM paper uses seg_embeddings for temporal space. | |
although not used it changed the randomness of initialization. | |
we keep it for reproducibility. | |
""" | |
self.seg_embeddings = nn.Embedding(256, config.hidden_size) | |
def forward( | |
self, | |
input_ids, | |
input_video_embeds, | |
token_type_ids=None, | |
position_ids=None, | |
inputs_embeds=None, | |
): | |
input_tensor = input_ids if input_ids is not None else inputs_embeds | |
if input_video_embeds is not None: | |
input_shape = ( | |
input_tensor.size(0), | |
input_tensor.size(1) + input_video_embeds.size(1), | |
) | |
else: | |
input_shape = (input_tensor.size(0), input_tensor.size(1)) | |
if position_ids is None: | |
""" | |
Auto skip position embeddings for text only case. | |
use cases: | |
(1) action localization and segmentation: | |
feed in len-1 dummy video token needs text part to | |
skip input_video_embeds.size(1) for the right | |
position_ids for video [SEP] and rest text tokens. | |
(2) MMFusionShare for two forward passings: | |
in `forward_text`: input_video_embeds is None. | |
need to skip video [SEP] token. | |
# video_len + 1: [CLS] + video_embed | |
# self.max_video_len + 1: [SEP] for video. | |
# self.max_video_len + 2: [SEP] for video. | |
# self.max_video_len + input_ids.size(1): rest for text. | |
""" | |
if input_video_embeds is not None: | |
video_len = input_video_embeds.size(1) | |
starting_offset = self.max_video_len + 1 # video [SEP] | |
ending_offset = self.max_video_len + input_ids.size(1) | |
else: | |
video_len = 0 | |
starting_offset = self.max_video_len + 2 # first text token. | |
ending_offset = self.max_video_len + input_ids.size(1) + 1 | |
position_ids = torch.cat([ | |
self.position_ids[:, :video_len + 1], | |
self.position_ids[:, starting_offset:ending_offset] | |
], dim=1) | |
if token_type_ids is None: | |
token_type_ids = torch.zeros( | |
input_shape, dtype=torch.long, device=self.position_ids.device | |
) | |
""" | |
the format of input_ids is [CLS] [SEP] caption [SEP] padding. | |
the goal is to build [CLS] video tokens [SEP] caption [SEP] . | |
""" | |
if inputs_embeds is None: | |
inputs_embeds = self.word_embeddings(input_ids) | |
if input_video_embeds is not None: | |
inputs_mm_embeds = torch.cat([ | |
inputs_embeds[:, :1], input_video_embeds, inputs_embeds[:, 1:] | |
], dim=1) | |
else: | |
# text only for `MMFusionShare`. | |
inputs_mm_embeds = inputs_embeds | |
position_embeddings = self.position_embeddings(position_ids) | |
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |
embeddings = inputs_mm_embeds + position_embeddings | |
embeddings += token_type_embeddings | |
embeddings = self.LayerNorm(embeddings) | |
embeddings = self.dropout(embeddings) | |
return embeddings | |
class AlignHead(nn.Module): | |
"""this will load pre-trained weights for NSP, which is desirable.""" | |
def __init__(self, config): | |
super().__init__() | |
self.seq_relationship = nn.Linear(config.hidden_size, 2) | |
def forward(self, dropout_pooled_output): | |
logits = self.seq_relationship(dropout_pooled_output) | |
return logits | |