import torch import numpy as np import cv2 import glob import argparse from pathlib import Path from tqdm import tqdm from copy import deepcopy from scipy.optimize import minimize import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) from collections import defaultdict from eval.monodepth.metadata import dataset_metadata from add_ckpt_path import add_path_to_dust3r def get_args_parser(): parser = argparse.ArgumentParser() parser.add_argument( "--weights", type=str, help="path to the model weights", default="" ) parser.add_argument("--device", type=str, default="cuda", help="pytorch device") parser.add_argument("--output_dir", type=str, default="", help="value for outdir") parser.add_argument( "--no_crop", type=bool, default=True, help="whether to crop input data" ) parser.add_argument( "--full_seq", type=bool, default=False, help="whether to use all seqs" ) parser.add_argument("--seq_list", default=None) parser.add_argument( "--eval_dataset", type=str, default="nyu", choices=list(dataset_metadata.keys()) ) return parser def eval_mono_depth_estimation(args, model, device): metadata = dataset_metadata.get(args.eval_dataset) if metadata is None: raise ValueError(f"Unknown dataset: {args.eval_dataset}") img_path = metadata.get("img_path") if "img_path_func" in metadata: img_path = metadata["img_path_func"](args) process_func = metadata.get("process_func") if process_func is None: raise ValueError( f"No processing function defined for dataset: {args.eval_dataset}" ) for filelist, save_dir in process_func(args, img_path): Path(save_dir).mkdir(parents=True, exist_ok=True) eval_mono_depth(args, model, device, filelist, save_dir=save_dir) def eval_mono_depth(args, model, device, filelist, save_dir=None): model.eval() load_img_size = 512 for file in tqdm(filelist): # construct the "image pair" for the single image file = [file] images = load_images( file, size=load_img_size, verbose=False, crop=not args.no_crop ) views = [] num_views = len(images) for i in range(num_views): view = { "img": images[i]["img"], "ray_map": torch.full( ( images[i]["img"].shape[0], 6, images[i]["img"].shape[-2], images[i]["img"].shape[-1], ), torch.nan, ), "true_shape": torch.from_numpy(images[i]["true_shape"]), "idx": i, "instance": str(i), "camera_pose": torch.from_numpy(np.eye(4).astype(np.float32)).unsqueeze( 0 ), "img_mask": torch.tensor(True).unsqueeze(0), "ray_mask": torch.tensor(False).unsqueeze(0), "update": torch.tensor(True).unsqueeze(0), "reset": torch.tensor(False).unsqueeze(0), } views.append(view) outputs, state_args = inference(views, model, device) pts3ds_self = [output["pts3d_in_self_view"].cpu() for output in outputs["pred"]] depth_map = pts3ds_self[0][..., -1].mean(dim=0) if save_dir is not None: # save the depth map to the save_dir as npy np.save( f"{save_dir}/{file[0].split('/')[-1].replace('.png','depth.npy')}", depth_map.cpu().numpy(), ) # also save the png depth_map = (depth_map - depth_map.min()) / ( depth_map.max() - depth_map.min() ) depth_map = (depth_map * 255).cpu().numpy().astype(np.uint8) cv2.imwrite( f"{save_dir}/{file[0].split('/')[-1].replace('.png','depth.png')}", depth_map, ) if __name__ == "__main__": args = get_args_parser() args = args.parse_args() if args.eval_dataset == "sintel": args.full_seq = True else: args.full_seq = False add_path_to_dust3r(args.weights) from dust3r.utils.image import load_images_for_eval as load_images from dust3r.inference import inference from dust3r.model import ARCroco3DStereo model = ARCroco3DStereo.from_pretrained(args.weights).to(args.device) eval_mono_depth_estimation(args, model, args.device)