Spaces:
Sleeping
Sleeping
import warnings | |
from collections import namedtuple | |
from typing import Any, Optional, Tuple, List, Callable, Dict | |
import torch | |
from torch.sparse._semi_structured_conversions import ( | |
sparse_semi_structured_from_dense_cutlass, | |
sparse_semi_structured_to_dense_cutlass, | |
) | |
from torch.sparse._semi_structured_ops import ( | |
fallback_dispatcher, | |
semi_sparse_values, | |
semi_sparse_indices, | |
semi_sparse_detach, | |
semi_sparse_t, | |
semi_sparse_view, | |
semi_sparse_mm, | |
semi_sparse_addmm, | |
semi_sparse_linear, | |
) | |
__all__ = [ | |
"SparseSemiStructuredTensor", | |
"SparseSemiStructuredTensorCUTLASS", | |
"SparseSemiStructuredTensorCUSPARSELT", | |
"to_sparse_semi_structured", | |
] | |
_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple( | |
"_SEMI_STRUCTURED_SPARSE_CONFIG", | |
"sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols", | |
) | |
class SparseSemiStructuredTensor(torch.Tensor): | |
""" | |
This class implementes semi-structured sparsity as a Tensor subclass. | |
Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse, | |
depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained | |
structured sparsity. | |
There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS. | |
This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS | |
and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items. | |
Note that as such, this class cannot be insantiated directly. | |
-`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints | |
- `def from_dense()` - backend specific compression routines | |
- `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_linear) | |
""" | |
_DEFAULT_ALG_ID: int = 0 | |
_DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG] | |
_FORCE_CUTLASS: bool = True | |
_FUSE_TRANSPOSE: bool = False | |
_PROTOTYPE_WARNING_SHOWN: bool = False | |
SPARSE_DISPATCH: Dict[Callable, Callable] | |
packed: Optional[torch.Tensor] | |
meta: Optional[torch.Tensor] | |
packed_t: Optional[torch.Tensor] | |
meta_t: Optional[torch.Tensor] | |
threads_masks: Optional[torch.Tensor] | |
fuse_transpose_cusparselt: bool | |
alg_id_cusparselt: int | |
__slots__ = ["packed", "meta", "packed_t", "meta_t", "threads_masks"] | |
def __new__( # noqa: PYI034 | |
cls, | |
shape: torch.Size, | |
packed: Optional[torch.Tensor], | |
meta: Optional[torch.Tensor], | |
packed_t: Optional[torch.Tensor], | |
meta_t: Optional[torch.Tensor], | |
threads_masks: Optional[torch.Tensor], | |
fuse_transpose_cusparselt: bool = False, | |
alg_id_cusparselt: int = 0, | |
requires_grad: bool = False, | |
): | |
""" | |
Create a new instance of the tensor subclass from the compressed sparse representation. | |
We have the option to create the subclass with the compressed representations of both X and X', for training. | |
For inference, we only need a single representation (either X or X'), while the corresponding other set will be None. | |
Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS) | |
Args: | |
shape: The shape of the original dense tensor | |
packed: The compressed representation of the original dense tensor | |
meta: The metadata of the original dense tensor, if it is stored separately | |
packed_t: The compressed representation of the transposed original dense tensor | |
meta_t: The metadata of the transposed original dense tensor, if it is stored separately | |
threads_masks: The masks used by the CUTLASS backend to determine which threads should participate in the computation. | |
Used for pointwise ops. | |
fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition | |
with a matmul, which is useful in the case of 2:4 sparse training. | |
alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance | |
Returns: | |
torch.Tensor: A torch.Tensor wrapper subclass. | |
Raises: | |
ValueError: If all of the tensor arguments are None. | |
""" | |
if not cls._PROTOTYPE_WARNING_SHOWN: | |
warnings.warn( | |
( | |
"The PyTorch API of SparseSemiStructuredTensor is in prototype stage " | |
"and will change in the near future. Please open a Github issue " | |
"for features requests and see our documentation on the torch.sparse " | |
"module for further information about the project." | |
), | |
UserWarning, | |
) | |
cls._PROTOTYPE_WARNING_SHOWN = True | |
# Because this only runs onces, we also load the dispatch table here as well. | |
# We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead | |
# But this is useful since it allows users to overload the dispatch table for debugging / testing. | |
cls._load_dispatch_table() | |
if packed is not None: | |
previous_tensor = packed | |
elif packed_t is not None: | |
previous_tensor = packed_t | |
else: | |
raise ValueError("At least one of packed or packed_t must be provided") | |
kwargs = { | |
"device": previous_tensor.device, | |
"dtype": previous_tensor.dtype, | |
"layout": previous_tensor.layout, | |
"requires_grad": requires_grad, | |
} | |
tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] | |
tensor.packed = packed | |
tensor.meta = meta | |
tensor.packed_t = packed_t | |
tensor.meta_t = meta_t | |
tensor.threads_masks = threads_masks | |
tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt | |
tensor.alg_id_cusparselt = alg_id_cusparselt | |
return tensor | |
def __repr__(self) -> str: # type: ignore[override] | |
assert hasattr(self, "shape") | |
return f"{self.__class__.__name__}(shape={self.shape})" | |
def __tensor_flatten__( | |
self, | |
) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]: | |
inner_tensors = list( | |
filter(lambda x: getattr(self, x) is not None, self.__slots__) | |
) | |
tensor_meta = ( | |
self.shape, | |
self.fuse_transpose_cusparselt, | |
self.alg_id_cusparselt, | |
self.requires_grad, | |
) | |
return inner_tensors, tensor_meta | |
def __tensor_unflatten__( | |
cls, | |
inner_tensors, | |
tensor_meta : Tuple[torch.Size, bool, int, bool], | |
outer_size, | |
outer_stride, | |
) -> torch.Tensor: | |
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta | |
return cls( | |
shape=shape, | |
packed=inner_tensors.get("packed", None), | |
meta=inner_tensors.get("meta", None), | |
packed_t=inner_tensors.get("packed_t", None), | |
meta_t=inner_tensors.get("meta_t", None), | |
threads_masks=inner_tensors.get("threads_masks", None), | |
fuse_transpose_cusparselt=fuse_transpose_cusparselt, | |
alg_id_cusparselt=alg_id_cusparselt, | |
requires_grad=requires_grad, | |
) | |
__torch_function__ = torch._C._disabled_torch_function_impl | |
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: | |
if func._overloadpacket not in cls.SPARSE_DISPATCH: | |
raise NotImplementedError( | |
f"{cls.__name__} only supports a specific set of operations, " | |
f"can't perform requested op ({func.__name__})" | |
) | |
return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs) | |
def _load_dispatch_table(cls, custom_dispatch_table=None) -> None: | |
""" | |
Loads the op overload sparse dispatch table for the current class. | |
""" | |
if getattr(cls, "SPARSE_DISPATCH", None) is None: | |
cls.SPARSE_DISPATCH = { | |
torch.ops.aten.values: semi_sparse_values, | |
torch.ops.aten.indices: semi_sparse_indices, | |
torch.ops.aten.is_same_size: fallback_dispatcher, | |
torch.ops.aten.detach_: fallback_dispatcher, | |
torch.ops.aten.detach: semi_sparse_detach, | |
torch.ops.aten.t: semi_sparse_t, | |
torch.ops.aten.view: semi_sparse_view, | |
torch.ops.aten.mm: semi_sparse_mm, | |
torch.ops.aten.matmul: semi_sparse_mm, | |
torch.ops.aten.addmm: semi_sparse_addmm, | |
torch.ops.aten.linear: semi_sparse_linear, | |
} | |
if custom_dispatch_table is not None: | |
cls.SPARSE_DISPATCH.update(custom_dispatch_table) | |
def _validate_device_dim_dtype_shape(cls, original_tensor : torch.Tensor) -> None: | |
""" | |
Assert that the given tensor is valid for semi-structured sparse compression. | |
""" | |
# check device | |
if not original_tensor.is_cuda: | |
raise RuntimeError( | |
f"Error original_tensor.device= {original_tensor.device} is not supported! " | |
"Only CUDA tensors are currently supported." | |
) | |
# check dim | |
if original_tensor.dim() != 2: | |
raise RuntimeError( | |
f"Error original_tensor.dim = {original_tensor.dim()} is not supported! " | |
"Only 2d tensors are currently supported." | |
) | |
# check contiguous | |
if not original_tensor.is_contiguous(): | |
raise RuntimeError( | |
"Error original_tensor is not contiguous!" | |
"Only contiguous tensors are currently supported." | |
) | |
# check dtype | |
if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS: | |
raise RuntimeError( | |
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! " | |
"dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}" | |
) | |
# check shape | |
m, n = original_tensor.shape | |
min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows | |
min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols | |
if m < min_rows or m % min_rows or n < min_cols or n % min_cols: | |
# TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples | |
raise RuntimeError( | |
f"Error original_tensor.shape {original_tensor.shape} is not supported! " | |
f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})" | |
) | |
def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor: | |
""" | |
Calculates padding for dense tensor and pads tensor if necessary. | |
If padding is not required, this function returns the original tensor. | |
""" | |
# only 2d matmul | |
assert dense_input.dim() == 2 | |
# check shape | |
m, n = dense_input.shape | |
min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows | |
min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols | |
# calculate padding | |
to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 | |
to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 | |
if to_pad_m or to_pad_n: | |
return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m)) | |
else: | |
return dense_input | |
def to_dense(self): | |
col = self.shape[-1] | |
return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device)) | |
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensor": | |
raise NotImplementedError | |
def _mm( | |
self, | |
B: torch.Tensor, | |
*, | |
bias: Optional[torch.Tensor] = None, | |
**kwargs, | |
) -> torch.Tensor: | |
raise NotImplementedError | |
def to_sparse_semi_structured( | |
original_tensor: torch.Tensor, | |
transposed: bool = False, | |
) -> SparseSemiStructuredTensor: | |
""" | |
This function converts a dense tensor into a sparse semi-structured tensor. | |
It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor. | |
This function will check to ensure the dense tensor has the right dtype, size, dims, and device. | |
We currently only support semi-structured sparse tensors for 2d CUDA tensors. | |
Additionally, your tensor must be a positive multiple of the mininum sparse block size, given in | |
`_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8). | |
Args: | |
original_tensor (Tensor): the dense tensor to convert | |
transposed (bool, optional): deprecated arg to be removed in another release. Do not use. | |
Returns: | |
SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor | |
Raises: | |
None | |
Example: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) | |
>>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda() | |
tensor([[0., 0., 1., ..., 0., 1., 1.], | |
[0., 0., 1., ..., 0., 1., 1.], | |
[0., 0., 1., ..., 0., 1., 1.], | |
..., | |
[0., 0., 1., ..., 0., 1., 1.], | |
[0., 0., 1., ..., 0., 1., 1.], | |
[0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16) | |
>>> A_sparse = to_sparse_semi_structured(A) | |
SparseSemiStructuredTensor(shape=torch.Size([128, 128])) | |
>>> A_sparse.values() | |
tensor([[1., 1., 1., ..., 1., 1., 1.], | |
[1., 1., 1., ..., 1., 1., 1.], | |
[1., 1., 1., ..., 1., 1., 1.], | |
..., | |
[1., 1., 1., ..., 1., 1., 1.], | |
[1., 1., 1., ..., 1., 1., 1.], | |
[1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16), | |
>>> A_sparse.indices() | |
tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370], | |
[-4370, -4370, -4370, ..., -4370, -4370, -4370], | |
[-4370, -4370, -4370, ..., -4370, -4370, -4370], | |
..., | |
[-4370, -4370, -4370, ..., -4370, -4370, -4370], | |
[-4370, -4370, -4370, ..., -4370, -4370, -4370], | |
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16)) | |
""" | |
if transposed: | |
raise DeprecationWarning( | |
"Setting transpose from to_sparse_semi_structured is deprecated and will be removed in a future release." | |
"SparseSemiStructuredTensor only support contiguous input tensors. " | |
) | |
sparse_subclass = ( | |
torch.sparse.SparseSemiStructuredTensorCUTLASS | |
if SparseSemiStructuredTensor._FORCE_CUTLASS | |
else torch.sparse.SparseSemiStructuredTensorCUSPARSELT | |
) | |
return sparse_subclass.from_dense(original_tensor) | |
class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): | |
""" | |
This class implements semi-structured sparsity for the CUTLASS backend. | |
In this implementation, the specified elements and metadata are stored seprately, | |
in packed and meta respectively. | |
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear | |
and sparse_semi_structured_from_dense for conversion to the compressed format. | |
""" | |
_DTYPE_SHAPE_CONSTRAINTS = { | |
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16), | |
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8), | |
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8), | |
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4), | |
} | |
def from_dense( | |
cls, original_tensor: torch.Tensor | |
) -> "SparseSemiStructuredTensorCUTLASS": | |
cls._validate_device_dim_dtype_shape(original_tensor) | |
( | |
sparse_tensor_cutlass, | |
meta_tensor_cutlass, | |
) = sparse_semi_structured_from_dense_cutlass(original_tensor) | |
return cls( | |
original_tensor.shape, | |
packed=sparse_tensor_cutlass, | |
meta=meta_tensor_cutlass, | |
packed_t=None, | |
meta_t=None, | |
threads_masks=None, | |
requires_grad=original_tensor.requires_grad, | |
) | |
def to_dense(self): | |
assert self.meta is not None and self.packed is not None | |
return ( | |
sparse_semi_structured_to_dense_cutlass( | |
self.packed, | |
self.meta, | |
) | |
if self.meta.ndim == 2 | |
else super().to_dense() | |
) | |
def _mm( | |
self, | |
B: torch.Tensor, | |
*, | |
bias: Optional[torch.Tensor] = None, | |
**kwargs | |
) -> torch.Tensor: | |
if isinstance(B, SparseSemiStructuredTensor): | |
raise ValueError( | |
"`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware" | |
) | |
cls_name = self.__class__.__name__ | |
if self.ndim != 2 or B.ndim != 2: | |
raise NotImplementedError( | |
f"`{cls_name}` matmul: Broadcasting is not implemented" | |
) | |
if self.packed is None or self.meta is None: | |
raise NotImplementedError( | |
f"`{cls_name}` matmul: operation is not supported" | |
) | |
else: | |
res = torch._sparse_semi_structured_linear( | |
B.t(), self.packed, self.meta, bias=bias | |
).t() | |
return res[: self.shape[0]] | |
class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor): | |
""" | |
The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor: | |
packed = [ specified elements of original tensor | metadata ] | |
For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements | |
The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t | |
attributes respectively. | |
cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well | |
as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes. | |
""" | |
_DTYPE_SHAPE_CONSTRAINTS = { | |
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16), | |
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8), | |
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8), | |
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(8, 8, 4, 4), | |
} | |
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensorCUSPARSELT": | |
cls._validate_device_dim_dtype_shape(original_tensor) | |
return cls( | |
shape=original_tensor.shape, | |
packed=torch._cslt_compress(original_tensor), | |
meta=None, | |
packed_t=None, | |
meta_t=None, | |
threads_masks=None, | |
fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE, | |
alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID, | |
requires_grad=original_tensor.requires_grad, | |
) | |
def _mm( | |
self, | |
B: torch.Tensor, | |
*, | |
bias: Optional[torch.Tensor] = None, | |
**kwargs | |
) -> torch.Tensor: | |
if isinstance(B, SparseSemiStructuredTensor): | |
raise ValueError( | |
"`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware" | |
) | |
if self.ndim != 2 or B.ndim != 2: | |
raise NotImplementedError( | |
f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented" | |
) | |
if B.dtype != self.dtype: | |
raise NotImplementedError( | |
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, " | |
f"with A.dtype={self.dtype} and B.dtype={B.dtype}. " | |
"This operation is only supported when A and B have the same data type." | |
) | |
if bias is not None and bias.dtype != self.dtype: | |
raise NotImplementedError( | |
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, " | |
"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. " | |
"This operation is only supported when A, B and C have the same data type." | |
) | |
if self.packed is None: | |
raise NotImplementedError( | |
f"`{self.__class__.__name__}` matmul: operation is not supported" | |
) | |
else: | |
res = torch._cslt_sparse_mm( | |
self.packed, | |
B, | |
bias=bias, | |
transpose_result=self.fuse_transpose_cusparselt, | |
alg_id=self.alg_id_cusparselt, | |
) | |
return res.t() if self.fuse_transpose_cusparselt else res | |