jbilcke-hf HF staff commited on
Commit
70db765
·
verified ·
1 Parent(s): 9bc27ab

Update enhance.py

Browse files
Files changed (1) hide show
  1. enhance.py +14 -4
enhance.py CHANGED
@@ -73,7 +73,7 @@ class LTXEnhanceAttnProcessor2_0:
73
  enhance_scores = mean_scores.mean() * (num_frames + 4.0)
74
  enhance_scores = enhance_scores.clamp(min=1)
75
  return enhance_scores
76
-
77
  def __call__(
78
  self,
79
  attn: Attention,
@@ -92,12 +92,13 @@ class LTXEnhanceAttnProcessor2_0:
92
  num_heads = attn.heads
93
  head_dim = inner_dim // num_heads
94
 
 
95
  query = attn.to_q(hidden_states)
96
  key = attn.to_k(encoder_hidden_states)
97
  value = attn.to_v(encoder_hidden_states)
98
 
99
- # Reshape query, key, value to match expected dimensions
100
- query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
101
  key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
102
  value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
103
 
@@ -105,6 +106,7 @@ class LTXEnhanceAttnProcessor2_0:
105
  query = query.float()
106
  key = key.float()
107
 
 
108
  enhance_scores = None
109
  if is_enhance_enabled():
110
  try:
@@ -118,6 +120,12 @@ class LTXEnhanceAttnProcessor2_0:
118
  except ValueError as e:
119
  print(f"Warning: Could not calculate enhance scores: {e}")
120
 
 
 
 
 
 
 
121
  hidden_states = torch.nn.functional.scaled_dot_product_attention(
122
  query, key, value,
123
  attn_mask=attention_mask,
@@ -125,6 +133,7 @@ class LTXEnhanceAttnProcessor2_0:
125
  is_causal=False
126
  )
127
 
 
128
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
129
  hidden_states = hidden_states.to(query.dtype)
130
 
@@ -132,11 +141,12 @@ class LTXEnhanceAttnProcessor2_0:
132
  if is_enhance_enabled() and enhance_scores is not None:
133
  hidden_states = hidden_states * enhance_scores
134
 
 
135
  hidden_states = attn.to_out[0](hidden_states)
136
  hidden_states = attn.to_out[1](hidden_states)
137
 
138
  return hidden_states
139
-
140
  def inject_enhance_for_ltx(model: nn.Module) -> None:
141
  """
142
  Inject enhance score for LTX model.
 
73
  enhance_scores = mean_scores.mean() * (num_frames + 4.0)
74
  enhance_scores = enhance_scores.clamp(min=1)
75
  return enhance_scores
76
+
77
  def __call__(
78
  self,
79
  attn: Attention,
 
92
  num_heads = attn.heads
93
  head_dim = inner_dim // num_heads
94
 
95
+ # Get query, key, value projections
96
  query = attn.to_q(hidden_states)
97
  key = attn.to_k(encoder_hidden_states)
98
  value = attn.to_v(encoder_hidden_states)
99
 
100
+ # Reshape projections
101
+ query = query.view(batch_size, sequence_length, num_heads, head_dim).transpose(1, 2)
102
  key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
103
  value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
104
 
 
106
  query = query.float()
107
  key = key.float()
108
 
109
+ # Process attention
110
  enhance_scores = None
111
  if is_enhance_enabled():
112
  try:
 
120
  except ValueError as e:
121
  print(f"Warning: Could not calculate enhance scores: {e}")
122
 
123
+ # Make sure attention_mask has correct shape
124
+ if attention_mask is not None:
125
+ attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
126
+ attention_mask = attention_mask.expand(-1, num_heads, -1, -1)
127
+
128
+ # Compute attention with correct shapes
129
  hidden_states = torch.nn.functional.scaled_dot_product_attention(
130
  query, key, value,
131
  attn_mask=attention_mask,
 
133
  is_causal=False
134
  )
135
 
136
+ # Reshape output
137
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
138
  hidden_states = hidden_states.to(query.dtype)
139
 
 
141
  if is_enhance_enabled() and enhance_scores is not None:
142
  hidden_states = hidden_states * enhance_scores
143
 
144
+ # Output projection
145
  hidden_states = attn.to_out[0](hidden_states)
146
  hidden_states = attn.to_out[1](hidden_states)
147
 
148
  return hidden_states
149
+
150
  def inject_enhance_for_ltx(model: nn.Module) -> None:
151
  """
152
  Inject enhance score for LTX model.