Update enhance.py
Browse files- enhance.py +14 -4
enhance.py
CHANGED
@@ -73,7 +73,7 @@ class LTXEnhanceAttnProcessor2_0:
|
|
73 |
enhance_scores = mean_scores.mean() * (num_frames + 4.0)
|
74 |
enhance_scores = enhance_scores.clamp(min=1)
|
75 |
return enhance_scores
|
76 |
-
|
77 |
def __call__(
|
78 |
self,
|
79 |
attn: Attention,
|
@@ -92,12 +92,13 @@ class LTXEnhanceAttnProcessor2_0:
|
|
92 |
num_heads = attn.heads
|
93 |
head_dim = inner_dim // num_heads
|
94 |
|
|
|
95 |
query = attn.to_q(hidden_states)
|
96 |
key = attn.to_k(encoder_hidden_states)
|
97 |
value = attn.to_v(encoder_hidden_states)
|
98 |
|
99 |
-
# Reshape
|
100 |
-
query = query.view(batch_size,
|
101 |
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
102 |
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
103 |
|
@@ -105,6 +106,7 @@ class LTXEnhanceAttnProcessor2_0:
|
|
105 |
query = query.float()
|
106 |
key = key.float()
|
107 |
|
|
|
108 |
enhance_scores = None
|
109 |
if is_enhance_enabled():
|
110 |
try:
|
@@ -118,6 +120,12 @@ class LTXEnhanceAttnProcessor2_0:
|
|
118 |
except ValueError as e:
|
119 |
print(f"Warning: Could not calculate enhance scores: {e}")
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
122 |
query, key, value,
|
123 |
attn_mask=attention_mask,
|
@@ -125,6 +133,7 @@ class LTXEnhanceAttnProcessor2_0:
|
|
125 |
is_causal=False
|
126 |
)
|
127 |
|
|
|
128 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
|
129 |
hidden_states = hidden_states.to(query.dtype)
|
130 |
|
@@ -132,11 +141,12 @@ class LTXEnhanceAttnProcessor2_0:
|
|
132 |
if is_enhance_enabled() and enhance_scores is not None:
|
133 |
hidden_states = hidden_states * enhance_scores
|
134 |
|
|
|
135 |
hidden_states = attn.to_out[0](hidden_states)
|
136 |
hidden_states = attn.to_out[1](hidden_states)
|
137 |
|
138 |
return hidden_states
|
139 |
-
|
140 |
def inject_enhance_for_ltx(model: nn.Module) -> None:
|
141 |
"""
|
142 |
Inject enhance score for LTX model.
|
|
|
73 |
enhance_scores = mean_scores.mean() * (num_frames + 4.0)
|
74 |
enhance_scores = enhance_scores.clamp(min=1)
|
75 |
return enhance_scores
|
76 |
+
|
77 |
def __call__(
|
78 |
self,
|
79 |
attn: Attention,
|
|
|
92 |
num_heads = attn.heads
|
93 |
head_dim = inner_dim // num_heads
|
94 |
|
95 |
+
# Get query, key, value projections
|
96 |
query = attn.to_q(hidden_states)
|
97 |
key = attn.to_k(encoder_hidden_states)
|
98 |
value = attn.to_v(encoder_hidden_states)
|
99 |
|
100 |
+
# Reshape projections
|
101 |
+
query = query.view(batch_size, sequence_length, num_heads, head_dim).transpose(1, 2)
|
102 |
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
103 |
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
104 |
|
|
|
106 |
query = query.float()
|
107 |
key = key.float()
|
108 |
|
109 |
+
# Process attention
|
110 |
enhance_scores = None
|
111 |
if is_enhance_enabled():
|
112 |
try:
|
|
|
120 |
except ValueError as e:
|
121 |
print(f"Warning: Could not calculate enhance scores: {e}")
|
122 |
|
123 |
+
# Make sure attention_mask has correct shape
|
124 |
+
if attention_mask is not None:
|
125 |
+
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
|
126 |
+
attention_mask = attention_mask.expand(-1, num_heads, -1, -1)
|
127 |
+
|
128 |
+
# Compute attention with correct shapes
|
129 |
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
130 |
query, key, value,
|
131 |
attn_mask=attention_mask,
|
|
|
133 |
is_causal=False
|
134 |
)
|
135 |
|
136 |
+
# Reshape output
|
137 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
|
138 |
hidden_states = hidden_states.to(query.dtype)
|
139 |
|
|
|
141 |
if is_enhance_enabled() and enhance_scores is not None:
|
142 |
hidden_states = hidden_states * enhance_scores
|
143 |
|
144 |
+
# Output projection
|
145 |
hidden_states = attn.to_out[0](hidden_states)
|
146 |
hidden_states = attn.to_out[1](hidden_states)
|
147 |
|
148 |
return hidden_states
|
149 |
+
|
150 |
def inject_enhance_for_ltx(model: nn.Module) -> None:
|
151 |
"""
|
152 |
Inject enhance score for LTX model.
|