Spaces:
Build error
Build error
Nadine Rueegg
commited on
Commit
·
3cd5df9
1
Parent(s):
0d2dc80
don't save images
Browse files- scripts/gradio_demo.py +1 -23
scripts/gradio_demo.py
CHANGED
|
@@ -244,7 +244,6 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
|
|
| 244 |
with open(loss_weight_path, 'r') as j:
|
| 245 |
losses = json.loads(j.read())
|
| 246 |
shutil.copyfile(loss_weight_path, root_out_path_details + os.path.basename(loss_weight_path))
|
| 247 |
-
print(losses)
|
| 248 |
|
| 249 |
# prepare dataset and dataset loader
|
| 250 |
val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints = get_single_crop_dataset_from_image(input_image, bbox=bbox)
|
|
@@ -258,7 +257,7 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
|
|
| 258 |
# prepare progress bar
|
| 259 |
iterable = enumerate(val_loader) # the length of this iterator should be 1
|
| 260 |
progress = None
|
| 261 |
-
if
|
| 262 |
progress = tqdm(iterable, desc='Train', total=len(val_loader), ascii=True, leave=False)
|
| 263 |
iterable = progress
|
| 264 |
ind_img_tot = 0
|
|
@@ -289,29 +288,9 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
|
|
| 289 |
ind_img = 0
|
| 290 |
name = (test_name_list[target_dict['index'][ind_img].long()]).replace('/', '__').split('.')[0]
|
| 291 |
|
| 292 |
-
print('ind_img_tot: ' + str(ind_img_tot) + ' -> ' + name)
|
| 293 |
ind_img_tot += 1
|
| 294 |
batch_size = 1
|
| 295 |
|
| 296 |
-
# save initial visualizations
|
| 297 |
-
# save the image with keypoints as predicted by the stacked hourglass
|
| 298 |
-
pred_unp_prep = torch.cat((res['hg_keyp_256'][ind_img, :, :].detach(), res['hg_keyp_scores'][ind_img, :, :]), 1)
|
| 299 |
-
inp_img = input[ind_img, :, :, :].detach().clone()
|
| 300 |
-
out_path = root_out_path + name + '_hg_key.png'
|
| 301 |
-
save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.01, print_scores=True, ratio_in_out=1.0) # threshold=0.3
|
| 302 |
-
# save the input image
|
| 303 |
-
img_inp = input[ind_img, :, :, :].clone()
|
| 304 |
-
for t, m, s in zip(img_inp, stanext_data_info.rgb_mean, stanext_data_info.rgb_stddev): t.add_(m) # inverse to transforms.color_normalize()
|
| 305 |
-
img_inp = img_inp.detach().cpu().numpy().transpose(1, 2, 0)
|
| 306 |
-
img_init = Image.fromarray(np.uint8(255*img_inp)).convert('RGB')
|
| 307 |
-
img_init.save(root_out_path_details + name + '_img_ainit.png')
|
| 308 |
-
# save ground truth silhouette (for visualization only, it is not used during the optimization)
|
| 309 |
-
target_img_silh = Image.fromarray(np.uint8(255*target_dict['silh'][ind_img, :, :].detach().cpu().numpy())).convert('RGB')
|
| 310 |
-
target_img_silh.save(root_out_path_details + name + '_target_silh.png')
|
| 311 |
-
# save the silhouette as predicted by the stacked hourglass
|
| 312 |
-
hg_img_silh = Image.fromarray(np.uint8(255*res['hg_silh_prep'][ind_img, :, :].detach().cpu().numpy())).convert('RGB')
|
| 313 |
-
hg_img_silh.save(root_out_path + name + '_hg_silh.png')
|
| 314 |
-
|
| 315 |
# initialize the variables over which we want to optimize
|
| 316 |
optimed_pose_6d = all_pose_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True)
|
| 317 |
optimed_orient_6d = all_orient_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True) # [1, 1, 6]
|
|
@@ -386,7 +365,6 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
|
|
| 386 |
target_hg_silh = res['hg_silh_prep'][ind_img, :, :].detach()
|
| 387 |
target_kp_resh = res['hg_keyp_256'][ind_img, None, :, :].reshape((-1, 2)).detach()
|
| 388 |
# find out if ground contact constraints should be used for the image at hand
|
| 389 |
-
# print('is flat: ' + str(res['isflat_prep'][ind_img]))
|
| 390 |
if res['isflat_prep'][ind_img] >= 0.5: # threshold should probably be set higher
|
| 391 |
isflat = [True]
|
| 392 |
else:
|
|
|
|
| 244 |
with open(loss_weight_path, 'r') as j:
|
| 245 |
losses = json.loads(j.read())
|
| 246 |
shutil.copyfile(loss_weight_path, root_out_path_details + os.path.basename(loss_weight_path))
|
|
|
|
| 247 |
|
| 248 |
# prepare dataset and dataset loader
|
| 249 |
val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints = get_single_crop_dataset_from_image(input_image, bbox=bbox)
|
|
|
|
| 257 |
# prepare progress bar
|
| 258 |
iterable = enumerate(val_loader) # the length of this iterator should be 1
|
| 259 |
progress = None
|
| 260 |
+
if False: # not quiet:
|
| 261 |
progress = tqdm(iterable, desc='Train', total=len(val_loader), ascii=True, leave=False)
|
| 262 |
iterable = progress
|
| 263 |
ind_img_tot = 0
|
|
|
|
| 288 |
ind_img = 0
|
| 289 |
name = (test_name_list[target_dict['index'][ind_img].long()]).replace('/', '__').split('.')[0]
|
| 290 |
|
|
|
|
| 291 |
ind_img_tot += 1
|
| 292 |
batch_size = 1
|
| 293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
# initialize the variables over which we want to optimize
|
| 295 |
optimed_pose_6d = all_pose_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True)
|
| 296 |
optimed_orient_6d = all_orient_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True) # [1, 1, 6]
|
|
|
|
| 365 |
target_hg_silh = res['hg_silh_prep'][ind_img, :, :].detach()
|
| 366 |
target_kp_resh = res['hg_keyp_256'][ind_img, None, :, :].reshape((-1, 2)).detach()
|
| 367 |
# find out if ground contact constraints should be used for the image at hand
|
|
|
|
| 368 |
if res['isflat_prep'][ind_img] >= 0.5: # threshold should probably be set higher
|
| 369 |
isflat = [True]
|
| 370 |
else:
|