Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. All Rights Reserved | |
import numpy as np | |
import os | |
import torch | |
class Processor(object): | |
""" | |
A generic processor for video (codec, feature etc.) and text. | |
""" | |
def __call__(self, **kwargs): | |
raise NotImplementedError | |
class MetaProcessor(Processor): | |
""" | |
A meta processor is expected to load the metadata of a dataset: | |
(e.g., video_ids, or captions). | |
You must implement the `__getitem__` (meta datasets are rather diverse.). | |
""" | |
def __init__(self, config): | |
self.split = config.split | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
raise NotImplementedError | |
def _get_split_path(self, config): | |
splits = { | |
"train": config.train_path, | |
"valid": config.val_path, | |
"test": config.test_path, | |
} | |
if config.split is not None: | |
return splits[config.split] | |
return config.train_path | |
class TextProcessor(Processor): | |
""" | |
A generic Text processor: rename this as `withTokenizer`. | |
tokenize a string of text on-the-fly. | |
Warning: mostly used for end tasks. | |
(on-the-fly tokenization is slow for how2.) | |
TODO(huxu): move this class as a subclass. | |
""" | |
def __init__(self, config): | |
self.bert_name = str(config.bert_name) | |
self.use_fast = config.use_fast | |
from transformers import AutoTokenizer | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.bert_name, use_fast=self.use_fast | |
) | |
def __call__(self, text_id): | |
caption = self.tokenizer(text_id, add_special_tokens=False) | |
return caption["input_ids"] | |
class VideoProcessor(Processor): | |
""" | |
A generic video processor: load a numpy video tokens by default. | |
""" | |
def __init__(self, config): | |
self.vfeat_dir = config.vfeat_dir | |
def __call__(self, video_fn): | |
if isinstance(video_fn, tuple): | |
video_fn = video_fn[0] | |
assert isinstance(video_fn, str) | |
video_fn = os.path.join(self.vfeat_dir, video_fn + ".npy") | |
feat = np.load(video_fn) | |
return feat | |
class Aligner(object): | |
""" | |
An alignprocessor align video and text and output a dict of tensors (for a model). | |
""" | |
def __init__(self, config): | |
"""__init__ needs to be light weight for more workers/threads.""" | |
self.split = config.split | |
self.max_video_len = config.max_video_len | |
self.max_len = config.max_len | |
from transformers import AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
str(config.bert_name), use_fast=config.use_fast | |
) | |
self.cls_token_id = tokenizer.cls_token_id | |
self.sep_token_id = tokenizer.sep_token_id | |
self.pad_token_id = tokenizer.pad_token_id | |
self.mask_token_id = tokenizer.mask_token_id | |
def __call__(self, video_id, video_feature, text_feature): | |
raise NotImplementedError | |
def _build_video_seq(self, video_feature, video_clips=None): | |
""" | |
`video_feature`: available video tokens. | |
`video_clips`: video clip sequence to build. | |
""" | |
if not isinstance(video_feature, np.ndarray): | |
raise ValueError( | |
"unsupported type of video_feature", type(video_feature) | |
) | |
if video_clips is None: | |
# this is borrowed from DSAligner | |
video_start = 0 | |
video_end = min(len(video_feature), self.max_video_len) | |
# the whole sequence is a single clip. | |
video_clips = {"start": [video_start], "end": [video_end]} | |
vfeats = np.zeros( | |
(self.max_video_len, video_feature.shape[1]), dtype=np.float32 | |
) | |
vmasks = torch.zeros((self.max_video_len,), dtype=torch.bool) | |
video_len = 0 | |
for start, end in zip(video_clips["start"], video_clips["end"]): | |
clip_len = min(self.max_video_len - video_len, (end - start)) | |
if clip_len > 0: | |
vfeats[video_len: video_len + clip_len] = video_feature[ | |
start: start + clip_len | |
] | |
vmasks[video_len: video_len + clip_len] = 1 | |
video_len += clip_len | |
vfeats = torch.from_numpy(vfeats) | |
return vfeats, vmasks | |
def _build_text_seq(self, text_feature, text_clip_indexs=None): | |
""" | |
`text_feature`: all available clips. | |
`text_clip_indexes`: clip sequence to build. | |
""" | |
if text_clip_indexs is None: | |
text_clip_indexs = [0] | |
full_caps = [] | |
if isinstance(text_feature, dict): | |
for clip_idx in text_clip_indexs: | |
full_caps.extend(text_feature["cap"][clip_idx]) | |
else: | |
full_caps = text_feature | |
max_text_len = self.max_len - self.max_video_len - 3 | |
full_caps = full_caps[:max_text_len] | |
full_caps = ( | |
[self.cls_token_id, self.sep_token_id] + full_caps + [self.sep_token_id] | |
) | |
text_pad_len = self.max_len - len(full_caps) - self.max_video_len | |
padded_full_caps = full_caps + [self.pad_token_id] * text_pad_len | |
caps = torch.LongTensor(padded_full_caps) | |
cmasks = torch.zeros((len(padded_full_caps),), dtype=torch.bool) | |
cmasks[: len(full_caps)] = 1 | |
return caps, cmasks | |
def batch_post_processing(self, batch, video_feature): | |
return batch | |
class MMAttentionMask2DProcessor(Processor): | |
"""text generation requires 2d mask | |
that is harder to generate by GPU at this stage.""" | |
def __call__(self, vmask, cmask, mtype): | |
if mtype == "textgen": | |
return self._build_textgeneration_mask(vmask, cmask) | |
elif mtype == "videogen": | |
return self._build_videogeneration_mask(vmask, cmask) | |
else: | |
return self._build_mm_mask(vmask, cmask) | |
def _build_mm_mask(self, vmask, cmask): | |
mask_1d = torch.cat([cmask[:1], vmask, cmask[1:]], dim=0) | |
return mask_1d[None, :].repeat(mask_1d.size(0), 1) | |
def _build_videogeneration_mask(self, vmask, cmask): | |
# cls_mask is only about text otherwise it will leak generation. | |
cls_text_mask = torch.cat([ | |
# [CLS] | |
torch.ones( | |
(1,), dtype=torch.bool, device=cmask.device), | |
# video tokens and [SEP] for video. | |
torch.zeros( | |
(vmask.size(0) + 1,), dtype=torch.bool, device=cmask.device), | |
cmask[2:] | |
], dim=0) | |
# concat horizontially. | |
video_len = int(vmask.sum()) | |
video_masks = torch.cat([ | |
# [CLS] | |
torch.ones( | |
(video_len, 1), dtype=torch.bool, device=cmask.device | |
), | |
torch.tril( | |
torch.ones( | |
(video_len, video_len), | |
dtype=torch.bool, device=cmask.device)), | |
# video_padding | |
torch.zeros( | |
(video_len, vmask.size(0) - video_len), | |
dtype=torch.bool, device=cmask.device | |
), | |
# [SEP] for video (unused). | |
torch.zeros( | |
(video_len, 1), dtype=torch.bool, device=cmask.device | |
), | |
cmask[2:].unsqueeze(0).repeat(video_len, 1) | |
], dim=1) | |
text_masks = cls_text_mask[None, :].repeat( | |
cmask.size(0) - 2, 1) | |
video_padding_masks = cls_text_mask[None, :].repeat( | |
vmask.size(0) - video_len, 1) | |
return torch.cat([ | |
cls_text_mask[None, :], | |
video_masks, | |
video_padding_masks, | |
torch.cat([cmask[:1], vmask, cmask[1:]], dim=0)[None,:], | |
text_masks | |
], dim=0) | |
def _build_textgeneration_mask(self, vmask, cmask): | |
# cls_mask is only about video otherwise it will leak generation. | |
cls_video_mask = torch.cat([ | |
# [CLS] | |
torch.ones( | |
(1,), dtype=torch.bool, device=cmask.device), | |
vmask, | |
# [SEP] | |
torch.ones((1,), dtype=torch.bool, device=cmask.device), | |
torch.zeros( | |
(cmask.size(0)-2,), dtype=torch.bool, device=cmask.device) | |
], dim=0) | |
# concat horizontially. | |
text_len = int(cmask[2:].sum()) | |
text_masks = torch.cat([ | |
# [CLS] | |
torch.ones( | |
(text_len, 1), dtype=torch.bool, device=cmask.device | |
), | |
vmask.unsqueeze(0).repeat(text_len, 1), | |
# [SEP] for video. | |
torch.ones( | |
(text_len, 1), dtype=torch.bool, device=cmask.device | |
), | |
torch.tril( | |
torch.ones( | |
(text_len, text_len), | |
dtype=torch.bool, device=cmask.device)), | |
# padding. | |
torch.zeros( | |
(text_len, cmask.size(0) - text_len - 2), | |
dtype=torch.bool, device=cmask.device | |
) | |
], dim=1) | |
cls_video_masks = cls_video_mask[None, :].repeat( | |
vmask.size(0) + 2, 1) | |
text_padding_masks = cls_video_mask[None, :].repeat( | |
cmask.size(0) - text_len - 2, 1) | |
return torch.cat([ | |
cls_video_masks, text_masks, text_padding_masks], dim=0) | |