Update enhance.py
Browse files- enhance.py +67 -67
enhance.py
CHANGED
@@ -13,66 +13,66 @@ class LTXEnhanceAttnProcessor2_0:
|
|
13 |
raise ImportError("LTXEnhanceAttnProcessor2_0 requires PyTorch 2.0.")
|
14 |
|
15 |
def _get_enhance_scores(self, query, key, inner_dim, num_heads, num_frames, text_seq_length=None):
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
# Create a mask for diagonal elements
|
62 |
-
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
|
63 |
-
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
|
64 |
|
65 |
-
|
66 |
-
|
|
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
num_off_diag = num_frames * num_frames - num_frames
|
71 |
-
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
def __call__(
|
78 |
self,
|
@@ -91,19 +91,20 @@ class LTXEnhanceAttnProcessor2_0:
|
|
91 |
inner_dim = attn.to_q.out_features
|
92 |
num_heads = attn.heads
|
93 |
head_dim = inner_dim // num_heads
|
94 |
-
|
95 |
query = attn.to_q(hidden_states)
|
96 |
key = attn.to_k(encoder_hidden_states)
|
97 |
value = attn.to_v(encoder_hidden_states)
|
98 |
-
|
|
|
99 |
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
100 |
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
101 |
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
102 |
-
|
103 |
if attn.upcast_attention:
|
104 |
query = query.float()
|
105 |
key = key.float()
|
106 |
-
|
107 |
enhance_scores = None
|
108 |
if is_enhance_enabled():
|
109 |
try:
|
@@ -116,25 +117,24 @@ class LTXEnhanceAttnProcessor2_0:
|
|
116 |
)
|
117 |
except ValueError as e:
|
118 |
print(f"Warning: Could not calculate enhance scores: {e}")
|
119 |
-
|
120 |
-
|
121 |
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
122 |
query, key, value,
|
123 |
attn_mask=attention_mask,
|
124 |
dropout_p=0.0,
|
125 |
is_causal=False
|
126 |
)
|
127 |
-
|
128 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
|
129 |
hidden_states = hidden_states.to(query.dtype)
|
130 |
-
|
131 |
# Apply enhancement if enabled
|
132 |
if is_enhance_enabled() and enhance_scores is not None:
|
133 |
hidden_states = hidden_states * enhance_scores
|
134 |
-
|
135 |
hidden_states = attn.to_out[0](hidden_states)
|
136 |
hidden_states = attn.to_out[1](hidden_states)
|
137 |
-
|
138 |
return hidden_states
|
139 |
|
140 |
def inject_enhance_for_ltx(model: nn.Module) -> None:
|
|
|
13 |
raise ImportError("LTXEnhanceAttnProcessor2_0 requires PyTorch 2.0.")
|
14 |
|
15 |
def _get_enhance_scores(self, query, key, inner_dim, num_heads, num_frames, text_seq_length=None):
|
16 |
+
"""Calculate enhancement scores for the attention mechanism"""
|
17 |
+
head_dim = inner_dim // num_heads
|
18 |
+
|
19 |
+
if text_seq_length is not None:
|
20 |
+
img_q = query[:, :, :-text_seq_length] if text_seq_length > 0 else query
|
21 |
+
img_k = key[:, :, :-text_seq_length] if text_seq_length > 0 else key
|
22 |
+
else:
|
23 |
+
img_q, img_k = query, key
|
24 |
+
|
25 |
+
batch_size, num_heads, ST, head_dim = img_q.shape
|
26 |
+
# Calculate spatial dimension by dividing total tokens by number of frames
|
27 |
+
spatial_dim = ST // num_frames
|
28 |
+
# Ensure spatial_dim is calculated correctly
|
29 |
+
if spatial_dim * num_frames != ST:
|
30 |
+
# If we can't divide evenly, we'll need to pad or reshape
|
31 |
+
spatial_dim = max(1, ST // num_frames)
|
32 |
+
# Adjust ST to be evenly divisible
|
33 |
+
ST = spatial_dim * num_frames
|
34 |
+
|
35 |
+
# Ensure tensors have the right shape before rearranging
|
36 |
+
img_q = img_q[:, :, :ST, :]
|
37 |
+
img_k = img_k[:, :, :ST, :]
|
38 |
+
|
39 |
+
try:
|
40 |
+
query_image = rearrange(
|
41 |
+
img_q, "B N (T S) C -> (B S) N T C",
|
42 |
+
T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
|
43 |
+
)
|
44 |
+
key_image = rearrange(
|
45 |
+
img_k, "B N (T S) C -> (B S) N T C",
|
46 |
+
T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
|
47 |
+
)
|
48 |
+
except Exception as e:
|
49 |
+
# If rearrangement fails, return a default enhancement score
|
50 |
+
return torch.ones(img_q.shape[0], 1, 1, 1, device=img_q.device)
|
51 |
+
|
52 |
+
scale = head_dim**-0.5
|
53 |
+
query_image = query_image * scale
|
54 |
+
attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32
|
55 |
+
attn_temp = attn_temp.to(torch.float32)
|
56 |
+
attn_temp = attn_temp.softmax(dim=-1)
|
57 |
+
|
58 |
+
# Reshape to [batch_size * num_tokens, num_frames, num_frames]
|
59 |
+
attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
|
|
|
|
|
|
|
|
|
60 |
|
61 |
+
# Create a mask for diagonal elements
|
62 |
+
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
|
63 |
+
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
|
64 |
|
65 |
+
# Zero out diagonal elements
|
66 |
+
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
|
|
|
|
|
67 |
|
68 |
+
# Calculate mean for each token's attention matrix
|
69 |
+
# Number of off-diagonal elements per matrix is n*n - n
|
70 |
+
num_off_diag = num_frames * num_frames - num_frames
|
71 |
+
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
|
72 |
+
|
73 |
+
enhance_scores = mean_scores.mean() * (num_frames + 4.0)
|
74 |
+
enhance_scores = enhance_scores.clamp(min=1)
|
75 |
+
return enhance_scores
|
76 |
|
77 |
def __call__(
|
78 |
self,
|
|
|
91 |
inner_dim = attn.to_q.out_features
|
92 |
num_heads = attn.heads
|
93 |
head_dim = inner_dim // num_heads
|
94 |
+
|
95 |
query = attn.to_q(hidden_states)
|
96 |
key = attn.to_k(encoder_hidden_states)
|
97 |
value = attn.to_v(encoder_hidden_states)
|
98 |
+
|
99 |
+
# Reshape query, key, value to match expected dimensions
|
100 |
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
101 |
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
102 |
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
103 |
+
|
104 |
if attn.upcast_attention:
|
105 |
query = query.float()
|
106 |
key = key.float()
|
107 |
+
|
108 |
enhance_scores = None
|
109 |
if is_enhance_enabled():
|
110 |
try:
|
|
|
117 |
)
|
118 |
except ValueError as e:
|
119 |
print(f"Warning: Could not calculate enhance scores: {e}")
|
120 |
+
|
|
|
121 |
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
122 |
query, key, value,
|
123 |
attn_mask=attention_mask,
|
124 |
dropout_p=0.0,
|
125 |
is_causal=False
|
126 |
)
|
127 |
+
|
128 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
|
129 |
hidden_states = hidden_states.to(query.dtype)
|
130 |
+
|
131 |
# Apply enhancement if enabled
|
132 |
if is_enhance_enabled() and enhance_scores is not None:
|
133 |
hidden_states = hidden_states * enhance_scores
|
134 |
+
|
135 |
hidden_states = attn.to_out[0](hidden_states)
|
136 |
hidden_states = attn.to_out[1](hidden_states)
|
137 |
+
|
138 |
return hidden_states
|
139 |
|
140 |
def inject_enhance_for_ltx(model: nn.Module) -> None:
|