File size: 8,104 Bytes
3e1d9f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
from typing import List, Optional, Literal
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import ConcatDataset as TorchConcatDataset
from torch.utils.data import Subset as TorchSubset
from ..root import DATASETS
@DATASETS.register_module()
class ConcatDataset(Dataset):
_repr_indent = 4
def __init__(self, cfgs):
self.cfgs = cfgs
datasets = [DATASETS.build(cfg) for cfg in cfgs]
self.concat_dataset = TorchConcatDataset(datasets)
def __len__(self):
return len(self.concat_dataset)
def __getitem__(self, index):
return self.concat_dataset[index]
def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__
body = [
f"Number of datapoints: {self.__len__()}",
]
for i, ds in enumerate(self.concat_dataset.datasets):
body.append(f"Subset {i + 1}/{len(self.concat_dataset.datasets)}")
body += ds.__repr__().splitlines()
lines = [head] + [" " * self._repr_indent + line for line in body]
return "\n".join(lines)
@DATASETS.register_module()
class InterleaveDateset(Dataset):
_repr_indent = 4
def __init__(
self,
cfgs,
probabilities: Optional[List[float]] = None,
seed: Optional[int] = 42,
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
):
self.cfgs = cfgs
self.probabilities = probabilities
self.seed = seed
self.stopping_strategy = stopping_strategy
datasets = [DATASETS.build(cfg) for cfg in cfgs]
self.concat_dataset = TorchConcatDataset(datasets)
self.index_mapping = _interleave_dataset_index(
lengths=[len(ds) for ds in datasets],
probabilities=probabilities,
seed=seed,
stopping_strategy=stopping_strategy,
)
def __len__(self):
return len(self.index_mapping)
def __getitem__(self, index):
return self.concat_dataset[self.index_mapping[index]]
def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__
body = [
f"Number of datapoints: {self.__len__()}",
f"Probabilities: {self.probabilities}",
f"stopping_strategy: {self.stopping_strategy}",
f"seed: {self.seed}",
]
for i, ds in enumerate(self.concat_dataset.datasets):
body.append(f"Subset {i + 1}/{len(self.concat_dataset.datasets)}")
body += ds.__repr__().splitlines()
lines = [head] + [" " * self._repr_indent + line for line in body]
return "\n".join(lines)
# stolen from huggingface/datasets
# https://github.com/huggingface/datasets/blob/074925b9b7c1dfd33b8675aa99c07cc26375665c/src/datasets/arrow_dataset.py#L5987
def _interleave_dataset_index(
*,
lengths: List[int],
probabilities: Optional[List[float]] = None,
seed: Optional[int] = None,
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
):
if probabilities is not None and 0 in probabilities:
assert stopping_strategy == 'first_exhausted', "you will meet a Infinite loop"
# Let's now build the indices to pass to .select()
offsets = np.cumsum([0] + lengths[:-1])
# if stopping_strategy is "first_exhausted", it is an undersampling situation whereas it is an oversampling situation if it is "all_exhausted"
oversampling = stopping_strategy == "all_exhausted"
if probabilities is None and not oversampling:
# Undersampling situation with cycling between each sources
# Example:: If lengths of the datasets are [3, 4, 5]
# Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9]
# Note that we only have 3 examples per dataset since the first dataset ran out of examples
# Reasoning behind the following operation: keeping the min_length first indices of each dataset
# while offsetting in order to correspond to the right indices of the concatenated dataset
# and flattening to effectively interleave the datasets
indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist()
elif probabilities is None:
# Oversampling situation with cycling between each sources
# Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 5, 9, 0, 6, 10, 1, 3, 11]
# Note that we have 5 examples per dataset with a rolling window since the longest dataset has 5 samples
# Reasoning behind the following operation: for each dataset indices (i.e column) repeat the indices to have max_length indices per dataset
# For example, if the max_length is 5 and the i-th dataset has 3 samples, the i-th column will be [0,1,2,0,1]
indices = np.mod(np.arange(max(lengths)).reshape(-1, 1), np.array(lengths).reshape(1, -1))
# We have to keep the indices to their respective dataset offsets and to flatten to effectively interleave the datasets
indices = (indices + offsets).flatten().tolist()
else:
# boolean array indicating if at index i if the dataset_i has been fully exhausted
is_exhausted = np.full(len(lengths), False)
# if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted
# if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once
bool_strategy_func = np.all if oversampling else np.any
def iter_random_indices():
"""Get an infinite iterator that randomly samples the index of the source to pick examples from."""
rng = np.random.default_rng(seed)
while True:
yield from (int(i) for i in rng.choice(len(lengths), size=1000, p=probabilities))
current_index = [0] * len(lengths)
indices = []
for source_idx in iter_random_indices():
# If no oversampling, we stop as soon as a dataset has ran out of examples (np.any)
# Otherwise, we stop as soon as every dataset has ran out of examples (np.all)
if bool_strategy_func(is_exhausted):
# the stopping condition was reached, let's stop
break
# let's add the example at the current index of the `source_idx`-th dataset
indices.append(current_index[source_idx] + offsets[source_idx])
current_index[source_idx] += 1
# we've ran out of examples for the current dataset, let's update our boolean array and bring the current_index back to 0
if current_index[source_idx] >= lengths[source_idx]:
is_exhausted[source_idx] = True
current_index[source_idx] = 0
return indices
@DATASETS.register_module()
class SubSet(TorchSubset):
def __init__(self, cfg, portion, do_shuffle=True, seed=42):
assert 0 < portion <= 1
dataset = DATASETS.build(cfg=cfg)
target_len = int(len(dataset) * portion)
if do_shuffle:
rng = np.random.default_rng(seed)
indices = list(range(len(dataset)))
rng.shuffle(indices)
indices = indices[:target_len]
else:
indices = list(range(target_len))
super().__init__(dataset, indices)
@DATASETS.register_module()
class ConcatDatasetWithShuffle(TorchSubset):
_repr_indent = 4
def __init__(self, cfgs, seed=42, portion=1):
self.cfgs = cfgs
self.seed = seed
self.portion = portion
dataset = TorchConcatDataset([DATASETS.build(cfg) for cfg in cfgs])
target_len = int(len(dataset) * portion)
indices = list(range(len(dataset))) * int(np.ceil(portion))
rng = np.random.default_rng(seed)
rng.shuffle(indices)
indices = indices[:target_len]
super().__init__(dataset, indices)
__all__ = ['ConcatDataset', 'InterleaveDateset', 'SubSet', 'ConcatDatasetWithShuffle']
|