Spaces:
Runtime error
Runtime error
from pix2pix.data.base_dataset import BaseDataset | |
from pix2pix.data.image_folder import make_dataset | |
from pix2pix.util.guidedfilter import GuidedFilter | |
import numpy as np | |
import os | |
import torch | |
from PIL import Image | |
def normalize(img): | |
img = img * 2 | |
img = img - 1 | |
return img | |
def normalize01(img): | |
return (img - torch.min(img)) / (torch.max(img)-torch.min(img)) | |
class DepthMergeDataset(BaseDataset): | |
def __init__(self, opt): | |
BaseDataset.__init__(self, opt) | |
self.dir_outer = os.path.join(opt.dataroot, opt.phase, 'outer') | |
self.dir_inner = os.path.join(opt.dataroot, opt.phase, 'inner') | |
self.dir_gtfake = os.path.join(opt.dataroot, opt.phase, 'gtfake') | |
self.outer_paths = sorted(make_dataset(self.dir_outer, opt.max_dataset_size)) | |
self.inner_paths = sorted(make_dataset(self.dir_inner, opt.max_dataset_size)) | |
self.gtfake_paths = sorted(make_dataset(self.dir_gtfake, opt.max_dataset_size)) | |
self.dataset_size = len(self.outer_paths) | |
if opt.phase == 'train': | |
self.isTrain = True | |
else: | |
self.isTrain = False | |
def __getitem__(self, index): | |
normalize_coef = np.float32(2 ** 16) | |
data_outer = Image.open(self.outer_paths[index % self.dataset_size]) # needs to be a tensor | |
data_outer = np.array(data_outer, dtype=np.float32) | |
data_outer = data_outer / normalize_coef | |
data_inner = Image.open(self.inner_paths[index % self.dataset_size]) # needs to be a tensor | |
data_inner = np.array(data_inner, dtype=np.float32) | |
data_inner = data_inner / normalize_coef | |
if self.isTrain: | |
data_gtfake = Image.open(self.gtfake_paths[index % self.dataset_size]) # needs to be a tensor | |
data_gtfake = np.array(data_gtfake, dtype=np.float32) | |
data_gtfake = data_gtfake / normalize_coef | |
data_inner = GuidedFilter(data_gtfake, data_inner, 64, 0.00000001).smooth.astype('float32') | |
data_outer = GuidedFilter(data_outer, data_gtfake, 64, 0.00000001).smooth.astype('float32') | |
data_outer = torch.from_numpy(data_outer) | |
data_outer = torch.unsqueeze(data_outer, 0) | |
data_outer = normalize01(data_outer) | |
data_outer = normalize(data_outer) | |
data_inner = torch.from_numpy(data_inner) | |
data_inner = torch.unsqueeze(data_inner, 0) | |
data_inner = normalize01(data_inner) | |
data_inner = normalize(data_inner) | |
if self.isTrain: | |
data_gtfake = torch.from_numpy(data_gtfake) | |
data_gtfake = torch.unsqueeze(data_gtfake, 0) | |
data_gtfake = normalize01(data_gtfake) | |
data_gtfake = normalize(data_gtfake) | |
image_path = self.outer_paths[index % self.dataset_size] | |
if self.isTrain: | |
return {'data_inner': data_inner, 'data_outer': data_outer, | |
'data_gtfake': data_gtfake, 'image_path': image_path} | |
else: | |
return {'data_inner': data_inner, 'data_outer': data_outer, 'image_path': image_path} | |
def __len__(self): | |
"""Return the total number of images.""" | |
return self.dataset_size | |