Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. All Rights Reserved | |
import torch | |
import os | |
import numpy as np | |
import pickle | |
from . import retri | |
from ..utils import get_local_rank | |
class VectorPool(object): | |
""" | |
Base class of retrieval space. | |
""" | |
def __init__(self, config): | |
from transformers import AutoConfig | |
self.hidden_size = AutoConfig.from_pretrained( | |
config.dataset.bert_name).hidden_size | |
self.retriever_cls = getattr(retri, config.retriever_cls) | |
def __call__(self, sample, **kwargs): | |
raise NotImplementedError | |
def build_retriver( | |
self, | |
retriever_cls=None, | |
hidden_size=None, | |
centroids=512, | |
db_type="flatl2", | |
examples_per_cent_to_train=48 | |
): | |
"""merge results from multiple gpus and return a retriver..""" | |
self.retriver = retriever_cls( | |
hidden_size, centroids, db_type, examples_per_cent_to_train) | |
return self.retriver | |
def __repr__(self): | |
if hasattr(self, "retriver"): | |
retriver_name = str(len(self.retriver)) | |
else: | |
retriver_name = "no retriver field yet" | |
return self.__class__.__name__ \ | |
+ "(" + retriver_name + ")" | |
class VideoVectorPool(VectorPool): | |
""" | |
average clips of a video as video representation. | |
""" | |
def __init__(self, config): | |
super().__init__(config) | |
self.build_retriver(self.retriever_cls, self.hidden_size) | |
def __call__(self, sample, subsampling, **kwargs): | |
hidden_states = ( | |
sample["pooled_video"] + sample["pooled_text"]) / 2. | |
hidden_states = hidden_states.view( | |
-1, subsampling, | |
hidden_states.size(-1)) | |
hidden_states = torch.mean(hidden_states, dim=1) | |
hidden_states = hidden_states.cpu().detach().numpy() | |
video_ids = [] | |
for offset_idx, video_id in enumerate(sample["video_id"]): | |
if isinstance(video_id, tuple) and len(video_id) == 3: | |
# a sharded video_id. | |
video_id = video_id[0] | |
video_ids.append(video_id) | |
assert len(video_ids) == len(hidden_states) | |
self.retriver.add( | |
hidden_states.astype("float32"), | |
video_ids | |
) | |
class DistributedVectorPool(VectorPool): | |
""" | |
support sync of multiple gpus/nodes. | |
""" | |
def __init__(self, config): | |
super().__init__(config) | |
self.out_dir = os.path.join( | |
config.fairseq.checkpoint.save_dir, | |
"retri") | |
os.makedirs(self.out_dir, exist_ok=True) | |
self.hidden_states = [] | |
self.video_ids = [] | |
def build_retriver( | |
self, | |
retriever_cls=None, | |
hidden_size=None, | |
centroids=4096, | |
db_type="flatl2", | |
examples_per_cent_to_train=48 | |
): | |
if retriever_cls is None: | |
retriever_cls = self.retriever_cls | |
if hidden_size is None: | |
hidden_size = self.hidden_size | |
"""merge results from multiple gpus and return a retriver..""" | |
if torch.distributed.is_initialized(): | |
self.save() | |
# sync saving. | |
torch.distributed.barrier() | |
world_size = torch.distributed.get_world_size() | |
else: | |
world_size = 1 | |
self.retriver = retriever_cls( | |
hidden_size, centroids, db_type, examples_per_cent_to_train) | |
# each gpu process has its own retriever. | |
for local_rank in range(world_size): | |
if get_local_rank() == 0: | |
print("load local_rank", local_rank) | |
hidden_states, video_ids = self.load(local_rank) | |
hidden_states = hidden_states.astype("float32") | |
self.retriver.add(hidden_states, video_ids) | |
return self.retriver | |
def load(self, local_rank): | |
hidden_states = np.load( | |
os.path.join( | |
self.out_dir, | |
"hidden_state" + str(local_rank) + ".npy" | |
) | |
) | |
with open( | |
os.path.join( | |
self.out_dir, "video_id" + str(local_rank) + ".pkl"), | |
"rb") as fr: | |
video_ids = pickle.load(fr) | |
return hidden_states, video_ids | |
def save(self): | |
hidden_states = np.vstack(self.hidden_states) | |
assert len(hidden_states) == len(self.video_ids), "{}, {}".format( | |
len(hidden_states), | |
len(self.video_ids) | |
) | |
local_rank = torch.distributed.get_rank() \ | |
if torch.distributed.is_initialized() else 0 | |
np.save( | |
os.path.join( | |
self.out_dir, | |
"hidden_state" + str(local_rank) + ".npy"), | |
hidden_states) | |
with open( | |
os.path.join( | |
self.out_dir, | |
"video_id" + str(local_rank) + ".pkl"), | |
"wb") as fw: | |
pickle.dump( | |
self.video_ids, | |
fw, | |
protocol=pickle.HIGHEST_PROTOCOL | |
) | |
class DistributedVideoVectorPool(DistributedVectorPool): | |
""" | |
average clips of a video as video representation. | |
""" | |
def __call__(self, sample, subsampling, **kwargs): | |
hidden_states = ( | |
sample["pooled_video"] + sample["pooled_text"]) / 2. | |
hidden_states = hidden_states.view( | |
-1, subsampling, | |
hidden_states.size(-1)) | |
hidden_states = torch.mean(hidden_states, dim=1) | |
hidden_states = hidden_states.cpu().detach().numpy() | |
video_ids = [] | |
for offset_idx, video_id in enumerate(sample["video_id"]): | |
if isinstance(video_id, tuple) and len(video_id) == 3: | |
# a sharded video_id. | |
video_id = video_id[0] | |
video_ids.append(video_id) | |
assert len(video_ids) == len(hidden_states) | |
self.hidden_states.append(hidden_states) | |
self.video_ids.extend(video_ids) | |
# ------------ the following are deprecated -------------- | |
class TextClipVectorPool(VectorPool): | |
def __init__(self, config): | |
from transformers import AutoConfig | |
hidden_size = AutoConfig.from_pretrained( | |
config.dataset.bert_name).hidden_size | |
retriever_cls = getattr(retri, config.retriever_cls) | |
self.build_retriver(retriever_cls, hidden_size) | |
def __call__(self, sample, **kwargs): | |
clip_meta = sample["clip_meta"].cpu() | |
assert torch.all(torch.le(clip_meta[:, 4], clip_meta[:, 5])) | |
text_meta = [tuple(item.tolist()) for item in clip_meta[:, 3:]] | |
if hasattr(self, "retriver"): | |
# build_retriver is called. | |
self.retriver.add( | |
sample["pooled_text"].cpu().numpy().astype("float32"), | |
text_meta | |
) | |
else: | |
raise NotImplementedError | |
class MMClipVectorPool(VectorPool): | |
""" | |
Multimodal Clip-level vector pool. | |
""" | |
def __init__(self, out_dir): | |
"""use hidden_states to store `(video, text)`.""" | |
"""use video_ids to store `(video_id, start, end)`.""" | |
super().__init__(out_dir) | |
def __call__(self, sample, **kwargs): | |
pooled_video = sample["pooled_video"].cpu().unsqueeze(1).numpy() | |
pooled_text = sample["pooled_text"].cpu().unsqueeze(1).numpy() | |
self.hidden_states.append( | |
np.concatenate([pooled_video, pooled_text], axis=1) | |
) | |
video_starts = sample["video_start"].cpu() | |
video_ends = sample["video_end"].cpu() | |
assert torch.all(torch.le(video_starts, video_ends)) | |
text_starts = sample["text_start"].cpu() | |
text_ends = sample["text_end"].cpu() | |
assert torch.all(torch.le(text_starts, text_ends)) | |
subsample_size = sample["pooled_video"].size(0) // len(sample["video_id"]) | |
video_ids = [video_id for video_id in sample["video_id"] | |
for _ in range(subsample_size) | |
] | |
for video_id, video_start, video_end, text_start, text_end in zip( | |
video_ids, video_starts, video_ends, text_starts, text_ends): | |
self.video_ids.append(( | |
video_id, | |
(int(video_start), int(video_end)), | |
(int(text_start), int(text_end)) | |
)) | |