import math import os import json import re import cv2 from dataclasses import dataclass, field import random import imageio import numpy as np import torch import torch.nn.functional as F import torchvision.transforms as transforms from torch.utils.data import DataLoader, Dataset from PIL import Image from step1x3d_geometry.utils.typing import * @dataclass class BaseDataModuleConfig: root_dir: str = None batch_size: int = 4 num_workers: int = 8 ################################# General argumentation ################################# random_flip: bool = ( False # whether to randomly flip the input point cloud and the input images ) ################################# Geometry part ################################# load_geometry: bool = True # whether to load geometry data with_sharp_data: bool = False geo_data_type: str = "sdf" # occupancy, sdf # for occupancy or sdf supervision n_samples: int = 4096 # number of points in input point cloud upsample_ratio: int = 1 # upsample ratio for input point cloud sampling_strategy: Optional[str] = ( "random" # sampling strategy for input point cloud ) scale: float = 1.0 # scale of the input point cloud and target supervision noise_sigma: float = 0.0 # noise level of the input point cloud rotate_points: bool = ( False # whether to rotate the input point cloud and the supervision, for VAE aug. ) load_geometry_supervision: bool = False # whether to load supervision supervision_type: str = "sdf" # occupancy, sdf, tsdf, tsdf_w_surface n_supervision: int = 10000 # number of points in supervision tsdf_threshold: float = ( 0.01 # threshold for truncating sdf values, used when input is sdf ) ################################# Image part ################################# load_image: bool = False # whether to load images image_type: str = "rgb" # rgb, normal, rgb_or_normal image_file_type: str = "png" # png, jpeg image_type_ratio: float = ( 1.0 # ratio of rgb for each dataset when image_type is "rgb_or_normal" ) crop_image: bool = True # whether to crop the input image random_color_jitter: bool = ( False # whether to randomly color jitter the input images ) random_rotate: bool = ( False # whether to randomly rotate the input images, default [-10 deg, 10 deg] ) random_mask: bool = False # whether to add random mask to the input image background_color: Tuple[int, int, int] = field( default_factory=lambda: (255, 255, 255) ) idx: Optional[List[int]] = None # index of the image to load n_views: int = 1 # number of views foreground_ratio: Optional[float] = 0.90 ################################# Caption part ################################# load_caption: bool = False # whether to load captions load_label: bool = False # whether to load labels class BaseDataset(Dataset): def __init__(self, cfg: Any, split: str) -> None: super().__init__() self.cfg: BaseDataModuleConfig = cfg self.split = split self.uids = json.load(open(f"{cfg.root_dir}/{split}.json")) print(f"Loaded {len(self.uids)} {split} uids") # add ColorJitter transforms for input images if self.cfg.random_color_jitter: self.color_jitter = transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2 ) # add RandomRotation transforms for input images if self.cfg.random_rotate: self.rotate = transforms.RandomRotation( degrees=10, fill=(*self.cfg.background_color, 0.0) ) # by default 10 deg def __len__(self): return len(self.uids) def _load_shape_from_occupancy_or_sdf(self, index: int) -> Dict[str, Any]: if self.cfg.geo_data_type == "sdf": data = np.load(f"{self.cfg.root_dir}/surfaces/{self.uids[index]}.npz") # for input point cloud surface = data["surface"] if self.cfg.with_sharp_data: sharp_surface = data["sharp_surface"] else: raise NotImplementedError( f"Data type {self.cfg.geo_data_type} not implemented" ) # random sampling if self.cfg.sampling_strategy == "random": rng = np.random.default_rng() ind = rng.choice( surface.shape[0], self.cfg.upsample_ratio * self.cfg.n_samples, replace=True, ) surface = surface[ind] if self.cfg.with_sharp_data: sharp_surface = sharp_surface[ind] elif self.cfg.sampling_strategy == "fps": import fpsample kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling( surface[:, :3], self.cfg.n_samples, h=5 ) surface = surface[kdline_fps_samples_idx] if self.cfg.with_sharp_data: kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling( sharp_surface[:, :3], self.cfg.n_samples, h=5 ) sharp_surface = sharp_surface[kdline_fps_samples_idx] else: raise NotImplementedError( f"sampling strategy {self.cfg.sampling_strategy} not implemented" ) # rescale data surface[:, :3] = surface[:, :3] * self.cfg.scale # target scale if self.cfg.with_sharp_data: sharp_surface[:, :3] = sharp_surface[:, :3] * self.cfg.scale # target scale ret = { "uid": self.uids[index].split("/")[-1], "surface": surface.astype(np.float32), "sharp_surface": sharp_surface.astype(np.float32), } else: ret = { "uid": self.uids[index].split("/")[-1], "surface": surface.astype(np.float32), } return ret def _load_shape_supervision_occupancy_or_sdf(self, index: int) -> Dict[str, Any]: # for supervision ret = {} if self.cfg.geo_data_type == "sdf": data = np.load(f"{self.cfg.root_dir}/surfaces/{self.uids[index]}.npz") data = np.concatenate( [data["volume_rand_points"], data["near_surface_points"]], axis=0 ) rand_points, sdfs = data[:, :3], data[:, 3:] else: raise NotImplementedError( f"Data type {self.cfg.geo_data_type} not implemented" ) # random sampling rng = np.random.default_rng() ind = rng.choice(rand_points.shape[0], self.cfg.n_supervision, replace=False) rand_points = rand_points[ind] rand_points = rand_points * self.cfg.scale ret["rand_points"] = rand_points.astype(np.float32) if self.cfg.geo_data_type == "sdf": if self.cfg.supervision_type == "sdf": ret["sdf"] = sdfs[ind].flatten().astype(np.float32) elif self.cfg.supervision_type == "occupancy": ret["occupancies"] = np.where(sdfs[ind].flatten() < 1e-3, 0, 1).astype( np.float32 ) elif self.cfg.supervision_type == "tsdf": ret["sdf"] = ( sdfs[ind] .flatten() .astype(np.float32) .clip(-self.cfg.tsdf_threshold, self.cfg.tsdf_threshold) / self.cfg.tsdf_threshold ) else: raise NotImplementedError( f"Supervision type {self.cfg.supervision_type} not implemented" ) return ret def _load_image(self, index: int) -> Dict[str, Any]: def _process_img(image, background_color=(255, 255, 255), foreground_ratio=0.9): alpha = image.getchannel("A") background = Image.new("RGBA", image.size, (*background_color, 255)) image = Image.alpha_composite(background, image) image = image.crop(alpha.getbbox()) new_size = tuple(int(dim * foreground_ratio) for dim in image.size) resized_image = image.resize(new_size) padded_image = Image.new("RGBA", image.size, (*background_color, 255)) paste_position = ( (image.width - resized_image.width) // 2, (image.height - resized_image.height) // 2, ) padded_image.paste(resized_image, paste_position) # Expand image to 1:1 max_dim = max(padded_image.size) image = Image.new("RGBA", (max_dim, max_dim), (*background_color, 255)) paste_position = ( (max_dim - padded_image.width) // 2, (max_dim - padded_image.height) // 2, ) image.paste(padded_image, paste_position) image = image.resize((512, 512)) return image.convert("RGB"), alpha ret = {} if self.cfg.image_type == "rgb" or self.cfg.image_type == "normal": assert ( self.cfg.n_views == 1 ), "Only single view is supported for single image" sel_idx = random.choice(self.cfg.idx) ret["sel_image_idx"] = sel_idx if self.cfg.image_type == "rgb": img_path = ( f"{self.cfg.root_dir}/images/" + "/".join(self.uids[index].split("/")[-2:]) + f"/{'{:04d}'.format(sel_idx)}_rgb.{self.cfg.image_file_type}" ) elif self.cfg.image_type == "normal": img_path = ( f"{self.cfg.root_dir}/images/" + "/".join(self.uids[index].split("/")[-2:]) + f"/{'{:04d}'.format(sel_idx)}_normal.{self.cfg.image_file_type}" ) image = Image.open(img_path).copy() # add random color jitter if self.cfg.random_color_jitter: rgb = self.color_jitter(image.convert("RGB")) image = Image.merge("RGBA", (*rgb.split(), image.getchannel("A"))) # add random rotation if self.cfg.random_rotate: image = self.rotate(image) # add crop if self.cfg.crop_image: background_color = ( torch.randint(0, 256, (3,)) if self.cfg.background_color is None else torch.as_tensor(self.cfg.background_color) ) image, alpha = _process_img( image, background_color, self.cfg.foreground_ratio ) else: alpha = image.getchannel("A") background = Image.new("RGBA", image.size, background_color) image = Image.alpha_composite(background, image).convert("RGB") ret["image"] = torch.from_numpy(np.array(image) / 255.0) ret["mask"] = torch.from_numpy(np.array(alpha) / 255.0).unsqueeze(0) else: raise NotImplementedError( f"Image type {self.cfg.image_type} not implemented" ) return ret def _get_data(self, index): ret = {"uid": self.uids[index]} # random flip flip = np.random.rand() < 0.5 if self.cfg.random_flip else False # load geometry if self.cfg.load_geometry: if self.cfg.geo_data_type == "occupancy" or self.cfg.geo_data_type == "sdf": # load shape ret = self._load_shape_from_occupancy_or_sdf(index) # load supervision for shape if self.cfg.load_geometry_supervision: ret.update(self._load_shape_supervision_occupancy_or_sdf(index)) else: raise NotImplementedError( f"Geo data type {self.cfg.geo_data_type} not implemented" ) if flip: # random flip the input point cloud and the supervision for key in ret.keys(): if key in ["surface", "sharp_surface"]: # N x (xyz + normal) ret[key][:, 0] = -ret[key][:, 0] ret[key][:, 3] = -ret[key][:, 3] elif key in ["rand_points"]: ret[key][:, 0] = -ret[key][:, 0] # load image if self.cfg.load_image: ret.update(self._load_image(index)) if flip: # random flip the input image for key in ret.keys(): if key in ["image"]: # random flip the input image ret[key] = torch.flip(ret[key], [2]) if key in ["mask"]: # random flip the input image ret[key] = torch.flip(ret[key], [2]) # load caption meta = None if self.cfg.load_caption: with open(f"{self.cfg.root_dir}/metas/{self.uids[index]}.json", "r") as f: meta = json.load(f) ret.update({"caption": meta["caption"]}) # load label if self.cfg.load_label: if meta is None: with open( f"{self.cfg.root_dir}/metas/{self.uids[index]}.json", "r" ) as f: meta = json.load(f) ret.update({"label": [meta["label"]]}) return ret def __getitem__(self, index): try: return self._get_data(index) except Exception as e: print(f"Error in {self.uids[index]}: {e}") return self.__getitem__(np.random.randint(len(self))) def collate(self, batch): from torch.utils.data._utils.collate import default_collate_fn_map return torch.utils.data.default_collate(batch)