File size: 2,835 Bytes
8d34f50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import importlib
import torch.utils.data
from data_test_flow.dd_dataset import DDDataset

def CreateDataLoader(opt):
    data_loader = CustomDatasetDataLoader()
    data_loader.initialize(opt)
    return data_loader

# def CreateTestDataLoader(opt):
#     data_loader = CustomTestDatasetDataLoader()
#     data_loader.initialize(opt)
#     return data_loader

class BaseDataLoader():
    def __init__(self):
        pass

    def initialize(self, opt):
        self.opt = opt
        pass

    def load_data(self):
        return None

class CustomDatasetDataLoader(BaseDataLoader):
    def name(self):
        return 'CustomDatasetDataLoader'
    
    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        self.dataset = DDDataset()
        self.dataset.initialize(opt)
        '''
        sampler = torch.utils.data.distributed.DistributedSampler(self.dataset)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset, 
            batch_size=opt.batch_size, 
            shuffle=False, 
            sampler=sampler)       
        '''
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt.batch_size,
            shuffle=opt.shuffle,
            drop_last=True,
            num_workers=int(opt.num_threads))
        
    def load_data(self):
        return self

    def __len__(self):
        return min(len(self.dataset), self.opt.max_dataset_size)
    
    def __iter__(self):
        for i, data in enumerate(self.dataloader):
            if i * self.opt.batch_size >= self.opt.max_dataset_size:
                break
            yield data

# class CustomTestDatasetDataLoader(BaseDataLoader):
#     def name(self):
#         return 'CustomDatasetDataLoader'
    
#     def initialize(self, opt):
#         BaseDataLoader.initialize(self, opt)
#         self.dataset = DDDatasetTest()
#         self.dataset.initialize(opt)
#         '''
#         sampler = torch.utils.data.distributed.DistributedSampler(self.dataset)
#         self.dataloader = torch.utils.data.DataLoader(
#             self.dataset, 
#             batch_size=opt.batch_size, 
#             shuffle=False, 
#             sampler=sampler)       
#         '''
#         self.dataloader = torch.utils.data.DataLoader(
#             self.dataset,
#             batch_size=opt.batch_size,
#             shuffle=opt.shuffle,
#             drop_last=True,
#             num_workers=int(opt.num_threads))
        
#     def load_data(self):
#         return self

#     def __len__(self):
#         return min(len(self.dataset), self.opt.max_dataset_size)
    
#     def __iter__(self):
#         for i, data in enumerate(self.dataloader):
#             if i * self.opt.batch_size >= self.opt.max_dataset_size:
#                 break
#             yield data