Spaces:
Running
Running
| import numbers | |
| import os | |
| import queue as Queue | |
| import threading | |
| import mxnet as mx | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| from torchvision import transforms | |
| class BackgroundGenerator(threading.Thread): | |
| def __init__(self, generator, local_rank, max_prefetch=6): | |
| super(BackgroundGenerator, self).__init__() | |
| self.queue = Queue.Queue(max_prefetch) | |
| self.generator = generator | |
| self.local_rank = local_rank | |
| self.daemon = True | |
| self.start() | |
| def run(self): | |
| torch.cuda.set_device(self.local_rank) | |
| for item in self.generator: | |
| self.queue.put(item) | |
| self.queue.put(None) | |
| def next(self): | |
| next_item = self.queue.get() | |
| if next_item is None: | |
| raise StopIteration | |
| return next_item | |
| def __next__(self): | |
| return self.next() | |
| def __iter__(self): | |
| return self | |
| class DataLoaderX(DataLoader): | |
| def __init__(self, local_rank, **kwargs): | |
| super(DataLoaderX, self).__init__(**kwargs) | |
| self.stream = torch.cuda.Stream(local_rank) | |
| self.local_rank = local_rank | |
| def __iter__(self): | |
| self.iter = super(DataLoaderX, self).__iter__() | |
| self.iter = BackgroundGenerator(self.iter, self.local_rank) | |
| self.preload() | |
| return self | |
| def preload(self): | |
| self.batch = next(self.iter, None) | |
| if self.batch is None: | |
| return None | |
| with torch.cuda.stream(self.stream): | |
| for k in range(len(self.batch)): | |
| self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) | |
| def __next__(self): | |
| torch.cuda.current_stream().wait_stream(self.stream) | |
| batch = self.batch | |
| if batch is None: | |
| raise StopIteration | |
| self.preload() | |
| return batch | |
| class MXFaceDataset(Dataset): | |
| def __init__(self, root_dir, local_rank): | |
| super(MXFaceDataset, self).__init__() | |
| self.transform = transforms.Compose( | |
| [transforms.ToPILImage(), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ]) | |
| self.root_dir = root_dir | |
| self.local_rank = local_rank | |
| path_imgrec = os.path.join(root_dir, 'train.rec') | |
| path_imgidx = os.path.join(root_dir, 'train.idx') | |
| self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') | |
| s = self.imgrec.read_idx(0) | |
| header, _ = mx.recordio.unpack(s) | |
| if header.flag > 0: | |
| self.header0 = (int(header.label[0]), int(header.label[1])) | |
| self.imgidx = np.array(range(1, int(header.label[0]))) | |
| else: | |
| self.imgidx = np.array(list(self.imgrec.keys)) | |
| def __getitem__(self, index): | |
| idx = self.imgidx[index] | |
| s = self.imgrec.read_idx(idx) | |
| header, img = mx.recordio.unpack(s) | |
| label = header.label | |
| if not isinstance(label, numbers.Number): | |
| label = label[0] | |
| label = torch.tensor(label, dtype=torch.long) | |
| sample = mx.image.imdecode(img).asnumpy() | |
| if self.transform is not None: | |
| sample = self.transform(sample) | |
| return sample, label | |
| def __len__(self): | |
| return len(self.imgidx) | |
| class SyntheticDataset(Dataset): | |
| def __init__(self, local_rank): | |
| super(SyntheticDataset, self).__init__() | |
| img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) | |
| img = np.transpose(img, (2, 0, 1)) | |
| img = torch.from_numpy(img).squeeze(0).float() | |
| img = ((img / 255) - 0.5) / 0.5 | |
| self.img = img | |
| self.label = 1 | |
| def __getitem__(self, index): | |
| return self.img, self.label | |
| def __len__(self): | |
| return 1000000 | |