Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import logging | |
import os | |
import random | |
from typing import Any, Callable, Dict, List, Tuple, Union | |
import torch | |
import torch.utils.checkpoint | |
from PIL import Image | |
from torch import nn | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO | |
) | |
logger = logging.getLogger(__name__) | |
__all__ = [ | |
"Asset3dGenDataset", | |
] | |
class Asset3dGenDataset(Dataset): | |
def __init__( | |
self, | |
index_file: str, | |
target_hw: Tuple[int, int], | |
transform: Callable = None, | |
control_transform: Callable = None, | |
max_train_samples: int = None, | |
sub_idxs: List[List[int]] = None, | |
seed: int = 79, | |
) -> None: | |
if not os.path.exists(index_file): | |
raise FileNotFoundError(f"{index_file} index_file not found.") | |
self.index_file = index_file | |
self.target_hw = target_hw | |
self.transform = transform | |
self.control_transform = control_transform | |
self.max_train_samples = max_train_samples | |
self.meta_info = self.prepare_data_index(index_file) | |
self.data_list = sorted(self.meta_info.keys()) | |
self.sub_idxs = sub_idxs # sub_idxs [[0,1,2], [3,4,5], [...], ...] | |
self.image_num = 6 # hardcode temp. | |
random.seed(seed) | |
logger.info(f"Trainset: {len(self)} asset3d instances.") | |
def __len__(self) -> int: | |
return len(self.meta_info) | |
def prepare_data_index(self, index_file: str) -> Dict[str, Any]: | |
with open(index_file, "r") as fin: | |
meta_info = json.load(fin) | |
meta_info_filtered = dict() | |
for idx, uid in enumerate(meta_info): | |
if "status" not in meta_info[uid]: | |
continue | |
if meta_info[uid]["status"] != "success": | |
continue | |
if self.max_train_samples and idx >= self.max_train_samples: | |
break | |
meta_info_filtered[uid] = meta_info[uid] | |
logger.info( | |
f"Load {len(meta_info)} assets, keep {len(meta_info_filtered)} valids." # noqa | |
) | |
return meta_info_filtered | |
def fetch_sample_images( | |
self, | |
uid: str, | |
attrs: List[str], | |
sub_index: int = None, | |
transform: Callable = None, | |
) -> torch.Tensor: | |
sample = self.meta_info[uid] | |
images = [] | |
for attr in attrs: | |
item = sample[attr] | |
if sub_index is not None: | |
item = item[sub_index] | |
mode = "L" if attr == "image_mask" else "RGB" | |
image = Image.open(item).convert(mode) | |
if transform is not None: | |
image = transform(image) | |
if len(image.shape) == 2: | |
image = image[..., None] | |
images.append(image) | |
images = torch.cat(images, dim=0) | |
return images | |
def fetch_sample_grid_images( | |
self, | |
uid: str, | |
attrs: List[str], | |
sub_idxs: List[List[int]], | |
transform: Callable = None, | |
) -> torch.Tensor: | |
assert transform is not None | |
grid_image = [] | |
for row_idxs in sub_idxs: | |
row_image = [] | |
for row_idx in row_idxs: | |
image = self.fetch_sample_images( | |
uid, attrs, row_idx, transform | |
) | |
row_image.append(image) | |
row_image = torch.cat(row_image, dim=2) # (c h w) | |
grid_image.append(row_image) | |
grid_image = torch.cat(grid_image, dim=1) | |
return grid_image | |
def compute_text_embeddings( | |
self, embed_path: str, original_size: Tuple[int, int] | |
) -> Dict[str, nn.Module]: | |
data_dict = torch.load(embed_path) | |
prompt_embeds = data_dict["prompt_embeds"][0] | |
add_text_embeds = data_dict["pooled_prompt_embeds"][0] | |
# Need changed if random crop, set as crop_top_left [y1, x1], center crop as [0, 0]. # noqa | |
crops_coords_top_left = (0, 0) | |
add_time_ids = list( | |
original_size + crops_coords_top_left + self.target_hw | |
) | |
add_time_ids = torch.tensor([add_time_ids]) | |
# add_time_ids = add_time_ids.repeat((len(add_text_embeds), 1)) | |
unet_added_cond_kwargs = { | |
"text_embeds": add_text_embeds, | |
"time_ids": add_time_ids, | |
} | |
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} | |
def visualize_item( | |
self, | |
control: torch.Tensor, | |
color: torch.Tensor, | |
save_dir: str = None, | |
) -> List[Image.Image]: | |
to_pil = transforms.ToPILImage() | |
color = (color + 1) / 2 | |
color_pil = to_pil(color) | |
normal_pil = to_pil(control[0:3]) | |
position_pil = to_pil(control[3:6]) | |
mask_pil = to_pil(control[6:]) | |
if save_dir is not None: | |
os.makedirs(save_dir, exist_ok=True) | |
color_pil.save(f"{save_dir}/rgb.jpg") | |
normal_pil.save(f"{save_dir}/normal.jpg") | |
position_pil.save(f"{save_dir}/position.jpg") | |
mask_pil.save(f"{save_dir}/mask.jpg") | |
logger.info(f"Visualization in {save_dir}") | |
return normal_pil, position_pil, mask_pil, color_pil | |
def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: | |
uid = self.data_list[index] | |
sub_idxs = self.sub_idxs | |
if sub_idxs is None: | |
sub_idxs = [[random.randint(0, self.image_num - 1)]] | |
input_image = self.fetch_sample_grid_images( | |
uid, | |
attrs=["image_view_normal", "image_position", "image_mask"], | |
sub_idxs=sub_idxs, | |
transform=self.control_transform, | |
) | |
assert input_image.shape[1:] == self.target_hw | |
output_image = self.fetch_sample_grid_images( | |
uid, | |
attrs=["image_color"], | |
sub_idxs=sub_idxs, | |
transform=self.transform, | |
) | |
sample = self.meta_info[uid] | |
text_feats = self.compute_text_embeddings( | |
sample["text_feat"], tuple(sample["image_hw"]) | |
) | |
data = dict( | |
pixel_values=output_image, | |
conditioning_pixel_values=input_image, | |
prompt_embeds=text_feats["prompt_embeds"], | |
text_embeds=text_feats["text_embeds"], | |
time_ids=text_feats["time_ids"], | |
) | |
return data | |
if __name__ == "__main__": | |
index_file = "/horizon-bucket/robot_lab/users/xinjie.wang/datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa | |
target_hw = (512, 512) | |
transform_list = [ | |
transforms.Resize( | |
target_hw, interpolation=transforms.InterpolationMode.BILINEAR | |
), | |
transforms.CenterCrop(target_hw), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
] | |
image_transform = transforms.Compose(transform_list) | |
control_transform = transforms.Compose(transform_list[:-1]) | |
sub_idxs = [[0, 1, 2], [3, 4, 5]] # None | |
if sub_idxs is not None: | |
target_hw = ( | |
target_hw[0] * len(sub_idxs), | |
target_hw[1] * len(sub_idxs[0]), | |
) | |
dataset = Asset3dGenDataset( | |
index_file, | |
target_hw, | |
image_transform, | |
control_transform, | |
sub_idxs=sub_idxs, | |
) | |
data = dataset[0] | |
dataset.visualize_item( | |
data["conditioning_pixel_values"], data["pixel_values"], save_dir="./" | |
) | |