import tempfile import gradio as gr import numpy as np import spaces import torch import trimesh import xatlas from PIL import Image from .render_utils import (get_mvp_matrix, get_pure_texture, render_geo_map, render_geo_views_tensor, render_views, setup_lights) from utils.file_utils import save_tensor_to_file class Mesh: def __init__(self, mesh_path=None, uv_tool="xAtlas", device='cuda', progress=gr.Progress()): """ Initialize the Mesh object with a mesh file path. :param mesh_path: Path to the mesh file (e.g., .obj or .glb). """ self._device = device if mesh_path is not None: # Initialize _parts dictionary to store all parts self._parts = {} if mesh_path.endswith('.obj'): progress(0., f"Loading mesh in .obj format...") mesh_data = trimesh.load(mesh_path, process=False) # Check if it's a mesh list (multi-part obj) if isinstance(mesh_data, list): progress(0.1, f"Handling part list...") for i, mesh_part in enumerate(mesh_data): self._add_part_to_parts(f"part_{i}", mesh_part) # Check if it's a Scene (another multi-part format) elif isinstance(mesh_data, trimesh.Scene): progress(0.1, f"Handling Scenes...") geometry = mesh_data.geometry if len(geometry) > 0: for key, mesh_part in geometry.items(): self._add_part_to_parts(key, mesh_part) else: raise ValueError("Empty scene, no mesh data found.") else: # Single part obj progress(0.1, f"Handling single part...") self._add_part_to_parts("part_0", mesh_data) elif mesh_path.endswith('.glb'): progress(0., f"Loading mesh in .glb format...") mesh_loaded = trimesh.load(mesh_path) # Check if it's a Scene (multi-part glb) if isinstance(mesh_loaded, trimesh.Scene): progress(0.1, f"Handling Scenes...") geometry = mesh_loaded.geometry if len(geometry) > 0: for key, mesh_part in geometry.items(): self._add_part_to_parts(key, mesh_part) else: raise ValueError("Empty scene, no mesh data found.") else: # Single part glb progress(0.1, f"Handling single part...") self._add_part_to_parts("part_0", mesh_loaded) else: raise ValueError(f"Unsupported file format: {mesh_path}") # Automatically merge all parts during initialization progress(0.2, f"Merging if the mesh have multiple parts.") self._merge_parts_internal() else: raise ValueError("Mesh path cannot be None.") self.to(self.device) # Move to the specified device # Initialize transformation flags self._upside_down_applied = False # UV parameterization if self.has_multi_parts or not self.has_uv: progress(0.4, f"Using {uv_tool} for UV parameterization. It may take quite a while (several minutes), if there are many faces. We STRONLY recommend using a mesh with UV parameterization.") if uv_tool == "xAtlas": self.uv_xatlas_mapping() # Use default parameters elif uv_tool == "UVAtlas": raise NotImplementedError("UVAtlas parameterization is not implemented yet.") else: raise ValueError("Unsupported UV parameterization tool.") print("UV parameterization completed.") else: progress(0.4, f"The model has SINGLE UV parameterization, no need to reparameterize.") self._vmapping = None # No vmapping needed when not reparameterizing @property def device(self): return self._device def to(self, device): """ Move the mesh data to the specified device. :param device: The target device (e.g., 'cuda' or 'cpu'). """ self._device = device self._v_pos = self._v_pos.to(device) self._t_pos_idx = self._t_pos_idx.to(device) if self._v_tex is not None: self._v_tex = self._v_tex.to(device) self._t_tex_idx = self._t_tex_idx.to(device) if hasattr(self, '_vmapping') and self._vmapping is not None: self._vmapping = self._vmapping.to(device) self._v_normal = self._v_normal.to(device) return self @property def has_multi_parts(self): """ Check if the mesh has multiple parts. :return: Boolean indicating whether the mesh has multiple parts. """ # If _parts is None, it means already merged, not multi-part if self._parts is None: return False return len(self._parts) > 1 @property def v_pos(self): """Vertex positions property.""" return self._v_pos @v_pos.setter def v_pos(self, value): self._v_pos = value @property def t_pos_idx(self): """Triangle position indices property.""" return self._t_pos_idx @t_pos_idx.setter def t_pos_idx(self, value): self._t_pos_idx = value @property def v_tex(self): """Vertex texture coordinates property.""" return self._v_tex @v_tex.setter def v_tex(self, value): self._v_tex = value @property def t_tex_idx(self): """Triangle texture indices property.""" return self._t_tex_idx @t_tex_idx.setter def t_tex_idx(self, value): self._t_tex_idx = value @property def v_normal(self): """Vertex normals property.""" return self._v_normal @v_normal.setter def v_normal(self, value): self._v_normal = value @property def has_uv(self): """ Check if the mesh has a valid UV mapping. :return: Boolean indicating whether the mesh has UV mapping. """ return self.v_tex is not None def uv_xatlas_mapping(self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}): # Merged mesh, directly add_mesh as a whole atlas = xatlas.Atlas() v_pos_np = self.v_pos.detach().cpu().numpy() t_pos_idx_np = self.t_pos_idx.cpu().numpy() atlas.add_mesh(v_pos_np, t_pos_idx_np) # Set reasonable pack parameters to avoid overlap co = xatlas.ChartOptions() po = xatlas.PackOptions() # Recommended default parameters if 'resolution' not in xatlas_pack_options: po.resolution = 1024 # or larger if 'padding' not in xatlas_pack_options: po.padding = 2 for k, v in xatlas_chart_options.items(): setattr(co, k, v) for k, v in xatlas_pack_options.items(): setattr(po, k, v) atlas.generate(co, po) # Get unpacked data vmapping, indices, uvs = atlas.get_mesh(0) # vmapping: new UV vertex -> original mesh vertex # indices: new triangle face indices (based on new UV vertices) # uvs: new UV vertex coordinates device = self.v_pos.device vmapping = torch.from_numpy(vmapping.astype(np.uint64, casting="same_kind").view(np.int64)).to(device).long() uvs = torch.from_numpy(uvs).to(device).float() indices = torch.from_numpy(indices.astype(np.uint64, casting="same_kind").view(np.int64)).to(device).long() self.v_tex = uvs # new UV vertices self.t_tex_idx = indices # new triangle face indices (based on UV vertices) self._vmapping = vmapping # save UV vertex to original vertex mapping for export def normalize(self): """ Normalize mesh vertices to [-1, 1] range. """ vertices = self.v_pos bounding_box_max = vertices.max(0)[0] bounding_box_min = vertices.min(0)[0] mesh_scale = 2.0 # Scale to [-1, 1] scale = mesh_scale / ((bounding_box_max - bounding_box_min).max() + 1e-6) center_offset = (bounding_box_max + bounding_box_min) * 0.5 self.v_pos = (vertices - center_offset) * scale def vertex_transform(self): """ Apply coordinate transformation to mesh vertices and normals. """ # Transform normals pre_normals = self.v_normal normals = torch.clone(pre_normals) normals[:, 1] = -pre_normals[:, 2] # -z --> y normals[:, 2] = pre_normals[:, 1] # y --> z # Transform vertices pre_vertices = self.v_pos vertices = torch.clone(pre_vertices) vertices[:, 1] = -pre_vertices[:, 2] # -z --> y vertices[:, 2] = pre_vertices[:, 1] # y --> z # Update mesh self.v_normal = normals self.v_pos = vertices def vertex_transform_y2x(self): """ Apply coordinate transformation to mesh vertices and normals. """ # Transform normals pre_normals = self.v_normal normals = torch.clone(pre_normals) normals[:, 1] = -pre_normals[:, 0] # -x --> y normals[:, 0] = pre_normals[:, 1] # y --> x # Transform vertices pre_vertices = self.v_pos vertices = torch.clone(pre_vertices) vertices[:, 1] = -pre_vertices[:, 0] # -z --> y vertices[:, 0] = pre_vertices[:, 1] # y --> z # 更新网格 self.v_normal = normals self.v_pos = vertices def vertex_transform_z2x(self): """ Apply coordinate transformation to mesh vertices and normals. """ # 变换法向量 pre_normals = self.v_normal normals = torch.clone(pre_normals) normals[:, 2] = -pre_normals[:, 0] # -x --> z normals[:, 0] = pre_normals[:, 2] # z --> x # 变换顶点 pre_vertices = self.v_pos vertices = torch.clone(pre_vertices) vertices[:, 2] = -pre_vertices[:, 0] # -z --> y vertices[:, 0] = pre_vertices[:, 2] # y --> z # 更新网格 self.v_normal = normals self.v_pos = vertices def vertex_transform_upsidedown(self): """ Apply upside-down transformation to mesh vertices and normals. """ # 变换法向量 pre_normals = self.v_normal normals = torch.clone(pre_normals) normals[:, 2] = -pre_normals[:, 2] # 变换顶点 pre_vertices = self.v_pos vertices = torch.clone(pre_vertices) vertices[:, 2] = -pre_vertices[:, 2] # 更新网格 self.v_normal = normals self.v_pos = vertices # self.t_pos_idx = faces # 标记已应用上下翻转变换 self._upside_down_applied = True def _add_part_to_parts(self, key, mesh_part): """ 将单个mesh部分添加到_parts字典中 :param key: 部分的键名 :param mesh_part: trimesh对象 """ # exclude PointCloud parts and empty parts if hasattr(mesh_part, 'vertices') and hasattr(mesh_part, 'faces') and len(mesh_part.vertices) > 0 and len(mesh_part.faces) > 0: raw_uv = getattr(mesh_part.visual, 'uv', None) processed_v_tex = None processed_t_tex_idx = None # 仅当UV数据存在且不为空时才处理 if raw_uv is not None and np.asarray(raw_uv).size > 0 and np.asarray(raw_uv).shape[0] > 0: processed_v_tex = torch.tensor(raw_uv, dtype=torch.float32) # 假设当源数据提供UV时,t_tex_idx 与 t_pos_idx 使用相同的面索引 # trimesh 通常提供每个顶点的UV processed_t_tex_idx = torch.tensor(mesh_part.faces, dtype=torch.int32) self._parts[key] = { 'v_pos': torch.tensor(mesh_part.vertices, dtype=torch.float32), 't_pos_idx': torch.tensor(mesh_part.faces, dtype=torch.int32), 'v_tex': processed_v_tex, 't_tex_idx': processed_t_tex_idx, 'v_normal': torch.tensor(mesh_part.vertex_normals, dtype=torch.float32) } def _merge_parts_internal(self): """ 内部使用的合并函数,在初始化时自动调用 将_parts中的所有部分合并为单一的mesh表示 """ # 如果没有部分或只有一个部分,简化处理 if not self._parts: raise ValueError("No mesh parts.") elif len(self._parts) == 1: key = next(iter(self._parts)) part = self._parts[key] self._v_pos = part['v_pos'] self._t_pos_idx = part['t_pos_idx'] self._v_tex = part['v_tex'] self._t_tex_idx = part['t_tex_idx'] self._v_normal = part['v_normal'] self._parts = None # 清理_parts字典,释放内存 return # 初始化合并后的数据 vertices = [] faces = [] normals = [] # Record vertex count for each part, used to adjust face indices v_count = 0 # Iterate through all parts for key, part in self._parts.items(): # Add vertices vertices.append(part['v_pos']) # Adjust face indices and add if len(faces) > 0: adjusted_faces = part['t_pos_idx'] + v_count faces.append(adjusted_faces) else: faces.append(part['t_pos_idx']) # Add normals normals.append(part['v_normal']) # Update vertex count v_count += part['v_pos'].shape[0] self._parts = None # Clear _parts dictionary to free memory # Merge all data self._v_pos = torch.cat(vertices, dim=0) self._t_pos_idx = torch.cat(faces, dim=0) self._v_normal = torch.cat(normals, dim=0) self._v_tex = None # multi-parts mesh must be reparameterized self._t_tex_idx = None # multi-parts mesh must be reparameterized self._vmapping = None # multi-parts mesh must be reparameterized @classmethod def export(cls, mesh, save_path=None, texture_map: Image.Image = None): """ Exports the mesh to a GLB file. :param mesh: Mesh instance to export :param save_path: Optional path to save the GLB file. If None, a temporary file will be created. :param texture_map: Optional PIL.Image to use as the texture. If None, a default texture will be used. :return: Path to the exported GLB file. """ # 由于传入的mesh一定是process过的,所以断言确保是单个part且有UV assert not mesh.has_multi_parts, "Mesh should be processed and merged to single part" assert mesh.has_uv, "Mesh should have UV mapping after processing" if save_path is None: temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb") save_path = temp_file.name temp_file.close() # 创建材质 if texture_map is not None: if type(texture_map) is np.ndarray: texture_map = Image.fromarray(texture_map) assert type(texture_map) is Image.Image, "texture_map should be a PIL.Image" texture_map = texture_map.transpose(Image.FLIP_TOP_BOTTOM).convert("RGB") material = trimesh.visual.material.PBRMaterial( baseColorTexture=texture_map, baseColorFactor=[255, 255, 255, 255], # 设置为白色以避免颜色混合 metallicFactor=0.0, roughnessFactor=1.0 ) else: default_texture = Image.new("RGB", (1024, 1024), (200, 200, 200)) material = trimesh.visual.texture.SimpleMaterial(image=default_texture) # If vmapping exists (processed by xatlas), need to rebuild vertices to match UV layout if hasattr(mesh, '_vmapping') and mesh._vmapping is not None: # Use xatlas-generated UV layout to rebuild mesh vertices = mesh.v_pos[mesh._vmapping].cpu().numpy() faces = mesh.t_tex_idx.cpu().numpy() uvs = mesh.v_tex.cpu().numpy() else: # Original UV mapping, directly use original vertices and faces vertices = mesh.v_pos.cpu().numpy() faces = mesh.t_pos_idx.cpu().numpy() uvs = mesh.v_tex.cpu().numpy() # If upside_down transformation was applied, need to apply face orientation correction if hasattr(mesh, '_upside_down_applied') and mesh._upside_down_applied: faces_corrected = faces.copy() faces_corrected[:, [1, 2]] = faces[:, [2, 1]] # (0,1,2) -> (0,2,1) faces = faces_corrected # Apply inverse transformation to convert vertices from rendering coordinate system back to GLB coordinate system # This is the inverse of vertex_transform: # vertex_transform: y = -z, z = y # inverse transformation: y = z, z = -y vertices_export = vertices.copy() vertices_export[:, 1] = vertices[:, 2] # z → y vertices_export[:, 2] = -vertices[:, 1] # -y → z # Create Trimesh object and set texture mesh_export = trimesh.Trimesh(vertices=vertices_export, faces=faces, process=False) mesh_export.visual = trimesh.visual.TextureVisuals(uv=uvs, material=material) # Export GLB file mesh_export.export(file_obj=save_path, file_type='glb') return save_path @classmethod @spaces.GPU(duration=30) def process(cls, mesh_file, uv_tool="xAtlas", y2z=True, y2x=False, z2x=False, upside_down=False, img_size=(512, 512), uv_size=(1024, 1024), device='cuda', progress=gr.Progress()): """ Handle the mesh processing, which includes normalization, parts merging, and UV mapping. Then render the untextured mesh from four views. :param mesh_file: uploaded mesh file. :param uv_tool: the UV parameterization tool, default is "xAtlas". :return: rendered clay model images from four views. """ # load mesh (automatically merge multiple parts) mesh: Mesh = cls(mesh_file, uv_tool, device, progress=progress) progress(0.7, f"Handling transformation and normalization...") # normalize mesh if y2z: mesh.vertex_transform() # transform vertices and normals if y2x: mesh.vertex_transform_y2x() if z2x: mesh.vertex_transform_z2x() if upside_down: mesh.vertex_transform_upsidedown() mesh.normalize() # render preparation texture = get_pure_texture(uv_size).to(device) # tensor of shape (3, height, width) # lights = setup_lights() lights = None mvp_matrix, w2c = get_mvp_matrix(mesh) mvp_matrix = mvp_matrix.to(device) w2c = w2c.to(device) # render untextured mesh from four views # images = render_views(mesh, texture, mvp_matrix, lights, img_size) # PIL.Image progress(0.8, f"Rendering clay model views...") print(f"Rendering geometry views...") position_images, normal_images, mask_images = render_geo_views_tensor(mesh, mvp_matrix, img_size) # torch.Tensor # [batch_size, height, width, 3] progress(0.9, f"Rendering geometry maps...") print(f"Rendering geometry maps...") position_map, normal_map = render_geo_map(mesh) progress(1, f"Mesh processing completed.") position_map_path = save_tensor_to_file(position_map, prefix="position_map") normal_map_path = save_tensor_to_file(normal_map, prefix="normal_map") position_images_path = save_tensor_to_file(position_images, prefix="position_images") normal_images_path = save_tensor_to_file(normal_images, prefix="normal_images") mask_images_path = save_tensor_to_file(mask_images.squeeze(-1), prefix="mask_images") w2c_path = save_tensor_to_file(w2c, prefix="w2c") mvp_matrix_path = save_tensor_to_file(mvp_matrix, prefix="mvp_matrix") # Return mesh instance as is return position_map_path, normal_map_path, position_images_path, normal_images_path, mask_images_path, w2c_path, mesh.to("cpu"), mvp_matrix_path, "Mesh processing completed." if __name__ == '__main__': glb_path = "/mnt/pfs/users/yuanze/projects/clean_seqtex/gradio/examples/multi_parts.glb" position_map, normal_map, position_images, normal_images, w2c = Mesh.process(glb_path) position_map.save("position_map.png") normal_map.save("normal_map.png") # 将 [-1, 1] 范围的normal_images save PIL # normal_images = rearrange(normal_images, "B H W C -> B C H W") # save_image(normal_images, "normal_images.png", normalize=True, value_range=(-1, 1))