Image-to-3D
Hunyuan3D-2
Diffusers
Safetensors
English
Chinese
text-to-3d
Huiwenshi commited on
Commit
594ad89
·
verified ·
1 Parent(s): fa26a41

Delete hunyuan3d-paintpbr-v2-1/attn_processor.py

Browse files
hunyuan3d-paintpbr-v2-1/attn_processor.py DELETED
@@ -1,839 +0,0 @@
1
- # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
- # except for the third-party components listed below.
3
- # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
- # in the repsective licenses of these third-party components.
5
- # Users must comply with all terms and conditions of original licenses of these third-party
6
- # components and must ensure that the usage of the third party components adheres to
7
- # all relevant laws and regulations.
8
-
9
- # For avoidance of doubts, Hunyuan 3D means the large language models and
10
- # their software and algorithms, including trained model weights, parameters (including
11
- # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
- # fine-tuning enabling code and other elements of the foregoing made publicly available
13
- # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
-
15
- import torch
16
- import torch.nn as nn
17
- import torch.nn.functional as F
18
- from typing import Optional, Dict, Tuple, Union, Literal, List, Callable
19
- from einops import rearrange
20
- from diffusers.utils import deprecate
21
- from diffusers.models.attention_processor import Attention, AttnProcessor
22
-
23
-
24
- class AttnUtils:
25
- """
26
- Shared utility functions for attention processing.
27
-
28
- This class provides common operations used across different attention processors
29
- to eliminate code duplication and improve maintainability.
30
- """
31
-
32
- @staticmethod
33
- def check_pytorch_compatibility():
34
- """
35
- Check PyTorch compatibility for scaled_dot_product_attention.
36
-
37
- Raises:
38
- ImportError: If PyTorch version doesn't support scaled_dot_product_attention
39
- """
40
- if not hasattr(F, "scaled_dot_product_attention"):
41
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
42
-
43
- @staticmethod
44
- def handle_deprecation_warning(args, kwargs):
45
- """
46
- Handle deprecation warning for the 'scale' argument.
47
-
48
- Args:
49
- args: Positional arguments passed to attention processor
50
- kwargs: Keyword arguments passed to attention processor
51
- """
52
- if len(args) > 0 or kwargs.get("scale", None) is not None:
53
- deprecation_message = (
54
- "The `scale` argument is deprecated and will be ignored."
55
- "Please remove it, as passing it will raise an error in the future."
56
- "`scale` should directly be passed while calling the underlying pipeline component"
57
- "i.e., via `cross_attention_kwargs`."
58
- )
59
- deprecate("scale", "1.0.0", deprecation_message)
60
-
61
- @staticmethod
62
- def prepare_hidden_states(
63
- hidden_states, attn, temb, spatial_norm_attr="spatial_norm", group_norm_attr="group_norm"
64
- ):
65
- """
66
- Common preprocessing of hidden states for attention computation.
67
-
68
- Args:
69
- hidden_states: Input hidden states tensor
70
- attn: Attention module instance
71
- temb: Optional temporal embedding tensor
72
- spatial_norm_attr: Attribute name for spatial normalization
73
- group_norm_attr: Attribute name for group normalization
74
-
75
- Returns:
76
- Tuple of (processed_hidden_states, residual, input_ndim, shape_info)
77
- """
78
- residual = hidden_states
79
-
80
- spatial_norm = getattr(attn, spatial_norm_attr, None)
81
- if spatial_norm is not None:
82
- hidden_states = spatial_norm(hidden_states, temb)
83
-
84
- input_ndim = hidden_states.ndim
85
-
86
- if input_ndim == 4:
87
- batch_size, channel, height, width = hidden_states.shape
88
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
89
- else:
90
- batch_size, channel, height, width = None, None, None, None
91
-
92
- group_norm = getattr(attn, group_norm_attr, None)
93
- if group_norm is not None:
94
- hidden_states = group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
95
-
96
- return hidden_states, residual, input_ndim, (batch_size, channel, height, width)
97
-
98
- @staticmethod
99
- def prepare_attention_mask(attention_mask, attn, sequence_length, batch_size):
100
- """
101
- Prepare attention mask for scaled_dot_product_attention.
102
-
103
- Args:
104
- attention_mask: Input attention mask tensor or None
105
- attn: Attention module instance
106
- sequence_length: Length of the sequence
107
- batch_size: Batch size
108
-
109
- Returns:
110
- Prepared attention mask tensor reshaped for multi-head attention
111
- """
112
- if attention_mask is not None:
113
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
114
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
115
- return attention_mask
116
-
117
- @staticmethod
118
- def reshape_qkv_for_attention(tensor, batch_size, attn_heads, head_dim):
119
- """
120
- Reshape Q/K/V tensors for multi-head attention computation.
121
-
122
- Args:
123
- tensor: Input tensor to reshape
124
- batch_size: Batch size
125
- attn_heads: Number of attention heads
126
- head_dim: Dimension per attention head
127
-
128
- Returns:
129
- Reshaped tensor with shape [batch_size, attn_heads, seq_len, head_dim]
130
- """
131
- return tensor.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)
132
-
133
- @staticmethod
134
- def apply_norms(query, key, norm_q, norm_k):
135
- """
136
- Apply Q/K normalization layers if available.
137
-
138
- Args:
139
- query: Query tensor
140
- key: Key tensor
141
- norm_q: Query normalization layer (optional)
142
- norm_k: Key normalization layer (optional)
143
-
144
- Returns:
145
- Tuple of (normalized_query, normalized_key)
146
- """
147
- if norm_q is not None:
148
- query = norm_q(query)
149
- if norm_k is not None:
150
- key = norm_k(key)
151
- return query, key
152
-
153
- @staticmethod
154
- def finalize_output(hidden_states, input_ndim, shape_info, attn, residual, to_out):
155
- """
156
- Common output processing including projection, dropout, reshaping, and residual connection.
157
-
158
- Args:
159
- hidden_states: Processed hidden states from attention
160
- input_ndim: Original input tensor dimensions
161
- shape_info: Tuple containing original shape information
162
- attn: Attention module instance
163
- residual: Residual connection tensor
164
- to_out: Output projection layers [linear, dropout]
165
-
166
- Returns:
167
- Final output tensor after all processing steps
168
- """
169
- batch_size, channel, height, width = shape_info
170
-
171
- # Apply output projection and dropout
172
- hidden_states = to_out[0](hidden_states)
173
- hidden_states = to_out[1](hidden_states)
174
-
175
- # Reshape back if needed
176
- if input_ndim == 4:
177
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
178
-
179
- # Apply residual connection
180
- if attn.residual_connection:
181
- hidden_states = hidden_states + residual
182
-
183
- # Apply rescaling
184
- hidden_states = hidden_states / attn.rescale_output_factor
185
- return hidden_states
186
-
187
-
188
- # Base class for attention processors (eliminating initialization duplication)
189
- class BaseAttnProcessor(nn.Module):
190
- """
191
- Base class for attention processors with common initialization.
192
-
193
- This base class provides shared parameter initialization and module registration
194
- functionality to reduce code duplication across different attention processor types.
195
- """
196
-
197
- def __init__(
198
- self,
199
- query_dim: int,
200
- pbr_setting: List[str] = ["albedo", "mr"],
201
- cross_attention_dim: Optional[int] = None,
202
- heads: int = 8,
203
- kv_heads: Optional[int] = None,
204
- dim_head: int = 64,
205
- dropout: float = 0.0,
206
- bias: bool = False,
207
- upcast_attention: bool = False,
208
- upcast_softmax: bool = False,
209
- cross_attention_norm: Optional[str] = None,
210
- cross_attention_norm_num_groups: int = 32,
211
- qk_norm: Optional[str] = None,
212
- added_kv_proj_dim: Optional[int] = None,
213
- added_proj_bias: Optional[bool] = True,
214
- norm_num_groups: Optional[int] = None,
215
- spatial_norm_dim: Optional[int] = None,
216
- out_bias: bool = True,
217
- scale_qk: bool = True,
218
- only_cross_attention: bool = False,
219
- eps: float = 1e-5,
220
- rescale_output_factor: float = 1.0,
221
- residual_connection: bool = False,
222
- _from_deprecated_attn_block: bool = False,
223
- processor: Optional["AttnProcessor"] = None,
224
- out_dim: int = None,
225
- out_context_dim: int = None,
226
- context_pre_only=None,
227
- pre_only=False,
228
- elementwise_affine: bool = True,
229
- is_causal: bool = False,
230
- **kwargs,
231
- ):
232
- """
233
- Initialize base attention processor with common parameters.
234
-
235
- Args:
236
- query_dim: Dimension of query features
237
- pbr_setting: List of PBR material types to process (e.g., ["albedo", "mr"])
238
- cross_attention_dim: Dimension of cross-attention features (optional)
239
- heads: Number of attention heads
240
- kv_heads: Number of key-value heads for grouped query attention (optional)
241
- dim_head: Dimension per attention head
242
- dropout: Dropout rate
243
- bias: Whether to use bias in linear projections
244
- upcast_attention: Whether to upcast attention computation to float32
245
- upcast_softmax: Whether to upcast softmax computation to float32
246
- cross_attention_norm: Type of cross-attention normalization (optional)
247
- cross_attention_norm_num_groups: Number of groups for cross-attention norm
248
- qk_norm: Type of query-key normalization (optional)
249
- added_kv_proj_dim: Dimension for additional key-value projections (optional)
250
- added_proj_bias: Whether to use bias in additional projections
251
- norm_num_groups: Number of groups for normalization (optional)
252
- spatial_norm_dim: Dimension for spatial normalization (optional)
253
- out_bias: Whether to use bias in output projection
254
- scale_qk: Whether to scale query-key products
255
- only_cross_attention: Whether to only perform cross-attention
256
- eps: Small epsilon value for numerical stability
257
- rescale_output_factor: Factor to rescale output values
258
- residual_connection: Whether to use residual connections
259
- _from_deprecated_attn_block: Flag for deprecated attention blocks
260
- processor: Optional attention processor instance
261
- out_dim: Output dimension (optional)
262
- out_context_dim: Output context dimension (optional)
263
- context_pre_only: Whether to only process context in pre-processing
264
- pre_only: Whether to only perform pre-processing
265
- elementwise_affine: Whether to use element-wise affine transformations
266
- is_causal: Whether to use causal attention masking
267
- **kwargs: Additional keyword arguments
268
- """
269
- super().__init__()
270
- AttnUtils.check_pytorch_compatibility()
271
-
272
- # Store common attributes
273
- self.pbr_setting = pbr_setting
274
- self.n_pbr_tokens = len(self.pbr_setting)
275
- self.inner_dim = out_dim if out_dim is not None else dim_head * heads
276
- self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
277
- self.query_dim = query_dim
278
- self.use_bias = bias
279
- self.is_cross_attention = cross_attention_dim is not None
280
- self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
281
- self.upcast_attention = upcast_attention
282
- self.upcast_softmax = upcast_softmax
283
- self.rescale_output_factor = rescale_output_factor
284
- self.residual_connection = residual_connection
285
- self.dropout = dropout
286
- self.fused_projections = False
287
- self.out_dim = out_dim if out_dim is not None else query_dim
288
- self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
289
- self.context_pre_only = context_pre_only
290
- self.pre_only = pre_only
291
- self.is_causal = is_causal
292
- self._from_deprecated_attn_block = _from_deprecated_attn_block
293
- self.scale_qk = scale_qk
294
- self.scale = dim_head**-0.5 if self.scale_qk else 1.0
295
- self.heads = out_dim // dim_head if out_dim is not None else heads
296
- self.sliceable_head_dim = heads
297
- self.added_kv_proj_dim = added_kv_proj_dim
298
- self.only_cross_attention = only_cross_attention
299
- self.added_proj_bias = added_proj_bias
300
-
301
- # Validation
302
- if self.added_kv_proj_dim is None and self.only_cross_attention:
303
- raise ValueError(
304
- "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None."
305
- "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
306
- )
307
-
308
- def register_pbr_modules(self, module_types: List[str], **kwargs):
309
- """
310
- Generic PBR module registration to eliminate code repetition.
311
-
312
- Dynamically registers PyTorch modules for different PBR material types
313
- based on the specified module types and PBR settings.
314
-
315
- Args:
316
- module_types: List of module types to register ("qkv", "v_only", "out", "add_kv")
317
- **kwargs: Additional arguments for module configuration
318
- """
319
- for pbr_token in self.pbr_setting:
320
- if pbr_token == "albedo":
321
- continue
322
-
323
- for module_type in module_types:
324
- if module_type == "qkv":
325
- self.register_module(
326
- f"to_q_{pbr_token}", nn.Linear(self.query_dim, self.inner_dim, bias=self.use_bias)
327
- )
328
- self.register_module(
329
- f"to_k_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias)
330
- )
331
- self.register_module(
332
- f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias)
333
- )
334
- elif module_type == "v_only":
335
- self.register_module(
336
- f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias)
337
- )
338
- elif module_type == "out":
339
- if not self.pre_only:
340
- self.register_module(
341
- f"to_out_{pbr_token}",
342
- nn.ModuleList(
343
- [
344
- nn.Linear(self.inner_dim, self.out_dim, bias=kwargs.get("out_bias", True)),
345
- nn.Dropout(self.dropout),
346
- ]
347
- ),
348
- )
349
- else:
350
- self.register_module(f"to_out_{pbr_token}", None)
351
- elif module_type == "add_kv":
352
- if self.added_kv_proj_dim is not None:
353
- self.register_module(
354
- f"add_k_proj_{pbr_token}",
355
- nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias),
356
- )
357
- self.register_module(
358
- f"add_v_proj_{pbr_token}",
359
- nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias),
360
- )
361
- else:
362
- self.register_module(f"add_k_proj_{pbr_token}", None)
363
- self.register_module(f"add_v_proj_{pbr_token}", None)
364
-
365
-
366
- # Rotary Position Embedding utilities (specialized for PoseRoPE)
367
- class RotaryEmbedding:
368
- """
369
- Rotary position embedding utilities for 3D spatial attention.
370
-
371
- Provides functions to compute and apply rotary position embeddings (RoPE)
372
- for 1D, 3D spatial coordinates used in 3D-aware attention mechanisms.
373
- """
374
-
375
- @staticmethod
376
- def get_1d_rotary_pos_embed(dim: int, pos: torch.Tensor, theta: float = 10000.0, linear_factor=1.0, ntk_factor=1.0):
377
- """
378
- Compute 1D rotary position embeddings.
379
-
380
- Args:
381
- dim: Embedding dimension (must be even)
382
- pos: Position tensor
383
- theta: Base frequency for rotary embeddings
384
- linear_factor: Linear scaling factor
385
- ntk_factor: NTK (Neural Tangent Kernel) scaling factor
386
-
387
- Returns:
388
- Tuple of (cos_embeddings, sin_embeddings)
389
- """
390
- assert dim % 2 == 0
391
- theta = theta * ntk_factor
392
- freqs = (
393
- 1.0
394
- / (theta ** (torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device)[: (dim // 2)] / dim))
395
- / linear_factor
396
- )
397
- freqs = torch.outer(pos, freqs)
398
- freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float()
399
- freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float()
400
- return freqs_cos, freqs_sin
401
-
402
- @staticmethod
403
- def get_3d_rotary_pos_embed(position, embed_dim, voxel_resolution, theta: int = 10000):
404
- """
405
- Compute 3D rotary position embeddings for spatial coordinates.
406
-
407
- Args:
408
- position: 3D position tensor with shape [..., 3]
409
- embed_dim: Embedding dimension
410
- voxel_resolution: Resolution of the voxel grid
411
- theta: Base frequency for rotary embeddings
412
-
413
- Returns:
414
- Tuple of (cos_embeddings, sin_embeddings) for 3D positions
415
- """
416
- assert position.shape[-1] == 3
417
- dim_xy = embed_dim // 8 * 3
418
- dim_z = embed_dim // 8 * 2
419
-
420
- grid = torch.arange(voxel_resolution, dtype=torch.float32, device=position.device)
421
- freqs_xy = RotaryEmbedding.get_1d_rotary_pos_embed(dim_xy, grid, theta=theta)
422
- freqs_z = RotaryEmbedding.get_1d_rotary_pos_embed(dim_z, grid, theta=theta)
423
-
424
- xy_cos, xy_sin = freqs_xy
425
- z_cos, z_sin = freqs_z
426
-
427
- embed_flattn = position.view(-1, position.shape[-1])
428
- x_cos = xy_cos[embed_flattn[:, 0], :]
429
- x_sin = xy_sin[embed_flattn[:, 0], :]
430
- y_cos = xy_cos[embed_flattn[:, 1], :]
431
- y_sin = xy_sin[embed_flattn[:, 1], :]
432
- z_cos = z_cos[embed_flattn[:, 2], :]
433
- z_sin = z_sin[embed_flattn[:, 2], :]
434
-
435
- cos = torch.cat((x_cos, y_cos, z_cos), dim=-1)
436
- sin = torch.cat((x_sin, y_sin, z_sin), dim=-1)
437
-
438
- cos = cos.view(*position.shape[:-1], embed_dim)
439
- sin = sin.view(*position.shape[:-1], embed_dim)
440
- return cos, sin
441
-
442
- @staticmethod
443
- def apply_rotary_emb(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]):
444
- """
445
- Apply rotary position embeddings to input tensor.
446
-
447
- Args:
448
- x: Input tensor to apply rotary embeddings to
449
- freqs_cis: Tuple of (cos_embeddings, sin_embeddings) or single tensor
450
-
451
- Returns:
452
- Tensor with rotary position embeddings applied
453
- """
454
- cos, sin = freqs_cis
455
- cos, sin = cos.to(x.device), sin.to(x.device)
456
- cos = cos.unsqueeze(1)
457
- sin = sin.unsqueeze(1)
458
-
459
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
460
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
461
-
462
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
463
- return out
464
-
465
-
466
- # Core attention processing logic (eliminating major duplication)
467
- class AttnCore:
468
- """
469
- Core attention processing logic shared across processors.
470
-
471
- This class provides the fundamental attention computation pipeline
472
- that can be reused across different attention processor implementations.
473
- """
474
-
475
- @staticmethod
476
- def process_attention_base(
477
- attn: Attention,
478
- hidden_states: torch.Tensor,
479
- encoder_hidden_states: Optional[torch.Tensor] = None,
480
- attention_mask: Optional[torch.Tensor] = None,
481
- temb: Optional[torch.Tensor] = None,
482
- get_qkv_fn: Callable = None,
483
- apply_rope_fn: Optional[Callable] = None,
484
- **kwargs,
485
- ):
486
- """
487
- Generic attention processing core shared across different processors.
488
-
489
- This function implements the common attention computation pipeline including:
490
- 1. Hidden state preprocessing
491
- 2. Attention mask preparation
492
- 3. Q/K/V computation via provided function
493
- 4. Tensor reshaping for multi-head attention
494
- 5. Optional normalization and RoPE application
495
- 6. Scaled dot-product attention computation
496
-
497
- Args:
498
- attn: Attention module instance
499
- hidden_states: Input hidden states tensor
500
- encoder_hidden_states: Optional encoder hidden states for cross-attention
501
- attention_mask: Optional attention mask tensor
502
- temb: Optional temporal embedding tensor
503
- get_qkv_fn: Function to compute Q, K, V tensors
504
- apply_rope_fn: Optional function to apply rotary position embeddings
505
- **kwargs: Additional keyword arguments passed to subfunctions
506
-
507
- Returns:
508
- Tuple containing (attention_output, residual, input_ndim, shape_info,
509
- batch_size, num_heads, head_dim)
510
- """
511
- # Prepare hidden states
512
- hidden_states, residual, input_ndim, shape_info = AttnUtils.prepare_hidden_states(hidden_states, attn, temb)
513
-
514
- batch_size, sequence_length, _ = (
515
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
516
- )
517
-
518
- # Prepare attention mask
519
- attention_mask = AttnUtils.prepare_attention_mask(attention_mask, attn, sequence_length, batch_size)
520
-
521
- # Get Q, K, V
522
- if encoder_hidden_states is None:
523
- encoder_hidden_states = hidden_states
524
- elif attn.norm_cross:
525
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
526
-
527
- query, key, value = get_qkv_fn(attn, hidden_states, encoder_hidden_states, **kwargs)
528
-
529
- # Reshape for attention
530
- inner_dim = key.shape[-1]
531
- head_dim = inner_dim // attn.heads
532
-
533
- query = AttnUtils.reshape_qkv_for_attention(query, batch_size, attn.heads, head_dim)
534
- key = AttnUtils.reshape_qkv_for_attention(key, batch_size, attn.heads, head_dim)
535
- value = AttnUtils.reshape_qkv_for_attention(value, batch_size, attn.heads, value.shape[-1] // attn.heads)
536
-
537
- # Apply normalization
538
- query, key = AttnUtils.apply_norms(query, key, getattr(attn, "norm_q", None), getattr(attn, "norm_k", None))
539
-
540
- # Apply RoPE if provided
541
- if apply_rope_fn is not None:
542
- query, key = apply_rope_fn(query, key, head_dim, **kwargs)
543
-
544
- # Compute attention
545
- hidden_states = F.scaled_dot_product_attention(
546
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
547
- )
548
-
549
- return hidden_states, residual, input_ndim, shape_info, batch_size, attn.heads, head_dim
550
-
551
-
552
- # Specific processor implementations (minimal unique code)
553
- class PoseRoPEAttnProcessor2_0:
554
- """
555
- Attention processor with Rotary Position Encoding (RoPE) for 3D spatial awareness.
556
-
557
- This processor extends standard attention with 3D rotary position embeddings
558
- to provide spatial awareness for 3D scene understanding tasks.
559
- """
560
-
561
- def __init__(self):
562
- """Initialize the RoPE attention processor."""
563
- AttnUtils.check_pytorch_compatibility()
564
-
565
- def __call__(
566
- self,
567
- attn: Attention,
568
- hidden_states: torch.Tensor,
569
- encoder_hidden_states: Optional[torch.Tensor] = None,
570
- attention_mask: Optional[torch.Tensor] = None,
571
- position_indices: Dict = None,
572
- temb: Optional[torch.Tensor] = None,
573
- n_pbrs=1,
574
- *args,
575
- **kwargs,
576
- ) -> torch.Tensor:
577
- """
578
- Apply RoPE-enhanced attention computation.
579
-
580
- Args:
581
- attn: Attention module instance
582
- hidden_states: Input hidden states tensor
583
- encoder_hidden_states: Optional encoder hidden states for cross-attention
584
- attention_mask: Optional attention mask tensor
585
- position_indices: Dictionary containing 3D position information for RoPE
586
- temb: Optional temporal embedding tensor
587
- n_pbrs: Number of PBR material types
588
- *args: Additional positional arguments
589
- **kwargs: Additional keyword arguments
590
-
591
- Returns:
592
- Attention output tensor with applied rotary position encodings
593
- """
594
- AttnUtils.handle_deprecation_warning(args, kwargs)
595
-
596
- def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs):
597
- return attn.to_q(hidden_states), attn.to_k(encoder_hidden_states), attn.to_v(encoder_hidden_states)
598
-
599
- def apply_rope(query, key, head_dim, **kwargs):
600
- if position_indices is not None:
601
- if head_dim in position_indices:
602
- image_rotary_emb = position_indices[head_dim]
603
- else:
604
- image_rotary_emb = RotaryEmbedding.get_3d_rotary_pos_embed(
605
- rearrange(
606
- position_indices["voxel_indices"].unsqueeze(1).repeat(1, n_pbrs, 1, 1),
607
- "b n_pbrs l c -> (b n_pbrs) l c",
608
- ),
609
- head_dim,
610
- voxel_resolution=position_indices["voxel_resolution"],
611
- )
612
- position_indices[head_dim] = image_rotary_emb
613
-
614
- query = RotaryEmbedding.apply_rotary_emb(query, image_rotary_emb)
615
- key = RotaryEmbedding.apply_rotary_emb(key, image_rotary_emb)
616
- return query, key
617
-
618
- # Core attention processing
619
- hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base(
620
- attn,
621
- hidden_states,
622
- encoder_hidden_states,
623
- attention_mask,
624
- temb,
625
- get_qkv_fn=get_qkv,
626
- apply_rope_fn=apply_rope,
627
- position_indices=position_indices,
628
- n_pbrs=n_pbrs,
629
- )
630
-
631
- # Finalize output
632
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim)
633
- hidden_states = hidden_states.to(hidden_states.dtype)
634
-
635
- return AttnUtils.finalize_output(hidden_states, input_ndim, shape_info, attn, residual, attn.to_out)
636
-
637
-
638
- class SelfAttnProcessor2_0(BaseAttnProcessor):
639
- """
640
- Self-attention processor with PBR (Physically Based Rendering) material support.
641
-
642
- This processor handles multiple PBR material types (e.g., albedo, metallic-roughness)
643
- with separate attention computation paths for each material type.
644
- """
645
-
646
- def __init__(self, **kwargs):
647
- """
648
- Initialize self-attention processor with PBR support.
649
-
650
- Args:
651
- **kwargs: Arguments passed to BaseAttnProcessor initialization
652
- """
653
- super().__init__(**kwargs)
654
- self.register_pbr_modules(["qkv", "out", "add_kv"], **kwargs)
655
-
656
- def process_single(
657
- self,
658
- attn: Attention,
659
- hidden_states: torch.Tensor,
660
- encoder_hidden_states: Optional[torch.Tensor] = None,
661
- attention_mask: Optional[torch.Tensor] = None,
662
- temb: Optional[torch.Tensor] = None,
663
- token: Literal["albedo", "mr"] = "albedo",
664
- multiple_devices=False,
665
- *args,
666
- **kwargs,
667
- ):
668
- """
669
- Process attention for a single PBR material type.
670
-
671
- Args:
672
- attn: Attention module instance
673
- hidden_states: Input hidden states tensor
674
- encoder_hidden_states: Optional encoder hidden states for cross-attention
675
- attention_mask: Optional attention mask tensor
676
- temb: Optional temporal embedding tensor
677
- token: PBR material type to process ("albedo", "mr", etc.)
678
- multiple_devices: Whether to use multiple GPU devices
679
- *args: Additional positional arguments
680
- **kwargs: Additional keyword arguments
681
-
682
- Returns:
683
- Processed attention output for the specified PBR material type
684
- """
685
- target = attn if token == "albedo" else attn.processor
686
- token_suffix = "" if token == "albedo" else "_" + token
687
-
688
- # Device management (if needed)
689
- if multiple_devices:
690
- device = torch.device("cuda:0") if token == "albedo" else torch.device("cuda:1")
691
- for attr in [f"to_q{token_suffix}", f"to_k{token_suffix}", f"to_v{token_suffix}", f"to_out{token_suffix}"]:
692
- getattr(target, attr).to(device)
693
-
694
- def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs):
695
- return (
696
- getattr(target, f"to_q{token_suffix}")(hidden_states),
697
- getattr(target, f"to_k{token_suffix}")(encoder_hidden_states),
698
- getattr(target, f"to_v{token_suffix}")(encoder_hidden_states),
699
- )
700
-
701
- # Core processing using shared logic
702
- hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base(
703
- attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv
704
- )
705
-
706
- # Finalize
707
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim)
708
- hidden_states = hidden_states.to(hidden_states.dtype)
709
-
710
- return AttnUtils.finalize_output(
711
- hidden_states, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}")
712
- )
713
-
714
- def __call__(
715
- self,
716
- attn: Attention,
717
- hidden_states: torch.Tensor,
718
- encoder_hidden_states: Optional[torch.Tensor] = None,
719
- attention_mask: Optional[torch.Tensor] = None,
720
- temb: Optional[torch.Tensor] = None,
721
- *args,
722
- **kwargs,
723
- ) -> torch.Tensor:
724
- """
725
- Apply self-attention with PBR material processing.
726
-
727
- Processes multiple PBR material types sequentially, applying attention
728
- computation for each material type separately and combining results.
729
-
730
- Args:
731
- attn: Attention module instance
732
- hidden_states: Input hidden states tensor with PBR dimension
733
- encoder_hidden_states: Optional encoder hidden states for cross-attention
734
- attention_mask: Optional attention mask tensor
735
- temb: Optional temporal embedding tensor
736
- *args: Additional positional arguments
737
- **kwargs: Additional keyword arguments
738
-
739
- Returns:
740
- Combined attention output for all PBR material types
741
- """
742
- AttnUtils.handle_deprecation_warning(args, kwargs)
743
-
744
- B = hidden_states.size(0)
745
- pbr_hidden_states = torch.split(hidden_states, 1, dim=1)
746
-
747
- # Process each PBR setting
748
- results = []
749
- for token, pbr_hs in zip(self.pbr_setting, pbr_hidden_states):
750
- processed_hs = rearrange(pbr_hs, "b n_pbrs n l c -> (b n_pbrs n) l c").to("cuda:0")
751
- result = self.process_single(attn, processed_hs, None, attention_mask, temb, token, False)
752
- results.append(result)
753
-
754
- outputs = [rearrange(result, "(b n_pbrs n) l c -> b n_pbrs n l c", b=B, n_pbrs=1) for result in results]
755
- return torch.cat(outputs, dim=1)
756
-
757
-
758
- class RefAttnProcessor2_0(BaseAttnProcessor):
759
- """
760
- Reference attention processor with shared value computation across PBR materials.
761
-
762
- This processor computes query and key once, but uses separate value projections
763
- for different PBR material types, enabling efficient multi-material processing.
764
- """
765
-
766
- def __init__(self, **kwargs):
767
- """
768
- Initialize reference attention processor.
769
-
770
- Args:
771
- **kwargs: Arguments passed to BaseAttnProcessor initialization
772
- """
773
- super().__init__(**kwargs)
774
- self.pbr_settings = self.pbr_setting # Alias for compatibility
775
- self.register_pbr_modules(["v_only", "out"], **kwargs)
776
-
777
- def __call__(
778
- self,
779
- attn: Attention,
780
- hidden_states: torch.Tensor,
781
- encoder_hidden_states: Optional[torch.Tensor] = None,
782
- attention_mask: Optional[torch.Tensor] = None,
783
- temb: Optional[torch.Tensor] = None,
784
- *args,
785
- **kwargs,
786
- ) -> torch.Tensor:
787
- """
788
- Apply reference attention with shared Q/K and separate V projections.
789
-
790
- This method computes query and key tensors once and reuses them across
791
- all PBR material types, while using separate value projections for each
792
- material type to maintain material-specific information.
793
-
794
- Args:
795
- attn: Attention module instance
796
- hidden_states: Input hidden states tensor
797
- encoder_hidden_states: Optional encoder hidden states for cross-attention
798
- attention_mask: Optional attention mask tensor
799
- temb: Optional temporal embedding tensor
800
- *args: Additional positional arguments
801
- **kwargs: Additional keyword arguments
802
-
803
- Returns:
804
- Stacked attention output for all PBR material types
805
- """
806
- AttnUtils.handle_deprecation_warning(args, kwargs)
807
-
808
- def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs):
809
- query = attn.to_q(hidden_states)
810
- key = attn.to_k(encoder_hidden_states)
811
-
812
- # Concatenate values from all PBR settings
813
- value_list = [attn.to_v(encoder_hidden_states)]
814
- for token in ["_" + token for token in self.pbr_settings if token != "albedo"]:
815
- value_list.append(getattr(attn.processor, f"to_v{token}")(encoder_hidden_states))
816
- value = torch.cat(value_list, dim=-1)
817
-
818
- return query, key, value
819
-
820
- # Core processing
821
- hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base(
822
- attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv
823
- )
824
-
825
- # Split and process each PBR setting output
826
- hidden_states_list = torch.split(hidden_states, head_dim, dim=-1)
827
- output_hidden_states_list = []
828
-
829
- for i, hs in enumerate(hidden_states_list):
830
- hs = hs.transpose(1, 2).reshape(batch_size, -1, heads * head_dim).to(hs.dtype)
831
- token_suffix = "_" + self.pbr_settings[i] if self.pbr_settings[i] != "albedo" else ""
832
- target = attn if self.pbr_settings[i] == "albedo" else attn.processor
833
-
834
- hs = AttnUtils.finalize_output(
835
- hs, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}")
836
- )
837
- output_hidden_states_list.append(hs)
838
-
839
- return torch.stack(output_hidden_states_list, dim=1)