import torch
import numpy as np
from tqdm import tqdm
from functools import partial
import torch.distributions as td
import gc
import wandb
import matplotlib.pyplot as plt
from utils.helper import params_train, get_optimizers,clean_directory, time_descretization, to_img, custom_to_np, save_params, save_samples, save_inpaintings, save_plot
import os
import PIL
import glob
from tqdm import trange
import time
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, extract_into_tensor, noise_like
import wandb
class HPosterior(object):
def __init__(self, model, vae_loss, t_steps_hierarchy, eta=0.4, z0_size=32, img_size = 256, latent_channels = 3,
num_hierarchy_steps=5, schedule="linear", first_stage = "kl", posterior = "hierarchical", image_queue = None, sampling_queue=None, **kwargs):
self.model = model #prior noise prediction model
self.schedule = schedule #noise schedule the prior was trained on
self.vae_loss = vae_loss #vae loss followed during training
self.eta = eta #eta used to produce faster, clean samples
self.first_stage= first_stage #first stage training procedure: kl or vq loss
self.posterior = posterior
self.t_steps_hierarchy = np.array(t_steps_hierarchy) #time steps for hierachical posterior
self.z0size = z0_size #dimension of latent space variables z
self.img_size = img_size #512 #
self.latent_size = z0_size #128 #
self.latent_channels = latent_channels
self.image_queue = image_queue
self.sampling_queue = sampling_queue
def q_given_te(self, t, s, shape, zeta_t_star=None):
if zeta_t_star is not None:
alpha_s = torch.sqrt(1 - zeta_t_star**2)
var_s = zeta_t_star**2
if len(s.shape) == 0 :m = 1
else: m = s.shape[0]
var_s = (self.model.sqrt_one_minus_alphas_cumprod[s].reshape(m, 1 ,1 ,1))**2
alpha_s = torch.sqrt(1 - var_s)
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t])**2
alpha_t = torch.sqrt(1 - var_t)
alpha_t_s = alpha_t.reshape(len(var_t), 1 ,1 ,1) / alpha_s
var_t_s = var_t.reshape(len(var_t), 1 ,1 ,1) - alpha_t_s**2 * var_s
return alpha_t_s, torch.sqrt(var_t_s)
def qpos_given_te(self, t, s, t_star, z_t_star, z_t, zeta_T_star=None):
alpha_t_s, scale_t_s = self.q_given_te(t, s, z_t_star.shape)
alpha_s_t_star, scale_s_t_star = self.q_given_te(s, t_star, z_t_star.shape, zeta_T_star)
var = scale_t_s**2 * scale_s_t_star**2 / (scale_t_s**2 + alpha_s_t_star**2 * scale_s_t_star**2 )
mean = (var) * ( (alpha_s_t_star/scale_s_t_star**2) * z_t_star + (alpha_t_s/scale_t_s**2) * z_t )
return mean, torch.sqrt(var)
def register_buffer(self, name, attr):
if isinstance(attr, torch.Tensor):
if not attr.is_cuda:
attr = attr.cuda()
setattr(self, name, attr)
def get_error(self,x,t,c, unconditional_conditioning, unconditional_guidance_scale):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x.float(), t, c)
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
return e_t
def descretize(self, rho):
#Get time descretization for prior loss (t > T_e)
self.timesteps_1000 = time_descretization(sigma_min=0.002, sigma_max = 0.999, rho = rho, num_t_steps = 1000)*1000
self.timesteps_1000 = self.timesteps_1000.cuda().long()
sigma_timesteps = self.model.sqrt_one_minus_alphas_cumprod[self.timesteps_1000]
self.register_buffer('sigma_timesteps', sigma_timesteps)
#Get prior std for hierarchical time points
sigma_hierarchy = self.model.sqrt_one_minus_alphas_cumprod[self.t_steps_hierarchy]
self.t_steps_hierarchy = torch.tensor(self.t_steps_hierarchy.copy()).cuda()
alphas_h = 1 - sigma_hierarchy**2
alphas_prev = torch.concatenate([ alphas_h[1:], alphas_h[-1].reshape(1)])
h_sigmas = torch.sqrt(self.eta * (1 - alphas_prev) / (1 - alphas_h) * (1 - alphas_h / alphas_prev) )
h_sigmas[1:] = torch.sqrt(self.eta * (1 - alphas_prev[:-1]) / (1 - alphas_h[:-1]) * (1 - alphas_h[:-1] / alphas_prev[:-1]) )
h_sigmas[0] = torch.sqrt(1 - alphas_h[0])
#register tensors
self.register_buffer('h_alphas', alphas_h)
self.register_buffer('h_alphas_prev', alphas_prev)
self.register_buffer('h_sigmas', h_sigmas)
def init(self, img, std_scale, mean_scale, prior_scale, mean_scale_top = 0.1):
num_h_steps = len(self.t_steps_hierarchy)
img = torch.Tensor.repeat(img,[num_h_steps,1,1,1])[:num_h_steps]
#sigmas = self.h_sigmas[...,None, None, None].expand(img.shape)
sigmas = torch.zeros_like(img)
sqrt_alphas = torch.sqrt(self.h_alphas)[...,None, None, None].expand(img.shape)
sqrt_one_minus_alphas = torch.sqrt(1 - self.h_alphas)[...,None, None, None].expand(img.shape)
## Variances for posterior
sigmas[0] = self.h_sigmas[0, None, None, None].expand(img[0].shape)
sigmas[1:] = std_scale * (1/np.sqrt(self.eta)) * self.h_sigmas[1:, None, None, None].expand(img[1:].shape)
logvar_pos = 2*torch.log(sigmas).float()
## Means :
mean_pos = sqrt_alphas*img + mean_scale*sqrt_one_minus_alphas* torch.randn_like(img)
mean_pos[0] = img[0] + mean_scale_top*torch.randn_like(img[0])
## Gammas for posterior weighing between prior and posterior
gamma = torch.tensor(prior_scale)[None,None,None,None].expand(img.shape).cuda()
return mean_pos, logvar_pos, gamma.float()
def get_kl(self,mu1, mu2, scale1, scale2, wt):
return wt*(1/2*scale2**2)*(mu1 - mu2)**2 \
+ torch.log(scale2/scale1) + scale1**2/(2*scale2**2) - 1/2
# diffusion loss
def loss_prior(self, mu_pos, logvar_pos, cond=None,
unconditional_guidance_scale=1, K=10, intermediate_mus=None):
This function gets the kl between q(x_{T_e})||p(x_T_e) ) = E_{t>T*_e}[(x_T_e - \mu_\theta(x_t))^2]
x_T_e = z_t_star, samples from q(x_{T_e})
Sample z_t by adding noise scaled by sqrt(\sigma_t^2 - \zeta_t^2) so that z_t matches total noise at t
t_e = self.t_steps_hierarchy[0]
## Sample z_{T_e}
tau_te = torch.exp(0.5*logvar_pos)
mu_te = torch.Tensor.repeat(mu_pos, [K,1,1,1])
z_te = torch.sqrt(1 - tau_te**2 )* mu_te + tau_te * torch.randn_like(mu_te)
## Sample t
#Get allowed timesteps > T_e
t_g = torch.where(self.sigma_timesteps > torch.max(tau_te))[0]
t_allowed = self.timesteps_1000[t_g]
# print(len(t_g))
def sample_uniform(t_allowed):
t0 = torch.rand(1)
T_max = len(t_allowed)
T_min = 2 #stay away from close values to T*
t = torch.remainder(t0 + torch.arange(0., 1., step=1. / K), 1.)*(T_max-T_min) + T_min
t = torch.floor(t).long()
return t
t = sample_uniform(t_allowed)
t_cur = t_allowed[t]
t_prev = t_allowed[t-1]
# print((t_cur - t_prev), t_cur)
#sample z_t from p(z_t | z_{T_e})
alpha_t, scale_t = self.q_given_te(t_cur, t_e, z_te.shape, tau_te)
error = torch.randn_like(z_te)
z_t = alpha_t*z_te + error* scale_t
#Get prior, posterior mean variances for t_prev
e_out = self.get_error(z_t.float(), t_cur, cond, unconditional_conditioning, unconditional_guidance_scale)
alpha_t_, scale_t_ = self.q_given_te(t_cur,t_e, z_te.shape)
mu_t_hat = (z_t - scale_t_*e_out)/alpha_t_
pos_mean, pos_scale = self.qpos_given_te(t_cur, t_prev, t_e, z_te, z_t, tau_te)
prior_mean, prior_scale = self.qpos_given_te(t_cur, t_prev, t_e, mu_t_hat, z_t, None)
wt = (1000-t_e)/2
kl = self.get_kl(pos_mean, prior_mean,pos_scale, prior_scale, wt=1)
kl = torch.mean(wt*kl, dim=[1,2,3])
return {"loss" : kl, "sample" : z_te, "intermediate_mus" : intermediate_mus}
def recon_loss(self, samples_pixel, x0_pixel, mask_pixel, operator=None):
global_step = 0
if self.first_stage == "kl":
nll_loss, _ = self.vae_loss(x0_pixel, samples_pixel, mask_pixel, 0, global_step,
last_layer=self.model.first_stage_model.get_last_layer(), split="val")
qloss = torch.tensor([0.]).cuda()
nll_loss, _ = self.vae_loss(qloss, x0_pixel, samples_pixel, mask_pixel, 0, 0,
last_layer=self.model.first_stage_model.get_last_layer(), split="val",
predicted_indices=None, operator=operator)
#nll_loss = nll_loss/1000
return { "loss" : nll_loss}
def prior_preds(self, z_t, t_cur, cond, a_t, a_prev, sigma_t, unconditional_conditioning, unconditional_guidance_scale ):
#Get e, pred_x0
e_out = self.get_error(z_t, t_cur, cond, unconditional_conditioning, unconditional_guidance_scale)
pred_x0 = (z_t - torch.sqrt(1 - a_t) * e_out) / a_t.sqrt()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_out
z_next = a_prev.sqrt() * pred_x0 + dir_xt
return z_next, pred_x0
def posterior_mean(self, mu_pos, mu_prior, gamma):
wt = torch.sigmoid(gamma)
mean_t_1 = wt*mu_prior + (1-wt)*mu_pos
return mean_t_1
def normalize(self, img):
img -= torch.min(img)
return 2*img/torch.max(img) - 1
def loss_posterior(self, z_t, mu_pos, logvar_pos, gamma, cond=None,
K=10, iteration=0, to_sample = False, intermediate_mus=None):
sigma_pos = torch.exp(0.5*logvar_pos)
kl_t, t0, q_entropy = torch.zeros(z_t.shape[0]).cuda(), 100, 0
num_steps = len(self.t_steps_hierarchy)
intermediate_samples = np.zeros((num_steps, 1, self.img_size, self.img_size, 3))
intermediate_preds = np.zeros((num_steps, 1, self.img_size, self.img_size, 3))
b = z_t.shape[0]
with torch.no_grad():
recon = self.model.decode_first_stage(z_t)
intermediate_samples[0] = to_img(recon)[0]
alphas = self.h_alphas
for i, (t_cur, t_next) in enumerate(zip(self.t_steps_hierarchy[:-1], self.t_steps_hierarchy[1:])):
t_hat_cur = torch.ones(b).cuda() * (t_cur )
a_t = torch.full((b, 1, 1, 1), alphas[i]).cuda()
a_prev = torch.full((b, 1, 1, 1), alphas[i+1]).cuda()
a_t_prev = a_t/a_prev
sigma_t = self.h_sigmas[i+1]
#Get prior predictions
z_next, pred_x0 = self.prior_preds(z_t.float(), t_hat_cur, cond, a_t, a_prev, sigma_t,
unconditional_conditioning, unconditional_guidance_scale)
std_prior = self.h_sigmas[i+1]
##Posterior means and variances
pos_mean = self.posterior_mean(a_prev.sqrt()*mu_pos[i].unsqueeze(0), z_next, gamma[i].unsqueeze(0))
std_pos = sigma_pos[i]
## Sample z_t
z_t = pos_mean + std_pos * torch.randn_like(pos_mean)
#Get kl
kl = self.get_kl(pos_mean, z_next, std_pos, std_prior, wt=1)
kl_t += torch.mean(kl, dim=[1,2,3])
with torch.no_grad():
recon_pred = self.model.decode_first_stage(pred_x0)
intermediate_preds[i] = to_img(recon_pred)[0]
intermediate_mus[i+1] = to_img(self.normalize(mu_pos[i]).unsqueeze(0)).astype(np.uint8)[0]
##One-step denoising
t_hat_cur = torch.ones(b).cuda() * (self.t_steps_hierarchy[-1])
e_out = self.get_error(z_t.float(), t_hat_cur, cond, unconditional_conditioning, unconditional_guidance_scale)
a_t = torch.full((b, 1, 1, 1), alphas[-1]).cuda()
sqrt_one_minus_at = torch.sqrt(1 - a_t)
pred_z0 = (z_t - sqrt_one_minus_at * e_out) / a_t.sqrt()
with torch.no_grad():
recon = self.model.decode_first_stage(pred_z0)
intermediate_preds[-1] = to_img(recon)[0]
return {"sample" : pred_z0, "loss" : kl_t, "entropy": q_entropy,
"intermediates" : intermediate_samples, "interim_preds" :intermediate_preds,
"intermediate_mus" : intermediate_mus}
def grad_and_value(self, x_prev, x_0_hat, measurement, mask_pixel, operator):
nll_loss = torch.mean(self.recon_loss(x_0_hat, measurement, mask_pixel, operator)["loss"])
norm_grad = torch.autograd.grad(outputs=nll_loss, inputs=x_prev)[0]
return norm_grad, nll_loss
def conditioning(self, x_prev, x_t, x_0_hat, measurement, mask_pixel, scale, operator, **kwargs):
norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat,
measurement=measurement, mask_pixel=mask_pixel, operator=operator)
x_t -= norm_grad*scale
return x_t, norm
def sample(self, scale, eta, mu_pos, logvar_pos, gamma,
mask_pixel, y, n_samples=100, cond=None,
unconditional_conditioning=None, unconditional_guidance_scale=1,
batch_size=10, dir_name="temp/", temp=1,
samples_iteration=0, operator = None):
sigma_pos = torch.exp(0.5*logvar_pos)
t0 = 100
num_steps = len(self.t_steps_hierarchy)
intermediate_samples = np.zeros((num_steps, 1, self.img_size, self.img_size, 3))
intermediate_preds = np.zeros((num_steps, 1, self.img_size, self.img_size, 3))
intermediate_mus = np.zeros((num_steps, 1, self.img_size, self.img_size, 3))
alphas = self.h_alphas
##batch your sample generation
all_images = []
t0 = time.time()
save_dir = os.path.join(dir_name , "samples_50_"+ str(scale) ) #50_ #"samples_" + str(scale)
os.makedirs(save_dir, exist_ok=True)
for _ in trange(n_samples // batch_size, desc="Sampling Batches"):
mu_10 = torch.Tensor.repeat(mu_pos[0], [batch_size,1,1,1])
tau_t = sigma_pos[0]
z_t = torch.sqrt(1 - tau_t**2 )* mu_10 + tau_t * torch.randn_like(mu_10)
##Sample from posterior
with torch.no_grad():
recon = self.model.decode_first_stage(z_t)
intermediate_samples[0] = to_img(recon)[0]
for i, (t_cur, t_next) in enumerate(zip(self.t_steps_hierarchy[:-1], self.t_steps_hierarchy[1:])):
# print(t_cur)
t_hat_cur = torch.ones(batch_size).cuda() * (t_cur )
a_t = torch.full((batch_size, 1, 1, 1), alphas[i]).cuda()
a_prev = torch.full((batch_size, 1, 1, 1), alphas[i+1]).cuda()
sigma_t = self.h_sigmas[i+1]
#Get prior predictions
z_next, pred_x0 = self.prior_preds(z_t.float(), t_hat_cur, cond, a_t, a_prev, sigma_t,
unconditional_conditioning, unconditional_guidance_scale)
##Posterior means and variances
# a_prev.sqrt()*
mean_t_1 = self.posterior_mean(a_prev.sqrt()*mu_pos[i+1].unsqueeze(0), z_next, gamma[i+1].unsqueeze(0))
std_pos = sigma_pos[i+1]
#Sample z_t
z_t = mean_t_1 + std_pos * torch.randn_like(mean_t_1)
with torch.no_grad():
pred_x = self.model.decode_first_stage(pred_x0)
save_samples(save_dir, pred_x, k=None, num_to_save = 1, file_name = f'sample_{i}.png')
timesteps = np.flip(np.arange(0, self.t_steps_hierarchy[-1].cpu().numpy(), 1))
timesteps = np.concatenate((self.t_steps_hierarchy[-1].cpu().reshape(1), timesteps))
##Sample using DPS algorithm
for i, (step, t_next) in enumerate(zip(timesteps[:-1], timesteps[1:])):
step = int(step)
t_hat_cur = torch.ones(batch_size).cuda() * (step)
a_t = torch.full((batch_size, 1, 1, 1), self.model.alphas_cumprod[step]).cuda()
a_prev = torch.full((batch_size, 1, 1, 1), self.model.alphas_cumprod[int(t_next)]).cuda()
sigma_t = eta *torch.sqrt( (1 - a_prev) / (1 - a_t) * (1 - a_t / a_prev))
z_t = z_t.requires_grad_()
z_next, pred_x0 = self.prior_preds(z_t.float(), t_hat_cur, cond, a_t, a_prev, sigma_t,
unconditional_conditioning, unconditional_guidance_scale)
pred_x = self.model.decode_first_stage(pred_x0)
z_t, _ = self.conditioning(x_prev = z_t , x_t = z_next,
x_0_hat = pred_x, measurement = y,
mask_pixel=mask_pixel, scale=scale, operator=operator)
z_t = z_t.detach_()
if i%50 == 0:
with torch.no_grad():
recons = self.model.decode_first_stage(pred_x0)
recons_np = to_img(recons).astype(np.uint8)
save_samples(save_dir, recons, k=None, num_to_save = 1, file_name = f'det_{step}.png')
z_0 = pred_x0
with torch.no_grad():
recon = self.model.decode_first_stage(z_0)
intermediate_preds[-1] = to_img(recons)[0]
with torch.no_grad() :
recons = self.model.decode_first_stage(pred_x0)
recons_np = to_img(recons).astype(np.uint8)
t1 = time.time()
all_img = np.concatenate(all_images, axis=0)
all_img = all_img[:n_samples]
shape_str = "x".join([str(x) for x in all_img.shape])
nppath = os.path.join(save_dir, f"{shape_str}-samples.npz")
np.savez(nppath, all_img, t1-t0)
recon_in = y*(mask_pixel) + ( 1-mask_pixel)*recons
recon_in = to_img(recon_in)
image_path = os.path.join(save_dir, str(samples_iteration) + ".png")
image_np = recon_in.astype(np.uint8)[0]
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
file_name_img = None
if operator is None:
save_inpaintings(save_dir, recons, y, mask_pixel, num_to_save = batch_size) #recons
save_samples(save_dir, recons, None, batch_size)
recons_np = to_img(recons).astype(np.uint8)
def fit(self, lambda_, cond, shape, quantize_denoised=False, mask_pixel = None,
y = None, log_every_t=100, unconditional_guidance_scale=1.,
unconditional_conditioning=None, dir_name = None, kl_weight_1=50, kl_weight_2 = 50,
debug=False, wdb=False, iterations=200, batch_size = 10, lr_init_gamma=0.01,
operator=None, recon_weight = 50):
if wdb:
wandb.init(project='LDM', dir = '/scratch/sakshi/wandb-cache')
wandb.config.run_type = 'hierarchical'
wandb.run.name = "hierarchical"
params_to_fit = params_train(lambda_)
mu_pos, logvar_pos, gamma = params_to_fit
optimizers, schedulers = get_optimizers(mu_pos, logvar_pos, gamma, lr_init_gamma)
rec_loss_all, prior_loss_all, posterior_loss_all =[], [], []
loss_all = []
mu_all, logvar_all, gamma_all = [], [], []
for k in range(iterations):
if k%100==0: print(k)
intermediate_mus = np.zeros((len(self.t_steps_hierarchy), self.latent_size, self.latent_size, self.latent_channels))
for opt in optimizers: opt.zero_grad()
stats_prior = self.loss_prior(mu_pos[0], logvar_pos[0], cond=cond,
K=batch_size, intermediate_mus=intermediate_mus)
#stats_posterior = self.get_z0_t(stats_prior["sample"], self.t_steps_hierarchy)
stats_posterior = self.loss_posterior(stats_prior["sample"], mu_pos[1:], logvar_pos[1:], gamma[1:],
K=batch_size, iteration=k, intermediate_mus=stats_prior["intermediate_mus"])
sample = self.model.decode_first_stage(stats_posterior["sample"])
stats_recon = self.recon_loss(sample, y, mask_pixel, operator)
num_pixels = 3*256*256 #(1000/num_pixels)* (1000/num_pixels)*
loss_total = torch.mean(kl_weight_1*stats_prior["loss"] \
+ kl_weight_2*stats_posterior["loss"] + recon_weight*stats_recon["loss"] ) #
for opt in optimizers: opt.step()
for sch in schedulers: sch.step()
sample_np = to_img(sample).astype(np.uint8)
save_plot(dir_name, [rec_loss_all, prior_loss_all, posterior_loss_all],
["Recon loss", "Diffusion loss", "Hierarchical loss"], "loss.png")
save_plot(dir_name, [loss_all],
["Total Loss"], "loss_t.png")
save_plot(dir_name, [mu_all],
["mean"], "mean.png")
save_plot(dir_name, [logvar_all],
["logvar"], "logvar.png")
save_plot(dir_name, [gamma_all],
["gamma"], "gamma.png")
if k%log_every_t == 0 or k == iterations - 1:
save_samples(os.path.join(dir_name , "progress"), sample, k, batch_size)
save_samples(os.path.join(dir_name , "mus"), stats_posterior["intermediate_mus"], k,
#save_inpaintings(os.path.join(dir_name , "progress_inpaintings"), sample, y,
# mask_pixel, k, num_to_save = 5)
save_params(os.path.join(dir_name , "params"), mu_pos, logvar_pos, gamma,k)
##unconditional samplinng for debugging purposes:
def sample_T(self, x0, cond, unconditional_conditioning, unconditional_guidance_scale , eta=0.4, t_steps_hierarchy=None, dir_="out_temp2"):
sigma_discretization_edm = time_descretization(sigma_min=0.002, sigma_max = 999, rho = 7, num_t_steps = 10)/1000
T_max = 1000
beta_start = 1 # 0.0015*T_max
beta_end = 15 # 0.0155*T_max
def var(t):
return 1.0 - (1.0) * torch.exp(- beta_start * t - 0.5 * (beta_end - beta_start) * t * t)
x0 = torch.randn_like(x0)
t_steps_hierarchy = torch.tensor(self.t_steps_hierarchy).cuda()
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[0]].reshape(1, 1 ,1 ,1))**2 # self.var(t_steps_hierarchy[0])
x_t = x0 # torch.sqrt(1 - var_t) * x0 + torch.sqrt(var_t) * torch.randn_like(x0)
os.makedirs(dir_, exist_ok=True)
alphas = self.h_alphas
b = 5
for i, t in enumerate(t_steps_hierarchy[:-1]):
t_hat = torch.ones(b).cuda() * (t)
a_t = torch.full((b, 1, 1, 1), alphas[i]).cuda()
a_prev = torch.full((b, 1, 1, 1), alphas[i+1]).cuda()
sigma_t = self.h_sigmas[i+1]
x_t, pred_x0 = self.prior_preds(x_t.float(), t_hat, cond, a_t, a_prev, sigma_t,
unconditional_conditioning, unconditional_guidance_scale)
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t].reshape(1, 1 ,1 ,1))**2
a_t = 1 - var_t
x_t = x_t + sigma_t*torch.randn_like(x_t)
recon = self.model.decode_first_stage(pred_x0)
image_path = os.path.join(dir_, f'{i}.png')
image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0]
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
t_hat_cur = torch.ones(b).cuda() * (self.t_steps_hierarchy[-1])
e_out = self.get_error(x_t.float(), t_hat_cur, cond, unconditional_conditioning, unconditional_guidance_scale)
a_t = torch.full((b, 1, 1, 1), alphas[-1]).cuda()
sqrt_one_minus_at = torch.sqrt(1 - a_t)
pred_x0 = (x_t - sqrt_one_minus_at * e_out) / a_t.sqrt()
recon = self.model.decode_first_stage(pred_x0)
image_path = os.path.join(dir_, f'{len(t_steps_hierarchy)}.png')
image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0]
PIL.Image.fromarray(image_np, 'RGB').save(image_path)