import logging import os import struct from dataclasses import dataclass, field from typing import Optional, Union import cv2 import numpy as np import torch from gsplat.cuda._wrapper import spherical_harmonics from gsplat.rendering import rasterization from plyfile import PlyData from scipy.spatial.transform import Rotation from torch.nn import functional as F logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) __all__ = [ "RenderResult", "GaussianOperator", ] def quat_mult(q1, q2): # NOTE: # Q1 is the quaternion that rotates the vector from the original position to the final position # noqa # Q2 is the quaternion that been rotated w1, x1, y1, z1 = q1.T w2, x2, y2, z2 = q2.T w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 return torch.stack([w, x, y, z]).T def quat_to_rotmat(quats: torch.Tensor, mode="wxyz") -> torch.Tensor: """Convert quaternion to rotation matrix.""" quats = F.normalize(quats, p=2, dim=-1) if mode == "xyzw": x, y, z, w = torch.unbind(quats, dim=-1) elif mode == "wxyz": w, x, y, z = torch.unbind(quats, dim=-1) else: raise ValueError(f"Invalid mode: {mode}.") R = torch.stack( [ 1 - 2 * (y**2 + z**2), 2 * (x * y - w * z), 2 * (x * z + w * y), 2 * (x * y + w * z), 1 - 2 * (x**2 + z**2), 2 * (y * z - w * x), 2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x**2 + y**2), ], dim=-1, ) return R.reshape(quats.shape[:-1] + (3, 3)) def gamma_shs(shs: torch.Tensor, gamma: float) -> torch.Tensor: C0 = 0.28209479177387814 # Constant for normalization in spherical harmonics # noqa # Clip to the range [0.0, 1.0], apply gamma correction, and then un-clip back # noqa new_shs = torch.clip(shs * C0 + 0.5, 0.0, 1.0) new_shs = (torch.pow(new_shs, gamma) - 0.5) / C0 return new_shs @dataclass class RenderResult: rgb: np.ndarray depth: np.ndarray opacity: np.ndarray mask_threshold: float = 10 mask: Optional[np.ndarray] = None rgba: Optional[np.ndarray] = None def __post_init__(self): if isinstance(self.rgb, torch.Tensor): rgb = self.rgb.detach().cpu().numpy() rgb = (rgb * 255).astype(np.uint8) self.rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) if isinstance(self.depth, torch.Tensor): self.depth = self.depth.detach().cpu().numpy() if isinstance(self.opacity, torch.Tensor): opacity = self.opacity.detach().cpu().numpy() opacity = (opacity * 255).astype(np.uint8) self.opacity = cv2.cvtColor(opacity, cv2.COLOR_GRAY2RGB) mask = np.where(self.opacity > self.mask_threshold, 255, 0) self.mask = mask[..., 0:1].astype(np.uint8) self.rgba = np.concatenate([self.rgb, self.mask], axis=-1) @dataclass class GaussianBase: _opacities: torch.Tensor _means: torch.Tensor _scales: torch.Tensor _quats: torch.Tensor _rgbs: Optional[torch.Tensor] = None _features_dc: Optional[torch.Tensor] = None _features_rest: Optional[torch.Tensor] = None sh_degree: Optional[int] = 0 device: str = "cuda" def __post_init__(self): self.active_sh_degree: int = self.sh_degree self.to(self.device) def to(self, device: str) -> None: for k, v in self.__dict__.items(): if not isinstance(v, torch.Tensor): continue self.__dict__[k] = v.to(device) def get_numpy_data(self): data = {} for k, v in self.__dict__.items(): if not isinstance(v, torch.Tensor): continue data[k] = v.detach().cpu().numpy() return data def quat_norm(self, x: torch.Tensor) -> torch.Tensor: return x / x.norm(dim=-1, keepdim=True) @classmethod def load_from_ply( cls, path: str, gamma: float = 1.0, device: str = "cuda", ) -> "GaussianBase": plydata = PlyData.read(path) xyz = torch.stack( ( torch.tensor(plydata.elements[0]["x"], dtype=torch.float32), torch.tensor(plydata.elements[0]["y"], dtype=torch.float32), torch.tensor(plydata.elements[0]["z"], dtype=torch.float32), ), dim=1, ) opacities = torch.tensor( plydata.elements[0]["opacity"], dtype=torch.float32 ).unsqueeze(-1) features_dc = torch.zeros((xyz.shape[0], 3), dtype=torch.float32) features_dc[:, 0] = torch.tensor( plydata.elements[0]["f_dc_0"], dtype=torch.float32 ) features_dc[:, 1] = torch.tensor( plydata.elements[0]["f_dc_1"], dtype=torch.float32 ) features_dc[:, 2] = torch.tensor( plydata.elements[0]["f_dc_2"], dtype=torch.float32 ) scale_names = [ p.name for p in plydata.elements[0].properties if p.name.startswith("scale_") ] scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1])) scales = torch.zeros( (xyz.shape[0], len(scale_names)), dtype=torch.float32 ) for idx, attr_name in enumerate(scale_names): scales[:, idx] = torch.tensor( plydata.elements[0][attr_name], dtype=torch.float32 ) rot_names = [ p.name for p in plydata.elements[0].properties if p.name.startswith("rot_") ] rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1])) rots = torch.zeros((xyz.shape[0], len(rot_names)), dtype=torch.float32) for idx, attr_name in enumerate(rot_names): rots[:, idx] = torch.tensor( plydata.elements[0][attr_name], dtype=torch.float32 ) rots = rots / torch.norm(rots, dim=-1, keepdim=True) # extra features extra_f_names = [ p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_") ] extra_f_names = sorted( extra_f_names, key=lambda x: int(x.split("_")[-1]) ) max_sh_degree = int(np.sqrt((len(extra_f_names) + 3) / 3) - 1) if max_sh_degree != 0: features_extra = torch.zeros( (xyz.shape[0], len(extra_f_names)), dtype=torch.float32 ) for idx, attr_name in enumerate(extra_f_names): features_extra[:, idx] = torch.tensor( plydata.elements[0][attr_name], dtype=torch.float32 ) features_extra = features_extra.view( (features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1) ) features_extra = features_extra.permute(0, 2, 1) if abs(gamma - 1.0) > 1e-3: features_dc = gamma_shs(features_dc, gamma) features_extra[..., :] = 0.0 opacities *= 0.8 shs = torch.cat( [ features_dc.reshape(-1, 3), features_extra.reshape(len(features_dc), -1), ], dim=-1, ) else: # sh_dim is 0, only dc features shs = features_dc features_extra = None return cls( sh_degree=max_sh_degree, _means=xyz, _opacities=opacities, _rgbs=shs, _scales=scales, _quats=rots, _features_dc=features_dc, _features_rest=features_extra, device=device, ) def save_to_ply( self, path: str, colors: torch.Tensor = None, enable_mask: bool = False ): os.makedirs(os.path.dirname(path), exist_ok=True) numpy_data = self.get_numpy_data() means = numpy_data["_means"] scales = numpy_data["_scales"] quats = numpy_data["_quats"] opacities = numpy_data["_opacities"] sh0 = numpy_data["_features_dc"] shN = numpy_data.get("_features_rest", np.zeros((means.shape[0], 0))) shN = shN.reshape(means.shape[0], -1) # Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays # noqa if enable_mask: invalid_mask = ( np.isnan(means).any(axis=1) | np.isinf(means).any(axis=1) | np.isnan(scales).any(axis=1) | np.isinf(scales).any(axis=1) | np.isnan(quats).any(axis=1) | np.isinf(quats).any(axis=1) | np.isnan(opacities).any(axis=0) | np.isinf(opacities).any(axis=0) | np.isnan(sh0).any(axis=1) | np.isinf(sh0).any(axis=1) | np.isnan(shN).any(axis=1) | np.isinf(shN).any(axis=1) ) # Filter out rows with NaNs or Infs from all data arrays means = means[~invalid_mask] scales = scales[~invalid_mask] quats = quats[~invalid_mask] opacities = opacities[~invalid_mask] sh0 = sh0[~invalid_mask] shN = shN[~invalid_mask] num_points = means.shape[0] with open(path, "wb") as f: # Write PLY header f.write(b"ply\n") f.write(b"format binary_little_endian 1.0\n") f.write(f"element vertex {num_points}\n".encode()) f.write(b"property float x\n") f.write(b"property float y\n") f.write(b"property float z\n") f.write(b"property float nx\n") f.write(b"property float ny\n") f.write(b"property float nz\n") if colors is not None: for j in range(colors.shape[1]): f.write(f"property float f_dc_{j}\n".encode()) else: for i, data in enumerate([sh0, shN]): prefix = "f_dc" if i == 0 else "f_rest" for j in range(data.shape[1]): f.write(f"property float {prefix}_{j}\n".encode()) f.write(b"property float opacity\n") for i in range(scales.shape[1]): f.write(f"property float scale_{i}\n".encode()) for i in range(quats.shape[1]): f.write(f"property float rot_{i}\n".encode()) f.write(b"end_header\n") # Write vertex data for i in range(num_points): f.write(struct.pack(" (x y z qw qx qy qz) instance_pose = instance_pose[[0, 1, 2, 6, 3, 4, 5]] cur_instances_quats = self.quat_norm(instance_pose[3:]) rot_cur = quat_to_rotmat(cur_instances_quats, mode="wxyz") # update the means num_gs = means.shape[0] trans_per_pts = torch.stack([instance_pose[:3]] * num_gs, dim=0) quat_per_pts = torch.stack([instance_pose[3:]] * num_gs, dim=0) rot_per_pts = torch.stack([rot_cur] * num_gs, dim=0) # (num_gs, 3, 3) # update the means cur_means = ( torch.bmm(rot_per_pts, means.unsqueeze(-1)).squeeze(-1) + trans_per_pts ) # update the quats _quats = self.quat_norm(quats) cur_quats = quat_mult(quat_per_pts, _quats) return cur_means, cur_quats def get_gaussians( self, c2w: torch.Tensor = None, instance_pose: torch.Tensor = None, apply_activate: bool = False, ) -> "GaussianBase": """Get Gaussian data under the given instance_pose.""" if c2w is None: c2w = torch.eye(4).to(self.device) if instance_pose is not None: # compute the transformed gs means and quats world_means, world_quats = self._compute_transform( self._means, self._quats, instance_pose.float().to(self.device) ) else: world_means, world_quats = self._means, self._quats # get colors of gaussians if self._features_rest is not None: colors = torch.cat( (self._features_dc[:, None, :], self._features_rest), dim=1 ) else: colors = self._features_dc[:, None, :] if self.sh_degree > 0: viewdirs = world_means.detach() - c2w[..., :3, 3] # (N, 3) viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True) rgbs = spherical_harmonics(self.sh_degree, viewdirs, colors) rgbs = torch.clamp(rgbs + 0.5, 0.0, 1.0) else: rgbs = torch.sigmoid(colors[:, 0, :]) gs_dict = dict( _means=world_means, _opacities=( torch.sigmoid(self._opacities) if apply_activate else self._opacities ), _rgbs=rgbs, _scales=( torch.exp(self._scales) if apply_activate else self._scales ), _quats=self.quat_norm(world_quats), _features_dc=self._features_dc, _features_rest=self._features_rest, sh_degree=self.sh_degree, device=self.device, ) return GaussianOperator(**gs_dict) def rescale(self, scale: float): if scale != 1.0: self._means *= scale self._scales += torch.log(self._scales.new_tensor(scale)) def set_scale_by_height(self, real_height: float) -> None: def _ptp(tensor, dim): val = tensor.max(dim=dim).values - tensor.min(dim=dim).values return val.tolist() xyz_scale = max(_ptp(self._means, dim=0)) self.rescale(1 / (xyz_scale + 1e-6)) # Normalize to [-0.5, 0.5] raw_height = _ptp(self._means, dim=0)[1] scale = real_height / raw_height self.rescale(scale) return @staticmethod def resave_ply( in_ply: str, out_ply: str, real_height: float = None, instance_pose: np.ndarray = None, sh_degree: int = 0, device: str = "cuda", ) -> None: gs_model = GaussianOperator.load_from_ply( in_ply, sh_degree, device=device ) if instance_pose is not None: gs_model = gs_model.get_gaussians(instance_pose=instance_pose) if real_height is not None: gs_model.set_scale_by_height(real_height) gs_model.save_to_ply(out_ply) return @staticmethod def trans_to_quatpose( rot_matrix: list[list[float]], trans_matrix: list[float] = [0, 0, 0], ) -> torch.Tensor: if isinstance(rot_matrix, list): rot_matrix = np.array(rot_matrix) rot = Rotation.from_matrix(rot_matrix) qx, qy, qz, qw = rot.as_quat() instance_pose = torch.tensor([*trans_matrix, qx, qy, qz, qw]) return instance_pose def render( self, c2w: torch.Tensor, Ks: torch.Tensor, image_width: int, image_height: int, ) -> RenderResult: gs = self.get_gaussians(c2w, apply_activate=True) renders, alphas, _ = rasterization( means=gs._means, quats=gs._quats, scales=gs._scales, opacities=gs._opacities.squeeze(), colors=gs._rgbs, viewmats=torch.linalg.inv(c2w)[None, ...], Ks=Ks[None, ...], width=image_width, height=image_height, packed=False, absgrad=True, sparse_grad=False, # rasterize_mode="classic", rasterize_mode="antialiased", **{ "near_plane": 0.01, "far_plane": 1000000000, "radius_clip": 0.0, "render_mode": "RGB+ED", }, ) renders = renders[0] alphas = alphas[0].squeeze(-1) assert renders.shape[-1] == 4, f"Must render rgb, depth and alpha" rendered_rgb, rendered_depth = torch.split(renders, [3, 1], dim=-1) return RenderResult( torch.clamp(rendered_rgb, min=0, max=1), rendered_depth, alphas[..., None], ) if __name__ == "__main__": input_gs = "outputs/test/debug.ply" output_gs = "./debug_v3.ply" gs_model: GaussianOperator = GaussianOperator.load_from_ply(input_gs) # 绕 x 轴旋转 180° R_x = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] instance_pose = gs_model.trans_to_quatpose(R_x) gs_model = gs_model.get_gaussians(instance_pose=instance_pose) gs_model.rescale(2) gs_model.set_scale_by_height(1.3) gs_model.save_to_ply(output_gs)