# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT # except for the third-party components listed below. # Hunyuan 3D does not impose any additional limitations beyond what is outlined # in the repsective licenses of these third-party components. # Users must comply with all terms and conditions of original licenses of these third-party # components and must ensure that the usage of the third party components adheres to # all relevant laws and regulations. # For avoidance of doubts, Hunyuan 3D means the large language models and # their software and algorithms, including trained model weights, parameters (including # optimizer states), machine-learning model code, inference-enabling code, training-enabling code, # fine-tuning enabling code and other elements of the foregoing made publicly available # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Dict, Tuple, Union, Literal, List, Callable from einops import rearrange from diffusers.utils import deprecate from diffusers.models.attention_processor import Attention, AttnProcessor class AttnUtils: """ Shared utility functions for attention processing. This class provides common operations used across different attention processors to eliminate code duplication and improve maintainability. """ @staticmethod def check_pytorch_compatibility(): """ Check PyTorch compatibility for scaled_dot_product_attention. Raises: ImportError: If PyTorch version doesn't support scaled_dot_product_attention """ if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @staticmethod def handle_deprecation_warning(args, kwargs): """ Handle deprecation warning for the 'scale' argument. Args: args: Positional arguments passed to attention processor kwargs: Keyword arguments passed to attention processor """ if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = ( "The `scale` argument is deprecated and will be ignored." "Please remove it, as passing it will raise an error in the future." "`scale` should directly be passed while calling the underlying pipeline component" "i.e., via `cross_attention_kwargs`." ) deprecate("scale", "1.0.0", deprecation_message) @staticmethod def prepare_hidden_states( hidden_states, attn, temb, spatial_norm_attr="spatial_norm", group_norm_attr="group_norm" ): """ Common preprocessing of hidden states for attention computation. Args: hidden_states: Input hidden states tensor attn: Attention module instance temb: Optional temporal embedding tensor spatial_norm_attr: Attribute name for spatial normalization group_norm_attr: Attribute name for group normalization Returns: Tuple of (processed_hidden_states, residual, input_ndim, shape_info) """ residual = hidden_states spatial_norm = getattr(attn, spatial_norm_attr, None) if spatial_norm is not None: hidden_states = spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) else: batch_size, channel, height, width = None, None, None, None group_norm = getattr(attn, group_norm_attr, None) if group_norm is not None: hidden_states = group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) return hidden_states, residual, input_ndim, (batch_size, channel, height, width) @staticmethod def prepare_attention_mask(attention_mask, attn, sequence_length, batch_size): """ Prepare attention mask for scaled_dot_product_attention. Args: attention_mask: Input attention mask tensor or None attn: Attention module instance sequence_length: Length of the sequence batch_size: Batch size Returns: Prepared attention mask tensor reshaped for multi-head attention """ if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) return attention_mask @staticmethod def reshape_qkv_for_attention(tensor, batch_size, attn_heads, head_dim): """ Reshape Q/K/V tensors for multi-head attention computation. Args: tensor: Input tensor to reshape batch_size: Batch size attn_heads: Number of attention heads head_dim: Dimension per attention head Returns: Reshaped tensor with shape [batch_size, attn_heads, seq_len, head_dim] """ return tensor.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) @staticmethod def apply_norms(query, key, norm_q, norm_k): """ Apply Q/K normalization layers if available. Args: query: Query tensor key: Key tensor norm_q: Query normalization layer (optional) norm_k: Key normalization layer (optional) Returns: Tuple of (normalized_query, normalized_key) """ if norm_q is not None: query = norm_q(query) if norm_k is not None: key = norm_k(key) return query, key @staticmethod def finalize_output(hidden_states, input_ndim, shape_info, attn, residual, to_out): """ Common output processing including projection, dropout, reshaping, and residual connection. Args: hidden_states: Processed hidden states from attention input_ndim: Original input tensor dimensions shape_info: Tuple containing original shape information attn: Attention module instance residual: Residual connection tensor to_out: Output projection layers [linear, dropout] Returns: Final output tensor after all processing steps """ batch_size, channel, height, width = shape_info # Apply output projection and dropout hidden_states = to_out[0](hidden_states) hidden_states = to_out[1](hidden_states) # Reshape back if needed if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) # Apply residual connection if attn.residual_connection: hidden_states = hidden_states + residual # Apply rescaling hidden_states = hidden_states / attn.rescale_output_factor return hidden_states # Base class for attention processors (eliminating initialization duplication) class BaseAttnProcessor(nn.Module): """ Base class for attention processors with common initialization. This base class provides shared parameter initialization and module registration functionality to reduce code duplication across different attention processor types. """ def __init__( self, query_dim: int, pbr_setting: List[str] = ["albedo", "mr"], cross_attention_dim: Optional[int] = None, heads: int = 8, kv_heads: Optional[int] = None, dim_head: int = 64, dropout: float = 0.0, bias: bool = False, upcast_attention: bool = False, upcast_softmax: bool = False, cross_attention_norm: Optional[str] = None, cross_attention_norm_num_groups: int = 32, qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, added_proj_bias: Optional[bool] = True, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, out_bias: bool = True, scale_qk: bool = True, only_cross_attention: bool = False, eps: float = 1e-5, rescale_output_factor: float = 1.0, residual_connection: bool = False, _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, out_dim: int = None, out_context_dim: int = None, context_pre_only=None, pre_only=False, elementwise_affine: bool = True, is_causal: bool = False, **kwargs, ): """ Initialize base attention processor with common parameters. Args: query_dim: Dimension of query features pbr_setting: List of PBR material types to process (e.g., ["albedo", "mr"]) cross_attention_dim: Dimension of cross-attention features (optional) heads: Number of attention heads kv_heads: Number of key-value heads for grouped query attention (optional) dim_head: Dimension per attention head dropout: Dropout rate bias: Whether to use bias in linear projections upcast_attention: Whether to upcast attention computation to float32 upcast_softmax: Whether to upcast softmax computation to float32 cross_attention_norm: Type of cross-attention normalization (optional) cross_attention_norm_num_groups: Number of groups for cross-attention norm qk_norm: Type of query-key normalization (optional) added_kv_proj_dim: Dimension for additional key-value projections (optional) added_proj_bias: Whether to use bias in additional projections norm_num_groups: Number of groups for normalization (optional) spatial_norm_dim: Dimension for spatial normalization (optional) out_bias: Whether to use bias in output projection scale_qk: Whether to scale query-key products only_cross_attention: Whether to only perform cross-attention eps: Small epsilon value for numerical stability rescale_output_factor: Factor to rescale output values residual_connection: Whether to use residual connections _from_deprecated_attn_block: Flag for deprecated attention blocks processor: Optional attention processor instance out_dim: Output dimension (optional) out_context_dim: Output context dimension (optional) context_pre_only: Whether to only process context in pre-processing pre_only: Whether to only perform pre-processing elementwise_affine: Whether to use element-wise affine transformations is_causal: Whether to use causal attention masking **kwargs: Additional keyword arguments """ super().__init__() AttnUtils.check_pytorch_compatibility() # Store common attributes self.pbr_setting = pbr_setting self.n_pbr_tokens = len(self.pbr_setting) self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads self.query_dim = query_dim self.use_bias = bias self.is_cross_attention = cross_attention_dim is not None self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection self.dropout = dropout self.fused_projections = False self.out_dim = out_dim if out_dim is not None else query_dim self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim self.context_pre_only = context_pre_only self.pre_only = pre_only self.is_causal = is_causal self._from_deprecated_attn_block = _from_deprecated_attn_block self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.heads = out_dim // dim_head if out_dim is not None else heads self.sliceable_head_dim = heads self.added_kv_proj_dim = added_kv_proj_dim self.only_cross_attention = only_cross_attention self.added_proj_bias = added_proj_bias # Validation if self.added_kv_proj_dim is None and self.only_cross_attention: raise ValueError( "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None." "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." ) def register_pbr_modules(self, module_types: List[str], **kwargs): """ Generic PBR module registration to eliminate code repetition. Dynamically registers PyTorch modules for different PBR material types based on the specified module types and PBR settings. Args: module_types: List of module types to register ("qkv", "v_only", "out", "add_kv") **kwargs: Additional arguments for module configuration """ for pbr_token in self.pbr_setting: if pbr_token == "albedo": continue for module_type in module_types: if module_type == "qkv": self.register_module( f"to_q_{pbr_token}", nn.Linear(self.query_dim, self.inner_dim, bias=self.use_bias) ) self.register_module( f"to_k_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias) ) self.register_module( f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias) ) elif module_type == "v_only": self.register_module( f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias) ) elif module_type == "out": if not self.pre_only: self.register_module( f"to_out_{pbr_token}", nn.ModuleList( [ nn.Linear(self.inner_dim, self.out_dim, bias=kwargs.get("out_bias", True)), nn.Dropout(self.dropout), ] ), ) else: self.register_module(f"to_out_{pbr_token}", None) elif module_type == "add_kv": if self.added_kv_proj_dim is not None: self.register_module( f"add_k_proj_{pbr_token}", nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias), ) self.register_module( f"add_v_proj_{pbr_token}", nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias), ) else: self.register_module(f"add_k_proj_{pbr_token}", None) self.register_module(f"add_v_proj_{pbr_token}", None) # Rotary Position Embedding utilities (specialized for PoseRoPE) class RotaryEmbedding: """ Rotary position embedding utilities for 3D spatial attention. Provides functions to compute and apply rotary position embeddings (RoPE) for 1D, 3D spatial coordinates used in 3D-aware attention mechanisms. """ @staticmethod def get_1d_rotary_pos_embed(dim: int, pos: torch.Tensor, theta: float = 10000.0, linear_factor=1.0, ntk_factor=1.0): """ Compute 1D rotary position embeddings. Args: dim: Embedding dimension (must be even) pos: Position tensor theta: Base frequency for rotary embeddings linear_factor: Linear scaling factor ntk_factor: NTK (Neural Tangent Kernel) scaling factor Returns: Tuple of (cos_embeddings, sin_embeddings) """ assert dim % 2 == 0 theta = theta * ntk_factor freqs = ( 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device)[: (dim // 2)] / dim)) / linear_factor ) freqs = torch.outer(pos, freqs) freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() return freqs_cos, freqs_sin @staticmethod def get_3d_rotary_pos_embed(position, embed_dim, voxel_resolution, theta: int = 10000): """ Compute 3D rotary position embeddings for spatial coordinates. Args: position: 3D position tensor with shape [..., 3] embed_dim: Embedding dimension voxel_resolution: Resolution of the voxel grid theta: Base frequency for rotary embeddings Returns: Tuple of (cos_embeddings, sin_embeddings) for 3D positions """ assert position.shape[-1] == 3 dim_xy = embed_dim // 8 * 3 dim_z = embed_dim // 8 * 2 grid = torch.arange(voxel_resolution, dtype=torch.float32, device=position.device) freqs_xy = RotaryEmbedding.get_1d_rotary_pos_embed(dim_xy, grid, theta=theta) freqs_z = RotaryEmbedding.get_1d_rotary_pos_embed(dim_z, grid, theta=theta) xy_cos, xy_sin = freqs_xy z_cos, z_sin = freqs_z embed_flattn = position.view(-1, position.shape[-1]) x_cos = xy_cos[embed_flattn[:, 0], :] x_sin = xy_sin[embed_flattn[:, 0], :] y_cos = xy_cos[embed_flattn[:, 1], :] y_sin = xy_sin[embed_flattn[:, 1], :] z_cos = z_cos[embed_flattn[:, 2], :] z_sin = z_sin[embed_flattn[:, 2], :] cos = torch.cat((x_cos, y_cos, z_cos), dim=-1) sin = torch.cat((x_sin, y_sin, z_sin), dim=-1) cos = cos.view(*position.shape[:-1], embed_dim) sin = sin.view(*position.shape[:-1], embed_dim) return cos, sin @staticmethod def apply_rotary_emb(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]): """ Apply rotary position embeddings to input tensor. Args: x: Input tensor to apply rotary embeddings to freqs_cis: Tuple of (cos_embeddings, sin_embeddings) or single tensor Returns: Tensor with rotary position embeddings applied """ cos, sin = freqs_cis cos, sin = cos.to(x.device), sin.to(x.device) cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out # Core attention processing logic (eliminating major duplication) class AttnCore: """ Core attention processing logic shared across processors. This class provides the fundamental attention computation pipeline that can be reused across different attention processor implementations. """ @staticmethod def process_attention_base( attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, get_qkv_fn: Callable = None, apply_rope_fn: Optional[Callable] = None, **kwargs, ): """ Generic attention processing core shared across different processors. This function implements the common attention computation pipeline including: 1. Hidden state preprocessing 2. Attention mask preparation 3. Q/K/V computation via provided function 4. Tensor reshaping for multi-head attention 5. Optional normalization and RoPE application 6. Scaled dot-product attention computation Args: attn: Attention module instance hidden_states: Input hidden states tensor encoder_hidden_states: Optional encoder hidden states for cross-attention attention_mask: Optional attention mask tensor temb: Optional temporal embedding tensor get_qkv_fn: Function to compute Q, K, V tensors apply_rope_fn: Optional function to apply rotary position embeddings **kwargs: Additional keyword arguments passed to subfunctions Returns: Tuple containing (attention_output, residual, input_ndim, shape_info, batch_size, num_heads, head_dim) """ # Prepare hidden states hidden_states, residual, input_ndim, shape_info = AttnUtils.prepare_hidden_states(hidden_states, attn, temb) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) # Prepare attention mask attention_mask = AttnUtils.prepare_attention_mask(attention_mask, attn, sequence_length, batch_size) # Get Q, K, V if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) query, key, value = get_qkv_fn(attn, hidden_states, encoder_hidden_states, **kwargs) # Reshape for attention inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = AttnUtils.reshape_qkv_for_attention(query, batch_size, attn.heads, head_dim) key = AttnUtils.reshape_qkv_for_attention(key, batch_size, attn.heads, head_dim) value = AttnUtils.reshape_qkv_for_attention(value, batch_size, attn.heads, value.shape[-1] // attn.heads) # Apply normalization query, key = AttnUtils.apply_norms(query, key, getattr(attn, "norm_q", None), getattr(attn, "norm_k", None)) # Apply RoPE if provided if apply_rope_fn is not None: query, key = apply_rope_fn(query, key, head_dim, **kwargs) # Compute attention hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) return hidden_states, residual, input_ndim, shape_info, batch_size, attn.heads, head_dim # Specific processor implementations (minimal unique code) class PoseRoPEAttnProcessor2_0: """ Attention processor with Rotary Position Encoding (RoPE) for 3D spatial awareness. This processor extends standard attention with 3D rotary position embeddings to provide spatial awareness for 3D scene understanding tasks. """ def __init__(self): """Initialize the RoPE attention processor.""" AttnUtils.check_pytorch_compatibility() def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_indices: Dict = None, temb: Optional[torch.Tensor] = None, n_pbrs=1, *args, **kwargs, ) -> torch.Tensor: """ Apply RoPE-enhanced attention computation. Args: attn: Attention module instance hidden_states: Input hidden states tensor encoder_hidden_states: Optional encoder hidden states for cross-attention attention_mask: Optional attention mask tensor position_indices: Dictionary containing 3D position information for RoPE temb: Optional temporal embedding tensor n_pbrs: Number of PBR material types *args: Additional positional arguments **kwargs: Additional keyword arguments Returns: Attention output tensor with applied rotary position encodings """ AttnUtils.handle_deprecation_warning(args, kwargs) def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs): return attn.to_q(hidden_states), attn.to_k(encoder_hidden_states), attn.to_v(encoder_hidden_states) def apply_rope(query, key, head_dim, **kwargs): if position_indices is not None: if head_dim in position_indices: image_rotary_emb = position_indices[head_dim] else: image_rotary_emb = RotaryEmbedding.get_3d_rotary_pos_embed( rearrange( position_indices["voxel_indices"].unsqueeze(1).repeat(1, n_pbrs, 1, 1), "b n_pbrs l c -> (b n_pbrs) l c", ), head_dim, voxel_resolution=position_indices["voxel_resolution"], ) position_indices[head_dim] = image_rotary_emb query = RotaryEmbedding.apply_rotary_emb(query, image_rotary_emb) key = RotaryEmbedding.apply_rotary_emb(key, image_rotary_emb) return query, key # Core attention processing hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base( attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv, apply_rope_fn=apply_rope, position_indices=position_indices, n_pbrs=n_pbrs, ) # Finalize output hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim) hidden_states = hidden_states.to(hidden_states.dtype) return AttnUtils.finalize_output(hidden_states, input_ndim, shape_info, attn, residual, attn.to_out) class SelfAttnProcessor2_0(BaseAttnProcessor): """ Self-attention processor with PBR (Physically Based Rendering) material support. This processor handles multiple PBR material types (e.g., albedo, metallic-roughness) with separate attention computation paths for each material type. """ def __init__(self, **kwargs): """ Initialize self-attention processor with PBR support. Args: **kwargs: Arguments passed to BaseAttnProcessor initialization """ super().__init__(**kwargs) self.register_pbr_modules(["qkv", "out", "add_kv"], **kwargs) def process_single( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, token: Literal["albedo", "mr"] = "albedo", multiple_devices=False, *args, **kwargs, ): """ Process attention for a single PBR material type. Args: attn: Attention module instance hidden_states: Input hidden states tensor encoder_hidden_states: Optional encoder hidden states for cross-attention attention_mask: Optional attention mask tensor temb: Optional temporal embedding tensor token: PBR material type to process ("albedo", "mr", etc.) multiple_devices: Whether to use multiple GPU devices *args: Additional positional arguments **kwargs: Additional keyword arguments Returns: Processed attention output for the specified PBR material type """ target = attn if token == "albedo" else attn.processor token_suffix = "" if token == "albedo" else "_" + token # Device management (if needed) if multiple_devices: device = torch.device("cuda:0") if token == "albedo" else torch.device("cuda:1") for attr in [f"to_q{token_suffix}", f"to_k{token_suffix}", f"to_v{token_suffix}", f"to_out{token_suffix}"]: getattr(target, attr).to(device) def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs): return ( getattr(target, f"to_q{token_suffix}")(hidden_states), getattr(target, f"to_k{token_suffix}")(encoder_hidden_states), getattr(target, f"to_v{token_suffix}")(encoder_hidden_states), ) # Core processing using shared logic hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base( attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv ) # Finalize hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim) hidden_states = hidden_states.to(hidden_states.dtype) return AttnUtils.finalize_output( hidden_states, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}") ) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: """ Apply self-attention with PBR material processing. Processes multiple PBR material types sequentially, applying attention computation for each material type separately and combining results. Args: attn: Attention module instance hidden_states: Input hidden states tensor with PBR dimension encoder_hidden_states: Optional encoder hidden states for cross-attention attention_mask: Optional attention mask tensor temb: Optional temporal embedding tensor *args: Additional positional arguments **kwargs: Additional keyword arguments Returns: Combined attention output for all PBR material types """ AttnUtils.handle_deprecation_warning(args, kwargs) B = hidden_states.size(0) pbr_hidden_states = torch.split(hidden_states, 1, dim=1) # Process each PBR setting results = [] for token, pbr_hs in zip(self.pbr_setting, pbr_hidden_states): processed_hs = rearrange(pbr_hs, "b n_pbrs n l c -> (b n_pbrs n) l c").to("cuda:0") result = self.process_single(attn, processed_hs, None, attention_mask, temb, token, False) results.append(result) outputs = [rearrange(result, "(b n_pbrs n) l c -> b n_pbrs n l c", b=B, n_pbrs=1) for result in results] return torch.cat(outputs, dim=1) class RefAttnProcessor2_0(BaseAttnProcessor): """ Reference attention processor with shared value computation across PBR materials. This processor computes query and key once, but uses separate value projections for different PBR material types, enabling efficient multi-material processing. """ def __init__(self, **kwargs): """ Initialize reference attention processor. Args: **kwargs: Arguments passed to BaseAttnProcessor initialization """ super().__init__(**kwargs) self.pbr_settings = self.pbr_setting # Alias for compatibility self.register_pbr_modules(["v_only", "out"], **kwargs) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: """ Apply reference attention with shared Q/K and separate V projections. This method computes query and key tensors once and reuses them across all PBR material types, while using separate value projections for each material type to maintain material-specific information. Args: attn: Attention module instance hidden_states: Input hidden states tensor encoder_hidden_states: Optional encoder hidden states for cross-attention attention_mask: Optional attention mask tensor temb: Optional temporal embedding tensor *args: Additional positional arguments **kwargs: Additional keyword arguments Returns: Stacked attention output for all PBR material types """ AttnUtils.handle_deprecation_warning(args, kwargs) def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs): query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) # Concatenate values from all PBR settings value_list = [attn.to_v(encoder_hidden_states)] for token in ["_" + token for token in self.pbr_settings if token != "albedo"]: value_list.append(getattr(attn.processor, f"to_v{token}")(encoder_hidden_states)) value = torch.cat(value_list, dim=-1) return query, key, value # Core processing hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base( attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv ) # Split and process each PBR setting output hidden_states_list = torch.split(hidden_states, head_dim, dim=-1) output_hidden_states_list = [] for i, hs in enumerate(hidden_states_list): hs = hs.transpose(1, 2).reshape(batch_size, -1, heads * head_dim).to(hs.dtype) token_suffix = "_" + self.pbr_settings[i] if self.pbr_settings[i] != "albedo" else "" target = attn if self.pbr_settings[i] == "albedo" else attn.processor hs = AttnUtils.finalize_output( hs, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}") ) output_hidden_states_list.append(hs) return torch.stack(output_hidden_states_list, dim=1)