Update enhance.py
Browse files- enhance.py +8 -7
enhance.py
CHANGED
@@ -76,8 +76,13 @@ class LTXEnhanceAttnProcessor2_0:
|
|
76 |
attention_mask = None,
|
77 |
**kwargs
|
78 |
) -> torch.Tensor:
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
81 |
text_seq_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
|
82 |
|
83 |
if encoder_hidden_states is None:
|
@@ -90,11 +95,7 @@ class LTXEnhanceAttnProcessor2_0:
|
|
90 |
query = attn.to_q(hidden_states)
|
91 |
key = attn.to_k(encoder_hidden_states)
|
92 |
value = attn.to_v(encoder_hidden_states)
|
93 |
-
|
94 |
-
query = query.view(batch_size, sequence_length, num_heads, head_dim).transpose(1, 2)
|
95 |
-
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
96 |
-
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
97 |
-
|
98 |
if attn.upcast_attention:
|
99 |
query = query.float()
|
100 |
key = key.float()
|
|
|
76 |
attention_mask = None,
|
77 |
**kwargs
|
78 |
) -> torch.Tensor:
|
79 |
+
# The shape could be [batch_size, sequence_length, channels] or [batch_size, sequence_length, num_heads, head_dim]
|
80 |
+
# We need to handle both cases
|
81 |
+
if hidden_states.ndim == 4:
|
82 |
+
batch_size, sequence_length, num_heads, head_dim = hidden_states.shape
|
83 |
+
else:
|
84 |
+
batch_size, sequence_length, inner_dim = hidden_states.shape
|
85 |
+
|
86 |
text_seq_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
|
87 |
|
88 |
if encoder_hidden_states is None:
|
|
|
95 |
query = attn.to_q(hidden_states)
|
96 |
key = attn.to_k(encoder_hidden_states)
|
97 |
value = attn.to_v(encoder_hidden_states)
|
98 |
+
|
|
|
|
|
|
|
|
|
99 |
if attn.upcast_attention:
|
100 |
query = query.float()
|
101 |
key = key.float()
|