Spaces:
Runtime error
Runtime error
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # utilities needed for the inference | |
| # -------------------------------------------------------- | |
| import tqdm | |
| import torch | |
| from mini_dust3r.utils.device import to_cpu, collate_with_cat | |
| from mini_dust3r.utils.misc import invalid_to_nans | |
| from mini_dust3r.utils.geometry import depthmap_to_pts3d, geotrf | |
| from mini_dust3r.utils.image import ImageDict | |
| from mini_dust3r.model import AsymmetricCroCo3DStereo | |
| from typing import Literal, TypedDict, Optional | |
| from jaxtyping import Float32 | |
| class Dust3rPred1(TypedDict): | |
| pts3d: Float32[torch.Tensor, "b h w c"] | |
| conf: Float32[torch.Tensor, "b h w"] | |
| class Dust3rPred2(TypedDict): | |
| pts3d_in_other_view: Float32[torch.Tensor, "b h w c"] | |
| conf: Float32[torch.Tensor, "b h w"] | |
| class Dust3rResult(TypedDict): | |
| view1: ImageDict | |
| view2: ImageDict | |
| pred1: Dust3rPred1 | |
| pred2: Dust3rPred2 | |
| loss: Optional[int] | |
| def _interleave_imgs(img1, img2): | |
| res = {} | |
| for key, value1 in img1.items(): | |
| value2 = img2[key] | |
| if isinstance(value1, torch.Tensor): | |
| value = torch.stack((value1, value2), dim=1).flatten(0, 1) | |
| else: | |
| value = [x for pair in zip(value1, value2) for x in pair] | |
| res[key] = value | |
| return res | |
| def make_batch_symmetric(batch): | |
| view1, view2 = batch | |
| view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1)) | |
| return view1, view2 | |
| def loss_of_one_batch( | |
| batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None | |
| ): | |
| view1, view2 = batch | |
| for view in batch: | |
| for name in ( | |
| "img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres".split() | |
| ): # pseudo_focal | |
| if name not in view: | |
| continue | |
| view[name] = view[name].to(device, non_blocking=True) | |
| if symmetrize_batch: | |
| view1, view2 = make_batch_symmetric(batch) | |
| with torch.cuda.amp.autocast(enabled=bool(use_amp)): | |
| pred1, pred2 = model(view1, view2) | |
| # loss is supposed to be symmetric | |
| with torch.cuda.amp.autocast(enabled=False): | |
| loss = ( | |
| criterion(view1, view2, pred1, pred2) if criterion is not None else None | |
| ) | |
| result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss) | |
| return result[ret] if ret else result | |
| def inference( | |
| pairs: list[tuple[ImageDict, ImageDict]], | |
| model: AsymmetricCroCo3DStereo, | |
| device: Literal["cpu", "cuda", "mps"], | |
| batch_size: int = 8, | |
| verbose: bool = True, | |
| ) -> Dust3rResult: | |
| if verbose: | |
| print(f">> Inference with model on {len(pairs)} image pairs") | |
| result = [] | |
| # first, check if all images have the same size | |
| multiple_shapes = not (check_if_same_size(pairs)) | |
| if multiple_shapes: # force bs=1 | |
| batch_size = 1 | |
| for i in tqdm.trange(0, len(pairs), batch_size, disable=not verbose): | |
| res: Dust3rResult = loss_of_one_batch( | |
| collate_with_cat(pairs[i : i + batch_size]), model, None, device | |
| ) | |
| result.append(to_cpu(res)) | |
| result = collate_with_cat(result, lists=multiple_shapes) | |
| return result | |
| def check_if_same_size(pairs): | |
| shapes1 = [img1["img"].shape[-2:] for img1, img2 in pairs] | |
| shapes2 = [img2["img"].shape[-2:] for img1, img2 in pairs] | |
| return all(shapes1[0] == s for s in shapes1) and all( | |
| shapes2[0] == s for s in shapes2 | |
| ) | |
| def get_pred_pts3d(gt, pred, use_pose=False): | |
| if "depth" in pred and "pseudo_focal" in pred: | |
| try: | |
| pp = gt["camera_intrinsics"][..., :2, 2] | |
| except KeyError: | |
| pp = None | |
| pts3d = depthmap_to_pts3d(**pred, pp=pp) | |
| elif "pts3d" in pred: | |
| # pts3d from my camera | |
| pts3d = pred["pts3d"] | |
| elif "pts3d_in_other_view" in pred: | |
| # pts3d from the other camera, already transformed | |
| assert use_pose is True | |
| return pred["pts3d_in_other_view"] # return! | |
| if use_pose: | |
| camera_pose = pred.get("camera_pose") | |
| assert camera_pose is not None | |
| pts3d = geotrf(camera_pose, pts3d) | |
| return pts3d | |
| def find_opt_scaling( | |
| gt_pts1, | |
| gt_pts2, | |
| pr_pts1, | |
| pr_pts2=None, | |
| fit_mode="weiszfeld_stop_grad", | |
| valid1=None, | |
| valid2=None, | |
| ): | |
| assert gt_pts1.ndim == pr_pts1.ndim == 4 | |
| assert gt_pts1.shape == pr_pts1.shape | |
| if gt_pts2 is not None: | |
| assert gt_pts2.ndim == pr_pts2.ndim == 4 | |
| assert gt_pts2.shape == pr_pts2.shape | |
| # concat the pointcloud | |
| nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2) | |
| nan_gt_pts2 = ( | |
| invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None | |
| ) | |
| pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2) | |
| pr_pts2 = ( | |
| invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None | |
| ) | |
| all_gt = ( | |
| torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) | |
| if gt_pts2 is not None | |
| else nan_gt_pts1 | |
| ) | |
| all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1 | |
| dot_gt_pr = (all_pr * all_gt).sum(dim=-1) | |
| dot_gt_gt = all_gt.square().sum(dim=-1) | |
| if fit_mode.startswith("avg"): | |
| # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1) | |
| scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) | |
| elif fit_mode.startswith("median"): | |
| scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values | |
| elif fit_mode.startswith("weiszfeld"): | |
| # init scaling with l2 closed form | |
| scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) | |
| # iterative re-weighted least-squares | |
| for iter in range(10): | |
| # re-weighting by inverse of distance | |
| dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1) | |
| # print(dis.nanmean(-1)) | |
| w = dis.clip_(min=1e-8).reciprocal() | |
| # update the scaling with the new weights | |
| scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1) | |
| else: | |
| raise ValueError(f"bad {fit_mode=}") | |
| if fit_mode.endswith("stop_grad"): | |
| scaling = scaling.detach() | |
| scaling = scaling.clip(min=1e-3) | |
| # assert scaling.isfinite().all(), bb() | |
| return scaling | |