xinjie.wang commited on
Commit
2e90551
·
1 Parent(s): d2536d4
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +490 -15
  2. asset3d_gen/data/backproject.py +503 -0
  3. asset3d_gen/data/backproject_v2.py +613 -0
  4. asset3d_gen/data/backup/backproject_v2 copy.py +650 -0
  5. asset3d_gen/data/backup/backproject_v2.py +700 -0
  6. asset3d_gen/data/backup/backproject_v3.py +866 -0
  7. asset3d_gen/data/backup/backprojectv2.py +835 -0
  8. asset3d_gen/data/backup/gpt_qwen.py +70 -0
  9. asset3d_gen/data/backup/quat.py +49 -0
  10. asset3d_gen/data/datasets.py +239 -0
  11. asset3d_gen/data/differentiable_render.py +520 -0
  12. asset3d_gen/data/mesh_operator.py +425 -0
  13. asset3d_gen/data/utils.py +943 -0
  14. asset3d_gen/models/delight.py +165 -0
  15. asset3d_gen/models/gs_model.py +540 -0
  16. asset3d_gen/models/segment.py +376 -0
  17. asset3d_gen/models/super_resolution.py +118 -0
  18. asset3d_gen/models/text_model.py +143 -0
  19. asset3d_gen/models/texture_model.py +91 -0
  20. asset3d_gen/scripts/render_gs.py +156 -0
  21. asset3d_gen/scripts/render_mv.py +185 -0
  22. asset3d_gen/scripts/text2image.py +145 -0
  23. asset3d_gen/utils/gpt_clients.py +190 -0
  24. asset3d_gen/utils/process_media.py +194 -0
  25. asset3d_gen/utils/tags.py +1 -0
  26. asset3d_gen/validators/aesthetic_predictor.py +136 -0
  27. asset3d_gen/validators/quality_checkers.py +195 -0
  28. asset3d_gen/validators/urdf_convertor.py +423 -0
  29. common.py +597 -0
  30. requirements.txt +7 -6
  31. thirdparty/TRELLIS/trellis/trellis/__init__.py +6 -0
  32. thirdparty/TRELLIS/trellis/trellis/models/__init__.py +70 -0
  33. thirdparty/TRELLIS/trellis/trellis/models/sparse_structure_flow.py +200 -0
  34. thirdparty/TRELLIS/trellis/trellis/models/sparse_structure_vae.py +306 -0
  35. thirdparty/TRELLIS/trellis/trellis/models/structured_latent_flow.py +262 -0
  36. thirdparty/TRELLIS/trellis/trellis/models/structured_latent_vae/__init__.py +4 -0
  37. thirdparty/TRELLIS/trellis/trellis/models/structured_latent_vae/base.py +117 -0
  38. thirdparty/TRELLIS/trellis/trellis/models/structured_latent_vae/decoder_gs.py +122 -0
  39. thirdparty/TRELLIS/trellis/trellis/models/structured_latent_vae/decoder_mesh.py +167 -0
  40. thirdparty/TRELLIS/trellis/trellis/models/structured_latent_vae/decoder_rf.py +104 -0
  41. thirdparty/TRELLIS/trellis/trellis/models/structured_latent_vae/encoder.py +72 -0
  42. thirdparty/TRELLIS/trellis/trellis/modules/attention/__init__.py +36 -0
  43. thirdparty/TRELLIS/trellis/trellis/modules/attention/full_attn.py +140 -0
  44. thirdparty/TRELLIS/trellis/trellis/modules/attention/modules.py +146 -0
  45. thirdparty/TRELLIS/trellis/trellis/modules/norm.py +25 -0
  46. thirdparty/TRELLIS/trellis/trellis/modules/sparse/__init__.py +102 -0
  47. thirdparty/TRELLIS/trellis/trellis/modules/sparse/attention/__init__.py +4 -0
  48. thirdparty/TRELLIS/trellis/trellis/modules/sparse/attention/full_attn.py +215 -0
  49. thirdparty/TRELLIS/trellis/trellis/modules/sparse/attention/modules.py +139 -0
  50. thirdparty/TRELLIS/trellis/trellis/modules/sparse/attention/serialized_attn.py +193 -0
app.py CHANGED
@@ -1,21 +1,496 @@
1
- import torch
 
 
 
2
  import gradio as gr
3
- import spaces
4
- import nvdiffrast.torch as dr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- @spaces.GPU
7
- def pinfo():
8
- print("CUDA Version:", torch.version.cuda)
9
- print(torch.version.cuda)
10
- print(torch.cuda.is_available())
11
- zero = torch.Tensor([0]).cuda()
12
- print(zero.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- def my_function(input_text):
16
- pinfo()
17
- return f"Received: {input_text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
- iface = gr.Interface(fn=my_function, inputs="text", outputs="text")
21
- iface.launch()
 
 
1
+ import os
2
+ import shutil
3
+ from functools import partial
4
+
5
  import gradio as gr
6
+ from common import (
7
+ MAX_SEED,
8
+ VERSION,
9
+ TrellisImageTo3DPipeline,
10
+ active_btn_by_content,
11
+ extract_3d_representations_v2,
12
+ extract_urdf,
13
+ get_seed,
14
+ image_to_3d,
15
+ preprocess_image_fn,
16
+ preprocess_sam_image_fn,
17
+ select_point,
18
+ )
19
+ from gradio.themes import Default
20
+ from gradio.themes.utils.colors import slate
21
+ from gradio_litmodel3d import LitModel3D
22
+ from asset3d_gen.models.delight import DelightingModel
23
+ from asset3d_gen.models.segment import RembgRemover, SAMPredictor
24
+ from asset3d_gen.models.super_resolution import ImageRealESRGAN
25
+ from asset3d_gen.utils.gpt_clients import GPT_CLIENT
26
+ from asset3d_gen.validators.quality_checkers import (
27
+ ImageAestheticChecker,
28
+ ImageSegChecker,
29
+ MeshGeoChecker,
30
+ )
31
+ from asset3d_gen.validators.urdf_convertor import URDFGenerator
32
+
33
+ TMP_DIR = os.path.join(
34
+ os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
35
+ )
36
+ os.makedirs(TMP_DIR, exist_ok=True)
37
+
38
+ RBG_REMOVER = RembgRemover()
39
+ SAM_PREDICTOR = SAMPredictor(model_type="vit_h")
40
+ DELIGHT = DelightingModel()
41
+ IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
42
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
43
+ "JeffreyXiang/TRELLIS-image-large"
44
+ )
45
+ # PIPELINE.cuda()
46
+
47
+ IMAGE_BUFFER = {}
48
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
49
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
50
+ AESTHETIC_CHECKER = ImageAestheticChecker()
51
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
52
+ URDF_CONVERTOR = URDFGenerator(GPT_CLIENT, render_view_num=4)
53
+
54
+
55
+ def start_session(req: gr.Request) -> None:
56
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
57
+ os.makedirs(user_dir, exist_ok=True)
58
+
59
+
60
+ def end_session(req: gr.Request) -> None:
61
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
62
+ if os.path.exists(user_dir):
63
+ shutil.rmtree(user_dir)
64
+
65
+
66
+ with gr.Blocks(
67
+ delete_cache=(43200, 43200), theme=Default(primary_hue=slate)
68
+ ) as demo:
69
+ gr.Markdown(
70
+ f"""
71
+ ## Image to 3D Asset Pipeline \n
72
+ version: {VERSION} \n
73
+ The service is temporarily deployed on `dev015-10.34.8.82: CUDA 4`.
74
+ """
75
+ )
76
+ with gr.Row():
77
+ with gr.Column(scale=2):
78
+ with gr.Tabs() as input_tabs:
79
+ with gr.Tab(
80
+ label="Image(auto seg)", id=0
81
+ ) as single_image_input_tab:
82
+ image_prompt = gr.Image(
83
+ label="Input Image",
84
+ format="png",
85
+ image_mode="RGBA",
86
+ type="pil",
87
+ height=300,
88
+ )
89
+ gr.Markdown(
90
+ """
91
+ If you are not satisfied with the auto segmentation
92
+ result, please switch to the `Image(SAM seg)` tab."""
93
+ )
94
+ with gr.Tab(
95
+ label="Image(SAM seg)", id=1
96
+ ) as samimage_input_tab:
97
+ with gr.Row():
98
+ with gr.Column(scale=1):
99
+ image_prompt_sam = gr.Image(
100
+ label="Input Image", type="numpy", height=400
101
+ )
102
+ image_seg_sam = gr.Image(
103
+ label="SAM Seg Image",
104
+ image_mode="RGBA",
105
+ type="pil",
106
+ height=400,
107
+ visible=False,
108
+ )
109
+ with gr.Column(scale=1):
110
+ image_mask_sam = gr.AnnotatedImage()
111
+
112
+ fg_bg_radio = gr.Radio(
113
+ ["foreground_point", "background_point"],
114
+ label="Select foreground(green) or background(red) points, by default foreground", # noqa
115
+ value="foreground_point",
116
+ )
117
+ gr.Markdown(
118
+ """ Click the `Input Image` to select SAM points,
119
+ after get the satisified segmentation, click `Generate`
120
+ button to generate the 3D asset. \n
121
+ Note: If the segmented foreground is too small relative
122
+ to the entire image area, the generation will fail.
123
+ """
124
+ )
125
+
126
+ with gr.Accordion(label="Generation Settings", open=False):
127
+ with gr.Row():
128
+ seed = gr.Slider(
129
+ 0, MAX_SEED, label="Seed", value=0, step=1
130
+ )
131
+ with gr.Row():
132
+ randomize_seed = gr.Checkbox(
133
+ label="Randomize Seed", value=False
134
+ )
135
+ project_delight = gr.Checkbox(
136
+ label="Backproject delighting",
137
+ value=True,
138
+ )
139
+ gr.Markdown("Geo Structure Generation")
140
+ with gr.Row():
141
+ ss_guidance_strength = gr.Slider(
142
+ 0.0,
143
+ 10.0,
144
+ label="Guidance Strength",
145
+ value=7.5,
146
+ step=0.1,
147
+ )
148
+ ss_sampling_steps = gr.Slider(
149
+ 1, 50, label="Sampling Steps", value=12, step=1
150
+ )
151
+ gr.Markdown("Visual Appearance Generation")
152
+ with gr.Row():
153
+ slat_guidance_strength = gr.Slider(
154
+ 0.0,
155
+ 10.0,
156
+ label="Guidance Strength",
157
+ value=3.0,
158
+ step=0.1,
159
+ )
160
+ slat_sampling_steps = gr.Slider(
161
+ 1, 50, label="Sampling Steps", value=12, step=1
162
+ )
163
+
164
+ generate_btn = gr.Button(
165
+ "Generate(~0.5 mins)", variant="primary", interactive=False
166
+ )
167
+ model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
168
+ with gr.Row():
169
+ extract_rep3d_btn = gr.Button(
170
+ "Extract 3D Representation(~2 mins)",
171
+ variant="primary",
172
+ interactive=False,
173
+ )
174
+ with gr.Accordion(
175
+ label="Enter Asset Attributes(optional)", open=False
176
+ ):
177
+ asset_cat_text = gr.Textbox(
178
+ label="Enter Asset Category (e.g., chair)"
179
+ )
180
+ height_range_text = gr.Textbox(
181
+ label="Enter Height Range in meter (e.g., 0.5-0.6)"
182
+ )
183
+ mass_range_text = gr.Textbox(
184
+ label="Enter Mass Range in kg (e.g., 1.1-1.2)"
185
+ )
186
+ asset_version_text = gr.Textbox(
187
+ label=f"Enter version (e.g., {VERSION})"
188
+ )
189
+ with gr.Row():
190
+ extract_urdf_btn = gr.Button(
191
+ "Extract URDF(~1 mins)",
192
+ variant="primary",
193
+ interactive=False,
194
+ )
195
+ with gr.Row():
196
+ gr.Markdown(
197
+ "#### Estimated Asset 3D Attributes(No input required)"
198
+ )
199
+ with gr.Row():
200
+ est_type_text = gr.Textbox(
201
+ label="Asset category", interactive=False
202
+ )
203
+ est_height_text = gr.Textbox(
204
+ label="Real height(.m)", interactive=False
205
+ )
206
+ est_mass_text = gr.Textbox(
207
+ label="Mass(.kg)", interactive=False
208
+ )
209
+ est_mu_text = gr.Textbox(
210
+ label="Friction coefficient", interactive=False
211
+ )
212
+ with gr.Row():
213
+ download_urdf = gr.DownloadButton(
214
+ label="Download URDF", variant="primary", interactive=False
215
+ )
216
+
217
+ gr.Markdown(
218
+ """ NOTE: If `Asset Attributes` are provided, the provided
219
+ properties will be used; otherwise, the GPT-preset properties
220
+ will be applied. \n
221
+ The `Download URDF` file is restored to the real scale and
222
+ has quality inspection, open with an editor to view details.
223
+ """
224
+ )
225
+
226
+ with gr.Row() as single_image_example:
227
+ examples = gr.Examples(
228
+ label="Image Gallery",
229
+ examples=[
230
+ [f"scripts/apps/assets/example_image/{image}"]
231
+ for image in os.listdir(
232
+ "scripts/apps/assets/example_image"
233
+ )
234
+ ],
235
+ inputs=[image_prompt],
236
+ fn=partial(
237
+ preprocess_image_fn,
238
+ model=RBG_REMOVER,
239
+ buffer=IMAGE_BUFFER,
240
+ ),
241
+ outputs=[image_prompt],
242
+ run_on_click=True,
243
+ examples_per_page=32,
244
+ )
245
+
246
+ with gr.Row(visible=False) as single_sam_image_example:
247
+ examples = gr.Examples(
248
+ label="Image Gallery",
249
+ examples=[
250
+ f"scripts/apps/assets/example_image/{image}"
251
+ for image in os.listdir(
252
+ "scripts/apps/assets/example_image"
253
+ )
254
+ ],
255
+ inputs=[image_prompt_sam],
256
+ fn=partial(
257
+ preprocess_sam_image_fn,
258
+ buffer=IMAGE_BUFFER,
259
+ model=SAM_PREDICTOR,
260
+ ),
261
+ outputs=[image_prompt_sam],
262
+ run_on_click=True,
263
+ examples_per_page=32,
264
+ )
265
+ with gr.Column(scale=1):
266
+ video_output = gr.Video(
267
+ label="Generated 3D Asset",
268
+ autoplay=True,
269
+ loop=True,
270
+ height=300,
271
+ )
272
+ model_output_gs = LitModel3D(
273
+ label="Gaussian Representation", height=300, interactive=False
274
+ )
275
+ aligned_gs = gr.Textbox(visible=False)
276
+ with gr.Row():
277
+ model_output_mesh = LitModel3D(
278
+ label="Mesh Representation",
279
+ exposure=10.0,
280
+ height=300,
281
+ interactive=False,
282
+ )
283
+ gr.Markdown(
284
+ """ The rendering of `Gaussian Representation` takes additional 10s. """ # noqa
285
+ )
286
+
287
+ is_samimage = gr.State(False)
288
+ output_buf = gr.State()
289
+ selected_points = gr.State(value=[])
290
+
291
+ demo.load(start_session)
292
+ demo.unload(end_session)
293
+
294
+ single_image_input_tab.select(
295
+ lambda: tuple(
296
+ [False, gr.Row.update(visible=True), gr.Row.update(visible=False)]
297
+ ),
298
+ outputs=[is_samimage, single_image_example, single_sam_image_example],
299
+ )
300
+ samimage_input_tab.select(
301
+ lambda: tuple(
302
+ [True, gr.Row.update(visible=True), gr.Row.update(visible=False)]
303
+ ),
304
+ outputs=[is_samimage, single_sam_image_example, single_image_example],
305
+ )
306
+
307
+ image_prompt.upload(
308
+ partial(preprocess_image_fn, model=RBG_REMOVER, buffer=IMAGE_BUFFER),
309
+ inputs=[image_prompt],
310
+ outputs=[image_prompt],
311
+ )
312
+ image_prompt.change(
313
+ lambda: tuple(
314
+ [
315
+ gr.Button(interactive=False),
316
+ gr.Button(interactive=False),
317
+ gr.Button(interactive=False),
318
+ None,
319
+ "",
320
+ None,
321
+ None,
322
+ "",
323
+ "",
324
+ "",
325
+ "",
326
+ "",
327
+ "",
328
+ "",
329
+ "",
330
+ ]
331
+ ),
332
+ outputs=[
333
+ extract_rep3d_btn,
334
+ extract_urdf_btn,
335
+ download_urdf,
336
+ model_output_gs,
337
+ aligned_gs,
338
+ model_output_mesh,
339
+ video_output,
340
+ asset_cat_text,
341
+ height_range_text,
342
+ mass_range_text,
343
+ asset_version_text,
344
+ est_type_text,
345
+ est_height_text,
346
+ est_mass_text,
347
+ est_mu_text,
348
+ ],
349
+ )
350
+ image_prompt.change(
351
+ active_btn_by_content,
352
+ inputs=image_prompt,
353
+ outputs=generate_btn,
354
+ )
355
+
356
+ image_prompt_sam.upload(
357
+ partial(
358
+ preprocess_sam_image_fn, buffer=IMAGE_BUFFER, model=SAM_PREDICTOR
359
+ ),
360
+ inputs=[image_prompt_sam],
361
+ outputs=[image_prompt_sam],
362
+ )
363
+ image_prompt_sam.change(
364
+ lambda: tuple(
365
+ [
366
+ gr.Button(interactive=False),
367
+ gr.Button(interactive=False),
368
+ gr.Button(interactive=False),
369
+ None,
370
+ None,
371
+ None,
372
+ "",
373
+ "",
374
+ "",
375
+ "",
376
+ "",
377
+ "",
378
+ "",
379
+ "",
380
+ None,
381
+ [],
382
+ ]
383
+ ),
384
+ outputs=[
385
+ extract_rep3d_btn,
386
+ extract_urdf_btn,
387
+ download_urdf,
388
+ model_output_gs,
389
+ model_output_mesh,
390
+ video_output,
391
+ asset_cat_text,
392
+ height_range_text,
393
+ mass_range_text,
394
+ asset_version_text,
395
+ est_type_text,
396
+ est_height_text,
397
+ est_mass_text,
398
+ est_mu_text,
399
+ image_mask_sam,
400
+ selected_points,
401
+ ],
402
+ )
403
+
404
+ image_prompt_sam.select(
405
+ select_point,
406
+ [
407
+ image_prompt_sam,
408
+ selected_points,
409
+ fg_bg_radio,
410
+ gr.State(lambda: SAM_PREDICTOR),
411
+ ],
412
+ [image_mask_sam, image_seg_sam],
413
+ )
414
+ image_seg_sam.change(
415
+ active_btn_by_content,
416
+ inputs=image_seg_sam,
417
+ outputs=generate_btn,
418
+ )
419
 
420
+ generate_btn.click(
421
+ get_seed,
422
+ inputs=[randomize_seed, seed],
423
+ outputs=[seed],
424
+ ).success(
425
+ image_to_3d,
426
+ inputs=[
427
+ image_prompt,
428
+ seed,
429
+ ss_guidance_strength,
430
+ ss_sampling_steps,
431
+ slat_guidance_strength,
432
+ slat_sampling_steps,
433
+ gr.State(lambda: IMAGE_BUFFER),
434
+ gr.State(lambda: PIPELINE),
435
+ gr.State(lambda: TMP_DIR),
436
+ image_seg_sam,
437
+ is_samimage,
438
+ ],
439
+ outputs=[output_buf, video_output],
440
+ ).success(
441
+ lambda: gr.Button(interactive=True),
442
+ outputs=[extract_rep3d_btn],
443
+ )
444
 
445
+ extract_rep3d_btn.click(
446
+ extract_3d_representations_v2,
447
+ inputs=[
448
+ output_buf,
449
+ project_delight,
450
+ gr.State(lambda: TMP_DIR),
451
+ gr.State(lambda: DELIGHT),
452
+ gr.State(lambda: IMAGESR_MODEL),
453
+ ],
454
+ outputs=[
455
+ model_output_mesh,
456
+ model_output_gs,
457
+ model_output_obj,
458
+ aligned_gs,
459
+ ],
460
+ ).success(
461
+ lambda: gr.Button(interactive=True),
462
+ outputs=[extract_urdf_btn],
463
+ )
464
 
465
+ extract_urdf_btn.click(
466
+ extract_urdf,
467
+ inputs=[
468
+ aligned_gs,
469
+ model_output_obj,
470
+ asset_cat_text,
471
+ height_range_text,
472
+ mass_range_text,
473
+ asset_version_text,
474
+ gr.State(lambda: TMP_DIR),
475
+ gr.State(lambda: URDF_CONVERTOR),
476
+ gr.State(lambda: IMAGE_BUFFER),
477
+ gr.State(lambda: CHECKERS),
478
+ ],
479
+ outputs=[
480
+ download_urdf,
481
+ est_type_text,
482
+ est_height_text,
483
+ est_mass_text,
484
+ est_mu_text,
485
+ ],
486
+ queue=True,
487
+ show_progress="full",
488
+ ).success(
489
+ lambda: gr.Button(interactive=True),
490
+ outputs=[download_urdf],
491
+ )
492
 
493
 
494
+ if __name__ == "__main__":
495
+ demo.queue()
496
+ demo.launch(server_name="10.34.8.82", server_port=8084)
asset3d_gen/data/backproject.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ import os
5
+ from typing import List, Literal, Tuple, Union
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import nvdiffrast.torch as dr
10
+ import torch
11
+ import trimesh
12
+ import utils3d
13
+ import xatlas
14
+ from tqdm import tqdm
15
+ from asset3d_gen.data.mesh_operator import MeshFixer
16
+ from asset3d_gen.data.utils import (
17
+ CameraSetting,
18
+ get_images_from_grid,
19
+ init_kal_camera,
20
+ normalize_vertices_array,
21
+ post_process_texture,
22
+ save_mesh_with_mtl,
23
+ )
24
+ from asset3d_gen.models.delight import DelightingModel
25
+
26
+ logging.basicConfig(
27
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
28
+ )
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class TextureBaker(object):
33
+ """Baking textures onto a mesh from multiple observations.
34
+
35
+ This class take 3D mesh data, camera settings and texture baking parameters
36
+ to generate texture map by projecting images to the mesh from diff views.
37
+ It supports both a fast texture baking approach and a more optimized method
38
+ with total variation regularization.
39
+
40
+ Attributes:
41
+ vertices (torch.Tensor): The vertices of the mesh.
42
+ faces (torch.Tensor): The faces of the mesh, defined by vertex indices.
43
+ uvs (torch.Tensor): The UV coordinates of the mesh.
44
+ camera_params (CameraSetting): Camera setting (intrinsics, extrinsics).
45
+ device (str): The device to run computations on ("cpu" or "cuda").
46
+ w2cs (torch.Tensor): World-to-camera transformation matrices.
47
+ projections (torch.Tensor): Camera projection matrices.
48
+
49
+ Example:
50
+ >>> vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces) # noqa
51
+ >>> texture_backer = TextureBaker(vertices, faces, uvs, camera_params)
52
+ >>> images = get_images_from_grid(args.input_image, image_size)
53
+ >>> texture = texture_backer.bake_texture(
54
+ ... images, texture_size=args.texture_size, mode=args.baker_mode
55
+ ... )
56
+ >>> texture = post_process_texture(texture)
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ vertices: np.ndarray,
62
+ faces: np.ndarray,
63
+ uvs: np.ndarray,
64
+ camera_params: CameraSetting,
65
+ device: str = "cuda",
66
+ ) -> None:
67
+ self.vertices = (
68
+ torch.tensor(vertices, device=device)
69
+ if isinstance(vertices, np.ndarray)
70
+ else vertices.to(device)
71
+ )
72
+ self.faces = (
73
+ torch.tensor(faces.astype(np.int32), device=device)
74
+ if isinstance(faces, np.ndarray)
75
+ else faces.to(device)
76
+ )
77
+ self.uvs = (
78
+ torch.tensor(uvs, device=device)
79
+ if isinstance(uvs, np.ndarray)
80
+ else uvs.to(device)
81
+ )
82
+ self.camera_params = camera_params
83
+ self.device = device
84
+
85
+ camera = init_kal_camera(camera_params)
86
+ matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
87
+ matrix_mv = kaolin_to_opencv_view(matrix_mv)
88
+ matrix_p = (
89
+ camera.intrinsics.projection_matrix()
90
+ ) # (n_cam 4 4) cam2pixel
91
+ self.w2cs = matrix_mv.to(self.device)
92
+ self.projections = matrix_p.to(self.device)
93
+
94
+ @staticmethod
95
+ def parametrize_mesh(
96
+ vertices: np.array, faces: np.array
97
+ ) -> Union[np.array, np.array, np.array]:
98
+ vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
99
+
100
+ vertices = vertices[vmapping]
101
+ faces = indices
102
+
103
+ return vertices, faces, uvs
104
+
105
+ def _bake_fast(self, observations, w2cs, projections, texture_size, masks):
106
+ texture = torch.zeros(
107
+ (texture_size * texture_size, 3), dtype=torch.float32
108
+ ).cuda()
109
+ texture_weights = torch.zeros(
110
+ (texture_size * texture_size), dtype=torch.float32
111
+ ).cuda()
112
+ rastctx = utils3d.torch.RastContext(backend="cuda")
113
+ for observation, w2c, projection in tqdm(
114
+ zip(observations, w2cs, projections),
115
+ total=len(observations),
116
+ desc="Texture baking (fast)",
117
+ ):
118
+ with torch.no_grad():
119
+ rast = utils3d.torch.rasterize_triangle_faces(
120
+ rastctx,
121
+ self.vertices[None],
122
+ self.faces,
123
+ observation.shape[1],
124
+ observation.shape[0],
125
+ uv=self.uvs[None],
126
+ view=w2c,
127
+ projection=projection,
128
+ )
129
+ uv_map = rast["uv"][0].detach().flip(0)
130
+ mask = rast["mask"][0].detach().bool() & masks[0]
131
+
132
+ # nearest neighbor interpolation
133
+ uv_map = (uv_map * texture_size).floor().long()
134
+ obs = observation[mask]
135
+ uv_map = uv_map[mask]
136
+ idx = (
137
+ uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
138
+ )
139
+ texture = texture.scatter_add(
140
+ 0, idx.view(-1, 1).expand(-1, 3), obs
141
+ )
142
+ texture_weights = texture_weights.scatter_add(
143
+ 0,
144
+ idx,
145
+ torch.ones(
146
+ (obs.shape[0]), dtype=torch.float32, device=texture.device
147
+ ),
148
+ )
149
+
150
+ mask = texture_weights > 0
151
+ texture[mask] /= texture_weights[mask][:, None]
152
+ texture = np.clip(
153
+ texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255,
154
+ 0,
155
+ 255,
156
+ ).astype(np.uint8)
157
+
158
+ # inpaint
159
+ mask = (
160
+ (texture_weights == 0)
161
+ .cpu()
162
+ .numpy()
163
+ .astype(np.uint8)
164
+ .reshape(texture_size, texture_size)
165
+ )
166
+ texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
167
+
168
+ return texture
169
+
170
+ def _bake_opt(
171
+ self,
172
+ observations,
173
+ w2cs,
174
+ projections,
175
+ texture_size,
176
+ lambda_tv,
177
+ masks,
178
+ total_steps,
179
+ ):
180
+ rastctx = utils3d.torch.RastContext(backend="cuda")
181
+ observations = [observations.flip(0) for observations in observations]
182
+ masks = [m.flip(0) for m in masks]
183
+ _uv = []
184
+ _uv_dr = []
185
+ for observation, w2c, projection in tqdm(
186
+ zip(observations, w2cs, projections),
187
+ total=len(w2cs),
188
+ ):
189
+ with torch.no_grad():
190
+ rast = utils3d.torch.rasterize_triangle_faces(
191
+ rastctx,
192
+ self.vertices[None],
193
+ self.faces,
194
+ observation.shape[1],
195
+ observation.shape[0],
196
+ uv=self.uvs[None],
197
+ view=w2c,
198
+ projection=projection,
199
+ )
200
+ _uv.append(rast["uv"].detach())
201
+ _uv_dr.append(rast["uv_dr"].detach())
202
+
203
+ texture = torch.nn.Parameter(
204
+ torch.zeros(
205
+ (1, texture_size, texture_size, 3), dtype=torch.float32
206
+ ).cuda()
207
+ )
208
+ optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
209
+
210
+ def cosine_anealing(step, total_steps, start_lr, end_lr):
211
+ return end_lr + 0.5 * (start_lr - end_lr) * (
212
+ 1 + np.cos(np.pi * step / total_steps)
213
+ )
214
+
215
+ def tv_loss(texture):
216
+ return torch.nn.functional.l1_loss(
217
+ texture[:, :-1, :, :], texture[:, 1:, :, :]
218
+ ) + torch.nn.functional.l1_loss(
219
+ texture[:, :, :-1, :], texture[:, :, 1:, :]
220
+ )
221
+
222
+ with tqdm(total=total_steps, desc="Texture baking") as pbar:
223
+ for step in range(total_steps):
224
+ optimizer.zero_grad()
225
+ selected = np.random.randint(0, len(w2cs))
226
+ uv, uv_dr, observation, mask = (
227
+ _uv[selected],
228
+ _uv_dr[selected],
229
+ observations[selected],
230
+ masks[selected],
231
+ )
232
+ render = dr.texture(texture, uv, uv_dr)[0]
233
+ loss = torch.nn.functional.l1_loss(
234
+ render[mask], observation[mask]
235
+ )
236
+ if lambda_tv > 0:
237
+ loss += lambda_tv * tv_loss(texture)
238
+ loss.backward()
239
+ optimizer.step()
240
+
241
+ optimizer.param_groups[0]["lr"] = cosine_anealing(
242
+ step, total_steps, 1e-2, 1e-5
243
+ )
244
+ pbar.set_postfix({"loss": loss.item()})
245
+ pbar.update()
246
+ texture = np.clip(
247
+ texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255
248
+ ).astype(np.uint8)
249
+ mask = 1 - utils3d.torch.rasterize_triangle_faces(
250
+ rastctx,
251
+ (self.uvs * 2 - 1)[None],
252
+ self.faces,
253
+ texture_size,
254
+ texture_size,
255
+ )["mask"][0].detach().cpu().numpy().astype(np.uint8)
256
+ texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
257
+
258
+ return texture
259
+
260
+ def bake_texture(
261
+ self,
262
+ images: List[np.array],
263
+ texture_size: int = 1024,
264
+ mode: Literal["fast", "opt"] = "opt",
265
+ lambda_tv: float = 1e-2,
266
+ opt_step: int = 2000,
267
+ ):
268
+ masks = [np.any(img > 0, axis=-1) for img in images]
269
+ masks = [torch.tensor(m > 0).bool().to(self.device) for m in masks]
270
+ images = [
271
+ torch.tensor(obs / 255.0).float().to(self.device) for obs in images
272
+ ]
273
+
274
+ if mode == "fast":
275
+ return self._bake_fast(
276
+ images, self.w2cs, self.projections, texture_size, masks
277
+ )
278
+ elif mode == "opt":
279
+ return self._bake_opt(
280
+ images,
281
+ self.w2cs,
282
+ self.projections,
283
+ texture_size,
284
+ lambda_tv,
285
+ masks,
286
+ opt_step,
287
+ )
288
+ else:
289
+ raise ValueError(f"Unknown mode: {mode}")
290
+
291
+
292
+ def kaolin_to_opencv_view(raw_matrix):
293
+ R_orig = raw_matrix[:, :3, :3]
294
+ t_orig = raw_matrix[:, :3, 3]
295
+
296
+ R_target = torch.zeros_like(R_orig)
297
+ R_target[:, :, 0] = R_orig[:, :, 2]
298
+ R_target[:, :, 1] = R_orig[:, :, 0]
299
+ R_target[:, :, 2] = R_orig[:, :, 1]
300
+
301
+ t_target = t_orig
302
+
303
+ target_matrix = (
304
+ torch.eye(4, device=raw_matrix.device)
305
+ .unsqueeze(0)
306
+ .repeat(raw_matrix.size(0), 1, 1)
307
+ )
308
+ target_matrix[:, :3, :3] = R_target
309
+ target_matrix[:, :3, 3] = t_target
310
+
311
+ return target_matrix
312
+
313
+
314
+ def parse_args():
315
+ parser = argparse.ArgumentParser(description="Render settings")
316
+
317
+ parser.add_argument(
318
+ "--mesh_path",
319
+ type=str,
320
+ nargs="+",
321
+ required=True,
322
+ help="Paths to the mesh files for rendering.",
323
+ )
324
+ parser.add_argument(
325
+ "--input_image",
326
+ type=str,
327
+ nargs="+",
328
+ required=True,
329
+ help="Paths to the mesh files for rendering.",
330
+ )
331
+ parser.add_argument(
332
+ "--output_root",
333
+ type=str,
334
+ default="./outputs",
335
+ help="Root directory for output",
336
+ )
337
+ parser.add_argument(
338
+ "--uuid",
339
+ type=str,
340
+ nargs="+",
341
+ default=None,
342
+ help="uuid for rendering saving.",
343
+ )
344
+ parser.add_argument(
345
+ "--num_images", type=int, default=6, help="Number of images to render."
346
+ )
347
+ parser.add_argument(
348
+ "--elevation",
349
+ type=float,
350
+ nargs="+",
351
+ default=[20.0, -10.0],
352
+ help="Elevation angles for the camera (default: [20.0, -10.0])",
353
+ )
354
+ parser.add_argument(
355
+ "--distance",
356
+ type=float,
357
+ default=5,
358
+ help="Camera distance (default: 5)",
359
+ )
360
+ parser.add_argument(
361
+ "--resolution_hw",
362
+ type=int,
363
+ nargs=2,
364
+ default=(512, 512),
365
+ help="Resolution of the output images (default: (512, 512))",
366
+ )
367
+ parser.add_argument(
368
+ "--fov",
369
+ type=float,
370
+ default=30,
371
+ help="Field of view in degrees (default: 30)",
372
+ )
373
+ parser.add_argument(
374
+ "--device",
375
+ type=str,
376
+ choices=["cpu", "cuda"],
377
+ default="cuda",
378
+ help="Device to run on (default: `cuda`)",
379
+ )
380
+ parser.add_argument(
381
+ "--texture_size",
382
+ type=int,
383
+ default=1024,
384
+ help="Texture size for texture baking (default: 1024)",
385
+ )
386
+ parser.add_argument(
387
+ "--baker_mode",
388
+ type=str,
389
+ default="opt",
390
+ help="Texture baking mode, `fast` or `opt` (default: opt)",
391
+ )
392
+ parser.add_argument(
393
+ "--opt_step",
394
+ type=int,
395
+ default=2500,
396
+ help="Optimization steps for texture baking (default: 2500)",
397
+ )
398
+ parser.add_argument(
399
+ "--mesh_sipmlify_ratio",
400
+ type=float,
401
+ default=0.9,
402
+ help="Mesh simplification ratio (default: 0.9)",
403
+ )
404
+ parser.add_argument(
405
+ "--no_coor_trans",
406
+ action="store_true",
407
+ help="Do not transform the asset coordinate system.",
408
+ )
409
+ parser.add_argument(
410
+ "--delight", action="store_true", help="Use delighting model."
411
+ )
412
+ parser.add_argument(
413
+ "--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
414
+ )
415
+
416
+ args = parser.parse_args()
417
+
418
+ if args.uuid is None:
419
+ args.uuid = []
420
+ for path in args.mesh_path:
421
+ uuid = os.path.basename(path).split(".")[0]
422
+ args.uuid.append(uuid)
423
+
424
+ return args
425
+
426
+
427
+ def entrypoint() -> None:
428
+ args = parse_args()
429
+ camera_params = CameraSetting(
430
+ num_images=args.num_images,
431
+ elevation=args.elevation,
432
+ distance=args.distance,
433
+ resolution_hw=args.resolution_hw,
434
+ fov=math.radians(args.fov),
435
+ device=args.device,
436
+ )
437
+
438
+ for mesh_path, uuid, img_path in zip(
439
+ args.mesh_path, args.uuid, args.input_image
440
+ ):
441
+ mesh = trimesh.load(mesh_path)
442
+ if isinstance(mesh, trimesh.Scene):
443
+ mesh = mesh.dump(concatenate=True)
444
+ vertices, scale, center = normalize_vertices_array(mesh.vertices)
445
+
446
+ if not args.no_coor_trans:
447
+ x_rot = torch.Tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
448
+ z_rot = torch.Tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
449
+ vertices = vertices @ x_rot
450
+ vertices = vertices @ z_rot
451
+
452
+ faces = mesh.faces.cpu().numpy().astype(np.int32)
453
+ vertices = vertices.cpu().numpy().astype(np.float32)
454
+
455
+ if not args.skip_fix_mesh:
456
+ mesh_fixer = MeshFixer(vertices, faces, args.device)
457
+ vertices, faces = mesh_fixer(
458
+ filter_ratio=args.mesh_sipmlify_ratio,
459
+ max_hole_size=0.04,
460
+ resolution=1024,
461
+ num_views=1000,
462
+ norm_mesh_ratio=0.5,
463
+ )
464
+
465
+ vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces)
466
+ texture_backer = TextureBaker(
467
+ vertices,
468
+ faces,
469
+ uvs,
470
+ camera_params,
471
+ )
472
+ images = get_images_from_grid(
473
+ img_path, img_size=camera_params.resolution_hw[0]
474
+ )
475
+ if args.delight:
476
+ delight_model = DelightingModel(
477
+ model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/hunyuan3d-delight-v2-0" # noqa
478
+ )
479
+ delight_images = [delight_model(img) for img in images]
480
+ images = [np.array(img) for img in delight_images]
481
+
482
+ texture = texture_backer.bake_texture(
483
+ images=[img[..., :3] for img in images],
484
+ texture_size=args.texture_size,
485
+ mode=args.baker_mode,
486
+ opt_step=args.opt_step,
487
+ )
488
+ texture = post_process_texture(texture)
489
+
490
+ if not args.no_coor_trans:
491
+ vertices = vertices @ np.linalg.inv(z_rot)
492
+ vertices = vertices @ np.linalg.inv(x_rot)
493
+ vertices = vertices / scale
494
+ vertices = vertices + center
495
+
496
+ output_path = os.path.join(args.output_root, f"{uuid}.obj")
497
+ mesh = save_mesh_with_mtl(vertices, faces, uvs, texture, output_path)
498
+
499
+ return
500
+
501
+
502
+ if __name__ == "__main__":
503
+ entrypoint()
asset3d_gen/data/backproject_v2.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ import os
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import nvdiffrast.torch as dr
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import trimesh
12
+ import xatlas
13
+ from PIL import Image
14
+ from asset3d_gen.data.mesh_operator import MeshFixer
15
+ from asset3d_gen.data.utils import (
16
+ CameraSetting,
17
+ DiffrastRender,
18
+ get_images_from_grid,
19
+ init_kal_camera,
20
+ normalize_vertices_array,
21
+ post_process_texture,
22
+ save_mesh_with_mtl,
23
+ )
24
+ from asset3d_gen.models.delight import DelightingModel
25
+ from asset3d_gen.models.super_resolution import ImageRealESRGAN
26
+
27
+ logging.basicConfig(
28
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
29
+ )
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ __all__ = [
34
+ "TextureBacker",
35
+ ]
36
+
37
+
38
+ def transform_vertices(
39
+ mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
40
+ ) -> torch.Tensor:
41
+ """Transform 3D vertices using a projection matrix."""
42
+ t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
43
+ if pos.size(-1) == 3:
44
+ pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
45
+
46
+ result = pos @ t_mtx.T
47
+
48
+ return result if keepdim else result.unsqueeze(0)
49
+
50
+
51
+ def _bilinear_interpolation_scattering(
52
+ image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
53
+ ) -> torch.Tensor:
54
+ """Bilinear interpolation scattering for grid-based value accumulation."""
55
+ device = values.device
56
+ dtype = values.dtype
57
+ C = values.shape[-1]
58
+
59
+ indices = coords * torch.tensor(
60
+ [image_h - 1, image_w - 1], dtype=dtype, device=device
61
+ )
62
+ i, j = indices.unbind(-1)
63
+
64
+ i0, j0 = (
65
+ indices.floor()
66
+ .long()
67
+ .clamp(0, image_h - 2)
68
+ .clamp(0, image_w - 2)
69
+ .unbind(-1)
70
+ )
71
+ i1, j1 = i0 + 1, j0 + 1
72
+
73
+ w_i = i - i0.float()
74
+ w_j = j - j0.float()
75
+ weights = torch.stack(
76
+ [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
77
+ dim=1,
78
+ )
79
+
80
+ indices_comb = torch.stack(
81
+ [
82
+ torch.stack([i0, j0], dim=1),
83
+ torch.stack([i0, j1], dim=1),
84
+ torch.stack([i1, j0], dim=1),
85
+ torch.stack([i1, j1], dim=1),
86
+ ],
87
+ dim=1,
88
+ )
89
+
90
+ grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
91
+ cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
92
+
93
+ for k in range(4):
94
+ idx = indices_comb[:, k]
95
+ w = weights[:, k].unsqueeze(-1)
96
+
97
+ stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
98
+ flat_idx = (idx * stride).sum(-1)
99
+
100
+ grid.view(-1, C).scatter_add_(
101
+ 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
102
+ )
103
+ cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
104
+
105
+ mask = cnt.squeeze(-1) > 0
106
+ grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
107
+
108
+ return grid
109
+
110
+
111
+ def _texture_inpaint_smooth(
112
+ texture: np.ndarray,
113
+ mask: np.ndarray,
114
+ vertices: np.ndarray,
115
+ faces: np.ndarray,
116
+ uv_map: np.ndarray,
117
+ ) -> tuple[np.ndarray, np.ndarray]:
118
+ """Perform texture inpainting using vertex-based color propagation."""
119
+ image_h, image_w, C = texture.shape
120
+ N = vertices.shape[0]
121
+
122
+ # Initialize vertex data structures
123
+ vtx_mask = np.zeros(N, dtype=np.float32)
124
+ vtx_colors = np.zeros((N, C), dtype=np.float32)
125
+ unprocessed = []
126
+ adjacency = [[] for _ in range(N)]
127
+
128
+ # Build adjacency graph and initial color assignment
129
+ for face_idx in range(faces.shape[0]):
130
+ for k in range(3):
131
+ uv_idx_k = faces[face_idx, k]
132
+ v_idx = faces[face_idx, k]
133
+
134
+ # Convert UV to pixel coordinates with boundary clamping
135
+ u = np.clip(
136
+ int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
137
+ )
138
+ v = np.clip(
139
+ int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
140
+ 0,
141
+ image_h - 1,
142
+ )
143
+
144
+ if mask[v, u]:
145
+ vtx_mask[v_idx] = 1.0
146
+ vtx_colors[v_idx] = texture[v, u]
147
+ elif v_idx not in unprocessed:
148
+ unprocessed.append(v_idx)
149
+
150
+ # Build undirected adjacency graph
151
+ neighbor = faces[face_idx, (k + 1) % 3]
152
+ if neighbor not in adjacency[v_idx]:
153
+ adjacency[v_idx].append(neighbor)
154
+ if v_idx not in adjacency[neighbor]:
155
+ adjacency[neighbor].append(v_idx)
156
+
157
+ # Color propagation with dynamic stopping
158
+ remaining_iters, prev_count = 2, 0
159
+ while remaining_iters > 0:
160
+ current_unprocessed = []
161
+
162
+ for v_idx in unprocessed:
163
+ valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
164
+ if not valid_neighbors:
165
+ current_unprocessed.append(v_idx)
166
+ continue
167
+
168
+ # Calculate inverse square distance weights
169
+ neighbors_pos = vertices[valid_neighbors]
170
+ dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
171
+ weights = 1 / np.maximum(dist_sq, 1e-8)
172
+
173
+ vtx_colors[v_idx] = np.average(
174
+ vtx_colors[valid_neighbors], weights=weights, axis=0
175
+ )
176
+ vtx_mask[v_idx] = 1.0
177
+
178
+ # Update iteration control
179
+ if len(current_unprocessed) == prev_count:
180
+ remaining_iters -= 1
181
+ else:
182
+ remaining_iters = min(remaining_iters + 1, 2)
183
+ prev_count = len(current_unprocessed)
184
+ unprocessed = current_unprocessed
185
+
186
+ # Generate output texture
187
+ inpainted_texture, updated_mask = texture.copy(), mask.copy()
188
+ for face_idx in range(faces.shape[0]):
189
+ for k in range(3):
190
+ v_idx = faces[face_idx, k]
191
+ if not vtx_mask[v_idx]:
192
+ continue
193
+
194
+ # UV coordinate conversion
195
+ uv_idx_k = faces[face_idx, k]
196
+ u = np.clip(
197
+ int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
198
+ )
199
+ v = np.clip(
200
+ int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
201
+ 0,
202
+ image_h - 1,
203
+ )
204
+
205
+ inpainted_texture[v, u] = vtx_colors[v_idx]
206
+ updated_mask[v, u] = 255
207
+
208
+ return inpainted_texture, updated_mask
209
+
210
+
211
+ class TextureBacker:
212
+ """Texture baking pipeline for multi-view projection and fusion."""
213
+
214
+ def __init__(
215
+ self,
216
+ camera_params: CameraSetting,
217
+ view_weights: list[float],
218
+ render_wh: tuple[int, int] = (2048, 2048),
219
+ texture_wh: tuple[int, int] = (2048, 2048),
220
+ bake_angle_thresh: int = 75,
221
+ mask_thresh: float = 0.5,
222
+ ):
223
+ camera = init_kal_camera(camera_params)
224
+ mv = camera.view_matrix() # (n 4 4) world2cam
225
+ p = camera.intrinsics.projection_matrix()
226
+ # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
227
+ p[:, 1, 1] = -p[:, 1, 1]
228
+ renderer = DiffrastRender(
229
+ p_matrix=p,
230
+ mv_matrix=mv,
231
+ resolution_hw=camera_params.resolution_hw,
232
+ context=dr.RasterizeCudaContext(),
233
+ mask_thresh=mask_thresh,
234
+ grad_db=False,
235
+ device=camera_params.device,
236
+ antialias_mask=True,
237
+ )
238
+ self.camera = camera
239
+ self.renderer = renderer
240
+ self.view_weights = view_weights
241
+ self.device = camera_params.device
242
+ self.render_wh = render_wh
243
+ self.texture_wh = texture_wh
244
+
245
+ self.bake_angle_thresh = bake_angle_thresh
246
+ self.bake_unreliable_kernel_size = int(
247
+ (2 / 512) * max(self.render_wh[0], self.render_wh[1])
248
+ )
249
+
250
+ def load_mesh(self, mesh: trimesh.Trimesh) -> None:
251
+ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
252
+ self.scale, self.center = scale, center
253
+
254
+ vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
255
+ uvs[:, 1] = 1 - uvs[:, 1]
256
+ mesh.vertices = mesh.vertices[vmapping]
257
+ mesh.faces = indices
258
+ mesh.visual.uv = uvs
259
+
260
+ self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
261
+ self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
262
+ self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
263
+
264
+ def get_mesh_np_attrs(
265
+ self,
266
+ scale: float = None,
267
+ center: np.ndarray = None,
268
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
269
+ vertices = self.vertices.cpu().numpy()
270
+ faces = self.faces.cpu().numpy()
271
+ uv_map = self.uv_map.cpu().numpy()
272
+ uv_map[:, 1] = 1.0 - uv_map[:, 1]
273
+
274
+ if scale is not None:
275
+ vertices = vertices / scale
276
+ if center is not None:
277
+ vertices = vertices + center
278
+
279
+ return vertices, faces, uv_map
280
+
281
+ def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
282
+ depth_image_np = depth_image.cpu().numpy()
283
+ depth_image_np = (depth_image_np * 255).astype(np.uint8)
284
+ depth_edges = cv2.Canny(depth_image_np, 30, 80)
285
+ sketch_image = (
286
+ torch.from_numpy(depth_edges).to(depth_image.device).float() / 255
287
+ )
288
+ sketch_image = sketch_image.unsqueeze(-1)
289
+
290
+ return sketch_image
291
+
292
+ def compute_enhanced_viewnormal(
293
+ self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
294
+ ) -> torch.Tensor:
295
+ rast, _ = self.renderer.compute_dr_raster(vertices, faces)
296
+ rendered_view_normals = []
297
+ for idx in range(len(mv_mtx)):
298
+ pos_cam = transform_vertices(mv_mtx[idx], vertices, keepdim=True)
299
+ pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
300
+ v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3))
301
+ face_norm = F.normalize(
302
+ torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1
303
+ )
304
+ vertex_norm = (
305
+ torch.from_numpy(
306
+ trimesh.geometry.mean_vertex_normals(
307
+ len(pos_cam), faces.cpu(), face_norm.cpu()
308
+ )
309
+ )
310
+ .to(vertices.device)
311
+ .contiguous()
312
+ )
313
+ im_base_normals, _ = dr.interpolate(
314
+ vertex_norm[None, ...].float(),
315
+ rast[idx : idx + 1],
316
+ faces.to(torch.int32),
317
+ )
318
+ rendered_view_normals.append(im_base_normals)
319
+
320
+ rendered_view_normals = torch.cat(rendered_view_normals, dim=0)
321
+
322
+ return rendered_view_normals
323
+
324
+ def back_project(
325
+ self, image, vis_mask, depth, normal, uv
326
+ ) -> tuple[torch.Tensor, torch.Tensor]:
327
+ image = np.array(image)
328
+ image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
329
+ if image.ndim == 2:
330
+ image = image.unsqueeze(-1)
331
+ image = image / 255
332
+
333
+ depth_inv = (1.0 - depth) * vis_mask
334
+ sketch_image = self._render_depth_edges(depth_inv)
335
+
336
+ cos = F.cosine_similarity(
337
+ torch.tensor([[0, 0, 1]], device=self.device),
338
+ normal.view(-1, 3),
339
+ ).view_as(normal[..., :1])
340
+ cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
341
+
342
+ k = self.bake_unreliable_kernel_size * 2 + 1
343
+ kernel = torch.ones((1, 1, k, k), device=self.device)
344
+
345
+ vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
346
+ vis_mask = F.conv2d(
347
+ 1.0 - vis_mask,
348
+ kernel,
349
+ padding=k // 2,
350
+ )
351
+ vis_mask = 1.0 - (vis_mask > 0).float()
352
+ vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
353
+
354
+ sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
355
+ sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
356
+ sketch_image = (sketch_image > 0).float()
357
+ sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
358
+ vis_mask = vis_mask * (sketch_image < 0.5)
359
+
360
+ cos[vis_mask == 0] = 0
361
+ valid_pixels = (vis_mask != 0).view(-1)
362
+
363
+ return (
364
+ self._scatter_texture(uv, image, valid_pixels),
365
+ self._scatter_texture(uv, cos, valid_pixels),
366
+ )
367
+
368
+ def _scatter_texture(self, uv, data, mask):
369
+ def __filter_data(data, mask):
370
+ return data.view(-1, data.shape[-1])[mask]
371
+
372
+ return _bilinear_interpolation_scattering(
373
+ self.texture_wh[1],
374
+ self.texture_wh[0],
375
+ __filter_data(uv, mask)[..., [1, 0]],
376
+ __filter_data(data, mask),
377
+ )
378
+
379
+ @torch.no_grad()
380
+ def fast_bake_texture(
381
+ self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
382
+ ) -> tuple[torch.Tensor, torch.Tensor]:
383
+ channel = textures[0].shape[-1]
384
+ texture_merge = torch.zeros(self.texture_wh + [channel]).to(
385
+ self.device
386
+ )
387
+ trust_map_merge = torch.zeros(self.texture_wh + [1]).to(self.device)
388
+ for texture, cos_map in zip(textures, confidence_maps):
389
+ view_sum = (cos_map > 0).sum()
390
+ painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
391
+ if painted_sum / view_sum > 0.99:
392
+ continue
393
+ texture_merge += texture * cos_map
394
+ trust_map_merge += cos_map
395
+ texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
396
+
397
+ return texture_merge, trust_map_merge > 1e-8
398
+
399
+ def uv_inpaint(
400
+ self, texture: torch.Tensor, mask: torch.Tensor
401
+ ) -> np.ndarray:
402
+ texture_np = texture.cpu().numpy()
403
+ mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
404
+ vertices, faces, uv_map = self.get_mesh_np_attrs()
405
+
406
+ texture_np, mask_np = _texture_inpaint_smooth(
407
+ texture_np, mask_np, vertices, faces, uv_map
408
+ )
409
+ texture_np = texture_np.clip(0, 1)
410
+ texture_np = cv2.inpaint(
411
+ (texture_np * 255).astype(np.uint8),
412
+ 255 - mask_np,
413
+ 3,
414
+ cv2.INPAINT_NS,
415
+ )
416
+
417
+ return texture_np
418
+
419
+ def __call__(
420
+ self,
421
+ colors: list[Image.Image],
422
+ mesh: trimesh.Trimesh,
423
+ output_path: str,
424
+ ) -> trimesh.Trimesh:
425
+ self.load_mesh(mesh)
426
+ rendered_depth, masks = self.renderer.render_depth(
427
+ self.vertices, self.faces
428
+ )
429
+ norm_deps = self.renderer.normalize_map_by_mask(rendered_depth, masks)
430
+ render_uvs, _ = self.renderer.render_uv(
431
+ self.vertices, self.faces, self.uv_map
432
+ )
433
+ view_normals = self.compute_enhanced_viewnormal(
434
+ self.renderer.mv_mtx, self.vertices, self.faces
435
+ )
436
+
437
+ textures, weighted_cos_maps = [], []
438
+ for color, mask, dep, normal, uv, weight in zip(
439
+ colors,
440
+ masks,
441
+ norm_deps,
442
+ view_normals,
443
+ render_uvs,
444
+ self.view_weights,
445
+ ):
446
+ texture, cos_map = self.back_project(color, mask, dep, normal, uv)
447
+ textures.append(texture)
448
+ weighted_cos_maps.append(weight * (cos_map**4))
449
+
450
+ texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
451
+ texture_np = self.uv_inpaint(texture, mask)
452
+ texture_np = post_process_texture(texture_np)
453
+ vertices, faces, uv_map = self.get_mesh_np_attrs(
454
+ self.scale, self.center
455
+ )
456
+
457
+ textured_mesh = save_mesh_with_mtl(
458
+ vertices, faces, uv_map, texture_np, output_path
459
+ )
460
+
461
+ return textured_mesh
462
+
463
+
464
+ def parse_args():
465
+ parser = argparse.ArgumentParser(description="Backproject texture")
466
+ parser.add_argument(
467
+ "--color_path",
468
+ type=str,
469
+ help="Multiview color image in 6x512x512 file path",
470
+ )
471
+ parser.add_argument(
472
+ "--mesh_path",
473
+ type=str,
474
+ help="Mesh path, .obj, .glb or .ply",
475
+ )
476
+ parser.add_argument(
477
+ "--output_path",
478
+ type=str,
479
+ help="Output mesh path with suffix",
480
+ )
481
+ parser.add_argument(
482
+ "--num_images", type=int, default=6, help="Number of images to render."
483
+ )
484
+ parser.add_argument(
485
+ "--elevation",
486
+ nargs=2,
487
+ type=float,
488
+ default=[20.0, -10.0],
489
+ help="Elevation angles for the camera (default: [20.0, -10.0])",
490
+ )
491
+ parser.add_argument(
492
+ "--distance",
493
+ type=float,
494
+ default=5,
495
+ help="Camera distance (default: 5)",
496
+ )
497
+ parser.add_argument(
498
+ "--resolution_hw",
499
+ type=int,
500
+ nargs=2,
501
+ default=(2048, 2048),
502
+ help="Resolution of the output images (default: (2048, 2048))",
503
+ )
504
+ parser.add_argument(
505
+ "--fov",
506
+ type=float,
507
+ default=30,
508
+ help="Field of view in degrees (default: 30)",
509
+ )
510
+ parser.add_argument(
511
+ "--device",
512
+ type=str,
513
+ choices=["cpu", "cuda"],
514
+ default="cuda",
515
+ help="Device to run on (default: `cuda`)",
516
+ )
517
+ parser.add_argument(
518
+ "--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
519
+ )
520
+ parser.add_argument(
521
+ "--texture_wh",
522
+ nargs=2,
523
+ type=int,
524
+ default=[2048, 2048],
525
+ help="Texture resolution width and height",
526
+ )
527
+ parser.add_argument(
528
+ "--mesh_sipmlify_ratio",
529
+ type=float,
530
+ default=0.9,
531
+ help="Mesh simplification ratio (default: 0.9)",
532
+ )
533
+ parser.add_argument(
534
+ "--delight", action="store_true", help="Use delighting model."
535
+ )
536
+ args = parser.parse_args()
537
+
538
+ return args
539
+
540
+
541
+ def entrypoint(
542
+ delight_model: DelightingModel = None,
543
+ imagesr_model: ImageRealESRGAN = None,
544
+ **kwargs,
545
+ ) -> trimesh.Trimesh:
546
+ args = parse_args()
547
+ for k, v in kwargs.items():
548
+ if hasattr(args, k) and v is not None:
549
+ setattr(args, k, v)
550
+
551
+ # Setup camera parameters.
552
+ camera_params = CameraSetting(
553
+ num_images=args.num_images,
554
+ elevation=args.elevation,
555
+ distance=args.distance,
556
+ resolution_hw=args.resolution_hw,
557
+ fov=math.radians(args.fov),
558
+ device=args.device,
559
+ )
560
+ view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
561
+
562
+ color_grid = Image.open(args.color_path)
563
+ if args.delight:
564
+ if delight_model is None:
565
+ delight_model = DelightingModel(
566
+ model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/hunyuan3d-delight-v2-0" # noqa
567
+ )
568
+ save_dir = os.path.dirname(args.output_path)
569
+ os.makedirs(save_dir, exist_ok=True)
570
+ color_grid.save(f"{save_dir}/color_grid.png")
571
+ color_grid = delight_model(color_grid)
572
+ color_grid.save(f"{save_dir}/color_grid_delight.png")
573
+
574
+ multiviews = get_images_from_grid(color_grid, img_size=512)
575
+
576
+ # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
577
+ if imagesr_model is None:
578
+ imagesr_model = ImageRealESRGAN(outscale=4)
579
+ multiviews = [imagesr_model(img) for img in multiviews]
580
+ multiviews = [img.convert("RGB") for img in multiviews]
581
+ mesh = trimesh.load(args.mesh_path)
582
+ if isinstance(mesh, trimesh.Scene):
583
+ mesh = mesh.dump(concatenate=True)
584
+
585
+ if not args.skip_fix_mesh:
586
+ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
587
+ mesh_fixer = MeshFixer(mesh.vertices, mesh.faces, args.device)
588
+ mesh.vertices, mesh.faces = mesh_fixer(
589
+ filter_ratio=args.mesh_sipmlify_ratio,
590
+ max_hole_size=0.04,
591
+ resolution=1024,
592
+ num_views=1000,
593
+ norm_mesh_ratio=0.5,
594
+ )
595
+ # Restore scale.
596
+ mesh.vertices = mesh.vertices / scale
597
+ mesh.vertices = mesh.vertices + center
598
+
599
+ # Baking texture to mesh.
600
+ texture_backer = TextureBacker(
601
+ camera_params=camera_params,
602
+ view_weights=view_weights,
603
+ render_wh=camera_params.resolution_hw,
604
+ texture_wh=args.texture_wh,
605
+ )
606
+
607
+ textured_mesh = texture_backer(multiviews, mesh, args.output_path)
608
+
609
+ return textured_mesh
610
+
611
+
612
+ if __name__ == "__main__":
613
+ entrypoint()
asset3d_gen/data/backup/backproject_v2 copy.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ import os
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import nvdiffrast.torch as dr
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torchvision.transforms import functional as tF
12
+
13
+ import trimesh
14
+ import xatlas
15
+ from PIL import Image
16
+ from asset3d_gen.data.mesh_operator import MeshFixer
17
+ from asset3d_gen.data.utils import (
18
+ CameraSetting,
19
+ DiffrastRender,
20
+ get_images_from_grid,
21
+ init_kal_camera,
22
+ normalize_vertices_array,
23
+ post_process_texture,
24
+ save_mesh_with_mtl,
25
+ )
26
+ from asset3d_gen.models.delight import DelightingModel
27
+ from asset3d_gen.models.super_resolution import ImageRealESRGAN
28
+
29
+ logging.basicConfig(
30
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
31
+ )
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ __all__ = [
36
+ "TextureBacker",
37
+ ]
38
+
39
+
40
+ def transform_vertices(
41
+ mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
42
+ ) -> torch.Tensor:
43
+ """Transform 3D vertices using a projection matrix."""
44
+ t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
45
+ if pos.size(-1) == 3:
46
+ pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
47
+
48
+ result = pos @ t_mtx.T
49
+
50
+ return result if keepdim else result.unsqueeze(0)
51
+
52
+
53
+ def _bilinear_interpolation_scattering(
54
+ image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
55
+ ) -> torch.Tensor:
56
+ """Bilinear interpolation scattering for grid-based value accumulation."""
57
+ device = values.device
58
+ dtype = values.dtype
59
+ C = values.shape[-1]
60
+
61
+ indices = coords * torch.tensor(
62
+ [image_h - 1, image_w - 1], dtype=dtype, device=device
63
+ )
64
+ i, j = indices.unbind(-1)
65
+
66
+ i0, j0 = (
67
+ indices.floor()
68
+ .long()
69
+ .clamp(0, image_h - 2)
70
+ .clamp(0, image_w - 2)
71
+ .unbind(-1)
72
+ )
73
+ i1, j1 = i0 + 1, j0 + 1
74
+
75
+ w_i = i - i0.float()
76
+ w_j = j - j0.float()
77
+ weights = torch.stack(
78
+ [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
79
+ dim=1,
80
+ )
81
+
82
+ indices_comb = torch.stack(
83
+ [
84
+ torch.stack([i0, j0], dim=1),
85
+ torch.stack([i0, j1], dim=1),
86
+ torch.stack([i1, j0], dim=1),
87
+ torch.stack([i1, j1], dim=1),
88
+ ],
89
+ dim=1,
90
+ )
91
+
92
+ grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
93
+ cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
94
+
95
+ for k in range(4):
96
+ idx = indices_comb[:, k]
97
+ w = weights[:, k].unsqueeze(-1)
98
+
99
+ stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
100
+ flat_idx = (idx * stride).sum(-1)
101
+
102
+ grid.view(-1, C).scatter_add_(
103
+ 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
104
+ )
105
+ cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
106
+
107
+ mask = cnt.squeeze(-1) > 0
108
+ grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
109
+
110
+ return grid
111
+
112
+
113
+ def _texture_inpaint_smooth(
114
+ texture: np.ndarray,
115
+ mask: np.ndarray,
116
+ vertices: np.ndarray,
117
+ faces: np.ndarray,
118
+ uv_map: np.ndarray,
119
+ ) -> tuple[np.ndarray, np.ndarray]:
120
+ """Perform texture inpainting using vertex-based color propagation."""
121
+ image_h, image_w, C = texture.shape
122
+ N = vertices.shape[0]
123
+
124
+ # Initialize vertex data structures
125
+ vtx_mask = np.zeros(N, dtype=np.float32)
126
+ vtx_colors = np.zeros((N, C), dtype=np.float32)
127
+ unprocessed = []
128
+ adjacency = [[] for _ in range(N)]
129
+
130
+ # Build adjacency graph and initial color assignment
131
+ for face_idx in range(faces.shape[0]):
132
+ for k in range(3):
133
+ uv_idx_k = faces[face_idx, k]
134
+ v_idx = faces[face_idx, k]
135
+
136
+ # Convert UV to pixel coordinates with boundary clamping
137
+ u = np.clip(
138
+ int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
139
+ )
140
+ v = np.clip(
141
+ int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
142
+ 0,
143
+ image_h - 1,
144
+ )
145
+
146
+ if mask[v, u]:
147
+ vtx_mask[v_idx] = 1.0
148
+ vtx_colors[v_idx] = texture[v, u]
149
+ elif v_idx not in unprocessed:
150
+ unprocessed.append(v_idx)
151
+
152
+ # Build undirected adjacency graph
153
+ neighbor = faces[face_idx, (k + 1) % 3]
154
+ if neighbor not in adjacency[v_idx]:
155
+ adjacency[v_idx].append(neighbor)
156
+ if v_idx not in adjacency[neighbor]:
157
+ adjacency[neighbor].append(v_idx)
158
+
159
+ # Color propagation with dynamic stopping
160
+ remaining_iters, prev_count = 2, 0
161
+ while remaining_iters > 0:
162
+ current_unprocessed = []
163
+
164
+ for v_idx in unprocessed:
165
+ valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
166
+ if not valid_neighbors:
167
+ current_unprocessed.append(v_idx)
168
+ continue
169
+
170
+ # Calculate inverse square distance weights
171
+ neighbors_pos = vertices[valid_neighbors]
172
+ dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
173
+ weights = 1 / np.maximum(dist_sq, 1e-8)
174
+
175
+ vtx_colors[v_idx] = np.average(
176
+ vtx_colors[valid_neighbors], weights=weights, axis=0
177
+ )
178
+ vtx_mask[v_idx] = 1.0
179
+
180
+ # Update iteration control
181
+ if len(current_unprocessed) == prev_count:
182
+ remaining_iters -= 1
183
+ else:
184
+ remaining_iters = min(remaining_iters + 1, 2)
185
+ prev_count = len(current_unprocessed)
186
+ unprocessed = current_unprocessed
187
+
188
+ # Generate output texture
189
+ inpainted_texture, updated_mask = texture.copy(), mask.copy()
190
+ for face_idx in range(faces.shape[0]):
191
+ for k in range(3):
192
+ v_idx = faces[face_idx, k]
193
+ if not vtx_mask[v_idx]:
194
+ continue
195
+
196
+ # UV coordinate conversion
197
+ uv_idx_k = faces[face_idx, k]
198
+ u = np.clip(
199
+ int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
200
+ )
201
+ v = np.clip(
202
+ int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
203
+ 0,
204
+ image_h - 1,
205
+ )
206
+
207
+ inpainted_texture[v, u] = vtx_colors[v_idx]
208
+ updated_mask[v, u] = 255
209
+
210
+ return inpainted_texture, updated_mask
211
+
212
+
213
+ def interp_tensers(tensors: list[torch.Tensor], target_wh: tuple[int, int]) -> list[torch.Tensor]:
214
+ for idx in range(len(tensors)):
215
+ tensor = tensors[idx].permute(2, 0, 1)
216
+ tensor = tF.resize(tensor, target_wh[::-1], antialias=True)
217
+ tensors[idx] = tensor.permute(1, 2, 0)
218
+
219
+ return tensors
220
+
221
+
222
+ class TextureBacker:
223
+ """Texture baking pipeline for multi-view projection and fusion."""
224
+
225
+ def __init__(
226
+ self,
227
+ camera_params: CameraSetting,
228
+ view_weights: list[float],
229
+ render_wh: tuple[int, int] = (2048, 2048),
230
+ texture_wh: tuple[int, int] = (2048, 2048),
231
+ bake_angle_thresh: int = 75,
232
+ mask_thresh: float = 0.5,
233
+ ):
234
+ camera = init_kal_camera(camera_params)
235
+ mv = camera.view_matrix() # (n 4 4) world2cam
236
+ p = camera.intrinsics.projection_matrix()
237
+ # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
238
+ p[:, 1, 1] = -p[:, 1, 1]
239
+ self.renderer = DiffrastRender(
240
+ p_matrix=p,
241
+ mv_matrix=mv,
242
+ resolution_hw=camera_params.resolution_hw,
243
+ context=dr.RasterizeCudaContext(),
244
+ mask_thresh=mask_thresh,
245
+ grad_db=False,
246
+ device=camera_params.device,
247
+ antialias_mask=True,
248
+ )
249
+ self.camera = camera
250
+ self.view_weights = view_weights
251
+ self.device = camera_params.device
252
+ self.render_wh = render_wh
253
+ self.texture_wh = texture_wh
254
+
255
+ self.bake_angle_thresh = bake_angle_thresh
256
+ self.bake_unreliable_kernel_size = int(
257
+ (2 / 512) * max(self.render_wh[0], self.render_wh[1])
258
+ )
259
+
260
+ def load_mesh(self, mesh: trimesh.Trimesh) -> None:
261
+ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
262
+ self.scale, self.center = scale, center
263
+
264
+ vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
265
+ uvs[:, 1] = 1 - uvs[:, 1]
266
+ mesh.vertices = mesh.vertices[vmapping]
267
+ mesh.faces = indices
268
+ mesh.visual.uv = uvs
269
+
270
+ self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
271
+ self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
272
+ self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
273
+
274
+ def get_mesh_np_attrs(
275
+ self,
276
+ scale: float = None,
277
+ center: np.ndarray = None,
278
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
279
+ vertices = self.vertices.cpu().numpy()
280
+ faces = self.faces.cpu().numpy()
281
+ uv_map = self.uv_map.cpu().numpy()
282
+ uv_map[:, 1] = 1.0 - uv_map[:, 1]
283
+
284
+ if scale is not None:
285
+ vertices = vertices / scale
286
+ if center is not None:
287
+ vertices = vertices + center
288
+
289
+ return vertices, faces, uv_map
290
+
291
+ def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
292
+ depth_image_np = depth_image.cpu().numpy()
293
+ depth_image_np = (depth_image_np * 255).astype(np.uint8)
294
+ depth_edges = cv2.Canny(depth_image_np, 30, 80)
295
+ sketch_image = (
296
+ torch.from_numpy(depth_edges).to(depth_image.device).float() / 255
297
+ )
298
+ sketch_image = sketch_image.unsqueeze(-1)
299
+
300
+ return sketch_image
301
+
302
+ def compute_enhanced_viewnormal(
303
+ self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
304
+ ) -> torch.Tensor:
305
+ rast, _ = self.renderer.compute_dr_raster(vertices, faces)
306
+ rendered_view_normals = []
307
+ for idx in range(len(mv_mtx)):
308
+ pos_cam = transform_vertices(mv_mtx[idx], vertices, keepdim=True)
309
+ pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
310
+ v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3))
311
+ face_norm = F.normalize(
312
+ torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1
313
+ )
314
+ vertex_norm = (
315
+ torch.from_numpy(
316
+ trimesh.geometry.mean_vertex_normals(
317
+ len(pos_cam), faces.cpu(), face_norm.cpu()
318
+ )
319
+ )
320
+ .to(vertices.device)
321
+ .contiguous()
322
+ )
323
+ im_base_normals, _ = dr.interpolate(
324
+ vertex_norm[None, ...].float(),
325
+ rast[idx : idx + 1],
326
+ faces.to(torch.int32),
327
+ )
328
+ rendered_view_normals.append(im_base_normals)
329
+
330
+ rendered_view_normals = torch.cat(rendered_view_normals, dim=0)
331
+
332
+ return rendered_view_normals
333
+
334
+ def back_project(
335
+ self, image, vis_mask, depth, normal, uv
336
+ ) -> tuple[torch.Tensor, torch.Tensor]:
337
+ image = np.array(image)
338
+ image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
339
+ if image.ndim == 2:
340
+ image = image.unsqueeze(-1)
341
+ image = image / 255
342
+
343
+ depth_inv = (1.0 - depth) * vis_mask
344
+ sketch_image = self._render_depth_edges(depth_inv)
345
+
346
+ cos = F.cosine_similarity(
347
+ torch.tensor([[0, 0, 1]], device=self.device),
348
+ normal.view(-1, 3),
349
+ ).view_as(normal[..., :1])
350
+ cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
351
+
352
+ k = self.bake_unreliable_kernel_size * 2 + 1
353
+ kernel = torch.ones((1, 1, k, k), device=self.device)
354
+
355
+ vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
356
+ vis_mask = F.conv2d(
357
+ 1.0 - vis_mask,
358
+ kernel,
359
+ padding=k // 2,
360
+ )
361
+ vis_mask = 1.0 - (vis_mask > 0).float()
362
+ vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
363
+
364
+ sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
365
+ sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
366
+ sketch_image = (sketch_image > 0).float()
367
+ sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
368
+ vis_mask = vis_mask * (sketch_image < 0.5)
369
+
370
+ cos[vis_mask == 0] = 0
371
+ valid_pixels = (vis_mask != 0).view(-1)
372
+
373
+ return (
374
+ self._scatter_texture(uv, image, valid_pixels),
375
+ self._scatter_texture(uv, cos, valid_pixels),
376
+ )
377
+
378
+ def _scatter_texture(self, uv, data, mask):
379
+ def __filter_data(data, mask):
380
+ return data.view(-1, data.shape[-1])[mask]
381
+
382
+ return _bilinear_interpolation_scattering(
383
+ self.texture_wh[1],
384
+ self.texture_wh[0],
385
+ __filter_data(uv, mask)[..., [1, 0]],
386
+ __filter_data(data, mask),
387
+ )
388
+
389
+ @torch.no_grad()
390
+ def fast_bake_texture(
391
+ self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
392
+ ) -> tuple[torch.Tensor, torch.Tensor]:
393
+ channel = textures[0].shape[-1]
394
+ texture_merge = torch.zeros(self.texture_wh + [channel]).to(
395
+ self.device
396
+ )
397
+ trust_map_merge = torch.zeros(self.texture_wh + [1]).to(self.device)
398
+ for texture, cos_map in zip(textures, confidence_maps):
399
+ view_sum = (cos_map > 0).sum()
400
+ painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
401
+ if painted_sum / view_sum > 0.99:
402
+ continue
403
+ texture_merge += texture * cos_map
404
+ trust_map_merge += cos_map
405
+ texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
406
+
407
+ return texture_merge, trust_map_merge > 1e-8
408
+
409
+ def uv_inpaint(
410
+ self, texture: torch.Tensor, mask: torch.Tensor
411
+ ) -> np.ndarray:
412
+ texture_np = texture.cpu().numpy()
413
+ mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
414
+ vertices, faces, uv_map = self.get_mesh_np_attrs()
415
+
416
+ texture_np, mask_np = _texture_inpaint_smooth(
417
+ texture_np, mask_np, vertices, faces, uv_map
418
+ )
419
+ texture_np = texture_np.clip(0, 1)
420
+ texture_np = cv2.inpaint(
421
+ (texture_np * 255).astype(np.uint8),
422
+ 255 - mask_np,
423
+ 3,
424
+ cv2.INPAINT_NS,
425
+ )
426
+
427
+ return texture_np
428
+
429
+ def __call__(
430
+ self,
431
+ colors: list[Image.Image],
432
+ mesh: trimesh.Trimesh,
433
+ output_path: str,
434
+ ) -> trimesh.Trimesh:
435
+ import time
436
+ start = time.time()
437
+ self.load_mesh(mesh)
438
+ print("load_mesh", time.time() - start)
439
+
440
+ start = time.time()
441
+ rendered_depth, masks = self.renderer.render_depth(
442
+ self.vertices, self.faces
443
+ )
444
+ norm_deps = self.renderer.normalize_map_by_mask(rendered_depth, masks)
445
+ render_uvs, _ = self.renderer.render_uv(
446
+ self.vertices, self.faces, self.uv_map
447
+ )
448
+ view_normals = self.compute_enhanced_viewnormal(
449
+ self.renderer.mv_mtx, self.vertices, self.faces
450
+ )
451
+ print("0", time.time() - start)
452
+
453
+ textures, weighted_cos_maps = [], []
454
+
455
+ start = time.time()
456
+ for color, mask, dep, normal, uv, weight in zip(
457
+ colors,
458
+ masks,
459
+ norm_deps,
460
+ view_normals,
461
+ render_uvs,
462
+ self.view_weights,
463
+ ):
464
+ mask, dep, normal, uv = interp_tensers([mask, dep, normal, uv], self.render_wh)
465
+ texture, cos_map = self.back_project(color, mask, dep, normal, uv)
466
+ textures.append(texture)
467
+ weighted_cos_maps.append(weight * (cos_map**4))
468
+ print("1", time.time() - start)
469
+ start = time.time()
470
+ texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
471
+ print("2", time.time() - start)
472
+ start = time.time()
473
+ texture_np = self.uv_inpaint(texture, mask)
474
+ print("3", time.time() - start)
475
+ start = time.time()
476
+ texture_np = post_process_texture(texture_np)
477
+ vertices, faces, uv_map = self.get_mesh_np_attrs(
478
+ self.scale, self.center
479
+ )
480
+
481
+ textured_mesh = save_mesh_with_mtl(
482
+ vertices, faces, uv_map, texture_np, output_path
483
+ )
484
+ print("4", time.time() - start)
485
+
486
+ return textured_mesh
487
+
488
+
489
+ def parse_args():
490
+ parser = argparse.ArgumentParser(description="Backproject texture")
491
+ parser.add_argument(
492
+ "--color_path",
493
+ type=str,
494
+ help="Multiview color image in 6x512x512 file path",
495
+ )
496
+ parser.add_argument(
497
+ "--mesh_path",
498
+ type=str,
499
+ help="Mesh path, .obj, .glb or .ply",
500
+ )
501
+ parser.add_argument(
502
+ "--output_path",
503
+ type=str,
504
+ help="Output mesh path with suffix",
505
+ )
506
+ parser.add_argument(
507
+ "--num_images", type=int, default=6, help="Number of images to render."
508
+ )
509
+ parser.add_argument(
510
+ "--elevation",
511
+ nargs=2,
512
+ type=float,
513
+ default=[20.0, -10.0],
514
+ help="Elevation angles for the camera (default: [20.0, -10.0])",
515
+ )
516
+ parser.add_argument(
517
+ "--distance",
518
+ type=float,
519
+ default=5,
520
+ help="Camera distance (default: 5)",
521
+ )
522
+ parser.add_argument(
523
+ "--resolution_hw",
524
+ type=int,
525
+ nargs=2,
526
+ default=(2048, 2048),
527
+ help="Resolution of the mesh rendering",
528
+ )
529
+ parser.add_argument(
530
+ "--target_hw",
531
+ type=int,
532
+ nargs=2,
533
+ default=(2048, 2048),
534
+ help="Target rendering images resolution",
535
+ )
536
+ parser.add_argument(
537
+ "--fov",
538
+ type=float,
539
+ default=30,
540
+ help="Field of view in degrees (default: 30)",
541
+ )
542
+ parser.add_argument(
543
+ "--device",
544
+ type=str,
545
+ choices=["cpu", "cuda"],
546
+ default="cuda",
547
+ help="Device to run on (default: `cuda`)",
548
+ )
549
+ parser.add_argument(
550
+ "--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
551
+ )
552
+ parser.add_argument(
553
+ "--texture_wh",
554
+ nargs=2,
555
+ type=int,
556
+ default=[2048, 2048],
557
+ help="Texture resolution width and height",
558
+ )
559
+ parser.add_argument(
560
+ "--mesh_sipmlify_ratio",
561
+ type=float,
562
+ default=0.9,
563
+ help="Mesh simplification ratio (default: 0.9)",
564
+ )
565
+ parser.add_argument(
566
+ "--delight", action="store_true", help="Use delighting model."
567
+ )
568
+ args = parser.parse_args()
569
+
570
+ return args
571
+
572
+
573
+ def entrypoint(
574
+ delight_model: DelightingModel = None,
575
+ imagesr_model: ImageRealESRGAN = None,
576
+ **kwargs,
577
+ ) -> trimesh.Trimesh:
578
+ args = parse_args()
579
+ for k, v in kwargs.items():
580
+ if hasattr(args, k) and v is not None:
581
+ setattr(args, k, v)
582
+
583
+ # Setup camera parameters.
584
+ camera_params = CameraSetting(
585
+ num_images=args.num_images,
586
+ elevation=args.elevation,
587
+ distance=args.distance,
588
+ resolution_hw=args.resolution_hw,
589
+ fov=math.radians(args.fov),
590
+ device=args.device,
591
+ )
592
+ view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
593
+
594
+ color_grid = Image.open(args.color_path)
595
+ if args.delight:
596
+ if delight_model is None:
597
+ delight_model = DelightingModel(
598
+ model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/hunyuan3d-delight-v2-0" # noqa
599
+ )
600
+ save_dir = os.path.dirname(args.output_path)
601
+ os.makedirs(save_dir, exist_ok=True)
602
+ color_grid.save(f"{save_dir}/color_grid.png")
603
+ color_grid = delight_model(color_grid)
604
+ color_grid.save(f"{save_dir}/color_grid_delight.png")
605
+
606
+ multiviews = get_images_from_grid(color_grid, img_size=512)
607
+
608
+ # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
609
+ if imagesr_model is None:
610
+ imagesr_model = ImageRealESRGAN(outscale=4)
611
+ multiviews = [imagesr_model(img.convert("RGB")) for img in multiviews]
612
+ multiviews = [img.resize(args.target_hw[::-1]) for img in multiviews]
613
+
614
+ mesh = trimesh.load(args.mesh_path)
615
+ if isinstance(mesh, trimesh.Scene):
616
+ mesh = mesh.dump(concatenate=True)
617
+
618
+ if not args.skip_fix_mesh:
619
+ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
620
+ mesh_fixer = MeshFixer(mesh.vertices, mesh.faces, args.device)
621
+ mesh.vertices, mesh.faces = mesh_fixer(
622
+ filter_ratio=args.mesh_sipmlify_ratio,
623
+ max_hole_size=0.04,
624
+ resolution=1024,
625
+ num_views=1000,
626
+ norm_mesh_ratio=0.5,
627
+ )
628
+ # Restore scale.
629
+ mesh.vertices = mesh.vertices / scale
630
+ mesh.vertices = mesh.vertices + center
631
+
632
+ # Baking texture to mesh.
633
+ import time
634
+ start = time.time()
635
+ texture_backer = TextureBacker(
636
+ camera_params=camera_params,
637
+ view_weights=view_weights,
638
+ render_wh=args.target_hw,
639
+ texture_wh=args.texture_wh,
640
+ )
641
+ print(time.time()-start)
642
+ start = time.time()
643
+ textured_mesh = texture_backer(multiviews, mesh, args.output_path)
644
+ print(f"Texture backproject time: {time.time() - start:.2f}s")
645
+
646
+ return textured_mesh
647
+
648
+
649
+ if __name__ == "__main__":
650
+ entrypoint()
asset3d_gen/data/backup/backproject_v2.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from typing import Union
4
+
5
+ import custom_rasterizer as cr
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import trimesh
11
+ import xatlas
12
+ from PIL import Image
13
+ from asset3d_gen.data.utils import (
14
+ get_images_from_file,
15
+ normalize_vertices_array,
16
+ post_process_texture,
17
+ save_mesh_with_mtl,
18
+ )
19
+
20
+ logging.basicConfig(
21
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
22
+ )
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ __all__ = ["TextureBacker", "Image_Super_Net", "Image_GANNet"]
27
+
28
+
29
+ import math
30
+ import numpy as np
31
+
32
+
33
+ def get_perspective_projection(
34
+ fov: float, aspect_wh: float, near: float = 0.01, far: float = 100
35
+ ) -> np.ndarray:
36
+ """Compute the perspective projection matrix for 3D rendering."""
37
+ fov_rad = math.radians(fov)
38
+ tan_half_fov = math.tan(fov_rad / 2.0)
39
+
40
+ return np.array(
41
+ [
42
+ [1.0 / (tan_half_fov * aspect_wh), 0.0, 0.0, 0.0],
43
+ [0.0, 1.0 / tan_half_fov, 0.0, 0.0],
44
+ [
45
+ 0.0,
46
+ 0.0,
47
+ -(far + near) / (far - near),
48
+ -(2.0 * far * near) / (far - near),
49
+ ],
50
+ [0.0, 0.0, -1.0, 0.0],
51
+ ],
52
+ dtype=np.float32,
53
+ )
54
+
55
+
56
+ def transform_vertices(
57
+ mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
58
+ ) -> torch.Tensor:
59
+ """Transform 3D vertices using a projection matrix."""
60
+ t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
61
+ if pos.size(-1) == 3:
62
+ pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
63
+
64
+ result = pos @ t_mtx.T
65
+
66
+ return result if keepdim else result.unsqueeze(0)
67
+
68
+
69
+ def compute_w2c_matrix(
70
+ elev_deg: float, azim_deg: float, cam_dist: float
71
+ ) -> np.ndarray:
72
+ """Compute w2c 4x4 transformation matrix from spherical coordinates."""
73
+
74
+ elev_rad = math.radians(-elev_deg)
75
+ azim_rad = math.radians(azim_deg)
76
+
77
+ sin_elev = math.sin(elev_rad)
78
+ cos_elev = math.cos(elev_rad)
79
+ sin_azim = math.sin(azim_rad)
80
+ cos_azim = math.cos(azim_rad)
81
+
82
+ cam_pos = np.array(
83
+ [
84
+ cam_dist * cos_elev * cos_azim,
85
+ cam_dist * cos_elev * sin_azim,
86
+ cam_dist * sin_elev,
87
+ ]
88
+ )
89
+
90
+ look_dir = -cam_pos / np.linalg.norm(cam_pos)
91
+ right_dir = np.cross(look_dir, [0, 0, 1])
92
+ right_dir /= np.linalg.norm(right_dir)
93
+ up_dir = np.cross(right_dir, look_dir)
94
+
95
+ c2w = np.eye(4)
96
+ c2w[:3, 0] = right_dir
97
+ c2w[:3, 1] = up_dir
98
+ c2w[:3, 2] = -look_dir
99
+ c2w[:3, 3] = cam_pos
100
+
101
+ try:
102
+ w2c = np.linalg.inv(c2w)
103
+ except np.linalg.LinAlgError as e:
104
+ raise ArithmeticError("Failed to invert camera-to-world matrix") from e
105
+
106
+ return w2c.astype(np.float32)
107
+
108
+
109
+ def _bilinear_interpolation_scattering(
110
+ image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
111
+ ) -> torch.Tensor:
112
+ """Bilinear interpolation scattering for grid-based value accumulation."""
113
+ device = values.device
114
+ dtype = values.dtype
115
+ C = values.shape[-1]
116
+
117
+ indices = coords * torch.tensor(
118
+ [image_h - 1, image_w - 1], dtype=dtype, device=device
119
+ )
120
+ i, j = indices.unbind(-1)
121
+
122
+ i0, j0 = (
123
+ indices.floor()
124
+ .long()
125
+ .clamp(0, image_h - 2)
126
+ .clamp(0, image_w - 2)
127
+ .unbind(-1)
128
+ )
129
+ i1, j1 = i0 + 1, j0 + 1
130
+
131
+ w_i = i - i0.float()
132
+ w_j = j - j0.float()
133
+ weights = torch.stack(
134
+ [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
135
+ dim=1,
136
+ )
137
+
138
+ indices_comb = torch.stack(
139
+ [
140
+ torch.stack([i0, j0], dim=1),
141
+ torch.stack([i0, j1], dim=1),
142
+ torch.stack([i1, j0], dim=1),
143
+ torch.stack([i1, j1], dim=1),
144
+ ],
145
+ dim=1,
146
+ )
147
+
148
+ grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
149
+ cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
150
+
151
+ for k in range(4):
152
+ idx = indices_comb[:, k]
153
+ w = weights[:, k].unsqueeze(-1)
154
+
155
+ stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
156
+ flat_idx = (idx * stride).sum(-1)
157
+
158
+ grid.view(-1, C).scatter_add_(
159
+ 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
160
+ )
161
+ cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
162
+
163
+ mask = cnt.squeeze(-1) > 0
164
+ grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
165
+
166
+ return grid
167
+
168
+
169
+ def _texture_inpaint_smooth(
170
+ texture: np.ndarray,
171
+ mask: np.ndarray,
172
+ vertices: np.ndarray,
173
+ faces: np.ndarray,
174
+ uv_map: np.ndarray,
175
+ ) -> tuple[np.ndarray, np.ndarray]:
176
+ """Perform texture inpainting using vertex-based color propagation."""
177
+ image_h, image_w, C = texture.shape
178
+ N = vertices.shape[0]
179
+
180
+ # Initialize vertex data structures
181
+ vtx_mask = np.zeros(N, dtype=np.float32)
182
+ vtx_colors = np.zeros((N, C), dtype=np.float32)
183
+ unprocessed = []
184
+ adjacency = [[] for _ in range(N)]
185
+
186
+ # Build adjacency graph and initial color assignment
187
+ for face_idx in range(faces.shape[0]):
188
+ for k in range(3):
189
+ uv_idx_k = faces[face_idx, k]
190
+ v_idx = faces[face_idx, k]
191
+
192
+ # Convert UV to pixel coordinates with boundary clamping
193
+ u = np.clip(
194
+ int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
195
+ )
196
+ v = np.clip(
197
+ int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
198
+ 0,
199
+ image_h - 1,
200
+ )
201
+
202
+ if mask[v, u]:
203
+ vtx_mask[v_idx] = 1.0
204
+ vtx_colors[v_idx] = texture[v, u]
205
+ elif v_idx not in unprocessed:
206
+ unprocessed.append(v_idx)
207
+
208
+ # Build undirected adjacency graph
209
+ neighbor = faces[face_idx, (k + 1) % 3]
210
+ if neighbor not in adjacency[v_idx]:
211
+ adjacency[v_idx].append(neighbor)
212
+ if v_idx not in adjacency[neighbor]:
213
+ adjacency[neighbor].append(v_idx)
214
+
215
+ # Color propagation with dynamic stopping
216
+ remaining_iters, prev_count = 2, 0
217
+ while remaining_iters > 0:
218
+ current_unprocessed = []
219
+
220
+ for v_idx in unprocessed:
221
+ valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
222
+ if not valid_neighbors:
223
+ current_unprocessed.append(v_idx)
224
+ continue
225
+
226
+ # Calculate inverse square distance weights
227
+ neighbors_pos = vertices[valid_neighbors]
228
+ dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
229
+ weights = 1 / np.maximum(dist_sq, 1e-8)
230
+
231
+ vtx_colors[v_idx] = np.average(
232
+ vtx_colors[valid_neighbors], weights=weights, axis=0
233
+ )
234
+ vtx_mask[v_idx] = 1.0
235
+
236
+ # Update iteration control
237
+ if len(current_unprocessed) == prev_count:
238
+ remaining_iters -= 1
239
+ else:
240
+ remaining_iters = min(remaining_iters + 1, 2)
241
+ prev_count = len(current_unprocessed)
242
+ unprocessed = current_unprocessed
243
+
244
+ # Generate output texture
245
+ inpainted_texture, updated_mask = texture.copy(), mask.copy()
246
+ for face_idx in range(faces.shape[0]):
247
+ for k in range(3):
248
+ v_idx = faces[face_idx, k]
249
+ if not vtx_mask[v_idx]:
250
+ continue
251
+
252
+ # UV coordinate conversion
253
+ uv_idx_k = faces[face_idx, k]
254
+ u = np.clip(
255
+ int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
256
+ )
257
+ v = np.clip(
258
+ int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
259
+ 0,
260
+ image_h - 1,
261
+ )
262
+
263
+ inpainted_texture[v, u] = vtx_colors[v_idx]
264
+ updated_mask[v, u] = 255
265
+
266
+ return inpainted_texture, updated_mask
267
+
268
+
269
+ class TextureBacker:
270
+ """Texture baking pipeline for multi-view projection and fusion."""
271
+
272
+ def __init__(
273
+ self,
274
+ camera_elevs: list[float],
275
+ camera_azims: list[float],
276
+ camera_distance: int,
277
+ camera_fov: float,
278
+ view_weights: list[float] = None,
279
+ render_wh: tuple[int, int] = (2048, 2048),
280
+ texture_wh: tuple[int, int] = (2048, 2048),
281
+ use_antialias: bool = True,
282
+ bake_angle_thres: int = 75,
283
+ device="cuda",
284
+ ):
285
+ self.camera_elevs = camera_elevs
286
+ self.camera_azims = camera_azims
287
+ self.view_weights = (
288
+ view_weights
289
+ if view_weights is not None
290
+ else [1] * len(camera_elevs)
291
+ )
292
+ self.device = device
293
+ self.render_wh = render_wh
294
+ self.texture_wh = texture_wh
295
+
296
+ self.camera_distance = camera_distance
297
+ self.use_antialias = use_antialias
298
+
299
+ self.bake_angle_thres = bake_angle_thres
300
+ self.bake_unreliable_kernel_size = int(
301
+ (2 / 512) * max(self.render_wh[0], self.render_wh[1])
302
+ )
303
+
304
+ self.camera_proj_mat = get_perspective_projection(
305
+ camera_fov,
306
+ self.render_wh[1] / self.render_wh[0],
307
+ )
308
+ self.cnt = 0
309
+
310
+ def rasterize_mesh(
311
+ self,
312
+ vertex: torch.Tensor,
313
+ face: torch.Tensor,
314
+ resolution: tuple[int, int],
315
+ ) -> torch.Tensor:
316
+ vertex = vertex[None] if vertex.ndim == 2 else vertex
317
+ indices, weights = cr.rasterize(vertex, face, resolution)
318
+
319
+ return torch.cat(
320
+ [weights, indices.unsqueeze(-1).to(weights.dtype)], dim=-1
321
+ ).unsqueeze(0)
322
+
323
+ def raster_interpolate(
324
+ self, uv: torch.Tensor, rast_out: torch.Tensor, faces: torch.Tensor
325
+ ) -> torch.Tensor:
326
+ barycentric = rast_out[0, ..., :-1]
327
+ findices = rast_out[0, ..., -1]
328
+ if uv.dim() == 2:
329
+ uv = uv.unsqueeze(0)
330
+
331
+ return cr.interpolate(uv, findices, barycentric, faces)[0]
332
+
333
+ def load_mesh(self, mesh_path: str) -> None:
334
+ mesh = trimesh.load(mesh_path)
335
+ if isinstance(mesh, trimesh.Scene):
336
+ mesh = mesh.dump(concatenate=True)
337
+
338
+ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
339
+ self.scale, self.center = scale, center
340
+
341
+ vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
342
+ mesh.vertices = mesh.vertices[vmapping]
343
+ mesh.faces = indices
344
+ mesh.visual.uv = uvs
345
+
346
+ self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
347
+ self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
348
+ self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
349
+
350
+ # Transformation of coordinate system
351
+ self.vertices[:, [0, 1]] = -self.vertices[:, [0, 1]]
352
+ self.vertices[:, [1, 2]] = self.vertices[:, [2, 1]]
353
+ self.uv_map[:, 1] = 1 - self.uv_map[:, 1]
354
+
355
+ def get_mesh_attrs(
356
+ self,
357
+ scale: float = None,
358
+ center: np.ndarray = None,
359
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
360
+ vertices = self.vertices.cpu().numpy()
361
+ faces = self.faces.cpu().numpy()
362
+ uv_map = self.uv_map.cpu().numpy()
363
+
364
+ # Inverse transformation of coordinate system
365
+ vertices[:, [1, 2]] = vertices[:, [2, 1]]
366
+ vertices[:, [0, 1]] = -vertices[:, [0, 1]]
367
+ uv_map[:, 1] = 1.0 - uv_map[:, 1]
368
+
369
+ if scale is not None:
370
+ vertices = vertices / scale
371
+ if center is not None:
372
+ vertices = vertices + center
373
+
374
+ return vertices, faces, uv_map
375
+
376
+ def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
377
+ depth_image_np = depth_image.cpu().numpy()
378
+ depth_image_np = (depth_image_np * 255).astype(np.uint8)
379
+ depth_edges = cv2.Canny(depth_image_np, 30, 80)
380
+ combined_edges = depth_edges
381
+ sketch_image = (
382
+ torch.from_numpy(combined_edges).to(depth_image.device).float()
383
+ / 255
384
+ )
385
+ sketch_image = sketch_image.unsqueeze(-1)
386
+
387
+ return sketch_image
388
+
389
+ def back_project(
390
+ self, image: Image.Image, elev: float, azim: float
391
+ ) -> tuple[torch.Tensor, torch.Tensor]:
392
+ if isinstance(image, Image.Image):
393
+ image = np.array(image)
394
+ image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
395
+ if image.ndim == 2:
396
+ image = image.unsqueeze(-1)
397
+ image = image / 255.0
398
+
399
+ view_mat = compute_w2c_matrix(elev, azim, self.camera_distance)
400
+ import pdb
401
+
402
+ pdb.set_trace()
403
+ pos_cam = transform_vertices(view_mat, self.vertices, keepdim=True)
404
+ pos_clip = transform_vertices(self.camera_proj_mat, pos_cam)
405
+ pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
406
+
407
+ v0, v1, v2 = (pos_cam[self.faces[:, i]] for i in range(3))
408
+ face_norm = F.normalize(torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1)
409
+ vertex_norm = (
410
+ torch.from_numpy(
411
+ trimesh.geometry.mean_vertex_normals(
412
+ len(pos_cam), self.faces.cpu(), face_norm.cpu()
413
+ )
414
+ )
415
+ .to(self.device)
416
+ .contiguous()
417
+ )
418
+
419
+ rast_out = self.rasterize_mesh(pos_clip, self.faces, image.shape[:2])
420
+ vis_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0]
421
+
422
+ interp_data = {
423
+ "normal": self.raster_interpolate(
424
+ vertex_norm[None], rast_out, self.faces
425
+ ),
426
+ "uv": self.raster_interpolate(
427
+ self.uv_map[None], rast_out, self.faces
428
+ ),
429
+ "depth": self.raster_interpolate(
430
+ pos_cam[:, 2].reshape(1, -1, 1), rast_out, self.faces
431
+ ),
432
+ }
433
+
434
+ valid_depth = interp_data["depth"][vis_mask > 0]
435
+ depth_norm = (interp_data["depth"] - valid_depth.min()) / (
436
+ valid_depth.max() - valid_depth.min()
437
+ )
438
+ # depth_norm[vis_mask <= 0] = 0
439
+ sketch_image = self._render_depth_edges(depth_norm * vis_mask)
440
+
441
+ # ddd = depth_norm * vis_mask
442
+ # cv2.imwrite(f"v2_depth_d{self.cnt}.png", (ddd.cpu().numpy() * 255).astype(np.uint8))
443
+
444
+ cv2.imwrite(
445
+ f"v2_vis_mask{self.cnt}.png",
446
+ (vis_mask.cpu().numpy() * 255).astype(np.uint8),
447
+ )
448
+ cv2.imwrite(
449
+ f"v2_normal{self.cnt}.png",
450
+ (interp_data["normal"].cpu().numpy() * 255).astype(np.uint8),
451
+ )
452
+ cv2.imwrite(
453
+ f"v2_depth{self.cnt}.png",
454
+ (depth_norm.cpu().numpy() * 255).astype(np.uint8),
455
+ )
456
+ cv2.imwrite(
457
+ f"v2_uv{self.cnt}.png",
458
+ (interp_data["uv"][..., 0].cpu().numpy() * 255).astype(np.uint8),
459
+ )
460
+ cv2.imwrite(
461
+ f"v2_sketch{self.cnt}.png",
462
+ (sketch_image.cpu().numpy() * 255).astype(np.uint8),
463
+ )
464
+
465
+ self.cnt += 1
466
+
467
+ cos = F.cosine_similarity(
468
+ torch.tensor([[0, 0, -1]], device=self.device),
469
+ interp_data["normal"].view(-1, 3),
470
+ ).view_as(interp_data["normal"][..., :1])
471
+ cos[cos < np.cos(np.radians(self.bake_angle_thres))] = 0
472
+
473
+ cv2.imwrite(
474
+ f"v2_cos{self.cnt}.png", (cos.cpu().numpy() * 255).astype(np.uint8)
475
+ )
476
+
477
+ k = self.bake_unreliable_kernel_size * 2 + 1
478
+ kernel = torch.ones((1, 1, k, k), device=self.device)
479
+
480
+ vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
481
+ vis_mask = F.conv2d(
482
+ 1.0 - vis_mask,
483
+ kernel,
484
+ padding=k // 2,
485
+ )
486
+ vis_mask = 1.0 - (vis_mask > 0).float()
487
+ vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
488
+
489
+ sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
490
+ sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
491
+ sketch_image = (sketch_image > 0).float()
492
+ sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
493
+ vis_mask = vis_mask * (sketch_image < 0.5)
494
+
495
+ cos[vis_mask == 0] = 0
496
+
497
+ vis_mask = cv2.imread(
498
+ f"v3_db_mask{self.cnt}.png", cv2.IMREAD_GRAYSCALE
499
+ )
500
+ vis_mask = (
501
+ torch.from_numpy(vis_mask[..., None]).to(self.device).float() / 255
502
+ )
503
+ # cos2 = cv2.imread(f"v3_db_cos{self.cnt}.png", cv2.IMREAD_GRAYSCALE)
504
+ # cos2 = torch.from_numpy(cos2[..., None]).to(self.device).float() / 255
505
+ # cos = cos2
506
+
507
+ valid_pixels = (vis_mask != 0).view(-1)
508
+ # import pdb; pdb.set_trace()
509
+
510
+ cv2.imwrite(
511
+ f"v2_db_sketch{self.cnt}.png",
512
+ (sketch_image.cpu().numpy() * 255).astype(np.uint8),
513
+ )
514
+ cv2.imwrite(
515
+ f"v2_db_uv{self.cnt}.png",
516
+ (interp_data["uv"][..., 0].cpu().numpy() * 255).astype(np.uint8),
517
+ )
518
+ cv2.imwrite(
519
+ f"v2_db_uv2{self.cnt}.png",
520
+ (interp_data["uv"][..., 1].cpu().numpy() * 255).astype(np.uint8),
521
+ )
522
+ cv2.imwrite(
523
+ f"v2_db_color{self.cnt}.png",
524
+ (image.cpu().numpy() * 255).astype(np.uint8),
525
+ )
526
+ cv2.imwrite(
527
+ f"v2_db_cos{self.cnt}.png",
528
+ (cos.cpu().numpy() * 255).astype(np.uint8),
529
+ )
530
+ cv2.imwrite(
531
+ f"v2_db_mask{self.cnt}.png",
532
+ (vis_mask.cpu().numpy() * 255).astype(np.uint8),
533
+ )
534
+ # import pdb; pdb.set_trace()
535
+ return (
536
+ self._scatter_texture(interp_data["uv"], image, valid_pixels),
537
+ self._scatter_texture(interp_data["uv"], cos, valid_pixels),
538
+ )
539
+
540
+ def _scatter_texture(self, uv, data, mask):
541
+ def __filter_data(data, mask):
542
+ return data.view(-1, data.shape[-1])[mask]
543
+
544
+ return _bilinear_interpolation_scattering(
545
+ self.texture_wh[1],
546
+ self.texture_wh[0],
547
+ __filter_data(uv, mask)[..., [1, 0]],
548
+ __filter_data(data, mask),
549
+ )
550
+
551
+ @torch.no_grad()
552
+ def fast_bake_texture(
553
+ self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
554
+ ) -> tuple[torch.Tensor, torch.Tensor]:
555
+ channel = textures[0].shape[-1]
556
+ texture_merge = torch.zeros(self.texture_wh + (channel,)).to(
557
+ self.device
558
+ )
559
+ trust_map_merge = torch.zeros(self.texture_wh + (1,)).to(self.device)
560
+ for texture, cos_map in zip(textures, confidence_maps):
561
+ view_sum = (cos_map > 0).sum()
562
+ painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
563
+ if painted_sum / view_sum > 0.99:
564
+ continue
565
+ texture_merge += texture * cos_map
566
+ trust_map_merge += cos_map
567
+ texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
568
+
569
+ return texture_merge, trust_map_merge > 1e-8
570
+
571
+ def uv_inpaint(
572
+ self, texture: torch.Tensor, mask: torch.Tensor
573
+ ) -> np.ndarray:
574
+ texture_np = texture.cpu().numpy()
575
+ mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
576
+ vertices, faces, uv_map = self.get_mesh_attrs()
577
+ # import pdb; pdb.set_trace()
578
+ texture_np, mask_np = _texture_inpaint_smooth(
579
+ texture_np, mask_np, vertices, faces, uv_map
580
+ )
581
+ texture_np = texture_np.clip(0, 1)
582
+ texture_np = cv2.inpaint(
583
+ (texture_np * 255).astype(np.uint8),
584
+ 255 - mask_np,
585
+ 3,
586
+ cv2.INPAINT_NS,
587
+ )
588
+
589
+ return texture_np
590
+
591
+ def __call__(
592
+ self, colors: list[Image.Image], input_mesh: str, output_path: str
593
+ ) -> trimesh.Trimesh:
594
+ self.load_mesh(input_mesh)
595
+
596
+ textures, weighted_cos_maps = [], []
597
+ for color, cam_elev, cam_azim, weight in zip(
598
+ colors, self.camera_elevs, self.camera_azims, self.view_weights
599
+ ):
600
+ texture, cos_map = self.back_project(color, cam_elev, cam_azim)
601
+ cv2.imwrite(
602
+ f"v2_texture{self.cnt}.png",
603
+ (texture.cpu().numpy() * 255).astype(np.uint8),
604
+ )
605
+ cv2.imwrite(
606
+ f"v2_texture_cos{self.cnt}.png",
607
+ (cos_map.cpu().numpy() * 255).astype(np.uint8),
608
+ )
609
+ # import pdb; pdb.set_trace()
610
+ textures.append(texture)
611
+ weighted_cos_maps.append(weight * (cos_map**4))
612
+
613
+ texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
614
+ texture_np = self.uv_inpaint(texture, mask)
615
+ texture_np = post_process_texture(texture_np)
616
+ vertices, faces, uvs = self.get_mesh_attrs(self.scale, self.center)
617
+ # import pdb; pdb.set_trace()
618
+ cv2.imwrite("v2_texture_np.png", texture_np)
619
+
620
+ textured_mesh = save_mesh_with_mtl(
621
+ vertices, faces, uvs, texture_np, output_path
622
+ )
623
+
624
+ return textured_mesh
625
+
626
+
627
+ class Image_Super_Net:
628
+ def __init__(self, device="cuda"):
629
+ from diffusers import StableDiffusionUpscalePipeline
630
+
631
+ self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
632
+ "stabilityai/stable-diffusion-x4-upscaler",
633
+ torch_dtype=torch.float16,
634
+ ).to(device)
635
+ self.up_pipeline_x4.set_progress_bar_config(disable=True)
636
+
637
+ def __call__(self, image, prompt=""):
638
+ with torch.no_grad():
639
+ upscaled_image = self.up_pipeline_x4(
640
+ prompt=[prompt],
641
+ image=image,
642
+ num_inference_steps=10,
643
+ ).images[0]
644
+
645
+ return upscaled_image
646
+
647
+
648
+ class Image_GANNet:
649
+ def __init__(self, outscale: int):
650
+ from basicsr.archs.rrdbnet_arch import RRDBNet
651
+ from realesrgan import RealESRGANer
652
+
653
+ self.outscale = outscale
654
+ model = RRDBNet(
655
+ num_in_ch=3,
656
+ num_out_ch=3,
657
+ num_feat=64,
658
+ num_block=23,
659
+ num_grow_ch=32,
660
+ scale=4,
661
+ )
662
+ self.upsampler = RealESRGANer(
663
+ scale=4,
664
+ model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/super_resolution/RealESRGAN_x4plus.pth", # noqa
665
+ model=model,
666
+ pre_pad=0,
667
+ half=True,
668
+ )
669
+
670
+ def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
671
+ if isinstance(image, Image.Image):
672
+ image = np.array(image)
673
+ output, _ = self.upsampler.enhance(image, outscale=self.outscale)
674
+
675
+ return Image.fromarray(output)
676
+
677
+
678
+ if __name__ == "__main__":
679
+ device = "cuda"
680
+ color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png"
681
+ mesh_path = "outputs/texture_mesh_gen/texture_mesh/kettle_color.glb"
682
+ output_path = "robot_test_v2/robot.obj"
683
+ target_image_size = (2048, 2048)
684
+
685
+ super_model = Image_GANNet(outscale=4)
686
+ multiviews = get_images_from_file(color_path, img_size=512)
687
+
688
+ texture_backer = TextureBacker(
689
+ camera_elevs=[20, 20, 20, -10, -10, -10],
690
+ camera_azims=[-180, -60, 60, -120, 0, 120],
691
+ view_weights=[1, 0.2, 0.2, 0.2, 1, 0.2],
692
+ camera_distance=5,
693
+ camera_fov=30,
694
+ render_wh=(2048, 2048),
695
+ texture_wh=(2048, 2048),
696
+ )
697
+
698
+ multiviews = [super_model(img) for img in multiviews]
699
+ multiviews = [img.convert("RGB") for img in multiviews]
700
+ textured_mesh = texture_backer(multiviews, mesh_path, output_path)
asset3d_gen/data/backup/backproject_v3.py ADDED
@@ -0,0 +1,866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from typing import Union
4
+
5
+ import custom_rasterizer as cr
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import trimesh
11
+ import xatlas
12
+ from PIL import Image
13
+ from asset3d_gen.data.utils import (
14
+ get_images_from_file,
15
+ normalize_vertices_array,
16
+ post_process_texture,
17
+ save_mesh_with_mtl,
18
+ )
19
+
20
+ logging.basicConfig(
21
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
22
+ )
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ __all__ = ["TextureBacker", "Image_Super_Net", "Image_GANNet"]
27
+
28
+
29
+ import math
30
+ import numpy as np
31
+
32
+
33
+ def get_perspective_projection(
34
+ fov: float, aspect_wh: float, near: float = 0.01, far: float = 100
35
+ ) -> np.ndarray:
36
+ """Compute the perspective projection matrix for 3D rendering."""
37
+ fov_rad = math.radians(fov)
38
+ tan_half_fov = math.tan(fov_rad / 2.0)
39
+
40
+ return np.array(
41
+ [
42
+ [1.0 / (tan_half_fov * aspect_wh), 0.0, 0.0, 0.0],
43
+ [0.0, 1.0 / tan_half_fov, 0.0, 0.0],
44
+ [
45
+ 0.0,
46
+ 0.0,
47
+ -(far + near) / (far - near),
48
+ -(2.0 * far * near) / (far - near),
49
+ ],
50
+ [0.0, 0.0, -1.0, 0.0],
51
+ ],
52
+ dtype=np.float32,
53
+ )
54
+
55
+
56
+ def transform_vertices(
57
+ mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
58
+ ) -> torch.Tensor:
59
+ """Transform 3D vertices using a projection matrix."""
60
+ t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
61
+ if pos.size(-1) == 3:
62
+ pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
63
+
64
+ result = pos @ t_mtx.T
65
+
66
+ return result if keepdim else result.unsqueeze(0)
67
+
68
+
69
+ def compute_w2c_matrix(
70
+ elev_deg: float, azim_deg: float, cam_dist: float
71
+ ) -> np.ndarray:
72
+ """Compute w2c 4x4 transformation matrix from spherical coordinates."""
73
+
74
+ elev_rad = math.radians(-elev_deg)
75
+ azim_rad = math.radians(azim_deg)
76
+
77
+ sin_elev = math.sin(elev_rad)
78
+ cos_elev = math.cos(elev_rad)
79
+ sin_azim = math.sin(azim_rad)
80
+ cos_azim = math.cos(azim_rad)
81
+
82
+ cam_pos = np.array(
83
+ [
84
+ cam_dist * cos_elev * cos_azim,
85
+ cam_dist * cos_elev * sin_azim,
86
+ cam_dist * sin_elev,
87
+ ]
88
+ )
89
+
90
+ look_dir = -cam_pos / np.linalg.norm(cam_pos)
91
+ right_dir = np.cross(look_dir, [0, 0, 1])
92
+ right_dir /= np.linalg.norm(right_dir)
93
+ up_dir = np.cross(right_dir, look_dir)
94
+
95
+ c2w = np.eye(4)
96
+ c2w[:3, 0] = right_dir
97
+ c2w[:3, 1] = up_dir
98
+ c2w[:3, 2] = -look_dir
99
+ c2w[:3, 3] = cam_pos
100
+
101
+ try:
102
+ w2c = np.linalg.inv(c2w)
103
+ except np.linalg.LinAlgError as e:
104
+ raise ArithmeticError("Failed to invert camera-to-world matrix") from e
105
+
106
+ return w2c.astype(np.float32)
107
+
108
+
109
+ def _bilinear_interpolation_scattering(
110
+ image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
111
+ ) -> torch.Tensor:
112
+ """Bilinear interpolation scattering for grid-based value accumulation."""
113
+ device = values.device
114
+ dtype = values.dtype
115
+ C = values.shape[-1]
116
+
117
+ indices = coords * torch.tensor(
118
+ [image_h - 1, image_w - 1], dtype=dtype, device=device
119
+ )
120
+ i, j = indices.unbind(-1)
121
+
122
+ i0, j0 = (
123
+ indices.floor()
124
+ .long()
125
+ .clamp(0, image_h - 2)
126
+ .clamp(0, image_w - 2)
127
+ .unbind(-1)
128
+ )
129
+ i1, j1 = i0 + 1, j0 + 1
130
+
131
+ w_i = i - i0.float()
132
+ w_j = j - j0.float()
133
+ weights = torch.stack(
134
+ [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
135
+ dim=1,
136
+ )
137
+
138
+ indices_comb = torch.stack(
139
+ [
140
+ torch.stack([i0, j0], dim=1),
141
+ torch.stack([i0, j1], dim=1),
142
+ torch.stack([i1, j0], dim=1),
143
+ torch.stack([i1, j1], dim=1),
144
+ ],
145
+ dim=1,
146
+ )
147
+
148
+ grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
149
+ cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
150
+
151
+ for k in range(4):
152
+ idx = indices_comb[:, k]
153
+ w = weights[:, k].unsqueeze(-1)
154
+
155
+ stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
156
+ flat_idx = (idx * stride).sum(-1)
157
+
158
+ grid.view(-1, C).scatter_add_(
159
+ 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
160
+ )
161
+ cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
162
+
163
+ mask = cnt.squeeze(-1) > 0
164
+ grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
165
+
166
+ return grid
167
+
168
+
169
+ def _texture_inpaint_smooth(
170
+ texture: np.ndarray,
171
+ mask: np.ndarray,
172
+ vertices: np.ndarray,
173
+ faces: np.ndarray,
174
+ uv_map: np.ndarray,
175
+ ) -> tuple[np.ndarray, np.ndarray]:
176
+ """Perform texture inpainting using vertex-based color propagation."""
177
+ image_h, image_w, C = texture.shape
178
+ N = vertices.shape[0]
179
+
180
+ # Initialize vertex data structures
181
+ vtx_mask = np.zeros(N, dtype=np.float32)
182
+ vtx_colors = np.zeros((N, C), dtype=np.float32)
183
+ unprocessed = []
184
+ adjacency = [[] for _ in range(N)]
185
+
186
+ # Build adjacency graph and initial color assignment
187
+ for face_idx in range(faces.shape[0]):
188
+ for k in range(3):
189
+ uv_idx_k = faces[face_idx, k]
190
+ v_idx = faces[face_idx, k]
191
+
192
+ # Convert UV to pixel coordinates with boundary clamping
193
+ u = np.clip(
194
+ int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
195
+ )
196
+ v = np.clip(
197
+ int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
198
+ 0,
199
+ image_h - 1,
200
+ )
201
+
202
+ if mask[v, u]:
203
+ vtx_mask[v_idx] = 1.0
204
+ vtx_colors[v_idx] = texture[v, u]
205
+ elif v_idx not in unprocessed:
206
+ unprocessed.append(v_idx)
207
+
208
+ # Build undirected adjacency graph
209
+ neighbor = faces[face_idx, (k + 1) % 3]
210
+ if neighbor not in adjacency[v_idx]:
211
+ adjacency[v_idx].append(neighbor)
212
+ if v_idx not in adjacency[neighbor]:
213
+ adjacency[neighbor].append(v_idx)
214
+
215
+ # Color propagation with dynamic stopping
216
+ remaining_iters, prev_count = 2, 0
217
+ while remaining_iters > 0:
218
+ current_unprocessed = []
219
+
220
+ for v_idx in unprocessed:
221
+ valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
222
+ if not valid_neighbors:
223
+ current_unprocessed.append(v_idx)
224
+ continue
225
+
226
+ # Calculate inverse square distance weights
227
+ neighbors_pos = vertices[valid_neighbors]
228
+ dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
229
+ weights = 1 / np.maximum(dist_sq, 1e-8)
230
+
231
+ vtx_colors[v_idx] = np.average(
232
+ vtx_colors[valid_neighbors], weights=weights, axis=0
233
+ )
234
+ vtx_mask[v_idx] = 1.0
235
+
236
+ # Update iteration control
237
+ if len(current_unprocessed) == prev_count:
238
+ remaining_iters -= 1
239
+ else:
240
+ remaining_iters = min(remaining_iters + 1, 2)
241
+ prev_count = len(current_unprocessed)
242
+ unprocessed = current_unprocessed
243
+
244
+ # Generate output texture
245
+ inpainted_texture, updated_mask = texture.copy(), mask.copy()
246
+ for face_idx in range(faces.shape[0]):
247
+ for k in range(3):
248
+ v_idx = faces[face_idx, k]
249
+ if not vtx_mask[v_idx]:
250
+ continue
251
+
252
+ # UV coordinate conversion
253
+ uv_idx_k = faces[face_idx, k]
254
+ u = np.clip(
255
+ int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
256
+ )
257
+ v = np.clip(
258
+ int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
259
+ 0,
260
+ image_h - 1,
261
+ )
262
+
263
+ inpainted_texture[v, u] = vtx_colors[v_idx]
264
+ updated_mask[v, u] = 255
265
+
266
+ return inpainted_texture, updated_mask
267
+
268
+
269
+ class TextureBacker:
270
+ """Texture baking pipeline for multi-view projection and fusion."""
271
+
272
+ def __init__(
273
+ self,
274
+ camera_elevs: list[float],
275
+ camera_azims: list[float],
276
+ camera_distance: int,
277
+ camera_fov: float,
278
+ view_weights: list[float] = None,
279
+ render_wh: tuple[int, int] = (2048, 2048),
280
+ texture_wh: tuple[int, int] = (2048, 2048),
281
+ use_antialias: bool = True,
282
+ bake_angle_thresh: int = 75,
283
+ device="cuda",
284
+ ):
285
+ self.camera_elevs = camera_elevs
286
+ self.camera_azims = camera_azims
287
+ self.view_weights = (
288
+ view_weights
289
+ if view_weights is not None
290
+ else [1] * len(camera_elevs)
291
+ )
292
+ self.device = device
293
+ self.render_wh = render_wh
294
+ self.texture_wh = texture_wh
295
+
296
+ self.camera_distance = camera_distance
297
+ self.use_antialias = use_antialias
298
+
299
+ self.bake_angle_thresh = bake_angle_thresh
300
+ self.bake_unreliable_kernel_size = int(
301
+ (2 / 512) * max(self.render_wh[0], self.render_wh[1])
302
+ )
303
+
304
+ self.camera_proj_mat = get_perspective_projection(
305
+ camera_fov,
306
+ self.render_wh[1] / self.render_wh[0],
307
+ )
308
+ self.cnt = 0
309
+
310
+ def rasterize_mesh(
311
+ self,
312
+ vertex: torch.Tensor,
313
+ face: torch.Tensor,
314
+ resolution: tuple[int, int],
315
+ ) -> torch.Tensor:
316
+ vertex = vertex[None] if vertex.ndim == 2 else vertex
317
+ indices, weights = cr.rasterize(vertex, face, resolution)
318
+
319
+ return torch.cat(
320
+ [weights, indices.unsqueeze(-1).to(weights.dtype)], dim=-1
321
+ ).unsqueeze(0)
322
+
323
+ def raster_interpolate(
324
+ self, uv: torch.Tensor, rast_out: torch.Tensor, faces: torch.Tensor
325
+ ) -> torch.Tensor:
326
+ barycentric = rast_out[0, ..., :-1]
327
+ findices = rast_out[0, ..., -1]
328
+ if uv.dim() == 2:
329
+ uv = uv.unsqueeze(0)
330
+
331
+ return cr.interpolate(uv, findices, barycentric, faces)[0]
332
+
333
+ def load_mesh(self, mesh_path: str) -> None:
334
+ mesh = trimesh.load(mesh_path)
335
+ if isinstance(mesh, trimesh.Scene):
336
+ mesh = mesh.dump(concatenate=True)
337
+
338
+ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
339
+ self.scale, self.center = scale, center
340
+
341
+ vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
342
+ mesh.vertices = mesh.vertices[vmapping]
343
+ mesh.faces = indices
344
+ mesh.visual.uv = uvs
345
+
346
+ self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
347
+ self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
348
+ self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
349
+
350
+ # Transformation of coordinate system
351
+ self.vertices[:, [0, 1]] = -self.vertices[:, [0, 1]]
352
+ self.vertices[:, [1, 2]] = self.vertices[:, [2, 1]]
353
+ self.uv_map[:, 1] = 1 - self.uv_map[:, 1]
354
+
355
+ def get_mesh_attrs(
356
+ self,
357
+ scale: float = None,
358
+ center: np.ndarray = None,
359
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
360
+ vertices = self.vertices.cpu().numpy()
361
+ faces = self.faces.cpu().numpy()
362
+ uv_map = self.uv_map.cpu().numpy()
363
+
364
+ if scale is not None:
365
+ vertices = vertices / scale
366
+ if center is not None:
367
+ vertices = vertices + center
368
+
369
+ return vertices, faces, uv_map
370
+
371
+ def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
372
+ depth_image_np = depth_image.cpu().numpy()
373
+ depth_image_np = (depth_image_np * 255).astype(np.uint8)
374
+ depth_edges = cv2.Canny(depth_image_np, 30, 80)
375
+ sketch_image = (
376
+ torch.from_numpy(depth_edges).to(depth_image.device).float() / 255
377
+ )
378
+ sketch_image = sketch_image.unsqueeze(-1)
379
+
380
+ return sketch_image
381
+
382
+ def back_project(
383
+ self, image: Image.Image, elev: float, azim: float
384
+ ) -> tuple[torch.Tensor, torch.Tensor]:
385
+ if isinstance(image, Image.Image):
386
+ image = np.array(image)
387
+ image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
388
+ if image.ndim == 2:
389
+ image = image.unsqueeze(-1)
390
+ image = image / 255.0
391
+
392
+ view_mat = compute_w2c_matrix(elev, azim, self.camera_distance)
393
+ pos_cam = transform_vertices(view_mat, self.vertices, keepdim=True)
394
+ pos_clip = transform_vertices(self.camera_proj_mat, pos_cam)
395
+ pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
396
+
397
+ v0, v1, v2 = (pos_cam[self.faces[:, i]] for i in range(3))
398
+ face_norm = F.normalize(torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1)
399
+ vertex_norm = (
400
+ torch.from_numpy(
401
+ trimesh.geometry.mean_vertex_normals(
402
+ len(pos_cam), self.faces.cpu(), face_norm.cpu()
403
+ )
404
+ )
405
+ .to(self.device)
406
+ .contiguous()
407
+ )
408
+
409
+ rast_out = self.rasterize_mesh(pos_clip, self.faces, image.shape[:2])
410
+ vis_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0]
411
+
412
+ interp_data = {
413
+ "normal": self.raster_interpolate(
414
+ vertex_norm[None], rast_out, self.faces
415
+ ),
416
+ "uv": self.raster_interpolate(
417
+ self.uv_map[None], rast_out, self.faces
418
+ ),
419
+ "depth": self.raster_interpolate(
420
+ pos_cam[:, 2].reshape(1, -1, 1), rast_out, self.faces
421
+ ),
422
+ }
423
+
424
+ valid_depth = interp_data["depth"][vis_mask > 0]
425
+ depth_norm = (interp_data["depth"] - valid_depth.min()) / (
426
+ valid_depth.max() - valid_depth.min()
427
+ )
428
+ depth_norm[vis_mask <= 0] = 0
429
+ sketch_image = self._render_depth_edges(depth_norm * vis_mask)
430
+
431
+ # cv2.imwrite("vis_mask.png", (vis_mask.cpu().numpy() * 255).astype(np.uint8))
432
+ # cv2.imwrite("normal.png", (interp_data['normal'].cpu().numpy() * 255).astype(np.uint8))
433
+ # cv2.imwrite("depth.png", (depth_norm.cpu().numpy() * 255).astype(np.uint8))
434
+ # cv2.imwrite("uv.png", (interp_data['uv'][..., 0].cpu().numpy() * 255).astype(np.uint8))
435
+ # import pdb; pdb.set_trace()
436
+
437
+ cos = F.cosine_similarity(
438
+ torch.tensor([[0, 0, -1]], device=self.device),
439
+ interp_data["normal"].view(-1, 3),
440
+ ).view_as(interp_data["normal"][..., :1])
441
+ cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
442
+
443
+ k = self.bake_unreliable_kernel_size * 2 + 1
444
+ kernel = torch.ones((1, 1, k, k), device=self.device)
445
+
446
+ vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
447
+ vis_mask = F.conv2d(
448
+ 1.0 - vis_mask,
449
+ kernel,
450
+ padding=k // 2,
451
+ )
452
+ vis_mask = 1.0 - (vis_mask > 0).float()
453
+ vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
454
+
455
+ sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
456
+ sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
457
+ sketch_image = (sketch_image > 0).float()
458
+ sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
459
+ vis_mask = vis_mask * (sketch_image < 0.5)
460
+
461
+ cos[vis_mask == 0] = 0
462
+ valid_pixels = (vis_mask != 0).view(-1)
463
+
464
+ return (
465
+ self._scatter_texture(interp_data["uv"], image, valid_pixels),
466
+ self._scatter_texture(interp_data["uv"], cos, valid_pixels),
467
+ )
468
+
469
+ def back_project2(
470
+ self, image, vis_mask, depth, normal, uv
471
+ ) -> tuple[torch.Tensor, torch.Tensor]:
472
+ if isinstance(image, Image.Image):
473
+ image = np.array(image)
474
+ image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
475
+ if image.ndim == 2:
476
+ image = image.unsqueeze(-1)
477
+ image = image / 255.0
478
+
479
+ depth_inv = (1.0 - depth) * vis_mask
480
+ sketch_image = self._render_depth_edges(depth_inv)
481
+
482
+ cv2.imwrite(
483
+ f"v3_depth_inv{self.cnt}.png",
484
+ (depth_inv.cpu().numpy() * 255).astype(np.uint8),
485
+ )
486
+
487
+ cos = F.cosine_similarity(
488
+ torch.tensor([[0, 0, 1]], device=self.device),
489
+ normal.view(-1, 3),
490
+ ).view_as(normal[..., :1])
491
+ cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
492
+ # import pdb; pdb.set_trace()
493
+ # cv2.imwrite(f"v3_cos{self.cnt}.png", (cos.cpu().numpy() * 255).astype(np.uint8))
494
+ # cv2.imwrite(f"v3_sketch{self.cnt}.png", (sketch_image.cpu().numpy() * 255).astype(np.uint8))
495
+
496
+ # cos2 = cv2.imread(f"v2_cos{self.cnt+1}.png", cv2.IMREAD_GRAYSCALE)
497
+ # cos2 = torch.from_numpy(cos2[..., None]).to(self.device).float() / 255
498
+ # cos = cos2
499
+
500
+ self.cnt += 1
501
+
502
+ k = self.bake_unreliable_kernel_size * 2 + 1
503
+ kernel = torch.ones((1, 1, k, k), device=self.device)
504
+
505
+ vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
506
+ vis_mask = F.conv2d(
507
+ 1.0 - vis_mask,
508
+ kernel,
509
+ padding=k // 2,
510
+ )
511
+ vis_mask = 1.0 - (vis_mask > 0).float()
512
+ vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
513
+
514
+ sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
515
+ sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
516
+ sketch_image = (sketch_image > 0).float()
517
+ sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
518
+ vis_mask = vis_mask * (sketch_image < 0.5)
519
+ # import pdb; pdb.set_trace()
520
+ cv2.imwrite(
521
+ f"v3_db_sketch{self.cnt}.png",
522
+ (sketch_image.cpu().numpy() * 255).astype(np.uint8),
523
+ )
524
+
525
+ cos[vis_mask == 0] = 0
526
+ # import pdb; pdb.set_trace()
527
+ # vis_mask = cv2.imread(f"v2_db_mask{self.cnt}.png", cv2.IMREAD_GRAYSCALE)
528
+ # vis_mask = torch.from_numpy(vis_mask[..., None]).to(self.device).float() / 255
529
+ # cos2 = cv2.imread(f"v2_db_cos{self.cnt}.png", cv2.IMREAD_GRAYSCALE)
530
+ # cos2 = torch.from_numpy(cos2[..., None]).to(self.device).float() / 255
531
+ # cos = cos2
532
+
533
+ valid_pixels = (vis_mask != 0).view(-1)
534
+ # import pdb; pdb.set_trace()
535
+ cv2.imwrite(
536
+ f"v3_db_uv{self.cnt}.png",
537
+ (uv[..., 0].cpu().numpy() * 255).astype(np.uint8),
538
+ )
539
+ cv2.imwrite(
540
+ f"v3_db_uv2{self.cnt}.png",
541
+ (uv[..., 1].cpu().numpy() * 255).astype(np.uint8),
542
+ )
543
+ cv2.imwrite(
544
+ f"v3_db_color{self.cnt}.png",
545
+ (image.cpu().numpy() * 255).astype(np.uint8),
546
+ )
547
+ cv2.imwrite(
548
+ f"v3_db_cos{self.cnt}.png",
549
+ (cos.cpu().numpy() * 255).astype(np.uint8),
550
+ )
551
+ cv2.imwrite(
552
+ f"v3_db_mask{self.cnt}.png",
553
+ (vis_mask.cpu().numpy() * 255).astype(np.uint8),
554
+ )
555
+
556
+ return (
557
+ self._scatter_texture(uv, image, valid_pixels),
558
+ self._scatter_texture(uv, cos, valid_pixels),
559
+ )
560
+
561
+ def _scatter_texture(self, uv, data, mask):
562
+ def __filter_data(data, mask):
563
+ return data.view(-1, data.shape[-1])[mask]
564
+
565
+ return _bilinear_interpolation_scattering(
566
+ self.texture_wh[1],
567
+ self.texture_wh[0],
568
+ __filter_data(uv, mask)[..., [1, 0]],
569
+ __filter_data(data, mask),
570
+ )
571
+
572
+ @torch.no_grad()
573
+ def fast_bake_texture(
574
+ self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
575
+ ) -> tuple[torch.Tensor, torch.Tensor]:
576
+ channel = textures[0].shape[-1]
577
+ texture_merge = torch.zeros(self.texture_wh + (channel,)).to(
578
+ self.device
579
+ )
580
+ trust_map_merge = torch.zeros(self.texture_wh + (1,)).to(self.device)
581
+ for texture, cos_map in zip(textures, confidence_maps):
582
+ view_sum = (cos_map > 0).sum()
583
+ painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
584
+ if painted_sum / view_sum > 0.99:
585
+ continue
586
+ texture_merge += texture * cos_map
587
+ trust_map_merge += cos_map
588
+ texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
589
+
590
+ return texture_merge, trust_map_merge > 1e-8
591
+
592
+ def uv_inpaint(
593
+ self, texture: torch.Tensor, mask: torch.Tensor
594
+ ) -> np.ndarray:
595
+ texture_np = texture.cpu().numpy()
596
+ mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
597
+ vertices, faces, uv_map = self.get_mesh_attrs()
598
+ # import pdb; pdb.set_trace()
599
+ texture_np, mask_np = _texture_inpaint_smooth(
600
+ texture_np, mask_np, vertices, faces, uv_map
601
+ )
602
+ texture_np = texture_np.clip(0, 1)
603
+ texture_np = cv2.inpaint(
604
+ (texture_np * 255).astype(np.uint8),
605
+ 255 - mask_np,
606
+ 3,
607
+ cv2.INPAINT_NS,
608
+ )
609
+
610
+ return texture_np
611
+
612
+ def __call__(
613
+ self, colors: list[Image.Image], input_mesh: str, output_path: str
614
+ ) -> trimesh.Trimesh:
615
+ self.load_mesh(input_mesh)
616
+
617
+ textures, weighted_cos_maps = [], []
618
+ for color, cam_elev, cam_azim, weight in zip(
619
+ colors, self.camera_elevs, self.camera_azims, self.view_weights
620
+ ):
621
+ texture, cos_map = self.back_project(color, cam_elev, cam_azim)
622
+ textures.append(texture)
623
+ weighted_cos_maps.append(weight * (cos_map**4))
624
+
625
+ texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
626
+ texture_np = self.uv_inpaint(texture, mask)
627
+ texture_np = post_process_texture(texture_np)
628
+ vertices, faces, uv_map = self.get_mesh_attrs(self.scale, self.center)
629
+ # import pdb; pdb.set_trace()
630
+ textured_mesh = save_mesh_with_mtl(
631
+ vertices, faces, uv_map, texture_np, output_path
632
+ )
633
+
634
+ return textured_mesh
635
+
636
+ def forward(
637
+ self,
638
+ colors: list[Image.Image],
639
+ masks,
640
+ depths,
641
+ normals,
642
+ uvs,
643
+ ) -> trimesh.Trimesh:
644
+ textures, weighted_cos_maps = [], []
645
+ for color, mask, depth, normal, uv, weight in zip(
646
+ colors, masks, depths, normals, uvs, self.view_weights
647
+ ):
648
+ texture, cos_map = self.back_project2(
649
+ color, mask, depth, normal, uv
650
+ )
651
+ cv2.imwrite(
652
+ f"v3_texture{self.cnt}.png",
653
+ (texture.cpu().numpy() * 255).astype(np.uint8),
654
+ )
655
+ cv2.imwrite(
656
+ f"v3_texture_cos{self.cnt}.png",
657
+ (cos_map.cpu().numpy() * 255).astype(np.uint8),
658
+ )
659
+
660
+ textures.append(texture)
661
+ weighted_cos_maps.append(weight * (cos_map**4))
662
+
663
+ texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
664
+ texture_np = self.uv_inpaint(texture, mask)
665
+ texture_np = post_process_texture(texture_np)
666
+ vertices, faces, uv_map = self.get_mesh_attrs(self.scale, self.center)
667
+ # import pdb; pdb.set_trace()
668
+ cv2.imwrite("v3_texture_np.png", texture_np)
669
+ textured_mesh = save_mesh_with_mtl(
670
+ vertices, faces, uv_map, texture_np, output_path
671
+ )
672
+
673
+ return textured_mesh
674
+
675
+
676
+ class Image_Super_Net:
677
+ def __init__(self, device="cuda"):
678
+ from diffusers import StableDiffusionUpscalePipeline
679
+
680
+ self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
681
+ "stabilityai/stable-diffusion-x4-upscaler",
682
+ torch_dtype=torch.float16,
683
+ ).to(device)
684
+ self.up_pipeline_x4.set_progress_bar_config(disable=True)
685
+
686
+ def __call__(self, image, prompt=""):
687
+ with torch.no_grad():
688
+ upscaled_image = self.up_pipeline_x4(
689
+ prompt=[prompt],
690
+ image=image,
691
+ num_inference_steps=10,
692
+ ).images[0]
693
+
694
+ return upscaled_image
695
+
696
+
697
+ class Image_GANNet:
698
+ def __init__(self, outscale: int):
699
+ from basicsr.archs.rrdbnet_arch import RRDBNet
700
+ from realesrgan import RealESRGANer
701
+
702
+ self.outscale = outscale
703
+ model = RRDBNet(
704
+ num_in_ch=3,
705
+ num_out_ch=3,
706
+ num_feat=64,
707
+ num_block=23,
708
+ num_grow_ch=32,
709
+ scale=4,
710
+ )
711
+ self.upsampler = RealESRGANer(
712
+ scale=4,
713
+ model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/super_resolution/RealESRGAN_x4plus.pth", # noqa
714
+ model=model,
715
+ pre_pad=0,
716
+ half=True,
717
+ )
718
+
719
+ def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
720
+ if isinstance(image, Image.Image):
721
+ image = np.array(image)
722
+ output, _ = self.upsampler.enhance(image, outscale=self.outscale)
723
+
724
+ return Image.fromarray(output)
725
+
726
+
727
+ if __name__ == "__main__":
728
+ device = "cuda"
729
+ color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png"
730
+ mesh_path = "outputs/texture_mesh_gen/texture_mesh/kettle_color.glb"
731
+ output_path = "robot_test_v6/robot.obj"
732
+ target_image_size = (2048, 2048)
733
+
734
+ super_model = Image_GANNet(outscale=4)
735
+ multiviews = get_images_from_file(color_path, img_size=512)
736
+ multiviews = [super_model(img) for img in multiviews]
737
+ multiviews = [img.convert("RGB") for img in multiviews]
738
+
739
+ from asset3d_gen.data.utils import (
740
+ CameraSetting,
741
+ init_kal_camera,
742
+ DiffrastRender,
743
+ )
744
+ import nvdiffrast.torch as dr
745
+
746
+ camera_params = CameraSetting(
747
+ num_images=6,
748
+ elevation=[20.0, -10.0],
749
+ distance=5,
750
+ resolution_hw=(2048, 2048),
751
+ fov=math.radians(30),
752
+ device="cuda",
753
+ )
754
+ camera = init_kal_camera(camera_params)
755
+ mv = camera.view_matrix() # (n 4 4) world2cam
756
+ p = camera.intrinsics.projection_matrix()
757
+ # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
758
+ p[:, 1, 1] = -p[:, 1, 1]
759
+ renderer = DiffrastRender(
760
+ p_matrix=p,
761
+ mv_matrix=mv,
762
+ resolution_hw=camera_params.resolution_hw,
763
+ context=dr.RasterizeCudaContext(),
764
+ mask_thresh=0.5,
765
+ grad_db=False,
766
+ device=camera_params.device,
767
+ antialias_mask=True,
768
+ )
769
+
770
+ mesh = trimesh.load(mesh_path)
771
+ if isinstance(mesh, trimesh.Scene):
772
+ mesh = mesh.dump(concatenate=True)
773
+
774
+ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
775
+ vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
776
+ uvs[:, 1] = 1 - uvs[:, 1]
777
+ mesh.vertices = mesh.vertices[vmapping]
778
+ mesh.faces = indices
779
+ mesh.visual.uv = uvs
780
+
781
+ vertices = torch.from_numpy(mesh.vertices).to(camera_params.device).float()
782
+ faces = (
783
+ torch.from_numpy(mesh.faces).to(camera_params.device).to(torch.int64)
784
+ )
785
+ uvs = torch.from_numpy(mesh.visual.uv).to(camera_params.device).float()
786
+
787
+ rendered_view_normals = []
788
+ rast, vertices_clip = renderer.compute_dr_raster(vertices, faces)
789
+ for idx in range(len(mv)):
790
+ pos_cam = transform_vertices(mv[idx], vertices, keepdim=True)
791
+ pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
792
+ v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3))
793
+ face_norm = F.normalize(torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1)
794
+ vertex_norm = (
795
+ torch.from_numpy(
796
+ trimesh.geometry.mean_vertex_normals(
797
+ len(pos_cam), faces.cpu(), face_norm.cpu()
798
+ )
799
+ )
800
+ .to(camera_params.device)
801
+ .contiguous()
802
+ )
803
+ im_base_normals, _ = dr.interpolate(
804
+ vertex_norm[None, ...].float(),
805
+ rast[idx : idx + 1],
806
+ faces.to(torch.int32),
807
+ )
808
+ rendered_view_normals.append(im_base_normals)
809
+
810
+ rendered_view_normals = torch.cat(rendered_view_normals, dim=0)
811
+
812
+ rendered_depth, masks = renderer.render_depth(vertices, faces)
813
+ norm_depths = []
814
+ for idx in range(len(rendered_depth)):
815
+ norm_depth = renderer.normalize_map_by_mask(
816
+ rendered_depth[idx : idx + 1], masks[idx : idx + 1]
817
+ )
818
+ norm_depths.append(norm_depth)
819
+ norm_depths = torch.cat(norm_depths, dim=0)
820
+ render_uvs, _ = renderer.render_uv(vertices, faces, uvs)
821
+
822
+ for index in range(6):
823
+ cv2.imwrite(
824
+ f"v3_mask{index}.png",
825
+ (masks[index] * 255).cpu().numpy().astype(np.uint8),
826
+ )
827
+ cv2.imwrite(
828
+ f"v3_normalv2{index}.png",
829
+ (rendered_view_normals[index] * 255)
830
+ .cpu()
831
+ .numpy()
832
+ .astype(np.uint8)[..., ::-1],
833
+ )
834
+ cv2.imwrite(
835
+ f"v3_depth{index}.png",
836
+ (norm_depths[index] * 255).cpu().numpy().astype(np.uint8),
837
+ )
838
+ cv2.imwrite(
839
+ f"v3_uv{index}.png",
840
+ (render_uvs[index, ..., 0] * 255).cpu().numpy().astype(np.uint8),
841
+ )
842
+ multiviews[index].save(f"v3_color{index}.png")
843
+
844
+ texture_backer = TextureBacker(
845
+ camera_elevs=[20, 20, 20, -10, -10, -10],
846
+ camera_azims=[-180, -60, 60, -120, 0, 120],
847
+ view_weights=[1, 0.2, 0.2, 0.2, 1, 0.2],
848
+ camera_distance=5,
849
+ camera_fov=30,
850
+ render_wh=(2048, 2048),
851
+ texture_wh=(2048, 2048),
852
+ )
853
+ texture_backer.vertices = vertices
854
+ texture_backer.faces = faces
855
+ uvs[:, 1] = 1.0 - uvs[:, 1]
856
+ texture_backer.uv_map = uvs
857
+ texture_backer.center = center
858
+ texture_backer.scale = scale
859
+
860
+ textured_mesh = texture_backer.forward(
861
+ multiviews, masks, norm_depths, rendered_view_normals, render_uvs
862
+ )
863
+
864
+ # multiviews = [super_model(img) for img in multiviews]
865
+ # multiviews = [img.convert("RGB") for img in multiviews]
866
+ # textured_mesh = texture_backer(multiviews, mesh_path, output_path)
asset3d_gen/data/backup/backprojectv2.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import math
6
+ import trimesh
7
+ import cv2
8
+ import xatlas
9
+ from typing import Union
10
+
11
+
12
+ def get_perspective_projection_matrix(fovy, aspect_wh, near, far):
13
+ fovy_rad = math.radians(fovy)
14
+ return np.array(
15
+ [
16
+ [1.0 / (math.tan(fovy_rad / 2.0) * aspect_wh), 0, 0, 0],
17
+ [0, 1.0 / math.tan(fovy_rad / 2.0), 0, 0],
18
+ [
19
+ 0,
20
+ 0,
21
+ -(far + near) / (far - near),
22
+ -2.0 * far * near / (far - near),
23
+ ],
24
+ [0, 0, -1, 0],
25
+ ]
26
+ ).astype(np.float32)
27
+
28
+
29
+ def load_mesh(mesh):
30
+ vtx_pos = mesh.vertices if hasattr(mesh, "vertices") else None
31
+ pos_idx = mesh.faces if hasattr(mesh, "faces") else None
32
+
33
+ vtx_uv = mesh.visual.uv if hasattr(mesh.visual, "uv") else None
34
+ uv_idx = mesh.faces if hasattr(mesh, "faces") else None
35
+
36
+ texture_data = None
37
+
38
+ return vtx_pos, pos_idx, vtx_uv, uv_idx, texture_data
39
+
40
+
41
+ def save_mesh(mesh, texture_data):
42
+ material = trimesh.visual.texture.SimpleMaterial(
43
+ image=texture_data, diffuse=(255, 255, 255)
44
+ )
45
+ texture_visuals = trimesh.visual.TextureVisuals(
46
+ uv=mesh.visual.uv, image=texture_data, material=material
47
+ )
48
+ mesh.visual = texture_visuals
49
+ return mesh
50
+
51
+
52
+ def transform_pos(mtx, pos, keepdim=False):
53
+ t_mtx = (
54
+ torch.from_numpy(mtx).to(pos.device)
55
+ if isinstance(mtx, np.ndarray)
56
+ else mtx
57
+ )
58
+ if pos.shape[-1] == 3:
59
+ posw = torch.cat(
60
+ [pos, torch.ones([pos.shape[0], 1]).to(pos.device)], axis=1
61
+ )
62
+ else:
63
+ posw = pos
64
+
65
+ if keepdim:
66
+ return torch.matmul(posw, t_mtx.t())[...]
67
+ else:
68
+ return torch.matmul(posw, t_mtx.t())[None, ...]
69
+
70
+
71
+ def get_mv_matrix(elev, azim, camera_distance, center=None):
72
+ elev = -elev
73
+
74
+ elev_rad = math.radians(elev)
75
+ azim_rad = math.radians(azim)
76
+
77
+ camera_position = np.array(
78
+ [
79
+ camera_distance * math.cos(elev_rad) * math.cos(azim_rad),
80
+ camera_distance * math.cos(elev_rad) * math.sin(azim_rad),
81
+ camera_distance * math.sin(elev_rad),
82
+ ]
83
+ )
84
+
85
+ if center is None:
86
+ center = np.array([0, 0, 0])
87
+ else:
88
+ center = np.array(center)
89
+
90
+ lookat = center - camera_position
91
+ lookat = lookat / np.linalg.norm(lookat)
92
+
93
+ up = np.array([0, 0, 1.0])
94
+ right = np.cross(lookat, up)
95
+ right = right / np.linalg.norm(right)
96
+ up = np.cross(right, lookat)
97
+ up = up / np.linalg.norm(up)
98
+
99
+ c2w = np.concatenate(
100
+ [np.stack([right, up, -lookat], axis=-1), camera_position[:, None]],
101
+ axis=-1,
102
+ )
103
+
104
+ w2c = np.zeros((4, 4))
105
+ w2c[:3, :3] = np.transpose(c2w[:3, :3], (1, 0))
106
+ w2c[:3, 3:] = -np.matmul(np.transpose(c2w[:3, :3], (1, 0)), c2w[:3, 3:])
107
+ w2c[3, 3] = 1.0
108
+
109
+ return w2c.astype(np.float32)
110
+
111
+
112
+ def stride_from_shape(shape):
113
+ stride = [1]
114
+ for x in reversed(shape[1:]):
115
+ stride.append(stride[-1] * x)
116
+ return list(reversed(stride))
117
+
118
+
119
+ def scatter_add_nd_with_count(input, count, indices, values, weights=None):
120
+ # input: [..., C], D dimension + C channel
121
+ # count: [..., 1], D dimension
122
+ # indices: [N, D], long
123
+ # values: [N, C]
124
+
125
+ D = indices.shape[-1]
126
+ C = input.shape[-1]
127
+ size = input.shape[:-1]
128
+ stride = stride_from_shape(size)
129
+
130
+ assert len(size) == D
131
+
132
+ input = input.view(-1, C) # [HW, C]
133
+ count = count.view(-1, 1)
134
+
135
+ flatten_indices = (
136
+ indices * torch.tensor(stride, dtype=torch.long, device=indices.device)
137
+ ).sum(
138
+ -1
139
+ ) # [N]
140
+
141
+ if weights is None:
142
+ weights = torch.ones_like(values[..., :1])
143
+
144
+ input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
145
+ count.scatter_add_(0, flatten_indices.unsqueeze(1), weights)
146
+
147
+ return input.view(*size, C), count.view(*size, 1)
148
+
149
+
150
+ def linear_grid_put_2d(H, W, coords, values, return_count=False):
151
+ # coords: [N, 2], float in [0, 1]
152
+ # values: [N, C]
153
+
154
+ C = values.shape[-1]
155
+
156
+ indices = coords * torch.tensor(
157
+ [H - 1, W - 1], dtype=torch.float32, device=coords.device
158
+ )
159
+ indices_00 = indices.floor().long() # [N, 2]
160
+ indices_00[:, 0].clamp_(0, H - 2)
161
+ indices_00[:, 1].clamp_(0, W - 2)
162
+ indices_01 = indices_00 + torch.tensor(
163
+ [0, 1], dtype=torch.long, device=indices.device
164
+ )
165
+ indices_10 = indices_00 + torch.tensor(
166
+ [1, 0], dtype=torch.long, device=indices.device
167
+ )
168
+ indices_11 = indices_00 + torch.tensor(
169
+ [1, 1], dtype=torch.long, device=indices.device
170
+ )
171
+
172
+ h = indices[..., 0] - indices_00[..., 0].float()
173
+ w = indices[..., 1] - indices_00[..., 1].float()
174
+ w_00 = (1 - h) * (1 - w)
175
+ w_01 = (1 - h) * w
176
+ w_10 = h * (1 - w)
177
+ w_11 = h * w
178
+
179
+ result = torch.zeros(
180
+ H, W, C, device=values.device, dtype=values.dtype
181
+ ) # [H, W, C]
182
+ count = torch.zeros(
183
+ H, W, 1, device=values.device, dtype=values.dtype
184
+ ) # [H, W, 1]
185
+ weights = torch.ones_like(values[..., :1]) # [N, 1]
186
+
187
+ result, count = scatter_add_nd_with_count(
188
+ result,
189
+ count,
190
+ indices_00,
191
+ values * w_00.unsqueeze(1),
192
+ weights * w_00.unsqueeze(1),
193
+ )
194
+ result, count = scatter_add_nd_with_count(
195
+ result,
196
+ count,
197
+ indices_01,
198
+ values * w_01.unsqueeze(1),
199
+ weights * w_01.unsqueeze(1),
200
+ )
201
+ result, count = scatter_add_nd_with_count(
202
+ result,
203
+ count,
204
+ indices_10,
205
+ values * w_10.unsqueeze(1),
206
+ weights * w_10.unsqueeze(1),
207
+ )
208
+ result, count = scatter_add_nd_with_count(
209
+ result,
210
+ count,
211
+ indices_11,
212
+ values * w_11.unsqueeze(1),
213
+ weights * w_11.unsqueeze(1),
214
+ )
215
+
216
+ if return_count:
217
+ return result, count
218
+
219
+ mask = count.squeeze(-1) > 0
220
+ result[mask] = result[mask] / count[mask].repeat(1, C)
221
+
222
+ return result
223
+
224
+
225
+ def meshVerticeInpaint_smooth(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx):
226
+ texture_height, texture_width, texture_channel = texture.shape
227
+ vtx_num = vtx_pos.shape[0]
228
+
229
+ vtx_mask = np.zeros(vtx_num, dtype=np.float32)
230
+ vtx_color = [
231
+ np.zeros(texture_channel, dtype=np.float32) for _ in range(vtx_num)
232
+ ]
233
+ uncolored_vtxs = []
234
+ G = [[] for _ in range(vtx_num)]
235
+
236
+ for i in range(uv_idx.shape[0]):
237
+ for k in range(3):
238
+ vtx_uv_idx = uv_idx[i, k]
239
+ vtx_idx = pos_idx[i, k]
240
+ uv_v = int(round(vtx_uv[vtx_uv_idx, 0] * (texture_width - 1)))
241
+ uv_u = int(
242
+ round((1.0 - vtx_uv[vtx_uv_idx, 1]) * (texture_height - 1))
243
+ )
244
+ if mask[uv_u, uv_v] > 0:
245
+ vtx_mask[vtx_idx] = 1.0
246
+ vtx_color[vtx_idx] = texture[uv_u, uv_v]
247
+ else:
248
+ uncolored_vtxs.append(vtx_idx)
249
+ G[pos_idx[i, k]].append(pos_idx[i, (k + 1) % 3])
250
+
251
+ smooth_count = 2
252
+ last_uncolored_vtx_count = 0
253
+ while smooth_count > 0:
254
+ uncolored_vtx_count = 0
255
+ for vtx_idx in uncolored_vtxs:
256
+ sum_color = np.zeros(texture_channel, dtype=np.float32)
257
+ total_weight = 0.0
258
+ vtx_0 = vtx_pos[vtx_idx]
259
+ for connected_idx in G[vtx_idx]:
260
+ if vtx_mask[connected_idx] > 0:
261
+ vtx1 = vtx_pos[connected_idx]
262
+ dist = np.sqrt(np.sum((vtx_0 - vtx1) ** 2))
263
+ dist_weight = 1.0 / max(dist, 1e-4)
264
+ dist_weight *= dist_weight
265
+ sum_color += vtx_color[connected_idx] * dist_weight
266
+ total_weight += dist_weight
267
+ if total_weight > 0:
268
+ vtx_color[vtx_idx] = sum_color / total_weight
269
+ vtx_mask[vtx_idx] = 1.0
270
+ else:
271
+ uncolored_vtx_count += 1
272
+
273
+ if last_uncolored_vtx_count == uncolored_vtx_count:
274
+ smooth_count -= 1
275
+ else:
276
+ smooth_count += 1
277
+ last_uncolored_vtx_count = uncolored_vtx_count
278
+
279
+ new_texture = texture.copy()
280
+ new_mask = mask.copy()
281
+ for face_idx in range(uv_idx.shape[0]):
282
+ for k in range(3):
283
+ vtx_uv_idx = uv_idx[face_idx, k]
284
+ vtx_idx = pos_idx[face_idx, k]
285
+ if vtx_mask[vtx_idx] == 1.0:
286
+ uv_v = int(round(vtx_uv[vtx_uv_idx, 0] * (texture_width - 1)))
287
+ uv_u = int(
288
+ round((1.0 - vtx_uv[vtx_uv_idx, 1]) * (texture_height - 1))
289
+ )
290
+ new_texture[uv_u, uv_v] = vtx_color[vtx_idx]
291
+ new_mask[uv_u, uv_v] = 255
292
+
293
+ return new_texture, new_mask
294
+
295
+
296
+ def mesh_uv_wrap(mesh):
297
+ if isinstance(mesh, trimesh.Scene):
298
+ mesh = mesh.dump(concatenate=True)
299
+
300
+ if len(mesh.faces) > 500000000:
301
+ raise ValueError(
302
+ "The mesh has more than 500,000,000 faces, which is not supported."
303
+ )
304
+
305
+ vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
306
+
307
+ mesh.vertices = mesh.vertices[vmapping]
308
+ mesh.faces = indices
309
+ mesh.visual.uv = uvs
310
+
311
+ return mesh
312
+
313
+
314
+ class MeshRender:
315
+ def __init__(
316
+ self,
317
+ camera_distance=1.45,
318
+ default_resolution=1024,
319
+ texture_size=1024,
320
+ use_antialias=True,
321
+ max_mip_level=None,
322
+ filter_mode="linear",
323
+ bake_mode="linear",
324
+ raster_mode="cr",
325
+ device="cuda",
326
+ ):
327
+
328
+ self.device = device
329
+
330
+ self.set_default_render_resolution(default_resolution)
331
+ self.set_default_texture_resolution(texture_size)
332
+
333
+ self.camera_distance = camera_distance
334
+ self.use_antialias = use_antialias
335
+ self.max_mip_level = max_mip_level
336
+ self.filter_mode = filter_mode
337
+
338
+ self.bake_angle_thres = 75
339
+ self.bake_unreliable_kernel_size = int(
340
+ (2 / 512)
341
+ * max(self.default_resolution[0], self.default_resolution[1])
342
+ )
343
+ self.bake_mode = bake_mode
344
+
345
+ self.raster_mode = raster_mode
346
+ if self.raster_mode == "cr":
347
+ import custom_rasterizer as cr
348
+
349
+ self.raster = cr
350
+ else:
351
+ raise f"No raster named {self.raster_mode}"
352
+
353
+ fov = 30
354
+ self.camera_proj_mat = get_perspective_projection_matrix(
355
+ fov,
356
+ self.default_resolution[1] / self.default_resolution[0],
357
+ 0.01,
358
+ 100.0,
359
+ )
360
+
361
+ def raster_rasterize(
362
+ self, pos, tri, resolution, ranges=None, grad_db=True
363
+ ):
364
+
365
+ if self.raster_mode == "cr":
366
+ rast_out_db = None
367
+ if pos.dim() == 2:
368
+ pos = pos.unsqueeze(0)
369
+ findices, barycentric = self.raster.rasterize(pos, tri, resolution)
370
+ rast_out = torch.cat((barycentric, findices.unsqueeze(-1)), dim=-1)
371
+ rast_out = rast_out.unsqueeze(0)
372
+ else:
373
+ raise f"No raster named {self.raster_mode}"
374
+
375
+ return rast_out, rast_out_db
376
+
377
+ def raster_interpolate(
378
+ self, uv, rast_out, uv_idx, rast_db=None, diff_attrs=None
379
+ ):
380
+
381
+ if self.raster_mode == "cr":
382
+ textd = None
383
+ barycentric = rast_out[0, ..., :-1]
384
+ findices = rast_out[0, ..., -1]
385
+ if uv.dim() == 2:
386
+ uv = uv.unsqueeze(0)
387
+ textc = self.raster.interpolate(uv, findices, barycentric, uv_idx)
388
+ else:
389
+ raise f"No raster named {self.raster_mode}"
390
+
391
+ return textc, textd
392
+
393
+ def load_mesh(
394
+ self,
395
+ mesh,
396
+ ):
397
+ vtx_pos, pos_idx, vtx_uv, uv_idx, texture_data = load_mesh(mesh)
398
+ self.mesh_copy = mesh
399
+ self.set_mesh(
400
+ vtx_pos,
401
+ pos_idx,
402
+ vtx_uv=vtx_uv,
403
+ uv_idx=uv_idx,
404
+ )
405
+ if texture_data is not None:
406
+ self.set_texture(texture_data)
407
+
408
+ def save_mesh(self):
409
+ texture_data = self.get_texture()
410
+ texture_data = Image.fromarray((texture_data * 255).astype(np.uint8))
411
+ return save_mesh(self.mesh_copy, texture_data)
412
+
413
+ def set_mesh(
414
+ self,
415
+ vtx_pos,
416
+ pos_idx,
417
+ vtx_uv=None,
418
+ uv_idx=None,
419
+ ):
420
+
421
+ self.vtx_pos = torch.from_numpy(vtx_pos).to(self.device).float()
422
+ self.pos_idx = torch.from_numpy(pos_idx).to(self.device).to(torch.int)
423
+ if (vtx_uv is not None) and (uv_idx is not None):
424
+ self.vtx_uv = torch.from_numpy(vtx_uv).to(self.device).float()
425
+ self.uv_idx = (
426
+ torch.from_numpy(uv_idx).to(self.device).to(torch.int)
427
+ )
428
+ else:
429
+ self.vtx_uv = None
430
+ self.uv_idx = None
431
+
432
+ self.vtx_pos[:, [0, 1]] = -self.vtx_pos[:, [0, 1]]
433
+ self.vtx_pos[:, [1, 2]] = self.vtx_pos[:, [2, 1]]
434
+ if (vtx_uv is not None) and (uv_idx is not None):
435
+ self.vtx_uv[:, 1] = 1.0 - self.vtx_uv[:, 1]
436
+
437
+ def set_texture(self, tex):
438
+ if isinstance(tex, np.ndarray):
439
+ tex = Image.fromarray((tex * 255).astype(np.uint8))
440
+ elif isinstance(tex, torch.Tensor):
441
+ tex = tex.cpu().numpy()
442
+ tex = Image.fromarray((tex * 255).astype(np.uint8))
443
+
444
+ tex = tex.resize(self.texture_size).convert("RGB")
445
+ tex = np.array(tex) / 255.0
446
+ self.tex = torch.from_numpy(tex).to(self.device)
447
+ self.tex = self.tex.float()
448
+
449
+ def set_default_render_resolution(self, default_resolution):
450
+ if isinstance(default_resolution, int):
451
+ default_resolution = (default_resolution, default_resolution)
452
+ self.default_resolution = default_resolution
453
+
454
+ def set_default_texture_resolution(self, texture_size):
455
+ if isinstance(texture_size, int):
456
+ texture_size = (texture_size, texture_size)
457
+ self.texture_size = texture_size
458
+
459
+ def get_mesh(self):
460
+ vtx_pos = self.vtx_pos.cpu().numpy()
461
+ pos_idx = self.pos_idx.cpu().numpy()
462
+ vtx_uv = self.vtx_uv.cpu().numpy()
463
+ uv_idx = self.uv_idx.cpu().numpy()
464
+
465
+ # 坐标变换的逆变换
466
+ vtx_pos[:, [1, 2]] = vtx_pos[:, [2, 1]]
467
+ vtx_pos[:, [0, 1]] = -vtx_pos[:, [0, 1]]
468
+
469
+ vtx_uv[:, 1] = 1.0 - vtx_uv[:, 1]
470
+ return vtx_pos, pos_idx, vtx_uv, uv_idx
471
+
472
+ def get_texture(self):
473
+ return self.tex.cpu().numpy()
474
+
475
+ def render_sketch_from_depth(self, depth_image):
476
+ depth_image_np = depth_image.cpu().numpy()
477
+ depth_image_np = (depth_image_np * 255).astype(np.uint8)
478
+ depth_edges = cv2.Canny(depth_image_np, 30, 80)
479
+ combined_edges = depth_edges
480
+ sketch_image = (
481
+ torch.from_numpy(combined_edges).to(depth_image.device).float()
482
+ / 255.0
483
+ )
484
+ sketch_image = sketch_image.unsqueeze(-1)
485
+ return sketch_image
486
+
487
+ def back_project(
488
+ self, image, elev, azim, camera_distance=None, center=None, method=None
489
+ ):
490
+ if isinstance(image, Image.Image):
491
+ image = torch.tensor(np.array(image) / 255.0)
492
+ elif isinstance(image, np.ndarray):
493
+ image = torch.tensor(image)
494
+ if image.dim() == 2:
495
+ image = image.unsqueeze(-1)
496
+ image = image.float().to(self.device)
497
+ resolution = image.shape[:2]
498
+ channel = image.shape[-1]
499
+ texture = torch.zeros(self.texture_size + (channel,)).to(self.device)
500
+ cos_map = torch.zeros(self.texture_size + (1,)).to(self.device)
501
+
502
+ proj = self.camera_proj_mat
503
+ r_mv = get_mv_matrix(
504
+ elev=elev,
505
+ azim=azim,
506
+ camera_distance=(
507
+ self.camera_distance
508
+ if camera_distance is None
509
+ else camera_distance
510
+ ),
511
+ center=center,
512
+ )
513
+ pos_camera = transform_pos(r_mv, self.vtx_pos, keepdim=True)
514
+ pos_clip = transform_pos(proj, pos_camera)
515
+ pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4]
516
+ v0 = pos_camera[self.pos_idx[:, 0], :]
517
+ v1 = pos_camera[self.pos_idx[:, 1], :]
518
+ v2 = pos_camera[self.pos_idx[:, 2], :]
519
+ face_normals = F.normalize(
520
+ torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1
521
+ )
522
+ vertex_normals = trimesh.geometry.mean_vertex_normals(
523
+ vertex_count=self.vtx_pos.shape[0],
524
+ faces=self.pos_idx.cpu(),
525
+ face_normals=face_normals.cpu(),
526
+ )
527
+ vertex_normals = (
528
+ torch.from_numpy(vertex_normals)
529
+ .float()
530
+ .to(self.device)
531
+ .contiguous()
532
+ )
533
+ tex_depth = pos_camera[:, 2].reshape(1, -1, 1).contiguous()
534
+ rast_out, rast_out_db = self.raster_rasterize(
535
+ pos_clip, self.pos_idx, resolution=resolution
536
+ )
537
+ visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0, ...]
538
+
539
+ normal, _ = self.raster_interpolate(
540
+ vertex_normals[None, ...], rast_out, self.pos_idx
541
+ )
542
+ normal = normal[0, ...]
543
+
544
+ uv, _ = self.raster_interpolate(
545
+ self.vtx_uv[None, ...], rast_out, self.uv_idx
546
+ )
547
+ depth, _ = self.raster_interpolate(tex_depth, rast_out, self.pos_idx)
548
+ depth = depth[0, ...]
549
+
550
+ depth_max, depth_min = (
551
+ depth[visible_mask > 0].max(),
552
+ depth[visible_mask > 0].min(),
553
+ )
554
+ depth_normalized = (depth - depth_min) / (depth_max - depth_min)
555
+ depth_image = depth_normalized * visible_mask # Mask out background.
556
+
557
+ sketch_image = self.render_sketch_from_depth(depth_image)
558
+
559
+ cv2.imwrite("d_depth.png", depth_image.cpu().numpy() * 255)
560
+ cv2.imwrite("d_normal.png", normal.cpu().numpy() * 255)
561
+ cv2.imwrite(
562
+ "d_image.png", image.cpu().numpy()[..., :3][..., ::-1] * 255
563
+ )
564
+ cv2.imwrite("d_sketch_image.png", sketch_image.cpu().numpy() * 255)
565
+ cv2.imwrite("d_uv1.png", uv.cpu().numpy()[0, ..., 0] * 255)
566
+ cv2.imwrite("d_uv2.png", uv.cpu().numpy()[0, ..., 1] * 255)
567
+ # p uv[0,...,0].mean(axis=0)
568
+ # import pdb; pdb.set_trace()
569
+
570
+ # depth_image = None
571
+ # normal = None
572
+ # image = None
573
+
574
+ sketch_image = self.render_sketch_from_depth(depth_image)
575
+ channel = image.shape[-1]
576
+
577
+ lookat = torch.tensor([[0, 0, -1]], device=self.device)
578
+ cos_image = torch.nn.functional.cosine_similarity(
579
+ lookat, normal.view(-1, 3)
580
+ )
581
+ cos_image = cos_image.view(normal.shape[0], normal.shape[1], 1)
582
+
583
+ cos_thres = np.cos(self.bake_angle_thres / 180 * np.pi)
584
+ cos_image[cos_image < cos_thres] = 0
585
+
586
+ # shrink
587
+ kernel_size = self.bake_unreliable_kernel_size * 2 + 1
588
+ kernel = torch.ones(
589
+ (1, 1, kernel_size, kernel_size), dtype=torch.float32
590
+ ).to(sketch_image.device)
591
+
592
+ visible_mask = visible_mask.permute(2, 0, 1).unsqueeze(0).float()
593
+ visible_mask = F.conv2d(
594
+ 1.0 - visible_mask, kernel, padding=kernel_size // 2
595
+ )
596
+ visible_mask = 1.0 - (visible_mask > 0).float() # 二值化
597
+ visible_mask = visible_mask.squeeze(0).permute(1, 2, 0)
598
+
599
+ sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
600
+ sketch_image = F.conv2d(sketch_image, kernel, padding=kernel_size // 2)
601
+ sketch_image = (sketch_image > 0).float() # 二值化
602
+ sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
603
+ visible_mask = visible_mask * (sketch_image < 0.5)
604
+
605
+ cos_image[visible_mask == 0] = 0
606
+ proj_mask = (visible_mask != 0).view(-1)
607
+ uv = uv.squeeze(0).contiguous().view(-1, 2)[proj_mask]
608
+ image = image.squeeze(0).contiguous().view(-1, channel)[proj_mask]
609
+ cos_image = cos_image.contiguous().view(-1, 1)[proj_mask]
610
+ sketch_image = sketch_image.contiguous().view(-1, 1)[proj_mask]
611
+ import pdb
612
+
613
+ pdb.set_trace()
614
+ texture = linear_grid_put_2d(
615
+ self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], image
616
+ )
617
+ cos_map = linear_grid_put_2d(
618
+ self.texture_size[1],
619
+ self.texture_size[0],
620
+ uv[..., [1, 0]],
621
+ cos_image,
622
+ )
623
+ boundary_map = linear_grid_put_2d(
624
+ self.texture_size[1],
625
+ self.texture_size[0],
626
+ uv[..., [1, 0]],
627
+ sketch_image,
628
+ )
629
+
630
+ return texture, cos_map, boundary_map
631
+
632
+ @torch.no_grad()
633
+ def fast_bake_texture(self, textures, cos_maps):
634
+
635
+ channel = textures[0].shape[-1]
636
+ texture_merge = torch.zeros(self.texture_size + (channel,)).to(
637
+ self.device
638
+ )
639
+ trust_map_merge = torch.zeros(self.texture_size + (1,)).to(self.device)
640
+ for texture, cos_map in zip(textures, cos_maps):
641
+ view_sum = (cos_map > 0).sum()
642
+ painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
643
+ if painted_sum / view_sum > 0.99:
644
+ continue
645
+ texture_merge += texture * cos_map
646
+ trust_map_merge += cos_map
647
+ texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
648
+
649
+ return texture_merge, trust_map_merge > 1e-8
650
+
651
+ def uv_inpaint(self, texture, mask):
652
+
653
+ if isinstance(texture, torch.Tensor):
654
+ texture_np = texture.cpu().numpy()
655
+ elif isinstance(texture, np.ndarray):
656
+ texture_np = texture
657
+ elif isinstance(texture, Image.Image):
658
+ texture_np = np.array(texture) / 255.0
659
+
660
+ vtx_pos, pos_idx, vtx_uv, uv_idx = self.get_mesh()
661
+
662
+ texture_np, mask = meshVerticeInpaint_smooth(
663
+ texture_np, mask, vtx_pos, vtx_uv, pos_idx, uv_idx
664
+ )
665
+
666
+ texture_np = cv2.inpaint(
667
+ (texture_np * 255).astype(np.uint8), 255 - mask, 3, cv2.INPAINT_NS
668
+ )
669
+
670
+ return texture_np
671
+
672
+
673
+ def get_images_from_file(img_path: str, img_size: int) -> list[np.array]:
674
+ input_image = Image.open(img_path)
675
+ view_images = np.array(input_image)
676
+ view_images = np.concatenate(
677
+ [view_images[:img_size, ...], view_images[img_size:, ...]], axis=1
678
+ )
679
+ images = np.split(view_images, view_images.shape[1] // img_size, axis=1)
680
+
681
+ return images
682
+
683
+
684
+ def bake_from_multiview(
685
+ render, views, camera_elevs, camera_azims, view_weights, method="fast"
686
+ ):
687
+ project_textures, project_weighted_cos_maps = [], []
688
+ project_boundary_maps = []
689
+ for view, camera_elev, camera_azim, weight in zip(
690
+ views, camera_elevs, camera_azims, view_weights
691
+ ):
692
+ project_texture, project_cos_map, project_boundary_map = (
693
+ render.back_project(view, camera_elev, camera_azim)
694
+ )
695
+ project_cos_map = weight * (project_cos_map**4)
696
+ project_textures.append(project_texture)
697
+ project_weighted_cos_maps.append(project_cos_map)
698
+ project_boundary_maps.append(project_boundary_map)
699
+
700
+ if method == "fast":
701
+ texture, ori_trust_map = render.fast_bake_texture(
702
+ project_textures, project_weighted_cos_maps
703
+ )
704
+ else:
705
+ raise f"no method {method}"
706
+
707
+ return texture, ori_trust_map > 1e-8
708
+
709
+
710
+ def post_process(texture: np.ndarray, iter: int = 2) -> np.ndarray:
711
+ for _ in range(iter):
712
+ texture = cv2.fastNlMeansDenoisingColored(texture, None, 11, 11, 9, 25)
713
+ texture = cv2.bilateralFilter(
714
+ texture, d=7, sigmaColor=80, sigmaSpace=80
715
+ )
716
+
717
+ return texture
718
+
719
+
720
+ class Image_Super_Net:
721
+ def __init__(self, device="cuda"):
722
+ from diffusers import StableDiffusionUpscalePipeline
723
+
724
+ self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
725
+ "stabilityai/stable-diffusion-x4-upscaler",
726
+ torch_dtype=torch.float16,
727
+ ).to(device)
728
+ self.up_pipeline_x4.set_progress_bar_config(disable=True)
729
+
730
+ def __call__(self, image, prompt=""):
731
+ with torch.no_grad():
732
+ upscaled_image = self.up_pipeline_x4(
733
+ prompt=[prompt],
734
+ image=image,
735
+ num_inference_steps=10,
736
+ ).images[0]
737
+
738
+ return upscaled_image
739
+
740
+
741
+ class Image_GANNet:
742
+ def __init__(self, outscale: int):
743
+ from realesrgan import RealESRGANer
744
+ from basicsr.archs.rrdbnet_arch import RRDBNet
745
+
746
+ self.outscale = outscale
747
+ model = RRDBNet(
748
+ num_in_ch=3,
749
+ num_out_ch=3,
750
+ num_feat=64,
751
+ num_block=23,
752
+ num_grow_ch=32,
753
+ scale=4,
754
+ )
755
+ self.upsampler = RealESRGANer(
756
+ scale=4,
757
+ model_path="/home/users/xinjie.wang/xinjie/Real-ESRGAN/weights/RealESRGAN_x4plus.pth",
758
+ model=model,
759
+ pre_pad=0,
760
+ half=True,
761
+ )
762
+
763
+ def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
764
+ if isinstance(image, Image.Image):
765
+ image = np.array(image)
766
+ output, _ = self.upsampler.enhance(image, outscale=self.outscale)
767
+
768
+ return Image.fromarray(output)
769
+
770
+
771
+ if __name__ == "__main__":
772
+ device = "cuda"
773
+
774
+ # super_model = Image_Super_Net(device)
775
+ super_model = Image_GANNet(outscale=4)
776
+
777
+ selected_camera_elevs = [20, 20, 20, -10, -10, -10]
778
+ selected_camera_azims = [-180, -60, 60, -120, 0, 120]
779
+ selected_view_weights = [1, 0.2, 0.2, 0.2, 1, 0.2]
780
+ # selected_view_weights = [1, 0.1, 0.5, 0.1, 0.05, 0.05]
781
+
782
+ multiviews = get_images_from_file(
783
+ "scripts/apps/texture_sessions/mfq4e7u4ko/multi_view/color_sample1.png",
784
+ 512,
785
+ )
786
+ target_image_size = (2048, 2048)
787
+
788
+ render = MeshRender(
789
+ camera_distance=5,
790
+ default_resolution=2048,
791
+ texture_size=2048,
792
+ )
793
+
794
+ mesh = trimesh.load("scripts/apps/assets/example_texture/meshes/robot.obj")
795
+ from asset3d_gen.data.utils import normalize_vertices_array
796
+
797
+ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
798
+ mesh = mesh_uv_wrap(mesh)
799
+ render.load_mesh(mesh)
800
+
801
+ # multiviews = [Image.fromarray(img) for img in multiviews]
802
+ # multiviews = [Image.fromarray(img).convert("RGB") for img in multiviews]
803
+ # for idx, img in enumerate(multiviews):
804
+ # img.save(f"robot/raw/res_{idx}.png")
805
+
806
+ multiviews = [super_model(img) for img in multiviews]
807
+ multiviews = [img.convert("RGB") for img in multiviews]
808
+ for idx, img in enumerate(multiviews):
809
+ img.save(f"robot/super_gan_res_{idx}.png")
810
+
811
+ texture, mask = bake_from_multiview(
812
+ render,
813
+ multiviews,
814
+ selected_camera_elevs,
815
+ selected_camera_azims,
816
+ selected_view_weights,
817
+ )
818
+
819
+ texture_np = (texture.cpu().numpy() * 255).astype(np.uint8)[..., :3][
820
+ ..., ::-1
821
+ ]
822
+ cv2.imwrite("robot/raw_texture.png", texture_np)
823
+ print("texture done.")
824
+
825
+ mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
826
+ texture_np = render.uv_inpaint(texture, mask_np)
827
+ cv2.imwrite("robot/inpaint_texture.png", texture_np[..., ::-1])
828
+ # texture_np = post_process(texture_np, 2)
829
+ # cv2.imwrite("robot/inpaint_conv_texture.png", texture_np[..., ::-1])
830
+ print("inpaint done.")
831
+
832
+ texture = torch.tensor(texture_np / 255).float().to(texture.device)
833
+ render.set_texture(texture)
834
+ textured_mesh = render.save_mesh()
835
+ _ = textured_mesh.export("robot/robot.obj")
asset3d_gen/data/backup/gpt_qwen.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
3
+ from qwen_vl_utils import process_vision_info
4
+ import os
5
+ os.environ["https_proxy"] = "10.9.0.31:8838"
6
+
7
+
8
+ # # default: Load the model on the available device(s)
9
+ # model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
10
+ # "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
11
+ # )
12
+
13
+ # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
14
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
15
+ "Qwen/Qwen2.5-VL-7B-Instruct",
16
+ torch_dtype=torch.bfloat16,
17
+ attn_implementation="flash_attention_2",
18
+ device_map="auto",
19
+ )
20
+
21
+
22
+ # default processer
23
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
24
+
25
+ # The default range for the number of visual tokens per image in the model is 4-16384.
26
+ # You can set min_pixels and max_pixels according to your needs, such as a token range of 256-1280, to balance performance and cost.
27
+ # min_pixels = 256*28*28
28
+ # max_pixels = 1280*28*28
29
+ # processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
30
+
31
+ messages = [
32
+ {
33
+ "role": "user",
34
+ "content": [
35
+ {
36
+ "type": "image",
37
+ "image": "outputs/text2image/demo_objects/bed/sample_0.jpg",
38
+ },
39
+ {
40
+ "type": "image",
41
+ "image": "outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png",
42
+ },
43
+ {"type": "text", "text": "Describe the secend image."},
44
+ ],
45
+ }
46
+ ]
47
+
48
+ # Preparation for inference
49
+ text = processor.apply_chat_template(
50
+ messages, tokenize=False, add_generation_prompt=True
51
+ )
52
+ image_inputs, video_inputs = process_vision_info(messages)
53
+ inputs = processor(
54
+ text=[text],
55
+ images=image_inputs,
56
+ videos=video_inputs,
57
+ padding=True,
58
+ return_tensors="pt",
59
+ )
60
+ inputs = inputs.to("cuda")
61
+
62
+ # Inference: Generation of the output
63
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
64
+ generated_ids_trimmed = [
65
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
66
+ ]
67
+ output_text = processor.batch_decode(
68
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
69
+ )
70
+ print(output_text)
asset3d_gen/data/backup/quat.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def quaternion_rotation_x_counterclockwise(angle_degrees):
4
+ angle_radians = np.radians(angle_degrees)
5
+ w = np.cos(angle_radians / 2)
6
+ x = np.sin(angle_radians / 2)
7
+ y, z = 0.0, 0.0
8
+ return np.array([x, y, z, w]).round(4).tolist()
9
+
10
+
11
+ def quaternion_rotation_y_counterclockwise(angle_degrees):
12
+ angle_radians = np.radians(angle_degrees)
13
+ w = np.cos(angle_radians / 2)
14
+ y = np.sin(angle_radians / 2)
15
+ x, z = 0.0, 0.0
16
+ return np.array([x, y, z, w]).round(4).tolist()
17
+
18
+
19
+ def quaternion_rotation_z_counterclockwise(angle_degrees):
20
+ angle_radians = np.radians(angle_degrees)
21
+ w = np.cos(angle_radians / 2)
22
+ z = np.sin(angle_radians / 2)
23
+ x, y = 0.0, 0.0
24
+ return np.array([x, y, z, w]).round(4).tolist()
25
+
26
+
27
+ def quaternion_multiply(q1, q2):
28
+ x1, y1, z1, w1 = q1
29
+ x2, y2, z2, w2 = q2
30
+ w = w1*w2 - x1*x2 - y1*y2 - z1*z2
31
+ x = w1*x2 + x1*w2 + y1*z2 - z1*y2
32
+ y = w1*y2 - x1*z2 + y1*w2 + z1*x2
33
+ z = w1*z2 + x1*y2 - y1*x2 + z1*w2
34
+ return np.array([w, x, y, z])
35
+
36
+
37
+
38
+ angle = 180
39
+
40
+ print(f"X轴逆时针旋转{angle}度: {quaternion_rotation_x_counterclockwise(angle)}")
41
+ print(f"Y轴逆时针旋转{angle}度: {quaternion_rotation_y_counterclockwise(angle)}")
42
+ print(f"Z轴逆时针旋转{angle}度: {quaternion_rotation_z_counterclockwise(angle)}")
43
+
44
+
45
+ q_1 = np.array([1.0, 0.0, 0.0, 0.0])
46
+ q_2 = np.array([0.0, 0.0, 1.0, 0.0])
47
+
48
+ q_total = quaternion_multiply(q_2, q_1)
49
+ print(q_total.round(4).tolist())
asset3d_gen/data/datasets.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import random
5
+ from typing import Any, Callable, Dict, List, Tuple, Union
6
+
7
+ import torch
8
+ import torch.utils.checkpoint
9
+ from PIL import Image
10
+ from torch import nn
11
+ from torch.utils.data import Dataset
12
+ from torchvision import transforms
13
+
14
+ logging.basicConfig(
15
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ __all__ = [
21
+ "Asset3dGenDataset",
22
+ ]
23
+
24
+
25
+ class Asset3dGenDataset(Dataset):
26
+ def __init__(
27
+ self,
28
+ index_file: str,
29
+ target_hw: Tuple[int, int],
30
+ transform: Callable = None,
31
+ control_transform: Callable = None,
32
+ max_train_samples: int = None,
33
+ sub_idxs: List[List[int]] = None,
34
+ seed: int = 79,
35
+ ) -> None:
36
+ if not os.path.exists(index_file):
37
+ raise FileNotFoundError(f"{index_file} index_file not found.")
38
+
39
+ self.index_file = index_file
40
+ self.target_hw = target_hw
41
+ self.transform = transform
42
+ self.control_transform = control_transform
43
+ self.max_train_samples = max_train_samples
44
+ self.meta_info = self.prepare_data_index(index_file)
45
+ self.data_list = sorted(self.meta_info.keys())
46
+ self.sub_idxs = sub_idxs # sub_idxs [[0,1,2], [3,4,5], [...], ...]
47
+ self.image_num = 6 # hardcode temp.
48
+ random.seed(seed)
49
+ logger.info(f"Trainset: {len(self)} asset3d instances.")
50
+
51
+ def __len__(self) -> int:
52
+ return len(self.meta_info)
53
+
54
+ def prepare_data_index(self, index_file: str) -> Dict[str, Any]:
55
+ with open(index_file, "r") as fin:
56
+ meta_info = json.load(fin)
57
+
58
+ meta_info_filtered = dict()
59
+ for idx, uid in enumerate(meta_info):
60
+ if "status" not in meta_info[uid]:
61
+ continue
62
+ if meta_info[uid]["status"] != "success":
63
+ continue
64
+ if self.max_train_samples and idx >= self.max_train_samples:
65
+ break
66
+
67
+ meta_info_filtered[uid] = meta_info[uid]
68
+
69
+ logger.info(
70
+ f"Load {len(meta_info)} assets, keep {len(meta_info_filtered)} valids." # noqa
71
+ )
72
+
73
+ return meta_info_filtered
74
+
75
+ def fetch_sample_images(
76
+ self,
77
+ uid: str,
78
+ attrs: List[str],
79
+ sub_index: int = None,
80
+ transform: Callable = None,
81
+ ) -> torch.Tensor:
82
+ sample = self.meta_info[uid]
83
+ images = []
84
+ for attr in attrs:
85
+ item = sample[attr]
86
+ if sub_index is not None:
87
+ item = item[sub_index]
88
+ mode = "L" if attr == "image_mask" else "RGB"
89
+ image = Image.open(item).convert(mode)
90
+ if transform is not None:
91
+ image = transform(image)
92
+ if len(image.shape) == 2:
93
+ image = image[..., None]
94
+ images.append(image)
95
+
96
+ images = torch.cat(images, dim=0)
97
+
98
+ return images
99
+
100
+ def fetch_sample_grid_images(
101
+ self,
102
+ uid: str,
103
+ attrs: List[str],
104
+ sub_idxs: List[List[int]],
105
+ transform: Callable = None,
106
+ ) -> torch.Tensor:
107
+ assert transform is not None
108
+
109
+ grid_image = []
110
+ for row_idxs in sub_idxs:
111
+ row_image = []
112
+ for row_idx in row_idxs:
113
+ image = self.fetch_sample_images(
114
+ uid, attrs, row_idx, transform
115
+ )
116
+ row_image.append(image)
117
+ row_image = torch.cat(row_image, dim=2) # (c h w)
118
+ grid_image.append(row_image)
119
+
120
+ grid_image = torch.cat(grid_image, dim=1)
121
+
122
+ return grid_image
123
+
124
+ def compute_text_embeddings(
125
+ self, embed_path: str, original_size: Tuple[int, int]
126
+ ) -> Dict[str, nn.Module]:
127
+ data_dict = torch.load(embed_path)
128
+ prompt_embeds = data_dict["prompt_embeds"][0]
129
+ add_text_embeds = data_dict["pooled_prompt_embeds"][0]
130
+
131
+ # Need changed if random crop, set as crop_top_left [y1, x1], center crop as [0, 0]. # noqa
132
+ crops_coords_top_left = (0, 0)
133
+ add_time_ids = list(
134
+ original_size + crops_coords_top_left + self.target_hw
135
+ )
136
+ add_time_ids = torch.tensor([add_time_ids])
137
+ # add_time_ids = add_time_ids.repeat((len(add_text_embeds), 1))
138
+
139
+ unet_added_cond_kwargs = {
140
+ "text_embeds": add_text_embeds,
141
+ "time_ids": add_time_ids,
142
+ }
143
+
144
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
145
+
146
+ def visualize_item(
147
+ self,
148
+ control: torch.Tensor,
149
+ color: torch.Tensor,
150
+ save_dir: str = None,
151
+ ) -> List[Image.Image]:
152
+ to_pil = transforms.ToPILImage()
153
+
154
+ color = (color + 1) / 2
155
+ color_pil = to_pil(color)
156
+ normal_pil = to_pil(control[0:3])
157
+ position_pil = to_pil(control[3:6])
158
+ mask_pil = to_pil(control[6:])
159
+
160
+ if save_dir is not None:
161
+ os.makedirs(save_dir, exist_ok=True)
162
+ color_pil.save(f"{save_dir}/rgb.jpg")
163
+ normal_pil.save(f"{save_dir}/normal.jpg")
164
+ position_pil.save(f"{save_dir}/position.jpg")
165
+ mask_pil.save(f"{save_dir}/mask.jpg")
166
+ logger.info(f"Visualization in {save_dir}")
167
+
168
+ return normal_pil, position_pil, mask_pil, color_pil
169
+
170
+ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
171
+ uid = self.data_list[index]
172
+
173
+ sub_idxs = self.sub_idxs
174
+ if sub_idxs is None:
175
+ sub_idxs = [[random.randint(0, self.image_num - 1)]]
176
+
177
+ input_image = self.fetch_sample_grid_images(
178
+ uid,
179
+ attrs=["image_view_normal", "image_position", "image_mask"],
180
+ sub_idxs=sub_idxs,
181
+ transform=self.control_transform,
182
+ )
183
+ assert input_image.shape[1:] == self.target_hw
184
+
185
+ output_image = self.fetch_sample_grid_images(
186
+ uid,
187
+ attrs=["image_color"],
188
+ sub_idxs=sub_idxs,
189
+ transform=self.transform,
190
+ )
191
+
192
+ sample = self.meta_info[uid]
193
+ text_feats = self.compute_text_embeddings(
194
+ sample["text_feat"], tuple(sample["image_hw"])
195
+ )
196
+
197
+ data = dict(
198
+ pixel_values=output_image,
199
+ conditioning_pixel_values=input_image,
200
+ prompt_embeds=text_feats["prompt_embeds"],
201
+ text_embeds=text_feats["text_embeds"],
202
+ time_ids=text_feats["time_ids"],
203
+ )
204
+
205
+ return data
206
+
207
+
208
+ if __name__ == "__main__":
209
+ index_file = "/horizon-bucket/robot_lab/users/xinjie.wang/datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa
210
+ target_hw = (512, 512)
211
+ transform_list = [
212
+ transforms.Resize(
213
+ target_hw, interpolation=transforms.InterpolationMode.BILINEAR
214
+ ),
215
+ transforms.CenterCrop(target_hw),
216
+ transforms.ToTensor(),
217
+ transforms.Normalize([0.5], [0.5]),
218
+ ]
219
+ image_transform = transforms.Compose(transform_list)
220
+ control_transform = transforms.Compose(transform_list[:-1])
221
+
222
+ sub_idxs = [[0, 1, 2], [3, 4, 5]] # None
223
+ if sub_idxs is not None:
224
+ target_hw = (
225
+ target_hw[0] * len(sub_idxs),
226
+ target_hw[1] * len(sub_idxs[0]),
227
+ )
228
+
229
+ dataset = Asset3dGenDataset(
230
+ index_file,
231
+ target_hw,
232
+ image_transform,
233
+ control_transform,
234
+ sub_idxs=sub_idxs,
235
+ )
236
+ data = dataset[0]
237
+ dataset.visualize_item(
238
+ data["conditioning_pixel_values"], data["pixel_values"], save_dir="./"
239
+ )
asset3d_gen/data/differentiable_render.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import math
5
+ import os
6
+ from collections import defaultdict
7
+ from typing import List, Union
8
+
9
+ import cv2
10
+ import imageio
11
+ import numpy as np
12
+ import nvdiffrast.torch as dr
13
+ import torch
14
+ from PIL import Image
15
+ from tqdm import tqdm
16
+ from asset3d_gen.data.utils import (
17
+ CameraSetting,
18
+ DiffrastRender,
19
+ RenderItems,
20
+ as_list,
21
+ calc_vertex_normals,
22
+ import_kaolin_mesh,
23
+ init_kal_camera,
24
+ normalize_vertices_array,
25
+ render_pbr,
26
+ save_images,
27
+ )
28
+
29
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
30
+ os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
31
+ "~/.cache/torch_extensions"
32
+ )
33
+ logging.basicConfig(
34
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
35
+ )
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ def create_gif_from_images(images, output_path, fps=10):
40
+ pil_images = []
41
+ for image in images:
42
+ image = image.clip(min=0, max=1)
43
+ image = (255.0 * image).astype(np.uint8)
44
+ image = Image.fromarray(image, mode="RGBA")
45
+ pil_images.append(image.convert("RGB"))
46
+
47
+ duration = 1000 // fps
48
+ pil_images[0].save(
49
+ output_path,
50
+ save_all=True,
51
+ append_images=pil_images[1:],
52
+ duration=duration,
53
+ loop=0,
54
+ )
55
+
56
+ logger.info(f"GIF saved to {output_path}")
57
+
58
+
59
+ def create_mp4_from_images(images, output_path, fps=10, prompt=None):
60
+ font = cv2.FONT_HERSHEY_SIMPLEX # 字体样式
61
+ font_scale = 0.5 # 字体大小
62
+ font_thickness = 1 # 字体粗细
63
+ color = (255, 255, 255) # 文字颜色(白色)
64
+ position = (20, 25) # 左上角坐标 (x, y)
65
+
66
+ with imageio.get_writer(output_path, fps=fps) as writer:
67
+ for image in images:
68
+ image = image.clip(min=0, max=1)
69
+ image = (255.0 * image).astype(np.uint8)
70
+ image = image[..., :3]
71
+ if prompt is not None:
72
+ cv2.putText(
73
+ image,
74
+ prompt,
75
+ position,
76
+ font,
77
+ font_scale,
78
+ color,
79
+ font_thickness,
80
+ )
81
+
82
+ writer.append_data(image)
83
+
84
+ logger.info(f"MP4 video saved to {output_path}")
85
+
86
+
87
+ class ImageRender(object):
88
+ def __init__(
89
+ self,
90
+ render_items: list[RenderItems],
91
+ camera_params: CameraSetting,
92
+ recompute_vtx_normal: bool = True,
93
+ device: str = "cuda",
94
+ with_mtl: bool = False,
95
+ gen_color_gif: bool = False,
96
+ gen_color_mp4: bool = False,
97
+ gen_viewnormal_mp4: bool = False,
98
+ gen_glonormal_mp4: bool = False,
99
+ no_index_file: bool = False,
100
+ light_factor: float = 1.0,
101
+ ) -> None:
102
+ camera_params.device = device
103
+ camera = init_kal_camera(camera_params)
104
+ self.camera = camera
105
+
106
+ # Setup MVP matrix and renderer.
107
+ mv = camera.view_matrix() # (n 4 4) world2cam
108
+ p = camera.intrinsics.projection_matrix()
109
+ # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
110
+ p[:, 1, 1] = -p[:, 1, 1]
111
+ # mvp = torch.bmm(p, mv) # camera.view_projection_matrix()
112
+ self.mv = mv
113
+ self.p = p
114
+
115
+ renderer = DiffrastRender(
116
+ p_matrix=p,
117
+ mv_matrix=mv,
118
+ resolution_hw=camera_params.resolution_hw,
119
+ context=dr.RasterizeCudaContext(),
120
+ mask_thresh=0.5,
121
+ grad_db=False,
122
+ device=camera_params.device,
123
+ antialias_mask=True,
124
+ )
125
+ self.renderer = renderer
126
+ self.recompute_vtx_normal = recompute_vtx_normal
127
+ self.render_items = render_items
128
+ self.device = device
129
+ self.with_mtl = with_mtl
130
+ self.gen_color_gif = gen_color_gif
131
+ self.gen_color_mp4 = gen_color_mp4
132
+ self.gen_viewnormal_mp4 = gen_viewnormal_mp4
133
+ self.gen_glonormal_mp4 = gen_glonormal_mp4
134
+ self.light_factor = light_factor
135
+ self.no_index_file = no_index_file
136
+
137
+ def render_mesh(
138
+ self,
139
+ mesh_path: Union[str, List[str]],
140
+ output_root: str,
141
+ uuid: Union[str, List[str]] = None,
142
+ prompts: List[str] = None,
143
+ ) -> None:
144
+ mesh_path = as_list(mesh_path)
145
+ if uuid is None:
146
+ uuid = [os.path.basename(p).split(".")[0] for p in mesh_path]
147
+ uuid = as_list(uuid)
148
+ assert len(mesh_path) == len(uuid)
149
+ os.makedirs(output_root, exist_ok=True)
150
+
151
+ meta_info = dict()
152
+ for idx, (path, uid) in tqdm(
153
+ enumerate(zip(mesh_path, uuid)), total=len(mesh_path)
154
+ ):
155
+ output_dir = os.path.join(output_root, uid)
156
+ os.makedirs(output_dir, exist_ok=True)
157
+ prompt = prompts[idx] if prompts else None
158
+ data_dict = self(path, output_dir, prompt)
159
+ meta_info[uid] = data_dict
160
+
161
+ if self.no_index_file:
162
+ return
163
+
164
+ index_file = os.path.join(output_root, "index.json")
165
+ with open(index_file, "w") as fout:
166
+ json.dump(meta_info, fout)
167
+
168
+ logger.info(f"Rendering meta info logged in {index_file}")
169
+
170
+ def __call__(
171
+ self, mesh_path: str, output_dir: str, prompt: str = None
172
+ ) -> dict[str, str]:
173
+ try:
174
+ mesh = import_kaolin_mesh(mesh_path, self.with_mtl)
175
+ except Exception as e:
176
+ logger.error(f"[ERROR MESH LOAD]: {e}, skip {mesh_path}")
177
+ return
178
+
179
+ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
180
+ if self.recompute_vtx_normal:
181
+ mesh.vertex_normals = calc_vertex_normals(
182
+ mesh.vertices, mesh.faces
183
+ )
184
+
185
+ mesh = mesh.to(self.device)
186
+ vertices, faces, vertex_normals = (
187
+ mesh.vertices,
188
+ mesh.faces,
189
+ mesh.vertex_normals,
190
+ )
191
+
192
+ # Perform rendering.
193
+ data_dict = defaultdict(list)
194
+ if RenderItems.ALPHA.value in self.render_items:
195
+ masks, _ = self.renderer.render_rast_alpha(vertices, faces)
196
+ render_paths = save_images(
197
+ masks, f"{output_dir}/{RenderItems.ALPHA}"
198
+ )
199
+ data_dict[RenderItems.ALPHA.value] = render_paths
200
+
201
+ if RenderItems.GLOBAL_NORMAL.value in self.render_items:
202
+ rendered_normals, masks = self.renderer.render_global_normal(
203
+ vertices, faces, vertex_normals
204
+ )
205
+ if self.gen_glonormal_mp4:
206
+ if isinstance(rendered_normals, torch.Tensor):
207
+ rendered_normals = rendered_normals.detach().cpu().numpy()
208
+ create_mp4_from_images(
209
+ rendered_normals,
210
+ output_path=f"{output_dir}/normal.mp4",
211
+ fps=15,
212
+ prompt=prompt,
213
+ )
214
+ else:
215
+ render_paths = save_images(
216
+ rendered_normals,
217
+ f"{output_dir}/{RenderItems.GLOBAL_NORMAL}",
218
+ cvt_color=cv2.COLOR_BGR2RGB,
219
+ )
220
+ data_dict[RenderItems.GLOBAL_NORMAL.value] = render_paths
221
+
222
+ if RenderItems.VIEW_NORMAL.value in self.render_items:
223
+ assert (
224
+ RenderItems.GLOBAL_NORMAL in self.render_items
225
+ ), f"Must render global normal firstly, got render_items: {self.render_items}." # noqa
226
+ rendered_view_normals = self.renderer.transform_normal(
227
+ rendered_normals, self.mv, masks, to_view=True
228
+ )
229
+ # rendered_inv_view_normals = renderer.transform_normal(rendered_view_normals, torch.linalg.inv(mv), masks, to_view=False) # noqa
230
+ if self.gen_viewnormal_mp4:
231
+ create_mp4_from_images(
232
+ rendered_view_normals,
233
+ output_path=f"{output_dir}/view_normal.mp4",
234
+ fps=15,
235
+ prompt=prompt,
236
+ )
237
+ else:
238
+ render_paths = save_images(
239
+ rendered_view_normals,
240
+ f"{output_dir}/{RenderItems.VIEW_NORMAL}",
241
+ cvt_color=cv2.COLOR_BGR2RGB,
242
+ )
243
+ data_dict[RenderItems.VIEW_NORMAL.value] = render_paths
244
+
245
+ if RenderItems.POSITION_MAP.value in self.render_items:
246
+ rendered_position, masks = self.renderer.render_position(
247
+ vertices, faces
248
+ )
249
+ norm_position = self.renderer.normalize_map_by_mask(
250
+ rendered_position, masks
251
+ )
252
+ render_paths = save_images(
253
+ norm_position,
254
+ f"{output_dir}/{RenderItems.POSITION_MAP}",
255
+ cvt_color=cv2.COLOR_BGR2RGB,
256
+ )
257
+ data_dict[RenderItems.POSITION_MAP.value] = render_paths
258
+
259
+ if RenderItems.DEPTH.value in self.render_items:
260
+ rendered_depth, masks = self.renderer.render_depth(vertices, faces)
261
+ norm_depth = self.renderer.normalize_map_by_mask(
262
+ rendered_depth, masks
263
+ )
264
+ render_paths = save_images(
265
+ norm_depth,
266
+ f"{output_dir}/{RenderItems.DEPTH}",
267
+ )
268
+ data_dict[RenderItems.DEPTH.value] = render_paths
269
+
270
+ render_paths = save_images(
271
+ rendered_depth,
272
+ f"{output_dir}/{RenderItems.DEPTH}_exr",
273
+ to_uint8=False,
274
+ format=".exr",
275
+ )
276
+ data_dict[f"{RenderItems.DEPTH.value}_exr"] = render_paths
277
+
278
+ if RenderItems.IMAGE.value in self.render_items:
279
+ images = []
280
+ albedos = []
281
+ diffuses = []
282
+ masks, _ = self.renderer.render_rast_alpha(vertices, faces)
283
+ try:
284
+ for idx, cam in enumerate(self.camera):
285
+ image, albedo, diffuse, _ = render_pbr(
286
+ mesh, cam, light_factor=self.light_factor
287
+ )
288
+ image = torch.cat([image[0], masks[idx]], axis=-1)
289
+ images.append(image.detach().cpu().numpy())
290
+
291
+ if RenderItems.ALBEDO.value in self.render_items:
292
+ albedo = torch.cat([albedo[0], masks[idx]], axis=-1)
293
+ albedos.append(albedo.detach().cpu().numpy())
294
+
295
+ if RenderItems.DIFFUSE.value in self.render_items:
296
+ diffuse = torch.cat([diffuse[0], masks[idx]], axis=-1)
297
+ diffuses.append(diffuse.detach().cpu().numpy())
298
+
299
+ except Exception as e:
300
+ logger.error(f"[ERROR pbr render]: {e}, skip {mesh_path}")
301
+ return
302
+
303
+ if self.gen_color_gif:
304
+ create_gif_from_images(
305
+ images,
306
+ output_path=f"{output_dir}/color.gif",
307
+ fps=15,
308
+ )
309
+
310
+ if self.gen_color_mp4:
311
+ create_mp4_from_images(
312
+ images,
313
+ output_path=f"{output_dir}/color.mp4",
314
+ fps=15,
315
+ prompt=prompt,
316
+ )
317
+
318
+ if self.gen_color_mp4 or self.gen_color_gif:
319
+ return data_dict
320
+
321
+ render_paths = save_images(
322
+ images,
323
+ f"{output_dir}/{RenderItems.IMAGE}",
324
+ cvt_color=cv2.COLOR_BGRA2RGBA,
325
+ )
326
+ data_dict[RenderItems.IMAGE.value] = render_paths
327
+
328
+ render_paths = save_images(
329
+ albedos,
330
+ f"{output_dir}/{RenderItems.ALBEDO}",
331
+ cvt_color=cv2.COLOR_BGRA2RGBA,
332
+ )
333
+ data_dict[RenderItems.ALBEDO.value] = render_paths
334
+
335
+ render_paths = save_images(
336
+ diffuses,
337
+ f"{output_dir}/{RenderItems.DIFFUSE}",
338
+ cvt_color=cv2.COLOR_BGRA2RGBA,
339
+ )
340
+ data_dict[RenderItems.DIFFUSE.value] = render_paths
341
+
342
+ data_dict["status"] = "success"
343
+
344
+ logger.info(f"Finish rendering in {output_dir}")
345
+
346
+ return data_dict
347
+
348
+
349
+ def parse_args():
350
+ parser = argparse.ArgumentParser(description="Render settings")
351
+
352
+ parser.add_argument(
353
+ "--mesh_path",
354
+ type=str,
355
+ nargs="+",
356
+ required=True,
357
+ help="Paths to the mesh files for rendering.",
358
+ )
359
+ parser.add_argument(
360
+ "--output_root",
361
+ type=str,
362
+ required=True,
363
+ help="Root directory for output",
364
+ )
365
+ parser.add_argument(
366
+ "--uuid",
367
+ type=str,
368
+ nargs="+",
369
+ default=None,
370
+ help="uuid for rendering saving.",
371
+ )
372
+ parser.add_argument(
373
+ "--num_images", type=int, default=6, help="Number of images to render."
374
+ )
375
+ parser.add_argument(
376
+ "--elevation",
377
+ type=float,
378
+ nargs="+",
379
+ default=[20.0, -10.0],
380
+ help="Elevation angles for the camera (default: [20.0, -10.0])",
381
+ )
382
+ parser.add_argument(
383
+ "--distance",
384
+ type=float,
385
+ default=5,
386
+ help="Camera distance (default: 5)",
387
+ )
388
+ parser.add_argument(
389
+ "--resolution_hw",
390
+ type=int,
391
+ nargs=2,
392
+ default=(512, 512),
393
+ help="Resolution of the output images (default: (512, 512))",
394
+ )
395
+ parser.add_argument(
396
+ "--fov",
397
+ type=float,
398
+ default=30,
399
+ help="Field of view in degrees (default: 30)",
400
+ )
401
+ parser.add_argument(
402
+ "--pbr_light_factor",
403
+ type=float,
404
+ default=1.0,
405
+ help="Light factor for mesh PBR rendering (default: 2.)",
406
+ )
407
+ parser.add_argument(
408
+ "--device",
409
+ type=str,
410
+ choices=["cpu", "cuda"],
411
+ default="cuda",
412
+ help="Device to run on (default: 'cuda')",
413
+ )
414
+ parser.add_argument(
415
+ "--with_mtl",
416
+ action="store_true",
417
+ help="Whether to render with mesh material.",
418
+ )
419
+ parser.add_argument(
420
+ "--gen_color_gif",
421
+ action="store_true",
422
+ help="Whether to generate color .gif rendering file.",
423
+ )
424
+ parser.add_argument(
425
+ "--gen_color_mp4",
426
+ action="store_true",
427
+ help="Whether to generate color .mp4 rendering file.",
428
+ )
429
+ parser.add_argument(
430
+ "--gen_viewnormal_mp4",
431
+ action="store_true",
432
+ help="Whether to generate view normal .mp4 rendering file.",
433
+ )
434
+ parser.add_argument(
435
+ "--gen_glonormal_mp4",
436
+ action="store_true",
437
+ help="Whether to generate global normal .mp4 rendering file.",
438
+ )
439
+ parser.add_argument(
440
+ "--prompts",
441
+ type=str,
442
+ nargs="+",
443
+ default=None,
444
+ help="Text prompts for the rendering.",
445
+ )
446
+
447
+ args = parser.parse_args()
448
+
449
+ if args.uuid is None:
450
+ args.uuid = []
451
+ for path in args.mesh_path:
452
+ uuid = os.path.basename(path).split(".")[0]
453
+ args.uuid.append(uuid)
454
+
455
+ return args
456
+
457
+
458
+ def entrypoint() -> None:
459
+ args = parse_args()
460
+
461
+ camera_settings = CameraSetting(
462
+ num_images=args.num_images,
463
+ elevation=args.elevation,
464
+ distance=args.distance,
465
+ resolution_hw=args.resolution_hw,
466
+ fov=math.radians(args.fov),
467
+ device=args.device,
468
+ )
469
+
470
+ render_items = [
471
+ RenderItems.ALPHA.value,
472
+ RenderItems.GLOBAL_NORMAL.value,
473
+ RenderItems.VIEW_NORMAL.value,
474
+ RenderItems.POSITION_MAP.value,
475
+ RenderItems.IMAGE.value,
476
+ RenderItems.DEPTH.value,
477
+ # RenderItems.ALBEDO.value,
478
+ # RenderItems.DIFFUSE.value,
479
+ ]
480
+
481
+ gen_video = (
482
+ args.gen_color_gif
483
+ or args.gen_color_mp4
484
+ or args.gen_viewnormal_mp4
485
+ or args.gen_glonormal_mp4
486
+ )
487
+ if gen_video:
488
+ render_items = []
489
+ if args.gen_color_gif or args.gen_color_mp4:
490
+ render_items.append(RenderItems.IMAGE.value)
491
+ if args.gen_glonormal_mp4:
492
+ render_items.append(RenderItems.GLOBAL_NORMAL.value)
493
+ if args.gen_viewnormal_mp4:
494
+ render_items.append(RenderItems.VIEW_NORMAL.value)
495
+ if RenderItems.GLOBAL_NORMAL.value not in render_items:
496
+ render_items.append(RenderItems.GLOBAL_NORMAL.value)
497
+
498
+ image_render = ImageRender(
499
+ render_items=render_items,
500
+ camera_params=camera_settings,
501
+ with_mtl=args.with_mtl,
502
+ gen_color_gif=args.gen_color_gif,
503
+ gen_color_mp4=args.gen_color_mp4,
504
+ gen_viewnormal_mp4=args.gen_viewnormal_mp4,
505
+ gen_glonormal_mp4=args.gen_glonormal_mp4,
506
+ light_factor=args.pbr_light_factor,
507
+ no_index_file=gen_video,
508
+ )
509
+ image_render.render_mesh(
510
+ mesh_path=args.mesh_path,
511
+ output_root=args.output_root,
512
+ uuid=args.uuid,
513
+ prompts=args.prompts,
514
+ )
515
+
516
+ return
517
+
518
+
519
+ if __name__ == "__main__":
520
+ entrypoint()
asset3d_gen/data/mesh_operator.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Tuple, Union
3
+
4
+ import igraph
5
+ import numpy as np
6
+ import pyvista as pv
7
+ import torch
8
+ import utils3d
9
+ from pymeshfix import _meshfix
10
+ from tqdm import tqdm
11
+
12
+ logging.basicConfig(
13
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
14
+ )
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ __all__ = ["MeshFixer"]
19
+
20
+
21
+ def radical_inverse(base, n):
22
+ val = 0
23
+ inv_base = 1.0 / base
24
+ inv_base_n = inv_base
25
+ while n > 0:
26
+ digit = n % base
27
+ val += digit * inv_base_n
28
+ n //= base
29
+ inv_base_n *= inv_base
30
+ return val
31
+
32
+
33
+ def halton_sequence(dim, n):
34
+ PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
35
+ return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
36
+
37
+
38
+ def hammersley_sequence(dim, n, num_samples):
39
+ return [n / num_samples] + halton_sequence(dim - 1, n)
40
+
41
+
42
+ def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False):
43
+ """Generate a point on a unit sphere using the Hammersley sequence.
44
+
45
+ Args:
46
+ n (int): The index of the sample.
47
+ num_samples (int): The total number of samples.
48
+ offset (tuple, optional): Offset for the u and v coordinates.
49
+ remap (bool, optional): Whether to remap the u coordinate.
50
+
51
+ Returns:
52
+ list: A list containing the spherical coordinates [phi, theta].
53
+ """
54
+ u, v = hammersley_sequence(2, n, num_samples)
55
+ u += offset[0] / num_samples
56
+ v += offset[1]
57
+
58
+ if remap:
59
+ u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
60
+
61
+ theta = np.arccos(1 - 2 * u) - np.pi / 2
62
+ phi = v * 2 * np.pi
63
+ return [phi, theta]
64
+
65
+
66
+ class MeshFixer(object):
67
+ """Reduce and postprocess 3D meshes, simplifying and filling holes."""
68
+
69
+ def __init__(
70
+ self,
71
+ vertices: Union[torch.Tensor, np.ndarray],
72
+ faces: Union[torch.Tensor, np.ndarray],
73
+ device: str = "cuda",
74
+ ) -> None:
75
+ self.device = device
76
+ self.vertices = (
77
+ torch.tensor(vertices, device=device)
78
+ if isinstance(vertices, np.ndarray)
79
+ else vertices.to(device)
80
+ )
81
+ self.faces = (
82
+ torch.tensor(faces.astype(np.int32), device=device)
83
+ if isinstance(faces, np.ndarray)
84
+ else faces.to(device)
85
+ )
86
+
87
+ @staticmethod
88
+ def log_mesh_changes(method):
89
+ def wrapper(self, *args, **kwargs):
90
+ logger.info(
91
+ f"Before {method.__name__}: {self.vertices.shape[0]} vertices, {self.faces.shape[0]} faces" # noqa
92
+ )
93
+ result = method(self, *args, **kwargs)
94
+ logger.info(
95
+ f"After {method.__name__}: {self.vertices.shape[0]} vertices, {self.faces.shape[0]} faces" # noqa
96
+ )
97
+ return result
98
+
99
+ return wrapper
100
+
101
+ @log_mesh_changes
102
+ def fill_holes(
103
+ self,
104
+ max_hole_size: float,
105
+ max_hole_nbe: int,
106
+ resolution: int,
107
+ num_views: int,
108
+ norm_mesh_ratio: float = 1.0,
109
+ ) -> None:
110
+ self.vertices = self.vertices * norm_mesh_ratio
111
+ vertices, self.faces = self._fill_holes(
112
+ self.vertices,
113
+ self.faces,
114
+ max_hole_size,
115
+ max_hole_nbe,
116
+ resolution,
117
+ num_views,
118
+ )
119
+ self.vertices = vertices / norm_mesh_ratio
120
+
121
+ @staticmethod
122
+ @torch.no_grad()
123
+ def _fill_holes(
124
+ vertices: torch.Tensor,
125
+ faces: torch.Tensor,
126
+ max_hole_size: float,
127
+ max_hole_nbe: int,
128
+ resolution: int,
129
+ num_views: int,
130
+ ) -> Union[torch.Tensor, torch.Tensor]:
131
+ yaws, pitchs = [], []
132
+ for i in range(num_views):
133
+ y, p = sphere_hammersley_sequence(i, num_views)
134
+ yaws.append(y)
135
+ pitchs.append(p)
136
+
137
+ yaws, pitchs = torch.tensor(yaws).to(vertices), torch.tensor(
138
+ pitchs
139
+ ).to(vertices)
140
+ radius, fov = 2.0, torch.deg2rad(torch.tensor(40)).to(vertices)
141
+ projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3)
142
+
143
+ views = []
144
+ for yaw, pitch in zip(yaws, pitchs):
145
+ orig = (
146
+ torch.tensor(
147
+ [
148
+ torch.sin(yaw) * torch.cos(pitch),
149
+ torch.cos(yaw) * torch.cos(pitch),
150
+ torch.sin(pitch),
151
+ ]
152
+ ).to(vertices)
153
+ * radius
154
+ )
155
+ view = utils3d.torch.view_look_at(
156
+ orig,
157
+ torch.tensor([0, 0, 0]).to(vertices),
158
+ torch.tensor([0, 0, 1]).to(vertices),
159
+ )
160
+ views.append(view)
161
+ views = torch.stack(views, dim=0)
162
+
163
+ # Rasterize the mesh
164
+ visibility = torch.zeros(
165
+ faces.shape[0], dtype=torch.int32, device=faces.device
166
+ )
167
+ rastctx = utils3d.torch.RastContext(backend="cuda")
168
+
169
+ for i in tqdm(
170
+ range(views.shape[0]), total=views.shape[0], desc="Rasterizing"
171
+ ):
172
+ view = views[i]
173
+ buffers = utils3d.torch.rasterize_triangle_faces(
174
+ rastctx,
175
+ vertices[None],
176
+ faces,
177
+ resolution,
178
+ resolution,
179
+ view=view,
180
+ projection=projection,
181
+ )
182
+ face_id = buffers["face_id"][0][buffers["mask"][0] > 0.95] - 1
183
+ face_id = torch.unique(face_id).long()
184
+ visibility[face_id] += 1
185
+
186
+ # Normalize visibility by the number of views
187
+ visibility = visibility.float() / num_views
188
+
189
+ # Mincut: Identify outer and inner faces
190
+ edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces)
191
+ boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1)
192
+ connected_components = utils3d.torch.compute_connected_components(
193
+ faces, edges, face2edge
194
+ )
195
+
196
+ outer_face_indices = torch.zeros(
197
+ faces.shape[0], dtype=torch.bool, device=faces.device
198
+ )
199
+ for i in range(len(connected_components)):
200
+ outer_face_indices[connected_components[i]] = visibility[
201
+ connected_components[i]
202
+ ] > min(
203
+ max(
204
+ visibility[connected_components[i]].quantile(0.75).item(),
205
+ 0.25,
206
+ ),
207
+ 0.5,
208
+ )
209
+
210
+ outer_face_indices = outer_face_indices.nonzero().reshape(-1)
211
+ inner_face_indices = torch.nonzero(visibility == 0).reshape(-1)
212
+
213
+ if inner_face_indices.shape[0] == 0:
214
+ return vertices, faces
215
+
216
+ # Construct dual graph (faces as nodes, edges as edges)
217
+ dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(
218
+ face2edge
219
+ )
220
+ dual_edge2edge = edges[dual_edge2edge]
221
+ dual_edges_weights = torch.norm(
222
+ vertices[dual_edge2edge[:, 0]] - vertices[dual_edge2edge[:, 1]],
223
+ dim=1,
224
+ )
225
+
226
+ # Mincut: Construct main graph and solve the mincut problem
227
+ g = igraph.Graph()
228
+ g.add_vertices(faces.shape[0])
229
+ g.add_edges(dual_edges.cpu().numpy())
230
+ g.es["weight"] = dual_edges_weights.cpu().numpy()
231
+
232
+ g.add_vertex("s") # source
233
+ g.add_vertex("t") # target
234
+
235
+ g.add_edges(
236
+ [(f, "s") for f in inner_face_indices],
237
+ attributes={
238
+ "weight": torch.ones(
239
+ inner_face_indices.shape[0], dtype=torch.float32
240
+ )
241
+ .cpu()
242
+ .numpy()
243
+ },
244
+ )
245
+ g.add_edges(
246
+ [(f, "t") for f in outer_face_indices],
247
+ attributes={
248
+ "weight": torch.ones(
249
+ outer_face_indices.shape[0], dtype=torch.float32
250
+ )
251
+ .cpu()
252
+ .numpy()
253
+ },
254
+ )
255
+
256
+ cut = g.mincut("s", "t", (np.array(g.es["weight"]) * 1000).tolist())
257
+ remove_face_indices = torch.tensor(
258
+ [v for v in cut.partition[0] if v < faces.shape[0]],
259
+ dtype=torch.long,
260
+ device=faces.device,
261
+ )
262
+
263
+ # Check if the cut is valid with each connected component
264
+ to_remove_cc = utils3d.torch.compute_connected_components(
265
+ faces[remove_face_indices]
266
+ )
267
+ valid_remove_cc = []
268
+ cutting_edges = []
269
+ for cc in to_remove_cc:
270
+ # Check visibility median for connected component
271
+ visibility_median = visibility[remove_face_indices[cc]].median()
272
+ if visibility_median > 0.25:
273
+ continue
274
+
275
+ # Check if the cutting loop is small enough
276
+ cc_edge_indices, cc_edges_degree = torch.unique(
277
+ face2edge[remove_face_indices[cc]], return_counts=True
278
+ )
279
+ cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1]
280
+ cc_new_boundary_edge_indices = cc_boundary_edge_indices[
281
+ ~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)
282
+ ]
283
+ if len(cc_new_boundary_edge_indices) > 0:
284
+ cc_new_boundary_edge_cc = (
285
+ utils3d.torch.compute_edge_connected_components(
286
+ edges[cc_new_boundary_edge_indices]
287
+ )
288
+ )
289
+ cc_new_boundary_edges_cc_center = [
290
+ vertices[edges[cc_new_boundary_edge_indices[edge_cc]]]
291
+ .mean(dim=1)
292
+ .mean(dim=0)
293
+ for edge_cc in cc_new_boundary_edge_cc
294
+ ]
295
+ cc_new_boundary_edges_cc_area = []
296
+ for i, edge_cc in enumerate(cc_new_boundary_edge_cc):
297
+ _e1 = (
298
+ vertices[
299
+ edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]
300
+ ]
301
+ - cc_new_boundary_edges_cc_center[i]
302
+ )
303
+ _e2 = (
304
+ vertices[
305
+ edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]
306
+ ]
307
+ - cc_new_boundary_edges_cc_center[i]
308
+ )
309
+ cc_new_boundary_edges_cc_area.append(
310
+ torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum()
311
+ * 0.5
312
+ )
313
+ cutting_edges.append(cc_new_boundary_edge_indices)
314
+ if any(
315
+ [
316
+ _l > max_hole_size
317
+ for _l in cc_new_boundary_edges_cc_area
318
+ ]
319
+ ):
320
+ continue
321
+
322
+ valid_remove_cc.append(cc)
323
+
324
+ if len(valid_remove_cc) > 0:
325
+ remove_face_indices = remove_face_indices[
326
+ torch.cat(valid_remove_cc)
327
+ ]
328
+ mask = torch.ones(
329
+ faces.shape[0], dtype=torch.bool, device=faces.device
330
+ )
331
+ mask[remove_face_indices] = 0
332
+ faces = faces[mask]
333
+ faces, vertices = utils3d.torch.remove_unreferenced_vertices(
334
+ faces, vertices
335
+ )
336
+
337
+ tqdm.write(f"Removed {(~mask).sum()} faces by mincut")
338
+ else:
339
+ tqdm.write(f"Removed 0 faces by mincut")
340
+
341
+ # Fill small boundaries (holes)
342
+ mesh = _meshfix.PyTMesh()
343
+ mesh.load_array(vertices.cpu().numpy(), faces.cpu().numpy())
344
+ mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True)
345
+
346
+ _vertices, _faces = mesh.return_arrays()
347
+ vertices = torch.tensor(_vertices).to(vertices)
348
+ faces = torch.tensor(_faces).to(faces)
349
+
350
+ return vertices, faces
351
+
352
+ @property
353
+ def vertices_np(self) -> np.ndarray:
354
+ return self.vertices.cpu().numpy()
355
+
356
+ @property
357
+ def faces_np(self) -> np.ndarray:
358
+ return self.faces.cpu().numpy()
359
+
360
+ @log_mesh_changes
361
+ def simplify(self, ratio: float) -> None:
362
+ """Simplify the mesh using quadric edge collapse decimation.
363
+
364
+ Args:
365
+ ratio (float): Ratio of faces to filter out.
366
+ """
367
+ if ratio <= 0 or ratio >= 1:
368
+ raise ValueError("Simplify ratio must be between 0 and 1.")
369
+
370
+ # Convert to PyVista format for simplification
371
+ mesh = pv.PolyData(
372
+ self.vertices_np,
373
+ np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]),
374
+ )
375
+ mesh = mesh.decimate(ratio, progress_bar=True)
376
+
377
+ # Update vertices and faces
378
+ self.vertices = torch.tensor(
379
+ mesh.points, device=self.device, dtype=torch.float32
380
+ )
381
+ self.faces = torch.tensor(
382
+ mesh.faces.reshape(-1, 4)[:, 1:],
383
+ device=self.device,
384
+ dtype=torch.int32,
385
+ )
386
+
387
+ def __call__(
388
+ self,
389
+ filter_ratio: float,
390
+ max_hole_size: float,
391
+ resolution: int,
392
+ num_views: int,
393
+ norm_mesh_ratio: float = 1.0,
394
+ ) -> Tuple[np.ndarray, np.ndarray]:
395
+ """Post-process the mesh by simplifying and filling holes.
396
+
397
+ This method performs a two-step process:
398
+ 1. Simplifies mesh by reducing faces using quadric edge decimation.
399
+ 2. Fills holes by removing invisible faces, repairing small boundaries.
400
+
401
+ Args:
402
+ filter_ratio (float): Ratio of faces to simplify out.
403
+ Must be in the range (0, 1).
404
+ max_hole_size (float): Maximum area of a hole to fill. Connected
405
+ components of holes larger than this size will not be repaired.
406
+ resolution (int): Resolution of the rasterization buffer.
407
+ num_views (int): Number of viewpoints to sample for rasterization.
408
+ norm_mesh_ratio (float, optional): A scaling factor applied to the
409
+ vertices of the mesh during processing.
410
+
411
+ Returns:
412
+ Tuple[np.ndarray, np.ndarray]:
413
+ - vertices: Simplified and repaired vertex array of (V, 3).
414
+ - faces: Simplified and repaired face array of (F, 3).
415
+ """
416
+ self.simplify(ratio=filter_ratio)
417
+ self.fill_holes(
418
+ max_hole_size=max_hole_size,
419
+ max_hole_nbe=int(250 * np.sqrt(1 - filter_ratio)),
420
+ resolution=resolution,
421
+ num_views=num_views,
422
+ norm_mesh_ratio=norm_mesh_ratio,
423
+ )
424
+
425
+ return self.vertices_np, self.faces_np
asset3d_gen/data/utils.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ from glob import glob
5
+ from typing import List, Tuple, Union
6
+
7
+ import cv2
8
+ import kaolin as kal
9
+ import numpy as np
10
+ import nvdiffrast.torch as dr
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from PIL import Image
14
+
15
+ try:
16
+ from kolors.models.modeling_chatglm import ChatGLMModel
17
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
18
+ except ImportError:
19
+ ChatGLMTokenizer = None
20
+ ChatGLMModel = None
21
+ import logging
22
+ from dataclasses import dataclass, field
23
+ from enum import Enum
24
+
25
+ import trimesh
26
+ from kaolin.render.camera import Camera
27
+ from torch import nn
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ __all__ = [
33
+ "center_points",
34
+ "get_points_stat",
35
+ "DiffrastRender",
36
+ "compute_cam_pts_by_az_el",
37
+ "compute_cam_pts_by_views",
38
+ "save_images",
39
+ "render_pbr",
40
+ "load_llm_models",
41
+ "prelabel_text_feature",
42
+ "calc_vertex_normals",
43
+ "normalize_vertices_array",
44
+ "load_mesh_to_unit_cube",
45
+ "as_list",
46
+ "CameraSetting",
47
+ "RenderItems",
48
+ "import_kaolin_mesh",
49
+ "save_mesh_with_mtl",
50
+ "get_images_from_grid",
51
+ "post_process_texture",
52
+ ]
53
+
54
+
55
+ def get_points_stat(
56
+ points: torch.FloatTensor, eps: float = 1e-6
57
+ ) -> torch.FloatTensor:
58
+ assert (
59
+ len(points.shape) == 3
60
+ ), f"Points have unexpected shape {points.shape}"
61
+
62
+ vmin = points.min(dim=1, keepdim=True)[0]
63
+ vmax = points.max(dim=1, keepdim=True)[0]
64
+ pts_center = (vmin + vmax) / 2
65
+
66
+ pts_dim = (vmax - vmin).max(dim=-1, keepdim=True)[0].clip(min=eps)
67
+
68
+ return pts_center, pts_dim
69
+
70
+
71
+ def center_points(
72
+ points: torch.FloatTensor, normalize: bool = False, eps: float = 1e-6
73
+ ) -> torch.FloatTensor:
74
+ vmid, den = get_points_stat(points)
75
+
76
+ res = points - vmid
77
+
78
+ if normalize:
79
+ res = res / den
80
+
81
+ return res
82
+
83
+
84
+ class DiffrastRender(object):
85
+ """A class to handle differentiable rendering using nvdiffrast.
86
+
87
+ This class provides methods to render position, depth, and normal maps
88
+ with optional anti-aliasing and gradient disabling for rasterization.
89
+
90
+ Attributes:
91
+ p_mtx (torch.Tensor): Projection matrix.
92
+ mv_mtx (torch.Tensor): Model-view matrix.
93
+ mvp_mtx (torch.Tensor): Model-view-projection matrix, calculated as
94
+ p_mtx @ mv_mtx if not provided.
95
+ resolution_hw (Tuple[int, int]): Height and width of the rendering resolution. # noqa
96
+ _ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): Rasterization context. # noqa
97
+ mask_thresh (float): Threshold for mask creation.
98
+ grad_db (bool): Whether to disable gradients during rasterization.
99
+ antialias_mask (bool): Whether to apply anti-aliasing to the mask.
100
+ device (str): Device used for rendering ('cuda' or 'cpu').
101
+
102
+ Methods:
103
+ _warmup(glctx): Warmup rasterization by rendering a simple triangle.
104
+ compute_dr_raster(vertices, faces): Rasterizes the mesh and returns
105
+ rasterized outputs and transformed vertices.
106
+ transform_vertices(vertices, matrix): Transforms the vertices using
107
+ the provided transformation matrix.
108
+ normalize_map_by_mask_separately(map, mask): Normalizes each map in
109
+ the batch separately using the mask.
110
+ normalize_map_by_mask(map, mask): Normalizes the entire map using the
111
+ mask, keeping the output in the range [0, 1].
112
+ render_position(vertices, faces): Renders the position map and
113
+ alpha mask from the given vertices and faces.
114
+ render_depth(vertices, faces): Renders the depth map and alpha
115
+ mask from the given vertices and faces.
116
+ _compute_mask(rast, vertices_clip, faces): Computes the mask from the
117
+ rasterization output.
118
+ render_global_normal(vertices, faces, vertice_normals): Renders the
119
+ normal map and alpha mask from the given vertices, faces, and
120
+ vertex normals.
121
+ transform_normal_to_view(normals, mat_w2c, masks): Transforms the normals
122
+ to the view space using the world-to-camera matrix.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ p_matrix: torch.Tensor,
128
+ mv_matrix: torch.Tensor,
129
+ resolution_hw: Tuple[int, int],
130
+ context: Union[dr.RasterizeCudaContext, dr.RasterizeGLContext] = None,
131
+ mvp_matrix: torch.Tensor = None,
132
+ mask_thresh: float = 0.5,
133
+ grad_db: bool = False,
134
+ antialias_mask: bool = True,
135
+ align_coordinate: bool = True,
136
+ device: str = "cuda",
137
+ ) -> None:
138
+ self.p_mtx = p_matrix
139
+ self.mv_mtx = mv_matrix
140
+ if mvp_matrix is None:
141
+ self.mvp_mtx = torch.bmm(p_matrix, mv_matrix)
142
+
143
+ self.resolution_hw = resolution_hw
144
+ if context is None:
145
+ context = dr.RasterizeCudaContext(device=device)
146
+ self._ctx = context
147
+ self.mask_thresh = mask_thresh
148
+ self.grad_db = grad_db
149
+ self.antialias_mask = antialias_mask
150
+ self.align_coordinate = align_coordinate
151
+ self.device = device
152
+ # self._warmup(self._ctx)
153
+
154
+ def _warmup(self, glctx):
155
+ # Seem solved. https://github.com/NVlabs/nvdiffrast/issues/59
156
+ def tensor(*args, **kwargs):
157
+ return torch.tensor(*args, device=self.device, **kwargs)
158
+
159
+ pos = tensor(
160
+ [[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]],
161
+ dtype=torch.float32,
162
+ )
163
+ tri = tensor([[0, 1, 2]], dtype=torch.int32)
164
+ dr.rasterize(glctx, pos, tri, resolution=[256, 256])
165
+
166
+ def compute_dr_raster(
167
+ self,
168
+ vertices: torch.Tensor,
169
+ faces: torch.Tensor,
170
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
171
+ vertices_clip = self.transform_vertices(vertices, matrix=self.mvp_mtx)
172
+ rast, _ = dr.rasterize(
173
+ self._ctx,
174
+ vertices_clip,
175
+ faces.int(),
176
+ resolution=self.resolution_hw,
177
+ grad_db=self.grad_db,
178
+ )
179
+
180
+ return rast, vertices_clip
181
+
182
+ def transform_vertices(
183
+ self,
184
+ vertices: torch.Tensor,
185
+ matrix: torch.Tensor,
186
+ ) -> torch.Tensor:
187
+ verts_ones = torch.ones((len(vertices), 1)).to(vertices)
188
+ verts_homo = torch.cat([vertices, verts_ones], dim=-1)
189
+ trans_vertices = torch.matmul(verts_homo, matrix.permute(0, 2, 1))
190
+
191
+ return trans_vertices
192
+
193
+ def normalize_map_by_mask_separately(
194
+ self, map: torch.Tensor, mask: torch.Tensor
195
+ ) -> torch.Tensor:
196
+ # Normalize each map separately by mask, normalized map in [0, 1].
197
+ normalized_maps = []
198
+ for map_item, mask_item in zip(map, mask):
199
+ normalized_map = self.normalize_map_by_mask(map_item, mask_item)
200
+ normalized_maps.append(normalized_map)
201
+
202
+ normalized_maps = torch.stack(normalized_maps, dim=0)
203
+
204
+ return normalized_maps
205
+
206
+ def normalize_map_by_mask(
207
+ self, map: torch.Tensor, mask: torch.Tensor
208
+ ) -> torch.Tensor:
209
+ # Normalize all maps in total by mask, normalized map in [0, 1].
210
+ foreground = (mask == 1).squeeze(dim=-1)
211
+ foreground_elements = map[foreground]
212
+ if len(foreground_elements) == 0:
213
+ return map
214
+
215
+ min_val, _ = foreground_elements.min(dim=0)
216
+ max_val, _ = foreground_elements.max(dim=0)
217
+ val_range = (max_val - min_val).clip(min=1e-6)
218
+
219
+ normalized_map = (map - min_val) / val_range
220
+ normalized_map = torch.lerp(
221
+ torch.zeros_like(normalized_map), normalized_map, mask
222
+ )
223
+ normalized_map[normalized_map < 0] = 0
224
+
225
+ return normalized_map
226
+
227
+ def _compute_mask(
228
+ self,
229
+ rast: torch.Tensor,
230
+ vertices_clip: torch.Tensor,
231
+ faces: torch.Tensor,
232
+ ) -> torch.Tensor:
233
+ mask = (rast[..., 3:] > 0).float()
234
+ mask = mask.clip(min=0, max=1)
235
+
236
+ if self.antialias_mask is True:
237
+ mask = dr.antialias(mask, rast, vertices_clip, faces)
238
+ else:
239
+ foreground = mask > self.mask_thresh
240
+ mask[foreground] = 1
241
+ mask[~foreground] = 0
242
+
243
+ return mask
244
+
245
+ def render_rast_alpha(
246
+ self,
247
+ vertices: torch.Tensor,
248
+ faces: torch.Tensor,
249
+ ):
250
+ faces = faces.to(torch.int32)
251
+ rast, vertices_clip = self.compute_dr_raster(vertices, faces)
252
+ mask = self._compute_mask(rast, vertices_clip, faces)
253
+
254
+ return mask, rast
255
+
256
+ def render_position(
257
+ self,
258
+ vertices: torch.Tensor,
259
+ faces: torch.Tensor,
260
+ ) -> Union[torch.Tensor, torch.Tensor]:
261
+ # Vertices in model coordinate system, real position coordinate number.
262
+ faces = faces.to(torch.int32)
263
+ mask, rast = self.render_rast_alpha(vertices, faces)
264
+
265
+ vertices_model = vertices[None, ...].contiguous().float()
266
+ position_map, _ = dr.interpolate(vertices_model, rast, faces)
267
+ # Align with blender.
268
+ if self.align_coordinate:
269
+ position_map = position_map[..., [0, 2, 1]]
270
+ position_map[..., 1] = -position_map[..., 1]
271
+
272
+ position_map = torch.lerp(
273
+ torch.zeros_like(position_map), position_map, mask
274
+ )
275
+
276
+ return position_map, mask
277
+
278
+ def render_uv(
279
+ self,
280
+ vertices: torch.Tensor,
281
+ faces: torch.Tensor,
282
+ vtx_uv: torch.Tensor,
283
+ ) -> Union[torch.Tensor, torch.Tensor]:
284
+ faces = faces.to(torch.int32)
285
+ mask, rast = self.render_rast_alpha(vertices, faces)
286
+ uv_map, _ = dr.interpolate(vtx_uv, rast, faces)
287
+ uv_map = torch.lerp(torch.zeros_like(uv_map), uv_map, mask)
288
+
289
+ return uv_map, mask
290
+
291
+ def render_depth(
292
+ self,
293
+ vertices: torch.Tensor,
294
+ faces: torch.Tensor,
295
+ ) -> Union[torch.Tensor, torch.Tensor]:
296
+ # Vertices in model coordinate system, real depth coordinate number.
297
+ faces = faces.to(torch.int32)
298
+ mask, rast = self.render_rast_alpha(vertices, faces)
299
+
300
+ vertices_camera = self.transform_vertices(vertices, matrix=self.mv_mtx)
301
+ vertices_camera = vertices_camera[..., 2:3].contiguous().float()
302
+ depth_map, _ = dr.interpolate(vertices_camera, rast, faces)
303
+ # Change camera depth minus to positive.
304
+ if self.align_coordinate:
305
+ depth_map = -depth_map
306
+ depth_map = torch.lerp(torch.zeros_like(depth_map), depth_map, mask)
307
+
308
+ return depth_map, mask
309
+
310
+ def render_global_normal(
311
+ self,
312
+ vertices: torch.Tensor,
313
+ faces: torch.Tensor,
314
+ vertice_normals: torch.Tensor,
315
+ ) -> Union[torch.Tensor, torch.Tensor]:
316
+ # NOTE: vertice_normals in [-1, 1], return normal in [0, 1].
317
+ # vertices / vertice_normals in model coordinate system.
318
+ faces = faces.to(torch.int32)
319
+ mask, rast = self.render_rast_alpha(vertices, faces)
320
+ im_base_normals, _ = dr.interpolate(
321
+ vertice_normals[None, ...].float(), rast, faces
322
+ )
323
+
324
+ if im_base_normals is not None:
325
+ faces = faces.to(torch.int64)
326
+ vertices_cam = self.transform_vertices(
327
+ vertices, matrix=self.mv_mtx
328
+ )
329
+ face_vertices_ndc = kal.ops.mesh.index_vertices_by_faces(
330
+ vertices_cam[..., :3], faces
331
+ )
332
+ face_normal_sign = kal.ops.mesh.face_normals(face_vertices_ndc)[
333
+ ..., 2
334
+ ]
335
+ for idx in range(len(im_base_normals)):
336
+ face_idx = (rast[idx, ..., -1].long() - 1).contiguous()
337
+ im_normal_sign = torch.sign(face_normal_sign[idx, face_idx])
338
+ im_normal_sign[face_idx == -1] = 0
339
+ im_base_normals[idx] *= im_normal_sign.unsqueeze(-1)
340
+
341
+ normal = (im_base_normals + 1) / 2
342
+ normal = normal.clip(min=0, max=1)
343
+ normal = torch.lerp(torch.zeros_like(normal), normal, mask)
344
+
345
+ return normal, mask
346
+
347
+ def transform_normal(
348
+ self,
349
+ normals: torch.Tensor,
350
+ trans_matrix: torch.Tensor,
351
+ masks: torch.Tensor,
352
+ to_view: bool,
353
+ ) -> torch.Tensor:
354
+ # NOTE: input normals in [0, 1], output normals in [0, 1].
355
+ normals = normals.clone()
356
+ assert len(normals) == len(trans_matrix)
357
+
358
+ if not to_view:
359
+ # Flip the sign on the x-axis to match inv bae system for global transformation. # noqa
360
+ normals[..., 0] = 1 - normals[..., 0]
361
+
362
+ normals = 2 * normals - 1
363
+ b, h, w, c = normals.shape
364
+
365
+ transformed_normals = []
366
+ for normal, matrix in zip(normals, trans_matrix):
367
+ # Transform normals using the transformation matrix (4x4).
368
+ reshaped_normals = normal.view(-1, c) # (h w 3) -> (hw 3)
369
+ padded_vectors = torch.nn.functional.pad(
370
+ reshaped_normals, pad=(0, 1), mode="constant", value=0.0
371
+ )
372
+ transformed_normal = torch.matmul(
373
+ padded_vectors, matrix.transpose(0, 1)
374
+ )[..., :3]
375
+
376
+ # Normalize and clip the normals to [0, 1] range.
377
+ transformed_normal = F.normalize(transformed_normal, p=2, dim=-1)
378
+ transformed_normal = (transformed_normal + 1) / 2
379
+
380
+ if to_view:
381
+ # Flip the sign on the x-axis to match bae system for view transformation. # noqa
382
+ transformed_normal[..., 0] = 1 - transformed_normal[..., 0]
383
+
384
+ transformed_normals.append(transformed_normal.view(h, w, c))
385
+
386
+ transformed_normals = torch.stack(transformed_normals, dim=0)
387
+
388
+ if masks is not None:
389
+ transformed_normals = torch.lerp(
390
+ torch.zeros_like(transformed_normals),
391
+ transformed_normals,
392
+ masks,
393
+ )
394
+
395
+ return transformed_normals
396
+
397
+
398
+ def az_el_to_points(
399
+ azimuths: np.ndarray, elevations: np.ndarray
400
+ ) -> np.ndarray:
401
+ x = np.cos(azimuths) * np.cos(elevations)
402
+ y = np.sin(azimuths) * np.cos(elevations)
403
+ z = np.sin(elevations)
404
+
405
+ return np.stack([x, y, z], axis=-1)
406
+
407
+
408
+ def compute_az_el_by_views(
409
+ num_view: int, el: float
410
+ ) -> Tuple[np.ndarray, np.ndarray]:
411
+ azimuths = np.arange(num_view) / num_view * np.pi * 2
412
+ elevations = np.deg2rad(np.array([el] * num_view))
413
+
414
+ return azimuths, elevations
415
+
416
+
417
+ def compute_cam_pts_by_az_el(
418
+ azs: np.ndarray,
419
+ els: np.ndarray,
420
+ distance: float,
421
+ extra_pts: np.ndarray = None,
422
+ ) -> np.ndarray:
423
+ distances = np.array([distance for _ in range(len(azs))])
424
+ cam_pts = az_el_to_points(azs, els) * distances[:, None]
425
+
426
+ if extra_pts is not None:
427
+ cam_pts = np.concatenate([cam_pts, extra_pts], axis=0)
428
+
429
+ # Align coordinate system.
430
+ cam_pts = cam_pts[:, [0, 2, 1]] # xyz -> xzy
431
+ cam_pts[..., 2] = -cam_pts[..., 2]
432
+
433
+ return cam_pts
434
+
435
+
436
+ def compute_cam_pts_by_views(
437
+ num_view: int, el: float, distance: float, extra_pts: np.ndarray = None
438
+ ) -> torch.Tensor:
439
+ """Computes object-center camera points for a given number of views.
440
+
441
+ Args:
442
+ num_view (int): The number of views (camera positions) to compute.
443
+ el (float): The elevation angle in degrees.
444
+ distance (float): The distance from the origin to the camera.
445
+ extra_pts (np.ndarray): Extra camera points postion.
446
+
447
+ Returns:
448
+ torch.Tensor: A tensor containing the camera points for each view, with shape `(num_view, 3)`. # noqa
449
+ """
450
+ azimuths, elevations = compute_az_el_by_views(num_view, el)
451
+ cam_pts = compute_cam_pts_by_az_el(
452
+ azimuths, elevations, distance, extra_pts
453
+ )
454
+
455
+ return cam_pts
456
+
457
+
458
+ def save_images(
459
+ images: Union[list[np.ndarray], list[torch.Tensor]],
460
+ output_dir: str,
461
+ cvt_color: str = None,
462
+ format: str = ".png",
463
+ to_uint8: bool = True,
464
+ verbose: bool = False,
465
+ ) -> List[str]:
466
+ # NOTE: images in [0, 1]
467
+ os.makedirs(output_dir, exist_ok=True)
468
+ save_paths = []
469
+ for idx, image in enumerate(images):
470
+ if isinstance(image, torch.Tensor):
471
+ image = image.detach().cpu().numpy()
472
+ if to_uint8:
473
+ image = image.clip(min=0, max=1)
474
+ image = (255.0 * image).astype(np.uint8)
475
+ if cvt_color is not None:
476
+ image = cv2.cvtColor(image, cvt_color)
477
+ save_path = os.path.join(output_dir, f"{idx:04d}{format}")
478
+ save_paths.append(save_path)
479
+
480
+ cv2.imwrite(save_path, image)
481
+
482
+ if verbose:
483
+ logger.info(f"Images saved in {output_dir}")
484
+
485
+ return save_paths
486
+
487
+
488
+ def current_lighting(
489
+ azimuths: List[float],
490
+ elevations: List[float],
491
+ light_factor: float = 1.0,
492
+ device: str = "cuda",
493
+ ):
494
+ # azimuths, elevations in degress.
495
+ directions = []
496
+ for az, el in zip(azimuths, elevations):
497
+ az, el = math.radians(az), math.radians(el)
498
+ direction = kal.render.lighting.sg_direction_from_azimuth_elevation(
499
+ az, el
500
+ )
501
+ directions.append(direction)
502
+ directions = torch.cat(directions, dim=0)
503
+
504
+ amplitude = torch.ones_like(directions) * light_factor
505
+ light_condition = kal.render.lighting.SgLightingParameters(
506
+ amplitude=amplitude,
507
+ direction=directions,
508
+ sharpness=3,
509
+ ).to(device)
510
+
511
+ # light_condition = kal.render.lighting.SgLightingParameters.from_sun(
512
+ # directions, strength=1, angle=90, color=None
513
+ # ).to(device)
514
+
515
+ return light_condition
516
+
517
+
518
+ def render_pbr(
519
+ mesh,
520
+ camera,
521
+ device="cuda",
522
+ cxt=None,
523
+ custom_materials=None,
524
+ light_factor=1.0,
525
+ ):
526
+ if cxt is None:
527
+ cxt = dr.RasterizeCudaContext()
528
+
529
+ light_condition = current_lighting(
530
+ azimuths=[0, 90, 180, 270],
531
+ elevations=[90, 60, 30, 20],
532
+ light_factor=light_factor,
533
+ device=device,
534
+ )
535
+ render_res = kal.render.easy_render.render_mesh(
536
+ camera,
537
+ mesh,
538
+ lighting=light_condition,
539
+ nvdiffrast_context=cxt,
540
+ custom_materials=custom_materials,
541
+ )
542
+
543
+ image = render_res[kal.render.easy_render.RenderPass.render]
544
+ image = image.clip(0, 1)
545
+
546
+ albedo = render_res[kal.render.easy_render.RenderPass.albedo]
547
+ albedo = albedo.clip(0, 1)
548
+
549
+ diffuse = render_res[kal.render.easy_render.RenderPass.diffuse]
550
+ diffuse = diffuse.clip(0, 1)
551
+
552
+ normal = render_res[kal.render.easy_render.RenderPass.normals]
553
+ normal = normal.clip(-1, 1)
554
+
555
+ return image, albedo, diffuse, normal
556
+
557
+
558
+ def load_saved_normal(path: str) -> np.ndarray:
559
+ image_paths = glob(os.path.join(path, "*.jpg"))
560
+ images = []
561
+ for path in sorted(image_paths):
562
+ image = cv2.imread(path)
563
+ image = image[..., ::-1] # rgb -> bgr
564
+ images.append(image)
565
+ images = np.stack(images, axis=0)
566
+
567
+ return images
568
+
569
+
570
+ def _move_to_target_device(data, device: str):
571
+ if isinstance(data, dict):
572
+ for key, value in data.items():
573
+ data[key] = _move_to_target_device(value, device)
574
+ elif isinstance(data, torch.Tensor):
575
+ return data.to(device)
576
+
577
+ return data
578
+
579
+
580
+ def _encode_prompt(
581
+ prompt_batch,
582
+ text_encoders,
583
+ tokenizers,
584
+ proportion_empty_prompts=0,
585
+ is_train=True,
586
+ ):
587
+ prompt_embeds_list = []
588
+
589
+ captions = []
590
+ for caption in prompt_batch:
591
+ if random.random() < proportion_empty_prompts:
592
+ captions.append("")
593
+ elif isinstance(caption, str):
594
+ captions.append(caption)
595
+ elif isinstance(caption, (list, np.ndarray)):
596
+ captions.append(random.choice(caption) if is_train else caption[0])
597
+
598
+ with torch.no_grad():
599
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
600
+ text_inputs = tokenizer(
601
+ captions,
602
+ padding="max_length",
603
+ max_length=256,
604
+ truncation=True,
605
+ return_tensors="pt",
606
+ ).to(text_encoder.device)
607
+
608
+ output = text_encoder(
609
+ input_ids=text_inputs.input_ids,
610
+ attention_mask=text_inputs.attention_mask,
611
+ position_ids=text_inputs.position_ids,
612
+ output_hidden_states=True,
613
+ )
614
+
615
+ # We are only interested in the pooled output of the text encoder.
616
+ prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
617
+ pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()
618
+ bs_embed, seq_len, _ = prompt_embeds.shape
619
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
620
+ prompt_embeds_list.append(prompt_embeds)
621
+
622
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
623
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
624
+
625
+ return prompt_embeds, pooled_prompt_embeds
626
+
627
+
628
+ def load_llm_models(pretrained_model_name_or_path: str, device: str):
629
+ tokenizer = ChatGLMTokenizer.from_pretrained(
630
+ pretrained_model_name_or_path,
631
+ subfolder="text_encoder",
632
+ )
633
+ text_encoder = ChatGLMModel.from_pretrained(
634
+ pretrained_model_name_or_path,
635
+ subfolder="text_encoder",
636
+ ).to(device)
637
+
638
+ text_encoders = [
639
+ text_encoder,
640
+ ]
641
+ tokenizers = [
642
+ tokenizer,
643
+ ]
644
+
645
+ logger.info(f"Load model from {pretrained_model_name_or_path} done.")
646
+
647
+ return tokenizers, text_encoders
648
+
649
+
650
+ def prelabel_text_feature(
651
+ prompt_batch: List[str],
652
+ output_dir: str,
653
+ tokenizers: nn.Module,
654
+ text_encoders: nn.Module,
655
+ ) -> List[str]:
656
+ os.makedirs(output_dir, exist_ok=True)
657
+
658
+ # prompt_batch ["text..."]
659
+ prompt_embeds, pooled_prompt_embeds = _encode_prompt(
660
+ prompt_batch, text_encoders, tokenizers
661
+ )
662
+
663
+ prompt_embeds = _move_to_target_device(prompt_embeds, device="cpu")
664
+ pooled_prompt_embeds = _move_to_target_device(
665
+ pooled_prompt_embeds, device="cpu"
666
+ )
667
+
668
+ data_dict = dict(
669
+ prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds
670
+ )
671
+
672
+ save_path = os.path.join(output_dir, "text_feat.pth")
673
+ torch.save(data_dict, save_path)
674
+
675
+ return save_path
676
+
677
+
678
+ def calc_face_normals(
679
+ vertices: torch.Tensor, # V,3 first vertex may be unreferenced
680
+ faces: torch.Tensor, # F,3 long, first face may be all zero
681
+ normalize: bool = False,
682
+ ) -> torch.Tensor: # F,3
683
+ full_vertices = vertices[faces] # F,C=3,3
684
+ v0, v1, v2 = full_vertices.unbind(dim=1) # F,3
685
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=1) # F,3
686
+ if normalize:
687
+ face_normals = F.normalize(
688
+ face_normals, eps=1e-6, dim=1
689
+ ) # TODO inplace?
690
+ return face_normals # F,3
691
+
692
+
693
+ def calc_vertex_normals(
694
+ vertices: torch.Tensor, # V,3 first vertex may be unreferenced
695
+ faces: torch.Tensor, # F,3 long, first face may be all zero
696
+ face_normals: torch.Tensor = None, # F,3, not normalized
697
+ ) -> torch.Tensor: # F,3
698
+ _F = faces.shape[0]
699
+
700
+ if face_normals is None:
701
+ face_normals = calc_face_normals(vertices, faces)
702
+
703
+ vertex_normals = torch.zeros(
704
+ (vertices.shape[0], 3, 3), dtype=vertices.dtype, device=vertices.device
705
+ ) # V,C=3,3
706
+ vertex_normals.scatter_add_(
707
+ dim=0,
708
+ index=faces[:, :, None].expand(_F, 3, 3),
709
+ src=face_normals[:, None, :].expand(_F, 3, 3),
710
+ )
711
+ vertex_normals = vertex_normals.sum(dim=1) # V,3
712
+ return F.normalize(vertex_normals, eps=1e-6, dim=1)
713
+
714
+
715
+ def normalize_vertices_array(
716
+ vertices: Union[torch.Tensor, np.ndarray],
717
+ mesh_scale: float = 1.0,
718
+ exec_norm: bool = True,
719
+ ):
720
+ if isinstance(vertices, torch.Tensor):
721
+ bbmin, bbmax = vertices.min(0)[0], vertices.max(0)[0]
722
+ else:
723
+ bbmin, bbmax = vertices.min(0), vertices.max(0) # (3,)
724
+ center = (bbmin + bbmax) * 0.5
725
+ bbsize = bbmax - bbmin
726
+ scale = 2 * mesh_scale / bbsize.max()
727
+ if exec_norm:
728
+ vertices = (vertices - center) * scale
729
+
730
+ return vertices, scale, center
731
+
732
+
733
+ def load_mesh_to_unit_cube(
734
+ mesh_file: str,
735
+ mesh_scale: float = 1.0,
736
+ ) -> tuple[trimesh.Trimesh, float, list[float]]:
737
+ if not os.path.exists(mesh_file):
738
+ raise FileNotFoundError(f"mesh_file path {mesh_file} not exists.")
739
+
740
+ mesh = trimesh.load(mesh_file)
741
+ if isinstance(mesh, trimesh.Scene):
742
+ mesh = trimesh.utils.concatenate(mesh)
743
+
744
+ vertices, scale, center = normalize_vertices_array(
745
+ mesh.vertices, mesh_scale
746
+ )
747
+ mesh.vertices = vertices
748
+
749
+ return mesh, scale, center
750
+
751
+
752
+ def as_list(obj):
753
+ if isinstance(obj, (list, tuple)):
754
+ return obj
755
+ elif isinstance(obj, set):
756
+ return list(obj)
757
+ else:
758
+ return [obj]
759
+
760
+
761
+ @dataclass
762
+ class CameraSetting:
763
+ """Camera settings for images rendering."""
764
+
765
+ num_images: int
766
+ elevation: list[float]
767
+ distance: float
768
+ resolution_hw: tuple[int, int]
769
+ fov: float
770
+ at: tuple[float, float, float] = field(
771
+ default_factory=lambda: (0.0, 0.0, 0.0)
772
+ )
773
+ up: tuple[float, float, float] = field(
774
+ default_factory=lambda: (0.0, 1.0, 0.0)
775
+ )
776
+ device: str = "cuda"
777
+ near: float = 1e-2
778
+ far: float = 1e2
779
+
780
+ def __post_init__(
781
+ self,
782
+ ):
783
+ h = self.resolution_hw[0]
784
+ f = (h / 2) / math.tan(self.fov / 2)
785
+ cx = self.resolution_hw[1] / 2
786
+ cy = self.resolution_hw[0] / 2
787
+ Ks = [
788
+ [f, 0, cx],
789
+ [0, f, cy],
790
+ [0, 0, 1],
791
+ ]
792
+
793
+ self.Ks = Ks
794
+
795
+
796
+ @dataclass
797
+ class RenderItems(str, Enum):
798
+ IMAGE = "image_color"
799
+ ALPHA = "image_mask"
800
+ VIEW_NORMAL = "image_view_normal"
801
+ GLOBAL_NORMAL = "image_global_normal"
802
+ POSITION_MAP = "image_position"
803
+ DEPTH = "image_depth"
804
+ ALBEDO = "image_albedo"
805
+ DIFFUSE = "image_diffuse"
806
+
807
+
808
+ def compute_az_el_by_camera_params(
809
+ camera_params: CameraSetting, flip_az: bool = False
810
+ ):
811
+ num_view = camera_params.num_images // len(camera_params.elevation)
812
+ view_interval = 2 * np.pi / num_view / 2
813
+ azimuths = []
814
+ elevations = []
815
+ for idx, el in enumerate(camera_params.elevation):
816
+ azs = np.arange(num_view) / num_view * np.pi * 2 + idx * view_interval
817
+ if flip_az:
818
+ azs *= -1
819
+ els = np.deg2rad(np.array([el] * num_view))
820
+ azimuths.append(azs)
821
+ elevations.append(els)
822
+
823
+ azimuths = np.concatenate(azimuths, axis=0)
824
+ elevations = np.concatenate(elevations, axis=0)
825
+
826
+ return azimuths, elevations
827
+
828
+
829
+ def init_kal_camera(camera_params: CameraSetting) -> Camera:
830
+ azimuths, elevations = compute_az_el_by_camera_params(camera_params)
831
+ cam_pts = compute_cam_pts_by_az_el(
832
+ azimuths, elevations, camera_params.distance
833
+ )
834
+
835
+ up = torch.cat(
836
+ [
837
+ torch.tensor(camera_params.up).repeat(camera_params.num_images, 1),
838
+ ],
839
+ dim=0,
840
+ )
841
+
842
+ camera = Camera.from_args(
843
+ eye=torch.tensor(cam_pts),
844
+ at=torch.tensor(camera_params.at),
845
+ up=up,
846
+ fov=camera_params.fov,
847
+ height=camera_params.resolution_hw[0],
848
+ width=camera_params.resolution_hw[1],
849
+ near=camera_params.near,
850
+ far=camera_params.far,
851
+ device=camera_params.device,
852
+ )
853
+
854
+ return camera
855
+
856
+
857
+ def import_kaolin_mesh(mesh_path: str, with_mtl: bool = False):
858
+ if mesh_path.endswith(".glb"):
859
+ mesh = kal.io.gltf.import_mesh(mesh_path)
860
+ elif mesh_path.endswith(".obj"):
861
+ with_material = True if with_mtl else False
862
+ mesh = kal.io.obj.import_mesh(mesh_path, with_materials=with_material)
863
+ if with_mtl and mesh.materials and len(mesh.materials) > 0:
864
+ material = kal.render.materials.PBRMaterial()
865
+ assert (
866
+ "map_Kd" in mesh.materials[0]
867
+ ), "'map_Kd' not found in materials."
868
+ material.diffuse_texture = mesh.materials[0]["map_Kd"] / 255.0
869
+ mesh.materials = [material]
870
+ elif mesh_path.endswith(".ply"):
871
+ mesh = trimesh.load(mesh_path)
872
+ mesh_path = mesh_path.replace(".ply", ".obj")
873
+ mesh.export(mesh_path)
874
+ mesh = kal.io.obj.import_mesh(mesh_path)
875
+ elif mesh_path.endswith(".off"):
876
+ mesh = kal.io.off.import_mesh(mesh_path)
877
+ else:
878
+ raise RuntimeError(
879
+ f"{mesh_path} mesh type not supported, "
880
+ "supported mesh type `.glb`, `.obj`, `.ply`, `.off`."
881
+ )
882
+
883
+ return mesh
884
+
885
+
886
+ def save_mesh_with_mtl(
887
+ vertices: np.ndarray,
888
+ faces: np.ndarray,
889
+ uvs: np.ndarray,
890
+ texture: Union[Image.Image, np.ndarray],
891
+ output_path: str,
892
+ material_base=(250, 250, 250, 255),
893
+ ) -> trimesh.Trimesh:
894
+ if isinstance(texture, np.ndarray):
895
+ texture = Image.fromarray(texture)
896
+
897
+ mesh = trimesh.Trimesh(
898
+ vertices,
899
+ faces,
900
+ visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture),
901
+ )
902
+ mesh.visual.material = trimesh.visual.material.SimpleMaterial(
903
+ image=texture,
904
+ diffuse=material_base,
905
+ ambient=material_base,
906
+ specular=material_base,
907
+ )
908
+
909
+ dir_name = os.path.dirname(output_path)
910
+ os.makedirs(dir_name, exist_ok=True)
911
+
912
+ _ = mesh.export(output_path)
913
+ # texture.save(os.path.join(dir_name, f"{file_name}_texture.png"))
914
+
915
+ logger.info(f"Saved mesh with texture to {output_path}")
916
+
917
+ return mesh
918
+
919
+
920
+ def get_images_from_grid(
921
+ image: Union[str, Image.Image], img_size: int
922
+ ) -> list[Image.Image]:
923
+ if isinstance(image, str):
924
+ image = Image.open(image)
925
+
926
+ view_images = np.array(image)
927
+ view_images = np.concatenate(
928
+ [view_images[:img_size, ...], view_images[img_size:, ...]], axis=1
929
+ )
930
+ images = np.split(view_images, view_images.shape[1] // img_size, axis=1)
931
+ images = [Image.fromarray(img) for img in images]
932
+
933
+ return images
934
+
935
+
936
+ def post_process_texture(texture: np.ndarray, iter: int = 2) -> np.ndarray:
937
+ for _ in range(iter):
938
+ texture = cv2.fastNlMeansDenoisingColored(texture, None, 13, 13, 9, 27)
939
+ texture = cv2.bilateralFilter(
940
+ texture, d=9, sigmaColor=80, sigmaSpace=80
941
+ )
942
+
943
+ return texture
asset3d_gen/models/delight.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import (
8
+ EulerAncestralDiscreteScheduler,
9
+ StableDiffusionInstructPix2PixPipeline,
10
+ )
11
+ from huggingface_hub import snapshot_download
12
+ from PIL import Image
13
+ from asset3d_gen.models.segment import RembgRemover
14
+
15
+ os.environ["https_proxy"] = "http://10.9.0.31:8838"
16
+
17
+
18
+ __all__ = [
19
+ "DelightingModel",
20
+ ]
21
+
22
+
23
+ class DelightingModel(object):
24
+ def __init__(
25
+ self,
26
+ model_path: str = None,
27
+ num_infer_step: int = 50,
28
+ mask_erosion_size: int = 3,
29
+ image_guide_scale: float = 1.5,
30
+ text_guide_scale: float = 1.0,
31
+ device: str = "cuda",
32
+ seed: int = 0,
33
+ ) -> None:
34
+ self.image_guide_scale = image_guide_scale
35
+ self.text_guide_scale = text_guide_scale
36
+ self.num_infer_step = num_infer_step
37
+ self.mask_erosion_size = mask_erosion_size
38
+ self.kernel = np.ones(
39
+ (self.mask_erosion_size, self.mask_erosion_size), np.uint8
40
+ )
41
+ self.seed = seed
42
+ self.device = device
43
+ self.bg_remover = RembgRemover()
44
+
45
+ if model_path is None:
46
+ suffix = "hunyuan3d-delight-v2-0"
47
+ model_path = snapshot_download(
48
+ repo_id="tencent/Hunyuan3D-2", allow_patterns=f"{suffix}/*"
49
+ )
50
+ model_path = os.path.join(model_path, suffix)
51
+
52
+ pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
53
+ model_path,
54
+ torch_dtype=torch.float16,
55
+ safety_checker=None,
56
+ )
57
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
58
+ pipeline.scheduler.config
59
+ )
60
+ pipeline.set_progress_bar_config(disable=True)
61
+
62
+ pipeline.to(self.device, torch.float16)
63
+ pipeline.enable_model_cpu_offload()
64
+ pipeline.enable_xformers_memory_efficient_attention()
65
+ self.pipeline = pipeline
66
+
67
+ def recenter_image(
68
+ self, image: Image.Image, border_ratio: float = 0.2
69
+ ) -> Image.Image:
70
+ if image.mode == "RGB":
71
+ return image
72
+ elif image.mode == "L":
73
+ image = image.convert("RGB")
74
+ return image
75
+
76
+ alpha_channel = np.array(image)[:, :, 3]
77
+ non_zero_indices = np.argwhere(alpha_channel > 0)
78
+ if non_zero_indices.size == 0:
79
+ raise ValueError("Image is fully transparent")
80
+
81
+ min_row, min_col = non_zero_indices.min(axis=0)
82
+ max_row, max_col = non_zero_indices.max(axis=0)
83
+
84
+ cropped_image = image.crop(
85
+ (min_col, min_row, max_col + 1, max_row + 1)
86
+ )
87
+
88
+ width, height = cropped_image.size
89
+ border_width = int(width * border_ratio)
90
+ border_height = int(height * border_ratio)
91
+
92
+ new_width = width + 2 * border_width
93
+ new_height = height + 2 * border_height
94
+
95
+ square_size = max(new_width, new_height)
96
+
97
+ new_image = Image.new(
98
+ "RGBA", (square_size, square_size), (255, 255, 255, 0)
99
+ )
100
+
101
+ paste_x = (square_size - new_width) // 2 + border_width
102
+ paste_y = (square_size - new_height) // 2 + border_height
103
+
104
+ new_image.paste(cropped_image, (paste_x, paste_y))
105
+
106
+ return new_image
107
+
108
+ @torch.no_grad()
109
+ def __call__(
110
+ self,
111
+ image: Union[str, np.ndarray, Image.Image],
112
+ preprocess: bool = False,
113
+ target_wh: tuple[int, int] = None,
114
+ ) -> Image.Image:
115
+ if isinstance(image, str):
116
+ image = Image.open(image)
117
+ elif isinstance(image, np.ndarray):
118
+ image = Image.fromarray(image)
119
+
120
+ if preprocess:
121
+ image = self.bg_remover(image)
122
+ image = self.recenter_image(image)
123
+
124
+ if target_wh is not None:
125
+ image = image.resize(target_wh)
126
+ else:
127
+ target_wh = image.size
128
+
129
+ image_array = np.array(image)
130
+ assert image_array.shape[-1] == 4, "Image must have alpha channel"
131
+
132
+ raw_alpha_channel = image_array[:, :, 3]
133
+ alpha_channel = cv2.erode(raw_alpha_channel, self.kernel, iterations=1)
134
+ image_array[alpha_channel == 0, :3] = 255 # must be white background
135
+ image_array[:, :, 3] = alpha_channel
136
+
137
+ image = self.pipeline(
138
+ prompt="",
139
+ image=Image.fromarray(image_array).convert("RGB"),
140
+ generator=torch.manual_seed(self.seed),
141
+ num_inference_steps=self.num_infer_step,
142
+ image_guidance_scale=self.image_guide_scale,
143
+ guidance_scale=self.text_guide_scale,
144
+ ).images[0]
145
+
146
+ alpha_channel = Image.fromarray(alpha_channel)
147
+ rgba_image = image.convert("RGBA").resize(target_wh)
148
+ rgba_image.putalpha(alpha_channel)
149
+
150
+ return rgba_image
151
+
152
+
153
+ if __name__ == "__main__":
154
+ delighting_model = DelightingModel(
155
+ # model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/hunyuan3d-delight-v2-0" # noqa
156
+ )
157
+ image_path = "scripts/apps/assets/example_image/room_bottle_002.jpeg"
158
+ image = delighting_model(
159
+ image_path, preprocess=True, target_wh=(512, 512)
160
+ ) # noqa
161
+ image.save("delight.png")
162
+
163
+ # image_path = "asset3d_gen/scripts/test_robot.png"
164
+ # image = delighting_model(image_path)
165
+ # image.save("delighting_image_a2.png")
asset3d_gen/models/gs_model.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import struct
4
+ from dataclasses import dataclass, field
5
+ from typing import Optional, Union
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ from gsplat.cuda._wrapper import spherical_harmonics
11
+ from gsplat.rendering import rasterization
12
+ from plyfile import PlyData
13
+ from scipy.spatial.transform import Rotation
14
+ from torch.nn import functional as F
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ __all__ = [
21
+ "RenderResult",
22
+ "GaussianOperator",
23
+ ]
24
+
25
+
26
+ def quat_mult(q1, q2):
27
+ # NOTE:
28
+ # Q1 is the quaternion that rotates the vector from the original position to the final position # noqa
29
+ # Q2 is the quaternion that been rotated
30
+ w1, x1, y1, z1 = q1.T
31
+ w2, x2, y2, z2 = q2.T
32
+ w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
33
+ x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
34
+ y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
35
+ z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
36
+ return torch.stack([w, x, y, z]).T
37
+
38
+
39
+ def quat_to_rotmat(quats: torch.Tensor, mode="wxyz") -> torch.Tensor:
40
+ """Convert quaternion to rotation matrix."""
41
+ quats = F.normalize(quats, p=2, dim=-1)
42
+
43
+ if mode == "xyzw":
44
+ x, y, z, w = torch.unbind(quats, dim=-1)
45
+ elif mode == "wxyz":
46
+ w, x, y, z = torch.unbind(quats, dim=-1)
47
+ else:
48
+ raise ValueError(f"Invalid mode: {mode}.")
49
+
50
+ R = torch.stack(
51
+ [
52
+ 1 - 2 * (y**2 + z**2),
53
+ 2 * (x * y - w * z),
54
+ 2 * (x * z + w * y),
55
+ 2 * (x * y + w * z),
56
+ 1 - 2 * (x**2 + z**2),
57
+ 2 * (y * z - w * x),
58
+ 2 * (x * z - w * y),
59
+ 2 * (y * z + w * x),
60
+ 1 - 2 * (x**2 + y**2),
61
+ ],
62
+ dim=-1,
63
+ )
64
+
65
+ return R.reshape(quats.shape[:-1] + (3, 3))
66
+
67
+
68
+ def gamma_shs(shs: torch.Tensor, gamma: float) -> torch.Tensor:
69
+ C0 = 0.28209479177387814 # Constant for normalization in spherical harmonics # noqa
70
+ # Clip to the range [0.0, 1.0], apply gamma correction, and then un-clip back # noqa
71
+ new_shs = torch.clip(shs * C0 + 0.5, 0.0, 1.0)
72
+ new_shs = (torch.pow(new_shs, gamma) - 0.5) / C0
73
+ return new_shs
74
+
75
+
76
+ @dataclass
77
+ class RenderResult:
78
+ rgb: np.ndarray
79
+ depth: np.ndarray
80
+ opacity: np.ndarray
81
+ mask_threshold: float = 10
82
+ mask: Optional[np.ndarray] = None
83
+ rgba: Optional[np.ndarray] = None
84
+
85
+ def __post_init__(self):
86
+ if isinstance(self.rgb, torch.Tensor):
87
+ rgb = self.rgb.detach().cpu().numpy()
88
+ rgb = (rgb * 255).astype(np.uint8)
89
+ self.rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
90
+ if isinstance(self.depth, torch.Tensor):
91
+ self.depth = self.depth.detach().cpu().numpy()
92
+ if isinstance(self.opacity, torch.Tensor):
93
+ opacity = self.opacity.detach().cpu().numpy()
94
+ opacity = (opacity * 255).astype(np.uint8)
95
+ self.opacity = cv2.cvtColor(opacity, cv2.COLOR_GRAY2RGB)
96
+ mask = np.where(self.opacity > self.mask_threshold, 255, 0)
97
+ self.mask = mask[..., 0:1].astype(np.uint8)
98
+ self.rgba = np.concatenate([self.rgb, self.mask], axis=-1)
99
+
100
+
101
+ @dataclass
102
+ class GaussianBase:
103
+ _opacities: torch.Tensor
104
+ _means: torch.Tensor
105
+ _scales: torch.Tensor
106
+ _quats: torch.Tensor
107
+ _rgbs: Optional[torch.Tensor] = None
108
+ _features_dc: Optional[torch.Tensor] = None
109
+ _features_rest: Optional[torch.Tensor] = None
110
+ sh_degree: Optional[int] = 0
111
+ device: str = "cuda"
112
+
113
+ def __post_init__(self):
114
+ self.active_sh_degree: int = self.sh_degree
115
+ self.to(self.device)
116
+
117
+ def to(self, device: str) -> None:
118
+ for k, v in self.__dict__.items():
119
+ if not isinstance(v, torch.Tensor):
120
+ continue
121
+ self.__dict__[k] = v.to(device)
122
+
123
+ def get_numpy_data(self):
124
+ data = {}
125
+ for k, v in self.__dict__.items():
126
+ if not isinstance(v, torch.Tensor):
127
+ continue
128
+ data[k] = v.detach().cpu().numpy()
129
+
130
+ return data
131
+
132
+ def quat_norm(self, x: torch.Tensor) -> torch.Tensor:
133
+ return x / x.norm(dim=-1, keepdim=True)
134
+
135
+ @classmethod
136
+ def load_from_ply(
137
+ cls,
138
+ path: str,
139
+ gamma: float = 1.0,
140
+ ) -> "GaussianBase":
141
+ plydata = PlyData.read(path)
142
+ xyz = torch.stack(
143
+ (
144
+ torch.tensor(plydata.elements[0]["x"], dtype=torch.float32),
145
+ torch.tensor(plydata.elements[0]["y"], dtype=torch.float32),
146
+ torch.tensor(plydata.elements[0]["z"], dtype=torch.float32),
147
+ ),
148
+ dim=1,
149
+ )
150
+
151
+ opacities = torch.tensor(
152
+ plydata.elements[0]["opacity"], dtype=torch.float32
153
+ ).unsqueeze(-1)
154
+ features_dc = torch.zeros((xyz.shape[0], 3), dtype=torch.float32)
155
+ features_dc[:, 0] = torch.tensor(
156
+ plydata.elements[0]["f_dc_0"], dtype=torch.float32
157
+ )
158
+ features_dc[:, 1] = torch.tensor(
159
+ plydata.elements[0]["f_dc_1"], dtype=torch.float32
160
+ )
161
+ features_dc[:, 2] = torch.tensor(
162
+ plydata.elements[0]["f_dc_2"], dtype=torch.float32
163
+ )
164
+
165
+ scale_names = [
166
+ p.name
167
+ for p in plydata.elements[0].properties
168
+ if p.name.startswith("scale_")
169
+ ]
170
+ scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
171
+ scales = torch.zeros(
172
+ (xyz.shape[0], len(scale_names)), dtype=torch.float32
173
+ )
174
+ for idx, attr_name in enumerate(scale_names):
175
+ scales[:, idx] = torch.tensor(
176
+ plydata.elements[0][attr_name], dtype=torch.float32
177
+ )
178
+
179
+ rot_names = [
180
+ p.name
181
+ for p in plydata.elements[0].properties
182
+ if p.name.startswith("rot_")
183
+ ]
184
+ rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
185
+ rots = torch.zeros((xyz.shape[0], len(rot_names)), dtype=torch.float32)
186
+ for idx, attr_name in enumerate(rot_names):
187
+ rots[:, idx] = torch.tensor(
188
+ plydata.elements[0][attr_name], dtype=torch.float32
189
+ )
190
+
191
+ rots = rots / torch.norm(rots, dim=-1, keepdim=True)
192
+
193
+ # extra features
194
+ extra_f_names = [
195
+ p.name
196
+ for p in plydata.elements[0].properties
197
+ if p.name.startswith("f_rest_")
198
+ ]
199
+ extra_f_names = sorted(
200
+ extra_f_names, key=lambda x: int(x.split("_")[-1])
201
+ )
202
+
203
+ max_sh_degree = int(np.sqrt((len(extra_f_names) + 3) / 3) - 1)
204
+ if max_sh_degree != 0:
205
+ features_extra = torch.zeros(
206
+ (xyz.shape[0], len(extra_f_names)), dtype=torch.float32
207
+ )
208
+ for idx, attr_name in enumerate(extra_f_names):
209
+ features_extra[:, idx] = torch.tensor(
210
+ plydata.elements[0][attr_name], dtype=torch.float32
211
+ )
212
+
213
+ features_extra = features_extra.view(
214
+ (features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1)
215
+ )
216
+ features_extra = features_extra.permute(0, 2, 1)
217
+
218
+ if abs(gamma - 1.0) > 1e-3:
219
+ features_dc = gamma_shs(features_dc, gamma)
220
+ features_extra[..., :] = 0.0
221
+ opacities *= 0.8
222
+
223
+ shs = torch.cat(
224
+ [
225
+ features_dc.reshape(-1, 3),
226
+ features_extra.reshape(len(features_dc), -1),
227
+ ],
228
+ dim=-1,
229
+ )
230
+ else:
231
+ # sh_dim is 0, only dc features
232
+ shs = features_dc
233
+ features_extra = None
234
+
235
+ return cls(
236
+ sh_degree=max_sh_degree,
237
+ _means=xyz,
238
+ _opacities=opacities,
239
+ _rgbs=shs,
240
+ _scales=scales,
241
+ _quats=rots,
242
+ _features_dc=features_dc,
243
+ _features_rest=features_extra,
244
+ )
245
+
246
+ def save_to_ply(
247
+ self, path: str, colors: torch.Tensor = None, enable_mask: bool = False
248
+ ):
249
+ os.makedirs(os.path.dirname(path), exist_ok=True)
250
+ numpy_data = self.get_numpy_data()
251
+ means = numpy_data["_means"]
252
+ scales = numpy_data["_scales"]
253
+ quats = numpy_data["_quats"]
254
+ opacities = numpy_data["_opacities"]
255
+ sh0 = numpy_data["_features_dc"]
256
+ shN = numpy_data.get("_features_rest", np.zeros((means.shape[0], 0)))
257
+
258
+ # Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays # noqa
259
+ if enable_mask:
260
+ invalid_mask = (
261
+ np.isnan(means).any(axis=1)
262
+ | np.isinf(means).any(axis=1)
263
+ | np.isnan(scales).any(axis=1)
264
+ | np.isinf(scales).any(axis=1)
265
+ | np.isnan(quats).any(axis=1)
266
+ | np.isinf(quats).any(axis=1)
267
+ | np.isnan(opacities).any(axis=0)
268
+ | np.isinf(opacities).any(axis=0)
269
+ | np.isnan(sh0).any(axis=1)
270
+ | np.isinf(sh0).any(axis=1)
271
+ | np.isnan(shN).any(axis=1)
272
+ | np.isinf(shN).any(axis=1)
273
+ )
274
+
275
+ # Filter out rows with NaNs or Infs from all data arrays
276
+ means = means[~invalid_mask]
277
+ scales = scales[~invalid_mask]
278
+ quats = quats[~invalid_mask]
279
+ opacities = opacities[~invalid_mask]
280
+ sh0 = sh0[~invalid_mask]
281
+ shN = shN[~invalid_mask]
282
+
283
+ num_points = means.shape[0]
284
+
285
+ with open(path, "wb") as f:
286
+ # Write PLY header
287
+ f.write(b"ply\n")
288
+ f.write(b"format binary_little_endian 1.0\n")
289
+ f.write(f"element vertex {num_points}\n".encode())
290
+ f.write(b"property float x\n")
291
+ f.write(b"property float y\n")
292
+ f.write(b"property float z\n")
293
+ f.write(b"property float nx\n")
294
+ f.write(b"property float ny\n")
295
+ f.write(b"property float nz\n")
296
+
297
+ if colors is not None:
298
+ for j in range(colors.shape[1]):
299
+ f.write(f"property float f_dc_{j}\n".encode())
300
+ else:
301
+ for i, data in enumerate([sh0, shN]):
302
+ prefix = "f_dc" if i == 0 else "f_rest"
303
+ for j in range(data.shape[1]):
304
+ f.write(f"property float {prefix}_{j}\n".encode())
305
+
306
+ f.write(b"property float opacity\n")
307
+
308
+ for i in range(scales.shape[1]):
309
+ f.write(f"property float scale_{i}\n".encode())
310
+ for i in range(quats.shape[1]):
311
+ f.write(f"property float rot_{i}\n".encode())
312
+
313
+ f.write(b"end_header\n")
314
+
315
+ # Write vertex data
316
+ for i in range(num_points):
317
+ f.write(struct.pack("<fff", *means[i])) # x, y, z
318
+ f.write(struct.pack("<fff", 0, 0, 0)) # nx, ny, nz (zeros)
319
+
320
+ if colors is not None:
321
+ color = colors.detach().cpu().numpy()
322
+ for j in range(color.shape[1]):
323
+ f_dc = (color[i, j] - 0.5) / 0.2820947917738781
324
+ f.write(struct.pack("<f", f_dc))
325
+ else:
326
+ for data in [sh0, shN]:
327
+ for j in range(data.shape[1]):
328
+ f.write(struct.pack("<f", data[i, j]))
329
+
330
+ f.write(struct.pack("<f", opacities[i])) # opacity
331
+
332
+ for data in [scales, quats]:
333
+ for j in range(data.shape[1]):
334
+ f.write(struct.pack("<f", data[i, j]))
335
+
336
+
337
+ @dataclass
338
+ class GaussianOperator(GaussianBase):
339
+
340
+ def _compute_transform(
341
+ self,
342
+ means: torch.Tensor,
343
+ quats: torch.Tensor,
344
+ instance_pose: torch.Tensor,
345
+ ):
346
+ """Compute the transform of the GS models.
347
+
348
+ Args:
349
+ means: tensor of gs means.
350
+ quats: tensor of gs quaternions.
351
+ instance_pose: instances poses in [x y z qx qy qz qw] format.
352
+
353
+ """
354
+ # (x y z qx qy qz qw) -> (x y z qw qx qy qz)
355
+ instance_pose = instance_pose[[0, 1, 2, 6, 3, 4, 5]]
356
+ cur_instances_quats = self.quat_norm(instance_pose[3:])
357
+ rot_cur = quat_to_rotmat(cur_instances_quats, mode="wxyz")
358
+
359
+ # update the means
360
+ num_gs = means.shape[0]
361
+ trans_per_pts = torch.stack([instance_pose[:3]] * num_gs, dim=0)
362
+ quat_per_pts = torch.stack([instance_pose[3:]] * num_gs, dim=0)
363
+ rot_per_pts = torch.stack([rot_cur] * num_gs, dim=0) # (num_gs, 3, 3)
364
+
365
+ # update the means
366
+ cur_means = (
367
+ torch.bmm(rot_per_pts, means.unsqueeze(-1)).squeeze(-1)
368
+ + trans_per_pts
369
+ )
370
+
371
+ # update the quats
372
+ _quats = self.quat_norm(quats)
373
+ cur_quats = quat_mult(quat_per_pts, _quats)
374
+
375
+ return cur_means, cur_quats
376
+
377
+ def get_gaussians(
378
+ self,
379
+ c2w: torch.Tensor = None,
380
+ instance_pose: torch.Tensor = None,
381
+ apply_activate: bool = False,
382
+ ) -> "GaussianBase":
383
+ """Get Gaussian data under the given instance_pose."""
384
+ if c2w is None:
385
+ c2w = torch.eye(4).to(self.device)
386
+
387
+ if instance_pose is not None:
388
+ # compute the transformed gs means and quats
389
+ world_means, world_quats = self._compute_transform(
390
+ self._means, self._quats, instance_pose.float().to(self.device)
391
+ )
392
+ else:
393
+ world_means, world_quats = self._means, self._quats
394
+
395
+ # get colors of gaussians
396
+ if self._features_rest is not None:
397
+ colors = torch.cat(
398
+ (self._features_dc[:, None, :], self._features_rest), dim=1
399
+ )
400
+ else:
401
+ colors = self._features_dc[:, None, :]
402
+
403
+ if self.sh_degree > 0:
404
+ viewdirs = world_means.detach() - c2w[..., :3, 3] # (N, 3)
405
+ viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
406
+ rgbs = spherical_harmonics(self.sh_degree, viewdirs, colors)
407
+ rgbs = torch.clamp(rgbs + 0.5, 0.0, 1.0)
408
+ else:
409
+ rgbs = torch.sigmoid(colors[:, 0, :])
410
+
411
+ gs_dict = dict(
412
+ _means=world_means,
413
+ _opacities=(
414
+ torch.sigmoid(self._opacities)
415
+ if apply_activate
416
+ else self._opacities
417
+ ),
418
+ _rgbs=rgbs,
419
+ _scales=(
420
+ torch.exp(self._scales) if apply_activate else self._scales
421
+ ),
422
+ _quats=self.quat_norm(world_quats),
423
+ _features_dc=self._features_dc,
424
+ _features_rest=self._features_rest,
425
+ sh_degree=self.sh_degree,
426
+ )
427
+
428
+ return GaussianOperator(**gs_dict)
429
+
430
+ def rescale(self, scale: float):
431
+ if scale != 1.0:
432
+ self._means *= scale
433
+ self._scales += torch.log(self._scales.new_tensor(scale))
434
+
435
+ def set_scale_by_height(self, real_height: float) -> None:
436
+ def _ptp(tensor, dim):
437
+ val = tensor.max(dim=dim).values - tensor.min(dim=dim).values
438
+ return val.tolist()
439
+
440
+ xyz_scale = max(_ptp(self._means, dim=0))
441
+ self.rescale(1 / (xyz_scale + 1e-6)) # Normalize to [-0.5, 0.5]
442
+ raw_height = _ptp(self._means, dim=0)[1]
443
+ scale = real_height / raw_height
444
+
445
+ self.rescale(scale)
446
+
447
+ return
448
+
449
+ @staticmethod
450
+ def resave_ply(
451
+ in_ply: str,
452
+ out_ply: str,
453
+ real_height: float = None,
454
+ instance_pose: np.ndarray = None,
455
+ sh_degree: int = 0,
456
+ ) -> None:
457
+ gs_model = GaussianOperator.load_from_ply(in_ply, sh_degree)
458
+
459
+ if instance_pose is not None:
460
+ gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
461
+
462
+ if real_height is not None:
463
+ gs_model.set_scale_by_height(real_height)
464
+
465
+ gs_model.save_to_ply(out_ply)
466
+
467
+ return
468
+
469
+ @staticmethod
470
+ def trans_to_quatpose(
471
+ rot_matrix: list[list[float]],
472
+ trans_matrix: list[float] = [0, 0, 0],
473
+ ) -> torch.Tensor:
474
+ if isinstance(rot_matrix, list):
475
+ rot_matrix = np.array(rot_matrix)
476
+
477
+ rot = Rotation.from_matrix(rot_matrix)
478
+ qx, qy, qz, qw = rot.as_quat()
479
+ instance_pose = torch.tensor([*trans_matrix, qx, qy, qz, qw])
480
+
481
+ return instance_pose
482
+
483
+ def render(
484
+ self,
485
+ c2w: torch.Tensor,
486
+ Ks: torch.Tensor,
487
+ image_width: int,
488
+ image_height: int,
489
+ ) -> RenderResult:
490
+ gs = self.get_gaussians(c2w, apply_activate=True)
491
+ renders, alphas, _ = rasterization(
492
+ means=gs._means,
493
+ quats=gs._quats,
494
+ scales=gs._scales,
495
+ opacities=gs._opacities.squeeze(),
496
+ colors=gs._rgbs,
497
+ viewmats=torch.linalg.inv(c2w)[None, ...],
498
+ Ks=Ks[None, ...],
499
+ width=image_width,
500
+ height=image_height,
501
+ packed=False,
502
+ absgrad=True,
503
+ sparse_grad=False,
504
+ # rasterize_mode="classic",
505
+ rasterize_mode="antialiased",
506
+ **{
507
+ "near_plane": 0.01,
508
+ "far_plane": 1000000000,
509
+ "radius_clip": 0.0,
510
+ "render_mode": "RGB+ED",
511
+ },
512
+ )
513
+ renders = renders[0]
514
+ alphas = alphas[0].squeeze(-1)
515
+
516
+ assert renders.shape[-1] == 4, f"Must render rgb, depth and alpha"
517
+ rendered_rgb, rendered_depth = torch.split(renders, [3, 1], dim=-1)
518
+
519
+ return RenderResult(
520
+ torch.clamp(rendered_rgb, min=0, max=1),
521
+ rendered_depth,
522
+ alphas[..., None],
523
+ )
524
+
525
+
526
+ if __name__ == "__main__":
527
+ input_gs = "outputs/test/debug.ply"
528
+ output_gs = "./debug_v3.ply"
529
+ gs_model: GaussianOperator = GaussianOperator.load_from_ply(input_gs)
530
+
531
+ # 绕 x 轴旋转 180°
532
+ R_x = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
533
+ instance_pose = gs_model.trans_to_quatpose(R_x)
534
+ gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
535
+
536
+ gs_model.rescale(2)
537
+
538
+ gs_model.set_scale_by_height(1.3)
539
+
540
+ gs_model.save_to_ply(output_gs)
asset3d_gen/models/segment.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Literal, Union
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import rembg
8
+ import torch
9
+ from huggingface_hub import snapshot_download
10
+ from PIL import Image
11
+ from segment_anything import (
12
+ SamAutomaticMaskGenerator,
13
+ SamPredictor,
14
+ sam_model_registry,
15
+ )
16
+ from asset3d_gen.utils.process_media import filter_small_connected_components
17
+ from asset3d_gen.validators.quality_checkers import ImageSegChecker
18
+
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ os.environ["https_proxy"] = "http://10.9.0.31:8838"
23
+
24
+ __all__ = [
25
+ "resize_pil",
26
+ "trellis_preprocess",
27
+ "SAMRemover",
28
+ "SAMPredictor",
29
+ "RembgRemover",
30
+ "get_segmented_image",
31
+ ]
32
+
33
+
34
+ def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image:
35
+ max_size = max(image.size)
36
+ scale = min(1, 1024 / max_size)
37
+ if scale < 1:
38
+ new_size = (int(image.width * scale), int(image.height * scale))
39
+ image = image.resize(new_size, Image.Resampling.LANCZOS)
40
+
41
+ return image
42
+
43
+
44
+ def trellis_preprocess(image: Image.Image) -> Image.Image:
45
+ """Process the input image as trellis done."""
46
+ image_np = np.array(image)
47
+ alpha = image_np[:, :, 3]
48
+ bbox = np.argwhere(alpha > 0.8 * 255)
49
+ bbox = (
50
+ np.min(bbox[:, 1]),
51
+ np.min(bbox[:, 0]),
52
+ np.max(bbox[:, 1]),
53
+ np.max(bbox[:, 0]),
54
+ )
55
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
56
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
57
+ size = int(size * 1.2)
58
+ bbox = (
59
+ center[0] - size // 2,
60
+ center[1] - size // 2,
61
+ center[0] + size // 2,
62
+ center[1] + size // 2,
63
+ )
64
+ image = image.crop(bbox)
65
+ image = image.resize((518, 518), Image.Resampling.LANCZOS)
66
+ image = np.array(image).astype(np.float32) / 255
67
+ image = image[:, :, :3] * image[:, :, 3:4]
68
+ image = Image.fromarray((image * 255).astype(np.uint8))
69
+
70
+ return image
71
+
72
+
73
+ class SAMRemover(object):
74
+ """Loading SAM models and performing background removal on images.
75
+
76
+ Attributes:
77
+ checkpoint (str): Path to the model checkpoint.
78
+ model_type (str): Type of the SAM model to load (default: "vit_h").
79
+ area_ratio (float): Area ratio filtering small connected components.
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ checkpoint: str = None,
85
+ model_type: str = "vit_h",
86
+ area_ratio: float = 15,
87
+ ):
88
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
89
+ self.model_type = model_type
90
+ self.area_ratio = area_ratio
91
+
92
+ if checkpoint is None:
93
+ suffix = "sam"
94
+ model_path = snapshot_download(
95
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
96
+ )
97
+ checkpoint = os.path.join(
98
+ model_path, suffix, "sam_vit_h_4b8939.pth"
99
+ )
100
+
101
+ self.mask_generator = self._load_sam_model(checkpoint)
102
+
103
+ def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator:
104
+ sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
105
+ sam.to(device=self.device)
106
+
107
+ return SamAutomaticMaskGenerator(sam)
108
+
109
+ def __call__(
110
+ self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
111
+ ) -> Image.Image:
112
+ """Removes the background from an image using the SAM model.
113
+
114
+ Args:
115
+ image (Union[str, Image.Image, np.ndarray]): Input image,
116
+ can be a file path, PIL Image, or numpy array.
117
+ save_path (str): Path to save the output image (default: None).
118
+
119
+ Returns:
120
+ Image.Image: The image with background removed,
121
+ including an alpha channel.
122
+ """
123
+ # Convert input to numpy array
124
+ if isinstance(image, str):
125
+ image = Image.open(image)
126
+ elif isinstance(image, np.ndarray):
127
+ image = Image.fromarray(image).convert("RGB")
128
+ image = resize_pil(image)
129
+ image = np.array(image.convert("RGB"))
130
+
131
+ # Generate masks
132
+ masks = self.mask_generator.generate(image)
133
+ masks = sorted(masks, key=lambda x: x["area"], reverse=True)
134
+
135
+ if not masks:
136
+ logger.warning(
137
+ "Segmentation failed: No mask generated, return raw image."
138
+ )
139
+ output_image = Image.fromarray(image, mode="RGB")
140
+ else:
141
+ # Use the largest mask
142
+ best_mask = masks[0]["segmentation"]
143
+ mask = (best_mask * 255).astype(np.uint8)
144
+ mask = filter_small_connected_components(
145
+ mask, area_ratio=self.area_ratio
146
+ )
147
+ # Apply the mask to remove the background
148
+ background_removed = cv2.bitwise_and(image, image, mask=mask)
149
+ output_image = np.dstack((background_removed, mask))
150
+ output_image = Image.fromarray(output_image, mode="RGBA")
151
+
152
+ if save_path is not None:
153
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
154
+ output_image.save(save_path)
155
+
156
+ return output_image
157
+
158
+
159
+ class SAMPredictor(object):
160
+ def __init__(
161
+ self,
162
+ checkpoint: str = None,
163
+ model_type: str = "vit_h",
164
+ binary_thresh: float = 0.1,
165
+ device: str = "cuda"
166
+ ):
167
+ self.device = device
168
+ self.model_type = model_type
169
+
170
+ if checkpoint is None:
171
+ suffix = "sam"
172
+ model_path = snapshot_download(
173
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
174
+ )
175
+ checkpoint = os.path.join(
176
+ model_path, suffix, "sam_vit_h_4b8939.pth"
177
+ )
178
+
179
+ self.predictor = self._load_sam_model(checkpoint)
180
+ self.binary_thresh = binary_thresh
181
+
182
+ def _load_sam_model(self, checkpoint: str) -> SamPredictor:
183
+ sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
184
+ sam.to(device=self.device)
185
+
186
+ return SamPredictor(sam)
187
+
188
+ def preprocess_image(self, image: Image.Image) -> np.ndarray:
189
+ if isinstance(image, str):
190
+ image = Image.open(image)
191
+ elif isinstance(image, np.ndarray):
192
+ image = Image.fromarray(image).convert("RGB")
193
+
194
+ image = resize_pil(image)
195
+ image = np.array(image.convert("RGB"))
196
+
197
+ return image
198
+
199
+ def generate_masks(
200
+ self,
201
+ image: np.ndarray,
202
+ selected_points: list[list[int]],
203
+ ) -> np.ndarray:
204
+ if len(selected_points) == 0:
205
+ return []
206
+
207
+ points = (
208
+ torch.Tensor([p for p, _ in selected_points])
209
+ .to(self.predictor.device)
210
+ .unsqueeze(1)
211
+ )
212
+
213
+ labels = (
214
+ torch.Tensor([int(l) for _, l in selected_points])
215
+ .to(self.predictor.device)
216
+ .unsqueeze(1)
217
+ )
218
+
219
+ transformed_points = self.predictor.transform.apply_coords_torch(
220
+ points, image.shape[:2]
221
+ )
222
+
223
+ masks, scores, _ = self.predictor.predict_torch(
224
+ point_coords=transformed_points,
225
+ point_labels=labels,
226
+ multimask_output=True,
227
+ )
228
+ valid_mask = masks[:, torch.argmax(scores, dim=1)]
229
+ masks_pos = valid_mask[labels[:, 0] == 1, 0].cpu().detach().numpy()
230
+ masks_neg = valid_mask[labels[:, 0] == 0, 0].cpu().detach().numpy()
231
+ if len(masks_neg) == 0:
232
+ masks_neg = np.zeros_like(masks_pos)
233
+ if len(masks_pos) == 0:
234
+ masks_pos = np.zeros_like(masks_neg)
235
+ masks_neg = masks_neg.max(axis=0, keepdims=True)
236
+ masks_pos = masks_pos.max(axis=0, keepdims=True)
237
+ valid_mask = (masks_pos.astype(int) - masks_neg.astype(int)).clip(0, 1)
238
+
239
+ binary_mask = (valid_mask > self.binary_thresh).astype(np.int32)
240
+
241
+ return [(mask, f"mask_{i}") for i, mask in enumerate(binary_mask)]
242
+
243
+ def get_segmented_image(
244
+ self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
245
+ ) -> Image.Image:
246
+ seg_image = Image.fromarray(image, mode="RGB")
247
+ alpha_channel = np.zeros(
248
+ (seg_image.height, seg_image.width), dtype=np.uint8
249
+ )
250
+ for mask, _ in masks:
251
+ # Use the maximum to combine multiple masks
252
+ alpha_channel = np.maximum(alpha_channel, mask)
253
+
254
+ alpha_channel = np.clip(alpha_channel, 0, 1)
255
+ alpha_channel = (alpha_channel * 255).astype(np.uint8)
256
+ alpha_image = Image.fromarray(alpha_channel, mode="L")
257
+ r, g, b = seg_image.split()
258
+ seg_image = Image.merge("RGBA", (r, g, b, alpha_image))
259
+
260
+ return seg_image
261
+
262
+ def __call__(
263
+ self,
264
+ image: Union[str, Image.Image, np.ndarray],
265
+ selected_points: list[list[int]],
266
+ ) -> Image.Image:
267
+ image = self.preprocess_image(image)
268
+ self.predictor.set_image(image)
269
+ masks = self.generate_masks(image, selected_points)
270
+
271
+ return self.get_segmented_image(image, masks)
272
+
273
+
274
+ class RembgRemover(object):
275
+ def __init__(self):
276
+ self.rembg_session = rembg.new_session("u2net")
277
+
278
+ def __call__(
279
+ self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
280
+ ) -> Image.Image:
281
+ if isinstance(image, str):
282
+ image = Image.open(image)
283
+ elif isinstance(image, np.ndarray):
284
+ image = Image.fromarray(image)
285
+
286
+ image = resize_pil(image)
287
+ output_image = rembg.remove(image, session=self.rembg_session)
288
+
289
+ if save_path is not None:
290
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
291
+ output_image.save(save_path)
292
+
293
+ return output_image
294
+
295
+
296
+ def invert_rgba_pil(
297
+ image: Image.Image, mask: Image.Image, save_path: str = None
298
+ ) -> Image.Image:
299
+ mask = (255 - np.array(mask))[..., None]
300
+ image_array = np.concatenate([np.array(image), mask], axis=-1)
301
+ inverted_image = Image.fromarray(image_array, "RGBA")
302
+
303
+ if save_path is not None:
304
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
305
+ inverted_image.save(save_path)
306
+
307
+ return inverted_image
308
+
309
+
310
+ def get_segmented_image(
311
+ image: Image.Image,
312
+ sam_remover: SAMRemover,
313
+ rbg_remover: RembgRemover,
314
+ seg_checker: ImageSegChecker = None,
315
+ save_path: str = None,
316
+ mode: Literal["loose", "strict"] = "loose",
317
+ ) -> Image.Image:
318
+ def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
319
+ if seg_checker is None:
320
+ return True
321
+ return raw_img.mode == "RGBA" and seg_checker([raw_img, seg_img])[0]
322
+
323
+ out_sam = f"{save_path}_sam.png" if save_path else None
324
+ out_sam_inv = f"{save_path}_sam_inv.png" if save_path else None
325
+ out_rbg = f"{save_path}_rbg.png" if save_path else None
326
+
327
+ seg_image = sam_remover(image, out_sam)
328
+ seg_image = seg_image.convert("RGBA")
329
+ _, _, _, alpha = seg_image.split()
330
+ seg_image_inv = invert_rgba_pil(image.convert("RGB"), alpha, out_sam_inv)
331
+ seg_image_rbg = rbg_remover(image, out_rbg)
332
+
333
+ final_image = None
334
+ if _is_valid_seg(image, seg_image):
335
+ final_image = seg_image
336
+ elif _is_valid_seg(image, seg_image_inv):
337
+ final_image = seg_image_inv
338
+ elif _is_valid_seg(image, seg_image_rbg):
339
+ logger.warning(f"Failed to segment by `SAM`, retry with `rembg`.")
340
+ final_image = seg_image_rbg
341
+ else:
342
+ if mode == "strict":
343
+ raise RuntimeError(
344
+ f"Failed to segment by `SAM` or `rembg`, abort."
345
+ )
346
+ logger.warning("Failed to segment by SAM or rembg, use raw image.")
347
+ final_image = image.convert("RGBA")
348
+
349
+ if save_path:
350
+ final_image.save(save_path)
351
+
352
+ final_image = trellis_preprocess(final_image)
353
+
354
+ return final_image
355
+
356
+
357
+ if __name__ == "__main__":
358
+ input_image = "outputs/text2image/demo_objects/electrical/sample_0.jpg"
359
+ output_image = "sample_0_seg2.png"
360
+
361
+ # input_image = "outputs/text2image/tmp/coffee_machine.jpeg"
362
+ # output_image = "outputs/text2image/tmp/coffee_machine_seg.png"
363
+
364
+ # input_image = "outputs/text2image/tmp/bucket.jpeg"
365
+ # output_image = "outputs/text2image/tmp/bucket_seg.png"
366
+
367
+ remover = SAMRemover(
368
+ # checkpoint="/horizon-bucket/robot_lab/users/xinjie.wang/weights/sam/sam_vit_h_4b8939.pth", # noqa
369
+ model_type="vit_h",
370
+ )
371
+ remover = RembgRemover()
372
+ # clean_image = remover(input_image)
373
+ # clean_image.save(output_image)
374
+ get_segmented_image(
375
+ Image.open(input_image), remover, remover, None, "./test_seg.png"
376
+ )
asset3d_gen/models/super_resolution.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from huggingface_hub import hf_hub_download, snapshot_download
8
+ from PIL import Image
9
+ from asset3d_gen.data.utils import get_images_from_grid
10
+
11
+ os.environ["https_proxy"] = "http://10.9.0.31:8838"
12
+
13
+ logging.basicConfig(
14
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
15
+ )
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ __all__ = [
20
+ "ImageStableSR",
21
+ "ImageRealESRGAN",
22
+ ]
23
+
24
+
25
+ class ImageStableSR:
26
+ def __init__(
27
+ self,
28
+ model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
29
+ device="cuda",
30
+ ) -> None:
31
+ from diffusers import StableDiffusionUpscalePipeline
32
+
33
+ self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
34
+ model_path,
35
+ torch_dtype=torch.float16,
36
+ ).to(device)
37
+ self.up_pipeline_x4.set_progress_bar_config(disable=True)
38
+ self.up_pipeline_x4.enable_model_cpu_offload()
39
+
40
+ def __call__(
41
+ self,
42
+ image: Union[Image.Image, np.ndarray],
43
+ prompt: str = "",
44
+ infer_step: int = 20,
45
+ ) -> Image.Image:
46
+ if isinstance(image, np.ndarray):
47
+ image = Image.fromarray(image)
48
+
49
+ image = image.convert("RGB")
50
+
51
+ with torch.no_grad():
52
+ upscaled_image = self.up_pipeline_x4(
53
+ image=image,
54
+ prompt=[prompt],
55
+ num_inference_steps=infer_step,
56
+ ).images[0]
57
+
58
+ return upscaled_image
59
+
60
+
61
+ class ImageRealESRGAN:
62
+ def __init__(self, outscale: int, model_path: str = None) -> None:
63
+ from basicsr.archs.rrdbnet_arch import RRDBNet
64
+ from realesrgan import RealESRGANer
65
+
66
+ self.outscale = outscale
67
+ model = RRDBNet(
68
+ num_in_ch=3,
69
+ num_out_ch=3,
70
+ num_feat=64,
71
+ num_block=23,
72
+ num_grow_ch=32,
73
+ scale=4,
74
+ )
75
+ if model_path is None:
76
+ suffix = "super_resolution"
77
+ model_path = snapshot_download(
78
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
79
+ )
80
+ model_path = os.path.join(
81
+ model_path, suffix, "RealESRGAN_x4plus.pth"
82
+ )
83
+
84
+ self.upsampler = RealESRGANer(
85
+ scale=4,
86
+ model_path=model_path,
87
+ model=model,
88
+ pre_pad=0,
89
+ half=True,
90
+ )
91
+
92
+ def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
93
+ if isinstance(image, Image.Image):
94
+ image = np.array(image)
95
+
96
+ with torch.no_grad():
97
+ output, _ = self.upsampler.enhance(image, outscale=self.outscale)
98
+
99
+ return Image.fromarray(output)
100
+
101
+
102
+ if __name__ == "__main__":
103
+ color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png"
104
+
105
+ # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
106
+ # model_path = "/horizon-bucket/robot_lab/users/xinjie.wang/weights/super_resolution/RealESRGAN_x4plus.pth" # noqa
107
+ super_model = ImageRealESRGAN(outscale=4)
108
+ multiviews = get_images_from_grid(color_path, img_size=512)
109
+ multiviews = [super_model(img.convert("RGB")) for img in multiviews]
110
+ for idx, img in enumerate(multiviews):
111
+ img.save(f"sr{idx}.png")
112
+
113
+ # # Use stable diffusion for x4 (512->2048) image super resolution.
114
+ # super_model = ImageStableSR()
115
+ # multiviews = get_images_from_grid(color_path, img_size=512)
116
+ # multiviews = [super_model(img) for img in multiviews]
117
+ # for idx, img in enumerate(multiviews):
118
+ # img.save(f"sr_stable{idx}.png")
asset3d_gen/models/text_model.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ from diffusers import (
5
+ AutoencoderKL,
6
+ EulerDiscreteScheduler,
7
+ UNet2DConditionModel,
8
+ )
9
+ from kolors.models.modeling_chatglm import ChatGLMModel
10
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
11
+ from kolors.models.unet_2d_condition import (
12
+ UNet2DConditionModel as UNet2DConditionModelIP,
13
+ )
14
+ from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import (
15
+ StableDiffusionXLPipeline,
16
+ )
17
+ from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa
18
+ StableDiffusionXLPipeline as StableDiffusionXLPipelineIP,
19
+ )
20
+ from PIL import Image
21
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
22
+
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ __all__ = [
28
+ "build_text2img_ip_pipeline",
29
+ "build_text2img_pipeline",
30
+ "text2img_gen",
31
+ ]
32
+
33
+
34
+ def build_text2img_ip_pipeline(
35
+ ckpt_dir: str,
36
+ ref_scale: float,
37
+ device: str = "cuda",
38
+ ) -> StableDiffusionXLPipelineIP:
39
+ text_encoder = ChatGLMModel.from_pretrained(
40
+ f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
41
+ ).half()
42
+ tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder")
43
+ vae = AutoencoderKL.from_pretrained(
44
+ f"{ckpt_dir}/vae", revision=None
45
+ ).half()
46
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
47
+ unet = UNet2DConditionModelIP.from_pretrained(
48
+ f"{ckpt_dir}/unet", revision=None
49
+ ).half()
50
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
51
+ f"{ckpt_dir}/../Kolors-IP-Adapter-Plus/image_encoder",
52
+ ignore_mismatched_sizes=True,
53
+ ).to(dtype=torch.float16)
54
+ clip_image_processor = CLIPImageProcessor(size=336, crop_size=336)
55
+
56
+ pipe = StableDiffusionXLPipelineIP(
57
+ vae=vae,
58
+ text_encoder=text_encoder,
59
+ tokenizer=tokenizer,
60
+ unet=unet,
61
+ scheduler=scheduler,
62
+ image_encoder=image_encoder,
63
+ feature_extractor=clip_image_processor,
64
+ force_zeros_for_empty_prompt=False,
65
+ )
66
+
67
+ if hasattr(pipe.unet, "encoder_hid_proj"):
68
+ pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
69
+
70
+ pipe.load_ip_adapter(
71
+ f"{ckpt_dir}/../Kolors-IP-Adapter-Plus",
72
+ subfolder="",
73
+ weight_name=["ip_adapter_plus_general.bin"],
74
+ )
75
+ pipe.set_ip_adapter_scale([ref_scale])
76
+
77
+ pipe = pipe.to(device)
78
+ pipe.enable_model_cpu_offload()
79
+ # pipe.enable_xformers_memory_efficient_attention()
80
+ # pipe.enable_vae_slicing()
81
+
82
+ return pipe
83
+
84
+
85
+ def build_text2img_pipeline(
86
+ ckpt_dir: str,
87
+ device: str = "cuda",
88
+ ) -> StableDiffusionXLPipeline:
89
+ text_encoder = ChatGLMModel.from_pretrained(
90
+ f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
91
+ ).half()
92
+ tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder")
93
+ vae = AutoencoderKL.from_pretrained(
94
+ f"{ckpt_dir}/vae", revision=None
95
+ ).half()
96
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
97
+ unet = UNet2DConditionModel.from_pretrained(
98
+ f"{ckpt_dir}/unet", revision=None
99
+ ).half()
100
+ pipe = StableDiffusionXLPipeline(
101
+ vae=vae,
102
+ text_encoder=text_encoder,
103
+ tokenizer=tokenizer,
104
+ unet=unet,
105
+ scheduler=scheduler,
106
+ force_zeros_for_empty_prompt=False,
107
+ )
108
+ pipe = pipe.to(device)
109
+ pipe.enable_model_cpu_offload()
110
+ pipe.enable_xformers_memory_efficient_attention()
111
+
112
+ return pipe
113
+
114
+
115
+ def text2img_gen(
116
+ prompt: str,
117
+ n_sample: int,
118
+ guidance_scale: float,
119
+ pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP,
120
+ ip_image: Image.Image | str = None,
121
+ image_wh: tuple[int, int] = [1024, 1024],
122
+ infer_step: int = 50,
123
+ ip_image_size: int = 512,
124
+ ) -> list[Image.Image]:
125
+ prompt = "Single " + prompt + ", in the center of the image"
126
+ prompt += ", high quality, high resolution, best quality, white background, 3D style," # noqa
127
+ logger.info(f"Processing prompt: {prompt}")
128
+
129
+ kwargs = dict(
130
+ prompt=prompt,
131
+ height=image_wh[1],
132
+ width=image_wh[0],
133
+ num_inference_steps=infer_step,
134
+ guidance_scale=guidance_scale,
135
+ num_images_per_prompt=n_sample,
136
+ )
137
+ if ip_image is not None:
138
+ if isinstance(ip_image, str):
139
+ ip_image = Image.open(ip_image)
140
+ ip_image = ip_image.resize((ip_image_size, ip_image_size))
141
+ kwargs.update(ip_adapter_image=[ip_image])
142
+
143
+ return pipeline(**kwargs).images
asset3d_gen/models/texture_model.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from diffusers import AutoencoderKL, DiffusionPipeline, EulerDiscreteScheduler
5
+ from huggingface_hub import snapshot_download
6
+ from kolors.models.controlnet import ControlNetModel
7
+ from kolors.models.modeling_chatglm import ChatGLMModel
8
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
9
+ from kolors.models.unet_2d_condition import UNet2DConditionModel
10
+ from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import (
11
+ StableDiffusionXLControlNetImg2ImgPipeline,
12
+ )
13
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
14
+
15
+ __all__ = [
16
+ "build_texture_gen_pipe",
17
+ ]
18
+
19
+
20
+ def build_texture_gen_pipe(
21
+ base_ckpt_dir: str,
22
+ controlnet_ckpt: str = None,
23
+ ip_adapt_scale: float = 0,
24
+ device: str = "cuda",
25
+ ) -> DiffusionPipeline:
26
+ tokenizer = ChatGLMTokenizer.from_pretrained(
27
+ f"{base_ckpt_dir}/Kolors/text_encoder"
28
+ )
29
+ text_encoder = ChatGLMModel.from_pretrained(
30
+ f"{base_ckpt_dir}/Kolors/text_encoder", torch_dtype=torch.float16
31
+ ).half()
32
+ vae = AutoencoderKL.from_pretrained(
33
+ f"{base_ckpt_dir}/Kolors/vae", revision=None
34
+ ).half()
35
+ unet = UNet2DConditionModel.from_pretrained(
36
+ f"{base_ckpt_dir}/Kolors/unet", revision=None
37
+ ).half()
38
+ scheduler = EulerDiscreteScheduler.from_pretrained(
39
+ f"{base_ckpt_dir}/Kolors/scheduler"
40
+ )
41
+
42
+ if controlnet_ckpt is None:
43
+ suffix = "geo_cond_mv"
44
+ model_path = snapshot_download(
45
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
46
+ )
47
+ controlnet_ckpt = os.path.join(model_path, suffix)
48
+
49
+ controlnet = ControlNetModel.from_pretrained(
50
+ controlnet_ckpt, use_safetensors=True
51
+ ).half()
52
+
53
+ # IP-Adapter model
54
+ image_encoder = None
55
+ clip_image_processor = None
56
+ if ip_adapt_scale > 0:
57
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
58
+ f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus/image_encoder",
59
+ # ignore_mismatched_sizes=True,
60
+ ).to(dtype=torch.float16)
61
+ ip_img_size = 336
62
+ clip_image_processor = CLIPImageProcessor(
63
+ size=ip_img_size, crop_size=ip_img_size
64
+ )
65
+
66
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline(
67
+ vae=vae,
68
+ controlnet=controlnet,
69
+ text_encoder=text_encoder,
70
+ tokenizer=tokenizer,
71
+ unet=unet,
72
+ scheduler=scheduler,
73
+ image_encoder=image_encoder,
74
+ feature_extractor=clip_image_processor,
75
+ force_zeros_for_empty_prompt=False,
76
+ )
77
+
78
+ if ip_adapt_scale > 0:
79
+ if hasattr(pipe.unet, "encoder_hid_proj"):
80
+ pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
81
+ pipe.load_ip_adapter(
82
+ f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus",
83
+ subfolder="",
84
+ weight_name=["ip_adapter_plus_general.bin"],
85
+ )
86
+ pipe.set_ip_adapter_scale([ip_adapt_scale])
87
+
88
+ pipe = pipe.to(device)
89
+ pipe.enable_model_cpu_offload()
90
+
91
+ return pipe
asset3d_gen/scripts/render_gs.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ import os
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+ from asset3d_gen.data.utils import (
11
+ CameraSetting,
12
+ init_kal_camera,
13
+ normalize_vertices_array,
14
+ )
15
+ from asset3d_gen.models.gs_model import GaussianOperator
16
+
17
+ logging.basicConfig(
18
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def parse_args():
24
+ parser = argparse.ArgumentParser(description="Render GS color images")
25
+
26
+ parser.add_argument(
27
+ "--input_gs", type=str, help="Input render GS.ply path."
28
+ )
29
+ parser.add_argument(
30
+ "--output_path",
31
+ type=str,
32
+ help="Output grid image path for rendered GS color images.",
33
+ )
34
+ parser.add_argument(
35
+ "--num_images", type=int, default=6, help="Number of images to render."
36
+ )
37
+ parser.add_argument(
38
+ "--elevation",
39
+ type=float,
40
+ nargs="+",
41
+ default=[20.0, -10.0],
42
+ help="Elevation angles for the camera (default: [20.0, -10.0])",
43
+ )
44
+ parser.add_argument(
45
+ "--distance",
46
+ type=float,
47
+ default=5,
48
+ help="Camera distance (default: 5)",
49
+ )
50
+ parser.add_argument(
51
+ "--resolution_hw",
52
+ type=int,
53
+ nargs=2,
54
+ default=(512, 512),
55
+ help="Resolution of the output images (default: (512, 512))",
56
+ )
57
+ parser.add_argument(
58
+ "--fov",
59
+ type=float,
60
+ default=30,
61
+ help="Field of view in degrees (default: 30)",
62
+ )
63
+ parser.add_argument(
64
+ "--device",
65
+ type=str,
66
+ choices=["cpu", "cuda"],
67
+ default="cuda",
68
+ help="Device to run on (default: `cuda`)",
69
+ )
70
+ parser.add_argument(
71
+ "--image_size",
72
+ type=int,
73
+ default=512,
74
+ help="Output image size for single view in color grid (default: 512)",
75
+ )
76
+
77
+ args = parser.parse_args()
78
+
79
+ return args
80
+
81
+
82
+ def load_gs_model(
83
+ input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
84
+ ) -> GaussianOperator:
85
+ gs_model = GaussianOperator.load_from_ply(input_gs)
86
+ # Normalize vertices to [-1, 1], center to (0, 0, 0).
87
+ _, scale, center = normalize_vertices_array(gs_model._means)
88
+ scale, center = float(scale), center.tolist()
89
+ transpose = [*[-v for v in center], *pre_quat]
90
+ instance_pose = torch.tensor(transpose).to(gs_model.device)
91
+ gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
92
+ gs_model.rescale(scale)
93
+
94
+ return gs_model
95
+
96
+
97
+ def entrypoint(input_gs: str = None, output_path: str = None) -> None:
98
+ args = parse_args()
99
+ if isinstance(input_gs, str):
100
+ args.input_gs = input_gs
101
+ if isinstance(output_path, str):
102
+ args.output_path = output_path
103
+
104
+ # Setup camera parameters
105
+ camera_params = CameraSetting(
106
+ num_images=args.num_images,
107
+ elevation=args.elevation,
108
+ distance=args.distance,
109
+ resolution_hw=args.resolution_hw,
110
+ fov=math.radians(args.fov),
111
+ device=args.device,
112
+ )
113
+ camera = init_kal_camera(camera_params)
114
+ matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
115
+ matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3]
116
+ w2cs = matrix_mv.to(camera_params.device)
117
+ c2ws = [torch.linalg.inv(matrix) for matrix in w2cs]
118
+ Ks = torch.tensor(camera_params.Ks).to(camera_params.device)
119
+
120
+ # Load GS model and normalize.
121
+ gs_model = load_gs_model(args.input_gs, pre_quat=[0.0, 0.0, 1.0, 0.0])
122
+
123
+ # Render GS color images.
124
+ images = []
125
+ for idx in tqdm(range(len(c2ws)), desc="Rendering GS"):
126
+ result = gs_model.render(
127
+ c2ws[idx],
128
+ Ks=Ks,
129
+ image_width=camera_params.resolution_hw[1],
130
+ image_height=camera_params.resolution_hw[0],
131
+ )
132
+ color = cv2.resize(
133
+ result.rgba,
134
+ (args.image_size, args.image_size),
135
+ interpolation=cv2.INTER_AREA,
136
+ )
137
+ images.append(color)
138
+
139
+ # Cat color images into grid image and save.
140
+ select_idxs = [[0, 2, 1], [5, 4, 3]] # fix order for 6 views
141
+ grid_image = []
142
+ for row_idxs in select_idxs:
143
+ row_image = []
144
+ for row_idx in row_idxs:
145
+ row_image.append(images[row_idx])
146
+ row_image = np.concatenate(row_image, axis=1)
147
+ grid_image.append(row_image)
148
+
149
+ grid_image = np.concatenate(grid_image, axis=0)
150
+ os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
151
+ cv2.imwrite(args.output_path, grid_image)
152
+ logger.info(f"Saved grid image to {args.output_path}")
153
+
154
+
155
+ if __name__ == "__main__":
156
+ entrypoint()
asset3d_gen/scripts/render_mv.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ from typing import List, Tuple
5
+
6
+ import fire
7
+ import numpy as np
8
+ import torch
9
+ from diffusers.utils import make_image_grid
10
+ from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import (
11
+ StableDiffusionXLControlNetImg2ImgPipeline,
12
+ )
13
+ from PIL import Image, ImageEnhance, ImageFilter
14
+ from torchvision import transforms
15
+ from asset3d_gen.data.datasets import Asset3dGenDataset
16
+ from asset3d_gen.models.texture_model import build_texture_gen_pipe
17
+
18
+ os.environ["https_proxy"] = "http://10.9.0.31:8838"
19
+
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def get_init_noise_image(image: Image.Image) -> Image.Image:
25
+ blurred_image = image.convert("L").filter(
26
+ ImageFilter.GaussianBlur(radius=3)
27
+ )
28
+
29
+ enhancer = ImageEnhance.Contrast(blurred_image)
30
+ image_decreased_contrast = enhancer.enhance(factor=0.5)
31
+
32
+ return image_decreased_contrast
33
+
34
+
35
+ def infer_pipe(
36
+ index_file: str,
37
+ controlnet_ckpt: str = None,
38
+ uid: str = None,
39
+ prompt: str = None,
40
+ controlnet_cond_scale: float = 0.4,
41
+ control_guidance_end: float = 0.9,
42
+ strength: float = 1.0,
43
+ num_inference_steps: int = 50,
44
+ guidance_scale: float = 10,
45
+ ip_adapt_scale: float = 0,
46
+ ip_img_path: str = None,
47
+ sub_idxs: List[List[int]] = None,
48
+ num_images_per_prompt: int = 3, # increase if want similar images.
49
+ device: str = "cuda",
50
+ save_dir: str = "infer_vis",
51
+ seed: int = None,
52
+ target_hw: tuple[int, int] = (512, 512),
53
+ pipeline: StableDiffusionXLControlNetImg2ImgPipeline = None,
54
+ ) -> str:
55
+ # sub_idxs = [[0, 1, 2], [3, 4, 5]] # None for single image.
56
+ if sub_idxs is None:
57
+ sub_idxs = [[random.randint(0, 5)]] # 6 views.
58
+ target_hw = [2 * size for size in target_hw]
59
+
60
+ transform_list = [
61
+ transforms.Resize(
62
+ target_hw, interpolation=transforms.InterpolationMode.BILINEAR
63
+ ),
64
+ transforms.CenterCrop(target_hw),
65
+ transforms.ToTensor(),
66
+ transforms.Normalize([0.5], [0.5]),
67
+ ]
68
+ image_transform = transforms.Compose(transform_list)
69
+ control_transform = transforms.Compose(transform_list[:-1])
70
+
71
+ grid_hw = (target_hw[0] * len(sub_idxs), target_hw[1] * len(sub_idxs[0]))
72
+ dataset = Asset3dGenDataset(
73
+ index_file, target_hw=grid_hw, sub_idxs=sub_idxs
74
+ )
75
+
76
+ if uid is None:
77
+ uid = random.choice(list(dataset.meta_info.keys()))
78
+ if prompt is None:
79
+ prompt = dataset.meta_info[uid]["capture"]
80
+ if isinstance(prompt, List) or isinstance(prompt, Tuple):
81
+ prompt = ", ".join(map(str, prompt))
82
+ # prompt += "high quality, ultra-clear, high resolution, best quality, 4k"
83
+ # prompt += "高品质,清晰,细节"
84
+ prompt += ", high quality, high resolution, best quality"
85
+ # prompt += ", with diffuse lighting, showing no reflections."
86
+ logger.info(f"Inference with prompt: {prompt}")
87
+
88
+ negative_prompt = (
89
+ "nsfw,脸部阴影,低分辨率,jpeg伪影、模糊、糟糕,黑脸,霓虹灯,高光,镜面反射"
90
+ )
91
+
92
+ control_image = dataset.fetch_sample_grid_images(
93
+ uid,
94
+ attrs=["image_view_normal", "image_position", "image_mask"],
95
+ sub_idxs=sub_idxs,
96
+ transform=control_transform,
97
+ )
98
+
99
+ color_image = dataset.fetch_sample_grid_images(
100
+ uid,
101
+ attrs=["image_color"],
102
+ sub_idxs=sub_idxs,
103
+ transform=image_transform,
104
+ )
105
+
106
+ normal_pil, position_pil, mask_pil, color_pil = dataset.visualize_item(
107
+ control_image,
108
+ color_image,
109
+ save_dir=save_dir,
110
+ )
111
+
112
+ if pipeline is None:
113
+ pipeline = build_texture_gen_pipe(
114
+ base_ckpt_dir="./weights",
115
+ controlnet_ckpt=controlnet_ckpt,
116
+ ip_adapt_scale=ip_adapt_scale,
117
+ device=device,
118
+ )
119
+
120
+ if ip_adapt_scale > 0 and ip_img_path is not None and len(ip_img_path) > 0:
121
+ ip_image = Image.open(ip_img_path).convert("RGB")
122
+ ip_image = ip_image.resize(target_hw[::-1])
123
+ ip_image = [ip_image]
124
+ pipeline.set_ip_adapter_scale([ip_adapt_scale])
125
+ else:
126
+ ip_image = None
127
+
128
+ generator = None
129
+ if seed is not None:
130
+ generator = torch.Generator(device).manual_seed(seed)
131
+ torch.manual_seed(seed)
132
+ np.random.seed(seed)
133
+ random.seed(seed)
134
+
135
+ init_image = get_init_noise_image(normal_pil)
136
+ # init_image = get_init_noise_image(color_pil)
137
+
138
+ images = []
139
+ row_num, col_num = 2, 3
140
+ img_save_paths = []
141
+ while len(images) < col_num:
142
+ image = pipeline(
143
+ prompt=prompt,
144
+ image=init_image,
145
+ controlnet_conditioning_scale=controlnet_cond_scale,
146
+ control_guidance_end=control_guidance_end,
147
+ strength=strength,
148
+ control_image=control_image[None, ...],
149
+ negative_prompt=negative_prompt,
150
+ num_inference_steps=num_inference_steps,
151
+ guidance_scale=guidance_scale,
152
+ num_images_per_prompt=num_images_per_prompt,
153
+ ip_adapter_image=ip_image,
154
+ generator=generator,
155
+ ).images
156
+ images.extend(image)
157
+
158
+ grid_image = [normal_pil, position_pil, color_pil] + images[:col_num]
159
+ # save_dir = os.path.join(save_dir, uid)
160
+ os.makedirs(save_dir, exist_ok=True)
161
+
162
+ for idx in range(col_num):
163
+ rgba_image = Image.merge("RGBA", (*images[idx].split(), mask_pil))
164
+ img_save_path = os.path.join(save_dir, f"color_sample{idx}.png")
165
+ rgba_image.save(img_save_path)
166
+ img_save_paths.append(img_save_path)
167
+
168
+ sub_idxs = "_".join(
169
+ [str(item) for sublist in sub_idxs for item in sublist]
170
+ )
171
+ save_path = os.path.join(
172
+ save_dir, f"sample_idx{str(sub_idxs)}_ip{ip_adapt_scale}.jpg"
173
+ )
174
+ make_image_grid(grid_image, row_num, col_num).save(save_path)
175
+ logger.info(f"Visualize in {save_path}")
176
+
177
+ return img_save_paths
178
+
179
+
180
+ def entrypoint() -> None:
181
+ fire.Fire(infer_pipe)
182
+
183
+
184
+ if __name__ == "__main__":
185
+ entrypoint()
asset3d_gen/scripts/text2image.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+
5
+ from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import (
6
+ StableDiffusionXLPipeline,
7
+ )
8
+ from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa
9
+ StableDiffusionXLPipeline as StableDiffusionXLPipelineIP,
10
+ )
11
+ from tqdm import tqdm
12
+ from asset3d_gen.models.text_model import (
13
+ build_text2img_ip_pipeline,
14
+ build_text2img_pipeline,
15
+ text2img_gen,
16
+ )
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def parse_args():
23
+ parser = argparse.ArgumentParser(description="Text to Image.")
24
+ parser.add_argument(
25
+ "--prompts",
26
+ type=str,
27
+ nargs="+",
28
+ help="List of prompts (space-separated).",
29
+ )
30
+ parser.add_argument(
31
+ "--ref_image",
32
+ type=str,
33
+ nargs="+",
34
+ help="List of ref_image paths (space-separated).",
35
+ )
36
+ parser.add_argument(
37
+ "--output_root",
38
+ type=str,
39
+ help="Root directory for saving outputs.",
40
+ )
41
+ parser.add_argument(
42
+ "--guidance_scale",
43
+ type=float,
44
+ default=12.0,
45
+ help="Guidance scale for the diffusion model.",
46
+ )
47
+ parser.add_argument(
48
+ "--ref_scale",
49
+ type=float,
50
+ default=0.3,
51
+ help="Reference image scale for the IP adapter.",
52
+ )
53
+ parser.add_argument(
54
+ "--n_sample",
55
+ type=int,
56
+ default=1,
57
+ )
58
+ parser.add_argument(
59
+ "--resolution",
60
+ type=int,
61
+ default=1024,
62
+ )
63
+ parser.add_argument(
64
+ "--infer_step",
65
+ type=int,
66
+ default=50,
67
+ )
68
+ args = parser.parse_args()
69
+
70
+ return args
71
+
72
+
73
+ def entrypoint(
74
+ pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP = None,
75
+ **kwargs,
76
+ ) -> list[str]:
77
+ args = parse_args()
78
+ for k, v in kwargs.items():
79
+ if hasattr(args, k) and v is not None:
80
+ setattr(args, k, v)
81
+
82
+ prompts = args.prompts
83
+ if len(prompts) == 1 and prompts[0].endswith(".txt"):
84
+ with open(prompts[0], "r") as f:
85
+ prompts = f.readlines()
86
+ prompts = [
87
+ prompt.strip() for prompt in prompts if prompt.strip() != ""
88
+ ]
89
+
90
+ os.makedirs(args.output_root, exist_ok=True)
91
+
92
+ ip_img_paths = args.ref_image
93
+ if ip_img_paths is None or len(ip_img_paths) == 0:
94
+ args.ref_scale = 0
95
+ ip_img_paths = [None] * len(prompts)
96
+ elif isinstance(ip_img_paths, str):
97
+ ip_img_paths = [ip_img_paths] * len(prompts)
98
+ elif isinstance(ip_img_paths, list):
99
+ if len(ip_img_paths) == 1:
100
+ ip_img_paths = ip_img_paths * len(prompts)
101
+ else:
102
+ raise ValueError("Invalid ref_image paths.")
103
+ assert len(ip_img_paths) == len(
104
+ prompts
105
+ ), f"Number of ref images does not match prompts, {len(ip_img_paths)} != {len(prompts)}" # noqa
106
+
107
+ if pipeline is None:
108
+ if args.ref_scale > 0:
109
+ pipeline = build_text2img_ip_pipeline(
110
+ "weights/Kolors",
111
+ ref_scale=args.ref_scale,
112
+ )
113
+ else:
114
+ pipeline = build_text2img_pipeline("weights/Kolors")
115
+
116
+ for idx, (prompt, ip_img_path) in tqdm(
117
+ enumerate(zip(prompts, ip_img_paths)),
118
+ desc="Generating images",
119
+ total=len(prompts),
120
+ ):
121
+ images = text2img_gen(
122
+ prompt=prompt,
123
+ n_sample=args.n_sample,
124
+ guidance_scale=args.guidance_scale,
125
+ pipeline=pipeline,
126
+ ip_image=ip_img_path,
127
+ image_wh=[args.resolution, args.resolution],
128
+ infer_step=args.infer_step,
129
+ )
130
+
131
+ save_paths = []
132
+ for sub_idx, image in enumerate(images):
133
+ save_path = (
134
+ f"{args.output_root}/sample_{idx*args.n_sample+sub_idx}.png"
135
+ )
136
+ image.save(save_path)
137
+ save_paths.append(save_path)
138
+
139
+ logger.info(f"Images saved to {args.output_root}")
140
+
141
+ return save_paths
142
+
143
+
144
+ if __name__ == "__main__":
145
+ entrypoint()
asset3d_gen/utils/gpt_clients.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import logging
3
+ import os
4
+ from io import BytesIO
5
+ from typing import Optional
6
+
7
+ from openai import AzureOpenAI, OpenAI # pip install openai
8
+ from PIL import Image
9
+ from tenacity import (
10
+ retry,
11
+ stop_after_attempt,
12
+ stop_after_delay,
13
+ wait_random_exponential,
14
+ )
15
+ from asset3d_gen.utils.process_media import combine_images_to_base64
16
+
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+ os.environ["https_proxy"] = "10.9.0.31:8838"
20
+
21
+
22
+ class GPTclient:
23
+ """A client to interact with the GPT model via OpenAI or Azure API."""
24
+
25
+ def __init__(
26
+ self,
27
+ endpoint: str,
28
+ api_key: str,
29
+ model_name: str = "yfb-gpt-4o",
30
+ api_version: str = None,
31
+ verbose: bool = False,
32
+ ):
33
+ if api_version is not None:
34
+ self.client = AzureOpenAI(
35
+ azure_endpoint=endpoint,
36
+ api_key=api_key,
37
+ api_version=api_version,
38
+ )
39
+ else:
40
+ self.client = OpenAI(
41
+ base_url=endpoint,
42
+ api_key=api_key,
43
+ )
44
+
45
+ self.endpoint = endpoint
46
+ self.model_name = model_name
47
+ self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
48
+ self.verbose = verbose
49
+
50
+ @retry(
51
+ wait=wait_random_exponential(min=1, max=20),
52
+ stop=(stop_after_attempt(10) | stop_after_delay(30)),
53
+ )
54
+ def completion_with_backoff(self, **kwargs):
55
+ return self.client.chat.completions.create(**kwargs)
56
+
57
+ def query(
58
+ self,
59
+ text_prompt: str,
60
+ image_base64: Optional[list[str | Image.Image]] = None,
61
+ system_role: Optional[str] = None,
62
+ ) -> Optional[str]:
63
+ """Queries the GPT model with a text and optional image prompts.
64
+
65
+ Args:
66
+ text_prompt (str): The main text input that the model responds to.
67
+ image_base64 (Optional[List[str]]): A list of image base64 strings
68
+ or local image paths or PIL.Image to accompany the text prompt.
69
+ system_role (Optional[str]): Optional system-level instructions
70
+ that specify the behavior of the assistant.
71
+
72
+ Returns:
73
+ Optional[str]: The response content generated by the model based on
74
+ the prompt. Returns `None` if an error occurs.
75
+ """
76
+ if system_role is None:
77
+ system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
78
+
79
+ content_user = [
80
+ {
81
+ "type": "text",
82
+ "text": text_prompt,
83
+ },
84
+ ]
85
+
86
+ # Process images if provided
87
+ if image_base64 is not None:
88
+ image_base64 = (
89
+ image_base64
90
+ if isinstance(image_base64, list)
91
+ else [image_base64]
92
+ )
93
+ for img in image_base64:
94
+ if isinstance(img, Image.Image):
95
+ buffer = BytesIO()
96
+ img.save(buffer, format=img.format or "PNG")
97
+ buffer.seek(0)
98
+ image_binary = buffer.read()
99
+ img = base64.b64encode(image_binary).decode("utf-8")
100
+ elif (
101
+ len(os.path.splitext(img)) > 1
102
+ and os.path.splitext(img)[-1].lower() in self.image_formats
103
+ ):
104
+ if not os.path.exists(img):
105
+ raise FileNotFoundError(f"Image file not found: {img}")
106
+ with open(img, "rb") as f:
107
+ img = base64.b64encode(f.read()).decode("utf-8")
108
+
109
+ content_user.append(
110
+ {
111
+ "type": "image_url",
112
+ "image_url": {"url": f"data:image/png;base64,{img}"},
113
+ }
114
+ )
115
+
116
+ payload = {
117
+ "messages": [
118
+ {"role": "system", "content": system_role},
119
+ {"role": "user", "content": content_user},
120
+ ],
121
+ "temperature": 0.1,
122
+ "max_tokens": 500,
123
+ "top_p": 0.1,
124
+ "frequency_penalty": 0,
125
+ "presence_penalty": 0,
126
+ "stop": None,
127
+ }
128
+ payload.update({"model": self.model_name})
129
+
130
+ response = None
131
+ try:
132
+ response = self.completion_with_backoff(**payload)
133
+ response = response.choices[0].message.content
134
+ except Exception as e:
135
+ logger.error(f"Error GPTclint {self.endpoint} API call: {e}")
136
+ response = None
137
+
138
+ if self.verbose:
139
+ logger.info(f"Prompt: {text_prompt}")
140
+ logger.info(f"Response: {response}")
141
+
142
+ return response
143
+
144
+
145
+ endpoint = os.environ.get("endpoint", None)
146
+ api_key = os.environ.get("api_key", None)
147
+ api_version = os.environ.get("api_version", None)
148
+ if endpoint and api_key and api_version:
149
+ GPT_CLIENT = GPTclient(
150
+ endpoint=endpoint,
151
+ api_key=api_key,
152
+ api_version=api_version,
153
+ model_name="yfb-gpt-4o-sweden" if "sweden" in endpoint else None,
154
+ )
155
+ else:
156
+ GPT_CLIENT = GPTclient(
157
+ endpoint="https://openrouter.ai/api/v1",
158
+ api_key="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", # noqa
159
+ model_name="qwen/qwen2.5-vl-72b-instruct:free",
160
+ )
161
+
162
+
163
+ if __name__ == "__main__":
164
+ if "openrouter" in GPT_CLIENT.endpoint:
165
+ response = GPT_CLIENT.query(
166
+ text_prompt="What is the content in each image?",
167
+ image_base64=combine_images_to_base64(
168
+ [
169
+ "outputs/text2image/demo_objects/bed/sample_0.jpg",
170
+ "outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png", # noqa
171
+ "outputs/text2image/demo_objects/cardboard/sample_1.jpg",
172
+ ]
173
+ ), # input raw image_path if only one image
174
+ )
175
+ print(response)
176
+ else:
177
+ response = GPT_CLIENT.query(
178
+ text_prompt="What is the content in the images?",
179
+ image_base64=[
180
+ Image.open("outputs/text2image/demo_objects/bed/sample_0.jpg"),
181
+ Image.open(
182
+ "outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png" # noqa
183
+ ),
184
+ ],
185
+ )
186
+ print(response)
187
+
188
+ # test2: text prompt
189
+ response = GPT_CLIENT.query(text_prompt="What is the capital of China?")
190
+ print(response)
asset3d_gen/utils/process_media.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import logging
3
+ import math
4
+ import os
5
+ import subprocess
6
+ from glob import glob
7
+ from io import BytesIO
8
+ from typing import Union
9
+
10
+ import cv2
11
+ import imageio
12
+ import numpy as np
13
+ import PIL.Image as Image
14
+ from moviepy.editor import VideoFileClip, clips_array
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ __all__ = [
21
+ "render_asset3d",
22
+ "merge_images_video",
23
+ "filter_small_connected_components",
24
+ "filter_image_small_connected_components",
25
+ "combine_images_to_base64",
26
+ ]
27
+
28
+
29
+ def render_asset3d(
30
+ mesh_path: str,
31
+ output_root: str,
32
+ distance: float = 5.0,
33
+ num_images: int = 1,
34
+ elevation: list[float] = (0.0,),
35
+ pbr_light_factor: float = 1.5,
36
+ return_key: str = "image_color/*",
37
+ output_subdir: str = "renders",
38
+ gen_color_mp4: bool = False,
39
+ gen_viewnormal_mp4: bool = False,
40
+ gen_glonormal_mp4: bool = False,
41
+ device: str = "cpu",
42
+ ) -> list[str]:
43
+ command = [
44
+ "python3",
45
+ "asset3d_gen/data/differentiable_render.py",
46
+ "--mesh_path",
47
+ mesh_path,
48
+ "--output_root",
49
+ output_root,
50
+ "--uuid",
51
+ output_subdir,
52
+ "--distance",
53
+ str(distance),
54
+ "--num_images",
55
+ str(num_images),
56
+ "--elevation",
57
+ *map(str, elevation),
58
+ "--pbr_light_factor",
59
+ str(pbr_light_factor),
60
+ "--with_mtl",
61
+ "--device",
62
+ device,
63
+ ]
64
+ if gen_color_mp4:
65
+ command.append("--gen_color_mp4")
66
+ if gen_viewnormal_mp4:
67
+ command.append("--gen_viewnormal_mp4")
68
+ if gen_glonormal_mp4:
69
+ command.append("--gen_glonormal_mp4")
70
+ try:
71
+ subprocess.run(command, check=True)
72
+ except subprocess.CalledProcessError as e:
73
+ logger.error(f"Error occurred during rendering: {e}.")
74
+
75
+ dst_paths = glob(os.path.join(output_root, output_subdir, return_key))
76
+
77
+ return dst_paths
78
+
79
+
80
+ def merge_images_video(color_images, normal_images, output_path) -> None:
81
+ width = color_images[0].shape[1]
82
+ combined_video = [
83
+ np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]])
84
+ for rgb_img, normal_img in zip(color_images, normal_images)
85
+ ]
86
+ imageio.mimsave(output_path, combined_video, fps=50)
87
+
88
+ return
89
+
90
+
91
+ def merge_video_video(
92
+ video_path1: str, video_path2: str, output_path: str
93
+ ) -> None:
94
+ """Merge two videos by the left half and the right half of the videos."""
95
+ clip1 = VideoFileClip(video_path1)
96
+ clip2 = VideoFileClip(video_path2)
97
+
98
+ if clip1.size != clip2.size:
99
+ raise ValueError("The resolutions of the two videos do not match.")
100
+
101
+ width, height = clip1.size
102
+ clip1_half = clip1.crop(x1=0, y1=0, x2=width // 2, y2=height)
103
+ clip2_half = clip2.crop(x1=width // 2, y1=0, x2=width, y2=height)
104
+ final_clip = clips_array([[clip1_half, clip2_half]])
105
+ final_clip.write_videofile(output_path, codec="libx264")
106
+
107
+
108
+ def filter_small_connected_components(
109
+ mask: Union[Image.Image, np.ndarray],
110
+ area_ratio: float,
111
+ connectivity: int = 8,
112
+ ) -> np.ndarray:
113
+ if isinstance(mask, Image.Image):
114
+ mask = np.array(mask)
115
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
116
+ mask,
117
+ connectivity=connectivity,
118
+ )
119
+
120
+ small_components = np.zeros_like(mask, dtype=np.uint8)
121
+ mask_area = (mask != 0).sum()
122
+ min_area = mask_area // area_ratio
123
+ for label in range(1, num_labels):
124
+ area = stats[label, cv2.CC_STAT_AREA]
125
+ if area < min_area:
126
+ small_components[labels == label] = 255
127
+
128
+ mask = cv2.bitwise_and(mask, cv2.bitwise_not(small_components))
129
+
130
+ return mask
131
+
132
+
133
+ def filter_image_small_connected_components(
134
+ image: Union[Image.Image, np.ndarray],
135
+ area_ratio: float = 10,
136
+ connectivity: int = 8,
137
+ ) -> np.ndarray:
138
+ if isinstance(image, Image.Image):
139
+ image = image.convert("RGBA")
140
+ image = np.array(image)
141
+
142
+ mask = image[..., 3]
143
+ mask = filter_small_connected_components(mask, area_ratio, connectivity)
144
+ image[..., 3] = mask
145
+
146
+ return image
147
+
148
+
149
+ def combine_images_to_base64(
150
+ images: list[str | Image.Image],
151
+ cat_row_col: tuple[int, int] = None,
152
+ target_wh: tuple[int, int] = (512, 512),
153
+ ) -> str:
154
+ n_images = len(images)
155
+ if cat_row_col is None:
156
+ n_col = math.ceil(math.sqrt(n_images))
157
+ n_row = math.ceil(n_images / n_col)
158
+ else:
159
+ n_row, n_col = cat_row_col
160
+
161
+ images = [
162
+ Image.open(p).convert("RGB") if isinstance(p, str) else p
163
+ for p in images[: n_row * n_col]
164
+ ]
165
+ images = [img.resize(target_wh) for img in images]
166
+
167
+ grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
168
+ grid = Image.new("RGB", (grid_w, grid_h), (255, 255, 255))
169
+
170
+ for idx, img in enumerate(images):
171
+ row, col = divmod(idx, n_col)
172
+ grid.paste(img, (col * target_wh[0], row * target_wh[1]))
173
+
174
+ buffer = BytesIO()
175
+ grid.save(buffer, format="PNG")
176
+
177
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
178
+
179
+
180
+ if __name__ == "__main__":
181
+ # Example usage:
182
+ merge_video_video(
183
+ "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
184
+ "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa
185
+ "merge.mp4",
186
+ )
187
+
188
+ image_base64 = combine_images_to_base64(
189
+ [
190
+ "outputs/text2image/demo_objects/bed/sample_0.jpg",
191
+ "outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png", # noqa
192
+ "outputs/text2image/demo_objects/cardboard/sample_1.jpg",
193
+ ]
194
+ )
asset3d_gen/utils/tags.py ADDED
@@ -0,0 +1 @@
 
 
1
+ VERSION = "v0.0.2"
asset3d_gen/validators/aesthetic_predictor.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import clip
4
+ import numpy as np
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ import torch.nn as nn
8
+ from huggingface_hub import snapshot_download
9
+ from PIL import Image
10
+
11
+ os.environ["https_proxy"] = "http://10.9.0.31:8838"
12
+
13
+
14
+ class AestheticPredictor:
15
+ """Aesthetic Score Predictor.
16
+
17
+ Args:
18
+ clip_model_dir (str): Path to the directory of the CLIP model.
19
+ sac_model_path (str): Path to the pre-trained SAC model.
20
+ device (str): Device to use for computation ("cuda" or "cpu").
21
+ """
22
+
23
+ def __init__(self, clip_model_dir=None, sac_model_path=None, device=None):
24
+
25
+ self.device = device or (
26
+ "cuda" if torch.cuda.is_available() else "cpu"
27
+ )
28
+
29
+ if clip_model_dir is None:
30
+ model_path = snapshot_download(
31
+ repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*"
32
+ )
33
+ suffix = "aesthetic"
34
+ model_path = snapshot_download(
35
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
36
+ )
37
+ clip_model_dir = os.path.join(model_path, suffix)
38
+
39
+ if sac_model_path is None:
40
+ model_path = snapshot_download(
41
+ repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*"
42
+ )
43
+ suffix = "aesthetic"
44
+ model_path = snapshot_download(
45
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
46
+ )
47
+ sac_model_path = os.path.join(
48
+ model_path, suffix, "sac+logos+ava1-l14-linearMSE.pth"
49
+ )
50
+
51
+ self.clip_model, self.preprocess = self._load_clip_model(
52
+ clip_model_dir
53
+ )
54
+ self.sac_model = self._load_sac_model(sac_model_path, input_size=768)
55
+
56
+ class MLP(pl.LightningModule): # noqa
57
+ def __init__(self, input_size):
58
+ super().__init__()
59
+ self.layers = nn.Sequential(
60
+ nn.Linear(input_size, 1024),
61
+ nn.Dropout(0.2),
62
+ nn.Linear(1024, 128),
63
+ nn.Dropout(0.2),
64
+ nn.Linear(128, 64),
65
+ nn.Dropout(0.1),
66
+ nn.Linear(64, 16),
67
+ nn.Linear(16, 1),
68
+ )
69
+
70
+ def forward(self, x):
71
+ return self.layers(x)
72
+
73
+ @staticmethod
74
+ def normalized(a, axis=-1, order=2):
75
+ """Normalize the array to unit norm."""
76
+ l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
77
+ l2[l2 == 0] = 1
78
+ return a / np.expand_dims(l2, axis)
79
+
80
+ def _load_clip_model(self, model_dir: str, model_name: str = "ViT-L/14"):
81
+ """Load the CLIP model."""
82
+ model, preprocess = clip.load(
83
+ model_name, download_root=model_dir, device=self.device
84
+ )
85
+ return model, preprocess
86
+
87
+ def _load_sac_model(self, model_path, input_size):
88
+ """Load the SAC model."""
89
+ model = self.MLP(input_size)
90
+ ckpt = torch.load(model_path)
91
+ model.load_state_dict(ckpt)
92
+ model.to(self.device)
93
+ model.eval()
94
+ return model
95
+
96
+ def predict(self, image_path):
97
+ """Predict the aesthetic score for a given image.
98
+
99
+ Args:
100
+ image_path (str): Path to the image file.
101
+
102
+ Returns:
103
+ float: Predicted aesthetic score.
104
+ """
105
+ pil_image = Image.open(image_path)
106
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
107
+
108
+ with torch.no_grad():
109
+ # Extract CLIP features
110
+ image_features = self.clip_model.encode_image(image)
111
+ # Normalize features
112
+ normalized_features = self.normalized(
113
+ image_features.cpu().detach().numpy()
114
+ )
115
+ # Predict score
116
+ prediction = self.sac_model(
117
+ torch.from_numpy(normalized_features)
118
+ .type(torch.FloatTensor)
119
+ .to(self.device)
120
+ )
121
+
122
+ return prediction.item()
123
+
124
+
125
+ if __name__ == "__main__":
126
+ # Configuration
127
+ img_path = "/home/users/xinjie.wang/xinjie/asset3d-gen/outputs/imageto3d/demo_objects/bed/sample_0/sample_0_raw.png" # noqa
128
+ # clip_model_dir = "/horizon-bucket/robot_lab/users/xinjie.wang/weights/clip" # noqa
129
+ # sac_model_path = "/horizon-bucket/robot_lab/users/xinjie.wang/weights/sac/sac+logos+ava1-l14-linearMSE.pth" # noqa
130
+
131
+ # Initialize the predictor
132
+ predictor = AestheticPredictor()
133
+
134
+ # Predict the aesthetic score
135
+ score = predictor.predict(img_path)
136
+ print("Aesthetic score predicted by the model:", score)
asset3d_gen/validators/quality_checkers.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ from tqdm import tqdm
5
+ from asset3d_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
6
+ from asset3d_gen.utils.process_media import render_asset3d
7
+ from asset3d_gen.validators.aesthetic_predictor import AestheticPredictor
8
+
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class BaseChecker:
14
+ def __init__(self, prompt: str = None, verbose: bool = False) -> None:
15
+ self.prompt = prompt
16
+ self.verbose = verbose
17
+
18
+ def query(self, *args, **kwargs):
19
+ raise NotImplementedError(
20
+ "Subclasses must implement the query method."
21
+ )
22
+
23
+ def __call__(self, *args, **kwargs) -> bool:
24
+ response = self.query(*args, **kwargs)
25
+ if response is None:
26
+ response = "Error when calling gpt api."
27
+
28
+ if self.verbose and response != "YES":
29
+ logger.info(response)
30
+
31
+ flag = "YES" in response
32
+ response = "YES" if flag else response
33
+
34
+ return flag, response
35
+
36
+ @staticmethod
37
+ def validate(
38
+ checkers: list["BaseChecker"], images_list: list[list[str]]
39
+ ) -> list:
40
+ assert len(checkers) == len(images_list)
41
+ results = []
42
+ overall_result = True
43
+ for checker, images in zip(checkers, images_list):
44
+ qa_flag, qa_info = checker(images)
45
+ if isinstance(qa_info, str):
46
+ qa_info = qa_info.replace("\n", ".")
47
+ results.append([checker.__class__.__name__, qa_info])
48
+ if qa_flag is False:
49
+ overall_result = False
50
+
51
+ results.append(["overall", "YES" if overall_result else "NO"])
52
+
53
+ return results
54
+
55
+
56
+ class MeshGeoChecker(BaseChecker):
57
+ def __init__(
58
+ self,
59
+ gpt_client: GPTclient,
60
+ prompt: str = None,
61
+ verbose: bool = False,
62
+ ) -> None:
63
+ super().__init__(prompt, verbose)
64
+ self.gpt_client = gpt_client
65
+ if self.prompt is None:
66
+ self.prompt = """
67
+ Refer to the provided multi-view rendering images to evaluate
68
+ whether the geometry of the 3D object asset is complete and
69
+ whether the asset can be placed stably on the ground.
70
+ Return "YES" only if reach the requirments,
71
+ otherwise "NO" and explain the reason very briefly.
72
+ """
73
+
74
+ def query(self, image_paths: str) -> str:
75
+ # Hardcode tmp because of the openrouter can't input multi images.
76
+ if "openrouter" in self.gpt_client.endpoint:
77
+ from asset3d_gen.utils.process_media import (
78
+ combine_images_to_base64,
79
+ )
80
+
81
+ image_paths = combine_images_to_base64(image_paths)
82
+
83
+ return self.gpt_client.query(
84
+ text_prompt=self.prompt,
85
+ image_base64=image_paths,
86
+ )
87
+
88
+
89
+ class ImageSegChecker(BaseChecker):
90
+ def __init__(
91
+ self,
92
+ gpt_client: GPTclient,
93
+ prompt: str = None,
94
+ verbose: bool = False,
95
+ ) -> None:
96
+ super().__init__(prompt, verbose)
97
+ self.gpt_client = gpt_client
98
+ if self.prompt is None:
99
+ self.prompt = """
100
+ The first image is the original, and the second image is the
101
+ result after segmenting the main object. Evaluate the segmentation
102
+ quality to ensure the main object is clearly segmented without
103
+ significant truncation. Note that the foreground of the object
104
+ needs to be extracted instead of the background.
105
+ Minor imperfections can be ignored. If segmentation is acceptable,
106
+ return "YES" only; otherwise, return "NO" with
107
+ very brief explanation.
108
+ """
109
+
110
+ def query(self, image_paths: list[str]) -> str:
111
+ if len(image_paths) != 2:
112
+ raise ValueError(
113
+ "ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa
114
+ )
115
+ # Hardcode tmp because of the openrouter can't input multi images.
116
+ if "openrouter" in self.gpt_client.endpoint:
117
+ from asset3d_gen.utils.process_media import (
118
+ combine_images_to_base64,
119
+ )
120
+
121
+ image_paths = combine_images_to_base64(image_paths)
122
+
123
+ return self.gpt_client.query(
124
+ text_prompt=self.prompt,
125
+ image_base64=image_paths,
126
+ )
127
+
128
+
129
+ class ImageAestheticChecker(BaseChecker):
130
+ def __init__(
131
+ self,
132
+ clip_model_dir: str = None,
133
+ sac_model_path: str = None,
134
+ thresh: float = 4.50,
135
+ verbose: bool = False,
136
+ ) -> None:
137
+ super().__init__(verbose=verbose)
138
+ self.clip_model_dir = clip_model_dir
139
+ self.sac_model_path = sac_model_path
140
+ self.thresh = thresh
141
+ self.predictor = AestheticPredictor(clip_model_dir, sac_model_path)
142
+
143
+ def query(self, image_paths: list[str]) -> float:
144
+ scores = [self.predictor.predict(img_path) for img_path in image_paths]
145
+ return sum(scores) / len(scores)
146
+
147
+ def __call__(self, image_paths: list[str], **kwargs) -> bool:
148
+ avg_score = self.query(image_paths)
149
+ if self.verbose:
150
+ logger.info(f"Average aesthetic score: {avg_score}")
151
+ return avg_score > self.thresh, avg_score
152
+
153
+
154
+ if __name__ == "__main__":
155
+ geo_checker = MeshGeoChecker(GPT_CLIENT)
156
+ seg_checker = ImageSegChecker(GPT_CLIENT)
157
+ aesthetic_checker = ImageAestheticChecker(
158
+ "/horizon-bucket/robot_lab/users/xinjie.wang/weights/clip",
159
+ "/horizon-bucket/robot_lab/users/xinjie.wang/weights/sac/sac+logos+ava1-l14-linearMSE.pth", # noqa
160
+ )
161
+
162
+ checkers = [geo_checker, seg_checker, aesthetic_checker]
163
+
164
+ output_root = "outputs/test_gpt"
165
+
166
+ fails = []
167
+ for idx in tqdm(range(150)):
168
+ mesh_path = f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}.obj" # noqa
169
+ if not os.path.exists(mesh_path):
170
+ continue
171
+ image_paths = render_asset3d(
172
+ mesh_path,
173
+ f"{output_root}/{idx}",
174
+ num_images=8,
175
+ elevation=(30, -30),
176
+ distance=5.5,
177
+ )
178
+
179
+ for cid, checker in enumerate(checkers):
180
+ if isinstance(checker, ImageSegChecker):
181
+ images = [
182
+ f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_raw.png", # noqa
183
+ f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_cond.png", # noqa
184
+ ]
185
+ else:
186
+ images = image_paths
187
+ result, info = checker(images)
188
+ logger.info(
189
+ f"Checker {checker.__class__.__name__}: {result}, {info}, mesh {mesh_path}" # noqa
190
+ )
191
+
192
+ if result is False:
193
+ fails.append((idx, cid, info))
194
+
195
+ break
asset3d_gen/validators/urdf_convertor.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import shutil
4
+ import xml.etree.ElementTree as ET
5
+ import zipfile
6
+ from datetime import datetime
7
+ from xml.dom.minidom import parseString
8
+
9
+ import numpy as np
10
+ import trimesh
11
+ from asset3d_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
12
+ from asset3d_gen.utils.process_media import render_asset3d
13
+ from asset3d_gen.utils.tags import VERSION
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ __all__ = ["URDFGenerator"]
20
+
21
+
22
+ URDF_TEMPLATE = """
23
+ <robot name="template_robot">
24
+ <link name="template_link">
25
+ <visual>
26
+ <geometry>
27
+ <mesh filename="mesh.obj" scale="1.0 1.0 1.0"/>
28
+ </geometry>
29
+ </visual>
30
+ <collision>
31
+ <geometry>
32
+ <mesh filename="mesh.obj" scale="1.0 1.0 1.0"/>
33
+ </geometry>
34
+ <gazebo>
35
+ <mu1>0.8</mu1> <!-- 主摩擦系数 -->
36
+ <mu2>0.6</mu2> <!-- 次摩擦系数 -->
37
+ </gazebo>
38
+ </collision>
39
+ <inertial>
40
+ <mass value="1.0"/>
41
+ <origin xyz="0 0 0"/>
42
+ <inertia ixx="1.0" ixy="0.0" ixz="0.0" iyy="1.0" iyz="0.0" izz="1.0"/>
43
+ </inertial>
44
+ <extra_info>
45
+ <scale>1.0</scale>
46
+ <version>"0.0.0"</version>
47
+ <category>"unknown"</category>
48
+ <description>"unknown"</description>
49
+ <min_height>0.0</min_height>
50
+ <max_height>0.0</max_height>
51
+ <real_height>0.0</real_height>
52
+ <min_mass>0.0</min_mass>
53
+ <max_mass>0.0</max_mass>
54
+ <generate_time>"-1"</generate_time>
55
+ <gs_model>""</gs_model>
56
+ </extra_info>
57
+ </link>
58
+ </robot>
59
+ """
60
+
61
+
62
+ def zip_files(input_paths: list[str], output_zip: str) -> str:
63
+ with zipfile.ZipFile(output_zip, "w", zipfile.ZIP_DEFLATED) as zipf:
64
+ for input_path in input_paths:
65
+ if not os.path.exists(input_path):
66
+ raise FileNotFoundError(f"File not found: {input_path}")
67
+
68
+ if os.path.isdir(input_path):
69
+ for root, _, files in os.walk(input_path):
70
+ for file in files:
71
+ file_path = os.path.join(root, file)
72
+ arcname = os.path.relpath(
73
+ file_path, start=os.path.commonpath(input_paths)
74
+ )
75
+ zipf.write(file_path, arcname=arcname)
76
+ else:
77
+ arcname = os.path.relpath(
78
+ input_path, start=os.path.commonpath(input_paths)
79
+ )
80
+ zipf.write(input_path, arcname=arcname)
81
+
82
+ return output_zip
83
+
84
+
85
+ class URDFGenerator(object):
86
+ def __init__(
87
+ self,
88
+ gpt_client: GPTclient,
89
+ mesh_file_list: list[str] = ["material_0.png", "material.mtl"],
90
+ prompt_template: str = None,
91
+ attrs_name: list[str] = None,
92
+ render_dir: str = "urdf_renders",
93
+ render_view_num: int = 4,
94
+ ) -> None:
95
+ if mesh_file_list is None:
96
+ mesh_file_list = []
97
+ self.mesh_file_list = mesh_file_list
98
+ self.output_mesh_dir = "mesh"
99
+ self.output_render_dir = render_dir
100
+ self.gpt_client = gpt_client
101
+ self.render_view_num = render_view_num
102
+ if render_view_num == 4:
103
+ view_desc = "This is orthographic projection showing the front, left, right and back views " # noqa
104
+ else:
105
+ view_desc = "This is the rendered views "
106
+
107
+ if prompt_template is None:
108
+ prompt_template = (
109
+ view_desc
110
+ + """of the 3D object asset,
111
+ category: {category}.
112
+ Give the category of this object asset (within 3 words),
113
+ (if category is already provided, use it directly),
114
+ accurately describe this 3D object asset (within 15 words),
115
+ and give the recommended geometric height range (unit: meter),
116
+ weight range (unit: kilogram), the average static friction
117
+ coefficient of the object relative to rubber and the average
118
+ dynamic friction coefficient of the object relative to rubber.
119
+ Return response format as shown in Example.
120
+
121
+ Example:
122
+ Category: cup
123
+ Description: shiny golden cup with floral design
124
+ Height: 0.1-0.15 m
125
+ Weight: 0.3-0.6 kg
126
+ Static friction coefficient: 1.1
127
+ Dynamic friction coefficient: 0.9
128
+ """
129
+ )
130
+
131
+ self.prompt_template = prompt_template
132
+ if attrs_name is None:
133
+ attrs_name = [
134
+ "category",
135
+ "description",
136
+ "min_height",
137
+ "max_height",
138
+ "real_height",
139
+ "min_mass",
140
+ "max_mass",
141
+ "version",
142
+ "generate_time",
143
+ "gs_model",
144
+ ]
145
+ self.attrs_name = attrs_name
146
+
147
+ def parse_response(self, response: str) -> dict[str, any]:
148
+ lines = response.split("\n")
149
+ lines = [line.strip() for line in lines if line]
150
+ category = lines[0].split(": ")[1]
151
+ description = lines[1].split(": ")[1]
152
+ min_height, max_height = map(
153
+ lambda x: float(x.strip().replace(",", "").split()[0]),
154
+ lines[2].split(": ")[1].split("-"),
155
+ )
156
+ min_mass, max_mass = map(
157
+ lambda x: float(x.strip().replace(",", "").split()[0]),
158
+ lines[3].split(": ")[1].split("-"),
159
+ )
160
+ mu1 = float(lines[4].split(": ")[1].replace(",", ""))
161
+ mu2 = float(lines[5].split(": ")[1].replace(",", ""))
162
+
163
+ return {
164
+ "category": category.lower(),
165
+ "description": description.lower(),
166
+ "min_height": round(min_height, 4),
167
+ "max_height": round(max_height, 4),
168
+ "real_height": round((min_height + max_height) / 2, 4),
169
+ "min_mass": round(min_mass, 4),
170
+ "max_mass": round(max_mass, 4),
171
+ "mu1": round(mu1, 2),
172
+ "mu2": round(mu2, 2),
173
+ "version": VERSION,
174
+ "generate_time": datetime.now().strftime("%Y%m%d%H%M%S"),
175
+ }
176
+
177
+ def generate_urdf(
178
+ self,
179
+ input_mesh: str,
180
+ output_dir: str,
181
+ attr_dict: dict,
182
+ output_name: str = None,
183
+ ) -> str:
184
+ """Generate a URDF file for a given mesh with specified attributes.
185
+
186
+ Args:
187
+ input_mesh (str): Path to the input mesh file.
188
+ output_dir (str): Directory to store the generated URDF
189
+ and processed mesh.
190
+ attr_dict (dict): Dictionary containing attributes like height,
191
+ mass, and friction coefficients.
192
+ output_name (str, optional): Name for the generated URDF and robot.
193
+
194
+ Returns:
195
+ str: Path to the generated URDF file.
196
+ """
197
+
198
+ # 1. Load and normalize the mesh
199
+ mesh = trimesh.load(input_mesh)
200
+ mesh_scale = np.ptp(mesh.vertices, axis=0).max()
201
+ mesh.vertices /= mesh_scale # Normalize to [-0.5, 0.5]
202
+ raw_height = np.ptp(mesh.vertices, axis=0)[1]
203
+
204
+ # 2. Scale the mesh to real height
205
+ real_height = attr_dict["real_height"]
206
+ scale = round(real_height / raw_height, 6)
207
+ mesh = mesh.apply_scale(scale)
208
+
209
+ # 3. Prepare output directories and save scaled mesh
210
+ mesh_folder = os.path.join(output_dir, self.output_mesh_dir)
211
+ os.makedirs(mesh_folder, exist_ok=True)
212
+
213
+ obj_name = os.path.basename(input_mesh)
214
+ mesh_output_path = os.path.join(mesh_folder, obj_name)
215
+ mesh.export(mesh_output_path)
216
+
217
+ # 4. Copy additional mesh files, if any
218
+ input_dir = os.path.dirname(input_mesh)
219
+ for file in self.mesh_file_list:
220
+ src_file = os.path.join(input_dir, file)
221
+ dest_file = os.path.join(mesh_folder, file)
222
+ if os.path.isfile(src_file):
223
+ shutil.copy(src_file, dest_file)
224
+
225
+ # 5. Determine output name
226
+ if output_name is None:
227
+ output_name = os.path.splitext(obj_name)[0]
228
+
229
+ # 6. Load URDF template and update attributes
230
+ robot = ET.fromstring(URDF_TEMPLATE)
231
+ robot.set("name", output_name)
232
+
233
+ link = robot.find("link")
234
+ if link is None:
235
+ raise ValueError("URDF template is missing 'link' element.")
236
+ link.set("name", output_name)
237
+
238
+ # Update visual geometry
239
+ visual = link.find("visual/geometry/mesh")
240
+ if visual is not None:
241
+ visual.set(
242
+ "filename", os.path.join(self.output_mesh_dir, obj_name)
243
+ )
244
+ visual.set("scale", "1.0 1.0 1.0")
245
+
246
+ # Update collision geometry
247
+ collision = link.find("collision/geometry/mesh")
248
+ if collision is not None:
249
+ collision.set(
250
+ "filename", os.path.join(self.output_mesh_dir, obj_name)
251
+ )
252
+ collision.set("scale", "1.0 1.0 1.0")
253
+
254
+ # Update friction coefficients
255
+ gazebo = link.find("collision/gazebo")
256
+ if gazebo is not None:
257
+ for param, key in zip(["mu1", "mu2"], ["mu1", "mu2"]):
258
+ element = gazebo.find(param)
259
+ if element is not None:
260
+ element.text = f"{attr_dict[key]:.2f}"
261
+
262
+ # Update mass
263
+ inertial = link.find("inertial/mass")
264
+ if inertial is not None:
265
+ mass_value = (attr_dict["min_mass"] + attr_dict["max_mass"]) / 2
266
+ inertial.set("value", f"{mass_value:.4f}")
267
+
268
+ # Add extra_info element to the link
269
+ extra_info = link.find("extra_info/scale")
270
+ if extra_info is not None:
271
+ extra_info.text = f"{scale:.6f}"
272
+
273
+ for key in self.attrs_name:
274
+ extra_info = link.find(f"extra_info/{key}")
275
+ if extra_info is not None and key in attr_dict:
276
+ extra_info.text = f"{attr_dict[key]}"
277
+
278
+ # 7. Write URDF to file
279
+ os.makedirs(output_dir, exist_ok=True)
280
+ urdf_path = os.path.join(output_dir, f"{output_name}.urdf")
281
+ tree = ET.ElementTree(robot)
282
+ tree.write(urdf_path, encoding="utf-8", xml_declaration=True)
283
+
284
+ logger.info(f"URDF file saved to {urdf_path}")
285
+
286
+ return urdf_path
287
+
288
+ @staticmethod
289
+ def get_attr_from_urdf(
290
+ urdf_path: str,
291
+ attr_root: str = ".//link/extra_info",
292
+ attr_name: str = "scale",
293
+ ) -> float:
294
+ if not os.path.exists(urdf_path):
295
+ raise FileNotFoundError(f"URDF file not found: {urdf_path}")
296
+
297
+ mesh_scale = 1.0
298
+ tree = ET.parse(urdf_path)
299
+ root = tree.getroot()
300
+ extra_info = root.find(attr_root)
301
+ if extra_info is not None:
302
+ scale_element = extra_info.find(attr_name)
303
+ if scale_element is not None:
304
+ mesh_scale = float(scale_element.text)
305
+
306
+ return mesh_scale
307
+
308
+ @staticmethod
309
+ def add_quality_tag(
310
+ urdf_path: str, results, output_path: str = None
311
+ ) -> None:
312
+ if output_path is None:
313
+ output_path = urdf_path
314
+
315
+ tree = ET.parse(urdf_path)
316
+ root = tree.getroot()
317
+ custom_data = ET.SubElement(root, "custom_data")
318
+ quality = ET.SubElement(custom_data, "quality")
319
+ for key, value in results:
320
+ checker_tag = ET.SubElement(quality, key)
321
+ checker_tag.text = str(value)
322
+
323
+ rough_string = ET.tostring(root, encoding="utf-8")
324
+ formatted_string = parseString(rough_string).toprettyxml(indent=" ")
325
+ cleaned_string = "\n".join(
326
+ [line for line in formatted_string.splitlines() if line.strip()]
327
+ )
328
+
329
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
330
+ with open(output_path, "w", encoding="utf-8") as f:
331
+ f.write(cleaned_string)
332
+
333
+ logger.info(f"URDF files saved to {output_path}")
334
+
335
+ def get_estimated_attributes(self, asset_attrs: dict):
336
+ estimated_attrs = {
337
+ "height": round(
338
+ (asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4
339
+ ),
340
+ "mass": round(
341
+ (asset_attrs["min_mass"] + asset_attrs["max_mass"]) / 2, 4
342
+ ),
343
+ "mu": round((asset_attrs["mu1"] + asset_attrs["mu2"]) / 2, 4),
344
+ "category": asset_attrs["category"],
345
+ }
346
+
347
+ return estimated_attrs
348
+
349
+ def __call__(
350
+ self,
351
+ mesh_path: str,
352
+ output_root: str,
353
+ text_prompt: str = None,
354
+ category: str = "unknown",
355
+ **kwargs,
356
+ ):
357
+ if text_prompt is None or len(text_prompt) == 0:
358
+ text_prompt = self.prompt_template
359
+ text_prompt = text_prompt.format(category=category.lower())
360
+
361
+ image_path = render_asset3d(
362
+ mesh_path,
363
+ output_root,
364
+ num_images=self.render_view_num,
365
+ output_subdir=self.output_render_dir,
366
+ )
367
+
368
+ # Hardcode tmp because of the openrouter can't input multi images.
369
+ if "openrouter" in self.gpt_client.endpoint:
370
+ from asset3d_gen.utils.process_media import (
371
+ combine_images_to_base64,
372
+ )
373
+
374
+ image_path = combine_images_to_base64(image_path)
375
+
376
+ response = self.gpt_client.query(text_prompt, image_path)
377
+ if response is None:
378
+ asset_attrs = {
379
+ "category": "unknown",
380
+ "description": "unknown",
381
+ "min_height": 1,
382
+ "max_height": 1,
383
+ "real_height": 1,
384
+ "min_mass": 1,
385
+ "max_mass": 1,
386
+ "mu1": 0.8,
387
+ "mu2": 0.6,
388
+ "version": VERSION,
389
+ "generate_time": datetime.now().strftime("%Y%m%d%H%M%S"),
390
+ }
391
+ else:
392
+ asset_attrs = self.parse_response(response)
393
+ for key in self.attrs_name:
394
+ if key in kwargs:
395
+ asset_attrs[key] = kwargs[key]
396
+
397
+ self.estimated_attrs = self.get_estimated_attributes(asset_attrs)
398
+
399
+ urdf_path = self.generate_urdf(mesh_path, output_root, asset_attrs)
400
+
401
+ logger.info(f"response: {response}")
402
+
403
+ return urdf_path
404
+
405
+
406
+ if __name__ == "__main__":
407
+ urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
408
+ urdf_path = urdf_gen(
409
+ mesh_path="scripts/apps/assets/example_texture/meshes/robot.obj",
410
+ output_root="outputs/test_urdf",
411
+ # category="coffee machine",
412
+ # min_height=1.0,
413
+ # max_height=1.2,
414
+ version=VERSION,
415
+ )
416
+
417
+ # zip_files(
418
+ # input_paths=[
419
+ # "scripts/apps/tmp/2umpdum3e5n/URDF_sample/mesh",
420
+ # "scripts/apps/tmp/2umpdum3e5n/URDF_sample/sample.urdf"
421
+ # ],
422
+ # output_zip="zip.zip"
423
+ # )
common.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+ import os
4
+ import sys
5
+ from glob import glob
6
+ from typing import Union
7
+
8
+ import cv2
9
+ import gradio as gr
10
+ import numpy as np
11
+ import torch
12
+ import trimesh
13
+ from easydict import EasyDict as edict
14
+ from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import (
15
+ StableDiffusionXLPipeline,
16
+ )
17
+ from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa
18
+ StableDiffusionXLPipeline as StableDiffusionXLPipelineIP,
19
+ )
20
+ from PIL import Image
21
+ from tqdm import tqdm
22
+ from asset3d_gen.data.backproject_v2 import entrypoint as backproject_api
23
+ from asset3d_gen.models.delight import DelightingModel
24
+ from asset3d_gen.models.gs_model import GaussianOperator
25
+ from asset3d_gen.models.segment import (
26
+ RembgRemover,
27
+ SAMPredictor,
28
+ trellis_preprocess,
29
+ )
30
+ from asset3d_gen.models.super_resolution import ImageRealESRGAN, ImageStableSR
31
+ from asset3d_gen.scripts.render_gs import entrypoint as render_gs_api
32
+ from asset3d_gen.scripts.text2image import text2img_gen
33
+ from asset3d_gen.utils.process_media import (
34
+ filter_image_small_connected_components,
35
+ merge_images_video,
36
+ render_asset3d,
37
+ )
38
+ from asset3d_gen.utils.tags import VERSION
39
+ from asset3d_gen.validators.quality_checkers import (
40
+ BaseChecker,
41
+ ImageAestheticChecker,
42
+ ImageSegChecker,
43
+ MeshGeoChecker,
44
+ )
45
+ from asset3d_gen.validators.urdf_convertor import URDFGenerator, zip_files
46
+
47
+ current_file_path = os.path.abspath(__file__)
48
+ current_dir = os.path.dirname(current_file_path)
49
+ sys.path.append(os.path.join(current_dir, "../.."))
50
+ from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
51
+ from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
52
+ from thirdparty.TRELLIS.trellis.representations import (
53
+ Gaussian,
54
+ MeshExtractResult,
55
+ )
56
+ from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
57
+ from thirdparty.TRELLIS.trellis.utils.render_utils import (
58
+ render_frames,
59
+ yaw_pitch_r_fov_to_extrinsics_intrinsics,
60
+ )
61
+ import spaces
62
+
63
+
64
+ logging.basicConfig(
65
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
66
+ )
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ MAX_SEED = 100000
71
+
72
+
73
+ @spaces.GPU
74
+ def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
75
+ renderer = MeshRenderer()
76
+ renderer.rendering_options.resolution = options.get("resolution", 512)
77
+ renderer.rendering_options.near = options.get("near", 1)
78
+ renderer.rendering_options.far = options.get("far", 100)
79
+ renderer.rendering_options.ssaa = options.get("ssaa", 4)
80
+ rets = {}
81
+ for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
82
+ res = renderer.render(sample, extr, intr)
83
+ if "normal" not in rets:
84
+ rets["normal"] = []
85
+ normal = torch.lerp(
86
+ torch.zeros_like(res["normal"]), res["normal"], res["mask"]
87
+ )
88
+ normal = np.clip(
89
+ normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
90
+ ).astype(np.uint8)
91
+ rets["normal"].append(normal)
92
+
93
+ return rets
94
+
95
+
96
+ @spaces.GPU
97
+ def render_video(
98
+ sample,
99
+ resolution=512,
100
+ bg_color=(0, 0, 0),
101
+ num_frames=300,
102
+ r=2,
103
+ fov=40,
104
+ **kwargs,
105
+ ):
106
+ yaws = torch.linspace(0, 2 * 3.1415, num_frames)
107
+ yaws = yaws.tolist()
108
+ pitch = [0.5] * num_frames
109
+ extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
110
+ yaws, pitch, r, fov
111
+ )
112
+ render_fn = (
113
+ render_mesh if isinstance(sample, MeshExtractResult) else render_frames
114
+ )
115
+ result = render_fn(
116
+ sample,
117
+ extrinsics,
118
+ intrinsics,
119
+ {"resolution": resolution, "bg_color": bg_color},
120
+ **kwargs,
121
+ )
122
+
123
+ return result
124
+
125
+
126
+ @spaces.GPU
127
+ def preprocess_image_fn(
128
+ image: str | np.ndarray | Image.Image,
129
+ model: DelightingModel | RembgRemover,
130
+ buffer: dict = None,
131
+ ) -> Image.Image:
132
+ if isinstance(image, str):
133
+ image = Image.open(image)
134
+ elif isinstance(image, np.ndarray):
135
+ image = Image.fromarray(image)
136
+
137
+ if buffer is not None:
138
+ buffer["raw_image"] = image
139
+
140
+ if isinstance(model, DelightingModel):
141
+ image = model(image, preprocess=True, target_wh=(512, 512))
142
+ elif isinstance(model, RembgRemover):
143
+ image = model(image)
144
+ image = trellis_preprocess(image)
145
+
146
+ return image
147
+
148
+
149
+ @spaces.GPU
150
+ def preprocess_sam_image_fn(
151
+ image: Image.Image, buffer: dict, model: SAMPredictor
152
+ ) -> Image.Image:
153
+ if isinstance(image, np.ndarray):
154
+ image = Image.fromarray(image)
155
+
156
+ buffer["raw_image"] = image
157
+ sam_image = model.preprocess_image(image)
158
+ model.predictor.set_image(sam_image)
159
+
160
+ return sam_image
161
+
162
+
163
+ def active_btn_by_content(content: gr.Image) -> gr.Button:
164
+ interactive = True if content is not None else False
165
+
166
+ return gr.Button(interactive=interactive)
167
+
168
+
169
+ def active_btn_by_text_content(content: gr.Textbox) -> gr.Button:
170
+ if content is not None and len(content) > 0:
171
+ interactive = True
172
+ else:
173
+ interactive = False
174
+
175
+ return gr.Button(interactive=interactive)
176
+
177
+
178
+ def get_selected_image(
179
+ choice: str, sample1: str, sample2: str, sample3: str
180
+ ) -> str:
181
+ if choice == "sample1":
182
+ return sample1
183
+ elif choice == "sample2":
184
+ return sample2
185
+ elif choice == "sample3":
186
+ return sample3
187
+ else:
188
+ raise ValueError(f"Invalid choice: {choice}")
189
+
190
+
191
+ @spaces.GPU
192
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
193
+ return {
194
+ "gaussian": {
195
+ **gs.init_params,
196
+ "_xyz": gs._xyz.cpu().numpy(),
197
+ "_features_dc": gs._features_dc.cpu().numpy(),
198
+ "_scaling": gs._scaling.cpu().numpy(),
199
+ "_rotation": gs._rotation.cpu().numpy(),
200
+ "_opacity": gs._opacity.cpu().numpy(),
201
+ },
202
+ "mesh": {
203
+ "vertices": mesh.vertices.cpu().numpy(),
204
+ "faces": mesh.faces.cpu().numpy(),
205
+ },
206
+ }
207
+
208
+
209
+ @spaces.GPU
210
+ def unpack_state(state: dict) -> tuple[Gaussian, edict, str]:
211
+ gs = Gaussian(
212
+ aabb=state["gaussian"]["aabb"],
213
+ sh_degree=state["gaussian"]["sh_degree"],
214
+ mininum_kernel_size=state["gaussian"]["mininum_kernel_size"],
215
+ scaling_bias=state["gaussian"]["scaling_bias"],
216
+ opacity_bias=state["gaussian"]["opacity_bias"],
217
+ scaling_activation=state["gaussian"]["scaling_activation"],
218
+ )
219
+ gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device="cuda")
220
+ gs._features_dc = torch.tensor(
221
+ state["gaussian"]["_features_dc"], device="cuda"
222
+ )
223
+ gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device="cuda")
224
+ gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device="cuda")
225
+ gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device="cuda")
226
+
227
+ mesh = edict(
228
+ vertices=torch.tensor(state["mesh"]["vertices"], device="cuda"),
229
+ faces=torch.tensor(state["mesh"]["faces"], device="cuda"),
230
+ )
231
+
232
+ return gs, mesh
233
+
234
+
235
+ def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int:
236
+ return np.random.randint(0, max_seed) if randomize_seed else seed
237
+
238
+
239
+ @spaces.GPU
240
+ def select_point(
241
+ image: np.ndarray,
242
+ sel_pix: list,
243
+ point_type: str,
244
+ model: SAMPredictor,
245
+ evt: gr.SelectData,
246
+ ):
247
+ if point_type == "foreground_point":
248
+ sel_pix.append((evt.index, 1)) # append the foreground_point
249
+ elif point_type == "background_point":
250
+ sel_pix.append((evt.index, 0)) # append the background_point
251
+ else:
252
+ sel_pix.append((evt.index, 1)) # default foreground_point
253
+
254
+ masks = model.generate_masks(image, sel_pix)
255
+ seg_image = model.get_segmented_image(image, masks)
256
+
257
+ for point, label in sel_pix:
258
+ color = (255, 0, 0) if label == 0 else (0, 255, 0)
259
+ marker_type = 1 if label == 0 else 5
260
+ cv2.drawMarker(
261
+ image,
262
+ point,
263
+ color,
264
+ markerType=marker_type,
265
+ markerSize=15,
266
+ thickness=10,
267
+ )
268
+
269
+ torch.cuda.empty_cache()
270
+
271
+ return (image, masks), seg_image
272
+
273
+
274
+ @spaces.GPU
275
+ def image_to_3d(
276
+ image: Image.Image,
277
+ seed: int,
278
+ ss_guidance_strength: float,
279
+ ss_sampling_steps: int,
280
+ slat_guidance_strength: float,
281
+ slat_sampling_steps: int,
282
+ buffer: dict,
283
+ pipeline: TrellisImageTo3DPipeline,
284
+ output_root: str,
285
+ sam_image: Image.Image = None,
286
+ is_sam_image: bool = False,
287
+ req: gr.Request = None,
288
+ ) -> tuple[dict, str]:
289
+ if is_sam_image:
290
+ seg_image = filter_image_small_connected_components(sam_image)
291
+ seg_image = Image.fromarray(seg_image, mode="RGBA")
292
+ seg_image = trellis_preprocess(seg_image)
293
+ # seg_image.save(f"{TMP_DIR}/seg_image_sam.png")
294
+ else:
295
+ seg_image = image
296
+
297
+ if isinstance(seg_image, np.ndarray):
298
+ seg_image = Image.fromarray(seg_image)
299
+ buffer["seg_image"] = seg_image
300
+
301
+ pipeline.cuda()
302
+ outputs = pipeline.run(
303
+ seg_image,
304
+ seed=seed,
305
+ formats=["gaussian", "mesh"],
306
+ preprocess_image=False,
307
+ sparse_structure_sampler_params={
308
+ "steps": ss_sampling_steps,
309
+ "cfg_strength": ss_guidance_strength,
310
+ },
311
+ slat_sampler_params={
312
+ "steps": slat_sampling_steps,
313
+ "cfg_strength": slat_guidance_strength,
314
+ },
315
+ )
316
+ # Set to cpu for memory saving.
317
+ pipeline.cpu()
318
+
319
+ gs_model = outputs["gaussian"][0]
320
+ mesh_model = outputs["mesh"][0]
321
+ color_images = render_video(gs_model)["color"]
322
+ normal_images = render_video(mesh_model)["normal"]
323
+ if req is not None:
324
+ output_root = os.path.join(output_root, str(req.session_hash))
325
+ video_path = os.path.join(output_root, "gs_mesh.mp4")
326
+ merge_images_video(color_images, normal_images, video_path)
327
+ state = pack_state(gs_model, mesh_model)
328
+
329
+ gc.collect()
330
+ torch.cuda.empty_cache()
331
+
332
+ return state, video_path
333
+
334
+
335
+ @spaces.GPU
336
+ def extract_3d_representations(
337
+ state: dict, enable_delight: bool, output_root: str, req: gr.Request
338
+ ):
339
+ user_dir = os.path.join(output_root, str(req.session_hash))
340
+ gs_model, mesh_model = unpack_state(state)
341
+
342
+ mesh = postprocessing_utils.to_glb(
343
+ gs_model,
344
+ mesh_model,
345
+ simplify=0.9,
346
+ texture_size=1024,
347
+ verbose=True,
348
+ )
349
+ filename = "sample"
350
+ gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
351
+ gs_model.save_ply(gs_path)
352
+
353
+ # Rotate mesh and GS by 90 degrees around Z-axis.
354
+ rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
355
+ # Addtional rotation for GS to align mesh.
356
+ gs_rot = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) @ np.array(
357
+ rot_matrix
358
+ )
359
+ pose = GaussianOperator.trans_to_quatpose(gs_rot)
360
+ aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
361
+ GaussianOperator.resave_ply(
362
+ in_ply=gs_path,
363
+ out_ply=aligned_gs_path,
364
+ instance_pose=pose,
365
+ )
366
+
367
+ mesh.vertices = mesh.vertices @ np.array(rot_matrix)
368
+ mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
369
+ mesh.export(mesh_obj_path)
370
+ mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
371
+ mesh.export(mesh_glb_path)
372
+
373
+ torch.cuda.empty_cache()
374
+
375
+ return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
376
+
377
+
378
+ @spaces.GPU
379
+ def extract_3d_representations_v2(
380
+ state: dict,
381
+ enable_delight: bool,
382
+ output_root: str,
383
+ delight_model: DelightingModel,
384
+ sr_model: Union[ImageRealESRGAN, ImageStableSR],
385
+ req: gr.Request,
386
+ ):
387
+ user_dir = os.path.join(output_root, str(req.session_hash))
388
+ gs_model, mesh_model = unpack_state(state)
389
+
390
+ filename = "sample"
391
+ gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
392
+ gs_model.save_ply(gs_path)
393
+
394
+ # Rotate mesh and GS by 90 degrees around Z-axis.
395
+ rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
396
+ gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
397
+ mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
398
+
399
+ # Addtional rotation for GS to align mesh.
400
+ gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
401
+ pose = GaussianOperator.trans_to_quatpose(gs_rot)
402
+ aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
403
+ GaussianOperator.resave_ply(
404
+ in_ply=gs_path,
405
+ out_ply=aligned_gs_path,
406
+ instance_pose=pose,
407
+ )
408
+ color_path = os.path.join(user_dir, "color.png")
409
+ render_gs_api(aligned_gs_path, color_path)
410
+
411
+ mesh = trimesh.Trimesh(
412
+ vertices=mesh_model.vertices.cpu().numpy(),
413
+ faces=mesh_model.faces.cpu().numpy(),
414
+ )
415
+ mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
416
+ mesh.vertices = mesh.vertices @ np.array(rot_matrix)
417
+
418
+ mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
419
+ mesh.export(mesh_obj_path)
420
+
421
+ mesh = backproject_api(
422
+ delight_model=delight_model,
423
+ imagesr_model=sr_model,
424
+ color_path=color_path,
425
+ mesh_path=mesh_obj_path,
426
+ output_path=mesh_obj_path,
427
+ skip_fix_mesh=False,
428
+ delight=enable_delight,
429
+ )
430
+
431
+ mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
432
+ mesh.export(mesh_glb_path)
433
+
434
+ torch.cuda.empty_cache()
435
+
436
+ return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
437
+
438
+
439
+ @spaces.GPU
440
+ def extract_urdf(
441
+ gs_path: str,
442
+ mesh_obj_path: str,
443
+ asset_cat_text: str,
444
+ height_range_text: str,
445
+ mass_range_text: str,
446
+ asset_version_text: str,
447
+ output_root: str,
448
+ urdf_convertor: URDFGenerator,
449
+ buffer: dict,
450
+ checkers: list[BaseChecker],
451
+ req: gr.Request = None,
452
+ ):
453
+ if req is not None:
454
+ output_root = os.path.join(output_root, str(req.session_hash))
455
+ # Convert to URDF and recover attrs by gpt4o
456
+ filename = "sample"
457
+ asset_attrs = {
458
+ "version": VERSION,
459
+ "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
460
+ }
461
+ if asset_version_text:
462
+ asset_attrs["version"] = asset_version_text
463
+ if asset_cat_text:
464
+ asset_attrs["category"] = asset_cat_text.lower()
465
+ if height_range_text:
466
+ try:
467
+ min_height, max_height = map(float, height_range_text.split("-"))
468
+ asset_attrs["min_height"] = min_height
469
+ asset_attrs["max_height"] = max_height
470
+ except ValueError:
471
+ return "Invalid height input format. Use the format: min-max."
472
+ if mass_range_text:
473
+ try:
474
+ min_mass, max_mass = map(float, mass_range_text.split("-"))
475
+ asset_attrs["min_mass"] = min_mass
476
+ asset_attrs["max_mass"] = max_mass
477
+ except ValueError:
478
+ return "Invalid mass input format. Use the format: min-max."
479
+
480
+ urdf_path = urdf_convertor(
481
+ mesh_path=mesh_obj_path,
482
+ output_root=f"{output_root}/URDF_{filename}",
483
+ **asset_attrs,
484
+ )
485
+
486
+ # Rescale GS and save to URDF/mesh folder.
487
+ real_height = urdf_convertor.get_attr_from_urdf(
488
+ urdf_path, attr_name="real_height"
489
+ )
490
+ out_gs = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa
491
+ GaussianOperator.resave_ply(
492
+ in_ply=gs_path,
493
+ out_ply=out_gs,
494
+ real_height=real_height,
495
+ )
496
+
497
+ # Quality check and update .urdf file.
498
+ mesh_out = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa
499
+ trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb"))
500
+ # image_paths = render_asset3d(
501
+ # mesh_path=mesh_out,
502
+ # output_root=f"{output_root}/URDF_{filename}",
503
+ # output_subdir="qa_renders",
504
+ # num_images=8,
505
+ # elevation=(30, -30),
506
+ # distance=5.5,
507
+ # )
508
+
509
+ image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa
510
+ image_paths = glob(f"{image_dir}/*.png")
511
+ images_list = []
512
+ for checker in checkers:
513
+ images = image_paths
514
+ if isinstance(checker, ImageSegChecker):
515
+ images = [buffer["raw_image"], buffer["seg_image"]]
516
+ images_list.append(images)
517
+
518
+ results = BaseChecker.validate(checkers, images_list)
519
+ urdf_convertor.add_quality_tag(urdf_path, results)
520
+
521
+ # Zip urdf files
522
+ urdf_zip = zip_files(
523
+ input_paths=[
524
+ f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}",
525
+ f"{output_root}/URDF_{filename}/{filename}.urdf",
526
+ ],
527
+ output_zip=f"{output_root}/urdf_{filename}.zip",
528
+ )
529
+
530
+ torch.cuda.empty_cache()
531
+
532
+ estimated_type = urdf_convertor.estimated_attrs["category"]
533
+ estimated_height = urdf_convertor.estimated_attrs["height"]
534
+ estimated_mass = urdf_convertor.estimated_attrs["mass"]
535
+ estimated_mu = urdf_convertor.estimated_attrs["mu"]
536
+
537
+ return (
538
+ urdf_zip,
539
+ estimated_type,
540
+ estimated_height,
541
+ estimated_mass,
542
+ estimated_mu,
543
+ )
544
+
545
+
546
+ @spaces.GPU
547
+ def text2image_fn(
548
+ prompt: str,
549
+ output_root: str,
550
+ guidance_scale: float,
551
+ model_ip: StableDiffusionXLPipelineIP,
552
+ model_img: StableDiffusionXLPipeline,
553
+ bg_model: RembgRemover,
554
+ infer_step: int = 50,
555
+ ip_image: Image.Image | str = None,
556
+ ip_adapt_scale: float = 0.3,
557
+ image_wh: int | tuple[int, int] = [1024, 1024],
558
+ n_sample: int = 3,
559
+ postprocess: bool = True,
560
+ req: gr.Request = None,
561
+ ):
562
+ if isinstance(image_wh, int):
563
+ image_wh = (image_wh, image_wh)
564
+ if req is not None:
565
+ output_root = os.path.join(output_root, str(req.session_hash))
566
+ os.makedirs(output_root, exist_ok=True)
567
+
568
+ pipeline = model_img if ip_image is None else model_ip
569
+ if ip_image is not None:
570
+ pipeline.set_ip_adapter_scale([ip_adapt_scale])
571
+
572
+ images = text2img_gen(
573
+ prompt=prompt,
574
+ n_sample=n_sample,
575
+ guidance_scale=guidance_scale,
576
+ pipeline=pipeline,
577
+ ip_image=ip_image,
578
+ image_wh=image_wh,
579
+ infer_step=infer_step,
580
+ )
581
+ if postprocess:
582
+ for idx in range(len(images)):
583
+ image = images[idx]
584
+ images[idx] = preprocess_image_fn(image, bg_model)
585
+
586
+ save_paths = []
587
+ for idx, image in enumerate(images):
588
+ save_path = f"{output_root}/sample_{idx}.png"
589
+ image.save(save_path)
590
+ save_paths.append(save_path)
591
+
592
+ logger.info(f"Images saved to {output_root}")
593
+
594
+ gc.collect()
595
+ torch.cuda.empty_cache()
596
+
597
+ return save_paths + save_paths
requirements.txt CHANGED
@@ -1,9 +1,9 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu118
2
 
3
- torch==2.1.0
4
- torchaudio==2.1.0
5
- torchvision==0.16.0
6
- xformers==0.0.22.post7
7
  dataclasses_json
8
  easydict
9
  opencv-python>4.5
@@ -21,7 +21,7 @@ openai==1.58.1
21
  spconv-cu118==2.3.8
22
  transformers==4.42.4
23
  gradio_litmodel3d==0.0.1
24
- # gradio==5.12.0
25
  sentencepiece==0.2.0
26
  diffusers==0.31.0
27
  xatlas==0.0.9
@@ -33,6 +33,7 @@ basicsr==1.4.2
33
  realesrgan==0.3.0
34
  pydantic==2.9.2
35
  vtk==9.3.1
 
36
  utils3d@git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
37
  clip@git+https://github.com/openai/CLIP.git
38
  kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d
 
1
+ # --extra-index-url https://download.pytorch.org/whl/cu118
2
 
3
+ torch==2.1.0+cu118
4
+ torchaudio==2.1.0+cu118
5
+ torchvision==0.16.0+cu118
6
+ xformers==0.0.22.post7+cu118
7
  dataclasses_json
8
  easydict
9
  opencv-python>4.5
 
21
  spconv-cu118==2.3.8
22
  transformers==4.42.4
23
  gradio_litmodel3d==0.0.1
24
+ gradio==5.12.0
25
  sentencepiece==0.2.0
26
  diffusers==0.31.0
27
  xatlas==0.0.9
 
33
  realesrgan==0.3.0
34
  pydantic==2.9.2
35
  vtk==9.3.1
36
+ spaces
37
  utils3d@git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
38
  clip@git+https://github.com/openai/CLIP.git
39
  kolors@git+https://github.com/Kwai-Kolors/Kolors.git#egg=038818d
thirdparty/TRELLIS/trellis/trellis/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from . import models
2
+ from . import modules
3
+ from . import pipelines
4
+ from . import renderers
5
+ from . import representations
6
+ from . import utils
thirdparty/TRELLIS/trellis/trellis/models/__init__.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __attributes = {
4
+ 'SparseStructureEncoder': 'sparse_structure_vae',
5
+ 'SparseStructureDecoder': 'sparse_structure_vae',
6
+ 'SparseStructureFlowModel': 'sparse_structure_flow',
7
+ 'SLatEncoder': 'structured_latent_vae',
8
+ 'SLatGaussianDecoder': 'structured_latent_vae',
9
+ 'SLatRadianceFieldDecoder': 'structured_latent_vae',
10
+ 'SLatMeshDecoder': 'structured_latent_vae',
11
+ 'SLatFlowModel': 'structured_latent_flow',
12
+ }
13
+
14
+ __submodules = []
15
+
16
+ __all__ = list(__attributes.keys()) + __submodules
17
+
18
+ def __getattr__(name):
19
+ if name not in globals():
20
+ if name in __attributes:
21
+ module_name = __attributes[name]
22
+ module = importlib.import_module(f".{module_name}", __name__)
23
+ globals()[name] = getattr(module, name)
24
+ elif name in __submodules:
25
+ module = importlib.import_module(f".{name}", __name__)
26
+ globals()[name] = module
27
+ else:
28
+ raise AttributeError(f"module {__name__} has no attribute {name}")
29
+ return globals()[name]
30
+
31
+
32
+ def from_pretrained(path: str, **kwargs):
33
+ """
34
+ Load a model from a pretrained checkpoint.
35
+
36
+ Args:
37
+ path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
38
+ NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
39
+ **kwargs: Additional arguments for the model constructor.
40
+ """
41
+ import os
42
+ import json
43
+ from safetensors.torch import load_file
44
+ is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
45
+
46
+ if is_local:
47
+ config_file = f"{path}.json"
48
+ model_file = f"{path}.safetensors"
49
+ else:
50
+ from huggingface_hub import hf_hub_download
51
+ path_parts = path.split('/')
52
+ repo_id = f'{path_parts[0]}/{path_parts[1]}'
53
+ model_name = '/'.join(path_parts[2:])
54
+ config_file = hf_hub_download(repo_id, f"{model_name}.json")
55
+ model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
56
+
57
+ with open(config_file, 'r') as f:
58
+ config = json.load(f)
59
+ model = __getattr__(config['name'])(**config['args'], **kwargs)
60
+ model.load_state_dict(load_file(model_file))
61
+
62
+ return model
63
+
64
+
65
+ # For Pylance
66
+ if __name__ == '__main__':
67
+ from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder
68
+ from .sparse_structure_flow import SparseStructureFlowModel
69
+ from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatRadianceFieldDecoder, SLatMeshDecoder
70
+ from .structured_latent_flow import SLatFlowModel
thirdparty/TRELLIS/trellis/trellis/models/sparse_structure_flow.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from ..modules.utils import convert_module_to_f16, convert_module_to_f32
7
+ from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
8
+ from ..modules.spatial import patchify, unpatchify
9
+
10
+
11
+ class TimestepEmbedder(nn.Module):
12
+ """
13
+ Embeds scalar timesteps into vector representations.
14
+ """
15
+ def __init__(self, hidden_size, frequency_embedding_size=256):
16
+ super().__init__()
17
+ self.mlp = nn.Sequential(
18
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
19
+ nn.SiLU(),
20
+ nn.Linear(hidden_size, hidden_size, bias=True),
21
+ )
22
+ self.frequency_embedding_size = frequency_embedding_size
23
+
24
+ @staticmethod
25
+ def timestep_embedding(t, dim, max_period=10000):
26
+ """
27
+ Create sinusoidal timestep embeddings.
28
+
29
+ Args:
30
+ t: a 1-D Tensor of N indices, one per batch element.
31
+ These may be fractional.
32
+ dim: the dimension of the output.
33
+ max_period: controls the minimum frequency of the embeddings.
34
+
35
+ Returns:
36
+ an (N, D) Tensor of positional embeddings.
37
+ """
38
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
39
+ half = dim // 2
40
+ freqs = torch.exp(
41
+ -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
42
+ ).to(device=t.device)
43
+ args = t[:, None].float() * freqs[None]
44
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
+ if dim % 2:
46
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47
+ return embedding
48
+
49
+ def forward(self, t):
50
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
51
+ t_emb = self.mlp(t_freq)
52
+ return t_emb
53
+
54
+
55
+ class SparseStructureFlowModel(nn.Module):
56
+ def __init__(
57
+ self,
58
+ resolution: int,
59
+ in_channels: int,
60
+ model_channels: int,
61
+ cond_channels: int,
62
+ out_channels: int,
63
+ num_blocks: int,
64
+ num_heads: Optional[int] = None,
65
+ num_head_channels: Optional[int] = 64,
66
+ mlp_ratio: float = 4,
67
+ patch_size: int = 2,
68
+ pe_mode: Literal["ape", "rope"] = "ape",
69
+ use_fp16: bool = False,
70
+ use_checkpoint: bool = False,
71
+ share_mod: bool = False,
72
+ qk_rms_norm: bool = False,
73
+ qk_rms_norm_cross: bool = False,
74
+ ):
75
+ super().__init__()
76
+ self.resolution = resolution
77
+ self.in_channels = in_channels
78
+ self.model_channels = model_channels
79
+ self.cond_channels = cond_channels
80
+ self.out_channels = out_channels
81
+ self.num_blocks = num_blocks
82
+ self.num_heads = num_heads or model_channels // num_head_channels
83
+ self.mlp_ratio = mlp_ratio
84
+ self.patch_size = patch_size
85
+ self.pe_mode = pe_mode
86
+ self.use_fp16 = use_fp16
87
+ self.use_checkpoint = use_checkpoint
88
+ self.share_mod = share_mod
89
+ self.qk_rms_norm = qk_rms_norm
90
+ self.qk_rms_norm_cross = qk_rms_norm_cross
91
+ self.dtype = torch.float16 if use_fp16 else torch.float32
92
+
93
+ self.t_embedder = TimestepEmbedder(model_channels)
94
+ if share_mod:
95
+ self.adaLN_modulation = nn.Sequential(
96
+ nn.SiLU(),
97
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
98
+ )
99
+
100
+ if pe_mode == "ape":
101
+ pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
102
+ coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij')
103
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3)
104
+ pos_emb = pos_embedder(coords)
105
+ self.register_buffer("pos_emb", pos_emb)
106
+
107
+ self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
108
+
109
+ self.blocks = nn.ModuleList([
110
+ ModulatedTransformerCrossBlock(
111
+ model_channels,
112
+ cond_channels,
113
+ num_heads=self.num_heads,
114
+ mlp_ratio=self.mlp_ratio,
115
+ attn_mode='full',
116
+ use_checkpoint=self.use_checkpoint,
117
+ use_rope=(pe_mode == "rope"),
118
+ share_mod=share_mod,
119
+ qk_rms_norm=self.qk_rms_norm,
120
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
121
+ )
122
+ for _ in range(num_blocks)
123
+ ])
124
+
125
+ self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
126
+
127
+ self.initialize_weights()
128
+ if use_fp16:
129
+ self.convert_to_fp16()
130
+
131
+ @property
132
+ def device(self) -> torch.device:
133
+ """
134
+ Return the device of the model.
135
+ """
136
+ return next(self.parameters()).device
137
+
138
+ def convert_to_fp16(self) -> None:
139
+ """
140
+ Convert the torso of the model to float16.
141
+ """
142
+ self.blocks.apply(convert_module_to_f16)
143
+
144
+ def convert_to_fp32(self) -> None:
145
+ """
146
+ Convert the torso of the model to float32.
147
+ """
148
+ self.blocks.apply(convert_module_to_f32)
149
+
150
+ def initialize_weights(self) -> None:
151
+ # Initialize transformer layers:
152
+ def _basic_init(module):
153
+ if isinstance(module, nn.Linear):
154
+ torch.nn.init.xavier_uniform_(module.weight)
155
+ if module.bias is not None:
156
+ nn.init.constant_(module.bias, 0)
157
+ self.apply(_basic_init)
158
+
159
+ # Initialize timestep embedding MLP:
160
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
161
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
162
+
163
+ # Zero-out adaLN modulation layers in DiT blocks:
164
+ if self.share_mod:
165
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
166
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
167
+ else:
168
+ for block in self.blocks:
169
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
170
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
171
+
172
+ # Zero-out output layers:
173
+ nn.init.constant_(self.out_layer.weight, 0)
174
+ nn.init.constant_(self.out_layer.bias, 0)
175
+
176
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
177
+ assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
178
+ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
179
+
180
+ h = patchify(x, self.patch_size)
181
+ h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
182
+
183
+ h = self.input_layer(h)
184
+ h = h + self.pos_emb[None]
185
+ t_emb = self.t_embedder(t)
186
+ if self.share_mod:
187
+ t_emb = self.adaLN_modulation(t_emb)
188
+ t_emb = t_emb.type(self.dtype)
189
+ h = h.type(self.dtype)
190
+ cond = cond.type(self.dtype)
191
+ for block in self.blocks:
192
+ h = block(h, t_emb, cond)
193
+ h = h.type(x.dtype)
194
+ h = F.layer_norm(h, h.shape[-1:])
195
+ h = self.out_layer(h)
196
+
197
+ h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
198
+ h = unpatchify(h, self.patch_size).contiguous()
199
+
200
+ return h
thirdparty/TRELLIS/trellis/trellis/models/sparse_structure_vae.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from ..modules.norm import GroupNorm32, ChannelLayerNorm32
6
+ from ..modules.spatial import pixel_shuffle_3d
7
+ from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
8
+
9
+
10
+ def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
11
+ """
12
+ Return a normalization layer.
13
+ """
14
+ if norm_type == "group":
15
+ return GroupNorm32(32, *args, **kwargs)
16
+ elif norm_type == "layer":
17
+ return ChannelLayerNorm32(*args, **kwargs)
18
+ else:
19
+ raise ValueError(f"Invalid norm type {norm_type}")
20
+
21
+
22
+ class ResBlock3d(nn.Module):
23
+ def __init__(
24
+ self,
25
+ channels: int,
26
+ out_channels: Optional[int] = None,
27
+ norm_type: Literal["group", "layer"] = "layer",
28
+ ):
29
+ super().__init__()
30
+ self.channels = channels
31
+ self.out_channels = out_channels or channels
32
+
33
+ self.norm1 = norm_layer(norm_type, channels)
34
+ self.norm2 = norm_layer(norm_type, self.out_channels)
35
+ self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
36
+ self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
37
+ self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ h = self.norm1(x)
41
+ h = F.silu(h)
42
+ h = self.conv1(h)
43
+ h = self.norm2(h)
44
+ h = F.silu(h)
45
+ h = self.conv2(h)
46
+ h = h + self.skip_connection(x)
47
+ return h
48
+
49
+
50
+ class DownsampleBlock3d(nn.Module):
51
+ def __init__(
52
+ self,
53
+ in_channels: int,
54
+ out_channels: int,
55
+ mode: Literal["conv", "avgpool"] = "conv",
56
+ ):
57
+ assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
58
+
59
+ super().__init__()
60
+ self.in_channels = in_channels
61
+ self.out_channels = out_channels
62
+
63
+ if mode == "conv":
64
+ self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
65
+ elif mode == "avgpool":
66
+ assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ if hasattr(self, "conv"):
70
+ return self.conv(x)
71
+ else:
72
+ return F.avg_pool3d(x, 2)
73
+
74
+
75
+ class UpsampleBlock3d(nn.Module):
76
+ def __init__(
77
+ self,
78
+ in_channels: int,
79
+ out_channels: int,
80
+ mode: Literal["conv", "nearest"] = "conv",
81
+ ):
82
+ assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
83
+
84
+ super().__init__()
85
+ self.in_channels = in_channels
86
+ self.out_channels = out_channels
87
+
88
+ if mode == "conv":
89
+ self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
90
+ elif mode == "nearest":
91
+ assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ if hasattr(self, "conv"):
95
+ x = self.conv(x)
96
+ return pixel_shuffle_3d(x, 2)
97
+ else:
98
+ return F.interpolate(x, scale_factor=2, mode="nearest")
99
+
100
+
101
+ class SparseStructureEncoder(nn.Module):
102
+ """
103
+ Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
104
+
105
+ Args:
106
+ in_channels (int): Channels of the input.
107
+ latent_channels (int): Channels of the latent representation.
108
+ num_res_blocks (int): Number of residual blocks at each resolution.
109
+ channels (List[int]): Channels of the encoder blocks.
110
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
111
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
112
+ use_fp16 (bool): Whether to use FP16.
113
+ """
114
+ def __init__(
115
+ self,
116
+ in_channels: int,
117
+ latent_channels: int,
118
+ num_res_blocks: int,
119
+ channels: List[int],
120
+ num_res_blocks_middle: int = 2,
121
+ norm_type: Literal["group", "layer"] = "layer",
122
+ use_fp16: bool = False,
123
+ ):
124
+ super().__init__()
125
+ self.in_channels = in_channels
126
+ self.latent_channels = latent_channels
127
+ self.num_res_blocks = num_res_blocks
128
+ self.channels = channels
129
+ self.num_res_blocks_middle = num_res_blocks_middle
130
+ self.norm_type = norm_type
131
+ self.use_fp16 = use_fp16
132
+ self.dtype = torch.float16 if use_fp16 else torch.float32
133
+
134
+ self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1)
135
+
136
+ self.blocks = nn.ModuleList([])
137
+ for i, ch in enumerate(channels):
138
+ self.blocks.extend([
139
+ ResBlock3d(ch, ch)
140
+ for _ in range(num_res_blocks)
141
+ ])
142
+ if i < len(channels) - 1:
143
+ self.blocks.append(
144
+ DownsampleBlock3d(ch, channels[i+1])
145
+ )
146
+
147
+ self.middle_block = nn.Sequential(*[
148
+ ResBlock3d(channels[-1], channels[-1])
149
+ for _ in range(num_res_blocks_middle)
150
+ ])
151
+
152
+ self.out_layer = nn.Sequential(
153
+ norm_layer(norm_type, channels[-1]),
154
+ nn.SiLU(),
155
+ nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1)
156
+ )
157
+
158
+ if use_fp16:
159
+ self.convert_to_fp16()
160
+
161
+ @property
162
+ def device(self) -> torch.device:
163
+ """
164
+ Return the device of the model.
165
+ """
166
+ return next(self.parameters()).device
167
+
168
+ def convert_to_fp16(self) -> None:
169
+ """
170
+ Convert the torso of the model to float16.
171
+ """
172
+ self.use_fp16 = True
173
+ self.dtype = torch.float16
174
+ self.blocks.apply(convert_module_to_f16)
175
+ self.middle_block.apply(convert_module_to_f16)
176
+
177
+ def convert_to_fp32(self) -> None:
178
+ """
179
+ Convert the torso of the model to float32.
180
+ """
181
+ self.use_fp16 = False
182
+ self.dtype = torch.float32
183
+ self.blocks.apply(convert_module_to_f32)
184
+ self.middle_block.apply(convert_module_to_f32)
185
+
186
+ def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor:
187
+ h = self.input_layer(x)
188
+ h = h.type(self.dtype)
189
+
190
+ for block in self.blocks:
191
+ h = block(h)
192
+ h = self.middle_block(h)
193
+
194
+ h = h.type(x.dtype)
195
+ h = self.out_layer(h)
196
+
197
+ mean, logvar = h.chunk(2, dim=1)
198
+
199
+ if sample_posterior:
200
+ std = torch.exp(0.5 * logvar)
201
+ z = mean + std * torch.randn_like(std)
202
+ else:
203
+ z = mean
204
+
205
+ if return_raw:
206
+ return z, mean, logvar
207
+ return z
208
+
209
+
210
+ class SparseStructureDecoder(nn.Module):
211
+ """
212
+ Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
213
+
214
+ Args:
215
+ out_channels (int): Channels of the output.
216
+ latent_channels (int): Channels of the latent representation.
217
+ num_res_blocks (int): Number of residual blocks at each resolution.
218
+ channels (List[int]): Channels of the decoder blocks.
219
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
220
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
221
+ use_fp16 (bool): Whether to use FP16.
222
+ """
223
+ def __init__(
224
+ self,
225
+ out_channels: int,
226
+ latent_channels: int,
227
+ num_res_blocks: int,
228
+ channels: List[int],
229
+ num_res_blocks_middle: int = 2,
230
+ norm_type: Literal["group", "layer"] = "layer",
231
+ use_fp16: bool = False,
232
+ ):
233
+ super().__init__()
234
+ self.out_channels = out_channels
235
+ self.latent_channels = latent_channels
236
+ self.num_res_blocks = num_res_blocks
237
+ self.channels = channels
238
+ self.num_res_blocks_middle = num_res_blocks_middle
239
+ self.norm_type = norm_type
240
+ self.use_fp16 = use_fp16
241
+ self.dtype = torch.float16 if use_fp16 else torch.float32
242
+
243
+ self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
244
+
245
+ self.middle_block = nn.Sequential(*[
246
+ ResBlock3d(channels[0], channels[0])
247
+ for _ in range(num_res_blocks_middle)
248
+ ])
249
+
250
+ self.blocks = nn.ModuleList([])
251
+ for i, ch in enumerate(channels):
252
+ self.blocks.extend([
253
+ ResBlock3d(ch, ch)
254
+ for _ in range(num_res_blocks)
255
+ ])
256
+ if i < len(channels) - 1:
257
+ self.blocks.append(
258
+ UpsampleBlock3d(ch, channels[i+1])
259
+ )
260
+
261
+ self.out_layer = nn.Sequential(
262
+ norm_layer(norm_type, channels[-1]),
263
+ nn.SiLU(),
264
+ nn.Conv3d(channels[-1], out_channels, 3, padding=1)
265
+ )
266
+
267
+ if use_fp16:
268
+ self.convert_to_fp16()
269
+
270
+ @property
271
+ def device(self) -> torch.device:
272
+ """
273
+ Return the device of the model.
274
+ """
275
+ return next(self.parameters()).device
276
+
277
+ def convert_to_fp16(self) -> None:
278
+ """
279
+ Convert the torso of the model to float16.
280
+ """
281
+ self.use_fp16 = True
282
+ self.dtype = torch.float16
283
+ self.blocks.apply(convert_module_to_f16)
284
+ self.middle_block.apply(convert_module_to_f16)
285
+
286
+ def convert_to_fp32(self) -> None:
287
+ """
288
+ Convert the torso of the model to float32.
289
+ """
290
+ self.use_fp16 = False
291
+ self.dtype = torch.float32
292
+ self.blocks.apply(convert_module_to_f32)
293
+ self.middle_block.apply(convert_module_to_f32)
294
+
295
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
296
+ h = self.input_layer(x)
297
+
298
+ h = h.type(self.dtype)
299
+
300
+ h = self.middle_block(h)
301
+ for block in self.blocks:
302
+ h = block(h)
303
+
304
+ h = h.type(x.dtype)
305
+ h = self.out_layer(h)
306
+ return h
thirdparty/TRELLIS/trellis/trellis/models/structured_latent_flow.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
7
+ from ..modules.transformer import AbsolutePositionEmbedder
8
+ from ..modules.norm import LayerNorm32
9
+ from ..modules import sparse as sp
10
+ from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
11
+ from .sparse_structure_flow import TimestepEmbedder
12
+
13
+
14
+ class SparseResBlock3d(nn.Module):
15
+ def __init__(
16
+ self,
17
+ channels: int,
18
+ emb_channels: int,
19
+ out_channels: Optional[int] = None,
20
+ downsample: bool = False,
21
+ upsample: bool = False,
22
+ ):
23
+ super().__init__()
24
+ self.channels = channels
25
+ self.emb_channels = emb_channels
26
+ self.out_channels = out_channels or channels
27
+ self.downsample = downsample
28
+ self.upsample = upsample
29
+
30
+ assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
31
+
32
+ self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
33
+ self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
34
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
35
+ self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
36
+ self.emb_layers = nn.Sequential(
37
+ nn.SiLU(),
38
+ nn.Linear(emb_channels, 2 * self.out_channels, bias=True),
39
+ )
40
+ self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
41
+ self.updown = None
42
+ if self.downsample:
43
+ self.updown = sp.SparseDownsample(2)
44
+ elif self.upsample:
45
+ self.updown = sp.SparseUpsample(2)
46
+
47
+ def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor:
48
+ if self.updown is not None:
49
+ x = self.updown(x)
50
+ return x
51
+
52
+ def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor:
53
+ emb_out = self.emb_layers(emb).type(x.dtype)
54
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
55
+
56
+ x = self._updown(x)
57
+ h = x.replace(self.norm1(x.feats))
58
+ h = h.replace(F.silu(h.feats))
59
+ h = self.conv1(h)
60
+ h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift
61
+ h = h.replace(F.silu(h.feats))
62
+ h = self.conv2(h)
63
+ h = h + self.skip_connection(x)
64
+
65
+ return h
66
+
67
+
68
+ class SLatFlowModel(nn.Module):
69
+ def __init__(
70
+ self,
71
+ resolution: int,
72
+ in_channels: int,
73
+ model_channels: int,
74
+ cond_channels: int,
75
+ out_channels: int,
76
+ num_blocks: int,
77
+ num_heads: Optional[int] = None,
78
+ num_head_channels: Optional[int] = 64,
79
+ mlp_ratio: float = 4,
80
+ patch_size: int = 2,
81
+ num_io_res_blocks: int = 2,
82
+ io_block_channels: List[int] = None,
83
+ pe_mode: Literal["ape", "rope"] = "ape",
84
+ use_fp16: bool = False,
85
+ use_checkpoint: bool = False,
86
+ use_skip_connection: bool = True,
87
+ share_mod: bool = False,
88
+ qk_rms_norm: bool = False,
89
+ qk_rms_norm_cross: bool = False,
90
+ ):
91
+ super().__init__()
92
+ self.resolution = resolution
93
+ self.in_channels = in_channels
94
+ self.model_channels = model_channels
95
+ self.cond_channels = cond_channels
96
+ self.out_channels = out_channels
97
+ self.num_blocks = num_blocks
98
+ self.num_heads = num_heads or model_channels // num_head_channels
99
+ self.mlp_ratio = mlp_ratio
100
+ self.patch_size = patch_size
101
+ self.num_io_res_blocks = num_io_res_blocks
102
+ self.io_block_channels = io_block_channels
103
+ self.pe_mode = pe_mode
104
+ self.use_fp16 = use_fp16
105
+ self.use_checkpoint = use_checkpoint
106
+ self.use_skip_connection = use_skip_connection
107
+ self.share_mod = share_mod
108
+ self.qk_rms_norm = qk_rms_norm
109
+ self.qk_rms_norm_cross = qk_rms_norm_cross
110
+ self.dtype = torch.float16 if use_fp16 else torch.float32
111
+
112
+ assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
113
+ assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
114
+
115
+ self.t_embedder = TimestepEmbedder(model_channels)
116
+ if share_mod:
117
+ self.adaLN_modulation = nn.Sequential(
118
+ nn.SiLU(),
119
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
120
+ )
121
+
122
+ if pe_mode == "ape":
123
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
124
+
125
+ self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0])
126
+ self.input_blocks = nn.ModuleList([])
127
+ for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
128
+ self.input_blocks.extend([
129
+ SparseResBlock3d(
130
+ chs,
131
+ model_channels,
132
+ out_channels=chs,
133
+ )
134
+ for _ in range(num_io_res_blocks-1)
135
+ ])
136
+ self.input_blocks.append(
137
+ SparseResBlock3d(
138
+ chs,
139
+ model_channels,
140
+ out_channels=next_chs,
141
+ downsample=True,
142
+ )
143
+ )
144
+
145
+ self.blocks = nn.ModuleList([
146
+ ModulatedSparseTransformerCrossBlock(
147
+ model_channels,
148
+ cond_channels,
149
+ num_heads=self.num_heads,
150
+ mlp_ratio=self.mlp_ratio,
151
+ attn_mode='full',
152
+ use_checkpoint=self.use_checkpoint,
153
+ use_rope=(pe_mode == "rope"),
154
+ share_mod=self.share_mod,
155
+ qk_rms_norm=self.qk_rms_norm,
156
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
157
+ )
158
+ for _ in range(num_blocks)
159
+ ])
160
+
161
+ self.out_blocks = nn.ModuleList([])
162
+ for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
163
+ self.out_blocks.append(
164
+ SparseResBlock3d(
165
+ prev_chs * 2 if self.use_skip_connection else prev_chs,
166
+ model_channels,
167
+ out_channels=chs,
168
+ upsample=True,
169
+ )
170
+ )
171
+ self.out_blocks.extend([
172
+ SparseResBlock3d(
173
+ chs * 2 if self.use_skip_connection else chs,
174
+ model_channels,
175
+ out_channels=chs,
176
+ )
177
+ for _ in range(num_io_res_blocks-1)
178
+ ])
179
+ self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels)
180
+
181
+ self.initialize_weights()
182
+ if use_fp16:
183
+ self.convert_to_fp16()
184
+
185
+ @property
186
+ def device(self) -> torch.device:
187
+ """
188
+ Return the device of the model.
189
+ """
190
+ return next(self.parameters()).device
191
+
192
+ def convert_to_fp16(self) -> None:
193
+ """
194
+ Convert the torso of the model to float16.
195
+ """
196
+ self.input_blocks.apply(convert_module_to_f16)
197
+ self.blocks.apply(convert_module_to_f16)
198
+ self.out_blocks.apply(convert_module_to_f16)
199
+
200
+ def convert_to_fp32(self) -> None:
201
+ """
202
+ Convert the torso of the model to float32.
203
+ """
204
+ self.input_blocks.apply(convert_module_to_f32)
205
+ self.blocks.apply(convert_module_to_f32)
206
+ self.out_blocks.apply(convert_module_to_f32)
207
+
208
+ def initialize_weights(self) -> None:
209
+ # Initialize transformer layers:
210
+ def _basic_init(module):
211
+ if isinstance(module, nn.Linear):
212
+ torch.nn.init.xavier_uniform_(module.weight)
213
+ if module.bias is not None:
214
+ nn.init.constant_(module.bias, 0)
215
+ self.apply(_basic_init)
216
+
217
+ # Initialize timestep embedding MLP:
218
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
219
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
220
+
221
+ # Zero-out adaLN modulation layers in DiT blocks:
222
+ if self.share_mod:
223
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
224
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
225
+ else:
226
+ for block in self.blocks:
227
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
228
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
229
+
230
+ # Zero-out output layers:
231
+ nn.init.constant_(self.out_layer.weight, 0)
232
+ nn.init.constant_(self.out_layer.bias, 0)
233
+
234
+ def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor:
235
+ h = self.input_layer(x).type(self.dtype)
236
+ t_emb = self.t_embedder(t)
237
+ if self.share_mod:
238
+ t_emb = self.adaLN_modulation(t_emb)
239
+ t_emb = t_emb.type(self.dtype)
240
+ cond = cond.type(self.dtype)
241
+
242
+ skips = []
243
+ # pack with input blocks
244
+ for block in self.input_blocks:
245
+ h = block(h, t_emb)
246
+ skips.append(h.feats)
247
+
248
+ if self.pe_mode == "ape":
249
+ h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype)
250
+ for block in self.blocks:
251
+ h = block(h, t_emb, cond)
252
+
253
+ # unpack with output blocks
254
+ for block, skip in zip(self.out_blocks, reversed(skips)):
255
+ if self.use_skip_connection:
256
+ h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb)
257
+ else:
258
+ h = block(h, t_emb)
259
+
260
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
261
+ h = self.out_layer(h.type(x.dtype))
262
+ return h
thirdparty/TRELLIS/trellis/trellis/models/structured_latent_vae/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .encoder import SLatEncoder
2
+ from .decoder_gs import SLatGaussianDecoder
3
+ from .decoder_rf import SLatRadianceFieldDecoder
4
+ from .decoder_mesh import SLatMeshDecoder
thirdparty/TRELLIS/trellis/trellis/models/structured_latent_vae/base.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ...modules.utils import convert_module_to_f16, convert_module_to_f32
5
+ from ...modules import sparse as sp
6
+ from ...modules.transformer import AbsolutePositionEmbedder
7
+ from ...modules.sparse.transformer import SparseTransformerBlock
8
+
9
+
10
+ def block_attn_config(self):
11
+ """
12
+ Return the attention configuration of the model.
13
+ """
14
+ for i in range(self.num_blocks):
15
+ if self.attn_mode == "shift_window":
16
+ yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
17
+ elif self.attn_mode == "shift_sequence":
18
+ yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
19
+ elif self.attn_mode == "shift_order":
20
+ yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
21
+ elif self.attn_mode == "full":
22
+ yield "full", None, None, None, None
23
+ elif self.attn_mode == "swin":
24
+ yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
25
+
26
+
27
+ class SparseTransformerBase(nn.Module):
28
+ """
29
+ Sparse Transformer without output layers.
30
+ Serve as the base class for encoder and decoder.
31
+ """
32
+ def __init__(
33
+ self,
34
+ in_channels: int,
35
+ model_channels: int,
36
+ num_blocks: int,
37
+ num_heads: Optional[int] = None,
38
+ num_head_channels: Optional[int] = 64,
39
+ mlp_ratio: float = 4.0,
40
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
41
+ window_size: Optional[int] = None,
42
+ pe_mode: Literal["ape", "rope"] = "ape",
43
+ use_fp16: bool = False,
44
+ use_checkpoint: bool = False,
45
+ qk_rms_norm: bool = False,
46
+ ):
47
+ super().__init__()
48
+ self.in_channels = in_channels
49
+ self.model_channels = model_channels
50
+ self.num_blocks = num_blocks
51
+ self.window_size = window_size
52
+ self.num_heads = num_heads or model_channels // num_head_channels
53
+ self.mlp_ratio = mlp_ratio
54
+ self.attn_mode = attn_mode
55
+ self.pe_mode = pe_mode
56
+ self.use_fp16 = use_fp16
57
+ self.use_checkpoint = use_checkpoint
58
+ self.qk_rms_norm = qk_rms_norm
59
+ self.dtype = torch.float16 if use_fp16 else torch.float32
60
+
61
+ if pe_mode == "ape":
62
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
63
+
64
+ self.input_layer = sp.SparseLinear(in_channels, model_channels)
65
+ self.blocks = nn.ModuleList([
66
+ SparseTransformerBlock(
67
+ model_channels,
68
+ num_heads=self.num_heads,
69
+ mlp_ratio=self.mlp_ratio,
70
+ attn_mode=attn_mode,
71
+ window_size=window_size,
72
+ shift_sequence=shift_sequence,
73
+ shift_window=shift_window,
74
+ serialize_mode=serialize_mode,
75
+ use_checkpoint=self.use_checkpoint,
76
+ use_rope=(pe_mode == "rope"),
77
+ qk_rms_norm=self.qk_rms_norm,
78
+ )
79
+ for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
80
+ ])
81
+
82
+ @property
83
+ def device(self) -> torch.device:
84
+ """
85
+ Return the device of the model.
86
+ """
87
+ return next(self.parameters()).device
88
+
89
+ def convert_to_fp16(self) -> None:
90
+ """
91
+ Convert the torso of the model to float16.
92
+ """
93
+ self.blocks.apply(convert_module_to_f16)
94
+
95
+ def convert_to_fp32(self) -> None:
96
+ """
97
+ Convert the torso of the model to float32.
98
+ """
99
+ self.blocks.apply(convert_module_to_f32)
100
+
101
+ def initialize_weights(self) -> None:
102
+ # Initialize transformer layers:
103
+ def _basic_init(module):
104
+ if isinstance(module, nn.Linear):
105
+ torch.nn.init.xavier_uniform_(module.weight)
106
+ if module.bias is not None:
107
+ nn.init.constant_(module.bias, 0)
108
+ self.apply(_basic_init)
109
+
110
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
111
+ h = self.input_layer(x)
112
+ if self.pe_mode == "ape":
113
+ h = h + self.pos_embedder(x.coords[:, 1:])
114
+ h = h.type(self.dtype)
115
+ for block in self.blocks:
116
+ h = block(h)
117
+ return h
thirdparty/TRELLIS/trellis/trellis/models/structured_latent_vae/decoder_gs.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from ...modules import sparse as sp
6
+ from ...utils.random_utils import hammersley_sequence
7
+ from .base import SparseTransformerBase
8
+ from ...representations import Gaussian
9
+
10
+
11
+ class SLatGaussianDecoder(SparseTransformerBase):
12
+ def __init__(
13
+ self,
14
+ resolution: int,
15
+ model_channels: int,
16
+ latent_channels: int,
17
+ num_blocks: int,
18
+ num_heads: Optional[int] = None,
19
+ num_head_channels: Optional[int] = 64,
20
+ mlp_ratio: float = 4,
21
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
22
+ window_size: int = 8,
23
+ pe_mode: Literal["ape", "rope"] = "ape",
24
+ use_fp16: bool = False,
25
+ use_checkpoint: bool = False,
26
+ qk_rms_norm: bool = False,
27
+ representation_config: dict = None,
28
+ ):
29
+ super().__init__(
30
+ in_channels=latent_channels,
31
+ model_channels=model_channels,
32
+ num_blocks=num_blocks,
33
+ num_heads=num_heads,
34
+ num_head_channels=num_head_channels,
35
+ mlp_ratio=mlp_ratio,
36
+ attn_mode=attn_mode,
37
+ window_size=window_size,
38
+ pe_mode=pe_mode,
39
+ use_fp16=use_fp16,
40
+ use_checkpoint=use_checkpoint,
41
+ qk_rms_norm=qk_rms_norm,
42
+ )
43
+ self.resolution = resolution
44
+ self.rep_config = representation_config
45
+ self._calc_layout()
46
+ self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
47
+ self._build_perturbation()
48
+
49
+ self.initialize_weights()
50
+ if use_fp16:
51
+ self.convert_to_fp16()
52
+
53
+ def initialize_weights(self) -> None:
54
+ super().initialize_weights()
55
+ # Zero-out output layers:
56
+ nn.init.constant_(self.out_layer.weight, 0)
57
+ nn.init.constant_(self.out_layer.bias, 0)
58
+
59
+ def _build_perturbation(self) -> None:
60
+ perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])]
61
+ perturbation = torch.tensor(perturbation).float() * 2 - 1
62
+ perturbation = perturbation / self.rep_config['voxel_size']
63
+ perturbation = torch.atanh(perturbation).to(self.device)
64
+ self.register_buffer('offset_perturbation', perturbation)
65
+
66
+ def _calc_layout(self) -> None:
67
+ self.layout = {
68
+ '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
69
+ '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3},
70
+ '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
71
+ '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4},
72
+ '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']},
73
+ }
74
+ start = 0
75
+ for k, v in self.layout.items():
76
+ v['range'] = (start, start + v['size'])
77
+ start += v['size']
78
+ self.out_channels = start
79
+
80
+ def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]:
81
+ """
82
+ Convert a batch of network outputs to 3D representations.
83
+
84
+ Args:
85
+ x: The [N x * x C] sparse tensor output by the network.
86
+
87
+ Returns:
88
+ list of representations
89
+ """
90
+ ret = []
91
+ for i in range(x.shape[0]):
92
+ representation = Gaussian(
93
+ sh_degree=0,
94
+ aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
95
+ mininum_kernel_size = self.rep_config['3d_filter_kernel_size'],
96
+ scaling_bias = self.rep_config['scaling_bias'],
97
+ opacity_bias = self.rep_config['opacity_bias'],
98
+ scaling_activation = self.rep_config['scaling_activation']
99
+ )
100
+ xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
101
+ for k, v in self.layout.items():
102
+ if k == '_xyz':
103
+ offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])
104
+ offset = offset * self.rep_config['lr'][k]
105
+ if self.rep_config['perturb_offset']:
106
+ offset = offset + self.offset_perturbation
107
+ offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size']
108
+ _xyz = xyz.unsqueeze(1) + offset
109
+ setattr(representation, k, _xyz.flatten(0, 1))
110
+ else:
111
+ feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
112
+ feats = feats * self.rep_config['lr'][k]
113
+ setattr(representation, k, feats)
114
+ ret.append(representation)
115
+ return ret
116
+
117
+ def forward(self, x: sp.SparseTensor) -> List[Gaussian]:
118
+ h = super().forward(x)
119
+ h = h.type(x.dtype)
120
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
121
+ h = self.out_layer(h)
122
+ return self.to_representation(h)
thirdparty/TRELLIS/trellis/trellis/models/structured_latent_vae/decoder_mesh.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
7
+ from ...modules import sparse as sp
8
+ from .base import SparseTransformerBase
9
+ from ...representations import MeshExtractResult
10
+ from ...representations.mesh import SparseFeatures2Mesh
11
+
12
+
13
+ class SparseSubdivideBlock3d(nn.Module):
14
+ """
15
+ A 3D subdivide block that can subdivide the sparse tensor.
16
+
17
+ Args:
18
+ channels: channels in the inputs and outputs.
19
+ out_channels: if specified, the number of output channels.
20
+ num_groups: the number of groups for the group norm.
21
+ """
22
+ def __init__(
23
+ self,
24
+ channels: int,
25
+ resolution: int,
26
+ out_channels: Optional[int] = None,
27
+ num_groups: int = 32
28
+ ):
29
+ super().__init__()
30
+ self.channels = channels
31
+ self.resolution = resolution
32
+ self.out_resolution = resolution * 2
33
+ self.out_channels = out_channels or channels
34
+
35
+ self.act_layers = nn.Sequential(
36
+ sp.SparseGroupNorm32(num_groups, channels),
37
+ sp.SparseSiLU()
38
+ )
39
+
40
+ self.sub = sp.SparseSubdivide()
41
+
42
+ self.out_layers = nn.Sequential(
43
+ sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
44
+ sp.SparseGroupNorm32(num_groups, self.out_channels),
45
+ sp.SparseSiLU(),
46
+ zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
47
+ )
48
+
49
+ if self.out_channels == channels:
50
+ self.skip_connection = nn.Identity()
51
+ else:
52
+ self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
53
+
54
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
55
+ """
56
+ Apply the block to a Tensor, conditioned on a timestep embedding.
57
+
58
+ Args:
59
+ x: an [N x C x ...] Tensor of features.
60
+ Returns:
61
+ an [N x C x ...] Tensor of outputs.
62
+ """
63
+ h = self.act_layers(x)
64
+ h = self.sub(h)
65
+ x = self.sub(x)
66
+ h = self.out_layers(h)
67
+ h = h + self.skip_connection(x)
68
+ return h
69
+
70
+
71
+ class SLatMeshDecoder(SparseTransformerBase):
72
+ def __init__(
73
+ self,
74
+ resolution: int,
75
+ model_channels: int,
76
+ latent_channels: int,
77
+ num_blocks: int,
78
+ num_heads: Optional[int] = None,
79
+ num_head_channels: Optional[int] = 64,
80
+ mlp_ratio: float = 4,
81
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
82
+ window_size: int = 8,
83
+ pe_mode: Literal["ape", "rope"] = "ape",
84
+ use_fp16: bool = False,
85
+ use_checkpoint: bool = False,
86
+ qk_rms_norm: bool = False,
87
+ representation_config: dict = None,
88
+ ):
89
+ super().__init__(
90
+ in_channels=latent_channels,
91
+ model_channels=model_channels,
92
+ num_blocks=num_blocks,
93
+ num_heads=num_heads,
94
+ num_head_channels=num_head_channels,
95
+ mlp_ratio=mlp_ratio,
96
+ attn_mode=attn_mode,
97
+ window_size=window_size,
98
+ pe_mode=pe_mode,
99
+ use_fp16=use_fp16,
100
+ use_checkpoint=use_checkpoint,
101
+ qk_rms_norm=qk_rms_norm,
102
+ )
103
+ self.resolution = resolution
104
+ self.rep_config = representation_config
105
+ self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False))
106
+ self.out_channels = self.mesh_extractor.feats_channels
107
+ self.upsample = nn.ModuleList([
108
+ SparseSubdivideBlock3d(
109
+ channels=model_channels,
110
+ resolution=resolution,
111
+ out_channels=model_channels // 4
112
+ ),
113
+ SparseSubdivideBlock3d(
114
+ channels=model_channels // 4,
115
+ resolution=resolution * 2,
116
+ out_channels=model_channels // 8
117
+ )
118
+ ])
119
+ self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
120
+
121
+ self.initialize_weights()
122
+ if use_fp16:
123
+ self.convert_to_fp16()
124
+
125
+ def initialize_weights(self) -> None:
126
+ super().initialize_weights()
127
+ # Zero-out output layers:
128
+ nn.init.constant_(self.out_layer.weight, 0)
129
+ nn.init.constant_(self.out_layer.bias, 0)
130
+
131
+ def convert_to_fp16(self) -> None:
132
+ """
133
+ Convert the torso of the model to float16.
134
+ """
135
+ super().convert_to_fp16()
136
+ self.upsample.apply(convert_module_to_f16)
137
+
138
+ def convert_to_fp32(self) -> None:
139
+ """
140
+ Convert the torso of the model to float32.
141
+ """
142
+ super().convert_to_fp32()
143
+ self.upsample.apply(convert_module_to_f32)
144
+
145
+ def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
146
+ """
147
+ Convert a batch of network outputs to 3D representations.
148
+
149
+ Args:
150
+ x: The [N x * x C] sparse tensor output by the network.
151
+
152
+ Returns:
153
+ list of representations
154
+ """
155
+ ret = []
156
+ for i in range(x.shape[0]):
157
+ mesh = self.mesh_extractor(x[i], training=self.training)
158
+ ret.append(mesh)
159
+ return ret
160
+
161
+ def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
162
+ h = super().forward(x)
163
+ for block in self.upsample:
164
+ h = block(h)
165
+ h = h.type(x.dtype)
166
+ h = self.out_layer(h)
167
+ return self.to_representation(h)
thirdparty/TRELLIS/trellis/trellis/models/structured_latent_vae/decoder_rf.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from ...modules import sparse as sp
7
+ from .base import SparseTransformerBase
8
+ from ...representations import Strivec
9
+
10
+
11
+ class SLatRadianceFieldDecoder(SparseTransformerBase):
12
+ def __init__(
13
+ self,
14
+ resolution: int,
15
+ model_channels: int,
16
+ latent_channels: int,
17
+ num_blocks: int,
18
+ num_heads: Optional[int] = None,
19
+ num_head_channels: Optional[int] = 64,
20
+ mlp_ratio: float = 4,
21
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
22
+ window_size: int = 8,
23
+ pe_mode: Literal["ape", "rope"] = "ape",
24
+ use_fp16: bool = False,
25
+ use_checkpoint: bool = False,
26
+ qk_rms_norm: bool = False,
27
+ representation_config: dict = None,
28
+ ):
29
+ super().__init__(
30
+ in_channels=latent_channels,
31
+ model_channels=model_channels,
32
+ num_blocks=num_blocks,
33
+ num_heads=num_heads,
34
+ num_head_channels=num_head_channels,
35
+ mlp_ratio=mlp_ratio,
36
+ attn_mode=attn_mode,
37
+ window_size=window_size,
38
+ pe_mode=pe_mode,
39
+ use_fp16=use_fp16,
40
+ use_checkpoint=use_checkpoint,
41
+ qk_rms_norm=qk_rms_norm,
42
+ )
43
+ self.resolution = resolution
44
+ self.rep_config = representation_config
45
+ self._calc_layout()
46
+ self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
47
+
48
+ self.initialize_weights()
49
+ if use_fp16:
50
+ self.convert_to_fp16()
51
+
52
+ def initialize_weights(self) -> None:
53
+ super().initialize_weights()
54
+ # Zero-out output layers:
55
+ nn.init.constant_(self.out_layer.weight, 0)
56
+ nn.init.constant_(self.out_layer.bias, 0)
57
+
58
+ def _calc_layout(self) -> None:
59
+ self.layout = {
60
+ 'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']},
61
+ 'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']},
62
+ 'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3},
63
+ }
64
+ start = 0
65
+ for k, v in self.layout.items():
66
+ v['range'] = (start, start + v['size'])
67
+ start += v['size']
68
+ self.out_channels = start
69
+
70
+ def to_representation(self, x: sp.SparseTensor) -> List[Strivec]:
71
+ """
72
+ Convert a batch of network outputs to 3D representations.
73
+
74
+ Args:
75
+ x: The [N x * x C] sparse tensor output by the network.
76
+
77
+ Returns:
78
+ list of representations
79
+ """
80
+ ret = []
81
+ for i in range(x.shape[0]):
82
+ representation = Strivec(
83
+ sh_degree=0,
84
+ resolution=self.resolution,
85
+ aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
86
+ rank=self.rep_config['rank'],
87
+ dim=self.rep_config['dim'],
88
+ device='cuda',
89
+ )
90
+ representation.density_shift = 0.0
91
+ representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
92
+ representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
93
+ for k, v in self.layout.items():
94
+ setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']))
95
+ representation.trivec = representation.trivec + 1
96
+ ret.append(representation)
97
+ return ret
98
+
99
+ def forward(self, x: sp.SparseTensor) -> List[Strivec]:
100
+ h = super().forward(x)
101
+ h = h.type(x.dtype)
102
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
103
+ h = self.out_layer(h)
104
+ return self.to_representation(h)
thirdparty/TRELLIS/trellis/trellis/models/structured_latent_vae/encoder.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from ...modules import sparse as sp
6
+ from .base import SparseTransformerBase
7
+
8
+
9
+ class SLatEncoder(SparseTransformerBase):
10
+ def __init__(
11
+ self,
12
+ resolution: int,
13
+ in_channels: int,
14
+ model_channels: int,
15
+ latent_channels: int,
16
+ num_blocks: int,
17
+ num_heads: Optional[int] = None,
18
+ num_head_channels: Optional[int] = 64,
19
+ mlp_ratio: float = 4,
20
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
21
+ window_size: int = 8,
22
+ pe_mode: Literal["ape", "rope"] = "ape",
23
+ use_fp16: bool = False,
24
+ use_checkpoint: bool = False,
25
+ qk_rms_norm: bool = False,
26
+ ):
27
+ super().__init__(
28
+ in_channels=in_channels,
29
+ model_channels=model_channels,
30
+ num_blocks=num_blocks,
31
+ num_heads=num_heads,
32
+ num_head_channels=num_head_channels,
33
+ mlp_ratio=mlp_ratio,
34
+ attn_mode=attn_mode,
35
+ window_size=window_size,
36
+ pe_mode=pe_mode,
37
+ use_fp16=use_fp16,
38
+ use_checkpoint=use_checkpoint,
39
+ qk_rms_norm=qk_rms_norm,
40
+ )
41
+ self.resolution = resolution
42
+ self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
43
+
44
+ self.initialize_weights()
45
+ if use_fp16:
46
+ self.convert_to_fp16()
47
+
48
+ def initialize_weights(self) -> None:
49
+ super().initialize_weights()
50
+ # Zero-out output layers:
51
+ nn.init.constant_(self.out_layer.weight, 0)
52
+ nn.init.constant_(self.out_layer.bias, 0)
53
+
54
+ def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False):
55
+ h = super().forward(x)
56
+ h = h.type(x.dtype)
57
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
58
+ h = self.out_layer(h)
59
+
60
+ # Sample from the posterior distribution
61
+ mean, logvar = h.feats.chunk(2, dim=-1)
62
+ if sample_posterior:
63
+ std = torch.exp(0.5 * logvar)
64
+ z = mean + std * torch.randn_like(std)
65
+ else:
66
+ z = mean
67
+ z = h.replace(z)
68
+
69
+ if return_raw:
70
+ return z, mean, logvar
71
+ else:
72
+ return z
thirdparty/TRELLIS/trellis/trellis/modules/attention/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ BACKEND = 'flash_attn'
4
+ DEBUG = False
5
+
6
+ def __from_env():
7
+ import os
8
+
9
+ global BACKEND
10
+ global DEBUG
11
+
12
+ env_attn_backend = os.environ.get('ATTN_BACKEND')
13
+ env_sttn_debug = os.environ.get('ATTN_DEBUG')
14
+
15
+ if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
16
+ BACKEND = env_attn_backend
17
+ if env_sttn_debug is not None:
18
+ DEBUG = env_sttn_debug == '1'
19
+
20
+ print(f"[ATTENTION] Using backend: {BACKEND}")
21
+
22
+
23
+ __from_env()
24
+
25
+
26
+ def set_backend(backend: Literal['xformers', 'flash_attn']):
27
+ global BACKEND
28
+ BACKEND = backend
29
+
30
+ def set_debug(debug: bool):
31
+ global DEBUG
32
+ DEBUG = debug
33
+
34
+
35
+ from .full_attn import *
36
+ from .modules import *
thirdparty/TRELLIS/trellis/trellis/modules/attention/full_attn.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import math
4
+ from . import DEBUG, BACKEND
5
+
6
+ if BACKEND == 'xformers':
7
+ import xformers.ops as xops
8
+ elif BACKEND == 'flash_attn':
9
+ import flash_attn
10
+ elif BACKEND == 'sdpa':
11
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
12
+ elif BACKEND == 'naive':
13
+ pass
14
+ else:
15
+ raise ValueError(f"Unknown attention backend: {BACKEND}")
16
+
17
+
18
+ __all__ = [
19
+ 'scaled_dot_product_attention',
20
+ ]
21
+
22
+
23
+ def _naive_sdpa(q, k, v):
24
+ """
25
+ Naive implementation of scaled dot product attention.
26
+ """
27
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
28
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
29
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
30
+ scale_factor = 1 / math.sqrt(q.size(-1))
31
+ attn_weight = q @ k.transpose(-2, -1) * scale_factor
32
+ attn_weight = torch.softmax(attn_weight, dim=-1)
33
+ out = attn_weight @ v
34
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
35
+ return out
36
+
37
+
38
+ @overload
39
+ def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
40
+ """
41
+ Apply scaled dot product attention.
42
+
43
+ Args:
44
+ qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
45
+ """
46
+ ...
47
+
48
+ @overload
49
+ def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
50
+ """
51
+ Apply scaled dot product attention.
52
+
53
+ Args:
54
+ q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
55
+ kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
56
+ """
57
+ ...
58
+
59
+ @overload
60
+ def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Apply scaled dot product attention.
63
+
64
+ Args:
65
+ q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
66
+ k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
67
+ v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
68
+
69
+ Note:
70
+ k and v are assumed to have the same coordinate map.
71
+ """
72
+ ...
73
+
74
+ def scaled_dot_product_attention(*args, **kwargs):
75
+ arg_names_dict = {
76
+ 1: ['qkv'],
77
+ 2: ['q', 'kv'],
78
+ 3: ['q', 'k', 'v']
79
+ }
80
+ num_all_args = len(args) + len(kwargs)
81
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
82
+ for key in arg_names_dict[num_all_args][len(args):]:
83
+ assert key in kwargs, f"Missing argument {key}"
84
+
85
+ if num_all_args == 1:
86
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
87
+ assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
88
+ device = qkv.device
89
+
90
+ elif num_all_args == 2:
91
+ q = args[0] if len(args) > 0 else kwargs['q']
92
+ kv = args[1] if len(args) > 1 else kwargs['kv']
93
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
94
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
95
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
96
+ device = q.device
97
+
98
+ elif num_all_args == 3:
99
+ q = args[0] if len(args) > 0 else kwargs['q']
100
+ k = args[1] if len(args) > 1 else kwargs['k']
101
+ v = args[2] if len(args) > 2 else kwargs['v']
102
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
103
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
104
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
105
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
106
+ device = q.device
107
+
108
+ if BACKEND == 'xformers':
109
+ if num_all_args == 1:
110
+ q, k, v = qkv.unbind(dim=2)
111
+ elif num_all_args == 2:
112
+ k, v = kv.unbind(dim=2)
113
+ out = xops.memory_efficient_attention(q, k, v)
114
+ elif BACKEND == 'flash_attn':
115
+ if num_all_args == 1:
116
+ out = flash_attn.flash_attn_qkvpacked_func(qkv)
117
+ elif num_all_args == 2:
118
+ out = flash_attn.flash_attn_kvpacked_func(q, kv)
119
+ elif num_all_args == 3:
120
+ out = flash_attn.flash_attn_func(q, k, v)
121
+ elif BACKEND == 'sdpa':
122
+ if num_all_args == 1:
123
+ q, k, v = qkv.unbind(dim=2)
124
+ elif num_all_args == 2:
125
+ k, v = kv.unbind(dim=2)
126
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
127
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
128
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
129
+ out = sdpa(q, k, v) # [N, H, L, C]
130
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
131
+ elif BACKEND == 'naive':
132
+ if num_all_args == 1:
133
+ q, k, v = qkv.unbind(dim=2)
134
+ elif num_all_args == 2:
135
+ k, v = kv.unbind(dim=2)
136
+ out = _naive_sdpa(q, k, v)
137
+ else:
138
+ raise ValueError(f"Unknown attention module: {BACKEND}")
139
+
140
+ return out
thirdparty/TRELLIS/trellis/trellis/modules/attention/modules.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .full_attn import scaled_dot_product_attention
6
+
7
+
8
+ class MultiHeadRMSNorm(nn.Module):
9
+ def __init__(self, dim: int, heads: int):
10
+ super().__init__()
11
+ self.scale = dim ** 0.5
12
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
16
+
17
+
18
+ class RotaryPositionEmbedder(nn.Module):
19
+ def __init__(self, hidden_size: int, in_channels: int = 3):
20
+ super().__init__()
21
+ assert hidden_size % 2 == 0, "Hidden size must be divisible by 2"
22
+ self.hidden_size = hidden_size
23
+ self.in_channels = in_channels
24
+ self.freq_dim = hidden_size // in_channels // 2
25
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
26
+ self.freqs = 1.0 / (10000 ** self.freqs)
27
+
28
+ def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
29
+ self.freqs = self.freqs.to(indices.device)
30
+ phases = torch.outer(indices, self.freqs)
31
+ phases = torch.polar(torch.ones_like(phases), phases)
32
+ return phases
33
+
34
+ def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
35
+ x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
36
+ x_rotated = x_complex * phases
37
+ x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
38
+ return x_embed
39
+
40
+ def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
41
+ """
42
+ Args:
43
+ q (sp.SparseTensor): [..., N, D] tensor of queries
44
+ k (sp.SparseTensor): [..., N, D] tensor of keys
45
+ indices (torch.Tensor): [..., N, C] tensor of spatial positions
46
+ """
47
+ if indices is None:
48
+ indices = torch.arange(q.shape[-2], device=q.device)
49
+ if len(q.shape) > 2:
50
+ indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
51
+
52
+ phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
53
+ if phases.shape[1] < self.hidden_size // 2:
54
+ phases = torch.cat([phases, torch.polar(
55
+ torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
56
+ torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
57
+ )], dim=-1)
58
+ q_embed = self._rotary_embedding(q, phases)
59
+ k_embed = self._rotary_embedding(k, phases)
60
+ return q_embed, k_embed
61
+
62
+
63
+ class MultiHeadAttention(nn.Module):
64
+ def __init__(
65
+ self,
66
+ channels: int,
67
+ num_heads: int,
68
+ ctx_channels: Optional[int]=None,
69
+ type: Literal["self", "cross"] = "self",
70
+ attn_mode: Literal["full", "windowed"] = "full",
71
+ window_size: Optional[int] = None,
72
+ shift_window: Optional[Tuple[int, int, int]] = None,
73
+ qkv_bias: bool = True,
74
+ use_rope: bool = False,
75
+ qk_rms_norm: bool = False,
76
+ ):
77
+ super().__init__()
78
+ assert channels % num_heads == 0
79
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
80
+ assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
81
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
82
+
83
+ if attn_mode == "windowed":
84
+ raise NotImplementedError("Windowed attention is not yet implemented")
85
+
86
+ self.channels = channels
87
+ self.head_dim = channels // num_heads
88
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
89
+ self.num_heads = num_heads
90
+ self._type = type
91
+ self.attn_mode = attn_mode
92
+ self.window_size = window_size
93
+ self.shift_window = shift_window
94
+ self.use_rope = use_rope
95
+ self.qk_rms_norm = qk_rms_norm
96
+
97
+ if self._type == "self":
98
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
99
+ else:
100
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
101
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
102
+
103
+ if self.qk_rms_norm:
104
+ self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
105
+ self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
106
+
107
+ self.to_out = nn.Linear(channels, channels)
108
+
109
+ if use_rope:
110
+ self.rope = RotaryPositionEmbedder(channels)
111
+
112
+ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
113
+ B, L, C = x.shape
114
+ if self._type == "self":
115
+ qkv = self.to_qkv(x)
116
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
117
+ if self.use_rope:
118
+ q, k, v = qkv.unbind(dim=2)
119
+ q, k = self.rope(q, k, indices)
120
+ qkv = torch.stack([q, k, v], dim=2)
121
+ if self.attn_mode == "full":
122
+ if self.qk_rms_norm:
123
+ q, k, v = qkv.unbind(dim=2)
124
+ q = self.q_rms_norm(q)
125
+ k = self.k_rms_norm(k)
126
+ h = scaled_dot_product_attention(q, k, v)
127
+ else:
128
+ h = scaled_dot_product_attention(qkv)
129
+ elif self.attn_mode == "windowed":
130
+ raise NotImplementedError("Windowed attention is not yet implemented")
131
+ else:
132
+ Lkv = context.shape[1]
133
+ q = self.to_q(x)
134
+ kv = self.to_kv(context)
135
+ q = q.reshape(B, L, self.num_heads, -1)
136
+ kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
137
+ if self.qk_rms_norm:
138
+ q = self.q_rms_norm(q)
139
+ k, v = kv.unbind(dim=2)
140
+ k = self.k_rms_norm(k)
141
+ h = scaled_dot_product_attention(q, k, v)
142
+ else:
143
+ h = scaled_dot_product_attention(q, kv)
144
+ h = h.reshape(B, L, -1)
145
+ h = self.to_out(h)
146
+ return h
thirdparty/TRELLIS/trellis/trellis/modules/norm.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class LayerNorm32(nn.LayerNorm):
6
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
7
+ return super().forward(x.float()).type(x.dtype)
8
+
9
+
10
+ class GroupNorm32(nn.GroupNorm):
11
+ """
12
+ A GroupNorm layer that converts to float32 before the forward pass.
13
+ """
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return super().forward(x.float()).type(x.dtype)
16
+
17
+
18
+ class ChannelLayerNorm32(LayerNorm32):
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ DIM = x.dim()
21
+ x = x.permute(0, *range(2, DIM), 1).contiguous()
22
+ x = super().forward(x)
23
+ x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
24
+ return x
25
+
thirdparty/TRELLIS/trellis/trellis/modules/sparse/__init__.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ BACKEND = 'spconv'
4
+ DEBUG = False
5
+ ATTN = 'flash_attn'
6
+
7
+ def __from_env():
8
+ import os
9
+
10
+ global BACKEND
11
+ global DEBUG
12
+ global ATTN
13
+
14
+ env_sparse_backend = os.environ.get('SPARSE_BACKEND')
15
+ env_sparse_debug = os.environ.get('SPARSE_DEBUG')
16
+ env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
17
+ if env_sparse_attn is None:
18
+ env_sparse_attn = os.environ.get('ATTN_BACKEND')
19
+
20
+ if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
21
+ BACKEND = env_sparse_backend
22
+ if env_sparse_debug is not None:
23
+ DEBUG = env_sparse_debug == '1'
24
+ if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
25
+ ATTN = env_sparse_attn
26
+
27
+ print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
28
+
29
+
30
+ __from_env()
31
+
32
+
33
+ def set_backend(backend: Literal['spconv', 'torchsparse']):
34
+ global BACKEND
35
+ BACKEND = backend
36
+
37
+ def set_debug(debug: bool):
38
+ global DEBUG
39
+ DEBUG = debug
40
+
41
+ def set_attn(attn: Literal['xformers', 'flash_attn']):
42
+ global ATTN
43
+ ATTN = attn
44
+
45
+
46
+ import importlib
47
+
48
+ __attributes = {
49
+ 'SparseTensor': 'basic',
50
+ 'sparse_batch_broadcast': 'basic',
51
+ 'sparse_batch_op': 'basic',
52
+ 'sparse_cat': 'basic',
53
+ 'sparse_unbind': 'basic',
54
+ 'SparseGroupNorm': 'norm',
55
+ 'SparseLayerNorm': 'norm',
56
+ 'SparseGroupNorm32': 'norm',
57
+ 'SparseLayerNorm32': 'norm',
58
+ 'SparseReLU': 'nonlinearity',
59
+ 'SparseSiLU': 'nonlinearity',
60
+ 'SparseGELU': 'nonlinearity',
61
+ 'SparseActivation': 'nonlinearity',
62
+ 'SparseLinear': 'linear',
63
+ 'sparse_scaled_dot_product_attention': 'attention',
64
+ 'SerializeMode': 'attention',
65
+ 'sparse_serialized_scaled_dot_product_self_attention': 'attention',
66
+ 'sparse_windowed_scaled_dot_product_self_attention': 'attention',
67
+ 'SparseMultiHeadAttention': 'attention',
68
+ 'SparseConv3d': 'conv',
69
+ 'SparseInverseConv3d': 'conv',
70
+ 'SparseDownsample': 'spatial',
71
+ 'SparseUpsample': 'spatial',
72
+ 'SparseSubdivide' : 'spatial'
73
+ }
74
+
75
+ __submodules = ['transformer']
76
+
77
+ __all__ = list(__attributes.keys()) + __submodules
78
+
79
+ def __getattr__(name):
80
+ if name not in globals():
81
+ if name in __attributes:
82
+ module_name = __attributes[name]
83
+ module = importlib.import_module(f".{module_name}", __name__)
84
+ globals()[name] = getattr(module, name)
85
+ elif name in __submodules:
86
+ module = importlib.import_module(f".{name}", __name__)
87
+ globals()[name] = module
88
+ else:
89
+ raise AttributeError(f"module {__name__} has no attribute {name}")
90
+ return globals()[name]
91
+
92
+
93
+ # For Pylance
94
+ if __name__ == '__main__':
95
+ from .basic import *
96
+ from .norm import *
97
+ from .nonlinearity import *
98
+ from .linear import *
99
+ from .attention import *
100
+ from .conv import *
101
+ from .spatial import *
102
+ import transformer
thirdparty/TRELLIS/trellis/trellis/modules/sparse/attention/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .full_attn import *
2
+ from .serialized_attn import *
3
+ from .windowed_attn import *
4
+ from .modules import *
thirdparty/TRELLIS/trellis/trellis/modules/sparse/attention/full_attn.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ from .. import SparseTensor
4
+ from .. import DEBUG, ATTN
5
+
6
+ if ATTN == 'xformers':
7
+ import xformers.ops as xops
8
+ elif ATTN == 'flash_attn':
9
+ import flash_attn
10
+ else:
11
+ raise ValueError(f"Unknown attention module: {ATTN}")
12
+
13
+
14
+ __all__ = [
15
+ 'sparse_scaled_dot_product_attention',
16
+ ]
17
+
18
+
19
+ @overload
20
+ def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
21
+ """
22
+ Apply scaled dot product attention to a sparse tensor.
23
+
24
+ Args:
25
+ qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
26
+ """
27
+ ...
28
+
29
+ @overload
30
+ def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor:
31
+ """
32
+ Apply scaled dot product attention to a sparse tensor.
33
+
34
+ Args:
35
+ q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs.
36
+ kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
37
+ """
38
+ ...
39
+
40
+ @overload
41
+ def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor:
42
+ """
43
+ Apply scaled dot product attention to a sparse tensor.
44
+
45
+ Args:
46
+ q (SparseTensor): A [N, L, H, C] dense tensor containing Qs.
47
+ kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
48
+ """
49
+ ...
50
+
51
+ @overload
52
+ def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor:
53
+ """
54
+ Apply scaled dot product attention to a sparse tensor.
55
+
56
+ Args:
57
+ q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
58
+ k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
59
+ v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
60
+
61
+ Note:
62
+ k and v are assumed to have the same coordinate map.
63
+ """
64
+ ...
65
+
66
+ @overload
67
+ def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor:
68
+ """
69
+ Apply scaled dot product attention to a sparse tensor.
70
+
71
+ Args:
72
+ q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
73
+ k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
74
+ v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
75
+ """
76
+ ...
77
+
78
+ @overload
79
+ def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor:
80
+ """
81
+ Apply scaled dot product attention to a sparse tensor.
82
+
83
+ Args:
84
+ q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
85
+ k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
86
+ v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
87
+ """
88
+ ...
89
+
90
+ def sparse_scaled_dot_product_attention(*args, **kwargs):
91
+ arg_names_dict = {
92
+ 1: ['qkv'],
93
+ 2: ['q', 'kv'],
94
+ 3: ['q', 'k', 'v']
95
+ }
96
+ num_all_args = len(args) + len(kwargs)
97
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
98
+ for key in arg_names_dict[num_all_args][len(args):]:
99
+ assert key in kwargs, f"Missing argument {key}"
100
+
101
+ if num_all_args == 1:
102
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
103
+ assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}"
104
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
105
+ device = qkv.device
106
+
107
+ s = qkv
108
+ q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
109
+ kv_seqlen = q_seqlen
110
+ qkv = qkv.feats # [T, 3, H, C]
111
+
112
+ elif num_all_args == 2:
113
+ q = args[0] if len(args) > 0 else kwargs['q']
114
+ kv = args[1] if len(args) > 1 else kwargs['kv']
115
+ assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \
116
+ isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \
117
+ f"Invalid types, got {type(q)} and {type(kv)}"
118
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
119
+ device = q.device
120
+
121
+ if isinstance(q, SparseTensor):
122
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
123
+ s = q
124
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
125
+ q = q.feats # [T_Q, H, C]
126
+ else:
127
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
128
+ s = None
129
+ N, L, H, C = q.shape
130
+ q_seqlen = [L] * N
131
+ q = q.reshape(N * L, H, C) # [T_Q, H, C]
132
+
133
+ if isinstance(kv, SparseTensor):
134
+ assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
135
+ kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
136
+ kv = kv.feats # [T_KV, 2, H, C]
137
+ else:
138
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
139
+ N, L, _, H, C = kv.shape
140
+ kv_seqlen = [L] * N
141
+ kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
142
+
143
+ elif num_all_args == 3:
144
+ q = args[0] if len(args) > 0 else kwargs['q']
145
+ k = args[1] if len(args) > 1 else kwargs['k']
146
+ v = args[2] if len(args) > 2 else kwargs['v']
147
+ assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \
148
+ isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \
149
+ f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
150
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
151
+ device = q.device
152
+
153
+ if isinstance(q, SparseTensor):
154
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
155
+ s = q
156
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
157
+ q = q.feats # [T_Q, H, Ci]
158
+ else:
159
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
160
+ s = None
161
+ N, L, H, CI = q.shape
162
+ q_seqlen = [L] * N
163
+ q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
164
+
165
+ if isinstance(k, SparseTensor):
166
+ assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
167
+ assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
168
+ kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
169
+ k = k.feats # [T_KV, H, Ci]
170
+ v = v.feats # [T_KV, H, Co]
171
+ else:
172
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
173
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
174
+ N, L, H, CI, CO = *k.shape, v.shape[-1]
175
+ kv_seqlen = [L] * N
176
+ k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
177
+ v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
178
+
179
+ if DEBUG:
180
+ if s is not None:
181
+ for i in range(s.shape[0]):
182
+ assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
183
+ if num_all_args in [2, 3]:
184
+ assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch"
185
+ if num_all_args == 3:
186
+ assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch"
187
+ assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch"
188
+
189
+ if ATTN == 'xformers':
190
+ if num_all_args == 1:
191
+ q, k, v = qkv.unbind(dim=1)
192
+ elif num_all_args == 2:
193
+ k, v = kv.unbind(dim=1)
194
+ q = q.unsqueeze(0)
195
+ k = k.unsqueeze(0)
196
+ v = v.unsqueeze(0)
197
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
198
+ out = xops.memory_efficient_attention(q, k, v, mask)[0]
199
+ elif ATTN == 'flash_attn':
200
+ cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
201
+ if num_all_args in [2, 3]:
202
+ cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
203
+ if num_all_args == 1:
204
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
205
+ elif num_all_args == 2:
206
+ out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
207
+ elif num_all_args == 3:
208
+ out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
209
+ else:
210
+ raise ValueError(f"Unknown attention module: {ATTN}")
211
+
212
+ if s is not None:
213
+ return s.replace(out)
214
+ else:
215
+ return out.reshape(N, L, H, -1)
thirdparty/TRELLIS/trellis/trellis/modules/sparse/attention/modules.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .. import SparseTensor
6
+ from .full_attn import sparse_scaled_dot_product_attention
7
+ from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention
8
+ from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
9
+ from ...attention import RotaryPositionEmbedder
10
+
11
+
12
+ class SparseMultiHeadRMSNorm(nn.Module):
13
+ def __init__(self, dim: int, heads: int):
14
+ super().__init__()
15
+ self.scale = dim ** 0.5
16
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
17
+
18
+ def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
19
+ x_type = x.dtype
20
+ x = x.float()
21
+ if isinstance(x, SparseTensor):
22
+ x = x.replace(F.normalize(x.feats, dim=-1))
23
+ else:
24
+ x = F.normalize(x, dim=-1)
25
+ return (x * self.gamma * self.scale).to(x_type)
26
+
27
+
28
+ class SparseMultiHeadAttention(nn.Module):
29
+ def __init__(
30
+ self,
31
+ channels: int,
32
+ num_heads: int,
33
+ ctx_channels: Optional[int] = None,
34
+ type: Literal["self", "cross"] = "self",
35
+ attn_mode: Literal["full", "serialized", "windowed"] = "full",
36
+ window_size: Optional[int] = None,
37
+ shift_sequence: Optional[int] = None,
38
+ shift_window: Optional[Tuple[int, int, int]] = None,
39
+ serialize_mode: Optional[SerializeMode] = None,
40
+ qkv_bias: bool = True,
41
+ use_rope: bool = False,
42
+ qk_rms_norm: bool = False,
43
+ ):
44
+ super().__init__()
45
+ assert channels % num_heads == 0
46
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
47
+ assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}"
48
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
49
+ assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
50
+ self.channels = channels
51
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
52
+ self.num_heads = num_heads
53
+ self._type = type
54
+ self.attn_mode = attn_mode
55
+ self.window_size = window_size
56
+ self.shift_sequence = shift_sequence
57
+ self.shift_window = shift_window
58
+ self.serialize_mode = serialize_mode
59
+ self.use_rope = use_rope
60
+ self.qk_rms_norm = qk_rms_norm
61
+
62
+ if self._type == "self":
63
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
64
+ else:
65
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
66
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
67
+
68
+ if self.qk_rms_norm:
69
+ self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
70
+ self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
71
+
72
+ self.to_out = nn.Linear(channels, channels)
73
+
74
+ if use_rope:
75
+ self.rope = RotaryPositionEmbedder(channels)
76
+
77
+ @staticmethod
78
+ def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
79
+ if isinstance(x, SparseTensor):
80
+ return x.replace(module(x.feats))
81
+ else:
82
+ return module(x)
83
+
84
+ @staticmethod
85
+ def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]:
86
+ if isinstance(x, SparseTensor):
87
+ return x.reshape(*shape)
88
+ else:
89
+ return x.reshape(*x.shape[:2], *shape)
90
+
91
+ def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]:
92
+ if isinstance(x, SparseTensor):
93
+ x_feats = x.feats.unsqueeze(0)
94
+ else:
95
+ x_feats = x
96
+ x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
97
+ return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
98
+
99
+ def _rope(self, qkv: SparseTensor) -> SparseTensor:
100
+ q, k, v = qkv.feats.unbind(dim=1) # [T, H, C]
101
+ q, k = self.rope(q, k, qkv.coords[:, 1:])
102
+ qkv = qkv.replace(torch.stack([q, k, v], dim=1))
103
+ return qkv
104
+
105
+ def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]:
106
+ if self._type == "self":
107
+ qkv = self._linear(self.to_qkv, x)
108
+ qkv = self._fused_pre(qkv, num_fused=3)
109
+ if self.use_rope:
110
+ qkv = self._rope(qkv)
111
+ if self.qk_rms_norm:
112
+ q, k, v = qkv.unbind(dim=1)
113
+ q = self.q_rms_norm(q)
114
+ k = self.k_rms_norm(k)
115
+ qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
116
+ if self.attn_mode == "full":
117
+ h = sparse_scaled_dot_product_attention(qkv)
118
+ elif self.attn_mode == "serialized":
119
+ h = sparse_serialized_scaled_dot_product_self_attention(
120
+ qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window
121
+ )
122
+ elif self.attn_mode == "windowed":
123
+ h = sparse_windowed_scaled_dot_product_self_attention(
124
+ qkv, self.window_size, shift_window=self.shift_window
125
+ )
126
+ else:
127
+ q = self._linear(self.to_q, x)
128
+ q = self._reshape_chs(q, (self.num_heads, -1))
129
+ kv = self._linear(self.to_kv, context)
130
+ kv = self._fused_pre(kv, num_fused=2)
131
+ if self.qk_rms_norm:
132
+ q = self.q_rms_norm(q)
133
+ k, v = kv.unbind(dim=1)
134
+ k = self.k_rms_norm(k)
135
+ kv = kv.replace(torch.stack([k.feats, v.feats], dim=1))
136
+ h = sparse_scaled_dot_product_attention(q, kv)
137
+ h = self._reshape_chs(h, (-1,))
138
+ h = self._linear(self.to_out, h)
139
+ return h
thirdparty/TRELLIS/trellis/trellis/modules/sparse/attention/serialized_attn.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from enum import Enum
3
+ import torch
4
+ import math
5
+ from .. import SparseTensor
6
+ from .. import DEBUG, ATTN
7
+
8
+ if ATTN == 'xformers':
9
+ import xformers.ops as xops
10
+ elif ATTN == 'flash_attn':
11
+ import flash_attn
12
+ else:
13
+ raise ValueError(f"Unknown attention module: {ATTN}")
14
+
15
+
16
+ __all__ = [
17
+ 'sparse_serialized_scaled_dot_product_self_attention',
18
+ ]
19
+
20
+
21
+ class SerializeMode(Enum):
22
+ Z_ORDER = 0
23
+ Z_ORDER_TRANSPOSED = 1
24
+ HILBERT = 2
25
+ HILBERT_TRANSPOSED = 3
26
+
27
+
28
+ SerializeModes = [
29
+ SerializeMode.Z_ORDER,
30
+ SerializeMode.Z_ORDER_TRANSPOSED,
31
+ SerializeMode.HILBERT,
32
+ SerializeMode.HILBERT_TRANSPOSED
33
+ ]
34
+
35
+
36
+ def calc_serialization(
37
+ tensor: SparseTensor,
38
+ window_size: int,
39
+ serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
40
+ shift_sequence: int = 0,
41
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
42
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
43
+ """
44
+ Calculate serialization and partitioning for a set of coordinates.
45
+
46
+ Args:
47
+ tensor (SparseTensor): The input tensor.
48
+ window_size (int): The window size to use.
49
+ serialize_mode (SerializeMode): The serialization mode to use.
50
+ shift_sequence (int): The shift of serialized sequence.
51
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
52
+
53
+ Returns:
54
+ (torch.Tensor, torch.Tensor): Forwards and backwards indices.
55
+ """
56
+ fwd_indices = []
57
+ bwd_indices = []
58
+ seq_lens = []
59
+ seq_batch_indices = []
60
+ offsets = [0]
61
+
62
+ if 'vox2seq' not in globals():
63
+ import vox2seq
64
+
65
+ # Serialize the input
66
+ serialize_coords = tensor.coords[:, 1:].clone()
67
+ serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
68
+ if serialize_mode == SerializeMode.Z_ORDER:
69
+ code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
70
+ elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
71
+ code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
72
+ elif serialize_mode == SerializeMode.HILBERT:
73
+ code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
74
+ elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
75
+ code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
76
+ else:
77
+ raise ValueError(f"Unknown serialize mode: {serialize_mode}")
78
+
79
+ for bi, s in enumerate(tensor.layout):
80
+ num_points = s.stop - s.start
81
+ num_windows = (num_points + window_size - 1) // window_size
82
+ valid_window_size = num_points / num_windows
83
+ to_ordered = torch.argsort(code[s.start:s.stop])
84
+ if num_windows == 1:
85
+ fwd_indices.append(to_ordered)
86
+ bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
87
+ fwd_indices[-1] += s.start
88
+ bwd_indices[-1] += offsets[-1]
89
+ seq_lens.append(num_points)
90
+ seq_batch_indices.append(bi)
91
+ offsets.append(offsets[-1] + seq_lens[-1])
92
+ else:
93
+ # Partition the input
94
+ offset = 0
95
+ mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
96
+ split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
97
+ bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
98
+ for i in range(num_windows):
99
+ mid = mids[i]
100
+ valid_start = split[i]
101
+ valid_end = split[i + 1]
102
+ padded_start = math.floor(mid - 0.5 * window_size)
103
+ padded_end = padded_start + window_size
104
+ fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
105
+ offset += valid_start - padded_start
106
+ bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
107
+ offset += padded_end - valid_start
108
+ fwd_indices[-1] += s.start
109
+ seq_lens.extend([window_size] * num_windows)
110
+ seq_batch_indices.extend([bi] * num_windows)
111
+ bwd_indices.append(bwd_index + offsets[-1])
112
+ offsets.append(offsets[-1] + num_windows * window_size)
113
+
114
+ fwd_indices = torch.cat(fwd_indices)
115
+ bwd_indices = torch.cat(bwd_indices)
116
+
117
+ return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
118
+
119
+
120
+ def sparse_serialized_scaled_dot_product_self_attention(
121
+ qkv: SparseTensor,
122
+ window_size: int,
123
+ serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
124
+ shift_sequence: int = 0,
125
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
126
+ ) -> SparseTensor:
127
+ """
128
+ Apply serialized scaled dot product self attention to a sparse tensor.
129
+
130
+ Args:
131
+ qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
132
+ window_size (int): The window size to use.
133
+ serialize_mode (SerializeMode): The serialization mode to use.
134
+ shift_sequence (int): The shift of serialized sequence.
135
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
136
+ shift (int): The shift to use.
137
+ """
138
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
139
+
140
+ serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
141
+ serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
142
+ if serialization_spatial_cache is None:
143
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
144
+ qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
145
+ else:
146
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
147
+
148
+ M = fwd_indices.shape[0]
149
+ T = qkv.feats.shape[0]
150
+ H = qkv.feats.shape[2]
151
+ C = qkv.feats.shape[3]
152
+
153
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
154
+
155
+ if DEBUG:
156
+ start = 0
157
+ qkv_coords = qkv.coords[fwd_indices]
158
+ for i in range(len(seq_lens)):
159
+ assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
160
+ start += seq_lens[i]
161
+
162
+ if all([seq_len == window_size for seq_len in seq_lens]):
163
+ B = len(seq_lens)
164
+ N = window_size
165
+ qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
166
+ if ATTN == 'xformers':
167
+ q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
168
+ out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
169
+ elif ATTN == 'flash_attn':
170
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
171
+ else:
172
+ raise ValueError(f"Unknown attention module: {ATTN}")
173
+ out = out.reshape(B * N, H, C) # [M, H, C]
174
+ else:
175
+ if ATTN == 'xformers':
176
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
177
+ q = q.unsqueeze(0) # [1, M, H, C]
178
+ k = k.unsqueeze(0) # [1, M, H, C]
179
+ v = v.unsqueeze(0) # [1, M, H, C]
180
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
181
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
182
+ elif ATTN == 'flash_attn':
183
+ cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
184
+ .to(qkv.device).int()
185
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
186
+
187
+ out = out[bwd_indices] # [T, H, C]
188
+
189
+ if DEBUG:
190
+ qkv_coords = qkv_coords[bwd_indices]
191
+ assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
192
+
193
+ return qkv.replace(out)