Spaces:
Runtime error
Runtime error
| """This package includes all the modules related to data loading and preprocessing | |
| To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. | |
| You need to implement four functions: | |
| -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). | |
| -- <__len__>: return the size of dataset. | |
| -- <__getitem__>: get a data point from data loader. | |
| -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options. | |
| Now you can use the dataset class by specifying flag '--dataset_mode dummy'. | |
| See our template dataset class 'template_dataset.py' for more details. | |
| """ | |
| import importlib | |
| import torch.utils.data | |
| from swapae.data.base_dataset import BaseDataset | |
| import swapae.util as util | |
| def find_dataset_using_name(dataset_name): | |
| """Import the module "data/[dataset_name]_dataset.py". | |
| In the file, the class called DatasetNameDataset() will | |
| be instantiated. It has to be a subclass of BaseDataset, | |
| and it is case-insensitive. | |
| """ | |
| dataset_filename = "swapae.data." + dataset_name + "_dataset" | |
| datasetlib = importlib.import_module(dataset_filename) | |
| dataset = None | |
| target_dataset_name = dataset_name.replace('_', '') + 'dataset' | |
| for name, cls in datasetlib.__dict__.items(): | |
| if name.lower() == target_dataset_name.lower() \ | |
| and issubclass(cls, BaseDataset): | |
| dataset = cls | |
| if dataset is None: | |
| raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) | |
| return dataset | |
| def get_option_setter(dataset_name): | |
| """Return the static method <modify_commandline_options> of the dataset class.""" | |
| dataset_class = find_dataset_using_name(dataset_name) | |
| return dataset_class.modify_commandline_options | |
| def create_dataset(opt): | |
| return ConfigurableDataLoader(opt) | |
| class DataPrefetcher(): | |
| def __init__(self, dataset): | |
| self.dataset = dataset | |
| self.stream = torch.cuda.Stream() | |
| self.preload() | |
| def preload(self): | |
| try: | |
| self.next_input = next(self.dataset) | |
| except StopIteration: | |
| self.next_input = None | |
| return | |
| with torch.cuda.stream(self.stream): | |
| self.next_input = self.next_input.cuda(non_blocking=True) | |
| def __next__(self): | |
| torch.cuda.current_stream().wait_stream(self.stream) | |
| input = self.next_input | |
| self.preload() | |
| return input | |
| def __iter__(self): | |
| return self | |
| def __len__(self): | |
| return len(self.dataset) | |
| class ConfigurableDataLoader(): | |
| def __init__(self, opt): | |
| self.opt = opt | |
| self.initialize(opt.phase) | |
| def initialize(self, phase): | |
| opt = self.opt | |
| self.phase = phase | |
| if hasattr(self, "dataloader"): | |
| del self.dataloader | |
| dataset_class = find_dataset_using_name(opt.dataset_mode) | |
| dataset = dataset_class(util.copyconf(opt, phase=phase, isTrain=phase == "train")) | |
| shuffle = phase == "train" if opt.shuffle_dataset is None else opt.shuffle_dataset == "true" | |
| print("dataset [%s] of size %d was created. shuffled=%s" % (type(dataset).__name__, len(dataset), shuffle)) | |
| #dataset = DataPrefetcher(dataset) | |
| self.opt = opt | |
| self.dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=opt.batch_size, | |
| shuffle=shuffle, | |
| num_workers=int(opt.num_gpus), | |
| drop_last=phase == "train", | |
| ) | |
| #self.dataloader = dataset | |
| self.dataloader_iterator = iter(self.dataloader) | |
| self.repeat = phase == "train" | |
| self.length = len(dataset) | |
| self.underlying_dataset = dataset | |
| def set_phase(self, target_phase): | |
| if self.phase != target_phase: | |
| self.initialize(target_phase) | |
| def __iter__(self): | |
| self.dataloader_iterator = iter(self.dataloader) | |
| return self | |
| def __len__(self): | |
| return self.length | |
| def __next__(self): | |
| try: | |
| return next(self.dataloader_iterator) | |
| except StopIteration: | |
| if self.repeat: | |
| self.dataloader_iterator = iter(self.dataloader) | |
| return next(self.dataloader_iterator) | |
| else: | |
| raise StopIteration | |