yinwentao
DockerFile
8d34f50
import os.path
import torch
import torch.utils.data as data
from PIL import Image
import random
import utils
import numpy as np
import torchvision.transforms as transforms
from utils_core import flow_viz
import cv2
class DDDataset(data.Dataset):
def __init__(self):
super(DDDataset, self).__init__()
def initialize(self, opt):
self.opt = opt
self.dir_txt = opt.datapath
self.paths = []
in_file = open(self.dir_txt, "r")
k = 0
list_paths = in_file.readlines()
for line in list_paths:
#if k>=20: break
flag = False
line = line.strip()
line = line.split()
#source data
if (not os.path.exists(line[0])):
print(line[0]+" not exists")
continue
if (not os.path.exists(line[1])):
print(line[1]+" not exists")
continue
if (not os.path.exists(line[2])):
print(line[2]+" not exists")
continue
if (not os.path.exists(line[3])):
print(line[3]+" not exists")
continue
# if (not os.path.exists(line[2])):
# print(line[2]+" not exists")
# continue
# path_list = [line[0], line[1], line[2]]
path_list = [line[0], line[1], line[2], line[3]]
self.paths.append(path_list)
k += 1
in_file.close()
self.data_size = len(self.paths)
print("num data: ", len(self.paths))
def process_data(self, color, mask):
non_zero = mask.nonzero()
bound = 10
min_x = max(0, non_zero[1].min()-bound)
max_x = min(self.opt.width-1, non_zero[1].max()+bound)
min_y = max(0, non_zero[0].min()-bound)
max_y = min(self.opt.height-1, non_zero[0].max()+bound)
color = color * (mask!=0).astype(float)[:, :, None]
crop_color = color[min_y:max_y, min_x:max_x, :]
crop_color = cv2.resize(np.ascontiguousarray(crop_color), (self.opt.crop_width, self.opt.crop_height), interpolation=cv2.INTER_LINEAR)
crop_params = [[min_x], [max_x], [min_y], [max_y]]
return crop_color, crop_params
def __getitem__(self, index):
paths = self.paths[index % self.data_size]
src_color = np.array(Image.open(paths[0]))
src_color = src_color.astype(np.uint8)
raw_src_color = src_color.copy()
src_mask = np.array(Image.open(paths[1]))
src_mask_copy = src_mask.copy()
src_crop_color, src_crop_params = self.process_data(src_color, src_mask)
#self.write_mesh(src_X, src_Y, src_Z, "./tmp/src.obj")
#HWC --> CHW,
raw_src_color = torch.from_numpy(raw_src_color).permute(2, 0, 1).float() / 255.0
src_crop_color = torch.from_numpy(src_crop_color).permute(2, 0, 1).float() / 255.0
src_mask_copy = (src_mask_copy!=0)
src_mask_copy = torch.tensor(src_mask_copy[np.newaxis, :, :])
tar_color = np.array(Image.open(paths[2]))
tar_color = tar_color.astype(np.uint8)
raw_tar_color = tar_color.copy()
tar_mask = np.array(Image.open(paths[3]))
tar_mask_copy = tar_mask.copy()
tar_crop_color, tar_crop_params = self.process_data(tar_color, tar_mask)
raw_tar_color = torch.from_numpy(raw_tar_color).permute(2, 0, 1).float() / 255.0
tar_crop_color = torch.from_numpy(tar_crop_color).permute(2, 0, 1).float() / 255.0
tar_mask_copy = (tar_mask_copy!=0)
tar_mask_copy = torch.tensor(tar_mask_copy[np.newaxis, :, :])
Crop_param = torch.tensor(src_crop_params+tar_crop_params)
split_ = paths[0].split("/")
path1 = split_[-1][:-4] + "_" + paths[2].split("/")[-1][:-4] +".oflow"
return {"path_flow":path1, "src_crop_color":src_crop_color, "tar_crop_color":tar_crop_color, "src_color":raw_src_color, "tar_color":raw_tar_color, "src_mask":src_mask_copy, "tar_mask":tar_mask_copy, "Crop_param":Crop_param}
def __len__(self):
return self.data_size
def name(self):
return 'DDDataset'