EricB HF Staff commited on
Commit
4dbd9c0
·
1 Parent(s): 16fc7e4

Improve tests for mps

Browse files
tests/kernels/conftest.py CHANGED
@@ -36,7 +36,6 @@ def create_kv_caches_with_random(
36
  seed: int = 0,
37
  device: Optional[str] = "cuda",
38
  ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
39
-
40
  if cache_dtype == "fp8" and head_size % 16:
41
  raise ValueError(
42
  f"Does not support key cache of type fp8 with head_size {head_size}"
 
36
  seed: int = 0,
37
  device: Optional[str] = "cuda",
38
  ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
 
39
  if cache_dtype == "fp8" and head_size % 16:
40
  raise ValueError(
41
  f"Does not support key cache of type fp8 with head_size {head_size}"
tests/kernels/test_attention.py CHANGED
@@ -43,6 +43,7 @@ if current_platform.is_mps():
43
  else:
44
  DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
45
 
 
46
  def ref_masked_attention(
47
  query: torch.Tensor,
48
  key: torch.Tensor,
@@ -232,7 +233,11 @@ def test_paged_attention(
232
  64,
233
  0,
234
  ),
235
- cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0] and not device.startswith("mps")),
 
 
 
 
236
  )
237
 
238
  elif version in ("v2", "rocm"):
@@ -295,7 +300,11 @@ def test_paged_attention(
295
  64,
296
  0,
297
  ),
298
- cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0] and not device.startswith("mps")),
 
 
 
 
299
  )
300
 
301
  else:
@@ -340,7 +349,11 @@ def test_paged_attention(
340
  k_scale,
341
  v_scale,
342
  ),
343
- cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0] and not device.startswith("mps")),
 
 
 
 
344
  )
345
 
346
  else:
 
43
  else:
44
  DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
45
 
46
+
47
  def ref_masked_attention(
48
  query: torch.Tensor,
49
  key: torch.Tensor,
 
233
  64,
234
  0,
235
  ),
236
+ cond=(
237
+ head_size == HEAD_SIZES[0]
238
+ and block_size == BLOCK_SIZES[0]
239
+ and not device.startswith("mps")
240
+ ),
241
  )
242
 
243
  elif version in ("v2", "rocm"):
 
300
  64,
301
  0,
302
  ),
303
+ cond=(
304
+ head_size == HEAD_SIZES[0]
305
+ and block_size == BLOCK_SIZES[0]
306
+ and not device.startswith("mps")
307
+ ),
308
  )
309
 
310
  else:
 
349
  k_scale,
350
  v_scale,
351
  ),
352
+ cond=(
353
+ head_size == HEAD_SIZES[0]
354
+ and block_size == BLOCK_SIZES[0]
355
+ and not device.startswith("mps")
356
+ ),
357
  )
358
 
359
  else:
tests/kernels/test_cache.py CHANGED
@@ -60,7 +60,9 @@ def test_copy_blocks(
60
  if kv_cache_dtype == "fp8" and head_size % 16:
61
  pytest.skip()
62
  current_platform.seed_everything(seed)
63
- torch.set_default_device(device)
 
 
64
  # Generate random block mappings where each source block is mapped to two
65
  # destination blocks.
66
  assert 2 * num_mappings <= num_blocks
@@ -144,13 +146,15 @@ def test_reshape_and_cache(
144
  if kv_cache_dtype == "fp8" and head_size % 16:
145
  pytest.skip()
146
  current_platform.seed_everything(seed)
147
- torch.set_default_device(device)
 
 
148
  # Create a random slot mapping.
149
  num_slots = block_size * num_blocks
150
  slot_mapping_lst = random.sample(range(num_slots), num_tokens)
151
- slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
152
 
153
- qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
154
  _, key, value = qkv.unbind(dim=1)
155
 
156
  # Create the KV caches.
@@ -262,7 +266,9 @@ def test_reshape_and_cache_flash(
262
  if current_platform.is_mps() and kv_cache_dtype == "fp8":
263
  pytest.skip("reshape_and_cache_flash doesn't support FP8 on MPS")
264
  current_platform.seed_everything(seed)
265
- torch.set_default_device(device)
 
 
266
 
267
  # Create a random slot mapping.
268
  num_slots = block_size * num_blocks
 
60
  if kv_cache_dtype == "fp8" and head_size % 16:
61
  pytest.skip()
62
  current_platform.seed_everything(seed)
63
+ # Don't set MPS as default device to avoid placeholder storage error
64
+ if not device.startswith("mps"):
65
+ torch.set_default_device(device)
66
  # Generate random block mappings where each source block is mapped to two
67
  # destination blocks.
68
  assert 2 * num_mappings <= num_blocks
 
146
  if kv_cache_dtype == "fp8" and head_size % 16:
147
  pytest.skip()
148
  current_platform.seed_everything(seed)
149
+ # Don't set MPS as default device to avoid placeholder storage error
150
+ if not device.startswith("mps"):
151
+ torch.set_default_device(device)
152
  # Create a random slot mapping.
153
  num_slots = block_size * num_blocks
154
  slot_mapping_lst = random.sample(range(num_slots), num_tokens)
155
+ slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
156
 
157
+ qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
158
  _, key, value = qkv.unbind(dim=1)
159
 
160
  # Create the KV caches.
 
266
  if current_platform.is_mps() and kv_cache_dtype == "fp8":
267
  pytest.skip("reshape_and_cache_flash doesn't support FP8 on MPS")
268
  current_platform.seed_everything(seed)
269
+ # Don't set MPS as default device to avoid placeholder storage error
270
+ if not device.startswith("mps"):
271
+ torch.set_default_device(device)
272
 
273
  # Create a random slot mapping.
274
  num_slots = block_size * num_blocks
tests/kernels/utils.py CHANGED
@@ -40,10 +40,18 @@ def fp8_allclose(
40
  """
41
  torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
42
 
 
 
 
 
 
 
 
 
43
  return bool(
44
  torch.all(
45
  torch.isclose(
46
- a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
47
  )
48
  ).item()
49
  )
@@ -68,25 +76,12 @@ def opcheck(
68
  *,
69
  test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
70
  raise_exception: bool = True,
71
- cond: bool = True
72
  ) -> Dict[str, str]:
73
  with unittest.mock.patch("torch.allclose", new=fp8_allclose):
74
  if not cond:
75
  return {}
76
-
77
- # Check if any arguments are on MPS device and skip opcheck if so
78
- # as MPS has issues with placeholder storage allocation in opcheck
79
- def is_mps_tensor(x):
80
- return hasattr(x, 'device') and x.device.type == 'mps'
81
-
82
- def check_args_for_mps(args):
83
- if isinstance(args, (list, tuple)):
84
- return any(check_args_for_mps(arg) for arg in args)
85
- return is_mps_tensor(args)
86
-
87
- if check_args_for_mps(args):
88
- return {}
89
-
90
  return torch.library.opcheck(
91
  op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
92
  )
 
40
  """
41
  torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
42
 
43
+ # MPS doesn't support float64, so use float32 for comparison
44
+ if a.device.type == "mps" or b.device.type == "mps":
45
+ a_cmp = a.float()
46
+ b_cmp = b.float()
47
+ else:
48
+ a_cmp = a.double()
49
+ b_cmp = b.double()
50
+
51
  return bool(
52
  torch.all(
53
  torch.isclose(
54
+ a_cmp, b_cmp, rtol=rtol, atol=atol, equal_nan=equal_nan
55
  )
56
  ).item()
57
  )
 
76
  *,
77
  test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
78
  raise_exception: bool = True,
79
+ cond: bool = True,
80
  ) -> Dict[str, str]:
81
  with unittest.mock.patch("torch.allclose", new=fp8_allclose):
82
  if not cond:
83
  return {}
84
+
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  return torch.library.opcheck(
86
  op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
87
  )