jbilcke-hf HF staff commited on
Commit
9bc27ab
·
verified ·
1 Parent(s): 930d076

Update enhance.py

Browse files
Files changed (1) hide show
  1. enhance.py +67 -67
enhance.py CHANGED
@@ -13,66 +13,66 @@ class LTXEnhanceAttnProcessor2_0:
13
  raise ImportError("LTXEnhanceAttnProcessor2_0 requires PyTorch 2.0.")
14
 
15
  def _get_enhance_scores(self, query, key, inner_dim, num_heads, num_frames, text_seq_length=None):
16
- """Calculate enhancement scores for the attention mechanism"""
17
- head_dim = inner_dim // num_heads
18
-
19
- if text_seq_length is not None:
20
- img_q = query[:, :, :-text_seq_length] if text_seq_length > 0 else query
21
- img_k = key[:, :, :-text_seq_length] if text_seq_length > 0 else key
22
- else:
23
- img_q, img_k = query, key
24
-
25
- batch_size, num_heads, ST, head_dim = img_q.shape
26
- # Calculate spatial dimension by dividing total tokens by number of frames
27
- spatial_dim = ST // num_frames
28
- # Ensure spatial_dim is calculated correctly
29
- if spatial_dim * num_frames != ST:
30
- # If we can't divide evenly, we'll need to pad or reshape
31
- spatial_dim = max(1, ST // num_frames)
32
- # Adjust ST to be evenly divisible
33
- ST = spatial_dim * num_frames
34
-
35
- # Ensure tensors have the right shape before rearranging
36
- img_q = img_q[:, :, :ST, :]
37
- img_k = img_k[:, :, :ST, :]
38
-
39
- try:
40
- query_image = rearrange(
41
- img_q, "B N (T S) C -> (B S) N T C",
42
- T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
43
- )
44
- key_image = rearrange(
45
- img_k, "B N (T S) C -> (B S) N T C",
46
- T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
47
- )
48
- except Exception as e:
49
- # If rearrangement fails, return a default enhancement score
50
- return torch.ones(img_q.shape[0], 1, 1, 1, device=img_q.device)
51
-
52
- scale = head_dim**-0.5
53
- query_image = query_image * scale
54
- attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32
55
- attn_temp = attn_temp.to(torch.float32)
56
- attn_temp = attn_temp.softmax(dim=-1)
57
-
58
- # Reshape to [batch_size * num_tokens, num_frames, num_frames]
59
- attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
60
-
61
- # Create a mask for diagonal elements
62
- diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
63
- diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
64
 
65
- # Zero out diagonal elements
66
- attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
 
67
 
68
- # Calculate mean for each token's attention matrix
69
- # Number of off-diagonal elements per matrix is n*n - n
70
- num_off_diag = num_frames * num_frames - num_frames
71
- mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
72
 
73
- enhance_scores = mean_scores.mean() * (num_frames + 4.0)
74
- enhance_scores = enhance_scores.clamp(min=1)
75
- return enhance_scores
 
 
 
 
 
76
 
77
  def __call__(
78
  self,
@@ -91,19 +91,20 @@ class LTXEnhanceAttnProcessor2_0:
91
  inner_dim = attn.to_q.out_features
92
  num_heads = attn.heads
93
  head_dim = inner_dim // num_heads
94
-
95
  query = attn.to_q(hidden_states)
96
  key = attn.to_k(encoder_hidden_states)
97
  value = attn.to_v(encoder_hidden_states)
98
-
 
99
  query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
100
  key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
101
  value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
102
-
103
  if attn.upcast_attention:
104
  query = query.float()
105
  key = key.float()
106
-
107
  enhance_scores = None
108
  if is_enhance_enabled():
109
  try:
@@ -116,25 +117,24 @@ class LTXEnhanceAttnProcessor2_0:
116
  )
117
  except ValueError as e:
118
  print(f"Warning: Could not calculate enhance scores: {e}")
119
- # Continue without enhancement if calculation fails
120
-
121
  hidden_states = torch.nn.functional.scaled_dot_product_attention(
122
  query, key, value,
123
  attn_mask=attention_mask,
124
  dropout_p=0.0,
125
  is_causal=False
126
  )
127
-
128
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
129
  hidden_states = hidden_states.to(query.dtype)
130
-
131
  # Apply enhancement if enabled
132
  if is_enhance_enabled() and enhance_scores is not None:
133
  hidden_states = hidden_states * enhance_scores
134
-
135
  hidden_states = attn.to_out[0](hidden_states)
136
  hidden_states = attn.to_out[1](hidden_states)
137
-
138
  return hidden_states
139
 
140
  def inject_enhance_for_ltx(model: nn.Module) -> None:
 
13
  raise ImportError("LTXEnhanceAttnProcessor2_0 requires PyTorch 2.0.")
14
 
15
  def _get_enhance_scores(self, query, key, inner_dim, num_heads, num_frames, text_seq_length=None):
16
+ """Calculate enhancement scores for the attention mechanism"""
17
+ head_dim = inner_dim // num_heads
18
+
19
+ if text_seq_length is not None:
20
+ img_q = query[:, :, :-text_seq_length] if text_seq_length > 0 else query
21
+ img_k = key[:, :, :-text_seq_length] if text_seq_length > 0 else key
22
+ else:
23
+ img_q, img_k = query, key
24
+
25
+ batch_size, num_heads, ST, head_dim = img_q.shape
26
+ # Calculate spatial dimension by dividing total tokens by number of frames
27
+ spatial_dim = ST // num_frames
28
+ # Ensure spatial_dim is calculated correctly
29
+ if spatial_dim * num_frames != ST:
30
+ # If we can't divide evenly, we'll need to pad or reshape
31
+ spatial_dim = max(1, ST // num_frames)
32
+ # Adjust ST to be evenly divisible
33
+ ST = spatial_dim * num_frames
34
+
35
+ # Ensure tensors have the right shape before rearranging
36
+ img_q = img_q[:, :, :ST, :]
37
+ img_k = img_k[:, :, :ST, :]
38
+
39
+ try:
40
+ query_image = rearrange(
41
+ img_q, "B N (T S) C -> (B S) N T C",
42
+ T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
43
+ )
44
+ key_image = rearrange(
45
+ img_k, "B N (T S) C -> (B S) N T C",
46
+ T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
47
+ )
48
+ except Exception as e:
49
+ # If rearrangement fails, return a default enhancement score
50
+ return torch.ones(img_q.shape[0], 1, 1, 1, device=img_q.device)
51
+
52
+ scale = head_dim**-0.5
53
+ query_image = query_image * scale
54
+ attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32
55
+ attn_temp = attn_temp.to(torch.float32)
56
+ attn_temp = attn_temp.softmax(dim=-1)
57
+
58
+ # Reshape to [batch_size * num_tokens, num_frames, num_frames]
59
+ attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
 
 
 
 
60
 
61
+ # Create a mask for diagonal elements
62
+ diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
63
+ diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
64
 
65
+ # Zero out diagonal elements
66
+ attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
 
 
67
 
68
+ # Calculate mean for each token's attention matrix
69
+ # Number of off-diagonal elements per matrix is n*n - n
70
+ num_off_diag = num_frames * num_frames - num_frames
71
+ mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
72
+
73
+ enhance_scores = mean_scores.mean() * (num_frames + 4.0)
74
+ enhance_scores = enhance_scores.clamp(min=1)
75
+ return enhance_scores
76
 
77
  def __call__(
78
  self,
 
91
  inner_dim = attn.to_q.out_features
92
  num_heads = attn.heads
93
  head_dim = inner_dim // num_heads
94
+
95
  query = attn.to_q(hidden_states)
96
  key = attn.to_k(encoder_hidden_states)
97
  value = attn.to_v(encoder_hidden_states)
98
+
99
+ # Reshape query, key, value to match expected dimensions
100
  query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
101
  key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
102
  value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
103
+
104
  if attn.upcast_attention:
105
  query = query.float()
106
  key = key.float()
107
+
108
  enhance_scores = None
109
  if is_enhance_enabled():
110
  try:
 
117
  )
118
  except ValueError as e:
119
  print(f"Warning: Could not calculate enhance scores: {e}")
120
+
 
121
  hidden_states = torch.nn.functional.scaled_dot_product_attention(
122
  query, key, value,
123
  attn_mask=attention_mask,
124
  dropout_p=0.0,
125
  is_causal=False
126
  )
127
+
128
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
129
  hidden_states = hidden_states.to(query.dtype)
130
+
131
  # Apply enhancement if enabled
132
  if is_enhance_enabled() and enhance_scores is not None:
133
  hidden_states = hidden_states * enhance_scores
134
+
135
  hidden_states = attn.to_out[0](hidden_states)
136
  hidden_states = attn.to_out[1](hidden_states)
137
+
138
  return hidden_states
139
 
140
  def inject_enhance_for_ltx(model: nn.Module) -> None: