File size: 8,104 Bytes
3e1d9f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
from typing import List, Optional, Literal

import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import ConcatDataset as TorchConcatDataset
from torch.utils.data import Subset as TorchSubset

from ..root import DATASETS


@DATASETS.register_module()
class ConcatDataset(Dataset):
    _repr_indent = 4

    def __init__(self, cfgs):
        self.cfgs = cfgs
        datasets = [DATASETS.build(cfg) for cfg in cfgs]
        self.concat_dataset = TorchConcatDataset(datasets)

    def __len__(self):
        return len(self.concat_dataset)

    def __getitem__(self, index):
        return self.concat_dataset[index]

    def __repr__(self) -> str:
        head = "Dataset " + self.__class__.__name__
        body = [
            f"Number of datapoints: {self.__len__()}",
        ]
        for i, ds in enumerate(self.concat_dataset.datasets):
            body.append(f"Subset {i + 1}/{len(self.concat_dataset.datasets)}")
            body += ds.__repr__().splitlines()
        lines = [head] + [" " * self._repr_indent + line for line in body]
        return "\n".join(lines)


@DATASETS.register_module()
class InterleaveDateset(Dataset):
    _repr_indent = 4

    def __init__(
            self,
            cfgs,
            probabilities: Optional[List[float]] = None,
            seed: Optional[int] = 42,
            stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
    ):
        self.cfgs = cfgs
        self.probabilities = probabilities
        self.seed = seed
        self.stopping_strategy = stopping_strategy

        datasets = [DATASETS.build(cfg) for cfg in cfgs]
        self.concat_dataset = TorchConcatDataset(datasets)

        self.index_mapping = _interleave_dataset_index(
            lengths=[len(ds) for ds in datasets],
            probabilities=probabilities,
            seed=seed,
            stopping_strategy=stopping_strategy,
        )

    def __len__(self):
        return len(self.index_mapping)

    def __getitem__(self, index):
        return self.concat_dataset[self.index_mapping[index]]

    def __repr__(self) -> str:
        head = "Dataset " + self.__class__.__name__
        body = [
            f"Number of datapoints: {self.__len__()}",
            f"Probabilities: {self.probabilities}",
            f"stopping_strategy: {self.stopping_strategy}",
            f"seed: {self.seed}",
        ]
        for i, ds in enumerate(self.concat_dataset.datasets):
            body.append(f"Subset {i + 1}/{len(self.concat_dataset.datasets)}")
            body += ds.__repr__().splitlines()
        lines = [head] + [" " * self._repr_indent + line for line in body]
        return "\n".join(lines)


# stolen from huggingface/datasets
# https://github.com/huggingface/datasets/blob/074925b9b7c1dfd33b8675aa99c07cc26375665c/src/datasets/arrow_dataset.py#L5987
def _interleave_dataset_index(
        *,
        lengths: List[int],
        probabilities: Optional[List[float]] = None,
        seed: Optional[int] = None,
        stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
):
    if probabilities is not None and 0 in probabilities:
        assert stopping_strategy == 'first_exhausted', "you will meet a Infinite loop"
    # Let's now build the indices to pass to .select()
    offsets = np.cumsum([0] + lengths[:-1])

    # if stopping_strategy is "first_exhausted", it is an undersampling situation whereas it is an oversampling situation if it is "all_exhausted"
    oversampling = stopping_strategy == "all_exhausted"

    if probabilities is None and not oversampling:
        # Undersampling situation with cycling between each sources
        # Example:: If lengths of the datasets are [3, 4, 5]
        # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9]
        # Note that we only have 3 examples per dataset since the first dataset ran out of examples

        # Reasoning behind the following operation: keeping the min_length first indices of each dataset
        # while offsetting in order to correspond to the right indices of the concatenated dataset
        # and flattening to effectively interleave the datasets
        indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist()
    elif probabilities is None:
        # Oversampling situation with cycling between each sources
        # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 5, 9, 0, 6, 10, 1, 3, 11]
        # Note that we have 5 examples per dataset with a rolling window since the longest dataset has 5 samples

        # Reasoning behind the following operation: for each dataset indices (i.e column) repeat the indices to have max_length indices per dataset
        # For example, if the max_length is 5 and the i-th dataset has 3 samples, the i-th column will be [0,1,2,0,1]
        indices = np.mod(np.arange(max(lengths)).reshape(-1, 1), np.array(lengths).reshape(1, -1))

        # We have to keep the indices to their respective dataset offsets and to flatten to effectively interleave the datasets
        indices = (indices + offsets).flatten().tolist()

    else:
        # boolean array indicating if at index i if the dataset_i has been fully exhausted
        is_exhausted = np.full(len(lengths), False)

        # if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted
        # if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once
        bool_strategy_func = np.all if oversampling else np.any

        def iter_random_indices():
            """Get an infinite iterator that randomly samples the index of the source to pick examples from."""
            rng = np.random.default_rng(seed)
            while True:
                yield from (int(i) for i in rng.choice(len(lengths), size=1000, p=probabilities))

        current_index = [0] * len(lengths)
        indices = []
        for source_idx in iter_random_indices():
            # If no oversampling, we stop as soon as a dataset has ran out of examples (np.any)
            # Otherwise, we stop as soon as every dataset has ran out of examples (np.all)
            if bool_strategy_func(is_exhausted):
                # the stopping condition was reached, let's stop
                break

            # let's add the example at the current index of the `source_idx`-th dataset
            indices.append(current_index[source_idx] + offsets[source_idx])
            current_index[source_idx] += 1

            # we've ran out of examples for the current dataset, let's update our boolean array and bring the current_index back to 0
            if current_index[source_idx] >= lengths[source_idx]:
                is_exhausted[source_idx] = True
                current_index[source_idx] = 0
    return indices


@DATASETS.register_module()
class SubSet(TorchSubset):
    def __init__(self, cfg, portion, do_shuffle=True, seed=42):
        assert 0 < portion <= 1
        dataset = DATASETS.build(cfg=cfg)
        target_len = int(len(dataset) * portion)
        if do_shuffle:
            rng = np.random.default_rng(seed)
            indices = list(range(len(dataset)))
            rng.shuffle(indices)
            indices = indices[:target_len]
        else:
            indices = list(range(target_len))
        super().__init__(dataset, indices)


@DATASETS.register_module()
class ConcatDatasetWithShuffle(TorchSubset):
    _repr_indent = 4

    def __init__(self, cfgs, seed=42, portion=1):
        self.cfgs = cfgs
        self.seed = seed
        self.portion = portion
        dataset = TorchConcatDataset([DATASETS.build(cfg) for cfg in cfgs])

        target_len = int(len(dataset) * portion)
        indices = list(range(len(dataset))) * int(np.ceil(portion))
        rng = np.random.default_rng(seed)
        rng.shuffle(indices)
        indices = indices[:target_len]
        super().__init__(dataset, indices)


__all__ = ['ConcatDataset', 'InterleaveDateset', 'SubSet', 'ConcatDatasetWithShuffle']