import os import json import torch from torchvision import transforms import numpy as np from PIL import Image def imresize(im, size, interp='bilinear'): if interp == 'nearest': resample = Image.NEAREST elif interp == 'bilinear': resample = Image.BILINEAR elif interp == 'bicubic': resample = Image.BICUBIC else: raise Exception('resample method undefined!') return im.resize(size, resample) class BaseDataset(torch.utils.data.Dataset): def __init__(self, odgt, opt, **kwargs): # parse options self.imgSizes = opt.imgSizes self.imgMaxSize = opt.imgMaxSize # max down sampling rate of network to avoid rounding during conv or pooling self.padding_constant = opt.padding_constant # parse the input list self.parse_input_list(odgt, **kwargs) # mean and std self.normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1): if isinstance(odgt, list): self.list_sample = odgt elif isinstance(odgt, str): self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')] if max_sample > 0: self.list_sample = self.list_sample[0:max_sample] if start_idx >= 0 and end_idx >= 0: # divide file list self.list_sample = self.list_sample[start_idx:end_idx] self.num_sample = len(self.list_sample) assert self.num_sample > 0 print('# samples: {}'.format(self.num_sample)) def img_transform(self, img): # 0-255 to 0-1 img = np.float32(np.array(img)) / 255. img = img.transpose((2, 0, 1)) img = self.normalize(torch.from_numpy(img.copy())) return img def segm_transform(self, segm): # to tensor, -1 to 149 segm = torch.from_numpy(np.array(segm)).long() - 1 return segm # Round x to the nearest multiple of p and x' >= x def round2nearest_multiple(self, x, p): return ((x - 1) // p + 1) * p class TrainDataset(BaseDataset): def __init__(self, root_dataset, odgt, opt, batch_per_gpu=1, **kwargs): super(TrainDataset, self).__init__(odgt, opt, **kwargs) self.root_dataset = root_dataset # down sampling rate of segm labe self.segm_downsampling_rate = opt.segm_downsampling_rate self.batch_per_gpu = batch_per_gpu # classify images into two classes: 1. h > w and 2. h <= w self.batch_record_list = [[], []] # override dataset length when trainig with batch_per_gpu > 1 self.cur_idx = 0 self.if_shuffled = False def _get_sub_batch(self): while True: # get a sample record this_sample = self.list_sample[self.cur_idx] if this_sample['height'] > this_sample['width']: self.batch_record_list[0].append(this_sample) # h > w, go to 1st class else: self.batch_record_list[1].append(this_sample) # h <= w, go to 2nd class # update current sample pointer self.cur_idx += 1 if self.cur_idx >= self.num_sample: self.cur_idx = 0 np.random.shuffle(self.list_sample) if len(self.batch_record_list[0]) == self.batch_per_gpu: batch_records = self.batch_record_list[0] self.batch_record_list[0] = [] break elif len(self.batch_record_list[1]) == self.batch_per_gpu: batch_records = self.batch_record_list[1] self.batch_record_list[1] = [] break return batch_records def __getitem__(self, index): # NOTE: random shuffle for the first time. shuffle in __init__ is useless if not self.if_shuffled: np.random.seed(index) np.random.shuffle(self.list_sample) self.if_shuffled = True # get sub-batch candidates batch_records = self._get_sub_batch() # resize all images' short edges to the chosen size if isinstance(self.imgSizes, list) or isinstance(self.imgSizes, tuple): this_short_size = np.random.choice(self.imgSizes) else: this_short_size = self.imgSizes # calculate the BATCH's height and width # since we concat more than one samples, the batch's h and w shall be larger than EACH sample batch_widths = np.zeros(self.batch_per_gpu, np.int32) batch_heights = np.zeros(self.batch_per_gpu, np.int32) for i in range(self.batch_per_gpu): img_height, img_width = batch_records[i]['height'], batch_records[i]['width'] this_scale = min( this_short_size / min(img_height, img_width), \ self.imgMaxSize / max(img_height, img_width)) batch_widths[i] = img_width * this_scale batch_heights[i] = img_height * this_scale # Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w' batch_width = np.max(batch_widths) batch_height = np.max(batch_heights) batch_width = int(self.round2nearest_multiple(batch_width, self.padding_constant)) batch_height = int(self.round2nearest_multiple(batch_height, self.padding_constant)) assert self.padding_constant >= self.segm_downsampling_rate, \ 'padding constant must be equal or large than segm downsamping rate' batch_images = torch.zeros( self.batch_per_gpu, 3, batch_height, batch_width) batch_segms = torch.zeros( self.batch_per_gpu, batch_height // self.segm_downsampling_rate, batch_width // self.segm_downsampling_rate).long() for i in range(self.batch_per_gpu): this_record = batch_records[i] # load image and label image_path = os.path.join(self.root_dataset, this_record['fpath_img']) segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) img = Image.open(image_path).convert('RGB') segm = Image.open(segm_path) assert(segm.mode == "L") assert(img.size[0] == segm.size[0]) assert(img.size[1] == segm.size[1]) # random_flip if np.random.choice([0, 1]): img = img.transpose(Image.FLIP_LEFT_RIGHT) segm = segm.transpose(Image.FLIP_LEFT_RIGHT) # note that each sample within a mini batch has different scale param img = imresize(img, (batch_widths[i], batch_heights[i]), interp='bilinear') segm = imresize(segm, (batch_widths[i], batch_heights[i]), interp='nearest') # further downsample seg label, need to avoid seg label misalignment segm_rounded_width = self.round2nearest_multiple(segm.size[0], self.segm_downsampling_rate) segm_rounded_height = self.round2nearest_multiple(segm.size[1], self.segm_downsampling_rate) segm_rounded = Image.new('L', (segm_rounded_width, segm_rounded_height), 0) segm_rounded.paste(segm, (0, 0)) segm = imresize( segm_rounded, (segm_rounded.size[0] // self.segm_downsampling_rate, \ segm_rounded.size[1] // self.segm_downsampling_rate), \ interp='nearest') # image transform, to torch float tensor 3xHxW img = self.img_transform(img) # segm transform, to torch long tensor HxW segm = self.segm_transform(segm) # put into batch arrays batch_images[i][:, :img.shape[1], :img.shape[2]] = img batch_segms[i][:segm.shape[0], :segm.shape[1]] = segm output = dict() output['img_data'] = batch_images output['seg_label'] = batch_segms return output def __len__(self): return int(1e10) # It's a fake length due to the trick that every loader maintains its own list #return self.num_sampleclass class ValDataset(BaseDataset): def __init__(self, root_dataset, odgt, opt, **kwargs): super(ValDataset, self).__init__(odgt, opt, **kwargs) self.root_dataset = root_dataset def __getitem__(self, index): this_record = self.list_sample[index] # load image and label image_path = os.path.join(self.root_dataset, this_record['fpath_img']) segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) img = Image.open(image_path).convert('RGB') segm = Image.open(segm_path) assert(segm.mode == "L") assert(img.size[0] == segm.size[0]) assert(img.size[1] == segm.size[1]) ori_width, ori_height = img.size img_resized_list = [] for this_short_size in self.imgSizes: # calculate target height and width scale = min(this_short_size / float(min(ori_height, ori_width)), self.imgMaxSize / float(max(ori_height, ori_width))) target_height, target_width = int(ori_height * scale), int(ori_width * scale) # to avoid rounding in network target_width = self.round2nearest_multiple(target_width, self.padding_constant) target_height = self.round2nearest_multiple(target_height, self.padding_constant) # resize images img_resized = imresize(img, (target_width, target_height), interp='bilinear') # image transform, to torch float tensor 3xHxW img_resized = self.img_transform(img_resized) img_resized = torch.unsqueeze(img_resized, 0) img_resized_list.append(img_resized) # segm transform, to torch long tensor HxW segm = self.segm_transform(segm) batch_segms = torch.unsqueeze(segm, 0) output = dict() output['img_ori'] = np.array(img) output['img_data'] = [x.contiguous() for x in img_resized_list] output['seg_label'] = batch_segms.contiguous() output['info'] = this_record['fpath_img'] return output def __len__(self): return self.num_sample class TestDataset(BaseDataset): def __init__(self, odgt, opt, **kwargs): super(TestDataset, self).__init__(odgt, opt, **kwargs) def __getitem__(self, index): this_record = self.list_sample[index] # load image image_path = this_record['fpath_img'] img = Image.open(image_path).convert('RGB') ori_width, ori_height = img.size img_resized_list = [] for this_short_size in self.imgSizes: # calculate target height and width scale = min(this_short_size / float(min(ori_height, ori_width)), self.imgMaxSize / float(max(ori_height, ori_width))) target_height, target_width = int(ori_height * scale), int(ori_width * scale) # to avoid rounding in network target_width = self.round2nearest_multiple(target_width, self.padding_constant) target_height = self.round2nearest_multiple(target_height, self.padding_constant) # resize images img_resized = imresize(img, (target_width, target_height), interp='bilinear') # image transform, to torch float tensor 3xHxW img_resized = self.img_transform(img_resized) img_resized = torch.unsqueeze(img_resized, 0) img_resized_list.append(img_resized) output = dict() output['img_ori'] = np.array(img) output['img_data'] = [x.contiguous() for x in img_resized_list] output['info'] = this_record['fpath_img'] return output def __len__(self): return self.num_sample