Spaces:
Runtime error
Runtime error
| import os | |
| from dataclasses import dataclass, field | |
| from collections import defaultdict | |
| try: | |
| from diff_gaussian_rasterization_wda import GaussianRasterizationSettings, GaussianRasterizer | |
| except: | |
| from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer | |
| from plyfile import PlyData, PlyElement | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import math | |
| import copy | |
| from diffusers.utils import is_torch_version | |
| from lam.models.rendering.flame_model.flame import FlameHeadSubdivided | |
| from lam.models.transformer import TransformerDecoder | |
| from pytorch3d.transforms import matrix_to_quaternion | |
| from lam.models.rendering.utils.typing import * | |
| from lam.models.rendering.utils.utils import trunc_exp, MLP | |
| from lam.models.rendering.gaussian_model import GaussianModel | |
| from einops import rearrange, repeat | |
| from pytorch3d.ops.points_normals import estimate_pointcloud_normals | |
| os.environ["PYOPENGL_PLATFORM"] = "egl" | |
| from pytorch3d.structures import Meshes, Pointclouds | |
| from pytorch3d.renderer import ( | |
| AmbientLights, | |
| PerspectiveCameras, | |
| SoftSilhouetteShader, | |
| SoftPhongShader, | |
| RasterizationSettings, | |
| MeshRenderer, | |
| MeshRendererWithFragments, | |
| MeshRasterizer, | |
| TexturesVertex, | |
| ) | |
| from pytorch3d.renderer.blending import BlendParams, softmax_rgb_blend | |
| import lam.models.rendering.utils.mesh_utils as mesh_utils | |
| from lam.models.rendering.utils.point_utils import depth_to_normal | |
| from pytorch3d.ops.interp_face_attrs import interpolate_face_attributes | |
| inverse_sigmoid = lambda x: np.log(x / (1 - x)) | |
| def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): | |
| Rt = np.zeros((4, 4)) | |
| Rt[:3, :3] = R.transpose() | |
| Rt[:3, 3] = t | |
| Rt[3, 3] = 1.0 | |
| C2W = np.linalg.inv(Rt) | |
| cam_center = C2W[:3, 3] | |
| cam_center = (cam_center + translate) * scale | |
| C2W[:3, 3] = cam_center | |
| Rt = np.linalg.inv(C2W) | |
| return np.float32(Rt) | |
| def getProjectionMatrix(znear, zfar, fovX, fovY): | |
| tanHalfFovY = math.tan((fovY / 2)) | |
| tanHalfFovX = math.tan((fovX / 2)) | |
| top = tanHalfFovY * znear | |
| bottom = -top | |
| right = tanHalfFovX * znear | |
| left = -right | |
| P = torch.zeros(4, 4) | |
| z_sign = 1.0 | |
| P[0, 0] = 2.0 * znear / (right - left) | |
| P[1, 1] = 2.0 * znear / (top - bottom) | |
| P[0, 2] = (right + left) / (right - left) | |
| P[1, 2] = (top + bottom) / (top - bottom) | |
| P[3, 2] = z_sign | |
| P[2, 2] = z_sign * zfar / (zfar - znear) | |
| P[2, 3] = -(zfar * znear) / (zfar - znear) | |
| return P | |
| def intrinsic_to_fov(intrinsic, w, h): | |
| fx, fy = intrinsic[0, 0], intrinsic[1, 1] | |
| fov_x = 2 * torch.arctan2(w, 2 * fx) | |
| fov_y = 2 * torch.arctan2(h, 2 * fy) | |
| return fov_x, fov_y | |
| class Camera: | |
| def __init__(self, w2c, intrinsic, FoVx, FoVy, height, width, trans=np.array([0.0, 0.0, 0.0]), scale=1.0) -> None: | |
| self.FoVx = FoVx | |
| self.FoVy = FoVy | |
| self.height = int(height) | |
| self.width = int(width) | |
| self.world_view_transform = w2c.transpose(0, 1) | |
| self.intrinsic = intrinsic | |
| self.zfar = 100.0 | |
| self.znear = 0.01 | |
| self.trans = trans | |
| self.scale = scale | |
| self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).to(w2c.device) | |
| self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) | |
| self.camera_center = self.world_view_transform.inverse()[3, :3] | |
| def from_c2w(c2w, intrinsic, height, width): | |
| w2c = torch.inverse(c2w) | |
| FoVx, FoVy = intrinsic_to_fov(intrinsic, w=torch.tensor(width, device=w2c.device), h=torch.tensor(height, device=w2c.device)) | |
| return Camera(w2c=w2c, intrinsic=intrinsic, FoVx=FoVx, FoVy=FoVy, height=height, width=width) | |
| class GSLayer(nn.Module): | |
| def __init__(self, in_channels, use_rgb, | |
| clip_scaling=0.2, | |
| init_scaling=-5.0, | |
| scale_sphere=False, | |
| init_density=0.1, | |
| sh_degree=None, | |
| xyz_offset=True, | |
| restrict_offset=True, | |
| xyz_offset_max_step=None, | |
| fix_opacity=False, | |
| fix_rotation=False, | |
| use_fine_feat=False, | |
| pred_res=False, | |
| ): | |
| super().__init__() | |
| self.clip_scaling = clip_scaling | |
| self.use_rgb = use_rgb | |
| self.restrict_offset = restrict_offset | |
| self.xyz_offset = xyz_offset | |
| self.xyz_offset_max_step = xyz_offset_max_step # 1.2 / 32 | |
| self.fix_opacity = fix_opacity | |
| self.fix_rotation = fix_rotation | |
| self.use_fine_feat = use_fine_feat | |
| self.scale_sphere = scale_sphere | |
| self.pred_res = pred_res | |
| self.attr_dict ={ | |
| "shs": (sh_degree + 1) ** 2 * 3, | |
| "scaling": 3 if not scale_sphere else 1, | |
| "xyz": 3, | |
| "opacity": None, | |
| "rotation": None | |
| } | |
| if not self.fix_opacity: | |
| self.attr_dict["opacity"] = 1 | |
| if not self.fix_rotation: | |
| self.attr_dict["rotation"] = 4 | |
| self.out_layers = nn.ModuleDict() | |
| for key, out_ch in self.attr_dict.items(): | |
| if out_ch is None: | |
| layer = nn.Identity() | |
| else: | |
| if key == "shs" and use_rgb: | |
| out_ch = 3 | |
| if key == "shs": | |
| shs_out_ch = out_ch | |
| if pred_res: | |
| layer = nn.Linear(in_channels+out_ch, out_ch) | |
| else: | |
| layer = nn.Linear(in_channels, out_ch) | |
| # initialize | |
| if not (key == "shs" and use_rgb): | |
| if key == "opacity" and self.fix_opacity: | |
| pass | |
| elif key == "rotation" and self.fix_rotation: | |
| pass | |
| else: | |
| nn.init.constant_(layer.weight, 0) | |
| nn.init.constant_(layer.bias, 0) | |
| if key == "scaling": | |
| nn.init.constant_(layer.bias, init_scaling) | |
| elif key == "rotation": | |
| if not self.fix_rotation: | |
| nn.init.constant_(layer.bias, 0) | |
| nn.init.constant_(layer.bias[0], 1.0) | |
| elif key == "opacity": | |
| if not self.fix_opacity: | |
| nn.init.constant_(layer.bias, inverse_sigmoid(init_density)) | |
| self.out_layers[key] = layer | |
| if self.use_fine_feat: | |
| fine_shs_layer = nn.Linear(in_channels, shs_out_ch) | |
| nn.init.constant_(fine_shs_layer.weight, 0) | |
| nn.init.constant_(fine_shs_layer.bias, 0) | |
| self.out_layers["fine_shs"] = fine_shs_layer | |
| def forward(self, x, pts, x_fine=None, gs_raw_attr=None, ret_raw=False, vtx_sym_idxs=None): | |
| assert len(x.shape) == 2 | |
| ret = {} | |
| if ret_raw: | |
| raw_attr = {} | |
| ori_x = x | |
| for k in self.attr_dict: | |
| # if vtx_sym_idxs is not None and k in ["shs", "scaling", "opacity"]: | |
| if vtx_sym_idxs is not None and k in ["shs", "scaling", "opacity", "rotation"]: | |
| # print("==="*16*3, "\n\n\n"+"use sym mean.", "\n"+"==="*16*3) | |
| # x = (x + x[vtx_sym_idxs.to(x.device), :]) / 2. | |
| x = ori_x[vtx_sym_idxs.to(x.device), :] | |
| else: | |
| x = ori_x | |
| layer =self.out_layers[k] | |
| if self.pred_res and (not self.fix_opacity or k != "opacity") and (not self.fix_rotation or k != "rotation"): | |
| v = layer(torch.cat([gs_raw_attr[k], x], dim=-1)) | |
| v = gs_raw_attr[k] + v | |
| else: | |
| v = layer(x) | |
| if ret_raw: | |
| raw_attr[k] = v | |
| if k == "rotation": | |
| if self.fix_rotation: | |
| v = matrix_to_quaternion(torch.eye(3).type_as(x)[None,: , :].repeat(x.shape[0], 1, 1)) # constant rotation | |
| else: | |
| # assert len(x.shape) == 2 | |
| v = torch.nn.functional.normalize(v) | |
| elif k == "scaling": | |
| v = trunc_exp(v) | |
| if self.scale_sphere: | |
| assert v.shape[-1] == 1 | |
| v = torch.cat([v, v, v], dim=-1) | |
| if self.clip_scaling is not None: | |
| v = torch.clamp(v, min=0, max=self.clip_scaling) | |
| elif k == "opacity": | |
| if self.fix_opacity: | |
| v = torch.ones_like(x)[..., 0:1] | |
| else: | |
| v = torch.sigmoid(v) | |
| elif k == "shs": | |
| if self.use_rgb: | |
| v[..., :3] = torch.sigmoid(v[..., :3]) | |
| if self.use_fine_feat: | |
| v_fine = self.out_layers["fine_shs"](x_fine) | |
| v_fine = torch.tanh(v_fine) | |
| v = v + v_fine | |
| else: | |
| if self.use_fine_feat: | |
| v_fine = self.out_layers["fine_shs"](x_fine) | |
| v = v + v_fine | |
| v = torch.reshape(v, (v.shape[0], -1, 3)) | |
| elif k == "xyz": | |
| # TODO check | |
| if self.restrict_offset: | |
| max_step = self.xyz_offset_max_step | |
| v = (torch.sigmoid(v) - 0.5) * max_step | |
| if self.xyz_offset: | |
| pass | |
| else: | |
| assert NotImplementedError | |
| ret["offset"] = v | |
| v = pts + v | |
| ret[k] = v | |
| if ret_raw: | |
| return GaussianModel(**ret), raw_attr | |
| else: | |
| return GaussianModel(**ret) | |
| class PointEmbed(nn.Module): | |
| def __init__(self, hidden_dim=48, dim=128): | |
| super().__init__() | |
| assert hidden_dim % 6 == 0 | |
| self.embedding_dim = hidden_dim | |
| e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi | |
| e = torch.stack([ | |
| torch.cat([e, torch.zeros(self.embedding_dim // 6), | |
| torch.zeros(self.embedding_dim // 6)]), | |
| torch.cat([torch.zeros(self.embedding_dim // 6), e, | |
| torch.zeros(self.embedding_dim // 6)]), | |
| torch.cat([torch.zeros(self.embedding_dim // 6), | |
| torch.zeros(self.embedding_dim // 6), e]), | |
| ]) | |
| self.register_buffer('basis', e) # 3 x 16 | |
| self.mlp = nn.Linear(self.embedding_dim+3, dim) | |
| self.norm = nn.LayerNorm(dim) | |
| def embed(input, basis): | |
| projections = torch.einsum( | |
| 'bnd,de->bne', input, basis) | |
| embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) | |
| return embeddings | |
| def forward(self, input): | |
| # input: B x N x 3 | |
| embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C | |
| embed = self.norm(embed) | |
| return embed | |
| class CrossAttnBlock(nn.Module): | |
| """ | |
| Transformer block that takes in a cross-attention condition. | |
| Designed for SparseLRM architecture. | |
| """ | |
| # Block contains a cross-attention layer, a self-attention layer, and an MLP | |
| def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float=None, | |
| attn_drop: float = 0., attn_bias: bool = False, | |
| mlp_ratio: float = 4., mlp_drop: float = 0., feedforward=False): | |
| super().__init__() | |
| # TODO check already apply normalization | |
| # self.norm_q = nn.LayerNorm(inner_dim, eps=eps) | |
| # self.norm_k = nn.LayerNorm(cond_dim, eps=eps) | |
| self.norm_q = nn.Identity() | |
| self.norm_k = nn.Identity() | |
| self.cross_attn = nn.MultiheadAttention( | |
| embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, | |
| dropout=attn_drop, bias=attn_bias, batch_first=True) | |
| self.mlp = None | |
| if feedforward: | |
| self.norm2 = nn.LayerNorm(inner_dim, eps=eps) | |
| self.self_attn = nn.MultiheadAttention( | |
| embed_dim=inner_dim, num_heads=num_heads, | |
| dropout=attn_drop, bias=attn_bias, batch_first=True) | |
| self.norm3 = nn.LayerNorm(inner_dim, eps=eps) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), | |
| nn.GELU(), | |
| nn.Dropout(mlp_drop), | |
| nn.Linear(int(inner_dim * mlp_ratio), inner_dim), | |
| nn.Dropout(mlp_drop), | |
| ) | |
| def forward(self, x, cond): | |
| # x: [N, L, D] | |
| # cond: [N, L_cond, D_cond] | |
| x = self.cross_attn(self.norm_q(x), self.norm_k(cond), cond, need_weights=False)[0] | |
| if self.mlp is not None: | |
| before_sa = self.norm2(x) | |
| x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0] | |
| x = x + self.mlp(self.norm3(x)) | |
| return x | |
| class DecoderCrossAttn(nn.Module): | |
| def __init__(self, query_dim, context_dim, num_heads, mlp=False, decode_with_extra_info=None): | |
| super().__init__() | |
| self.query_dim = query_dim | |
| self.context_dim = context_dim | |
| self.cross_attn = CrossAttnBlock(inner_dim=query_dim, cond_dim=context_dim, | |
| num_heads=num_heads, feedforward=mlp, | |
| eps=1e-5) | |
| self.decode_with_extra_info = decode_with_extra_info | |
| if decode_with_extra_info is not None: | |
| if decode_with_extra_info["type"] == "dinov2p14_feat": | |
| context_dim = decode_with_extra_info["cond_dim"] | |
| self.cross_attn_color = CrossAttnBlock(inner_dim=query_dim, cond_dim=context_dim, | |
| num_heads=num_heads, feedforward=False, eps=1e-5) | |
| elif decode_with_extra_info["type"] == "decoder_dinov2p14_feat": | |
| from lam.models.encoders.dinov2_wrapper import Dinov2Wrapper | |
| self.encoder = Dinov2Wrapper(model_name='dinov2_vits14_reg', freeze=False, encoder_feat_dim=384) | |
| self.cross_attn_color = CrossAttnBlock(inner_dim=query_dim, cond_dim=384, | |
| num_heads=num_heads, feedforward=False, | |
| eps=1e-5) | |
| elif decode_with_extra_info["type"] == "decoder_resnet18_feat": | |
| from lam.models.encoders.xunet_wrapper import XnetWrapper | |
| self.encoder = XnetWrapper(model_name='resnet18', freeze=False, encoder_feat_dim=64) | |
| self.cross_attn_color = CrossAttnBlock(inner_dim=query_dim, cond_dim=64, | |
| num_heads=num_heads, feedforward=False, | |
| eps=1e-5) | |
| def resize_image(self, image, multiply): | |
| B, _, H, W = image.shape | |
| new_h, new_w = math.ceil(H / multiply) * multiply, math.ceil(W / multiply) * multiply | |
| image = F.interpolate(image, (new_h, new_w), align_corners=True, mode="bilinear") | |
| return image | |
| def forward(self, pcl_query, pcl_latent, extra_info=None): | |
| out = self.cross_attn(pcl_query, pcl_latent) | |
| if self.decode_with_extra_info is not None: | |
| out_dict = {} | |
| out_dict["coarse"] = out | |
| if self.decode_with_extra_info["type"] == "dinov2p14_feat": | |
| out = self.cross_attn_color(out, extra_info["image_feats"]) | |
| out_dict["fine"] = out | |
| return out_dict | |
| elif self.decode_with_extra_info["type"] == "decoder_dinov2p14_feat": | |
| img_feat = self.encoder(extra_info["image"]) | |
| out = self.cross_attn_color(out, img_feat) | |
| out_dict["fine"] = out | |
| return out_dict | |
| elif self.decode_with_extra_info["type"] == "decoder_resnet18_feat": | |
| image = extra_info["image"] | |
| image = self.resize_image(image, multiply=32) | |
| img_feat = self.encoder(image) | |
| out = self.cross_attn_color(out, img_feat) | |
| out_dict["fine"] = out | |
| return out_dict | |
| return out | |
| class GS3DRenderer(nn.Module): | |
| def __init__(self, human_model_path, subdivide_num, smpl_type, feat_dim, query_dim, | |
| use_rgb, sh_degree, xyz_offset_max_step, mlp_network_config, | |
| expr_param_dim, shape_param_dim, | |
| clip_scaling=0.2, | |
| scale_sphere=False, | |
| skip_decoder=False, | |
| fix_opacity=False, | |
| fix_rotation=False, | |
| decode_with_extra_info=None, | |
| gradient_checkpointing=False, | |
| add_teeth=True, | |
| teeth_bs_flag=False, | |
| oral_mesh_flag=False, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| print(f"#########scale sphere:{scale_sphere}, add_teeth:{add_teeth}") | |
| self.gradient_checkpointing = gradient_checkpointing | |
| self.skip_decoder = skip_decoder | |
| self.smpl_type = smpl_type | |
| assert self.smpl_type == "flame" | |
| self.sym_rend2 = True | |
| self.teeth_bs_flag = teeth_bs_flag | |
| self.oral_mesh_flag = oral_mesh_flag | |
| self.render_rgb = kwargs.get("render_rgb", True) | |
| print("==="*16*3, "\n Render rgb:", self.render_rgb, "\n"+"==="*16*3) | |
| self.scaling_modifier = 1.0 | |
| self.sh_degree = sh_degree | |
| if use_rgb: | |
| self.sh_degree = 0 | |
| use_rgb = use_rgb | |
| self.flame_model = FlameHeadSubdivided( | |
| 300, | |
| 100, | |
| add_teeth=add_teeth, | |
| add_shoulder=False, | |
| flame_model_path=f'{human_model_path}/flame_assets/flame/flame2023.pkl', | |
| flame_lmk_embedding_path=f"{human_model_path}/flame_assets/flame/landmark_embedding_with_eyes.npy", | |
| flame_template_mesh_path=f"{human_model_path}/flame_assets/flame/head_template_mesh.obj", | |
| flame_parts_path=f"{human_model_path}/flame_assets/flame/FLAME_masks.pkl", | |
| subdivide_num=subdivide_num, | |
| teeth_bs_flag=teeth_bs_flag, | |
| oral_mesh_flag=oral_mesh_flag | |
| ) | |
| if not self.skip_decoder: | |
| self.pcl_embed = PointEmbed(dim=query_dim) | |
| self.mlp_network_config = mlp_network_config | |
| if self.mlp_network_config is not None: | |
| self.mlp_net = MLP(query_dim, query_dim, **self.mlp_network_config) | |
| init_scaling = -5.0 | |
| self.gs_net = GSLayer(in_channels=query_dim, | |
| use_rgb=use_rgb, | |
| sh_degree=self.sh_degree, | |
| clip_scaling=clip_scaling, | |
| scale_sphere=scale_sphere, | |
| init_scaling=init_scaling, | |
| init_density=0.1, | |
| xyz_offset=True, | |
| restrict_offset=True, | |
| xyz_offset_max_step=xyz_offset_max_step, | |
| fix_opacity=fix_opacity, | |
| fix_rotation=fix_rotation, | |
| use_fine_feat=True if decode_with_extra_info is not None and decode_with_extra_info["type"] is not None else False, | |
| ) | |
| def forward_single_view(self, | |
| gs: GaussianModel, | |
| viewpoint_camera: Camera, | |
| background_color: Optional[Float[Tensor, "3"]], | |
| ): | |
| # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means | |
| screenspace_points = torch.zeros_like(gs.xyz, dtype=gs.xyz.dtype, requires_grad=True, device=self.device) + 0 | |
| try: | |
| screenspace_points.retain_grad() | |
| except: | |
| pass | |
| bg_color = background_color | |
| # Set up rasterization configuration | |
| tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) | |
| tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) | |
| GSRSettings = GaussianRasterizationSettings | |
| GSR = GaussianRasterizer | |
| raster_settings = GSRSettings( | |
| image_height=int(viewpoint_camera.height), | |
| image_width=int(viewpoint_camera.width), | |
| tanfovx=tanfovx, | |
| tanfovy=tanfovy, | |
| bg=bg_color, | |
| scale_modifier=self.scaling_modifier, | |
| viewmatrix=viewpoint_camera.world_view_transform, | |
| projmatrix=viewpoint_camera.full_proj_transform.float(), | |
| sh_degree=self.sh_degree, | |
| campos=viewpoint_camera.camera_center, | |
| prefiltered=False, | |
| debug=False | |
| ) | |
| rasterizer = GSR(raster_settings=raster_settings) | |
| means3D = gs.xyz | |
| means2D = screenspace_points | |
| opacity = gs.opacity | |
| # If precomputed 3d covariance is provided, use it. If not, then it will be computed from | |
| # scaling / rotation by the rasterizer. | |
| scales = None | |
| rotations = None | |
| cov3D_precomp = None | |
| scales = gs.scaling | |
| rotations = gs.rotation | |
| # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors | |
| # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. | |
| shs = None | |
| colors_precomp = None | |
| if self.gs_net.use_rgb: | |
| colors_precomp = gs.shs.squeeze(1) | |
| else: | |
| shs = gs.shs | |
| # Rasterize visible Gaussians to image, obtain their radii (on screen). | |
| # torch.cuda.synchronize() | |
| # with boxx.timeit(): | |
| with torch.autocast(device_type=self.device.type, dtype=torch.float32): | |
| raster_ret = rasterizer( | |
| means3D = means3D.float(), | |
| means2D = means2D.float(), | |
| shs = shs.float() if not self.gs_net.use_rgb else None, | |
| colors_precomp = colors_precomp.float() if colors_precomp is not None else None, | |
| opacities = opacity.float(), | |
| scales = scales.float(), | |
| rotations = rotations.float(), | |
| cov3D_precomp = cov3D_precomp | |
| ) | |
| rendered_image, radii, rendered_depth, rendered_alpha = raster_ret | |
| ret = { | |
| "comp_rgb": rendered_image.permute(1, 2, 0), # [H, W, 3] | |
| "comp_rgb_bg": bg_color, | |
| 'comp_mask': rendered_alpha.permute(1, 2, 0), | |
| 'comp_depth': rendered_depth.permute(1, 2, 0), | |
| } | |
| return ret | |
| def animate_gs_model(self, gs_attr: GaussianModel, query_points, flame_data, debug=False): | |
| """ | |
| query_points: [N, 3] | |
| """ | |
| device = gs_attr.xyz.device | |
| if debug: | |
| N = gs_attr.xyz.shape[0] | |
| gs_attr.xyz = torch.ones_like(gs_attr.xyz) * 0.0 | |
| rotation = matrix_to_quaternion(torch.eye(3).float()[None, :, :].repeat(N, 1, 1)).to(device) # constant rotation | |
| opacity = torch.ones((N, 1)).float().to(device) # constant opacity | |
| gs_attr.opacity = opacity | |
| gs_attr.rotation = rotation | |
| # gs_attr.scaling = torch.ones_like(gs_attr.scaling) * 0.05 | |
| # print(gs_attr.shs.shape) | |
| with torch.autocast(device_type=device.type, dtype=torch.float32): | |
| # mean_3d = query_points + gs_attr.xyz # [N, 3] | |
| mean_3d = gs_attr.xyz # [N, 3] | |
| num_view = flame_data["expr"].shape[0] # [Nv, 100] | |
| mean_3d = mean_3d.unsqueeze(0).repeat(num_view, 1, 1) # [Nv, N, 3] | |
| query_points = query_points.unsqueeze(0).repeat(num_view, 1, 1) | |
| if self.teeth_bs_flag: | |
| expr = torch.cat([flame_data['expr'], flame_data['teeth_bs']], dim=-1) | |
| else: | |
| expr = flame_data["expr"] | |
| ret = self.flame_model.animation_forward(v_cano=mean_3d, | |
| shape=flame_data["betas"].repeat(num_view, 1), | |
| expr=expr, | |
| rotation=flame_data["rotation"], | |
| neck=flame_data["neck_pose"], | |
| jaw=flame_data["jaw_pose"], | |
| eyes=flame_data["eyes_pose"], | |
| translation=flame_data["translation"], | |
| zero_centered_at_root_node=False, | |
| return_landmarks=False, | |
| return_verts_cano=False, | |
| # static_offset=flame_data['static_offset'].to('cuda'), | |
| static_offset=None, | |
| ) | |
| mean_3d = ret["animated"] | |
| gs_attr_list = [] | |
| for i in range(num_view): | |
| gs_attr_copy = GaussianModel(xyz=mean_3d[i], | |
| opacity=gs_attr.opacity, | |
| rotation=gs_attr.rotation, | |
| scaling=gs_attr.scaling, | |
| shs=gs_attr.shs, | |
| albedo=gs_attr.albedo, | |
| lights=gs_attr.lights, | |
| offset=gs_attr.offset) # [N, 3] | |
| gs_attr_list.append(gs_attr_copy) | |
| return gs_attr_list | |
| def forward_gs_attr(self, x, query_points, flame_data, debug=False, x_fine=None, vtx_sym_idxs=None): | |
| """ | |
| x: [N, C] Float[Tensor, "Np Cp"], | |
| query_points: [N, 3] Float[Tensor, "Np 3"] | |
| """ | |
| device = x.device | |
| if self.mlp_network_config is not None: | |
| x = self.mlp_net(x) | |
| if x_fine is not None: | |
| x_fine = self.mlp_net(x_fine) | |
| gs_attr: GaussianModel = self.gs_net(x, query_points, x_fine, vtx_sym_idxs=vtx_sym_idxs) | |
| return gs_attr | |
| def get_query_points(self, flame_data, device): | |
| with torch.no_grad(): | |
| with torch.autocast(device_type=device.type, dtype=torch.float32): | |
| # print(flame_data["betas"].shape, flame_data["face_offset"].shape, flame_data["joint_offset"].shape) | |
| # positions, _, transform_mat_neutral_pose = self.flame_model.get_query_points(flame_data, device=device) # [B, N, 3] | |
| positions = self.flame_model.get_cano_verts(shape_params=flame_data["betas"]) # [B, N, 3] | |
| # print(f"positions shape:{positions.shape}") | |
| return positions, flame_data | |
| def query_latent_feat(self, | |
| positions: Float[Tensor, "*B N1 3"], | |
| flame_data, | |
| latent_feat: Float[Tensor, "*B N2 C"], | |
| extra_info): | |
| device = latent_feat.device | |
| if self.skip_decoder: | |
| gs_feats = latent_feat | |
| assert positions is not None | |
| else: | |
| assert positions is None | |
| if positions is None: | |
| positions, flame_data = self.get_query_points(flame_data, device) | |
| with torch.autocast(device_type=device.type, dtype=torch.float32): | |
| pcl_embed = self.pcl_embed(positions) | |
| gs_feats = pcl_embed | |
| return gs_feats, positions, flame_data | |
| def forward_single_batch( | |
| self, | |
| gs_list: list[GaussianModel], | |
| c2ws: Float[Tensor, "Nv 4 4"], | |
| intrinsics: Float[Tensor, "Nv 4 4"], | |
| height: int, | |
| width: int, | |
| background_color: Optional[Float[Tensor, "Nv 3"]], | |
| debug: bool=False, | |
| ): | |
| out_list = [] | |
| self.device = gs_list[0].xyz.device | |
| for v_idx, (c2w, intrinsic) in enumerate(zip(c2ws, intrinsics)): | |
| out_list.append(self.forward_single_view( | |
| gs_list[v_idx], | |
| Camera.from_c2w(c2w, intrinsic, height, width), | |
| background_color[v_idx], | |
| )) | |
| out = defaultdict(list) | |
| for out_ in out_list: | |
| for k, v in out_.items(): | |
| out[k].append(v) | |
| out = {k: torch.stack(v, dim=0) for k, v in out.items()} | |
| out["3dgs"] = gs_list | |
| return out | |
| def get_sing_batch_smpl_data(self, smpl_data, bidx): | |
| smpl_data_single_batch = {} | |
| for k, v in smpl_data.items(): | |
| smpl_data_single_batch[k] = v[bidx] # e.g. body_pose: [B, N_v, 21, 3] -> [N_v, 21, 3] | |
| if k == "betas" or (k == "joint_offset") or (k == "face_offset"): | |
| smpl_data_single_batch[k] = v[bidx:bidx+1] # e.g. betas: [B, 100] -> [1, 100] | |
| return smpl_data_single_batch | |
| def get_single_view_smpl_data(self, smpl_data, vidx): | |
| smpl_data_single_view = {} | |
| for k, v in smpl_data.items(): | |
| assert v.shape[0] == 1 | |
| if k == "betas" or (k == "joint_offset") or (k == "face_offset") or (k == "transform_mat_neutral_pose"): | |
| smpl_data_single_view[k] = v # e.g. betas: [1, 100] -> [1, 100] | |
| else: | |
| smpl_data_single_view[k] = v[:, vidx: vidx + 1] # e.g. body_pose: [1, N_v, 21, 3] -> [1, 1, 21, 3] | |
| return smpl_data_single_view | |
| def forward_gs(self, | |
| gs_hidden_features: Float[Tensor, "B Np Cp"], | |
| query_points: Float[Tensor, "B Np_q 3"], | |
| flame_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100] | |
| additional_features: Optional[dict] = None, | |
| debug: bool = False, | |
| **kwargs): | |
| batch_size = gs_hidden_features.shape[0] | |
| query_gs_features, query_points, flame_data = self.query_latent_feat(query_points, flame_data, gs_hidden_features, | |
| additional_features) | |
| gs_model_list = [] | |
| all_query_points = [] | |
| for b in range(batch_size): | |
| all_query_points.append(query_points[b:b+1, :]) | |
| if isinstance(query_gs_features, dict): | |
| ret_gs = self.forward_gs_attr(query_gs_features["coarse"][b], query_points[b], None, debug, | |
| x_fine=query_gs_features["fine"][b], vtx_sym_idxs=None) | |
| else: | |
| ret_gs = self.forward_gs_attr(query_gs_features[b], query_points[b], None, debug, vtx_sym_idxs=None) | |
| ret_gs.update_albedo(ret_gs.shs.clone()) | |
| gs_model_list.append(ret_gs) | |
| query_points = torch.cat(all_query_points, dim=0) | |
| return gs_model_list, query_points, flame_data, query_gs_features | |
| def forward_res_refine_gs(self, | |
| gs_hidden_features: Float[Tensor, "B Np Cp"], | |
| query_points: Float[Tensor, "B Np_q 3"], | |
| flame_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100] | |
| additional_features: Optional[dict] = None, | |
| debug: bool = False, | |
| gs_raw_attr_list: list = None, | |
| **kwargs): | |
| batch_size = gs_hidden_features.shape[0] | |
| query_gs_features, query_points, flame_data = self.query_latent_feat(query_points, flame_data, gs_hidden_features, | |
| additional_features) | |
| gs_model_list = [] | |
| for b in range(batch_size): | |
| gs_model = self.gs_refine_net(query_gs_features[b], query_points[b], x_fine=None, gs_raw_attr=gs_raw_attr_list[b]) | |
| gs_model_list.append(gs_model) | |
| return gs_model_list, query_points, flame_data, query_gs_features | |
| def forward_animate_gs(self, gs_model_list, query_points, flame_data, c2w, intrinsic, height, width, | |
| background_color, debug=False): | |
| batch_size = len(gs_model_list) | |
| out_list = [] | |
| for b in range(batch_size): | |
| gs_model = gs_model_list[b] | |
| query_pt = query_points[b] | |
| animatable_gs_model_list: list[GaussianModel] = self.animate_gs_model(gs_model, | |
| query_pt, | |
| self.get_sing_batch_smpl_data(flame_data, b), | |
| debug=debug) | |
| assert len(animatable_gs_model_list) == c2w.shape[1] | |
| out_list.append(self.forward_single_batch( | |
| animatable_gs_model_list, | |
| c2w[b], | |
| intrinsic[b], | |
| height, width, | |
| background_color[b] if background_color is not None else None, | |
| debug=debug)) | |
| out = defaultdict(list) | |
| for out_ in out_list: | |
| for k, v in out_.items(): | |
| out[k].append(v) | |
| for k, v in out.items(): | |
| if isinstance(v[0], torch.Tensor): | |
| out[k] = torch.stack(v, dim=0) | |
| else: | |
| out[k] = v | |
| render_keys = ["comp_rgb", "comp_mask", "comp_depth"] | |
| for key in render_keys: | |
| out[key] = rearrange(out[key], "b v h w c -> b v c h w") | |
| return out | |
| def project_single_view_feats(self, img_vtx_ids, feats, nv, inter_feat=True): | |
| b, h, w, k = img_vtx_ids.shape | |
| c, ih, iw = feats.shape | |
| vtx_ids = img_vtx_ids | |
| if h != ih or w != iw: | |
| if inter_feat: | |
| feats = torch.nn.functional.interpolate( | |
| rearrange(feats, "(b c) h w -> b c h w", b=1).float(), (h, w) | |
| ).squeeze(0) | |
| vtx_ids = rearrange(vtx_ids, "b (c h) w k -> (b k) c h w", c=1).long().squeeze(1) | |
| else: | |
| vtx_ids = torch.nn.functional.interpolate( | |
| rearrange(vtx_ids, "b (c h) w k -> (b k) c h w", c=1).float(), (ih, iw), mode="nearest" | |
| ).long().squeeze(1) | |
| else: | |
| vtx_ids = rearrange(vtx_ids, "b h w k -> (b k) h w", b=1).long() | |
| vis_mask = vtx_ids > 0 | |
| vtx_ids = vtx_ids[vis_mask] # n | |
| vtx_ids = repeat(vtx_ids, "n -> n c", c=c) | |
| feats = repeat(feats, "c h w -> k h w c", k=k).to(vtx_ids.device) | |
| feats = feats[vis_mask, :] # n, c | |
| sums = torch.zeros((nv, c), dtype=feats.dtype, device=feats.device) | |
| counts = torch.zeros((nv), dtype=torch.int64, device=feats.device) | |
| sums.scatter_add_(0, vtx_ids, feats) | |
| one_hot = torch.ones_like(vtx_ids[:, 0], dtype=torch.int64).to(feats.device) | |
| counts.scatter_add_(0, vtx_ids[:, 0], one_hot) | |
| clamp_counts = counts.clamp(min=1) | |
| mean_feats = sums / clamp_counts.view(-1, 1) | |
| return mean_feats | |
| def forward(self, | |
| gs_hidden_features: Float[Tensor, "B Np Cp"], | |
| query_points: Float[Tensor, "B Np 3"], | |
| flame_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100] | |
| c2w: Float[Tensor, "B Nv 4 4"], | |
| intrinsic: Float[Tensor, "B Nv 4 4"], | |
| height, | |
| width, | |
| additional_features: Optional[Float[Tensor, "B C H W"]] = None, | |
| background_color: Optional[Float[Tensor, "B Nv 3"]] = None, | |
| debug: bool = False, | |
| **kwargs): | |
| # need shape_params of flame_data to get querty points and get "transform_mat_neutral_pose" | |
| gs_model_list, query_points, flame_data, query_gs_features = self.forward_gs(gs_hidden_features, query_points, flame_data=flame_data, | |
| additional_features=additional_features, debug=debug) | |
| out = self.forward_animate_gs(gs_model_list, query_points, flame_data, c2w, intrinsic, height, width, background_color, debug) | |
| return out | |
| def test_head(): | |
| import cv2 | |
| human_model_path = "./pretrained_models/human_model_files" | |
| device = "cuda" | |
| from accelerate.utils import set_seed | |
| set_seed(1234) | |
| from lam.datasets.video_head import VideoHeadDataset | |
| root_dir = "./train_data/vfhq_vhap/export" | |
| meta_path = "./train_data/vfhq_vhap/label/valid_id_list.json" | |
| # root_dir = "./train_data/nersemble/export" | |
| # meta_path = "./train_data/nersemble/label/valid_id_list1.json" | |
| dataset = VideoHeadDataset(root_dirs=root_dir, meta_path=meta_path, sample_side_views=7, | |
| render_image_res_low=512, render_image_res_high=512, | |
| render_region_size=(512, 512), source_image_res=512, | |
| enlarge_ratio=[0.8, 1.2], | |
| debug=False) | |
| data = dataset[0] | |
| def get_flame_params(data): | |
| flame_params = {} | |
| flame_keys = ['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas',\ | |
| 'rotation', 'neck_pose', 'eyes_pose', 'translation'] | |
| for k, v in data.items(): | |
| if k in flame_keys: | |
| # print(k, v.shape) | |
| flame_params[k] = data[k] | |
| return flame_params | |
| flame_data = get_flame_params(data) | |
| flame_data_tmp = {} | |
| for k, v in flame_data.items(): | |
| flame_data_tmp[k] = v.unsqueeze(0).to(device) | |
| print(k, v.shape) | |
| flame_data = flame_data_tmp | |
| c2ws = data["c2ws"].unsqueeze(0).to(device) | |
| intrs = data["intrs"].unsqueeze(0).to(device) | |
| render_images = data["render_image"].numpy() | |
| render_h = data["render_full_resolutions"][0, 0] | |
| render_w= data["render_full_resolutions"][0, 1] | |
| render_bg_colors = data["render_bg_colors"].unsqueeze(0).to(device) | |
| print("c2ws", c2ws.shape, "intrs", intrs.shape, intrs) | |
| gs_render = GS3DRenderer(human_model_path=human_model_path, subdivide_num=2, smpl_type="flame", | |
| feat_dim=64, query_dim=64, use_rgb=True, sh_degree=3, mlp_network_config=None, | |
| xyz_offset_max_step=0.0001, expr_param_dim=10, shape_param_dim=10, | |
| fix_opacity=True, fix_rotation=True, clip_scaling=0.001, add_teeth=False) | |
| gs_render.to(device) | |
| out = gs_render.forward(gs_hidden_features=torch.zeros((1, 2048, 64)).float().to(device), | |
| query_points=None, | |
| flame_data=flame_data, | |
| c2w=c2ws, | |
| intrinsic=intrs, | |
| height=render_h, | |
| width=render_w, | |
| background_color=render_bg_colors, | |
| debug=False) | |
| os.makedirs("./debug_vis/gs_render", exist_ok=True) | |
| for k, v in out.items(): | |
| if k == "comp_rgb_bg": | |
| print("comp_rgb_bg", v) | |
| continue | |
| for b_idx in range(len(v)): | |
| if k == "3dgs": | |
| for v_idx in range(len(v[b_idx])): | |
| v[b_idx][v_idx].save_ply(f"./debug_vis/gs_render/{b_idx}_{v_idx}.ply") | |
| continue | |
| for v_idx in range(v.shape[1]): | |
| save_path = os.path.join("./debug_vis/gs_render", f"{b_idx}_{v_idx}_{k}.jpg") | |
| if "normal" in k: | |
| img = ((v[b_idx, v_idx].permute(1, 2, 0).detach().cpu().numpy() + 1.0) / 2. * 255).astype(np.uint8) | |
| else: | |
| img = (v[b_idx, v_idx].permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8) | |
| print(v[b_idx, v_idx].shape, img.shape, save_path) | |
| if "mask" in k: | |
| render_img = render_images[v_idx].transpose(1, 2, 0) * 255 | |
| blend_img = (render_images[v_idx].transpose(1, 2, 0) * 255 * 0.5 + np.tile(img, (1, 1, 3)) * 0.5).clip(0, 255).astype(np.uint8) | |
| cv2.imwrite(save_path, np.hstack([np.tile(img, (1, 1, 3)), render_img.astype(np.uint8), blend_img])[:, :, (2, 1, 0)]) | |
| else: | |
| print(save_path, k) | |
| cv2.imwrite(save_path, img) | |
| if __name__ == "__main__": | |
| test_head() | |