Spaces:
Running
Running
import os | |
import json | |
import torch | |
from torchvision import transforms | |
import numpy as np | |
from PIL import Image | |
def imresize(im, size, interp='bilinear'): | |
if interp == 'nearest': | |
resample = Image.NEAREST | |
elif interp == 'bilinear': | |
resample = Image.BILINEAR | |
elif interp == 'bicubic': | |
resample = Image.BICUBIC | |
else: | |
raise Exception('resample method undefined!') | |
return im.resize(size, resample) | |
class BaseDataset(torch.utils.data.Dataset): | |
def __init__(self, odgt, opt, **kwargs): | |
# parse options | |
self.imgSizes = opt.imgSizes | |
self.imgMaxSize = opt.imgMaxSize | |
# max down sampling rate of network to avoid rounding during conv or pooling | |
self.padding_constant = opt.padding_constant | |
# parse the input list | |
self.parse_input_list(odgt, **kwargs) | |
# mean and std | |
self.normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1): | |
if isinstance(odgt, list): | |
self.list_sample = odgt | |
elif isinstance(odgt, str): | |
self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')] | |
if max_sample > 0: | |
self.list_sample = self.list_sample[0:max_sample] | |
if start_idx >= 0 and end_idx >= 0: # divide file list | |
self.list_sample = self.list_sample[start_idx:end_idx] | |
self.num_sample = len(self.list_sample) | |
assert self.num_sample > 0 | |
print('# samples: {}'.format(self.num_sample)) | |
def img_transform(self, img): | |
# 0-255 to 0-1 | |
img = np.float32(np.array(img)) / 255. | |
img = img.transpose((2, 0, 1)) | |
img = self.normalize(torch.from_numpy(img.copy())) | |
return img | |
def segm_transform(self, segm): | |
# to tensor, -1 to 149 | |
segm = torch.from_numpy(np.array(segm)).long() - 1 | |
return segm | |
# Round x to the nearest multiple of p and x' >= x | |
def round2nearest_multiple(self, x, p): | |
return ((x - 1) // p + 1) * p | |
class TrainDataset(BaseDataset): | |
def __init__(self, root_dataset, odgt, opt, batch_per_gpu=1, **kwargs): | |
super(TrainDataset, self).__init__(odgt, opt, **kwargs) | |
self.root_dataset = root_dataset | |
# down sampling rate of segm labe | |
self.segm_downsampling_rate = opt.segm_downsampling_rate | |
self.batch_per_gpu = batch_per_gpu | |
# classify images into two classes: 1. h > w and 2. h <= w | |
self.batch_record_list = [[], []] | |
# override dataset length when trainig with batch_per_gpu > 1 | |
self.cur_idx = 0 | |
self.if_shuffled = False | |
def _get_sub_batch(self): | |
while True: | |
# get a sample record | |
this_sample = self.list_sample[self.cur_idx] | |
if this_sample['height'] > this_sample['width']: | |
self.batch_record_list[0].append(this_sample) # h > w, go to 1st class | |
else: | |
self.batch_record_list[1].append(this_sample) # h <= w, go to 2nd class | |
# update current sample pointer | |
self.cur_idx += 1 | |
if self.cur_idx >= self.num_sample: | |
self.cur_idx = 0 | |
np.random.shuffle(self.list_sample) | |
if len(self.batch_record_list[0]) == self.batch_per_gpu: | |
batch_records = self.batch_record_list[0] | |
self.batch_record_list[0] = [] | |
break | |
elif len(self.batch_record_list[1]) == self.batch_per_gpu: | |
batch_records = self.batch_record_list[1] | |
self.batch_record_list[1] = [] | |
break | |
return batch_records | |
def __getitem__(self, index): | |
# NOTE: random shuffle for the first time. shuffle in __init__ is useless | |
if not self.if_shuffled: | |
np.random.seed(index) | |
np.random.shuffle(self.list_sample) | |
self.if_shuffled = True | |
# get sub-batch candidates | |
batch_records = self._get_sub_batch() | |
# resize all images' short edges to the chosen size | |
if isinstance(self.imgSizes, list) or isinstance(self.imgSizes, tuple): | |
this_short_size = np.random.choice(self.imgSizes) | |
else: | |
this_short_size = self.imgSizes | |
# calculate the BATCH's height and width | |
# since we concat more than one samples, the batch's h and w shall be larger than EACH sample | |
batch_widths = np.zeros(self.batch_per_gpu, np.int32) | |
batch_heights = np.zeros(self.batch_per_gpu, np.int32) | |
for i in range(self.batch_per_gpu): | |
img_height, img_width = batch_records[i]['height'], batch_records[i]['width'] | |
this_scale = min( | |
this_short_size / min(img_height, img_width), \ | |
self.imgMaxSize / max(img_height, img_width)) | |
batch_widths[i] = img_width * this_scale | |
batch_heights[i] = img_height * this_scale | |
# Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w' | |
batch_width = np.max(batch_widths) | |
batch_height = np.max(batch_heights) | |
batch_width = int(self.round2nearest_multiple(batch_width, self.padding_constant)) | |
batch_height = int(self.round2nearest_multiple(batch_height, self.padding_constant)) | |
assert self.padding_constant >= self.segm_downsampling_rate, \ | |
'padding constant must be equal or large than segm downsamping rate' | |
batch_images = torch.zeros( | |
self.batch_per_gpu, 3, batch_height, batch_width) | |
batch_segms = torch.zeros( | |
self.batch_per_gpu, | |
batch_height // self.segm_downsampling_rate, | |
batch_width // self.segm_downsampling_rate).long() | |
for i in range(self.batch_per_gpu): | |
this_record = batch_records[i] | |
# load image and label | |
image_path = os.path.join(self.root_dataset, this_record['fpath_img']) | |
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) | |
img = Image.open(image_path).convert('RGB') | |
segm = Image.open(segm_path) | |
assert(segm.mode == "L") | |
assert(img.size[0] == segm.size[0]) | |
assert(img.size[1] == segm.size[1]) | |
# random_flip | |
if np.random.choice([0, 1]): | |
img = img.transpose(Image.FLIP_LEFT_RIGHT) | |
segm = segm.transpose(Image.FLIP_LEFT_RIGHT) | |
# note that each sample within a mini batch has different scale param | |
img = imresize(img, (batch_widths[i], batch_heights[i]), interp='bilinear') | |
segm = imresize(segm, (batch_widths[i], batch_heights[i]), interp='nearest') | |
# further downsample seg label, need to avoid seg label misalignment | |
segm_rounded_width = self.round2nearest_multiple(segm.size[0], self.segm_downsampling_rate) | |
segm_rounded_height = self.round2nearest_multiple(segm.size[1], self.segm_downsampling_rate) | |
segm_rounded = Image.new('L', (segm_rounded_width, segm_rounded_height), 0) | |
segm_rounded.paste(segm, (0, 0)) | |
segm = imresize( | |
segm_rounded, | |
(segm_rounded.size[0] // self.segm_downsampling_rate, \ | |
segm_rounded.size[1] // self.segm_downsampling_rate), \ | |
interp='nearest') | |
# image transform, to torch float tensor 3xHxW | |
img = self.img_transform(img) | |
# segm transform, to torch long tensor HxW | |
segm = self.segm_transform(segm) | |
# put into batch arrays | |
batch_images[i][:, :img.shape[1], :img.shape[2]] = img | |
batch_segms[i][:segm.shape[0], :segm.shape[1]] = segm | |
output = dict() | |
output['img_data'] = batch_images | |
output['seg_label'] = batch_segms | |
return output | |
def __len__(self): | |
return int(1e10) # It's a fake length due to the trick that every loader maintains its own list | |
#return self.num_sampleclass | |
class ValDataset(BaseDataset): | |
def __init__(self, root_dataset, odgt, opt, **kwargs): | |
super(ValDataset, self).__init__(odgt, opt, **kwargs) | |
self.root_dataset = root_dataset | |
def __getitem__(self, index): | |
this_record = self.list_sample[index] | |
# load image and label | |
image_path = os.path.join(self.root_dataset, this_record['fpath_img']) | |
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) | |
img = Image.open(image_path).convert('RGB') | |
segm = Image.open(segm_path) | |
assert(segm.mode == "L") | |
assert(img.size[0] == segm.size[0]) | |
assert(img.size[1] == segm.size[1]) | |
ori_width, ori_height = img.size | |
img_resized_list = [] | |
for this_short_size in self.imgSizes: | |
# calculate target height and width | |
scale = min(this_short_size / float(min(ori_height, ori_width)), | |
self.imgMaxSize / float(max(ori_height, ori_width))) | |
target_height, target_width = int(ori_height * scale), int(ori_width * scale) | |
# to avoid rounding in network | |
target_width = self.round2nearest_multiple(target_width, self.padding_constant) | |
target_height = self.round2nearest_multiple(target_height, self.padding_constant) | |
# resize images | |
img_resized = imresize(img, (target_width, target_height), interp='bilinear') | |
# image transform, to torch float tensor 3xHxW | |
img_resized = self.img_transform(img_resized) | |
img_resized = torch.unsqueeze(img_resized, 0) | |
img_resized_list.append(img_resized) | |
# segm transform, to torch long tensor HxW | |
segm = self.segm_transform(segm) | |
batch_segms = torch.unsqueeze(segm, 0) | |
output = dict() | |
output['img_ori'] = np.array(img) | |
output['img_data'] = [x.contiguous() for x in img_resized_list] | |
output['seg_label'] = batch_segms.contiguous() | |
output['info'] = this_record['fpath_img'] | |
return output | |
def __len__(self): | |
return self.num_sample | |
class TestDataset(BaseDataset): | |
def __init__(self, odgt, opt, **kwargs): | |
super(TestDataset, self).__init__(odgt, opt, **kwargs) | |
def __getitem__(self, index): | |
this_record = self.list_sample[index] | |
# load image | |
image_path = this_record['fpath_img'] | |
img = Image.open(image_path).convert('RGB') | |
ori_width, ori_height = img.size | |
img_resized_list = [] | |
for this_short_size in self.imgSizes: | |
# calculate target height and width | |
scale = min(this_short_size / float(min(ori_height, ori_width)), | |
self.imgMaxSize / float(max(ori_height, ori_width))) | |
target_height, target_width = int(ori_height * scale), int(ori_width * scale) | |
# to avoid rounding in network | |
target_width = self.round2nearest_multiple(target_width, self.padding_constant) | |
target_height = self.round2nearest_multiple(target_height, self.padding_constant) | |
# resize images | |
img_resized = imresize(img, (target_width, target_height), interp='bilinear') | |
# image transform, to torch float tensor 3xHxW | |
img_resized = self.img_transform(img_resized) | |
img_resized = torch.unsqueeze(img_resized, 0) | |
img_resized_list.append(img_resized) | |
output = dict() | |
output['img_ori'] = np.array(img) | |
output['img_data'] = [x.contiguous() for x in img_resized_list] | |
output['info'] = this_record['fpath_img'] | |
return output | |
def __len__(self): | |
return self.num_sample | |