|
from typing import List, Optional, Literal |
|
|
|
import numpy as np |
|
from torch.utils.data import Dataset |
|
from torch.utils.data import ConcatDataset as TorchConcatDataset |
|
from torch.utils.data import Subset as TorchSubset |
|
|
|
from ..root import DATASETS |
|
|
|
|
|
@DATASETS.register_module() |
|
class ConcatDataset(Dataset): |
|
_repr_indent = 4 |
|
|
|
def __init__(self, cfgs): |
|
self.cfgs = cfgs |
|
datasets = [DATASETS.build(cfg) for cfg in cfgs] |
|
self.concat_dataset = TorchConcatDataset(datasets) |
|
|
|
def __len__(self): |
|
return len(self.concat_dataset) |
|
|
|
def __getitem__(self, index): |
|
return self.concat_dataset[index] |
|
|
|
def __repr__(self) -> str: |
|
head = "Dataset " + self.__class__.__name__ |
|
body = [ |
|
f"Number of datapoints: {self.__len__()}", |
|
] |
|
for i, ds in enumerate(self.concat_dataset.datasets): |
|
body.append(f"Subset {i + 1}/{len(self.concat_dataset.datasets)}") |
|
body += ds.__repr__().splitlines() |
|
lines = [head] + [" " * self._repr_indent + line for line in body] |
|
return "\n".join(lines) |
|
|
|
|
|
@DATASETS.register_module() |
|
class InterleaveDateset(Dataset): |
|
_repr_indent = 4 |
|
|
|
def __init__( |
|
self, |
|
cfgs, |
|
probabilities: Optional[List[float]] = None, |
|
seed: Optional[int] = 42, |
|
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted", |
|
): |
|
self.cfgs = cfgs |
|
self.probabilities = probabilities |
|
self.seed = seed |
|
self.stopping_strategy = stopping_strategy |
|
|
|
datasets = [DATASETS.build(cfg) for cfg in cfgs] |
|
self.concat_dataset = TorchConcatDataset(datasets) |
|
|
|
self.index_mapping = _interleave_dataset_index( |
|
lengths=[len(ds) for ds in datasets], |
|
probabilities=probabilities, |
|
seed=seed, |
|
stopping_strategy=stopping_strategy, |
|
) |
|
|
|
def __len__(self): |
|
return len(self.index_mapping) |
|
|
|
def __getitem__(self, index): |
|
return self.concat_dataset[self.index_mapping[index]] |
|
|
|
def __repr__(self) -> str: |
|
head = "Dataset " + self.__class__.__name__ |
|
body = [ |
|
f"Number of datapoints: {self.__len__()}", |
|
f"Probabilities: {self.probabilities}", |
|
f"stopping_strategy: {self.stopping_strategy}", |
|
f"seed: {self.seed}", |
|
] |
|
for i, ds in enumerate(self.concat_dataset.datasets): |
|
body.append(f"Subset {i + 1}/{len(self.concat_dataset.datasets)}") |
|
body += ds.__repr__().splitlines() |
|
lines = [head] + [" " * self._repr_indent + line for line in body] |
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
|
def _interleave_dataset_index( |
|
*, |
|
lengths: List[int], |
|
probabilities: Optional[List[float]] = None, |
|
seed: Optional[int] = None, |
|
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted", |
|
): |
|
if probabilities is not None and 0 in probabilities: |
|
assert stopping_strategy == 'first_exhausted', "you will meet a Infinite loop" |
|
|
|
offsets = np.cumsum([0] + lengths[:-1]) |
|
|
|
|
|
oversampling = stopping_strategy == "all_exhausted" |
|
|
|
if probabilities is None and not oversampling: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist() |
|
elif probabilities is None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
indices = np.mod(np.arange(max(lengths)).reshape(-1, 1), np.array(lengths).reshape(1, -1)) |
|
|
|
|
|
indices = (indices + offsets).flatten().tolist() |
|
|
|
else: |
|
|
|
is_exhausted = np.full(len(lengths), False) |
|
|
|
|
|
|
|
bool_strategy_func = np.all if oversampling else np.any |
|
|
|
def iter_random_indices(): |
|
"""Get an infinite iterator that randomly samples the index of the source to pick examples from.""" |
|
rng = np.random.default_rng(seed) |
|
while True: |
|
yield from (int(i) for i in rng.choice(len(lengths), size=1000, p=probabilities)) |
|
|
|
current_index = [0] * len(lengths) |
|
indices = [] |
|
for source_idx in iter_random_indices(): |
|
|
|
|
|
if bool_strategy_func(is_exhausted): |
|
|
|
break |
|
|
|
|
|
indices.append(current_index[source_idx] + offsets[source_idx]) |
|
current_index[source_idx] += 1 |
|
|
|
|
|
if current_index[source_idx] >= lengths[source_idx]: |
|
is_exhausted[source_idx] = True |
|
current_index[source_idx] = 0 |
|
return indices |
|
|
|
|
|
@DATASETS.register_module() |
|
class SubSet(TorchSubset): |
|
def __init__(self, cfg, portion, do_shuffle=True, seed=42): |
|
assert 0 < portion <= 1 |
|
dataset = DATASETS.build(cfg=cfg) |
|
target_len = int(len(dataset) * portion) |
|
if do_shuffle: |
|
rng = np.random.default_rng(seed) |
|
indices = list(range(len(dataset))) |
|
rng.shuffle(indices) |
|
indices = indices[:target_len] |
|
else: |
|
indices = list(range(target_len)) |
|
super().__init__(dataset, indices) |
|
|
|
|
|
@DATASETS.register_module() |
|
class ConcatDatasetWithShuffle(TorchSubset): |
|
_repr_indent = 4 |
|
|
|
def __init__(self, cfgs, seed=42, portion=1): |
|
self.cfgs = cfgs |
|
self.seed = seed |
|
self.portion = portion |
|
dataset = TorchConcatDataset([DATASETS.build(cfg) for cfg in cfgs]) |
|
|
|
target_len = int(len(dataset) * portion) |
|
indices = list(range(len(dataset))) * int(np.ceil(portion)) |
|
rng = np.random.default_rng(seed) |
|
rng.shuffle(indices) |
|
indices = indices[:target_len] |
|
super().__init__(dataset, indices) |
|
|
|
|
|
__all__ = ['ConcatDataset', 'InterleaveDateset', 'SubSet', 'ConcatDatasetWithShuffle'] |
|
|