yinwentao
DockerFile
8d34f50
import importlib
import torch.utils.data
from data_test_flow.dd_dataset import DDDataset
def CreateDataLoader(opt):
data_loader = CustomDatasetDataLoader()
data_loader.initialize(opt)
return data_loader
# def CreateTestDataLoader(opt):
# data_loader = CustomTestDatasetDataLoader()
# data_loader.initialize(opt)
# return data_loader
class BaseDataLoader():
def __init__(self):
pass
def initialize(self, opt):
self.opt = opt
pass
def load_data(self):
return None
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = DDDataset()
self.dataset.initialize(opt)
'''
sampler = torch.utils.data.distributed.DistributedSampler(self.dataset)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=False,
sampler=sampler)
'''
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=opt.shuffle,
drop_last=True,
num_workers=int(opt.num_threads))
def load_data(self):
return self
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data
# class CustomTestDatasetDataLoader(BaseDataLoader):
# def name(self):
# return 'CustomDatasetDataLoader'
# def initialize(self, opt):
# BaseDataLoader.initialize(self, opt)
# self.dataset = DDDatasetTest()
# self.dataset.initialize(opt)
# '''
# sampler = torch.utils.data.distributed.DistributedSampler(self.dataset)
# self.dataloader = torch.utils.data.DataLoader(
# self.dataset,
# batch_size=opt.batch_size,
# shuffle=False,
# sampler=sampler)
# '''
# self.dataloader = torch.utils.data.DataLoader(
# self.dataset,
# batch_size=opt.batch_size,
# shuffle=opt.shuffle,
# drop_last=True,
# num_workers=int(opt.num_threads))
# def load_data(self):
# return self
# def __len__(self):
# return min(len(self.dataset), self.opt.max_dataset_size)
# def __iter__(self):
# for i, data in enumerate(self.dataloader):
# if i * self.opt.batch_size >= self.opt.max_dataset_size:
# break
# yield data