import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) import math import cv2 import numpy as np import torch import argparse from copy import deepcopy from eval.relpose.metadata import dataset_metadata from eval.relpose.utils import * from accelerate import PartialState from add_ckpt_path import add_path_to_dust3r from tqdm import tqdm def get_args_parser(): parser = argparse.ArgumentParser() parser.add_argument( "--weights", type=str, help="path to the model weights", default="", ) parser.add_argument("--device", type=str, default="cuda", help="pytorch device") parser.add_argument( "--output_dir", type=str, default="", help="value for outdir", ) parser.add_argument( "--no_crop", type=bool, default=True, help="whether to crop input data" ) parser.add_argument( "--eval_dataset", type=str, default="sintel", choices=list(dataset_metadata.keys()), ) parser.add_argument("--size", type=int, default="224") parser.add_argument( "--pose_eval_stride", default=1, type=int, help="stride for pose evaluation" ) parser.add_argument("--shuffle", action="store_true", default=False) parser.add_argument( "--full_seq", action="store_true", default=False, help="use full sequence for pose evaluation", ) parser.add_argument( "--seq_list", nargs="+", default=None, help="list of sequences for pose evaluation", ) parser.add_argument("--revisit", type=int, default=1) parser.add_argument("--freeze_state", action="store_true", default=False) parser.add_argument("--solve_pose", action="store_true", default=False) return parser def eval_pose_estimation(args, model, save_dir=None): metadata = dataset_metadata.get(args.eval_dataset) img_path = metadata["img_path"] mask_path = metadata["mask_path"] ate_mean, rpe_trans_mean, rpe_rot_mean = eval_pose_estimation_dist( args, model, save_dir=save_dir, img_path=img_path, mask_path=mask_path ) return ate_mean, rpe_trans_mean, rpe_rot_mean def eval_pose_estimation_dist(args, model, img_path, save_dir=None, mask_path=None): from dust3r.inference import inference metadata = dataset_metadata.get(args.eval_dataset) anno_path = metadata.get("anno_path", None) seq_list = args.seq_list if seq_list is None: if metadata.get("full_seq", False): args.full_seq = True else: seq_list = metadata.get("seq_list", []) if args.full_seq: seq_list = os.listdir(img_path) seq_list = [ seq for seq in seq_list if os.path.isdir(os.path.join(img_path, seq)) ] seq_list = sorted(seq_list) if save_dir is None: save_dir = args.output_dir distributed_state = PartialState() model.to(distributed_state.device) device = distributed_state.device with distributed_state.split_between_processes(seq_list) as seqs: ate_list = [] rpe_trans_list = [] rpe_rot_list = [] load_img_size = args.size error_log_path = f"{save_dir}/_error_log_{distributed_state.process_index}.txt" # Unique log file per process bug = False for seq in tqdm(seqs): try: dir_path = metadata["dir_path_func"](img_path, seq) # Handle skip_condition skip_condition = metadata.get("skip_condition", None) if skip_condition is not None and skip_condition(save_dir, seq): continue mask_path_seq_func = metadata.get( "mask_path_seq_func", lambda mask_path, seq: None ) mask_path_seq = mask_path_seq_func(mask_path, seq) filelist = [ os.path.join(dir_path, name) for name in os.listdir(dir_path) ] filelist.sort() filelist = filelist[:: args.pose_eval_stride] views = prepare_input( filelist, [True for _ in filelist], size=load_img_size, crop=not args.no_crop, revisit=args.revisit, update=not args.freeze_state, ) outputs, _ = inference(views, model, device) ( colors, pts3ds_self, pts3ds_other, conf_self, conf_other, cam_dict, pr_poses, ) = prepare_output( outputs, revisit=args.revisit, solve_pose=args.solve_pose ) pred_traj = get_tum_poses(pr_poses) os.makedirs(f"{save_dir}/{seq}", exist_ok=True) save_tum_poses(pr_poses, f"{save_dir}/{seq}/pred_traj.txt") save_focals(cam_dict, f"{save_dir}/{seq}/pred_focal.txt") save_intrinsics(cam_dict, f"{save_dir}/{seq}/pred_intrinsics.txt") # save_depth_maps(pts3ds_self,f'{save_dir}/{seq}', conf_self=conf_self) # save_conf_maps(conf_self,f'{save_dir}/{seq}') # save_rgb_imgs(colors,f'{save_dir}/{seq}') gt_traj_file = metadata["gt_traj_func"](img_path, anno_path, seq) traj_format = metadata.get("traj_format", None) if args.eval_dataset == "sintel": gt_traj = load_traj( gt_traj_file=gt_traj_file, stride=args.pose_eval_stride ) elif traj_format is not None: gt_traj = load_traj( gt_traj_file=gt_traj_file, traj_format=traj_format, stride=args.pose_eval_stride, ) else: gt_traj = None if gt_traj is not None: ate, rpe_trans, rpe_rot = eval_metrics( pred_traj, gt_traj, seq=seq, filename=f"{save_dir}/{seq}_eval_metric.txt", ) plot_trajectory( pred_traj, gt_traj, title=seq, filename=f"{save_dir}/{seq}.png" ) else: ate, rpe_trans, rpe_rot = 0, 0, 0 bug = True ate_list.append(ate) rpe_trans_list.append(rpe_trans) rpe_rot_list.append(rpe_rot) # Write to error log after each sequence with open(error_log_path, "a") as f: f.write( f"{args.eval_dataset}-{seq: <16} | ATE: {ate:.5f}, RPE trans: {rpe_trans:.5f}, RPE rot: {rpe_rot:.5f}\n" ) f.write(f"{ate:.5f}\n") f.write(f"{rpe_trans:.5f}\n") f.write(f"{rpe_rot:.5f}\n") except Exception as e: if "out of memory" in str(e): # Handle OOM torch.cuda.empty_cache() # Clear the CUDA memory with open(error_log_path, "a") as f: f.write( f"OOM error in sequence {seq}, skipping this sequence.\n" ) print(f"OOM error in sequence {seq}, skipping...") elif "Degenerate covariance rank" in str( e ) or "Eigenvalues did not converge" in str(e): # Handle Degenerate covariance rank exception and Eigenvalues did not converge exception with open(error_log_path, "a") as f: f.write(f"Exception in sequence {seq}: {str(e)}\n") print(f"Traj evaluation error in sequence {seq}, skipping.") else: raise e # Rethrow if it's not an expected exception distributed_state.wait_for_everyone() results = process_directory(save_dir) avg_ate, avg_rpe_trans, avg_rpe_rot = calculate_averages(results) # Write the averages to the error log (only on the main process) if distributed_state.is_main_process: with open(f"{save_dir}/_error_log.txt", "a") as f: # Copy the error log from each process to the main error log for i in range(distributed_state.num_processes): if not os.path.exists(f"{save_dir}/_error_log_{i}.txt"): break with open(f"{save_dir}/_error_log_{i}.txt", "r") as f_sub: f.write(f_sub.read()) f.write( f"Average ATE: {avg_ate:.5f}, Average RPE trans: {avg_rpe_trans:.5f}, Average RPE rot: {avg_rpe_rot:.5f}\n" ) return avg_ate, avg_rpe_trans, avg_rpe_rot if __name__ == "__main__": args = get_args_parser() args = args.parse_args() add_path_to_dust3r(args.weights) from dust3r.utils.image import load_images_for_eval as load_images from dust3r.post_process import estimate_focal_knowing_depth from dust3r.model import ARCroco3DStereo from dust3r.utils.camera import pose_encoding_to_camera from dust3r.utils.geometry import weighted_procrustes, geotrf args.full_seq = False args.no_crop = False def recover_cam_params(pts3ds_self, pts3ds_other, conf_self, conf_other): B, H, W, _ = pts3ds_self.shape pp = ( torch.tensor([W // 2, H // 2], device=pts3ds_self.device) .float() .repeat(B, 1) .reshape(B, 1, 2) ) focal = estimate_focal_knowing_depth(pts3ds_self, pp, focal_mode="weiszfeld") pts3ds_self = pts3ds_self.reshape(B, -1, 3) pts3ds_other = pts3ds_other.reshape(B, -1, 3) conf_self = conf_self.reshape(B, -1) conf_other = conf_other.reshape(B, -1) # weighted procrustes c2w = weighted_procrustes( pts3ds_self, pts3ds_other, torch.log(conf_self) * torch.log(conf_other), use_weights=True, return_T=True, ) return c2w, focal, pp.reshape(B, 2) def prepare_input( img_paths, img_mask, size, raymaps=None, raymap_mask=None, revisit=1, update=True, crop=True, ): images = load_images(img_paths, size=size, crop=crop) views = [] if raymaps is None and raymap_mask is None: num_views = len(images) for i in range(num_views): view = { "img": images[i]["img"], "ray_map": torch.full( ( images[i]["img"].shape[0], 6, images[i]["img"].shape[-2], images[i]["img"].shape[-1], ), torch.nan, ), "true_shape": torch.from_numpy(images[i]["true_shape"]), "idx": i, "instance": str(i), "camera_pose": torch.from_numpy( np.eye(4).astype(np.float32) ).unsqueeze(0), "img_mask": torch.tensor(True).unsqueeze(0), "ray_mask": torch.tensor(False).unsqueeze(0), "update": torch.tensor(True).unsqueeze(0), "reset": torch.tensor(False).unsqueeze(0), } views.append(view) else: num_views = len(images) + len(raymaps) assert len(img_mask) == len(raymap_mask) == num_views assert sum(img_mask) == len(images) and sum(raymap_mask) == len(raymaps) j = 0 k = 0 for i in range(num_views): view = { "img": ( images[j]["img"] if img_mask[i] else torch.full_like(images[0]["img"], torch.nan) ), "ray_map": ( raymaps[k] if raymap_mask[i] else torch.full_like(raymaps[0], torch.nan) ), "true_shape": ( torch.from_numpy(images[j]["true_shape"]) if img_mask[i] else torch.from_numpy(np.int32([raymaps[k].shape[1:-1][::-1]])) ), "idx": i, "instance": str(i), "camera_pose": torch.from_numpy( np.eye(4).astype(np.float32) ).unsqueeze(0), "img_mask": torch.tensor(img_mask[i]).unsqueeze(0), "ray_mask": torch.tensor(raymap_mask[i]).unsqueeze(0), "update": torch.tensor(img_mask[i]).unsqueeze(0), "reset": torch.tensor(False).unsqueeze(0), } if img_mask[i]: j += 1 if raymap_mask[i]: k += 1 views.append(view) assert j == len(images) and k == len(raymaps) if revisit > 1: # repeat input for 'revisit' times new_views = [] for r in range(revisit): for i in range(len(views)): new_view = deepcopy(views[i]) new_view["idx"] = r * len(views) + i new_view["instance"] = str(r * len(views) + i) if r > 0: if not update: new_view["update"] = torch.tensor(False).unsqueeze(0) new_views.append(new_view) return new_views return views def prepare_output(outputs, revisit=1, solve_pose=False): valid_length = len(outputs["pred"]) // revisit outputs["pred"] = outputs["pred"][-valid_length:] outputs["views"] = outputs["views"][-valid_length:] if solve_pose: pts3ds_self = [ output["pts3d_in_self_view"].cpu() for output in outputs["pred"] ] pts3ds_other = [ output["pts3d_in_other_view"].cpu() for output in outputs["pred"] ] conf_self = [output["conf_self"].cpu() for output in outputs["pred"]] conf_other = [output["conf"].cpu() for output in outputs["pred"]] pr_poses, focal, pp = recover_cam_params( torch.cat(pts3ds_self, 0), torch.cat(pts3ds_other, 0), torch.cat(conf_self, 0), torch.cat(conf_other, 0), ) pts3ds_self = torch.cat(pts3ds_self, 0) else: pts3ds_self = [ output["pts3d_in_self_view"].cpu() for output in outputs["pred"] ] pts3ds_other = [ output["pts3d_in_other_view"].cpu() for output in outputs["pred"] ] conf_self = [output["conf_self"].cpu() for output in outputs["pred"]] conf_other = [output["conf"].cpu() for output in outputs["pred"]] pts3ds_self = torch.cat(pts3ds_self, 0) pr_poses = [ pose_encoding_to_camera(pred["camera_pose"].clone()).cpu() for pred in outputs["pred"] ] pr_poses = torch.cat(pr_poses, 0) B, H, W, _ = pts3ds_self.shape pp = ( torch.tensor([W // 2, H // 2], device=pts3ds_self.device) .float() .repeat(B, 1) .reshape(B, 2) ) focal = estimate_focal_knowing_depth( pts3ds_self, pp, focal_mode="weiszfeld" ) colors = [0.5 * (output["rgb"][0] + 1.0) for output in outputs["pred"]] cam_dict = { "focal": focal.cpu().numpy(), "pp": pp.cpu().numpy(), } return ( colors, pts3ds_self, pts3ds_other, conf_self, conf_other, cam_dict, pr_poses, ) model = ARCroco3DStereo.from_pretrained(args.weights) eval_pose_estimation(args, model, save_dir=args.output_dir)