yinwentao
DockerFile
8d34f50
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
def __len__(self):
return 0
def get_transform(opt):
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Resize(osize, Image.BICUBIC))
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'crop':
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'scale_width':
transform_list.append(transforms.Lambda(
lambda img: __scale_width(img, opt.fineSize)))
elif opt.resize_or_crop == 'scale_width_and_crop':
transform_list.append(transforms.Lambda(
lambda img: __scale_width(img, opt.loadSize)))
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'none':
transform_list.append(transforms.Lambda(
lambda img: __adjust(img)))
else:
raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop)
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())
transform_list += [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
# just modify the width and height to be multiple of 4
def __adjust(img):
ow, oh = img.size
# the size needs to be a multiple of this number,
# because going through generator network may change img size
# and eventually cause size mismatch error
mult = 4
if ow % mult == 0 and oh % mult == 0:
return img
w = (ow - 1) // mult
w = (w + 1) * mult
h = (oh - 1) // mult
h = (h + 1) * mult
if ow != w or oh != h:
__print_size_warning(ow, oh, w, h)
return img.resize((w, h), Image.BICUBIC)
def __scale_width(img, target_width):
ow, oh = img.size
# the size needs to be a multiple of this number,
# because going through generator network may change img size
# and eventually cause size mismatch error
mult = 4
assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult
if (ow == target_width and oh % mult == 0):
return img
w = target_width
target_height = int(target_width * oh / ow)
m = (target_height - 1) // mult
h = (m + 1) * mult
if target_height != h:
__print_size_warning(target_width, target_height, w, h)
return img.resize((w, h), Image.BICUBIC)
def __print_size_warning(ow, oh, w, h):
if not hasattr(__print_size_warning, 'has_printed'):
print("The image size needs to be a multiple of 4. "
"The loaded image size was (%d, %d), so it was adjusted to "
"(%d, %d). This adjustment will be done to all images "
"whose sizes are not multiples of 4" % (ow, oh, w, h))
__print_size_warning.has_printed = True