File size: 1,064 Bytes
51e2f90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import warnings
from typing import Dict, Optional, Union

import torch
from torch import nn
from torch.utils import data


class AugmentedDataset(data.Dataset):
    def __init__(

            self,

            dataset: data.Dataset,

            augmentation: nn.Module = nn.Identity(),

            target_length: Optional[int] = None,

    ) -> None:
        warnings.warn(
                "This class is no longer used. Attach augmentation to "
                "the LightningSystem instead.",
                DeprecationWarning,
        )

        self.dataset = dataset
        self.augmentation = augmentation

        self.ds_length: int = len(dataset)  # type: ignore[arg-type]
        self.length = target_length if target_length is not None else self.ds_length

    def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str,
    torch.Tensor]]]:
        item = self.dataset[index % self.ds_length]
        item = self.augmentation(item)
        return item

    def __len__(self) -> int:
        return self.length