Image-to-3D
Hunyuan3D-2
Diffusers
Safetensors
English
Chinese
text-to-3d
File size: 35,104 Bytes
6e9aaf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
# except for the third-party components listed below.
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
# in the repsective licenses of these third-party components.
# Users must comply with all terms and conditions of original licenses of these third-party
# components and must ensure that the usage of the third party components adheres to
# all relevant laws and regulations.

# For avoidance of doubts, Hunyuan 3D means the large language models and
# their software and algorithms, including trained model weights, parameters (including
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
# fine-tuning enabling code and other elements of the foregoing made publicly available
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Tuple, Union, Literal, List, Callable
from einops import rearrange
from diffusers.utils import deprecate
from diffusers.models.attention_processor import Attention, AttnProcessor


class AttnUtils:
    """
    Shared utility functions for attention processing.

    This class provides common operations used across different attention processors
    to eliminate code duplication and improve maintainability.
    """

    @staticmethod
    def check_pytorch_compatibility():
        """
        Check PyTorch compatibility for scaled_dot_product_attention.

        Raises:
            ImportError: If PyTorch version doesn't support scaled_dot_product_attention
        """
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    @staticmethod
    def handle_deprecation_warning(args, kwargs):
        """
        Handle deprecation warning for the 'scale' argument.

        Args:
            args: Positional arguments passed to attention processor
            kwargs: Keyword arguments passed to attention processor
        """
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            deprecation_message = (
                "The `scale` argument is deprecated and will be ignored."
                "Please remove it, as passing it will raise an error in the future."
                "`scale` should directly be passed while calling the underlying pipeline component"
                "i.e., via `cross_attention_kwargs`."
            )
            deprecate("scale", "1.0.0", deprecation_message)

    @staticmethod
    def prepare_hidden_states(
        hidden_states, attn, temb, spatial_norm_attr="spatial_norm", group_norm_attr="group_norm"
    ):
        """
        Common preprocessing of hidden states for attention computation.

        Args:
            hidden_states: Input hidden states tensor
            attn: Attention module instance
            temb: Optional temporal embedding tensor
            spatial_norm_attr: Attribute name for spatial normalization
            group_norm_attr: Attribute name for group normalization

        Returns:
            Tuple of (processed_hidden_states, residual, input_ndim, shape_info)
        """
        residual = hidden_states

        spatial_norm = getattr(attn, spatial_norm_attr, None)
        if spatial_norm is not None:
            hidden_states = spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
        else:
            batch_size, channel, height, width = None, None, None, None

        group_norm = getattr(attn, group_norm_attr, None)
        if group_norm is not None:
            hidden_states = group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        return hidden_states, residual, input_ndim, (batch_size, channel, height, width)

    @staticmethod
    def prepare_attention_mask(attention_mask, attn, sequence_length, batch_size):
        """
        Prepare attention mask for scaled_dot_product_attention.

        Args:
            attention_mask: Input attention mask tensor or None
            attn: Attention module instance
            sequence_length: Length of the sequence
            batch_size: Batch size

        Returns:
            Prepared attention mask tensor reshaped for multi-head attention
        """
        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
        return attention_mask

    @staticmethod
    def reshape_qkv_for_attention(tensor, batch_size, attn_heads, head_dim):
        """
        Reshape Q/K/V tensors for multi-head attention computation.

        Args:
            tensor: Input tensor to reshape
            batch_size: Batch size
            attn_heads: Number of attention heads
            head_dim: Dimension per attention head

        Returns:
            Reshaped tensor with shape [batch_size, attn_heads, seq_len, head_dim]
        """
        return tensor.view(batch_size, -1, attn_heads, head_dim).transpose(1, 2)

    @staticmethod
    def apply_norms(query, key, norm_q, norm_k):
        """
        Apply Q/K normalization layers if available.

        Args:
            query: Query tensor
            key: Key tensor
            norm_q: Query normalization layer (optional)
            norm_k: Key normalization layer (optional)

        Returns:
            Tuple of (normalized_query, normalized_key)
        """
        if norm_q is not None:
            query = norm_q(query)
        if norm_k is not None:
            key = norm_k(key)
        return query, key

    @staticmethod
    def finalize_output(hidden_states, input_ndim, shape_info, attn, residual, to_out):
        """
        Common output processing including projection, dropout, reshaping, and residual connection.

        Args:
            hidden_states: Processed hidden states from attention
            input_ndim: Original input tensor dimensions
            shape_info: Tuple containing original shape information
            attn: Attention module instance
            residual: Residual connection tensor
            to_out: Output projection layers [linear, dropout]

        Returns:
            Final output tensor after all processing steps
        """
        batch_size, channel, height, width = shape_info

        # Apply output projection and dropout
        hidden_states = to_out[0](hidden_states)
        hidden_states = to_out[1](hidden_states)

        # Reshape back if needed
        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        # Apply residual connection
        if attn.residual_connection:
            hidden_states = hidden_states + residual

        # Apply rescaling
        hidden_states = hidden_states / attn.rescale_output_factor
        return hidden_states


# Base class for attention processors (eliminating initialization duplication)
class BaseAttnProcessor(nn.Module):
    """
    Base class for attention processors with common initialization.

    This base class provides shared parameter initialization and module registration
    functionality to reduce code duplication across different attention processor types.
    """

    def __init__(
        self,
        query_dim: int,
        pbr_setting: List[str] = ["albedo", "mr"],
        cross_attention_dim: Optional[int] = None,
        heads: int = 8,
        kv_heads: Optional[int] = None,
        dim_head: int = 64,
        dropout: float = 0.0,
        bias: bool = False,
        upcast_attention: bool = False,
        upcast_softmax: bool = False,
        cross_attention_norm: Optional[str] = None,
        cross_attention_norm_num_groups: int = 32,
        qk_norm: Optional[str] = None,
        added_kv_proj_dim: Optional[int] = None,
        added_proj_bias: Optional[bool] = True,
        norm_num_groups: Optional[int] = None,
        spatial_norm_dim: Optional[int] = None,
        out_bias: bool = True,
        scale_qk: bool = True,
        only_cross_attention: bool = False,
        eps: float = 1e-5,
        rescale_output_factor: float = 1.0,
        residual_connection: bool = False,
        _from_deprecated_attn_block: bool = False,
        processor: Optional["AttnProcessor"] = None,
        out_dim: int = None,
        out_context_dim: int = None,
        context_pre_only=None,
        pre_only=False,
        elementwise_affine: bool = True,
        is_causal: bool = False,
        **kwargs,
    ):
        """
        Initialize base attention processor with common parameters.

        Args:
            query_dim: Dimension of query features
            pbr_setting: List of PBR material types to process (e.g., ["albedo", "mr"])
            cross_attention_dim: Dimension of cross-attention features (optional)
            heads: Number of attention heads
            kv_heads: Number of key-value heads for grouped query attention (optional)
            dim_head: Dimension per attention head
            dropout: Dropout rate
            bias: Whether to use bias in linear projections
            upcast_attention: Whether to upcast attention computation to float32
            upcast_softmax: Whether to upcast softmax computation to float32
            cross_attention_norm: Type of cross-attention normalization (optional)
            cross_attention_norm_num_groups: Number of groups for cross-attention norm
            qk_norm: Type of query-key normalization (optional)
            added_kv_proj_dim: Dimension for additional key-value projections (optional)
            added_proj_bias: Whether to use bias in additional projections
            norm_num_groups: Number of groups for normalization (optional)
            spatial_norm_dim: Dimension for spatial normalization (optional)
            out_bias: Whether to use bias in output projection
            scale_qk: Whether to scale query-key products
            only_cross_attention: Whether to only perform cross-attention
            eps: Small epsilon value for numerical stability
            rescale_output_factor: Factor to rescale output values
            residual_connection: Whether to use residual connections
            _from_deprecated_attn_block: Flag for deprecated attention blocks
            processor: Optional attention processor instance
            out_dim: Output dimension (optional)
            out_context_dim: Output context dimension (optional)
            context_pre_only: Whether to only process context in pre-processing
            pre_only: Whether to only perform pre-processing
            elementwise_affine: Whether to use element-wise affine transformations
            is_causal: Whether to use causal attention masking
            **kwargs: Additional keyword arguments
        """
        super().__init__()
        AttnUtils.check_pytorch_compatibility()

        # Store common attributes
        self.pbr_setting = pbr_setting
        self.n_pbr_tokens = len(self.pbr_setting)
        self.inner_dim = out_dim if out_dim is not None else dim_head * heads
        self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
        self.query_dim = query_dim
        self.use_bias = bias
        self.is_cross_attention = cross_attention_dim is not None
        self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
        self.upcast_attention = upcast_attention
        self.upcast_softmax = upcast_softmax
        self.rescale_output_factor = rescale_output_factor
        self.residual_connection = residual_connection
        self.dropout = dropout
        self.fused_projections = False
        self.out_dim = out_dim if out_dim is not None else query_dim
        self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
        self.context_pre_only = context_pre_only
        self.pre_only = pre_only
        self.is_causal = is_causal
        self._from_deprecated_attn_block = _from_deprecated_attn_block
        self.scale_qk = scale_qk
        self.scale = dim_head**-0.5 if self.scale_qk else 1.0
        self.heads = out_dim // dim_head if out_dim is not None else heads
        self.sliceable_head_dim = heads
        self.added_kv_proj_dim = added_kv_proj_dim
        self.only_cross_attention = only_cross_attention
        self.added_proj_bias = added_proj_bias

        # Validation
        if self.added_kv_proj_dim is None and self.only_cross_attention:
            raise ValueError(
                "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None."
                "Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
            )

    def register_pbr_modules(self, module_types: List[str], **kwargs):
        """
        Generic PBR module registration to eliminate code repetition.

        Dynamically registers PyTorch modules for different PBR material types
        based on the specified module types and PBR settings.

        Args:
            module_types: List of module types to register ("qkv", "v_only", "out", "add_kv")
            **kwargs: Additional arguments for module configuration
        """
        for pbr_token in self.pbr_setting:
            if pbr_token == "albedo":
                continue

            for module_type in module_types:
                if module_type == "qkv":
                    self.register_module(
                        f"to_q_{pbr_token}", nn.Linear(self.query_dim, self.inner_dim, bias=self.use_bias)
                    )
                    self.register_module(
                        f"to_k_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias)
                    )
                    self.register_module(
                        f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias)
                    )
                elif module_type == "v_only":
                    self.register_module(
                        f"to_v_{pbr_token}", nn.Linear(self.cross_attention_dim, self.inner_dim, bias=self.use_bias)
                    )
                elif module_type == "out":
                    if not self.pre_only:
                        self.register_module(
                            f"to_out_{pbr_token}",
                            nn.ModuleList(
                                [
                                    nn.Linear(self.inner_dim, self.out_dim, bias=kwargs.get("out_bias", True)),
                                    nn.Dropout(self.dropout),
                                ]
                            ),
                        )
                    else:
                        self.register_module(f"to_out_{pbr_token}", None)
                elif module_type == "add_kv":
                    if self.added_kv_proj_dim is not None:
                        self.register_module(
                            f"add_k_proj_{pbr_token}",
                            nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias),
                        )
                        self.register_module(
                            f"add_v_proj_{pbr_token}",
                            nn.Linear(self.added_kv_proj_dim, self.inner_kv_dim, bias=self.added_proj_bias),
                        )
                    else:
                        self.register_module(f"add_k_proj_{pbr_token}", None)
                        self.register_module(f"add_v_proj_{pbr_token}", None)


# Rotary Position Embedding utilities (specialized for PoseRoPE)
class RotaryEmbedding:
    """
    Rotary position embedding utilities for 3D spatial attention.

    Provides functions to compute and apply rotary position embeddings (RoPE)
    for 1D, 3D spatial coordinates used in 3D-aware attention mechanisms.
    """

    @staticmethod
    def get_1d_rotary_pos_embed(dim: int, pos: torch.Tensor, theta: float = 10000.0, linear_factor=1.0, ntk_factor=1.0):
        """
        Compute 1D rotary position embeddings.

        Args:
            dim: Embedding dimension (must be even)
            pos: Position tensor
            theta: Base frequency for rotary embeddings
            linear_factor: Linear scaling factor
            ntk_factor: NTK (Neural Tangent Kernel) scaling factor

        Returns:
            Tuple of (cos_embeddings, sin_embeddings)
        """
        assert dim % 2 == 0
        theta = theta * ntk_factor
        freqs = (
            1.0
            / (theta ** (torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device)[: (dim // 2)] / dim))
            / linear_factor
        )
        freqs = torch.outer(pos, freqs)
        freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float()
        freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float()
        return freqs_cos, freqs_sin

    @staticmethod
    def get_3d_rotary_pos_embed(position, embed_dim, voxel_resolution, theta: int = 10000):
        """
        Compute 3D rotary position embeddings for spatial coordinates.

        Args:
            position: 3D position tensor with shape [..., 3]
            embed_dim: Embedding dimension
            voxel_resolution: Resolution of the voxel grid
            theta: Base frequency for rotary embeddings

        Returns:
            Tuple of (cos_embeddings, sin_embeddings) for 3D positions
        """
        assert position.shape[-1] == 3
        dim_xy = embed_dim // 8 * 3
        dim_z = embed_dim // 8 * 2

        grid = torch.arange(voxel_resolution, dtype=torch.float32, device=position.device)
        freqs_xy = RotaryEmbedding.get_1d_rotary_pos_embed(dim_xy, grid, theta=theta)
        freqs_z = RotaryEmbedding.get_1d_rotary_pos_embed(dim_z, grid, theta=theta)

        xy_cos, xy_sin = freqs_xy
        z_cos, z_sin = freqs_z

        embed_flattn = position.view(-1, position.shape[-1])
        x_cos = xy_cos[embed_flattn[:, 0], :]
        x_sin = xy_sin[embed_flattn[:, 0], :]
        y_cos = xy_cos[embed_flattn[:, 1], :]
        y_sin = xy_sin[embed_flattn[:, 1], :]
        z_cos = z_cos[embed_flattn[:, 2], :]
        z_sin = z_sin[embed_flattn[:, 2], :]

        cos = torch.cat((x_cos, y_cos, z_cos), dim=-1)
        sin = torch.cat((x_sin, y_sin, z_sin), dim=-1)

        cos = cos.view(*position.shape[:-1], embed_dim)
        sin = sin.view(*position.shape[:-1], embed_dim)
        return cos, sin

    @staticmethod
    def apply_rotary_emb(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]):
        """
        Apply rotary position embeddings to input tensor.

        Args:
            x: Input tensor to apply rotary embeddings to
            freqs_cis: Tuple of (cos_embeddings, sin_embeddings) or single tensor

        Returns:
            Tensor with rotary position embeddings applied
        """
        cos, sin = freqs_cis
        cos, sin = cos.to(x.device), sin.to(x.device)
        cos = cos.unsqueeze(1)
        sin = sin.unsqueeze(1)

        x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
        x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)

        out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
        return out


# Core attention processing logic (eliminating major duplication)
class AttnCore:
    """
    Core attention processing logic shared across processors.

    This class provides the fundamental attention computation pipeline
    that can be reused across different attention processor implementations.
    """

    @staticmethod
    def process_attention_base(
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        get_qkv_fn: Callable = None,
        apply_rope_fn: Optional[Callable] = None,
        **kwargs,
    ):
        """
        Generic attention processing core shared across different processors.

        This function implements the common attention computation pipeline including:
        1. Hidden state preprocessing
        2. Attention mask preparation
        3. Q/K/V computation via provided function
        4. Tensor reshaping for multi-head attention
        5. Optional normalization and RoPE application
        6. Scaled dot-product attention computation

        Args:
            attn: Attention module instance
            hidden_states: Input hidden states tensor
            encoder_hidden_states: Optional encoder hidden states for cross-attention
            attention_mask: Optional attention mask tensor
            temb: Optional temporal embedding tensor
            get_qkv_fn: Function to compute Q, K, V tensors
            apply_rope_fn: Optional function to apply rotary position embeddings
            **kwargs: Additional keyword arguments passed to subfunctions

        Returns:
            Tuple containing (attention_output, residual, input_ndim, shape_info,
            batch_size, num_heads, head_dim)
        """
        # Prepare hidden states
        hidden_states, residual, input_ndim, shape_info = AttnUtils.prepare_hidden_states(hidden_states, attn, temb)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        # Prepare attention mask
        attention_mask = AttnUtils.prepare_attention_mask(attention_mask, attn, sequence_length, batch_size)

        # Get Q, K, V
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        query, key, value = get_qkv_fn(attn, hidden_states, encoder_hidden_states, **kwargs)

        # Reshape for attention
        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = AttnUtils.reshape_qkv_for_attention(query, batch_size, attn.heads, head_dim)
        key = AttnUtils.reshape_qkv_for_attention(key, batch_size, attn.heads, head_dim)
        value = AttnUtils.reshape_qkv_for_attention(value, batch_size, attn.heads, value.shape[-1] // attn.heads)

        # Apply normalization
        query, key = AttnUtils.apply_norms(query, key, getattr(attn, "norm_q", None), getattr(attn, "norm_k", None))

        # Apply RoPE if provided
        if apply_rope_fn is not None:
            query, key = apply_rope_fn(query, key, head_dim, **kwargs)

        # Compute attention
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        return hidden_states, residual, input_ndim, shape_info, batch_size, attn.heads, head_dim


# Specific processor implementations (minimal unique code)
class PoseRoPEAttnProcessor2_0:
    """
    Attention processor with Rotary Position Encoding (RoPE) for 3D spatial awareness.

    This processor extends standard attention with 3D rotary position embeddings
    to provide spatial awareness for 3D scene understanding tasks.
    """

    def __init__(self):
        """Initialize the RoPE attention processor."""
        AttnUtils.check_pytorch_compatibility()

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_indices: Dict = None,
        temb: Optional[torch.Tensor] = None,
        n_pbrs=1,
        *args,
        **kwargs,
    ) -> torch.Tensor:
        """
        Apply RoPE-enhanced attention computation.

        Args:
            attn: Attention module instance
            hidden_states: Input hidden states tensor
            encoder_hidden_states: Optional encoder hidden states for cross-attention
            attention_mask: Optional attention mask tensor
            position_indices: Dictionary containing 3D position information for RoPE
            temb: Optional temporal embedding tensor
            n_pbrs: Number of PBR material types
            *args: Additional positional arguments
            **kwargs: Additional keyword arguments

        Returns:
            Attention output tensor with applied rotary position encodings
        """
        AttnUtils.handle_deprecation_warning(args, kwargs)

        def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs):
            return attn.to_q(hidden_states), attn.to_k(encoder_hidden_states), attn.to_v(encoder_hidden_states)

        def apply_rope(query, key, head_dim, **kwargs):
            if position_indices is not None:
                if head_dim in position_indices:
                    image_rotary_emb = position_indices[head_dim]
                else:
                    image_rotary_emb = RotaryEmbedding.get_3d_rotary_pos_embed(
                        rearrange(
                            position_indices["voxel_indices"].unsqueeze(1).repeat(1, n_pbrs, 1, 1),
                            "b n_pbrs l c -> (b n_pbrs) l c",
                        ),
                        head_dim,
                        voxel_resolution=position_indices["voxel_resolution"],
                    )
                    position_indices[head_dim] = image_rotary_emb

                query = RotaryEmbedding.apply_rotary_emb(query, image_rotary_emb)
                key = RotaryEmbedding.apply_rotary_emb(key, image_rotary_emb)
            return query, key

        # Core attention processing
        hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base(
            attn,
            hidden_states,
            encoder_hidden_states,
            attention_mask,
            temb,
            get_qkv_fn=get_qkv,
            apply_rope_fn=apply_rope,
            position_indices=position_indices,
            n_pbrs=n_pbrs,
        )

        # Finalize output
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim)
        hidden_states = hidden_states.to(hidden_states.dtype)

        return AttnUtils.finalize_output(hidden_states, input_ndim, shape_info, attn, residual, attn.to_out)


class SelfAttnProcessor2_0(BaseAttnProcessor):
    """
    Self-attention processor with PBR (Physically Based Rendering) material support.

    This processor handles multiple PBR material types (e.g., albedo, metallic-roughness)
    with separate attention computation paths for each material type.
    """

    def __init__(self, **kwargs):
        """
        Initialize self-attention processor with PBR support.

        Args:
            **kwargs: Arguments passed to BaseAttnProcessor initialization
        """
        super().__init__(**kwargs)
        self.register_pbr_modules(["qkv", "out", "add_kv"], **kwargs)

    def process_single(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        token: Literal["albedo", "mr"] = "albedo",
        multiple_devices=False,
        *args,
        **kwargs,
    ):
        """
        Process attention for a single PBR material type.

        Args:
            attn: Attention module instance
            hidden_states: Input hidden states tensor
            encoder_hidden_states: Optional encoder hidden states for cross-attention
            attention_mask: Optional attention mask tensor
            temb: Optional temporal embedding tensor
            token: PBR material type to process ("albedo", "mr", etc.)
            multiple_devices: Whether to use multiple GPU devices
            *args: Additional positional arguments
            **kwargs: Additional keyword arguments

        Returns:
            Processed attention output for the specified PBR material type
        """
        target = attn if token == "albedo" else attn.processor
        token_suffix = "" if token == "albedo" else "_" + token

        # Device management (if needed)
        if multiple_devices:
            device = torch.device("cuda:0") if token == "albedo" else torch.device("cuda:1")
            for attr in [f"to_q{token_suffix}", f"to_k{token_suffix}", f"to_v{token_suffix}", f"to_out{token_suffix}"]:
                getattr(target, attr).to(device)

        def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs):
            return (
                getattr(target, f"to_q{token_suffix}")(hidden_states),
                getattr(target, f"to_k{token_suffix}")(encoder_hidden_states),
                getattr(target, f"to_v{token_suffix}")(encoder_hidden_states),
            )

        # Core processing using shared logic
        hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base(
            attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv
        )

        # Finalize
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * head_dim)
        hidden_states = hidden_states.to(hidden_states.dtype)

        return AttnUtils.finalize_output(
            hidden_states, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}")
        )

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        *args,
        **kwargs,
    ) -> torch.Tensor:
        """
        Apply self-attention with PBR material processing.

        Processes multiple PBR material types sequentially, applying attention
        computation for each material type separately and combining results.

        Args:
            attn: Attention module instance
            hidden_states: Input hidden states tensor with PBR dimension
            encoder_hidden_states: Optional encoder hidden states for cross-attention
            attention_mask: Optional attention mask tensor
            temb: Optional temporal embedding tensor
            *args: Additional positional arguments
            **kwargs: Additional keyword arguments

        Returns:
            Combined attention output for all PBR material types
        """
        AttnUtils.handle_deprecation_warning(args, kwargs)

        B = hidden_states.size(0)
        pbr_hidden_states = torch.split(hidden_states, 1, dim=1)

        # Process each PBR setting
        results = []
        for token, pbr_hs in zip(self.pbr_setting, pbr_hidden_states):
            processed_hs = rearrange(pbr_hs, "b n_pbrs n l c -> (b n_pbrs n) l c").to("cuda:0")
            result = self.process_single(attn, processed_hs, None, attention_mask, temb, token, False)
            results.append(result)

        outputs = [rearrange(result, "(b n_pbrs n) l c -> b n_pbrs n l c", b=B, n_pbrs=1) for result in results]
        return torch.cat(outputs, dim=1)


class RefAttnProcessor2_0(BaseAttnProcessor):
    """
    Reference attention processor with shared value computation across PBR materials.

    This processor computes query and key once, but uses separate value projections
    for different PBR material types, enabling efficient multi-material processing.
    """

    def __init__(self, **kwargs):
        """
        Initialize reference attention processor.

        Args:
            **kwargs: Arguments passed to BaseAttnProcessor initialization
        """
        super().__init__(**kwargs)
        self.pbr_settings = self.pbr_setting  # Alias for compatibility
        self.register_pbr_modules(["v_only", "out"], **kwargs)

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        *args,
        **kwargs,
    ) -> torch.Tensor:
        """
        Apply reference attention with shared Q/K and separate V projections.

        This method computes query and key tensors once and reuses them across
        all PBR material types, while using separate value projections for each
        material type to maintain material-specific information.

        Args:
            attn: Attention module instance
            hidden_states: Input hidden states tensor
            encoder_hidden_states: Optional encoder hidden states for cross-attention
            attention_mask: Optional attention mask tensor
            temb: Optional temporal embedding tensor
            *args: Additional positional arguments
            **kwargs: Additional keyword arguments

        Returns:
            Stacked attention output for all PBR material types
        """
        AttnUtils.handle_deprecation_warning(args, kwargs)

        def get_qkv(attn, hidden_states, encoder_hidden_states, **kwargs):
            query = attn.to_q(hidden_states)
            key = attn.to_k(encoder_hidden_states)

            # Concatenate values from all PBR settings
            value_list = [attn.to_v(encoder_hidden_states)]
            for token in ["_" + token for token in self.pbr_settings if token != "albedo"]:
                value_list.append(getattr(attn.processor, f"to_v{token}")(encoder_hidden_states))
            value = torch.cat(value_list, dim=-1)

            return query, key, value

        # Core processing
        hidden_states, residual, input_ndim, shape_info, batch_size, heads, head_dim = AttnCore.process_attention_base(
            attn, hidden_states, encoder_hidden_states, attention_mask, temb, get_qkv_fn=get_qkv
        )

        # Split and process each PBR setting output
        hidden_states_list = torch.split(hidden_states, head_dim, dim=-1)
        output_hidden_states_list = []

        for i, hs in enumerate(hidden_states_list):
            hs = hs.transpose(1, 2).reshape(batch_size, -1, heads * head_dim).to(hs.dtype)
            token_suffix = "_" + self.pbr_settings[i] if self.pbr_settings[i] != "albedo" else ""
            target = attn if self.pbr_settings[i] == "albedo" else attn.processor

            hs = AttnUtils.finalize_output(
                hs, input_ndim, shape_info, attn, residual, getattr(target, f"to_out{token_suffix}")
            )
            output_hidden_states_list.append(hs)

        return torch.stack(output_hidden_states_list, dim=1)