Spaces:
Sleeping
Sleeping
vipaint
Browse files- vipainting.py +203 -0
vipainting.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import yaml
|
| 5 |
+
from omegaconf import OmegaConf
|
| 6 |
+
from ldm.util import instantiate_from_config, get_obj_from_str
|
| 7 |
+
import torch
|
| 8 |
+
import torchvision.transforms as transforms
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from utils.logger import get_logger
|
| 11 |
+
from utils.mask_generator import mask_generator
|
| 12 |
+
from utils.helper import encoder_kl, clean_directory, to_img, encoder_vq, load_file
|
| 13 |
+
from ldm.guided_diffusion.h_posterior import HPosterior
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import numpy as np
|
| 16 |
+
from torchvision.transforms.functional import pil_to_tensor
|
| 17 |
+
|
| 18 |
+
def load_yaml(file_path: str) -> dict:
|
| 19 |
+
with open(file_path) as f:
|
| 20 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
| 21 |
+
return config
|
| 22 |
+
|
| 23 |
+
def save_segmentation(s, img_path, name):
|
| 24 |
+
s = s.detach().cpu().numpy().transpose(0,2,3,1)[0,:,:,None,:]
|
| 25 |
+
colorize = np.random.RandomState(1).randn(1,1,s.shape[-1],3)
|
| 26 |
+
colorize = colorize / colorize.sum(axis=2, keepdims=True)
|
| 27 |
+
s = s@colorize
|
| 28 |
+
s = s[...,0,:]
|
| 29 |
+
s = ((s+1.0)*127.5).clip(0,255).astype(np.uint8)
|
| 30 |
+
s = Image.fromarray(s)
|
| 31 |
+
s.save(os.path.join(img_path, name))
|
| 32 |
+
|
| 33 |
+
def vipaint(num, mask_web, image_queue, sampling_queue):
|
| 34 |
+
parser = argparse.ArgumentParser()
|
| 35 |
+
parser.add_argument('--inpaint_config', type=str, default='configs/inpainting/lands_config_mountain.yaml') #lsun_config, imagenet_config
|
| 36 |
+
parser.add_argument('--working_directory', type=str, default='results/')
|
| 37 |
+
parser.add_argument('--gpu', type=int, default=0)
|
| 38 |
+
parser.add_argument('--id', type=int, default=0)
|
| 39 |
+
parser.add_argument('--k_steps', type=int, default=2)
|
| 40 |
+
parser.add_argument('--case', type=str, default="random_all")
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Device setting
|
| 45 |
+
print("================= Device setting")
|
| 46 |
+
device_str = f"cuda:{args.gpu}" if torch.cuda.is_available() else 'cpu'
|
| 47 |
+
device = torch.device(device_str)
|
| 48 |
+
|
| 49 |
+
# Load configurations
|
| 50 |
+
print("================= Load config")
|
| 51 |
+
inpaint_config = load_yaml(args.inpaint_config)
|
| 52 |
+
working_directory = args.working_directory
|
| 53 |
+
|
| 54 |
+
# Load model
|
| 55 |
+
print("================= Load model")
|
| 56 |
+
config = OmegaConf.load(inpaint_config['diffusion'])
|
| 57 |
+
vae_config = OmegaConf.load(inpaint_config['autoencoder'])
|
| 58 |
+
|
| 59 |
+
diff = instantiate_from_config(config.model)
|
| 60 |
+
diff.load_state_dict(torch.load(inpaint_config['diffusion_model'],
|
| 61 |
+
map_location='cpu')["state_dict"], strict=False)
|
| 62 |
+
diff = diff.to(device)
|
| 63 |
+
diff.model.eval()
|
| 64 |
+
diff.first_stage_model.eval()
|
| 65 |
+
diff.eval()
|
| 66 |
+
|
| 67 |
+
# Load pre-trained autoencoder loss config
|
| 68 |
+
print("================= Load pre-trained")
|
| 69 |
+
loss_config = vae_config['model']['params']['lossconfig']
|
| 70 |
+
vae_loss = get_obj_from_str(inpaint_config['name'],
|
| 71 |
+
reload=False)(**loss_config.get("params", dict()))
|
| 72 |
+
|
| 73 |
+
# Load test data
|
| 74 |
+
print("================= Load test data")
|
| 75 |
+
if os.path.exists(inpaint_config['data']['file_name']):
|
| 76 |
+
dataset = np.load(inpaint_config['data']['file_name'])
|
| 77 |
+
loader = torch.utils.data.DataLoader(dataset= dataset, batch_size=1)
|
| 78 |
+
|
| 79 |
+
# Working directory
|
| 80 |
+
print("================= working directory")
|
| 81 |
+
out_path = working_directory
|
| 82 |
+
os.makedirs(out_path, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
#mask = torch.tensor(np.load("masks/mask_" + str(args.id) + ".npy")).to(device)
|
| 86 |
+
posterior = inpaint_config['posterior']
|
| 87 |
+
if args.k_steps == 1:
|
| 88 |
+
posterior = "gauss"
|
| 89 |
+
t_steps_hierarchy = [400]
|
| 90 |
+
else :
|
| 91 |
+
posterior = "hierarchical"
|
| 92 |
+
if args.k_steps == 2: t_steps_hierarchy = [inpaint_config[posterior]['t_steps_hierarchy'][0],
|
| 93 |
+
inpaint_config[posterior]['t_steps_hierarchy'][-1]]
|
| 94 |
+
elif args.k_steps == 4: t_steps_hierarchy = inpaint_config[posterior]['t_steps_hierarchy'] # [550, 500, 450, 400]
|
| 95 |
+
elif args.k_steps == 6: t_steps_hierarchy = [650, 600, 550, 500, 450, 400]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# Prepare VI method
|
| 99 |
+
print("=================== Prepare VI method")
|
| 100 |
+
h_inpainter = HPosterior(diff, vae_loss,
|
| 101 |
+
eta = inpaint_config[posterior]["eta"],
|
| 102 |
+
z0_size = inpaint_config["data"]["latent_size"],
|
| 103 |
+
img_size = inpaint_config["data"]["image_size"],
|
| 104 |
+
latent_channels = inpaint_config["data"]["latent_channels"],
|
| 105 |
+
first_stage=inpaint_config[posterior]["first_stage"],
|
| 106 |
+
t_steps_hierarchy=t_steps_hierarchy, #inpaint_config[posterior]['t_steps_hierarchy'],
|
| 107 |
+
posterior = inpaint_config['posterior'], image_queue = image_queue,
|
| 108 |
+
sampling_queue = sampling_queue)
|
| 109 |
+
|
| 110 |
+
h_inpainter.descretize(inpaint_config[posterior]['rho'])
|
| 111 |
+
|
| 112 |
+
x_size = inpaint_config['mask_opt']['image_size']
|
| 113 |
+
channels = inpaint_config['data']['channels']
|
| 114 |
+
|
| 115 |
+
# Do Inference
|
| 116 |
+
print("=================== Do Inference")
|
| 117 |
+
imgs = [num]
|
| 118 |
+
for i, random_num in enumerate(imgs):
|
| 119 |
+
img_path = os.path.join(out_path, str(random_num) ) # +str(args.k_steps) + "_h" #"Loss-ablation"
|
| 120 |
+
for img_dir in ['progress', 'params', 'mus']:
|
| 121 |
+
sub_dir = os.path.join(img_path, img_dir)
|
| 122 |
+
os.makedirs(sub_dir, exist_ok=True)
|
| 123 |
+
|
| 124 |
+
bs = inpaint_config[posterior]["batch_size"]
|
| 125 |
+
|
| 126 |
+
batch_size = bs
|
| 127 |
+
channels = 182
|
| 128 |
+
# For conditional models
|
| 129 |
+
segmentation = loader.dataset["segmentation"][random_num]
|
| 130 |
+
if inpaint_config["conditional_model"] :
|
| 131 |
+
segment_c = torch.tensor(segmentation.transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device)
|
| 132 |
+
segment_c = segment_c.repeat(batch_size, 1, 1, 1)
|
| 133 |
+
uc = diff.get_learned_conditioning(
|
| 134 |
+
{diff.cond_stage_key: segment_c.to(diff.device)}['segmentation']
|
| 135 |
+
).detach()
|
| 136 |
+
|
| 137 |
+
#Get Image/Labels
|
| 138 |
+
print("==================== get image/labels")
|
| 139 |
+
#Get Image/Labels
|
| 140 |
+
if len(loader.dataset) ==2:
|
| 141 |
+
ref_img = loader.dataset["images"][random_num] #512, 512, 3
|
| 142 |
+
ref_img = torch.tensor(ref_img[None]).to(dtype=torch.float32, device=diff.device)
|
| 143 |
+
print(f"ref_img {ref_img.shape}") #1, 512, 512, 3
|
| 144 |
+
ref_img = ref_img/127.5 - 1
|
| 145 |
+
|
| 146 |
+
label = torch.tensor(segmentation.transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device)
|
| 147 |
+
save_segmentation(label, img_path, 'input.png')
|
| 148 |
+
label = label.repeat(batch_size, 1, 1, 1) # Now shape is [batch_size, 182, 128, 128]
|
| 149 |
+
xc = torch.tensor(label)
|
| 150 |
+
c = diff.get_learned_conditioning({diff.cond_stage_key: xc}['segmentation']).detach()
|
| 151 |
+
else:
|
| 152 |
+
ref_img = loader.dataset[random_num].reshape(1,x_size,x_size,channels)
|
| 153 |
+
c = None
|
| 154 |
+
uc = None
|
| 155 |
+
|
| 156 |
+
ref_img = torch.tensor(ref_img).to(device)
|
| 157 |
+
|
| 158 |
+
# #Get mask
|
| 159 |
+
mask_tensor = torch.tensor(mask_web).to(device)
|
| 160 |
+
mask_tensor = mask_tensor.float() / 255.0 # Convert to float and normalize to [0, 1]
|
| 161 |
+
ref_img = torch.permute(ref_img, (0,3,1,2))
|
| 162 |
+
y = torch.Tensor.repeat(mask_tensor*ref_img, [bs,1,1,1]).float()
|
| 163 |
+
|
| 164 |
+
if inpaint_config[posterior]["first_stage"] == "kl":
|
| 165 |
+
y_encoded = encoder_kl(diff, y)[0]
|
| 166 |
+
else:
|
| 167 |
+
y_encoded = encoder_vq(diff, y)
|
| 168 |
+
|
| 169 |
+
# print(f"shape {ref_img.shape} {mask.shape}")
|
| 170 |
+
plt.imsave(os.path.join(img_path, 'true.png'), to_img(ref_img).astype(np.uint8)[0])
|
| 171 |
+
plt.imsave(os.path.join(img_path, 'observed.png'), to_img(y).astype(np.uint8)[0])
|
| 172 |
+
|
| 173 |
+
lambda_ = h_inpainter.init(y_encoded, inpaint_config["init"]["var_scale"],
|
| 174 |
+
inpaint_config[posterior]["mean_scale"], inpaint_config["init"]["prior_scale"],
|
| 175 |
+
inpaint_config[posterior]["mean_scale_top"])
|
| 176 |
+
# Fit posterior once
|
| 177 |
+
print("============ fit posterior once")
|
| 178 |
+
torch.cuda.empty_cache()
|
| 179 |
+
h_inpainter.fit(lambda_ = lambda_, cond=c, shape = (bs, *y_encoded.shape[1:]),
|
| 180 |
+
quantize_denoised=False, mask_pixel = mask_tensor, y =y,
|
| 181 |
+
log_every_t=25, iterations = inpaint_config[posterior]['iterations'],
|
| 182 |
+
unconditional_guidance_scale= inpaint_config[posterior]["unconditional_guidance_scale"] ,
|
| 183 |
+
unconditional_conditioning=uc, kl_weight_1=inpaint_config[posterior]["beta_1"],
|
| 184 |
+
kl_weight_2 = inpaint_config[posterior]["beta_2"],
|
| 185 |
+
debug=True, wdb = False,
|
| 186 |
+
dir_name = img_path,
|
| 187 |
+
batch_size = bs,
|
| 188 |
+
lr_init_gamma = inpaint_config[posterior]["lr_init_gamma"],
|
| 189 |
+
recon_weight = inpaint_config[posterior]["recon"],
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Load parameters and sample
|
| 193 |
+
print("============= load parameters and sample")
|
| 194 |
+
params_path = os.path.join(img_path, 'params', f'{inpaint_config[posterior]["iterations"]}.pt') #, j+1
|
| 195 |
+
[mu, logvar, gamma] = torch.load(params_path)
|
| 196 |
+
|
| 197 |
+
h_inpainter.sample(inpaint_config["sampling"]["scale"], inpaint_config[posterior]["eta"],
|
| 198 |
+
mu.cuda(), logvar.cuda(), gamma.cuda(), mask_tensor, y,
|
| 199 |
+
n_samples=inpaint_config["sampling"]["n_samples"],
|
| 200 |
+
batch_size = bs, dir_name= img_path, cond=c,
|
| 201 |
+
unconditional_conditioning=uc,
|
| 202 |
+
unconditional_guidance_scale=inpaint_config["sampling"]["unconditional_guidance_scale"],
|
| 203 |
+
samples_iteration=inpaint_config[posterior]["iterations"])
|