Spaces:
Sleeping
Sleeping
Dit-document-layout-analysis
/
unilm
/edgelm
/examples
/MMPT
/mmpt
/processors
/how2retriprocessor.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from .how2processor import ( | |
ShardedHow2MetaProcessor, | |
ShardedVideoProcessor, | |
ShardedTextProcessor, | |
VariedLenAligner, | |
OverlappedAligner | |
) | |
class ShardedHow2VideoRetriMetaProcessor(ShardedHow2MetaProcessor): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_video_per_batch = config.num_video_per_batch | |
self.cands = [ | |
self.data[batch_offset:batch_offset + self.num_video_per_batch] | |
for batch_offset in | |
range(0, (len(self.data) // (8 * self.num_video_per_batch)) * 8 * self.num_video_per_batch, self.num_video_per_batch)] | |
def __len__(self): | |
return len(self.cands) | |
def set_candidates(self, cands): | |
# no changes on num of batches. | |
print(len(self.cands), "->", len(cands)) | |
# assert len(self.cands) == len(cands) | |
self.cands = cands | |
def __getitem__(self, idx): | |
video_ids = self.cands[idx] | |
assert isinstance(video_ids, list) | |
sharded_video_idxs = [] | |
for video_id in video_ids: | |
shard_id, video_idx = self.video_id_to_shard[video_id] | |
sharded_video_idxs.append((video_id, -1, shard_id, video_idx)) | |
return sharded_video_idxs, sharded_video_idxs | |
class ShardedVideoRetriVideoProcessor(ShardedVideoProcessor): | |
"""In retrival case the video_id | |
is a list of tuples: `(shard_id, video_idx)` .""" | |
def __call__(self, sharded_video_idxs): | |
assert isinstance(sharded_video_idxs, list) | |
cand_feats = [] | |
for shared_video_idx in sharded_video_idxs: | |
feat = super().__call__(shared_video_idx) | |
cand_feats.append(feat) | |
return cand_feats | |
class ShardedVideoRetriTextProcessor(ShardedTextProcessor): | |
"""In retrival case the video_id | |
is a list of tuples: `(shard_id, video_idx)` .""" | |
def __call__(self, sharded_video_idxs): | |
assert isinstance(sharded_video_idxs, list) | |
cand_caps = [] | |
for shared_video_idx in sharded_video_idxs: | |
caps = super().__call__(shared_video_idx) | |
cand_caps.append(caps) | |
return cand_caps | |
class VideoRetriAligner(VariedLenAligner): | |
# Retritask will trim dim-0. | |
def __call__(self, sharded_video_idxs, video_features, text_features): | |
from transformers import default_data_collator | |
batch, video_ids = [], [] | |
for video_id, video_feature, text_feature in \ | |
zip(sharded_video_idxs, video_features, text_features): | |
sub_batch = super().__call__(video_id, video_feature, text_feature) | |
batch.append(sub_batch) | |
if isinstance(video_id, tuple): | |
video_id = video_id[0] | |
video_ids.append(video_id) | |
batch = default_data_collator(batch) | |
batch["video_id"] = video_ids | |
return batch | |
class VideoRetriOverlappedAligner(OverlappedAligner): | |
# Retritask will trim dim-0. | |
def __call__(self, sharded_video_idxs, video_features, text_features): | |
from transformers import default_data_collator | |
batch, video_ids = [], [] | |
for video_id, video_feature, text_feature in \ | |
zip(sharded_video_idxs, video_features, text_features): | |
sub_batch = super().__call__(video_id, video_feature, text_feature) | |
batch.append(sub_batch) | |
if isinstance(video_id, tuple): | |
video_id = video_id[0] | |
video_ids.append(video_id) | |
batch = default_data_collator(batch) | |
batch["video_id"] = video_ids | |
return batch | |