File size: 7,512 Bytes
17cd746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# 
# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 
# property and proprietary rights in and to this software and related documentation. 
# Any commercial use, reproduction, disclosure or distribution of this software and 
# related documentation without an express license agreement from Toyota Motor Europe NV/SA 
# is strictly prohibited.
#


from typing import Optional, Literal, List
from copy import deepcopy
import json
import tyro
from pathlib import Path
import shutil
import random


class NeRFDatasetAssembler:
    def __init__(self, src_folders: List[Path], tgt_folder: Path, division_mode: Literal['random_single', 'random_group', 'last']='random_group'):
        self.src_folders = src_folders
        self.tgt_folder = tgt_folder
        self.num_timestep = 0

        # use the subject name as the random seed to sample the test sequence
        subjects = [sf.name.split('_')[0] for sf in src_folders]
        for s in subjects:
            assert s == subjects[0], f"Cannot combine datasets from different subjects: {subjects}"
        subject = subjects[0]
        random.seed(subject)

        if division_mode == 'random_single':
            self.src_folders_test = [self.src_folders.pop(int(random.uniform(0, 1) * len(src_folders)))]
        elif division_mode == 'random_group':
            # sample one sequence as the test sequence every `group_size` sequences
            self.src_folders_test = []
            num_all = len(self.src_folders)
            group_size = 10
            num_test = max(1, num_all // group_size)
            indices_test  = []
            for gi in range(num_test):
                idx = min(num_all - 1, random.randint(0, group_size - 1) + gi * group_size)
                indices_test.append(idx)

            for idx in indices_test:
                self.src_folders_test.append(self.src_folders.pop(idx))
        elif division_mode == 'last':
            self.src_folders_test = [self.src_folders.pop(-1)]
        else:
            raise ValueError(f"Unknown division mode: {division_mode}")

        self.src_folders_train = self.src_folders

    def write(self):
        self.combine_dbs(self.src_folders_train, division='train')
        self.combine_dbs(self.src_folders_test, division='test')

    def combine_dbs(self, src_folders, division: Optional[Literal['train', 'test']] = None):
        db = None
        for i, src_folder in enumerate(src_folders):
            dbi_path = src_folder / "transforms.json"
            assert dbi_path.exists(), f"Could not find {dbi_path}"
            # print(f"Loading database: {dbi_path}")
            dbi = json.load(open(dbi_path, "r"))
           
            dbi['timestep_indices'] = [t + self.num_timestep for t in dbi['timestep_indices']]
            self.num_timestep += len(dbi['timestep_indices'])
            for frame in dbi['frames']:
                # drop keys that are irrelevant for a combined dataset
                frame.pop('timestep_index_original')
                frame.pop('timestep_id')

                # accumulate timestep indices
                frame['timestep_index'] = dbi['timestep_indices'][frame['timestep_index']]

                # complement the parent folder
                frame['file_path'] = str(Path('..') / Path(src_folder.name) / frame['file_path'])
                frame['flame_param_path'] = str(Path('..') / Path(src_folder.name) / frame['flame_param_path'])
                frame['fg_mask_path'] = str(Path('..') / Path(src_folder.name) / frame['fg_mask_path'])
            
            if db is None:
                db = dbi
            else:
                db['frames'] += dbi['frames']
                db['timestep_indices'] += dbi['timestep_indices']
            
        if not self.tgt_folder.exists():
            self.tgt_folder.mkdir(parents=True)
        
        if division == 'train':
            # copy the canonical flame param
            cano_flame_param_path = src_folders[0] / "canonical_flame_param.npz"
            tgt_flame_param_path = self.tgt_folder / f"canonical_flame_param.npz"
            print(f"Copying canonical flame param: {tgt_flame_param_path}")
            shutil.copy(cano_flame_param_path, tgt_flame_param_path)

            # leave one camera for validation
            db_train = {k: v for k, v in db.items() if k not in ['frames', 'camera_indices']}
            db_train['frames'] = []
            db_val = deepcopy(db_train)

            if len(db['camera_indices']) > 1:
                # when having multiple cameras, leave one camera for validation (novel-view sythesis)
                if 8 in db['camera_indices']:
                    # use camera 8 for validation (front-view of the NeRSemble dataset)
                    db_train['camera_indices'] = [i for i in db['camera_indices'] if i != 8]
                    db_val['camera_indices'] = [8]
                else:
                    # use the last camera for validation
                    db_train['camera_indices'] = db['camera_indices'][:-1]
                    db_val['camera_indices'] = [db['camera_indices'][-1]]
            else:
                # when only having one camera, we create an empty validation set
                db_train['camera_indices'] = db['camera_indices']
                db_val['camera_indices'] = []

            for frame in db['frames']:
                if frame['camera_index'] in db_train['camera_indices']:
                    db_train['frames'].append(frame)
                elif frame['camera_index'] in db_val['camera_indices']:
                    db_val['frames'].append(frame)
                else:
                    raise ValueError(f"Unknown camera index: {frame['camera_index']}")
                
            write_json(db_train, self.tgt_folder, 'train')
            write_json(db_val, self.tgt_folder, 'val')

            with open(self.tgt_folder / 'sequences_trainval.txt', 'w') as f:
                for folder in src_folders:
                    f.write(folder.name + '\n')
        else:
            db['timestep_indices'] = sorted(db['timestep_indices'])
            write_json(db, self.tgt_folder, division)

            with open(self.tgt_folder / f'sequences_{division}.txt', 'w') as f:
                for folder in src_folders:
                    f.write(folder.name + '\n')

    
def write_json(db, tgt_folder, division=None):
    fname = "transforms.json" if division is None else f"transforms_{division}.json"
    json_path = tgt_folder / fname
    print(f"Writing database: {json_path}")
    with open(json_path, "w") as f:
        json.dump(db, f, indent=4)
    
def main(
        src_folders: List[Path],
        tgt_folder: Path,
        division_mode: Literal['random_single', 'random_group', 'last']='random_group',
    ):
    incomplete = False
    print("==== Begin assembling datasets ====")
    print(f"Division mode: {division_mode}")
    for src_folder in src_folders:
        try:
            assert src_folder.exists(), f"Error: could not find {src_folder}"
            assert src_folder.parent == tgt_folder.parent, "All source folders must be in the same parent folder as the target folder"
            # print(src_folder)
        except AssertionError as e:
            print(e)
            incomplete = True

    if incomplete:
        return

    nerf_dataset_assembler = NeRFDatasetAssembler(src_folders, tgt_folder, division_mode)
    nerf_dataset_assembler.write()

    print("Done!")


if __name__ == "__main__":
    tyro.cli(main)