liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
import os.path as osp
import numpy as np
import cv2
import numpy as np
import itertools
import os
import sys
import pickle
import h5py
from tqdm import tqdm
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
from dust3r.utils.image import imread_cv2
class MapFree_Multi(BaseMultiViewDataset):
def __init__(self, ROOT, *args, **kwargs):
self.ROOT = ROOT
self.video = True
self.is_metric = True
self.max_interval = 30
super().__init__(*args, **kwargs)
self._load_data()
def imgid2path(self, img_id, scene):
first_seq_id, first_frame_id = img_id
return os.path.join(
self.ROOT,
scene,
f"dense{first_seq_id}",
"rgb",
f"frame_{first_frame_id:05d}.jpg",
)
def path2imgid(self, subscene, filename):
first_seq_id = int(subscene[5:])
first_frame_id = int(filename[6:-4])
return [first_seq_id, first_frame_id]
def _load_data(self):
cache_file = f"{self.ROOT}/cached_metadata_50_col_only.h5"
if os.path.exists(cache_file):
print(f"Loading cached metadata from {cache_file}")
with h5py.File(cache_file, "r") as hf:
self.scenes = list(map(lambda x: x.decode("utf-8"), hf["scenes"][:]))
self.sceneids = hf["sceneids"][:]
self.scope = hf["scope"][:]
self.video_flags = hf["video_flags"][:]
self.groups = hf["groups"][:]
self.id_ranges = hf["id_ranges"][:]
self.images = hf["images"][:]
else:
scene_dirs = sorted(
[
d
for d in os.listdir(self.ROOT)
if os.path.isdir(os.path.join(self.ROOT, d))
]
)
scenes = []
sceneids = []
groups = []
scope = []
images = []
id_ranges = []
is_video = []
start = 0
j = 0
offset = 0
for scene in tqdm(scene_dirs):
scenes.append(scene)
# video sequences
subscenes = sorted(
[
d
for d in os.listdir(os.path.join(self.ROOT, scene))
if d.startswith("dense")
]
)
id_range_subscenes = []
for subscene in subscenes:
rgb_paths = sorted(
[
d
for d in os.listdir(
os.path.join(self.ROOT, scene, subscene, "rgb")
)
if d.endswith(".jpg")
]
)
assert (
len(rgb_paths) > 0
), f"{os.path.join(self.ROOT, scene, subscene)} is empty."
num_imgs = len(rgb_paths)
images.extend(
[self.path2imgid(subscene, rgb_path) for rgb_path in rgb_paths]
)
id_range_subscenes.append((offset, offset + num_imgs))
offset += num_imgs
# image collections
metadata = pickle.load(
open(os.path.join(self.ROOT, scene, "metadata.pkl"), "rb")
)
ref_imgs = list(metadata.keys())
img_groups = []
for ref_img in ref_imgs:
other_imgs = metadata[ref_img]
if len(other_imgs) + 1 < self.num_views:
continue
group = [(*other_img[0], other_img[1]) for other_img in other_imgs]
group.insert(0, (*ref_img, 1))
img_groups.append(np.array(group))
id_ranges.append(id_range_subscenes[ref_img[0]])
scope.append(start)
start = start + len(group)
num_groups = len(img_groups)
sceneids.extend([j] * num_groups)
groups.extend(img_groups)
is_video.extend([False] * num_groups)
j += 1
self.scenes = np.array(scenes)
self.sceneids = np.array(sceneids)
self.scope = np.array(scope)
self.video_flags = np.array(is_video)
self.groups = np.concatenate(groups, 0)
self.id_ranges = np.array(id_ranges)
self.images = np.array(images)
data = dict(
scenes=self.scenes,
sceneids=self.sceneids,
scope=self.scope,
video_flags=self.video_flags,
groups=self.groups,
id_ranges=self.id_ranges,
images=self.images,
)
with h5py.File(cache_file, "w") as h5f:
h5f.create_dataset(
"scenes",
data=data["scenes"].astype(object),
dtype=h5py.string_dtype(encoding="utf-8"),
compression="lzf",
chunks=True,
)
h5f.create_dataset(
"sceneids", data=data["sceneids"], compression="lzf", chunks=True
)
h5f.create_dataset(
"scope", data=data["scope"], compression="lzf", chunks=True
)
h5f.create_dataset(
"video_flags",
data=data["video_flags"],
compression="lzf",
chunks=True,
)
h5f.create_dataset(
"groups", data=data["groups"], compression="lzf", chunks=True
)
h5f.create_dataset(
"id_ranges", data=data["id_ranges"], compression="lzf", chunks=True
)
h5f.create_dataset(
"images", data=data["images"], compression="lzf", chunks=True
)
def __len__(self):
return len(self.scope)
def get_image_num(self):
return len(self.images)
def get_stats(self):
return f"{len(self)} groups of views"
def _get_views(self, idx, resolution, rng, num_views):
scene = self.scenes[self.sceneids[idx]]
if rng.random() < 0.6:
ids = np.arange(self.id_ranges[idx][0], self.id_ranges[idx][1])
cut_off = num_views if not self.allow_repeat else max(num_views // 3, 3)
start_ids = ids[: len(ids) - cut_off + 1]
start_id = rng.choice(start_ids)
pos, ordered_video = self.get_seq_from_start_id(
num_views,
start_id,
ids.tolist(),
rng,
max_interval=self.max_interval,
video_prob=0.8,
fix_interval_prob=0.5,
block_shuffle=16,
)
ids = np.array(ids)[pos]
image_idxs = self.images[ids]
else:
ordered_video = False
seq_start_index = self.scope[idx]
seq_end_index = self.scope[idx + 1] if idx < len(self.scope) - 1 else None
image_idxs = (
self.groups[seq_start_index:seq_end_index]
if seq_end_index is not None
else self.groups[seq_start_index:]
)
image_idxs, overlap_scores = image_idxs[:, :2], image_idxs[:, 2]
replace = (
True
if self.allow_repeat
or len(overlap_scores[overlap_scores > 0]) < num_views
else False
)
image_idxs = rng.choice(
image_idxs,
num_views,
replace=replace,
p=overlap_scores / np.sum(overlap_scores),
)
image_idxs = image_idxs.astype(np.int64)
views = []
for v, view_idx in enumerate(image_idxs):
img_path = self.imgid2path(view_idx, scene)
depth_path = img_path.replace("rgb", "depth").replace(".jpg", ".npy")
cam_path = img_path.replace("rgb", "cam").replace(".jpg", ".npz")
sky_mask_path = img_path.replace("rgb", "sky_mask")
image = imread_cv2(img_path)
depthmap = np.load(depth_path)
camera_params = np.load(cam_path)
sky_mask = cv2.imread(sky_mask_path, cv2.IMREAD_UNCHANGED) >= 127
intrinsics = camera_params["intrinsic"].astype(np.float32)
camera_pose = camera_params["pose"].astype(np.float32)
depthmap[sky_mask] = -1.0
depthmap[depthmap > 400.0] = 0.0
depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0)
threshold = (
np.percentile(depthmap[depthmap > 0], 98)
if depthmap[depthmap > 0].size > 0
else 0
)
depthmap[depthmap > threshold] = 0.0
image, depthmap, intrinsics = self._crop_resize_if_necessary(
image, depthmap, intrinsics, resolution, rng, info=(img_path)
)
# generate img mask and raymap mask
img_mask, ray_mask = self.get_img_and_ray_masks(
self.is_metric, v, rng, p=[0.75, 0.2, 0.05]
)
views.append(
dict(
img=image,
depthmap=depthmap,
camera_pose=camera_pose, # cam2world
camera_intrinsics=intrinsics,
dataset="MapFree",
label=img_path,
is_metric=self.is_metric,
instance=img_path,
is_video=ordered_video,
quantile=np.array(0.96, dtype=np.float32),
img_mask=img_mask,
ray_mask=ray_mask,
camera_only=False,
depth_only=False,
single_view=False,
reset=False,
)
)
assert len(views) == num_views
return views