|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
from diffusers.models.attention import Attention |
|
from enhance_a_video.enhance import enhance_score |
|
from enhance_a_video.globals import get_num_frames, is_enhance_enabled, set_num_frames |
|
|
|
class LTXEnhanceAttnProcessor2_0: |
|
"""Attention processor for LTX that implements enhance-a-video functionality""" |
|
|
|
def __init__(self): |
|
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): |
|
raise ImportError("LTXEnhanceAttnProcessor2_0 requires PyTorch 2.0.") |
|
|
|
def _get_enhance_scores(self, query, key, inner_dim, num_heads, num_frames, text_seq_length=None): |
|
"""Calculate enhancement scores for the attention mechanism""" |
|
head_dim = inner_dim // num_heads |
|
orig_dtype = query.dtype |
|
|
|
if text_seq_length is not None: |
|
img_q = query[:, :, :-text_seq_length] if text_seq_length > 0 else query |
|
img_k = key[:, :, :-text_seq_length] if text_seq_length > 0 else key |
|
else: |
|
img_q, img_k = query, key |
|
|
|
batch_size, num_heads, ST, head_dim = img_q.shape |
|
spatial_dim = ST // num_frames |
|
if spatial_dim * num_frames != ST: |
|
spatial_dim = max(1, ST // num_frames) |
|
ST = spatial_dim * num_frames |
|
|
|
img_q = img_q[:, :, :ST, :] |
|
img_k = img_k[:, :, :ST, :] |
|
|
|
try: |
|
query_image = rearrange( |
|
img_q, "B N (T S) C -> (B S) N T C", |
|
T=num_frames, S=spatial_dim, N=num_heads, C=head_dim |
|
) |
|
key_image = rearrange( |
|
img_k, "B N (T S) C -> (B S) N T C", |
|
T=num_frames, S=spatial_dim, N=num_heads, C=head_dim |
|
) |
|
except Exception as e: |
|
return torch.ones(img_q.shape[0], 1, 1, 1, device=img_q.device, dtype=orig_dtype) |
|
|
|
scale = head_dim**-0.5 |
|
query_image = query_image * scale |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
query_image = query_image.float() |
|
key_image = key_image.float() |
|
attn_temp = query_image @ key_image.transpose(-2, -1) |
|
attn_temp = attn_temp.softmax(dim=-1) |
|
|
|
attn_temp = attn_temp.reshape(-1, num_frames, num_frames) |
|
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool() |
|
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1) |
|
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0) |
|
|
|
num_off_diag = num_frames * num_frames - num_frames |
|
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag |
|
|
|
enhance_scores = mean_scores.mean() * (num_frames + 4.0) |
|
enhance_scores = enhance_scores.clamp(min=1) |
|
|
|
|
|
return enhance_scores.to(orig_dtype) |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor = None, |
|
attention_mask = None, |
|
**kwargs |
|
) -> torch.Tensor: |
|
|
|
|
|
if hidden_states.ndim == 4: |
|
batch_size, sequence_length, num_heads, head_dim = hidden_states.shape |
|
else: |
|
batch_size, sequence_length, inner_dim = hidden_states.shape |
|
|
|
text_seq_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0 |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
|
|
inner_dim = attn.to_q.out_features |
|
num_heads = attn.heads |
|
head_dim = inner_dim // num_heads |
|
|
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
if attn.upcast_attention: |
|
query = query.float() |
|
key = key.float() |
|
|
|
enhance_scores = None |
|
if is_enhance_enabled(): |
|
try: |
|
enhance_scores = self._get_enhance_scores( |
|
query, key, |
|
inner_dim, |
|
num_heads, |
|
get_num_frames(), |
|
text_seq_length |
|
) |
|
except ValueError as e: |
|
print(f"Warning: Could not calculate enhance scores: {e}") |
|
|
|
if attention_mask is not None: |
|
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1]) |
|
attention_mask = attention_mask.expand(-1, num_heads, -1, -1) |
|
|
|
hidden_states = torch.nn.functional.scaled_dot_product_attention( |
|
query, key, value, |
|
attn_mask=attention_mask, |
|
dropout_p=0.0, |
|
is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim) |
|
hidden_states = hidden_states.to(orig_dtype) |
|
|
|
if is_enhance_enabled() and enhance_scores is not None: |
|
hidden_states = hidden_states * enhance_scores |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
return hidden_states |
|
|
|
def inject_enhance_for_ltx(model: nn.Module) -> None: |
|
""" |
|
Inject enhance score for LTX model. |
|
1. Register hook to update num frames |
|
2. Replace attention processor with enhance processor |
|
""" |
|
|
|
model.register_forward_pre_hook(num_frames_hook, with_kwargs=True) |
|
|
|
|
|
for name, module in model.named_modules(): |
|
if isinstance(module, Attention): |
|
module.set_processor(LTXEnhanceAttnProcessor2_0()) |
|
|
|
def num_frames_hook(module, args, kwargs): |
|
"""Hook to update the number of frames automatically.""" |
|
if "hidden_states" in kwargs: |
|
hidden_states = kwargs["hidden_states"] |
|
else: |
|
hidden_states = args[0] |
|
num_frames = hidden_states.shape[2] |
|
set_num_frames(num_frames) |
|
return args, kwargs |