Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
import trimesh | |
import random | |
from transformers import AutoModelForImageSegmentation | |
from torchvision import transforms | |
from huggingface_hub import hf_hub_download, snapshot_download, login | |
import subprocess | |
import shutil | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
DTYPE = torch.float16 | |
print("DEVICE: ", DEVICE) | |
DEFAULT_PART_FACE_NUMBER = 10000 | |
MAX_SEED = np.iinfo(np.int32).max | |
HOLOPART_REPO_URL = "https://github.com/VAST-AI-Research/HoloPart" | |
HOLOPART_PRETRAINED_MODEL = "checkpoints/HoloPart" | |
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") | |
os.makedirs(TMP_DIR, exist_ok=True) | |
HOLOPART_CODE_DIR = "./holopart" | |
if not os.path.exists(HOLOPART_REPO_URL): | |
os.system(f"git clone {HOLOPART_REPO_URL} {HOLOPART_CODE_DIR}") | |
import sys | |
sys.path.append(HOLOPART_CODE_DIR) | |
sys.path.append(os.path.join(HOLOPART_CODE_DIR, "scripts")) | |
EXAMPLES = [ | |
["./holopart/assets/example_data/000.glb", "./holopart/assets/example_data/000.png"], | |
["./holopart/assets/example_data/001.glb", "./holopart/assets/example_data/001.png"], | |
["./holopart/assets/example_data/002.glb", "./holopart/assets/example_data/002.png"], | |
["./holopart/assets/example_data/003.glb", "./holopart/assets/example_data/003.png"], | |
] | |
HEADER = """ | |
# 🔮 Decompose a 3D shape into complete parts with [HoloPart](https://github.com/VAST-AI-Research/HoloPart). | |
### Step 1: Prepare Your Segmented Mesh | |
Upload a mesh with part segmentation. We recommend using these segmentation tools: | |
- [SAMPart3D](https://github.com/Pointcept/SAMPart3D) | |
- [SAMesh](https://github.com/gtangg12/samesh) | |
For a mesh file `mesh.glb` and corresponding face mask `mask.npy`, prepare your input using this Python code: | |
```python | |
import trimesh | |
import numpy as np | |
mesh = trimesh.load("mesh.glb", force="mesh") | |
mask_npy = np.load("mask.npy") | |
mesh_parts = [] | |
for part_id in np.unique(mask_npy): | |
mesh_part = mesh.submesh([mask_npy == part_id], append=True) | |
mesh_parts.append(mesh_part) | |
mesh_parts = trimesh.Scene(mesh_parts).export("input_mesh.glb") | |
``` | |
The resulting **input_mesh.glb** is your prepared input for HoloPart. | |
### Step 2: Click the Decompose Parts button to begin the decomposition process. | |
""" | |
from inference_holopart import prepare_data, run_holopart | |
from holopart.pipelines.pipeline_holopart import HoloPartPipeline | |
snapshot_download("VAST-AI/HoloPart", local_dir=HOLOPART_PRETRAINED_MODEL) | |
holopart_pipe = HoloPartPipeline.from_pretrained(HOLOPART_PRETRAINED_MODEL).to(DEVICE, DTYPE) | |
def start_session(req: gr.Request): | |
save_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
os.makedirs(save_dir, exist_ok=True) | |
print("start session, mkdir", save_dir) | |
def end_session(req: gr.Request): | |
save_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
shutil.rmtree(save_dir) | |
def get_random_hex(): | |
random_bytes = os.urandom(8) | |
random_hex = random_bytes.hex() | |
return random_hex | |
def get_random_seed(randomize_seed, seed): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
def explode_mesh(mesh: trimesh.Scene, explode_factor: float = 0.5): | |
center = mesh.centroid | |
exploded_mesh = trimesh.Scene() | |
for geometry_name, geometry in mesh.geometry.items(): | |
transform = mesh.graph[geometry_name][0] | |
vertices_global = trimesh.transformations.transform_points( | |
geometry.vertices, transform) | |
part_center = np.mean(vertices_global, axis=0) | |
direction = part_center - center | |
direction_length = np.linalg.norm(direction) | |
if direction_length > 0: | |
direction = direction / direction_length | |
displacement = direction * explode_factor | |
new_transform = np.copy(transform) | |
new_transform[:3, 3] += displacement | |
exploded_mesh.add_geometry(geometry, transform=new_transform, geom_name=geometry_name) | |
return exploded_mesh | |
def run_full(data_path, seed=42, num_inference_steps=25, guidance_scale=3.5): | |
batch_size = 30 | |
parts_data = prepare_data(data_path) | |
part_scene = run_holopart( | |
holopart_pipe, | |
batch=parts_data, | |
batch_size=batch_size, | |
seed=seed, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
num_chunks=1000000, | |
) | |
print("mesh extraction done") | |
save_dir = os.path.join(TMP_DIR, "examples") | |
os.makedirs(save_dir, exist_ok=True) | |
mesh_path = os.path.join(save_dir, f"holorpart_{get_random_hex()}.glb") | |
part_scene.export(mesh_path) | |
print("save to ", mesh_path) | |
exploded_mesh = explode_mesh(part_scene, 0.7) | |
exploded_mesh_path = os.path.join(save_dir, f"holorpart_exploded_{get_random_hex()}.glb") | |
exploded_mesh.export(exploded_mesh_path) | |
torch.cuda.empty_cache() | |
return mesh_path, exploded_mesh_path | |
def run_example(data_path: str, example_image_path, seed=42, num_inference_steps=25, guidance_scale=3.5): | |
batch_size = 30 | |
parts_data = prepare_data(data_path) | |
part_scene = run_holopart( | |
holopart_pipe, | |
batch=parts_data, | |
batch_size=batch_size, | |
seed=seed, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
num_chunks=1000000, | |
) | |
print("mesh extraction done") | |
save_dir = os.path.join(TMP_DIR, "examples") | |
os.makedirs(save_dir, exist_ok=True) | |
mesh_path = os.path.join(save_dir, f"holorpart_{get_random_hex()}.glb") | |
part_scene.export(mesh_path) | |
print("save to ", mesh_path) | |
exploded_mesh = explode_mesh(part_scene, 0.5) | |
exploded_mesh_path = os.path.join(save_dir, f"holorpart_exploded_{get_random_hex()}.glb") | |
exploded_mesh.export(exploded_mesh_path) | |
torch.cuda.empty_cache() | |
return mesh_path, exploded_mesh_path | |
with gr.Blocks(title="HoloPart") as demo: | |
gr.Markdown(HEADER) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
input_mesh = gr.Model3D(label="Input Mesh") | |
example_image = gr.Image(label="Example Image", type="filepath", interactive=False, visible=False) | |
# seg_image = gr.Image( | |
# label="Segmentation Result", type="pil", format="png", interactive=False | |
# ) | |
with gr.Accordion("Generation Settings", open=True): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=0, | |
value=0 | |
) | |
# randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=8, | |
maximum=50, | |
step=1, | |
value=25, | |
) | |
guidance_scale = gr.Slider( | |
label="CFG scale", | |
minimum=0.0, | |
maximum=20.0, | |
step=0.1, | |
value=3.5, | |
) | |
with gr.Row(): | |
reduce_face = gr.Checkbox(label="Simplify Mesh", value=True, interactive=False) | |
# target_face_num = gr.Slider(maximum=1000000, minimum=10000, value=DEFAULT_FACE_NUMBER, label="Target Face Number") | |
gen_button = gr.Button("Decompose Parts", variant="primary") | |
with gr.Column(): | |
model_output = gr.Model3D(label="Decomposed GLB", interactive=False) | |
exploded_parts_output = gr.Model3D(label="Exploded Parts", interactive=False) | |
with gr.Row(): | |
examples = gr.Examples( | |
examples=EXAMPLES, | |
fn=run_example, | |
inputs=[input_mesh, example_image], | |
outputs=[model_output, exploded_parts_output], | |
cache_examples=True, | |
) | |
gen_button.click( | |
run_full, | |
inputs=[ | |
input_mesh, | |
seed, | |
num_inference_steps, | |
guidance_scale | |
], | |
outputs=[model_output, exploded_parts_output], | |
) | |
demo.load(start_session) | |
demo.unload(end_session) | |
demo.launch() | |