Image-to-3D
Hunyuan3D-2
Diffusers
Safetensors
English
Chinese
text-to-3d
Huiwenshi's picture
Upload hunyuan3d-paintpbr-v2-1/unet/attn_processor.py with huggingface_hub
6e9aaf2 verified
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.
# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Tuple, Union, Literal, List, Callable
from einops import rearrange
from diffusers.utils import deprecate
from diffusers.models.attention_processor import Attention, AttnProcessor
class AttnUtils:
"""
Shared utility functions for attention processing.
This class provides common operations used across different attention processors
to eliminate code duplication and improve maintainability.
"""
@staticmethod
def check_pytorch_compatibility():
"""
Check PyTorch compatibility for scaled_dot_product_attention.
Raises:
ImportError: If PyTorch version doesn't support scaled_dot_product_attention
"""
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@staticmethod
def handle_deprecation_warning(args, kwargs):
"""
Handle deprecation warning for the 'scale' argument.
Args:
args: Positional arguments passed to attention processor
kwargs: Keyword arguments passed to attention processor
"""
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = (
"The `scale` argument is deprecated and will be ignored."
"Please remove it, as passing it will raise an error in the future."
"`scale` should directly be passed while calling the underlying pipeline component"
"i.e., via `cross_attention_kwargs`."
)
deprecate("scale", "1.0.0", deprecation_message)
@staticmethod
def prepare_hidden_states(
hidden_states, attn, temb, spatial_norm_attr="spatial_norm", group_norm_attr="group_norm"
):
"""
Common preprocessing of hidden states for attention computation.
Args:
hidden_states: Input hidden states tensor
attn: Attention module instance
temb: Optional temporal embedding tensor
spatial_norm_attr: Attribute name for spatial normalization
group_norm_attr: Attribute name for group normalization
Returns:
Tuple of (processed_hidden_states, residual, input_ndim, shape_info)
"""
residual = hidden_states
spatial_norm = getattr(attn, spatial_norm_attr, None)
if spatial_norm is not None:
hidden_states = spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
else:
batch_size, channel, height, width = None, None, None, None
group_norm = getattr(attn, group_norm_attr, None)
if group_norm is not None:
hidden_states = group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
return hidden_states, residual, input_ndim, (batch_size, channel, height, width)
@staticmethod
def prepare_attention_mask(attention_mask, attn, sequence_length, batch_size):
"""
Prepare attention mask for scaled_dot_product_attention.
Args:
attention_mask: Input attention mask tensor or None
attn: Attention module instance
sequence_length: Length of the sequence
batch_size: Batch size
Returns:
Prepared attention mask tensor reshaped for multi-head attention
"""
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
return attention_mask
@staticmethod
def reshape_qkv_for_attention(tensor, batch_size, attn_heads, head_dim):
"""
Reshape Q/K/V tensors for multi-head attention computation.
Args:
tensor: Input tensor to reshape
batch_size: Batch size
attn_heads: Number of attention heads
head_dim: Dimension per attention head
Returns:
Reshaped tensor with shape [batch_size, attn_heads, seq_len, head_dim]
"""
return tensor.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
@staticmethod
def apply_norms(query, key, norm_q, norm_k):
"""
Apply Q/K normalization layers if available.
Args:
query: Query tensor
key: Key tensor
norm_q: Query normalization layer (optional)
norm_k: Key normalization layer (optional)
Returns:
Tuple of (normalized_query, normalized_key)
"""
if norm_q is not None:
query = norm_q(query)
if norm_k is not None:
key = norm_k(key)
return query, key
@staticmethod
def finalize_output(hidden_states, input_ndim, shape_info, attn, residual, to_out):
"""
Common output processing including projection, dropout, reshaping, and residual connection.
Args:
hidden_states: Processed hidden states from attention
input_ndim: Original input tensor dimensions
shape_info: Tuple containing original shape information
attn: Attention module instance
residual: Residual connection tensor
to_out: Output projection layers [linear, dropout]
Returns:
Final output tensor after all processing steps
"""
batch_size, channel, height, width = shape_info
# Apply output projection and dropout
hidden_states = to_out[0](hidden_states)
hidden_states = to_out[1](hidden_states)
# Reshape back if needed
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
# Apply residual connection
if attn.residual_connection:
hidden_states = hidden_states + residual
# Apply rescaling
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
# Base class for attention processors (eliminating initialization duplication)
class BaseAttnProcessor(nn.Module):
"""
Base class for attention processors with common initialization.
This base class provides shared parameter initialization and module registration
functionality to reduce code duplication across different attention processor types.
"""
def __init__(
self,
query_dim: int,
pbr_setting: List[str] = ["albedo", "mr"],
cross_attention_dim: Optional[int] = None,
heads: int = 8,
kv_heads: Optional[int] = None,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: Optional[str] = None,
cross_attention_norm_num_groups: int = 32,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
out_context_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
is_causal: bool = False,
**kwargs,
):
"""
Initialize base attention processor with common parameters.
Args:
query_dim: Dimension of query features
pbr_setting: List of PBR material types to process (e.g., ["albedo", "mr"])
cross_attention_dim: Dimension of cross-attention features (optional)
heads: Number of attention heads
kv_heads: Number of key-value heads for grouped query attention (optional)
dim_head: Dimension per attention head
dropout: Dropout rate
bias: Whether to use bias in linear projections
upcast_attention: Whether to upcast attention computation to float32
upcast_softmax: Whether to upcast softmax computation to float32
cross_attention_norm: Type of cross-attention normalization (optional)
cross_attention_norm_num_groups: Number of groups for cross-attention norm
qk_norm: Type of query-key normalization (optional)
added_kv_proj_dim: Dimension for additional key-value projections (optional)
added_proj_bias: Whether to use bias in additional projections
norm_num_groups: Number of groups for normalization (optional)
spatial_norm_dim: Dimension for spatial normalization (optional)
out_bias: Whether to use bias in output projection
scale_qk: Whether to scale query-key products
only_cross_attention: Whether to only perform cross-attention
eps: Small epsilon value for numerical stability
rescale_output_factor: Factor to rescale output values
residual_connection: Whether to use residual connections
_from_deprecated_attn_block: Flag for deprecated attention blocks
processor: Optional attention processor instance
out_dim: Output dimension (optional)
out_context_dim: Output context dimension (optional)
context_pre_only: Whether to only process context in pre-processing
pre_only: Whether to only perform pre-processing
elementwise_affine: Whether to use element-wise affine transformations
is_causal: Whether to use causal attention masking
**kwargs: Additional keyword arguments
"""
super().__init__()
AttnUtils.check_pytorch_compatibility()
# Store common attributes
self.pbr_setting = pbr_setting
self.n_pbr_tokens = len(self.pbr_setting)
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.use_bias = bias
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.is_causal = is_causal
self._from_deprecated_attn_block = _from_deprecated_attn_block
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.only_cross_attention = only_cross_attention
self.added_proj_bias = added_proj_bias
# Validation
if self.added_kv_proj_dim is None and self.only_cross_attention:
raise ValueError(
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None."
"Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
)
def register_pbr_modules(self, module_types: List[str], **kwargs):
"""
Generic PBR module registration to eliminate code repetition.
Dynamically registers PyTorch modules for different PBR material types
based on the specified module types and PBR settings.
Args:
module_types: List of module types to register ("qkv", "v_only", "out", "add_kv")
**kwargs: Additional arguments for module configuration
"""
for pbr_token in self.pbr_setting:
if pbr_token == "albedo":
continue
for module_type in module_types:
if module_type == "qkv":
self.register_module(
f"to_q_{pbr_token}", nn.Linear(self.query_dim, self.inner_dim, bias=self.use_bias)
)
self.register_module(
f"to_k_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias)
)
self.register_module(
f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias)
)
elif module_type == "v_only":
self.register_module(
f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias)
)
elif module_type == "out":
if not self.pre_only:
self.register_module(
f"to_out_{pbr_token}",
nn.ModuleList(
[
nn.Linear(self.inner_dim, self.out_dim, bias=kwargs.get("out_bias", True)),
nn.Dropout(self.dropout),
]
),
)
else:
self.register_module(f"to_out_{pbr_token}", None)
elif module_type == "add_kv":
if self.added_kv_proj_dim is not None:
self.register_module(
f"add_k_proj_{pbr_token}",
nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias),
)
self.register_module(
f"add_v_proj_{pbr_token}",
nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias),
)
else:
self.register_module(f"add_k_proj_{pbr_token}", None)
self.register_module(f"add_v_proj_{pbr_token}", None)
# Rotary Position Embedding utilities (specialized for PoseRoPE)
class RotaryEmbedding:
"""
Rotary position embedding utilities for 3D spatial attention.
Provides functions to compute and apply rotary position embeddings (RoPE)
for 1D, 3D spatial coordinates used in 3D-aware attention mechanisms.
"""
@staticmethod
def get_1d_rotary_pos_embed(dim: int, pos: torch.Tensor, theta: float = 10000.0, linear_factor=1.0, ntk_factor=1.0):
"""
Compute 1D rotary position embeddings.
Args:
dim: Embedding dimension (must be even)
pos: Position tensor
theta: Base frequency for rotary embeddings
linear_factor: Linear scaling factor
ntk_factor: NTK (Neural Tangent Kernel) scaling factor
Returns:
Tuple of (cos_embeddings, sin_embeddings)
"""
assert dim % 2 == 0
theta = theta * ntk_factor
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
)
freqs = torch.outer(pos, freqs)
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float()
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float()
return freqs_cos, freqs_sin
@staticmethod
def get_3d_rotary_pos_embed(position, embed_dim, voxel_resolution, theta: int = 10000):
"""
Compute 3D rotary position embeddings for spatial coordinates.
Args:
position: 3D position tensor with shape [..., 3]
embed_dim: Embedding dimension
voxel_resolution: Resolution of the voxel grid
theta: Base frequency for rotary embeddings
Returns:
Tuple of (cos_embeddings, sin_embeddings) for 3D positions
"""
assert position.shape[-1] == 3
dim_xy = embed_dim // 8 * 3
dim_z = embed_dim // 8 * 2
grid = torch.arange(voxel_resolution, dtype=torch.float32, device=position.device)
freqs_xy = RotaryEmbedding.get_1d_rotary_pos_embed(dim_xy, grid, theta=theta)
freqs_z = RotaryEmbedding.get_1d_rotary_pos_embed(dim_z, grid, theta=theta)
xy_cos, xy_sin = freqs_xy
z_cos, z_sin = freqs_z
embed_flattn = position.view(-1, position.shape[-1])
x_cos = xy_cos[embed_flattn[:, 0], :]
x_sin = xy_sin[embed_flattn[:, 0], :]
y_cos = xy_cos[embed_flattn[:, 1], :]
y_sin = xy_sin[embed_flattn[:, 1], :]
z_cos = z_cos[embed_flattn[:, 2], :]
z_sin = z_sin[embed_flattn[:, 2], :]
cos = torch.cat((x_cos, y_cos, z_cos), dim=-1)
sin = torch.cat((x_sin, y_sin, z_sin), dim=-1)
cos = cos.view(*position.shape[:-1], embed_dim)
sin = sin.view(*position.shape[:-1], embed_dim)
return cos, sin
@staticmethod
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]):
"""
Apply rotary position embeddings to input tensor.
Args:
x: Input tensor to apply rotary embeddings to
freqs_cis: Tuple of (cos_embeddings, sin_embeddings) or single tensor
Returns:
Tensor with rotary position embeddings applied
"""
cos, sin = freqs_cis
cos, sin = cos.to(x.device), sin.to(x.device)
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
# Core attention processing logic (eliminating major duplication)
class AttnCore:
"""
Core attention processing logic shared across processors.
This class provides the fundamental attention computation pipeline
that can be reused across different attention processor implementations.
"""
@staticmethod
def process_attention_base(
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
get_qkv_fn: Callable = None,
apply_rope_fn: Optional[Callable] = None,
**kwargs,
):
"""
Generic attention processing core shared across different processors.
This function implements the common attention computation pipeline including:
1. Hidden state preprocessing
2. Attention mask preparation
3. Q/K/V computation via provided function
4. Tensor reshaping for multi-head attention
5. Optional normalization and RoPE application
6. Scaled dot-product attention computation
Args:
attn: Attention module instance
hidden_states: Input hidden states tensor
encoder_hidden_states: Optional encoder hidden states for cross-attention
attention_mask: Optional attention mask tensor
temb: Optional temporal embedding tensor
get_qkv_fn: Function to compute Q, K, V tensors
apply_rope_fn: Optional function to apply rotary position embeddings
**kwargs: Additional keyword arguments passed to subfunctions
Returns:
Tuple containing (attention_output, residual, input_ndim, shape_info,
batch_size, num_heads, head_dim)
"""
# Prepare hidden states
hidden_states, residual, input_ndim, shape_info = AttnUtils.prepare_hidden_states(hidden_states, attn, temb)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
# Prepare attention mask
attention_mask = AttnUtils.prepare_attention_mask(attention_mask, attn, sequence_length, batch_size)
# Get Q, K, V
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query, key, value = get_qkv_fn(attn, hidden_states, encoder_hidden_states, **kwargs)
# Reshape for attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = AttnUtils.reshape_qkv_for_attention(query, batch_size, attn.heads, head_dim)
key = AttnUtils.reshape_qkv_for_attention(key, batch_size, attn.heads, head_dim)
value = AttnUtils.reshape_qkv_for_attention(value, batch_size, attn.heads, value.shape[-1] // attn.heads)
# Apply normalization
query, key = AttnUtils.apply_norms(query, key, getattr(attn, "norm_q", None), getattr(attn, "norm_k", None))
# Apply RoPE if provided
if apply_rope_fn is not None:
query, key = apply_rope_fn(query, key, head_dim, **kwargs)
# Compute attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
return hidden_states, residual, input_ndim, shape_info, batch_size, attn.heads, head_dim
# Specific processor implementations (minimal unique code)
class PoseRoPEAttnProcessor2_0:
"""
Attention processor with Rotary Position Encoding (RoPE) for 3D spatial awareness.
This processor extends standard attention with 3D rotary position embeddings
to provide spatial awareness for 3D scene understanding tasks.
"""
def __init__(self):
"""Initialize the RoPE attention processor."""
AttnUtils.check_pytorch_compatibility()
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_indices: Dict = None,
temb: Optional[torch.Tensor] = None,
n_pbrs=1,
*args,
**kwargs,
) -> torch.Tensor:
"""
Apply RoPE-enhanced attention computation.
Args:
attn: Attention module instance
hidden_states: Input hidden states tensor
encoder_hidden_states: Optional encoder hidden states for cross-attention
attention_mask: Optional attention mask tensor
position_indices: Dictionary containing 3D position information for RoPE
temb: Optional temporal embedding tensor
n_pbrs: Number of PBR material types
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
Attention output tensor with applied rotary position encodings
"""
AttnUtils.handle_deprecation_warning(args, kwargs)
def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs):
return attn.to_q(hidden_states), attn.to_k(encoder_hidden_states), attn.to_v(encoder_hidden_states)
def apply_rope(query, key, head_dim, **kwargs):
if position_indices is not None:
if head_dim in position_indices:
image_rotary_emb = position_indices[head_dim]
else:
image_rotary_emb = RotaryEmbedding.get_3d_rotary_pos_embed(
rearrange(
position_indices["voxel_indices"].unsqueeze(1).repeat(1, n_pbrs, 1, 1),
"b n_pbrs l c -> (b n_pbrs) l c",
),
head_dim,
voxel_resolution=position_indices["voxel_resolution"],
)
position_indices[head_dim] = image_rotary_emb
query = RotaryEmbedding.apply_rotary_emb(query, image_rotary_emb)
key = RotaryEmbedding.apply_rotary_emb(key, image_rotary_emb)
return query, key
# Core attention processing
hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base(
attn,
hidden_states,
encoder_hidden_states,
attention_mask,
temb,
get_qkv_fn=get_qkv,
apply_rope_fn=apply_rope,
position_indices=position_indices,
n_pbrs=n_pbrs,
)
# Finalize output
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim)
hidden_states = hidden_states.to(hidden_states.dtype)
return AttnUtils.finalize_output(hidden_states, input_ndim, shape_info, attn, residual, attn.to_out)
class SelfAttnProcessor2_0(BaseAttnProcessor):
"""
Self-attention processor with PBR (Physically Based Rendering) material support.
This processor handles multiple PBR material types (e.g., albedo, metallic-roughness)
with separate attention computation paths for each material type.
"""
def __init__(self, **kwargs):
"""
Initialize self-attention processor with PBR support.
Args:
**kwargs: Arguments passed to BaseAttnProcessor initialization
"""
super().__init__(**kwargs)
self.register_pbr_modules(["qkv", "out", "add_kv"], **kwargs)
def process_single(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
token: Literal["albedo", "mr"] = "albedo",
multiple_devices=False,
*args,
**kwargs,
):
"""
Process attention for a single PBR material type.
Args:
attn: Attention module instance
hidden_states: Input hidden states tensor
encoder_hidden_states: Optional encoder hidden states for cross-attention
attention_mask: Optional attention mask tensor
temb: Optional temporal embedding tensor
token: PBR material type to process ("albedo", "mr", etc.)
multiple_devices: Whether to use multiple GPU devices
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
Processed attention output for the specified PBR material type
"""
target = attn if token == "albedo" else attn.processor
token_suffix = "" if token == "albedo" else "_" + token
# Device management (if needed)
if multiple_devices:
device = torch.device("cuda:0") if token == "albedo" else torch.device("cuda:1")
for attr in [f"to_q{token_suffix}", f"to_k{token_suffix}", f"to_v{token_suffix}", f"to_out{token_suffix}"]:
getattr(target, attr).to(device)
def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs):
return (
getattr(target, f"to_q{token_suffix}")(hidden_states),
getattr(target, f"to_k{token_suffix}")(encoder_hidden_states),
getattr(target, f"to_v{token_suffix}")(encoder_hidden_states),
)
# Core processing using shared logic
hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base(
attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv
)
# Finalize
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim)
hidden_states = hidden_states.to(hidden_states.dtype)
return AttnUtils.finalize_output(
hidden_states, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}")
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
"""
Apply self-attention with PBR material processing.
Processes multiple PBR material types sequentially, applying attention
computation for each material type separately and combining results.
Args:
attn: Attention module instance
hidden_states: Input hidden states tensor with PBR dimension
encoder_hidden_states: Optional encoder hidden states for cross-attention
attention_mask: Optional attention mask tensor
temb: Optional temporal embedding tensor
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
Combined attention output for all PBR material types
"""
AttnUtils.handle_deprecation_warning(args, kwargs)
B = hidden_states.size(0)
pbr_hidden_states = torch.split(hidden_states, 1, dim=1)
# Process each PBR setting
results = []
for token, pbr_hs in zip(self.pbr_setting, pbr_hidden_states):
processed_hs = rearrange(pbr_hs, "b n_pbrs n l c -> (b n_pbrs n) l c").to("cuda:0")
result = self.process_single(attn, processed_hs, None, attention_mask, temb, token, False)
results.append(result)
outputs = [rearrange(result, "(b n_pbrs n) l c -> b n_pbrs n l c", b=B, n_pbrs=1) for result in results]
return torch.cat(outputs, dim=1)
class RefAttnProcessor2_0(BaseAttnProcessor):
"""
Reference attention processor with shared value computation across PBR materials.
This processor computes query and key once, but uses separate value projections
for different PBR material types, enabling efficient multi-material processing.
"""
def __init__(self, **kwargs):
"""
Initialize reference attention processor.
Args:
**kwargs: Arguments passed to BaseAttnProcessor initialization
"""
super().__init__(**kwargs)
self.pbr_settings = self.pbr_setting # Alias for compatibility
self.register_pbr_modules(["v_only", "out"], **kwargs)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
"""
Apply reference attention with shared Q/K and separate V projections.
This method computes query and key tensors once and reuses them across
all PBR material types, while using separate value projections for each
material type to maintain material-specific information.
Args:
attn: Attention module instance
hidden_states: Input hidden states tensor
encoder_hidden_states: Optional encoder hidden states for cross-attention
attention_mask: Optional attention mask tensor
temb: Optional temporal embedding tensor
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
Stacked attention output for all PBR material types
"""
AttnUtils.handle_deprecation_warning(args, kwargs)
def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs):
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
# Concatenate values from all PBR settings
value_list = [attn.to_v(encoder_hidden_states)]
for token in ["_" + token for token in self.pbr_settings if token != "albedo"]:
value_list.append(getattr(attn.processor, f"to_v{token}")(encoder_hidden_states))
value = torch.cat(value_list, dim=-1)
return query, key, value
# Core processing
hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base(
attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv
)
# Split and process each PBR setting output
hidden_states_list = torch.split(hidden_states, head_dim, dim=-1)
output_hidden_states_list = []
for i, hs in enumerate(hidden_states_list):
hs = hs.transpose(1, 2).reshape(batch_size, -1, heads * head_dim).to(hs.dtype)
token_suffix = "_" + self.pbr_settings[i] if self.pbr_settings[i] != "albedo" else ""
target = attn if self.pbr_settings[i] == "albedo" else attn.processor
hs = AttnUtils.finalize_output(
hs, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}")
)
output_hidden_states_list.append(hs)
return torch.stack(output_hidden_states_list, dim=1)