|
from typing import * |
|
import numpy as np |
|
import torch |
|
import utils3d |
|
import nvdiffrast.torch as dr |
|
from tqdm import tqdm |
|
import trimesh |
|
import trimesh.visual |
|
import xatlas |
|
import pyvista as pv |
|
from pymeshfix import _meshfix |
|
import igraph |
|
import cv2 |
|
from PIL import Image |
|
from .random_utils import sphere_hammersley_sequence |
|
from .render_utils import render_multiview |
|
from ..renderers import GaussianRenderer |
|
from ..representations import Strivec, Gaussian, MeshExtractResult |
|
|
|
|
|
@torch.no_grad() |
|
def _fill_holes( |
|
verts, |
|
faces, |
|
max_hole_size=0.04, |
|
max_hole_nbe=32, |
|
resolution=128, |
|
num_views=500, |
|
debug=False, |
|
verbose=False, |
|
): |
|
""" |
|
Rasterize a mesh from multiple views and remove invisible faces. |
|
Also includes postprocessing to: |
|
1. Remove connected components that are have low visibility. |
|
2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole. |
|
|
|
Args: |
|
verts (torch.Tensor): Vertices of the mesh. Shape (V, 3). |
|
faces (torch.Tensor): Faces of the mesh. Shape (F, 3). |
|
max_hole_size (float): Maximum area of a hole to fill. |
|
resolution (int): Resolution of the rasterization. |
|
num_views (int): Number of views to rasterize the mesh. |
|
verbose (bool): Whether to print progress. |
|
""" |
|
|
|
yaws = [] |
|
pitchs = [] |
|
for i in range(num_views): |
|
y, p = sphere_hammersley_sequence(i, num_views) |
|
yaws.append(y) |
|
pitchs.append(p) |
|
yaws = torch.tensor(yaws).cuda() |
|
pitchs = torch.tensor(pitchs).cuda() |
|
radius = 2.0 |
|
fov = torch.deg2rad(torch.tensor(40)).cuda() |
|
projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3) |
|
views = [] |
|
for (yaw, pitch) in zip(yaws, pitchs): |
|
orig = ( |
|
torch.tensor( |
|
[ |
|
torch.sin(yaw) * torch.cos(pitch), |
|
torch.cos(yaw) * torch.cos(pitch), |
|
torch.sin(pitch), |
|
] |
|
) |
|
.cuda() |
|
.float() |
|
* radius |
|
) |
|
view = utils3d.torch.view_look_at( |
|
orig, |
|
torch.tensor([0, 0, 0]).float().cuda(), |
|
torch.tensor([0, 0, 1]).float().cuda(), |
|
) |
|
views.append(view) |
|
views = torch.stack(views, dim=0) |
|
|
|
|
|
visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device) |
|
rastctx = utils3d.torch.RastContext(backend="cuda") |
|
for i in tqdm( |
|
range(views.shape[0]), |
|
total=views.shape[0], |
|
disable=not verbose, |
|
desc="Rasterizing", |
|
): |
|
view = views[i] |
|
buffers = utils3d.torch.rasterize_triangle_faces( |
|
rastctx, |
|
verts[None], |
|
faces, |
|
resolution, |
|
resolution, |
|
view=view, |
|
projection=projection, |
|
) |
|
face_id = buffers["face_id"][0][buffers["mask"][0] > 0.95] - 1 |
|
face_id = torch.unique(face_id).long() |
|
visblity[face_id] += 1 |
|
visblity = visblity.float() / num_views |
|
|
|
|
|
|
|
edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces) |
|
boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1) |
|
connected_components = utils3d.torch.compute_connected_components( |
|
faces, edges, face2edge |
|
) |
|
outer_face_indices = torch.zeros( |
|
faces.shape[0], dtype=torch.bool, device=faces.device |
|
) |
|
for i in range(len(connected_components)): |
|
outer_face_indices[connected_components[i]] = visblity[ |
|
connected_components[i] |
|
] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5) |
|
outer_face_indices = outer_face_indices.nonzero().reshape(-1) |
|
|
|
|
|
inner_face_indices = torch.nonzero(visblity == 0).reshape(-1) |
|
if verbose: |
|
tqdm.write(f"Found {inner_face_indices.shape[0]} invisible faces") |
|
if inner_face_indices.shape[0] == 0: |
|
return verts, faces |
|
|
|
|
|
dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge) |
|
dual_edge2edge = edges[dual_edge2edge] |
|
dual_edges_weights = torch.norm( |
|
verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1 |
|
) |
|
if verbose: |
|
tqdm.write(f"Dual graph: {dual_edges.shape[0]} edges") |
|
|
|
|
|
|
|
g = igraph.Graph() |
|
g.add_vertices(faces.shape[0]) |
|
g.add_edges(dual_edges.cpu().numpy()) |
|
g.es["weight"] = dual_edges_weights.cpu().numpy() |
|
|
|
|
|
g.add_vertex("s") |
|
g.add_vertex("t") |
|
|
|
|
|
g.add_edges( |
|
[(f, "s") for f in inner_face_indices], |
|
attributes={ |
|
"weight": torch.ones(inner_face_indices.shape[0], dtype=torch.float32) |
|
.cpu() |
|
.numpy() |
|
}, |
|
) |
|
|
|
|
|
g.add_edges( |
|
[(f, "t") for f in outer_face_indices], |
|
attributes={ |
|
"weight": torch.ones(outer_face_indices.shape[0], dtype=torch.float32) |
|
.cpu() |
|
.numpy() |
|
}, |
|
) |
|
|
|
|
|
cut = g.mincut("s", "t", (np.array(g.es["weight"]) * 1000).tolist()) |
|
remove_face_indices = torch.tensor( |
|
[v for v in cut.partition[0] if v < faces.shape[0]], |
|
dtype=torch.long, |
|
device=faces.device, |
|
) |
|
if verbose: |
|
tqdm.write(f"Mincut solved, start checking the cut") |
|
|
|
|
|
to_remove_cc = utils3d.torch.compute_connected_components( |
|
faces[remove_face_indices] |
|
) |
|
if debug: |
|
tqdm.write(f"Number of connected components of the cut: {len(to_remove_cc)}") |
|
valid_remove_cc = [] |
|
cutting_edges = [] |
|
for cc in to_remove_cc: |
|
|
|
visblity_median = visblity[remove_face_indices[cc]].median() |
|
if debug: |
|
tqdm.write(f"visblity_median: {visblity_median}") |
|
if visblity_median > 0.25: |
|
continue |
|
|
|
|
|
cc_edge_indices, cc_edges_degree = torch.unique( |
|
face2edge[remove_face_indices[cc]], return_counts=True |
|
) |
|
cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1] |
|
cc_new_boundary_edge_indices = cc_boundary_edge_indices[ |
|
~torch.isin(cc_boundary_edge_indices, boundary_edge_indices) |
|
] |
|
if len(cc_new_boundary_edge_indices) > 0: |
|
cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components( |
|
edges[cc_new_boundary_edge_indices] |
|
) |
|
cc_new_boundary_edges_cc_center = [ |
|
verts[edges[cc_new_boundary_edge_indices[edge_cc]]] |
|
.mean(dim=1) |
|
.mean(dim=0) |
|
for edge_cc in cc_new_boundary_edge_cc |
|
] |
|
cc_new_boundary_edges_cc_area = [] |
|
for i, edge_cc in enumerate(cc_new_boundary_edge_cc): |
|
_e1 = ( |
|
verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] |
|
- cc_new_boundary_edges_cc_center[i] |
|
) |
|
_e2 = ( |
|
verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] |
|
- cc_new_boundary_edges_cc_center[i] |
|
) |
|
cc_new_boundary_edges_cc_area.append( |
|
torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5 |
|
) |
|
if debug: |
|
cutting_edges.append(cc_new_boundary_edge_indices) |
|
tqdm.write(f"Area of the cutting loop: {cc_new_boundary_edges_cc_area}") |
|
if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]): |
|
continue |
|
|
|
valid_remove_cc.append(cc) |
|
|
|
if debug: |
|
face_v = verts[faces].mean(dim=1).cpu().numpy() |
|
vis_dual_edges = dual_edges.cpu().numpy() |
|
vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8) |
|
vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255] |
|
vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0] |
|
vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255] |
|
if len(valid_remove_cc) > 0: |
|
vis_colors[ |
|
remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy() |
|
] = [255, 0, 0] |
|
utils3d.io.write_ply( |
|
"dbg_dual.ply", face_v, edges=vis_dual_edges, vertex_colors=vis_colors |
|
) |
|
|
|
vis_verts = verts.cpu().numpy() |
|
vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy() |
|
utils3d.io.write_ply("dbg_cut.ply", vis_verts, edges=vis_edges) |
|
|
|
if len(valid_remove_cc) > 0: |
|
remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)] |
|
mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device) |
|
mask[remove_face_indices] = 0 |
|
faces = faces[mask] |
|
faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts) |
|
if verbose: |
|
tqdm.write(f"Removed {(~mask).sum()} faces by mincut") |
|
else: |
|
if verbose: |
|
tqdm.write(f"Removed 0 faces by mincut") |
|
|
|
mesh = _meshfix.PyTMesh() |
|
mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy()) |
|
mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True) |
|
verts, faces = mesh.return_arrays() |
|
verts, faces = torch.tensor( |
|
verts, device="cuda", dtype=torch.float32 |
|
), torch.tensor(faces, device="cuda", dtype=torch.int32) |
|
|
|
return verts, faces |
|
|
|
|
|
def postprocess_mesh( |
|
vertices: np.array, |
|
faces: np.array, |
|
simplify: bool = True, |
|
simplify_ratio: float = 0.9, |
|
fill_holes: bool = True, |
|
fill_holes_max_hole_size: float = 0.04, |
|
fill_holes_max_hole_nbe: int = 32, |
|
fill_holes_resolution: int = 1024, |
|
fill_holes_num_views: int = 1000, |
|
debug: bool = False, |
|
verbose: bool = False, |
|
): |
|
""" |
|
Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces. |
|
|
|
Args: |
|
vertices (np.array): Vertices of the mesh. Shape (V, 3). |
|
faces (np.array): Faces of the mesh. Shape (F, 3). |
|
simplify (bool): Whether to simplify the mesh, using quadric edge collapse. |
|
simplify_ratio (float): Ratio of faces to keep after simplification. |
|
fill_holes (bool): Whether to fill holes in the mesh. |
|
fill_holes_max_hole_size (float): Maximum area of a hole to fill. |
|
fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill. |
|
fill_holes_resolution (int): Resolution of the rasterization. |
|
fill_holes_num_views (int): Number of views to rasterize the mesh. |
|
verbose (bool): Whether to print progress. |
|
""" |
|
|
|
if verbose: |
|
tqdm.write( |
|
f"Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces" |
|
) |
|
|
|
|
|
if simplify and simplify_ratio > 0: |
|
mesh = pv.PolyData( |
|
vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1) |
|
) |
|
mesh = mesh.decimate(simplify_ratio, progress_bar=verbose) |
|
vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:] |
|
if verbose: |
|
tqdm.write( |
|
f"After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces" |
|
) |
|
|
|
|
|
if fill_holes: |
|
vertices, faces = ( |
|
torch.tensor(vertices).cuda(), |
|
torch.tensor(faces.astype(np.int32)).cuda(), |
|
) |
|
vertices, faces = _fill_holes( |
|
vertices, |
|
faces, |
|
max_hole_size=fill_holes_max_hole_size, |
|
max_hole_nbe=fill_holes_max_hole_nbe, |
|
resolution=fill_holes_resolution, |
|
num_views=fill_holes_num_views, |
|
debug=debug, |
|
verbose=verbose, |
|
) |
|
vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy() |
|
if verbose: |
|
tqdm.write( |
|
f"After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces" |
|
) |
|
|
|
return vertices, faces |
|
|
|
|
|
def parametrize_mesh(vertices: np.array, faces: np.array): |
|
""" |
|
Parametrize a mesh to a texture space, using xatlas. |
|
|
|
Args: |
|
vertices (np.array): Vertices of the mesh. Shape (V, 3). |
|
faces (np.array): Faces of the mesh. Shape (F, 3). |
|
""" |
|
|
|
vmapping, indices, uvs = xatlas.parametrize(vertices, faces) |
|
|
|
vertices = vertices[vmapping] |
|
faces = indices |
|
|
|
return vertices, faces, uvs |
|
|
|
|
|
def bake_texture( |
|
vertices: np.array, |
|
faces: np.array, |
|
uvs: np.array, |
|
observations: List[np.array], |
|
masks: List[np.array], |
|
extrinsics: List[np.array], |
|
intrinsics: List[np.array], |
|
texture_size: int = 2048, |
|
near: float = 0.1, |
|
far: float = 10.0, |
|
mode: Literal["fast", "opt"] = "opt", |
|
lambda_tv: float = 1e-2, |
|
verbose: bool = False, |
|
): |
|
""" |
|
Bake texture to a mesh from multiple observations. |
|
|
|
Args: |
|
vertices (np.array): Vertices of the mesh. Shape (V, 3). |
|
faces (np.array): Faces of the mesh. Shape (F, 3). |
|
uvs (np.array): UV coordinates of the mesh. Shape (V, 2). |
|
observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3). |
|
masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W). |
|
extrinsics (List[np.array]): List of extrinsics. Shape (4, 4). |
|
intrinsics (List[np.array]): List of intrinsics. Shape (3, 3). |
|
texture_size (int): Size of the texture. |
|
near (float): Near plane of the camera. |
|
far (float): Far plane of the camera. |
|
mode (Literal['fast', 'opt']): Mode of texture baking. |
|
lambda_tv (float): Weight of total variation loss in optimization. |
|
verbose (bool): Whether to print progress. |
|
""" |
|
vertices = torch.tensor(vertices).cuda() |
|
faces = torch.tensor(faces.astype(np.int32)).cuda() |
|
uvs = torch.tensor(uvs).cuda() |
|
observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations] |
|
masks = [torch.tensor(m > 0).bool().cuda() for m in masks] |
|
views = [ |
|
utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) |
|
for extr in extrinsics |
|
] |
|
projections = [ |
|
utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) |
|
for intr in intrinsics |
|
] |
|
|
|
if mode == "fast": |
|
texture = torch.zeros( |
|
(texture_size * texture_size, 3), dtype=torch.float32 |
|
).cuda() |
|
texture_weights = torch.zeros( |
|
(texture_size * texture_size), dtype=torch.float32 |
|
).cuda() |
|
rastctx = utils3d.torch.RastContext(backend="cuda") |
|
for observation, view, projection in tqdm( |
|
zip(observations, views, projections), |
|
total=len(observations), |
|
disable=not verbose, |
|
desc="Texture baking (fast)", |
|
): |
|
with torch.no_grad(): |
|
rast = utils3d.torch.rasterize_triangle_faces( |
|
rastctx, |
|
vertices[None], |
|
faces, |
|
observation.shape[1], |
|
observation.shape[0], |
|
uv=uvs[None], |
|
view=view, |
|
projection=projection, |
|
) |
|
uv_map = rast["uv"][0].detach().flip(0) |
|
mask = rast["mask"][0].detach().bool() & masks[0] |
|
|
|
|
|
uv_map = (uv_map * texture_size).floor().long() |
|
obs = observation[mask] |
|
uv_map = uv_map[mask] |
|
idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size |
|
texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs) |
|
texture_weights = texture_weights.scatter_add( |
|
0, |
|
idx, |
|
torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device), |
|
) |
|
|
|
mask = texture_weights > 0 |
|
texture[mask] /= texture_weights[mask][:, None] |
|
texture = np.clip( |
|
texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255 |
|
).astype(np.uint8) |
|
|
|
|
|
mask = ( |
|
(texture_weights == 0) |
|
.cpu() |
|
.numpy() |
|
.astype(np.uint8) |
|
.reshape(texture_size, texture_size) |
|
) |
|
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) |
|
|
|
elif mode == "opt": |
|
rastctx = utils3d.torch.RastContext(backend="cuda") |
|
observations = [observations.flip(0) for observations in observations] |
|
masks = [m.flip(0) for m in masks] |
|
_uv = [] |
|
_uv_dr = [] |
|
for observation, view, projection in tqdm( |
|
zip(observations, views, projections), |
|
total=len(views), |
|
disable=not verbose, |
|
desc="Texture baking (opt): UV", |
|
): |
|
with torch.no_grad(): |
|
rast = utils3d.torch.rasterize_triangle_faces( |
|
rastctx, |
|
vertices[None], |
|
faces, |
|
observation.shape[1], |
|
observation.shape[0], |
|
uv=uvs[None], |
|
view=view, |
|
projection=projection, |
|
) |
|
_uv.append(rast["uv"].detach()) |
|
_uv_dr.append(rast["uv_dr"].detach()) |
|
|
|
texture = torch.nn.Parameter( |
|
torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda() |
|
) |
|
optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2) |
|
|
|
def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): |
|
return start_lr * (end_lr / start_lr) ** (step / total_steps) |
|
|
|
def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): |
|
return end_lr + 0.5 * (start_lr - end_lr) * ( |
|
1 + np.cos(np.pi * step / total_steps) |
|
) |
|
|
|
def tv_loss(texture): |
|
return torch.nn.functional.l1_loss( |
|
texture[:, :-1, :, :], texture[:, 1:, :, :] |
|
) + torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :]) |
|
|
|
total_steps = 2500 |
|
with tqdm( |
|
total=total_steps, |
|
disable=not verbose, |
|
desc="Texture baking (opt): optimizing", |
|
) as pbar: |
|
for step in range(total_steps): |
|
optimizer.zero_grad() |
|
selected = np.random.randint(0, len(views)) |
|
uv, uv_dr, observation, mask = ( |
|
_uv[selected], |
|
_uv_dr[selected], |
|
observations[selected], |
|
masks[selected], |
|
) |
|
render = dr.texture(texture, uv, uv_dr)[0] |
|
loss = torch.nn.functional.l1_loss(render[mask], observation[mask]) |
|
if lambda_tv > 0: |
|
loss += lambda_tv * tv_loss(texture) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
optimizer.param_groups[0]["lr"] = cosine_anealing( |
|
optimizer, step, total_steps, 1e-2, 1e-5 |
|
) |
|
pbar.set_postfix({"loss": loss.item()}) |
|
pbar.update() |
|
texture = np.clip( |
|
texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255 |
|
).astype(np.uint8) |
|
mask = 1 - utils3d.torch.rasterize_triangle_faces( |
|
rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size |
|
)["mask"][0].detach().cpu().numpy().astype(np.uint8) |
|
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) |
|
else: |
|
raise ValueError(f"Unknown mode: {mode}") |
|
|
|
return texture |
|
|
|
|
|
def to_glb( |
|
app_rep: Union[Strivec, Gaussian], |
|
mesh: MeshExtractResult, |
|
simplify: float = 0.95, |
|
fill_holes: bool = True, |
|
fill_holes_max_size: float = 0.04, |
|
texture_size: int = 1024, |
|
debug: bool = False, |
|
verbose: bool = True, |
|
) -> trimesh.Trimesh: |
|
""" |
|
Convert a generated asset to a glb file. |
|
|
|
Args: |
|
app_rep (Union[Strivec, Gaussian]): Appearance representation. |
|
mesh (MeshExtractResult): Extracted mesh. |
|
simplify (float): Ratio of faces to remove in simplification. |
|
fill_holes (bool): Whether to fill holes in the mesh. |
|
fill_holes_max_size (float): Maximum area of a hole to fill. |
|
texture_size (int): Size of the texture. |
|
debug (bool): Whether to print debug information. |
|
verbose (bool): Whether to print progress. |
|
""" |
|
vertices = mesh.vertices.cpu().numpy() |
|
faces = mesh.faces.cpu().numpy() |
|
|
|
|
|
vertices, faces = postprocess_mesh( |
|
vertices, |
|
faces, |
|
simplify=simplify > 0, |
|
simplify_ratio=simplify, |
|
fill_holes=fill_holes, |
|
fill_holes_max_hole_size=fill_holes_max_size, |
|
fill_holes_max_hole_nbe=int(250 * np.sqrt(1 - simplify)), |
|
fill_holes_resolution=1024, |
|
fill_holes_num_views=1000, |
|
debug=debug, |
|
verbose=verbose, |
|
) |
|
|
|
|
|
vertices, faces, uvs = parametrize_mesh(vertices, faces) |
|
|
|
|
|
observations, extrinsics, intrinsics = render_multiview( |
|
app_rep, resolution=1024, nviews=100 |
|
) |
|
masks = [np.any(observation > 0, axis=-1) for observation in observations] |
|
extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))] |
|
intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))] |
|
texture = bake_texture( |
|
vertices, |
|
faces, |
|
uvs, |
|
observations, |
|
masks, |
|
extrinsics, |
|
intrinsics, |
|
texture_size=texture_size, |
|
mode="opt", |
|
lambda_tv=0.01, |
|
verbose=verbose, |
|
) |
|
texture = Image.fromarray(texture) |
|
|
|
|
|
vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) |
|
material = trimesh.visual.material.PBRMaterial( |
|
roughnessFactor=1.0, |
|
baseColorTexture=texture, |
|
baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8), |
|
) |
|
mesh = trimesh.Trimesh( |
|
vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material) |
|
) |
|
return mesh |
|
|
|
|
|
def simplify_gs( |
|
gs: Gaussian, |
|
simplify: float = 0.95, |
|
verbose: bool = True, |
|
): |
|
""" |
|
Simplify 3D Gaussians |
|
NOTE: this function is not used in the current implementation for the unsatisfactory performance. |
|
|
|
Args: |
|
gs (Gaussian): 3D Gaussian. |
|
simplify (float): Ratio of Gaussians to remove in simplification. |
|
""" |
|
if simplify <= 0: |
|
return gs |
|
|
|
|
|
observations, extrinsics, intrinsics = render_multiview( |
|
gs, resolution=1024, nviews=100 |
|
) |
|
observations = [ |
|
torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) |
|
for obs in observations |
|
] |
|
|
|
|
|
renderer = GaussianRenderer( |
|
{ |
|
"resolution": 1024, |
|
"near": 0.8, |
|
"far": 1.6, |
|
"ssaa": 1, |
|
"bg_color": (0, 0, 0), |
|
} |
|
) |
|
new_gs = Gaussian(**gs.init_params) |
|
new_gs._features_dc = gs._features_dc.clone() |
|
new_gs._features_rest = ( |
|
gs._features_rest.clone() if gs._features_rest is not None else None |
|
) |
|
new_gs._opacity = torch.nn.Parameter(gs._opacity.clone()) |
|
new_gs._rotation = torch.nn.Parameter(gs._rotation.clone()) |
|
new_gs._scaling = torch.nn.Parameter(gs._scaling.clone()) |
|
new_gs._xyz = torch.nn.Parameter(gs._xyz.clone()) |
|
|
|
start_lr = [1e-4, 1e-3, 5e-3, 0.025] |
|
end_lr = [1e-6, 1e-5, 5e-5, 0.00025] |
|
optimizer = torch.optim.Adam( |
|
[ |
|
{"params": new_gs._xyz, "lr": start_lr[0]}, |
|
{"params": new_gs._rotation, "lr": start_lr[1]}, |
|
{"params": new_gs._scaling, "lr": start_lr[2]}, |
|
{"params": new_gs._opacity, "lr": start_lr[3]}, |
|
], |
|
lr=start_lr[0], |
|
) |
|
|
|
def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): |
|
return start_lr * (end_lr / start_lr) ** (step / total_steps) |
|
|
|
def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): |
|
return end_lr + 0.5 * (start_lr - end_lr) * ( |
|
1 + np.cos(np.pi * step / total_steps) |
|
) |
|
|
|
_zeta = new_gs.get_opacity.clone().detach().squeeze() |
|
_lambda = torch.zeros_like(_zeta) |
|
_delta = 1e-7 |
|
_interval = 10 |
|
num_target = int((1 - simplify) * _zeta.shape[0]) |
|
|
|
with tqdm(total=2500, disable=not verbose, desc="Simplifying Gaussian") as pbar: |
|
for i in range(2500): |
|
|
|
if i % 100 == 0: |
|
mask = new_gs.get_opacity.squeeze() > 0.05 |
|
mask = torch.nonzero(mask).squeeze() |
|
new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask]) |
|
new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask]) |
|
new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask]) |
|
new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask]) |
|
new_gs._features_dc = new_gs._features_dc[mask] |
|
new_gs._features_rest = ( |
|
new_gs._features_rest[mask] |
|
if new_gs._features_rest is not None |
|
else None |
|
) |
|
_zeta = _zeta[mask] |
|
_lambda = _lambda[mask] |
|
|
|
for param_group, new_param in zip( |
|
optimizer.param_groups, |
|
[new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity], |
|
): |
|
stored_state = optimizer.state[param_group["params"][0]] |
|
if "exp_avg" in stored_state: |
|
stored_state["exp_avg"] = stored_state["exp_avg"][mask] |
|
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] |
|
del optimizer.state[param_group["params"][0]] |
|
param_group["params"][0] = new_param |
|
optimizer.state[param_group["params"][0]] = stored_state |
|
|
|
opacity = new_gs.get_opacity.squeeze() |
|
|
|
|
|
if i % _interval == 0: |
|
_zeta = _lambda + opacity.detach() |
|
if opacity.shape[0] > num_target: |
|
index = _zeta.topk(num_target)[1] |
|
_m = torch.ones_like(_zeta, dtype=torch.bool) |
|
_m[index] = 0 |
|
_zeta[_m] = 0 |
|
_lambda = _lambda + opacity.detach() - _zeta |
|
|
|
|
|
view_idx = np.random.randint(len(observations)) |
|
observation = observations[view_idx] |
|
extrinsic = extrinsics[view_idx] |
|
intrinsic = intrinsics[view_idx] |
|
|
|
color = renderer.render(new_gs, extrinsic, intrinsic)["color"] |
|
rgb_loss = torch.nn.functional.l1_loss(color, observation) |
|
loss = rgb_loss + _delta * torch.sum( |
|
torch.pow(_lambda + opacity - _zeta, 2) |
|
) |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
for j in range(len(optimizer.param_groups)): |
|
optimizer.param_groups[j]["lr"] = cosine_anealing( |
|
optimizer, i, 2500, start_lr[j], end_lr[j] |
|
) |
|
|
|
pbar.set_postfix( |
|
{ |
|
"loss": rgb_loss.item(), |
|
"num": opacity.shape[0], |
|
"lambda": _lambda.mean().item(), |
|
} |
|
) |
|
pbar.update() |
|
|
|
new_gs._xyz = new_gs._xyz.data |
|
new_gs._rotation = new_gs._rotation.data |
|
new_gs._scaling = new_gs._scaling.data |
|
new_gs._opacity = new_gs._opacity.data |
|
|
|
return new_gs |
|
|