File size: 2,263 Bytes
7634bde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os.path
from data.base_dataset import BaseDataset, get_params, get_transform
from PIL import Image
import linecache

class AlignedDataset(BaseDataset):
    def initialize(self, opt):
        self.opt = opt
        self.root = opt.dataroot

        self.fine_height=256
        self.fine_width=192

        self.text = './test_pairs.txt'

        dir_I = '_ma_img'
        self.dir_I = os.path.join(opt.dataroot, opt.phase + dir_I)

        dir_C = '_clothes'
        self.dir_C = os.path.join(opt.dataroot, opt.phase + dir_C)

        dir_E = '_edge'
        self.dir_E = os.path.join(opt.dataroot, opt.phase + dir_E)

        self.im_name = []
        self.c_name = []
        self.e_name = []
        self.get_file_name()
        #import ipdb; ipdb.set_trace()
        self.dataset_size = len(self.im_name)

    def get_file_name(self):

        with open(self.text, 'r') as f:
            for line in f.readlines():
                im_name, c_name = line.strip().split()
                self.im_name.append(os.path.join(self.dir_I, im_name))
                self.c_name.append(os.path.join(self.dir_C, c_name))
                self.e_name.append(os.path.join(self.dir_E, c_name))

    def __getitem__(self, index):        

        #file_path ='demo.txt'
        #im_name, c_name = linecache.getline(file_path, index+1).strip().split()

        I_path = os.path.join(self.im_name[index])
        I = Image.open(I_path).convert('RGB')

        params = get_params(self.opt, I.size)
        transform = get_transform(self.opt, params)
        transform_E = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)

        I_tensor = transform(I)
        #import ipdb; ipdb.set_trace()
        C_path = os.path.join(self.c_name[index])
        #print(self.c_name[index])
        C = Image.open(C_path).convert('RGB')
        C_tensor = transform(C)

        E_path = os.path.join(self.e_name[index])
        E = Image.open(E_path).convert('L')
        E_tensor = transform_E(E)

        input_dict = { 'image': I_tensor,'clothes': C_tensor, 'edge': E_tensor, 'p_name':self.im_name[index].split('/')[-1]}
        return input_dict

    def __len__(self):
        return self.dataset_size 

    def name(self):
        return 'AlignedDataset'