Spaces:
ifire
/
Runtime error

ifire's picture
Format code and change app.py.
a6bbecf
from typing import *
import torch
from .. import SparseTensor
from .. import DEBUG, ATTN
if ATTN == "xformers":
import xformers.ops as xops
elif ATTN == "flash_attn":
import flash_attn
else:
raise ValueError(f"Unknown attention module: {ATTN}")
__all__ = [
"sparse_scaled_dot_product_attention",
]
@overload
def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
"""
...
@overload
def sparse_scaled_dot_product_attention(
q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]
) -> SparseTensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs.
kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
"""
...
@overload
def sparse_scaled_dot_product_attention(
q: torch.Tensor, kv: SparseTensor
) -> torch.Tensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
q (SparseTensor): A [N, L, H, C] dense tensor containing Qs.
kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
"""
...
@overload
def sparse_scaled_dot_product_attention(
q: SparseTensor, k: SparseTensor, v: SparseTensor
) -> SparseTensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
Note:
k and v are assumed to have the same coordinate map.
"""
...
@overload
def sparse_scaled_dot_product_attention(
q: SparseTensor, k: torch.Tensor, v: torch.Tensor
) -> SparseTensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
"""
...
@overload
def sparse_scaled_dot_product_attention(
q: torch.Tensor, k: SparseTensor, v: SparseTensor
) -> torch.Tensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
"""
...
def sparse_scaled_dot_product_attention(*args, **kwargs):
arg_names_dict = {1: ["qkv"], 2: ["q", "kv"], 3: ["q", "k", "v"]}
num_all_args = len(args) + len(kwargs)
assert (
num_all_args in arg_names_dict
), f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
for key in arg_names_dict[num_all_args][len(args) :]:
assert key in kwargs, f"Missing argument {key}"
if num_all_args == 1:
qkv = args[0] if len(args) > 0 else kwargs["qkv"]
assert isinstance(
qkv, SparseTensor
), f"qkv must be a SparseTensor, got {type(qkv)}"
assert (
len(qkv.shape) == 4 and qkv.shape[1] == 3
), f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
device = qkv.device
s = qkv
q_seqlen = [
qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])
]
kv_seqlen = q_seqlen
qkv = qkv.feats # [T, 3, H, C]
elif num_all_args == 2:
q = args[0] if len(args) > 0 else kwargs["q"]
kv = args[1] if len(args) > 1 else kwargs["kv"]
assert (
isinstance(q, SparseTensor)
and isinstance(kv, (SparseTensor, torch.Tensor))
or isinstance(q, torch.Tensor)
and isinstance(kv, SparseTensor)
), f"Invalid types, got {type(q)} and {type(kv)}"
assert (
q.shape[0] == kv.shape[0]
), f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
device = q.device
if isinstance(q, SparseTensor):
assert (
len(q.shape) == 3
), f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
s = q
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
q = q.feats # [T_Q, H, C]
else:
assert (
len(q.shape) == 4
), f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
s = None
N, L, H, C = q.shape
q_seqlen = [L] * N
q = q.reshape(N * L, H, C) # [T_Q, H, C]
if isinstance(kv, SparseTensor):
assert (
len(kv.shape) == 4 and kv.shape[1] == 2
), f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
kv_seqlen = [
kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])
]
kv = kv.feats # [T_KV, 2, H, C]
else:
assert (
len(kv.shape) == 5
), f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
N, L, _, H, C = kv.shape
kv_seqlen = [L] * N
kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
elif num_all_args == 3:
q = args[0] if len(args) > 0 else kwargs["q"]
k = args[1] if len(args) > 1 else kwargs["k"]
v = args[2] if len(args) > 2 else kwargs["v"]
assert (
isinstance(q, SparseTensor)
and isinstance(k, (SparseTensor, torch.Tensor))
and type(k) == type(v)
or isinstance(q, torch.Tensor)
and isinstance(k, SparseTensor)
and isinstance(v, SparseTensor)
), f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
assert (
q.shape[0] == k.shape[0] == v.shape[0]
), f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
device = q.device
if isinstance(q, SparseTensor):
assert (
len(q.shape) == 3
), f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
s = q
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
q = q.feats # [T_Q, H, Ci]
else:
assert (
len(q.shape) == 4
), f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
s = None
N, L, H, CI = q.shape
q_seqlen = [L] * N
q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
if isinstance(k, SparseTensor):
assert (
len(k.shape) == 3
), f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
assert (
len(v.shape) == 3
), f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
kv_seqlen = [
k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])
]
k = k.feats # [T_KV, H, Ci]
v = v.feats # [T_KV, H, Co]
else:
assert (
len(k.shape) == 4
), f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
assert (
len(v.shape) == 4
), f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
N, L, H, CI, CO = *k.shape, v.shape[-1]
kv_seqlen = [L] * N
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
if DEBUG:
if s is not None:
for i in range(s.shape[0]):
assert (
s.coords[s.layout[i]] == i
).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
if num_all_args in [2, 3]:
assert q.shape[:2] == [
1,
sum(q_seqlen),
], f"SparseScaledDotProductSelfAttention: q shape mismatch"
if num_all_args == 3:
assert k.shape[:2] == [
1,
sum(kv_seqlen),
], f"SparseScaledDotProductSelfAttention: k shape mismatch"
assert v.shape[:2] == [
1,
sum(kv_seqlen),
], f"SparseScaledDotProductSelfAttention: v shape mismatch"
if ATTN == "xformers":
if num_all_args == 1:
q, k, v = qkv.unbind(dim=1)
elif num_all_args == 2:
k, v = kv.unbind(dim=1)
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
out = xops.memory_efficient_attention(q, k, v, mask)[0]
elif ATTN == "flash_attn":
cu_seqlens_q = (
torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)])
.int()
.to(device)
)
if num_all_args in [2, 3]:
cu_seqlens_kv = (
torch.cat(
[torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]
)
.int()
.to(device)
)
if num_all_args == 1:
out = flash_attn.flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens_q, max(q_seqlen)
)
elif num_all_args == 2:
out = flash_attn.flash_attn_varlen_kvpacked_func(
q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)
)
elif num_all_args == 3:
out = flash_attn.flash_attn_varlen_func(
q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)
)
else:
raise ValueError(f"Unknown attention module: {ATTN}")
if s is not None:
return s.replace(out)
else:
return out.reshape(N, L, H, -1)