import os from dataclasses import dataclass, field from collections import defaultdict try: from diff_gaussian_rasterization_wda import GaussianRasterizationSettings, GaussianRasterizer except: from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer from plyfile import PlyData, PlyElement import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math import copy from diffusers.utils import is_torch_version from lam.models.rendering.flame_model.flame import FlameHeadSubdivided from lam.models.transformer import TransformerDecoder from pytorch3d.transforms import matrix_to_quaternion from lam.models.rendering.utils.typing import * from lam.models.rendering.utils.utils import trunc_exp, MLP from lam.models.rendering.gaussian_model import GaussianModel from einops import rearrange, repeat from pytorch3d.ops.points_normals import estimate_pointcloud_normals os.environ["PYOPENGL_PLATFORM"] = "egl" from pytorch3d.structures import Meshes, Pointclouds from pytorch3d.renderer import ( AmbientLights, PerspectiveCameras, SoftSilhouetteShader, SoftPhongShader, RasterizationSettings, MeshRenderer, MeshRendererWithFragments, MeshRasterizer, TexturesVertex, ) from pytorch3d.renderer.blending import BlendParams, softmax_rgb_blend import lam.models.rendering.utils.mesh_utils as mesh_utils from lam.models.rendering.utils.point_utils import depth_to_normal from pytorch3d.ops.interp_face_attrs import interpolate_face_attributes inverse_sigmoid = lambda x: np.log(x / (1 - x)) def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): Rt = np.zeros((4, 4)) Rt[:3, :3] = R.transpose() Rt[:3, 3] = t Rt[3, 3] = 1.0 C2W = np.linalg.inv(Rt) cam_center = C2W[:3, 3] cam_center = (cam_center + translate) * scale C2W[:3, 3] = cam_center Rt = np.linalg.inv(C2W) return np.float32(Rt) def getProjectionMatrix(znear, zfar, fovX, fovY): tanHalfFovY = math.tan((fovY / 2)) tanHalfFovX = math.tan((fovX / 2)) top = tanHalfFovY * znear bottom = -top right = tanHalfFovX * znear left = -right P = torch.zeros(4, 4) z_sign = 1.0 P[0, 0] = 2.0 * znear / (right - left) P[1, 1] = 2.0 * znear / (top - bottom) P[0, 2] = (right + left) / (right - left) P[1, 2] = (top + bottom) / (top - bottom) P[3, 2] = z_sign P[2, 2] = z_sign * zfar / (zfar - znear) P[2, 3] = -(zfar * znear) / (zfar - znear) return P def intrinsic_to_fov(intrinsic, w, h): fx, fy = intrinsic[0, 0], intrinsic[1, 1] fov_x = 2 * torch.arctan2(w, 2 * fx) fov_y = 2 * torch.arctan2(h, 2 * fy) return fov_x, fov_y class Camera: def __init__(self, w2c, intrinsic, FoVx, FoVy, height, width, trans=np.array([0.0, 0.0, 0.0]), scale=1.0) -> None: self.FoVx = FoVx self.FoVy = FoVy self.height = int(height) self.width = int(width) self.world_view_transform = w2c.transpose(0, 1) self.intrinsic = intrinsic self.zfar = 100.0 self.znear = 0.01 self.trans = trans self.scale = scale self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).to(w2c.device) self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) self.camera_center = self.world_view_transform.inverse()[3, :3] @staticmethod def from_c2w(c2w, intrinsic, height, width): w2c = torch.inverse(c2w) FoVx, FoVy = intrinsic_to_fov(intrinsic, w=torch.tensor(width, device=w2c.device), h=torch.tensor(height, device=w2c.device)) return Camera(w2c=w2c, intrinsic=intrinsic, FoVx=FoVx, FoVy=FoVy, height=height, width=width) class GSLayer(nn.Module): def __init__(self, in_channels, use_rgb, clip_scaling=0.2, init_scaling=-5.0, scale_sphere=False, init_density=0.1, sh_degree=None, xyz_offset=True, restrict_offset=True, xyz_offset_max_step=None, fix_opacity=False, fix_rotation=False, use_fine_feat=False, pred_res=False, ): super().__init__() self.clip_scaling = clip_scaling self.use_rgb = use_rgb self.restrict_offset = restrict_offset self.xyz_offset = xyz_offset self.xyz_offset_max_step = xyz_offset_max_step # 1.2 / 32 self.fix_opacity = fix_opacity self.fix_rotation = fix_rotation self.use_fine_feat = use_fine_feat self.scale_sphere = scale_sphere self.pred_res = pred_res self.attr_dict ={ "shs": (sh_degree + 1) ** 2 * 3, "scaling": 3 if not scale_sphere else 1, "xyz": 3, "opacity": None, "rotation": None } if not self.fix_opacity: self.attr_dict["opacity"] = 1 if not self.fix_rotation: self.attr_dict["rotation"] = 4 self.out_layers = nn.ModuleDict() for key, out_ch in self.attr_dict.items(): if out_ch is None: layer = nn.Identity() else: if key == "shs" and use_rgb: out_ch = 3 if key == "shs": shs_out_ch = out_ch if pred_res: layer = nn.Linear(in_channels+out_ch, out_ch) else: layer = nn.Linear(in_channels, out_ch) # initialize if not (key == "shs" and use_rgb): if key == "opacity" and self.fix_opacity: pass elif key == "rotation" and self.fix_rotation: pass else: nn.init.constant_(layer.weight, 0) nn.init.constant_(layer.bias, 0) if key == "scaling": nn.init.constant_(layer.bias, init_scaling) elif key == "rotation": if not self.fix_rotation: nn.init.constant_(layer.bias, 0) nn.init.constant_(layer.bias[0], 1.0) elif key == "opacity": if not self.fix_opacity: nn.init.constant_(layer.bias, inverse_sigmoid(init_density)) self.out_layers[key] = layer if self.use_fine_feat: fine_shs_layer = nn.Linear(in_channels, shs_out_ch) nn.init.constant_(fine_shs_layer.weight, 0) nn.init.constant_(fine_shs_layer.bias, 0) self.out_layers["fine_shs"] = fine_shs_layer def forward(self, x, pts, x_fine=None, gs_raw_attr=None, ret_raw=False, vtx_sym_idxs=None): assert len(x.shape) == 2 ret = {} if ret_raw: raw_attr = {} ori_x = x for k in self.attr_dict: # if vtx_sym_idxs is not None and k in ["shs", "scaling", "opacity"]: if vtx_sym_idxs is not None and k in ["shs", "scaling", "opacity", "rotation"]: # print("==="*16*3, "\n\n\n"+"use sym mean.", "\n"+"==="*16*3) # x = (x + x[vtx_sym_idxs.to(x.device), :]) / 2. x = ori_x[vtx_sym_idxs.to(x.device), :] else: x = ori_x layer =self.out_layers[k] if self.pred_res and (not self.fix_opacity or k != "opacity") and (not self.fix_rotation or k != "rotation"): v = layer(torch.cat([gs_raw_attr[k], x], dim=-1)) v = gs_raw_attr[k] + v else: v = layer(x) if ret_raw: raw_attr[k] = v if k == "rotation": if self.fix_rotation: v = matrix_to_quaternion(torch.eye(3).type_as(x)[None,: , :].repeat(x.shape[0], 1, 1)) # constant rotation else: # assert len(x.shape) == 2 v = torch.nn.functional.normalize(v) elif k == "scaling": v = trunc_exp(v) if self.scale_sphere: assert v.shape[-1] == 1 v = torch.cat([v, v, v], dim=-1) if self.clip_scaling is not None: v = torch.clamp(v, min=0, max=self.clip_scaling) elif k == "opacity": if self.fix_opacity: v = torch.ones_like(x)[..., 0:1] else: v = torch.sigmoid(v) elif k == "shs": if self.use_rgb: v[..., :3] = torch.sigmoid(v[..., :3]) if self.use_fine_feat: v_fine = self.out_layers["fine_shs"](x_fine) v_fine = torch.tanh(v_fine) v = v + v_fine else: if self.use_fine_feat: v_fine = self.out_layers["fine_shs"](x_fine) v = v + v_fine v = torch.reshape(v, (v.shape[0], -1, 3)) elif k == "xyz": # TODO check if self.restrict_offset: max_step = self.xyz_offset_max_step v = (torch.sigmoid(v) - 0.5) * max_step if self.xyz_offset: pass else: assert NotImplementedError ret["offset"] = v v = pts + v ret[k] = v if ret_raw: return GaussianModel(**ret), raw_attr else: return GaussianModel(**ret) class PointEmbed(nn.Module): def __init__(self, hidden_dim=48, dim=128): super().__init__() assert hidden_dim % 6 == 0 self.embedding_dim = hidden_dim e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi e = torch.stack([ torch.cat([e, torch.zeros(self.embedding_dim // 6), torch.zeros(self.embedding_dim // 6)]), torch.cat([torch.zeros(self.embedding_dim // 6), e, torch.zeros(self.embedding_dim // 6)]), torch.cat([torch.zeros(self.embedding_dim // 6), torch.zeros(self.embedding_dim // 6), e]), ]) self.register_buffer('basis', e) # 3 x 16 self.mlp = nn.Linear(self.embedding_dim+3, dim) self.norm = nn.LayerNorm(dim) @staticmethod def embed(input, basis): projections = torch.einsum( 'bnd,de->bne', input, basis) embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) return embeddings def forward(self, input): # input: B x N x 3 embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C embed = self.norm(embed) return embed class CrossAttnBlock(nn.Module): """ Transformer block that takes in a cross-attention condition. Designed for SparseLRM architecture. """ # Block contains a cross-attention layer, a self-attention layer, and an MLP def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float=None, attn_drop: float = 0., attn_bias: bool = False, mlp_ratio: float = 4., mlp_drop: float = 0., feedforward=False): super().__init__() # TODO check already apply normalization # self.norm_q = nn.LayerNorm(inner_dim, eps=eps) # self.norm_k = nn.LayerNorm(cond_dim, eps=eps) self.norm_q = nn.Identity() self.norm_k = nn.Identity() self.cross_attn = nn.MultiheadAttention( embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, dropout=attn_drop, bias=attn_bias, batch_first=True) self.mlp = None if feedforward: self.norm2 = nn.LayerNorm(inner_dim, eps=eps) self.self_attn = nn.MultiheadAttention( embed_dim=inner_dim, num_heads=num_heads, dropout=attn_drop, bias=attn_bias, batch_first=True) self.norm3 = nn.LayerNorm(inner_dim, eps=eps) self.mlp = nn.Sequential( nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), nn.GELU(), nn.Dropout(mlp_drop), nn.Linear(int(inner_dim * mlp_ratio), inner_dim), nn.Dropout(mlp_drop), ) def forward(self, x, cond): # x: [N, L, D] # cond: [N, L_cond, D_cond] x = self.cross_attn(self.norm_q(x), self.norm_k(cond), cond, need_weights=False)[0] if self.mlp is not None: before_sa = self.norm2(x) x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0] x = x + self.mlp(self.norm3(x)) return x class DecoderCrossAttn(nn.Module): def __init__(self, query_dim, context_dim, num_heads, mlp=False, decode_with_extra_info=None): super().__init__() self.query_dim = query_dim self.context_dim = context_dim self.cross_attn = CrossAttnBlock(inner_dim=query_dim, cond_dim=context_dim, num_heads=num_heads, feedforward=mlp, eps=1e-5) self.decode_with_extra_info = decode_with_extra_info if decode_with_extra_info is not None: if decode_with_extra_info["type"] == "dinov2p14_feat": context_dim = decode_with_extra_info["cond_dim"] self.cross_attn_color = CrossAttnBlock(inner_dim=query_dim, cond_dim=context_dim, num_heads=num_heads, feedforward=False, eps=1e-5) elif decode_with_extra_info["type"] == "decoder_dinov2p14_feat": from lam.models.encoders.dinov2_wrapper import Dinov2Wrapper self.encoder = Dinov2Wrapper(model_name='dinov2_vits14_reg', freeze=False, encoder_feat_dim=384) self.cross_attn_color = CrossAttnBlock(inner_dim=query_dim, cond_dim=384, num_heads=num_heads, feedforward=False, eps=1e-5) elif decode_with_extra_info["type"] == "decoder_resnet18_feat": from lam.models.encoders.xunet_wrapper import XnetWrapper self.encoder = XnetWrapper(model_name='resnet18', freeze=False, encoder_feat_dim=64) self.cross_attn_color = CrossAttnBlock(inner_dim=query_dim, cond_dim=64, num_heads=num_heads, feedforward=False, eps=1e-5) def resize_image(self, image, multiply): B, _, H, W = image.shape new_h, new_w = math.ceil(H / multiply) * multiply, math.ceil(W / multiply) * multiply image = F.interpolate(image, (new_h, new_w), align_corners=True, mode="bilinear") return image def forward(self, pcl_query, pcl_latent, extra_info=None): out = self.cross_attn(pcl_query, pcl_latent) if self.decode_with_extra_info is not None: out_dict = {} out_dict["coarse"] = out if self.decode_with_extra_info["type"] == "dinov2p14_feat": out = self.cross_attn_color(out, extra_info["image_feats"]) out_dict["fine"] = out return out_dict elif self.decode_with_extra_info["type"] == "decoder_dinov2p14_feat": img_feat = self.encoder(extra_info["image"]) out = self.cross_attn_color(out, img_feat) out_dict["fine"] = out return out_dict elif self.decode_with_extra_info["type"] == "decoder_resnet18_feat": image = extra_info["image"] image = self.resize_image(image, multiply=32) img_feat = self.encoder(image) out = self.cross_attn_color(out, img_feat) out_dict["fine"] = out return out_dict return out class GS3DRenderer(nn.Module): def __init__(self, human_model_path, subdivide_num, smpl_type, feat_dim, query_dim, use_rgb, sh_degree, xyz_offset_max_step, mlp_network_config, expr_param_dim, shape_param_dim, clip_scaling=0.2, scale_sphere=False, skip_decoder=False, fix_opacity=False, fix_rotation=False, decode_with_extra_info=None, gradient_checkpointing=False, add_teeth=True, teeth_bs_flag=False, oral_mesh_flag=False, **kwargs, ): super().__init__() print(f"#########scale sphere:{scale_sphere}, add_teeth:{add_teeth}") self.gradient_checkpointing = gradient_checkpointing self.skip_decoder = skip_decoder self.smpl_type = smpl_type assert self.smpl_type == "flame" self.sym_rend2 = True self.teeth_bs_flag = teeth_bs_flag self.oral_mesh_flag = oral_mesh_flag self.render_rgb = kwargs.get("render_rgb", True) print("==="*16*3, "\n Render rgb:", self.render_rgb, "\n"+"==="*16*3) self.scaling_modifier = 1.0 self.sh_degree = sh_degree if use_rgb: self.sh_degree = 0 use_rgb = use_rgb self.flame_model = FlameHeadSubdivided( 300, 100, add_teeth=add_teeth, add_shoulder=False, flame_model_path=f'{human_model_path}/flame_assets/flame/flame2023.pkl', flame_lmk_embedding_path=f"{human_model_path}/flame_assets/flame/landmark_embedding_with_eyes.npy", flame_template_mesh_path=f"{human_model_path}/flame_assets/flame/head_template_mesh.obj", flame_parts_path=f"{human_model_path}/flame_assets/flame/FLAME_masks.pkl", subdivide_num=subdivide_num, teeth_bs_flag=teeth_bs_flag, oral_mesh_flag=oral_mesh_flag ) if not self.skip_decoder: self.pcl_embed = PointEmbed(dim=query_dim) self.mlp_network_config = mlp_network_config if self.mlp_network_config is not None: self.mlp_net = MLP(query_dim, query_dim, **self.mlp_network_config) init_scaling = -5.0 self.gs_net = GSLayer(in_channels=query_dim, use_rgb=use_rgb, sh_degree=self.sh_degree, clip_scaling=clip_scaling, scale_sphere=scale_sphere, init_scaling=init_scaling, init_density=0.1, xyz_offset=True, restrict_offset=True, xyz_offset_max_step=xyz_offset_max_step, fix_opacity=fix_opacity, fix_rotation=fix_rotation, use_fine_feat=True if decode_with_extra_info is not None and decode_with_extra_info["type"] is not None else False, ) def forward_single_view(self, gs: GaussianModel, viewpoint_camera: Camera, background_color: Optional[Float[Tensor, "3"]], ): # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means screenspace_points = torch.zeros_like(gs.xyz, dtype=gs.xyz.dtype, requires_grad=True, device=self.device) + 0 try: screenspace_points.retain_grad() except: pass bg_color = background_color # Set up rasterization configuration tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) GSRSettings = GaussianRasterizationSettings GSR = GaussianRasterizer raster_settings = GSRSettings( image_height=int(viewpoint_camera.height), image_width=int(viewpoint_camera.width), tanfovx=tanfovx, tanfovy=tanfovy, bg=bg_color, scale_modifier=self.scaling_modifier, viewmatrix=viewpoint_camera.world_view_transform, projmatrix=viewpoint_camera.full_proj_transform.float(), sh_degree=self.sh_degree, campos=viewpoint_camera.camera_center, prefiltered=False, debug=False ) rasterizer = GSR(raster_settings=raster_settings) means3D = gs.xyz means2D = screenspace_points opacity = gs.opacity # If precomputed 3d covariance is provided, use it. If not, then it will be computed from # scaling / rotation by the rasterizer. scales = None rotations = None cov3D_precomp = None scales = gs.scaling rotations = gs.rotation # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. shs = None colors_precomp = None if self.gs_net.use_rgb: colors_precomp = gs.shs.squeeze(1) else: shs = gs.shs # Rasterize visible Gaussians to image, obtain their radii (on screen). # torch.cuda.synchronize() # with boxx.timeit(): with torch.autocast(device_type=self.device.type, dtype=torch.float32): raster_ret = rasterizer( means3D = means3D.float(), means2D = means2D.float(), shs = shs.float() if not self.gs_net.use_rgb else None, colors_precomp = colors_precomp.float() if colors_precomp is not None else None, opacities = opacity.float(), scales = scales.float(), rotations = rotations.float(), cov3D_precomp = cov3D_precomp ) rendered_image, radii, rendered_depth, rendered_alpha = raster_ret ret = { "comp_rgb": rendered_image.permute(1, 2, 0), # [H, W, 3] "comp_rgb_bg": bg_color, 'comp_mask': rendered_alpha.permute(1, 2, 0), 'comp_depth': rendered_depth.permute(1, 2, 0), } return ret def animate_gs_model(self, gs_attr: GaussianModel, query_points, flame_data, debug=False): """ query_points: [N, 3] """ device = gs_attr.xyz.device if debug: N = gs_attr.xyz.shape[0] gs_attr.xyz = torch.ones_like(gs_attr.xyz) * 0.0 rotation = matrix_to_quaternion(torch.eye(3).float()[None, :, :].repeat(N, 1, 1)).to(device) # constant rotation opacity = torch.ones((N, 1)).float().to(device) # constant opacity gs_attr.opacity = opacity gs_attr.rotation = rotation # gs_attr.scaling = torch.ones_like(gs_attr.scaling) * 0.05 # print(gs_attr.shs.shape) with torch.autocast(device_type=device.type, dtype=torch.float32): # mean_3d = query_points + gs_attr.xyz # [N, 3] mean_3d = gs_attr.xyz # [N, 3] num_view = flame_data["expr"].shape[0] # [Nv, 100] mean_3d = mean_3d.unsqueeze(0).repeat(num_view, 1, 1) # [Nv, N, 3] query_points = query_points.unsqueeze(0).repeat(num_view, 1, 1) if self.teeth_bs_flag: expr = torch.cat([flame_data['expr'], flame_data['teeth_bs']], dim=-1) else: expr = flame_data["expr"] ret = self.flame_model.animation_forward(v_cano=mean_3d, shape=flame_data["betas"].repeat(num_view, 1), expr=expr, rotation=flame_data["rotation"], neck=flame_data["neck_pose"], jaw=flame_data["jaw_pose"], eyes=flame_data["eyes_pose"], translation=flame_data["translation"], zero_centered_at_root_node=False, return_landmarks=False, return_verts_cano=False, # static_offset=flame_data['static_offset'].to('cuda'), static_offset=None, ) mean_3d = ret["animated"] gs_attr_list = [] for i in range(num_view): gs_attr_copy = GaussianModel(xyz=mean_3d[i], opacity=gs_attr.opacity, rotation=gs_attr.rotation, scaling=gs_attr.scaling, shs=gs_attr.shs, albedo=gs_attr.albedo, lights=gs_attr.lights, offset=gs_attr.offset) # [N, 3] gs_attr_list.append(gs_attr_copy) return gs_attr_list def forward_gs_attr(self, x, query_points, flame_data, debug=False, x_fine=None, vtx_sym_idxs=None): """ x: [N, C] Float[Tensor, "Np Cp"], query_points: [N, 3] Float[Tensor, "Np 3"] """ device = x.device if self.mlp_network_config is not None: x = self.mlp_net(x) if x_fine is not None: x_fine = self.mlp_net(x_fine) gs_attr: GaussianModel = self.gs_net(x, query_points, x_fine, vtx_sym_idxs=vtx_sym_idxs) return gs_attr def get_query_points(self, flame_data, device): with torch.no_grad(): with torch.autocast(device_type=device.type, dtype=torch.float32): # print(flame_data["betas"].shape, flame_data["face_offset"].shape, flame_data["joint_offset"].shape) # positions, _, transform_mat_neutral_pose = self.flame_model.get_query_points(flame_data, device=device) # [B, N, 3] positions = self.flame_model.get_cano_verts(shape_params=flame_data["betas"]) # [B, N, 3] # print(f"positions shape:{positions.shape}") return positions, flame_data def query_latent_feat(self, positions: Float[Tensor, "*B N1 3"], flame_data, latent_feat: Float[Tensor, "*B N2 C"], extra_info): device = latent_feat.device if self.skip_decoder: gs_feats = latent_feat assert positions is not None else: assert positions is None if positions is None: positions, flame_data = self.get_query_points(flame_data, device) with torch.autocast(device_type=device.type, dtype=torch.float32): pcl_embed = self.pcl_embed(positions) gs_feats = pcl_embed return gs_feats, positions, flame_data def forward_single_batch( self, gs_list: list[GaussianModel], c2ws: Float[Tensor, "Nv 4 4"], intrinsics: Float[Tensor, "Nv 4 4"], height: int, width: int, background_color: Optional[Float[Tensor, "Nv 3"]], debug: bool=False, ): out_list = [] self.device = gs_list[0].xyz.device for v_idx, (c2w, intrinsic) in enumerate(zip(c2ws, intrinsics)): out_list.append(self.forward_single_view( gs_list[v_idx], Camera.from_c2w(c2w, intrinsic, height, width), background_color[v_idx], )) out = defaultdict(list) for out_ in out_list: for k, v in out_.items(): out[k].append(v) out = {k: torch.stack(v, dim=0) for k, v in out.items()} out["3dgs"] = gs_list return out def get_sing_batch_smpl_data(self, smpl_data, bidx): smpl_data_single_batch = {} for k, v in smpl_data.items(): smpl_data_single_batch[k] = v[bidx] # e.g. body_pose: [B, N_v, 21, 3] -> [N_v, 21, 3] if k == "betas" or (k == "joint_offset") or (k == "face_offset"): smpl_data_single_batch[k] = v[bidx:bidx+1] # e.g. betas: [B, 100] -> [1, 100] return smpl_data_single_batch def get_single_view_smpl_data(self, smpl_data, vidx): smpl_data_single_view = {} for k, v in smpl_data.items(): assert v.shape[0] == 1 if k == "betas" or (k == "joint_offset") or (k == "face_offset") or (k == "transform_mat_neutral_pose"): smpl_data_single_view[k] = v # e.g. betas: [1, 100] -> [1, 100] else: smpl_data_single_view[k] = v[:, vidx: vidx + 1] # e.g. body_pose: [1, N_v, 21, 3] -> [1, 1, 21, 3] return smpl_data_single_view def forward_gs(self, gs_hidden_features: Float[Tensor, "B Np Cp"], query_points: Float[Tensor, "B Np_q 3"], flame_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100] additional_features: Optional[dict] = None, debug: bool = False, **kwargs): batch_size = gs_hidden_features.shape[0] query_gs_features, query_points, flame_data = self.query_latent_feat(query_points, flame_data, gs_hidden_features, additional_features) gs_model_list = [] all_query_points = [] for b in range(batch_size): all_query_points.append(query_points[b:b+1, :]) if isinstance(query_gs_features, dict): ret_gs = self.forward_gs_attr(query_gs_features["coarse"][b], query_points[b], None, debug, x_fine=query_gs_features["fine"][b], vtx_sym_idxs=None) else: ret_gs = self.forward_gs_attr(query_gs_features[b], query_points[b], None, debug, vtx_sym_idxs=None) ret_gs.update_albedo(ret_gs.shs.clone()) gs_model_list.append(ret_gs) query_points = torch.cat(all_query_points, dim=0) return gs_model_list, query_points, flame_data, query_gs_features def forward_res_refine_gs(self, gs_hidden_features: Float[Tensor, "B Np Cp"], query_points: Float[Tensor, "B Np_q 3"], flame_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100] additional_features: Optional[dict] = None, debug: bool = False, gs_raw_attr_list: list = None, **kwargs): batch_size = gs_hidden_features.shape[0] query_gs_features, query_points, flame_data = self.query_latent_feat(query_points, flame_data, gs_hidden_features, additional_features) gs_model_list = [] for b in range(batch_size): gs_model = self.gs_refine_net(query_gs_features[b], query_points[b], x_fine=None, gs_raw_attr=gs_raw_attr_list[b]) gs_model_list.append(gs_model) return gs_model_list, query_points, flame_data, query_gs_features def forward_animate_gs(self, gs_model_list, query_points, flame_data, c2w, intrinsic, height, width, background_color, debug=False): batch_size = len(gs_model_list) out_list = [] for b in range(batch_size): gs_model = gs_model_list[b] query_pt = query_points[b] animatable_gs_model_list: list[GaussianModel] = self.animate_gs_model(gs_model, query_pt, self.get_sing_batch_smpl_data(flame_data, b), debug=debug) assert len(animatable_gs_model_list) == c2w.shape[1] out_list.append(self.forward_single_batch( animatable_gs_model_list, c2w[b], intrinsic[b], height, width, background_color[b] if background_color is not None else None, debug=debug)) out = defaultdict(list) for out_ in out_list: for k, v in out_.items(): out[k].append(v) for k, v in out.items(): if isinstance(v[0], torch.Tensor): out[k] = torch.stack(v, dim=0) else: out[k] = v render_keys = ["comp_rgb", "comp_mask", "comp_depth"] for key in render_keys: out[key] = rearrange(out[key], "b v h w c -> b v c h w") return out def project_single_view_feats(self, img_vtx_ids, feats, nv, inter_feat=True): b, h, w, k = img_vtx_ids.shape c, ih, iw = feats.shape vtx_ids = img_vtx_ids if h != ih or w != iw: if inter_feat: feats = torch.nn.functional.interpolate( rearrange(feats, "(b c) h w -> b c h w", b=1).float(), (h, w) ).squeeze(0) vtx_ids = rearrange(vtx_ids, "b (c h) w k -> (b k) c h w", c=1).long().squeeze(1) else: vtx_ids = torch.nn.functional.interpolate( rearrange(vtx_ids, "b (c h) w k -> (b k) c h w", c=1).float(), (ih, iw), mode="nearest" ).long().squeeze(1) else: vtx_ids = rearrange(vtx_ids, "b h w k -> (b k) h w", b=1).long() vis_mask = vtx_ids > 0 vtx_ids = vtx_ids[vis_mask] # n vtx_ids = repeat(vtx_ids, "n -> n c", c=c) feats = repeat(feats, "c h w -> k h w c", k=k).to(vtx_ids.device) feats = feats[vis_mask, :] # n, c sums = torch.zeros((nv, c), dtype=feats.dtype, device=feats.device) counts = torch.zeros((nv), dtype=torch.int64, device=feats.device) sums.scatter_add_(0, vtx_ids, feats) one_hot = torch.ones_like(vtx_ids[:, 0], dtype=torch.int64).to(feats.device) counts.scatter_add_(0, vtx_ids[:, 0], one_hot) clamp_counts = counts.clamp(min=1) mean_feats = sums / clamp_counts.view(-1, 1) return mean_feats def forward(self, gs_hidden_features: Float[Tensor, "B Np Cp"], query_points: Float[Tensor, "B Np 3"], flame_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100] c2w: Float[Tensor, "B Nv 4 4"], intrinsic: Float[Tensor, "B Nv 4 4"], height, width, additional_features: Optional[Float[Tensor, "B C H W"]] = None, background_color: Optional[Float[Tensor, "B Nv 3"]] = None, debug: bool = False, **kwargs): # need shape_params of flame_data to get querty points and get "transform_mat_neutral_pose" gs_model_list, query_points, flame_data, query_gs_features = self.forward_gs(gs_hidden_features, query_points, flame_data=flame_data, additional_features=additional_features, debug=debug) out = self.forward_animate_gs(gs_model_list, query_points, flame_data, c2w, intrinsic, height, width, background_color, debug) return out def test_head(): import cv2 human_model_path = "./pretrained_models/human_model_files" device = "cuda" from accelerate.utils import set_seed set_seed(1234) from lam.datasets.video_head import VideoHeadDataset root_dir = "./train_data/vfhq_vhap/export" meta_path = "./train_data/vfhq_vhap/label/valid_id_list.json" # root_dir = "./train_data/nersemble/export" # meta_path = "./train_data/nersemble/label/valid_id_list1.json" dataset = VideoHeadDataset(root_dirs=root_dir, meta_path=meta_path, sample_side_views=7, render_image_res_low=512, render_image_res_high=512, render_region_size=(512, 512), source_image_res=512, enlarge_ratio=[0.8, 1.2], debug=False) data = dataset[0] def get_flame_params(data): flame_params = {} flame_keys = ['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas',\ 'rotation', 'neck_pose', 'eyes_pose', 'translation'] for k, v in data.items(): if k in flame_keys: # print(k, v.shape) flame_params[k] = data[k] return flame_params flame_data = get_flame_params(data) flame_data_tmp = {} for k, v in flame_data.items(): flame_data_tmp[k] = v.unsqueeze(0).to(device) print(k, v.shape) flame_data = flame_data_tmp c2ws = data["c2ws"].unsqueeze(0).to(device) intrs = data["intrs"].unsqueeze(0).to(device) render_images = data["render_image"].numpy() render_h = data["render_full_resolutions"][0, 0] render_w= data["render_full_resolutions"][0, 1] render_bg_colors = data["render_bg_colors"].unsqueeze(0).to(device) print("c2ws", c2ws.shape, "intrs", intrs.shape, intrs) gs_render = GS3DRenderer(human_model_path=human_model_path, subdivide_num=2, smpl_type="flame", feat_dim=64, query_dim=64, use_rgb=True, sh_degree=3, mlp_network_config=None, xyz_offset_max_step=0.0001, expr_param_dim=10, shape_param_dim=10, fix_opacity=True, fix_rotation=True, clip_scaling=0.001, add_teeth=False) gs_render.to(device) out = gs_render.forward(gs_hidden_features=torch.zeros((1, 2048, 64)).float().to(device), query_points=None, flame_data=flame_data, c2w=c2ws, intrinsic=intrs, height=render_h, width=render_w, background_color=render_bg_colors, debug=False) os.makedirs("./debug_vis/gs_render", exist_ok=True) for k, v in out.items(): if k == "comp_rgb_bg": print("comp_rgb_bg", v) continue for b_idx in range(len(v)): if k == "3dgs": for v_idx in range(len(v[b_idx])): v[b_idx][v_idx].save_ply(f"./debug_vis/gs_render/{b_idx}_{v_idx}.ply") continue for v_idx in range(v.shape[1]): save_path = os.path.join("./debug_vis/gs_render", f"{b_idx}_{v_idx}_{k}.jpg") if "normal" in k: img = ((v[b_idx, v_idx].permute(1, 2, 0).detach().cpu().numpy() + 1.0) / 2. * 255).astype(np.uint8) else: img = (v[b_idx, v_idx].permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8) print(v[b_idx, v_idx].shape, img.shape, save_path) if "mask" in k: render_img = render_images[v_idx].transpose(1, 2, 0) * 255 blend_img = (render_images[v_idx].transpose(1, 2, 0) * 255 * 0.5 + np.tile(img, (1, 1, 3)) * 0.5).clip(0, 255).astype(np.uint8) cv2.imwrite(save_path, np.hstack([np.tile(img, (1, 1, 3)), render_img.astype(np.uint8), blend_img])[:, :, (2, 1, 0)]) else: print(save_path, k) cv2.imwrite(save_path, img) if __name__ == "__main__": test_head()