LAM / vhap /data /video_dataset.py
yuandong513
feat: init
17cd746
import os
from pathlib import Path
from copy import deepcopy
from typing import Optional
import numpy as np
import PIL.Image as Image
import torch
import torchvision.transforms.functional as F
from torch.utils.data import Dataset, default_collate
import json
from vhap.util.log import get_logger
from vhap.config.base import DataConfig
logger = get_logger(__name__)
class VideoDataset(Dataset):
def __init__(
self,
cfg: DataConfig,
img_to_tensor: bool = False,
batchify_all_views: bool = False,
):
"""
Args:
root_folder: Path to dataset with the following directory layout
<root_folder>/
|---images/
| |---<timestep_id>.jpg
|
|---alpha_maps/
| |---<timestep_id>.png
|
|---landmark2d/
|---face-alignment/
| |---<camera_id>.npz
|
|---STAR/
|---<camera_id>.npz
"""
super().__init__()
self.cfg = cfg
self.img_to_tensor = img_to_tensor
self.batchify_all_views = batchify_all_views
sequence_paths = self.match_sequences()
if len(sequence_paths) > 1:
logger.info(f"Found multiple sequences: {sequence_paths}")
raise ValueError(f"Found multiple sequences by '{cfg.sequence}': \n" + "\n\t".join([str(x) for x in sequence_paths]))
elif len(sequence_paths) == 0:
raise ValueError(f"Cannot find sequence: {cfg.sequence}")
self.sequence_path = sequence_paths[0]
logger.info(f"Initializing dataset from {self.sequence_path}")
self.define_properties()
self.load_camera_params()
# timesteps
self.timestep_ids = set(
f.split('.')[0].split('_')[-1]
for f in os.listdir(self.sequence_path / self.properties['rgb']['folder']) if f.endswith(self.properties['rgb']['suffix'])
)
self.timestep_ids = sorted(self.timestep_ids)
self.timestep_indices = list(range(len(self.timestep_ids)))
self.filter_division(cfg.division)
self.filter_subset(cfg.subset)
logger.info(f"number of timesteps: {self.num_timesteps}, number of cameras: {self.num_cameras}")
# collect
self.items = []
for fi, timestep_index in enumerate(self.timestep_indices):
for ci, camera_id in enumerate(self.camera_ids):
self.items.append(
{
"timestep_index": fi, # new index after filtering
"timestep_index_original": timestep_index, # original index
"timestep_id": self.timestep_ids[timestep_index],
"camera_index": ci,
"camera_id": camera_id,
}
)
def match_sequences(self):
logger.info(f"Looking for sequence '{self.cfg.sequence}' at {self.cfg.root_folder}")
return list(filter(lambda x: x.is_dir(), self.cfg.root_folder.glob(f"{self.cfg.sequence}*")))
def define_properties(self):
self.properties = {
"rgb": {
"folder": f"images_{self.cfg.n_downsample_rgb}"
if self.cfg.n_downsample_rgb
else "images",
"per_timestep": True,
# "suffix": "jpg",
"suffix": "png",
},
"alpha_map": {
"folder": "alpha_maps",
"per_timestep": True,
"suffix": "jpg",
},
"landmark2d/face-alignment": {
"folder": "landmark2d/face-alignment",
"per_timestep": False,
"suffix": "npz",
},
"landmark2d/STAR": {
"folder": "landmark2d/STAR",
"per_timestep": False,
"suffix": "npz",
},
"landmark2d/lms": {
"folder": "landmark2d/landmarks",
"per_timestep": False,
"suffix": "npz",
},
}
@staticmethod
def get_number_after_prefix(string, prefix):
i = string.find(prefix)
if i != -1:
number_begin = i + len(prefix)
assert number_begin < len(string), f"No number found behind prefix '{prefix}'"
assert string[number_begin].isdigit(), f"No number found behind prefix '{prefix}'"
non_digit_indices = [i for i, c in enumerate(string[number_begin:]) if not c.isdigit()]
if len(non_digit_indices) > 0:
number_end = number_begin + min(non_digit_indices)
return int(string[number_begin:number_end])
else:
return int(string[number_begin:])
else:
return None
def filter_division(self, division):
pass
def filter_subset(self, subset):
if subset is not None:
if 'ti' in subset:
ti = self.get_number_after_prefix(subset, 'ti')
if 'tj' in subset:
tj = self.get_number_after_prefix(subset, 'tj')
self.timestep_indices = self.timestep_indices[ti:tj+1]
else:
self.timestep_indices = self.timestep_indices[ti:ti+1]
elif 'tn' in subset:
tn = self.get_number_after_prefix(subset, 'tn')
tn_all = len(self.timestep_indices)
tn = min(tn, tn_all)
self.timestep_indices = self.timestep_indices[::tn_all // tn][:tn]
elif 'ts' in subset:
ts = self.get_number_after_prefix(subset, 'ts')
self.timestep_indices = self.timestep_indices[::ts]
if 'ci' in subset:
ci = self.get_number_after_prefix(subset, 'ci')
self.camera_ids = self.camera_ids[ci:ci+1]
elif 'cn' in subset:
cn = self.get_number_after_prefix(subset, 'cn')
cn_all = len(self.camera_ids)
cn = min(cn, cn_all)
self.camera_ids = self.camera_ids[::cn_all // cn][:cn]
elif 'cs' in subset:
cs = self.get_number_after_prefix(subset, 'cs')
self.camera_ids = self.camera_ids[::cs]
def load_camera_params(self):
self.camera_ids = ['0']
# Guessed focal length, height, width. Should be optimized or replaced by real values
f, h, w = 512, 512, 512
K = torch.Tensor([
[f, 0, w],
[0, f, h],
[0, 0, 1]
])
orientation = torch.eye(3)[None, ...] # (1, 3, 3)
location = torch.Tensor([0, 0, 1])[None, ..., None] # (1, 3, 1)
c2w = torch.cat([orientation, location], dim=-1) # camera-to-world transformation
if self.cfg.target_extrinsic_type == "w2c":
R = orientation.transpose(-1, -2)
T = orientation.transpose(-1, -2) @ -location
w2c = torch.cat([R, T], dim=-1) # world-to-camera transformation
extrinsic = w2c
elif self.cfg.target_extrinsic_type == "c2w":
extrinsic = c2w
else:
raise NotImplementedError(f"Unknown extrinsic type: {self.cfg.target_extrinsic_type}")
self.camera_params = {}
for i, camera_id in enumerate(self.camera_ids):
self.camera_params[camera_id] = {"intrinsic": K, "extrinsic": extrinsic[i]}
return self.camera_params
def __len__(self):
if self.batchify_all_views:
return self.num_timesteps
else:
return len(self.items)
def __getitem__(self, i):
if self.batchify_all_views:
return self.getitem_by_timestep(i)
else:
return self.getitem_single_image(i)
def getitem_single_image(self, i):
item = deepcopy(self.items[i])
rgb_path = self.get_property_path("rgb", i)
item["rgb"] = np.array(Image.open(rgb_path))[:, :, :3]
camera_param = self.camera_params[item["camera_id"]]
item["intrinsic"] = camera_param["intrinsic"].clone()
item["extrinsic"] = camera_param["extrinsic"].clone()
if self.cfg.use_alpha_map or self.cfg.background_color is not None:
alpha_path = self.get_property_path("alpha_map", i)
item["alpha_map"] = np.array(Image.open(alpha_path))
if self.cfg.use_landmark:
timestep_index = self.items[i]["timestep_index"]
landmark_path = self.get_property_path("landmark2d/lms", i)
landmark_npz = np.load(landmark_path)
lms_eyes_path = os.path.join(os.path.dirname(landmark_path),'iris.json')
item["lmk2d"] = landmark_npz["face_landmark_2d"][timestep_index] # (num_points, 3)
if (item["lmk2d"][:, :2] == -1).sum() > 0:
item["lmk2d"][:, 2:] = 0.0
else:
item["lmk2d"][:, 2:] = 1.0
if(os.path.exists(lms_eyes_path)):
with open(lms_eyes_path,'r') as f:
lms_eye = json.load(f)
lms_eye = np.array([lms_eye[key] for key in lms_eye][timestep_index]).reshape((2,2)) / 1024.
lms_eye = np.concatenate([lms_eye,np.ones((2,1))],axis=1)[(1,0),:]
item["lmk2d"] = np.concatenate([item["lmk2d"], lms_eye], 0)
else:
item["lmk2d"] = np.concatenate([item["lmk2d"]], 0)
item = self.apply_transforms(item)
return item
def getitem_by_timestep(self, timestep_index):
begin = timestep_index * self.num_cameras
indices = range(begin, begin + self.num_cameras)
item = default_collate([self.getitem_single_image(i) for i in indices])
item["num_cameras"] = self.num_cameras
return item
def apply_transforms(self, item):
item = self.apply_scale_factor(item)
item = self.apply_background_color(item)
item = self.apply_to_tensor(item)
return item
def apply_to_tensor(self, item):
if self.img_to_tensor:
if "rgb" in item:
item["rgb"] = F.to_tensor(item["rgb"])
if "alpha_map" in item:
item["alpha_map"] = F.to_tensor(item["alpha_map"])
return item
def apply_scale_factor(self, item):
assert self.cfg.scale_factor <= 1.0
if "rgb" in item:
H, W, _ = item["rgb"].shape
h, w = int(H * self.cfg.scale_factor), int(W * self.cfg.scale_factor)
rgb = Image.fromarray(item["rgb"]).resize(
(w, h), resample=Image.BILINEAR
)
item["rgb"] = np.array(rgb)
# properties that are defined based on image size
if "lmk2d" in item:
item["lmk2d"][..., 0] *= w
item["lmk2d"][..., 1] *= h
if "lmk2d_iris" in item:
item["lmk2d_iris"][..., 0] *= w
item["lmk2d_iris"][..., 1] *= h
if "bbox_2d" in item:
item["bbox_2d"][[0, 2]] *= w
item["bbox_2d"][[1, 3]] *= h
# properties need to be scaled down when rgb is downsampled
n_downsample_rgb = self.cfg.n_downsample_rgb if self.cfg.n_downsample_rgb else 1
scale_factor = self.cfg.scale_factor / n_downsample_rgb
item["scale_factor"] = scale_factor # NOTE: not self.cfg.scale_factor
if scale_factor < 1.0:
if "intrinsic" in item:
item["intrinsic"][:2] *= scale_factor
if "alpha_map" in item:
h, w = item["rgb"].shape[:2]
alpha_map = Image.fromarray(item["alpha_map"]).resize(
(w, h), Image.Resampling.BILINEAR
)
item["alpha_map"] = np.array(alpha_map)
return item
def apply_background_color(self, item):
if self.cfg.background_color is not None:
assert (
"alpha_map" in item
), "'alpha_map' is required to apply background color."
fg = item["rgb"]
if self.cfg.background_color == "white":
bg = np.ones_like(fg) * 255
elif self.cfg.background_color == "black":
bg = np.zeros_like(fg)
else:
raise NotImplementedError(
f"Unknown background color: {self.cfg.background_color}."
)
# w = item["alpha_map"][..., None] / 255
w = item["alpha_map"] / 255
img = (w * fg + (1 - w) * bg).astype(np.uint8)
item["rgb"] = img
return item
def get_property_path(
self,
name,
index: Optional[int] = None,
timestep_id: Optional[str] = None,
camera_id: Optional[str] = None,
):
p = self.properties[name]
folder = p["folder"] if "folder" in p else None
per_timestep = p["per_timestep"]
suffix = p["suffix"]
path = self.sequence_path
if folder is not None:
path = path / folder
if self.num_cameras > 1:
if camera_id is None:
assert (
index is not None), "index is required when camera_id is not provided."
camera_id = self.items[index]["camera_id"]
if "cam_id_prefix" in p:
camera_id = p["cam_id_prefix"] + camera_id
else:
camera_id = ""
if per_timestep:
if timestep_id is None:
assert index is not None, "index is required when timestep_id is not provided."
timestep_id = self.items[index]["timestep_id"]
if len(camera_id) > 0:
path /= f"{camera_id}_{timestep_id}.{suffix}"
else:
path /= f"{timestep_id}.{suffix}"
else:
if len(camera_id) > 0:
path /= f"{camera_id}.{suffix}"
else:
path = Path(str(path) + f".{suffix}")
return path
def get_property_path_list(self, name):
paths = []
for i in range(len(self.items)):
img_path = self.get_property_path(name, i)
paths.append(img_path)
return paths
@property
def num_timesteps(self):
return len(self.timestep_indices)
@property
def num_cameras(self):
return len(self.camera_ids)
if __name__ == "__main__":
import tyro
from tqdm import tqdm
from torch.utils.data import DataLoader
from vhap.config.base import DataConfig, import_module
cfg = tyro.cli(DataConfig)
cfg.use_landmark = False
dataset = import_module(cfg._target)(
cfg=cfg,
img_to_tensor=False,
batchify_all_views=True,
)
print(len(dataset))
sample = dataset[0]
print(sample.keys())
print(sample["rgb"].shape)
dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1)
for item in tqdm(dataloader):
pass