yhyang-myron commited on
Commit
f89a9bf
·
0 Parent(s):
Files changed (4) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. app.py +243 -0
  4. requirements.txt +25 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: HoloPart
3
+ emoji: 🔮
4
+ colorFrom: indigo
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 5.24.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ import trimesh
8
+ import random
9
+ from transformers import AutoModelForImageSegmentation
10
+ from torchvision import transforms
11
+ from huggingface_hub import hf_hub_download, snapshot_download, login
12
+ import subprocess
13
+ import shutil
14
+
15
+
16
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+ DTYPE = torch.float16
18
+
19
+ print("DEVICE: ", DEVICE)
20
+
21
+ DEFAULT_PART_FACE_NUMBER = 10000
22
+ MAX_SEED = np.iinfo(np.int32).max
23
+ HOLOPART_REPO_URL = "https://github.com/VAST-AI-Research/HoloPart"
24
+ HOLOPART_PRETRAINED_MODEL = "checkpoints/HoloPart"
25
+
26
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
27
+ os.makedirs(TMP_DIR, exist_ok=True)
28
+
29
+ HOLOPART_CODE_DIR = "./holopart"
30
+ if not os.path.exists(HOLOPART_REPO_URL):
31
+ os.system(f"git clone {HOLOPART_REPO_URL} {HOLOPART_CODE_DIR}")
32
+
33
+ import sys
34
+ sys.path.append(HOLOPART_CODE_DIR)
35
+ sys.path.append(os.path.join(HOLOPART_CODE_DIR, "scripts"))
36
+
37
+ EXAMPLES = [
38
+ ["./holopart/assets/example_data/000.glb", "./holopart/assets/example_data/000.png"],
39
+ ["./holopart/assets/example_data/001.glb", "./holopart/assets/example_data/001.png"],
40
+ ["./holopart/assets/example_data/002.glb", "./holopart/assets/example_data/002.png"],
41
+ ["./holopart/assets/example_data/003.glb", "./holopart/assets/example_data/003.png"],
42
+ ["./holopart/assets/example_data/004.glb", "./holopart/assets/example_data/004.png"],
43
+ ]
44
+
45
+ HEADER = """
46
+ # 🔮 Decompose a 3D shape into complete parts with [HoloPart](https://github.com/VAST-AI-Research/HoloPart).
47
+ ### Step 1: Prepare Your Segmented Mesh
48
+ Upload a mesh with part segmentation. We recommend using these segmentation tools:
49
+ - [SAMPart3D](https://github.com/Pointcept/SAMPart3D)
50
+ - [SAMesh](https://github.com/gtangg12/samesh)
51
+ For a mesh file `mesh.glb` and corresponding face mask `mask.npy`, prepare your input using this Python code:
52
+ ```python
53
+ import trimesh
54
+ import numpy as np
55
+ mesh = trimesh.load("mesh.glb", force="mesh")
56
+ mesh_parts = []
57
+ for part_id in np.unique(mask_npy):
58
+ mesh_part = mesh.submesh([mask_npy == part_id], append=True)
59
+ mesh_parts.append(mesh_part)
60
+ mesh_parts = trimesh.Scene(mesh_parts).export(input_mesh.glb)
61
+ ```
62
+ The resulting **input_mesh.glb** is your prepared input for HoloPart.
63
+ ### Step 2: Click the Decompose Parts button to begin the decomposition process.
64
+ """
65
+
66
+ from inference_holopart import prepare_data, run_holopart
67
+ from holopart.pipelines.pipeline_holopart import HoloPartPipeline
68
+
69
+ snapshot_download("VAST-AI/HoloPart", local_dir=HOLOPART_PRETRAINED_MODEL)
70
+ holopart_pipe = HoloPartPipeline.from_pretrained(HOLOPART_PRETRAINED_MODEL).to(DEVICE, DTYPE)
71
+
72
+ def start_session(req: gr.Request):
73
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
74
+ os.makedirs(save_dir, exist_ok=True)
75
+ print("start session, mkdir", save_dir)
76
+
77
+ def end_session(req: gr.Request):
78
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
79
+ shutil.rmtree(save_dir)
80
+
81
+ def get_random_hex():
82
+ random_bytes = os.urandom(8)
83
+ random_hex = random_bytes.hex()
84
+ return random_hex
85
+
86
+ def get_random_seed(randomize_seed, seed):
87
+ if randomize_seed:
88
+ seed = random.randint(0, MAX_SEED)
89
+ return seed
90
+
91
+ def explode_mesh(mesh: trimesh.Scene, explode_factor: float = 0.5):
92
+ center = mesh.centroid
93
+ exploded_mesh = trimesh.Scene()
94
+ for geometry_name, geometry in mesh.geometry.items():
95
+ transform = mesh.graph[geometry_name][0]
96
+ vertices_global = trimesh.transformations.transform_points(
97
+ geometry.vertices, transform)
98
+ part_center = np.mean(vertices_global, axis=0)
99
+ direction = part_center - center
100
+ direction_length = np.linalg.norm(direction)
101
+ if direction_length > 0:
102
+ direction = direction / direction_length
103
+ displacement = direction * explode_factor
104
+ new_transform = np.copy(transform)
105
+ new_transform[:3, 3] += displacement
106
+ exploded_mesh.add_geometry(geometry, transform=new_transform, geom_name=geometry_name)
107
+ return exploded_mesh
108
+
109
+
110
+ @spaces.GPU(duration=600)
111
+ def run_full(data_path, seed=42, num_inference_steps=25, guidance_scale=3.5):
112
+
113
+ batch_size = 30
114
+ parts_data = prepare_data(data_path)
115
+
116
+ part_scene = run_holopart(
117
+ holopart_pipe,
118
+ batch=parts_data,
119
+ batch_size=batch_size,
120
+ seed=seed,
121
+ num_inference_steps=num_inference_steps,
122
+ guidance_scale=guidance_scale,
123
+ num_chunks=1000000,
124
+ )
125
+ print("mesh extraction done")
126
+
127
+ save_dir = os.path.join(TMP_DIR, "examples")
128
+ os.makedirs(save_dir, exist_ok=True)
129
+ mesh_path = os.path.join(save_dir, f"holorpart_{get_random_hex()}.glb")
130
+ part_scene.export(mesh_path)
131
+ print("save to ", mesh_path)
132
+ exploded_mesh = explode_mesh(part_scene, 0.7)
133
+ exploded_mesh_path = os.path.join(save_dir, f"holorpart_exploded_{get_random_hex()}.glb")
134
+ exploded_mesh.export(exploded_mesh_path)
135
+
136
+ torch.cuda.empty_cache()
137
+
138
+ return mesh_path, exploded_mesh_path
139
+
140
+
141
+ @spaces.GPU(duration=600)
142
+ def run_example(data_path: str, example_image_path, seed=42, num_inference_steps=25, guidance_scale=3.5):
143
+
144
+ batch_size = 30
145
+ parts_data = prepare_data(data_path)
146
+
147
+ part_scene = run_holopart(
148
+ holopart_pipe,
149
+ batch=parts_data,
150
+ batch_size=batch_size,
151
+ seed=seed,
152
+ num_inference_steps=num_inference_steps,
153
+ guidance_scale=guidance_scale,
154
+ num_chunks=1000000,
155
+ )
156
+ print("mesh extraction done")
157
+
158
+
159
+ save_dir = os.path.join(TMP_DIR, "examples")
160
+ os.makedirs(save_dir, exist_ok=True)
161
+ mesh_path = os.path.join(save_dir, f"holorpart_{get_random_hex()}.glb")
162
+ part_scene.export(mesh_path)
163
+ print("save to ", mesh_path)
164
+ exploded_mesh = explode_mesh(part_scene, 0.5)
165
+ exploded_mesh_path = os.path.join(save_dir, f"holorpart_exploded_{get_random_hex()}.glb")
166
+ exploded_mesh.export(exploded_mesh_path)
167
+
168
+ torch.cuda.empty_cache()
169
+
170
+ return mesh_path, exploded_mesh_path
171
+
172
+
173
+ with gr.Blocks(title="HoloPart") as demo:
174
+ gr.Markdown(HEADER)
175
+
176
+ with gr.Row():
177
+ with gr.Column():
178
+ with gr.Row():
179
+ input_mesh = gr.Model3D(label="Input Mesh")
180
+ example_image = gr.Image(label="Example Image", type="filepath", interactive=False, visible=False)
181
+ # seg_image = gr.Image(
182
+ # label="Segmentation Result", type="pil", format="png", interactive=False
183
+ # )
184
+
185
+ with gr.Accordion("Generation Settings", open=True):
186
+ seed = gr.Slider(
187
+ label="Seed",
188
+ minimum=0,
189
+ maximum=MAX_SEED,
190
+ step=0,
191
+ value=0
192
+ )
193
+ # randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
194
+ num_inference_steps = gr.Slider(
195
+ label="Number of inference steps",
196
+ minimum=8,
197
+ maximum=50,
198
+ step=1,
199
+ value=25,
200
+ )
201
+ guidance_scale = gr.Slider(
202
+ label="CFG scale",
203
+ minimum=0.0,
204
+ maximum=20.0,
205
+ step=0.1,
206
+ value=3.5,
207
+ )
208
+
209
+ with gr.Row():
210
+ reduce_face = gr.Checkbox(label="Simplify Mesh", value=True, interactive=False)
211
+ # target_face_num = gr.Slider(maximum=1000000, minimum=10000, value=DEFAULT_FACE_NUMBER, label="Target Face Number")
212
+
213
+ gen_button = gr.Button("Decompose Parts", variant="primary")
214
+
215
+ with gr.Column():
216
+ model_output = gr.Model3D(label="Decomposed GLB", interactive=False)
217
+ exploded_parts_output = gr.Model3D(label="Exploded Parts", interactive=False)
218
+
219
+ with gr.Row():
220
+ examples = gr.Examples(
221
+ examples=EXAMPLES,
222
+ fn=run_example,
223
+ inputs=[input_mesh, example_image],
224
+ outputs=[model_output, exploded_parts_output],
225
+ cache_examples=True,
226
+ )
227
+
228
+
229
+ gen_button.click(
230
+ run_full,
231
+ inputs=[
232
+ input_mesh,
233
+ seed,
234
+ num_inference_steps,
235
+ guidance_scale
236
+ ],
237
+ outputs=[model_output, exploded_parts_output],
238
+ )
239
+
240
+ demo.load(start_session)
241
+ demo.unload(end_session)
242
+
243
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchvision
2
+ diffusers
3
+ transformers==4.49.0
4
+ einops
5
+ huggingface_hub
6
+ opencv-python
7
+ trimesh==4.5.3
8
+ omegaconf
9
+ scikit-image
10
+ numpy
11
+ peft
12
+ scipy==1.11.4
13
+ jaxtyping
14
+ typeguard
15
+ pymeshlab==2022.2.post4
16
+ open3d
17
+ timm
18
+ kornia
19
+ ninja
20
+ https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl?download=true
21
+ cvcuda_cu12
22
+ gltflib
23
+ https://huggingface.co/spaces/VAST-AI/TripoSG/resolve/main/diso-0.1.4-cp310-cp310-linux_x86_64.whl?download=true
24
+ --find-links https://data.pyg.org/whl/torch-2.6.0+cu124.html
25
+ torch-cluster