from functools import partial import os import argparse import yaml from omegaconf import OmegaConf from ldm.util import instantiate_from_config, get_obj_from_str import torch import torchvision.transforms as transforms import matplotlib.pyplot as plt from utils.logger import get_logger from utils.mask_generator import mask_generator from utils.helper import encoder_kl, clean_directory, to_img, encoder_vq, load_file from ldm.guided_diffusion.h_posterior import HPosterior from PIL import Image import numpy as np from torchvision.transforms.functional import pil_to_tensor def load_yaml(file_path: str) -> dict: with open(file_path) as f: config = yaml.load(f, Loader=yaml.FullLoader) return config def save_segmentation(s, img_path, name): s = s.detach().cpu().numpy().transpose(0,2,3,1)[0,:,:,None,:] colorize = np.random.RandomState(1).randn(1,1,s.shape[-1],3) colorize = colorize / colorize.sum(axis=2, keepdims=True) s = s@colorize s = s[...,0,:] s = ((s+1.0)*127.5).clip(0,255).astype(np.uint8) s = Image.fromarray(s) s.save(os.path.join(img_path, name)) def vipaint(num, mask_web, image_queue, sampling_queue): parser = argparse.ArgumentParser() parser.add_argument('--inpaint_config', type=str, default='configs/inpainting/lands_config_mountain.yaml') #lsun_config, imagenet_config parser.add_argument('--working_directory', type=str, default='results/') parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--id', type=int, default=0) parser.add_argument('--k_steps', type=int, default=2) parser.add_argument('--case', type=str, default="random_all") args = parser.parse_args() # Device setting print("================= Device setting") device_str = f"cuda:{args.gpu}" if torch.cuda.is_available() else 'cpu' device = torch.device(device_str) # Load configurations print("================= Load config") inpaint_config = load_yaml(args.inpaint_config) working_directory = args.working_directory # Load model print("================= Load model") config = OmegaConf.load(inpaint_config['diffusion']) vae_config = OmegaConf.load(inpaint_config['autoencoder']) diff = instantiate_from_config(config.model) diff.load_state_dict(torch.load(inpaint_config['diffusion_model'], map_location='cpu')["state_dict"], strict=False) diff = diff.to(device) diff.model.eval() diff.first_stage_model.eval() diff.eval() # Load pre-trained autoencoder loss config print("================= Load pre-trained") loss_config = vae_config['model']['params']['lossconfig'] vae_loss = get_obj_from_str(inpaint_config['name'], reload=False)(**loss_config.get("params", dict())) # Load test data print("================= Load test data") if os.path.exists(inpaint_config['data']['file_name']): dataset = np.load(inpaint_config['data']['file_name']) loader = torch.utils.data.DataLoader(dataset= dataset, batch_size=1) # Working directory print("================= working directory") out_path = working_directory os.makedirs(out_path, exist_ok=True) #mask = torch.tensor(np.load("masks/mask_" + str(args.id) + ".npy")).to(device) posterior = inpaint_config['posterior'] if args.k_steps == 1: posterior = "gauss" t_steps_hierarchy = [400] else : posterior = "hierarchical" if args.k_steps == 2: t_steps_hierarchy = [inpaint_config[posterior]['t_steps_hierarchy'][0], inpaint_config[posterior]['t_steps_hierarchy'][-1]] elif args.k_steps == 4: t_steps_hierarchy = inpaint_config[posterior]['t_steps_hierarchy'] # [550, 500, 450, 400] elif args.k_steps == 6: t_steps_hierarchy = [650, 600, 550, 500, 450, 400] batch_size = inpaint_config[posterior]["batch_size"] zero_tensor = torch.zeros(batch_size, 182, 512, 512, device=diff.device) uc = diff.get_learned_conditioning({diff.cond_stage_key: zero_tensor}['segmentation']).detach() # Prepare VI method print("=================== Prepare VI method") h_inpainter = HPosterior(diff, vae_loss, eta = inpaint_config[posterior]["eta"], z0_size = inpaint_config["data"]["latent_size"], img_size = inpaint_config["data"]["image_size"], latent_channels = inpaint_config["data"]["latent_channels"], first_stage=inpaint_config[posterior]["first_stage"], t_steps_hierarchy=t_steps_hierarchy, #inpaint_config[posterior]['t_steps_hierarchy'], posterior = inpaint_config['posterior'], image_queue = image_queue, sampling_queue = sampling_queue) h_inpainter.descretize(inpaint_config[posterior]['rho']) x_size = inpaint_config['mask_opt']['image_size'] channels = inpaint_config['data']['channels'] # Do Inference print("=================== Do Inference") imgs = [num] for i, random_num in enumerate(imgs): img_path = os.path.join(out_path, str(random_num) ) # +str(args.k_steps) + "_h" #"Loss-ablation" for img_dir in ['progress', 'params', 'mus']: sub_dir = os.path.join(img_path, img_dir) os.makedirs(sub_dir, exist_ok=True) #Get Image/Labels print(f"==================== get image/labels") if len(loader.dataset) ==2: ref_img = torch.tensor(loader.dataset["images"][random_num][None], dtype=torch.float32, device=diff.device) #1, 512, 512, 3 ref_img = ref_img/127.5 - 1 segmentation = torch.tensor(dataset["segmentation"][random_num].transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device) segmentation_repeated = segmentation.repeat(batch_size, 1, 1, 1) save_segmentation(segmentation, img_path, 'input.png') c = diff.get_learned_conditioning( {diff.cond_stage_key: segmentation_repeated.to(diff.device)}['segmentation'] ).detach() else: ref_img = torch.tensor(loader.dataset[random_num].reshape(1, x_size, x_size, channels), dtype=torch.float32, device=diff.device) c = None uc = None # #Get mask mask_tensor = torch.tensor(mask_web).to(device) mask_tensor = mask_tensor.float() / 255.0 # Convert to float and normalize to [0, 1] ref_img = torch.permute(ref_img, (0,3,1,2)) y = torch.Tensor.repeat(mask_tensor*ref_img, [batch_size,1,1,1]).float() if inpaint_config[posterior]["first_stage"] == "kl": y_encoded = encoder_kl(diff, y)[0] else: y_encoded = encoder_vq(diff, y) # print(f"shape {ref_img.shape} {mask.shape}") plt.imsave(os.path.join(img_path, 'true.png'), to_img(ref_img).astype(np.uint8)[0]) plt.imsave(os.path.join(img_path, 'observed.png'), to_img(y).astype(np.uint8)[0]) lambda_ = h_inpainter.init(y_encoded, inpaint_config["init"]["var_scale"], inpaint_config[posterior]["mean_scale"], inpaint_config["init"]["prior_scale"], inpaint_config[posterior]["mean_scale_top"]) # Fit posterior once print("============ fit posterior once") torch.cuda.empty_cache() h_inpainter.fit(lambda_ = lambda_, cond=c, shape = (batch_size, *y_encoded.shape[1:]), quantize_denoised=False, mask_pixel = mask_tensor, y =y, log_every_t=25, iterations = inpaint_config[posterior]['iterations'], unconditional_guidance_scale= inpaint_config[posterior]["unconditional_guidance_scale"] , unconditional_conditioning=uc, kl_weight_1=inpaint_config[posterior]["beta_1"], kl_weight_2 = inpaint_config[posterior]["beta_2"], debug=True, wdb = False, dir_name = img_path, batch_size = batch_size, lr_init_gamma = inpaint_config[posterior]["lr_init_gamma"], recon_weight = inpaint_config[posterior]["recon"], ) # Load parameters and sample print("============= load parameters and sample") params_path = os.path.join(img_path, 'params', f'{inpaint_config[posterior]["iterations"]}.pt') #, j+1 [mu, logvar, gamma] = torch.load(params_path) h_inpainter.sample(inpaint_config["sampling"]["scale"], inpaint_config[posterior]["eta"], mu.cuda(), logvar.cuda(), gamma.cuda(), mask_tensor, y, n_samples=inpaint_config["sampling"]["n_samples"], batch_size = batch_size, dir_name= img_path, cond=c, unconditional_conditioning=uc, unconditional_guidance_scale=inpaint_config["sampling"]["unconditional_guidance_scale"], samples_iteration=inpaint_config[posterior]["iterations"])