LAM / vhap /data /nersemble_dataset.py
yuandong513
feat: init
17cd746
#
# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
# property and proprietary rights in and to this software and related documentation.
# Any commercial use, reproduction, disclosure or distribution of this software and
# related documentation without an express license agreement from Toyota Motor Europe NV/SA
# is strictly prohibited.
#
import json
import numpy as np
import torch
from vhap.data.video_dataset import VideoDataset
from vhap.config.nersemble import NersembleDataConfig
from vhap.util import camera
from vhap.util.log import get_logger
logger = get_logger(__name__)
class NeRSembleDataset(VideoDataset):
def __init__(
self,
cfg: NersembleDataConfig,
img_to_tensor: bool = False,
batchify_all_views: bool = False,
):
"""
Args:
root_folder: Path to dataset with the following directory layout
<root_folder>/
|---camera_params/
| |---<subject>/
| |---camera_params.json
|
|---color_correction/
| |---<subject>/
| |---<camera_id>.npy
|
|---<subject>/
|---<sequence>/
|---images/
| |---cam_<camera_id>_<timestep_id>.jpg
|
|---alpha_maps/
| |---cam_<camera_id>_<timestep_id>.png
|
|---landmark2d/
|---face-alignment/
| |---<camera_id>.npz
|
|---STAR/
|---<camera_id>.npz
"""
self.cfg = cfg
assert cfg.subject != "", "Please specify the subject name"
super().__init__(
cfg=cfg,
img_to_tensor=img_to_tensor,
batchify_all_views=batchify_all_views,
)
def match_sequences(self):
logger.info(f"Subject: {self.cfg.subject}, sequence: {self.cfg.sequence}")
return list(filter(lambda x: x.is_dir(), (self.cfg.root_folder / self.cfg.subject).glob(f"{self.cfg.sequence}*")))
def define_properties(self):
super().define_properties()
self.properties['rgb']['cam_id_prefix'] = "cam_"
self.properties['alpha_map']['cam_id_prefix'] = "cam_"
def load_camera_params(self):
load_path = self.cfg.root_folder / "camera_params" / self.cfg.subject / "camera_params.json"
assert load_path.exists()
param = json.load(open(load_path))
K = torch.Tensor(param["intrinsics"])
if "height" not in param or "width" not in param:
assert self.cfg.image_size_during_calibration is not None
H, W = self.cfg.image_size_during_calibration
else:
H, W = param["height"], param["width"]
self.camera_ids = list(param["world_2_cam"].keys())
w2c = torch.tensor([param["world_2_cam"][k] for k in self.camera_ids]) # (N, 4, 4)
R = w2c[..., :3, :3]
T = w2c[..., :3, 3]
orientation = R.transpose(-1, -2) # (N, 3, 3)
location = R.transpose(-1, -2) @ -T[..., None] # (N, 3, 1)
# adjust how cameras distribute in the space with a global rotation
if self.cfg.align_cameras_to_axes:
orientation, location = camera.align_cameras_to_axes(
orientation, location, target_convention="opengl"
)
# modify the local orientation of cameras to fit in different camera conventions
if self.cfg.camera_convention_conversion is not None:
orientation, K = camera.convert_camera_convention(
self.cfg.camera_convention_conversion, orientation, K, H, W
)
c2w = torch.cat([orientation, location], dim=-1) # camera-to-world transformation
if self.cfg.target_extrinsic_type == "w2c":
R = orientation.transpose(-1, -2)
T = orientation.transpose(-1, -2) @ -location
w2c = torch.cat([R, T], dim=-1) # world-to-camera transformation
extrinsic = w2c
elif self.cfg.target_extrinsic_type == "c2w":
extrinsic = c2w
else:
raise NotImplementedError(f"Unknown extrinsic type: {self.cfg.target_extrinsic_type}")
self.camera_params = {}
for i, camera_id in enumerate(self.camera_ids):
self.camera_params[camera_id] = {"intrinsic": K, "extrinsic": extrinsic[i]}
def filter_division(self, division):
if division is not None:
cam_for_train = [8, 7, 9, 4, 10, 5, 13, 2, 12, 1, 14, 0]
if division == "train":
self.camera_ids = [
self.camera_ids[i]
for i in range(len(self.camera_ids))
if i in cam_for_train
]
elif division == "val":
self.camera_ids = [
self.camera_ids[i]
for i in range(len(self.camera_ids))
if i not in cam_for_train
]
elif division == "front-view":
self.camera_ids = self.camera_ids[8:9]
elif division == "side-view":
self.camera_ids = self.camera_ids[0:1]
elif division == "six-view":
self.camera_ids = [self.camera_ids[i] for i in [0, 1, 7, 8, 14, 15]]
else:
raise NotImplementedError(f"Unknown division type: {division}")
logger.info(f"division: {division}")
def apply_transforms(self, item):
if self.cfg.use_color_correction:
color_correction_path = self.cfg.root_folder / 'color_correction' / self.cfg.subject / f'{item["camera_id"]}.npy'
affine_color_transform = np.load(color_correction_path)
rgb = item["rgb"] / 255
rgb = rgb @ affine_color_transform[:3, :3] + affine_color_transform[np.newaxis, :3, 3]
item["rgb"] = (np.clip(rgb, 0, 1) * 255).astype(np.uint8)
super().apply_transforms(item)
return item
if __name__ == "__main__":
import tyro
from tqdm import tqdm
from torch.utils.data import DataLoader
from vhap.config.nersemble import NersembleDataConfig
from vhap.config.base import import_module
cfg = tyro.cli(NersembleDataConfig)
cfg.use_landmark = False
dataset = import_module(cfg._target)(
cfg=cfg,
img_to_tensor=False,
batchify_all_views=True,
)
print(len(dataset))
sample = dataset[0]
print(sample.keys())
print(sample["rgb"].shape)
dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1)
for item in tqdm(dataloader):
pass