Realcat's picture
update: major change
499e141
import os
import json
import torch
from torchvision import transforms
import numpy as np
from PIL import Image
def imresize(im, size, interp='bilinear'):
if interp == 'nearest':
resample = Image.NEAREST
elif interp == 'bilinear':
resample = Image.BILINEAR
elif interp == 'bicubic':
resample = Image.BICUBIC
else:
raise Exception('resample method undefined!')
return im.resize(size, resample)
class BaseDataset(torch.utils.data.Dataset):
def __init__(self, odgt, opt, **kwargs):
# parse options
self.imgSizes = opt.imgSizes
self.imgMaxSize = opt.imgMaxSize
# max down sampling rate of network to avoid rounding during conv or pooling
self.padding_constant = opt.padding_constant
# parse the input list
self.parse_input_list(odgt, **kwargs)
# mean and std
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1):
if isinstance(odgt, list):
self.list_sample = odgt
elif isinstance(odgt, str):
self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]
if max_sample > 0:
self.list_sample = self.list_sample[0:max_sample]
if start_idx >= 0 and end_idx >= 0: # divide file list
self.list_sample = self.list_sample[start_idx:end_idx]
self.num_sample = len(self.list_sample)
assert self.num_sample > 0
print('# samples: {}'.format(self.num_sample))
def img_transform(self, img):
# 0-255 to 0-1
img = np.float32(np.array(img)) / 255.
img = img.transpose((2, 0, 1))
img = self.normalize(torch.from_numpy(img.copy()))
return img
def segm_transform(self, segm):
# to tensor, -1 to 149
segm = torch.from_numpy(np.array(segm)).long() - 1
return segm
# Round x to the nearest multiple of p and x' >= x
def round2nearest_multiple(self, x, p):
return ((x - 1) // p + 1) * p
class TrainDataset(BaseDataset):
def __init__(self, root_dataset, odgt, opt, batch_per_gpu=1, **kwargs):
super(TrainDataset, self).__init__(odgt, opt, **kwargs)
self.root_dataset = root_dataset
# down sampling rate of segm labe
self.segm_downsampling_rate = opt.segm_downsampling_rate
self.batch_per_gpu = batch_per_gpu
# classify images into two classes: 1. h > w and 2. h <= w
self.batch_record_list = [[], []]
# override dataset length when trainig with batch_per_gpu > 1
self.cur_idx = 0
self.if_shuffled = False
def _get_sub_batch(self):
while True:
# get a sample record
this_sample = self.list_sample[self.cur_idx]
if this_sample['height'] > this_sample['width']:
self.batch_record_list[0].append(this_sample) # h > w, go to 1st class
else:
self.batch_record_list[1].append(this_sample) # h <= w, go to 2nd class
# update current sample pointer
self.cur_idx += 1
if self.cur_idx >= self.num_sample:
self.cur_idx = 0
np.random.shuffle(self.list_sample)
if len(self.batch_record_list[0]) == self.batch_per_gpu:
batch_records = self.batch_record_list[0]
self.batch_record_list[0] = []
break
elif len(self.batch_record_list[1]) == self.batch_per_gpu:
batch_records = self.batch_record_list[1]
self.batch_record_list[1] = []
break
return batch_records
def __getitem__(self, index):
# NOTE: random shuffle for the first time. shuffle in __init__ is useless
if not self.if_shuffled:
np.random.seed(index)
np.random.shuffle(self.list_sample)
self.if_shuffled = True
# get sub-batch candidates
batch_records = self._get_sub_batch()
# resize all images' short edges to the chosen size
if isinstance(self.imgSizes, list) or isinstance(self.imgSizes, tuple):
this_short_size = np.random.choice(self.imgSizes)
else:
this_short_size = self.imgSizes
# calculate the BATCH's height and width
# since we concat more than one samples, the batch's h and w shall be larger than EACH sample
batch_widths = np.zeros(self.batch_per_gpu, np.int32)
batch_heights = np.zeros(self.batch_per_gpu, np.int32)
for i in range(self.batch_per_gpu):
img_height, img_width = batch_records[i]['height'], batch_records[i]['width']
this_scale = min(
this_short_size / min(img_height, img_width), \
self.imgMaxSize / max(img_height, img_width))
batch_widths[i] = img_width * this_scale
batch_heights[i] = img_height * this_scale
# Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w'
batch_width = np.max(batch_widths)
batch_height = np.max(batch_heights)
batch_width = int(self.round2nearest_multiple(batch_width, self.padding_constant))
batch_height = int(self.round2nearest_multiple(batch_height, self.padding_constant))
assert self.padding_constant >= self.segm_downsampling_rate, \
'padding constant must be equal or large than segm downsamping rate'
batch_images = torch.zeros(
self.batch_per_gpu, 3, batch_height, batch_width)
batch_segms = torch.zeros(
self.batch_per_gpu,
batch_height // self.segm_downsampling_rate,
batch_width // self.segm_downsampling_rate).long()
for i in range(self.batch_per_gpu):
this_record = batch_records[i]
# load image and label
image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
img = Image.open(image_path).convert('RGB')
segm = Image.open(segm_path)
assert(segm.mode == "L")
assert(img.size[0] == segm.size[0])
assert(img.size[1] == segm.size[1])
# random_flip
if np.random.choice([0, 1]):
img = img.transpose(Image.FLIP_LEFT_RIGHT)
segm = segm.transpose(Image.FLIP_LEFT_RIGHT)
# note that each sample within a mini batch has different scale param
img = imresize(img, (batch_widths[i], batch_heights[i]), interp='bilinear')
segm = imresize(segm, (batch_widths[i], batch_heights[i]), interp='nearest')
# further downsample seg label, need to avoid seg label misalignment
segm_rounded_width = self.round2nearest_multiple(segm.size[0], self.segm_downsampling_rate)
segm_rounded_height = self.round2nearest_multiple(segm.size[1], self.segm_downsampling_rate)
segm_rounded = Image.new('L', (segm_rounded_width, segm_rounded_height), 0)
segm_rounded.paste(segm, (0, 0))
segm = imresize(
segm_rounded,
(segm_rounded.size[0] // self.segm_downsampling_rate, \
segm_rounded.size[1] // self.segm_downsampling_rate), \
interp='nearest')
# image transform, to torch float tensor 3xHxW
img = self.img_transform(img)
# segm transform, to torch long tensor HxW
segm = self.segm_transform(segm)
# put into batch arrays
batch_images[i][:, :img.shape[1], :img.shape[2]] = img
batch_segms[i][:segm.shape[0], :segm.shape[1]] = segm
output = dict()
output['img_data'] = batch_images
output['seg_label'] = batch_segms
return output
def __len__(self):
return int(1e10) # It's a fake length due to the trick that every loader maintains its own list
#return self.num_sampleclass
class ValDataset(BaseDataset):
def __init__(self, root_dataset, odgt, opt, **kwargs):
super(ValDataset, self).__init__(odgt, opt, **kwargs)
self.root_dataset = root_dataset
def __getitem__(self, index):
this_record = self.list_sample[index]
# load image and label
image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
img = Image.open(image_path).convert('RGB')
segm = Image.open(segm_path)
assert(segm.mode == "L")
assert(img.size[0] == segm.size[0])
assert(img.size[1] == segm.size[1])
ori_width, ori_height = img.size
img_resized_list = []
for this_short_size in self.imgSizes:
# calculate target height and width
scale = min(this_short_size / float(min(ori_height, ori_width)),
self.imgMaxSize / float(max(ori_height, ori_width)))
target_height, target_width = int(ori_height * scale), int(ori_width * scale)
# to avoid rounding in network
target_width = self.round2nearest_multiple(target_width, self.padding_constant)
target_height = self.round2nearest_multiple(target_height, self.padding_constant)
# resize images
img_resized = imresize(img, (target_width, target_height), interp='bilinear')
# image transform, to torch float tensor 3xHxW
img_resized = self.img_transform(img_resized)
img_resized = torch.unsqueeze(img_resized, 0)
img_resized_list.append(img_resized)
# segm transform, to torch long tensor HxW
segm = self.segm_transform(segm)
batch_segms = torch.unsqueeze(segm, 0)
output = dict()
output['img_ori'] = np.array(img)
output['img_data'] = [x.contiguous() for x in img_resized_list]
output['seg_label'] = batch_segms.contiguous()
output['info'] = this_record['fpath_img']
return output
def __len__(self):
return self.num_sample
class TestDataset(BaseDataset):
def __init__(self, odgt, opt, **kwargs):
super(TestDataset, self).__init__(odgt, opt, **kwargs)
def __getitem__(self, index):
this_record = self.list_sample[index]
# load image
image_path = this_record['fpath_img']
img = Image.open(image_path).convert('RGB')
ori_width, ori_height = img.size
img_resized_list = []
for this_short_size in self.imgSizes:
# calculate target height and width
scale = min(this_short_size / float(min(ori_height, ori_width)),
self.imgMaxSize / float(max(ori_height, ori_width)))
target_height, target_width = int(ori_height * scale), int(ori_width * scale)
# to avoid rounding in network
target_width = self.round2nearest_multiple(target_width, self.padding_constant)
target_height = self.round2nearest_multiple(target_height, self.padding_constant)
# resize images
img_resized = imresize(img, (target_width, target_height), interp='bilinear')
# image transform, to torch float tensor 3xHxW
img_resized = self.img_transform(img_resized)
img_resized = torch.unsqueeze(img_resized, 0)
img_resized_list.append(img_resized)
output = dict()
output['img_ori'] = np.array(img)
output['img_data'] = [x.contiguous() for x in img_resized_list]
output['info'] = this_record['fpath_img']
return output
def __len__(self):
return self.num_sample