jbilcke-hf HF staff commited on
Commit
de8d4f2
·
verified ·
1 Parent(s): 60a27e7

Update enhance.py

Browse files
Files changed (1) hide show
  1. enhance.py +9 -2
enhance.py CHANGED
@@ -76,6 +76,9 @@ class LTXEnhanceAttnProcessor2_0:
76
  attention_mask = None,
77
  **kwargs
78
  ) -> torch.Tensor:
 
 
 
79
  # The shape could be [batch_size, sequence_length, channels] or [batch_size, sequence_length, num_heads, head_dim]
80
  # We need to handle both cases
81
  if hidden_states.ndim == 4:
@@ -95,7 +98,11 @@ class LTXEnhanceAttnProcessor2_0:
95
  query = attn.to_q(hidden_states)
96
  key = attn.to_k(encoder_hidden_states)
97
  value = attn.to_v(encoder_hidden_states)
98
-
 
 
 
 
99
  if attn.upcast_attention:
100
  query = query.float()
101
  key = key.float()
@@ -125,7 +132,7 @@ class LTXEnhanceAttnProcessor2_0:
125
  )
126
 
127
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
128
- hidden_states = hidden_states.to(orig_dtype) # Ensure we're back to original dtype
129
 
130
  if is_enhance_enabled() and enhance_scores is not None:
131
  hidden_states = hidden_states * enhance_scores
 
76
  attention_mask = None,
77
  **kwargs
78
  ) -> torch.Tensor:
79
+ # Store original dtype first
80
+ orig_dtype = hidden_states.dtype
81
+
82
  # The shape could be [batch_size, sequence_length, channels] or [batch_size, sequence_length, num_heads, head_dim]
83
  # We need to handle both cases
84
  if hidden_states.ndim == 4:
 
98
  query = attn.to_q(hidden_states)
99
  key = attn.to_k(encoder_hidden_states)
100
  value = attn.to_v(encoder_hidden_states)
101
+
102
+ query = query.view(batch_size, sequence_length, num_heads, head_dim).transpose(1, 2)
103
+ key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
104
+ value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
105
+
106
  if attn.upcast_attention:
107
  query = query.float()
108
  key = key.float()
 
132
  )
133
 
134
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
135
+ hidden_states = hidden_states.to(orig_dtype) # Now orig_dtype is defined
136
 
137
  if is_enhance_enabled() and enhance_scores is not None:
138
  hidden_states = hidden_states * enhance_scores