liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import math
import cv2
import numpy as np
import torch
import argparse
from copy import deepcopy
from eval.relpose.metadata import dataset_metadata
from eval.relpose.utils import *
from accelerate import PartialState
from add_ckpt_path import add_path_to_dust3r
from tqdm import tqdm
def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--weights",
type=str,
help="path to the model weights",
default="",
)
parser.add_argument("--device", type=str, default="cuda", help="pytorch device")
parser.add_argument(
"--output_dir",
type=str,
default="",
help="value for outdir",
)
parser.add_argument(
"--no_crop", type=bool, default=True, help="whether to crop input data"
)
parser.add_argument(
"--eval_dataset",
type=str,
default="sintel",
choices=list(dataset_metadata.keys()),
)
parser.add_argument("--size", type=int, default="224")
parser.add_argument(
"--pose_eval_stride", default=1, type=int, help="stride for pose evaluation"
)
parser.add_argument("--shuffle", action="store_true", default=False)
parser.add_argument(
"--full_seq",
action="store_true",
default=False,
help="use full sequence for pose evaluation",
)
parser.add_argument(
"--seq_list",
nargs="+",
default=None,
help="list of sequences for pose evaluation",
)
parser.add_argument("--revisit", type=int, default=1)
parser.add_argument("--freeze_state", action="store_true", default=False)
parser.add_argument("--solve_pose", action="store_true", default=False)
return parser
def eval_pose_estimation(args, model, save_dir=None):
metadata = dataset_metadata.get(args.eval_dataset)
img_path = metadata["img_path"]
mask_path = metadata["mask_path"]
ate_mean, rpe_trans_mean, rpe_rot_mean = eval_pose_estimation_dist(
args, model, save_dir=save_dir, img_path=img_path, mask_path=mask_path
)
return ate_mean, rpe_trans_mean, rpe_rot_mean
def eval_pose_estimation_dist(args, model, img_path, save_dir=None, mask_path=None):
from dust3r.inference import inference
metadata = dataset_metadata.get(args.eval_dataset)
anno_path = metadata.get("anno_path", None)
seq_list = args.seq_list
if seq_list is None:
if metadata.get("full_seq", False):
args.full_seq = True
else:
seq_list = metadata.get("seq_list", [])
if args.full_seq:
seq_list = os.listdir(img_path)
seq_list = [
seq for seq in seq_list if os.path.isdir(os.path.join(img_path, seq))
]
seq_list = sorted(seq_list)
if save_dir is None:
save_dir = args.output_dir
distributed_state = PartialState()
model.to(distributed_state.device)
device = distributed_state.device
with distributed_state.split_between_processes(seq_list) as seqs:
ate_list = []
rpe_trans_list = []
rpe_rot_list = []
load_img_size = args.size
error_log_path = f"{save_dir}/_error_log_{distributed_state.process_index}.txt" # Unique log file per process
bug = False
for seq in tqdm(seqs):
try:
dir_path = metadata["dir_path_func"](img_path, seq)
# Handle skip_condition
skip_condition = metadata.get("skip_condition", None)
if skip_condition is not None and skip_condition(save_dir, seq):
continue
mask_path_seq_func = metadata.get(
"mask_path_seq_func", lambda mask_path, seq: None
)
mask_path_seq = mask_path_seq_func(mask_path, seq)
filelist = [
os.path.join(dir_path, name) for name in os.listdir(dir_path)
]
filelist.sort()
filelist = filelist[:: args.pose_eval_stride]
views = prepare_input(
filelist,
[True for _ in filelist],
size=load_img_size,
crop=not args.no_crop,
revisit=args.revisit,
update=not args.freeze_state,
)
outputs, _ = inference(views, model, device)
(
colors,
pts3ds_self,
pts3ds_other,
conf_self,
conf_other,
cam_dict,
pr_poses,
) = prepare_output(
outputs, revisit=args.revisit, solve_pose=args.solve_pose
)
pred_traj = get_tum_poses(pr_poses)
os.makedirs(f"{save_dir}/{seq}", exist_ok=True)
save_tum_poses(pr_poses, f"{save_dir}/{seq}/pred_traj.txt")
save_focals(cam_dict, f"{save_dir}/{seq}/pred_focal.txt")
save_intrinsics(cam_dict, f"{save_dir}/{seq}/pred_intrinsics.txt")
# save_depth_maps(pts3ds_self,f'{save_dir}/{seq}', conf_self=conf_self)
# save_conf_maps(conf_self,f'{save_dir}/{seq}')
# save_rgb_imgs(colors,f'{save_dir}/{seq}')
gt_traj_file = metadata["gt_traj_func"](img_path, anno_path, seq)
traj_format = metadata.get("traj_format", None)
if args.eval_dataset == "sintel":
gt_traj = load_traj(
gt_traj_file=gt_traj_file, stride=args.pose_eval_stride
)
elif traj_format is not None:
gt_traj = load_traj(
gt_traj_file=gt_traj_file,
traj_format=traj_format,
stride=args.pose_eval_stride,
)
else:
gt_traj = None
if gt_traj is not None:
ate, rpe_trans, rpe_rot = eval_metrics(
pred_traj,
gt_traj,
seq=seq,
filename=f"{save_dir}/{seq}_eval_metric.txt",
)
plot_trajectory(
pred_traj, gt_traj, title=seq, filename=f"{save_dir}/{seq}.png"
)
else:
ate, rpe_trans, rpe_rot = 0, 0, 0
bug = True
ate_list.append(ate)
rpe_trans_list.append(rpe_trans)
rpe_rot_list.append(rpe_rot)
# Write to error log after each sequence
with open(error_log_path, "a") as f:
f.write(
f"{args.eval_dataset}-{seq: <16} | ATE: {ate:.5f}, RPE trans: {rpe_trans:.5f}, RPE rot: {rpe_rot:.5f}\n"
)
f.write(f"{ate:.5f}\n")
f.write(f"{rpe_trans:.5f}\n")
f.write(f"{rpe_rot:.5f}\n")
except Exception as e:
if "out of memory" in str(e):
# Handle OOM
torch.cuda.empty_cache() # Clear the CUDA memory
with open(error_log_path, "a") as f:
f.write(
f"OOM error in sequence {seq}, skipping this sequence.\n"
)
print(f"OOM error in sequence {seq}, skipping...")
elif "Degenerate covariance rank" in str(
e
) or "Eigenvalues did not converge" in str(e):
# Handle Degenerate covariance rank exception and Eigenvalues did not converge exception
with open(error_log_path, "a") as f:
f.write(f"Exception in sequence {seq}: {str(e)}\n")
print(f"Traj evaluation error in sequence {seq}, skipping.")
else:
raise e # Rethrow if it's not an expected exception
distributed_state.wait_for_everyone()
results = process_directory(save_dir)
avg_ate, avg_rpe_trans, avg_rpe_rot = calculate_averages(results)
# Write the averages to the error log (only on the main process)
if distributed_state.is_main_process:
with open(f"{save_dir}/_error_log.txt", "a") as f:
# Copy the error log from each process to the main error log
for i in range(distributed_state.num_processes):
if not os.path.exists(f"{save_dir}/_error_log_{i}.txt"):
break
with open(f"{save_dir}/_error_log_{i}.txt", "r") as f_sub:
f.write(f_sub.read())
f.write(
f"Average ATE: {avg_ate:.5f}, Average RPE trans: {avg_rpe_trans:.5f}, Average RPE rot: {avg_rpe_rot:.5f}\n"
)
return avg_ate, avg_rpe_trans, avg_rpe_rot
if __name__ == "__main__":
args = get_args_parser()
args = args.parse_args()
add_path_to_dust3r(args.weights)
from dust3r.utils.image import load_images_for_eval as load_images
from dust3r.post_process import estimate_focal_knowing_depth
from dust3r.model import ARCroco3DStereo
from dust3r.utils.camera import pose_encoding_to_camera
from dust3r.utils.geometry import weighted_procrustes, geotrf
args.full_seq = False
args.no_crop = False
def recover_cam_params(pts3ds_self, pts3ds_other, conf_self, conf_other):
B, H, W, _ = pts3ds_self.shape
pp = (
torch.tensor([W // 2, H // 2], device=pts3ds_self.device)
.float()
.repeat(B, 1)
.reshape(B, 1, 2)
)
focal = estimate_focal_knowing_depth(pts3ds_self, pp, focal_mode="weiszfeld")
pts3ds_self = pts3ds_self.reshape(B, -1, 3)
pts3ds_other = pts3ds_other.reshape(B, -1, 3)
conf_self = conf_self.reshape(B, -1)
conf_other = conf_other.reshape(B, -1)
# weighted procrustes
c2w = weighted_procrustes(
pts3ds_self,
pts3ds_other,
torch.log(conf_self) * torch.log(conf_other),
use_weights=True,
return_T=True,
)
return c2w, focal, pp.reshape(B, 2)
def prepare_input(
img_paths,
img_mask,
size,
raymaps=None,
raymap_mask=None,
revisit=1,
update=True,
crop=True,
):
images = load_images(img_paths, size=size, crop=crop)
views = []
if raymaps is None and raymap_mask is None:
num_views = len(images)
for i in range(num_views):
view = {
"img": images[i]["img"],
"ray_map": torch.full(
(
images[i]["img"].shape[0],
6,
images[i]["img"].shape[-2],
images[i]["img"].shape[-1],
),
torch.nan,
),
"true_shape": torch.from_numpy(images[i]["true_shape"]),
"idx": i,
"instance": str(i),
"camera_pose": torch.from_numpy(
np.eye(4).astype(np.float32)
).unsqueeze(0),
"img_mask": torch.tensor(True).unsqueeze(0),
"ray_mask": torch.tensor(False).unsqueeze(0),
"update": torch.tensor(True).unsqueeze(0),
"reset": torch.tensor(False).unsqueeze(0),
}
views.append(view)
else:
num_views = len(images) + len(raymaps)
assert len(img_mask) == len(raymap_mask) == num_views
assert sum(img_mask) == len(images) and sum(raymap_mask) == len(raymaps)
j = 0
k = 0
for i in range(num_views):
view = {
"img": (
images[j]["img"]
if img_mask[i]
else torch.full_like(images[0]["img"], torch.nan)
),
"ray_map": (
raymaps[k]
if raymap_mask[i]
else torch.full_like(raymaps[0], torch.nan)
),
"true_shape": (
torch.from_numpy(images[j]["true_shape"])
if img_mask[i]
else torch.from_numpy(np.int32([raymaps[k].shape[1:-1][::-1]]))
),
"idx": i,
"instance": str(i),
"camera_pose": torch.from_numpy(
np.eye(4).astype(np.float32)
).unsqueeze(0),
"img_mask": torch.tensor(img_mask[i]).unsqueeze(0),
"ray_mask": torch.tensor(raymap_mask[i]).unsqueeze(0),
"update": torch.tensor(img_mask[i]).unsqueeze(0),
"reset": torch.tensor(False).unsqueeze(0),
}
if img_mask[i]:
j += 1
if raymap_mask[i]:
k += 1
views.append(view)
assert j == len(images) and k == len(raymaps)
if revisit > 1:
# repeat input for 'revisit' times
new_views = []
for r in range(revisit):
for i in range(len(views)):
new_view = deepcopy(views[i])
new_view["idx"] = r * len(views) + i
new_view["instance"] = str(r * len(views) + i)
if r > 0:
if not update:
new_view["update"] = torch.tensor(False).unsqueeze(0)
new_views.append(new_view)
return new_views
return views
def prepare_output(outputs, revisit=1, solve_pose=False):
valid_length = len(outputs["pred"]) // revisit
outputs["pred"] = outputs["pred"][-valid_length:]
outputs["views"] = outputs["views"][-valid_length:]
if solve_pose:
pts3ds_self = [
output["pts3d_in_self_view"].cpu() for output in outputs["pred"]
]
pts3ds_other = [
output["pts3d_in_other_view"].cpu() for output in outputs["pred"]
]
conf_self = [output["conf_self"].cpu() for output in outputs["pred"]]
conf_other = [output["conf"].cpu() for output in outputs["pred"]]
pr_poses, focal, pp = recover_cam_params(
torch.cat(pts3ds_self, 0),
torch.cat(pts3ds_other, 0),
torch.cat(conf_self, 0),
torch.cat(conf_other, 0),
)
pts3ds_self = torch.cat(pts3ds_self, 0)
else:
pts3ds_self = [
output["pts3d_in_self_view"].cpu() for output in outputs["pred"]
]
pts3ds_other = [
output["pts3d_in_other_view"].cpu() for output in outputs["pred"]
]
conf_self = [output["conf_self"].cpu() for output in outputs["pred"]]
conf_other = [output["conf"].cpu() for output in outputs["pred"]]
pts3ds_self = torch.cat(pts3ds_self, 0)
pr_poses = [
pose_encoding_to_camera(pred["camera_pose"].clone()).cpu()
for pred in outputs["pred"]
]
pr_poses = torch.cat(pr_poses, 0)
B, H, W, _ = pts3ds_self.shape
pp = (
torch.tensor([W // 2, H // 2], device=pts3ds_self.device)
.float()
.repeat(B, 1)
.reshape(B, 2)
)
focal = estimate_focal_knowing_depth(
pts3ds_self, pp, focal_mode="weiszfeld"
)
colors = [0.5 * (output["rgb"][0] + 1.0) for output in outputs["pred"]]
cam_dict = {
"focal": focal.cpu().numpy(),
"pp": pp.cpu().numpy(),
}
return (
colors,
pts3ds_self,
pts3ds_other,
conf_self,
conf_other,
cam_dict,
pr_poses,
)
model = ARCroco3DStereo.from_pretrained(args.weights)
eval_pose_estimation(args, model, save_dir=args.output_dir)