|
import random |
|
from typing import List, Tuple |
|
|
|
import paged_attention as ops |
|
import pytest |
|
import torch |
|
from paged_attention.platforms import current_platform |
|
|
|
from .utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck |
|
|
|
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")] |
|
DTYPES = [torch.half, torch.bfloat16, torch.float] |
|
NUM_TOKENS = [42] |
|
NUM_LAYERS = [1] |
|
NUM_HEADS = [8] |
|
HEAD_SIZES = [64, 80, 120, 256] |
|
BLOCK_SIZES = [8, 16, 32] |
|
|
|
|
|
|
|
NUM_BLOCKS = [1024, 10000] |
|
|
|
NUM_MAPPINGS = [256] |
|
SEEDS = [0] |
|
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] |
|
|
|
|
|
KV_CACHE_DTYPE = ["auto", "fp8"] |
|
|
|
|
|
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) |
|
@pytest.mark.parametrize("num_layers", NUM_LAYERS) |
|
@pytest.mark.parametrize("num_heads", NUM_HEADS) |
|
@pytest.mark.parametrize("head_size", HEAD_SIZES) |
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES) |
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) |
|
@pytest.mark.parametrize("dtype", DTYPES) |
|
@pytest.mark.parametrize("seed", SEEDS) |
|
@pytest.mark.parametrize("device", CUDA_DEVICES) |
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) |
|
@torch.inference_mode() |
|
def test_copy_blocks( |
|
kv_cache_factory, |
|
num_mappings: int, |
|
num_layers: int, |
|
num_heads: int, |
|
head_size: int, |
|
block_size: int, |
|
num_blocks: int, |
|
dtype: torch.dtype, |
|
seed: int, |
|
kv_cache_dtype: str, |
|
device: str, |
|
) -> None: |
|
if kv_cache_dtype == "fp8" and head_size % 16: |
|
pytest.skip() |
|
current_platform.seed_everything(seed) |
|
torch.set_default_device(device) |
|
|
|
|
|
assert 2 * num_mappings <= num_blocks |
|
src_blocks = random.sample(range(num_blocks), num_mappings) |
|
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) |
|
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) |
|
block_mapping: List[Tuple[int, int]] = [] |
|
for i in range(num_mappings): |
|
src = src_blocks[i] |
|
dst1 = dst_blocks[2 * i] |
|
dst2 = dst_blocks[2 * i + 1] |
|
block_mapping.append((src, dst1)) |
|
block_mapping.append((src, dst2)) |
|
|
|
|
|
key_caches, value_caches = kv_cache_factory( |
|
num_blocks, |
|
block_size, |
|
num_layers, |
|
num_heads, |
|
head_size, |
|
kv_cache_dtype, |
|
dtype, |
|
seed, |
|
device, |
|
) |
|
|
|
|
|
cloned_key_caches = [key_cache.clone() for key_cache in key_caches] |
|
cloned_value_caches = [value_cache.clone() for value_cache in value_caches] |
|
|
|
|
|
block_mapping_tensor = torch.tensor( |
|
block_mapping, dtype=torch.int64, device=device |
|
).view(-1, 2) |
|
|
|
opcheck( |
|
ops.ops.copy_blocks, |
|
(key_caches, value_caches, block_mapping_tensor), |
|
test_utils=DEFAULT_OPCHECK_TEST_UTILS, |
|
cond=(head_size == HEAD_SIZES[0]), |
|
) |
|
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor) |
|
|
|
|
|
for src, dst in block_mapping: |
|
for cloned_key_cache in cloned_key_caches: |
|
cloned_key_cache[dst].copy_(cloned_key_cache[src]) |
|
for cloned_value_cache in cloned_value_caches: |
|
cloned_value_cache[dst].copy_(cloned_value_cache[src]) |
|
|
|
|
|
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): |
|
torch.testing.assert_close(key_cache, cloned_key_cache) |
|
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): |
|
torch.testing.assert_close(value_cache, cloned_value_cache) |
|
|
|
|
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
|
@pytest.mark.parametrize("num_heads", NUM_HEADS) |
|
@pytest.mark.parametrize("head_size", HEAD_SIZES) |
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES) |
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) |
|
@pytest.mark.parametrize("dtype", DTYPES) |
|
@pytest.mark.parametrize("seed", SEEDS) |
|
@pytest.mark.parametrize("device", CUDA_DEVICES) |
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) |
|
@torch.inference_mode() |
|
def test_reshape_and_cache( |
|
kv_cache_factory, |
|
num_tokens: int, |
|
num_heads: int, |
|
head_size: int, |
|
block_size: int, |
|
num_blocks: int, |
|
dtype: torch.dtype, |
|
seed: int, |
|
device: str, |
|
kv_cache_dtype: str, |
|
) -> None: |
|
if kv_cache_dtype == "fp8" and head_size % 16: |
|
pytest.skip() |
|
current_platform.seed_everything(seed) |
|
torch.set_default_device(device) |
|
|
|
num_slots = block_size * num_blocks |
|
slot_mapping_lst = random.sample(range(num_slots), num_tokens) |
|
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long) |
|
|
|
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype) |
|
_, key, value = qkv.unbind(dim=1) |
|
|
|
|
|
key_caches, value_caches = kv_cache_factory( |
|
num_blocks, |
|
block_size, |
|
1, |
|
num_heads, |
|
head_size, |
|
kv_cache_dtype, |
|
dtype, |
|
seed, |
|
device, |
|
) |
|
key_cache, value_cache = key_caches[0], value_caches[0] |
|
|
|
|
|
if kv_cache_dtype == "fp8": |
|
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) |
|
ops.convert_fp8(cloned_key_cache, key_cache) |
|
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) |
|
ops.convert_fp8(cloned_value_cache, value_cache) |
|
else: |
|
cloned_key_cache = key_cache.clone() |
|
cloned_value_cache = value_cache.clone() |
|
|
|
|
|
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) |
|
|
|
|
|
opcheck( |
|
ops.ops.reshape_and_cache, |
|
( |
|
key, |
|
value, |
|
key_cache, |
|
value_cache, |
|
slot_mapping, |
|
kv_cache_dtype, |
|
k_scale, |
|
v_scale, |
|
), |
|
cond=(head_size == HEAD_SIZES[0]), |
|
) |
|
ops.reshape_and_cache( |
|
key, |
|
value, |
|
key_cache, |
|
value_cache, |
|
slot_mapping, |
|
kv_cache_dtype, |
|
k_scale, |
|
v_scale, |
|
) |
|
|
|
if kv_cache_dtype == "fp8": |
|
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) |
|
ops.convert_fp8(result_key_cache, key_cache) |
|
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) |
|
ops.convert_fp8(result_value_cache, value_cache) |
|
|
|
|
|
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) |
|
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") |
|
block_indicies_lst = block_indicies.cpu().tolist() |
|
block_offsets = slot_mapping % block_size |
|
block_offsets_lst = block_offsets.cpu().tolist() |
|
for i in range(num_tokens): |
|
block_idx = block_indicies_lst[i] |
|
block_offset = block_offsets_lst[i] |
|
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] |
|
cloned_value_cache[block_idx, :, :, block_offset] = value[i] |
|
|
|
if kv_cache_dtype == "fp8": |
|
torch.testing.assert_close( |
|
result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1 |
|
) |
|
torch.testing.assert_close( |
|
result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1 |
|
) |
|
else: |
|
torch.testing.assert_close(key_cache, cloned_key_cache) |
|
torch.testing.assert_close(value_cache, cloned_value_cache) |
|
|
|
|
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
|
@pytest.mark.parametrize("num_heads", NUM_HEADS) |
|
@pytest.mark.parametrize("head_size", HEAD_SIZES) |
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES) |
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) |
|
@pytest.mark.parametrize("dtype", DTYPES) |
|
@pytest.mark.parametrize("seed", SEEDS) |
|
@pytest.mark.parametrize("device", CUDA_DEVICES) |
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) |
|
@torch.inference_mode() |
|
def test_reshape_and_cache_flash( |
|
kv_cache_factory_flashinfer, |
|
num_tokens: int, |
|
num_heads: int, |
|
head_size: int, |
|
block_size: int, |
|
num_blocks: int, |
|
dtype: torch.dtype, |
|
seed: int, |
|
device: str, |
|
kv_cache_dtype: str, |
|
) -> None: |
|
current_platform.seed_everything(seed) |
|
torch.set_default_device(device) |
|
|
|
|
|
num_slots = block_size * num_blocks |
|
slot_mapping_lst = random.sample(range(num_slots), num_tokens) |
|
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) |
|
|
|
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device) |
|
_, key, value = qkv.unbind(dim=1) |
|
|
|
|
|
key_caches, value_caches = kv_cache_factory_flashinfer( |
|
num_blocks, |
|
block_size, |
|
1, |
|
num_heads, |
|
head_size, |
|
kv_cache_dtype, |
|
dtype, |
|
device=device, |
|
) |
|
key_cache, value_cache = key_caches[0].contiguous(), value_caches[0].contiguous() |
|
del key_caches |
|
del value_caches |
|
|
|
k_scale = (key.amax() / 256.0).to(torch.float32) |
|
v_scale = (value.amax() / 256.0).to(torch.float32) |
|
|
|
|
|
if kv_cache_dtype == "fp8": |
|
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) |
|
ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype) |
|
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) |
|
ops.convert_fp8(cloned_value_cache, value_cache, v_scale, kv_cache_dtype) |
|
else: |
|
cloned_key_cache = key_cache.clone() |
|
cloned_value_cache = value_cache.clone() |
|
|
|
|
|
opcheck( |
|
ops.ops.reshape_and_cache_flash, |
|
( |
|
key, |
|
value, |
|
key_cache, |
|
value_cache, |
|
slot_mapping, |
|
kv_cache_dtype, |
|
k_scale, |
|
v_scale, |
|
), |
|
cond=(head_size == HEAD_SIZES[0]), |
|
) |
|
ops.reshape_and_cache_flash( |
|
key, |
|
value, |
|
key_cache, |
|
value_cache, |
|
slot_mapping, |
|
kv_cache_dtype, |
|
k_scale, |
|
v_scale, |
|
) |
|
|
|
if kv_cache_dtype == "fp8": |
|
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) |
|
ops.convert_fp8( |
|
result_key_cache, key_cache, k_scale.item(), kv_dtype=kv_cache_dtype |
|
) |
|
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) |
|
ops.convert_fp8( |
|
result_value_cache, value_cache, v_scale.item(), kv_dtype=kv_cache_dtype |
|
) |
|
|
|
|
|
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") |
|
block_indicies_lst = block_indicies.cpu().tolist() |
|
block_offsets = slot_mapping % block_size |
|
block_offsets_lst = block_offsets.cpu().tolist() |
|
for i in range(num_tokens): |
|
block_idx = block_indicies_lst[i] |
|
block_offset = block_offsets_lst[i] |
|
cloned_key_cache[block_idx, block_offset, :, :] = key[i] |
|
cloned_value_cache[block_idx, block_offset, :, :] = value[i] |
|
|
|
if kv_cache_dtype == "fp8": |
|
torch.testing.assert_close( |
|
result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1 |
|
) |
|
torch.testing.assert_close( |
|
result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1 |
|
) |
|
else: |
|
torch.testing.assert_close(key_cache, cloned_key_cache) |
|
torch.testing.assert_close(value_cache, cloned_value_cache) |
|
|
|
|
|
@pytest.mark.parametrize("direction", COPYING_DIRECTION) |
|
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) |
|
@pytest.mark.parametrize("num_heads", NUM_HEADS) |
|
@pytest.mark.parametrize("head_size", HEAD_SIZES) |
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES) |
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) |
|
@pytest.mark.parametrize("dtype", DTYPES) |
|
@pytest.mark.parametrize("seed", SEEDS) |
|
@pytest.mark.parametrize("device", CUDA_DEVICES) |
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) |
|
@torch.inference_mode() |
|
def test_swap_blocks( |
|
kv_cache_factory, |
|
direction: Tuple[str, str], |
|
num_mappings: int, |
|
num_heads: int, |
|
head_size: int, |
|
block_size: int, |
|
num_blocks: int, |
|
dtype: torch.dtype, |
|
seed: int, |
|
device: str, |
|
kv_cache_dtype: str, |
|
) -> None: |
|
if kv_cache_dtype == "fp8" and "cpu" in direction: |
|
pytest.skip() |
|
if kv_cache_dtype == "fp8" and head_size % 16: |
|
pytest.skip() |
|
|
|
current_platform.seed_everything(seed) |
|
|
|
src_device = device if direction[0] == "cuda" else "cpu" |
|
dst_device = device if direction[1] == "cuda" else "cpu" |
|
|
|
src_blocks = random.sample(range(num_blocks), num_mappings) |
|
|
|
if src_device == dst_device: |
|
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks)) |
|
dst_blocks = random.sample(remaining_blocks, num_mappings) |
|
else: |
|
dst_blocks = random.sample(range(num_blocks), num_mappings) |
|
|
|
block_mapping = list(zip(src_blocks, dst_blocks)) |
|
block_mapping_tensor = torch.tensor( |
|
block_mapping, dtype=torch.int64, device="cpu" |
|
).view(-1, 2) |
|
|
|
|
|
src_key_caches, src_value_caches = kv_cache_factory( |
|
num_blocks, |
|
block_size, |
|
1, |
|
num_heads, |
|
head_size, |
|
kv_cache_dtype, |
|
dtype, |
|
seed, |
|
src_device, |
|
) |
|
|
|
|
|
dist_key_caches, dist_value_caches = kv_cache_factory( |
|
num_blocks, |
|
block_size, |
|
1, |
|
num_heads, |
|
head_size, |
|
kv_cache_dtype, |
|
dtype, |
|
seed, |
|
dst_device, |
|
) |
|
|
|
src_key_caches_clone = src_key_caches[0].clone() |
|
src_value_caches_clone = src_value_caches[0].clone() |
|
|
|
|
|
do_opcheck = head_size == HEAD_SIZES[0] |
|
opcheck( |
|
ops.ops.swap_blocks, |
|
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor), |
|
cond=do_opcheck, |
|
) |
|
opcheck( |
|
ops.ops.swap_blocks, |
|
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor), |
|
cond=do_opcheck, |
|
) |
|
|
|
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor) |
|
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor) |
|
|
|
for src, dst in block_mapping: |
|
torch.testing.assert_close( |
|
src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu() |
|
) |
|
torch.testing.assert_close( |
|
src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu() |
|
) |
|
|
|
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS) |
|
@pytest.mark.parametrize("head_size", HEAD_SIZES) |
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES) |
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) |
|
@pytest.mark.parametrize("dtype", DTYPES) |
|
@pytest.mark.parametrize("seed", SEEDS) |
|
@pytest.mark.parametrize("device", CUDA_DEVICES) |
|
@torch.inference_mode() |
|
def test_fp8_e4m3_conversion( |
|
num_heads: int, |
|
head_size: int, |
|
block_size: int, |
|
num_blocks: int, |
|
dtype: torch.dtype, |
|
seed: int, |
|
device: str, |
|
) -> None: |
|
current_platform.seed_everything(seed) |
|
|
|
low = -224.0 |
|
high = 224.0 |
|
shape = (num_blocks, num_heads, head_size, block_size) |
|
cache = torch.empty(shape, dtype=dtype, device=device) |
|
cache.uniform_(low, high) |
|
|
|
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) |
|
ops.convert_fp8(cache_fp8, cache) |
|
|
|
converted_cache = torch.empty_like(cache) |
|
ops.convert_fp8(converted_cache, cache_fp8) |
|
|
|
torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1) |
|
|