LAM / lam /models /rendering /gs_renderer.py
yuandong513
feat: init
17cd746
raw
history blame
40.6 kB
import os
from dataclasses import dataclass, field
from collections import defaultdict
try:
from diff_gaussian_rasterization_wda import GaussianRasterizationSettings, GaussianRasterizer
except:
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from plyfile import PlyData, PlyElement
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import copy
from diffusers.utils import is_torch_version
from lam.models.rendering.flame_model.flame import FlameHeadSubdivided
from lam.models.transformer import TransformerDecoder
from pytorch3d.transforms import matrix_to_quaternion
from lam.models.rendering.utils.typing import *
from lam.models.rendering.utils.utils import trunc_exp, MLP
from lam.models.rendering.gaussian_model import GaussianModel
from einops import rearrange, repeat
from pytorch3d.ops.points_normals import estimate_pointcloud_normals
os.environ["PYOPENGL_PLATFORM"] = "egl"
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.renderer import (
AmbientLights,
PerspectiveCameras,
SoftSilhouetteShader,
SoftPhongShader,
RasterizationSettings,
MeshRenderer,
MeshRendererWithFragments,
MeshRasterizer,
TexturesVertex,
)
from pytorch3d.renderer.blending import BlendParams, softmax_rgb_blend
import lam.models.rendering.utils.mesh_utils as mesh_utils
from lam.models.rendering.utils.point_utils import depth_to_normal
from pytorch3d.ops.interp_face_attrs import interpolate_face_attributes
inverse_sigmoid = lambda x: np.log(x / (1 - x))
def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
Rt = np.zeros((4, 4))
Rt[:3, :3] = R.transpose()
Rt[:3, 3] = t
Rt[3, 3] = 1.0
C2W = np.linalg.inv(Rt)
cam_center = C2W[:3, 3]
cam_center = (cam_center + translate) * scale
C2W[:3, 3] = cam_center
Rt = np.linalg.inv(C2W)
return np.float32(Rt)
def getProjectionMatrix(znear, zfar, fovX, fovY):
tanHalfFovY = math.tan((fovY / 2))
tanHalfFovX = math.tan((fovX / 2))
top = tanHalfFovY * znear
bottom = -top
right = tanHalfFovX * znear
left = -right
P = torch.zeros(4, 4)
z_sign = 1.0
P[0, 0] = 2.0 * znear / (right - left)
P[1, 1] = 2.0 * znear / (top - bottom)
P[0, 2] = (right + left) / (right - left)
P[1, 2] = (top + bottom) / (top - bottom)
P[3, 2] = z_sign
P[2, 2] = z_sign * zfar / (zfar - znear)
P[2, 3] = -(zfar * znear) / (zfar - znear)
return P
def intrinsic_to_fov(intrinsic, w, h):
fx, fy = intrinsic[0, 0], intrinsic[1, 1]
fov_x = 2 * torch.arctan2(w, 2 * fx)
fov_y = 2 * torch.arctan2(h, 2 * fy)
return fov_x, fov_y
class Camera:
def __init__(self, w2c, intrinsic, FoVx, FoVy, height, width, trans=np.array([0.0, 0.0, 0.0]), scale=1.0) -> None:
self.FoVx = FoVx
self.FoVy = FoVy
self.height = int(height)
self.width = int(width)
self.world_view_transform = w2c.transpose(0, 1)
self.intrinsic = intrinsic
self.zfar = 100.0
self.znear = 0.01
self.trans = trans
self.scale = scale
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).to(w2c.device)
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
self.camera_center = self.world_view_transform.inverse()[3, :3]
@staticmethod
def from_c2w(c2w, intrinsic, height, width):
w2c = torch.inverse(c2w)
FoVx, FoVy = intrinsic_to_fov(intrinsic, w=torch.tensor(width, device=w2c.device), h=torch.tensor(height, device=w2c.device))
return Camera(w2c=w2c, intrinsic=intrinsic, FoVx=FoVx, FoVy=FoVy, height=height, width=width)
class GSLayer(nn.Module):
def __init__(self, in_channels, use_rgb,
clip_scaling=0.2,
init_scaling=-5.0,
scale_sphere=False,
init_density=0.1,
sh_degree=None,
xyz_offset=True,
restrict_offset=True,
xyz_offset_max_step=None,
fix_opacity=False,
fix_rotation=False,
use_fine_feat=False,
pred_res=False,
):
super().__init__()
self.clip_scaling = clip_scaling
self.use_rgb = use_rgb
self.restrict_offset = restrict_offset
self.xyz_offset = xyz_offset
self.xyz_offset_max_step = xyz_offset_max_step # 1.2 / 32
self.fix_opacity = fix_opacity
self.fix_rotation = fix_rotation
self.use_fine_feat = use_fine_feat
self.scale_sphere = scale_sphere
self.pred_res = pred_res
self.attr_dict ={
"shs": (sh_degree + 1) ** 2 * 3,
"scaling": 3 if not scale_sphere else 1,
"xyz": 3,
"opacity": None,
"rotation": None
}
if not self.fix_opacity:
self.attr_dict["opacity"] = 1
if not self.fix_rotation:
self.attr_dict["rotation"] = 4
self.out_layers = nn.ModuleDict()
for key, out_ch in self.attr_dict.items():
if out_ch is None:
layer = nn.Identity()
else:
if key == "shs" and use_rgb:
out_ch = 3
if key == "shs":
shs_out_ch = out_ch
if pred_res:
layer = nn.Linear(in_channels+out_ch, out_ch)
else:
layer = nn.Linear(in_channels, out_ch)
# initialize
if not (key == "shs" and use_rgb):
if key == "opacity" and self.fix_opacity:
pass
elif key == "rotation" and self.fix_rotation:
pass
else:
nn.init.constant_(layer.weight, 0)
nn.init.constant_(layer.bias, 0)
if key == "scaling":
nn.init.constant_(layer.bias, init_scaling)
elif key == "rotation":
if not self.fix_rotation:
nn.init.constant_(layer.bias, 0)
nn.init.constant_(layer.bias[0], 1.0)
elif key == "opacity":
if not self.fix_opacity:
nn.init.constant_(layer.bias, inverse_sigmoid(init_density))
self.out_layers[key] = layer
if self.use_fine_feat:
fine_shs_layer = nn.Linear(in_channels, shs_out_ch)
nn.init.constant_(fine_shs_layer.weight, 0)
nn.init.constant_(fine_shs_layer.bias, 0)
self.out_layers["fine_shs"] = fine_shs_layer
def forward(self, x, pts, x_fine=None, gs_raw_attr=None, ret_raw=False, vtx_sym_idxs=None):
assert len(x.shape) == 2
ret = {}
if ret_raw:
raw_attr = {}
ori_x = x
for k in self.attr_dict:
# if vtx_sym_idxs is not None and k in ["shs", "scaling", "opacity"]:
if vtx_sym_idxs is not None and k in ["shs", "scaling", "opacity", "rotation"]:
# print("==="*16*3, "\n\n\n"+"use sym mean.", "\n"+"==="*16*3)
# x = (x + x[vtx_sym_idxs.to(x.device), :]) / 2.
x = ori_x[vtx_sym_idxs.to(x.device), :]
else:
x = ori_x
layer =self.out_layers[k]
if self.pred_res and (not self.fix_opacity or k != "opacity") and (not self.fix_rotation or k != "rotation"):
v = layer(torch.cat([gs_raw_attr[k], x], dim=-1))
v = gs_raw_attr[k] + v
else:
v = layer(x)
if ret_raw:
raw_attr[k] = v
if k == "rotation":
if self.fix_rotation:
v = matrix_to_quaternion(torch.eye(3).type_as(x)[None,: , :].repeat(x.shape[0], 1, 1)) # constant rotation
else:
# assert len(x.shape) == 2
v = torch.nn.functional.normalize(v)
elif k == "scaling":
v = trunc_exp(v)
if self.scale_sphere:
assert v.shape[-1] == 1
v = torch.cat([v, v, v], dim=-1)
if self.clip_scaling is not None:
v = torch.clamp(v, min=0, max=self.clip_scaling)
elif k == "opacity":
if self.fix_opacity:
v = torch.ones_like(x)[..., 0:1]
else:
v = torch.sigmoid(v)
elif k == "shs":
if self.use_rgb:
v[..., :3] = torch.sigmoid(v[..., :3])
if self.use_fine_feat:
v_fine = self.out_layers["fine_shs"](x_fine)
v_fine = torch.tanh(v_fine)
v = v + v_fine
else:
if self.use_fine_feat:
v_fine = self.out_layers["fine_shs"](x_fine)
v = v + v_fine
v = torch.reshape(v, (v.shape[0], -1, 3))
elif k == "xyz":
# TODO check
if self.restrict_offset:
max_step = self.xyz_offset_max_step
v = (torch.sigmoid(v) - 0.5) * max_step
if self.xyz_offset:
pass
else:
assert NotImplementedError
ret["offset"] = v
v = pts + v
ret[k] = v
if ret_raw:
return GaussianModel(**ret), raw_attr
else:
return GaussianModel(**ret)
class PointEmbed(nn.Module):
def __init__(self, hidden_dim=48, dim=128):
super().__init__()
assert hidden_dim % 6 == 0
self.embedding_dim = hidden_dim
e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
e = torch.stack([
torch.cat([e, torch.zeros(self.embedding_dim // 6),
torch.zeros(self.embedding_dim // 6)]),
torch.cat([torch.zeros(self.embedding_dim // 6), e,
torch.zeros(self.embedding_dim // 6)]),
torch.cat([torch.zeros(self.embedding_dim // 6),
torch.zeros(self.embedding_dim // 6), e]),
])
self.register_buffer('basis', e) # 3 x 16
self.mlp = nn.Linear(self.embedding_dim+3, dim)
self.norm = nn.LayerNorm(dim)
@staticmethod
def embed(input, basis):
projections = torch.einsum(
'bnd,de->bne', input, basis)
embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
return embeddings
def forward(self, input):
# input: B x N x 3
embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C
embed = self.norm(embed)
return embed
class CrossAttnBlock(nn.Module):
"""
Transformer block that takes in a cross-attention condition.
Designed for SparseLRM architecture.
"""
# Block contains a cross-attention layer, a self-attention layer, and an MLP
def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float=None,
attn_drop: float = 0., attn_bias: bool = False,
mlp_ratio: float = 4., mlp_drop: float = 0., feedforward=False):
super().__init__()
# TODO check already apply normalization
# self.norm_q = nn.LayerNorm(inner_dim, eps=eps)
# self.norm_k = nn.LayerNorm(cond_dim, eps=eps)
self.norm_q = nn.Identity()
self.norm_k = nn.Identity()
self.cross_attn = nn.MultiheadAttention(
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
dropout=attn_drop, bias=attn_bias, batch_first=True)
self.mlp = None
if feedforward:
self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
self.self_attn = nn.MultiheadAttention(
embed_dim=inner_dim, num_heads=num_heads,
dropout=attn_drop, bias=attn_bias, batch_first=True)
self.norm3 = nn.LayerNorm(inner_dim, eps=eps)
self.mlp = nn.Sequential(
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(mlp_drop),
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
nn.Dropout(mlp_drop),
)
def forward(self, x, cond):
# x: [N, L, D]
# cond: [N, L_cond, D_cond]
x = self.cross_attn(self.norm_q(x), self.norm_k(cond), cond, need_weights=False)[0]
if self.mlp is not None:
before_sa = self.norm2(x)
x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
x = x + self.mlp(self.norm3(x))
return x
class DecoderCrossAttn(nn.Module):
def __init__(self, query_dim, context_dim, num_heads, mlp=False, decode_with_extra_info=None):
super().__init__()
self.query_dim = query_dim
self.context_dim = context_dim
self.cross_attn = CrossAttnBlock(inner_dim=query_dim, cond_dim=context_dim,
num_heads=num_heads, feedforward=mlp,
eps=1e-5)
self.decode_with_extra_info = decode_with_extra_info
if decode_with_extra_info is not None:
if decode_with_extra_info["type"] == "dinov2p14_feat":
context_dim = decode_with_extra_info["cond_dim"]
self.cross_attn_color = CrossAttnBlock(inner_dim=query_dim, cond_dim=context_dim,
num_heads=num_heads, feedforward=False, eps=1e-5)
elif decode_with_extra_info["type"] == "decoder_dinov2p14_feat":
from lam.models.encoders.dinov2_wrapper import Dinov2Wrapper
self.encoder = Dinov2Wrapper(model_name='dinov2_vits14_reg', freeze=False, encoder_feat_dim=384)
self.cross_attn_color = CrossAttnBlock(inner_dim=query_dim, cond_dim=384,
num_heads=num_heads, feedforward=False,
eps=1e-5)
elif decode_with_extra_info["type"] == "decoder_resnet18_feat":
from lam.models.encoders.xunet_wrapper import XnetWrapper
self.encoder = XnetWrapper(model_name='resnet18', freeze=False, encoder_feat_dim=64)
self.cross_attn_color = CrossAttnBlock(inner_dim=query_dim, cond_dim=64,
num_heads=num_heads, feedforward=False,
eps=1e-5)
def resize_image(self, image, multiply):
B, _, H, W = image.shape
new_h, new_w = math.ceil(H / multiply) * multiply, math.ceil(W / multiply) * multiply
image = F.interpolate(image, (new_h, new_w), align_corners=True, mode="bilinear")
return image
def forward(self, pcl_query, pcl_latent, extra_info=None):
out = self.cross_attn(pcl_query, pcl_latent)
if self.decode_with_extra_info is not None:
out_dict = {}
out_dict["coarse"] = out
if self.decode_with_extra_info["type"] == "dinov2p14_feat":
out = self.cross_attn_color(out, extra_info["image_feats"])
out_dict["fine"] = out
return out_dict
elif self.decode_with_extra_info["type"] == "decoder_dinov2p14_feat":
img_feat = self.encoder(extra_info["image"])
out = self.cross_attn_color(out, img_feat)
out_dict["fine"] = out
return out_dict
elif self.decode_with_extra_info["type"] == "decoder_resnet18_feat":
image = extra_info["image"]
image = self.resize_image(image, multiply=32)
img_feat = self.encoder(image)
out = self.cross_attn_color(out, img_feat)
out_dict["fine"] = out
return out_dict
return out
class GS3DRenderer(nn.Module):
def __init__(self, human_model_path, subdivide_num, smpl_type, feat_dim, query_dim,
use_rgb, sh_degree, xyz_offset_max_step, mlp_network_config,
expr_param_dim, shape_param_dim,
clip_scaling=0.2,
scale_sphere=False,
skip_decoder=False,
fix_opacity=False,
fix_rotation=False,
decode_with_extra_info=None,
gradient_checkpointing=False,
add_teeth=True,
teeth_bs_flag=False,
oral_mesh_flag=False,
**kwargs,
):
super().__init__()
print(f"#########scale sphere:{scale_sphere}, add_teeth:{add_teeth}")
self.gradient_checkpointing = gradient_checkpointing
self.skip_decoder = skip_decoder
self.smpl_type = smpl_type
assert self.smpl_type == "flame"
self.sym_rend2 = True
self.teeth_bs_flag = teeth_bs_flag
self.oral_mesh_flag = oral_mesh_flag
self.render_rgb = kwargs.get("render_rgb", True)
print("==="*16*3, "\n Render rgb:", self.render_rgb, "\n"+"==="*16*3)
self.scaling_modifier = 1.0
self.sh_degree = sh_degree
if use_rgb:
self.sh_degree = 0
use_rgb = use_rgb
self.flame_model = FlameHeadSubdivided(
300,
100,
add_teeth=add_teeth,
add_shoulder=False,
flame_model_path=f'{human_model_path}/flame_assets/flame/flame2023.pkl',
flame_lmk_embedding_path=f"{human_model_path}/flame_assets/flame/landmark_embedding_with_eyes.npy",
flame_template_mesh_path=f"{human_model_path}/flame_assets/flame/head_template_mesh.obj",
flame_parts_path=f"{human_model_path}/flame_assets/flame/FLAME_masks.pkl",
subdivide_num=subdivide_num,
teeth_bs_flag=teeth_bs_flag,
oral_mesh_flag=oral_mesh_flag
)
if not self.skip_decoder:
self.pcl_embed = PointEmbed(dim=query_dim)
self.mlp_network_config = mlp_network_config
if self.mlp_network_config is not None:
self.mlp_net = MLP(query_dim, query_dim, **self.mlp_network_config)
init_scaling = -5.0
self.gs_net = GSLayer(in_channels=query_dim,
use_rgb=use_rgb,
sh_degree=self.sh_degree,
clip_scaling=clip_scaling,
scale_sphere=scale_sphere,
init_scaling=init_scaling,
init_density=0.1,
xyz_offset=True,
restrict_offset=True,
xyz_offset_max_step=xyz_offset_max_step,
fix_opacity=fix_opacity,
fix_rotation=fix_rotation,
use_fine_feat=True if decode_with_extra_info is not None and decode_with_extra_info["type"] is not None else False,
)
def forward_single_view(self,
gs: GaussianModel,
viewpoint_camera: Camera,
background_color: Optional[Float[Tensor, "3"]],
):
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
screenspace_points = torch.zeros_like(gs.xyz, dtype=gs.xyz.dtype, requires_grad=True, device=self.device) + 0
try:
screenspace_points.retain_grad()
except:
pass
bg_color = background_color
# Set up rasterization configuration
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
GSRSettings = GaussianRasterizationSettings
GSR = GaussianRasterizer
raster_settings = GSRSettings(
image_height=int(viewpoint_camera.height),
image_width=int(viewpoint_camera.width),
tanfovx=tanfovx,
tanfovy=tanfovy,
bg=bg_color,
scale_modifier=self.scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform.float(),
sh_degree=self.sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=False
)
rasterizer = GSR(raster_settings=raster_settings)
means3D = gs.xyz
means2D = screenspace_points
opacity = gs.opacity
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
# scaling / rotation by the rasterizer.
scales = None
rotations = None
cov3D_precomp = None
scales = gs.scaling
rotations = gs.rotation
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
shs = None
colors_precomp = None
if self.gs_net.use_rgb:
colors_precomp = gs.shs.squeeze(1)
else:
shs = gs.shs
# Rasterize visible Gaussians to image, obtain their radii (on screen).
# torch.cuda.synchronize()
# with boxx.timeit():
with torch.autocast(device_type=self.device.type, dtype=torch.float32):
raster_ret = rasterizer(
means3D = means3D.float(),
means2D = means2D.float(),
shs = shs.float() if not self.gs_net.use_rgb else None,
colors_precomp = colors_precomp.float() if colors_precomp is not None else None,
opacities = opacity.float(),
scales = scales.float(),
rotations = rotations.float(),
cov3D_precomp = cov3D_precomp
)
rendered_image, radii, rendered_depth, rendered_alpha = raster_ret
ret = {
"comp_rgb": rendered_image.permute(1, 2, 0), # [H, W, 3]
"comp_rgb_bg": bg_color,
'comp_mask': rendered_alpha.permute(1, 2, 0),
'comp_depth': rendered_depth.permute(1, 2, 0),
}
return ret
def animate_gs_model(self, gs_attr: GaussianModel, query_points, flame_data, debug=False):
"""
query_points: [N, 3]
"""
device = gs_attr.xyz.device
if debug:
N = gs_attr.xyz.shape[0]
gs_attr.xyz = torch.ones_like(gs_attr.xyz) * 0.0
rotation = matrix_to_quaternion(torch.eye(3).float()[None, :, :].repeat(N, 1, 1)).to(device) # constant rotation
opacity = torch.ones((N, 1)).float().to(device) # constant opacity
gs_attr.opacity = opacity
gs_attr.rotation = rotation
# gs_attr.scaling = torch.ones_like(gs_attr.scaling) * 0.05
# print(gs_attr.shs.shape)
with torch.autocast(device_type=device.type, dtype=torch.float32):
# mean_3d = query_points + gs_attr.xyz # [N, 3]
mean_3d = gs_attr.xyz # [N, 3]
num_view = flame_data["expr"].shape[0] # [Nv, 100]
mean_3d = mean_3d.unsqueeze(0).repeat(num_view, 1, 1) # [Nv, N, 3]
query_points = query_points.unsqueeze(0).repeat(num_view, 1, 1)
if self.teeth_bs_flag:
expr = torch.cat([flame_data['expr'], flame_data['teeth_bs']], dim=-1)
else:
expr = flame_data["expr"]
ret = self.flame_model.animation_forward(v_cano=mean_3d,
shape=flame_data["betas"].repeat(num_view, 1),
expr=expr,
rotation=flame_data["rotation"],
neck=flame_data["neck_pose"],
jaw=flame_data["jaw_pose"],
eyes=flame_data["eyes_pose"],
translation=flame_data["translation"],
zero_centered_at_root_node=False,
return_landmarks=False,
return_verts_cano=False,
# static_offset=flame_data['static_offset'].to('cuda'),
static_offset=None,
)
mean_3d = ret["animated"]
gs_attr_list = []
for i in range(num_view):
gs_attr_copy = GaussianModel(xyz=mean_3d[i],
opacity=gs_attr.opacity,
rotation=gs_attr.rotation,
scaling=gs_attr.scaling,
shs=gs_attr.shs,
albedo=gs_attr.albedo,
lights=gs_attr.lights,
offset=gs_attr.offset) # [N, 3]
gs_attr_list.append(gs_attr_copy)
return gs_attr_list
def forward_gs_attr(self, x, query_points, flame_data, debug=False, x_fine=None, vtx_sym_idxs=None):
"""
x: [N, C] Float[Tensor, "Np Cp"],
query_points: [N, 3] Float[Tensor, "Np 3"]
"""
device = x.device
if self.mlp_network_config is not None:
x = self.mlp_net(x)
if x_fine is not None:
x_fine = self.mlp_net(x_fine)
gs_attr: GaussianModel = self.gs_net(x, query_points, x_fine, vtx_sym_idxs=vtx_sym_idxs)
return gs_attr
def get_query_points(self, flame_data, device):
with torch.no_grad():
with torch.autocast(device_type=device.type, dtype=torch.float32):
# print(flame_data["betas"].shape, flame_data["face_offset"].shape, flame_data["joint_offset"].shape)
# positions, _, transform_mat_neutral_pose = self.flame_model.get_query_points(flame_data, device=device) # [B, N, 3]
positions = self.flame_model.get_cano_verts(shape_params=flame_data["betas"]) # [B, N, 3]
# print(f"positions shape:{positions.shape}")
return positions, flame_data
def query_latent_feat(self,
positions: Float[Tensor, "*B N1 3"],
flame_data,
latent_feat: Float[Tensor, "*B N2 C"],
extra_info):
device = latent_feat.device
if self.skip_decoder:
gs_feats = latent_feat
assert positions is not None
else:
assert positions is None
if positions is None:
positions, flame_data = self.get_query_points(flame_data, device)
with torch.autocast(device_type=device.type, dtype=torch.float32):
pcl_embed = self.pcl_embed(positions)
gs_feats = pcl_embed
return gs_feats, positions, flame_data
def forward_single_batch(
self,
gs_list: list[GaussianModel],
c2ws: Float[Tensor, "Nv 4 4"],
intrinsics: Float[Tensor, "Nv 4 4"],
height: int,
width: int,
background_color: Optional[Float[Tensor, "Nv 3"]],
debug: bool=False,
):
out_list = []
self.device = gs_list[0].xyz.device
for v_idx, (c2w, intrinsic) in enumerate(zip(c2ws, intrinsics)):
out_list.append(self.forward_single_view(
gs_list[v_idx],
Camera.from_c2w(c2w, intrinsic, height, width),
background_color[v_idx],
))
out = defaultdict(list)
for out_ in out_list:
for k, v in out_.items():
out[k].append(v)
out = {k: torch.stack(v, dim=0) for k, v in out.items()}
out["3dgs"] = gs_list
return out
def get_sing_batch_smpl_data(self, smpl_data, bidx):
smpl_data_single_batch = {}
for k, v in smpl_data.items():
smpl_data_single_batch[k] = v[bidx] # e.g. body_pose: [B, N_v, 21, 3] -> [N_v, 21, 3]
if k == "betas" or (k == "joint_offset") or (k == "face_offset"):
smpl_data_single_batch[k] = v[bidx:bidx+1] # e.g. betas: [B, 100] -> [1, 100]
return smpl_data_single_batch
def get_single_view_smpl_data(self, smpl_data, vidx):
smpl_data_single_view = {}
for k, v in smpl_data.items():
assert v.shape[0] == 1
if k == "betas" or (k == "joint_offset") or (k == "face_offset") or (k == "transform_mat_neutral_pose"):
smpl_data_single_view[k] = v # e.g. betas: [1, 100] -> [1, 100]
else:
smpl_data_single_view[k] = v[:, vidx: vidx + 1] # e.g. body_pose: [1, N_v, 21, 3] -> [1, 1, 21, 3]
return smpl_data_single_view
def forward_gs(self,
gs_hidden_features: Float[Tensor, "B Np Cp"],
query_points: Float[Tensor, "B Np_q 3"],
flame_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100]
additional_features: Optional[dict] = None,
debug: bool = False,
**kwargs):
batch_size = gs_hidden_features.shape[0]
query_gs_features, query_points, flame_data = self.query_latent_feat(query_points, flame_data, gs_hidden_features,
additional_features)
gs_model_list = []
all_query_points = []
for b in range(batch_size):
all_query_points.append(query_points[b:b+1, :])
if isinstance(query_gs_features, dict):
ret_gs = self.forward_gs_attr(query_gs_features["coarse"][b], query_points[b], None, debug,
x_fine=query_gs_features["fine"][b], vtx_sym_idxs=None)
else:
ret_gs = self.forward_gs_attr(query_gs_features[b], query_points[b], None, debug, vtx_sym_idxs=None)
ret_gs.update_albedo(ret_gs.shs.clone())
gs_model_list.append(ret_gs)
query_points = torch.cat(all_query_points, dim=0)
return gs_model_list, query_points, flame_data, query_gs_features
def forward_res_refine_gs(self,
gs_hidden_features: Float[Tensor, "B Np Cp"],
query_points: Float[Tensor, "B Np_q 3"],
flame_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100]
additional_features: Optional[dict] = None,
debug: bool = False,
gs_raw_attr_list: list = None,
**kwargs):
batch_size = gs_hidden_features.shape[0]
query_gs_features, query_points, flame_data = self.query_latent_feat(query_points, flame_data, gs_hidden_features,
additional_features)
gs_model_list = []
for b in range(batch_size):
gs_model = self.gs_refine_net(query_gs_features[b], query_points[b], x_fine=None, gs_raw_attr=gs_raw_attr_list[b])
gs_model_list.append(gs_model)
return gs_model_list, query_points, flame_data, query_gs_features
def forward_animate_gs(self, gs_model_list, query_points, flame_data, c2w, intrinsic, height, width,
background_color, debug=False):
batch_size = len(gs_model_list)
out_list = []
for b in range(batch_size):
gs_model = gs_model_list[b]
query_pt = query_points[b]
animatable_gs_model_list: list[GaussianModel] = self.animate_gs_model(gs_model,
query_pt,
self.get_sing_batch_smpl_data(flame_data, b),
debug=debug)
assert len(animatable_gs_model_list) == c2w.shape[1]
out_list.append(self.forward_single_batch(
animatable_gs_model_list,
c2w[b],
intrinsic[b],
height, width,
background_color[b] if background_color is not None else None,
debug=debug))
out = defaultdict(list)
for out_ in out_list:
for k, v in out_.items():
out[k].append(v)
for k, v in out.items():
if isinstance(v[0], torch.Tensor):
out[k] = torch.stack(v, dim=0)
else:
out[k] = v
render_keys = ["comp_rgb", "comp_mask", "comp_depth"]
for key in render_keys:
out[key] = rearrange(out[key], "b v h w c -> b v c h w")
return out
def project_single_view_feats(self, img_vtx_ids, feats, nv, inter_feat=True):
b, h, w, k = img_vtx_ids.shape
c, ih, iw = feats.shape
vtx_ids = img_vtx_ids
if h != ih or w != iw:
if inter_feat:
feats = torch.nn.functional.interpolate(
rearrange(feats, "(b c) h w -> b c h w", b=1).float(), (h, w)
).squeeze(0)
vtx_ids = rearrange(vtx_ids, "b (c h) w k -> (b k) c h w", c=1).long().squeeze(1)
else:
vtx_ids = torch.nn.functional.interpolate(
rearrange(vtx_ids, "b (c h) w k -> (b k) c h w", c=1).float(), (ih, iw), mode="nearest"
).long().squeeze(1)
else:
vtx_ids = rearrange(vtx_ids, "b h w k -> (b k) h w", b=1).long()
vis_mask = vtx_ids > 0
vtx_ids = vtx_ids[vis_mask] # n
vtx_ids = repeat(vtx_ids, "n -> n c", c=c)
feats = repeat(feats, "c h w -> k h w c", k=k).to(vtx_ids.device)
feats = feats[vis_mask, :] # n, c
sums = torch.zeros((nv, c), dtype=feats.dtype, device=feats.device)
counts = torch.zeros((nv), dtype=torch.int64, device=feats.device)
sums.scatter_add_(0, vtx_ids, feats)
one_hot = torch.ones_like(vtx_ids[:, 0], dtype=torch.int64).to(feats.device)
counts.scatter_add_(0, vtx_ids[:, 0], one_hot)
clamp_counts = counts.clamp(min=1)
mean_feats = sums / clamp_counts.view(-1, 1)
return mean_feats
def forward(self,
gs_hidden_features: Float[Tensor, "B Np Cp"],
query_points: Float[Tensor, "B Np 3"],
flame_data, # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100]
c2w: Float[Tensor, "B Nv 4 4"],
intrinsic: Float[Tensor, "B Nv 4 4"],
height,
width,
additional_features: Optional[Float[Tensor, "B C H W"]] = None,
background_color: Optional[Float[Tensor, "B Nv 3"]] = None,
debug: bool = False,
**kwargs):
# need shape_params of flame_data to get querty points and get "transform_mat_neutral_pose"
gs_model_list, query_points, flame_data, query_gs_features = self.forward_gs(gs_hidden_features, query_points, flame_data=flame_data,
additional_features=additional_features, debug=debug)
out = self.forward_animate_gs(gs_model_list, query_points, flame_data, c2w, intrinsic, height, width, background_color, debug)
return out
def test_head():
import cv2
human_model_path = "./pretrained_models/human_model_files"
device = "cuda"
from accelerate.utils import set_seed
set_seed(1234)
from lam.datasets.video_head import VideoHeadDataset
root_dir = "./train_data/vfhq_vhap/export"
meta_path = "./train_data/vfhq_vhap/label/valid_id_list.json"
# root_dir = "./train_data/nersemble/export"
# meta_path = "./train_data/nersemble/label/valid_id_list1.json"
dataset = VideoHeadDataset(root_dirs=root_dir, meta_path=meta_path, sample_side_views=7,
render_image_res_low=512, render_image_res_high=512,
render_region_size=(512, 512), source_image_res=512,
enlarge_ratio=[0.8, 1.2],
debug=False)
data = dataset[0]
def get_flame_params(data):
flame_params = {}
flame_keys = ['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas',\
'rotation', 'neck_pose', 'eyes_pose', 'translation']
for k, v in data.items():
if k in flame_keys:
# print(k, v.shape)
flame_params[k] = data[k]
return flame_params
flame_data = get_flame_params(data)
flame_data_tmp = {}
for k, v in flame_data.items():
flame_data_tmp[k] = v.unsqueeze(0).to(device)
print(k, v.shape)
flame_data = flame_data_tmp
c2ws = data["c2ws"].unsqueeze(0).to(device)
intrs = data["intrs"].unsqueeze(0).to(device)
render_images = data["render_image"].numpy()
render_h = data["render_full_resolutions"][0, 0]
render_w= data["render_full_resolutions"][0, 1]
render_bg_colors = data["render_bg_colors"].unsqueeze(0).to(device)
print("c2ws", c2ws.shape, "intrs", intrs.shape, intrs)
gs_render = GS3DRenderer(human_model_path=human_model_path, subdivide_num=2, smpl_type="flame",
feat_dim=64, query_dim=64, use_rgb=True, sh_degree=3, mlp_network_config=None,
xyz_offset_max_step=0.0001, expr_param_dim=10, shape_param_dim=10,
fix_opacity=True, fix_rotation=True, clip_scaling=0.001, add_teeth=False)
gs_render.to(device)
out = gs_render.forward(gs_hidden_features=torch.zeros((1, 2048, 64)).float().to(device),
query_points=None,
flame_data=flame_data,
c2w=c2ws,
intrinsic=intrs,
height=render_h,
width=render_w,
background_color=render_bg_colors,
debug=False)
os.makedirs("./debug_vis/gs_render", exist_ok=True)
for k, v in out.items():
if k == "comp_rgb_bg":
print("comp_rgb_bg", v)
continue
for b_idx in range(len(v)):
if k == "3dgs":
for v_idx in range(len(v[b_idx])):
v[b_idx][v_idx].save_ply(f"./debug_vis/gs_render/{b_idx}_{v_idx}.ply")
continue
for v_idx in range(v.shape[1]):
save_path = os.path.join("./debug_vis/gs_render", f"{b_idx}_{v_idx}_{k}.jpg")
if "normal" in k:
img = ((v[b_idx, v_idx].permute(1, 2, 0).detach().cpu().numpy() + 1.0) / 2. * 255).astype(np.uint8)
else:
img = (v[b_idx, v_idx].permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)
print(v[b_idx, v_idx].shape, img.shape, save_path)
if "mask" in k:
render_img = render_images[v_idx].transpose(1, 2, 0) * 255
blend_img = (render_images[v_idx].transpose(1, 2, 0) * 255 * 0.5 + np.tile(img, (1, 1, 3)) * 0.5).clip(0, 255).astype(np.uint8)
cv2.imwrite(save_path, np.hstack([np.tile(img, (1, 1, 3)), render_img.astype(np.uint8), blend_img])[:, :, (2, 1, 0)])
else:
print(save_path, k)
cv2.imwrite(save_path, img)
if __name__ == "__main__":
test_head()