import spaces import os import gradio as gr import numpy as np import torch from PIL import Image import trimesh import random from transformers import AutoModelForImageSegmentation from torchvision import transforms from huggingface_hub import hf_hub_download, snapshot_download, login import subprocess import shutil DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 print("DEVICE: ", DEVICE) DEFAULT_PART_FACE_NUMBER = 10000 MAX_SEED = np.iinfo(np.int32).max HOLOPART_REPO_URL = "https://github.com/VAST-AI-Research/HoloPart" HOLOPART_PRETRAINED_MODEL = "checkpoints/HoloPart" TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") os.makedirs(TMP_DIR, exist_ok=True) HOLOPART_CODE_DIR = "./holopart" if not os.path.exists(HOLOPART_REPO_URL): os.system(f"git clone {HOLOPART_REPO_URL} {HOLOPART_CODE_DIR}") import sys sys.path.append(HOLOPART_CODE_DIR) sys.path.append(os.path.join(HOLOPART_CODE_DIR, "scripts")) EXAMPLES = [ ["./holopart/assets/example_data/000.glb", "./holopart/assets/example_data/000.png"], ["./holopart/assets/example_data/001.glb", "./holopart/assets/example_data/001.png"], ["./holopart/assets/example_data/002.glb", "./holopart/assets/example_data/002.png"], ["./holopart/assets/example_data/003.glb", "./holopart/assets/example_data/003.png"], ] HEADER = """ # 🔮 Decompose a 3D shape into complete parts with [HoloPart](https://github.com/VAST-AI-Research/HoloPart). ### Step 1: Prepare Your Segmented Mesh Upload a mesh with part segmentation. We recommend using these segmentation tools: - [SAMPart3D](https://github.com/Pointcept/SAMPart3D) - [SAMesh](https://github.com/gtangg12/samesh) For a mesh file `mesh.glb` and corresponding face mask `mask.npy`, prepare your input using this Python code: ```python import trimesh import numpy as np mesh = trimesh.load("mesh.glb", force="mesh") mask_npy = np.load("mask.npy") mesh_parts = [] for part_id in np.unique(mask_npy): mesh_part = mesh.submesh([mask_npy == part_id], append=True) mesh_parts.append(mesh_part) mesh_parts = trimesh.Scene(mesh_parts).export("input_mesh.glb") ``` The resulting **input_mesh.glb** is your prepared input for HoloPart. ### Step 2: Click the Decompose Parts button to begin the decomposition process. """ from inference_holopart import prepare_data, run_holopart from holopart.pipelines.pipeline_holopart import HoloPartPipeline snapshot_download("VAST-AI/HoloPart", local_dir=HOLOPART_PRETRAINED_MODEL) holopart_pipe = HoloPartPipeline.from_pretrained(HOLOPART_PRETRAINED_MODEL).to(DEVICE, DTYPE) def start_session(req: gr.Request): save_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(save_dir, exist_ok=True) print("start session, mkdir", save_dir) def end_session(req: gr.Request): save_dir = os.path.join(TMP_DIR, str(req.session_hash)) shutil.rmtree(save_dir) def get_random_hex(): random_bytes = os.urandom(8) random_hex = random_bytes.hex() return random_hex def get_random_seed(randomize_seed, seed): if randomize_seed: seed = random.randint(0, MAX_SEED) return seed def explode_mesh(mesh: trimesh.Scene, explode_factor: float = 0.5): center = mesh.centroid exploded_mesh = trimesh.Scene() for geometry_name, geometry in mesh.geometry.items(): transform = mesh.graph[geometry_name][0] vertices_global = trimesh.transformations.transform_points( geometry.vertices, transform) part_center = np.mean(vertices_global, axis=0) direction = part_center - center direction_length = np.linalg.norm(direction) if direction_length > 0: direction = direction / direction_length displacement = direction * explode_factor new_transform = np.copy(transform) new_transform[:3, 3] += displacement exploded_mesh.add_geometry(geometry, transform=new_transform, geom_name=geometry_name) return exploded_mesh @spaces.GPU(duration=600) def run_full(data_path, seed=42, num_inference_steps=25, guidance_scale=3.5): batch_size = 30 parts_data = prepare_data(data_path) part_scene = run_holopart( holopart_pipe, batch=parts_data, batch_size=batch_size, seed=seed, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_chunks=1000000, ) print("mesh extraction done") save_dir = os.path.join(TMP_DIR, "examples") os.makedirs(save_dir, exist_ok=True) mesh_path = os.path.join(save_dir, f"holorpart_{get_random_hex()}.glb") part_scene.export(mesh_path) print("save to ", mesh_path) exploded_mesh = explode_mesh(part_scene, 0.7) exploded_mesh_path = os.path.join(save_dir, f"holorpart_exploded_{get_random_hex()}.glb") exploded_mesh.export(exploded_mesh_path) torch.cuda.empty_cache() return mesh_path, exploded_mesh_path @spaces.GPU(duration=600) def run_example(data_path: str, example_image_path, seed=42, num_inference_steps=25, guidance_scale=3.5): batch_size = 30 parts_data = prepare_data(data_path) part_scene = run_holopart( holopart_pipe, batch=parts_data, batch_size=batch_size, seed=seed, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_chunks=1000000, ) print("mesh extraction done") save_dir = os.path.join(TMP_DIR, "examples") os.makedirs(save_dir, exist_ok=True) mesh_path = os.path.join(save_dir, f"holorpart_{get_random_hex()}.glb") part_scene.export(mesh_path) print("save to ", mesh_path) exploded_mesh = explode_mesh(part_scene, 0.5) exploded_mesh_path = os.path.join(save_dir, f"holorpart_exploded_{get_random_hex()}.glb") exploded_mesh.export(exploded_mesh_path) torch.cuda.empty_cache() return mesh_path, exploded_mesh_path with gr.Blocks(title="HoloPart") as demo: gr.Markdown(HEADER) with gr.Row(): with gr.Column(): with gr.Row(): input_mesh = gr.Model3D(label="Input Mesh") example_image = gr.Image(label="Example Image", type="filepath", interactive=False, visible=False) # seg_image = gr.Image( # label="Segmentation Result", type="pil", format="png", interactive=False # ) with gr.Accordion("Generation Settings", open=True): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=0, value=0 ) # randomize_seed = gr.Checkbox(label="Randomize seed", value=True) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=8, maximum=50, step=1, value=25, ) guidance_scale = gr.Slider( label="CFG scale", minimum=0.0, maximum=20.0, step=0.1, value=3.5, ) with gr.Row(): reduce_face = gr.Checkbox(label="Simplify Mesh", value=True, interactive=False) # target_face_num = gr.Slider(maximum=1000000, minimum=10000, value=DEFAULT_FACE_NUMBER, label="Target Face Number") gen_button = gr.Button("Decompose Parts", variant="primary") with gr.Column(): model_output = gr.Model3D(label="Decomposed GLB", interactive=False) exploded_parts_output = gr.Model3D(label="Exploded Parts", interactive=False) with gr.Row(): examples = gr.Examples( examples=EXAMPLES, fn=run_example, inputs=[input_mesh, example_image], outputs=[model_output, exploded_parts_output], cache_examples=True, ) gen_button.click( run_full, inputs=[ input_mesh, seed, num_inference_steps, guidance_scale ], outputs=[model_output, exploded_parts_output], ) demo.load(start_session) demo.unload(end_session) demo.launch()