Spaces:
Sleeping
Sleeping
ldm
Browse files- ldm/guided_diffusion/h_posterior.py +506 -0
- ldm/guided_diffusion/loss_vq.py +203 -0
- ldm/guided_diffusion/losses.py +116 -0
ldm/guided_diffusion/h_posterior.py
ADDED
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""INFERENCE TIME OPTIMIZATION"""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
from functools import partial
|
7 |
+
import torch.distributions as td
|
8 |
+
import gc
|
9 |
+
import wandb
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from utils.helper import params_train, get_optimizers,clean_directory, time_descretization, to_img, custom_to_np, save_params, save_samples, save_inpaintings, save_plot
|
12 |
+
import os
|
13 |
+
import PIL
|
14 |
+
import glob
|
15 |
+
from tqdm import trange
|
16 |
+
import time
|
17 |
+
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, extract_into_tensor, noise_like
|
18 |
+
import wandb
|
19 |
+
|
20 |
+
|
21 |
+
class HPosterior(object):
|
22 |
+
def __init__(self, model, vae_loss, t_steps_hierarchy, eta=0.4, z0_size=32, img_size = 256, latent_channels = 3,
|
23 |
+
num_hierarchy_steps=5, schedule="linear", first_stage = "kl", posterior = "hierarchical", image_queue = None, sampling_queue=None, **kwargs):
|
24 |
+
super().__init__()
|
25 |
+
self.model = model #prior noise prediction model
|
26 |
+
self.schedule = schedule #noise schedule the prior was trained on
|
27 |
+
self.vae_loss = vae_loss #vae loss followed during training
|
28 |
+
self.eta = eta #eta used to produce faster, clean samples
|
29 |
+
self.first_stage= first_stage #first stage training procedure: kl or vq loss
|
30 |
+
self.posterior = posterior
|
31 |
+
self.t_steps_hierarchy = np.array(t_steps_hierarchy) #time steps for hierachical posterior
|
32 |
+
self.z0size = z0_size #dimension of latent space variables z
|
33 |
+
self.img_size = img_size #512 #
|
34 |
+
self.latent_size = z0_size #128 #
|
35 |
+
self.latent_channels = latent_channels
|
36 |
+
self.image_queue = image_queue
|
37 |
+
self.sampling_queue = sampling_queue
|
38 |
+
|
39 |
+
def q_given_te(self, t, s, shape, zeta_t_star=None):
|
40 |
+
if zeta_t_star is not None:
|
41 |
+
alpha_s = torch.sqrt(1 - zeta_t_star**2)
|
42 |
+
var_s = zeta_t_star**2
|
43 |
+
else:
|
44 |
+
if len(s.shape) == 0 :m = 1
|
45 |
+
else: m = s.shape[0]
|
46 |
+
var_s = (self.model.sqrt_one_minus_alphas_cumprod[s].reshape(m, 1 ,1 ,1))**2
|
47 |
+
alpha_s = torch.sqrt(1 - var_s)
|
48 |
+
|
49 |
+
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t])**2
|
50 |
+
alpha_t = torch.sqrt(1 - var_t)
|
51 |
+
alpha_t_s = alpha_t.reshape(len(var_t), 1 ,1 ,1) / alpha_s
|
52 |
+
var_t_s = var_t.reshape(len(var_t), 1 ,1 ,1) - alpha_t_s**2 * var_s
|
53 |
+
return alpha_t_s, torch.sqrt(var_t_s)
|
54 |
+
|
55 |
+
def qpos_given_te(self, t, s, t_star, z_t_star, z_t, zeta_T_star=None):
|
56 |
+
alpha_t_s, scale_t_s = self.q_given_te(t, s, z_t_star.shape)
|
57 |
+
alpha_s_t_star, scale_s_t_star = self.q_given_te(s, t_star, z_t_star.shape, zeta_T_star)
|
58 |
+
|
59 |
+
var = scale_t_s**2 * scale_s_t_star**2 / (scale_t_s**2 + alpha_s_t_star**2 * scale_s_t_star**2 )
|
60 |
+
mean = (var) * ( (alpha_s_t_star/scale_s_t_star**2) * z_t_star + (alpha_t_s/scale_t_s**2) * z_t )
|
61 |
+
return mean, torch.sqrt(var)
|
62 |
+
|
63 |
+
def register_buffer(self, name, attr):
|
64 |
+
if type(attr) == torch.Tensor:
|
65 |
+
if attr.device != torch.device("cuda"):
|
66 |
+
attr = attr.to(torch.device("cuda"))
|
67 |
+
setattr(self, name, attr)
|
68 |
+
|
69 |
+
def get_error(self,x,t,c, unconditional_conditioning, unconditional_guidance_scale):
|
70 |
+
|
71 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
72 |
+
e_t = self.model.apply_model(x.float(), t, c)
|
73 |
+
else:
|
74 |
+
x_in = torch.cat([x] * 2)
|
75 |
+
t_in = torch.cat([t] * 2)
|
76 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
77 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
78 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
79 |
+
|
80 |
+
return e_t
|
81 |
+
|
82 |
+
def descretize(self, rho):
|
83 |
+
#Get time descretization for prior loss (t > T_e)
|
84 |
+
self.timesteps_1000 = time_descretization(sigma_min=0.002, sigma_max = 0.999, rho = rho, num_t_steps = 1000)*1000
|
85 |
+
self.timesteps_1000 = self.timesteps_1000.cuda().long()
|
86 |
+
sigma_timesteps = self.model.sqrt_one_minus_alphas_cumprod[self.timesteps_1000]
|
87 |
+
self.register_buffer('sigma_timesteps', sigma_timesteps)
|
88 |
+
|
89 |
+
#Get prior std for hierarchical time points
|
90 |
+
sigma_hierarchy = self.model.sqrt_one_minus_alphas_cumprod[self.t_steps_hierarchy]
|
91 |
+
self.t_steps_hierarchy = torch.tensor(self.t_steps_hierarchy.copy()).cuda()
|
92 |
+
alphas_h = 1 - sigma_hierarchy**2
|
93 |
+
alphas_prev = torch.concatenate([ alphas_h[1:], alphas_h[-1].reshape(1)])
|
94 |
+
h_sigmas = torch.sqrt(self.eta * (1 - alphas_prev) / (1 - alphas_h) * (1 - alphas_h / alphas_prev) )
|
95 |
+
h_sigmas[1:] = torch.sqrt(self.eta * (1 - alphas_prev[:-1]) / (1 - alphas_h[:-1]) * (1 - alphas_h[:-1] / alphas_prev[:-1]) )
|
96 |
+
h_sigmas[0] = torch.sqrt(1 - alphas_h[0])
|
97 |
+
|
98 |
+
#register tensors
|
99 |
+
self.register_buffer('h_alphas', alphas_h)
|
100 |
+
self.register_buffer('h_alphas_prev', alphas_prev)
|
101 |
+
self.register_buffer('h_sigmas', h_sigmas)
|
102 |
+
|
103 |
+
def init(self, img, std_scale, mean_scale, prior_scale, mean_scale_top = 0.1):
|
104 |
+
num_h_steps = len(self.t_steps_hierarchy)
|
105 |
+
img = torch.Tensor.repeat(img,[num_h_steps,1,1,1])[:num_h_steps]
|
106 |
+
#sigmas = self.h_sigmas[...,None, None, None].expand(img.shape)
|
107 |
+
sigmas = torch.zeros_like(img)
|
108 |
+
sqrt_alphas = torch.sqrt(self.h_alphas)[...,None, None, None].expand(img.shape)
|
109 |
+
sqrt_one_minus_alphas = torch.sqrt(1 - self.h_alphas)[...,None, None, None].expand(img.shape)
|
110 |
+
## Variances for posterior
|
111 |
+
sigmas[0] = self.h_sigmas[0, None, None, None].expand(img[0].shape)
|
112 |
+
sigmas[1:] = std_scale * (1/np.sqrt(self.eta)) * self.h_sigmas[1:, None, None, None].expand(img[1:].shape)
|
113 |
+
logvar_pos = 2*torch.log(sigmas).float()
|
114 |
+
## Means :
|
115 |
+
mean_pos = sqrt_alphas*img + mean_scale*sqrt_one_minus_alphas* torch.randn_like(img)
|
116 |
+
mean_pos[0] = img[0] + mean_scale_top*torch.randn_like(img[0])
|
117 |
+
## Gammas for posterior weighing between prior and posterior
|
118 |
+
gamma = torch.tensor(prior_scale)[None,None,None,None].expand(img.shape).cuda()
|
119 |
+
return mean_pos, logvar_pos, gamma.float()
|
120 |
+
|
121 |
+
def get_kl(self,mu1, mu2, scale1, scale2, wt):
|
122 |
+
return wt*(1/2*scale2**2)*(mu1 - mu2)**2 \
|
123 |
+
+ torch.log(scale2/scale1) + scale1**2/(2*scale2**2) - 1/2
|
124 |
+
|
125 |
+
# diffusion loss
|
126 |
+
def loss_prior(self, mu_pos, logvar_pos, cond=None,
|
127 |
+
unconditional_conditioning=None,
|
128 |
+
unconditional_guidance_scale=1, K=10, intermediate_mus=None):
|
129 |
+
'''
|
130 |
+
This function gets the kl between q(x_{T_e})||p(x_T_e) ) = E_{t>T*_e}[(x_T_e - \mu_\theta(x_t))^2]
|
131 |
+
x_T_e = z_t_star, samples from q(x_{T_e})
|
132 |
+
Sample z_t by adding noise scaled by sqrt(\sigma_t^2 - \zeta_t^2) so that z_t matches total noise at t
|
133 |
+
'''
|
134 |
+
t_e = self.t_steps_hierarchy[0]
|
135 |
+
## Sample z_{T_e}
|
136 |
+
tau_te = torch.exp(0.5*logvar_pos)
|
137 |
+
mu_te = torch.Tensor.repeat(mu_pos, [K,1,1,1])
|
138 |
+
z_te = torch.sqrt(1 - tau_te**2 )* mu_te + tau_te * torch.randn_like(mu_te)
|
139 |
+
|
140 |
+
## Sample t
|
141 |
+
#Get allowed timesteps > T_e
|
142 |
+
t_g = torch.where(self.sigma_timesteps > torch.max(tau_te))[0]
|
143 |
+
t_allowed = self.timesteps_1000[t_g]
|
144 |
+
# print(len(t_g))
|
145 |
+
def sample_uniform(t_allowed):
|
146 |
+
t0 = torch.rand(1)
|
147 |
+
T_max = len(t_allowed)
|
148 |
+
T_min = 2 #stay away from close values to T*
|
149 |
+
t = torch.remainder(t0 + torch.arange(0., 1., step=1. / K), 1.)*(T_max-T_min) + T_min
|
150 |
+
t = torch.floor(t).long()
|
151 |
+
return t
|
152 |
+
t = sample_uniform(t_allowed)
|
153 |
+
t_cur = t_allowed[t]
|
154 |
+
t_prev = t_allowed[t-1]
|
155 |
+
# print((t_cur - t_prev), t_cur)
|
156 |
+
|
157 |
+
#sample z_t from p(z_t | z_{T_e})
|
158 |
+
alpha_t, scale_t = self.q_given_te(t_cur, t_e, z_te.shape, tau_te)
|
159 |
+
error = torch.randn_like(z_te)
|
160 |
+
z_t = alpha_t*z_te + error* scale_t
|
161 |
+
|
162 |
+
#Get prior, posterior mean variances for t_prev
|
163 |
+
e_out = self.get_error(z_t.float(), t_cur, cond, unconditional_conditioning, unconditional_guidance_scale)
|
164 |
+
alpha_t_, scale_t_ = self.q_given_te(t_cur,t_e, z_te.shape)
|
165 |
+
mu_t_hat = (z_t - scale_t_*e_out)/alpha_t_
|
166 |
+
pos_mean, pos_scale = self.qpos_given_te(t_cur, t_prev, t_e, z_te, z_t, tau_te)
|
167 |
+
prior_mean, prior_scale = self.qpos_given_te(t_cur, t_prev, t_e, mu_t_hat, z_t, None)
|
168 |
+
|
169 |
+
wt = (1000-t_e)/2
|
170 |
+
kl = self.get_kl(pos_mean, prior_mean,pos_scale, prior_scale, wt=1)
|
171 |
+
kl = torch.mean(wt*kl, dim=[1,2,3])
|
172 |
+
|
173 |
+
return {"loss" : kl, "sample" : z_te, "intermediate_mus" : intermediate_mus}
|
174 |
+
|
175 |
+
def recon_loss(self, samples_pixel, x0_pixel, mask_pixel, operator=None):
|
176 |
+
global_step = 0
|
177 |
+
if self.first_stage == "kl":
|
178 |
+
nll_loss, _ = self.vae_loss(x0_pixel, samples_pixel, mask_pixel, 0, global_step,
|
179 |
+
last_layer=self.model.first_stage_model.get_last_layer(), split="val")
|
180 |
+
else:
|
181 |
+
qloss = torch.tensor([0.]).cuda()
|
182 |
+
nll_loss, _ = self.vae_loss(qloss, x0_pixel, samples_pixel, mask_pixel, 0, 0,
|
183 |
+
last_layer=self.model.first_stage_model.get_last_layer(), split="val",
|
184 |
+
predicted_indices=None, operator=operator)
|
185 |
+
#nll_loss = nll_loss/1000
|
186 |
+
return { "loss" : nll_loss}
|
187 |
+
|
188 |
+
def prior_preds(self, z_t, t_cur, cond, a_t, a_prev, sigma_t, unconditional_conditioning, unconditional_guidance_scale ):
|
189 |
+
#Get e, pred_x0
|
190 |
+
e_out = self.get_error(z_t, t_cur, cond, unconditional_conditioning, unconditional_guidance_scale)
|
191 |
+
pred_x0 = (z_t - torch.sqrt(1 - a_t) * e_out) / a_t.sqrt()
|
192 |
+
# direction pointing to x_t
|
193 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_out
|
194 |
+
z_next = a_prev.sqrt() * pred_x0 + dir_xt
|
195 |
+
return z_next, pred_x0
|
196 |
+
|
197 |
+
def posterior_mean(self, mu_pos, mu_prior, gamma):
|
198 |
+
wt = torch.sigmoid(gamma)
|
199 |
+
mean_t_1 = wt*mu_prior + (1-wt)*mu_pos
|
200 |
+
return mean_t_1
|
201 |
+
|
202 |
+
def normalize(self, img):
|
203 |
+
img -= torch.min(img)
|
204 |
+
return 2*img/torch.max(img) - 1
|
205 |
+
|
206 |
+
def loss_posterior(self, z_t, mu_pos, logvar_pos, gamma, cond=None,
|
207 |
+
unconditional_conditioning=None,
|
208 |
+
unconditional_guidance_scale=1,
|
209 |
+
K=10, iteration=0, to_sample = False, intermediate_mus=None):
|
210 |
+
|
211 |
+
sigma_pos = torch.exp(0.5*logvar_pos)
|
212 |
+
kl_t, t0, q_entropy = torch.zeros(z_t.shape[0]).cuda(), 100, 0
|
213 |
+
num_steps = len(self.t_steps_hierarchy)
|
214 |
+
intermediate_samples = np.zeros((num_steps, 1, self.img_size, self.img_size, 3))
|
215 |
+
intermediate_preds = np.zeros((num_steps, 1, self.img_size, self.img_size, 3))
|
216 |
+
b = z_t.shape[0]
|
217 |
+
with torch.no_grad():
|
218 |
+
recon = self.model.decode_first_stage(z_t)
|
219 |
+
intermediate_samples[0] = to_img(recon)[0]
|
220 |
+
|
221 |
+
alphas = self.h_alphas
|
222 |
+
for i, (t_cur, t_next) in enumerate(zip(self.t_steps_hierarchy[:-1], self.t_steps_hierarchy[1:])):
|
223 |
+
t_hat_cur = torch.ones(b).cuda() * (t_cur )
|
224 |
+
a_t = torch.full((b, 1, 1, 1), alphas[i]).cuda()
|
225 |
+
a_prev = torch.full((b, 1, 1, 1), alphas[i+1]).cuda()
|
226 |
+
a_t_prev = a_t/a_prev
|
227 |
+
sigma_t = self.h_sigmas[i+1]
|
228 |
+
#Get prior predictions
|
229 |
+
z_next, pred_x0 = self.prior_preds(z_t.float(), t_hat_cur, cond, a_t, a_prev, sigma_t,
|
230 |
+
unconditional_conditioning, unconditional_guidance_scale)
|
231 |
+
std_prior = self.h_sigmas[i+1]
|
232 |
+
|
233 |
+
##Posterior means and variances
|
234 |
+
pos_mean = self.posterior_mean(a_prev.sqrt()*mu_pos[i].unsqueeze(0), z_next, gamma[i].unsqueeze(0))
|
235 |
+
std_pos = sigma_pos[i]
|
236 |
+
|
237 |
+
## Sample z_t
|
238 |
+
z_t = pos_mean + std_pos * torch.randn_like(pos_mean)
|
239 |
+
#Get kl
|
240 |
+
kl = self.get_kl(pos_mean, z_next, std_pos, std_prior, wt=1)
|
241 |
+
kl_t += torch.mean(kl, dim=[1,2,3])
|
242 |
+
|
243 |
+
with torch.no_grad():
|
244 |
+
recon_pred = self.model.decode_first_stage(pred_x0)
|
245 |
+
intermediate_preds[i] = to_img(recon_pred)[0]
|
246 |
+
intermediate_mus[i+1] = to_img(self.normalize(mu_pos[i]).unsqueeze(0)).astype(np.uint8)[0]
|
247 |
+
|
248 |
+
##One-step denoising
|
249 |
+
t_hat_cur = torch.ones(b).cuda() * (self.t_steps_hierarchy[-1])
|
250 |
+
e_out = self.get_error(z_t.float(), t_hat_cur, cond, unconditional_conditioning, unconditional_guidance_scale)
|
251 |
+
a_t = torch.full((b, 1, 1, 1), alphas[-1]).cuda()
|
252 |
+
sqrt_one_minus_at = torch.sqrt(1 - a_t)
|
253 |
+
pred_z0 = (z_t - sqrt_one_minus_at * e_out) / a_t.sqrt()
|
254 |
+
|
255 |
+
with torch.no_grad():
|
256 |
+
recon = self.model.decode_first_stage(pred_z0)
|
257 |
+
intermediate_preds[-1] = to_img(recon)[0]
|
258 |
+
|
259 |
+
return {"sample" : pred_z0, "loss" : kl_t, "entropy": q_entropy,
|
260 |
+
"intermediates" : intermediate_samples, "interim_preds" :intermediate_preds,
|
261 |
+
"intermediate_mus" : intermediate_mus}
|
262 |
+
|
263 |
+
def grad_and_value(self, x_prev, x_0_hat, measurement, mask_pixel, operator):
|
264 |
+
nll_loss = torch.mean(self.recon_loss(x_0_hat, measurement, mask_pixel, operator)["loss"])
|
265 |
+
norm_grad = torch.autograd.grad(outputs=nll_loss, inputs=x_prev)[0]
|
266 |
+
return norm_grad, nll_loss
|
267 |
+
|
268 |
+
def conditioning(self, x_prev, x_t, x_0_hat, measurement, mask_pixel, scale, operator, **kwargs):
|
269 |
+
norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat,
|
270 |
+
measurement=measurement, mask_pixel=mask_pixel, operator=operator)
|
271 |
+
x_t -= norm_grad*scale
|
272 |
+
return x_t, norm
|
273 |
+
|
274 |
+
def sample(self, scale, eta, mu_pos, logvar_pos, gamma,
|
275 |
+
mask_pixel, y, n_samples=100, cond=None,
|
276 |
+
unconditional_conditioning=None, unconditional_guidance_scale=1,
|
277 |
+
batch_size=10, dir_name="temp/", temp=1,
|
278 |
+
samples_iteration=0, operator = None):
|
279 |
+
sigma_pos = torch.exp(0.5*logvar_pos)
|
280 |
+
t0 = 100
|
281 |
+
num_steps = len(self.t_steps_hierarchy)
|
282 |
+
intermediate_samples = np.zeros((num_steps, 1, self.img_size, self.img_size, 3))
|
283 |
+
intermediate_preds = np.zeros((num_steps, 1, self.img_size, self.img_size, 3))
|
284 |
+
intermediate_mus = np.zeros((num_steps, 1, self.img_size, self.img_size, 3))
|
285 |
+
alphas = self.h_alphas
|
286 |
+
|
287 |
+
##batch your sample generation
|
288 |
+
all_images = []
|
289 |
+
t0 = time.time()
|
290 |
+
save_dir = os.path.join(dir_name , "samples_50_"+ str(scale) ) #50_ #"samples_" + str(scale)
|
291 |
+
os.makedirs(save_dir, exist_ok=True)
|
292 |
+
for _ in trange(n_samples // batch_size, desc="Sampling Batches"):
|
293 |
+
mu_10 = torch.Tensor.repeat(mu_pos[0], [batch_size,1,1,1])
|
294 |
+
tau_t = sigma_pos[0]
|
295 |
+
z_t = torch.sqrt(1 - tau_t**2 )* mu_10 + tau_t * torch.randn_like(mu_10)
|
296 |
+
##Sample from posterior
|
297 |
+
with torch.no_grad():
|
298 |
+
recon = self.model.decode_first_stage(z_t)
|
299 |
+
intermediate_samples[0] = to_img(recon)[0]
|
300 |
+
for i, (t_cur, t_next) in enumerate(zip(self.t_steps_hierarchy[:-1], self.t_steps_hierarchy[1:])):
|
301 |
+
# print(t_cur)
|
302 |
+
t_hat_cur = torch.ones(batch_size).cuda() * (t_cur )
|
303 |
+
a_t = torch.full((batch_size, 1, 1, 1), alphas[i]).cuda()
|
304 |
+
a_prev = torch.full((batch_size, 1, 1, 1), alphas[i+1]).cuda()
|
305 |
+
sigma_t = self.h_sigmas[i+1]
|
306 |
+
#Get prior predictions
|
307 |
+
z_next, pred_x0 = self.prior_preds(z_t.float(), t_hat_cur, cond, a_t, a_prev, sigma_t,
|
308 |
+
unconditional_conditioning, unconditional_guidance_scale)
|
309 |
+
##Posterior means and variances
|
310 |
+
# a_prev.sqrt()*
|
311 |
+
mean_t_1 = self.posterior_mean(a_prev.sqrt()*mu_pos[i+1].unsqueeze(0), z_next, gamma[i+1].unsqueeze(0))
|
312 |
+
std_pos = sigma_pos[i+1]
|
313 |
+
#Sample z_t
|
314 |
+
z_t = mean_t_1 + std_pos * torch.randn_like(mean_t_1)
|
315 |
+
|
316 |
+
with torch.no_grad():
|
317 |
+
pred_x = self.model.decode_first_stage(pred_x0)
|
318 |
+
save_samples(save_dir, pred_x, k=None, num_to_save = 1, file_name = f'sample_{i}.png')
|
319 |
+
|
320 |
+
|
321 |
+
timesteps = np.flip(np.arange(0, self.t_steps_hierarchy[-1].cpu().numpy(), 1))
|
322 |
+
timesteps = np.concatenate((self.t_steps_hierarchy[-1].cpu().reshape(1), timesteps))
|
323 |
+
##Sample using DPS algorithm
|
324 |
+
for i, (step, t_next) in enumerate(zip(timesteps[:-1], timesteps[1:])):
|
325 |
+
step = int(step)
|
326 |
+
t_hat_cur = torch.ones(batch_size).cuda() * (step)
|
327 |
+
a_t = torch.full((batch_size, 1, 1, 1), self.model.alphas_cumprod[step]).cuda()
|
328 |
+
a_prev = torch.full((batch_size, 1, 1, 1), self.model.alphas_cumprod[int(t_next)]).cuda()
|
329 |
+
sigma_t = eta *torch.sqrt( (1 - a_prev) / (1 - a_t) * (1 - a_t / a_prev))
|
330 |
+
z_t = z_t.requires_grad_()
|
331 |
+
z_next, pred_x0 = self.prior_preds(z_t.float(), t_hat_cur, cond, a_t, a_prev, sigma_t,
|
332 |
+
unconditional_conditioning, unconditional_guidance_scale)
|
333 |
+
pred_x = self.model.decode_first_stage(pred_x0)
|
334 |
+
z_t, _ = self.conditioning(x_prev = z_t , x_t = z_next,
|
335 |
+
x_0_hat = pred_x, measurement = y,
|
336 |
+
mask_pixel=mask_pixel, scale=scale, operator=operator)
|
337 |
+
z_t = z_t.detach_()
|
338 |
+
|
339 |
+
if i%50 == 0:
|
340 |
+
with torch.no_grad():
|
341 |
+
recons = self.model.decode_first_stage(pred_x0)
|
342 |
+
recons_np = to_img(recons).astype(np.uint8)
|
343 |
+
self.sampling_queue.put(recons_np)
|
344 |
+
save_samples(save_dir, recons, k=None, num_to_save = 1, file_name = f'det_{step}.png')
|
345 |
+
|
346 |
+
z_0 = pred_x0
|
347 |
+
with torch.no_grad():
|
348 |
+
recon = self.model.decode_first_stage(z_0)
|
349 |
+
intermediate_preds[-1] = to_img(recons)[0]
|
350 |
+
|
351 |
+
with torch.no_grad() :
|
352 |
+
recons = self.model.decode_first_stage(pred_x0)
|
353 |
+
recons_np = to_img(recons).astype(np.uint8)
|
354 |
+
self.sampling_queue.put(recons_np)
|
355 |
+
all_images.append(custom_to_np(recons))
|
356 |
+
|
357 |
+
t1 = time.time()
|
358 |
+
|
359 |
+
all_img = np.concatenate(all_images, axis=0)
|
360 |
+
all_img = all_img[:n_samples]
|
361 |
+
shape_str = "x".join([str(x) for x in all_img.shape])
|
362 |
+
nppath = os.path.join(save_dir, f"{shape_str}-samples.npz")
|
363 |
+
np.savez(nppath, all_img, t1-t0)
|
364 |
+
|
365 |
+
'''
|
366 |
+
recon_in = y*(mask_pixel) + ( 1-mask_pixel)*recons
|
367 |
+
recon_in = to_img(recon_in)
|
368 |
+
image_path = os.path.join(save_dir, str(samples_iteration) + ".png")
|
369 |
+
image_np = recon_in.astype(np.uint8)[0]
|
370 |
+
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
|
371 |
+
'''
|
372 |
+
file_name_img = None
|
373 |
+
|
374 |
+
if operator is None:
|
375 |
+
save_inpaintings(save_dir, recons, y, mask_pixel, num_to_save = batch_size) #recons
|
376 |
+
else:
|
377 |
+
save_samples(save_dir, recons, None, batch_size)
|
378 |
+
recons_np = to_img(recons).astype(np.uint8)
|
379 |
+
self.sampling_queue.put(recons_np)
|
380 |
+
return
|
381 |
+
|
382 |
+
def fit(self, lambda_, cond, shape, quantize_denoised=False, mask_pixel = None,
|
383 |
+
y = None, log_every_t=100, unconditional_guidance_scale=1.,
|
384 |
+
unconditional_conditioning=None, dir_name = None, kl_weight_1=50, kl_weight_2 = 50,
|
385 |
+
debug=False, wdb=False, iterations=200, batch_size = 10, lr_init_gamma=0.01,
|
386 |
+
operator=None, recon_weight = 50):
|
387 |
+
|
388 |
+
if wdb:
|
389 |
+
wandb.init(project='LDM', dir = '/scratch/sakshi/wandb-cache')
|
390 |
+
wandb.config.run_type = 'hierarchical'
|
391 |
+
wandb.run.name = "hierarchical"
|
392 |
+
|
393 |
+
params_to_fit = params_train(lambda_)
|
394 |
+
mu_pos, logvar_pos, gamma = params_to_fit
|
395 |
+
optimizers, schedulers = get_optimizers(mu_pos, logvar_pos, gamma, lr_init_gamma)
|
396 |
+
rec_loss_all, prior_loss_all, posterior_loss_all =[], [], []
|
397 |
+
loss_all = []
|
398 |
+
mu_all, logvar_all, gamma_all = [], [], []
|
399 |
+
for k in range(iterations):
|
400 |
+
if k%100==0: print(k)
|
401 |
+
intermediate_mus = np.zeros((len(self.t_steps_hierarchy), self.latent_size, self.latent_size, self.latent_channels))
|
402 |
+
|
403 |
+
for opt in optimizers: opt.zero_grad()
|
404 |
+
stats_prior = self.loss_prior(mu_pos[0], logvar_pos[0], cond=cond,
|
405 |
+
unconditional_conditioning=unconditional_conditioning,
|
406 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
407 |
+
K=batch_size, intermediate_mus=intermediate_mus)
|
408 |
+
#stats_posterior = self.get_z0_t(stats_prior["sample"], self.t_steps_hierarchy)
|
409 |
+
stats_posterior = self.loss_posterior(stats_prior["sample"], mu_pos[1:], logvar_pos[1:], gamma[1:],
|
410 |
+
cond=cond,
|
411 |
+
unconditional_conditioning=unconditional_conditioning,
|
412 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
413 |
+
K=batch_size, iteration=k, intermediate_mus=stats_prior["intermediate_mus"])
|
414 |
+
sample = self.model.decode_first_stage(stats_posterior["sample"])
|
415 |
+
|
416 |
+
stats_recon = self.recon_loss(sample, y, mask_pixel, operator)
|
417 |
+
num_pixels = 3*256*256 #(1000/num_pixels)* (1000/num_pixels)*
|
418 |
+
loss_total = torch.mean(kl_weight_1*stats_prior["loss"] \
|
419 |
+
+ kl_weight_2*stats_posterior["loss"] + recon_weight*stats_recon["loss"] ) #
|
420 |
+
loss_total.backward()
|
421 |
+
for opt in optimizers: opt.step()
|
422 |
+
for sch in schedulers: sch.step()
|
423 |
+
|
424 |
+
rec_loss_all.append(torch.mean(stats_recon["loss"].detach()).item())
|
425 |
+
prior_loss_all.append(torch.mean(kl_weight_1*stats_prior["loss"].detach()).item())
|
426 |
+
posterior_loss_all.append(torch.mean(kl_weight_2*stats_posterior["loss"].detach()).item())
|
427 |
+
mu_all.append(torch.mean(mu_pos.detach()).item())
|
428 |
+
logvar_all.append(torch.mean(logvar_pos.detach()).item())
|
429 |
+
gamma_all.append(torch.mean(torch.sigmoid(gamma).detach()).item())
|
430 |
+
sample_np = to_img(sample).astype(np.uint8)
|
431 |
+
loss_all.append(loss_total.detach().item())
|
432 |
+
self.image_queue.put(sample_np)
|
433 |
+
|
434 |
+
|
435 |
+
save_plot(dir_name, [rec_loss_all, prior_loss_all, posterior_loss_all],
|
436 |
+
["Recon loss", "Diffusion loss", "Hierarchical loss"], "loss.png")
|
437 |
+
save_plot(dir_name, [loss_all],
|
438 |
+
["Total Loss"], "loss_t.png")
|
439 |
+
save_plot(dir_name, [mu_all],
|
440 |
+
["mean"], "mean.png")
|
441 |
+
save_plot(dir_name, [logvar_all],
|
442 |
+
["logvar"], "logvar.png")
|
443 |
+
save_plot(dir_name, [gamma_all],
|
444 |
+
["gamma"], "gamma.png")
|
445 |
+
|
446 |
+
if k%log_every_t == 0 or k == iterations - 1:
|
447 |
+
save_samples(os.path.join(dir_name , "progress"), sample, k, batch_size)
|
448 |
+
save_samples(os.path.join(dir_name , "mus"), stats_posterior["intermediate_mus"], k,
|
449 |
+
len(stats_posterior["intermediate_mus"]))
|
450 |
+
|
451 |
+
#save_inpaintings(os.path.join(dir_name , "progress_inpaintings"), sample, y,
|
452 |
+
# mask_pixel, k, num_to_save = 5)
|
453 |
+
save_params(os.path.join(dir_name , "params"), mu_pos, logvar_pos, gamma,k)
|
454 |
+
|
455 |
+
gc.collect()
|
456 |
+
return
|
457 |
+
|
458 |
+
##unconditional samplinng for debugging purposes:
|
459 |
+
'''
|
460 |
+
def sample_T(self, x0, cond, unconditional_conditioning, unconditional_guidance_scale , eta=0.4, t_steps_hierarchy=None, dir_="out_temp2"):
|
461 |
+
''
|
462 |
+
sigma_discretization_edm = time_descretization(sigma_min=0.002, sigma_max = 999, rho = 7, num_t_steps = 10)/1000
|
463 |
+
T_max = 1000
|
464 |
+
beta_start = 1 # 0.0015*T_max
|
465 |
+
beta_end = 15 # 0.0155*T_max
|
466 |
+
def var(t):
|
467 |
+
return 1.0 - (1.0) * torch.exp(- beta_start * t - 0.5 * (beta_end - beta_start) * t * t)
|
468 |
+
''
|
469 |
+
|
470 |
+
x0 = torch.randn_like(x0)
|
471 |
+
t_steps_hierarchy = torch.tensor(self.t_steps_hierarchy).cuda()
|
472 |
+
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[0]].reshape(1, 1 ,1 ,1))**2 # self.var(t_steps_hierarchy[0])
|
473 |
+
x_t = x0 # torch.sqrt(1 - var_t) * x0 + torch.sqrt(var_t) * torch.randn_like(x0)
|
474 |
+
|
475 |
+
os.makedirs(dir_, exist_ok=True)
|
476 |
+
alphas = self.h_alphas
|
477 |
+
b = 5
|
478 |
+
for i, t in enumerate(t_steps_hierarchy[:-1]):
|
479 |
+
t_hat = torch.ones(b).cuda() * (t)
|
480 |
+
a_t = torch.full((b, 1, 1, 1), alphas[i]).cuda()
|
481 |
+
a_prev = torch.full((b, 1, 1, 1), alphas[i+1]).cuda()
|
482 |
+
sigma_t = self.h_sigmas[i+1]
|
483 |
+
x_t, pred_x0 = self.prior_preds(x_t.float(), t_hat, cond, a_t, a_prev, sigma_t,
|
484 |
+
unconditional_conditioning, unconditional_guidance_scale)
|
485 |
+
|
486 |
+
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t].reshape(1, 1 ,1 ,1))**2
|
487 |
+
a_t = 1 - var_t
|
488 |
+
x_t = x_t + sigma_t*torch.randn_like(x_t)
|
489 |
+
recon = self.model.decode_first_stage(pred_x0)
|
490 |
+
image_path = os.path.join(dir_, f'{i}.png')
|
491 |
+
image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0]
|
492 |
+
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
|
493 |
+
|
494 |
+
t_hat_cur = torch.ones(b).cuda() * (self.t_steps_hierarchy[-1])
|
495 |
+
e_out = self.get_error(x_t.float(), t_hat_cur, cond, unconditional_conditioning, unconditional_guidance_scale)
|
496 |
+
a_t = torch.full((b, 1, 1, 1), alphas[-1]).cuda()
|
497 |
+
sqrt_one_minus_at = torch.sqrt(1 - a_t)
|
498 |
+
pred_x0 = (x_t - sqrt_one_minus_at * e_out) / a_t.sqrt()
|
499 |
+
|
500 |
+
recon = self.model.decode_first_stage(pred_x0)
|
501 |
+
image_path = os.path.join(dir_, f'{len(t_steps_hierarchy)}.png')
|
502 |
+
image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0]
|
503 |
+
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
|
504 |
+
return
|
505 |
+
|
506 |
+
'''
|
ldm/guided_diffusion/loss_vq.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import repeat
|
5 |
+
|
6 |
+
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
7 |
+
from taming.modules.losses.lpips import LPIPS
|
8 |
+
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
|
9 |
+
|
10 |
+
|
11 |
+
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
|
12 |
+
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
|
13 |
+
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
|
14 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
|
15 |
+
loss_real = (weights * loss_real).sum() / weights.sum()
|
16 |
+
loss_fake = (weights * loss_fake).sum() / weights.sum()
|
17 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
18 |
+
return d_loss
|
19 |
+
|
20 |
+
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
21 |
+
if global_step < threshold:
|
22 |
+
weight = value
|
23 |
+
return weight
|
24 |
+
|
25 |
+
|
26 |
+
def measure_perplexity(predicted_indices, n_embed):
|
27 |
+
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
28 |
+
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
29 |
+
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
|
30 |
+
avg_probs = encodings.mean(0)
|
31 |
+
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
32 |
+
cluster_use = torch.sum(avg_probs > 0)
|
33 |
+
return perplexity, cluster_use
|
34 |
+
|
35 |
+
def l1(x, y):
|
36 |
+
return torch.abs(x-y)
|
37 |
+
|
38 |
+
|
39 |
+
def l2(x, y):
|
40 |
+
return torch.pow((x-y), 2)
|
41 |
+
|
42 |
+
|
43 |
+
class VQLPIPSWithDiscriminator(nn.Module):
|
44 |
+
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
|
45 |
+
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
46 |
+
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
47 |
+
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
|
48 |
+
pixel_loss="l1"):
|
49 |
+
super().__init__()
|
50 |
+
assert disc_loss in ["hinge", "vanilla"]
|
51 |
+
assert perceptual_loss in ["lpips", "clips", "dists"]
|
52 |
+
assert pixel_loss in ["l1", "l2"]
|
53 |
+
self.codebook_weight = codebook_weight
|
54 |
+
self.pixel_weight = pixelloss_weight
|
55 |
+
if perceptual_loss == "lpips":
|
56 |
+
print(f"{self.__class__.__name__}: Running with LPIPS.")
|
57 |
+
self.perceptual_loss = LPIPS().eval().to(device="cuda")
|
58 |
+
else:
|
59 |
+
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
|
60 |
+
self.perceptual_weight = perceptual_weight
|
61 |
+
|
62 |
+
if pixel_loss == "l1":
|
63 |
+
self.pixel_loss = l1
|
64 |
+
else:
|
65 |
+
self.pixel_loss = l2
|
66 |
+
|
67 |
+
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
68 |
+
n_layers=disc_num_layers,
|
69 |
+
use_actnorm=use_actnorm,
|
70 |
+
ndf=disc_ndf
|
71 |
+
).apply(weights_init).cuda()
|
72 |
+
self.discriminator.eval()
|
73 |
+
self.discriminator_iter_start = disc_start
|
74 |
+
if disc_loss == "hinge":
|
75 |
+
self.disc_loss = hinge_d_loss
|
76 |
+
elif disc_loss == "vanilla":
|
77 |
+
self.disc_loss = vanilla_d_loss
|
78 |
+
else:
|
79 |
+
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
80 |
+
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
81 |
+
self.disc_factor = disc_factor
|
82 |
+
self.discriminator_weight = disc_weight
|
83 |
+
self.disc_conditional = disc_conditional
|
84 |
+
self.n_classes = n_classes
|
85 |
+
|
86 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
87 |
+
if last_layer is not None:
|
88 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
89 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
90 |
+
else:
|
91 |
+
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
92 |
+
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
93 |
+
|
94 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
95 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
96 |
+
d_weight = d_weight * self.discriminator_weight
|
97 |
+
return d_weight
|
98 |
+
|
99 |
+
def forward(self, codebook_loss, inputs, reconstructions, mask, optimizer_idx,
|
100 |
+
global_step, last_layer=None, cond=None, split="train", predicted_indices=None,
|
101 |
+
operator=None, noiser = None):
|
102 |
+
|
103 |
+
#if not exists(codebook_loss):
|
104 |
+
# codebook_loss = torch.tensor([0.]).to(inputs.device)
|
105 |
+
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
106 |
+
'''
|
107 |
+
if operator is not None: x = operator.forward(reconstructions)
|
108 |
+
else: x = reconstructions.contiguous()
|
109 |
+
rec_loss = torch.abs(inputs - x)
|
110 |
+
'''
|
111 |
+
#rec_loss = torch.sum(rec_loss, dim=[1,2,3])
|
112 |
+
#rec_loss = torch.linalg.norm(difference)
|
113 |
+
if operator is not None : x = operator.forward(reconstructions)
|
114 |
+
else :
|
115 |
+
x = reconstructions.contiguous()*mask
|
116 |
+
inputs = inputs.contiguous()*mask
|
117 |
+
rec_loss = self.pixel_loss(inputs,x)
|
118 |
+
std = 0.566 #+ 0.05
|
119 |
+
|
120 |
+
#rec_loss = torch.abs(inputs.contiguous()*(mask) - reconstructions.contiguous()*(mask))
|
121 |
+
#nll_loss = torch.linalg.norm(rec_loss)
|
122 |
+
#num_obs = torch.sum(mask)
|
123 |
+
|
124 |
+
if self.perceptual_weight > 0:
|
125 |
+
if operator is None:
|
126 |
+
p_loss = self.perceptual_loss(mask*inputs.contiguous().float(), mask*reconstructions.contiguous().float())
|
127 |
+
else:
|
128 |
+
p_loss = torch.tensor([0.0])
|
129 |
+
# p_loss = self.perceptual_loss(inputs.contiguous().float(), reconstructions.contiguous().float())
|
130 |
+
|
131 |
+
rec_loss = rec_loss #+ self.perceptual_weight * p_loss #.reshape(rec_loss.shape[0]) #
|
132 |
+
else:
|
133 |
+
p_loss = torch.tensor([0.0])
|
134 |
+
|
135 |
+
#rec_loss = torch.mean(rec_loss, dim =[1,2,3])
|
136 |
+
|
137 |
+
nll_loss = rec_loss /(2*std**2) #+ 2* torch.log(std) #+ self.logvar
|
138 |
+
nll_loss = 100*torch.mean(nll_loss) + 100*self.perceptual_weight * p_loss.squeeze() #/ (nll_loss.shape[0]) #num_obs
|
139 |
+
|
140 |
+
#rec_loss = torch.sum(rec_loss, dim=[1,2,3]) / (torch.sum(mask)*3) #*1000 #rec_loss.shape[0]*
|
141 |
+
|
142 |
+
#nll_loss = torch.mean(rec_loss)
|
143 |
+
|
144 |
+
#nll_loss = torch.mean(nll_loss) + self.codebook_weight * codebook_loss.mean()
|
145 |
+
return nll_loss, nll_loss
|
146 |
+
# now the GAN part
|
147 |
+
if optimizer_idx == 0:
|
148 |
+
# generator update
|
149 |
+
if cond is None:
|
150 |
+
assert not self.disc_conditional
|
151 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
152 |
+
else:
|
153 |
+
assert self.disc_conditional
|
154 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
155 |
+
g_loss = -torch.mean(logits_fake) #200*
|
156 |
+
|
157 |
+
'''
|
158 |
+
try:
|
159 |
+
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
160 |
+
except RuntimeError:
|
161 |
+
assert not self.training
|
162 |
+
d_weight = torch.tensor(0.0)
|
163 |
+
|
164 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
165 |
+
'''
|
166 |
+
#d_weight * disc_factor *
|
167 |
+
loss = nll_loss + g_loss + self.codebook_weight * codebook_loss.mean()
|
168 |
+
|
169 |
+
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
170 |
+
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
|
171 |
+
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
172 |
+
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
173 |
+
#"{}/p_loss".format(split): p_loss.detach().mean(),
|
174 |
+
#"{}/d_weight".format(split): d_weight.detach(),
|
175 |
+
#"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
176 |
+
"{}/g_loss".format(split): g_loss.detach().mean(),
|
177 |
+
}
|
178 |
+
|
179 |
+
if predicted_indices is not None:
|
180 |
+
assert self.n_classes is not None
|
181 |
+
with torch.no_grad():
|
182 |
+
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
|
183 |
+
log[f"{split}/perplexity"] = perplexity
|
184 |
+
log[f"{split}/cluster_usage"] = cluster_usage
|
185 |
+
return loss, log
|
186 |
+
|
187 |
+
if optimizer_idx == 1:
|
188 |
+
# second pass for discriminator update
|
189 |
+
if cond is None:
|
190 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
191 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
192 |
+
else:
|
193 |
+
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
194 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
195 |
+
|
196 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
197 |
+
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
198 |
+
|
199 |
+
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
200 |
+
"{}/logits_real".format(split): logits_real.detach().mean(),
|
201 |
+
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
202 |
+
}
|
203 |
+
return d_loss, log
|
ldm/guided_diffusion/losses.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
|
5 |
+
|
6 |
+
class LPIPSWithDiscriminator(nn.Module):
|
7 |
+
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
|
8 |
+
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
9 |
+
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
10 |
+
disc_loss="hinge"):
|
11 |
+
|
12 |
+
super().__init__()
|
13 |
+
assert disc_loss in ["hinge", "vanilla"]
|
14 |
+
self.kl_weight = kl_weight
|
15 |
+
self.pixel_weight = pixelloss_weight
|
16 |
+
self.perceptual_loss = LPIPS().eval().cuda()
|
17 |
+
self.perceptual_weight = perceptual_weight
|
18 |
+
# output log variance
|
19 |
+
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
20 |
+
|
21 |
+
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
22 |
+
n_layers=disc_num_layers,
|
23 |
+
use_actnorm=use_actnorm
|
24 |
+
).apply(weights_init).cuda()
|
25 |
+
self.discriminator_iter_start = disc_start
|
26 |
+
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
27 |
+
self.disc_factor = disc_factor
|
28 |
+
self.discriminator_weight = disc_weight
|
29 |
+
self.disc_conditional = disc_conditional
|
30 |
+
|
31 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, reconstructions, last_layer=None):
|
32 |
+
if last_layer is not None:
|
33 |
+
|
34 |
+
nll_grads = torch.autograd.grad(nll_loss, reconstructions, retain_graph=True)[0]
|
35 |
+
g_grads = torch.autograd.grad(g_loss, reconstructions, retain_graph=True)[0]
|
36 |
+
else:
|
37 |
+
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
38 |
+
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
39 |
+
|
40 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
41 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
42 |
+
d_weight = d_weight * self.discriminator_weight
|
43 |
+
return d_weight
|
44 |
+
|
45 |
+
def forward(self, inputs, reconstructions, mask, optimizer_idx,
|
46 |
+
global_step, posteriors = None, last_layer=None, cond=None, split="train",
|
47 |
+
weights=None):
|
48 |
+
rec_loss = torch.abs(inputs.contiguous()*(mask) - reconstructions.contiguous()*(mask))
|
49 |
+
if self.perceptual_weight > 0:
|
50 |
+
p_loss = self.perceptual_loss(inputs.contiguous()*(mask), reconstructions.contiguous()*(mask))
|
51 |
+
rec_loss = rec_loss
|
52 |
+
|
53 |
+
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
54 |
+
#weighted_nll_loss = nll_loss
|
55 |
+
#if weights is not None:
|
56 |
+
# weighted_nll_loss = weights*nll_loss
|
57 |
+
|
58 |
+
#weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
59 |
+
nll_loss = 100*torch.mean(nll_loss, dim = [1,2,3]) + 100*self.perceptual_weight * p_loss.squeeze() #/ nll_loss.shape[0]
|
60 |
+
#kl_loss = posteriors.kl()
|
61 |
+
#kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
62 |
+
|
63 |
+
return nll_loss, nll_loss
|
64 |
+
#return weighted_nll_loss, nll_loss
|
65 |
+
# now the GAN part
|
66 |
+
if optimizer_idx == 0:
|
67 |
+
# generator update
|
68 |
+
if cond is None:
|
69 |
+
assert not self.disc_conditional
|
70 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
71 |
+
else:
|
72 |
+
assert self.disc_conditional
|
73 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
74 |
+
g_loss = -torch.mean(logits_fake)
|
75 |
+
|
76 |
+
if self.disc_factor > 0.0:
|
77 |
+
try:
|
78 |
+
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, reconstructions, last_layer=last_layer)
|
79 |
+
except RuntimeError:
|
80 |
+
assert not self.training
|
81 |
+
d_weight = torch.tensor(0.0)
|
82 |
+
else:
|
83 |
+
d_weight = torch.tensor(0.0)
|
84 |
+
|
85 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
86 |
+
#+ self.kl_weight * kl_loss
|
87 |
+
#print("GAN Losss : ", d_weight * g_loss)
|
88 |
+
loss = weighted_nll_loss #+ d_weight * g_loss
|
89 |
+
|
90 |
+
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
|
91 |
+
#"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
|
92 |
+
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
93 |
+
"{}/d_weight".format(split): d_weight.detach(),
|
94 |
+
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
95 |
+
"{}/g_loss".format(split): g_loss.detach().mean(),
|
96 |
+
}
|
97 |
+
return loss, log
|
98 |
+
|
99 |
+
if optimizer_idx == 1:
|
100 |
+
# second pass for discriminator update
|
101 |
+
if cond is None:
|
102 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
103 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
104 |
+
else:
|
105 |
+
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
106 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
107 |
+
|
108 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
109 |
+
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
110 |
+
|
111 |
+
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
112 |
+
"{}/logits_real".format(split): logits_real.detach().mean(),
|
113 |
+
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
114 |
+
}
|
115 |
+
return d_loss, log
|
116 |
+
|