|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import numpy as np |
|
import torch |
|
from vhap.data.video_dataset import VideoDataset |
|
from vhap.config.nersemble import NersembleDataConfig |
|
from vhap.util import camera |
|
from vhap.util.log import get_logger |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
class NeRSembleDataset(VideoDataset): |
|
def __init__( |
|
self, |
|
cfg: NersembleDataConfig, |
|
img_to_tensor: bool = False, |
|
batchify_all_views: bool = False, |
|
): |
|
""" |
|
Args: |
|
root_folder: Path to dataset with the following directory layout |
|
<root_folder>/ |
|
|---camera_params/ |
|
| |---<subject>/ |
|
| |---camera_params.json |
|
| |
|
|---color_correction/ |
|
| |---<subject>/ |
|
| |---<camera_id>.npy |
|
| |
|
|---<subject>/ |
|
|---<sequence>/ |
|
|---images/ |
|
| |---cam_<camera_id>_<timestep_id>.jpg |
|
| |
|
|---alpha_maps/ |
|
| |---cam_<camera_id>_<timestep_id>.png |
|
| |
|
|---landmark2d/ |
|
|---face-alignment/ |
|
| |---<camera_id>.npz |
|
| |
|
|---STAR/ |
|
|---<camera_id>.npz |
|
""" |
|
self.cfg = cfg |
|
assert cfg.subject != "", "Please specify the subject name" |
|
|
|
super().__init__( |
|
cfg=cfg, |
|
img_to_tensor=img_to_tensor, |
|
batchify_all_views=batchify_all_views, |
|
) |
|
|
|
def match_sequences(self): |
|
logger.info(f"Subject: {self.cfg.subject}, sequence: {self.cfg.sequence}") |
|
return list(filter(lambda x: x.is_dir(), (self.cfg.root_folder / self.cfg.subject).glob(f"{self.cfg.sequence}*"))) |
|
|
|
def define_properties(self): |
|
super().define_properties() |
|
self.properties['rgb']['cam_id_prefix'] = "cam_" |
|
self.properties['alpha_map']['cam_id_prefix'] = "cam_" |
|
|
|
def load_camera_params(self): |
|
load_path = self.cfg.root_folder / "camera_params" / self.cfg.subject / "camera_params.json" |
|
assert load_path.exists() |
|
param = json.load(open(load_path)) |
|
|
|
K = torch.Tensor(param["intrinsics"]) |
|
|
|
if "height" not in param or "width" not in param: |
|
assert self.cfg.image_size_during_calibration is not None |
|
H, W = self.cfg.image_size_during_calibration |
|
else: |
|
H, W = param["height"], param["width"] |
|
|
|
self.camera_ids = list(param["world_2_cam"].keys()) |
|
w2c = torch.tensor([param["world_2_cam"][k] for k in self.camera_ids]) |
|
R = w2c[..., :3, :3] |
|
T = w2c[..., :3, 3] |
|
|
|
orientation = R.transpose(-1, -2) |
|
location = R.transpose(-1, -2) @ -T[..., None] |
|
|
|
|
|
if self.cfg.align_cameras_to_axes: |
|
orientation, location = camera.align_cameras_to_axes( |
|
orientation, location, target_convention="opengl" |
|
) |
|
|
|
|
|
if self.cfg.camera_convention_conversion is not None: |
|
orientation, K = camera.convert_camera_convention( |
|
self.cfg.camera_convention_conversion, orientation, K, H, W |
|
) |
|
|
|
c2w = torch.cat([orientation, location], dim=-1) |
|
|
|
if self.cfg.target_extrinsic_type == "w2c": |
|
R = orientation.transpose(-1, -2) |
|
T = orientation.transpose(-1, -2) @ -location |
|
w2c = torch.cat([R, T], dim=-1) |
|
extrinsic = w2c |
|
elif self.cfg.target_extrinsic_type == "c2w": |
|
extrinsic = c2w |
|
else: |
|
raise NotImplementedError(f"Unknown extrinsic type: {self.cfg.target_extrinsic_type}") |
|
|
|
self.camera_params = {} |
|
for i, camera_id in enumerate(self.camera_ids): |
|
self.camera_params[camera_id] = {"intrinsic": K, "extrinsic": extrinsic[i]} |
|
|
|
def filter_division(self, division): |
|
if division is not None: |
|
cam_for_train = [8, 7, 9, 4, 10, 5, 13, 2, 12, 1, 14, 0] |
|
if division == "train": |
|
self.camera_ids = [ |
|
self.camera_ids[i] |
|
for i in range(len(self.camera_ids)) |
|
if i in cam_for_train |
|
] |
|
elif division == "val": |
|
self.camera_ids = [ |
|
self.camera_ids[i] |
|
for i in range(len(self.camera_ids)) |
|
if i not in cam_for_train |
|
] |
|
elif division == "front-view": |
|
self.camera_ids = self.camera_ids[8:9] |
|
elif division == "side-view": |
|
self.camera_ids = self.camera_ids[0:1] |
|
elif division == "six-view": |
|
self.camera_ids = [self.camera_ids[i] for i in [0, 1, 7, 8, 14, 15]] |
|
else: |
|
raise NotImplementedError(f"Unknown division type: {division}") |
|
logger.info(f"division: {division}") |
|
|
|
def apply_transforms(self, item): |
|
if self.cfg.use_color_correction: |
|
color_correction_path = self.cfg.root_folder / 'color_correction' / self.cfg.subject / f'{item["camera_id"]}.npy' |
|
affine_color_transform = np.load(color_correction_path) |
|
rgb = item["rgb"] / 255 |
|
rgb = rgb @ affine_color_transform[:3, :3] + affine_color_transform[np.newaxis, :3, 3] |
|
item["rgb"] = (np.clip(rgb, 0, 1) * 255).astype(np.uint8) |
|
|
|
super().apply_transforms(item) |
|
return item |
|
|
|
|
|
if __name__ == "__main__": |
|
import tyro |
|
from tqdm import tqdm |
|
from torch.utils.data import DataLoader |
|
from vhap.config.nersemble import NersembleDataConfig |
|
from vhap.config.base import import_module |
|
|
|
cfg = tyro.cli(NersembleDataConfig) |
|
cfg.use_landmark = False |
|
dataset = import_module(cfg._target)( |
|
cfg=cfg, |
|
img_to_tensor=False, |
|
batchify_all_views=True, |
|
) |
|
|
|
print(len(dataset)) |
|
|
|
sample = dataset[0] |
|
print(sample.keys()) |
|
print(sample["rgb"].shape) |
|
|
|
dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1) |
|
for item in tqdm(dataloader): |
|
pass |
|
|