ImgRoboAssetGen / asset3d_gen /data /backup /backproject_v2 copy.py
xinjie.wang
update
55ed985
raw
history blame
21.1 kB
import argparse
import logging
import math
import os
import cv2
import numpy as np
import nvdiffrast.torch as dr
import torch
import torch.nn.functional as F
from torchvision.transforms import functional as tF
import trimesh
import xatlas
from PIL import Image
from asset3d_gen.data.mesh_operator import MeshFixer
from asset3d_gen.data.utils import (
CameraSetting,
DiffrastRender,
get_images_from_grid,
init_kal_camera,
normalize_vertices_array,
post_process_texture,
save_mesh_with_mtl,
)
from asset3d_gen.models.delight_model import DelightingModel
from asset3d_gen.models.sr_model import ImageRealESRGAN
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
__all__ = [
"TextureBacker",
]
def transform_vertices(
mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
) -> torch.Tensor:
"""Transform 3D vertices using a projection matrix."""
t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
if pos.size(-1) == 3:
pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
result = pos @ t_mtx.T
return result if keepdim else result.unsqueeze(0)
def _bilinear_interpolation_scattering(
image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
) -> torch.Tensor:
"""Bilinear interpolation scattering for grid-based value accumulation."""
device = values.device
dtype = values.dtype
C = values.shape[-1]
indices = coords * torch.tensor(
[image_h - 1, image_w - 1], dtype=dtype, device=device
)
i, j = indices.unbind(-1)
i0, j0 = (
indices.floor()
.long()
.clamp(0, image_h - 2)
.clamp(0, image_w - 2)
.unbind(-1)
)
i1, j1 = i0 + 1, j0 + 1
w_i = i - i0.float()
w_j = j - j0.float()
weights = torch.stack(
[(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
dim=1,
)
indices_comb = torch.stack(
[
torch.stack([i0, j0], dim=1),
torch.stack([i0, j1], dim=1),
torch.stack([i1, j0], dim=1),
torch.stack([i1, j1], dim=1),
],
dim=1,
)
grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
for k in range(4):
idx = indices_comb[:, k]
w = weights[:, k].unsqueeze(-1)
stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
flat_idx = (idx * stride).sum(-1)
grid.view(-1, C).scatter_add_(
0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
)
cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
mask = cnt.squeeze(-1) > 0
grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
return grid
def _texture_inpaint_smooth(
texture: np.ndarray,
mask: np.ndarray,
vertices: np.ndarray,
faces: np.ndarray,
uv_map: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Perform texture inpainting using vertex-based color propagation."""
image_h, image_w, C = texture.shape
N = vertices.shape[0]
# Initialize vertex data structures
vtx_mask = np.zeros(N, dtype=np.float32)
vtx_colors = np.zeros((N, C), dtype=np.float32)
unprocessed = []
adjacency = [[] for _ in range(N)]
# Build adjacency graph and initial color assignment
for face_idx in range(faces.shape[0]):
for k in range(3):
uv_idx_k = faces[face_idx, k]
v_idx = faces[face_idx, k]
# Convert UV to pixel coordinates with boundary clamping
u = np.clip(
int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
)
v = np.clip(
int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
0,
image_h - 1,
)
if mask[v, u]:
vtx_mask[v_idx] = 1.0
vtx_colors[v_idx] = texture[v, u]
elif v_idx not in unprocessed:
unprocessed.append(v_idx)
# Build undirected adjacency graph
neighbor = faces[face_idx, (k + 1) % 3]
if neighbor not in adjacency[v_idx]:
adjacency[v_idx].append(neighbor)
if v_idx not in adjacency[neighbor]:
adjacency[neighbor].append(v_idx)
# Color propagation with dynamic stopping
remaining_iters, prev_count = 2, 0
while remaining_iters > 0:
current_unprocessed = []
for v_idx in unprocessed:
valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
if not valid_neighbors:
current_unprocessed.append(v_idx)
continue
# Calculate inverse square distance weights
neighbors_pos = vertices[valid_neighbors]
dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
weights = 1 / np.maximum(dist_sq, 1e-8)
vtx_colors[v_idx] = np.average(
vtx_colors[valid_neighbors], weights=weights, axis=0
)
vtx_mask[v_idx] = 1.0
# Update iteration control
if len(current_unprocessed) == prev_count:
remaining_iters -= 1
else:
remaining_iters = min(remaining_iters + 1, 2)
prev_count = len(current_unprocessed)
unprocessed = current_unprocessed
# Generate output texture
inpainted_texture, updated_mask = texture.copy(), mask.copy()
for face_idx in range(faces.shape[0]):
for k in range(3):
v_idx = faces[face_idx, k]
if not vtx_mask[v_idx]:
continue
# UV coordinate conversion
uv_idx_k = faces[face_idx, k]
u = np.clip(
int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
)
v = np.clip(
int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
0,
image_h - 1,
)
inpainted_texture[v, u] = vtx_colors[v_idx]
updated_mask[v, u] = 255
return inpainted_texture, updated_mask
def interp_tensers(tensors: list[torch.Tensor], target_wh: tuple[int, int]) -> list[torch.Tensor]:
for idx in range(len(tensors)):
tensor = tensors[idx].permute(2, 0, 1)
tensor = tF.resize(tensor, target_wh[::-1], antialias=True)
tensors[idx] = tensor.permute(1, 2, 0)
return tensors
class TextureBacker:
"""Texture baking pipeline for multi-view projection and fusion."""
def __init__(
self,
camera_params: CameraSetting,
view_weights: list[float],
render_wh: tuple[int, int] = (2048, 2048),
texture_wh: tuple[int, int] = (2048, 2048),
bake_angle_thresh: int = 75,
mask_thresh: float = 0.5,
):
camera = init_kal_camera(camera_params)
mv = camera.view_matrix() # (n 4 4) world2cam
p = camera.intrinsics.projection_matrix()
# NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
p[:, 1, 1] = -p[:, 1, 1]
self.renderer = DiffrastRender(
p_matrix=p,
mv_matrix=mv,
resolution_hw=camera_params.resolution_hw,
context=dr.RasterizeCudaContext(),
mask_thresh=mask_thresh,
grad_db=False,
device=camera_params.device,
antialias_mask=True,
)
self.camera = camera
self.view_weights = view_weights
self.device = camera_params.device
self.render_wh = render_wh
self.texture_wh = texture_wh
self.bake_angle_thresh = bake_angle_thresh
self.bake_unreliable_kernel_size = int(
(2 / 512) * max(self.render_wh[0], self.render_wh[1])
)
def load_mesh(self, mesh: trimesh.Trimesh) -> None:
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
self.scale, self.center = scale, center
vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
uvs[:, 1] = 1 - uvs[:, 1]
mesh.vertices = mesh.vertices[vmapping]
mesh.faces = indices
mesh.visual.uv = uvs
self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
def get_mesh_np_attrs(
self,
scale: float = None,
center: np.ndarray = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
vertices = self.vertices.cpu().numpy()
faces = self.faces.cpu().numpy()
uv_map = self.uv_map.cpu().numpy()
uv_map[:, 1] = 1.0 - uv_map[:, 1]
if scale is not None:
vertices = vertices / scale
if center is not None:
vertices = vertices + center
return vertices, faces, uv_map
def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
depth_image_np = depth_image.cpu().numpy()
depth_image_np = (depth_image_np * 255).astype(np.uint8)
depth_edges = cv2.Canny(depth_image_np, 30, 80)
sketch_image = (
torch.from_numpy(depth_edges).to(depth_image.device).float() / 255
)
sketch_image = sketch_image.unsqueeze(-1)
return sketch_image
def compute_enhanced_viewnormal(
self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
) -> torch.Tensor:
rast, _ = self.renderer.compute_dr_raster(vertices, faces)
rendered_view_normals = []
for idx in range(len(mv_mtx)):
pos_cam = transform_vertices(mv_mtx[idx], vertices, keepdim=True)
pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3))
face_norm = F.normalize(
torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1
)
vertex_norm = (
torch.from_numpy(
trimesh.geometry.mean_vertex_normals(
len(pos_cam), faces.cpu(), face_norm.cpu()
)
)
.to(vertices.device)
.contiguous()
)
im_base_normals, _ = dr.interpolate(
vertex_norm[None, ...].float(),
rast[idx : idx + 1],
faces.to(torch.int32),
)
rendered_view_normals.append(im_base_normals)
rendered_view_normals = torch.cat(rendered_view_normals, dim=0)
return rendered_view_normals
def back_project(
self, image, vis_mask, depth, normal, uv
) -> tuple[torch.Tensor, torch.Tensor]:
image = np.array(image)
image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
if image.ndim == 2:
image = image.unsqueeze(-1)
image = image / 255
depth_inv = (1.0 - depth) * vis_mask
sketch_image = self._render_depth_edges(depth_inv)
cos = F.cosine_similarity(
torch.tensor([[0, 0, 1]], device=self.device),
normal.view(-1, 3),
).view_as(normal[..., :1])
cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
k = self.bake_unreliable_kernel_size * 2 + 1
kernel = torch.ones((1, 1, k, k), device=self.device)
vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
vis_mask = F.conv2d(
1.0 - vis_mask,
kernel,
padding=k // 2,
)
vis_mask = 1.0 - (vis_mask > 0).float()
vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
sketch_image = (sketch_image > 0).float()
sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
vis_mask = vis_mask * (sketch_image < 0.5)
cos[vis_mask == 0] = 0
valid_pixels = (vis_mask != 0).view(-1)
return (
self._scatter_texture(uv, image, valid_pixels),
self._scatter_texture(uv, cos, valid_pixels),
)
def _scatter_texture(self, uv, data, mask):
def __filter_data(data, mask):
return data.view(-1, data.shape[-1])[mask]
return _bilinear_interpolation_scattering(
self.texture_wh[1],
self.texture_wh[0],
__filter_data(uv, mask)[..., [1, 0]],
__filter_data(data, mask),
)
@torch.no_grad()
def fast_bake_texture(
self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor]:
channel = textures[0].shape[-1]
texture_merge = torch.zeros(self.texture_wh + [channel]).to(
self.device
)
trust_map_merge = torch.zeros(self.texture_wh + [1]).to(self.device)
for texture, cos_map in zip(textures, confidence_maps):
view_sum = (cos_map > 0).sum()
painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
if painted_sum / view_sum > 0.99:
continue
texture_merge += texture * cos_map
trust_map_merge += cos_map
texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
return texture_merge, trust_map_merge > 1e-8
def uv_inpaint(
self, texture: torch.Tensor, mask: torch.Tensor
) -> np.ndarray:
texture_np = texture.cpu().numpy()
mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
vertices, faces, uv_map = self.get_mesh_np_attrs()
texture_np, mask_np = _texture_inpaint_smooth(
texture_np, mask_np, vertices, faces, uv_map
)
texture_np = texture_np.clip(0, 1)
texture_np = cv2.inpaint(
(texture_np * 255).astype(np.uint8),
255 - mask_np,
3,
cv2.INPAINT_NS,
)
return texture_np
def __call__(
self,
colors: list[Image.Image],
mesh: trimesh.Trimesh,
output_path: str,
) -> trimesh.Trimesh:
import time
start = time.time()
self.load_mesh(mesh)
print("load_mesh", time.time() - start)
start = time.time()
rendered_depth, masks = self.renderer.render_depth(
self.vertices, self.faces
)
norm_deps = self.renderer.normalize_map_by_mask(rendered_depth, masks)
render_uvs, _ = self.renderer.render_uv(
self.vertices, self.faces, self.uv_map
)
view_normals = self.compute_enhanced_viewnormal(
self.renderer.mv_mtx, self.vertices, self.faces
)
print("0", time.time() - start)
textures, weighted_cos_maps = [], []
start = time.time()
for color, mask, dep, normal, uv, weight in zip(
colors,
masks,
norm_deps,
view_normals,
render_uvs,
self.view_weights,
):
mask, dep, normal, uv = interp_tensers([mask, dep, normal, uv], self.render_wh)
texture, cos_map = self.back_project(color, mask, dep, normal, uv)
textures.append(texture)
weighted_cos_maps.append(weight * (cos_map**4))
print("1", time.time() - start)
start = time.time()
texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
print("2", time.time() - start)
start = time.time()
texture_np = self.uv_inpaint(texture, mask)
print("3", time.time() - start)
start = time.time()
texture_np = post_process_texture(texture_np)
vertices, faces, uv_map = self.get_mesh_np_attrs(
self.scale, self.center
)
textured_mesh = save_mesh_with_mtl(
vertices, faces, uv_map, texture_np, output_path
)
print("4", time.time() - start)
return textured_mesh
def parse_args():
parser = argparse.ArgumentParser(description="Backproject texture")
parser.add_argument(
"--color_path",
type=str,
help="Multiview color image in 6x512x512 file path",
)
parser.add_argument(
"--mesh_path",
type=str,
help="Mesh path, .obj, .glb or .ply",
)
parser.add_argument(
"--output_path",
type=str,
help="Output mesh path with suffix",
)
parser.add_argument(
"--num_images", type=int, default=6, help="Number of images to render."
)
parser.add_argument(
"--elevation",
nargs=2,
type=float,
default=[20.0, -10.0],
help="Elevation angles for the camera (default: [20.0, -10.0])",
)
parser.add_argument(
"--distance",
type=float,
default=5,
help="Camera distance (default: 5)",
)
parser.add_argument(
"--resolution_hw",
type=int,
nargs=2,
default=(2048, 2048),
help="Resolution of the mesh rendering",
)
parser.add_argument(
"--target_hw",
type=int,
nargs=2,
default=(2048, 2048),
help="Target rendering images resolution",
)
parser.add_argument(
"--fov",
type=float,
default=30,
help="Field of view in degrees (default: 30)",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda"],
default="cuda",
help="Device to run on (default: `cuda`)",
)
parser.add_argument(
"--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
)
parser.add_argument(
"--texture_wh",
nargs=2,
type=int,
default=[2048, 2048],
help="Texture resolution width and height",
)
parser.add_argument(
"--mesh_sipmlify_ratio",
type=float,
default=0.9,
help="Mesh simplification ratio (default: 0.9)",
)
parser.add_argument(
"--delight", action="store_true", help="Use delighting model."
)
args = parser.parse_args()
return args
def entrypoint(
delight_model: DelightingModel = None,
imagesr_model: ImageRealESRGAN = None,
**kwargs,
) -> trimesh.Trimesh:
args = parse_args()
for k, v in kwargs.items():
if hasattr(args, k) and v is not None:
setattr(args, k, v)
# Setup camera parameters.
camera_params = CameraSetting(
num_images=args.num_images,
elevation=args.elevation,
distance=args.distance,
resolution_hw=args.resolution_hw,
fov=math.radians(args.fov),
device=args.device,
)
view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
color_grid = Image.open(args.color_path)
if args.delight:
if delight_model is None:
delight_model = DelightingModel(
model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/hunyuan3d-delight-v2-0" # noqa
)
save_dir = os.path.dirname(args.output_path)
os.makedirs(save_dir, exist_ok=True)
color_grid.save(f"{save_dir}/color_grid.png")
color_grid = delight_model(color_grid)
color_grid.save(f"{save_dir}/color_grid_delight.png")
multiviews = get_images_from_grid(color_grid, img_size=512)
# Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
if imagesr_model is None:
imagesr_model = ImageRealESRGAN(outscale=4)
multiviews = [imagesr_model(img.convert("RGB")) for img in multiviews]
multiviews = [img.resize(args.target_hw[::-1]) for img in multiviews]
mesh = trimesh.load(args.mesh_path)
if isinstance(mesh, trimesh.Scene):
mesh = mesh.dump(concatenate=True)
if not args.skip_fix_mesh:
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
mesh_fixer = MeshFixer(mesh.vertices, mesh.faces, args.device)
mesh.vertices, mesh.faces = mesh_fixer(
filter_ratio=args.mesh_sipmlify_ratio,
max_hole_size=0.04,
resolution=1024,
num_views=1000,
norm_mesh_ratio=0.5,
)
# Restore scale.
mesh.vertices = mesh.vertices / scale
mesh.vertices = mesh.vertices + center
# Baking texture to mesh.
import time
start = time.time()
texture_backer = TextureBacker(
camera_params=camera_params,
view_weights=view_weights,
render_wh=args.target_hw,
texture_wh=args.texture_wh,
)
print(time.time()-start)
start = time.time()
textured_mesh = texture_backer(multiviews, mesh, args.output_path)
print(f"Texture backproject time: {time.time() - start:.2f}s")
return textured_mesh
if __name__ == "__main__":
entrypoint()