# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT # except for the third-party components listed below. # Hunyuan 3D does not impose any additional limitations beyond what is outlined # in the repsective licenses of these third-party components. # Users must comply with all terms and conditions of original licenses of these third-party # components and must ensure that the usage of the third party components adheres to # all relevant laws and regulations. # For avoidance of doubts, Hunyuan 3D means the large language models and # their software and algorithms, including trained model weights, parameters (including # optimizer states), machine-learning model code, inference-enabling code, training-enabling code, # fine-tuning enabling code and other elements of the foregoing made publicly available # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. import os import json import copy import numpy as np import torch import torch.nn as nn from einops import rearrange from typing import Any, Callable, Dict, List, Optional, Union, Tuple, Literal import diffusers from diffusers.utils import deprecate from diffusers import ( DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel, ) from diffusers.models import UNet2DConditionModel from diffusers.models.attention_processor import Attention, AttnProcessor from diffusers.models.transformers.transformer_2d import BasicTransformerBlock from .attn_processor import SelfAttnProcessor2_0, RefAttnProcessor2_0, PoseRoPEAttnProcessor2_0 from transformers import AutoImageProcessor, AutoModel class Dino_v2(nn.Module): """Wrapper for DINOv2 vision transformer (frozen weights). Provides feature extraction for reference images. Args: dino_v2_path: Custom path to DINOv2 model weights (uses default if None) """ def __init__(self, dino_v2_path): super(Dino_v2, self).__init__() self.dino_processor = AutoImageProcessor.from_pretrained(dino_v2_path) self.dino_v2 = AutoModel.from_pretrained(dino_v2_path) for param in self.parameters(): param.requires_grad = False self.dino_v2.eval() def forward(self, images): """Processes input images through DINOv2 ViT. Handles both tensor input (B, N, C, H, W) and PIL image lists. Extracts patch embeddings and flattens spatial dimensions. Returns: torch.Tensor: Feature vectors [B, N*(num_patches), feature_dim] """ if isinstance(images, torch.Tensor): batch_size = images.shape[0] dino_proceesed_images = self.dino_processor( images=rearrange(images, "b n c h w -> (b n) c h w"), return_tensors="pt", do_rescale=False ).pixel_values else: batch_size = 1 dino_proceesed_images = self.dino_processor(images=images, return_tensors="pt").pixel_values dino_proceesed_images = torch.stack( [torch.from_numpy(np.array(image)) for image in dino_proceesed_images], dim=0 ) dino_param = next(self.dino_v2.parameters()) dino_proceesed_images = dino_proceesed_images.to(dino_param) dino_hidden_states = self.dino_v2(dino_proceesed_images)[0] dino_hidden_states = rearrange(dino_hidden_states.to(dino_param), "(b n) l c -> b (n l) c", b=batch_size) return dino_hidden_states def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): # "feed_forward_chunk_size" can be used to save memory """Memory-efficient feedforward execution via chunking. Divides input along specified dimension for sequential processing. Args: ff: Feedforward module to apply hidden_states: Input tensor chunk_dim: Dimension to split chunk_size: Size of each chunk Returns: torch.Tensor: Reassembled output tensor """ if hidden_states.shape[chunk_dim] % chunk_size != 0: raise ValueError( f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]}" f"has to be divisible by chunk size: {chunk_size}." "Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." ) num_chunks = hidden_states.shape[chunk_dim] // chunk_size ff_output = torch.cat( [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], dim=chunk_dim, ) return ff_output @torch.no_grad() def compute_voxel_grid_mask(position, grid_resolution=8): """Generates view-to-view attention mask based on 3D position similarity. Uses voxel grid downsampling to determine spatially adjacent regions. Mask indicates where features should interact across different views. Args: position: Position maps [B, N, 3, H, W] (normalized 0-1) grid_resolution: Spatial reduction factor Returns: torch.Tensor: Attention mask [B, N*grid_res**2, N*grid_res**2] """ position = position.half() B, N, _, H, W = position.shape assert H % grid_resolution == 0 and W % grid_resolution == 0 valid_mask = (position != 1).all(dim=2, keepdim=True) valid_mask = valid_mask.expand_as(position) position[valid_mask == False] = 0 position = rearrange( position, "b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w", num_h=grid_resolution, num_w=grid_resolution, ) valid_mask = rearrange( valid_mask, "b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w", num_h=grid_resolution, num_w=grid_resolution, ) grid_position = position.sum(dim=(-2, -1)) count_masked = valid_mask.sum(dim=(-2, -1)) grid_position = grid_position / count_masked.clamp(min=1) grid_position[count_masked < 5] = 0 grid_position = grid_position.permute(0, 1, 4, 2, 3) grid_position = rearrange(grid_position, "b n c h w -> b n (h w) c") grid_position_expanded_1 = grid_position.unsqueeze(2).unsqueeze(4) # 形状变为 B, N, 1, L, 1, 3 grid_position_expanded_2 = grid_position.unsqueeze(1).unsqueeze(3) # 形状变为 B, 1, N, 1, L, 3 # 计算欧氏距离 distances = torch.norm(grid_position_expanded_1 - grid_position_expanded_2, dim=-1) # 形状为 B, N, N, L, L weights = distances grid_distance = 1.73 / grid_resolution weights = weights < grid_distance return weights def compute_multi_resolution_mask(position_maps, grid_resolutions=[32, 16, 8]): """Generates attention masks at multiple spatial resolutions. Creates pyramid of position-based masks for hierarchical attention. Args: position_maps: Position maps [B, N, 3, H, W] grid_resolutions: List of downsampling factors Returns: dict: Resolution-specific masks keyed by flattened dimension size """ position_attn_mask = {} with torch.no_grad(): for grid_resolution in grid_resolutions: position_mask = compute_voxel_grid_mask(position_maps, grid_resolution) position_mask = rearrange(position_mask, "b ni nj li lj -> b (ni li) (nj lj)") position_attn_mask[position_mask.shape[1]] = position_mask return position_attn_mask @torch.no_grad() def compute_discrete_voxel_indice(position, grid_resolution=8, voxel_resolution=128): """Quantizes position maps to discrete voxel indices. Creates sparse 3D coordinate representations for efficient hashing. Args: position: Position maps [B, N, 3, H, W] grid_resolution: Spatial downsampling factor voxel_resolution: Quantization resolution Returns: torch.Tensor: Voxel indices [B, N, grid_res, grid_res, 3] """ position = position.half() B, N, _, H, W = position.shape assert H % grid_resolution == 0 and W % grid_resolution == 0 valid_mask = (position != 1).all(dim=2, keepdim=True) valid_mask = valid_mask.expand_as(position) position[valid_mask == False] = 0 position = rearrange( position, "b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w", num_h=grid_resolution, num_w=grid_resolution, ) valid_mask = rearrange( valid_mask, "b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w", num_h=grid_resolution, num_w=grid_resolution, ) grid_position = position.sum(dim=(-2, -1)) count_masked = valid_mask.sum(dim=(-2, -1)) grid_position = grid_position / count_masked.clamp(min=1) voxel_mask_thres = (H // grid_resolution) * (W // grid_resolution) // (4 * 4) grid_position[count_masked < voxel_mask_thres] = 0 grid_position = grid_position.permute(0, 1, 4, 2, 3).clamp(0, 1) # B N C H W voxel_indices = grid_position * (voxel_resolution - 1) voxel_indices = torch.round(voxel_indices).long() return voxel_indices def calc_multires_voxel_idxs(position_maps, grid_resolutions=[64, 32, 16, 8], voxel_resolutions=[512, 256, 128, 64]): """Generates multi-resolution voxel indices for position encoding. Creates pyramid of quantized position representations. Args: position_maps: Input position maps grid_resolutions: Spatial resolution levels voxel_resolutions: Quantization levels Returns: dict: Voxel indices keyed by flattened dimension size, with resolution metadata """ voxel_indices = {} with torch.no_grad(): for grid_resolution, voxel_resolution in zip(grid_resolutions, voxel_resolutions): voxel_indice = compute_discrete_voxel_indice(position_maps, grid_resolution, voxel_resolution) voxel_indice = rearrange(voxel_indice, "b n c h w -> b (n h w) c") voxel_indices[voxel_indice.shape[1]] = {"voxel_indices": voxel_indice, "voxel_resolution": voxel_resolution} return voxel_indices class Basic2p5DTransformerBlock(torch.nn.Module): """Enhanced transformer block for multiview 2.5D image generation. Extends standard transformer blocks with: - Material-specific attention (MDA) - Multiview attention (MA) - Reference attention (RA) - DINO feature integration Args: transformer: Base transformer block layer_name: Identifier for layer use_ma: Enable multiview attention use_ra: Enable reference attention use_mda: Enable material-aware attention use_dino: Enable DINO feature integration pbr_setting: List of PBR materials """ def __init__( self, transformer: BasicTransformerBlock, layer_name, use_ma=True, use_ra=True, use_mda=True, use_dino=True, pbr_setting=None, ) -> None: """ Initialization: 1. Material-Dimension Attention (MDA): - Processes each PBR material with separate projection weights - Uses custom SelfAttnProcessor2_0 with material awareness 2. Multiview Attention (MA): - Adds cross-view attention with PoseRoPE - Initialized as zero-initialized residual pathway 3. Reference Attention (RA): - Conditions on reference view features - Uses RefAttnProcessor2_0 for material-specific conditioning 4. DINO Attention: - Incorporates DINO-ViT features - Initialized as zero-initialized residual pathway """ super().__init__() self.transformer = transformer self.layer_name = layer_name self.use_ma = use_ma self.use_ra = use_ra self.use_mda = use_mda self.use_dino = use_dino self.pbr_setting = pbr_setting if self.use_mda: self.attn1.set_processor( SelfAttnProcessor2_0( query_dim=self.dim, heads=self.num_attention_heads, dim_head=self.attention_head_dim, dropout=self.dropout, bias=self.attention_bias, cross_attention_dim=None, upcast_attention=self.attn1.upcast_attention, out_bias=True, pbr_setting=self.pbr_setting, ) ) # multiview attn if self.use_ma: self.attn_multiview = Attention( query_dim=self.dim, heads=self.num_attention_heads, dim_head=self.attention_head_dim, dropout=self.dropout, bias=self.attention_bias, cross_attention_dim=None, upcast_attention=self.attn1.upcast_attention, out_bias=True, processor=PoseRoPEAttnProcessor2_0(), ) # ref attn if self.use_ra: self.attn_refview = Attention( query_dim=self.dim, heads=self.num_attention_heads, dim_head=self.attention_head_dim, dropout=self.dropout, bias=self.attention_bias, cross_attention_dim=None, upcast_attention=self.attn1.upcast_attention, out_bias=True, processor=RefAttnProcessor2_0( query_dim=self.dim, heads=self.num_attention_heads, dim_head=self.attention_head_dim, dropout=self.dropout, bias=self.attention_bias, cross_attention_dim=None, upcast_attention=self.attn1.upcast_attention, out_bias=True, pbr_setting=self.pbr_setting, ), ) # dino attn if self.use_dino: self.attn_dino = Attention( query_dim=self.dim, heads=self.num_attention_heads, dim_head=self.attention_head_dim, dropout=self.dropout, bias=self.attention_bias, cross_attention_dim=self.cross_attention_dim, upcast_attention=self.attn2.upcast_attention, out_bias=True, ) self._initialize_attn_weights() def _initialize_attn_weights(self): """Initializes specialized attention heads with base weights. Uses weight sharing strategy: - Copies base transformer weights to specialized heads - Initializes newly-added parameters to zero """ if self.use_mda: for token in self.pbr_setting: if token == "albedo": continue getattr(self.attn1.processor, f"to_q_{token}").load_state_dict(self.attn1.to_q.state_dict()) getattr(self.attn1.processor, f"to_k_{token}").load_state_dict(self.attn1.to_k.state_dict()) getattr(self.attn1.processor, f"to_v_{token}").load_state_dict(self.attn1.to_v.state_dict()) getattr(self.attn1.processor, f"to_out_{token}").load_state_dict(self.attn1.to_out.state_dict()) if self.use_ma: self.attn_multiview.load_state_dict(self.attn1.state_dict(), strict=False) with torch.no_grad(): for layer in self.attn_multiview.to_out: for param in layer.parameters(): param.zero_() if self.use_ra: self.attn_refview.load_state_dict(self.attn1.state_dict(), strict=False) for token in self.pbr_setting: if token == "albedo": continue getattr(self.attn_refview.processor, f"to_v_{token}").load_state_dict( self.attn_refview.to_q.state_dict() ) getattr(self.attn_refview.processor, f"to_out_{token}").load_state_dict( self.attn_refview.to_out.state_dict() ) with torch.no_grad(): for layer in self.attn_refview.to_out: for param in layer.parameters(): param.zero_() for token in self.pbr_setting: if token == "albedo": continue for layer in getattr(self.attn_refview.processor, f"to_out_{token}"): for param in layer.parameters(): param.zero_() if self.use_dino: self.attn_dino.load_state_dict(self.attn2.state_dict(), strict=False) with torch.no_grad(): for layer in self.attn_dino.to_out: for param in layer.parameters(): param.zero_() if self.use_dino: self.attn_dino.load_state_dict(self.attn2.state_dict(), strict=False) with torch.no_grad(): for layer in self.attn_dino.to_out: for param in layer.parameters(): param.zero_() def __getattr__(self, name: str): try: return super().__getattr__(name) except AttributeError: return getattr(self.transformer, name) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: """Forward pass with multi-mechanism attention. Processing stages: 1. Material-aware self-attention (MDA) 2. Reference attention (RA) 3. Multiview attention (MA) with position-aware attention 4. Text conditioning (base attention) 5. DINO feature conditioning (optional) 6. Position-aware conditioning 7. Feed-forward network Args: hidden_states: Input features [B * N_materials * N_views, Seq_len, Feat_dim] See base transformer for other parameters Returns: torch.Tensor: Output features """ # [Full multi-mechanism processing pipeline...] # Key processing stages: # 1. Material-aware self-attention (handles albedo/mr separation) # 2. Reference attention (conditioned on reference features) # 3. View-to-view attention with geometric constraints # 4. Text-to-image cross-attention # 5. DINO feature fusion (when enabled) # 6. Positional conditioning (RoPE-style) # 7. Feed-forward network with conditional normalization # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention batch_size = hidden_states.shape[0] cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} num_in_batch = cross_attention_kwargs.pop("num_in_batch", 1) mode = cross_attention_kwargs.pop("mode", None) mva_scale = cross_attention_kwargs.pop("mva_scale", 1.0) ref_scale = cross_attention_kwargs.pop("ref_scale", 1.0) condition_embed_dict = cross_attention_kwargs.pop("condition_embed_dict", None) dino_hidden_states = cross_attention_kwargs.pop("dino_hidden_states", None) position_voxel_indices = cross_attention_kwargs.pop("position_voxel_indices", None) N_pbr = len(self.pbr_setting) if self.pbr_setting is not None else 1 if self.norm_type == "ada_norm": norm_hidden_states = self.norm1(hidden_states, timestep) elif self.norm_type == "ada_norm_zero": norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype ) elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: norm_hidden_states = self.norm1(hidden_states) elif self.norm_type == "ada_norm_continuous": norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) elif self.norm_type == "ada_norm_single": shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) ).chunk(6, dim=1) norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa else: raise ValueError("Incorrect norm used") if self.pos_embed is not None: norm_hidden_states = self.pos_embed(norm_hidden_states) # 1. Prepare GLIGEN inputs cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} gligen_kwargs = cross_attention_kwargs.pop("gligen", None) if self.use_mda: mda_norm_hidden_states = rearrange( norm_hidden_states, "(b n_pbr n) l c -> b n_pbr n l c", n=num_in_batch, n_pbr=N_pbr ) attn_output = self.attn1( mda_norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) attn_output = rearrange(attn_output, "b n_pbr n l c -> (b n_pbr n) l c") else: attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) if self.norm_type == "ada_norm_zero": attn_output = gate_msa.unsqueeze(1) * attn_output elif self.norm_type == "ada_norm_single": attn_output = gate_msa * attn_output hidden_states = attn_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) # 1.2 Reference Attention if "w" in mode: condition_embed_dict[self.layer_name] = rearrange( norm_hidden_states, "(b n) l c -> b (n l) c", n=num_in_batch ) # B, (N L), C if "r" in mode and self.use_ra: condition_embed = condition_embed_dict[self.layer_name] #! Only using albedo features for reference attention ref_norm_hidden_states = rearrange( norm_hidden_states, "(b n_pbr n) l c -> b n_pbr (n l) c", n=num_in_batch, n_pbr=N_pbr )[:, 0, ...] attn_output = self.attn_refview( ref_norm_hidden_states, encoder_hidden_states=condition_embed, attention_mask=None, **cross_attention_kwargs, ) # b (n l) c attn_output = rearrange(attn_output, "b n_pbr (n l) c -> (b n_pbr n) l c", n=num_in_batch, n_pbr=N_pbr) ref_scale_timing = ref_scale if isinstance(ref_scale, torch.Tensor): ref_scale_timing = ref_scale.unsqueeze(1).repeat(1, num_in_batch * N_pbr).view(-1) for _ in range(attn_output.ndim - 1): ref_scale_timing = ref_scale_timing.unsqueeze(-1) hidden_states = ref_scale_timing * attn_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) # 1.3 Multiview Attention if num_in_batch > 1 and self.use_ma: multivew_hidden_states = rearrange( norm_hidden_states, "(b n_pbr n) l c -> (b n_pbr) (n l) c", n_pbr=N_pbr, n=num_in_batch ) position_indices = None if position_voxel_indices is not None: if multivew_hidden_states.shape[1] in position_voxel_indices: position_indices = position_voxel_indices[multivew_hidden_states.shape[1]] attn_output = self.attn_multiview( multivew_hidden_states, encoder_hidden_states=multivew_hidden_states, position_indices=position_indices, n_pbrs=N_pbr, **cross_attention_kwargs, ) attn_output = rearrange(attn_output, "(b n_pbr) (n l) c -> (b n_pbr n) l c", n_pbr=N_pbr, n=num_in_batch) hidden_states = mva_scale * attn_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) # 1.2 GLIGEN Control if gligen_kwargs is not None: hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) # 3. Cross-Attention if self.attn2 is not None: if self.norm_type == "ada_norm": norm_hidden_states = self.norm2(hidden_states, timestep) elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: norm_hidden_states = self.norm2(hidden_states) elif self.norm_type == "ada_norm_single": # For PixArt norm2 isn't applied here: # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 norm_hidden_states = hidden_states elif self.norm_type == "ada_norm_continuous": norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) else: raise ValueError("Incorrect norm") if self.pos_embed is not None and self.norm_type != "ada_norm_single": norm_hidden_states = self.pos_embed(norm_hidden_states) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # dino attn if self.use_dino: dino_hidden_states = dino_hidden_states.unsqueeze(1).repeat(1, N_pbr * num_in_batch, 1, 1) dino_hidden_states = rearrange(dino_hidden_states, "b n l c -> (b n) l c") attn_output = self.attn_dino( norm_hidden_states, encoder_hidden_states=dino_hidden_states, attention_mask=None, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 4. Feed-forward # i2vgen doesn't have this norm 🤷‍♂️ if self.norm_type == "ada_norm_continuous": norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) elif not self.norm_type == "ada_norm_single": norm_hidden_states = self.norm3(hidden_states) if self.norm_type == "ada_norm_zero": norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self.norm_type == "ada_norm_single": norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) if self.norm_type == "ada_norm_zero": ff_output = gate_mlp.unsqueeze(1) * ff_output elif self.norm_type == "ada_norm_single": ff_output = gate_mlp * ff_output hidden_states = ff_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) return hidden_states class ImageProjModel(torch.nn.Module): """Projects image embeddings into cross-attention space. Transforms CLIP embeddings into additional context tokens for conditioning. Args: cross_attention_dim: Dimension of attention space clip_embeddings_dim: Dimension of input CLIP embeddings clip_extra_context_tokens: Number of context tokens to generate """ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() self.generator = None self.cross_attention_dim = cross_attention_dim self.clip_extra_context_tokens = clip_extra_context_tokens self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds): """Projects image embeddings to cross-attention context tokens. Args: image_embeds: Input embeddings [B, N, C] or [B, C] Returns: torch.Tensor: Context tokens [B, N*clip_extra_context_tokens, cross_attention_dim] """ embeds = image_embeds num_token = 1 if embeds.dim() == 3: num_token = embeds.shape[1] embeds = rearrange(embeds, "b n c -> (b n) c") clip_extra_context_tokens = self.proj(embeds).reshape( -1, self.clip_extra_context_tokens, self.cross_attention_dim ) clip_extra_context_tokens = self.norm(clip_extra_context_tokens) clip_extra_context_tokens = rearrange(clip_extra_context_tokens, "(b nt) n c -> b (nt n) c", nt=num_token) return clip_extra_context_tokens class UNet2p5DConditionModel(torch.nn.Module): """2.5D UNet extension for multiview PBR generation. Enhances standard 2D UNet with: - Multiview attention mechanisms - Material-aware processing - Position-aware conditioning - Dual-stream reference processing Args: unet: Base 2D UNet model train_sched: Training scheduler (DDPM) val_sched: Validation scheduler (EulerAncestral) """ def __init__( self, unet: UNet2DConditionModel, train_sched: DDPMScheduler = None, val_sched: EulerAncestralDiscreteScheduler = None, ) -> None: super().__init__() self.unet = unet self.train_sched = train_sched self.val_sched = val_sched self.use_ma = True self.use_ra = True self.use_mda = True self.use_dino = True self.use_position_rope = True self.use_learned_text_clip = True self.use_dual_stream = True self.pbr_setting = ["albedo", "mr"] self.pbr_token_channels = 77 if self.use_dual_stream and self.use_ra: self.unet_dual = copy.deepcopy(unet) self.init_attention(self.unet_dual) self.init_attention( self.unet, use_ma=self.use_ma, use_ra=self.use_ra, use_dino=self.use_dino, use_mda=self.use_mda, pbr_setting=self.pbr_setting, ) self.init_condition(use_dino=self.use_dino) @staticmethod def from_pretrained(pretrained_model_name_or_path, **kwargs): torch_dtype = kwargs.pop("torch_dtype", torch.float32) config_path = os.path.join(pretrained_model_name_or_path, "config.json") unet_ckpt_path = os.path.join(pretrained_model_name_or_path, "diffusion_pytorch_model.bin") with open(config_path, "r", encoding="utf-8") as file: config = json.load(file) unet = UNet2DConditionModel(**config) unet_2p5d = UNet2p5DConditionModel(unet) unet_2p5d.unet.conv_in = torch.nn.Conv2d( 12, unet.conv_in.out_channels, kernel_size=unet.conv_in.kernel_size, stride=unet.conv_in.stride, padding=unet.conv_in.padding, dilation=unet.conv_in.dilation, groups=unet.conv_in.groups, bias=unet.conv_in.bias is not None, ) unet_ckpt = torch.load(unet_ckpt_path, map_location="cpu", weights_only=True) unet_2p5d.load_state_dict(unet_ckpt, strict=True) unet_2p5d = unet_2p5d.to(torch_dtype) return unet_2p5d def init_condition(self, use_dino): """Initializes conditioning mechanisms for multiview PBR generation. Sets up: 1. Learned text embeddings: Material-specific tokens (albedo, mr) initialized to zeros 2. DINO projector: Model to process DINO-ViT features for cross-attention Args: use_dino: Flag to enable DINO feature integration """ if self.use_learned_text_clip: for token in self.pbr_setting: self.unet.register_parameter( f"learned_text_clip_{token}", nn.Parameter(torch.zeros(self.pbr_token_channels, 1024)) ) self.unet.learned_text_clip_ref = nn.Parameter(torch.zeros(self.pbr_token_channels, 1024)) if use_dino: self.unet.image_proj_model_dino = ImageProjModel( cross_attention_dim=self.unet.config.cross_attention_dim, clip_embeddings_dim=1536, clip_extra_context_tokens=4, ) def init_attention(self, unet, use_ma=False, use_ra=False, use_mda=False, use_dino=False, pbr_setting=None): """Recursively replaces standard transformers with enhanced 2.5D blocks. Processes UNet architecture: 1. Downsampling blocks: Replaces transformers in attention layers 2. Middle block: Upgrades central transformers 3. Upsampling blocks: Modifies decoder transformers Args: unet: UNet model to enhance use_ma: Enable multiview attention use_ra: Enable reference attention use_mda: Enable material-specific attention use_dino: Enable DINO feature integration pbr_setting: List of PBR materials """ for down_block_i, down_block in enumerate(unet.down_blocks): if hasattr(down_block, "has_cross_attention") and down_block.has_cross_attention: for attn_i, attn in enumerate(down_block.attentions): for transformer_i, transformer in enumerate(attn.transformer_blocks): if isinstance(transformer, BasicTransformerBlock): attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock( transformer, f"down_{down_block_i}_{attn_i}_{transformer_i}", use_ma, use_ra, use_mda, use_dino, pbr_setting, ) if hasattr(unet.mid_block, "has_cross_attention") and unet.mid_block.has_cross_attention: for attn_i, attn in enumerate(unet.mid_block.attentions): for transformer_i, transformer in enumerate(attn.transformer_blocks): if isinstance(transformer, BasicTransformerBlock): attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock( transformer, f"mid_{attn_i}_{transformer_i}", use_ma, use_ra, use_mda, use_dino, pbr_setting ) for up_block_i, up_block in enumerate(unet.up_blocks): if hasattr(up_block, "has_cross_attention") and up_block.has_cross_attention: for attn_i, attn in enumerate(up_block.attentions): for transformer_i, transformer in enumerate(attn.transformer_blocks): if isinstance(transformer, BasicTransformerBlock): attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock( transformer, f"up_{up_block_i}_{attn_i}_{transformer_i}", use_ma, use_ra, use_mda, use_dino, pbr_setting, ) def __getattr__(self, name: str): try: return super().__getattr__(name) except AttributeError: return getattr(self.unet, name) def forward( self, sample, timestep, encoder_hidden_states, *args, added_cond_kwargs=None, cross_attention_kwargs=None, down_intrablock_additional_residuals=None, down_block_res_samples=None, mid_block_res_sample=None, **cached_condition, ): """Forward pass with multiview/material conditioning. Key stages: 1. Input preparation (concat normal/position maps) 2. Reference feature extraction (dual-stream) 3. Position encoding (voxel indices) 4. DINO feature projection 5. Main UNet processing with attention conditioning Args: sample: Input latents [B, N_pbr, N_gen, C, H, W] cached_condition: Dictionary containing: - embeds_normal: Normal map embeddings - embeds_position: Position map embeddings - ref_latents: Reference image latents - dino_hidden_states: DINO features - position_maps: 3D position maps - mva_scale: Multiview attention scale - ref_scale: Reference attention scale Returns: torch.Tensor: Output features """ B, N_pbr, N_gen, _, H, W = sample.shape assert H == W if "cache" not in cached_condition: cached_condition["cache"] = {} sample = [sample] if "embeds_normal" in cached_condition: sample.append(cached_condition["embeds_normal"].unsqueeze(1).repeat(1, N_pbr, 1, 1, 1, 1)) if "embeds_position" in cached_condition: sample.append(cached_condition["embeds_position"].unsqueeze(1).repeat(1, N_pbr, 1, 1, 1, 1)) sample = torch.cat(sample, dim=-3) sample = rearrange(sample, "b n_pbr n c h w -> (b n_pbr n) c h w") encoder_hidden_states_gen = encoder_hidden_states.unsqueeze(-3).repeat(1, 1, N_gen, 1, 1) encoder_hidden_states_gen = rearrange(encoder_hidden_states_gen, "b n_pbr n l c -> (b n_pbr n) l c") if added_cond_kwargs is not None: text_embeds_gen = added_cond_kwargs["text_embeds"].unsqueeze(1).repeat(1, N_gen, 1) text_embeds_gen = rearrange(text_embeds_gen, "b n c -> (b n) c") time_ids_gen = added_cond_kwargs["time_ids"].unsqueeze(1).repeat(1, N_gen, 1) time_ids_gen = rearrange(time_ids_gen, "b n c -> (b n) c") added_cond_kwargs_gen = {"text_embeds": text_embeds_gen, "time_ids": time_ids_gen} else: added_cond_kwargs_gen = None if self.use_position_rope: if "position_voxel_indices" in cached_condition["cache"]: position_voxel_indices = cached_condition["cache"]["position_voxel_indices"] else: if "position_maps" in cached_condition: position_voxel_indices = calc_multires_voxel_idxs( cached_condition["position_maps"], grid_resolutions=[H, H // 2, H // 4, H // 8], voxel_resolutions=[H * 8, H * 4, H * 2, H], ) cached_condition["cache"]["position_voxel_indices"] = position_voxel_indices else: position_voxel_indices = None if self.use_dino: if "dino_hidden_states_proj" in cached_condition["cache"]: dino_hidden_states = cached_condition["cache"]["dino_hidden_states_proj"] else: assert "dino_hidden_states" in cached_condition dino_hidden_states = cached_condition["dino_hidden_states"] dino_hidden_states = self.image_proj_model_dino(dino_hidden_states) cached_condition["cache"]["dino_hidden_states_proj"] = dino_hidden_states else: dino_hidden_states = None if self.use_ra: if "condition_embed_dict" in cached_condition["cache"]: condition_embed_dict = cached_condition["cache"]["condition_embed_dict"] else: condition_embed_dict = {} ref_latents = cached_condition["ref_latents"] N_ref = ref_latents.shape[1] if not self.use_dual_stream: ref_latents = [ref_latents] if "embeds_normal" in cached_condition: ref_latents.append(torch.zeros_like(ref_latents[0])) if "embeds_position" in cached_condition: ref_latents.append(torch.zeros_like(ref_latents[0])) ref_latents = torch.cat(ref_latents, dim=2) ref_latents = rearrange(ref_latents, "b n c h w -> (b n) c h w") encoder_hidden_states_ref = self.unet.learned_text_clip_ref.repeat(B, N_ref, 1, 1) encoder_hidden_states_ref = rearrange(encoder_hidden_states_ref, "b n l c -> (b n) l c") if added_cond_kwargs is not None: text_embeds_ref = added_cond_kwargs["text_embeds"].unsqueeze(1).repeat(1, N_ref, 1) text_embeds_ref = rearrange(text_embeds_ref, "b n c -> (b n) c") time_ids_ref = added_cond_kwargs["time_ids"].unsqueeze(1).repeat(1, N_ref, 1) time_ids_ref = rearrange(time_ids_ref, "b n c -> (b n) c") added_cond_kwargs_ref = { "text_embeds": text_embeds_ref, "time_ids": time_ids_ref, } else: added_cond_kwargs_ref = None noisy_ref_latents = ref_latents timestep_ref = 0 if self.use_dual_stream: unet_ref = self.unet_dual else: unet_ref = self.unet unet_ref( noisy_ref_latents, timestep_ref, encoder_hidden_states=encoder_hidden_states_ref, class_labels=None, added_cond_kwargs=added_cond_kwargs_ref, # **kwargs return_dict=False, cross_attention_kwargs={ "mode": "w", "num_in_batch": N_ref, "condition_embed_dict": condition_embed_dict, }, ) cached_condition["cache"]["condition_embed_dict"] = condition_embed_dict else: condition_embed_dict = None mva_scale = cached_condition.get("mva_scale", 1.0) ref_scale = cached_condition.get("ref_scale", 1.0) return self.unet( sample, timestep, encoder_hidden_states_gen, *args, class_labels=None, added_cond_kwargs=added_cond_kwargs_gen, down_intrablock_additional_residuals=( [sample.to(dtype=self.unet.dtype) for sample in down_intrablock_additional_residuals] if down_intrablock_additional_residuals is not None else None ), down_block_additional_residuals=( [sample.to(dtype=self.unet.dtype) for sample in down_block_res_samples] if down_block_res_samples is not None else None ), mid_block_additional_residual=( mid_block_res_sample.to(dtype=self.unet.dtype) if mid_block_res_sample is not None else None ), return_dict=False, cross_attention_kwargs={ "mode": "r", "num_in_batch": N_gen, "dino_hidden_states": dino_hidden_states, "condition_embed_dict": condition_embed_dict, "mva_scale": mva_scale, "ref_scale": ref_scale, "position_voxel_indices": position_voxel_indices, }, )