Spaces:
Runtime error
Runtime error
import os.path as osp | |
import numpy as np | |
import os | |
import sys | |
sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) | |
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset | |
from dust3r.utils.image import imread_cv2 | |
import h5py | |
from tqdm import tqdm | |
class BlendedMVS_Multi(BaseMultiViewDataset): | |
"""Dataset of outdoor street scenes, 5 images each time""" | |
def __init__(self, *args, ROOT, split=None, **kwargs): | |
self.ROOT = ROOT | |
self.video = False | |
self.is_metric = False | |
super().__init__(*args, **kwargs) | |
# assert split is None | |
self._load_data() | |
def _load_data(self): | |
self.data_dict = self.read_h5_file(os.path.join(self.ROOT, "new_overlap.h5")) | |
self.num_imgs = sum( | |
[len(self.data_dict[s]["basenames"]) for s in self.data_dict.keys()] | |
) | |
self.num_scenes = len(self.data_dict.keys()) | |
self.invalid_scenes = [] | |
self.is_reachable_cache = {scene: {} for scene in self.data_dict.keys()} | |
def read_h5_file(self, h5_file_path): | |
data_dict = {} | |
self.all_ref_imgs = [] | |
with h5py.File(h5_file_path, "r") as f: | |
for scene_dir in tqdm(f.keys()): | |
group = f[scene_dir] | |
basenames = group["basenames"][:] | |
indices = group["indices"][:] | |
values = group["values"][:] | |
shape = group.attrs["shape"] | |
# Reconstruct the sparse matrix | |
score_matrix = np.zeros(shape, dtype=np.float32) | |
score_matrix[indices[0], indices[1]] = values | |
data_dict[scene_dir] = { | |
"basenames": basenames, | |
"score_matrix": self.build_adjacency_list(score_matrix), | |
} | |
self.all_ref_imgs.extend( | |
[(scene_dir, b) for b in range(len(basenames))] | |
) | |
return data_dict | |
def build_adjacency_list(S, thresh=0.2): | |
adjacency_list = [[] for _ in range(len(S))] | |
S = S - thresh | |
S[S < 0] = 0 | |
rows, cols = np.nonzero(S) | |
for i, j in zip(rows, cols): | |
adjacency_list[i].append((j, S[i][j])) | |
return adjacency_list | |
def is_reachable(adjacency_list, start_index, k): | |
visited = set() | |
stack = [start_index] | |
while stack and len(visited) < k: | |
node = stack.pop() | |
if node not in visited: | |
visited.add(node) | |
for neighbor in adjacency_list[node]: | |
if neighbor[0] not in visited: | |
stack.append(neighbor[0]) | |
return len(visited) >= k | |
def random_sequence_no_revisit_with_backtracking( | |
adjacency_list, k, start_index, rng: np.random.Generator | |
): | |
path = [start_index] | |
visited = set([start_index]) | |
neighbor_iterators = [] | |
# Initialize the iterator for the start index | |
neighbors = adjacency_list[start_index] | |
neighbor_idxs = [n[0] for n in neighbors] | |
neighbor_weights = [n[1] for n in neighbors] | |
neighbor_idxs = rng.choice( | |
neighbor_idxs, | |
size=len(neighbor_idxs), | |
replace=False, | |
p=np.array(neighbor_weights) / np.sum(neighbor_weights), | |
).tolist() | |
neighbor_iterators.append(iter(neighbor_idxs)) | |
while len(path) < k: | |
if not neighbor_iterators: | |
# No possible sequence | |
return None | |
current_iterator = neighbor_iterators[-1] | |
try: | |
next_index = next(current_iterator) | |
if next_index not in visited: | |
path.append(next_index) | |
visited.add(next_index) | |
# Prepare iterator for the next node | |
neighbors = adjacency_list[next_index] | |
neighbor_idxs = [n[0] for n in neighbors] | |
neighbor_weights = [n[1] for n in neighbors] | |
neighbor_idxs = rng.choice( | |
neighbor_idxs, | |
size=len(neighbor_idxs), | |
replace=False, | |
p=np.array(neighbor_weights) / np.sum(neighbor_weights), | |
).tolist() | |
neighbor_iterators.append(iter(neighbor_idxs)) | |
except StopIteration: | |
# No more neighbors to try at this node, backtrack | |
neighbor_iterators.pop() | |
visited.remove(path.pop()) | |
return path | |
def random_sequence_with_optional_repeats( | |
adjacency_list, | |
k, | |
start_index, | |
rng: np.random.Generator, | |
max_k=None, | |
max_attempts=100, | |
): | |
if max_k is None: | |
max_k = k | |
path = [start_index] | |
visited = set([start_index]) | |
current_index = start_index | |
attempts = 0 | |
while len(path) < max_k and attempts < max_attempts: | |
attempts += 1 | |
neighbors = adjacency_list[current_index] | |
neighbor_idxs = [n[0] for n in neighbors] | |
neighbor_weights = [n[1] for n in neighbors] | |
if not neighbor_idxs: | |
# No neighbors, cannot proceed further | |
break | |
# Try to find unvisited neighbors | |
unvisited_neighbors = [ | |
(idx, wgt) | |
for idx, wgt in zip(neighbor_idxs, neighbor_weights) | |
if idx not in visited | |
] | |
if unvisited_neighbors: | |
# Select among unvisited neighbors | |
unvisited_idxs = [idx for idx, _ in unvisited_neighbors] | |
unvisited_weights = [wgt for _, wgt in unvisited_neighbors] | |
probabilities = np.array(unvisited_weights) / np.sum(unvisited_weights) | |
next_index = rng.choice(unvisited_idxs, p=probabilities) | |
visited.add(next_index) | |
else: | |
# All neighbors visited, but we need to reach length max_k | |
# So we can revisit nodes | |
probabilities = np.array(neighbor_weights) / np.sum(neighbor_weights) | |
next_index = rng.choice(neighbor_idxs, p=probabilities) | |
path.append(next_index) | |
current_index = next_index | |
if len(set(path)) >= k: | |
# If path is shorter than max_k, extend it by repeating existing elements | |
while len(path) < max_k: | |
# Randomly select nodes from the existing path to repeat | |
next_index = rng.choice(path) | |
path.append(next_index) | |
return path | |
else: | |
# Could not reach k unique nodes | |
return None | |
def __len__(self): | |
return len(self.all_ref_imgs) | |
def get_image_num(self): | |
return self.num_imgs | |
def get_stats(self): | |
return f"{len(self)} imgs from {self.num_scenes} scenes" | |
def generate_sequence( | |
self, scene, adj_list, num_views, start_index, rng, allow_repeat=False | |
): | |
cutoff = num_views if not allow_repeat else max(num_views // 5, 3) | |
if start_index in self.is_reachable_cache[scene]: | |
if not self.is_reachable_cache[scene][start_index]: | |
print( | |
f"Cannot reach {num_views} unique elements from index {start_index}." | |
) | |
return None | |
else: | |
self.is_reachable_cache[scene][start_index] = self.is_reachable( | |
adj_list, start_index, cutoff | |
) | |
if not self.is_reachable_cache[scene][start_index]: | |
print( | |
f"Cannot reach {num_views} unique elements from index {start_index}." | |
) | |
return None | |
if not allow_repeat: | |
sequence = self.random_sequence_no_revisit_with_backtracking( | |
adj_list, cutoff, start_index, rng | |
) | |
else: | |
sequence = self.random_sequence_with_optional_repeats( | |
adj_list, cutoff, start_index, rng, max_k=num_views | |
) | |
if not sequence: | |
self.is_reachable_cache[scene][start_index] = False | |
print("Failed to generate a sequence without revisiting.") | |
return sequence | |
def _get_views(self, idx, resolution, rng: np.random.Generator, num_views): | |
scene_info, ref_img_idx = self.all_ref_imgs[idx] | |
invalid_seq = True | |
ordered_video = False | |
while invalid_seq: | |
basenames = self.data_dict[scene_info]["basenames"] | |
if ( | |
sum( | |
[ | |
(1 - int(x)) | |
for x in list(self.is_reachable_cache[scene_info].values()) | |
] | |
) | |
> len(basenames) - self.num_views | |
): | |
self.invalid_scenes.append(scene_info) | |
while scene_info in self.invalid_scenes: | |
idx = rng.integers(low=0, high=len(self.all_ref_imgs)) | |
scene_info, ref_img_idx = self.all_ref_imgs[idx] | |
basenames = self.data_dict[scene_info]["basenames"] | |
score_matrix = self.data_dict[scene_info]["score_matrix"] | |
imgs_idxs = self.generate_sequence( | |
scene_info, score_matrix, num_views, ref_img_idx, rng, self.allow_repeat | |
) | |
if imgs_idxs is None: | |
random_direction = 2 * rng.choice(2) - 1 | |
for offset in range(1, len(basenames)): | |
tentative_im_idx = ( | |
ref_img_idx + (random_direction * offset) | |
) % len(basenames) | |
if ( | |
tentative_im_idx not in self.is_reachable_cache[scene_info] | |
or self.is_reachable_cache[scene_info][tentative_im_idx] | |
): | |
ref_img_idx = tentative_im_idx | |
break | |
else: | |
invalid_seq = False | |
views = [] | |
for view_idx in imgs_idxs: | |
scene_dir = osp.join(self.ROOT, scene_info) | |
impath = basenames[view_idx].decode("utf-8") | |
image = imread_cv2(osp.join(scene_dir, impath + ".jpg")) | |
depthmap = imread_cv2(osp.join(scene_dir, impath + ".exr")) | |
camera_params = np.load(osp.join(scene_dir, impath + ".npz")) | |
intrinsics = np.float32(camera_params["intrinsics"]) | |
camera_pose = np.eye(4, dtype=np.float32) | |
camera_pose[:3, :3] = camera_params["R_cam2world"] | |
camera_pose[:3, 3] = camera_params["t_cam2world"] | |
image, depthmap, intrinsics = self._crop_resize_if_necessary( | |
image, depthmap, intrinsics, resolution, rng, info=(scene_dir, impath) | |
) | |
views.append( | |
dict( | |
img=image, | |
depthmap=depthmap, | |
camera_pose=camera_pose, # cam2world | |
camera_intrinsics=intrinsics, | |
dataset="BlendedMVS", | |
label=osp.relpath(scene_dir, self.ROOT), | |
is_metric=self.is_metric, | |
is_video=ordered_video, | |
instance=osp.join(scene_dir, impath + ".jpg"), | |
quantile=np.array(0.97, dtype=np.float32), | |
img_mask=True, | |
ray_mask=False, | |
camera_only=False, | |
depth_only=False, | |
single_view=False, | |
reset=False, | |
) | |
) | |
assert len(views) == num_views | |
return views | |