yinwentao
DockerFile
8d34f50
raw
history blame
4.16 kB
import os.path
import torch
import torch.utils.data as data
from PIL import Image
import random
import utils
import numpy as np
import torchvision.transforms as transforms
from utils_core import flow_viz
import cv2
class DDDataset(data.Dataset):
def __init__(self):
super(DDDataset, self).__init__()
def initialize(self, opt):
self.opt = opt
self.dir_txt = opt.datapath
self.paths = []
in_file = open(self.dir_txt, "r")
k = 0
list_paths = in_file.readlines()
for line in list_paths:
#if k>=20: break
flag = False
line = line.strip()
line = line.split()
#source data
if (not os.path.exists(line[0])):
print(line[0]+" not exists")
continue
if (not os.path.exists(line[1])):
print(line[1]+" not exists")
continue
if (not os.path.exists(line[2])):
print(line[2]+" not exists")
continue
if (not os.path.exists(line[3])):
print(line[3]+" not exists")
continue
# if (not os.path.exists(line[2])):
# print(line[2]+" not exists")
# continue
# path_list = [line[0], line[1], line[2]]
path_list = [line[0], line[1], line[2], line[3]]
self.paths.append(path_list)
k += 1
in_file.close()
self.data_size = len(self.paths)
print("num data: ", len(self.paths))
def process_data(self, color, mask):
non_zero = mask.nonzero()
bound = 10
min_x = max(0, non_zero[1].min()-bound)
max_x = min(self.opt.width-1, non_zero[1].max()+bound)
min_y = max(0, non_zero[0].min()-bound)
max_y = min(self.opt.height-1, non_zero[0].max()+bound)
color = color * (mask!=0).astype(float)[:, :, None]
crop_color = color[min_y:max_y, min_x:max_x, :]
crop_color = cv2.resize(np.ascontiguousarray(crop_color), (self.opt.crop_width, self.opt.crop_height), interpolation=cv2.INTER_LINEAR)
crop_params = [[min_x], [max_x], [min_y], [max_y]]
return crop_color, crop_params
def __getitem__(self, index):
paths = self.paths[index % self.data_size]
src_color = np.array(Image.open(paths[0]))
src_color = src_color.astype(np.uint8)
raw_src_color = src_color.copy()
src_mask = np.array(Image.open(paths[1]))
src_mask_copy = src_mask.copy()
src_crop_color, src_crop_params = self.process_data(src_color, src_mask)
#self.write_mesh(src_X, src_Y, src_Z, "./tmp/src.obj")
#HWC --> CHW,
raw_src_color = torch.from_numpy(raw_src_color).permute(2, 0, 1).float() / 255.0
src_crop_color = torch.from_numpy(src_crop_color).permute(2, 0, 1).float() / 255.0
src_mask_copy = (src_mask_copy!=0)
src_mask_copy = torch.tensor(src_mask_copy[np.newaxis, :, :])
tar_color = np.array(Image.open(paths[2]))
tar_color = tar_color.astype(np.uint8)
raw_tar_color = tar_color.copy()
tar_mask = np.array(Image.open(paths[3]))
tar_mask_copy = tar_mask.copy()
tar_crop_color, tar_crop_params = self.process_data(tar_color, tar_mask)
raw_tar_color = torch.from_numpy(raw_tar_color).permute(2, 0, 1).float() / 255.0
tar_crop_color = torch.from_numpy(tar_crop_color).permute(2, 0, 1).float() / 255.0
tar_mask_copy = (tar_mask_copy!=0)
tar_mask_copy = torch.tensor(tar_mask_copy[np.newaxis, :, :])
Crop_param = torch.tensor(src_crop_params+tar_crop_params)
split_ = paths[0].split("/")
path1 = split_[-1][:-4] + "_" + paths[2].split("/")[-1][:-4] +".oflow"
return {"path_flow":path1, "src_crop_color":src_crop_color, "tar_crop_color":tar_crop_color, "src_color":raw_src_color, "tar_color":raw_tar_color, "src_mask":src_mask_copy, "tar_mask":tar_mask_copy, "Crop_param":Crop_param}
def __len__(self):
return self.data_size
def name(self):
return 'DDDataset'