from transformers import pipeline from PIL import Image import requests import torchvision import os from .camera.WarperPytorch import Warper import numpy as np from einops import rearrange, repeat import torch import torch.nn as nn import torch.nn.functional as F from .depth_anything_v2.dpt import DepthAnythingV2 import pdb def to_pil_image(x): # x: c h w, [-1, 1] x_np = ((x+1)/2*255).permute(1,2,0).detach().cpu().numpy().astype(np.uint8) x_pil = Image.fromarray(x_np) return x_pil def to_npy(x): return ((x+1)/2*255).permute(1,2,0).detach().cpu().numpy() def unnormalize_intrinsic(x, size): h, w = size x_ = x.detach().clone() x_[:,0:1] = x[:,0:1].detach().clone() * w x_[:,1:2] = x[:,1:2].detach().clone() * h return x_ class DepthWarping_wrapper(nn.Module): def __init__(self, model_config, ckpt_path,): super().__init__() # self.depth_model = pipeline(task="depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf") self.depth_model = DepthAnythingV2(**model_config) self.depth_model.load_state_dict(torch.load(ckpt_path, map_location='cpu')) self.depth_model = self.depth_model.eval() self.warper = Warper() def get_input(self, batch): # pdb.set_trace() b, v = batch["target"]["intrinsics"].shape[:2] h, w = batch["context"]["image"].shape[-2:] image = (batch["context"]["image"]) * 2 - 1 image_ctxt = repeat(image, "b c h w -> (b v) c h w", v=v) c2w_ctxt = repeat(batch["context"]["extrinsics"], "b t h w -> (b v t) h w", v=v) # No need to apply inverse as it is an eye matrix. # c2w_trgt = rearrange(torch.inverse(batch["target"]["extrinsics"]), "b t h w -> (b t) h w") c2w_trgt = rearrange(batch["target"]["extrinsics"], "b t h w -> (b t) h w") intrinsics_ctxt = unnormalize_intrinsic(repeat(batch["context"]["intrinsics"], "b t h w -> (b v t) h w", v=v), size=(h,w)) intrinsics_trgt = unnormalize_intrinsic(rearrange(batch["target"]["intrinsics"], "b t h w -> (b t) h w"), size=(h,w)) # image = image.squeeze(1) # depth_ctxt = torch.stack([torch.tensor(self.depth_model.infer_image(to_npy(x))) for x in image], dim=0).to(image.device).unsqueeze(1) # B 1 H W depth_ctxt = torch.stack([self.depth_model.infer_image(to_npy(x)) for x in image], dim=0).to(image.device).unsqueeze(1) # B 1 H W # depth_ctxt = torch.nn.functional.interpolate( # depth_ctxt, # size=(h,w), # mode="bicubic", # align_corners=False, # ) return image_ctxt, c2w_ctxt, c2w_trgt, intrinsics_ctxt, intrinsics_trgt, depth_ctxt, batch['variable_intrinsic'] def forward(self, batch): image_ctxt, c2w_ctxt, c2w_trgt, intrinsics_ctxt, intrinsics_trgt, depth_ctxt, variable_intrinsic = self.get_input(batch) with torch.cuda.amp.autocast(enabled=False): b, v = batch["target"]["intrinsics"].shape[:2] # h, w = image_ctxt.shape[-2:] warped_trgt, mask_trgt, warped_depth_trgt, flow_f = self.warper.forward_warp( frame1=image_ctxt, mask1=None, depth1=repeat(depth_ctxt, "b c h w -> (b t) c h w", t=v), transformation1=c2w_ctxt, transformation2=c2w_trgt, intrinsic1=intrinsics_ctxt, intrinsic2=intrinsics_trgt if variable_intrinsic else None) warped_src, mask_src, warped_depth_src, flow_b = self.warper.forward_warp( frame1=warped_trgt, mask1=None, depth1=warped_depth_trgt, transformation1=c2w_trgt, transformation2=c2w_ctxt, intrinsic1=intrinsics_trgt, intrinsic2=None) # if use_backward_flow: # mask = mask_trgt # else: # mask = mask_src return flow_f, flow_b, warped_trgt, depth_ctxt, warped_depth_trgt