Spaces:
Running
Running
Merge pull request #278 from jhj0517/fix/sparse-mps
Browse files
modules/whisper/whisper_base.py
CHANGED
|
@@ -458,10 +458,30 @@ class WhisperBase(ABC):
|
|
| 458 |
if torch.cuda.is_available():
|
| 459 |
return "cuda"
|
| 460 |
elif torch.backends.mps.is_available():
|
|
|
|
|
|
|
|
|
|
| 461 |
return "mps"
|
| 462 |
else:
|
| 463 |
return "cpu"
|
| 464 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
@staticmethod
|
| 466 |
def release_cuda_memory():
|
| 467 |
"""Release memory"""
|
|
|
|
| 458 |
if torch.cuda.is_available():
|
| 459 |
return "cuda"
|
| 460 |
elif torch.backends.mps.is_available():
|
| 461 |
+
if not WhisperBase.is_sparse_api_supported():
|
| 462 |
+
# Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
|
| 463 |
+
return "cpu"
|
| 464 |
return "mps"
|
| 465 |
else:
|
| 466 |
return "cpu"
|
| 467 |
|
| 468 |
+
@staticmethod
|
| 469 |
+
def is_sparse_api_supported():
|
| 470 |
+
if not torch.backends.mps.is_available():
|
| 471 |
+
return False
|
| 472 |
+
|
| 473 |
+
try:
|
| 474 |
+
device = torch.device("mps")
|
| 475 |
+
sparse_tensor = torch.sparse_coo_tensor(
|
| 476 |
+
indices=torch.tensor([[0, 1], [2, 3]]),
|
| 477 |
+
values=torch.tensor([1, 2]),
|
| 478 |
+
size=(4, 4),
|
| 479 |
+
device=device
|
| 480 |
+
)
|
| 481 |
+
return True
|
| 482 |
+
except RuntimeError:
|
| 483 |
+
return False
|
| 484 |
+
|
| 485 |
@staticmethod
|
| 486 |
def release_cuda_memory():
|
| 487 |
"""Release memory"""
|