JiminHeo commited on
Commit
2cc1551
·
1 Parent(s): 2bfb23d
ldm/guided_diffusion/h_posterior.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """INFERENCE TIME OPTIMIZATION"""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+ import torch.distributions as td
8
+ import gc
9
+ import wandb
10
+ import matplotlib.pyplot as plt
11
+ from utils.helper import params_train, get_optimizers,clean_directory, time_descretization, to_img, custom_to_np, save_params, save_samples, save_inpaintings, save_plot
12
+ import os
13
+ import PIL
14
+ import glob
15
+ from tqdm import trange
16
+ import time
17
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, extract_into_tensor, noise_like
18
+ import wandb
19
+
20
+
21
+ class HPosterior(object):
22
+ def __init__(self, model, vae_loss, t_steps_hierarchy, eta=0.4, z0_size=32, img_size = 256, latent_channels = 3,
23
+ num_hierarchy_steps=5, schedule="linear", first_stage = "kl", posterior = "hierarchical", image_queue = None, sampling_queue=None, **kwargs):
24
+ super().__init__()
25
+ self.model = model #prior noise prediction model
26
+ self.schedule = schedule #noise schedule the prior was trained on
27
+ self.vae_loss = vae_loss #vae loss followed during training
28
+ self.eta = eta #eta used to produce faster, clean samples
29
+ self.first_stage= first_stage #first stage training procedure: kl or vq loss
30
+ self.posterior = posterior
31
+ self.t_steps_hierarchy = np.array(t_steps_hierarchy) #time steps for hierachical posterior
32
+ self.z0size = z0_size #dimension of latent space variables z
33
+ self.img_size = img_size #512 #
34
+ self.latent_size = z0_size #128 #
35
+ self.latent_channels = latent_channels
36
+ self.image_queue = image_queue
37
+ self.sampling_queue = sampling_queue
38
+
39
+ def q_given_te(self, t, s, shape, zeta_t_star=None):
40
+ if zeta_t_star is not None:
41
+ alpha_s = torch.sqrt(1 - zeta_t_star**2)
42
+ var_s = zeta_t_star**2
43
+ else:
44
+ if len(s.shape) == 0 :m = 1
45
+ else: m = s.shape[0]
46
+ var_s = (self.model.sqrt_one_minus_alphas_cumprod[s].reshape(m, 1 ,1 ,1))**2
47
+ alpha_s = torch.sqrt(1 - var_s)
48
+
49
+ var_t = (self.model.sqrt_one_minus_alphas_cumprod[t])**2
50
+ alpha_t = torch.sqrt(1 - var_t)
51
+ alpha_t_s = alpha_t.reshape(len(var_t), 1 ,1 ,1) / alpha_s
52
+ var_t_s = var_t.reshape(len(var_t), 1 ,1 ,1) - alpha_t_s**2 * var_s
53
+ return alpha_t_s, torch.sqrt(var_t_s)
54
+
55
+ def qpos_given_te(self, t, s, t_star, z_t_star, z_t, zeta_T_star=None):
56
+ alpha_t_s, scale_t_s = self.q_given_te(t, s, z_t_star.shape)
57
+ alpha_s_t_star, scale_s_t_star = self.q_given_te(s, t_star, z_t_star.shape, zeta_T_star)
58
+
59
+ var = scale_t_s**2 * scale_s_t_star**2 / (scale_t_s**2 + alpha_s_t_star**2 * scale_s_t_star**2 )
60
+ mean = (var) * ( (alpha_s_t_star/scale_s_t_star**2) * z_t_star + (alpha_t_s/scale_t_s**2) * z_t )
61
+ return mean, torch.sqrt(var)
62
+
63
+ def register_buffer(self, name, attr):
64
+ if type(attr) == torch.Tensor:
65
+ if attr.device != torch.device("cuda"):
66
+ attr = attr.to(torch.device("cuda"))
67
+ setattr(self, name, attr)
68
+
69
+ def get_error(self,x,t,c, unconditional_conditioning, unconditional_guidance_scale):
70
+
71
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
72
+ e_t = self.model.apply_model(x.float(), t, c)
73
+ else:
74
+ x_in = torch.cat([x] * 2)
75
+ t_in = torch.cat([t] * 2)
76
+ c_in = torch.cat([unconditional_conditioning, c])
77
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
78
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
79
+
80
+ return e_t
81
+
82
+ def descretize(self, rho):
83
+ #Get time descretization for prior loss (t > T_e)
84
+ self.timesteps_1000 = time_descretization(sigma_min=0.002, sigma_max = 0.999, rho = rho, num_t_steps = 1000)*1000
85
+ self.timesteps_1000 = self.timesteps_1000.cuda().long()
86
+ sigma_timesteps = self.model.sqrt_one_minus_alphas_cumprod[self.timesteps_1000]
87
+ self.register_buffer('sigma_timesteps', sigma_timesteps)
88
+
89
+ #Get prior std for hierarchical time points
90
+ sigma_hierarchy = self.model.sqrt_one_minus_alphas_cumprod[self.t_steps_hierarchy]
91
+ self.t_steps_hierarchy = torch.tensor(self.t_steps_hierarchy.copy()).cuda()
92
+ alphas_h = 1 - sigma_hierarchy**2
93
+ alphas_prev = torch.concatenate([ alphas_h[1:], alphas_h[-1].reshape(1)])
94
+ h_sigmas = torch.sqrt(self.eta * (1 - alphas_prev) / (1 - alphas_h) * (1 - alphas_h / alphas_prev) )
95
+ h_sigmas[1:] = torch.sqrt(self.eta * (1 - alphas_prev[:-1]) / (1 - alphas_h[:-1]) * (1 - alphas_h[:-1] / alphas_prev[:-1]) )
96
+ h_sigmas[0] = torch.sqrt(1 - alphas_h[0])
97
+
98
+ #register tensors
99
+ self.register_buffer('h_alphas', alphas_h)
100
+ self.register_buffer('h_alphas_prev', alphas_prev)
101
+ self.register_buffer('h_sigmas', h_sigmas)
102
+
103
+ def init(self, img, std_scale, mean_scale, prior_scale, mean_scale_top = 0.1):
104
+ num_h_steps = len(self.t_steps_hierarchy)
105
+ img = torch.Tensor.repeat(img,[num_h_steps,1,1,1])[:num_h_steps]
106
+ #sigmas = self.h_sigmas[...,None, None, None].expand(img.shape)
107
+ sigmas = torch.zeros_like(img)
108
+ sqrt_alphas = torch.sqrt(self.h_alphas)[...,None, None, None].expand(img.shape)
109
+ sqrt_one_minus_alphas = torch.sqrt(1 - self.h_alphas)[...,None, None, None].expand(img.shape)
110
+ ## Variances for posterior
111
+ sigmas[0] = self.h_sigmas[0, None, None, None].expand(img[0].shape)
112
+ sigmas[1:] = std_scale * (1/np.sqrt(self.eta)) * self.h_sigmas[1:, None, None, None].expand(img[1:].shape)
113
+ logvar_pos = 2*torch.log(sigmas).float()
114
+ ## Means :
115
+ mean_pos = sqrt_alphas*img + mean_scale*sqrt_one_minus_alphas* torch.randn_like(img)
116
+ mean_pos[0] = img[0] + mean_scale_top*torch.randn_like(img[0])
117
+ ## Gammas for posterior weighing between prior and posterior
118
+ gamma = torch.tensor(prior_scale)[None,None,None,None].expand(img.shape).cuda()
119
+ return mean_pos, logvar_pos, gamma.float()
120
+
121
+ def get_kl(self,mu1, mu2, scale1, scale2, wt):
122
+ return wt*(1/2*scale2**2)*(mu1 - mu2)**2 \
123
+ + torch.log(scale2/scale1) + scale1**2/(2*scale2**2) - 1/2
124
+
125
+ # diffusion loss
126
+ def loss_prior(self, mu_pos, logvar_pos, cond=None,
127
+ unconditional_conditioning=None,
128
+ unconditional_guidance_scale=1, K=10, intermediate_mus=None):
129
+ '''
130
+ This function gets the kl between q(x_{T_e})||p(x_T_e) ) = E_{t>T*_e}[(x_T_e - \mu_\theta(x_t))^2]
131
+ x_T_e = z_t_star, samples from q(x_{T_e})
132
+ Sample z_t by adding noise scaled by sqrt(\sigma_t^2 - \zeta_t^2) so that z_t matches total noise at t
133
+ '''
134
+ t_e = self.t_steps_hierarchy[0]
135
+ ## Sample z_{T_e}
136
+ tau_te = torch.exp(0.5*logvar_pos)
137
+ mu_te = torch.Tensor.repeat(mu_pos, [K,1,1,1])
138
+ z_te = torch.sqrt(1 - tau_te**2 )* mu_te + tau_te * torch.randn_like(mu_te)
139
+
140
+ ## Sample t
141
+ #Get allowed timesteps > T_e
142
+ t_g = torch.where(self.sigma_timesteps > torch.max(tau_te))[0]
143
+ t_allowed = self.timesteps_1000[t_g]
144
+ # print(len(t_g))
145
+ def sample_uniform(t_allowed):
146
+ t0 = torch.rand(1)
147
+ T_max = len(t_allowed)
148
+ T_min = 2 #stay away from close values to T*
149
+ t = torch.remainder(t0 + torch.arange(0., 1., step=1. / K), 1.)*(T_max-T_min) + T_min
150
+ t = torch.floor(t).long()
151
+ return t
152
+ t = sample_uniform(t_allowed)
153
+ t_cur = t_allowed[t]
154
+ t_prev = t_allowed[t-1]
155
+ # print((t_cur - t_prev), t_cur)
156
+
157
+ #sample z_t from p(z_t | z_{T_e})
158
+ alpha_t, scale_t = self.q_given_te(t_cur, t_e, z_te.shape, tau_te)
159
+ error = torch.randn_like(z_te)
160
+ z_t = alpha_t*z_te + error* scale_t
161
+
162
+ #Get prior, posterior mean variances for t_prev
163
+ e_out = self.get_error(z_t.float(), t_cur, cond, unconditional_conditioning, unconditional_guidance_scale)
164
+ alpha_t_, scale_t_ = self.q_given_te(t_cur,t_e, z_te.shape)
165
+ mu_t_hat = (z_t - scale_t_*e_out)/alpha_t_
166
+ pos_mean, pos_scale = self.qpos_given_te(t_cur, t_prev, t_e, z_te, z_t, tau_te)
167
+ prior_mean, prior_scale = self.qpos_given_te(t_cur, t_prev, t_e, mu_t_hat, z_t, None)
168
+
169
+ wt = (1000-t_e)/2
170
+ kl = self.get_kl(pos_mean, prior_mean,pos_scale, prior_scale, wt=1)
171
+ kl = torch.mean(wt*kl, dim=[1,2,3])
172
+
173
+ return {"loss" : kl, "sample" : z_te, "intermediate_mus" : intermediate_mus}
174
+
175
+ def recon_loss(self, samples_pixel, x0_pixel, mask_pixel, operator=None):
176
+ global_step = 0
177
+ if self.first_stage == "kl":
178
+ nll_loss, _ = self.vae_loss(x0_pixel, samples_pixel, mask_pixel, 0, global_step,
179
+ last_layer=self.model.first_stage_model.get_last_layer(), split="val")
180
+ else:
181
+ qloss = torch.tensor([0.]).cuda()
182
+ nll_loss, _ = self.vae_loss(qloss, x0_pixel, samples_pixel, mask_pixel, 0, 0,
183
+ last_layer=self.model.first_stage_model.get_last_layer(), split="val",
184
+ predicted_indices=None, operator=operator)
185
+ #nll_loss = nll_loss/1000
186
+ return { "loss" : nll_loss}
187
+
188
+ def prior_preds(self, z_t, t_cur, cond, a_t, a_prev, sigma_t, unconditional_conditioning, unconditional_guidance_scale ):
189
+ #Get e, pred_x0
190
+ e_out = self.get_error(z_t, t_cur, cond, unconditional_conditioning, unconditional_guidance_scale)
191
+ pred_x0 = (z_t - torch.sqrt(1 - a_t) * e_out) / a_t.sqrt()
192
+ # direction pointing to x_t
193
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_out
194
+ z_next = a_prev.sqrt() * pred_x0 + dir_xt
195
+ return z_next, pred_x0
196
+
197
+ def posterior_mean(self, mu_pos, mu_prior, gamma):
198
+ wt = torch.sigmoid(gamma)
199
+ mean_t_1 = wt*mu_prior + (1-wt)*mu_pos
200
+ return mean_t_1
201
+
202
+ def normalize(self, img):
203
+ img -= torch.min(img)
204
+ return 2*img/torch.max(img) - 1
205
+
206
+ def loss_posterior(self, z_t, mu_pos, logvar_pos, gamma, cond=None,
207
+ unconditional_conditioning=None,
208
+ unconditional_guidance_scale=1,
209
+ K=10, iteration=0, to_sample = False, intermediate_mus=None):
210
+
211
+ sigma_pos = torch.exp(0.5*logvar_pos)
212
+ kl_t, t0, q_entropy = torch.zeros(z_t.shape[0]).cuda(), 100, 0
213
+ num_steps = len(self.t_steps_hierarchy)
214
+ intermediate_samples = np.zeros((num_steps, 1, self.img_size, self.img_size, 3))
215
+ intermediate_preds = np.zeros((num_steps, 1, self.img_size, self.img_size, 3))
216
+ b = z_t.shape[0]
217
+ with torch.no_grad():
218
+ recon = self.model.decode_first_stage(z_t)
219
+ intermediate_samples[0] = to_img(recon)[0]
220
+
221
+ alphas = self.h_alphas
222
+ for i, (t_cur, t_next) in enumerate(zip(self.t_steps_hierarchy[:-1], self.t_steps_hierarchy[1:])):
223
+ t_hat_cur = torch.ones(b).cuda() * (t_cur )
224
+ a_t = torch.full((b, 1, 1, 1), alphas[i]).cuda()
225
+ a_prev = torch.full((b, 1, 1, 1), alphas[i+1]).cuda()
226
+ a_t_prev = a_t/a_prev
227
+ sigma_t = self.h_sigmas[i+1]
228
+ #Get prior predictions
229
+ z_next, pred_x0 = self.prior_preds(z_t.float(), t_hat_cur, cond, a_t, a_prev, sigma_t,
230
+ unconditional_conditioning, unconditional_guidance_scale)
231
+ std_prior = self.h_sigmas[i+1]
232
+
233
+ ##Posterior means and variances
234
+ pos_mean = self.posterior_mean(a_prev.sqrt()*mu_pos[i].unsqueeze(0), z_next, gamma[i].unsqueeze(0))
235
+ std_pos = sigma_pos[i]
236
+
237
+ ## Sample z_t
238
+ z_t = pos_mean + std_pos * torch.randn_like(pos_mean)
239
+ #Get kl
240
+ kl = self.get_kl(pos_mean, z_next, std_pos, std_prior, wt=1)
241
+ kl_t += torch.mean(kl, dim=[1,2,3])
242
+
243
+ with torch.no_grad():
244
+ recon_pred = self.model.decode_first_stage(pred_x0)
245
+ intermediate_preds[i] = to_img(recon_pred)[0]
246
+ intermediate_mus[i+1] = to_img(self.normalize(mu_pos[i]).unsqueeze(0)).astype(np.uint8)[0]
247
+
248
+ ##One-step denoising
249
+ t_hat_cur = torch.ones(b).cuda() * (self.t_steps_hierarchy[-1])
250
+ e_out = self.get_error(z_t.float(), t_hat_cur, cond, unconditional_conditioning, unconditional_guidance_scale)
251
+ a_t = torch.full((b, 1, 1, 1), alphas[-1]).cuda()
252
+ sqrt_one_minus_at = torch.sqrt(1 - a_t)
253
+ pred_z0 = (z_t - sqrt_one_minus_at * e_out) / a_t.sqrt()
254
+
255
+ with torch.no_grad():
256
+ recon = self.model.decode_first_stage(pred_z0)
257
+ intermediate_preds[-1] = to_img(recon)[0]
258
+
259
+ return {"sample" : pred_z0, "loss" : kl_t, "entropy": q_entropy,
260
+ "intermediates" : intermediate_samples, "interim_preds" :intermediate_preds,
261
+ "intermediate_mus" : intermediate_mus}
262
+
263
+ def grad_and_value(self, x_prev, x_0_hat, measurement, mask_pixel, operator):
264
+ nll_loss = torch.mean(self.recon_loss(x_0_hat, measurement, mask_pixel, operator)["loss"])
265
+ norm_grad = torch.autograd.grad(outputs=nll_loss, inputs=x_prev)[0]
266
+ return norm_grad, nll_loss
267
+
268
+ def conditioning(self, x_prev, x_t, x_0_hat, measurement, mask_pixel, scale, operator, **kwargs):
269
+ norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat,
270
+ measurement=measurement, mask_pixel=mask_pixel, operator=operator)
271
+ x_t -= norm_grad*scale
272
+ return x_t, norm
273
+
274
+ def sample(self, scale, eta, mu_pos, logvar_pos, gamma,
275
+ mask_pixel, y, n_samples=100, cond=None,
276
+ unconditional_conditioning=None, unconditional_guidance_scale=1,
277
+ batch_size=10, dir_name="temp/", temp=1,
278
+ samples_iteration=0, operator = None):
279
+ sigma_pos = torch.exp(0.5*logvar_pos)
280
+ t0 = 100
281
+ num_steps = len(self.t_steps_hierarchy)
282
+ intermediate_samples = np.zeros((num_steps, 1, self.img_size, self.img_size, 3))
283
+ intermediate_preds = np.zeros((num_steps, 1, self.img_size, self.img_size, 3))
284
+ intermediate_mus = np.zeros((num_steps, 1, self.img_size, self.img_size, 3))
285
+ alphas = self.h_alphas
286
+
287
+ ##batch your sample generation
288
+ all_images = []
289
+ t0 = time.time()
290
+ save_dir = os.path.join(dir_name , "samples_50_"+ str(scale) ) #50_ #"samples_" + str(scale)
291
+ os.makedirs(save_dir, exist_ok=True)
292
+ for _ in trange(n_samples // batch_size, desc="Sampling Batches"):
293
+ mu_10 = torch.Tensor.repeat(mu_pos[0], [batch_size,1,1,1])
294
+ tau_t = sigma_pos[0]
295
+ z_t = torch.sqrt(1 - tau_t**2 )* mu_10 + tau_t * torch.randn_like(mu_10)
296
+ ##Sample from posterior
297
+ with torch.no_grad():
298
+ recon = self.model.decode_first_stage(z_t)
299
+ intermediate_samples[0] = to_img(recon)[0]
300
+ for i, (t_cur, t_next) in enumerate(zip(self.t_steps_hierarchy[:-1], self.t_steps_hierarchy[1:])):
301
+ # print(t_cur)
302
+ t_hat_cur = torch.ones(batch_size).cuda() * (t_cur )
303
+ a_t = torch.full((batch_size, 1, 1, 1), alphas[i]).cuda()
304
+ a_prev = torch.full((batch_size, 1, 1, 1), alphas[i+1]).cuda()
305
+ sigma_t = self.h_sigmas[i+1]
306
+ #Get prior predictions
307
+ z_next, pred_x0 = self.prior_preds(z_t.float(), t_hat_cur, cond, a_t, a_prev, sigma_t,
308
+ unconditional_conditioning, unconditional_guidance_scale)
309
+ ##Posterior means and variances
310
+ # a_prev.sqrt()*
311
+ mean_t_1 = self.posterior_mean(a_prev.sqrt()*mu_pos[i+1].unsqueeze(0), z_next, gamma[i+1].unsqueeze(0))
312
+ std_pos = sigma_pos[i+1]
313
+ #Sample z_t
314
+ z_t = mean_t_1 + std_pos * torch.randn_like(mean_t_1)
315
+
316
+ with torch.no_grad():
317
+ pred_x = self.model.decode_first_stage(pred_x0)
318
+ save_samples(save_dir, pred_x, k=None, num_to_save = 1, file_name = f'sample_{i}.png')
319
+
320
+
321
+ timesteps = np.flip(np.arange(0, self.t_steps_hierarchy[-1].cpu().numpy(), 1))
322
+ timesteps = np.concatenate((self.t_steps_hierarchy[-1].cpu().reshape(1), timesteps))
323
+ ##Sample using DPS algorithm
324
+ for i, (step, t_next) in enumerate(zip(timesteps[:-1], timesteps[1:])):
325
+ step = int(step)
326
+ t_hat_cur = torch.ones(batch_size).cuda() * (step)
327
+ a_t = torch.full((batch_size, 1, 1, 1), self.model.alphas_cumprod[step]).cuda()
328
+ a_prev = torch.full((batch_size, 1, 1, 1), self.model.alphas_cumprod[int(t_next)]).cuda()
329
+ sigma_t = eta *torch.sqrt( (1 - a_prev) / (1 - a_t) * (1 - a_t / a_prev))
330
+ z_t = z_t.requires_grad_()
331
+ z_next, pred_x0 = self.prior_preds(z_t.float(), t_hat_cur, cond, a_t, a_prev, sigma_t,
332
+ unconditional_conditioning, unconditional_guidance_scale)
333
+ pred_x = self.model.decode_first_stage(pred_x0)
334
+ z_t, _ = self.conditioning(x_prev = z_t , x_t = z_next,
335
+ x_0_hat = pred_x, measurement = y,
336
+ mask_pixel=mask_pixel, scale=scale, operator=operator)
337
+ z_t = z_t.detach_()
338
+
339
+ if i%50 == 0:
340
+ with torch.no_grad():
341
+ recons = self.model.decode_first_stage(pred_x0)
342
+ recons_np = to_img(recons).astype(np.uint8)
343
+ self.sampling_queue.put(recons_np)
344
+ save_samples(save_dir, recons, k=None, num_to_save = 1, file_name = f'det_{step}.png')
345
+
346
+ z_0 = pred_x0
347
+ with torch.no_grad():
348
+ recon = self.model.decode_first_stage(z_0)
349
+ intermediate_preds[-1] = to_img(recons)[0]
350
+
351
+ with torch.no_grad() :
352
+ recons = self.model.decode_first_stage(pred_x0)
353
+ recons_np = to_img(recons).astype(np.uint8)
354
+ self.sampling_queue.put(recons_np)
355
+ all_images.append(custom_to_np(recons))
356
+
357
+ t1 = time.time()
358
+
359
+ all_img = np.concatenate(all_images, axis=0)
360
+ all_img = all_img[:n_samples]
361
+ shape_str = "x".join([str(x) for x in all_img.shape])
362
+ nppath = os.path.join(save_dir, f"{shape_str}-samples.npz")
363
+ np.savez(nppath, all_img, t1-t0)
364
+
365
+ '''
366
+ recon_in = y*(mask_pixel) + ( 1-mask_pixel)*recons
367
+ recon_in = to_img(recon_in)
368
+ image_path = os.path.join(save_dir, str(samples_iteration) + ".png")
369
+ image_np = recon_in.astype(np.uint8)[0]
370
+ PIL.Image.fromarray(image_np, 'RGB').save(image_path)
371
+ '''
372
+ file_name_img = None
373
+
374
+ if operator is None:
375
+ save_inpaintings(save_dir, recons, y, mask_pixel, num_to_save = batch_size) #recons
376
+ else:
377
+ save_samples(save_dir, recons, None, batch_size)
378
+ recons_np = to_img(recons).astype(np.uint8)
379
+ self.sampling_queue.put(recons_np)
380
+ return
381
+
382
+ def fit(self, lambda_, cond, shape, quantize_denoised=False, mask_pixel = None,
383
+ y = None, log_every_t=100, unconditional_guidance_scale=1.,
384
+ unconditional_conditioning=None, dir_name = None, kl_weight_1=50, kl_weight_2 = 50,
385
+ debug=False, wdb=False, iterations=200, batch_size = 10, lr_init_gamma=0.01,
386
+ operator=None, recon_weight = 50):
387
+
388
+ if wdb:
389
+ wandb.init(project='LDM', dir = '/scratch/sakshi/wandb-cache')
390
+ wandb.config.run_type = 'hierarchical'
391
+ wandb.run.name = "hierarchical"
392
+
393
+ params_to_fit = params_train(lambda_)
394
+ mu_pos, logvar_pos, gamma = params_to_fit
395
+ optimizers, schedulers = get_optimizers(mu_pos, logvar_pos, gamma, lr_init_gamma)
396
+ rec_loss_all, prior_loss_all, posterior_loss_all =[], [], []
397
+ loss_all = []
398
+ mu_all, logvar_all, gamma_all = [], [], []
399
+ for k in range(iterations):
400
+ if k%100==0: print(k)
401
+ intermediate_mus = np.zeros((len(self.t_steps_hierarchy), self.latent_size, self.latent_size, self.latent_channels))
402
+
403
+ for opt in optimizers: opt.zero_grad()
404
+ stats_prior = self.loss_prior(mu_pos[0], logvar_pos[0], cond=cond,
405
+ unconditional_conditioning=unconditional_conditioning,
406
+ unconditional_guidance_scale=unconditional_guidance_scale,
407
+ K=batch_size, intermediate_mus=intermediate_mus)
408
+ #stats_posterior = self.get_z0_t(stats_prior["sample"], self.t_steps_hierarchy)
409
+ stats_posterior = self.loss_posterior(stats_prior["sample"], mu_pos[1:], logvar_pos[1:], gamma[1:],
410
+ cond=cond,
411
+ unconditional_conditioning=unconditional_conditioning,
412
+ unconditional_guidance_scale=unconditional_guidance_scale,
413
+ K=batch_size, iteration=k, intermediate_mus=stats_prior["intermediate_mus"])
414
+ sample = self.model.decode_first_stage(stats_posterior["sample"])
415
+
416
+ stats_recon = self.recon_loss(sample, y, mask_pixel, operator)
417
+ num_pixels = 3*256*256 #(1000/num_pixels)* (1000/num_pixels)*
418
+ loss_total = torch.mean(kl_weight_1*stats_prior["loss"] \
419
+ + kl_weight_2*stats_posterior["loss"] + recon_weight*stats_recon["loss"] ) #
420
+ loss_total.backward()
421
+ for opt in optimizers: opt.step()
422
+ for sch in schedulers: sch.step()
423
+
424
+ rec_loss_all.append(torch.mean(stats_recon["loss"].detach()).item())
425
+ prior_loss_all.append(torch.mean(kl_weight_1*stats_prior["loss"].detach()).item())
426
+ posterior_loss_all.append(torch.mean(kl_weight_2*stats_posterior["loss"].detach()).item())
427
+ mu_all.append(torch.mean(mu_pos.detach()).item())
428
+ logvar_all.append(torch.mean(logvar_pos.detach()).item())
429
+ gamma_all.append(torch.mean(torch.sigmoid(gamma).detach()).item())
430
+ sample_np = to_img(sample).astype(np.uint8)
431
+ loss_all.append(loss_total.detach().item())
432
+ self.image_queue.put(sample_np)
433
+
434
+
435
+ save_plot(dir_name, [rec_loss_all, prior_loss_all, posterior_loss_all],
436
+ ["Recon loss", "Diffusion loss", "Hierarchical loss"], "loss.png")
437
+ save_plot(dir_name, [loss_all],
438
+ ["Total Loss"], "loss_t.png")
439
+ save_plot(dir_name, [mu_all],
440
+ ["mean"], "mean.png")
441
+ save_plot(dir_name, [logvar_all],
442
+ ["logvar"], "logvar.png")
443
+ save_plot(dir_name, [gamma_all],
444
+ ["gamma"], "gamma.png")
445
+
446
+ if k%log_every_t == 0 or k == iterations - 1:
447
+ save_samples(os.path.join(dir_name , "progress"), sample, k, batch_size)
448
+ save_samples(os.path.join(dir_name , "mus"), stats_posterior["intermediate_mus"], k,
449
+ len(stats_posterior["intermediate_mus"]))
450
+
451
+ #save_inpaintings(os.path.join(dir_name , "progress_inpaintings"), sample, y,
452
+ # mask_pixel, k, num_to_save = 5)
453
+ save_params(os.path.join(dir_name , "params"), mu_pos, logvar_pos, gamma,k)
454
+
455
+ gc.collect()
456
+ return
457
+
458
+ ##unconditional samplinng for debugging purposes:
459
+ '''
460
+ def sample_T(self, x0, cond, unconditional_conditioning, unconditional_guidance_scale , eta=0.4, t_steps_hierarchy=None, dir_="out_temp2"):
461
+ ''
462
+ sigma_discretization_edm = time_descretization(sigma_min=0.002, sigma_max = 999, rho = 7, num_t_steps = 10)/1000
463
+ T_max = 1000
464
+ beta_start = 1 # 0.0015*T_max
465
+ beta_end = 15 # 0.0155*T_max
466
+ def var(t):
467
+ return 1.0 - (1.0) * torch.exp(- beta_start * t - 0.5 * (beta_end - beta_start) * t * t)
468
+ ''
469
+
470
+ x0 = torch.randn_like(x0)
471
+ t_steps_hierarchy = torch.tensor(self.t_steps_hierarchy).cuda()
472
+ var_t = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[0]].reshape(1, 1 ,1 ,1))**2 # self.var(t_steps_hierarchy[0])
473
+ x_t = x0 # torch.sqrt(1 - var_t) * x0 + torch.sqrt(var_t) * torch.randn_like(x0)
474
+
475
+ os.makedirs(dir_, exist_ok=True)
476
+ alphas = self.h_alphas
477
+ b = 5
478
+ for i, t in enumerate(t_steps_hierarchy[:-1]):
479
+ t_hat = torch.ones(b).cuda() * (t)
480
+ a_t = torch.full((b, 1, 1, 1), alphas[i]).cuda()
481
+ a_prev = torch.full((b, 1, 1, 1), alphas[i+1]).cuda()
482
+ sigma_t = self.h_sigmas[i+1]
483
+ x_t, pred_x0 = self.prior_preds(x_t.float(), t_hat, cond, a_t, a_prev, sigma_t,
484
+ unconditional_conditioning, unconditional_guidance_scale)
485
+
486
+ var_t = (self.model.sqrt_one_minus_alphas_cumprod[t].reshape(1, 1 ,1 ,1))**2
487
+ a_t = 1 - var_t
488
+ x_t = x_t + sigma_t*torch.randn_like(x_t)
489
+ recon = self.model.decode_first_stage(pred_x0)
490
+ image_path = os.path.join(dir_, f'{i}.png')
491
+ image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0]
492
+ PIL.Image.fromarray(image_np, 'RGB').save(image_path)
493
+
494
+ t_hat_cur = torch.ones(b).cuda() * (self.t_steps_hierarchy[-1])
495
+ e_out = self.get_error(x_t.float(), t_hat_cur, cond, unconditional_conditioning, unconditional_guidance_scale)
496
+ a_t = torch.full((b, 1, 1, 1), alphas[-1]).cuda()
497
+ sqrt_one_minus_at = torch.sqrt(1 - a_t)
498
+ pred_x0 = (x_t - sqrt_one_minus_at * e_out) / a_t.sqrt()
499
+
500
+ recon = self.model.decode_first_stage(pred_x0)
501
+ image_path = os.path.join(dir_, f'{len(t_steps_hierarchy)}.png')
502
+ image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0]
503
+ PIL.Image.fromarray(image_np, 'RGB').save(image_path)
504
+ return
505
+
506
+ '''
ldm/guided_diffusion/loss_vq.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from einops import repeat
5
+
6
+ from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7
+ from taming.modules.losses.lpips import LPIPS
8
+ from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9
+
10
+
11
+ def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12
+ assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13
+ loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14
+ loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15
+ loss_real = (weights * loss_real).sum() / weights.sum()
16
+ loss_fake = (weights * loss_fake).sum() / weights.sum()
17
+ d_loss = 0.5 * (loss_real + loss_fake)
18
+ return d_loss
19
+
20
+ def adopt_weight(weight, global_step, threshold=0, value=0.):
21
+ if global_step < threshold:
22
+ weight = value
23
+ return weight
24
+
25
+
26
+ def measure_perplexity(predicted_indices, n_embed):
27
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29
+ encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30
+ avg_probs = encodings.mean(0)
31
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32
+ cluster_use = torch.sum(avg_probs > 0)
33
+ return perplexity, cluster_use
34
+
35
+ def l1(x, y):
36
+ return torch.abs(x-y)
37
+
38
+
39
+ def l2(x, y):
40
+ return torch.pow((x-y), 2)
41
+
42
+
43
+ class VQLPIPSWithDiscriminator(nn.Module):
44
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47
+ disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48
+ pixel_loss="l1"):
49
+ super().__init__()
50
+ assert disc_loss in ["hinge", "vanilla"]
51
+ assert perceptual_loss in ["lpips", "clips", "dists"]
52
+ assert pixel_loss in ["l1", "l2"]
53
+ self.codebook_weight = codebook_weight
54
+ self.pixel_weight = pixelloss_weight
55
+ if perceptual_loss == "lpips":
56
+ print(f"{self.__class__.__name__}: Running with LPIPS.")
57
+ self.perceptual_loss = LPIPS().eval().to(device="cuda")
58
+ else:
59
+ raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60
+ self.perceptual_weight = perceptual_weight
61
+
62
+ if pixel_loss == "l1":
63
+ self.pixel_loss = l1
64
+ else:
65
+ self.pixel_loss = l2
66
+
67
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68
+ n_layers=disc_num_layers,
69
+ use_actnorm=use_actnorm,
70
+ ndf=disc_ndf
71
+ ).apply(weights_init).cuda()
72
+ self.discriminator.eval()
73
+ self.discriminator_iter_start = disc_start
74
+ if disc_loss == "hinge":
75
+ self.disc_loss = hinge_d_loss
76
+ elif disc_loss == "vanilla":
77
+ self.disc_loss = vanilla_d_loss
78
+ else:
79
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
80
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
81
+ self.disc_factor = disc_factor
82
+ self.discriminator_weight = disc_weight
83
+ self.disc_conditional = disc_conditional
84
+ self.n_classes = n_classes
85
+
86
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
87
+ if last_layer is not None:
88
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
89
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
90
+ else:
91
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
92
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
93
+
94
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
95
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
96
+ d_weight = d_weight * self.discriminator_weight
97
+ return d_weight
98
+
99
+ def forward(self, codebook_loss, inputs, reconstructions, mask, optimizer_idx,
100
+ global_step, last_layer=None, cond=None, split="train", predicted_indices=None,
101
+ operator=None, noiser = None):
102
+
103
+ #if not exists(codebook_loss):
104
+ # codebook_loss = torch.tensor([0.]).to(inputs.device)
105
+ #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
106
+ '''
107
+ if operator is not None: x = operator.forward(reconstructions)
108
+ else: x = reconstructions.contiguous()
109
+ rec_loss = torch.abs(inputs - x)
110
+ '''
111
+ #rec_loss = torch.sum(rec_loss, dim=[1,2,3])
112
+ #rec_loss = torch.linalg.norm(difference)
113
+ if operator is not None : x = operator.forward(reconstructions)
114
+ else :
115
+ x = reconstructions.contiguous()*mask
116
+ inputs = inputs.contiguous()*mask
117
+ rec_loss = self.pixel_loss(inputs,x)
118
+ std = 0.566 #+ 0.05
119
+
120
+ #rec_loss = torch.abs(inputs.contiguous()*(mask) - reconstructions.contiguous()*(mask))
121
+ #nll_loss = torch.linalg.norm(rec_loss)
122
+ #num_obs = torch.sum(mask)
123
+
124
+ if self.perceptual_weight > 0:
125
+ if operator is None:
126
+ p_loss = self.perceptual_loss(mask*inputs.contiguous().float(), mask*reconstructions.contiguous().float())
127
+ else:
128
+ p_loss = torch.tensor([0.0])
129
+ # p_loss = self.perceptual_loss(inputs.contiguous().float(), reconstructions.contiguous().float())
130
+
131
+ rec_loss = rec_loss #+ self.perceptual_weight * p_loss #.reshape(rec_loss.shape[0]) #
132
+ else:
133
+ p_loss = torch.tensor([0.0])
134
+
135
+ #rec_loss = torch.mean(rec_loss, dim =[1,2,3])
136
+
137
+ nll_loss = rec_loss /(2*std**2) #+ 2* torch.log(std) #+ self.logvar
138
+ nll_loss = 100*torch.mean(nll_loss) + 100*self.perceptual_weight * p_loss.squeeze() #/ (nll_loss.shape[0]) #num_obs
139
+
140
+ #rec_loss = torch.sum(rec_loss, dim=[1,2,3]) / (torch.sum(mask)*3) #*1000 #rec_loss.shape[0]*
141
+
142
+ #nll_loss = torch.mean(rec_loss)
143
+
144
+ #nll_loss = torch.mean(nll_loss) + self.codebook_weight * codebook_loss.mean()
145
+ return nll_loss, nll_loss
146
+ # now the GAN part
147
+ if optimizer_idx == 0:
148
+ # generator update
149
+ if cond is None:
150
+ assert not self.disc_conditional
151
+ logits_fake = self.discriminator(reconstructions.contiguous())
152
+ else:
153
+ assert self.disc_conditional
154
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
155
+ g_loss = -torch.mean(logits_fake) #200*
156
+
157
+ '''
158
+ try:
159
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
160
+ except RuntimeError:
161
+ assert not self.training
162
+ d_weight = torch.tensor(0.0)
163
+
164
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
165
+ '''
166
+ #d_weight * disc_factor *
167
+ loss = nll_loss + g_loss + self.codebook_weight * codebook_loss.mean()
168
+
169
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
170
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
171
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
172
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
173
+ #"{}/p_loss".format(split): p_loss.detach().mean(),
174
+ #"{}/d_weight".format(split): d_weight.detach(),
175
+ #"{}/disc_factor".format(split): torch.tensor(disc_factor),
176
+ "{}/g_loss".format(split): g_loss.detach().mean(),
177
+ }
178
+
179
+ if predicted_indices is not None:
180
+ assert self.n_classes is not None
181
+ with torch.no_grad():
182
+ perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
183
+ log[f"{split}/perplexity"] = perplexity
184
+ log[f"{split}/cluster_usage"] = cluster_usage
185
+ return loss, log
186
+
187
+ if optimizer_idx == 1:
188
+ # second pass for discriminator update
189
+ if cond is None:
190
+ logits_real = self.discriminator(inputs.contiguous().detach())
191
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
192
+ else:
193
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
194
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
195
+
196
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
197
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
198
+
199
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
200
+ "{}/logits_real".format(split): logits_real.detach().mean(),
201
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
202
+ }
203
+ return d_loss, log
ldm/guided_diffusion/losses.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5
+
6
+ class LPIPSWithDiscriminator(nn.Module):
7
+ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
8
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
9
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
10
+ disc_loss="hinge"):
11
+
12
+ super().__init__()
13
+ assert disc_loss in ["hinge", "vanilla"]
14
+ self.kl_weight = kl_weight
15
+ self.pixel_weight = pixelloss_weight
16
+ self.perceptual_loss = LPIPS().eval().cuda()
17
+ self.perceptual_weight = perceptual_weight
18
+ # output log variance
19
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
20
+
21
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
22
+ n_layers=disc_num_layers,
23
+ use_actnorm=use_actnorm
24
+ ).apply(weights_init).cuda()
25
+ self.discriminator_iter_start = disc_start
26
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
27
+ self.disc_factor = disc_factor
28
+ self.discriminator_weight = disc_weight
29
+ self.disc_conditional = disc_conditional
30
+
31
+ def calculate_adaptive_weight(self, nll_loss, g_loss, reconstructions, last_layer=None):
32
+ if last_layer is not None:
33
+
34
+ nll_grads = torch.autograd.grad(nll_loss, reconstructions, retain_graph=True)[0]
35
+ g_grads = torch.autograd.grad(g_loss, reconstructions, retain_graph=True)[0]
36
+ else:
37
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39
+
40
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42
+ d_weight = d_weight * self.discriminator_weight
43
+ return d_weight
44
+
45
+ def forward(self, inputs, reconstructions, mask, optimizer_idx,
46
+ global_step, posteriors = None, last_layer=None, cond=None, split="train",
47
+ weights=None):
48
+ rec_loss = torch.abs(inputs.contiguous()*(mask) - reconstructions.contiguous()*(mask))
49
+ if self.perceptual_weight > 0:
50
+ p_loss = self.perceptual_loss(inputs.contiguous()*(mask), reconstructions.contiguous()*(mask))
51
+ rec_loss = rec_loss
52
+
53
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54
+ #weighted_nll_loss = nll_loss
55
+ #if weights is not None:
56
+ # weighted_nll_loss = weights*nll_loss
57
+
58
+ #weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
59
+ nll_loss = 100*torch.mean(nll_loss, dim = [1,2,3]) + 100*self.perceptual_weight * p_loss.squeeze() #/ nll_loss.shape[0]
60
+ #kl_loss = posteriors.kl()
61
+ #kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
62
+
63
+ return nll_loss, nll_loss
64
+ #return weighted_nll_loss, nll_loss
65
+ # now the GAN part
66
+ if optimizer_idx == 0:
67
+ # generator update
68
+ if cond is None:
69
+ assert not self.disc_conditional
70
+ logits_fake = self.discriminator(reconstructions.contiguous())
71
+ else:
72
+ assert self.disc_conditional
73
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
74
+ g_loss = -torch.mean(logits_fake)
75
+
76
+ if self.disc_factor > 0.0:
77
+ try:
78
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, reconstructions, last_layer=last_layer)
79
+ except RuntimeError:
80
+ assert not self.training
81
+ d_weight = torch.tensor(0.0)
82
+ else:
83
+ d_weight = torch.tensor(0.0)
84
+
85
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
86
+ #+ self.kl_weight * kl_loss
87
+ #print("GAN Losss : ", d_weight * g_loss)
88
+ loss = weighted_nll_loss #+ d_weight * g_loss
89
+
90
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
91
+ #"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
92
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
93
+ "{}/d_weight".format(split): d_weight.detach(),
94
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
95
+ "{}/g_loss".format(split): g_loss.detach().mean(),
96
+ }
97
+ return loss, log
98
+
99
+ if optimizer_idx == 1:
100
+ # second pass for discriminator update
101
+ if cond is None:
102
+ logits_real = self.discriminator(inputs.contiguous().detach())
103
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
104
+ else:
105
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
106
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
107
+
108
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
109
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
110
+
111
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
112
+ "{}/logits_real".format(split): logits_real.detach().mean(),
113
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
114
+ }
115
+ return d_loss, log
116
+