diff --git a/.gitignore b/.gitignore
index ba0430d26c996e7f078385407f959c96c271087c..208b0be95d9ed2968365b9367a95b7ca1fbe58ed 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,2 @@
-__pycache__/
\ No newline at end of file
+__pycache__/
+*.DS_Store
\ No newline at end of file
diff --git a/SparseNeuS_demo_v1/confs/blender_general_lod1_val_new.conf b/SparseNeuS_demo_v1/confs/blender_general_lod1_val_new.conf
new file mode 100644
index 0000000000000000000000000000000000000000..dacbc09968c2f4cd6f7348dd93552ea5d8876236
--- /dev/null
+++ b/SparseNeuS_demo_v1/confs/blender_general_lod1_val_new.conf
@@ -0,0 +1,137 @@
+# - for the lod1 geometry network, using adaptive cost for sparse cost regularization network
+#- for lod1 rendering network, using depth-adaptive render
+
+general {
+ base_exp_dir = ./exp/val/1_4_only_narrow_lod1
+
+ recording = [
+ ./,
+ ./data
+ ./ops
+ ./models
+ ./loss
+ ]
+}
+
+dataset {
+ # local path
+ trainpath = /objaverse-processed/zero12345_img/eval_selected
+ valpath = /objaverse-processed/zero12345_img/eval_selected
+ testpath = /objaverse-processed/zero12345_img/eval_selected
+ # trainpath = /objaverse-processed/zero12345_img/zero12345_2stage_5pred_sample/
+ # valpath = /objaverse-processed/zero12345_img/zero12345_2stage_5pred_sample/
+ # testpath = /objaverse-processed/zero12345_img/zero12345_2stage_5pred_sample/
+ imgScale_train = 1.0
+ imgScale_test = 1.0
+ nviews = 5
+ clean_image = True
+ importance_sample = True
+ test_ref_views = [23]
+
+ # test dataset
+ test_n_views = 2
+ test_img_wh = [256, 256]
+ test_clip_wh = [0, 0]
+ test_scan_id = scan110
+ train_img_idx = [49, 50, 52, 53, 54, 56, 58] #[21, 22, 23, 24, 25] #
+ test_img_idx = [51, 55, 57] #[32, 33, 34] #
+
+ test_dir_comment = train
+}
+
+train {
+ learning_rate = 2e-4
+ learning_rate_milestone = [100000, 150000, 200000]
+ learning_rate_factor = 0.5
+ end_iter = 200000
+ save_freq = 5000
+ val_freq = 1
+ val_mesh_freq =1
+ report_freq = 100
+
+ N_rays = 512
+
+ validate_resolution_level = 4
+ anneal_start = 0
+ anneal_end = 25000
+ anneal_start_lod1 = 0
+ anneal_end_lod1 = 15000
+
+ use_white_bkgd = True
+
+ # Loss
+ # ! for training the lod1 network, don't use this regularization in first 10k steps; then use the regularization
+ sdf_igr_weight = 0.1
+ sdf_sparse_weight = 0.02 # 0.002 for lod1 network; 0.02 for lod0 network
+ sdf_decay_param = 100 # cannot be too large, which decide the tsdf range
+ fg_bg_weight = 0.01 # first 0.01
+ bg_ratio = 0.3
+
+ if_fix_lod0_networks = True
+}
+
+model {
+ num_lods = 2
+
+ sdf_network_lod0 {
+ lod = 0,
+ ch_in = 56, # the channel num of fused pyramid features
+ voxel_size = 0.02105263, # 0.02083333, should be 2/95
+ vol_dims = [96, 96, 96],
+ hidden_dim = 128,
+ cost_type = variance_mean
+ d_pyramid_feature_compress = 16,
+ regnet_d_out = 16,
+ num_sdf_layers = 4,
+ # position embedding
+ multires = 6
+ }
+
+
+ sdf_network_lod1 {
+ lod = 1,
+ ch_in = 56, # the channel num of fused pyramid features
+ voxel_size = 0.0104712, #0.01041667, should be 2/191
+ vol_dims = [192, 192, 192],
+ hidden_dim = 128,
+ cost_type = variance_mean
+ d_pyramid_feature_compress = 8,
+ regnet_d_out = 8,
+ num_sdf_layers = 4,
+ # position embedding
+ multires = 6
+ }
+
+
+ variance_network {
+ init_val = 0.2
+ }
+
+ variance_network_lod1 {
+ init_val = 0.2
+ }
+
+ rendering_network {
+ in_geometry_feat_ch = 16
+ in_rendering_feat_ch = 56
+ anti_alias_pooling = True
+ }
+
+ rendering_network_lod1 {
+ in_geometry_feat_ch = 8
+ in_rendering_feat_ch = 56
+ anti_alias_pooling = True
+
+ }
+
+
+ trainer {
+ n_samples_lod0 = 64
+ n_importance_lod0 = 64
+ n_samples_lod1 = 64
+ n_importance_lod1 = 64
+ n_outside = 0 # 128 if render_outside_uniform_sampling
+ perturb = 1.0
+ alpha_type = div
+ }
+}
diff --git a/SparseNeuS_demo_v1/confs/one2345_lod0_val_demo.conf b/SparseNeuS_demo_v1/confs/one2345_lod0_val_demo.conf
new file mode 100644
index 0000000000000000000000000000000000000000..7be6d4098d66473f63252c42d0a1bd25e2338a6b
--- /dev/null
+++ b/SparseNeuS_demo_v1/confs/one2345_lod0_val_demo.conf
@@ -0,0 +1,137 @@
+# - for the lod1 geometry network, using adaptive cost for sparse cost regularization network
+#- for lod1 rendering network, using depth-adaptive render
+
+general {
+
+ base_exp_dir = exp/lod0 # !!! where you store the results and checkpoints to be used
+ recording = [
+ ./,
+ ./data
+ ./ops
+ ./models
+ ./loss
+ ]
+}
+
+dataset {
+ trainpath = ../
+ valpath = ../ # !!! where you store the validation data
+ testpath = ../
+
+
+
+ imgScale_train = 1.0
+ imgScale_test = 1.0
+ nviews = 5
+ clean_image = True
+ importance_sample = True
+ test_ref_views = [23]
+
+ # test dataset
+ test_n_views = 2
+ test_img_wh = [256, 256]
+ test_clip_wh = [0, 0]
+ test_scan_id = scan110
+ train_img_idx = [49, 50, 52, 53, 54, 56, 58] #[21, 22, 23, 24, 25] #
+ test_img_idx = [51, 55, 57] #[32, 33, 34] #
+
+ test_dir_comment = train
+}
+
+train {
+ learning_rate = 2e-4
+ learning_rate_milestone = [100000, 150000, 200000]
+ learning_rate_factor = 0.5
+ end_iter = 200000
+ save_freq = 5000
+ val_freq = 1
+ val_mesh_freq = 1
+ report_freq = 100
+
+ N_rays = 512
+
+ validate_resolution_level = 4
+ anneal_start = 0
+ anneal_end = 25000
+ anneal_start_lod1 = 0
+ anneal_end_lod1 = 15000
+
+ use_white_bkgd = True
+
+ # Loss
+ # ! for training the lod1 network, don't use this regularization in first 10k steps; then use the regularization
+ sdf_igr_weight = 0.1
+ sdf_sparse_weight = 0.02 # 0.002 for lod1 network; 0.02 for lod0 network
+ sdf_decay_param = 100 # cannot be too large, which decide the tsdf range
+ fg_bg_weight = 0.01 # first 0.01
+ bg_ratio = 0.3
+
+ if_fix_lod0_networks = False
+}
+
+model {
+ num_lods = 1
+
+ sdf_network_lod0 {
+ lod = 0,
+ ch_in = 56, # the channel num of fused pyramid features
+ voxel_size = 0.02105263, # 0.02083333, should be 2/95
+ vol_dims = [96, 96, 96],
+ hidden_dim = 128,
+ cost_type = variance_mean
+ d_pyramid_feature_compress = 16,
+ regnet_d_out = 16,
+ num_sdf_layers = 4,
+ # position embedding
+ multires = 6
+ }
+
+
+ sdf_network_lod1 {
+ lod = 1,
+ ch_in = 56, # the channel num of fused pyramid features
+ voxel_size = 0.0104712, #0.01041667, should be 2/191
+ vol_dims = [192, 192, 192],
+ hidden_dim = 128,
+ cost_type = variance_mean
+ d_pyramid_feature_compress = 8,
+ regnet_d_out = 16,
+ num_sdf_layers = 4,
+
+ # position embedding
+ multires = 6
+ }
+
+
+ variance_network {
+ init_val = 0.2
+ }
+
+ variance_network_lod1 {
+ init_val = 0.2
+ }
+
+ rendering_network {
+ in_geometry_feat_ch = 16
+ in_rendering_feat_ch = 56
+ anti_alias_pooling = True
+ }
+
+ rendering_network_lod1 {
+ in_geometry_feat_ch = 16 # default 8
+ in_rendering_feat_ch = 56
+ anti_alias_pooling = True
+
+ }
+
+
+ trainer {
+ n_samples_lod0 = 64
+ n_importance_lod0 = 64
+ n_samples_lod1 = 64
+ n_importance_lod1 = 64
+ n_outside = 0 # 128 if render_outside_uniform_sampling
+ perturb = 1.0
+ alpha_type = div
+ }
+}
diff --git a/SparseNeuS_demo_v1/data/__init__.py b/SparseNeuS_demo_v1/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SparseNeuS_demo_v1/data/blender.py b/SparseNeuS_demo_v1/data/blender.py
new file mode 100644
index 0000000000000000000000000000000000000000..c027f3e05367497c91026b362af4378fe31ff24a
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender.py
@@ -0,0 +1,340 @@
+import torch
+from torch.utils.data import Dataset
+import json
+import numpy as np
+import os
+from PIL import Image
+from torchvision import transforms as T
+from kornia import create_meshgrid
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import cv2 as cv
+from data.scene import get_boundingbox
+
+
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0]
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def get_rays(directions, c2w):
+ """
+ Get ray origin and normalized directions in world coordinate for all pixels in one image.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ directions: (H, W, 3) precomputed ray directions in camera coordinate
+ c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
+ Outputs:
+ rays_o: (H*W, 3), the origin of the rays in world coordinate
+ rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
+ """
+ # Rotate ray directions from camera coordinate to the world coordinate
+ rays_d = directions @ c2w[:3, :3].T # (H, W, 3)
+ # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
+ # The origin of all rays is the camera origin in world coordinate
+ rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3)
+
+ rays_d = rays_d.view(-1, 3)
+ rays_o = rays_o.view(-1, 3)
+
+ return rays_o, rays_d
+
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+class BlenderDataset(Dataset):
+ def __init__(self, root_dir, split, scan_id, n_views, train_img_idx=[], test_img_idx=[],
+ img_wh=[800, 800], clip_wh=[0, 0], original_img_wh=[800, 800],
+ N_rays=512, h_patch_size=5, near=2.0, far=6.0):
+ self.root_dir = root_dir
+ self.split = split
+ self.img_wh = img_wh
+ self.clip_wh = clip_wh
+ self.define_transforms()
+ self.train_img_idx = train_img_idx
+ self.test_img_idx = test_img_idx
+ self.N_rays = N_rays
+ self.h_patch_size = h_patch_size # used to extract patch for supervision
+ self.n_views = n_views
+ self.near, self.far = near, far
+ self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
+
+ with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
+ self.meta = json.load(f)
+
+
+ self.read_meta(near, far)
+ # import ipdb; ipdb.set_trace()
+ self.raw_near_fars = np.stack([np.array([self.near, self.far]) for i in range(len(self.meta['frames']))])
+
+
+ # ! estimate scale_mat
+ self.scale_mat, self.scale_factor = self.cal_scale_mat(
+ img_hw=[self.img_wh[1], self.img_wh[0]],
+ intrinsics=self.all_intrinsics[self.train_img_idx],
+ extrinsics=self.all_w2cs[self.train_img_idx],
+ near_fars=self.raw_near_fars[self.train_img_idx],
+ factor=1.1)
+ # self.scale_mat = np.eye(4)
+ # self.scale_factor = 1.0
+ # import ipdb; ipdb.set_trace()
+ # * after scaling and translation, unit bounding box
+ self.scaled_intrinsics, self.scaled_w2cs, self.scaled_c2ws, \
+ self.scaled_affine_mats, self.scaled_near_fars = self.scale_cam_info()
+
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+ self.partial_vol_origin = torch.Tensor([-1., -1., -1.])
+ self.white_back = True
+
+ def read_meta(self, near=2.0, far=6.0):
+
+
+ self.ref_img_idx = self.train_img_idx[0]
+ ref_c2w = np.array(self.meta['frames'][self.ref_img_idx]['transform_matrix']) @ self.blender2opencv
+ # ref_c2w = torch.FloatTensor(ref_c2w)
+ self.ref_c2w = ref_c2w
+ self.ref_w2c = np.linalg.inv(ref_c2w)
+
+
+ w, h = self.img_wh
+ self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
+ self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh
+
+ # bounds, common for all scenes
+ self.near = near
+ self.far = far
+ self.bounds = np.array([self.near, self.far])
+
+ # ray directions for all pixels, same for all images (same H, W, focal)
+ self.directions = get_ray_directions(h, w, [self.focal,self.focal]) # (h, w, 3)
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = np.array([[self.focal,0,w/2],[0,self.focal,h/2],[0,0,1]]).astype(np.float32)
+ self.intrinsics = intrinsics
+
+ self.image_paths = []
+ self.poses = []
+ self.all_rays = []
+ self.all_images = []
+ self.all_masks = []
+ self.all_w2cs = []
+ self.all_intrinsics = []
+ for frame in self.meta['frames']:
+ pose = np.array(frame['transform_matrix']) @ self.blender2opencv
+ self.poses += [pose]
+ c2w = torch.FloatTensor(pose)
+ w2c = np.linalg.inv(c2w)
+ image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
+ self.image_paths += [image_path]
+ img = Image.open(image_path)
+ img = img.resize(self.img_wh, Image.LANCZOS)
+ img = self.transform(img) # (4, h, w)
+
+ self.all_masks += [img[-1:,:]>0]
+ # img = img[:3, :] * img[ -1:,:] + (1 - img[-1:, :]) # blend A to RGB
+ img = img[:3, :] * img[ -1:,:]
+ img = img.numpy() # (3, h, w)
+ self.all_images += [img]
+
+
+ self.all_masks += []
+ self.all_intrinsics.append(self.intrinsics)
+ # - transform from world system to ref-camera system
+ self.all_w2cs.append(w2c @ np.linalg.inv(self.ref_w2c))
+
+ self.all_images = torch.from_numpy(np.stack(self.all_images)).to(torch.float32)
+ self.all_intrinsics = torch.from_numpy(np.stack(self.all_intrinsics)).to(torch.float32)
+ self.all_w2cs = torch.from_numpy(np.stack(self.all_w2cs)).to(torch.float32)
+ # self.img_wh = [self.img_wh[0] - self.clip_wh[0] - self.clip_wh[2],
+ # self.img_wh[1] - self.clip_wh[1] - self.clip_wh[3]]
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+ center, radius, _ = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def scale_cam_info(self):
+ new_intrinsics = []
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ for idx in range(len(self.all_images)):
+
+ intrinsics = self.all_intrinsics[idx]
+ # import ipdb; ipdb.set_trace()
+ P = intrinsics @ self.all_w2cs[idx] @ self.scale_mat
+ P = P.cpu().numpy()[:3, :4]
+
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ new_intrinsics.append(intrinsics)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsics[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+
+ new_intrinsics, new_w2cs, new_c2ws, new_affine_mats, new_near_fars = \
+ np.stack(new_intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), \
+ np.stack(new_affine_mats), np.stack(new_near_fars)
+
+ new_intrinsics = torch.from_numpy(np.float32(new_intrinsics))
+ new_w2cs = torch.from_numpy(np.float32(new_w2cs))
+ new_c2ws = torch.from_numpy(np.float32(new_c2ws))
+ new_affine_mats = torch.from_numpy(np.float32(new_affine_mats))
+ new_near_fars = torch.from_numpy(np.float32(new_near_fars))
+
+ return new_intrinsics, new_w2cs, new_c2ws, new_affine_mats, new_near_fars
+
+ def load_poses_all(self, file=f"transforms_train.json"):
+ with open(os.path.join(self.root_dir, file), 'r') as f:
+ meta = json.load(f)
+
+ c2ws = []
+ for i,frame in enumerate(meta['frames']):
+ c2ws.append(np.array(frame['transform_matrix']) @ self.blender2opencv)
+ return np.stack(c2ws)
+
+ def define_transforms(self):
+ self.transform = T.ToTensor()
+
+
+
+ def get_conditional_sample(self):
+ sample = {}
+ support_idxs = self.train_img_idx
+
+ sample['images'] = self.all_images[support_idxs] # (V, 3, H, W)
+ sample['w2cs'] = self.scaled_w2cs[self.train_img_idx] # (V, 4, 4)
+ sample['c2ws'] = self.scaled_c2ws[self.train_img_idx] # (V, 4, 4)
+ sample['near_fars'] = self.scaled_near_fars[self.train_img_idx] # (V, 2)
+ sample['intrinsics'] = self.scaled_intrinsics[self.train_img_idx][:, :3, :3] # (V, 3, 3)
+ sample['affine_mats'] = self.scaled_affine_mats[self.train_img_idx] # ! in world space
+
+ # sample['scan'] = self.scan_id
+ sample['scale_factor'] = torch.tensor(self.scale_factor)
+ sample['scale_mat'] = torch.from_numpy(self.scale_mat)
+ sample['trans_mat'] = torch.from_numpy(np.linalg.inv(self.ref_w2c))
+ sample['img_wh'] = torch.from_numpy(np.array(self.img_wh))
+ sample['partial_vol_origin'] = torch.tensor(self.partial_vol_origin, dtype=torch.float32)
+
+ return sample
+
+
+
+ def __len__(self):
+ if self.split == 'train':
+ return self.n_views * 1000
+ else:
+ return len(self.test_img_idx) * 1000
+
+
+ def __getitem__(self, idx):
+ sample = {}
+
+ if self.split == 'train':
+ render_idx = self.train_img_idx[idx % self.n_views]
+ support_idxs = [idx for idx in self.train_img_idx if idx != render_idx]
+ else:
+ # render_idx = idx % self.n_test_images + self.n_train_images
+ render_idx = self.test_img_idx[idx % len(self.test_img_idx)]
+ support_idxs = [render_idx]
+
+ sample['images'] = self.all_images[support_idxs] # (V, 3, H, W)
+ sample['w2cs'] = self.scaled_w2cs[support_idxs] # (V, 4, 4)
+ sample['c2ws'] = self.scaled_c2ws[support_idxs] # (V, 4, 4)
+ sample['intrinsics'] = self.scaled_intrinsics[support_idxs][:, :3, :3] # (V, 3, 3)
+ sample['affine_mats'] = self.scaled_affine_mats[support_idxs] # ! in world space
+ # sample['scan'] = self.scan_id
+ sample['scale_factor'] = torch.tensor(self.scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(self.img_wh))
+ sample['partial_vol_origin'] = torch.tensor(self.partial_vol_origin, dtype=torch.float32)
+ sample['img_index'] = torch.tensor(render_idx)
+
+ # - query image
+ sample['query_image'] = self.all_images[render_idx]
+ sample['query_c2w'] = self.scaled_c2ws[render_idx]
+ sample['query_w2c'] = self.scaled_w2cs[render_idx]
+ sample['query_intrinsic'] = self.scaled_intrinsics[render_idx]
+ sample['query_near_far'] = self.scaled_near_fars[render_idx]
+ # sample['meta'] = str(self.scan_id) + "_" + os.path.basename(self.images_list[render_idx])
+ sample['scale_mat'] = torch.from_numpy(self.scale_mat)
+ sample['trans_mat'] = torch.from_numpy(np.linalg.inv(self.ref_w2c))
+ sample['rendering_c2ws'] = self.scaled_c2ws[self.test_img_idx]
+ sample['rendering_imgs_idx'] = torch.Tensor(np.array(self.test_img_idx).astype(np.int32))
+
+ # - generate rays
+ if self.split == 'val' or self.split == 'test':
+ sample_rays = gen_rays_from_single_image(
+ self.img_wh[1], self.img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=None,
+ mask=None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ self.img_wh[1], self.img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=None,
+ mask=None,
+ dilated_mask=None,
+ importance_sample=False,
+ h_patch_size=self.h_patch_size
+ )
+
+ sample['rays'] = sample_rays
+
+ return sample
\ No newline at end of file
diff --git a/SparseNeuS_demo_v1/data/blender_general.py b/SparseNeuS_demo_v1/data/blender_general.py
new file mode 100644
index 0000000000000000000000000000000000000000..871bcd6e9e2542110213e34ac5e7bde97184d938
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general.py
@@ -0,0 +1,432 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ depth_h = np.array(read_pfm(filename)[0], dtype=np.float32) # (1200, 1600)
+ depth_h = cv2.resize(depth_h, None, fx=0.5, fy=0.5,
+ interpolation=cv2.INTER_NEAREST) # (600, 800)
+ depth_h = depth_h[44:556, 80:720] # (512, 640)
+ depth_h = cv2.resize(depth_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ depth = cv2.resize(depth_h, None, fx=1.0 / 4, fy=1.0 / 4,
+ interpolation=cv2.INTER_NEAREST)
+
+ return depth, depth_h
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ # print("center", center)
+ # print("radius", radius)
+ # print("bounds", bounds)
+ # import ipdb; ipdb.set_trace()
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 8*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ depth_h = cv2.imread(filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 65535 * 1.4 + 0.5
+
+ depth_h[depth_h < near_bound+1e-3] = 0.0
+
+ depth = {}
+ for l in range(3):
+ depth[f"level_{l}"] = cv2.resize(
+ depth_h,
+ None,
+ fx=1.0 / (2**l),
+ fy=1.0 / (2**l),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if self.split == "train":
+ cutout = np.ones_like(depth[f"level_2"])
+ h0 = int(np.random.randint(0, high=cutout.shape[0] // 5, size=1))
+ h1 = int(
+ np.random.randint(
+ 4 * cutout.shape[0] // 5, high=cutout.shape[0], size=1
+ )
+ )
+ w0 = int(np.random.randint(0, high=cutout.shape[1] // 5, size=1))
+ w1 = int(
+ np.random.randint(
+ 4 * cutout.shape[1] // 5, high=cutout.shape[1], size=1
+ )
+ )
+ cutout[h0:h1, w0:w1] = 0
+ depth_aug = depth[f"level_2"] * cutout
+ else:
+ depth_aug = depth[f"level_2"].copy()
+
+ return depth, depth_h, depth_aug
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+
+ folder_uid_dict = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ # idx = idx % 8
+ # uid = 'c40d63d5d740405e91c7f5fce855076e'
+ # folder_id = '000-123'
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ src_views = range(8+idx*4, 8+(idx+1)*4)
+
+
+ for vid in src_views:
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_{vid%4}_10.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+ # print(scale_mat)
+ # print(scale_factor)
+ # ! calculate the new w2cs after scaling
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0])
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_12_narrow.py b/SparseNeuS_demo_v1/data/blender_general_12_narrow.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb1183fb695101bac1f8f33da9438a84378b3dca
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_12_narrow.py
@@ -0,0 +1,427 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ self.root_dir = root_dir
+ self.split = split
+ self.imgs_per_instance = 12
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/narrow_12_split_upd.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path_narrow_8 = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path_narrow_8, 'r') as f:
+ narrow_8_meta = json.load(f)
+
+ pose_json_path_narrow_4 = "/objaverse-processed/zero12345_img/zero12345_2stage_12_pose.json"
+ with open(pose_json_path_narrow_4, 'r') as f:
+ narrow_4_meta = json.load(f)
+
+
+ self.img_ids = list(narrow_8_meta["c2ws"].keys()) + list(narrow_4_meta["c2ws"].keys()) # (8 + 8*4) + (4 + 4*4)
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(narrow_8_meta["c2ws"].values()) + list(narrow_4_meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ assert narrow_8_meta["intrinsics"] == narrow_4_meta["intrinsics"], "intrinsics not equal"
+ intrinsic[:3, :3] = np.array(narrow_8_meta["intrinsics"])
+ self.intrinsic = intrinsic
+ assert narrow_8_meta["near_far"] == narrow_4_meta["near_far"], "near_far not equal"
+ self.near_far = np.array(narrow_8_meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+
+
+
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return self.imgs_per_instance*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+ idx_original=idx
+
+ folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance]
+
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ idx = idx % self.imgs_per_instance # [0, 11]
+ if idx < 8:
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+ else:
+ # target view
+ c2w = self.c2ws[idx-8+40]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow_12/", folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow_12/", folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+
+ src_views = range(8, 8 + 8 * 4 + 4 + 4*4)
+ src_views_used = []
+ skipped_idx = [40, 41, 42, 43]
+ for vid in src_views:
+ if vid in skipped_idx:
+ continue
+
+ src_views_used.append(vid)
+ cur_view_id = (vid - 8) // 4 # [0, 7]
+
+ # choose narrow
+ if cur_view_id < 8:
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{cur_view_id}_{vid%4}_10.png')
+ else: # choose 2-stage
+ cur_view_id = cur_view_id - 1
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow_12", folder_id, uid, f'view_{cur_view_id}_{vid%4}.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+
+
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ # print("img numeber: ", len(imgs))
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx_original % self.imgs_per_instance] + src_views_used
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ if view_ids[0] < 8:
+ meta_end = "_narrow"+ "_refview" + str(view_ids[0])
+ else:
+ meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8)
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_12_narrow_8.py b/SparseNeuS_demo_v1/data/blender_general_12_narrow_8.py
new file mode 100644
index 0000000000000000000000000000000000000000..467dc5d4d1df3b6d3c8aa4384a1048bec9910973
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_12_narrow_8.py
@@ -0,0 +1,427 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ self.root_dir = root_dir
+ self.split = split
+ self.imgs_per_instance = 8
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/narrow_12_split_upd.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path_narrow_8 = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path_narrow_8, 'r') as f:
+ narrow_8_meta = json.load(f)
+
+ pose_json_path_narrow_4 = "/objaverse-processed/zero12345_img/zero12345_2stage_12_pose.json"
+ with open(pose_json_path_narrow_4, 'r') as f:
+ narrow_4_meta = json.load(f)
+
+
+ self.img_ids = list(narrow_8_meta["c2ws"].keys()) + list(narrow_4_meta["c2ws"].keys()) # (8 + 8*4) + (4 + 4*4)
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(narrow_8_meta["c2ws"].values()) + list(narrow_4_meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ assert narrow_8_meta["intrinsics"] == narrow_4_meta["intrinsics"], "intrinsics not equal"
+ intrinsic[:3, :3] = np.array(narrow_8_meta["intrinsics"])
+ self.intrinsic = intrinsic
+ assert narrow_8_meta["near_far"] == narrow_4_meta["near_far"], "near_far not equal"
+ self.near_far = np.array(narrow_8_meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+
+
+
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return self.imgs_per_instance*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+ idx_original=idx
+
+ folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance]
+
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ idx = idx % self.imgs_per_instance # [0, 11]
+ if idx < 8:
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+ else:
+ # target view
+ c2w = self.c2ws[idx-8+40]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow_12/", folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow_12/", folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+
+ src_views = range(8, 8 + 8 * 4 + 4 + 4*4)
+ src_views_used = []
+ skipped_idx = [40, 41, 42, 43]
+ for vid in src_views:
+ if vid in skipped_idx:
+ continue
+
+ src_views_used.append(vid)
+ cur_view_id = (vid - 8) // 4 # [0, 7]
+
+ # choose narrow
+ if cur_view_id < 8:
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{cur_view_id}_{vid%4}_10.png')
+ else: # choose 2-stage
+ cur_view_id = cur_view_id - 1
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow_12", folder_id, uid, f'view_{cur_view_id}_{vid%4}.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+
+
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ # print("img numeber: ", len(imgs))
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx_original % self.imgs_per_instance] + src_views_used
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ if view_ids[0] < 8:
+ meta_end = "_narrow"+ "_refview" + str(view_ids[0])
+ else:
+ meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8)
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_360.py b/SparseNeuS_demo_v1/data/blender_general_360.py
new file mode 100644
index 0000000000000000000000000000000000000000..37e8664613a614c03227375d8a0b25224d694bdc
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_360.py
@@ -0,0 +1,412 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_wide_pose.json"
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0_0", "view_0_5", "view_1_7"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+
+
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ depth_h = np.array(read_pfm(filename)[0], dtype=np.float32) # (1200, 1600)
+ depth_h = cv2.resize(depth_h, None, fx=0.5, fy=0.5,
+ interpolation=cv2.INTER_NEAREST) # (600, 800)
+ depth_h = depth_h[44:556, 80:720] # (512, 640)
+ depth_h = cv2.resize(depth_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ depth = cv2.resize(depth_h, None, fx=1.0 / 4, fy=1.0 / 4,
+ interpolation=cv2.INTER_NEAREST)
+
+ return depth, depth_h
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ # print("center", center)
+ # print("radius", radius)
+ # print("bounds", bounds)
+ # import ipdb; ipdb.set_trace()
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 36*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+
+ folder_uid_dict = self.lvis_paths[idx//36]
+
+
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ idx = idx % 36 # [0, 35]
+ gt_view_idx = idx // 12 # [0, 2]
+ target_view_idx = idx % 12 # [0, 11]
+
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{gt_view_idx}_{target_view_idx}_gt.png')
+
+ depth_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{gt_view_idx}_{target_view_idx}_gt_depth_mm.png')
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(gt_view_idx * 12, (gt_view_idx + 1) * 12)
+
+ idx_of_12 = idx - 12 * gt_view_idx # idx % 12
+
+ src_views = list(i % 12 + 12 * gt_view_idx for i in range(idx_of_12 - 1-1, idx_of_12 + 2+1))
+
+
+ for vid in src_views:
+ # if vid == idx:
+ # continue
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{gt_view_idx}_{target_view_idx}.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+ # print(scale_mat)
+ # print(scale_factor)
+ # ! calculate the new w2cs after scaling
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx] + list(src_views)
+
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0])
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_360_2_stage_1_3.py b/SparseNeuS_demo_v1/data/blender_general_360_2_stage_1_3.py
new file mode 100644
index 0000000000000000000000000000000000000000..72ad72bbfb336fa3e0d8b69f74c94afbea1593b7
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_360_2_stage_1_3.py
@@ -0,0 +1,406 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_2stage_pose.json"
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0_0", "view_0_5", "view_1_7"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ depth_h = np.array(read_pfm(filename)[0], dtype=np.float32) # (1200, 1600)
+ depth_h = cv2.resize(depth_h, None, fx=0.5, fy=0.5,
+ interpolation=cv2.INTER_NEAREST) # (600, 800)
+ depth_h = depth_h[44:556, 80:720] # (512, 640)
+ depth_h = cv2.resize(depth_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ depth = cv2.resize(depth_h, None, fx=1.0 / 4, fy=1.0 / 4,
+ interpolation=cv2.INTER_NEAREST)
+
+ return depth, depth_h
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ # print("center", center)
+ # print("radius", radius)
+ # print("bounds", bounds)
+ # import ipdb; ipdb.set_trace()
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 6*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+
+ folder_uid_dict = self.lvis_paths[idx//6]
+ idx = idx % 6
+
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ # idx = idx % 24 # [0, 23]
+
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_0_{idx}_gt.png')
+
+ depth_filename = os.path.join(self.root_dir, folder_id, uid, f'view_0_{idx}_gt_depth_mm.png')
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(gt_view_idx * 12, (gt_view_idx + 1) * 12)
+
+
+ src_views = range(6+idx*4, 6+(idx+1)*4)
+
+ for vid in src_views:
+ # if vid == idx:
+ # continue
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_0_{idx}_{vid % 4}.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+ # print(scale_mat)
+ # print(scale_factor)
+ # ! calculate the new w2cs after scaling
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx] + list(src_views)
+
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0])
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_360_2_stage_1_4.py b/SparseNeuS_demo_v1/data/blender_general_360_2_stage_1_4.py
new file mode 100644
index 0000000000000000000000000000000000000000..380706615bfe4a183b302f127af9913bfc2f4790
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_360_2_stage_1_4.py
@@ -0,0 +1,411 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_2stage_5pred_pose.json"
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0_0", "view_0_5", "view_1_7"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ depth_h = np.array(read_pfm(filename)[0], dtype=np.float32) # (1200, 1600)
+ depth_h = cv2.resize(depth_h, None, fx=0.5, fy=0.5,
+ interpolation=cv2.INTER_NEAREST) # (600, 800)
+ depth_h = depth_h[44:556, 80:720] # (512, 640)
+ depth_h = cv2.resize(depth_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ depth = cv2.resize(depth_h, None, fx=1.0 / 4, fy=1.0 / 4,
+ interpolation=cv2.INTER_NEAREST)
+
+ return depth, depth_h
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ # print("center", center)
+ # print("radius", radius)
+ # print("bounds", bounds)
+ # import ipdb; ipdb.set_trace()
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 6*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+
+ folder_uid_dict = self.lvis_paths[idx//6]
+ idx = idx % 6
+
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ # idx = idx % 24 # [0, 23]
+
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage", folder_id, uid, f'view_0_{idx}_gt.png')
+
+ depth_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage", folder_id, uid, f'view_0_{idx}_gt_depth_mm.png')
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+ # print("img_pre", img.shape)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ # print("img", img.shape)
+ imgs += [img]
+
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+ # print("depth_h", depth_h.shape)
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(gt_view_idx * 12, (gt_view_idx + 1) * 12)
+
+
+ src_views = range(6+idx*4, 6+(idx+1)*4)
+
+ for vid in src_views:
+ # if vid == idx:
+ # continue
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_0_{idx}_{vid % 4 + 1}.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ # print("img shape1: ", img.shape)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ # print("img shape2: ", img.shape)
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+ # print(scale_mat)
+ # print(scale_factor)
+ # ! calculate the new w2cs after scaling
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ # print("imgs: ", len(imgs))
+ # print("img1 shape:", imgs[0].shape)
+ # print("img2 shape:", imgs[1].shape)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx] + list(src_views)
+
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0])
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_4_narrow_and_4_2_stage_mix.py b/SparseNeuS_demo_v1/data/blender_general_4_narrow_and_4_2_stage_mix.py
new file mode 100644
index 0000000000000000000000000000000000000000..beb1f976907680936b20b37d76133589804d40c5
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_4_narrow_and_4_2_stage_mix.py
@@ -0,0 +1,480 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ self.root_dir = root_dir
+ self.split = split
+ self.imgs_per_instance = 16
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path_narrow, 'r') as f:
+ narrow_meta = json.load(f)
+
+ pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json"
+ with open(pose_json_path_two_stage, 'r') as f:
+ two_stage_meta = json.load(f)
+
+
+ self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 4*4)
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal"
+ intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"])
+ self.intrinsic = intrinsic
+ assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal"
+ self.near_far = np.array(narrow_meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+
+
+
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return self.imgs_per_instance * len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+ idx_original=idx
+
+ folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance]
+
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ if idx % 2 == 0:
+ valid_list = [0, 2, 4, 6]
+ else:
+ valid_list = [1, 3, 5, 7]
+
+ if idx % 16 < 8:
+ idx = idx % 16 # [0, 7]
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(8+idx*4, 8+(idx+1)*4)
+
+ src_views = range(8, 8 + 8 * 4)
+ src_views_used = []
+ for vid in src_views:
+ view_dix_to_use = (vid - 8) // 4
+ if view_dix_to_use not in valid_list:
+ continue
+ src_views_used.append(vid)
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+ else:
+ idx = idx % 16 - 8 # [0, 7]
+
+ c2w = self.c2ws[idx + 40]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_0.png')
+
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+ # print("img_pre", img.shape)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ # print("img", img.shape)
+ imgs += [img]
+
+
+ depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ # depth_h = torch.fill((img.shape[1], img.shape[2]), -1.0)
+ # print("depth_h", depth_h.shape)
+ mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+
+
+ src_views = range(40+8, 40+8+32)
+ src_views_used = []
+ for vid in src_views:
+ view_dix_to_use = (vid - 40 - 8) // 4
+ if view_dix_to_use not in valid_list:
+ continue
+ src_views_used.append(vid)
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_{(vid-48) % 4 + 1}.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ # print("img shape1: ", img.shape)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ # print("img shape2: ", img.shape)
+ imgs += [img]
+ depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ # print("img numeber: ", len(imgs))
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx_original % self.imgs_per_instance] + src_views_used
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ if view_ids[0] < 8:
+ meta_end = "_narrow"+ "_refview" + str(view_ids[0])
+ else:
+ meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8)
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_4_narrow_and_6_2_stage_mix.py b/SparseNeuS_demo_v1/data/blender_general_4_narrow_and_6_2_stage_mix.py
new file mode 100644
index 0000000000000000000000000000000000000000..e80567fe34ee51cb49355ee26ea8ce80dff706e6
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_4_narrow_and_6_2_stage_mix.py
@@ -0,0 +1,476 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path_narrow, 'r') as f:
+ narrow_meta = json.load(f)
+
+ pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_5pred_pose.json"
+ with open(pose_json_path_two_stage, 'r') as f:
+ two_stage_meta = json.load(f)
+
+
+ self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (6 + 6*4)
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal"
+ intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"])
+ self.intrinsic = intrinsic
+ assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal"
+ self.near_far = np.array(narrow_meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+
+
+
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 12*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+ idx_original=idx
+
+ folder_uid_dict = self.lvis_paths[idx//12]
+
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ if idx % 12 < 8:
+ idx = idx % 12 # [0, 7]
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(8+idx*4, 8+(idx+1)*4)
+
+ src_views = range(8, 8 + 8 * 4)
+ src_views_used = []
+ for vid in src_views:
+ if (vid // 4) % 2 != idx % 2:
+ continue
+ src_views_used.append(vid)
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+ else:
+ idx = idx % 12 - 8 # [0, 5]
+ valid_list = [0, 2, 3, 5]
+ idx = valid_list[idx] # [0, 3]
+ c2w = self.c2ws[idx + 40]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_5pred/", folder_id, uid, f'view_0_{idx}_0.png')
+
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+ # print("img_pre", img.shape)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ # print("img", img.shape)
+ imgs += [img]
+
+
+ depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ # depth_h = torch.fill((img.shape[1], img.shape[2]), -1.0)
+ # print("depth_h", depth_h.shape)
+ mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(gt_view_idx * 12, (gt_view_idx + 1) * 12)
+
+
+ src_views = range(40+6, 40+6+24)
+ src_views_used = []
+ for vid in src_views:
+ view_dix_to_use = (vid - 40 - 6) // 4
+ if view_dix_to_use not in valid_list:
+ continue
+ src_views_used.append(vid)
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_5pred/", folder_id, uid, f'view_0_{idx}_{(vid-46) % 4 + 1}.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ # print("img shape1: ", img.shape)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ # print("img shape2: ", img.shape)
+ imgs += [img]
+ depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ # print("img numeber: ", len(imgs))
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx_original % 12] + src_views_used
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ if view_ids[0] < 8:
+ meta_end = "_narrow"+ "_refview" + str(view_ids[0])
+ else:
+ meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8)
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_6_narrow_and_6_2_stage_blend_mix.py b/SparseNeuS_demo_v1/data/blender_general_6_narrow_and_6_2_stage_blend_mix.py
new file mode 100644
index 0000000000000000000000000000000000000000..248e9f9591b95a711406b0e1efb3568e05e2414a
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_6_narrow_and_6_2_stage_blend_mix.py
@@ -0,0 +1,449 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ self.root_dir = root_dir
+ self.split = split
+ if self.split == 'train':
+ self.imgs_per_instance = 12
+ else:
+ self.imgs_per_instance = 16
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path_narrow, 'r') as f:
+ narrow_meta = json.load(f)
+
+ pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json"
+ with open(pose_json_path_two_stage, 'r') as f:
+ two_stage_meta = json.load(f)
+
+
+ self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 4*4)
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal"
+ intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"])
+ self.intrinsic = intrinsic
+ assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal"
+ self.near_far = np.array(narrow_meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+
+
+
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return self.imgs_per_instance*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+ idx_original=idx
+
+ folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance]
+
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ if self.split == 'train':
+ if idx == 4:
+ idx = 5
+ elif idx == 5:
+ idx = 7
+ elif idx == 10:
+ idx = 13
+ elif idx == 11:
+ idx = 15
+
+ if idx % 16 < 8: # narrow image as target
+ idx = idx % 16 # [0, 7]
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+ else:
+ idx = idx % 16 - 8 # [0, 5]
+ c2w = self.c2ws[idx + 40]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_0.png')
+
+ img = Image.open(img_filename)
+ img = self.transform(img) # (4, h, w)
+
+ # print("img_pre", img.shape)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ # print("img", img.shape)
+ imgs += [img]
+
+ depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+
+ mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+ if_use_narrow = []
+ if self.split == 'train':
+ for i in range(8):
+ if np.random.random() > 0.5:
+ if_use_narrow.append(True) # use narrow
+ else:
+ if_use_narrow.append(False) # 2-stage prediction
+ if_use_narrow[origin_idx % 8] = True if origin_idx < 8 else False
+ else:
+ for i in range(8):
+ if_use_narrow.append( True if origin_idx < 8 else False)
+ src_views = range(8, 8 + 8 * 4)
+ src_views_used = []
+ for vid in src_views:
+ if ((vid - 8) // 4 == 4) or ((vid - 8) // 4 == 6):
+ continue
+ src_views_used.append(vid)
+ cur_view_id = (vid - 8) // 4
+ # choose narrow
+ if if_use_narrow[cur_view_id]:
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{cur_view_id}_{vid%4}_10.png')
+ else: # choose 2-stage
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{(vid - 8) // 4}_{(vid-8) % 4 + 1}.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx_original % self.imgs_per_instance] + src_views_used
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ if view_ids[0] < 8:
+ meta_end = "_narrow"+ "_refview" + str(view_ids[0])
+ else:
+ meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8)
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_8_2_stage.py b/SparseNeuS_demo_v1/data/blender_general_8_2_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1fd371e5fc7be9685b81efa3d607018b2a9bdb1
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_8_2_stage.py
@@ -0,0 +1,396 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ self.root_dir = root_dir
+ self.split = split
+
+ self.imgs_per_instance = 8
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path_narrow, 'r') as f:
+ narrow_meta = json.load(f)
+
+ pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json"
+ with open(pose_json_path_two_stage, 'r') as f:
+ two_stage_meta = json.load(f)
+
+
+ self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 8*4)
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal"
+ intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"])
+ self.intrinsic = intrinsic
+ assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal"
+ self.near_far = np.array(narrow_meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+
+
+
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return self.imgs_per_instance * len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+ idx_original=idx
+
+ folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance]
+
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ idx = idx % self.imgs_per_instance # [0, 7]
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+
+
+ src_views = range(8, 8+32)
+ src_views_used = []
+ for vid in src_views:
+ view_dix_to_use = (vid - 8) // 4
+ src_views_used.append(vid)
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_{(vid-8) % 4 + 1}.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+ depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+
+ new_depths_h.append(depth * scale_factor)
+
+
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx_original % self.imgs_per_instance] + src_views_used
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8)
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_8_4_gt.py b/SparseNeuS_demo_v1/data/blender_general_8_4_gt.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1072d6a3e02f1908add474963aa6c6acaf69055
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_8_4_gt.py
@@ -0,0 +1,396 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ self.root_dir = root_dir
+ self.split = split
+
+ self.imgs_per_instance = 8
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path_narrow, 'r') as f:
+ narrow_meta = json.load(f)
+
+ pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json"
+ with open(pose_json_path_two_stage, 'r') as f:
+ two_stage_meta = json.load(f)
+
+
+ self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 8*4)
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal"
+ intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"])
+ self.intrinsic = intrinsic
+ assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal"
+ self.near_far = np.array(narrow_meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+
+
+
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return self.imgs_per_instance * len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+ idx_original=idx
+
+ folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance]
+
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ idx = idx % self.imgs_per_instance # [0, 7]
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+
+
+ src_views = range(8, 8+32)
+ src_views_used = []
+ for vid in src_views:
+ view_dix_to_use = (vid - 8) // 4
+ src_views_used.append(vid)
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10_gt.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+ depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+
+ new_depths_h.append(depth * scale_factor)
+
+
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx_original % self.imgs_per_instance] + src_views_used
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8)
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_blend_3_views.py b/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_blend_3_views.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa97eb6ca99c254548e501f2e05d883f2b015e1c
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_blend_3_views.py
@@ -0,0 +1,446 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ self.root_dir = root_dir
+ self.split = split
+ self.imgs_per_instance = 16
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path_narrow, 'r') as f:
+ narrow_meta = json.load(f)
+
+ pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json"
+ with open(pose_json_path_two_stage, 'r') as f:
+ two_stage_meta = json.load(f)
+
+
+ self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 4*4)
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal"
+ intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"])
+ self.intrinsic = intrinsic
+ assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal"
+ self.near_far = np.array(narrow_meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+
+
+
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return self.imgs_per_instance*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+ idx_original=idx
+
+ folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance]
+
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ if idx % 16 < 8: # narrow image as target
+ idx = idx % self.imgs_per_instance # [0, 7]
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+ else:
+ idx = idx % self.imgs_per_instance - 8 # [0, 5]
+ c2w = self.c2ws[idx + 40]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_0.png')
+
+
+ img = Image.open(img_filename)
+ img = self.transform(img) # (4, h, w)
+
+ # print("img_pre", img.shape)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ # print("img", img.shape)
+ imgs += [img]
+
+ depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+
+ mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+ if_use_narrow = []
+ if self.split == 'train':
+ for i in range(8):
+ if np.random.random() > 0.5:
+ if_use_narrow.append(True) # use narrow
+ else:
+ if_use_narrow.append(False) # 2-stage prediction
+ if_use_narrow[origin_idx % 8] = True if origin_idx < 8 else False
+ else:
+ for i in range(8):
+ if_use_narrow.append( True if origin_idx < 8 else False)
+
+ src_views = list()
+ for i in range(8):
+ # randomly choose 3 different number from [0,3]
+ local_idxs = np.random.choice(4, 3, replace=False)
+ local_idxs = [0,1,2]
+ local_idxs = [8+i*4+local_idx for local_idx in local_idxs]
+ src_views += local_idxs
+ src_views_used = []
+ for vid in src_views:
+ src_views_used.append(vid)
+ cur_view_id = (vid - 8) // 4
+ # choose narrow
+ if if_use_narrow[cur_view_id]:
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{cur_view_id}_{vid%4}_10.png')
+ else: # choose 2-stage
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{(vid - 8) // 4}_{(vid-8) % 4 + 1}.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+
+
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ # print("img numeber: ", len(imgs))
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx_original % self.imgs_per_instance] + src_views_used
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ if view_ids[0] < 8:
+ meta_end = "_narrow"+ "_refview" + str(view_ids[0])
+ else:
+ meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8)
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_blend_mix.py b/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_blend_mix.py
new file mode 100644
index 0000000000000000000000000000000000000000..740bb81125a297fc1d504f4c119c7f9a76630507
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_blend_mix.py
@@ -0,0 +1,439 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ self.root_dir = root_dir
+ self.split = split
+ self.imgs_per_instance = 16
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path_narrow, 'r') as f:
+ narrow_meta = json.load(f)
+
+ pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json"
+ with open(pose_json_path_two_stage, 'r') as f:
+ two_stage_meta = json.load(f)
+
+
+ self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 8*4)
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal"
+ intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"])
+ self.intrinsic = intrinsic
+ assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal"
+ self.near_far = np.array(narrow_meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+
+
+
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return self.imgs_per_instance*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+ idx_original=idx
+
+ folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance]
+
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ if idx % 16 < 8: # gt image as target
+ idx = idx % self.imgs_per_instance # [0, 7]
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+ else:
+ idx = idx % self.imgs_per_instance - 8 # [0, 7]
+ c2w = self.c2ws[idx + 40]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_0.png')
+
+
+ img = Image.open(img_filename)
+ img = self.transform(img) # (4, h, w)
+
+ # print("img_pre", img.shape)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ # print("img", img.shape)
+ imgs += [img]
+
+ depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+
+ mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+ if_use_narrow = []
+ if self.split == 'train':
+ for i in range(8):
+ if np.random.random() > 0.5:
+ if_use_narrow.append(True) # use narrow
+ else:
+ if_use_narrow.append(False) # 2-stage prediction
+ if_use_narrow[origin_idx % 8] = True if (origin_idx % 16) < 8 else False
+ else:
+ for i in range(8):
+ if_use_narrow.append( True if (origin_idx % 16) < 8 else False)
+ src_views = range(8, 8 + 8 * 4)
+ src_views_used = []
+ for vid in src_views:
+ src_views_used.append(vid)
+ cur_view_id = (vid - 8) // 4 # [0, 7]
+ # choose narrow
+ if if_use_narrow[cur_view_id]:
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{cur_view_id}_{vid%4}_10.png')
+ else: # choose 2-stage
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{cur_view_id}_{(vid) % 4 + 1}.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+
+
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ # print("img numeber: ", len(imgs))
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx_original % self.imgs_per_instance] + src_views_used
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ if view_ids[0] < 8:
+ meta_end = "_narrow"+ "_refview" + str(view_ids[0])
+ else:
+ meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8)
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_mix.py b/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_mix.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d860e521935b529c4240a0299d892ff90f683b2
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_8_narrow_and_8_2_stage_mix.py
@@ -0,0 +1,470 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ self.root_dir = root_dir
+ self.split = split
+ self.imgs_per_instance = 16
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path_narrow, 'r') as f:
+ narrow_meta = json.load(f)
+
+ pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json"
+ with open(pose_json_path_two_stage, 'r') as f:
+ two_stage_meta = json.load(f)
+
+
+ self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 8*4)
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal"
+ intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"])
+ self.intrinsic = intrinsic
+ assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal"
+ self.near_far = np.array(narrow_meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+
+
+
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return self.imgs_per_instance * len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+ idx_original=idx
+
+ folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance]
+
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ if idx % self.imgs_per_instance < 8:
+ idx = idx % self.imgs_per_instance # [0, 7]
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(8+idx*4, 8+(idx+1)*4)
+
+ src_views = range(8, 8 + 8 * 4)
+ src_views_used = []
+ for vid in src_views:
+ src_views_used.append(vid)
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+ else:
+ idx = idx % self.imgs_per_instance - 8 # [0, 5]
+
+ c2w = self.c2ws[idx + 40]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_0.png')
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+ # print("img_pre", img.shape)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ # print("img", img.shape)
+ imgs += [img]
+
+
+ depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ # depth_h = torch.fill((img.shape[1], img.shape[2]), -1.0)
+ # print("depth_h", depth_h.shape)
+ mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+
+
+ src_views = range(40+8, 40+8+32)
+ src_views_used = []
+ for vid in src_views:
+ view_dix_to_use = (vid - 40 - 8) // 4
+
+ src_views_used.append(vid)
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{idx}_{(vid-48) % 4 + 1}.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ # print("img shape1: ", img.shape)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ # print("img shape2: ", img.shape)
+ imgs += [img]
+ depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ # print("img numeber: ", len(imgs))
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx_original % self.imgs_per_instance] + src_views_used
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ if view_ids[0] < 8:
+ meta_end = "_narrow"+ "_refview" + str(view_ids[0])
+ else:
+ meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8)
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_8_wide_from_2_stage.py b/SparseNeuS_demo_v1/data/blender_general_8_wide_from_2_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..9609f20a733486544347d7fec78ae16bf1b9e2a3
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_8_wide_from_2_stage.py
@@ -0,0 +1,395 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ self.root_dir = root_dir
+ self.split = split
+
+ self.imgs_per_instance = 8
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/random32_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path_narrow = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path_narrow, 'r') as f:
+ narrow_meta = json.load(f)
+
+ pose_json_path_two_stage = "/objaverse-processed/zero12345_img/zero12345_2stage_8_pose.json"
+ with open(pose_json_path_two_stage, 'r') as f:
+ two_stage_meta = json.load(f)
+
+
+ self.img_ids = list(narrow_meta["c2ws"].keys()) + list(two_stage_meta["c2ws"].keys()) # (8 + 8*4) + (8 + 8*4)
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(narrow_meta["c2ws"].values()) + list(two_stage_meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ assert narrow_meta["intrinsics"] == two_stage_meta["intrinsics"], "intrinsics not equal"
+ intrinsic[:3, :3] = np.array(narrow_meta["intrinsics"])
+ self.intrinsic = intrinsic
+ assert narrow_meta["near_far"] == two_stage_meta["near_far"], "near_far not equal"
+ self.near_far = np.array(narrow_meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+
+
+
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return self.imgs_per_instance * len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+ idx_original=idx
+
+ folder_uid_dict = self.lvis_paths[idx//self.imgs_per_instance]
+
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ idx = idx % self.imgs_per_instance # [0, 7]
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+
+
+ src_views = range(0, 8)
+ src_views_used = []
+ for vid in src_views:
+ src_views_used.append(vid)
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_2stage_8/", folder_id, uid, f'view_0_{vid}_0.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+ depth_h =torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+
+ new_depths_h.append(depth * scale_factor)
+
+
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx_original % self.imgs_per_instance] + src_views_used
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ meta_end = "_two_stage"+ "_refview" + str(view_ids[0] - 8)
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + meta_end
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_4_1_eval_new_data.py b/SparseNeuS_demo_v1/data/blender_general_narrow_4_1_eval_new_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..bacd68d0d8cc7b578bf546e4484590f985920051
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_4_1_eval_new_data.py
@@ -0,0 +1,418 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+
+
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[],
+ specific_dataset_name = 'GSO'
+ ):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+ # self.specific_dataset_name = 'Realfusion'
+ # self.specific_dataset_name = 'GSO'
+ # self.specific_dataset_name = 'Objaverse'
+ # self.specific_dataset_name = 'Zero123'
+
+ self.specific_dataset_name = specific_dataset_name
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+ assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh'
+ # find all subfolders
+ main_folder = os.path.join(root_dir, self.specific_dataset_name)
+ self.shape_list = os.listdir(main_folder)
+ self.shape_list.sort()
+
+ # self.shape_list = ['barrel_render']
+ # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED
+
+
+ self.lvis_paths = []
+ for shape_name in self.shape_list:
+ self.lvis_paths.append(os.path.join(main_folder, shape_name))
+
+ # print("lvis_paths: ", self.lvis_paths)
+
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 8*len(self.lvis_paths)
+ # return len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ # idx = idx * 8 # to be deleted
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views
+
+ folder_path = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+
+ # last subdir name
+ shape_name = os.path.split(folder_path)[-1]
+
+ pose_json_path = os.path.join(folder_path, "pose.json")
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for image_dix, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[image_dix]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}')
+ img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}')
+
+ img = Image.open(img_filename)
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+
+ depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+
+ # src_views = range(8, 8 + 8 * 4)
+ src_views = range(8+idx*4, 8+(idx+1)*4)
+ for vid in src_views:
+
+ # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}')
+ img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}')
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+
+
+ target_w2cs = []
+ target_intrinsics = []
+ new_target_w2cs = []
+ for i_idx in range(8):
+ target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv)
+ target_intrinsics.append(self.all_intrinsics[i_idx])
+
+ for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_target_w2cs.append(w2c)
+ target_w2cs = np.stack(new_target_w2cs)
+
+
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = shape_name
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0])
+ # print("meta: ", sample['meta'])
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_6.py b/SparseNeuS_demo_v1/data/blender_general_narrow_6.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d8333986bb15b3e3fd495f1ee4600e22ef93246
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_6.py
@@ -0,0 +1,399 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ # print("center", center)
+ # print("radius", radius)
+ # print("bounds", bounds)
+ # import ipdb; ipdb.set_trace()
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ if self.split == 'train':
+ return 6*len(self.lvis_paths)
+ else:
+ return 8*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+ if self.split == 'train':
+ folder_uid_dict = self.lvis_paths[idx//6]
+ idx = idx % 6 # [0, 5]
+ if idx == 4:
+ idx = 5
+ elif idx == 5:
+ idx = 7
+ else:
+ folder_uid_dict = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(8+idx*4, 8+(idx+1)*4)
+ src_views = range(8, 8 + 8 * 4)
+
+ for vid in src_views:
+ if ((vid - 8) // 4 == 4) or ((vid - 8) // 4 == 6):
+ continue
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+ # print("len(imges)", len(imgs))
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0])
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_8_3_fixed.py b/SparseNeuS_demo_v1/data/blender_general_narrow_8_3_fixed.py
new file mode 100644
index 0000000000000000000000000000000000000000..58c26348e73b44fdcb33bad81b1fddba66efeffc
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_8_3_fixed.py
@@ -0,0 +1,393 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ # print("center", center)
+ # print("radius", radius)
+ # print("bounds", bounds)
+ # import ipdb; ipdb.set_trace()
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 8*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+
+ folder_uid_dict = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(8+idx*4, 8+(idx+1)*4)
+ src_views = list()
+ for i in range(8):
+ # randomly choose 3 different number from [0,3]
+ # local_idxs = np.random.choice(4, 3, replace=False)
+ local_idxs = [0, 2, 3]
+ # local_idxs = np.random.choice(4, 3, replace=False)
+
+ local_idxs = [8 + i * 4 + local_idx for local_idx in local_idxs]
+ src_views += local_idxs
+ for vid in src_views:
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+ # print("len(imgs)", len(imgs))
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0])
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_8_3_random.py b/SparseNeuS_demo_v1/data/blender_general_narrow_8_3_random.py
new file mode 100644
index 0000000000000000000000000000000000000000..b52542595e8d39dff91f18e63a0b504c4c4d2d48
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_8_3_random.py
@@ -0,0 +1,395 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ # print("center", center)
+ # print("radius", radius)
+ # print("bounds", bounds)
+ # import ipdb; ipdb.set_trace()
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 8*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+
+ folder_uid_dict = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(8+idx*4, 8+(idx+1)*4)
+ src_views = list()
+ for i in range(8):
+
+ if self.split == 'train':
+ local_idxs = np.random.choice(4, 3, replace=False)
+ else:
+ local_idxs = [0, 2, 3]
+ # local_idxs = np.random.choice(4, 3, replace=False)
+
+ local_idxs = [8 + i * 4 + local_idx for local_idx in local_idxs]
+ src_views += local_idxs
+ for vid in src_views:
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+ # print("len(imgs)", len(imgs))
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0])
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_8_4_random_shading.py b/SparseNeuS_demo_v1/data/blender_general_narrow_8_4_random_shading.py
new file mode 100644
index 0000000000000000000000000000000000000000..e120367ce96847e9fb60b2ae038a812583fe75e3
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_8_4_random_shading.py
@@ -0,0 +1,432 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ # print("center", center)
+ # print("radius", radius)
+ # print("bounds", bounds)
+ # import ipdb; ipdb.set_trace()
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 8*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+
+ folder_uid_dict = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(8+idx*4, 8+(idx+1)*4)
+ src_views = range(8, 8 + 8 * 4)
+
+ for vid in src_views:
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ if self.split == 'train':
+ # randomly select one view from eight views as reference view
+ idx_to_select = np.random.randint(0, 8)
+
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx_to_select}.png')
+ img = Image.open(img_filename)
+ img = self.transform(img) # (4, h, w)
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs[0] = img
+
+ w2c_selected = self.all_extrinsics[idx_to_select] @ w2c_ref_inv
+ P = self.all_intrinsics[idx_to_select] @ w2c_selected @ scale_mat
+ P = P[:3, :4]
+
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = self.all_intrinsics[idx_to_select][:3, :3] @ w2c[:3, :4]
+ new_affine_mats[0] = affine_mat
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+ new_near_fars[0] = [0.95 * near, 1.05 * far]
+
+ new_w2cs[0] = w2c
+ new_c2ws[0] = c2w
+
+ depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx_to_select}_depth_mm.png'))
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance * scale_factor
+
+ new_depths_h[0] = depth_h
+ masks_h[0] = mask_h
+
+
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0])
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..50b85d133707e83b36d926b7acf1cb121dd4d04d
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all.py
@@ -0,0 +1,386 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ # print("center", center)
+ # print("radius", radius)
+ # print("bounds", bounds)
+ # import ipdb; ipdb.set_trace()
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 8*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+
+ folder_uid_dict = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(8+idx*4, 8+(idx+1)*4)
+ src_views = range(8, 8 + 8 * 4)
+
+ for vid in src_views:
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0])
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_2_stage.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_2_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b832beccd85c8a0be98edf95f0d244c1cbf8b17
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_2_stage.py
@@ -0,0 +1,410 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ # print("center", center)
+ # print("radius", radius)
+ # print("bounds", bounds)
+ # import ipdb; ipdb.set_trace()
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 8*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+
+ folder_uid_dict = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+ # print("img_pre", img.shape)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ # print("img", img.shape)
+ imgs += [img]
+
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+ # print("depth_h", depth_h.shape)
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(8+idx*4, 8+(idx+1)*4)
+ src_views = range(8, 8 + 8 * 4)
+
+ for vid in src_views:
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_0_{(vid - 8) // 4}_{vid % 4 + 1}.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+
+
+ target_w2cs = []
+ target_intrinsics = []
+ new_target_w2cs = []
+ for i_idx in range(8):
+ target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv)
+ target_intrinsics.append(self.all_intrinsics[i_idx])
+
+ for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_target_w2cs.append(w2c)
+ target_w2cs = np.stack(new_target_w2cs)
+
+
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0])
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_2_stage_temp.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_2_stage_temp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c2dbebd00ed9e0293c26029c97ab77b7880fcf0
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_2_stage_temp.py
@@ -0,0 +1,411 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ # print("center", center)
+ # print("radius", radius)
+ # print("bounds", bounds)
+ # import ipdb; ipdb.set_trace()
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 10
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ idx = idx * 8
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+
+ folder_uid_dict = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join("/objaverse-processed/zero12345_img/zero12345_narrow/", folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+ # print("img_pre", img.shape)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ # print("img", img.shape)
+ imgs += [img]
+
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+ # print("depth_h", depth_h.shape)
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(8+idx*4, 8+(idx+1)*4)
+ src_views = range(8, 8 + 8 * 4)
+
+ for vid in src_views:
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_0_{(vid - 8) // 4}_{vid % 4 + 1}.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+
+
+ target_w2cs = []
+ target_intrinsics = []
+ new_target_w2cs = []
+ for i_idx in range(8):
+ target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv)
+ target_intrinsics.append(self.all_intrinsics[i_idx])
+
+ for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_target_w2cs.append(w2c)
+ target_w2cs = np.stack(new_target_w2cs)
+
+
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0])
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..194cf007f54d2d377ce6561050f82e38dc246e73
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data.py
@@ -0,0 +1,418 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+
+
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[],
+ specific_dataset_name = 'GSO'
+ ):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+ # self.specific_dataset_name = 'Realfusion'
+ # self.specific_dataset_name = 'GSO'
+ # self.specific_dataset_name = 'Objaverse'
+ # self.specific_dataset_name = 'Zero123'
+
+ self.specific_dataset_name = specific_dataset_name
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+ assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh'
+ # find all subfolders
+ main_folder = os.path.join(root_dir, self.specific_dataset_name)
+ self.shape_list = [""] # os.listdir(main_folder) # MODIFIED
+ self.shape_list.sort()
+
+ # self.shape_list = ['barrel_render']
+ # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED
+
+
+ self.lvis_paths = []
+ for shape_name in self.shape_list:
+ self.lvis_paths.append(os.path.join(main_folder, shape_name))
+
+ # print("lvis_paths: ", self.lvis_paths)
+
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ # return 8*len(self.lvis_paths)
+ return len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ idx = idx * 8 # to be deleted
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views
+
+ folder_path = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+
+ # last subdir name
+ shape_name = os.path.split(folder_path)[-1]
+
+ pose_json_path = os.path.join(folder_path, "pose.json")
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for image_dix, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[image_dix]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}')
+ img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}')
+
+ img = Image.open(img_filename)
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+
+ depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+
+ src_views = range(8, 8 + 8 * 4)
+
+ for vid in src_views:
+
+ # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}')
+ img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}')
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+
+
+ target_w2cs = []
+ target_intrinsics = []
+ new_target_w2cs = []
+ for i_idx in range(8):
+ target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv)
+ target_intrinsics.append(self.all_intrinsics[i_idx])
+
+ for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_target_w2cs.append(w2c)
+ target_w2cs = np.stack(new_target_w2cs)
+
+
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = shape_name
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0])
+ # print("meta: ", sample['meta'])
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data3_1.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data3_1.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ce059be019a360b193c526c358057ffc9b48d1a
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data3_1.py
@@ -0,0 +1,414 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+ # self.specific_dataset_name = 'Realfusion'
+ self.specific_dataset_name = 'Objaverse'
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+ assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh'
+ # find all subfolders
+ main_folder = os.path.join(root_dir, self.specific_dataset_name)
+ self.shape_list = os.listdir(main_folder)
+ self.shape_list.sort()
+
+ # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED
+
+
+ self.lvis_paths = []
+ for shape_name in self.shape_list:
+ self.lvis_paths.append(os.path.join(main_folder, shape_name))
+
+ # print("lvis_paths: ", self.lvis_paths)
+
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ # return 8*len(self.lvis_paths)
+ return len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ idx = idx * 8 # to be deleted
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+
+ folder_path = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+
+ # last subdir name
+ shape_name = os.path.split(folder_path)[-1]
+
+
+ pose_json_path = os.path.join(folder_path, "pose.json")
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for image_dix, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[image_dix]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}')
+ # print(self.img_ids)
+ img = Image.open(img_filename)
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+
+ depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(8+idx*4, 8+(idx+1)*4)
+ src_views = range(8, 8 + 8 * 4)
+
+ for vid in src_views:
+ if vid % 4 == 0:
+ vid = (vid - 8) // 4
+ img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[vid]}')
+ else:
+ img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+
+
+ target_w2cs = []
+ target_intrinsics = []
+ new_target_w2cs = []
+ for i_idx in range(8):
+ target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv)
+ target_intrinsics.append(self.all_intrinsics[i_idx])
+
+ for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_target_w2cs.append(w2c)
+ target_w2cs = np.stack(new_target_w2cs)
+
+
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = shape_name
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0])
+ # print("meta: ", sample['meta'])
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_32_wide.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_32_wide.py
new file mode 100644
index 0000000000000000000000000000000000000000..f69ece26bdd88955bf5612f2f6f66ae7f9262e19
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_32_wide.py
@@ -0,0 +1,465 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+
+def calc_pose(phis, thetas, size, radius = 1.2):
+ import torch
+ def normalize(vectors):
+ return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)
+ # device = torch.device('cuda')
+ thetas = torch.FloatTensor(thetas)
+ phis = torch.FloatTensor(phis)
+
+ centers = torch.stack([
+ radius * torch.sin(thetas) * torch.sin(phis),
+ -radius * torch.cos(thetas) * torch.sin(phis),
+ radius * torch.cos(phis),
+ ], dim=-1) # [B, 3]
+
+ # lookat
+ forward_vector = normalize(centers).squeeze(0)
+ up_vector = torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(size, 1)
+ right_vector = normalize(torch.cross(up_vector, forward_vector, dim=-1))
+ if right_vector.pow(2).sum() < 0.01:
+ right_vector = torch.FloatTensor([0, 1, 0]).unsqueeze(0).repeat(size, 1)
+ up_vector = normalize(torch.cross(forward_vector, right_vector, dim=-1))
+
+ poses = torch.eye(4, dtype=torch.float)[:3].unsqueeze(0).repeat(size, 1, 1)
+ poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
+ poses[:, :3, 3] = centers
+ return poses
+
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[],
+ specific_dataset_name = 'GSO'
+ ):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+ # self.specific_dataset_name = 'Realfusion'
+ # self.specific_dataset_name = 'GSO'
+ # self.specific_dataset_name = 'Objaverse'
+ # self.specific_dataset_name = 'Zero123'
+
+ self.specific_dataset_name = specific_dataset_name
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+ assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh'
+ # find all subfolders
+ main_folder = os.path.join(root_dir)
+ self.shape_list = os.listdir(main_folder)
+ self.shape_list.sort()
+
+ # self.shape_list = ['barrel_render']
+ # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED
+
+
+ self.lvis_paths = []
+ for shape_name in self.shape_list:
+ self.lvis_paths.append(os.path.join(main_folder, shape_name))
+
+ # print("lvis_paths: ", self.lvis_paths)
+
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid in range(self.input_poses.shape[0]):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ # return 8*len(self.lvis_paths)
+ return len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ idx = idx * 8 # to be deleted
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views
+
+ folder_path = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+
+ # last subdir name
+ shape_name = os.path.split(folder_path)[-1]
+
+ # pose_json_path = os.path.join(folder_path, "pose.json")
+ # with open(pose_json_path, 'r') as f:
+ # meta = json.load(f)
+
+ # self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ # self.img_wh = (256, 256)
+ # self.input_poses = np.array(list(meta["c2ws"].values()))
+ # intrinsic = np.eye(4)
+ # intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ # self.intrinsic = intrinsic
+ # self.near_far = np.array(meta["near_far"])
+ # self.near_far[1] = 1.8
+ # self.define_transforms()
+ # self.blender2opencv = np.array(
+ # [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ # )
+
+ pose_file = os.path.join(folder_path, '32_random', 'views.npz')
+ pose_array = np.load(pose_file)
+ pose = calc_pose(pose_array['elevations'], pose_array['azimuths'], 32) # [32, 3, 4] c2ws
+
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(pose)
+ self.input_poses = np.concatenate([self.input_poses, np.tile(np.array([0, 0, 0, 1], dtype=np.float32)[None, None, :], [self.input_poses.shape[0], 1, 1])], axis=1)
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for image_dix in range(pose.shape[0]):
+ pose = self.input_poses[image_dix]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}')
+ img_filename = os.path.join(folder_path, '32_random', f'{idx}.png')
+
+ img = Image.open(img_filename)
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+
+ depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+
+ src_views = range(0, 8 * 4)
+
+ for vid in src_views:
+
+ # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}')
+ img_filename = os.path.join(folder_path, '32_random', f'{vid}.png')
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+
+
+ target_w2cs = []
+ target_intrinsics = []
+ new_target_w2cs = []
+ for i_idx in range(8):
+ target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv)
+ target_intrinsics.append(self.all_intrinsics[i_idx])
+
+ for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_target_w2cs.append(w2c)
+ target_w2cs = np.stack(new_target_w2cs)
+
+
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = shape_name
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0])
+ # print("meta: ", sample['meta'])
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_4_4.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_4_4.py
new file mode 100644
index 0000000000000000000000000000000000000000..6263a9ff47edc8f7b65600786c244fafb809240b
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_4_4.py
@@ -0,0 +1,419 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+
+
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[],
+ specific_dataset_name = 'GSO'
+ ):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+ # self.specific_dataset_name = 'Realfusion'
+ # self.specific_dataset_name = 'GSO'
+ # self.specific_dataset_name = 'Objaverse'
+ # self.specific_dataset_name = 'Zero123'
+
+ self.specific_dataset_name = specific_dataset_name
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+ assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh'
+ # find all subfolders
+ main_folder = os.path.join(root_dir, self.specific_dataset_name)
+ self.shape_list = os.listdir(main_folder)
+ self.shape_list.sort()
+
+ # self.shape_list = ['barrel_render']
+ # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED
+
+
+ self.lvis_paths = []
+ for shape_name in self.shape_list:
+ self.lvis_paths.append(os.path.join(main_folder, shape_name))
+
+ # print("lvis_paths: ", self.lvis_paths)
+
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ # return 8*len(self.lvis_paths)
+ return len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ idx = idx * 8 # to be deleted
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views
+
+ folder_path = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+
+ # last subdir name
+ shape_name = os.path.split(folder_path)[-1]
+
+ pose_json_path = os.path.join(folder_path, "pose.json")
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for image_dix, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[image_dix]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}')
+ img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}')
+
+ img = Image.open(img_filename)
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+
+ depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+
+ src_views = range(8, 8 + 8 * 4)
+
+ for vid in src_views:
+ if (vid // 4) % 2 != 0:
+ continue
+ # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}')
+ img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}')
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+
+
+ target_w2cs = []
+ target_intrinsics = []
+ new_target_w2cs = []
+ for i_idx in range(8):
+ target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv)
+ target_intrinsics.append(self.all_intrinsics[i_idx])
+
+ for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_target_w2cs.append(w2c)
+ target_w2cs = np.stack(new_target_w2cs)
+
+
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = shape_name
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0])
+ # print("meta: ", sample['meta'])
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_6_4.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_6_4.py
new file mode 100644
index 0000000000000000000000000000000000000000..c88c0d9b37402f970d9b2d7686b774943366e9a8
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_6_4.py
@@ -0,0 +1,420 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+
+
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[],
+ specific_dataset_name = 'GSO'
+ ):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+ # self.specific_dataset_name = 'Realfusion'
+ # self.specific_dataset_name = 'GSO'
+ # self.specific_dataset_name = 'Objaverse'
+ # self.specific_dataset_name = 'Zero123'
+
+ self.specific_dataset_name = specific_dataset_name
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+ assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh'
+ # find all subfolders
+ main_folder = os.path.join(root_dir, self.specific_dataset_name)
+ self.shape_list = os.listdir(main_folder)
+ self.shape_list.sort()
+
+ # self.shape_list = ['barrel_render']
+ # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED
+
+
+ self.lvis_paths = []
+ for shape_name in self.shape_list:
+ self.lvis_paths.append(os.path.join(main_folder, shape_name))
+
+ # print("lvis_paths: ", self.lvis_paths)
+
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ # return 8*len(self.lvis_paths)
+ return len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ idx = idx * 8 # to be deleted
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views
+
+ folder_path = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+
+ # last subdir name
+ shape_name = os.path.split(folder_path)[-1]
+
+ pose_json_path = os.path.join(folder_path, "pose.json")
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for image_dix, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[image_dix]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}')
+ img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}')
+
+ img = Image.open(img_filename)
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+
+ depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+
+ src_views = range(8, 8 + 8 * 4)
+
+ for vid in src_views:
+ if ((vid - 8) // 4 == 4) or ((vid - 8) // 4 == 6):
+ continue
+
+ # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}')
+ img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}')
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+
+
+ target_w2cs = []
+ target_intrinsics = []
+ new_target_w2cs = []
+ for i_idx in range(8):
+ target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv)
+ target_intrinsics.append(self.all_intrinsics[i_idx])
+
+ for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_target_w2cs.append(w2c)
+ target_w2cs = np.stack(new_target_w2cs)
+
+
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = shape_name
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0])
+ # print("meta: ", sample['meta'])
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_8_3.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_8_3.py
new file mode 100644
index 0000000000000000000000000000000000000000..512c3db02edc8e68208167b7d1715f1f67025cdf
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_8_3.py
@@ -0,0 +1,428 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+
+
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[],
+ specific_dataset_name = 'GSO'
+ ):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+ # self.specific_dataset_name = 'Realfusion'
+ # self.specific_dataset_name = 'GSO'
+ # self.specific_dataset_name = 'Objaverse'
+ # self.specific_dataset_name = 'Zero123'
+
+ self.specific_dataset_name = specific_dataset_name
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+ assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh'
+ # find all subfolders
+ main_folder = os.path.join(root_dir, self.specific_dataset_name)
+ self.shape_list = os.listdir(main_folder)
+ self.shape_list.sort()
+
+ # self.shape_list = ['barrel_render']
+ # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED
+
+
+ self.lvis_paths = []
+ for shape_name in self.shape_list:
+ self.lvis_paths.append(os.path.join(main_folder, shape_name))
+
+ # print("lvis_paths: ", self.lvis_paths)
+
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ # return 8*len(self.lvis_paths)
+ return len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ idx = idx * 8 # to be deleted
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views
+
+ folder_path = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+
+ # last subdir name
+ shape_name = os.path.split(folder_path)[-1]
+
+ pose_json_path = os.path.join(folder_path, "pose.json")
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for image_dix, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[image_dix]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}')
+ img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}')
+
+ img = Image.open(img_filename)
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+
+ depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+
+ # src_views = range(8, 8 + 8 * 4)
+
+ src_views = list()
+ for i in range(8):
+ # randomly choose 3 different number from [0,3]
+ # local_idxs = np.random.choice(4, 3, replace=False)
+ local_idxs = [0, 2, 3]
+ # local_idxs = np.random.choice(4, 3, replace=False)
+
+ local_idxs = [8 + i * 4 + local_idx for local_idx in local_idxs]
+ src_views += local_idxs
+
+ for vid in src_views:
+
+ # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}')
+ img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}')
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+
+
+ target_w2cs = []
+ target_intrinsics = []
+ new_target_w2cs = []
+ for i_idx in range(8):
+ target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv)
+ target_intrinsics.append(self.all_intrinsics[i_idx])
+
+ for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_target_w2cs.append(w2c)
+ target_w2cs = np.stack(new_target_w2cs)
+
+
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = shape_name
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0])
+ # print("meta: ", sample['meta'])
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_8_wide.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_8_wide.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c1a23183a388175c2212bf552fb15ae385737ab
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_8_wide.py
@@ -0,0 +1,420 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+
+
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[],
+ specific_dataset_name = 'GSO'
+ ):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+ # self.specific_dataset_name = 'Realfusion'
+ # self.specific_dataset_name = 'GSO'
+ # self.specific_dataset_name = 'Objaverse'
+ # self.specific_dataset_name = 'Zero123'
+
+ self.specific_dataset_name = specific_dataset_name
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+ assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh'
+ # find all subfolders
+ main_folder = os.path.join(root_dir, self.specific_dataset_name)
+ self.shape_list = os.listdir(main_folder)
+ self.shape_list.sort()
+
+ # self.shape_list = ['barrel_render']
+ # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED
+
+
+ self.lvis_paths = []
+ for shape_name in self.shape_list:
+ self.lvis_paths.append(os.path.join(main_folder, shape_name))
+
+ # print("lvis_paths: ", self.lvis_paths)
+
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ # return 8*len(self.lvis_paths)
+ return len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ idx = idx * 8 # to be deleted
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views
+
+ folder_path = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+
+ # last subdir name
+ shape_name = os.path.split(folder_path)[-1]
+
+ pose_json_path = os.path.join(folder_path, "pose.json")
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for image_dix, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[image_dix]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}')
+ img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}')
+
+ img = Image.open(img_filename)
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+
+ depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+
+ src_views = range(8)
+
+
+ for vid in src_views:
+
+ # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}')
+ # img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}')
+ img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[vid]}')
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+
+
+ target_w2cs = []
+ target_intrinsics = []
+ new_target_w2cs = []
+ for i_idx in range(8):
+ target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv)
+ target_intrinsics.append(self.all_intrinsics[i_idx])
+
+ for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_target_w2cs.append(w2c)
+ target_w2cs = np.stack(new_target_w2cs)
+
+
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = shape_name
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0])
+ # print("meta: ", sample['meta'])
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_temp.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_temp.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b2c7f6b2306cca93f476c2c233956e4cff0dcfb
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data_temp.py
@@ -0,0 +1,417 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+
+
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[],
+ specific_dataset_name = 'GSO'
+ ):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+ # self.specific_dataset_name = 'Realfusion'
+ # self.specific_dataset_name = 'GSO'
+ # self.specific_dataset_name = 'Objaverse'
+ self.specific_dataset_name = 'Objaverse_archived'
+
+ # self.specific_dataset_name = specific_dataset_name
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+ assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh'
+ # find all subfolders
+ main_folder = os.path.join(root_dir, self.specific_dataset_name)
+ self.shape_list = os.listdir(main_folder)
+ self.shape_list.sort()
+
+ # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED
+
+
+ self.lvis_paths = []
+ for shape_name in self.shape_list:
+ self.lvis_paths.append(os.path.join(main_folder, shape_name))
+
+ # print("lvis_paths: ", self.lvis_paths)
+
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ # return 8*len(self.lvis_paths)
+ return len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ idx = idx * 8 # to be deleted
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views
+
+ folder_path = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+
+ # last subdir name
+ shape_name = os.path.split(folder_path)[-1]
+
+ pose_json_path = os.path.join('/objaverse-processed/zero12345_img/zero12345_narrow_pose.json')
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for image_dix, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[image_dix]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ # img_filename = os.path.join(folder_path, 'stage1_8_debug', f'{self.img_ids[idx]}')
+ img_filename = os.path.join(folder_path, 'stage1_8', f'{idx}.png')
+
+ img = Image.open(img_filename)
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+
+ depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+ mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+
+ src_views = range(8, 8 + 8 * 4)
+
+ for vid in src_views:
+
+ # img_filename = os.path.join(folder_path, 'stage2_8_debug', f'{self.img_ids[vid]}')
+ img_filename = os.path.join(folder_path, 'stage2_8', f'{(vid-8)//4}_{(vid-8)%4}.png')
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+
+
+ target_w2cs = []
+ target_intrinsics = []
+ new_target_w2cs = []
+ for i_idx in range(8):
+ target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv)
+ target_intrinsics.append(self.all_intrinsics[i_idx])
+
+ for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_target_w2cs.append(w2c)
+ target_w2cs = np.stack(new_target_w2cs)
+
+
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = shape_name
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0])
+ # print("meta: ", sample['meta'])
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_no_depth.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_no_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..33a4ecf7de541049e3b89cc98f74106b59d418c7
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_no_depth.py
@@ -0,0 +1,388 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ # print("center", center)
+ # print("radius", radius)
+ # print("bounds", bounds)
+ # import ipdb; ipdb.set_trace()
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 8*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+
+ folder_uid_dict = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ # directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ # surface_points = directions * depth_h[..., None] # [H, W, 3]
+ # distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ # depth_h = distance
+
+ depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
+ depth_h = depth_h.fill_(-1.0)
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(8+idx*4, 8+(idx+1)*4)
+ src_views = range(8, 8 + 8 * 4)
+
+ for vid in src_views:
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0])
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_only_4.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_only_4.py
new file mode 100644
index 0000000000000000000000000000000000000000..f811326da45563ae870350f78ccdbe358411f3b6
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_only_4.py
@@ -0,0 +1,389 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ # print("center", center)
+ # print("radius", radius)
+ # print("bounds", bounds)
+ # import ipdb; ipdb.set_trace()
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 4*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ idx = idx * 2
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+
+ folder_uid_dict = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(8+idx*4, 8+(idx+1)*4)
+ src_views = range(8, 8 + 8 * 4)
+
+ for vid in src_views:
+ if (vid // 4) % 2 != 0:
+ continue
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+ # print("len(imgs)", len(imgs))
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0])
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_only_4_and_4.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_only_4_and_4.py
new file mode 100644
index 0000000000000000000000000000000000000000..76b9fccad69f6929e086074b55807ef5a0a17eee
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_only_4_and_4.py
@@ -0,0 +1,395 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/lvis_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+
+ self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(list(meta["c2ws"].values()))
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for idx, img_id in enumerate(self.img_ids):
+ pose = self.input_poses[idx]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid, img_id in enumerate(self.img_ids):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ # print("center", center)
+ # print("radius", radius)
+ # print("bounds", bounds)
+ # import ipdb; ipdb.set_trace()
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 8*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ idx = idx
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+
+ folder_uid_dict = self.lvis_paths[idx//8]
+ idx = idx % 8 # [0, 7]
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{idx}.png')
+
+ depth_filename = os.path.join(os.path.join(self.root_dir, folder_id, uid, f'view_{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+ # print("valid pixels", np.sum(mask_h))
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(8+idx*4, 8+(idx+1)*4)
+
+ src_views = range(8, 8 + 8 * 4)
+
+ vid_list = []
+ for vid in src_views:
+ if (vid // 4) % 2 != idx % 2:
+ continue
+ vid_list.append(vid)
+ img_filename = os.path.join(self.root_dir, folder_id, uid, f'view_{(vid - 8) // 4}_{vid%4}_10.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # print("idx:", idx)
+ # print("len(imgs)", len(imgs))
+ # print("vid_list", vid_list)
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0])
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/blender_gt_32.py b/SparseNeuS_demo_v1/data/blender_gt_32.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ec6f0075febfcd46061e61ae10cd68b05dfb5fc
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/blender_gt_32.py
@@ -0,0 +1,419 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+import json
+from termcolor import colored
+import imageio
+from kornia import create_meshgrid
+import open3d as o3d
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+import os, json
+import numpy as np
+def calc_pose(phis, thetas, size, radius = 1.2):
+ import torch
+ def normalize(vectors):
+ return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)
+ # device = torch.device('cuda')
+ thetas = torch.FloatTensor(thetas)
+ phis = torch.FloatTensor(phis)
+
+ centers = torch.stack([
+ radius * torch.sin(thetas) * torch.sin(phis),
+ -radius * torch.cos(thetas) * torch.sin(phis),
+ radius * torch.cos(phis),
+ ], dim=-1) # [B, 3]
+
+ # lookat
+ forward_vector = normalize(centers).squeeze(0)
+ up_vector = torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(size, 1)
+ right_vector = normalize(torch.cross(up_vector, forward_vector, dim=-1))
+ if right_vector.pow(2).sum() < 0.01:
+ right_vector = torch.FloatTensor([0, 1, 0]).unsqueeze(0).repeat(size, 1)
+ up_vector = normalize(torch.cross(forward_vector, right_vector, dim=-1))
+
+ poses = torch.eye(4, dtype=torch.float)[:3].unsqueeze(0).repeat(size, 1, 1)
+ poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
+ poses[:, :3, 3] = centers
+ return poses
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class BlenderPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ # print("root_dir: ", root_dir)
+ self.root_dir = root_dir
+ self.split = split
+
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ lvis_json_path = '/objaverse-processed/zero12345_img/random32_split.json' # folder_id and uid
+ with open(lvis_json_path, 'r') as f:
+ lvis_paths = json.load(f)
+ if self.split == 'train':
+ self.lvis_paths = lvis_paths['train']
+ else:
+ self.lvis_paths = lvis_paths['val']
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+ pose_json_path = "/objaverse-processed/zero12345_img/zero12345_narrow_pose.json"
+
+ with open(pose_json_path, 'r') as f:
+ meta = json.load(f)
+ intrinsic = np.eye(4)
+ intrinsic[:3, :3] = np.array(meta["intrinsics"])
+ self.intrinsic = intrinsic
+ self.near_far = np.array(meta["near_far"])
+ self.near_far[1] = 1.8
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
+
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+
+
+ def load_cam_info(self):
+ for vid in range(self.input_poses.shape[0]):
+ intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ pass
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+
+ center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ # print("center", center)
+ # print("radius", radius)
+ # print("bounds", bounds)
+ # import ipdb; ipdb.set_trace()
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return 32*len(self.lvis_paths)
+
+
+ def read_depth(self, filename, near_bound, noisy_factor=1.0):
+ pass
+
+
+ def __getitem__(self, idx):
+ sample = {}
+ origin_idx = idx
+ imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
+ intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj mats between views
+
+
+ folder_uid_dict = self.lvis_paths[idx//32]
+ idx = idx % 32 # [0, 7]
+ folder_id = folder_uid_dict['folder_id']
+ uid = folder_uid_dict['uid']
+
+ pose_file = os.path.join('/objaverse-processed/zero12345_img/random32/', folder_id, uid, 'views.npz')
+ pose_array = np.load(pose_file)
+ pose = calc_pose(pose_array['elevations'], pose_array['azimuths'], 32) # [32, 3, 4] c2ws
+
+ self.img_wh = (256, 256)
+ self.input_poses = np.array(pose)
+ self.input_poses = np.concatenate([self.input_poses, np.tile(np.array([0, 0, 0, 1], dtype=np.float32)[None, None, :], [self.input_poses.shape[0], 1, 1])], axis=1)
+ self.define_transforms()
+ self.blender2opencv = np.array(
+ [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
+ )
+
+ self.c2ws = []
+ self.w2cs = []
+ self.near_fars = []
+ # self.root_dir = root_dir
+ for image_dix in range(pose.shape[0]):
+ pose = self.input_poses[image_dix]
+ c2w = pose @ self.blender2opencv
+ self.c2ws.append(c2w)
+ self.w2cs.append(np.linalg.inv(c2w))
+ self.near_fars.append(self.near_far)
+ self.c2ws = np.stack(self.c2ws, axis=0)
+ self.w2cs = np.stack(self.w2cs, axis=0)
+
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+ self.load_cam_info()
+
+
+
+ # target view
+ c2w = self.c2ws[idx]
+ w2c = np.linalg.inv(c2w)
+ w2c_ref = w2c
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ w2cs.append(w2c @ w2c_ref_inv)
+ c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
+
+ img_filename = os.path.join('/objaverse-processed/zero12345_img/random32/', folder_id, uid, f'{idx}.png')
+
+ depth_filename = os.path.join(os.path.join('/objaverse-processed/zero12345_img/random32/', folder_id, uid, f'{idx}_depth_mm.png'))
+
+
+ img = Image.open(img_filename)
+
+ img = self.transform(img) # (4, h, w)
+
+
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+ imgs += [img]
+
+ depth_h = cv2.imread(depth_filename, cv2.IMREAD_UNCHANGED).astype(np.uint16) / 1000.0
+ mask_h = depth_h > 0
+
+ directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsic[0, 0], self.intrinsic[1, 1]]) # [H, W, 3]
+ surface_points = directions * depth_h[..., None] # [H, W, 3]
+ distance = np.linalg.norm(surface_points, axis=-1) # [H, W]
+ depth_h = distance
+
+
+ depths_h.append(depth_h)
+ masks_h.append(mask_h)
+
+ intrinsic = self.intrinsic
+ intrinsics.append(intrinsic)
+
+
+ near_fars.append(self.near_fars[idx])
+ image_perm = 0 # only supervised on reference view
+
+ mask_dilated = None
+
+ # src_views = range(8+idx*4, 8+(idx+1)*4)
+ src_views = range(0, 8 * 4)
+
+ for vid in src_views:
+ img_filename = os.path.join('/objaverse-processed/zero12345_img/random32/', folder_id, uid, f'{vid}.png')
+
+ img = Image.open(img_filename)
+ img_wh = self.img_wh
+
+ img = self.transform(img)
+ if img.shape[0] == 4:
+ img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
+
+ imgs += [img]
+ depth_h = np.ones(img.shape[1:], dtype=np.float32)
+ depths_h.append(depth_h)
+ masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
+
+ near_fars.append(self.all_near_fars[vid])
+ intrinsics.append(self.all_intrinsics[vid])
+
+ w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
+
+
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(
+ img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1
+ )
+
+
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ # print(new_near_fars)
+ imgs = torch.stack(imgs).float()
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if self.split == 'train':
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ view_ids = [idx] + list(src_views)
+ sample['origin_idx'] = origin_idx
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ # sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = folder_id
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(img_wh))
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = self.partial_vol_origin
+ sample['meta'] = str(folder_id) + "_" + str(uid) + "_refview" + str(view_ids[0])
+
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/dtu/dtu_pairs.txt b/SparseNeuS_demo_v1/data/dtu/dtu_pairs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bd0d79868f196991c06ec2a496dbe06e5ded0fd2
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/dtu/dtu_pairs.txt
@@ -0,0 +1,93 @@
+46
+0
+10 10 2346.410000 1 2036.530000 9 1243.890000 12 1052.870000 11 1000.840000 13 703.583000 2 604.456000 8 439.759000 14 327.419000 27 249.278000
+1
+10 9 2850.870000 10 2583.940000 2 2105.590000 0 2052.840000 8 1868.240000 13 1184.230000 14 1017.510000 12 961.966000 7 670.208000 15 657.218000
+2
+10 8 2501.240000 1 2106.880000 7 1856.500000 9 1782.340000 3 1141.770000 15 1061.760000 14 815.457000 16 762.153000 6 709.789000 10 699.921000
+3
+10 7 1294.390000 6 1159.130000 2 1134.270000 4 905.717000 8 687.320000 5 600.015000 17 496.958000 16 481.969000 1 379.011000 15 307.450000
+4
+10 5 1333.740000 6 1145.150000 3 895.254000 7 486.504000 18 446.420000 2 418.517000 17 326.528000 8 161.115000 16 149.154000 1 103.626000
+5
+10 6 1676.060000 18 1555.060000 4 1335.550000 17 868.416000 3 593.755000 7 467.816000 20 440.579000 19 428.255000 16 242.327000 21 210.253000
+6
+10 17 2332.350000 7 1848.240000 18 1812.740000 5 1696.070000 16 1273.000000 3 1157.990000 4 1155.410000 20 771.624000 21 744.945000 2 700.368000
+7
+10 16 2709.460000 8 2439.700000 15 2078.210000 6 1864.160000 2 1846.600000 17 1791.710000 3 1296.860000 22 957.793000 9 879.088000 21 782.277000
+8
+10 15 3124.010000 9 3099.920000 14 2756.290000 2 2501.220000 7 2449.320000 1 1875.940000 16 1726.040000 13 1325.760000 23 1177.090000 24 1108.820000
+9
+10 13 3355.620000 14 3226.070000 8 3098.800000 10 3097.070000 1 2861.420000 12 1873.630000 2 1785.980000 15 1753.320000 25 1365.450000 0 1261.590000
+10
+10 12 3750.700000 9 3085.870000 13 3028.390000 1 2590.550000 0 2369.790000 11 2266.670000 14 1524.160000 26 1448.150000 27 1293.600000 8 1041.840000
+11
+10 12 3543.760000 27 3056.050000 10 2248.070000 26 1524.280000 28 1273.330000 13 1265.900000 29 1129.550000 0 998.164000 9 591.176000 30 572.919000
+12
+10 27 3889.870000 10 3754.540000 13 3745.210000 11 3584.260000 26 3574.560000 25 1877.110000 9 1866.340000 29 1482.720000 30 1418.510000 14 1341.860000
+13
+10 12 3773.140000 26 3699.280000 25 3657.170000 14 3652.040000 9 3356.290000 10 3049.270000 24 2098.910000 27 1900.960000 31 1460.960000 30 1349.620000
+14
+10 13 3663.520000 24 3610.690000 9 3232.550000 25 3216.400000 15 3128.840000 8 2758.040000 23 2219.910000 26 1567.450000 10 1536.600000 32 1419.330000
+15
+10 23 3194.920000 14 3126.000000 8 3120.430000 16 2897.020000 24 2562.490000 7 2084.050000 22 2041.630000 9 1752.080000 33 1232.290000 13 1137.550000
+16
+10 15 2884.140000 7 2713.880000 22 2708.570000 17 2448.500000 21 2173.300000 23 1908.030000 8 1718.790000 6 1281.960000 35 1047.380000 34 980.064000
+17
+10 21 2632.480000 16 2428.000000 6 2343.570000 18 2250.230000 20 2149.750000 7 1779.420000 22 1380.250000 36 957.046000 5 878.398000 15 789.068000
+18
+9 17 2219.150000 20 2173.020000 6 1802.390000 19 1575.770000 5 1564.810000 21 1160.130000 16 660.317000 7 589.484000 36 559.983000
+19
+7 20 1828.970000 18 1564.630000 17 685.249000 36 613.420000 21 572.770000 5 427.597000 6 368.651000
+20
+8 21 2569.790000 36 2258.330000 18 2186.710000 17 2130.670000 19 1865.060000 35 996.122000 16 799.808000 40 778.721000
+21
+9 36 2704.590000 35 2639.690000 17 2638.190000 20 2605.430000 22 2604.260000 16 2158.250000 34 1239.250000 18 1178.240000 40 1128.570000
+22
+10 23 3232.680000 34 3175.150000 35 2831.090000 16 2712.510000 21 2632.190000 15 2033.390000 33 1712.670000 17 1393.860000 36 1290.960000 24 1195.330000
+23
+10 24 3710.900000 33 3603.070000 22 3244.200000 15 3190.620000 34 3086.490000 14 2220.110000 32 2100.000000 16 1917.100000 35 1359.790000 25 1356.710000
+24
+10 25 3844.600000 32 3750.750000 23 3710.600000 14 3609.090000 33 3091.040000 15 2559.240000 31 2423.710000 13 2109.360000 26 1440.580000 34 1410.030000
+25
+10 26 3951.740000 31 3888.570000 24 3833.070000 13 3667.350000 14 3208.210000 32 2993.460000 30 2681.520000 12 1900.230000 45 1484.030000 27 1462.880000
+26
+10 30 4033.350000 27 3970.470000 25 3925.250000 13 3686.340000 12 3595.590000 29 2943.870000 31 2917.000000 14 1556.340000 11 1554.750000 46 1503.840000
+27
+10 29 4027.840000 26 3929.940000 12 3875.580000 11 3085.030000 28 2908.600000 30 2792.670000 13 1878.420000 25 1438.550000 47 1425.200000 10 1290.250000
+28
+10 29 3687.020000 48 3209.130000 27 2872.860000 47 2014.530000 30 1361.950000 11 1273.600000 26 1062.850000 12 840.841000 46 672.985000 31 271.952000
+29
+10 27 4029.430000 30 3909.550000 28 3739.930000 47 3695.230000 48 3135.870000 26 2910.970000 46 2229.550000 12 1479.160000 31 1430.260000 11 1144.560000
+30
+10 26 4029.860000 29 3953.720000 31 3811.120000 46 3630.460000 47 3105.960000 27 2824.430000 25 2657.890000 45 2347.750000 32 1459.110000 12 1429.620000
+31
+10 25 3882.210000 30 3841.880000 32 3808.500000 45 3649.820000 46 3000.670000 26 2939.940000 24 2409.930000 44 2381.300000 13 1467.590000 29 1459.560000
+32
+10 31 3826.500000 24 3744.140000 33 3613.240000 44 3552.040000 25 3004.600000 45 2884.590000 43 2393.340000 23 2095.270000 30 1478.600000 14 1420.780000
+33
+10 32 3618.110000 23 3598.100000 34 3530.530000 43 3462.370000 24 3091.530000 44 2608.080000 42 2426.000000 22 1717.940000 31 1407.650000 25 1324.780000
+34
+10 33 3523.370000 42 3356.550000 35 3210.340000 22 3178.850000 23 3079.030000 43 2396.450000 41 2386.860000 24 1408.020000 32 1301.340000 21 1256.450000
+35
+10 34 3187.880000 41 3106.440000 36 2866.040000 22 2817.740000 21 2654.870000 40 2416.980000 42 2137.810000 23 1346.860000 33 1150.330000 16 1044.660000
+36
+8 40 2910.700000 35 2832.660000 21 2689.960000 20 2280.460000 41 1787.970000 22 1268.490000 34 981.636000 17 954.229000
+40
+7 36 2918.140000 41 2852.620000 35 2392.960000 21 1124.300000 42 1056.480000 34 877.946000 20 788.701000
+41
+9 35 3111.050000 42 3049.710000 40 2885.360000 34 2371.020000 36 1813.690000 43 1164.710000 22 1126.900000 21 906.536000 33 903.238000
+42
+10 34 3356.980000 43 3183.000000 41 3070.540000 33 2421.770000 35 2155.080000 44 1278.410000 23 1183.520000 22 1147.070000 40 1077.080000 32 899.646000
+43
+10 33 3461.240000 44 3380.740000 42 3188.700000 34 2400.600000 32 2399.090000 45 1359.370000 23 1314.080000 41 1176.120000 24 1159.620000 31 901.556000
+44
+10 32 3550.810000 45 3510.160000 43 3373.110000 33 2602.330000 31 2395.930000 24 1410.430000 46 1386.310000 42 1279.000000 25 1095.240000 34 968.440000
+45
+10 31 3650.090000 46 3555.090000 44 3491.150000 32 2868.390000 30 2373.590000 25 1485.370000 47 1405.280000 43 1349.540000 33 1104.770000 26 1046.810000
+46
+10 30 3635.640000 47 3562.170000 45 3524.170000 31 2976.820000 29 2264.040000 26 1508.870000 44 1367.410000 48 1352.100000 32 1211.240000 25 1102.170000
+47
+10 29 3705.310000 46 3519.760000 48 3450.480000 30 3074.770000 28 2054.630000 27 1434.570000 45 1377.340000 31 1268.230000 26 1223.830000 25 471.111000
+48
+10 47 3401.950000 28 3224.840000 29 3101.160000 46 1317.100000 30 1306.700000 27 1235.070000 26 537.731000 31 291.919000 45 276.869000 11 258.856000
diff --git a/SparseNeuS_demo_v1/data/dtu/lists/test.txt b/SparseNeuS_demo_v1/data/dtu/lists/test.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b1420254bbe0fe15e9ad9358cdbaedf34605a558
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/dtu/lists/test.txt
@@ -0,0 +1,15 @@
+scan24
+scan37
+scan40
+scan55
+scan63
+scan65
+scan69
+scan83
+scan97
+scan105
+scan106
+scan110
+scan114
+scan118
+scan122
\ No newline at end of file
diff --git a/SparseNeuS_demo_v1/data/dtu/lists/train.txt b/SparseNeuS_demo_v1/data/dtu/lists/train.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4259e846edcee621baf19875e2900e169849f5e3
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/dtu/lists/train.txt
@@ -0,0 +1,75 @@
+scan1
+scan4
+scan5
+scan6
+scan8
+scan9
+scan10
+scan11
+scan12
+scan13
+scan14
+scan15
+scan16
+scan17
+scan18
+scan19
+scan20
+scan21
+scan22
+scan23
+scan28
+scan29
+scan30
+scan31
+scan32
+scan33
+scan34
+scan35
+scan36
+scan38
+scan39
+scan41
+scan42
+scan43
+scan44
+scan45
+scan46
+scan47
+scan48
+scan49
+scan50
+scan51
+scan52
+scan59
+scan60
+scan61
+scan62
+scan64
+scan74
+scan75
+scan76
+scan77
+scan84
+scan85
+scan86
+scan87
+scan88
+scan89
+scan90
+scan91
+scan92
+scan93
+scan94
+scan95
+scan96
+scan98
+scan99
+scan100
+scan101
+scan102
+scan103
+scan104
+scan126
+scan127
+scan128
\ No newline at end of file
diff --git a/SparseNeuS_demo_v1/data/dtu_fit.py b/SparseNeuS_demo_v1/data/dtu_fit.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4a97d28b635a9158c49e2a651c7799ad1009021
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/dtu_fit.py
@@ -0,0 +1,278 @@
+import torch
+import torch.nn as nn
+import cv2 as cv
+import numpy as np
+import re
+import os
+import logging
+from glob import glob
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+
+from data.scene import get_boundingbox
+
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+class DtuFit:
+ def __init__(self, root_dir, split, scan_id, n_views, train_img_idx=[], test_img_idx=[],
+ img_wh=[800, 600], clip_wh=[0, 0], original_img_wh=[1600, 1200],
+ N_rays=512, h_patch_size=5, near=425, far=900):
+ super(DtuFit, self).__init__()
+ logging.info('Load data: Begin')
+
+ self.root_dir = root_dir
+ self.split = split
+ self.scan_id = scan_id
+ self.n_views = n_views
+
+ self.near = near
+ self.far = far
+
+ if self.scan_id is not None:
+ self.data_dir = os.path.join(self.root_dir, self.scan_id)
+ else:
+ self.data_dir = self.root_dir
+
+ self.img_wh = img_wh
+ self.clip_wh = clip_wh
+
+ if len(self.clip_wh) == 2:
+ self.clip_wh = self.clip_wh + self.clip_wh
+
+ self.original_img_wh = original_img_wh
+ self.N_rays = N_rays
+ self.h_patch_size = h_patch_size # used to extract patch for supervision
+ self.train_img_idx = train_img_idx
+ self.test_img_idx = test_img_idx
+
+ camera_dict = np.load(os.path.join(self.data_dir, 'cameras.npz'), allow_pickle=True)
+ self.images_list = sorted(glob(os.path.join(self.data_dir, "image/*.png")))
+ # world_mat: projection matrix: world to image
+ self.world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in
+ range(len(self.images_list))]
+
+ self.raw_near_fars = np.stack([np.array([self.near, self.far]) for i in range(len(self.images_list))])
+
+ # - reference image; transform the world system to the ref-camera system
+ self.ref_img_idx = self.train_img_idx[0]
+ ref_world_mat = self.world_mats_np[self.ref_img_idx]
+ self.ref_w2c = np.linalg.inv(load_K_Rt_from_P(None, ref_world_mat[:3, :4])[1])
+
+ self.all_images = []
+ self.all_intrinsics = []
+ self.all_w2cs = []
+
+ self.load_scene() # load the scene
+
+ # ! estimate scale_mat
+ self.scale_mat, self.scale_factor = self.cal_scale_mat(
+ img_hw=[self.img_wh[1], self.img_wh[0]],
+ intrinsics=self.all_intrinsics[self.train_img_idx],
+ extrinsics=self.all_w2cs[self.train_img_idx],
+ near_fars=self.raw_near_fars[self.train_img_idx],
+ factor=1.1)
+
+ # * after scaling and translation, unit bounding box
+ self.scaled_intrinsics, self.scaled_w2cs, self.scaled_c2ws, \
+ self.scaled_affine_mats, self.scaled_near_fars = self.scale_cam_info()
+ # import ipdb; ipdb.set_trace()
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+ self.partial_vol_origin = torch.Tensor([-1., -1., -1.])
+
+ logging.info('Load data: End')
+
+ def load_scene(self):
+
+ scale_x = self.img_wh[0] / self.original_img_wh[0]
+ scale_y = self.img_wh[1] / self.original_img_wh[1]
+
+ for idx in range(len(self.images_list)):
+ image = cv.imread(self.images_list[idx])
+ image = cv.resize(image, (self.img_wh[0], self.img_wh[1])) / 255.
+
+ image = image[self.clip_wh[1]:self.img_wh[1] - self.clip_wh[3],
+ self.clip_wh[0]:self.img_wh[0] - self.clip_wh[2]]
+ self.all_images.append(np.transpose(image[:, :, ::-1], (2, 0, 1))) # append [3,]
+
+ P = self.world_mats_np[idx]
+ P = P[:3, :4]
+ intrinsics, c2w = load_K_Rt_from_P(None, P)
+ w2c = np.linalg.inv(c2w)
+
+ intrinsics[:1] *= scale_x
+ intrinsics[1:2] *= scale_y
+
+ intrinsics[0, 2] -= self.clip_wh[0]
+ intrinsics[1, 2] -= self.clip_wh[1]
+
+ self.all_intrinsics.append(intrinsics)
+ # - transform from world system to ref-camera system
+ self.all_w2cs.append(w2c @ np.linalg.inv(self.ref_w2c))
+
+
+ self.all_images = torch.from_numpy(np.stack(self.all_images)).to(torch.float32)
+ self.all_intrinsics = torch.from_numpy(np.stack(self.all_intrinsics)).to(torch.float32)
+ self.all_w2cs = torch.from_numpy(np.stack(self.all_w2cs)).to(torch.float32)
+ self.img_wh = [self.img_wh[0] - self.clip_wh[0] - self.clip_wh[2],
+ self.img_wh[1] - self.clip_wh[1] - self.clip_wh[3]]
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+ center, radius, _ = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def scale_cam_info(self):
+ new_intrinsics = []
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ for idx in range(len(self.all_images)):
+ intrinsics = self.all_intrinsics[idx]
+ P = intrinsics @ self.all_w2cs[idx] @ self.scale_mat
+ P = P.cpu().numpy()[:3, :4]
+
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ new_intrinsics.append(intrinsics)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsics[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+
+ new_intrinsics, new_w2cs, new_c2ws, new_affine_mats, new_near_fars = \
+ np.stack(new_intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), \
+ np.stack(new_affine_mats), np.stack(new_near_fars)
+
+ new_intrinsics = torch.from_numpy(np.float32(new_intrinsics))
+ new_w2cs = torch.from_numpy(np.float32(new_w2cs))
+ new_c2ws = torch.from_numpy(np.float32(new_c2ws))
+ new_affine_mats = torch.from_numpy(np.float32(new_affine_mats))
+ new_near_fars = torch.from_numpy(np.float32(new_near_fars))
+
+ return new_intrinsics, new_w2cs, new_c2ws, new_affine_mats, new_near_fars
+
+
+ def get_conditional_sample(self):
+ sample = {}
+ support_idxs = self.train_img_idx
+
+ sample['images'] = self.all_images[support_idxs] # (V, 3, H, W)
+ sample['w2cs'] = self.scaled_w2cs[self.train_img_idx] # (V, 4, 4)
+ sample['c2ws'] = self.scaled_c2ws[self.train_img_idx] # (V, 4, 4)
+ sample['near_fars'] = self.scaled_near_fars[self.train_img_idx] # (V, 2)
+ sample['intrinsics'] = self.scaled_intrinsics[self.train_img_idx][:, :3, :3] # (V, 3, 3)
+ sample['affine_mats'] = self.scaled_affine_mats[self.train_img_idx] # ! in world space
+
+ sample['scan'] = self.scan_id
+ sample['scale_factor'] = torch.tensor(self.scale_factor)
+ sample['scale_mat'] = torch.from_numpy(self.scale_mat)
+ sample['trans_mat'] = torch.from_numpy(np.linalg.inv(self.ref_w2c))
+ sample['img_wh'] = torch.from_numpy(np.array(self.img_wh))
+ sample['partial_vol_origin'] = torch.tensor(self.partial_vol_origin, dtype=torch.float32)
+
+ return sample
+
+ def __len__(self):
+ if self.split == 'train':
+ return self.n_views * 1000
+ else:
+ return len(self.test_img_idx) * 1000
+
+ def __getitem__(self, idx):
+ sample = {}
+
+ if self.split == 'train':
+ render_idx = self.train_img_idx[idx % self.n_views]
+ support_idxs = [idx for idx in self.train_img_idx if idx != render_idx]
+ else:
+ # render_idx = idx % self.n_test_images + self.n_train_images
+ render_idx = self.test_img_idx[idx % len(self.test_img_idx)]
+ support_idxs = [render_idx]
+
+ sample['images'] = self.all_images[support_idxs] # (V, 3, H, W)
+ sample['w2cs'] = self.scaled_w2cs[support_idxs] # (V, 4, 4)
+ sample['c2ws'] = self.scaled_c2ws[support_idxs] # (V, 4, 4)
+ sample['intrinsics'] = self.scaled_intrinsics[support_idxs][:, :3, :3] # (V, 3, 3)
+ sample['affine_mats'] = self.scaled_affine_mats[support_idxs] # ! in world space
+ sample['scan'] = self.scan_id
+ sample['scale_factor'] = torch.tensor(self.scale_factor)
+ sample['img_wh'] = torch.from_numpy(np.array(self.img_wh))
+ sample['partial_vol_origin'] = torch.tensor(self.partial_vol_origin, dtype=torch.float32)
+ sample['img_index'] = torch.tensor(render_idx)
+
+ # - query image
+ sample['query_image'] = self.all_images[render_idx]
+ sample['query_c2w'] = self.scaled_c2ws[render_idx]
+ sample['query_w2c'] = self.scaled_w2cs[render_idx]
+ sample['query_intrinsic'] = self.scaled_intrinsics[render_idx]
+ sample['query_near_far'] = self.scaled_near_fars[render_idx]
+ sample['meta'] = str(self.scan_id) + "_" + os.path.basename(self.images_list[render_idx])
+ sample['scale_mat'] = torch.from_numpy(self.scale_mat)
+ sample['trans_mat'] = torch.from_numpy(np.linalg.inv(self.ref_w2c))
+ sample['rendering_c2ws'] = self.scaled_c2ws[self.test_img_idx]
+ sample['rendering_imgs_idx'] = torch.Tensor(np.array(self.test_img_idx).astype(np.int32))
+
+ # - generate rays
+ if self.split == 'val' or self.split == 'test':
+ sample_rays = gen_rays_from_single_image(
+ self.img_wh[1], self.img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=None,
+ mask=None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ self.img_wh[1], self.img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=None,
+ mask=None,
+ dilated_mask=None,
+ importance_sample=False,
+ h_patch_size=self.h_patch_size
+ )
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/dtu_general.py b/SparseNeuS_demo_v1/data/dtu_general.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6c7734df6072dd618ccdde71ca428f983a605e8
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/dtu_general.py
@@ -0,0 +1,376 @@
+from torch.utils.data import Dataset
+from utils.misc_utils import read_pfm
+import os
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms as T
+from data.scene import get_boundingbox
+
+from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
+
+from termcolor import colored
+import pdb
+import random
+
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv2.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+# ! load one ref-image with multiple src-images in camera coordinate system
+class MVSDatasetDtuPerView(Dataset):
+ def __init__(self, root_dir, split, n_views=3, img_wh=(640, 512), downSample=1.0,
+ split_filepath=None, pair_filepath=None,
+ N_rays=512,
+ vol_dims=[128, 128, 128], batch_size=1,
+ clean_image=False, importance_sample=False, test_ref_views=[]):
+
+ self.root_dir = root_dir
+ self.split = split
+
+ self.img_wh = img_wh
+ self.downSample = downSample
+ self.num_all_imgs = 49 # this preprocessed DTU dataset has 49 images
+ self.n_views = n_views
+ self.N_rays = N_rays
+ self.batch_size = batch_size # - used for construct new metas for gru fusion training
+
+ self.clean_image = clean_image
+ self.importance_sample = importance_sample
+ self.test_ref_views = test_ref_views # used for testing
+ self.scale_factor = 1.0
+ self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
+
+ if img_wh is not None:
+ assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
+ 'img_wh must both be multiples of 32!'
+
+ self.split_filepath = f'data/dtu/lists/{self.split}.txt' if split_filepath is None else split_filepath
+ self.pair_filepath = f'data/dtu/dtu_pairs.txt' if pair_filepath is None else pair_filepath
+
+ print(colored("loading all scenes together", 'red'))
+ with open(self.split_filepath) as f:
+ self.scans = [line.rstrip() for line in f.readlines()]
+
+ self.all_intrinsics = [] # the cam info of the whole scene
+ self.all_extrinsics = []
+ self.all_near_fars = []
+
+ self.metas, self.ref_src_pairs = self.build_metas() # load ref-srcs view pairs info of the scene
+
+ self.allview_ids = [i for i in range(self.num_all_imgs)]
+
+ self.load_cam_info() # load camera info of DTU, and estimate scale_mat
+
+ self.build_remap()
+ self.define_transforms()
+ print(f'==> image down scale: {self.downSample}')
+
+ # * bounding box for rendering
+ self.bbox_min = np.array([-1.0, -1.0, -1.0])
+ self.bbox_max = np.array([1.0, 1.0, 1.0])
+
+ # - used for cost volume regularization
+ self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
+ self.partial_vol_origin = torch.Tensor([-1., -1., -1.])
+
+ def build_remap(self):
+ self.remap = np.zeros(np.max(self.allview_ids) + 1).astype('int')
+ for i, item in enumerate(self.allview_ids):
+ self.remap[item] = i
+
+ def define_transforms(self):
+ self.transform = T.Compose([T.ToTensor()])
+
+ def build_metas(self):
+ metas = []
+ ref_src_pairs = {}
+ # light conditions 0-6 for training
+ # light condition 3 for testing (the brightest?)
+ light_idxs = [3] if 'train' not in self.split else range(7)
+
+ with open(self.pair_filepath) as f:
+ num_viewpoint = int(f.readline())
+ # viewpoints (49)
+ for _ in range(num_viewpoint):
+ ref_view = int(f.readline().rstrip())
+ src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
+
+ ref_src_pairs[ref_view] = src_views
+
+ for light_idx in light_idxs:
+ for scan in self.scans:
+ with open(self.pair_filepath) as f:
+ num_viewpoint = int(f.readline())
+ # viewpoints (49)
+ for _ in range(num_viewpoint):
+ ref_view = int(f.readline().rstrip())
+ src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
+
+ # ! only for validation
+ if len(self.test_ref_views) > 0 and ref_view not in self.test_ref_views:
+ continue
+
+ metas += [(scan, light_idx, ref_view, src_views)]
+
+ return metas, ref_src_pairs
+
+ def read_cam_file(self, filename):
+ with open(filename) as f:
+ lines = [line.rstrip() for line in f.readlines()]
+ # extrinsics: line [1,5), 4x4 matrix
+ extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
+ extrinsics = extrinsics.reshape((4, 4))
+ # intrinsics: line [7-10), 3x3 matrix
+ intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
+ intrinsics = intrinsics.reshape((3, 3))
+ # depth_min & depth_interval: line 11
+ depth_min = float(lines[11].split()[0])
+ depth_max = depth_min + float(lines[11].split()[1]) * 192
+ self.depth_interval = float(lines[11].split()[1])
+ intrinsics_ = np.float32(np.diag([1, 1, 1, 1]))
+ intrinsics_[:3, :3] = intrinsics
+ return intrinsics_, extrinsics, [depth_min, depth_max]
+
+ def load_cam_info(self):
+ for vid in range(self.num_all_imgs):
+ proj_mat_filename = os.path.join(self.root_dir,
+ f'Cameras/train/{vid:08d}_cam.txt')
+ intrinsic, extrinsic, near_far = self.read_cam_file(proj_mat_filename)
+ intrinsic[:2] *= 4 # * the provided intrinsics is 4x downsampled, now keep the same scale with image
+ self.all_intrinsics.append(intrinsic)
+ self.all_extrinsics.append(extrinsic)
+ self.all_near_fars.append(near_far)
+
+ def read_depth(self, filename):
+ # import ipdb; ipdb.set_trace()
+ depth_h = np.array(read_pfm(filename)[0], dtype=np.float32) # (1200, 1600)
+ depth_h = np.ones((1200, 1600))
+ # print(depth_h.shape)
+ depth_h = cv2.resize(depth_h, None, fx=0.5, fy=0.5,
+ interpolation=cv2.INTER_NEAREST) # (600, 800)
+ depth_h = depth_h[44:556, 80:720] # (512, 640)
+ # print(depth_h.shape)
+ # import ipdb; ipdb.set_trace()
+ depth_h = cv2.resize(depth_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ depth = cv2.resize(depth_h, None, fx=1.0 / 4, fy=1.0 / 4,
+ interpolation=cv2.INTER_NEAREST)
+
+ return depth, depth_h
+
+ def read_mask(self, filename):
+ mask_h = cv2.imread(filename, 0)
+ mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
+ interpolation=cv2.INTER_NEAREST)
+ mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
+ interpolation=cv2.INTER_NEAREST)
+
+ mask[mask > 0] = 1 # the masks stored in png are not binary
+ mask_h[mask_h > 0] = 1
+
+ return mask, mask_h
+
+ def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
+ center, radius, _ = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
+ radius = radius * factor
+ scale_mat = np.diag([radius, radius, radius, 1.0])
+ scale_mat[:3, 3] = center.cpu().numpy()
+ scale_mat = scale_mat.astype(np.float32)
+
+ return scale_mat, 1. / radius.cpu().numpy()
+
+ def __len__(self):
+ return len(self.metas)
+
+ def __getitem__(self, idx):
+ sample = {}
+ scan, light_idx, ref_view, src_views = self.metas[idx % len(self.metas)]
+
+ # generalized, load some images at once
+ view_ids = [ref_view] + src_views[:self.n_views]
+ # * transform from world system to camera system
+ w2c_ref = self.all_extrinsics[self.remap[ref_view]]
+ w2c_ref_inv = np.linalg.inv(w2c_ref)
+
+ image_perm = 0 # only supervised on reference view
+
+ imgs, depths_h, masks_h = [], [], [] # full size (640, 512)
+ intrinsics, w2cs, near_fars = [], [], [] # record proj mats between views
+ mask_dilated = None
+ for i, vid in enumerate(view_ids):
+ # NOTE that the id in image file names is from 1 to 49 (not 0~48)
+ img_filename = os.path.join(self.root_dir,
+ f'Rectified/{scan}_train/rect_{vid + 1:03d}_{light_idx}_r5000.png')
+ depth_filename = os.path.join(self.root_dir,
+ f'Depths/{scan}_train/depth_map_{vid:04d}.pfm')
+ # print(depth_filename)
+ mask_filename = os.path.join(self.root_dir,
+ f'Masks_clean_dilated/{scan}_train/mask_{vid:04d}.png')
+
+ img = Image.open(img_filename)
+ img_wh = np.round(np.array(img.size) * self.downSample).astype('int')
+ img = img.resize(img_wh, Image.BILINEAR)
+
+ if os.path.exists(mask_filename) and self.clean_image:
+ mask_l, mask_h = self.read_mask(mask_filename)
+ else:
+ # print(self.split, "don't find mask file", mask_filename)
+ mask_h = np.ones([img_wh[1], img_wh[0]])
+ masks_h.append(mask_h)
+
+ if i == 0:
+ kernel_size = 101 # default 101
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
+ mask_dilated = np.float32(cv2.dilate(np.uint8(mask_h * 255), kernel, iterations=1) > 128)
+
+ if self.clean_image:
+ img = np.array(img)
+ img[mask_h < 0.5] = 0.0
+
+ img = self.transform(img)
+
+ imgs += [img]
+
+ index_mat = self.remap[vid]
+ near_fars.append(self.all_near_fars[index_mat])
+ intrinsics.append(self.all_intrinsics[index_mat])
+
+ w2cs.append(self.all_extrinsics[index_mat] @ w2c_ref_inv)
+
+ # print(depth_filename)
+ if os.path.exists(depth_filename): # and i == 0
+ # print("file exists")
+ depth_l, depth_h = self.read_depth(depth_filename)
+ depths_h.append(depth_h)
+ # ! estimate scale_mat
+ scale_mat, scale_factor = self.cal_scale_mat(img_hw=[img_wh[1], img_wh[0]],
+ intrinsics=intrinsics, extrinsics=w2cs,
+ near_fars=near_fars, factor=1.1)
+
+ # ! calculate the new w2cs after scaling
+ new_near_fars = []
+ new_w2cs = []
+ new_c2ws = []
+ new_affine_mats = []
+ new_depths_h = []
+ for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
+ P = intrinsic @ extrinsic @ scale_mat
+ P = P[:3, :4]
+ # - should use load_K_Rt_from_P() to obtain c2w
+ c2w = load_K_Rt_from_P(None, P)[1]
+ w2c = np.linalg.inv(c2w)
+ new_w2cs.append(w2c)
+ new_c2ws.append(c2w)
+ affine_mat = np.eye(4)
+ affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
+ new_affine_mats.append(affine_mat)
+
+ camera_o = c2w[:3, 3]
+ dist = np.sqrt(np.sum(camera_o ** 2))
+ near = dist - 1
+ far = dist + 1
+
+ new_near_fars.append([0.95 * near, 1.05 * far])
+ new_depths_h.append(depth * scale_factor)
+
+ imgs = torch.stack(imgs).float()
+ print(new_near_fars)
+ depths_h = np.stack(new_depths_h)
+ masks_h = np.stack(masks_h)
+
+ affine_mats = np.stack(new_affine_mats)
+ intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
+ new_near_fars)
+
+ if 'train' in self.split:
+ start_idx = 0
+ else:
+ start_idx = 1
+
+ sample['images'] = imgs # (V, 3, H, W)
+ sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
+ sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
+ sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
+ sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
+ sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
+ sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
+ sample['view_ids'] = torch.from_numpy(np.array(view_ids))
+ sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
+
+ sample['light_idx'] = torch.tensor(light_idx)
+ sample['scan'] = scan
+
+ sample['scale_factor'] = torch.tensor(scale_factor)
+ sample['img_wh'] = torch.from_numpy(img_wh)
+ sample['render_img_idx'] = torch.tensor(image_perm)
+ sample['partial_vol_origin'] = torch.tensor(self.partial_vol_origin, dtype=torch.float32)
+ sample['meta'] = str(scan) + "_light" + str(light_idx) + "_refview" + str(ref_view)
+
+ # - image to render
+ sample['query_image'] = sample['images'][0]
+ sample['query_c2w'] = sample['c2ws'][0]
+ sample['query_w2c'] = sample['w2cs'][0]
+ sample['query_intrinsic'] = sample['intrinsics'][0]
+ sample['query_depth'] = sample['depths_h'][0]
+ sample['query_mask'] = sample['masks_h'][0]
+ sample['query_near_far'] = sample['near_fars'][0]
+
+ sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
+ sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
+ sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
+ sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
+ sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
+ sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
+ sample['view_ids'] = sample['view_ids'][start_idx:]
+ sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
+
+ sample['scale_mat'] = torch.from_numpy(scale_mat)
+ sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
+
+ # - generate rays
+ if ('val' in self.split) or ('test' in self.split):
+ sample_rays = gen_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None)
+ else:
+ sample_rays = gen_random_rays_from_single_image(
+ img_wh[1], img_wh[0],
+ self.N_rays,
+ sample['query_image'],
+ sample['query_intrinsic'],
+ sample['query_c2w'],
+ depth=sample['query_depth'],
+ mask=sample['query_mask'] if self.clean_image else None,
+ dilated_mask=mask_dilated,
+ importance_sample=self.importance_sample)
+
+ sample['rays'] = sample_rays
+
+ return sample
diff --git a/SparseNeuS_demo_v1/data/scene.py b/SparseNeuS_demo_v1/data/scene.py
new file mode 100644
index 0000000000000000000000000000000000000000..49183c65418338864ecabdd1af914bbb0f055579
--- /dev/null
+++ b/SparseNeuS_demo_v1/data/scene.py
@@ -0,0 +1,102 @@
+import numpy as np
+import torch
+import pdb
+
+
+def rigid_transform(xyz, transform):
+ """Applies a rigid transform (c2w) to an (N, 3) pointcloud.
+ """
+ device = xyz.device
+ xyz_h = torch.cat([xyz, torch.ones((len(xyz), 1)).to(device)], dim=1) # (N, 4)
+ xyz_t_h = (transform @ xyz_h.T).T # * checked: the same with the below
+
+ return xyz_t_h[:, :3]
+
+
+def get_view_frustum(min_depth, max_depth, size, cam_intr, c2w):
+ """Get corners of 3D camera view frustum of depth image
+ """
+ device = cam_intr.device
+ im_h, im_w = size
+ im_h = int(im_h)
+ im_w = int(im_w)
+ view_frust_pts = torch.stack([
+ (torch.tensor([0, 0, im_w, im_w, 0, 0, im_w, im_w]).to(device) - cam_intr[0, 2]) * torch.tensor(
+ [min_depth, min_depth, min_depth, min_depth, max_depth, max_depth, max_depth, max_depth]).to(device) /
+ cam_intr[0, 0],
+ (torch.tensor([0, im_h, 0, im_h, 0, im_h, 0, im_h]).to(device) - cam_intr[1, 2]) * torch.tensor(
+ [min_depth, min_depth, min_depth, min_depth, max_depth, max_depth, max_depth, max_depth]).to(device) /
+ cam_intr[1, 1],
+ torch.tensor([min_depth, min_depth, min_depth, min_depth, max_depth, max_depth, max_depth, max_depth]).to(
+ device)
+ ])
+ view_frust_pts = view_frust_pts.type(torch.float32)
+ c2w = c2w.type(torch.float32)
+ view_frust_pts = rigid_transform(view_frust_pts.T, c2w).T
+ return view_frust_pts
+
+
+def set_pixel_coords(h, w):
+ i_range = torch.arange(0, h).view(1, h, 1).expand(1, h, w).type(torch.float32) # [1, H, W]
+ j_range = torch.arange(0, w).view(1, 1, w).expand(1, h, w).type(torch.float32) # [1, H, W]
+ ones = torch.ones(1, h, w).type(torch.float32)
+
+ pixel_coords = torch.stack((j_range, i_range, ones), dim=1) # [1, 3, H, W]
+
+ return pixel_coords
+
+
+def get_boundingbox(img_hw, intrinsics, extrinsics, near_fars):
+ """
+ # get the minimum bounding box of all visual hulls
+ :param img_hw:
+ :param intrinsics:
+ :param extrinsics:
+ :param near_fars:
+ :return:
+ """
+
+ bnds = torch.zeros((3, 2))
+ bnds[:, 0] = np.inf
+ bnds[:, 1] = -np.inf
+
+ if isinstance(intrinsics, list):
+ num = len(intrinsics)
+ else:
+ num = intrinsics.shape[0]
+ # print("num: ", num)
+ view_frust_pts_list = []
+ for i in range(num):
+ if not isinstance(intrinsics[i], torch.Tensor):
+ cam_intr = torch.tensor(intrinsics[i])
+ w2c = torch.tensor(extrinsics[i])
+ c2w = torch.inverse(w2c)
+ else:
+ cam_intr = intrinsics[i]
+ w2c = extrinsics[i]
+ c2w = torch.inverse(w2c)
+ min_depth, max_depth = near_fars[i][0], near_fars[i][1]
+ # todo: check the coresponding points are matched
+
+ view_frust_pts = get_view_frustum(min_depth, max_depth, img_hw, cam_intr, c2w)
+ bnds[:, 0] = torch.min(bnds[:, 0], torch.min(view_frust_pts, dim=1)[0])
+ bnds[:, 1] = torch.max(bnds[:, 1], torch.max(view_frust_pts, dim=1)[0])
+ view_frust_pts_list.append(view_frust_pts)
+ all_view_frust_pts = torch.cat(view_frust_pts_list, dim=1)
+
+ # print("all_view_frust_pts: ", all_view_frust_pts.shape)
+ # distance = torch.norm(all_view_frust_pts, dim=0)
+ # print("distance: ", distance)
+
+ # print("all_view_frust_pts_z: ", all_view_frust_pts[2, :])
+
+ center = torch.tensor(((bnds[0, 1] + bnds[0, 0]) / 2, (bnds[1, 1] + bnds[1, 0]) / 2,
+ (bnds[2, 1] + bnds[2, 0]) / 2))
+
+ lengths = bnds[:, 1] - bnds[:, 0]
+
+ max_length, _ = torch.max(lengths, dim=0)
+ radius = max_length / 2
+
+ # print("radius: ", radius)
+ return center, radius, bnds
diff --git a/SparseNeuS_demo_v1/evaluation/__init__.py b/SparseNeuS_demo_v1/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SparseNeuS_demo_v1/evaluation/clean_mesh.py b/SparseNeuS_demo_v1/evaluation/clean_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab65cc72d3be615b71ec852a7adea933355aa250
--- /dev/null
+++ b/SparseNeuS_demo_v1/evaluation/clean_mesh.py
@@ -0,0 +1,283 @@
+import numpy as np
+import cv2 as cv
+import os
+from glob import glob
+from scipy.io import loadmat
+import trimesh
+import open3d as o3d
+import torch
+from tqdm import tqdm
+
+import sys
+
+sys.path.append("../")
+
+
+def gen_rays_from_single_image(H, W, image, intrinsic, c2w, depth=None, mask=None):
+ """
+ generate rays in world space, for image image
+ :param H:
+ :param W:
+ :param intrinsics: [3,3]
+ :param c2ws: [4,4]
+ :return:
+ """
+ device = image.device
+ ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H),
+ torch.linspace(0, W - 1, W)) # pytorch's meshgrid has indexing='ij'
+ p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) # H, W, 3
+
+ # normalized ndc uv coordinates, (-1, 1)
+ ndc_u = 2 * xs / (W - 1) - 1
+ ndc_v = 2 * ys / (H - 1) - 1
+ rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float().to(device)
+
+ intrinsic_inv = torch.inverse(intrinsic)
+
+ p = p.view(-1, 3).float().to(device) # N_rays, 3
+ p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays, 3
+ rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays, 3
+ rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays, 3
+ rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays, 3
+
+ image = image.permute(1, 2, 0)
+ color = image.view(-1, 3)
+ depth = depth.view(-1, 1) if depth is not None else None
+ mask = mask.view(-1, 1) if mask is not None else torch.ones([H * W, 1]).to(device)
+ sample = {
+ 'rays_o': rays_o,
+ 'rays_v': rays_v,
+ 'rays_ndc_uv': rays_ndc_uv,
+ 'rays_color': color,
+ # 'rays_depth': depth,
+ 'rays_mask': mask,
+ 'rays_norm_XYZ_cam': p # - XYZ_cam, before multiply depth
+ }
+ if depth is not None:
+ sample['rays_depth'] = depth
+
+ return sample
+
+
+def load_K_Rt_from_P(filename, P=None):
+ if P is None:
+ lines = open(filename).read().splitlines()
+ if len(lines) == 4:
+ lines = lines[1:]
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
+ P = np.asarray(lines).astype(np.float32).squeeze()
+
+ out = cv.decomposeProjectionMatrix(P)
+ K = out[0]
+ R = out[1]
+ t = out[2]
+
+ K = K / K[2, 2]
+ intrinsics = np.eye(4)
+ intrinsics[:3, :3] = K
+
+ pose = np.eye(4, dtype=np.float32)
+ pose[:3, :3] = R.transpose() # ? why need transpose here
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
+
+ return intrinsics, pose # ! return cam2world matrix here
+
+
+def clean_points_by_mask(points, scan, imgs_idx=None, minimal_vis=0, mask_dilated_size=11):
+ cameras = np.load('{}/scan{}/cameras.npz'.format(DTU_DIR, scan))
+ mask_lis = sorted(glob('{}/scan{}/mask/*.png'.format(DTU_DIR, scan)))
+ n_images = 49 if scan < 83 else 64
+ inside_mask = np.zeros(len(points))
+
+ if imgs_idx is None:
+ imgs_idx = [i for i in range(n_images)]
+
+ # imgs_idx = [i for i in range(n_images)]
+ for i in imgs_idx:
+ P = cameras['world_mat_{}'.format(i)]
+ pts_image = np.matmul(P[None, :3, :3], points[:, :, None]).squeeze() + P[None, :3, 3]
+ pts_image = pts_image / pts_image[:, 2:]
+ pts_image = np.round(pts_image).astype(np.int32) + 1
+
+ mask_image = cv.imread(mask_lis[i])
+ kernel_size = mask_dilated_size # default 101
+ kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, (kernel_size, kernel_size))
+ mask_image = cv.dilate(mask_image, kernel, iterations=1)
+ mask_image = (mask_image[:, :, 0] > 128)
+
+ mask_image = np.concatenate([np.ones([1, 1600]), mask_image, np.ones([1, 1600])], axis=0)
+ mask_image = np.concatenate([np.ones([1202, 1]), mask_image, np.ones([1202, 1])], axis=1)
+
+ in_mask = (pts_image[:, 0] >= 0) * (pts_image[:, 0] <= 1600) * (pts_image[:, 1] >= 0) * (
+ pts_image[:, 1] <= 1200) > 0
+ curr_mask = mask_image[(pts_image[:, 1].clip(0, 1201), pts_image[:, 0].clip(0, 1601))]
+
+ curr_mask = curr_mask.astype(np.float32) * in_mask
+
+ inside_mask += curr_mask
+
+ return inside_mask > minimal_vis
+
+
+def clean_mesh_faces_by_mask(mesh_file, new_mesh_file, scan, imgs_idx, minimal_vis=0, mask_dilated_size=11):
+ old_mesh = trimesh.load(mesh_file)
+ old_vertices = old_mesh.vertices[:]
+ old_faces = old_mesh.faces[:]
+ mask = clean_points_by_mask(old_vertices, scan, imgs_idx, minimal_vis, mask_dilated_size)
+ indexes = np.ones(len(old_vertices)) * -1
+ indexes = indexes.astype(np.long)
+ indexes[np.where(mask)] = np.arange(len(np.where(mask)[0]))
+
+ faces_mask = mask[old_faces[:, 0]] & mask[old_faces[:, 1]] & mask[old_faces[:, 2]]
+ new_faces = old_faces[np.where(faces_mask)]
+ new_faces[:, 0] = indexes[new_faces[:, 0]]
+ new_faces[:, 1] = indexes[new_faces[:, 1]]
+ new_faces[:, 2] = indexes[new_faces[:, 2]]
+ new_vertices = old_vertices[np.where(mask)]
+
+ new_mesh = trimesh.Trimesh(new_vertices, new_faces)
+
+ new_mesh.export(new_mesh_file)
+
+
+def clean_mesh_by_faces_num(mesh, faces_num=500):
+ old_vertices = mesh.vertices[:]
+ old_faces = mesh.faces[:]
+
+ cc = trimesh.graph.connected_components(mesh.face_adjacency, min_len=faces_num)
+ mask = np.zeros(len(mesh.faces), dtype=np.bool)
+ mask[np.concatenate(cc)] = True
+
+ indexes = np.ones(len(old_vertices)) * -1
+ indexes = indexes.astype(np.long)
+ indexes[np.where(mask)] = np.arange(len(np.where(mask)[0]))
+
+ faces_mask = mask[old_faces[:, 0]] & mask[old_faces[:, 1]] & mask[old_faces[:, 2]]
+ new_faces = old_faces[np.where(faces_mask)]
+ new_faces[:, 0] = indexes[new_faces[:, 0]]
+ new_faces[:, 1] = indexes[new_faces[:, 1]]
+ new_faces[:, 2] = indexes[new_faces[:, 2]]
+ new_vertices = old_vertices[np.where(mask)]
+
+ new_mesh = trimesh.Trimesh(new_vertices, new_faces)
+
+ return new_mesh
+
+
+def clean_mesh_faces_outside_frustum(old_mesh_file, new_mesh_file, imgs_idx, H=1200, W=1600, mask_dilated_size=11,
+ isolated_face_num=500, keep_largest=True):
+ '''Remove faces of mesh which cannot be orserved by all cameras
+ '''
+ # if path_mask_npz:
+ # path_save_clean = IOUtils.add_file_name_suffix(path_save_clean, '_mask')
+
+ cameras = np.load('{}/scan{}/cameras.npz'.format(DTU_DIR, scan))
+ mask_lis = sorted(glob('{}/scan{}/mask/*.png'.format(DTU_DIR, scan)))
+
+ mesh = trimesh.load(old_mesh_file)
+ intersector = trimesh.ray.ray_pyembree.RayMeshIntersector(mesh)
+
+ all_indices = []
+ chunk_size = 5120
+ for i in imgs_idx:
+ mask_image = cv.imread(mask_lis[i])
+ kernel_size = mask_dilated_size # default 101
+ kernel = cv.getStructuringElement(cv.MORPH_ELLIPSE, (kernel_size, kernel_size))
+ mask_image = cv.dilate(mask_image, kernel, iterations=1)
+
+ P = cameras['world_mat_{}'.format(i)]
+
+ intrinsic, pose = load_K_Rt_from_P(None, P[:3, :])
+
+ rays = gen_rays_from_single_image(H, W, torch.from_numpy(mask_image).permute(2, 0, 1).float(),
+ torch.from_numpy(intrinsic)[:3, :3].float(),
+ torch.from_numpy(pose).float())
+ rays_o = rays['rays_o']
+ rays_d = rays['rays_v']
+ rays_mask = rays['rays_color']
+
+ rays_o = rays_o.split(chunk_size)
+ rays_d = rays_d.split(chunk_size)
+ rays_mask = rays_mask.split(chunk_size)
+
+ for rays_o_batch, rays_d_batch, rays_mask_batch in tqdm(zip(rays_o, rays_d, rays_mask)):
+ rays_mask_batch = rays_mask_batch[:, 0] > 128
+ rays_o_batch = rays_o_batch[rays_mask_batch]
+ rays_d_batch = rays_d_batch[rays_mask_batch]
+
+ idx_faces_hits = intersector.intersects_first(rays_o_batch.cpu().numpy(), rays_d_batch.cpu().numpy())
+ all_indices.append(idx_faces_hits)
+
+ values = np.unique(np.concatenate(all_indices, axis=0))
+ mask_faces = np.ones(len(mesh.faces))
+ mask_faces[values[1:]] = 0
+ print(f'Surfaces/Kept: {len(mesh.faces)}/{len(values)}')
+
+ mesh_o3d = o3d.io.read_triangle_mesh(old_mesh_file)
+ print("removing triangles by mask")
+ mesh_o3d.remove_triangles_by_mask(mask_faces)
+
+ o3d.io.write_triangle_mesh(new_mesh_file, mesh_o3d)
+
+ # # clean meshes
+ new_mesh = trimesh.load(new_mesh_file)
+ cc = trimesh.graph.connected_components(new_mesh.face_adjacency, min_len=500)
+ mask = np.zeros(len(new_mesh.faces), dtype=np.bool)
+ mask[np.concatenate(cc)] = True
+ new_mesh.update_faces(mask)
+ new_mesh.remove_unreferenced_vertices()
+ new_mesh.export(new_mesh_file)
+
+ # meshes = new_mesh.split(only_watertight=False)
+ #
+ # if not keep_largest:
+ # meshes = [mesh for mesh in meshes if len(mesh.faces) > isolated_face_num]
+ # # new_mesh = meshes[np.argmax([len(mesh.faces) for mesh in meshes])]
+ # merged_mesh = trimesh.util.concatenate(meshes)
+ # merged_mesh.export(new_mesh_file)
+ # else:
+ # new_mesh = meshes[np.argmax([len(mesh.faces) for mesh in meshes])]
+ # new_mesh.export(new_mesh_file)
+
+ o3d.io.write_triangle_mesh(new_mesh_file.replace(".ply", "_raw.ply"), mesh_o3d)
+ print("finishing removing triangles")
+
+
+def clean_outliers(old_mesh_file, new_mesh_file):
+ new_mesh = trimesh.load(old_mesh_file)
+
+ meshes = new_mesh.split(only_watertight=False)
+ new_mesh = meshes[np.argmax([len(mesh.faces) for mesh in meshes])]
+
+ new_mesh.export(new_mesh_file)
+
+
+if __name__ == "__main__":
+
+ scans = [24, 37, 40, 55, 63, 65, 69, 83, 97, 105, 106, 110, 114, 118, 122]
+
+ mask_kernel_size = 11
+
+ imgs_idx = [0, 1, 2]
+ # imgs_idx = [42, 43, 44]
+ # imgs_idx = [1, 8, 9]
+
+ DTU_DIR = "/home/xiaoxiao/dataset/DTU_IDR/DTU"
+ # DTU_DIR = "/userhome/cs/xxlong/dataset/DTU_IDR/DTU"
+
+ base_path = "/home/xiaoxiao/Workplace/nerf_reconstruction/Volume_NeuS/neus_camsys/exp/dtu/evaluation_23_24_33_new/volsdf"
+
+ for scan in scans:
+ print("processing scan%d" % scan)
+ dir_path = os.path.join(base_path, "scan%d" % scan)
+
+ old_mesh_file = glob(os.path.join(dir_path, "*.ply"))[0]
+
+ clean_mesh_file = os.path.join(dir_path, "clean_%03d.ply" % scan)
+ final_mesh_file = os.path.join(dir_path, "final_%03d.ply" % scan)
+
+ clean_mesh_faces_by_mask(old_mesh_file, clean_mesh_file, scan, imgs_idx, minimal_vis=1,
+ mask_dilated_size=mask_kernel_size)
+ clean_mesh_faces_outside_frustum(clean_mesh_file, final_mesh_file, imgs_idx, mask_dilated_size=mask_kernel_size)
+
+ print("finish processing scan%d" % scan)
diff --git a/SparseNeuS_demo_v1/evaluation/eval_dtu_python.py b/SparseNeuS_demo_v1/evaluation/eval_dtu_python.py
new file mode 100644
index 0000000000000000000000000000000000000000..a60230705ab3f8c7c2a0ed64a20634c7ab4d2eea
--- /dev/null
+++ b/SparseNeuS_demo_v1/evaluation/eval_dtu_python.py
@@ -0,0 +1,369 @@
+import numpy as np
+import open3d as o3d
+import sklearn.neighbors as skln
+from tqdm import tqdm
+from scipy.io import loadmat
+import multiprocessing as mp
+import argparse, os, sys
+import cv2 as cv
+
+from pathlib import Path
+
+
+def get_path_components(path):
+ path = Path(path)
+ ppath = str(path.parent)
+ stem = str(path.stem)
+ ext = str(path.suffix)
+ return ppath, stem, ext
+
+
+def sample_single_tri(input_):
+ n1, n2, v1, v2, tri_vert = input_
+ c = np.mgrid[:n1 + 1, :n2 + 1]
+ c += 0.5
+ c[0] /= max(n1, 1e-7)
+ c[1] /= max(n2, 1e-7)
+ c = np.transpose(c, (1, 2, 0))
+ k = c[c.sum(axis=-1) < 1] # m2
+ q = v1 * k[:, :1] + v2 * k[:, 1:] + tri_vert
+ return q
+
+
+def write_vis_pcd(file, points, colors):
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points)
+ pcd.colors = o3d.utility.Vector3dVector(colors)
+ o3d.io.write_point_cloud(file, pcd)
+
+
+def eval_cloud(args, num_cpu_cores=-1):
+ mp.freeze_support()
+ os.makedirs(args.vis_out_dir, exist_ok=True)
+
+ thresh = args.downsample_density
+ if args.mode == 'mesh':
+ pbar = tqdm(total=9)
+ pbar.set_description('read data mesh')
+ data_mesh = o3d.io.read_triangle_mesh(args.data)
+
+ vertices = np.asarray(data_mesh.vertices)
+ triangles = np.asarray(data_mesh.triangles)
+ tri_vert = vertices[triangles]
+
+ pbar.update(1)
+ pbar.set_description('sample pcd from mesh')
+ v1 = tri_vert[:, 1] - tri_vert[:, 0]
+ v2 = tri_vert[:, 2] - tri_vert[:, 0]
+ l1 = np.linalg.norm(v1, axis=-1, keepdims=True)
+ l2 = np.linalg.norm(v2, axis=-1, keepdims=True)
+ area2 = np.linalg.norm(np.cross(v1, v2), axis=-1, keepdims=True)
+ non_zero_area = (area2 > 0)[:, 0]
+ l1, l2, area2, v1, v2, tri_vert = [
+ arr[non_zero_area] for arr in [l1, l2, area2, v1, v2, tri_vert]
+ ]
+ thr = thresh * np.sqrt(l1 * l2 / area2)
+ n1 = np.floor(l1 / thr)
+ n2 = np.floor(l2 / thr)
+
+ with mp.Pool() as mp_pool:
+ new_pts = mp_pool.map(sample_single_tri,
+ ((n1[i, 0], n2[i, 0], v1[i:i + 1], v2[i:i + 1], tri_vert[i:i + 1, 0]) for i in
+ range(len(n1))), chunksize=1024)
+
+ new_pts = np.concatenate(new_pts, axis=0)
+ data_pcd = np.concatenate([vertices, new_pts], axis=0)
+
+ elif args.mode == 'pcd':
+ pbar = tqdm(total=8)
+ pbar.set_description('read data pcd')
+ data_pcd_o3d = o3d.io.read_point_cloud(args.data)
+ data_pcd = np.asarray(data_pcd_o3d.points)
+
+ pbar.update(1)
+ pbar.set_description('random shuffle pcd index')
+ shuffle_rng = np.random.default_rng()
+ shuffle_rng.shuffle(data_pcd, axis=0)
+
+ pbar.update(1)
+ pbar.set_description('downsample pcd')
+ nn_engine = skln.NearestNeighbors(n_neighbors=1, radius=thresh, algorithm='kd_tree', n_jobs=num_cpu_cores)
+ nn_engine.fit(data_pcd)
+ rnn_idxs = nn_engine.radius_neighbors(data_pcd, radius=thresh, return_distance=False)
+ mask = np.ones(data_pcd.shape[0], dtype=np.bool_)
+ for curr, idxs in enumerate(rnn_idxs):
+ if mask[curr]:
+ mask[idxs] = 0
+ mask[curr] = 1
+ data_down = data_pcd[mask]
+
+ pbar.update(1)
+ pbar.set_description('masking data pcd')
+ obs_mask_file = loadmat(f'{args.dataset_dir}/ObsMask/ObsMask{args.scan}_10.mat')
+ ObsMask, BB, Res = [obs_mask_file[attr] for attr in ['ObsMask', 'BB', 'Res']]
+ BB = BB.astype(np.float32)
+
+ patch = args.patch_size
+ inbound = ((data_down >= BB[:1] - patch) & (data_down < BB[1:] + patch * 2)).sum(axis=-1) == 3
+ data_in = data_down[inbound]
+
+ data_grid = np.around((data_in - BB[:1]) / Res).astype(np.int32)
+ grid_inbound = ((data_grid >= 0) & (data_grid < np.expand_dims(ObsMask.shape, 0))).sum(axis=-1) == 3
+ data_grid_in = data_grid[grid_inbound]
+ in_obs = ObsMask[data_grid_in[:, 0], data_grid_in[:, 1], data_grid_in[:, 2]].astype(np.bool_)
+ data_in_obs = data_in[grid_inbound][in_obs]
+
+ pbar.update(1)
+ pbar.set_description('read STL pcd')
+ stl_pcd = o3d.io.read_point_cloud(args.gt)
+ stl = np.asarray(stl_pcd.points)
+
+ pbar.update(1)
+ pbar.set_description('compute data2stl')
+ nn_engine.fit(stl)
+ dist_d2s, idx_d2s = nn_engine.kneighbors(data_in_obs, n_neighbors=1, return_distance=True)
+ max_dist = args.max_dist
+ mean_d2s = dist_d2s[dist_d2s < max_dist].mean()
+
+ pbar.update(1)
+ pbar.set_description('compute stl2data')
+ ground_plane = loadmat(f'{args.dataset_dir}/ObsMask/Plane{args.scan}.mat')['P']
+
+ stl_hom = np.concatenate([stl, np.ones_like(stl[:, :1])], -1)
+ above = (ground_plane.reshape((1, 4)) * stl_hom).sum(-1) > 0
+ stl_above = stl[above]
+
+ nn_engine.fit(data_in)
+ dist_s2d, idx_s2d = nn_engine.kneighbors(stl_above, n_neighbors=1, return_distance=True)
+ mean_s2d = dist_s2d[dist_s2d < max_dist].mean()
+
+ pbar.update(1)
+ pbar.set_description('visualize error')
+ vis_dist = args.visualize_threshold
+ R = np.array([[1, 0, 0]], dtype=np.float64)
+ G = np.array([[0, 1, 0]], dtype=np.float64)
+ B = np.array([[0, 0, 1]], dtype=np.float64)
+ W = np.array([[1, 1, 1]], dtype=np.float64)
+ data_color = np.tile(B, (data_down.shape[0], 1))
+ data_alpha = dist_d2s.clip(max=vis_dist) / vis_dist
+ data_color[np.where(inbound)[0][grid_inbound][in_obs]] = R * data_alpha + W * (1 - data_alpha)
+ data_color[np.where(inbound)[0][grid_inbound][in_obs][dist_d2s[:, 0] >= max_dist]] = G
+ write_vis_pcd(f'{args.vis_out_dir}/vis_{args.scan:03}_d2gt.ply', data_down, data_color)
+ stl_color = np.tile(B, (stl.shape[0], 1))
+ stl_alpha = dist_s2d.clip(max=vis_dist) / vis_dist
+ stl_color[np.where(above)[0]] = R * stl_alpha + W * (1 - stl_alpha)
+ stl_color[np.where(above)[0][dist_s2d[:, 0] >= max_dist]] = G
+ write_vis_pcd(f'{args.vis_out_dir}/vis_{args.scan:03}_gt2d.ply', stl, stl_color)
+
+ pbar.update(1)
+ pbar.set_description('done')
+ pbar.close()
+ over_all = (mean_d2s + mean_s2d) / 2
+ print(f'ean_d2gt: {mean_d2s}; mean_gt2d: {mean_s2d} over_all: {over_all}; .')
+
+ pparent, stem, ext = get_path_components(args.data)
+ if args.log is None:
+ path_log = os.path.join(pparent, 'eval_result.txt')
+ else:
+ path_log = args.log
+ with open(path_log, 'a+') as fLog:
+ fLog.write(f'mean_d2gt {np.round(mean_d2s, 3)} '
+ f'mean_gt2d {np.round(mean_s2d, 3)} '
+ f'Over_all {np.round(over_all, 3)} '
+ f'[{stem}] \n')
+
+ return over_all, mean_d2s, mean_s2d
+
+
+if __name__ == '__main__':
+ from glob import glob
+
+ mp.freeze_support()
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--data', type=str, default='data_in.ply')
+ parser.add_argument('--gt', type=str, help='ground truth')
+ parser.add_argument('--scan', type=int, default=1)
+ parser.add_argument('--mode', type=str, default='mesh', choices=['mesh', 'pcd'])
+ parser.add_argument('--dataset_dir', type=str, default='/dataset/dtu_official/SampleSet/MVS_Data')
+ parser.add_argument('--vis_out_dir', type=str, default='.')
+ parser.add_argument('--downsample_density', type=float, default=0.2)
+ parser.add_argument('--patch_size', type=float, default=60)
+ parser.add_argument('--max_dist', type=float, default=20)
+ parser.add_argument('--visualize_threshold', type=float, default=10)
+ parser.add_argument('--log', type=str, default=None)
+ args = parser.parse_args()
+
+ base_dir = "./exp"
+
+ GT_DIR = "./gt_pcd"
+
+ scans = [24, 37, 40, 55, 63, 65, 69, 83, 97, 105, 106, 110, 114, 118, 122]
+
+ for scan in scans:
+
+ print("processing scan%d" % scan)
+
+ args.data = os.path.join(base_dir, "scan{}".format(scan), "final_%03d.ply" % scan)
+
+ if not os.path.exists(args.data):
+ continue
+
+ args.gt = os.path.join(GT_DIR, "stl%03d_total.ply" % scan)
+ args.vis_out_dir = os.path.join(base_dir, "scan{}".format(scan))
+ args.scan = scan
+ os.makedirs(args.vis_out_dir, exist_ok=True)
+
+ dist_thred1 = 1
+ dist_thred2 = 2
+
+ thresh = args.downsample_density
+
+ if args.mode == 'mesh':
+ pbar = tqdm(total=9)
+ pbar.set_description('read data mesh')
+ data_mesh = o3d.io.read_triangle_mesh(args.data)
+
+ vertices = np.asarray(data_mesh.vertices)
+ triangles = np.asarray(data_mesh.triangles)
+ tri_vert = vertices[triangles]
+
+ pbar.update(1)
+ pbar.set_description('sample pcd from mesh')
+ v1 = tri_vert[:, 1] - tri_vert[:, 0]
+ v2 = tri_vert[:, 2] - tri_vert[:, 0]
+ l1 = np.linalg.norm(v1, axis=-1, keepdims=True)
+ l2 = np.linalg.norm(v2, axis=-1, keepdims=True)
+ area2 = np.linalg.norm(np.cross(v1, v2), axis=-1, keepdims=True)
+ non_zero_area = (area2 > 0)[:, 0]
+ l1, l2, area2, v1, v2, tri_vert = [
+ arr[non_zero_area] for arr in [l1, l2, area2, v1, v2, tri_vert]
+ ]
+ thr = thresh * np.sqrt(l1 * l2 / area2)
+ n1 = np.floor(l1 / thr)
+ n2 = np.floor(l2 / thr)
+
+ with mp.Pool() as mp_pool:
+ new_pts = mp_pool.map(sample_single_tri,
+ ((n1[i, 0], n2[i, 0], v1[i:i + 1], v2[i:i + 1], tri_vert[i:i + 1, 0]) for i in
+ range(len(n1))), chunksize=1024)
+
+ new_pts = np.concatenate(new_pts, axis=0)
+ data_pcd = np.concatenate([vertices, new_pts], axis=0)
+
+ elif args.mode == 'pcd':
+ pbar = tqdm(total=8)
+ pbar.set_description('read data pcd')
+ data_pcd_o3d = o3d.io.read_point_cloud(args.data)
+ data_pcd = np.asarray(data_pcd_o3d.points)
+
+ pbar.update(1)
+ pbar.set_description('random shuffle pcd index')
+ shuffle_rng = np.random.default_rng()
+ shuffle_rng.shuffle(data_pcd, axis=0)
+
+ pbar.update(1)
+ pbar.set_description('downsample pcd')
+ nn_engine = skln.NearestNeighbors(n_neighbors=1, radius=thresh, algorithm='kd_tree', n_jobs=-1)
+ nn_engine.fit(data_pcd)
+ rnn_idxs = nn_engine.radius_neighbors(data_pcd, radius=thresh, return_distance=False)
+ mask = np.ones(data_pcd.shape[0], dtype=np.bool_)
+ for curr, idxs in enumerate(rnn_idxs):
+ if mask[curr]:
+ mask[idxs] = 0
+ mask[curr] = 1
+ data_down = data_pcd[mask]
+
+ pbar.update(1)
+ pbar.set_description('masking data pcd')
+ obs_mask_file = loadmat(f'{args.dataset_dir}/ObsMask/ObsMask{args.scan}_10.mat')
+ ObsMask, BB, Res = [obs_mask_file[attr] for attr in ['ObsMask', 'BB', 'Res']]
+ BB = BB.astype(np.float32)
+
+ patch = args.patch_size
+ inbound = ((data_down >= BB[:1] - patch) & (data_down < BB[1:] + patch * 2)).sum(axis=-1) == 3
+ data_in = data_down[inbound]
+
+ data_grid = np.around((data_in - BB[:1]) / Res).astype(np.int32)
+ grid_inbound = ((data_grid >= 0) & (data_grid < np.expand_dims(ObsMask.shape, 0))).sum(axis=-1) == 3
+ data_grid_in = data_grid[grid_inbound]
+ in_obs = ObsMask[data_grid_in[:, 0], data_grid_in[:, 1], data_grid_in[:, 2]].astype(np.bool_)
+ data_in_obs = data_in[grid_inbound][in_obs]
+
+ pbar.update(1)
+ pbar.set_description('read STL pcd')
+ stl_pcd = o3d.io.read_point_cloud(args.gt)
+ stl = np.asarray(stl_pcd.points)
+
+ pbar.update(1)
+ pbar.set_description('compute data2stl')
+ nn_engine.fit(stl)
+ dist_d2s, idx_d2s = nn_engine.kneighbors(data_in_obs, n_neighbors=1, return_distance=True)
+ max_dist = args.max_dist
+ mean_d2s = dist_d2s[dist_d2s < max_dist].mean()
+
+ precision_1 = len(dist_d2s[dist_d2s < dist_thred1]) / len(dist_d2s)
+ precision_2 = len(dist_d2s[dist_d2s < dist_thred2]) / len(dist_d2s)
+
+ pbar.update(1)
+ pbar.set_description('compute stl2data')
+ ground_plane = loadmat(f'{args.dataset_dir}/ObsMask/Plane{args.scan}.mat')['P']
+
+ stl_hom = np.concatenate([stl, np.ones_like(stl[:, :1])], -1)
+ above = (ground_plane.reshape((1, 4)) * stl_hom).sum(-1) > 0
+
+ stl_above = stl[above]
+
+ nn_engine.fit(data_in)
+ dist_s2d, idx_s2d = nn_engine.kneighbors(stl_above, n_neighbors=1, return_distance=True)
+ mean_s2d = dist_s2d[dist_s2d < max_dist].mean()
+
+ recall_1 = len(dist_s2d[dist_s2d < dist_thred1]) / len(dist_s2d)
+ recall_2 = len(dist_s2d[dist_s2d < dist_thred2]) / len(dist_s2d)
+
+ pbar.update(1)
+ pbar.set_description('visualize error')
+ vis_dist = args.visualize_threshold
+ R = np.array([[1, 0, 0]], dtype=np.float64)
+ G = np.array([[0, 1, 0]], dtype=np.float64)
+ B = np.array([[0, 0, 1]], dtype=np.float64)
+ W = np.array([[1, 1, 1]], dtype=np.float64)
+ data_color = np.tile(B, (data_down.shape[0], 1))
+ data_alpha = dist_d2s.clip(max=vis_dist) / vis_dist
+ data_color[np.where(inbound)[0][grid_inbound][in_obs]] = R * data_alpha + W * (1 - data_alpha)
+ data_color[np.where(inbound)[0][grid_inbound][in_obs][dist_d2s[:, 0] >= max_dist]] = G
+ write_vis_pcd(f'{args.vis_out_dir}/vis_{args.scan:03}_d2gt.ply', data_down, data_color)
+ stl_color = np.tile(B, (stl.shape[0], 1))
+ stl_alpha = dist_s2d.clip(max=vis_dist) / vis_dist
+ stl_color[np.where(above)[0]] = R * stl_alpha + W * (1 - stl_alpha)
+ stl_color[np.where(above)[0][dist_s2d[:, 0] >= max_dist]] = G
+ write_vis_pcd(f'{args.vis_out_dir}/vis_{args.scan:03}_gt2d.ply', stl, stl_color)
+
+ pbar.update(1)
+ pbar.set_description('done')
+ pbar.close()
+ over_all = (mean_d2s + mean_s2d) / 2
+
+ fscore_1 = 2 * precision_1 * recall_1 / (precision_1 + recall_1 + 1e-6)
+ fscore_2 = 2 * precision_2 * recall_2 / (precision_2 + recall_2 + 1e-6)
+
+ print(f'over_all: {over_all}; mean_d2gt: {mean_d2s}; mean_gt2d: {mean_s2d}.')
+ print(f'precision_1mm: {precision_1}; recall_1mm: {recall_1}; fscore_1mm: {fscore_1}')
+ print(f'precision_2mm: {precision_2}; recall_2mm: {recall_2}; fscore_2mm: {fscore_2}')
+
+ pparent, stem, ext = get_path_components(args.data)
+ if args.log is None:
+ path_log = os.path.join(pparent, 'eval_result.txt')
+ else:
+ path_log = args.log
+ with open(path_log, 'w+') as fLog:
+ fLog.write(f'over_all {np.round(over_all, 3)} '
+ f'mean_d2gt {np.round(mean_d2s, 3)} '
+ f'mean_gt2d {np.round(mean_s2d, 3)} \n'
+ f'precision_1mm {np.round(precision_1, 3)} '
+ f'recall_1mm {np.round(recall_1, 3)} '
+ f'fscore_1mm {np.round(fscore_1, 3)} \n'
+ f'precision_2mm {np.round(precision_2, 3)} '
+ f'recall_2mm {np.round(recall_2, 3)} '
+ f'fscore_2mm {np.round(fscore_2, 3)} \n'
+ f'[{stem}] \n')
diff --git a/SparseNeuS_demo_v1/exp/lod0/checkpoint_trash/ckpt_285000.pth b/SparseNeuS_demo_v1/exp/lod0/checkpoint_trash/ckpt_285000.pth
new file mode 100644
index 0000000000000000000000000000000000000000..043937847350af33b459128ade1a470064ce261c
--- /dev/null
+++ b/SparseNeuS_demo_v1/exp/lod0/checkpoint_trash/ckpt_285000.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:763c2a4934928cc089342905ba61481d6f9efc977b9729d7fc2d3eae4f0e1f9b
+size 5310703
diff --git a/SparseNeuS_demo_v1/exp/lod0/checkpoints/ckpt_340000.pth b/SparseNeuS_demo_v1/exp/lod0/checkpoints/ckpt_340000.pth
new file mode 100644
index 0000000000000000000000000000000000000000..b5ba43d31ad82a3ccd5e5be45087e602fb98260e
--- /dev/null
+++ b/SparseNeuS_demo_v1/exp/lod0/checkpoints/ckpt_340000.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3a947469b4b1a7044b2dcdd576e52279ed48a05d52135231137ece9f0ef810c8
+size 5310703
diff --git a/SparseNeuS_demo_v1/exp/lod0/checkpoints_white/ckpt_245000.pth b/SparseNeuS_demo_v1/exp/lod0/checkpoints_white/ckpt_245000.pth
new file mode 100644
index 0000000000000000000000000000000000000000..90e582ba7a02d6b46dc2366a8b9ef61e195dd9ef
--- /dev/null
+++ b/SparseNeuS_demo_v1/exp/lod0/checkpoints_white/ckpt_245000.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f40cd7db7f7a5ff16bb2bfbfcbd7c6a8c7e6d3032698cc5779eeaf507225cd97
+size 5310703
diff --git a/SparseNeuS_demo_v1/exp_runner_generic_blender_val.py b/SparseNeuS_demo_v1/exp_runner_generic_blender_val.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a94a670988c6c354fdface782fe88870cd891c5
--- /dev/null
+++ b/SparseNeuS_demo_v1/exp_runner_generic_blender_val.py
@@ -0,0 +1,656 @@
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import argparse
+import os
+import logging
+import numpy as np
+import cv2 as cv
+import trimesh
+from shutil import copyfile
+from torch.utils.tensorboard import SummaryWriter
+from icecream import ic
+from tqdm import tqdm
+from pyhocon import ConfigFactory
+
+from models.fields import SingleVarianceNetwork
+
+from models.featurenet import FeatureNet
+
+from models.trainer_generic import GenericTrainer
+
+from models.sparse_sdf_network import SparseSdfNetwork
+
+from models.rendering_network import GeneralRenderingNetwork
+
+from datetime import datetime
+
+from data.dtu_general import MVSDatasetDtuPerView
+
+from utils.training_utils import tocuda
+from data.blender_general_narrow_all_eval_new_data import BlenderPerView
+
+from termcolor import colored
+
+from datetime import datetime
+
+class Runner:
+ def __init__(self, conf_path, mode='train', is_continue=False,
+ is_restore=False, restore_lod0=False, local_rank=0):
+
+ # Initial setting
+ self.device = torch.device('cuda:%d' % local_rank)
+ # self.device = torch.device('cuda')
+ self.num_devices = torch.cuda.device_count()
+ self.is_continue = is_continue
+ self.is_restore = is_restore
+ self.restore_lod0 = restore_lod0
+ self.mode = mode
+ self.model_list = []
+ self.logger = logging.getLogger('exp_logger')
+
+ print(colored("detected %d GPUs" % self.num_devices, "red"))
+
+ self.conf_path = conf_path
+ self.conf = ConfigFactory.parse_file(conf_path)
+ self.timestamp = None
+ if not self.is_continue:
+ self.timestamp = '_{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())
+ self.base_exp_dir = self.conf['general.base_exp_dir'] + self.timestamp # jha comment this when testing and use this when training
+ else:
+ self.base_exp_dir = self.conf['general.base_exp_dir']
+ self.conf['general.base_exp_dir'] = self.base_exp_dir # jha use this when testing
+ print(colored("base_exp_dir: " + self.base_exp_dir, 'yellow'))
+ os.makedirs(self.base_exp_dir, exist_ok=True)
+ self.iter_step = 0
+ self.val_step = 0
+
+ # trainning parameters
+ self.end_iter = self.conf.get_int('train.end_iter')
+ self.save_freq = self.conf.get_int('train.save_freq')
+ self.report_freq = self.conf.get_int('train.report_freq')
+ self.val_freq = self.conf.get_int('train.val_freq')
+ self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq')
+ self.batch_size = self.num_devices # use DataParallel to warp
+ self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level')
+ self.learning_rate = self.conf.get_float('train.learning_rate')
+ self.learning_rate_milestone = self.conf.get_list('train.learning_rate_milestone')
+ self.learning_rate_factor = self.conf.get_float('train.learning_rate_factor')
+ self.use_white_bkgd = self.conf.get_bool('train.use_white_bkgd')
+ self.N_rays = self.conf.get_int('train.N_rays')
+
+ # warmup params for sdf gradient
+ self.anneal_start_lod0 = self.conf.get_float('train.anneal_start', default=0)
+ self.anneal_end_lod0 = self.conf.get_float('train.anneal_end', default=0)
+ self.anneal_start_lod1 = self.conf.get_float('train.anneal_start_lod1', default=0)
+ self.anneal_end_lod1 = self.conf.get_float('train.anneal_end_lod1', default=0)
+
+ self.writer = None
+
+ # Networks
+ self.num_lods = self.conf.get_int('model.num_lods')
+
+ self.rendering_network_outside = None
+ self.sdf_network_lod0 = None
+ self.sdf_network_lod1 = None
+ self.variance_network_lod0 = None
+ self.variance_network_lod1 = None
+ self.rendering_network_lod0 = None
+ self.rendering_network_lod1 = None
+ self.pyramid_feature_network = None # extract 2d pyramid feature maps from images, used for geometry
+ self.pyramid_feature_network_lod1 = None # may use different feature network for different lod
+
+ # * pyramid_feature_network
+ self.pyramid_feature_network = FeatureNet().to(self.device)
+ self.sdf_network_lod0 = SparseSdfNetwork(**self.conf['model.sdf_network_lod0']).to(self.device)
+ self.variance_network_lod0 = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device)
+
+ if self.num_lods > 1:
+ self.sdf_network_lod1 = SparseSdfNetwork(**self.conf['model.sdf_network_lod1']).to(self.device)
+ self.variance_network_lod1 = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device)
+
+ self.rendering_network_lod0 = GeneralRenderingNetwork(**self.conf['model.rendering_network']).to(
+ self.device)
+
+ if self.num_lods > 1:
+ self.pyramid_feature_network_lod1 = FeatureNet().to(self.device)
+ self.rendering_network_lod1 = GeneralRenderingNetwork(
+ **self.conf['model.rendering_network_lod1']).to(self.device)
+ if self.mode == 'export_mesh' or self.mode == 'val':
+ # base_exp_dir_to_store = os.path.join(self.base_exp_dir, '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()))
+ print("save mesh to:", os.path.join("../", args.specific_dataset_name))
+ base_exp_dir_to_store = os.path.join("../", args.specific_dataset_name) #"../gradio_tmp" # MODIFIED
+ else:
+ base_exp_dir_to_store = self.base_exp_dir
+
+ print(colored(f"Store in: {base_exp_dir_to_store}", "blue"))
+ # Renderer model
+ self.trainer = GenericTrainer(
+ self.rendering_network_outside,
+ self.pyramid_feature_network,
+ self.pyramid_feature_network_lod1,
+ self.sdf_network_lod0,
+ self.sdf_network_lod1,
+ self.variance_network_lod0,
+ self.variance_network_lod1,
+ self.rendering_network_lod0,
+ self.rendering_network_lod1,
+ **self.conf['model.trainer'],
+ timestamp=self.timestamp,
+ base_exp_dir=base_exp_dir_to_store,
+ conf=self.conf)
+
+ self.data_setup() # * data setup
+
+ self.optimizer_setup()
+
+ # Load checkpoint
+ latest_model_name = None
+ if is_continue:
+ model_list_raw = os.listdir(os.path.join(self.base_exp_dir, 'checkpoints'))
+ model_list = []
+ for model_name in model_list_raw:
+ if model_name.startswith('ckpt'):
+ if model_name[-3:] == 'pth': # and int(model_name[5:-4]) <= self.end_iter:
+ model_list.append(model_name)
+ model_list.sort()
+ latest_model_name = model_list[-1]
+
+ if latest_model_name is not None:
+ self.logger.info('Find checkpoint: {}'.format(latest_model_name))
+ self.load_checkpoint(latest_model_name)
+
+ self.trainer = torch.nn.DataParallel(self.trainer).to(self.device)
+
+ if self.mode[:5] == 'train':
+ self.file_backup()
+
+ def optimizer_setup(self):
+ self.params_to_train = self.trainer.get_trainable_params()
+ self.optimizer = torch.optim.Adam(self.params_to_train, lr=self.learning_rate)
+
+ def data_setup(self):
+ """
+ if use ddp, use setup() not prepare_data(),
+ prepare_data() only called on 1 GPU/TPU in distributed
+ :return:
+ """
+
+ self.train_dataset = BlenderPerView(
+ root_dir=self.conf['dataset.trainpath'],
+ split=self.conf.get_string('dataset.train_split', default='train'),
+ split_filepath=self.conf.get_string('dataset.train_split_filepath', default=None),
+ n_views=self.conf['dataset.nviews'],
+ downSample=self.conf['dataset.imgScale_train'],
+ N_rays=self.N_rays,
+ batch_size=self.batch_size,
+ clean_image=True, # True for training
+ importance_sample=self.conf.get_bool('dataset.importance_sample', default=False),
+ specific_dataset_name = args.specific_dataset_name
+ )
+
+ self.val_dataset = BlenderPerView(
+ root_dir=self.conf['dataset.valpath'],
+ split=self.conf.get_string('dataset.test_split', default='test'),
+ split_filepath=self.conf.get_string('dataset.val_split_filepath', default=None),
+ n_views=3,
+ downSample=self.conf['dataset.imgScale_test'],
+ N_rays=self.N_rays,
+ batch_size=self.batch_size,
+ clean_image=self.conf.get_bool('dataset.mask_out_image',
+ default=False) if self.mode != 'train' else False,
+ importance_sample=self.conf.get_bool('dataset.importance_sample', default=False),
+ test_ref_views=self.conf.get_list('dataset.test_ref_views', default=[]),
+ specific_dataset_name = args.specific_dataset_name
+ )
+
+ # item = self.train_dataset.__getitem__(0)
+ self.train_dataloader = DataLoader(self.train_dataset,
+ shuffle=True,
+ num_workers=4 * self.batch_size,
+ # num_workers=1,
+ batch_size=self.batch_size,
+ pin_memory=True,
+ drop_last=True
+ )
+
+ self.val_dataloader = DataLoader(self.val_dataset,
+ # shuffle=False if self.mode == 'train' else True,
+ shuffle=False,
+ num_workers=4 * self.batch_size,
+ # num_workers=1,
+ batch_size=self.batch_size,
+ pin_memory=True,
+ drop_last=False
+ )
+
+ self.val_dataloader_iterator = iter(self.val_dataloader) # - should be after "reconstruct_metas_for_gru_fusion"
+
+ def train(self):
+ self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, 'logs'))
+ res_step = self.end_iter - self.iter_step
+
+ dataloader = self.train_dataloader
+
+ epochs = int(1 + res_step // len(dataloader))
+
+ self.adjust_learning_rate()
+ print(colored("starting training learning rate: {:.5f}".format(self.optimizer.param_groups[0]['lr']), "yellow"))
+
+ background_rgb = None
+ if self.use_white_bkgd:
+ # background_rgb = torch.ones([1, 3]).to(self.device)
+ background_rgb = 1.0
+
+ for epoch_i in range(epochs):
+
+ print(colored("current epoch %d" % epoch_i, 'red'))
+ dataloader = tqdm(dataloader)
+
+ for batch in dataloader:
+ # print("Checker1:, fetch data")
+ batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)]) # used to get meta
+
+ # - warmup params
+ if self.num_lods == 1:
+ alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0)
+ else:
+ alpha_inter_ratio_lod0 = 1.
+ alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1)
+
+ losses = self.trainer(
+ batch,
+ background_rgb=background_rgb,
+ alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
+ alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
+ iter_step=self.iter_step,
+ mode='train',
+ )
+
+ loss_types = ['loss_lod0', 'loss_lod1']
+ # print("[TEST]: weights_sum in trainer return", losses['losses_lod0']['weights_sum'].mean())
+
+ losses_lod0 = losses['losses_lod0']
+ losses_lod1 = losses['losses_lod1']
+ # import ipdb; ipdb.set_trace()
+ loss = 0
+ for loss_type in loss_types:
+ if losses[loss_type] is not None:
+ loss = loss + losses[loss_type].mean()
+ # print("Checker4:, begin BP")
+ self.optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(self.params_to_train, 1.0)
+ self.optimizer.step()
+ # print("Checker5:, end BP")
+ self.iter_step += 1
+
+ if self.iter_step % self.report_freq == 0:
+ self.writer.add_scalar('Loss/loss', loss, self.iter_step)
+
+ if losses_lod0 is not None:
+ self.writer.add_scalar('Loss/d_loss_lod0',
+ losses_lod0['depth_loss'].mean() if losses_lod0 is not None else 0,
+ self.iter_step)
+ self.writer.add_scalar('Loss/sparse_loss_lod0',
+ losses_lod0[
+ 'sparse_loss'].mean() if losses_lod0 is not None else 0,
+ self.iter_step)
+ self.writer.add_scalar('Loss/color_loss_lod0',
+ losses_lod0['color_fine_loss'].mean()
+ if losses_lod0['color_fine_loss'] is not None else 0,
+ self.iter_step)
+
+ self.writer.add_scalar('statis/psnr_lod0',
+ losses_lod0['psnr'].mean()
+ if losses_lod0['psnr'] is not None else 0,
+ self.iter_step)
+
+ self.writer.add_scalar('param/variance_lod0',
+ 1. / torch.exp(self.variance_network_lod0.variance * 10),
+ self.iter_step)
+ self.writer.add_scalar('param/eikonal_loss', losses_lod0['gradient_error_loss'].mean() if losses_lod0 is not None else 0,
+ self.iter_step)
+
+ ######## - lod 1
+ if self.num_lods > 1:
+ self.writer.add_scalar('Loss/d_loss_lod1',
+ losses_lod1['depth_loss'].mean() if losses_lod1 is not None else 0,
+ self.iter_step)
+ self.writer.add_scalar('Loss/sparse_loss_lod1',
+ losses_lod1[
+ 'sparse_loss'].mean() if losses_lod1 is not None else 0,
+ self.iter_step)
+ self.writer.add_scalar('Loss/color_loss_lod1',
+ losses_lod1['color_fine_loss'].mean()
+ if losses_lod1['color_fine_loss'] is not None else 0,
+ self.iter_step)
+ self.writer.add_scalar('statis/sdf_mean_lod1',
+ losses_lod1['sdf_mean'].mean() if losses_lod1 is not None else 0,
+ self.iter_step)
+ self.writer.add_scalar('statis/psnr_lod1',
+ losses_lod1['psnr'].mean()
+ if losses_lod1['psnr'] is not None else 0,
+ self.iter_step)
+ self.writer.add_scalar('statis/sparseness_0.01_lod1',
+ losses_lod1['sparseness_1'].mean()
+ if losses_lod1['sparseness_1'] is not None else 0,
+ self.iter_step)
+ self.writer.add_scalar('statis/sparseness_0.02_lod1',
+ losses_lod1['sparseness_2'].mean()
+ if losses_lod1['sparseness_2'] is not None else 0,
+ self.iter_step)
+ self.writer.add_scalar('param/variance_lod1',
+ 1. / torch.exp(self.variance_network_lod1.variance * 10),
+ self.iter_step)
+
+ print(self.base_exp_dir)
+ print(
+ 'iter:{:8>d} '
+ 'loss = {:.4f} '
+ 'd_loss_lod0 = {:.4f} '
+ 'color_loss_lod0 = {:.4f} '
+ 'sparse_loss_lod0= {:.4f} '
+ 'd_loss_lod1 = {:.4f} '
+ 'color_loss_lod1 = {:.4f} '
+ ' lr = {:.5f}'.format(
+ self.iter_step, loss,
+ losses_lod0['depth_loss'].mean() if losses_lod0 is not None else 0,
+ losses_lod0['color_fine_loss'].mean() if losses_lod0 is not None else 0,
+ losses_lod0['sparse_loss'].mean() if losses_lod0 is not None else 0,
+ losses_lod1['depth_loss'].mean() if losses_lod1 is not None else 0,
+ losses_lod1['color_fine_loss'].mean() if losses_lod1 is not None else 0,
+ self.optimizer.param_groups[0]['lr']))
+
+ print(colored('alpha_inter_ratio_lod0 = {:.4f} alpha_inter_ratio_lod1 = {:.4f}\n'.format(
+ alpha_inter_ratio_lod0, alpha_inter_ratio_lod1), 'green'))
+
+ if losses_lod0 is not None:
+ # print("[TEST]: weights_sum in print", losses_lod0['weights_sum'].mean())
+ # import ipdb; ipdb.set_trace()
+ print(
+ 'iter:{:8>d} '
+ 'variance = {:.5f} '
+ 'weights_sum = {:.4f} '
+ 'weights_sum_fg = {:.4f} '
+ 'alpha_sum = {:.4f} '
+ 'sparse_weight= {:.4f} '
+ 'background_loss = {:.4f} '
+ 'background_weight = {:.4f} '
+ .format(
+ self.iter_step,
+ losses_lod0['variance'].mean(),
+ losses_lod0['weights_sum'].mean(),
+ losses_lod0['weights_sum_fg'].mean(),
+ losses_lod0['alpha_sum'].mean(),
+ losses_lod0['sparse_weight'].mean(),
+ losses_lod0['fg_bg_loss'].mean(),
+ losses_lod0['fg_bg_weight'].mean(),
+ ))
+
+ if losses_lod1 is not None:
+ print(
+ 'iter:{:8>d} '
+ 'variance = {:.5f} '
+ ' weights_sum = {:.4f} '
+ 'alpha_sum = {:.4f} '
+ 'fg_bg_loss = {:.4f} '
+ 'fg_bg_weight = {:.4f} '
+ 'sparse_weight= {:.4f} '
+ 'fg_bg_loss = {:.4f} '
+ 'fg_bg_weight = {:.4f} '
+ .format(
+ self.iter_step,
+ losses_lod1['variance'].mean(),
+ losses_lod1['weights_sum'].mean(),
+ losses_lod1['alpha_sum'].mean(),
+ losses_lod1['fg_bg_loss'].mean(),
+ losses_lod1['fg_bg_weight'].mean(),
+ losses_lod1['sparse_weight'].mean(),
+ losses_lod1['fg_bg_loss'].mean(),
+ losses_lod1['fg_bg_weight'].mean(),
+ ))
+
+ if self.iter_step % self.save_freq == 0:
+ self.save_checkpoint()
+
+ if self.iter_step % self.val_freq == 0:
+ self.validate()
+
+ # - ajust learning rate
+ self.adjust_learning_rate()
+
+ def adjust_learning_rate(self):
+ # - ajust learning rate, cosine learning schedule
+ learning_rate = (np.cos(np.pi * self.iter_step / self.end_iter) + 1.0) * 0.5 * 0.9 + 0.1
+ learning_rate = self.learning_rate * learning_rate
+ for g in self.optimizer.param_groups:
+ g['lr'] = learning_rate
+
+ def get_alpha_inter_ratio(self, start, end):
+ if end == 0.0:
+ return 1.0
+ elif self.iter_step < start:
+ return 0.0
+ else:
+ return np.min([1.0, (self.iter_step - start) / (end - start)])
+
+ def file_backup(self):
+ # copy python file
+ dir_lis = self.conf['general.recording']
+ os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True)
+ for dir_name in dir_lis:
+ cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name)
+ os.makedirs(cur_dir, exist_ok=True)
+ files = os.listdir(dir_name)
+ for f_name in files:
+ if f_name[-3:] == '.py':
+ copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name))
+
+ # copy configs
+ copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf'))
+
+ def load_checkpoint(self, checkpoint_name):
+
+ def load_state_dict(network, checkpoint, comment):
+ if network is not None:
+ try:
+ pretrained_dict = checkpoint[comment]
+
+ model_dict = network.state_dict()
+
+ # 1. filter out unnecessary keys
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
+ # 2. overwrite entries in the existing state dict
+ model_dict.update(pretrained_dict)
+ # 3. load the new state dict
+ network.load_state_dict(pretrained_dict)
+ except:
+ print(colored(comment + " load fails", 'yellow'))
+
+ checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name),
+ map_location=self.device)
+
+ load_state_dict(self.rendering_network_outside, checkpoint, 'rendering_network_outside')
+
+ load_state_dict(self.sdf_network_lod0, checkpoint, 'sdf_network_lod0')
+ load_state_dict(self.sdf_network_lod1, checkpoint, 'sdf_network_lod1')
+
+ load_state_dict(self.pyramid_feature_network, checkpoint, 'pyramid_feature_network')
+ load_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network_lod1')
+
+ load_state_dict(self.variance_network_lod0, checkpoint, 'variance_network_lod0')
+ load_state_dict(self.variance_network_lod1, checkpoint, 'variance_network_lod1')
+
+ load_state_dict(self.rendering_network_lod0, checkpoint, 'rendering_network_lod0')
+ load_state_dict(self.rendering_network_lod1, checkpoint, 'rendering_network_lod1')
+
+ if self.restore_lod0: # use the trained lod0 networks to initialize lod1 networks
+ load_state_dict(self.sdf_network_lod1, checkpoint, 'sdf_network_lod0')
+ load_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network')
+ load_state_dict(self.rendering_network_lod1, checkpoint, 'rendering_network_lod0')
+
+ if self.is_continue and (not self.restore_lod0):
+ try:
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
+ except:
+ print(colored("load optimizer fails", "yellow"))
+ self.iter_step = checkpoint['iter_step']
+ self.val_step = checkpoint['val_step'] if 'val_step' in checkpoint.keys() else 0
+
+ self.logger.info('End')
+
+ def save_checkpoint(self):
+
+ def save_state_dict(network, checkpoint, comment):
+ if network is not None:
+ checkpoint[comment] = network.state_dict()
+
+ checkpoint = {
+ 'optimizer': self.optimizer.state_dict(),
+ 'iter_step': self.iter_step,
+ 'val_step': self.val_step,
+ }
+
+ save_state_dict(self.sdf_network_lod0, checkpoint, "sdf_network_lod0")
+ save_state_dict(self.sdf_network_lod1, checkpoint, "sdf_network_lod1")
+
+ save_state_dict(self.rendering_network_outside, checkpoint, 'rendering_network_outside')
+ save_state_dict(self.rendering_network_lod0, checkpoint, "rendering_network_lod0")
+ save_state_dict(self.rendering_network_lod1, checkpoint, "rendering_network_lod1")
+
+ save_state_dict(self.variance_network_lod0, checkpoint, 'variance_network_lod0')
+ save_state_dict(self.variance_network_lod1, checkpoint, 'variance_network_lod1')
+
+ save_state_dict(self.pyramid_feature_network, checkpoint, 'pyramid_feature_network')
+ save_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network_lod1')
+
+ os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True)
+ torch.save(checkpoint,
+ os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step)))
+
+ def validate(self, idx=-1, resolution_level=-1):
+ # validate image
+
+ ic(self.iter_step, idx)
+ self.logger.info('Validate begin')
+
+ if idx < 0:
+ idx = self.val_step
+ # idx = np.random.randint(len(self.val_dataset))
+ self.val_step += 1
+
+ try:
+ batch = next(self.val_dataloader_iterator)
+ except:
+ self.val_dataloader_iterator = iter(self.val_dataloader) # reset
+
+ batch = next(self.val_dataloader_iterator)
+
+
+ background_rgb = None
+ if self.use_white_bkgd:
+ # background_rgb = torch.ones([1, 3]).to(self.device)
+ background_rgb = 1.0
+
+ batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)])
+
+ # - warmup params
+ if self.num_lods == 1:
+ alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0)
+ else:
+ alpha_inter_ratio_lod0 = 1.
+ alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1)
+
+ self.trainer(
+ batch,
+ background_rgb=background_rgb,
+ alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
+ alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
+ iter_step=self.iter_step,
+ save_vis=True,
+ mode='val',
+ )
+
+
+ def export_mesh(self, idx=-1, resolution_level=-1):
+ # validate image
+
+ ic(self.iter_step, idx)
+ self.logger.info('Validate begin')
+ import time
+ start1 = time.time()
+ if idx < 0:
+ idx = self.val_step
+ # idx = np.random.randint(len(self.val_dataset))
+ self.val_step += 1
+
+ try:
+ batch = next(self.val_dataloader_iterator)
+ except:
+ self.val_dataloader_iterator = iter(self.val_dataloader) # reset
+
+ batch = next(self.val_dataloader_iterator)
+
+
+ background_rgb = None
+ if self.use_white_bkgd:
+ # background_rgb = torch.ones([1, 3]).to(self.device)
+ background_rgb = 1.0
+
+ batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)])
+
+ # - warmup params
+ if self.num_lods == 1:
+ alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0)
+ else:
+ alpha_inter_ratio_lod0 = 1.
+ alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1)
+ end1 = time.time()
+ print("time for getting data", end1 - start1)
+ self.trainer(
+ batch,
+ background_rgb=background_rgb,
+ alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
+ alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
+ iter_step=self.iter_step,
+ save_vis=True,
+ mode='export_mesh',
+ )
+
+
+if __name__ == '__main__':
+ # torch.set_default_tensor_type('torch.cuda.FloatTensor')
+ torch.set_default_dtype(torch.float32)
+ FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
+ logging.basicConfig(level=logging.INFO, format=FORMAT)
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--conf', type=str, default='./confs/base.conf')
+ parser.add_argument('--mode', type=str, default='train')
+ parser.add_argument('--threshold', type=float, default=0.0)
+ parser.add_argument('--is_continue', default=False, action="store_true")
+ parser.add_argument('--is_restore', default=False, action="store_true")
+ parser.add_argument('--is_finetune', default=False, action="store_true")
+ parser.add_argument('--train_from_scratch', default=False, action="store_true")
+ parser.add_argument('--restore_lod0', default=False, action="store_true")
+ parser.add_argument('--local_rank', type=int, default=0)
+ parser.add_argument('--specific_dataset_name', type=str, default='GSO')
+
+
+ args = parser.parse_args()
+
+ torch.cuda.set_device(args.local_rank)
+ torch.backends.cudnn.benchmark = True # ! make training 2x faster
+
+ runner = Runner(args.conf, args.mode, args.is_continue, args.is_restore, args.restore_lod0,
+ args.local_rank)
+
+ if args.mode == 'train':
+ runner.train()
+ elif args.mode == 'val':
+ for i in range(len(runner.val_dataset)):
+ runner.validate()
+ elif args.mode == 'export_mesh':
+ for i in range(len(runner.val_dataset)):
+ runner.export_mesh()
diff --git a/SparseNeuS_demo_v1/loss/__init__.py b/SparseNeuS_demo_v1/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SparseNeuS_demo_v1/loss/color_loss.py b/SparseNeuS_demo_v1/loss/color_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddb5d1d0d6f7b71416b010ad8d167f4c2eb04f1c
--- /dev/null
+++ b/SparseNeuS_demo_v1/loss/color_loss.py
@@ -0,0 +1,156 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import icecream as ic
+from loss.ncc import NCC
+from termcolor import colored
+
+
+class Normalize(nn.Module):
+ def __init__(self):
+ super(Normalize, self).__init__()
+
+ def forward(self, bottom):
+ qn = torch.norm(bottom, p=2, dim=1).unsqueeze(dim=1) + 1e-12
+ top = bottom.div(qn)
+
+ return top
+
+
+class OcclusionColorLoss(nn.Module):
+ def __init__(self, alpha=1, beta=0.025, gama=0.01, occlusion_aware=True, weight_thred=[0.6]):
+ super(OcclusionColorLoss, self).__init__()
+ self.alpha = alpha
+ self.beta = beta
+ self.gama = gama
+ self.occlusion_aware = occlusion_aware
+ self.eps = 1e-4
+
+ self.weight_thred = weight_thred
+ self.adjuster = ParamAdjuster(self.weight_thred, self.beta)
+
+ def forward(self, pred, gt, weight, mask, detach=False, occlusion_aware=True):
+ """
+
+ :param pred: [N_pts, 3]
+ :param gt: [N_pts, 3]
+ :param weight: [N_pts]
+ :param mask: [N_pts]
+ :return:
+ """
+ if detach:
+ weight = weight.detach()
+
+ error = torch.abs(pred - gt).sum(dim=-1, keepdim=False) # [N_pts]
+ error = error[mask]
+
+ if not (self.occlusion_aware and occlusion_aware):
+ return torch.mean(error), torch.mean(error)
+
+ beta = self.adjuster(weight.mean())
+
+ # weight = weight[mask]
+ weight = weight.clamp(0.0, 1.0)
+ term1 = self.alpha * torch.mean(weight[mask] * error)
+ term2 = beta * torch.log(1 - weight + self.eps).mean()
+ term3 = self.gama * torch.log(weight + self.eps).mean()
+
+ return term1 + term2 + term3, term1
+
+
+class OcclusionColorPatchLoss(nn.Module):
+ def __init__(self, alpha=1, beta=0.025, gama=0.015,
+ occlusion_aware=True, type='l1', h_patch_size=3, weight_thred=[0.6]):
+ super(OcclusionColorPatchLoss, self).__init__()
+ self.alpha = alpha
+ self.beta = beta
+ self.gama = gama
+ self.occlusion_aware = occlusion_aware
+ self.type = type # 'l1' or 'ncc' loss
+ self.ncc = NCC(h_patch_size=h_patch_size)
+ self.eps = 1e-4
+ self.weight_thred = weight_thred
+
+ self.adjuster = ParamAdjuster(self.weight_thred, self.beta)
+
+ print("type {} patch_size {} beta {} gama {} weight_thred {}".format(type, h_patch_size, beta, gama,
+ weight_thred))
+
+ def forward(self, pred, gt, weight, mask, penalize_ratio=0.9, detach=False, occlusion_aware=True):
+ """
+
+ :param pred: [N_pts, Npx, 3]
+ :param gt: [N_pts, Npx, 3]
+ :param weight: [N_pts]
+ :param mask: [N_pts]
+ :return:
+ """
+
+ if detach:
+ weight = weight.detach()
+
+ if self.type == 'l1':
+ error = torch.abs(pred - gt).mean(dim=-1, keepdim=False).sum(dim=-1, keepdim=False) # [N_pts]
+ elif self.type == 'ncc':
+ error = 1 - self.ncc(pred[:, None, :, :], gt)[:, 0] # ncc 1 positive, -1 negative
+ error, indices = torch.sort(error)
+ mask = torch.index_select(mask, 0, index=indices)
+ mask[int(penalize_ratio * mask.shape[0]):] = False # can help boundaries
+ elif self.type == 'ssd':
+ error = ((pred - gt) ** 2).mean(dim=-1, keepdim=False).sum(dim=-1, keepdims=False)
+
+ error = error[mask]
+ if not (self.occlusion_aware and occlusion_aware):
+ return torch.mean(error), torch.mean(error), 0.
+
+ # * weight adjuster
+ beta = self.adjuster(weight.mean())
+
+ # weight = weight[mask]
+ weight = weight.clamp(0.0, 1.0)
+
+ term1 = self.alpha * torch.mean(weight[mask] * error)
+ term2 = beta * torch.log(1 - weight + self.eps).mean()
+ term3 = self.gama * torch.log(weight + self.eps).mean()
+
+ return term1 + term2 + term3, term1, beta
+
+
+class ParamAdjuster(nn.Module):
+ def __init__(self, weight_thred, param):
+ super(ParamAdjuster, self).__init__()
+ self.weight_thred = weight_thred
+ self.thred_num = len(weight_thred)
+ self.param = param
+ self.global_step = 0
+ self.statis_window = 100
+ self.counter = 0
+ self.adjusted = False
+ self.adjusted_step = 0
+ self.thred_idx = 0
+
+ def reset(self):
+ self.counter = 0
+ self.adjusted = False
+
+ def adjust(self):
+ if (self.counter / self.statis_window) > 0.3:
+ self.param = self.param + 0.005
+ self.adjusted = True
+ self.adjusted_step = self.global_step
+ self.thred_idx += 1
+ print(colored("ajusted param, now {}".format(self.param), 'red'))
+
+ def forward(self, weight_mean):
+ self.global_step += 1
+
+ if (self.global_step % self.statis_window == 0) and self.adjusted is False:
+ self.adjust()
+ self.reset()
+
+ if self.thred_idx < self.thred_num:
+ if weight_mean < self.weight_thred[self.thred_idx] and (not self.adjusted):
+ self.counter += 1
+
+ return self.param
diff --git a/SparseNeuS_demo_v1/loss/depth_loss.py b/SparseNeuS_demo_v1/loss/depth_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..cba92851a79857ff6edd5c2f2eb12a2972b85bdc
--- /dev/null
+++ b/SparseNeuS_demo_v1/loss/depth_loss.py
@@ -0,0 +1,71 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class DepthLoss(nn.Module):
+ def __init__(self, type='l1'):
+ super(DepthLoss, self).__init__()
+ self.type = type
+
+
+ def forward(self, depth_pred, depth_gt, mask=None):
+ if (depth_gt < 0).sum() > 0:
+ # print("no depth loss")
+ return torch.tensor(0.0).to(depth_pred.device)
+ if mask is not None:
+ mask_d = (depth_gt > 0).float()
+
+ mask = mask * mask_d
+
+ mask_sum = mask.sum() + 1e-5
+ depth_error = (depth_pred - depth_gt) * mask
+ depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
+ reduction='sum') / mask_sum
+ else:
+ depth_error = depth_pred - depth_gt
+ depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
+ reduction='mean')
+ return depth_loss
+
+def forward(self, depth_pred, depth_gt, mask=None):
+ if mask is not None:
+ mask_d = (depth_gt > 0).float()
+
+ mask = mask * mask_d
+
+ mask_sum = mask.sum() + 1e-5
+ depth_error = (depth_pred - depth_gt) * mask
+ depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
+ reduction='sum') / mask_sum
+ else:
+ depth_error = depth_pred - depth_gt
+ depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
+ reduction='mean')
+ return depth_loss
+
+class DepthSmoothLoss(nn.Module):
+ def __init__(self):
+ super(DepthSmoothLoss, self).__init__()
+
+ def forward(self, disp, img, mask):
+ """
+ Computes the smoothness loss for a disparity image
+ The color image is used for edge-aware smoothness
+ :param disp: [B, 1, H, W]
+ :param img: [B, 1, H, W]
+ :param mask: [B, 1, H, W]
+ :return:
+ """
+ grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
+ grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
+
+ grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
+ grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
+
+ grad_disp_x *= torch.exp(-grad_img_x)
+ grad_disp_y *= torch.exp(-grad_img_y)
+
+ grad_disp = (grad_disp_x * mask[:, :, :, :-1]).mean() + (grad_disp_y * mask[:, :, :-1, :]).mean()
+
+ return grad_disp
diff --git a/SparseNeuS_demo_v1/loss/depth_metric.py b/SparseNeuS_demo_v1/loss/depth_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8b6249ac6a06906e20a344f468fc1c6e4b992ae
--- /dev/null
+++ b/SparseNeuS_demo_v1/loss/depth_metric.py
@@ -0,0 +1,240 @@
+import numpy as np
+
+
+def l1(depth1, depth2):
+ """
+ Computes the l1 errors between the two depth maps.
+ Takes preprocessed depths (no nans, infs and non-positive values)
+
+ depth1: one depth map
+ depth2: another depth map
+
+ Returns:
+ L1(log)
+
+ """
+ assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0)))
+ diff = depth1 - depth2
+ num_pixels = float(diff.size)
+
+ if num_pixels == 0:
+ return np.nan
+ else:
+ return np.sum(np.absolute(diff)) / num_pixels
+
+
+def l1_inverse(depth1, depth2):
+ """
+ Computes the l1 errors between inverses of two depth maps.
+ Takes preprocessed depths (no nans, infs and non-positive values)
+
+ depth1: one depth map
+ depth2: another depth map
+
+ Returns:
+ L1(log)
+
+ """
+ assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0)))
+ diff = np.reciprocal(depth1) - np.reciprocal(depth2)
+ num_pixels = float(diff.size)
+
+ if num_pixels == 0:
+ return np.nan
+ else:
+ return np.sum(np.absolute(diff)) / num_pixels
+
+
+def rmse_log(depth1, depth2):
+ """
+ Computes the root min square errors between the logs of two depth maps.
+ Takes preprocessed depths (no nans, infs and non-positive values)
+
+ depth1: one depth map
+ depth2: another depth map
+
+ Returns:
+ RMSE(log)
+
+ """
+ assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0)))
+ log_diff = np.log(depth1) - np.log(depth2)
+ num_pixels = float(log_diff.size)
+
+ if num_pixels == 0:
+ return np.nan
+ else:
+ return np.sqrt(np.sum(np.square(log_diff)) / num_pixels)
+
+
+def rmse(depth1, depth2):
+ """
+ Computes the root min square errors between the two depth maps.
+ Takes preprocessed depths (no nans, infs and non-positive values)
+
+ depth1: one depth map
+ depth2: another depth map
+
+ Returns:
+ RMSE(log)
+
+ """
+ assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0)))
+ diff = depth1 - depth2
+ num_pixels = float(diff.size)
+
+ if num_pixels == 0:
+ return np.nan
+ else:
+ return np.sqrt(np.sum(np.square(diff)) / num_pixels)
+
+
+def scale_invariant(depth1, depth2):
+ """
+ Computes the scale invariant loss based on differences of logs of depth maps.
+ Takes preprocessed depths (no nans, infs and non-positive values)
+
+ depth1: one depth map
+ depth2: another depth map
+
+ Returns:
+ scale_invariant_distance
+
+ """
+ # sqrt(Eq. 3)
+ assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0)))
+ log_diff = np.log(depth1) - np.log(depth2)
+ num_pixels = float(log_diff.size)
+
+ if num_pixels == 0:
+ return np.nan
+ else:
+ return np.sqrt(np.sum(np.square(log_diff)) / num_pixels - np.square(np.sum(log_diff)) / np.square(num_pixels))
+
+
+def abs_relative(depth_pred, depth_gt):
+ """
+ Computes relative absolute distance.
+ Takes preprocessed depths (no nans, infs and non-positive values)
+
+ depth_pred: depth map prediction
+ depth_gt: depth map ground truth
+
+ Returns:
+ abs_relative_distance
+
+ """
+ assert (np.all(np.isfinite(depth_pred) & np.isfinite(depth_gt) & (depth_pred >= 0) & (depth_gt >= 0)))
+ diff = depth_pred - depth_gt
+ num_pixels = float(diff.size)
+
+ if num_pixels == 0:
+ return np.nan
+ else:
+ return np.sum(np.absolute(diff) / depth_gt) / num_pixels
+
+
+def avg_log10(depth1, depth2):
+ """
+ Computes average log_10 error (Liu, Neural Fields, 2015).
+ Takes preprocessed depths (no nans, infs and non-positive values)
+
+ depth1: one depth map
+ depth2: another depth map
+
+ Returns:
+ abs_relative_distance
+
+ """
+ assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0)))
+ log_diff = np.log10(depth1) - np.log10(depth2)
+ num_pixels = float(log_diff.size)
+
+ if num_pixels == 0:
+ return np.nan
+ else:
+ return np.sum(np.absolute(log_diff)) / num_pixels
+
+
+def sq_relative(depth_pred, depth_gt):
+ """
+ Computes relative squared distance.
+ Takes preprocessed depths (no nans, infs and non-positive values)
+
+ depth_pred: depth map prediction
+ depth_gt: depth map ground truth
+
+ Returns:
+ squared_relative_distance
+
+ """
+ assert (np.all(np.isfinite(depth_pred) & np.isfinite(depth_gt) & (depth_pred >= 0) & (depth_gt >= 0)))
+ diff = depth_pred - depth_gt
+ num_pixels = float(diff.size)
+
+ if num_pixels == 0:
+ return np.nan
+ else:
+ return np.sum(np.square(diff) / depth_gt) / num_pixels
+
+
+def ratio_threshold(depth1, depth2, threshold):
+ """
+ Computes the percentage of pixels for which the ratio of the two depth maps is less than a given threshold.
+ Takes preprocessed depths (no nans, infs and non-positive values)
+
+ depth1: one depth map
+ depth2: another depth map
+
+ Returns:
+ percentage of pixels with ratio less than the threshold
+
+ """
+ assert (threshold > 0.)
+ assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0)))
+ log_diff = np.log(depth1) - np.log(depth2)
+ num_pixels = float(log_diff.size)
+
+ if num_pixels == 0:
+ return np.nan
+ else:
+ return float(np.sum(np.absolute(log_diff) < np.log(threshold))) / num_pixels
+
+
+def compute_depth_errors(depth_pred, depth_gt, valid_mask):
+ """
+ Computes different distance measures between two depth maps.
+
+ depth_pred: depth map prediction
+ depth_gt: depth map ground truth
+ distances_to_compute: which distances to compute
+
+ Returns:
+ a dictionary with computed distances, and the number of valid pixels
+
+ """
+ depth_pred = depth_pred[valid_mask]
+ depth_gt = depth_gt[valid_mask]
+ num_valid = np.sum(valid_mask)
+
+ distances_to_compute = ['l1',
+ 'l1_inverse',
+ 'scale_invariant',
+ 'abs_relative',
+ 'sq_relative',
+ 'avg_log10',
+ 'rmse_log',
+ 'rmse',
+ 'ratio_threshold_1.25',
+ 'ratio_threshold_1.5625',
+ 'ratio_threshold_1.953125']
+
+ results = {'num_valid': num_valid}
+ for dist in distances_to_compute:
+ if dist.startswith('ratio_threshold'):
+ threshold = float(dist.split('_')[-1])
+ results[dist] = ratio_threshold(depth_pred, depth_gt, threshold)
+ else:
+ results[dist] = globals()[dist](depth_pred, depth_gt)
+
+ return results
diff --git a/SparseNeuS_demo_v1/loss/ncc.py b/SparseNeuS_demo_v1/loss/ncc.py
new file mode 100644
index 0000000000000000000000000000000000000000..768fcefc3aab55d8e3fed49f23ffb4a974eec4ec
--- /dev/null
+++ b/SparseNeuS_demo_v1/loss/ncc.py
@@ -0,0 +1,65 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+from math import exp, sqrt
+
+
+class NCC(torch.nn.Module):
+ def __init__(self, h_patch_size, mode='rgb'):
+ super(NCC, self).__init__()
+ self.window_size = 2 * h_patch_size + 1
+ self.mode = mode # 'rgb' or 'gray'
+ self.channel = 3
+ self.register_buffer("window", create_window(self.window_size, self.channel))
+
+ def forward(self, img_pred, img_gt):
+ """
+ :param img_pred: [Npx, nviews, npatch, c]
+ :param img_gt: [Npx, npatch, c]
+ :return:
+ """
+ ntotpx, nviews, npatch, channels = img_pred.shape
+
+ patch_size = int(sqrt(npatch))
+ patch_img_pred = img_pred.reshape(ntotpx, nviews, patch_size, patch_size, channels).permute(0, 1, 4, 2,
+ 3).contiguous()
+ patch_img_gt = img_gt.reshape(ntotpx, patch_size, patch_size, channels).permute(0, 3, 1, 2)
+
+ return _ncc(patch_img_pred, patch_img_gt, self.window, self.channel)
+
+
+def gaussian(window_size, sigma):
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
+ return gauss / gauss.sum()
+
+
+def create_window(window_size, channel, std=1.5):
+ _1D_window = gaussian(window_size, std).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0)
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
+ return window
+
+
+def _ncc(pred, gt, window, channel):
+ ntotpx, nviews, nc, h, w = pred.shape
+ flat_pred = pred.view(-1, nc, h, w)
+ mu1 = F.conv2d(flat_pred, window, padding=0, groups=channel).view(ntotpx, nviews, nc)
+ mu2 = F.conv2d(gt, window, padding=0, groups=channel).view(ntotpx, nc)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2).unsqueeze(1) # (ntotpx, 1, nc)
+
+ sigma1_sq = F.conv2d(flat_pred * flat_pred, window, padding=0, groups=channel).view(ntotpx, nviews, nc) - mu1_sq
+ sigma2_sq = F.conv2d(gt * gt, window, padding=0, groups=channel).view(ntotpx, 1, 3) - mu2_sq
+
+ sigma1 = torch.sqrt(sigma1_sq + 1e-4)
+ sigma2 = torch.sqrt(sigma2_sq + 1e-4)
+
+ pred_norm = (pred - mu1[:, :, :, None, None]) / (sigma1[:, :, :, None, None] + 1e-8) # [ntotpx, nviews, nc, h, w]
+ gt_norm = (gt[:, None, :, :, :] - mu2[:, None, :, None, None]) / (
+ sigma2[:, :, :, None, None] + 1e-8) # ntotpx, nc, h, w
+
+ ncc = F.conv2d((pred_norm * gt_norm).view(-1, nc, h, w), window, padding=0, groups=channel).view(
+ ntotpx, nviews, nc)
+
+ return torch.mean(ncc, dim=2)
diff --git a/SparseNeuS_demo_v1/models/__init__.py b/SparseNeuS_demo_v1/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SparseNeuS_demo_v1/models/embedder.py b/SparseNeuS_demo_v1/models/embedder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d327d92d9f64c0b32908dbee864160b65daa450e
--- /dev/null
+++ b/SparseNeuS_demo_v1/models/embedder.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+
+""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
+
+
+class Embedder:
+ def __init__(self, **kwargs):
+ self.kwargs = kwargs
+ self.create_embedding_fn()
+
+ def create_embedding_fn(self):
+ embed_fns = []
+ d = self.kwargs['input_dims']
+ out_dim = 0
+ if self.kwargs['include_input']:
+ embed_fns.append(lambda x: x)
+ out_dim += d
+
+ max_freq = self.kwargs['max_freq_log2']
+ N_freqs = self.kwargs['num_freqs']
+
+ if self.kwargs['log_sampling']:
+ freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
+ else:
+ freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs)
+
+ for freq in freq_bands:
+ for p_fn in self.kwargs['periodic_fns']:
+ if self.kwargs['normalize']:
+ embed_fns.append(lambda x, p_fn=p_fn,
+ freq=freq: p_fn(x * freq) / freq)
+ else:
+ embed_fns.append(lambda x, p_fn=p_fn,
+ freq=freq: p_fn(x * freq))
+ out_dim += d
+
+ self.embed_fns = embed_fns
+ self.out_dim = out_dim
+
+ def embed(self, inputs):
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
+
+
+def get_embedder(multires, normalize=False, input_dims=3):
+ embed_kwargs = {
+ 'include_input': True,
+ 'input_dims': input_dims,
+ 'max_freq_log2': multires - 1,
+ 'num_freqs': multires,
+ 'normalize': normalize,
+ 'log_sampling': True,
+ 'periodic_fns': [torch.sin, torch.cos],
+ }
+
+ embedder_obj = Embedder(**embed_kwargs)
+
+ def embed(x, eo=embedder_obj): return eo.embed(x)
+
+ return embed, embedder_obj.out_dim
+
+
+class Embedding(nn.Module):
+ def __init__(self, in_channels, N_freqs, logscale=True, normalize=False):
+ """
+ Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
+ in_channels: number of input channels (3 for both xyz and direction)
+ """
+ super(Embedding, self).__init__()
+ self.N_freqs = N_freqs
+ self.in_channels = in_channels
+ self.funcs = [torch.sin, torch.cos]
+ self.out_channels = in_channels * (len(self.funcs) * N_freqs + 1)
+ self.normalize = normalize
+
+ if logscale:
+ self.freq_bands = 2 ** torch.linspace(0, N_freqs - 1, N_freqs)
+ else:
+ self.freq_bands = torch.linspace(1, 2 ** (N_freqs - 1), N_freqs)
+
+ def forward(self, x):
+ """
+ Embeds x to (x, sin(2^k x), cos(2^k x), ...)
+ Different from the paper, "x" is also in the output
+ See https://github.com/bmild/nerf/issues/12
+
+ Inputs:
+ x: (B, self.in_channels)
+
+ Outputs:
+ out: (B, self.out_channels)
+ """
+ out = [x]
+ for freq in self.freq_bands:
+ for func in self.funcs:
+ if self.normalize:
+ out += [func(freq * x) / freq]
+ else:
+ out += [func(freq * x)]
+
+ return torch.cat(out, -1)
diff --git a/SparseNeuS_demo_v1/models/fast_renderer.py b/SparseNeuS_demo_v1/models/fast_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1faeba85e5b156d0de12e430287d90f4a803aa92
--- /dev/null
+++ b/SparseNeuS_demo_v1/models/fast_renderer.py
@@ -0,0 +1,316 @@
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from icecream import ic
+
+
+# - neus: use sphere-tracing to speed up depth maps extraction
+# This code snippet is heavily borrowed from IDR.
+class FastRenderer(nn.Module):
+ def __init__(self):
+ super(FastRenderer, self).__init__()
+
+ self.sdf_threshold = 5e-5
+ self.line_search_step = 0.5
+ self.line_step_iters = 1
+ self.sphere_tracing_iters = 10
+ self.n_steps = 100
+ self.n_secant_steps = 8
+
+ # - use sdf_network to inference sdf value or directly interpolate sdf value from precomputed sdf_volume
+ self.network_inference = False
+
+ def extract_depth_maps(self, rays_o, rays_d, near, far, sdf_network, conditional_volume):
+ with torch.no_grad():
+ curr_start_points, network_object_mask, acc_start_dis = self.get_intersection(
+ rays_o, rays_d, near, far,
+ sdf_network, conditional_volume)
+
+ network_object_mask = network_object_mask.reshape(-1)
+
+ return network_object_mask, acc_start_dis
+
+ def get_intersection(self, rays_o, rays_d, near, far, sdf_network, conditional_volume):
+ device = rays_o.device
+ num_pixels, _ = rays_d.shape
+
+ curr_start_points, unfinished_mask_start, acc_start_dis, acc_end_dis, min_dis, max_dis = \
+ self.sphere_tracing(rays_o, rays_d, near, far, sdf_network, conditional_volume)
+
+ network_object_mask = (acc_start_dis < acc_end_dis)
+
+ # The non convergent rays should be handled by the sampler
+ sampler_mask = unfinished_mask_start
+ sampler_net_obj_mask = torch.zeros_like(sampler_mask).bool().to(device)
+ if sampler_mask.sum() > 0:
+ # sampler_min_max = torch.zeros((num_pixels, 2)).to(device)
+ # sampler_min_max[sampler_mask, 0] = acc_start_dis[sampler_mask]
+ # sampler_min_max[sampler_mask, 1] = acc_end_dis[sampler_mask]
+
+ # ray_sampler(self, rays_o, rays_d, near, far, sampler_mask):
+ sampler_pts, sampler_net_obj_mask, sampler_dists = self.ray_sampler(rays_o,
+ rays_d,
+ acc_start_dis,
+ acc_end_dis,
+ sampler_mask,
+ sdf_network,
+ conditional_volume
+ )
+
+ curr_start_points[sampler_mask] = sampler_pts[sampler_mask]
+ acc_start_dis[sampler_mask] = sampler_dists[sampler_mask][:, None]
+ network_object_mask[sampler_mask] = sampler_net_obj_mask[sampler_mask][:, None]
+
+ # print('----------------------------------------------------------------')
+ # print('RayTracing: object = {0}/{1}, secant on {2}/{3}.'
+ # .format(network_object_mask.sum(), len(network_object_mask), sampler_net_obj_mask.sum(),
+ # sampler_mask.sum()))
+ # print('----------------------------------------------------------------')
+
+ return curr_start_points, network_object_mask, acc_start_dis
+
+ def sphere_tracing(self, rays_o, rays_d, near, far, sdf_network, conditional_volume):
+ ''' Run sphere tracing algorithm for max iterations from both sides of unit sphere intersection '''
+
+ device = rays_o.device
+
+ unfinished_mask_start = (near < far).reshape(-1).clone()
+ unfinished_mask_end = (near < far).reshape(-1).clone()
+
+ # Initialize start current points
+ curr_start_points = rays_o + rays_d * near
+ acc_start_dis = near.clone()
+
+ # Initialize end current points
+ curr_end_points = rays_o + rays_d * far
+ acc_end_dis = far.clone()
+
+ # Initizlize min and max depth
+ min_dis = acc_start_dis.clone()
+ max_dis = acc_end_dis.clone()
+
+ # Iterate on the rays (from both sides) till finding a surface
+ iters = 0
+
+ next_sdf_start = torch.zeros_like(acc_start_dis).to(device)
+
+ if self.network_inference:
+ sdf_func = sdf_network.sdf
+ else:
+ sdf_func = sdf_network.sdf_from_sdfvolume
+
+ next_sdf_start[unfinished_mask_start] = sdf_func(
+ curr_start_points[unfinished_mask_start],
+ conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0]
+
+ next_sdf_end = torch.zeros_like(acc_end_dis).to(device)
+ next_sdf_end[unfinished_mask_end] = sdf_func(curr_end_points[unfinished_mask_end],
+ conditional_volume, lod=0, gru_fusion=False)[
+ 'sdf_pts_scale%d' % 0]
+
+ while True:
+ # Update sdf
+ curr_sdf_start = torch.zeros_like(acc_start_dis).to(device)
+ curr_sdf_start[unfinished_mask_start] = next_sdf_start[unfinished_mask_start]
+ curr_sdf_start[curr_sdf_start <= self.sdf_threshold] = 0
+
+ curr_sdf_end = torch.zeros_like(acc_end_dis).to(device)
+ curr_sdf_end[unfinished_mask_end] = next_sdf_end[unfinished_mask_end]
+ curr_sdf_end[curr_sdf_end <= self.sdf_threshold] = 0
+
+ # Update masks
+ unfinished_mask_start = unfinished_mask_start & (curr_sdf_start > self.sdf_threshold).reshape(-1)
+ unfinished_mask_end = unfinished_mask_end & (curr_sdf_end > self.sdf_threshold).reshape(-1)
+
+ if (
+ unfinished_mask_start.sum() == 0 and unfinished_mask_end.sum() == 0) or iters == self.sphere_tracing_iters:
+ break
+ iters += 1
+
+ # Make step
+ # Update distance
+ acc_start_dis = acc_start_dis + curr_sdf_start
+ acc_end_dis = acc_end_dis - curr_sdf_end
+
+ # Update points
+ curr_start_points = rays_o + acc_start_dis * rays_d
+ curr_end_points = rays_o + acc_end_dis * rays_d
+
+ # Fix points which wrongly crossed the surface
+ next_sdf_start = torch.zeros_like(acc_start_dis).to(device)
+ if unfinished_mask_start.sum() > 0:
+ next_sdf_start[unfinished_mask_start] = sdf_func(curr_start_points[unfinished_mask_start],
+ conditional_volume, lod=0, gru_fusion=False)[
+ 'sdf_pts_scale%d' % 0]
+
+ next_sdf_end = torch.zeros_like(acc_end_dis).to(device)
+ if unfinished_mask_end.sum() > 0:
+ next_sdf_end[unfinished_mask_end] = sdf_func(curr_end_points[unfinished_mask_end],
+ conditional_volume, lod=0, gru_fusion=False)[
+ 'sdf_pts_scale%d' % 0]
+
+ not_projected_start = (next_sdf_start < 0).reshape(-1)
+ not_projected_end = (next_sdf_end < 0).reshape(-1)
+ not_proj_iters = 0
+
+ while (
+ not_projected_start.sum() > 0 or not_projected_end.sum() > 0) and not_proj_iters < self.line_step_iters:
+ # Step backwards
+ if not_projected_start.sum() > 0:
+ acc_start_dis[not_projected_start] -= ((1 - self.line_search_step) / (2 ** not_proj_iters)) * \
+ curr_sdf_start[not_projected_start]
+ curr_start_points[not_projected_start] = (rays_o + acc_start_dis * rays_d)[not_projected_start]
+
+ next_sdf_start[not_projected_start] = sdf_func(
+ curr_start_points[not_projected_start],
+ conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0]
+
+ if not_projected_end.sum() > 0:
+ acc_end_dis[not_projected_end] += ((1 - self.line_search_step) / (2 ** not_proj_iters)) * \
+ curr_sdf_end[
+ not_projected_end]
+ curr_end_points[not_projected_end] = (rays_o + acc_end_dis * rays_d)[not_projected_end]
+
+ # Calc sdf
+
+ next_sdf_end[not_projected_end] = sdf_func(
+ curr_end_points[not_projected_end],
+ conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0]
+
+ # Update mask
+ not_projected_start = (next_sdf_start < 0).reshape(-1)
+ not_projected_end = (next_sdf_end < 0).reshape(-1)
+ not_proj_iters += 1
+
+ unfinished_mask_start = unfinished_mask_start & (acc_start_dis < acc_end_dis).reshape(-1)
+ unfinished_mask_end = unfinished_mask_end & (acc_start_dis < acc_end_dis).reshape(-1)
+
+ return curr_start_points, unfinished_mask_start, acc_start_dis, acc_end_dis, min_dis, max_dis
+
+ def ray_sampler(self, rays_o, rays_d, near, far, sampler_mask, sdf_network, conditional_volume):
+ ''' Sample the ray in a given range and run secant on rays which have sign transition '''
+ device = rays_o.device
+ num_pixels, _ = rays_d.shape
+ sampler_pts = torch.zeros(num_pixels, 3).to(device).float()
+ sampler_dists = torch.zeros(num_pixels).to(device).float()
+
+ intervals_dist = torch.linspace(0, 1, steps=self.n_steps).to(device).view(1, -1)
+
+ pts_intervals = near + intervals_dist * (far - near)
+ points = rays_o[:, None, :] + pts_intervals[:, :, None] * rays_d[:, None, :]
+
+ # Get the non convergent rays
+ mask_intersect_idx = torch.nonzero(sampler_mask).flatten()
+ points = points.reshape((-1, self.n_steps, 3))[sampler_mask, :, :]
+ pts_intervals = pts_intervals.reshape((-1, self.n_steps))[sampler_mask]
+
+ if self.network_inference:
+ sdf_func = sdf_network.sdf
+ else:
+ sdf_func = sdf_network.sdf_from_sdfvolume
+
+ sdf_val_all = []
+ for pnts in torch.split(points.reshape(-1, 3), 100000, dim=0):
+ sdf_val_all.append(sdf_func(pnts,
+ conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0])
+ sdf_val = torch.cat(sdf_val_all).reshape(-1, self.n_steps)
+
+ tmp = torch.sign(sdf_val) * torch.arange(self.n_steps, 0, -1).to(device).float().reshape(
+ (1, self.n_steps)) # Force argmin to return the first min value
+ sampler_pts_ind = torch.argmin(tmp, -1)
+ sampler_pts[mask_intersect_idx] = points[torch.arange(points.shape[0]), sampler_pts_ind, :]
+ sampler_dists[mask_intersect_idx] = pts_intervals[torch.arange(pts_intervals.shape[0]), sampler_pts_ind]
+
+ net_surface_pts = (sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind] < 0)
+
+ # take points with minimal SDF value for P_out pixels
+ p_out_mask = ~net_surface_pts
+ n_p_out = p_out_mask.sum()
+ if n_p_out > 0:
+ out_pts_idx = torch.argmin(sdf_val[p_out_mask, :], -1)
+ sampler_pts[mask_intersect_idx[p_out_mask]] = points[p_out_mask, :, :][torch.arange(n_p_out), out_pts_idx,
+ :]
+ sampler_dists[mask_intersect_idx[p_out_mask]] = pts_intervals[p_out_mask, :][
+ torch.arange(n_p_out), out_pts_idx]
+
+ # Get Network object mask
+ sampler_net_obj_mask = sampler_mask.clone()
+ sampler_net_obj_mask[mask_intersect_idx[~net_surface_pts]] = False
+
+ # Run Secant method
+ secant_pts = net_surface_pts
+ n_secant_pts = secant_pts.sum()
+ if n_secant_pts > 0:
+ # Get secant z predictions
+ z_high = pts_intervals[torch.arange(pts_intervals.shape[0]), sampler_pts_ind][secant_pts]
+ sdf_high = sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind][secant_pts]
+ z_low = pts_intervals[secant_pts][torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1]
+ sdf_low = sdf_val[secant_pts][torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1]
+
+ cam_loc_secant = rays_o[mask_intersect_idx[secant_pts]]
+ ray_directions_secant = rays_d[mask_intersect_idx[secant_pts]]
+ z_pred_secant = self.secant(sdf_low, sdf_high, z_low, z_high, cam_loc_secant, ray_directions_secant,
+ sdf_network, conditional_volume)
+
+ # Get points
+ sampler_pts[mask_intersect_idx[secant_pts]] = cam_loc_secant + z_pred_secant[:,
+ None] * ray_directions_secant
+ sampler_dists[mask_intersect_idx[secant_pts]] = z_pred_secant
+
+ return sampler_pts, sampler_net_obj_mask, sampler_dists
+
+ def secant(self, sdf_low, sdf_high, z_low, z_high, rays_o, rays_d, sdf_network, conditional_volume):
+ ''' Runs the secant method for interval [z_low, z_high] for n_secant_steps '''
+
+ if self.network_inference:
+ sdf_func = sdf_network.sdf
+ else:
+ sdf_func = sdf_network.sdf_from_sdfvolume
+
+ z_pred = -sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low
+ for i in range(self.n_secant_steps):
+ p_mid = rays_o + z_pred[:, None] * rays_d
+ sdf_mid = sdf_func(p_mid,
+ conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0].reshape(-1)
+ ind_low = (sdf_mid > 0).reshape(-1)
+ if ind_low.sum() > 0:
+ z_low[ind_low] = z_pred[ind_low]
+ sdf_low[ind_low] = sdf_mid[ind_low]
+ ind_high = sdf_mid < 0
+ if ind_high.sum() > 0:
+ z_high[ind_high] = z_pred[ind_high]
+ sdf_high[ind_high] = sdf_mid[ind_high]
+
+ z_pred = - sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low
+
+ return z_pred # 1D tensor
+
+ def minimal_sdf_points(self, num_pixels, sdf, cam_loc, ray_directions, mask, min_dis, max_dis):
+ ''' Find points with minimal SDF value on rays for P_out pixels '''
+ device = sdf.device
+ n_mask_points = mask.sum()
+
+ n = self.n_steps
+ # steps = torch.linspace(0.0, 1.0,n).to(device)
+ steps = torch.empty(n).uniform_(0.0, 1.0).to(device)
+ mask_max_dis = max_dis[mask].unsqueeze(-1)
+ mask_min_dis = min_dis[mask].unsqueeze(-1)
+ steps = steps.unsqueeze(0).repeat(n_mask_points, 1) * (mask_max_dis - mask_min_dis) + mask_min_dis
+
+ mask_points = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)[mask]
+ mask_rays = ray_directions[mask, :]
+
+ mask_points_all = mask_points.unsqueeze(1).repeat(1, n, 1) + steps.unsqueeze(-1) * mask_rays.unsqueeze(
+ 1).repeat(1, n, 1)
+ points = mask_points_all.reshape(-1, 3)
+
+ mask_sdf_all = []
+ for pnts in torch.split(points, 100000, dim=0):
+ mask_sdf_all.append(sdf(pnts))
+
+ mask_sdf_all = torch.cat(mask_sdf_all).reshape(-1, n)
+ min_vals, min_idx = mask_sdf_all.min(-1)
+ min_mask_points = mask_points_all.reshape(-1, n, 3)[torch.arange(0, n_mask_points), min_idx]
+ min_mask_dist = steps.reshape(-1, n)[torch.arange(0, n_mask_points), min_idx]
+
+ return min_mask_points, min_mask_dist
diff --git a/SparseNeuS_demo_v1/models/featurenet.py b/SparseNeuS_demo_v1/models/featurenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..652e65967708f57a1722c5951d53e72f05ddf1d3
--- /dev/null
+++ b/SparseNeuS_demo_v1/models/featurenet.py
@@ -0,0 +1,91 @@
+import torch
+
+# ! amazing!!!! autograd.grad with set_detect_anomaly(True) will cause memory leak
+# ! https://github.com/pytorch/pytorch/issues/51349
+# torch.autograd.set_detect_anomaly(True)
+import torch.nn as nn
+import torch.nn.functional as F
+from inplace_abn import InPlaceABN
+
+
+############################################# MVS Net models ################################################
+class ConvBnReLU(nn.Module):
+ def __init__(self, in_channels, out_channels,
+ kernel_size=3, stride=1, pad=1,
+ norm_act=InPlaceABN):
+ super(ConvBnReLU, self).__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels,
+ kernel_size, stride=stride, padding=pad, bias=False)
+ self.bn = norm_act(out_channels)
+
+ def forward(self, x):
+ return self.bn(self.conv(x))
+
+
+class ConvBnReLU3D(nn.Module):
+ def __init__(self, in_channels, out_channels,
+ kernel_size=3, stride=1, pad=1,
+ norm_act=InPlaceABN):
+ super(ConvBnReLU3D, self).__init__()
+ self.conv = nn.Conv3d(in_channels, out_channels,
+ kernel_size, stride=stride, padding=pad, bias=False)
+ self.bn = norm_act(out_channels)
+ # self.bn = nn.ReLU()
+
+ def forward(self, x):
+ return self.bn(self.conv(x))
+
+
+################################### feature net ######################################
+class FeatureNet(nn.Module):
+ """
+ output 3 levels of features using a FPN structure
+ """
+
+ def __init__(self, norm_act=InPlaceABN):
+ super(FeatureNet, self).__init__()
+
+ self.conv0 = nn.Sequential(
+ ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act),
+ ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act))
+
+ self.conv1 = nn.Sequential(
+ ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act),
+ ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act),
+ ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act))
+
+ self.conv2 = nn.Sequential(
+ ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act),
+ ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act),
+ ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act))
+
+ self.toplayer = nn.Conv2d(32, 32, 1)
+ self.lat1 = nn.Conv2d(16, 32, 1)
+ self.lat0 = nn.Conv2d(8, 32, 1)
+
+ # to reduce channel size of the outputs from FPN
+ self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
+ self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)
+
+ def _upsample_add(self, x, y):
+ return F.interpolate(x, scale_factor=2,
+ mode="bilinear", align_corners=True) + y
+
+ def forward(self, x):
+ # x: (B, 3, H, W)
+ conv0 = self.conv0(x) # (B, 8, H, W)
+ conv1 = self.conv1(conv0) # (B, 16, H//2, W//2)
+ conv2 = self.conv2(conv1) # (B, 32, H//4, W//4)
+ feat2 = self.toplayer(conv2) # (B, 32, H//4, W//4)
+ feat1 = self._upsample_add(feat2, self.lat1(conv1)) # (B, 32, H//2, W//2)
+ feat0 = self._upsample_add(feat1, self.lat0(conv0)) # (B, 32, H, W)
+
+ # reduce output channels
+ feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2)
+ feat0 = self.smooth0(feat0) # (B, 8, H, W)
+
+ # feats = {"level_0": feat0,
+ # "level_1": feat1,
+ # "level_2": feat2}
+
+ return [feat2, feat1, feat0] # coarser to finer features
diff --git a/SparseNeuS_demo_v1/models/fields.py b/SparseNeuS_demo_v1/models/fields.py
new file mode 100644
index 0000000000000000000000000000000000000000..184e4a55399f56f8f505379ce4a14add8821c4c4
--- /dev/null
+++ b/SparseNeuS_demo_v1/models/fields.py
@@ -0,0 +1,333 @@
+# The codes are from NeuS
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from models.embedder import get_embedder
+
+
+class SDFNetwork(nn.Module):
+ def __init__(self,
+ d_in,
+ d_out,
+ d_hidden,
+ n_layers,
+ skip_in=(4,),
+ multires=0,
+ bias=0.5,
+ scale=1,
+ geometric_init=True,
+ weight_norm=True,
+ activation='softplus',
+ conditional_type='multiply'):
+ super(SDFNetwork, self).__init__()
+
+ dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
+
+ self.embed_fn_fine = None
+
+ if multires > 0:
+ embed_fn, input_ch = get_embedder(multires, input_dims=d_in, normalize=False)
+ self.embed_fn_fine = embed_fn
+ dims[0] = input_ch
+
+ self.num_layers = len(dims)
+ self.skip_in = skip_in
+ self.scale = scale
+
+ for l in range(0, self.num_layers - 1):
+ if l + 1 in self.skip_in:
+ out_dim = dims[l + 1] - dims[0]
+ else:
+ out_dim = dims[l + 1]
+
+ lin = nn.Linear(dims[l], out_dim)
+
+ if geometric_init:
+ if l == self.num_layers - 2:
+ torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
+ torch.nn.init.constant_(lin.bias, -bias)
+ elif multires > 0 and l == 0:
+ torch.nn.init.constant_(lin.bias, 0.0)
+ torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
+ torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
+ elif multires > 0 and l in self.skip_in:
+ torch.nn.init.constant_(lin.bias, 0.0)
+ torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
+ torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) # ? why dims[0] - 3
+ else:
+ torch.nn.init.constant_(lin.bias, 0.0)
+ torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
+
+ if weight_norm:
+ lin = nn.utils.weight_norm(lin)
+
+ setattr(self, "lin" + str(l), lin)
+
+ if activation == 'softplus':
+ self.activation = nn.Softplus(beta=100)
+ else:
+ assert activation == 'relu'
+ self.activation = nn.ReLU()
+
+ def forward(self, inputs):
+ inputs = inputs * self.scale
+ if self.embed_fn_fine is not None:
+ inputs = self.embed_fn_fine(inputs)
+
+ x = inputs
+ for l in range(0, self.num_layers - 1):
+ lin = getattr(self, "lin" + str(l))
+
+ if l in self.skip_in:
+ x = torch.cat([x, inputs], 1) / np.sqrt(2)
+
+ x = lin(x)
+
+ if l < self.num_layers - 2:
+ x = self.activation(x)
+ return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1)
+
+ def sdf(self, x):
+ return self.forward(x)[:, :1]
+
+ def sdf_hidden_appearance(self, x):
+ return self.forward(x)
+
+ def gradient(self, x):
+ x.requires_grad_(True)
+ y = self.sdf(x)
+ d_output = torch.ones_like(y, requires_grad=False, device=y.device)
+ gradients = torch.autograd.grad(
+ outputs=y,
+ inputs=x,
+ grad_outputs=d_output,
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True)[0]
+ return gradients.unsqueeze(1)
+
+
+class VarianceNetwork(nn.Module):
+ def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0):
+ super(VarianceNetwork, self).__init__()
+
+ dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
+
+ self.embed_fn_fine = None
+
+ if multires > 0:
+ embed_fn, input_ch = get_embedder(multires, normalize=False)
+ self.embed_fn_fine = embed_fn
+ dims[0] = input_ch
+
+ self.num_layers = len(dims)
+ self.skip_in = skip_in
+
+ for l in range(0, self.num_layers - 1):
+ if l + 1 in self.skip_in:
+ out_dim = dims[l + 1] - dims[0]
+ else:
+ out_dim = dims[l + 1]
+
+ lin = nn.Linear(dims[l], out_dim)
+ setattr(self, "lin" + str(l), lin)
+
+ self.relu = nn.ReLU()
+ self.softplus = nn.Softplus(beta=100)
+
+ def forward(self, inputs):
+ if self.embed_fn_fine is not None:
+ inputs = self.embed_fn_fine(inputs)
+
+ x = inputs
+ for l in range(0, self.num_layers - 1):
+ lin = getattr(self, "lin" + str(l))
+
+ if l in self.skip_in:
+ x = torch.cat([x, inputs], 1) / np.sqrt(2)
+
+ x = lin(x)
+
+ if l < self.num_layers - 2:
+ x = self.relu(x)
+
+ # return torch.exp(x)
+ return 1.0 / (self.softplus(x + 0.5) + 1e-3)
+
+ def coarse(self, inputs):
+ return self.forward(inputs)[:, :1]
+
+ def fine(self, inputs):
+ return self.forward(inputs)[:, 1:]
+
+
+class FixVarianceNetwork(nn.Module):
+ def __init__(self, base):
+ super(FixVarianceNetwork, self).__init__()
+ self.base = base
+ self.iter_step = 0
+
+ def set_iter_step(self, iter_step):
+ self.iter_step = iter_step
+
+ def forward(self, x):
+ return torch.ones([len(x), 1]) * np.exp(-self.iter_step / self.base)
+
+
+class SingleVarianceNetwork(nn.Module):
+ def __init__(self, init_val=1.0):
+ super(SingleVarianceNetwork, self).__init__()
+ self.register_parameter('variance', nn.Parameter(torch.tensor(init_val)))
+
+ def forward(self, x):
+ return torch.ones([len(x), 1]).to(x.device) * torch.exp(self.variance * 10.0)
+
+
+
+class RenderingNetwork(nn.Module):
+ def __init__(
+ self,
+ d_feature,
+ mode,
+ d_in,
+ d_out,
+ d_hidden,
+ n_layers,
+ weight_norm=True,
+ multires_view=0,
+ squeeze_out=True,
+ d_conditional_colors=0
+ ):
+ super().__init__()
+
+ self.mode = mode
+ self.squeeze_out = squeeze_out
+ dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out]
+
+ self.embedview_fn = None
+ if multires_view > 0:
+ embedview_fn, input_ch = get_embedder(multires_view)
+ self.embedview_fn = embedview_fn
+ dims[0] += (input_ch - 3)
+
+ self.num_layers = len(dims)
+
+ for l in range(0, self.num_layers - 1):
+ out_dim = dims[l + 1]
+ lin = nn.Linear(dims[l], out_dim)
+
+ if weight_norm:
+ lin = nn.utils.weight_norm(lin)
+
+ setattr(self, "lin" + str(l), lin)
+
+ self.relu = nn.ReLU()
+
+ def forward(self, points, normals, view_dirs, feature_vectors):
+ if self.embedview_fn is not None:
+ view_dirs = self.embedview_fn(view_dirs)
+
+ rendering_input = None
+
+ if self.mode == 'idr':
+ rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
+ elif self.mode == 'no_view_dir':
+ rendering_input = torch.cat([points, normals, feature_vectors], dim=-1)
+ elif self.mode == 'no_normal':
+ rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1)
+ elif self.mode == 'no_points':
+ rendering_input = torch.cat([view_dirs, normals, feature_vectors], dim=-1)
+ elif self.mode == 'no_points_no_view_dir':
+ rendering_input = torch.cat([normals, feature_vectors], dim=-1)
+
+ x = rendering_input
+
+ for l in range(0, self.num_layers - 1):
+ lin = getattr(self, "lin" + str(l))
+
+ x = lin(x)
+
+ if l < self.num_layers - 2:
+ x = self.relu(x)
+
+ if self.squeeze_out:
+ x = torch.sigmoid(x)
+ return x
+
+
+# Code from nerf-pytorch
+class NeRF(nn.Module):
+ def __init__(self, D=8, W=256, d_in=3, d_in_view=3, multires=0, multires_view=0, output_ch=4, skips=[4],
+ use_viewdirs=False):
+ """
+ """
+ super(NeRF, self).__init__()
+ self.D = D
+ self.W = W
+ self.d_in = d_in
+ self.d_in_view = d_in_view
+ self.input_ch = 3
+ self.input_ch_view = 3
+ self.embed_fn = None
+ self.embed_fn_view = None
+
+ if multires > 0:
+ embed_fn, input_ch = get_embedder(multires, input_dims=d_in, normalize=False)
+ self.embed_fn = embed_fn
+ self.input_ch = input_ch
+
+ if multires_view > 0:
+ embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view, normalize=False)
+ self.embed_fn_view = embed_fn_view
+ self.input_ch_view = input_ch_view
+
+ self.skips = skips
+ self.use_viewdirs = use_viewdirs
+
+ self.pts_linears = nn.ModuleList(
+ [nn.Linear(self.input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W)
+ for i in
+ range(D - 1)])
+
+ ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
+ self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)])
+
+ ### Implementation according to the paper
+ # self.views_linears = nn.ModuleList(
+ # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
+
+ if use_viewdirs:
+ self.feature_linear = nn.Linear(W, W)
+ self.alpha_linear = nn.Linear(W, 1)
+ self.rgb_linear = nn.Linear(W // 2, 3)
+ else:
+ self.output_linear = nn.Linear(W, output_ch)
+
+ def forward(self, input_pts, input_views):
+ if self.embed_fn is not None:
+ input_pts = self.embed_fn(input_pts)
+ if self.embed_fn_view is not None:
+ input_views = self.embed_fn_view(input_views)
+
+ h = input_pts
+ for i, l in enumerate(self.pts_linears):
+ h = self.pts_linears[i](h)
+ h = F.relu(h)
+ if i in self.skips:
+ h = torch.cat([input_pts, h], -1)
+
+ if self.use_viewdirs:
+ alpha = self.alpha_linear(h)
+ feature = self.feature_linear(h)
+ h = torch.cat([feature, input_views], -1)
+
+ for i, l in enumerate(self.views_linears):
+ h = self.views_linears[i](h)
+ h = F.relu(h)
+
+ rgb = self.rgb_linear(h)
+ return alpha + 1.0, rgb
+ else:
+ assert False
diff --git a/SparseNeuS_demo_v1/models/patch_projector.py b/SparseNeuS_demo_v1/models/patch_projector.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf9ca424c588e49d754988814233069b2cf127fa
--- /dev/null
+++ b/SparseNeuS_demo_v1/models/patch_projector.py
@@ -0,0 +1,211 @@
+"""
+Patch Projector
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from models.render_utils import sample_ptsFeatures_from_featureMaps
+
+
+class PatchProjector():
+ def __init__(self, patch_size):
+ self.h_patch_size = patch_size
+ self.offsets = build_patch_offset(patch_size) # the warping patch offsets index
+
+ self.z_axis = torch.tensor([0, 0, 1]).float()
+
+ self.plane_dist_thresh = 0.001
+
+ # * correctness checked
+ def pixel_warp(self, pts, imgs, intrinsics,
+ w2cs, img_wh=None):
+ """
+
+ :param pts: [N_rays, n_samples, 3]
+ :param imgs: [N_views, 3, H, W]
+ :param intrinsics: [N_views, 4, 4]
+ :param c2ws: [N_views, 4, 4]
+ :param img_wh:
+ :return:
+ """
+ if img_wh is None:
+ N_views, _, sizeH, sizeW = imgs.shape
+ img_wh = [sizeW, sizeH]
+
+ pts_color, valid_mask = sample_ptsFeatures_from_featureMaps(
+ pts, imgs, w2cs, intrinsics, img_wh,
+ proj_matrix=None, return_mask=True) # [N_views, c, N_rays, n_samples], [N_views, N_rays, n_samples]
+
+ pts_color = pts_color.permute(2, 3, 0, 1)
+ valid_mask = valid_mask.permute(1, 2, 0)
+
+ return pts_color, valid_mask # [N_rays, n_samples, N_views, 3] , [N_rays, n_samples, N_views]
+
+ def patch_warp(self, pts, uv, normals, src_imgs,
+ ref_intrinsic, src_intrinsics,
+ ref_c2w, src_c2ws, img_wh=None
+ ):
+ """
+
+ :param pts: [N_rays, n_samples, 3]
+ :param uv : [N_rays, 2] normalized in (-1, 1)
+ :param normals: [N_rays, n_samples, 3] The normal of pt in world space
+ :param src_imgs: [N_src, 3, h, w]
+ :param ref_intrinsic: [4,4]
+ :param src_intrinsics: [N_src, 4, 4]
+ :param ref_c2w: [4,4]
+ :param src_c2ws: [N_src, 4, 4]
+ :return:
+ """
+ device = pts.device
+
+ N_rays, n_samples, _ = pts.shape
+ N_pts = N_rays * n_samples
+
+ N_src, _, sizeH, sizeW = src_imgs.shape
+
+ if img_wh is not None:
+ sizeW, sizeH = img_wh[0], img_wh[1]
+
+ # scale uv from (-1, 1) to (0, W/H)
+ uv[:, 0] = (uv[:, 0] + 1) / 2. * (sizeW - 1)
+ uv[:, 1] = (uv[:, 1] + 1) / 2. * (sizeH - 1)
+
+ ref_intr = ref_intrinsic[:3, :3]
+ inv_ref_intr = torch.inverse(ref_intr)
+ src_intrs = src_intrinsics[:, :3, :3]
+ inv_src_intrs = torch.inverse(src_intrs)
+
+ ref_pose = ref_c2w
+ inv_ref_pose = torch.inverse(ref_pose)
+ src_poses = src_c2ws
+ inv_src_poses = torch.inverse(src_poses)
+
+ ref_cam_loc = ref_pose[:3, 3].unsqueeze(0) # [1, 3]
+ sampled_dists = torch.norm(pts - ref_cam_loc, dim=-1) # [N_pts, 1]
+
+ relative_proj = inv_src_poses @ ref_pose
+ R_rel = relative_proj[:, :3, :3]
+ t_rel = relative_proj[:, :3, 3:]
+ R_ref = inv_ref_pose[:3, :3]
+ t_ref = inv_ref_pose[:3, 3:]
+
+ pts = pts.view(-1, 3)
+ normals = normals.view(-1, 3)
+
+ with torch.no_grad():
+ rot_normals = R_ref @ normals.unsqueeze(-1) # [N_pts, 3, 1]
+ points_in_ref = R_ref @ pts.unsqueeze(
+ -1) + t_ref # [N_pts, 3, 1] points in the reference frame coordiantes system
+ d1 = torch.sum(rot_normals * points_in_ref, dim=1).unsqueeze(
+ 1) # distance from the plane to ref camera center
+
+ d2 = torch.sum(rot_normals.unsqueeze(1) * (-R_rel.transpose(1, 2) @ t_rel).unsqueeze(0),
+ dim=2) # distance from the plane to src camera center
+ valid_hom = (torch.abs(d1) > self.plane_dist_thresh) & (
+ torch.abs(d1 - d2) > self.plane_dist_thresh) & ((d2 / d1) < 1)
+
+ d1 = d1.squeeze()
+ sign = torch.sign(d1)
+ sign[sign == 0] = 1
+ d = torch.clamp(torch.abs(d1), 1e-8) * sign
+
+ H = src_intrs.unsqueeze(1) @ (
+ R_rel.unsqueeze(1) + t_rel.unsqueeze(1) @ rot_normals.view(1, N_pts, 1, 3) / d.view(1,
+ N_pts,
+ 1, 1)
+ ) @ inv_ref_intr.view(1, 1, 3, 3)
+
+ # replace invalid homs with fronto-parallel homographies
+ H_invalid = src_intrs.unsqueeze(1) @ (
+ R_rel.unsqueeze(1) + t_rel.unsqueeze(1) @ self.z_axis.to(device).view(1, 1, 1, 3).expand(-1, N_pts,
+ -1,
+ -1) / sampled_dists.view(
+ 1, N_pts, 1, 1)
+ ) @ inv_ref_intr.view(1, 1, 3, 3)
+ tmp_m = ~valid_hom.view(-1, N_src).t()
+ H[tmp_m] = H_invalid[tmp_m]
+
+ pixels = uv.view(N_rays, 1, 2) + self.offsets.float().to(device)
+ Npx = pixels.shape[1]
+ grid, warp_mask_full = self.patch_homography(H, pixels)
+
+ warp_mask_full = warp_mask_full & (grid[..., 0] < (sizeW - self.h_patch_size)) & (
+ grid[..., 1] < (sizeH - self.h_patch_size)) & (grid >= self.h_patch_size).all(dim=-1)
+ warp_mask_full = warp_mask_full.view(N_src, N_rays, n_samples, Npx)
+
+ grid = torch.clamp(normalize(grid, sizeH, sizeW), -10, 10)
+
+ sampled_rgb_val = F.grid_sample(src_imgs, grid.view(N_src, -1, 1, 2), align_corners=True).squeeze(
+ -1).transpose(1, 2)
+ sampled_rgb_val = sampled_rgb_val.view(N_src, N_rays, n_samples, Npx, 3)
+
+ warp_mask_full = warp_mask_full.permute(1, 2, 0, 3).contiguous() # (N_rays, n_samples, N_src, Npx)
+ sampled_rgb_val = sampled_rgb_val.permute(1, 2, 0, 3, 4).contiguous() # (N_rays, n_samples, N_src, Npx, 3)
+
+ return sampled_rgb_val, warp_mask_full
+
+ def patch_homography(self, H, uv):
+ N, Npx = uv.shape[:2]
+ Nsrc = H.shape[0]
+ H = H.view(Nsrc, N, -1, 3, 3)
+ hom_uv = add_hom(uv)
+
+ # einsum is 30 times faster
+ # tmp = (H.view(Nsrc, N, -1, 1, 3, 3) @ hom_uv.view(1, N, 1, -1, 3, 1)).squeeze(-1).view(Nsrc, -1, 3)
+ tmp = torch.einsum("vprik,pok->vproi", H, hom_uv).reshape(Nsrc, -1, 3)
+
+ grid = tmp[..., :2] / torch.clamp(tmp[..., 2:], 1e-8)
+ mask = tmp[..., 2] > 0
+ return grid, mask
+
+
+def add_hom(pts):
+ try:
+ dev = pts.device
+ ones = torch.ones(pts.shape[:-1], device=dev).unsqueeze(-1)
+ return torch.cat((pts, ones), dim=-1)
+
+ except AttributeError:
+ ones = np.ones((pts.shape[0], 1))
+ return np.concatenate((pts, ones), axis=1)
+
+
+def normalize(flow, h, w, clamp=None):
+ # either h and w are simple float or N torch.tensor where N batch size
+ try:
+ h.device
+
+ except AttributeError:
+ h = torch.tensor(h, device=flow.device).float().unsqueeze(0)
+ w = torch.tensor(w, device=flow.device).float().unsqueeze(0)
+
+ if len(flow.shape) == 4:
+ w = w.unsqueeze(1).unsqueeze(2)
+ h = h.unsqueeze(1).unsqueeze(2)
+ elif len(flow.shape) == 3:
+ w = w.unsqueeze(1)
+ h = h.unsqueeze(1)
+ elif len(flow.shape) == 5:
+ w = w.unsqueeze(0).unsqueeze(2).unsqueeze(2)
+ h = h.unsqueeze(0).unsqueeze(2).unsqueeze(2)
+
+ res = torch.empty_like(flow)
+ if res.shape[-1] == 3:
+ res[..., 2] = 1
+
+ # for grid_sample with align_corners=True
+ # https://github.com/pytorch/pytorch/blob/c371542efc31b1abfe6f388042aa3ab0cef935f2/aten/src/ATen/native/GridSampler.h#L33
+ res[..., 0] = 2 * flow[..., 0] / (w - 1) - 1
+ res[..., 1] = 2 * flow[..., 1] / (h - 1) - 1
+
+ if clamp:
+ return torch.clamp(res, -clamp, clamp)
+ else:
+ return res
+
+
+def build_patch_offset(h_patch_size):
+ offsets = torch.arange(-h_patch_size, h_patch_size + 1)
+ return torch.stack(torch.meshgrid(offsets, offsets)[::-1], dim=-1).view(1, -1, 2) # nb_pixels_patch * 2
diff --git a/SparseNeuS_demo_v1/models/projector.py b/SparseNeuS_demo_v1/models/projector.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa58d3f896edefff25cbb6fa713e7342d9b84a1d
--- /dev/null
+++ b/SparseNeuS_demo_v1/models/projector.py
@@ -0,0 +1,425 @@
+# The codes are partly from IBRNet
+
+import torch
+import torch.nn.functional as F
+from models.render_utils import sample_ptsFeatures_from_featureMaps, sample_ptsFeatures_from_featureVolume
+
+def safe_l2_normalize(x, dim=None, eps=1e-6):
+ return F.normalize(x, p=2, dim=dim, eps=eps)
+
+class Projector():
+ """
+ Obtain features from geometryVolume and rendering_feature_maps for generalized rendering
+ """
+
+ def compute_angle(self, xyz, query_c2w, supporting_c2ws):
+ """
+
+ :param xyz: [N_rays, n_samples,3 ]
+ :param query_c2w: [1,4,4]
+ :param supporting_c2ws: [n,4,4]
+ :return:
+ """
+ N_rays, n_samples, _ = xyz.shape
+ num_views = supporting_c2ws.shape[0]
+ xyz = xyz.reshape(-1, 3)
+
+ ray2tar_pose = (query_c2w[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0))
+ ray2tar_pose /= (torch.norm(ray2tar_pose, dim=-1, keepdim=True) + 1e-6)
+ ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0))
+ ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6)
+ ray_diff = ray2tar_pose - ray2support_pose
+ ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True)
+ ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True)
+ ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6)
+ ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1)
+ ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) # the last dimension (4) is dot-product
+ return ray_diff.detach()
+
+
+ def compute_angle_view_independent(self, xyz, surface_normals, supporting_c2ws):
+ """
+
+ :param xyz: [N_rays, n_samples,3 ]
+ :param surface_normals: [N_rays, n_samples,3 ]
+ :param supporting_c2ws: [n,4,4]
+ :return:
+ """
+ N_rays, n_samples, _ = xyz.shape
+ num_views = supporting_c2ws.shape[0]
+ xyz = xyz.reshape(-1, 3)
+
+ ray2tar_pose = surface_normals
+ ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0))
+ ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6)
+ ray_diff = ray2tar_pose - ray2support_pose
+ ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True)
+ ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True)
+ ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6)
+ ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1)
+ ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) # the last dimension (4) is dot-product,
+ # and the first three dimensions is the normalized ray diff vector
+ return ray_diff.detach()
+
+ @torch.no_grad()
+ def compute_z_diff(self, xyz, w2cs, intrinsics, pred_depth_values):
+ """
+ compute the depth difference of query pts projected on the image and the predicted depth values of the image
+ :param xyz: [N_rays, n_samples,3 ]
+ :param w2cs: [N_views, 4, 4]
+ :param intrinsics: [N_views, 3, 3]
+ :param pred_depth_values: [N_views, N_rays, n_samples,1 ]
+ :param pred_depth_masks: [N_views, N_rays, n_samples]
+ :return:
+ """
+ device = xyz.device
+ N_views = w2cs.shape[0]
+ N_rays, n_samples, _ = xyz.shape
+ proj_matrix = torch.matmul(intrinsics, w2cs[:, :3, :])
+
+ proj_rot = proj_matrix[:, :3, :3]
+ proj_trans = proj_matrix[:, :3, 3:]
+
+ batch_xyz = xyz.permute(2, 0, 1).contiguous().view(1, 3, N_rays * n_samples).repeat(N_views, 1, 1)
+
+ proj_xyz = proj_rot.bmm(batch_xyz) + proj_trans
+
+ # X = proj_xyz[:, 0]
+ # Y = proj_xyz[:, 1]
+ Z = proj_xyz[:, 2].clamp(min=1e-3) # [N_views, N_rays*n_samples]
+ proj_z = Z.view(N_views, N_rays, n_samples, 1)
+
+ z_diff = proj_z - pred_depth_values # [N_views, N_rays, n_samples,1 ]
+
+ return z_diff
+
+ def compute(self,
+ pts,
+ # * 3d geometry feature volumes
+ geometryVolume=None,
+ geometryVolumeMask=None,
+ vol_dims=None,
+ partial_vol_origin=None,
+ vol_size=None,
+ # * 2d rendering feature maps
+ rendering_feature_maps=None,
+ color_maps=None,
+ w2cs=None,
+ intrinsics=None,
+ img_wh=None,
+ query_img_idx=0, # the index of the N_views dim for rendering
+ query_c2w=None,
+ pred_depth_maps=None, # no use here
+ pred_depth_masks=None # no use here
+ ):
+ """
+ extract features of pts for rendering
+ :param pts:
+ :param geometryVolume:
+ :param vol_dims:
+ :param partial_vol_origin:
+ :param vol_size:
+ :param rendering_feature_maps:
+ :param color_maps:
+ :param w2cs:
+ :param intrinsics:
+ :param img_wh:
+ :param rendering_img_idx: by default, we render the first view of w2cs
+ :return:
+ """
+ device = pts.device
+ c2ws = torch.inverse(w2cs)
+
+ if len(pts.shape) == 2:
+ pts = pts[None, :, :]
+
+ N_rays, n_samples, _ = pts.shape
+ N_views = rendering_feature_maps.shape[0] # shape (N_views, C, H, W)
+
+ supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device)
+ query_img_idx = torch.LongTensor([query_img_idx]).to(device)
+
+ if query_c2w is None and query_img_idx > -1:
+ query_c2w = torch.index_select(c2ws, 0, query_img_idx)
+ supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs)
+ supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs)
+ supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs)
+ supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs)
+ supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs)
+
+ if pred_depth_maps is not None:
+ supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs)
+ supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs)
+ # print("N_supporting_views: ", N_views - 1)
+ N_supporting_views = N_views - 1
+ else:
+ supporting_c2ws = c2ws
+ supporting_w2cs = w2cs
+ supporting_rendering_feature_maps = rendering_feature_maps
+ supporting_color_maps = color_maps
+ supporting_intrinsics = intrinsics
+ supporting_depth_maps = pred_depth_masks
+ supporting_depth_masks = pred_depth_masks
+ # print("N_supporting_views: ", N_views)
+ N_supporting_views = N_views
+ # import ipdb; ipdb.set_trace()
+ if geometryVolume is not None:
+ # * sample feature of pts from 3D feature volume
+ pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume(
+ pts, geometryVolume, vol_dims,
+ partial_vol_origin, vol_size) # [N_rays, n_samples, C], [N_rays, n_samples]
+
+ if len(geometryVolumeMask.shape) == 3:
+ geometryVolumeMask = geometryVolumeMask[None, :, :, :]
+
+ pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume(
+ pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims,
+ partial_vol_origin, vol_size) # [N_rays, n_samples, C]
+
+ pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0)
+ else:
+ pts_geometry_feature = None
+ pts_geometry_masks = None
+
+ # * sample feature of pts from 2D feature maps
+ pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps(
+ pts, supporting_rendering_feature_maps, supporting_w2cs,
+ supporting_intrinsics, img_wh,
+ return_mask=True) # [N_views, C, N_rays, n_samples], # [N_views, N_rays, n_samples]
+ # import ipdb; ipdb.set_trace()
+ # * size (N_views, N_rays*n_samples, c)
+ pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous()
+
+ pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs,
+ supporting_intrinsics, img_wh)
+ # * size (N_views, N_rays*n_samples, c)
+ pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous()
+
+ rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) # [N_views, N_rays, n_samples, 3+c]
+
+
+ ray_diff = self.compute_angle(pts, query_c2w, supporting_c2ws) # [N_views, N_rays, n_samples, 4]
+ # import ipdb; ipdb.set_trace()
+ if pts_geometry_masks is not None:
+ final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \
+ pts_rendering_mask # [N_views, N_rays, n_samples]
+ else:
+ final_mask = pts_rendering_mask
+ # import ipdb; ipdb.set_trace()
+ z_diff, pts_pred_depth_masks = None, None
+
+ if pred_depth_maps is not None:
+ pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs,
+ supporting_intrinsics, img_wh)
+ pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3,
+ 1).contiguous() # (N_views, N_rays*n_samples, 1)
+
+ # - pts_pred_depth_masks are critical than final_mask,
+ # - the ray containing few invalid pts will be treated invalid
+ pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(),
+ supporting_w2cs,
+ supporting_intrinsics, img_wh)
+
+ pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :,
+ 0] # (N_views, N_rays*n_samples)
+
+ z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values)
+ # import ipdb; ipdb.set_trace()
+ return pts_geometry_feature, rgb_feats, ray_diff, final_mask, z_diff, pts_pred_depth_masks
+
+
+ def compute_view_independent(
+ self,
+ pts,
+ # * 3d geometry feature volumes
+ geometryVolume=None,
+ geometryVolumeMask=None,
+ sdf_network=None,
+ lod=0,
+ vol_dims=None,
+ partial_vol_origin=None,
+ vol_size=None,
+ # * 2d rendering feature maps
+ rendering_feature_maps=None,
+ color_maps=None,
+ w2cs=None,
+ target_candidate_w2cs=None,
+ intrinsics=None,
+ img_wh=None,
+ query_img_idx=0, # the index of the N_views dim for rendering
+ query_c2w=None,
+ pred_depth_maps=None, # no use here
+ pred_depth_masks=None # no use here
+ ):
+ """
+ extract features of pts for rendering
+ :param pts:
+ :param geometryVolume:
+ :param vol_dims:
+ :param partial_vol_origin:
+ :param vol_size:
+ :param rendering_feature_maps:
+ :param color_maps:
+ :param w2cs:
+ :param intrinsics:
+ :param img_wh:
+ :param rendering_img_idx: by default, we render the first view of w2cs
+ :return:
+ """
+ device = pts.device
+ c2ws = torch.inverse(w2cs)
+
+ if len(pts.shape) == 2:
+ pts = pts[None, :, :]
+
+ N_rays, n_samples, _ = pts.shape
+ N_views = rendering_feature_maps.shape[0] # shape (N_views, C, H, W)
+
+ supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device)
+ query_img_idx = torch.LongTensor([query_img_idx]).to(device)
+
+ if query_c2w is None and query_img_idx > -1:
+ query_c2w = torch.index_select(c2ws, 0, query_img_idx)
+ supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs)
+ supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs)
+ supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs)
+ supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs)
+ supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs)
+
+ if pred_depth_maps is not None:
+ supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs)
+ supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs)
+ # print("N_supporting_views: ", N_views - 1)
+ N_supporting_views = N_views - 1
+ else:
+ supporting_c2ws = c2ws
+ supporting_w2cs = w2cs
+ supporting_rendering_feature_maps = rendering_feature_maps
+ supporting_color_maps = color_maps
+ supporting_intrinsics = intrinsics
+ supporting_depth_maps = pred_depth_masks
+ supporting_depth_masks = pred_depth_masks
+ # print("N_supporting_views: ", N_views)
+ N_supporting_views = N_views
+ # import ipdb; ipdb.set_trace()
+ if geometryVolume is not None:
+ # * sample feature of pts from 3D feature volume
+ pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume(
+ pts, geometryVolume, vol_dims,
+ partial_vol_origin, vol_size) # [N_rays, n_samples, C], [N_rays, n_samples]
+
+ if len(geometryVolumeMask.shape) == 3:
+ geometryVolumeMask = geometryVolumeMask[None, :, :, :]
+
+ pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume(
+ pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims,
+ partial_vol_origin, vol_size) # [N_rays, n_samples, C]
+
+ pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0)
+ else:
+ pts_geometry_feature = None
+ pts_geometry_masks = None
+
+ # * sample feature of pts from 2D feature maps
+ pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps(
+ pts, supporting_rendering_feature_maps, supporting_w2cs,
+ supporting_intrinsics, img_wh,
+ return_mask=True) # [N_views, C, N_rays, n_samples], # [N_views, N_rays, n_samples]
+
+ # * size (N_views, N_rays*n_samples, c)
+ pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous()
+
+ pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs,
+ supporting_intrinsics, img_wh)
+ # * size (N_views, N_rays*n_samples, c)
+ pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous()
+
+ rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) # [N_views, N_rays, n_samples, 3+c]
+
+ # import ipdb; ipdb.set_trace()
+
+ gradients = sdf_network.gradient(
+ pts.reshape(-1, 3), # pts.squeeze(0),
+ geometryVolume.unsqueeze(0),
+ lod=lod
+ ).squeeze()
+
+ surface_normals = safe_l2_normalize(gradients, dim=-1) # [npts, 3]
+ # input normals
+ ren_ray_diff = self.compute_angle_view_independent(
+ xyz=pts,
+ surface_normals=surface_normals,
+ supporting_c2ws=supporting_c2ws
+ )
+
+ # # choose closest target view direction from 32 candidate views
+ # # choose the closest source view as view direction instead of the normals vectors
+ # pts2src_centers = safe_l2_normalize((supporting_c2ws[:, :3, 3].unsqueeze(1) - pts)) # [N_views, npts, 3]
+
+ # cosine_distance = torch.sum(pts2src_centers * surface_normals, dim=-1, keepdim=True) # [N_views, npts, 1]
+ # # choose the largest cosine distance as the view direction
+ # max_idx = torch.argmax(cosine_distance, dim=0) # [npts, 1]
+
+ # chosen_view_direction = pts2src_centers[max_idx.squeeze(), torch.arange(pts.shape[1]), :] # [npts, 3]
+ # ren_ray_diff = self.compute_angle_view_independent(
+ # xyz=pts,
+ # surface_normals=chosen_view_direction,
+ # supporting_c2ws=supporting_c2ws
+ # )
+
+
+
+ # # choose closest target view direction from 8 candidate views
+ # # choose the closest source view as view direction instead of the normals vectors
+ # target_candidate_c2ws = torch.inverse(target_candidate_w2cs)
+ # pts2src_centers = safe_l2_normalize((target_candidate_c2ws[:, :3, 3].unsqueeze(1) - pts)) # [N_views, npts, 3]
+
+ # cosine_distance = torch.sum(pts2src_centers * surface_normals, dim=-1, keepdim=True) # [N_views, npts, 1]
+ # # choose the largest cosine distance as the view direction
+ # max_idx = torch.argmax(cosine_distance, dim=0) # [npts, 1]
+
+ # chosen_view_direction = pts2src_centers[max_idx.squeeze(), torch.arange(pts.shape[1]), :] # [npts, 3]
+ # ren_ray_diff = self.compute_angle_view_independent(
+ # xyz=pts,
+ # surface_normals=chosen_view_direction,
+ # supporting_c2ws=supporting_c2ws
+ # )
+
+
+ # ray_diff = self.compute_angle(pts, query_c2w, supporting_c2ws) # [N_views, N_rays, n_samples, 4]
+ # import ipdb; ipdb.set_trace()
+
+
+ # input_directions = safe_l2_normalize(pts)
+ # ren_ray_diff = self.compute_angle_view_independent(
+ # xyz=pts,
+ # surface_normals=input_directions,
+ # supporting_c2ws=supporting_c2ws
+ # )
+
+ if pts_geometry_masks is not None:
+ final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \
+ pts_rendering_mask # [N_views, N_rays, n_samples]
+ else:
+ final_mask = pts_rendering_mask
+ # import ipdb; ipdb.set_trace()
+ z_diff, pts_pred_depth_masks = None, None
+
+ if pred_depth_maps is not None:
+ pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs,
+ supporting_intrinsics, img_wh)
+ pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3,
+ 1).contiguous() # (N_views, N_rays*n_samples, 1)
+
+ # - pts_pred_depth_masks are critical than final_mask,
+ # - the ray containing few invalid pts will be treated invalid
+ pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(),
+ supporting_w2cs,
+ supporting_intrinsics, img_wh)
+
+ pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :,
+ 0] # (N_views, N_rays*n_samples)
+
+ z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values)
+ # import ipdb; ipdb.set_trace()
+ return pts_geometry_feature, rgb_feats, ren_ray_diff, final_mask, z_diff, pts_pred_depth_masks
diff --git a/SparseNeuS_demo_v1/models/rays.py b/SparseNeuS_demo_v1/models/rays.py
new file mode 100644
index 0000000000000000000000000000000000000000..a31df93e727fd79adaaa3e934c67378b611d4ee0
--- /dev/null
+++ b/SparseNeuS_demo_v1/models/rays.py
@@ -0,0 +1,325 @@
+import os, torch, cv2, re
+import numpy as np
+
+from PIL import Image
+import torch.nn.functional as F
+import torchvision.transforms as T
+
+from random import random
+
+
+def build_patch_offset(h_patch_size):
+ offsets = torch.arange(-h_patch_size, h_patch_size + 1)
+ return torch.stack(torch.meshgrid(offsets, offsets)[::-1], dim=-1).view(1, -1, 2) # nb_pixels_patch * 2
+
+
+def gen_rays_from_single_image(H, W, image, intrinsic, c2w, depth=None, mask=None):
+ """
+ generate rays in world space, for image image
+ :param H:
+ :param W:
+ :param intrinsics: [3,3]
+ :param c2ws: [4,4]
+ :return:
+ """
+ device = image.device
+ ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H),
+ torch.linspace(0, W - 1, W)) # pytorch's meshgrid has indexing='ij'
+ p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) # H, W, 3
+
+ # normalized ndc uv coordinates, (-1, 1)
+ ndc_u = 2 * xs / (W - 1) - 1
+ ndc_v = 2 * ys / (H - 1) - 1
+ rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float().to(device)
+
+ intrinsic_inv = torch.inverse(intrinsic)
+
+ p = p.view(-1, 3).float().to(device) # N_rays, 3
+ p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays, 3
+ rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays, 3
+ rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays, 3
+ rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays, 3
+
+ image = image.permute(1, 2, 0)
+ color = image.view(-1, 3)
+ depth = depth.view(-1, 1) if depth is not None else None
+ mask = mask.view(-1, 1) if mask is not None else torch.ones([H * W, 1]).to(device)
+ sample = {
+ 'rays_o': rays_o,
+ 'rays_v': rays_v,
+ 'rays_ndc_uv': rays_ndc_uv,
+ 'rays_color': color,
+ # 'rays_depth': depth,
+ 'rays_mask': mask,
+ 'rays_norm_XYZ_cam': p # - XYZ_cam, before multiply depth
+ }
+ if depth is not None:
+ sample['rays_depth'] = depth
+
+ return sample
+
+
+def gen_random_rays_from_single_image(H, W, N_rays, image, intrinsic, c2w, depth=None, mask=None, dilated_mask=None,
+ importance_sample=False, h_patch_size=3):
+ """
+ generate random rays in world space, for a single image
+ :param H:
+ :param W:
+ :param N_rays:
+ :param image: [3, H, W]
+ :param intrinsic: [3,3]
+ :param c2w: [4,4]
+ :param depth: [H, W]
+ :param mask: [H, W]
+ :return:
+ """
+ device = image.device
+
+ if dilated_mask is None:
+ dilated_mask = mask
+
+ if not importance_sample:
+ pixels_x = torch.randint(low=0, high=W, size=[N_rays])
+ pixels_y = torch.randint(low=0, high=H, size=[N_rays])
+ elif importance_sample and dilated_mask is not None: # sample more pts in the valid mask regions
+ pixels_x_1 = torch.randint(low=0, high=W, size=[N_rays // 4])
+ pixels_y_1 = torch.randint(low=0, high=H, size=[N_rays // 4])
+
+ ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H),
+ torch.linspace(0, W - 1, W)) # pytorch's meshgrid has indexing='ij'
+ p = torch.stack([xs, ys], dim=-1) # H, W, 2
+
+ try:
+ p_valid = p[dilated_mask > 0] # [num, 2]
+ random_idx = torch.randint(low=0, high=p_valid.shape[0], size=[N_rays // 4 * 3])
+ except:
+ print("dilated_mask.shape: ", dilated_mask.shape)
+ print("dilated_mask valid number", dilated_mask.sum())
+
+ raise ValueError("hhhh")
+ p_select = p_valid[random_idx] # [N_rays//2, 2]
+ pixels_x_2 = p_select[:, 0]
+ pixels_y_2 = p_select[:, 1]
+
+ pixels_x = torch.cat([pixels_x_1, pixels_x_2], dim=0).to(torch.int64)
+ pixels_y = torch.cat([pixels_y_1, pixels_y_2], dim=0).to(torch.int64)
+
+ # - crop patch from images
+ offsets = build_patch_offset(h_patch_size).to(device)
+ grid_patch = torch.stack([pixels_x, pixels_y], dim=-1).view(-1, 1, 2) + offsets.float() # [N_pts, Npx, 2]
+ patch_mask = (pixels_x > h_patch_size) * (pixels_x < (W - h_patch_size)) * (pixels_y > h_patch_size) * (
+ pixels_y < H - h_patch_size) # [N_pts]
+ grid_patch_u = 2 * grid_patch[:, :, 0] / (W - 1) - 1
+ grid_patch_v = 2 * grid_patch[:, :, 1] / (H - 1) - 1
+ grid_patch_uv = torch.stack([grid_patch_u, grid_patch_v], dim=-1) # [N_pts, Npx, 2]
+ patch_color = F.grid_sample(image[None, :, :, :], grid_patch_uv[None, :, :, :], mode='bilinear',
+ padding_mode='zeros',align_corners=True)[0] # [3, N_pts, Npx]
+ patch_color = patch_color.permute(1, 2, 0).contiguous()
+
+ # normalized ndc uv coordinates, (-1, 1)
+ ndc_u = 2 * pixels_x / (W - 1) - 1
+ ndc_v = 2 * pixels_y / (H - 1) - 1
+ rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float().to(device)
+
+ image = image.permute(1, 2, 0) # H ,W, C
+ color = image[(pixels_y, pixels_x)] # N_rays, 3
+
+ if mask is not None:
+ mask = mask[(pixels_y, pixels_x)] # N_rays
+ patch_mask = patch_mask * mask # N_rays
+ mask = mask.view(-1, 1)
+ else:
+ mask = torch.ones([N_rays, 1])
+
+ if depth is not None:
+ depth = depth[(pixels_y, pixels_x)] # N_rays
+ depth = depth.view(-1, 1)
+
+ intrinsic_inv = torch.inverse(intrinsic)
+
+ p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float().to(device) # N_rays, 3
+ p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays, 3
+ rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays, 3
+ rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays, 3
+ rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays, 3
+
+ sample = {
+ 'rays_o': rays_o,
+ 'rays_v': rays_v,
+ 'rays_ndc_uv': rays_ndc_uv,
+ 'rays_color': color,
+ # 'rays_depth': depth,
+ 'rays_mask': mask,
+ 'rays_norm_XYZ_cam': p, # - XYZ_cam, before multiply depth,
+ 'rays_patch_color': patch_color,
+ 'rays_patch_mask': patch_mask.view(-1, 1)
+ }
+
+ if depth is not None:
+ sample['rays_depth'] = depth
+
+ return sample
+
+
+def gen_random_rays_of_patch_from_single_image(H, W, N_rays, num_neighboring_pts, patch_size,
+ image, intrinsic, c2w, depth=None, mask=None):
+ """
+ generate random rays in world space, for a single image
+ sample rays from local patches
+ :param H:
+ :param W:
+ :param N_rays: the number of center rays of patches
+ :param image: [3, H, W]
+ :param intrinsic: [3,3]
+ :param c2w: [4,4]
+ :param depth: [H, W]
+ :param mask: [H, W]
+ :return:
+ """
+ device = image.device
+ patch_radius_max = patch_size // 2
+
+ unit_u = 2 / (W - 1)
+ unit_v = 2 / (H - 1)
+
+ pixels_x_center = torch.randint(low=patch_size, high=W - patch_size, size=[N_rays])
+ pixels_y_center = torch.randint(low=patch_size, high=H - patch_size, size=[N_rays])
+
+ # normalized ndc uv coordinates, (-1, 1)
+ ndc_u_center = 2 * pixels_x_center / (W - 1) - 1
+ ndc_v_center = 2 * pixels_y_center / (H - 1) - 1
+ ndc_uv_center = torch.stack([ndc_u_center, ndc_v_center], dim=-1).view(-1, 2).float().to(device)[:, None,
+ :] # [N_rays, 1, 2]
+
+ shift_u, shift_v = torch.rand([N_rays, num_neighboring_pts, 1]), torch.rand(
+ [N_rays, num_neighboring_pts, 1]) # uniform distribution of [0,1)
+ shift_u = 2 * (shift_u - 0.5) # mapping to [-1, 1)
+ shift_v = 2 * (shift_v - 0.5)
+
+ # - avoid sample points which are too close to center point
+ shift_uv = torch.cat([(shift_u * patch_radius_max) * unit_u, (shift_v * patch_radius_max) * unit_v],
+ dim=-1) # [N_rays, num_npts, 2]
+ neighboring_pts_uv = ndc_uv_center + shift_uv # [N_rays, num_npts, 2]
+
+ sampled_pts_uv = torch.cat([ndc_uv_center, neighboring_pts_uv], dim=1) # concat the center point
+
+ # sample the gts
+ color = F.grid_sample(image[None, :, :, :], sampled_pts_uv[None, :, :, :], mode='bilinear',
+ align_corners=True)[0] # [3, N_rays, num_npts]
+ depth = F.grid_sample(depth[None, None, :, :], sampled_pts_uv[None, :, :, :], mode='bilinear',
+ align_corners=True)[0] # [1, N_rays, num_npts]
+
+ mask = F.grid_sample(mask[None, None, :, :].to(torch.float32), sampled_pts_uv[None, :, :, :], mode='nearest',
+ align_corners=True).to(torch.int64)[0] # [1, N_rays, num_npts]
+
+ intrinsic_inv = torch.inverse(intrinsic)
+
+ sampled_pts_uv = sampled_pts_uv.view(N_rays * (1 + num_neighboring_pts), 2)
+ color = color.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 3)
+ depth = depth.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 1)
+ mask = mask.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 1)
+
+ pixels_x = (sampled_pts_uv[:, 0] + 1) * (W - 1) / 2
+ pixels_y = (sampled_pts_uv[:, 1] + 1) * (H - 1) / 2
+ p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float().to(device) # N_rays*num_pts, 3
+ p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays*num_pts, 3
+ rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays*num_pts, 3
+ rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays*num_pts, 3
+ rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays*num_pts, 3
+
+ sample = {
+ 'rays_o': rays_o,
+ 'rays_v': rays_v,
+ 'rays_ndc_uv': sampled_pts_uv,
+ 'rays_color': color,
+ 'rays_depth': depth,
+ 'rays_mask': mask,
+ # 'rays_norm_XYZ_cam': p # - XYZ_cam, before multiply depth
+ }
+
+ return sample
+
+
+def gen_random_rays_from_batch_images(H, W, N_rays, images, intrinsics, c2ws, depths=None, masks=None):
+ """
+
+ :param H:
+ :param W:
+ :param N_rays:
+ :param images: [B,3,H,W]
+ :param intrinsics: [B, 3, 3]
+ :param c2ws: [B, 4, 4]
+ :param depths: [B,H,W]
+ :param masks: [B,H,W]
+ :return:
+ """
+ assert len(images.shape) == 4
+
+ rays_o = []
+ rays_v = []
+ rays_color = []
+ rays_depth = []
+ rays_mask = []
+ for i in range(images.shape[0]):
+ sample = gen_random_rays_from_single_image(H, W, N_rays, images[i], intrinsics[i], c2ws[i],
+ depth=depths[i] if depths is not None else None,
+ mask=masks[i] if masks is not None else None)
+ rays_o.append(sample['rays_o'])
+ rays_v.append(sample['rays_v'])
+ rays_color.append(sample['rays_color'])
+ if depths is not None:
+ rays_depth.append(sample['rays_depth'])
+ if masks is not None:
+ rays_mask.append(sample['rays_mask'])
+
+ sample = {
+ 'rays_o': torch.stack(rays_o, dim=0), # [batch, N_rays, 3]
+ 'rays_v': torch.stack(rays_v, dim=0),
+ 'rays_color': torch.stack(rays_color, dim=0),
+ 'rays_depth': torch.stack(rays_depth, dim=0) if depths is not None else None,
+ 'rays_mask': torch.stack(rays_mask, dim=0) if masks is not None else None
+ }
+ return sample
+
+
+from scipy.spatial.transform import Rotation as Rot
+from scipy.spatial.transform import Slerp
+
+
+def gen_rays_between(c2w_0, c2w_1, intrinsic, ratio, H, W, resolution_level=1):
+ device = c2w_0.device
+
+ l = resolution_level
+ tx = torch.linspace(0, W - 1, W // l)
+ ty = torch.linspace(0, H - 1, H // l)
+ pixels_x, pixels_y = torch.meshgrid(tx, ty)
+ p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).to(device) # W, H, 3
+
+ intrinsic_inv = torch.inverse(intrinsic[:3, :3])
+ p = torch.matmul(intrinsic_inv[None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
+ rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
+ trans = c2w_0[:3, 3] * (1.0 - ratio) + c2w_1[:3, 3] * ratio
+
+ pose_0 = c2w_0.detach().cpu().numpy()
+ pose_1 = c2w_1.detach().cpu().numpy()
+ pose_0 = np.linalg.inv(pose_0)
+ pose_1 = np.linalg.inv(pose_1)
+ rot_0 = pose_0[:3, :3]
+ rot_1 = pose_1[:3, :3]
+ rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
+ key_times = [0, 1]
+ key_rots = [rot_0, rot_1]
+ slerp = Slerp(key_times, rots)
+ rot = slerp(ratio)
+ pose = np.diag([1.0, 1.0, 1.0, 1.0])
+ pose = pose.astype(np.float32)
+ pose[:3, :3] = rot.as_matrix()
+ pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
+ pose = np.linalg.inv(pose)
+
+ c2w = torch.from_numpy(pose).to(device)
+ rot = torch.from_numpy(pose[:3, :3]).cuda()
+ trans = torch.from_numpy(pose[:3, 3]).cuda()
+ rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
+ rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3
+ return c2w, rays_o.transpose(0, 1).contiguous().view(-1, 3), rays_v.transpose(0, 1).contiguous().view(-1, 3)
diff --git a/SparseNeuS_demo_v1/models/render_utils.py b/SparseNeuS_demo_v1/models/render_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1d3d8fc4ca7bf5e306733a213dec96a517a71c7
--- /dev/null
+++ b/SparseNeuS_demo_v1/models/render_utils.py
@@ -0,0 +1,126 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import logging
+import mcubes
+import trimesh
+from icecream import ic
+
+from ops.back_project import cam2pixel
+import pdb
+
+
+def sample_pdf(bins, weights, n_samples, det=False):
+ '''
+ :param bins: tensor of shape [N_rays, M+1], M is the number of bins
+ :param weights: tensor of shape [N_rays, M]
+ :param N_samples: number of samples along each ray
+ :param det: if True, will perform deterministic sampling
+ :return: [N_rays, N_samples]
+ '''
+ device = weights.device
+
+ weights = weights + 1e-5 # prevent nans
+ pdf = weights / torch.sum(weights, -1, keepdim=True)
+ cdf = torch.cumsum(pdf, -1)
+ cdf = torch.cat([torch.zeros_like(cdf[..., :1]).to(device), cdf], -1)
+
+ # if bins.shape[1] != weights.shape[1]: # - minor modification, add this constraint
+ # cdf = torch.cat([torch.zeros_like(cdf[..., :1]).to(device), cdf], -1)
+ # Take uniform samples
+ if det:
+ u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(device)
+ u = u.expand(list(cdf.shape[:-1]) + [n_samples])
+ else:
+ u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(device)
+
+ # Invert CDF
+ u = u.contiguous()
+ # inds = searchsorted(cdf, u, side='right')
+ inds = torch.searchsorted(cdf, u, right=True)
+
+ below = torch.max(torch.zeros_like(inds - 1), inds - 1)
+ above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
+ inds_g = torch.stack([below, above], -1) # (batch, n_samples, 2)
+
+ matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
+ cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
+ bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
+
+ denom = (cdf_g[..., 1] - cdf_g[..., 0])
+ denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
+ t = (u - cdf_g[..., 0]) / denom
+ samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
+
+ # pdb.set_trace()
+ return samples
+
+
+def sample_ptsFeatures_from_featureVolume(pts, featureVolume, vol_dims=None, partial_vol_origin=None, vol_size=None):
+ """
+ sample feature of pts_wrd from featureVolume, all in world space
+ :param pts: [N_rays, n_samples, 3]
+ :param featureVolume: [C,wX,wY,wZ]
+ :param vol_dims: [3] "3" for dimX, dimY, dimZ
+ :param partial_vol_origin: [3]
+ :return: pts_feature: [N_rays, n_samples, C]
+ :return: valid_mask: [N_rays]
+ """
+
+ N_rays, n_samples, _ = pts.shape
+
+ if vol_dims is None:
+ pts_normalized = pts
+ else:
+ # normalized to (-1, 1)
+ pts_normalized = 2 * (pts - partial_vol_origin[None, None, :]) / (vol_size * (vol_dims[None, None, :] - 1)) - 1
+
+ valid_mask = (torch.abs(pts_normalized[:, :, 0]) < 1.0) & (
+ torch.abs(pts_normalized[:, :, 1]) < 1.0) & (
+ torch.abs(pts_normalized[:, :, 2]) < 1.0) # (N_rays, n_samples)
+
+ pts_normalized = torch.flip(pts_normalized, dims=[-1]) # ! reverse the xyz for grid_sample
+
+ # ! checked grid_sample, (x,y,z) is for (D,H,W), reverse for (W,H,D)
+ pts_feature = F.grid_sample(featureVolume[None, :, :, :, :], pts_normalized[None, None, :, :, :],
+ padding_mode='zeros',
+ align_corners=True).view(-1, N_rays, n_samples) # [C, N_rays, n_samples]
+
+ pts_feature = pts_feature.permute(1, 2, 0) # [N_rays, n_samples, C]
+ return pts_feature, valid_mask
+
+
+def sample_ptsFeatures_from_featureMaps(pts, featureMaps, w2cs, intrinsics, WH, proj_matrix=None, return_mask=False):
+ """
+ sample features of pts from 2d feature maps
+ :param pts: [N_rays, N_samples, 3]
+ :param featureMaps: [N_views, C, H, W]
+ :param w2cs: [N_views, 4, 4]
+ :param intrinsics: [N_views, 3, 3]
+ :param proj_matrix: [N_views, 4, 4]
+ :param HW:
+ :return:
+ """
+ # normalized to (-1, 1)
+ N_rays, n_samples, _ = pts.shape
+ N_views = featureMaps.shape[0]
+
+ if proj_matrix is None:
+ proj_matrix = torch.matmul(intrinsics, w2cs[:, :3, :])
+
+ pts = pts.permute(2, 0, 1).contiguous().view(1, 3, N_rays, n_samples).repeat(N_views, 1, 1, 1)
+ pixel_grids = cam2pixel(pts, proj_matrix[:, :3, :3], proj_matrix[:, :3, 3:],
+ 'zeros', sizeH=WH[1], sizeW=WH[0]) # (nviews, N_rays, n_samples, 2)
+
+ valid_mask = (torch.abs(pixel_grids[:, :, :, 0]) < 1.0) & (
+ torch.abs(pixel_grids[:, :, :, 1]) < 1.00) # (nviews, N_rays, n_samples)
+
+ pts_feature = F.grid_sample(featureMaps, pixel_grids,
+ padding_mode='zeros',
+ align_corners=True) # [N_views, C, N_rays, n_samples]
+
+ if return_mask:
+ return pts_feature, valid_mask
+ else:
+ return pts_feature
diff --git a/SparseNeuS_demo_v1/models/rendering_network.py b/SparseNeuS_demo_v1/models/rendering_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfc984223450a609024a65956439ff741a6b133d
--- /dev/null
+++ b/SparseNeuS_demo_v1/models/rendering_network.py
@@ -0,0 +1,129 @@
+# the codes are partly borrowed from IBRNet
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_set_profiling_mode(False)
+
+
+# default tensorflow initialization of linear layers
+def weights_init(m):
+ if isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight.data)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias.data)
+
+
+@torch.jit.script
+def fused_mean_variance(x, weight):
+ mean = torch.sum(x * weight, dim=2, keepdim=True)
+ var = torch.sum(weight * (x - mean) ** 2, dim=2, keepdim=True)
+ return mean, var
+
+
+class GeneralRenderingNetwork(nn.Module):
+ """
+ This model is not sensitive to finetuning
+ """
+
+ def __init__(self, in_geometry_feat_ch=8, in_rendering_feat_ch=56, anti_alias_pooling=True):
+ super(GeneralRenderingNetwork, self).__init__()
+
+ self.in_geometry_feat_ch = in_geometry_feat_ch
+ self.in_rendering_feat_ch = in_rendering_feat_ch
+ self.anti_alias_pooling = anti_alias_pooling
+
+ if self.anti_alias_pooling:
+ self.s = nn.Parameter(torch.tensor(0.2), requires_grad=True)
+ activation_func = nn.ELU(inplace=True)
+
+ self.ray_dir_fc = nn.Sequential(nn.Linear(4, 16),
+ activation_func,
+ nn.Linear(16, in_rendering_feat_ch + 3),
+ activation_func)
+
+ self.base_fc = nn.Sequential(nn.Linear((in_rendering_feat_ch + 3) * 3 + in_geometry_feat_ch, 64),
+ activation_func,
+ nn.Linear(64, 32),
+ activation_func)
+
+ self.vis_fc = nn.Sequential(nn.Linear(32, 32),
+ activation_func,
+ nn.Linear(32, 33),
+ activation_func,
+ )
+
+ self.vis_fc2 = nn.Sequential(nn.Linear(32, 32),
+ activation_func,
+ nn.Linear(32, 1),
+ nn.Sigmoid()
+ )
+
+ self.rgb_fc = nn.Sequential(nn.Linear(32 + 1 + 4, 16),
+ activation_func,
+ nn.Linear(16, 8),
+ activation_func,
+ nn.Linear(8, 1))
+
+ self.base_fc.apply(weights_init)
+ self.vis_fc2.apply(weights_init)
+ self.vis_fc.apply(weights_init)
+ self.rgb_fc.apply(weights_init)
+
+ def forward(self, geometry_feat, rgb_feat, ray_diff, mask):
+ '''
+ :param geometry_feat: geometry features indicates sdf [n_rays, n_samples, n_feat]
+ :param rgb_feat: rgbs and image features [n_views, n_rays, n_samples, n_feat]
+ :param ray_diff: ray direction difference [n_views, n_rays, n_samples, 4], first 3 channels are directions,
+ last channel is inner product
+ :param mask: mask for whether each projection is valid or not. [n_views, n_rays, n_samples]
+ :return: rgb and density output, [n_rays, n_samples, 4]
+ '''
+
+ rgb_feat = rgb_feat.permute(1, 2, 0, 3).contiguous()
+ ray_diff = ray_diff.permute(1, 2, 0, 3).contiguous()
+ mask = mask[:, :, :, None].permute(1, 2, 0, 3).contiguous()
+ num_views = rgb_feat.shape[2]
+ geometry_feat = geometry_feat[:, :, None, :].repeat(1, 1, num_views, 1)
+
+ direction_feat = self.ray_dir_fc(ray_diff)
+ rgb_in = rgb_feat[..., :3]
+ rgb_feat = rgb_feat + direction_feat
+
+ if self.anti_alias_pooling:
+ _, dot_prod = torch.split(ray_diff, [3, 1], dim=-1)
+ exp_dot_prod = torch.exp(torch.abs(self.s) * (dot_prod - 1))
+ weight = (exp_dot_prod - torch.min(exp_dot_prod, dim=2, keepdim=True)[0]) * mask
+ weight = weight / (torch.sum(weight, dim=2, keepdim=True) + 1e-8)
+ else:
+ weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8)
+
+ # compute mean and variance across different views for each point
+ mean, var = fused_mean_variance(rgb_feat, weight) # [n_rays, n_samples, 1, n_feat]
+ globalfeat = torch.cat([mean, var], dim=-1) # [n_rays, n_samples, 1, 2*n_feat]
+
+ x = torch.cat([geometry_feat, globalfeat.expand(-1, -1, num_views, -1), rgb_feat],
+ dim=-1) # [n_rays, n_samples, n_views, 3*n_feat+n_geo_feat]
+ x = self.base_fc(x)
+
+ x_vis = self.vis_fc(x * weight)
+ x_res, vis = torch.split(x_vis, [x_vis.shape[-1] - 1, 1], dim=-1)
+ vis = F.sigmoid(vis) * mask
+ x = x + x_res
+ vis = self.vis_fc2(x * vis) * mask
+
+ # rgb computation
+ x = torch.cat([x, vis, ray_diff], dim=-1)
+ x = self.rgb_fc(x)
+ x = x.masked_fill(mask == 0, -1e9)
+ blending_weights_valid = F.softmax(x, dim=2) # color blending
+ rgb_out = torch.sum(rgb_in * blending_weights_valid, dim=2)
+
+ mask = mask.detach().to(rgb_out.dtype) # [n_rays, n_samples, n_views, 1]
+ mask = torch.sum(mask, dim=2, keepdim=False)
+ mask = mask >= 2 # more than 2 views see the point
+ mask = torch.sum(mask.to(rgb_out.dtype), dim=1, keepdim=False)
+ valid_mask = mask > 8 # valid rays, more than 8 valid samples
+ return rgb_out, valid_mask # (N_rays, n_samples, 3), (N_rays, 1)
diff --git a/SparseNeuS_demo_v1/models/sparse_neus_renderer.py b/SparseNeuS_demo_v1/models/sparse_neus_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8015669f349f5b61ca1cb234ec2fcdf71cd10407
--- /dev/null
+++ b/SparseNeuS_demo_v1/models/sparse_neus_renderer.py
@@ -0,0 +1,990 @@
+"""
+The codes are heavily borrowed from NeuS
+"""
+
+import os
+import cv2 as cv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import logging
+import mcubes
+import trimesh
+from icecream import ic
+from models.render_utils import sample_pdf
+
+from models.projector import Projector
+from tsparse.torchsparse_utils import sparse_to_dense_channel
+
+from models.fast_renderer import FastRenderer
+
+from models.patch_projector import PatchProjector
+
+from models.rays import gen_rays_between
+
+import pdb
+
+
+class SparseNeuSRenderer(nn.Module):
+ """
+ conditional neus render;
+ optimize on normalized world space;
+ warped by nn.Module to support DataParallel traning
+ """
+
+ def __init__(self,
+ rendering_network_outside,
+ sdf_network,
+ variance_network,
+ rendering_network,
+ n_samples,
+ n_importance,
+ n_outside,
+ perturb,
+ alpha_type='div',
+ conf=None
+ ):
+ super(SparseNeuSRenderer, self).__init__()
+
+ self.conf = conf
+ self.base_exp_dir = conf['general.base_exp_dir']
+
+ # network setups
+ self.rendering_network_outside = rendering_network_outside
+ self.sdf_network = sdf_network
+ self.variance_network = variance_network
+ self.rendering_network = rendering_network
+
+ self.n_samples = n_samples
+ self.n_importance = n_importance
+ self.n_outside = n_outside
+ self.perturb = perturb
+ self.alpha_type = alpha_type
+
+ self.rendering_projector = Projector() # used to obtain features for generalized rendering
+
+ self.h_patch_size = self.conf.get_int('model.h_patch_size', default=3)
+ self.patch_projector = PatchProjector(self.h_patch_size)
+
+ self.ray_tracer = FastRenderer() # ray_tracer to extract depth maps from sdf_volume
+
+ # - fitted rendering or general rendering
+ try:
+ self.if_fitted_rendering = self.sdf_network.if_fitted_rendering
+ except:
+ self.if_fitted_rendering = False
+
+ def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_variance,
+ conditional_valid_mask_volume=None):
+ device = rays_o.device
+ batch_size, n_samples = z_vals.shape
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
+
+ if conditional_valid_mask_volume is not None:
+ pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume)
+ pts_mask = pts_mask.reshape(batch_size, n_samples)
+ pts_mask = pts_mask[:, :-1] * pts_mask[:, 1:] # [batch_size, n_samples-1]
+ else:
+ pts_mask = torch.ones([batch_size, n_samples]).to(pts.device)
+
+ sdf = sdf.reshape(batch_size, n_samples)
+ prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
+ prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
+ mid_sdf = (prev_sdf + next_sdf) * 0.5
+ dot_val = None
+ if self.alpha_type == 'uniform':
+ dot_val = torch.ones([batch_size, n_samples - 1]) * -1.0
+ else:
+ dot_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
+ prev_dot_val = torch.cat([torch.zeros([batch_size, 1]).to(device), dot_val[:, :-1]], dim=-1)
+ dot_val = torch.stack([prev_dot_val, dot_val], dim=-1)
+ dot_val, _ = torch.min(dot_val, dim=-1, keepdim=False)
+ dot_val = dot_val.clip(-10.0, 0.0) * pts_mask
+ dist = (next_z_vals - prev_z_vals)
+ prev_esti_sdf = mid_sdf - dot_val * dist * 0.5
+ next_esti_sdf = mid_sdf + dot_val * dist * 0.5
+ prev_cdf = torch.sigmoid(prev_esti_sdf * inv_variance)
+ next_cdf = torch.sigmoid(next_esti_sdf * inv_variance)
+ alpha_sdf = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
+
+ alpha = alpha_sdf
+
+ # - apply pts_mask
+ alpha = pts_mask * alpha
+
+ weights = alpha * torch.cumprod(
+ torch.cat([torch.ones([batch_size, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, :-1]
+
+ z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
+ return z_samples
+
+ def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, lod,
+ sdf_network, gru_fusion,
+ # * related to conditional feature
+ conditional_volume=None,
+ conditional_valid_mask_volume=None
+ ):
+ device = rays_o.device
+ batch_size, n_samples = z_vals.shape
+ _, n_importance = new_z_vals.shape
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
+
+ if conditional_valid_mask_volume is not None:
+ pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume)
+ pts_mask = pts_mask.reshape(batch_size, n_importance)
+ pts_mask_bool = (pts_mask > 0).view(-1)
+ else:
+ pts_mask = torch.ones([batch_size, n_importance]).to(pts.device)
+
+ new_sdf = torch.ones([batch_size * n_importance, 1]).to(pts.dtype).to(device) * 100
+
+ if torch.sum(pts_mask) > 1:
+ new_outputs = sdf_network.sdf(pts.reshape(-1, 3)[pts_mask_bool], conditional_volume, lod=lod)
+ new_sdf[pts_mask_bool] = new_outputs['sdf_pts_scale%d' % lod] # .reshape(batch_size, n_importance)
+
+ new_sdf = new_sdf.view(batch_size, n_importance)
+
+ z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
+ sdf = torch.cat([sdf, new_sdf], dim=-1)
+
+ z_vals, index = torch.sort(z_vals, dim=-1)
+ xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
+ index = index.reshape(-1)
+ sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
+
+ return z_vals, sdf
+
+ @torch.no_grad()
+ def get_pts_mask_for_conditional_volume(self, pts, mask_volume):
+ """
+
+ :param pts: [N, 3]
+ :param mask_volume: [1, 1, X, Y, Z]
+ :return:
+ """
+ num_pts = pts.shape[0]
+ pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1)
+
+ pts = torch.flip(pts, dims=[-1])
+
+ pts_mask = F.grid_sample(mask_volume, pts, mode='nearest') # [1, c, 1, 1, num_pts]
+ pts_mask = pts_mask.view(-1, num_pts).permute(1, 0).contiguous() # [num_pts, 1]
+
+ return pts_mask
+
+ def render_core(self,
+ rays_o,
+ rays_d,
+ z_vals,
+ sample_dist,
+ lod,
+ sdf_network,
+ rendering_network,
+ background_alpha=None, # - no use here
+ background_sampled_color=None, # - no use here
+ background_rgb=None, # - no use here
+ alpha_inter_ratio=0.0,
+ # * related to conditional feature
+ conditional_volume=None,
+ conditional_valid_mask_volume=None,
+ # * 2d feature maps
+ feature_maps=None,
+ color_maps=None,
+ w2cs=None,
+ intrinsics=None,
+ img_wh=None,
+ query_c2w=None, # - used for testing
+ if_general_rendering=True,
+ if_render_with_grad=True,
+ # * used for blending mlp rendering network
+ img_index=None,
+ rays_uv=None,
+ # * used for clear bg and fg
+ bg_num=0
+ ):
+ device = rays_o.device
+ N_rays = rays_o.shape[0]
+ _, n_samples = z_vals.shape
+ dists = z_vals[..., 1:] - z_vals[..., :-1]
+ dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape).to(device)], -1)
+
+ mid_z_vals = z_vals + dists * 0.5
+ mid_dists = mid_z_vals[..., 1:] - mid_z_vals[..., :-1]
+
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
+ dirs = rays_d[:, None, :].expand(pts.shape)
+
+ pts = pts.reshape(-1, 3)
+ dirs = dirs.reshape(-1, 3)
+
+ # * if conditional_volume is restored from sparse volume, need mask for pts
+ if conditional_valid_mask_volume is not None:
+ pts_mask = self.get_pts_mask_for_conditional_volume(pts, conditional_valid_mask_volume)
+ pts_mask = pts_mask.reshape(N_rays, n_samples).float().detach()
+ pts_mask_bool = (pts_mask > 0).view(-1)
+
+ if torch.sum(pts_mask_bool.float()) < 1: # ! when render out image, may meet this problem
+ pts_mask_bool[:100] = True
+
+ else:
+ pts_mask = torch.ones([N_rays, n_samples]).to(pts.device)
+ # import ipdb; ipdb.set_trace()
+ # pts_valid = pts[pts_mask_bool]
+ sdf_nn_output = sdf_network.sdf(pts[pts_mask_bool], conditional_volume, lod=lod)
+
+ sdf = torch.ones([N_rays * n_samples, 1]).to(pts.dtype).to(device) * 100
+ sdf[pts_mask_bool] = sdf_nn_output['sdf_pts_scale%d' % lod] # [N_rays*n_samples, 1]
+ feature_vector_valid = sdf_nn_output['sdf_features_pts_scale%d' % lod]
+ feature_vector = torch.zeros([N_rays * n_samples, feature_vector_valid.shape[1]]).to(pts.dtype).to(device)
+ feature_vector[pts_mask_bool] = feature_vector_valid
+
+ # * estimate alpha from sdf
+ gradients = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device)
+ # import ipdb; ipdb.set_trace()
+ gradients[pts_mask_bool] = sdf_network.gradient(
+ pts[pts_mask_bool], conditional_volume, lod=lod).squeeze()
+
+ sampled_color_mlp = None
+ rendering_valid_mask_mlp = None
+ sampled_color_patch = None
+ rendering_patch_mask = None
+
+ if self.if_fitted_rendering: # used for fine-tuning
+ position_latent = sdf_nn_output['sampled_latent_scale%d' % lod]
+ sampled_color_mlp = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device)
+ sampled_color_mlp_mask = torch.zeros([N_rays * n_samples, 1]).to(pts.dtype).to(device)
+
+ # - extract pixel
+ pts_pixel_color, pts_pixel_mask = self.patch_projector.pixel_warp(
+ pts[pts_mask_bool][:, None, :], color_maps, intrinsics,
+ w2cs, img_wh=None) # [N_rays * n_samples,1, N_views, 3] , [N_rays*n_samples, 1, N_views]
+ pts_pixel_color = pts_pixel_color[:, 0, :, :] # [N_rays * n_samples, N_views, 3]
+ pts_pixel_mask = pts_pixel_mask[:, 0, :] # [N_rays*n_samples, N_views]
+
+ # - extract patch
+ if_patch_blending = False if rays_uv is None else True
+ pts_patch_color, pts_patch_mask = None, None
+ if if_patch_blending:
+ pts_patch_color, pts_patch_mask = self.patch_projector.patch_warp(
+ pts.reshape([N_rays, n_samples, 3]),
+ rays_uv, gradients.reshape([N_rays, n_samples, 3]),
+ color_maps,
+ intrinsics[0], intrinsics,
+ query_c2w[0], torch.inverse(w2cs), img_wh=None
+ ) # (N_rays, n_samples, N_src, Npx, 3), (N_rays, n_samples, N_src, Npx)
+ N_src, Npx = pts_patch_mask.shape[2:]
+ pts_patch_color = pts_patch_color.view(N_rays * n_samples, N_src, Npx, 3)[pts_mask_bool]
+ pts_patch_mask = pts_patch_mask.view(N_rays * n_samples, N_src, Npx)[pts_mask_bool]
+
+ sampled_color_patch = torch.zeros([N_rays * n_samples, Npx, 3]).to(device)
+ sampled_color_patch_mask = torch.zeros([N_rays * n_samples, 1]).to(device)
+
+ sampled_color_mlp_, sampled_color_mlp_mask_, \
+ sampled_color_patch_, sampled_color_patch_mask_ = sdf_network.color_blend(
+ pts[pts_mask_bool],
+ position_latent,
+ gradients[pts_mask_bool],
+ dirs[pts_mask_bool],
+ feature_vector[pts_mask_bool],
+ img_index=img_index,
+ pts_pixel_color=pts_pixel_color,
+ pts_pixel_mask=pts_pixel_mask,
+ pts_patch_color=pts_patch_color,
+ pts_patch_mask=pts_patch_mask
+
+ ) # [n, 3], [n, 1]
+ sampled_color_mlp[pts_mask_bool] = sampled_color_mlp_
+ sampled_color_mlp_mask[pts_mask_bool] = sampled_color_mlp_mask_.float()
+ sampled_color_mlp = sampled_color_mlp.view(N_rays, n_samples, 3)
+ sampled_color_mlp_mask = sampled_color_mlp_mask.view(N_rays, n_samples)
+ rendering_valid_mask_mlp = torch.mean(pts_mask * sampled_color_mlp_mask, dim=-1, keepdim=True) > 0.5
+
+ # patch blending
+ if if_patch_blending:
+ sampled_color_patch[pts_mask_bool] = sampled_color_patch_
+ sampled_color_patch_mask[pts_mask_bool] = sampled_color_patch_mask_.float()
+ sampled_color_patch = sampled_color_patch.view(N_rays, n_samples, Npx, 3)
+ sampled_color_patch_mask = sampled_color_patch_mask.view(N_rays, n_samples)
+ rendering_patch_mask = torch.mean(pts_mask * sampled_color_patch_mask, dim=-1,
+ keepdim=True) > 0.5 # [N_rays, 1]
+ else:
+ sampled_color_patch, rendering_patch_mask = None, None
+
+ if if_general_rendering: # used for general training
+ # [512, 128, 16]; [4, 512, 128, 59]; [4, 512, 128, 4]
+ ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = self.rendering_projector.compute(
+ pts.view(N_rays, n_samples, 3),
+ # * 3d geometry feature volumes
+ geometryVolume=conditional_volume[0],
+ geometryVolumeMask=conditional_valid_mask_volume[0],
+ # * 2d rendering feature maps
+ rendering_feature_maps=feature_maps, # [n_views, 56, 256, 256]
+ color_maps=color_maps,
+ w2cs=w2cs,
+ intrinsics=intrinsics,
+ img_wh=img_wh,
+ query_img_idx=0, # the index of the N_views dim for rendering
+ query_c2w=query_c2w,
+ )
+
+ # (N_rays, n_samples, 3)
+ if if_render_with_grad:
+ # import ipdb; ipdb.set_trace()
+ # [nrays, 3] [nrays, 1]
+ sampled_color, rendering_valid_mask = rendering_network(
+ ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask)
+ # import ipdb; ipdb.set_trace()
+ else:
+ with torch.no_grad():
+ sampled_color, rendering_valid_mask = rendering_network(
+ ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask)
+ else:
+ sampled_color, rendering_valid_mask = None, None
+
+ inv_variance = self.variance_network(feature_vector)[:, :1].clip(1e-6, 1e6)
+
+ true_dot_val = (dirs * gradients).sum(-1, keepdim=True) # * calculate
+
+ iter_cos = -(F.relu(-true_dot_val * 0.5 + 0.5) * (1.0 - alpha_inter_ratio) + F.relu(
+ -true_dot_val) * alpha_inter_ratio) # always non-positive
+
+ iter_cos = iter_cos * pts_mask.view(-1, 1)
+
+ true_estimate_sdf_half_next = sdf + iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5
+ true_estimate_sdf_half_prev = sdf - iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5
+
+ prev_cdf = torch.sigmoid(true_estimate_sdf_half_prev * inv_variance)
+ next_cdf = torch.sigmoid(true_estimate_sdf_half_next * inv_variance)
+
+ p = prev_cdf - next_cdf
+ c = prev_cdf
+
+ if self.alpha_type == 'div':
+ alpha_sdf = ((p + 1e-5) / (c + 1e-5)).reshape(N_rays, n_samples).clip(0.0, 1.0)
+ elif self.alpha_type == 'uniform':
+ uniform_estimate_sdf_half_next = sdf - dists.reshape(-1, 1) * 0.5
+ uniform_estimate_sdf_half_prev = sdf + dists.reshape(-1, 1) * 0.5
+ uniform_prev_cdf = torch.sigmoid(uniform_estimate_sdf_half_prev * inv_variance)
+ uniform_next_cdf = torch.sigmoid(uniform_estimate_sdf_half_next * inv_variance)
+ uniform_alpha = F.relu(
+ (uniform_prev_cdf - uniform_next_cdf + 1e-5) / (uniform_prev_cdf + 1e-5)).reshape(
+ N_rays, n_samples).clip(0.0, 1.0)
+ alpha_sdf = uniform_alpha
+ else:
+ assert False
+
+ alpha = alpha_sdf
+
+ # - apply pts_mask
+ alpha = alpha * pts_mask
+
+ # pts_radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(N_rays, n_samples)
+ # inside_sphere = (pts_radius < 1.0).float().detach()
+ # relax_inside_sphere = (pts_radius < 1.2).float().detach()
+ inside_sphere = pts_mask
+ relax_inside_sphere = pts_mask
+
+ weights = alpha * torch.cumprod(torch.cat([torch.ones([N_rays, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:,
+ :-1] # n_rays, n_samples
+ weights_sum = weights.sum(dim=-1, keepdim=True)
+ alpha_sum = alpha.sum(dim=-1, keepdim=True)
+
+ if bg_num > 0:
+ weights_sum_fg = weights[:, :-bg_num].sum(dim=-1, keepdim=True)
+ else:
+ weights_sum_fg = weights_sum
+
+ if sampled_color is not None:
+ color = (sampled_color * weights[:, :, None]).sum(dim=1)
+ else:
+ color = None
+ # import ipdb; ipdb.set_trace()
+
+ if background_rgb is not None and color is not None:
+ color = color + background_rgb * (1.0 - weights_sum)
+ # print("color device:" + str(color.device))
+ # if color is not None:
+ # # import ipdb; ipdb.set_trace()
+ # color = color + (1.0 - weights_sum)
+
+
+ ###################* mlp color rendering #####################
+ color_mlp = None
+ # import ipdb; ipdb.set_trace()
+ if sampled_color_mlp is not None:
+ color_mlp = (sampled_color_mlp * weights[:, :, None]).sum(dim=1)
+
+ if background_rgb is not None and color_mlp is not None:
+ color_mlp = color_mlp + background_rgb * (1.0 - weights_sum)
+
+ ############################ * patch blending ################
+ blended_color_patch = None
+ if sampled_color_patch is not None:
+ blended_color_patch = (sampled_color_patch * weights[:, :, None, None]).sum(dim=1) # [N_rays, Npx, 3]
+
+ ######################################################
+
+ gradient_error = (torch.linalg.norm(gradients.reshape(N_rays, n_samples, 3), ord=2,
+ dim=-1) - 1.0) ** 2
+ # ! the gradient normal should be masked out, the pts out of the bounding box should also be penalized
+ gradient_error = (pts_mask * gradient_error).sum() / (
+ (pts_mask).sum() + 1e-5)
+
+ depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True)
+ # print("[TEST]: weights_sum in render_core", weights_sum.mean())
+ # print("[TEST]: weights_sum in render_core NAN number", weights_sum.isnan().sum())
+ # if weights_sum.isnan().sum() > 0:
+ # import ipdb; ipdb.set_trace()
+ return {
+ 'color': color,
+ 'color_mask': rendering_valid_mask, # (N_rays, 1)
+ 'color_mlp': color_mlp,
+ 'color_mlp_mask': rendering_valid_mask_mlp,
+ 'sdf': sdf, # (N_rays, n_samples)
+ 'depth': depth, # (N_rays, 1)
+ 'dists': dists,
+ 'gradients': gradients.reshape(N_rays, n_samples, 3),
+ 'variance': 1.0 / inv_variance,
+ 'mid_z_vals': mid_z_vals,
+ 'weights': weights,
+ 'weights_sum': weights_sum,
+ 'alpha_sum': alpha_sum,
+ 'alpha_mean': alpha.mean(),
+ 'cdf': c.reshape(N_rays, n_samples),
+ 'gradient_error': gradient_error,
+ 'inside_sphere': inside_sphere,
+ 'blended_color_patch': blended_color_patch,
+ 'blended_color_patch_mask': rendering_patch_mask,
+ 'weights_sum_fg': weights_sum_fg
+ }
+
+ def render(self, rays_o, rays_d, near, far, sdf_network, rendering_network,
+ perturb_overwrite=-1,
+ background_rgb=None,
+ alpha_inter_ratio=0.0,
+ # * related to conditional feature
+ lod=None,
+ conditional_volume=None,
+ conditional_valid_mask_volume=None,
+ # * 2d feature maps
+ feature_maps=None,
+ color_maps=None,
+ w2cs=None,
+ intrinsics=None,
+ img_wh=None,
+ query_c2w=None, # -used for testing
+ if_general_rendering=True,
+ if_render_with_grad=True,
+ # * used for blending mlp rendering network
+ img_index=None,
+ rays_uv=None,
+ # * importance sample for second lod network
+ pre_sample=False, # no use here
+ # * for clear foreground
+ bg_ratio=0.0
+ ):
+ device = rays_o.device
+ N_rays = len(rays_o)
+ # sample_dist = 2.0 / self.n_samples
+ sample_dist = ((far - near) / self.n_samples).mean().item()
+ z_vals = torch.linspace(0.0, 1.0, self.n_samples).to(device)
+ z_vals = near + (far - near) * z_vals[None, :]
+
+ bg_num = int(self.n_samples * bg_ratio)
+
+ if z_vals.shape[0] == 1:
+ z_vals = z_vals.repeat(N_rays, 1)
+
+ if bg_num > 0:
+ z_vals_bg = z_vals[:, self.n_samples - bg_num:]
+ z_vals = z_vals[:, :self.n_samples - bg_num]
+
+ n_samples = self.n_samples - bg_num
+ perturb = self.perturb
+
+ # - significantly speed up training, for the second lod network
+ if pre_sample:
+ z_vals = self.sample_z_vals_from_maskVolume(rays_o, rays_d, near, far,
+ conditional_valid_mask_volume)
+
+ if perturb_overwrite >= 0:
+ perturb = perturb_overwrite
+ if perturb > 0:
+ # get intervals between samples
+ mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
+ upper = torch.cat([mids, z_vals[..., -1:]], -1)
+ lower = torch.cat([z_vals[..., :1], mids], -1)
+ # stratified samples in those intervals
+ t_rand = torch.rand(z_vals.shape).to(device)
+ z_vals = lower + (upper - lower) * t_rand
+
+ background_alpha = None
+ background_sampled_color = None
+ z_val_before = z_vals.clone()
+ # Up sample
+ if self.n_importance > 0:
+ with torch.no_grad():
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
+
+ sdf_outputs = sdf_network.sdf(
+ pts.reshape(-1, 3), conditional_volume, lod=lod)
+ # pdb.set_trace()
+ sdf = sdf_outputs['sdf_pts_scale%d' % lod].reshape(N_rays, self.n_samples - bg_num)
+
+ n_steps = 4
+ for i in range(n_steps):
+ new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_importance // n_steps,
+ 64 * 2 ** i,
+ conditional_valid_mask_volume=conditional_valid_mask_volume,
+ )
+
+ # if new_z_vals.isnan().sum() > 0:
+ # import ipdb; ipdb.set_trace()
+
+ z_vals, sdf = self.cat_z_vals(
+ rays_o, rays_d, z_vals, new_z_vals, sdf, lod,
+ sdf_network, gru_fusion=False,
+ conditional_volume=conditional_volume,
+ conditional_valid_mask_volume=conditional_valid_mask_volume,
+ )
+
+ del sdf
+
+ n_samples = self.n_samples + self.n_importance
+
+ # Background
+ ret_outside = None
+
+ # Render
+ if bg_num > 0:
+ z_vals = torch.cat([z_vals, z_vals_bg], dim=1)
+ # if z_vals.isnan().sum() > 0:
+ # import ipdb; ipdb.set_trace()
+ ret_fine = self.render_core(rays_o,
+ rays_d,
+ z_vals,
+ sample_dist,
+ lod,
+ sdf_network,
+ rendering_network,
+ background_rgb=background_rgb,
+ background_alpha=background_alpha,
+ background_sampled_color=background_sampled_color,
+ alpha_inter_ratio=alpha_inter_ratio,
+ # * related to conditional feature
+ conditional_volume=conditional_volume,
+ conditional_valid_mask_volume=conditional_valid_mask_volume,
+ # * 2d feature maps
+ feature_maps=feature_maps,
+ color_maps=color_maps,
+ w2cs=w2cs,
+ intrinsics=intrinsics,
+ img_wh=img_wh,
+ query_c2w=query_c2w,
+ if_general_rendering=if_general_rendering,
+ if_render_with_grad=if_render_with_grad,
+ # * used for blending mlp rendering network
+ img_index=img_index,
+ rays_uv=rays_uv
+ )
+
+ color_fine = ret_fine['color']
+
+ if self.n_outside > 0:
+ color_fine_mask = torch.logical_or(ret_fine['color_mask'], ret_outside['color_mask'])
+ else:
+ color_fine_mask = ret_fine['color_mask']
+
+ weights = ret_fine['weights']
+ weights_sum = ret_fine['weights_sum']
+
+ gradients = ret_fine['gradients']
+ mid_z_vals = ret_fine['mid_z_vals']
+
+ # depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True)
+ depth = ret_fine['depth']
+ depth_varaince = ((mid_z_vals - depth) ** 2 * weights[:, :n_samples]).sum(dim=-1, keepdim=True)
+ variance = ret_fine['variance'].reshape(N_rays, n_samples).mean(dim=-1, keepdim=True)
+
+ # - randomly sample points from the volume, and maximize the sdf
+ pts_random = torch.rand([1024, 3]).float().to(device) * 2 - 1 # normalized to (-1, 1)
+ sdf_random = sdf_network.sdf(pts_random, conditional_volume, lod=lod)['sdf_pts_scale%d' % lod]
+
+ result = {
+ 'depth': depth,
+ 'color_fine': color_fine,
+ 'color_fine_mask': color_fine_mask,
+ 'color_outside': ret_outside['color'] if ret_outside is not None else None,
+ 'color_outside_mask': ret_outside['color_mask'] if ret_outside is not None else None,
+ 'color_mlp': ret_fine['color_mlp'],
+ 'color_mlp_mask': ret_fine['color_mlp_mask'],
+ 'variance': variance.mean(),
+ 'cdf_fine': ret_fine['cdf'],
+ 'depth_variance': depth_varaince,
+ 'weights_sum': weights_sum,
+ 'weights_max': torch.max(weights, dim=-1, keepdim=True)[0],
+ 'alpha_sum': ret_fine['alpha_sum'].mean(),
+ 'alpha_mean': ret_fine['alpha_mean'],
+ 'gradients': gradients,
+ 'weights': weights,
+ 'gradient_error_fine': ret_fine['gradient_error'],
+ 'inside_sphere': ret_fine['inside_sphere'],
+ 'sdf': ret_fine['sdf'],
+ 'sdf_random': sdf_random,
+ 'blended_color_patch': ret_fine['blended_color_patch'],
+ 'blended_color_patch_mask': ret_fine['blended_color_patch_mask'],
+ 'weights_sum_fg': ret_fine['weights_sum_fg']
+ }
+
+ return result
+
+ @torch.no_grad()
+ def sample_z_vals_from_sdfVolume(self, rays_o, rays_d, near, far, sdf_volume, mask_volume):
+ # ? based on sdf to do importance sampling, seems that too biased on pre-estimation
+ device = rays_o.device
+ N_rays = len(rays_o)
+ n_samples = self.n_samples * 2
+
+ z_vals = torch.linspace(0.0, 1.0, n_samples).to(device)
+ z_vals = near + (far - near) * z_vals[None, :]
+
+ if z_vals.shape[0] == 1:
+ z_vals = z_vals.repeat(N_rays, 1)
+
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
+
+ sdf = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), sdf_volume).reshape([N_rays, n_samples])
+
+ new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_samples,
+ 200,
+ conditional_valid_mask_volume=mask_volume,
+ )
+ return new_z_vals
+
+ @torch.no_grad()
+ def sample_z_vals_from_maskVolume(self, rays_o, rays_d, near, far, mask_volume): # don't use
+ device = rays_o.device
+ N_rays = len(rays_o)
+ n_samples = self.n_samples * 2
+
+ z_vals = torch.linspace(0.0, 1.0, n_samples).to(device)
+ z_vals = near + (far - near) * z_vals[None, :]
+
+ if z_vals.shape[0] == 1:
+ z_vals = z_vals.repeat(N_rays, 1)
+
+ mid_z_vals = (z_vals[:, 1:] + z_vals[:, :-1]) * 0.5
+
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None]
+
+ pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), mask_volume).reshape(
+ [N_rays, n_samples - 1])
+
+ # empty voxel set to 0.1, non-empty voxel set to 1
+ weights = torch.where(pts_mask > 0, torch.ones_like(pts_mask).to(device),
+ 0.1 * torch.ones_like(pts_mask).to(device))
+
+ # sample more pts in non-empty voxels
+ z_samples = sample_pdf(z_vals, weights, self.n_samples, det=True).detach()
+ return z_samples
+
+ @torch.no_grad()
+ def filter_pts_by_depthmaps(self, coords, pred_depth_maps, proj_matrices,
+ partial_vol_origin, voxel_size,
+ near, far, depth_interval, d_plane_nums):
+ """
+ Use the pred_depthmaps to remove redundant pts (pruned by sdf, sdf always have two sides, the back side is useless)
+ :param coords: [n, 3] int coords
+ :param pred_depth_maps: [N_views, 1, h, w]
+ :param proj_matrices: [N_views, 4, 4]
+ :param partial_vol_origin: [3]
+ :param voxel_size: 1
+ :param near: 1
+ :param far: 1
+ :param depth_interval: 1
+ :param d_plane_nums: 1
+ :return:
+ """
+ device = pred_depth_maps.device
+ n_views, _, sizeH, sizeW = pred_depth_maps.shape
+
+ if len(partial_vol_origin.shape) == 1:
+ partial_vol_origin = partial_vol_origin[None, :]
+ pts = coords * voxel_size + partial_vol_origin
+
+ rs_grid = pts.unsqueeze(0).expand(n_views, -1, -1)
+ rs_grid = rs_grid.permute(0, 2, 1).contiguous() # [n_views, 3, n_pts]
+ nV = rs_grid.shape[-1]
+ rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1) # [n_views, 4, n_pts]
+
+ # Project grid
+ im_p = proj_matrices @ rs_grid # - transform world pts to image UV space # [n_views, 4, n_pts]
+ im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2]
+ im_x = im_x / im_z
+ im_y = im_y / im_z
+
+ im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1)
+
+ im_grid = im_grid.view(n_views, 1, -1, 2)
+ sampled_depths = torch.nn.functional.grid_sample(pred_depth_maps, im_grid, mode='bilinear',
+ padding_mode='zeros',
+ align_corners=True)[:, 0, 0, :] # [n_views, n_pts]
+ sampled_depths_valid = (sampled_depths > 0.5 * near).float()
+ valid_d_min = (sampled_depths - d_plane_nums * depth_interval).clamp(near.item(),
+ far.item()) * sampled_depths_valid
+ valid_d_max = (sampled_depths + d_plane_nums * depth_interval).clamp(near.item(),
+ far.item()) * sampled_depths_valid
+
+ mask = im_grid.abs() <= 1
+ mask = mask[:, 0] # [n_views, n_pts, 2]
+ mask = (mask.sum(dim=-1) == 2) & (im_z > valid_d_min) & (im_z < valid_d_max)
+
+ mask = mask.view(n_views, -1)
+ mask = mask.permute(1, 0).contiguous() # [num_pts, nviews]
+
+ mask_final = torch.sum(mask.float(), dim=1, keepdim=False) > 0
+
+ return mask_final
+
+ @torch.no_grad()
+ def get_valid_sparse_coords_by_sdf_depthfilter(self, sdf_volume, coords_volume, mask_volume, feature_volume,
+ pred_depth_maps, proj_matrices,
+ partial_vol_origin, voxel_size,
+ near, far, depth_interval, d_plane_nums,
+ threshold=0.02, maximum_pts=110000):
+ """
+ assume batch size == 1, from the first lod to get sparse voxels
+ :param sdf_volume: [1, X, Y, Z]
+ :param coords_volume: [3, X, Y, Z]
+ :param mask_volume: [1, X, Y, Z]
+ :param feature_volume: [C, X, Y, Z]
+ :param threshold:
+ :return:
+ """
+ device = coords_volume.device
+ _, dX, dY, dZ = coords_volume.shape
+
+ def prune(sdf_pts, coords_pts, mask_volume, threshold):
+ occupancy_mask = (torch.abs(sdf_pts) < threshold).squeeze(1) # [num_pts]
+ valid_coords = coords_pts[occupancy_mask]
+
+ # - filter backside surface by depth maps
+ mask_filtered = self.filter_pts_by_depthmaps(valid_coords, pred_depth_maps, proj_matrices,
+ partial_vol_origin, voxel_size,
+ near, far, depth_interval, d_plane_nums)
+ valid_coords = valid_coords[mask_filtered]
+
+ # - dilate
+ occupancy_mask = sparse_to_dense_channel(valid_coords, 1, [dX, dY, dZ], 1, 0, device) # [dX, dY, dZ, 1]
+
+ # - dilate
+ occupancy_mask = occupancy_mask.float()
+ occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ)
+ occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3)
+ occupancy_mask = occupancy_mask.view(-1, 1) > 0
+
+ final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] # [num_pts]
+
+ return final_mask, torch.sum(final_mask.float())
+
+ C, dX, dY, dZ = feature_volume.shape
+ sdf_volume = sdf_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1)
+ coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3)
+ mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1)
+ feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C)
+
+ # - for check
+ # sdf_volume = torch.rand_like(sdf_volume).float().to(sdf_volume.device) * 0.02
+
+ final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold)
+
+ while (valid_num > maximum_pts) and (threshold > 0.003):
+ threshold = threshold - 0.002
+ final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold)
+
+ valid_coords = coords_volume[final_mask] # [N, 3]
+ valid_feature = feature_volume[final_mask] # [N, C]
+
+ valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0,
+ valid_coords], dim=1) # [N, 4], append batch idx
+
+ # ! if the valid_num is still larger than maximum_pts, sample part of pts
+ if valid_num > maximum_pts:
+ valid_num = valid_num.long()
+ occupancy = torch.ones([valid_num]).to(device) > 0
+ choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts,
+ replace=False)
+ ind = torch.nonzero(occupancy).to(device)
+ occupancy[ind[choice]] = False
+ valid_coords = valid_coords[occupancy]
+ valid_feature = valid_feature[occupancy]
+
+ print(threshold, "randomly sample to save memory")
+
+ return valid_coords, valid_feature
+
+ @torch.no_grad()
+ def get_valid_sparse_coords_by_sdf(self, sdf_volume, coords_volume, mask_volume, feature_volume, threshold=0.02,
+ maximum_pts=110000):
+ """
+ assume batch size == 1, from the first lod to get sparse voxels
+ :param sdf_volume: [num_pts, 1]
+ :param coords_volume: [3, X, Y, Z]
+ :param mask_volume: [1, X, Y, Z]
+ :param feature_volume: [C, X, Y, Z]
+ :param threshold:
+ :return:
+ """
+
+ def prune(sdf_volume, mask_volume, threshold):
+ occupancy_mask = torch.abs(sdf_volume) < threshold # [num_pts, 1]
+
+ # - dilate
+ occupancy_mask = occupancy_mask.float()
+ occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ)
+ occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3)
+ occupancy_mask = occupancy_mask.view(-1, 1) > 0
+
+ final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] # [num_pts]
+
+ return final_mask, torch.sum(final_mask.float())
+
+ C, dX, dY, dZ = feature_volume.shape
+ coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3)
+ mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1)
+ feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C)
+
+ final_mask, valid_num = prune(sdf_volume, mask_volume, threshold)
+
+ while (valid_num > maximum_pts) and (threshold > 0.003):
+ threshold = threshold - 0.002
+ final_mask, valid_num = prune(sdf_volume, mask_volume, threshold)
+
+ valid_coords = coords_volume[final_mask] # [N, 3]
+ valid_feature = feature_volume[final_mask] # [N, C]
+
+ valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0,
+ valid_coords], dim=1) # [N, 4], append batch idx
+
+ # ! if the valid_num is still larger than maximum_pts, sample part of pts
+ if valid_num > maximum_pts:
+ device = sdf_volume.device
+ valid_num = valid_num.long()
+ occupancy = torch.ones([valid_num]).to(device) > 0
+ choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts,
+ replace=False)
+ ind = torch.nonzero(occupancy).to(device)
+ occupancy[ind[choice]] = False
+ valid_coords = valid_coords[occupancy]
+ valid_feature = valid_feature[occupancy]
+
+ print(threshold, "randomly sample to save memory")
+
+ return valid_coords, valid_feature
+
+ @torch.no_grad()
+ def extract_fields(self, bound_min, bound_max, resolution, query_func, device,
+ # * related to conditional feature
+ **kwargs
+ ):
+ N = 64
+ X = torch.linspace(bound_min[0], bound_max[0], resolution).to(device).split(N)
+ Y = torch.linspace(bound_min[1], bound_max[1], resolution).to(device).split(N)
+ Z = torch.linspace(bound_min[2], bound_max[2], resolution).to(device).split(N)
+
+ u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
+ with torch.no_grad():
+ for xi, xs in enumerate(X):
+ for yi, ys in enumerate(Y):
+ for zi, zs in enumerate(Z):
+ xx, yy, zz = torch.meshgrid(xs, ys, zs)
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
+
+ # ! attention, the query function is different for extract geometry and fields
+ output = query_func(pts, **kwargs)
+ sdf = output['sdf_pts_scale%d' % kwargs['lod']].reshape(len(xs), len(ys),
+ len(zs)).detach().cpu().numpy()
+
+ u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = -1 * sdf
+ return u
+
+ @torch.no_grad()
+ def extract_geometry(self, sdf_network, bound_min, bound_max, resolution, threshold, device, occupancy_mask=None,
+ # * 3d feature volume
+ **kwargs
+ ):
+ # logging.info('threshold: {}'.format(threshold))
+
+ u = self.extract_fields(bound_min, bound_max, resolution,
+ lambda pts, **kwargs: sdf_network.sdf(pts, **kwargs),
+ # - sdf need to be multiplied by -1
+ device,
+ # * 3d feature volume
+ **kwargs
+ )
+ if occupancy_mask is not None:
+ dX, dY, dZ = occupancy_mask.shape
+ empty_mask = 1 - occupancy_mask
+ empty_mask = empty_mask.view(1, 1, dX, dY, dZ)
+ # - dilation
+ # empty_mask = F.avg_pool3d(empty_mask, kernel_size=7, stride=1, padding=3)
+ empty_mask = F.interpolate(empty_mask, [resolution, resolution, resolution], mode='nearest')
+ empty_mask = empty_mask.view(resolution, resolution, resolution).cpu().numpy() > 0
+ u[empty_mask] = -100
+ del empty_mask
+
+ vertices, triangles = mcubes.marching_cubes(u, threshold)
+ b_max_np = bound_max.detach().cpu().numpy()
+ b_min_np = bound_min.detach().cpu().numpy()
+
+ vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
+ return vertices, triangles, u
+
+ @torch.no_grad()
+ def extract_depth_maps(self, sdf_network, con_volume, intrinsics, c2ws, H, W, near, far):
+ """
+ extract depth maps from the density volume
+ :param con_volume: [1, 1+C, dX, dY, dZ] can by con_volume or sdf_volume
+ :param c2ws: [B, 4, 4]
+ :param H:
+ :param W:
+ :param near:
+ :param far:
+ :return:
+ """
+ device = con_volume.device
+ batch_size = intrinsics.shape[0]
+
+ with torch.no_grad():
+ ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H),
+ torch.linspace(0, W - 1, W)) # pytorch's meshgrid has indexing='ij'
+ p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) # H, W, 3
+
+ intrinsics_inv = torch.inverse(intrinsics)
+
+ p = p.view(-1, 3).float().to(device) # N_rays, 3
+ p = torch.matmul(intrinsics_inv[:, None, :3, :3], p[:, :, None]).squeeze() # Batch, N_rays, 3
+ rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # Batch, N_rays, 3
+ rays_v = torch.matmul(c2ws[:, None, :3, :3], rays_v[:, :, :, None]).squeeze() # Batch, N_rays, 3
+ rays_o = c2ws[:, None, :3, 3].expand(rays_v.shape) # Batch, N_rays, 3
+ rays_d = rays_v
+
+ rays_o = rays_o.contiguous().view(-1, 3)
+ rays_d = rays_d.contiguous().view(-1, 3)
+
+ ################## - sphere tracer to extract depth maps ######################
+ depth_masks_sphere, depth_maps_sphere = self.ray_tracer.extract_depth_maps(
+ rays_o, rays_d,
+ near[None, :].repeat(rays_o.shape[0], 1),
+ far[None, :].repeat(rays_o.shape[0], 1),
+ sdf_network, con_volume
+ )
+
+ depth_maps = depth_maps_sphere.view(batch_size, 1, H, W)
+ depth_masks = depth_masks_sphere.view(batch_size, 1, H, W)
+
+ depth_maps = torch.where(depth_masks, depth_maps,
+ torch.zeros_like(depth_masks.float()).to(device)) # fill invalid pixels by 0
+
+ return depth_maps, depth_masks
diff --git a/SparseNeuS_demo_v1/models/sparse_neus_renderer_normals_new.py b/SparseNeuS_demo_v1/models/sparse_neus_renderer_normals_new.py
new file mode 100644
index 0000000000000000000000000000000000000000..34e22aa312312b4fc7e8225e15f1eea5a2de71d1
--- /dev/null
+++ b/SparseNeuS_demo_v1/models/sparse_neus_renderer_normals_new.py
@@ -0,0 +1,992 @@
+"""
+The codes are heavily borrowed from NeuS
+"""
+
+import os
+import cv2 as cv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import logging
+import mcubes
+import trimesh
+from icecream import ic
+from models.render_utils import sample_pdf
+
+from models.projector import Projector
+from tsparse.torchsparse_utils import sparse_to_dense_channel
+
+from models.fast_renderer import FastRenderer
+
+from models.patch_projector import PatchProjector
+
+from models.rays import gen_rays_between
+
+import pdb
+
+
+class SparseNeuSRenderer(nn.Module):
+ """
+ conditional neus render;
+ optimize on normalized world space;
+ warped by nn.Module to support DataParallel traning
+ """
+
+ def __init__(self,
+ rendering_network_outside,
+ sdf_network,
+ variance_network,
+ rendering_network,
+ n_samples,
+ n_importance,
+ n_outside,
+ perturb,
+ alpha_type='div',
+ conf=None
+ ):
+ super(SparseNeuSRenderer, self).__init__()
+
+ self.conf = conf
+ self.base_exp_dir = conf['general.base_exp_dir']
+
+ # network setups
+ self.rendering_network_outside = rendering_network_outside
+ self.sdf_network = sdf_network
+ self.variance_network = variance_network
+ self.rendering_network = rendering_network
+
+ self.n_samples = n_samples
+ self.n_importance = n_importance
+ self.n_outside = n_outside
+ self.perturb = perturb
+ self.alpha_type = alpha_type
+
+ self.rendering_projector = Projector() # used to obtain features for generalized rendering
+
+ self.h_patch_size = self.conf.get_int('model.h_patch_size', default=3)
+ self.patch_projector = PatchProjector(self.h_patch_size)
+
+ self.ray_tracer = FastRenderer() # ray_tracer to extract depth maps from sdf_volume
+
+ # - fitted rendering or general rendering
+ try:
+ self.if_fitted_rendering = self.sdf_network.if_fitted_rendering
+ except:
+ self.if_fitted_rendering = False
+
+ def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_variance,
+ conditional_valid_mask_volume=None):
+ device = rays_o.device
+ batch_size, n_samples = z_vals.shape
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
+
+ if conditional_valid_mask_volume is not None:
+ pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume)
+ pts_mask = pts_mask.reshape(batch_size, n_samples)
+ pts_mask = pts_mask[:, :-1] * pts_mask[:, 1:] # [batch_size, n_samples-1]
+ else:
+ pts_mask = torch.ones([batch_size, n_samples]).to(pts.device)
+
+ sdf = sdf.reshape(batch_size, n_samples)
+ prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
+ prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
+ mid_sdf = (prev_sdf + next_sdf) * 0.5
+ dot_val = None
+ if self.alpha_type == 'uniform':
+ dot_val = torch.ones([batch_size, n_samples - 1]) * -1.0
+ else:
+ dot_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
+ prev_dot_val = torch.cat([torch.zeros([batch_size, 1]).to(device), dot_val[:, :-1]], dim=-1)
+ dot_val = torch.stack([prev_dot_val, dot_val], dim=-1)
+ dot_val, _ = torch.min(dot_val, dim=-1, keepdim=False)
+ dot_val = dot_val.clip(-10.0, 0.0) * pts_mask
+ dist = (next_z_vals - prev_z_vals)
+ prev_esti_sdf = mid_sdf - dot_val * dist * 0.5
+ next_esti_sdf = mid_sdf + dot_val * dist * 0.5
+ prev_cdf = torch.sigmoid(prev_esti_sdf * inv_variance)
+ next_cdf = torch.sigmoid(next_esti_sdf * inv_variance)
+ alpha_sdf = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
+
+ alpha = alpha_sdf
+
+ # - apply pts_mask
+ alpha = pts_mask * alpha
+
+ weights = alpha * torch.cumprod(
+ torch.cat([torch.ones([batch_size, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, :-1]
+
+ z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
+ return z_samples
+
+ def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, lod,
+ sdf_network, gru_fusion,
+ # * related to conditional feature
+ conditional_volume=None,
+ conditional_valid_mask_volume=None
+ ):
+ device = rays_o.device
+ batch_size, n_samples = z_vals.shape
+ _, n_importance = new_z_vals.shape
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
+
+ if conditional_valid_mask_volume is not None:
+ pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume)
+ pts_mask = pts_mask.reshape(batch_size, n_importance)
+ pts_mask_bool = (pts_mask > 0).view(-1)
+ else:
+ pts_mask = torch.ones([batch_size, n_importance]).to(pts.device)
+
+ new_sdf = torch.ones([batch_size * n_importance, 1]).to(pts.dtype).to(device) * 100
+
+ if torch.sum(pts_mask) > 1:
+ new_outputs = sdf_network.sdf(pts.reshape(-1, 3)[pts_mask_bool], conditional_volume, lod=lod)
+ new_sdf[pts_mask_bool] = new_outputs['sdf_pts_scale%d' % lod] # .reshape(batch_size, n_importance)
+
+ new_sdf = new_sdf.view(batch_size, n_importance)
+
+ z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
+ sdf = torch.cat([sdf, new_sdf], dim=-1)
+
+ z_vals, index = torch.sort(z_vals, dim=-1)
+ xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
+ index = index.reshape(-1)
+ sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
+
+ return z_vals, sdf
+
+ @torch.no_grad()
+ def get_pts_mask_for_conditional_volume(self, pts, mask_volume):
+ """
+
+ :param pts: [N, 3]
+ :param mask_volume: [1, 1, X, Y, Z]
+ :return:
+ """
+ num_pts = pts.shape[0]
+ pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1)
+
+ pts = torch.flip(pts, dims=[-1])
+
+ pts_mask = F.grid_sample(mask_volume, pts, mode='nearest') # [1, c, 1, 1, num_pts]
+ pts_mask = pts_mask.view(-1, num_pts).permute(1, 0).contiguous() # [num_pts, 1]
+
+ return pts_mask
+
+ def render_core(self,
+ rays_o,
+ rays_d,
+ z_vals,
+ sample_dist,
+ lod,
+ sdf_network,
+ rendering_network,
+ background_alpha=None, # - no use here
+ background_sampled_color=None, # - no use here
+ background_rgb=None, # - no use here
+ alpha_inter_ratio=0.0,
+ # * related to conditional feature
+ conditional_volume=None,
+ conditional_valid_mask_volume=None,
+ # * 2d feature maps
+ feature_maps=None,
+ color_maps=None,
+ w2cs=None,
+ intrinsics=None,
+ img_wh=None,
+ query_c2w=None, # - used for testing
+ if_general_rendering=True,
+ if_render_with_grad=True,
+ # * used for blending mlp rendering network
+ img_index=None,
+ rays_uv=None,
+ # * used for clear bg and fg
+ bg_num=0
+ ):
+ device = rays_o.device
+ N_rays = rays_o.shape[0]
+ _, n_samples = z_vals.shape
+ dists = z_vals[..., 1:] - z_vals[..., :-1]
+ dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape).to(device)], -1)
+
+ mid_z_vals = z_vals + dists * 0.5
+ mid_dists = mid_z_vals[..., 1:] - mid_z_vals[..., :-1]
+
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
+ dirs = rays_d[:, None, :].expand(pts.shape)
+
+ pts = pts.reshape(-1, 3)
+ dirs = dirs.reshape(-1, 3)
+
+ # * if conditional_volume is restored from sparse volume, need mask for pts
+ if conditional_valid_mask_volume is not None:
+ pts_mask = self.get_pts_mask_for_conditional_volume(pts, conditional_valid_mask_volume)
+ pts_mask = pts_mask.reshape(N_rays, n_samples).float().detach()
+ pts_mask_bool = (pts_mask > 0).view(-1)
+
+ if torch.sum(pts_mask_bool.float()) < 1: # ! when render out image, may meet this problem
+ pts_mask_bool[:100] = True
+
+ else:
+ pts_mask = torch.ones([N_rays, n_samples]).to(pts.device)
+ # import ipdb; ipdb.set_trace()
+ # pts_valid = pts[pts_mask_bool]
+ sdf_nn_output = sdf_network.sdf(pts[pts_mask_bool], conditional_volume, lod=lod)
+
+ sdf = torch.ones([N_rays * n_samples, 1]).to(pts.dtype).to(device) * 100
+ sdf[pts_mask_bool] = sdf_nn_output['sdf_pts_scale%d' % lod] # [N_rays*n_samples, 1]
+ feature_vector_valid = sdf_nn_output['sdf_features_pts_scale%d' % lod]
+ feature_vector = torch.zeros([N_rays * n_samples, feature_vector_valid.shape[1]]).to(pts.dtype).to(device)
+ feature_vector[pts_mask_bool] = feature_vector_valid
+
+ # * estimate alpha from sdf
+ gradients = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device)
+ # import ipdb; ipdb.set_trace()
+ gradients[pts_mask_bool] = sdf_network.gradient(
+ pts[pts_mask_bool], conditional_volume, lod=lod).squeeze()
+
+ sampled_color_mlp = None
+ rendering_valid_mask_mlp = None
+ sampled_color_patch = None
+ rendering_patch_mask = None
+
+ if self.if_fitted_rendering: # used for fine-tuning
+ position_latent = sdf_nn_output['sampled_latent_scale%d' % lod]
+ sampled_color_mlp = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device)
+ sampled_color_mlp_mask = torch.zeros([N_rays * n_samples, 1]).to(pts.dtype).to(device)
+
+ # - extract pixel
+ pts_pixel_color, pts_pixel_mask = self.patch_projector.pixel_warp(
+ pts[pts_mask_bool][:, None, :], color_maps, intrinsics,
+ w2cs, img_wh=None) # [N_rays * n_samples,1, N_views, 3] , [N_rays*n_samples, 1, N_views]
+ pts_pixel_color = pts_pixel_color[:, 0, :, :] # [N_rays * n_samples, N_views, 3]
+ pts_pixel_mask = pts_pixel_mask[:, 0, :] # [N_rays*n_samples, N_views]
+
+ # - extract patch
+ if_patch_blending = False if rays_uv is None else True
+ pts_patch_color, pts_patch_mask = None, None
+ if if_patch_blending:
+ pts_patch_color, pts_patch_mask = self.patch_projector.patch_warp(
+ pts.reshape([N_rays, n_samples, 3]),
+ rays_uv, gradients.reshape([N_rays, n_samples, 3]),
+ color_maps,
+ intrinsics[0], intrinsics,
+ query_c2w[0], torch.inverse(w2cs), img_wh=None
+ ) # (N_rays, n_samples, N_src, Npx, 3), (N_rays, n_samples, N_src, Npx)
+ N_src, Npx = pts_patch_mask.shape[2:]
+ pts_patch_color = pts_patch_color.view(N_rays * n_samples, N_src, Npx, 3)[pts_mask_bool]
+ pts_patch_mask = pts_patch_mask.view(N_rays * n_samples, N_src, Npx)[pts_mask_bool]
+
+ sampled_color_patch = torch.zeros([N_rays * n_samples, Npx, 3]).to(device)
+ sampled_color_patch_mask = torch.zeros([N_rays * n_samples, 1]).to(device)
+
+ sampled_color_mlp_, sampled_color_mlp_mask_, \
+ sampled_color_patch_, sampled_color_patch_mask_ = sdf_network.color_blend(
+ pts[pts_mask_bool],
+ position_latent,
+ gradients[pts_mask_bool],
+ dirs[pts_mask_bool],
+ feature_vector[pts_mask_bool],
+ img_index=img_index,
+ pts_pixel_color=pts_pixel_color,
+ pts_pixel_mask=pts_pixel_mask,
+ pts_patch_color=pts_patch_color,
+ pts_patch_mask=pts_patch_mask
+
+ ) # [n, 3], [n, 1]
+ sampled_color_mlp[pts_mask_bool] = sampled_color_mlp_
+ sampled_color_mlp_mask[pts_mask_bool] = sampled_color_mlp_mask_.float()
+ sampled_color_mlp = sampled_color_mlp.view(N_rays, n_samples, 3)
+ sampled_color_mlp_mask = sampled_color_mlp_mask.view(N_rays, n_samples)
+ rendering_valid_mask_mlp = torch.mean(pts_mask * sampled_color_mlp_mask, dim=-1, keepdim=True) > 0.5
+
+ # patch blending
+ if if_patch_blending:
+ sampled_color_patch[pts_mask_bool] = sampled_color_patch_
+ sampled_color_patch_mask[pts_mask_bool] = sampled_color_patch_mask_.float()
+ sampled_color_patch = sampled_color_patch.view(N_rays, n_samples, Npx, 3)
+ sampled_color_patch_mask = sampled_color_patch_mask.view(N_rays, n_samples)
+ rendering_patch_mask = torch.mean(pts_mask * sampled_color_patch_mask, dim=-1,
+ keepdim=True) > 0.5 # [N_rays, 1]
+ else:
+ sampled_color_patch, rendering_patch_mask = None, None
+
+ if if_general_rendering: # used for general training
+ # [512, 128, 16]; [4, 512, 128, 59]; [4, 512, 128, 4]
+ ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = self.rendering_projector.compute_view_independent(
+ pts.view(N_rays, n_samples, 3),
+ # * 3d geometry feature volumes
+ geometryVolume=conditional_volume[0],
+ geometryVolumeMask=conditional_valid_mask_volume[0],
+ sdf_network=sdf_network,
+ lod=lod,
+ # * 2d rendering feature maps
+ rendering_feature_maps=feature_maps, # [n_views, 56, 256, 256]
+ color_maps=color_maps,
+ w2cs=w2cs,
+ intrinsics=intrinsics,
+ img_wh=img_wh,
+ query_img_idx=0, # the index of the N_views dim for rendering
+ query_c2w=query_c2w,
+ )
+
+ # (N_rays, n_samples, 3)
+ if if_render_with_grad:
+ # import ipdb; ipdb.set_trace()
+ # [nrays, 3] [nrays, 1]
+ sampled_color, rendering_valid_mask = rendering_network(
+ ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask)
+ # import ipdb; ipdb.set_trace()
+ else:
+ with torch.no_grad():
+ sampled_color, rendering_valid_mask = rendering_network(
+ ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask)
+ else:
+ sampled_color, rendering_valid_mask = None, None
+
+ inv_variance = self.variance_network(feature_vector)[:, :1].clip(1e-6, 1e6)
+
+ true_dot_val = (dirs * gradients).sum(-1, keepdim=True) # * calculate
+
+ iter_cos = -(F.relu(-true_dot_val * 0.5 + 0.5) * (1.0 - alpha_inter_ratio) + F.relu(
+ -true_dot_val) * alpha_inter_ratio) # always non-positive
+
+ iter_cos = iter_cos * pts_mask.view(-1, 1)
+
+ true_estimate_sdf_half_next = sdf + iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5
+ true_estimate_sdf_half_prev = sdf - iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5
+
+ prev_cdf = torch.sigmoid(true_estimate_sdf_half_prev * inv_variance)
+ next_cdf = torch.sigmoid(true_estimate_sdf_half_next * inv_variance)
+
+ p = prev_cdf - next_cdf
+ c = prev_cdf
+
+ if self.alpha_type == 'div':
+ alpha_sdf = ((p + 1e-5) / (c + 1e-5)).reshape(N_rays, n_samples).clip(0.0, 1.0)
+ elif self.alpha_type == 'uniform':
+ uniform_estimate_sdf_half_next = sdf - dists.reshape(-1, 1) * 0.5
+ uniform_estimate_sdf_half_prev = sdf + dists.reshape(-1, 1) * 0.5
+ uniform_prev_cdf = torch.sigmoid(uniform_estimate_sdf_half_prev * inv_variance)
+ uniform_next_cdf = torch.sigmoid(uniform_estimate_sdf_half_next * inv_variance)
+ uniform_alpha = F.relu(
+ (uniform_prev_cdf - uniform_next_cdf + 1e-5) / (uniform_prev_cdf + 1e-5)).reshape(
+ N_rays, n_samples).clip(0.0, 1.0)
+ alpha_sdf = uniform_alpha
+ else:
+ assert False
+
+ alpha = alpha_sdf
+
+ # - apply pts_mask
+ alpha = alpha * pts_mask
+
+ # pts_radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(N_rays, n_samples)
+ # inside_sphere = (pts_radius < 1.0).float().detach()
+ # relax_inside_sphere = (pts_radius < 1.2).float().detach()
+ inside_sphere = pts_mask
+ relax_inside_sphere = pts_mask
+
+ weights = alpha * torch.cumprod(torch.cat([torch.ones([N_rays, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:,
+ :-1] # n_rays, n_samples
+ weights_sum = weights.sum(dim=-1, keepdim=True)
+ alpha_sum = alpha.sum(dim=-1, keepdim=True)
+
+ if bg_num > 0:
+ weights_sum_fg = weights[:, :-bg_num].sum(dim=-1, keepdim=True)
+ else:
+ weights_sum_fg = weights_sum
+
+ if sampled_color is not None:
+ color = (sampled_color * weights[:, :, None]).sum(dim=1)
+ else:
+ color = None
+ # import ipdb; ipdb.set_trace()
+
+ if background_rgb is not None and color is not None:
+ color = color + background_rgb * (1.0 - weights_sum)
+ # print("color device:" + str(color.device))
+ # if color is not None:
+ # # import ipdb; ipdb.set_trace()
+ # color = color + (1.0 - weights_sum)
+
+
+ ###################* mlp color rendering #####################
+ color_mlp = None
+ # import ipdb; ipdb.set_trace()
+ if sampled_color_mlp is not None:
+ color_mlp = (sampled_color_mlp * weights[:, :, None]).sum(dim=1)
+
+ if background_rgb is not None and color_mlp is not None:
+ color_mlp = color_mlp + background_rgb * (1.0 - weights_sum)
+
+ ############################ * patch blending ################
+ blended_color_patch = None
+ if sampled_color_patch is not None:
+ blended_color_patch = (sampled_color_patch * weights[:, :, None, None]).sum(dim=1) # [N_rays, Npx, 3]
+
+ ######################################################
+
+ gradient_error = (torch.linalg.norm(gradients.reshape(N_rays, n_samples, 3), ord=2,
+ dim=-1) - 1.0) ** 2
+ # ! the gradient normal should be masked out, the pts out of the bounding box should also be penalized
+ gradient_error = (pts_mask * gradient_error).sum() / (
+ (pts_mask).sum() + 1e-5)
+
+ depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True)
+ # print("[TEST]: weights_sum in render_core", weights_sum.mean())
+ # print("[TEST]: weights_sum in render_core NAN number", weights_sum.isnan().sum())
+ # if weights_sum.isnan().sum() > 0:
+ # import ipdb; ipdb.set_trace()
+ return {
+ 'color': color,
+ 'color_mask': rendering_valid_mask, # (N_rays, 1)
+ 'color_mlp': color_mlp,
+ 'color_mlp_mask': rendering_valid_mask_mlp,
+ 'sdf': sdf, # (N_rays, n_samples)
+ 'depth': depth, # (N_rays, 1)
+ 'dists': dists,
+ 'gradients': gradients.reshape(N_rays, n_samples, 3),
+ 'variance': 1.0 / inv_variance,
+ 'mid_z_vals': mid_z_vals,
+ 'weights': weights,
+ 'weights_sum': weights_sum,
+ 'alpha_sum': alpha_sum,
+ 'alpha_mean': alpha.mean(),
+ 'cdf': c.reshape(N_rays, n_samples),
+ 'gradient_error': gradient_error,
+ 'inside_sphere': inside_sphere,
+ 'blended_color_patch': blended_color_patch,
+ 'blended_color_patch_mask': rendering_patch_mask,
+ 'weights_sum_fg': weights_sum_fg
+ }
+
+ def render(self, rays_o, rays_d, near, far, sdf_network, rendering_network,
+ perturb_overwrite=-1,
+ background_rgb=None,
+ alpha_inter_ratio=0.0,
+ # * related to conditional feature
+ lod=None,
+ conditional_volume=None,
+ conditional_valid_mask_volume=None,
+ # * 2d feature maps
+ feature_maps=None,
+ color_maps=None,
+ w2cs=None,
+ intrinsics=None,
+ img_wh=None,
+ query_c2w=None, # -used for testing
+ if_general_rendering=True,
+ if_render_with_grad=True,
+ # * used for blending mlp rendering network
+ img_index=None,
+ rays_uv=None,
+ # * importance sample for second lod network
+ pre_sample=False, # no use here
+ # * for clear foreground
+ bg_ratio=0.0
+ ):
+ device = rays_o.device
+ N_rays = len(rays_o)
+ # sample_dist = 2.0 / self.n_samples
+ sample_dist = ((far - near) / self.n_samples).mean().item()
+ z_vals = torch.linspace(0.0, 1.0, self.n_samples).to(device)
+ z_vals = near + (far - near) * z_vals[None, :]
+
+ bg_num = int(self.n_samples * bg_ratio)
+
+ if z_vals.shape[0] == 1:
+ z_vals = z_vals.repeat(N_rays, 1)
+
+ if bg_num > 0:
+ z_vals_bg = z_vals[:, self.n_samples - bg_num:]
+ z_vals = z_vals[:, :self.n_samples - bg_num]
+
+ n_samples = self.n_samples - bg_num
+ perturb = self.perturb
+
+ # - significantly speed up training, for the second lod network
+ if pre_sample:
+ z_vals = self.sample_z_vals_from_maskVolume(rays_o, rays_d, near, far,
+ conditional_valid_mask_volume)
+
+ if perturb_overwrite >= 0:
+ perturb = perturb_overwrite
+ if perturb > 0:
+ # get intervals between samples
+ mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
+ upper = torch.cat([mids, z_vals[..., -1:]], -1)
+ lower = torch.cat([z_vals[..., :1], mids], -1)
+ # stratified samples in those intervals
+ t_rand = torch.rand(z_vals.shape).to(device)
+ z_vals = lower + (upper - lower) * t_rand
+
+ background_alpha = None
+ background_sampled_color = None
+ z_val_before = z_vals.clone()
+ # Up sample
+ if self.n_importance > 0:
+ with torch.no_grad():
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
+
+ sdf_outputs = sdf_network.sdf(
+ pts.reshape(-1, 3), conditional_volume, lod=lod)
+ # pdb.set_trace()
+ sdf = sdf_outputs['sdf_pts_scale%d' % lod].reshape(N_rays, self.n_samples - bg_num)
+
+ n_steps = 4
+ for i in range(n_steps):
+ new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_importance // n_steps,
+ 64 * 2 ** i,
+ conditional_valid_mask_volume=conditional_valid_mask_volume,
+ )
+
+ # if new_z_vals.isnan().sum() > 0:
+ # import ipdb; ipdb.set_trace()
+
+ z_vals, sdf = self.cat_z_vals(
+ rays_o, rays_d, z_vals, new_z_vals, sdf, lod,
+ sdf_network, gru_fusion=False,
+ conditional_volume=conditional_volume,
+ conditional_valid_mask_volume=conditional_valid_mask_volume,
+ )
+
+ del sdf
+
+ n_samples = self.n_samples + self.n_importance
+
+ # Background
+ ret_outside = None
+
+ # Render
+ if bg_num > 0:
+ z_vals = torch.cat([z_vals, z_vals_bg], dim=1)
+ # if z_vals.isnan().sum() > 0:
+ # import ipdb; ipdb.set_trace()
+ ret_fine = self.render_core(rays_o,
+ rays_d,
+ z_vals,
+ sample_dist,
+ lod,
+ sdf_network,
+ rendering_network,
+ background_rgb=background_rgb,
+ background_alpha=background_alpha,
+ background_sampled_color=background_sampled_color,
+ alpha_inter_ratio=alpha_inter_ratio,
+ # * related to conditional feature
+ conditional_volume=conditional_volume,
+ conditional_valid_mask_volume=conditional_valid_mask_volume,
+ # * 2d feature maps
+ feature_maps=feature_maps,
+ color_maps=color_maps,
+ w2cs=w2cs,
+ intrinsics=intrinsics,
+ img_wh=img_wh,
+ query_c2w=query_c2w,
+ if_general_rendering=if_general_rendering,
+ if_render_with_grad=if_render_with_grad,
+ # * used for blending mlp rendering network
+ img_index=img_index,
+ rays_uv=rays_uv
+ )
+
+ color_fine = ret_fine['color']
+
+ if self.n_outside > 0:
+ color_fine_mask = torch.logical_or(ret_fine['color_mask'], ret_outside['color_mask'])
+ else:
+ color_fine_mask = ret_fine['color_mask']
+
+ weights = ret_fine['weights']
+ weights_sum = ret_fine['weights_sum']
+
+ gradients = ret_fine['gradients']
+ mid_z_vals = ret_fine['mid_z_vals']
+
+ # depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True)
+ depth = ret_fine['depth']
+ depth_varaince = ((mid_z_vals - depth) ** 2 * weights[:, :n_samples]).sum(dim=-1, keepdim=True)
+ variance = ret_fine['variance'].reshape(N_rays, n_samples).mean(dim=-1, keepdim=True)
+
+ # - randomly sample points from the volume, and maximize the sdf
+ pts_random = torch.rand([1024, 3]).float().to(device) * 2 - 1 # normalized to (-1, 1)
+ sdf_random = sdf_network.sdf(pts_random, conditional_volume, lod=lod)['sdf_pts_scale%d' % lod]
+
+ result = {
+ 'depth': depth,
+ 'color_fine': color_fine,
+ 'color_fine_mask': color_fine_mask,
+ 'color_outside': ret_outside['color'] if ret_outside is not None else None,
+ 'color_outside_mask': ret_outside['color_mask'] if ret_outside is not None else None,
+ 'color_mlp': ret_fine['color_mlp'],
+ 'color_mlp_mask': ret_fine['color_mlp_mask'],
+ 'variance': variance.mean(),
+ 'cdf_fine': ret_fine['cdf'],
+ 'depth_variance': depth_varaince,
+ 'weights_sum': weights_sum,
+ 'weights_max': torch.max(weights, dim=-1, keepdim=True)[0],
+ 'alpha_sum': ret_fine['alpha_sum'].mean(),
+ 'alpha_mean': ret_fine['alpha_mean'],
+ 'gradients': gradients,
+ 'weights': weights,
+ 'gradient_error_fine': ret_fine['gradient_error'],
+ 'inside_sphere': ret_fine['inside_sphere'],
+ 'sdf': ret_fine['sdf'],
+ 'sdf_random': sdf_random,
+ 'blended_color_patch': ret_fine['blended_color_patch'],
+ 'blended_color_patch_mask': ret_fine['blended_color_patch_mask'],
+ 'weights_sum_fg': ret_fine['weights_sum_fg']
+ }
+
+ return result
+
+ @torch.no_grad()
+ def sample_z_vals_from_sdfVolume(self, rays_o, rays_d, near, far, sdf_volume, mask_volume):
+ # ? based on sdf to do importance sampling, seems that too biased on pre-estimation
+ device = rays_o.device
+ N_rays = len(rays_o)
+ n_samples = self.n_samples * 2
+
+ z_vals = torch.linspace(0.0, 1.0, n_samples).to(device)
+ z_vals = near + (far - near) * z_vals[None, :]
+
+ if z_vals.shape[0] == 1:
+ z_vals = z_vals.repeat(N_rays, 1)
+
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
+
+ sdf = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), sdf_volume).reshape([N_rays, n_samples])
+
+ new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_samples,
+ 200,
+ conditional_valid_mask_volume=mask_volume,
+ )
+ return new_z_vals
+
+ @torch.no_grad()
+ def sample_z_vals_from_maskVolume(self, rays_o, rays_d, near, far, mask_volume): # don't use
+ device = rays_o.device
+ N_rays = len(rays_o)
+ n_samples = self.n_samples * 2
+
+ z_vals = torch.linspace(0.0, 1.0, n_samples).to(device)
+ z_vals = near + (far - near) * z_vals[None, :]
+
+ if z_vals.shape[0] == 1:
+ z_vals = z_vals.repeat(N_rays, 1)
+
+ mid_z_vals = (z_vals[:, 1:] + z_vals[:, :-1]) * 0.5
+
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None]
+
+ pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), mask_volume).reshape(
+ [N_rays, n_samples - 1])
+
+ # empty voxel set to 0.1, non-empty voxel set to 1
+ weights = torch.where(pts_mask > 0, torch.ones_like(pts_mask).to(device),
+ 0.1 * torch.ones_like(pts_mask).to(device))
+
+ # sample more pts in non-empty voxels
+ z_samples = sample_pdf(z_vals, weights, self.n_samples, det=True).detach()
+ return z_samples
+
+ @torch.no_grad()
+ def filter_pts_by_depthmaps(self, coords, pred_depth_maps, proj_matrices,
+ partial_vol_origin, voxel_size,
+ near, far, depth_interval, d_plane_nums):
+ """
+ Use the pred_depthmaps to remove redundant pts (pruned by sdf, sdf always have two sides, the back side is useless)
+ :param coords: [n, 3] int coords
+ :param pred_depth_maps: [N_views, 1, h, w]
+ :param proj_matrices: [N_views, 4, 4]
+ :param partial_vol_origin: [3]
+ :param voxel_size: 1
+ :param near: 1
+ :param far: 1
+ :param depth_interval: 1
+ :param d_plane_nums: 1
+ :return:
+ """
+ device = pred_depth_maps.device
+ n_views, _, sizeH, sizeW = pred_depth_maps.shape
+
+ if len(partial_vol_origin.shape) == 1:
+ partial_vol_origin = partial_vol_origin[None, :]
+ pts = coords * voxel_size + partial_vol_origin
+
+ rs_grid = pts.unsqueeze(0).expand(n_views, -1, -1)
+ rs_grid = rs_grid.permute(0, 2, 1).contiguous() # [n_views, 3, n_pts]
+ nV = rs_grid.shape[-1]
+ rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1) # [n_views, 4, n_pts]
+
+ # Project grid
+ im_p = proj_matrices @ rs_grid # - transform world pts to image UV space # [n_views, 4, n_pts]
+ im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2]
+ im_x = im_x / im_z
+ im_y = im_y / im_z
+
+ im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1)
+
+ im_grid = im_grid.view(n_views, 1, -1, 2)
+ sampled_depths = torch.nn.functional.grid_sample(pred_depth_maps, im_grid, mode='bilinear',
+ padding_mode='zeros',
+ align_corners=True)[:, 0, 0, :] # [n_views, n_pts]
+ sampled_depths_valid = (sampled_depths > 0.5 * near).float()
+ valid_d_min = (sampled_depths - d_plane_nums * depth_interval).clamp(near.item(),
+ far.item()) * sampled_depths_valid
+ valid_d_max = (sampled_depths + d_plane_nums * depth_interval).clamp(near.item(),
+ far.item()) * sampled_depths_valid
+
+ mask = im_grid.abs() <= 1
+ mask = mask[:, 0] # [n_views, n_pts, 2]
+ mask = (mask.sum(dim=-1) == 2) & (im_z > valid_d_min) & (im_z < valid_d_max)
+
+ mask = mask.view(n_views, -1)
+ mask = mask.permute(1, 0).contiguous() # [num_pts, nviews]
+
+ mask_final = torch.sum(mask.float(), dim=1, keepdim=False) > 0
+
+ return mask_final
+
+ @torch.no_grad()
+ def get_valid_sparse_coords_by_sdf_depthfilter(self, sdf_volume, coords_volume, mask_volume, feature_volume,
+ pred_depth_maps, proj_matrices,
+ partial_vol_origin, voxel_size,
+ near, far, depth_interval, d_plane_nums,
+ threshold=0.02, maximum_pts=110000):
+ """
+ assume batch size == 1, from the first lod to get sparse voxels
+ :param sdf_volume: [1, X, Y, Z]
+ :param coords_volume: [3, X, Y, Z]
+ :param mask_volume: [1, X, Y, Z]
+ :param feature_volume: [C, X, Y, Z]
+ :param threshold:
+ :return:
+ """
+ device = coords_volume.device
+ _, dX, dY, dZ = coords_volume.shape
+
+ def prune(sdf_pts, coords_pts, mask_volume, threshold):
+ occupancy_mask = (torch.abs(sdf_pts) < threshold).squeeze(1) # [num_pts]
+ valid_coords = coords_pts[occupancy_mask]
+
+ # - filter backside surface by depth maps
+ mask_filtered = self.filter_pts_by_depthmaps(valid_coords, pred_depth_maps, proj_matrices,
+ partial_vol_origin, voxel_size,
+ near, far, depth_interval, d_plane_nums)
+ valid_coords = valid_coords[mask_filtered]
+
+ # - dilate
+ occupancy_mask = sparse_to_dense_channel(valid_coords, 1, [dX, dY, dZ], 1, 0, device) # [dX, dY, dZ, 1]
+
+ # - dilate
+ occupancy_mask = occupancy_mask.float()
+ occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ)
+ occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3)
+ occupancy_mask = occupancy_mask.view(-1, 1) > 0
+
+ final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] # [num_pts]
+
+ return final_mask, torch.sum(final_mask.float())
+
+ C, dX, dY, dZ = feature_volume.shape
+ sdf_volume = sdf_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1)
+ coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3)
+ mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1)
+ feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C)
+
+ # - for check
+ # sdf_volume = torch.rand_like(sdf_volume).float().to(sdf_volume.device) * 0.02
+
+ final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold)
+
+ while (valid_num > maximum_pts) and (threshold > 0.003):
+ threshold = threshold - 0.002
+ final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold)
+
+ valid_coords = coords_volume[final_mask] # [N, 3]
+ valid_feature = feature_volume[final_mask] # [N, C]
+
+ valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0,
+ valid_coords], dim=1) # [N, 4], append batch idx
+
+ # ! if the valid_num is still larger than maximum_pts, sample part of pts
+ if valid_num > maximum_pts:
+ valid_num = valid_num.long()
+ occupancy = torch.ones([valid_num]).to(device) > 0
+ choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts,
+ replace=False)
+ ind = torch.nonzero(occupancy).to(device)
+ occupancy[ind[choice]] = False
+ valid_coords = valid_coords[occupancy]
+ valid_feature = valid_feature[occupancy]
+
+ print(threshold, "randomly sample to save memory")
+
+ return valid_coords, valid_feature
+
+ @torch.no_grad()
+ def get_valid_sparse_coords_by_sdf(self, sdf_volume, coords_volume, mask_volume, feature_volume, threshold=0.02,
+ maximum_pts=110000):
+ """
+ assume batch size == 1, from the first lod to get sparse voxels
+ :param sdf_volume: [num_pts, 1]
+ :param coords_volume: [3, X, Y, Z]
+ :param mask_volume: [1, X, Y, Z]
+ :param feature_volume: [C, X, Y, Z]
+ :param threshold:
+ :return:
+ """
+
+ def prune(sdf_volume, mask_volume, threshold):
+ occupancy_mask = torch.abs(sdf_volume) < threshold # [num_pts, 1]
+
+ # - dilate
+ occupancy_mask = occupancy_mask.float()
+ occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ)
+ occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3)
+ occupancy_mask = occupancy_mask.view(-1, 1) > 0
+
+ final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] # [num_pts]
+
+ return final_mask, torch.sum(final_mask.float())
+
+ C, dX, dY, dZ = feature_volume.shape
+ coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3)
+ mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1)
+ feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C)
+
+ final_mask, valid_num = prune(sdf_volume, mask_volume, threshold)
+
+ while (valid_num > maximum_pts) and (threshold > 0.003):
+ threshold = threshold - 0.002
+ final_mask, valid_num = prune(sdf_volume, mask_volume, threshold)
+
+ valid_coords = coords_volume[final_mask] # [N, 3]
+ valid_feature = feature_volume[final_mask] # [N, C]
+
+ valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0,
+ valid_coords], dim=1) # [N, 4], append batch idx
+
+ # ! if the valid_num is still larger than maximum_pts, sample part of pts
+ if valid_num > maximum_pts:
+ device = sdf_volume.device
+ valid_num = valid_num.long()
+ occupancy = torch.ones([valid_num]).to(device) > 0
+ choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts,
+ replace=False)
+ ind = torch.nonzero(occupancy).to(device)
+ occupancy[ind[choice]] = False
+ valid_coords = valid_coords[occupancy]
+ valid_feature = valid_feature[occupancy]
+
+ print(threshold, "randomly sample to save memory")
+
+ return valid_coords, valid_feature
+
+ @torch.no_grad()
+ def extract_fields(self, bound_min, bound_max, resolution, query_func, device,
+ # * related to conditional feature
+ **kwargs
+ ):
+ N = 64
+ X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
+ Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
+ Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
+
+ u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
+ with torch.no_grad():
+ for xi, xs in enumerate(X):
+ for yi, ys in enumerate(Y):
+ for zi, zs in enumerate(Z):
+ xx, yy, zz = torch.meshgrid(xs, ys, zs)
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).to(device)
+
+ # ! attention, the query function is different for extract geometry and fields
+ output = query_func(pts, **kwargs)
+ sdf = output['sdf_pts_scale%d' % kwargs['lod']].reshape(len(xs), len(ys),
+ len(zs)).detach().cpu().numpy()
+
+ u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = -1 * sdf
+ return u
+
+ @torch.no_grad()
+ def extract_geometry(self, sdf_network, bound_min, bound_max, resolution, threshold, device, occupancy_mask=None,
+ # * 3d feature volume
+ **kwargs
+ ):
+ # logging.info('threshold: {}'.format(threshold))
+
+ u = self.extract_fields(bound_min, bound_max, resolution,
+ lambda pts, **kwargs: sdf_network.sdf(pts, **kwargs),
+ # - sdf need to be multiplied by -1
+ device,
+ # * 3d feature volume
+ **kwargs
+ )
+ if occupancy_mask is not None:
+ dX, dY, dZ = occupancy_mask.shape
+ empty_mask = 1 - occupancy_mask
+ empty_mask = empty_mask.view(1, 1, dX, dY, dZ)
+ # - dilation
+ # empty_mask = F.avg_pool3d(empty_mask, kernel_size=7, stride=1, padding=3)
+ empty_mask = F.interpolate(empty_mask, [resolution, resolution, resolution], mode='nearest')
+ empty_mask = empty_mask.view(resolution, resolution, resolution).cpu().numpy() > 0
+ u[empty_mask] = -100
+ del empty_mask
+
+ vertices, triangles = mcubes.marching_cubes(u, threshold)
+ b_max_np = bound_max.detach().cpu().numpy()
+ b_min_np = bound_min.detach().cpu().numpy()
+
+ vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
+ return vertices, triangles, u
+
+ @torch.no_grad()
+ def extract_depth_maps(self, sdf_network, con_volume, intrinsics, c2ws, H, W, near, far):
+ """
+ extract depth maps from the density volume
+ :param con_volume: [1, 1+C, dX, dY, dZ] can by con_volume or sdf_volume
+ :param c2ws: [B, 4, 4]
+ :param H:
+ :param W:
+ :param near:
+ :param far:
+ :return:
+ """
+ device = con_volume.device
+ batch_size = intrinsics.shape[0]
+
+ with torch.no_grad():
+ ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H),
+ torch.linspace(0, W - 1, W)) # pytorch's meshgrid has indexing='ij'
+ p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) # H, W, 3
+
+ intrinsics_inv = torch.inverse(intrinsics)
+
+ p = p.view(-1, 3).float().to(device) # N_rays, 3
+ p = torch.matmul(intrinsics_inv[:, None, :3, :3], p[:, :, None]).squeeze() # Batch, N_rays, 3
+ rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # Batch, N_rays, 3
+ rays_v = torch.matmul(c2ws[:, None, :3, :3], rays_v[:, :, :, None]).squeeze() # Batch, N_rays, 3
+ rays_o = c2ws[:, None, :3, 3].expand(rays_v.shape) # Batch, N_rays, 3
+ rays_d = rays_v
+
+ rays_o = rays_o.contiguous().view(-1, 3)
+ rays_d = rays_d.contiguous().view(-1, 3)
+
+ ################## - sphere tracer to extract depth maps ######################
+ depth_masks_sphere, depth_maps_sphere = self.ray_tracer.extract_depth_maps(
+ rays_o, rays_d,
+ near[None, :].repeat(rays_o.shape[0], 1),
+ far[None, :].repeat(rays_o.shape[0], 1),
+ sdf_network, con_volume
+ )
+
+ depth_maps = depth_maps_sphere.view(batch_size, 1, H, W)
+ depth_masks = depth_masks_sphere.view(batch_size, 1, H, W)
+
+ depth_maps = torch.where(depth_masks, depth_maps,
+ torch.zeros_like(depth_masks.float()).to(device)) # fill invalid pixels by 0
+
+ return depth_maps, depth_masks
diff --git a/SparseNeuS_demo_v1/models/sparse_sdf_network.py b/SparseNeuS_demo_v1/models/sparse_sdf_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..817f40ed08b7cb65fb284a4666d6f6a4a3c52683
--- /dev/null
+++ b/SparseNeuS_demo_v1/models/sparse_sdf_network.py
@@ -0,0 +1,907 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchsparse.tensor import PointTensor, SparseTensor
+import torchsparse.nn as spnn
+
+from tsparse.modules import SparseCostRegNet
+from tsparse.torchsparse_utils import sparse_to_dense_channel
+from ops.grid_sampler import grid_sample_3d, tricubic_sample_3d
+
+# from .gru_fusion import GRUFusion
+from ops.back_project import back_project_sparse_type
+from ops.generate_grids import generate_grid
+
+from inplace_abn import InPlaceABN
+
+from models.embedder import Embedding
+from models.featurenet import ConvBnReLU
+
+import pdb
+import random
+
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_set_profiling_mode(False)
+
+
+@torch.jit.script
+def fused_mean_variance(x, weight):
+ mean = torch.sum(x * weight, dim=1, keepdim=True)
+ var = torch.sum(weight * (x - mean) ** 2, dim=1, keepdim=True)
+ return mean, var
+
+
+class LatentSDFLayer(nn.Module):
+ def __init__(self,
+ d_in=3,
+ d_out=129,
+ d_hidden=128,
+ n_layers=4,
+ skip_in=(4,),
+ multires=0,
+ bias=0.5,
+ geometric_init=True,
+ weight_norm=True,
+ activation='softplus',
+ d_conditional_feature=16):
+ super(LatentSDFLayer, self).__init__()
+
+ self.d_conditional_feature = d_conditional_feature
+
+ # concat latent code for ench layer input excepting the first layer and the last layer
+ dims_in = [d_in] + [d_hidden + d_conditional_feature for _ in range(n_layers - 2)] + [d_hidden]
+ dims_out = [d_hidden for _ in range(n_layers - 1)] + [d_out]
+
+ self.embed_fn_fine = None
+
+ if multires > 0:
+ embed_fn = Embedding(in_channels=d_in, N_freqs=multires) # * include the input
+ self.embed_fn_fine = embed_fn
+ dims_in[0] = embed_fn.out_channels
+
+ self.num_layers = n_layers
+ self.skip_in = skip_in
+
+ for l in range(0, self.num_layers - 1):
+ if l in self.skip_in:
+ in_dim = dims_in[l] + dims_in[0]
+ else:
+ in_dim = dims_in[l]
+
+ out_dim = dims_out[l]
+ lin = nn.Linear(in_dim, out_dim)
+
+ if geometric_init: # - from IDR code,
+ if l == self.num_layers - 2:
+ torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(in_dim), std=0.0001)
+ torch.nn.init.constant_(lin.bias, -bias)
+ # the channels for latent codes are set to 0
+ torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0)
+ torch.nn.init.constant_(lin.bias[-d_conditional_feature:], 0.0)
+
+ elif multires > 0 and l == 0: # the first layer
+ torch.nn.init.constant_(lin.bias, 0.0)
+ # * the channels for position embeddings are set to 0
+ torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
+ # * the channels for the xyz coordinate (3 channels) for initialized by normal distribution
+ torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
+ elif multires > 0 and l in self.skip_in:
+ torch.nn.init.constant_(lin.bias, 0.0)
+ torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
+ # * the channels for position embeddings (and conditional_feature) are initialized to 0
+ torch.nn.init.constant_(lin.weight[:, -(dims_in[0] - 3 + d_conditional_feature):], 0.0)
+ else:
+ torch.nn.init.constant_(lin.bias, 0.0)
+ torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
+ # the channels for latent code are initialized to 0
+ torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0)
+
+ if weight_norm:
+ lin = nn.utils.weight_norm(lin)
+
+ setattr(self, "lin" + str(l), lin)
+
+ if activation == 'softplus':
+ self.activation = nn.Softplus(beta=100)
+ else:
+ assert activation == 'relu'
+ self.activation = nn.ReLU()
+
+ def forward(self, inputs, latent):
+ inputs = inputs
+ if self.embed_fn_fine is not None:
+ inputs = self.embed_fn_fine(inputs)
+
+ # - only for lod1 network can use the pretrained params of lod0 network
+ if latent.shape[1] != self.d_conditional_feature:
+ latent = torch.cat([latent, latent], dim=1)
+
+ x = inputs
+ for l in range(0, self.num_layers - 1):
+ lin = getattr(self, "lin" + str(l))
+
+ # * due to the conditional bias, different from original neus version
+ if l in self.skip_in:
+ x = torch.cat([x, inputs], 1) / np.sqrt(2)
+
+ if 0 < l < self.num_layers - 1:
+ x = torch.cat([x, latent], 1)
+
+ x = lin(x)
+
+ if l < self.num_layers - 2:
+ x = self.activation(x)
+
+ return x
+
+
+class SparseSdfNetwork(nn.Module):
+ '''
+ Coarse-to-fine sparse cost regularization network
+ return sparse volume feature for extracting sdf
+ '''
+
+ def __init__(self, lod, ch_in, voxel_size, vol_dims,
+ hidden_dim=128, activation='softplus',
+ cost_type='variance_mean',
+ d_pyramid_feature_compress=16,
+ regnet_d_out=8, num_sdf_layers=4,
+ multires=6,
+ ):
+ super(SparseSdfNetwork, self).__init__()
+
+ self.lod = lod # - gradually training, the current regularization lod
+ self.ch_in = ch_in
+ self.voxel_size = voxel_size # - the voxel size of the current volume
+ self.vol_dims = torch.tensor(vol_dims) # - the dims of the current volume
+
+ self.selected_views_num = 2 # the number of selected views for feature aggregation
+ self.hidden_dim = hidden_dim
+ self.activation = activation
+ self.cost_type = cost_type
+ self.d_pyramid_feature_compress = d_pyramid_feature_compress
+ self.gru_fusion = None
+
+ self.regnet_d_out = regnet_d_out
+ self.multires = multires
+
+ self.pos_embedder = Embedding(3, self.multires)
+
+ self.compress_layer = ConvBnReLU(
+ self.ch_in, self.d_pyramid_feature_compress, 3, 1, 1,
+ norm_act=InPlaceABN)
+ sparse_ch_in = self.d_pyramid_feature_compress * 2
+
+ sparse_ch_in = sparse_ch_in + 16 if self.lod > 0 else sparse_ch_in
+ self.sparse_costreg_net = SparseCostRegNet(
+ d_in=sparse_ch_in, d_out=self.regnet_d_out)
+ # self.regnet_d_out = self.sparse_costreg_net.d_out
+
+ if activation == 'softplus':
+ self.activation = nn.Softplus(beta=100)
+ else:
+ assert activation == 'relu'
+ self.activation = nn.ReLU()
+
+ self.sdf_layer = LatentSDFLayer(d_in=3,
+ d_out=self.hidden_dim + 1,
+ d_hidden=self.hidden_dim,
+ n_layers=num_sdf_layers,
+ multires=multires,
+ geometric_init=True,
+ weight_norm=True,
+ activation=activation,
+ d_conditional_feature=16 # self.regnet_d_out
+ )
+
+ def upsample(self, pre_feat, pre_coords, interval, num=8):
+ '''
+
+ :param pre_feat: (Tensor), features from last level, (N, C)
+ :param pre_coords: (Tensor), coordinates from last level, (N, 4) (4 : Batch ind, x, y, z)
+ :param interval: interval of voxels, interval = scale ** 2
+ :param num: 1 -> 8
+ :return: up_feat : (Tensor), upsampled features, (N*8, C)
+ :return: up_coords: (N*8, 4), upsampled coordinates, (4 : Batch ind, x, y, z)
+ '''
+ with torch.no_grad():
+ pos_list = [1, 2, 3, [1, 2], [1, 3], [2, 3], [1, 2, 3]]
+ n, c = pre_feat.shape
+ up_feat = pre_feat.unsqueeze(1).expand(-1, num, -1).contiguous()
+ up_coords = pre_coords.unsqueeze(1).repeat(1, num, 1).contiguous()
+ for i in range(num - 1):
+ up_coords[:, i + 1, pos_list[i]] += interval
+
+ up_feat = up_feat.view(-1, c)
+ up_coords = up_coords.view(-1, 4)
+
+ return up_feat, up_coords
+
+ def aggregate_multiview_features(self, multiview_features, multiview_masks):
+ """
+ aggregate mutli-view features by compute their cost variance
+ :param multiview_features: (num of voxels, num_of_views, c)
+ :param multiview_masks: (num of voxels, num_of_views)
+ :return:
+ """
+ num_pts, n_views, C = multiview_features.shape
+
+ counts = torch.sum(multiview_masks, dim=1, keepdim=False) # [num_pts]
+
+ assert torch.all(counts > 0) # the point is visible for at least 1 view
+
+ volume_sum = torch.sum(multiview_features, dim=1, keepdim=False) # [num_pts, C]
+ volume_sq_sum = torch.sum(multiview_features ** 2, dim=1, keepdim=False)
+
+ if volume_sum.isnan().sum() > 0:
+ import ipdb; ipdb.set_trace()
+
+ del multiview_features
+
+ counts = 1. / (counts + 1e-5)
+ costvar = volume_sq_sum * counts[:, None] - (volume_sum * counts[:, None]) ** 2
+
+ costvar_mean = torch.cat([costvar, volume_sum * counts[:, None]], dim=1)
+ del volume_sum, volume_sq_sum, counts
+
+
+
+ return costvar_mean
+
+ def sparse_to_dense_volume(self, coords, feature, vol_dims, interval, device=None):
+ """
+ convert the sparse volume into dense volume to enable trilinear sampling
+ to save GPU memory;
+ :param coords: [num_pts, 3]
+ :param feature: [num_pts, C]
+ :param vol_dims: [3] dX, dY, dZ
+ :param interval:
+ :return:
+ """
+
+ # * assume batch size is 1
+ if device is None:
+ device = feature.device
+
+ coords_int = (coords / interval).to(torch.int64)
+ vol_dims = (vol_dims / interval).to(torch.int64)
+
+ # - if stored in CPU, too slow
+ dense_volume = sparse_to_dense_channel(
+ coords_int.to(device), feature.to(device), vol_dims.to(device),
+ feature.shape[1], 0, device) # [X, Y, Z, C]
+
+ valid_mask_volume = sparse_to_dense_channel(
+ coords_int.to(device),
+ torch.ones([feature.shape[0], 1]).to(feature.device),
+ vol_dims.to(device),
+ 1, 0, device) # [X, Y, Z, 1]
+
+ dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, C, X, Y, Z]
+ valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, 1, X, Y, Z]
+
+ return dense_volume, valid_mask_volume
+
+ def get_conditional_volume(self, feature_maps, partial_vol_origin, proj_mats, sizeH=None, sizeW=None, lod=0,
+ pre_coords=None, pre_feats=None,
+ ):
+ """
+
+ :param feature_maps: pyramid features (B,V,C0+C1+C2,H,W) fused pyramid features
+ :param partial_vol_origin: [B, 3] the world coordinates of the volume origin (0,0,0)
+ :param proj_mats: projection matrix transform world pts into image space [B,V,4,4] suitable for original image size
+ :param sizeH: the H of original image size
+ :param sizeW: the W of original image size
+ :param pre_coords: the coordinates of sparse volume from the prior lod
+ :param pre_feats: the features of sparse volume from the prior lod
+ :return:
+ """
+ device = proj_mats.device
+ bs = feature_maps.shape[0]
+ N_views = feature_maps.shape[1]
+ minimum_visible_views = np.min([1, N_views - 1])
+ # import ipdb; ipdb.set_trace()
+ outputs = {}
+ pts_samples = []
+
+ # ----coarse to fine----
+
+ # * use fused pyramid feature maps are very important
+ if self.compress_layer is not None:
+ feats = self.compress_layer(feature_maps[0])
+ else:
+ feats = feature_maps[0]
+ feats = feats[:, None, :, :, :] # [V, B, C, H, W]
+ KRcam = proj_mats.permute(1, 0, 2, 3).contiguous() # [V, B, 4, 4]
+ interval = 1
+
+ if self.lod == 0:
+ # ----generate new coords----
+ coords = generate_grid(self.vol_dims, 1)[0]
+ coords = coords.view(3, -1).to(device) # [3, num_pts]
+ up_coords = []
+ for b in range(bs):
+ up_coords.append(torch.cat([torch.ones(1, coords.shape[-1]).to(coords.device) * b, coords]))
+ up_coords = torch.cat(up_coords, dim=1).permute(1, 0).contiguous()
+ # * since we only estimate the geometry of input reference image at one time;
+ # * mask the outside of the camera frustum
+ # import ipdb; ipdb.set_trace()
+ frustum_mask = back_project_sparse_type(
+ up_coords, partial_vol_origin, self.voxel_size,
+ feats, KRcam, sizeH=sizeH, sizeW=sizeW, only_mask=True) # [num_pts, n_views]
+ frustum_mask = torch.sum(frustum_mask, dim=-1) > minimum_visible_views # ! here should be large
+ up_coords = up_coords[frustum_mask] # [num_pts_valid, 4]
+
+ else:
+ # ----upsample coords----
+ assert pre_feats is not None
+ assert pre_coords is not None
+ up_feat, up_coords = self.upsample(pre_feats, pre_coords, 1)
+
+ # ----back project----
+ # give each valid 3d grid point all valid 2D features and masks
+ multiview_features, multiview_masks = back_project_sparse_type(
+ up_coords, partial_vol_origin, self.voxel_size, feats,
+ KRcam, sizeH=sizeH, sizeW=sizeW) # (num of voxels, num_of_views, c), (num of voxels, num_of_views)
+ # num_of_views = all views
+
+ # if multiview_features.isnan().sum() > 0:
+ # import ipdb; ipdb.set_trace()
+
+ # import ipdb; ipdb.set_trace()
+ if self.lod > 0:
+ # ! need another invalid voxels filtering
+ frustum_mask = torch.sum(multiview_masks, dim=-1) > 1
+ up_feat = up_feat[frustum_mask]
+ up_coords = up_coords[frustum_mask]
+ multiview_features = multiview_features[frustum_mask]
+ multiview_masks = multiview_masks[frustum_mask]
+ # if multiview_features.isnan().sum() > 0:
+ # import ipdb; ipdb.set_trace()
+ volume = self.aggregate_multiview_features(multiview_features, multiview_masks) # compute variance for all images features
+ # import ipdb; ipdb.set_trace()
+
+ # if volume.isnan().sum() > 0:
+ # import ipdb; ipdb.set_trace()
+
+ del multiview_features, multiview_masks
+
+ # ----concat feature from last stage----
+ if self.lod != 0:
+ feat = torch.cat([volume, up_feat], dim=1)
+ else:
+ feat = volume
+
+ # batch index is in the last position
+ r_coords = up_coords[:, [1, 2, 3, 0]]
+
+ # if feat.isnan().sum() > 0:
+ # print('feat has nan:', feat.isnan().sum())
+ # import ipdb; ipdb.set_trace()
+
+ sparse_feat = SparseTensor(feat, r_coords.to(
+ torch.int32)) # - directly use sparse tensor to avoid point2voxel operations
+ # import ipdb; ipdb.set_trace()
+ feat = self.sparse_costreg_net(sparse_feat)
+
+ dense_volume, valid_mask_volume = self.sparse_to_dense_volume(up_coords[:, 1:], feat, self.vol_dims, interval,
+ device=None) # [1, C/1, X, Y, Z]
+
+ # if dense_volume.isnan().sum() > 0:
+ # import ipdb; ipdb.set_trace()
+
+
+ outputs['dense_volume_scale%d' % self.lod] = dense_volume # [1, 16, 96, 96, 96]
+ outputs['valid_mask_volume_scale%d' % self.lod] = valid_mask_volume # [1, 1, 96, 96, 96]
+ outputs['visible_mask_scale%d' % self.lod] = valid_mask_volume # [1, 1, 96, 96, 96]
+ outputs['coords_scale%d' % self.lod] = generate_grid(self.vol_dims, interval).to(device)
+ # import ipdb; ipdb.set_trace()
+ return outputs
+
+ def sdf(self, pts, conditional_volume, lod):
+ num_pts = pts.shape[0]
+ device = pts.device
+ pts_ = pts.clone()
+ pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1)
+
+ pts = torch.flip(pts, dims=[-1])
+ # import ipdb; ipdb.set_trace()
+ sampled_feature = grid_sample_3d(conditional_volume, pts) # [1, c, 1, 1, num_pts]
+ sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous().to(device)
+
+ sdf_pts = self.sdf_layer(pts_, sampled_feature)
+
+ outputs = {}
+ outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1]
+ outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:]
+ outputs['sampled_latent_scale%d' % lod] = sampled_feature
+
+ return outputs
+
+ @torch.no_grad()
+ def sdf_from_sdfvolume(self, pts, sdf_volume, lod=0):
+ num_pts = pts.shape[0]
+ device = pts.device
+ pts_ = pts.clone()
+ pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1)
+
+ pts = torch.flip(pts, dims=[-1])
+
+ sdf = torch.nn.functional.grid_sample(sdf_volume, pts, mode='bilinear', align_corners=True,
+ padding_mode='border')
+ sdf = sdf.view(-1, num_pts).permute(1, 0).contiguous().to(device)
+
+ outputs = {}
+ outputs['sdf_pts_scale%d' % lod] = sdf
+
+ return outputs
+
+ @torch.no_grad()
+ def get_sdf_volume(self, conditional_volume, mask_volume, coords_volume, partial_origin):
+ """
+
+ :param conditional_volume: [1,C, dX,dY,dZ]
+ :param mask_volume: [1,1, dX,dY,dZ]
+ :param coords_volume: [1,3, dX,dY,dZ]
+ :return:
+ """
+ device = conditional_volume.device
+ chunk_size = 10240
+
+ _, C, dX, dY, dZ = conditional_volume.shape
+ conditional_volume = conditional_volume.view(C, dX * dY * dZ).permute(1, 0).contiguous()
+ mask_volume = mask_volume.view(-1)
+ coords_volume = coords_volume.view(3, dX * dY * dZ).permute(1, 0).contiguous()
+
+ pts = coords_volume * self.voxel_size + partial_origin # [dX*dY*dZ, 3]
+
+ sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(device)
+
+ conditional_volume = conditional_volume[mask_volume > 0]
+ pts = pts[mask_volume > 0]
+ conditional_volume = conditional_volume.split(chunk_size)
+ pts = pts.split(chunk_size)
+
+ sdf_all = []
+ for pts_part, feature_part in zip(pts, conditional_volume):
+ sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1]
+ sdf_all.append(sdf_part)
+
+ sdf_all = torch.cat(sdf_all, dim=0)
+ sdf_volume[mask_volume > 0] = sdf_all
+ sdf_volume = sdf_volume.view(1, 1, dX, dY, dZ)
+ return sdf_volume
+
+ def gradient(self, x, conditional_volume, lod):
+ """
+ return the gradient of specific lod
+ :param x:
+ :param lod:
+ :return:
+ """
+ x.requires_grad_(True)
+ # import ipdb; ipdb.set_trace()
+ with torch.enable_grad():
+ output = self.sdf(x, conditional_volume, lod)
+ y = output['sdf_pts_scale%d' % lod]
+
+ d_output = torch.ones_like(y, requires_grad=False, device=y.device)
+ # ! Distributed Data Parallel doesn’t work with torch.autograd.grad()
+ # ! (i.e. it will only work if gradients are to be accumulated in .grad attributes of parameters).
+ gradients = torch.autograd.grad(
+ outputs=y,
+ inputs=x,
+ grad_outputs=d_output,
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True)[0]
+ return gradients.unsqueeze(1)
+
+
+def sparse_to_dense_volume(coords, feature, vol_dims, interval, device=None):
+ """
+ convert the sparse volume into dense volume to enable trilinear sampling
+ to save GPU memory;
+ :param coords: [num_pts, 3]
+ :param feature: [num_pts, C]
+ :param vol_dims: [3] dX, dY, dZ
+ :param interval:
+ :return:
+ """
+
+ # * assume batch size is 1
+ if device is None:
+ device = feature.device
+
+ coords_int = (coords / interval).to(torch.int64)
+ vol_dims = (vol_dims / interval).to(torch.int64)
+
+ # - if stored in CPU, too slow
+ dense_volume = sparse_to_dense_channel(
+ coords_int.to(device), feature.to(device), vol_dims.to(device),
+ feature.shape[1], 0, device) # [X, Y, Z, C]
+
+ valid_mask_volume = sparse_to_dense_channel(
+ coords_int.to(device),
+ torch.ones([feature.shape[0], 1]).to(feature.device),
+ vol_dims.to(device),
+ 1, 0, device) # [X, Y, Z, 1]
+
+ dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, C, X, Y, Z]
+ valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, 1, X, Y, Z]
+
+ return dense_volume, valid_mask_volume
+
+
+class SdfVolume(nn.Module):
+ def __init__(self, volume, coords=None, type='dense'):
+ super(SdfVolume, self).__init__()
+ self.volume = torch.nn.Parameter(volume, requires_grad=True)
+ self.coords = coords
+ self.type = type
+
+ def forward(self):
+ return self.volume
+
+
+class FinetuneOctreeSdfNetwork(nn.Module):
+ '''
+ After obtain the conditional volume from generalized network;
+ directly optimize the conditional volume
+ The conditional volume is still sparse
+ '''
+
+ def __init__(self, voxel_size, vol_dims,
+ origin=[-1., -1., -1.],
+ hidden_dim=128, activation='softplus',
+ regnet_d_out=8,
+ multires=6,
+ if_fitted_rendering=True,
+ num_sdf_layers=4,
+ ):
+ super(FinetuneOctreeSdfNetwork, self).__init__()
+
+ self.voxel_size = voxel_size # - the voxel size of the current volume
+ self.vol_dims = torch.tensor(vol_dims) # - the dims of the current volume
+
+ self.origin = torch.tensor(origin).to(torch.float32)
+
+ self.hidden_dim = hidden_dim
+ self.activation = activation
+
+ self.regnet_d_out = regnet_d_out
+
+ self.if_fitted_rendering = if_fitted_rendering
+ self.multires = multires
+ # d_in_embedding = self.regnet_d_out if self.pos_add_type == 'latent' else 3
+ # self.pos_embedder = Embedding(d_in_embedding, self.multires)
+
+ # - the optimized parameters
+ self.sparse_volume_lod0 = None
+ self.sparse_coords_lod0 = None
+
+ if activation == 'softplus':
+ self.activation = nn.Softplus(beta=100)
+ else:
+ assert activation == 'relu'
+ self.activation = nn.ReLU()
+
+ self.sdf_layer = LatentSDFLayer(d_in=3,
+ d_out=self.hidden_dim + 1,
+ d_hidden=self.hidden_dim,
+ n_layers=num_sdf_layers,
+ multires=multires,
+ geometric_init=True,
+ weight_norm=True,
+ activation=activation,
+ d_conditional_feature=16 # self.regnet_d_out
+ )
+
+ # - add mlp rendering when finetuning
+ self.renderer = None
+
+ d_in_renderer = 3 + self.regnet_d_out + 3 + 3
+ self.renderer = BlendingRenderingNetwork(
+ d_feature=self.hidden_dim - 1,
+ mode='idr', # ! the view direction influence a lot
+ d_in=d_in_renderer,
+ d_out=50, # maximum 50 images
+ d_hidden=self.hidden_dim,
+ n_layers=3,
+ weight_norm=True,
+ multires_view=4,
+ squeeze_out=True,
+ )
+
+ def initialize_conditional_volumes(self, dense_volume_lod0, dense_volume_mask_lod0,
+ sparse_volume_lod0=None, sparse_coords_lod0=None):
+ """
+
+ :param dense_volume_lod0: [1,C,dX,dY,dZ]
+ :param dense_volume_mask_lod0: [1,1,dX,dY,dZ]
+ :param dense_volume_lod1:
+ :param dense_volume_mask_lod1:
+ :return:
+ """
+
+ if sparse_volume_lod0 is None:
+ device = dense_volume_lod0.device
+ _, C, dX, dY, dZ = dense_volume_lod0.shape
+
+ dense_volume_lod0 = dense_volume_lod0.view(C, dX * dY * dZ).permute(1, 0).contiguous()
+ mask_lod0 = dense_volume_mask_lod0.view(dX * dY * dZ) > 0
+
+ self.sparse_volume_lod0 = SdfVolume(dense_volume_lod0[mask_lod0], type='sparse')
+
+ coords = generate_grid(self.vol_dims, 1)[0] # [3, dX, dY, dZ]
+ coords = coords.view(3, dX * dY * dZ).permute(1, 0).to(device)
+ self.sparse_coords_lod0 = torch.nn.Parameter(coords[mask_lod0], requires_grad=False)
+ else:
+ self.sparse_volume_lod0 = SdfVolume(sparse_volume_lod0, type='sparse')
+ self.sparse_coords_lod0 = torch.nn.Parameter(sparse_coords_lod0, requires_grad=False)
+
+ def get_conditional_volume(self):
+ dense_volume, valid_mask_volume = sparse_to_dense_volume(
+ self.sparse_coords_lod0,
+ self.sparse_volume_lod0(), self.vol_dims, interval=1,
+ device=None) # [1, C/1, X, Y, Z]
+
+ # valid_mask_volume = self.dense_volume_mask_lod0
+
+ outputs = {}
+ outputs['dense_volume_scale%d' % 0] = dense_volume
+ outputs['valid_mask_volume_scale%d' % 0] = valid_mask_volume
+
+ return outputs
+
+ def tv_regularizer(self):
+ dense_volume, valid_mask_volume = sparse_to_dense_volume(
+ self.sparse_coords_lod0,
+ self.sparse_volume_lod0(), self.vol_dims, interval=1,
+ device=None) # [1, C/1, X, Y, Z]
+
+ dx = (dense_volume[:, :, 1:, :, :] - dense_volume[:, :, :-1, :, :]) ** 2 # [1, C/1, X-1, Y, Z]
+ dy = (dense_volume[:, :, :, 1:, :] - dense_volume[:, :, :, :-1, :]) ** 2 # [1, C/1, X, Y-1, Z]
+ dz = (dense_volume[:, :, :, :, 1:] - dense_volume[:, :, :, :, :-1]) ** 2 # [1, C/1, X, Y, Z-1]
+
+ tv = dx[:, :, :, :-1, :-1] + dy[:, :, :-1, :, :-1] + dz[:, :, :-1, :-1, :] # [1, C/1, X-1, Y-1, Z-1]
+
+ mask = valid_mask_volume[:, :, :-1, :-1, :-1] * valid_mask_volume[:, :, 1:, :-1, :-1] * \
+ valid_mask_volume[:, :, :-1, 1:, :-1] * valid_mask_volume[:, :, :-1, :-1, 1:]
+
+ tv = torch.sqrt(tv + 1e-6).mean(dim=1, keepdim=True) * mask
+ # tv = tv.mean(dim=1, keepdim=True) * mask
+
+ assert torch.all(~torch.isnan(tv))
+
+ return torch.mean(tv)
+
+ def sdf(self, pts, conditional_volume, lod):
+
+ outputs = {}
+
+ num_pts = pts.shape[0]
+ device = pts.device
+ pts_ = pts.clone()
+ pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1)
+
+ pts = torch.flip(pts, dims=[-1])
+
+ sampled_feature = grid_sample_3d(conditional_volume, pts) # [1, c, 1, 1, num_pts]
+ sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous()
+ outputs['sampled_latent_scale%d' % lod] = sampled_feature
+
+ sdf_pts = self.sdf_layer(pts_, sampled_feature)
+
+ lod = 0
+ outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1]
+ outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:]
+
+ return outputs
+
+ def color_blend(self, pts, position, normals, view_dirs, feature_vectors, img_index,
+ pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None):
+
+ return self.renderer(torch.cat([pts, position], dim=-1), normals, view_dirs, feature_vectors,
+ img_index, pts_pixel_color, pts_pixel_mask,
+ pts_patch_color=pts_patch_color, pts_patch_mask=pts_patch_mask)
+
+ def gradient(self, x, conditional_volume, lod):
+ """
+ return the gradient of specific lod
+ :param x:
+ :param lod:
+ :return:
+ """
+ x.requires_grad_(True)
+ output = self.sdf(x, conditional_volume, lod)
+ y = output['sdf_pts_scale%d' % 0]
+
+ d_output = torch.ones_like(y, requires_grad=False, device=y.device)
+
+ gradients = torch.autograd.grad(
+ outputs=y,
+ inputs=x,
+ grad_outputs=d_output,
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True)[0]
+ return gradients.unsqueeze(1)
+
+ @torch.no_grad()
+ def prune_dense_mask(self, threshold=0.02):
+ """
+ Just gradually prune the mask of dense volume to decrease the number of sdf network inference
+ :return:
+ """
+ chunk_size = 10240
+ coords = generate_grid(self.vol_dims_lod0, 1)[0] # [3, dX, dY, dZ]
+
+ _, dX, dY, dZ = coords.shape
+
+ pts = coords.view(3, -1).permute(1,
+ 0).contiguous() * self.voxel_size_lod0 + self.origin[None, :] # [dX*dY*dZ, 3]
+
+ # dense_volume = self.dense_volume_lod0() # [1,C,dX,dY,dZ]
+ dense_volume, _ = sparse_to_dense_volume(
+ self.sparse_coords_lod0,
+ self.sparse_volume_lod0(), self.vol_dims_lod0, interval=1,
+ device=None) # [1, C/1, X, Y, Z]
+
+ sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(dense_volume.device) * 100
+
+ mask = self.dense_volume_mask_lod0.view(-1) > 0
+
+ pts_valid = pts[mask].to(dense_volume.device)
+ feature_valid = dense_volume.view(self.regnet_d_out, -1).permute(1, 0).contiguous()[mask]
+
+ pts_valid = pts_valid.split(chunk_size)
+ feature_valid = feature_valid.split(chunk_size)
+
+ sdf_list = []
+
+ for pts_part, feature_part in zip(pts_valid, feature_valid):
+ sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1]
+ sdf_list.append(sdf_part)
+
+ sdf_list = torch.cat(sdf_list, dim=0)
+
+ sdf_volume[mask] = sdf_list
+
+ occupancy_mask = torch.abs(sdf_volume) < threshold # [num_pts, 1]
+
+ # - dilate
+ occupancy_mask = occupancy_mask.float()
+ occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ)
+ occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3)
+ occupancy_mask = occupancy_mask > 0
+
+ self.dense_volume_mask_lod0 = torch.logical_and(self.dense_volume_mask_lod0,
+ occupancy_mask).float() # (1, 1, dX, dY, dZ)
+
+
+class BlendingRenderingNetwork(nn.Module):
+ def __init__(
+ self,
+ d_feature,
+ mode,
+ d_in,
+ d_out,
+ d_hidden,
+ n_layers,
+ weight_norm=True,
+ multires_view=0,
+ squeeze_out=True,
+ ):
+ super(BlendingRenderingNetwork, self).__init__()
+
+ self.mode = mode
+ self.squeeze_out = squeeze_out
+ dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out]
+
+ self.embedder = None
+ if multires_view > 0:
+ self.embedder = Embedding(3, multires_view)
+ dims[0] += (self.embedder.out_channels - 3)
+
+ self.num_layers = len(dims)
+
+ for l in range(0, self.num_layers - 1):
+ out_dim = dims[l + 1]
+ lin = nn.Linear(dims[l], out_dim)
+
+ if weight_norm:
+ lin = nn.utils.weight_norm(lin)
+
+ setattr(self, "lin" + str(l), lin)
+
+ self.relu = nn.ReLU()
+
+ self.color_volume = None
+
+ self.softmax = nn.Softmax(dim=1)
+
+ self.type = 'blending'
+
+ def sample_pts_from_colorVolume(self, pts):
+ device = pts.device
+ num_pts = pts.shape[0]
+ pts_ = pts.clone()
+ pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1)
+
+ pts = torch.flip(pts, dims=[-1])
+
+ sampled_color = grid_sample_3d(self.color_volume, pts) # [1, c, 1, 1, num_pts]
+ sampled_color = sampled_color.view(-1, num_pts).permute(1, 0).contiguous().to(device)
+
+ return sampled_color
+
+ def forward(self, position, normals, view_dirs, feature_vectors, img_index,
+ pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None):
+ """
+
+ :param position: can be 3d coord or interpolated volume latent
+ :param normals:
+ :param view_dirs:
+ :param feature_vectors:
+ :param img_index: [N_views], used to extract corresponding weights
+ :param pts_pixel_color: [N_pts, N_views, 3]
+ :param pts_pixel_mask: [N_pts, N_views]
+ :param pts_patch_color: [N_pts, N_views, Npx, 3]
+ :return:
+ """
+ if self.embedder is not None:
+ view_dirs = self.embedder(view_dirs)
+
+ rendering_input = None
+
+ if self.mode == 'idr':
+ rendering_input = torch.cat([position, view_dirs, normals, feature_vectors], dim=-1)
+ elif self.mode == 'no_view_dir':
+ rendering_input = torch.cat([position, normals, feature_vectors], dim=-1)
+ elif self.mode == 'no_normal':
+ rendering_input = torch.cat([position, view_dirs, feature_vectors], dim=-1)
+ elif self.mode == 'no_points':
+ rendering_input = torch.cat([view_dirs, normals, feature_vectors], dim=-1)
+ elif self.mode == 'no_points_no_view_dir':
+ rendering_input = torch.cat([normals, feature_vectors], dim=-1)
+
+ x = rendering_input
+
+ for l in range(0, self.num_layers - 1):
+ lin = getattr(self, "lin" + str(l))
+
+ x = lin(x)
+
+ if l < self.num_layers - 2:
+ x = self.relu(x) # [n_pts, d_out]
+
+ ## extract value based on img_index
+ x_extracted = torch.index_select(x, 1, img_index.long())
+
+ weights_pixel = self.softmax(x_extracted) # [n_pts, N_views]
+ weights_pixel = weights_pixel * pts_pixel_mask
+ weights_pixel = weights_pixel / (
+ torch.sum(weights_pixel.float(), dim=1, keepdim=True) + 1e-8) # [n_pts, N_views]
+ final_pixel_color = torch.sum(pts_pixel_color * weights_pixel[:, :, None], dim=1,
+ keepdim=False) # [N_pts, 3]
+
+ final_pixel_mask = torch.sum(pts_pixel_mask.float(), dim=1, keepdim=True) > 0 # [N_pts, 1]
+
+ final_patch_color, final_patch_mask = None, None
+ # pts_patch_color [N_pts, N_views, Npx, 3]; pts_patch_mask [N_pts, N_views, Npx]
+ if pts_patch_color is not None:
+ N_pts, N_views, Npx, _ = pts_patch_color.shape
+ patch_mask = torch.sum(pts_patch_mask, dim=-1, keepdim=False) > Npx - 1 # [N_pts, N_views]
+
+ weights_patch = self.softmax(x_extracted) # [N_pts, N_views]
+ weights_patch = weights_patch * patch_mask
+ weights_patch = weights_patch / (
+ torch.sum(weights_patch.float(), dim=1, keepdim=True) + 1e-8) # [n_pts, N_views]
+
+ final_patch_color = torch.sum(pts_patch_color * weights_patch[:, :, None, None], dim=1,
+ keepdim=False) # [N_pts, Npx, 3]
+ final_patch_mask = torch.sum(patch_mask, dim=1, keepdim=True) > 0 # [N_pts, 1] at least one image sees
+
+ return final_pixel_color, final_pixel_mask, final_patch_color, final_patch_mask
diff --git a/SparseNeuS_demo_v1/models/trainer_finetune.py b/SparseNeuS_demo_v1/models/trainer_finetune.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6203976b2a72dea61e1e728a3b1a225366f56a2
--- /dev/null
+++ b/SparseNeuS_demo_v1/models/trainer_finetune.py
@@ -0,0 +1,979 @@
+"""
+Trainer for fine-tuning
+"""
+import os
+import cv2 as cv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import numpy as np
+import logging
+import mcubes
+import trimesh
+from icecream import ic
+from models.render_utils import sample_pdf
+from utils.misc_utils import visualize_depth_numpy
+
+from utils.training_utils import tocuda, numpy2tensor
+from loss.depth_metric import compute_depth_errors
+from loss.color_loss import OcclusionColorLoss, OcclusionColorPatchLoss
+from loss.depth_loss import DepthLoss, DepthSmoothLoss
+
+from models.projector import Projector
+
+from models.rays import gen_rays_between
+
+from models.sparse_neus_renderer import SparseNeuSRenderer
+
+import pdb
+
+
+class FinetuneTrainer(nn.Module):
+ """
+ Trainer used for fine-tuning
+ """
+
+ def __init__(self,
+ rendering_network_outside,
+ pyramid_feature_network_lod0,
+ pyramid_feature_network_lod1,
+ sdf_network_lod0,
+ sdf_network_lod1,
+ variance_network_lod0,
+ variance_network_lod1,
+ sdf_network_finetune,
+ finetune_lod, # which lod fine-tuning use
+ n_samples,
+ n_importance,
+ n_outside,
+ perturb,
+ alpha_type='div',
+ conf=None
+ ):
+ super(FinetuneTrainer, self).__init__()
+
+ self.conf = conf
+ self.base_exp_dir = conf['general.base_exp_dir']
+
+ self.finetune_lod = finetune_lod
+
+ self.anneal_start = self.conf.get_float('train.anneal_start', default=0.0)
+ self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0)
+ self.end_iter = self.conf.get_int('train.end_iter')
+
+ # network setups
+ self.rendering_network_outside = rendering_network_outside
+ self.pyramid_feature_network_geometry_lod0 = pyramid_feature_network_lod0 # 2D pyramid feature network for geometry
+ self.pyramid_feature_network_geometry_lod1 = pyramid_feature_network_lod1 # use differnet networks for the two lods
+
+ self.sdf_network_lod0 = sdf_network_lod0 # the first lod is density_network
+ self.sdf_network_lod1 = sdf_network_lod1
+
+ # - warpped by ModuleList to support DataParallel
+ self.variance_network_lod0 = variance_network_lod0
+ self.variance_network_lod1 = variance_network_lod1
+ self.variance_network_finetune = variance_network_lod0 if self.finetune_lod == 0 else variance_network_lod1
+
+ self.sdf_network_finetune = sdf_network_finetune
+
+ self.n_samples = n_samples
+ self.n_importance = n_importance
+ self.n_outside = n_outside
+ self.perturb = perturb
+ self.alpha_type = alpha_type
+
+ self.sdf_renderer_finetune = SparseNeuSRenderer(
+ self.rendering_network_outside,
+ self.sdf_network_finetune,
+ self.variance_network_finetune,
+ None, # rendering_network
+ self.n_samples,
+ self.n_importance,
+ self.n_outside,
+ self.perturb,
+ alpha_type='div',
+ conf=self.conf)
+
+ # sdf network weights
+ self.sdf_igr_weight = self.conf.get_float('train.sdf_igr_weight')
+ self.sdf_sparse_weight = self.conf.get_float('train.sdf_sparse_weight', default=0)
+
+ self.sdf_decay_param = self.conf.get_float('train.sdf_decay_param', default=100)
+ self.color_pixel_weight = self.conf.get_float('train.color_pixel_weight', default=1.0)
+ self.color_patch_weight = self.conf.get_float('train.color_patch_weight', default=0.)
+ self.tv_weight = self.conf.get_float('train.tv_weight', default=0.001) # no use
+ self.visibility_beta = self.conf.get_float('train.visibility_beta', default=0.025)
+ self.visibility_gama = self.conf.get_float('train.visibility_gama', default=0.015)
+ self.visibility_penalize_ratio = self.conf.get_float('train.visibility_penalize_ratio', default=0.8)
+ self.visibility_weight_thred = self.conf.get_list('train.visibility_weight_thred', default=[0.7])
+ self.if_visibility_aware = self.conf.get_bool('train.if_visibility_aware', default=True)
+ self.train_from_scratch = self.conf.get_bool('train.train_from_scratch', default=False)
+
+ self.depth_criterion = DepthLoss()
+ self.depth_smooth_criterion = DepthSmoothLoss()
+ self.occlusion_color_criterion = OcclusionColorLoss(beta=self.visibility_beta,
+ gama=self.visibility_gama,
+ weight_thred=self.visibility_weight_thred,
+ occlusion_aware=self.if_visibility_aware)
+ self.occlusion_color_patch_criterion = OcclusionColorPatchLoss(
+ type=self.conf.get_string('train.patch_loss_type', default='ncc'),
+ h_patch_size=self.conf.get_int('model.h_patch_size', default=5),
+ beta=self.visibility_beta, gama=self.visibility_gama,
+ weight_thred=self.visibility_weight_thred,
+ occlusion_aware=self.if_visibility_aware
+ )
+
+ # self.iter_step = 0
+ self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq')
+
+ # - True if fine-tuning
+ self.if_fitted_rendering = self.conf.get_bool('train.if_fitted_rendering', default=False)
+
+ def get_trainable_params(self):
+ # set trainable params
+
+ params = []
+ faster_params = []
+ slower_params = []
+
+ params += self.variance_network_finetune.parameters()
+ slower_params += self.sdf_network_finetune.sparse_volume_lod0.parameters()
+ params += self.sdf_network_finetune.sdf_layer.parameters()
+
+ faster_params += self.sdf_network_finetune.renderer.parameters()
+
+ self.params_to_train = {
+ 'slower_params': slower_params,
+ 'params': params,
+ 'faster_params': faster_params
+ }
+
+ return self.params_to_train
+
+ @torch.no_grad()
+ def prepare_con_volume(self, sample):
+ # * only support batch_size==1
+ sizeW = sample['img_wh'][0]
+ sizeH = sample['img_wh'][1]
+ partial_vol_origin = sample['partial_vol_origin'][None, :] # [B, 3]
+ near, far = sample['near_fars'][0, :1], sample['near_fars'][0, 1:]
+ near = 0.8 * near
+ far = 1.2 * far
+
+ imgs = sample['images']
+ intrinsics = sample['intrinsics']
+ intrinsics_l_4x = intrinsics.clone()
+ intrinsics_l_4x[:, :2] *= 0.25
+ w2cs = sample['w2cs']
+ c2ws = sample['c2ws']
+ proj_matrices = sample['affine_mats'][None, :, :, :]
+
+ # *********************** Lod==0 ***********************
+
+ with torch.no_grad():
+ geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs)
+ # import ipdb; ipdb.set_trace()
+ conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
+ feature_maps=geometry_feature_maps[None, :, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ lod=0,
+ )
+
+ con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
+
+ con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
+ coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
+
+ if self.finetune_lod == 0:
+ return con_volume_lod0, con_valid_mask_volume_lod0, coords_lod0
+
+ # * extract depth maps for all the images for adaptive rendering_network
+ depth_maps_lod0, depth_masks_lod0 = None, None
+ if self.finetune_lod == 1:
+ sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume(
+ con_volume_lod0, con_valid_mask_volume_lod0,
+ coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ]
+
+ if self.finetune_lod == 1:
+ geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
+
+ pre_coords, pre_feats = self.sdf_renderer_finetune.get_valid_sparse_coords_by_sdf(
+ sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0],
+ maximum_pts=200000)
+
+ pre_coords[:, 1:] = pre_coords[:, 1:] * 2
+
+ conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume(
+ feature_maps=geometry_feature_maps_lod1[None, :, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ pre_coords=pre_coords,
+ pre_feats=pre_feats
+ )
+
+ con_volume_lod1 = conditional_features_lod1['dense_volume_scale1']
+ con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1']
+ coords_lod1 = conditional_features_lod1['coords_scale1'] # [1,3,wX,wY,wZ]
+ con_valid_mask_volume_lod0 = F.interpolate(con_valid_mask_volume_lod0, scale_factor=2)
+
+ return con_volume_lod1, con_valid_mask_volume_lod1, coords_lod1
+
+ def initialize_finetune_network(self, sample, sparse_con_volume=None, sparse_coords_volume=None,
+ train_from_scratch=False):
+
+ if not train_from_scratch:
+ if sparse_con_volume is None: # if the
+
+ con_volume, con_mask_volume, _ = self.prepare_con_volume(sample)
+
+ device = con_volume.device
+
+ self.sdf_network_finetune.initialize_conditional_volumes(
+ con_volume,
+ con_mask_volume
+ )
+ else:
+ self.sdf_network_finetune.initialize_conditional_volumes(
+ None,
+ None,
+ sparse_con_volume,
+ sparse_coords_volume
+ )
+ else:
+ device = sample['images'].device
+ vol_dims = self.sdf_network_finetune.vol_dims
+ con_volume = torch.zeros(
+ [1, self.sdf_network_finetune.regnet_d_out, vol_dims[0], vol_dims[1], vol_dims[2]]).to(device)
+ con_mask_volume = torch.ones([1, 1, vol_dims[0], vol_dims[1], vol_dims[2]]).to(device)
+ self.sdf_network_finetune.initialize_conditional_volumes(
+ con_volume,
+ con_mask_volume
+ )
+
+ self.sdf_network_lod0, self.sdf_network_lod1 = None, None
+ self.pyramid_feature_network_geometry_lod0, self.pyramid_feature_network_geometry_lod1 = None, None
+
+ def train_step(self, sample,
+ perturb_overwrite=-1,
+ background_rgb=None,
+ iter_step=0,
+ chunk_size=512,
+ save_vis=False,
+ ):
+
+ # * finetune on one specific scene
+ # * only support batch_size==1
+ # ! attention: the list of string cannot be splited in DataParallel
+ batch_idx = sample['batch_idx'][0]
+ meta = sample['meta'][batch_idx] # the scan lighting ref_view info
+
+ sizeW = sample['img_wh'][0][0]
+ sizeH = sample['img_wh'][0][1]
+ partial_vol_origin = sample['partial_vol_origin'] # [B, 3]
+ near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:]
+
+ img_index = sample['img_index'][0] # [n]
+
+ # the full-size ray variables
+ sample_rays = sample['rays']
+ rays_o = sample_rays['rays_o'][0]
+ rays_d = sample_rays['rays_v'][0]
+ rays_ndc_uv = sample_rays['rays_ndc_uv'][0]
+
+ imgs = sample['images'][0]
+ intrinsics = sample['intrinsics'][0]
+ w2cs = sample['w2cs'][0]
+ proj_matrices = sample['affine_mats']
+ scale_mat = sample['scale_mat']
+ trans_mat = sample['trans_mat']
+
+ query_c2w = sample['query_c2w']
+
+ # *********************** Lod==0 ***********************
+
+ conditional_features_lod0 = self.sdf_network_finetune.get_conditional_volume()
+
+ con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
+ con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
+
+ # coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
+
+ # # - extract mesh
+ if iter_step % self.val_mesh_freq == 0:
+ torch.cuda.empty_cache()
+ self.validate_mesh(self.sdf_network_finetune,
+ self.sdf_renderer_finetune.extract_geometry,
+ conditional_volume=con_volume_lod0,
+ lod=0,
+ threshold=0.,
+ occupancy_mask=con_valid_mask_volume_lod0[0, 0],
+ mode='ft', meta=meta,
+ iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat)
+
+ torch.cuda.empty_cache()
+
+ render_out = self.sdf_renderer_finetune.render(
+ rays_o, rays_d, near, far,
+ self.sdf_network_finetune,
+ None, # rendering_network
+ background_rgb=background_rgb,
+ alpha_inter_ratio=1.0,
+ # * related to conditional feature
+ lod=0,
+ conditional_volume=con_volume_lod0,
+ conditional_valid_mask_volume=con_valid_mask_volume_lod0,
+ # * 2d feature maps
+ feature_maps=None,
+ color_maps=imgs,
+ w2cs=w2cs,
+ intrinsics=intrinsics,
+ img_wh=[sizeW, sizeH],
+ query_c2w=query_c2w,
+ if_general_rendering=False,
+ img_index=img_index,
+ rays_uv=rays_ndc_uv if self.color_patch_weight > 0 else None,
+ )
+
+ # * optional TV regularizer, we don't use in this paper
+ if self.tv_weight > 0:
+ tv = self.sdf_network_finetune.tv_regularizer()
+ else:
+ tv = 0.0
+ render_out['tv'] = tv
+ loss_lod0, losses_lod0, depth_statis_lod0 = self.cal_losses_sdf(render_out, sample_rays, iter_step)
+
+ losses = {
+ # - lod 0
+ 'loss_lod0': loss_lod0,
+ 'losses_lod0': losses_lod0,
+ 'depth_statis_lod0': depth_statis_lod0,
+ }
+
+ return losses
+
+ def val_step(self, sample,
+ perturb_overwrite=-1,
+ background_rgb=None,
+ iter_step=0,
+ chunk_size=512,
+ save_vis=True,
+ ):
+ # * only support batch_size==1
+ # ! attention: the list of string cannot be splited in DataParallel
+ batch_idx = sample['batch_idx'][0]
+ meta = sample['meta'][batch_idx] # the scan lighting ref_view info
+
+ sizeW = sample['img_wh'][0][0]
+ sizeH = sample['img_wh'][0][1]
+ H, W = sizeH, sizeW
+
+ partial_vol_origin = sample['partial_vol_origin'] # [B, 3]
+ near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:]
+
+ img_index = sample['img_index'][0] # [n]
+
+ # the ray variables
+ sample_rays = sample['rays']
+ rays_o = sample_rays['rays_o'][0]
+ rays_d = sample_rays['rays_v'][0]
+ rays_ndc_uv = sample_rays['rays_ndc_uv'][0]
+
+ imgs = sample['images'][0]
+ intrinsics = sample['intrinsics'][0]
+ intrinsics_l_4x = intrinsics.clone()
+ intrinsics_l_4x[:, :2] *= 0.25
+ w2cs = sample['w2cs'][0]
+ c2ws = sample['c2ws'][0]
+ proj_matrices = sample['affine_mats']
+
+ # - the image to render
+ scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale
+ trans_mat = sample['trans_mat']
+ query_c2w = sample['query_c2w'] # [1,4,4]
+ query_w2c = sample['query_w2c'] # [1,4,4]
+ true_img = sample['query_image'][0]
+ true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255)
+
+ depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy()
+
+ true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None
+ if true_depth is not None:
+ true_depth = true_depth[0].cpu().numpy()
+ true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0]
+ else:
+ true_depth_colored = None
+
+ rays_o = rays_o.reshape(-1, 3).split(chunk_size)
+ rays_d = rays_d.reshape(-1, 3).split(chunk_size)
+
+ # - obtain conditional features
+ with torch.no_grad():
+ # - lod 0
+ conditional_features_lod0 = self.sdf_network_finetune.get_conditional_volume()
+
+ con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
+ con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
+ # coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
+
+ out_rgb_fine = []
+ out_normal_fine = []
+ out_depth_fine = []
+
+ out_rgb_mlp = []
+
+ if save_vis:
+ for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
+
+ # ****** lod 0 ****
+ render_out = self.sdf_renderer_finetune.render(
+ rays_o_batch, rays_d_batch, near, far,
+ self.sdf_network_finetune,
+ None,
+ background_rgb=background_rgb,
+ alpha_inter_ratio=1.,
+ # * related to conditional feature
+ lod=0,
+ conditional_volume=con_volume_lod0,
+ conditional_valid_mask_volume=con_valid_mask_volume_lod0,
+ # * 2d feature maps
+ feature_maps=None,
+ color_maps=imgs,
+ w2cs=w2cs,
+ intrinsics=intrinsics,
+ img_wh=[sizeW, sizeH],
+ query_c2w=query_c2w,
+ if_general_rendering=False,
+ if_render_with_grad=False,
+ img_index=img_index,
+ # rays_uv=rays_ndc_uv
+ )
+
+ feasible = lambda key: ((key in render_out) and (render_out[key] is not None))
+
+ if feasible('depth'):
+ out_depth_fine.append(render_out['depth'].detach().cpu().numpy())
+
+ # if render_out['color_coarse'] is not None:
+ if feasible('color_fine'):
+ out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
+
+ if feasible('color_mlp'):
+ out_rgb_mlp.append(render_out['color_mlp'].detach().cpu().numpy())
+
+ if feasible('gradients') and feasible('weights'):
+ if render_out['inside_sphere'] is not None:
+ out_normal_fine.append((render_out['gradients'] * render_out['weights'][:,
+ :self.n_samples + self.n_importance,
+ None] * render_out['inside_sphere'][
+ ..., None]).sum(dim=1).detach().cpu().numpy())
+ else:
+ out_normal_fine.append((render_out['gradients'] * render_out['weights'][:,
+ :self.n_samples + self.n_importance,
+ None]).sum(dim=1).detach().cpu().numpy())
+ del render_out
+
+ # - save visualization of lod 0
+
+ self.save_visualization(true_img, true_depth_colored, out_depth_fine, out_normal_fine,
+ query_w2c[0], out_rgb_fine, H, W,
+ depth_min, depth_max, iter_step, meta, "val_lod0",
+ out_color_mlp=out_rgb_mlp, true_depth=true_depth)
+
+ # - extract mesh
+ if (iter_step % self.val_mesh_freq == 0):
+ torch.cuda.empty_cache()
+ self.validate_mesh(self.sdf_network_finetune,
+ self.sdf_renderer_finetune.extract_geometry,
+ conditional_volume=con_volume_lod0, lod=0,
+ threshold=0,
+ occupancy_mask=con_valid_mask_volume_lod0[0, 0],
+ mode='val', meta=meta,
+ iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat)
+
+ torch.cuda.empty_cache()
+
+ def export_mesh_step(self, sample,
+ perturb_overwrite=-1,
+ background_rgb=None,
+ iter_step=0,
+ chunk_size=512,
+ save_vis=True,
+ ):
+ # * only support batch_size==1
+ # ! attention: the list of string cannot be splited in DataParallel
+ batch_idx = sample['batch_idx'][0]
+ # meta = sample['meta'][batch_idx] # the scan lighting ref_view info
+ meta=''
+
+ sizeW = sample['img_wh'][0][0]
+ sizeH = sample['img_wh'][0][1]
+
+ near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:]
+
+
+ # the ray variables
+ sample_rays = sample['rays']
+ rays_o = sample_rays['rays_o'][0]
+ rays_d = sample_rays['rays_v'][0]
+
+ intrinsics = sample['intrinsics'][0]
+ intrinsics_l_4x = intrinsics.clone()
+ intrinsics_l_4x[:, :2] *= 0.25
+
+
+ # - the image to render
+ scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale
+ trans_mat = sample['trans_mat']
+
+ true_img = sample['query_image'][0]
+ true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255)
+
+
+ rays_o = rays_o.reshape(-1, 3).split(chunk_size)
+ rays_d = rays_d.reshape(-1, 3).split(chunk_size)
+
+ # import ipdb; ipdb.set_trace()
+ # - obtain conditional features
+ with torch.no_grad():
+ # - lod 0
+ conditional_features_lod0 = self.sdf_network_finetune.get_conditional_volume()
+
+ con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
+ con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
+ # coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
+
+
+ # - extract mesh
+
+ torch.cuda.empty_cache()
+ self.validate_mesh(self.sdf_network_finetune,
+ self.sdf_renderer_finetune.extract_geometry,
+ conditional_volume=con_volume_lod0, lod=0,
+ threshold=0,
+ occupancy_mask=con_valid_mask_volume_lod0[0, 0],
+ mode='val', meta=meta,
+ iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat)
+
+ torch.cuda.empty_cache()
+
+ def save_visualization(self, true_img, true_colored_depth, out_depth, out_normal, w2cs, out_color, H, W,
+ depth_min, depth_max, iter_step, meta, comment, out_color_mlp=[], true_depth=None):
+ if len(out_color) > 0:
+ img_fine = (np.concatenate(out_color, axis=0).reshape([H, W, 3]) * 256).clip(0, 255)
+
+ if len(out_color_mlp) > 0:
+ img_mlp = (np.concatenate(out_color_mlp, axis=0).reshape([H, W, 3]) * 256).clip(0, 255)
+
+ if len(out_normal) > 0:
+ normal_img = np.concatenate(out_normal, axis=0)
+ rot = w2cs[:3, :3].detach().cpu().numpy()
+ # - convert normal from world space to camera space
+ normal_img = (np.matmul(rot[None, :, :],
+ normal_img[:, :, None]).reshape([H, W, 3]) * 128 + 128).clip(0, 255)
+ if len(out_depth) > 0:
+ pred_depth = np.concatenate(out_depth, axis=0).reshape([H, W])
+ pred_depth_colored = visualize_depth_numpy(pred_depth, [depth_min, depth_max])[0]
+
+ if len(out_depth) > 0:
+ os.makedirs(os.path.join(self.base_exp_dir, 'depths_' + comment), exist_ok=True)
+ if true_colored_depth is not None:
+
+ if true_depth is not None:
+ depth_error_map = np.abs(true_depth - pred_depth) * 5.0
+ depth_visualized = np.concatenate(
+ [depth_error_map, true_colored_depth, pred_depth_colored, true_img], axis=1)[:, :, ::-1]
+ else:
+ depth_visualized = np.concatenate(
+ [true_colored_depth, pred_depth_colored, true_img])[:, :, ::-1]
+ cv.imwrite(
+ os.path.join(self.base_exp_dir, 'depths_' + comment,
+ '{:0>8d}_{}.png'.format(iter_step, meta)), depth_visualized
+ )
+ else:
+ cv.imwrite(
+ os.path.join(self.base_exp_dir, 'depths_' + comment,
+ '{:0>8d}_{}.png'.format(iter_step, meta)),
+ np.concatenate(
+ [pred_depth_colored, true_img])[:, :, ::-1])
+ if len(out_color) > 0:
+ os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment), exist_ok=True)
+ cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment,
+ '{:0>8d}_{}.png'.format(iter_step, meta)),
+ np.concatenate(
+ [img_fine, true_img])[:, :, ::-1]) # bgr2rgb
+
+ if len(out_color_mlp) > 0:
+ os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment), exist_ok=True)
+ cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment,
+ '{:0>8d}_{}.png'.format(iter_step, meta)),
+ np.concatenate(
+ [img_mlp, true_img])[:, :, ::-1]) # bgr2rgb
+
+ if len(out_normal) > 0:
+ os.makedirs(os.path.join(self.base_exp_dir, 'normals_' + comment), exist_ok=True)
+ cv.imwrite(os.path.join(self.base_exp_dir, 'normals_' + comment,
+ '{:0>8d}_{}.png'.format(iter_step, meta)),
+ normal_img[:, :, ::-1])
+
+ def forward(self, sample,
+ perturb_overwrite=-1,
+ background_rgb=None,
+ iter_step=0,
+ mode='train',
+ save_vis=False,
+ ):
+
+ if mode == 'train':
+ return self.train_step(sample,
+ perturb_overwrite=perturb_overwrite,
+ background_rgb=background_rgb,
+ iter_step=iter_step,
+ )
+ elif mode == 'val':
+ return self.val_step(sample,
+ perturb_overwrite=perturb_overwrite,
+ background_rgb=background_rgb,
+ iter_step=iter_step, save_vis=save_vis,
+ )
+ elif mode == 'export_mesh':
+ return self.export_mesh_step(sample,
+ perturb_overwrite=perturb_overwrite,
+ background_rgb=background_rgb,
+ iter_step=iter_step, save_vis=save_vis,
+ )
+
+ def obtain_pyramid_feature_maps(self, imgs, lod=0):
+ """
+ get feature maps of all conditional images
+ :param imgs:
+ :return:
+ """
+
+ if lod == 0:
+ extractor = self.pyramid_feature_network_geometry_lod0
+ elif lod >= 1:
+ extractor = self.pyramid_feature_network_geometry_lod1
+
+ pyramid_feature_maps = extractor(imgs)
+
+ # * the pyramid features are very important, if only use the coarst features, hard to optimize
+ fused_feature_maps = torch.cat([
+ F.interpolate(pyramid_feature_maps[0], scale_factor=4, mode='bilinear', align_corners=True),
+ F.interpolate(pyramid_feature_maps[1], scale_factor=2, mode='bilinear', align_corners=True),
+ pyramid_feature_maps[2]
+ ], dim=1)
+
+ return fused_feature_maps
+
+ def cal_losses_sdf(self, render_out, sample_rays, iter_step=-1):
+
+ def get_weight(iter_step, weight):
+ if iter_step < 0:
+ return weight
+
+ if self.anneal_end == 0.0:
+ return weight
+ elif iter_step < self.anneal_start:
+ return 0.0
+ else:
+ return np.min(
+ [1.0,
+ (iter_step - self.anneal_start) / (self.anneal_end * 2 - self.anneal_start)]) * weight
+
+ rays_o = sample_rays['rays_o'][0]
+ rays_d = sample_rays['rays_v'][0]
+ true_rgb = sample_rays['rays_color'][0]
+
+ if 'rays_depth' in sample_rays.keys():
+ true_depth = sample_rays['rays_depth'][0]
+ else:
+ true_depth = None
+ mask = sample_rays['rays_mask'][0]
+
+ color_fine = render_out['color_fine']
+ color_fine_mask = render_out['color_fine_mask']
+ depth_pred = render_out['depth']
+
+ variance = render_out['variance']
+ cdf_fine = render_out['cdf_fine']
+ weight_sum = render_out['weights_sum']
+
+ if self.train_from_scratch:
+ occlusion_aware = False if iter_step < 5000 else True
+ else:
+ occlusion_aware = True
+
+ gradient_error_fine = render_out['gradient_error_fine']
+
+ sdf = render_out['sdf']
+
+ # * color generated by mlp
+ color_mlp = render_out['color_mlp']
+ color_mlp_mask = render_out['color_mlp_mask']
+
+ if color_mlp is not None:
+ # Color loss
+ color_mlp_mask = color_mlp_mask[..., 0]
+
+ color_mlp_loss, color_mlp_error = self.occlusion_color_criterion(pred=color_mlp, gt=true_rgb,
+ weight=weight_sum.squeeze(),
+ mask=color_mlp_mask,
+ occlusion_aware=occlusion_aware)
+
+ psnr_mlp = 20.0 * torch.log10(
+ 1.0 / (((color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) ** 2).mean() / (3.0)).sqrt())
+ else:
+ color_mlp_loss = 0.
+ psnr_mlp = 0.
+
+ # - blended patch loss
+ blended_color_patch = render_out['blended_color_patch'] # [N_pts, Npx, 3]
+ blended_color_patch_mask = render_out['blended_color_patch_mask'] # [N_pts, 1]
+ color_patch_loss = 0.0
+ color_patch_error = 0.0
+ visibility_beta = 0.0
+ if blended_color_patch is not None:
+ rays_patch_color = sample_rays['rays_patch_color'][0]
+ rays_patch_mask = sample_rays['rays_patch_mask'][0]
+ patch_mask = (rays_patch_mask * blended_color_patch_mask).float()[:, 0] > 0 # [N_pts]
+
+ color_patch_loss, color_patch_error, visibility_beta = self.occlusion_color_patch_criterion(
+ blended_color_patch,
+ rays_patch_color,
+ weight=weight_sum.squeeze(),
+ mask=patch_mask,
+ penalize_ratio=self.visibility_penalize_ratio,
+ occlusion_aware=occlusion_aware
+ )
+
+ if true_depth is not None:
+ depth_loss = self.depth_criterion(depth_pred, true_depth, mask)
+
+ # depth evaluation
+ depth_statis = compute_depth_errors(depth_pred.detach().cpu().numpy(), true_depth.cpu().numpy(),
+ mask.cpu().numpy() > 0)
+ depth_statis = numpy2tensor(depth_statis, device=rays_o.device)
+ else:
+ depth_loss = 0.
+ depth_statis = None
+
+ # - if without sparse_loss, the mean sdf is 0.02.
+ # - use sparse_loss to prevent occluded pts have 0 sdf
+ sparse_loss_1 = torch.exp(-1 * torch.abs(render_out['sdf_random']) * self.sdf_decay_param * 10).mean()
+ sparse_loss_2 = torch.exp(-1 * torch.abs(sdf) * self.sdf_decay_param).mean()
+ sparse_loss = (sparse_loss_1 + sparse_loss_2) / 2
+
+ sdf_mean = torch.abs(sdf).mean()
+ sparseness_1 = (torch.abs(sdf) < 0.01).to(torch.float32).mean()
+ sparseness_2 = (torch.abs(sdf) < 0.02).to(torch.float32).mean()
+
+ # Eikonal loss
+ gradient_error_loss = gradient_error_fine
+
+ # * optional TV regularizer
+ if 'tv' in render_out.keys():
+ tv = render_out['tv']
+ else:
+ tv = 0.0
+
+ loss = color_mlp_loss + \
+ color_patch_loss * self.color_patch_weight + \
+ sparse_loss * get_weight(iter_step, self.sdf_sparse_weight) + \
+ gradient_error_loss * self.sdf_igr_weight
+
+ losses = {
+ "loss": loss,
+ "depth_loss": depth_loss,
+ "color_mlp_loss": color_mlp_error,
+ "gradient_error_loss": gradient_error_loss,
+ "sparse_loss": sparse_loss,
+ "sparseness_1": sparseness_1,
+ "sparseness_2": sparseness_2,
+ "sdf_mean": sdf_mean,
+ "psnr_mlp": psnr_mlp,
+ "weights_sum": render_out['weights_sum'],
+ "alpha_sum": render_out['alpha_sum'],
+ "variance": render_out['variance'],
+ "sparse_weight": get_weight(iter_step, self.sdf_sparse_weight),
+ 'color_patch_loss': color_patch_error,
+ 'visibility_beta': visibility_beta,
+ 'tv': tv,
+ }
+
+ losses = numpy2tensor(losses, device=rays_o.device)
+
+ return loss, losses, depth_statis
+
+ def validate_mesh(self, sdf_network, func_extract_geometry, world_space=True, resolution=256,
+ threshold=0.0, mode='val',
+ # * 3d feature volume
+ conditional_volume=None, lod=None, occupancy_mask=None,
+ bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None,
+ trans_mat=None
+ ):
+ bound_min = torch.tensor(bound_min, dtype=torch.float32)
+ bound_max = torch.tensor(bound_max, dtype=torch.float32)
+
+ vertices, triangles, fields = func_extract_geometry(
+ sdf_network,
+ bound_min, bound_max, resolution=resolution,
+ threshold=threshold, device=conditional_volume.device,
+ # * 3d feature volume
+ conditional_volume=conditional_volume, lod=lod,
+ # occupancy_mask=occupancy_mask
+ )
+
+
+
+ if scale_mat is not None:
+ scale_mat_np = scale_mat.cpu().numpy()
+ vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None]
+
+ if trans_mat is not None:
+ trans_mat_np = trans_mat.cpu().numpy()
+ vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1)
+ vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0]
+
+ mesh = trimesh.Trimesh(vertices, triangles)
+ os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode), exist_ok=True)
+ mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode,
+ 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod)))
+
+ def gen_video(self, sample,
+ perturb_overwrite=-1,
+ background_rgb=None,
+ iter_step=0,
+ chunk_size=1024,
+ ):
+ # * only support batch_size==1
+ batch_idx = sample['batch_idx'][0]
+ meta = sample['meta'][batch_idx] # the scan lighting ref_view info
+
+ sizeW = sample['img_wh'][0][0]
+ sizeH = sample['img_wh'][0][1]
+ H, W = sizeH, sizeW
+
+ partial_vol_origin = sample['partial_vol_origin'] # [B, 3]
+ near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:] * 0.8
+
+ img_index = sample['img_index'][0] # [n]
+
+ # the ray variables
+ sample_rays = sample['rays']
+ rays_o = sample_rays['rays_o'][0]
+ rays_d = sample_rays['rays_v'][0]
+ rays_ndc_uv = sample_rays['rays_ndc_uv'][0]
+
+ imgs = sample['images'][0]
+ intrinsics = sample['intrinsics'][0]
+ intrinsics_l_4x = intrinsics.clone()
+ intrinsics_l_4x[:, :2] *= 0.25
+ w2cs = sample['w2cs'][0]
+ c2ws = sample['c2ws'][0]
+ proj_matrices = sample['affine_mats']
+
+ # - the image to render
+ scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale
+ trans_mat = sample['trans_mat']
+ query_c2w = sample['query_c2w'] # [1,4,4]
+ query_w2c = sample['query_w2c'] # [1,4,4]
+ true_img = sample['query_image'][0]
+ true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255)
+ rendering_c2ws = sample['rendering_c2ws'][0] # [n, 4, 4]
+ rendering_imgs_idx = sample['rendering_imgs_idx'][0]
+
+ depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy()
+
+ true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None
+ if true_depth is not None:
+ true_depth = true_depth[0].cpu().numpy()
+ true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0]
+ else:
+ true_depth_colored = None
+
+ # - obtain conditional features
+ with torch.no_grad():
+ # - lod 0
+ conditional_features_lod0 = self.sdf_network_finetune.get_conditional_volume()
+
+ con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
+ con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
+ # coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
+
+ inter_views_num = 60
+ resolution_level = 2
+ for r_idx in range(rendering_c2ws.shape[0] - 1):
+ for idx in range(inter_views_num):
+ query_c2w, rays_o, rays_d = gen_rays_between(
+ rendering_c2ws[r_idx], rendering_c2ws[r_idx + 1], intrinsics[0],
+ np.sin(((idx / 60.0) - 0.5) * np.pi) * 0.5 + 0.5,
+ H, W, resolution_level=resolution_level)
+
+ rays_o = rays_o.reshape(-1, 3).split(chunk_size)
+ rays_d = rays_d.reshape(-1, 3).split(chunk_size)
+
+ out_rgb_fine = []
+ out_normal_fine = []
+ out_depth_fine = []
+
+ for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
+ # ****** lod 0 ****
+ render_out = self.sdf_renderer_finetune.render(
+ rays_o_batch, rays_d_batch, near, far,
+ self.sdf_network_finetune,
+ None,
+ background_rgb=background_rgb,
+ alpha_inter_ratio=1.,
+ # * related to conditional feature
+ lod=0,
+ conditional_volume=con_volume_lod0,
+ conditional_valid_mask_volume=con_valid_mask_volume_lod0,
+ # * 2d feature maps
+ feature_maps=None,
+ color_maps=imgs,
+ w2cs=w2cs,
+ intrinsics=intrinsics,
+ img_wh=[sizeW, sizeH],
+ query_c2w=query_c2w,
+ if_general_rendering=False,
+ if_render_with_grad=False,
+ img_index=img_index,
+ # rays_uv=rays_ndc_uv
+ )
+ # pdb.set_trace()
+ feasible = lambda key: ((key in render_out) and (render_out[key] is not None))
+
+ if feasible('depth'):
+ out_depth_fine.append(render_out['depth'].detach().cpu().numpy())
+
+ # if render_out['color_coarse'] is not None:
+ if feasible('color_mlp'):
+ out_rgb_fine.append(render_out['color_mlp'].detach().cpu().numpy())
+ if feasible('gradients') and feasible('weights'):
+ if render_out['inside_sphere'] is not None:
+ out_normal_fine.append((render_out['gradients'] * render_out['weights'][:,
+ :self.n_samples + self.n_importance,
+ None] * render_out['inside_sphere'][
+ ..., None]).sum(dim=1).detach().cpu().numpy())
+ else:
+ out_normal_fine.append((render_out['gradients'] * render_out['weights'][:,
+ :self.n_samples + self.n_importance,
+ None]).sum(dim=1).detach().cpu().numpy())
+ del render_out
+
+ img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape(
+ [H // resolution_level, W // resolution_level, 3, -1]) * 256).clip(0, 255)
+ save_dir = os.path.join(self.base_exp_dir, 'render_{}_{}'.format(rendering_imgs_idx[r_idx],
+ rendering_imgs_idx[r_idx + 1]))
+ os.makedirs(save_dir, exist_ok=True)
+ # ic(img_fine.shape)
+ print(cv.imwrite(
+ os.path.join(save_dir, '{}.png'.format(idx + r_idx * inter_views_num)),
+ img_fine.squeeze()[:, :, ::-1]))
+ print(os.path.join(save_dir, '{}.png'.format(idx + r_idx * inter_views_num)))
diff --git a/SparseNeuS_demo_v1/models/trainer_generic.py b/SparseNeuS_demo_v1/models/trainer_generic.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c87d61d5c7feb93dadd40099a5ebe0a9db81924
--- /dev/null
+++ b/SparseNeuS_demo_v1/models/trainer_generic.py
@@ -0,0 +1,1224 @@
+"""
+decouple the trainer with the renderer
+"""
+import os
+import cv2 as cv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import numpy as np
+import logging
+import mcubes
+import trimesh
+from icecream import ic
+
+from utils.misc_utils import visualize_depth_numpy
+
+from utils.training_utils import numpy2tensor
+from loss.depth_metric import compute_depth_errors
+
+from loss.depth_loss import DepthLoss, DepthSmoothLoss
+
+from models.rays import gen_rays_between
+
+from models.sparse_neus_renderer import SparseNeuSRenderer
+
+def safe_l2_normalize(x, dim=None, eps=1e-6):
+ return F.normalize(x, p=2, dim=dim, eps=eps)
+
+
+class GenericTrainer(nn.Module):
+ def __init__(self,
+ rendering_network_outside,
+ pyramid_feature_network_lod0,
+ pyramid_feature_network_lod1,
+ sdf_network_lod0,
+ sdf_network_lod1,
+ variance_network_lod0,
+ variance_network_lod1,
+ rendering_network_lod0,
+ rendering_network_lod1,
+ n_samples_lod0,
+ n_importance_lod0,
+ n_samples_lod1,
+ n_importance_lod1,
+ n_outside,
+ perturb,
+ alpha_type='div',
+ conf=None,
+ timestamp="",
+ mode='train',
+ base_exp_dir=None,
+ ):
+ super(GenericTrainer, self).__init__()
+
+ self.conf = conf
+ self.timestamp = timestamp
+
+
+ self.base_exp_dir = base_exp_dir
+
+
+ self.anneal_start = self.conf.get_float('train.anneal_start', default=0.0)
+ self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0)
+ self.anneal_start_lod1 = self.conf.get_float('train.anneal_start_lod1', default=0.0)
+ self.anneal_end_lod1 = self.conf.get_float('train.anneal_end_lod1', default=0.0)
+
+ # network setups
+ self.rendering_network_outside = rendering_network_outside
+ self.pyramid_feature_network_geometry_lod0 = pyramid_feature_network_lod0 # 2D pyramid feature network for geometry
+ self.pyramid_feature_network_geometry_lod1 = pyramid_feature_network_lod1 # use differnet networks for the two lods
+
+ # when num_lods==2, may consume too much memeory
+ self.sdf_network_lod0 = sdf_network_lod0
+ self.sdf_network_lod1 = sdf_network_lod1
+
+ # - warpped by ModuleList to support DataParallel
+ self.variance_network_lod0 = variance_network_lod0
+ self.variance_network_lod1 = variance_network_lod1
+
+ self.rendering_network_lod0 = rendering_network_lod0
+ self.rendering_network_lod1 = rendering_network_lod1
+
+ self.n_samples_lod0 = n_samples_lod0
+ self.n_importance_lod0 = n_importance_lod0
+ self.n_samples_lod1 = n_samples_lod1
+ self.n_importance_lod1 = n_importance_lod1
+ self.n_outside = n_outside
+ self.num_lods = conf.get_int('model.num_lods') # the number of octree lods
+ self.perturb = perturb
+ self.alpha_type = alpha_type
+
+ # - the two renderers
+ self.sdf_renderer_lod0 = SparseNeuSRenderer(
+ self.rendering_network_outside,
+ self.sdf_network_lod0,
+ self.variance_network_lod0,
+ self.rendering_network_lod0,
+ self.n_samples_lod0,
+ self.n_importance_lod0,
+ self.n_outside,
+ self.perturb,
+ alpha_type='div',
+ conf=self.conf)
+
+ self.sdf_renderer_lod1 = SparseNeuSRenderer(
+ self.rendering_network_outside,
+ self.sdf_network_lod1,
+ self.variance_network_lod1,
+ self.rendering_network_lod1,
+ self.n_samples_lod1,
+ self.n_importance_lod1,
+ self.n_outside,
+ self.perturb,
+ alpha_type='div',
+ conf=self.conf)
+
+ self.if_fix_lod0_networks = self.conf.get_bool('train.if_fix_lod0_networks')
+
+ # sdf network weights
+ self.sdf_igr_weight = self.conf.get_float('train.sdf_igr_weight')
+ self.sdf_sparse_weight = self.conf.get_float('train.sdf_sparse_weight', default=0)
+ self.sdf_decay_param = self.conf.get_float('train.sdf_decay_param', default=100)
+ self.fg_bg_weight = self.conf.get_float('train.fg_bg_weight', default=0.00)
+ self.bg_ratio = self.conf.get_float('train.bg_ratio', default=0.0)
+
+ self.depth_criterion = DepthLoss()
+
+ # - DataParallel mode, cannot modify attributes in forward()
+ # self.iter_step = 0
+ self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq')
+
+ # - True for finetuning; False for general training
+ self.if_fitted_rendering = self.conf.get_bool('train.if_fitted_rendering', default=False)
+
+ self.prune_depth_filter = self.conf.get_bool('model.prune_depth_filter', default=False)
+
+ def get_trainable_params(self):
+ # set trainable params
+
+ self.params_to_train = []
+
+ if not self.if_fix_lod0_networks:
+ # load pretrained featurenet
+ self.params_to_train += list(self.pyramid_feature_network_geometry_lod0.parameters())
+ self.params_to_train += list(self.sdf_network_lod0.parameters())
+ self.params_to_train += list(self.variance_network_lod0.parameters())
+
+ if self.rendering_network_lod0 is not None:
+ self.params_to_train += list(self.rendering_network_lod0.parameters())
+
+ if self.sdf_network_lod1 is not None:
+ # load pretrained featurenet
+ self.params_to_train += list(self.pyramid_feature_network_geometry_lod1.parameters())
+
+ self.params_to_train += list(self.sdf_network_lod1.parameters())
+ self.params_to_train += list(self.variance_network_lod1.parameters())
+ if self.rendering_network_lod1 is not None:
+ self.params_to_train += list(self.rendering_network_lod1.parameters())
+
+ return self.params_to_train
+
+ def train_step(self, sample,
+ perturb_overwrite=-1,
+ background_rgb=None,
+ alpha_inter_ratio_lod0=0.0,
+ alpha_inter_ratio_lod1=0.0,
+ iter_step=0,
+ ):
+ # * only support batch_size==1
+ # ! attention: the list of string cannot be splited in DataParallel
+ batch_idx = sample['batch_idx'][0]
+ meta = sample['meta'][batch_idx] # the scan lighting ref_view info
+
+ sizeW = sample['img_wh'][0][0]
+ sizeH = sample['img_wh'][0][1]
+ partial_vol_origin = sample['partial_vol_origin'] # [B, 3]
+ near, far = sample['near_fars'][0, 0, :1], sample['near_fars'][0, 0, 1:]
+
+ # the full-size ray variables
+ sample_rays = sample['rays']
+ rays_o = sample_rays['rays_o'][0]
+ rays_d = sample_rays['rays_v'][0]
+
+ imgs = sample['images'][0]
+ intrinsics = sample['intrinsics'][0]
+ intrinsics_l_4x = intrinsics.clone()
+ intrinsics_l_4x[:, :2] *= 0.25
+ w2cs = sample['w2cs'][0]
+ c2ws = sample['c2ws'][0]
+ proj_matrices = sample['affine_mats']
+ scale_mat = sample['scale_mat']
+ trans_mat = sample['trans_mat']
+
+ # *********************** Lod==0 ***********************
+ if not self.if_fix_lod0_networks:
+ geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs)
+
+ conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
+ feature_maps=geometry_feature_maps[None, 1:, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices[:,1:],
+ # proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ lod=0,
+ )
+
+ else:
+ with torch.no_grad():
+ geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
+ # geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
+ conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
+ feature_maps=geometry_feature_maps[None, 1:, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices[:,1:],
+ # proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ lod=0,
+ )
+ # print("Checker2:, construct cost volume")
+ con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
+
+ con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
+ # import ipdb; ipdb.set_trace()
+ coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
+
+ # * extract depth maps for all the images
+ depth_maps_lod0, depth_masks_lod0 = None, None
+ if self.num_lods > 1:
+ sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume(
+ con_volume_lod0, con_valid_mask_volume_lod0,
+ coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ]
+
+ if self.prune_depth_filter:
+ depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps(
+ self.sdf_network_lod0, sdf_volume_lod0, intrinsics_l_4x, c2ws,
+ sizeH // 4, sizeW // 4, near * 1.5, far)
+ depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear',
+ align_corners=True)
+ depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest')
+
+ # *************** losses
+ loss_lod0, losses_lod0, depth_statis_lod0 = None, None, None
+
+ if not self.if_fix_lod0_networks:
+
+ render_out = self.sdf_renderer_lod0.render(
+ rays_o, rays_d, near, far,
+ self.sdf_network_lod0,
+ self.rendering_network_lod0,
+ background_rgb=background_rgb,
+ alpha_inter_ratio=alpha_inter_ratio_lod0,
+ # * related to conditional feature
+ lod=0,
+ conditional_volume=con_volume_lod0,
+ conditional_valid_mask_volume=con_valid_mask_volume_lod0,
+ # * 2d feature maps
+ feature_maps=geometry_feature_maps,
+ color_maps=imgs,
+ w2cs=w2cs,
+ intrinsics=intrinsics,
+ img_wh=[sizeW, sizeH],
+ if_general_rendering=True,
+ if_render_with_grad=True,
+ )
+
+ loss_lod0, losses_lod0, depth_statis_lod0 = self.cal_losses_sdf(render_out, sample_rays,
+ iter_step, lod=0)
+
+ # *********************** Lod==1 ***********************
+
+ loss_lod1, losses_lod1, depth_statis_lod1 = None, None, None
+
+ if self.num_lods > 1:
+ geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
+ # geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
+ if self.prune_depth_filter:
+ pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter(
+ sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0],
+ depth_maps_lod0, proj_matrices[0],
+ partial_vol_origin, self.sdf_network_lod0.voxel_size,
+ near, far, self.sdf_network_lod0.voxel_size, 12)
+ else:
+ pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf(
+ sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0])
+
+ pre_coords[:, 1:] = pre_coords[:, 1:] * 2
+
+ # ? It seems that training gru_fusion, this part should be trainable too
+ conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume(
+ feature_maps=geometry_feature_maps_lod1[None, 1:, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices[:,1:],
+ # proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ pre_coords=pre_coords,
+ pre_feats=pre_feats,
+ )
+
+ con_volume_lod1 = conditional_features_lod1['dense_volume_scale1']
+ con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1']
+
+ # if not self.if_gru_fusion_lod1:
+ render_out_lod1 = self.sdf_renderer_lod1.render(
+ rays_o, rays_d, near, far,
+ self.sdf_network_lod1,
+ self.rendering_network_lod1,
+ background_rgb=background_rgb,
+ alpha_inter_ratio=alpha_inter_ratio_lod1,
+ # * related to conditional feature
+ lod=1,
+ conditional_volume=con_volume_lod1,
+ conditional_valid_mask_volume=con_valid_mask_volume_lod1,
+ # * 2d feature maps
+ feature_maps=geometry_feature_maps_lod1,
+ color_maps=imgs,
+ w2cs=w2cs,
+ intrinsics=intrinsics,
+ img_wh=[sizeW, sizeH],
+ bg_ratio=self.bg_ratio,
+ )
+ loss_lod1, losses_lod1, depth_statis_lod1 = self.cal_losses_sdf(render_out_lod1, sample_rays,
+ iter_step, lod=1)
+
+ # print("Checker3:, compute losses")
+ # # - extract mesh
+ if iter_step % self.val_mesh_freq == 0:
+ torch.cuda.empty_cache()
+ self.validate_mesh(self.sdf_network_lod0,
+ self.sdf_renderer_lod0.extract_geometry,
+ conditional_volume=con_volume_lod0, lod=0,
+ threshold=0,
+ # occupancy_mask=con_valid_mask_volume_lod0[0, 0],
+ mode='train_bg', meta=meta,
+ iter_step=iter_step, scale_mat=scale_mat,
+ trans_mat=trans_mat)
+ torch.cuda.empty_cache()
+
+ if self.num_lods > 1:
+ self.validate_mesh(self.sdf_network_lod1,
+ self.sdf_renderer_lod1.extract_geometry,
+ conditional_volume=con_volume_lod1, lod=1,
+ # occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(),
+ mode='train_bg', meta=meta,
+ iter_step=iter_step, scale_mat=scale_mat,
+ trans_mat=trans_mat)
+ # import ipdb; ipdb.set_trace()
+ # print("Checker3.1:, after val mesh")
+ losses = {
+ # - lod 0
+ 'loss_lod0': loss_lod0,
+ 'losses_lod0': losses_lod0,
+ 'depth_statis_lod0': depth_statis_lod0,
+
+ # - lod 1
+ 'loss_lod1': loss_lod1,
+ 'losses_lod1': losses_lod1,
+ 'depth_statis_lod1': depth_statis_lod1,
+
+ }
+
+ return losses
+
+ def val_step(self, sample,
+ perturb_overwrite=-1,
+ background_rgb=None,
+ alpha_inter_ratio_lod0=0.0,
+ alpha_inter_ratio_lod1=0.0,
+ iter_step=0,
+ chunk_size=512,
+ save_vis=False,
+ ):
+ # * only support batch_size==1
+ # ! attention: the list of string cannot be splited in DataParallel
+ batch_idx = sample['batch_idx'][0]
+ meta = sample['meta'][batch_idx] # the scan lighting ref_view info
+
+ sizeW = sample['img_wh'][0][0]
+ sizeH = sample['img_wh'][0][1]
+ H, W = sizeH, sizeW
+
+ partial_vol_origin = sample['partial_vol_origin'] # [B, 3]
+ near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:]
+
+ # the ray variables
+ sample_rays = sample['rays']
+ rays_o = sample_rays['rays_o'][0]
+ rays_d = sample_rays['rays_v'][0]
+ rays_ndc_uv = sample_rays['rays_ndc_uv'][0]
+
+ imgs = sample['images'][0]
+ intrinsics = sample['intrinsics'][0]
+ intrinsics_l_4x = intrinsics.clone()
+ intrinsics_l_4x[:, :2] *= 0.25
+ w2cs = sample['w2cs'][0]
+ c2ws = sample['c2ws'][0]
+ proj_matrices = sample['affine_mats']
+
+ # render_img_idx = sample['render_img_idx'][0]
+ # true_img = sample['images'][0][render_img_idx]
+
+ # - the image to render
+ scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale
+ trans_mat = sample['trans_mat']
+ query_c2w = sample['query_c2w'] # [1,4,4]
+ query_w2c = sample['query_w2c'] # [1,4,4]
+ true_img = sample['query_image'][0]
+ true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255)
+
+ depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy()
+
+ scale_factor = sample['scale_factor'][0].cpu().numpy()
+ true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None
+ if true_depth is not None:
+ true_depth = true_depth[0].cpu().numpy()
+ true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0]
+ else:
+ true_depth_colored = None
+
+ rays_o = rays_o.reshape(-1, 3).split(chunk_size)
+ rays_d = rays_d.reshape(-1, 3).split(chunk_size)
+
+ # - obtain conditional features
+ with torch.no_grad():
+ # - obtain conditional features
+ geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
+ # import ipdb; ipdb.set_trace()
+ # - lod 0
+ conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
+ feature_maps=geometry_feature_maps[None, :, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ lod=0,
+ )
+
+ con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
+ con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
+ coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
+
+ if self.num_lods > 1:
+ sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume(
+ con_volume_lod0, con_valid_mask_volume_lod0,
+ coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ]
+
+ depth_maps_lod0, depth_masks_lod0 = None, None
+ if self.prune_depth_filter:
+ depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps(
+ self.sdf_network_lod0, sdf_volume_lod0,
+ intrinsics_l_4x, c2ws,
+ sizeH // 4, sizeW // 4, near * 1.5, far) # - near*1.5 is a experienced number
+ depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear',
+ align_corners=True)
+ depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest')
+
+ #### visualize the depth_maps_lod0 for checking
+ colored_depth_maps_lod0 = []
+ for i in range(depth_maps_lod0.shape[0]):
+ colored_depth_maps_lod0.append(
+ visualize_depth_numpy(depth_maps_lod0[i, 0].cpu().numpy(), [depth_min, depth_max])[0])
+
+ colored_depth_maps_lod0 = np.concatenate(colored_depth_maps_lod0, axis=0).astype(np.uint8)
+ os.makedirs(os.path.join(self.base_exp_dir, 'depth_maps_lod0'), exist_ok=True)
+ cv.imwrite(os.path.join(self.base_exp_dir, 'depth_maps_lod0',
+ '{:0>8d}_{}.png'.format(iter_step, meta)),
+ colored_depth_maps_lod0[:, :, ::-1])
+
+ if self.num_lods > 1:
+ geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
+
+ if self.prune_depth_filter:
+ pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter(
+ sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0],
+ depth_maps_lod0, proj_matrices[0],
+ partial_vol_origin, self.sdf_network_lod0.voxel_size,
+ near, far, self.sdf_network_lod0.voxel_size, 12)
+ else:
+ pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf(
+ sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0])
+
+ pre_coords[:, 1:] = pre_coords[:, 1:] * 2
+
+ with torch.no_grad():
+ conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume(
+ feature_maps=geometry_feature_maps_lod1[None, :, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ pre_coords=pre_coords,
+ pre_feats=pre_feats,
+ )
+
+ con_volume_lod1 = conditional_features_lod1['dense_volume_scale1']
+ con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1']
+
+ out_rgb_fine = []
+ out_normal_fine = []
+ out_depth_fine = []
+
+ out_rgb_fine_lod1 = []
+ out_normal_fine_lod1 = []
+ out_depth_fine_lod1 = []
+
+ # out_depth_fine_explicit = []
+ if save_vis:
+ for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
+
+ # ****** lod 0 ****
+ render_out = self.sdf_renderer_lod0.render(
+ rays_o_batch, rays_d_batch, near, far,
+ self.sdf_network_lod0,
+ self.rendering_network_lod0,
+ background_rgb=background_rgb,
+ alpha_inter_ratio=alpha_inter_ratio_lod0,
+ # * related to conditional feature
+ lod=0,
+ conditional_volume=con_volume_lod0,
+ conditional_valid_mask_volume=con_valid_mask_volume_lod0,
+ # * 2d feature maps
+ feature_maps=geometry_feature_maps,
+ color_maps=imgs,
+ w2cs=w2cs,
+ intrinsics=intrinsics,
+ img_wh=[sizeW, sizeH],
+ query_c2w=query_c2w,
+ if_render_with_grad=False,
+ )
+
+ feasible = lambda key: ((key in render_out) and (render_out[key] is not None))
+
+ if feasible('depth'):
+ out_depth_fine.append(render_out['depth'].detach().cpu().numpy())
+
+ # if render_out['color_coarse'] is not None:
+ if feasible('color_fine'):
+ out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
+ if feasible('gradients') and feasible('weights'):
+ if render_out['inside_sphere'] is not None:
+ out_normal_fine.append((render_out['gradients'] * render_out['weights'][:,
+ :self.n_samples_lod0 + self.n_importance_lod0,
+ None] * render_out['inside_sphere'][
+ ..., None]).sum(dim=1).detach().cpu().numpy())
+ else:
+ out_normal_fine.append((render_out['gradients'] * render_out['weights'][:,
+ :self.n_samples_lod0 + self.n_importance_lod0,
+ None]).sum(dim=1).detach().cpu().numpy())
+ del render_out
+
+ # ****************** lod 1 **************************
+ if self.num_lods > 1:
+ for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
+ render_out_lod1 = self.sdf_renderer_lod1.render(
+ rays_o_batch, rays_d_batch, near, far,
+ self.sdf_network_lod1,
+ self.rendering_network_lod1,
+ background_rgb=background_rgb,
+ alpha_inter_ratio=alpha_inter_ratio_lod1,
+ # * related to conditional feature
+ lod=1,
+ conditional_volume=con_volume_lod1,
+ conditional_valid_mask_volume=con_valid_mask_volume_lod1,
+ # * 2d feature maps
+ feature_maps=geometry_feature_maps_lod1,
+ color_maps=imgs,
+ w2cs=w2cs,
+ intrinsics=intrinsics,
+ img_wh=[sizeW, sizeH],
+ query_c2w=query_c2w,
+ if_render_with_grad=False,
+ )
+
+ feasible = lambda key: ((key in render_out_lod1) and (render_out_lod1[key] is not None))
+
+ if feasible('depth'):
+ out_depth_fine_lod1.append(render_out_lod1['depth'].detach().cpu().numpy())
+
+ # if render_out['color_coarse'] is not None:
+ if feasible('color_fine'):
+ out_rgb_fine_lod1.append(render_out_lod1['color_fine'].detach().cpu().numpy())
+ if feasible('gradients') and feasible('weights'):
+ if render_out_lod1['inside_sphere'] is not None:
+ out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:,
+ :self.n_samples_lod1 + self.n_importance_lod1,
+ None] *
+ render_out_lod1['inside_sphere'][
+ ..., None]).sum(dim=1).detach().cpu().numpy())
+ else:
+ out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:,
+ :self.n_samples_lod1 + self.n_importance_lod1,
+ None]).sum(
+ dim=1).detach().cpu().numpy())
+ del render_out_lod1
+
+ # - save visualization of lod 0
+
+ self.save_visualization(true_img, true_depth_colored, out_depth_fine, out_normal_fine,
+ query_w2c[0], out_rgb_fine, H, W,
+ depth_min, depth_max, iter_step, meta, "val_lod0", true_depth=true_depth, scale_factor=scale_factor)
+
+ if self.num_lods > 1:
+ self.save_visualization(true_img, true_depth_colored, out_depth_fine_lod1, out_normal_fine_lod1,
+ query_w2c[0], out_rgb_fine_lod1, H, W,
+ depth_min, depth_max, iter_step, meta, "val_lod1", true_depth=true_depth, scale_factor=scale_factor)
+
+ # - extract mesh
+ if (iter_step % self.val_mesh_freq == 0):
+ torch.cuda.empty_cache()
+ self.validate_mesh(self.sdf_network_lod0,
+ self.sdf_renderer_lod0.extract_geometry,
+ conditional_volume=con_volume_lod0, lod=0,
+ threshold=0,
+ # occupancy_mask=con_valid_mask_volume_lod0[0, 0],
+ mode='val_bg', meta=meta,
+ iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat)
+ torch.cuda.empty_cache()
+
+ if self.num_lods > 1:
+ self.validate_mesh(self.sdf_network_lod1,
+ self.sdf_renderer_lod1.extract_geometry,
+ conditional_volume=con_volume_lod1, lod=1,
+ # occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(),
+ mode='val_bg', meta=meta,
+ iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat)
+
+ torch.cuda.empty_cache()
+
+
+
+ def export_mesh_step(self, sample,
+ perturb_overwrite=-1,
+ background_rgb=None,
+ alpha_inter_ratio_lod0=0.0,
+ alpha_inter_ratio_lod1=0.0,
+ iter_step=0,
+ chunk_size=512,
+ save_vis=False,
+ ):
+ # * only support batch_size==1
+ # ! attention: the list of string cannot be splited in DataParallel
+ batch_idx = sample['batch_idx'][0]
+ meta = sample['meta'][batch_idx] # the scan lighting ref_view info
+
+ sizeW = sample['img_wh'][0][0]
+ sizeH = sample['img_wh'][0][1]
+ H, W = sizeH, sizeW
+
+ partial_vol_origin = sample['partial_vol_origin'] # [B, 3]
+ near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:]
+
+ # the ray variables
+ sample_rays = sample['rays']
+ rays_o = sample_rays['rays_o'][0]
+ rays_d = sample_rays['rays_v'][0]
+ rays_ndc_uv = sample_rays['rays_ndc_uv'][0]
+
+ imgs = sample['images'][0]
+ intrinsics = sample['intrinsics'][0]
+ intrinsics_l_4x = intrinsics.clone()
+ intrinsics_l_4x[:, :2] *= 0.25
+ w2cs = sample['w2cs'][0]
+ c2ws = sample['c2ws'][0]
+ # target_candidate_w2cs = sample['target_candidate_w2cs'][0]
+ proj_matrices = sample['affine_mats']
+
+
+ # - the image to render
+ scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale
+ trans_mat = sample['trans_mat']
+ query_c2w = sample['query_c2w'] # [1,4,4]
+ query_w2c = sample['query_w2c'] # [1,4,4]
+ true_img = sample['query_image'][0]
+ true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255)
+
+ depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy()
+
+ scale_factor = sample['scale_factor'][0].cpu().numpy()
+ true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None
+ if true_depth is not None:
+ true_depth = true_depth[0].cpu().numpy()
+ true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0]
+ else:
+ true_depth_colored = None
+
+ rays_o = rays_o.reshape(-1, 3).split(chunk_size)
+ rays_d = rays_d.reshape(-1, 3).split(chunk_size)
+ # import time
+ # jha_begin1 = time.time()
+ # - obtain conditional features
+ with torch.no_grad():
+ # - obtain conditional features
+ geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
+ # - lod 0
+ conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
+ feature_maps=geometry_feature_maps[None, :, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ lod=0,
+ )
+ # jha_end1 = time.time()
+ # print("get_conditional_volume: ", jha_end1 - jha_begin1)
+ # jha_begin2 = time.time()
+ con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
+ con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
+ coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
+
+ if self.num_lods > 1:
+ sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume(
+ con_volume_lod0, con_valid_mask_volume_lod0,
+ coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ]
+
+ depth_maps_lod0, depth_masks_lod0 = None, None
+
+
+ if self.num_lods > 1:
+ geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
+
+ if self.prune_depth_filter:
+ pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter(
+ sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0],
+ depth_maps_lod0, proj_matrices[0],
+ partial_vol_origin, self.sdf_network_lod0.voxel_size,
+ near, far, self.sdf_network_lod0.voxel_size, 12)
+ else:
+ pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf(
+ sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0])
+
+ pre_coords[:, 1:] = pre_coords[:, 1:] * 2
+
+ with torch.no_grad():
+ conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume(
+ feature_maps=geometry_feature_maps_lod1[None, :, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ pre_coords=pre_coords,
+ pre_feats=pre_feats,
+ )
+
+ con_volume_lod1 = conditional_features_lod1['dense_volume_scale1']
+ con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1']
+
+ out_rgb_fine = []
+ out_normal_fine = []
+ out_depth_fine = []
+
+ out_rgb_fine_lod1 = []
+ out_normal_fine_lod1 = []
+ out_depth_fine_lod1 = []
+
+ # jha_end2 = time.time()
+ # print("interval before starting mesh export: ", jha_end2 - jha_begin2)
+
+ # - extract mesh
+ if (iter_step % self.val_mesh_freq == 0):
+ torch.cuda.empty_cache()
+ # jha_begin3 = time.time()
+ self.validate_colored_mesh(
+ density_or_sdf_network=self.sdf_network_lod0,
+ func_extract_geometry=self.sdf_renderer_lod0.extract_geometry,
+ conditional_volume=con_volume_lod0,
+ conditional_valid_mask_volume = con_valid_mask_volume_lod0,
+ feature_maps=geometry_feature_maps,
+ color_maps=imgs,
+ w2cs=w2cs,
+ target_candidate_w2cs=None,
+ intrinsics=intrinsics,
+ rendering_network=self.rendering_network_lod0,
+ rendering_projector=self.sdf_renderer_lod0.rendering_projector,
+ lod=0,
+ threshold=0,
+ query_c2w=query_c2w,
+ mode='val_bg', meta=meta,
+ iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat
+ )
+ torch.cuda.empty_cache()
+ # jha_end3 = time.time()
+ # print("validate_colored_mesh_test_time: ", jha_end3 - jha_begin3)
+
+ if self.num_lods > 1:
+ self.validate_colored_mesh(
+ density_or_sdf_network=self.sdf_network_lod1,
+ func_extract_geometry=self.sdf_renderer_lod1.extract_geometry,
+ conditional_volume=con_volume_lod1,
+ conditional_valid_mask_volume = con_valid_mask_volume_lod1,
+ feature_maps=geometry_feature_maps,
+ color_maps=imgs,
+ w2cs=w2cs,
+ target_candidate_w2cs=None,
+ intrinsics=intrinsics,
+ rendering_network=self.rendering_network_lod1,
+ rendering_projector=self.sdf_renderer_lod1.rendering_projector,
+ lod=1,
+ threshold=0,
+ query_c2w=query_c2w,
+ mode='val_bg', meta=meta,
+ iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat
+ )
+ torch.cuda.empty_cache()
+
+
+
+ def save_visualization(self, true_img, true_colored_depth, out_depth, out_normal, w2cs, out_color, H, W,
+ depth_min, depth_max, iter_step, meta, comment, out_color_mlp=[], true_depth=None, scale_factor=1.0):
+ if len(out_color) > 0:
+ img_fine = (np.concatenate(out_color, axis=0).reshape([H, W, 3]) * 256).clip(0, 255)
+
+ if len(out_color_mlp) > 0:
+ img_mlp = (np.concatenate(out_color_mlp, axis=0).reshape([H, W, 3]) * 256).clip(0, 255)
+
+ if len(out_normal) > 0:
+ normal_img = np.concatenate(out_normal, axis=0)
+ rot = w2cs[:3, :3].detach().cpu().numpy()
+ # - convert normal from world space to camera space
+ normal_img = (np.matmul(rot[None, :, :],
+ normal_img[:, :, None]).reshape([H, W, 3]) * 128 + 128).clip(0, 255)
+ if len(out_depth) > 0:
+ pred_depth = np.concatenate(out_depth, axis=0).reshape([H, W])
+ pred_depth_colored = visualize_depth_numpy(pred_depth, [depth_min, depth_max])[0]
+
+ if len(out_depth) > 0:
+ os.makedirs(os.path.join(self.base_exp_dir, 'depths_' + comment), exist_ok=True)
+ if true_colored_depth is not None:
+
+ if true_depth is not None:
+ depth_error_map = np.abs(true_depth - pred_depth) * 2.0 / scale_factor
+ # [256, 256, 1] -> [256, 256, 3]
+ depth_error_map = np.tile(depth_error_map[:, :, None], [1, 1, 3])
+ print("meta: ", meta)
+ print("scale_factor: ", scale_factor)
+ print("depth_error_mean: ", depth_error_map.mean())
+ # import ipdb; ipdb.set_trace()
+ depth_visualized = np.concatenate(
+ [(depth_error_map * 255).astype(np.uint8), true_colored_depth, pred_depth_colored, true_img], axis=1)[:, :, ::-1]
+ # print("depth_visualized.shape: ", depth_visualized.shape)
+ # write depth error result text on img, the input is a numpy array of [256, 1024, 3]
+ # cv.putText(depth_visualized.copy(), "depth_error_mean: {:.4f}".format(depth_error_map.mean()), (10, 30), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
+ else:
+ depth_visualized = np.concatenate(
+ [true_colored_depth, pred_depth_colored, true_img])[:, :, ::-1]
+ cv.imwrite(
+ os.path.join(self.base_exp_dir, 'depths_' + comment,
+ '{:0>8d}_{}.png'.format(iter_step, meta)), depth_visualized
+ )
+ else:
+ cv.imwrite(
+ os.path.join(self.base_exp_dir, 'depths_' + comment,
+ '{:0>8d}_{}.png'.format(iter_step, meta)),
+ np.concatenate(
+ [pred_depth_colored, true_img])[:, :, ::-1])
+ if len(out_color) > 0:
+ os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment), exist_ok=True)
+ cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment,
+ '{:0>8d}_{}.png'.format(iter_step, meta)),
+ np.concatenate(
+ [img_fine, true_img])[:, :, ::-1]) # bgr2rgb
+ # compute psnr (image pixel lie in [0, 255])
+ mse_loss = np.mean((img_fine - true_img) ** 2)
+ psnr = 10 * np.log10(255 ** 2 / mse_loss)
+
+ print("PSNR: ", psnr)
+
+ if len(out_color_mlp) > 0:
+ os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment), exist_ok=True)
+ cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment,
+ '{:0>8d}_{}.png'.format(iter_step, meta)),
+ np.concatenate(
+ [img_mlp, true_img])[:, :, ::-1]) # bgr2rgb
+
+ if len(out_normal) > 0:
+ os.makedirs(os.path.join(self.base_exp_dir, 'normals_' + comment), exist_ok=True)
+ cv.imwrite(os.path.join(self.base_exp_dir, 'normals_' + comment,
+ '{:0>8d}_{}.png'.format(iter_step, meta)),
+ normal_img[:, :, ::-1])
+
+ def forward(self, sample,
+ perturb_overwrite=-1,
+ background_rgb=None,
+ alpha_inter_ratio_lod0=0.0,
+ alpha_inter_ratio_lod1=0.0,
+ iter_step=0,
+ mode='train',
+ save_vis=False,
+ ):
+
+ if mode == 'train':
+ return self.train_step(sample,
+ perturb_overwrite=perturb_overwrite,
+ background_rgb=background_rgb,
+ alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
+ alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
+ iter_step=iter_step
+ )
+ elif mode == 'val':
+ import time
+ begin = time.time()
+ result = self.val_step(sample,
+ perturb_overwrite=perturb_overwrite,
+ background_rgb=background_rgb,
+ alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
+ alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
+ iter_step=iter_step,
+ save_vis=save_vis,
+ )
+ end = time.time()
+ print("val_step time: ", end - begin)
+ return result
+ elif mode == 'export_mesh':
+ import time
+ begin = time.time()
+ result = self.export_mesh_step(sample,
+ perturb_overwrite=perturb_overwrite,
+ background_rgb=background_rgb,
+ alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
+ alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
+ iter_step=iter_step,
+ save_vis=save_vis,
+ )
+ end = time.time()
+ print("export mesh time: ", end - begin)
+ return result
+ def obtain_pyramid_feature_maps(self, imgs, lod=0):
+ """
+ get feature maps of all conditional images
+ :param imgs:
+ :return:
+ """
+
+ if lod == 0:
+ extractor = self.pyramid_feature_network_geometry_lod0
+ elif lod >= 1:
+ extractor = self.pyramid_feature_network_geometry_lod1
+
+ pyramid_feature_maps = extractor(imgs)
+
+ # * the pyramid features are very important, if only use the coarst features, hard to optimize
+ fused_feature_maps = torch.cat([
+ F.interpolate(pyramid_feature_maps[0], scale_factor=4, mode='bilinear', align_corners=True),
+ F.interpolate(pyramid_feature_maps[1], scale_factor=2, mode='bilinear', align_corners=True),
+ pyramid_feature_maps[2]
+ ], dim=1)
+
+ return fused_feature_maps
+
+ def cal_losses_sdf(self, render_out, sample_rays, iter_step=-1, lod=0):
+
+ # loss weight schedule; the regularization terms should be added in later training stage
+ def get_weight(iter_step, weight):
+ if lod == 1:
+ anneal_start = self.anneal_end if lod == 0 else self.anneal_end_lod1
+ anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1
+ anneal_end = anneal_end * 2
+ else:
+ anneal_start = self.anneal_start if lod == 0 else self.anneal_start_lod1
+ anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1
+ anneal_end = anneal_end * 2
+
+ if iter_step < 0:
+ return weight
+
+ if anneal_end == 0.0:
+ return weight
+ elif iter_step < anneal_start:
+ return 0.0
+ else:
+ return np.min(
+ [1.0,
+ (iter_step - anneal_start) / (anneal_end - anneal_start)]) * weight
+
+ rays_o = sample_rays['rays_o'][0]
+ rays_d = sample_rays['rays_v'][0]
+ true_rgb = sample_rays['rays_color'][0]
+
+ if 'rays_depth' in sample_rays.keys():
+ true_depth = sample_rays['rays_depth'][0]
+ else:
+ true_depth = None
+ mask = sample_rays['rays_mask'][0]
+
+ color_fine = render_out['color_fine']
+ color_fine_mask = render_out['color_fine_mask']
+ depth_pred = render_out['depth']
+
+ variance = render_out['variance']
+ cdf_fine = render_out['cdf_fine']
+ weight_sum = render_out['weights_sum']
+
+ gradient_error_fine = render_out['gradient_error_fine']
+
+ sdf = render_out['sdf']
+
+ # * color generated by mlp
+ color_mlp = render_out['color_mlp']
+ color_mlp_mask = render_out['color_mlp_mask']
+
+ if color_fine is not None:
+ # Color loss
+ color_mask = color_fine_mask if color_fine_mask is not None else mask
+ # import ipdb; ipdb.set_trace()
+ color_mask = color_mask[..., 0]
+ color_error = (color_fine[color_mask] - true_rgb[color_mask])
+ # print("Nan number", torch.isnan(color_error).sum())
+ # print("Color error shape", color_error.shape)
+ # import ipdb; ipdb.set_trace()
+ color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error).to(color_error.device),
+ reduction='mean')
+ # print(color_fine_loss)
+ psnr = 20.0 * torch.log10(
+ 1.0 / (((color_fine[color_mask] - true_rgb[color_mask]) ** 2).mean() / (3.0)).sqrt())
+ else:
+ color_fine_loss = 0.
+ psnr = 0.
+
+ if color_mlp is not None:
+ # Color loss
+ color_mlp_mask = color_mlp_mask[..., 0]
+ color_error_mlp = (color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask])
+ color_mlp_loss = F.l1_loss(color_error_mlp,
+ torch.zeros_like(color_error_mlp).to(color_error_mlp.device),
+ reduction='mean')
+
+ psnr_mlp = 20.0 * torch.log10(
+ 1.0 / (((color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) ** 2).mean() / (3.0)).sqrt())
+ else:
+ color_mlp_loss = 0.
+ psnr_mlp = 0.
+
+ # depth loss is only used for inference, not included in total loss
+ if true_depth is not None:
+ # depth_loss = self.depth_criterion(depth_pred, true_depth, mask)
+ depth_loss = self.depth_criterion(depth_pred, true_depth)
+
+ # # depth evaluation
+ # depth_statis = compute_depth_errors(depth_pred.detach().cpu().numpy(), true_depth.cpu().numpy())
+ # depth_statis = numpy2tensor(depth_statis, device=rays_o.device)
+ depth_statis = None
+ else:
+ depth_loss = 0.
+ depth_statis = None
+
+ sparse_loss_1 = torch.exp(
+ -1 * torch.abs(render_out['sdf_random']) * self.sdf_decay_param).mean() # - should equal
+ sparse_loss_2 = torch.exp(-1 * torch.abs(sdf) * self.sdf_decay_param).mean()
+ sparse_loss = (sparse_loss_1 + sparse_loss_2) / 2
+
+ sdf_mean = torch.abs(sdf).mean()
+ sparseness_1 = (torch.abs(sdf) < 0.01).to(torch.float32).mean()
+ sparseness_2 = (torch.abs(sdf) < 0.02).to(torch.float32).mean()
+
+ # Eikonal loss
+ gradient_error_loss = gradient_error_fine
+
+ # ! the first 50k, don't use bg constraint
+ fg_bg_weight = 0.0 if iter_step < 50000 else get_weight(iter_step, self.fg_bg_weight)
+
+ # Mask loss, optional
+ # The images of DTU dataset contain large black regions (0 rgb values),
+ # can use this data prior to make fg more clean
+ background_loss = 0.0
+ fg_bg_loss = 0.0
+ if self.fg_bg_weight > 0 and torch.mean((mask < 0.5).to(torch.float32)) > 0.02:
+ weights_sum_fg = render_out['weights_sum_fg']
+ fg_bg_error = (weights_sum_fg - mask)[mask < 0.5]
+ fg_bg_loss = F.l1_loss(fg_bg_error,
+ torch.zeros_like(fg_bg_error).to(fg_bg_error.device),
+ reduction='mean')
+
+
+
+ loss = 1.0 * depth_loss + color_fine_loss + color_mlp_loss + \
+ sparse_loss * get_weight(iter_step, self.sdf_sparse_weight) + \
+ fg_bg_loss * fg_bg_weight + \
+ gradient_error_loss * self.sdf_igr_weight # ! gradient_error_loss need a mask
+
+ losses = {
+ "loss": loss,
+ "depth_loss": depth_loss,
+ "color_fine_loss": color_fine_loss,
+ "color_mlp_loss": color_mlp_loss,
+ "gradient_error_loss": gradient_error_loss,
+ "background_loss": background_loss,
+ "sparse_loss": sparse_loss,
+ "sparseness_1": sparseness_1,
+ "sparseness_2": sparseness_2,
+ "sdf_mean": sdf_mean,
+ "psnr": psnr,
+ "psnr_mlp": psnr_mlp,
+ "weights_sum": render_out['weights_sum'],
+ "weights_sum_fg": render_out['weights_sum_fg'],
+ "alpha_sum": render_out['alpha_sum'],
+ "variance": render_out['variance'],
+ "sparse_weight": get_weight(iter_step, self.sdf_sparse_weight),
+ "fg_bg_weight": fg_bg_weight,
+ "fg_bg_loss": fg_bg_loss, # added by jha, bug of sparseNeuS
+ }
+ # print("[TEST]: weights_sum in trainner forward", losses['weights_sum'].mean())
+ losses = numpy2tensor(losses, device=rays_o.device)
+ return loss, losses, depth_statis
+
+ @torch.no_grad()
+ def validate_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360,
+ threshold=0.0, mode='val',
+ # * 3d feature volume
+ conditional_volume=None, lod=None, occupancy_mask=None,
+ bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None,
+ trans_mat=None
+ ):
+
+ bound_min = torch.tensor(bound_min, dtype=torch.float32)
+ bound_max = torch.tensor(bound_max, dtype=torch.float32)
+
+ vertices, triangles, fields = func_extract_geometry(
+ density_or_sdf_network,
+ bound_min, bound_max, resolution=resolution,
+ threshold=threshold, device=conditional_volume.device,
+ # * 3d feature volume
+ conditional_volume=conditional_volume, lod=lod,
+ occupancy_mask=occupancy_mask
+ )
+
+
+ if scale_mat is not None:
+ scale_mat_np = scale_mat.cpu().numpy()
+ vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None]
+
+ if trans_mat is not None: # w2c_ref_inv
+ trans_mat_np = trans_mat.cpu().numpy()
+ vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1)
+ vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0]
+
+ mesh = trimesh.Trimesh(vertices, triangles)
+ os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode), exist_ok=True)
+ mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode,
+ 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod)))
+
+
+
+ def validate_colored_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360,
+ threshold=0.0, mode='val',
+ # * 3d feature volume
+ conditional_volume=None,
+ conditional_valid_mask_volume=None,
+ feature_maps=None,
+ color_maps = None,
+ w2cs=None,
+ target_candidate_w2cs=None,
+ intrinsics=None,
+ rendering_network=None,
+ rendering_projector=None,
+ query_c2w=None,
+ lod=None, occupancy_mask=None,
+ bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None,
+ trans_mat=None
+ ):
+
+ bound_min = torch.tensor(bound_min, dtype=torch.float32)
+ bound_max = torch.tensor(bound_max, dtype=torch.float32)
+ # import time
+ # jha_begin4 = time.time()
+ vertices, triangles, fields = func_extract_geometry(
+ density_or_sdf_network,
+ bound_min, bound_max, resolution=resolution,
+ threshold=threshold, device=conditional_volume.device,
+ # * 3d feature volume
+ conditional_volume=conditional_volume, lod=lod,
+ occupancy_mask=occupancy_mask
+ )
+ # jha_end4 = time.time()
+ # print("[TEST]: func_extract_geometry time", jha_end4 - jha_begin4)
+
+ # import time
+ # jha_begin5 = time.time()
+
+
+ with torch.no_grad():
+ ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = rendering_projector.compute_view_independent(
+ torch.tensor(vertices).to(conditional_volume),
+ lod=lod, # JHA EDITED
+ # * 3d geometry feature volumes
+ geometryVolume=conditional_volume[0],
+ geometryVolumeMask=conditional_valid_mask_volume[0],
+ sdf_network=density_or_sdf_network,
+ # * 2d rendering feature maps
+ rendering_feature_maps=feature_maps, # [n_view, 56, 256, 256]
+ color_maps=color_maps,
+ w2cs=w2cs,
+ target_candidate_w2cs=target_candidate_w2cs,
+ intrinsics=intrinsics,
+ img_wh=[256,256],
+ query_img_idx=0, # the index of the N_views dim for rendering
+ query_c2w=query_c2w,
+ )
+
+
+ vertices_color, rendering_valid_mask = rendering_network(
+ ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask)
+
+ # jha_end5 = time.time()
+ # print("[TEST]: rendering_network time", jha_end5 - jha_begin5)
+
+ if scale_mat is not None:
+ scale_mat_np = scale_mat.cpu().numpy()
+ vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None]
+
+ if trans_mat is not None: # w2c_ref_inv
+ trans_mat_np = trans_mat.cpu().numpy()
+ vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1)
+ vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0]
+ # import ipdb; ipdb.set_trace()
+ vertices_color = np.array(vertices_color.squeeze(0).cpu() * 255, dtype=np.uint8)
+ mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertices_color)
+ os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod)), exist_ok=True)
+ # mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod),
+ # 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod)))
+ # MODIFIED
+ mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod),
+ 'mesh_{:0>8d}_gradio_lod{:0>1d}.ply'.format(iter_step, lod)))
\ No newline at end of file
diff --git a/SparseNeuS_demo_v1/models/trainer_generic_normals_new.py b/SparseNeuS_demo_v1/models/trainer_generic_normals_new.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a75f2c7fcaf613e1a4c5deeb9a8be15abd96d8d
--- /dev/null
+++ b/SparseNeuS_demo_v1/models/trainer_generic_normals_new.py
@@ -0,0 +1,1313 @@
+"""
+decouple the trainer with the renderer
+"""
+import os
+import cv2 as cv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import numpy as np
+import logging
+import mcubes
+import trimesh
+from icecream import ic
+
+from utils.misc_utils import visualize_depth_numpy
+
+from utils.training_utils import numpy2tensor
+from loss.depth_metric import compute_depth_errors
+
+from loss.depth_loss import DepthLoss, DepthSmoothLoss
+
+from models.rays import gen_rays_between
+
+from models.sparse_neus_renderer_normals_new import SparseNeuSRenderer
+
+def safe_l2_normalize(x, dim=None, eps=1e-6):
+ return F.normalize(x, p=2, dim=dim, eps=eps)
+
+
+class GenericTrainer(nn.Module):
+ def __init__(self,
+ rendering_network_outside,
+ pyramid_feature_network_lod0,
+ pyramid_feature_network_lod1,
+ sdf_network_lod0,
+ sdf_network_lod1,
+ variance_network_lod0,
+ variance_network_lod1,
+ rendering_network_lod0,
+ rendering_network_lod1,
+ n_samples_lod0,
+ n_importance_lod0,
+ n_samples_lod1,
+ n_importance_lod1,
+ n_outside,
+ perturb,
+ alpha_type='div',
+ conf=None,
+ timestamp="",
+ mode='train',
+ base_exp_dir=None,
+ ):
+ super(GenericTrainer, self).__init__()
+
+ self.conf = conf
+ self.timestamp = timestamp
+
+
+ self.base_exp_dir = base_exp_dir
+
+
+ self.anneal_start = self.conf.get_float('train.anneal_start', default=0.0)
+ self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0)
+ self.anneal_start_lod1 = self.conf.get_float('train.anneal_start_lod1', default=0.0)
+ self.anneal_end_lod1 = self.conf.get_float('train.anneal_end_lod1', default=0.0)
+
+ # network setups
+ self.rendering_network_outside = rendering_network_outside
+ self.pyramid_feature_network_geometry_lod0 = pyramid_feature_network_lod0 # 2D pyramid feature network for geometry
+ self.pyramid_feature_network_geometry_lod1 = pyramid_feature_network_lod1 # use differnet networks for the two lods
+
+ # when num_lods==2, may consume too much memeory
+ self.sdf_network_lod0 = sdf_network_lod0
+ self.sdf_network_lod1 = sdf_network_lod1
+
+ # - warpped by ModuleList to support DataParallel
+ self.variance_network_lod0 = variance_network_lod0
+ self.variance_network_lod1 = variance_network_lod1
+
+ self.rendering_network_lod0 = rendering_network_lod0
+ self.rendering_network_lod1 = rendering_network_lod1
+
+ self.n_samples_lod0 = n_samples_lod0
+ self.n_importance_lod0 = n_importance_lod0
+ self.n_samples_lod1 = n_samples_lod1
+ self.n_importance_lod1 = n_importance_lod1
+ self.n_outside = n_outside
+ self.num_lods = conf.get_int('model.num_lods') # the number of octree lods
+ self.perturb = perturb
+ self.alpha_type = alpha_type
+
+ # - the two renderers
+ self.sdf_renderer_lod0 = SparseNeuSRenderer(
+ self.rendering_network_outside,
+ self.sdf_network_lod0,
+ self.variance_network_lod0,
+ self.rendering_network_lod0,
+ self.n_samples_lod0,
+ self.n_importance_lod0,
+ self.n_outside,
+ self.perturb,
+ alpha_type='div',
+ conf=self.conf)
+
+ self.sdf_renderer_lod1 = SparseNeuSRenderer(
+ self.rendering_network_outside,
+ self.sdf_network_lod1,
+ self.variance_network_lod1,
+ self.rendering_network_lod1,
+ self.n_samples_lod1,
+ self.n_importance_lod1,
+ self.n_outside,
+ self.perturb,
+ alpha_type='div',
+ conf=self.conf)
+
+ self.if_fix_lod0_networks = self.conf.get_bool('train.if_fix_lod0_networks')
+
+ # sdf network weights
+ self.sdf_igr_weight = self.conf.get_float('train.sdf_igr_weight')
+ self.sdf_sparse_weight = self.conf.get_float('train.sdf_sparse_weight', default=0)
+ self.sdf_decay_param = self.conf.get_float('train.sdf_decay_param', default=100)
+ self.fg_bg_weight = self.conf.get_float('train.fg_bg_weight', default=0.00)
+ self.bg_ratio = self.conf.get_float('train.bg_ratio', default=0.0)
+
+ self.depth_criterion = DepthLoss()
+
+ # - DataParallel mode, cannot modify attributes in forward()
+ # self.iter_step = 0
+ self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq')
+
+ # - True for finetuning; False for general training
+ self.if_fitted_rendering = self.conf.get_bool('train.if_fitted_rendering', default=False)
+
+ self.prune_depth_filter = self.conf.get_bool('model.prune_depth_filter', default=False)
+
+ def get_trainable_params(self):
+ # set trainable params
+
+ self.params_to_train = []
+
+ if not self.if_fix_lod0_networks:
+ # load pretrained featurenet
+ self.params_to_train += list(self.pyramid_feature_network_geometry_lod0.parameters())
+ self.params_to_train += list(self.sdf_network_lod0.parameters())
+ self.params_to_train += list(self.variance_network_lod0.parameters())
+
+ if self.rendering_network_lod0 is not None:
+ self.params_to_train += list(self.rendering_network_lod0.parameters())
+
+ if self.sdf_network_lod1 is not None:
+ # load pretrained featurenet
+ self.params_to_train += list(self.pyramid_feature_network_geometry_lod1.parameters())
+
+ self.params_to_train += list(self.sdf_network_lod1.parameters())
+ self.params_to_train += list(self.variance_network_lod1.parameters())
+ if self.rendering_network_lod1 is not None:
+ self.params_to_train += list(self.rendering_network_lod1.parameters())
+
+ return self.params_to_train
+
+ def train_step(self, sample,
+ perturb_overwrite=-1,
+ background_rgb=None,
+ alpha_inter_ratio_lod0=0.0,
+ alpha_inter_ratio_lod1=0.0,
+ iter_step=0,
+ ):
+ # * only support batch_size==1
+ # ! attention: the list of string cannot be splited in DataParallel
+ batch_idx = sample['batch_idx'][0]
+ meta = sample['meta'][batch_idx] # the scan lighting ref_view info
+
+ sizeW = sample['img_wh'][0][0]
+ sizeH = sample['img_wh'][0][1]
+ partial_vol_origin = sample['partial_vol_origin'] # [B, 3]
+ near, far = sample['near_fars'][0, 0, :1], sample['near_fars'][0, 0, 1:]
+
+ # the full-size ray variables
+ sample_rays = sample['rays']
+ rays_o = sample_rays['rays_o'][0]
+ rays_d = sample_rays['rays_v'][0]
+
+ imgs = sample['images'][0]
+ intrinsics = sample['intrinsics'][0]
+ intrinsics_l_4x = intrinsics.clone()
+ intrinsics_l_4x[:, :2] *= 0.25
+ w2cs = sample['w2cs'][0]
+ c2ws = sample['c2ws'][0]
+ proj_matrices = sample['affine_mats']
+ scale_mat = sample['scale_mat']
+ trans_mat = sample['trans_mat']
+
+ # *********************** Lod==0 ***********************
+ if not self.if_fix_lod0_networks:
+ geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs)
+
+ conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
+ feature_maps=geometry_feature_maps[None, 1:, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices[:,1:],
+ # proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ lod=0,
+ )
+
+ else:
+ with torch.no_grad():
+ geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
+ # geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
+ conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
+ feature_maps=geometry_feature_maps[None, 1:, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices[:,1:],
+ # proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ lod=0,
+ )
+ # print("Checker2:, construct cost volume")
+ con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
+
+ con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
+ # import ipdb; ipdb.set_trace()
+ coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
+
+ # * extract depth maps for all the images
+ depth_maps_lod0, depth_masks_lod0 = None, None
+ if self.num_lods > 1:
+ sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume(
+ con_volume_lod0, con_valid_mask_volume_lod0,
+ coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ]
+
+ if self.prune_depth_filter:
+ depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps(
+ self.sdf_network_lod0, sdf_volume_lod0, intrinsics_l_4x, c2ws,
+ sizeH // 4, sizeW // 4, near * 1.5, far)
+ depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear',
+ align_corners=True)
+ depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest')
+
+ # *************** losses
+ loss_lod0, losses_lod0, depth_statis_lod0 = None, None, None
+
+ if not self.if_fix_lod0_networks:
+
+ render_out = self.sdf_renderer_lod0.render(
+ rays_o, rays_d, near, far,
+ self.sdf_network_lod0,
+ self.rendering_network_lod0,
+ background_rgb=background_rgb,
+ alpha_inter_ratio=alpha_inter_ratio_lod0,
+ # * related to conditional feature
+ lod=0,
+ conditional_volume=con_volume_lod0,
+ conditional_valid_mask_volume=con_valid_mask_volume_lod0,
+ # * 2d feature maps
+ feature_maps=geometry_feature_maps,
+ color_maps=imgs,
+ w2cs=w2cs,
+ intrinsics=intrinsics,
+ img_wh=[sizeW, sizeH],
+ if_general_rendering=True,
+ if_render_with_grad=True,
+ )
+
+ loss_lod0, losses_lod0, depth_statis_lod0 = self.cal_losses_sdf(render_out, sample_rays,
+ iter_step, lod=0)
+
+ # *********************** Lod==1 ***********************
+
+ loss_lod1, losses_lod1, depth_statis_lod1 = None, None, None
+
+ if self.num_lods > 1:
+ geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
+ # geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
+ if self.prune_depth_filter:
+ pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter(
+ sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0],
+ depth_maps_lod0, proj_matrices[0],
+ partial_vol_origin, self.sdf_network_lod0.voxel_size,
+ near, far, self.sdf_network_lod0.voxel_size, 12)
+ else:
+ pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf(
+ sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0])
+
+ pre_coords[:, 1:] = pre_coords[:, 1:] * 2
+
+ # ? It seems that training gru_fusion, this part should be trainable too
+ conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume(
+ feature_maps=geometry_feature_maps_lod1[None, 1:, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices[:,1:],
+ # proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ pre_coords=pre_coords,
+ pre_feats=pre_feats,
+ )
+
+ con_volume_lod1 = conditional_features_lod1['dense_volume_scale1']
+ con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1']
+
+ # if not self.if_gru_fusion_lod1:
+ render_out_lod1 = self.sdf_renderer_lod1.render(
+ rays_o, rays_d, near, far,
+ self.sdf_network_lod1,
+ self.rendering_network_lod1,
+ background_rgb=background_rgb,
+ alpha_inter_ratio=alpha_inter_ratio_lod1,
+ # * related to conditional feature
+ lod=1,
+ conditional_volume=con_volume_lod1,
+ conditional_valid_mask_volume=con_valid_mask_volume_lod1,
+ # * 2d feature maps
+ feature_maps=geometry_feature_maps_lod1,
+ color_maps=imgs,
+ w2cs=w2cs,
+ intrinsics=intrinsics,
+ img_wh=[sizeW, sizeH],
+ bg_ratio=self.bg_ratio,
+ )
+ loss_lod1, losses_lod1, depth_statis_lod1 = self.cal_losses_sdf(render_out_lod1, sample_rays,
+ iter_step, lod=1)
+
+ # print("Checker3:, compute losses")
+ # # - extract mesh
+ if iter_step % self.val_mesh_freq == 0:
+ torch.cuda.empty_cache()
+ self.validate_mesh(self.sdf_network_lod0,
+ self.sdf_renderer_lod0.extract_geometry,
+ conditional_volume=con_volume_lod0, lod=0,
+ threshold=0,
+ # occupancy_mask=con_valid_mask_volume_lod0[0, 0],
+ mode='train_bg', meta=meta,
+ iter_step=iter_step, scale_mat=scale_mat,
+ trans_mat=trans_mat)
+ torch.cuda.empty_cache()
+
+ if self.num_lods > 1:
+ self.validate_mesh(self.sdf_network_lod1,
+ self.sdf_renderer_lod1.extract_geometry,
+ conditional_volume=con_volume_lod1, lod=1,
+ # occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(),
+ mode='train_bg', meta=meta,
+ iter_step=iter_step, scale_mat=scale_mat,
+ trans_mat=trans_mat)
+ # import ipdb; ipdb.set_trace()
+ # print("Checker3.1:, after val mesh")
+ losses = {
+ # - lod 0
+ 'loss_lod0': loss_lod0,
+ 'losses_lod0': losses_lod0,
+ 'depth_statis_lod0': depth_statis_lod0,
+
+ # - lod 1
+ 'loss_lod1': loss_lod1,
+ 'losses_lod1': losses_lod1,
+ 'depth_statis_lod1': depth_statis_lod1,
+
+ }
+
+ return losses
+
+ def val_step(self, sample,
+ perturb_overwrite=-1,
+ background_rgb=None,
+ alpha_inter_ratio_lod0=0.0,
+ alpha_inter_ratio_lod1=0.0,
+ iter_step=0,
+ chunk_size=512,
+ save_vis=False,
+ ):
+ # * only support batch_size==1
+ # ! attention: the list of string cannot be splited in DataParallel
+ batch_idx = sample['batch_idx'][0]
+ meta = sample['meta'][batch_idx] # the scan lighting ref_view info
+
+ sizeW = sample['img_wh'][0][0]
+ sizeH = sample['img_wh'][0][1]
+ H, W = sizeH, sizeW
+
+ partial_vol_origin = sample['partial_vol_origin'] # [B, 3]
+ near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:]
+
+ # the ray variables
+ sample_rays = sample['rays']
+ rays_o = sample_rays['rays_o'][0]
+ rays_d = sample_rays['rays_v'][0]
+ rays_ndc_uv = sample_rays['rays_ndc_uv'][0]
+
+ imgs = sample['images'][0]
+ intrinsics = sample['intrinsics'][0]
+ intrinsics_l_4x = intrinsics.clone()
+ intrinsics_l_4x[:, :2] *= 0.25
+ w2cs = sample['w2cs'][0]
+ c2ws = sample['c2ws'][0]
+ proj_matrices = sample['affine_mats']
+
+ # render_img_idx = sample['render_img_idx'][0]
+ # true_img = sample['images'][0][render_img_idx]
+
+ # - the image to render
+ scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale
+ trans_mat = sample['trans_mat']
+ query_c2w = sample['query_c2w'] # [1,4,4]
+ query_w2c = sample['query_w2c'] # [1,4,4]
+ true_img = sample['query_image'][0]
+ true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255)
+
+ depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy()
+
+ scale_factor = sample['scale_factor'][0].cpu().numpy()
+ true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None
+ if true_depth is not None:
+ true_depth = true_depth[0].cpu().numpy()
+ true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0]
+ else:
+ true_depth_colored = None
+
+ rays_o = rays_o.reshape(-1, 3).split(chunk_size)
+ rays_d = rays_d.reshape(-1, 3).split(chunk_size)
+
+ # - obtain conditional features
+ with torch.no_grad():
+ # - obtain conditional features
+ geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
+ # import ipdb; ipdb.set_trace()
+ # - lod 0
+ conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
+ feature_maps=geometry_feature_maps[None, :, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ lod=0,
+ )
+
+ con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
+ con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
+ coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
+
+ if self.num_lods > 1:
+ sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume(
+ con_volume_lod0, con_valid_mask_volume_lod0,
+ coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ]
+
+ depth_maps_lod0, depth_masks_lod0 = None, None
+ if self.prune_depth_filter:
+ depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps(
+ self.sdf_network_lod0, sdf_volume_lod0,
+ intrinsics_l_4x, c2ws,
+ sizeH // 4, sizeW // 4, near * 1.5, far) # - near*1.5 is a experienced number
+ depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear',
+ align_corners=True)
+ depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest')
+
+ #### visualize the depth_maps_lod0 for checking
+ colored_depth_maps_lod0 = []
+ for i in range(depth_maps_lod0.shape[0]):
+ colored_depth_maps_lod0.append(
+ visualize_depth_numpy(depth_maps_lod0[i, 0].cpu().numpy(), [depth_min, depth_max])[0])
+
+ colored_depth_maps_lod0 = np.concatenate(colored_depth_maps_lod0, axis=0).astype(np.uint8)
+ os.makedirs(os.path.join(self.base_exp_dir, 'depth_maps_lod0'), exist_ok=True)
+ cv.imwrite(os.path.join(self.base_exp_dir, 'depth_maps_lod0',
+ '{:0>8d}_{}.png'.format(iter_step, meta)),
+ colored_depth_maps_lod0[:, :, ::-1])
+
+ if self.num_lods > 1:
+ geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
+
+ if self.prune_depth_filter:
+ pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter(
+ sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0],
+ depth_maps_lod0, proj_matrices[0],
+ partial_vol_origin, self.sdf_network_lod0.voxel_size,
+ near, far, self.sdf_network_lod0.voxel_size, 12)
+ else:
+ pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf(
+ sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0])
+
+ pre_coords[:, 1:] = pre_coords[:, 1:] * 2
+
+ with torch.no_grad():
+ conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume(
+ feature_maps=geometry_feature_maps_lod1[None, :, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ pre_coords=pre_coords,
+ pre_feats=pre_feats,
+ )
+
+ con_volume_lod1 = conditional_features_lod1['dense_volume_scale1']
+ con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1']
+
+ out_rgb_fine = []
+ out_normal_fine = []
+ out_depth_fine = []
+
+ out_rgb_fine_lod1 = []
+ out_normal_fine_lod1 = []
+ out_depth_fine_lod1 = []
+
+ # out_depth_fine_explicit = []
+ if save_vis:
+ for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
+
+ # ****** lod 0 ****
+ render_out = self.sdf_renderer_lod0.render(
+ rays_o_batch, rays_d_batch, near, far,
+ self.sdf_network_lod0,
+ self.rendering_network_lod0,
+ background_rgb=background_rgb,
+ alpha_inter_ratio=alpha_inter_ratio_lod0,
+ # * related to conditional feature
+ lod=0,
+ conditional_volume=con_volume_lod0,
+ conditional_valid_mask_volume=con_valid_mask_volume_lod0,
+ # * 2d feature maps
+ feature_maps=geometry_feature_maps,
+ color_maps=imgs,
+ w2cs=w2cs,
+ intrinsics=intrinsics,
+ img_wh=[sizeW, sizeH],
+ query_c2w=query_c2w,
+ if_render_with_grad=False,
+ )
+
+ feasible = lambda key: ((key in render_out) and (render_out[key] is not None))
+
+ if feasible('depth'):
+ out_depth_fine.append(render_out['depth'].detach().cpu().numpy())
+
+ # if render_out['color_coarse'] is not None:
+ if feasible('color_fine'):
+ out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
+ if feasible('gradients') and feasible('weights'):
+ if render_out['inside_sphere'] is not None:
+ out_normal_fine.append((render_out['gradients'] * render_out['weights'][:,
+ :self.n_samples_lod0 + self.n_importance_lod0,
+ None] * render_out['inside_sphere'][
+ ..., None]).sum(dim=1).detach().cpu().numpy())
+ else:
+ out_normal_fine.append((render_out['gradients'] * render_out['weights'][:,
+ :self.n_samples_lod0 + self.n_importance_lod0,
+ None]).sum(dim=1).detach().cpu().numpy())
+ del render_out
+
+ # ****************** lod 1 **************************
+ if self.num_lods > 1:
+ for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
+ render_out_lod1 = self.sdf_renderer_lod1.render(
+ rays_o_batch, rays_d_batch, near, far,
+ self.sdf_network_lod1,
+ self.rendering_network_lod1,
+ background_rgb=background_rgb,
+ alpha_inter_ratio=alpha_inter_ratio_lod1,
+ # * related to conditional feature
+ lod=1,
+ conditional_volume=con_volume_lod1,
+ conditional_valid_mask_volume=con_valid_mask_volume_lod1,
+ # * 2d feature maps
+ feature_maps=geometry_feature_maps_lod1,
+ color_maps=imgs,
+ w2cs=w2cs,
+ intrinsics=intrinsics,
+ img_wh=[sizeW, sizeH],
+ query_c2w=query_c2w,
+ if_render_with_grad=False,
+ )
+
+ feasible = lambda key: ((key in render_out_lod1) and (render_out_lod1[key] is not None))
+
+ if feasible('depth'):
+ out_depth_fine_lod1.append(render_out_lod1['depth'].detach().cpu().numpy())
+
+ # if render_out['color_coarse'] is not None:
+ if feasible('color_fine'):
+ out_rgb_fine_lod1.append(render_out_lod1['color_fine'].detach().cpu().numpy())
+ if feasible('gradients') and feasible('weights'):
+ if render_out_lod1['inside_sphere'] is not None:
+ out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:,
+ :self.n_samples_lod1 + self.n_importance_lod1,
+ None] *
+ render_out_lod1['inside_sphere'][
+ ..., None]).sum(dim=1).detach().cpu().numpy())
+ else:
+ out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:,
+ :self.n_samples_lod1 + self.n_importance_lod1,
+ None]).sum(
+ dim=1).detach().cpu().numpy())
+ del render_out_lod1
+
+ # - save visualization of lod 0
+
+ self.save_visualization(true_img, true_depth_colored, out_depth_fine, out_normal_fine,
+ query_w2c[0], out_rgb_fine, H, W,
+ depth_min, depth_max, iter_step, meta, "val_lod0", true_depth=true_depth, scale_factor=scale_factor)
+
+ if self.num_lods > 1:
+ self.save_visualization(true_img, true_depth_colored, out_depth_fine_lod1, out_normal_fine_lod1,
+ query_w2c[0], out_rgb_fine_lod1, H, W,
+ depth_min, depth_max, iter_step, meta, "val_lod1", true_depth=true_depth, scale_factor=scale_factor)
+
+ # - extract mesh
+ if (iter_step % self.val_mesh_freq == 0):
+ torch.cuda.empty_cache()
+ self.validate_mesh(self.sdf_network_lod0,
+ self.sdf_renderer_lod0.extract_geometry,
+ conditional_volume=con_volume_lod0, lod=0,
+ threshold=0,
+ # occupancy_mask=con_valid_mask_volume_lod0[0, 0],
+ mode='val_bg', meta=meta,
+ iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat)
+ torch.cuda.empty_cache()
+
+ if self.num_lods > 1:
+ self.validate_mesh(self.sdf_network_lod1,
+ self.sdf_renderer_lod1.extract_geometry,
+ conditional_volume=con_volume_lod1, lod=1,
+ # occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(),
+ mode='val_bg', meta=meta,
+ iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat)
+
+ torch.cuda.empty_cache()
+
+
+
+ def export_mesh_step(self, sample,
+ perturb_overwrite=-1,
+ background_rgb=None,
+ alpha_inter_ratio_lod0=0.0,
+ alpha_inter_ratio_lod1=0.0,
+ iter_step=0,
+ chunk_size=512,
+ save_vis=False,
+ ):
+ # * only support batch_size==1
+ # ! attention: the list of string cannot be splited in DataParallel
+ batch_idx = sample['batch_idx'][0]
+ meta = sample['meta'][batch_idx] # the scan lighting ref_view info
+
+ sizeW = sample['img_wh'][0][0]
+ sizeH = sample['img_wh'][0][1]
+ H, W = sizeH, sizeW
+
+ partial_vol_origin = sample['partial_vol_origin'] # [B, 3]
+ near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:]
+
+ # the ray variables
+ sample_rays = sample['rays']
+ rays_o = sample_rays['rays_o'][0]
+ rays_d = sample_rays['rays_v'][0]
+ rays_ndc_uv = sample_rays['rays_ndc_uv'][0]
+
+ imgs = sample['images'][0]
+ intrinsics = sample['intrinsics'][0]
+ intrinsics_l_4x = intrinsics.clone()
+ intrinsics_l_4x[:, :2] *= 0.25
+ w2cs = sample['w2cs'][0]
+ c2ws = sample['c2ws'][0]
+ # target_candidate_w2cs = sample['target_candidate_w2cs'][0]
+ proj_matrices = sample['affine_mats']
+
+
+ # - the image to render
+ scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale
+ trans_mat = sample['trans_mat']
+ query_c2w = sample['query_c2w'] # [1,4,4]
+ query_w2c = sample['query_w2c'] # [1,4,4]
+ true_img = sample['query_image'][0]
+ true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255)
+
+ depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy()
+
+ scale_factor = sample['scale_factor'][0].cpu().numpy()
+ true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None
+ if true_depth is not None:
+ true_depth = true_depth[0].cpu().numpy()
+ true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0]
+ else:
+ true_depth_colored = None
+
+ rays_o = rays_o.reshape(-1, 3).split(chunk_size)
+ rays_d = rays_d.reshape(-1, 3).split(chunk_size)
+
+ # - obtain conditional features
+ with torch.no_grad():
+ # - obtain conditional features
+ geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
+ # - lod 0
+ conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
+ feature_maps=geometry_feature_maps[None, :, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ lod=0,
+ )
+
+ con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
+ con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
+ coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
+
+ if self.num_lods > 1:
+ sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume(
+ con_volume_lod0, con_valid_mask_volume_lod0,
+ coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ]
+
+ depth_maps_lod0, depth_masks_lod0 = None, None
+
+
+ if self.num_lods > 1:
+ geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
+
+ if self.prune_depth_filter:
+ pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter(
+ sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0],
+ depth_maps_lod0, proj_matrices[0],
+ partial_vol_origin, self.sdf_network_lod0.voxel_size,
+ near, far, self.sdf_network_lod0.voxel_size, 12)
+ else:
+ pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf(
+ sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0])
+
+ pre_coords[:, 1:] = pre_coords[:, 1:] * 2
+
+ with torch.no_grad():
+ conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume(
+ feature_maps=geometry_feature_maps_lod1[None, :, :, :, :],
+ partial_vol_origin=partial_vol_origin,
+ proj_mats=proj_matrices,
+ sizeH=sizeH,
+ sizeW=sizeW,
+ pre_coords=pre_coords,
+ pre_feats=pre_feats,
+ )
+
+ con_volume_lod1 = conditional_features_lod1['dense_volume_scale1']
+ con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1']
+
+ out_rgb_fine = []
+ out_normal_fine = []
+ out_depth_fine = []
+
+ out_rgb_fine_lod1 = []
+ out_normal_fine_lod1 = []
+ out_depth_fine_lod1 = []
+
+ # # out_depth_fine_explicit = []
+ # if save_vis:
+ # for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
+
+ # # ****** lod 0 ****
+ # render_out = self.sdf_renderer_lod0.render(
+ # rays_o_batch, rays_d_batch, near, far,
+ # self.sdf_network_lod0,
+ # self.rendering_network_lod0,
+ # background_rgb=background_rgb,
+ # alpha_inter_ratio=alpha_inter_ratio_lod0,
+ # # * related to conditional feature
+ # lod=0,
+ # conditional_volume=con_volume_lod0,
+ # conditional_valid_mask_volume=con_valid_mask_volume_lod0,
+ # # * 2d feature maps
+ # feature_maps=geometry_feature_maps,
+ # color_maps=imgs,
+ # w2cs=w2cs,
+ # intrinsics=intrinsics,
+ # img_wh=[sizeW, sizeH],
+ # query_c2w=query_c2w,
+ # if_render_with_grad=False,
+ # )
+
+ # feasible = lambda key: ((key in render_out) and (render_out[key] is not None))
+
+ # if feasible('depth'):
+ # out_depth_fine.append(render_out['depth'].detach().cpu().numpy())
+
+ # # if render_out['color_coarse'] is not None:
+ # if feasible('color_fine'):
+ # out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
+ # if feasible('gradients') and feasible('weights'):
+ # if render_out['inside_sphere'] is not None:
+ # out_normal_fine.append((render_out['gradients'] * render_out['weights'][:,
+ # :self.n_samples_lod0 + self.n_importance_lod0,
+ # None] * render_out['inside_sphere'][
+ # ..., None]).sum(dim=1).detach().cpu().numpy())
+ # else:
+ # out_normal_fine.append((render_out['gradients'] * render_out['weights'][:,
+ # :self.n_samples_lod0 + self.n_importance_lod0,
+ # None]).sum(dim=1).detach().cpu().numpy())
+ # del render_out
+
+ # # ****************** lod 1 **************************
+ # if self.num_lods > 1:
+ # for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
+ # render_out_lod1 = self.sdf_renderer_lod1.render(
+ # rays_o_batch, rays_d_batch, near, far,
+ # self.sdf_network_lod1,
+ # self.rendering_network_lod1,
+ # background_rgb=background_rgb,
+ # alpha_inter_ratio=alpha_inter_ratio_lod1,
+ # # * related to conditional feature
+ # lod=1,
+ # conditional_volume=con_volume_lod1,
+ # conditional_valid_mask_volume=con_valid_mask_volume_lod1,
+ # # * 2d feature maps
+ # feature_maps=geometry_feature_maps_lod1,
+ # color_maps=imgs,
+ # w2cs=w2cs,
+ # intrinsics=intrinsics,
+ # img_wh=[sizeW, sizeH],
+ # query_c2w=query_c2w,
+ # if_render_with_grad=False,
+ # )
+
+ # feasible = lambda key: ((key in render_out_lod1) and (render_out_lod1[key] is not None))
+
+ # if feasible('depth'):
+ # out_depth_fine_lod1.append(render_out_lod1['depth'].detach().cpu().numpy())
+
+ # # if render_out['color_coarse'] is not None:
+ # if feasible('color_fine'):
+ # out_rgb_fine_lod1.append(render_out_lod1['color_fine'].detach().cpu().numpy())
+ # if feasible('gradients') and feasible('weights'):
+ # if render_out_lod1['inside_sphere'] is not None:
+ # out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:,
+ # :self.n_samples_lod1 + self.n_importance_lod1,
+ # None] *
+ # render_out_lod1['inside_sphere'][
+ # ..., None]).sum(dim=1).detach().cpu().numpy())
+ # else:
+ # out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:,
+ # :self.n_samples_lod1 + self.n_importance_lod1,
+ # None]).sum(
+ # dim=1).detach().cpu().numpy())
+ # del render_out_lod1
+
+ # # - save visualization of lod 0
+
+ # self.save_visualization(true_img, true_depth_colored, out_depth_fine, out_normal_fine,
+ # query_w2c[0], out_rgb_fine, H, W,
+ # depth_min, depth_max, iter_step, meta, "val_lod0", true_depth=true_depth, scale_factor=scale_factor)
+
+ # if self.num_lods > 1:
+ # self.save_visualization(true_img, true_depth_colored, out_depth_fine_lod1, out_normal_fine_lod1,
+ # query_w2c[0], out_rgb_fine_lod1, H, W,
+ # depth_min, depth_max, iter_step, meta, "val_lod1", true_depth=true_depth, scale_factor=scale_factor)
+
+ # - extract mesh
+ if (iter_step % self.val_mesh_freq == 0):
+ torch.cuda.empty_cache()
+ self.validate_colored_mesh(
+ density_or_sdf_network=self.sdf_network_lod0,
+ func_extract_geometry=self.sdf_renderer_lod0.extract_geometry,
+ conditional_volume=con_volume_lod0,
+ conditional_valid_mask_volume = con_valid_mask_volume_lod0,
+ feature_maps=geometry_feature_maps,
+ color_maps=imgs,
+ w2cs=w2cs,
+ target_candidate_w2cs=None,
+ intrinsics=intrinsics,
+ rendering_network=self.rendering_network_lod0,
+ rendering_projector=self.sdf_renderer_lod0.rendering_projector,
+ lod=0,
+ threshold=0,
+ query_c2w=query_c2w,
+ mode='val_bg', meta=meta,
+ iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat
+ )
+ torch.cuda.empty_cache()
+
+ if self.num_lods > 1:
+ self.validate_colored_mesh(
+ density_or_sdf_network=self.sdf_network_lod1,
+ func_extract_geometry=self.sdf_renderer_lod1.extract_geometry,
+ conditional_volume=con_volume_lod1,
+ conditional_valid_mask_volume = con_valid_mask_volume_lod1,
+ feature_maps=geometry_feature_maps,
+ color_maps=imgs,
+ w2cs=w2cs,
+ target_candidate_w2cs=None,
+ intrinsics=intrinsics,
+ rendering_network=self.rendering_network_lod1,
+ rendering_projector=self.sdf_renderer_lod1.rendering_projector,
+ lod=1,
+ threshold=0,
+ query_c2w=query_c2w,
+ mode='val_bg', meta=meta,
+ iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat
+ )
+ torch.cuda.empty_cache()
+ # self.validate_mesh(self.sdf_network_lod1,
+ # self.sdf_renderer_lod1.extract_geometry,
+ # conditional_volume=con_volume_lod1, lod=1,
+ # # occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(),
+ # mode='val_bg', meta=meta,
+ # iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat)
+
+ # torch.cuda.empty_cache()
+
+
+ def save_visualization(self, true_img, true_colored_depth, out_depth, out_normal, w2cs, out_color, H, W,
+ depth_min, depth_max, iter_step, meta, comment, out_color_mlp=[], true_depth=None, scale_factor=1.0):
+ if len(out_color) > 0:
+ img_fine = (np.concatenate(out_color, axis=0).reshape([H, W, 3]) * 256).clip(0, 255)
+
+ if len(out_color_mlp) > 0:
+ img_mlp = (np.concatenate(out_color_mlp, axis=0).reshape([H, W, 3]) * 256).clip(0, 255)
+
+ if len(out_normal) > 0:
+ normal_img = np.concatenate(out_normal, axis=0)
+ rot = w2cs[:3, :3].detach().cpu().numpy()
+ # - convert normal from world space to camera space
+ normal_img = (np.matmul(rot[None, :, :],
+ normal_img[:, :, None]).reshape([H, W, 3]) * 128 + 128).clip(0, 255)
+ if len(out_depth) > 0:
+ pred_depth = np.concatenate(out_depth, axis=0).reshape([H, W])
+ pred_depth_colored = visualize_depth_numpy(pred_depth, [depth_min, depth_max])[0]
+
+ if len(out_depth) > 0:
+ os.makedirs(os.path.join(self.base_exp_dir, 'depths_' + comment), exist_ok=True)
+ if true_colored_depth is not None:
+
+ if true_depth is not None:
+ depth_error_map = np.abs(true_depth - pred_depth) * 2.0 / scale_factor
+ # [256, 256, 1] -> [256, 256, 3]
+ depth_error_map = np.tile(depth_error_map[:, :, None], [1, 1, 3])
+ print("meta: ", meta)
+ print("scale_factor: ", scale_factor)
+ print("depth_error_mean: ", depth_error_map.mean())
+ # import ipdb; ipdb.set_trace()
+ depth_visualized = np.concatenate(
+ [(depth_error_map * 255).astype(np.uint8), true_colored_depth, pred_depth_colored, true_img], axis=1)[:, :, ::-1]
+ # print("depth_visualized.shape: ", depth_visualized.shape)
+ # write depth error result text on img, the input is a numpy array of [256, 1024, 3]
+ # cv.putText(depth_visualized.copy(), "depth_error_mean: {:.4f}".format(depth_error_map.mean()), (10, 30), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
+ else:
+ depth_visualized = np.concatenate(
+ [true_colored_depth, pred_depth_colored, true_img])[:, :, ::-1]
+ cv.imwrite(
+ os.path.join(self.base_exp_dir, 'depths_' + comment,
+ '{:0>8d}_{}.png'.format(iter_step, meta)), depth_visualized
+ )
+ else:
+ cv.imwrite(
+ os.path.join(self.base_exp_dir, 'depths_' + comment,
+ '{:0>8d}_{}.png'.format(iter_step, meta)),
+ np.concatenate(
+ [pred_depth_colored, true_img])[:, :, ::-1])
+ if len(out_color) > 0:
+ os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment), exist_ok=True)
+ cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment,
+ '{:0>8d}_{}.png'.format(iter_step, meta)),
+ np.concatenate(
+ [img_fine, true_img])[:, :, ::-1]) # bgr2rgb
+ # compute psnr (image pixel lie in [0, 255])
+ mse_loss = np.mean((img_fine - true_img) ** 2)
+ psnr = 10 * np.log10(255 ** 2 / mse_loss)
+
+ print("PSNR: ", psnr)
+
+ if len(out_color_mlp) > 0:
+ os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment), exist_ok=True)
+ cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment,
+ '{:0>8d}_{}.png'.format(iter_step, meta)),
+ np.concatenate(
+ [img_mlp, true_img])[:, :, ::-1]) # bgr2rgb
+
+ if len(out_normal) > 0:
+ os.makedirs(os.path.join(self.base_exp_dir, 'normals_' + comment), exist_ok=True)
+ cv.imwrite(os.path.join(self.base_exp_dir, 'normals_' + comment,
+ '{:0>8d}_{}.png'.format(iter_step, meta)),
+ normal_img[:, :, ::-1])
+
+ def forward(self, sample,
+ perturb_overwrite=-1,
+ background_rgb=None,
+ alpha_inter_ratio_lod0=0.0,
+ alpha_inter_ratio_lod1=0.0,
+ iter_step=0,
+ mode='train',
+ save_vis=False,
+ ):
+
+ if mode == 'train':
+ return self.train_step(sample,
+ perturb_overwrite=perturb_overwrite,
+ background_rgb=background_rgb,
+ alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
+ alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
+ iter_step=iter_step
+ )
+ elif mode == 'val':
+ import time
+ begin = time.time()
+ result = self.val_step(sample,
+ perturb_overwrite=perturb_overwrite,
+ background_rgb=background_rgb,
+ alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
+ alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
+ iter_step=iter_step,
+ save_vis=save_vis,
+ )
+ end = time.time()
+ print("val_step time: ", end - begin)
+ return result
+ elif mode == 'export_mesh':
+ import time
+ begin = time.time()
+ result = self.export_mesh_step(sample,
+ perturb_overwrite=perturb_overwrite,
+ background_rgb=background_rgb,
+ alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
+ alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
+ iter_step=iter_step,
+ save_vis=save_vis,
+ )
+ end = time.time()
+ print("export mesh time: ", end - begin)
+ return result
+ def obtain_pyramid_feature_maps(self, imgs, lod=0):
+ """
+ get feature maps of all conditional images
+ :param imgs:
+ :return:
+ """
+
+ if lod == 0:
+ extractor = self.pyramid_feature_network_geometry_lod0
+ elif lod >= 1:
+ extractor = self.pyramid_feature_network_geometry_lod1
+
+ pyramid_feature_maps = extractor(imgs)
+
+ # * the pyramid features are very important, if only use the coarst features, hard to optimize
+ fused_feature_maps = torch.cat([
+ F.interpolate(pyramid_feature_maps[0], scale_factor=4, mode='bilinear', align_corners=True),
+ F.interpolate(pyramid_feature_maps[1], scale_factor=2, mode='bilinear', align_corners=True),
+ pyramid_feature_maps[2]
+ ], dim=1)
+
+ return fused_feature_maps
+
+ def cal_losses_sdf(self, render_out, sample_rays, iter_step=-1, lod=0):
+
+ # loss weight schedule; the regularization terms should be added in later training stage
+ def get_weight(iter_step, weight):
+ if lod == 1:
+ anneal_start = self.anneal_end if lod == 0 else self.anneal_end_lod1
+ anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1
+ anneal_end = anneal_end * 2
+ else:
+ anneal_start = self.anneal_start if lod == 0 else self.anneal_start_lod1
+ anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1
+ anneal_end = anneal_end * 2
+
+ if iter_step < 0:
+ return weight
+
+ if anneal_end == 0.0:
+ return weight
+ elif iter_step < anneal_start:
+ return 0.0
+ else:
+ return np.min(
+ [1.0,
+ (iter_step - anneal_start) / (anneal_end - anneal_start)]) * weight
+
+ rays_o = sample_rays['rays_o'][0]
+ rays_d = sample_rays['rays_v'][0]
+ true_rgb = sample_rays['rays_color'][0]
+
+ if 'rays_depth' in sample_rays.keys():
+ true_depth = sample_rays['rays_depth'][0]
+ else:
+ true_depth = None
+ mask = sample_rays['rays_mask'][0]
+
+ color_fine = render_out['color_fine']
+ color_fine_mask = render_out['color_fine_mask']
+ depth_pred = render_out['depth']
+
+ variance = render_out['variance']
+ cdf_fine = render_out['cdf_fine']
+ weight_sum = render_out['weights_sum']
+
+ gradient_error_fine = render_out['gradient_error_fine']
+
+ sdf = render_out['sdf']
+
+ # * color generated by mlp
+ color_mlp = render_out['color_mlp']
+ color_mlp_mask = render_out['color_mlp_mask']
+
+ if color_fine is not None:
+ # Color loss
+ color_mask = color_fine_mask if color_fine_mask is not None else mask
+ # import ipdb; ipdb.set_trace()
+ color_mask = color_mask[..., 0]
+ color_error = (color_fine[color_mask] - true_rgb[color_mask])
+ # print("Nan number", torch.isnan(color_error).sum())
+ # print("Color error shape", color_error.shape)
+ # import ipdb; ipdb.set_trace()
+ color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error).to(color_error.device),
+ reduction='mean')
+ # print(color_fine_loss)
+ psnr = 20.0 * torch.log10(
+ 1.0 / (((color_fine[color_mask] - true_rgb[color_mask]) ** 2).mean() / (3.0)).sqrt())
+ else:
+ color_fine_loss = 0.
+ psnr = 0.
+
+ if color_mlp is not None:
+ # Color loss
+ color_mlp_mask = color_mlp_mask[..., 0]
+ color_error_mlp = (color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask])
+ color_mlp_loss = F.l1_loss(color_error_mlp,
+ torch.zeros_like(color_error_mlp).to(color_error_mlp.device),
+ reduction='mean')
+
+ psnr_mlp = 20.0 * torch.log10(
+ 1.0 / (((color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) ** 2).mean() / (3.0)).sqrt())
+ else:
+ color_mlp_loss = 0.
+ psnr_mlp = 0.
+
+ # depth loss is only used for inference, not included in total loss
+ if true_depth is not None:
+ # depth_loss = self.depth_criterion(depth_pred, true_depth, mask)
+ depth_loss = self.depth_criterion(depth_pred, true_depth)
+
+ # # depth evaluation
+ # depth_statis = compute_depth_errors(depth_pred.detach().cpu().numpy(), true_depth.cpu().numpy())
+ # depth_statis = numpy2tensor(depth_statis, device=rays_o.device)
+ depth_statis = None
+ else:
+ depth_loss = 0.
+ depth_statis = None
+
+ sparse_loss_1 = torch.exp(
+ -1 * torch.abs(render_out['sdf_random']) * self.sdf_decay_param).mean() # - should equal
+ sparse_loss_2 = torch.exp(-1 * torch.abs(sdf) * self.sdf_decay_param).mean()
+ sparse_loss = (sparse_loss_1 + sparse_loss_2) / 2
+
+ sdf_mean = torch.abs(sdf).mean()
+ sparseness_1 = (torch.abs(sdf) < 0.01).to(torch.float32).mean()
+ sparseness_2 = (torch.abs(sdf) < 0.02).to(torch.float32).mean()
+
+ # Eikonal loss
+ gradient_error_loss = gradient_error_fine
+
+ # ! the first 50k, don't use bg constraint
+ fg_bg_weight = 0.0 if iter_step < 50000 else get_weight(iter_step, self.fg_bg_weight)
+
+ # Mask loss, optional
+ # The images of DTU dataset contain large black regions (0 rgb values),
+ # can use this data prior to make fg more clean
+ background_loss = 0.0
+ fg_bg_loss = 0.0
+ if self.fg_bg_weight > 0 and torch.mean((mask < 0.5).to(torch.float32)) > 0.02:
+ weights_sum_fg = render_out['weights_sum_fg']
+ fg_bg_error = (weights_sum_fg - mask)[mask < 0.5]
+ fg_bg_loss = F.l1_loss(fg_bg_error,
+ torch.zeros_like(fg_bg_error).to(fg_bg_error.device),
+ reduction='mean')
+
+
+
+ loss = 1.0 * depth_loss + color_fine_loss + color_mlp_loss + \
+ sparse_loss * get_weight(iter_step, self.sdf_sparse_weight) + \
+ fg_bg_loss * fg_bg_weight + \
+ gradient_error_loss * self.sdf_igr_weight # ! gradient_error_loss need a mask
+
+ losses = {
+ "loss": loss,
+ "depth_loss": depth_loss,
+ "color_fine_loss": color_fine_loss,
+ "color_mlp_loss": color_mlp_loss,
+ "gradient_error_loss": gradient_error_loss,
+ "background_loss": background_loss,
+ "sparse_loss": sparse_loss,
+ "sparseness_1": sparseness_1,
+ "sparseness_2": sparseness_2,
+ "sdf_mean": sdf_mean,
+ "psnr": psnr,
+ "psnr_mlp": psnr_mlp,
+ "weights_sum": render_out['weights_sum'],
+ "weights_sum_fg": render_out['weights_sum_fg'],
+ "alpha_sum": render_out['alpha_sum'],
+ "variance": render_out['variance'],
+ "sparse_weight": get_weight(iter_step, self.sdf_sparse_weight),
+ "fg_bg_weight": fg_bg_weight,
+ "fg_bg_loss": fg_bg_loss, # added by jha, bug of sparseNeuS
+ }
+ # print("[TEST]: weights_sum in trainner forward", losses['weights_sum'].mean())
+ losses = numpy2tensor(losses, device=rays_o.device)
+ return loss, losses, depth_statis
+
+ @torch.no_grad()
+ def validate_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360,
+ threshold=0.0, mode='val',
+ # * 3d feature volume
+ conditional_volume=None, lod=None, occupancy_mask=None,
+ bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None,
+ trans_mat=None
+ ):
+
+ bound_min = torch.tensor(bound_min, dtype=torch.float32)
+ bound_max = torch.tensor(bound_max, dtype=torch.float32)
+
+ vertices, triangles, fields = func_extract_geometry(
+ density_or_sdf_network,
+ bound_min, bound_max, resolution=resolution,
+ threshold=threshold, device=conditional_volume.device,
+ # * 3d feature volume
+ conditional_volume=conditional_volume, lod=lod,
+ occupancy_mask=occupancy_mask
+ )
+
+
+ if scale_mat is not None:
+ scale_mat_np = scale_mat.cpu().numpy()
+ vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None]
+
+ if trans_mat is not None: # w2c_ref_inv
+ trans_mat_np = trans_mat.cpu().numpy()
+ vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1)
+ vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0]
+
+ mesh = trimesh.Trimesh(vertices, triangles)
+ os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode), exist_ok=True)
+ mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode,
+ 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod)))
+
+
+
+ def validate_colored_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360,
+ threshold=0.0, mode='val',
+ # * 3d feature volume
+ conditional_volume=None,
+ conditional_valid_mask_volume=None,
+ feature_maps=None,
+ color_maps = None,
+ w2cs=None,
+ target_candidate_w2cs=None,
+ intrinsics=None,
+ rendering_network=None,
+ rendering_projector=None,
+ query_c2w=None,
+ lod=None, occupancy_mask=None,
+ bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None,
+ trans_mat=None
+ ):
+
+ bound_min = torch.tensor(bound_min, dtype=torch.float32)
+ bound_max = torch.tensor(bound_max, dtype=torch.float32)
+
+ vertices, triangles, fields = func_extract_geometry(
+ density_or_sdf_network,
+ bound_min, bound_max, resolution=resolution,
+ threshold=threshold, device=conditional_volume.device,
+ # * 3d feature volume
+ conditional_volume=conditional_volume, lod=lod,
+ occupancy_mask=occupancy_mask
+ )
+
+
+ with torch.no_grad():
+ ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = rendering_projector.compute_view_independent(
+ torch.tensor(vertices).to(conditional_volume),
+ lod=0,
+ # * 3d geometry feature volumes
+ geometryVolume=conditional_volume[0],
+ geometryVolumeMask=conditional_valid_mask_volume[0],
+ sdf_network=density_or_sdf_network,
+ # * 2d rendering feature maps
+ rendering_feature_maps=feature_maps, # [n_view, 56, 256, 256]
+ color_maps=color_maps,
+ w2cs=w2cs,
+ target_candidate_w2cs=target_candidate_w2cs,
+ intrinsics=intrinsics,
+ img_wh=[256,256],
+ query_img_idx=0, # the index of the N_views dim for rendering
+ query_c2w=query_c2w,
+ )
+
+
+ vertices_color, rendering_valid_mask = rendering_network(
+ ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask)
+
+
+
+ if scale_mat is not None:
+ scale_mat_np = scale_mat.cpu().numpy()
+ vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None]
+
+ if trans_mat is not None: # w2c_ref_inv
+ trans_mat_np = trans_mat.cpu().numpy()
+ vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1)
+ vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0]
+ # import ipdb; ipdb.set_trace()
+ vertices_color = np.array(vertices_color.squeeze(0).cpu() * 255, dtype=np.uint8)
+ mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertices_color)
+ os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod)), exist_ok=True)
+ mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod),
+ 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod)))
\ No newline at end of file
diff --git a/SparseNeuS_demo_v1/ops/__init__.py b/SparseNeuS_demo_v1/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SparseNeuS_demo_v1/ops/back_project.py b/SparseNeuS_demo_v1/ops/back_project.py
new file mode 100644
index 0000000000000000000000000000000000000000..5398f285f786a0e6c7a029138aa8a6554aae6e58
--- /dev/null
+++ b/SparseNeuS_demo_v1/ops/back_project.py
@@ -0,0 +1,175 @@
+import torch
+from torch.nn.functional import grid_sample
+
+
+def back_project_sparse_type(coords, origin, voxel_size, feats, KRcam, sizeH=None, sizeW=None, only_mask=False,
+ with_proj_z=False):
+ # - modified version from NeuRecon
+ '''
+ Unproject the image fetures to form a 3D (sparse) feature volume
+
+ :param coords: coordinates of voxels,
+ dim: (num of voxels, 4) (4 : batch ind, x, y, z)
+ :param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0))
+ dim: (batch size, 3) (3: x, y, z)
+ :param voxel_size: floats specifying the size of a voxel
+ :param feats: image features
+ dim: (num of views, batch size, C, H, W)
+ :param KRcam: projection matrix
+ dim: (num of views, batch size, 4, 4)
+ :return: feature_volume_all: 3D feature volumes
+ dim: (num of voxels, num_of_views, c)
+ :return: mask_volume_all: indicate the voxel of sampled feature volume is valid or not
+ dim: (num of voxels, num_of_views)
+ '''
+ n_views, bs, c, h, w = feats.shape
+ device = feats.device
+
+ if sizeH is None:
+ sizeH, sizeW = h, w # - if the KRcam is not suitable for the current feats
+
+ feature_volume_all = torch.zeros(coords.shape[0], n_views, c).to(device)
+ mask_volume_all = torch.zeros([coords.shape[0], n_views], dtype=torch.int32).to(device)
+ # import ipdb; ipdb.set_trace()
+ for batch in range(bs):
+ # import ipdb; ipdb.set_trace()
+ batch_ind = torch.nonzero(coords[:, 0] == batch).squeeze(1)
+ coords_batch = coords[batch_ind][:, 1:]
+
+ coords_batch = coords_batch.view(-1, 3)
+ origin_batch = origin[batch].unsqueeze(0)
+ feats_batch = feats[:, batch]
+ proj_batch = KRcam[:, batch]
+
+ grid_batch = coords_batch * voxel_size + origin_batch.float()
+ rs_grid = grid_batch.unsqueeze(0).expand(n_views, -1, -1)
+ rs_grid = rs_grid.permute(0, 2, 1).contiguous()
+ nV = rs_grid.shape[-1]
+ rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1)
+
+ # Project grid
+ im_p = proj_batch @ rs_grid # - transform world pts to image UV space
+ im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2]
+
+ im_z[im_z >= 0] = im_z[im_z >= 0].clamp(min=1e-6)
+
+ im_x = im_x / im_z
+ im_y = im_y / im_z
+
+ im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1)
+ mask = im_grid.abs() <= 1
+ mask = (mask.sum(dim=-1) == 2) & (im_z > 0)
+
+ mask = mask.view(n_views, -1)
+ mask = mask.permute(1, 0).contiguous() # [num_pts, nviews]
+
+ mask_volume_all[batch_ind] = mask.to(torch.int32)
+
+ if only_mask:
+ return mask_volume_all
+
+ feats_batch = feats_batch.view(n_views, c, h, w)
+ im_grid = im_grid.view(n_views, 1, -1, 2)
+ features = grid_sample(feats_batch, im_grid, padding_mode='zeros', align_corners=True)
+ # if features.isnan().sum() > 0:
+ # import ipdb; ipdb.set_trace()
+ features = features.view(n_views, c, -1)
+ features = features.permute(2, 0, 1).contiguous() # [num_pts, nviews, c]
+
+ feature_volume_all[batch_ind] = features
+
+ if with_proj_z:
+ im_z = im_z.view(n_views, 1, -1).permute(2, 0, 1).contiguous() # [num_pts, nviews, 1]
+ return feature_volume_all, mask_volume_all, im_z
+ # if feature_volume_all.isnan().sum() > 0:
+ # import ipdb; ipdb.set_trace()
+ return feature_volume_all, mask_volume_all
+
+
+def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr, padding_mode, sizeH=None, sizeW=None, with_depth=False):
+ """Transform coordinates in the camera frame to the pixel frame.
+ Args:
+ cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W]
+ proj_c2p_rot: rotation matrix of cameras -- [B, 3, 3]
+ proj_c2p_tr: translation vectors of cameras -- [B, 3, 1]
+ Returns:
+ array of [-1,1] coordinates -- [B, H, W, 2]
+ """
+ b, _, h, w = cam_coords.size()
+ if sizeH is None:
+ sizeH = h
+ sizeW = w
+
+ cam_coords_flat = cam_coords.view(b, 3, -1) # [B, 3, H*W]
+ if proj_c2p_rot is not None:
+ pcoords = proj_c2p_rot.bmm(cam_coords_flat)
+ else:
+ pcoords = cam_coords_flat
+
+ if proj_c2p_tr is not None:
+ pcoords = pcoords + proj_c2p_tr # [B, 3, H*W]
+ X = pcoords[:, 0]
+ Y = pcoords[:, 1]
+ Z = pcoords[:, 2].clamp(min=1e-3)
+
+ X_norm = 2 * (X / Z) / (sizeW - 1) - 1 # Normalized, -1 if on extreme left,
+ # 1 if on extreme right (x = w-1) [B, H*W]
+ Y_norm = 2 * (Y / Z) / (sizeH - 1) - 1 # Idem [B, H*W]
+ if padding_mode == 'zeros':
+ X_mask = ((X_norm > 1) + (X_norm < -1)).detach()
+ X_norm[X_mask] = 2 # make sure that no point in warped image is a combinaison of im and gray
+ Y_mask = ((Y_norm > 1) + (Y_norm < -1)).detach()
+ Y_norm[Y_mask] = 2
+
+ if with_depth:
+ pixel_coords = torch.stack([X_norm, Y_norm, Z], dim=2) # [B, H*W, 3]
+ return pixel_coords.view(b, h, w, 3)
+ else:
+ pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2]
+ return pixel_coords.view(b, h, w, 2)
+
+
+# * have already checked, should check whether proj_matrix is for right coordinate system and resolution
+def back_project_dense_type(coords, origin, voxel_size, feats, proj_matrix, sizeH=None, sizeW=None):
+ '''
+ Unproject the image fetures to form a 3D (dense) feature volume
+
+ :param coords: coordinates of voxels,
+ dim: (batch, nviews, 3, X,Y,Z)
+ :param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0))
+ dim: (batch size, 3) (3: x, y, z)
+ :param voxel_size: floats specifying the size of a voxel
+ :param feats: image features
+ dim: (batch size, num of views, C, H, W)
+ :param proj_matrix: projection matrix
+ dim: (batch size, num of views, 4, 4)
+ :return: feature_volume_all: 3D feature volumes
+ dim: (batch, nviews, C, X,Y,Z)
+ :return: count: number of times each voxel can be seen
+ dim: (batch, nviews, 1, X,Y,Z)
+ '''
+
+ batch, nviews, _, wX, wY, wZ = coords.shape
+
+ if sizeH is None:
+ sizeH, sizeW = feats.shape[-2:]
+ proj_matrix = proj_matrix.view(batch * nviews, *proj_matrix.shape[2:])
+
+ coords_wrd = coords * voxel_size + origin.view(batch, 1, 3, 1, 1, 1)
+ coords_wrd = coords_wrd.view(batch * nviews, 3, wX * wY * wZ, 1) # (b*nviews,3,wX*wY*wZ, 1)
+
+ pixel_grids = cam2pixel(coords_wrd, proj_matrix[:, :3, :3], proj_matrix[:, :3, 3:],
+ 'zeros', sizeH=sizeH, sizeW=sizeW) # (b*nviews,wX*wY*wZ, 2)
+ pixel_grids = pixel_grids.view(batch * nviews, 1, wX * wY * wZ, 2)
+
+ feats = feats.view(batch * nviews, *feats.shape[2:]) # (b*nviews,c,h,w)
+
+ ones = torch.ones((batch * nviews, 1, *feats.shape[2:])).to(feats.dtype).to(feats.device)
+
+ features_volume = torch.nn.functional.grid_sample(feats, pixel_grids, padding_mode='zeros', align_corners=True)
+ counts_volume = torch.nn.functional.grid_sample(ones, pixel_grids, padding_mode='zeros', align_corners=True)
+
+ features_volume = features_volume.view(batch, nviews, -1, wX, wY, wZ) # (batch, nviews, C, X,Y,Z)
+ counts_volume = counts_volume.view(batch, nviews, -1, wX, wY, wZ)
+ return features_volume, counts_volume
+
diff --git a/SparseNeuS_demo_v1/ops/generate_grids.py b/SparseNeuS_demo_v1/ops/generate_grids.py
new file mode 100644
index 0000000000000000000000000000000000000000..884c37793131323c566c6d1a738f06d497bbd2fb
--- /dev/null
+++ b/SparseNeuS_demo_v1/ops/generate_grids.py
@@ -0,0 +1,33 @@
+import torch
+
+
+def generate_grid(n_vox, interval):
+ """
+ generate grid
+ if 3D volume, grid[:,:,x,y,z] = (x,y,z)
+ :param n_vox:
+ :param interval:
+ :return:
+ """
+ with torch.no_grad():
+ # Create voxel grid
+ grid_range = [torch.arange(0, n_vox[axis], interval) for axis in range(3)]
+ grid = torch.stack(torch.meshgrid(grid_range[0], grid_range[1], grid_range[2])) # 3 dx dy dz
+ # ! don't create tensor on gpu; imbalanced gpu memory in ddp mode
+ grid = grid.unsqueeze(0).type(torch.float32) # 1 3 dx dy dz
+
+ return grid
+
+
+if __name__ == "__main__":
+ import torch.nn.functional as F
+ grid = generate_grid([5, 6, 8], 1)
+
+ pts = 2 * torch.tensor([1, 2, 3]) / (torch.tensor([5, 6, 8]) - 1) - 1
+ pts = pts.view(1, 1, 1, 1, 3)
+
+ pts = torch.flip(pts, dims=[-1])
+
+ sampled = F.grid_sample(grid, pts, mode='nearest')
+
+ print(sampled)
diff --git a/SparseNeuS_demo_v1/ops/grid_sampler.py b/SparseNeuS_demo_v1/ops/grid_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..44113faa705f0b98a5689c0e4fb9e7a95865d6c1
--- /dev/null
+++ b/SparseNeuS_demo_v1/ops/grid_sampler.py
@@ -0,0 +1,467 @@
+"""
+pytorch grid_sample doesn't support second-order derivative
+implement custom version
+"""
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+
+
+def grid_sample_2d(image, optical):
+ N, C, IH, IW = image.shape
+ _, H, W, _ = optical.shape
+
+ ix = optical[..., 0]
+ iy = optical[..., 1]
+
+ ix = ((ix + 1) / 2) * (IW - 1);
+ iy = ((iy + 1) / 2) * (IH - 1);
+ with torch.no_grad():
+ ix_nw = torch.floor(ix);
+ iy_nw = torch.floor(iy);
+ ix_ne = ix_nw + 1;
+ iy_ne = iy_nw;
+ ix_sw = ix_nw;
+ iy_sw = iy_nw + 1;
+ ix_se = ix_nw + 1;
+ iy_se = iy_nw + 1;
+
+ nw = (ix_se - ix) * (iy_se - iy)
+ ne = (ix - ix_sw) * (iy_sw - iy)
+ sw = (ix_ne - ix) * (iy - iy_ne)
+ se = (ix - ix_nw) * (iy - iy_nw)
+
+ with torch.no_grad():
+ torch.clamp(ix_nw, 0, IW - 1, out=ix_nw)
+ torch.clamp(iy_nw, 0, IH - 1, out=iy_nw)
+
+ torch.clamp(ix_ne, 0, IW - 1, out=ix_ne)
+ torch.clamp(iy_ne, 0, IH - 1, out=iy_ne)
+
+ torch.clamp(ix_sw, 0, IW - 1, out=ix_sw)
+ torch.clamp(iy_sw, 0, IH - 1, out=iy_sw)
+
+ torch.clamp(ix_se, 0, IW - 1, out=ix_se)
+ torch.clamp(iy_se, 0, IH - 1, out=iy_se)
+
+ image = image.view(N, C, IH * IW)
+
+ nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1))
+ ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1))
+ sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1))
+ se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1))
+
+ out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) +
+ ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) +
+ sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) +
+ se_val.view(N, C, H, W) * se.view(N, 1, H, W))
+
+ return out_val
+
+
+# - checked for correctness
+def grid_sample_3d(volume, optical):
+ """
+ bilinear sampling cannot guarantee continuous first-order gradient
+ mimic pytorch grid_sample function
+ The 8 corner points of a volume noted as: 4 points (front view); 4 points (back view)
+ fnw (front north west) point
+ bse (back south east) point
+ :param volume: [B, C, X, Y, Z]
+ :param optical: [B, x, y, z, 3]
+ :return:
+ """
+ N, C, ID, IH, IW = volume.shape
+ _, D, H, W, _ = optical.shape
+
+ ix = optical[..., 0]
+ iy = optical[..., 1]
+ iz = optical[..., 2]
+
+ ix = ((ix + 1) / 2) * (IW - 1)
+ iy = ((iy + 1) / 2) * (IH - 1)
+ iz = ((iz + 1) / 2) * (ID - 1)
+
+ mask_x = (ix > 0) & (ix < IW)
+ mask_y = (iy > 0) & (iy < IH)
+ mask_z = (iz > 0) & (iz < ID)
+
+ mask = mask_x & mask_y & mask_z # [B, x, y, z]
+ mask = mask[:, None, :, :, :].repeat(1, C, 1, 1, 1) # [B, C, x, y, z]
+
+ with torch.no_grad():
+ # back north west
+ ix_bnw = torch.floor(ix)
+ iy_bnw = torch.floor(iy)
+ iz_bnw = torch.floor(iz)
+
+ ix_bne = ix_bnw + 1
+ iy_bne = iy_bnw
+ iz_bne = iz_bnw
+
+ ix_bsw = ix_bnw
+ iy_bsw = iy_bnw + 1
+ iz_bsw = iz_bnw
+
+ ix_bse = ix_bnw + 1
+ iy_bse = iy_bnw + 1
+ iz_bse = iz_bnw
+
+ # front view
+ ix_fnw = ix_bnw
+ iy_fnw = iy_bnw
+ iz_fnw = iz_bnw + 1
+
+ ix_fne = ix_bnw + 1
+ iy_fne = iy_bnw
+ iz_fne = iz_bnw + 1
+
+ ix_fsw = ix_bnw
+ iy_fsw = iy_bnw + 1
+ iz_fsw = iz_bnw + 1
+
+ ix_fse = ix_bnw + 1
+ iy_fse = iy_bnw + 1
+ iz_fse = iz_bnw + 1
+
+ # back view
+ bnw = (ix_fse - ix) * (iy_fse - iy) * (iz_fse - iz) # smaller volume, larger weight
+ bne = (ix - ix_fsw) * (iy_fsw - iy) * (iz_fsw - iz)
+ bsw = (ix_fne - ix) * (iy - iy_fne) * (iz_fne - iz)
+ bse = (ix - ix_fnw) * (iy - iy_fnw) * (iz_fnw - iz)
+
+ # front view
+ fnw = (ix_bse - ix) * (iy_bse - iy) * (iz - iz_bse) # smaller volume, larger weight
+ fne = (ix - ix_bsw) * (iy_bsw - iy) * (iz - iz_bsw)
+ fsw = (ix_bne - ix) * (iy - iy_bne) * (iz - iz_bne)
+ fse = (ix - ix_bnw) * (iy - iy_bnw) * (iz - iz_bnw)
+
+ with torch.no_grad():
+ # back view
+ torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw)
+ torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw)
+ torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw)
+
+ torch.clamp(ix_bne, 0, IW - 1, out=ix_bne)
+ torch.clamp(iy_bne, 0, IH - 1, out=iy_bne)
+ torch.clamp(iz_bne, 0, ID - 1, out=iz_bne)
+
+ torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw)
+ torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw)
+ torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw)
+
+ torch.clamp(ix_bse, 0, IW - 1, out=ix_bse)
+ torch.clamp(iy_bse, 0, IH - 1, out=iy_bse)
+ torch.clamp(iz_bse, 0, ID - 1, out=iz_bse)
+
+ # front view
+ torch.clamp(ix_fnw, 0, IW - 1, out=ix_fnw)
+ torch.clamp(iy_fnw, 0, IH - 1, out=iy_fnw)
+ torch.clamp(iz_fnw, 0, ID - 1, out=iz_fnw)
+
+ torch.clamp(ix_fne, 0, IW - 1, out=ix_fne)
+ torch.clamp(iy_fne, 0, IH - 1, out=iy_fne)
+ torch.clamp(iz_fne, 0, ID - 1, out=iz_fne)
+
+ torch.clamp(ix_fsw, 0, IW - 1, out=ix_fsw)
+ torch.clamp(iy_fsw, 0, IH - 1, out=iy_fsw)
+ torch.clamp(iz_fsw, 0, ID - 1, out=iz_fsw)
+
+ torch.clamp(ix_fse, 0, IW - 1, out=ix_fse)
+ torch.clamp(iy_fse, 0, IH - 1, out=iy_fse)
+ torch.clamp(iz_fse, 0, ID - 1, out=iz_fse)
+
+ # xxx = volume[:, :, iz_bnw.long(), iy_bnw.long(), ix_bnw.long()]
+ volume = volume.view(N, C, ID * IH * IW)
+ # yyy = volume[:, :, (iz_bnw * ID + iy_bnw * IW + ix_bnw).long()]
+
+ # back view
+ bnw_val = torch.gather(volume, 2,
+ (iz_bnw * ID ** 2 + iy_bnw * IW + ix_bnw).long().view(N, 1, D * H * W).repeat(1, C, 1))
+ bne_val = torch.gather(volume, 2,
+ (iz_bne * ID ** 2 + iy_bne * IW + ix_bne).long().view(N, 1, D * H * W).repeat(1, C, 1))
+ bsw_val = torch.gather(volume, 2,
+ (iz_bsw * ID ** 2 + iy_bsw * IW + ix_bsw).long().view(N, 1, D * H * W).repeat(1, C, 1))
+ bse_val = torch.gather(volume, 2,
+ (iz_bse * ID ** 2 + iy_bse * IW + ix_bse).long().view(N, 1, D * H * W).repeat(1, C, 1))
+
+ # front view
+ fnw_val = torch.gather(volume, 2,
+ (iz_fnw * ID ** 2 + iy_fnw * IW + ix_fnw).long().view(N, 1, D * H * W).repeat(1, C, 1))
+ fne_val = torch.gather(volume, 2,
+ (iz_fne * ID ** 2 + iy_fne * IW + ix_fne).long().view(N, 1, D * H * W).repeat(1, C, 1))
+ fsw_val = torch.gather(volume, 2,
+ (iz_fsw * ID ** 2 + iy_fsw * IW + ix_fsw).long().view(N, 1, D * H * W).repeat(1, C, 1))
+ fse_val = torch.gather(volume, 2,
+ (iz_fse * ID ** 2 + iy_fse * IW + ix_fse).long().view(N, 1, D * H * W).repeat(1, C, 1))
+
+ out_val = (
+ # back
+ bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) +
+ bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) +
+ bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) +
+ bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W) +
+ # front
+ fnw_val.view(N, C, D, H, W) * fnw.view(N, 1, D, H, W) +
+ fne_val.view(N, C, D, H, W) * fne.view(N, 1, D, H, W) +
+ fsw_val.view(N, C, D, H, W) * fsw.view(N, 1, D, H, W) +
+ fse_val.view(N, C, D, H, W) * fse.view(N, 1, D, H, W)
+
+ )
+
+ # * zero padding
+ out_val = torch.where(mask, out_val, torch.zeros_like(out_val).float().to(out_val.device))
+
+ return out_val
+
+
+# Interpolation kernel
+def get_weight(s, a=-0.5):
+ mask_0 = (torch.abs(s) >= 0) & (torch.abs(s) <= 1)
+ mask_1 = (torch.abs(s) > 1) & (torch.abs(s) <= 2)
+ mask_2 = torch.abs(s) > 2
+
+ weight = torch.zeros_like(s).to(s.device)
+ weight = torch.where(mask_0, (a + 2) * (torch.abs(s) ** 3) - (a + 3) * (torch.abs(s) ** 2) + 1, weight)
+ weight = torch.where(mask_1,
+ a * (torch.abs(s) ** 3) - (5 * a) * (torch.abs(s) ** 2) + (8 * a) * torch.abs(s) - 4 * a,
+ weight)
+
+ # if (torch.abs(s) >= 0) & (torch.abs(s) <= 1):
+ # return (a + 2) * (torch.abs(s) ** 3) - (a + 3) * (torch.abs(s) ** 2) + 1
+ #
+ # elif (torch.abs(s) > 1) & (torch.abs(s) <= 2):
+ # return a * (torch.abs(s) ** 3) - (5 * a) * (torch.abs(s) ** 2) + (8 * a) * torch.abs(s) - 4 * a
+ # return 0
+
+ return weight
+
+
+def cubic_interpolate(p, x):
+ """
+ one dimensional cubic interpolation
+ :param p: [N, 4] (4) should be in order
+ :param x: [N]
+ :return:
+ """
+ return p[:, 1] + 0.5 * x * (p[:, 2] - p[:, 0] + x * (
+ 2.0 * p[:, 0] - 5.0 * p[:, 1] + 4.0 * p[:, 2] - p[:, 3] + x * (
+ 3.0 * (p[:, 1] - p[:, 2]) + p[:, 3] - p[:, 0])))
+
+
+def bicubic_interpolate(p, x, y, if_batch=True):
+ """
+ two dimensional cubic interpolation
+ :param p: [N, 4, 4]
+ :param x: [N]
+ :param y: [N]
+ :return:
+ """
+ num = p.shape[0]
+
+ if not if_batch:
+ arr0 = cubic_interpolate(p[:, 0, :], x) # [N]
+ arr1 = cubic_interpolate(p[:, 1, :], x)
+ arr2 = cubic_interpolate(p[:, 2, :], x)
+ arr3 = cubic_interpolate(p[:, 3, :], x)
+ return cubic_interpolate(torch.stack([arr0, arr1, arr2, arr3], dim=-1), y) # [N]
+ else:
+ x = x[:, None].repeat(1, 4).view(-1)
+ p = p.contiguous().view(num * 4, 4)
+ arr = cubic_interpolate(p, x)
+ arr = arr.view(num, 4)
+
+ return cubic_interpolate(arr, y)
+
+
+def tricubic_interpolate(p, x, y, z):
+ """
+ three dimensional cubic interpolation
+ :param p: [N,4,4,4]
+ :param x: [N]
+ :param y: [N]
+ :param z: [N]
+ :return:
+ """
+ num = p.shape[0]
+
+ arr0 = bicubic_interpolate(p[:, 0, :, :], x, y) # [N]
+ arr1 = bicubic_interpolate(p[:, 1, :, :], x, y)
+ arr2 = bicubic_interpolate(p[:, 2, :, :], x, y)
+ arr3 = bicubic_interpolate(p[:, 3, :, :], x, y)
+
+ return cubic_interpolate(torch.stack([arr0, arr1, arr2, arr3], dim=-1), z) # [N]
+
+
+def cubic_interpolate_batch(p, x):
+ """
+ one dimensional cubic interpolation
+ :param p: [B, N, 4] (4) should be in order
+ :param x: [B, N]
+ :return:
+ """
+ return p[:, :, 1] + 0.5 * x * (p[:, :, 2] - p[:, :, 0] + x * (
+ 2.0 * p[:, :, 0] - 5.0 * p[:, :, 1] + 4.0 * p[:, :, 2] - p[:, :, 3] + x * (
+ 3.0 * (p[:, :, 1] - p[:, :, 2]) + p[:, :, 3] - p[:, :, 0])))
+
+
+def bicubic_interpolate_batch(p, x, y):
+ """
+ two dimensional cubic interpolation
+ :param p: [B, N, 4, 4]
+ :param x: [B, N]
+ :param y: [B, N]
+ :return:
+ """
+ B, N, _, _ = p.shape
+
+ x = x[:, :, None].repeat(1, 1, 4).view(B, N * 4) # [B, N*4]
+ arr = cubic_interpolate_batch(p.contiguous().view(B, N * 4, 4), x)
+ arr = arr.view(B, N, 4)
+ return cubic_interpolate_batch(arr, y) # [B, N]
+
+
+# * batch version cannot speed up training
+def tricubic_interpolate_batch(p, x, y, z):
+ """
+ three dimensional cubic interpolation
+ :param p: [N,4,4,4]
+ :param x: [N]
+ :param y: [N]
+ :param z: [N]
+ :return:
+ """
+ N = p.shape[0]
+
+ x = x[None, :].repeat(4, 1)
+ y = y[None, :].repeat(4, 1)
+
+ p = p.permute(1, 0, 2, 3).contiguous()
+
+ arr = bicubic_interpolate_batch(p[:, :, :, :], x, y) # [4, N]
+
+ arr = arr.permute(1, 0).contiguous() # [N, 4]
+
+ return cubic_interpolate(arr, z) # [N]
+
+
+def tricubic_sample_3d(volume, optical):
+ """
+ tricubic sampling; can guarantee continuous gradient (interpolation border)
+ :param volume: [B, C, ID, IH, IW]
+ :param optical: [B, D, H, W, 3]
+ :param sample_num:
+ :return:
+ """
+
+ @torch.no_grad()
+ def get_shifts(x):
+ x1 = -1 * (1 + x - torch.floor(x))
+ x2 = -1 * (x - torch.floor(x))
+ x3 = torch.floor(x) + 1 - x
+ x4 = torch.floor(x) + 2 - x
+
+ return torch.stack([x1, x2, x3, x4], dim=-1) # (B,d,h,w,4)
+
+ N, C, ID, IH, IW = volume.shape
+ _, D, H, W, _ = optical.shape
+
+ device = volume.device
+
+ ix = optical[..., 0]
+ iy = optical[..., 1]
+ iz = optical[..., 2]
+
+ ix = ((ix + 1) / 2) * (IW - 1) # (B,d,h,w)
+ iy = ((iy + 1) / 2) * (IH - 1)
+ iz = ((iz + 1) / 2) * (ID - 1)
+
+ ix = ix.view(-1)
+ iy = iy.view(-1)
+ iz = iz.view(-1)
+
+ with torch.no_grad():
+ shifts_x = get_shifts(ix).view(-1, 4) # (B*d*h*w,4)
+ shifts_y = get_shifts(iy).view(-1, 4)
+ shifts_z = get_shifts(iz).view(-1, 4)
+
+ perm_weights = torch.ones([N * D * H * W, 4 * 4 * 4]).long().to(device)
+ perm = torch.cumsum(perm_weights, dim=-1) - 1 # (B*d*h*w,64)
+
+ perm_z = perm // 16 # [N*D*H*W, num]
+ perm_y = (perm - perm_z * 16) // 4
+ perm_x = (perm - perm_z * 16 - perm_y * 4)
+
+ shifts_x = torch.gather(shifts_x, 1, perm_x) # [N*D*H*W, num]
+ shifts_y = torch.gather(shifts_y, 1, perm_y)
+ shifts_z = torch.gather(shifts_z, 1, perm_z)
+
+ ix_target = (ix[:, None] + shifts_x).long() # [N*D*H*W, num]
+ iy_target = (iy[:, None] + shifts_y).long()
+ iz_target = (iz[:, None] + shifts_z).long()
+
+ torch.clamp(ix_target, 0, IW - 1, out=ix_target)
+ torch.clamp(iy_target, 0, IH - 1, out=iy_target)
+ torch.clamp(iz_target, 0, ID - 1, out=iz_target)
+
+ local_dist_x = ix - ix_target[:, 1] # ! attention here is [:, 1]
+ local_dist_y = iy - iy_target[:, 1 + 4]
+ local_dist_z = iz - iz_target[:, 1 + 16]
+
+ local_dist_x = local_dist_x.view(N, 1, D * H * W).repeat(1, C, 1).view(-1)
+ local_dist_y = local_dist_y.view(N, 1, D * H * W).repeat(1, C, 1).view(-1)
+ local_dist_z = local_dist_z.view(N, 1, D * H * W).repeat(1, C, 1).view(-1)
+
+ # ! attention: IW is correct
+ idx_target = iz_target * ID ** 2 + iy_target * IW + ix_target # [N*D*H*W, num]
+
+ volume = volume.view(N, C, ID * IH * IW)
+
+ out = torch.gather(volume, 2,
+ idx_target.view(N, 1, D * H * W * 64).repeat(1, C, 1))
+ out = out.view(N * C * D * H * W, 4, 4, 4)
+
+ # - tricubic_interpolate() is a bit faster than tricubic_interpolate_batch()
+ final = tricubic_interpolate(out, local_dist_x, local_dist_y, local_dist_z).view(N, C, D, H, W) # [N,C,D,H,W]
+
+ return final
+
+
+
+if __name__ == "__main__":
+ # image = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).view(1, 3, 1, 3)
+ #
+ # optical = torch.Tensor([0.9, 0.5, 0.6, -0.7]).view(1, 1, 2, 2)
+ #
+ # print(grid_sample_2d(image, optical))
+ #
+ # print(F.grid_sample(image, optical, padding_mode='border', align_corners=True))
+
+ from ops.generate_grids import generate_grid
+
+ p = torch.tensor([x for x in range(4)]).view(1, 4).float()
+
+ v = cubic_interpolate(p, torch.tensor([0.5]).view(1))
+ # v = bicubic_interpolate(p, torch.tensor([2/3]).view(1) , torch.tensor([2/3]).view(1))
+
+ vsize = 9
+ volume = generate_grid([vsize, vsize, vsize], 1) # [1,3,10,10,10]
+ # volume = torch.tensor([x for x in range(1000)]).view(1, 1, 10, 10, 10).float()
+ X, Y, Z = 0, 0, 6
+ x = 2 * X / (vsize - 1) - 1
+ y = 2 * Y / (vsize - 1) - 1
+ z = 2 * Z / (vsize - 1) - 1
+
+ # print(volume[:, :, Z, Y, X])
+
+ # volume = volume.view(1, 3, -1)
+ # xx = volume[:, :, Z * 9*9 + Y * 9 + X]
+
+ optical = torch.Tensor([-0.6, -0.7, 0.5, 0.3, 0.5, 0.5]).view(1, 1, 1, 2, 3)
+
+ print(F.grid_sample(volume, optical, padding_mode='border', align_corners=True))
+ print(grid_sample_3d(volume, optical))
+ print(tricubic_sample_3d(volume, optical))
+ # target, relative_coords = implicit_sample_3d(volume, optical, 1)
+ # print(target)
diff --git a/SparseNeuS_demo_v1/requirements.txt b/SparseNeuS_demo_v1/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..06a26a213732e63398677064788c50e7d03bf95f
--- /dev/null
+++ b/SparseNeuS_demo_v1/requirements.txt
@@ -0,0 +1,11 @@
+git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0
+opencv_python
+trimesh
+numpy
+pyhocon
+icecream
+tqdm
+scipy
+PyMCubes
+# sudo apt-get install libsparsehash-dev
+inplace_abn
\ No newline at end of file
diff --git a/SparseNeuS_demo_v1/tsparse/__init__.py b/SparseNeuS_demo_v1/tsparse/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SparseNeuS_demo_v1/tsparse/modules.py b/SparseNeuS_demo_v1/tsparse/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..520809144718d84b77708bbc7a582a64078958b4
--- /dev/null
+++ b/SparseNeuS_demo_v1/tsparse/modules.py
@@ -0,0 +1,326 @@
+import torch
+import torch.nn as nn
+import torchsparse
+import torchsparse.nn as spnn
+from torchsparse.tensor import PointTensor
+
+from tsparse.torchsparse_utils import *
+
+
+# __all__ = ['SPVCNN', 'SConv3d', 'SparseConvGRU']
+
+
+class ConvBnReLU(nn.Module):
+ def __init__(self, in_channels, out_channels,
+ kernel_size=3, stride=1, pad=1):
+ super(ConvBnReLU, self).__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels,
+ kernel_size, stride=stride, padding=pad, bias=False)
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.activation = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ return self.activation(self.bn(self.conv(x)))
+
+
+class ConvBnReLU3D(nn.Module):
+ def __init__(self, in_channels, out_channels,
+ kernel_size=3, stride=1, pad=1):
+ super(ConvBnReLU3D, self).__init__()
+ self.conv = nn.Conv3d(in_channels, out_channels,
+ kernel_size, stride=stride, padding=pad, bias=False)
+ self.bn = nn.BatchNorm3d(out_channels)
+ self.activation = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ return self.activation(self.bn(self.conv(x)))
+
+
+################################### feature net ######################################
+class FeatureNet(nn.Module):
+ """
+ output 3 levels of features using a FPN structure
+ """
+
+ def __init__(self):
+ super(FeatureNet, self).__init__()
+
+ self.conv0 = nn.Sequential(
+ ConvBnReLU(3, 8, 3, 1, 1),
+ ConvBnReLU(8, 8, 3, 1, 1))
+
+ self.conv1 = nn.Sequential(
+ ConvBnReLU(8, 16, 5, 2, 2),
+ ConvBnReLU(16, 16, 3, 1, 1),
+ ConvBnReLU(16, 16, 3, 1, 1))
+
+ self.conv2 = nn.Sequential(
+ ConvBnReLU(16, 32, 5, 2, 2),
+ ConvBnReLU(32, 32, 3, 1, 1),
+ ConvBnReLU(32, 32, 3, 1, 1))
+
+ self.toplayer = nn.Conv2d(32, 32, 1)
+ self.lat1 = nn.Conv2d(16, 32, 1)
+ self.lat0 = nn.Conv2d(8, 32, 1)
+
+ # to reduce channel size of the outputs from FPN
+ self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
+ self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)
+
+ def _upsample_add(self, x, y):
+ return torch.nn.functional.interpolate(x, scale_factor=2,
+ mode="bilinear", align_corners=True) + y
+
+ def forward(self, x):
+ # x: (B, 3, H, W)
+ conv0 = self.conv0(x) # (B, 8, H, W)
+ conv1 = self.conv1(conv0) # (B, 16, H//2, W//2)
+ conv2 = self.conv2(conv1) # (B, 32, H//4, W//4)
+ feat2 = self.toplayer(conv2) # (B, 32, H//4, W//4)
+ feat1 = self._upsample_add(feat2, self.lat1(conv1)) # (B, 32, H//2, W//2)
+ feat0 = self._upsample_add(feat1, self.lat0(conv0)) # (B, 32, H, W)
+
+ # reduce output channels
+ feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2)
+ feat0 = self.smooth0(feat0) # (B, 8, H, W)
+
+ # feats = {"level_0": feat0,
+ # "level_1": feat1,
+ # "level_2": feat2}
+
+ return [feat2, feat1, feat0] # coarser to finer features
+
+
+class BasicSparseConvolutionBlock(nn.Module):
+ def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
+ super().__init__()
+ self.net = nn.Sequential(
+ spnn.Conv3d(inc,
+ outc,
+ kernel_size=ks,
+ dilation=dilation,
+ stride=stride),
+ spnn.BatchNorm(outc),
+ spnn.ReLU(True))
+
+ def forward(self, x):
+ out = self.net(x)
+ return out
+
+
+class BasicSparseDeconvolutionBlock(nn.Module):
+ def __init__(self, inc, outc, ks=3, stride=1):
+ super().__init__()
+ self.net = nn.Sequential(
+ spnn.Conv3d(inc,
+ outc,
+ kernel_size=ks,
+ stride=stride,
+ transposed=True),
+ spnn.BatchNorm(outc),
+ spnn.ReLU(True))
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class SparseResidualBlock(nn.Module):
+ def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
+ super().__init__()
+ self.net = nn.Sequential(
+ spnn.Conv3d(inc,
+ outc,
+ kernel_size=ks,
+ dilation=dilation,
+ stride=stride), spnn.BatchNorm(outc),
+ spnn.ReLU(True),
+ spnn.Conv3d(outc,
+ outc,
+ kernel_size=ks,
+ dilation=dilation,
+ stride=1), spnn.BatchNorm(outc))
+
+ self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \
+ nn.Sequential(
+ spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride),
+ spnn.BatchNorm(outc)
+ )
+
+ self.relu = spnn.ReLU(True)
+
+ def forward(self, x):
+ out = self.relu(self.net(x) + self.downsample(x))
+ return out
+
+
+class SPVCNN(nn.Module):
+ def __init__(self, **kwargs):
+ super().__init__()
+
+ self.dropout = kwargs['dropout']
+
+ cr = kwargs.get('cr', 1.0)
+ cs = [32, 64, 128, 96, 96]
+ cs = [int(cr * x) for x in cs]
+
+ if 'pres' in kwargs and 'vres' in kwargs:
+ self.pres = kwargs['pres']
+ self.vres = kwargs['vres']
+
+ self.stem = nn.Sequential(
+ spnn.Conv3d(kwargs['in_channels'], cs[0], kernel_size=3, stride=1),
+ spnn.BatchNorm(cs[0]), spnn.ReLU(True)
+ )
+
+ self.stage1 = nn.Sequential(
+ BasicSparseConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1),
+ SparseResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1),
+ SparseResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1),
+ )
+
+ self.stage2 = nn.Sequential(
+ BasicSparseConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1),
+ SparseResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1),
+ SparseResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1),
+ )
+
+ self.up1 = nn.ModuleList([
+ BasicSparseDeconvolutionBlock(cs[2], cs[3], ks=2, stride=2),
+ nn.Sequential(
+ SparseResidualBlock(cs[3] + cs[1], cs[3], ks=3, stride=1,
+ dilation=1),
+ SparseResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1),
+ )
+ ])
+
+ self.up2 = nn.ModuleList([
+ BasicSparseDeconvolutionBlock(cs[3], cs[4], ks=2, stride=2),
+ nn.Sequential(
+ SparseResidualBlock(cs[4] + cs[0], cs[4], ks=3, stride=1,
+ dilation=1),
+ SparseResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1),
+ )
+ ])
+
+ self.point_transforms = nn.ModuleList([
+ nn.Sequential(
+ nn.Linear(cs[0], cs[2]),
+ nn.BatchNorm1d(cs[2]),
+ nn.ReLU(True),
+ ),
+ nn.Sequential(
+ nn.Linear(cs[2], cs[4]),
+ nn.BatchNorm1d(cs[4]),
+ nn.ReLU(True),
+ )
+ ])
+
+ self.weight_initialization()
+
+ if self.dropout:
+ self.dropout = nn.Dropout(0.3, True)
+
+ def weight_initialization(self):
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm1d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, z):
+ # x: SparseTensor z: PointTensor
+ x0 = initial_voxelize(z, self.pres, self.vres)
+
+ x0 = self.stem(x0)
+ z0 = voxel_to_point(x0, z, nearest=False)
+ z0.F = z0.F
+
+ x1 = point_to_voxel(x0, z0)
+ x1 = self.stage1(x1)
+ x2 = self.stage2(x1)
+ z1 = voxel_to_point(x2, z0)
+ z1.F = z1.F + self.point_transforms[0](z0.F)
+
+ y3 = point_to_voxel(x2, z1)
+ if self.dropout:
+ y3.F = self.dropout(y3.F)
+ y3 = self.up1[0](y3)
+ y3 = torchsparse.cat([y3, x1])
+ y3 = self.up1[1](y3)
+
+ y4 = self.up2[0](y3)
+ y4 = torchsparse.cat([y4, x0])
+ y4 = self.up2[1](y4)
+ z3 = voxel_to_point(y4, z1)
+ z3.F = z3.F + self.point_transforms[1](z1.F)
+
+ return z3.F
+
+
+class SparseCostRegNet(nn.Module):
+ """
+ Sparse cost regularization network;
+ require sparse tensors as input
+ """
+
+ def __init__(self, d_in, d_out=8):
+ super(SparseCostRegNet, self).__init__()
+ self.d_in = d_in
+ self.d_out = d_out
+
+ self.conv0 = BasicSparseConvolutionBlock(d_in, d_out)
+
+ self.conv1 = BasicSparseConvolutionBlock(d_out, 16, stride=2)
+ self.conv2 = BasicSparseConvolutionBlock(16, 16)
+
+ self.conv3 = BasicSparseConvolutionBlock(16, 32, stride=2)
+ self.conv4 = BasicSparseConvolutionBlock(32, 32)
+
+ self.conv5 = BasicSparseConvolutionBlock(32, 64, stride=2)
+ self.conv6 = BasicSparseConvolutionBlock(64, 64)
+
+ self.conv7 = BasicSparseDeconvolutionBlock(64, 32, ks=3, stride=2)
+
+ self.conv9 = BasicSparseDeconvolutionBlock(32, 16, ks=3, stride=2)
+
+ self.conv11 = BasicSparseDeconvolutionBlock(16, d_out, ks=3, stride=2)
+
+ def forward(self, x):
+ """
+
+ :param x: sparse tensor
+ :return: sparse tensor
+ """
+ conv0 = self.conv0(x)
+ conv2 = self.conv2(self.conv1(conv0))
+ conv4 = self.conv4(self.conv3(conv2))
+
+ x = self.conv6(self.conv5(conv4))
+ x = conv4 + self.conv7(x)
+ del conv4
+ x = conv2 + self.conv9(x)
+ del conv2
+ x = conv0 + self.conv11(x)
+ del conv0
+ return x.F
+
+
+class SConv3d(nn.Module):
+ def __init__(self, inc, outc, pres, vres, ks=3, stride=1, dilation=1):
+ super().__init__()
+ self.net = spnn.Conv3d(inc,
+ outc,
+ kernel_size=ks,
+ dilation=dilation,
+ stride=stride)
+ self.point_transforms = nn.Sequential(
+ nn.Linear(inc, outc),
+ )
+ self.pres = pres
+ self.vres = vres
+
+ def forward(self, z):
+ x = initial_voxelize(z, self.pres, self.vres)
+ x = self.net(x)
+ out = voxel_to_point(x, z, nearest=False)
+ out.F = out.F + self.point_transforms(z.F)
+ return out
diff --git a/SparseNeuS_demo_v1/tsparse/torchsparse_utils.py b/SparseNeuS_demo_v1/tsparse/torchsparse_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..32f5b92ae5ef4bf9836b1e4c1dc17eaf3f7c93f9
--- /dev/null
+++ b/SparseNeuS_demo_v1/tsparse/torchsparse_utils.py
@@ -0,0 +1,137 @@
+"""
+Copied from:
+https://github.com/mit-han-lab/spvnas/blob/b24f50379ed888d3a0e784508a809d4e92e820c0/core/models/utils.py
+"""
+import torch
+import torchsparse.nn.functional as F
+from torchsparse import PointTensor, SparseTensor
+from torchsparse.nn.utils import get_kernel_offsets
+
+import numpy as np
+
+# __all__ = ['initial_voxelize', 'point_to_voxel', 'voxel_to_point']
+
+
+# z: PointTensor
+# return: SparseTensor
+def initial_voxelize(z, init_res, after_res):
+ new_float_coord = torch.cat(
+ [(z.C[:, :3] * init_res) / after_res, z.C[:, -1].view(-1, 1)], 1)
+
+ pc_hash = F.sphash(torch.floor(new_float_coord).int())
+ sparse_hash = torch.unique(pc_hash)
+ idx_query = F.sphashquery(pc_hash, sparse_hash)
+ counts = F.spcount(idx_query.int(), len(sparse_hash))
+
+ inserted_coords = F.spvoxelize(torch.floor(new_float_coord), idx_query,
+ counts)
+ inserted_coords = torch.round(inserted_coords).int()
+ inserted_feat = F.spvoxelize(z.F, idx_query, counts)
+
+ new_tensor = SparseTensor(inserted_feat, inserted_coords, 1)
+ new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords)
+ z.additional_features['idx_query'][1] = idx_query
+ z.additional_features['counts'][1] = counts
+ z.C = new_float_coord
+
+ return new_tensor
+
+
+# x: SparseTensor, z: PointTensor
+# return: SparseTensor
+def point_to_voxel(x, z):
+ if z.additional_features is None or z.additional_features.get('idx_query') is None \
+ or z.additional_features['idx_query'].get(x.s) is None:
+ # pc_hash = hash_gpu(torch.floor(z.C).int())
+ pc_hash = F.sphash(
+ torch.cat([
+ torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0],
+ z.C[:, -1].int().view(-1, 1)
+ ], 1))
+ sparse_hash = F.sphash(x.C)
+ idx_query = F.sphashquery(pc_hash, sparse_hash)
+ counts = F.spcount(idx_query.int(), x.C.shape[0])
+ z.additional_features['idx_query'][x.s] = idx_query
+ z.additional_features['counts'][x.s] = counts
+ else:
+ idx_query = z.additional_features['idx_query'][x.s]
+ counts = z.additional_features['counts'][x.s]
+
+ inserted_feat = F.spvoxelize(z.F, idx_query, counts)
+ new_tensor = SparseTensor(inserted_feat, x.C, x.s)
+ new_tensor.cmaps = x.cmaps
+ new_tensor.kmaps = x.kmaps
+
+ return new_tensor
+
+
+# x: SparseTensor, z: PointTensor
+# return: PointTensor
+def voxel_to_point(x, z, nearest=False):
+ if z.idx_query is None or z.weights is None or z.idx_query.get(
+ x.s) is None or z.weights.get(x.s) is None:
+ off = get_kernel_offsets(2, x.s, 1, device=z.F.device)
+ # old_hash = kernel_hash_gpu(torch.floor(z.C).int(), off)
+ old_hash = F.sphash(
+ torch.cat([
+ torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0],
+ z.C[:, -1].int().view(-1, 1)
+ ], 1), off)
+ mm = x.C.to(z.F.device)
+ pc_hash = F.sphash(x.C.to(z.F.device))
+ idx_query = F.sphashquery(old_hash, pc_hash)
+ weights = F.calc_ti_weights(z.C, idx_query,
+ scale=x.s[0]).transpose(0, 1).contiguous()
+ idx_query = idx_query.transpose(0, 1).contiguous()
+ if nearest:
+ weights[:, 1:] = 0.
+ idx_query[:, 1:] = -1
+ new_feat = F.spdevoxelize(x.F, idx_query, weights)
+ new_tensor = PointTensor(new_feat,
+ z.C,
+ idx_query=z.idx_query,
+ weights=z.weights)
+ new_tensor.additional_features = z.additional_features
+ new_tensor.idx_query[x.s] = idx_query
+ new_tensor.weights[x.s] = weights
+ z.idx_query[x.s] = idx_query
+ z.weights[x.s] = weights
+
+ else:
+ new_feat = F.spdevoxelize(x.F, z.idx_query.get(x.s),
+ z.weights.get(x.s)) # - sparse trilinear interpoltation operation
+ new_tensor = PointTensor(new_feat,
+ z.C,
+ idx_query=z.idx_query,
+ weights=z.weights)
+ new_tensor.additional_features = z.additional_features
+
+ return new_tensor
+
+
+def sparse_to_dense_torch_batch(locs, values, dim, default_val):
+ dense = torch.full([dim[0], dim[1], dim[2], dim[3]], float(default_val), device=locs.device)
+ dense[locs[:, 0], locs[:, 1], locs[:, 2], locs[:, 3]] = values
+ return dense
+
+
+def sparse_to_dense_torch(locs, values, dim, default_val, device):
+ dense = torch.full([dim[0], dim[1], dim[2]], float(default_val), device=device)
+ if locs.shape[0] > 0:
+ dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values
+ return dense
+
+
+def sparse_to_dense_channel(locs, values, dim, c, default_val, device):
+ locs = locs.to(torch.int64)
+ dense = torch.full([dim[0], dim[1], dim[2], c], float(default_val), device=device)
+ if locs.shape[0] > 0:
+ dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values
+ return dense
+
+
+def sparse_to_dense_np(locs, values, dim, default_val):
+ dense = np.zeros([dim[0], dim[1], dim[2]], dtype=values.dtype)
+ dense.fill(default_val)
+ dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values
+ return dense
diff --git a/SparseNeuS_demo_v1/utils/__init__.py b/SparseNeuS_demo_v1/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SparseNeuS_demo_v1/utils/misc_utils.py b/SparseNeuS_demo_v1/utils/misc_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..85e80cf4e2bcf8bed0086e2b6c8a3bf3da40a056
--- /dev/null
+++ b/SparseNeuS_demo_v1/utils/misc_utils.py
@@ -0,0 +1,219 @@
+import os, torch, cv2, re
+import numpy as np
+
+from PIL import Image
+import torch.nn.functional as F
+import torchvision.transforms as T
+
+# Misc
+img2mse = lambda x, y: torch.mean((x - y) ** 2)
+mse2psnr = lambda x: -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
+to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
+mse2psnr2 = lambda x: -10. * np.log(x) / np.log(10.)
+
+
+def get_psnr(imgs_pred, imgs_gt):
+ psnrs = []
+ for (img, tar) in zip(imgs_pred, imgs_gt):
+ psnrs.append(mse2psnr2(np.mean((img - tar.cpu().numpy()) ** 2)))
+ return np.array(psnrs)
+
+
+def init_log(log, keys):
+ for key in keys:
+ log[key] = torch.tensor([0.0], dtype=float)
+ return log
+
+
+def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):
+ """
+ depth: (H, W)
+ """
+
+ x = np.nan_to_num(depth) # change nan to 0
+ if minmax is None:
+ mi = np.min(x[x > 0]) # get minimum positive depth (ignore background)
+ ma = np.max(x)
+ else:
+ mi, ma = minmax
+
+ x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1
+ x = (255 * x).astype(np.uint8)
+ x_ = cv2.applyColorMap(x, cmap)
+ return x_, [mi, ma]
+
+
+def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):
+ """
+ depth: (H, W)
+ """
+ if type(depth) is not np.ndarray:
+ depth = depth.cpu().numpy()
+
+ x = np.nan_to_num(depth) # change nan to 0
+ if minmax is None:
+ mi = np.min(x[x > 0]) # get minimum positive depth (ignore background)
+ ma = np.max(x)
+ else:
+ mi, ma = minmax
+
+ x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1
+ x = (255 * x).astype(np.uint8)
+ x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
+ x_ = T.ToTensor()(x_) # (3, H, W)
+ return x_, [mi, ma]
+
+
+def abs_error_numpy(depth_pred, depth_gt, mask):
+ depth_pred, depth_gt = depth_pred[mask], depth_gt[mask]
+ return np.abs(depth_pred - depth_gt)
+
+
+def abs_error(depth_pred, depth_gt, mask):
+ depth_pred, depth_gt = depth_pred[mask], depth_gt[mask]
+ err = depth_pred - depth_gt
+ return np.abs(err) if type(depth_pred) is np.ndarray else err.abs()
+
+
+def acc_threshold(depth_pred, depth_gt, mask, threshold):
+ """
+ computes the percentage of pixels whose depth error is less than @threshold
+ """
+ errors = abs_error(depth_pred, depth_gt, mask)
+ acc_mask = errors < threshold
+ return acc_mask.astype('float') if type(depth_pred) is np.ndarray else acc_mask.float()
+
+
+def to_tensor_cuda(data, device, filter):
+ for item in data.keys():
+
+ if item in filter:
+ continue
+
+ if type(data[item]) is np.ndarray:
+ data[item] = torch.tensor(data[item], dtype=torch.float32, device=device)
+ else:
+ data[item] = data[item].float().to(device)
+ return data
+
+
+def to_cuda(data, device, filter):
+ for item in data.keys():
+ if item in filter:
+ continue
+
+ data[item] = data[item].float().to(device)
+ return data
+
+
+def tensor_unsqueeze(data, filter):
+ for item in data.keys():
+ if item in filter:
+ continue
+
+ data[item] = data[item][None]
+ return data
+
+
+def filter_keys(dict):
+ dict.pop('N_samples')
+ if 'ndc' in dict.keys():
+ dict.pop('ndc')
+ if 'lindisp' in dict.keys():
+ dict.pop('lindisp')
+ return dict
+
+
+def sub_selete_data(data_batch, device, idx, filtKey=[],
+ filtIndex=['view_ids_all', 'c2ws_all', 'scan', 'bbox', 'w2ref', 'ref2w', 'light_id', 'ckpt',
+ 'idx']):
+ data_sub_selete = {}
+ for item in data_batch.keys():
+ data_sub_selete[item] = data_batch[item][:, idx].float() if (
+ item not in filtIndex and torch.is_tensor(item) and item.dim() > 2) else data_batch[item].float()
+ if not data_sub_selete[item].is_cuda:
+ data_sub_selete[item] = data_sub_selete[item].to(device)
+ return data_sub_selete
+
+
+def detach_data(dictionary):
+ dictionary_new = {}
+ for key in dictionary.keys():
+ dictionary_new[key] = dictionary[key].detach().clone()
+ return dictionary_new
+
+
+def read_pfm(filename):
+ file = open(filename, 'rb')
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().decode('utf-8').rstrip()
+ if header == 'PF':
+ color = True
+ elif header == 'Pf':
+ color = False
+ else:
+ raise Exception('Not a PFM file.')
+
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
+ if dim_match:
+ width, height = map(int, dim_match.groups())
+ else:
+ raise Exception('Malformed PFM header.')
+
+ scale = float(file.readline().rstrip())
+ if scale < 0: # little-endian
+ endian = '<'
+ scale = -scale
+ else:
+ endian = '>' # big-endian
+
+ data = np.fromfile(file, endian + 'f')
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+ file.close()
+ return data, scale
+
+
+from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
+
+
+# from warmup_scheduler import GradualWarmupScheduler
+def get_scheduler(hparams, optimizer):
+ eps = 1e-8
+ if hparams.lr_scheduler == 'steplr':
+ scheduler = MultiStepLR(optimizer, milestones=hparams.decay_step,
+ gamma=hparams.decay_gamma)
+ elif hparams.lr_scheduler == 'cosine':
+ scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_epochs, eta_min=eps)
+
+ else:
+ raise ValueError('scheduler not recognized!')
+
+ # if hparams.warmup_epochs > 0 and hparams.optimizer not in ['radam', 'ranger']:
+ # scheduler = GradualWarmupScheduler(optimizer, multiplier=hparams.warmup_multiplier,
+ # total_epoch=hparams.warmup_epochs, after_scheduler=scheduler)
+ return scheduler
+
+
+#### pairing ####
+def get_nearest_pose_ids(tar_pose, ref_poses, num_select):
+ '''
+ Args:
+ tar_pose: target pose [N, 4, 4]
+ ref_poses: reference poses [M, 4, 4]
+ num_select: the number of nearest views to select
+ Returns: the selected indices
+ '''
+
+ dists = np.linalg.norm(tar_pose[:, None, :3, 3] - ref_poses[None, :, :3, 3], axis=-1)
+
+ sorted_ids = np.argsort(dists, axis=-1)
+ selected_ids = sorted_ids[:, :num_select]
+ return selected_ids
diff --git a/SparseNeuS_demo_v1/utils/training_utils.py b/SparseNeuS_demo_v1/utils/training_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d128ba2beda39b708850bd4c17c4603a8a17848
--- /dev/null
+++ b/SparseNeuS_demo_v1/utils/training_utils.py
@@ -0,0 +1,129 @@
+import numpy as np
+import torchvision.utils as vutils
+import torch, random
+import torch.nn.functional as F
+
+
+# print arguments
+def print_args(args):
+ print("################################ args ################################")
+ for k, v in args.__dict__.items():
+ print("{0: <10}\t{1: <30}\t{2: <20}".format(k, str(v), str(type(v))))
+ print("########################################################################")
+
+
+# torch.no_grad warpper for functions
+def make_nograd_func(func):
+ def wrapper(*f_args, **f_kwargs):
+ with torch.no_grad():
+ ret = func(*f_args, **f_kwargs)
+ return ret
+
+ return wrapper
+
+
+# convert a function into recursive style to handle nested dict/list/tuple variables
+def make_recursive_func(func):
+ def wrapper(vars, device=None):
+ if isinstance(vars, list):
+ return [wrapper(x, device) for x in vars]
+ elif isinstance(vars, tuple):
+ return tuple([wrapper(x, device) for x in vars])
+ elif isinstance(vars, dict):
+ return {k: wrapper(v, device) for k, v in vars.items()}
+ else:
+ return func(vars, device)
+
+ return wrapper
+
+
+@make_recursive_func
+def tensor2float(vars):
+ if isinstance(vars, float):
+ return vars
+ elif isinstance(vars, torch.Tensor):
+ return vars.data.item()
+ else:
+ raise NotImplementedError("invalid input type {} for tensor2float".format(type(vars)))
+
+
+@make_recursive_func
+def tensor2numpy(vars):
+ if isinstance(vars, np.ndarray):
+ return vars
+ elif isinstance(vars, torch.Tensor):
+ return vars.detach().cpu().numpy().copy()
+ else:
+ raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars)))
+
+
+@make_recursive_func
+def numpy2tensor(vars, device='cpu'):
+ if not isinstance(vars, torch.Tensor) and vars is not None :
+ return torch.tensor(vars, device=device)
+ elif isinstance(vars, torch.Tensor):
+ return vars
+ elif vars is None:
+ return vars
+ else:
+ raise NotImplementedError("invalid input type {} for float2tensor".format(type(vars)))
+
+
+@make_recursive_func
+def tocuda(vars, device='cuda'):
+ if isinstance(vars, torch.Tensor):
+ return vars.to(device)
+ elif isinstance(vars, str):
+ return vars
+ else:
+ raise NotImplementedError("invalid input type {} for tocuda".format(type(vars)))
+
+
+import torch.distributed as dist
+
+
+def synchronize():
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+ dist.barrier()
+
+
+def get_world_size():
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def reduce_scalar_outputs(scalar_outputs):
+ world_size = get_world_size()
+ if world_size < 2:
+ return scalar_outputs
+ with torch.no_grad():
+ names = []
+ scalars = []
+ for k in sorted(scalar_outputs.keys()):
+ names.append(k)
+ if isinstance(scalar_outputs[k], torch.Tensor):
+ scalars.append(scalar_outputs[k])
+ else:
+ scalars.append(torch.tensor(scalar_outputs[k], device='cuda'))
+ scalars = torch.stack(scalars, dim=0)
+ dist.reduce(scalars, dst=0)
+ if dist.get_rank() == 0:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ scalars /= world_size
+ reduced_scalars = {k: v for k, v in zip(names, scalars)}
+
+ return reduced_scalars
diff --git a/SparseNeuS_demo_v1/val.ipynb b/SparseNeuS_demo_v1/val.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..a39350692b1cc7e35754de19b4dae0277a959f2b
--- /dev/null
+++ b/SparseNeuS_demo_v1/val.ipynb
@@ -0,0 +1,951 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name gradio_tmp --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "\u001b[34mStore in: ../gradio_tmp\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:159 - __init__() ] Find checkpoint: ckpt_285000.pth\n",
+ "[exp_runner_generic_blender_val.py:500 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 285000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:579 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.0004603862762451172\n",
+ "export mesh time: 5.274312734603882\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "0"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import os \n",
+ "\n",
+ "dataset = 'gradio_tmp' # !!! the subfolder name in valpath for which you want to eval\n",
+ "# os.system('pwd')\n",
+ "\n",
+ "bash_script = f'CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name {dataset} --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue'\n",
+ "print(bash_script)\n",
+ "os.system(bash_script)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name eval_test/ebicycle2 --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../eval_test/ebicycle2\n",
+ "\u001b[34mStore in: ../eval_test/ebicycle2\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_180000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 180000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.0004477500915527344\n",
+ "export mesh time: 5.30308723449707\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "0"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import os \n",
+ "\n",
+ "dataset = 'eval_test/ebicycle2' # !!! the subfolder name in valpath for which you want to eval\n",
+ "# os.system('pwd')\n",
+ "\n",
+ "bash_script = f'CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name {dataset} --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue'\n",
+ "print(bash_script)\n",
+ "os.system(bash_script)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 0%| | 0/20 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/bigmac --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/bigmac\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/bigmac\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.00046324729919433594\n",
+ "export mesh time: 5.254480361938477\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 5%|▌ | 1/20 [00:09<03:05, 9.78s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/dinosaur --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/dinosaur\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/dinosaur\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.0004961490631103516\n",
+ "export mesh time: 5.2382142543792725\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 10%|█ | 2/20 [00:19<02:56, 9.79s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/goose_chef --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/goose_chef\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/goose_chef\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.0004622936248779297\n",
+ "export mesh time: 5.243659257888794\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 15%|█▌ | 3/20 [00:29<02:46, 9.79s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/stool2 --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/stool2\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/stool2\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.0004367828369140625\n",
+ "export mesh time: 5.210400342941284\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 20%|██ | 4/20 [00:39<02:36, 9.78s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/nuts --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/nuts\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/nuts\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.00047707557678222656\n",
+ "export mesh time: 5.257585763931274\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 25%|██▌ | 5/20 [00:48<02:27, 9.80s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/yellow_duck --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/yellow_duck\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/yellow_duck\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.0004456043243408203\n",
+ "export mesh time: 5.2359161376953125\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 30%|███ | 6/20 [00:58<02:17, 9.80s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/pineapple_bottle --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/pineapple_bottle\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/pineapple_bottle\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.0004482269287109375\n",
+ "export mesh time: 5.260643243789673\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 35%|███▌ | 7/20 [01:08<02:07, 9.82s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/pancake --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/pancake\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/pancake\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.00047278404235839844\n",
+ "export mesh time: 5.221261739730835\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 40%|████ | 8/20 [01:18<01:57, 9.83s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/hydrant --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/hydrant\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/hydrant\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.00044608116149902344\n",
+ "export mesh time: 5.220775604248047\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 45%|████▌ | 9/20 [01:28<01:47, 9.80s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/oreo_drink --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/oreo_drink\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/oreo_drink\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.00043773651123046875\n",
+ "export mesh time: 5.22020149230957\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 50%|█████ | 10/20 [01:37<01:37, 9.79s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/scissor --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/scissor\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/scissor\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.0004305839538574219\n",
+ "export mesh time: 5.205048322677612\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 55%|█████▌ | 11/20 [01:47<01:28, 9.78s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/chocolatecake --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/chocolatecake\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/chocolatecake\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.00043392181396484375\n",
+ "export mesh time: 5.239720344543457\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 60%|██████ | 12/20 [01:57<01:18, 9.79s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/mario --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/mario\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/mario\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.0004582405090332031\n",
+ "export mesh time: 5.255834102630615\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 65%|██████▌ | 13/20 [02:07<01:08, 9.79s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/broccoli --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/broccoli\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/broccoli\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.0005385875701904297\n",
+ "export mesh time: 5.311108112335205\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 70%|███████ | 14/20 [02:17<00:58, 9.83s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/chair3 --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/chair3\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/chair3\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.0004534721374511719\n",
+ "export mesh time: 5.227793216705322\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 75%|███████▌ | 15/20 [02:27<00:49, 9.82s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/lysol --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/lysol\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/lysol\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.00045418739318847656\n",
+ "export mesh time: 5.2634806632995605\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 80%|████████ | 16/20 [02:36<00:39, 9.82s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/clock2 --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/clock2\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/clock2\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.00043654441833496094\n",
+ "export mesh time: 5.26728630065918\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 85%|████████▌ | 17/20 [02:46<00:29, 9.84s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/downy --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/downy\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/downy\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.0004367828369140625\n",
+ "export mesh time: 5.207308769226074\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 90%|█████████ | 18/20 [02:56<00:19, 9.82s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/ebicycle2 --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/ebicycle2\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/ebicycle2\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.00042724609375\n",
+ "export mesh time: 5.277446985244751\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 95%|█████████▌| 19/20 [03:06<00:09, 9.84s/it]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name mesh_pred/eval_real/strawberrycake --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue\n",
+ "\u001b[31mdetected 1 GPUs\u001b[0m\n",
+ "\u001b[33mbase_exp_dir: exp/lod0\u001b[0m\n",
+ "save mesh to: ../mesh_pred/eval_real/strawberrycake\n",
+ "\u001b[34mStore in: ../mesh_pred/eval_real/strawberrycake\u001b[0m\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/chao/anaconda3/envs/gradio/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)\n",
+ " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n",
+ "[exp_runner_generic_blender_val.py:160 - __init__() ] Find checkpoint: ckpt_190000.pth\n",
+ "[exp_runner_generic_blender_val.py:501 - load_checkpoint() ] End\n",
+ "ic| self.iter_step: 190000, idx: -1\n",
+ "[exp_runner_generic_blender_val.py:580 - export_mesh() ] Validate begin\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "time for getting data 0.0004417896270751953\n",
+ "export mesh time: 5.276712417602539\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 20/20 [03:16<00:00, 9.81s/it]\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os \n",
+ "from tqdm import tqdm\n",
+ "base_dir = \"mesh_pred/eval_real\"\n",
+ "for nickname in tqdm(os.listdir(\"../\"+base_dir)):\n",
+ " dataset = os.path.join(base_dir, nickname)\n",
+ " bash_script = f'CUDA_VISIBLE_DEVICES=3 python exp_runner_generic_blender_val.py --specific_dataset_name {dataset} --mode export_mesh --conf confs/one2345_lod0_val_demo.conf --is_continue'\n",
+ " print(bash_script)\n",
+ " os.system(bash_script)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 20/20 [00:00<00:00, 1310.90it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os \n",
+ "from tqdm import tqdm\n",
+ "from shutil import copyfile\n",
+ "base_dir = \"mesh_pred/eval_real\"\n",
+ "out_dir = f\"../{base_dir}_mesh\"\n",
+ "os.makedirs(out_dir, exist_ok=True)\n",
+ "for nickname in tqdm(os.listdir(\"../\"+base_dir)):\n",
+ " mesh_path = os.path.join(\"../\", base_dir, nickname, \"meshes_val_bg/lod0/mesh_00190000_gradio_lod0.ply\")\n",
+ " target_path = os.path.join(out_dir, nickname+\".ply\")\n",
+ " copyfile(mesh_path, target_path)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "base",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.11"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/SparseNeuS_demo_v1/weights/ckpt.pth b/SparseNeuS_demo_v1/weights/ckpt.pth
new file mode 100644
index 0000000000000000000000000000000000000000..ea22ffa970c253e2f1d6cccbe195f703027264f6
--- /dev/null
+++ b/SparseNeuS_demo_v1/weights/ckpt.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ee9a0027096b3f4f304e2801ebe41545241f974f7d812dc802ac70c8aeeab2b2
+size 6859767
diff --git a/configs/sd-objaverse-finetune-c_concat-256.yaml b/configs/sd-objaverse-finetune-c_concat-256.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..488dafa27fcd632215ab869f9ab15c8ed452b66a
--- /dev/null
+++ b/configs/sd-objaverse-finetune-c_concat-256.yaml
@@ -0,0 +1,117 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "image_target"
+ cond_stage_key: "image_cond"
+ image_size: 32
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: hybrid
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 100 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 8
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
+
+
+data:
+ target: ldm.data.simple.ObjaverseDataModuleFromConfig
+ params:
+ root_dir: 'views_whole_sphere'
+ batch_size: 192
+ num_workers: 16
+ total_view: 4
+ train:
+ validation: False
+ image_transforms:
+ size: 256
+
+ validation:
+ validation: True
+ image_transforms:
+ size: 256
+
+
+lightning:
+ find_unused_parameters: false
+ metrics_over_trainsteps_checkpoint: True
+ modelcheckpoint:
+ params:
+ every_n_train_steps: 5000
+ callbacks:
+ image_logger:
+ target: main.ImageLogger
+ params:
+ batch_frequency: 500
+ max_images: 32
+ increase_log_steps: False
+ log_first_step: True
+ log_images_kwargs:
+ use_ema_scope: False
+ inpaint: False
+ plot_progressive_rows: False
+ plot_diffusion_rows: False
+ N: 32
+ unconditional_guidance_scale: 3.0
+ unconditional_guidance_label: [""]
+
+ trainer:
+ benchmark: True
+ val_check_interval: 5000000 # really sorry
+ num_sanity_val_steps: 0
+ accumulate_grad_batches: 1
diff --git a/one2345_elev_est/.idea/.gitignore b/one2345_elev_est/.idea/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..b58b603fea78041071d125a30db58d79b3d49217
--- /dev/null
+++ b/one2345_elev_est/.idea/.gitignore
@@ -0,0 +1,5 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
diff --git a/one2345_elev_est/.idea/inspectionProfiles/Project_Default.xml b/one2345_elev_est/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000000000000000000000000000000000000..2be94a6daed0ec89aa028bac406adc51b1b53d89
--- /dev/null
+++ b/one2345_elev_est/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,29 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/one2345_elev_est/.idea/inspectionProfiles/profiles_settings.xml b/one2345_elev_est/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99
--- /dev/null
+++ b/one2345_elev_est/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/one2345_elev_est/.idea/misc.xml b/one2345_elev_est/.idea/misc.xml
new file mode 100644
index 0000000000000000000000000000000000000000..d56657add3eb3c246989284ec6e6a8475603cf1d
--- /dev/null
+++ b/one2345_elev_est/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/one2345_elev_est/.idea/modules.xml b/one2345_elev_est/.idea/modules.xml
new file mode 100644
index 0000000000000000000000000000000000000000..c835d1a7ad8ef9f9ef336501b88a5c38eac2dd86
--- /dev/null
+++ b/one2345_elev_est/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/one2345_elev_est/.idea/one2345_elev_est.iml b/one2345_elev_est/.idea/one2345_elev_est.iml
new file mode 100644
index 0000000000000000000000000000000000000000..870ae2cd58bfdf63caf9d20b67f1b848bad7aabe
--- /dev/null
+++ b/one2345_elev_est/.idea/one2345_elev_est.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/one2345_elev_est/install.sh b/one2345_elev_est/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..27ad025b471b9368a059759d92501730d9f14cd2
--- /dev/null
+++ b/one2345_elev_est/install.sh
@@ -0,0 +1 @@
+python setup.py build develop
diff --git a/one2345_elev_est/oee/models/loftr/__init__.py b/one2345_elev_est/oee/models/loftr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d69b9c131cf41e95c5c6ee7d389b375267b22fa
--- /dev/null
+++ b/one2345_elev_est/oee/models/loftr/__init__.py
@@ -0,0 +1,2 @@
+from .loftr import LoFTR
+from .utils.cvpr_ds_config import default_cfg
diff --git a/one2345_elev_est/oee/models/loftr/backbone/__init__.py b/one2345_elev_est/oee/models/loftr/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6e731b3f53ab367c89ef0ea8e1cbffb0d990775
--- /dev/null
+++ b/one2345_elev_est/oee/models/loftr/backbone/__init__.py
@@ -0,0 +1,11 @@
+from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4
+
+
+def build_backbone(config):
+ if config['backbone_type'] == 'ResNetFPN':
+ if config['resolution'] == (8, 2):
+ return ResNetFPN_8_2(config['resnetfpn'])
+ elif config['resolution'] == (16, 4):
+ return ResNetFPN_16_4(config['resnetfpn'])
+ else:
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
diff --git a/one2345_elev_est/oee/models/loftr/backbone/resnet_fpn.py b/one2345_elev_est/oee/models/loftr/backbone/resnet_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..985e5b3f273a51e51447a8025ca3aadbe46752eb
--- /dev/null
+++ b/one2345_elev_est/oee/models/loftr/backbone/resnet_fpn.py
@@ -0,0 +1,199 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution without padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ def __init__(self, in_planes, planes, stride=1):
+ super().__init__()
+ self.conv1 = conv3x3(in_planes, planes, stride)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+
+ if stride == 1:
+ self.downsample = None
+ else:
+ self.downsample = nn.Sequential(
+ conv1x1(in_planes, planes, stride=stride),
+ nn.BatchNorm2d(planes)
+ )
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.bn1(self.conv1(y)))
+ y = self.bn2(self.conv2(y))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x+y)
+
+
+class ResNetFPN_8_2(nn.Module):
+ """
+ ResNet+FPN, output resolution are 1/8 and 1/2.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block = BasicBlock
+ initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+
+ # Class Variable
+ self.in_planes = initial_dim
+
+ # Networks
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(initial_dim)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
+
+ # 3. FPN upsample
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = nn.Sequential(
+ conv3x3(block_dims[2], block_dims[2]),
+ nn.BatchNorm2d(block_dims[2]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[2], block_dims[1]),
+ )
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
+ self.layer1_outconv2 = nn.Sequential(
+ conv3x3(block_dims[1], block_dims[1]),
+ nn.BatchNorm2d(block_dims[1]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[1], block_dims[0]),
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, dim, stride=1):
+ layer1 = block(self.in_planes, dim, stride=stride)
+ layer2 = block(dim, dim, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # ResNet Backbone
+ x0 = self.relu(self.bn1(self.conv1(x)))
+ x1 = self.layer1(x0) # 1/2
+ x2 = self.layer2(x1) # 1/4
+ x3 = self.layer3(x2) # 1/8
+
+ # FPN
+ x3_out = self.layer3_outconv(x3)
+
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x2_out = self.layer2_outconv(x2)
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x1_out = self.layer1_outconv(x1)
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
+
+ return [x3_out, x1_out]
+
+
+class ResNetFPN_16_4(nn.Module):
+ """
+ ResNet+FPN, output resolution are 1/16 and 1/4.
+ Each block has 2 layers.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ # Config
+ block = BasicBlock
+ initial_dim = config['initial_dim']
+ block_dims = config['block_dims']
+
+ # Class Variable
+ self.in_planes = initial_dim
+
+ # Networks
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(initial_dim)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
+ self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16
+
+ # 3. FPN upsample
+ self.layer4_outconv = conv1x1(block_dims[3], block_dims[3])
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[3])
+ self.layer3_outconv2 = nn.Sequential(
+ conv3x3(block_dims[3], block_dims[3]),
+ nn.BatchNorm2d(block_dims[3]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[3], block_dims[2]),
+ )
+
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
+ self.layer2_outconv2 = nn.Sequential(
+ conv3x3(block_dims[2], block_dims[2]),
+ nn.BatchNorm2d(block_dims[2]),
+ nn.LeakyReLU(),
+ conv3x3(block_dims[2], block_dims[1]),
+ )
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, dim, stride=1):
+ layer1 = block(self.in_planes, dim, stride=stride)
+ layer2 = block(dim, dim, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ # ResNet Backbone
+ x0 = self.relu(self.bn1(self.conv1(x)))
+ x1 = self.layer1(x0) # 1/2
+ x2 = self.layer2(x1) # 1/4
+ x3 = self.layer3(x2) # 1/8
+ x4 = self.layer4(x3) # 1/16
+
+ # FPN
+ x4_out = self.layer4_outconv(x4)
+
+ x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x3_out = self.layer3_outconv(x3)
+ x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
+
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
+ x2_out = self.layer2_outconv(x2)
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
+
+ return [x4_out, x2_out]
diff --git a/one2345_elev_est/oee/models/loftr/loftr.py b/one2345_elev_est/oee/models/loftr/loftr.py
new file mode 100644
index 0000000000000000000000000000000000000000..79c491ee47a4d67cb8b3fe493397349e0867accd
--- /dev/null
+++ b/one2345_elev_est/oee/models/loftr/loftr.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+from einops.einops import rearrange
+
+from .backbone import build_backbone
+from .utils.position_encoding import PositionEncodingSine
+from .loftr_module import LocalFeatureTransformer, FinePreprocess
+from .utils.coarse_matching import CoarseMatching
+from .utils.fine_matching import FineMatching
+
+
+class LoFTR(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ # Misc
+ self.config = config
+
+ # Modules
+ self.backbone = build_backbone(config)
+ self.pos_encoding = PositionEncodingSine(
+ config['coarse']['d_model'],
+ temp_bug_fix=config['coarse']['temp_bug_fix'])
+ self.loftr_coarse = LocalFeatureTransformer(config['coarse'])
+ self.coarse_matching = CoarseMatching(config['match_coarse'])
+ self.fine_preprocess = FinePreprocess(config)
+ self.loftr_fine = LocalFeatureTransformer(config["fine"])
+ self.fine_matching = FineMatching()
+
+ def forward(self, data):
+ """
+ Update:
+ data (dict): {
+ 'image0': (torch.Tensor): (N, 1, H, W)
+ 'image1': (torch.Tensor): (N, 1, H, W)
+ 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
+ 'mask1'(optional) : (torch.Tensor): (N, H, W)
+ }
+ """
+ # 1. Local Feature CNN
+ data.update({
+ 'bs': data['image0'].size(0),
+ 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
+ })
+
+ if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
+ feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
+ (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs'])
+ else: # handle different input shapes
+ (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1'])
+
+ data.update({
+ 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
+ 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:]
+ })
+
+ # 2. coarse-level loftr module
+ # add featmap with positional encoding, then flatten it to sequence [N, HW, C]
+ feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c')
+ feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c')
+
+ mask_c0 = mask_c1 = None # mask is useful in training
+ if 'mask0' in data:
+ mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
+ feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
+
+ # 3. match coarse-level
+ self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1)
+
+ # 4. fine-level refinement
+ feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
+ if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
+ feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
+
+ # 5. match fine-level
+ self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
+
+ def load_state_dict(self, state_dict, *args, **kwargs):
+ for k in list(state_dict.keys()):
+ if k.startswith('matcher.'):
+ state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
+ return super().load_state_dict(state_dict, *args, **kwargs)
diff --git a/one2345_elev_est/oee/models/loftr/loftr_module/__init__.py b/one2345_elev_est/oee/models/loftr/loftr_module/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca51db4f50a0c4f3dcd795e74b83e633ab2e990a
--- /dev/null
+++ b/one2345_elev_est/oee/models/loftr/loftr_module/__init__.py
@@ -0,0 +1,2 @@
+from .transformer import LocalFeatureTransformer
+from .fine_preprocess import FinePreprocess
diff --git a/one2345_elev_est/oee/models/loftr/loftr_module/fine_preprocess.py b/one2345_elev_est/oee/models/loftr/loftr_module/fine_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bb8eefd362240a9901a335f0e6e07770ff04567
--- /dev/null
+++ b/one2345_elev_est/oee/models/loftr/loftr_module/fine_preprocess.py
@@ -0,0 +1,59 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops.einops import rearrange, repeat
+
+
+class FinePreprocess(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ self.cat_c_feat = config['fine_concat_coarse_feat']
+ self.W = self.config['fine_window_size']
+
+ d_model_c = self.config['coarse']['d_model']
+ d_model_f = self.config['fine']['d_model']
+ self.d_model_f = d_model_f
+ if self.cat_c_feat:
+ self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
+ self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu")
+
+ def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
+ W = self.W
+ stride = data['hw0_f'][0] // data['hw0_c'][0]
+
+ data.update({'W': W})
+ if data['b_ids'].shape[0] == 0:
+ feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
+ feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
+ return feat0, feat1
+
+ # 1. unfold(crop) all local windows
+ feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+ feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
+
+ # 2. select only the predicted matches
+ feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
+ feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
+
+ # option: use coarse-level loftr feature as context: concat and linear
+ if self.cat_c_feat:
+ feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
+ feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
+ feat_cf_win = self.merge_feat(torch.cat([
+ torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
+ repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
+ ], -1))
+ feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
+
+ return feat_f0_unfold, feat_f1_unfold
diff --git a/one2345_elev_est/oee/models/loftr/loftr_module/linear_attention.py b/one2345_elev_est/oee/models/loftr/loftr_module/linear_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..b73c5a6a6a722a44c0b68f70cb77c0988b8a5fb3
--- /dev/null
+++ b/one2345_elev_est/oee/models/loftr/loftr_module/linear_attention.py
@@ -0,0 +1,81 @@
+"""
+Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
+Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
+"""
+
+import torch
+from torch.nn import Module, Dropout
+
+
+def elu_feature_map(x):
+ return torch.nn.functional.elu(x) + 1
+
+
+class LinearAttention(Module):
+ def __init__(self, eps=1e-6):
+ super().__init__()
+ self.feature_map = elu_feature_map
+ self.eps = eps
+
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
+ """ Multi-Head linear attention proposed in "Transformers are RNNs"
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+ Q = self.feature_map(queries)
+ K = self.feature_map(keys)
+
+ # set padded position to zero
+ if q_mask is not None:
+ Q = Q * q_mask[:, :, None, None]
+ if kv_mask is not None:
+ K = K * kv_mask[:, :, None, None]
+ values = values * kv_mask[:, :, None, None]
+
+ v_length = values.size(1)
+ values = values / v_length # prevent fp16 overflow
+ KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
+ Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
+ queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
+
+ return queried_values.contiguous()
+
+
+class FullAttention(Module):
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
+ super().__init__()
+ self.use_dropout = use_dropout
+ self.dropout = Dropout(attention_dropout)
+
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+
+ # Compute the unnormalized attention and apply the masks
+ QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
+ if kv_mask is not None:
+ QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
+
+ # Compute the attention and the weighted average
+ softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
+ A = torch.softmax(softmax_temp * QK, dim=2)
+ if self.use_dropout:
+ A = self.dropout(A)
+
+ queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
+
+ return queried_values.contiguous()
diff --git a/one2345_elev_est/oee/models/loftr/loftr_module/transformer.py b/one2345_elev_est/oee/models/loftr/loftr_module/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d79390ca08953bbef44e98149e662a681a16e42e
--- /dev/null
+++ b/one2345_elev_est/oee/models/loftr/loftr_module/transformer.py
@@ -0,0 +1,101 @@
+import copy
+import torch
+import torch.nn as nn
+from .linear_attention import LinearAttention, FullAttention
+
+
+class LoFTREncoderLayer(nn.Module):
+ def __init__(self,
+ d_model,
+ nhead,
+ attention='linear'):
+ super(LoFTREncoderLayer, self).__init__()
+
+ self.dim = d_model // nhead
+ self.nhead = nhead
+
+ # multi-head attention
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
+ self.merge = nn.Linear(d_model, d_model, bias=False)
+
+ # feed-forward network
+ self.mlp = nn.Sequential(
+ nn.Linear(d_model*2, d_model*2, bias=False),
+ nn.ReLU(True),
+ nn.Linear(d_model*2, d_model, bias=False),
+ )
+
+ # norm and dropout
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ def forward(self, x, source, x_mask=None, source_mask=None):
+ """
+ Args:
+ x (torch.Tensor): [N, L, C]
+ source (torch.Tensor): [N, S, C]
+ x_mask (torch.Tensor): [N, L] (optional)
+ source_mask (torch.Tensor): [N, S] (optional)
+ """
+ bs = x.size(0)
+ query, key, value = x, source, source
+
+ # multi-head attention
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
+ message = self.norm1(message)
+
+ # feed-forward network
+ message = self.mlp(torch.cat([x, message], dim=2))
+ message = self.norm2(message)
+
+ return x + message
+
+
+class LocalFeatureTransformer(nn.Module):
+ """A Local Feature Transformer (LoFTR) module."""
+
+ def __init__(self, config):
+ super(LocalFeatureTransformer, self).__init__()
+
+ self.config = config
+ self.d_model = config['d_model']
+ self.nhead = config['nhead']
+ self.layer_names = config['layer_names']
+ encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'])
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, feat0, feat1, mask0=None, mask1=None):
+ """
+ Args:
+ feat0 (torch.Tensor): [N, L, C]
+ feat1 (torch.Tensor): [N, S, C]
+ mask0 (torch.Tensor): [N, L] (optional)
+ mask1 (torch.Tensor): [N, S] (optional)
+ """
+
+ assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
+
+ for layer, name in zip(self.layers, self.layer_names):
+ if name == 'self':
+ feat0 = layer(feat0, feat0, mask0, mask0)
+ feat1 = layer(feat1, feat1, mask1, mask1)
+ elif name == 'cross':
+ feat0 = layer(feat0, feat1, mask0, mask1)
+ feat1 = layer(feat1, feat0, mask1, mask0)
+ else:
+ raise KeyError
+
+ return feat0, feat1
diff --git a/one2345_elev_est/oee/models/loftr/utils/coarse_matching.py b/one2345_elev_est/oee/models/loftr/utils/coarse_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..a97263339462dec3af9705d33d6ee634e2f46914
--- /dev/null
+++ b/one2345_elev_est/oee/models/loftr/utils/coarse_matching.py
@@ -0,0 +1,261 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops.einops import rearrange
+
+INF = 1e9
+
+def mask_border(m, b: int, v):
+ """ Mask borders with value
+ Args:
+ m (torch.Tensor): [N, H0, W0, H1, W1]
+ b (int)
+ v (m.dtype)
+ """
+ if b <= 0:
+ return
+
+ m[:, :b] = v
+ m[:, :, :b] = v
+ m[:, :, :, :b] = v
+ m[:, :, :, :, :b] = v
+ m[:, -b:] = v
+ m[:, :, -b:] = v
+ m[:, :, :, -b:] = v
+ m[:, :, :, :, -b:] = v
+
+
+def mask_border_with_padding(m, bd, v, p_m0, p_m1):
+ if bd <= 0:
+ return
+
+ m[:, :bd] = v
+ m[:, :, :bd] = v
+ m[:, :, :, :bd] = v
+ m[:, :, :, :, :bd] = v
+
+ h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
+ h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
+ for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
+ m[b_idx, h0 - bd:] = v
+ m[b_idx, :, w0 - bd:] = v
+ m[b_idx, :, :, h1 - bd:] = v
+ m[b_idx, :, :, :, w1 - bd:] = v
+
+
+def compute_max_candidates(p_m0, p_m1):
+ """Compute the max candidates of all pairs within a batch
+
+ Args:
+ p_m0, p_m1 (torch.Tensor): padded masks
+ """
+ h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
+ h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
+ max_cand = torch.sum(
+ torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
+ return max_cand
+
+
+class CoarseMatching(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ # general config
+ self.thr = config['thr']
+ self.border_rm = config['border_rm']
+ # -- # for trainig fine-level LoFTR
+ self.train_coarse_percent = config['train_coarse_percent']
+ self.train_pad_num_gt_min = config['train_pad_num_gt_min']
+
+ # we provide 2 options for differentiable matching
+ self.match_type = config['match_type']
+ if self.match_type == 'dual_softmax':
+ self.temperature = config['dsmax_temperature']
+ elif self.match_type == 'sinkhorn':
+ try:
+ from .superglue import log_optimal_transport
+ except ImportError:
+ raise ImportError("download superglue.py first!")
+ self.log_optimal_transport = log_optimal_transport
+ self.bin_score = nn.Parameter(
+ torch.tensor(config['skh_init_bin_score'], requires_grad=True))
+ self.skh_iters = config['skh_iters']
+ self.skh_prefilter = config['skh_prefilter']
+ else:
+ raise NotImplementedError()
+
+ def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
+ """
+ Args:
+ feat0 (torch.Tensor): [N, L, C]
+ feat1 (torch.Tensor): [N, S, C]
+ data (dict)
+ mask_c0 (torch.Tensor): [N, L] (optional)
+ mask_c1 (torch.Tensor): [N, S] (optional)
+ Update:
+ data (dict): {
+ 'b_ids' (torch.Tensor): [M'],
+ 'i_ids' (torch.Tensor): [M'],
+ 'j_ids' (torch.Tensor): [M'],
+ 'gt_mask' (torch.Tensor): [M'],
+ 'mkpts0_c' (torch.Tensor): [M, 2],
+ 'mkpts1_c' (torch.Tensor): [M, 2],
+ 'mconf' (torch.Tensor): [M]}
+ NOTE: M' != M during training.
+ """
+ N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
+
+ # normalize
+ feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
+ [feat_c0, feat_c1])
+
+ if self.match_type == 'dual_softmax':
+ sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
+ feat_c1) / self.temperature
+ if mask_c0 is not None:
+ sim_matrix.masked_fill_(
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
+ -INF)
+ conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
+
+ elif self.match_type == 'sinkhorn':
+ # sinkhorn, dustbin included
+ sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)
+ if mask_c0 is not None:
+ sim_matrix[:, :L, :S].masked_fill_(
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
+ -INF)
+
+ # build uniform prior & use sinkhorn
+ log_assign_matrix = self.log_optimal_transport(
+ sim_matrix, self.bin_score, self.skh_iters)
+ assign_matrix = log_assign_matrix.exp()
+ conf_matrix = assign_matrix[:, :-1, :-1]
+
+ # filter prediction with dustbin score (only in evaluation mode)
+ if not self.training and self.skh_prefilter:
+ filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L]
+ filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S]
+ conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0
+ conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0
+
+ if self.config['sparse_spvs']:
+ data.update({'conf_matrix_with_bin': assign_matrix.clone()})
+
+ data.update({'conf_matrix': conf_matrix})
+
+ # predict coarse matches from conf_matrix
+ data.update(**self.get_coarse_match(conf_matrix, data))
+
+ @torch.no_grad()
+ def get_coarse_match(self, conf_matrix, data):
+ """
+ Args:
+ conf_matrix (torch.Tensor): [N, L, S]
+ data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
+ Returns:
+ coarse_matches (dict): {
+ 'b_ids' (torch.Tensor): [M'],
+ 'i_ids' (torch.Tensor): [M'],
+ 'j_ids' (torch.Tensor): [M'],
+ 'gt_mask' (torch.Tensor): [M'],
+ 'm_bids' (torch.Tensor): [M],
+ 'mkpts0_c' (torch.Tensor): [M, 2],
+ 'mkpts1_c' (torch.Tensor): [M, 2],
+ 'mconf' (torch.Tensor): [M]}
+ """
+ axes_lengths = {
+ 'h0c': data['hw0_c'][0],
+ 'w0c': data['hw0_c'][1],
+ 'h1c': data['hw1_c'][0],
+ 'w1c': data['hw1_c'][1]
+ }
+ _device = conf_matrix.device
+ # 1. confidence thresholding
+ mask = conf_matrix > self.thr
+ mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
+ **axes_lengths)
+ if 'mask0' not in data:
+ mask_border(mask, self.border_rm, False)
+ else:
+ mask_border_with_padding(mask, self.border_rm, False,
+ data['mask0'], data['mask1'])
+ mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
+ **axes_lengths)
+
+ # 2. mutual nearest
+ mask = mask \
+ * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
+ * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
+
+ # 3. find all valid coarse matches
+ # this only works when at most one `True` in each row
+ mask_v, all_j_ids = mask.max(dim=2)
+ b_ids, i_ids = torch.where(mask_v)
+ j_ids = all_j_ids[b_ids, i_ids]
+ mconf = conf_matrix[b_ids, i_ids, j_ids]
+
+ # 4. Random sampling of training samples for fine-level LoFTR
+ # (optional) pad samples with gt coarse-level matches
+ if self.training:
+ # NOTE:
+ # The sampling is performed across all pairs in a batch without manually balancing
+ # #samples for fine-level increases w.r.t. batch_size
+ if 'mask0' not in data:
+ num_candidates_max = mask.size(0) * max(
+ mask.size(1), mask.size(2))
+ else:
+ num_candidates_max = compute_max_candidates(
+ data['mask0'], data['mask1'])
+ num_matches_train = int(num_candidates_max *
+ self.train_coarse_percent)
+ num_matches_pred = len(b_ids)
+ assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
+
+ # pred_indices is to select from prediction
+ if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
+ pred_indices = torch.arange(num_matches_pred, device=_device)
+ else:
+ pred_indices = torch.randint(
+ num_matches_pred,
+ (num_matches_train - self.train_pad_num_gt_min, ),
+ device=_device)
+
+ # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
+ gt_pad_indices = torch.randint(
+ len(data['spv_b_ids']),
+ (max(num_matches_train - num_matches_pred,
+ self.train_pad_num_gt_min), ),
+ device=_device)
+ mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
+
+ b_ids, i_ids, j_ids, mconf = map(
+ lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
+ dim=0),
+ *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
+ [j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
+
+ # These matches select patches that feed into fine-level network
+ coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
+
+ # 4. Update with matches in original image resolution
+ scale = data['hw0_i'][0] / data['hw0_c'][0]
+ scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
+ mkpts0_c = torch.stack(
+ [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
+ dim=1) * scale0
+ mkpts1_c = torch.stack(
+ [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
+ dim=1) * scale1
+
+ # These matches is the current prediction (for visualization)
+ coarse_matches.update({
+ 'gt_mask': mconf == 0,
+ 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches
+ 'mkpts0_c': mkpts0_c[mconf != 0],
+ 'mkpts1_c': mkpts1_c[mconf != 0],
+ 'mconf': mconf[mconf != 0]
+ })
+
+ return coarse_matches
diff --git a/one2345_elev_est/oee/models/loftr/utils/cvpr_ds_config.py b/one2345_elev_est/oee/models/loftr/utils/cvpr_ds_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c9ce70154d3a1b961d3b4f08897415720f451f8
--- /dev/null
+++ b/one2345_elev_est/oee/models/loftr/utils/cvpr_ds_config.py
@@ -0,0 +1,50 @@
+from yacs.config import CfgNode as CN
+
+
+def lower_config(yacs_cfg):
+ if not isinstance(yacs_cfg, CN):
+ return yacs_cfg
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
+
+
+_CN = CN()
+_CN.BACKBONE_TYPE = 'ResNetFPN'
+_CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
+_CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
+_CN.FINE_CONCAT_COARSE_FEAT = True
+
+# 1. LoFTR-backbone (local feature CNN) config
+_CN.RESNETFPN = CN()
+_CN.RESNETFPN.INITIAL_DIM = 128
+_CN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
+
+# 2. LoFTR-coarse module config
+_CN.COARSE = CN()
+_CN.COARSE.D_MODEL = 256
+_CN.COARSE.D_FFN = 256
+_CN.COARSE.NHEAD = 8
+_CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
+_CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
+_CN.COARSE.TEMP_BUG_FIX = False
+
+# 3. Coarse-Matching config
+_CN.MATCH_COARSE = CN()
+_CN.MATCH_COARSE.THR = 0.2
+_CN.MATCH_COARSE.BORDER_RM = 2
+_CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn']
+_CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
+_CN.MATCH_COARSE.SKH_ITERS = 3
+_CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
+_CN.MATCH_COARSE.SKH_PREFILTER = True
+_CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory
+_CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
+
+# 4. LoFTR-fine module config
+_CN.FINE = CN()
+_CN.FINE.D_MODEL = 128
+_CN.FINE.D_FFN = 128
+_CN.FINE.NHEAD = 8
+_CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1
+_CN.FINE.ATTENTION = 'linear'
+
+default_cfg = lower_config(_CN)
diff --git a/one2345_elev_est/oee/models/loftr/utils/fine_matching.py b/one2345_elev_est/oee/models/loftr/utils/fine_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e77aded52e1eb5c01e22c2738104f3b09d6922a
--- /dev/null
+++ b/one2345_elev_est/oee/models/loftr/utils/fine_matching.py
@@ -0,0 +1,74 @@
+import math
+import torch
+import torch.nn as nn
+
+from kornia.geometry.subpix import dsnt
+from kornia.utils.grid import create_meshgrid
+
+
+class FineMatching(nn.Module):
+ """FineMatching with s2d paradigm"""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, feat_f0, feat_f1, data):
+ """
+ Args:
+ feat0 (torch.Tensor): [M, WW, C]
+ feat1 (torch.Tensor): [M, WW, C]
+ data (dict)
+ Update:
+ data (dict):{
+ 'expec_f' (torch.Tensor): [M, 3],
+ 'mkpts0_f' (torch.Tensor): [M, 2],
+ 'mkpts1_f' (torch.Tensor): [M, 2]}
+ """
+ M, WW, C = feat_f0.shape
+ W = int(math.sqrt(WW))
+ scale = data['hw0_i'][0] / data['hw0_f'][0]
+ self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
+
+ # corner case: if no coarse matches found
+ if M == 0:
+ assert self.training == False, "M is always >0, when training, see coarse_matching.py"
+ # logger.warning('No matches found in coarse-level.')
+ data.update({
+ 'expec_f': torch.empty(0, 3, device=feat_f0.device),
+ 'mkpts0_f': data['mkpts0_c'],
+ 'mkpts1_f': data['mkpts1_c'],
+ })
+ return
+
+ feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :]
+ sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
+ softmax_temp = 1. / C**.5
+ heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
+
+ # compute coordinates from heatmap
+ coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
+ grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2]
+
+ # compute std over
+ var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
+ std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
+
+ # for fine-level supervision
+ data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
+
+ # compute absolute kpt coords
+ self.get_fine_match(coords_normalized, data)
+
+ @torch.no_grad()
+ def get_fine_match(self, coords_normed, data):
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
+
+ # mkpts0_f and mkpts1_f
+ mkpts0_f = data['mkpts0_c']
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
+ mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])]
+
+ data.update({
+ "mkpts0_f": mkpts0_f,
+ "mkpts1_f": mkpts1_f
+ })
diff --git a/one2345_elev_est/oee/models/loftr/utils/geometry.py b/one2345_elev_est/oee/models/loftr/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..f95cdb65b48324c4f4ceb20231b1bed992b41116
--- /dev/null
+++ b/one2345_elev_est/oee/models/loftr/utils/geometry.py
@@ -0,0 +1,54 @@
+import torch
+
+
+@torch.no_grad()
+def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
+ """ Warp kpts0 from I0 to I1 with depth, K and Rt
+ Also check covisibility and depth consistency.
+ Depth is consistent if relative error < 0.2 (hard-coded).
+
+ Args:
+ kpts0 (torch.Tensor): [N, L, 2] - ,
+ depth0 (torch.Tensor): [N, H, W],
+ depth1 (torch.Tensor): [N, H, W],
+ T_0to1 (torch.Tensor): [N, 3, 4],
+ K0 (torch.Tensor): [N, 3, 3],
+ K1 (torch.Tensor): [N, 3, 3],
+ Returns:
+ calculable_mask (torch.Tensor): [N, L]
+ warped_keypoints0 (torch.Tensor): [N, L, 2]
+ """
+ kpts0_long = kpts0.round().long()
+
+ # Sample depth, get calculable_mask on depth != 0
+ kpts0_depth = torch.stack(
+ [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
+ ) # (N, L)
+ nonzero_mask = kpts0_depth != 0
+
+ # Unproject
+ kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
+ kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
+
+ # Rigid Transform
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
+
+ # Project
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
+ w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
+
+ # Covisible Check
+ h, w = depth1.shape[1:3]
+ covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
+ (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
+ w_kpts0_long = w_kpts0.long()
+ w_kpts0_long[~covisible_mask, :] = 0
+
+ w_kpts0_depth = torch.stack(
+ [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
+ ) # (N, L)
+ consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
+
+ return valid_mask, w_kpts0
diff --git a/one2345_elev_est/oee/models/loftr/utils/position_encoding.py b/one2345_elev_est/oee/models/loftr/utils/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..732d28c814ef93bf48d338ba7554f6dcfc3b880e
--- /dev/null
+++ b/one2345_elev_est/oee/models/loftr/utils/position_encoding.py
@@ -0,0 +1,42 @@
+import math
+import torch
+from torch import nn
+
+
+class PositionEncodingSine(nn.Module):
+ """
+ This is a sinusoidal position encoding that generalized to 2-dimensional images
+ """
+
+ def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):
+ """
+ Args:
+ max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
+ temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
+ the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
+ on the final performance. For now, we keep both impls for backward compatability.
+ We will remove the buggy impl after re-training all variants of our released models.
+ """
+ super().__init__()
+
+ pe = torch.zeros((d_model, *max_shape))
+ y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
+ x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
+ if temp_bug_fix:
+ div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
+ else: # a buggy implementation (for backward compatability only)
+ div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
+ div_term = div_term[:, None, None] # [C//4, 1, 1]
+ pe[0::4, :, :] = torch.sin(x_position * div_term)
+ pe[1::4, :, :] = torch.cos(x_position * div_term)
+ pe[2::4, :, :] = torch.sin(y_position * div_term)
+ pe[3::4, :, :] = torch.cos(y_position * div_term)
+
+ self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
+
+ def forward(self, x):
+ """
+ Args:
+ x: [N, C, H, W]
+ """
+ return x + self.pe[:, :, :x.size(2), :x.size(3)]
diff --git a/one2345_elev_est/oee/models/loftr/utils/supervision.py b/one2345_elev_est/oee/models/loftr/utils/supervision.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ce6e79ec72b45fcb6b187e33bda93a47b168acd
--- /dev/null
+++ b/one2345_elev_est/oee/models/loftr/utils/supervision.py
@@ -0,0 +1,151 @@
+from math import log
+from loguru import logger
+
+import torch
+from einops import repeat
+from kornia.utils import create_meshgrid
+
+from .geometry import warp_kpts
+
+############## ↓ Coarse-Level supervision ↓ ##############
+
+
+@torch.no_grad()
+def mask_pts_at_padded_regions(grid_pt, mask):
+ """For megadepth dataset, zero-padding exists in images"""
+ mask = repeat(mask, 'n h w -> n (h w) c', c=2)
+ grid_pt[~mask.bool()] = 0
+ return grid_pt
+
+
+@torch.no_grad()
+def spvs_coarse(data, config):
+ """
+ Update:
+ data (dict): {
+ "conf_matrix_gt": [N, hw0, hw1],
+ 'spv_b_ids': [M]
+ 'spv_i_ids': [M]
+ 'spv_j_ids': [M]
+ 'spv_w_pt0_i': [N, hw0, 2], in original image resolution
+ 'spv_pt1_i': [N, hw1, 2], in original image resolution
+ }
+
+ NOTE:
+ - for scannet dataset, there're 3 kinds of resolution {i, c, f}
+ - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
+ """
+ # 1. misc
+ device = data['image0'].device
+ N, _, H0, W0 = data['image0'].shape
+ _, _, H1, W1 = data['image1'].shape
+ scale = config['LOFTR']['RESOLUTION'][0]
+ scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
+ scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
+ h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
+
+ # 2. warp grids
+ # create kpts in meshgrid and resize them to image resolution
+ grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
+ grid_pt0_i = scale0 * grid_pt0_c
+ grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
+ grid_pt1_i = scale1 * grid_pt1_c
+
+ # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
+ if 'mask0' in data:
+ grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0'])
+ grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1'])
+
+ # warp kpts bi-directionally and resize them to coarse-level resolution
+ # (no depth consistency check, since it leads to worse results experimentally)
+ # (unhandled edge case: points with 0-depth will be warped to the left-up corner)
+ _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
+ _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
+ w_pt0_c = w_pt0_i / scale1
+ w_pt1_c = w_pt1_i / scale0
+
+ # 3. check if mutual nearest neighbor
+ w_pt0_c_round = w_pt0_c[:, :, :].round().long()
+ nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1
+ w_pt1_c_round = w_pt1_c[:, :, :].round().long()
+ nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0
+
+ # corner case: out of boundary
+ def out_bound_mask(pt, w, h):
+ return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
+ nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
+ nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
+
+ loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
+ correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1)
+ correct_0to1[:, 0] = False # ignore the top-left corner
+
+ # 4. construct a gt conf_matrix
+ conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
+ b_ids, i_ids = torch.where(correct_0to1 != 0)
+ j_ids = nearest_index1[b_ids, i_ids]
+
+ conf_matrix_gt[b_ids, i_ids, j_ids] = 1
+ data.update({'conf_matrix_gt': conf_matrix_gt})
+
+ # 5. save coarse matches(gt) for training fine level
+ if len(b_ids) == 0:
+ logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
+ # this won't affect fine-level loss calculation
+ b_ids = torch.tensor([0], device=device)
+ i_ids = torch.tensor([0], device=device)
+ j_ids = torch.tensor([0], device=device)
+
+ data.update({
+ 'spv_b_ids': b_ids,
+ 'spv_i_ids': i_ids,
+ 'spv_j_ids': j_ids
+ })
+
+ # 6. save intermediate results (for fast fine-level computation)
+ data.update({
+ 'spv_w_pt0_i': w_pt0_i,
+ 'spv_pt1_i': grid_pt1_i
+ })
+
+
+def compute_supervision_coarse(data, config):
+ assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
+ data_source = data['dataset_name'][0]
+ if data_source.lower() in ['scannet', 'megadepth']:
+ spvs_coarse(data, config)
+ else:
+ raise ValueError(f'Unknown data source: {data_source}')
+
+
+############## ↓ Fine-Level supervision ↓ ##############
+
+@torch.no_grad()
+def spvs_fine(data, config):
+ """
+ Update:
+ data (dict):{
+ "expec_f_gt": [M, 2]}
+ """
+ # 1. misc
+ # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
+ w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i']
+ scale = config['LOFTR']['RESOLUTION'][1]
+ radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2
+
+ # 2. get coarse prediction
+ b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
+
+ # 3. compute gt
+ scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale
+ # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
+ expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2]
+ data.update({"expec_f_gt": expec_f_gt})
+
+
+def compute_supervision_fine(data, config):
+ data_source = data['dataset_name'][0]
+ if data_source.lower() in ['scannet', 'megadepth']:
+ spvs_fine(data, config)
+ else:
+ raise NotImplementedError
diff --git a/one2345_elev_est/oee/utils/elev_est_api.py b/one2345_elev_est/oee/utils/elev_est_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..a82345c02eae79e19e450c7d4583467da153f501
--- /dev/null
+++ b/one2345_elev_est/oee/utils/elev_est_api.py
@@ -0,0 +1,206 @@
+import matplotlib.pyplot as plt
+import warnings
+
+import numpy as np
+import cv2
+import os
+import os.path as osp
+import imageio
+from copy import deepcopy
+
+import loguru
+import torch
+from oee.models.loftr import LoFTR, default_cfg
+import matplotlib.cm as cm
+
+from oee.utils import plt_utils
+from oee.utils.plotting import make_matching_figure
+from oee.utils.utils3d import rect_to_img, canonical_to_camera, calc_pose
+
+
+class ElevEstHelper:
+ _feature_matcher = None
+
+ @classmethod
+ def get_feature_matcher(cls):
+ if cls._feature_matcher is None:
+ loguru.logger.info("Loading feature matcher...")
+ _default_cfg = deepcopy(default_cfg)
+ _default_cfg['coarse']['temp_bug_fix'] = True # set to False when using the old ckpt
+ matcher = LoFTR(config=_default_cfg)
+ ckpt_path = "weights/indoor_ds_new.ckpt"
+ if not osp.exists(ckpt_path):
+ loguru.logger.info("Downloading feature matcher...")
+ os.makedirs("weights", exist_ok=True)
+ import gdown
+ gdown.cached_download(url="https://drive.google.com/uc?id=19s3QvcCWQ6g-N1PrYlDCg-2mOJZ3kkgS",
+ path=ckpt_path)
+ matcher.load_state_dict(torch.load(ckpt_path)['state_dict'])
+ matcher = matcher.eval().cuda()
+ cls._feature_matcher = matcher
+ return cls._feature_matcher
+
+
+def mask_out_bkgd(img_path, dbg=False):
+ img = imageio.imread_v2(img_path)
+ if img.shape[-1] == 4:
+ fg_mask = img[:, :, :3]
+ else:
+ loguru.logger.info("Image has no alpha channel, using thresholding to mask out background")
+ fg_mask = ~(img > 245).all(axis=-1)
+ if dbg:
+ plt.imshow(plt_utils.vis_mask(img, fg_mask.astype(np.uint8), color=[0, 255, 0]))
+ plt.show()
+ return fg_mask
+
+
+def get_feature_matching(img_paths, dbg=False):
+ assert len(img_paths) == 4
+ matcher = ElevEstHelper.get_feature_matcher()
+ feature_matching = {}
+ masks = []
+ for i in range(4):
+ mask = mask_out_bkgd(img_paths[i], dbg=dbg)
+ masks.append(mask)
+ for i in range(0, 4):
+ for j in range(i + 1, 4):
+ img0_pth = img_paths[i]
+ img1_pth = img_paths[j]
+ mask0 = masks[i]
+ mask1 = masks[j]
+ img0_raw = cv2.imread(img0_pth, cv2.IMREAD_GRAYSCALE)
+ img1_raw = cv2.imread(img1_pth, cv2.IMREAD_GRAYSCALE)
+ original_shape = img0_raw.shape
+ img0_raw_resized = cv2.resize(img0_raw, (480, 480))
+ img1_raw_resized = cv2.resize(img1_raw, (480, 480))
+
+ img0 = torch.from_numpy(img0_raw_resized)[None][None].cuda() / 255.
+ img1 = torch.from_numpy(img1_raw_resized)[None][None].cuda() / 255.
+ batch = {'image0': img0, 'image1': img1}
+
+ # Inference with LoFTR and get prediction
+ with torch.no_grad():
+ matcher(batch)
+ mkpts0 = batch['mkpts0_f'].cpu().numpy()
+ mkpts1 = batch['mkpts1_f'].cpu().numpy()
+ mconf = batch['mconf'].cpu().numpy()
+ mkpts0[:, 0] = mkpts0[:, 0] * original_shape[1] / 480
+ mkpts0[:, 1] = mkpts0[:, 1] * original_shape[0] / 480
+ mkpts1[:, 0] = mkpts1[:, 0] * original_shape[1] / 480
+ mkpts1[:, 1] = mkpts1[:, 1] * original_shape[0] / 480
+ keep0 = mask0[mkpts0[:, 1].astype(int), mkpts1[:, 0].astype(int)]
+ keep1 = mask1[mkpts1[:, 1].astype(int), mkpts1[:, 0].astype(int)]
+ keep = np.logical_and(keep0, keep1)
+ mkpts0 = mkpts0[keep]
+ mkpts1 = mkpts1[keep]
+ mconf = mconf[keep]
+ if dbg:
+ # Draw visualization
+ color = cm.jet(mconf)
+ text = [
+ 'LoFTR',
+ 'Matches: {}'.format(len(mkpts0)),
+ ]
+ fig = make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, text=text)
+ fig.show()
+ feature_matching[f"{i}_{j}"] = np.concatenate([mkpts0, mkpts1, mconf[:, None]], axis=1)
+
+ return feature_matching
+
+
+def gen_pose_hypothesis(center_elevation):
+ elevations = np.radians(
+ [center_elevation, center_elevation - 10, center_elevation + 10, center_elevation, center_elevation]) # 45~120
+ azimuths = np.radians([30, 30, 30, 20, 40])
+ input_poses = calc_pose(elevations, azimuths, len(azimuths))
+ input_poses = input_poses[1:]
+ input_poses[..., 1] *= -1
+ input_poses[..., 2] *= -1
+ return input_poses
+
+
+def ba_error_general(K, matches, poses):
+ projmat0 = K @ poses[0].inverse()[:3, :4]
+ projmat1 = K @ poses[1].inverse()[:3, :4]
+ match_01 = matches[0]
+ pts0 = match_01[:, :2]
+ pts1 = match_01[:, 2:4]
+ Xref = cv2.triangulatePoints(projmat0.cpu().numpy(), projmat1.cpu().numpy(),
+ pts0.cpu().numpy().T, pts1.cpu().numpy().T)
+ Xref = Xref[:3] / Xref[3:]
+ Xref = Xref.T
+ Xref = torch.from_numpy(Xref).cuda().float()
+ reproj_error = 0
+ for match, cp in zip(matches[1:], poses[2:]):
+ dist = (torch.norm(match_01[:, :2][:, None, :] - match[:, :2][None, :, :], dim=-1))
+ if dist.numel() > 0:
+ # print("dist.shape", dist.shape)
+ m0to2_index = dist.argmin(1)
+ keep = dist[torch.arange(match_01.shape[0]), m0to2_index] < 1
+ if keep.sum() > 0:
+ xref_in2 = rect_to_img(K, canonical_to_camera(Xref, cp.inverse()))
+ reproj_error2 = torch.norm(match[m0to2_index][keep][:, 2:4] - xref_in2[keep], dim=-1)
+ conf02 = match[m0to2_index][keep][:, -1]
+ reproj_error += (reproj_error2 * conf02).sum() / (conf02.sum())
+
+ return reproj_error
+
+
+def find_optim_elev(elevs, nimgs, matches, K, dbg=False):
+ errs = []
+ for elev in elevs:
+ err = 0
+ cam_poses = gen_pose_hypothesis(elev)
+ for start in range(nimgs - 1):
+ batch_matches, batch_poses = [], []
+ for i in range(start, nimgs + start):
+ ci = i % nimgs
+ batch_poses.append(cam_poses[ci])
+ for j in range(nimgs - 1):
+ key = f"{start}_{(start + j + 1) % nimgs}"
+ match = matches[key]
+ batch_matches.append(match)
+ err += ba_error_general(K, batch_matches, batch_poses)
+ errs.append(err)
+ errs = torch.tensor(errs)
+ if dbg:
+ plt.plot(elevs, errs)
+ plt.show()
+ optim_elev = elevs[torch.argmin(errs)].item()
+ return optim_elev
+
+
+def get_elev_est(feature_matching, min_elev=30, max_elev=150, K=None, dbg=False):
+ flag = True
+ matches = {}
+ for i in range(4):
+ for j in range(i + 1, 4):
+ match_ij = feature_matching[f"{i}_{j}"]
+ if len(match_ij) == 0:
+ flag = False
+ match_ji = np.concatenate([match_ij[:, 2:4], match_ij[:, 0:2], match_ij[:, 4:5]], axis=1)
+ matches[f"{i}_{j}"] = torch.from_numpy(match_ij).float().cuda()
+ matches[f"{j}_{i}"] = torch.from_numpy(match_ji).float().cuda()
+ if not flag:
+ loguru.logger.info("0 matches, could not estimate elevation")
+ return None
+ interval = 10
+ elevs = np.arange(min_elev, max_elev, interval)
+ optim_elev1 = find_optim_elev(elevs, 4, matches, K)
+
+ elevs = np.arange(optim_elev1 - 10, optim_elev1 + 10, 1)
+ optim_elev2 = find_optim_elev(elevs, 4, matches, K)
+
+ return optim_elev2
+
+
+def elev_est_api(img_paths, min_elev=30, max_elev=150, K=None, dbg=False):
+ feature_matching = get_feature_matching(img_paths, dbg=dbg)
+ if K is None:
+ loguru.logger.warning("K is not provided, using default K")
+ K = np.array([[280.0, 0, 128.0],
+ [0, 280.0, 128.0],
+ [0, 0, 1]])
+ K = torch.from_numpy(K).cuda().float()
+ elev = get_elev_est(feature_matching, min_elev, max_elev, K, dbg=dbg)
+ return elev
diff --git a/one2345_elev_est/oee/utils/plotting.py b/one2345_elev_est/oee/utils/plotting.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e7ac1de4b1fb6d0cbeda2f61eca81c68a9ba423
--- /dev/null
+++ b/one2345_elev_est/oee/utils/plotting.py
@@ -0,0 +1,154 @@
+import bisect
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib
+
+
+def _compute_conf_thresh(data):
+ dataset_name = data['dataset_name'][0].lower()
+ if dataset_name == 'scannet':
+ thr = 5e-4
+ elif dataset_name == 'megadepth':
+ thr = 1e-4
+ else:
+ raise ValueError(f'Unknown dataset: {dataset_name}')
+ return thr
+
+
+# --- VISUALIZATION --- #
+
+def make_matching_figure(
+ img0, img1, mkpts0, mkpts1, color,
+ kpts0=None, kpts1=None, text=[], dpi=75, path=None):
+ # draw image pair
+ assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
+ axes[0].imshow(img0, cmap='gray')
+ axes[1].imshow(img1, cmap='gray')
+ for i in range(2): # clear all frames
+ axes[i].get_yaxis().set_ticks([])
+ axes[i].get_xaxis().set_ticks([])
+ for spine in axes[i].spines.values():
+ spine.set_visible(False)
+ plt.tight_layout(pad=1)
+
+ if kpts0 is not None:
+ assert kpts1 is not None
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
+
+ # draw matches
+ if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
+ fig.canvas.draw()
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
+ fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
+ (fkpts0[i, 1], fkpts1[i, 1]),
+ transform=fig.transFigure, c=color[i], linewidth=1)
+ for i in range(len(mkpts0))]
+
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
+
+ # put txts
+ txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
+ fig.text(
+ 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
+ fontsize=15, va='top', ha='left', color=txt_color)
+
+ # save or return figure
+ if path:
+ plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
+ plt.close()
+ else:
+ return fig
+
+
+def _make_evaluation_figure(data, b_id, alpha='dynamic'):
+ b_mask = data['m_bids'] == b_id
+ conf_thr = _compute_conf_thresh(data)
+
+ img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
+ kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
+ kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
+
+ # for megadepth, we visualize matches on the resized image
+ if 'scale0' in data:
+ kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
+ kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]]
+
+ epi_errs = data['epi_errs'][b_mask].cpu().numpy()
+ correct_mask = epi_errs < conf_thr
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
+ n_correct = np.sum(correct_mask)
+ n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
+ recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
+ # recall might be larger than 1, since the calculation of conf_matrix_gt
+ # uses groundtruth depths and camera poses, but epipolar distance is used here.
+
+ # matching info
+ if alpha == 'dynamic':
+ alpha = dynamic_alpha(len(correct_mask))
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
+
+ text = [
+ f'#Matches {len(kpts0)}',
+ f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
+ f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
+ ]
+
+ # make the figure
+ figure = make_matching_figure(img0, img1, kpts0, kpts1,
+ color, text=text)
+ return figure
+
+def _make_confidence_figure(data, b_id):
+ # TODO: Implement confidence figure
+ raise NotImplementedError()
+
+
+def make_matching_figures(data, config, mode='evaluation'):
+ """ Make matching figures for a batch.
+
+ Args:
+ data (Dict): a batch updated by PL_LoFTR.
+ config (Dict): matcher config
+ Returns:
+ figures (Dict[str, List[plt.figure]]
+ """
+ assert mode in ['evaluation', 'confidence'] # 'confidence'
+ figures = {mode: []}
+ for b_id in range(data['image0'].size(0)):
+ if mode == 'evaluation':
+ fig = _make_evaluation_figure(
+ data, b_id,
+ alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
+ elif mode == 'confidence':
+ fig = _make_confidence_figure(data, b_id)
+ else:
+ raise ValueError(f'Unknown plot mode: {mode}')
+ figures[mode].append(fig)
+ return figures
+
+
+def dynamic_alpha(n_matches,
+ milestones=[0, 300, 1000, 2000],
+ alphas=[1.0, 0.8, 0.4, 0.2]):
+ if n_matches == 0:
+ return 1.0
+ ranges = list(zip(alphas, alphas[1:] + [None]))
+ loc = bisect.bisect_right(milestones, n_matches) - 1
+ _range = ranges[loc]
+ if _range[1] is None:
+ return _range[0]
+ return _range[1] + (milestones[loc + 1] - n_matches) / (
+ milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
+
+
+def error_colormap(err, thr, alpha=1.0):
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
+ x = 1 - np.clip(err / (thr * 2), 0, 1)
+ return np.clip(
+ np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
diff --git a/one2345_elev_est/oee/utils/plt_utils.py b/one2345_elev_est/oee/utils/plt_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..92353edab179de9f702633a01e123e94403bd83f
--- /dev/null
+++ b/one2345_elev_est/oee/utils/plt_utils.py
@@ -0,0 +1,318 @@
+import os.path as osp
+import os
+import matplotlib.pyplot as plt
+import torch
+import cv2
+import math
+
+import numpy as np
+import tqdm
+from cv2 import findContours
+from dl_ext.primitive import safe_zip
+from dl_ext.timer import EvalTime
+
+
+def plot_confidence(confidence):
+ n = len(confidence)
+ plt.plot(np.arange(n), confidence)
+ plt.show()
+
+
+def image_grid(
+ images,
+ rows=None,
+ cols=None,
+ fill: bool = True,
+ show_axes: bool = False,
+ rgb=None,
+ show=True,
+ label=None,
+ **kwargs
+):
+ """
+ A util function for plotting a grid of images.
+ Args:
+ images: (N, H, W, 4) array of RGBA images
+ rows: number of rows in the grid
+ cols: number of columns in the grid
+ fill: boolean indicating if the space between images should be filled
+ show_axes: boolean indicating if the axes of the plots should be visible
+ rgb: boolean, If True, only RGB channels are plotted.
+ If False, only the alpha channel is plotted.
+ Returns:
+ None
+ """
+ evaltime = EvalTime(disable=True)
+ evaltime('')
+ if isinstance(images, torch.Tensor):
+ images = images.detach().cpu()
+ if len(images[0].shape) == 2:
+ rgb = False
+ if images[0].shape[-1] == 2:
+ # flow
+ images = [flow_to_image(im) for im in images]
+ if (rows is None) != (cols is None):
+ raise ValueError("Specify either both rows and cols or neither.")
+
+ if rows is None:
+ rows = int(len(images) ** 0.5)
+ cols = math.ceil(len(images) / rows)
+
+ gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {}
+ if len(images) < 50:
+ figsize = (10, 10)
+ else:
+ figsize = (15, 15)
+ evaltime('0.5')
+ plt.figure(figsize=figsize)
+ # fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=figsize)
+ if label:
+ # fig.suptitle(label, fontsize=30)
+ plt.suptitle(label, fontsize=30)
+ # bleed = 0
+ # fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed))
+ evaltime('subplots')
+
+ # for i, (ax, im) in enumerate(tqdm.tqdm(zip(axarr.ravel(), images), leave=True, total=len(images))):
+ for i in range(len(images)):
+ # evaltime(f'{i} begin')
+ plt.subplot(rows, cols, i + 1)
+ if rgb:
+ # only render RGB channels
+ plt.imshow(images[i][..., :3], **kwargs)
+ # ax.imshow(im[..., :3], **kwargs)
+ else:
+ # only render Alpha channel
+ plt.imshow(images[i], **kwargs)
+ # ax.imshow(im, **kwargs)
+ if not show_axes:
+ plt.axis('off')
+ # ax.set_axis_off()
+ # ax.set_title(f'{i}')
+ plt.title(f'{i}')
+ # evaltime(f'{i} end')
+ evaltime('2')
+ if show:
+ plt.show()
+ # return fig
+
+
+def depth_grid(
+ depths,
+ rows=None,
+ cols=None,
+ fill: bool = True,
+ show_axes: bool = False,
+):
+ """
+ A util function for plotting a grid of images.
+ Args:
+ images: (N, H, W, 4) array of RGBA images
+ rows: number of rows in the grid
+ cols: number of columns in the grid
+ fill: boolean indicating if the space between images should be filled
+ show_axes: boolean indicating if the axes of the plots should be visible
+ rgb: boolean, If True, only RGB channels are plotted.
+ If False, only the alpha channel is plotted.
+ Returns:
+ None
+ """
+ if (rows is None) != (cols is None):
+ raise ValueError("Specify either both rows and cols or neither.")
+
+ if rows is None:
+ rows = len(depths)
+ cols = 1
+
+ gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {}
+ fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9))
+ bleed = 0
+ fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed))
+
+ for ax, im in zip(axarr.ravel(), depths):
+ ax.imshow(im)
+ if not show_axes:
+ ax.set_axis_off()
+ plt.show()
+
+
+def hover_masks_on_imgs(images, masks):
+ masks = np.array(masks)
+ new_imgs = []
+ tids = list(range(1, masks.max() + 1))
+ colors = colormap(rgb=True, lighten=True)
+ for im, mask in tqdm.tqdm(safe_zip(images, masks), total=len(images)):
+ for tid in tids:
+ im = vis_mask(
+ im,
+ (mask == tid).astype(np.uint8),
+ color=colors[tid],
+ alpha=0.5,
+ border_alpha=0.5,
+ border_color=[255, 255, 255],
+ border_thick=3)
+ new_imgs.append(im)
+ return new_imgs
+
+
+def vis_mask(img,
+ mask,
+ color=[255, 255, 255],
+ alpha=0.4,
+ show_border=True,
+ border_alpha=0.5,
+ border_thick=1,
+ border_color=None):
+ """Visualizes a single binary mask."""
+ if isinstance(mask, torch.Tensor):
+ from anypose.utils.pn_utils import to_array
+ mask = to_array(mask > 0).astype(np.uint8)
+ img = img.astype(np.float32)
+ idx = np.nonzero(mask)
+
+ img[idx[0], idx[1], :] *= 1.0 - alpha
+ img[idx[0], idx[1], :] += [alpha * x for x in color]
+
+ if show_border:
+ contours, _ = findContours(
+ mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
+ # contours = [c for c in contours if c.shape[0] > 10]
+ if border_color is None:
+ border_color = color
+ if not isinstance(border_color, list):
+ border_color = border_color.tolist()
+ if border_alpha < 1:
+ with_border = img.copy()
+ cv2.drawContours(with_border, contours, -1, border_color,
+ border_thick, cv2.LINE_AA)
+ img = (1 - border_alpha) * img + border_alpha * with_border
+ else:
+ cv2.drawContours(img, contours, -1, border_color, border_thick,
+ cv2.LINE_AA)
+
+ return img.astype(np.uint8)
+
+
+def colormap(rgb=False, lighten=True):
+ """Copied from Detectron codebase."""
+ color_list = np.array(
+ [
+ 0.000, 0.447, 0.741,
+ 0.850, 0.325, 0.098,
+ 0.929, 0.694, 0.125,
+ 0.494, 0.184, 0.556,
+ 0.466, 0.674, 0.188,
+ 0.301, 0.745, 0.933,
+ 0.635, 0.078, 0.184,
+ 0.300, 0.300, 0.300,
+ 0.600, 0.600, 0.600,
+ 1.000, 0.000, 0.000,
+ 1.000, 0.500, 0.000,
+ 0.749, 0.749, 0.000,
+ 0.000, 1.000, 0.000,
+ 0.000, 0.000, 1.000,
+ 0.667, 0.000, 1.000,
+ 0.333, 0.333, 0.000,
+ 0.333, 0.667, 0.000,
+ 0.333, 1.000, 0.000,
+ 0.667, 0.333, 0.000,
+ 0.667, 0.667, 0.000,
+ 0.667, 1.000, 0.000,
+ 1.000, 0.333, 0.000,
+ 1.000, 0.667, 0.000,
+ 1.000, 1.000, 0.000,
+ 0.000, 0.333, 0.500,
+ 0.000, 0.667, 0.500,
+ 0.000, 1.000, 0.500,
+ 0.333, 0.000, 0.500,
+ 0.333, 0.333, 0.500,
+ 0.333, 0.667, 0.500,
+ 0.333, 1.000, 0.500,
+ 0.667, 0.000, 0.500,
+ 0.667, 0.333, 0.500,
+ 0.667, 0.667, 0.500,
+ 0.667, 1.000, 0.500,
+ 1.000, 0.000, 0.500,
+ 1.000, 0.333, 0.500,
+ 1.000, 0.667, 0.500,
+ 1.000, 1.000, 0.500,
+ 0.000, 0.333, 1.000,
+ 0.000, 0.667, 1.000,
+ 0.000, 1.000, 1.000,
+ 0.333, 0.000, 1.000,
+ 0.333, 0.333, 1.000,
+ 0.333, 0.667, 1.000,
+ 0.333, 1.000, 1.000,
+ 0.667, 0.000, 1.000,
+ 0.667, 0.333, 1.000,
+ 0.667, 0.667, 1.000,
+ 0.667, 1.000, 1.000,
+ 1.000, 0.000, 1.000,
+ 1.000, 0.333, 1.000,
+ 1.000, 0.667, 1.000,
+ 0.167, 0.000, 0.000,
+ 0.333, 0.000, 0.000,
+ 0.500, 0.000, 0.000,
+ 0.667, 0.000, 0.000,
+ 0.833, 0.000, 0.000,
+ 1.000, 0.000, 0.000,
+ 0.000, 0.167, 0.000,
+ 0.000, 0.333, 0.000,
+ 0.000, 0.500, 0.000,
+ 0.000, 0.667, 0.000,
+ 0.000, 0.833, 0.000,
+ 0.000, 1.000, 0.000,
+ 0.000, 0.000, 0.167,
+ 0.000, 0.000, 0.333,
+ 0.000, 0.000, 0.500,
+ 0.000, 0.000, 0.667,
+ 0.000, 0.000, 0.833,
+ 0.000, 0.000, 1.000,
+ 0.000, 0.000, 0.000,
+ 0.143, 0.143, 0.143,
+ 0.286, 0.286, 0.286,
+ 0.429, 0.429, 0.429,
+ 0.571, 0.571, 0.571,
+ 0.714, 0.714, 0.714,
+ 0.857, 0.857, 0.857,
+ 1.000, 1.000, 1.000
+ ]
+ ).astype(np.float32)
+ color_list = color_list.reshape((-1, 3))
+ if not rgb:
+ color_list = color_list[:, ::-1]
+
+ if lighten:
+ # Make all the colors a little lighter / whiter. This is copied
+ # from the detectron visualization code (search for 'w_ratio').
+ w_ratio = 0.4
+ color_list = (color_list * (1 - w_ratio) + w_ratio)
+ return color_list * 255
+
+
+def vis_layer_mask(masks, save_path=None):
+ masks = torch.as_tensor(masks)
+ tids = masks.unique().tolist()
+ tids.remove(0)
+ for tid in tqdm.tqdm(tids):
+ show = save_path is None
+ image_grid(masks == tid, label=f'{tid}', show=show)
+ if save_path:
+ os.makedirs(osp.dirname(save_path), exist_ok=True)
+ plt.savefig(save_path % tid)
+ plt.close('all')
+
+
+def show(x, **kwargs):
+ if isinstance(x, torch.Tensor):
+ x = x.detach().cpu()
+ plt.imshow(x, **kwargs)
+ plt.show()
+
+
+def vis_title(rgb, text, shift_y=30):
+ tmp = rgb.copy()
+ shift_x = rgb.shape[1] // 2
+ cv2.putText(tmp, text,
+ (shift_x, shift_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=2, lineType=cv2.LINE_AA)
+ return tmp
diff --git a/one2345_elev_est/oee/utils/utils3d.py b/one2345_elev_est/oee/utils/utils3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cc92fbde4143a4ed5187c989e3f98a896e7caab
--- /dev/null
+++ b/one2345_elev_est/oee/utils/utils3d.py
@@ -0,0 +1,62 @@
+import numpy as np
+import torch
+
+
+def cart_to_hom(pts):
+ """
+ :param pts: (N, 3 or 2)
+ :return pts_hom: (N, 4 or 3)
+ """
+ if isinstance(pts, np.ndarray):
+ pts_hom = np.concatenate((pts, np.ones([*pts.shape[:-1], 1], dtype=np.float32)), -1)
+ else:
+ ones = torch.ones([*pts.shape[:-1], 1], dtype=torch.float32, device=pts.device)
+ pts_hom = torch.cat((pts, ones), dim=-1)
+ return pts_hom
+
+
+def hom_to_cart(pts):
+ return pts[..., :-1] / pts[..., -1:]
+
+
+def canonical_to_camera(pts, pose):
+ pts = cart_to_hom(pts)
+ pts = pts @ pose.transpose(-1, -2)
+ pts = hom_to_cart(pts)
+ return pts
+
+
+def rect_to_img(K, pts_rect):
+ from dl_ext.vision_ext.datasets.kitti.structures import Calibration
+ pts_2d_hom = pts_rect @ K.t()
+ pts_img = Calibration.hom_to_cart(pts_2d_hom)
+ return pts_img
+
+
+def calc_pose(phis, thetas, size, radius=1.2):
+ import torch
+ def normalize(vectors):
+ return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)
+
+ device = torch.device('cuda')
+ thetas = torch.FloatTensor(thetas).to(device)
+ phis = torch.FloatTensor(phis).to(device)
+
+ centers = torch.stack([
+ radius * torch.sin(thetas) * torch.sin(phis),
+ -radius * torch.cos(thetas) * torch.sin(phis),
+ radius * torch.cos(phis),
+ ], dim=-1) # [B, 3]
+
+ # lookat
+ forward_vector = normalize(centers).squeeze(0)
+ up_vector = torch.FloatTensor([0, 0, 1]).to(device).unsqueeze(0).repeat(size, 1)
+ right_vector = normalize(torch.cross(up_vector, forward_vector, dim=-1))
+ if right_vector.pow(2).sum() < 0.01:
+ right_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1)
+ up_vector = normalize(torch.cross(forward_vector, right_vector, dim=-1))
+
+ poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
+ poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
+ poses[:, :3, 3] = centers
+ return poses
diff --git a/one2345_elev_est/one2345_elev_est.egg-info/PKG-INFO b/one2345_elev_est/one2345_elev_est.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..3487d5a2d125c78739dd3655fe057b9769a7c54a
--- /dev/null
+++ b/one2345_elev_est/one2345_elev_est.egg-info/PKG-INFO
@@ -0,0 +1,4 @@
+Metadata-Version: 2.1
+Name: one2345-elev-est
+Version: 0.1
+Author: chenlinghao
diff --git a/one2345_elev_est/one2345_elev_est.egg-info/SOURCES.txt b/one2345_elev_est/one2345_elev_est.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..281c2ecbcdefaac7a27c6ef2a61b7958a3ca102a
--- /dev/null
+++ b/one2345_elev_est/one2345_elev_est.egg-info/SOURCES.txt
@@ -0,0 +1,5 @@
+setup.py
+one2345_elev_est.egg-info/PKG-INFO
+one2345_elev_est.egg-info/SOURCES.txt
+one2345_elev_est.egg-info/dependency_links.txt
+one2345_elev_est.egg-info/top_level.txt
\ No newline at end of file
diff --git a/one2345_elev_est/one2345_elev_est.egg-info/dependency_links.txt b/one2345_elev_est/one2345_elev_est.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/one2345_elev_est/one2345_elev_est.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/one2345_elev_est/one2345_elev_est.egg-info/top_level.txt b/one2345_elev_est/one2345_elev_est.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/one2345_elev_est/one2345_elev_est.egg-info/top_level.txt
@@ -0,0 +1 @@
+
diff --git a/one2345_elev_est/requirements.txt b/one2345_elev_est/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d6459497e3f3aad592f73f06d0207167cb2b4ec5
--- /dev/null
+++ b/one2345_elev_est/requirements.txt
@@ -0,0 +1,42 @@
+dl_ext
+easydict
+glumpy
+gym
+h5py
+imageio
+loguru
+matplotlib
+mplib
+multipledispatch
+# numpy
+open3d
+packaging
+# pandas
+Pillow
+pycocotools
+# pyk4a
+motion-planning
+# pyrealsense2
+pyrender
+# pytorch3d
+PyYAML
+scikit_image
+scikit_learn
+scipy
+screeninfo
+# seaborn
+setuptools
+# skimage
+tensorboardX
+termcolor
+# torch
+# torchvision
+tqdm
+transforms3d
+trimesh
+yacs
+zarr
+sapien
+pyglet==1.5.27
+wis3d
+git+https://github.com/NVlabs/nvdiffrast.git
diff --git a/one2345_elev_est/setup.py b/one2345_elev_est/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..370ef83c85e5700ee440957117ef9382304f8321
--- /dev/null
+++ b/one2345_elev_est/setup.py
@@ -0,0 +1,9 @@
+from setuptools import find_packages
+from setuptools import setup
+
+setup(
+ name="one2345_elev_est",
+ version="0.1",
+ author="chenlinghao",
+ packages=find_packages(exclude=("configs", "tests",)),
+)
diff --git a/one2345_elev_est/tools/estimate_wild_imgs.py b/one2345_elev_est/tools/estimate_wild_imgs.py
new file mode 100644
index 0000000000000000000000000000000000000000..47103d46e48d3758a2cc237f659dd6fe5fc183c7
--- /dev/null
+++ b/one2345_elev_est/tools/estimate_wild_imgs.py
@@ -0,0 +1,35 @@
+import tqdm
+import imageio
+import json
+import os.path as osp
+import os
+
+from oee.utils import plt_utils
+from oee.utils.elev_est_api import elev_est_api
+
+
+def visualize(img_paths, elev):
+ imgs = [imageio.imread_v2(img_path) for img_path in img_paths]
+ plt_utils.image_grid(imgs, 2, 2, label=f"elev={elev}")
+
+
+def estimate_elev(root_dir):
+ # root_dir = "/home/linghao/Datasets/objaverse-processed/zero12345_img/wild"
+ # dataset = "supp_fail"
+ # root_dir = "/home/chao/chao/OpenComplete/zero123/zero123/gradio_tmp/"
+ # obj_names = sorted(os.listdir(root_dir))
+ # results = {}
+ # for obj_name in tqdm.tqdm(obj_names):
+ img_dir = osp.join(root_dir, "stage2_8")
+ img_paths = []
+ for i in range(4):
+ img_paths.append(f"{img_dir}/0_{i}.png")
+ elev = elev_est_api(img_paths)
+ # visualize(img_paths, elev)
+ # results[obj_name] = elev
+ # json.dump(results, open(osp.join(root_dir, f"../{dataset}_elev.json"), "w"), indent=4)
+ return elev
+
+
+# if __name__ == '__main__':
+# main()
diff --git a/one2345_elev_est/tools/example.py b/one2345_elev_est/tools/example.py
new file mode 100644
index 0000000000000000000000000000000000000000..065f31e3e64f21494b04ca2aed87e665ddc6d23d
--- /dev/null
+++ b/one2345_elev_est/tools/example.py
@@ -0,0 +1,38 @@
+import imageio
+import numpy as np
+
+from oee.utils import plt_utils
+from oee.utils.elev_est_api import elev_est_api
+import argparse
+
+
+def visualize(img_paths, elev):
+ imgs = [imageio.imread_v2(img_path) for img_path in img_paths]
+ plt_utils.image_grid(imgs, 2, 2, label=f"elev={elev}")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--img_paths", type=str, nargs=4, help="image paths",
+ default=["assets/example_data/0_0.png",
+ "assets/example_data/0_1.png",
+ "assets/example_data/0_2.png",
+ "assets/example_data/0_3.png"])
+ parser.add_argument("--min_elev", type=float, default=30, help="min elevation")
+ parser.add_argument("--max_elev", type=float, default=150, help="max elevation")
+ parser.add_argument("--dbg", default=False, action="store_true", help="debug mode")
+ parser.add_argument("--K_path", type=str, default=None, help="path to K")
+ args = parser.parse_args()
+
+ if args.K_path is not None:
+ K = np.loadtxt(args.K_path)
+ else:
+ K = None
+
+ elev = elev_est_api(args.img_paths, args.min_elev, args.max_elev, K, args.dbg)
+
+ visualize(args.img_paths, elev)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/one2345_elev_est/tools/weights/indoor_ds_new.ckpt b/one2345_elev_est/tools/weights/indoor_ds_new.ckpt
new file mode 100644
index 0000000000000000000000000000000000000000..ef68cc903b08e710d46a51c2aeb97407e047f3a5
--- /dev/null
+++ b/one2345_elev_est/tools/weights/indoor_ds_new.ckpt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:be9ff88b323ec27889114719f668ae41aff7034b56a4c4acbd46b8b180b87ed3
+size 46355053
diff --git a/taming-transformers b/taming-transformers
new file mode 160000
index 0000000000000000000000000000000000000000..3ba01b241669f5ade541ce990f7650a3b8f65318
--- /dev/null
+++ b/taming-transformers
@@ -0,0 +1 @@
+Subproject commit 3ba01b241669f5ade541ce990f7650a3b8f65318