"""INFERENCE TIME OPTIMIZATION""" import torch import numpy as np from tqdm import tqdm from functools import partial import torch.distributions as td import gc import wandb import matplotlib.pyplot as plt from utils.helper import params_train, get_optimizers,clean_directory, time_descretization, to_img, custom_to_np, save_params, save_samples, save_inpaintings, save_plot import os import PIL import glob from tqdm import trange import time from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, extract_into_tensor, noise_like import wandb class HPosterior(object): def __init__(self, model, vae_loss, t_steps_hierarchy, eta=0.4, z0_size=32, img_size = 256, latent_channels = 3, num_hierarchy_steps=5, schedule="linear", first_stage = "kl", posterior = "hierarchical", image_queue = None, sampling_queue=None, **kwargs): super().__init__() self.model = model #prior noise prediction model self.schedule = schedule #noise schedule the prior was trained on self.vae_loss = vae_loss #vae loss followed during training self.eta = eta #eta used to produce faster, clean samples self.first_stage= first_stage #first stage training procedure: kl or vq loss self.posterior = posterior self.t_steps_hierarchy = np.array(t_steps_hierarchy) #time steps for hierachical posterior self.z0size = z0_size #dimension of latent space variables z self.img_size = img_size #512 # self.latent_size = z0_size #128 # self.latent_channels = latent_channels self.image_queue = image_queue self.sampling_queue = sampling_queue def q_given_te(self, t, s, shape, zeta_t_star=None): if zeta_t_star is not None: alpha_s = torch.sqrt(1 - zeta_t_star**2) var_s = zeta_t_star**2 else: if len(s.shape) == 0 :m = 1 else: m = s.shape[0] var_s = (self.model.sqrt_one_minus_alphas_cumprod[s].reshape(m, 1 ,1 ,1))**2 alpha_s = torch.sqrt(1 - var_s) var_t = (self.model.sqrt_one_minus_alphas_cumprod[t])**2 alpha_t = torch.sqrt(1 - var_t) alpha_t_s = alpha_t.reshape(len(var_t), 1 ,1 ,1) / alpha_s var_t_s = var_t.reshape(len(var_t), 1 ,1 ,1) - alpha_t_s**2 * var_s return alpha_t_s, torch.sqrt(var_t_s) def qpos_given_te(self, t, s, t_star, z_t_star, z_t, zeta_T_star=None): alpha_t_s, scale_t_s = self.q_given_te(t, s, z_t_star.shape) alpha_s_t_star, scale_s_t_star = self.q_given_te(s, t_star, z_t_star.shape, zeta_T_star) var = scale_t_s**2 * scale_s_t_star**2 / (scale_t_s**2 + alpha_s_t_star**2 * scale_s_t_star**2 ) mean = (var) * ( (alpha_s_t_star/scale_s_t_star**2) * z_t_star + (alpha_t_s/scale_t_s**2) * z_t ) return mean, torch.sqrt(var) def register_buffer(self, name, attr): if isinstance(attr, torch.Tensor): if not attr.is_cuda: attr = attr.cuda() setattr(self, name, attr) def get_error(self,x,t,c, unconditional_conditioning, unconditional_guidance_scale): if unconditional_conditioning is None or unconditional_guidance_scale == 1.: e_t = self.model.apply_model(x.float(), t, c) else: x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) c_in = torch.cat([unconditional_conditioning, c]) e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) return e_t def descretize(self, rho): #Get time descretization for prior loss (t > T_e) self.timesteps_1000 = time_descretization(sigma_min=0.002, sigma_max = 0.999, rho = rho, num_t_steps = 1000)*1000 self.timesteps_1000 = self.timesteps_1000.cuda().long() sigma_timesteps = self.model.sqrt_one_minus_alphas_cumprod[self.timesteps_1000] self.register_buffer('sigma_timesteps', sigma_timesteps) #Get prior std for hierarchical time points sigma_hierarchy = self.model.sqrt_one_minus_alphas_cumprod[self.t_steps_hierarchy] self.t_steps_hierarchy = torch.tensor(self.t_steps_hierarchy.copy()).cuda() alphas_h = 1 - sigma_hierarchy**2 alphas_prev = torch.concatenate([ alphas_h[1:], alphas_h[-1].reshape(1)]) h_sigmas = torch.sqrt(self.eta * (1 - alphas_prev) / (1 - alphas_h) * (1 - alphas_h / alphas_prev) ) h_sigmas[1:] = torch.sqrt(self.eta * (1 - alphas_prev[:-1]) / (1 - alphas_h[:-1]) * (1 - alphas_h[:-1] / alphas_prev[:-1]) ) h_sigmas[0] = torch.sqrt(1 - alphas_h[0]) #register tensors self.register_buffer('h_alphas', alphas_h) self.register_buffer('h_alphas_prev', alphas_prev) self.register_buffer('h_sigmas', h_sigmas) def init(self, img, std_scale, mean_scale, prior_scale, mean_scale_top = 0.1): num_h_steps = len(self.t_steps_hierarchy) img = torch.Tensor.repeat(img,[num_h_steps,1,1,1])[:num_h_steps] #sigmas = self.h_sigmas[...,None, None, None].expand(img.shape) sigmas = torch.zeros_like(img) sqrt_alphas = torch.sqrt(self.h_alphas)[...,None, None, None].expand(img.shape) sqrt_one_minus_alphas = torch.sqrt(1 - self.h_alphas)[...,None, None, None].expand(img.shape) ## Variances for posterior sigmas[0] = self.h_sigmas[0, None, None, None].expand(img[0].shape) sigmas[1:] = std_scale * (1/np.sqrt(self.eta)) * self.h_sigmas[1:, None, None, None].expand(img[1:].shape) logvar_pos = 2*torch.log(sigmas).float() ## Means : mean_pos = sqrt_alphas*img + mean_scale*sqrt_one_minus_alphas* torch.randn_like(img) mean_pos[0] = img[0] + mean_scale_top*torch.randn_like(img[0]) ## Gammas for posterior weighing between prior and posterior gamma = torch.tensor(prior_scale)[None,None,None,None].expand(img.shape).cuda() return mean_pos, logvar_pos, gamma.float() def get_kl(self,mu1, mu2, scale1, scale2, wt): return wt*(1/2*scale2**2)*(mu1 - mu2)**2 \ + torch.log(scale2/scale1) + scale1**2/(2*scale2**2) - 1/2 # diffusion loss def loss_prior(self, mu_pos, logvar_pos, cond=None, unconditional_conditioning=None, unconditional_guidance_scale=1, K=10, intermediate_mus=None): ''' This function gets the kl between q(x_{T_e})||p(x_T_e) ) = E_{t>T*_e}[(x_T_e - \mu_\theta(x_t))^2] x_T_e = z_t_star, samples from q(x_{T_e}) Sample z_t by adding noise scaled by sqrt(\sigma_t^2 - \zeta_t^2) so that z_t matches total noise at t ''' t_e = self.t_steps_hierarchy[0] ## Sample z_{T_e} tau_te = torch.exp(0.5*logvar_pos) mu_te = torch.Tensor.repeat(mu_pos, [K,1,1,1]) z_te = torch.sqrt(1 - tau_te**2 )* mu_te + tau_te * torch.randn_like(mu_te) ## Sample t #Get allowed timesteps > T_e t_g = torch.where(self.sigma_timesteps > torch.max(tau_te))[0] t_allowed = self.timesteps_1000[t_g] # print(len(t_g)) def sample_uniform(t_allowed): t0 = torch.rand(1) T_max = len(t_allowed) T_min = 2 #stay away from close values to T* t = torch.remainder(t0 + torch.arange(0., 1., step=1. / K), 1.)*(T_max-T_min) + T_min t = torch.floor(t).long() return t t = sample_uniform(t_allowed) t_cur = t_allowed[t] t_prev = t_allowed[t-1] # print((t_cur - t_prev), t_cur) #sample z_t from p(z_t | z_{T_e}) alpha_t, scale_t = self.q_given_te(t_cur, t_e, z_te.shape, tau_te) error = torch.randn_like(z_te) z_t = alpha_t*z_te + error* scale_t #Get prior, posterior mean variances for t_prev e_out = self.get_error(z_t.float(), t_cur, cond, unconditional_conditioning, unconditional_guidance_scale) alpha_t_, scale_t_ = self.q_given_te(t_cur,t_e, z_te.shape) mu_t_hat = (z_t - scale_t_*e_out)/alpha_t_ pos_mean, pos_scale = self.qpos_given_te(t_cur, t_prev, t_e, z_te, z_t, tau_te) prior_mean, prior_scale = self.qpos_given_te(t_cur, t_prev, t_e, mu_t_hat, z_t, None) wt = (1000-t_e)/2 kl = self.get_kl(pos_mean, prior_mean,pos_scale, prior_scale, wt=1) kl = torch.mean(wt*kl, dim=[1,2,3]) return {"loss" : kl, "sample" : z_te, "intermediate_mus" : intermediate_mus} def recon_loss(self, samples_pixel, x0_pixel, mask_pixel, operator=None): global_step = 0 if self.first_stage == "kl": nll_loss, _ = self.vae_loss(x0_pixel, samples_pixel, mask_pixel, 0, global_step, last_layer=self.model.first_stage_model.get_last_layer(), split="val") else: qloss = torch.tensor([0.]).cuda() nll_loss, _ = self.vae_loss(qloss, x0_pixel, samples_pixel, mask_pixel, 0, 0, last_layer=self.model.first_stage_model.get_last_layer(), split="val", predicted_indices=None, operator=operator) #nll_loss = nll_loss/1000 return { "loss" : nll_loss} def prior_preds(self, z_t, t_cur, cond, a_t, a_prev, sigma_t, unconditional_conditioning, unconditional_guidance_scale ): #Get e, pred_x0 e_out = self.get_error(z_t, t_cur, cond, unconditional_conditioning, unconditional_guidance_scale) pred_x0 = (z_t - torch.sqrt(1 - a_t) * e_out) / a_t.sqrt() # direction pointing to x_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_out z_next = a_prev.sqrt() * pred_x0 + dir_xt return z_next, pred_x0 def posterior_mean(self, mu_pos, mu_prior, gamma): wt = torch.sigmoid(gamma) mean_t_1 = wt*mu_prior + (1-wt)*mu_pos return mean_t_1 def normalize(self, img): img -= torch.min(img) return 2*img/torch.max(img) - 1 def loss_posterior(self, z_t, mu_pos, logvar_pos, gamma, cond=None, unconditional_conditioning=None, unconditional_guidance_scale=1, K=10, iteration=0, to_sample = False, intermediate_mus=None): sigma_pos = torch.exp(0.5*logvar_pos) kl_t, t0, q_entropy = torch.zeros(z_t.shape[0]).cuda(), 100, 0 num_steps = len(self.t_steps_hierarchy) intermediate_samples = np.zeros((num_steps, 1, self.img_size, self.img_size, 3)) intermediate_preds = np.zeros((num_steps, 1, self.img_size, self.img_size, 3)) b = z_t.shape[0] with torch.no_grad(): recon = self.model.decode_first_stage(z_t) intermediate_samples[0] = to_img(recon)[0] alphas = self.h_alphas for i, (t_cur, t_next) in enumerate(zip(self.t_steps_hierarchy[:-1], self.t_steps_hierarchy[1:])): t_hat_cur = torch.ones(b).cuda() * (t_cur ) a_t = torch.full((b, 1, 1, 1), alphas[i]).cuda() a_prev = torch.full((b, 1, 1, 1), alphas[i+1]).cuda() a_t_prev = a_t/a_prev sigma_t = self.h_sigmas[i+1] #Get prior predictions z_next, pred_x0 = self.prior_preds(z_t.float(), t_hat_cur, cond, a_t, a_prev, sigma_t, unconditional_conditioning, unconditional_guidance_scale) std_prior = self.h_sigmas[i+1] ##Posterior means and variances pos_mean = self.posterior_mean(a_prev.sqrt()*mu_pos[i].unsqueeze(0), z_next, gamma[i].unsqueeze(0)) std_pos = sigma_pos[i] ## Sample z_t z_t = pos_mean + std_pos * torch.randn_like(pos_mean) #Get kl kl = self.get_kl(pos_mean, z_next, std_pos, std_prior, wt=1) kl_t += torch.mean(kl, dim=[1,2,3]) with torch.no_grad(): recon_pred = self.model.decode_first_stage(pred_x0) intermediate_preds[i] = to_img(recon_pred)[0] intermediate_mus[i+1] = to_img(self.normalize(mu_pos[i]).unsqueeze(0)).astype(np.uint8)[0] ##One-step denoising t_hat_cur = torch.ones(b).cuda() * (self.t_steps_hierarchy[-1]) e_out = self.get_error(z_t.float(), t_hat_cur, cond, unconditional_conditioning, unconditional_guidance_scale) a_t = torch.full((b, 1, 1, 1), alphas[-1]).cuda() sqrt_one_minus_at = torch.sqrt(1 - a_t) pred_z0 = (z_t - sqrt_one_minus_at * e_out) / a_t.sqrt() with torch.no_grad(): recon = self.model.decode_first_stage(pred_z0) intermediate_preds[-1] = to_img(recon)[0] return {"sample" : pred_z0, "loss" : kl_t, "entropy": q_entropy, "intermediates" : intermediate_samples, "interim_preds" :intermediate_preds, "intermediate_mus" : intermediate_mus} def grad_and_value(self, x_prev, x_0_hat, measurement, mask_pixel, operator): nll_loss = torch.mean(self.recon_loss(x_0_hat, measurement, mask_pixel, operator)["loss"]) norm_grad = torch.autograd.grad(outputs=nll_loss, inputs=x_prev)[0] return norm_grad, nll_loss def conditioning(self, x_prev, x_t, x_0_hat, measurement, mask_pixel, scale, operator, **kwargs): norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat, measurement=measurement, mask_pixel=mask_pixel, operator=operator) x_t -= norm_grad*scale return x_t, norm def sample(self, scale, eta, mu_pos, logvar_pos, gamma, mask_pixel, y, n_samples=100, cond=None, unconditional_conditioning=None, unconditional_guidance_scale=1, batch_size=10, dir_name="temp/", temp=1, samples_iteration=0, operator = None): sigma_pos = torch.exp(0.5*logvar_pos) t0 = 100 num_steps = len(self.t_steps_hierarchy) intermediate_samples = np.zeros((num_steps, 1, self.img_size, self.img_size, 3)) intermediate_preds = np.zeros((num_steps, 1, self.img_size, self.img_size, 3)) intermediate_mus = np.zeros((num_steps, 1, self.img_size, self.img_size, 3)) alphas = self.h_alphas ##batch your sample generation all_images = [] t0 = time.time() save_dir = os.path.join(dir_name , "samples_50_"+ str(scale) ) #50_ #"samples_" + str(scale) os.makedirs(save_dir, exist_ok=True) for _ in trange(n_samples // batch_size, desc="Sampling Batches"): mu_10 = torch.Tensor.repeat(mu_pos[0], [batch_size,1,1,1]) tau_t = sigma_pos[0] z_t = torch.sqrt(1 - tau_t**2 )* mu_10 + tau_t * torch.randn_like(mu_10) ##Sample from posterior with torch.no_grad(): recon = self.model.decode_first_stage(z_t) intermediate_samples[0] = to_img(recon)[0] for i, (t_cur, t_next) in enumerate(zip(self.t_steps_hierarchy[:-1], self.t_steps_hierarchy[1:])): # print(t_cur) t_hat_cur = torch.ones(batch_size).cuda() * (t_cur ) a_t = torch.full((batch_size, 1, 1, 1), alphas[i]).cuda() a_prev = torch.full((batch_size, 1, 1, 1), alphas[i+1]).cuda() sigma_t = self.h_sigmas[i+1] #Get prior predictions z_next, pred_x0 = self.prior_preds(z_t.float(), t_hat_cur, cond, a_t, a_prev, sigma_t, unconditional_conditioning, unconditional_guidance_scale) ##Posterior means and variances # a_prev.sqrt()* mean_t_1 = self.posterior_mean(a_prev.sqrt()*mu_pos[i+1].unsqueeze(0), z_next, gamma[i+1].unsqueeze(0)) std_pos = sigma_pos[i+1] #Sample z_t z_t = mean_t_1 + std_pos * torch.randn_like(mean_t_1) with torch.no_grad(): pred_x = self.model.decode_first_stage(pred_x0) save_samples(save_dir, pred_x, k=None, num_to_save = 1, file_name = f'sample_{i}.png') timesteps = np.flip(np.arange(0, self.t_steps_hierarchy[-1].cpu().numpy(), 1)) timesteps = np.concatenate((self.t_steps_hierarchy[-1].cpu().reshape(1), timesteps)) ##Sample using DPS algorithm for i, (step, t_next) in enumerate(zip(timesteps[:-1], timesteps[1:])): step = int(step) t_hat_cur = torch.ones(batch_size).cuda() * (step) a_t = torch.full((batch_size, 1, 1, 1), self.model.alphas_cumprod[step]).cuda() a_prev = torch.full((batch_size, 1, 1, 1), self.model.alphas_cumprod[int(t_next)]).cuda() sigma_t = eta *torch.sqrt( (1 - a_prev) / (1 - a_t) * (1 - a_t / a_prev)) z_t = z_t.requires_grad_() z_next, pred_x0 = self.prior_preds(z_t.float(), t_hat_cur, cond, a_t, a_prev, sigma_t, unconditional_conditioning, unconditional_guidance_scale) pred_x = self.model.decode_first_stage(pred_x0) z_t, _ = self.conditioning(x_prev = z_t , x_t = z_next, x_0_hat = pred_x, measurement = y, mask_pixel=mask_pixel, scale=scale, operator=operator) z_t = z_t.detach_() if i%50 == 0: with torch.no_grad(): recons = self.model.decode_first_stage(pred_x0) recons_np = to_img(recons).astype(np.uint8) self.sampling_queue.put(recons_np) save_samples(save_dir, recons, k=None, num_to_save = 1, file_name = f'det_{step}.png') z_0 = pred_x0 with torch.no_grad(): recon = self.model.decode_first_stage(z_0) intermediate_preds[-1] = to_img(recons)[0] with torch.no_grad() : recons = self.model.decode_first_stage(pred_x0) recons_np = to_img(recons).astype(np.uint8) self.sampling_queue.put(recons_np) all_images.append(custom_to_np(recons)) t1 = time.time() all_img = np.concatenate(all_images, axis=0) all_img = all_img[:n_samples] shape_str = "x".join([str(x) for x in all_img.shape]) nppath = os.path.join(save_dir, f"{shape_str}-samples.npz") np.savez(nppath, all_img, t1-t0) ''' recon_in = y*(mask_pixel) + ( 1-mask_pixel)*recons recon_in = to_img(recon_in) image_path = os.path.join(save_dir, str(samples_iteration) + ".png") image_np = recon_in.astype(np.uint8)[0] PIL.Image.fromarray(image_np, 'RGB').save(image_path) ''' file_name_img = None if operator is None: save_inpaintings(save_dir, recons, y, mask_pixel, num_to_save = batch_size) #recons else: save_samples(save_dir, recons, None, batch_size) recons_np = to_img(recons).astype(np.uint8) self.sampling_queue.put(recons_np) return def fit(self, lambda_, cond, shape, quantize_denoised=False, mask_pixel = None, y = None, log_every_t=100, unconditional_guidance_scale=1., unconditional_conditioning=None, dir_name = None, kl_weight_1=50, kl_weight_2 = 50, debug=False, wdb=False, iterations=200, batch_size = 10, lr_init_gamma=0.01, operator=None, recon_weight = 50): if wdb: wandb.init(project='LDM', dir = '/scratch/sakshi/wandb-cache') wandb.config.run_type = 'hierarchical' wandb.run.name = "hierarchical" params_to_fit = params_train(lambda_) mu_pos, logvar_pos, gamma = params_to_fit optimizers, schedulers = get_optimizers(mu_pos, logvar_pos, gamma, lr_init_gamma) rec_loss_all, prior_loss_all, posterior_loss_all =[], [], [] loss_all = [] mu_all, logvar_all, gamma_all = [], [], [] for k in range(iterations): if k%100==0: print(k) intermediate_mus = np.zeros((len(self.t_steps_hierarchy), self.latent_size, self.latent_size, self.latent_channels)) for opt in optimizers: opt.zero_grad() stats_prior = self.loss_prior(mu_pos[0], logvar_pos[0], cond=cond, unconditional_conditioning=unconditional_conditioning, unconditional_guidance_scale=unconditional_guidance_scale, K=batch_size, intermediate_mus=intermediate_mus) #stats_posterior = self.get_z0_t(stats_prior["sample"], self.t_steps_hierarchy) stats_posterior = self.loss_posterior(stats_prior["sample"], mu_pos[1:], logvar_pos[1:], gamma[1:], cond=cond, unconditional_conditioning=unconditional_conditioning, unconditional_guidance_scale=unconditional_guidance_scale, K=batch_size, iteration=k, intermediate_mus=stats_prior["intermediate_mus"]) sample = self.model.decode_first_stage(stats_posterior["sample"]) stats_recon = self.recon_loss(sample, y, mask_pixel, operator) num_pixels = 3*256*256 #(1000/num_pixels)* (1000/num_pixels)* loss_total = torch.mean(kl_weight_1*stats_prior["loss"] \ + kl_weight_2*stats_posterior["loss"] + recon_weight*stats_recon["loss"] ) # loss_total.backward() for opt in optimizers: opt.step() for sch in schedulers: sch.step() rec_loss_all.append(torch.mean(stats_recon["loss"].detach()).item()) prior_loss_all.append(torch.mean(kl_weight_1*stats_prior["loss"].detach()).item()) posterior_loss_all.append(torch.mean(kl_weight_2*stats_posterior["loss"].detach()).item()) mu_all.append(torch.mean(mu_pos.detach()).item()) logvar_all.append(torch.mean(logvar_pos.detach()).item()) gamma_all.append(torch.mean(torch.sigmoid(gamma).detach()).item()) sample_np = to_img(sample).astype(np.uint8) loss_all.append(loss_total.detach().item()) self.image_queue.put(sample_np) save_plot(dir_name, [rec_loss_all, prior_loss_all, posterior_loss_all], ["Recon loss", "Diffusion loss", "Hierarchical loss"], "loss.png") save_plot(dir_name, [loss_all], ["Total Loss"], "loss_t.png") save_plot(dir_name, [mu_all], ["mean"], "mean.png") save_plot(dir_name, [logvar_all], ["logvar"], "logvar.png") save_plot(dir_name, [gamma_all], ["gamma"], "gamma.png") if k%log_every_t == 0 or k == iterations - 1: save_samples(os.path.join(dir_name , "progress"), sample, k, batch_size) save_samples(os.path.join(dir_name , "mus"), stats_posterior["intermediate_mus"], k, len(stats_posterior["intermediate_mus"])) #save_inpaintings(os.path.join(dir_name , "progress_inpaintings"), sample, y, # mask_pixel, k, num_to_save = 5) save_params(os.path.join(dir_name , "params"), mu_pos, logvar_pos, gamma,k) gc.collect() return ##unconditional samplinng for debugging purposes: ''' def sample_T(self, x0, cond, unconditional_conditioning, unconditional_guidance_scale , eta=0.4, t_steps_hierarchy=None, dir_="out_temp2"): '' sigma_discretization_edm = time_descretization(sigma_min=0.002, sigma_max = 999, rho = 7, num_t_steps = 10)/1000 T_max = 1000 beta_start = 1 # 0.0015*T_max beta_end = 15 # 0.0155*T_max def var(t): return 1.0 - (1.0) * torch.exp(- beta_start * t - 0.5 * (beta_end - beta_start) * t * t) '' x0 = torch.randn_like(x0) t_steps_hierarchy = torch.tensor(self.t_steps_hierarchy).cuda() var_t = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[0]].reshape(1, 1 ,1 ,1))**2 # self.var(t_steps_hierarchy[0]) x_t = x0 # torch.sqrt(1 - var_t) * x0 + torch.sqrt(var_t) * torch.randn_like(x0) os.makedirs(dir_, exist_ok=True) alphas = self.h_alphas b = 5 for i, t in enumerate(t_steps_hierarchy[:-1]): t_hat = torch.ones(b).cuda() * (t) a_t = torch.full((b, 1, 1, 1), alphas[i]).cuda() a_prev = torch.full((b, 1, 1, 1), alphas[i+1]).cuda() sigma_t = self.h_sigmas[i+1] x_t, pred_x0 = self.prior_preds(x_t.float(), t_hat, cond, a_t, a_prev, sigma_t, unconditional_conditioning, unconditional_guidance_scale) var_t = (self.model.sqrt_one_minus_alphas_cumprod[t].reshape(1, 1 ,1 ,1))**2 a_t = 1 - var_t x_t = x_t + sigma_t*torch.randn_like(x_t) recon = self.model.decode_first_stage(pred_x0) image_path = os.path.join(dir_, f'{i}.png') image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0] PIL.Image.fromarray(image_np, 'RGB').save(image_path) t_hat_cur = torch.ones(b).cuda() * (self.t_steps_hierarchy[-1]) e_out = self.get_error(x_t.float(), t_hat_cur, cond, unconditional_conditioning, unconditional_guidance_scale) a_t = torch.full((b, 1, 1, 1), alphas[-1]).cuda() sqrt_one_minus_at = torch.sqrt(1 - a_t) pred_x0 = (x_t - sqrt_one_minus_at * e_out) / a_t.sqrt() recon = self.model.decode_first_stage(pred_x0) image_path = os.path.join(dir_, f'{len(t_steps_hierarchy)}.png') image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0] PIL.Image.fromarray(image_np, 'RGB').save(image_path) return '''