from typing import List, Optional, Tuple, Union import paged_attention as ops import pytest import torch @pytest.fixture() def kv_cache_factory(): return create_kv_caches_with_random @pytest.fixture() def kv_cache_factory_flashinfer(): return create_kv_caches_with_random_flash STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.half, "bfloat16": torch.bfloat16, "float": torch.float, "fp8": torch.uint8, "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, } def create_kv_caches_with_random( num_blocks: int, block_size: int, num_layers: int, num_heads: int, head_size: int, cache_dtype: Optional[Union[str, torch.dtype]], model_dtype: Optional[Union[str, torch.dtype]] = None, seed: int = 0, device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: if cache_dtype == "fp8" and head_size % 16: raise ValueError( f"Does not support key cache of type fp8 with head_size {head_size}" ) from paged_attention.platforms import current_platform current_platform.seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=torch_dtype).element_size() key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_caches: List[torch.Tensor] = [] for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device) if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_cache.uniform_(-scale, scale) elif cache_dtype == "fp8": _generate_random_fp8(key_cache, -scale, scale) else: raise ValueError(f"Does not support key cache of type {cache_dtype}") key_caches.append(key_cache) value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_caches: List[torch.Tensor] = [] for _ in range(num_layers): value_cache = torch.empty( size=value_cache_shape, dtype=torch_dtype, device=device ) if cache_dtype in ["auto", "half", "bfloat16", "float"]: value_cache.uniform_(-scale, scale) elif cache_dtype == "fp8": _generate_random_fp8(value_cache, -scale, scale) else: raise ValueError(f"Does not support value cache of type {cache_dtype}") value_caches.append(value_cache) return key_caches, value_caches def create_kv_caches_with_random_flash( num_blocks: int, block_size: int, num_layers: int, num_heads: int, head_size: int, cache_dtype: Optional[Union[str, torch.dtype]], model_dtype: Optional[Union[str, torch.dtype]] = None, seed: int = 0, device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: from paged_attention.platforms import current_platform current_platform.seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) scale = head_size**-0.5 key_caches: List[torch.Tensor] = [] value_caches: List[torch.Tensor] = [] for _ in range(num_layers): key_value_cache = torch.empty( size=key_value_cache_shape, dtype=torch_dtype, device=device ) if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_value_cache.uniform_(-scale, scale) elif cache_dtype == "fp8": _generate_random_fp8(key_value_cache, -scale, scale) else: raise ValueError(f"Does not support key cache of type {cache_dtype}") key_caches.append(key_value_cache[:, 0]) value_caches.append(key_value_cache[:, 1]) return key_caches, value_caches def get_kv_cache_torch_dtype( cache_dtype: Optional[Union[str, torch.dtype]], model_dtype: Optional[Union[str, torch.dtype]] = None, ) -> torch.dtype: if isinstance(cache_dtype, str): if cache_dtype == "auto": if isinstance(model_dtype, str): torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] elif isinstance(model_dtype, torch.dtype): torch_dtype = model_dtype else: raise ValueError(f"Invalid model dtype: {model_dtype}") elif cache_dtype in ["half", "bfloat16", "float"]: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] elif cache_dtype == "fp8": torch_dtype = torch.uint8 else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") elif isinstance(cache_dtype, torch.dtype): torch_dtype = cache_dtype else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") return torch_dtype def _generate_random_fp8( tensor: torch.Tensor, low: float, high: float, ) -> None: # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, # it may occur Inf or NaN if we directly use torch.randint # to generate random data for fp8 data. # For example, s.11111.00 in fp8e5m2 format represents Inf. # | E4M3 | E5M2 # -----|-------------|------------------- # Inf | N/A | s.11111.00 # NaN | s.1111.111 | s.11111.{01,10,11} tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) ops.convert_fp8(tensor, tensor_tmp) del tensor_tmp