File size: 5,097 Bytes
2514fb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os.path
import random
import numpy as np
import torch.utils.data as data
import utils.utils_image as util



class DatasetPlainPatch(data.Dataset):
    '''
    # -----------------------------------------
    # Get L/H for image-to-image mapping.
    # Both "paths_L" and "paths_H" are needed.
    # -----------------------------------------
    # e.g., train denoiser with L and H patches
    # create a large patch dataset first
    # -----------------------------------------
    '''

    def __init__(self, opt):
        super(DatasetPlainPatch, self).__init__()
        print('Get L/H for image-to-image mapping. Both "paths_L" and "paths_H" are needed.')
        self.opt = opt
        self.n_channels = opt['n_channels'] if opt['n_channels'] else 3
        self.patch_size = self.opt['H_size'] if self.opt['H_size'] else 64

        self.num_patches_per_image = opt['num_patches_per_image'] if opt['num_patches_per_image'] else 40
        self.num_sampled = opt['num_sampled'] if opt['num_sampled'] else 3000

        # -------------------
        # get the path of L/H
        # -------------------
        self.paths_H = util.get_image_paths(opt['dataroot_H'])
        self.paths_L = util.get_image_paths(opt['dataroot_L'])

        assert self.paths_H, 'Error: H path is empty.'
        assert self.paths_L, 'Error: L path is empty. This dataset uses L path, you can use dataset_dnpatchh'
        if self.paths_L and self.paths_H:
            assert len(self.paths_L) == len(self.paths_H), 'H and L datasets have different number of images - {}, {}.'.format(len(self.paths_L), len(self.paths_H))

        # ------------------------------------
        # number of sampled images
        # ------------------------------------
        self.num_sampled = min(self.num_sampled, len(self.paths_H))

        # ------------------------------------
        # reserve space with zeros
        # ------------------------------------
        self.total_patches = self.num_sampled * self.num_patches_per_image
        self.H_data = np.zeros([self.total_patches, self.path_size, self.path_size, self.n_channels], dtype=np.uint8)
        self.L_data = np.zeros([self.total_patches, self.path_size, self.path_size, self.n_channels], dtype=np.uint8)

        # ------------------------------------
        # update H patches
        # ------------------------------------
        self.update_data()


    def update_data(self):
        """
        # ------------------------------------
        # update whole L/H patches
        # ------------------------------------
        """
        self.index_sampled = random.sample(range(0, len(self.paths_H), 1), self.num_sampled)
        n_count = 0

        for i in range(len(self.index_sampled)):
            L_patches, H_patches = self.get_patches(self.index_sampled[i])
            for (L_patch, H_patch) in zip(L_patches, H_patches):
                self.L_data[n_count,:,:,:] = L_patch
                self.H_data[n_count,:,:,:] = H_patch
                n_count += 1

        print('Training data updated! Total number of patches is:  %5.2f X %5.2f = %5.2f\n' % (len(self.H_data)//128, 128, len(self.H_data)))

    def get_patches(self, index):
        """
        # ------------------------------------
        # get L/H patches from L/H images
        # ------------------------------------
        """
        L_path = self.paths_L[index]
        H_path = self.paths_H[index]
        img_L = util.imread_uint(L_path, self.n_channels)  # uint format
        img_H = util.imread_uint(H_path, self.n_channels)  # uint format

        H, W = img_H.shape[:2]

        L_patches, H_patches = [], []

        num = self.num_patches_per_image
        for _ in range(num):
            rnd_h = random.randint(0, max(0, H - self.path_size))
            rnd_w = random.randint(0, max(0, W - self.path_size))
            L_patch = img_L[rnd_h:rnd_h + self.path_size, rnd_w:rnd_w + self.path_size, :]
            H_patch = img_H[rnd_h:rnd_h + self.path_size, rnd_w:rnd_w + self.path_size, :]
            L_patches.append(L_patch)
            H_patches.append(H_patch)

        return L_patches, H_patches

    def __getitem__(self, index):

        if self.opt['phase'] == 'train':

            patch_L, patch_H = self.L_data[index], self.H_data[index]

            # --------------------------------
            # augmentation - flip and/or rotate
            # --------------------------------
            mode = random.randint(0, 7)
            patch_L = util.augment_img(patch_L, mode=mode)
            patch_H = util.augment_img(patch_H, mode=mode)

            patch_L, patch_H = util.uint2tensor3(patch_L), util.uint2tensor3(patch_H)

        else:

            L_path, H_path = self.paths_L[index], self.paths_H[index]
            patch_L = util.imread_uint(L_path, self.n_channels)
            patch_H = util.imread_uint(H_path, self.n_channels)

            patch_L, patch_H = util.uint2tensor3(patch_L), util.uint2tensor3(patch_H)

        return {'L': patch_L, 'H': patch_H}


    def __len__(self):
        
        return self.total_patches