davidvgilmore commited on
Commit
1b41729
·
verified ·
1 Parent(s): 2dd594f

Upload hy3dgen/shapegen/pipelines.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. hy3dgen/shapegen/pipelines.py +589 -0
hy3dgen/shapegen/pipelines.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0
2
+ # and Other Licenses of the Third-Party Components therein:
3
+ # The below Model in this distribution may have been modified by THL A29 Limited
4
+ # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
+
6
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
+ # The below software and/or models in this distribution may have been
8
+ # modified by THL A29 Limited ("Tencent Modifications").
9
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
+
11
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
+ # except for the third-party components listed below.
13
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
+ # in the repsective licenses of these third-party components.
15
+ # Users must comply with all terms and conditions of original licenses of these third-party
16
+ # components and must ensure that the usage of the third party components adheres to
17
+ # all relevant laws and regulations.
18
+
19
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
20
+ # their software and algorithms, including trained model weights, parameters (including
21
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
23
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
+
25
+ import copy
26
+ import importlib
27
+ import inspect
28
+ import logging
29
+ import os
30
+ from typing import List, Optional, Union
31
+
32
+ import numpy as np
33
+ import torch
34
+ import trimesh
35
+ import yaml
36
+ from PIL import Image
37
+ from diffusers.utils.torch_utils import randn_tensor
38
+ from tqdm import tqdm
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ def retrieve_timesteps(
44
+ scheduler,
45
+ num_inference_steps: Optional[int] = None,
46
+ device: Optional[Union[str, torch.device]] = None,
47
+ timesteps: Optional[List[int]] = None,
48
+ sigmas: Optional[List[float]] = None,
49
+ **kwargs,
50
+ ):
51
+ """
52
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
53
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
54
+
55
+ Args:
56
+ scheduler (`SchedulerMixin`):
57
+ The scheduler to get timesteps from.
58
+ num_inference_steps (`int`):
59
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
60
+ must be `None`.
61
+ device (`str` or `torch.device`, *optional*):
62
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
63
+ timesteps (`List[int]`, *optional*):
64
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
65
+ `num_inference_steps` and `sigmas` must be `None`.
66
+ sigmas (`List[float]`, *optional*):
67
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
68
+ `num_inference_steps` and `timesteps` must be `None`.
69
+
70
+ Returns:
71
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
72
+ second element is the number of inference steps.
73
+ """
74
+ if timesteps is not None and sigmas is not None:
75
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
76
+ if timesteps is not None:
77
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
78
+ if not accepts_timesteps:
79
+ raise ValueError(
80
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
81
+ f" timestep schedules. Please check whether you are using the correct scheduler."
82
+ )
83
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
84
+ timesteps = scheduler.timesteps
85
+ num_inference_steps = len(timesteps)
86
+ elif sigmas is not None:
87
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
88
+ if not accept_sigmas:
89
+ raise ValueError(
90
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
91
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
92
+ )
93
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
94
+ timesteps = scheduler.timesteps
95
+ num_inference_steps = len(timesteps)
96
+ else:
97
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
98
+ timesteps = scheduler.timesteps
99
+ return timesteps, num_inference_steps
100
+
101
+
102
+ def export_to_trimesh(mesh_output):
103
+ if isinstance(mesh_output, list):
104
+ outputs = []
105
+ for mesh in mesh_output:
106
+ if mesh is None:
107
+ outputs.append(None)
108
+ else:
109
+ mesh.mesh_f = mesh.mesh_f[:, ::-1]
110
+ mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
111
+ outputs.append(mesh_output)
112
+ return outputs
113
+ else:
114
+ mesh_output.mesh_f = mesh_output.mesh_f[:, ::-1]
115
+ mesh_output = trimesh.Trimesh(mesh_output.mesh_v, mesh_output.mesh_f)
116
+ return mesh_output
117
+
118
+
119
+ def get_obj_from_str(string, reload=False):
120
+ module, cls = string.rsplit(".", 1)
121
+ if reload:
122
+ module_imp = importlib.import_module(module)
123
+ importlib.reload(module_imp)
124
+ return getattr(importlib.import_module(module, package=None), cls)
125
+
126
+
127
+ def instantiate_from_config(config, **kwargs):
128
+ if "target" not in config:
129
+ raise KeyError("Expected key `target` to instantiate.")
130
+ cls = get_obj_from_str(config["target"])
131
+ params = config.get("params", dict())
132
+ kwargs.update(params)
133
+ instance = cls(**kwargs)
134
+ return instance
135
+
136
+
137
+ class Hunyuan3DDiTPipeline:
138
+ @classmethod
139
+ def from_single_file(
140
+ cls,
141
+ ckpt_path,
142
+ config_path,
143
+ device='cuda',
144
+ dtype=torch.float16,
145
+ use_safetensors=None,
146
+ **kwargs,
147
+ ):
148
+ # load config
149
+ with open(config_path, 'r') as f:
150
+ config = yaml.safe_load(f)
151
+
152
+ # load ckpt
153
+ if use_safetensors:
154
+ ckpt_path = ckpt_path.replace('.ckpt', '.safetensors')
155
+ if not os.path.exists(ckpt_path):
156
+ raise FileNotFoundError(f"Model file {ckpt_path} not found")
157
+ logger.info(f"Loading model from {ckpt_path}")
158
+
159
+ if use_safetensors:
160
+ # parse safetensors
161
+ import safetensors.torch
162
+ safetensors_ckpt = safetensors.torch.load_file(ckpt_path, device='cpu')
163
+ ckpt = {}
164
+ for key, value in safetensors_ckpt.items():
165
+ model_name = key.split('.')[0]
166
+ new_key = key[len(model_name) + 1:]
167
+ if model_name not in ckpt:
168
+ ckpt[model_name] = {}
169
+ ckpt[model_name][new_key] = value
170
+ else:
171
+ ckpt = torch.load(ckpt_path, map_location='cpu')
172
+ # load model
173
+ model = instantiate_from_config(config['model'])
174
+ model.load_state_dict(ckpt['model'])
175
+ vae = instantiate_from_config(config['vae'])
176
+ vae.load_state_dict(ckpt['vae'])
177
+ conditioner = instantiate_from_config(config['conditioner'])
178
+ if 'conditioner' in ckpt:
179
+ conditioner.load_state_dict(ckpt['conditioner'])
180
+ image_processor = instantiate_from_config(config['image_processor'])
181
+ scheduler = instantiate_from_config(config['scheduler'])
182
+
183
+ model_kwargs = dict(
184
+ vae=vae,
185
+ model=model,
186
+ scheduler=scheduler,
187
+ conditioner=conditioner,
188
+ image_processor=image_processor,
189
+ scheduler_cfg=config['scheduler'],
190
+ device=device,
191
+ dtype=dtype,
192
+ )
193
+ model_kwargs.update(kwargs)
194
+
195
+ return cls(
196
+ **model_kwargs
197
+ )
198
+
199
+ @classmethod
200
+ def from_pretrained(
201
+ cls,
202
+ model_path,
203
+ ckpt_name='model.ckpt',
204
+ config_name='config.yaml',
205
+ device='cuda',
206
+ dtype=torch.float16,
207
+ use_safetensors=None,
208
+ **kwargs,
209
+ ):
210
+ original_model_path = model_path
211
+ if not os.path.exists(model_path):
212
+ # try local path
213
+ base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
214
+ model_path = os.path.expanduser(os.path.join(base_dir, model_path, 'hunyuan3d-dit-v2-0'))
215
+ if not os.path.exists(model_path):
216
+ try:
217
+ import huggingface_hub
218
+ # download from huggingface
219
+ path = huggingface_hub.snapshot_download(repo_id=original_model_path)
220
+ model_path = os.path.join(path, 'hunyuan3d-dit-v2-0')
221
+ except ImportError:
222
+ logger.warning(
223
+ "You need to install HuggingFace Hub to load models from the hub."
224
+ )
225
+ raise RuntimeError(f"Model path {model_path} not found")
226
+ if not os.path.exists(model_path):
227
+ raise FileNotFoundError(f"Model path {original_model_path} not found")
228
+
229
+ config_path = os.path.join(model_path, config_name)
230
+ ckpt_path = os.path.join(model_path, ckpt_name)
231
+ return cls.from_single_file(
232
+ ckpt_path,
233
+ config_path,
234
+ device=device,
235
+ dtype=dtype,
236
+ use_safetensors=use_safetensors,
237
+ **kwargs
238
+ )
239
+
240
+ def __init__(
241
+ self,
242
+ vae,
243
+ model,
244
+ scheduler,
245
+ conditioner,
246
+ image_processor,
247
+ device='cuda',
248
+ dtype=torch.float16,
249
+ **kwargs
250
+ ):
251
+ self.vae = vae
252
+ self.model = model
253
+ self.scheduler = scheduler
254
+ self.conditioner = conditioner
255
+ self.image_processor = image_processor
256
+ self.kwargs = kwargs
257
+
258
+ self.to(device, dtype)
259
+
260
+ def to(self, device=None, dtype=None):
261
+ if device is not None:
262
+ self.device = torch.device(device)
263
+ self.vae.to(device)
264
+ self.model.to(device)
265
+ self.conditioner.to(device)
266
+ if dtype is not None:
267
+ self.dtype = dtype
268
+ self.vae.to(dtype=dtype)
269
+ self.model.to(dtype=dtype)
270
+ self.conditioner.to(dtype=dtype)
271
+
272
+ def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance):
273
+ bsz = image.shape[0]
274
+ cond = self.conditioner(image=image, mask=mask)
275
+
276
+ if do_classifier_free_guidance:
277
+ un_cond = self.conditioner.unconditional_embedding(bsz)
278
+
279
+ if dual_guidance:
280
+ un_cond_drop_main = copy.deepcopy(un_cond)
281
+ un_cond_drop_main['additional'] = cond['additional']
282
+
283
+ def cat_recursive(a, b, c):
284
+ if isinstance(a, torch.Tensor):
285
+ return torch.cat([a, b, c], dim=0).to(self.dtype)
286
+ out = {}
287
+ for k in a.keys():
288
+ out[k] = cat_recursive(a[k], b[k], c[k])
289
+ return out
290
+
291
+ cond = cat_recursive(cond, un_cond_drop_main, un_cond)
292
+ else:
293
+ un_cond = self.conditioner.unconditional_embedding(bsz)
294
+
295
+ def cat_recursive(a, b):
296
+ if isinstance(a, torch.Tensor):
297
+ return torch.cat([a, b], dim=0).to(self.dtype)
298
+ out = {}
299
+ for k in a.keys():
300
+ out[k] = cat_recursive(a[k], b[k])
301
+ return out
302
+
303
+ cond = cat_recursive(cond, un_cond)
304
+ return cond
305
+
306
+ def prepare_extra_step_kwargs(self, generator, eta):
307
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
308
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
309
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
310
+ # and should be between [0, 1]
311
+
312
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
313
+ extra_step_kwargs = {}
314
+ if accepts_eta:
315
+ extra_step_kwargs["eta"] = eta
316
+
317
+ # check if the scheduler accepts generator
318
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
319
+ if accepts_generator:
320
+ extra_step_kwargs["generator"] = generator
321
+ return extra_step_kwargs
322
+
323
+ def prepare_latents(self, batch_size, dtype, device, generator, latents=None):
324
+ shape = (batch_size, *self.vae.latent_shape)
325
+ if isinstance(generator, list) and len(generator) != batch_size:
326
+ raise ValueError(
327
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
328
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
329
+ )
330
+
331
+ if latents is None:
332
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
333
+ else:
334
+ latents = latents.to(device)
335
+
336
+ # scale the initial noise by the standard deviation required by the scheduler
337
+ latents = latents * getattr(self.scheduler, 'init_noise_sigma', 1.0)
338
+ return latents
339
+
340
+ def prepare_image(self, image):
341
+ if isinstance(image, str) and not os.path.exists(image):
342
+ raise FileNotFoundError(f"Couldn't find image at path {image}")
343
+
344
+ if not isinstance(image, list):
345
+ image = [image]
346
+ image_pts = []
347
+ mask_pts = []
348
+ for img in image:
349
+ image_pt, mask_pt = self.image_processor(img, return_mask=True)
350
+ image_pts.append(image_pt)
351
+ mask_pts.append(mask_pt)
352
+
353
+ image_pts = torch.cat(image_pts, dim=0).to(self.device, dtype=self.dtype)
354
+ if mask_pts[0] is not None:
355
+ mask_pts = torch.cat(mask_pts, dim=0).to(self.device, dtype=self.dtype)
356
+ else:
357
+ mask_pts = None
358
+ return image_pts, mask_pts
359
+
360
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
361
+ """
362
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
363
+
364
+ Args:
365
+ timesteps (`torch.Tensor`):
366
+ generate embedding vectors at these timesteps
367
+ embedding_dim (`int`, *optional*, defaults to 512):
368
+ dimension of the embeddings to generate
369
+ dtype:
370
+ data type of the generated embeddings
371
+
372
+ Returns:
373
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
374
+ """
375
+ assert len(w.shape) == 1
376
+ w = w * 1000.0
377
+
378
+ half_dim = embedding_dim // 2
379
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
380
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
381
+ emb = w.to(dtype)[:, None] * emb[None, :]
382
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
383
+ if embedding_dim % 2 == 1: # zero pad
384
+ emb = torch.nn.functional.pad(emb, (0, 1))
385
+ assert emb.shape == (w.shape[0], embedding_dim)
386
+ return emb
387
+
388
+ @torch.no_grad()
389
+ def __call__(
390
+ self,
391
+ image: Union[str, List[str], Image.Image] = None,
392
+ num_inference_steps: int = 50,
393
+ timesteps: List[int] = None,
394
+ sigmas: List[float] = None,
395
+ eta: float = 0.0,
396
+ guidance_scale: float = 7.5,
397
+ dual_guidance_scale: float = 10.5,
398
+ dual_guidance: bool = True,
399
+ generator=None,
400
+ box_v=1.01,
401
+ octree_resolution=384,
402
+ mc_level=-1 / 512,
403
+ num_chunks=8000,
404
+ mc_algo='mc',
405
+ output_type: Optional[str] = "trimesh",
406
+ enable_pbar=True,
407
+ **kwargs,
408
+ ) -> List[List[trimesh.Trimesh]]:
409
+ callback = kwargs.pop("callback", None)
410
+ callback_steps = kwargs.pop("callback_steps", None)
411
+
412
+ device = self.device
413
+ dtype = self.dtype
414
+ do_classifier_free_guidance = guidance_scale >= 0 and \
415
+ getattr(self.model, 'guidance_cond_proj_dim', None) is None
416
+ dual_guidance = dual_guidance_scale >= 0 and dual_guidance
417
+
418
+ image, mask = self.prepare_image(image)
419
+ cond = self.encode_cond(image=image,
420
+ mask=mask,
421
+ do_classifier_free_guidance=do_classifier_free_guidance,
422
+ dual_guidance=dual_guidance)
423
+ batch_size = image.shape[0]
424
+
425
+ t_dtype = torch.long
426
+ scheduler = instantiate_from_config(self.kwargs['scheduler_cfg'])
427
+ timesteps, num_inference_steps = retrieve_timesteps(
428
+ scheduler, num_inference_steps, device, timesteps, sigmas
429
+ )
430
+
431
+ latents = self.prepare_latents(batch_size, dtype, device, generator)
432
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
433
+
434
+ guidance_cond = None
435
+ if getattr(self.model, 'guidance_cond_proj_dim', None) is not None:
436
+ print('Using lcm guidance scale')
437
+ guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(batch_size)
438
+ guidance_cond = self.get_guidance_scale_embedding(
439
+ guidance_scale_tensor, embedding_dim=self.model.guidance_cond_proj_dim
440
+ ).to(device=device, dtype=latents.dtype)
441
+
442
+ for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:", leave=False)):
443
+ # expand the latents if we are doing classifier free guidance
444
+ if do_classifier_free_guidance:
445
+ latent_model_input = torch.cat([latents] * (3 if dual_guidance else 2))
446
+ else:
447
+ latent_model_input = latents
448
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
449
+
450
+ # predict the noise residual
451
+ timestep_tensor = torch.tensor([t], dtype=t_dtype, device=device)
452
+ timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
453
+ noise_pred = self.model(latent_model_input, timestep_tensor, cond, guidance_cond=guidance_cond)
454
+
455
+ # no drop, drop clip, all drop
456
+ if do_classifier_free_guidance:
457
+ if dual_guidance:
458
+ noise_pred_clip, noise_pred_dino, noise_pred_uncond = noise_pred.chunk(3)
459
+ noise_pred = (
460
+ noise_pred_uncond
461
+ + guidance_scale * (noise_pred_clip - noise_pred_dino)
462
+ + dual_guidance_scale * (noise_pred_dino - noise_pred_uncond)
463
+ )
464
+ else:
465
+ noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
466
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
467
+
468
+ # compute the previous noisy sample x_t -> x_t-1
469
+ outputs = scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
470
+ latents = outputs.prev_sample
471
+
472
+ if callback is not None and i % callback_steps == 0:
473
+ step_idx = i // getattr(scheduler, "order", 1)
474
+ callback(step_idx, t, outputs)
475
+
476
+ return self._export(
477
+ latents,
478
+ output_type,
479
+ box_v, mc_level, num_chunks, octree_resolution, mc_algo,
480
+ )
481
+
482
+ def _export(self, latents, output_type, box_v, mc_level, num_chunks, octree_resolution, mc_algo):
483
+ if not output_type == "latent":
484
+ latents = 1. / self.vae.scale_factor * latents
485
+ latents = self.vae(latents)
486
+ outputs = self.vae.latents2mesh(
487
+ latents,
488
+ bounds=box_v,
489
+ mc_level=mc_level,
490
+ num_chunks=num_chunks,
491
+ octree_resolution=octree_resolution,
492
+ mc_algo=mc_algo,
493
+ )
494
+ else:
495
+ outputs = latents
496
+
497
+ if output_type == 'trimesh':
498
+ outputs = export_to_trimesh(outputs)
499
+
500
+ return outputs
501
+
502
+
503
+ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
504
+
505
+ @torch.no_grad()
506
+ def __call__(
507
+ self,
508
+ image: Union[str, List[str], Image.Image] = None,
509
+ num_inference_steps: int = 50,
510
+ timesteps: List[int] = None,
511
+ sigmas: List[float] = None,
512
+ eta: float = 0.0,
513
+ guidance_scale: float = 7.5,
514
+ generator=None,
515
+ box_v=1.01,
516
+ octree_resolution=384,
517
+ mc_level=0.0,
518
+ mc_algo='mc',
519
+ num_chunks=8000,
520
+ output_type: Optional[str] = "trimesh",
521
+ enable_pbar=True,
522
+ **kwargs,
523
+ ) -> List[List[trimesh.Trimesh]]:
524
+ callback = kwargs.pop("callback", None)
525
+ callback_steps = kwargs.pop("callback_steps", None)
526
+
527
+ device = self.device
528
+ dtype = self.dtype
529
+ do_classifier_free_guidance = guidance_scale >= 0 and not (
530
+ hasattr(self.model, 'guidance_embed') and
531
+ self.model.guidance_embed is True
532
+ )
533
+
534
+ image, mask = self.prepare_image(image)
535
+ cond = self.encode_cond(
536
+ image=image,
537
+ mask=mask,
538
+ do_classifier_free_guidance=do_classifier_free_guidance,
539
+ dual_guidance=False,
540
+ )
541
+ batch_size = image.shape[0]
542
+
543
+ # 5. Prepare timesteps
544
+ # NOTE: this is slightly different from common usage, we start from 0.
545
+ sigmas = np.linspace(0, 1, num_inference_steps) if sigmas is None else sigmas
546
+ scheduler = instantiate_from_config(self.kwargs['scheduler_cfg'])
547
+ timesteps, num_inference_steps = retrieve_timesteps(
548
+ scheduler,
549
+ num_inference_steps,
550
+ device,
551
+ sigmas=sigmas,
552
+ )
553
+ latents = self.prepare_latents(batch_size, dtype, device, generator)
554
+
555
+ guidance = None
556
+ if hasattr(self.model, 'guidance_embed') and \
557
+ self.model.guidance_embed is True:
558
+ guidance = torch.tensor([guidance_scale] * batch_size, device=device, dtype=dtype)
559
+
560
+ for i, t in enumerate(tqdm(timesteps, disable=not enable_pbar, desc="Diffusion Sampling:")):
561
+ # expand the latents if we are doing classifier free guidance
562
+ if do_classifier_free_guidance:
563
+ latent_model_input = torch.cat([latents] * 2)
564
+ else:
565
+ latent_model_input = latents
566
+
567
+ # NOTE: we assume model get timesteps ranged from 0 to 1
568
+ timestep = t.expand(latent_model_input.shape[0]).to(
569
+ latents.dtype) / scheduler.config.num_train_timesteps
570
+ noise_pred = self.model(latent_model_input, timestep, cond, guidance=guidance)
571
+
572
+ if do_classifier_free_guidance:
573
+ noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
574
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
575
+
576
+ # compute the previous noisy sample x_t -> x_t-1
577
+ outputs = scheduler.step(noise_pred, t, latents)
578
+ latents = outputs.prev_sample
579
+
580
+ if callback is not None and i % callback_steps == 0:
581
+ step_idx = i // getattr(scheduler, "order", 1)
582
+ callback(step_idx, t, outputs)
583
+
584
+ return self._export(
585
+ latents,
586
+ output_type,
587
+ box_v, mc_level, num_chunks, octree_resolution, mc_algo,
588
+ )
589
+