Spaces:
Sleeping
Sleeping
"""INFERENCE TIME OPTIMIZATION""" | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from functools import partial | |
import torch.distributions as td | |
import gc | |
import wandb | |
import matplotlib.pyplot as plt | |
from utils.helper import params_train, get_optimizers,clean_directory, time_descretization, to_img, custom_to_np, save_params, save_samples, save_inpaintings, save_plot | |
import os | |
import PIL | |
import glob | |
from tqdm import trange | |
import time | |
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, extract_into_tensor, noise_like | |
import wandb | |
class HPosterior(object): | |
def __init__(self, model, vae_loss, t_steps_hierarchy, eta=0.4, z0_size=32, img_size = 256, latent_channels = 3, | |
num_hierarchy_steps=5, schedule="linear", first_stage = "kl", posterior = "hierarchical", image_queue = None, sampling_queue=None, **kwargs): | |
super().__init__() | |
self.model = model #prior noise prediction model | |
self.schedule = schedule #noise schedule the prior was trained on | |
self.vae_loss = vae_loss #vae loss followed during training | |
self.eta = eta #eta used to produce faster, clean samples | |
self.first_stage= first_stage #first stage training procedure: kl or vq loss | |
self.posterior = posterior | |
self.t_steps_hierarchy = np.array(t_steps_hierarchy) #time steps for hierachical posterior | |
self.z0size = z0_size #dimension of latent space variables z | |
self.img_size = img_size #512 # | |
self.latent_size = z0_size #128 # | |
self.latent_channels = latent_channels | |
self.image_queue = image_queue | |
self.sampling_queue = sampling_queue | |
def q_given_te(self, t, s, shape, zeta_t_star=None): | |
if zeta_t_star is not None: | |
alpha_s = torch.sqrt(1 - zeta_t_star**2) | |
var_s = zeta_t_star**2 | |
else: | |
if len(s.shape) == 0 :m = 1 | |
else: m = s.shape[0] | |
var_s = (self.model.sqrt_one_minus_alphas_cumprod[s].reshape(m, 1 ,1 ,1))**2 | |
alpha_s = torch.sqrt(1 - var_s) | |
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t])**2 | |
alpha_t = torch.sqrt(1 - var_t) | |
alpha_t_s = alpha_t.reshape(len(var_t), 1 ,1 ,1) / alpha_s | |
var_t_s = var_t.reshape(len(var_t), 1 ,1 ,1) - alpha_t_s**2 * var_s | |
return alpha_t_s, torch.sqrt(var_t_s) | |
def qpos_given_te(self, t, s, t_star, z_t_star, z_t, zeta_T_star=None): | |
alpha_t_s, scale_t_s = self.q_given_te(t, s, z_t_star.shape) | |
alpha_s_t_star, scale_s_t_star = self.q_given_te(s, t_star, z_t_star.shape, zeta_T_star) | |
var = scale_t_s**2 * scale_s_t_star**2 / (scale_t_s**2 + alpha_s_t_star**2 * scale_s_t_star**2 ) | |
mean = (var) * ( (alpha_s_t_star/scale_s_t_star**2) * z_t_star + (alpha_t_s/scale_t_s**2) * z_t ) | |
return mean, torch.sqrt(var) | |
def register_buffer(self, name, attr): | |
if isinstance(attr, torch.Tensor): | |
if not attr.is_cuda: | |
attr = attr.cuda() | |
setattr(self, name, attr) | |
def get_error(self,x,t,c, unconditional_conditioning, unconditional_guidance_scale): | |
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: | |
e_t = self.model.apply_model(x.float(), t, c) | |
else: | |
x_in = torch.cat([x] * 2) | |
t_in = torch.cat([t] * 2) | |
c_in = torch.cat([unconditional_conditioning, c]) | |
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) | |
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | |
return e_t | |
def descretize(self, rho): | |
#Get time descretization for prior loss (t > T_e) | |
self.timesteps_1000 = time_descretization(sigma_min=0.002, sigma_max = 0.999, rho = rho, num_t_steps = 1000)*1000 | |
self.timesteps_1000 = self.timesteps_1000.cuda().long() | |
sigma_timesteps = self.model.sqrt_one_minus_alphas_cumprod[self.timesteps_1000] | |
self.register_buffer('sigma_timesteps', sigma_timesteps) | |
#Get prior std for hierarchical time points | |
sigma_hierarchy = self.model.sqrt_one_minus_alphas_cumprod[self.t_steps_hierarchy] | |
self.t_steps_hierarchy = torch.tensor(self.t_steps_hierarchy.copy()).cuda() | |
alphas_h = 1 - sigma_hierarchy**2 | |
alphas_prev = torch.concatenate([ alphas_h[1:], alphas_h[-1].reshape(1)]) | |
h_sigmas = torch.sqrt(self.eta * (1 - alphas_prev) / (1 - alphas_h) * (1 - alphas_h / alphas_prev) ) | |
h_sigmas[1:] = torch.sqrt(self.eta * (1 - alphas_prev[:-1]) / (1 - alphas_h[:-1]) * (1 - alphas_h[:-1] / alphas_prev[:-1]) ) | |
h_sigmas[0] = torch.sqrt(1 - alphas_h[0]) | |
#register tensors | |
self.register_buffer('h_alphas', alphas_h) | |
self.register_buffer('h_alphas_prev', alphas_prev) | |
self.register_buffer('h_sigmas', h_sigmas) | |
def init(self, img, std_scale, mean_scale, prior_scale, mean_scale_top = 0.1): | |
num_h_steps = len(self.t_steps_hierarchy) | |
img = torch.Tensor.repeat(img,[num_h_steps,1,1,1])[:num_h_steps] | |
#sigmas = self.h_sigmas[...,None, None, None].expand(img.shape) | |
sigmas = torch.zeros_like(img) | |
sqrt_alphas = torch.sqrt(self.h_alphas)[...,None, None, None].expand(img.shape) | |
sqrt_one_minus_alphas = torch.sqrt(1 - self.h_alphas)[...,None, None, None].expand(img.shape) | |
## Variances for posterior | |
sigmas[0] = self.h_sigmas[0, None, None, None].expand(img[0].shape) | |
sigmas[1:] = std_scale * (1/np.sqrt(self.eta)) * self.h_sigmas[1:, None, None, None].expand(img[1:].shape) | |
logvar_pos = 2*torch.log(sigmas).float() | |
## Means : | |
mean_pos = sqrt_alphas*img + mean_scale*sqrt_one_minus_alphas* torch.randn_like(img) | |
mean_pos[0] = img[0] + mean_scale_top*torch.randn_like(img[0]) | |
## Gammas for posterior weighing between prior and posterior | |
gamma = torch.tensor(prior_scale)[None,None,None,None].expand(img.shape).cuda() | |
return mean_pos, logvar_pos, gamma.float() | |
def get_kl(self,mu1, mu2, scale1, scale2, wt): | |
return wt*(1/2*scale2**2)*(mu1 - mu2)**2 \ | |
+ torch.log(scale2/scale1) + scale1**2/(2*scale2**2) - 1/2 | |
# diffusion loss | |
def loss_prior(self, mu_pos, logvar_pos, cond=None, | |
unconditional_conditioning=None, | |
unconditional_guidance_scale=1, K=10, intermediate_mus=None): | |
''' | |
This function gets the kl between q(x_{T_e})||p(x_T_e) ) = E_{t>T*_e}[(x_T_e - \mu_\theta(x_t))^2] | |
x_T_e = z_t_star, samples from q(x_{T_e}) | |
Sample z_t by adding noise scaled by sqrt(\sigma_t^2 - \zeta_t^2) so that z_t matches total noise at t | |
''' | |
t_e = self.t_steps_hierarchy[0] | |
## Sample z_{T_e} | |
tau_te = torch.exp(0.5*logvar_pos) | |
mu_te = torch.Tensor.repeat(mu_pos, [K,1,1,1]) | |
z_te = torch.sqrt(1 - tau_te**2 )* mu_te + tau_te * torch.randn_like(mu_te) | |
## Sample t | |
#Get allowed timesteps > T_e | |
t_g = torch.where(self.sigma_timesteps > torch.max(tau_te))[0] | |
t_allowed = self.timesteps_1000[t_g] | |
# print(len(t_g)) | |
def sample_uniform(t_allowed): | |
t0 = torch.rand(1) | |
T_max = len(t_allowed) | |
T_min = 2 #stay away from close values to T* | |
t = torch.remainder(t0 + torch.arange(0., 1., step=1. / K), 1.)*(T_max-T_min) + T_min | |
t = torch.floor(t).long() | |
return t | |
t = sample_uniform(t_allowed) | |
t_cur = t_allowed[t] | |
t_prev = t_allowed[t-1] | |
# print((t_cur - t_prev), t_cur) | |
#sample z_t from p(z_t | z_{T_e}) | |
alpha_t, scale_t = self.q_given_te(t_cur, t_e, z_te.shape, tau_te) | |
error = torch.randn_like(z_te) | |
z_t = alpha_t*z_te + error* scale_t | |
#Get prior, posterior mean variances for t_prev | |
e_out = self.get_error(z_t.float(), t_cur, cond, unconditional_conditioning, unconditional_guidance_scale) | |
alpha_t_, scale_t_ = self.q_given_te(t_cur,t_e, z_te.shape) | |
mu_t_hat = (z_t - scale_t_*e_out)/alpha_t_ | |
pos_mean, pos_scale = self.qpos_given_te(t_cur, t_prev, t_e, z_te, z_t, tau_te) | |
prior_mean, prior_scale = self.qpos_given_te(t_cur, t_prev, t_e, mu_t_hat, z_t, None) | |
wt = (1000-t_e)/2 | |
kl = self.get_kl(pos_mean, prior_mean,pos_scale, prior_scale, wt=1) | |
kl = torch.mean(wt*kl, dim=[1,2,3]) | |
return {"loss" : kl, "sample" : z_te, "intermediate_mus" : intermediate_mus} | |
def recon_loss(self, samples_pixel, x0_pixel, mask_pixel, operator=None): | |
global_step = 0 | |
if self.first_stage == "kl": | |
nll_loss, _ = self.vae_loss(x0_pixel, samples_pixel, mask_pixel, 0, global_step, | |
last_layer=self.model.first_stage_model.get_last_layer(), split="val") | |
else: | |
qloss = torch.tensor([0.]).cuda() | |
nll_loss, _ = self.vae_loss(qloss, x0_pixel, samples_pixel, mask_pixel, 0, 0, | |
last_layer=self.model.first_stage_model.get_last_layer(), split="val", | |
predicted_indices=None, operator=operator) | |
#nll_loss = nll_loss/1000 | |
return { "loss" : nll_loss} | |
def prior_preds(self, z_t, t_cur, cond, a_t, a_prev, sigma_t, unconditional_conditioning, unconditional_guidance_scale ): | |
#Get e, pred_x0 | |
e_out = self.get_error(z_t, t_cur, cond, unconditional_conditioning, unconditional_guidance_scale) | |
pred_x0 = (z_t - torch.sqrt(1 - a_t) * e_out) / a_t.sqrt() | |
# direction pointing to x_t | |
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_out | |
z_next = a_prev.sqrt() * pred_x0 + dir_xt | |
return z_next, pred_x0 | |
def posterior_mean(self, mu_pos, mu_prior, gamma): | |
wt = torch.sigmoid(gamma) | |
mean_t_1 = wt*mu_prior + (1-wt)*mu_pos | |
return mean_t_1 | |
def normalize(self, img): | |
img -= torch.min(img) | |
return 2*img/torch.max(img) - 1 | |
def loss_posterior(self, z_t, mu_pos, logvar_pos, gamma, cond=None, | |
unconditional_conditioning=None, | |
unconditional_guidance_scale=1, | |
K=10, iteration=0, to_sample = False, intermediate_mus=None): | |
sigma_pos = torch.exp(0.5*logvar_pos) | |
kl_t, t0, q_entropy = torch.zeros(z_t.shape[0]).cuda(), 100, 0 | |
num_steps = len(self.t_steps_hierarchy) | |
intermediate_samples = np.zeros((num_steps, 1, self.img_size, self.img_size, 3)) | |
intermediate_preds = np.zeros((num_steps, 1, self.img_size, self.img_size, 3)) | |
b = z_t.shape[0] | |
with torch.no_grad(): | |
recon = self.model.decode_first_stage(z_t) | |
intermediate_samples[0] = to_img(recon)[0] | |
alphas = self.h_alphas | |
for i, (t_cur, t_next) in enumerate(zip(self.t_steps_hierarchy[:-1], self.t_steps_hierarchy[1:])): | |
t_hat_cur = torch.ones(b).cuda() * (t_cur ) | |
a_t = torch.full((b, 1, 1, 1), alphas[i]).cuda() | |
a_prev = torch.full((b, 1, 1, 1), alphas[i+1]).cuda() | |
a_t_prev = a_t/a_prev | |
sigma_t = self.h_sigmas[i+1] | |
#Get prior predictions | |
z_next, pred_x0 = self.prior_preds(z_t.float(), t_hat_cur, cond, a_t, a_prev, sigma_t, | |
unconditional_conditioning, unconditional_guidance_scale) | |
std_prior = self.h_sigmas[i+1] | |
##Posterior means and variances | |
pos_mean = self.posterior_mean(a_prev.sqrt()*mu_pos[i].unsqueeze(0), z_next, gamma[i].unsqueeze(0)) | |
std_pos = sigma_pos[i] | |
## Sample z_t | |
z_t = pos_mean + std_pos * torch.randn_like(pos_mean) | |
#Get kl | |
kl = self.get_kl(pos_mean, z_next, std_pos, std_prior, wt=1) | |
kl_t += torch.mean(kl, dim=[1,2,3]) | |
with torch.no_grad(): | |
recon_pred = self.model.decode_first_stage(pred_x0) | |
intermediate_preds[i] = to_img(recon_pred)[0] | |
intermediate_mus[i+1] = to_img(self.normalize(mu_pos[i]).unsqueeze(0)).astype(np.uint8)[0] | |
##One-step denoising | |
t_hat_cur = torch.ones(b).cuda() * (self.t_steps_hierarchy[-1]) | |
e_out = self.get_error(z_t.float(), t_hat_cur, cond, unconditional_conditioning, unconditional_guidance_scale) | |
a_t = torch.full((b, 1, 1, 1), alphas[-1]).cuda() | |
sqrt_one_minus_at = torch.sqrt(1 - a_t) | |
pred_z0 = (z_t - sqrt_one_minus_at * e_out) / a_t.sqrt() | |
with torch.no_grad(): | |
recon = self.model.decode_first_stage(pred_z0) | |
intermediate_preds[-1] = to_img(recon)[0] | |
return {"sample" : pred_z0, "loss" : kl_t, "entropy": q_entropy, | |
"intermediates" : intermediate_samples, "interim_preds" :intermediate_preds, | |
"intermediate_mus" : intermediate_mus} | |
def grad_and_value(self, x_prev, x_0_hat, measurement, mask_pixel, operator): | |
nll_loss = torch.mean(self.recon_loss(x_0_hat, measurement, mask_pixel, operator)["loss"]) | |
norm_grad = torch.autograd.grad(outputs=nll_loss, inputs=x_prev)[0] | |
return norm_grad, nll_loss | |
def conditioning(self, x_prev, x_t, x_0_hat, measurement, mask_pixel, scale, operator, **kwargs): | |
norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat, | |
measurement=measurement, mask_pixel=mask_pixel, operator=operator) | |
x_t -= norm_grad*scale | |
return x_t, norm | |
def sample(self, scale, eta, mu_pos, logvar_pos, gamma, | |
mask_pixel, y, n_samples=100, cond=None, | |
unconditional_conditioning=None, unconditional_guidance_scale=1, | |
batch_size=10, dir_name="temp/", temp=1, | |
samples_iteration=0, operator = None): | |
sigma_pos = torch.exp(0.5*logvar_pos) | |
t0 = 100 | |
num_steps = len(self.t_steps_hierarchy) | |
intermediate_samples = np.zeros((num_steps, 1, self.img_size, self.img_size, 3)) | |
intermediate_preds = np.zeros((num_steps, 1, self.img_size, self.img_size, 3)) | |
intermediate_mus = np.zeros((num_steps, 1, self.img_size, self.img_size, 3)) | |
alphas = self.h_alphas | |
##batch your sample generation | |
all_images = [] | |
t0 = time.time() | |
save_dir = os.path.join(dir_name , "samples_50_"+ str(scale) ) #50_ #"samples_" + str(scale) | |
os.makedirs(save_dir, exist_ok=True) | |
for _ in trange(n_samples // batch_size, desc="Sampling Batches"): | |
mu_10 = torch.Tensor.repeat(mu_pos[0], [batch_size,1,1,1]) | |
tau_t = sigma_pos[0] | |
z_t = torch.sqrt(1 - tau_t**2 )* mu_10 + tau_t * torch.randn_like(mu_10) | |
##Sample from posterior | |
with torch.no_grad(): | |
recon = self.model.decode_first_stage(z_t) | |
intermediate_samples[0] = to_img(recon)[0] | |
for i, (t_cur, t_next) in enumerate(zip(self.t_steps_hierarchy[:-1], self.t_steps_hierarchy[1:])): | |
# print(t_cur) | |
t_hat_cur = torch.ones(batch_size).cuda() * (t_cur ) | |
a_t = torch.full((batch_size, 1, 1, 1), alphas[i]).cuda() | |
a_prev = torch.full((batch_size, 1, 1, 1), alphas[i+1]).cuda() | |
sigma_t = self.h_sigmas[i+1] | |
#Get prior predictions | |
z_next, pred_x0 = self.prior_preds(z_t.float(), t_hat_cur, cond, a_t, a_prev, sigma_t, | |
unconditional_conditioning, unconditional_guidance_scale) | |
##Posterior means and variances | |
# a_prev.sqrt()* | |
mean_t_1 = self.posterior_mean(a_prev.sqrt()*mu_pos[i+1].unsqueeze(0), z_next, gamma[i+1].unsqueeze(0)) | |
std_pos = sigma_pos[i+1] | |
#Sample z_t | |
z_t = mean_t_1 + std_pos * torch.randn_like(mean_t_1) | |
with torch.no_grad(): | |
pred_x = self.model.decode_first_stage(pred_x0) | |
save_samples(save_dir, pred_x, k=None, num_to_save = 1, file_name = f'sample_{i}.png') | |
timesteps = np.flip(np.arange(0, self.t_steps_hierarchy[-1].cpu().numpy(), 1)) | |
timesteps = np.concatenate((self.t_steps_hierarchy[-1].cpu().reshape(1), timesteps)) | |
##Sample using DPS algorithm | |
for i, (step, t_next) in enumerate(zip(timesteps[:-1], timesteps[1:])): | |
step = int(step) | |
t_hat_cur = torch.ones(batch_size).cuda() * (step) | |
a_t = torch.full((batch_size, 1, 1, 1), self.model.alphas_cumprod[step]).cuda() | |
a_prev = torch.full((batch_size, 1, 1, 1), self.model.alphas_cumprod[int(t_next)]).cuda() | |
sigma_t = eta *torch.sqrt( (1 - a_prev) / (1 - a_t) * (1 - a_t / a_prev)) | |
z_t = z_t.requires_grad_() | |
z_next, pred_x0 = self.prior_preds(z_t.float(), t_hat_cur, cond, a_t, a_prev, sigma_t, | |
unconditional_conditioning, unconditional_guidance_scale) | |
pred_x = self.model.decode_first_stage(pred_x0) | |
z_t, _ = self.conditioning(x_prev = z_t , x_t = z_next, | |
x_0_hat = pred_x, measurement = y, | |
mask_pixel=mask_pixel, scale=scale, operator=operator) | |
z_t = z_t.detach_() | |
if i%50 == 0: | |
with torch.no_grad(): | |
recons = self.model.decode_first_stage(pred_x0) | |
recons_np = to_img(recons).astype(np.uint8) | |
self.sampling_queue.put(recons_np) | |
save_samples(save_dir, recons, k=None, num_to_save = 1, file_name = f'det_{step}.png') | |
z_0 = pred_x0 | |
with torch.no_grad(): | |
recon = self.model.decode_first_stage(z_0) | |
intermediate_preds[-1] = to_img(recons)[0] | |
with torch.no_grad() : | |
recons = self.model.decode_first_stage(pred_x0) | |
recons_np = to_img(recons).astype(np.uint8) | |
self.sampling_queue.put(recons_np) | |
all_images.append(custom_to_np(recons)) | |
t1 = time.time() | |
all_img = np.concatenate(all_images, axis=0) | |
all_img = all_img[:n_samples] | |
shape_str = "x".join([str(x) for x in all_img.shape]) | |
nppath = os.path.join(save_dir, f"{shape_str}-samples.npz") | |
np.savez(nppath, all_img, t1-t0) | |
''' | |
recon_in = y*(mask_pixel) + ( 1-mask_pixel)*recons | |
recon_in = to_img(recon_in) | |
image_path = os.path.join(save_dir, str(samples_iteration) + ".png") | |
image_np = recon_in.astype(np.uint8)[0] | |
PIL.Image.fromarray(image_np, 'RGB').save(image_path) | |
''' | |
file_name_img = None | |
if operator is None: | |
save_inpaintings(save_dir, recons, y, mask_pixel, num_to_save = batch_size) #recons | |
else: | |
save_samples(save_dir, recons, None, batch_size) | |
recons_np = to_img(recons).astype(np.uint8) | |
self.sampling_queue.put(recons_np) | |
return | |
def fit(self, lambda_, cond, shape, quantize_denoised=False, mask_pixel = None, | |
y = None, log_every_t=100, unconditional_guidance_scale=1., | |
unconditional_conditioning=None, dir_name = None, kl_weight_1=50, kl_weight_2 = 50, | |
debug=False, wdb=False, iterations=200, batch_size = 10, lr_init_gamma=0.01, | |
operator=None, recon_weight = 50): | |
if wdb: | |
wandb.init(project='LDM', dir = '/scratch/sakshi/wandb-cache') | |
wandb.config.run_type = 'hierarchical' | |
wandb.run.name = "hierarchical" | |
params_to_fit = params_train(lambda_) | |
mu_pos, logvar_pos, gamma = params_to_fit | |
optimizers, schedulers = get_optimizers(mu_pos, logvar_pos, gamma, lr_init_gamma) | |
rec_loss_all, prior_loss_all, posterior_loss_all =[], [], [] | |
loss_all = [] | |
mu_all, logvar_all, gamma_all = [], [], [] | |
for k in range(iterations): | |
if k%100==0: print(k) | |
intermediate_mus = np.zeros((len(self.t_steps_hierarchy), self.latent_size, self.latent_size, self.latent_channels)) | |
for opt in optimizers: opt.zero_grad() | |
stats_prior = self.loss_prior(mu_pos[0], logvar_pos[0], cond=cond, | |
unconditional_conditioning=unconditional_conditioning, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
K=batch_size, intermediate_mus=intermediate_mus) | |
#stats_posterior = self.get_z0_t(stats_prior["sample"], self.t_steps_hierarchy) | |
stats_posterior = self.loss_posterior(stats_prior["sample"], mu_pos[1:], logvar_pos[1:], gamma[1:], | |
cond=cond, | |
unconditional_conditioning=unconditional_conditioning, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
K=batch_size, iteration=k, intermediate_mus=stats_prior["intermediate_mus"]) | |
sample = self.model.decode_first_stage(stats_posterior["sample"]) | |
stats_recon = self.recon_loss(sample, y, mask_pixel, operator) | |
num_pixels = 3*256*256 #(1000/num_pixels)* (1000/num_pixels)* | |
loss_total = torch.mean(kl_weight_1*stats_prior["loss"] \ | |
+ kl_weight_2*stats_posterior["loss"] + recon_weight*stats_recon["loss"] ) # | |
loss_total.backward() | |
for opt in optimizers: opt.step() | |
for sch in schedulers: sch.step() | |
rec_loss_all.append(torch.mean(stats_recon["loss"].detach()).item()) | |
prior_loss_all.append(torch.mean(kl_weight_1*stats_prior["loss"].detach()).item()) | |
posterior_loss_all.append(torch.mean(kl_weight_2*stats_posterior["loss"].detach()).item()) | |
mu_all.append(torch.mean(mu_pos.detach()).item()) | |
logvar_all.append(torch.mean(logvar_pos.detach()).item()) | |
gamma_all.append(torch.mean(torch.sigmoid(gamma).detach()).item()) | |
sample_np = to_img(sample).astype(np.uint8) | |
loss_all.append(loss_total.detach().item()) | |
self.image_queue.put(sample_np) | |
save_plot(dir_name, [rec_loss_all, prior_loss_all, posterior_loss_all], | |
["Recon loss", "Diffusion loss", "Hierarchical loss"], "loss.png") | |
save_plot(dir_name, [loss_all], | |
["Total Loss"], "loss_t.png") | |
save_plot(dir_name, [mu_all], | |
["mean"], "mean.png") | |
save_plot(dir_name, [logvar_all], | |
["logvar"], "logvar.png") | |
save_plot(dir_name, [gamma_all], | |
["gamma"], "gamma.png") | |
if k%log_every_t == 0 or k == iterations - 1: | |
save_samples(os.path.join(dir_name , "progress"), sample, k, batch_size) | |
save_samples(os.path.join(dir_name , "mus"), stats_posterior["intermediate_mus"], k, | |
len(stats_posterior["intermediate_mus"])) | |
#save_inpaintings(os.path.join(dir_name , "progress_inpaintings"), sample, y, | |
# mask_pixel, k, num_to_save = 5) | |
save_params(os.path.join(dir_name , "params"), mu_pos, logvar_pos, gamma,k) | |
gc.collect() | |
return | |
##unconditional samplinng for debugging purposes: | |
''' | |
def sample_T(self, x0, cond, unconditional_conditioning, unconditional_guidance_scale , eta=0.4, t_steps_hierarchy=None, dir_="out_temp2"): | |
'' | |
sigma_discretization_edm = time_descretization(sigma_min=0.002, sigma_max = 999, rho = 7, num_t_steps = 10)/1000 | |
T_max = 1000 | |
beta_start = 1 # 0.0015*T_max | |
beta_end = 15 # 0.0155*T_max | |
def var(t): | |
return 1.0 - (1.0) * torch.exp(- beta_start * t - 0.5 * (beta_end - beta_start) * t * t) | |
'' | |
x0 = torch.randn_like(x0) | |
t_steps_hierarchy = torch.tensor(self.t_steps_hierarchy).cuda() | |
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[0]].reshape(1, 1 ,1 ,1))**2 # self.var(t_steps_hierarchy[0]) | |
x_t = x0 # torch.sqrt(1 - var_t) * x0 + torch.sqrt(var_t) * torch.randn_like(x0) | |
os.makedirs(dir_, exist_ok=True) | |
alphas = self.h_alphas | |
b = 5 | |
for i, t in enumerate(t_steps_hierarchy[:-1]): | |
t_hat = torch.ones(b).cuda() * (t) | |
a_t = torch.full((b, 1, 1, 1), alphas[i]).cuda() | |
a_prev = torch.full((b, 1, 1, 1), alphas[i+1]).cuda() | |
sigma_t = self.h_sigmas[i+1] | |
x_t, pred_x0 = self.prior_preds(x_t.float(), t_hat, cond, a_t, a_prev, sigma_t, | |
unconditional_conditioning, unconditional_guidance_scale) | |
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t].reshape(1, 1 ,1 ,1))**2 | |
a_t = 1 - var_t | |
x_t = x_t + sigma_t*torch.randn_like(x_t) | |
recon = self.model.decode_first_stage(pred_x0) | |
image_path = os.path.join(dir_, f'{i}.png') | |
image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0] | |
PIL.Image.fromarray(image_np, 'RGB').save(image_path) | |
t_hat_cur = torch.ones(b).cuda() * (self.t_steps_hierarchy[-1]) | |
e_out = self.get_error(x_t.float(), t_hat_cur, cond, unconditional_conditioning, unconditional_guidance_scale) | |
a_t = torch.full((b, 1, 1, 1), alphas[-1]).cuda() | |
sqrt_one_minus_at = torch.sqrt(1 - a_t) | |
pred_x0 = (x_t - sqrt_one_minus_at * e_out) / a_t.sqrt() | |
recon = self.model.decode_first_stage(pred_x0) | |
image_path = os.path.join(dir_, f'{len(t_steps_hierarchy)}.png') | |
image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0] | |
PIL.Image.fromarray(image_np, 'RGB').save(image_path) | |
return | |
''' |