Spaces:
Running
on
Zero
Running
on
Zero
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT | |
# except for the third-party components listed below. | |
# Hunyuan 3D does not impose any additional limitations beyond what is outlined | |
# in the repsective licenses of these third-party components. | |
# Users must comply with all terms and conditions of original licenses of these third-party | |
# components and must ensure that the usage of the third party components adheres to | |
# all relevant laws and regulations. | |
# For avoidance of doubts, Hunyuan 3D means the large language models and | |
# their software and algorithms, including trained model weights, parameters (including | |
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code, | |
# fine-tuning enabling code and other elements of the foregoing made publicly available | |
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Optional, Dict, Tuple, Union, Literal, List, Callable | |
from einops import rearrange | |
from diffusers.utils import deprecate | |
from diffusers.models.attention_processor import Attention, AttnProcessor | |
class AttnUtils: | |
""" | |
Shared utility functions for attention processing. | |
This class provides common operations used across different attention processors | |
to eliminate code duplication and improve maintainability. | |
""" | |
def check_pytorch_compatibility(): | |
""" | |
Check PyTorch compatibility for scaled_dot_product_attention. | |
Raises: | |
ImportError: If PyTorch version doesn't support scaled_dot_product_attention | |
""" | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | |
def handle_deprecation_warning(args, kwargs): | |
""" | |
Handle deprecation warning for the 'scale' argument. | |
Args: | |
args: Positional arguments passed to attention processor | |
kwargs: Keyword arguments passed to attention processor | |
""" | |
if len(args) > 0 or kwargs.get("scale", None) is not None: | |
deprecation_message = ( | |
"The `scale` argument is deprecated and will be ignored." | |
"Please remove it, as passing it will raise an error in the future." | |
"`scale` should directly be passed while calling the underlying pipeline component" | |
"i.e., via `cross_attention_kwargs`." | |
) | |
deprecate("scale", "1.0.0", deprecation_message) | |
def prepare_hidden_states( | |
hidden_states, attn, temb, spatial_norm_attr="spatial_norm", group_norm_attr="group_norm" | |
): | |
""" | |
Common preprocessing of hidden states for attention computation. | |
Args: | |
hidden_states: Input hidden states tensor | |
attn: Attention module instance | |
temb: Optional temporal embedding tensor | |
spatial_norm_attr: Attribute name for spatial normalization | |
group_norm_attr: Attribute name for group normalization | |
Returns: | |
Tuple of (processed_hidden_states, residual, input_ndim, shape_info) | |
""" | |
residual = hidden_states | |
spatial_norm = getattr(attn, spatial_norm_attr, None) | |
if spatial_norm is not None: | |
hidden_states = spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
else: | |
batch_size, channel, height, width = None, None, None, None | |
group_norm = getattr(attn, group_norm_attr, None) | |
if group_norm is not None: | |
hidden_states = group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
return hidden_states, residual, input_ndim, (batch_size, channel, height, width) | |
def prepare_attention_mask(attention_mask, attn, sequence_length, batch_size): | |
""" | |
Prepare attention mask for scaled_dot_product_attention. | |
Args: | |
attention_mask: Input attention mask tensor or None | |
attn: Attention module instance | |
sequence_length: Length of the sequence | |
batch_size: Batch size | |
Returns: | |
Prepared attention mask tensor reshaped for multi-head attention | |
""" | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
return attention_mask | |
def reshape_qkv_for_attention(tensor, batch_size, attn_heads, head_dim): | |
""" | |
Reshape Q/K/V tensors for multi-head attention computation. | |
Args: | |
tensor: Input tensor to reshape | |
batch_size: Batch size | |
attn_heads: Number of attention heads | |
head_dim: Dimension per attention head | |
Returns: | |
Reshaped tensor with shape [batch_size, attn_heads, seq_len, head_dim] | |
""" | |
return tensor.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2) | |
def apply_norms(query, key, norm_q, norm_k): | |
""" | |
Apply Q/K normalization layers if available. | |
Args: | |
query: Query tensor | |
key: Key tensor | |
norm_q: Query normalization layer (optional) | |
norm_k: Key normalization layer (optional) | |
Returns: | |
Tuple of (normalized_query, normalized_key) | |
""" | |
if norm_q is not None: | |
query = norm_q(query) | |
if norm_k is not None: | |
key = norm_k(key) | |
return query, key | |
def finalize_output(hidden_states, input_ndim, shape_info, attn, residual, to_out): | |
""" | |
Common output processing including projection, dropout, reshaping, and residual connection. | |
Args: | |
hidden_states: Processed hidden states from attention | |
input_ndim: Original input tensor dimensions | |
shape_info: Tuple containing original shape information | |
attn: Attention module instance | |
residual: Residual connection tensor | |
to_out: Output projection layers [linear, dropout] | |
Returns: | |
Final output tensor after all processing steps | |
""" | |
batch_size, channel, height, width = shape_info | |
# Apply output projection and dropout | |
hidden_states = to_out[0](hidden_states) | |
hidden_states = to_out[1](hidden_states) | |
# Reshape back if needed | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
# Apply residual connection | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
# Apply rescaling | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
# Base class for attention processors (eliminating initialization duplication) | |
class BaseAttnProcessor(nn.Module): | |
""" | |
Base class for attention processors with common initialization. | |
This base class provides shared parameter initialization and module registration | |
functionality to reduce code duplication across different attention processor types. | |
""" | |
def __init__( | |
self, | |
query_dim: int, | |
pbr_setting: List[str] = ["albedo", "mr"], | |
cross_attention_dim: Optional[int] = None, | |
heads: int = 8, | |
kv_heads: Optional[int] = None, | |
dim_head: int = 64, | |
dropout: float = 0.0, | |
bias: bool = False, | |
upcast_attention: bool = False, | |
upcast_softmax: bool = False, | |
cross_attention_norm: Optional[str] = None, | |
cross_attention_norm_num_groups: int = 32, | |
qk_norm: Optional[str] = None, | |
added_kv_proj_dim: Optional[int] = None, | |
added_proj_bias: Optional[bool] = True, | |
norm_num_groups: Optional[int] = None, | |
spatial_norm_dim: Optional[int] = None, | |
out_bias: bool = True, | |
scale_qk: bool = True, | |
only_cross_attention: bool = False, | |
eps: float = 1e-5, | |
rescale_output_factor: float = 1.0, | |
residual_connection: bool = False, | |
_from_deprecated_attn_block: bool = False, | |
processor: Optional["AttnProcessor"] = None, | |
out_dim: int = None, | |
out_context_dim: int = None, | |
context_pre_only=None, | |
pre_only=False, | |
elementwise_affine: bool = True, | |
is_causal: bool = False, | |
**kwargs, | |
): | |
""" | |
Initialize base attention processor with common parameters. | |
Args: | |
query_dim: Dimension of query features | |
pbr_setting: List of PBR material types to process (e.g., ["albedo", "mr"]) | |
cross_attention_dim: Dimension of cross-attention features (optional) | |
heads: Number of attention heads | |
kv_heads: Number of key-value heads for grouped query attention (optional) | |
dim_head: Dimension per attention head | |
dropout: Dropout rate | |
bias: Whether to use bias in linear projections | |
upcast_attention: Whether to upcast attention computation to float32 | |
upcast_softmax: Whether to upcast softmax computation to float32 | |
cross_attention_norm: Type of cross-attention normalization (optional) | |
cross_attention_norm_num_groups: Number of groups for cross-attention norm | |
qk_norm: Type of query-key normalization (optional) | |
added_kv_proj_dim: Dimension for additional key-value projections (optional) | |
added_proj_bias: Whether to use bias in additional projections | |
norm_num_groups: Number of groups for normalization (optional) | |
spatial_norm_dim: Dimension for spatial normalization (optional) | |
out_bias: Whether to use bias in output projection | |
scale_qk: Whether to scale query-key products | |
only_cross_attention: Whether to only perform cross-attention | |
eps: Small epsilon value for numerical stability | |
rescale_output_factor: Factor to rescale output values | |
residual_connection: Whether to use residual connections | |
_from_deprecated_attn_block: Flag for deprecated attention blocks | |
processor: Optional attention processor instance | |
out_dim: Output dimension (optional) | |
out_context_dim: Output context dimension (optional) | |
context_pre_only: Whether to only process context in pre-processing | |
pre_only: Whether to only perform pre-processing | |
elementwise_affine: Whether to use element-wise affine transformations | |
is_causal: Whether to use causal attention masking | |
**kwargs: Additional keyword arguments | |
""" | |
super().__init__() | |
AttnUtils.check_pytorch_compatibility() | |
# Store common attributes | |
self.pbr_setting = pbr_setting | |
self.n_pbr_tokens = len(self.pbr_setting) | |
self.inner_dim = out_dim if out_dim is not None else dim_head * heads | |
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads | |
self.query_dim = query_dim | |
self.use_bias = bias | |
self.is_cross_attention = cross_attention_dim is not None | |
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim | |
self.upcast_attention = upcast_attention | |
self.upcast_softmax = upcast_softmax | |
self.rescale_output_factor = rescale_output_factor | |
self.residual_connection = residual_connection | |
self.dropout = dropout | |
self.fused_projections = False | |
self.out_dim = out_dim if out_dim is not None else query_dim | |
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim | |
self.context_pre_only = context_pre_only | |
self.pre_only = pre_only | |
self.is_causal = is_causal | |
self._from_deprecated_attn_block = _from_deprecated_attn_block | |
self.scale_qk = scale_qk | |
self.scale = dim_head**-0.5 if self.scale_qk else 1.0 | |
self.heads = out_dim // dim_head if out_dim is not None else heads | |
self.sliceable_head_dim = heads | |
self.added_kv_proj_dim = added_kv_proj_dim | |
self.only_cross_attention = only_cross_attention | |
self.added_proj_bias = added_proj_bias | |
# Validation | |
if self.added_kv_proj_dim is None and self.only_cross_attention: | |
raise ValueError( | |
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None." | |
"Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." | |
) | |
def register_pbr_modules(self, module_types: List[str], **kwargs): | |
""" | |
Generic PBR module registration to eliminate code repetition. | |
Dynamically registers PyTorch modules for different PBR material types | |
based on the specified module types and PBR settings. | |
Args: | |
module_types: List of module types to register ("qkv", "v_only", "out", "add_kv") | |
**kwargs: Additional arguments for module configuration | |
""" | |
for pbr_token in self.pbr_setting: | |
if pbr_token == "albedo": | |
continue | |
for module_type in module_types: | |
if module_type == "qkv": | |
self.register_module( | |
f"to_q_{pbr_token}", nn.Linear(self.query_dim, self.inner_dim, bias=self.use_bias) | |
) | |
self.register_module( | |
f"to_k_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias) | |
) | |
self.register_module( | |
f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias) | |
) | |
elif module_type == "v_only": | |
self.register_module( | |
f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias) | |
) | |
elif module_type == "out": | |
if not self.pre_only: | |
self.register_module( | |
f"to_out_{pbr_token}", | |
nn.ModuleList( | |
[ | |
nn.Linear(self.inner_dim, self.out_dim, bias=kwargs.get("out_bias", True)), | |
nn.Dropout(self.dropout), | |
] | |
), | |
) | |
else: | |
self.register_module(f"to_out_{pbr_token}", None) | |
elif module_type == "add_kv": | |
if self.added_kv_proj_dim is not None: | |
self.register_module( | |
f"add_k_proj_{pbr_token}", | |
nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias), | |
) | |
self.register_module( | |
f"add_v_proj_{pbr_token}", | |
nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias), | |
) | |
else: | |
self.register_module(f"add_k_proj_{pbr_token}", None) | |
self.register_module(f"add_v_proj_{pbr_token}", None) | |
# Rotary Position Embedding utilities (specialized for PoseRoPE) | |
class RotaryEmbedding: | |
""" | |
Rotary position embedding utilities for 3D spatial attention. | |
Provides functions to compute and apply rotary position embeddings (RoPE) | |
for 1D, 3D spatial coordinates used in 3D-aware attention mechanisms. | |
""" | |
def get_1d_rotary_pos_embed(dim: int, pos: torch.Tensor, theta: float = 10000.0, linear_factor=1.0, ntk_factor=1.0): | |
""" | |
Compute 1D rotary position embeddings. | |
Args: | |
dim: Embedding dimension (must be even) | |
pos: Position tensor | |
theta: Base frequency for rotary embeddings | |
linear_factor: Linear scaling factor | |
ntk_factor: NTK (Neural Tangent Kernel) scaling factor | |
Returns: | |
Tuple of (cos_embeddings, sin_embeddings) | |
""" | |
assert dim % 2 == 0 | |
theta = theta * ntk_factor | |
freqs = ( | |
1.0 | |
/ (theta ** (torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device)[: (dim // 2)] / dim)) | |
/ linear_factor | |
) | |
freqs = torch.outer(pos, freqs) | |
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() | |
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() | |
return freqs_cos, freqs_sin | |
def get_3d_rotary_pos_embed(position, embed_dim, voxel_resolution, theta: int = 10000): | |
""" | |
Compute 3D rotary position embeddings for spatial coordinates. | |
Args: | |
position: 3D position tensor with shape [..., 3] | |
embed_dim: Embedding dimension | |
voxel_resolution: Resolution of the voxel grid | |
theta: Base frequency for rotary embeddings | |
Returns: | |
Tuple of (cos_embeddings, sin_embeddings) for 3D positions | |
""" | |
assert position.shape[-1] == 3 | |
dim_xy = embed_dim // 8 * 3 | |
dim_z = embed_dim // 8 * 2 | |
grid = torch.arange(voxel_resolution, dtype=torch.float32, device=position.device) | |
freqs_xy = RotaryEmbedding.get_1d_rotary_pos_embed(dim_xy, grid, theta=theta) | |
freqs_z = RotaryEmbedding.get_1d_rotary_pos_embed(dim_z, grid, theta=theta) | |
xy_cos, xy_sin = freqs_xy | |
z_cos, z_sin = freqs_z | |
embed_flattn = position.view(-1, position.shape[-1]) | |
x_cos = xy_cos[embed_flattn[:, 0], :] | |
x_sin = xy_sin[embed_flattn[:, 0], :] | |
y_cos = xy_cos[embed_flattn[:, 1], :] | |
y_sin = xy_sin[embed_flattn[:, 1], :] | |
z_cos = z_cos[embed_flattn[:, 2], :] | |
z_sin = z_sin[embed_flattn[:, 2], :] | |
cos = torch.cat((x_cos, y_cos, z_cos), dim=-1) | |
sin = torch.cat((x_sin, y_sin, z_sin), dim=-1) | |
cos = cos.view(*position.shape[:-1], embed_dim) | |
sin = sin.view(*position.shape[:-1], embed_dim) | |
return cos, sin | |
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]): | |
""" | |
Apply rotary position embeddings to input tensor. | |
Args: | |
x: Input tensor to apply rotary embeddings to | |
freqs_cis: Tuple of (cos_embeddings, sin_embeddings) or single tensor | |
Returns: | |
Tensor with rotary position embeddings applied | |
""" | |
cos, sin = freqs_cis | |
cos, sin = cos.to(x.device), sin.to(x.device) | |
cos = cos.unsqueeze(1) | |
sin = sin.unsqueeze(1) | |
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) | |
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) | |
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) | |
return out | |
# Core attention processing logic (eliminating major duplication) | |
class AttnCore: | |
""" | |
Core attention processing logic shared across processors. | |
This class provides the fundamental attention computation pipeline | |
that can be reused across different attention processor implementations. | |
""" | |
def process_attention_base( | |
attn: Attention, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
temb: Optional[torch.Tensor] = None, | |
get_qkv_fn: Callable = None, | |
apply_rope_fn: Optional[Callable] = None, | |
**kwargs, | |
): | |
""" | |
Generic attention processing core shared across different processors. | |
This function implements the common attention computation pipeline including: | |
1. Hidden state preprocessing | |
2. Attention mask preparation | |
3. Q/K/V computation via provided function | |
4. Tensor reshaping for multi-head attention | |
5. Optional normalization and RoPE application | |
6. Scaled dot-product attention computation | |
Args: | |
attn: Attention module instance | |
hidden_states: Input hidden states tensor | |
encoder_hidden_states: Optional encoder hidden states for cross-attention | |
attention_mask: Optional attention mask tensor | |
temb: Optional temporal embedding tensor | |
get_qkv_fn: Function to compute Q, K, V tensors | |
apply_rope_fn: Optional function to apply rotary position embeddings | |
**kwargs: Additional keyword arguments passed to subfunctions | |
Returns: | |
Tuple containing (attention_output, residual, input_ndim, shape_info, | |
batch_size, num_heads, head_dim) | |
""" | |
# Prepare hidden states | |
hidden_states, residual, input_ndim, shape_info = AttnUtils.prepare_hidden_states(hidden_states, attn, temb) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
# Prepare attention mask | |
attention_mask = AttnUtils.prepare_attention_mask(attention_mask, attn, sequence_length, batch_size) | |
# Get Q, K, V | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
query, key, value = get_qkv_fn(attn, hidden_states, encoder_hidden_states, **kwargs) | |
# Reshape for attention | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = AttnUtils.reshape_qkv_for_attention(query, batch_size, attn.heads, head_dim) | |
key = AttnUtils.reshape_qkv_for_attention(key, batch_size, attn.heads, head_dim) | |
value = AttnUtils.reshape_qkv_for_attention(value, batch_size, attn.heads, value.shape[-1] // attn.heads) | |
# Apply normalization | |
query, key = AttnUtils.apply_norms(query, key, getattr(attn, "norm_q", None), getattr(attn, "norm_k", None)) | |
# Apply RoPE if provided | |
if apply_rope_fn is not None: | |
query, key = apply_rope_fn(query, key, head_dim, **kwargs) | |
# Compute attention | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
return hidden_states, residual, input_ndim, shape_info, batch_size, attn.heads, head_dim | |
# Specific processor implementations (minimal unique code) | |
class PoseRoPEAttnProcessor2_0: | |
""" | |
Attention processor with Rotary Position Encoding (RoPE) for 3D spatial awareness. | |
This processor extends standard attention with 3D rotary position embeddings | |
to provide spatial awareness for 3D scene understanding tasks. | |
""" | |
def __init__(self): | |
"""Initialize the RoPE attention processor.""" | |
AttnUtils.check_pytorch_compatibility() | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_indices: Dict = None, | |
temb: Optional[torch.Tensor] = None, | |
n_pbrs=1, | |
*args, | |
**kwargs, | |
) -> torch.Tensor: | |
""" | |
Apply RoPE-enhanced attention computation. | |
Args: | |
attn: Attention module instance | |
hidden_states: Input hidden states tensor | |
encoder_hidden_states: Optional encoder hidden states for cross-attention | |
attention_mask: Optional attention mask tensor | |
position_indices: Dictionary containing 3D position information for RoPE | |
temb: Optional temporal embedding tensor | |
n_pbrs: Number of PBR material types | |
*args: Additional positional arguments | |
**kwargs: Additional keyword arguments | |
Returns: | |
Attention output tensor with applied rotary position encodings | |
""" | |
AttnUtils.handle_deprecation_warning(args, kwargs) | |
def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs): | |
return attn.to_q(hidden_states), attn.to_k(encoder_hidden_states), attn.to_v(encoder_hidden_states) | |
def apply_rope(query, key, head_dim, **kwargs): | |
if position_indices is not None: | |
if head_dim in position_indices: | |
image_rotary_emb = position_indices[head_dim] | |
else: | |
image_rotary_emb = RotaryEmbedding.get_3d_rotary_pos_embed( | |
rearrange( | |
position_indices["voxel_indices"].unsqueeze(1).repeat(1, n_pbrs, 1, 1), | |
"b n_pbrs l c -> (b n_pbrs) l c", | |
), | |
head_dim, | |
voxel_resolution=position_indices["voxel_resolution"], | |
) | |
position_indices[head_dim] = image_rotary_emb | |
query = RotaryEmbedding.apply_rotary_emb(query, image_rotary_emb) | |
key = RotaryEmbedding.apply_rotary_emb(key, image_rotary_emb) | |
return query, key | |
# Core attention processing | |
hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base( | |
attn, | |
hidden_states, | |
encoder_hidden_states, | |
attention_mask, | |
temb, | |
get_qkv_fn=get_qkv, | |
apply_rope_fn=apply_rope, | |
position_indices=position_indices, | |
n_pbrs=n_pbrs, | |
) | |
# Finalize output | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim) | |
hidden_states = hidden_states.to(hidden_states.dtype) | |
return AttnUtils.finalize_output(hidden_states, input_ndim, shape_info, attn, residual, attn.to_out) | |
class SelfAttnProcessor2_0(BaseAttnProcessor): | |
""" | |
Self-attention processor with PBR (Physically Based Rendering) material support. | |
This processor handles multiple PBR material types (e.g., albedo, metallic-roughness) | |
with separate attention computation paths for each material type. | |
""" | |
def __init__(self, **kwargs): | |
""" | |
Initialize self-attention processor with PBR support. | |
Args: | |
**kwargs: Arguments passed to BaseAttnProcessor initialization | |
""" | |
super().__init__(**kwargs) | |
self.register_pbr_modules(["qkv", "out", "add_kv"], **kwargs) | |
def process_single( | |
self, | |
attn: Attention, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
temb: Optional[torch.Tensor] = None, | |
token: Literal["albedo", "mr"] = "albedo", | |
multiple_devices=False, | |
*args, | |
**kwargs, | |
): | |
""" | |
Process attention for a single PBR material type. | |
Args: | |
attn: Attention module instance | |
hidden_states: Input hidden states tensor | |
encoder_hidden_states: Optional encoder hidden states for cross-attention | |
attention_mask: Optional attention mask tensor | |
temb: Optional temporal embedding tensor | |
token: PBR material type to process ("albedo", "mr", etc.) | |
multiple_devices: Whether to use multiple GPU devices | |
*args: Additional positional arguments | |
**kwargs: Additional keyword arguments | |
Returns: | |
Processed attention output for the specified PBR material type | |
""" | |
target = attn if token == "albedo" else attn.processor | |
token_suffix = "" if token == "albedo" else "_" + token | |
# Device management (if needed) | |
if multiple_devices: | |
device = torch.device("cuda:0") if token == "albedo" else torch.device("cuda:1") | |
for attr in [f"to_q{token_suffix}", f"to_k{token_suffix}", f"to_v{token_suffix}", f"to_out{token_suffix}"]: | |
getattr(target, attr).to(device) | |
def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs): | |
return ( | |
getattr(target, f"to_q{token_suffix}")(hidden_states), | |
getattr(target, f"to_k{token_suffix}")(encoder_hidden_states), | |
getattr(target, f"to_v{token_suffix}")(encoder_hidden_states), | |
) | |
# Core processing using shared logic | |
hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base( | |
attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv | |
) | |
# Finalize | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim) | |
hidden_states = hidden_states.to(hidden_states.dtype) | |
return AttnUtils.finalize_output( | |
hidden_states, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}") | |
) | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
temb: Optional[torch.Tensor] = None, | |
*args, | |
**kwargs, | |
) -> torch.Tensor: | |
""" | |
Apply self-attention with PBR material processing. | |
Processes multiple PBR material types sequentially, applying attention | |
computation for each material type separately and combining results. | |
Args: | |
attn: Attention module instance | |
hidden_states: Input hidden states tensor with PBR dimension | |
encoder_hidden_states: Optional encoder hidden states for cross-attention | |
attention_mask: Optional attention mask tensor | |
temb: Optional temporal embedding tensor | |
*args: Additional positional arguments | |
**kwargs: Additional keyword arguments | |
Returns: | |
Combined attention output for all PBR material types | |
""" | |
AttnUtils.handle_deprecation_warning(args, kwargs) | |
B = hidden_states.size(0) | |
pbr_hidden_states = torch.split(hidden_states, 1, dim=1) | |
# Process each PBR setting | |
results = [] | |
for token, pbr_hs in zip(self.pbr_setting, pbr_hidden_states): | |
processed_hs = rearrange(pbr_hs, "b n_pbrs n l c -> (b n_pbrs n) l c").to("cuda:0") | |
result = self.process_single(attn, processed_hs, None, attention_mask, temb, token, False) | |
results.append(result) | |
outputs = [rearrange(result, "(b n_pbrs n) l c -> b n_pbrs n l c", b=B, n_pbrs=1) for result in results] | |
return torch.cat(outputs, dim=1) | |
class RefAttnProcessor2_0(BaseAttnProcessor): | |
""" | |
Reference attention processor with shared value computation across PBR materials. | |
This processor computes query and key once, but uses separate value projections | |
for different PBR material types, enabling efficient multi-material processing. | |
""" | |
def __init__(self, **kwargs): | |
""" | |
Initialize reference attention processor. | |
Args: | |
**kwargs: Arguments passed to BaseAttnProcessor initialization | |
""" | |
super().__init__(**kwargs) | |
self.pbr_settings = self.pbr_setting # Alias for compatibility | |
self.register_pbr_modules(["v_only", "out"], **kwargs) | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
temb: Optional[torch.Tensor] = None, | |
*args, | |
**kwargs, | |
) -> torch.Tensor: | |
""" | |
Apply reference attention with shared Q/K and separate V projections. | |
This method computes query and key tensors once and reuses them across | |
all PBR material types, while using separate value projections for each | |
material type to maintain material-specific information. | |
Args: | |
attn: Attention module instance | |
hidden_states: Input hidden states tensor | |
encoder_hidden_states: Optional encoder hidden states for cross-attention | |
attention_mask: Optional attention mask tensor | |
temb: Optional temporal embedding tensor | |
*args: Additional positional arguments | |
**kwargs: Additional keyword arguments | |
Returns: | |
Stacked attention output for all PBR material types | |
""" | |
AttnUtils.handle_deprecation_warning(args, kwargs) | |
def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs): | |
query = attn.to_q(hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
# Concatenate values from all PBR settings | |
value_list = [attn.to_v(encoder_hidden_states)] | |
for token in ["_" + token for token in self.pbr_settings if token != "albedo"]: | |
value_list.append(getattr(attn.processor, f"to_v{token}")(encoder_hidden_states)) | |
value = torch.cat(value_list, dim=-1) | |
return query, key, value | |
# Core processing | |
hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base( | |
attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv | |
) | |
# Split and process each PBR setting output | |
hidden_states_list = torch.split(hidden_states, head_dim, dim=-1) | |
output_hidden_states_list = [] | |
for i, hs in enumerate(hidden_states_list): | |
hs = hs.transpose(1, 2).reshape(batch_size, -1, heads * head_dim).to(hs.dtype) | |
token_suffix = "_" + self.pbr_settings[i] if self.pbr_settings[i] != "albedo" else "" | |
target = attn if self.pbr_settings[i] == "albedo" else attn.processor | |
hs = AttnUtils.finalize_output( | |
hs, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}") | |
) | |
output_hidden_states_list.append(hs) | |
return torch.stack(output_hidden_states_list, dim=1) | |