Hunyuan3D-2.1 / hy3dpaint /src /data /dataloader /objaverse_loader_forTexturePBR.py
Huiwenshi's picture
Upload folder using huggingface_hub
600759a verified
raw
history blame
6.06 kB
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import os
import time
import glob
import json
import random
import numpy as np
import torch
from .loader_util import BaseDataset
class TextureDataset(BaseDataset):
def __init__(
self, json_path, num_view=6, image_size=512, lighting_suffix_pool=["light_PL", "light_AL", "light_ENVMAP"]
):
self.data = list()
self.num_view = num_view
self.image_size = image_size
self.lighting_suffix_pool = lighting_suffix_pool
if isinstance(json_path, str):
json_path = [json_path]
for jp in json_path:
with open(jp) as f:
self.data.extend(json.load(f))
print("============= length of dataset %d =============" % len(self.data))
def __getitem__(self, index):
try_sleep_interval = 20
total_try_num = 100
cnt = try_sleep_interval * total_try_num
# try:
images_ref = list()
images_albedo = list()
images_mr = list()
images_normal = list()
images_position = list()
bg_white = [1.0, 1.0, 1.0]
bg_black = [0.0, 0.0, 0.0]
bg_gray = [127 / 255.0, 127 / 255.0, 127 / 255.0]
dirx = self.data[index]
condition_dict = {}
# 6view
fix_num_view = self.num_view
available_views = []
for ext in ["*_albedo.png", "*_albedo.jpg", "*_albedo.jpeg"]:
available_views.extend(glob.glob(os.path.join(dirx, "render_tex", ext)))
cond_images = (
glob.glob(os.path.join(dirx, "render_cond", "*.png"))
+ glob.glob(os.path.join(dirx, "render_cond", "*.jpg"))
+ glob.glob(os.path.join(dirx, "render_cond", "*.jpeg"))
)
# 确保有足够的样本
if len(available_views) < fix_num_view:
print(
f"Warning: Only {len(available_views)} views available, but {fix_num_view} requested."
"Using all available views."
)
images_gen = available_views
else:
images_gen = random.sample(available_views, fix_num_view)
if not cond_images:
raise ValueError(f"No condition images found in {os.path.join(dirx, 'render_cond')}")
ref_image_path = random.choice(cond_images)
light_suffix = None
for suffix in self.lighting_suffix_pool:
if suffix in ref_image_path:
light_suffix = suffix
break
if light_suffix is None:
raise ValueError(f"light suffix not found in {ref_image_path}")
ref_image_diff_light_path = random.choice(
[
ref_image_path.replace(light_suffix, tar_suffix)
for tar_suffix in self.lighting_suffix_pool
if tar_suffix != light_suffix
]
)
images_ref_paths = [ref_image_path, ref_image_diff_light_path]
# Data aug
bg_c_record = None
for i, image_ref in enumerate(images_ref_paths):
if random.random() < 0.6:
bg_c = bg_gray
else:
if random.random() < 0.5:
bg_c = bg_black
else:
bg_c = bg_white
if i == 0:
bg_c_record = bg_c
image, alpha = self.load_image(image_ref, bg_c_record)
image = self.augment_image(image, bg_c_record).float()
images_ref.append(image)
condition_dict["images_cond"] = torch.stack(images_ref, dim=0).float()
for i, image_gen in enumerate(images_gen):
images_albedo.append(self.augment_image(self.load_image(image_gen, bg_gray)[0], bg_gray))
images_mr.append(
self.augment_image(self.load_image(image_gen.replace("_albedo", "_mr"), bg_gray)[0], bg_gray)
)
images_normal.append(
self.augment_image(self.load_image(image_gen.replace("_albedo", "_normal"), bg_gray)[0], bg_gray)
)
images_position.append(
self.augment_image(self.load_image(image_gen.replace("_albedo", "_pos"), bg_gray)[0], bg_gray)
)
condition_dict["images_albedo"] = torch.stack(images_albedo, dim=0).float()
condition_dict["images_mr"] = torch.stack(images_mr, dim=0).float()
condition_dict["images_normal"] = torch.stack(images_normal, dim=0).float()
condition_dict["images_position"] = torch.stack(images_position, dim=0).float()
condition_dict["name"] = dirx # .replace('/', '_')
return condition_dict # (N, 3, H, W)
# except Exception as e:
# print(e, self.data[index])
# # exit()
if __name__ == "__main__":
dataset = TextureDataset(json_path=["../../../train_examples/examples.json"])
print("images_cond", dataset[0]["images_cond"].shape)
print("images_albedo", dataset[0]["images_albedo"].shape)
print("images_mr", dataset[0]["images_mr"].shape)
print("images_normal", dataset[0]["images_normal"].shape)
print("images_position", dataset[0]["images_position"].shape)
print("name", dataset[0]["name"])