xinjie.wang commited on
Commit
5f6e2e5
·
1 Parent(s): af48ff9
app.py CHANGED
@@ -1,35 +1,8 @@
1
  import os
2
- import shutil
3
- from functools import partial
4
-
5
- import gradio as gr
6
- from common import (
7
- MAX_SEED,
8
- VERSION,
9
- TrellisImageTo3DPipeline,
10
- active_btn_by_content,
11
- extract_3d_representations_v2,
12
- extract_urdf,
13
- get_seed,
14
- image_to_3d,
15
- preprocess_image_fn,
16
- preprocess_sam_image_fn,
17
- select_point,
18
- )
19
- from gradio.themes import Default
20
- from gradio.themes.utils.colors import slate
21
- from gradio_litmodel3d import LitModel3D
22
- from asset3d_gen.models.delight import DelightingModel
23
- from asset3d_gen.models.segment import RembgRemover, SAMPredictor
24
- from asset3d_gen.models.super_resolution import ImageRealESRGAN
25
- from asset3d_gen.utils.gpt_clients import GPT_CLIENT
26
- from asset3d_gen.validators.quality_checkers import (
27
- ImageAestheticChecker,
28
- ImageSegChecker,
29
- MeshGeoChecker,
30
- )
31
- from asset3d_gen.validators.urdf_convertor import URDFGenerator
32
 
 
 
33
  TMP_DIR = os.path.join(
34
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
35
  )
 
1
  import os
2
+ import shutil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ import gradio as gr
5
+ from gradio_litmodel3d import LitModel3D
6
  TMP_DIR = os.path.join(
7
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
8
  )
asset3d_gen/data/backproject.py DELETED
@@ -1,503 +0,0 @@
1
- import argparse
2
- import logging
3
- import math
4
- import os
5
- from typing import List, Literal, Tuple, Union
6
-
7
- import cv2
8
- import numpy as np
9
- import nvdiffrast.torch as dr
10
- import torch
11
- import trimesh
12
- import utils3d
13
- import xatlas
14
- from tqdm import tqdm
15
- from asset3d_gen.data.mesh_operator import MeshFixer
16
- from asset3d_gen.data.utils import (
17
- CameraSetting,
18
- get_images_from_grid,
19
- init_kal_camera,
20
- normalize_vertices_array,
21
- post_process_texture,
22
- save_mesh_with_mtl,
23
- )
24
- from asset3d_gen.models.delight import DelightingModel
25
-
26
- logging.basicConfig(
27
- format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
28
- )
29
- logger = logging.getLogger(__name__)
30
-
31
-
32
- class TextureBaker(object):
33
- """Baking textures onto a mesh from multiple observations.
34
-
35
- This class take 3D mesh data, camera settings and texture baking parameters
36
- to generate texture map by projecting images to the mesh from diff views.
37
- It supports both a fast texture baking approach and a more optimized method
38
- with total variation regularization.
39
-
40
- Attributes:
41
- vertices (torch.Tensor): The vertices of the mesh.
42
- faces (torch.Tensor): The faces of the mesh, defined by vertex indices.
43
- uvs (torch.Tensor): The UV coordinates of the mesh.
44
- camera_params (CameraSetting): Camera setting (intrinsics, extrinsics).
45
- device (str): The device to run computations on ("cpu" or "cuda").
46
- w2cs (torch.Tensor): World-to-camera transformation matrices.
47
- projections (torch.Tensor): Camera projection matrices.
48
-
49
- Example:
50
- >>> vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces) # noqa
51
- >>> texture_backer = TextureBaker(vertices, faces, uvs, camera_params)
52
- >>> images = get_images_from_grid(args.input_image, image_size)
53
- >>> texture = texture_backer.bake_texture(
54
- ... images, texture_size=args.texture_size, mode=args.baker_mode
55
- ... )
56
- >>> texture = post_process_texture(texture)
57
- """
58
-
59
- def __init__(
60
- self,
61
- vertices: np.ndarray,
62
- faces: np.ndarray,
63
- uvs: np.ndarray,
64
- camera_params: CameraSetting,
65
- device: str = "cuda",
66
- ) -> None:
67
- self.vertices = (
68
- torch.tensor(vertices, device=device)
69
- if isinstance(vertices, np.ndarray)
70
- else vertices.to(device)
71
- )
72
- self.faces = (
73
- torch.tensor(faces.astype(np.int32), device=device)
74
- if isinstance(faces, np.ndarray)
75
- else faces.to(device)
76
- )
77
- self.uvs = (
78
- torch.tensor(uvs, device=device)
79
- if isinstance(uvs, np.ndarray)
80
- else uvs.to(device)
81
- )
82
- self.camera_params = camera_params
83
- self.device = device
84
-
85
- camera = init_kal_camera(camera_params)
86
- matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
87
- matrix_mv = kaolin_to_opencv_view(matrix_mv)
88
- matrix_p = (
89
- camera.intrinsics.projection_matrix()
90
- ) # (n_cam 4 4) cam2pixel
91
- self.w2cs = matrix_mv.to(self.device)
92
- self.projections = matrix_p.to(self.device)
93
-
94
- @staticmethod
95
- def parametrize_mesh(
96
- vertices: np.array, faces: np.array
97
- ) -> Union[np.array, np.array, np.array]:
98
- vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
99
-
100
- vertices = vertices[vmapping]
101
- faces = indices
102
-
103
- return vertices, faces, uvs
104
-
105
- def _bake_fast(self, observations, w2cs, projections, texture_size, masks):
106
- texture = torch.zeros(
107
- (texture_size * texture_size, 3), dtype=torch.float32
108
- ).cuda()
109
- texture_weights = torch.zeros(
110
- (texture_size * texture_size), dtype=torch.float32
111
- ).cuda()
112
- rastctx = utils3d.torch.RastContext(backend="cuda")
113
- for observation, w2c, projection in tqdm(
114
- zip(observations, w2cs, projections),
115
- total=len(observations),
116
- desc="Texture baking (fast)",
117
- ):
118
- with torch.no_grad():
119
- rast = utils3d.torch.rasterize_triangle_faces(
120
- rastctx,
121
- self.vertices[None],
122
- self.faces,
123
- observation.shape[1],
124
- observation.shape[0],
125
- uv=self.uvs[None],
126
- view=w2c,
127
- projection=projection,
128
- )
129
- uv_map = rast["uv"][0].detach().flip(0)
130
- mask = rast["mask"][0].detach().bool() & masks[0]
131
-
132
- # nearest neighbor interpolation
133
- uv_map = (uv_map * texture_size).floor().long()
134
- obs = observation[mask]
135
- uv_map = uv_map[mask]
136
- idx = (
137
- uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
138
- )
139
- texture = texture.scatter_add(
140
- 0, idx.view(-1, 1).expand(-1, 3), obs
141
- )
142
- texture_weights = texture_weights.scatter_add(
143
- 0,
144
- idx,
145
- torch.ones(
146
- (obs.shape[0]), dtype=torch.float32, device=texture.device
147
- ),
148
- )
149
-
150
- mask = texture_weights > 0
151
- texture[mask] /= texture_weights[mask][:, None]
152
- texture = np.clip(
153
- texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255,
154
- 0,
155
- 255,
156
- ).astype(np.uint8)
157
-
158
- # inpaint
159
- mask = (
160
- (texture_weights == 0)
161
- .cpu()
162
- .numpy()
163
- .astype(np.uint8)
164
- .reshape(texture_size, texture_size)
165
- )
166
- texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
167
-
168
- return texture
169
-
170
- def _bake_opt(
171
- self,
172
- observations,
173
- w2cs,
174
- projections,
175
- texture_size,
176
- lambda_tv,
177
- masks,
178
- total_steps,
179
- ):
180
- rastctx = utils3d.torch.RastContext(backend="cuda")
181
- observations = [observations.flip(0) for observations in observations]
182
- masks = [m.flip(0) for m in masks]
183
- _uv = []
184
- _uv_dr = []
185
- for observation, w2c, projection in tqdm(
186
- zip(observations, w2cs, projections),
187
- total=len(w2cs),
188
- ):
189
- with torch.no_grad():
190
- rast = utils3d.torch.rasterize_triangle_faces(
191
- rastctx,
192
- self.vertices[None],
193
- self.faces,
194
- observation.shape[1],
195
- observation.shape[0],
196
- uv=self.uvs[None],
197
- view=w2c,
198
- projection=projection,
199
- )
200
- _uv.append(rast["uv"].detach())
201
- _uv_dr.append(rast["uv_dr"].detach())
202
-
203
- texture = torch.nn.Parameter(
204
- torch.zeros(
205
- (1, texture_size, texture_size, 3), dtype=torch.float32
206
- ).cuda()
207
- )
208
- optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
209
-
210
- def cosine_anealing(step, total_steps, start_lr, end_lr):
211
- return end_lr + 0.5 * (start_lr - end_lr) * (
212
- 1 + np.cos(np.pi * step / total_steps)
213
- )
214
-
215
- def tv_loss(texture):
216
- return torch.nn.functional.l1_loss(
217
- texture[:, :-1, :, :], texture[:, 1:, :, :]
218
- ) + torch.nn.functional.l1_loss(
219
- texture[:, :, :-1, :], texture[:, :, 1:, :]
220
- )
221
-
222
- with tqdm(total=total_steps, desc="Texture baking") as pbar:
223
- for step in range(total_steps):
224
- optimizer.zero_grad()
225
- selected = np.random.randint(0, len(w2cs))
226
- uv, uv_dr, observation, mask = (
227
- _uv[selected],
228
- _uv_dr[selected],
229
- observations[selected],
230
- masks[selected],
231
- )
232
- render = dr.texture(texture, uv, uv_dr)[0]
233
- loss = torch.nn.functional.l1_loss(
234
- render[mask], observation[mask]
235
- )
236
- if lambda_tv > 0:
237
- loss += lambda_tv * tv_loss(texture)
238
- loss.backward()
239
- optimizer.step()
240
-
241
- optimizer.param_groups[0]["lr"] = cosine_anealing(
242
- step, total_steps, 1e-2, 1e-5
243
- )
244
- pbar.set_postfix({"loss": loss.item()})
245
- pbar.update()
246
- texture = np.clip(
247
- texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255
248
- ).astype(np.uint8)
249
- mask = 1 - utils3d.torch.rasterize_triangle_faces(
250
- rastctx,
251
- (self.uvs * 2 - 1)[None],
252
- self.faces,
253
- texture_size,
254
- texture_size,
255
- )["mask"][0].detach().cpu().numpy().astype(np.uint8)
256
- texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
257
-
258
- return texture
259
-
260
- def bake_texture(
261
- self,
262
- images: List[np.array],
263
- texture_size: int = 1024,
264
- mode: Literal["fast", "opt"] = "opt",
265
- lambda_tv: float = 1e-2,
266
- opt_step: int = 2000,
267
- ):
268
- masks = [np.any(img > 0, axis=-1) for img in images]
269
- masks = [torch.tensor(m > 0).bool().to(self.device) for m in masks]
270
- images = [
271
- torch.tensor(obs / 255.0).float().to(self.device) for obs in images
272
- ]
273
-
274
- if mode == "fast":
275
- return self._bake_fast(
276
- images, self.w2cs, self.projections, texture_size, masks
277
- )
278
- elif mode == "opt":
279
- return self._bake_opt(
280
- images,
281
- self.w2cs,
282
- self.projections,
283
- texture_size,
284
- lambda_tv,
285
- masks,
286
- opt_step,
287
- )
288
- else:
289
- raise ValueError(f"Unknown mode: {mode}")
290
-
291
-
292
- def kaolin_to_opencv_view(raw_matrix):
293
- R_orig = raw_matrix[:, :3, :3]
294
- t_orig = raw_matrix[:, :3, 3]
295
-
296
- R_target = torch.zeros_like(R_orig)
297
- R_target[:, :, 0] = R_orig[:, :, 2]
298
- R_target[:, :, 1] = R_orig[:, :, 0]
299
- R_target[:, :, 2] = R_orig[:, :, 1]
300
-
301
- t_target = t_orig
302
-
303
- target_matrix = (
304
- torch.eye(4, device=raw_matrix.device)
305
- .unsqueeze(0)
306
- .repeat(raw_matrix.size(0), 1, 1)
307
- )
308
- target_matrix[:, :3, :3] = R_target
309
- target_matrix[:, :3, 3] = t_target
310
-
311
- return target_matrix
312
-
313
-
314
- def parse_args():
315
- parser = argparse.ArgumentParser(description="Render settings")
316
-
317
- parser.add_argument(
318
- "--mesh_path",
319
- type=str,
320
- nargs="+",
321
- required=True,
322
- help="Paths to the mesh files for rendering.",
323
- )
324
- parser.add_argument(
325
- "--input_image",
326
- type=str,
327
- nargs="+",
328
- required=True,
329
- help="Paths to the mesh files for rendering.",
330
- )
331
- parser.add_argument(
332
- "--output_root",
333
- type=str,
334
- default="./outputs",
335
- help="Root directory for output",
336
- )
337
- parser.add_argument(
338
- "--uuid",
339
- type=str,
340
- nargs="+",
341
- default=None,
342
- help="uuid for rendering saving.",
343
- )
344
- parser.add_argument(
345
- "--num_images", type=int, default=6, help="Number of images to render."
346
- )
347
- parser.add_argument(
348
- "--elevation",
349
- type=float,
350
- nargs="+",
351
- default=[20.0, -10.0],
352
- help="Elevation angles for the camera (default: [20.0, -10.0])",
353
- )
354
- parser.add_argument(
355
- "--distance",
356
- type=float,
357
- default=5,
358
- help="Camera distance (default: 5)",
359
- )
360
- parser.add_argument(
361
- "--resolution_hw",
362
- type=int,
363
- nargs=2,
364
- default=(512, 512),
365
- help="Resolution of the output images (default: (512, 512))",
366
- )
367
- parser.add_argument(
368
- "--fov",
369
- type=float,
370
- default=30,
371
- help="Field of view in degrees (default: 30)",
372
- )
373
- parser.add_argument(
374
- "--device",
375
- type=str,
376
- choices=["cpu", "cuda"],
377
- default="cuda",
378
- help="Device to run on (default: `cuda`)",
379
- )
380
- parser.add_argument(
381
- "--texture_size",
382
- type=int,
383
- default=1024,
384
- help="Texture size for texture baking (default: 1024)",
385
- )
386
- parser.add_argument(
387
- "--baker_mode",
388
- type=str,
389
- default="opt",
390
- help="Texture baking mode, `fast` or `opt` (default: opt)",
391
- )
392
- parser.add_argument(
393
- "--opt_step",
394
- type=int,
395
- default=2500,
396
- help="Optimization steps for texture baking (default: 2500)",
397
- )
398
- parser.add_argument(
399
- "--mesh_sipmlify_ratio",
400
- type=float,
401
- default=0.9,
402
- help="Mesh simplification ratio (default: 0.9)",
403
- )
404
- parser.add_argument(
405
- "--no_coor_trans",
406
- action="store_true",
407
- help="Do not transform the asset coordinate system.",
408
- )
409
- parser.add_argument(
410
- "--delight", action="store_true", help="Use delighting model."
411
- )
412
- parser.add_argument(
413
- "--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
414
- )
415
-
416
- args = parser.parse_args()
417
-
418
- if args.uuid is None:
419
- args.uuid = []
420
- for path in args.mesh_path:
421
- uuid = os.path.basename(path).split(".")[0]
422
- args.uuid.append(uuid)
423
-
424
- return args
425
-
426
-
427
- def entrypoint() -> None:
428
- args = parse_args()
429
- camera_params = CameraSetting(
430
- num_images=args.num_images,
431
- elevation=args.elevation,
432
- distance=args.distance,
433
- resolution_hw=args.resolution_hw,
434
- fov=math.radians(args.fov),
435
- device=args.device,
436
- )
437
-
438
- for mesh_path, uuid, img_path in zip(
439
- args.mesh_path, args.uuid, args.input_image
440
- ):
441
- mesh = trimesh.load(mesh_path)
442
- if isinstance(mesh, trimesh.Scene):
443
- mesh = mesh.dump(concatenate=True)
444
- vertices, scale, center = normalize_vertices_array(mesh.vertices)
445
-
446
- if not args.no_coor_trans:
447
- x_rot = torch.Tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
448
- z_rot = torch.Tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
449
- vertices = vertices @ x_rot
450
- vertices = vertices @ z_rot
451
-
452
- faces = mesh.faces.cpu().numpy().astype(np.int32)
453
- vertices = vertices.cpu().numpy().astype(np.float32)
454
-
455
- if not args.skip_fix_mesh:
456
- mesh_fixer = MeshFixer(vertices, faces, args.device)
457
- vertices, faces = mesh_fixer(
458
- filter_ratio=args.mesh_sipmlify_ratio,
459
- max_hole_size=0.04,
460
- resolution=1024,
461
- num_views=1000,
462
- norm_mesh_ratio=0.5,
463
- )
464
-
465
- vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces)
466
- texture_backer = TextureBaker(
467
- vertices,
468
- faces,
469
- uvs,
470
- camera_params,
471
- )
472
- images = get_images_from_grid(
473
- img_path, img_size=camera_params.resolution_hw[0]
474
- )
475
- if args.delight:
476
- delight_model = DelightingModel(
477
- model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/hunyuan3d-delight-v2-0" # noqa
478
- )
479
- delight_images = [delight_model(img) for img in images]
480
- images = [np.array(img) for img in delight_images]
481
-
482
- texture = texture_backer.bake_texture(
483
- images=[img[..., :3] for img in images],
484
- texture_size=args.texture_size,
485
- mode=args.baker_mode,
486
- opt_step=args.opt_step,
487
- )
488
- texture = post_process_texture(texture)
489
-
490
- if not args.no_coor_trans:
491
- vertices = vertices @ np.linalg.inv(z_rot)
492
- vertices = vertices @ np.linalg.inv(x_rot)
493
- vertices = vertices / scale
494
- vertices = vertices + center
495
-
496
- output_path = os.path.join(args.output_root, f"{uuid}.obj")
497
- mesh = save_mesh_with_mtl(vertices, faces, uvs, texture, output_path)
498
-
499
- return
500
-
501
-
502
- if __name__ == "__main__":
503
- entrypoint()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/backproject_v2.py DELETED
@@ -1,613 +0,0 @@
1
- import argparse
2
- import logging
3
- import math
4
- import os
5
-
6
- import cv2
7
- import numpy as np
8
- import nvdiffrast.torch as dr
9
- import torch
10
- import torch.nn.functional as F
11
- import trimesh
12
- import xatlas
13
- from PIL import Image
14
- from asset3d_gen.data.mesh_operator import MeshFixer
15
- from asset3d_gen.data.utils import (
16
- CameraSetting,
17
- DiffrastRender,
18
- get_images_from_grid,
19
- init_kal_camera,
20
- normalize_vertices_array,
21
- post_process_texture,
22
- save_mesh_with_mtl,
23
- )
24
- from asset3d_gen.models.delight import DelightingModel
25
- from asset3d_gen.models.super_resolution import ImageRealESRGAN
26
-
27
- logging.basicConfig(
28
- format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
29
- )
30
- logger = logging.getLogger(__name__)
31
-
32
-
33
- __all__ = [
34
- "TextureBacker",
35
- ]
36
-
37
-
38
- def transform_vertices(
39
- mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
40
- ) -> torch.Tensor:
41
- """Transform 3D vertices using a projection matrix."""
42
- t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
43
- if pos.size(-1) == 3:
44
- pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
45
-
46
- result = pos @ t_mtx.T
47
-
48
- return result if keepdim else result.unsqueeze(0)
49
-
50
-
51
- def _bilinear_interpolation_scattering(
52
- image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
53
- ) -> torch.Tensor:
54
- """Bilinear interpolation scattering for grid-based value accumulation."""
55
- device = values.device
56
- dtype = values.dtype
57
- C = values.shape[-1]
58
-
59
- indices = coords * torch.tensor(
60
- [image_h - 1, image_w - 1], dtype=dtype, device=device
61
- )
62
- i, j = indices.unbind(-1)
63
-
64
- i0, j0 = (
65
- indices.floor()
66
- .long()
67
- .clamp(0, image_h - 2)
68
- .clamp(0, image_w - 2)
69
- .unbind(-1)
70
- )
71
- i1, j1 = i0 + 1, j0 + 1
72
-
73
- w_i = i - i0.float()
74
- w_j = j - j0.float()
75
- weights = torch.stack(
76
- [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
77
- dim=1,
78
- )
79
-
80
- indices_comb = torch.stack(
81
- [
82
- torch.stack([i0, j0], dim=1),
83
- torch.stack([i0, j1], dim=1),
84
- torch.stack([i1, j0], dim=1),
85
- torch.stack([i1, j1], dim=1),
86
- ],
87
- dim=1,
88
- )
89
-
90
- grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
91
- cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
92
-
93
- for k in range(4):
94
- idx = indices_comb[:, k]
95
- w = weights[:, k].unsqueeze(-1)
96
-
97
- stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
98
- flat_idx = (idx * stride).sum(-1)
99
-
100
- grid.view(-1, C).scatter_add_(
101
- 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
102
- )
103
- cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
104
-
105
- mask = cnt.squeeze(-1) > 0
106
- grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
107
-
108
- return grid
109
-
110
-
111
- def _texture_inpaint_smooth(
112
- texture: np.ndarray,
113
- mask: np.ndarray,
114
- vertices: np.ndarray,
115
- faces: np.ndarray,
116
- uv_map: np.ndarray,
117
- ) -> tuple[np.ndarray, np.ndarray]:
118
- """Perform texture inpainting using vertex-based color propagation."""
119
- image_h, image_w, C = texture.shape
120
- N = vertices.shape[0]
121
-
122
- # Initialize vertex data structures
123
- vtx_mask = np.zeros(N, dtype=np.float32)
124
- vtx_colors = np.zeros((N, C), dtype=np.float32)
125
- unprocessed = []
126
- adjacency = [[] for _ in range(N)]
127
-
128
- # Build adjacency graph and initial color assignment
129
- for face_idx in range(faces.shape[0]):
130
- for k in range(3):
131
- uv_idx_k = faces[face_idx, k]
132
- v_idx = faces[face_idx, k]
133
-
134
- # Convert UV to pixel coordinates with boundary clamping
135
- u = np.clip(
136
- int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
137
- )
138
- v = np.clip(
139
- int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
140
- 0,
141
- image_h - 1,
142
- )
143
-
144
- if mask[v, u]:
145
- vtx_mask[v_idx] = 1.0
146
- vtx_colors[v_idx] = texture[v, u]
147
- elif v_idx not in unprocessed:
148
- unprocessed.append(v_idx)
149
-
150
- # Build undirected adjacency graph
151
- neighbor = faces[face_idx, (k + 1) % 3]
152
- if neighbor not in adjacency[v_idx]:
153
- adjacency[v_idx].append(neighbor)
154
- if v_idx not in adjacency[neighbor]:
155
- adjacency[neighbor].append(v_idx)
156
-
157
- # Color propagation with dynamic stopping
158
- remaining_iters, prev_count = 2, 0
159
- while remaining_iters > 0:
160
- current_unprocessed = []
161
-
162
- for v_idx in unprocessed:
163
- valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
164
- if not valid_neighbors:
165
- current_unprocessed.append(v_idx)
166
- continue
167
-
168
- # Calculate inverse square distance weights
169
- neighbors_pos = vertices[valid_neighbors]
170
- dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
171
- weights = 1 / np.maximum(dist_sq, 1e-8)
172
-
173
- vtx_colors[v_idx] = np.average(
174
- vtx_colors[valid_neighbors], weights=weights, axis=0
175
- )
176
- vtx_mask[v_idx] = 1.0
177
-
178
- # Update iteration control
179
- if len(current_unprocessed) == prev_count:
180
- remaining_iters -= 1
181
- else:
182
- remaining_iters = min(remaining_iters + 1, 2)
183
- prev_count = len(current_unprocessed)
184
- unprocessed = current_unprocessed
185
-
186
- # Generate output texture
187
- inpainted_texture, updated_mask = texture.copy(), mask.copy()
188
- for face_idx in range(faces.shape[0]):
189
- for k in range(3):
190
- v_idx = faces[face_idx, k]
191
- if not vtx_mask[v_idx]:
192
- continue
193
-
194
- # UV coordinate conversion
195
- uv_idx_k = faces[face_idx, k]
196
- u = np.clip(
197
- int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
198
- )
199
- v = np.clip(
200
- int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
201
- 0,
202
- image_h - 1,
203
- )
204
-
205
- inpainted_texture[v, u] = vtx_colors[v_idx]
206
- updated_mask[v, u] = 255
207
-
208
- return inpainted_texture, updated_mask
209
-
210
-
211
- class TextureBacker:
212
- """Texture baking pipeline for multi-view projection and fusion."""
213
-
214
- def __init__(
215
- self,
216
- camera_params: CameraSetting,
217
- view_weights: list[float],
218
- render_wh: tuple[int, int] = (2048, 2048),
219
- texture_wh: tuple[int, int] = (2048, 2048),
220
- bake_angle_thresh: int = 75,
221
- mask_thresh: float = 0.5,
222
- ):
223
- camera = init_kal_camera(camera_params)
224
- mv = camera.view_matrix() # (n 4 4) world2cam
225
- p = camera.intrinsics.projection_matrix()
226
- # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
227
- p[:, 1, 1] = -p[:, 1, 1]
228
- renderer = DiffrastRender(
229
- p_matrix=p,
230
- mv_matrix=mv,
231
- resolution_hw=camera_params.resolution_hw,
232
- context=dr.RasterizeCudaContext(),
233
- mask_thresh=mask_thresh,
234
- grad_db=False,
235
- device=camera_params.device,
236
- antialias_mask=True,
237
- )
238
- self.camera = camera
239
- self.renderer = renderer
240
- self.view_weights = view_weights
241
- self.device = camera_params.device
242
- self.render_wh = render_wh
243
- self.texture_wh = texture_wh
244
-
245
- self.bake_angle_thresh = bake_angle_thresh
246
- self.bake_unreliable_kernel_size = int(
247
- (2 / 512) * max(self.render_wh[0], self.render_wh[1])
248
- )
249
-
250
- def load_mesh(self, mesh: trimesh.Trimesh) -> None:
251
- mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
252
- self.scale, self.center = scale, center
253
-
254
- vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
255
- uvs[:, 1] = 1 - uvs[:, 1]
256
- mesh.vertices = mesh.vertices[vmapping]
257
- mesh.faces = indices
258
- mesh.visual.uv = uvs
259
-
260
- self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
261
- self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
262
- self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
263
-
264
- def get_mesh_np_attrs(
265
- self,
266
- scale: float = None,
267
- center: np.ndarray = None,
268
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
269
- vertices = self.vertices.cpu().numpy()
270
- faces = self.faces.cpu().numpy()
271
- uv_map = self.uv_map.cpu().numpy()
272
- uv_map[:, 1] = 1.0 - uv_map[:, 1]
273
-
274
- if scale is not None:
275
- vertices = vertices / scale
276
- if center is not None:
277
- vertices = vertices + center
278
-
279
- return vertices, faces, uv_map
280
-
281
- def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
282
- depth_image_np = depth_image.cpu().numpy()
283
- depth_image_np = (depth_image_np * 255).astype(np.uint8)
284
- depth_edges = cv2.Canny(depth_image_np, 30, 80)
285
- sketch_image = (
286
- torch.from_numpy(depth_edges).to(depth_image.device).float() / 255
287
- )
288
- sketch_image = sketch_image.unsqueeze(-1)
289
-
290
- return sketch_image
291
-
292
- def compute_enhanced_viewnormal(
293
- self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
294
- ) -> torch.Tensor:
295
- rast, _ = self.renderer.compute_dr_raster(vertices, faces)
296
- rendered_view_normals = []
297
- for idx in range(len(mv_mtx)):
298
- pos_cam = transform_vertices(mv_mtx[idx], vertices, keepdim=True)
299
- pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
300
- v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3))
301
- face_norm = F.normalize(
302
- torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1
303
- )
304
- vertex_norm = (
305
- torch.from_numpy(
306
- trimesh.geometry.mean_vertex_normals(
307
- len(pos_cam), faces.cpu(), face_norm.cpu()
308
- )
309
- )
310
- .to(vertices.device)
311
- .contiguous()
312
- )
313
- im_base_normals, _ = dr.interpolate(
314
- vertex_norm[None, ...].float(),
315
- rast[idx : idx + 1],
316
- faces.to(torch.int32),
317
- )
318
- rendered_view_normals.append(im_base_normals)
319
-
320
- rendered_view_normals = torch.cat(rendered_view_normals, dim=0)
321
-
322
- return rendered_view_normals
323
-
324
- def back_project(
325
- self, image, vis_mask, depth, normal, uv
326
- ) -> tuple[torch.Tensor, torch.Tensor]:
327
- image = np.array(image)
328
- image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
329
- if image.ndim == 2:
330
- image = image.unsqueeze(-1)
331
- image = image / 255
332
-
333
- depth_inv = (1.0 - depth) * vis_mask
334
- sketch_image = self._render_depth_edges(depth_inv)
335
-
336
- cos = F.cosine_similarity(
337
- torch.tensor([[0, 0, 1]], device=self.device),
338
- normal.view(-1, 3),
339
- ).view_as(normal[..., :1])
340
- cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
341
-
342
- k = self.bake_unreliable_kernel_size * 2 + 1
343
- kernel = torch.ones((1, 1, k, k), device=self.device)
344
-
345
- vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
346
- vis_mask = F.conv2d(
347
- 1.0 - vis_mask,
348
- kernel,
349
- padding=k // 2,
350
- )
351
- vis_mask = 1.0 - (vis_mask > 0).float()
352
- vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
353
-
354
- sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
355
- sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
356
- sketch_image = (sketch_image > 0).float()
357
- sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
358
- vis_mask = vis_mask * (sketch_image < 0.5)
359
-
360
- cos[vis_mask == 0] = 0
361
- valid_pixels = (vis_mask != 0).view(-1)
362
-
363
- return (
364
- self._scatter_texture(uv, image, valid_pixels),
365
- self._scatter_texture(uv, cos, valid_pixels),
366
- )
367
-
368
- def _scatter_texture(self, uv, data, mask):
369
- def __filter_data(data, mask):
370
- return data.view(-1, data.shape[-1])[mask]
371
-
372
- return _bilinear_interpolation_scattering(
373
- self.texture_wh[1],
374
- self.texture_wh[0],
375
- __filter_data(uv, mask)[..., [1, 0]],
376
- __filter_data(data, mask),
377
- )
378
-
379
- @torch.no_grad()
380
- def fast_bake_texture(
381
- self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
382
- ) -> tuple[torch.Tensor, torch.Tensor]:
383
- channel = textures[0].shape[-1]
384
- texture_merge = torch.zeros(self.texture_wh + [channel]).to(
385
- self.device
386
- )
387
- trust_map_merge = torch.zeros(self.texture_wh + [1]).to(self.device)
388
- for texture, cos_map in zip(textures, confidence_maps):
389
- view_sum = (cos_map > 0).sum()
390
- painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
391
- if painted_sum / view_sum > 0.99:
392
- continue
393
- texture_merge += texture * cos_map
394
- trust_map_merge += cos_map
395
- texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
396
-
397
- return texture_merge, trust_map_merge > 1e-8
398
-
399
- def uv_inpaint(
400
- self, texture: torch.Tensor, mask: torch.Tensor
401
- ) -> np.ndarray:
402
- texture_np = texture.cpu().numpy()
403
- mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
404
- vertices, faces, uv_map = self.get_mesh_np_attrs()
405
-
406
- texture_np, mask_np = _texture_inpaint_smooth(
407
- texture_np, mask_np, vertices, faces, uv_map
408
- )
409
- texture_np = texture_np.clip(0, 1)
410
- texture_np = cv2.inpaint(
411
- (texture_np * 255).astype(np.uint8),
412
- 255 - mask_np,
413
- 3,
414
- cv2.INPAINT_NS,
415
- )
416
-
417
- return texture_np
418
-
419
- def __call__(
420
- self,
421
- colors: list[Image.Image],
422
- mesh: trimesh.Trimesh,
423
- output_path: str,
424
- ) -> trimesh.Trimesh:
425
- self.load_mesh(mesh)
426
- rendered_depth, masks = self.renderer.render_depth(
427
- self.vertices, self.faces
428
- )
429
- norm_deps = self.renderer.normalize_map_by_mask(rendered_depth, masks)
430
- render_uvs, _ = self.renderer.render_uv(
431
- self.vertices, self.faces, self.uv_map
432
- )
433
- view_normals = self.compute_enhanced_viewnormal(
434
- self.renderer.mv_mtx, self.vertices, self.faces
435
- )
436
-
437
- textures, weighted_cos_maps = [], []
438
- for color, mask, dep, normal, uv, weight in zip(
439
- colors,
440
- masks,
441
- norm_deps,
442
- view_normals,
443
- render_uvs,
444
- self.view_weights,
445
- ):
446
- texture, cos_map = self.back_project(color, mask, dep, normal, uv)
447
- textures.append(texture)
448
- weighted_cos_maps.append(weight * (cos_map**4))
449
-
450
- texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
451
- texture_np = self.uv_inpaint(texture, mask)
452
- texture_np = post_process_texture(texture_np)
453
- vertices, faces, uv_map = self.get_mesh_np_attrs(
454
- self.scale, self.center
455
- )
456
-
457
- textured_mesh = save_mesh_with_mtl(
458
- vertices, faces, uv_map, texture_np, output_path
459
- )
460
-
461
- return textured_mesh
462
-
463
-
464
- def parse_args():
465
- parser = argparse.ArgumentParser(description="Backproject texture")
466
- parser.add_argument(
467
- "--color_path",
468
- type=str,
469
- help="Multiview color image in 6x512x512 file path",
470
- )
471
- parser.add_argument(
472
- "--mesh_path",
473
- type=str,
474
- help="Mesh path, .obj, .glb or .ply",
475
- )
476
- parser.add_argument(
477
- "--output_path",
478
- type=str,
479
- help="Output mesh path with suffix",
480
- )
481
- parser.add_argument(
482
- "--num_images", type=int, default=6, help="Number of images to render."
483
- )
484
- parser.add_argument(
485
- "--elevation",
486
- nargs=2,
487
- type=float,
488
- default=[20.0, -10.0],
489
- help="Elevation angles for the camera (default: [20.0, -10.0])",
490
- )
491
- parser.add_argument(
492
- "--distance",
493
- type=float,
494
- default=5,
495
- help="Camera distance (default: 5)",
496
- )
497
- parser.add_argument(
498
- "--resolution_hw",
499
- type=int,
500
- nargs=2,
501
- default=(2048, 2048),
502
- help="Resolution of the output images (default: (2048, 2048))",
503
- )
504
- parser.add_argument(
505
- "--fov",
506
- type=float,
507
- default=30,
508
- help="Field of view in degrees (default: 30)",
509
- )
510
- parser.add_argument(
511
- "--device",
512
- type=str,
513
- choices=["cpu", "cuda"],
514
- default="cuda",
515
- help="Device to run on (default: `cuda`)",
516
- )
517
- parser.add_argument(
518
- "--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
519
- )
520
- parser.add_argument(
521
- "--texture_wh",
522
- nargs=2,
523
- type=int,
524
- default=[2048, 2048],
525
- help="Texture resolution width and height",
526
- )
527
- parser.add_argument(
528
- "--mesh_sipmlify_ratio",
529
- type=float,
530
- default=0.9,
531
- help="Mesh simplification ratio (default: 0.9)",
532
- )
533
- parser.add_argument(
534
- "--delight", action="store_true", help="Use delighting model."
535
- )
536
- args = parser.parse_args()
537
-
538
- return args
539
-
540
-
541
- def entrypoint(
542
- delight_model: DelightingModel = None,
543
- imagesr_model: ImageRealESRGAN = None,
544
- **kwargs,
545
- ) -> trimesh.Trimesh:
546
- args = parse_args()
547
- for k, v in kwargs.items():
548
- if hasattr(args, k) and v is not None:
549
- setattr(args, k, v)
550
-
551
- # Setup camera parameters.
552
- camera_params = CameraSetting(
553
- num_images=args.num_images,
554
- elevation=args.elevation,
555
- distance=args.distance,
556
- resolution_hw=args.resolution_hw,
557
- fov=math.radians(args.fov),
558
- device=args.device,
559
- )
560
- view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
561
-
562
- color_grid = Image.open(args.color_path)
563
- if args.delight:
564
- if delight_model is None:
565
- delight_model = DelightingModel(
566
- model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/hunyuan3d-delight-v2-0" # noqa
567
- )
568
- save_dir = os.path.dirname(args.output_path)
569
- os.makedirs(save_dir, exist_ok=True)
570
- color_grid.save(f"{save_dir}/color_grid.png")
571
- color_grid = delight_model(color_grid)
572
- color_grid.save(f"{save_dir}/color_grid_delight.png")
573
-
574
- multiviews = get_images_from_grid(color_grid, img_size=512)
575
-
576
- # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
577
- if imagesr_model is None:
578
- imagesr_model = ImageRealESRGAN(outscale=4)
579
- multiviews = [imagesr_model(img) for img in multiviews]
580
- multiviews = [img.convert("RGB") for img in multiviews]
581
- mesh = trimesh.load(args.mesh_path)
582
- if isinstance(mesh, trimesh.Scene):
583
- mesh = mesh.dump(concatenate=True)
584
-
585
- if not args.skip_fix_mesh:
586
- mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
587
- mesh_fixer = MeshFixer(mesh.vertices, mesh.faces, args.device)
588
- mesh.vertices, mesh.faces = mesh_fixer(
589
- filter_ratio=args.mesh_sipmlify_ratio,
590
- max_hole_size=0.04,
591
- resolution=1024,
592
- num_views=1000,
593
- norm_mesh_ratio=0.5,
594
- )
595
- # Restore scale.
596
- mesh.vertices = mesh.vertices / scale
597
- mesh.vertices = mesh.vertices + center
598
-
599
- # Baking texture to mesh.
600
- texture_backer = TextureBacker(
601
- camera_params=camera_params,
602
- view_weights=view_weights,
603
- render_wh=camera_params.resolution_hw,
604
- texture_wh=args.texture_wh,
605
- )
606
-
607
- textured_mesh = texture_backer(multiviews, mesh, args.output_path)
608
-
609
- return textured_mesh
610
-
611
-
612
- if __name__ == "__main__":
613
- entrypoint()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/backup/backproject_v2 copy.py DELETED
@@ -1,650 +0,0 @@
1
- import argparse
2
- import logging
3
- import math
4
- import os
5
-
6
- import cv2
7
- import numpy as np
8
- import nvdiffrast.torch as dr
9
- import torch
10
- import torch.nn.functional as F
11
- from torchvision.transforms import functional as tF
12
-
13
- import trimesh
14
- import xatlas
15
- from PIL import Image
16
- from asset3d_gen.data.mesh_operator import MeshFixer
17
- from asset3d_gen.data.utils import (
18
- CameraSetting,
19
- DiffrastRender,
20
- get_images_from_grid,
21
- init_kal_camera,
22
- normalize_vertices_array,
23
- post_process_texture,
24
- save_mesh_with_mtl,
25
- )
26
- from asset3d_gen.models.delight import DelightingModel
27
- from asset3d_gen.models.super_resolution import ImageRealESRGAN
28
-
29
- logging.basicConfig(
30
- format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
31
- )
32
- logger = logging.getLogger(__name__)
33
-
34
-
35
- __all__ = [
36
- "TextureBacker",
37
- ]
38
-
39
-
40
- def transform_vertices(
41
- mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
42
- ) -> torch.Tensor:
43
- """Transform 3D vertices using a projection matrix."""
44
- t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
45
- if pos.size(-1) == 3:
46
- pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
47
-
48
- result = pos @ t_mtx.T
49
-
50
- return result if keepdim else result.unsqueeze(0)
51
-
52
-
53
- def _bilinear_interpolation_scattering(
54
- image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
55
- ) -> torch.Tensor:
56
- """Bilinear interpolation scattering for grid-based value accumulation."""
57
- device = values.device
58
- dtype = values.dtype
59
- C = values.shape[-1]
60
-
61
- indices = coords * torch.tensor(
62
- [image_h - 1, image_w - 1], dtype=dtype, device=device
63
- )
64
- i, j = indices.unbind(-1)
65
-
66
- i0, j0 = (
67
- indices.floor()
68
- .long()
69
- .clamp(0, image_h - 2)
70
- .clamp(0, image_w - 2)
71
- .unbind(-1)
72
- )
73
- i1, j1 = i0 + 1, j0 + 1
74
-
75
- w_i = i - i0.float()
76
- w_j = j - j0.float()
77
- weights = torch.stack(
78
- [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
79
- dim=1,
80
- )
81
-
82
- indices_comb = torch.stack(
83
- [
84
- torch.stack([i0, j0], dim=1),
85
- torch.stack([i0, j1], dim=1),
86
- torch.stack([i1, j0], dim=1),
87
- torch.stack([i1, j1], dim=1),
88
- ],
89
- dim=1,
90
- )
91
-
92
- grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
93
- cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
94
-
95
- for k in range(4):
96
- idx = indices_comb[:, k]
97
- w = weights[:, k].unsqueeze(-1)
98
-
99
- stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
100
- flat_idx = (idx * stride).sum(-1)
101
-
102
- grid.view(-1, C).scatter_add_(
103
- 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
104
- )
105
- cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
106
-
107
- mask = cnt.squeeze(-1) > 0
108
- grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
109
-
110
- return grid
111
-
112
-
113
- def _texture_inpaint_smooth(
114
- texture: np.ndarray,
115
- mask: np.ndarray,
116
- vertices: np.ndarray,
117
- faces: np.ndarray,
118
- uv_map: np.ndarray,
119
- ) -> tuple[np.ndarray, np.ndarray]:
120
- """Perform texture inpainting using vertex-based color propagation."""
121
- image_h, image_w, C = texture.shape
122
- N = vertices.shape[0]
123
-
124
- # Initialize vertex data structures
125
- vtx_mask = np.zeros(N, dtype=np.float32)
126
- vtx_colors = np.zeros((N, C), dtype=np.float32)
127
- unprocessed = []
128
- adjacency = [[] for _ in range(N)]
129
-
130
- # Build adjacency graph and initial color assignment
131
- for face_idx in range(faces.shape[0]):
132
- for k in range(3):
133
- uv_idx_k = faces[face_idx, k]
134
- v_idx = faces[face_idx, k]
135
-
136
- # Convert UV to pixel coordinates with boundary clamping
137
- u = np.clip(
138
- int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
139
- )
140
- v = np.clip(
141
- int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
142
- 0,
143
- image_h - 1,
144
- )
145
-
146
- if mask[v, u]:
147
- vtx_mask[v_idx] = 1.0
148
- vtx_colors[v_idx] = texture[v, u]
149
- elif v_idx not in unprocessed:
150
- unprocessed.append(v_idx)
151
-
152
- # Build undirected adjacency graph
153
- neighbor = faces[face_idx, (k + 1) % 3]
154
- if neighbor not in adjacency[v_idx]:
155
- adjacency[v_idx].append(neighbor)
156
- if v_idx not in adjacency[neighbor]:
157
- adjacency[neighbor].append(v_idx)
158
-
159
- # Color propagation with dynamic stopping
160
- remaining_iters, prev_count = 2, 0
161
- while remaining_iters > 0:
162
- current_unprocessed = []
163
-
164
- for v_idx in unprocessed:
165
- valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
166
- if not valid_neighbors:
167
- current_unprocessed.append(v_idx)
168
- continue
169
-
170
- # Calculate inverse square distance weights
171
- neighbors_pos = vertices[valid_neighbors]
172
- dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
173
- weights = 1 / np.maximum(dist_sq, 1e-8)
174
-
175
- vtx_colors[v_idx] = np.average(
176
- vtx_colors[valid_neighbors], weights=weights, axis=0
177
- )
178
- vtx_mask[v_idx] = 1.0
179
-
180
- # Update iteration control
181
- if len(current_unprocessed) == prev_count:
182
- remaining_iters -= 1
183
- else:
184
- remaining_iters = min(remaining_iters + 1, 2)
185
- prev_count = len(current_unprocessed)
186
- unprocessed = current_unprocessed
187
-
188
- # Generate output texture
189
- inpainted_texture, updated_mask = texture.copy(), mask.copy()
190
- for face_idx in range(faces.shape[0]):
191
- for k in range(3):
192
- v_idx = faces[face_idx, k]
193
- if not vtx_mask[v_idx]:
194
- continue
195
-
196
- # UV coordinate conversion
197
- uv_idx_k = faces[face_idx, k]
198
- u = np.clip(
199
- int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
200
- )
201
- v = np.clip(
202
- int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
203
- 0,
204
- image_h - 1,
205
- )
206
-
207
- inpainted_texture[v, u] = vtx_colors[v_idx]
208
- updated_mask[v, u] = 255
209
-
210
- return inpainted_texture, updated_mask
211
-
212
-
213
- def interp_tensers(tensors: list[torch.Tensor], target_wh: tuple[int, int]) -> list[torch.Tensor]:
214
- for idx in range(len(tensors)):
215
- tensor = tensors[idx].permute(2, 0, 1)
216
- tensor = tF.resize(tensor, target_wh[::-1], antialias=True)
217
- tensors[idx] = tensor.permute(1, 2, 0)
218
-
219
- return tensors
220
-
221
-
222
- class TextureBacker:
223
- """Texture baking pipeline for multi-view projection and fusion."""
224
-
225
- def __init__(
226
- self,
227
- camera_params: CameraSetting,
228
- view_weights: list[float],
229
- render_wh: tuple[int, int] = (2048, 2048),
230
- texture_wh: tuple[int, int] = (2048, 2048),
231
- bake_angle_thresh: int = 75,
232
- mask_thresh: float = 0.5,
233
- ):
234
- camera = init_kal_camera(camera_params)
235
- mv = camera.view_matrix() # (n 4 4) world2cam
236
- p = camera.intrinsics.projection_matrix()
237
- # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
238
- p[:, 1, 1] = -p[:, 1, 1]
239
- self.renderer = DiffrastRender(
240
- p_matrix=p,
241
- mv_matrix=mv,
242
- resolution_hw=camera_params.resolution_hw,
243
- context=dr.RasterizeCudaContext(),
244
- mask_thresh=mask_thresh,
245
- grad_db=False,
246
- device=camera_params.device,
247
- antialias_mask=True,
248
- )
249
- self.camera = camera
250
- self.view_weights = view_weights
251
- self.device = camera_params.device
252
- self.render_wh = render_wh
253
- self.texture_wh = texture_wh
254
-
255
- self.bake_angle_thresh = bake_angle_thresh
256
- self.bake_unreliable_kernel_size = int(
257
- (2 / 512) * max(self.render_wh[0], self.render_wh[1])
258
- )
259
-
260
- def load_mesh(self, mesh: trimesh.Trimesh) -> None:
261
- mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
262
- self.scale, self.center = scale, center
263
-
264
- vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
265
- uvs[:, 1] = 1 - uvs[:, 1]
266
- mesh.vertices = mesh.vertices[vmapping]
267
- mesh.faces = indices
268
- mesh.visual.uv = uvs
269
-
270
- self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
271
- self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
272
- self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
273
-
274
- def get_mesh_np_attrs(
275
- self,
276
- scale: float = None,
277
- center: np.ndarray = None,
278
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
279
- vertices = self.vertices.cpu().numpy()
280
- faces = self.faces.cpu().numpy()
281
- uv_map = self.uv_map.cpu().numpy()
282
- uv_map[:, 1] = 1.0 - uv_map[:, 1]
283
-
284
- if scale is not None:
285
- vertices = vertices / scale
286
- if center is not None:
287
- vertices = vertices + center
288
-
289
- return vertices, faces, uv_map
290
-
291
- def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
292
- depth_image_np = depth_image.cpu().numpy()
293
- depth_image_np = (depth_image_np * 255).astype(np.uint8)
294
- depth_edges = cv2.Canny(depth_image_np, 30, 80)
295
- sketch_image = (
296
- torch.from_numpy(depth_edges).to(depth_image.device).float() / 255
297
- )
298
- sketch_image = sketch_image.unsqueeze(-1)
299
-
300
- return sketch_image
301
-
302
- def compute_enhanced_viewnormal(
303
- self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
304
- ) -> torch.Tensor:
305
- rast, _ = self.renderer.compute_dr_raster(vertices, faces)
306
- rendered_view_normals = []
307
- for idx in range(len(mv_mtx)):
308
- pos_cam = transform_vertices(mv_mtx[idx], vertices, keepdim=True)
309
- pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
310
- v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3))
311
- face_norm = F.normalize(
312
- torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1
313
- )
314
- vertex_norm = (
315
- torch.from_numpy(
316
- trimesh.geometry.mean_vertex_normals(
317
- len(pos_cam), faces.cpu(), face_norm.cpu()
318
- )
319
- )
320
- .to(vertices.device)
321
- .contiguous()
322
- )
323
- im_base_normals, _ = dr.interpolate(
324
- vertex_norm[None, ...].float(),
325
- rast[idx : idx + 1],
326
- faces.to(torch.int32),
327
- )
328
- rendered_view_normals.append(im_base_normals)
329
-
330
- rendered_view_normals = torch.cat(rendered_view_normals, dim=0)
331
-
332
- return rendered_view_normals
333
-
334
- def back_project(
335
- self, image, vis_mask, depth, normal, uv
336
- ) -> tuple[torch.Tensor, torch.Tensor]:
337
- image = np.array(image)
338
- image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
339
- if image.ndim == 2:
340
- image = image.unsqueeze(-1)
341
- image = image / 255
342
-
343
- depth_inv = (1.0 - depth) * vis_mask
344
- sketch_image = self._render_depth_edges(depth_inv)
345
-
346
- cos = F.cosine_similarity(
347
- torch.tensor([[0, 0, 1]], device=self.device),
348
- normal.view(-1, 3),
349
- ).view_as(normal[..., :1])
350
- cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
351
-
352
- k = self.bake_unreliable_kernel_size * 2 + 1
353
- kernel = torch.ones((1, 1, k, k), device=self.device)
354
-
355
- vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
356
- vis_mask = F.conv2d(
357
- 1.0 - vis_mask,
358
- kernel,
359
- padding=k // 2,
360
- )
361
- vis_mask = 1.0 - (vis_mask > 0).float()
362
- vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
363
-
364
- sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
365
- sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
366
- sketch_image = (sketch_image > 0).float()
367
- sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
368
- vis_mask = vis_mask * (sketch_image < 0.5)
369
-
370
- cos[vis_mask == 0] = 0
371
- valid_pixels = (vis_mask != 0).view(-1)
372
-
373
- return (
374
- self._scatter_texture(uv, image, valid_pixels),
375
- self._scatter_texture(uv, cos, valid_pixels),
376
- )
377
-
378
- def _scatter_texture(self, uv, data, mask):
379
- def __filter_data(data, mask):
380
- return data.view(-1, data.shape[-1])[mask]
381
-
382
- return _bilinear_interpolation_scattering(
383
- self.texture_wh[1],
384
- self.texture_wh[0],
385
- __filter_data(uv, mask)[..., [1, 0]],
386
- __filter_data(data, mask),
387
- )
388
-
389
- @torch.no_grad()
390
- def fast_bake_texture(
391
- self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
392
- ) -> tuple[torch.Tensor, torch.Tensor]:
393
- channel = textures[0].shape[-1]
394
- texture_merge = torch.zeros(self.texture_wh + [channel]).to(
395
- self.device
396
- )
397
- trust_map_merge = torch.zeros(self.texture_wh + [1]).to(self.device)
398
- for texture, cos_map in zip(textures, confidence_maps):
399
- view_sum = (cos_map > 0).sum()
400
- painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
401
- if painted_sum / view_sum > 0.99:
402
- continue
403
- texture_merge += texture * cos_map
404
- trust_map_merge += cos_map
405
- texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
406
-
407
- return texture_merge, trust_map_merge > 1e-8
408
-
409
- def uv_inpaint(
410
- self, texture: torch.Tensor, mask: torch.Tensor
411
- ) -> np.ndarray:
412
- texture_np = texture.cpu().numpy()
413
- mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
414
- vertices, faces, uv_map = self.get_mesh_np_attrs()
415
-
416
- texture_np, mask_np = _texture_inpaint_smooth(
417
- texture_np, mask_np, vertices, faces, uv_map
418
- )
419
- texture_np = texture_np.clip(0, 1)
420
- texture_np = cv2.inpaint(
421
- (texture_np * 255).astype(np.uint8),
422
- 255 - mask_np,
423
- 3,
424
- cv2.INPAINT_NS,
425
- )
426
-
427
- return texture_np
428
-
429
- def __call__(
430
- self,
431
- colors: list[Image.Image],
432
- mesh: trimesh.Trimesh,
433
- output_path: str,
434
- ) -> trimesh.Trimesh:
435
- import time
436
- start = time.time()
437
- self.load_mesh(mesh)
438
- print("load_mesh", time.time() - start)
439
-
440
- start = time.time()
441
- rendered_depth, masks = self.renderer.render_depth(
442
- self.vertices, self.faces
443
- )
444
- norm_deps = self.renderer.normalize_map_by_mask(rendered_depth, masks)
445
- render_uvs, _ = self.renderer.render_uv(
446
- self.vertices, self.faces, self.uv_map
447
- )
448
- view_normals = self.compute_enhanced_viewnormal(
449
- self.renderer.mv_mtx, self.vertices, self.faces
450
- )
451
- print("0", time.time() - start)
452
-
453
- textures, weighted_cos_maps = [], []
454
-
455
- start = time.time()
456
- for color, mask, dep, normal, uv, weight in zip(
457
- colors,
458
- masks,
459
- norm_deps,
460
- view_normals,
461
- render_uvs,
462
- self.view_weights,
463
- ):
464
- mask, dep, normal, uv = interp_tensers([mask, dep, normal, uv], self.render_wh)
465
- texture, cos_map = self.back_project(color, mask, dep, normal, uv)
466
- textures.append(texture)
467
- weighted_cos_maps.append(weight * (cos_map**4))
468
- print("1", time.time() - start)
469
- start = time.time()
470
- texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
471
- print("2", time.time() - start)
472
- start = time.time()
473
- texture_np = self.uv_inpaint(texture, mask)
474
- print("3", time.time() - start)
475
- start = time.time()
476
- texture_np = post_process_texture(texture_np)
477
- vertices, faces, uv_map = self.get_mesh_np_attrs(
478
- self.scale, self.center
479
- )
480
-
481
- textured_mesh = save_mesh_with_mtl(
482
- vertices, faces, uv_map, texture_np, output_path
483
- )
484
- print("4", time.time() - start)
485
-
486
- return textured_mesh
487
-
488
-
489
- def parse_args():
490
- parser = argparse.ArgumentParser(description="Backproject texture")
491
- parser.add_argument(
492
- "--color_path",
493
- type=str,
494
- help="Multiview color image in 6x512x512 file path",
495
- )
496
- parser.add_argument(
497
- "--mesh_path",
498
- type=str,
499
- help="Mesh path, .obj, .glb or .ply",
500
- )
501
- parser.add_argument(
502
- "--output_path",
503
- type=str,
504
- help="Output mesh path with suffix",
505
- )
506
- parser.add_argument(
507
- "--num_images", type=int, default=6, help="Number of images to render."
508
- )
509
- parser.add_argument(
510
- "--elevation",
511
- nargs=2,
512
- type=float,
513
- default=[20.0, -10.0],
514
- help="Elevation angles for the camera (default: [20.0, -10.0])",
515
- )
516
- parser.add_argument(
517
- "--distance",
518
- type=float,
519
- default=5,
520
- help="Camera distance (default: 5)",
521
- )
522
- parser.add_argument(
523
- "--resolution_hw",
524
- type=int,
525
- nargs=2,
526
- default=(2048, 2048),
527
- help="Resolution of the mesh rendering",
528
- )
529
- parser.add_argument(
530
- "--target_hw",
531
- type=int,
532
- nargs=2,
533
- default=(2048, 2048),
534
- help="Target rendering images resolution",
535
- )
536
- parser.add_argument(
537
- "--fov",
538
- type=float,
539
- default=30,
540
- help="Field of view in degrees (default: 30)",
541
- )
542
- parser.add_argument(
543
- "--device",
544
- type=str,
545
- choices=["cpu", "cuda"],
546
- default="cuda",
547
- help="Device to run on (default: `cuda`)",
548
- )
549
- parser.add_argument(
550
- "--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
551
- )
552
- parser.add_argument(
553
- "--texture_wh",
554
- nargs=2,
555
- type=int,
556
- default=[2048, 2048],
557
- help="Texture resolution width and height",
558
- )
559
- parser.add_argument(
560
- "--mesh_sipmlify_ratio",
561
- type=float,
562
- default=0.9,
563
- help="Mesh simplification ratio (default: 0.9)",
564
- )
565
- parser.add_argument(
566
- "--delight", action="store_true", help="Use delighting model."
567
- )
568
- args = parser.parse_args()
569
-
570
- return args
571
-
572
-
573
- def entrypoint(
574
- delight_model: DelightingModel = None,
575
- imagesr_model: ImageRealESRGAN = None,
576
- **kwargs,
577
- ) -> trimesh.Trimesh:
578
- args = parse_args()
579
- for k, v in kwargs.items():
580
- if hasattr(args, k) and v is not None:
581
- setattr(args, k, v)
582
-
583
- # Setup camera parameters.
584
- camera_params = CameraSetting(
585
- num_images=args.num_images,
586
- elevation=args.elevation,
587
- distance=args.distance,
588
- resolution_hw=args.resolution_hw,
589
- fov=math.radians(args.fov),
590
- device=args.device,
591
- )
592
- view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
593
-
594
- color_grid = Image.open(args.color_path)
595
- if args.delight:
596
- if delight_model is None:
597
- delight_model = DelightingModel(
598
- model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/hunyuan3d-delight-v2-0" # noqa
599
- )
600
- save_dir = os.path.dirname(args.output_path)
601
- os.makedirs(save_dir, exist_ok=True)
602
- color_grid.save(f"{save_dir}/color_grid.png")
603
- color_grid = delight_model(color_grid)
604
- color_grid.save(f"{save_dir}/color_grid_delight.png")
605
-
606
- multiviews = get_images_from_grid(color_grid, img_size=512)
607
-
608
- # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
609
- if imagesr_model is None:
610
- imagesr_model = ImageRealESRGAN(outscale=4)
611
- multiviews = [imagesr_model(img.convert("RGB")) for img in multiviews]
612
- multiviews = [img.resize(args.target_hw[::-1]) for img in multiviews]
613
-
614
- mesh = trimesh.load(args.mesh_path)
615
- if isinstance(mesh, trimesh.Scene):
616
- mesh = mesh.dump(concatenate=True)
617
-
618
- if not args.skip_fix_mesh:
619
- mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
620
- mesh_fixer = MeshFixer(mesh.vertices, mesh.faces, args.device)
621
- mesh.vertices, mesh.faces = mesh_fixer(
622
- filter_ratio=args.mesh_sipmlify_ratio,
623
- max_hole_size=0.04,
624
- resolution=1024,
625
- num_views=1000,
626
- norm_mesh_ratio=0.5,
627
- )
628
- # Restore scale.
629
- mesh.vertices = mesh.vertices / scale
630
- mesh.vertices = mesh.vertices + center
631
-
632
- # Baking texture to mesh.
633
- import time
634
- start = time.time()
635
- texture_backer = TextureBacker(
636
- camera_params=camera_params,
637
- view_weights=view_weights,
638
- render_wh=args.target_hw,
639
- texture_wh=args.texture_wh,
640
- )
641
- print(time.time()-start)
642
- start = time.time()
643
- textured_mesh = texture_backer(multiviews, mesh, args.output_path)
644
- print(f"Texture backproject time: {time.time() - start:.2f}s")
645
-
646
- return textured_mesh
647
-
648
-
649
- if __name__ == "__main__":
650
- entrypoint()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/backup/backproject_v2.py DELETED
@@ -1,700 +0,0 @@
1
- import logging
2
- import math
3
- from typing import Union
4
-
5
- import custom_rasterizer as cr
6
- import cv2
7
- import numpy as np
8
- import torch
9
- import torch.nn.functional as F
10
- import trimesh
11
- import xatlas
12
- from PIL import Image
13
- from asset3d_gen.data.utils import (
14
- get_images_from_file,
15
- normalize_vertices_array,
16
- post_process_texture,
17
- save_mesh_with_mtl,
18
- )
19
-
20
- logging.basicConfig(
21
- format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
22
- )
23
- logger = logging.getLogger(__name__)
24
-
25
-
26
- __all__ = ["TextureBacker", "Image_Super_Net", "Image_GANNet"]
27
-
28
-
29
- import math
30
- import numpy as np
31
-
32
-
33
- def get_perspective_projection(
34
- fov: float, aspect_wh: float, near: float = 0.01, far: float = 100
35
- ) -> np.ndarray:
36
- """Compute the perspective projection matrix for 3D rendering."""
37
- fov_rad = math.radians(fov)
38
- tan_half_fov = math.tan(fov_rad / 2.0)
39
-
40
- return np.array(
41
- [
42
- [1.0 / (tan_half_fov * aspect_wh), 0.0, 0.0, 0.0],
43
- [0.0, 1.0 / tan_half_fov, 0.0, 0.0],
44
- [
45
- 0.0,
46
- 0.0,
47
- -(far + near) / (far - near),
48
- -(2.0 * far * near) / (far - near),
49
- ],
50
- [0.0, 0.0, -1.0, 0.0],
51
- ],
52
- dtype=np.float32,
53
- )
54
-
55
-
56
- def transform_vertices(
57
- mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
58
- ) -> torch.Tensor:
59
- """Transform 3D vertices using a projection matrix."""
60
- t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
61
- if pos.size(-1) == 3:
62
- pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
63
-
64
- result = pos @ t_mtx.T
65
-
66
- return result if keepdim else result.unsqueeze(0)
67
-
68
-
69
- def compute_w2c_matrix(
70
- elev_deg: float, azim_deg: float, cam_dist: float
71
- ) -> np.ndarray:
72
- """Compute w2c 4x4 transformation matrix from spherical coordinates."""
73
-
74
- elev_rad = math.radians(-elev_deg)
75
- azim_rad = math.radians(azim_deg)
76
-
77
- sin_elev = math.sin(elev_rad)
78
- cos_elev = math.cos(elev_rad)
79
- sin_azim = math.sin(azim_rad)
80
- cos_azim = math.cos(azim_rad)
81
-
82
- cam_pos = np.array(
83
- [
84
- cam_dist * cos_elev * cos_azim,
85
- cam_dist * cos_elev * sin_azim,
86
- cam_dist * sin_elev,
87
- ]
88
- )
89
-
90
- look_dir = -cam_pos / np.linalg.norm(cam_pos)
91
- right_dir = np.cross(look_dir, [0, 0, 1])
92
- right_dir /= np.linalg.norm(right_dir)
93
- up_dir = np.cross(right_dir, look_dir)
94
-
95
- c2w = np.eye(4)
96
- c2w[:3, 0] = right_dir
97
- c2w[:3, 1] = up_dir
98
- c2w[:3, 2] = -look_dir
99
- c2w[:3, 3] = cam_pos
100
-
101
- try:
102
- w2c = np.linalg.inv(c2w)
103
- except np.linalg.LinAlgError as e:
104
- raise ArithmeticError("Failed to invert camera-to-world matrix") from e
105
-
106
- return w2c.astype(np.float32)
107
-
108
-
109
- def _bilinear_interpolation_scattering(
110
- image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
111
- ) -> torch.Tensor:
112
- """Bilinear interpolation scattering for grid-based value accumulation."""
113
- device = values.device
114
- dtype = values.dtype
115
- C = values.shape[-1]
116
-
117
- indices = coords * torch.tensor(
118
- [image_h - 1, image_w - 1], dtype=dtype, device=device
119
- )
120
- i, j = indices.unbind(-1)
121
-
122
- i0, j0 = (
123
- indices.floor()
124
- .long()
125
- .clamp(0, image_h - 2)
126
- .clamp(0, image_w - 2)
127
- .unbind(-1)
128
- )
129
- i1, j1 = i0 + 1, j0 + 1
130
-
131
- w_i = i - i0.float()
132
- w_j = j - j0.float()
133
- weights = torch.stack(
134
- [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
135
- dim=1,
136
- )
137
-
138
- indices_comb = torch.stack(
139
- [
140
- torch.stack([i0, j0], dim=1),
141
- torch.stack([i0, j1], dim=1),
142
- torch.stack([i1, j0], dim=1),
143
- torch.stack([i1, j1], dim=1),
144
- ],
145
- dim=1,
146
- )
147
-
148
- grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
149
- cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
150
-
151
- for k in range(4):
152
- idx = indices_comb[:, k]
153
- w = weights[:, k].unsqueeze(-1)
154
-
155
- stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
156
- flat_idx = (idx * stride).sum(-1)
157
-
158
- grid.view(-1, C).scatter_add_(
159
- 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
160
- )
161
- cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
162
-
163
- mask = cnt.squeeze(-1) > 0
164
- grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
165
-
166
- return grid
167
-
168
-
169
- def _texture_inpaint_smooth(
170
- texture: np.ndarray,
171
- mask: np.ndarray,
172
- vertices: np.ndarray,
173
- faces: np.ndarray,
174
- uv_map: np.ndarray,
175
- ) -> tuple[np.ndarray, np.ndarray]:
176
- """Perform texture inpainting using vertex-based color propagation."""
177
- image_h, image_w, C = texture.shape
178
- N = vertices.shape[0]
179
-
180
- # Initialize vertex data structures
181
- vtx_mask = np.zeros(N, dtype=np.float32)
182
- vtx_colors = np.zeros((N, C), dtype=np.float32)
183
- unprocessed = []
184
- adjacency = [[] for _ in range(N)]
185
-
186
- # Build adjacency graph and initial color assignment
187
- for face_idx in range(faces.shape[0]):
188
- for k in range(3):
189
- uv_idx_k = faces[face_idx, k]
190
- v_idx = faces[face_idx, k]
191
-
192
- # Convert UV to pixel coordinates with boundary clamping
193
- u = np.clip(
194
- int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
195
- )
196
- v = np.clip(
197
- int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
198
- 0,
199
- image_h - 1,
200
- )
201
-
202
- if mask[v, u]:
203
- vtx_mask[v_idx] = 1.0
204
- vtx_colors[v_idx] = texture[v, u]
205
- elif v_idx not in unprocessed:
206
- unprocessed.append(v_idx)
207
-
208
- # Build undirected adjacency graph
209
- neighbor = faces[face_idx, (k + 1) % 3]
210
- if neighbor not in adjacency[v_idx]:
211
- adjacency[v_idx].append(neighbor)
212
- if v_idx not in adjacency[neighbor]:
213
- adjacency[neighbor].append(v_idx)
214
-
215
- # Color propagation with dynamic stopping
216
- remaining_iters, prev_count = 2, 0
217
- while remaining_iters > 0:
218
- current_unprocessed = []
219
-
220
- for v_idx in unprocessed:
221
- valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
222
- if not valid_neighbors:
223
- current_unprocessed.append(v_idx)
224
- continue
225
-
226
- # Calculate inverse square distance weights
227
- neighbors_pos = vertices[valid_neighbors]
228
- dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
229
- weights = 1 / np.maximum(dist_sq, 1e-8)
230
-
231
- vtx_colors[v_idx] = np.average(
232
- vtx_colors[valid_neighbors], weights=weights, axis=0
233
- )
234
- vtx_mask[v_idx] = 1.0
235
-
236
- # Update iteration control
237
- if len(current_unprocessed) == prev_count:
238
- remaining_iters -= 1
239
- else:
240
- remaining_iters = min(remaining_iters + 1, 2)
241
- prev_count = len(current_unprocessed)
242
- unprocessed = current_unprocessed
243
-
244
- # Generate output texture
245
- inpainted_texture, updated_mask = texture.copy(), mask.copy()
246
- for face_idx in range(faces.shape[0]):
247
- for k in range(3):
248
- v_idx = faces[face_idx, k]
249
- if not vtx_mask[v_idx]:
250
- continue
251
-
252
- # UV coordinate conversion
253
- uv_idx_k = faces[face_idx, k]
254
- u = np.clip(
255
- int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
256
- )
257
- v = np.clip(
258
- int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
259
- 0,
260
- image_h - 1,
261
- )
262
-
263
- inpainted_texture[v, u] = vtx_colors[v_idx]
264
- updated_mask[v, u] = 255
265
-
266
- return inpainted_texture, updated_mask
267
-
268
-
269
- class TextureBacker:
270
- """Texture baking pipeline for multi-view projection and fusion."""
271
-
272
- def __init__(
273
- self,
274
- camera_elevs: list[float],
275
- camera_azims: list[float],
276
- camera_distance: int,
277
- camera_fov: float,
278
- view_weights: list[float] = None,
279
- render_wh: tuple[int, int] = (2048, 2048),
280
- texture_wh: tuple[int, int] = (2048, 2048),
281
- use_antialias: bool = True,
282
- bake_angle_thres: int = 75,
283
- device="cuda",
284
- ):
285
- self.camera_elevs = camera_elevs
286
- self.camera_azims = camera_azims
287
- self.view_weights = (
288
- view_weights
289
- if view_weights is not None
290
- else [1] * len(camera_elevs)
291
- )
292
- self.device = device
293
- self.render_wh = render_wh
294
- self.texture_wh = texture_wh
295
-
296
- self.camera_distance = camera_distance
297
- self.use_antialias = use_antialias
298
-
299
- self.bake_angle_thres = bake_angle_thres
300
- self.bake_unreliable_kernel_size = int(
301
- (2 / 512) * max(self.render_wh[0], self.render_wh[1])
302
- )
303
-
304
- self.camera_proj_mat = get_perspective_projection(
305
- camera_fov,
306
- self.render_wh[1] / self.render_wh[0],
307
- )
308
- self.cnt = 0
309
-
310
- def rasterize_mesh(
311
- self,
312
- vertex: torch.Tensor,
313
- face: torch.Tensor,
314
- resolution: tuple[int, int],
315
- ) -> torch.Tensor:
316
- vertex = vertex[None] if vertex.ndim == 2 else vertex
317
- indices, weights = cr.rasterize(vertex, face, resolution)
318
-
319
- return torch.cat(
320
- [weights, indices.unsqueeze(-1).to(weights.dtype)], dim=-1
321
- ).unsqueeze(0)
322
-
323
- def raster_interpolate(
324
- self, uv: torch.Tensor, rast_out: torch.Tensor, faces: torch.Tensor
325
- ) -> torch.Tensor:
326
- barycentric = rast_out[0, ..., :-1]
327
- findices = rast_out[0, ..., -1]
328
- if uv.dim() == 2:
329
- uv = uv.unsqueeze(0)
330
-
331
- return cr.interpolate(uv, findices, barycentric, faces)[0]
332
-
333
- def load_mesh(self, mesh_path: str) -> None:
334
- mesh = trimesh.load(mesh_path)
335
- if isinstance(mesh, trimesh.Scene):
336
- mesh = mesh.dump(concatenate=True)
337
-
338
- mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
339
- self.scale, self.center = scale, center
340
-
341
- vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
342
- mesh.vertices = mesh.vertices[vmapping]
343
- mesh.faces = indices
344
- mesh.visual.uv = uvs
345
-
346
- self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
347
- self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
348
- self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
349
-
350
- # Transformation of coordinate system
351
- self.vertices[:, [0, 1]] = -self.vertices[:, [0, 1]]
352
- self.vertices[:, [1, 2]] = self.vertices[:, [2, 1]]
353
- self.uv_map[:, 1] = 1 - self.uv_map[:, 1]
354
-
355
- def get_mesh_attrs(
356
- self,
357
- scale: float = None,
358
- center: np.ndarray = None,
359
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
360
- vertices = self.vertices.cpu().numpy()
361
- faces = self.faces.cpu().numpy()
362
- uv_map = self.uv_map.cpu().numpy()
363
-
364
- # Inverse transformation of coordinate system
365
- vertices[:, [1, 2]] = vertices[:, [2, 1]]
366
- vertices[:, [0, 1]] = -vertices[:, [0, 1]]
367
- uv_map[:, 1] = 1.0 - uv_map[:, 1]
368
-
369
- if scale is not None:
370
- vertices = vertices / scale
371
- if center is not None:
372
- vertices = vertices + center
373
-
374
- return vertices, faces, uv_map
375
-
376
- def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
377
- depth_image_np = depth_image.cpu().numpy()
378
- depth_image_np = (depth_image_np * 255).astype(np.uint8)
379
- depth_edges = cv2.Canny(depth_image_np, 30, 80)
380
- combined_edges = depth_edges
381
- sketch_image = (
382
- torch.from_numpy(combined_edges).to(depth_image.device).float()
383
- / 255
384
- )
385
- sketch_image = sketch_image.unsqueeze(-1)
386
-
387
- return sketch_image
388
-
389
- def back_project(
390
- self, image: Image.Image, elev: float, azim: float
391
- ) -> tuple[torch.Tensor, torch.Tensor]:
392
- if isinstance(image, Image.Image):
393
- image = np.array(image)
394
- image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
395
- if image.ndim == 2:
396
- image = image.unsqueeze(-1)
397
- image = image / 255.0
398
-
399
- view_mat = compute_w2c_matrix(elev, azim, self.camera_distance)
400
- import pdb
401
-
402
- pdb.set_trace()
403
- pos_cam = transform_vertices(view_mat, self.vertices, keepdim=True)
404
- pos_clip = transform_vertices(self.camera_proj_mat, pos_cam)
405
- pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
406
-
407
- v0, v1, v2 = (pos_cam[self.faces[:, i]] for i in range(3))
408
- face_norm = F.normalize(torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1)
409
- vertex_norm = (
410
- torch.from_numpy(
411
- trimesh.geometry.mean_vertex_normals(
412
- len(pos_cam), self.faces.cpu(), face_norm.cpu()
413
- )
414
- )
415
- .to(self.device)
416
- .contiguous()
417
- )
418
-
419
- rast_out = self.rasterize_mesh(pos_clip, self.faces, image.shape[:2])
420
- vis_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0]
421
-
422
- interp_data = {
423
- "normal": self.raster_interpolate(
424
- vertex_norm[None], rast_out, self.faces
425
- ),
426
- "uv": self.raster_interpolate(
427
- self.uv_map[None], rast_out, self.faces
428
- ),
429
- "depth": self.raster_interpolate(
430
- pos_cam[:, 2].reshape(1, -1, 1), rast_out, self.faces
431
- ),
432
- }
433
-
434
- valid_depth = interp_data["depth"][vis_mask > 0]
435
- depth_norm = (interp_data["depth"] - valid_depth.min()) / (
436
- valid_depth.max() - valid_depth.min()
437
- )
438
- # depth_norm[vis_mask <= 0] = 0
439
- sketch_image = self._render_depth_edges(depth_norm * vis_mask)
440
-
441
- # ddd = depth_norm * vis_mask
442
- # cv2.imwrite(f"v2_depth_d{self.cnt}.png", (ddd.cpu().numpy() * 255).astype(np.uint8))
443
-
444
- cv2.imwrite(
445
- f"v2_vis_mask{self.cnt}.png",
446
- (vis_mask.cpu().numpy() * 255).astype(np.uint8),
447
- )
448
- cv2.imwrite(
449
- f"v2_normal{self.cnt}.png",
450
- (interp_data["normal"].cpu().numpy() * 255).astype(np.uint8),
451
- )
452
- cv2.imwrite(
453
- f"v2_depth{self.cnt}.png",
454
- (depth_norm.cpu().numpy() * 255).astype(np.uint8),
455
- )
456
- cv2.imwrite(
457
- f"v2_uv{self.cnt}.png",
458
- (interp_data["uv"][..., 0].cpu().numpy() * 255).astype(np.uint8),
459
- )
460
- cv2.imwrite(
461
- f"v2_sketch{self.cnt}.png",
462
- (sketch_image.cpu().numpy() * 255).astype(np.uint8),
463
- )
464
-
465
- self.cnt += 1
466
-
467
- cos = F.cosine_similarity(
468
- torch.tensor([[0, 0, -1]], device=self.device),
469
- interp_data["normal"].view(-1, 3),
470
- ).view_as(interp_data["normal"][..., :1])
471
- cos[cos < np.cos(np.radians(self.bake_angle_thres))] = 0
472
-
473
- cv2.imwrite(
474
- f"v2_cos{self.cnt}.png", (cos.cpu().numpy() * 255).astype(np.uint8)
475
- )
476
-
477
- k = self.bake_unreliable_kernel_size * 2 + 1
478
- kernel = torch.ones((1, 1, k, k), device=self.device)
479
-
480
- vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
481
- vis_mask = F.conv2d(
482
- 1.0 - vis_mask,
483
- kernel,
484
- padding=k // 2,
485
- )
486
- vis_mask = 1.0 - (vis_mask > 0).float()
487
- vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
488
-
489
- sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
490
- sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
491
- sketch_image = (sketch_image > 0).float()
492
- sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
493
- vis_mask = vis_mask * (sketch_image < 0.5)
494
-
495
- cos[vis_mask == 0] = 0
496
-
497
- vis_mask = cv2.imread(
498
- f"v3_db_mask{self.cnt}.png", cv2.IMREAD_GRAYSCALE
499
- )
500
- vis_mask = (
501
- torch.from_numpy(vis_mask[..., None]).to(self.device).float() / 255
502
- )
503
- # cos2 = cv2.imread(f"v3_db_cos{self.cnt}.png", cv2.IMREAD_GRAYSCALE)
504
- # cos2 = torch.from_numpy(cos2[..., None]).to(self.device).float() / 255
505
- # cos = cos2
506
-
507
- valid_pixels = (vis_mask != 0).view(-1)
508
- # import pdb; pdb.set_trace()
509
-
510
- cv2.imwrite(
511
- f"v2_db_sketch{self.cnt}.png",
512
- (sketch_image.cpu().numpy() * 255).astype(np.uint8),
513
- )
514
- cv2.imwrite(
515
- f"v2_db_uv{self.cnt}.png",
516
- (interp_data["uv"][..., 0].cpu().numpy() * 255).astype(np.uint8),
517
- )
518
- cv2.imwrite(
519
- f"v2_db_uv2{self.cnt}.png",
520
- (interp_data["uv"][..., 1].cpu().numpy() * 255).astype(np.uint8),
521
- )
522
- cv2.imwrite(
523
- f"v2_db_color{self.cnt}.png",
524
- (image.cpu().numpy() * 255).astype(np.uint8),
525
- )
526
- cv2.imwrite(
527
- f"v2_db_cos{self.cnt}.png",
528
- (cos.cpu().numpy() * 255).astype(np.uint8),
529
- )
530
- cv2.imwrite(
531
- f"v2_db_mask{self.cnt}.png",
532
- (vis_mask.cpu().numpy() * 255).astype(np.uint8),
533
- )
534
- # import pdb; pdb.set_trace()
535
- return (
536
- self._scatter_texture(interp_data["uv"], image, valid_pixels),
537
- self._scatter_texture(interp_data["uv"], cos, valid_pixels),
538
- )
539
-
540
- def _scatter_texture(self, uv, data, mask):
541
- def __filter_data(data, mask):
542
- return data.view(-1, data.shape[-1])[mask]
543
-
544
- return _bilinear_interpolation_scattering(
545
- self.texture_wh[1],
546
- self.texture_wh[0],
547
- __filter_data(uv, mask)[..., [1, 0]],
548
- __filter_data(data, mask),
549
- )
550
-
551
- @torch.no_grad()
552
- def fast_bake_texture(
553
- self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
554
- ) -> tuple[torch.Tensor, torch.Tensor]:
555
- channel = textures[0].shape[-1]
556
- texture_merge = torch.zeros(self.texture_wh + (channel,)).to(
557
- self.device
558
- )
559
- trust_map_merge = torch.zeros(self.texture_wh + (1,)).to(self.device)
560
- for texture, cos_map in zip(textures, confidence_maps):
561
- view_sum = (cos_map > 0).sum()
562
- painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
563
- if painted_sum / view_sum > 0.99:
564
- continue
565
- texture_merge += texture * cos_map
566
- trust_map_merge += cos_map
567
- texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
568
-
569
- return texture_merge, trust_map_merge > 1e-8
570
-
571
- def uv_inpaint(
572
- self, texture: torch.Tensor, mask: torch.Tensor
573
- ) -> np.ndarray:
574
- texture_np = texture.cpu().numpy()
575
- mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
576
- vertices, faces, uv_map = self.get_mesh_attrs()
577
- # import pdb; pdb.set_trace()
578
- texture_np, mask_np = _texture_inpaint_smooth(
579
- texture_np, mask_np, vertices, faces, uv_map
580
- )
581
- texture_np = texture_np.clip(0, 1)
582
- texture_np = cv2.inpaint(
583
- (texture_np * 255).astype(np.uint8),
584
- 255 - mask_np,
585
- 3,
586
- cv2.INPAINT_NS,
587
- )
588
-
589
- return texture_np
590
-
591
- def __call__(
592
- self, colors: list[Image.Image], input_mesh: str, output_path: str
593
- ) -> trimesh.Trimesh:
594
- self.load_mesh(input_mesh)
595
-
596
- textures, weighted_cos_maps = [], []
597
- for color, cam_elev, cam_azim, weight in zip(
598
- colors, self.camera_elevs, self.camera_azims, self.view_weights
599
- ):
600
- texture, cos_map = self.back_project(color, cam_elev, cam_azim)
601
- cv2.imwrite(
602
- f"v2_texture{self.cnt}.png",
603
- (texture.cpu().numpy() * 255).astype(np.uint8),
604
- )
605
- cv2.imwrite(
606
- f"v2_texture_cos{self.cnt}.png",
607
- (cos_map.cpu().numpy() * 255).astype(np.uint8),
608
- )
609
- # import pdb; pdb.set_trace()
610
- textures.append(texture)
611
- weighted_cos_maps.append(weight * (cos_map**4))
612
-
613
- texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
614
- texture_np = self.uv_inpaint(texture, mask)
615
- texture_np = post_process_texture(texture_np)
616
- vertices, faces, uvs = self.get_mesh_attrs(self.scale, self.center)
617
- # import pdb; pdb.set_trace()
618
- cv2.imwrite("v2_texture_np.png", texture_np)
619
-
620
- textured_mesh = save_mesh_with_mtl(
621
- vertices, faces, uvs, texture_np, output_path
622
- )
623
-
624
- return textured_mesh
625
-
626
-
627
- class Image_Super_Net:
628
- def __init__(self, device="cuda"):
629
- from diffusers import StableDiffusionUpscalePipeline
630
-
631
- self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
632
- "stabilityai/stable-diffusion-x4-upscaler",
633
- torch_dtype=torch.float16,
634
- ).to(device)
635
- self.up_pipeline_x4.set_progress_bar_config(disable=True)
636
-
637
- def __call__(self, image, prompt=""):
638
- with torch.no_grad():
639
- upscaled_image = self.up_pipeline_x4(
640
- prompt=[prompt],
641
- image=image,
642
- num_inference_steps=10,
643
- ).images[0]
644
-
645
- return upscaled_image
646
-
647
-
648
- class Image_GANNet:
649
- def __init__(self, outscale: int):
650
- from basicsr.archs.rrdbnet_arch import RRDBNet
651
- from realesrgan import RealESRGANer
652
-
653
- self.outscale = outscale
654
- model = RRDBNet(
655
- num_in_ch=3,
656
- num_out_ch=3,
657
- num_feat=64,
658
- num_block=23,
659
- num_grow_ch=32,
660
- scale=4,
661
- )
662
- self.upsampler = RealESRGANer(
663
- scale=4,
664
- model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/super_resolution/RealESRGAN_x4plus.pth", # noqa
665
- model=model,
666
- pre_pad=0,
667
- half=True,
668
- )
669
-
670
- def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
671
- if isinstance(image, Image.Image):
672
- image = np.array(image)
673
- output, _ = self.upsampler.enhance(image, outscale=self.outscale)
674
-
675
- return Image.fromarray(output)
676
-
677
-
678
- if __name__ == "__main__":
679
- device = "cuda"
680
- color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png"
681
- mesh_path = "outputs/texture_mesh_gen/texture_mesh/kettle_color.glb"
682
- output_path = "robot_test_v2/robot.obj"
683
- target_image_size = (2048, 2048)
684
-
685
- super_model = Image_GANNet(outscale=4)
686
- multiviews = get_images_from_file(color_path, img_size=512)
687
-
688
- texture_backer = TextureBacker(
689
- camera_elevs=[20, 20, 20, -10, -10, -10],
690
- camera_azims=[-180, -60, 60, -120, 0, 120],
691
- view_weights=[1, 0.2, 0.2, 0.2, 1, 0.2],
692
- camera_distance=5,
693
- camera_fov=30,
694
- render_wh=(2048, 2048),
695
- texture_wh=(2048, 2048),
696
- )
697
-
698
- multiviews = [super_model(img) for img in multiviews]
699
- multiviews = [img.convert("RGB") for img in multiviews]
700
- textured_mesh = texture_backer(multiviews, mesh_path, output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/backup/backproject_v3.py DELETED
@@ -1,866 +0,0 @@
1
- import logging
2
- import math
3
- from typing import Union
4
-
5
- import custom_rasterizer as cr
6
- import cv2
7
- import numpy as np
8
- import torch
9
- import torch.nn.functional as F
10
- import trimesh
11
- import xatlas
12
- from PIL import Image
13
- from asset3d_gen.data.utils import (
14
- get_images_from_file,
15
- normalize_vertices_array,
16
- post_process_texture,
17
- save_mesh_with_mtl,
18
- )
19
-
20
- logging.basicConfig(
21
- format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
22
- )
23
- logger = logging.getLogger(__name__)
24
-
25
-
26
- __all__ = ["TextureBacker", "Image_Super_Net", "Image_GANNet"]
27
-
28
-
29
- import math
30
- import numpy as np
31
-
32
-
33
- def get_perspective_projection(
34
- fov: float, aspect_wh: float, near: float = 0.01, far: float = 100
35
- ) -> np.ndarray:
36
- """Compute the perspective projection matrix for 3D rendering."""
37
- fov_rad = math.radians(fov)
38
- tan_half_fov = math.tan(fov_rad / 2.0)
39
-
40
- return np.array(
41
- [
42
- [1.0 / (tan_half_fov * aspect_wh), 0.0, 0.0, 0.0],
43
- [0.0, 1.0 / tan_half_fov, 0.0, 0.0],
44
- [
45
- 0.0,
46
- 0.0,
47
- -(far + near) / (far - near),
48
- -(2.0 * far * near) / (far - near),
49
- ],
50
- [0.0, 0.0, -1.0, 0.0],
51
- ],
52
- dtype=np.float32,
53
- )
54
-
55
-
56
- def transform_vertices(
57
- mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
58
- ) -> torch.Tensor:
59
- """Transform 3D vertices using a projection matrix."""
60
- t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
61
- if pos.size(-1) == 3:
62
- pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
63
-
64
- result = pos @ t_mtx.T
65
-
66
- return result if keepdim else result.unsqueeze(0)
67
-
68
-
69
- def compute_w2c_matrix(
70
- elev_deg: float, azim_deg: float, cam_dist: float
71
- ) -> np.ndarray:
72
- """Compute w2c 4x4 transformation matrix from spherical coordinates."""
73
-
74
- elev_rad = math.radians(-elev_deg)
75
- azim_rad = math.radians(azim_deg)
76
-
77
- sin_elev = math.sin(elev_rad)
78
- cos_elev = math.cos(elev_rad)
79
- sin_azim = math.sin(azim_rad)
80
- cos_azim = math.cos(azim_rad)
81
-
82
- cam_pos = np.array(
83
- [
84
- cam_dist * cos_elev * cos_azim,
85
- cam_dist * cos_elev * sin_azim,
86
- cam_dist * sin_elev,
87
- ]
88
- )
89
-
90
- look_dir = -cam_pos / np.linalg.norm(cam_pos)
91
- right_dir = np.cross(look_dir, [0, 0, 1])
92
- right_dir /= np.linalg.norm(right_dir)
93
- up_dir = np.cross(right_dir, look_dir)
94
-
95
- c2w = np.eye(4)
96
- c2w[:3, 0] = right_dir
97
- c2w[:3, 1] = up_dir
98
- c2w[:3, 2] = -look_dir
99
- c2w[:3, 3] = cam_pos
100
-
101
- try:
102
- w2c = np.linalg.inv(c2w)
103
- except np.linalg.LinAlgError as e:
104
- raise ArithmeticError("Failed to invert camera-to-world matrix") from e
105
-
106
- return w2c.astype(np.float32)
107
-
108
-
109
- def _bilinear_interpolation_scattering(
110
- image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
111
- ) -> torch.Tensor:
112
- """Bilinear interpolation scattering for grid-based value accumulation."""
113
- device = values.device
114
- dtype = values.dtype
115
- C = values.shape[-1]
116
-
117
- indices = coords * torch.tensor(
118
- [image_h - 1, image_w - 1], dtype=dtype, device=device
119
- )
120
- i, j = indices.unbind(-1)
121
-
122
- i0, j0 = (
123
- indices.floor()
124
- .long()
125
- .clamp(0, image_h - 2)
126
- .clamp(0, image_w - 2)
127
- .unbind(-1)
128
- )
129
- i1, j1 = i0 + 1, j0 + 1
130
-
131
- w_i = i - i0.float()
132
- w_j = j - j0.float()
133
- weights = torch.stack(
134
- [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
135
- dim=1,
136
- )
137
-
138
- indices_comb = torch.stack(
139
- [
140
- torch.stack([i0, j0], dim=1),
141
- torch.stack([i0, j1], dim=1),
142
- torch.stack([i1, j0], dim=1),
143
- torch.stack([i1, j1], dim=1),
144
- ],
145
- dim=1,
146
- )
147
-
148
- grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
149
- cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
150
-
151
- for k in range(4):
152
- idx = indices_comb[:, k]
153
- w = weights[:, k].unsqueeze(-1)
154
-
155
- stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
156
- flat_idx = (idx * stride).sum(-1)
157
-
158
- grid.view(-1, C).scatter_add_(
159
- 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
160
- )
161
- cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
162
-
163
- mask = cnt.squeeze(-1) > 0
164
- grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
165
-
166
- return grid
167
-
168
-
169
- def _texture_inpaint_smooth(
170
- texture: np.ndarray,
171
- mask: np.ndarray,
172
- vertices: np.ndarray,
173
- faces: np.ndarray,
174
- uv_map: np.ndarray,
175
- ) -> tuple[np.ndarray, np.ndarray]:
176
- """Perform texture inpainting using vertex-based color propagation."""
177
- image_h, image_w, C = texture.shape
178
- N = vertices.shape[0]
179
-
180
- # Initialize vertex data structures
181
- vtx_mask = np.zeros(N, dtype=np.float32)
182
- vtx_colors = np.zeros((N, C), dtype=np.float32)
183
- unprocessed = []
184
- adjacency = [[] for _ in range(N)]
185
-
186
- # Build adjacency graph and initial color assignment
187
- for face_idx in range(faces.shape[0]):
188
- for k in range(3):
189
- uv_idx_k = faces[face_idx, k]
190
- v_idx = faces[face_idx, k]
191
-
192
- # Convert UV to pixel coordinates with boundary clamping
193
- u = np.clip(
194
- int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
195
- )
196
- v = np.clip(
197
- int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
198
- 0,
199
- image_h - 1,
200
- )
201
-
202
- if mask[v, u]:
203
- vtx_mask[v_idx] = 1.0
204
- vtx_colors[v_idx] = texture[v, u]
205
- elif v_idx not in unprocessed:
206
- unprocessed.append(v_idx)
207
-
208
- # Build undirected adjacency graph
209
- neighbor = faces[face_idx, (k + 1) % 3]
210
- if neighbor not in adjacency[v_idx]:
211
- adjacency[v_idx].append(neighbor)
212
- if v_idx not in adjacency[neighbor]:
213
- adjacency[neighbor].append(v_idx)
214
-
215
- # Color propagation with dynamic stopping
216
- remaining_iters, prev_count = 2, 0
217
- while remaining_iters > 0:
218
- current_unprocessed = []
219
-
220
- for v_idx in unprocessed:
221
- valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
222
- if not valid_neighbors:
223
- current_unprocessed.append(v_idx)
224
- continue
225
-
226
- # Calculate inverse square distance weights
227
- neighbors_pos = vertices[valid_neighbors]
228
- dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
229
- weights = 1 / np.maximum(dist_sq, 1e-8)
230
-
231
- vtx_colors[v_idx] = np.average(
232
- vtx_colors[valid_neighbors], weights=weights, axis=0
233
- )
234
- vtx_mask[v_idx] = 1.0
235
-
236
- # Update iteration control
237
- if len(current_unprocessed) == prev_count:
238
- remaining_iters -= 1
239
- else:
240
- remaining_iters = min(remaining_iters + 1, 2)
241
- prev_count = len(current_unprocessed)
242
- unprocessed = current_unprocessed
243
-
244
- # Generate output texture
245
- inpainted_texture, updated_mask = texture.copy(), mask.copy()
246
- for face_idx in range(faces.shape[0]):
247
- for k in range(3):
248
- v_idx = faces[face_idx, k]
249
- if not vtx_mask[v_idx]:
250
- continue
251
-
252
- # UV coordinate conversion
253
- uv_idx_k = faces[face_idx, k]
254
- u = np.clip(
255
- int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
256
- )
257
- v = np.clip(
258
- int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
259
- 0,
260
- image_h - 1,
261
- )
262
-
263
- inpainted_texture[v, u] = vtx_colors[v_idx]
264
- updated_mask[v, u] = 255
265
-
266
- return inpainted_texture, updated_mask
267
-
268
-
269
- class TextureBacker:
270
- """Texture baking pipeline for multi-view projection and fusion."""
271
-
272
- def __init__(
273
- self,
274
- camera_elevs: list[float],
275
- camera_azims: list[float],
276
- camera_distance: int,
277
- camera_fov: float,
278
- view_weights: list[float] = None,
279
- render_wh: tuple[int, int] = (2048, 2048),
280
- texture_wh: tuple[int, int] = (2048, 2048),
281
- use_antialias: bool = True,
282
- bake_angle_thresh: int = 75,
283
- device="cuda",
284
- ):
285
- self.camera_elevs = camera_elevs
286
- self.camera_azims = camera_azims
287
- self.view_weights = (
288
- view_weights
289
- if view_weights is not None
290
- else [1] * len(camera_elevs)
291
- )
292
- self.device = device
293
- self.render_wh = render_wh
294
- self.texture_wh = texture_wh
295
-
296
- self.camera_distance = camera_distance
297
- self.use_antialias = use_antialias
298
-
299
- self.bake_angle_thresh = bake_angle_thresh
300
- self.bake_unreliable_kernel_size = int(
301
- (2 / 512) * max(self.render_wh[0], self.render_wh[1])
302
- )
303
-
304
- self.camera_proj_mat = get_perspective_projection(
305
- camera_fov,
306
- self.render_wh[1] / self.render_wh[0],
307
- )
308
- self.cnt = 0
309
-
310
- def rasterize_mesh(
311
- self,
312
- vertex: torch.Tensor,
313
- face: torch.Tensor,
314
- resolution: tuple[int, int],
315
- ) -> torch.Tensor:
316
- vertex = vertex[None] if vertex.ndim == 2 else vertex
317
- indices, weights = cr.rasterize(vertex, face, resolution)
318
-
319
- return torch.cat(
320
- [weights, indices.unsqueeze(-1).to(weights.dtype)], dim=-1
321
- ).unsqueeze(0)
322
-
323
- def raster_interpolate(
324
- self, uv: torch.Tensor, rast_out: torch.Tensor, faces: torch.Tensor
325
- ) -> torch.Tensor:
326
- barycentric = rast_out[0, ..., :-1]
327
- findices = rast_out[0, ..., -1]
328
- if uv.dim() == 2:
329
- uv = uv.unsqueeze(0)
330
-
331
- return cr.interpolate(uv, findices, barycentric, faces)[0]
332
-
333
- def load_mesh(self, mesh_path: str) -> None:
334
- mesh = trimesh.load(mesh_path)
335
- if isinstance(mesh, trimesh.Scene):
336
- mesh = mesh.dump(concatenate=True)
337
-
338
- mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
339
- self.scale, self.center = scale, center
340
-
341
- vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
342
- mesh.vertices = mesh.vertices[vmapping]
343
- mesh.faces = indices
344
- mesh.visual.uv = uvs
345
-
346
- self.vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
347
- self.faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
348
- self.uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
349
-
350
- # Transformation of coordinate system
351
- self.vertices[:, [0, 1]] = -self.vertices[:, [0, 1]]
352
- self.vertices[:, [1, 2]] = self.vertices[:, [2, 1]]
353
- self.uv_map[:, 1] = 1 - self.uv_map[:, 1]
354
-
355
- def get_mesh_attrs(
356
- self,
357
- scale: float = None,
358
- center: np.ndarray = None,
359
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
360
- vertices = self.vertices.cpu().numpy()
361
- faces = self.faces.cpu().numpy()
362
- uv_map = self.uv_map.cpu().numpy()
363
-
364
- if scale is not None:
365
- vertices = vertices / scale
366
- if center is not None:
367
- vertices = vertices + center
368
-
369
- return vertices, faces, uv_map
370
-
371
- def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
372
- depth_image_np = depth_image.cpu().numpy()
373
- depth_image_np = (depth_image_np * 255).astype(np.uint8)
374
- depth_edges = cv2.Canny(depth_image_np, 30, 80)
375
- sketch_image = (
376
- torch.from_numpy(depth_edges).to(depth_image.device).float() / 255
377
- )
378
- sketch_image = sketch_image.unsqueeze(-1)
379
-
380
- return sketch_image
381
-
382
- def back_project(
383
- self, image: Image.Image, elev: float, azim: float
384
- ) -> tuple[torch.Tensor, torch.Tensor]:
385
- if isinstance(image, Image.Image):
386
- image = np.array(image)
387
- image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
388
- if image.ndim == 2:
389
- image = image.unsqueeze(-1)
390
- image = image / 255.0
391
-
392
- view_mat = compute_w2c_matrix(elev, azim, self.camera_distance)
393
- pos_cam = transform_vertices(view_mat, self.vertices, keepdim=True)
394
- pos_clip = transform_vertices(self.camera_proj_mat, pos_cam)
395
- pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
396
-
397
- v0, v1, v2 = (pos_cam[self.faces[:, i]] for i in range(3))
398
- face_norm = F.normalize(torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1)
399
- vertex_norm = (
400
- torch.from_numpy(
401
- trimesh.geometry.mean_vertex_normals(
402
- len(pos_cam), self.faces.cpu(), face_norm.cpu()
403
- )
404
- )
405
- .to(self.device)
406
- .contiguous()
407
- )
408
-
409
- rast_out = self.rasterize_mesh(pos_clip, self.faces, image.shape[:2])
410
- vis_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0]
411
-
412
- interp_data = {
413
- "normal": self.raster_interpolate(
414
- vertex_norm[None], rast_out, self.faces
415
- ),
416
- "uv": self.raster_interpolate(
417
- self.uv_map[None], rast_out, self.faces
418
- ),
419
- "depth": self.raster_interpolate(
420
- pos_cam[:, 2].reshape(1, -1, 1), rast_out, self.faces
421
- ),
422
- }
423
-
424
- valid_depth = interp_data["depth"][vis_mask > 0]
425
- depth_norm = (interp_data["depth"] - valid_depth.min()) / (
426
- valid_depth.max() - valid_depth.min()
427
- )
428
- depth_norm[vis_mask <= 0] = 0
429
- sketch_image = self._render_depth_edges(depth_norm * vis_mask)
430
-
431
- # cv2.imwrite("vis_mask.png", (vis_mask.cpu().numpy() * 255).astype(np.uint8))
432
- # cv2.imwrite("normal.png", (interp_data['normal'].cpu().numpy() * 255).astype(np.uint8))
433
- # cv2.imwrite("depth.png", (depth_norm.cpu().numpy() * 255).astype(np.uint8))
434
- # cv2.imwrite("uv.png", (interp_data['uv'][..., 0].cpu().numpy() * 255).astype(np.uint8))
435
- # import pdb; pdb.set_trace()
436
-
437
- cos = F.cosine_similarity(
438
- torch.tensor([[0, 0, -1]], device=self.device),
439
- interp_data["normal"].view(-1, 3),
440
- ).view_as(interp_data["normal"][..., :1])
441
- cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
442
-
443
- k = self.bake_unreliable_kernel_size * 2 + 1
444
- kernel = torch.ones((1, 1, k, k), device=self.device)
445
-
446
- vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
447
- vis_mask = F.conv2d(
448
- 1.0 - vis_mask,
449
- kernel,
450
- padding=k // 2,
451
- )
452
- vis_mask = 1.0 - (vis_mask > 0).float()
453
- vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
454
-
455
- sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
456
- sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
457
- sketch_image = (sketch_image > 0).float()
458
- sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
459
- vis_mask = vis_mask * (sketch_image < 0.5)
460
-
461
- cos[vis_mask == 0] = 0
462
- valid_pixels = (vis_mask != 0).view(-1)
463
-
464
- return (
465
- self._scatter_texture(interp_data["uv"], image, valid_pixels),
466
- self._scatter_texture(interp_data["uv"], cos, valid_pixels),
467
- )
468
-
469
- def back_project2(
470
- self, image, vis_mask, depth, normal, uv
471
- ) -> tuple[torch.Tensor, torch.Tensor]:
472
- if isinstance(image, Image.Image):
473
- image = np.array(image)
474
- image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
475
- if image.ndim == 2:
476
- image = image.unsqueeze(-1)
477
- image = image / 255.0
478
-
479
- depth_inv = (1.0 - depth) * vis_mask
480
- sketch_image = self._render_depth_edges(depth_inv)
481
-
482
- cv2.imwrite(
483
- f"v3_depth_inv{self.cnt}.png",
484
- (depth_inv.cpu().numpy() * 255).astype(np.uint8),
485
- )
486
-
487
- cos = F.cosine_similarity(
488
- torch.tensor([[0, 0, 1]], device=self.device),
489
- normal.view(-1, 3),
490
- ).view_as(normal[..., :1])
491
- cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
492
- # import pdb; pdb.set_trace()
493
- # cv2.imwrite(f"v3_cos{self.cnt}.png", (cos.cpu().numpy() * 255).astype(np.uint8))
494
- # cv2.imwrite(f"v3_sketch{self.cnt}.png", (sketch_image.cpu().numpy() * 255).astype(np.uint8))
495
-
496
- # cos2 = cv2.imread(f"v2_cos{self.cnt+1}.png", cv2.IMREAD_GRAYSCALE)
497
- # cos2 = torch.from_numpy(cos2[..., None]).to(self.device).float() / 255
498
- # cos = cos2
499
-
500
- self.cnt += 1
501
-
502
- k = self.bake_unreliable_kernel_size * 2 + 1
503
- kernel = torch.ones((1, 1, k, k), device=self.device)
504
-
505
- vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
506
- vis_mask = F.conv2d(
507
- 1.0 - vis_mask,
508
- kernel,
509
- padding=k // 2,
510
- )
511
- vis_mask = 1.0 - (vis_mask > 0).float()
512
- vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
513
-
514
- sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
515
- sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
516
- sketch_image = (sketch_image > 0).float()
517
- sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
518
- vis_mask = vis_mask * (sketch_image < 0.5)
519
- # import pdb; pdb.set_trace()
520
- cv2.imwrite(
521
- f"v3_db_sketch{self.cnt}.png",
522
- (sketch_image.cpu().numpy() * 255).astype(np.uint8),
523
- )
524
-
525
- cos[vis_mask == 0] = 0
526
- # import pdb; pdb.set_trace()
527
- # vis_mask = cv2.imread(f"v2_db_mask{self.cnt}.png", cv2.IMREAD_GRAYSCALE)
528
- # vis_mask = torch.from_numpy(vis_mask[..., None]).to(self.device).float() / 255
529
- # cos2 = cv2.imread(f"v2_db_cos{self.cnt}.png", cv2.IMREAD_GRAYSCALE)
530
- # cos2 = torch.from_numpy(cos2[..., None]).to(self.device).float() / 255
531
- # cos = cos2
532
-
533
- valid_pixels = (vis_mask != 0).view(-1)
534
- # import pdb; pdb.set_trace()
535
- cv2.imwrite(
536
- f"v3_db_uv{self.cnt}.png",
537
- (uv[..., 0].cpu().numpy() * 255).astype(np.uint8),
538
- )
539
- cv2.imwrite(
540
- f"v3_db_uv2{self.cnt}.png",
541
- (uv[..., 1].cpu().numpy() * 255).astype(np.uint8),
542
- )
543
- cv2.imwrite(
544
- f"v3_db_color{self.cnt}.png",
545
- (image.cpu().numpy() * 255).astype(np.uint8),
546
- )
547
- cv2.imwrite(
548
- f"v3_db_cos{self.cnt}.png",
549
- (cos.cpu().numpy() * 255).astype(np.uint8),
550
- )
551
- cv2.imwrite(
552
- f"v3_db_mask{self.cnt}.png",
553
- (vis_mask.cpu().numpy() * 255).astype(np.uint8),
554
- )
555
-
556
- return (
557
- self._scatter_texture(uv, image, valid_pixels),
558
- self._scatter_texture(uv, cos, valid_pixels),
559
- )
560
-
561
- def _scatter_texture(self, uv, data, mask):
562
- def __filter_data(data, mask):
563
- return data.view(-1, data.shape[-1])[mask]
564
-
565
- return _bilinear_interpolation_scattering(
566
- self.texture_wh[1],
567
- self.texture_wh[0],
568
- __filter_data(uv, mask)[..., [1, 0]],
569
- __filter_data(data, mask),
570
- )
571
-
572
- @torch.no_grad()
573
- def fast_bake_texture(
574
- self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
575
- ) -> tuple[torch.Tensor, torch.Tensor]:
576
- channel = textures[0].shape[-1]
577
- texture_merge = torch.zeros(self.texture_wh + (channel,)).to(
578
- self.device
579
- )
580
- trust_map_merge = torch.zeros(self.texture_wh + (1,)).to(self.device)
581
- for texture, cos_map in zip(textures, confidence_maps):
582
- view_sum = (cos_map > 0).sum()
583
- painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
584
- if painted_sum / view_sum > 0.99:
585
- continue
586
- texture_merge += texture * cos_map
587
- trust_map_merge += cos_map
588
- texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
589
-
590
- return texture_merge, trust_map_merge > 1e-8
591
-
592
- def uv_inpaint(
593
- self, texture: torch.Tensor, mask: torch.Tensor
594
- ) -> np.ndarray:
595
- texture_np = texture.cpu().numpy()
596
- mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
597
- vertices, faces, uv_map = self.get_mesh_attrs()
598
- # import pdb; pdb.set_trace()
599
- texture_np, mask_np = _texture_inpaint_smooth(
600
- texture_np, mask_np, vertices, faces, uv_map
601
- )
602
- texture_np = texture_np.clip(0, 1)
603
- texture_np = cv2.inpaint(
604
- (texture_np * 255).astype(np.uint8),
605
- 255 - mask_np,
606
- 3,
607
- cv2.INPAINT_NS,
608
- )
609
-
610
- return texture_np
611
-
612
- def __call__(
613
- self, colors: list[Image.Image], input_mesh: str, output_path: str
614
- ) -> trimesh.Trimesh:
615
- self.load_mesh(input_mesh)
616
-
617
- textures, weighted_cos_maps = [], []
618
- for color, cam_elev, cam_azim, weight in zip(
619
- colors, self.camera_elevs, self.camera_azims, self.view_weights
620
- ):
621
- texture, cos_map = self.back_project(color, cam_elev, cam_azim)
622
- textures.append(texture)
623
- weighted_cos_maps.append(weight * (cos_map**4))
624
-
625
- texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
626
- texture_np = self.uv_inpaint(texture, mask)
627
- texture_np = post_process_texture(texture_np)
628
- vertices, faces, uv_map = self.get_mesh_attrs(self.scale, self.center)
629
- # import pdb; pdb.set_trace()
630
- textured_mesh = save_mesh_with_mtl(
631
- vertices, faces, uv_map, texture_np, output_path
632
- )
633
-
634
- return textured_mesh
635
-
636
- def forward(
637
- self,
638
- colors: list[Image.Image],
639
- masks,
640
- depths,
641
- normals,
642
- uvs,
643
- ) -> trimesh.Trimesh:
644
- textures, weighted_cos_maps = [], []
645
- for color, mask, depth, normal, uv, weight in zip(
646
- colors, masks, depths, normals, uvs, self.view_weights
647
- ):
648
- texture, cos_map = self.back_project2(
649
- color, mask, depth, normal, uv
650
- )
651
- cv2.imwrite(
652
- f"v3_texture{self.cnt}.png",
653
- (texture.cpu().numpy() * 255).astype(np.uint8),
654
- )
655
- cv2.imwrite(
656
- f"v3_texture_cos{self.cnt}.png",
657
- (cos_map.cpu().numpy() * 255).astype(np.uint8),
658
- )
659
-
660
- textures.append(texture)
661
- weighted_cos_maps.append(weight * (cos_map**4))
662
-
663
- texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
664
- texture_np = self.uv_inpaint(texture, mask)
665
- texture_np = post_process_texture(texture_np)
666
- vertices, faces, uv_map = self.get_mesh_attrs(self.scale, self.center)
667
- # import pdb; pdb.set_trace()
668
- cv2.imwrite("v3_texture_np.png", texture_np)
669
- textured_mesh = save_mesh_with_mtl(
670
- vertices, faces, uv_map, texture_np, output_path
671
- )
672
-
673
- return textured_mesh
674
-
675
-
676
- class Image_Super_Net:
677
- def __init__(self, device="cuda"):
678
- from diffusers import StableDiffusionUpscalePipeline
679
-
680
- self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
681
- "stabilityai/stable-diffusion-x4-upscaler",
682
- torch_dtype=torch.float16,
683
- ).to(device)
684
- self.up_pipeline_x4.set_progress_bar_config(disable=True)
685
-
686
- def __call__(self, image, prompt=""):
687
- with torch.no_grad():
688
- upscaled_image = self.up_pipeline_x4(
689
- prompt=[prompt],
690
- image=image,
691
- num_inference_steps=10,
692
- ).images[0]
693
-
694
- return upscaled_image
695
-
696
-
697
- class Image_GANNet:
698
- def __init__(self, outscale: int):
699
- from basicsr.archs.rrdbnet_arch import RRDBNet
700
- from realesrgan import RealESRGANer
701
-
702
- self.outscale = outscale
703
- model = RRDBNet(
704
- num_in_ch=3,
705
- num_out_ch=3,
706
- num_feat=64,
707
- num_block=23,
708
- num_grow_ch=32,
709
- scale=4,
710
- )
711
- self.upsampler = RealESRGANer(
712
- scale=4,
713
- model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/super_resolution/RealESRGAN_x4plus.pth", # noqa
714
- model=model,
715
- pre_pad=0,
716
- half=True,
717
- )
718
-
719
- def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
720
- if isinstance(image, Image.Image):
721
- image = np.array(image)
722
- output, _ = self.upsampler.enhance(image, outscale=self.outscale)
723
-
724
- return Image.fromarray(output)
725
-
726
-
727
- if __name__ == "__main__":
728
- device = "cuda"
729
- color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png"
730
- mesh_path = "outputs/texture_mesh_gen/texture_mesh/kettle_color.glb"
731
- output_path = "robot_test_v6/robot.obj"
732
- target_image_size = (2048, 2048)
733
-
734
- super_model = Image_GANNet(outscale=4)
735
- multiviews = get_images_from_file(color_path, img_size=512)
736
- multiviews = [super_model(img) for img in multiviews]
737
- multiviews = [img.convert("RGB") for img in multiviews]
738
-
739
- from asset3d_gen.data.utils import (
740
- CameraSetting,
741
- init_kal_camera,
742
- DiffrastRender,
743
- )
744
- import nvdiffrast.torch as dr
745
-
746
- camera_params = CameraSetting(
747
- num_images=6,
748
- elevation=[20.0, -10.0],
749
- distance=5,
750
- resolution_hw=(2048, 2048),
751
- fov=math.radians(30),
752
- device="cuda",
753
- )
754
- camera = init_kal_camera(camera_params)
755
- mv = camera.view_matrix() # (n 4 4) world2cam
756
- p = camera.intrinsics.projection_matrix()
757
- # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
758
- p[:, 1, 1] = -p[:, 1, 1]
759
- renderer = DiffrastRender(
760
- p_matrix=p,
761
- mv_matrix=mv,
762
- resolution_hw=camera_params.resolution_hw,
763
- context=dr.RasterizeCudaContext(),
764
- mask_thresh=0.5,
765
- grad_db=False,
766
- device=camera_params.device,
767
- antialias_mask=True,
768
- )
769
-
770
- mesh = trimesh.load(mesh_path)
771
- if isinstance(mesh, trimesh.Scene):
772
- mesh = mesh.dump(concatenate=True)
773
-
774
- mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
775
- vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
776
- uvs[:, 1] = 1 - uvs[:, 1]
777
- mesh.vertices = mesh.vertices[vmapping]
778
- mesh.faces = indices
779
- mesh.visual.uv = uvs
780
-
781
- vertices = torch.from_numpy(mesh.vertices).to(camera_params.device).float()
782
- faces = (
783
- torch.from_numpy(mesh.faces).to(camera_params.device).to(torch.int64)
784
- )
785
- uvs = torch.from_numpy(mesh.visual.uv).to(camera_params.device).float()
786
-
787
- rendered_view_normals = []
788
- rast, vertices_clip = renderer.compute_dr_raster(vertices, faces)
789
- for idx in range(len(mv)):
790
- pos_cam = transform_vertices(mv[idx], vertices, keepdim=True)
791
- pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
792
- v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3))
793
- face_norm = F.normalize(torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1)
794
- vertex_norm = (
795
- torch.from_numpy(
796
- trimesh.geometry.mean_vertex_normals(
797
- len(pos_cam), faces.cpu(), face_norm.cpu()
798
- )
799
- )
800
- .to(camera_params.device)
801
- .contiguous()
802
- )
803
- im_base_normals, _ = dr.interpolate(
804
- vertex_norm[None, ...].float(),
805
- rast[idx : idx + 1],
806
- faces.to(torch.int32),
807
- )
808
- rendered_view_normals.append(im_base_normals)
809
-
810
- rendered_view_normals = torch.cat(rendered_view_normals, dim=0)
811
-
812
- rendered_depth, masks = renderer.render_depth(vertices, faces)
813
- norm_depths = []
814
- for idx in range(len(rendered_depth)):
815
- norm_depth = renderer.normalize_map_by_mask(
816
- rendered_depth[idx : idx + 1], masks[idx : idx + 1]
817
- )
818
- norm_depths.append(norm_depth)
819
- norm_depths = torch.cat(norm_depths, dim=0)
820
- render_uvs, _ = renderer.render_uv(vertices, faces, uvs)
821
-
822
- for index in range(6):
823
- cv2.imwrite(
824
- f"v3_mask{index}.png",
825
- (masks[index] * 255).cpu().numpy().astype(np.uint8),
826
- )
827
- cv2.imwrite(
828
- f"v3_normalv2{index}.png",
829
- (rendered_view_normals[index] * 255)
830
- .cpu()
831
- .numpy()
832
- .astype(np.uint8)[..., ::-1],
833
- )
834
- cv2.imwrite(
835
- f"v3_depth{index}.png",
836
- (norm_depths[index] * 255).cpu().numpy().astype(np.uint8),
837
- )
838
- cv2.imwrite(
839
- f"v3_uv{index}.png",
840
- (render_uvs[index, ..., 0] * 255).cpu().numpy().astype(np.uint8),
841
- )
842
- multiviews[index].save(f"v3_color{index}.png")
843
-
844
- texture_backer = TextureBacker(
845
- camera_elevs=[20, 20, 20, -10, -10, -10],
846
- camera_azims=[-180, -60, 60, -120, 0, 120],
847
- view_weights=[1, 0.2, 0.2, 0.2, 1, 0.2],
848
- camera_distance=5,
849
- camera_fov=30,
850
- render_wh=(2048, 2048),
851
- texture_wh=(2048, 2048),
852
- )
853
- texture_backer.vertices = vertices
854
- texture_backer.faces = faces
855
- uvs[:, 1] = 1.0 - uvs[:, 1]
856
- texture_backer.uv_map = uvs
857
- texture_backer.center = center
858
- texture_backer.scale = scale
859
-
860
- textured_mesh = texture_backer.forward(
861
- multiviews, masks, norm_depths, rendered_view_normals, render_uvs
862
- )
863
-
864
- # multiviews = [super_model(img) for img in multiviews]
865
- # multiviews = [img.convert("RGB") for img in multiviews]
866
- # textured_mesh = texture_backer(multiviews, mesh_path, output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/backup/backprojectv2.py DELETED
@@ -1,835 +0,0 @@
1
- from PIL import Image
2
- import torch
3
- import torch.nn.functional as F
4
- import numpy as np
5
- import math
6
- import trimesh
7
- import cv2
8
- import xatlas
9
- from typing import Union
10
-
11
-
12
- def get_perspective_projection_matrix(fovy, aspect_wh, near, far):
13
- fovy_rad = math.radians(fovy)
14
- return np.array(
15
- [
16
- [1.0 / (math.tan(fovy_rad / 2.0) * aspect_wh), 0, 0, 0],
17
- [0, 1.0 / math.tan(fovy_rad / 2.0), 0, 0],
18
- [
19
- 0,
20
- 0,
21
- -(far + near) / (far - near),
22
- -2.0 * far * near / (far - near),
23
- ],
24
- [0, 0, -1, 0],
25
- ]
26
- ).astype(np.float32)
27
-
28
-
29
- def load_mesh(mesh):
30
- vtx_pos = mesh.vertices if hasattr(mesh, "vertices") else None
31
- pos_idx = mesh.faces if hasattr(mesh, "faces") else None
32
-
33
- vtx_uv = mesh.visual.uv if hasattr(mesh.visual, "uv") else None
34
- uv_idx = mesh.faces if hasattr(mesh, "faces") else None
35
-
36
- texture_data = None
37
-
38
- return vtx_pos, pos_idx, vtx_uv, uv_idx, texture_data
39
-
40
-
41
- def save_mesh(mesh, texture_data):
42
- material = trimesh.visual.texture.SimpleMaterial(
43
- image=texture_data, diffuse=(255, 255, 255)
44
- )
45
- texture_visuals = trimesh.visual.TextureVisuals(
46
- uv=mesh.visual.uv, image=texture_data, material=material
47
- )
48
- mesh.visual = texture_visuals
49
- return mesh
50
-
51
-
52
- def transform_pos(mtx, pos, keepdim=False):
53
- t_mtx = (
54
- torch.from_numpy(mtx).to(pos.device)
55
- if isinstance(mtx, np.ndarray)
56
- else mtx
57
- )
58
- if pos.shape[-1] == 3:
59
- posw = torch.cat(
60
- [pos, torch.ones([pos.shape[0], 1]).to(pos.device)], axis=1
61
- )
62
- else:
63
- posw = pos
64
-
65
- if keepdim:
66
- return torch.matmul(posw, t_mtx.t())[...]
67
- else:
68
- return torch.matmul(posw, t_mtx.t())[None, ...]
69
-
70
-
71
- def get_mv_matrix(elev, azim, camera_distance, center=None):
72
- elev = -elev
73
-
74
- elev_rad = math.radians(elev)
75
- azim_rad = math.radians(azim)
76
-
77
- camera_position = np.array(
78
- [
79
- camera_distance * math.cos(elev_rad) * math.cos(azim_rad),
80
- camera_distance * math.cos(elev_rad) * math.sin(azim_rad),
81
- camera_distance * math.sin(elev_rad),
82
- ]
83
- )
84
-
85
- if center is None:
86
- center = np.array([0, 0, 0])
87
- else:
88
- center = np.array(center)
89
-
90
- lookat = center - camera_position
91
- lookat = lookat / np.linalg.norm(lookat)
92
-
93
- up = np.array([0, 0, 1.0])
94
- right = np.cross(lookat, up)
95
- right = right / np.linalg.norm(right)
96
- up = np.cross(right, lookat)
97
- up = up / np.linalg.norm(up)
98
-
99
- c2w = np.concatenate(
100
- [np.stack([right, up, -lookat], axis=-1), camera_position[:, None]],
101
- axis=-1,
102
- )
103
-
104
- w2c = np.zeros((4, 4))
105
- w2c[:3, :3] = np.transpose(c2w[:3, :3], (1, 0))
106
- w2c[:3, 3:] = -np.matmul(np.transpose(c2w[:3, :3], (1, 0)), c2w[:3, 3:])
107
- w2c[3, 3] = 1.0
108
-
109
- return w2c.astype(np.float32)
110
-
111
-
112
- def stride_from_shape(shape):
113
- stride = [1]
114
- for x in reversed(shape[1:]):
115
- stride.append(stride[-1] * x)
116
- return list(reversed(stride))
117
-
118
-
119
- def scatter_add_nd_with_count(input, count, indices, values, weights=None):
120
- # input: [..., C], D dimension + C channel
121
- # count: [..., 1], D dimension
122
- # indices: [N, D], long
123
- # values: [N, C]
124
-
125
- D = indices.shape[-1]
126
- C = input.shape[-1]
127
- size = input.shape[:-1]
128
- stride = stride_from_shape(size)
129
-
130
- assert len(size) == D
131
-
132
- input = input.view(-1, C) # [HW, C]
133
- count = count.view(-1, 1)
134
-
135
- flatten_indices = (
136
- indices * torch.tensor(stride, dtype=torch.long, device=indices.device)
137
- ).sum(
138
- -1
139
- ) # [N]
140
-
141
- if weights is None:
142
- weights = torch.ones_like(values[..., :1])
143
-
144
- input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
145
- count.scatter_add_(0, flatten_indices.unsqueeze(1), weights)
146
-
147
- return input.view(*size, C), count.view(*size, 1)
148
-
149
-
150
- def linear_grid_put_2d(H, W, coords, values, return_count=False):
151
- # coords: [N, 2], float in [0, 1]
152
- # values: [N, C]
153
-
154
- C = values.shape[-1]
155
-
156
- indices = coords * torch.tensor(
157
- [H - 1, W - 1], dtype=torch.float32, device=coords.device
158
- )
159
- indices_00 = indices.floor().long() # [N, 2]
160
- indices_00[:, 0].clamp_(0, H - 2)
161
- indices_00[:, 1].clamp_(0, W - 2)
162
- indices_01 = indices_00 + torch.tensor(
163
- [0, 1], dtype=torch.long, device=indices.device
164
- )
165
- indices_10 = indices_00 + torch.tensor(
166
- [1, 0], dtype=torch.long, device=indices.device
167
- )
168
- indices_11 = indices_00 + torch.tensor(
169
- [1, 1], dtype=torch.long, device=indices.device
170
- )
171
-
172
- h = indices[..., 0] - indices_00[..., 0].float()
173
- w = indices[..., 1] - indices_00[..., 1].float()
174
- w_00 = (1 - h) * (1 - w)
175
- w_01 = (1 - h) * w
176
- w_10 = h * (1 - w)
177
- w_11 = h * w
178
-
179
- result = torch.zeros(
180
- H, W, C, device=values.device, dtype=values.dtype
181
- ) # [H, W, C]
182
- count = torch.zeros(
183
- H, W, 1, device=values.device, dtype=values.dtype
184
- ) # [H, W, 1]
185
- weights = torch.ones_like(values[..., :1]) # [N, 1]
186
-
187
- result, count = scatter_add_nd_with_count(
188
- result,
189
- count,
190
- indices_00,
191
- values * w_00.unsqueeze(1),
192
- weights * w_00.unsqueeze(1),
193
- )
194
- result, count = scatter_add_nd_with_count(
195
- result,
196
- count,
197
- indices_01,
198
- values * w_01.unsqueeze(1),
199
- weights * w_01.unsqueeze(1),
200
- )
201
- result, count = scatter_add_nd_with_count(
202
- result,
203
- count,
204
- indices_10,
205
- values * w_10.unsqueeze(1),
206
- weights * w_10.unsqueeze(1),
207
- )
208
- result, count = scatter_add_nd_with_count(
209
- result,
210
- count,
211
- indices_11,
212
- values * w_11.unsqueeze(1),
213
- weights * w_11.unsqueeze(1),
214
- )
215
-
216
- if return_count:
217
- return result, count
218
-
219
- mask = count.squeeze(-1) > 0
220
- result[mask] = result[mask] / count[mask].repeat(1, C)
221
-
222
- return result
223
-
224
-
225
- def meshVerticeInpaint_smooth(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx):
226
- texture_height, texture_width, texture_channel = texture.shape
227
- vtx_num = vtx_pos.shape[0]
228
-
229
- vtx_mask = np.zeros(vtx_num, dtype=np.float32)
230
- vtx_color = [
231
- np.zeros(texture_channel, dtype=np.float32) for _ in range(vtx_num)
232
- ]
233
- uncolored_vtxs = []
234
- G = [[] for _ in range(vtx_num)]
235
-
236
- for i in range(uv_idx.shape[0]):
237
- for k in range(3):
238
- vtx_uv_idx = uv_idx[i, k]
239
- vtx_idx = pos_idx[i, k]
240
- uv_v = int(round(vtx_uv[vtx_uv_idx, 0] * (texture_width - 1)))
241
- uv_u = int(
242
- round((1.0 - vtx_uv[vtx_uv_idx, 1]) * (texture_height - 1))
243
- )
244
- if mask[uv_u, uv_v] > 0:
245
- vtx_mask[vtx_idx] = 1.0
246
- vtx_color[vtx_idx] = texture[uv_u, uv_v]
247
- else:
248
- uncolored_vtxs.append(vtx_idx)
249
- G[pos_idx[i, k]].append(pos_idx[i, (k + 1) % 3])
250
-
251
- smooth_count = 2
252
- last_uncolored_vtx_count = 0
253
- while smooth_count > 0:
254
- uncolored_vtx_count = 0
255
- for vtx_idx in uncolored_vtxs:
256
- sum_color = np.zeros(texture_channel, dtype=np.float32)
257
- total_weight = 0.0
258
- vtx_0 = vtx_pos[vtx_idx]
259
- for connected_idx in G[vtx_idx]:
260
- if vtx_mask[connected_idx] > 0:
261
- vtx1 = vtx_pos[connected_idx]
262
- dist = np.sqrt(np.sum((vtx_0 - vtx1) ** 2))
263
- dist_weight = 1.0 / max(dist, 1e-4)
264
- dist_weight *= dist_weight
265
- sum_color += vtx_color[connected_idx] * dist_weight
266
- total_weight += dist_weight
267
- if total_weight > 0:
268
- vtx_color[vtx_idx] = sum_color / total_weight
269
- vtx_mask[vtx_idx] = 1.0
270
- else:
271
- uncolored_vtx_count += 1
272
-
273
- if last_uncolored_vtx_count == uncolored_vtx_count:
274
- smooth_count -= 1
275
- else:
276
- smooth_count += 1
277
- last_uncolored_vtx_count = uncolored_vtx_count
278
-
279
- new_texture = texture.copy()
280
- new_mask = mask.copy()
281
- for face_idx in range(uv_idx.shape[0]):
282
- for k in range(3):
283
- vtx_uv_idx = uv_idx[face_idx, k]
284
- vtx_idx = pos_idx[face_idx, k]
285
- if vtx_mask[vtx_idx] == 1.0:
286
- uv_v = int(round(vtx_uv[vtx_uv_idx, 0] * (texture_width - 1)))
287
- uv_u = int(
288
- round((1.0 - vtx_uv[vtx_uv_idx, 1]) * (texture_height - 1))
289
- )
290
- new_texture[uv_u, uv_v] = vtx_color[vtx_idx]
291
- new_mask[uv_u, uv_v] = 255
292
-
293
- return new_texture, new_mask
294
-
295
-
296
- def mesh_uv_wrap(mesh):
297
- if isinstance(mesh, trimesh.Scene):
298
- mesh = mesh.dump(concatenate=True)
299
-
300
- if len(mesh.faces) > 500000000:
301
- raise ValueError(
302
- "The mesh has more than 500,000,000 faces, which is not supported."
303
- )
304
-
305
- vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
306
-
307
- mesh.vertices = mesh.vertices[vmapping]
308
- mesh.faces = indices
309
- mesh.visual.uv = uvs
310
-
311
- return mesh
312
-
313
-
314
- class MeshRender:
315
- def __init__(
316
- self,
317
- camera_distance=1.45,
318
- default_resolution=1024,
319
- texture_size=1024,
320
- use_antialias=True,
321
- max_mip_level=None,
322
- filter_mode="linear",
323
- bake_mode="linear",
324
- raster_mode="cr",
325
- device="cuda",
326
- ):
327
-
328
- self.device = device
329
-
330
- self.set_default_render_resolution(default_resolution)
331
- self.set_default_texture_resolution(texture_size)
332
-
333
- self.camera_distance = camera_distance
334
- self.use_antialias = use_antialias
335
- self.max_mip_level = max_mip_level
336
- self.filter_mode = filter_mode
337
-
338
- self.bake_angle_thres = 75
339
- self.bake_unreliable_kernel_size = int(
340
- (2 / 512)
341
- * max(self.default_resolution[0], self.default_resolution[1])
342
- )
343
- self.bake_mode = bake_mode
344
-
345
- self.raster_mode = raster_mode
346
- if self.raster_mode == "cr":
347
- import custom_rasterizer as cr
348
-
349
- self.raster = cr
350
- else:
351
- raise f"No raster named {self.raster_mode}"
352
-
353
- fov = 30
354
- self.camera_proj_mat = get_perspective_projection_matrix(
355
- fov,
356
- self.default_resolution[1] / self.default_resolution[0],
357
- 0.01,
358
- 100.0,
359
- )
360
-
361
- def raster_rasterize(
362
- self, pos, tri, resolution, ranges=None, grad_db=True
363
- ):
364
-
365
- if self.raster_mode == "cr":
366
- rast_out_db = None
367
- if pos.dim() == 2:
368
- pos = pos.unsqueeze(0)
369
- findices, barycentric = self.raster.rasterize(pos, tri, resolution)
370
- rast_out = torch.cat((barycentric, findices.unsqueeze(-1)), dim=-1)
371
- rast_out = rast_out.unsqueeze(0)
372
- else:
373
- raise f"No raster named {self.raster_mode}"
374
-
375
- return rast_out, rast_out_db
376
-
377
- def raster_interpolate(
378
- self, uv, rast_out, uv_idx, rast_db=None, diff_attrs=None
379
- ):
380
-
381
- if self.raster_mode == "cr":
382
- textd = None
383
- barycentric = rast_out[0, ..., :-1]
384
- findices = rast_out[0, ..., -1]
385
- if uv.dim() == 2:
386
- uv = uv.unsqueeze(0)
387
- textc = self.raster.interpolate(uv, findices, barycentric, uv_idx)
388
- else:
389
- raise f"No raster named {self.raster_mode}"
390
-
391
- return textc, textd
392
-
393
- def load_mesh(
394
- self,
395
- mesh,
396
- ):
397
- vtx_pos, pos_idx, vtx_uv, uv_idx, texture_data = load_mesh(mesh)
398
- self.mesh_copy = mesh
399
- self.set_mesh(
400
- vtx_pos,
401
- pos_idx,
402
- vtx_uv=vtx_uv,
403
- uv_idx=uv_idx,
404
- )
405
- if texture_data is not None:
406
- self.set_texture(texture_data)
407
-
408
- def save_mesh(self):
409
- texture_data = self.get_texture()
410
- texture_data = Image.fromarray((texture_data * 255).astype(np.uint8))
411
- return save_mesh(self.mesh_copy, texture_data)
412
-
413
- def set_mesh(
414
- self,
415
- vtx_pos,
416
- pos_idx,
417
- vtx_uv=None,
418
- uv_idx=None,
419
- ):
420
-
421
- self.vtx_pos = torch.from_numpy(vtx_pos).to(self.device).float()
422
- self.pos_idx = torch.from_numpy(pos_idx).to(self.device).to(torch.int)
423
- if (vtx_uv is not None) and (uv_idx is not None):
424
- self.vtx_uv = torch.from_numpy(vtx_uv).to(self.device).float()
425
- self.uv_idx = (
426
- torch.from_numpy(uv_idx).to(self.device).to(torch.int)
427
- )
428
- else:
429
- self.vtx_uv = None
430
- self.uv_idx = None
431
-
432
- self.vtx_pos[:, [0, 1]] = -self.vtx_pos[:, [0, 1]]
433
- self.vtx_pos[:, [1, 2]] = self.vtx_pos[:, [2, 1]]
434
- if (vtx_uv is not None) and (uv_idx is not None):
435
- self.vtx_uv[:, 1] = 1.0 - self.vtx_uv[:, 1]
436
-
437
- def set_texture(self, tex):
438
- if isinstance(tex, np.ndarray):
439
- tex = Image.fromarray((tex * 255).astype(np.uint8))
440
- elif isinstance(tex, torch.Tensor):
441
- tex = tex.cpu().numpy()
442
- tex = Image.fromarray((tex * 255).astype(np.uint8))
443
-
444
- tex = tex.resize(self.texture_size).convert("RGB")
445
- tex = np.array(tex) / 255.0
446
- self.tex = torch.from_numpy(tex).to(self.device)
447
- self.tex = self.tex.float()
448
-
449
- def set_default_render_resolution(self, default_resolution):
450
- if isinstance(default_resolution, int):
451
- default_resolution = (default_resolution, default_resolution)
452
- self.default_resolution = default_resolution
453
-
454
- def set_default_texture_resolution(self, texture_size):
455
- if isinstance(texture_size, int):
456
- texture_size = (texture_size, texture_size)
457
- self.texture_size = texture_size
458
-
459
- def get_mesh(self):
460
- vtx_pos = self.vtx_pos.cpu().numpy()
461
- pos_idx = self.pos_idx.cpu().numpy()
462
- vtx_uv = self.vtx_uv.cpu().numpy()
463
- uv_idx = self.uv_idx.cpu().numpy()
464
-
465
- # 坐标变换的逆变换
466
- vtx_pos[:, [1, 2]] = vtx_pos[:, [2, 1]]
467
- vtx_pos[:, [0, 1]] = -vtx_pos[:, [0, 1]]
468
-
469
- vtx_uv[:, 1] = 1.0 - vtx_uv[:, 1]
470
- return vtx_pos, pos_idx, vtx_uv, uv_idx
471
-
472
- def get_texture(self):
473
- return self.tex.cpu().numpy()
474
-
475
- def render_sketch_from_depth(self, depth_image):
476
- depth_image_np = depth_image.cpu().numpy()
477
- depth_image_np = (depth_image_np * 255).astype(np.uint8)
478
- depth_edges = cv2.Canny(depth_image_np, 30, 80)
479
- combined_edges = depth_edges
480
- sketch_image = (
481
- torch.from_numpy(combined_edges).to(depth_image.device).float()
482
- / 255.0
483
- )
484
- sketch_image = sketch_image.unsqueeze(-1)
485
- return sketch_image
486
-
487
- def back_project(
488
- self, image, elev, azim, camera_distance=None, center=None, method=None
489
- ):
490
- if isinstance(image, Image.Image):
491
- image = torch.tensor(np.array(image) / 255.0)
492
- elif isinstance(image, np.ndarray):
493
- image = torch.tensor(image)
494
- if image.dim() == 2:
495
- image = image.unsqueeze(-1)
496
- image = image.float().to(self.device)
497
- resolution = image.shape[:2]
498
- channel = image.shape[-1]
499
- texture = torch.zeros(self.texture_size + (channel,)).to(self.device)
500
- cos_map = torch.zeros(self.texture_size + (1,)).to(self.device)
501
-
502
- proj = self.camera_proj_mat
503
- r_mv = get_mv_matrix(
504
- elev=elev,
505
- azim=azim,
506
- camera_distance=(
507
- self.camera_distance
508
- if camera_distance is None
509
- else camera_distance
510
- ),
511
- center=center,
512
- )
513
- pos_camera = transform_pos(r_mv, self.vtx_pos, keepdim=True)
514
- pos_clip = transform_pos(proj, pos_camera)
515
- pos_camera = pos_camera[:, :3] / pos_camera[:, 3:4]
516
- v0 = pos_camera[self.pos_idx[:, 0], :]
517
- v1 = pos_camera[self.pos_idx[:, 1], :]
518
- v2 = pos_camera[self.pos_idx[:, 2], :]
519
- face_normals = F.normalize(
520
- torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1
521
- )
522
- vertex_normals = trimesh.geometry.mean_vertex_normals(
523
- vertex_count=self.vtx_pos.shape[0],
524
- faces=self.pos_idx.cpu(),
525
- face_normals=face_normals.cpu(),
526
- )
527
- vertex_normals = (
528
- torch.from_numpy(vertex_normals)
529
- .float()
530
- .to(self.device)
531
- .contiguous()
532
- )
533
- tex_depth = pos_camera[:, 2].reshape(1, -1, 1).contiguous()
534
- rast_out, rast_out_db = self.raster_rasterize(
535
- pos_clip, self.pos_idx, resolution=resolution
536
- )
537
- visible_mask = torch.clamp(rast_out[..., -1:], 0, 1)[0, ...]
538
-
539
- normal, _ = self.raster_interpolate(
540
- vertex_normals[None, ...], rast_out, self.pos_idx
541
- )
542
- normal = normal[0, ...]
543
-
544
- uv, _ = self.raster_interpolate(
545
- self.vtx_uv[None, ...], rast_out, self.uv_idx
546
- )
547
- depth, _ = self.raster_interpolate(tex_depth, rast_out, self.pos_idx)
548
- depth = depth[0, ...]
549
-
550
- depth_max, depth_min = (
551
- depth[visible_mask > 0].max(),
552
- depth[visible_mask > 0].min(),
553
- )
554
- depth_normalized = (depth - depth_min) / (depth_max - depth_min)
555
- depth_image = depth_normalized * visible_mask # Mask out background.
556
-
557
- sketch_image = self.render_sketch_from_depth(depth_image)
558
-
559
- cv2.imwrite("d_depth.png", depth_image.cpu().numpy() * 255)
560
- cv2.imwrite("d_normal.png", normal.cpu().numpy() * 255)
561
- cv2.imwrite(
562
- "d_image.png", image.cpu().numpy()[..., :3][..., ::-1] * 255
563
- )
564
- cv2.imwrite("d_sketch_image.png", sketch_image.cpu().numpy() * 255)
565
- cv2.imwrite("d_uv1.png", uv.cpu().numpy()[0, ..., 0] * 255)
566
- cv2.imwrite("d_uv2.png", uv.cpu().numpy()[0, ..., 1] * 255)
567
- # p uv[0,...,0].mean(axis=0)
568
- # import pdb; pdb.set_trace()
569
-
570
- # depth_image = None
571
- # normal = None
572
- # image = None
573
-
574
- sketch_image = self.render_sketch_from_depth(depth_image)
575
- channel = image.shape[-1]
576
-
577
- lookat = torch.tensor([[0, 0, -1]], device=self.device)
578
- cos_image = torch.nn.functional.cosine_similarity(
579
- lookat, normal.view(-1, 3)
580
- )
581
- cos_image = cos_image.view(normal.shape[0], normal.shape[1], 1)
582
-
583
- cos_thres = np.cos(self.bake_angle_thres / 180 * np.pi)
584
- cos_image[cos_image < cos_thres] = 0
585
-
586
- # shrink
587
- kernel_size = self.bake_unreliable_kernel_size * 2 + 1
588
- kernel = torch.ones(
589
- (1, 1, kernel_size, kernel_size), dtype=torch.float32
590
- ).to(sketch_image.device)
591
-
592
- visible_mask = visible_mask.permute(2, 0, 1).unsqueeze(0).float()
593
- visible_mask = F.conv2d(
594
- 1.0 - visible_mask, kernel, padding=kernel_size // 2
595
- )
596
- visible_mask = 1.0 - (visible_mask > 0).float() # 二值化
597
- visible_mask = visible_mask.squeeze(0).permute(1, 2, 0)
598
-
599
- sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
600
- sketch_image = F.conv2d(sketch_image, kernel, padding=kernel_size // 2)
601
- sketch_image = (sketch_image > 0).float() # 二值化
602
- sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
603
- visible_mask = visible_mask * (sketch_image < 0.5)
604
-
605
- cos_image[visible_mask == 0] = 0
606
- proj_mask = (visible_mask != 0).view(-1)
607
- uv = uv.squeeze(0).contiguous().view(-1, 2)[proj_mask]
608
- image = image.squeeze(0).contiguous().view(-1, channel)[proj_mask]
609
- cos_image = cos_image.contiguous().view(-1, 1)[proj_mask]
610
- sketch_image = sketch_image.contiguous().view(-1, 1)[proj_mask]
611
- import pdb
612
-
613
- pdb.set_trace()
614
- texture = linear_grid_put_2d(
615
- self.texture_size[1], self.texture_size[0], uv[..., [1, 0]], image
616
- )
617
- cos_map = linear_grid_put_2d(
618
- self.texture_size[1],
619
- self.texture_size[0],
620
- uv[..., [1, 0]],
621
- cos_image,
622
- )
623
- boundary_map = linear_grid_put_2d(
624
- self.texture_size[1],
625
- self.texture_size[0],
626
- uv[..., [1, 0]],
627
- sketch_image,
628
- )
629
-
630
- return texture, cos_map, boundary_map
631
-
632
- @torch.no_grad()
633
- def fast_bake_texture(self, textures, cos_maps):
634
-
635
- channel = textures[0].shape[-1]
636
- texture_merge = torch.zeros(self.texture_size + (channel,)).to(
637
- self.device
638
- )
639
- trust_map_merge = torch.zeros(self.texture_size + (1,)).to(self.device)
640
- for texture, cos_map in zip(textures, cos_maps):
641
- view_sum = (cos_map > 0).sum()
642
- painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
643
- if painted_sum / view_sum > 0.99:
644
- continue
645
- texture_merge += texture * cos_map
646
- trust_map_merge += cos_map
647
- texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
648
-
649
- return texture_merge, trust_map_merge > 1e-8
650
-
651
- def uv_inpaint(self, texture, mask):
652
-
653
- if isinstance(texture, torch.Tensor):
654
- texture_np = texture.cpu().numpy()
655
- elif isinstance(texture, np.ndarray):
656
- texture_np = texture
657
- elif isinstance(texture, Image.Image):
658
- texture_np = np.array(texture) / 255.0
659
-
660
- vtx_pos, pos_idx, vtx_uv, uv_idx = self.get_mesh()
661
-
662
- texture_np, mask = meshVerticeInpaint_smooth(
663
- texture_np, mask, vtx_pos, vtx_uv, pos_idx, uv_idx
664
- )
665
-
666
- texture_np = cv2.inpaint(
667
- (texture_np * 255).astype(np.uint8), 255 - mask, 3, cv2.INPAINT_NS
668
- )
669
-
670
- return texture_np
671
-
672
-
673
- def get_images_from_file(img_path: str, img_size: int) -> list[np.array]:
674
- input_image = Image.open(img_path)
675
- view_images = np.array(input_image)
676
- view_images = np.concatenate(
677
- [view_images[:img_size, ...], view_images[img_size:, ...]], axis=1
678
- )
679
- images = np.split(view_images, view_images.shape[1] // img_size, axis=1)
680
-
681
- return images
682
-
683
-
684
- def bake_from_multiview(
685
- render, views, camera_elevs, camera_azims, view_weights, method="fast"
686
- ):
687
- project_textures, project_weighted_cos_maps = [], []
688
- project_boundary_maps = []
689
- for view, camera_elev, camera_azim, weight in zip(
690
- views, camera_elevs, camera_azims, view_weights
691
- ):
692
- project_texture, project_cos_map, project_boundary_map = (
693
- render.back_project(view, camera_elev, camera_azim)
694
- )
695
- project_cos_map = weight * (project_cos_map**4)
696
- project_textures.append(project_texture)
697
- project_weighted_cos_maps.append(project_cos_map)
698
- project_boundary_maps.append(project_boundary_map)
699
-
700
- if method == "fast":
701
- texture, ori_trust_map = render.fast_bake_texture(
702
- project_textures, project_weighted_cos_maps
703
- )
704
- else:
705
- raise f"no method {method}"
706
-
707
- return texture, ori_trust_map > 1e-8
708
-
709
-
710
- def post_process(texture: np.ndarray, iter: int = 2) -> np.ndarray:
711
- for _ in range(iter):
712
- texture = cv2.fastNlMeansDenoisingColored(texture, None, 11, 11, 9, 25)
713
- texture = cv2.bilateralFilter(
714
- texture, d=7, sigmaColor=80, sigmaSpace=80
715
- )
716
-
717
- return texture
718
-
719
-
720
- class Image_Super_Net:
721
- def __init__(self, device="cuda"):
722
- from diffusers import StableDiffusionUpscalePipeline
723
-
724
- self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
725
- "stabilityai/stable-diffusion-x4-upscaler",
726
- torch_dtype=torch.float16,
727
- ).to(device)
728
- self.up_pipeline_x4.set_progress_bar_config(disable=True)
729
-
730
- def __call__(self, image, prompt=""):
731
- with torch.no_grad():
732
- upscaled_image = self.up_pipeline_x4(
733
- prompt=[prompt],
734
- image=image,
735
- num_inference_steps=10,
736
- ).images[0]
737
-
738
- return upscaled_image
739
-
740
-
741
- class Image_GANNet:
742
- def __init__(self, outscale: int):
743
- from realesrgan import RealESRGANer
744
- from basicsr.archs.rrdbnet_arch import RRDBNet
745
-
746
- self.outscale = outscale
747
- model = RRDBNet(
748
- num_in_ch=3,
749
- num_out_ch=3,
750
- num_feat=64,
751
- num_block=23,
752
- num_grow_ch=32,
753
- scale=4,
754
- )
755
- self.upsampler = RealESRGANer(
756
- scale=4,
757
- model_path="/home/users/xinjie.wang/xinjie/Real-ESRGAN/weights/RealESRGAN_x4plus.pth",
758
- model=model,
759
- pre_pad=0,
760
- half=True,
761
- )
762
-
763
- def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
764
- if isinstance(image, Image.Image):
765
- image = np.array(image)
766
- output, _ = self.upsampler.enhance(image, outscale=self.outscale)
767
-
768
- return Image.fromarray(output)
769
-
770
-
771
- if __name__ == "__main__":
772
- device = "cuda"
773
-
774
- # super_model = Image_Super_Net(device)
775
- super_model = Image_GANNet(outscale=4)
776
-
777
- selected_camera_elevs = [20, 20, 20, -10, -10, -10]
778
- selected_camera_azims = [-180, -60, 60, -120, 0, 120]
779
- selected_view_weights = [1, 0.2, 0.2, 0.2, 1, 0.2]
780
- # selected_view_weights = [1, 0.1, 0.5, 0.1, 0.05, 0.05]
781
-
782
- multiviews = get_images_from_file(
783
- "scripts/apps/texture_sessions/mfq4e7u4ko/multi_view/color_sample1.png",
784
- 512,
785
- )
786
- target_image_size = (2048, 2048)
787
-
788
- render = MeshRender(
789
- camera_distance=5,
790
- default_resolution=2048,
791
- texture_size=2048,
792
- )
793
-
794
- mesh = trimesh.load("scripts/apps/assets/example_texture/meshes/robot.obj")
795
- from asset3d_gen.data.utils import normalize_vertices_array
796
-
797
- mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
798
- mesh = mesh_uv_wrap(mesh)
799
- render.load_mesh(mesh)
800
-
801
- # multiviews = [Image.fromarray(img) for img in multiviews]
802
- # multiviews = [Image.fromarray(img).convert("RGB") for img in multiviews]
803
- # for idx, img in enumerate(multiviews):
804
- # img.save(f"robot/raw/res_{idx}.png")
805
-
806
- multiviews = [super_model(img) for img in multiviews]
807
- multiviews = [img.convert("RGB") for img in multiviews]
808
- for idx, img in enumerate(multiviews):
809
- img.save(f"robot/super_gan_res_{idx}.png")
810
-
811
- texture, mask = bake_from_multiview(
812
- render,
813
- multiviews,
814
- selected_camera_elevs,
815
- selected_camera_azims,
816
- selected_view_weights,
817
- )
818
-
819
- texture_np = (texture.cpu().numpy() * 255).astype(np.uint8)[..., :3][
820
- ..., ::-1
821
- ]
822
- cv2.imwrite("robot/raw_texture.png", texture_np)
823
- print("texture done.")
824
-
825
- mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
826
- texture_np = render.uv_inpaint(texture, mask_np)
827
- cv2.imwrite("robot/inpaint_texture.png", texture_np[..., ::-1])
828
- # texture_np = post_process(texture_np, 2)
829
- # cv2.imwrite("robot/inpaint_conv_texture.png", texture_np[..., ::-1])
830
- print("inpaint done.")
831
-
832
- texture = torch.tensor(texture_np / 255).float().to(texture.device)
833
- render.set_texture(texture)
834
- textured_mesh = render.save_mesh()
835
- _ = textured_mesh.export("robot/robot.obj")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/backup/gpt_qwen.py DELETED
@@ -1,70 +0,0 @@
1
- import torch
2
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
3
- from qwen_vl_utils import process_vision_info
4
- import os
5
- os.environ["https_proxy"] = "10.9.0.31:8838"
6
-
7
-
8
- # # default: Load the model on the available device(s)
9
- # model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
10
- # "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
11
- # )
12
-
13
- # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
14
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
15
- "Qwen/Qwen2.5-VL-7B-Instruct",
16
- torch_dtype=torch.bfloat16,
17
- attn_implementation="flash_attention_2",
18
- device_map="auto",
19
- )
20
-
21
-
22
- # default processer
23
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
24
-
25
- # The default range for the number of visual tokens per image in the model is 4-16384.
26
- # You can set min_pixels and max_pixels according to your needs, such as a token range of 256-1280, to balance performance and cost.
27
- # min_pixels = 256*28*28
28
- # max_pixels = 1280*28*28
29
- # processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
30
-
31
- messages = [
32
- {
33
- "role": "user",
34
- "content": [
35
- {
36
- "type": "image",
37
- "image": "outputs/text2image/demo_objects/bed/sample_0.jpg",
38
- },
39
- {
40
- "type": "image",
41
- "image": "outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png",
42
- },
43
- {"type": "text", "text": "Describe the secend image."},
44
- ],
45
- }
46
- ]
47
-
48
- # Preparation for inference
49
- text = processor.apply_chat_template(
50
- messages, tokenize=False, add_generation_prompt=True
51
- )
52
- image_inputs, video_inputs = process_vision_info(messages)
53
- inputs = processor(
54
- text=[text],
55
- images=image_inputs,
56
- videos=video_inputs,
57
- padding=True,
58
- return_tensors="pt",
59
- )
60
- inputs = inputs.to("cuda")
61
-
62
- # Inference: Generation of the output
63
- generated_ids = model.generate(**inputs, max_new_tokens=128)
64
- generated_ids_trimmed = [
65
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
66
- ]
67
- output_text = processor.batch_decode(
68
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
69
- )
70
- print(output_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/backup/quat.py DELETED
@@ -1,49 +0,0 @@
1
- import numpy as np
2
-
3
- def quaternion_rotation_x_counterclockwise(angle_degrees):
4
- angle_radians = np.radians(angle_degrees)
5
- w = np.cos(angle_radians / 2)
6
- x = np.sin(angle_radians / 2)
7
- y, z = 0.0, 0.0
8
- return np.array([x, y, z, w]).round(4).tolist()
9
-
10
-
11
- def quaternion_rotation_y_counterclockwise(angle_degrees):
12
- angle_radians = np.radians(angle_degrees)
13
- w = np.cos(angle_radians / 2)
14
- y = np.sin(angle_radians / 2)
15
- x, z = 0.0, 0.0
16
- return np.array([x, y, z, w]).round(4).tolist()
17
-
18
-
19
- def quaternion_rotation_z_counterclockwise(angle_degrees):
20
- angle_radians = np.radians(angle_degrees)
21
- w = np.cos(angle_radians / 2)
22
- z = np.sin(angle_radians / 2)
23
- x, y = 0.0, 0.0
24
- return np.array([x, y, z, w]).round(4).tolist()
25
-
26
-
27
- def quaternion_multiply(q1, q2):
28
- x1, y1, z1, w1 = q1
29
- x2, y2, z2, w2 = q2
30
- w = w1*w2 - x1*x2 - y1*y2 - z1*z2
31
- x = w1*x2 + x1*w2 + y1*z2 - z1*y2
32
- y = w1*y2 - x1*z2 + y1*w2 + z1*x2
33
- z = w1*z2 + x1*y2 - y1*x2 + z1*w2
34
- return np.array([w, x, y, z])
35
-
36
-
37
-
38
- angle = 180
39
-
40
- print(f"X轴逆时针旋转{angle}度: {quaternion_rotation_x_counterclockwise(angle)}")
41
- print(f"Y轴逆时针旋转{angle}度: {quaternion_rotation_y_counterclockwise(angle)}")
42
- print(f"Z轴逆时针旋转{angle}度: {quaternion_rotation_z_counterclockwise(angle)}")
43
-
44
-
45
- q_1 = np.array([1.0, 0.0, 0.0, 0.0])
46
- q_2 = np.array([0.0, 0.0, 1.0, 0.0])
47
-
48
- q_total = quaternion_multiply(q_2, q_1)
49
- print(q_total.round(4).tolist())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/datasets.py DELETED
@@ -1,239 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import random
5
- from typing import Any, Callable, Dict, List, Tuple, Union
6
-
7
- import torch
8
- import torch.utils.checkpoint
9
- from PIL import Image
10
- from torch import nn
11
- from torch.utils.data import Dataset
12
- from torchvision import transforms
13
-
14
- logging.basicConfig(
15
- format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
16
- )
17
- logger = logging.getLogger(__name__)
18
-
19
-
20
- __all__ = [
21
- "Asset3dGenDataset",
22
- ]
23
-
24
-
25
- class Asset3dGenDataset(Dataset):
26
- def __init__(
27
- self,
28
- index_file: str,
29
- target_hw: Tuple[int, int],
30
- transform: Callable = None,
31
- control_transform: Callable = None,
32
- max_train_samples: int = None,
33
- sub_idxs: List[List[int]] = None,
34
- seed: int = 79,
35
- ) -> None:
36
- if not os.path.exists(index_file):
37
- raise FileNotFoundError(f"{index_file} index_file not found.")
38
-
39
- self.index_file = index_file
40
- self.target_hw = target_hw
41
- self.transform = transform
42
- self.control_transform = control_transform
43
- self.max_train_samples = max_train_samples
44
- self.meta_info = self.prepare_data_index(index_file)
45
- self.data_list = sorted(self.meta_info.keys())
46
- self.sub_idxs = sub_idxs # sub_idxs [[0,1,2], [3,4,5], [...], ...]
47
- self.image_num = 6 # hardcode temp.
48
- random.seed(seed)
49
- logger.info(f"Trainset: {len(self)} asset3d instances.")
50
-
51
- def __len__(self) -> int:
52
- return len(self.meta_info)
53
-
54
- def prepare_data_index(self, index_file: str) -> Dict[str, Any]:
55
- with open(index_file, "r") as fin:
56
- meta_info = json.load(fin)
57
-
58
- meta_info_filtered = dict()
59
- for idx, uid in enumerate(meta_info):
60
- if "status" not in meta_info[uid]:
61
- continue
62
- if meta_info[uid]["status"] != "success":
63
- continue
64
- if self.max_train_samples and idx >= self.max_train_samples:
65
- break
66
-
67
- meta_info_filtered[uid] = meta_info[uid]
68
-
69
- logger.info(
70
- f"Load {len(meta_info)} assets, keep {len(meta_info_filtered)} valids." # noqa
71
- )
72
-
73
- return meta_info_filtered
74
-
75
- def fetch_sample_images(
76
- self,
77
- uid: str,
78
- attrs: List[str],
79
- sub_index: int = None,
80
- transform: Callable = None,
81
- ) -> torch.Tensor:
82
- sample = self.meta_info[uid]
83
- images = []
84
- for attr in attrs:
85
- item = sample[attr]
86
- if sub_index is not None:
87
- item = item[sub_index]
88
- mode = "L" if attr == "image_mask" else "RGB"
89
- image = Image.open(item).convert(mode)
90
- if transform is not None:
91
- image = transform(image)
92
- if len(image.shape) == 2:
93
- image = image[..., None]
94
- images.append(image)
95
-
96
- images = torch.cat(images, dim=0)
97
-
98
- return images
99
-
100
- def fetch_sample_grid_images(
101
- self,
102
- uid: str,
103
- attrs: List[str],
104
- sub_idxs: List[List[int]],
105
- transform: Callable = None,
106
- ) -> torch.Tensor:
107
- assert transform is not None
108
-
109
- grid_image = []
110
- for row_idxs in sub_idxs:
111
- row_image = []
112
- for row_idx in row_idxs:
113
- image = self.fetch_sample_images(
114
- uid, attrs, row_idx, transform
115
- )
116
- row_image.append(image)
117
- row_image = torch.cat(row_image, dim=2) # (c h w)
118
- grid_image.append(row_image)
119
-
120
- grid_image = torch.cat(grid_image, dim=1)
121
-
122
- return grid_image
123
-
124
- def compute_text_embeddings(
125
- self, embed_path: str, original_size: Tuple[int, int]
126
- ) -> Dict[str, nn.Module]:
127
- data_dict = torch.load(embed_path)
128
- prompt_embeds = data_dict["prompt_embeds"][0]
129
- add_text_embeds = data_dict["pooled_prompt_embeds"][0]
130
-
131
- # Need changed if random crop, set as crop_top_left [y1, x1], center crop as [0, 0]. # noqa
132
- crops_coords_top_left = (0, 0)
133
- add_time_ids = list(
134
- original_size + crops_coords_top_left + self.target_hw
135
- )
136
- add_time_ids = torch.tensor([add_time_ids])
137
- # add_time_ids = add_time_ids.repeat((len(add_text_embeds), 1))
138
-
139
- unet_added_cond_kwargs = {
140
- "text_embeds": add_text_embeds,
141
- "time_ids": add_time_ids,
142
- }
143
-
144
- return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
145
-
146
- def visualize_item(
147
- self,
148
- control: torch.Tensor,
149
- color: torch.Tensor,
150
- save_dir: str = None,
151
- ) -> List[Image.Image]:
152
- to_pil = transforms.ToPILImage()
153
-
154
- color = (color + 1) / 2
155
- color_pil = to_pil(color)
156
- normal_pil = to_pil(control[0:3])
157
- position_pil = to_pil(control[3:6])
158
- mask_pil = to_pil(control[6:])
159
-
160
- if save_dir is not None:
161
- os.makedirs(save_dir, exist_ok=True)
162
- color_pil.save(f"{save_dir}/rgb.jpg")
163
- normal_pil.save(f"{save_dir}/normal.jpg")
164
- position_pil.save(f"{save_dir}/position.jpg")
165
- mask_pil.save(f"{save_dir}/mask.jpg")
166
- logger.info(f"Visualization in {save_dir}")
167
-
168
- return normal_pil, position_pil, mask_pil, color_pil
169
-
170
- def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
171
- uid = self.data_list[index]
172
-
173
- sub_idxs = self.sub_idxs
174
- if sub_idxs is None:
175
- sub_idxs = [[random.randint(0, self.image_num - 1)]]
176
-
177
- input_image = self.fetch_sample_grid_images(
178
- uid,
179
- attrs=["image_view_normal", "image_position", "image_mask"],
180
- sub_idxs=sub_idxs,
181
- transform=self.control_transform,
182
- )
183
- assert input_image.shape[1:] == self.target_hw
184
-
185
- output_image = self.fetch_sample_grid_images(
186
- uid,
187
- attrs=["image_color"],
188
- sub_idxs=sub_idxs,
189
- transform=self.transform,
190
- )
191
-
192
- sample = self.meta_info[uid]
193
- text_feats = self.compute_text_embeddings(
194
- sample["text_feat"], tuple(sample["image_hw"])
195
- )
196
-
197
- data = dict(
198
- pixel_values=output_image,
199
- conditioning_pixel_values=input_image,
200
- prompt_embeds=text_feats["prompt_embeds"],
201
- text_embeds=text_feats["text_embeds"],
202
- time_ids=text_feats["time_ids"],
203
- )
204
-
205
- return data
206
-
207
-
208
- if __name__ == "__main__":
209
- index_file = "/horizon-bucket/robot_lab/users/xinjie.wang/datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa
210
- target_hw = (512, 512)
211
- transform_list = [
212
- transforms.Resize(
213
- target_hw, interpolation=transforms.InterpolationMode.BILINEAR
214
- ),
215
- transforms.CenterCrop(target_hw),
216
- transforms.ToTensor(),
217
- transforms.Normalize([0.5], [0.5]),
218
- ]
219
- image_transform = transforms.Compose(transform_list)
220
- control_transform = transforms.Compose(transform_list[:-1])
221
-
222
- sub_idxs = [[0, 1, 2], [3, 4, 5]] # None
223
- if sub_idxs is not None:
224
- target_hw = (
225
- target_hw[0] * len(sub_idxs),
226
- target_hw[1] * len(sub_idxs[0]),
227
- )
228
-
229
- dataset = Asset3dGenDataset(
230
- index_file,
231
- target_hw,
232
- image_transform,
233
- control_transform,
234
- sub_idxs=sub_idxs,
235
- )
236
- data = dataset[0]
237
- dataset.visualize_item(
238
- data["conditioning_pixel_values"], data["pixel_values"], save_dir="./"
239
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/differentiable_render.py DELETED
@@ -1,520 +0,0 @@
1
- import argparse
2
- import json
3
- import logging
4
- import math
5
- import os
6
- from collections import defaultdict
7
- from typing import List, Union
8
-
9
- import cv2
10
- import imageio
11
- import numpy as np
12
- import nvdiffrast.torch as dr
13
- import torch
14
- from PIL import Image
15
- from tqdm import tqdm
16
- from asset3d_gen.data.utils import (
17
- CameraSetting,
18
- DiffrastRender,
19
- RenderItems,
20
- as_list,
21
- calc_vertex_normals,
22
- import_kaolin_mesh,
23
- init_kal_camera,
24
- normalize_vertices_array,
25
- render_pbr,
26
- save_images,
27
- )
28
-
29
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
30
- os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
31
- "~/.cache/torch_extensions"
32
- )
33
- logging.basicConfig(
34
- format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
35
- )
36
- logger = logging.getLogger(__name__)
37
-
38
-
39
- def create_gif_from_images(images, output_path, fps=10):
40
- pil_images = []
41
- for image in images:
42
- image = image.clip(min=0, max=1)
43
- image = (255.0 * image).astype(np.uint8)
44
- image = Image.fromarray(image, mode="RGBA")
45
- pil_images.append(image.convert("RGB"))
46
-
47
- duration = 1000 // fps
48
- pil_images[0].save(
49
- output_path,
50
- save_all=True,
51
- append_images=pil_images[1:],
52
- duration=duration,
53
- loop=0,
54
- )
55
-
56
- logger.info(f"GIF saved to {output_path}")
57
-
58
-
59
- def create_mp4_from_images(images, output_path, fps=10, prompt=None):
60
- font = cv2.FONT_HERSHEY_SIMPLEX # 字体样式
61
- font_scale = 0.5 # 字体大小
62
- font_thickness = 1 # 字体粗细
63
- color = (255, 255, 255) # 文字颜色(白色)
64
- position = (20, 25) # 左上角坐标 (x, y)
65
-
66
- with imageio.get_writer(output_path, fps=fps) as writer:
67
- for image in images:
68
- image = image.clip(min=0, max=1)
69
- image = (255.0 * image).astype(np.uint8)
70
- image = image[..., :3]
71
- if prompt is not None:
72
- cv2.putText(
73
- image,
74
- prompt,
75
- position,
76
- font,
77
- font_scale,
78
- color,
79
- font_thickness,
80
- )
81
-
82
- writer.append_data(image)
83
-
84
- logger.info(f"MP4 video saved to {output_path}")
85
-
86
-
87
- class ImageRender(object):
88
- def __init__(
89
- self,
90
- render_items: list[RenderItems],
91
- camera_params: CameraSetting,
92
- recompute_vtx_normal: bool = True,
93
- device: str = "cuda",
94
- with_mtl: bool = False,
95
- gen_color_gif: bool = False,
96
- gen_color_mp4: bool = False,
97
- gen_viewnormal_mp4: bool = False,
98
- gen_glonormal_mp4: bool = False,
99
- no_index_file: bool = False,
100
- light_factor: float = 1.0,
101
- ) -> None:
102
- camera_params.device = device
103
- camera = init_kal_camera(camera_params)
104
- self.camera = camera
105
-
106
- # Setup MVP matrix and renderer.
107
- mv = camera.view_matrix() # (n 4 4) world2cam
108
- p = camera.intrinsics.projection_matrix()
109
- # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
110
- p[:, 1, 1] = -p[:, 1, 1]
111
- # mvp = torch.bmm(p, mv) # camera.view_projection_matrix()
112
- self.mv = mv
113
- self.p = p
114
-
115
- renderer = DiffrastRender(
116
- p_matrix=p,
117
- mv_matrix=mv,
118
- resolution_hw=camera_params.resolution_hw,
119
- context=dr.RasterizeCudaContext(),
120
- mask_thresh=0.5,
121
- grad_db=False,
122
- device=camera_params.device,
123
- antialias_mask=True,
124
- )
125
- self.renderer = renderer
126
- self.recompute_vtx_normal = recompute_vtx_normal
127
- self.render_items = render_items
128
- self.device = device
129
- self.with_mtl = with_mtl
130
- self.gen_color_gif = gen_color_gif
131
- self.gen_color_mp4 = gen_color_mp4
132
- self.gen_viewnormal_mp4 = gen_viewnormal_mp4
133
- self.gen_glonormal_mp4 = gen_glonormal_mp4
134
- self.light_factor = light_factor
135
- self.no_index_file = no_index_file
136
-
137
- def render_mesh(
138
- self,
139
- mesh_path: Union[str, List[str]],
140
- output_root: str,
141
- uuid: Union[str, List[str]] = None,
142
- prompts: List[str] = None,
143
- ) -> None:
144
- mesh_path = as_list(mesh_path)
145
- if uuid is None:
146
- uuid = [os.path.basename(p).split(".")[0] for p in mesh_path]
147
- uuid = as_list(uuid)
148
- assert len(mesh_path) == len(uuid)
149
- os.makedirs(output_root, exist_ok=True)
150
-
151
- meta_info = dict()
152
- for idx, (path, uid) in tqdm(
153
- enumerate(zip(mesh_path, uuid)), total=len(mesh_path)
154
- ):
155
- output_dir = os.path.join(output_root, uid)
156
- os.makedirs(output_dir, exist_ok=True)
157
- prompt = prompts[idx] if prompts else None
158
- data_dict = self(path, output_dir, prompt)
159
- meta_info[uid] = data_dict
160
-
161
- if self.no_index_file:
162
- return
163
-
164
- index_file = os.path.join(output_root, "index.json")
165
- with open(index_file, "w") as fout:
166
- json.dump(meta_info, fout)
167
-
168
- logger.info(f"Rendering meta info logged in {index_file}")
169
-
170
- def __call__(
171
- self, mesh_path: str, output_dir: str, prompt: str = None
172
- ) -> dict[str, str]:
173
- try:
174
- mesh = import_kaolin_mesh(mesh_path, self.with_mtl)
175
- except Exception as e:
176
- logger.error(f"[ERROR MESH LOAD]: {e}, skip {mesh_path}")
177
- return
178
-
179
- mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
180
- if self.recompute_vtx_normal:
181
- mesh.vertex_normals = calc_vertex_normals(
182
- mesh.vertices, mesh.faces
183
- )
184
-
185
- mesh = mesh.to(self.device)
186
- vertices, faces, vertex_normals = (
187
- mesh.vertices,
188
- mesh.faces,
189
- mesh.vertex_normals,
190
- )
191
-
192
- # Perform rendering.
193
- data_dict = defaultdict(list)
194
- if RenderItems.ALPHA.value in self.render_items:
195
- masks, _ = self.renderer.render_rast_alpha(vertices, faces)
196
- render_paths = save_images(
197
- masks, f"{output_dir}/{RenderItems.ALPHA}"
198
- )
199
- data_dict[RenderItems.ALPHA.value] = render_paths
200
-
201
- if RenderItems.GLOBAL_NORMAL.value in self.render_items:
202
- rendered_normals, masks = self.renderer.render_global_normal(
203
- vertices, faces, vertex_normals
204
- )
205
- if self.gen_glonormal_mp4:
206
- if isinstance(rendered_normals, torch.Tensor):
207
- rendered_normals = rendered_normals.detach().cpu().numpy()
208
- create_mp4_from_images(
209
- rendered_normals,
210
- output_path=f"{output_dir}/normal.mp4",
211
- fps=15,
212
- prompt=prompt,
213
- )
214
- else:
215
- render_paths = save_images(
216
- rendered_normals,
217
- f"{output_dir}/{RenderItems.GLOBAL_NORMAL}",
218
- cvt_color=cv2.COLOR_BGR2RGB,
219
- )
220
- data_dict[RenderItems.GLOBAL_NORMAL.value] = render_paths
221
-
222
- if RenderItems.VIEW_NORMAL.value in self.render_items:
223
- assert (
224
- RenderItems.GLOBAL_NORMAL in self.render_items
225
- ), f"Must render global normal firstly, got render_items: {self.render_items}." # noqa
226
- rendered_view_normals = self.renderer.transform_normal(
227
- rendered_normals, self.mv, masks, to_view=True
228
- )
229
- # rendered_inv_view_normals = renderer.transform_normal(rendered_view_normals, torch.linalg.inv(mv), masks, to_view=False) # noqa
230
- if self.gen_viewnormal_mp4:
231
- create_mp4_from_images(
232
- rendered_view_normals,
233
- output_path=f"{output_dir}/view_normal.mp4",
234
- fps=15,
235
- prompt=prompt,
236
- )
237
- else:
238
- render_paths = save_images(
239
- rendered_view_normals,
240
- f"{output_dir}/{RenderItems.VIEW_NORMAL}",
241
- cvt_color=cv2.COLOR_BGR2RGB,
242
- )
243
- data_dict[RenderItems.VIEW_NORMAL.value] = render_paths
244
-
245
- if RenderItems.POSITION_MAP.value in self.render_items:
246
- rendered_position, masks = self.renderer.render_position(
247
- vertices, faces
248
- )
249
- norm_position = self.renderer.normalize_map_by_mask(
250
- rendered_position, masks
251
- )
252
- render_paths = save_images(
253
- norm_position,
254
- f"{output_dir}/{RenderItems.POSITION_MAP}",
255
- cvt_color=cv2.COLOR_BGR2RGB,
256
- )
257
- data_dict[RenderItems.POSITION_MAP.value] = render_paths
258
-
259
- if RenderItems.DEPTH.value in self.render_items:
260
- rendered_depth, masks = self.renderer.render_depth(vertices, faces)
261
- norm_depth = self.renderer.normalize_map_by_mask(
262
- rendered_depth, masks
263
- )
264
- render_paths = save_images(
265
- norm_depth,
266
- f"{output_dir}/{RenderItems.DEPTH}",
267
- )
268
- data_dict[RenderItems.DEPTH.value] = render_paths
269
-
270
- render_paths = save_images(
271
- rendered_depth,
272
- f"{output_dir}/{RenderItems.DEPTH}_exr",
273
- to_uint8=False,
274
- format=".exr",
275
- )
276
- data_dict[f"{RenderItems.DEPTH.value}_exr"] = render_paths
277
-
278
- if RenderItems.IMAGE.value in self.render_items:
279
- images = []
280
- albedos = []
281
- diffuses = []
282
- masks, _ = self.renderer.render_rast_alpha(vertices, faces)
283
- try:
284
- for idx, cam in enumerate(self.camera):
285
- image, albedo, diffuse, _ = render_pbr(
286
- mesh, cam, light_factor=self.light_factor
287
- )
288
- image = torch.cat([image[0], masks[idx]], axis=-1)
289
- images.append(image.detach().cpu().numpy())
290
-
291
- if RenderItems.ALBEDO.value in self.render_items:
292
- albedo = torch.cat([albedo[0], masks[idx]], axis=-1)
293
- albedos.append(albedo.detach().cpu().numpy())
294
-
295
- if RenderItems.DIFFUSE.value in self.render_items:
296
- diffuse = torch.cat([diffuse[0], masks[idx]], axis=-1)
297
- diffuses.append(diffuse.detach().cpu().numpy())
298
-
299
- except Exception as e:
300
- logger.error(f"[ERROR pbr render]: {e}, skip {mesh_path}")
301
- return
302
-
303
- if self.gen_color_gif:
304
- create_gif_from_images(
305
- images,
306
- output_path=f"{output_dir}/color.gif",
307
- fps=15,
308
- )
309
-
310
- if self.gen_color_mp4:
311
- create_mp4_from_images(
312
- images,
313
- output_path=f"{output_dir}/color.mp4",
314
- fps=15,
315
- prompt=prompt,
316
- )
317
-
318
- if self.gen_color_mp4 or self.gen_color_gif:
319
- return data_dict
320
-
321
- render_paths = save_images(
322
- images,
323
- f"{output_dir}/{RenderItems.IMAGE}",
324
- cvt_color=cv2.COLOR_BGRA2RGBA,
325
- )
326
- data_dict[RenderItems.IMAGE.value] = render_paths
327
-
328
- render_paths = save_images(
329
- albedos,
330
- f"{output_dir}/{RenderItems.ALBEDO}",
331
- cvt_color=cv2.COLOR_BGRA2RGBA,
332
- )
333
- data_dict[RenderItems.ALBEDO.value] = render_paths
334
-
335
- render_paths = save_images(
336
- diffuses,
337
- f"{output_dir}/{RenderItems.DIFFUSE}",
338
- cvt_color=cv2.COLOR_BGRA2RGBA,
339
- )
340
- data_dict[RenderItems.DIFFUSE.value] = render_paths
341
-
342
- data_dict["status"] = "success"
343
-
344
- logger.info(f"Finish rendering in {output_dir}")
345
-
346
- return data_dict
347
-
348
-
349
- def parse_args():
350
- parser = argparse.ArgumentParser(description="Render settings")
351
-
352
- parser.add_argument(
353
- "--mesh_path",
354
- type=str,
355
- nargs="+",
356
- required=True,
357
- help="Paths to the mesh files for rendering.",
358
- )
359
- parser.add_argument(
360
- "--output_root",
361
- type=str,
362
- required=True,
363
- help="Root directory for output",
364
- )
365
- parser.add_argument(
366
- "--uuid",
367
- type=str,
368
- nargs="+",
369
- default=None,
370
- help="uuid for rendering saving.",
371
- )
372
- parser.add_argument(
373
- "--num_images", type=int, default=6, help="Number of images to render."
374
- )
375
- parser.add_argument(
376
- "--elevation",
377
- type=float,
378
- nargs="+",
379
- default=[20.0, -10.0],
380
- help="Elevation angles for the camera (default: [20.0, -10.0])",
381
- )
382
- parser.add_argument(
383
- "--distance",
384
- type=float,
385
- default=5,
386
- help="Camera distance (default: 5)",
387
- )
388
- parser.add_argument(
389
- "--resolution_hw",
390
- type=int,
391
- nargs=2,
392
- default=(512, 512),
393
- help="Resolution of the output images (default: (512, 512))",
394
- )
395
- parser.add_argument(
396
- "--fov",
397
- type=float,
398
- default=30,
399
- help="Field of view in degrees (default: 30)",
400
- )
401
- parser.add_argument(
402
- "--pbr_light_factor",
403
- type=float,
404
- default=1.0,
405
- help="Light factor for mesh PBR rendering (default: 2.)",
406
- )
407
- parser.add_argument(
408
- "--device",
409
- type=str,
410
- choices=["cpu", "cuda"],
411
- default="cuda",
412
- help="Device to run on (default: 'cuda')",
413
- )
414
- parser.add_argument(
415
- "--with_mtl",
416
- action="store_true",
417
- help="Whether to render with mesh material.",
418
- )
419
- parser.add_argument(
420
- "--gen_color_gif",
421
- action="store_true",
422
- help="Whether to generate color .gif rendering file.",
423
- )
424
- parser.add_argument(
425
- "--gen_color_mp4",
426
- action="store_true",
427
- help="Whether to generate color .mp4 rendering file.",
428
- )
429
- parser.add_argument(
430
- "--gen_viewnormal_mp4",
431
- action="store_true",
432
- help="Whether to generate view normal .mp4 rendering file.",
433
- )
434
- parser.add_argument(
435
- "--gen_glonormal_mp4",
436
- action="store_true",
437
- help="Whether to generate global normal .mp4 rendering file.",
438
- )
439
- parser.add_argument(
440
- "--prompts",
441
- type=str,
442
- nargs="+",
443
- default=None,
444
- help="Text prompts for the rendering.",
445
- )
446
-
447
- args = parser.parse_args()
448
-
449
- if args.uuid is None:
450
- args.uuid = []
451
- for path in args.mesh_path:
452
- uuid = os.path.basename(path).split(".")[0]
453
- args.uuid.append(uuid)
454
-
455
- return args
456
-
457
-
458
- def entrypoint() -> None:
459
- args = parse_args()
460
-
461
- camera_settings = CameraSetting(
462
- num_images=args.num_images,
463
- elevation=args.elevation,
464
- distance=args.distance,
465
- resolution_hw=args.resolution_hw,
466
- fov=math.radians(args.fov),
467
- device=args.device,
468
- )
469
-
470
- render_items = [
471
- RenderItems.ALPHA.value,
472
- RenderItems.GLOBAL_NORMAL.value,
473
- RenderItems.VIEW_NORMAL.value,
474
- RenderItems.POSITION_MAP.value,
475
- RenderItems.IMAGE.value,
476
- RenderItems.DEPTH.value,
477
- # RenderItems.ALBEDO.value,
478
- # RenderItems.DIFFUSE.value,
479
- ]
480
-
481
- gen_video = (
482
- args.gen_color_gif
483
- or args.gen_color_mp4
484
- or args.gen_viewnormal_mp4
485
- or args.gen_glonormal_mp4
486
- )
487
- if gen_video:
488
- render_items = []
489
- if args.gen_color_gif or args.gen_color_mp4:
490
- render_items.append(RenderItems.IMAGE.value)
491
- if args.gen_glonormal_mp4:
492
- render_items.append(RenderItems.GLOBAL_NORMAL.value)
493
- if args.gen_viewnormal_mp4:
494
- render_items.append(RenderItems.VIEW_NORMAL.value)
495
- if RenderItems.GLOBAL_NORMAL.value not in render_items:
496
- render_items.append(RenderItems.GLOBAL_NORMAL.value)
497
-
498
- image_render = ImageRender(
499
- render_items=render_items,
500
- camera_params=camera_settings,
501
- with_mtl=args.with_mtl,
502
- gen_color_gif=args.gen_color_gif,
503
- gen_color_mp4=args.gen_color_mp4,
504
- gen_viewnormal_mp4=args.gen_viewnormal_mp4,
505
- gen_glonormal_mp4=args.gen_glonormal_mp4,
506
- light_factor=args.pbr_light_factor,
507
- no_index_file=gen_video,
508
- )
509
- image_render.render_mesh(
510
- mesh_path=args.mesh_path,
511
- output_root=args.output_root,
512
- uuid=args.uuid,
513
- prompts=args.prompts,
514
- )
515
-
516
- return
517
-
518
-
519
- if __name__ == "__main__":
520
- entrypoint()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/mesh_operator.py DELETED
@@ -1,425 +0,0 @@
1
- import logging
2
- from typing import Tuple, Union
3
-
4
- import igraph
5
- import numpy as np
6
- import pyvista as pv
7
- import torch
8
- import utils3d
9
- from pymeshfix import _meshfix
10
- from tqdm import tqdm
11
-
12
- logging.basicConfig(
13
- format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
14
- )
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- __all__ = ["MeshFixer"]
19
-
20
-
21
- def radical_inverse(base, n):
22
- val = 0
23
- inv_base = 1.0 / base
24
- inv_base_n = inv_base
25
- while n > 0:
26
- digit = n % base
27
- val += digit * inv_base_n
28
- n //= base
29
- inv_base_n *= inv_base
30
- return val
31
-
32
-
33
- def halton_sequence(dim, n):
34
- PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
35
- return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
36
-
37
-
38
- def hammersley_sequence(dim, n, num_samples):
39
- return [n / num_samples] + halton_sequence(dim - 1, n)
40
-
41
-
42
- def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False):
43
- """Generate a point on a unit sphere using the Hammersley sequence.
44
-
45
- Args:
46
- n (int): The index of the sample.
47
- num_samples (int): The total number of samples.
48
- offset (tuple, optional): Offset for the u and v coordinates.
49
- remap (bool, optional): Whether to remap the u coordinate.
50
-
51
- Returns:
52
- list: A list containing the spherical coordinates [phi, theta].
53
- """
54
- u, v = hammersley_sequence(2, n, num_samples)
55
- u += offset[0] / num_samples
56
- v += offset[1]
57
-
58
- if remap:
59
- u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
60
-
61
- theta = np.arccos(1 - 2 * u) - np.pi / 2
62
- phi = v * 2 * np.pi
63
- return [phi, theta]
64
-
65
-
66
- class MeshFixer(object):
67
- """Reduce and postprocess 3D meshes, simplifying and filling holes."""
68
-
69
- def __init__(
70
- self,
71
- vertices: Union[torch.Tensor, np.ndarray],
72
- faces: Union[torch.Tensor, np.ndarray],
73
- device: str = "cuda",
74
- ) -> None:
75
- self.device = device
76
- self.vertices = (
77
- torch.tensor(vertices, device=device)
78
- if isinstance(vertices, np.ndarray)
79
- else vertices.to(device)
80
- )
81
- self.faces = (
82
- torch.tensor(faces.astype(np.int32), device=device)
83
- if isinstance(faces, np.ndarray)
84
- else faces.to(device)
85
- )
86
-
87
- @staticmethod
88
- def log_mesh_changes(method):
89
- def wrapper(self, *args, **kwargs):
90
- logger.info(
91
- f"Before {method.__name__}: {self.vertices.shape[0]} vertices, {self.faces.shape[0]} faces" # noqa
92
- )
93
- result = method(self, *args, **kwargs)
94
- logger.info(
95
- f"After {method.__name__}: {self.vertices.shape[0]} vertices, {self.faces.shape[0]} faces" # noqa
96
- )
97
- return result
98
-
99
- return wrapper
100
-
101
- @log_mesh_changes
102
- def fill_holes(
103
- self,
104
- max_hole_size: float,
105
- max_hole_nbe: int,
106
- resolution: int,
107
- num_views: int,
108
- norm_mesh_ratio: float = 1.0,
109
- ) -> None:
110
- self.vertices = self.vertices * norm_mesh_ratio
111
- vertices, self.faces = self._fill_holes(
112
- self.vertices,
113
- self.faces,
114
- max_hole_size,
115
- max_hole_nbe,
116
- resolution,
117
- num_views,
118
- )
119
- self.vertices = vertices / norm_mesh_ratio
120
-
121
- @staticmethod
122
- @torch.no_grad()
123
- def _fill_holes(
124
- vertices: torch.Tensor,
125
- faces: torch.Tensor,
126
- max_hole_size: float,
127
- max_hole_nbe: int,
128
- resolution: int,
129
- num_views: int,
130
- ) -> Union[torch.Tensor, torch.Tensor]:
131
- yaws, pitchs = [], []
132
- for i in range(num_views):
133
- y, p = sphere_hammersley_sequence(i, num_views)
134
- yaws.append(y)
135
- pitchs.append(p)
136
-
137
- yaws, pitchs = torch.tensor(yaws).to(vertices), torch.tensor(
138
- pitchs
139
- ).to(vertices)
140
- radius, fov = 2.0, torch.deg2rad(torch.tensor(40)).to(vertices)
141
- projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3)
142
-
143
- views = []
144
- for yaw, pitch in zip(yaws, pitchs):
145
- orig = (
146
- torch.tensor(
147
- [
148
- torch.sin(yaw) * torch.cos(pitch),
149
- torch.cos(yaw) * torch.cos(pitch),
150
- torch.sin(pitch),
151
- ]
152
- ).to(vertices)
153
- * radius
154
- )
155
- view = utils3d.torch.view_look_at(
156
- orig,
157
- torch.tensor([0, 0, 0]).to(vertices),
158
- torch.tensor([0, 0, 1]).to(vertices),
159
- )
160
- views.append(view)
161
- views = torch.stack(views, dim=0)
162
-
163
- # Rasterize the mesh
164
- visibility = torch.zeros(
165
- faces.shape[0], dtype=torch.int32, device=faces.device
166
- )
167
- rastctx = utils3d.torch.RastContext(backend="cuda")
168
-
169
- for i in tqdm(
170
- range(views.shape[0]), total=views.shape[0], desc="Rasterizing"
171
- ):
172
- view = views[i]
173
- buffers = utils3d.torch.rasterize_triangle_faces(
174
- rastctx,
175
- vertices[None],
176
- faces,
177
- resolution,
178
- resolution,
179
- view=view,
180
- projection=projection,
181
- )
182
- face_id = buffers["face_id"][0][buffers["mask"][0] > 0.95] - 1
183
- face_id = torch.unique(face_id).long()
184
- visibility[face_id] += 1
185
-
186
- # Normalize visibility by the number of views
187
- visibility = visibility.float() / num_views
188
-
189
- # Mincut: Identify outer and inner faces
190
- edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces)
191
- boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1)
192
- connected_components = utils3d.torch.compute_connected_components(
193
- faces, edges, face2edge
194
- )
195
-
196
- outer_face_indices = torch.zeros(
197
- faces.shape[0], dtype=torch.bool, device=faces.device
198
- )
199
- for i in range(len(connected_components)):
200
- outer_face_indices[connected_components[i]] = visibility[
201
- connected_components[i]
202
- ] > min(
203
- max(
204
- visibility[connected_components[i]].quantile(0.75).item(),
205
- 0.25,
206
- ),
207
- 0.5,
208
- )
209
-
210
- outer_face_indices = outer_face_indices.nonzero().reshape(-1)
211
- inner_face_indices = torch.nonzero(visibility == 0).reshape(-1)
212
-
213
- if inner_face_indices.shape[0] == 0:
214
- return vertices, faces
215
-
216
- # Construct dual graph (faces as nodes, edges as edges)
217
- dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(
218
- face2edge
219
- )
220
- dual_edge2edge = edges[dual_edge2edge]
221
- dual_edges_weights = torch.norm(
222
- vertices[dual_edge2edge[:, 0]] - vertices[dual_edge2edge[:, 1]],
223
- dim=1,
224
- )
225
-
226
- # Mincut: Construct main graph and solve the mincut problem
227
- g = igraph.Graph()
228
- g.add_vertices(faces.shape[0])
229
- g.add_edges(dual_edges.cpu().numpy())
230
- g.es["weight"] = dual_edges_weights.cpu().numpy()
231
-
232
- g.add_vertex("s") # source
233
- g.add_vertex("t") # target
234
-
235
- g.add_edges(
236
- [(f, "s") for f in inner_face_indices],
237
- attributes={
238
- "weight": torch.ones(
239
- inner_face_indices.shape[0], dtype=torch.float32
240
- )
241
- .cpu()
242
- .numpy()
243
- },
244
- )
245
- g.add_edges(
246
- [(f, "t") for f in outer_face_indices],
247
- attributes={
248
- "weight": torch.ones(
249
- outer_face_indices.shape[0], dtype=torch.float32
250
- )
251
- .cpu()
252
- .numpy()
253
- },
254
- )
255
-
256
- cut = g.mincut("s", "t", (np.array(g.es["weight"]) * 1000).tolist())
257
- remove_face_indices = torch.tensor(
258
- [v for v in cut.partition[0] if v < faces.shape[0]],
259
- dtype=torch.long,
260
- device=faces.device,
261
- )
262
-
263
- # Check if the cut is valid with each connected component
264
- to_remove_cc = utils3d.torch.compute_connected_components(
265
- faces[remove_face_indices]
266
- )
267
- valid_remove_cc = []
268
- cutting_edges = []
269
- for cc in to_remove_cc:
270
- # Check visibility median for connected component
271
- visibility_median = visibility[remove_face_indices[cc]].median()
272
- if visibility_median > 0.25:
273
- continue
274
-
275
- # Check if the cutting loop is small enough
276
- cc_edge_indices, cc_edges_degree = torch.unique(
277
- face2edge[remove_face_indices[cc]], return_counts=True
278
- )
279
- cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1]
280
- cc_new_boundary_edge_indices = cc_boundary_edge_indices[
281
- ~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)
282
- ]
283
- if len(cc_new_boundary_edge_indices) > 0:
284
- cc_new_boundary_edge_cc = (
285
- utils3d.torch.compute_edge_connected_components(
286
- edges[cc_new_boundary_edge_indices]
287
- )
288
- )
289
- cc_new_boundary_edges_cc_center = [
290
- vertices[edges[cc_new_boundary_edge_indices[edge_cc]]]
291
- .mean(dim=1)
292
- .mean(dim=0)
293
- for edge_cc in cc_new_boundary_edge_cc
294
- ]
295
- cc_new_boundary_edges_cc_area = []
296
- for i, edge_cc in enumerate(cc_new_boundary_edge_cc):
297
- _e1 = (
298
- vertices[
299
- edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]
300
- ]
301
- - cc_new_boundary_edges_cc_center[i]
302
- )
303
- _e2 = (
304
- vertices[
305
- edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]
306
- ]
307
- - cc_new_boundary_edges_cc_center[i]
308
- )
309
- cc_new_boundary_edges_cc_area.append(
310
- torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum()
311
- * 0.5
312
- )
313
- cutting_edges.append(cc_new_boundary_edge_indices)
314
- if any(
315
- [
316
- _l > max_hole_size
317
- for _l in cc_new_boundary_edges_cc_area
318
- ]
319
- ):
320
- continue
321
-
322
- valid_remove_cc.append(cc)
323
-
324
- if len(valid_remove_cc) > 0:
325
- remove_face_indices = remove_face_indices[
326
- torch.cat(valid_remove_cc)
327
- ]
328
- mask = torch.ones(
329
- faces.shape[0], dtype=torch.bool, device=faces.device
330
- )
331
- mask[remove_face_indices] = 0
332
- faces = faces[mask]
333
- faces, vertices = utils3d.torch.remove_unreferenced_vertices(
334
- faces, vertices
335
- )
336
-
337
- tqdm.write(f"Removed {(~mask).sum()} faces by mincut")
338
- else:
339
- tqdm.write(f"Removed 0 faces by mincut")
340
-
341
- # Fill small boundaries (holes)
342
- mesh = _meshfix.PyTMesh()
343
- mesh.load_array(vertices.cpu().numpy(), faces.cpu().numpy())
344
- mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True)
345
-
346
- _vertices, _faces = mesh.return_arrays()
347
- vertices = torch.tensor(_vertices).to(vertices)
348
- faces = torch.tensor(_faces).to(faces)
349
-
350
- return vertices, faces
351
-
352
- @property
353
- def vertices_np(self) -> np.ndarray:
354
- return self.vertices.cpu().numpy()
355
-
356
- @property
357
- def faces_np(self) -> np.ndarray:
358
- return self.faces.cpu().numpy()
359
-
360
- @log_mesh_changes
361
- def simplify(self, ratio: float) -> None:
362
- """Simplify the mesh using quadric edge collapse decimation.
363
-
364
- Args:
365
- ratio (float): Ratio of faces to filter out.
366
- """
367
- if ratio <= 0 or ratio >= 1:
368
- raise ValueError("Simplify ratio must be between 0 and 1.")
369
-
370
- # Convert to PyVista format for simplification
371
- mesh = pv.PolyData(
372
- self.vertices_np,
373
- np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]),
374
- )
375
- mesh = mesh.decimate(ratio, progress_bar=True)
376
-
377
- # Update vertices and faces
378
- self.vertices = torch.tensor(
379
- mesh.points, device=self.device, dtype=torch.float32
380
- )
381
- self.faces = torch.tensor(
382
- mesh.faces.reshape(-1, 4)[:, 1:],
383
- device=self.device,
384
- dtype=torch.int32,
385
- )
386
-
387
- def __call__(
388
- self,
389
- filter_ratio: float,
390
- max_hole_size: float,
391
- resolution: int,
392
- num_views: int,
393
- norm_mesh_ratio: float = 1.0,
394
- ) -> Tuple[np.ndarray, np.ndarray]:
395
- """Post-process the mesh by simplifying and filling holes.
396
-
397
- This method performs a two-step process:
398
- 1. Simplifies mesh by reducing faces using quadric edge decimation.
399
- 2. Fills holes by removing invisible faces, repairing small boundaries.
400
-
401
- Args:
402
- filter_ratio (float): Ratio of faces to simplify out.
403
- Must be in the range (0, 1).
404
- max_hole_size (float): Maximum area of a hole to fill. Connected
405
- components of holes larger than this size will not be repaired.
406
- resolution (int): Resolution of the rasterization buffer.
407
- num_views (int): Number of viewpoints to sample for rasterization.
408
- norm_mesh_ratio (float, optional): A scaling factor applied to the
409
- vertices of the mesh during processing.
410
-
411
- Returns:
412
- Tuple[np.ndarray, np.ndarray]:
413
- - vertices: Simplified and repaired vertex array of (V, 3).
414
- - faces: Simplified and repaired face array of (F, 3).
415
- """
416
- self.simplify(ratio=filter_ratio)
417
- self.fill_holes(
418
- max_hole_size=max_hole_size,
419
- max_hole_nbe=int(250 * np.sqrt(1 - filter_ratio)),
420
- resolution=resolution,
421
- num_views=num_views,
422
- norm_mesh_ratio=norm_mesh_ratio,
423
- )
424
-
425
- return self.vertices_np, self.faces_np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/data/utils.py DELETED
@@ -1,943 +0,0 @@
1
- import math
2
- import os
3
- import random
4
- from glob import glob
5
- from typing import List, Tuple, Union
6
-
7
- import cv2
8
- import kaolin as kal
9
- import numpy as np
10
- import nvdiffrast.torch as dr
11
- import torch
12
- import torch.nn.functional as F
13
- from PIL import Image
14
-
15
- try:
16
- from kolors.models.modeling_chatglm import ChatGLMModel
17
- from kolors.models.tokenization_chatglm import ChatGLMTokenizer
18
- except ImportError:
19
- ChatGLMTokenizer = None
20
- ChatGLMModel = None
21
- import logging
22
- from dataclasses import dataclass, field
23
- from enum import Enum
24
-
25
- import trimesh
26
- from kaolin.render.camera import Camera
27
- from torch import nn
28
-
29
- logger = logging.getLogger(__name__)
30
-
31
-
32
- __all__ = [
33
- "center_points",
34
- "get_points_stat",
35
- "DiffrastRender",
36
- "compute_cam_pts_by_az_el",
37
- "compute_cam_pts_by_views",
38
- "save_images",
39
- "render_pbr",
40
- "load_llm_models",
41
- "prelabel_text_feature",
42
- "calc_vertex_normals",
43
- "normalize_vertices_array",
44
- "load_mesh_to_unit_cube",
45
- "as_list",
46
- "CameraSetting",
47
- "RenderItems",
48
- "import_kaolin_mesh",
49
- "save_mesh_with_mtl",
50
- "get_images_from_grid",
51
- "post_process_texture",
52
- ]
53
-
54
-
55
- def get_points_stat(
56
- points: torch.FloatTensor, eps: float = 1e-6
57
- ) -> torch.FloatTensor:
58
- assert (
59
- len(points.shape) == 3
60
- ), f"Points have unexpected shape {points.shape}"
61
-
62
- vmin = points.min(dim=1, keepdim=True)[0]
63
- vmax = points.max(dim=1, keepdim=True)[0]
64
- pts_center = (vmin + vmax) / 2
65
-
66
- pts_dim = (vmax - vmin).max(dim=-1, keepdim=True)[0].clip(min=eps)
67
-
68
- return pts_center, pts_dim
69
-
70
-
71
- def center_points(
72
- points: torch.FloatTensor, normalize: bool = False, eps: float = 1e-6
73
- ) -> torch.FloatTensor:
74
- vmid, den = get_points_stat(points)
75
-
76
- res = points - vmid
77
-
78
- if normalize:
79
- res = res / den
80
-
81
- return res
82
-
83
-
84
- class DiffrastRender(object):
85
- """A class to handle differentiable rendering using nvdiffrast.
86
-
87
- This class provides methods to render position, depth, and normal maps
88
- with optional anti-aliasing and gradient disabling for rasterization.
89
-
90
- Attributes:
91
- p_mtx (torch.Tensor): Projection matrix.
92
- mv_mtx (torch.Tensor): Model-view matrix.
93
- mvp_mtx (torch.Tensor): Model-view-projection matrix, calculated as
94
- p_mtx @ mv_mtx if not provided.
95
- resolution_hw (Tuple[int, int]): Height and width of the rendering resolution. # noqa
96
- _ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): Rasterization context. # noqa
97
- mask_thresh (float): Threshold for mask creation.
98
- grad_db (bool): Whether to disable gradients during rasterization.
99
- antialias_mask (bool): Whether to apply anti-aliasing to the mask.
100
- device (str): Device used for rendering ('cuda' or 'cpu').
101
-
102
- Methods:
103
- _warmup(glctx): Warmup rasterization by rendering a simple triangle.
104
- compute_dr_raster(vertices, faces): Rasterizes the mesh and returns
105
- rasterized outputs and transformed vertices.
106
- transform_vertices(vertices, matrix): Transforms the vertices using
107
- the provided transformation matrix.
108
- normalize_map_by_mask_separately(map, mask): Normalizes each map in
109
- the batch separately using the mask.
110
- normalize_map_by_mask(map, mask): Normalizes the entire map using the
111
- mask, keeping the output in the range [0, 1].
112
- render_position(vertices, faces): Renders the position map and
113
- alpha mask from the given vertices and faces.
114
- render_depth(vertices, faces): Renders the depth map and alpha
115
- mask from the given vertices and faces.
116
- _compute_mask(rast, vertices_clip, faces): Computes the mask from the
117
- rasterization output.
118
- render_global_normal(vertices, faces, vertice_normals): Renders the
119
- normal map and alpha mask from the given vertices, faces, and
120
- vertex normals.
121
- transform_normal_to_view(normals, mat_w2c, masks): Transforms the normals
122
- to the view space using the world-to-camera matrix.
123
- """
124
-
125
- def __init__(
126
- self,
127
- p_matrix: torch.Tensor,
128
- mv_matrix: torch.Tensor,
129
- resolution_hw: Tuple[int, int],
130
- context: Union[dr.RasterizeCudaContext, dr.RasterizeGLContext] = None,
131
- mvp_matrix: torch.Tensor = None,
132
- mask_thresh: float = 0.5,
133
- grad_db: bool = False,
134
- antialias_mask: bool = True,
135
- align_coordinate: bool = True,
136
- device: str = "cuda",
137
- ) -> None:
138
- self.p_mtx = p_matrix
139
- self.mv_mtx = mv_matrix
140
- if mvp_matrix is None:
141
- self.mvp_mtx = torch.bmm(p_matrix, mv_matrix)
142
-
143
- self.resolution_hw = resolution_hw
144
- if context is None:
145
- context = dr.RasterizeCudaContext(device=device)
146
- self._ctx = context
147
- self.mask_thresh = mask_thresh
148
- self.grad_db = grad_db
149
- self.antialias_mask = antialias_mask
150
- self.align_coordinate = align_coordinate
151
- self.device = device
152
- # self._warmup(self._ctx)
153
-
154
- def _warmup(self, glctx):
155
- # Seem solved. https://github.com/NVlabs/nvdiffrast/issues/59
156
- def tensor(*args, **kwargs):
157
- return torch.tensor(*args, device=self.device, **kwargs)
158
-
159
- pos = tensor(
160
- [[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]],
161
- dtype=torch.float32,
162
- )
163
- tri = tensor([[0, 1, 2]], dtype=torch.int32)
164
- dr.rasterize(glctx, pos, tri, resolution=[256, 256])
165
-
166
- def compute_dr_raster(
167
- self,
168
- vertices: torch.Tensor,
169
- faces: torch.Tensor,
170
- ) -> Tuple[torch.Tensor, torch.Tensor]:
171
- vertices_clip = self.transform_vertices(vertices, matrix=self.mvp_mtx)
172
- rast, _ = dr.rasterize(
173
- self._ctx,
174
- vertices_clip,
175
- faces.int(),
176
- resolution=self.resolution_hw,
177
- grad_db=self.grad_db,
178
- )
179
-
180
- return rast, vertices_clip
181
-
182
- def transform_vertices(
183
- self,
184
- vertices: torch.Tensor,
185
- matrix: torch.Tensor,
186
- ) -> torch.Tensor:
187
- verts_ones = torch.ones((len(vertices), 1)).to(vertices)
188
- verts_homo = torch.cat([vertices, verts_ones], dim=-1)
189
- trans_vertices = torch.matmul(verts_homo, matrix.permute(0, 2, 1))
190
-
191
- return trans_vertices
192
-
193
- def normalize_map_by_mask_separately(
194
- self, map: torch.Tensor, mask: torch.Tensor
195
- ) -> torch.Tensor:
196
- # Normalize each map separately by mask, normalized map in [0, 1].
197
- normalized_maps = []
198
- for map_item, mask_item in zip(map, mask):
199
- normalized_map = self.normalize_map_by_mask(map_item, mask_item)
200
- normalized_maps.append(normalized_map)
201
-
202
- normalized_maps = torch.stack(normalized_maps, dim=0)
203
-
204
- return normalized_maps
205
-
206
- def normalize_map_by_mask(
207
- self, map: torch.Tensor, mask: torch.Tensor
208
- ) -> torch.Tensor:
209
- # Normalize all maps in total by mask, normalized map in [0, 1].
210
- foreground = (mask == 1).squeeze(dim=-1)
211
- foreground_elements = map[foreground]
212
- if len(foreground_elements) == 0:
213
- return map
214
-
215
- min_val, _ = foreground_elements.min(dim=0)
216
- max_val, _ = foreground_elements.max(dim=0)
217
- val_range = (max_val - min_val).clip(min=1e-6)
218
-
219
- normalized_map = (map - min_val) / val_range
220
- normalized_map = torch.lerp(
221
- torch.zeros_like(normalized_map), normalized_map, mask
222
- )
223
- normalized_map[normalized_map < 0] = 0
224
-
225
- return normalized_map
226
-
227
- def _compute_mask(
228
- self,
229
- rast: torch.Tensor,
230
- vertices_clip: torch.Tensor,
231
- faces: torch.Tensor,
232
- ) -> torch.Tensor:
233
- mask = (rast[..., 3:] > 0).float()
234
- mask = mask.clip(min=0, max=1)
235
-
236
- if self.antialias_mask is True:
237
- mask = dr.antialias(mask, rast, vertices_clip, faces)
238
- else:
239
- foreground = mask > self.mask_thresh
240
- mask[foreground] = 1
241
- mask[~foreground] = 0
242
-
243
- return mask
244
-
245
- def render_rast_alpha(
246
- self,
247
- vertices: torch.Tensor,
248
- faces: torch.Tensor,
249
- ):
250
- faces = faces.to(torch.int32)
251
- rast, vertices_clip = self.compute_dr_raster(vertices, faces)
252
- mask = self._compute_mask(rast, vertices_clip, faces)
253
-
254
- return mask, rast
255
-
256
- def render_position(
257
- self,
258
- vertices: torch.Tensor,
259
- faces: torch.Tensor,
260
- ) -> Union[torch.Tensor, torch.Tensor]:
261
- # Vertices in model coordinate system, real position coordinate number.
262
- faces = faces.to(torch.int32)
263
- mask, rast = self.render_rast_alpha(vertices, faces)
264
-
265
- vertices_model = vertices[None, ...].contiguous().float()
266
- position_map, _ = dr.interpolate(vertices_model, rast, faces)
267
- # Align with blender.
268
- if self.align_coordinate:
269
- position_map = position_map[..., [0, 2, 1]]
270
- position_map[..., 1] = -position_map[..., 1]
271
-
272
- position_map = torch.lerp(
273
- torch.zeros_like(position_map), position_map, mask
274
- )
275
-
276
- return position_map, mask
277
-
278
- def render_uv(
279
- self,
280
- vertices: torch.Tensor,
281
- faces: torch.Tensor,
282
- vtx_uv: torch.Tensor,
283
- ) -> Union[torch.Tensor, torch.Tensor]:
284
- faces = faces.to(torch.int32)
285
- mask, rast = self.render_rast_alpha(vertices, faces)
286
- uv_map, _ = dr.interpolate(vtx_uv, rast, faces)
287
- uv_map = torch.lerp(torch.zeros_like(uv_map), uv_map, mask)
288
-
289
- return uv_map, mask
290
-
291
- def render_depth(
292
- self,
293
- vertices: torch.Tensor,
294
- faces: torch.Tensor,
295
- ) -> Union[torch.Tensor, torch.Tensor]:
296
- # Vertices in model coordinate system, real depth coordinate number.
297
- faces = faces.to(torch.int32)
298
- mask, rast = self.render_rast_alpha(vertices, faces)
299
-
300
- vertices_camera = self.transform_vertices(vertices, matrix=self.mv_mtx)
301
- vertices_camera = vertices_camera[..., 2:3].contiguous().float()
302
- depth_map, _ = dr.interpolate(vertices_camera, rast, faces)
303
- # Change camera depth minus to positive.
304
- if self.align_coordinate:
305
- depth_map = -depth_map
306
- depth_map = torch.lerp(torch.zeros_like(depth_map), depth_map, mask)
307
-
308
- return depth_map, mask
309
-
310
- def render_global_normal(
311
- self,
312
- vertices: torch.Tensor,
313
- faces: torch.Tensor,
314
- vertice_normals: torch.Tensor,
315
- ) -> Union[torch.Tensor, torch.Tensor]:
316
- # NOTE: vertice_normals in [-1, 1], return normal in [0, 1].
317
- # vertices / vertice_normals in model coordinate system.
318
- faces = faces.to(torch.int32)
319
- mask, rast = self.render_rast_alpha(vertices, faces)
320
- im_base_normals, _ = dr.interpolate(
321
- vertice_normals[None, ...].float(), rast, faces
322
- )
323
-
324
- if im_base_normals is not None:
325
- faces = faces.to(torch.int64)
326
- vertices_cam = self.transform_vertices(
327
- vertices, matrix=self.mv_mtx
328
- )
329
- face_vertices_ndc = kal.ops.mesh.index_vertices_by_faces(
330
- vertices_cam[..., :3], faces
331
- )
332
- face_normal_sign = kal.ops.mesh.face_normals(face_vertices_ndc)[
333
- ..., 2
334
- ]
335
- for idx in range(len(im_base_normals)):
336
- face_idx = (rast[idx, ..., -1].long() - 1).contiguous()
337
- im_normal_sign = torch.sign(face_normal_sign[idx, face_idx])
338
- im_normal_sign[face_idx == -1] = 0
339
- im_base_normals[idx] *= im_normal_sign.unsqueeze(-1)
340
-
341
- normal = (im_base_normals + 1) / 2
342
- normal = normal.clip(min=0, max=1)
343
- normal = torch.lerp(torch.zeros_like(normal), normal, mask)
344
-
345
- return normal, mask
346
-
347
- def transform_normal(
348
- self,
349
- normals: torch.Tensor,
350
- trans_matrix: torch.Tensor,
351
- masks: torch.Tensor,
352
- to_view: bool,
353
- ) -> torch.Tensor:
354
- # NOTE: input normals in [0, 1], output normals in [0, 1].
355
- normals = normals.clone()
356
- assert len(normals) == len(trans_matrix)
357
-
358
- if not to_view:
359
- # Flip the sign on the x-axis to match inv bae system for global transformation. # noqa
360
- normals[..., 0] = 1 - normals[..., 0]
361
-
362
- normals = 2 * normals - 1
363
- b, h, w, c = normals.shape
364
-
365
- transformed_normals = []
366
- for normal, matrix in zip(normals, trans_matrix):
367
- # Transform normals using the transformation matrix (4x4).
368
- reshaped_normals = normal.view(-1, c) # (h w 3) -> (hw 3)
369
- padded_vectors = torch.nn.functional.pad(
370
- reshaped_normals, pad=(0, 1), mode="constant", value=0.0
371
- )
372
- transformed_normal = torch.matmul(
373
- padded_vectors, matrix.transpose(0, 1)
374
- )[..., :3]
375
-
376
- # Normalize and clip the normals to [0, 1] range.
377
- transformed_normal = F.normalize(transformed_normal, p=2, dim=-1)
378
- transformed_normal = (transformed_normal + 1) / 2
379
-
380
- if to_view:
381
- # Flip the sign on the x-axis to match bae system for view transformation. # noqa
382
- transformed_normal[..., 0] = 1 - transformed_normal[..., 0]
383
-
384
- transformed_normals.append(transformed_normal.view(h, w, c))
385
-
386
- transformed_normals = torch.stack(transformed_normals, dim=0)
387
-
388
- if masks is not None:
389
- transformed_normals = torch.lerp(
390
- torch.zeros_like(transformed_normals),
391
- transformed_normals,
392
- masks,
393
- )
394
-
395
- return transformed_normals
396
-
397
-
398
- def az_el_to_points(
399
- azimuths: np.ndarray, elevations: np.ndarray
400
- ) -> np.ndarray:
401
- x = np.cos(azimuths) * np.cos(elevations)
402
- y = np.sin(azimuths) * np.cos(elevations)
403
- z = np.sin(elevations)
404
-
405
- return np.stack([x, y, z], axis=-1)
406
-
407
-
408
- def compute_az_el_by_views(
409
- num_view: int, el: float
410
- ) -> Tuple[np.ndarray, np.ndarray]:
411
- azimuths = np.arange(num_view) / num_view * np.pi * 2
412
- elevations = np.deg2rad(np.array([el] * num_view))
413
-
414
- return azimuths, elevations
415
-
416
-
417
- def compute_cam_pts_by_az_el(
418
- azs: np.ndarray,
419
- els: np.ndarray,
420
- distance: float,
421
- extra_pts: np.ndarray = None,
422
- ) -> np.ndarray:
423
- distances = np.array([distance for _ in range(len(azs))])
424
- cam_pts = az_el_to_points(azs, els) * distances[:, None]
425
-
426
- if extra_pts is not None:
427
- cam_pts = np.concatenate([cam_pts, extra_pts], axis=0)
428
-
429
- # Align coordinate system.
430
- cam_pts = cam_pts[:, [0, 2, 1]] # xyz -> xzy
431
- cam_pts[..., 2] = -cam_pts[..., 2]
432
-
433
- return cam_pts
434
-
435
-
436
- def compute_cam_pts_by_views(
437
- num_view: int, el: float, distance: float, extra_pts: np.ndarray = None
438
- ) -> torch.Tensor:
439
- """Computes object-center camera points for a given number of views.
440
-
441
- Args:
442
- num_view (int): The number of views (camera positions) to compute.
443
- el (float): The elevation angle in degrees.
444
- distance (float): The distance from the origin to the camera.
445
- extra_pts (np.ndarray): Extra camera points postion.
446
-
447
- Returns:
448
- torch.Tensor: A tensor containing the camera points for each view, with shape `(num_view, 3)`. # noqa
449
- """
450
- azimuths, elevations = compute_az_el_by_views(num_view, el)
451
- cam_pts = compute_cam_pts_by_az_el(
452
- azimuths, elevations, distance, extra_pts
453
- )
454
-
455
- return cam_pts
456
-
457
-
458
- def save_images(
459
- images: Union[list[np.ndarray], list[torch.Tensor]],
460
- output_dir: str,
461
- cvt_color: str = None,
462
- format: str = ".png",
463
- to_uint8: bool = True,
464
- verbose: bool = False,
465
- ) -> List[str]:
466
- # NOTE: images in [0, 1]
467
- os.makedirs(output_dir, exist_ok=True)
468
- save_paths = []
469
- for idx, image in enumerate(images):
470
- if isinstance(image, torch.Tensor):
471
- image = image.detach().cpu().numpy()
472
- if to_uint8:
473
- image = image.clip(min=0, max=1)
474
- image = (255.0 * image).astype(np.uint8)
475
- if cvt_color is not None:
476
- image = cv2.cvtColor(image, cvt_color)
477
- save_path = os.path.join(output_dir, f"{idx:04d}{format}")
478
- save_paths.append(save_path)
479
-
480
- cv2.imwrite(save_path, image)
481
-
482
- if verbose:
483
- logger.info(f"Images saved in {output_dir}")
484
-
485
- return save_paths
486
-
487
-
488
- def current_lighting(
489
- azimuths: List[float],
490
- elevations: List[float],
491
- light_factor: float = 1.0,
492
- device: str = "cuda",
493
- ):
494
- # azimuths, elevations in degress.
495
- directions = []
496
- for az, el in zip(azimuths, elevations):
497
- az, el = math.radians(az), math.radians(el)
498
- direction = kal.render.lighting.sg_direction_from_azimuth_elevation(
499
- az, el
500
- )
501
- directions.append(direction)
502
- directions = torch.cat(directions, dim=0)
503
-
504
- amplitude = torch.ones_like(directions) * light_factor
505
- light_condition = kal.render.lighting.SgLightingParameters(
506
- amplitude=amplitude,
507
- direction=directions,
508
- sharpness=3,
509
- ).to(device)
510
-
511
- # light_condition = kal.render.lighting.SgLightingParameters.from_sun(
512
- # directions, strength=1, angle=90, color=None
513
- # ).to(device)
514
-
515
- return light_condition
516
-
517
-
518
- def render_pbr(
519
- mesh,
520
- camera,
521
- device="cuda",
522
- cxt=None,
523
- custom_materials=None,
524
- light_factor=1.0,
525
- ):
526
- if cxt is None:
527
- cxt = dr.RasterizeCudaContext()
528
-
529
- light_condition = current_lighting(
530
- azimuths=[0, 90, 180, 270],
531
- elevations=[90, 60, 30, 20],
532
- light_factor=light_factor,
533
- device=device,
534
- )
535
- render_res = kal.render.easy_render.render_mesh(
536
- camera,
537
- mesh,
538
- lighting=light_condition,
539
- nvdiffrast_context=cxt,
540
- custom_materials=custom_materials,
541
- )
542
-
543
- image = render_res[kal.render.easy_render.RenderPass.render]
544
- image = image.clip(0, 1)
545
-
546
- albedo = render_res[kal.render.easy_render.RenderPass.albedo]
547
- albedo = albedo.clip(0, 1)
548
-
549
- diffuse = render_res[kal.render.easy_render.RenderPass.diffuse]
550
- diffuse = diffuse.clip(0, 1)
551
-
552
- normal = render_res[kal.render.easy_render.RenderPass.normals]
553
- normal = normal.clip(-1, 1)
554
-
555
- return image, albedo, diffuse, normal
556
-
557
-
558
- def load_saved_normal(path: str) -> np.ndarray:
559
- image_paths = glob(os.path.join(path, "*.jpg"))
560
- images = []
561
- for path in sorted(image_paths):
562
- image = cv2.imread(path)
563
- image = image[..., ::-1] # rgb -> bgr
564
- images.append(image)
565
- images = np.stack(images, axis=0)
566
-
567
- return images
568
-
569
-
570
- def _move_to_target_device(data, device: str):
571
- if isinstance(data, dict):
572
- for key, value in data.items():
573
- data[key] = _move_to_target_device(value, device)
574
- elif isinstance(data, torch.Tensor):
575
- return data.to(device)
576
-
577
- return data
578
-
579
-
580
- def _encode_prompt(
581
- prompt_batch,
582
- text_encoders,
583
- tokenizers,
584
- proportion_empty_prompts=0,
585
- is_train=True,
586
- ):
587
- prompt_embeds_list = []
588
-
589
- captions = []
590
- for caption in prompt_batch:
591
- if random.random() < proportion_empty_prompts:
592
- captions.append("")
593
- elif isinstance(caption, str):
594
- captions.append(caption)
595
- elif isinstance(caption, (list, np.ndarray)):
596
- captions.append(random.choice(caption) if is_train else caption[0])
597
-
598
- with torch.no_grad():
599
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
600
- text_inputs = tokenizer(
601
- captions,
602
- padding="max_length",
603
- max_length=256,
604
- truncation=True,
605
- return_tensors="pt",
606
- ).to(text_encoder.device)
607
-
608
- output = text_encoder(
609
- input_ids=text_inputs.input_ids,
610
- attention_mask=text_inputs.attention_mask,
611
- position_ids=text_inputs.position_ids,
612
- output_hidden_states=True,
613
- )
614
-
615
- # We are only interested in the pooled output of the text encoder.
616
- prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
617
- pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()
618
- bs_embed, seq_len, _ = prompt_embeds.shape
619
- prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
620
- prompt_embeds_list.append(prompt_embeds)
621
-
622
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
623
- pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
624
-
625
- return prompt_embeds, pooled_prompt_embeds
626
-
627
-
628
- def load_llm_models(pretrained_model_name_or_path: str, device: str):
629
- tokenizer = ChatGLMTokenizer.from_pretrained(
630
- pretrained_model_name_or_path,
631
- subfolder="text_encoder",
632
- )
633
- text_encoder = ChatGLMModel.from_pretrained(
634
- pretrained_model_name_or_path,
635
- subfolder="text_encoder",
636
- ).to(device)
637
-
638
- text_encoders = [
639
- text_encoder,
640
- ]
641
- tokenizers = [
642
- tokenizer,
643
- ]
644
-
645
- logger.info(f"Load model from {pretrained_model_name_or_path} done.")
646
-
647
- return tokenizers, text_encoders
648
-
649
-
650
- def prelabel_text_feature(
651
- prompt_batch: List[str],
652
- output_dir: str,
653
- tokenizers: nn.Module,
654
- text_encoders: nn.Module,
655
- ) -> List[str]:
656
- os.makedirs(output_dir, exist_ok=True)
657
-
658
- # prompt_batch ["text..."]
659
- prompt_embeds, pooled_prompt_embeds = _encode_prompt(
660
- prompt_batch, text_encoders, tokenizers
661
- )
662
-
663
- prompt_embeds = _move_to_target_device(prompt_embeds, device="cpu")
664
- pooled_prompt_embeds = _move_to_target_device(
665
- pooled_prompt_embeds, device="cpu"
666
- )
667
-
668
- data_dict = dict(
669
- prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds
670
- )
671
-
672
- save_path = os.path.join(output_dir, "text_feat.pth")
673
- torch.save(data_dict, save_path)
674
-
675
- return save_path
676
-
677
-
678
- def calc_face_normals(
679
- vertices: torch.Tensor, # V,3 first vertex may be unreferenced
680
- faces: torch.Tensor, # F,3 long, first face may be all zero
681
- normalize: bool = False,
682
- ) -> torch.Tensor: # F,3
683
- full_vertices = vertices[faces] # F,C=3,3
684
- v0, v1, v2 = full_vertices.unbind(dim=1) # F,3
685
- face_normals = torch.cross(v1 - v0, v2 - v0, dim=1) # F,3
686
- if normalize:
687
- face_normals = F.normalize(
688
- face_normals, eps=1e-6, dim=1
689
- ) # TODO inplace?
690
- return face_normals # F,3
691
-
692
-
693
- def calc_vertex_normals(
694
- vertices: torch.Tensor, # V,3 first vertex may be unreferenced
695
- faces: torch.Tensor, # F,3 long, first face may be all zero
696
- face_normals: torch.Tensor = None, # F,3, not normalized
697
- ) -> torch.Tensor: # F,3
698
- _F = faces.shape[0]
699
-
700
- if face_normals is None:
701
- face_normals = calc_face_normals(vertices, faces)
702
-
703
- vertex_normals = torch.zeros(
704
- (vertices.shape[0], 3, 3), dtype=vertices.dtype, device=vertices.device
705
- ) # V,C=3,3
706
- vertex_normals.scatter_add_(
707
- dim=0,
708
- index=faces[:, :, None].expand(_F, 3, 3),
709
- src=face_normals[:, None, :].expand(_F, 3, 3),
710
- )
711
- vertex_normals = vertex_normals.sum(dim=1) # V,3
712
- return F.normalize(vertex_normals, eps=1e-6, dim=1)
713
-
714
-
715
- def normalize_vertices_array(
716
- vertices: Union[torch.Tensor, np.ndarray],
717
- mesh_scale: float = 1.0,
718
- exec_norm: bool = True,
719
- ):
720
- if isinstance(vertices, torch.Tensor):
721
- bbmin, bbmax = vertices.min(0)[0], vertices.max(0)[0]
722
- else:
723
- bbmin, bbmax = vertices.min(0), vertices.max(0) # (3,)
724
- center = (bbmin + bbmax) * 0.5
725
- bbsize = bbmax - bbmin
726
- scale = 2 * mesh_scale / bbsize.max()
727
- if exec_norm:
728
- vertices = (vertices - center) * scale
729
-
730
- return vertices, scale, center
731
-
732
-
733
- def load_mesh_to_unit_cube(
734
- mesh_file: str,
735
- mesh_scale: float = 1.0,
736
- ) -> tuple[trimesh.Trimesh, float, list[float]]:
737
- if not os.path.exists(mesh_file):
738
- raise FileNotFoundError(f"mesh_file path {mesh_file} not exists.")
739
-
740
- mesh = trimesh.load(mesh_file)
741
- if isinstance(mesh, trimesh.Scene):
742
- mesh = trimesh.utils.concatenate(mesh)
743
-
744
- vertices, scale, center = normalize_vertices_array(
745
- mesh.vertices, mesh_scale
746
- )
747
- mesh.vertices = vertices
748
-
749
- return mesh, scale, center
750
-
751
-
752
- def as_list(obj):
753
- if isinstance(obj, (list, tuple)):
754
- return obj
755
- elif isinstance(obj, set):
756
- return list(obj)
757
- else:
758
- return [obj]
759
-
760
-
761
- @dataclass
762
- class CameraSetting:
763
- """Camera settings for images rendering."""
764
-
765
- num_images: int
766
- elevation: list[float]
767
- distance: float
768
- resolution_hw: tuple[int, int]
769
- fov: float
770
- at: tuple[float, float, float] = field(
771
- default_factory=lambda: (0.0, 0.0, 0.0)
772
- )
773
- up: tuple[float, float, float] = field(
774
- default_factory=lambda: (0.0, 1.0, 0.0)
775
- )
776
- device: str = "cuda"
777
- near: float = 1e-2
778
- far: float = 1e2
779
-
780
- def __post_init__(
781
- self,
782
- ):
783
- h = self.resolution_hw[0]
784
- f = (h / 2) / math.tan(self.fov / 2)
785
- cx = self.resolution_hw[1] / 2
786
- cy = self.resolution_hw[0] / 2
787
- Ks = [
788
- [f, 0, cx],
789
- [0, f, cy],
790
- [0, 0, 1],
791
- ]
792
-
793
- self.Ks = Ks
794
-
795
-
796
- @dataclass
797
- class RenderItems(str, Enum):
798
- IMAGE = "image_color"
799
- ALPHA = "image_mask"
800
- VIEW_NORMAL = "image_view_normal"
801
- GLOBAL_NORMAL = "image_global_normal"
802
- POSITION_MAP = "image_position"
803
- DEPTH = "image_depth"
804
- ALBEDO = "image_albedo"
805
- DIFFUSE = "image_diffuse"
806
-
807
-
808
- def compute_az_el_by_camera_params(
809
- camera_params: CameraSetting, flip_az: bool = False
810
- ):
811
- num_view = camera_params.num_images // len(camera_params.elevation)
812
- view_interval = 2 * np.pi / num_view / 2
813
- azimuths = []
814
- elevations = []
815
- for idx, el in enumerate(camera_params.elevation):
816
- azs = np.arange(num_view) / num_view * np.pi * 2 + idx * view_interval
817
- if flip_az:
818
- azs *= -1
819
- els = np.deg2rad(np.array([el] * num_view))
820
- azimuths.append(azs)
821
- elevations.append(els)
822
-
823
- azimuths = np.concatenate(azimuths, axis=0)
824
- elevations = np.concatenate(elevations, axis=0)
825
-
826
- return azimuths, elevations
827
-
828
-
829
- def init_kal_camera(camera_params: CameraSetting) -> Camera:
830
- azimuths, elevations = compute_az_el_by_camera_params(camera_params)
831
- cam_pts = compute_cam_pts_by_az_el(
832
- azimuths, elevations, camera_params.distance
833
- )
834
-
835
- up = torch.cat(
836
- [
837
- torch.tensor(camera_params.up).repeat(camera_params.num_images, 1),
838
- ],
839
- dim=0,
840
- )
841
-
842
- camera = Camera.from_args(
843
- eye=torch.tensor(cam_pts),
844
- at=torch.tensor(camera_params.at),
845
- up=up,
846
- fov=camera_params.fov,
847
- height=camera_params.resolution_hw[0],
848
- width=camera_params.resolution_hw[1],
849
- near=camera_params.near,
850
- far=camera_params.far,
851
- device=camera_params.device,
852
- )
853
-
854
- return camera
855
-
856
-
857
- def import_kaolin_mesh(mesh_path: str, with_mtl: bool = False):
858
- if mesh_path.endswith(".glb"):
859
- mesh = kal.io.gltf.import_mesh(mesh_path)
860
- elif mesh_path.endswith(".obj"):
861
- with_material = True if with_mtl else False
862
- mesh = kal.io.obj.import_mesh(mesh_path, with_materials=with_material)
863
- if with_mtl and mesh.materials and len(mesh.materials) > 0:
864
- material = kal.render.materials.PBRMaterial()
865
- assert (
866
- "map_Kd" in mesh.materials[0]
867
- ), "'map_Kd' not found in materials."
868
- material.diffuse_texture = mesh.materials[0]["map_Kd"] / 255.0
869
- mesh.materials = [material]
870
- elif mesh_path.endswith(".ply"):
871
- mesh = trimesh.load(mesh_path)
872
- mesh_path = mesh_path.replace(".ply", ".obj")
873
- mesh.export(mesh_path)
874
- mesh = kal.io.obj.import_mesh(mesh_path)
875
- elif mesh_path.endswith(".off"):
876
- mesh = kal.io.off.import_mesh(mesh_path)
877
- else:
878
- raise RuntimeError(
879
- f"{mesh_path} mesh type not supported, "
880
- "supported mesh type `.glb`, `.obj`, `.ply`, `.off`."
881
- )
882
-
883
- return mesh
884
-
885
-
886
- def save_mesh_with_mtl(
887
- vertices: np.ndarray,
888
- faces: np.ndarray,
889
- uvs: np.ndarray,
890
- texture: Union[Image.Image, np.ndarray],
891
- output_path: str,
892
- material_base=(250, 250, 250, 255),
893
- ) -> trimesh.Trimesh:
894
- if isinstance(texture, np.ndarray):
895
- texture = Image.fromarray(texture)
896
-
897
- mesh = trimesh.Trimesh(
898
- vertices,
899
- faces,
900
- visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture),
901
- )
902
- mesh.visual.material = trimesh.visual.material.SimpleMaterial(
903
- image=texture,
904
- diffuse=material_base,
905
- ambient=material_base,
906
- specular=material_base,
907
- )
908
-
909
- dir_name = os.path.dirname(output_path)
910
- os.makedirs(dir_name, exist_ok=True)
911
-
912
- _ = mesh.export(output_path)
913
- # texture.save(os.path.join(dir_name, f"{file_name}_texture.png"))
914
-
915
- logger.info(f"Saved mesh with texture to {output_path}")
916
-
917
- return mesh
918
-
919
-
920
- def get_images_from_grid(
921
- image: Union[str, Image.Image], img_size: int
922
- ) -> list[Image.Image]:
923
- if isinstance(image, str):
924
- image = Image.open(image)
925
-
926
- view_images = np.array(image)
927
- view_images = np.concatenate(
928
- [view_images[:img_size, ...], view_images[img_size:, ...]], axis=1
929
- )
930
- images = np.split(view_images, view_images.shape[1] // img_size, axis=1)
931
- images = [Image.fromarray(img) for img in images]
932
-
933
- return images
934
-
935
-
936
- def post_process_texture(texture: np.ndarray, iter: int = 2) -> np.ndarray:
937
- for _ in range(iter):
938
- texture = cv2.fastNlMeansDenoisingColored(texture, None, 13, 13, 9, 27)
939
- texture = cv2.bilateralFilter(
940
- texture, d=9, sigmaColor=80, sigmaSpace=80
941
- )
942
-
943
- return texture
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/models/delight.py DELETED
@@ -1,162 +0,0 @@
1
- import os
2
- from typing import Union
3
-
4
- import cv2
5
- import numpy as np
6
- import torch
7
- from diffusers import (
8
- EulerAncestralDiscreteScheduler,
9
- StableDiffusionInstructPix2PixPipeline,
10
- )
11
- from huggingface_hub import snapshot_download
12
- from PIL import Image
13
- from asset3d_gen.models.segment import RembgRemover
14
-
15
- __all__ = [
16
- "DelightingModel",
17
- ]
18
-
19
-
20
- class DelightingModel(object):
21
- def __init__(
22
- self,
23
- model_path: str = None,
24
- num_infer_step: int = 50,
25
- mask_erosion_size: int = 3,
26
- image_guide_scale: float = 1.5,
27
- text_guide_scale: float = 1.0,
28
- device: str = "cuda",
29
- seed: int = 0,
30
- ) -> None:
31
- self.image_guide_scale = image_guide_scale
32
- self.text_guide_scale = text_guide_scale
33
- self.num_infer_step = num_infer_step
34
- self.mask_erosion_size = mask_erosion_size
35
- self.kernel = np.ones(
36
- (self.mask_erosion_size, self.mask_erosion_size), np.uint8
37
- )
38
- self.seed = seed
39
- self.device = device
40
- self.bg_remover = RembgRemover()
41
-
42
- if model_path is None:
43
- suffix = "hunyuan3d-delight-v2-0"
44
- model_path = snapshot_download(
45
- repo_id="tencent/Hunyuan3D-2", allow_patterns=f"{suffix}/*"
46
- )
47
- model_path = os.path.join(model_path, suffix)
48
-
49
- pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
50
- model_path,
51
- torch_dtype=torch.float16,
52
- safety_checker=None,
53
- )
54
- pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
55
- pipeline.scheduler.config
56
- )
57
- pipeline.set_progress_bar_config(disable=True)
58
-
59
- pipeline.to(self.device, torch.float16)
60
- # pipeline.enable_model_cpu_offload()
61
- # pipeline.enable_xformers_memory_efficient_attention()
62
- self.pipeline = pipeline
63
-
64
- def recenter_image(
65
- self, image: Image.Image, border_ratio: float = 0.2
66
- ) -> Image.Image:
67
- if image.mode == "RGB":
68
- return image
69
- elif image.mode == "L":
70
- image = image.convert("RGB")
71
- return image
72
-
73
- alpha_channel = np.array(image)[:, :, 3]
74
- non_zero_indices = np.argwhere(alpha_channel > 0)
75
- if non_zero_indices.size == 0:
76
- raise ValueError("Image is fully transparent")
77
-
78
- min_row, min_col = non_zero_indices.min(axis=0)
79
- max_row, max_col = non_zero_indices.max(axis=0)
80
-
81
- cropped_image = image.crop(
82
- (min_col, min_row, max_col + 1, max_row + 1)
83
- )
84
-
85
- width, height = cropped_image.size
86
- border_width = int(width * border_ratio)
87
- border_height = int(height * border_ratio)
88
-
89
- new_width = width + 2 * border_width
90
- new_height = height + 2 * border_height
91
-
92
- square_size = max(new_width, new_height)
93
-
94
- new_image = Image.new(
95
- "RGBA", (square_size, square_size), (255, 255, 255, 0)
96
- )
97
-
98
- paste_x = (square_size - new_width) // 2 + border_width
99
- paste_y = (square_size - new_height) // 2 + border_height
100
-
101
- new_image.paste(cropped_image, (paste_x, paste_y))
102
-
103
- return new_image
104
-
105
- @torch.no_grad()
106
- def __call__(
107
- self,
108
- image: Union[str, np.ndarray, Image.Image],
109
- preprocess: bool = False,
110
- target_wh: tuple[int, int] = None,
111
- ) -> Image.Image:
112
- if isinstance(image, str):
113
- image = Image.open(image)
114
- elif isinstance(image, np.ndarray):
115
- image = Image.fromarray(image)
116
-
117
- if preprocess:
118
- image = self.bg_remover(image)
119
- image = self.recenter_image(image)
120
-
121
- if target_wh is not None:
122
- image = image.resize(target_wh)
123
- else:
124
- target_wh = image.size
125
-
126
- image_array = np.array(image)
127
- assert image_array.shape[-1] == 4, "Image must have alpha channel"
128
-
129
- raw_alpha_channel = image_array[:, :, 3]
130
- alpha_channel = cv2.erode(raw_alpha_channel, self.kernel, iterations=1)
131
- image_array[alpha_channel == 0, :3] = 255 # must be white background
132
- image_array[:, :, 3] = alpha_channel
133
-
134
- image = self.pipeline(
135
- prompt="",
136
- image=Image.fromarray(image_array).convert("RGB"),
137
- generator=torch.manual_seed(self.seed),
138
- num_inference_steps=self.num_infer_step,
139
- image_guidance_scale=self.image_guide_scale,
140
- guidance_scale=self.text_guide_scale,
141
- ).images[0]
142
-
143
- alpha_channel = Image.fromarray(alpha_channel)
144
- rgba_image = image.convert("RGBA").resize(target_wh)
145
- rgba_image.putalpha(alpha_channel)
146
-
147
- return rgba_image
148
-
149
-
150
- if __name__ == "__main__":
151
- delighting_model = DelightingModel(
152
- # model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/hunyuan3d-delight-v2-0" # noqa
153
- )
154
- image_path = "scripts/apps/assets/example_image/room_bottle_002.jpeg"
155
- image = delighting_model(
156
- image_path, preprocess=True, target_wh=(512, 512)
157
- ) # noqa
158
- image.save("delight.png")
159
-
160
- # image_path = "asset3d_gen/scripts/test_robot.png"
161
- # image = delighting_model(image_path)
162
- # image.save("delighting_image_a2.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/models/gs_model.py DELETED
@@ -1,540 +0,0 @@
1
- import logging
2
- import os
3
- import struct
4
- from dataclasses import dataclass, field
5
- from typing import Optional, Union
6
-
7
- import cv2
8
- import numpy as np
9
- import torch
10
- from gsplat.cuda._wrapper import spherical_harmonics
11
- from gsplat.rendering import rasterization
12
- from plyfile import PlyData
13
- from scipy.spatial.transform import Rotation
14
- from torch.nn import functional as F
15
-
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
-
19
-
20
- __all__ = [
21
- "RenderResult",
22
- "GaussianOperator",
23
- ]
24
-
25
-
26
- def quat_mult(q1, q2):
27
- # NOTE:
28
- # Q1 is the quaternion that rotates the vector from the original position to the final position # noqa
29
- # Q2 is the quaternion that been rotated
30
- w1, x1, y1, z1 = q1.T
31
- w2, x2, y2, z2 = q2.T
32
- w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
33
- x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
34
- y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
35
- z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
36
- return torch.stack([w, x, y, z]).T
37
-
38
-
39
- def quat_to_rotmat(quats: torch.Tensor, mode="wxyz") -> torch.Tensor:
40
- """Convert quaternion to rotation matrix."""
41
- quats = F.normalize(quats, p=2, dim=-1)
42
-
43
- if mode == "xyzw":
44
- x, y, z, w = torch.unbind(quats, dim=-1)
45
- elif mode == "wxyz":
46
- w, x, y, z = torch.unbind(quats, dim=-1)
47
- else:
48
- raise ValueError(f"Invalid mode: {mode}.")
49
-
50
- R = torch.stack(
51
- [
52
- 1 - 2 * (y**2 + z**2),
53
- 2 * (x * y - w * z),
54
- 2 * (x * z + w * y),
55
- 2 * (x * y + w * z),
56
- 1 - 2 * (x**2 + z**2),
57
- 2 * (y * z - w * x),
58
- 2 * (x * z - w * y),
59
- 2 * (y * z + w * x),
60
- 1 - 2 * (x**2 + y**2),
61
- ],
62
- dim=-1,
63
- )
64
-
65
- return R.reshape(quats.shape[:-1] + (3, 3))
66
-
67
-
68
- def gamma_shs(shs: torch.Tensor, gamma: float) -> torch.Tensor:
69
- C0 = 0.28209479177387814 # Constant for normalization in spherical harmonics # noqa
70
- # Clip to the range [0.0, 1.0], apply gamma correction, and then un-clip back # noqa
71
- new_shs = torch.clip(shs * C0 + 0.5, 0.0, 1.0)
72
- new_shs = (torch.pow(new_shs, gamma) - 0.5) / C0
73
- return new_shs
74
-
75
-
76
- @dataclass
77
- class RenderResult:
78
- rgb: np.ndarray
79
- depth: np.ndarray
80
- opacity: np.ndarray
81
- mask_threshold: float = 10
82
- mask: Optional[np.ndarray] = None
83
- rgba: Optional[np.ndarray] = None
84
-
85
- def __post_init__(self):
86
- if isinstance(self.rgb, torch.Tensor):
87
- rgb = self.rgb.detach().cpu().numpy()
88
- rgb = (rgb * 255).astype(np.uint8)
89
- self.rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
90
- if isinstance(self.depth, torch.Tensor):
91
- self.depth = self.depth.detach().cpu().numpy()
92
- if isinstance(self.opacity, torch.Tensor):
93
- opacity = self.opacity.detach().cpu().numpy()
94
- opacity = (opacity * 255).astype(np.uint8)
95
- self.opacity = cv2.cvtColor(opacity, cv2.COLOR_GRAY2RGB)
96
- mask = np.where(self.opacity > self.mask_threshold, 255, 0)
97
- self.mask = mask[..., 0:1].astype(np.uint8)
98
- self.rgba = np.concatenate([self.rgb, self.mask], axis=-1)
99
-
100
-
101
- @dataclass
102
- class GaussianBase:
103
- _opacities: torch.Tensor
104
- _means: torch.Tensor
105
- _scales: torch.Tensor
106
- _quats: torch.Tensor
107
- _rgbs: Optional[torch.Tensor] = None
108
- _features_dc: Optional[torch.Tensor] = None
109
- _features_rest: Optional[torch.Tensor] = None
110
- sh_degree: Optional[int] = 0
111
- device: str = "cuda"
112
-
113
- def __post_init__(self):
114
- self.active_sh_degree: int = self.sh_degree
115
- self.to(self.device)
116
-
117
- def to(self, device: str) -> None:
118
- for k, v in self.__dict__.items():
119
- if not isinstance(v, torch.Tensor):
120
- continue
121
- self.__dict__[k] = v.to(device)
122
-
123
- def get_numpy_data(self):
124
- data = {}
125
- for k, v in self.__dict__.items():
126
- if not isinstance(v, torch.Tensor):
127
- continue
128
- data[k] = v.detach().cpu().numpy()
129
-
130
- return data
131
-
132
- def quat_norm(self, x: torch.Tensor) -> torch.Tensor:
133
- return x / x.norm(dim=-1, keepdim=True)
134
-
135
- @classmethod
136
- def load_from_ply(
137
- cls,
138
- path: str,
139
- gamma: float = 1.0,
140
- ) -> "GaussianBase":
141
- plydata = PlyData.read(path)
142
- xyz = torch.stack(
143
- (
144
- torch.tensor(plydata.elements[0]["x"], dtype=torch.float32),
145
- torch.tensor(plydata.elements[0]["y"], dtype=torch.float32),
146
- torch.tensor(plydata.elements[0]["z"], dtype=torch.float32),
147
- ),
148
- dim=1,
149
- )
150
-
151
- opacities = torch.tensor(
152
- plydata.elements[0]["opacity"], dtype=torch.float32
153
- ).unsqueeze(-1)
154
- features_dc = torch.zeros((xyz.shape[0], 3), dtype=torch.float32)
155
- features_dc[:, 0] = torch.tensor(
156
- plydata.elements[0]["f_dc_0"], dtype=torch.float32
157
- )
158
- features_dc[:, 1] = torch.tensor(
159
- plydata.elements[0]["f_dc_1"], dtype=torch.float32
160
- )
161
- features_dc[:, 2] = torch.tensor(
162
- plydata.elements[0]["f_dc_2"], dtype=torch.float32
163
- )
164
-
165
- scale_names = [
166
- p.name
167
- for p in plydata.elements[0].properties
168
- if p.name.startswith("scale_")
169
- ]
170
- scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
171
- scales = torch.zeros(
172
- (xyz.shape[0], len(scale_names)), dtype=torch.float32
173
- )
174
- for idx, attr_name in enumerate(scale_names):
175
- scales[:, idx] = torch.tensor(
176
- plydata.elements[0][attr_name], dtype=torch.float32
177
- )
178
-
179
- rot_names = [
180
- p.name
181
- for p in plydata.elements[0].properties
182
- if p.name.startswith("rot_")
183
- ]
184
- rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
185
- rots = torch.zeros((xyz.shape[0], len(rot_names)), dtype=torch.float32)
186
- for idx, attr_name in enumerate(rot_names):
187
- rots[:, idx] = torch.tensor(
188
- plydata.elements[0][attr_name], dtype=torch.float32
189
- )
190
-
191
- rots = rots / torch.norm(rots, dim=-1, keepdim=True)
192
-
193
- # extra features
194
- extra_f_names = [
195
- p.name
196
- for p in plydata.elements[0].properties
197
- if p.name.startswith("f_rest_")
198
- ]
199
- extra_f_names = sorted(
200
- extra_f_names, key=lambda x: int(x.split("_")[-1])
201
- )
202
-
203
- max_sh_degree = int(np.sqrt((len(extra_f_names) + 3) / 3) - 1)
204
- if max_sh_degree != 0:
205
- features_extra = torch.zeros(
206
- (xyz.shape[0], len(extra_f_names)), dtype=torch.float32
207
- )
208
- for idx, attr_name in enumerate(extra_f_names):
209
- features_extra[:, idx] = torch.tensor(
210
- plydata.elements[0][attr_name], dtype=torch.float32
211
- )
212
-
213
- features_extra = features_extra.view(
214
- (features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1)
215
- )
216
- features_extra = features_extra.permute(0, 2, 1)
217
-
218
- if abs(gamma - 1.0) > 1e-3:
219
- features_dc = gamma_shs(features_dc, gamma)
220
- features_extra[..., :] = 0.0
221
- opacities *= 0.8
222
-
223
- shs = torch.cat(
224
- [
225
- features_dc.reshape(-1, 3),
226
- features_extra.reshape(len(features_dc), -1),
227
- ],
228
- dim=-1,
229
- )
230
- else:
231
- # sh_dim is 0, only dc features
232
- shs = features_dc
233
- features_extra = None
234
-
235
- return cls(
236
- sh_degree=max_sh_degree,
237
- _means=xyz,
238
- _opacities=opacities,
239
- _rgbs=shs,
240
- _scales=scales,
241
- _quats=rots,
242
- _features_dc=features_dc,
243
- _features_rest=features_extra,
244
- )
245
-
246
- def save_to_ply(
247
- self, path: str, colors: torch.Tensor = None, enable_mask: bool = False
248
- ):
249
- os.makedirs(os.path.dirname(path), exist_ok=True)
250
- numpy_data = self.get_numpy_data()
251
- means = numpy_data["_means"]
252
- scales = numpy_data["_scales"]
253
- quats = numpy_data["_quats"]
254
- opacities = numpy_data["_opacities"]
255
- sh0 = numpy_data["_features_dc"]
256
- shN = numpy_data.get("_features_rest", np.zeros((means.shape[0], 0)))
257
-
258
- # Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays # noqa
259
- if enable_mask:
260
- invalid_mask = (
261
- np.isnan(means).any(axis=1)
262
- | np.isinf(means).any(axis=1)
263
- | np.isnan(scales).any(axis=1)
264
- | np.isinf(scales).any(axis=1)
265
- | np.isnan(quats).any(axis=1)
266
- | np.isinf(quats).any(axis=1)
267
- | np.isnan(opacities).any(axis=0)
268
- | np.isinf(opacities).any(axis=0)
269
- | np.isnan(sh0).any(axis=1)
270
- | np.isinf(sh0).any(axis=1)
271
- | np.isnan(shN).any(axis=1)
272
- | np.isinf(shN).any(axis=1)
273
- )
274
-
275
- # Filter out rows with NaNs or Infs from all data arrays
276
- means = means[~invalid_mask]
277
- scales = scales[~invalid_mask]
278
- quats = quats[~invalid_mask]
279
- opacities = opacities[~invalid_mask]
280
- sh0 = sh0[~invalid_mask]
281
- shN = shN[~invalid_mask]
282
-
283
- num_points = means.shape[0]
284
-
285
- with open(path, "wb") as f:
286
- # Write PLY header
287
- f.write(b"ply\n")
288
- f.write(b"format binary_little_endian 1.0\n")
289
- f.write(f"element vertex {num_points}\n".encode())
290
- f.write(b"property float x\n")
291
- f.write(b"property float y\n")
292
- f.write(b"property float z\n")
293
- f.write(b"property float nx\n")
294
- f.write(b"property float ny\n")
295
- f.write(b"property float nz\n")
296
-
297
- if colors is not None:
298
- for j in range(colors.shape[1]):
299
- f.write(f"property float f_dc_{j}\n".encode())
300
- else:
301
- for i, data in enumerate([sh0, shN]):
302
- prefix = "f_dc" if i == 0 else "f_rest"
303
- for j in range(data.shape[1]):
304
- f.write(f"property float {prefix}_{j}\n".encode())
305
-
306
- f.write(b"property float opacity\n")
307
-
308
- for i in range(scales.shape[1]):
309
- f.write(f"property float scale_{i}\n".encode())
310
- for i in range(quats.shape[1]):
311
- f.write(f"property float rot_{i}\n".encode())
312
-
313
- f.write(b"end_header\n")
314
-
315
- # Write vertex data
316
- for i in range(num_points):
317
- f.write(struct.pack("<fff", *means[i])) # x, y, z
318
- f.write(struct.pack("<fff", 0, 0, 0)) # nx, ny, nz (zeros)
319
-
320
- if colors is not None:
321
- color = colors.detach().cpu().numpy()
322
- for j in range(color.shape[1]):
323
- f_dc = (color[i, j] - 0.5) / 0.2820947917738781
324
- f.write(struct.pack("<f", f_dc))
325
- else:
326
- for data in [sh0, shN]:
327
- for j in range(data.shape[1]):
328
- f.write(struct.pack("<f", data[i, j]))
329
-
330
- f.write(struct.pack("<f", opacities[i])) # opacity
331
-
332
- for data in [scales, quats]:
333
- for j in range(data.shape[1]):
334
- f.write(struct.pack("<f", data[i, j]))
335
-
336
-
337
- @dataclass
338
- class GaussianOperator(GaussianBase):
339
-
340
- def _compute_transform(
341
- self,
342
- means: torch.Tensor,
343
- quats: torch.Tensor,
344
- instance_pose: torch.Tensor,
345
- ):
346
- """Compute the transform of the GS models.
347
-
348
- Args:
349
- means: tensor of gs means.
350
- quats: tensor of gs quaternions.
351
- instance_pose: instances poses in [x y z qx qy qz qw] format.
352
-
353
- """
354
- # (x y z qx qy qz qw) -> (x y z qw qx qy qz)
355
- instance_pose = instance_pose[[0, 1, 2, 6, 3, 4, 5]]
356
- cur_instances_quats = self.quat_norm(instance_pose[3:])
357
- rot_cur = quat_to_rotmat(cur_instances_quats, mode="wxyz")
358
-
359
- # update the means
360
- num_gs = means.shape[0]
361
- trans_per_pts = torch.stack([instance_pose[:3]] * num_gs, dim=0)
362
- quat_per_pts = torch.stack([instance_pose[3:]] * num_gs, dim=0)
363
- rot_per_pts = torch.stack([rot_cur] * num_gs, dim=0) # (num_gs, 3, 3)
364
-
365
- # update the means
366
- cur_means = (
367
- torch.bmm(rot_per_pts, means.unsqueeze(-1)).squeeze(-1)
368
- + trans_per_pts
369
- )
370
-
371
- # update the quats
372
- _quats = self.quat_norm(quats)
373
- cur_quats = quat_mult(quat_per_pts, _quats)
374
-
375
- return cur_means, cur_quats
376
-
377
- def get_gaussians(
378
- self,
379
- c2w: torch.Tensor = None,
380
- instance_pose: torch.Tensor = None,
381
- apply_activate: bool = False,
382
- ) -> "GaussianBase":
383
- """Get Gaussian data under the given instance_pose."""
384
- if c2w is None:
385
- c2w = torch.eye(4).to(self.device)
386
-
387
- if instance_pose is not None:
388
- # compute the transformed gs means and quats
389
- world_means, world_quats = self._compute_transform(
390
- self._means, self._quats, instance_pose.float().to(self.device)
391
- )
392
- else:
393
- world_means, world_quats = self._means, self._quats
394
-
395
- # get colors of gaussians
396
- if self._features_rest is not None:
397
- colors = torch.cat(
398
- (self._features_dc[:, None, :], self._features_rest), dim=1
399
- )
400
- else:
401
- colors = self._features_dc[:, None, :]
402
-
403
- if self.sh_degree > 0:
404
- viewdirs = world_means.detach() - c2w[..., :3, 3] # (N, 3)
405
- viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
406
- rgbs = spherical_harmonics(self.sh_degree, viewdirs, colors)
407
- rgbs = torch.clamp(rgbs + 0.5, 0.0, 1.0)
408
- else:
409
- rgbs = torch.sigmoid(colors[:, 0, :])
410
-
411
- gs_dict = dict(
412
- _means=world_means,
413
- _opacities=(
414
- torch.sigmoid(self._opacities)
415
- if apply_activate
416
- else self._opacities
417
- ),
418
- _rgbs=rgbs,
419
- _scales=(
420
- torch.exp(self._scales) if apply_activate else self._scales
421
- ),
422
- _quats=self.quat_norm(world_quats),
423
- _features_dc=self._features_dc,
424
- _features_rest=self._features_rest,
425
- sh_degree=self.sh_degree,
426
- )
427
-
428
- return GaussianOperator(**gs_dict)
429
-
430
- def rescale(self, scale: float):
431
- if scale != 1.0:
432
- self._means *= scale
433
- self._scales += torch.log(self._scales.new_tensor(scale))
434
-
435
- def set_scale_by_height(self, real_height: float) -> None:
436
- def _ptp(tensor, dim):
437
- val = tensor.max(dim=dim).values - tensor.min(dim=dim).values
438
- return val.tolist()
439
-
440
- xyz_scale = max(_ptp(self._means, dim=0))
441
- self.rescale(1 / (xyz_scale + 1e-6)) # Normalize to [-0.5, 0.5]
442
- raw_height = _ptp(self._means, dim=0)[1]
443
- scale = real_height / raw_height
444
-
445
- self.rescale(scale)
446
-
447
- return
448
-
449
- @staticmethod
450
- def resave_ply(
451
- in_ply: str,
452
- out_ply: str,
453
- real_height: float = None,
454
- instance_pose: np.ndarray = None,
455
- sh_degree: int = 0,
456
- ) -> None:
457
- gs_model = GaussianOperator.load_from_ply(in_ply, sh_degree)
458
-
459
- if instance_pose is not None:
460
- gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
461
-
462
- if real_height is not None:
463
- gs_model.set_scale_by_height(real_height)
464
-
465
- gs_model.save_to_ply(out_ply)
466
-
467
- return
468
-
469
- @staticmethod
470
- def trans_to_quatpose(
471
- rot_matrix: list[list[float]],
472
- trans_matrix: list[float] = [0, 0, 0],
473
- ) -> torch.Tensor:
474
- if isinstance(rot_matrix, list):
475
- rot_matrix = np.array(rot_matrix)
476
-
477
- rot = Rotation.from_matrix(rot_matrix)
478
- qx, qy, qz, qw = rot.as_quat()
479
- instance_pose = torch.tensor([*trans_matrix, qx, qy, qz, qw])
480
-
481
- return instance_pose
482
-
483
- def render(
484
- self,
485
- c2w: torch.Tensor,
486
- Ks: torch.Tensor,
487
- image_width: int,
488
- image_height: int,
489
- ) -> RenderResult:
490
- gs = self.get_gaussians(c2w, apply_activate=True)
491
- renders, alphas, _ = rasterization(
492
- means=gs._means,
493
- quats=gs._quats,
494
- scales=gs._scales,
495
- opacities=gs._opacities.squeeze(),
496
- colors=gs._rgbs,
497
- viewmats=torch.linalg.inv(c2w)[None, ...],
498
- Ks=Ks[None, ...],
499
- width=image_width,
500
- height=image_height,
501
- packed=False,
502
- absgrad=True,
503
- sparse_grad=False,
504
- # rasterize_mode="classic",
505
- rasterize_mode="antialiased",
506
- **{
507
- "near_plane": 0.01,
508
- "far_plane": 1000000000,
509
- "radius_clip": 0.0,
510
- "render_mode": "RGB+ED",
511
- },
512
- )
513
- renders = renders[0]
514
- alphas = alphas[0].squeeze(-1)
515
-
516
- assert renders.shape[-1] == 4, f"Must render rgb, depth and alpha"
517
- rendered_rgb, rendered_depth = torch.split(renders, [3, 1], dim=-1)
518
-
519
- return RenderResult(
520
- torch.clamp(rendered_rgb, min=0, max=1),
521
- rendered_depth,
522
- alphas[..., None],
523
- )
524
-
525
-
526
- if __name__ == "__main__":
527
- input_gs = "outputs/test/debug.ply"
528
- output_gs = "./debug_v3.ply"
529
- gs_model: GaussianOperator = GaussianOperator.load_from_ply(input_gs)
530
-
531
- # 绕 x 轴旋转 180°
532
- R_x = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
533
- instance_pose = gs_model.trans_to_quatpose(R_x)
534
- gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
535
-
536
- gs_model.rescale(2)
537
-
538
- gs_model.set_scale_by_height(1.3)
539
-
540
- gs_model.save_to_ply(output_gs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/models/segment.py DELETED
@@ -1,375 +0,0 @@
1
- import logging
2
- import os
3
- from typing import Literal, Union
4
-
5
- import cv2
6
- import numpy as np
7
- import rembg
8
- import torch
9
- from huggingface_hub import snapshot_download
10
- from PIL import Image
11
- from segment_anything import (
12
- SamAutomaticMaskGenerator,
13
- SamPredictor,
14
- sam_model_registry,
15
- )
16
- from asset3d_gen.utils.process_media import filter_small_connected_components
17
- from asset3d_gen.validators.quality_checkers import ImageSegChecker
18
-
19
- logging.basicConfig(level=logging.INFO)
20
- logger = logging.getLogger(__name__)
21
-
22
-
23
- __all__ = [
24
- "resize_pil",
25
- "trellis_preprocess",
26
- "SAMRemover",
27
- "SAMPredictor",
28
- "RembgRemover",
29
- "get_segmented_image",
30
- ]
31
-
32
-
33
- def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image:
34
- max_size = max(image.size)
35
- scale = min(1, 1024 / max_size)
36
- if scale < 1:
37
- new_size = (int(image.width * scale), int(image.height * scale))
38
- image = image.resize(new_size, Image.Resampling.LANCZOS)
39
-
40
- return image
41
-
42
-
43
- def trellis_preprocess(image: Image.Image) -> Image.Image:
44
- """Process the input image as trellis done."""
45
- image_np = np.array(image)
46
- alpha = image_np[:, :, 3]
47
- bbox = np.argwhere(alpha > 0.8 * 255)
48
- bbox = (
49
- np.min(bbox[:, 1]),
50
- np.min(bbox[:, 0]),
51
- np.max(bbox[:, 1]),
52
- np.max(bbox[:, 0]),
53
- )
54
- center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
55
- size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
56
- size = int(size * 1.2)
57
- bbox = (
58
- center[0] - size // 2,
59
- center[1] - size // 2,
60
- center[0] + size // 2,
61
- center[1] + size // 2,
62
- )
63
- image = image.crop(bbox)
64
- image = image.resize((518, 518), Image.Resampling.LANCZOS)
65
- image = np.array(image).astype(np.float32) / 255
66
- image = image[:, :, :3] * image[:, :, 3:4]
67
- image = Image.fromarray((image * 255).astype(np.uint8))
68
-
69
- return image
70
-
71
-
72
- class SAMRemover(object):
73
- """Loading SAM models and performing background removal on images.
74
-
75
- Attributes:
76
- checkpoint (str): Path to the model checkpoint.
77
- model_type (str): Type of the SAM model to load (default: "vit_h").
78
- area_ratio (float): Area ratio filtering small connected components.
79
- """
80
-
81
- def __init__(
82
- self,
83
- checkpoint: str = None,
84
- model_type: str = "vit_h",
85
- area_ratio: float = 15,
86
- ):
87
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
88
- self.model_type = model_type
89
- self.area_ratio = area_ratio
90
-
91
- if checkpoint is None:
92
- suffix = "sam"
93
- model_path = snapshot_download(
94
- repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
95
- )
96
- checkpoint = os.path.join(
97
- model_path, suffix, "sam_vit_h_4b8939.pth"
98
- )
99
-
100
- self.mask_generator = self._load_sam_model(checkpoint)
101
-
102
- def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator:
103
- sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
104
- sam.to(device=self.device)
105
-
106
- return SamAutomaticMaskGenerator(sam)
107
-
108
- def __call__(
109
- self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
110
- ) -> Image.Image:
111
- """Removes the background from an image using the SAM model.
112
-
113
- Args:
114
- image (Union[str, Image.Image, np.ndarray]): Input image,
115
- can be a file path, PIL Image, or numpy array.
116
- save_path (str): Path to save the output image (default: None).
117
-
118
- Returns:
119
- Image.Image: The image with background removed,
120
- including an alpha channel.
121
- """
122
- # Convert input to numpy array
123
- if isinstance(image, str):
124
- image = Image.open(image)
125
- elif isinstance(image, np.ndarray):
126
- image = Image.fromarray(image).convert("RGB")
127
- image = resize_pil(image)
128
- image = np.array(image.convert("RGB"))
129
-
130
- # Generate masks
131
- masks = self.mask_generator.generate(image)
132
- masks = sorted(masks, key=lambda x: x["area"], reverse=True)
133
-
134
- if not masks:
135
- logger.warning(
136
- "Segmentation failed: No mask generated, return raw image."
137
- )
138
- output_image = Image.fromarray(image, mode="RGB")
139
- else:
140
- # Use the largest mask
141
- best_mask = masks[0]["segmentation"]
142
- mask = (best_mask * 255).astype(np.uint8)
143
- mask = filter_small_connected_components(
144
- mask, area_ratio=self.area_ratio
145
- )
146
- # Apply the mask to remove the background
147
- background_removed = cv2.bitwise_and(image, image, mask=mask)
148
- output_image = np.dstack((background_removed, mask))
149
- output_image = Image.fromarray(output_image, mode="RGBA")
150
-
151
- if save_path is not None:
152
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
153
- output_image.save(save_path)
154
-
155
- return output_image
156
-
157
-
158
- class SAMPredictor(object):
159
- def __init__(
160
- self,
161
- checkpoint: str = None,
162
- model_type: str = "vit_h",
163
- binary_thresh: float = 0.1,
164
- device: str = "cuda",
165
- ):
166
- self.device = device
167
- self.model_type = model_type
168
-
169
- if checkpoint is None:
170
- suffix = "sam"
171
- model_path = snapshot_download(
172
- repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
173
- )
174
- checkpoint = os.path.join(
175
- model_path, suffix, "sam_vit_h_4b8939.pth"
176
- )
177
-
178
- self.predictor = self._load_sam_model(checkpoint)
179
- self.binary_thresh = binary_thresh
180
-
181
- def _load_sam_model(self, checkpoint: str) -> SamPredictor:
182
- sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
183
- sam.to(device=self.device)
184
-
185
- return SamPredictor(sam)
186
-
187
- def preprocess_image(self, image: Image.Image) -> np.ndarray:
188
- if isinstance(image, str):
189
- image = Image.open(image)
190
- elif isinstance(image, np.ndarray):
191
- image = Image.fromarray(image).convert("RGB")
192
-
193
- image = resize_pil(image)
194
- image = np.array(image.convert("RGB"))
195
-
196
- return image
197
-
198
- def generate_masks(
199
- self,
200
- image: np.ndarray,
201
- selected_points: list[list[int]],
202
- ) -> np.ndarray:
203
- if len(selected_points) == 0:
204
- return []
205
-
206
- points = (
207
- torch.Tensor([p for p, _ in selected_points])
208
- .to(self.predictor.device)
209
- .unsqueeze(1)
210
- )
211
-
212
- labels = (
213
- torch.Tensor([int(l) for _, l in selected_points])
214
- .to(self.predictor.device)
215
- .unsqueeze(1)
216
- )
217
-
218
- transformed_points = self.predictor.transform.apply_coords_torch(
219
- points, image.shape[:2]
220
- )
221
-
222
- masks, scores, _ = self.predictor.predict_torch(
223
- point_coords=transformed_points,
224
- point_labels=labels,
225
- multimask_output=True,
226
- )
227
- valid_mask = masks[:, torch.argmax(scores, dim=1)]
228
- masks_pos = valid_mask[labels[:, 0] == 1, 0].cpu().detach().numpy()
229
- masks_neg = valid_mask[labels[:, 0] == 0, 0].cpu().detach().numpy()
230
- if len(masks_neg) == 0:
231
- masks_neg = np.zeros_like(masks_pos)
232
- if len(masks_pos) == 0:
233
- masks_pos = np.zeros_like(masks_neg)
234
- masks_neg = masks_neg.max(axis=0, keepdims=True)
235
- masks_pos = masks_pos.max(axis=0, keepdims=True)
236
- valid_mask = (masks_pos.astype(int) - masks_neg.astype(int)).clip(0, 1)
237
-
238
- binary_mask = (valid_mask > self.binary_thresh).astype(np.int32)
239
-
240
- return [(mask, f"mask_{i}") for i, mask in enumerate(binary_mask)]
241
-
242
- def get_segmented_image(
243
- self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
244
- ) -> Image.Image:
245
- seg_image = Image.fromarray(image, mode="RGB")
246
- alpha_channel = np.zeros(
247
- (seg_image.height, seg_image.width), dtype=np.uint8
248
- )
249
- for mask, _ in masks:
250
- # Use the maximum to combine multiple masks
251
- alpha_channel = np.maximum(alpha_channel, mask)
252
-
253
- alpha_channel = np.clip(alpha_channel, 0, 1)
254
- alpha_channel = (alpha_channel * 255).astype(np.uint8)
255
- alpha_image = Image.fromarray(alpha_channel, mode="L")
256
- r, g, b = seg_image.split()
257
- seg_image = Image.merge("RGBA", (r, g, b, alpha_image))
258
-
259
- return seg_image
260
-
261
- def __call__(
262
- self,
263
- image: Union[str, Image.Image, np.ndarray],
264
- selected_points: list[list[int]],
265
- ) -> Image.Image:
266
- image = self.preprocess_image(image)
267
- self.predictor.set_image(image)
268
- masks = self.generate_masks(image, selected_points)
269
-
270
- return self.get_segmented_image(image, masks)
271
-
272
-
273
- class RembgRemover(object):
274
- def __init__(self):
275
- self.rembg_session = rembg.new_session("u2net")
276
-
277
- def __call__(
278
- self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
279
- ) -> Image.Image:
280
- if isinstance(image, str):
281
- image = Image.open(image)
282
- elif isinstance(image, np.ndarray):
283
- image = Image.fromarray(image)
284
-
285
- image = resize_pil(image)
286
- output_image = rembg.remove(image, session=self.rembg_session)
287
-
288
- if save_path is not None:
289
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
290
- output_image.save(save_path)
291
-
292
- return output_image
293
-
294
-
295
- def invert_rgba_pil(
296
- image: Image.Image, mask: Image.Image, save_path: str = None
297
- ) -> Image.Image:
298
- mask = (255 - np.array(mask))[..., None]
299
- image_array = np.concatenate([np.array(image), mask], axis=-1)
300
- inverted_image = Image.fromarray(image_array, "RGBA")
301
-
302
- if save_path is not None:
303
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
304
- inverted_image.save(save_path)
305
-
306
- return inverted_image
307
-
308
-
309
- def get_segmented_image(
310
- image: Image.Image,
311
- sam_remover: SAMRemover,
312
- rbg_remover: RembgRemover,
313
- seg_checker: ImageSegChecker = None,
314
- save_path: str = None,
315
- mode: Literal["loose", "strict"] = "loose",
316
- ) -> Image.Image:
317
- def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
318
- if seg_checker is None:
319
- return True
320
- return raw_img.mode == "RGBA" and seg_checker([raw_img, seg_img])[0]
321
-
322
- out_sam = f"{save_path}_sam.png" if save_path else None
323
- out_sam_inv = f"{save_path}_sam_inv.png" if save_path else None
324
- out_rbg = f"{save_path}_rbg.png" if save_path else None
325
-
326
- seg_image = sam_remover(image, out_sam)
327
- seg_image = seg_image.convert("RGBA")
328
- _, _, _, alpha = seg_image.split()
329
- seg_image_inv = invert_rgba_pil(image.convert("RGB"), alpha, out_sam_inv)
330
- seg_image_rbg = rbg_remover(image, out_rbg)
331
-
332
- final_image = None
333
- if _is_valid_seg(image, seg_image):
334
- final_image = seg_image
335
- elif _is_valid_seg(image, seg_image_inv):
336
- final_image = seg_image_inv
337
- elif _is_valid_seg(image, seg_image_rbg):
338
- logger.warning(f"Failed to segment by `SAM`, retry with `rembg`.")
339
- final_image = seg_image_rbg
340
- else:
341
- if mode == "strict":
342
- raise RuntimeError(
343
- f"Failed to segment by `SAM` or `rembg`, abort."
344
- )
345
- logger.warning("Failed to segment by SAM or rembg, use raw image.")
346
- final_image = image.convert("RGBA")
347
-
348
- if save_path:
349
- final_image.save(save_path)
350
-
351
- final_image = trellis_preprocess(final_image)
352
-
353
- return final_image
354
-
355
-
356
- if __name__ == "__main__":
357
- input_image = "outputs/text2image/demo_objects/electrical/sample_0.jpg"
358
- output_image = "sample_0_seg2.png"
359
-
360
- # input_image = "outputs/text2image/tmp/coffee_machine.jpeg"
361
- # output_image = "outputs/text2image/tmp/coffee_machine_seg.png"
362
-
363
- # input_image = "outputs/text2image/tmp/bucket.jpeg"
364
- # output_image = "outputs/text2image/tmp/bucket_seg.png"
365
-
366
- remover = SAMRemover(
367
- # checkpoint="/horizon-bucket/robot_lab/users/xinjie.wang/weights/sam/sam_vit_h_4b8939.pth", # noqa
368
- model_type="vit_h",
369
- )
370
- remover = RembgRemover()
371
- # clean_image = remover(input_image)
372
- # clean_image.save(output_image)
373
- get_segmented_image(
374
- Image.open(input_image), remover, remover, None, "./test_seg.png"
375
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/models/super_resolution.py DELETED
@@ -1,116 +0,0 @@
1
- import logging
2
- import os
3
- from typing import Union
4
-
5
- import numpy as np
6
- import torch
7
- from huggingface_hub import snapshot_download
8
- from PIL import Image
9
- from asset3d_gen.data.utils import get_images_from_grid
10
-
11
- logging.basicConfig(
12
- format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
13
- )
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- __all__ = [
18
- "ImageStableSR",
19
- "ImageRealESRGAN",
20
- ]
21
-
22
-
23
- class ImageStableSR:
24
- def __init__(
25
- self,
26
- model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
27
- device="cuda",
28
- ) -> None:
29
- from diffusers import StableDiffusionUpscalePipeline
30
-
31
- self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
32
- model_path,
33
- torch_dtype=torch.float16,
34
- ).to(device)
35
- self.up_pipeline_x4.set_progress_bar_config(disable=True)
36
- # self.up_pipeline_x4.enable_model_cpu_offload()
37
-
38
- def __call__(
39
- self,
40
- image: Union[Image.Image, np.ndarray],
41
- prompt: str = "",
42
- infer_step: int = 20,
43
- ) -> Image.Image:
44
- if isinstance(image, np.ndarray):
45
- image = Image.fromarray(image)
46
-
47
- image = image.convert("RGB")
48
-
49
- with torch.no_grad():
50
- upscaled_image = self.up_pipeline_x4(
51
- image=image,
52
- prompt=[prompt],
53
- num_inference_steps=infer_step,
54
- ).images[0]
55
-
56
- return upscaled_image
57
-
58
-
59
- class ImageRealESRGAN:
60
- def __init__(self, outscale: int, model_path: str = None) -> None:
61
- from basicsr.archs.rrdbnet_arch import RRDBNet
62
- from realesrgan import RealESRGANer
63
-
64
- self.outscale = outscale
65
- model = RRDBNet(
66
- num_in_ch=3,
67
- num_out_ch=3,
68
- num_feat=64,
69
- num_block=23,
70
- num_grow_ch=32,
71
- scale=4,
72
- )
73
- if model_path is None:
74
- suffix = "super_resolution"
75
- model_path = snapshot_download(
76
- repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
77
- )
78
- model_path = os.path.join(
79
- model_path, suffix, "RealESRGAN_x4plus.pth"
80
- )
81
-
82
- self.upsampler = RealESRGANer(
83
- scale=4,
84
- model_path=model_path,
85
- model=model,
86
- pre_pad=0,
87
- half=True,
88
- )
89
-
90
- def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
91
- if isinstance(image, Image.Image):
92
- image = np.array(image)
93
-
94
- with torch.no_grad():
95
- output, _ = self.upsampler.enhance(image, outscale=self.outscale)
96
-
97
- return Image.fromarray(output)
98
-
99
-
100
- if __name__ == "__main__":
101
- color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png"
102
-
103
- # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
104
- # model_path = "/horizon-bucket/robot_lab/users/xinjie.wang/weights/super_resolution/RealESRGAN_x4plus.pth" # noqa
105
- super_model = ImageRealESRGAN(outscale=4)
106
- multiviews = get_images_from_grid(color_path, img_size=512)
107
- multiviews = [super_model(img.convert("RGB")) for img in multiviews]
108
- for idx, img in enumerate(multiviews):
109
- img.save(f"sr{idx}.png")
110
-
111
- # # Use stable diffusion for x4 (512->2048) image super resolution.
112
- # super_model = ImageStableSR()
113
- # multiviews = get_images_from_grid(color_path, img_size=512)
114
- # multiviews = [super_model(img) for img in multiviews]
115
- # for idx, img in enumerate(multiviews):
116
- # img.save(f"sr_stable{idx}.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/models/text_model.py DELETED
@@ -1,143 +0,0 @@
1
- import logging
2
-
3
- import torch
4
- from diffusers import (
5
- AutoencoderKL,
6
- EulerDiscreteScheduler,
7
- UNet2DConditionModel,
8
- )
9
- from kolors.models.modeling_chatglm import ChatGLMModel
10
- from kolors.models.tokenization_chatglm import ChatGLMTokenizer
11
- from kolors.models.unet_2d_condition import (
12
- UNet2DConditionModel as UNet2DConditionModelIP,
13
- )
14
- from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import (
15
- StableDiffusionXLPipeline,
16
- )
17
- from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa
18
- StableDiffusionXLPipeline as StableDiffusionXLPipelineIP,
19
- )
20
- from PIL import Image
21
- from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
22
-
23
- logging.basicConfig(level=logging.INFO)
24
- logger = logging.getLogger(__name__)
25
-
26
-
27
- __all__ = [
28
- "build_text2img_ip_pipeline",
29
- "build_text2img_pipeline",
30
- "text2img_gen",
31
- ]
32
-
33
-
34
- def build_text2img_ip_pipeline(
35
- ckpt_dir: str,
36
- ref_scale: float,
37
- device: str = "cuda",
38
- ) -> StableDiffusionXLPipelineIP:
39
- text_encoder = ChatGLMModel.from_pretrained(
40
- f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
41
- ).half()
42
- tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder")
43
- vae = AutoencoderKL.from_pretrained(
44
- f"{ckpt_dir}/vae", revision=None
45
- ).half()
46
- scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
47
- unet = UNet2DConditionModelIP.from_pretrained(
48
- f"{ckpt_dir}/unet", revision=None
49
- ).half()
50
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
51
- f"{ckpt_dir}/../Kolors-IP-Adapter-Plus/image_encoder",
52
- ignore_mismatched_sizes=True,
53
- ).to(dtype=torch.float16)
54
- clip_image_processor = CLIPImageProcessor(size=336, crop_size=336)
55
-
56
- pipe = StableDiffusionXLPipelineIP(
57
- vae=vae,
58
- text_encoder=text_encoder,
59
- tokenizer=tokenizer,
60
- unet=unet,
61
- scheduler=scheduler,
62
- image_encoder=image_encoder,
63
- feature_extractor=clip_image_processor,
64
- force_zeros_for_empty_prompt=False,
65
- )
66
-
67
- if hasattr(pipe.unet, "encoder_hid_proj"):
68
- pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
69
-
70
- pipe.load_ip_adapter(
71
- f"{ckpt_dir}/../Kolors-IP-Adapter-Plus",
72
- subfolder="",
73
- weight_name=["ip_adapter_plus_general.bin"],
74
- )
75
- pipe.set_ip_adapter_scale([ref_scale])
76
-
77
- pipe = pipe.to(device)
78
- # pipe.enable_model_cpu_offload()
79
- # pipe.enable_xformers_memory_efficient_attention()
80
- # pipe.enable_vae_slicing()
81
-
82
- return pipe
83
-
84
-
85
- def build_text2img_pipeline(
86
- ckpt_dir: str,
87
- device: str = "cuda",
88
- ) -> StableDiffusionXLPipeline:
89
- text_encoder = ChatGLMModel.from_pretrained(
90
- f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
91
- ).half()
92
- tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder")
93
- vae = AutoencoderKL.from_pretrained(
94
- f"{ckpt_dir}/vae", revision=None
95
- ).half()
96
- scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
97
- unet = UNet2DConditionModel.from_pretrained(
98
- f"{ckpt_dir}/unet", revision=None
99
- ).half()
100
- pipe = StableDiffusionXLPipeline(
101
- vae=vae,
102
- text_encoder=text_encoder,
103
- tokenizer=tokenizer,
104
- unet=unet,
105
- scheduler=scheduler,
106
- force_zeros_for_empty_prompt=False,
107
- )
108
- pipe = pipe.to(device)
109
- # pipe.enable_model_cpu_offload()
110
- # pipe.enable_xformers_memory_efficient_attention()
111
-
112
- return pipe
113
-
114
-
115
- def text2img_gen(
116
- prompt: str,
117
- n_sample: int,
118
- guidance_scale: float,
119
- pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP,
120
- ip_image: Image.Image | str = None,
121
- image_wh: tuple[int, int] = [1024, 1024],
122
- infer_step: int = 50,
123
- ip_image_size: int = 512,
124
- ) -> list[Image.Image]:
125
- prompt = "Single " + prompt + ", in the center of the image"
126
- prompt += ", high quality, high resolution, best quality, white background, 3D style," # noqa
127
- logger.info(f"Processing prompt: {prompt}")
128
-
129
- kwargs = dict(
130
- prompt=prompt,
131
- height=image_wh[1],
132
- width=image_wh[0],
133
- num_inference_steps=infer_step,
134
- guidance_scale=guidance_scale,
135
- num_images_per_prompt=n_sample,
136
- )
137
- if ip_image is not None:
138
- if isinstance(ip_image, str):
139
- ip_image = Image.open(ip_image)
140
- ip_image = ip_image.resize((ip_image_size, ip_image_size))
141
- kwargs.update(ip_adapter_image=[ip_image])
142
-
143
- return pipeline(**kwargs).images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/models/texture_model.py DELETED
@@ -1,91 +0,0 @@
1
- import os
2
-
3
- import torch
4
- from diffusers import AutoencoderKL, DiffusionPipeline, EulerDiscreteScheduler
5
- from huggingface_hub import snapshot_download
6
- from kolors.models.controlnet import ControlNetModel
7
- from kolors.models.modeling_chatglm import ChatGLMModel
8
- from kolors.models.tokenization_chatglm import ChatGLMTokenizer
9
- from kolors.models.unet_2d_condition import UNet2DConditionModel
10
- from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import (
11
- StableDiffusionXLControlNetImg2ImgPipeline,
12
- )
13
- from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
14
-
15
- __all__ = [
16
- "build_texture_gen_pipe",
17
- ]
18
-
19
-
20
- def build_texture_gen_pipe(
21
- base_ckpt_dir: str,
22
- controlnet_ckpt: str = None,
23
- ip_adapt_scale: float = 0,
24
- device: str = "cuda",
25
- ) -> DiffusionPipeline:
26
- tokenizer = ChatGLMTokenizer.from_pretrained(
27
- f"{base_ckpt_dir}/Kolors/text_encoder"
28
- )
29
- text_encoder = ChatGLMModel.from_pretrained(
30
- f"{base_ckpt_dir}/Kolors/text_encoder", torch_dtype=torch.float16
31
- ).half()
32
- vae = AutoencoderKL.from_pretrained(
33
- f"{base_ckpt_dir}/Kolors/vae", revision=None
34
- ).half()
35
- unet = UNet2DConditionModel.from_pretrained(
36
- f"{base_ckpt_dir}/Kolors/unet", revision=None
37
- ).half()
38
- scheduler = EulerDiscreteScheduler.from_pretrained(
39
- f"{base_ckpt_dir}/Kolors/scheduler"
40
- )
41
-
42
- if controlnet_ckpt is None:
43
- suffix = "geo_cond_mv"
44
- model_path = snapshot_download(
45
- repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
46
- )
47
- controlnet_ckpt = os.path.join(model_path, suffix)
48
-
49
- controlnet = ControlNetModel.from_pretrained(
50
- controlnet_ckpt, use_safetensors=True
51
- ).half()
52
-
53
- # IP-Adapter model
54
- image_encoder = None
55
- clip_image_processor = None
56
- if ip_adapt_scale > 0:
57
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
58
- f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus/image_encoder",
59
- # ignore_mismatched_sizes=True,
60
- ).to(dtype=torch.float16)
61
- ip_img_size = 336
62
- clip_image_processor = CLIPImageProcessor(
63
- size=ip_img_size, crop_size=ip_img_size
64
- )
65
-
66
- pipe = StableDiffusionXLControlNetImg2ImgPipeline(
67
- vae=vae,
68
- controlnet=controlnet,
69
- text_encoder=text_encoder,
70
- tokenizer=tokenizer,
71
- unet=unet,
72
- scheduler=scheduler,
73
- image_encoder=image_encoder,
74
- feature_extractor=clip_image_processor,
75
- force_zeros_for_empty_prompt=False,
76
- )
77
-
78
- if ip_adapt_scale > 0:
79
- if hasattr(pipe.unet, "encoder_hid_proj"):
80
- pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
81
- pipe.load_ip_adapter(
82
- f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus",
83
- subfolder="",
84
- weight_name=["ip_adapter_plus_general.bin"],
85
- )
86
- pipe.set_ip_adapter_scale([ip_adapt_scale])
87
-
88
- pipe = pipe.to(device)
89
- # pipe.enable_model_cpu_offload()
90
-
91
- return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/scripts/render_gs.py DELETED
@@ -1,156 +0,0 @@
1
- import argparse
2
- import logging
3
- import math
4
- import os
5
-
6
- import cv2
7
- import numpy as np
8
- import torch
9
- from tqdm import tqdm
10
- from asset3d_gen.data.utils import (
11
- CameraSetting,
12
- init_kal_camera,
13
- normalize_vertices_array,
14
- )
15
- from asset3d_gen.models.gs_model import GaussianOperator
16
-
17
- logging.basicConfig(
18
- format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
19
- )
20
- logger = logging.getLogger(__name__)
21
-
22
-
23
- def parse_args():
24
- parser = argparse.ArgumentParser(description="Render GS color images")
25
-
26
- parser.add_argument(
27
- "--input_gs", type=str, help="Input render GS.ply path."
28
- )
29
- parser.add_argument(
30
- "--output_path",
31
- type=str,
32
- help="Output grid image path for rendered GS color images.",
33
- )
34
- parser.add_argument(
35
- "--num_images", type=int, default=6, help="Number of images to render."
36
- )
37
- parser.add_argument(
38
- "--elevation",
39
- type=float,
40
- nargs="+",
41
- default=[20.0, -10.0],
42
- help="Elevation angles for the camera (default: [20.0, -10.0])",
43
- )
44
- parser.add_argument(
45
- "--distance",
46
- type=float,
47
- default=5,
48
- help="Camera distance (default: 5)",
49
- )
50
- parser.add_argument(
51
- "--resolution_hw",
52
- type=int,
53
- nargs=2,
54
- default=(512, 512),
55
- help="Resolution of the output images (default: (512, 512))",
56
- )
57
- parser.add_argument(
58
- "--fov",
59
- type=float,
60
- default=30,
61
- help="Field of view in degrees (default: 30)",
62
- )
63
- parser.add_argument(
64
- "--device",
65
- type=str,
66
- choices=["cpu", "cuda"],
67
- default="cuda",
68
- help="Device to run on (default: `cuda`)",
69
- )
70
- parser.add_argument(
71
- "--image_size",
72
- type=int,
73
- default=512,
74
- help="Output image size for single view in color grid (default: 512)",
75
- )
76
-
77
- args = parser.parse_args()
78
-
79
- return args
80
-
81
-
82
- def load_gs_model(
83
- input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
84
- ) -> GaussianOperator:
85
- gs_model = GaussianOperator.load_from_ply(input_gs)
86
- # Normalize vertices to [-1, 1], center to (0, 0, 0).
87
- _, scale, center = normalize_vertices_array(gs_model._means)
88
- scale, center = float(scale), center.tolist()
89
- transpose = [*[-v for v in center], *pre_quat]
90
- instance_pose = torch.tensor(transpose).to(gs_model.device)
91
- gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
92
- gs_model.rescale(scale)
93
-
94
- return gs_model
95
-
96
-
97
- def entrypoint(input_gs: str = None, output_path: str = None) -> None:
98
- args = parse_args()
99
- if isinstance(input_gs, str):
100
- args.input_gs = input_gs
101
- if isinstance(output_path, str):
102
- args.output_path = output_path
103
-
104
- # Setup camera parameters
105
- camera_params = CameraSetting(
106
- num_images=args.num_images,
107
- elevation=args.elevation,
108
- distance=args.distance,
109
- resolution_hw=args.resolution_hw,
110
- fov=math.radians(args.fov),
111
- device=args.device,
112
- )
113
- camera = init_kal_camera(camera_params)
114
- matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
115
- matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3]
116
- w2cs = matrix_mv.to(camera_params.device)
117
- c2ws = [torch.linalg.inv(matrix) for matrix in w2cs]
118
- Ks = torch.tensor(camera_params.Ks).to(camera_params.device)
119
-
120
- # Load GS model and normalize.
121
- gs_model = load_gs_model(args.input_gs, pre_quat=[0.0, 0.0, 1.0, 0.0])
122
-
123
- # Render GS color images.
124
- images = []
125
- for idx in tqdm(range(len(c2ws)), desc="Rendering GS"):
126
- result = gs_model.render(
127
- c2ws[idx],
128
- Ks=Ks,
129
- image_width=camera_params.resolution_hw[1],
130
- image_height=camera_params.resolution_hw[0],
131
- )
132
- color = cv2.resize(
133
- result.rgba,
134
- (args.image_size, args.image_size),
135
- interpolation=cv2.INTER_AREA,
136
- )
137
- images.append(color)
138
-
139
- # Cat color images into grid image and save.
140
- select_idxs = [[0, 2, 1], [5, 4, 3]] # fix order for 6 views
141
- grid_image = []
142
- for row_idxs in select_idxs:
143
- row_image = []
144
- for row_idx in row_idxs:
145
- row_image.append(images[row_idx])
146
- row_image = np.concatenate(row_image, axis=1)
147
- grid_image.append(row_image)
148
-
149
- grid_image = np.concatenate(grid_image, axis=0)
150
- os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
151
- cv2.imwrite(args.output_path, grid_image)
152
- logger.info(f"Saved grid image to {args.output_path}")
153
-
154
-
155
- if __name__ == "__main__":
156
- entrypoint()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/scripts/render_mv.py DELETED
@@ -1,183 +0,0 @@
1
- import logging
2
- import os
3
- import random
4
- from typing import List, Tuple
5
-
6
- import fire
7
- import numpy as np
8
- import torch
9
- from diffusers.utils import make_image_grid
10
- from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import (
11
- StableDiffusionXLControlNetImg2ImgPipeline,
12
- )
13
- from PIL import Image, ImageEnhance, ImageFilter
14
- from torchvision import transforms
15
- from asset3d_gen.data.datasets import Asset3dGenDataset
16
- from asset3d_gen.models.texture_model import build_texture_gen_pipe
17
-
18
- logging.basicConfig(level=logging.INFO)
19
- logger = logging.getLogger(__name__)
20
-
21
-
22
- def get_init_noise_image(image: Image.Image) -> Image.Image:
23
- blurred_image = image.convert("L").filter(
24
- ImageFilter.GaussianBlur(radius=3)
25
- )
26
-
27
- enhancer = ImageEnhance.Contrast(blurred_image)
28
- image_decreased_contrast = enhancer.enhance(factor=0.5)
29
-
30
- return image_decreased_contrast
31
-
32
-
33
- def infer_pipe(
34
- index_file: str,
35
- controlnet_ckpt: str = None,
36
- uid: str = None,
37
- prompt: str = None,
38
- controlnet_cond_scale: float = 0.4,
39
- control_guidance_end: float = 0.9,
40
- strength: float = 1.0,
41
- num_inference_steps: int = 50,
42
- guidance_scale: float = 10,
43
- ip_adapt_scale: float = 0,
44
- ip_img_path: str = None,
45
- sub_idxs: List[List[int]] = None,
46
- num_images_per_prompt: int = 3, # increase if want similar images.
47
- device: str = "cuda",
48
- save_dir: str = "infer_vis",
49
- seed: int = None,
50
- target_hw: tuple[int, int] = (512, 512),
51
- pipeline: StableDiffusionXLControlNetImg2ImgPipeline = None,
52
- ) -> str:
53
- # sub_idxs = [[0, 1, 2], [3, 4, 5]] # None for single image.
54
- if sub_idxs is None:
55
- sub_idxs = [[random.randint(0, 5)]] # 6 views.
56
- target_hw = [2 * size for size in target_hw]
57
-
58
- transform_list = [
59
- transforms.Resize(
60
- target_hw, interpolation=transforms.InterpolationMode.BILINEAR
61
- ),
62
- transforms.CenterCrop(target_hw),
63
- transforms.ToTensor(),
64
- transforms.Normalize([0.5], [0.5]),
65
- ]
66
- image_transform = transforms.Compose(transform_list)
67
- control_transform = transforms.Compose(transform_list[:-1])
68
-
69
- grid_hw = (target_hw[0] * len(sub_idxs), target_hw[1] * len(sub_idxs[0]))
70
- dataset = Asset3dGenDataset(
71
- index_file, target_hw=grid_hw, sub_idxs=sub_idxs
72
- )
73
-
74
- if uid is None:
75
- uid = random.choice(list(dataset.meta_info.keys()))
76
- if prompt is None:
77
- prompt = dataset.meta_info[uid]["capture"]
78
- if isinstance(prompt, List) or isinstance(prompt, Tuple):
79
- prompt = ", ".join(map(str, prompt))
80
- # prompt += "high quality, ultra-clear, high resolution, best quality, 4k"
81
- # prompt += "高品质,清晰,细节"
82
- prompt += ", high quality, high resolution, best quality"
83
- # prompt += ", with diffuse lighting, showing no reflections."
84
- logger.info(f"Inference with prompt: {prompt}")
85
-
86
- negative_prompt = (
87
- "nsfw,脸部阴影,低分辨率,jpeg伪影、模糊、糟糕,黑脸,霓虹灯,高光,镜面反射"
88
- )
89
-
90
- control_image = dataset.fetch_sample_grid_images(
91
- uid,
92
- attrs=["image_view_normal", "image_position", "image_mask"],
93
- sub_idxs=sub_idxs,
94
- transform=control_transform,
95
- )
96
-
97
- color_image = dataset.fetch_sample_grid_images(
98
- uid,
99
- attrs=["image_color"],
100
- sub_idxs=sub_idxs,
101
- transform=image_transform,
102
- )
103
-
104
- normal_pil, position_pil, mask_pil, color_pil = dataset.visualize_item(
105
- control_image,
106
- color_image,
107
- save_dir=save_dir,
108
- )
109
-
110
- if pipeline is None:
111
- pipeline = build_texture_gen_pipe(
112
- base_ckpt_dir="./weights",
113
- controlnet_ckpt=controlnet_ckpt,
114
- ip_adapt_scale=ip_adapt_scale,
115
- device=device,
116
- )
117
-
118
- if ip_adapt_scale > 0 and ip_img_path is not None and len(ip_img_path) > 0:
119
- ip_image = Image.open(ip_img_path).convert("RGB")
120
- ip_image = ip_image.resize(target_hw[::-1])
121
- ip_image = [ip_image]
122
- pipeline.set_ip_adapter_scale([ip_adapt_scale])
123
- else:
124
- ip_image = None
125
-
126
- generator = None
127
- if seed is not None:
128
- generator = torch.Generator(device).manual_seed(seed)
129
- torch.manual_seed(seed)
130
- np.random.seed(seed)
131
- random.seed(seed)
132
-
133
- init_image = get_init_noise_image(normal_pil)
134
- # init_image = get_init_noise_image(color_pil)
135
-
136
- images = []
137
- row_num, col_num = 2, 3
138
- img_save_paths = []
139
- while len(images) < col_num:
140
- image = pipeline(
141
- prompt=prompt,
142
- image=init_image,
143
- controlnet_conditioning_scale=controlnet_cond_scale,
144
- control_guidance_end=control_guidance_end,
145
- strength=strength,
146
- control_image=control_image[None, ...],
147
- negative_prompt=negative_prompt,
148
- num_inference_steps=num_inference_steps,
149
- guidance_scale=guidance_scale,
150
- num_images_per_prompt=num_images_per_prompt,
151
- ip_adapter_image=ip_image,
152
- generator=generator,
153
- ).images
154
- images.extend(image)
155
-
156
- grid_image = [normal_pil, position_pil, color_pil] + images[:col_num]
157
- # save_dir = os.path.join(save_dir, uid)
158
- os.makedirs(save_dir, exist_ok=True)
159
-
160
- for idx in range(col_num):
161
- rgba_image = Image.merge("RGBA", (*images[idx].split(), mask_pil))
162
- img_save_path = os.path.join(save_dir, f"color_sample{idx}.png")
163
- rgba_image.save(img_save_path)
164
- img_save_paths.append(img_save_path)
165
-
166
- sub_idxs = "_".join(
167
- [str(item) for sublist in sub_idxs for item in sublist]
168
- )
169
- save_path = os.path.join(
170
- save_dir, f"sample_idx{str(sub_idxs)}_ip{ip_adapt_scale}.jpg"
171
- )
172
- make_image_grid(grid_image, row_num, col_num).save(save_path)
173
- logger.info(f"Visualize in {save_path}")
174
-
175
- return img_save_paths
176
-
177
-
178
- def entrypoint() -> None:
179
- fire.Fire(infer_pipe)
180
-
181
-
182
- if __name__ == "__main__":
183
- entrypoint()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/scripts/text2image.py DELETED
@@ -1,145 +0,0 @@
1
- import argparse
2
- import logging
3
- import os
4
-
5
- from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import (
6
- StableDiffusionXLPipeline,
7
- )
8
- from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa
9
- StableDiffusionXLPipeline as StableDiffusionXLPipelineIP,
10
- )
11
- from tqdm import tqdm
12
- from asset3d_gen.models.text_model import (
13
- build_text2img_ip_pipeline,
14
- build_text2img_pipeline,
15
- text2img_gen,
16
- )
17
-
18
- logging.basicConfig(level=logging.INFO)
19
- logger = logging.getLogger(__name__)
20
-
21
-
22
- def parse_args():
23
- parser = argparse.ArgumentParser(description="Text to Image.")
24
- parser.add_argument(
25
- "--prompts",
26
- type=str,
27
- nargs="+",
28
- help="List of prompts (space-separated).",
29
- )
30
- parser.add_argument(
31
- "--ref_image",
32
- type=str,
33
- nargs="+",
34
- help="List of ref_image paths (space-separated).",
35
- )
36
- parser.add_argument(
37
- "--output_root",
38
- type=str,
39
- help="Root directory for saving outputs.",
40
- )
41
- parser.add_argument(
42
- "--guidance_scale",
43
- type=float,
44
- default=12.0,
45
- help="Guidance scale for the diffusion model.",
46
- )
47
- parser.add_argument(
48
- "--ref_scale",
49
- type=float,
50
- default=0.3,
51
- help="Reference image scale for the IP adapter.",
52
- )
53
- parser.add_argument(
54
- "--n_sample",
55
- type=int,
56
- default=1,
57
- )
58
- parser.add_argument(
59
- "--resolution",
60
- type=int,
61
- default=1024,
62
- )
63
- parser.add_argument(
64
- "--infer_step",
65
- type=int,
66
- default=50,
67
- )
68
- args = parser.parse_args()
69
-
70
- return args
71
-
72
-
73
- def entrypoint(
74
- pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP = None,
75
- **kwargs,
76
- ) -> list[str]:
77
- args = parse_args()
78
- for k, v in kwargs.items():
79
- if hasattr(args, k) and v is not None:
80
- setattr(args, k, v)
81
-
82
- prompts = args.prompts
83
- if len(prompts) == 1 and prompts[0].endswith(".txt"):
84
- with open(prompts[0], "r") as f:
85
- prompts = f.readlines()
86
- prompts = [
87
- prompt.strip() for prompt in prompts if prompt.strip() != ""
88
- ]
89
-
90
- os.makedirs(args.output_root, exist_ok=True)
91
-
92
- ip_img_paths = args.ref_image
93
- if ip_img_paths is None or len(ip_img_paths) == 0:
94
- args.ref_scale = 0
95
- ip_img_paths = [None] * len(prompts)
96
- elif isinstance(ip_img_paths, str):
97
- ip_img_paths = [ip_img_paths] * len(prompts)
98
- elif isinstance(ip_img_paths, list):
99
- if len(ip_img_paths) == 1:
100
- ip_img_paths = ip_img_paths * len(prompts)
101
- else:
102
- raise ValueError("Invalid ref_image paths.")
103
- assert len(ip_img_paths) == len(
104
- prompts
105
- ), f"Number of ref images does not match prompts, {len(ip_img_paths)} != {len(prompts)}" # noqa
106
-
107
- if pipeline is None:
108
- if args.ref_scale > 0:
109
- pipeline = build_text2img_ip_pipeline(
110
- "weights/Kolors",
111
- ref_scale=args.ref_scale,
112
- )
113
- else:
114
- pipeline = build_text2img_pipeline("weights/Kolors")
115
-
116
- for idx, (prompt, ip_img_path) in tqdm(
117
- enumerate(zip(prompts, ip_img_paths)),
118
- desc="Generating images",
119
- total=len(prompts),
120
- ):
121
- images = text2img_gen(
122
- prompt=prompt,
123
- n_sample=args.n_sample,
124
- guidance_scale=args.guidance_scale,
125
- pipeline=pipeline,
126
- ip_image=ip_img_path,
127
- image_wh=[args.resolution, args.resolution],
128
- infer_step=args.infer_step,
129
- )
130
-
131
- save_paths = []
132
- for sub_idx, image in enumerate(images):
133
- save_path = (
134
- f"{args.output_root}/sample_{idx*args.n_sample+sub_idx}.png"
135
- )
136
- image.save(save_path)
137
- save_paths.append(save_path)
138
-
139
- logger.info(f"Images saved to {args.output_root}")
140
-
141
- return save_paths
142
-
143
-
144
- if __name__ == "__main__":
145
- entrypoint()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/utils/gpt_clients.py DELETED
@@ -1,191 +0,0 @@
1
- import base64
2
- import logging
3
- import os
4
- from io import BytesIO
5
- from typing import Optional
6
-
7
- from openai import AzureOpenAI, OpenAI # pip install openai
8
- from PIL import Image
9
- from tenacity import (
10
- retry,
11
- stop_after_attempt,
12
- stop_after_delay,
13
- wait_random_exponential,
14
- )
15
- from asset3d_gen.utils.process_media import combine_images_to_base64
16
-
17
- logging.basicConfig(level=logging.INFO)
18
- logger = logging.getLogger(__name__)
19
-
20
-
21
- class GPTclient:
22
- """A client to interact with the GPT model via OpenAI or Azure API."""
23
-
24
- def __init__(
25
- self,
26
- endpoint: str,
27
- api_key: str,
28
- model_name: str = "yfb-gpt-4o",
29
- api_version: str = None,
30
- verbose: bool = False,
31
- ):
32
- if api_version is not None:
33
- self.client = AzureOpenAI(
34
- azure_endpoint=endpoint,
35
- api_key=api_key,
36
- api_version=api_version,
37
- )
38
- else:
39
- self.client = OpenAI(
40
- base_url=endpoint,
41
- api_key=api_key,
42
- )
43
-
44
- self.endpoint = endpoint
45
- self.model_name = model_name
46
- self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
47
- self.verbose = verbose
48
-
49
- @retry(
50
- wait=wait_random_exponential(min=1, max=20),
51
- stop=(stop_after_attempt(10) | stop_after_delay(30)),
52
- )
53
- def completion_with_backoff(self, **kwargs):
54
- return self.client.chat.completions.create(**kwargs)
55
-
56
- def query(
57
- self,
58
- text_prompt: str,
59
- image_base64: Optional[list[str | Image.Image]] = None,
60
- system_role: Optional[str] = None,
61
- ) -> Optional[str]:
62
- """Queries the GPT model with a text and optional image prompts.
63
-
64
- Args:
65
- text_prompt (str): The main text input that the model responds to.
66
- image_base64 (Optional[List[str]]): A list of image base64 strings
67
- or local image paths or PIL.Image to accompany the text prompt.
68
- system_role (Optional[str]): Optional system-level instructions
69
- that specify the behavior of the assistant.
70
-
71
- Returns:
72
- Optional[str]: The response content generated by the model based on
73
- the prompt. Returns `None` if an error occurs.
74
- """
75
- if system_role is None:
76
- system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
77
-
78
- content_user = [
79
- {
80
- "type": "text",
81
- "text": text_prompt,
82
- },
83
- ]
84
-
85
- # Process images if provided
86
- if image_base64 is not None:
87
- image_base64 = (
88
- image_base64
89
- if isinstance(image_base64, list)
90
- else [image_base64]
91
- )
92
- for img in image_base64:
93
- if isinstance(img, Image.Image):
94
- buffer = BytesIO()
95
- img.save(buffer, format=img.format or "PNG")
96
- buffer.seek(0)
97
- image_binary = buffer.read()
98
- img = base64.b64encode(image_binary).decode("utf-8")
99
- elif (
100
- len(os.path.splitext(img)) > 1
101
- and os.path.splitext(img)[-1].lower() in self.image_formats
102
- ):
103
- if not os.path.exists(img):
104
- raise FileNotFoundError(f"Image file not found: {img}")
105
- with open(img, "rb") as f:
106
- img = base64.b64encode(f.read()).decode("utf-8")
107
-
108
- content_user.append(
109
- {
110
- "type": "image_url",
111
- "image_url": {"url": f"data:image/png;base64,{img}"},
112
- }
113
- )
114
-
115
- payload = {
116
- "messages": [
117
- {"role": "system", "content": system_role},
118
- {"role": "user", "content": content_user},
119
- ],
120
- "temperature": 0.1,
121
- "max_tokens": 500,
122
- "top_p": 0.1,
123
- "frequency_penalty": 0,
124
- "presence_penalty": 0,
125
- "stop": None,
126
- }
127
- payload.update({"model": self.model_name})
128
-
129
- response = None
130
- try:
131
- response = self.completion_with_backoff(**payload)
132
- response = response.choices[0].message.content
133
- except Exception as e:
134
- logger.error(f"Error GPTclint {self.endpoint} API call: {e}")
135
- response = None
136
-
137
- if self.verbose:
138
- logger.info(f"Prompt: {text_prompt}")
139
- logger.info(f"Response: {response}")
140
-
141
- return response
142
-
143
-
144
- endpoint = os.environ.get("endpoint", None)
145
- api_key = os.environ.get("api_key", None)
146
- api_version = os.environ.get("api_version", None)
147
- if endpoint and api_key and api_version:
148
- GPT_CLIENT = GPTclient(
149
- endpoint=endpoint,
150
- api_key=api_key,
151
- api_version=api_version,
152
- model_name="yfb-gpt-4o-sweden" if "sweden" in endpoint else None,
153
- )
154
- else:
155
- GPT_CLIENT = GPTclient(
156
- endpoint="https://openrouter.ai/api/v1",
157
- api_key="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", # noqa
158
- model_name="qwen/qwen2.5-vl-72b-instruct:free",
159
- )
160
-
161
-
162
- if __name__ == "__main__":
163
- if "openrouter" in GPT_CLIENT.endpoint:
164
- response = GPT_CLIENT.query(
165
- text_prompt="What is the content in each image?",
166
- image_base64=combine_images_to_base64(
167
- [
168
- "outputs/text2image/demo_objects/bed/sample_0.jpg",
169
- "outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png", # noqa
170
- "outputs/text2image/demo_objects/cardboard/sample_1.jpg",
171
- ]
172
- ), # input raw image_path if only one image
173
- )
174
- print(response)
175
- else:
176
- response = GPT_CLIENT.query(
177
- text_prompt="What is the content in the images?",
178
- image_base64=[
179
- Image.open("outputs/text2image/demo_objects/bed/sample_0.jpg"),
180
- Image.open(
181
- "outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png" # noqa
182
- ),
183
- ],
184
- )
185
- print(response)
186
-
187
- # test2: text prompt
188
- response = GPT_CLIENT.query(
189
- text_prompt="What is the capital of China?"
190
- )
191
- print(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/utils/process_media.py DELETED
@@ -1,194 +0,0 @@
1
- import base64
2
- import logging
3
- import math
4
- import os
5
- import subprocess
6
- from glob import glob
7
- from io import BytesIO
8
- from typing import Union
9
-
10
- import cv2
11
- import imageio
12
- import numpy as np
13
- import PIL.Image as Image
14
- from moviepy.editor import VideoFileClip, clips_array
15
-
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
-
19
-
20
- __all__ = [
21
- "render_asset3d",
22
- "merge_images_video",
23
- "filter_small_connected_components",
24
- "filter_image_small_connected_components",
25
- "combine_images_to_base64",
26
- ]
27
-
28
-
29
- def render_asset3d(
30
- mesh_path: str,
31
- output_root: str,
32
- distance: float = 5.0,
33
- num_images: int = 1,
34
- elevation: list[float] = (0.0,),
35
- pbr_light_factor: float = 1.5,
36
- return_key: str = "image_color/*",
37
- output_subdir: str = "renders",
38
- gen_color_mp4: bool = False,
39
- gen_viewnormal_mp4: bool = False,
40
- gen_glonormal_mp4: bool = False,
41
- device: str = "cpu",
42
- ) -> list[str]:
43
- command = [
44
- "python3",
45
- "asset3d_gen/data/differentiable_render.py",
46
- "--mesh_path",
47
- mesh_path,
48
- "--output_root",
49
- output_root,
50
- "--uuid",
51
- output_subdir,
52
- "--distance",
53
- str(distance),
54
- "--num_images",
55
- str(num_images),
56
- "--elevation",
57
- *map(str, elevation),
58
- "--pbr_light_factor",
59
- str(pbr_light_factor),
60
- "--with_mtl",
61
- "--device",
62
- device,
63
- ]
64
- if gen_color_mp4:
65
- command.append("--gen_color_mp4")
66
- if gen_viewnormal_mp4:
67
- command.append("--gen_viewnormal_mp4")
68
- if gen_glonormal_mp4:
69
- command.append("--gen_glonormal_mp4")
70
- try:
71
- subprocess.run(command, check=True)
72
- except subprocess.CalledProcessError as e:
73
- logger.error(f"Error occurred during rendering: {e}.")
74
-
75
- dst_paths = glob(os.path.join(output_root, output_subdir, return_key))
76
-
77
- return dst_paths
78
-
79
-
80
- def merge_images_video(color_images, normal_images, output_path) -> None:
81
- width = color_images[0].shape[1]
82
- combined_video = [
83
- np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]])
84
- for rgb_img, normal_img in zip(color_images, normal_images)
85
- ]
86
- imageio.mimsave(output_path, combined_video, fps=50)
87
-
88
- return
89
-
90
-
91
- def merge_video_video(
92
- video_path1: str, video_path2: str, output_path: str
93
- ) -> None:
94
- """Merge two videos by the left half and the right half of the videos."""
95
- clip1 = VideoFileClip(video_path1)
96
- clip2 = VideoFileClip(video_path2)
97
-
98
- if clip1.size != clip2.size:
99
- raise ValueError("The resolutions of the two videos do not match.")
100
-
101
- width, height = clip1.size
102
- clip1_half = clip1.crop(x1=0, y1=0, x2=width // 2, y2=height)
103
- clip2_half = clip2.crop(x1=width // 2, y1=0, x2=width, y2=height)
104
- final_clip = clips_array([[clip1_half, clip2_half]])
105
- final_clip.write_videofile(output_path, codec="libx264")
106
-
107
-
108
- def filter_small_connected_components(
109
- mask: Union[Image.Image, np.ndarray],
110
- area_ratio: float,
111
- connectivity: int = 8,
112
- ) -> np.ndarray:
113
- if isinstance(mask, Image.Image):
114
- mask = np.array(mask)
115
- num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
116
- mask,
117
- connectivity=connectivity,
118
- )
119
-
120
- small_components = np.zeros_like(mask, dtype=np.uint8)
121
- mask_area = (mask != 0).sum()
122
- min_area = mask_area // area_ratio
123
- for label in range(1, num_labels):
124
- area = stats[label, cv2.CC_STAT_AREA]
125
- if area < min_area:
126
- small_components[labels == label] = 255
127
-
128
- mask = cv2.bitwise_and(mask, cv2.bitwise_not(small_components))
129
-
130
- return mask
131
-
132
-
133
- def filter_image_small_connected_components(
134
- image: Union[Image.Image, np.ndarray],
135
- area_ratio: float = 10,
136
- connectivity: int = 8,
137
- ) -> np.ndarray:
138
- if isinstance(image, Image.Image):
139
- image = image.convert("RGBA")
140
- image = np.array(image)
141
-
142
- mask = image[..., 3]
143
- mask = filter_small_connected_components(mask, area_ratio, connectivity)
144
- image[..., 3] = mask
145
-
146
- return image
147
-
148
-
149
- def combine_images_to_base64(
150
- images: list[str | Image.Image],
151
- cat_row_col: tuple[int, int] = None,
152
- target_wh: tuple[int, int] = (512, 512),
153
- ) -> str:
154
- n_images = len(images)
155
- if cat_row_col is None:
156
- n_col = math.ceil(math.sqrt(n_images))
157
- n_row = math.ceil(n_images / n_col)
158
- else:
159
- n_row, n_col = cat_row_col
160
-
161
- images = [
162
- Image.open(p).convert("RGB") if isinstance(p, str) else p
163
- for p in images[: n_row * n_col]
164
- ]
165
- images = [img.resize(target_wh) for img in images]
166
-
167
- grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
168
- grid = Image.new("RGB", (grid_w, grid_h), (255, 255, 255))
169
-
170
- for idx, img in enumerate(images):
171
- row, col = divmod(idx, n_col)
172
- grid.paste(img, (col * target_wh[0], row * target_wh[1]))
173
-
174
- buffer = BytesIO()
175
- grid.save(buffer, format="PNG")
176
-
177
- return base64.b64encode(buffer.getvalue()).decode("utf-8")
178
-
179
-
180
- if __name__ == "__main__":
181
- # Example usage:
182
- merge_video_video(
183
- "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh_glo_normal.mp4", # noqa
184
- "outputs/imageto3d/room_bottle7/room_bottle_007/URDF_room_bottle_007/mesh.mp4", # noqa
185
- "merge.mp4",
186
- )
187
-
188
- image_base64 = combine_images_to_base64(
189
- [
190
- "outputs/text2image/demo_objects/bed/sample_0.jpg",
191
- "outputs/imageto3d/v2/cups/sample_69/URDF_sample_69/qa_renders/image_color/003.png", # noqa
192
- "outputs/text2image/demo_objects/cardboard/sample_1.jpg",
193
- ]
194
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/utils/tags.py DELETED
@@ -1 +0,0 @@
1
- VERSION = "v0.0.2"
 
 
asset3d_gen/validators/aesthetic_predictor.py DELETED
@@ -1,134 +0,0 @@
1
- import os
2
-
3
- import clip
4
- import numpy as np
5
- import pytorch_lightning as pl
6
- import torch
7
- import torch.nn as nn
8
- from huggingface_hub import snapshot_download
9
- from PIL import Image
10
-
11
-
12
- class AestheticPredictor:
13
- """Aesthetic Score Predictor.
14
-
15
- Args:
16
- clip_model_dir (str): Path to the directory of the CLIP model.
17
- sac_model_path (str): Path to the pre-trained SAC model.
18
- device (str): Device to use for computation ("cuda" or "cpu").
19
- """
20
-
21
- def __init__(self, clip_model_dir=None, sac_model_path=None, device=None):
22
-
23
- self.device = device or (
24
- "cuda" if torch.cuda.is_available() else "cpu"
25
- )
26
-
27
- if clip_model_dir is None:
28
- model_path = snapshot_download(
29
- repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*"
30
- )
31
- suffix = "aesthetic"
32
- model_path = snapshot_download(
33
- repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
34
- )
35
- clip_model_dir = os.path.join(model_path, suffix)
36
-
37
- if sac_model_path is None:
38
- model_path = snapshot_download(
39
- repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*"
40
- )
41
- suffix = "aesthetic"
42
- model_path = snapshot_download(
43
- repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
44
- )
45
- sac_model_path = os.path.join(
46
- model_path, suffix, "sac+logos+ava1-l14-linearMSE.pth"
47
- )
48
-
49
- self.clip_model, self.preprocess = self._load_clip_model(
50
- clip_model_dir
51
- )
52
- self.sac_model = self._load_sac_model(sac_model_path, input_size=768)
53
-
54
- class MLP(pl.LightningModule): # noqa
55
- def __init__(self, input_size):
56
- super().__init__()
57
- self.layers = nn.Sequential(
58
- nn.Linear(input_size, 1024),
59
- nn.Dropout(0.2),
60
- nn.Linear(1024, 128),
61
- nn.Dropout(0.2),
62
- nn.Linear(128, 64),
63
- nn.Dropout(0.1),
64
- nn.Linear(64, 16),
65
- nn.Linear(16, 1),
66
- )
67
-
68
- def forward(self, x):
69
- return self.layers(x)
70
-
71
- @staticmethod
72
- def normalized(a, axis=-1, order=2):
73
- """Normalize the array to unit norm."""
74
- l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
75
- l2[l2 == 0] = 1
76
- return a / np.expand_dims(l2, axis)
77
-
78
- def _load_clip_model(self, model_dir: str, model_name: str = "ViT-L/14"):
79
- """Load the CLIP model."""
80
- model, preprocess = clip.load(
81
- model_name, download_root=model_dir, device=self.device
82
- )
83
- return model, preprocess
84
-
85
- def _load_sac_model(self, model_path, input_size):
86
- """Load the SAC model."""
87
- model = self.MLP(input_size)
88
- ckpt = torch.load(model_path)
89
- model.load_state_dict(ckpt)
90
- model.to(self.device)
91
- model.eval()
92
- return model
93
-
94
- def predict(self, image_path):
95
- """Predict the aesthetic score for a given image.
96
-
97
- Args:
98
- image_path (str): Path to the image file.
99
-
100
- Returns:
101
- float: Predicted aesthetic score.
102
- """
103
- pil_image = Image.open(image_path)
104
- image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
105
-
106
- with torch.no_grad():
107
- # Extract CLIP features
108
- image_features = self.clip_model.encode_image(image)
109
- # Normalize features
110
- normalized_features = self.normalized(
111
- image_features.cpu().detach().numpy()
112
- )
113
- # Predict score
114
- prediction = self.sac_model(
115
- torch.from_numpy(normalized_features)
116
- .type(torch.FloatTensor)
117
- .to(self.device)
118
- )
119
-
120
- return prediction.item()
121
-
122
-
123
- if __name__ == "__main__":
124
- # Configuration
125
- img_path = "/home/users/xinjie.wang/xinjie/asset3d-gen/outputs/imageto3d/demo_objects/bed/sample_0/sample_0_raw.png" # noqa
126
- # clip_model_dir = "/horizon-bucket/robot_lab/users/xinjie.wang/weights/clip" # noqa
127
- # sac_model_path = "/horizon-bucket/robot_lab/users/xinjie.wang/weights/sac/sac+logos+ava1-l14-linearMSE.pth" # noqa
128
-
129
- # Initialize the predictor
130
- predictor = AestheticPredictor()
131
-
132
- # Predict the aesthetic score
133
- score = predictor.predict(img_path)
134
- print("Aesthetic score predicted by the model:", score)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/validators/quality_checkers.py DELETED
@@ -1,195 +0,0 @@
1
- import logging
2
- import os
3
-
4
- from tqdm import tqdm
5
- from asset3d_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
6
- from asset3d_gen.utils.process_media import render_asset3d
7
- from asset3d_gen.validators.aesthetic_predictor import AestheticPredictor
8
-
9
- logging.basicConfig(level=logging.INFO)
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- class BaseChecker:
14
- def __init__(self, prompt: str = None, verbose: bool = False) -> None:
15
- self.prompt = prompt
16
- self.verbose = verbose
17
-
18
- def query(self, *args, **kwargs):
19
- raise NotImplementedError(
20
- "Subclasses must implement the query method."
21
- )
22
-
23
- def __call__(self, *args, **kwargs) -> bool:
24
- response = self.query(*args, **kwargs)
25
- if response is None:
26
- response = "Error when calling gpt api."
27
-
28
- if self.verbose and response != "YES":
29
- logger.info(response)
30
-
31
- flag = "YES" in response
32
- response = "YES" if flag else response
33
-
34
- return flag, response
35
-
36
- @staticmethod
37
- def validate(
38
- checkers: list["BaseChecker"], images_list: list[list[str]]
39
- ) -> list:
40
- assert len(checkers) == len(images_list)
41
- results = []
42
- overall_result = True
43
- for checker, images in zip(checkers, images_list):
44
- qa_flag, qa_info = checker(images)
45
- if isinstance(qa_info, str):
46
- qa_info = qa_info.replace("\n", ".")
47
- results.append([checker.__class__.__name__, qa_info])
48
- if qa_flag is False:
49
- overall_result = False
50
-
51
- results.append(["overall", "YES" if overall_result else "NO"])
52
-
53
- return results
54
-
55
-
56
- class MeshGeoChecker(BaseChecker):
57
- def __init__(
58
- self,
59
- gpt_client: GPTclient,
60
- prompt: str = None,
61
- verbose: bool = False,
62
- ) -> None:
63
- super().__init__(prompt, verbose)
64
- self.gpt_client = gpt_client
65
- if self.prompt is None:
66
- self.prompt = """
67
- Refer to the provided multi-view rendering images to evaluate
68
- whether the geometry of the 3D object asset is complete and
69
- whether the asset can be placed stably on the ground.
70
- Return "YES" only if reach the requirments,
71
- otherwise "NO" and explain the reason very briefly.
72
- """
73
-
74
- def query(self, image_paths: str) -> str:
75
- # Hardcode tmp because of the openrouter can't input multi images.
76
- if "openrouter" in self.gpt_client.endpoint:
77
- from asset3d_gen.utils.process_media import (
78
- combine_images_to_base64,
79
- )
80
-
81
- image_paths = combine_images_to_base64(image_paths)
82
-
83
- return self.gpt_client.query(
84
- text_prompt=self.prompt,
85
- image_base64=image_paths,
86
- )
87
-
88
-
89
- class ImageSegChecker(BaseChecker):
90
- def __init__(
91
- self,
92
- gpt_client: GPTclient,
93
- prompt: str = None,
94
- verbose: bool = False,
95
- ) -> None:
96
- super().__init__(prompt, verbose)
97
- self.gpt_client = gpt_client
98
- if self.prompt is None:
99
- self.prompt = """
100
- The first image is the original, and the second image is the
101
- result after segmenting the main object. Evaluate the segmentation
102
- quality to ensure the main object is clearly segmented without
103
- significant truncation. Note that the foreground of the object
104
- needs to be extracted instead of the background.
105
- Minor imperfections can be ignored. If segmentation is acceptable,
106
- return "YES" only; otherwise, return "NO" with
107
- very brief explanation.
108
- """
109
-
110
- def query(self, image_paths: list[str]) -> str:
111
- if len(image_paths) != 2:
112
- raise ValueError(
113
- "ImageSegChecker requires exactly two images: [raw_image, seg_image]." # noqa
114
- )
115
- # Hardcode tmp because of the openrouter can't input multi images.
116
- if "openrouter" in self.gpt_client.endpoint:
117
- from asset3d_gen.utils.process_media import (
118
- combine_images_to_base64,
119
- )
120
-
121
- image_paths = combine_images_to_base64(image_paths)
122
-
123
- return self.gpt_client.query(
124
- text_prompt=self.prompt,
125
- image_base64=image_paths,
126
- )
127
-
128
-
129
- class ImageAestheticChecker(BaseChecker):
130
- def __init__(
131
- self,
132
- clip_model_dir: str = None,
133
- sac_model_path: str = None,
134
- thresh: float = 4.50,
135
- verbose: bool = False,
136
- ) -> None:
137
- super().__init__(verbose=verbose)
138
- self.clip_model_dir = clip_model_dir
139
- self.sac_model_path = sac_model_path
140
- self.thresh = thresh
141
- self.predictor = AestheticPredictor(clip_model_dir, sac_model_path)
142
-
143
- def query(self, image_paths: list[str]) -> float:
144
- scores = [self.predictor.predict(img_path) for img_path in image_paths]
145
- return sum(scores) / len(scores)
146
-
147
- def __call__(self, image_paths: list[str], **kwargs) -> bool:
148
- avg_score = self.query(image_paths)
149
- if self.verbose:
150
- logger.info(f"Average aesthetic score: {avg_score}")
151
- return avg_score > self.thresh, avg_score
152
-
153
-
154
- if __name__ == "__main__":
155
- geo_checker = MeshGeoChecker(GPT_CLIENT)
156
- seg_checker = ImageSegChecker(GPT_CLIENT)
157
- aesthetic_checker = ImageAestheticChecker(
158
- "/horizon-bucket/robot_lab/users/xinjie.wang/weights/clip",
159
- "/horizon-bucket/robot_lab/users/xinjie.wang/weights/sac/sac+logos+ava1-l14-linearMSE.pth", # noqa
160
- )
161
-
162
- checkers = [geo_checker, seg_checker, aesthetic_checker]
163
-
164
- output_root = "outputs/test_gpt"
165
-
166
- fails = []
167
- for idx in tqdm(range(150)):
168
- mesh_path = f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}.obj" # noqa
169
- if not os.path.exists(mesh_path):
170
- continue
171
- image_paths = render_asset3d(
172
- mesh_path,
173
- f"{output_root}/{idx}",
174
- num_images=8,
175
- elevation=(30, -30),
176
- distance=5.5,
177
- )
178
-
179
- for cid, checker in enumerate(checkers):
180
- if isinstance(checker, ImageSegChecker):
181
- images = [
182
- f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_raw.png", # noqa
183
- f"outputs/imageto3d/demo_objects/cups/sample_{idx}/sample_{idx}_cond.png", # noqa
184
- ]
185
- else:
186
- images = image_paths
187
- result, info = checker(images)
188
- logger.info(
189
- f"Checker {checker.__class__.__name__}: {result}, {info}, mesh {mesh_path}" # noqa
190
- )
191
-
192
- if result is False:
193
- fails.append((idx, cid, info))
194
-
195
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
asset3d_gen/validators/urdf_convertor.py DELETED
@@ -1,423 +0,0 @@
1
- import logging
2
- import os
3
- import shutil
4
- import xml.etree.ElementTree as ET
5
- import zipfile
6
- from datetime import datetime
7
- from xml.dom.minidom import parseString
8
-
9
- import numpy as np
10
- import trimesh
11
- from asset3d_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
12
- from asset3d_gen.utils.process_media import render_asset3d
13
- from asset3d_gen.utils.tags import VERSION
14
-
15
- logging.basicConfig(level=logging.INFO)
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- __all__ = ["URDFGenerator"]
20
-
21
-
22
- URDF_TEMPLATE = """
23
- <robot name="template_robot">
24
- <link name="template_link">
25
- <visual>
26
- <geometry>
27
- <mesh filename="mesh.obj" scale="1.0 1.0 1.0"/>
28
- </geometry>
29
- </visual>
30
- <collision>
31
- <geometry>
32
- <mesh filename="mesh.obj" scale="1.0 1.0 1.0"/>
33
- </geometry>
34
- <gazebo>
35
- <mu1>0.8</mu1> <!-- 主摩擦系数 -->
36
- <mu2>0.6</mu2> <!-- 次摩擦系数 -->
37
- </gazebo>
38
- </collision>
39
- <inertial>
40
- <mass value="1.0"/>
41
- <origin xyz="0 0 0"/>
42
- <inertia ixx="1.0" ixy="0.0" ixz="0.0" iyy="1.0" iyz="0.0" izz="1.0"/>
43
- </inertial>
44
- <extra_info>
45
- <scale>1.0</scale>
46
- <version>"0.0.0"</version>
47
- <category>"unknown"</category>
48
- <description>"unknown"</description>
49
- <min_height>0.0</min_height>
50
- <max_height>0.0</max_height>
51
- <real_height>0.0</real_height>
52
- <min_mass>0.0</min_mass>
53
- <max_mass>0.0</max_mass>
54
- <generate_time>"-1"</generate_time>
55
- <gs_model>""</gs_model>
56
- </extra_info>
57
- </link>
58
- </robot>
59
- """
60
-
61
-
62
- def zip_files(input_paths: list[str], output_zip: str) -> str:
63
- with zipfile.ZipFile(output_zip, "w", zipfile.ZIP_DEFLATED) as zipf:
64
- for input_path in input_paths:
65
- if not os.path.exists(input_path):
66
- raise FileNotFoundError(f"File not found: {input_path}")
67
-
68
- if os.path.isdir(input_path):
69
- for root, _, files in os.walk(input_path):
70
- for file in files:
71
- file_path = os.path.join(root, file)
72
- arcname = os.path.relpath(
73
- file_path, start=os.path.commonpath(input_paths)
74
- )
75
- zipf.write(file_path, arcname=arcname)
76
- else:
77
- arcname = os.path.relpath(
78
- input_path, start=os.path.commonpath(input_paths)
79
- )
80
- zipf.write(input_path, arcname=arcname)
81
-
82
- return output_zip
83
-
84
-
85
- class URDFGenerator(object):
86
- def __init__(
87
- self,
88
- gpt_client: GPTclient,
89
- mesh_file_list: list[str] = ["material_0.png", "material.mtl"],
90
- prompt_template: str = None,
91
- attrs_name: list[str] = None,
92
- render_dir: str = "urdf_renders",
93
- render_view_num: int = 4,
94
- ) -> None:
95
- if mesh_file_list is None:
96
- mesh_file_list = []
97
- self.mesh_file_list = mesh_file_list
98
- self.output_mesh_dir = "mesh"
99
- self.output_render_dir = render_dir
100
- self.gpt_client = gpt_client
101
- self.render_view_num = render_view_num
102
- if render_view_num == 4:
103
- view_desc = "This is orthographic projection showing the front, left, right and back views " # noqa
104
- else:
105
- view_desc = "This is the rendered views "
106
-
107
- if prompt_template is None:
108
- prompt_template = (
109
- view_desc
110
- + """of the 3D object asset,
111
- category: {category}.
112
- Give the category of this object asset (within 3 words),
113
- (if category is already provided, use it directly),
114
- accurately describe this 3D object asset (within 15 words),
115
- and give the recommended geometric height range (unit: meter),
116
- weight range (unit: kilogram), the average static friction
117
- coefficient of the object relative to rubber and the average
118
- dynamic friction coefficient of the object relative to rubber.
119
- Return response format as shown in Example.
120
-
121
- Example:
122
- Category: cup
123
- Description: shiny golden cup with floral design
124
- Height: 0.1-0.15 m
125
- Weight: 0.3-0.6 kg
126
- Static friction coefficient: 1.1
127
- Dynamic friction coefficient: 0.9
128
- """
129
- )
130
-
131
- self.prompt_template = prompt_template
132
- if attrs_name is None:
133
- attrs_name = [
134
- "category",
135
- "description",
136
- "min_height",
137
- "max_height",
138
- "real_height",
139
- "min_mass",
140
- "max_mass",
141
- "version",
142
- "generate_time",
143
- "gs_model",
144
- ]
145
- self.attrs_name = attrs_name
146
-
147
- def parse_response(self, response: str) -> dict[str, any]:
148
- lines = response.split("\n")
149
- lines = [line.strip() for line in lines if line]
150
- category = lines[0].split(": ")[1]
151
- description = lines[1].split(": ")[1]
152
- min_height, max_height = map(
153
- lambda x: float(x.strip().replace(",", "").split()[0]),
154
- lines[2].split(": ")[1].split("-"),
155
- )
156
- min_mass, max_mass = map(
157
- lambda x: float(x.strip().replace(",", "").split()[0]),
158
- lines[3].split(": ")[1].split("-"),
159
- )
160
- mu1 = float(lines[4].split(": ")[1].replace(",", ""))
161
- mu2 = float(lines[5].split(": ")[1].replace(",", ""))
162
-
163
- return {
164
- "category": category.lower(),
165
- "description": description.lower(),
166
- "min_height": round(min_height, 4),
167
- "max_height": round(max_height, 4),
168
- "real_height": round((min_height + max_height) / 2, 4),
169
- "min_mass": round(min_mass, 4),
170
- "max_mass": round(max_mass, 4),
171
- "mu1": round(mu1, 2),
172
- "mu2": round(mu2, 2),
173
- "version": VERSION,
174
- "generate_time": datetime.now().strftime("%Y%m%d%H%M%S"),
175
- }
176
-
177
- def generate_urdf(
178
- self,
179
- input_mesh: str,
180
- output_dir: str,
181
- attr_dict: dict,
182
- output_name: str = None,
183
- ) -> str:
184
- """Generate a URDF file for a given mesh with specified attributes.
185
-
186
- Args:
187
- input_mesh (str): Path to the input mesh file.
188
- output_dir (str): Directory to store the generated URDF
189
- and processed mesh.
190
- attr_dict (dict): Dictionary containing attributes like height,
191
- mass, and friction coefficients.
192
- output_name (str, optional): Name for the generated URDF and robot.
193
-
194
- Returns:
195
- str: Path to the generated URDF file.
196
- """
197
-
198
- # 1. Load and normalize the mesh
199
- mesh = trimesh.load(input_mesh)
200
- mesh_scale = np.ptp(mesh.vertices, axis=0).max()
201
- mesh.vertices /= mesh_scale # Normalize to [-0.5, 0.5]
202
- raw_height = np.ptp(mesh.vertices, axis=0)[1]
203
-
204
- # 2. Scale the mesh to real height
205
- real_height = attr_dict["real_height"]
206
- scale = round(real_height / raw_height, 6)
207
- mesh = mesh.apply_scale(scale)
208
-
209
- # 3. Prepare output directories and save scaled mesh
210
- mesh_folder = os.path.join(output_dir, self.output_mesh_dir)
211
- os.makedirs(mesh_folder, exist_ok=True)
212
-
213
- obj_name = os.path.basename(input_mesh)
214
- mesh_output_path = os.path.join(mesh_folder, obj_name)
215
- mesh.export(mesh_output_path)
216
-
217
- # 4. Copy additional mesh files, if any
218
- input_dir = os.path.dirname(input_mesh)
219
- for file in self.mesh_file_list:
220
- src_file = os.path.join(input_dir, file)
221
- dest_file = os.path.join(mesh_folder, file)
222
- if os.path.isfile(src_file):
223
- shutil.copy(src_file, dest_file)
224
-
225
- # 5. Determine output name
226
- if output_name is None:
227
- output_name = os.path.splitext(obj_name)[0]
228
-
229
- # 6. Load URDF template and update attributes
230
- robot = ET.fromstring(URDF_TEMPLATE)
231
- robot.set("name", output_name)
232
-
233
- link = robot.find("link")
234
- if link is None:
235
- raise ValueError("URDF template is missing 'link' element.")
236
- link.set("name", output_name)
237
-
238
- # Update visual geometry
239
- visual = link.find("visual/geometry/mesh")
240
- if visual is not None:
241
- visual.set(
242
- "filename", os.path.join(self.output_mesh_dir, obj_name)
243
- )
244
- visual.set("scale", "1.0 1.0 1.0")
245
-
246
- # Update collision geometry
247
- collision = link.find("collision/geometry/mesh")
248
- if collision is not None:
249
- collision.set(
250
- "filename", os.path.join(self.output_mesh_dir, obj_name)
251
- )
252
- collision.set("scale", "1.0 1.0 1.0")
253
-
254
- # Update friction coefficients
255
- gazebo = link.find("collision/gazebo")
256
- if gazebo is not None:
257
- for param, key in zip(["mu1", "mu2"], ["mu1", "mu2"]):
258
- element = gazebo.find(param)
259
- if element is not None:
260
- element.text = f"{attr_dict[key]:.2f}"
261
-
262
- # Update mass
263
- inertial = link.find("inertial/mass")
264
- if inertial is not None:
265
- mass_value = (attr_dict["min_mass"] + attr_dict["max_mass"]) / 2
266
- inertial.set("value", f"{mass_value:.4f}")
267
-
268
- # Add extra_info element to the link
269
- extra_info = link.find("extra_info/scale")
270
- if extra_info is not None:
271
- extra_info.text = f"{scale:.6f}"
272
-
273
- for key in self.attrs_name:
274
- extra_info = link.find(f"extra_info/{key}")
275
- if extra_info is not None and key in attr_dict:
276
- extra_info.text = f"{attr_dict[key]}"
277
-
278
- # 7. Write URDF to file
279
- os.makedirs(output_dir, exist_ok=True)
280
- urdf_path = os.path.join(output_dir, f"{output_name}.urdf")
281
- tree = ET.ElementTree(robot)
282
- tree.write(urdf_path, encoding="utf-8", xml_declaration=True)
283
-
284
- logger.info(f"URDF file saved to {urdf_path}")
285
-
286
- return urdf_path
287
-
288
- @staticmethod
289
- def get_attr_from_urdf(
290
- urdf_path: str,
291
- attr_root: str = ".//link/extra_info",
292
- attr_name: str = "scale",
293
- ) -> float:
294
- if not os.path.exists(urdf_path):
295
- raise FileNotFoundError(f"URDF file not found: {urdf_path}")
296
-
297
- mesh_scale = 1.0
298
- tree = ET.parse(urdf_path)
299
- root = tree.getroot()
300
- extra_info = root.find(attr_root)
301
- if extra_info is not None:
302
- scale_element = extra_info.find(attr_name)
303
- if scale_element is not None:
304
- mesh_scale = float(scale_element.text)
305
-
306
- return mesh_scale
307
-
308
- @staticmethod
309
- def add_quality_tag(
310
- urdf_path: str, results, output_path: str = None
311
- ) -> None:
312
- if output_path is None:
313
- output_path = urdf_path
314
-
315
- tree = ET.parse(urdf_path)
316
- root = tree.getroot()
317
- custom_data = ET.SubElement(root, "custom_data")
318
- quality = ET.SubElement(custom_data, "quality")
319
- for key, value in results:
320
- checker_tag = ET.SubElement(quality, key)
321
- checker_tag.text = str(value)
322
-
323
- rough_string = ET.tostring(root, encoding="utf-8")
324
- formatted_string = parseString(rough_string).toprettyxml(indent=" ")
325
- cleaned_string = "\n".join(
326
- [line for line in formatted_string.splitlines() if line.strip()]
327
- )
328
-
329
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
330
- with open(output_path, "w", encoding="utf-8") as f:
331
- f.write(cleaned_string)
332
-
333
- logger.info(f"URDF files saved to {output_path}")
334
-
335
- def get_estimated_attributes(self, asset_attrs: dict):
336
- estimated_attrs = {
337
- "height": round(
338
- (asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4
339
- ),
340
- "mass": round(
341
- (asset_attrs["min_mass"] + asset_attrs["max_mass"]) / 2, 4
342
- ),
343
- "mu": round((asset_attrs["mu1"] + asset_attrs["mu2"]) / 2, 4),
344
- "category": asset_attrs["category"],
345
- }
346
-
347
- return estimated_attrs
348
-
349
- def __call__(
350
- self,
351
- mesh_path: str,
352
- output_root: str,
353
- text_prompt: str = None,
354
- category: str = "unknown",
355
- **kwargs,
356
- ):
357
- if text_prompt is None or len(text_prompt) == 0:
358
- text_prompt = self.prompt_template
359
- text_prompt = text_prompt.format(category=category.lower())
360
-
361
- image_path = render_asset3d(
362
- mesh_path,
363
- output_root,
364
- num_images=self.render_view_num,
365
- output_subdir=self.output_render_dir,
366
- )
367
-
368
- # Hardcode tmp because of the openrouter can't input multi images.
369
- if "openrouter" in self.gpt_client.endpoint:
370
- from asset3d_gen.utils.process_media import (
371
- combine_images_to_base64,
372
- )
373
-
374
- image_path = combine_images_to_base64(image_path)
375
-
376
- response = self.gpt_client.query(text_prompt, image_path)
377
- if response is None:
378
- asset_attrs = {
379
- "category": "unknown",
380
- "description": "unknown",
381
- "min_height": 1,
382
- "max_height": 1,
383
- "real_height": 1,
384
- "min_mass": 1,
385
- "max_mass": 1,
386
- "mu1": 0.8,
387
- "mu2": 0.6,
388
- "version": VERSION,
389
- "generate_time": datetime.now().strftime("%Y%m%d%H%M%S"),
390
- }
391
- else:
392
- asset_attrs = self.parse_response(response)
393
- for key in self.attrs_name:
394
- if key in kwargs:
395
- asset_attrs[key] = kwargs[key]
396
-
397
- self.estimated_attrs = self.get_estimated_attributes(asset_attrs)
398
-
399
- urdf_path = self.generate_urdf(mesh_path, output_root, asset_attrs)
400
-
401
- logger.info(f"response: {response}")
402
-
403
- return urdf_path
404
-
405
-
406
- if __name__ == "__main__":
407
- urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4)
408
- urdf_path = urdf_gen(
409
- mesh_path="scripts/apps/assets/example_texture/meshes/robot.obj",
410
- output_root="outputs/test_urdf",
411
- # category="coffee machine",
412
- # min_height=1.0,
413
- # max_height=1.2,
414
- version=VERSION,
415
- )
416
-
417
- # zip_files(
418
- # input_paths=[
419
- # "scripts/apps/tmp/2umpdum3e5n/URDF_sample/mesh",
420
- # "scripts/apps/tmp/2umpdum3e5n/URDF_sample/sample.urdf"
421
- # ],
422
- # output_zip="zip.zip"
423
- # )