Spaces:
Running
on
Zero
Running
on
Zero
hatmanstack
commited on
Commit
·
28111ae
1
Parent(s):
e59f1dc
add SD35AdaLayerNormZeroX
Browse files- models_attention.py +35 -1
models_attention.py
CHANGED
@@ -22,12 +22,13 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph
|
|
22 |
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
|
23 |
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
|
24 |
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
|
25 |
-
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
26 |
|
27 |
|
28 |
logger = logging.get_logger(__name__)
|
29 |
|
30 |
|
|
|
31 |
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
32 |
# "feed_forward_chunk_size" can be used to save memory
|
33 |
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
@@ -42,6 +43,39 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim:
|
|
42 |
)
|
43 |
return ff_output
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
@maybe_allow_in_graph
|
47 |
class GatedSelfAttentionDense(nn.Module):
|
|
|
22 |
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
|
23 |
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
|
24 |
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
|
25 |
+
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
26 |
|
27 |
|
28 |
logger = logging.get_logger(__name__)
|
29 |
|
30 |
|
31 |
+
|
32 |
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
33 |
# "feed_forward_chunk_size" can be used to save memory
|
34 |
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
|
|
43 |
)
|
44 |
return ff_output
|
45 |
|
46 |
+
@maybe_allow_in_graph
|
47 |
+
class SD35AdaLayerNormZeroX(nn.Module):
|
48 |
+
r"""
|
49 |
+
Norm layer adaptive layer norm zero (AdaLN-Zero).
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
embedding_dim (`int`): The size of each embedding vector.
|
53 |
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None:
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.silu = nn.SiLU()
|
60 |
+
self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias)
|
61 |
+
if norm_type == "layer_norm":
|
62 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
63 |
+
else:
|
64 |
+
raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.")
|
65 |
+
|
66 |
+
def forward(
|
67 |
+
self,
|
68 |
+
hidden_states: torch.Tensor,
|
69 |
+
emb: Optional[torch.Tensor] = None,
|
70 |
+
) -> Tuple[torch.Tensor, ...]:
|
71 |
+
emb = self.linear(self.silu(emb))
|
72 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
|
73 |
+
9, dim=1
|
74 |
+
)
|
75 |
+
norm_hidden_states = self.norm(hidden_states)
|
76 |
+
hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
77 |
+
norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
|
78 |
+
return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2
|
79 |
|
80 |
@maybe_allow_in_graph
|
81 |
class GatedSelfAttentionDense(nn.Module):
|