jbilcke-hf HF staff commited on
Commit
60a27e7
·
verified ·
1 Parent(s): 573d0ba

Update enhance.py

Browse files
Files changed (1) hide show
  1. enhance.py +8 -7
enhance.py CHANGED
@@ -76,8 +76,13 @@ class LTXEnhanceAttnProcessor2_0:
76
  attention_mask = None,
77
  **kwargs
78
  ) -> torch.Tensor:
79
- orig_dtype = hidden_states.dtype # Store original dtype
80
- batch_size, sequence_length, _ = hidden_states.shape
 
 
 
 
 
81
  text_seq_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
82
 
83
  if encoder_hidden_states is None:
@@ -90,11 +95,7 @@ class LTXEnhanceAttnProcessor2_0:
90
  query = attn.to_q(hidden_states)
91
  key = attn.to_k(encoder_hidden_states)
92
  value = attn.to_v(encoder_hidden_states)
93
-
94
- query = query.view(batch_size, sequence_length, num_heads, head_dim).transpose(1, 2)
95
- key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
96
- value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
97
-
98
  if attn.upcast_attention:
99
  query = query.float()
100
  key = key.float()
 
76
  attention_mask = None,
77
  **kwargs
78
  ) -> torch.Tensor:
79
+ # The shape could be [batch_size, sequence_length, channels] or [batch_size, sequence_length, num_heads, head_dim]
80
+ # We need to handle both cases
81
+ if hidden_states.ndim == 4:
82
+ batch_size, sequence_length, num_heads, head_dim = hidden_states.shape
83
+ else:
84
+ batch_size, sequence_length, inner_dim = hidden_states.shape
85
+
86
  text_seq_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
87
 
88
  if encoder_hidden_states is None:
 
95
  query = attn.to_q(hidden_states)
96
  key = attn.to_k(encoder_hidden_states)
97
  value = attn.to_v(encoder_hidden_states)
98
+
 
 
 
 
99
  if attn.upcast_attention:
100
  query = query.float()
101
  key = key.float()