import importlib import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat import pdb class Camera(object): def __init__(self, entry): fx, fy, cx, cy = entry[:4] self.fx = fx self.fy = fy self.cx = cx self.cy = cy w2c_mat = np.array(entry[6:]).reshape(3, 4) w2c_mat_4x4 = np.eye(4) w2c_mat_4x4[:3, :] = w2c_mat self.w2c_mat = w2c_mat_4x4 self.c2w_mat = np.linalg.inv(w2c_mat_4x4) def get_relative_pose(cam_params, zero_first_frame_scale): abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] source_cam_c2w = abs_c2ws[0] if zero_first_frame_scale: cam_to_origin = 0 else: cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3]) target_cam_c2w = np.array([ [1, 0, 0, 0], [0, 1, 0, -cam_to_origin], [0, 0, 1, 0], [0, 0, 0, 1] ]) abs2rel = target_cam_c2w @ abs_w2cs[0] ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] ret_poses = np.array(ret_poses, dtype=np.float32) return ret_poses def get_K(intrinsics, size): def normalize_intrinsic(x, size): h, w = size x[:,:,0:1] = x[:,:,0:1] / w x[:,:,1:2] = x[:,:,1:2] / h return x b, _, t, _ = intrinsics.shape K = torch.zeros((b, t, 9), dtype=intrinsics.dtype, device=intrinsics.device) fx, fy, cx, cy = intrinsics.squeeze(1).chunk(4, dim=-1) K[:,:,0:1] = fx K[:,:,2:3] = cx K[:,:,4:5] = fy K[:,:,5:6] = cy K[:,:,8:9] = 1.0 K = rearrange(K, "b t (h w) -> b t h w", h=3, w=3) K = normalize_intrinsic(K, size) return K def get_camera_flow_generator_input(condition_image, camparams, device, speed=1.0): """ Args - condition_image: [c h w], scale~[0,255] - camparam: [b, 18] (fx, fy, cx, cy, 0, 0, 3x4 Rt matrix), W2C. - intrinsic: [b, 1, t, 4] (fx, fy, cx, cy) - c2w: [b, 1, t, 4, 4] """ condition_image = condition_image.unsqueeze(0)/255. # bchw, scale~[0,1] sample_size = condition_image.shape[2:] cam_params = [[float(x) for x in camparam] for camparam in camparams] cam_params = [Camera(cam_param) for cam_param in cam_params] intrinsic = np.asarray([[cam_param.fx * sample_size[1], cam_param.fy * sample_size[0], cam_param.cx * sample_size[1], cam_param.cy * sample_size[0]] for cam_param in cam_params], dtype=np.float32) intrinsic = torch.as_tensor(intrinsic).unsqueeze(0).unsqueeze(0) # [1, 1, f, 4] c2w = get_relative_pose(cam_params, zero_first_frame_scale=True) c2w[:, :3, -1] = c2w[:, :3, -1] * speed c2w = torch.as_tensor(c2w) c2w = c2w.unsqueeze(0) b = condition_image.shape[0] t = c2w.shape[1] K = get_K(intrinsic, size=condition_image.shape[2:]) # [b t 3 3] c2w_dummy = repeat(torch.eye(4, dtype=c2w.dtype, device=device), "h w -> b 1 h w", b=c2w.shape[0]) t = 1 assert t == 1, "We use single image setting in 3D estimation networks! Now, you use more than one image for the context view." batch = dict() batch['context'] = { 'image': condition_image, 'intrinsics': K[:,:1], 'extrinsics': c2w_dummy, 'near': torch.ones((b, t), device=device), 'far': torch.ones((b, t), device=device) * 100, 'index': torch.arange(t).to(device) } b, t = c2w.shape[:2] batch['target'] = { 'intrinsics': K, 'extrinsics': c2w, 'near': torch.ones((b, t), device=device), 'far': torch.ones((b, t), device=device) * 100, 'index': repeat(torch.arange(t).to(device), "t -> b t", b=b) } batch['scene'] = 'random' batch['variable_intrinsic'] = None return batch def to_zero_to_one(x): return (x+1)/2 def instantiate_from_config(config, **additional_kwargs): if not "target" in config: if config == '__is_first_stage__': return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") additional_kwargs.update(config.get("kwargs", dict())) return get_obj_from_str(config["target"])(**additional_kwargs) def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def warp_image(image, flow, use_forward_flow=True): """ Args image: context image (src view image) flow: forward (src -> trgt) or backward optical flow (trgt -> src) """ assert image.ndim==4 and flow.ndim==4 h, w = flow.shape[2:] if use_forward_flow: flow = -flow # Create a mesh grid meshgrid = torch.meshgrid(torch.arange(w), torch.arange(h), indexing='xy') grid = torch.stack(meshgrid, dim=2).float().to(image.device) # Shape: (h, w, 2) # Apply flow to the grid flow_map = repeat(grid, "h w c -> b h w c", b=flow.shape[0]) + flow.permute(0, 2, 3, 1) # Permute to match grid shape (h, w, 2) # Normalize the flow map to [-1, 1] range for grid_sample flow_map[..., 0] = 2.0 * flow_map[..., 0] / max(w - 1, 1) - 1.0 flow_map[..., 1] = 2.0 * flow_map[..., 1] / max(h - 1, 1) - 1.0 # Warp image using grid_sample warped_image = F.grid_sample(image, flow_map, mode='bilinear', align_corners=True) # Create the unobserved mask # observed_mask = (flow_map[..., 0] >= -1.0) & (flow_map[..., 0] <= 1.0) & (flow_map[..., 1] >= -1.0) & (flow_map[..., 1] <= 1.0) return warped_image def forward_bilinear_splatting(image, flow, mask=None): """ Forward warping (splatting) with bilinear interpolation for an entire batch at once. Args: image: (B, 3, H, W) # 소스 이미지 flow: (B, 2, H, W) # forward flow (dx, dy) mask: (B, 1, H, W) # 1: valid, 0: invalid Returns: warped: (B, 3, H, W) # forward warp 결과 """ device = image.device B, C_i, H, W = image.shape if mask is None: mask = torch.ones(B, 1, H, W).to(device, flow.dtype) assert C_i == 3, f"image의 채널 수는 3이어야 합니다. (현재: {C_i})" assert flow.shape == (B, 2, H, W), "flow는 (B,2,H,W) 형태여야 합니다." # (BF)CHW, C=2 assert mask.shape == (B, 1, H, W), "mask는 (B,1,H,W) 형태여야 합니다." # (BF)CHW, C=1 # (B,3,H,W) -> (B,H,W,3) image_bhwc = image.permute(0, 2, 3, 1).contiguous() # (B,H,W,3) # (B,2,H,W) -> (B,H,W,2) flow_bhwt = flow.permute(0, 2, 3, 1).contiguous() # (B,H,W,2) # (B,1,H,W) -> (B,H,W) mask_bhw = mask.view(B, H, W) # (B,H,W) # 나중에 scatter_add로 누적하기 위해 1D로 펼침 # 소스 이미지 픽셀 값 (B*H*W, 3) image_flat = image_bhwc.view(-1, C_i) # 플로우 (B*H*W, 2) flow_flat = flow_bhwt.view(-1, 2) # 마스크 (B*H*W,) mask_flat = mask_bhw.view(-1) # 각 픽셀이 속한 (batch b, y, x) 좌표를 1D로 만들기 b_grid = torch.arange(B, device=device).view(B,1,1).expand(-1,H,W) # (B,H,W) y_grid = torch.arange(H, device=device).view(1,H,1).expand(B,-1,W) x_grid = torch.arange(W, device=device).view(1,1,W).expand(B,H,-1) b_idx = b_grid.flatten() # (B*H*W) y_idx = y_grid.flatten() x_idx = x_grid.flatten() # flow 적용 (x+dx, y+dy) dx = flow_flat[:, 0] dy = flow_flat[:, 1] tx = x_idx + dx # float ty = y_idx + dy # float # bilinear 보간을 위해 floor/ceil tx0 = tx.floor().long() tx1 = tx0 + 1 ty0 = ty.floor().long() ty1 = ty0 + 1 alpha = tx - tx.floor() # (B*H*W) beta = ty - ty.floor() # 유효 범위 & mask valid = ((mask_flat == 1) & (tx0 >= 0) & (tx1 < W) & (ty0 >= 0) & (ty1 < H)) valid_idx = valid.nonzero(as_tuple=True) # (N,) # 필요한 부분만 인덱싱 v_b = b_idx[valid_idx] # (N,) v_x0 = tx0[valid_idx] v_x1 = tx1[valid_idx] v_y0 = ty0[valid_idx] v_y1 = ty1[valid_idx] v_alpha = alpha[valid_idx] v_beta = beta[valid_idx] v_src = image_flat[valid_idx] # (N,3) # bilinear 가중치 w00 = (1 - v_alpha) * (1 - v_beta) w01 = v_alpha * (1 - v_beta) w10 = (1 - v_alpha) * v_beta w11 = v_alpha * v_beta # 최종 결과 (B,H,W,3)와 가중치맵 (B,H,W) warped_bhwc = torch.zeros_like(image_bhwc) # (B,H,W,3) weight_map = torch.zeros((B, H, W), dtype=image.dtype, device=device) # 다시 (B*H*W)로 펼침 warped_flat = warped_bhwc.view(-1, C_i) # (B*H*W,3) weight_flat = weight_map.view(-1) # (B*H*W,) # (b, y, x)를 (B,H,W) 1D 인덱스로 변환 # offset_b = b*(H*W), 그 후 y*W + x def flatten_index(b, y, x): return b*(H*W) + (y * W) + x i00 = flatten_index(v_b, v_y0, v_x0) i01 = flatten_index(v_b, v_y0, v_x1) i10 = flatten_index(v_b, v_y1, v_x0) i11 = flatten_index(v_b, v_y1, v_x1) # scatter_add로 누적 warped_flat.index_add_(0, i00, w00.unsqueeze(-1) * v_src) warped_flat.index_add_(0, i01, w01.unsqueeze(-1) * v_src) warped_flat.index_add_(0, i10, w10.unsqueeze(-1) * v_src) warped_flat.index_add_(0, i11, w11.unsqueeze(-1) * v_src) weight_flat.index_add_(0, i00, w00) weight_flat.index_add_(0, i01, w01) weight_flat.index_add_(0, i10, w10) weight_flat.index_add_(0, i11, w11) # 누적된 값을 weight로 나누어 최종 색상 확정 w_valid = (weight_flat > 0) warped_flat[w_valid] /= weight_flat[w_valid].unsqueeze(-1) # (B,H,W,3)로 복원 후, (B,3,H,W)로 permute warped_bhwc = warped_flat.view(B, H, W, C_i) warped = warped_bhwc.permute(0, 3, 1, 2).contiguous() # (B,3,H,W) return warped def run_filtering(flow_f, flow_b, cycle_th=3.): """ Args: flow_f: b 2 h w flow_b: b 2 h w cycle_th: distance threshold for inconsistency (e.g., 3.0 pixel) Returns: valid_mask: binary mask (0: Not consistent or 1: consistent), float, [b 1 h w] """ assert flow_f.ndim == 4 and flow_b.ndim == 4 device = flow_f.device h, w = flow_f.shape[-2:] num_imgs = flow_f.shape[0] flow_f = flow_f flow_b = flow_b grid = repeat(gen_grid(h, w, device=device).permute(2, 0, 1)[None], "b c h w -> (b v) c h w", v=num_imgs) coord2 = flow_f + grid coord2_normed = normalize_coords(coord2.permute(0, 2, 3, 1), h, w) flow_21_sampled = F.grid_sample(flow_b, coord2_normed, align_corners=True) map_i = flow_f + flow_21_sampled fb_discrepancy = torch.norm(map_i.squeeze(), dim=1) valid_mask = fb_discrepancy < cycle_th return valid_mask.unsqueeze(1).float() def gen_grid(h, w, device, normalize=False, homogeneous=False): if normalize: lin_y = torch.linspace(-1., 1., steps=h, device=device) lin_x = torch.linspace(-1., 1., steps=w, device=device) else: lin_y = torch.arange(0, h, device=device) lin_x = torch.arange(0, w, device=device) grid_y, grid_x = torch.meshgrid((lin_y, lin_x)) grid = torch.stack((grid_x, grid_y), -1) if homogeneous: grid = torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1) return grid # [h, w, 2 or 3] def normalize_coords(coords, h, w, no_shift=False): assert coords.shape[-1] == 2 if no_shift: return coords / torch.tensor([w-1., h-1.], device=coords.device) * 2 else: return coords / torch.tensor([w-1., h-1.], device=coords.device) * 2 - 1. #-------------------------------------------------------------------------------------------------------------- # Codes borrowed from https://github.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch from typing import Optional, Union import torch from math import pi as PI def get_color_wheel(device: torch.device) -> torch.Tensor: """ Generates the color wheel. :param device: (torch.device) Device to be used :return: (torch.Tensor) Color wheel tensor of the shape [55, 3] """ # Set constants RY: int = 15 YG: int = 6 GC: int = 4 CB: int = 11 BM: int = 13 MR: int = 6 # Init color wheel color_wheel: torch.Tensor = torch.zeros((RY + YG + GC + CB + BM + MR, 3), dtype=torch.float32) # Init counter counter: int = 0 # RY color_wheel[0:RY, 0] = 255 color_wheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY) counter: int = counter + RY # YG color_wheel[counter:counter + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG) color_wheel[counter:counter + YG, 1] = 255 counter: int = counter + YG # GC color_wheel[counter:counter + GC, 1] = 255 color_wheel[counter:counter + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC) counter: int = counter + GC # CB color_wheel[counter:counter + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB) color_wheel[counter:counter + CB, 2] = 255 counter: int = counter + CB # BM color_wheel[counter:counter + BM, 2] = 255 color_wheel[counter:counter + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM) counter: int = counter + BM # MR color_wheel[counter:counter + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR) color_wheel[counter:counter + MR, 0] = 255 # To device color_wheel: torch.Tensor = color_wheel.to(device) return color_wheel def _flow_hw_to_color(flow_vertical: torch.Tensor, flow_horizontal: torch.Tensor, color_wheel: torch.Tensor, device: torch.device) -> torch.Tensor: """ Private function applies the flow color wheel to flow components (vertical and horizontal). :param flow_vertical: (torch.Tensor) Vertical flow of the shape [height, width] :param flow_horizontal: (torch.Tensor) Horizontal flow of the shape [height, width] :param color_wheel: (torch.Tensor) Color wheel tensor of the shape [55, 3] :param: device: (torch.device) Device to be used :return: (torch.Tensor) Visualized flow of the shape [3, height, width] """ # Get shapes _, height, width = flow_vertical.shape # Init flow image flow_image: torch.Tensor = torch.zeros(3, height, width, dtype=torch.float32, device=device) # Get number of colors number_of_colors: int = color_wheel.shape[0] # Compute norm, angle and factors flow_norm: torch.Tensor = (flow_vertical ** 2 + flow_horizontal ** 2).sqrt() angle: torch.Tensor = torch.atan2(- flow_vertical, - flow_horizontal) / PI fk: torch.Tensor = (angle + 1.) / 2. * (number_of_colors - 1.) k0: torch.Tensor = torch.floor(fk).long() k1: torch.Tensor = k0 + 1 k1[k1 == number_of_colors] = 0 f: torch.Tensor = fk - k0 # Iterate over color components for index in range(color_wheel.shape[1]): # Get component of all colors tmp: torch.Tensor = color_wheel[:, index] # Get colors color_0: torch.Tensor = tmp[k0] / 255. color_1: torch.Tensor = tmp[k1] / 255. # Compute color color: torch.Tensor = (1. - f) * color_0 + f * color_1 # Get color index color_index: torch.Tensor = flow_norm <= 1 # Set color saturation color[color_index] = 1 - flow_norm[color_index] * (1. - color[color_index]) color[~color_index] = color[~color_index] * 0.75 # Set color in image flow_image[index] = torch.floor(255 * color) return flow_image def flow_to_color(flow: torch.Tensor, clip_flow: Optional[Union[float, torch.Tensor]] = None, normalize_over_video: bool = False) -> torch.Tensor: """ Function converts a given optical flow map into the classical color schema. :param flow: (torch.Tensor) Optical flow tensor of the shape [batch size (optional), 2, height, width]. :param clip_flow: (Optional[Union[float, torch.Tensor]]) Max value of flow values for clipping (default None). :param normalize_over_video: (bool) If true scale is normalized over the whole video (batch). :return: (torch.Tensor) Flow visualization (float tensor) with the shape [batch size (if used), 3, height, width]. """ # Check parameter types assert torch.is_tensor(flow), "Given flow map must be a torch.Tensor, {} given".format(type(flow)) assert torch.is_tensor(clip_flow) or isinstance(clip_flow, float) or clip_flow is None, \ "Given clip_flow parameter must be a float, a torch.Tensor, or None, {} given".format(type(clip_flow)) # Check shapes assert flow.ndimension() in [3, 4], \ "Given flow must be a 3D or 4D tensor, given tensor shape {}.".format(flow.shape) if torch.is_tensor(clip_flow): assert clip_flow.ndimension() == 0, \ "Given clip_flow tensor must be a scalar, given tensor shape {}.".format(clip_flow.shape) # Manage batch dimension batch_dimension: bool = True if flow.ndimension() == 3: flow = flow[None] batch_dimension: bool = False # Save shape batch_size, _, height, width = flow.shape # Check flow dimension assert flow.shape[1] == 2, "Flow dimension must have the shape 2 but tensor with {} given".format(flow.shape[1]) # Save device device: torch.device = flow.device # Clip flow if utilized if clip_flow is not None: flow = flow.clip(max=clip_flow) # Get horizontal and vertical flow flow_vertical: torch.Tensor = flow[:, 0:1] flow_horizontal: torch.Tensor = flow[:, 1:2] # Get max norm of flow flow_max_norm: torch.Tensor = (flow_vertical ** 2 + flow_horizontal ** 2).sqrt().view(batch_size, -1).max(dim=-1)[0] flow_max_norm: torch.Tensor = flow_max_norm.view(batch_size, 1, 1, 1) if normalize_over_video: flow_max_norm: Tensor = flow_max_norm.max(dim=0, keepdim=True)[0] # Normalize flow flow_vertical: torch.Tensor = flow_vertical / (flow_max_norm + 1e-05) flow_horizontal: torch.Tensor = flow_horizontal / (flow_max_norm + 1e-05) # Get color wheel color_wheel: torch.Tensor = get_color_wheel(device=device) # Init flow image flow_image = torch.zeros(batch_size, 3, height, width, device=device) # Iterate over batch dimension for index in range(batch_size): flow_image[index] = _flow_hw_to_color(flow_vertical=flow_vertical[index], flow_horizontal=flow_horizontal[index], color_wheel=color_wheel, device=device) return flow_image if batch_dimension else flow_image[0]