Spaces:
Sleeping
Sleeping
cleaner code
Browse files- configs/inpainting/lands_config_mountain.yaml +1 -1
- vipainting.py +18 -29
configs/inpainting/lands_config_mountain.yaml
CHANGED
@@ -3,7 +3,7 @@ data:
|
|
3 |
seq: {'half': [200, 300], 'box': [300, 350], 'random': [400,500]} #[400,500] #[350, 450], #, 'val': "random" : [350, 450], half : , val: [0,50]
|
4 |
file_seq: None
|
5 |
file_name: data/sflckr_all_images.npz
|
6 |
-
channels:
|
7 |
image_size: 512
|
8 |
latent_size: 128
|
9 |
latent_channels: 3
|
|
|
3 |
seq: {'half': [200, 300], 'box': [300, 350], 'random': [400,500]} #[400,500] #[350, 450], #, 'val': "random" : [350, 450], half : , val: [0,50]
|
4 |
file_seq: None
|
5 |
file_name: data/sflckr_all_images.npz
|
6 |
+
channels: 182
|
7 |
image_size: 512
|
8 |
latent_size: 128
|
9 |
latent_channels: 3
|
vipainting.py
CHANGED
@@ -94,6 +94,10 @@ def vipaint(num, mask_web, image_queue, sampling_queue):
|
|
94 |
elif args.k_steps == 4: t_steps_hierarchy = inpaint_config[posterior]['t_steps_hierarchy'] # [550, 500, 450, 400]
|
95 |
elif args.k_steps == 6: t_steps_hierarchy = [650, 600, 550, 500, 450, 400]
|
96 |
|
|
|
|
|
|
|
|
|
97 |
|
98 |
# Prepare VI method
|
99 |
print("=================== Prepare VI method")
|
@@ -121,45 +125,30 @@ def vipaint(num, mask_web, image_queue, sampling_queue):
|
|
121 |
sub_dir = os.path.join(img_path, img_dir)
|
122 |
os.makedirs(sub_dir, exist_ok=True)
|
123 |
|
124 |
-
bs = inpaint_config[posterior]["batch_size"]
|
125 |
-
|
126 |
-
batch_size = bs
|
127 |
-
channels = 182
|
128 |
-
# For conditional models
|
129 |
-
segmentation = loader.dataset["segmentation"][random_num]
|
130 |
-
if inpaint_config["conditional_model"] :
|
131 |
-
segment_c = torch.tensor(segmentation.transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device)
|
132 |
-
segment_c = segment_c.repeat(batch_size, 1, 1, 1)
|
133 |
-
uc = diff.get_learned_conditioning(
|
134 |
-
{diff.cond_stage_key: segment_c.to(diff.device)}['segmentation']
|
135 |
-
).detach()
|
136 |
|
137 |
#Get Image/Labels
|
138 |
-
print("==================== get image/labels")
|
139 |
-
#Get Image/Labels
|
140 |
if len(loader.dataset) ==2:
|
141 |
-
ref_img = loader.dataset["images"][random_num] #512, 512, 3
|
142 |
-
ref_img = torch.tensor(ref_img[None]).to(dtype=torch.float32, device=diff.device)
|
143 |
-
print(f"ref_img {ref_img.shape}") #1, 512, 512, 3
|
144 |
ref_img = ref_img/127.5 - 1
|
145 |
-
|
146 |
-
|
147 |
-
save_segmentation(
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
151 |
else:
|
152 |
-
ref_img = loader.dataset[random_num].reshape(1,x_size,x_size,channels)
|
153 |
c = None
|
154 |
uc = None
|
155 |
|
156 |
-
ref_img = torch.tensor(ref_img).to(device)
|
157 |
|
158 |
# #Get mask
|
159 |
mask_tensor = torch.tensor(mask_web).to(device)
|
160 |
mask_tensor = mask_tensor.float() / 255.0 # Convert to float and normalize to [0, 1]
|
161 |
ref_img = torch.permute(ref_img, (0,3,1,2))
|
162 |
-
y = torch.Tensor.repeat(mask_tensor*ref_img, [
|
163 |
|
164 |
if inpaint_config[posterior]["first_stage"] == "kl":
|
165 |
y_encoded = encoder_kl(diff, y)[0]
|
@@ -176,7 +165,7 @@ def vipaint(num, mask_web, image_queue, sampling_queue):
|
|
176 |
# Fit posterior once
|
177 |
print("============ fit posterior once")
|
178 |
torch.cuda.empty_cache()
|
179 |
-
h_inpainter.fit(lambda_ = lambda_, cond=c, shape = (
|
180 |
quantize_denoised=False, mask_pixel = mask_tensor, y =y,
|
181 |
log_every_t=25, iterations = inpaint_config[posterior]['iterations'],
|
182 |
unconditional_guidance_scale= inpaint_config[posterior]["unconditional_guidance_scale"] ,
|
@@ -184,7 +173,7 @@ def vipaint(num, mask_web, image_queue, sampling_queue):
|
|
184 |
kl_weight_2 = inpaint_config[posterior]["beta_2"],
|
185 |
debug=True, wdb = False,
|
186 |
dir_name = img_path,
|
187 |
-
batch_size =
|
188 |
lr_init_gamma = inpaint_config[posterior]["lr_init_gamma"],
|
189 |
recon_weight = inpaint_config[posterior]["recon"],
|
190 |
)
|
@@ -197,7 +186,7 @@ def vipaint(num, mask_web, image_queue, sampling_queue):
|
|
197 |
h_inpainter.sample(inpaint_config["sampling"]["scale"], inpaint_config[posterior]["eta"],
|
198 |
mu.cuda(), logvar.cuda(), gamma.cuda(), mask_tensor, y,
|
199 |
n_samples=inpaint_config["sampling"]["n_samples"],
|
200 |
-
batch_size =
|
201 |
unconditional_conditioning=uc,
|
202 |
unconditional_guidance_scale=inpaint_config["sampling"]["unconditional_guidance_scale"],
|
203 |
samples_iteration=inpaint_config[posterior]["iterations"])
|
|
|
94 |
elif args.k_steps == 4: t_steps_hierarchy = inpaint_config[posterior]['t_steps_hierarchy'] # [550, 500, 450, 400]
|
95 |
elif args.k_steps == 6: t_steps_hierarchy = [650, 600, 550, 500, 450, 400]
|
96 |
|
97 |
+
batch_size = inpaint_config[posterior]["batch_size"]
|
98 |
+
zero_tensor = torch.zeros(batch_size, 182, 512, 512, device=diff.device)
|
99 |
+
uc = diff.get_learned_conditioning({diff.cond_stage_key: zero_tensor}['segmentation']).detach()
|
100 |
+
|
101 |
|
102 |
# Prepare VI method
|
103 |
print("=================== Prepare VI method")
|
|
|
125 |
sub_dir = os.path.join(img_path, img_dir)
|
126 |
os.makedirs(sub_dir, exist_ok=True)
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
#Get Image/Labels
|
130 |
+
print(f"==================== get image/labels")
|
|
|
131 |
if len(loader.dataset) ==2:
|
132 |
+
ref_img = torch.tensor(loader.dataset["images"][random_num][None], dtype=torch.float32, device=diff.device) #1, 512, 512, 3
|
|
|
|
|
133 |
ref_img = ref_img/127.5 - 1
|
134 |
+
segmentation = torch.tensor(dataset["segmentation"][random_num].transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device)
|
135 |
+
segmentation_repeated = segmentation.repeat(batch_size, 1, 1, 1)
|
136 |
+
save_segmentation(segmentation, img_path, 'input.png')
|
137 |
+
c = diff.get_learned_conditioning(
|
138 |
+
{diff.cond_stage_key: segmentation_repeated.to(diff.device)}['segmentation']
|
139 |
+
).detach()
|
140 |
+
|
141 |
else:
|
142 |
+
ref_img = torch.tensor(loader.dataset[random_num].reshape(1, x_size, x_size, channels), dtype=torch.float32, device=diff.device)
|
143 |
c = None
|
144 |
uc = None
|
145 |
|
|
|
146 |
|
147 |
# #Get mask
|
148 |
mask_tensor = torch.tensor(mask_web).to(device)
|
149 |
mask_tensor = mask_tensor.float() / 255.0 # Convert to float and normalize to [0, 1]
|
150 |
ref_img = torch.permute(ref_img, (0,3,1,2))
|
151 |
+
y = torch.Tensor.repeat(mask_tensor*ref_img, [batch_size,1,1,1]).float()
|
152 |
|
153 |
if inpaint_config[posterior]["first_stage"] == "kl":
|
154 |
y_encoded = encoder_kl(diff, y)[0]
|
|
|
165 |
# Fit posterior once
|
166 |
print("============ fit posterior once")
|
167 |
torch.cuda.empty_cache()
|
168 |
+
h_inpainter.fit(lambda_ = lambda_, cond=c, shape = (batch_size, *y_encoded.shape[1:]),
|
169 |
quantize_denoised=False, mask_pixel = mask_tensor, y =y,
|
170 |
log_every_t=25, iterations = inpaint_config[posterior]['iterations'],
|
171 |
unconditional_guidance_scale= inpaint_config[posterior]["unconditional_guidance_scale"] ,
|
|
|
173 |
kl_weight_2 = inpaint_config[posterior]["beta_2"],
|
174 |
debug=True, wdb = False,
|
175 |
dir_name = img_path,
|
176 |
+
batch_size = batch_size,
|
177 |
lr_init_gamma = inpaint_config[posterior]["lr_init_gamma"],
|
178 |
recon_weight = inpaint_config[posterior]["recon"],
|
179 |
)
|
|
|
186 |
h_inpainter.sample(inpaint_config["sampling"]["scale"], inpaint_config[posterior]["eta"],
|
187 |
mu.cuda(), logvar.cuda(), gamma.cuda(), mask_tensor, y,
|
188 |
n_samples=inpaint_config["sampling"]["n_samples"],
|
189 |
+
batch_size = batch_size, dir_name= img_path, cond=c,
|
190 |
unconditional_conditioning=uc,
|
191 |
unconditional_guidance_scale=inpaint_config["sampling"]["unconditional_guidance_scale"],
|
192 |
samples_iteration=inpaint_config[posterior]["iterations"])
|