|
import json |
|
import os |
|
import random |
|
from dataclasses import dataclass, field |
|
|
|
import cv2 |
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
from ..utils.config import parse_structured |
|
from ..utils.geometry import ( |
|
get_plucker_embeds_from_cameras, |
|
get_plucker_embeds_from_cameras_ortho, |
|
get_position_map_from_depth, |
|
get_position_map_from_depth_ortho, |
|
) |
|
from ..utils.typing import * |
|
|
|
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" |
|
|
|
|
|
def _parse_scene_list_single(scene_list_path: str, root_data_dir: str): |
|
all_scenes = [] |
|
if scene_list_path.endswith(".json"): |
|
with open(scene_list_path) as f: |
|
for p in json.loads(f.read()): |
|
if "/" in p: |
|
all_scenes.append(os.path.join(root_data_dir, p)) |
|
else: |
|
all_scenes.append(os.path.join(root_data_dir, p[:2], p)) |
|
elif scene_list_path.endswith(".txt"): |
|
with open(scene_list_path) as f: |
|
for p in f.readlines(): |
|
p = p.strip() |
|
if "/" in p: |
|
all_scenes.append(os.path.join(root_data_dir, p)) |
|
else: |
|
all_scenes.append(os.path.join(root_data_dir, p[:2], p)) |
|
else: |
|
raise NotImplementedError |
|
|
|
return all_scenes |
|
|
|
|
|
def _parse_scene_list( |
|
scene_list_path: Union[str, List[str]], root_data_dir: Union[str, List[str]] |
|
): |
|
all_scenes = [] |
|
if isinstance(scene_list_path, str): |
|
scene_list_path = [scene_list_path] |
|
if isinstance(root_data_dir, str): |
|
root_data_dir = [root_data_dir] |
|
for scene_list_path_, root_data_dir_ in zip(scene_list_path, root_data_dir): |
|
all_scenes += _parse_scene_list_single(scene_list_path_, root_data_dir_) |
|
return all_scenes |
|
|
|
|
|
def _parse_reference_scene_list(reference_scenes: List[str], all_scenes: List[str]): |
|
all_ids = set(scene.split("/")[-1] for scene in all_scenes) |
|
ref_ids = set(scene.split("/")[-1] for scene in reference_scenes) |
|
common_ids = ref_ids.intersection(all_ids) |
|
all_scenes = [scene for scene in all_scenes if scene.split("/")[-1] in common_ids] |
|
all_ids = {scene.split("/")[-1]: idx for idx, scene in enumerate(all_scenes)} |
|
|
|
ref_scenes = [ |
|
scene for scene in reference_scenes if scene.split("/")[-1] in all_ids |
|
] |
|
sorted_ref_scenes = sorted(ref_scenes, key=lambda x: all_ids[x.split("/")[-1]]) |
|
scene2ref = { |
|
scene: ref_scene for scene, ref_scene in zip(all_scenes, sorted_ref_scenes) |
|
} |
|
|
|
return all_scenes, scene2ref |
|
|
|
|
|
@dataclass |
|
class MultiviewDataModuleConfig: |
|
root_dir: Any = "" |
|
scene_list: Any = "" |
|
image_suffix: str = "webp" |
|
background_color: Union[str, float] = "gray" |
|
image_names: List[str] = field(default_factory=lambda: []) |
|
image_modality: str = "render" |
|
num_views: int = 1 |
|
random_view_list: Optional[List[List[int]]] = None |
|
|
|
prompt_db_path: Optional[str] = None |
|
return_prompt: bool = False |
|
use_empty_prompt: bool = False |
|
prompt_prefix: Optional[Any] = None |
|
return_one_prompt: bool = True |
|
|
|
projection_type: str = "ORTHO" |
|
|
|
|
|
source_image_modality: Any = "position" |
|
use_camera_space_normal: bool = False |
|
position_offset: float = 0.5 |
|
position_scale: float = 1.0 |
|
plucker_offset: float = 1.0 |
|
plucker_scale: float = 2.0 |
|
|
|
|
|
reference_root_dir: Optional[Any] = None |
|
reference_scene_list: Optional[Any] = None |
|
reference_image_modality: str = "render" |
|
reference_image_names: List[str] = field(default_factory=lambda: []) |
|
reference_augment_resolutions: Optional[List[int]] = None |
|
reference_mask_aug: bool = False |
|
|
|
repeat: int = 1 |
|
|
|
train_indices: Optional[Tuple[Any, Any]] = None |
|
val_indices: Optional[Tuple[Any, Any]] = None |
|
test_indices: Optional[Tuple[Any, Any]] = None |
|
|
|
height: int = 768 |
|
width: int = 768 |
|
|
|
batch_size: int = 1 |
|
eval_batch_size: int = 1 |
|
|
|
num_workers: int = 16 |
|
|
|
|
|
class MultiviewDataset(Dataset): |
|
def __init__(self, cfg: Any, split: str = "train") -> None: |
|
super().__init__() |
|
assert split in ["train", "val", "test"] |
|
self.cfg: MultiviewDataModuleConfig = cfg |
|
self.all_scenes = _parse_scene_list(self.cfg.scene_list, self.cfg.root_dir) |
|
|
|
if ( |
|
self.cfg.reference_root_dir is not None |
|
and self.cfg.reference_scene_list is not None |
|
): |
|
reference_scenes = _parse_scene_list( |
|
self.cfg.reference_scene_list, self.cfg.reference_root_dir |
|
) |
|
self.all_scenes, self.reference_scenes = _parse_reference_scene_list( |
|
reference_scenes, self.all_scenes |
|
) |
|
else: |
|
self.reference_scenes = None |
|
|
|
self.split = split |
|
if self.split == "train" and self.cfg.train_indices is not None: |
|
self.all_scenes = self.all_scenes[ |
|
self.cfg.train_indices[0] : self.cfg.train_indices[1] |
|
] |
|
self.all_scenes = self.all_scenes * self.cfg.repeat |
|
elif self.split == "val" and self.cfg.val_indices is not None: |
|
self.all_scenes = self.all_scenes[ |
|
self.cfg.val_indices[0] : self.cfg.val_indices[1] |
|
] |
|
elif self.split == "test" and self.cfg.test_indices is not None: |
|
self.all_scenes = self.all_scenes[ |
|
self.cfg.test_indices[0] : self.cfg.test_indices[1] |
|
] |
|
|
|
if self.cfg.prompt_db_path is not None: |
|
self.prompt_db = json.load(open(self.cfg.prompt_db_path)) |
|
else: |
|
self.prompt_db = None |
|
|
|
def __len__(self): |
|
return len(self.all_scenes) |
|
|
|
def get_bg_color(self, bg_color): |
|
if bg_color == "white": |
|
bg_color = np.array([1.0, 1.0, 1.0], dtype=np.float32) |
|
elif bg_color == "black": |
|
bg_color = np.array([0.0, 0.0, 0.0], dtype=np.float32) |
|
elif bg_color == "gray": |
|
bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32) |
|
elif bg_color == "random": |
|
bg_color = np.random.rand(3) |
|
elif bg_color == "random_gray": |
|
bg_color = random.uniform(0.3, 0.7) |
|
bg_color = np.array([bg_color] * 3, dtype=np.float32) |
|
elif isinstance(bg_color, float): |
|
bg_color = np.array([bg_color] * 3, dtype=np.float32) |
|
elif isinstance(bg_color, list) or isinstance(bg_color, tuple): |
|
bg_color = np.array(bg_color, dtype=np.float32) |
|
else: |
|
raise NotImplementedError |
|
return bg_color |
|
|
|
def load_image( |
|
self, |
|
image: Union[str, Image.Image], |
|
height: int, |
|
width: int, |
|
background_color: torch.Tensor, |
|
rescale: bool = False, |
|
mask_aug: bool = False, |
|
): |
|
if isinstance(image, str): |
|
image = Image.open(image) |
|
|
|
image = image.resize((width, height)) |
|
image = torch.from_numpy(np.array(image)).float() / 255.0 |
|
|
|
if mask_aug: |
|
alpha = image[:, :, 3] |
|
h, w = alpha.shape |
|
y_indices, x_indices = torch.where(alpha > 0.5) |
|
if len(y_indices) > 0 and len(x_indices) > 0: |
|
idx = torch.randint(len(y_indices), (1,)).item() |
|
y_center = y_indices[idx].item() |
|
x_center = x_indices[idx].item() |
|
mask_h = random.randint(h // 8, h // 4) |
|
mask_w = random.randint(w // 8, w // 4) |
|
|
|
y1 = max(0, y_center - mask_h // 2) |
|
y2 = min(h, y_center + mask_h // 2) |
|
x1 = max(0, x_center - mask_w // 2) |
|
x2 = min(w, x_center + mask_w // 2) |
|
|
|
alpha[y1:y2, x1:x2] = 0.0 |
|
image[:, :, 3] = alpha |
|
|
|
image = image[:, :, :3] * image[:, :, 3:4] + background_color * ( |
|
1 - image[:, :, 3:4] |
|
) |
|
if rescale: |
|
image = image * 2.0 - 1.0 |
|
return image |
|
|
|
def load_normal_image( |
|
self, |
|
path, |
|
height, |
|
width, |
|
background_color, |
|
camera_space: bool = False, |
|
c2w: Optional[torch.FloatTensor] = None, |
|
): |
|
image = Image.open(path).resize((width, height), resample=Image.NEAREST) |
|
image = torch.from_numpy(np.array(image)).float() / 255.0 |
|
alpha = image[:, :, 3:4] |
|
image = image[:, :, :3] |
|
if camera_space: |
|
w2c = torch.linalg.inv(c2w)[:3, :3] |
|
image = ( |
|
F.normalize(((image * 2 - 1)[:, :, None, :] * w2c).sum(-1), dim=-1) |
|
* 0.5 |
|
+ 0.5 |
|
) |
|
image = image * alpha + background_color * (1 - alpha) |
|
return image |
|
|
|
def load_depth(self, path, height, width): |
|
depth = cv2.imread(path, cv2.IMREAD_UNCHANGED) |
|
depth = cv2.resize(depth, (width, height), interpolation=cv2.INTER_NEAREST) |
|
depth = torch.from_numpy(depth[..., 0:1]).float() |
|
mask = torch.ones_like(depth) |
|
mask[depth > 1000.0] = 0.0 |
|
depth[~(mask > 0.5)] = 0.0 |
|
return depth, mask |
|
|
|
def retrieve_prompt(self, scene_dir): |
|
assert self.prompt_db is not None |
|
source_id = os.path.basename(scene_dir) |
|
return self.prompt_db.get(source_id, "") |
|
|
|
def __getitem__(self, index): |
|
background_color = torch.as_tensor(self.get_bg_color(self.cfg.background_color)) |
|
scene_dir = self.all_scenes[index] |
|
|
|
with open(os.path.join(scene_dir, "meta.json")) as f: |
|
meta = json.load(f) |
|
name2loc = {loc["index"]: loc for loc in meta["locations"]} |
|
|
|
|
|
image_paths = [ |
|
os.path.join( |
|
scene_dir, f"{self.cfg.image_modality}_{f}.{self.cfg.image_suffix}" |
|
) |
|
for f in self.cfg.image_names |
|
] |
|
images = [ |
|
self.load_image( |
|
p, |
|
height=self.cfg.height, |
|
width=self.cfg.width, |
|
background_color=background_color, |
|
) |
|
for p in image_paths |
|
] |
|
images = torch.stack(images, dim=0).permute(0, 3, 1, 2) |
|
|
|
|
|
c2w = [ |
|
torch.as_tensor(name2loc[name]["transform_matrix"]) |
|
for name in self.cfg.image_names |
|
] |
|
c2w = torch.stack(c2w, dim=0) |
|
|
|
if self.cfg.projection_type == "PERSP": |
|
camera_angle_x = ( |
|
meta.get("camera_angle_x", None) |
|
or meta["locations"][0]["camera_angle_x"] |
|
) |
|
focal_length = 0.5 * self.cfg.width / np.tan(0.5 * camera_angle_x) |
|
intrinsics = ( |
|
torch.as_tensor( |
|
[ |
|
[focal_length, 0.0, 0.5 * self.cfg.width], |
|
[0.0, focal_length, 0.5 * self.cfg.height], |
|
[0.0, 0.0, 1.0], |
|
] |
|
) |
|
.unsqueeze(0) |
|
.float() |
|
.repeat(len(self.cfg.image_names), 1, 1) |
|
) |
|
elif self.cfg.projection_type == "ORTHO": |
|
ortho_scale = ( |
|
meta.get("ortho_scale", None) or meta["locations"][0]["ortho_scale"] |
|
) |
|
|
|
|
|
source_image_modality = self.cfg.source_image_modality |
|
if isinstance(source_image_modality, str): |
|
source_image_modality = [source_image_modality] |
|
source_images = [] |
|
for modality in source_image_modality: |
|
if modality == "position": |
|
depth_masks = [ |
|
self.load_depth( |
|
os.path.join(scene_dir, f"depth_{f}.exr"), |
|
self.cfg.height, |
|
self.cfg.width, |
|
) |
|
for f in self.cfg.image_names |
|
] |
|
depths = torch.stack([d for d, _ in depth_masks]) |
|
masks = torch.stack([m for _, m in depth_masks]) |
|
c2w_ = c2w.clone() |
|
c2w_[:, :, 1:3] *= -1 |
|
|
|
if self.cfg.projection_type == "PERSP": |
|
position_maps = get_position_map_from_depth( |
|
depths, |
|
masks, |
|
intrinsics, |
|
c2w_, |
|
image_wh=(self.cfg.width, self.cfg.height), |
|
) |
|
elif self.cfg.projection_type == "ORTHO": |
|
position_maps = get_position_map_from_depth_ortho( |
|
depths, |
|
masks, |
|
c2w_, |
|
ortho_scale, |
|
image_wh=(self.cfg.width, self.cfg.height), |
|
) |
|
position_maps = ( |
|
(position_maps + self.cfg.position_offset) / self.cfg.position_scale |
|
).clamp(0.0, 1.0) |
|
source_images.append(position_maps) |
|
elif modality == "normal": |
|
normal_maps = [ |
|
self.load_normal_image( |
|
os.path.join( |
|
scene_dir, f"{modality}_{f}.{self.cfg.image_suffix}" |
|
), |
|
height=self.cfg.height, |
|
width=self.cfg.width, |
|
background_color=background_color, |
|
camera_space=self.cfg.use_camera_space_normal, |
|
c2w=c, |
|
) |
|
for c, f in zip(c2w, self.cfg.image_names) |
|
] |
|
source_images.append(torch.stack(normal_maps, dim=0)) |
|
elif modality == "plucker": |
|
if self.cfg.projection_type == "ORTHO": |
|
plucker_embed = get_plucker_embeds_from_cameras_ortho( |
|
c2w, [ortho_scale] * len(c2w), self.cfg.width |
|
) |
|
elif self.cfg.projection_type == "PERSP": |
|
plucker_embed = get_plucker_embeds_from_cameras( |
|
c2w, [camera_angle_x] * len(c2w), self.cfg.width |
|
) |
|
else: |
|
raise NotImplementedError |
|
plucker_embed = plucker_embed.permute(0, 2, 3, 1) |
|
plucker_embed = ( |
|
(plucker_embed + self.cfg.plucker_offset) / self.cfg.plucker_scale |
|
).clamp(0.0, 1.0) |
|
source_images.append(plucker_embed) |
|
else: |
|
raise NotImplementedError |
|
source_images = torch.cat(source_images, dim=-1).permute(0, 3, 1, 2) |
|
rv = {"rgb": images, "c2w": c2w, "source_rgb": source_images} |
|
|
|
num_images = len(self.cfg.image_names) |
|
|
|
if self.cfg.return_prompt: |
|
if self.cfg.use_empty_prompt: |
|
prompt = "" |
|
else: |
|
prompt = self.retrieve_prompt(scene_dir) |
|
prompts = [prompt] * num_images |
|
|
|
if self.cfg.prompt_prefix is not None: |
|
prompt_prefix = self.cfg.prompt_prefix |
|
if isinstance(prompt_prefix, str): |
|
prompt_prefix = [prompt_prefix] * num_images |
|
|
|
for i, prompt in enumerate(prompts): |
|
prompts[i] = f"{prompt_prefix[i]} {prompt}" |
|
|
|
if self.cfg.return_one_prompt: |
|
rv.update({"prompts": prompts[0]}) |
|
else: |
|
rv.update({"prompts": prompts}) |
|
|
|
|
|
if self.reference_scenes is not None: |
|
reference_scene_dir = self.reference_scenes[scene_dir] |
|
reference_image_paths = [ |
|
os.path.join( |
|
reference_scene_dir, |
|
f"{self.cfg.reference_image_modality}_{f}.{self.cfg.image_suffix}", |
|
) |
|
for f in self.cfg.reference_image_names |
|
] |
|
reference_image_path = random.choice(reference_image_paths) |
|
|
|
if self.cfg.reference_augment_resolutions is None: |
|
reference_image = self.load_image( |
|
reference_image_path, |
|
height=self.cfg.height, |
|
width=self.cfg.width, |
|
background_color=background_color, |
|
mask_aug=self.cfg.reference_mask_aug, |
|
).permute(2, 0, 1) |
|
rv.update({"reference_rgb": reference_image}) |
|
else: |
|
random_resolution = random.choice( |
|
self.cfg.reference_augment_resolutions |
|
) |
|
reference_image_ = Image.open(reference_image_path).resize( |
|
(random_resolution, random_resolution) |
|
) |
|
reference_image = self.load_image( |
|
reference_image_, |
|
height=self.cfg.height, |
|
width=self.cfg.width, |
|
background_color=background_color, |
|
mask_aug=self.cfg.reference_mask_aug, |
|
).permute(2, 0, 1) |
|
rv.update({"reference_rgb": reference_image}) |
|
|
|
return rv |
|
|
|
def collate(self, batch): |
|
batch = torch.utils.data.default_collate(batch) |
|
pack = lambda t: t.view(-1, *t.shape[2:]) |
|
|
|
if self.cfg.random_view_list is not None: |
|
indices = random.choice(self.cfg.random_view_list) |
|
else: |
|
indices = list(range(self.cfg.num_views)) |
|
num_views = len(indices) |
|
|
|
for k in batch.keys(): |
|
if k in ["rgb", "source_rgb", "c2w"]: |
|
batch[k] = batch[k][:, indices] |
|
batch[k] = pack(batch[k]) |
|
for k in ["prompts"]: |
|
if not self.cfg.return_one_prompt: |
|
batch[k] = [item for pair in zip(*batch[k]) for item in pair] |
|
|
|
batch.update( |
|
{ |
|
"num_views": num_views, |
|
|
|
"original_size": (self.cfg.height, self.cfg.width), |
|
"target_size": (self.cfg.height, self.cfg.width), |
|
"crops_coords_top_left": (0, 0), |
|
} |
|
) |
|
return batch |
|
|
|
|
|
class MultiviewDataModule(pl.LightningDataModule): |
|
cfg: MultiviewDataModuleConfig |
|
|
|
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: |
|
super().__init__() |
|
self.cfg = parse_structured(MultiviewDataModuleConfig, cfg) |
|
|
|
def setup(self, stage=None) -> None: |
|
if stage in [None, "fit"]: |
|
self.train_dataset = MultiviewDataset(self.cfg, "train") |
|
if stage in [None, "fit", "validate"]: |
|
self.val_dataset = MultiviewDataset(self.cfg, "val") |
|
if stage in [None, "test", "predict"]: |
|
self.test_dataset = MultiviewDataset(self.cfg, "test") |
|
|
|
def prepare_data(self): |
|
pass |
|
|
|
def train_dataloader(self) -> DataLoader: |
|
return DataLoader( |
|
self.train_dataset, |
|
batch_size=self.cfg.batch_size, |
|
num_workers=self.cfg.num_workers, |
|
shuffle=True, |
|
collate_fn=self.train_dataset.collate, |
|
) |
|
|
|
def val_dataloader(self) -> DataLoader: |
|
return DataLoader( |
|
self.val_dataset, |
|
batch_size=self.cfg.eval_batch_size, |
|
num_workers=self.cfg.num_workers, |
|
shuffle=False, |
|
collate_fn=self.val_dataset.collate, |
|
) |
|
|
|
def test_dataloader(self) -> DataLoader: |
|
return DataLoader( |
|
self.test_dataset, |
|
batch_size=self.cfg.eval_batch_size, |
|
num_workers=self.cfg.num_workers, |
|
shuffle=False, |
|
collate_fn=self.test_dataset.collate, |
|
) |
|
|
|
def predict_dataloader(self) -> DataLoader: |
|
return self.test_dataloader() |
|
|