hatmanstack commited on
Commit
28111ae
·
1 Parent(s): e59f1dc

add SD35AdaLayerNormZeroX

Browse files
Files changed (1) hide show
  1. models_attention.py +35 -1
models_attention.py CHANGED
@@ -22,12 +22,13 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph
22
  from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
23
  from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
24
  from diffusers.models.embeddings import SinusoidalPositionalEmbedding
25
- from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
26
 
27
 
28
  logger = logging.get_logger(__name__)
29
 
30
 
 
31
  def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
32
  # "feed_forward_chunk_size" can be used to save memory
33
  if hidden_states.shape[chunk_dim] % chunk_size != 0:
@@ -42,6 +43,39 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim:
42
  )
43
  return ff_output
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  @maybe_allow_in_graph
47
  class GatedSelfAttentionDense(nn.Module):
 
22
  from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
23
  from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
24
  from diffusers.models.embeddings import SinusoidalPositionalEmbedding
25
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
26
 
27
 
28
  logger = logging.get_logger(__name__)
29
 
30
 
31
+
32
  def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
33
  # "feed_forward_chunk_size" can be used to save memory
34
  if hidden_states.shape[chunk_dim] % chunk_size != 0:
 
43
  )
44
  return ff_output
45
 
46
+ @maybe_allow_in_graph
47
+ class SD35AdaLayerNormZeroX(nn.Module):
48
+ r"""
49
+ Norm layer adaptive layer norm zero (AdaLN-Zero).
50
+
51
+ Parameters:
52
+ embedding_dim (`int`): The size of each embedding vector.
53
+ num_embeddings (`int`): The size of the embeddings dictionary.
54
+ """
55
+
56
+ def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None:
57
+ super().__init__()
58
+
59
+ self.silu = nn.SiLU()
60
+ self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias)
61
+ if norm_type == "layer_norm":
62
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
63
+ else:
64
+ raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.")
65
+
66
+ def forward(
67
+ self,
68
+ hidden_states: torch.Tensor,
69
+ emb: Optional[torch.Tensor] = None,
70
+ ) -> Tuple[torch.Tensor, ...]:
71
+ emb = self.linear(self.silu(emb))
72
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
73
+ 9, dim=1
74
+ )
75
+ norm_hidden_states = self.norm(hidden_states)
76
+ hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
77
+ norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
78
+ return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2
79
 
80
  @maybe_allow_in_graph
81
  class GatedSelfAttentionDense(nn.Module):