import torch import torch.nn as nn from einops import rearrange from diffusers.models.attention import Attention from enhance_a_video.enhance import enhance_score from enhance_a_video.globals import get_num_frames, is_enhance_enabled, set_num_frames class LTXEnhanceAttnProcessor2_0: """Attention processor for LTX that implements enhance-a-video functionality""" def __init__(self): if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): raise ImportError("LTXEnhanceAttnProcessor2_0 requires PyTorch 2.0.") def _get_enhance_scores(self, query, key, inner_dim, num_heads, num_frames, text_seq_length=None): """Calculate enhancement scores for the attention mechanism""" head_dim = inner_dim // num_heads if text_seq_length is not None: img_q = query[:, :, :-text_seq_length] if text_seq_length > 0 else query img_k = key[:, :, :-text_seq_length] if text_seq_length > 0 else key else: img_q, img_k = query, key batch_size, num_heads, ST, head_dim = img_q.shape # Calculate spatial dimension by dividing total tokens by number of frames spatial_dim = ST // num_frames # Ensure spatial_dim is calculated correctly if spatial_dim * num_frames != ST: # If we can't divide evenly, we'll need to pad or reshape spatial_dim = max(1, ST // num_frames) # Adjust ST to be evenly divisible ST = spatial_dim * num_frames # Ensure tensors have the right shape before rearranging img_q = img_q[:, :, :ST, :] img_k = img_k[:, :, :ST, :] try: query_image = rearrange( img_q, "B N (T S) C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim ) key_image = rearrange( img_k, "B N (T S) C -> (B S) N T C", T=num_frames, S=spatial_dim, N=num_heads, C=head_dim ) except Exception as e: # If rearrangement fails, return a default enhancement score return torch.ones(img_q.shape[0], 1, 1, 1, device=img_q.device) scale = head_dim**-0.5 query_image = query_image * scale attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32 attn_temp = attn_temp.to(torch.float32) attn_temp = attn_temp.softmax(dim=-1) # Reshape to [batch_size * num_tokens, num_frames, num_frames] attn_temp = attn_temp.reshape(-1, num_frames, num_frames) # Create a mask for diagonal elements diag_mask = torch.eye(num_frames, device=attn_temp.device).bool() diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1) # Zero out diagonal elements attn_wo_diag = attn_temp.masked_fill(diag_mask, 0) # Calculate mean for each token's attention matrix # Number of off-diagonal elements per matrix is n*n - n num_off_diag = num_frames * num_frames - num_frames mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag enhance_scores = mean_scores.mean() * (num_frames + 4.0) enhance_scores = enhance_scores.clamp(min=1) return enhance_scores def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, attention_mask = None, **kwargs ) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape text_seq_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0 if encoder_hidden_states is None: encoder_hidden_states = hidden_states inner_dim = attn.to_q.out_features num_heads = attn.heads head_dim = inner_dim // num_heads # Get query, key, value projections query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # Reshape projections query = query.view(batch_size, sequence_length, num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) if attn.upcast_attention: query = query.float() key = key.float() # Process attention enhance_scores = None if is_enhance_enabled(): try: enhance_scores = self._get_enhance_scores( query, key, inner_dim, num_heads, get_num_frames(), text_seq_length ) except ValueError as e: print(f"Warning: Could not calculate enhance scores: {e}") # Make sure attention_mask has correct shape if attention_mask is not None: attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1]) attention_mask = attention_mask.expand(-1, num_heads, -1, -1) # Compute attention with correct shapes hidden_states = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) # Reshape output hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim) hidden_states = hidden_states.to(query.dtype) # Apply enhancement if enabled if is_enhance_enabled() and enhance_scores is not None: hidden_states = hidden_states * enhance_scores # Output projection hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states def inject_enhance_for_ltx(model: nn.Module) -> None: """ Inject enhance score for LTX model. 1. Register hook to update num frames 2. Replace attention processor with enhance processor """ # Register hook to update num frames model.register_forward_pre_hook(num_frames_hook, with_kwargs=True) # Replace attention with enhance processor for name, module in model.named_modules(): if isinstance(module, Attention): module.set_processor(LTXEnhanceAttnProcessor2_0()) def num_frames_hook(module, args, kwargs): """Hook to update the number of frames automatically.""" if "hidden_states" in kwargs: hidden_states = kwargs["hidden_states"] else: hidden_states = args[0] num_frames = hidden_states.shape[2] set_num_frames(num_frames) return args, kwargs