JiminHeo commited on
Commit
cc55052
·
1 Parent(s): 29b76c6

cleaner code

Browse files
configs/inpainting/lands_config_mountain.yaml CHANGED
@@ -3,7 +3,7 @@ data:
3
  seq: {'half': [200, 300], 'box': [300, 350], 'random': [400,500]} #[400,500] #[350, 450], #, 'val': "random" : [350, 450], half : , val: [0,50]
4
  file_seq: None
5
  file_name: data/sflckr_all_images.npz
6
- channels: 3
7
  image_size: 512
8
  latent_size: 128
9
  latent_channels: 3
 
3
  seq: {'half': [200, 300], 'box': [300, 350], 'random': [400,500]} #[400,500] #[350, 450], #, 'val': "random" : [350, 450], half : , val: [0,50]
4
  file_seq: None
5
  file_name: data/sflckr_all_images.npz
6
+ channels: 182
7
  image_size: 512
8
  latent_size: 128
9
  latent_channels: 3
vipainting.py CHANGED
@@ -94,6 +94,10 @@ def vipaint(num, mask_web, image_queue, sampling_queue):
94
  elif args.k_steps == 4: t_steps_hierarchy = inpaint_config[posterior]['t_steps_hierarchy'] # [550, 500, 450, 400]
95
  elif args.k_steps == 6: t_steps_hierarchy = [650, 600, 550, 500, 450, 400]
96
 
 
 
 
 
97
 
98
  # Prepare VI method
99
  print("=================== Prepare VI method")
@@ -121,45 +125,30 @@ def vipaint(num, mask_web, image_queue, sampling_queue):
121
  sub_dir = os.path.join(img_path, img_dir)
122
  os.makedirs(sub_dir, exist_ok=True)
123
 
124
- bs = inpaint_config[posterior]["batch_size"]
125
-
126
- batch_size = bs
127
- channels = 182
128
- # For conditional models
129
- segmentation = loader.dataset["segmentation"][random_num]
130
- if inpaint_config["conditional_model"] :
131
- segment_c = torch.tensor(segmentation.transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device)
132
- segment_c = segment_c.repeat(batch_size, 1, 1, 1)
133
- uc = diff.get_learned_conditioning(
134
- {diff.cond_stage_key: segment_c.to(diff.device)}['segmentation']
135
- ).detach()
136
 
137
  #Get Image/Labels
138
- print("==================== get image/labels")
139
- #Get Image/Labels
140
  if len(loader.dataset) ==2:
141
- ref_img = loader.dataset["images"][random_num] #512, 512, 3
142
- ref_img = torch.tensor(ref_img[None]).to(dtype=torch.float32, device=diff.device)
143
- print(f"ref_img {ref_img.shape}") #1, 512, 512, 3
144
  ref_img = ref_img/127.5 - 1
145
-
146
- label = torch.tensor(segmentation.transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device)
147
- save_segmentation(label, img_path, 'input.png')
148
- label = label.repeat(batch_size, 1, 1, 1) # Now shape is [batch_size, 182, 128, 128]
149
- xc = torch.tensor(label)
150
- c = diff.get_learned_conditioning({diff.cond_stage_key: xc}['segmentation']).detach()
 
151
  else:
152
- ref_img = loader.dataset[random_num].reshape(1,x_size,x_size,channels)
153
  c = None
154
  uc = None
155
 
156
- ref_img = torch.tensor(ref_img).to(device)
157
 
158
  # #Get mask
159
  mask_tensor = torch.tensor(mask_web).to(device)
160
  mask_tensor = mask_tensor.float() / 255.0 # Convert to float and normalize to [0, 1]
161
  ref_img = torch.permute(ref_img, (0,3,1,2))
162
- y = torch.Tensor.repeat(mask_tensor*ref_img, [bs,1,1,1]).float()
163
 
164
  if inpaint_config[posterior]["first_stage"] == "kl":
165
  y_encoded = encoder_kl(diff, y)[0]
@@ -176,7 +165,7 @@ def vipaint(num, mask_web, image_queue, sampling_queue):
176
  # Fit posterior once
177
  print("============ fit posterior once")
178
  torch.cuda.empty_cache()
179
- h_inpainter.fit(lambda_ = lambda_, cond=c, shape = (bs, *y_encoded.shape[1:]),
180
  quantize_denoised=False, mask_pixel = mask_tensor, y =y,
181
  log_every_t=25, iterations = inpaint_config[posterior]['iterations'],
182
  unconditional_guidance_scale= inpaint_config[posterior]["unconditional_guidance_scale"] ,
@@ -184,7 +173,7 @@ def vipaint(num, mask_web, image_queue, sampling_queue):
184
  kl_weight_2 = inpaint_config[posterior]["beta_2"],
185
  debug=True, wdb = False,
186
  dir_name = img_path,
187
- batch_size = bs,
188
  lr_init_gamma = inpaint_config[posterior]["lr_init_gamma"],
189
  recon_weight = inpaint_config[posterior]["recon"],
190
  )
@@ -197,7 +186,7 @@ def vipaint(num, mask_web, image_queue, sampling_queue):
197
  h_inpainter.sample(inpaint_config["sampling"]["scale"], inpaint_config[posterior]["eta"],
198
  mu.cuda(), logvar.cuda(), gamma.cuda(), mask_tensor, y,
199
  n_samples=inpaint_config["sampling"]["n_samples"],
200
- batch_size = bs, dir_name= img_path, cond=c,
201
  unconditional_conditioning=uc,
202
  unconditional_guidance_scale=inpaint_config["sampling"]["unconditional_guidance_scale"],
203
  samples_iteration=inpaint_config[posterior]["iterations"])
 
94
  elif args.k_steps == 4: t_steps_hierarchy = inpaint_config[posterior]['t_steps_hierarchy'] # [550, 500, 450, 400]
95
  elif args.k_steps == 6: t_steps_hierarchy = [650, 600, 550, 500, 450, 400]
96
 
97
+ batch_size = inpaint_config[posterior]["batch_size"]
98
+ zero_tensor = torch.zeros(batch_size, 182, 512, 512, device=diff.device)
99
+ uc = diff.get_learned_conditioning({diff.cond_stage_key: zero_tensor}['segmentation']).detach()
100
+
101
 
102
  # Prepare VI method
103
  print("=================== Prepare VI method")
 
125
  sub_dir = os.path.join(img_path, img_dir)
126
  os.makedirs(sub_dir, exist_ok=True)
127
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  #Get Image/Labels
130
+ print(f"==================== get image/labels")
 
131
  if len(loader.dataset) ==2:
132
+ ref_img = torch.tensor(loader.dataset["images"][random_num][None], dtype=torch.float32, device=diff.device) #1, 512, 512, 3
 
 
133
  ref_img = ref_img/127.5 - 1
134
+ segmentation = torch.tensor(dataset["segmentation"][random_num].transpose(2,0,1)[None]).to(dtype=torch.float32, device=diff.device)
135
+ segmentation_repeated = segmentation.repeat(batch_size, 1, 1, 1)
136
+ save_segmentation(segmentation, img_path, 'input.png')
137
+ c = diff.get_learned_conditioning(
138
+ {diff.cond_stage_key: segmentation_repeated.to(diff.device)}['segmentation']
139
+ ).detach()
140
+
141
  else:
142
+ ref_img = torch.tensor(loader.dataset[random_num].reshape(1, x_size, x_size, channels), dtype=torch.float32, device=diff.device)
143
  c = None
144
  uc = None
145
 
 
146
 
147
  # #Get mask
148
  mask_tensor = torch.tensor(mask_web).to(device)
149
  mask_tensor = mask_tensor.float() / 255.0 # Convert to float and normalize to [0, 1]
150
  ref_img = torch.permute(ref_img, (0,3,1,2))
151
+ y = torch.Tensor.repeat(mask_tensor*ref_img, [batch_size,1,1,1]).float()
152
 
153
  if inpaint_config[posterior]["first_stage"] == "kl":
154
  y_encoded = encoder_kl(diff, y)[0]
 
165
  # Fit posterior once
166
  print("============ fit posterior once")
167
  torch.cuda.empty_cache()
168
+ h_inpainter.fit(lambda_ = lambda_, cond=c, shape = (batch_size, *y_encoded.shape[1:]),
169
  quantize_denoised=False, mask_pixel = mask_tensor, y =y,
170
  log_every_t=25, iterations = inpaint_config[posterior]['iterations'],
171
  unconditional_guidance_scale= inpaint_config[posterior]["unconditional_guidance_scale"] ,
 
173
  kl_weight_2 = inpaint_config[posterior]["beta_2"],
174
  debug=True, wdb = False,
175
  dir_name = img_path,
176
+ batch_size = batch_size,
177
  lr_init_gamma = inpaint_config[posterior]["lr_init_gamma"],
178
  recon_weight = inpaint_config[posterior]["recon"],
179
  )
 
186
  h_inpainter.sample(inpaint_config["sampling"]["scale"], inpaint_config[posterior]["eta"],
187
  mu.cuda(), logvar.cuda(), gamma.cuda(), mask_tensor, y,
188
  n_samples=inpaint_config["sampling"]["n_samples"],
189
+ batch_size = batch_size, dir_name= img_path, cond=c,
190
  unconditional_conditioning=uc,
191
  unconditional_guidance_scale=inpaint_config["sampling"]["unconditional_guidance_scale"],
192
  samples_iteration=inpaint_config[posterior]["iterations"])