Nadine Rueegg commited on
Commit
3cd5df9
·
1 Parent(s): 0d2dc80

don't save images

Browse files
Files changed (1) hide show
  1. scripts/gradio_demo.py +1 -23
scripts/gradio_demo.py CHANGED
@@ -244,7 +244,6 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
244
  with open(loss_weight_path, 'r') as j:
245
  losses = json.loads(j.read())
246
  shutil.copyfile(loss_weight_path, root_out_path_details + os.path.basename(loss_weight_path))
247
- print(losses)
248
 
249
  # prepare dataset and dataset loader
250
  val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints = get_single_crop_dataset_from_image(input_image, bbox=bbox)
@@ -258,7 +257,7 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
258
  # prepare progress bar
259
  iterable = enumerate(val_loader) # the length of this iterator should be 1
260
  progress = None
261
- if True: # not quiet:
262
  progress = tqdm(iterable, desc='Train', total=len(val_loader), ascii=True, leave=False)
263
  iterable = progress
264
  ind_img_tot = 0
@@ -289,29 +288,9 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
289
  ind_img = 0
290
  name = (test_name_list[target_dict['index'][ind_img].long()]).replace('/', '__').split('.')[0]
291
 
292
- print('ind_img_tot: ' + str(ind_img_tot) + ' -> ' + name)
293
  ind_img_tot += 1
294
  batch_size = 1
295
 
296
- # save initial visualizations
297
- # save the image with keypoints as predicted by the stacked hourglass
298
- pred_unp_prep = torch.cat((res['hg_keyp_256'][ind_img, :, :].detach(), res['hg_keyp_scores'][ind_img, :, :]), 1)
299
- inp_img = input[ind_img, :, :, :].detach().clone()
300
- out_path = root_out_path + name + '_hg_key.png'
301
- save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.01, print_scores=True, ratio_in_out=1.0) # threshold=0.3
302
- # save the input image
303
- img_inp = input[ind_img, :, :, :].clone()
304
- for t, m, s in zip(img_inp, stanext_data_info.rgb_mean, stanext_data_info.rgb_stddev): t.add_(m) # inverse to transforms.color_normalize()
305
- img_inp = img_inp.detach().cpu().numpy().transpose(1, 2, 0)
306
- img_init = Image.fromarray(np.uint8(255*img_inp)).convert('RGB')
307
- img_init.save(root_out_path_details + name + '_img_ainit.png')
308
- # save ground truth silhouette (for visualization only, it is not used during the optimization)
309
- target_img_silh = Image.fromarray(np.uint8(255*target_dict['silh'][ind_img, :, :].detach().cpu().numpy())).convert('RGB')
310
- target_img_silh.save(root_out_path_details + name + '_target_silh.png')
311
- # save the silhouette as predicted by the stacked hourglass
312
- hg_img_silh = Image.fromarray(np.uint8(255*res['hg_silh_prep'][ind_img, :, :].detach().cpu().numpy())).convert('RGB')
313
- hg_img_silh.save(root_out_path + name + '_hg_silh.png')
314
-
315
  # initialize the variables over which we want to optimize
316
  optimed_pose_6d = all_pose_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True)
317
  optimed_orient_6d = all_orient_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True) # [1, 1, 6]
@@ -386,7 +365,6 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
386
  target_hg_silh = res['hg_silh_prep'][ind_img, :, :].detach()
387
  target_kp_resh = res['hg_keyp_256'][ind_img, None, :, :].reshape((-1, 2)).detach()
388
  # find out if ground contact constraints should be used for the image at hand
389
- # print('is flat: ' + str(res['isflat_prep'][ind_img]))
390
  if res['isflat_prep'][ind_img] >= 0.5: # threshold should probably be set higher
391
  isflat = [True]
392
  else:
 
244
  with open(loss_weight_path, 'r') as j:
245
  losses = json.loads(j.read())
246
  shutil.copyfile(loss_weight_path, root_out_path_details + os.path.basename(loss_weight_path))
 
247
 
248
  # prepare dataset and dataset loader
249
  val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints = get_single_crop_dataset_from_image(input_image, bbox=bbox)
 
257
  # prepare progress bar
258
  iterable = enumerate(val_loader) # the length of this iterator should be 1
259
  progress = None
260
+ if False: # not quiet:
261
  progress = tqdm(iterable, desc='Train', total=len(val_loader), ascii=True, leave=False)
262
  iterable = progress
263
  ind_img_tot = 0
 
288
  ind_img = 0
289
  name = (test_name_list[target_dict['index'][ind_img].long()]).replace('/', '__').split('.')[0]
290
 
 
291
  ind_img_tot += 1
292
  batch_size = 1
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  # initialize the variables over which we want to optimize
295
  optimed_pose_6d = all_pose_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True)
296
  optimed_orient_6d = all_orient_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True) # [1, 1, 6]
 
365
  target_hg_silh = res['hg_silh_prep'][ind_img, :, :].detach()
366
  target_kp_resh = res['hg_keyp_256'][ind_img, None, :, :].reshape((-1, 2)).detach()
367
  # find out if ground contact constraints should be used for the image at hand
 
368
  if res['isflat_prep'][ind_img] >= 0.5: # threshold should probably be set higher
369
  isflat = [True]
370
  else: