Spaces:
Sleeping
Sleeping
util
Browse files- utils/helper.py +259 -0
- utils/logger.py +12 -0
- utils/mask_generator.py +198 -0
utils/helper.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
from ldm.util import default
|
6 |
+
import glob
|
7 |
+
import PIL
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
def load_file(filename):
|
11 |
+
with open(filename , 'rb') as file:
|
12 |
+
x = pickle.load(file)
|
13 |
+
return x
|
14 |
+
|
15 |
+
def save_file(filename, x, mode="wb"):
|
16 |
+
with open(filename, mode) as file:
|
17 |
+
pickle.dump(x, file)
|
18 |
+
|
19 |
+
def normalize_np(img):
|
20 |
+
""" Normalize img in arbitrary range to [0, 1] """
|
21 |
+
img -= np.min(img)
|
22 |
+
img /= np.max(img)
|
23 |
+
return img
|
24 |
+
|
25 |
+
def clear_color(x):
|
26 |
+
if torch.is_complex(x):
|
27 |
+
x = torch.abs(x)
|
28 |
+
x = x.detach().cpu().squeeze().numpy()
|
29 |
+
return normalize_np(np.transpose(x, (1, 2, 0)))
|
30 |
+
|
31 |
+
def to_img(sample):
|
32 |
+
return (sample.detach().cpu().numpy().transpose(0,2,3,1) * 127.5 + 128).clip(0, 255)
|
33 |
+
|
34 |
+
def save_plot(dir_name, tensors, labels, file_name="loss.png"):
|
35 |
+
t = np.linspace(0, len(tensors[0]), len(tensors[0]))
|
36 |
+
colours = ["r", "b", "g"]
|
37 |
+
plt.figure()
|
38 |
+
for j in range(len(tensors)):
|
39 |
+
plt.plot(t, tensors[j],color = colours[j], label = labels[j])
|
40 |
+
plt.legend()
|
41 |
+
plt.savefig(os.path.join(dir_name, file_name))
|
42 |
+
#plt.show()
|
43 |
+
|
44 |
+
def save_samples(dir_name, sample, k=None, num_to_save = 5, file_name = None):
|
45 |
+
if type(sample) is not np.ndarray: sample_np = to_img(sample).astype(np.uint8)
|
46 |
+
else: sample_np = sample.astype(np.uint8)
|
47 |
+
|
48 |
+
for j in range(num_to_save):
|
49 |
+
if file_name is None:
|
50 |
+
if k is not None: file_name_img = f'sample_{k+1}'f'{j}.png'
|
51 |
+
else: file_name_img = f'{j}.png'
|
52 |
+
else: file_name_img = file_name
|
53 |
+
image_path = os.path.join(dir_name,file_name_img)
|
54 |
+
image_np = sample_np[j]
|
55 |
+
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
|
56 |
+
file_name_img = None
|
57 |
+
|
58 |
+
def save_inpaintings(dir_name, sample, y, mask_pixel, k=None, num_to_save = 5, file_name = None):
|
59 |
+
recon_in = y*(mask_pixel) + ( 1-mask_pixel)*sample
|
60 |
+
recon_in = to_img(recon_in)
|
61 |
+
for j in range(num_to_save):
|
62 |
+
if file_name is None:
|
63 |
+
if k is not None: file_name_img = f'sample_{k+1}'f'{j}.png'
|
64 |
+
else: file_name_img = f'{j}.png'
|
65 |
+
else: file_name_img = file_name
|
66 |
+
image_path = os.path.join(dir_name, file_name_img)
|
67 |
+
image_np = recon_in.astype(np.uint8)[j]
|
68 |
+
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
|
69 |
+
file_name_img = None
|
70 |
+
|
71 |
+
def save_params(dir_name, mu_pos, logvar_pos, gamma,k):
|
72 |
+
params_to_fit = params_untrain([mu_pos.detach().cpu(), logvar_pos.detach().cpu(), gamma.detach().cpu()])
|
73 |
+
params_path = os.path.join(dir_name, f'{k+1}.pt')
|
74 |
+
torch.save(params_to_fit, params_path)
|
75 |
+
|
76 |
+
def custom_to_np(img):
|
77 |
+
sample = img.detach().cpu()
|
78 |
+
#sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
|
79 |
+
#sample = sample.permute(0, 2, 3, 1)
|
80 |
+
sample = sample.contiguous()
|
81 |
+
return sample
|
82 |
+
|
83 |
+
def encoder_kl(diff, img):
|
84 |
+
_, params = diff.encode_first_stage(img, return_all = True)
|
85 |
+
params = diff.scale_factor * params
|
86 |
+
mean, logvar = torch.chunk(params, 2, dim=1)
|
87 |
+
noise = default(None, lambda: torch.randn_like(mean))
|
88 |
+
mean = mean + diff.scale_factor*noise
|
89 |
+
return mean, logvar
|
90 |
+
|
91 |
+
def encoder_vq(diff, img):
|
92 |
+
quant = diff.encode_first_stage(img) #, diff, (_,_,ind)
|
93 |
+
quant = diff.scale_factor * quant
|
94 |
+
#mean, logvar = torch.chunk(params, 2, dim=1)
|
95 |
+
noise = default(None, lambda: torch.randn_like(quant))
|
96 |
+
mean = quant + diff.scale_factor*noise #
|
97 |
+
return mean
|
98 |
+
|
99 |
+
def clean_directory(dir_name):
|
100 |
+
files = glob.glob(dir_name)
|
101 |
+
for f in files:
|
102 |
+
os.remove(f)
|
103 |
+
|
104 |
+
def params_train( params ):
|
105 |
+
for item in params:
|
106 |
+
item.requires_grad = True
|
107 |
+
return params
|
108 |
+
|
109 |
+
def params_untrain(params):
|
110 |
+
for item in params:
|
111 |
+
item.requires_grad = False
|
112 |
+
return params
|
113 |
+
|
114 |
+
def time_descretization(sigma_min=0.002, sigma_max = 80, rho = 7, num_t_steps = 18):
|
115 |
+
step_indices = torch.arange(num_t_steps, dtype=torch.float64).cuda()
|
116 |
+
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_t_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
117 |
+
inv_idx = torch.arange(num_t_steps -1, -1, -1).long()
|
118 |
+
t_steps_fwd = t_steps[inv_idx]
|
119 |
+
#t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
|
120 |
+
return t_steps_fwd
|
121 |
+
|
122 |
+
def get_optimizers(means, variances, gamma_param, lr_init_gamma=0.01) :
|
123 |
+
[lr, step_size, gamma] = [0.1, 10, 0.99] #was 0.999 for right-half: [0.01, 10, 0.99]
|
124 |
+
optimizer = torch.optim.Adam([means], lr=lr, betas=(0.9, 0.99))
|
125 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
|
126 |
+
|
127 |
+
optimizer_2 = torch.optim.Adam([variances], lr=0.001, betas=(0.9, 0.99)) #0.001 for lsun
|
128 |
+
optimizer_3 = torch.optim.Adam([gamma_param], lr=lr_init_gamma, betas=(0.9, 0.99)) #0.01
|
129 |
+
|
130 |
+
scheduler_2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=step_size, gamma=gamma) ##added this
|
131 |
+
scheduler_3 = torch.optim.lr_scheduler.StepLR(optimizer_3, step_size=step_size, gamma=gamma)
|
132 |
+
|
133 |
+
return [optimizer, optimizer_2, optimizer_3 ], [scheduler, scheduler_2, scheduler_3]
|
134 |
+
|
135 |
+
def check_directory(filename_list):
|
136 |
+
for filename in filename_list:
|
137 |
+
if not os.path.exists(filename):
|
138 |
+
os.mkdir(filename)
|
139 |
+
|
140 |
+
def s_file(filename, x, mode="wb"):
|
141 |
+
with open(filename, mode) as file:
|
142 |
+
pickle.dump(x, file)
|
143 |
+
|
144 |
+
def r_file(filename, mode="rb"):
|
145 |
+
with open(filename, mode) as file:
|
146 |
+
x = pickle.load(file)
|
147 |
+
return x
|
148 |
+
|
149 |
+
def sample_from_gaussian(mu, alpha, sigma):
|
150 |
+
noise = torch.randn_like(mu)
|
151 |
+
return alpha*mu + sigma * noise
|
152 |
+
|
153 |
+
'''
|
154 |
+
def make_batch(image, mask=None, device=None):
|
155 |
+
image = torch.permute(image, (0,3,1,2))
|
156 |
+
batch_size = image.shape[0]
|
157 |
+
if mask is None :
|
158 |
+
mask = torch.zeros_like(image)
|
159 |
+
mask[0, :, :256, :128] = 1
|
160 |
+
else :
|
161 |
+
mask = torch.tensor(mask)
|
162 |
+
masked_image = (mask)*image #+ mask*noise*0.2
|
163 |
+
mask = mask[:,0,:,:].reshape(batch_size,1,image.shape[2], image.shape[3])
|
164 |
+
batch = {"image": image, "mask": mask, "masked_image": masked_image}
|
165 |
+
for k in batch:
|
166 |
+
batch[k] = batch[k].to(device)
|
167 |
+
return batch
|
168 |
+
|
169 |
+
def get_sigma_t_steps(net, n_steps=3, kwargs=None):
|
170 |
+
sigma_min = kwargs["sigma_min"]
|
171 |
+
sigma_max = kwargs["sigma_max"]
|
172 |
+
sigma_min = max(sigma_min, net.sigma_min)
|
173 |
+
sigma_max = min(sigma_max, net.sigma_max)
|
174 |
+
|
175 |
+
##Get the time-steps based on iddpm discretization
|
176 |
+
num_steps = n_steps #11 # kwargs["num_steps"]
|
177 |
+
C_2 = kwargs["C_2"]
|
178 |
+
C_1 = kwargs["C_1"]
|
179 |
+
M = kwargs["M"]
|
180 |
+
step_indices = torch.arange(num_steps, dtype=torch.float64).cuda()
|
181 |
+
u = torch.zeros(M + 1, dtype=torch.float64).cuda()
|
182 |
+
alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
|
183 |
+
for j in torch.arange(M, 0, -1, device=step_indices.device): # M, ..., 1
|
184 |
+
u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
|
185 |
+
u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
|
186 |
+
sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
|
187 |
+
#print(sigma_steps)
|
188 |
+
|
189 |
+
##get noise schedule
|
190 |
+
sigma = lambda t: t
|
191 |
+
sigma_deriv = lambda t: 1
|
192 |
+
sigma_inv = lambda sigma: sigma
|
193 |
+
|
194 |
+
##scaling schedule
|
195 |
+
s = lambda t: 1
|
196 |
+
s_deriv = lambda t: 0
|
197 |
+
|
198 |
+
##compute some final time steps based on the corresponding noise levels.
|
199 |
+
t_steps = sigma_inv(net.round_sigma(sigma_steps))
|
200 |
+
|
201 |
+
return t_steps, sigma_inv, sigma, s, sigma_deriv
|
202 |
+
|
203 |
+
def data_replicate(data, K):
|
204 |
+
if len(data.shape)==2: data_batch = torch.Tensor.repeat(data,[K,1])
|
205 |
+
else: data_batch = torch.Tensor.repeat(data,[K,1,1,1])
|
206 |
+
return data_batch
|
207 |
+
|
208 |
+
'''
|
209 |
+
|
210 |
+
|
211 |
+
def sample_T(self, x0, eta=0.4, t_steps_hierarchy=None):
|
212 |
+
'''
|
213 |
+
sigma_discretization_edm = time_descretization(sigma_min=0.002, sigma_max = 999, rho = 7, num_t_steps = 10)/1000
|
214 |
+
T_max = 1000
|
215 |
+
beta_start = 1 # 0.0015*T_max
|
216 |
+
beta_end = 15 # 0.0155*T_max
|
217 |
+
def var(t):
|
218 |
+
return 1.0 - (1.0) * torch.exp(- beta_start * t - 0.5 * (beta_end - beta_start) * t * t)
|
219 |
+
'''
|
220 |
+
t_steps_hierarchy = torch.tensor(t_steps_hierarchy).cuda()
|
221 |
+
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[0]].reshape(1, 1 ,1 ,1))**2 # self.var(t_steps_hierarchy[0])
|
222 |
+
x_t = torch.sqrt(1 - var_t) * x0 + torch.sqrt(var_t) * torch.randn_like(x0)
|
223 |
+
|
224 |
+
os.makedirs("out_temp2/", exist_ok=True)
|
225 |
+
for i, t in enumerate(t_steps_hierarchy):
|
226 |
+
t_hat = torch.ones(10).cuda() * (t)
|
227 |
+
e_out = self.model.model(x_t, t_hat)
|
228 |
+
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t].reshape(1, 1 ,1 ,1))**2
|
229 |
+
#score_out = - e_out / torch.sqrt()
|
230 |
+
a_t = 1 - var_t
|
231 |
+
#beta_t = 1 - a_t/a_prev
|
232 |
+
#std_pos = ((1 - a_prev)/(1 - a_t)).sqrt()*torch.sqrt(beta_t)
|
233 |
+
pred_x0 = (x_t - torch.sqrt(1 - a_t) * e_out) / a_t.sqrt()
|
234 |
+
|
235 |
+
if i != len(t_steps_hierarchy) - 1:
|
236 |
+
var_t1 = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[i+1]].reshape(1, 1 ,1 ,1))**2
|
237 |
+
a_prev = 1 - var_t1 # var(t_steps_hierarchy[i+1]/1000) # torch.full((10, 1, 1, 1), alphas[t_steps_hierarchy[i+1]]).cuda()
|
238 |
+
sigma_t = eta * torch.sqrt((1 - a_prev) / (1 - a_t) * (1 - a_t / a_prev))
|
239 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_out
|
240 |
+
x_t = a_prev.sqrt() * pred_x0 + dir_xt + torch.randn_like(x_t) * sigma_t + sigma_t*torch.randn_like(x_t)
|
241 |
+
|
242 |
+
#x_t= (x_t - torch.sqrt( 1 - a_t/a_prev) * e_out ) / (a_t/a_prev).sqrt() + std_pos*torch.randn_like(x_t)
|
243 |
+
|
244 |
+
'''
|
245 |
+
def pred_mean(pred_x0, z_t):
|
246 |
+
posterior_mean_coef1 = beta_t * torch.sqrt(a_prev) / (1. - a_t)
|
247 |
+
posterior_mean_coef2 = (1. - a_prev) * torch.sqrt(a_t/a_prev) / (1. - a_t)
|
248 |
+
return posterior_mean_coef1*pred_x0 + posterior_mean_coef2*z_t
|
249 |
+
|
250 |
+
x_t = torch.sqrt(a_prev) * pred_x0 # pred_mean(pred_x0, x_t) #+ 0.4*torch.sqrt(beta_t) *torch.randn_like(x_t)
|
251 |
+
'''
|
252 |
+
recon = self.model.decode_first_stage(pred_x0)
|
253 |
+
image_path = os.path.join("out_temp2/", f'{i}.png')
|
254 |
+
image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0]
|
255 |
+
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
|
256 |
+
|
257 |
+
return
|
258 |
+
|
259 |
+
|
utils/logger.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
def get_logger():
|
4 |
+
logger = logging.getLogger(name='DPS')
|
5 |
+
logger.setLevel(logging.INFO)
|
6 |
+
|
7 |
+
formatter = logging.Formatter("%(asctime)s [%(name)s] >> %(message)s")
|
8 |
+
stream_handler = logging.StreamHandler()
|
9 |
+
stream_handler.setFormatter(formatter)
|
10 |
+
logger.addHandler(stream_handler)
|
11 |
+
|
12 |
+
return logger
|
utils/mask_generator.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImageDraw
|
4 |
+
import math
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
#import tensorflow as tf
|
8 |
+
np.random.seed(10)
|
9 |
+
def random_sq_bbox(img, mask_shape, image_size=256, margin=(16, 16)):
|
10 |
+
"""Generate a random sqaure mask for inpainting
|
11 |
+
"""
|
12 |
+
B, H, W, C = img.shape
|
13 |
+
h, w = mask_shape
|
14 |
+
margin_height, margin_width = margin
|
15 |
+
maxt = image_size - margin_height - h
|
16 |
+
maxl = image_size - margin_width - w
|
17 |
+
|
18 |
+
# bb
|
19 |
+
t = np.random.randint(margin_height, maxt)
|
20 |
+
l = np.random.randint(margin_width, maxl)
|
21 |
+
|
22 |
+
# make mask
|
23 |
+
mask = torch.ones([B, C, H, W], device=img.device)
|
24 |
+
mask[..., t:t+h, l:l+w] = 0
|
25 |
+
mask = 1 - mask
|
26 |
+
#Fixed mid box
|
27 |
+
#mask[..., t:t+h, l:l+w] = 0
|
28 |
+
return mask, t, t+h, l, l+w
|
29 |
+
|
30 |
+
def RandomBrush(
|
31 |
+
max_tries,
|
32 |
+
s,
|
33 |
+
min_num_vertex = 4,
|
34 |
+
max_num_vertex = 18,
|
35 |
+
mean_angle = 2*math.pi / 5,
|
36 |
+
angle_range = 2*math.pi / 15,
|
37 |
+
min_width = 12,
|
38 |
+
max_width = 48):
|
39 |
+
H, W = s, s
|
40 |
+
average_radius = math.sqrt(H*H+W*W) / 8
|
41 |
+
mask = Image.new('L', (W, H), 0)
|
42 |
+
for _ in range(np.random.randint(max_tries)):
|
43 |
+
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
|
44 |
+
angle_min = mean_angle - np.random.uniform(0, angle_range)
|
45 |
+
angle_max = mean_angle + np.random.uniform(0, angle_range)
|
46 |
+
angles = []
|
47 |
+
vertex = []
|
48 |
+
for i in range(num_vertex):
|
49 |
+
if i % 2 == 0:
|
50 |
+
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
|
51 |
+
else:
|
52 |
+
angles.append(np.random.uniform(angle_min, angle_max))
|
53 |
+
|
54 |
+
h, w = mask.size
|
55 |
+
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
|
56 |
+
for i in range(num_vertex):
|
57 |
+
r = np.clip(
|
58 |
+
np.random.normal(loc=average_radius, scale=average_radius//2),
|
59 |
+
0, 2*average_radius)
|
60 |
+
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
|
61 |
+
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
|
62 |
+
vertex.append((int(new_x), int(new_y)))
|
63 |
+
|
64 |
+
draw = ImageDraw.Draw(mask)
|
65 |
+
width = int(np.random.uniform(min_width, max_width))
|
66 |
+
draw.line(vertex, fill=1, width=width)
|
67 |
+
for v in vertex:
|
68 |
+
draw.ellipse((v[0] - width//2,
|
69 |
+
v[1] - width//2,
|
70 |
+
v[0] + width//2,
|
71 |
+
v[1] + width//2),
|
72 |
+
fill=1)
|
73 |
+
if np.random.random() > 0.5:
|
74 |
+
mask.transpose(Image.FLIP_LEFT_RIGHT)
|
75 |
+
if np.random.random() > 0.5:
|
76 |
+
mask.transpose(Image.FLIP_TOP_BOTTOM)
|
77 |
+
mask = np.asarray(mask, np.uint8)
|
78 |
+
if np.random.random() > 0.5:
|
79 |
+
mask = np.flip(mask, 0)
|
80 |
+
if np.random.random() > 0.5:
|
81 |
+
mask = np.flip(mask, 1)
|
82 |
+
return mask
|
83 |
+
|
84 |
+
def RandomMask(s, hole_range=[0,1]):
|
85 |
+
coef = min(hole_range[0] + hole_range[1], 1.0)
|
86 |
+
while True:
|
87 |
+
mask = np.ones((s, s), np.uint8)
|
88 |
+
def Fill(max_size):
|
89 |
+
w, h = np.random.randint(max_size), np.random.randint(max_size)
|
90 |
+
ww, hh = w // 2, h // 2
|
91 |
+
x, y = np.random.randint(-ww, s - w + ww), np.random.randint(-hh, s - h + hh)
|
92 |
+
mask[max(y, 0): min(y + h, s), max(x, 0): min(x + w, s)] = 0
|
93 |
+
def MultiFill(max_tries, max_size):
|
94 |
+
for _ in range(np.random.randint(max_tries)):
|
95 |
+
Fill(max_size)
|
96 |
+
MultiFill(int(10 * coef), s // 2)
|
97 |
+
MultiFill(int(5 * coef), s)
|
98 |
+
##comment the following line for lower masking ratios
|
99 |
+
#mask = np.logical_and(mask, 1 - RandomBrush(int(20 * coef), s))
|
100 |
+
hole_ratio = 1 - np.mean(mask)
|
101 |
+
if hole_range is not None and (hole_ratio <= hole_range[0] or hole_ratio >= hole_range[1]):
|
102 |
+
continue
|
103 |
+
return mask[np.newaxis, ...].astype(np.float32)
|
104 |
+
|
105 |
+
def BatchRandomMask(batch_size, s, hole_range=[0, 1]):
|
106 |
+
return np.stack([RandomMask(s, hole_range=hole_range) for _ in range(batch_size)], axis = 0)
|
107 |
+
|
108 |
+
def random_rotation(shape):
|
109 |
+
cutoff = 100 #was 30
|
110 |
+
(n , channels, p, q) = shape
|
111 |
+
mask = np.zeros((n,p,q))
|
112 |
+
|
113 |
+
for i in range(n):
|
114 |
+
angle = np.random.choice(360, 1)
|
115 |
+
mask_one = np.ones((p+cutoff,q+cutoff))
|
116 |
+
mask_one[int((p+cutoff)/2):,:] = 0
|
117 |
+
|
118 |
+
im = Image.fromarray(mask_one)
|
119 |
+
im = im.rotate(angle)
|
120 |
+
|
121 |
+
left = (p+cutoff - p)/2
|
122 |
+
top = (q+cutoff - q)/2
|
123 |
+
right = (p+cutoff + p)/2
|
124 |
+
bottom = (q+cutoff + q)/2
|
125 |
+
|
126 |
+
# Crop the center of the image
|
127 |
+
im = im.crop((left, top, right, bottom))
|
128 |
+
|
129 |
+
mask[i] = np.array(im)
|
130 |
+
|
131 |
+
#mask = np.repeat(mask.reshape([n,1,p,q]), channels, axis=1)
|
132 |
+
mask = mask.reshape([n,1,p,q])
|
133 |
+
return mask
|
134 |
+
|
135 |
+
class mask_generator:
|
136 |
+
def __init__(self, mask_type, mask_len_range=None, mask_prob_range=None,
|
137 |
+
image_size=256, margin=(16, 16)):
|
138 |
+
"""
|
139 |
+
(mask_len_range): given in (min, max) tuple.
|
140 |
+
Specifies the range of box size in each dimension
|
141 |
+
(mask_prob_range): for the case of random masking,
|
142 |
+
specify the probability of individual pixels being masked
|
143 |
+
"""
|
144 |
+
assert mask_type in ['box', 'random', 'half', 'extreme']
|
145 |
+
self.mask_type = mask_type
|
146 |
+
self.mask_len_range = mask_len_range
|
147 |
+
self.mask_prob_range = mask_prob_range
|
148 |
+
self.image_size = image_size
|
149 |
+
self.margin = margin
|
150 |
+
|
151 |
+
def _retrieve_box(self, img):
|
152 |
+
l, h = self.mask_len_range
|
153 |
+
l, h = int(l), int(h)
|
154 |
+
mask_h = np.random.randint(l, h)
|
155 |
+
mask_w = np.random.randint(l, h)
|
156 |
+
mask, t, tl, w, wh = random_sq_bbox(img,
|
157 |
+
mask_shape=(mask_h, mask_w),
|
158 |
+
image_size=self.image_size,
|
159 |
+
margin=self.margin)
|
160 |
+
return mask, t, tl, w, wh
|
161 |
+
|
162 |
+
def generate_center_mask(self, shape):
|
163 |
+
assert len(shape) == 2
|
164 |
+
assert shape[1] % 2 == 0
|
165 |
+
center = shape[0] // 2
|
166 |
+
center_size = shape[0] // 4
|
167 |
+
half_resol = center_size // 2 # for now
|
168 |
+
ret = torch.zeros(shape, dtype=torch.float32)
|
169 |
+
ret[
|
170 |
+
center - half_resol: center + half_resol,
|
171 |
+
center - half_resol: center + half_resol,
|
172 |
+
] = 1
|
173 |
+
ret = ret.unsqueeze(0).unsqueeze(0)
|
174 |
+
return ret
|
175 |
+
|
176 |
+
def __call__(self, img):
|
177 |
+
if self.mask_type == 'random':
|
178 |
+
mask = BatchRandomMask(1, self.image_size, hole_range=self.mask_prob_range) #self._retrieve_random(img)
|
179 |
+
return mask
|
180 |
+
elif self.mask_type == "half":
|
181 |
+
mask = random_rotation((1, 3, self.image_size, self.image_size))
|
182 |
+
elif self.mask_type == 'box':
|
183 |
+
#mask, t, th, w, wl = self._retrieve_box(img)
|
184 |
+
mask = self.generate_center_mask((self.image_size,self.image_size)) # self._retrieve_box(img)
|
185 |
+
return mask #.permute(0,3,1,2)
|
186 |
+
elif self.mask_type == 'extreme':
|
187 |
+
mask, t, th, w, wl = self._retrieve_box(img)
|
188 |
+
mask = 1. - mask
|
189 |
+
return mask
|
190 |
+
|
191 |
+
|
192 |
+
'''
|
193 |
+
def tf_mask_generator(s, tf_hole_range):
|
194 |
+
def random_mask_generator(hole_range):
|
195 |
+
while True:
|
196 |
+
yield RandomMask(s, hole_range=hole_range)
|
197 |
+
return tf.data.Dataset.from_generator(random_mask_generator, tf.float32, tf.TensorShape([1, s, s]), (tf_hole_range,))
|
198 |
+
'''
|