import os import torch import torch.nn as nn import yaml from .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors from .volume_decoders import VanillaVolumeDecoder from ...utils import logger, synchronize_timer class VectsetVAE(nn.Module): @classmethod @synchronize_timer('VectsetVAE Model Loading') def from_single_file( cls, ckpt_path, config_path, device='cuda', dtype=torch.float16, use_safetensors=None, **kwargs, ): # load config with open(config_path, 'r') as f: config = yaml.safe_load(f) # load ckpt if use_safetensors: ckpt_path = ckpt_path.replace('.ckpt', '.safetensors') if not os.path.exists(ckpt_path): raise FileNotFoundError(f"Model file {ckpt_path} not found") logger.info(f"Loading model from {ckpt_path}") if use_safetensors: import safetensors.torch ckpt = safetensors.torch.load_file(ckpt_path, device='cpu') else: ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True) model_kwargs = config['params'] model_kwargs.update(kwargs) model = cls(**model_kwargs) model.load_state_dict(ckpt) model.to(device=device, dtype=dtype) return model @classmethod def from_pretrained( cls, model_path, device='cuda', dtype=torch.float16, use_safetensors=True, variant='fp16', subfolder='hunyuan3d-vae-v2-0', **kwargs, ): original_model_path = model_path # try local path base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen') model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder)) logger.info(f'Try to load model from local path: {model_path}') if not os.path.exists(model_path): logger.info('Model path not exists, try to download from huggingface') try: import huggingface_hub # download from huggingface path = huggingface_hub.snapshot_download(repo_id=original_model_path) model_path = os.path.join(path, subfolder) except ImportError: logger.warning( "You need to install HuggingFace Hub to load models from the hub." ) raise RuntimeError(f"Model path {model_path} not found") except Exception as e: raise e if not os.path.exists(model_path): raise FileNotFoundError(f"Model path {original_model_path} not found") extension = 'ckpt' if not use_safetensors else 'safetensors' variant = '' if variant is None else f'.{variant}' ckpt_name = f'model{variant}.{extension}' config_path = os.path.join(model_path, 'config.yaml') ckpt_path = os.path.join(model_path, ckpt_name) return cls.from_single_file( ckpt_path, config_path, device=device, dtype=dtype, use_safetensors=use_safetensors, **kwargs ) def __init__( self, volume_decoder=None, surface_extractor=None ): super().__init__() if volume_decoder is None: volume_decoder = VanillaVolumeDecoder() if surface_extractor is None: surface_extractor = MCSurfaceExtractor() self.volume_decoder = volume_decoder self.surface_extractor = surface_extractor def latents2mesh(self, latents: torch.FloatTensor, **kwargs): with synchronize_timer('Volume decoding'): grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs) with synchronize_timer('Surface extraction'): outputs = self.surface_extractor(grid_logits, **kwargs) return outputs class ShapeVAE(VectsetVAE): def __init__( self, *, num_latents: int, embed_dim: int, width: int, heads: int, num_decoder_layers: int, geo_decoder_downsample_ratio: int = 1, geo_decoder_mlp_expand_ratio: int = 4, geo_decoder_ln_post: bool = True, num_freqs: int = 8, include_pi: bool = True, qkv_bias: bool = True, qk_norm: bool = False, label_type: str = "binary", drop_path_rate: float = 0.0, scale_factor: float = 1.0, ): super().__init__() self.geo_decoder_ln_post = geo_decoder_ln_post self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) self.post_kl = nn.Linear(embed_dim, width) self.transformer = Transformer( n_ctx=num_latents, width=width, layers=num_decoder_layers, heads=heads, qkv_bias=qkv_bias, qk_norm=qk_norm, drop_path_rate=drop_path_rate ) self.geo_decoder = CrossAttentionDecoder( fourier_embedder=self.fourier_embedder, out_channels=1, num_latents=num_latents, mlp_expand_ratio=geo_decoder_mlp_expand_ratio, downsample_ratio=geo_decoder_downsample_ratio, enable_ln_post=self.geo_decoder_ln_post, width=width // geo_decoder_downsample_ratio, heads=heads // geo_decoder_downsample_ratio, qkv_bias=qkv_bias, qk_norm=qk_norm, label_type=label_type, ) self.scale_factor = scale_factor self.latent_shape = (num_latents, embed_dim) def forward(self, latents): latents = self.post_kl(latents) latents = self.transformer(latents) return latents