Image-to-3D
Hunyuan3D-2
Diffusers
Safetensors
English
Chinese
text-to-3d
Huiwenshi commited on
Commit
6e9aaf2
·
verified ·
1 Parent(s): 594ad89

Upload hunyuan3d-paintpbr-v2-1/unet/attn_processor.py with huggingface_hub

Browse files
hunyuan3d-paintpbr-v2-1/unet/attn_processor.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
+ # except for the third-party components listed below.
3
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
+ # in the repsective licenses of these third-party components.
5
+ # Users must comply with all terms and conditions of original licenses of these third-party
6
+ # components and must ensure that the usage of the third party components adheres to
7
+ # all relevant laws and regulations.
8
+
9
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
10
+ # their software and algorithms, including trained model weights, parameters (including
11
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
13
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from typing import Optional, Dict, Tuple, Union, Literal, List, Callable
19
+ from einops import rearrange
20
+ from diffusers.utils import deprecate
21
+ from diffusers.models.attention_processor import Attention, AttnProcessor
22
+
23
+
24
+ class AttnUtils:
25
+ """
26
+ Shared utility functions for attention processing.
27
+
28
+ This class provides common operations used across different attention processors
29
+ to eliminate code duplication and improve maintainability.
30
+ """
31
+
32
+ @staticmethod
33
+ def check_pytorch_compatibility():
34
+ """
35
+ Check PyTorch compatibility for scaled_dot_product_attention.
36
+
37
+ Raises:
38
+ ImportError: If PyTorch version doesn't support scaled_dot_product_attention
39
+ """
40
+ if not hasattr(F, "scaled_dot_product_attention"):
41
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
42
+
43
+ @staticmethod
44
+ def handle_deprecation_warning(args, kwargs):
45
+ """
46
+ Handle deprecation warning for the 'scale' argument.
47
+
48
+ Args:
49
+ args: Positional arguments passed to attention processor
50
+ kwargs: Keyword arguments passed to attention processor
51
+ """
52
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
53
+ deprecation_message = (
54
+ "The `scale` argument is deprecated and will be ignored."
55
+ "Please remove it, as passing it will raise an error in the future."
56
+ "`scale` should directly be passed while calling the underlying pipeline component"
57
+ "i.e., via `cross_attention_kwargs`."
58
+ )
59
+ deprecate("scale", "1.0.0", deprecation_message)
60
+
61
+ @staticmethod
62
+ def prepare_hidden_states(
63
+ hidden_states, attn, temb, spatial_norm_attr="spatial_norm", group_norm_attr="group_norm"
64
+ ):
65
+ """
66
+ Common preprocessing of hidden states for attention computation.
67
+
68
+ Args:
69
+ hidden_states: Input hidden states tensor
70
+ attn: Attention module instance
71
+ temb: Optional temporal embedding tensor
72
+ spatial_norm_attr: Attribute name for spatial normalization
73
+ group_norm_attr: Attribute name for group normalization
74
+
75
+ Returns:
76
+ Tuple of (processed_hidden_states, residual, input_ndim, shape_info)
77
+ """
78
+ residual = hidden_states
79
+
80
+ spatial_norm = getattr(attn, spatial_norm_attr, None)
81
+ if spatial_norm is not None:
82
+ hidden_states = spatial_norm(hidden_states, temb)
83
+
84
+ input_ndim = hidden_states.ndim
85
+
86
+ if input_ndim == 4:
87
+ batch_size, channel, height, width = hidden_states.shape
88
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
89
+ else:
90
+ batch_size, channel, height, width = None, None, None, None
91
+
92
+ group_norm = getattr(attn, group_norm_attr, None)
93
+ if group_norm is not None:
94
+ hidden_states = group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
95
+
96
+ return hidden_states, residual, input_ndim, (batch_size, channel, height, width)
97
+
98
+ @staticmethod
99
+ def prepare_attention_mask(attention_mask, attn, sequence_length, batch_size):
100
+ """
101
+ Prepare attention mask for scaled_dot_product_attention.
102
+
103
+ Args:
104
+ attention_mask: Input attention mask tensor or None
105
+ attn: Attention module instance
106
+ sequence_length: Length of the sequence
107
+ batch_size: Batch size
108
+
109
+ Returns:
110
+ Prepared attention mask tensor reshaped for multi-head attention
111
+ """
112
+ if attention_mask is not None:
113
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
114
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
115
+ return attention_mask
116
+
117
+ @staticmethod
118
+ def reshape_qkv_for_attention(tensor, batch_size, attn_heads, head_dim):
119
+ """
120
+ Reshape Q/K/V tensors for multi-head attention computation.
121
+
122
+ Args:
123
+ tensor: Input tensor to reshape
124
+ batch_size: Batch size
125
+ attn_heads: Number of attention heads
126
+ head_dim: Dimension per attention head
127
+
128
+ Returns:
129
+ Reshaped tensor with shape [batch_size, attn_heads, seq_len, head_dim]
130
+ """
131
+ return tensor.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
132
+
133
+ @staticmethod
134
+ def apply_norms(query, key, norm_q, norm_k):
135
+ """
136
+ Apply Q/K normalization layers if available.
137
+
138
+ Args:
139
+ query: Query tensor
140
+ key: Key tensor
141
+ norm_q: Query normalization layer (optional)
142
+ norm_k: Key normalization layer (optional)
143
+
144
+ Returns:
145
+ Tuple of (normalized_query, normalized_key)
146
+ """
147
+ if norm_q is not None:
148
+ query = norm_q(query)
149
+ if norm_k is not None:
150
+ key = norm_k(key)
151
+ return query, key
152
+
153
+ @staticmethod
154
+ def finalize_output(hidden_states, input_ndim, shape_info, attn, residual, to_out):
155
+ """
156
+ Common output processing including projection, dropout, reshaping, and residual connection.
157
+
158
+ Args:
159
+ hidden_states: Processed hidden states from attention
160
+ input_ndim: Original input tensor dimensions
161
+ shape_info: Tuple containing original shape information
162
+ attn: Attention module instance
163
+ residual: Residual connection tensor
164
+ to_out: Output projection layers [linear, dropout]
165
+
166
+ Returns:
167
+ Final output tensor after all processing steps
168
+ """
169
+ batch_size, channel, height, width = shape_info
170
+
171
+ # Apply output projection and dropout
172
+ hidden_states = to_out[0](hidden_states)
173
+ hidden_states = to_out[1](hidden_states)
174
+
175
+ # Reshape back if needed
176
+ if input_ndim == 4:
177
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
178
+
179
+ # Apply residual connection
180
+ if attn.residual_connection:
181
+ hidden_states = hidden_states + residual
182
+
183
+ # Apply rescaling
184
+ hidden_states = hidden_states / attn.rescale_output_factor
185
+ return hidden_states
186
+
187
+
188
+ # Base class for attention processors (eliminating initialization duplication)
189
+ class BaseAttnProcessor(nn.Module):
190
+ """
191
+ Base class for attention processors with common initialization.
192
+
193
+ This base class provides shared parameter initialization and module registration
194
+ functionality to reduce code duplication across different attention processor types.
195
+ """
196
+
197
+ def __init__(
198
+ self,
199
+ query_dim: int,
200
+ pbr_setting: List[str] = ["albedo", "mr"],
201
+ cross_attention_dim: Optional[int] = None,
202
+ heads: int = 8,
203
+ kv_heads: Optional[int] = None,
204
+ dim_head: int = 64,
205
+ dropout: float = 0.0,
206
+ bias: bool = False,
207
+ upcast_attention: bool = False,
208
+ upcast_softmax: bool = False,
209
+ cross_attention_norm: Optional[str] = None,
210
+ cross_attention_norm_num_groups: int = 32,
211
+ qk_norm: Optional[str] = None,
212
+ added_kv_proj_dim: Optional[int] = None,
213
+ added_proj_bias: Optional[bool] = True,
214
+ norm_num_groups: Optional[int] = None,
215
+ spatial_norm_dim: Optional[int] = None,
216
+ out_bias: bool = True,
217
+ scale_qk: bool = True,
218
+ only_cross_attention: bool = False,
219
+ eps: float = 1e-5,
220
+ rescale_output_factor: float = 1.0,
221
+ residual_connection: bool = False,
222
+ _from_deprecated_attn_block: bool = False,
223
+ processor: Optional["AttnProcessor"] = None,
224
+ out_dim: int = None,
225
+ out_context_dim: int = None,
226
+ context_pre_only=None,
227
+ pre_only=False,
228
+ elementwise_affine: bool = True,
229
+ is_causal: bool = False,
230
+ **kwargs,
231
+ ):
232
+ """
233
+ Initialize base attention processor with common parameters.
234
+
235
+ Args:
236
+ query_dim: Dimension of query features
237
+ pbr_setting: List of PBR material types to process (e.g., ["albedo", "mr"])
238
+ cross_attention_dim: Dimension of cross-attention features (optional)
239
+ heads: Number of attention heads
240
+ kv_heads: Number of key-value heads for grouped query attention (optional)
241
+ dim_head: Dimension per attention head
242
+ dropout: Dropout rate
243
+ bias: Whether to use bias in linear projections
244
+ upcast_attention: Whether to upcast attention computation to float32
245
+ upcast_softmax: Whether to upcast softmax computation to float32
246
+ cross_attention_norm: Type of cross-attention normalization (optional)
247
+ cross_attention_norm_num_groups: Number of groups for cross-attention norm
248
+ qk_norm: Type of query-key normalization (optional)
249
+ added_kv_proj_dim: Dimension for additional key-value projections (optional)
250
+ added_proj_bias: Whether to use bias in additional projections
251
+ norm_num_groups: Number of groups for normalization (optional)
252
+ spatial_norm_dim: Dimension for spatial normalization (optional)
253
+ out_bias: Whether to use bias in output projection
254
+ scale_qk: Whether to scale query-key products
255
+ only_cross_attention: Whether to only perform cross-attention
256
+ eps: Small epsilon value for numerical stability
257
+ rescale_output_factor: Factor to rescale output values
258
+ residual_connection: Whether to use residual connections
259
+ _from_deprecated_attn_block: Flag for deprecated attention blocks
260
+ processor: Optional attention processor instance
261
+ out_dim: Output dimension (optional)
262
+ out_context_dim: Output context dimension (optional)
263
+ context_pre_only: Whether to only process context in pre-processing
264
+ pre_only: Whether to only perform pre-processing
265
+ elementwise_affine: Whether to use element-wise affine transformations
266
+ is_causal: Whether to use causal attention masking
267
+ **kwargs: Additional keyword arguments
268
+ """
269
+ super().__init__()
270
+ AttnUtils.check_pytorch_compatibility()
271
+
272
+ # Store common attributes
273
+ self.pbr_setting = pbr_setting
274
+ self.n_pbr_tokens = len(self.pbr_setting)
275
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
276
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
277
+ self.query_dim = query_dim
278
+ self.use_bias = bias
279
+ self.is_cross_attention = cross_attention_dim is not None
280
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
281
+ self.upcast_attention = upcast_attention
282
+ self.upcast_softmax = upcast_softmax
283
+ self.rescale_output_factor = rescale_output_factor
284
+ self.residual_connection = residual_connection
285
+ self.dropout = dropout
286
+ self.fused_projections = False
287
+ self.out_dim = out_dim if out_dim is not None else query_dim
288
+ self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
289
+ self.context_pre_only = context_pre_only
290
+ self.pre_only = pre_only
291
+ self.is_causal = is_causal
292
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
293
+ self.scale_qk = scale_qk
294
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
295
+ self.heads = out_dim // dim_head if out_dim is not None else heads
296
+ self.sliceable_head_dim = heads
297
+ self.added_kv_proj_dim = added_kv_proj_dim
298
+ self.only_cross_attention = only_cross_attention
299
+ self.added_proj_bias = added_proj_bias
300
+
301
+ # Validation
302
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
303
+ raise ValueError(
304
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None."
305
+ "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
306
+ )
307
+
308
+ def register_pbr_modules(self, module_types: List[str], **kwargs):
309
+ """
310
+ Generic PBR module registration to eliminate code repetition.
311
+
312
+ Dynamically registers PyTorch modules for different PBR material types
313
+ based on the specified module types and PBR settings.
314
+
315
+ Args:
316
+ module_types: List of module types to register ("qkv", "v_only", "out", "add_kv")
317
+ **kwargs: Additional arguments for module configuration
318
+ """
319
+ for pbr_token in self.pbr_setting:
320
+ if pbr_token == "albedo":
321
+ continue
322
+
323
+ for module_type in module_types:
324
+ if module_type == "qkv":
325
+ self.register_module(
326
+ f"to_q_{pbr_token}", nn.Linear(self.query_dim, self.inner_dim, bias=self.use_bias)
327
+ )
328
+ self.register_module(
329
+ f"to_k_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias)
330
+ )
331
+ self.register_module(
332
+ f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias)
333
+ )
334
+ elif module_type == "v_only":
335
+ self.register_module(
336
+ f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias)
337
+ )
338
+ elif module_type == "out":
339
+ if not self.pre_only:
340
+ self.register_module(
341
+ f"to_out_{pbr_token}",
342
+ nn.ModuleList(
343
+ [
344
+ nn.Linear(self.inner_dim, self.out_dim, bias=kwargs.get("out_bias", True)),
345
+ nn.Dropout(self.dropout),
346
+ ]
347
+ ),
348
+ )
349
+ else:
350
+ self.register_module(f"to_out_{pbr_token}", None)
351
+ elif module_type == "add_kv":
352
+ if self.added_kv_proj_dim is not None:
353
+ self.register_module(
354
+ f"add_k_proj_{pbr_token}",
355
+ nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias),
356
+ )
357
+ self.register_module(
358
+ f"add_v_proj_{pbr_token}",
359
+ nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias),
360
+ )
361
+ else:
362
+ self.register_module(f"add_k_proj_{pbr_token}", None)
363
+ self.register_module(f"add_v_proj_{pbr_token}", None)
364
+
365
+
366
+ # Rotary Position Embedding utilities (specialized for PoseRoPE)
367
+ class RotaryEmbedding:
368
+ """
369
+ Rotary position embedding utilities for 3D spatial attention.
370
+
371
+ Provides functions to compute and apply rotary position embeddings (RoPE)
372
+ for 1D, 3D spatial coordinates used in 3D-aware attention mechanisms.
373
+ """
374
+
375
+ @staticmethod
376
+ def get_1d_rotary_pos_embed(dim: int, pos: torch.Tensor, theta: float = 10000.0, linear_factor=1.0, ntk_factor=1.0):
377
+ """
378
+ Compute 1D rotary position embeddings.
379
+
380
+ Args:
381
+ dim: Embedding dimension (must be even)
382
+ pos: Position tensor
383
+ theta: Base frequency for rotary embeddings
384
+ linear_factor: Linear scaling factor
385
+ ntk_factor: NTK (Neural Tangent Kernel) scaling factor
386
+
387
+ Returns:
388
+ Tuple of (cos_embeddings, sin_embeddings)
389
+ """
390
+ assert dim % 2 == 0
391
+ theta = theta * ntk_factor
392
+ freqs = (
393
+ 1.0
394
+ / (theta ** (torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device)[: (dim // 2)] / dim))
395
+ / linear_factor
396
+ )
397
+ freqs = torch.outer(pos, freqs)
398
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float()
399
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float()
400
+ return freqs_cos, freqs_sin
401
+
402
+ @staticmethod
403
+ def get_3d_rotary_pos_embed(position, embed_dim, voxel_resolution, theta: int = 10000):
404
+ """
405
+ Compute 3D rotary position embeddings for spatial coordinates.
406
+
407
+ Args:
408
+ position: 3D position tensor with shape [..., 3]
409
+ embed_dim: Embedding dimension
410
+ voxel_resolution: Resolution of the voxel grid
411
+ theta: Base frequency for rotary embeddings
412
+
413
+ Returns:
414
+ Tuple of (cos_embeddings, sin_embeddings) for 3D positions
415
+ """
416
+ assert position.shape[-1] == 3
417
+ dim_xy = embed_dim // 8 * 3
418
+ dim_z = embed_dim // 8 * 2
419
+
420
+ grid = torch.arange(voxel_resolution, dtype=torch.float32, device=position.device)
421
+ freqs_xy = RotaryEmbedding.get_1d_rotary_pos_embed(dim_xy, grid, theta=theta)
422
+ freqs_z = RotaryEmbedding.get_1d_rotary_pos_embed(dim_z, grid, theta=theta)
423
+
424
+ xy_cos, xy_sin = freqs_xy
425
+ z_cos, z_sin = freqs_z
426
+
427
+ embed_flattn = position.view(-1, position.shape[-1])
428
+ x_cos = xy_cos[embed_flattn[:, 0], :]
429
+ x_sin = xy_sin[embed_flattn[:, 0], :]
430
+ y_cos = xy_cos[embed_flattn[:, 1], :]
431
+ y_sin = xy_sin[embed_flattn[:, 1], :]
432
+ z_cos = z_cos[embed_flattn[:, 2], :]
433
+ z_sin = z_sin[embed_flattn[:, 2], :]
434
+
435
+ cos = torch.cat((x_cos, y_cos, z_cos), dim=-1)
436
+ sin = torch.cat((x_sin, y_sin, z_sin), dim=-1)
437
+
438
+ cos = cos.view(*position.shape[:-1], embed_dim)
439
+ sin = sin.view(*position.shape[:-1], embed_dim)
440
+ return cos, sin
441
+
442
+ @staticmethod
443
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]):
444
+ """
445
+ Apply rotary position embeddings to input tensor.
446
+
447
+ Args:
448
+ x: Input tensor to apply rotary embeddings to
449
+ freqs_cis: Tuple of (cos_embeddings, sin_embeddings) or single tensor
450
+
451
+ Returns:
452
+ Tensor with rotary position embeddings applied
453
+ """
454
+ cos, sin = freqs_cis
455
+ cos, sin = cos.to(x.device), sin.to(x.device)
456
+ cos = cos.unsqueeze(1)
457
+ sin = sin.unsqueeze(1)
458
+
459
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
460
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
461
+
462
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
463
+ return out
464
+
465
+
466
+ # Core attention processing logic (eliminating major duplication)
467
+ class AttnCore:
468
+ """
469
+ Core attention processing logic shared across processors.
470
+
471
+ This class provides the fundamental attention computation pipeline
472
+ that can be reused across different attention processor implementations.
473
+ """
474
+
475
+ @staticmethod
476
+ def process_attention_base(
477
+ attn: Attention,
478
+ hidden_states: torch.Tensor,
479
+ encoder_hidden_states: Optional[torch.Tensor] = None,
480
+ attention_mask: Optional[torch.Tensor] = None,
481
+ temb: Optional[torch.Tensor] = None,
482
+ get_qkv_fn: Callable = None,
483
+ apply_rope_fn: Optional[Callable] = None,
484
+ **kwargs,
485
+ ):
486
+ """
487
+ Generic attention processing core shared across different processors.
488
+
489
+ This function implements the common attention computation pipeline including:
490
+ 1. Hidden state preprocessing
491
+ 2. Attention mask preparation
492
+ 3. Q/K/V computation via provided function
493
+ 4. Tensor reshaping for multi-head attention
494
+ 5. Optional normalization and RoPE application
495
+ 6. Scaled dot-product attention computation
496
+
497
+ Args:
498
+ attn: Attention module instance
499
+ hidden_states: Input hidden states tensor
500
+ encoder_hidden_states: Optional encoder hidden states for cross-attention
501
+ attention_mask: Optional attention mask tensor
502
+ temb: Optional temporal embedding tensor
503
+ get_qkv_fn: Function to compute Q, K, V tensors
504
+ apply_rope_fn: Optional function to apply rotary position embeddings
505
+ **kwargs: Additional keyword arguments passed to subfunctions
506
+
507
+ Returns:
508
+ Tuple containing (attention_output, residual, input_ndim, shape_info,
509
+ batch_size, num_heads, head_dim)
510
+ """
511
+ # Prepare hidden states
512
+ hidden_states, residual, input_ndim, shape_info = AttnUtils.prepare_hidden_states(hidden_states, attn, temb)
513
+
514
+ batch_size, sequence_length, _ = (
515
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
516
+ )
517
+
518
+ # Prepare attention mask
519
+ attention_mask = AttnUtils.prepare_attention_mask(attention_mask, attn, sequence_length, batch_size)
520
+
521
+ # Get Q, K, V
522
+ if encoder_hidden_states is None:
523
+ encoder_hidden_states = hidden_states
524
+ elif attn.norm_cross:
525
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
526
+
527
+ query, key, value = get_qkv_fn(attn, hidden_states, encoder_hidden_states, **kwargs)
528
+
529
+ # Reshape for attention
530
+ inner_dim = key.shape[-1]
531
+ head_dim = inner_dim // attn.heads
532
+
533
+ query = AttnUtils.reshape_qkv_for_attention(query, batch_size, attn.heads, head_dim)
534
+ key = AttnUtils.reshape_qkv_for_attention(key, batch_size, attn.heads, head_dim)
535
+ value = AttnUtils.reshape_qkv_for_attention(value, batch_size, attn.heads, value.shape[-1] // attn.heads)
536
+
537
+ # Apply normalization
538
+ query, key = AttnUtils.apply_norms(query, key, getattr(attn, "norm_q", None), getattr(attn, "norm_k", None))
539
+
540
+ # Apply RoPE if provided
541
+ if apply_rope_fn is not None:
542
+ query, key = apply_rope_fn(query, key, head_dim, **kwargs)
543
+
544
+ # Compute attention
545
+ hidden_states = F.scaled_dot_product_attention(
546
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
547
+ )
548
+
549
+ return hidden_states, residual, input_ndim, shape_info, batch_size, attn.heads, head_dim
550
+
551
+
552
+ # Specific processor implementations (minimal unique code)
553
+ class PoseRoPEAttnProcessor2_0:
554
+ """
555
+ Attention processor with Rotary Position Encoding (RoPE) for 3D spatial awareness.
556
+
557
+ This processor extends standard attention with 3D rotary position embeddings
558
+ to provide spatial awareness for 3D scene understanding tasks.
559
+ """
560
+
561
+ def __init__(self):
562
+ """Initialize the RoPE attention processor."""
563
+ AttnUtils.check_pytorch_compatibility()
564
+
565
+ def __call__(
566
+ self,
567
+ attn: Attention,
568
+ hidden_states: torch.Tensor,
569
+ encoder_hidden_states: Optional[torch.Tensor] = None,
570
+ attention_mask: Optional[torch.Tensor] = None,
571
+ position_indices: Dict = None,
572
+ temb: Optional[torch.Tensor] = None,
573
+ n_pbrs=1,
574
+ *args,
575
+ **kwargs,
576
+ ) -> torch.Tensor:
577
+ """
578
+ Apply RoPE-enhanced attention computation.
579
+
580
+ Args:
581
+ attn: Attention module instance
582
+ hidden_states: Input hidden states tensor
583
+ encoder_hidden_states: Optional encoder hidden states for cross-attention
584
+ attention_mask: Optional attention mask tensor
585
+ position_indices: Dictionary containing 3D position information for RoPE
586
+ temb: Optional temporal embedding tensor
587
+ n_pbrs: Number of PBR material types
588
+ *args: Additional positional arguments
589
+ **kwargs: Additional keyword arguments
590
+
591
+ Returns:
592
+ Attention output tensor with applied rotary position encodings
593
+ """
594
+ AttnUtils.handle_deprecation_warning(args, kwargs)
595
+
596
+ def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs):
597
+ return attn.to_q(hidden_states), attn.to_k(encoder_hidden_states), attn.to_v(encoder_hidden_states)
598
+
599
+ def apply_rope(query, key, head_dim, **kwargs):
600
+ if position_indices is not None:
601
+ if head_dim in position_indices:
602
+ image_rotary_emb = position_indices[head_dim]
603
+ else:
604
+ image_rotary_emb = RotaryEmbedding.get_3d_rotary_pos_embed(
605
+ rearrange(
606
+ position_indices["voxel_indices"].unsqueeze(1).repeat(1, n_pbrs, 1, 1),
607
+ "b n_pbrs l c -> (b n_pbrs) l c",
608
+ ),
609
+ head_dim,
610
+ voxel_resolution=position_indices["voxel_resolution"],
611
+ )
612
+ position_indices[head_dim] = image_rotary_emb
613
+
614
+ query = RotaryEmbedding.apply_rotary_emb(query, image_rotary_emb)
615
+ key = RotaryEmbedding.apply_rotary_emb(key, image_rotary_emb)
616
+ return query, key
617
+
618
+ # Core attention processing
619
+ hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base(
620
+ attn,
621
+ hidden_states,
622
+ encoder_hidden_states,
623
+ attention_mask,
624
+ temb,
625
+ get_qkv_fn=get_qkv,
626
+ apply_rope_fn=apply_rope,
627
+ position_indices=position_indices,
628
+ n_pbrs=n_pbrs,
629
+ )
630
+
631
+ # Finalize output
632
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim)
633
+ hidden_states = hidden_states.to(hidden_states.dtype)
634
+
635
+ return AttnUtils.finalize_output(hidden_states, input_ndim, shape_info, attn, residual, attn.to_out)
636
+
637
+
638
+ class SelfAttnProcessor2_0(BaseAttnProcessor):
639
+ """
640
+ Self-attention processor with PBR (Physically Based Rendering) material support.
641
+
642
+ This processor handles multiple PBR material types (e.g., albedo, metallic-roughness)
643
+ with separate attention computation paths for each material type.
644
+ """
645
+
646
+ def __init__(self, **kwargs):
647
+ """
648
+ Initialize self-attention processor with PBR support.
649
+
650
+ Args:
651
+ **kwargs: Arguments passed to BaseAttnProcessor initialization
652
+ """
653
+ super().__init__(**kwargs)
654
+ self.register_pbr_modules(["qkv", "out", "add_kv"], **kwargs)
655
+
656
+ def process_single(
657
+ self,
658
+ attn: Attention,
659
+ hidden_states: torch.Tensor,
660
+ encoder_hidden_states: Optional[torch.Tensor] = None,
661
+ attention_mask: Optional[torch.Tensor] = None,
662
+ temb: Optional[torch.Tensor] = None,
663
+ token: Literal["albedo", "mr"] = "albedo",
664
+ multiple_devices=False,
665
+ *args,
666
+ **kwargs,
667
+ ):
668
+ """
669
+ Process attention for a single PBR material type.
670
+
671
+ Args:
672
+ attn: Attention module instance
673
+ hidden_states: Input hidden states tensor
674
+ encoder_hidden_states: Optional encoder hidden states for cross-attention
675
+ attention_mask: Optional attention mask tensor
676
+ temb: Optional temporal embedding tensor
677
+ token: PBR material type to process ("albedo", "mr", etc.)
678
+ multiple_devices: Whether to use multiple GPU devices
679
+ *args: Additional positional arguments
680
+ **kwargs: Additional keyword arguments
681
+
682
+ Returns:
683
+ Processed attention output for the specified PBR material type
684
+ """
685
+ target = attn if token == "albedo" else attn.processor
686
+ token_suffix = "" if token == "albedo" else "_" + token
687
+
688
+ # Device management (if needed)
689
+ if multiple_devices:
690
+ device = torch.device("cuda:0") if token == "albedo" else torch.device("cuda:1")
691
+ for attr in [f"to_q{token_suffix}", f"to_k{token_suffix}", f"to_v{token_suffix}", f"to_out{token_suffix}"]:
692
+ getattr(target, attr).to(device)
693
+
694
+ def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs):
695
+ return (
696
+ getattr(target, f"to_q{token_suffix}")(hidden_states),
697
+ getattr(target, f"to_k{token_suffix}")(encoder_hidden_states),
698
+ getattr(target, f"to_v{token_suffix}")(encoder_hidden_states),
699
+ )
700
+
701
+ # Core processing using shared logic
702
+ hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base(
703
+ attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv
704
+ )
705
+
706
+ # Finalize
707
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim)
708
+ hidden_states = hidden_states.to(hidden_states.dtype)
709
+
710
+ return AttnUtils.finalize_output(
711
+ hidden_states, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}")
712
+ )
713
+
714
+ def __call__(
715
+ self,
716
+ attn: Attention,
717
+ hidden_states: torch.Tensor,
718
+ encoder_hidden_states: Optional[torch.Tensor] = None,
719
+ attention_mask: Optional[torch.Tensor] = None,
720
+ temb: Optional[torch.Tensor] = None,
721
+ *args,
722
+ **kwargs,
723
+ ) -> torch.Tensor:
724
+ """
725
+ Apply self-attention with PBR material processing.
726
+
727
+ Processes multiple PBR material types sequentially, applying attention
728
+ computation for each material type separately and combining results.
729
+
730
+ Args:
731
+ attn: Attention module instance
732
+ hidden_states: Input hidden states tensor with PBR dimension
733
+ encoder_hidden_states: Optional encoder hidden states for cross-attention
734
+ attention_mask: Optional attention mask tensor
735
+ temb: Optional temporal embedding tensor
736
+ *args: Additional positional arguments
737
+ **kwargs: Additional keyword arguments
738
+
739
+ Returns:
740
+ Combined attention output for all PBR material types
741
+ """
742
+ AttnUtils.handle_deprecation_warning(args, kwargs)
743
+
744
+ B = hidden_states.size(0)
745
+ pbr_hidden_states = torch.split(hidden_states, 1, dim=1)
746
+
747
+ # Process each PBR setting
748
+ results = []
749
+ for token, pbr_hs in zip(self.pbr_setting, pbr_hidden_states):
750
+ processed_hs = rearrange(pbr_hs, "b n_pbrs n l c -> (b n_pbrs n) l c").to("cuda:0")
751
+ result = self.process_single(attn, processed_hs, None, attention_mask, temb, token, False)
752
+ results.append(result)
753
+
754
+ outputs = [rearrange(result, "(b n_pbrs n) l c -> b n_pbrs n l c", b=B, n_pbrs=1) for result in results]
755
+ return torch.cat(outputs, dim=1)
756
+
757
+
758
+ class RefAttnProcessor2_0(BaseAttnProcessor):
759
+ """
760
+ Reference attention processor with shared value computation across PBR materials.
761
+
762
+ This processor computes query and key once, but uses separate value projections
763
+ for different PBR material types, enabling efficient multi-material processing.
764
+ """
765
+
766
+ def __init__(self, **kwargs):
767
+ """
768
+ Initialize reference attention processor.
769
+
770
+ Args:
771
+ **kwargs: Arguments passed to BaseAttnProcessor initialization
772
+ """
773
+ super().__init__(**kwargs)
774
+ self.pbr_settings = self.pbr_setting # Alias for compatibility
775
+ self.register_pbr_modules(["v_only", "out"], **kwargs)
776
+
777
+ def __call__(
778
+ self,
779
+ attn: Attention,
780
+ hidden_states: torch.Tensor,
781
+ encoder_hidden_states: Optional[torch.Tensor] = None,
782
+ attention_mask: Optional[torch.Tensor] = None,
783
+ temb: Optional[torch.Tensor] = None,
784
+ *args,
785
+ **kwargs,
786
+ ) -> torch.Tensor:
787
+ """
788
+ Apply reference attention with shared Q/K and separate V projections.
789
+
790
+ This method computes query and key tensors once and reuses them across
791
+ all PBR material types, while using separate value projections for each
792
+ material type to maintain material-specific information.
793
+
794
+ Args:
795
+ attn: Attention module instance
796
+ hidden_states: Input hidden states tensor
797
+ encoder_hidden_states: Optional encoder hidden states for cross-attention
798
+ attention_mask: Optional attention mask tensor
799
+ temb: Optional temporal embedding tensor
800
+ *args: Additional positional arguments
801
+ **kwargs: Additional keyword arguments
802
+
803
+ Returns:
804
+ Stacked attention output for all PBR material types
805
+ """
806
+ AttnUtils.handle_deprecation_warning(args, kwargs)
807
+
808
+ def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs):
809
+ query = attn.to_q(hidden_states)
810
+ key = attn.to_k(encoder_hidden_states)
811
+
812
+ # Concatenate values from all PBR settings
813
+ value_list = [attn.to_v(encoder_hidden_states)]
814
+ for token in ["_" + token for token in self.pbr_settings if token != "albedo"]:
815
+ value_list.append(getattr(attn.processor, f"to_v{token}")(encoder_hidden_states))
816
+ value = torch.cat(value_list, dim=-1)
817
+
818
+ return query, key, value
819
+
820
+ # Core processing
821
+ hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base(
822
+ attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv
823
+ )
824
+
825
+ # Split and process each PBR setting output
826
+ hidden_states_list = torch.split(hidden_states, head_dim, dim=-1)
827
+ output_hidden_states_list = []
828
+
829
+ for i, hs in enumerate(hidden_states_list):
830
+ hs = hs.transpose(1, 2).reshape(batch_size, -1, heads * head_dim).to(hs.dtype)
831
+ token_suffix = "_" + self.pbr_settings[i] if self.pbr_settings[i] != "albedo" else ""
832
+ target = attn if self.pbr_settings[i] == "albedo" else attn.processor
833
+
834
+ hs = AttnUtils.finalize_output(
835
+ hs, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}")
836
+ )
837
+ output_hidden_states_list.append(hs)
838
+
839
+ return torch.stack(output_hidden_states_list, dim=1)