Spaces:
Sleeping
Sleeping
vipaint
Browse files- vipainting.py +203 -0
vipainting.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import yaml
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
from ldm.util import instantiate_from_config, get_obj_from_str
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from utils.logger import get_logger
|
11 |
+
from utils.mask_generator import mask_generator
|
12 |
+
from utils.helper import encoder_kl, clean_directory, to_img, encoder_vq, load_file
|
13 |
+
from ldm.guided_diffusion.h_posterior import HPosterior
|
14 |
+
from PIL import Image
|
15 |
+
import numpy as np
|
16 |
+
from torchvision.transforms.functional import pil_to_tensor
|
17 |
+
|
18 |
+
def load_yaml(file_path: str) -> dict:
|
19 |
+
with open(file_path) as f:
|
20 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
21 |
+
return config
|
22 |
+
|
23 |
+
def save_segmentation(s, img_path, name):
|
24 |
+
s = s.detach().cpu().numpy().transpose(0,2,3,1)[0,:,:,None,:]
|
25 |
+
colorize = np.random.RandomState(1).randn(1,1,s.shape[-1],3)
|
26 |
+
colorize = colorize / colorize.sum(axis=2, keepdims=True)
|
27 |
+
s = s@colorize
|
28 |
+
s = s[...,0,:]
|
29 |
+
s = ((s+1.0)*127.5).clip(0,255).astype(np.uint8)
|
30 |
+
s = Image.fromarray(s)
|
31 |
+
s.save(os.path.join(img_path, name))
|
32 |
+
|
33 |
+
def vipaint(num, mask_web, image_queue, sampling_queue):
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
parser.add_argument('--inpaint_config', type=str, default='configs/inpainting/lands_config_mountain.yaml') #lsun_config, imagenet_config
|
36 |
+
parser.add_argument('--working_directory', type=str, default='results/')
|
37 |
+
parser.add_argument('--gpu', type=int, default=0)
|
38 |
+
parser.add_argument('--id', type=int, default=0)
|
39 |
+
parser.add_argument('--k_steps', type=int, default=2)
|
40 |
+
parser.add_argument('--case', type=str, default="random_all")
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
|
44 |
+
# Device setting
|
45 |
+
print("================= Device setting")
|
46 |
+
device_str = f"cuda:{args.gpu}" if torch.cuda.is_available() else 'cpu'
|
47 |
+
device = torch.device(device_str)
|
48 |
+
|
49 |
+
# Load configurations
|
50 |
+
print("================= Load config")
|
51 |
+
inpaint_config = load_yaml(args.inpaint_config)
|
52 |
+
working_directory = args.working_directory
|
53 |
+
|
54 |
+
# Load model
|
55 |
+
print("================= Load model")
|
56 |
+
config = OmegaConf.load(inpaint_config['diffusion'])
|
57 |
+
vae_config = OmegaConf.load(inpaint_config['autoencoder'])
|
58 |
+
|
59 |
+
diff = instantiate_from_config(config.model)
|
60 |
+
diff.load_state_dict(torch.load(inpaint_config['diffusion_model'],
|
61 |
+
map_location='cpu')["state_dict"], strict=False)
|
62 |
+
diff = diff.to(device)
|
63 |
+
diff.model.eval()
|
64 |
+
diff.first_stage_model.eval()
|
65 |
+
diff.eval()
|
66 |
+
|
67 |
+
# Load pre-trained autoencoder loss config
|
68 |
+
print("================= Load pre-trained")
|
69 |
+
loss_config = vae_config['model']['params']['lossconfig']
|
70 |
+
vae_loss = get_obj_from_str(inpaint_config['name'],
|
71 |
+
reload=False)(**loss_config.get("params", dict()))
|
72 |
+
|
73 |
+
# Load test data
|
74 |
+
print("================= Load test data")
|
75 |
+
if os.path.exists(inpaint_config['data']['file_name']):
|
76 |
+
dataset = np.load(inpaint_config['data']['file_name'])
|
77 |
+
loader = torch.utils.data.DataLoader(dataset= dataset, batch_size=1)
|
78 |
+
|
79 |
+
# Working directory
|
80 |
+
print("================= working directory")
|
81 |
+
out_path = working_directory
|
82 |
+
os.makedirs(out_path, exist_ok=True)
|
83 |
+
|
84 |
+
|
85 |
+
#mask = torch.tensor(np.load("masks/mask_" + str(args.id) + ".npy")).to(device)
|
86 |
+
posterior = inpaint_config['posterior']
|
87 |
+
if args.k_steps == 1:
|
88 |
+
posterior = "gauss"
|
89 |
+
t_steps_hierarchy = [400]
|
90 |
+
else :
|
91 |
+
posterior = "hierarchical"
|
92 |
+
if args.k_steps == 2: t_steps_hierarchy = [inpaint_config[posterior]['t_steps_hierarchy'][0],
|
93 |
+
inpaint_config[posterior]['t_steps_hierarchy'][-1]]
|
94 |
+
elif args.k_steps == 4: t_steps_hierarchy = inpaint_config[posterior]['t_steps_hierarchy'] # [550, 500, 450, 400]
|
95 |
+
elif args.k_steps == 6: t_steps_hierarchy = [650, 600, 550, 500, 450, 400]
|
96 |
+
|
97 |
+
|
98 |
+
# Prepare VI method
|
99 |
+
print("=================== Prepare VI method")
|
100 |
+
h_inpainter = HPosterior(diff, vae_loss,
|
101 |
+
eta = inpaint_config[posterior]["eta"],
|
102 |
+
z0_size = inpaint_config["data"]["latent_size"],
|
103 |
+
img_size = inpaint_config["data"]["image_size"],
|
104 |
+
latent_channels = inpaint_config["data"]["latent_channels"],
|
105 |
+
first_stage=inpaint_config[posterior]["first_stage"],
|
106 |
+
t_steps_hierarchy=t_steps_hierarchy, #inpaint_config[posterior]['t_steps_hierarchy'],
|
107 |
+
posterior = inpaint_config['posterior'], image_queue = image_queue,
|
108 |
+
sampling_queue = sampling_queue)
|
109 |
+
|
110 |
+
h_inpainter.descretize(inpaint_config[posterior]['rho'])
|
111 |
+
|
112 |
+
x_size = inpaint_config['mask_opt']['image_size']
|
113 |
+
channels = inpaint_config['data']['channels']
|
114 |
+
|
115 |
+
# Do Inference
|
116 |
+
print("=================== Do Inference")
|
117 |
+
imgs = [num]
|
118 |
+
for i, random_num in enumerate(imgs):
|
119 |
+
img_path = os.path.join(out_path, str(random_num) ) # +str(args.k_steps) + "_h" #"Loss-ablation"
|
120 |
+
for img_dir in ['progress', 'params', 'mus']:
|
121 |
+
sub_dir = os.path.join(img_path, img_dir)
|
122 |
+
os.makedirs(sub_dir, exist_ok=True)
|
123 |
+
|
124 |
+
bs = inpaint_config[posterior]["batch_size"]
|
125 |
+
|
126 |
+
batch_size = bs
|
127 |
+
channels = 182
|
128 |
+
# For conditional models
|
129 |
+
segmentation = loader.dataset["segmentation"][random_num]
|
130 |
+
if inpaint_config["conditional_model"] :
|
131 |
+
segment_c = torch.tensor(segmentation.transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device)
|
132 |
+
segment_c = segment_c.repeat(batch_size, 1, 1, 1)
|
133 |
+
uc = diff.get_learned_conditioning(
|
134 |
+
{diff.cond_stage_key: segment_c.to(diff.device)}['segmentation']
|
135 |
+
).detach()
|
136 |
+
|
137 |
+
#Get Image/Labels
|
138 |
+
print("==================== get image/labels")
|
139 |
+
#Get Image/Labels
|
140 |
+
if len(loader.dataset) ==2:
|
141 |
+
ref_img = loader.dataset["images"][random_num] #512, 512, 3
|
142 |
+
ref_img = torch.tensor(ref_img[None]).to(dtype=torch.float32, device=diff.device)
|
143 |
+
print(f"ref_img {ref_img.shape}") #1, 512, 512, 3
|
144 |
+
ref_img = ref_img/127.5 - 1
|
145 |
+
|
146 |
+
label = torch.tensor(segmentation.transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device)
|
147 |
+
save_segmentation(label, img_path, 'input.png')
|
148 |
+
label = label.repeat(batch_size, 1, 1, 1) # Now shape is [batch_size, 182, 128, 128]
|
149 |
+
xc = torch.tensor(label)
|
150 |
+
c = diff.get_learned_conditioning({diff.cond_stage_key: xc}['segmentation']).detach()
|
151 |
+
else:
|
152 |
+
ref_img = loader.dataset[random_num].reshape(1,x_size,x_size,channels)
|
153 |
+
c = None
|
154 |
+
uc = None
|
155 |
+
|
156 |
+
ref_img = torch.tensor(ref_img).to(device)
|
157 |
+
|
158 |
+
# #Get mask
|
159 |
+
mask_tensor = torch.tensor(mask_web).to(device)
|
160 |
+
mask_tensor = mask_tensor.float() / 255.0 # Convert to float and normalize to [0, 1]
|
161 |
+
ref_img = torch.permute(ref_img, (0,3,1,2))
|
162 |
+
y = torch.Tensor.repeat(mask_tensor*ref_img, [bs,1,1,1]).float()
|
163 |
+
|
164 |
+
if inpaint_config[posterior]["first_stage"] == "kl":
|
165 |
+
y_encoded = encoder_kl(diff, y)[0]
|
166 |
+
else:
|
167 |
+
y_encoded = encoder_vq(diff, y)
|
168 |
+
|
169 |
+
# print(f"shape {ref_img.shape} {mask.shape}")
|
170 |
+
plt.imsave(os.path.join(img_path, 'true.png'), to_img(ref_img).astype(np.uint8)[0])
|
171 |
+
plt.imsave(os.path.join(img_path, 'observed.png'), to_img(y).astype(np.uint8)[0])
|
172 |
+
|
173 |
+
lambda_ = h_inpainter.init(y_encoded, inpaint_config["init"]["var_scale"],
|
174 |
+
inpaint_config[posterior]["mean_scale"], inpaint_config["init"]["prior_scale"],
|
175 |
+
inpaint_config[posterior]["mean_scale_top"])
|
176 |
+
# Fit posterior once
|
177 |
+
print("============ fit posterior once")
|
178 |
+
torch.cuda.empty_cache()
|
179 |
+
h_inpainter.fit(lambda_ = lambda_, cond=c, shape = (bs, *y_encoded.shape[1:]),
|
180 |
+
quantize_denoised=False, mask_pixel = mask_tensor, y =y,
|
181 |
+
log_every_t=25, iterations = inpaint_config[posterior]['iterations'],
|
182 |
+
unconditional_guidance_scale= inpaint_config[posterior]["unconditional_guidance_scale"] ,
|
183 |
+
unconditional_conditioning=uc, kl_weight_1=inpaint_config[posterior]["beta_1"],
|
184 |
+
kl_weight_2 = inpaint_config[posterior]["beta_2"],
|
185 |
+
debug=True, wdb = False,
|
186 |
+
dir_name = img_path,
|
187 |
+
batch_size = bs,
|
188 |
+
lr_init_gamma = inpaint_config[posterior]["lr_init_gamma"],
|
189 |
+
recon_weight = inpaint_config[posterior]["recon"],
|
190 |
+
)
|
191 |
+
|
192 |
+
# Load parameters and sample
|
193 |
+
print("============= load parameters and sample")
|
194 |
+
params_path = os.path.join(img_path, 'params', f'{inpaint_config[posterior]["iterations"]}.pt') #, j+1
|
195 |
+
[mu, logvar, gamma] = torch.load(params_path)
|
196 |
+
|
197 |
+
h_inpainter.sample(inpaint_config["sampling"]["scale"], inpaint_config[posterior]["eta"],
|
198 |
+
mu.cuda(), logvar.cuda(), gamma.cuda(), mask_tensor, y,
|
199 |
+
n_samples=inpaint_config["sampling"]["n_samples"],
|
200 |
+
batch_size = bs, dir_name= img_path, cond=c,
|
201 |
+
unconditional_conditioning=uc,
|
202 |
+
unconditional_guidance_scale=inpaint_config["sampling"]["unconditional_guidance_scale"],
|
203 |
+
samples_iteration=inpaint_config[posterior]["iterations"])
|