Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from collections import OrderedDict | |
| import numpy as np | |
| from fairseq.data import BaseWrapperDataset, FairseqDataset, iterators | |
| class MultiItr(object): | |
| def __init__(self, itr): | |
| self.itr = itr | |
| self._counts = [0 for x in itr] | |
| def __len__(self): | |
| return sum(len(itr) for itr in self.itr) | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| ratios = [count / len(itr) for count, itr in zip(self._counts, self.itr)] | |
| idx = ratios.index(min(ratios)) | |
| self._counts[idx] += 1 | |
| return next(self.itr[idx]) | |
| class MultidatasetEpochBatchIterator(iterators.EpochBatchIterating): | |
| """A wrapper around multiple epoch batch iterators.""" | |
| def __init__( | |
| self, | |
| dataset, | |
| batch_sampler, | |
| seed=1, | |
| num_shards=1, | |
| shard_id=0, | |
| num_workers=0, | |
| epoch=1, | |
| ): | |
| assert isinstance(dataset, OrderedDict) | |
| assert len(dataset) | |
| assert isinstance(dataset[next(iter(dataset))], FairseqDataset) | |
| self.iterators = [] | |
| self.epoch = epoch | |
| for key, dt in dataset.items(): | |
| epoch_iter = iterators.EpochBatchIterator( | |
| dataset=dt, | |
| collate_fn=dt.collater, | |
| batch_sampler=batch_sampler[key], | |
| seed=seed, | |
| num_shards=num_shards, | |
| shard_id=shard_id, | |
| num_workers=0, | |
| epoch=epoch, | |
| ) | |
| self.iterators.append(epoch_iter) | |
| def __len__(self): | |
| return sum(len(itr) for itr in self.iterators) | |
| def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): | |
| # `self.epoch += 1` should be handled by underlying `EpochBatchIterator`s. | |
| return MultiItr( | |
| [ | |
| itr.next_epoch_itr( | |
| shuffle=shuffle, fix_batches_to_gpus=fix_batches_to_gpus | |
| ) | |
| for itr in self.iterators | |
| ] | |
| ) | |
| def end_of_epoch(self): | |
| return all(itr.end_of_epoch() for itr in self.iterators) | |
| def next_epoch_idx(self): | |
| """Return the epoch index after *next_epoch_itr* is called.""" | |
| epochs = [itr.next_epoch_idx for itr in self.iterators] | |
| self.epoch = epochs[0] | |
| assert all(epoch == self.epoch for epoch in epochs) | |
| return self.epoch | |
| def iterations_in_epoch(self): | |
| return sum(itr.iterations_in_epoch for itr in self.iterators) | |
| def state_dict(self): | |
| return { | |
| "iterators": [it.state_dict() for it in self.iterators], | |
| "epoch": self.epoch, | |
| } | |
| def load_state_dict(self, state_dict): | |
| self.epoch = state_dict["epoch"] | |
| for it, d in zip(self.iterators, state_dict["iterators"]): | |
| it.load_state_dict(d) | |
| class MultitaskDatasetWrapper(BaseWrapperDataset): | |
| """A wrapper for a multitask dataset.""" | |
| def __init__(self, dataset, target_language_id, sample=1.0, name=""): | |
| super().__init__(dataset) | |
| self.target_language_id = target_language_id | |
| self.sample = sample | |
| self.name = name | |
| def collater(self, *args, **kwargs): | |
| ans = self.dataset.collater(*args, **kwargs) | |
| if "net_input" in ans: | |
| ans["net_input"]["target_language_id"] = self.target_language_id | |
| ans["net_input"]["dataset_name"] = self.name | |
| return ans | |
| def num_tokens(self, *args, **kwargs): | |
| return self.dataset.num_tokens(*args, **kwargs) | |
| def ordered_indices(self, *args, **kwargs): | |
| indices = self.dataset.ordered_indices(*args, **kwargs) | |
| # Hacky solution for sampling | |
| size = int(self.sample * indices.shape[0]) | |
| return indices.take(np.sort(np.random.permutation(indices.shape[0])[:size])) | |
| def size(self, index: int): | |
| return self.dataset.size(index) | |
| def supports_prefetch(self): | |
| """Whether this dataset supports prefetching.""" | |
| return getattr(self.dataset, "supports_prefetch", False) | |
| def prefetch(self, indices): | |
| return self.dataset.prefetch(indices) | |