liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
import os.path as osp
import numpy as np
import os
import sys
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
from dust3r.utils.image import imread_cv2
import h5py
from tqdm import tqdm
class BlendedMVS_Multi(BaseMultiViewDataset):
"""Dataset of outdoor street scenes, 5 images each time"""
def __init__(self, *args, ROOT, split=None, **kwargs):
self.ROOT = ROOT
self.video = False
self.is_metric = False
super().__init__(*args, **kwargs)
# assert split is None
self._load_data()
def _load_data(self):
self.data_dict = self.read_h5_file(os.path.join(self.ROOT, "new_overlap.h5"))
self.num_imgs = sum(
[len(self.data_dict[s]["basenames"]) for s in self.data_dict.keys()]
)
self.num_scenes = len(self.data_dict.keys())
self.invalid_scenes = []
self.is_reachable_cache = {scene: {} for scene in self.data_dict.keys()}
def read_h5_file(self, h5_file_path):
data_dict = {}
self.all_ref_imgs = []
with h5py.File(h5_file_path, "r") as f:
for scene_dir in tqdm(f.keys()):
group = f[scene_dir]
basenames = group["basenames"][:]
indices = group["indices"][:]
values = group["values"][:]
shape = group.attrs["shape"]
# Reconstruct the sparse matrix
score_matrix = np.zeros(shape, dtype=np.float32)
score_matrix[indices[0], indices[1]] = values
data_dict[scene_dir] = {
"basenames": basenames,
"score_matrix": self.build_adjacency_list(score_matrix),
}
self.all_ref_imgs.extend(
[(scene_dir, b) for b in range(len(basenames))]
)
return data_dict
@staticmethod
def build_adjacency_list(S, thresh=0.2):
adjacency_list = [[] for _ in range(len(S))]
S = S - thresh
S[S < 0] = 0
rows, cols = np.nonzero(S)
for i, j in zip(rows, cols):
adjacency_list[i].append((j, S[i][j]))
return adjacency_list
@staticmethod
def is_reachable(adjacency_list, start_index, k):
visited = set()
stack = [start_index]
while stack and len(visited) < k:
node = stack.pop()
if node not in visited:
visited.add(node)
for neighbor in adjacency_list[node]:
if neighbor[0] not in visited:
stack.append(neighbor[0])
return len(visited) >= k
@staticmethod
def random_sequence_no_revisit_with_backtracking(
adjacency_list, k, start_index, rng: np.random.Generator
):
path = [start_index]
visited = set([start_index])
neighbor_iterators = []
# Initialize the iterator for the start index
neighbors = adjacency_list[start_index]
neighbor_idxs = [n[0] for n in neighbors]
neighbor_weights = [n[1] for n in neighbors]
neighbor_idxs = rng.choice(
neighbor_idxs,
size=len(neighbor_idxs),
replace=False,
p=np.array(neighbor_weights) / np.sum(neighbor_weights),
).tolist()
neighbor_iterators.append(iter(neighbor_idxs))
while len(path) < k:
if not neighbor_iterators:
# No possible sequence
return None
current_iterator = neighbor_iterators[-1]
try:
next_index = next(current_iterator)
if next_index not in visited:
path.append(next_index)
visited.add(next_index)
# Prepare iterator for the next node
neighbors = adjacency_list[next_index]
neighbor_idxs = [n[0] for n in neighbors]
neighbor_weights = [n[1] for n in neighbors]
neighbor_idxs = rng.choice(
neighbor_idxs,
size=len(neighbor_idxs),
replace=False,
p=np.array(neighbor_weights) / np.sum(neighbor_weights),
).tolist()
neighbor_iterators.append(iter(neighbor_idxs))
except StopIteration:
# No more neighbors to try at this node, backtrack
neighbor_iterators.pop()
visited.remove(path.pop())
return path
@staticmethod
def random_sequence_with_optional_repeats(
adjacency_list,
k,
start_index,
rng: np.random.Generator,
max_k=None,
max_attempts=100,
):
if max_k is None:
max_k = k
path = [start_index]
visited = set([start_index])
current_index = start_index
attempts = 0
while len(path) < max_k and attempts < max_attempts:
attempts += 1
neighbors = adjacency_list[current_index]
neighbor_idxs = [n[0] for n in neighbors]
neighbor_weights = [n[1] for n in neighbors]
if not neighbor_idxs:
# No neighbors, cannot proceed further
break
# Try to find unvisited neighbors
unvisited_neighbors = [
(idx, wgt)
for idx, wgt in zip(neighbor_idxs, neighbor_weights)
if idx not in visited
]
if unvisited_neighbors:
# Select among unvisited neighbors
unvisited_idxs = [idx for idx, _ in unvisited_neighbors]
unvisited_weights = [wgt for _, wgt in unvisited_neighbors]
probabilities = np.array(unvisited_weights) / np.sum(unvisited_weights)
next_index = rng.choice(unvisited_idxs, p=probabilities)
visited.add(next_index)
else:
# All neighbors visited, but we need to reach length max_k
# So we can revisit nodes
probabilities = np.array(neighbor_weights) / np.sum(neighbor_weights)
next_index = rng.choice(neighbor_idxs, p=probabilities)
path.append(next_index)
current_index = next_index
if len(set(path)) >= k:
# If path is shorter than max_k, extend it by repeating existing elements
while len(path) < max_k:
# Randomly select nodes from the existing path to repeat
next_index = rng.choice(path)
path.append(next_index)
return path
else:
# Could not reach k unique nodes
return None
def __len__(self):
return len(self.all_ref_imgs)
def get_image_num(self):
return self.num_imgs
def get_stats(self):
return f"{len(self)} imgs from {self.num_scenes} scenes"
def generate_sequence(
self, scene, adj_list, num_views, start_index, rng, allow_repeat=False
):
cutoff = num_views if not allow_repeat else max(num_views // 5, 3)
if start_index in self.is_reachable_cache[scene]:
if not self.is_reachable_cache[scene][start_index]:
print(
f"Cannot reach {num_views} unique elements from index {start_index}."
)
return None
else:
self.is_reachable_cache[scene][start_index] = self.is_reachable(
adj_list, start_index, cutoff
)
if not self.is_reachable_cache[scene][start_index]:
print(
f"Cannot reach {num_views} unique elements from index {start_index}."
)
return None
if not allow_repeat:
sequence = self.random_sequence_no_revisit_with_backtracking(
adj_list, cutoff, start_index, rng
)
else:
sequence = self.random_sequence_with_optional_repeats(
adj_list, cutoff, start_index, rng, max_k=num_views
)
if not sequence:
self.is_reachable_cache[scene][start_index] = False
print("Failed to generate a sequence without revisiting.")
return sequence
def _get_views(self, idx, resolution, rng: np.random.Generator, num_views):
scene_info, ref_img_idx = self.all_ref_imgs[idx]
invalid_seq = True
ordered_video = False
while invalid_seq:
basenames = self.data_dict[scene_info]["basenames"]
if (
sum(
[
(1 - int(x))
for x in list(self.is_reachable_cache[scene_info].values())
]
)
> len(basenames) - self.num_views
):
self.invalid_scenes.append(scene_info)
while scene_info in self.invalid_scenes:
idx = rng.integers(low=0, high=len(self.all_ref_imgs))
scene_info, ref_img_idx = self.all_ref_imgs[idx]
basenames = self.data_dict[scene_info]["basenames"]
score_matrix = self.data_dict[scene_info]["score_matrix"]
imgs_idxs = self.generate_sequence(
scene_info, score_matrix, num_views, ref_img_idx, rng, self.allow_repeat
)
if imgs_idxs is None:
random_direction = 2 * rng.choice(2) - 1
for offset in range(1, len(basenames)):
tentative_im_idx = (
ref_img_idx + (random_direction * offset)
) % len(basenames)
if (
tentative_im_idx not in self.is_reachable_cache[scene_info]
or self.is_reachable_cache[scene_info][tentative_im_idx]
):
ref_img_idx = tentative_im_idx
break
else:
invalid_seq = False
views = []
for view_idx in imgs_idxs:
scene_dir = osp.join(self.ROOT, scene_info)
impath = basenames[view_idx].decode("utf-8")
image = imread_cv2(osp.join(scene_dir, impath + ".jpg"))
depthmap = imread_cv2(osp.join(scene_dir, impath + ".exr"))
camera_params = np.load(osp.join(scene_dir, impath + ".npz"))
intrinsics = np.float32(camera_params["intrinsics"])
camera_pose = np.eye(4, dtype=np.float32)
camera_pose[:3, :3] = camera_params["R_cam2world"]
camera_pose[:3, 3] = camera_params["t_cam2world"]
image, depthmap, intrinsics = self._crop_resize_if_necessary(
image, depthmap, intrinsics, resolution, rng, info=(scene_dir, impath)
)
views.append(
dict(
img=image,
depthmap=depthmap,
camera_pose=camera_pose, # cam2world
camera_intrinsics=intrinsics,
dataset="BlendedMVS",
label=osp.relpath(scene_dir, self.ROOT),
is_metric=self.is_metric,
is_video=ordered_video,
instance=osp.join(scene_dir, impath + ".jpg"),
quantile=np.array(0.97, dtype=np.float32),
img_mask=True,
ray_mask=False,
camera_only=False,
depth_only=False,
single_view=False,
reset=False,
)
)
assert len(views) == num_views
return views