Spaces:
Running
on
Zero
Running
on
Zero
import tempfile | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
import trimesh | |
import xatlas | |
from PIL import Image | |
from .render_utils import (get_mvp_matrix, get_pure_texture, render_geo_map, | |
render_geo_views_tensor, render_views, setup_lights) | |
from utils.file_utils import save_tensor_to_file | |
class Mesh: | |
def __init__(self, mesh_path=None, uv_tool="xAtlas", device='cuda', progress=gr.Progress()): | |
""" | |
Initialize the Mesh object with a mesh file path. | |
:param mesh_path: Path to the mesh file (e.g., .obj or .glb). | |
""" | |
self._device = device | |
if mesh_path is not None: | |
# Initialize _parts dictionary to store all parts | |
self._parts = {} | |
if mesh_path.endswith('.obj'): | |
progress(0., f"Loading mesh in .obj format...") | |
mesh_data = trimesh.load(mesh_path, process=False) | |
# Check if it's a mesh list (multi-part obj) | |
if isinstance(mesh_data, list): | |
progress(0.1, f"Handling part list...") | |
for i, mesh_part in enumerate(mesh_data): | |
self._add_part_to_parts(f"part_{i}", mesh_part) | |
# Check if it's a Scene (another multi-part format) | |
elif isinstance(mesh_data, trimesh.Scene): | |
progress(0.1, f"Handling Scenes...") | |
geometry = mesh_data.geometry | |
if len(geometry) > 0: | |
for key, mesh_part in geometry.items(): | |
self._add_part_to_parts(key, mesh_part) | |
else: | |
raise ValueError("Empty scene, no mesh data found.") | |
else: | |
# Single part obj | |
progress(0.1, f"Handling single part...") | |
self._add_part_to_parts("part_0", mesh_data) | |
elif mesh_path.endswith('.glb'): | |
progress(0., f"Loading mesh in .glb format...") | |
mesh_loaded = trimesh.load(mesh_path) | |
# Check if it's a Scene (multi-part glb) | |
if isinstance(mesh_loaded, trimesh.Scene): | |
progress(0.1, f"Handling Scenes...") | |
geometry = mesh_loaded.geometry | |
if len(geometry) > 0: | |
for key, mesh_part in geometry.items(): | |
self._add_part_to_parts(key, mesh_part) | |
else: | |
raise ValueError("Empty scene, no mesh data found.") | |
else: | |
# Single part glb | |
progress(0.1, f"Handling single part...") | |
self._add_part_to_parts("part_0", mesh_loaded) | |
else: | |
raise ValueError(f"Unsupported file format: {mesh_path}") | |
# Automatically merge all parts during initialization | |
progress(0.2, f"Merging if the mesh have multiple parts.") | |
self._merge_parts_internal() | |
else: | |
raise ValueError("Mesh path cannot be None.") | |
self.to(self.device) # Move to the specified device | |
# Initialize transformation flags | |
self._upside_down_applied = False | |
# UV parameterization | |
if self.has_multi_parts or not self.has_uv: | |
progress(0.4, f"Using {uv_tool} for UV parameterization. It may take quite a while (several minutes), if there are many faces. We STRONLY recommend using a mesh with UV parameterization.") | |
if uv_tool == "xAtlas": | |
self.uv_xatlas_mapping() # Use default parameters | |
elif uv_tool == "UVAtlas": | |
raise NotImplementedError("UVAtlas parameterization is not implemented yet.") | |
else: | |
raise ValueError("Unsupported UV parameterization tool.") | |
print("UV parameterization completed.") | |
else: | |
progress(0.4, f"The model has SINGLE UV parameterization, no need to reparameterize.") | |
self._vmapping = None # No vmapping needed when not reparameterizing | |
def device(self): | |
return self._device | |
def to(self, device): | |
""" | |
Move the mesh data to the specified device. | |
:param device: The target device (e.g., 'cuda' or 'cpu'). | |
""" | |
self._device = device | |
self._v_pos = self._v_pos.to(device) | |
self._t_pos_idx = self._t_pos_idx.to(device) | |
if self._v_tex is not None: | |
self._v_tex = self._v_tex.to(device) | |
self._t_tex_idx = self._t_tex_idx.to(device) | |
if hasattr(self, '_vmapping') and self._vmapping is not None: | |
self._vmapping = self._vmapping.to(device) | |
self._v_normal = self._v_normal.to(device) | |
return self | |
def has_multi_parts(self): | |
""" | |
Check if the mesh has multiple parts. | |
:return: Boolean indicating whether the mesh has multiple parts. | |
""" | |
# If _parts is None, it means already merged, not multi-part | |
if self._parts is None: | |
return False | |
return len(self._parts) > 1 | |
def v_pos(self): | |
"""Vertex positions property.""" | |
return self._v_pos | |
def v_pos(self, value): | |
self._v_pos = value | |
def t_pos_idx(self): | |
"""Triangle position indices property.""" | |
return self._t_pos_idx | |
def t_pos_idx(self, value): | |
self._t_pos_idx = value | |
def v_tex(self): | |
"""Vertex texture coordinates property.""" | |
return self._v_tex | |
def v_tex(self, value): | |
self._v_tex = value | |
def t_tex_idx(self): | |
"""Triangle texture indices property.""" | |
return self._t_tex_idx | |
def t_tex_idx(self, value): | |
self._t_tex_idx = value | |
def v_normal(self): | |
"""Vertex normals property.""" | |
return self._v_normal | |
def v_normal(self, value): | |
self._v_normal = value | |
def has_uv(self): | |
""" | |
Check if the mesh has a valid UV mapping. | |
:return: Boolean indicating whether the mesh has UV mapping. | |
""" | |
return self.v_tex is not None | |
def uv_xatlas_mapping(self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}): | |
# Merged mesh, directly add_mesh as a whole | |
atlas = xatlas.Atlas() | |
v_pos_np = self.v_pos.detach().cpu().numpy() | |
t_pos_idx_np = self.t_pos_idx.cpu().numpy() | |
atlas.add_mesh(v_pos_np, t_pos_idx_np) | |
# Set reasonable pack parameters to avoid overlap | |
co = xatlas.ChartOptions() | |
po = xatlas.PackOptions() | |
# Recommended default parameters | |
if 'resolution' not in xatlas_pack_options: | |
po.resolution = 1024 # or larger | |
if 'padding' not in xatlas_pack_options: | |
po.padding = 2 | |
for k, v in xatlas_chart_options.items(): | |
setattr(co, k, v) | |
for k, v in xatlas_pack_options.items(): | |
setattr(po, k, v) | |
atlas.generate(co, po) | |
# Get unpacked data | |
vmapping, indices, uvs = atlas.get_mesh(0) | |
# vmapping: new UV vertex -> original mesh vertex | |
# indices: new triangle face indices (based on new UV vertices) | |
# uvs: new UV vertex coordinates | |
device = self.v_pos.device | |
vmapping = torch.from_numpy(vmapping.astype(np.uint64, casting="same_kind").view(np.int64)).to(device).long() | |
uvs = torch.from_numpy(uvs).to(device).float() | |
indices = torch.from_numpy(indices.astype(np.uint64, casting="same_kind").view(np.int64)).to(device).long() | |
self.v_tex = uvs # new UV vertices | |
self.t_tex_idx = indices # new triangle face indices (based on UV vertices) | |
self._vmapping = vmapping # save UV vertex to original vertex mapping for export | |
def normalize(self): | |
""" | |
Normalize mesh vertices to [-1, 1] range. | |
""" | |
vertices = self.v_pos | |
bounding_box_max = vertices.max(0)[0] | |
bounding_box_min = vertices.min(0)[0] | |
mesh_scale = 2.0 # Scale to [-1, 1] | |
scale = mesh_scale / ((bounding_box_max - bounding_box_min).max() + 1e-6) | |
center_offset = (bounding_box_max + bounding_box_min) * 0.5 | |
self.v_pos = (vertices - center_offset) * scale | |
def vertex_transform(self): | |
""" | |
Apply coordinate transformation to mesh vertices and normals. | |
""" | |
# Transform normals | |
pre_normals = self.v_normal | |
normals = torch.clone(pre_normals) | |
normals[:, 1] = -pre_normals[:, 2] # -z --> y | |
normals[:, 2] = pre_normals[:, 1] # y --> z | |
# Transform vertices | |
pre_vertices = self.v_pos | |
vertices = torch.clone(pre_vertices) | |
vertices[:, 1] = -pre_vertices[:, 2] # -z --> y | |
vertices[:, 2] = pre_vertices[:, 1] # y --> z | |
# Update mesh | |
self.v_normal = normals | |
self.v_pos = vertices | |
def vertex_transform_y2x(self): | |
""" | |
Apply coordinate transformation to mesh vertices and normals. | |
""" | |
# Transform normals | |
pre_normals = self.v_normal | |
normals = torch.clone(pre_normals) | |
normals[:, 1] = -pre_normals[:, 0] # -x --> y | |
normals[:, 0] = pre_normals[:, 1] # y --> x | |
# Transform vertices | |
pre_vertices = self.v_pos | |
vertices = torch.clone(pre_vertices) | |
vertices[:, 1] = -pre_vertices[:, 0] # -z --> y | |
vertices[:, 0] = pre_vertices[:, 1] # y --> z | |
# 更新网格 | |
self.v_normal = normals | |
self.v_pos = vertices | |
def vertex_transform_z2x(self): | |
""" | |
Apply coordinate transformation to mesh vertices and normals. | |
""" | |
# 变换法向量 | |
pre_normals = self.v_normal | |
normals = torch.clone(pre_normals) | |
normals[:, 2] = -pre_normals[:, 0] # -x --> z | |
normals[:, 0] = pre_normals[:, 2] # z --> x | |
# 变换顶点 | |
pre_vertices = self.v_pos | |
vertices = torch.clone(pre_vertices) | |
vertices[:, 2] = -pre_vertices[:, 0] # -z --> y | |
vertices[:, 0] = pre_vertices[:, 2] # y --> z | |
# 更新网格 | |
self.v_normal = normals | |
self.v_pos = vertices | |
def vertex_transform_upsidedown(self): | |
""" | |
Apply upside-down transformation to mesh vertices and normals. | |
""" | |
# 变换法向量 | |
pre_normals = self.v_normal | |
normals = torch.clone(pre_normals) | |
normals[:, 2] = -pre_normals[:, 2] | |
# 变换顶点 | |
pre_vertices = self.v_pos | |
vertices = torch.clone(pre_vertices) | |
vertices[:, 2] = -pre_vertices[:, 2] | |
# 更新网格 | |
self.v_normal = normals | |
self.v_pos = vertices | |
# self.t_pos_idx = faces | |
# 标记已应用上下翻转变换 | |
self._upside_down_applied = True | |
def _add_part_to_parts(self, key, mesh_part): | |
""" | |
将单个mesh部分添加到_parts字典中 | |
:param key: 部分的键名 | |
:param mesh_part: trimesh对象 | |
""" | |
# exclude PointCloud parts and empty parts | |
if hasattr(mesh_part, 'vertices') and hasattr(mesh_part, 'faces') and len(mesh_part.vertices) > 0 and len(mesh_part.faces) > 0: | |
raw_uv = getattr(mesh_part.visual, 'uv', None) | |
processed_v_tex = None | |
processed_t_tex_idx = None | |
# 仅当UV数据存在且不为空时才处理 | |
if raw_uv is not None and np.asarray(raw_uv).size > 0 and np.asarray(raw_uv).shape[0] > 0: | |
processed_v_tex = torch.tensor(raw_uv, dtype=torch.float32) | |
# 假设当源数据提供UV时,t_tex_idx 与 t_pos_idx 使用相同的面索引 | |
# trimesh 通常提供每个顶点的UV | |
processed_t_tex_idx = torch.tensor(mesh_part.faces, dtype=torch.int32) | |
self._parts[key] = { | |
'v_pos': torch.tensor(mesh_part.vertices, dtype=torch.float32), | |
't_pos_idx': torch.tensor(mesh_part.faces, dtype=torch.int32), | |
'v_tex': processed_v_tex, | |
't_tex_idx': processed_t_tex_idx, | |
'v_normal': torch.tensor(mesh_part.vertex_normals, dtype=torch.float32) | |
} | |
def _merge_parts_internal(self): | |
""" | |
内部使用的合并函数,在初始化时自动调用 | |
将_parts中的所有部分合并为单一的mesh表示 | |
""" | |
# 如果没有部分或只有一个部分,简化处理 | |
if not self._parts: | |
raise ValueError("No mesh parts.") | |
elif len(self._parts) == 1: | |
key = next(iter(self._parts)) | |
part = self._parts[key] | |
self._v_pos = part['v_pos'] | |
self._t_pos_idx = part['t_pos_idx'] | |
self._v_tex = part['v_tex'] | |
self._t_tex_idx = part['t_tex_idx'] | |
self._v_normal = part['v_normal'] | |
self._parts = None # 清理_parts字典,释放内存 | |
return | |
# 初始化合并后的数据 | |
vertices = [] | |
faces = [] | |
normals = [] | |
# Record vertex count for each part, used to adjust face indices | |
v_count = 0 | |
# Iterate through all parts | |
for key, part in self._parts.items(): | |
# Add vertices | |
vertices.append(part['v_pos']) | |
# Adjust face indices and add | |
if len(faces) > 0: | |
adjusted_faces = part['t_pos_idx'] + v_count | |
faces.append(adjusted_faces) | |
else: | |
faces.append(part['t_pos_idx']) | |
# Add normals | |
normals.append(part['v_normal']) | |
# Update vertex count | |
v_count += part['v_pos'].shape[0] | |
self._parts = None # Clear _parts dictionary to free memory | |
# Merge all data | |
self._v_pos = torch.cat(vertices, dim=0) | |
self._t_pos_idx = torch.cat(faces, dim=0) | |
self._v_normal = torch.cat(normals, dim=0) | |
self._v_tex = None # multi-parts mesh must be reparameterized | |
self._t_tex_idx = None # multi-parts mesh must be reparameterized | |
self._vmapping = None # multi-parts mesh must be reparameterized | |
def export(cls, mesh, save_path=None, texture_map: Image.Image = None): | |
""" | |
Exports the mesh to a GLB file. | |
:param mesh: Mesh instance to export | |
:param save_path: Optional path to save the GLB file. If None, a temporary file will be created. | |
:param texture_map: Optional PIL.Image to use as the texture. If None, a default texture will be used. | |
:return: Path to the exported GLB file. | |
""" | |
# 由于传入的mesh一定是process过的,所以断言确保是单个part且有UV | |
assert not mesh.has_multi_parts, "Mesh should be processed and merged to single part" | |
assert mesh.has_uv, "Mesh should have UV mapping after processing" | |
if save_path is None: | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb") | |
save_path = temp_file.name | |
temp_file.close() | |
# 创建材质 | |
if texture_map is not None: | |
if type(texture_map) is np.ndarray: | |
texture_map = Image.fromarray(texture_map) | |
assert type(texture_map) is Image.Image, "texture_map should be a PIL.Image" | |
texture_map = texture_map.transpose(Image.FLIP_TOP_BOTTOM).convert("RGB") | |
material = trimesh.visual.material.PBRMaterial( | |
baseColorTexture=texture_map, | |
baseColorFactor=[255, 255, 255, 255], # 设置为白色以避免颜色混合 | |
metallicFactor=0.0, | |
roughnessFactor=1.0 | |
) | |
else: | |
default_texture = Image.new("RGB", (1024, 1024), (200, 200, 200)) | |
material = trimesh.visual.texture.SimpleMaterial(image=default_texture) | |
# If vmapping exists (processed by xatlas), need to rebuild vertices to match UV layout | |
if hasattr(mesh, '_vmapping') and mesh._vmapping is not None: | |
# Use xatlas-generated UV layout to rebuild mesh | |
vertices = mesh.v_pos[mesh._vmapping].cpu().numpy() | |
faces = mesh.t_tex_idx.cpu().numpy() | |
uvs = mesh.v_tex.cpu().numpy() | |
else: | |
# Original UV mapping, directly use original vertices and faces | |
vertices = mesh.v_pos.cpu().numpy() | |
faces = mesh.t_pos_idx.cpu().numpy() | |
uvs = mesh.v_tex.cpu().numpy() | |
# If upside_down transformation was applied, need to apply face orientation correction | |
if hasattr(mesh, '_upside_down_applied') and mesh._upside_down_applied: | |
faces_corrected = faces.copy() | |
faces_corrected[:, [1, 2]] = faces[:, [2, 1]] # (0,1,2) -> (0,2,1) | |
faces = faces_corrected | |
# Apply inverse transformation to convert vertices from rendering coordinate system back to GLB coordinate system | |
# This is the inverse of vertex_transform: | |
# vertex_transform: y = -z, z = y | |
# inverse transformation: y = z, z = -y | |
vertices_export = vertices.copy() | |
vertices_export[:, 1] = vertices[:, 2] # z → y | |
vertices_export[:, 2] = -vertices[:, 1] # -y → z | |
# Create Trimesh object and set texture | |
mesh_export = trimesh.Trimesh(vertices=vertices_export, faces=faces, process=False) | |
mesh_export.visual = trimesh.visual.TextureVisuals(uv=uvs, material=material) | |
# Export GLB file | |
mesh_export.export(file_obj=save_path, file_type='glb') | |
return save_path | |
def process(cls, mesh_file, uv_tool="xAtlas", y2z=True, y2x=False, z2x=False, upside_down=False, img_size=(512, 512), uv_size=(1024, 1024), device='cuda', progress=gr.Progress()): | |
""" | |
Handle the mesh processing, which includes normalization, parts merging, and UV mapping. | |
Then render the untextured mesh from four views. | |
:param mesh_file: uploaded mesh file. | |
:param uv_tool: the UV parameterization tool, default is "xAtlas". | |
:return: rendered clay model images from four views. | |
""" | |
# load mesh (automatically merge multiple parts) | |
mesh: Mesh = cls(mesh_file, uv_tool, device, progress=progress) | |
progress(0.7, f"Handling transformation and normalization...") | |
# normalize mesh | |
if y2z: | |
mesh.vertex_transform() # transform vertices and normals | |
if y2x: | |
mesh.vertex_transform_y2x() | |
if z2x: | |
mesh.vertex_transform_z2x() | |
if upside_down: | |
mesh.vertex_transform_upsidedown() | |
mesh.normalize() | |
# render preparation | |
texture = get_pure_texture(uv_size).to(device) # tensor of shape (3, height, width) | |
# lights = setup_lights() | |
lights = None | |
mvp_matrix, w2c = get_mvp_matrix(mesh) | |
mvp_matrix = mvp_matrix.to(device) | |
w2c = w2c.to(device) | |
# render untextured mesh from four views | |
# images = render_views(mesh, texture, mvp_matrix, lights, img_size) # PIL.Image | |
progress(0.8, f"Rendering clay model views...") | |
print(f"Rendering geometry views...") | |
position_images, normal_images, mask_images = render_geo_views_tensor(mesh, mvp_matrix, img_size) # torch.Tensor # [batch_size, height, width, 3] | |
progress(0.9, f"Rendering geometry maps...") | |
print(f"Rendering geometry maps...") | |
position_map, normal_map = render_geo_map(mesh) | |
progress(1, f"Mesh processing completed.") | |
position_map_path = save_tensor_to_file(position_map, prefix="position_map") | |
normal_map_path = save_tensor_to_file(normal_map, prefix="normal_map") | |
position_images_path = save_tensor_to_file(position_images, prefix="position_images") | |
normal_images_path = save_tensor_to_file(normal_images, prefix="normal_images") | |
mask_images_path = save_tensor_to_file(mask_images.squeeze(-1), prefix="mask_images") | |
w2c_path = save_tensor_to_file(w2c, prefix="w2c") | |
mvp_matrix_path = save_tensor_to_file(mvp_matrix, prefix="mvp_matrix") | |
# Return mesh instance as is | |
return position_map_path, normal_map_path, position_images_path, normal_images_path, mask_images_path, w2c_path, mesh.to("cpu"), mvp_matrix_path, "Mesh processing completed." | |
if __name__ == '__main__': | |
glb_path = "/mnt/pfs/users/yuanze/projects/clean_seqtex/gradio/examples/multi_parts.glb" | |
position_map, normal_map, position_images, normal_images, w2c = Mesh.process(glb_path) | |
position_map.save("position_map.png") | |
normal_map.save("normal_map.png") | |
# 将 [-1, 1] 范围的normal_images save PIL | |
# normal_images = rearrange(normal_images, "B H W C -> B C H W") | |
# save_image(normal_images, "normal_images.png", normalize=True, value_range=(-1, 1)) |