Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
TODO (huxu): fairseq wrapper class for all dataset you defined: mostly MMDataset. | |
""" | |
from collections import OrderedDict | |
from torch.utils.data import Dataset | |
from torch.utils.data.dataloader import default_collate | |
from fairseq.data import FairseqDataset, data_utils | |
class FairseqMMDataset(FairseqDataset): | |
""" | |
A wrapper class for MMDataset for fairseq. | |
""" | |
def __init__(self, mmdataset): | |
if not isinstance(mmdataset, Dataset): | |
raise TypeError("mmdataset must be of type `torch.utils.data.dataset`.") | |
self.mmdataset = mmdataset | |
def set_epoch(self, epoch, **unused): | |
super().set_epoch(epoch) | |
self.epoch = epoch | |
def __getitem__(self, idx): | |
with data_utils.numpy_seed(43211, self.epoch, idx): | |
return self.mmdataset[idx] | |
def __len__(self): | |
return len(self.mmdataset) | |
def collater(self, samples): | |
if hasattr(self.mmdataset, "collator"): | |
return self.mmdataset.collator(samples) | |
if len(samples) == 0: | |
return {} | |
if isinstance(samples[0], dict): | |
batch = OrderedDict() | |
for key in samples[0]: | |
if samples[0][key] is not None: | |
batch[key] = default_collate([sample[key] for sample in samples]) | |
return batch | |
else: | |
return default_collate(samples) | |
def size(self, index): | |
"""dummy implementation: we don't use --max-tokens""" | |
return 1 | |
def num_tokens(self, index): | |
"""dummy implementation: we don't use --max-tokens""" | |
return 1 | |