Spaces:
ifire
/
Runtime error

ifire's picture
Format code and change app.py.
a6bbecf
from typing import *
from enum import Enum
import torch
import math
from .. import SparseTensor
from .. import DEBUG, ATTN
if ATTN == "xformers":
import xformers.ops as xops
elif ATTN == "flash_attn":
import flash_attn
else:
raise ValueError(f"Unknown attention module: {ATTN}")
__all__ = [
"sparse_serialized_scaled_dot_product_self_attention",
]
class SerializeMode(Enum):
Z_ORDER = 0
Z_ORDER_TRANSPOSED = 1
HILBERT = 2
HILBERT_TRANSPOSED = 3
SerializeModes = [
SerializeMode.Z_ORDER,
SerializeMode.Z_ORDER_TRANSPOSED,
SerializeMode.HILBERT,
SerializeMode.HILBERT_TRANSPOSED,
]
def calc_serialization(
tensor: SparseTensor,
window_size: int,
serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
shift_sequence: int = 0,
shift_window: Tuple[int, int, int] = (0, 0, 0),
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
"""
Calculate serialization and partitioning for a set of coordinates.
Args:
tensor (SparseTensor): The input tensor.
window_size (int): The window size to use.
serialize_mode (SerializeMode): The serialization mode to use.
shift_sequence (int): The shift of serialized sequence.
shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
Returns:
(torch.Tensor, torch.Tensor): Forwards and backwards indices.
"""
fwd_indices = []
bwd_indices = []
seq_lens = []
seq_batch_indices = []
offsets = [0]
if "vox2seq" not in globals():
import vox2seq
# Serialize the input
serialize_coords = tensor.coords[:, 1:].clone()
serialize_coords += torch.tensor(
shift_window, dtype=torch.int32, device=tensor.device
).reshape(1, 3)
if serialize_mode == SerializeMode.Z_ORDER:
code = vox2seq.encode(serialize_coords, mode="z_order", permute=[0, 1, 2])
elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
code = vox2seq.encode(serialize_coords, mode="z_order", permute=[1, 0, 2])
elif serialize_mode == SerializeMode.HILBERT:
code = vox2seq.encode(serialize_coords, mode="hilbert", permute=[0, 1, 2])
elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
code = vox2seq.encode(serialize_coords, mode="hilbert", permute=[1, 0, 2])
else:
raise ValueError(f"Unknown serialize mode: {serialize_mode}")
for bi, s in enumerate(tensor.layout):
num_points = s.stop - s.start
num_windows = (num_points + window_size - 1) // window_size
valid_window_size = num_points / num_windows
to_ordered = torch.argsort(code[s.start : s.stop])
if num_windows == 1:
fwd_indices.append(to_ordered)
bwd_indices.append(
torch.zeros_like(to_ordered).scatter_(
0, to_ordered, torch.arange(num_points, device=tensor.device)
)
)
fwd_indices[-1] += s.start
bwd_indices[-1] += offsets[-1]
seq_lens.append(num_points)
seq_batch_indices.append(bi)
offsets.append(offsets[-1] + seq_lens[-1])
else:
# Partition the input
offset = 0
mids = [
(i + 0.5) * valid_window_size + shift_sequence
for i in range(num_windows)
]
split = [
math.floor(i * valid_window_size + shift_sequence)
for i in range(num_windows + 1)
]
bwd_index = torch.zeros(
(num_points,), dtype=torch.int64, device=tensor.device
)
for i in range(num_windows):
mid = mids[i]
valid_start = split[i]
valid_end = split[i + 1]
padded_start = math.floor(mid - 0.5 * window_size)
padded_end = padded_start + window_size
fwd_indices.append(
to_ordered[
torch.arange(padded_start, padded_end, device=tensor.device)
% num_points
]
)
offset += valid_start - padded_start
bwd_index.scatter_(
0,
fwd_indices[-1][
valid_start - padded_start : valid_end - padded_start
],
torch.arange(
offset, offset + valid_end - valid_start, device=tensor.device
),
)
offset += padded_end - valid_start
fwd_indices[-1] += s.start
seq_lens.extend([window_size] * num_windows)
seq_batch_indices.extend([bi] * num_windows)
bwd_indices.append(bwd_index + offsets[-1])
offsets.append(offsets[-1] + num_windows * window_size)
fwd_indices = torch.cat(fwd_indices)
bwd_indices = torch.cat(bwd_indices)
return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
def sparse_serialized_scaled_dot_product_self_attention(
qkv: SparseTensor,
window_size: int,
serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
shift_sequence: int = 0,
shift_window: Tuple[int, int, int] = (0, 0, 0),
) -> SparseTensor:
"""
Apply serialized scaled dot product self attention to a sparse tensor.
Args:
qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
window_size (int): The window size to use.
serialize_mode (SerializeMode): The serialization mode to use.
shift_sequence (int): The shift of serialized sequence.
shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
shift (int): The shift to use.
"""
assert (
len(qkv.shape) == 4 and qkv.shape[1] == 3
), f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
serialization_spatial_cache_name = (
f"serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}"
)
serialization_spatial_cache = qkv.get_spatial_cache(
serialization_spatial_cache_name
)
if serialization_spatial_cache is None:
fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(
qkv, window_size, serialize_mode, shift_sequence, shift_window
)
qkv.register_spatial_cache(
serialization_spatial_cache_name,
(fwd_indices, bwd_indices, seq_lens, seq_batch_indices),
)
else:
(
fwd_indices,
bwd_indices,
seq_lens,
seq_batch_indices,
) = serialization_spatial_cache
M = fwd_indices.shape[0]
T = qkv.feats.shape[0]
H = qkv.feats.shape[2]
C = qkv.feats.shape[3]
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
if DEBUG:
start = 0
qkv_coords = qkv.coords[fwd_indices]
for i in range(len(seq_lens)):
assert (
qkv_coords[start : start + seq_lens[i], 0] == seq_batch_indices[i]
).all(), (
f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
)
start += seq_lens[i]
if all([seq_len == window_size for seq_len in seq_lens]):
B = len(seq_lens)
N = window_size
qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
if ATTN == "xformers":
q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
elif ATTN == "flash_attn":
out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
else:
raise ValueError(f"Unknown attention module: {ATTN}")
out = out.reshape(B * N, H, C) # [M, H, C]
else:
if ATTN == "xformers":
q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
q = q.unsqueeze(0) # [1, M, H, C]
k = k.unsqueeze(0) # [1, M, H, C]
v = v.unsqueeze(0) # [1, M, H, C]
mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
elif ATTN == "flash_attn":
cu_seqlens = (
torch.cat(
[torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)],
dim=0,
)
.to(qkv.device)
.int()
)
out = flash_attn.flash_attn_varlen_qkvpacked_func(
qkv_feats, cu_seqlens, max(seq_lens)
) # [M, H, C]
out = out[bwd_indices] # [T, H, C]
if DEBUG:
qkv_coords = qkv_coords[bwd_indices]
assert torch.equal(
qkv_coords, qkv.coords
), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
return qkv.replace(out)