RxnIM / mllm /dataset /utils /concatenate_dataset.py
CYF200127's picture
Upload 235 files
3e1d9f3 verified
raw
history blame
8.1 kB
from typing import List, Optional, Literal
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import ConcatDataset as TorchConcatDataset
from torch.utils.data import Subset as TorchSubset
from ..root import DATASETS
@DATASETS.register_module()
class ConcatDataset(Dataset):
_repr_indent = 4
def __init__(self, cfgs):
self.cfgs = cfgs
datasets = [DATASETS.build(cfg) for cfg in cfgs]
self.concat_dataset = TorchConcatDataset(datasets)
def __len__(self):
return len(self.concat_dataset)
def __getitem__(self, index):
return self.concat_dataset[index]
def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__
body = [
f"Number of datapoints: {self.__len__()}",
]
for i, ds in enumerate(self.concat_dataset.datasets):
body.append(f"Subset {i + 1}/{len(self.concat_dataset.datasets)}")
body += ds.__repr__().splitlines()
lines = [head] + [" " * self._repr_indent + line for line in body]
return "\n".join(lines)
@DATASETS.register_module()
class InterleaveDateset(Dataset):
_repr_indent = 4
def __init__(
self,
cfgs,
probabilities: Optional[List[float]] = None,
seed: Optional[int] = 42,
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
):
self.cfgs = cfgs
self.probabilities = probabilities
self.seed = seed
self.stopping_strategy = stopping_strategy
datasets = [DATASETS.build(cfg) for cfg in cfgs]
self.concat_dataset = TorchConcatDataset(datasets)
self.index_mapping = _interleave_dataset_index(
lengths=[len(ds) for ds in datasets],
probabilities=probabilities,
seed=seed,
stopping_strategy=stopping_strategy,
)
def __len__(self):
return len(self.index_mapping)
def __getitem__(self, index):
return self.concat_dataset[self.index_mapping[index]]
def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__
body = [
f"Number of datapoints: {self.__len__()}",
f"Probabilities: {self.probabilities}",
f"stopping_strategy: {self.stopping_strategy}",
f"seed: {self.seed}",
]
for i, ds in enumerate(self.concat_dataset.datasets):
body.append(f"Subset {i + 1}/{len(self.concat_dataset.datasets)}")
body += ds.__repr__().splitlines()
lines = [head] + [" " * self._repr_indent + line for line in body]
return "\n".join(lines)
# stolen from huggingface/datasets
# https://github.com/huggingface/datasets/blob/074925b9b7c1dfd33b8675aa99c07cc26375665c/src/datasets/arrow_dataset.py#L5987
def _interleave_dataset_index(
*,
lengths: List[int],
probabilities: Optional[List[float]] = None,
seed: Optional[int] = None,
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
):
if probabilities is not None and 0 in probabilities:
assert stopping_strategy == 'first_exhausted', "you will meet a Infinite loop"
# Let's now build the indices to pass to .select()
offsets = np.cumsum([0] + lengths[:-1])
# if stopping_strategy is "first_exhausted", it is an undersampling situation whereas it is an oversampling situation if it is "all_exhausted"
oversampling = stopping_strategy == "all_exhausted"
if probabilities is None and not oversampling:
# Undersampling situation with cycling between each sources
# Example:: If lengths of the datasets are [3, 4, 5]
# Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9]
# Note that we only have 3 examples per dataset since the first dataset ran out of examples
# Reasoning behind the following operation: keeping the min_length first indices of each dataset
# while offsetting in order to correspond to the right indices of the concatenated dataset
# and flattening to effectively interleave the datasets
indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist()
elif probabilities is None:
# Oversampling situation with cycling between each sources
# Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 5, 9, 0, 6, 10, 1, 3, 11]
# Note that we have 5 examples per dataset with a rolling window since the longest dataset has 5 samples
# Reasoning behind the following operation: for each dataset indices (i.e column) repeat the indices to have max_length indices per dataset
# For example, if the max_length is 5 and the i-th dataset has 3 samples, the i-th column will be [0,1,2,0,1]
indices = np.mod(np.arange(max(lengths)).reshape(-1, 1), np.array(lengths).reshape(1, -1))
# We have to keep the indices to their respective dataset offsets and to flatten to effectively interleave the datasets
indices = (indices + offsets).flatten().tolist()
else:
# boolean array indicating if at index i if the dataset_i has been fully exhausted
is_exhausted = np.full(len(lengths), False)
# if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted
# if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once
bool_strategy_func = np.all if oversampling else np.any
def iter_random_indices():
"""Get an infinite iterator that randomly samples the index of the source to pick examples from."""
rng = np.random.default_rng(seed)
while True:
yield from (int(i) for i in rng.choice(len(lengths), size=1000, p=probabilities))
current_index = [0] * len(lengths)
indices = []
for source_idx in iter_random_indices():
# If no oversampling, we stop as soon as a dataset has ran out of examples (np.any)
# Otherwise, we stop as soon as every dataset has ran out of examples (np.all)
if bool_strategy_func(is_exhausted):
# the stopping condition was reached, let's stop
break
# let's add the example at the current index of the `source_idx`-th dataset
indices.append(current_index[source_idx] + offsets[source_idx])
current_index[source_idx] += 1
# we've ran out of examples for the current dataset, let's update our boolean array and bring the current_index back to 0
if current_index[source_idx] >= lengths[source_idx]:
is_exhausted[source_idx] = True
current_index[source_idx] = 0
return indices
@DATASETS.register_module()
class SubSet(TorchSubset):
def __init__(self, cfg, portion, do_shuffle=True, seed=42):
assert 0 < portion <= 1
dataset = DATASETS.build(cfg=cfg)
target_len = int(len(dataset) * portion)
if do_shuffle:
rng = np.random.default_rng(seed)
indices = list(range(len(dataset)))
rng.shuffle(indices)
indices = indices[:target_len]
else:
indices = list(range(target_len))
super().__init__(dataset, indices)
@DATASETS.register_module()
class ConcatDatasetWithShuffle(TorchSubset):
_repr_indent = 4
def __init__(self, cfgs, seed=42, portion=1):
self.cfgs = cfgs
self.seed = seed
self.portion = portion
dataset = TorchConcatDataset([DATASETS.build(cfg) for cfg in cfgs])
target_len = int(len(dataset) * portion)
indices = list(range(len(dataset))) * int(np.ceil(portion))
rng = np.random.default_rng(seed)
rng.shuffle(indices)
indices = indices[:target_len]
super().__init__(dataset, indices)
__all__ = ['ConcatDataset', 'InterleaveDateset', 'SubSet', 'ConcatDatasetWithShuffle']