Spaces:
Sleeping
Sleeping
from functools import partial | |
import os | |
import argparse | |
import yaml | |
from omegaconf import OmegaConf | |
from ldm.util import instantiate_from_config, get_obj_from_str | |
import torch | |
import torchvision.transforms as transforms | |
import matplotlib.pyplot as plt | |
from utils.logger import get_logger | |
from utils.mask_generator import mask_generator | |
from utils.helper import encoder_kl, clean_directory, to_img, encoder_vq, load_file | |
from ldm.guided_diffusion.h_posterior import HPosterior | |
from PIL import Image | |
import numpy as np | |
from torchvision.transforms.functional import pil_to_tensor | |
def load_yaml(file_path: str) -> dict: | |
with open(file_path) as f: | |
config = yaml.load(f, Loader=yaml.FullLoader) | |
return config | |
def save_segmentation(s, img_path, name): | |
s = s.detach().cpu().numpy().transpose(0,2,3,1)[0,:,:,None,:] | |
colorize = np.random.RandomState(1).randn(1,1,s.shape[-1],3) | |
colorize = colorize / colorize.sum(axis=2, keepdims=True) | |
s = s@colorize | |
s = s[...,0,:] | |
s = ((s+1.0)*127.5).clip(0,255).astype(np.uint8) | |
s = Image.fromarray(s) | |
s.save(os.path.join(img_path, name)) | |
def vipaint(num, mask_web, image_queue, sampling_queue): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--inpaint_config', type=str, default='configs/inpainting/lands_config_mountain.yaml') #lsun_config, imagenet_config | |
parser.add_argument('--working_directory', type=str, default='results/') | |
parser.add_argument('--gpu', type=int, default=0) | |
parser.add_argument('--id', type=int, default=0) | |
parser.add_argument('--k_steps', type=int, default=2) | |
parser.add_argument('--case', type=str, default="random_all") | |
args = parser.parse_args() | |
# Device setting | |
print("================= Device setting") | |
device_str = f"cuda:{args.gpu}" if torch.cuda.is_available() else 'cpu' | |
device = torch.device(device_str) | |
# Load configurations | |
print("================= Load config") | |
inpaint_config = load_yaml(args.inpaint_config) | |
working_directory = args.working_directory | |
# Load model | |
print("================= Load model") | |
config = OmegaConf.load(inpaint_config['diffusion']) | |
vae_config = OmegaConf.load(inpaint_config['autoencoder']) | |
diff = instantiate_from_config(config.model) | |
diff.load_state_dict(torch.load(inpaint_config['diffusion_model'], | |
map_location='cpu')["state_dict"], strict=False) | |
diff = diff.to(device) | |
diff.model.eval() | |
diff.first_stage_model.eval() | |
diff.eval() | |
# Load pre-trained autoencoder loss config | |
print("================= Load pre-trained") | |
loss_config = vae_config['model']['params']['lossconfig'] | |
vae_loss = get_obj_from_str(inpaint_config['name'], | |
reload=False)(**loss_config.get("params", dict())) | |
# Load test data | |
print("================= Load test data") | |
if os.path.exists(inpaint_config['data']['file_name']): | |
dataset = np.load(inpaint_config['data']['file_name']) | |
loader = torch.utils.data.DataLoader(dataset= dataset, batch_size=1) | |
# Working directory | |
print("================= working directory") | |
out_path = working_directory | |
os.makedirs(out_path, exist_ok=True) | |
#mask = torch.tensor(np.load("masks/mask_" + str(args.id) + ".npy")).to(device) | |
posterior = inpaint_config['posterior'] | |
if args.k_steps == 1: | |
posterior = "gauss" | |
t_steps_hierarchy = [400] | |
else : | |
posterior = "hierarchical" | |
if args.k_steps == 2: t_steps_hierarchy = [inpaint_config[posterior]['t_steps_hierarchy'][0], | |
inpaint_config[posterior]['t_steps_hierarchy'][-1]] | |
elif args.k_steps == 4: t_steps_hierarchy = inpaint_config[posterior]['t_steps_hierarchy'] # [550, 500, 450, 400] | |
elif args.k_steps == 6: t_steps_hierarchy = [650, 600, 550, 500, 450, 400] | |
# Prepare VI method | |
print("=================== Prepare VI method") | |
h_inpainter = HPosterior(diff, vae_loss, | |
eta = inpaint_config[posterior]["eta"], | |
z0_size = inpaint_config["data"]["latent_size"], | |
img_size = inpaint_config["data"]["image_size"], | |
latent_channels = inpaint_config["data"]["latent_channels"], | |
first_stage=inpaint_config[posterior]["first_stage"], | |
t_steps_hierarchy=t_steps_hierarchy, #inpaint_config[posterior]['t_steps_hierarchy'], | |
posterior = inpaint_config['posterior'], image_queue = image_queue, | |
sampling_queue = sampling_queue) | |
h_inpainter.descretize(inpaint_config[posterior]['rho']) | |
x_size = inpaint_config['mask_opt']['image_size'] | |
channels = inpaint_config['data']['channels'] | |
# Do Inference | |
print("=================== Do Inference") | |
imgs = [num] | |
for i, random_num in enumerate(imgs): | |
img_path = os.path.join(out_path, str(random_num) ) # +str(args.k_steps) + "_h" #"Loss-ablation" | |
for img_dir in ['progress', 'params', 'mus']: | |
sub_dir = os.path.join(img_path, img_dir) | |
os.makedirs(sub_dir, exist_ok=True) | |
bs = inpaint_config[posterior]["batch_size"] | |
batch_size = bs | |
channels = 182 | |
# For conditional models | |
segmentation = loader.dataset["segmentation"][random_num] | |
if inpaint_config["conditional_model"] : | |
segment_c = torch.tensor(segmentation.transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device) | |
segment_c = segment_c.repeat(batch_size, 1, 1, 1) | |
uc = diff.get_learned_conditioning( | |
{diff.cond_stage_key: segment_c.to(diff.device)}['segmentation'] | |
).detach() | |
#Get Image/Labels | |
print("==================== get image/labels") | |
#Get Image/Labels | |
if len(loader.dataset) ==2: | |
ref_img = loader.dataset["images"][random_num] #512, 512, 3 | |
ref_img = torch.tensor(ref_img[None]).to(dtype=torch.float32, device=diff.device) | |
print(f"ref_img {ref_img.shape}") #1, 512, 512, 3 | |
ref_img = ref_img/127.5 - 1 | |
label = torch.tensor(segmentation.transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device) | |
save_segmentation(label, img_path, 'input.png') | |
label = label.repeat(batch_size, 1, 1, 1) # Now shape is [batch_size, 182, 128, 128] | |
xc = torch.tensor(label) | |
c = diff.get_learned_conditioning({diff.cond_stage_key: xc}['segmentation']).detach() | |
else: | |
ref_img = loader.dataset[random_num].reshape(1,x_size,x_size,channels) | |
c = None | |
uc = None | |
ref_img = torch.tensor(ref_img).to(device) | |
# #Get mask | |
mask_tensor = torch.tensor(mask_web).to(device) | |
mask_tensor = mask_tensor.float() / 255.0 # Convert to float and normalize to [0, 1] | |
ref_img = torch.permute(ref_img, (0,3,1,2)) | |
y = torch.Tensor.repeat(mask_tensor*ref_img, [bs,1,1,1]).float() | |
if inpaint_config[posterior]["first_stage"] == "kl": | |
y_encoded = encoder_kl(diff, y)[0] | |
else: | |
y_encoded = encoder_vq(diff, y) | |
# print(f"shape {ref_img.shape} {mask.shape}") | |
plt.imsave(os.path.join(img_path, 'true.png'), to_img(ref_img).astype(np.uint8)[0]) | |
plt.imsave(os.path.join(img_path, 'observed.png'), to_img(y).astype(np.uint8)[0]) | |
lambda_ = h_inpainter.init(y_encoded, inpaint_config["init"]["var_scale"], | |
inpaint_config[posterior]["mean_scale"], inpaint_config["init"]["prior_scale"], | |
inpaint_config[posterior]["mean_scale_top"]) | |
# Fit posterior once | |
print("============ fit posterior once") | |
torch.cuda.empty_cache() | |
h_inpainter.fit(lambda_ = lambda_, cond=c, shape = (bs, *y_encoded.shape[1:]), | |
quantize_denoised=False, mask_pixel = mask_tensor, y =y, | |
log_every_t=25, iterations = inpaint_config[posterior]['iterations'], | |
unconditional_guidance_scale= inpaint_config[posterior]["unconditional_guidance_scale"] , | |
unconditional_conditioning=uc, kl_weight_1=inpaint_config[posterior]["beta_1"], | |
kl_weight_2 = inpaint_config[posterior]["beta_2"], | |
debug=True, wdb = False, | |
dir_name = img_path, | |
batch_size = bs, | |
lr_init_gamma = inpaint_config[posterior]["lr_init_gamma"], | |
recon_weight = inpaint_config[posterior]["recon"], | |
) | |
# Load parameters and sample | |
print("============= load parameters and sample") | |
params_path = os.path.join(img_path, 'params', f'{inpaint_config[posterior]["iterations"]}.pt') #, j+1 | |
[mu, logvar, gamma] = torch.load(params_path) | |
h_inpainter.sample(inpaint_config["sampling"]["scale"], inpaint_config[posterior]["eta"], | |
mu.cuda(), logvar.cuda(), gamma.cuda(), mask_tensor, y, | |
n_samples=inpaint_config["sampling"]["n_samples"], | |
batch_size = bs, dir_name= img_path, cond=c, | |
unconditional_conditioning=uc, | |
unconditional_guidance_scale=inpaint_config["sampling"]["unconditional_guidance_scale"], | |
samples_iteration=inpaint_config[posterior]["iterations"]) |