|
import math
|
|
from typing import Any
|
|
from einops import rearrange
|
|
import torch
|
|
from diffusers.models.attention_processor import Attention
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EPSILON = 1e-6
|
|
|
|
|
|
class FlashAttentionFunction(torch.autograd.function.Function):
|
|
@staticmethod
|
|
@torch.no_grad()
|
|
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
|
"""Algorithm 2 in the paper"""
|
|
|
|
device = q.device
|
|
dtype = q.dtype
|
|
max_neg_value = -torch.finfo(q.dtype).max
|
|
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
|
|
|
o = torch.zeros_like(q)
|
|
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
|
all_row_maxes = torch.full(
|
|
(*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
|
|
)
|
|
|
|
scale = q.shape[-1] ** -0.5
|
|
|
|
if mask is None:
|
|
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
|
else:
|
|
mask = rearrange(mask, "b n -> b 1 1 n")
|
|
mask = mask.split(q_bucket_size, dim=-1)
|
|
|
|
row_splits = zip(
|
|
q.split(q_bucket_size, dim=-2),
|
|
o.split(q_bucket_size, dim=-2),
|
|
mask,
|
|
all_row_sums.split(q_bucket_size, dim=-2),
|
|
all_row_maxes.split(q_bucket_size, dim=-2),
|
|
)
|
|
|
|
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
|
q_start_index = ind * q_bucket_size - qk_len_diff
|
|
|
|
col_splits = zip(
|
|
k.split(k_bucket_size, dim=-2),
|
|
v.split(k_bucket_size, dim=-2),
|
|
)
|
|
|
|
for k_ind, (kc, vc) in enumerate(col_splits):
|
|
k_start_index = k_ind * k_bucket_size
|
|
|
|
attn_weights = (
|
|
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
|
)
|
|
|
|
if row_mask is not None:
|
|
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
|
|
|
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
|
causal_mask = torch.ones(
|
|
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
|
).triu(q_start_index - k_start_index + 1)
|
|
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
|
|
|
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
|
attn_weights -= block_row_maxes
|
|
exp_weights = torch.exp(attn_weights)
|
|
|
|
if row_mask is not None:
|
|
exp_weights.masked_fill_(~row_mask, 0.0)
|
|
|
|
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
|
|
min=EPSILON
|
|
)
|
|
|
|
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
|
|
|
exp_values = torch.einsum(
|
|
"... i j, ... j d -> ... i d", exp_weights, vc
|
|
)
|
|
|
|
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
|
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
|
|
|
new_row_sums = (
|
|
exp_row_max_diff * row_sums
|
|
+ exp_block_row_max_diff * block_row_sums
|
|
)
|
|
|
|
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
|
|
(exp_block_row_max_diff / new_row_sums) * exp_values
|
|
)
|
|
|
|
row_maxes.copy_(new_row_maxes)
|
|
row_sums.copy_(new_row_sums)
|
|
|
|
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
|
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
|
|
|
return o
|
|
|
|
@staticmethod
|
|
@torch.no_grad()
|
|
def backward(ctx, do):
|
|
"""Algorithm 4 in the paper"""
|
|
|
|
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
|
q, k, v, o, l, m = ctx.saved_tensors
|
|
|
|
device = q.device
|
|
|
|
max_neg_value = -torch.finfo(q.dtype).max
|
|
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
|
|
|
dq = torch.zeros_like(q)
|
|
dk = torch.zeros_like(k)
|
|
dv = torch.zeros_like(v)
|
|
|
|
row_splits = zip(
|
|
q.split(q_bucket_size, dim=-2),
|
|
o.split(q_bucket_size, dim=-2),
|
|
do.split(q_bucket_size, dim=-2),
|
|
mask,
|
|
l.split(q_bucket_size, dim=-2),
|
|
m.split(q_bucket_size, dim=-2),
|
|
dq.split(q_bucket_size, dim=-2),
|
|
)
|
|
|
|
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
|
q_start_index = ind * q_bucket_size - qk_len_diff
|
|
|
|
col_splits = zip(
|
|
k.split(k_bucket_size, dim=-2),
|
|
v.split(k_bucket_size, dim=-2),
|
|
dk.split(k_bucket_size, dim=-2),
|
|
dv.split(k_bucket_size, dim=-2),
|
|
)
|
|
|
|
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
|
k_start_index = k_ind * k_bucket_size
|
|
|
|
attn_weights = (
|
|
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
|
)
|
|
|
|
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
|
causal_mask = torch.ones(
|
|
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
|
).triu(q_start_index - k_start_index + 1)
|
|
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
|
|
|
exp_attn_weights = torch.exp(attn_weights - mc)
|
|
|
|
if row_mask is not None:
|
|
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
|
|
|
p = exp_attn_weights / lc
|
|
|
|
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
|
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
|
|
|
D = (doc * oc).sum(dim=-1, keepdims=True)
|
|
ds = p * scale * (dp - D)
|
|
|
|
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
|
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
|
|
|
dqc.add_(dq_chunk)
|
|
dkc.add_(dk_chunk)
|
|
dvc.add_(dv_chunk)
|
|
|
|
return dq, dk, dv, None, None, None, None
|
|
|
|
|
|
class FlashAttnProcessor:
|
|
def __call__(
|
|
self,
|
|
attn: Attention,
|
|
hidden_states,
|
|
encoder_hidden_states=None,
|
|
attention_mask=None,
|
|
) -> Any:
|
|
q_bucket_size = 512
|
|
k_bucket_size = 1024
|
|
|
|
h = attn.heads
|
|
q = attn.to_q(hidden_states)
|
|
|
|
encoder_hidden_states = (
|
|
encoder_hidden_states
|
|
if encoder_hidden_states is not None
|
|
else hidden_states
|
|
)
|
|
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
|
|
|
|
if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
|
|
context_k, context_v = attn.hypernetwork.forward(
|
|
hidden_states, encoder_hidden_states
|
|
)
|
|
context_k = context_k.to(hidden_states.dtype)
|
|
context_v = context_v.to(hidden_states.dtype)
|
|
else:
|
|
context_k = encoder_hidden_states
|
|
context_v = encoder_hidden_states
|
|
|
|
k = attn.to_k(context_k)
|
|
v = attn.to_v(context_v)
|
|
del encoder_hidden_states, hidden_states
|
|
|
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
|
|
|
out = FlashAttentionFunction.apply(
|
|
q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
|
|
)
|
|
|
|
out = rearrange(out, "b h n d -> b n (h d)")
|
|
|
|
out = attn.to_out[0](out)
|
|
out = attn.to_out[1](out)
|
|
return out
|
|
|