Improve tests for mps
Browse files- tests/kernels/conftest.py +0 -1
- tests/kernels/test_attention.py +16 -3
- tests/kernels/test_cache.py +11 -5
- tests/kernels/utils.py +11 -16
tests/kernels/conftest.py
CHANGED
@@ -36,7 +36,6 @@ def create_kv_caches_with_random(
|
|
36 |
seed: int = 0,
|
37 |
device: Optional[str] = "cuda",
|
38 |
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
39 |
-
|
40 |
if cache_dtype == "fp8" and head_size % 16:
|
41 |
raise ValueError(
|
42 |
f"Does not support key cache of type fp8 with head_size {head_size}"
|
|
|
36 |
seed: int = 0,
|
37 |
device: Optional[str] = "cuda",
|
38 |
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
|
|
39 |
if cache_dtype == "fp8" and head_size % 16:
|
40 |
raise ValueError(
|
41 |
f"Does not support key cache of type fp8 with head_size {head_size}"
|
tests/kernels/test_attention.py
CHANGED
@@ -43,6 +43,7 @@ if current_platform.is_mps():
|
|
43 |
else:
|
44 |
DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
45 |
|
|
|
46 |
def ref_masked_attention(
|
47 |
query: torch.Tensor,
|
48 |
key: torch.Tensor,
|
@@ -232,7 +233,11 @@ def test_paged_attention(
|
|
232 |
64,
|
233 |
0,
|
234 |
),
|
235 |
-
cond=(
|
|
|
|
|
|
|
|
|
236 |
)
|
237 |
|
238 |
elif version in ("v2", "rocm"):
|
@@ -295,7 +300,11 @@ def test_paged_attention(
|
|
295 |
64,
|
296 |
0,
|
297 |
),
|
298 |
-
cond=(
|
|
|
|
|
|
|
|
|
299 |
)
|
300 |
|
301 |
else:
|
@@ -340,7 +349,11 @@ def test_paged_attention(
|
|
340 |
k_scale,
|
341 |
v_scale,
|
342 |
),
|
343 |
-
cond=(
|
|
|
|
|
|
|
|
|
344 |
)
|
345 |
|
346 |
else:
|
|
|
43 |
else:
|
44 |
DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
45 |
|
46 |
+
|
47 |
def ref_masked_attention(
|
48 |
query: torch.Tensor,
|
49 |
key: torch.Tensor,
|
|
|
233 |
64,
|
234 |
0,
|
235 |
),
|
236 |
+
cond=(
|
237 |
+
head_size == HEAD_SIZES[0]
|
238 |
+
and block_size == BLOCK_SIZES[0]
|
239 |
+
and not device.startswith("mps")
|
240 |
+
),
|
241 |
)
|
242 |
|
243 |
elif version in ("v2", "rocm"):
|
|
|
300 |
64,
|
301 |
0,
|
302 |
),
|
303 |
+
cond=(
|
304 |
+
head_size == HEAD_SIZES[0]
|
305 |
+
and block_size == BLOCK_SIZES[0]
|
306 |
+
and not device.startswith("mps")
|
307 |
+
),
|
308 |
)
|
309 |
|
310 |
else:
|
|
|
349 |
k_scale,
|
350 |
v_scale,
|
351 |
),
|
352 |
+
cond=(
|
353 |
+
head_size == HEAD_SIZES[0]
|
354 |
+
and block_size == BLOCK_SIZES[0]
|
355 |
+
and not device.startswith("mps")
|
356 |
+
),
|
357 |
)
|
358 |
|
359 |
else:
|
tests/kernels/test_cache.py
CHANGED
@@ -60,7 +60,9 @@ def test_copy_blocks(
|
|
60 |
if kv_cache_dtype == "fp8" and head_size % 16:
|
61 |
pytest.skip()
|
62 |
current_platform.seed_everything(seed)
|
63 |
-
|
|
|
|
|
64 |
# Generate random block mappings where each source block is mapped to two
|
65 |
# destination blocks.
|
66 |
assert 2 * num_mappings <= num_blocks
|
@@ -144,13 +146,15 @@ def test_reshape_and_cache(
|
|
144 |
if kv_cache_dtype == "fp8" and head_size % 16:
|
145 |
pytest.skip()
|
146 |
current_platform.seed_everything(seed)
|
147 |
-
|
|
|
|
|
148 |
# Create a random slot mapping.
|
149 |
num_slots = block_size * num_blocks
|
150 |
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
151 |
-
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
|
152 |
|
153 |
-
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
|
154 |
_, key, value = qkv.unbind(dim=1)
|
155 |
|
156 |
# Create the KV caches.
|
@@ -262,7 +266,9 @@ def test_reshape_and_cache_flash(
|
|
262 |
if current_platform.is_mps() and kv_cache_dtype == "fp8":
|
263 |
pytest.skip("reshape_and_cache_flash doesn't support FP8 on MPS")
|
264 |
current_platform.seed_everything(seed)
|
265 |
-
|
|
|
|
|
266 |
|
267 |
# Create a random slot mapping.
|
268 |
num_slots = block_size * num_blocks
|
|
|
60 |
if kv_cache_dtype == "fp8" and head_size % 16:
|
61 |
pytest.skip()
|
62 |
current_platform.seed_everything(seed)
|
63 |
+
# Don't set MPS as default device to avoid placeholder storage error
|
64 |
+
if not device.startswith("mps"):
|
65 |
+
torch.set_default_device(device)
|
66 |
# Generate random block mappings where each source block is mapped to two
|
67 |
# destination blocks.
|
68 |
assert 2 * num_mappings <= num_blocks
|
|
|
146 |
if kv_cache_dtype == "fp8" and head_size % 16:
|
147 |
pytest.skip()
|
148 |
current_platform.seed_everything(seed)
|
149 |
+
# Don't set MPS as default device to avoid placeholder storage error
|
150 |
+
if not device.startswith("mps"):
|
151 |
+
torch.set_default_device(device)
|
152 |
# Create a random slot mapping.
|
153 |
num_slots = block_size * num_blocks
|
154 |
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
155 |
+
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
156 |
|
157 |
+
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
|
158 |
_, key, value = qkv.unbind(dim=1)
|
159 |
|
160 |
# Create the KV caches.
|
|
|
266 |
if current_platform.is_mps() and kv_cache_dtype == "fp8":
|
267 |
pytest.skip("reshape_and_cache_flash doesn't support FP8 on MPS")
|
268 |
current_platform.seed_everything(seed)
|
269 |
+
# Don't set MPS as default device to avoid placeholder storage error
|
270 |
+
if not device.startswith("mps"):
|
271 |
+
torch.set_default_device(device)
|
272 |
|
273 |
# Create a random slot mapping.
|
274 |
num_slots = block_size * num_blocks
|
tests/kernels/utils.py
CHANGED
@@ -40,10 +40,18 @@ def fp8_allclose(
|
|
40 |
"""
|
41 |
torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
return bool(
|
44 |
torch.all(
|
45 |
torch.isclose(
|
46 |
-
|
47 |
)
|
48 |
).item()
|
49 |
)
|
@@ -68,25 +76,12 @@ def opcheck(
|
|
68 |
*,
|
69 |
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
70 |
raise_exception: bool = True,
|
71 |
-
cond: bool = True
|
72 |
) -> Dict[str, str]:
|
73 |
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
|
74 |
if not cond:
|
75 |
return {}
|
76 |
-
|
77 |
-
# Check if any arguments are on MPS device and skip opcheck if so
|
78 |
-
# as MPS has issues with placeholder storage allocation in opcheck
|
79 |
-
def is_mps_tensor(x):
|
80 |
-
return hasattr(x, 'device') and x.device.type == 'mps'
|
81 |
-
|
82 |
-
def check_args_for_mps(args):
|
83 |
-
if isinstance(args, (list, tuple)):
|
84 |
-
return any(check_args_for_mps(arg) for arg in args)
|
85 |
-
return is_mps_tensor(args)
|
86 |
-
|
87 |
-
if check_args_for_mps(args):
|
88 |
-
return {}
|
89 |
-
|
90 |
return torch.library.opcheck(
|
91 |
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
92 |
)
|
|
|
40 |
"""
|
41 |
torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
|
42 |
|
43 |
+
# MPS doesn't support float64, so use float32 for comparison
|
44 |
+
if a.device.type == "mps" or b.device.type == "mps":
|
45 |
+
a_cmp = a.float()
|
46 |
+
b_cmp = b.float()
|
47 |
+
else:
|
48 |
+
a_cmp = a.double()
|
49 |
+
b_cmp = b.double()
|
50 |
+
|
51 |
return bool(
|
52 |
torch.all(
|
53 |
torch.isclose(
|
54 |
+
a_cmp, b_cmp, rtol=rtol, atol=atol, equal_nan=equal_nan
|
55 |
)
|
56 |
).item()
|
57 |
)
|
|
|
76 |
*,
|
77 |
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
78 |
raise_exception: bool = True,
|
79 |
+
cond: bool = True,
|
80 |
) -> Dict[str, str]:
|
81 |
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
|
82 |
if not cond:
|
83 |
return {}
|
84 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
return torch.library.opcheck(
|
86 |
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
87 |
)
|