|
import os.path |
|
import torch |
|
import torch.utils.data as data |
|
from PIL import Image |
|
import random |
|
import utils |
|
import numpy as np |
|
import torchvision.transforms as transforms |
|
from utils_core import flow_viz |
|
import cv2 |
|
|
|
class DDDataset(data.Dataset): |
|
def __init__(self): |
|
super(DDDataset, self).__init__() |
|
def initialize(self, opt): |
|
self.opt = opt |
|
self.dir_txt = opt.datapath |
|
self.paths = [] |
|
in_file = open(self.dir_txt, "r") |
|
k = 0 |
|
list_paths = in_file.readlines() |
|
for line in list_paths: |
|
|
|
flag = False |
|
line = line.strip() |
|
line = line.split() |
|
|
|
|
|
if (not os.path.exists(line[0])): |
|
print(line[0]+" not exists") |
|
continue |
|
if (not os.path.exists(line[1])): |
|
print(line[1]+" not exists") |
|
continue |
|
if (not os.path.exists(line[2])): |
|
print(line[2]+" not exists") |
|
continue |
|
if (not os.path.exists(line[3])): |
|
print(line[3]+" not exists") |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
path_list = [line[0], line[1], line[2], line[3]] |
|
self.paths.append(path_list) |
|
k += 1 |
|
in_file.close() |
|
self.data_size = len(self.paths) |
|
print("num data: ", len(self.paths)) |
|
|
|
def process_data(self, color, mask): |
|
non_zero = mask.nonzero() |
|
bound = 10 |
|
min_x = max(0, non_zero[1].min()-bound) |
|
max_x = min(self.opt.width-1, non_zero[1].max()+bound) |
|
min_y = max(0, non_zero[0].min()-bound) |
|
max_y = min(self.opt.height-1, non_zero[0].max()+bound) |
|
color = color * (mask!=0).astype(float)[:, :, None] |
|
crop_color = color[min_y:max_y, min_x:max_x, :] |
|
crop_color = cv2.resize(np.ascontiguousarray(crop_color), (self.opt.crop_width, self.opt.crop_height), interpolation=cv2.INTER_LINEAR) |
|
crop_params = [[min_x], [max_x], [min_y], [max_y]] |
|
|
|
return crop_color, crop_params |
|
|
|
def __getitem__(self, index): |
|
paths = self.paths[index % self.data_size] |
|
src_color = np.array(Image.open(paths[0])) |
|
src_color = src_color.astype(np.uint8) |
|
raw_src_color = src_color.copy() |
|
src_mask = np.array(Image.open(paths[1])) |
|
src_mask_copy = src_mask.copy() |
|
src_crop_color, src_crop_params = self.process_data(src_color, src_mask) |
|
|
|
|
|
raw_src_color = torch.from_numpy(raw_src_color).permute(2, 0, 1).float() / 255.0 |
|
src_crop_color = torch.from_numpy(src_crop_color).permute(2, 0, 1).float() / 255.0 |
|
|
|
src_mask_copy = (src_mask_copy!=0) |
|
src_mask_copy = torch.tensor(src_mask_copy[np.newaxis, :, :]) |
|
|
|
tar_color = np.array(Image.open(paths[2])) |
|
tar_color = tar_color.astype(np.uint8) |
|
raw_tar_color = tar_color.copy() |
|
tar_mask = np.array(Image.open(paths[3])) |
|
tar_mask_copy = tar_mask.copy() |
|
tar_crop_color, tar_crop_params = self.process_data(tar_color, tar_mask) |
|
|
|
raw_tar_color = torch.from_numpy(raw_tar_color).permute(2, 0, 1).float() / 255.0 |
|
tar_crop_color = torch.from_numpy(tar_crop_color).permute(2, 0, 1).float() / 255.0 |
|
|
|
tar_mask_copy = (tar_mask_copy!=0) |
|
tar_mask_copy = torch.tensor(tar_mask_copy[np.newaxis, :, :]) |
|
|
|
Crop_param = torch.tensor(src_crop_params+tar_crop_params) |
|
|
|
split_ = paths[0].split("/") |
|
path1 = split_[-1][:-4] + "_" + paths[2].split("/")[-1][:-4] +".oflow" |
|
|
|
return {"path_flow":path1, "src_crop_color":src_crop_color, "tar_crop_color":tar_crop_color, "src_color":raw_src_color, "tar_color":raw_tar_color, "src_mask":src_mask_copy, "tar_mask":tar_mask_copy, "Crop_param":Crop_param} |
|
|
|
def __len__(self): |
|
return self.data_size |
|
|
|
def name(self): |
|
return 'DDDataset' |