liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# modified from DUSt3R
import numpy as np
from dust3r.datasets.base.batched_sampler import (
BatchedRandomSampler,
CustomRandomSampler,
)
import torch
class EasyDataset:
"""a dataset that you can easily resize and combine.
Examples:
---------
2 * dataset ==> duplicate each element 2x
10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary)
dataset1 + dataset2 ==> concatenate datasets
"""
def __add__(self, other):
return CatDataset([self, other])
def __rmul__(self, factor):
return MulDataset(factor, self)
def __rmatmul__(self, factor):
return ResizedDataset(factor, self)
def set_epoch(self, epoch):
pass # nothing to do by default
def make_sampler(
self, batch_size, shuffle=True, drop_last=True, world_size=1, rank=0, fixed_length=False
):
if not (shuffle):
raise NotImplementedError() # cannot deal yet
num_of_aspect_ratios = len(self._resolutions)
num_of_views = self.num_views
sampler = CustomRandomSampler(
self,
batch_size,
num_of_aspect_ratios,
4 if not fixed_length else num_of_views,
num_of_views,
world_size,
warmup=1,
drop_last=drop_last,
)
return BatchedRandomSampler(sampler, batch_size, drop_last)
class MulDataset(EasyDataset):
"""Artifically augmenting the size of a dataset."""
multiplicator: int
def __init__(self, multiplicator, dataset):
assert isinstance(multiplicator, int) and multiplicator > 0
self.multiplicator = multiplicator
self.dataset = dataset
def __len__(self):
return self.multiplicator * len(self.dataset)
def __repr__(self):
return f"{self.multiplicator}*{repr(self.dataset)}"
def __getitem__(self, idx):
if isinstance(idx, tuple):
idx, other, another = idx
return self.dataset[idx // self.multiplicator, other, another]
else:
return self.dataset[idx // self.multiplicator]
@property
def _resolutions(self):
return self.dataset._resolutions
@property
def num_views(self):
return self.dataset.num_views
class ResizedDataset(EasyDataset):
"""Artifically changing the size of a dataset."""
new_size: int
def __init__(self, new_size, dataset):
assert isinstance(new_size, int) and new_size > 0
self.new_size = new_size
self.dataset = dataset
def __len__(self):
return self.new_size
def __repr__(self):
size_str = str(self.new_size)
for i in range((len(size_str) - 1) // 3):
sep = -4 * i - 3
size_str = size_str[:sep] + "_" + size_str[sep:]
return f"{size_str} @ {repr(self.dataset)}"
def set_epoch(self, epoch):
# this random shuffle only depends on the epoch
rng = np.random.default_rng(seed=epoch + 777)
# shuffle all indices
perm = rng.permutation(len(self.dataset))
# rotary extension until target size is met
shuffled_idxs = np.concatenate(
[perm] * (1 + (len(self) - 1) // len(self.dataset))
)
self._idxs_mapping = shuffled_idxs[: self.new_size]
assert len(self._idxs_mapping) == self.new_size
def __getitem__(self, idx):
assert hasattr(
self, "_idxs_mapping"
), "You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()"
if isinstance(idx, tuple):
idx, other, another = idx
return self.dataset[self._idxs_mapping[idx], other, another]
else:
return self.dataset[self._idxs_mapping[idx]]
@property
def _resolutions(self):
return self.dataset._resolutions
@property
def num_views(self):
return self.dataset.num_views
class CatDataset(EasyDataset):
"""Concatenation of several datasets"""
def __init__(self, datasets):
for dataset in datasets:
assert isinstance(dataset, EasyDataset)
self.datasets = datasets
self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
def __len__(self):
return self._cum_sizes[-1]
def __repr__(self):
# remove uselessly long transform
return " + ".join(
repr(dataset).replace(
",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))",
"",
)
for dataset in self.datasets
)
def set_epoch(self, epoch):
for dataset in self.datasets:
dataset.set_epoch(epoch)
def __getitem__(self, idx):
other = None
if isinstance(idx, tuple):
idx, other, another = idx
if not (0 <= idx < len(self)):
raise IndexError()
db_idx = np.searchsorted(self._cum_sizes, idx, "right")
dataset = self.datasets[db_idx]
new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
if other is not None and another is not None:
new_idx = (new_idx, other, another)
return dataset[new_idx]
@property
def _resolutions(self):
resolutions = self.datasets[0]._resolutions
for dataset in self.datasets[1:]:
assert tuple(dataset._resolutions) == tuple(resolutions)
return resolutions
@property
def num_views(self):
num_views = self.datasets[0].num_views
for dataset in self.datasets[1:]:
assert dataset.num_views == num_views
return num_views