Update enhance.py
Browse files- enhance.py +16 -28
enhance.py
CHANGED
@@ -15,6 +15,7 @@ class LTXEnhanceAttnProcessor2_0:
|
|
15 |
def _get_enhance_scores(self, query, key, inner_dim, num_heads, num_frames, text_seq_length=None):
|
16 |
"""Calculate enhancement scores for the attention mechanism"""
|
17 |
head_dim = inner_dim // num_heads
|
|
|
18 |
|
19 |
if text_seq_length is not None:
|
20 |
img_q = query[:, :, :-text_seq_length] if text_seq_length > 0 else query
|
@@ -23,16 +24,11 @@ class LTXEnhanceAttnProcessor2_0:
|
|
23 |
img_q, img_k = query, key
|
24 |
|
25 |
batch_size, num_heads, ST, head_dim = img_q.shape
|
26 |
-
# Calculate spatial dimension by dividing total tokens by number of frames
|
27 |
spatial_dim = ST // num_frames
|
28 |
-
# Ensure spatial_dim is calculated correctly
|
29 |
if spatial_dim * num_frames != ST:
|
30 |
-
# If we can't divide evenly, we'll need to pad or reshape
|
31 |
spatial_dim = max(1, ST // num_frames)
|
32 |
-
# Adjust ST to be evenly divisible
|
33 |
ST = spatial_dim * num_frames
|
34 |
|
35 |
-
# Ensure tensors have the right shape before rearranging
|
36 |
img_q = img_q[:, :, :ST, :]
|
37 |
img_k = img_k[:, :, :ST, :]
|
38 |
|
@@ -46,33 +42,31 @@ class LTXEnhanceAttnProcessor2_0:
|
|
46 |
T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
|
47 |
)
|
48 |
except Exception as e:
|
49 |
-
|
50 |
-
return torch.ones(img_q.shape[0], 1, 1, 1, device=img_q.device)
|
51 |
|
52 |
scale = head_dim**-0.5
|
53 |
query_image = query_image * scale
|
54 |
-
attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32
|
55 |
-
attn_temp = attn_temp.to(torch.float32)
|
56 |
-
attn_temp = attn_temp.softmax(dim=-1)
|
57 |
|
58 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
|
60 |
-
|
61 |
-
# Create a mask for diagonal elements
|
62 |
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
|
63 |
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
|
64 |
-
|
65 |
-
# Zero out diagonal elements
|
66 |
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
|
67 |
-
|
68 |
-
# Calculate mean for each token's attention matrix
|
69 |
-
# Number of off-diagonal elements per matrix is n*n - n
|
70 |
num_off_diag = num_frames * num_frames - num_frames
|
71 |
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
|
72 |
|
73 |
enhance_scores = mean_scores.mean() * (num_frames + 4.0)
|
74 |
enhance_scores = enhance_scores.clamp(min=1)
|
75 |
-
|
|
|
|
|
76 |
|
77 |
def __call__(
|
78 |
self,
|
@@ -82,6 +76,7 @@ class LTXEnhanceAttnProcessor2_0:
|
|
82 |
attention_mask = None,
|
83 |
**kwargs
|
84 |
) -> torch.Tensor:
|
|
|
85 |
batch_size, sequence_length, _ = hidden_states.shape
|
86 |
text_seq_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
|
87 |
|
@@ -92,12 +87,10 @@ class LTXEnhanceAttnProcessor2_0:
|
|
92 |
num_heads = attn.heads
|
93 |
head_dim = inner_dim // num_heads
|
94 |
|
95 |
-
# Get query, key, value projections
|
96 |
query = attn.to_q(hidden_states)
|
97 |
key = attn.to_k(encoder_hidden_states)
|
98 |
value = attn.to_v(encoder_hidden_states)
|
99 |
|
100 |
-
# Reshape projections
|
101 |
query = query.view(batch_size, sequence_length, num_heads, head_dim).transpose(1, 2)
|
102 |
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
103 |
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
@@ -106,7 +99,6 @@ class LTXEnhanceAttnProcessor2_0:
|
|
106 |
query = query.float()
|
107 |
key = key.float()
|
108 |
|
109 |
-
# Process attention
|
110 |
enhance_scores = None
|
111 |
if is_enhance_enabled():
|
112 |
try:
|
@@ -120,12 +112,10 @@ class LTXEnhanceAttnProcessor2_0:
|
|
120 |
except ValueError as e:
|
121 |
print(f"Warning: Could not calculate enhance scores: {e}")
|
122 |
|
123 |
-
# Make sure attention_mask has correct shape
|
124 |
if attention_mask is not None:
|
125 |
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
|
126 |
attention_mask = attention_mask.expand(-1, num_heads, -1, -1)
|
127 |
|
128 |
-
# Compute attention with correct shapes
|
129 |
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
130 |
query, key, value,
|
131 |
attn_mask=attention_mask,
|
@@ -133,15 +123,13 @@ class LTXEnhanceAttnProcessor2_0:
|
|
133 |
is_causal=False
|
134 |
)
|
135 |
|
136 |
-
# Reshape output
|
137 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
|
138 |
-
hidden_states = hidden_states.to(
|
139 |
|
140 |
-
# Apply enhancement if enabled
|
141 |
if is_enhance_enabled() and enhance_scores is not None:
|
142 |
hidden_states = hidden_states * enhance_scores
|
143 |
|
144 |
-
#
|
145 |
hidden_states = attn.to_out[0](hidden_states)
|
146 |
hidden_states = attn.to_out[1](hidden_states)
|
147 |
|
|
|
15 |
def _get_enhance_scores(self, query, key, inner_dim, num_heads, num_frames, text_seq_length=None):
|
16 |
"""Calculate enhancement scores for the attention mechanism"""
|
17 |
head_dim = inner_dim // num_heads
|
18 |
+
orig_dtype = query.dtype # Store original dtype
|
19 |
|
20 |
if text_seq_length is not None:
|
21 |
img_q = query[:, :, :-text_seq_length] if text_seq_length > 0 else query
|
|
|
24 |
img_q, img_k = query, key
|
25 |
|
26 |
batch_size, num_heads, ST, head_dim = img_q.shape
|
|
|
27 |
spatial_dim = ST // num_frames
|
|
|
28 |
if spatial_dim * num_frames != ST:
|
|
|
29 |
spatial_dim = max(1, ST // num_frames)
|
|
|
30 |
ST = spatial_dim * num_frames
|
31 |
|
|
|
32 |
img_q = img_q[:, :, :ST, :]
|
33 |
img_k = img_k[:, :, :ST, :]
|
34 |
|
|
|
42 |
T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
|
43 |
)
|
44 |
except Exception as e:
|
45 |
+
return torch.ones(img_q.shape[0], 1, 1, 1, device=img_q.device, dtype=orig_dtype)
|
|
|
46 |
|
47 |
scale = head_dim**-0.5
|
48 |
query_image = query_image * scale
|
|
|
|
|
|
|
49 |
|
50 |
+
# Compute attention in float32 for stability
|
51 |
+
with torch.cuda.amp.autocast(enabled=False):
|
52 |
+
query_image = query_image.float()
|
53 |
+
key_image = key_image.float()
|
54 |
+
attn_temp = query_image @ key_image.transpose(-2, -1)
|
55 |
+
attn_temp = attn_temp.softmax(dim=-1)
|
56 |
+
|
57 |
attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
|
|
|
|
|
58 |
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
|
59 |
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
|
|
|
|
|
60 |
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
|
61 |
+
|
|
|
|
|
62 |
num_off_diag = num_frames * num_frames - num_frames
|
63 |
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
|
64 |
|
65 |
enhance_scores = mean_scores.mean() * (num_frames + 4.0)
|
66 |
enhance_scores = enhance_scores.clamp(min=1)
|
67 |
+
|
68 |
+
# Convert back to original dtype
|
69 |
+
return enhance_scores.to(orig_dtype)
|
70 |
|
71 |
def __call__(
|
72 |
self,
|
|
|
76 |
attention_mask = None,
|
77 |
**kwargs
|
78 |
) -> torch.Tensor:
|
79 |
+
orig_dtype = hidden_states.dtype # Store original dtype
|
80 |
batch_size, sequence_length, _ = hidden_states.shape
|
81 |
text_seq_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
|
82 |
|
|
|
87 |
num_heads = attn.heads
|
88 |
head_dim = inner_dim // num_heads
|
89 |
|
|
|
90 |
query = attn.to_q(hidden_states)
|
91 |
key = attn.to_k(encoder_hidden_states)
|
92 |
value = attn.to_v(encoder_hidden_states)
|
93 |
|
|
|
94 |
query = query.view(batch_size, sequence_length, num_heads, head_dim).transpose(1, 2)
|
95 |
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
96 |
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
|
|
99 |
query = query.float()
|
100 |
key = key.float()
|
101 |
|
|
|
102 |
enhance_scores = None
|
103 |
if is_enhance_enabled():
|
104 |
try:
|
|
|
112 |
except ValueError as e:
|
113 |
print(f"Warning: Could not calculate enhance scores: {e}")
|
114 |
|
|
|
115 |
if attention_mask is not None:
|
116 |
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
|
117 |
attention_mask = attention_mask.expand(-1, num_heads, -1, -1)
|
118 |
|
|
|
119 |
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
120 |
query, key, value,
|
121 |
attn_mask=attention_mask,
|
|
|
123 |
is_causal=False
|
124 |
)
|
125 |
|
|
|
126 |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
|
127 |
+
hidden_states = hidden_states.to(orig_dtype) # Ensure we're back to original dtype
|
128 |
|
|
|
129 |
if is_enhance_enabled() and enhance_scores is not None:
|
130 |
hidden_states = hidden_states * enhance_scores
|
131 |
|
132 |
+
# Apply output projections while maintaining dtype
|
133 |
hidden_states = attn.to_out[0](hidden_states)
|
134 |
hidden_states = attn.to_out[1](hidden_states)
|
135 |
|