import logging import math from typing import Union import custom_rasterizer as cr import cv2 import numpy as np import torch import torch.nn.functional as F import trimesh import xatlas from PIL import Image from asset3d_gen.data.utils import ( get_images_from_file, normalize_vertices_array, post_process_texture, save_mesh_with_mtl, ) logging.basicConfig( format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO ) logger = logging.getLogger(__name__) __all__ = ["TextureBacker", "Image_Super_Net", "Image_GANNet"] import math import numpy as np def get_perspective_projection( fov: float, aspect_wh: float, near: float = 0.01, far: float = 100 ) -> np.ndarray: """Compute the perspective projection matrix for 3D rendering.""" fov_rad = math.radians(fov) tan_half_fov = math.tan(fov_rad / 2.0) return np.array( [ [1.0 / (tan_half_fov * aspect_wh), 0.0, 0.0, 0.0], [0.0, 1.0 / tan_half_fov, 0.0, 0.0], [ 0.0, 0.0, -(far + near) / (far - near), -(2.0 * far * near) / (far - near), ], [0.0, 0.0, -1.0, 0.0], ], dtype=np.float32, ) def transform_vertices( mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False ) -> torch.Tensor: """Transform 3D vertices using a projection matrix.""" t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype) if pos.size(-1) == 3: pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1) result = pos @ t_mtx.T return result if keepdim else result.unsqueeze(0) def compute_w2c_matrix( elev_deg: float, azim_deg: float, cam_dist: float ) -> np.ndarray: """Compute w2c 4x4 transformation matrix from spherical coordinates.""" elev_rad = math.radians(-elev_deg) azim_rad = math.radians(azim_deg) sin_elev = math.sin(elev_rad) cos_elev = math.cos(elev_rad) sin_azim = math.sin(azim_rad) cos_azim = math.cos(azim_rad) cam_pos = np.array( [ cam_dist * cos_elev * cos_azim, cam_dist * cos_elev * sin_azim, cam_dist * sin_elev, ] ) look_dir = -cam_pos / np.linalg.norm(cam_pos) right_dir = np.cross(look_dir, [0, 0, 1]) right_dir /= np.linalg.norm(right_dir) up_dir = np.cross(right_dir, look_dir) c2w = np.eye(4) c2w[:3, 0] = right_dir c2w[:3, 1] = up_dir c2w[:3, 2] = -look_dir c2w[:3, 3] = cam_pos try: w2c = np.linalg.inv(c2w) except np.linalg.LinAlgError as e: raise ArithmeticError("Failed to invert camera-to-world matrix") from e return w2c.astype(np.float32) def _bilinear_interpolation_scattering( image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor ) -> torch.Tensor: """Bilinear interpolation scattering for grid-based value accumulation.""" device = values.device dtype = values.dtype C = values.shape[-1] indices = coords * torch.tensor( [image_h - 1, image_w - 1], dtype=dtype, device=device ) i, j = indices.unbind(-1) i0, j0 = ( indices.floor() .long() .clamp(0, image_h - 2) .clamp(0, image_w - 2) .unbind(-1) ) i1, j1 = i0 + 1, j0 + 1 w_i = i - i0.float() w_j = j - j0.float() weights = torch.stack( [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j], dim=1, ) indices_comb = torch.stack( [ torch.stack([i0, j0], dim=1), torch.stack([i0, j1], dim=1), torch.stack([i1, j0], dim=1), torch.stack([i1, j1], dim=1), ], dim=1, ) grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype) cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype) for k in range(4): idx = indices_comb[:, k] w = weights[:, k].unsqueeze(-1) stride = torch.tensor([image_w, 1], device=device, dtype=torch.long) flat_idx = (idx * stride).sum(-1) grid.view(-1, C).scatter_add_( 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w ) cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w) mask = cnt.squeeze(-1) > 0 grid[mask] = grid[mask] / cnt[mask].repeat(1, C) return grid def _texture_inpaint_smooth( texture: np.ndarray, mask: np.ndarray, vertices: np.ndarray, faces: np.ndarray, uv_map: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: """Perform texture inpainting using vertex-based color propagation.""" image_h, image_w, C = texture.shape N = vertices.shape[0] # Initialize vertex data structures vtx_mask = np.zeros(N, dtype=np.float32) vtx_colors = np.zeros((N, C), dtype=np.float32) unprocessed = [] adjacency = [[] for _ in range(N)] # Build adjacency graph and initial color assignment for face_idx in range(faces.shape[0]): for k in range(3): uv_idx_k = faces[face_idx, k] v_idx = faces[face_idx, k] # Convert UV to pixel coordinates with boundary clamping u = np.clip( int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1 ) v = np.clip( int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))), 0, image_h - 1, ) if mask[v, u]: vtx_mask[v_idx] = 1.0 vtx_colors[v_idx] = texture[v, u] elif v_idx not in unprocessed: unprocessed.append(v_idx) # Build undirected adjacency graph neighbor = faces[face_idx, (k + 1) % 3] if neighbor not in adjacency[v_idx]: adjacency[v_idx].append(neighbor) if v_idx not in adjacency[neighbor]: adjacency[neighbor].append(v_idx) # Color propagation with dynamic stopping remaining_iters, prev_count = 2, 0 while remaining_iters > 0: current_unprocessed = [] for v_idx in unprocessed: valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0] if not valid_neighbors: current_unprocessed.append(v_idx) continue # Calculate inverse square distance weights neighbors_pos = vertices[valid_neighbors] dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1) weights = 1 / np.maximum(dist_sq, 1e-8) vtx_colors[v_idx] = np.average( vtx_colors[valid_neighbors], weights=weights, axis=0 ) vtx_mask[v_idx] = 1.0 # Update iteration control if len(current_unprocessed) == prev_count: remaining_iters -= 1 else: remaining_iters = min(remaining_iters + 1, 2) prev_count = len(current_unprocessed) unprocessed = current_unprocessed # Generate output texture inpainted_texture, updated_mask = texture.copy(), mask.copy() for face_idx in range(faces.shape[0]): for k in range(3): v_idx = faces[face_idx, k] if not vtx_mask[v_idx]: continue # UV coordinate conversion uv_idx_k = faces[face_idx, k] u = np.clip( int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1 ) v = np.clip( int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))), 0, image_h - 1, ) inpainted_texture[v, u] = vtx_colors[v_idx] updated_mask[v, u] = 255 return inpainted_texture, updated_mask class TextureBacker: """Texture baking pipeline for multi-view projection and fusion.""" def __init__( self, camera_elevs: list[float], camera_azims: list[float], camera_distance: int, camera_fov: float, view_weights: list[float] = None, render_wh: tuple[int, int] = (2048, 2048), texture_wh: tuple[int, int] = (2048, 2048), use_antialias: bool = True, bake_angle_thres: int = 75, device="cuda", ): self.camera_elevs = camera_elevs self.camera_azims = camera_azims self.view_weights = ( view_weights if view_weights is not None else [1] * len(camera_elevs) ) self.device = device self.render_wh = render_wh self.texture_wh = texture_wh self.camera_distance = camera_distance self.use_antialias = use_antialias self.bake_angle_thres = bake_angle_thres self.bake_unreliable_kernel_size = int( (2 / 512) * max(self.render_wh[0], self.render_wh[1]) ) self.camera_proj_mat = get_perspective_projection( camera_fov, self.render_wh[1] / self.render_wh[0], ) self.cnt = 0 def rasterize_mesh( self, vertex: torch.Tensor, face: torch.Tensor, resolution: tuple[int, int], ) -> torch.Tensor: vertex = vertex[None] if vertex.ndim == 2 else vertex indices, weights = cr.rasterize(vertex, face, resolution) return torch.cat( [weights, indices.unsqueeze(-1).to(weights.dtype)], dim=-1 ).unsqueeze(0) def raster_interpolate( self, uv: torch.Tensor, rast_out: torch.Tensor, faces: torch.Tensor ) -> torch.Tensor: barycentric = rast_out[0, ..., :-1] findices = rast_out[0, ..., -1] if uv.dim() == 2: uv = uv.unsqueeze(0) return cr.interpolate(uv, findices, barycentric, faces)[0] def load_mesh(self, mesh_path: str) -> None: mesh = trimesh.load(mesh_path) if isinstance(mesh, trimesh.Scene): mesh = mesh.dump(concatenate=True) mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices) self.scale, self.center = scale, center vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces) mesh.vertices = mesh.vertices[vmapping] mesh.faces = indices mesh.visual.uv = uvs self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float() self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int) self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float() # Transformation of coordinate system self.vertices[:, [0, 1]] = -self.vertices[:, [0, 1]] self.vertices[:, [1, 2]] = self.vertices[:, [2, 1]] self.uv_map[:, 1] = 1 - self.uv_map[:, 1] def get_mesh_attrs( self, scale: float = None, center: np.ndarray = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: vertices = self.vertices.cpu().numpy() faces = self.faces.cpu().numpy() uv_map = self.uv_map.cpu().numpy() # Inverse transformation of coordinate system vertices[:, [1, 2]] = vertices[:, [2, 1]] vertices[:, [0, 1]] = -vertices[:, [0, 1]] uv_map[:, 1] = 1.0 - uv_map[:, 1] if scale is not None: vertices = vertices / scale if center is not None: vertices = vertices + center return vertices, faces, uv_map def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor: depth_image_np = depth_image.cpu().numpy() depth_image_np = (depth_image_np * 255).astype(np.uint8) depth_edges = cv2.Canny(depth_image_np, 30, 80) combined_edges = depth_edges sketch_image = ( torch.from_numpy(combined_edges).to(depth_image.device).float() / 255 ) sketch_image = sketch_image.unsqueeze(-1) return sketch_image def back_project( self, image: Image.Image, elev: float, azim: float ) -> tuple[torch.Tensor, torch.Tensor]: if isinstance(image, Image.Image): image = np.array(image) image = torch.as_tensor(image, device=self.device, dtype=torch.float32) if image.ndim == 2: image = image.unsqueeze(-1) image = image / 255.0 view_mat = compute_w2c_matrix(elev, azim, self.camera_distance) import pdb pdb.set_trace() pos_cam = transform_vertices(view_mat, self.vertices, keepdim=True) pos_clip = transform_vertices(self.camera_proj_mat, pos_cam) pos_cam = pos_cam[:, :3] / pos_cam[:, 3:] v0, v1, v2 = (pos_cam[self.faces[:, i]] for i in range(3)) face_norm = F.normalize(torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1) vertex_norm = ( torch.from_numpy( trimesh.geometry.mean_vertex_normals( len(pos_cam), self.faces.cpu(), face_norm.cpu() ) ) .to(self.device) .contiguous() ) rast_out = self.rasterize_mesh(pos_clip, self.faces, image.shape[:2]) vis_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0] interp_data = { "normal": self.raster_interpolate( vertex_norm[None], rast_out, self.faces ), "uv": self.raster_interpolate( self.uv_map[None], rast_out, self.faces ), "depth": self.raster_interpolate( pos_cam[:, 2].reshape(1, -1, 1), rast_out, self.faces ), } valid_depth = interp_data["depth"][vis_mask > 0] depth_norm = (interp_data["depth"] - valid_depth.min()) / ( valid_depth.max() - valid_depth.min() ) # depth_norm[vis_mask <= 0] = 0 sketch_image = self._render_depth_edges(depth_norm * vis_mask) # ddd = depth_norm * vis_mask # cv2.imwrite(f"v2_depth_d{self.cnt}.png", (ddd.cpu().numpy() * 255).astype(np.uint8)) cv2.imwrite( f"v2_vis_mask{self.cnt}.png", (vis_mask.cpu().numpy() * 255).astype(np.uint8), ) cv2.imwrite( f"v2_normal{self.cnt}.png", (interp_data["normal"].cpu().numpy() * 255).astype(np.uint8), ) cv2.imwrite( f"v2_depth{self.cnt}.png", (depth_norm.cpu().numpy() * 255).astype(np.uint8), ) cv2.imwrite( f"v2_uv{self.cnt}.png", (interp_data["uv"][..., 0].cpu().numpy() * 255).astype(np.uint8), ) cv2.imwrite( f"v2_sketch{self.cnt}.png", (sketch_image.cpu().numpy() * 255).astype(np.uint8), ) self.cnt += 1 cos = F.cosine_similarity( torch.tensor([[0, 0, -1]], device=self.device), interp_data["normal"].view(-1, 3), ).view_as(interp_data["normal"][..., :1]) cos[cos < np.cos(np.radians(self.bake_angle_thres))] = 0 cv2.imwrite( f"v2_cos{self.cnt}.png", (cos.cpu().numpy() * 255).astype(np.uint8) ) k = self.bake_unreliable_kernel_size * 2 + 1 kernel = torch.ones((1, 1, k, k), device=self.device) vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float() vis_mask = F.conv2d( 1.0 - vis_mask, kernel, padding=k // 2, ) vis_mask = 1.0 - (vis_mask > 0).float() vis_mask = vis_mask.squeeze(0).permute(1, 2, 0) sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0) sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2) sketch_image = (sketch_image > 0).float() sketch_image = sketch_image.squeeze(0).permute(1, 2, 0) vis_mask = vis_mask * (sketch_image < 0.5) cos[vis_mask == 0] = 0 vis_mask = cv2.imread( f"v3_db_mask{self.cnt}.png", cv2.IMREAD_GRAYSCALE ) vis_mask = ( torch.from_numpy(vis_mask[..., None]).to(self.device).float() / 255 ) # cos2 = cv2.imread(f"v3_db_cos{self.cnt}.png", cv2.IMREAD_GRAYSCALE) # cos2 = torch.from_numpy(cos2[..., None]).to(self.device).float() / 255 # cos = cos2 valid_pixels = (vis_mask != 0).view(-1) # import pdb; pdb.set_trace() cv2.imwrite( f"v2_db_sketch{self.cnt}.png", (sketch_image.cpu().numpy() * 255).astype(np.uint8), ) cv2.imwrite( f"v2_db_uv{self.cnt}.png", (interp_data["uv"][..., 0].cpu().numpy() * 255).astype(np.uint8), ) cv2.imwrite( f"v2_db_uv2{self.cnt}.png", (interp_data["uv"][..., 1].cpu().numpy() * 255).astype(np.uint8), ) cv2.imwrite( f"v2_db_color{self.cnt}.png", (image.cpu().numpy() * 255).astype(np.uint8), ) cv2.imwrite( f"v2_db_cos{self.cnt}.png", (cos.cpu().numpy() * 255).astype(np.uint8), ) cv2.imwrite( f"v2_db_mask{self.cnt}.png", (vis_mask.cpu().numpy() * 255).astype(np.uint8), ) # import pdb; pdb.set_trace() return ( self._scatter_texture(interp_data["uv"], image, valid_pixels), self._scatter_texture(interp_data["uv"], cos, valid_pixels), ) def _scatter_texture(self, uv, data, mask): def __filter_data(data, mask): return data.view(-1, data.shape[-1])[mask] return _bilinear_interpolation_scattering( self.texture_wh[1], self.texture_wh[0], __filter_data(uv, mask)[..., [1, 0]], __filter_data(data, mask), ) @torch.no_grad() def fast_bake_texture( self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor] ) -> tuple[torch.Tensor, torch.Tensor]: channel = textures[0].shape[-1] texture_merge = torch.zeros(self.texture_wh + (channel,)).to( self.device ) trust_map_merge = torch.zeros(self.texture_wh + (1,)).to(self.device) for texture, cos_map in zip(textures, confidence_maps): view_sum = (cos_map > 0).sum() painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum() if painted_sum / view_sum > 0.99: continue texture_merge += texture * cos_map trust_map_merge += cos_map texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8) return texture_merge, trust_map_merge > 1e-8 def uv_inpaint( self, texture: torch.Tensor, mask: torch.Tensor ) -> np.ndarray: texture_np = texture.cpu().numpy() mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8) vertices, faces, uv_map = self.get_mesh_attrs() # import pdb; pdb.set_trace() texture_np, mask_np = _texture_inpaint_smooth( texture_np, mask_np, vertices, faces, uv_map ) texture_np = texture_np.clip(0, 1) texture_np = cv2.inpaint( (texture_np * 255).astype(np.uint8), 255 - mask_np, 3, cv2.INPAINT_NS, ) return texture_np def __call__( self, colors: list[Image.Image], input_mesh: str, output_path: str ) -> trimesh.Trimesh: self.load_mesh(input_mesh) textures, weighted_cos_maps = [], [] for color, cam_elev, cam_azim, weight in zip( colors, self.camera_elevs, self.camera_azims, self.view_weights ): texture, cos_map = self.back_project(color, cam_elev, cam_azim) cv2.imwrite( f"v2_texture{self.cnt}.png", (texture.cpu().numpy() * 255).astype(np.uint8), ) cv2.imwrite( f"v2_texture_cos{self.cnt}.png", (cos_map.cpu().numpy() * 255).astype(np.uint8), ) # import pdb; pdb.set_trace() textures.append(texture) weighted_cos_maps.append(weight * (cos_map**4)) texture, mask = self.fast_bake_texture(textures, weighted_cos_maps) texture_np = self.uv_inpaint(texture, mask) texture_np = post_process_texture(texture_np) vertices, faces, uvs = self.get_mesh_attrs(self.scale, self.center) # import pdb; pdb.set_trace() cv2.imwrite("v2_texture_np.png", texture_np) textured_mesh = save_mesh_with_mtl( vertices, faces, uvs, texture_np, output_path ) return textured_mesh class Image_Super_Net: def __init__(self, device="cuda"): from diffusers import StableDiffusionUpscalePipeline self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained( "stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, ).to(device) self.up_pipeline_x4.set_progress_bar_config(disable=True) def __call__(self, image, prompt=""): with torch.no_grad(): upscaled_image = self.up_pipeline_x4( prompt=[prompt], image=image, num_inference_steps=10, ).images[0] return upscaled_image class Image_GANNet: def __init__(self, outscale: int): from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer self.outscale = outscale model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4, ) self.upsampler = RealESRGANer( scale=4, model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/super_resolution/RealESRGAN_x4plus.pth", # noqa model=model, pre_pad=0, half=True, ) def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image: if isinstance(image, Image.Image): image = np.array(image) output, _ = self.upsampler.enhance(image, outscale=self.outscale) return Image.fromarray(output) if __name__ == "__main__": device = "cuda" color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png" mesh_path = "outputs/texture_mesh_gen/texture_mesh/kettle_color.glb" output_path = "robot_test_v2/robot.obj" target_image_size = (2048, 2048) super_model = Image_GANNet(outscale=4) multiviews = get_images_from_file(color_path, img_size=512) texture_backer = TextureBacker( camera_elevs=[20, 20, 20, -10, -10, -10], camera_azims=[-180, -60, 60, -120, 0, 120], view_weights=[1, 0.2, 0.2, 0.2, 1, 0.2], camera_distance=5, camera_fov=30, render_wh=(2048, 2048), texture_wh=(2048, 2048), ) multiviews = [super_model(img) for img in multiviews] multiviews = [img.convert("RGB") for img in multiviews] textured_mesh = texture_backer(multiviews, mesh_path, output_path)