|
from typing import * |
|
from enum import Enum |
|
import torch |
|
import math |
|
from .. import SparseTensor |
|
from .. import DEBUG, ATTN |
|
|
|
if ATTN == "xformers": |
|
import xformers.ops as xops |
|
elif ATTN == "flash_attn": |
|
import flash_attn |
|
else: |
|
raise ValueError(f"Unknown attention module: {ATTN}") |
|
|
|
|
|
__all__ = [ |
|
"sparse_serialized_scaled_dot_product_self_attention", |
|
] |
|
|
|
|
|
class SerializeMode(Enum): |
|
Z_ORDER = 0 |
|
Z_ORDER_TRANSPOSED = 1 |
|
HILBERT = 2 |
|
HILBERT_TRANSPOSED = 3 |
|
|
|
|
|
SerializeModes = [ |
|
SerializeMode.Z_ORDER, |
|
SerializeMode.Z_ORDER_TRANSPOSED, |
|
SerializeMode.HILBERT, |
|
SerializeMode.HILBERT_TRANSPOSED, |
|
] |
|
|
|
|
|
def calc_serialization( |
|
tensor: SparseTensor, |
|
window_size: int, |
|
serialize_mode: SerializeMode = SerializeMode.Z_ORDER, |
|
shift_sequence: int = 0, |
|
shift_window: Tuple[int, int, int] = (0, 0, 0), |
|
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: |
|
""" |
|
Calculate serialization and partitioning for a set of coordinates. |
|
|
|
Args: |
|
tensor (SparseTensor): The input tensor. |
|
window_size (int): The window size to use. |
|
serialize_mode (SerializeMode): The serialization mode to use. |
|
shift_sequence (int): The shift of serialized sequence. |
|
shift_window (Tuple[int, int, int]): The shift of serialized coordinates. |
|
|
|
Returns: |
|
(torch.Tensor, torch.Tensor): Forwards and backwards indices. |
|
""" |
|
fwd_indices = [] |
|
bwd_indices = [] |
|
seq_lens = [] |
|
seq_batch_indices = [] |
|
offsets = [0] |
|
|
|
if "vox2seq" not in globals(): |
|
import vox2seq |
|
|
|
|
|
serialize_coords = tensor.coords[:, 1:].clone() |
|
serialize_coords += torch.tensor( |
|
shift_window, dtype=torch.int32, device=tensor.device |
|
).reshape(1, 3) |
|
if serialize_mode == SerializeMode.Z_ORDER: |
|
code = vox2seq.encode(serialize_coords, mode="z_order", permute=[0, 1, 2]) |
|
elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED: |
|
code = vox2seq.encode(serialize_coords, mode="z_order", permute=[1, 0, 2]) |
|
elif serialize_mode == SerializeMode.HILBERT: |
|
code = vox2seq.encode(serialize_coords, mode="hilbert", permute=[0, 1, 2]) |
|
elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED: |
|
code = vox2seq.encode(serialize_coords, mode="hilbert", permute=[1, 0, 2]) |
|
else: |
|
raise ValueError(f"Unknown serialize mode: {serialize_mode}") |
|
|
|
for bi, s in enumerate(tensor.layout): |
|
num_points = s.stop - s.start |
|
num_windows = (num_points + window_size - 1) // window_size |
|
valid_window_size = num_points / num_windows |
|
to_ordered = torch.argsort(code[s.start : s.stop]) |
|
if num_windows == 1: |
|
fwd_indices.append(to_ordered) |
|
bwd_indices.append( |
|
torch.zeros_like(to_ordered).scatter_( |
|
0, to_ordered, torch.arange(num_points, device=tensor.device) |
|
) |
|
) |
|
fwd_indices[-1] += s.start |
|
bwd_indices[-1] += offsets[-1] |
|
seq_lens.append(num_points) |
|
seq_batch_indices.append(bi) |
|
offsets.append(offsets[-1] + seq_lens[-1]) |
|
else: |
|
|
|
offset = 0 |
|
mids = [ |
|
(i + 0.5) * valid_window_size + shift_sequence |
|
for i in range(num_windows) |
|
] |
|
split = [ |
|
math.floor(i * valid_window_size + shift_sequence) |
|
for i in range(num_windows + 1) |
|
] |
|
bwd_index = torch.zeros( |
|
(num_points,), dtype=torch.int64, device=tensor.device |
|
) |
|
for i in range(num_windows): |
|
mid = mids[i] |
|
valid_start = split[i] |
|
valid_end = split[i + 1] |
|
padded_start = math.floor(mid - 0.5 * window_size) |
|
padded_end = padded_start + window_size |
|
fwd_indices.append( |
|
to_ordered[ |
|
torch.arange(padded_start, padded_end, device=tensor.device) |
|
% num_points |
|
] |
|
) |
|
offset += valid_start - padded_start |
|
bwd_index.scatter_( |
|
0, |
|
fwd_indices[-1][ |
|
valid_start - padded_start : valid_end - padded_start |
|
], |
|
torch.arange( |
|
offset, offset + valid_end - valid_start, device=tensor.device |
|
), |
|
) |
|
offset += padded_end - valid_start |
|
fwd_indices[-1] += s.start |
|
seq_lens.extend([window_size] * num_windows) |
|
seq_batch_indices.extend([bi] * num_windows) |
|
bwd_indices.append(bwd_index + offsets[-1]) |
|
offsets.append(offsets[-1] + num_windows * window_size) |
|
|
|
fwd_indices = torch.cat(fwd_indices) |
|
bwd_indices = torch.cat(bwd_indices) |
|
|
|
return fwd_indices, bwd_indices, seq_lens, seq_batch_indices |
|
|
|
|
|
def sparse_serialized_scaled_dot_product_self_attention( |
|
qkv: SparseTensor, |
|
window_size: int, |
|
serialize_mode: SerializeMode = SerializeMode.Z_ORDER, |
|
shift_sequence: int = 0, |
|
shift_window: Tuple[int, int, int] = (0, 0, 0), |
|
) -> SparseTensor: |
|
""" |
|
Apply serialized scaled dot product self attention to a sparse tensor. |
|
|
|
Args: |
|
qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. |
|
window_size (int): The window size to use. |
|
serialize_mode (SerializeMode): The serialization mode to use. |
|
shift_sequence (int): The shift of serialized sequence. |
|
shift_window (Tuple[int, int, int]): The shift of serialized coordinates. |
|
shift (int): The shift to use. |
|
""" |
|
assert ( |
|
len(qkv.shape) == 4 and qkv.shape[1] == 3 |
|
), f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" |
|
|
|
serialization_spatial_cache_name = ( |
|
f"serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}" |
|
) |
|
serialization_spatial_cache = qkv.get_spatial_cache( |
|
serialization_spatial_cache_name |
|
) |
|
if serialization_spatial_cache is None: |
|
fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization( |
|
qkv, window_size, serialize_mode, shift_sequence, shift_window |
|
) |
|
qkv.register_spatial_cache( |
|
serialization_spatial_cache_name, |
|
(fwd_indices, bwd_indices, seq_lens, seq_batch_indices), |
|
) |
|
else: |
|
( |
|
fwd_indices, |
|
bwd_indices, |
|
seq_lens, |
|
seq_batch_indices, |
|
) = serialization_spatial_cache |
|
|
|
M = fwd_indices.shape[0] |
|
T = qkv.feats.shape[0] |
|
H = qkv.feats.shape[2] |
|
C = qkv.feats.shape[3] |
|
|
|
qkv_feats = qkv.feats[fwd_indices] |
|
|
|
if DEBUG: |
|
start = 0 |
|
qkv_coords = qkv.coords[fwd_indices] |
|
for i in range(len(seq_lens)): |
|
assert ( |
|
qkv_coords[start : start + seq_lens[i], 0] == seq_batch_indices[i] |
|
).all(), ( |
|
f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" |
|
) |
|
start += seq_lens[i] |
|
|
|
if all([seq_len == window_size for seq_len in seq_lens]): |
|
B = len(seq_lens) |
|
N = window_size |
|
qkv_feats = qkv_feats.reshape(B, N, 3, H, C) |
|
if ATTN == "xformers": |
|
q, k, v = qkv_feats.unbind(dim=2) |
|
out = xops.memory_efficient_attention(q, k, v) |
|
elif ATTN == "flash_attn": |
|
out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) |
|
else: |
|
raise ValueError(f"Unknown attention module: {ATTN}") |
|
out = out.reshape(B * N, H, C) |
|
else: |
|
if ATTN == "xformers": |
|
q, k, v = qkv_feats.unbind(dim=1) |
|
q = q.unsqueeze(0) |
|
k = k.unsqueeze(0) |
|
v = v.unsqueeze(0) |
|
mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) |
|
out = xops.memory_efficient_attention(q, k, v, mask)[0] |
|
elif ATTN == "flash_attn": |
|
cu_seqlens = ( |
|
torch.cat( |
|
[torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], |
|
dim=0, |
|
) |
|
.to(qkv.device) |
|
.int() |
|
) |
|
out = flash_attn.flash_attn_varlen_qkvpacked_func( |
|
qkv_feats, cu_seqlens, max(seq_lens) |
|
) |
|
|
|
out = out[bwd_indices] |
|
|
|
if DEBUG: |
|
qkv_coords = qkv_coords[bwd_indices] |
|
assert torch.equal( |
|
qkv_coords, qkv.coords |
|
), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" |
|
|
|
return qkv.replace(out) |
|
|