Update enhance.py
Browse files- enhance.py +9 -2
enhance.py
CHANGED
@@ -76,6 +76,9 @@ class LTXEnhanceAttnProcessor2_0:
|
|
76 |
attention_mask = None,
|
77 |
**kwargs
|
78 |
) -> torch.Tensor:
|
|
|
|
|
|
|
79 |
# The shape could be [batch_size, sequence_length, channels] or [batch_size, sequence_length, num_heads, head_dim]
|
80 |
# We need to handle both cases
|
81 |
if hidden_states.ndim == 4:
|
@@ -95,7 +98,11 @@ class LTXEnhanceAttnProcessor2_0:
|
|
95 |
query = attn.to_q(hidden_states)
|
96 |
key = attn.to_k(encoder_hidden_states)
|
97 |
value = attn.to_v(encoder_hidden_states)
|
98 |
-
|
|
|
|
|
|
|
|
|
99 |
if attn.upcast_attention:
|
100 |
query = query.float()
|
101 |
key = key.float()
|
@@ -125,7 +132,7 @@ class LTXEnhanceAttnProcessor2_0:
|
|
125 |
)
|
126 |
|
127 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
|
128 |
-
hidden_states = hidden_states.to(orig_dtype) #
|
129 |
|
130 |
if is_enhance_enabled() and enhance_scores is not None:
|
131 |
hidden_states = hidden_states * enhance_scores
|
|
|
76 |
attention_mask = None,
|
77 |
**kwargs
|
78 |
) -> torch.Tensor:
|
79 |
+
# Store original dtype first
|
80 |
+
orig_dtype = hidden_states.dtype
|
81 |
+
|
82 |
# The shape could be [batch_size, sequence_length, channels] or [batch_size, sequence_length, num_heads, head_dim]
|
83 |
# We need to handle both cases
|
84 |
if hidden_states.ndim == 4:
|
|
|
98 |
query = attn.to_q(hidden_states)
|
99 |
key = attn.to_k(encoder_hidden_states)
|
100 |
value = attn.to_v(encoder_hidden_states)
|
101 |
+
|
102 |
+
query = query.view(batch_size, sequence_length, num_heads, head_dim).transpose(1, 2)
|
103 |
+
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
104 |
+
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
105 |
+
|
106 |
if attn.upcast_attention:
|
107 |
query = query.float()
|
108 |
key = key.float()
|
|
|
132 |
)
|
133 |
|
134 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
|
135 |
+
hidden_states = hidden_states.to(orig_dtype) # Now orig_dtype is defined
|
136 |
|
137 |
if is_enhance_enabled() and enhance_scores is not None:
|
138 |
hidden_states = hidden_states * enhance_scores
|