File size: 4,162 Bytes
8d34f50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os.path
import torch
import torch.utils.data as data
from PIL import Image
import random
import utils
import numpy as np
import torchvision.transforms as transforms
from utils_core import flow_viz
import cv2

class DDDataset(data.Dataset):
    def __init__(self):
        super(DDDataset, self).__init__()
    def initialize(self, opt):
        self.opt = opt
        self.dir_txt = opt.datapath
        self.paths = []
        in_file = open(self.dir_txt, "r")
        k = 0
        list_paths = in_file.readlines()
        for line in list_paths:
            #if k>=20: break
            flag = False
            line = line.strip()
            line = line.split()
            
            #source data
            if (not os.path.exists(line[0])):
                print(line[0]+" not exists")
                continue
            if (not os.path.exists(line[1])):
                print(line[1]+" not exists")
                continue
            if (not os.path.exists(line[2])):
                print(line[2]+" not exists")
                continue
            if (not os.path.exists(line[3])):
                print(line[3]+" not exists")
                continue
            # if (not os.path.exists(line[2])):
            #     print(line[2]+" not exists")
            #     continue

            # path_list = [line[0], line[1], line[2]]
            path_list = [line[0], line[1], line[2], line[3]]
            self.paths.append(path_list)
            k += 1
        in_file.close()
        self.data_size = len(self.paths)
        print("num data: ", len(self.paths))

    def process_data(self, color, mask):
        non_zero = mask.nonzero()
        bound = 10
        min_x = max(0, non_zero[1].min()-bound)
        max_x = min(self.opt.width-1, non_zero[1].max()+bound)
        min_y = max(0, non_zero[0].min()-bound)
        max_y = min(self.opt.height-1, non_zero[0].max()+bound)
        color = color * (mask!=0).astype(float)[:, :, None]
        crop_color = color[min_y:max_y, min_x:max_x, :]
        crop_color = cv2.resize(np.ascontiguousarray(crop_color), (self.opt.crop_width, self.opt.crop_height), interpolation=cv2.INTER_LINEAR)
        crop_params = [[min_x], [max_x], [min_y], [max_y]]

        return crop_color, crop_params

    def __getitem__(self, index):
        paths = self.paths[index % self.data_size]
        src_color = np.array(Image.open(paths[0]))
        src_color = src_color.astype(np.uint8)
        raw_src_color = src_color.copy()
        src_mask = np.array(Image.open(paths[1]))
        src_mask_copy = src_mask.copy()
        src_crop_color, src_crop_params = self.process_data(src_color, src_mask)
        #self.write_mesh(src_X, src_Y, src_Z, "./tmp/src.obj")
        #HWC --> CHW, 
        raw_src_color = torch.from_numpy(raw_src_color).permute(2, 0, 1).float() / 255.0
        src_crop_color = torch.from_numpy(src_crop_color).permute(2, 0, 1).float() / 255.0

        src_mask_copy = (src_mask_copy!=0)
        src_mask_copy = torch.tensor(src_mask_copy[np.newaxis, :, :])

        tar_color = np.array(Image.open(paths[2]))
        tar_color = tar_color.astype(np.uint8)
        raw_tar_color = tar_color.copy()
        tar_mask = np.array(Image.open(paths[3]))
        tar_mask_copy = tar_mask.copy()
        tar_crop_color, tar_crop_params = self.process_data(tar_color, tar_mask) 

        raw_tar_color = torch.from_numpy(raw_tar_color).permute(2, 0, 1).float() / 255.0
        tar_crop_color = torch.from_numpy(tar_crop_color).permute(2, 0, 1).float() / 255.0

        tar_mask_copy = (tar_mask_copy!=0)
        tar_mask_copy = torch.tensor(tar_mask_copy[np.newaxis, :, :])

        Crop_param = torch.tensor(src_crop_params+tar_crop_params)

        split_ = paths[0].split("/")
        path1 = split_[-1][:-4] + "_" + paths[2].split("/")[-1][:-4] +".oflow"

        return {"path_flow":path1, "src_crop_color":src_crop_color, "tar_crop_color":tar_crop_color, "src_color":raw_src_color, "tar_color":raw_tar_color, "src_mask":src_mask_copy, "tar_mask":tar_mask_copy, "Crop_param":Crop_param}

    def __len__(self):
        return self.data_size

    def name(self):
        return 'DDDataset'