File size: 5,608 Bytes
17cd746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from pathlib import Path
import json
import numpy as np
import PIL.Image as Image
import torch
import torchvision.transforms.functional as F
from torch.utils.data import Dataset
from vhap.util.log import get_logger


logger = get_logger(__name__)


class NeRFDataset(Dataset):
    def __init__(
        self,
        root_folder,
        division=None,
        camera_convention_conversion=None,
        target_extrinsic_type='w2c',
        use_fg_mask=False,
        use_flame_param=False,
    ):
        """
        Args:
            root_folder: Path to dataset with the following directory layout
                <root_folder>/
                |
                |---<images>/
                |   |---00000.jpg
                |   |...
                |
                |---<fg_masks>/
                |   |---00000.png
                |   |...
                |
                |---<flame_param>/
                |   |---00000.npz
                |   |...
                |
                |---transforms_backup.json          # backup of the original transforms.json
                |---transforms_backup_flame.json    # backup of the original transforms.json with flame_param
                |---transforms.json                 # the final transforms.json
                |---transforms_train.json   # the final transforms.json for training
                |---transforms_val.json     # the final transforms.json for validation
                |---transforms_test.json    # the final transforms.json for testing

                                   
        """

        super().__init__()
        self.root_folder = Path(root_folder)
        self.division = division
        self.camera_convention_conversion = camera_convention_conversion
        self.target_extrinsic_type = target_extrinsic_type
        self.use_fg_mask = use_fg_mask
        self.use_flame_param = use_flame_param

        logger.info(f"Loading NeRF scene from: {root_folder}")

        # data division
        if division is None:
            tranform_path = self.root_folder / "transforms.json"
        elif division == "train":
            tranform_path = self.root_folder / "transforms_train.json"
        elif division == "val":
            tranform_path = self.root_folder / "transforms_val.json"
        elif division == "test":
            tranform_path = self.root_folder / "transforms_test.json"
        else:
            raise NotImplementedError(f"Unknown division type: {division}")
        logger.info(f"division: {division}")

        self.transforms = json.load(open(tranform_path, "r"))
        logger.info(f"number of timesteps: {len(self.transforms['timestep_indices'])}, number of cameras: {len(self.transforms['camera_indices'])}")

        assert len(self.transforms['timestep_indices']) == max(self.transforms['timestep_indices']) + 1

    def __len__(self):
        return len(self.transforms['frames'])

    def __getitem__(self, i):
        frame = self.transforms['frames'][i]

        # 'timestep_index', 'timestep_index_original', 'timestep_id', 'camera_index', 'camera_id', 'cx', 'cy', 'fl_x', 'fl_y', 'h', 'w', 'camera_angle_x', 'camera_angle_y', 'transform_matrix', 'file_path', 'fg_mask_path', 'flame_param_path']

        K = torch.eye(3)
        K[[0, 1, 0, 1], [0, 1, 2, 2]] = torch.tensor(
            [frame["fl_x"], frame["fl_y"], frame["cx"], frame["cy"]]
        )

        c2w = torch.tensor(frame['transform_matrix'])
        if self.target_extrinsic_type == "w2c":
            extrinsic = c2w.inverse()
        elif self.target_extrinsic_type == "c2w":
            extrinsic = c2w
        else:
            raise NotImplementedError(f"Unknown extrinsic type: {self.target_extrinsic_type}")
        
        img_path = self.root_folder / frame['file_path']

        item = {
            'timestep_index': frame['timestep_index'],
            'camera_index': frame['camera_index'],
            'intrinsics': K,
            'extrinsics': extrinsic,
            'image_height': frame['h'],
            'image_width': frame['w'],
            'image': np.array(Image.open(img_path)),
            'image_path': img_path,
        }

        if self.use_fg_mask and 'fg_mask_path' in frame:
            fg_mask_path = self.root_folder / frame['fg_mask_path']
            item["fg_mask"] = np.array(Image.open(fg_mask_path))
            item["fg_mask_path"] = fg_mask_path

        if self.use_flame_param and 'flame_param_path' in frame:
            npz = np.load(self.root_folder / frame['flame_param_path'], allow_pickle=True)
            item["flame_param"] = dict(npz)

        return item

    def apply_to_tensor(self, item):
        if self.img_to_tensor:
            if "rgb" in item:
                item["rgb"] = F.to_tensor(item["rgb"])
                # if self.rgb_range_shift:
                #     item["rgb"] = (item["rgb"] - 0.5) / 0.5

            if "alpha_map" in item:
                item["alpha_map"] = F.to_tensor(item["alpha_map"])
        return item


if __name__ == "__main__":
    from tqdm import tqdm
    from dataclasses import dataclass
    import tyro
    from torch.utils.data import DataLoader

    @dataclass
    class Args:
        root_folder: str
        subject: str
        sequence: str
        use_landmark: bool = False
        batchify_all_views: bool = False

    args = tyro.cli(Args)

    dataset = NeRFDataset(root_folder=args.root_folder)

    print(len(dataset))

    sample = dataset[0]
    print(sample.keys())

    dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1)
    for item in tqdm(dataloader):
        pass