File size: 3,131 Bytes
3424266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright 2024 EPFL and Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from torch.utils.data import Dataset


class RepeatedDatasetWrapper(Dataset):
    def __init__(self, original_dataset, num_repeats):
        """
        Dataset wrapper that repeats the original dataset n times.

        Args:
            original_dataset (torch.utils.data.Dataset): The original dataset to be repeated.
            num_repeats (int): The number of times the dataset should be repeated.
        """
        self.original_dataset = original_dataset
        self.num_repeats = num_repeats

    def __getitem__(self, index):
        """
        Retrieve the item at the given index.
        
        Args:
            index (int): The index of the item to be retrieved.
        """
        original_index = index % len(self.original_dataset)
        return self.original_dataset[original_index]

    def __len__(self):
        """
        Get the length of the dataset after repeating it n times.
        
        Returns:
            int: The length of the dataset.
        """
        return len(self.original_dataset) * self.num_repeats


class SubsampleDatasetWrapper(Dataset):
    def __init__(self, original_dataset, dataset_size, seed=0, return_orig_idx=False):
        """
        Dataset wrapper that randomly subsamples the original dataset.

        Args:
            original_dataset (torch.utils.data.Dataset): The original dataset to be subsampled.
            dataset_size (int): The size of the subsampled dataset.
            seed (int): The seed to use for selecting the subset of indices of the original dataset.
            return_orig_idx (bool): Whether to return the original index of the item in the original dataset.
        """
        self.original_dataset = original_dataset
        self.dataset_size = dataset_size or len(original_dataset)
        self.return_orig_idx = return_orig_idx
        np.random.seed(seed)
        self.indices = np.random.permutation(len(self.original_dataset))[:self.dataset_size]

    def __getitem__(self, index):
        """
        Retrieve the item at the given index.
        
        Args:
            index (int): The index of the item to be retrieved.
        """
        original_index = self.indices[index]
        sample = self.original_dataset[original_index]
        return sample, original_index if self.return_orig_idx else sample

    def __len__(self):
        """
        Get the length of the dataset after subsampling it.
        
        Returns:
            int: The length of the dataset.
        """
        return len(self.indices)