JiminHeo commited on
Commit
c429825
·
1 Parent(s): 2731600
Files changed (3) hide show
  1. utils/helper.py +259 -0
  2. utils/logger.py +12 -0
  3. utils/mask_generator.py +198 -0
utils/helper.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import os
4
+ import pickle
5
+ from ldm.util import default
6
+ import glob
7
+ import PIL
8
+ import matplotlib.pyplot as plt
9
+
10
+ def load_file(filename):
11
+ with open(filename , 'rb') as file:
12
+ x = pickle.load(file)
13
+ return x
14
+
15
+ def save_file(filename, x, mode="wb"):
16
+ with open(filename, mode) as file:
17
+ pickle.dump(x, file)
18
+
19
+ def normalize_np(img):
20
+ """ Normalize img in arbitrary range to [0, 1] """
21
+ img -= np.min(img)
22
+ img /= np.max(img)
23
+ return img
24
+
25
+ def clear_color(x):
26
+ if torch.is_complex(x):
27
+ x = torch.abs(x)
28
+ x = x.detach().cpu().squeeze().numpy()
29
+ return normalize_np(np.transpose(x, (1, 2, 0)))
30
+
31
+ def to_img(sample):
32
+ return (sample.detach().cpu().numpy().transpose(0,2,3,1) * 127.5 + 128).clip(0, 255)
33
+
34
+ def save_plot(dir_name, tensors, labels, file_name="loss.png"):
35
+ t = np.linspace(0, len(tensors[0]), len(tensors[0]))
36
+ colours = ["r", "b", "g"]
37
+ plt.figure()
38
+ for j in range(len(tensors)):
39
+ plt.plot(t, tensors[j],color = colours[j], label = labels[j])
40
+ plt.legend()
41
+ plt.savefig(os.path.join(dir_name, file_name))
42
+ #plt.show()
43
+
44
+ def save_samples(dir_name, sample, k=None, num_to_save = 5, file_name = None):
45
+ if type(sample) is not np.ndarray: sample_np = to_img(sample).astype(np.uint8)
46
+ else: sample_np = sample.astype(np.uint8)
47
+
48
+ for j in range(num_to_save):
49
+ if file_name is None:
50
+ if k is not None: file_name_img = f'sample_{k+1}'f'{j}.png'
51
+ else: file_name_img = f'{j}.png'
52
+ else: file_name_img = file_name
53
+ image_path = os.path.join(dir_name,file_name_img)
54
+ image_np = sample_np[j]
55
+ PIL.Image.fromarray(image_np, 'RGB').save(image_path)
56
+ file_name_img = None
57
+
58
+ def save_inpaintings(dir_name, sample, y, mask_pixel, k=None, num_to_save = 5, file_name = None):
59
+ recon_in = y*(mask_pixel) + ( 1-mask_pixel)*sample
60
+ recon_in = to_img(recon_in)
61
+ for j in range(num_to_save):
62
+ if file_name is None:
63
+ if k is not None: file_name_img = f'sample_{k+1}'f'{j}.png'
64
+ else: file_name_img = f'{j}.png'
65
+ else: file_name_img = file_name
66
+ image_path = os.path.join(dir_name, file_name_img)
67
+ image_np = recon_in.astype(np.uint8)[j]
68
+ PIL.Image.fromarray(image_np, 'RGB').save(image_path)
69
+ file_name_img = None
70
+
71
+ def save_params(dir_name, mu_pos, logvar_pos, gamma,k):
72
+ params_to_fit = params_untrain([mu_pos.detach().cpu(), logvar_pos.detach().cpu(), gamma.detach().cpu()])
73
+ params_path = os.path.join(dir_name, f'{k+1}.pt')
74
+ torch.save(params_to_fit, params_path)
75
+
76
+ def custom_to_np(img):
77
+ sample = img.detach().cpu()
78
+ #sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
79
+ #sample = sample.permute(0, 2, 3, 1)
80
+ sample = sample.contiguous()
81
+ return sample
82
+
83
+ def encoder_kl(diff, img):
84
+ _, params = diff.encode_first_stage(img, return_all = True)
85
+ params = diff.scale_factor * params
86
+ mean, logvar = torch.chunk(params, 2, dim=1)
87
+ noise = default(None, lambda: torch.randn_like(mean))
88
+ mean = mean + diff.scale_factor*noise
89
+ return mean, logvar
90
+
91
+ def encoder_vq(diff, img):
92
+ quant = diff.encode_first_stage(img) #, diff, (_,_,ind)
93
+ quant = diff.scale_factor * quant
94
+ #mean, logvar = torch.chunk(params, 2, dim=1)
95
+ noise = default(None, lambda: torch.randn_like(quant))
96
+ mean = quant + diff.scale_factor*noise #
97
+ return mean
98
+
99
+ def clean_directory(dir_name):
100
+ files = glob.glob(dir_name)
101
+ for f in files:
102
+ os.remove(f)
103
+
104
+ def params_train( params ):
105
+ for item in params:
106
+ item.requires_grad = True
107
+ return params
108
+
109
+ def params_untrain(params):
110
+ for item in params:
111
+ item.requires_grad = False
112
+ return params
113
+
114
+ def time_descretization(sigma_min=0.002, sigma_max = 80, rho = 7, num_t_steps = 18):
115
+ step_indices = torch.arange(num_t_steps, dtype=torch.float64).cuda()
116
+ t_steps = (sigma_max ** (1 / rho) + step_indices / (num_t_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
117
+ inv_idx = torch.arange(num_t_steps -1, -1, -1).long()
118
+ t_steps_fwd = t_steps[inv_idx]
119
+ #t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
120
+ return t_steps_fwd
121
+
122
+ def get_optimizers(means, variances, gamma_param, lr_init_gamma=0.01) :
123
+ [lr, step_size, gamma] = [0.1, 10, 0.99] #was 0.999 for right-half: [0.01, 10, 0.99]
124
+ optimizer = torch.optim.Adam([means], lr=lr, betas=(0.9, 0.99))
125
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
126
+
127
+ optimizer_2 = torch.optim.Adam([variances], lr=0.001, betas=(0.9, 0.99)) #0.001 for lsun
128
+ optimizer_3 = torch.optim.Adam([gamma_param], lr=lr_init_gamma, betas=(0.9, 0.99)) #0.01
129
+
130
+ scheduler_2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=step_size, gamma=gamma) ##added this
131
+ scheduler_3 = torch.optim.lr_scheduler.StepLR(optimizer_3, step_size=step_size, gamma=gamma)
132
+
133
+ return [optimizer, optimizer_2, optimizer_3 ], [scheduler, scheduler_2, scheduler_3]
134
+
135
+ def check_directory(filename_list):
136
+ for filename in filename_list:
137
+ if not os.path.exists(filename):
138
+ os.mkdir(filename)
139
+
140
+ def s_file(filename, x, mode="wb"):
141
+ with open(filename, mode) as file:
142
+ pickle.dump(x, file)
143
+
144
+ def r_file(filename, mode="rb"):
145
+ with open(filename, mode) as file:
146
+ x = pickle.load(file)
147
+ return x
148
+
149
+ def sample_from_gaussian(mu, alpha, sigma):
150
+ noise = torch.randn_like(mu)
151
+ return alpha*mu + sigma * noise
152
+
153
+ '''
154
+ def make_batch(image, mask=None, device=None):
155
+ image = torch.permute(image, (0,3,1,2))
156
+ batch_size = image.shape[0]
157
+ if mask is None :
158
+ mask = torch.zeros_like(image)
159
+ mask[0, :, :256, :128] = 1
160
+ else :
161
+ mask = torch.tensor(mask)
162
+ masked_image = (mask)*image #+ mask*noise*0.2
163
+ mask = mask[:,0,:,:].reshape(batch_size,1,image.shape[2], image.shape[3])
164
+ batch = {"image": image, "mask": mask, "masked_image": masked_image}
165
+ for k in batch:
166
+ batch[k] = batch[k].to(device)
167
+ return batch
168
+
169
+ def get_sigma_t_steps(net, n_steps=3, kwargs=None):
170
+ sigma_min = kwargs["sigma_min"]
171
+ sigma_max = kwargs["sigma_max"]
172
+ sigma_min = max(sigma_min, net.sigma_min)
173
+ sigma_max = min(sigma_max, net.sigma_max)
174
+
175
+ ##Get the time-steps based on iddpm discretization
176
+ num_steps = n_steps #11 # kwargs["num_steps"]
177
+ C_2 = kwargs["C_2"]
178
+ C_1 = kwargs["C_1"]
179
+ M = kwargs["M"]
180
+ step_indices = torch.arange(num_steps, dtype=torch.float64).cuda()
181
+ u = torch.zeros(M + 1, dtype=torch.float64).cuda()
182
+ alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
183
+ for j in torch.arange(M, 0, -1, device=step_indices.device): # M, ..., 1
184
+ u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
185
+ u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
186
+ sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
187
+ #print(sigma_steps)
188
+
189
+ ##get noise schedule
190
+ sigma = lambda t: t
191
+ sigma_deriv = lambda t: 1
192
+ sigma_inv = lambda sigma: sigma
193
+
194
+ ##scaling schedule
195
+ s = lambda t: 1
196
+ s_deriv = lambda t: 0
197
+
198
+ ##compute some final time steps based on the corresponding noise levels.
199
+ t_steps = sigma_inv(net.round_sigma(sigma_steps))
200
+
201
+ return t_steps, sigma_inv, sigma, s, sigma_deriv
202
+
203
+ def data_replicate(data, K):
204
+ if len(data.shape)==2: data_batch = torch.Tensor.repeat(data,[K,1])
205
+ else: data_batch = torch.Tensor.repeat(data,[K,1,1,1])
206
+ return data_batch
207
+
208
+ '''
209
+
210
+
211
+ def sample_T(self, x0, eta=0.4, t_steps_hierarchy=None):
212
+ '''
213
+ sigma_discretization_edm = time_descretization(sigma_min=0.002, sigma_max = 999, rho = 7, num_t_steps = 10)/1000
214
+ T_max = 1000
215
+ beta_start = 1 # 0.0015*T_max
216
+ beta_end = 15 # 0.0155*T_max
217
+ def var(t):
218
+ return 1.0 - (1.0) * torch.exp(- beta_start * t - 0.5 * (beta_end - beta_start) * t * t)
219
+ '''
220
+ t_steps_hierarchy = torch.tensor(t_steps_hierarchy).cuda()
221
+ var_t = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[0]].reshape(1, 1 ,1 ,1))**2 # self.var(t_steps_hierarchy[0])
222
+ x_t = torch.sqrt(1 - var_t) * x0 + torch.sqrt(var_t) * torch.randn_like(x0)
223
+
224
+ os.makedirs("out_temp2/", exist_ok=True)
225
+ for i, t in enumerate(t_steps_hierarchy):
226
+ t_hat = torch.ones(10).cuda() * (t)
227
+ e_out = self.model.model(x_t, t_hat)
228
+ var_t = (self.model.sqrt_one_minus_alphas_cumprod[t].reshape(1, 1 ,1 ,1))**2
229
+ #score_out = - e_out / torch.sqrt()
230
+ a_t = 1 - var_t
231
+ #beta_t = 1 - a_t/a_prev
232
+ #std_pos = ((1 - a_prev)/(1 - a_t)).sqrt()*torch.sqrt(beta_t)
233
+ pred_x0 = (x_t - torch.sqrt(1 - a_t) * e_out) / a_t.sqrt()
234
+
235
+ if i != len(t_steps_hierarchy) - 1:
236
+ var_t1 = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[i+1]].reshape(1, 1 ,1 ,1))**2
237
+ a_prev = 1 - var_t1 # var(t_steps_hierarchy[i+1]/1000) # torch.full((10, 1, 1, 1), alphas[t_steps_hierarchy[i+1]]).cuda()
238
+ sigma_t = eta * torch.sqrt((1 - a_prev) / (1 - a_t) * (1 - a_t / a_prev))
239
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_out
240
+ x_t = a_prev.sqrt() * pred_x0 + dir_xt + torch.randn_like(x_t) * sigma_t + sigma_t*torch.randn_like(x_t)
241
+
242
+ #x_t= (x_t - torch.sqrt( 1 - a_t/a_prev) * e_out ) / (a_t/a_prev).sqrt() + std_pos*torch.randn_like(x_t)
243
+
244
+ '''
245
+ def pred_mean(pred_x0, z_t):
246
+ posterior_mean_coef1 = beta_t * torch.sqrt(a_prev) / (1. - a_t)
247
+ posterior_mean_coef2 = (1. - a_prev) * torch.sqrt(a_t/a_prev) / (1. - a_t)
248
+ return posterior_mean_coef1*pred_x0 + posterior_mean_coef2*z_t
249
+
250
+ x_t = torch.sqrt(a_prev) * pred_x0 # pred_mean(pred_x0, x_t) #+ 0.4*torch.sqrt(beta_t) *torch.randn_like(x_t)
251
+ '''
252
+ recon = self.model.decode_first_stage(pred_x0)
253
+ image_path = os.path.join("out_temp2/", f'{i}.png')
254
+ image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0]
255
+ PIL.Image.fromarray(image_np, 'RGB').save(image_path)
256
+
257
+ return
258
+
259
+
utils/logger.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ def get_logger():
4
+ logger = logging.getLogger(name='DPS')
5
+ logger.setLevel(logging.INFO)
6
+
7
+ formatter = logging.Formatter("%(asctime)s [%(name)s] >> %(message)s")
8
+ stream_handler = logging.StreamHandler()
9
+ stream_handler.setFormatter(formatter)
10
+ logger.addHandler(stream_handler)
11
+
12
+ return logger
utils/mask_generator.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw
4
+ import math
5
+ import random
6
+ import torch
7
+ #import tensorflow as tf
8
+ np.random.seed(10)
9
+ def random_sq_bbox(img, mask_shape, image_size=256, margin=(16, 16)):
10
+ """Generate a random sqaure mask for inpainting
11
+ """
12
+ B, H, W, C = img.shape
13
+ h, w = mask_shape
14
+ margin_height, margin_width = margin
15
+ maxt = image_size - margin_height - h
16
+ maxl = image_size - margin_width - w
17
+
18
+ # bb
19
+ t = np.random.randint(margin_height, maxt)
20
+ l = np.random.randint(margin_width, maxl)
21
+
22
+ # make mask
23
+ mask = torch.ones([B, C, H, W], device=img.device)
24
+ mask[..., t:t+h, l:l+w] = 0
25
+ mask = 1 - mask
26
+ #Fixed mid box
27
+ #mask[..., t:t+h, l:l+w] = 0
28
+ return mask, t, t+h, l, l+w
29
+
30
+ def RandomBrush(
31
+ max_tries,
32
+ s,
33
+ min_num_vertex = 4,
34
+ max_num_vertex = 18,
35
+ mean_angle = 2*math.pi / 5,
36
+ angle_range = 2*math.pi / 15,
37
+ min_width = 12,
38
+ max_width = 48):
39
+ H, W = s, s
40
+ average_radius = math.sqrt(H*H+W*W) / 8
41
+ mask = Image.new('L', (W, H), 0)
42
+ for _ in range(np.random.randint(max_tries)):
43
+ num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
44
+ angle_min = mean_angle - np.random.uniform(0, angle_range)
45
+ angle_max = mean_angle + np.random.uniform(0, angle_range)
46
+ angles = []
47
+ vertex = []
48
+ for i in range(num_vertex):
49
+ if i % 2 == 0:
50
+ angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
51
+ else:
52
+ angles.append(np.random.uniform(angle_min, angle_max))
53
+
54
+ h, w = mask.size
55
+ vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
56
+ for i in range(num_vertex):
57
+ r = np.clip(
58
+ np.random.normal(loc=average_radius, scale=average_radius//2),
59
+ 0, 2*average_radius)
60
+ new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
61
+ new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
62
+ vertex.append((int(new_x), int(new_y)))
63
+
64
+ draw = ImageDraw.Draw(mask)
65
+ width = int(np.random.uniform(min_width, max_width))
66
+ draw.line(vertex, fill=1, width=width)
67
+ for v in vertex:
68
+ draw.ellipse((v[0] - width//2,
69
+ v[1] - width//2,
70
+ v[0] + width//2,
71
+ v[1] + width//2),
72
+ fill=1)
73
+ if np.random.random() > 0.5:
74
+ mask.transpose(Image.FLIP_LEFT_RIGHT)
75
+ if np.random.random() > 0.5:
76
+ mask.transpose(Image.FLIP_TOP_BOTTOM)
77
+ mask = np.asarray(mask, np.uint8)
78
+ if np.random.random() > 0.5:
79
+ mask = np.flip(mask, 0)
80
+ if np.random.random() > 0.5:
81
+ mask = np.flip(mask, 1)
82
+ return mask
83
+
84
+ def RandomMask(s, hole_range=[0,1]):
85
+ coef = min(hole_range[0] + hole_range[1], 1.0)
86
+ while True:
87
+ mask = np.ones((s, s), np.uint8)
88
+ def Fill(max_size):
89
+ w, h = np.random.randint(max_size), np.random.randint(max_size)
90
+ ww, hh = w // 2, h // 2
91
+ x, y = np.random.randint(-ww, s - w + ww), np.random.randint(-hh, s - h + hh)
92
+ mask[max(y, 0): min(y + h, s), max(x, 0): min(x + w, s)] = 0
93
+ def MultiFill(max_tries, max_size):
94
+ for _ in range(np.random.randint(max_tries)):
95
+ Fill(max_size)
96
+ MultiFill(int(10 * coef), s // 2)
97
+ MultiFill(int(5 * coef), s)
98
+ ##comment the following line for lower masking ratios
99
+ #mask = np.logical_and(mask, 1 - RandomBrush(int(20 * coef), s))
100
+ hole_ratio = 1 - np.mean(mask)
101
+ if hole_range is not None and (hole_ratio <= hole_range[0] or hole_ratio >= hole_range[1]):
102
+ continue
103
+ return mask[np.newaxis, ...].astype(np.float32)
104
+
105
+ def BatchRandomMask(batch_size, s, hole_range=[0, 1]):
106
+ return np.stack([RandomMask(s, hole_range=hole_range) for _ in range(batch_size)], axis = 0)
107
+
108
+ def random_rotation(shape):
109
+ cutoff = 100 #was 30
110
+ (n , channels, p, q) = shape
111
+ mask = np.zeros((n,p,q))
112
+
113
+ for i in range(n):
114
+ angle = np.random.choice(360, 1)
115
+ mask_one = np.ones((p+cutoff,q+cutoff))
116
+ mask_one[int((p+cutoff)/2):,:] = 0
117
+
118
+ im = Image.fromarray(mask_one)
119
+ im = im.rotate(angle)
120
+
121
+ left = (p+cutoff - p)/2
122
+ top = (q+cutoff - q)/2
123
+ right = (p+cutoff + p)/2
124
+ bottom = (q+cutoff + q)/2
125
+
126
+ # Crop the center of the image
127
+ im = im.crop((left, top, right, bottom))
128
+
129
+ mask[i] = np.array(im)
130
+
131
+ #mask = np.repeat(mask.reshape([n,1,p,q]), channels, axis=1)
132
+ mask = mask.reshape([n,1,p,q])
133
+ return mask
134
+
135
+ class mask_generator:
136
+ def __init__(self, mask_type, mask_len_range=None, mask_prob_range=None,
137
+ image_size=256, margin=(16, 16)):
138
+ """
139
+ (mask_len_range): given in (min, max) tuple.
140
+ Specifies the range of box size in each dimension
141
+ (mask_prob_range): for the case of random masking,
142
+ specify the probability of individual pixels being masked
143
+ """
144
+ assert mask_type in ['box', 'random', 'half', 'extreme']
145
+ self.mask_type = mask_type
146
+ self.mask_len_range = mask_len_range
147
+ self.mask_prob_range = mask_prob_range
148
+ self.image_size = image_size
149
+ self.margin = margin
150
+
151
+ def _retrieve_box(self, img):
152
+ l, h = self.mask_len_range
153
+ l, h = int(l), int(h)
154
+ mask_h = np.random.randint(l, h)
155
+ mask_w = np.random.randint(l, h)
156
+ mask, t, tl, w, wh = random_sq_bbox(img,
157
+ mask_shape=(mask_h, mask_w),
158
+ image_size=self.image_size,
159
+ margin=self.margin)
160
+ return mask, t, tl, w, wh
161
+
162
+ def generate_center_mask(self, shape):
163
+ assert len(shape) == 2
164
+ assert shape[1] % 2 == 0
165
+ center = shape[0] // 2
166
+ center_size = shape[0] // 4
167
+ half_resol = center_size // 2 # for now
168
+ ret = torch.zeros(shape, dtype=torch.float32)
169
+ ret[
170
+ center - half_resol: center + half_resol,
171
+ center - half_resol: center + half_resol,
172
+ ] = 1
173
+ ret = ret.unsqueeze(0).unsqueeze(0)
174
+ return ret
175
+
176
+ def __call__(self, img):
177
+ if self.mask_type == 'random':
178
+ mask = BatchRandomMask(1, self.image_size, hole_range=self.mask_prob_range) #self._retrieve_random(img)
179
+ return mask
180
+ elif self.mask_type == "half":
181
+ mask = random_rotation((1, 3, self.image_size, self.image_size))
182
+ elif self.mask_type == 'box':
183
+ #mask, t, th, w, wl = self._retrieve_box(img)
184
+ mask = self.generate_center_mask((self.image_size,self.image_size)) # self._retrieve_box(img)
185
+ return mask #.permute(0,3,1,2)
186
+ elif self.mask_type == 'extreme':
187
+ mask, t, th, w, wl = self._retrieve_box(img)
188
+ mask = 1. - mask
189
+ return mask
190
+
191
+
192
+ '''
193
+ def tf_mask_generator(s, tf_hole_range):
194
+ def random_mask_generator(hole_range):
195
+ while True:
196
+ yield RandomMask(s, hole_range=hole_range)
197
+ return tf.data.Dataset.from_generator(random_mask_generator, tf.float32, tf.TensorShape([1, s, s]), (tf_hole_range,))
198
+ '''