File size: 7,512 Bytes
17cd746 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
#
# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
# property and proprietary rights in and to this software and related documentation.
# Any commercial use, reproduction, disclosure or distribution of this software and
# related documentation without an express license agreement from Toyota Motor Europe NV/SA
# is strictly prohibited.
#
from typing import Optional, Literal, List
from copy import deepcopy
import json
import tyro
from pathlib import Path
import shutil
import random
class NeRFDatasetAssembler:
def __init__(self, src_folders: List[Path], tgt_folder: Path, division_mode: Literal['random_single', 'random_group', 'last']='random_group'):
self.src_folders = src_folders
self.tgt_folder = tgt_folder
self.num_timestep = 0
# use the subject name as the random seed to sample the test sequence
subjects = [sf.name.split('_')[0] for sf in src_folders]
for s in subjects:
assert s == subjects[0], f"Cannot combine datasets from different subjects: {subjects}"
subject = subjects[0]
random.seed(subject)
if division_mode == 'random_single':
self.src_folders_test = [self.src_folders.pop(int(random.uniform(0, 1) * len(src_folders)))]
elif division_mode == 'random_group':
# sample one sequence as the test sequence every `group_size` sequences
self.src_folders_test = []
num_all = len(self.src_folders)
group_size = 10
num_test = max(1, num_all // group_size)
indices_test = []
for gi in range(num_test):
idx = min(num_all - 1, random.randint(0, group_size - 1) + gi * group_size)
indices_test.append(idx)
for idx in indices_test:
self.src_folders_test.append(self.src_folders.pop(idx))
elif division_mode == 'last':
self.src_folders_test = [self.src_folders.pop(-1)]
else:
raise ValueError(f"Unknown division mode: {division_mode}")
self.src_folders_train = self.src_folders
def write(self):
self.combine_dbs(self.src_folders_train, division='train')
self.combine_dbs(self.src_folders_test, division='test')
def combine_dbs(self, src_folders, division: Optional[Literal['train', 'test']] = None):
db = None
for i, src_folder in enumerate(src_folders):
dbi_path = src_folder / "transforms.json"
assert dbi_path.exists(), f"Could not find {dbi_path}"
# print(f"Loading database: {dbi_path}")
dbi = json.load(open(dbi_path, "r"))
dbi['timestep_indices'] = [t + self.num_timestep for t in dbi['timestep_indices']]
self.num_timestep += len(dbi['timestep_indices'])
for frame in dbi['frames']:
# drop keys that are irrelevant for a combined dataset
frame.pop('timestep_index_original')
frame.pop('timestep_id')
# accumulate timestep indices
frame['timestep_index'] = dbi['timestep_indices'][frame['timestep_index']]
# complement the parent folder
frame['file_path'] = str(Path('..') / Path(src_folder.name) / frame['file_path'])
frame['flame_param_path'] = str(Path('..') / Path(src_folder.name) / frame['flame_param_path'])
frame['fg_mask_path'] = str(Path('..') / Path(src_folder.name) / frame['fg_mask_path'])
if db is None:
db = dbi
else:
db['frames'] += dbi['frames']
db['timestep_indices'] += dbi['timestep_indices']
if not self.tgt_folder.exists():
self.tgt_folder.mkdir(parents=True)
if division == 'train':
# copy the canonical flame param
cano_flame_param_path = src_folders[0] / "canonical_flame_param.npz"
tgt_flame_param_path = self.tgt_folder / f"canonical_flame_param.npz"
print(f"Copying canonical flame param: {tgt_flame_param_path}")
shutil.copy(cano_flame_param_path, tgt_flame_param_path)
# leave one camera for validation
db_train = {k: v for k, v in db.items() if k not in ['frames', 'camera_indices']}
db_train['frames'] = []
db_val = deepcopy(db_train)
if len(db['camera_indices']) > 1:
# when having multiple cameras, leave one camera for validation (novel-view sythesis)
if 8 in db['camera_indices']:
# use camera 8 for validation (front-view of the NeRSemble dataset)
db_train['camera_indices'] = [i for i in db['camera_indices'] if i != 8]
db_val['camera_indices'] = [8]
else:
# use the last camera for validation
db_train['camera_indices'] = db['camera_indices'][:-1]
db_val['camera_indices'] = [db['camera_indices'][-1]]
else:
# when only having one camera, we create an empty validation set
db_train['camera_indices'] = db['camera_indices']
db_val['camera_indices'] = []
for frame in db['frames']:
if frame['camera_index'] in db_train['camera_indices']:
db_train['frames'].append(frame)
elif frame['camera_index'] in db_val['camera_indices']:
db_val['frames'].append(frame)
else:
raise ValueError(f"Unknown camera index: {frame['camera_index']}")
write_json(db_train, self.tgt_folder, 'train')
write_json(db_val, self.tgt_folder, 'val')
with open(self.tgt_folder / 'sequences_trainval.txt', 'w') as f:
for folder in src_folders:
f.write(folder.name + '\n')
else:
db['timestep_indices'] = sorted(db['timestep_indices'])
write_json(db, self.tgt_folder, division)
with open(self.tgt_folder / f'sequences_{division}.txt', 'w') as f:
for folder in src_folders:
f.write(folder.name + '\n')
def write_json(db, tgt_folder, division=None):
fname = "transforms.json" if division is None else f"transforms_{division}.json"
json_path = tgt_folder / fname
print(f"Writing database: {json_path}")
with open(json_path, "w") as f:
json.dump(db, f, indent=4)
def main(
src_folders: List[Path],
tgt_folder: Path,
division_mode: Literal['random_single', 'random_group', 'last']='random_group',
):
incomplete = False
print("==== Begin assembling datasets ====")
print(f"Division mode: {division_mode}")
for src_folder in src_folders:
try:
assert src_folder.exists(), f"Error: could not find {src_folder}"
assert src_folder.parent == tgt_folder.parent, "All source folders must be in the same parent folder as the target folder"
# print(src_folder)
except AssertionError as e:
print(e)
incomplete = True
if incomplete:
return
nerf_dataset_assembler = NeRFDatasetAssembler(src_folders, tgt_folder, division_mode)
nerf_dataset_assembler.write()
print("Done!")
if __name__ == "__main__":
tyro.cli(main)
|