jbilcke-hf HF staff commited on
Commit
573d0ba
·
verified ·
1 Parent(s): 70db765

Update enhance.py

Browse files
Files changed (1) hide show
  1. enhance.py +16 -28
enhance.py CHANGED
@@ -15,6 +15,7 @@ class LTXEnhanceAttnProcessor2_0:
15
  def _get_enhance_scores(self, query, key, inner_dim, num_heads, num_frames, text_seq_length=None):
16
  """Calculate enhancement scores for the attention mechanism"""
17
  head_dim = inner_dim // num_heads
 
18
 
19
  if text_seq_length is not None:
20
  img_q = query[:, :, :-text_seq_length] if text_seq_length > 0 else query
@@ -23,16 +24,11 @@ class LTXEnhanceAttnProcessor2_0:
23
  img_q, img_k = query, key
24
 
25
  batch_size, num_heads, ST, head_dim = img_q.shape
26
- # Calculate spatial dimension by dividing total tokens by number of frames
27
  spatial_dim = ST // num_frames
28
- # Ensure spatial_dim is calculated correctly
29
  if spatial_dim * num_frames != ST:
30
- # If we can't divide evenly, we'll need to pad or reshape
31
  spatial_dim = max(1, ST // num_frames)
32
- # Adjust ST to be evenly divisible
33
  ST = spatial_dim * num_frames
34
 
35
- # Ensure tensors have the right shape before rearranging
36
  img_q = img_q[:, :, :ST, :]
37
  img_k = img_k[:, :, :ST, :]
38
 
@@ -46,33 +42,31 @@ class LTXEnhanceAttnProcessor2_0:
46
  T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
47
  )
48
  except Exception as e:
49
- # If rearrangement fails, return a default enhancement score
50
- return torch.ones(img_q.shape[0], 1, 1, 1, device=img_q.device)
51
 
52
  scale = head_dim**-0.5
53
  query_image = query_image * scale
54
- attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32
55
- attn_temp = attn_temp.to(torch.float32)
56
- attn_temp = attn_temp.softmax(dim=-1)
57
 
58
- # Reshape to [batch_size * num_tokens, num_frames, num_frames]
 
 
 
 
 
 
59
  attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
60
-
61
- # Create a mask for diagonal elements
62
  diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
63
  diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
64
-
65
- # Zero out diagonal elements
66
  attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
67
-
68
- # Calculate mean for each token's attention matrix
69
- # Number of off-diagonal elements per matrix is n*n - n
70
  num_off_diag = num_frames * num_frames - num_frames
71
  mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
72
 
73
  enhance_scores = mean_scores.mean() * (num_frames + 4.0)
74
  enhance_scores = enhance_scores.clamp(min=1)
75
- return enhance_scores
 
 
76
 
77
  def __call__(
78
  self,
@@ -82,6 +76,7 @@ class LTXEnhanceAttnProcessor2_0:
82
  attention_mask = None,
83
  **kwargs
84
  ) -> torch.Tensor:
 
85
  batch_size, sequence_length, _ = hidden_states.shape
86
  text_seq_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
87
 
@@ -92,12 +87,10 @@ class LTXEnhanceAttnProcessor2_0:
92
  num_heads = attn.heads
93
  head_dim = inner_dim // num_heads
94
 
95
- # Get query, key, value projections
96
  query = attn.to_q(hidden_states)
97
  key = attn.to_k(encoder_hidden_states)
98
  value = attn.to_v(encoder_hidden_states)
99
 
100
- # Reshape projections
101
  query = query.view(batch_size, sequence_length, num_heads, head_dim).transpose(1, 2)
102
  key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
103
  value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
@@ -106,7 +99,6 @@ class LTXEnhanceAttnProcessor2_0:
106
  query = query.float()
107
  key = key.float()
108
 
109
- # Process attention
110
  enhance_scores = None
111
  if is_enhance_enabled():
112
  try:
@@ -120,12 +112,10 @@ class LTXEnhanceAttnProcessor2_0:
120
  except ValueError as e:
121
  print(f"Warning: Could not calculate enhance scores: {e}")
122
 
123
- # Make sure attention_mask has correct shape
124
  if attention_mask is not None:
125
  attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
126
  attention_mask = attention_mask.expand(-1, num_heads, -1, -1)
127
 
128
- # Compute attention with correct shapes
129
  hidden_states = torch.nn.functional.scaled_dot_product_attention(
130
  query, key, value,
131
  attn_mask=attention_mask,
@@ -133,15 +123,13 @@ class LTXEnhanceAttnProcessor2_0:
133
  is_causal=False
134
  )
135
 
136
- # Reshape output
137
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
138
- hidden_states = hidden_states.to(query.dtype)
139
 
140
- # Apply enhancement if enabled
141
  if is_enhance_enabled() and enhance_scores is not None:
142
  hidden_states = hidden_states * enhance_scores
143
 
144
- # Output projection
145
  hidden_states = attn.to_out[0](hidden_states)
146
  hidden_states = attn.to_out[1](hidden_states)
147
 
 
15
  def _get_enhance_scores(self, query, key, inner_dim, num_heads, num_frames, text_seq_length=None):
16
  """Calculate enhancement scores for the attention mechanism"""
17
  head_dim = inner_dim // num_heads
18
+ orig_dtype = query.dtype # Store original dtype
19
 
20
  if text_seq_length is not None:
21
  img_q = query[:, :, :-text_seq_length] if text_seq_length > 0 else query
 
24
  img_q, img_k = query, key
25
 
26
  batch_size, num_heads, ST, head_dim = img_q.shape
 
27
  spatial_dim = ST // num_frames
 
28
  if spatial_dim * num_frames != ST:
 
29
  spatial_dim = max(1, ST // num_frames)
 
30
  ST = spatial_dim * num_frames
31
 
 
32
  img_q = img_q[:, :, :ST, :]
33
  img_k = img_k[:, :, :ST, :]
34
 
 
42
  T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
43
  )
44
  except Exception as e:
45
+ return torch.ones(img_q.shape[0], 1, 1, 1, device=img_q.device, dtype=orig_dtype)
 
46
 
47
  scale = head_dim**-0.5
48
  query_image = query_image * scale
 
 
 
49
 
50
+ # Compute attention in float32 for stability
51
+ with torch.cuda.amp.autocast(enabled=False):
52
+ query_image = query_image.float()
53
+ key_image = key_image.float()
54
+ attn_temp = query_image @ key_image.transpose(-2, -1)
55
+ attn_temp = attn_temp.softmax(dim=-1)
56
+
57
  attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
 
 
58
  diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
59
  diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
 
 
60
  attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
61
+
 
 
62
  num_off_diag = num_frames * num_frames - num_frames
63
  mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
64
 
65
  enhance_scores = mean_scores.mean() * (num_frames + 4.0)
66
  enhance_scores = enhance_scores.clamp(min=1)
67
+
68
+ # Convert back to original dtype
69
+ return enhance_scores.to(orig_dtype)
70
 
71
  def __call__(
72
  self,
 
76
  attention_mask = None,
77
  **kwargs
78
  ) -> torch.Tensor:
79
+ orig_dtype = hidden_states.dtype # Store original dtype
80
  batch_size, sequence_length, _ = hidden_states.shape
81
  text_seq_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
82
 
 
87
  num_heads = attn.heads
88
  head_dim = inner_dim // num_heads
89
 
 
90
  query = attn.to_q(hidden_states)
91
  key = attn.to_k(encoder_hidden_states)
92
  value = attn.to_v(encoder_hidden_states)
93
 
 
94
  query = query.view(batch_size, sequence_length, num_heads, head_dim).transpose(1, 2)
95
  key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
96
  value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
 
99
  query = query.float()
100
  key = key.float()
101
 
 
102
  enhance_scores = None
103
  if is_enhance_enabled():
104
  try:
 
112
  except ValueError as e:
113
  print(f"Warning: Could not calculate enhance scores: {e}")
114
 
 
115
  if attention_mask is not None:
116
  attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
117
  attention_mask = attention_mask.expand(-1, num_heads, -1, -1)
118
 
 
119
  hidden_states = torch.nn.functional.scaled_dot_product_attention(
120
  query, key, value,
121
  attn_mask=attention_mask,
 
123
  is_causal=False
124
  )
125
 
 
126
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
127
+ hidden_states = hidden_states.to(orig_dtype) # Ensure we're back to original dtype
128
 
 
129
  if is_enhance_enabled() and enhance_scores is not None:
130
  hidden_states = hidden_states * enhance_scores
131
 
132
+ # Apply output projections while maintaining dtype
133
  hidden_states = attn.to_out[0](hidden_states)
134
  hidden_states = attn.to_out[1](hidden_states)
135