File size: 1,048 Bytes
8bb8404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import pytorch_lightning as pl
from torch.utils.data import Dataset
import webdataset as wds
from torch.utils.data.distributed import DistributedSampler
class DummyDataset(pl.LightningDataModule):
    def __init__(self,seed):
        super().__init__()

    def setup(self, stage):
        if stage in ['fit']:
            self.train_dataset = DummyData(True)
            self.val_dataset = DummyData(False)
        else:
            raise NotImplementedError

    def train_dataloader(self):
        return wds.WebLoader(self.train_dataset, batch_size=1, num_workers=0, shuffle=False)

    def val_dataloader(self):
        return wds.WebLoader(self.val_dataset, batch_size=1, num_workers=0, shuffle=False)

    def test_dataloader(self):
        return wds.WebLoader(DummyData(False))

class DummyData(Dataset):
    def __init__(self,is_train):
        self.is_train=is_train

    def __len__(self):
        if self.is_train:
            return 99999999
        else:
            return 1

    def __getitem__(self, index):
        return {}