Spaces:
Runtime error
Runtime error
Nadine Rueegg
commited on
Commit
·
3cd5df9
1
Parent(s):
0d2dc80
don't save images
Browse files- scripts/gradio_demo.py +1 -23
scripts/gradio_demo.py
CHANGED
@@ -244,7 +244,6 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
|
|
244 |
with open(loss_weight_path, 'r') as j:
|
245 |
losses = json.loads(j.read())
|
246 |
shutil.copyfile(loss_weight_path, root_out_path_details + os.path.basename(loss_weight_path))
|
247 |
-
print(losses)
|
248 |
|
249 |
# prepare dataset and dataset loader
|
250 |
val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints = get_single_crop_dataset_from_image(input_image, bbox=bbox)
|
@@ -258,7 +257,7 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
|
|
258 |
# prepare progress bar
|
259 |
iterable = enumerate(val_loader) # the length of this iterator should be 1
|
260 |
progress = None
|
261 |
-
if
|
262 |
progress = tqdm(iterable, desc='Train', total=len(val_loader), ascii=True, leave=False)
|
263 |
iterable = progress
|
264 |
ind_img_tot = 0
|
@@ -289,29 +288,9 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
|
|
289 |
ind_img = 0
|
290 |
name = (test_name_list[target_dict['index'][ind_img].long()]).replace('/', '__').split('.')[0]
|
291 |
|
292 |
-
print('ind_img_tot: ' + str(ind_img_tot) + ' -> ' + name)
|
293 |
ind_img_tot += 1
|
294 |
batch_size = 1
|
295 |
|
296 |
-
# save initial visualizations
|
297 |
-
# save the image with keypoints as predicted by the stacked hourglass
|
298 |
-
pred_unp_prep = torch.cat((res['hg_keyp_256'][ind_img, :, :].detach(), res['hg_keyp_scores'][ind_img, :, :]), 1)
|
299 |
-
inp_img = input[ind_img, :, :, :].detach().clone()
|
300 |
-
out_path = root_out_path + name + '_hg_key.png'
|
301 |
-
save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.01, print_scores=True, ratio_in_out=1.0) # threshold=0.3
|
302 |
-
# save the input image
|
303 |
-
img_inp = input[ind_img, :, :, :].clone()
|
304 |
-
for t, m, s in zip(img_inp, stanext_data_info.rgb_mean, stanext_data_info.rgb_stddev): t.add_(m) # inverse to transforms.color_normalize()
|
305 |
-
img_inp = img_inp.detach().cpu().numpy().transpose(1, 2, 0)
|
306 |
-
img_init = Image.fromarray(np.uint8(255*img_inp)).convert('RGB')
|
307 |
-
img_init.save(root_out_path_details + name + '_img_ainit.png')
|
308 |
-
# save ground truth silhouette (for visualization only, it is not used during the optimization)
|
309 |
-
target_img_silh = Image.fromarray(np.uint8(255*target_dict['silh'][ind_img, :, :].detach().cpu().numpy())).convert('RGB')
|
310 |
-
target_img_silh.save(root_out_path_details + name + '_target_silh.png')
|
311 |
-
# save the silhouette as predicted by the stacked hourglass
|
312 |
-
hg_img_silh = Image.fromarray(np.uint8(255*res['hg_silh_prep'][ind_img, :, :].detach().cpu().numpy())).convert('RGB')
|
313 |
-
hg_img_silh.save(root_out_path + name + '_hg_silh.png')
|
314 |
-
|
315 |
# initialize the variables over which we want to optimize
|
316 |
optimed_pose_6d = all_pose_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True)
|
317 |
optimed_orient_6d = all_orient_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True) # [1, 1, 6]
|
@@ -386,7 +365,6 @@ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
|
|
386 |
target_hg_silh = res['hg_silh_prep'][ind_img, :, :].detach()
|
387 |
target_kp_resh = res['hg_keyp_256'][ind_img, None, :, :].reshape((-1, 2)).detach()
|
388 |
# find out if ground contact constraints should be used for the image at hand
|
389 |
-
# print('is flat: ' + str(res['isflat_prep'][ind_img]))
|
390 |
if res['isflat_prep'][ind_img] >= 0.5: # threshold should probably be set higher
|
391 |
isflat = [True]
|
392 |
else:
|
|
|
244 |
with open(loss_weight_path, 'r') as j:
|
245 |
losses = json.loads(j.read())
|
246 |
shutil.copyfile(loss_weight_path, root_out_path_details + os.path.basename(loss_weight_path))
|
|
|
247 |
|
248 |
# prepare dataset and dataset loader
|
249 |
val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints = get_single_crop_dataset_from_image(input_image, bbox=bbox)
|
|
|
257 |
# prepare progress bar
|
258 |
iterable = enumerate(val_loader) # the length of this iterator should be 1
|
259 |
progress = None
|
260 |
+
if False: # not quiet:
|
261 |
progress = tqdm(iterable, desc='Train', total=len(val_loader), ascii=True, leave=False)
|
262 |
iterable = progress
|
263 |
ind_img_tot = 0
|
|
|
288 |
ind_img = 0
|
289 |
name = (test_name_list[target_dict['index'][ind_img].long()]).replace('/', '__').split('.')[0]
|
290 |
|
|
|
291 |
ind_img_tot += 1
|
292 |
batch_size = 1
|
293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
# initialize the variables over which we want to optimize
|
295 |
optimed_pose_6d = all_pose_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True)
|
296 |
optimed_orient_6d = all_orient_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True) # [1, 1, 6]
|
|
|
365 |
target_hg_silh = res['hg_silh_prep'][ind_img, :, :].detach()
|
366 |
target_kp_resh = res['hg_keyp_256'][ind_img, None, :, :].reshape((-1, 2)).detach()
|
367 |
# find out if ground contact constraints should be used for the image at hand
|
|
|
368 |
if res['isflat_prep'][ind_img] >= 0.5: # threshold should probably be set higher
|
369 |
isflat = [True]
|
370 |
else:
|