MakeYourOwnMask_Inpaint / vipainting.py
JiminHeo's picture
vipaint
33bcb61
raw
history blame
9.72 kB
from functools import partial
import os
import argparse
import yaml
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config, get_obj_from_str
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from utils.logger import get_logger
from utils.mask_generator import mask_generator
from utils.helper import encoder_kl, clean_directory, to_img, encoder_vq, load_file
from ldm.guided_diffusion.h_posterior import HPosterior
from PIL import Image
import numpy as np
from torchvision.transforms.functional import pil_to_tensor
def load_yaml(file_path: str) -> dict:
with open(file_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config
def save_segmentation(s, img_path, name):
s = s.detach().cpu().numpy().transpose(0,2,3,1)[0,:,:,None,:]
colorize = np.random.RandomState(1).randn(1,1,s.shape[-1],3)
colorize = colorize / colorize.sum(axis=2, keepdims=True)
s = s@colorize
s = s[...,0,:]
s = ((s+1.0)*127.5).clip(0,255).astype(np.uint8)
s = Image.fromarray(s)
s.save(os.path.join(img_path, name))
def vipaint(num, mask_web, image_queue, sampling_queue):
parser = argparse.ArgumentParser()
parser.add_argument('--inpaint_config', type=str, default='configs/inpainting/lands_config_mountain.yaml') #lsun_config, imagenet_config
parser.add_argument('--working_directory', type=str, default='results/')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--id', type=int, default=0)
parser.add_argument('--k_steps', type=int, default=2)
parser.add_argument('--case', type=str, default="random_all")
args = parser.parse_args()
# Device setting
print("================= Device setting")
device_str = f"cuda:{args.gpu}" if torch.cuda.is_available() else 'cpu'
device = torch.device(device_str)
# Load configurations
print("================= Load config")
inpaint_config = load_yaml(args.inpaint_config)
working_directory = args.working_directory
# Load model
print("================= Load model")
config = OmegaConf.load(inpaint_config['diffusion'])
vae_config = OmegaConf.load(inpaint_config['autoencoder'])
diff = instantiate_from_config(config.model)
diff.load_state_dict(torch.load(inpaint_config['diffusion_model'],
map_location='cpu')["state_dict"], strict=False)
diff = diff.to(device)
diff.model.eval()
diff.first_stage_model.eval()
diff.eval()
# Load pre-trained autoencoder loss config
print("================= Load pre-trained")
loss_config = vae_config['model']['params']['lossconfig']
vae_loss = get_obj_from_str(inpaint_config['name'],
reload=False)(**loss_config.get("params", dict()))
# Load test data
print("================= Load test data")
if os.path.exists(inpaint_config['data']['file_name']):
dataset = np.load(inpaint_config['data']['file_name'])
loader = torch.utils.data.DataLoader(dataset= dataset, batch_size=1)
# Working directory
print("================= working directory")
out_path = working_directory
os.makedirs(out_path, exist_ok=True)
#mask = torch.tensor(np.load("masks/mask_" + str(args.id) + ".npy")).to(device)
posterior = inpaint_config['posterior']
if args.k_steps == 1:
posterior = "gauss"
t_steps_hierarchy = [400]
else :
posterior = "hierarchical"
if args.k_steps == 2: t_steps_hierarchy = [inpaint_config[posterior]['t_steps_hierarchy'][0],
inpaint_config[posterior]['t_steps_hierarchy'][-1]]
elif args.k_steps == 4: t_steps_hierarchy = inpaint_config[posterior]['t_steps_hierarchy'] # [550, 500, 450, 400]
elif args.k_steps == 6: t_steps_hierarchy = [650, 600, 550, 500, 450, 400]
# Prepare VI method
print("=================== Prepare VI method")
h_inpainter = HPosterior(diff, vae_loss,
eta = inpaint_config[posterior]["eta"],
z0_size = inpaint_config["data"]["latent_size"],
img_size = inpaint_config["data"]["image_size"],
latent_channels = inpaint_config["data"]["latent_channels"],
first_stage=inpaint_config[posterior]["first_stage"],
t_steps_hierarchy=t_steps_hierarchy, #inpaint_config[posterior]['t_steps_hierarchy'],
posterior = inpaint_config['posterior'], image_queue = image_queue,
sampling_queue = sampling_queue)
h_inpainter.descretize(inpaint_config[posterior]['rho'])
x_size = inpaint_config['mask_opt']['image_size']
channels = inpaint_config['data']['channels']
# Do Inference
print("=================== Do Inference")
imgs = [num]
for i, random_num in enumerate(imgs):
img_path = os.path.join(out_path, str(random_num) ) # +str(args.k_steps) + "_h" #"Loss-ablation"
for img_dir in ['progress', 'params', 'mus']:
sub_dir = os.path.join(img_path, img_dir)
os.makedirs(sub_dir, exist_ok=True)
bs = inpaint_config[posterior]["batch_size"]
batch_size = bs
channels = 182
# For conditional models
segmentation = loader.dataset["segmentation"][random_num]
if inpaint_config["conditional_model"] :
segment_c = torch.tensor(segmentation.transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device)
segment_c = segment_c.repeat(batch_size, 1, 1, 1)
uc = diff.get_learned_conditioning(
{diff.cond_stage_key: segment_c.to(diff.device)}['segmentation']
).detach()
#Get Image/Labels
print("==================== get image/labels")
#Get Image/Labels
if len(loader.dataset) ==2:
ref_img = loader.dataset["images"][random_num] #512, 512, 3
ref_img = torch.tensor(ref_img[None]).to(dtype=torch.float32, device=diff.device)
print(f"ref_img {ref_img.shape}") #1, 512, 512, 3
ref_img = ref_img/127.5 - 1
label = torch.tensor(segmentation.transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device)
save_segmentation(label, img_path, 'input.png')
label = label.repeat(batch_size, 1, 1, 1) # Now shape is [batch_size, 182, 128, 128]
xc = torch.tensor(label)
c = diff.get_learned_conditioning({diff.cond_stage_key: xc}['segmentation']).detach()
else:
ref_img = loader.dataset[random_num].reshape(1,x_size,x_size,channels)
c = None
uc = None
ref_img = torch.tensor(ref_img).to(device)
# #Get mask
mask_tensor = torch.tensor(mask_web).to(device)
mask_tensor = mask_tensor.float() / 255.0 # Convert to float and normalize to [0, 1]
ref_img = torch.permute(ref_img, (0,3,1,2))
y = torch.Tensor.repeat(mask_tensor*ref_img, [bs,1,1,1]).float()
if inpaint_config[posterior]["first_stage"] == "kl":
y_encoded = encoder_kl(diff, y)[0]
else:
y_encoded = encoder_vq(diff, y)
# print(f"shape {ref_img.shape} {mask.shape}")
plt.imsave(os.path.join(img_path, 'true.png'), to_img(ref_img).astype(np.uint8)[0])
plt.imsave(os.path.join(img_path, 'observed.png'), to_img(y).astype(np.uint8)[0])
lambda_ = h_inpainter.init(y_encoded, inpaint_config["init"]["var_scale"],
inpaint_config[posterior]["mean_scale"], inpaint_config["init"]["prior_scale"],
inpaint_config[posterior]["mean_scale_top"])
# Fit posterior once
print("============ fit posterior once")
torch.cuda.empty_cache()
h_inpainter.fit(lambda_ = lambda_, cond=c, shape = (bs, *y_encoded.shape[1:]),
quantize_denoised=False, mask_pixel = mask_tensor, y =y,
log_every_t=25, iterations = inpaint_config[posterior]['iterations'],
unconditional_guidance_scale= inpaint_config[posterior]["unconditional_guidance_scale"] ,
unconditional_conditioning=uc, kl_weight_1=inpaint_config[posterior]["beta_1"],
kl_weight_2 = inpaint_config[posterior]["beta_2"],
debug=True, wdb = False,
dir_name = img_path,
batch_size = bs,
lr_init_gamma = inpaint_config[posterior]["lr_init_gamma"],
recon_weight = inpaint_config[posterior]["recon"],
)
# Load parameters and sample
print("============= load parameters and sample")
params_path = os.path.join(img_path, 'params', f'{inpaint_config[posterior]["iterations"]}.pt') #, j+1
[mu, logvar, gamma] = torch.load(params_path)
h_inpainter.sample(inpaint_config["sampling"]["scale"], inpaint_config[posterior]["eta"],
mu.cuda(), logvar.cuda(), gamma.cuda(), mask_tensor, y,
n_samples=inpaint_config["sampling"]["n_samples"],
batch_size = bs, dir_name= img_path, cond=c,
unconditional_conditioning=uc,
unconditional_guidance_scale=inpaint_config["sampling"]["unconditional_guidance_scale"],
samples_iteration=inpaint_config[posterior]["iterations"])