|
from typing import Dict
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from pathlib import Path
|
|
import json
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
from einops import rearrange, repeat
|
|
from typing import Literal, Tuple, Optional, Any
|
|
import cv2
|
|
import random
|
|
|
|
import json
|
|
import os, sys
|
|
import math
|
|
|
|
from PIL import Image, ImageOps
|
|
from normal_utils import worldNormal2camNormal, plot_grid_images, img2normal, norm_normalize, deg2rad
|
|
|
|
import pdb
|
|
from icecream import ic
|
|
def shift_list(lst, n):
|
|
length = len(lst)
|
|
n = n % length
|
|
return lst[-n:] + lst[:-n]
|
|
|
|
|
|
class ObjaverseDataset(Dataset):
|
|
def __init__(self,
|
|
root_dir: str,
|
|
azi_interval: float,
|
|
random_views: int,
|
|
predict_relative_views: list,
|
|
bg_color: Any,
|
|
object_list: str,
|
|
prompt_embeds_path: str,
|
|
img_wh: Tuple[int, int],
|
|
validation: bool = False,
|
|
num_validation_samples: int = 64,
|
|
num_samples: Optional[int] = None,
|
|
invalid_list: Optional[str] = None,
|
|
trans_norm_system: bool = True,
|
|
|
|
side_views_rate: float = 0.,
|
|
read_normal: bool = True,
|
|
read_color: bool = False,
|
|
read_depth: bool = False,
|
|
mix_color_normal: bool = False,
|
|
random_view_and_domain: bool = False,
|
|
load_cache: bool = False,
|
|
exten: str = '.png',
|
|
elevation_list: Optional[str] = None,
|
|
) -> None:
|
|
"""Create a dataset from a folder of images.
|
|
If you pass in a root directory it will be searched for images
|
|
ending in ext (ext can be a list)
|
|
"""
|
|
self.root_dir = root_dir
|
|
self.fixed_views = int(360 // azi_interval)
|
|
self.bg_color = bg_color
|
|
self.validation = validation
|
|
self.num_samples = num_samples
|
|
self.trans_norm_system = trans_norm_system
|
|
|
|
self.invalid_list = invalid_list
|
|
self.img_wh = img_wh
|
|
self.read_normal = read_normal
|
|
self.read_color = read_color
|
|
self.read_depth = read_depth
|
|
self.mix_color_normal = mix_color_normal
|
|
self.random_view_and_domain = random_view_and_domain
|
|
self.random_views = random_views
|
|
self.load_cache = load_cache
|
|
self.total_views = int(self.fixed_views * (self.random_views + 1))
|
|
self.predict_relative_views = predict_relative_views
|
|
self.pred_view_nums = len(self.predict_relative_views)
|
|
self.exten = exten
|
|
self.side_views_rate = side_views_rate
|
|
|
|
|
|
ic(self.total_views)
|
|
ic(self.fixed_views)
|
|
ic(self.predict_relative_views)
|
|
|
|
self.objects = []
|
|
if object_list is not None:
|
|
for dataset_list in object_list:
|
|
with open(dataset_list, 'r') as f:
|
|
|
|
|
|
objects = json.load(f)
|
|
self.objects.extend(objects)
|
|
else:
|
|
self.objects = os.listdir(self.root_dir)
|
|
|
|
|
|
self.trans_cv2gl_mat = np.linalg.inv(np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]))
|
|
self.fix_cam_poses = []
|
|
camera_path = os.path.join(self.root_dir, self.objects[0], 'camera')
|
|
for vid in range(0, self.total_views, self.random_views+1):
|
|
cam_info = np.load(f'{camera_path}/{vid:03d}.npy', allow_pickle=True).item()
|
|
assert cam_info['camera'] == 'ortho', 'Only support predict ortho camera !!!'
|
|
self.fix_cam_poses.append(cam_info['extrinsic'])
|
|
random.shuffle(self.objects)
|
|
|
|
|
|
invalid_objects = []
|
|
if self.invalid_list is not None:
|
|
for invalid_list in self.invalid_list:
|
|
if invalid_list[-4:] == '.txt':
|
|
with open(invalid_list, 'r') as f:
|
|
sub_invalid = f.readlines()
|
|
invalid_objects.extend([o.strip() for o in sub_invalid])
|
|
else:
|
|
with open(invalid_list) as f:
|
|
invalid_objects.extend(json.load(f))
|
|
self.invalid_objects = invalid_objects
|
|
ic(len(self.invalid_objects))
|
|
|
|
if elevation_list:
|
|
with open(elevation_list, 'r') as f:
|
|
ele_list = [o.strip() for o in f.readlines()]
|
|
self.objects = set(ele_list) & set(self.objects)
|
|
|
|
self.all_objects = set(self.objects) - (set(self.invalid_objects) & set(self.objects))
|
|
self.all_objects = list(self.all_objects)
|
|
|
|
self.validation = validation
|
|
if not validation:
|
|
self.all_objects = self.all_objects[:-num_validation_samples]
|
|
|
|
|
|
else:
|
|
self.all_objects = self.all_objects[-num_validation_samples:]
|
|
|
|
if num_samples is not None:
|
|
self.all_objects = self.all_objects[:num_samples]
|
|
ic(len(self.all_objects))
|
|
print("loading ", len(self.all_objects), " objects in the dataset")
|
|
|
|
self.normal_prompt_embedding = torch.load(f'{prompt_embeds_path}/normal_embeds.pt')
|
|
self.color_prompt_embedding = torch.load(f'{prompt_embeds_path}/clr_embeds.pt')
|
|
|
|
if self.mix_color_normal:
|
|
self.backup_data = self.__getitem_mix__(0, '8609cf7e67bf413487a7d94c73aeaa3e')
|
|
else:
|
|
self.backup_data = self.__getitem_norm__(0, '8609cf7e67bf413487a7d94c73aeaa3e')
|
|
|
|
def trans_cv2gl(self, rt):
|
|
r, t = rt[:3, :3], rt[:3, -1]
|
|
r = np.matmul(self.trans_cv2gl_mat, r)
|
|
t = np.matmul(self.trans_cv2gl_mat, t)
|
|
return np.concatenate([r, t[:, None]], axis=-1)
|
|
|
|
def get_bg_color(self):
|
|
if self.bg_color == 'white':
|
|
bg_color = np.array([1., 1., 1.], dtype=np.float32)
|
|
elif self.bg_color == 'black':
|
|
bg_color = np.array([0., 0., 0.], dtype=np.float32)
|
|
elif self.bg_color == 'gray':
|
|
bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
|
|
elif self.bg_color == 'random':
|
|
bg_color = np.random.rand(3)
|
|
elif self.bg_color == 'three_choices':
|
|
white = np.array([1., 1., 1.], dtype=np.float32)
|
|
black = np.array([0., 0., 0.], dtype=np.float32)
|
|
gray = np.array([0.5, 0.5, 0.5], dtype=np.float32)
|
|
bg_color = random.choice([white, black, gray])
|
|
elif isinstance(self.bg_color, float):
|
|
bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
|
|
else:
|
|
raise NotImplementedError
|
|
return bg_color
|
|
|
|
|
|
def load_image(self, img_path, bg_color, alpha=None, return_type='np'):
|
|
|
|
|
|
|
|
|
|
rgba = np.array(Image.open(img_path).resize(self.img_wh))
|
|
rgba = rgba.astype(np.float32) / 255.
|
|
|
|
img = rgba[..., :3]
|
|
if alpha is None:
|
|
assert rgba.shape[-1] == 4
|
|
alpha = rgba[..., 3:4]
|
|
assert alpha.sum() > 1e-8, 'w/o foreground'
|
|
img = img[...,:3] * alpha + bg_color * (1 - alpha)
|
|
|
|
if return_type == "np":
|
|
pass
|
|
elif return_type == "pt":
|
|
img = torch.from_numpy(img)
|
|
alpha = torch.from_numpy(alpha)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
return img, alpha
|
|
|
|
def load_depth(self, img_path, bg_color, alpha, input_type='png', return_type='np'):
|
|
|
|
|
|
|
|
|
|
img = np.array(Image.open(img_path).resize(self.img_wh))
|
|
img = img.astype(np.float32) / 65535.
|
|
|
|
img[img > 0.4] = 0
|
|
img = img / 0.4
|
|
|
|
assert img.ndim == 2
|
|
img = np.stack([img]*3, axis=-1)
|
|
|
|
if alpha.shape[-1] != 1:
|
|
alpha = alpha[:, :, None]
|
|
|
|
|
|
|
|
img = img[...,:3] * alpha + bg_color * (1 - alpha)
|
|
|
|
if return_type == "np":
|
|
pass
|
|
elif return_type == "pt":
|
|
img = torch.from_numpy(img)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
return img
|
|
|
|
def load_normal(self, img_path, bg_color, alpha, RT_w2c_cond=None, return_type='np'):
|
|
normal_np = np.array(Image.open(img_path).resize(self.img_wh))[:, :, :3]
|
|
assert np.var(normal_np) > 1e-8, 'pure normal'
|
|
normal_cv = img2normal(normal_np)
|
|
|
|
normal_relative_cv = worldNormal2camNormal(RT_w2c_cond[:3, :3], normal_cv)
|
|
normal_relative_cv = norm_normalize(normal_relative_cv)
|
|
|
|
|
|
normal_relative_gl = normal_relative_cv
|
|
normal_relative_gl[..., 1:] = -normal_relative_gl[..., 1:]
|
|
|
|
img = (normal_relative_cv*0.5 + 0.5).astype(np.float32)
|
|
|
|
if alpha.shape[-1] != 1:
|
|
alpha = alpha[:, :, None]
|
|
|
|
|
|
img = img[...,:3] * alpha + bg_color * (1 - alpha)
|
|
|
|
if return_type == "np":
|
|
pass
|
|
elif return_type == "pt":
|
|
img = torch.from_numpy(img)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
return img
|
|
|
|
def __len__(self):
|
|
return len(self.all_objects)
|
|
|
|
def __getitem_norm__(self, index, debug_object=None):
|
|
|
|
bg_color = self.get_bg_color()
|
|
if debug_object is not None:
|
|
object_name = debug_object
|
|
else:
|
|
object_name = self.all_objects[index % len(self.all_objects)]
|
|
|
|
if self.validation:
|
|
cond_ele0_idx = 12
|
|
else:
|
|
rand = random.random()
|
|
if rand < self.side_views_rate:
|
|
cond_ele0_idx = random.sample([8, 0], 1)[0]
|
|
elif rand < 3 * self.side_views_rate:
|
|
cond_ele0_idx = random.sample([10, 14], 1)[0]
|
|
else:
|
|
cond_ele0_idx = 12
|
|
cond_random_idx = random.sample(range(self.random_views+1), 1)[0]
|
|
|
|
|
|
cond_ele0_vid = cond_ele0_idx * (self.random_views + 1)
|
|
cond_vid = cond_ele0_vid + cond_random_idx
|
|
cond_ele0_w2c = self.fix_cam_poses[cond_ele0_idx]
|
|
cond_info = np.load(f'{self.root_dir}/{object_name}/camera/{cond_vid:03d}.npy', allow_pickle=True).item()
|
|
cond_type = cond_info['camera']
|
|
focal_len = cond_info['focal']
|
|
|
|
cond_eles = np.array([deg2rad(cond_info['elevation'])])
|
|
|
|
img_tensors_in = [
|
|
self.load_image(f"{self.root_dir}/{object_name}/image/{cond_vid:03d}{self.exten}", bg_color, return_type='pt')[0].permute(2, 0, 1)
|
|
] * self.pred_view_nums
|
|
|
|
|
|
pred_vids = [(cond_ele0_vid + i * (self.random_views+1)) % self.total_views for i in self.predict_relative_views]
|
|
|
|
img_tensors_out = []
|
|
normal_tensors_out = []
|
|
for i, vid in enumerate(pred_vids):
|
|
try:
|
|
img_tensor, alpha_ = self.load_image(f"{self.root_dir}/{object_name}/image/{vid:03d}{self.exten}", bg_color, return_type='pt')
|
|
except:
|
|
img_tensor, alpha_ = self.load_image(f"{self.root_dir}/{object_name}/image_relit/{vid:03d}{self.exten}", bg_color, return_type='pt')
|
|
|
|
img_tensor = img_tensor.permute(2, 0, 1)
|
|
img_tensors_out.append(img_tensor)
|
|
|
|
|
|
normal_tensor = self.load_normal(f"{self.root_dir}/{object_name}/normal/{vid:03d}{self.exten}", bg_color, alpha_.numpy(), RT_w2c_cond=cond_ele0_w2c[:3, :], return_type="pt").permute(2, 0, 1)
|
|
normal_tensors_out.append(normal_tensor)
|
|
|
|
|
|
img_tensors_in = torch.stack(img_tensors_in, dim=0).float()
|
|
img_tensors_out = torch.stack(img_tensors_out, dim=0).float()
|
|
normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float()
|
|
|
|
elevations_cond = torch.as_tensor(cond_eles).float()
|
|
if cond_type == 'ortho':
|
|
focal_embed = torch.tensor([0.])
|
|
else:
|
|
focal_embed = torch.tensor([24./focal_len])
|
|
|
|
|
|
if not self.load_cache:
|
|
return {
|
|
'elevations_cond': elevations_cond,
|
|
'focal_cond': focal_embed,
|
|
'id': object_name,
|
|
'vid':cond_vid,
|
|
'imgs_in': img_tensors_in,
|
|
'imgs_out': img_tensors_out,
|
|
'normals_out': normal_tensors_out,
|
|
'normal_prompt_embeddings': self.normal_prompt_embedding,
|
|
'color_prompt_embeddings': self.color_prompt_embedding
|
|
}
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
try:
|
|
return self.__getitem_norm__(index)
|
|
except:
|
|
print("load error ", self.all_objects[index%len(self.all_objects)] )
|
|
return self.backup_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|