Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import os | |
import pickle | |
from ldm.util import default | |
import glob | |
import PIL | |
import matplotlib.pyplot as plt | |
def load_file(filename): | |
with open(filename , 'rb') as file: | |
x = pickle.load(file) | |
return x | |
def save_file(filename, x, mode="wb"): | |
with open(filename, mode) as file: | |
pickle.dump(x, file) | |
def normalize_np(img): | |
""" Normalize img in arbitrary range to [0, 1] """ | |
img -= np.min(img) | |
img /= np.max(img) | |
return img | |
def clear_color(x): | |
if torch.is_complex(x): | |
x = torch.abs(x) | |
x = x.detach().cpu().squeeze().numpy() | |
return normalize_np(np.transpose(x, (1, 2, 0))) | |
def to_img(sample): | |
return (sample.detach().cpu().numpy().transpose(0,2,3,1) * 127.5 + 128).clip(0, 255) | |
def save_plot(dir_name, tensors, labels, file_name="loss.png"): | |
t = np.linspace(0, len(tensors[0]), len(tensors[0])) | |
colours = ["r", "b", "g"] | |
plt.figure() | |
for j in range(len(tensors)): | |
plt.plot(t, tensors[j],color = colours[j], label = labels[j]) | |
plt.legend() | |
plt.savefig(os.path.join(dir_name, file_name)) | |
#plt.show() | |
def save_samples(dir_name, sample, k=None, num_to_save = 5, file_name = None): | |
if type(sample) is not np.ndarray: sample_np = to_img(sample).astype(np.uint8) | |
else: sample_np = sample.astype(np.uint8) | |
for j in range(num_to_save): | |
if file_name is None: | |
if k is not None: file_name_img = f'sample_{k+1}'f'{j}.png' | |
else: file_name_img = f'{j}.png' | |
else: file_name_img = file_name | |
image_path = os.path.join(dir_name,file_name_img) | |
image_np = sample_np[j] | |
PIL.Image.fromarray(image_np, 'RGB').save(image_path) | |
file_name_img = None | |
def save_inpaintings(dir_name, sample, y, mask_pixel, k=None, num_to_save = 5, file_name = None): | |
recon_in = y*(mask_pixel) + ( 1-mask_pixel)*sample | |
recon_in = to_img(recon_in) | |
for j in range(num_to_save): | |
if file_name is None: | |
if k is not None: file_name_img = f'sample_{k+1}'f'{j}.png' | |
else: file_name_img = f'{j}.png' | |
else: file_name_img = file_name | |
image_path = os.path.join(dir_name, file_name_img) | |
image_np = recon_in.astype(np.uint8)[j] | |
PIL.Image.fromarray(image_np, 'RGB').save(image_path) | |
file_name_img = None | |
def save_params(dir_name, mu_pos, logvar_pos, gamma,k): | |
params_to_fit = params_untrain([mu_pos.detach().cpu(), logvar_pos.detach().cpu(), gamma.detach().cpu()]) | |
params_path = os.path.join(dir_name, f'{k+1}.pt') | |
torch.save(params_to_fit, params_path) | |
def custom_to_np(img): | |
sample = img.detach().cpu() | |
#sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) | |
#sample = sample.permute(0, 2, 3, 1) | |
sample = sample.contiguous() | |
return sample | |
def encoder_kl(diff, img): | |
_, params = diff.encode_first_stage(img, return_all = True) | |
params = diff.scale_factor * params | |
mean, logvar = torch.chunk(params, 2, dim=1) | |
noise = default(None, lambda: torch.randn_like(mean)) | |
mean = mean + diff.scale_factor*noise | |
return mean, logvar | |
def encoder_vq(diff, img): | |
quant = diff.encode_first_stage(img) #, diff, (_,_,ind) | |
quant = diff.scale_factor * quant | |
#mean, logvar = torch.chunk(params, 2, dim=1) | |
noise = default(None, lambda: torch.randn_like(quant)) | |
mean = quant + diff.scale_factor*noise # | |
return mean | |
def clean_directory(dir_name): | |
files = glob.glob(dir_name) | |
for f in files: | |
os.remove(f) | |
def params_train( params ): | |
for item in params: | |
item.requires_grad = True | |
return params | |
def params_untrain(params): | |
for item in params: | |
item.requires_grad = False | |
return params | |
def time_descretization(sigma_min=0.002, sigma_max = 80, rho = 7, num_t_steps = 18): | |
step_indices = torch.arange(num_t_steps, dtype=torch.float64).cuda() | |
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_t_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho | |
inv_idx = torch.arange(num_t_steps -1, -1, -1).long() | |
t_steps_fwd = t_steps[inv_idx] | |
#t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 | |
return t_steps_fwd | |
def get_optimizers(means, variances, gamma_param, lr_init_gamma=0.01) : | |
[lr, step_size, gamma] = [0.1, 10, 0.99] #was 0.999 for right-half: [0.01, 10, 0.99] | |
optimizer = torch.optim.Adam([means], lr=lr, betas=(0.9, 0.99)) | |
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) | |
optimizer_2 = torch.optim.Adam([variances], lr=0.001, betas=(0.9, 0.99)) #0.001 for lsun | |
optimizer_3 = torch.optim.Adam([gamma_param], lr=lr_init_gamma, betas=(0.9, 0.99)) #0.01 | |
scheduler_2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=step_size, gamma=gamma) ##added this | |
scheduler_3 = torch.optim.lr_scheduler.StepLR(optimizer_3, step_size=step_size, gamma=gamma) | |
return [optimizer, optimizer_2, optimizer_3 ], [scheduler, scheduler_2, scheduler_3] | |
def check_directory(filename_list): | |
for filename in filename_list: | |
if not os.path.exists(filename): | |
os.mkdir(filename) | |
def s_file(filename, x, mode="wb"): | |
with open(filename, mode) as file: | |
pickle.dump(x, file) | |
def r_file(filename, mode="rb"): | |
with open(filename, mode) as file: | |
x = pickle.load(file) | |
return x | |
def sample_from_gaussian(mu, alpha, sigma): | |
noise = torch.randn_like(mu) | |
return alpha*mu + sigma * noise | |
''' | |
def make_batch(image, mask=None, device=None): | |
image = torch.permute(image, (0,3,1,2)) | |
batch_size = image.shape[0] | |
if mask is None : | |
mask = torch.zeros_like(image) | |
mask[0, :, :256, :128] = 1 | |
else : | |
mask = torch.tensor(mask) | |
masked_image = (mask)*image #+ mask*noise*0.2 | |
mask = mask[:,0,:,:].reshape(batch_size,1,image.shape[2], image.shape[3]) | |
batch = {"image": image, "mask": mask, "masked_image": masked_image} | |
for k in batch: | |
batch[k] = batch[k].to(device) | |
return batch | |
def get_sigma_t_steps(net, n_steps=3, kwargs=None): | |
sigma_min = kwargs["sigma_min"] | |
sigma_max = kwargs["sigma_max"] | |
sigma_min = max(sigma_min, net.sigma_min) | |
sigma_max = min(sigma_max, net.sigma_max) | |
##Get the time-steps based on iddpm discretization | |
num_steps = n_steps #11 # kwargs["num_steps"] | |
C_2 = kwargs["C_2"] | |
C_1 = kwargs["C_1"] | |
M = kwargs["M"] | |
step_indices = torch.arange(num_steps, dtype=torch.float64).cuda() | |
u = torch.zeros(M + 1, dtype=torch.float64).cuda() | |
alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 | |
for j in torch.arange(M, 0, -1, device=step_indices.device): # M, ..., 1 | |
u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt() | |
u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] | |
sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] | |
#print(sigma_steps) | |
##get noise schedule | |
sigma = lambda t: t | |
sigma_deriv = lambda t: 1 | |
sigma_inv = lambda sigma: sigma | |
##scaling schedule | |
s = lambda t: 1 | |
s_deriv = lambda t: 0 | |
##compute some final time steps based on the corresponding noise levels. | |
t_steps = sigma_inv(net.round_sigma(sigma_steps)) | |
return t_steps, sigma_inv, sigma, s, sigma_deriv | |
def data_replicate(data, K): | |
if len(data.shape)==2: data_batch = torch.Tensor.repeat(data,[K,1]) | |
else: data_batch = torch.Tensor.repeat(data,[K,1,1,1]) | |
return data_batch | |
''' | |
def sample_T(self, x0, eta=0.4, t_steps_hierarchy=None): | |
''' | |
sigma_discretization_edm = time_descretization(sigma_min=0.002, sigma_max = 999, rho = 7, num_t_steps = 10)/1000 | |
T_max = 1000 | |
beta_start = 1 # 0.0015*T_max | |
beta_end = 15 # 0.0155*T_max | |
def var(t): | |
return 1.0 - (1.0) * torch.exp(- beta_start * t - 0.5 * (beta_end - beta_start) * t * t) | |
''' | |
t_steps_hierarchy = torch.tensor(t_steps_hierarchy).cuda() | |
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[0]].reshape(1, 1 ,1 ,1))**2 # self.var(t_steps_hierarchy[0]) | |
x_t = torch.sqrt(1 - var_t) * x0 + torch.sqrt(var_t) * torch.randn_like(x0) | |
os.makedirs("out_temp2/", exist_ok=True) | |
for i, t in enumerate(t_steps_hierarchy): | |
t_hat = torch.ones(10).cuda() * (t) | |
e_out = self.model.model(x_t, t_hat) | |
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t].reshape(1, 1 ,1 ,1))**2 | |
#score_out = - e_out / torch.sqrt() | |
a_t = 1 - var_t | |
#beta_t = 1 - a_t/a_prev | |
#std_pos = ((1 - a_prev)/(1 - a_t)).sqrt()*torch.sqrt(beta_t) | |
pred_x0 = (x_t - torch.sqrt(1 - a_t) * e_out) / a_t.sqrt() | |
if i != len(t_steps_hierarchy) - 1: | |
var_t1 = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[i+1]].reshape(1, 1 ,1 ,1))**2 | |
a_prev = 1 - var_t1 # var(t_steps_hierarchy[i+1]/1000) # torch.full((10, 1, 1, 1), alphas[t_steps_hierarchy[i+1]]).cuda() | |
sigma_t = eta * torch.sqrt((1 - a_prev) / (1 - a_t) * (1 - a_t / a_prev)) | |
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_out | |
x_t = a_prev.sqrt() * pred_x0 + dir_xt + torch.randn_like(x_t) * sigma_t + sigma_t*torch.randn_like(x_t) | |
#x_t= (x_t - torch.sqrt( 1 - a_t/a_prev) * e_out ) / (a_t/a_prev).sqrt() + std_pos*torch.randn_like(x_t) | |
''' | |
def pred_mean(pred_x0, z_t): | |
posterior_mean_coef1 = beta_t * torch.sqrt(a_prev) / (1. - a_t) | |
posterior_mean_coef2 = (1. - a_prev) * torch.sqrt(a_t/a_prev) / (1. - a_t) | |
return posterior_mean_coef1*pred_x0 + posterior_mean_coef2*z_t | |
x_t = torch.sqrt(a_prev) * pred_x0 # pred_mean(pred_x0, x_t) #+ 0.4*torch.sqrt(beta_t) *torch.randn_like(x_t) | |
''' | |
recon = self.model.decode_first_stage(pred_x0) | |
image_path = os.path.join("out_temp2/", f'{i}.png') | |
image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0] | |
PIL.Image.fromarray(image_np, 'RGB').save(image_path) | |
return | |