LTX-Video-0.9.1-HFIE / enhance.py
jbilcke-hf's picture
jbilcke-hf HF staff
Update enhance.py
60a27e7 verified
raw
history blame
6.38 kB
import torch
import torch.nn as nn
from einops import rearrange
from diffusers.models.attention import Attention
from enhance_a_video.enhance import enhance_score
from enhance_a_video.globals import get_num_frames, is_enhance_enabled, set_num_frames
class LTXEnhanceAttnProcessor2_0:
"""Attention processor for LTX that implements enhance-a-video functionality"""
def __init__(self):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise ImportError("LTXEnhanceAttnProcessor2_0 requires PyTorch 2.0.")
def _get_enhance_scores(self, query, key, inner_dim, num_heads, num_frames, text_seq_length=None):
"""Calculate enhancement scores for the attention mechanism"""
head_dim = inner_dim // num_heads
orig_dtype = query.dtype # Store original dtype
if text_seq_length is not None:
img_q = query[:, :, :-text_seq_length] if text_seq_length > 0 else query
img_k = key[:, :, :-text_seq_length] if text_seq_length > 0 else key
else:
img_q, img_k = query, key
batch_size, num_heads, ST, head_dim = img_q.shape
spatial_dim = ST // num_frames
if spatial_dim * num_frames != ST:
spatial_dim = max(1, ST // num_frames)
ST = spatial_dim * num_frames
img_q = img_q[:, :, :ST, :]
img_k = img_k[:, :, :ST, :]
try:
query_image = rearrange(
img_q, "B N (T S) C -> (B S) N T C",
T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
)
key_image = rearrange(
img_k, "B N (T S) C -> (B S) N T C",
T=num_frames, S=spatial_dim, N=num_heads, C=head_dim
)
except Exception as e:
return torch.ones(img_q.shape[0], 1, 1, 1, device=img_q.device, dtype=orig_dtype)
scale = head_dim**-0.5
query_image = query_image * scale
# Compute attention in float32 for stability
with torch.cuda.amp.autocast(enabled=False):
query_image = query_image.float()
key_image = key_image.float()
attn_temp = query_image @ key_image.transpose(-2, -1)
attn_temp = attn_temp.softmax(dim=-1)
attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
num_off_diag = num_frames * num_frames - num_frames
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
enhance_scores = mean_scores.mean() * (num_frames + 4.0)
enhance_scores = enhance_scores.clamp(min=1)
# Convert back to original dtype
return enhance_scores.to(orig_dtype)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask = None,
**kwargs
) -> torch.Tensor:
# The shape could be [batch_size, sequence_length, channels] or [batch_size, sequence_length, num_heads, head_dim]
# We need to handle both cases
if hidden_states.ndim == 4:
batch_size, sequence_length, num_heads, head_dim = hidden_states.shape
else:
batch_size, sequence_length, inner_dim = hidden_states.shape
text_seq_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
inner_dim = attn.to_q.out_features
num_heads = attn.heads
head_dim = inner_dim // num_heads
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.upcast_attention:
query = query.float()
key = key.float()
enhance_scores = None
if is_enhance_enabled():
try:
enhance_scores = self._get_enhance_scores(
query, key,
inner_dim,
num_heads,
get_num_frames(),
text_seq_length
)
except ValueError as e:
print(f"Warning: Could not calculate enhance scores: {e}")
if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
attention_mask = attention_mask.expand(-1, num_heads, -1, -1)
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query, key, value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
hidden_states = hidden_states.to(orig_dtype) # Ensure we're back to original dtype
if is_enhance_enabled() and enhance_scores is not None:
hidden_states = hidden_states * enhance_scores
# Apply output projections while maintaining dtype
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def inject_enhance_for_ltx(model: nn.Module) -> None:
"""
Inject enhance score for LTX model.
1. Register hook to update num frames
2. Replace attention processor with enhance processor
"""
# Register hook to update num frames
model.register_forward_pre_hook(num_frames_hook, with_kwargs=True)
# Replace attention with enhance processor
for name, module in model.named_modules():
if isinstance(module, Attention):
module.set_processor(LTXEnhanceAttnProcessor2_0())
def num_frames_hook(module, args, kwargs):
"""Hook to update the number of frames automatically."""
if "hidden_states" in kwargs:
hidden_states = kwargs["hidden_states"]
else:
hidden_states = args[0]
num_frames = hidden_states.shape[2]
set_num_frames(num_frames)
return args, kwargs