JiminHeo commited on
Commit
33bcb61
·
1 Parent(s): c1b628d
Files changed (1) hide show
  1. vipainting.py +203 -0
vipainting.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import os
3
+ import argparse
4
+ import yaml
5
+ from omegaconf import OmegaConf
6
+ from ldm.util import instantiate_from_config, get_obj_from_str
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ import matplotlib.pyplot as plt
10
+ from utils.logger import get_logger
11
+ from utils.mask_generator import mask_generator
12
+ from utils.helper import encoder_kl, clean_directory, to_img, encoder_vq, load_file
13
+ from ldm.guided_diffusion.h_posterior import HPosterior
14
+ from PIL import Image
15
+ import numpy as np
16
+ from torchvision.transforms.functional import pil_to_tensor
17
+
18
+ def load_yaml(file_path: str) -> dict:
19
+ with open(file_path) as f:
20
+ config = yaml.load(f, Loader=yaml.FullLoader)
21
+ return config
22
+
23
+ def save_segmentation(s, img_path, name):
24
+ s = s.detach().cpu().numpy().transpose(0,2,3,1)[0,:,:,None,:]
25
+ colorize = np.random.RandomState(1).randn(1,1,s.shape[-1],3)
26
+ colorize = colorize / colorize.sum(axis=2, keepdims=True)
27
+ s = s@colorize
28
+ s = s[...,0,:]
29
+ s = ((s+1.0)*127.5).clip(0,255).astype(np.uint8)
30
+ s = Image.fromarray(s)
31
+ s.save(os.path.join(img_path, name))
32
+
33
+ def vipaint(num, mask_web, image_queue, sampling_queue):
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument('--inpaint_config', type=str, default='configs/inpainting/lands_config_mountain.yaml') #lsun_config, imagenet_config
36
+ parser.add_argument('--working_directory', type=str, default='results/')
37
+ parser.add_argument('--gpu', type=int, default=0)
38
+ parser.add_argument('--id', type=int, default=0)
39
+ parser.add_argument('--k_steps', type=int, default=2)
40
+ parser.add_argument('--case', type=str, default="random_all")
41
+ args = parser.parse_args()
42
+
43
+
44
+ # Device setting
45
+ print("================= Device setting")
46
+ device_str = f"cuda:{args.gpu}" if torch.cuda.is_available() else 'cpu'
47
+ device = torch.device(device_str)
48
+
49
+ # Load configurations
50
+ print("================= Load config")
51
+ inpaint_config = load_yaml(args.inpaint_config)
52
+ working_directory = args.working_directory
53
+
54
+ # Load model
55
+ print("================= Load model")
56
+ config = OmegaConf.load(inpaint_config['diffusion'])
57
+ vae_config = OmegaConf.load(inpaint_config['autoencoder'])
58
+
59
+ diff = instantiate_from_config(config.model)
60
+ diff.load_state_dict(torch.load(inpaint_config['diffusion_model'],
61
+ map_location='cpu')["state_dict"], strict=False)
62
+ diff = diff.to(device)
63
+ diff.model.eval()
64
+ diff.first_stage_model.eval()
65
+ diff.eval()
66
+
67
+ # Load pre-trained autoencoder loss config
68
+ print("================= Load pre-trained")
69
+ loss_config = vae_config['model']['params']['lossconfig']
70
+ vae_loss = get_obj_from_str(inpaint_config['name'],
71
+ reload=False)(**loss_config.get("params", dict()))
72
+
73
+ # Load test data
74
+ print("================= Load test data")
75
+ if os.path.exists(inpaint_config['data']['file_name']):
76
+ dataset = np.load(inpaint_config['data']['file_name'])
77
+ loader = torch.utils.data.DataLoader(dataset= dataset, batch_size=1)
78
+
79
+ # Working directory
80
+ print("================= working directory")
81
+ out_path = working_directory
82
+ os.makedirs(out_path, exist_ok=True)
83
+
84
+
85
+ #mask = torch.tensor(np.load("masks/mask_" + str(args.id) + ".npy")).to(device)
86
+ posterior = inpaint_config['posterior']
87
+ if args.k_steps == 1:
88
+ posterior = "gauss"
89
+ t_steps_hierarchy = [400]
90
+ else :
91
+ posterior = "hierarchical"
92
+ if args.k_steps == 2: t_steps_hierarchy = [inpaint_config[posterior]['t_steps_hierarchy'][0],
93
+ inpaint_config[posterior]['t_steps_hierarchy'][-1]]
94
+ elif args.k_steps == 4: t_steps_hierarchy = inpaint_config[posterior]['t_steps_hierarchy'] # [550, 500, 450, 400]
95
+ elif args.k_steps == 6: t_steps_hierarchy = [650, 600, 550, 500, 450, 400]
96
+
97
+
98
+ # Prepare VI method
99
+ print("=================== Prepare VI method")
100
+ h_inpainter = HPosterior(diff, vae_loss,
101
+ eta = inpaint_config[posterior]["eta"],
102
+ z0_size = inpaint_config["data"]["latent_size"],
103
+ img_size = inpaint_config["data"]["image_size"],
104
+ latent_channels = inpaint_config["data"]["latent_channels"],
105
+ first_stage=inpaint_config[posterior]["first_stage"],
106
+ t_steps_hierarchy=t_steps_hierarchy, #inpaint_config[posterior]['t_steps_hierarchy'],
107
+ posterior = inpaint_config['posterior'], image_queue = image_queue,
108
+ sampling_queue = sampling_queue)
109
+
110
+ h_inpainter.descretize(inpaint_config[posterior]['rho'])
111
+
112
+ x_size = inpaint_config['mask_opt']['image_size']
113
+ channels = inpaint_config['data']['channels']
114
+
115
+ # Do Inference
116
+ print("=================== Do Inference")
117
+ imgs = [num]
118
+ for i, random_num in enumerate(imgs):
119
+ img_path = os.path.join(out_path, str(random_num) ) # +str(args.k_steps) + "_h" #"Loss-ablation"
120
+ for img_dir in ['progress', 'params', 'mus']:
121
+ sub_dir = os.path.join(img_path, img_dir)
122
+ os.makedirs(sub_dir, exist_ok=True)
123
+
124
+ bs = inpaint_config[posterior]["batch_size"]
125
+
126
+ batch_size = bs
127
+ channels = 182
128
+ # For conditional models
129
+ segmentation = loader.dataset["segmentation"][random_num]
130
+ if inpaint_config["conditional_model"] :
131
+ segment_c = torch.tensor(segmentation.transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device)
132
+ segment_c = segment_c.repeat(batch_size, 1, 1, 1)
133
+ uc = diff.get_learned_conditioning(
134
+ {diff.cond_stage_key: segment_c.to(diff.device)}['segmentation']
135
+ ).detach()
136
+
137
+ #Get Image/Labels
138
+ print("==================== get image/labels")
139
+ #Get Image/Labels
140
+ if len(loader.dataset) ==2:
141
+ ref_img = loader.dataset["images"][random_num] #512, 512, 3
142
+ ref_img = torch.tensor(ref_img[None]).to(dtype=torch.float32, device=diff.device)
143
+ print(f"ref_img {ref_img.shape}") #1, 512, 512, 3
144
+ ref_img = ref_img/127.5 - 1
145
+
146
+ label = torch.tensor(segmentation.transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device)
147
+ save_segmentation(label, img_path, 'input.png')
148
+ label = label.repeat(batch_size, 1, 1, 1) # Now shape is [batch_size, 182, 128, 128]
149
+ xc = torch.tensor(label)
150
+ c = diff.get_learned_conditioning({diff.cond_stage_key: xc}['segmentation']).detach()
151
+ else:
152
+ ref_img = loader.dataset[random_num].reshape(1,x_size,x_size,channels)
153
+ c = None
154
+ uc = None
155
+
156
+ ref_img = torch.tensor(ref_img).to(device)
157
+
158
+ # #Get mask
159
+ mask_tensor = torch.tensor(mask_web).to(device)
160
+ mask_tensor = mask_tensor.float() / 255.0 # Convert to float and normalize to [0, 1]
161
+ ref_img = torch.permute(ref_img, (0,3,1,2))
162
+ y = torch.Tensor.repeat(mask_tensor*ref_img, [bs,1,1,1]).float()
163
+
164
+ if inpaint_config[posterior]["first_stage"] == "kl":
165
+ y_encoded = encoder_kl(diff, y)[0]
166
+ else:
167
+ y_encoded = encoder_vq(diff, y)
168
+
169
+ # print(f"shape {ref_img.shape} {mask.shape}")
170
+ plt.imsave(os.path.join(img_path, 'true.png'), to_img(ref_img).astype(np.uint8)[0])
171
+ plt.imsave(os.path.join(img_path, 'observed.png'), to_img(y).astype(np.uint8)[0])
172
+
173
+ lambda_ = h_inpainter.init(y_encoded, inpaint_config["init"]["var_scale"],
174
+ inpaint_config[posterior]["mean_scale"], inpaint_config["init"]["prior_scale"],
175
+ inpaint_config[posterior]["mean_scale_top"])
176
+ # Fit posterior once
177
+ print("============ fit posterior once")
178
+ torch.cuda.empty_cache()
179
+ h_inpainter.fit(lambda_ = lambda_, cond=c, shape = (bs, *y_encoded.shape[1:]),
180
+ quantize_denoised=False, mask_pixel = mask_tensor, y =y,
181
+ log_every_t=25, iterations = inpaint_config[posterior]['iterations'],
182
+ unconditional_guidance_scale= inpaint_config[posterior]["unconditional_guidance_scale"] ,
183
+ unconditional_conditioning=uc, kl_weight_1=inpaint_config[posterior]["beta_1"],
184
+ kl_weight_2 = inpaint_config[posterior]["beta_2"],
185
+ debug=True, wdb = False,
186
+ dir_name = img_path,
187
+ batch_size = bs,
188
+ lr_init_gamma = inpaint_config[posterior]["lr_init_gamma"],
189
+ recon_weight = inpaint_config[posterior]["recon"],
190
+ )
191
+
192
+ # Load parameters and sample
193
+ print("============= load parameters and sample")
194
+ params_path = os.path.join(img_path, 'params', f'{inpaint_config[posterior]["iterations"]}.pt') #, j+1
195
+ [mu, logvar, gamma] = torch.load(params_path)
196
+
197
+ h_inpainter.sample(inpaint_config["sampling"]["scale"], inpaint_config[posterior]["eta"],
198
+ mu.cuda(), logvar.cuda(), gamma.cuda(), mask_tensor, y,
199
+ n_samples=inpaint_config["sampling"]["n_samples"],
200
+ batch_size = bs, dir_name= img_path, cond=c,
201
+ unconditional_conditioning=uc,
202
+ unconditional_guidance_scale=inpaint_config["sampling"]["unconditional_guidance_scale"],
203
+ samples_iteration=inpaint_config[posterior]["iterations"])