SeqTex / utils /mesh_utils.py
yuanze1024's picture
init space 2
6d4bcdf
import tempfile
import gradio as gr
import numpy as np
import spaces
import torch
import trimesh
import xatlas
from PIL import Image
from .render_utils import (get_mvp_matrix, get_pure_texture, render_geo_map,
render_geo_views_tensor, render_views, setup_lights)
from utils.file_utils import save_tensor_to_file
class Mesh:
def __init__(self, mesh_path=None, uv_tool="xAtlas", device='cuda', progress=gr.Progress()):
"""
Initialize the Mesh object with a mesh file path.
:param mesh_path: Path to the mesh file (e.g., .obj or .glb).
"""
self._device = device
if mesh_path is not None:
# Initialize _parts dictionary to store all parts
self._parts = {}
if mesh_path.endswith('.obj'):
progress(0., f"Loading mesh in .obj format...")
mesh_data = trimesh.load(mesh_path, process=False)
# Check if it's a mesh list (multi-part obj)
if isinstance(mesh_data, list):
progress(0.1, f"Handling part list...")
for i, mesh_part in enumerate(mesh_data):
self._add_part_to_parts(f"part_{i}", mesh_part)
# Check if it's a Scene (another multi-part format)
elif isinstance(mesh_data, trimesh.Scene):
progress(0.1, f"Handling Scenes...")
geometry = mesh_data.geometry
if len(geometry) > 0:
for key, mesh_part in geometry.items():
self._add_part_to_parts(key, mesh_part)
else:
raise ValueError("Empty scene, no mesh data found.")
else:
# Single part obj
progress(0.1, f"Handling single part...")
self._add_part_to_parts("part_0", mesh_data)
elif mesh_path.endswith('.glb'):
progress(0., f"Loading mesh in .glb format...")
mesh_loaded = trimesh.load(mesh_path)
# Check if it's a Scene (multi-part glb)
if isinstance(mesh_loaded, trimesh.Scene):
progress(0.1, f"Handling Scenes...")
geometry = mesh_loaded.geometry
if len(geometry) > 0:
for key, mesh_part in geometry.items():
self._add_part_to_parts(key, mesh_part)
else:
raise ValueError("Empty scene, no mesh data found.")
else:
# Single part glb
progress(0.1, f"Handling single part...")
self._add_part_to_parts("part_0", mesh_loaded)
else:
raise ValueError(f"Unsupported file format: {mesh_path}")
# Automatically merge all parts during initialization
progress(0.2, f"Merging if the mesh have multiple parts.")
self._merge_parts_internal()
else:
raise ValueError("Mesh path cannot be None.")
self.to(self.device) # Move to the specified device
# Initialize transformation flags
self._upside_down_applied = False
# UV parameterization
if self.has_multi_parts or not self.has_uv:
progress(0.4, f"Using {uv_tool} for UV parameterization. It may take quite a while (several minutes), if there are many faces. We STRONLY recommend using a mesh with UV parameterization.")
if uv_tool == "xAtlas":
self.uv_xatlas_mapping() # Use default parameters
elif uv_tool == "UVAtlas":
raise NotImplementedError("UVAtlas parameterization is not implemented yet.")
else:
raise ValueError("Unsupported UV parameterization tool.")
print("UV parameterization completed.")
else:
progress(0.4, f"The model has SINGLE UV parameterization, no need to reparameterize.")
self._vmapping = None # No vmapping needed when not reparameterizing
@property
def device(self):
return self._device
def to(self, device):
"""
Move the mesh data to the specified device.
:param device: The target device (e.g., 'cuda' or 'cpu').
"""
self._device = device
self._v_pos = self._v_pos.to(device)
self._t_pos_idx = self._t_pos_idx.to(device)
if self._v_tex is not None:
self._v_tex = self._v_tex.to(device)
self._t_tex_idx = self._t_tex_idx.to(device)
if hasattr(self, '_vmapping') and self._vmapping is not None:
self._vmapping = self._vmapping.to(device)
self._v_normal = self._v_normal.to(device)
return self
@property
def has_multi_parts(self):
"""
Check if the mesh has multiple parts.
:return: Boolean indicating whether the mesh has multiple parts.
"""
# If _parts is None, it means already merged, not multi-part
if self._parts is None:
return False
return len(self._parts) > 1
@property
def v_pos(self):
"""Vertex positions property."""
return self._v_pos
@v_pos.setter
def v_pos(self, value):
self._v_pos = value
@property
def t_pos_idx(self):
"""Triangle position indices property."""
return self._t_pos_idx
@t_pos_idx.setter
def t_pos_idx(self, value):
self._t_pos_idx = value
@property
def v_tex(self):
"""Vertex texture coordinates property."""
return self._v_tex
@v_tex.setter
def v_tex(self, value):
self._v_tex = value
@property
def t_tex_idx(self):
"""Triangle texture indices property."""
return self._t_tex_idx
@t_tex_idx.setter
def t_tex_idx(self, value):
self._t_tex_idx = value
@property
def v_normal(self):
"""Vertex normals property."""
return self._v_normal
@v_normal.setter
def v_normal(self, value):
self._v_normal = value
@property
def has_uv(self):
"""
Check if the mesh has a valid UV mapping.
:return: Boolean indicating whether the mesh has UV mapping.
"""
return self.v_tex is not None
def uv_xatlas_mapping(self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}):
# Merged mesh, directly add_mesh as a whole
atlas = xatlas.Atlas()
v_pos_np = self.v_pos.detach().cpu().numpy()
t_pos_idx_np = self.t_pos_idx.cpu().numpy()
atlas.add_mesh(v_pos_np, t_pos_idx_np)
# Set reasonable pack parameters to avoid overlap
co = xatlas.ChartOptions()
po = xatlas.PackOptions()
# Recommended default parameters
if 'resolution' not in xatlas_pack_options:
po.resolution = 1024 # or larger
if 'padding' not in xatlas_pack_options:
po.padding = 2
for k, v in xatlas_chart_options.items():
setattr(co, k, v)
for k, v in xatlas_pack_options.items():
setattr(po, k, v)
atlas.generate(co, po)
# Get unpacked data
vmapping, indices, uvs = atlas.get_mesh(0)
# vmapping: new UV vertex -> original mesh vertex
# indices: new triangle face indices (based on new UV vertices)
# uvs: new UV vertex coordinates
device = self.v_pos.device
vmapping = torch.from_numpy(vmapping.astype(np.uint64, casting="same_kind").view(np.int64)).to(device).long()
uvs = torch.from_numpy(uvs).to(device).float()
indices = torch.from_numpy(indices.astype(np.uint64, casting="same_kind").view(np.int64)).to(device).long()
self.v_tex = uvs # new UV vertices
self.t_tex_idx = indices # new triangle face indices (based on UV vertices)
self._vmapping = vmapping # save UV vertex to original vertex mapping for export
def normalize(self):
"""
Normalize mesh vertices to [-1, 1] range.
"""
vertices = self.v_pos
bounding_box_max = vertices.max(0)[0]
bounding_box_min = vertices.min(0)[0]
mesh_scale = 2.0 # Scale to [-1, 1]
scale = mesh_scale / ((bounding_box_max - bounding_box_min).max() + 1e-6)
center_offset = (bounding_box_max + bounding_box_min) * 0.5
self.v_pos = (vertices - center_offset) * scale
def vertex_transform(self):
"""
Apply coordinate transformation to mesh vertices and normals.
"""
# Transform normals
pre_normals = self.v_normal
normals = torch.clone(pre_normals)
normals[:, 1] = -pre_normals[:, 2] # -z --> y
normals[:, 2] = pre_normals[:, 1] # y --> z
# Transform vertices
pre_vertices = self.v_pos
vertices = torch.clone(pre_vertices)
vertices[:, 1] = -pre_vertices[:, 2] # -z --> y
vertices[:, 2] = pre_vertices[:, 1] # y --> z
# Update mesh
self.v_normal = normals
self.v_pos = vertices
def vertex_transform_y2x(self):
"""
Apply coordinate transformation to mesh vertices and normals.
"""
# Transform normals
pre_normals = self.v_normal
normals = torch.clone(pre_normals)
normals[:, 1] = -pre_normals[:, 0] # -x --> y
normals[:, 0] = pre_normals[:, 1] # y --> x
# Transform vertices
pre_vertices = self.v_pos
vertices = torch.clone(pre_vertices)
vertices[:, 1] = -pre_vertices[:, 0] # -z --> y
vertices[:, 0] = pre_vertices[:, 1] # y --> z
# 更新网格
self.v_normal = normals
self.v_pos = vertices
def vertex_transform_z2x(self):
"""
Apply coordinate transformation to mesh vertices and normals.
"""
# 变换法向量
pre_normals = self.v_normal
normals = torch.clone(pre_normals)
normals[:, 2] = -pre_normals[:, 0] # -x --> z
normals[:, 0] = pre_normals[:, 2] # z --> x
# 变换顶点
pre_vertices = self.v_pos
vertices = torch.clone(pre_vertices)
vertices[:, 2] = -pre_vertices[:, 0] # -z --> y
vertices[:, 0] = pre_vertices[:, 2] # y --> z
# 更新网格
self.v_normal = normals
self.v_pos = vertices
def vertex_transform_upsidedown(self):
"""
Apply upside-down transformation to mesh vertices and normals.
"""
# 变换法向量
pre_normals = self.v_normal
normals = torch.clone(pre_normals)
normals[:, 2] = -pre_normals[:, 2]
# 变换顶点
pre_vertices = self.v_pos
vertices = torch.clone(pre_vertices)
vertices[:, 2] = -pre_vertices[:, 2]
# 更新网格
self.v_normal = normals
self.v_pos = vertices
# self.t_pos_idx = faces
# 标记已应用上下翻转变换
self._upside_down_applied = True
def _add_part_to_parts(self, key, mesh_part):
"""
将单个mesh部分添加到_parts字典中
:param key: 部分的键名
:param mesh_part: trimesh对象
"""
# exclude PointCloud parts and empty parts
if hasattr(mesh_part, 'vertices') and hasattr(mesh_part, 'faces') and len(mesh_part.vertices) > 0 and len(mesh_part.faces) > 0:
raw_uv = getattr(mesh_part.visual, 'uv', None)
processed_v_tex = None
processed_t_tex_idx = None
# 仅当UV数据存在且不为空时才处理
if raw_uv is not None and np.asarray(raw_uv).size > 0 and np.asarray(raw_uv).shape[0] > 0:
processed_v_tex = torch.tensor(raw_uv, dtype=torch.float32)
# 假设当源数据提供UV时,t_tex_idx 与 t_pos_idx 使用相同的面索引
# trimesh 通常提供每个顶点的UV
processed_t_tex_idx = torch.tensor(mesh_part.faces, dtype=torch.int32)
self._parts[key] = {
'v_pos': torch.tensor(mesh_part.vertices, dtype=torch.float32),
't_pos_idx': torch.tensor(mesh_part.faces, dtype=torch.int32),
'v_tex': processed_v_tex,
't_tex_idx': processed_t_tex_idx,
'v_normal': torch.tensor(mesh_part.vertex_normals, dtype=torch.float32)
}
def _merge_parts_internal(self):
"""
内部使用的合并函数,在初始化时自动调用
将_parts中的所有部分合并为单一的mesh表示
"""
# 如果没有部分或只有一个部分,简化处理
if not self._parts:
raise ValueError("No mesh parts.")
elif len(self._parts) == 1:
key = next(iter(self._parts))
part = self._parts[key]
self._v_pos = part['v_pos']
self._t_pos_idx = part['t_pos_idx']
self._v_tex = part['v_tex']
self._t_tex_idx = part['t_tex_idx']
self._v_normal = part['v_normal']
self._parts = None # 清理_parts字典,释放内存
return
# 初始化合并后的数据
vertices = []
faces = []
normals = []
# Record vertex count for each part, used to adjust face indices
v_count = 0
# Iterate through all parts
for key, part in self._parts.items():
# Add vertices
vertices.append(part['v_pos'])
# Adjust face indices and add
if len(faces) > 0:
adjusted_faces = part['t_pos_idx'] + v_count
faces.append(adjusted_faces)
else:
faces.append(part['t_pos_idx'])
# Add normals
normals.append(part['v_normal'])
# Update vertex count
v_count += part['v_pos'].shape[0]
self._parts = None # Clear _parts dictionary to free memory
# Merge all data
self._v_pos = torch.cat(vertices, dim=0)
self._t_pos_idx = torch.cat(faces, dim=0)
self._v_normal = torch.cat(normals, dim=0)
self._v_tex = None # multi-parts mesh must be reparameterized
self._t_tex_idx = None # multi-parts mesh must be reparameterized
self._vmapping = None # multi-parts mesh must be reparameterized
@classmethod
def export(cls, mesh, save_path=None, texture_map: Image.Image = None):
"""
Exports the mesh to a GLB file.
:param mesh: Mesh instance to export
:param save_path: Optional path to save the GLB file. If None, a temporary file will be created.
:param texture_map: Optional PIL.Image to use as the texture. If None, a default texture will be used.
:return: Path to the exported GLB file.
"""
# 由于传入的mesh一定是process过的,所以断言确保是单个part且有UV
assert not mesh.has_multi_parts, "Mesh should be processed and merged to single part"
assert mesh.has_uv, "Mesh should have UV mapping after processing"
if save_path is None:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
save_path = temp_file.name
temp_file.close()
# 创建材质
if texture_map is not None:
if type(texture_map) is np.ndarray:
texture_map = Image.fromarray(texture_map)
assert type(texture_map) is Image.Image, "texture_map should be a PIL.Image"
texture_map = texture_map.transpose(Image.FLIP_TOP_BOTTOM).convert("RGB")
material = trimesh.visual.material.PBRMaterial(
baseColorTexture=texture_map,
baseColorFactor=[255, 255, 255, 255], # 设置为白色以避免颜色混合
metallicFactor=0.0,
roughnessFactor=1.0
)
else:
default_texture = Image.new("RGB", (1024, 1024), (200, 200, 200))
material = trimesh.visual.texture.SimpleMaterial(image=default_texture)
# If vmapping exists (processed by xatlas), need to rebuild vertices to match UV layout
if hasattr(mesh, '_vmapping') and mesh._vmapping is not None:
# Use xatlas-generated UV layout to rebuild mesh
vertices = mesh.v_pos[mesh._vmapping].cpu().numpy()
faces = mesh.t_tex_idx.cpu().numpy()
uvs = mesh.v_tex.cpu().numpy()
else:
# Original UV mapping, directly use original vertices and faces
vertices = mesh.v_pos.cpu().numpy()
faces = mesh.t_pos_idx.cpu().numpy()
uvs = mesh.v_tex.cpu().numpy()
# If upside_down transformation was applied, need to apply face orientation correction
if hasattr(mesh, '_upside_down_applied') and mesh._upside_down_applied:
faces_corrected = faces.copy()
faces_corrected[:, [1, 2]] = faces[:, [2, 1]] # (0,1,2) -> (0,2,1)
faces = faces_corrected
# Apply inverse transformation to convert vertices from rendering coordinate system back to GLB coordinate system
# This is the inverse of vertex_transform:
# vertex_transform: y = -z, z = y
# inverse transformation: y = z, z = -y
vertices_export = vertices.copy()
vertices_export[:, 1] = vertices[:, 2] # z → y
vertices_export[:, 2] = -vertices[:, 1] # -y → z
# Create Trimesh object and set texture
mesh_export = trimesh.Trimesh(vertices=vertices_export, faces=faces, process=False)
mesh_export.visual = trimesh.visual.TextureVisuals(uv=uvs, material=material)
# Export GLB file
mesh_export.export(file_obj=save_path, file_type='glb')
return save_path
@classmethod
@spaces.GPU(duration=30)
def process(cls, mesh_file, uv_tool="xAtlas", y2z=True, y2x=False, z2x=False, upside_down=False, img_size=(512, 512), uv_size=(1024, 1024), device='cuda', progress=gr.Progress()):
"""
Handle the mesh processing, which includes normalization, parts merging, and UV mapping.
Then render the untextured mesh from four views.
:param mesh_file: uploaded mesh file.
:param uv_tool: the UV parameterization tool, default is "xAtlas".
:return: rendered clay model images from four views.
"""
# load mesh (automatically merge multiple parts)
mesh: Mesh = cls(mesh_file, uv_tool, device, progress=progress)
progress(0.7, f"Handling transformation and normalization...")
# normalize mesh
if y2z:
mesh.vertex_transform() # transform vertices and normals
if y2x:
mesh.vertex_transform_y2x()
if z2x:
mesh.vertex_transform_z2x()
if upside_down:
mesh.vertex_transform_upsidedown()
mesh.normalize()
# render preparation
texture = get_pure_texture(uv_size).to(device) # tensor of shape (3, height, width)
# lights = setup_lights()
lights = None
mvp_matrix, w2c = get_mvp_matrix(mesh)
mvp_matrix = mvp_matrix.to(device)
w2c = w2c.to(device)
# render untextured mesh from four views
# images = render_views(mesh, texture, mvp_matrix, lights, img_size) # PIL.Image
progress(0.8, f"Rendering clay model views...")
print(f"Rendering geometry views...")
position_images, normal_images, mask_images = render_geo_views_tensor(mesh, mvp_matrix, img_size) # torch.Tensor # [batch_size, height, width, 3]
progress(0.9, f"Rendering geometry maps...")
print(f"Rendering geometry maps...")
position_map, normal_map = render_geo_map(mesh)
progress(1, f"Mesh processing completed.")
position_map_path = save_tensor_to_file(position_map, prefix="position_map")
normal_map_path = save_tensor_to_file(normal_map, prefix="normal_map")
position_images_path = save_tensor_to_file(position_images, prefix="position_images")
normal_images_path = save_tensor_to_file(normal_images, prefix="normal_images")
mask_images_path = save_tensor_to_file(mask_images.squeeze(-1), prefix="mask_images")
w2c_path = save_tensor_to_file(w2c, prefix="w2c")
mvp_matrix_path = save_tensor_to_file(mvp_matrix, prefix="mvp_matrix")
# Return mesh instance as is
return position_map_path, normal_map_path, position_images_path, normal_images_path, mask_images_path, w2c_path, mesh.to("cpu"), mvp_matrix_path, "Mesh processing completed."
if __name__ == '__main__':
glb_path = "/mnt/pfs/users/yuanze/projects/clean_seqtex/gradio/examples/multi_parts.glb"
position_map, normal_map, position_images, normal_images, w2c = Mesh.process(glb_path)
position_map.save("position_map.png")
normal_map.save("normal_map.png")
# 将 [-1, 1] 范围的normal_images save PIL
# normal_images = rearrange(normal_images, "B H W C -> B C H W")
# save_image(normal_images, "normal_images.png", normalize=True, value_range=(-1, 1))