danieldk HF Staff commited on
Commit
150f8c2
·
1 Parent(s): daf6c87

Build (aarch64-linux)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch28-cxx11-cu126-aarch64-linux/quantization/__init__.py +48 -0
  2. build/torch28-cxx11-cu126-aarch64-linux/quantization/__pycache__/__init__.cpython-313.pyc +0 -0
  3. build/torch28-cxx11-cu126-aarch64-linux/quantization/__pycache__/_ops.cpython-313.pyc +0 -0
  4. build/torch28-cxx11-cu126-aarch64-linux/quantization/__pycache__/compressed_tensors.cpython-313.pyc +0 -0
  5. build/torch28-cxx11-cu126-aarch64-linux/quantization/__pycache__/cutlass.cpython-313.pyc +0 -0
  6. build/torch28-cxx11-cu126-aarch64-linux/quantization/__pycache__/marlin.cpython-313.pyc +0 -0
  7. build/torch28-cxx11-cu126-aarch64-linux/quantization/__pycache__/platforms.cpython-313.pyc +0 -0
  8. build/torch28-cxx11-cu126-aarch64-linux/quantization/__pycache__/scalar_type.cpython-313.pyc +0 -0
  9. build/torch28-cxx11-cu126-aarch64-linux/quantization/_ops.py +9 -0
  10. build/torch28-cxx11-cu126-aarch64-linux/quantization/_quantization_eabe7c2.abi3.so +3 -0
  11. build/torch28-cxx11-cu126-aarch64-linux/quantization/compressed_tensors.py +113 -0
  12. build/torch28-cxx11-cu126-aarch64-linux/quantization/cutlass.py +69 -0
  13. build/torch28-cxx11-cu126-aarch64-linux/quantization/marlin.py +174 -0
  14. build/torch28-cxx11-cu126-aarch64-linux/quantization/platforms.py +104 -0
  15. build/torch28-cxx11-cu126-aarch64-linux/quantization/scalar_type.py +347 -0
  16. build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/__init__.py +0 -0
  17. build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/__pycache__/__init__.cpython-313.pyc +0 -0
  18. build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/__pycache__/marlin_utils.cpython-313.pyc +0 -0
  19. build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/__pycache__/marlin_utils_fp4.cpython-313.pyc +0 -0
  20. build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/__pycache__/marlin_utils_fp8.cpython-313.pyc +0 -0
  21. build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/__pycache__/quant_utils.cpython-313.pyc +0 -0
  22. build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils.py +451 -0
  23. build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_fp4.py +281 -0
  24. build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_fp8.py +122 -0
  25. build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_test.py +161 -0
  26. build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_test_24.py +472 -0
  27. build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_test_qqq.py +125 -0
  28. build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/quant_utils.py +470 -0
  29. build/torch28-cxx11-cu128-aarch64-linux/quantization/__init__.py +48 -0
  30. build/torch28-cxx11-cu128-aarch64-linux/quantization/__pycache__/__init__.cpython-313.pyc +0 -0
  31. build/torch28-cxx11-cu128-aarch64-linux/quantization/__pycache__/_ops.cpython-313.pyc +0 -0
  32. build/torch28-cxx11-cu128-aarch64-linux/quantization/__pycache__/compressed_tensors.cpython-313.pyc +0 -0
  33. build/torch28-cxx11-cu128-aarch64-linux/quantization/__pycache__/cutlass.cpython-313.pyc +0 -0
  34. build/torch28-cxx11-cu128-aarch64-linux/quantization/__pycache__/marlin.cpython-313.pyc +0 -0
  35. build/torch28-cxx11-cu128-aarch64-linux/quantization/__pycache__/platforms.cpython-313.pyc +0 -0
  36. build/torch28-cxx11-cu128-aarch64-linux/quantization/__pycache__/scalar_type.cpython-313.pyc +0 -0
  37. build/torch28-cxx11-cu128-aarch64-linux/quantization/_ops.py +9 -0
  38. build/torch28-cxx11-cu128-aarch64-linux/quantization/_quantization_eabe7c2.abi3.so +3 -0
  39. build/torch28-cxx11-cu128-aarch64-linux/quantization/compressed_tensors.py +113 -0
  40. build/torch28-cxx11-cu128-aarch64-linux/quantization/cutlass.py +69 -0
  41. build/torch28-cxx11-cu128-aarch64-linux/quantization/marlin.py +174 -0
  42. build/torch28-cxx11-cu128-aarch64-linux/quantization/platforms.py +104 -0
  43. build/torch28-cxx11-cu128-aarch64-linux/quantization/scalar_type.py +347 -0
  44. build/torch28-cxx11-cu128-aarch64-linux/quantization/utils/__init__.py +0 -0
  45. build/torch28-cxx11-cu128-aarch64-linux/quantization/utils/__pycache__/__init__.cpython-313.pyc +0 -0
  46. build/torch28-cxx11-cu128-aarch64-linux/quantization/utils/__pycache__/marlin_utils.cpython-313.pyc +0 -0
  47. build/torch28-cxx11-cu128-aarch64-linux/quantization/utils/__pycache__/marlin_utils_fp4.cpython-313.pyc +0 -0
  48. build/torch28-cxx11-cu128-aarch64-linux/quantization/utils/__pycache__/marlin_utils_fp8.cpython-313.pyc +0 -0
  49. build/torch28-cxx11-cu128-aarch64-linux/quantization/utils/__pycache__/quant_utils.cpython-313.pyc +0 -0
  50. build/torch28-cxx11-cu128-aarch64-linux/quantization/utils/marlin_utils.py +451 -0
build/torch28-cxx11-cu126-aarch64-linux/quantization/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
+ from .cutlass import (
3
+ cutlass_scaled_mm_supports_block_fp8,
4
+ cutlass_scaled_mm_supports_fp8,
5
+ cutlass_scaled_mm,
6
+ cutlass_scaled_mm_azp,
7
+ )
8
+ from .marlin import (
9
+ awq_marlin_repack,
10
+ gptq_marlin_gemm,
11
+ gptq_marlin_repack,
12
+ gptq_marlin_24_gemm,
13
+ marlin_qqq_gemm,
14
+ marlin_gemm,
15
+ )
16
+ from .scalar_type import (
17
+ ScalarType,
18
+ scalar_types,
19
+ )
20
+ from ._ops import ops
21
+
22
+ from .utils import marlin_utils
23
+ from .utils import marlin_utils_fp4
24
+ from .utils import marlin_utils_fp8
25
+ from .utils import quant_utils
26
+
27
+
28
+ __all__ = [
29
+ "ScalarType",
30
+ "awq_marlin_repack",
31
+ "cutlass_scaled_mm",
32
+ "cutlass_scaled_mm_azp",
33
+ "cutlass_scaled_mm_supports_block_fp8",
34
+ "cutlass_scaled_mm_supports_fp8",
35
+ "gptq_marlin_24_gemm",
36
+ "gptq_marlin_gemm",
37
+ "gptq_marlin_repack",
38
+ "marlin_gemm",
39
+ "marlin_qqq_gemm",
40
+ "marlin_utils",
41
+ "marlin_utils_fp4",
42
+ "marlin_utils_fp8",
43
+ "ops",
44
+ "quant_utils",
45
+ "scalar_types",
46
+ "scaled_fp8_quant",
47
+ "scaled_int8_quant",
48
+ ]
build/torch28-cxx11-cu126-aarch64-linux/quantization/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (1.02 kB). View file
 
build/torch28-cxx11-cu126-aarch64-linux/quantization/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (533 Bytes). View file
 
build/torch28-cxx11-cu126-aarch64-linux/quantization/__pycache__/compressed_tensors.cpython-313.pyc ADDED
Binary file (5.2 kB). View file
 
build/torch28-cxx11-cu126-aarch64-linux/quantization/__pycache__/cutlass.cpython-313.pyc ADDED
Binary file (3.87 kB). View file
 
build/torch28-cxx11-cu126-aarch64-linux/quantization/__pycache__/marlin.cpython-313.pyc ADDED
Binary file (7.84 kB). View file
 
build/torch28-cxx11-cu126-aarch64-linux/quantization/__pycache__/platforms.cpython-313.pyc ADDED
Binary file (5.8 kB). View file
 
build/torch28-cxx11-cu126-aarch64-linux/quantization/__pycache__/scalar_type.cpython-313.pyc ADDED
Binary file (14.2 kB). View file
 
build/torch28-cxx11-cu126-aarch64-linux/quantization/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _quantization_eabe7c2
3
+ ops = torch.ops._quantization_eabe7c2
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_quantization_eabe7c2::{op_name}"
build/torch28-cxx11-cu126-aarch64-linux/quantization/_quantization_eabe7c2.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7914703f365b7ffaf34d6ea723d2ab56b1ba91e32f591115d8022ccfad26eb7
3
+ size 159926048
build/torch28-cxx11-cu126-aarch64-linux/quantization/compressed_tensors.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+ from .platforms import current_platform
7
+
8
+
9
+ # fp8
10
+ def scaled_fp8_quant(
11
+ input: torch.Tensor,
12
+ scale: Optional[torch.Tensor] = None,
13
+ num_token_padding: Optional[int] = None,
14
+ scale_ub: Optional[torch.Tensor] = None,
15
+ use_per_token_if_dynamic: bool = False,
16
+ output: Optional[torch.Tensor] = None,
17
+ ) -> tuple[torch.Tensor, torch.Tensor]:
18
+ """
19
+ Quantize input tensor to FP8 and return quantized tensor and scale.
20
+
21
+ This function supports both static and dynamic quantization: If you
22
+ provide the scale, it will use static scaling and if you omit it,
23
+ the scale will be determined dynamically. The function also allows
24
+ optional padding of the output tensors for downstream kernels that
25
+ will benefit from padding.
26
+
27
+ Args:
28
+ input: The input tensor to be quantized to FP8
29
+ scale: Optional scaling factor for the FP8 quantization
30
+ scale_ub: Optional upper bound for scaling factor in dynamic
31
+ per token case
32
+ num_token_padding: If specified, pad the first dimension
33
+ of the output to at least this value.
34
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
35
+ in the dynamic quantization case.
36
+
37
+ Returns:
38
+ tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
39
+ scaling factor.
40
+ """
41
+ # This code assumes batch_dim and num_tokens are flattened
42
+ assert (input.ndim == 2)
43
+ shape: Union[tuple[int, int], torch.Size] = input.shape
44
+ # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
45
+ out_dtype: torch.dtype = current_platform.fp8_dtype()
46
+ if num_token_padding:
47
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
48
+ if output is None:
49
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
50
+ else:
51
+ assert num_token_padding is None, \
52
+ "padding not supported if output passed in"
53
+ assert output.dtype == out_dtype
54
+
55
+ if scale is None:
56
+ if use_per_token_if_dynamic:
57
+ scale = torch.empty((shape[0], 1),
58
+ device=input.device,
59
+ dtype=torch.float32)
60
+ ops.dynamic_per_token_scaled_fp8_quant(
61
+ output, input.contiguous(), scale, scale_ub)
62
+ else:
63
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
64
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
65
+ else:
66
+ # num_token_padding not implemented for this case
67
+ assert (scale.numel() == 1 and num_token_padding is None)
68
+ ops.static_scaled_fp8_quant(output, input, scale)
69
+
70
+ return output, scale
71
+
72
+
73
+ # int8
74
+ def scaled_int8_quant(
75
+ input: torch.Tensor,
76
+ scale: Optional[torch.Tensor] = None,
77
+ azp: Optional[torch.Tensor] = None,
78
+ symmetric: bool = True
79
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
80
+ """
81
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
82
+
83
+ Args:
84
+ input: The input tensor to be quantized to int8.
85
+ scale: Optional scaling factor for the int8 quantization.
86
+ When not provided, we invoke dynamic-per-token quantization.
87
+ azp: Optional zero-point for the int8 quantization.
88
+ Must be provided for asymmetric quantization if `scale` is provided.
89
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
90
+
91
+ Returns:
92
+ tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
93
+ """
94
+ output = torch.empty_like(input, dtype=torch.int8)
95
+ if scale is not None:
96
+ # static-per-tensor quantization.
97
+ assert symmetric == (
98
+ azp
99
+ is None), "azp must only be provided for asymmetric quantization."
100
+ ops.static_scaled_int8_quant(output, input, scale, azp)
101
+ return output, scale, azp
102
+
103
+ # dynamic-per-token quantization.
104
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
105
+ device=input.device,
106
+ dtype=torch.float32)
107
+ input_azp = None if symmetric else torch.empty_like(input_scales,
108
+ dtype=torch.int32)
109
+ ops.dynamic_scaled_int8_quant(output, input.contiguous(),
110
+ input_scales, input_azp)
111
+ return output, input_scales, input_azp
112
+
113
+
build/torch28-cxx11-cu126-aarch64-linux/quantization/cutlass.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+ from .platforms import current_platform
7
+
8
+
9
+ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
10
+ return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
11
+
12
+
13
+ def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool:
14
+ return ops.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability)
15
+
16
+
17
+ def cutlass_scaled_mm(
18
+ a: torch.Tensor,
19
+ b: torch.Tensor,
20
+ scale_a: torch.Tensor,
21
+ scale_b: torch.Tensor,
22
+ out_dtype: torch.dtype,
23
+ bias: Optional[torch.Tensor] = None,
24
+ ) -> torch.Tensor:
25
+ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
26
+ assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
27
+ assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype
28
+
29
+ m = a.shape[0]
30
+ n = b.shape[1]
31
+
32
+ cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
33
+ if not cutlass_compatible_b:
34
+ from .triton_scaled_mm import triton_scaled_mm
35
+ return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
36
+
37
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
38
+
39
+ ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
40
+
41
+ return out
42
+
43
+
44
+ def cutlass_scaled_mm_azp(
45
+ a: torch.Tensor,
46
+ b: torch.Tensor,
47
+ scale_a: torch.Tensor,
48
+ scale_b: torch.Tensor,
49
+ out_dtype: torch.dtype,
50
+ azp_adj: torch.Tensor,
51
+ azp: Optional[torch.Tensor] = None,
52
+ bias: Optional[torch.Tensor] = None,
53
+ ) -> torch.Tensor:
54
+ """
55
+ :param azp_adj: In the per-tensor case, this should include the azp.
56
+ Always per-channel.
57
+ :param azp: Only set in the per-token case. Per-token if set.
58
+ """
59
+ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
60
+ assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
61
+ assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
62
+ assert azp is None or azp.numel() == a.shape[0]
63
+
64
+ m = a.shape[0]
65
+ n = b.shape[1]
66
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
67
+
68
+ ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
69
+ return out
build/torch28-cxx11-cu126-aarch64-linux/quantization/marlin.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Optional
2
+
3
+ import torch
4
+
5
+ # neuron has torch version that doesn't even have impl_abstract
6
+ if TYPE_CHECKING:
7
+ def register_fake(fn):
8
+ return lambda name: fn
9
+ else:
10
+ try:
11
+ from torch.library import register_fake
12
+ except ImportError:
13
+ from torch.library import impl_abstract as register_fake
14
+
15
+ try:
16
+ from ._ops import ops, add_op_namespace_prefix
17
+ except ImportError as e:
18
+ # Fallback for local development.
19
+ try:
20
+ import _quantization
21
+
22
+ ops = torch.ops._quantization
23
+
24
+ def add_op_namespace_prefix(op_name: str):
25
+ return f"_quantization::{op_name}"
26
+ except ImportError:
27
+ raise e
28
+
29
+
30
+ from .scalar_type import ScalarType
31
+
32
+
33
+ # gptq_marlin
34
+ def gptq_marlin_gemm(a: torch.Tensor,
35
+ c: Optional[torch.Tensor],
36
+ b_q_weight: torch.Tensor,
37
+ b_scales: torch.Tensor,
38
+ global_scale: Optional[torch.Tensor],
39
+ b_zeros: Optional[torch.Tensor],
40
+ g_idx: Optional[torch.Tensor],
41
+ perm: Optional[torch.Tensor],
42
+ workspace: torch.Tensor,
43
+ b_q_type: ScalarType,
44
+ size_m: int,
45
+ size_n: int,
46
+ size_k: int,
47
+ is_k_full: bool = True,
48
+ use_atomic_add: bool = False,
49
+ use_fp32_reduce: bool = False,
50
+ is_zp_float: bool = False) -> torch.Tensor:
51
+ return ops.gptq_marlin_gemm(a, c, b_q_weight, b_scales,
52
+ global_scale, b_zeros, g_idx, perm,
53
+ workspace, b_q_type.id, size_m,
54
+ size_n, size_k, is_k_full,
55
+ use_atomic_add, use_fp32_reduce,
56
+ is_zp_float)
57
+
58
+ # gptq_marlin
59
+ def gptq_marlin_repack(
60
+ b_q_weight: torch.Tensor,
61
+ perm: torch.Tensor,
62
+ size_k: int,
63
+ size_n: int,
64
+ num_bits: int,
65
+ ) -> torch.Tensor:
66
+ return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
67
+
68
+
69
+ # gptq_marlin
70
+ def awq_marlin_repack(
71
+ b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
72
+ ) -> torch.Tensor:
73
+ return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
74
+
75
+
76
+ # marlin
77
+ def marlin_gemm(
78
+ a: torch.Tensor,
79
+ b_q_weight: torch.Tensor,
80
+ b_scales: torch.Tensor,
81
+ workspace: torch.Tensor,
82
+ size_m: int,
83
+ size_n: int,
84
+ size_k: int,
85
+ ) -> torch.Tensor:
86
+ return ops.marlin_gemm(
87
+ a, b_q_weight, b_scales, workspace, size_m, size_n, size_k
88
+ )
89
+
90
+
91
+ # marlin_24
92
+ def gptq_marlin_24_gemm(
93
+ a: torch.Tensor,
94
+ b_q_weight: torch.Tensor,
95
+ b_meta: torch.Tensor,
96
+ b_scales: torch.Tensor,
97
+ workspace: torch.Tensor,
98
+ b_q_type: ScalarType,
99
+ size_m: int,
100
+ size_n: int,
101
+ size_k: int,
102
+ ) -> torch.Tensor:
103
+ return ops.gptq_marlin_24_gemm(
104
+ a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
105
+ )
106
+
107
+
108
+ # qqq ops
109
+ def marlin_qqq_gemm(
110
+ a: torch.Tensor,
111
+ b_q_weight: torch.Tensor,
112
+ s_tok: torch.Tensor,
113
+ s_ch: torch.Tensor,
114
+ s_group: torch.Tensor,
115
+ workspace: torch.Tensor,
116
+ size_m: int,
117
+ size_n: int,
118
+ size_k: int,
119
+ ) -> torch.Tensor:
120
+ return ops.marlin_qqq_gemm(
121
+ a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k
122
+ )
123
+
124
+
125
+ # Fake ops
126
+
127
+ if hasattr(ops, "gptq_marlin_24_gemm"):
128
+ @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
129
+ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
130
+ b_meta: torch.Tensor, b_scales: torch.Tensor,
131
+ workspace: torch.Tensor,
132
+ b_q_type: ScalarType, size_m: torch.SymInt,
133
+ size_n: torch.SymInt,
134
+ size_k: torch.SymInt) -> torch.Tensor:
135
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
136
+
137
+ @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
138
+ def _gptq_marlin_gemm_fake(a: torch.Tensor,
139
+ c: Optional[torch.Tensor],
140
+ b_q_weight: torch.Tensor,
141
+ b_scales: torch.Tensor,
142
+ global_scale: Optional[torch.Tensor],
143
+ b_zeros: Optional[torch.Tensor],
144
+ g_idx: Optional[torch.Tensor],
145
+ perm: Optional[torch.Tensor],
146
+ workspace: torch.Tensor,
147
+ b_q_type_id: int,
148
+ size_m: torch.SymInt,
149
+ size_n: torch.SymInt,
150
+ size_k: torch.SymInt,
151
+ is_k_full: bool = True,
152
+ use_atomic_add: bool = False,
153
+ use_fp32_reduce: bool = False,
154
+ is_zp_float: bool = False) -> torch.Tensor:
155
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
156
+
157
+ @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
158
+ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
159
+ s_tok: torch.Tensor, s_ch: torch.Tensor,
160
+ s_group: torch.Tensor, workspace: torch.Tensor,
161
+ size_m: torch.SymInt, size_n: torch.SymInt,
162
+ size_k: torch.SymInt) -> torch.Tensor:
163
+ return torch.empty((size_m, size_n),
164
+ dtype=torch.float16,
165
+ device=a.device)
166
+
167
+ @register_fake(add_op_namespace_prefix("marlin_gemm"))
168
+ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
169
+ b_scales: torch.Tensor, workspace: torch.Tensor,
170
+ size_m: torch.SymInt, size_n: torch.SymInt,
171
+ size_k: torch.SymInt) -> torch.Tensor:
172
+ return torch.empty((size_m, size_n),
173
+ dtype=torch.float16,
174
+ device=a.device)
build/torch28-cxx11-cu126-aarch64-linux/quantization/platforms.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from functools import lru_cache
3
+ from typing import NamedTuple
4
+
5
+ import torch
6
+
7
+ IS_ROCM = torch.version.hip is not None
8
+
9
+
10
+ class DeviceCapability(NamedTuple):
11
+ major: int
12
+ minor: int
13
+
14
+ def as_version_str(self) -> str:
15
+ return f"{self.major}.{self.minor}"
16
+
17
+ def to_int(self) -> int:
18
+ """
19
+ Express device capability as an integer ``<major><minor>``.
20
+
21
+ It is assumed that the minor version is always a single digit.
22
+ """
23
+ assert 0 <= self.minor < 10
24
+ return self.major * 10 + self.minor
25
+
26
+
27
+ class Platform(ABC):
28
+ simple_compile_backend: str = "inductor"
29
+
30
+ @classmethod
31
+ def fp8_dtype(cls) -> torch.dtype:
32
+ """
33
+ Returns the preferred FP8 type on the current platform.
34
+
35
+ See the documentation for is_fp8_fnuz for details.
36
+ """
37
+ return torch.float8_e4m3fn
38
+
39
+ @classmethod
40
+ def is_fp8_fnuz(cls) -> bool:
41
+ """
42
+ Returns whether the preferred FP8 type is FNUZ on the current platform.
43
+
44
+ There are two representations of FP8, OCP FP8 and FNUZ FP8.
45
+ The OCP specification can be found at https://tinyurl.com/b7jvwpft.
46
+ The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.
47
+
48
+ AMD's MI300 and MI325 have native hardware support for FNUZ. All other
49
+ hardware has converged on the OCP FP8 standard.
50
+ """
51
+ return False
52
+
53
+ @classmethod
54
+ @abstractmethod
55
+ def get_device_name(cls, device_id: int = 0) -> str: ...
56
+
57
+ @abstractmethod
58
+ def is_rocm(self): ...
59
+
60
+
61
+ class CudaPlatform(Platform):
62
+ @classmethod
63
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
64
+ major, minor = torch.cuda.get_device_capability(device_id)
65
+ return DeviceCapability(major=major, minor=minor)
66
+
67
+ @classmethod
68
+ @lru_cache(maxsize=8)
69
+ def get_device_name(cls, device_id: int = 0) -> str:
70
+ return torch.cuda.get_device_name(0)
71
+
72
+ def is_rocm(self):
73
+ return False
74
+
75
+
76
+ class RocmPlatform(Platform):
77
+ @classmethod
78
+ def fp8_dtype(cls) -> torch.dtype:
79
+ if cls.is_fp8_fnuz():
80
+ return torch.float8_e4m3fnuz
81
+ else:
82
+ return torch.float8_e4m3fn
83
+
84
+ @classmethod
85
+ def is_fp8_fnuz(cls) -> bool:
86
+ # only device 0 is checked, this assumes MI300 platforms are homogeneous
87
+ return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
88
+
89
+ @classmethod
90
+ @lru_cache(maxsize=8)
91
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
92
+ major, minor = torch.cuda.get_device_capability(device_id)
93
+ return DeviceCapability(major=major, minor=minor)
94
+
95
+ @classmethod
96
+ @lru_cache(maxsize=8)
97
+ def get_device_name(cls, device_id: int = 0) -> str:
98
+ return torch.cuda.get_device_name(device_id)
99
+
100
+ def is_rocm(self):
101
+ return True
102
+
103
+
104
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch28-cxx11-cu126-aarch64-linux/quantization/scalar_type.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import functools
5
+ import struct
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ from typing import Optional, Union
9
+
10
+ _SCALAR_TYPES_ID_MAP = {}
11
+
12
+
13
+ # Mirrors enum in `core/scalar_type.hpp`
14
+ class NanRepr(Enum):
15
+ NONE = 0 # nans are not supported
16
+ IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
17
+ EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
18
+
19
+
20
+ # This ScalarType class is a parallel implementation of the C++ ScalarType
21
+ # class found in csrc/core/scalar_type.hpp. These two classes should be kept
22
+ # in sync until the inductor fully supports custom C++ classes.
23
+ @dataclass(frozen=True)
24
+ class ScalarType:
25
+ """
26
+ ScalarType can represent a wide range of floating point and integer
27
+ types, in particular it can be used to represent sub-byte data types
28
+ (something that torch.dtype currently does not support). It is also
29
+ capable of representing types with a bias, i.e.:
30
+ `stored_value = value + bias`,
31
+ this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
32
+ of 8). The implementation for this class can be found in
33
+ csrc/core/scalar_type.hpp, these type signatures should be kept in sync
34
+ with that file.
35
+ """
36
+
37
+ exponent: int
38
+ """
39
+ Number of bits in the exponent if this is a floating point type
40
+ (zero if this an integer type)
41
+ """
42
+
43
+ mantissa: int
44
+ """
45
+ Number of bits in the mantissa if this is a floating point type,
46
+ or the number bits representing an integer excluding the sign bit if
47
+ this an integer type.
48
+ """
49
+
50
+ signed: bool
51
+ "If the type is signed (i.e. has a sign bit)"
52
+
53
+ bias: int
54
+ """
55
+ bias used to encode the values in this scalar type
56
+ (value = stored_value - bias, default 0) for example if we store the
57
+ type as an unsigned integer with a bias of 128 then the value 0 will be
58
+ stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
59
+ """
60
+
61
+ _finite_values_only: bool = False
62
+ """
63
+ Private: if infs are supported, used `has_infs()` instead.
64
+ """
65
+
66
+ nan_repr: NanRepr = NanRepr.IEEE_754
67
+ """
68
+ How NaNs are represent in this scalar type, returns NanRepr value.
69
+ (not applicable for integer types)
70
+ """
71
+
72
+ def _floating_point_max_int(self) -> int:
73
+ assert (
74
+ self.mantissa <= 52 and self.exponent <= 11
75
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
76
+
77
+ max_mantissa = (1 << self.mantissa) - 1
78
+ if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
79
+ max_mantissa = max_mantissa - 1
80
+
81
+ max_exponent = (1 << self.exponent) - 2
82
+ if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
83
+ or self.nan_repr == NanRepr.NONE):
84
+ assert (
85
+ self.exponent < 11
86
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
87
+ max_exponent = max_exponent + 1
88
+
89
+ # adjust the exponent to match that of a double
90
+ # for now we assume the exponent bias is the standard 2^(e-1) -1, (where
91
+ # e is the exponent bits), there is some precedent for non-standard
92
+ # biases, example `float8_e4m3b11fnuz` here:
93
+ # https://github.com/jax-ml/ml_dtypes but to avoid premature over
94
+ # complication we are just assuming the standard exponent bias until
95
+ # there is a need to support non-standard biases
96
+ exponent_bias = (1 << (self.exponent - 1)) - 1
97
+ exponent_bias_double = (1 << 10) - 1 # double e = 11
98
+
99
+ max_exponent_double = (max_exponent - exponent_bias +
100
+ exponent_bias_double)
101
+
102
+ # shift the mantissa and exponent into the proper positions for an
103
+ # IEEE double and bitwise-or them together.
104
+ return (max_mantissa <<
105
+ (52 - self.mantissa)) | (max_exponent_double << 52)
106
+
107
+ def _floating_point_max(self) -> float:
108
+ double_raw = self._floating_point_max_int()
109
+ return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
110
+
111
+ def _raw_max(self) -> Union[int, float]:
112
+ if self.is_floating_point():
113
+ return self._floating_point_max()
114
+ else:
115
+ assert (self.size_bits < 64 or self.size_bits == 64
116
+ and self.is_signed()), "Cannot represent max as an int"
117
+ return (1 << self.mantissa) - 1
118
+
119
+ def _raw_min(self) -> Union[int, float]:
120
+ if self.is_floating_point():
121
+ assert self.is_signed(
122
+ ), "We currently assume all floating point types are signed"
123
+ sign_bit_double = 1 << 63
124
+
125
+ max_raw = self._floating_point_max_int()
126
+ min_raw = max_raw | sign_bit_double
127
+ return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
128
+ else:
129
+ assert (not self.is_signed() or self.size_bits
130
+ <= 64), "Cannot represent min as a int64_t"
131
+
132
+ if self.is_signed():
133
+ return -(1 << (self.size_bits - 1))
134
+ else:
135
+ return 0
136
+
137
+ @functools.cached_property
138
+ def id(self) -> int:
139
+ """
140
+ Convert the ScalarType to an int which can be passed to pytorch custom
141
+ ops. This layout of the int must be kept in sync with the C++
142
+ ScalarType's from_id method.
143
+ """
144
+ val = 0
145
+ offset = 0
146
+
147
+ def or_and_advance(member, bit_width):
148
+ nonlocal val
149
+ nonlocal offset
150
+ bit_mask = (1 << bit_width) - 1
151
+ val = val | (int(member) & bit_mask) << offset
152
+ offset = offset + bit_width
153
+
154
+ or_and_advance(self.exponent, 8)
155
+ or_and_advance(self.mantissa, 8)
156
+ or_and_advance(self.signed, 1)
157
+ or_and_advance(self.bias, 32)
158
+ or_and_advance(self._finite_values_only, 1)
159
+ or_and_advance(self.nan_repr.value, 8)
160
+
161
+ assert offset <= 64, \
162
+ f"ScalarType fields too big {offset} to fit into an int64"
163
+
164
+ _SCALAR_TYPES_ID_MAP[val] = self
165
+
166
+ return val
167
+
168
+ @property
169
+ def size_bits(self) -> int:
170
+ return self.exponent + self.mantissa + int(self.signed)
171
+
172
+ def min(self) -> Union[int, float]:
173
+ """
174
+ Min representable value for this scalar type.
175
+ (accounting for bias if there is one)
176
+ """
177
+ return self._raw_min() - self.bias
178
+
179
+ def max(self) -> Union[int, float]:
180
+ """
181
+ Max representable value for this scalar type.
182
+ (accounting for bias if there is one)
183
+ """
184
+ return self._raw_max() - self.bias
185
+
186
+ def is_signed(self) -> bool:
187
+ """
188
+ If the type is signed (i.e. has a sign bit), same as `signed`
189
+ added for consistency with:
190
+ https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
191
+ """
192
+ return self.signed
193
+
194
+ def is_floating_point(self) -> bool:
195
+ "If the type is a floating point type"
196
+ return self.exponent != 0
197
+
198
+ def is_integer(self) -> bool:
199
+ "If the type is an integer type"
200
+ return self.exponent == 0
201
+
202
+ def has_bias(self) -> bool:
203
+ "If the type has a non-zero bias"
204
+ return self.bias != 0
205
+
206
+ def has_infs(self) -> bool:
207
+ "If the type is floating point and supports infinity"
208
+ return not self._finite_values_only
209
+
210
+ def has_nans(self) -> bool:
211
+ return self.nan_repr != NanRepr.NONE.value
212
+
213
+ def is_ieee_754(self) -> bool:
214
+ """
215
+ If the type is a floating point type that follows IEEE 754
216
+ conventions
217
+ """
218
+ return self.nan_repr == NanRepr.IEEE_754.value and \
219
+ not self._finite_values_only
220
+
221
+ def __str__(self) -> str:
222
+ """
223
+ naming generally follows: https://github.com/jax-ml/ml_dtypes
224
+ for floating point types (leading f) the scheme is:
225
+ `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
226
+ flags:
227
+ - no-flags: means it follows IEEE 754 conventions
228
+ - f: means finite values only (no infinities)
229
+ - n: means nans are supported (non-standard encoding)
230
+ for integer types the scheme is:
231
+ `[u]int<size_bits>[b<bias>]`
232
+ - if bias is not present it means its zero
233
+ """
234
+ if self.is_floating_point():
235
+ ret = "float" + str(self.size_bits) + "_e" + str(
236
+ self.exponent) + "m" + str(self.mantissa)
237
+
238
+ if not self.is_ieee_754():
239
+ if self._finite_values_only:
240
+ ret = ret + "f"
241
+ if self.nan_repr != NanRepr.NONE:
242
+ ret = ret + "n"
243
+
244
+ return ret
245
+ else:
246
+ ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
247
+ if self.has_bias():
248
+ ret = ret + "b" + str(self.bias)
249
+ return ret
250
+
251
+ def __repr__(self) -> str:
252
+ return "ScalarType." + self.__str__()
253
+
254
+ # __len__ needs to be defined (and has to throw TypeError) for pytorch's
255
+ # opcheck to work.
256
+ def __len__(self) -> int:
257
+ raise TypeError
258
+
259
+ #
260
+ # Convenience Constructors
261
+ #
262
+
263
+ @classmethod
264
+ def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
265
+ "Create a signed integer scalar type (size_bits includes sign-bit)."
266
+ ret = cls(0, size_bits - 1, True, bias if bias else 0)
267
+ ret.id # noqa B018: make sure the id is cached
268
+ return ret
269
+
270
+ @classmethod
271
+ def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
272
+ """Create a unsigned integer scalar type."""
273
+ ret = cls(0, size_bits, False, bias if bias else 0)
274
+ ret.id # noqa B018: make sure the id is cached
275
+ return ret
276
+
277
+ @classmethod
278
+ def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
279
+ """
280
+ Create a standard floating point type
281
+ (i.e. follows IEEE 754 conventions).
282
+ """
283
+ assert (mantissa > 0 and exponent > 0)
284
+ ret = cls(exponent, mantissa, True, 0)
285
+ ret.id # noqa B018: make sure the id is cached
286
+ return ret
287
+
288
+ @classmethod
289
+ def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
290
+ nan_repr: NanRepr) -> 'ScalarType':
291
+ """
292
+ Create a non-standard floating point type
293
+ (i.e. does not follow IEEE 754 conventions).
294
+ """
295
+ assert (mantissa > 0 and exponent > 0)
296
+ assert (nan_repr != NanRepr.IEEE_754), (
297
+ "use `float_IEEE754` constructor for floating point types that "
298
+ "follow IEEE 754 conventions")
299
+ ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
300
+ ret.id # noqa B018: make sure the id is cached
301
+ return ret
302
+
303
+ @classmethod
304
+ def from_id(cls, scalar_type_id: int):
305
+ if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
306
+ raise ValueError(
307
+ f"scalar_type_id {scalar_type_id} doesn't exists.")
308
+ return _SCALAR_TYPES_ID_MAP[scalar_type_id]
309
+
310
+
311
+ # naming generally follows: https://github.com/jax-ml/ml_dtypes
312
+ # for floating point types (leading f) the scheme is:
313
+ # `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
314
+ # flags:
315
+ # - no-flags: means it follows IEEE 754 conventions
316
+ # - f: means finite values only (no infinities)
317
+ # - n: means nans are supported (non-standard encoding)
318
+ # for integer types the scheme is:
319
+ # `[u]int<size_bits>[b<bias>]`
320
+ # - if bias is not present it means its zero
321
+
322
+
323
+ class scalar_types:
324
+ int4 = ScalarType.int_(4, None)
325
+ uint4 = ScalarType.uint(4, None)
326
+ int8 = ScalarType.int_(8, None)
327
+ uint8 = ScalarType.uint(8, None)
328
+ float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
329
+ float8_e5m2 = ScalarType.float_IEEE754(5, 2)
330
+ float16_e8m7 = ScalarType.float_IEEE754(8, 7)
331
+ float16_e5m10 = ScalarType.float_IEEE754(5, 10)
332
+
333
+ # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
334
+ float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
335
+
336
+ # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
337
+ float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
338
+
339
+ # "gptq" types
340
+ uint2b2 = ScalarType.uint(2, 2)
341
+ uint3b4 = ScalarType.uint(3, 4)
342
+ uint4b8 = ScalarType.uint(4, 8)
343
+ uint8b128 = ScalarType.uint(8, 128)
344
+
345
+ # colloquial names
346
+ bfloat16 = float16_e8m7
347
+ float16 = float16_e5m10
build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/__init__.py ADDED
File without changes
build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (186 Bytes). View file
 
build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/__pycache__/marlin_utils.cpython-313.pyc ADDED
Binary file (17.7 kB). View file
 
build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/__pycache__/marlin_utils_fp4.cpython-313.pyc ADDED
Binary file (11.8 kB). View file
 
build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/__pycache__/marlin_utils_fp8.cpython-313.pyc ADDED
Binary file (5.29 kB). View file
 
build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/__pycache__/quant_utils.cpython-313.pyc ADDED
Binary file (19.9 kB). View file
 
build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from typing import Optional
5
+
6
+ import numpy
7
+ import torch
8
+
9
+ from .. import ScalarType, gptq_marlin_gemm, scalar_types
10
+
11
+ from .quant_utils import pack_cols, unpack_cols
12
+
13
+ GPTQ_MARLIN_TILE = 16
14
+ GPTQ_MARLIN_MIN_THREAD_N = 64
15
+ GPTQ_MARLIN_MIN_THREAD_K = 128
16
+ GPTQ_MARLIN_MAX_PARALLEL = 16
17
+
18
+ GPTQ_MARLIN_24_TILE = 16
19
+ GPTQ_MARLIN_24_MIN_THREAD_N = 128
20
+ GPTQ_MARLIN_24_MIN_THREAD_K = 128
21
+ GPTQ_MARLIN_24_MAX_PARALLEL = 64
22
+
23
+ GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
24
+ GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
25
+
26
+ MARLIN_QQQ_TILE = 16
27
+ MARLIN_QQQ_MIN_THREAD_N = 64
28
+ MARLIN_QQQ_MIN_THREAD_K = 128
29
+ MARLIN_QQQ_MAX_PARALLEL = 16
30
+
31
+ MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
32
+ MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
33
+ MARLIN_QQQ_SUPPORTED_SYM = [True]
34
+
35
+ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
36
+
37
+ # In case there is a performance issue with Marlin, the variable below can be
38
+ # changed to False, which allows Marlin to perform global reductions in fp16
39
+ # precision (instead of fp32), and therefore, save on some memory movements.
40
+ USE_FP32_REDUCE_DEFAULT = True
41
+
42
+
43
+ # For binary size and compile time, we don't support the same types for with and
44
+ # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
45
+ # TODO: we may want to move this into the C++ so its closer to the actual impl
46
+ def query_marlin_supported_quant_types(
47
+ has_zp: Optional[bool] = None,
48
+ include_fp_type: bool = True,
49
+ device_capability: Optional[int] = None,
50
+ ):
51
+ if device_capability is None:
52
+ capability_tuple = torch.cuda.get_device_capability()
53
+ device_capability = capability_tuple[0] * 10 + capability_tuple[1]
54
+
55
+ if device_capability < 80:
56
+ return []
57
+
58
+ # - has_zp is True: return quant_types that has zero points
59
+ # - has_zp is False: return quant_types that has not zero points
60
+ # - has_zp is None: both
61
+ if has_zp is None:
62
+ types0 = query_marlin_supported_quant_types(False, include_fp_type,
63
+ device_capability)
64
+ types1 = query_marlin_supported_quant_types(True, include_fp_type,
65
+ device_capability)
66
+ return types0 + types1
67
+
68
+ if has_zp:
69
+ # AWQ style, unsigned + runtime zero-point
70
+ return [scalar_types.uint4]
71
+ else:
72
+ # GPTQ style, unsigned + symmetric bias
73
+ res = [scalar_types.uint4b8, scalar_types.uint8b128]
74
+ if include_fp_type:
75
+ res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
76
+ return res
77
+
78
+
79
+ def _check_marlin_supported(
80
+ quant_type: ScalarType,
81
+ group_size: Optional[int],
82
+ has_zp: bool,
83
+ device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
84
+
85
+ if device_capability is None:
86
+ capability_tuple = torch.cuda.get_device_capability()
87
+ device_capability = capability_tuple[0] * 10 + capability_tuple[1]
88
+
89
+ supported_types = query_marlin_supported_quant_types(
90
+ has_zp, True, device_capability)
91
+
92
+ if quant_type not in supported_types:
93
+ return (False, f"Marlin does not support weight_bits = {quant_type}. "
94
+ f"Only types = {supported_types} "
95
+ f"are supported (for group_size = {group_size}, "
96
+ f"device_capability = {device_capability}, zp = {has_zp}).")
97
+ if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
98
+ return (False, f"Marlin does not support group_size = {group_size}. "
99
+ f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
100
+ "are supported.")
101
+
102
+ return True, None
103
+
104
+
105
+ def check_marlin_supported(quant_type: ScalarType,
106
+ group_size: int,
107
+ has_zp: bool = False,
108
+ device_capability: Optional[int] = None) -> bool:
109
+ cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
110
+ device_capability)
111
+ return cond
112
+
113
+
114
+ def verify_marlin_supported(quant_type: ScalarType,
115
+ group_size: int,
116
+ has_zp: bool = False) -> None:
117
+ cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
118
+ if not cond:
119
+ assert err_msg is not None
120
+ raise ValueError(err_msg)
121
+
122
+
123
+ def verify_marlin_supports_shape(output_size_per_partition: int,
124
+ input_size_per_partition: int,
125
+ input_size: int, group_size: int) -> None:
126
+
127
+ # Validate output_size_per_partition
128
+ if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
129
+ raise ValueError(f"Weight output_size_per_partition = "
130
+ f"{output_size_per_partition} is not divisible by "
131
+ f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
132
+ "Consider reducing tensor_parallel_size or running "
133
+ "with --quantization gptq.")
134
+
135
+ # Validate input_size_per_partition
136
+ if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
137
+ raise ValueError(f"Weight input_size_per_partition = "
138
+ f"{input_size_per_partition} is not divisible "
139
+ f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
140
+ "Consider reducing tensor_parallel_size or running "
141
+ "with --quantization gptq.")
142
+
143
+ if (group_size < input_size
144
+ and input_size_per_partition % group_size != 0):
145
+ raise ValueError(
146
+ f"Weight input_size_per_partition = {input_size_per_partition}"
147
+ f" is not divisible by group_size = {group_size}. "
148
+ "Consider reducing tensor_parallel_size or running "
149
+ "with --quantization gptq.")
150
+
151
+
152
+ def check_marlin_supports_shape(output_size_per_partition: int,
153
+ input_size_per_partition: int,
154
+ input_size: int, group_size: int) \
155
+ -> tuple[bool, Optional[str]]:
156
+ try:
157
+ verify_marlin_supports_shape(output_size_per_partition,
158
+ input_size_per_partition, input_size,
159
+ group_size)
160
+ except ValueError as e:
161
+ return False, e.__str__()
162
+ return True, None
163
+
164
+
165
+ def marlin_make_workspace(output_size_per_partition: int,
166
+ device: torch.device) -> torch.Tensor:
167
+ max_workspace_size = (output_size_per_partition //
168
+ GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
169
+
170
+ return torch.zeros(max_workspace_size,
171
+ dtype=torch.int,
172
+ device=device,
173
+ requires_grad=False)
174
+
175
+
176
+ def marlin_make_workspace_new(device: torch.device,
177
+ max_blocks_per_sm: int = 1) -> torch.Tensor:
178
+ # In the new marlin kernel, we use the num of threadblocks as workspace
179
+ # size. The num of threadblocks is is sms_count * max_blocks_per_sm.
180
+ sms = torch.cuda.get_device_properties(device).multi_processor_count
181
+ return torch.zeros(sms * max_blocks_per_sm,
182
+ dtype=torch.int,
183
+ device=device,
184
+ requires_grad=False)
185
+
186
+
187
+ def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
188
+ return (not act_order) or (act_order and not is_row_parallel)
189
+
190
+
191
+ def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
192
+ is_row_parallel: bool) -> bool:
193
+ # Need to repeat scales on every rank if act_ordering or
194
+ # channelwise and RowParallelLinear
195
+ is_channelwise = group_size == -1
196
+ return act_order or (is_channelwise and is_row_parallel)
197
+
198
+
199
+ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
200
+ return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
201
+ requires_grad=False)
202
+
203
+
204
+ def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
205
+ return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
206
+ requires_grad=False)
207
+
208
+
209
+ def marlin_sort_g_idx(
210
+ g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
211
+ g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
212
+ return g_idx[g_idx_sort_indices], g_idx_sort_indices
213
+
214
+
215
+ def get_scale_perms():
216
+ scale_perm: list[int] = []
217
+ for i in range(8):
218
+ scale_perm.extend([i + 8 * j for j in range(8)])
219
+ scale_perm_single: list[int] = []
220
+ for i in range(4):
221
+ scale_perm_single.extend(
222
+ [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
223
+ return scale_perm, scale_perm_single
224
+
225
+
226
+ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
227
+ group_size: int) -> torch.Tensor:
228
+
229
+ scale_perm, scale_perm_single = get_scale_perms()
230
+ if group_size < size_k and group_size != -1:
231
+ s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
232
+ else:
233
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
234
+ s = s.reshape((-1, size_n)).contiguous()
235
+
236
+ return s
237
+
238
+
239
+ def marlin_moe_permute_scales(
240
+ s: torch.Tensor,
241
+ size_k: int,
242
+ size_n: int,
243
+ group_size: int,
244
+ ):
245
+ num_experts = s.shape[0]
246
+ output = torch.empty(
247
+ (num_experts, s.shape[1], s.shape[2]),
248
+ device=s.device,
249
+ dtype=s.dtype,
250
+ )
251
+
252
+ for e in range(num_experts):
253
+ output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
254
+ return output
255
+
256
+
257
+ def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
258
+ num_bits: int) -> torch.Tensor:
259
+ # Permute zero-points in a similar way to scales, but do not use the
260
+ # "single" permutation, since zero-points are applied on every MMA
261
+ scale_perm, _ = get_scale_perms()
262
+ zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
263
+
264
+ # Interleave column dim (for the dequantize code) and pack it to int32
265
+ if num_bits == 4:
266
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
267
+ elif num_bits == 8:
268
+ interleave = numpy.array([0, 2, 1, 3])
269
+ else:
270
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
271
+
272
+ zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
273
+ zp = zp.reshape((-1, size_n)).contiguous()
274
+ zp = pack_cols(zp, num_bits, size_k, size_n)
275
+
276
+ return zp
277
+
278
+
279
+ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
280
+ size_n: int, num_bits: int) -> torch.Tensor:
281
+ # AWQ zero-points are quantized and packed on the column dim.
282
+ # In addition, the values are permuted based on dequantizer.
283
+ # Here we undo both of these, and then apply marlin permutation
284
+ # and pack it back.
285
+ q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
286
+
287
+ # Undo interleaving (use argsort(..) to get inverse perm)
288
+ if num_bits == 4:
289
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
290
+ elif num_bits == 8:
291
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
292
+ else:
293
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
294
+
295
+ q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
296
+ q_zp = q_zp.reshape((-1, size_n)).contiguous()
297
+
298
+ marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
299
+ return marlin_zp
300
+
301
+
302
+ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
303
+ size_n: int, num_bits: int):
304
+ num_experts = q_zp_packed.shape[0]
305
+ output = torch.empty(
306
+ (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
307
+ device=q_zp_packed.device,
308
+ dtype=q_zp_packed.dtype,
309
+ )
310
+ for e in range(num_experts):
311
+ output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n,
312
+ num_bits)
313
+ return output
314
+
315
+
316
+ def maybe_warn_marlin_atomic_add(device, dtype):
317
+ if torch.compiler.is_dynamo_compiling():
318
+ return
319
+ device_capability = torch.cuda.get_device_capability(device)
320
+ if device_capability[0] < 9 and dtype == torch.bfloat16:
321
+ logger.info_once(
322
+ "You are running Marlin kernel with bf16 on GPUs before SM90. "
323
+ "You can consider change to fp16 to achieve better performance "
324
+ "if possible.")
325
+
326
+
327
+ def maybe_warn_marlin_atomic_add_env():
328
+ if torch.compiler.is_dynamo_compiling():
329
+ return
330
+ if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
331
+ return
332
+ logger.info_once(
333
+ "Marlin kernel can achieve better performance for small size_n "
334
+ "with experimental use_atomic_add feature. "
335
+ "You can consider set environment variable "
336
+ "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.")
337
+
338
+
339
+ def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
340
+ dtype: torch.dtype) -> bool:
341
+
342
+ # the performance of atomicAdd is better than global reduce
343
+ # only when m*n is small and k is large
344
+ if n >= 2048 or k < 2048 or device.type != "cuda":
345
+ return False
346
+
347
+ # disable atomicAdd reduce by default,
348
+ # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
349
+ if not envs.VLLM_MARLIN_USE_ATOMIC_ADD:
350
+ maybe_warn_marlin_atomic_add_env()
351
+ return False
352
+
353
+ # sm8x doesn't support atomicAdd + bfloat16 natively
354
+ device_capability = torch.cuda.get_device_capability(device)
355
+ if device_capability[0] < 9 and dtype == torch.bfloat16:
356
+ maybe_warn_marlin_atomic_add(device, dtype)
357
+ return False
358
+
359
+ return True
360
+
361
+
362
+ def apply_gptq_marlin_linear(
363
+ input: torch.Tensor,
364
+ weight: torch.Tensor,
365
+ weight_scale: torch.Tensor,
366
+ weight_zp: torch.Tensor,
367
+ g_idx: torch.Tensor,
368
+ g_idx_sort_indices: torch.Tensor,
369
+ workspace: torch.Tensor,
370
+ wtype: ScalarType,
371
+ output_size_per_partition: int,
372
+ input_size_per_partition: int,
373
+ is_k_full: bool,
374
+ bias: Optional[torch.Tensor] = None,
375
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
376
+ reshaped_x = input.reshape(-1, input.shape[-1])
377
+ out_shape = input.shape[:-1] + (output_size_per_partition, )
378
+
379
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
380
+ n=output_size_per_partition,
381
+ k=reshaped_x.size(1),
382
+ device=input.device,
383
+ dtype=input.dtype)
384
+
385
+ output = gptq_marlin_gemm(reshaped_x,
386
+ None,
387
+ weight,
388
+ weight_scale,
389
+ None,
390
+ weight_zp,
391
+ g_idx,
392
+ g_idx_sort_indices,
393
+ workspace,
394
+ wtype,
395
+ size_m=reshaped_x.shape[0],
396
+ size_n=output_size_per_partition,
397
+ size_k=input_size_per_partition,
398
+ is_k_full=is_k_full,
399
+ use_atomic_add=use_atomic_add,
400
+ use_fp32_reduce=use_fp32_reduce,
401
+ is_zp_float=False)
402
+
403
+ if bias is not None:
404
+ output.add_(bias) # In-place add
405
+
406
+ return output.reshape(out_shape)
407
+
408
+
409
+ def apply_awq_marlin_linear(
410
+ input: torch.Tensor,
411
+ weight: torch.Tensor,
412
+ weight_scale: torch.Tensor,
413
+ weight_zp: torch.Tensor,
414
+ g_idx: torch.Tensor,
415
+ g_idx_sort_indices: torch.Tensor,
416
+ workspace: torch.Tensor,
417
+ quant_type: ScalarType,
418
+ output_size_per_partition: int,
419
+ input_size_per_partition: int,
420
+ bias: Optional[torch.Tensor] = None,
421
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
422
+ reshaped_x = input.reshape(-1, input.shape[-1])
423
+ out_shape = input.shape[:-1] + (output_size_per_partition, )
424
+
425
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
426
+ n=output_size_per_partition,
427
+ k=reshaped_x.size(1),
428
+ device=input.device,
429
+ dtype=input.dtype)
430
+
431
+ output = gptq_marlin_gemm(reshaped_x,
432
+ None,
433
+ weight,
434
+ weight_scale,
435
+ None,
436
+ weight_zp,
437
+ g_idx,
438
+ g_idx_sort_indices,
439
+ workspace,
440
+ quant_type,
441
+ size_m=reshaped_x.shape[0],
442
+ size_n=output_size_per_partition,
443
+ size_k=input_size_per_partition,
444
+ use_atomic_add=use_atomic_add,
445
+ use_fp32_reduce=use_fp32_reduce,
446
+ is_zp_float=False)
447
+
448
+ if bias is not None:
449
+ output.add_(bias) # In-place add
450
+
451
+ return output.reshape(out_shape)
build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_fp4.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+ from .. import gptq_marlin_gemm, gptq_marlin_repack
9
+ from .marlin_utils import (
10
+ USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
11
+ should_use_atomic_add_reduce)
12
+ from ..scalar_type import scalar_types
13
+
14
+ FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16]
15
+
16
+
17
+ def is_fp4_marlin_supported():
18
+ capability = torch.cuda.get_device_capability()
19
+ capability = capability[0] * 10 + capability[1]
20
+ return capability >= 80
21
+
22
+
23
+ def fp4_marlin_process_scales(marlin_scales):
24
+ if not (marlin_scales >= 0).all():
25
+ logger.warning_once(
26
+ "NVFP4 Marlin assumes the scales to be >=0, but has encountered "
27
+ "negative scales. Accuracy will likely be degraded. This is "
28
+ "because it changes the scales from FP8-S1E4M3 to a special "
29
+ "FP8-S0E5M3 format to speedup the dequantization.")
30
+
31
+ # convert to half first, we would convert to fp8 later
32
+ marlin_scales = marlin_scales.to(torch.half)
33
+
34
+ # 8 is the number of scale number using by one thread
35
+ marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
36
+ marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
37
+ marlin_scales.size(0) * 2, -1)
38
+
39
+ # fit the layout of fp8 dequantization
40
+ marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
41
+ marlin_scales.size(0), -1)
42
+
43
+ # We assume that weight_scale (FP8-S1E4M3) is always greater
44
+ # than or equal to 0. So we can convert
45
+ # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format.
46
+ # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1
47
+ # when weight_scale > 0. This allows us to have an exponent bias
48
+ # closer to zero after dequantization.
49
+
50
+ marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
51
+ marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
52
+ marlin_scales = marlin_scales[:, 1::2].contiguous()
53
+
54
+ return marlin_scales
55
+
56
+
57
+ def fp4_marlin_process_global_scale(global_scale):
58
+ assert global_scale.dtype in [torch.half, torch.bfloat16]
59
+ fp4_exponent = 2
60
+ if global_scale.dtype == torch.half:
61
+ target_exponent = 5
62
+ elif global_scale.dtype == torch.bfloat16:
63
+ target_exponent = 8
64
+ # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14
65
+ # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126
66
+ exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1)
67
+ return global_scale * (2.0**(exponent_bias - 7))
68
+
69
+
70
+ def apply_fp4_marlin_linear(
71
+ input: torch.Tensor,
72
+ weight: torch.Tensor,
73
+ weight_scale: torch.Tensor,
74
+ weight_scale_2: torch.Tensor,
75
+ workspace: torch.Tensor,
76
+ size_n: int,
77
+ size_k: int,
78
+ bias: Optional[torch.Tensor] = None,
79
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
80
+ # For GPUs that lack FP4 hardware support, we can leverage the
81
+ # Marlin kernel for fast weight-only FP4 quantization
82
+
83
+ reshaped_x = input.reshape(-1, input.shape[-1])
84
+ out_shape = input.shape[:-1] + (size_n, )
85
+
86
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
87
+ n=size_n,
88
+ k=size_k,
89
+ device=input.device,
90
+ dtype=input.dtype)
91
+
92
+ output = gptq_marlin_gemm(a=reshaped_x,
93
+ c=None,
94
+ b_q_weight=weight,
95
+ b_scales=weight_scale,
96
+ global_scale=weight_scale_2,
97
+ b_zeros=None,
98
+ g_idx=None,
99
+ perm=None,
100
+ workspace=workspace,
101
+ b_q_type=scalar_types.float4_e2m1f,
102
+ size_m=reshaped_x.size(0),
103
+ size_n=size_n,
104
+ size_k=size_k,
105
+ use_atomic_add=use_atomic_add,
106
+ use_fp32_reduce=use_fp32_reduce)
107
+
108
+ if bias is not None:
109
+ output.add_(bias) # In-place add
110
+
111
+ return output.reshape(out_shape)
112
+
113
+
114
+ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
115
+ logger.warning_once(
116
+ "Your GPU does not have native support for FP4 computation but "
117
+ "FP4 quantization is being used. Weight-only FP4 compression will "
118
+ "be used leveraging the Marlin kernel. This may degrade "
119
+ "performance for compute-heavy workloads.")
120
+
121
+ part_size_n = layer.output_size_per_partition
122
+ part_size_k = layer.input_size_per_partition
123
+ param_dtype = layer.params_dtype
124
+
125
+ assert layer.weight.shape == (part_size_n, part_size_k // 2)
126
+
127
+ device = layer.weight.device
128
+
129
+ # WORKSPACE
130
+ layer.workspace = marlin_make_workspace_new(device)
131
+
132
+ # WEIGHT
133
+ # Repack weights to marlin format
134
+ perm = torch.empty(0, dtype=torch.int, device=device)
135
+ qweight = layer.weight.view(torch.int32).T.contiguous()
136
+
137
+ marlin_qweight = gptq_marlin_repack(b_q_weight=qweight,
138
+ perm=perm,
139
+ size_k=part_size_k,
140
+ size_n=part_size_n,
141
+ num_bits=4)
142
+ layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
143
+
144
+ # WEIGHT SCALES
145
+ # Permute scales
146
+ weight_scale = layer.weight_scale.T.to(param_dtype)
147
+ weight_scale = marlin_permute_scales(s=weight_scale,
148
+ size_k=part_size_k,
149
+ size_n=part_size_n,
150
+ group_size=16)
151
+ weight_scale = fp4_marlin_process_scales(weight_scale)
152
+ layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
153
+
154
+ weight_scale_2 = layer.weight_scale_2.to(param_dtype)
155
+ weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2)
156
+ layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
157
+ requires_grad=False)
158
+
159
+ return
160
+
161
+
162
+ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
163
+ logger.warning_once(
164
+ "Your GPU does not have native support for FP4 computation but "
165
+ "FP4 quantization is being used. Weight-only FP4 compression will "
166
+ "be used leveraging the Marlin kernel. This may degrade "
167
+ "performance for compute-heavy workloads.")
168
+
169
+ e = layer.num_experts
170
+ k = layer.hidden_size
171
+ n = layer.intermediate_size_per_partition
172
+
173
+ # WORKSPACE
174
+ device = layer.w13_weight.device
175
+ param_dtype = layer.params_dtype
176
+ layer.workspace = marlin_make_workspace_new(device, 4)
177
+ perm = torch.empty(0, dtype=torch.int, device=device)
178
+
179
+ # WEIGHT
180
+ # Repack weights to marlin format
181
+ for name in ["w13_weight", "w2_weight"]:
182
+ weight = getattr(layer, name)
183
+ tensor_list = []
184
+ if "w13" in name:
185
+ size_n, size_k = n * 2, k
186
+ else:
187
+ size_n, size_k = k, n
188
+
189
+ assert weight.shape == (e, size_n, size_k // 2)
190
+
191
+ for i in range(e):
192
+ qweight = weight[i].view(torch.int32).T.contiguous()
193
+
194
+ marlin_qweight = gptq_marlin_repack(b_q_weight=qweight,
195
+ perm=perm,
196
+ size_k=size_k,
197
+ size_n=size_n,
198
+ num_bits=4)
199
+ tensor_list.append(marlin_qweight)
200
+
201
+ weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
202
+ weight = torch.nn.Parameter(weight, requires_grad=False)
203
+
204
+ setattr(layer, name, weight)
205
+
206
+ # WEIGHT SCALES
207
+ # Permute scales
208
+ for name in ["w13", "w2"]:
209
+ scales = getattr(layer, name + "_weight_scale").to(param_dtype)
210
+ global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
211
+
212
+ tensor_list = []
213
+ if "w13" in name:
214
+ size_n, size_k = n * 2, k
215
+ else:
216
+ size_n, size_k = k, n
217
+
218
+ for i in range(e):
219
+ marlin_scales = marlin_permute_scales(s=scales[i].T,
220
+ size_k=size_k,
221
+ size_n=size_n,
222
+ group_size=16)
223
+ marlin_scales = fp4_marlin_process_scales(marlin_scales)
224
+ tensor_list.append(marlin_scales)
225
+
226
+ scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
227
+ scales = torch.nn.Parameter(scales, requires_grad=False)
228
+ setattr(layer, name + "_weight_scale", scales)
229
+
230
+ global_scale = fp4_marlin_process_global_scale(global_scale)
231
+ global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
232
+ setattr(layer, name + "_weight_scale_2", global_scale)
233
+
234
+
235
+ def rand_marlin_weight_fp4_like(weight, group_size):
236
+ assert group_size > 0
237
+ size_n, size_k = weight.shape
238
+ device = weight.device
239
+
240
+ scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6
241
+ global_scale = scales.max() / 448
242
+ scales = (scales / global_scale).to(torch.float8_e4m3fn)
243
+
244
+ fp4_weight = torch.randint(0,
245
+ 256, (size_n, size_k // 2),
246
+ dtype=torch.uint8,
247
+ device=weight.device)
248
+ fp4_weight_part_1 = ((fp4_weight & 0b10000000) |
249
+ ((fp4_weight & 0b01110000) >> 2))
250
+ fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn)
251
+ fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6)
252
+
253
+ fp4_weight2 = fp4_weight << 4
254
+ fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) |
255
+ ((fp4_weight2 & 0b01110000) >> 2))
256
+ fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn)
257
+ fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6)
258
+
259
+ weight_ref = torch.cat(
260
+ [fp4_weight_part_2.unsqueeze(2),
261
+ fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k)
262
+ weight_ref = weight_ref * global_scale.to(weight.dtype) * \
263
+ scales.repeat_interleave(group_size, 1).to(weight.dtype)
264
+
265
+ marlin_qweight = gptq_marlin_repack(
266
+ b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
267
+ perm=torch.empty(0, dtype=torch.int, device=device),
268
+ size_k=size_k,
269
+ size_n=size_n,
270
+ num_bits=4,
271
+ )
272
+
273
+ marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype),
274
+ size_k=size_k,
275
+ size_n=size_n,
276
+ group_size=group_size)
277
+ marlin_scales = fp4_marlin_process_scales(marlin_scales)
278
+
279
+ global_scale = fp4_marlin_process_global_scale(global_scale)
280
+
281
+ return weight_ref.T, marlin_qweight, marlin_scales, global_scale
build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_fp8.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+ from .. import gptq_marlin_gemm, gptq_marlin_repack
9
+
10
+ from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales
11
+
12
+
13
+ def is_fp8_marlin_supported():
14
+ capability = torch.cuda.get_device_capability()
15
+ capability = capability[0] * 10 + capability[1]
16
+ return capability >= 80
17
+
18
+
19
+ def fp8_fused_exponent_bias_into_scales(scales):
20
+ fp8_exponent = 4
21
+ if scales.dtype == torch.half:
22
+ target_exponent = 5
23
+ elif scales.dtype == torch.bfloat16:
24
+ target_exponent = 8
25
+ # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
26
+ # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
27
+ exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1)
28
+ s = torch.ones_like(scales) * 2
29
+ s = s**exponent_bias
30
+ return scales * s
31
+
32
+
33
+ def apply_fp8_marlin_linear(
34
+ input: torch.Tensor,
35
+ weight: torch.Tensor,
36
+ weight_scale: torch.Tensor,
37
+ workspace: torch.Tensor,
38
+ size_n: int,
39
+ size_k: int,
40
+ bias: Optional[torch.Tensor],
41
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
42
+ # For GPUs that lack FP8 hardware support, we can leverage the
43
+ # Marlin kernel for fast weight-only FP8 quantization
44
+
45
+ reshaped_x = input.reshape(-1, input.shape[-1])
46
+ out_shape = input.shape[:-1] + (size_n, )
47
+
48
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
49
+ n=size_n,
50
+ k=size_k,
51
+ device=input.device,
52
+ dtype=input.dtype)
53
+
54
+ output = gptq_marlin_gemm(a=reshaped_x,
55
+ c=None,
56
+ b_q_weight=weight,
57
+ b_scales=weight_scale,
58
+ global_scale=None,
59
+ b_zeros=None,
60
+ g_idx=None,
61
+ perm=None,
62
+ workspace=workspace,
63
+ b_q_type=scalar_types.float8_e4m3fn,
64
+ size_m=reshaped_x.size(0),
65
+ size_n=size_n,
66
+ size_k=size_k,
67
+ use_atomic_add=use_atomic_add,
68
+ use_fp32_reduce=use_fp32_reduce)
69
+
70
+ if bias is not None:
71
+ output.add_(bias) # In-place add
72
+
73
+ return output.reshape(out_shape)
74
+
75
+ def pack_fp8_to_int32(fp8_tensor: torch.Tensor,
76
+ size_k_first: bool = True) -> torch.Tensor:
77
+ """
78
+ Repack FP8 weights to gptq format (packed int32 elements)
79
+ """
80
+ assert fp8_tensor.dtype == torch.float8_e4m3fn
81
+ assert fp8_tensor.ndim == 2
82
+
83
+ fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
84
+ fp8_tensor = fp8_tensor.contiguous()
85
+ # fp8_tensor is contiguous and have shape (N, K) now
86
+ # with `.view(torch.int32)`, it become (N, K // 4)
87
+ int32_tensor = fp8_tensor.view(torch.int32)
88
+ return int32_tensor.T.contiguous() if size_k_first else int32_tensor
89
+
90
+
91
+ def marlin_quant_fp8_torch(weight, group_size):
92
+ size_n, size_k = weight.shape
93
+ device = weight.device
94
+
95
+ if group_size != -1:
96
+ scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
97
+ repeated_scales = scales.repeat_interleave(group_size, 1)
98
+ fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
99
+ weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
100
+ else:
101
+ scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
102
+ repeated_scales = scales.repeat_interleave(size_k, 1)
103
+ fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
104
+ weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
105
+
106
+ packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
107
+ marlin_qweight = gptq_marlin_repack(
108
+ b_q_weight=packed_weight,
109
+ perm=torch.empty(0, dtype=torch.int, device=device),
110
+ size_k=size_k,
111
+ size_n=size_n,
112
+ num_bits=8,
113
+ )
114
+
115
+ marlin_scales = marlin_permute_scales(s=scales.T,
116
+ size_k=size_k,
117
+ size_n=size_n,
118
+ group_size=group_size)
119
+
120
+ marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
121
+
122
+ return weight_ref.T, marlin_qweight, marlin_scales
build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_test.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions used for tests and benchmarks"""
2
+
3
+ from typing import List, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from ..scalar_type import ScalarType
9
+ from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
10
+ from .quant_utils import (
11
+ get_pack_factor,
12
+ gptq_quantize_weights,
13
+ quantize_weights,
14
+ sort_weights,
15
+ )
16
+
17
+
18
+ class MarlinWorkspace:
19
+
20
+ def __init__(self, out_features, min_thread_n, max_parallel):
21
+ assert (
22
+ out_features % min_thread_n == 0
23
+ ), "out_features = {} is undivisible by min_thread_n = {}".format(
24
+ out_features, min_thread_n
25
+ )
26
+
27
+ max_workspace_size = (out_features // min_thread_n) * max_parallel
28
+
29
+ self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
30
+
31
+
32
+ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
33
+ assert q_w.shape == (size_k, size_n)
34
+ assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
35
+ assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
36
+
37
+ # Permute weights to 16x64 marlin tiles
38
+ q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
39
+ q_w = q_w.permute((0, 2, 1, 3))
40
+ q_w = q_w.reshape((size_k // tile, size_n * tile))
41
+
42
+ q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
43
+
44
+ return q_w
45
+
46
+
47
+ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
48
+ # Permute
49
+ q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
50
+
51
+ # Pack
52
+ pack_factor = get_pack_factor(num_bits)
53
+ orig_device = q_w.device
54
+
55
+ q_w = q_w.cpu().numpy().astype(np.uint32)
56
+
57
+ q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
58
+ for i in range(pack_factor):
59
+ q_packed |= q_w[:, i::pack_factor] << num_bits * i
60
+
61
+ q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
62
+
63
+ return q_packed
64
+
65
+
66
+ def get_weight_perm(num_bits: int):
67
+ perm_list: List[int] = []
68
+ for i in range(32):
69
+ perm1: List[int] = []
70
+ col = i // 4
71
+ for block in [0, 1]:
72
+ for row in [
73
+ 2 * (i % 4),
74
+ 2 * (i % 4) + 1,
75
+ 2 * (i % 4 + 4),
76
+ 2 * (i % 4 + 4) + 1,
77
+ ]:
78
+ perm1.append(16 * row + col + 8 * block)
79
+ for j in range(4):
80
+ perm_list.extend([p + 256 * j for p in perm1])
81
+
82
+ perm = np.array(perm_list)
83
+
84
+ if num_bits == 4:
85
+ interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
86
+ elif num_bits == 8:
87
+ interleave = np.array([0, 2, 1, 3])
88
+ else:
89
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
90
+
91
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
92
+ perm = torch.from_numpy(perm)
93
+ return perm
94
+
95
+
96
+ def marlin_quantize(
97
+ w: torch.Tensor,
98
+ quant_type: ScalarType,
99
+ group_size: int,
100
+ act_order: bool,
101
+ test_perm: Optional[torch.Tensor] = None,
102
+ ):
103
+ size_k, size_n = w.shape
104
+ num_bits = quant_type.size_bits
105
+
106
+ # Normalize group_size
107
+ if group_size == -1:
108
+ group_size = size_k
109
+ assert group_size <= size_k
110
+
111
+ # Quantize (and apply act_order if provided)
112
+ w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
113
+ w, quant_type, group_size, act_order, test_perm
114
+ )
115
+
116
+ # For act_order, sort the "weights" and "g_idx" so that group ids are
117
+ # increasing
118
+ sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
119
+ if act_order:
120
+ q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
121
+
122
+ # Reformat to marlin
123
+ weight_perm = get_weight_perm(num_bits)
124
+ marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
125
+ marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
126
+
127
+ # Create result
128
+ res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
129
+ for i in range(len(res_list)):
130
+ res_list[i] = res_list[i].to(w.device)
131
+
132
+ return res_list
133
+
134
+
135
+ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
136
+ size_k, size_n = w.shape
137
+
138
+ # Normalize group_size
139
+ if group_size == -1:
140
+ group_size = size_k
141
+ assert group_size <= size_k
142
+
143
+ # Detect num groups
144
+ assert size_k % group_size == 0
145
+ num_groups = size_k // group_size
146
+
147
+ # Quantize with zp
148
+ w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
149
+
150
+ # Reformat to marlin
151
+ weight_perm = get_weight_perm(quant_type.size_bits)
152
+ marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
153
+ marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
154
+ marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
155
+
156
+ # Create result
157
+ res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
158
+ for i in range(len(res_list)):
159
+ res_list[i] = res_list[i].to(w.device)
160
+
161
+ return res_list
build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_test_24.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions used for tests and benchmarks"""
2
+
3
+ import random
4
+ from typing import List
5
+
6
+ import numpy
7
+ import torch
8
+
9
+ from ..scalar_type import ScalarType
10
+ from .marlin_utils_test import marlin_weights
11
+ from .quant_utils import gptq_quantize_weights
12
+
13
+
14
+ # This is PyTorch implementation of main part of reorder_meta()
15
+ # function, from tools/util/include/cutlass/util/host_reorder.h file
16
+ # of CUTLASS source tree. Furthermore, CUTLASS template for sparse
17
+ # GEMM decides upon layout of this matrix, and at the moment for the
18
+ # sparse GEMM executed on tensor cores, this is layout described by
19
+ # ColumnMajorInterleaved<2> data structure, in
20
+ # include/cutlass/layout/matrix.h of CUTLASS source tree. The
21
+ # reordering of meta matrix into meta_reordered matrix calculated
22
+ # according to these segments of CUTLASS code is re-implemented here.
23
+ # Note that this calculation produces offsets for scattering metadata
24
+ # matrix elements into reordered metadata matrix elements (or,
25
+ # equivalently, for gathering reordered metadata matrix element back
26
+ # into metadata matrix elements).
27
+ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
28
+ dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
29
+ dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
30
+
31
+ # Reorder the rows, then swizzle the 2x2 blocks.
32
+ group_x = 64
33
+ group_y = 32 if meta_dtype.itemsize == 2 else 16
34
+
35
+ dst_rows = (
36
+ dst_rows // group_x * group_x
37
+ + (dst_rows % 2) * 2
38
+ + (dst_rows % 8) // 4
39
+ + ((dst_rows % group_y) % 4) // 2 * 32
40
+ + ((dst_rows % group_x) // 8) * 4
41
+ )
42
+
43
+ topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
44
+ bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
45
+ dst_rows += topright - bottomleft
46
+ dst_cols -= topright - bottomleft
47
+
48
+ # Assumed that meta tensor is to be stored in CUTLASS
49
+ # InterleavedColumnMajor layout, and reverse engineered
50
+ # corresponding code to store values into this tensor.
51
+ interleave = 2
52
+ cols_maj = dst_cols // interleave
53
+ cols_min = dst_cols % interleave
54
+ return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
55
+
56
+
57
+ # This function converts dense matrix into sparse semi-structured
58
+ # representation, producing "compressed" matrix, in the layout used by
59
+ # CUTLASS backend, and corresponding metadata matrix.
60
+ def sparse_semi_structured_from_dense_cutlass(dense):
61
+ if dense.dim() != 2:
62
+ raise RuntimeError(
63
+ f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
64
+ )
65
+
66
+ m, k = dense.shape
67
+ device = dense.device
68
+
69
+ meta_dtype = torch.int8
70
+ if dense.dtype == torch.int8:
71
+ meta_dtype = torch.int32
72
+ elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
73
+ meta_dtype = torch.int16
74
+ else:
75
+ raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
76
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
77
+ if quadbits_per_meta_elem not in (4, 8):
78
+ raise RuntimeError("Invalid number of elements per meta element calculated")
79
+
80
+ if meta_dtype == torch.int32:
81
+ if m % 16 != 0:
82
+ raise RuntimeError(
83
+ f"Number of rows of dense matrix {m} must be divisible by 16"
84
+ )
85
+ else:
86
+ if m % 32 != 0:
87
+ raise RuntimeError(
88
+ f"Number of rows of dense matrix {m} must be divisible by 32"
89
+ )
90
+ if k % (4 * quadbits_per_meta_elem) != 0:
91
+ raise RuntimeError(
92
+ f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
93
+ )
94
+
95
+ if dense.dtype != torch.float:
96
+ ksparse = 4
97
+ dense_4 = dense.view(-1, k // ksparse, ksparse)
98
+ m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
99
+ else:
100
+ ksparse = 2
101
+ dense_2 = dense.view(-1, k // ksparse, ksparse)
102
+ m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
103
+ meta_ncols = k // (ksparse * quadbits_per_meta_elem)
104
+
105
+ # Encoding quadruples of True/False values as follows:
106
+ # [True, True, False, False] -> 0b0100
107
+ # [True, False, True, False] -> 0b1000
108
+ # [False, True, True, False] -> 0b1001
109
+ # [True, False, False, True ] -> 0b1100
110
+ # [False, True, False, True ] -> 0b1101
111
+ # [False, False, True, True ] -> 0b1110
112
+ # Thus, lower two bits in the encoding are index of the True value
113
+ # at the lowest index in the quadruple, and the higher two bits in
114
+ # the encoding are index of the other True value in the quadruple.
115
+ # In case there are less than two True values, than False value or
116
+ # values at some index or indices are considered True for the
117
+ # encoding. In case there are more than two True values, then the
118
+ # excess True value(s) at some indices are considered False for
119
+ # the encoding. The exact encodings used for these cases are as
120
+ # follows:
121
+ # [False, False, False, False] -> 0b1110
122
+ # [False, False, False, True ] -> 0b1110
123
+ # [False, False, True, False] -> 0b1110
124
+ # [False, True, False, False] -> 0b1001
125
+ # [False, True, True, True ] -> 0b1101
126
+ # [True, False, False, False] -> 0b1000
127
+ # [True, False, True, True ] -> 0b1100
128
+ # [True, True, False, True ] -> 0b0100
129
+ # [True, True, True, False] -> 0b0100
130
+ # [True, True, True, True ] -> 0b0100
131
+ # These particular encodings are chosen, with the help of Espresso
132
+ # logic minimizer software, for the purpose of minimization of
133
+ # corresponding Boolean functions, that translate non-zero flags
134
+ # into encoding bits. Note also possible choices for the first
135
+ # and last of these encodings were limited only to (0b0100,
136
+ # 0b1110), in order to produce valid encodings for 1:2 sparsity
137
+ # case.
138
+
139
+ expr0 = m0 & m1
140
+ expr1 = ~m0 & m1
141
+ expr2 = ~m0 & ~m1
142
+ bit0 = expr1
143
+ bit1 = expr2
144
+ bit2 = expr0 | expr2 | m3
145
+ bit3 = expr1 | ~m1
146
+ idxs0 = bit0 | (bit1.to(torch.int64) << 1)
147
+ idxs1 = bit2 | (bit3.to(torch.int64) << 1)
148
+
149
+ if dense.dtype != torch.float:
150
+ sparse0 = dense_4.gather(
151
+ -1, idxs0.unsqueeze(-1)
152
+ ) # type: ignore[possibly-undefined]
153
+ sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
154
+ sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
155
+ else:
156
+ sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
157
+ m, k // 2
158
+ ) # type: ignore[possibly-undefined]
159
+
160
+ meta_4 = idxs0 | (idxs1 << 2)
161
+ meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
162
+
163
+ if quadbits_per_meta_elem == 4:
164
+ meta = (
165
+ meta_n[:, :, 0]
166
+ | (meta_n[:, :, 1] << 4)
167
+ | (meta_n[:, :, 2] << 8)
168
+ | (meta_n[:, :, 3] << 12)
169
+ )
170
+ elif quadbits_per_meta_elem == 8:
171
+ meta = (
172
+ meta_n[:, :, 0]
173
+ | (meta_n[:, :, 1] << 4)
174
+ | (meta_n[:, :, 2] << 8)
175
+ | (meta_n[:, :, 3] << 12)
176
+ | (meta_n[:, :, 4] << 16)
177
+ | (meta_n[:, :, 5] << 20)
178
+ | (meta_n[:, :, 6] << 24)
179
+ | (meta_n[:, :, 7] << 28)
180
+ )
181
+
182
+ # Reorder meta tensor elements.
183
+ meta_reordered = meta.new_empty(
184
+ (m * meta_ncols,)
185
+ ) # type: ignore[possibly-undefined]
186
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
187
+ m, meta_ncols, meta_dtype, device
188
+ )
189
+ meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
190
+
191
+ return (sparse, meta_reordered.view(m, meta_ncols))
192
+
193
+
194
+ # This function performs reverse of the function above - it
195
+ # reconstructs dense matrix from a pair of "compressed" matrix, given
196
+ # in the layout used by CUTLASS backend, and accompanying metadata
197
+ # matrix.
198
+ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
199
+ if sparse.dim() != 2:
200
+ raise RuntimeError(
201
+ f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
202
+ )
203
+
204
+ m, k = sparse.shape
205
+ device = sparse.device
206
+
207
+ if meta_reordered.dim() != 2:
208
+ raise RuntimeError(
209
+ f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
210
+ )
211
+ if meta_reordered.device != device:
212
+ raise RuntimeError(
213
+ f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
214
+ )
215
+
216
+ meta_dtype = meta_reordered.dtype
217
+ if meta_dtype not in (torch.int16, torch.int32):
218
+ raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
219
+ quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
220
+
221
+ ksparse = 4 if sparse.dtype != torch.float else 2
222
+
223
+ meta_nrows, meta_ncols = meta_reordered.shape
224
+ if meta_nrows != m:
225
+ raise RuntimeError(
226
+ f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
227
+ )
228
+ if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
229
+ raise RuntimeError(
230
+ f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
231
+ "expected according to the number of columns of meta matrix"
232
+ )
233
+
234
+ # Undo meta tensor elements reordering.
235
+ meta_offsets = _calculate_meta_reordering_scatter_offsets(
236
+ m, meta_ncols, meta_dtype, device
237
+ )
238
+ meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
239
+
240
+ # Unpack sparse tensor back to original dense tensor, using
241
+ # information provided by meta tensor. Note that torch.float
242
+ # datatype is handled pretty much the same as
243
+ # torch.half/torch.bfloat16, as metadata for a pair of torch.float
244
+ # value is encoded as if underlying 8 bytes contain four
245
+ # torch.half/torch.bfloat16 values, where either first two or last
246
+ # two are zeros.
247
+ meta_2 = torch.empty(
248
+ (m, meta_ncols, 2 * quadbits_per_meta_elem),
249
+ dtype=meta_dtype,
250
+ device=device,
251
+ )
252
+ if quadbits_per_meta_elem == 4:
253
+ meta_2[:, :, 0] = meta & 0b11
254
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
255
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
256
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
257
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
258
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
259
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
260
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
261
+ elif quadbits_per_meta_elem == 8:
262
+ meta_2[:, :, 0] = meta & 0b11
263
+ meta_2[:, :, 1] = (meta >> 2) & 0b11
264
+ meta_2[:, :, 2] = (meta >> 4) & 0b11
265
+ meta_2[:, :, 3] = (meta >> 6) & 0b11
266
+ meta_2[:, :, 4] = (meta >> 8) & 0b11
267
+ meta_2[:, :, 5] = (meta >> 10) & 0b11
268
+ meta_2[:, :, 6] = (meta >> 12) & 0b11
269
+ meta_2[:, :, 7] = (meta >> 14) & 0b11
270
+ meta_2[:, :, 8] = (meta >> 16) & 0b11
271
+ meta_2[:, :, 9] = (meta >> 18) & 0b11
272
+ meta_2[:, :, 10] = (meta >> 20) & 0b11
273
+ meta_2[:, :, 11] = (meta >> 22) & 0b11
274
+ meta_2[:, :, 12] = (meta >> 24) & 0b11
275
+ meta_2[:, :, 13] = (meta >> 26) & 0b11
276
+ meta_2[:, :, 14] = (meta >> 28) & 0b11
277
+ meta_2[:, :, 15] = (meta >> 30) & 0b11
278
+
279
+ dense_offsets = meta_2.view(-1) + (
280
+ torch.arange(0, 2 * m * k // ksparse, device=device) * 4
281
+ ).view(-1, 1).repeat(1, 2).view(-1)
282
+
283
+ dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
284
+ if sparse.dtype != torch.float:
285
+ # dense.scatter_(0, dense_offsets, sparse.view(-1))
286
+ dense.scatter_(0, dense_offsets, sparse.reshape(-1))
287
+ else:
288
+ dense.view(torch.half).scatter_(
289
+ 0, dense_offsets, sparse.view(torch.half).view(-1)
290
+ )
291
+
292
+ return dense.view(m, 2 * k)
293
+
294
+
295
+ def mask_creator(tensor):
296
+ """
297
+ Class for creating N:M sparsity masks.
298
+ Masks will be created using the N:M ratio, where for every block of
299
+ M weights, N will be pruned based on ranked weight value. Each mask
300
+ will correspond to the given tensor.
301
+
302
+ :param N: The number of weights in a group to keep
303
+ :param M: The size of a weight group
304
+ """
305
+ N = 2
306
+ M = 4
307
+
308
+ mask = None
309
+ # for i, tensor in enumerate(tensors):
310
+ if tensor.numel() % M != 0:
311
+ raise ValueError(
312
+ f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
313
+ )
314
+
315
+ num_groups = tensor.numel() // M
316
+
317
+ # N:M sparsity for linear layers
318
+ tensor_temp = tensor.detach().abs().reshape(num_groups, M)
319
+ index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
320
+
321
+ w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
322
+ mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
323
+
324
+ return mask
325
+
326
+
327
+ def inject_24(w, size_k, size_n):
328
+ assert w.shape == (size_k, size_n)
329
+
330
+ mask = mask_creator(w.t()).t().cuda().bool()
331
+
332
+ return (mask * w).contiguous(), mask.contiguous()
333
+
334
+
335
+ def check_24(w, num_rows_to_sample=50, _verbose=False):
336
+ BLOCK_SIZE = 4
337
+ MAX_NON_ZEROS = 2
338
+
339
+ w = w.t().contiguous()
340
+
341
+ print("check_24: w.shape = {}".format(w.shape))
342
+
343
+ num_rows, num_cols = w.shape
344
+ sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
345
+ if _verbose:
346
+ print(f"Sampled row idxs = {sampled_row_idxs}")
347
+
348
+ total_segments = 0
349
+ non_24_segments = 0
350
+ for i in sampled_row_idxs:
351
+ for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
352
+ total_segments += 1
353
+ block = w[i, j : j + BLOCK_SIZE]
354
+ num_nonzero = torch.count_nonzero(block)
355
+ if num_nonzero > MAX_NON_ZEROS:
356
+ print("i = {} j = {} block = {}".format(i, j, block))
357
+ non_24_segments += 1
358
+
359
+ print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
360
+
361
+
362
+ def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
363
+ assert q_24.shape == (size_k, size_n)
364
+
365
+ # Remove bias to normalize over 0
366
+ q_24_no_zp = q_24 - wtype.bias
367
+
368
+ # Compress
369
+ q_24_no_zp = q_24_no_zp.t().contiguous()
370
+ q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp)
371
+ q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
372
+
373
+ # Restore bias
374
+ q_24_comp = q_24_no_zp_comp + wtype.bias
375
+
376
+ # Resize meta to its actual shape (without moving any data)
377
+ meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
378
+
379
+ return q_24_comp, meta
380
+
381
+
382
+ def get_scale_perms_24():
383
+ scale_perm: List[int] = []
384
+ for i in range(8):
385
+ scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
386
+ scale_perm_single: List[int] = []
387
+ for i in range(8):
388
+ scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
389
+ return scale_perm, scale_perm_single
390
+
391
+
392
+ def get_weight_perm_24(num_bits: int):
393
+ perm_list: List[int] = []
394
+ for i in range(32):
395
+ perm1: List[int] = []
396
+ col = i // 4
397
+ col_o = col // 2
398
+ for block in [0, 1]:
399
+ for row in [
400
+ 2 * (i % 4),
401
+ 2 * (i % 4) + 1,
402
+ 2 * (i % 4 + 4),
403
+ 2 * (i % 4 + 4) + 1,
404
+ ]:
405
+ perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
406
+ for j in range(4):
407
+ perm_list.extend([p + 1 * j for p in perm1])
408
+ perm = numpy.array(perm_list)
409
+
410
+ if num_bits == 4:
411
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
412
+ elif num_bits == 8:
413
+ interleave = numpy.array([0, 2, 1, 3])
414
+ else:
415
+ raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
416
+
417
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
418
+ perm = torch.from_numpy(perm)
419
+ return perm
420
+
421
+
422
+ def marlin_permute_scales_24(
423
+ s: torch.Tensor, size_k: int, size_n: int, group_size: int
424
+ ) -> torch.Tensor:
425
+
426
+ scale_perm, scale_perm_single = get_scale_perms_24()
427
+ if group_size < size_k and group_size != -1:
428
+ s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
429
+ else:
430
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
431
+ s = s.reshape((-1, size_n)).contiguous()
432
+
433
+ return s
434
+
435
+
436
+ def marlin_24_quantize(
437
+ w: torch.Tensor,
438
+ quant_type: ScalarType,
439
+ group_size: int,
440
+ ):
441
+ size_k, size_n = w.shape
442
+
443
+ # Normalize group_size
444
+ if group_size == -1:
445
+ group_size = size_k
446
+ assert group_size <= size_k
447
+
448
+ # Inject 2:4 sparsity
449
+ w_24, mask_24 = inject_24(w, size_k, size_n)
450
+
451
+ # Quantize
452
+ w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
453
+ w_24, quant_type, group_size, act_order=False
454
+ )
455
+
456
+ # Compress quantized weight
457
+ q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type)
458
+ size_k_comp = size_k // 2
459
+
460
+ # Reformat to marlin
461
+ weight_perm = get_weight_perm_24(quant_type.size_bits)
462
+ marlin_24_q_w_comp = marlin_weights(
463
+ q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm
464
+ )
465
+ marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
466
+
467
+ # Create result
468
+ res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
469
+ for i in range(len(res_list)):
470
+ res_list[i] = res_list[i].to(w.device)
471
+
472
+ return res_list
build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/marlin_utils_test_qqq.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy
4
+ import torch
5
+
6
+ from .marlin_utils_test import marlin_permute_weights
7
+ from .quant_utils import get_pack_factor, qqq_quantize_weights
8
+
9
+
10
+ def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
11
+ # Permute
12
+ q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
13
+
14
+ # Pack
15
+ pack_factor = get_pack_factor(num_bits)
16
+ orig_device = q_w.device
17
+
18
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
19
+
20
+ q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
21
+ dtype=numpy.uint32)
22
+ if group_size == size_k:
23
+ for i in range(pack_factor):
24
+ q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i
25
+ else:
26
+ for i in range(pack_factor):
27
+ q_packed |= q_w[:, i::pack_factor] << num_bits * i
28
+
29
+ q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
30
+
31
+ return q_packed
32
+
33
+
34
+ def get_qqq_scale_perms():
35
+ scale_perm: List[int] = []
36
+ for i in range(8):
37
+ scale_perm.extend([i + 8 * j for j in range(8)])
38
+ scale_perm_single: List[int] = []
39
+ for i in range(4):
40
+ scale_perm_single.extend(
41
+ [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
42
+ return scale_perm, scale_perm_single
43
+
44
+
45
+ # NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
46
+ def get_qqq_weight_perm(num_bits: int, quant_type: str):
47
+ perm_list: List[int] = []
48
+ for i in range(32):
49
+ perm1: List[int] = []
50
+ col = i // 4
51
+ for block in [0, 1]:
52
+ for row in [
53
+ 4 * (i % 4),
54
+ 4 * (i % 4) + 1,
55
+ 4 * (i % 4) + 2,
56
+ 4 * (i % 4) + 3,
57
+ ]:
58
+ perm1.append(16 * row + col + 8 * block)
59
+ for j in range(4):
60
+ perm_list.extend([p + 256 * j for p in perm1])
61
+
62
+ perm = numpy.array(perm_list)
63
+
64
+ assert quant_type in ["per-channel",
65
+ "per-group"], "not supported quantization type"
66
+ if num_bits == 4:
67
+ if quant_type == "per-channel":
68
+ interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3])
69
+ else:
70
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
71
+ else:
72
+ raise Exception("num_bits must be 4, got {}".format(num_bits))
73
+
74
+ perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
75
+ perm = torch.from_numpy(perm)
76
+ return perm
77
+
78
+
79
+ def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size):
80
+ scale_perm, scale_perm_single = get_qqq_scale_perms()
81
+ if group_size < size_k and group_size != -1:
82
+ s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm]
83
+ s_channel = s_channel.reshape(
84
+ (-1, len(scale_perm_single)))[:, scale_perm_single]
85
+ s_group = s_group.reshape((-1, size_n)).contiguous()
86
+ else:
87
+ s_channel = s_channel.reshape(
88
+ (-1, len(scale_perm_single)))[:, scale_perm_single]
89
+ s_channel = s_channel.reshape((-1, size_n)).contiguous()
90
+
91
+ return s_group, s_channel
92
+
93
+
94
+ def marlin_qqq_quantize(
95
+ w: torch.Tensor,
96
+ num_bits: int,
97
+ group_size: int,
98
+ ):
99
+ size_k, size_n = w.shape
100
+
101
+ # Normalize group_size
102
+ if group_size == -1:
103
+ group_size = size_k
104
+ assert group_size <= size_k
105
+ quant_type = "per-channel" if group_size == size_k else "per-group"
106
+
107
+ # Quantize
108
+ w_ref, q_w, s_group, s_channel = qqq_quantize_weights(
109
+ w, num_bits, group_size)
110
+
111
+ # Reformat to marlin_qqq
112
+ weight_perm = get_qqq_weight_perm(num_bits, quant_type)
113
+ marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits,
114
+ weight_perm, group_size)
115
+ marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales(
116
+ s_group, s_channel, size_k, size_n, group_size)
117
+
118
+ # Create result
119
+ res_list = [
120
+ w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel
121
+ ]
122
+ for i in range(len(res_list)):
123
+ res_list[i] = res_list[i].to(w.device)
124
+
125
+ return res_list
build/torch28-cxx11-cu126-aarch64-linux/quantization/utils/quant_utils.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file is used for /tests and /benchmarks"""
2
+
3
+ from typing import List, Optional
4
+
5
+ import numpy
6
+ import torch
7
+
8
+ from ..scalar_type import ScalarType, scalar_types
9
+
10
+ SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
11
+ SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
12
+
13
+ MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
14
+
15
+ # Note: this is a hack. We should update each model to register the
16
+ # stacked params and get it from there instead in a future PR.
17
+ # fused_name: List[shard_name]
18
+ FUSED_LAYER_NAME_MAPPING = {
19
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
20
+ "gate_up_proj": ["gate_proj", "up_proj"],
21
+ }
22
+
23
+
24
+ def pack_quantized_values_into_int32(
25
+ w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
26
+ ):
27
+ # move dim to pack to the end
28
+ perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
29
+ inv_perm = tuple(perm.index(i) for i in range(len(perm)))
30
+ w_q_perm = w_q.permute(perm)
31
+
32
+ pack_factor = 32 // wtype.size_bits
33
+ mask = (1 << wtype.size_bits) - 1
34
+
35
+ new_shape_perm = list(w_q_perm.shape)
36
+ assert w_q_perm.shape[-1] % pack_factor == 0
37
+ new_shape_perm[-1] //= pack_factor
38
+
39
+ res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
40
+ for i in range(pack_factor):
41
+ res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
42
+
43
+ return res.permute(inv_perm)
44
+
45
+
46
+ def unpack_quantized_values_into_int32(
47
+ w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
48
+ ):
49
+ # move dim to pack to the end
50
+ perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
51
+ inv_perm = tuple(perm.index(i) for i in range(len(perm)))
52
+ w_q_perm = w_q.permute(perm)
53
+
54
+ pack_factor = 32 // wtype.size_bits
55
+ mask = (1 << wtype.size_bits) - 1
56
+
57
+ new_shape_perm = list(w_q_perm.shape)
58
+ new_shape_perm[-1] *= pack_factor
59
+
60
+ res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
61
+ for i in range(pack_factor):
62
+ res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
63
+
64
+ return res.permute(inv_perm)
65
+
66
+
67
+ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
68
+ # prefix: model.layers.0.self_attn.q_proj
69
+ # proj_name: q_proj
70
+ proj_name = prefix.split(".")[-1]
71
+ if proj_name in FUSED_LAYER_NAME_MAPPING:
72
+ shard_prefixes = [
73
+ prefix.replace(proj_name, shard_proj_name)
74
+ for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
75
+ ]
76
+
77
+ is_skipped = None
78
+ for shard_prefix in shard_prefixes:
79
+ is_shard_skipped = shard_prefix in ignored_layers
80
+
81
+ if is_skipped is None:
82
+ is_skipped = is_shard_skipped
83
+ elif is_shard_skipped != is_skipped:
84
+ raise ValueError(
85
+ f"Detected some but not all shards of {prefix} "
86
+ "are quantized. All shards of fused layers "
87
+ "to have the same precision."
88
+ )
89
+ else:
90
+ is_skipped = prefix in ignored_layers
91
+
92
+ assert is_skipped is not None
93
+ return is_skipped
94
+
95
+
96
+ def get_pack_factor(num_bits):
97
+ assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
98
+ return 32 // num_bits
99
+
100
+
101
+ def permute_rows(
102
+ q_w: torch.Tensor,
103
+ w_ref: torch.Tensor,
104
+ group_size: int,
105
+ test_perm: Optional[torch.Tensor] = None,
106
+ ):
107
+ assert q_w.shape == w_ref.shape
108
+
109
+ orig_device = q_w.device
110
+ k_size, _ = q_w.shape
111
+
112
+ g_idx = torch.zeros((k_size,), dtype=torch.int32)
113
+ for i in range(k_size):
114
+ g_idx[i] = i // group_size
115
+
116
+ # Simulate act_order by doing a random permutation on K
117
+ rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
118
+
119
+ g_idx = g_idx[rand_perm].contiguous()
120
+ q_w = q_w[rand_perm, :].contiguous()
121
+ w_ref = w_ref[rand_perm, :].contiguous()
122
+
123
+ return (
124
+ w_ref.to(device=orig_device),
125
+ q_w.to(device=orig_device),
126
+ g_idx.to(device=orig_device),
127
+ rand_perm.to(device=orig_device),
128
+ )
129
+
130
+
131
+ def quantize_weights(
132
+ w: torch.Tensor,
133
+ quant_type: ScalarType,
134
+ group_size: Optional[int],
135
+ zero_points: bool = False,
136
+ ref_zero_points_after_scales: bool = False,
137
+ ):
138
+ assert (
139
+ quant_type.is_integer()
140
+ ), "Floating point quantization may work but has not been tested"
141
+ assert not zero_points or group_size is not None, (
142
+ "to have group zero points, group_size must be provided "
143
+ "(-1 group_size is channelwise)"
144
+ )
145
+
146
+ orig_device = w.device
147
+ orig_type = w.dtype
148
+ size_k, size_n = w.shape
149
+
150
+ assert w.is_floating_point(), "w must be float"
151
+
152
+ if group_size == -1:
153
+ group_size = size_k
154
+
155
+ # Reshape to [groupsize, -1]
156
+ if group_size is not None and group_size < size_k:
157
+ w = w.reshape((-1, group_size, size_n))
158
+ w = w.permute(1, 0, 2)
159
+ w = w.reshape((group_size, -1))
160
+
161
+ # Compute scale for each group
162
+ max_val = torch.max(w, 0, keepdim=True).values
163
+ min_val = torch.min(w, 0, keepdim=True).values
164
+
165
+ max_q_val = quant_type.max()
166
+ min_q_val = quant_type.min()
167
+
168
+ w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
169
+ maybe_w_zp = None
170
+ if group_size is not None:
171
+ if zero_points:
172
+ assert not quant_type.is_signed() and quant_type.max() > 0
173
+ w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
174
+ maybe_w_zp = (
175
+ torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
176
+ )
177
+ else:
178
+ # If the bias is such that there are no possible negative/positive
179
+ # values, set the max value to inf to avoid divide by 0
180
+ w_s = torch.max(
181
+ abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
182
+ abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
183
+ )
184
+
185
+ # Quantize
186
+ w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
187
+ w_q = torch.clamp(w_q, min_q_val, max_q_val)
188
+
189
+ # Compute ref (dequantized)
190
+ # For some kernels (namely Machete) the zero-points are applied after the
191
+ # scales are applied, for this case computing the reference in similar way
192
+ # allows us to use tighter error tolerances in our unit tests.
193
+ if ref_zero_points_after_scales and maybe_w_zp is not None:
194
+ w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
195
+ else:
196
+ w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
197
+
198
+ if quant_type.has_bias():
199
+ w_q += quant_type.bias
200
+
201
+ # Restore original shapes
202
+ if group_size is not None and group_size < size_k:
203
+
204
+ def reshape_w(w):
205
+ w = w.reshape((group_size, -1, size_n))
206
+ w = w.permute(1, 0, 2)
207
+ w = w.reshape((size_k, size_n)).contiguous()
208
+ return w
209
+
210
+ w_q = reshape_w(w_q)
211
+ w_ref = reshape_w(w_ref)
212
+ w_s = w_s.reshape((-1, size_n)).contiguous()
213
+
214
+ if maybe_w_zp is not None:
215
+ maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
216
+ maybe_w_zp = maybe_w_zp.to(device=orig_device)
217
+
218
+ return (
219
+ w_ref.to(device=orig_device),
220
+ w_q.to(device=orig_device),
221
+ w_s if group_size is not None else None,
222
+ maybe_w_zp,
223
+ )
224
+
225
+
226
+ def gptq_quantize_weights(
227
+ w: torch.Tensor,
228
+ quant_type: ScalarType,
229
+ group_size: int,
230
+ act_order: bool,
231
+ test_perm: Optional[torch.Tensor] = None,
232
+ ):
233
+ size_k, _ = w.shape
234
+
235
+ assert w.is_floating_point(), "w must be float"
236
+ assert (
237
+ quant_type in SUPPORTED_GPTQ_QUANT_TYPES
238
+ ), f"Unsupported gptq type = {quant_type}"
239
+ assert group_size in SUPPORTED_GROUP_SIZES + [
240
+ size_k
241
+ ], f"Unsupported groupsize = {group_size}"
242
+
243
+ w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
244
+
245
+ # Apply act_order
246
+ g_idx = torch.empty(0, dtype=torch.int, device=w.device)
247
+ rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
248
+ if act_order:
249
+ assert (
250
+ group_size < size_k
251
+ ), "For act_order, groupsize = {} must be less than size_k = {}".format(
252
+ group_size, size_k
253
+ )
254
+
255
+ w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
256
+
257
+ return w_ref, w_q, w_s, g_idx, rand_perm
258
+
259
+
260
+ # QQQ employs different quant schemes for per-group and
261
+ # per-channel quantization.
262
+ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
263
+ orig_device = w.device
264
+ size_k, size_n = w.shape
265
+
266
+ assert w.is_floating_point(), "w must be float"
267
+ assert (
268
+ num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS
269
+ ), f"Unsupported num_bits = {num_bits}"
270
+ assert group_size in SUPPORTED_GROUP_SIZES + [
271
+ size_k
272
+ ], f"Unsupported groupsize = {group_size}"
273
+
274
+ if group_size == -1:
275
+ group_size = size_k
276
+ assert group_size <= size_k
277
+
278
+ if group_size < size_k:
279
+ # Reshape to [groupsize, -1]
280
+ w = w.reshape((-1, group_size, size_n))
281
+ w = w.permute(1, 0, 2)
282
+ w = w.reshape((group_size, -1))
283
+
284
+ max_q_val = 2**num_bits - 1
285
+ half_q_val = (max_q_val + 1) // 2
286
+
287
+ # Compute scale for each group
288
+ s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
289
+ s_group *= 2 / max_q_val # 2 => symmetric
290
+
291
+ # Quantize
292
+ q_w = torch.round(w / s_group).int()
293
+ q_w += half_q_val
294
+ q_w = torch.clamp(q_w, 0, max_q_val)
295
+ # Compute ref (dequantized)
296
+ w_ref = (q_w - half_q_val).half() * s_group
297
+
298
+ # Restore original shapes
299
+ def reshape_w(w):
300
+ w = w.reshape((group_size, -1, size_n))
301
+ w = w.permute(1, 0, 2)
302
+ w = w.reshape((size_k, size_n)).contiguous()
303
+ return w
304
+
305
+ q_w = reshape_w(q_w)
306
+ w_ref = reshape_w(w_ref)
307
+
308
+ # Compute int8 quantization scale for each channel
309
+ s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
310
+ s_channel /= 127.0
311
+ t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
312
+ w_ref = t_int8.half() * s_channel
313
+ s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
314
+
315
+ # Fuse scales
316
+ s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to(
317
+ dtype=torch.half
318
+ )
319
+ else:
320
+ max_q_val = 2 ** (num_bits - 1) - 1
321
+
322
+ # Compute scale for each channel
323
+ s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
324
+ s_channel /= max_q_val
325
+
326
+ # Quantize
327
+ q_w = torch.round(w / s_channel).int()
328
+ q_w = torch.clamp(q_w, -max_q_val, max_q_val)
329
+ # Compute ref (dequantized)
330
+ w_ref = q_w.half() * s_channel
331
+
332
+ s_group = torch.tensor([], dtype=torch.half)
333
+ # div 2 ** (8 - self.bits)) to offset right shift in unpacking
334
+ s_channel /= 2 ** (8 - num_bits)
335
+ s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
336
+
337
+ return (
338
+ w_ref.to(device=orig_device),
339
+ q_w.to(device=orig_device),
340
+ s_group.to(device=orig_device),
341
+ s_channel.to(device=orig_device),
342
+ )
343
+
344
+
345
+ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
346
+ orig_device = q_w.device
347
+
348
+ sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
349
+
350
+ g_idx = g_idx[sort_indices].contiguous()
351
+ q_w = q_w[sort_indices, :].contiguous()
352
+
353
+ return (
354
+ q_w.to(device=orig_device),
355
+ g_idx.to(device=orig_device),
356
+ sort_indices.to(device=orig_device),
357
+ )
358
+
359
+
360
+ def pack_rows(
361
+ q_w: torch.Tensor,
362
+ num_bits: int,
363
+ size_k: int,
364
+ size_n: int,
365
+ ):
366
+ assert q_w.shape == (size_k, size_n)
367
+
368
+ pack_factor = get_pack_factor(num_bits)
369
+ assert size_k % pack_factor == 0
370
+
371
+ orig_device = q_w.device
372
+
373
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
374
+
375
+ q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
376
+
377
+ for i in range(pack_factor):
378
+ q_res |= q_w[i::pack_factor, :] << num_bits * i
379
+
380
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
381
+ return q_res
382
+
383
+
384
+ def pack_cols(
385
+ q_w: torch.Tensor,
386
+ num_bits: int,
387
+ size_k: int,
388
+ size_n: int,
389
+ ):
390
+ assert q_w.shape == (size_k, size_n)
391
+
392
+ pack_factor = get_pack_factor(num_bits)
393
+ assert size_n % pack_factor == 0
394
+
395
+ orig_device = q_w.device
396
+
397
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
398
+
399
+ q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
400
+
401
+ for i in range(pack_factor):
402
+ q_res |= q_w[:, i::pack_factor] << num_bits * i
403
+
404
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
405
+ q_res = q_res.contiguous()
406
+
407
+ return q_res
408
+
409
+
410
+ def unpack_cols(
411
+ packed_q_w: torch.Tensor,
412
+ num_bits: int,
413
+ size_k: int,
414
+ size_n: int,
415
+ ):
416
+ pack_factor = get_pack_factor(num_bits)
417
+ assert size_n % pack_factor == 0
418
+ assert packed_q_w.shape == (
419
+ size_k,
420
+ size_n // pack_factor,
421
+ ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
422
+ packed_q_w.shape, size_k, size_n, pack_factor
423
+ )
424
+
425
+ orig_device = packed_q_w.device
426
+
427
+ packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
428
+ q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
429
+
430
+ mask = (1 << num_bits) - 1
431
+ for i in range(pack_factor):
432
+ vals = packed_q_w_cpu & mask
433
+ packed_q_w_cpu >>= num_bits
434
+ q_res[:, i::pack_factor] = vals
435
+
436
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
437
+ q_res = q_res.contiguous()
438
+
439
+ return q_res
440
+
441
+
442
+ def gptq_pack(
443
+ q_w: torch.Tensor,
444
+ num_bits: int,
445
+ size_k: int,
446
+ size_n: int,
447
+ ):
448
+ return pack_rows(q_w, num_bits, size_k, size_n)
449
+
450
+
451
+ def awq_pack(
452
+ q_w: torch.Tensor,
453
+ num_bits: int,
454
+ size_k: int,
455
+ size_n: int,
456
+ ):
457
+ assert q_w.shape == (size_k, size_n)
458
+
459
+ # Interleave column dim (for the dequantize code) and pack it to int32
460
+ if num_bits == 4:
461
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
462
+ elif num_bits == 8:
463
+ interleave = numpy.array([0, 2, 1, 3])
464
+ else:
465
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
466
+
467
+ q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
468
+ q_w = q_w.reshape((-1, size_n)).contiguous()
469
+
470
+ return pack_cols(q_w, num_bits, size_k, size_n)
build/torch28-cxx11-cu128-aarch64-linux/quantization/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
2
+ from .cutlass import (
3
+ cutlass_scaled_mm_supports_block_fp8,
4
+ cutlass_scaled_mm_supports_fp8,
5
+ cutlass_scaled_mm,
6
+ cutlass_scaled_mm_azp,
7
+ )
8
+ from .marlin import (
9
+ awq_marlin_repack,
10
+ gptq_marlin_gemm,
11
+ gptq_marlin_repack,
12
+ gptq_marlin_24_gemm,
13
+ marlin_qqq_gemm,
14
+ marlin_gemm,
15
+ )
16
+ from .scalar_type import (
17
+ ScalarType,
18
+ scalar_types,
19
+ )
20
+ from ._ops import ops
21
+
22
+ from .utils import marlin_utils
23
+ from .utils import marlin_utils_fp4
24
+ from .utils import marlin_utils_fp8
25
+ from .utils import quant_utils
26
+
27
+
28
+ __all__ = [
29
+ "ScalarType",
30
+ "awq_marlin_repack",
31
+ "cutlass_scaled_mm",
32
+ "cutlass_scaled_mm_azp",
33
+ "cutlass_scaled_mm_supports_block_fp8",
34
+ "cutlass_scaled_mm_supports_fp8",
35
+ "gptq_marlin_24_gemm",
36
+ "gptq_marlin_gemm",
37
+ "gptq_marlin_repack",
38
+ "marlin_gemm",
39
+ "marlin_qqq_gemm",
40
+ "marlin_utils",
41
+ "marlin_utils_fp4",
42
+ "marlin_utils_fp8",
43
+ "ops",
44
+ "quant_utils",
45
+ "scalar_types",
46
+ "scaled_fp8_quant",
47
+ "scaled_int8_quant",
48
+ ]
build/torch28-cxx11-cu128-aarch64-linux/quantization/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (1.02 kB). View file
 
build/torch28-cxx11-cu128-aarch64-linux/quantization/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (533 Bytes). View file
 
build/torch28-cxx11-cu128-aarch64-linux/quantization/__pycache__/compressed_tensors.cpython-313.pyc ADDED
Binary file (5.2 kB). View file
 
build/torch28-cxx11-cu128-aarch64-linux/quantization/__pycache__/cutlass.cpython-313.pyc ADDED
Binary file (3.87 kB). View file
 
build/torch28-cxx11-cu128-aarch64-linux/quantization/__pycache__/marlin.cpython-313.pyc ADDED
Binary file (7.84 kB). View file
 
build/torch28-cxx11-cu128-aarch64-linux/quantization/__pycache__/platforms.cpython-313.pyc ADDED
Binary file (5.8 kB). View file
 
build/torch28-cxx11-cu128-aarch64-linux/quantization/__pycache__/scalar_type.cpython-313.pyc ADDED
Binary file (14.2 kB). View file
 
build/torch28-cxx11-cu128-aarch64-linux/quantization/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _quantization_eabe7c2
3
+ ops = torch.ops._quantization_eabe7c2
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_quantization_eabe7c2::{op_name}"
build/torch28-cxx11-cu128-aarch64-linux/quantization/_quantization_eabe7c2.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:839e3e27bde010190c86a1186665fdbfcdf72e8b55cb7378e191fd35b6d62808
3
+ size 296553224
build/torch28-cxx11-cu128-aarch64-linux/quantization/compressed_tensors.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+ from .platforms import current_platform
7
+
8
+
9
+ # fp8
10
+ def scaled_fp8_quant(
11
+ input: torch.Tensor,
12
+ scale: Optional[torch.Tensor] = None,
13
+ num_token_padding: Optional[int] = None,
14
+ scale_ub: Optional[torch.Tensor] = None,
15
+ use_per_token_if_dynamic: bool = False,
16
+ output: Optional[torch.Tensor] = None,
17
+ ) -> tuple[torch.Tensor, torch.Tensor]:
18
+ """
19
+ Quantize input tensor to FP8 and return quantized tensor and scale.
20
+
21
+ This function supports both static and dynamic quantization: If you
22
+ provide the scale, it will use static scaling and if you omit it,
23
+ the scale will be determined dynamically. The function also allows
24
+ optional padding of the output tensors for downstream kernels that
25
+ will benefit from padding.
26
+
27
+ Args:
28
+ input: The input tensor to be quantized to FP8
29
+ scale: Optional scaling factor for the FP8 quantization
30
+ scale_ub: Optional upper bound for scaling factor in dynamic
31
+ per token case
32
+ num_token_padding: If specified, pad the first dimension
33
+ of the output to at least this value.
34
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
35
+ in the dynamic quantization case.
36
+
37
+ Returns:
38
+ tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
39
+ scaling factor.
40
+ """
41
+ # This code assumes batch_dim and num_tokens are flattened
42
+ assert (input.ndim == 2)
43
+ shape: Union[tuple[int, int], torch.Size] = input.shape
44
+ # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
45
+ out_dtype: torch.dtype = current_platform.fp8_dtype()
46
+ if num_token_padding:
47
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
48
+ if output is None:
49
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
50
+ else:
51
+ assert num_token_padding is None, \
52
+ "padding not supported if output passed in"
53
+ assert output.dtype == out_dtype
54
+
55
+ if scale is None:
56
+ if use_per_token_if_dynamic:
57
+ scale = torch.empty((shape[0], 1),
58
+ device=input.device,
59
+ dtype=torch.float32)
60
+ ops.dynamic_per_token_scaled_fp8_quant(
61
+ output, input.contiguous(), scale, scale_ub)
62
+ else:
63
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
64
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
65
+ else:
66
+ # num_token_padding not implemented for this case
67
+ assert (scale.numel() == 1 and num_token_padding is None)
68
+ ops.static_scaled_fp8_quant(output, input, scale)
69
+
70
+ return output, scale
71
+
72
+
73
+ # int8
74
+ def scaled_int8_quant(
75
+ input: torch.Tensor,
76
+ scale: Optional[torch.Tensor] = None,
77
+ azp: Optional[torch.Tensor] = None,
78
+ symmetric: bool = True
79
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
80
+ """
81
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
82
+
83
+ Args:
84
+ input: The input tensor to be quantized to int8.
85
+ scale: Optional scaling factor for the int8 quantization.
86
+ When not provided, we invoke dynamic-per-token quantization.
87
+ azp: Optional zero-point for the int8 quantization.
88
+ Must be provided for asymmetric quantization if `scale` is provided.
89
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
90
+
91
+ Returns:
92
+ tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
93
+ """
94
+ output = torch.empty_like(input, dtype=torch.int8)
95
+ if scale is not None:
96
+ # static-per-tensor quantization.
97
+ assert symmetric == (
98
+ azp
99
+ is None), "azp must only be provided for asymmetric quantization."
100
+ ops.static_scaled_int8_quant(output, input, scale, azp)
101
+ return output, scale, azp
102
+
103
+ # dynamic-per-token quantization.
104
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
105
+ device=input.device,
106
+ dtype=torch.float32)
107
+ input_azp = None if symmetric else torch.empty_like(input_scales,
108
+ dtype=torch.int32)
109
+ ops.dynamic_scaled_int8_quant(output, input.contiguous(),
110
+ input_scales, input_azp)
111
+ return output, input_scales, input_azp
112
+
113
+
build/torch28-cxx11-cu128-aarch64-linux/quantization/cutlass.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+ from .platforms import current_platform
7
+
8
+
9
+ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
10
+ return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
11
+
12
+
13
+ def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool:
14
+ return ops.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability)
15
+
16
+
17
+ def cutlass_scaled_mm(
18
+ a: torch.Tensor,
19
+ b: torch.Tensor,
20
+ scale_a: torch.Tensor,
21
+ scale_b: torch.Tensor,
22
+ out_dtype: torch.dtype,
23
+ bias: Optional[torch.Tensor] = None,
24
+ ) -> torch.Tensor:
25
+ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
26
+ assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
27
+ assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype
28
+
29
+ m = a.shape[0]
30
+ n = b.shape[1]
31
+
32
+ cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
33
+ if not cutlass_compatible_b:
34
+ from .triton_scaled_mm import triton_scaled_mm
35
+ return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
36
+
37
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
38
+
39
+ ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
40
+
41
+ return out
42
+
43
+
44
+ def cutlass_scaled_mm_azp(
45
+ a: torch.Tensor,
46
+ b: torch.Tensor,
47
+ scale_a: torch.Tensor,
48
+ scale_b: torch.Tensor,
49
+ out_dtype: torch.dtype,
50
+ azp_adj: torch.Tensor,
51
+ azp: Optional[torch.Tensor] = None,
52
+ bias: Optional[torch.Tensor] = None,
53
+ ) -> torch.Tensor:
54
+ """
55
+ :param azp_adj: In the per-tensor case, this should include the azp.
56
+ Always per-channel.
57
+ :param azp: Only set in the per-token case. Per-token if set.
58
+ """
59
+ assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
60
+ assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
61
+ assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
62
+ assert azp is None or azp.numel() == a.shape[0]
63
+
64
+ m = a.shape[0]
65
+ n = b.shape[1]
66
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
67
+
68
+ ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
69
+ return out
build/torch28-cxx11-cu128-aarch64-linux/quantization/marlin.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Optional
2
+
3
+ import torch
4
+
5
+ # neuron has torch version that doesn't even have impl_abstract
6
+ if TYPE_CHECKING:
7
+ def register_fake(fn):
8
+ return lambda name: fn
9
+ else:
10
+ try:
11
+ from torch.library import register_fake
12
+ except ImportError:
13
+ from torch.library import impl_abstract as register_fake
14
+
15
+ try:
16
+ from ._ops import ops, add_op_namespace_prefix
17
+ except ImportError as e:
18
+ # Fallback for local development.
19
+ try:
20
+ import _quantization
21
+
22
+ ops = torch.ops._quantization
23
+
24
+ def add_op_namespace_prefix(op_name: str):
25
+ return f"_quantization::{op_name}"
26
+ except ImportError:
27
+ raise e
28
+
29
+
30
+ from .scalar_type import ScalarType
31
+
32
+
33
+ # gptq_marlin
34
+ def gptq_marlin_gemm(a: torch.Tensor,
35
+ c: Optional[torch.Tensor],
36
+ b_q_weight: torch.Tensor,
37
+ b_scales: torch.Tensor,
38
+ global_scale: Optional[torch.Tensor],
39
+ b_zeros: Optional[torch.Tensor],
40
+ g_idx: Optional[torch.Tensor],
41
+ perm: Optional[torch.Tensor],
42
+ workspace: torch.Tensor,
43
+ b_q_type: ScalarType,
44
+ size_m: int,
45
+ size_n: int,
46
+ size_k: int,
47
+ is_k_full: bool = True,
48
+ use_atomic_add: bool = False,
49
+ use_fp32_reduce: bool = False,
50
+ is_zp_float: bool = False) -> torch.Tensor:
51
+ return ops.gptq_marlin_gemm(a, c, b_q_weight, b_scales,
52
+ global_scale, b_zeros, g_idx, perm,
53
+ workspace, b_q_type.id, size_m,
54
+ size_n, size_k, is_k_full,
55
+ use_atomic_add, use_fp32_reduce,
56
+ is_zp_float)
57
+
58
+ # gptq_marlin
59
+ def gptq_marlin_repack(
60
+ b_q_weight: torch.Tensor,
61
+ perm: torch.Tensor,
62
+ size_k: int,
63
+ size_n: int,
64
+ num_bits: int,
65
+ ) -> torch.Tensor:
66
+ return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
67
+
68
+
69
+ # gptq_marlin
70
+ def awq_marlin_repack(
71
+ b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
72
+ ) -> torch.Tensor:
73
+ return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
74
+
75
+
76
+ # marlin
77
+ def marlin_gemm(
78
+ a: torch.Tensor,
79
+ b_q_weight: torch.Tensor,
80
+ b_scales: torch.Tensor,
81
+ workspace: torch.Tensor,
82
+ size_m: int,
83
+ size_n: int,
84
+ size_k: int,
85
+ ) -> torch.Tensor:
86
+ return ops.marlin_gemm(
87
+ a, b_q_weight, b_scales, workspace, size_m, size_n, size_k
88
+ )
89
+
90
+
91
+ # marlin_24
92
+ def gptq_marlin_24_gemm(
93
+ a: torch.Tensor,
94
+ b_q_weight: torch.Tensor,
95
+ b_meta: torch.Tensor,
96
+ b_scales: torch.Tensor,
97
+ workspace: torch.Tensor,
98
+ b_q_type: ScalarType,
99
+ size_m: int,
100
+ size_n: int,
101
+ size_k: int,
102
+ ) -> torch.Tensor:
103
+ return ops.gptq_marlin_24_gemm(
104
+ a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
105
+ )
106
+
107
+
108
+ # qqq ops
109
+ def marlin_qqq_gemm(
110
+ a: torch.Tensor,
111
+ b_q_weight: torch.Tensor,
112
+ s_tok: torch.Tensor,
113
+ s_ch: torch.Tensor,
114
+ s_group: torch.Tensor,
115
+ workspace: torch.Tensor,
116
+ size_m: int,
117
+ size_n: int,
118
+ size_k: int,
119
+ ) -> torch.Tensor:
120
+ return ops.marlin_qqq_gemm(
121
+ a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k
122
+ )
123
+
124
+
125
+ # Fake ops
126
+
127
+ if hasattr(ops, "gptq_marlin_24_gemm"):
128
+ @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
129
+ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
130
+ b_meta: torch.Tensor, b_scales: torch.Tensor,
131
+ workspace: torch.Tensor,
132
+ b_q_type: ScalarType, size_m: torch.SymInt,
133
+ size_n: torch.SymInt,
134
+ size_k: torch.SymInt) -> torch.Tensor:
135
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
136
+
137
+ @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
138
+ def _gptq_marlin_gemm_fake(a: torch.Tensor,
139
+ c: Optional[torch.Tensor],
140
+ b_q_weight: torch.Tensor,
141
+ b_scales: torch.Tensor,
142
+ global_scale: Optional[torch.Tensor],
143
+ b_zeros: Optional[torch.Tensor],
144
+ g_idx: Optional[torch.Tensor],
145
+ perm: Optional[torch.Tensor],
146
+ workspace: torch.Tensor,
147
+ b_q_type_id: int,
148
+ size_m: torch.SymInt,
149
+ size_n: torch.SymInt,
150
+ size_k: torch.SymInt,
151
+ is_k_full: bool = True,
152
+ use_atomic_add: bool = False,
153
+ use_fp32_reduce: bool = False,
154
+ is_zp_float: bool = False) -> torch.Tensor:
155
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
156
+
157
+ @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
158
+ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
159
+ s_tok: torch.Tensor, s_ch: torch.Tensor,
160
+ s_group: torch.Tensor, workspace: torch.Tensor,
161
+ size_m: torch.SymInt, size_n: torch.SymInt,
162
+ size_k: torch.SymInt) -> torch.Tensor:
163
+ return torch.empty((size_m, size_n),
164
+ dtype=torch.float16,
165
+ device=a.device)
166
+
167
+ @register_fake(add_op_namespace_prefix("marlin_gemm"))
168
+ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
169
+ b_scales: torch.Tensor, workspace: torch.Tensor,
170
+ size_m: torch.SymInt, size_n: torch.SymInt,
171
+ size_k: torch.SymInt) -> torch.Tensor:
172
+ return torch.empty((size_m, size_n),
173
+ dtype=torch.float16,
174
+ device=a.device)
build/torch28-cxx11-cu128-aarch64-linux/quantization/platforms.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from functools import lru_cache
3
+ from typing import NamedTuple
4
+
5
+ import torch
6
+
7
+ IS_ROCM = torch.version.hip is not None
8
+
9
+
10
+ class DeviceCapability(NamedTuple):
11
+ major: int
12
+ minor: int
13
+
14
+ def as_version_str(self) -> str:
15
+ return f"{self.major}.{self.minor}"
16
+
17
+ def to_int(self) -> int:
18
+ """
19
+ Express device capability as an integer ``<major><minor>``.
20
+
21
+ It is assumed that the minor version is always a single digit.
22
+ """
23
+ assert 0 <= self.minor < 10
24
+ return self.major * 10 + self.minor
25
+
26
+
27
+ class Platform(ABC):
28
+ simple_compile_backend: str = "inductor"
29
+
30
+ @classmethod
31
+ def fp8_dtype(cls) -> torch.dtype:
32
+ """
33
+ Returns the preferred FP8 type on the current platform.
34
+
35
+ See the documentation for is_fp8_fnuz for details.
36
+ """
37
+ return torch.float8_e4m3fn
38
+
39
+ @classmethod
40
+ def is_fp8_fnuz(cls) -> bool:
41
+ """
42
+ Returns whether the preferred FP8 type is FNUZ on the current platform.
43
+
44
+ There are two representations of FP8, OCP FP8 and FNUZ FP8.
45
+ The OCP specification can be found at https://tinyurl.com/b7jvwpft.
46
+ The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.
47
+
48
+ AMD's MI300 and MI325 have native hardware support for FNUZ. All other
49
+ hardware has converged on the OCP FP8 standard.
50
+ """
51
+ return False
52
+
53
+ @classmethod
54
+ @abstractmethod
55
+ def get_device_name(cls, device_id: int = 0) -> str: ...
56
+
57
+ @abstractmethod
58
+ def is_rocm(self): ...
59
+
60
+
61
+ class CudaPlatform(Platform):
62
+ @classmethod
63
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
64
+ major, minor = torch.cuda.get_device_capability(device_id)
65
+ return DeviceCapability(major=major, minor=minor)
66
+
67
+ @classmethod
68
+ @lru_cache(maxsize=8)
69
+ def get_device_name(cls, device_id: int = 0) -> str:
70
+ return torch.cuda.get_device_name(0)
71
+
72
+ def is_rocm(self):
73
+ return False
74
+
75
+
76
+ class RocmPlatform(Platform):
77
+ @classmethod
78
+ def fp8_dtype(cls) -> torch.dtype:
79
+ if cls.is_fp8_fnuz():
80
+ return torch.float8_e4m3fnuz
81
+ else:
82
+ return torch.float8_e4m3fn
83
+
84
+ @classmethod
85
+ def is_fp8_fnuz(cls) -> bool:
86
+ # only device 0 is checked, this assumes MI300 platforms are homogeneous
87
+ return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
88
+
89
+ @classmethod
90
+ @lru_cache(maxsize=8)
91
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
92
+ major, minor = torch.cuda.get_device_capability(device_id)
93
+ return DeviceCapability(major=major, minor=minor)
94
+
95
+ @classmethod
96
+ @lru_cache(maxsize=8)
97
+ def get_device_name(cls, device_id: int = 0) -> str:
98
+ return torch.cuda.get_device_name(device_id)
99
+
100
+ def is_rocm(self):
101
+ return True
102
+
103
+
104
+ current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
build/torch28-cxx11-cu128-aarch64-linux/quantization/scalar_type.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import functools
5
+ import struct
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ from typing import Optional, Union
9
+
10
+ _SCALAR_TYPES_ID_MAP = {}
11
+
12
+
13
+ # Mirrors enum in `core/scalar_type.hpp`
14
+ class NanRepr(Enum):
15
+ NONE = 0 # nans are not supported
16
+ IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
17
+ EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
18
+
19
+
20
+ # This ScalarType class is a parallel implementation of the C++ ScalarType
21
+ # class found in csrc/core/scalar_type.hpp. These two classes should be kept
22
+ # in sync until the inductor fully supports custom C++ classes.
23
+ @dataclass(frozen=True)
24
+ class ScalarType:
25
+ """
26
+ ScalarType can represent a wide range of floating point and integer
27
+ types, in particular it can be used to represent sub-byte data types
28
+ (something that torch.dtype currently does not support). It is also
29
+ capable of representing types with a bias, i.e.:
30
+ `stored_value = value + bias`,
31
+ this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
32
+ of 8). The implementation for this class can be found in
33
+ csrc/core/scalar_type.hpp, these type signatures should be kept in sync
34
+ with that file.
35
+ """
36
+
37
+ exponent: int
38
+ """
39
+ Number of bits in the exponent if this is a floating point type
40
+ (zero if this an integer type)
41
+ """
42
+
43
+ mantissa: int
44
+ """
45
+ Number of bits in the mantissa if this is a floating point type,
46
+ or the number bits representing an integer excluding the sign bit if
47
+ this an integer type.
48
+ """
49
+
50
+ signed: bool
51
+ "If the type is signed (i.e. has a sign bit)"
52
+
53
+ bias: int
54
+ """
55
+ bias used to encode the values in this scalar type
56
+ (value = stored_value - bias, default 0) for example if we store the
57
+ type as an unsigned integer with a bias of 128 then the value 0 will be
58
+ stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
59
+ """
60
+
61
+ _finite_values_only: bool = False
62
+ """
63
+ Private: if infs are supported, used `has_infs()` instead.
64
+ """
65
+
66
+ nan_repr: NanRepr = NanRepr.IEEE_754
67
+ """
68
+ How NaNs are represent in this scalar type, returns NanRepr value.
69
+ (not applicable for integer types)
70
+ """
71
+
72
+ def _floating_point_max_int(self) -> int:
73
+ assert (
74
+ self.mantissa <= 52 and self.exponent <= 11
75
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
76
+
77
+ max_mantissa = (1 << self.mantissa) - 1
78
+ if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
79
+ max_mantissa = max_mantissa - 1
80
+
81
+ max_exponent = (1 << self.exponent) - 2
82
+ if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
83
+ or self.nan_repr == NanRepr.NONE):
84
+ assert (
85
+ self.exponent < 11
86
+ ), f"Cannot represent max/min as a double for type {self.__str__()}"
87
+ max_exponent = max_exponent + 1
88
+
89
+ # adjust the exponent to match that of a double
90
+ # for now we assume the exponent bias is the standard 2^(e-1) -1, (where
91
+ # e is the exponent bits), there is some precedent for non-standard
92
+ # biases, example `float8_e4m3b11fnuz` here:
93
+ # https://github.com/jax-ml/ml_dtypes but to avoid premature over
94
+ # complication we are just assuming the standard exponent bias until
95
+ # there is a need to support non-standard biases
96
+ exponent_bias = (1 << (self.exponent - 1)) - 1
97
+ exponent_bias_double = (1 << 10) - 1 # double e = 11
98
+
99
+ max_exponent_double = (max_exponent - exponent_bias +
100
+ exponent_bias_double)
101
+
102
+ # shift the mantissa and exponent into the proper positions for an
103
+ # IEEE double and bitwise-or them together.
104
+ return (max_mantissa <<
105
+ (52 - self.mantissa)) | (max_exponent_double << 52)
106
+
107
+ def _floating_point_max(self) -> float:
108
+ double_raw = self._floating_point_max_int()
109
+ return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
110
+
111
+ def _raw_max(self) -> Union[int, float]:
112
+ if self.is_floating_point():
113
+ return self._floating_point_max()
114
+ else:
115
+ assert (self.size_bits < 64 or self.size_bits == 64
116
+ and self.is_signed()), "Cannot represent max as an int"
117
+ return (1 << self.mantissa) - 1
118
+
119
+ def _raw_min(self) -> Union[int, float]:
120
+ if self.is_floating_point():
121
+ assert self.is_signed(
122
+ ), "We currently assume all floating point types are signed"
123
+ sign_bit_double = 1 << 63
124
+
125
+ max_raw = self._floating_point_max_int()
126
+ min_raw = max_raw | sign_bit_double
127
+ return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
128
+ else:
129
+ assert (not self.is_signed() or self.size_bits
130
+ <= 64), "Cannot represent min as a int64_t"
131
+
132
+ if self.is_signed():
133
+ return -(1 << (self.size_bits - 1))
134
+ else:
135
+ return 0
136
+
137
+ @functools.cached_property
138
+ def id(self) -> int:
139
+ """
140
+ Convert the ScalarType to an int which can be passed to pytorch custom
141
+ ops. This layout of the int must be kept in sync with the C++
142
+ ScalarType's from_id method.
143
+ """
144
+ val = 0
145
+ offset = 0
146
+
147
+ def or_and_advance(member, bit_width):
148
+ nonlocal val
149
+ nonlocal offset
150
+ bit_mask = (1 << bit_width) - 1
151
+ val = val | (int(member) & bit_mask) << offset
152
+ offset = offset + bit_width
153
+
154
+ or_and_advance(self.exponent, 8)
155
+ or_and_advance(self.mantissa, 8)
156
+ or_and_advance(self.signed, 1)
157
+ or_and_advance(self.bias, 32)
158
+ or_and_advance(self._finite_values_only, 1)
159
+ or_and_advance(self.nan_repr.value, 8)
160
+
161
+ assert offset <= 64, \
162
+ f"ScalarType fields too big {offset} to fit into an int64"
163
+
164
+ _SCALAR_TYPES_ID_MAP[val] = self
165
+
166
+ return val
167
+
168
+ @property
169
+ def size_bits(self) -> int:
170
+ return self.exponent + self.mantissa + int(self.signed)
171
+
172
+ def min(self) -> Union[int, float]:
173
+ """
174
+ Min representable value for this scalar type.
175
+ (accounting for bias if there is one)
176
+ """
177
+ return self._raw_min() - self.bias
178
+
179
+ def max(self) -> Union[int, float]:
180
+ """
181
+ Max representable value for this scalar type.
182
+ (accounting for bias if there is one)
183
+ """
184
+ return self._raw_max() - self.bias
185
+
186
+ def is_signed(self) -> bool:
187
+ """
188
+ If the type is signed (i.e. has a sign bit), same as `signed`
189
+ added for consistency with:
190
+ https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
191
+ """
192
+ return self.signed
193
+
194
+ def is_floating_point(self) -> bool:
195
+ "If the type is a floating point type"
196
+ return self.exponent != 0
197
+
198
+ def is_integer(self) -> bool:
199
+ "If the type is an integer type"
200
+ return self.exponent == 0
201
+
202
+ def has_bias(self) -> bool:
203
+ "If the type has a non-zero bias"
204
+ return self.bias != 0
205
+
206
+ def has_infs(self) -> bool:
207
+ "If the type is floating point and supports infinity"
208
+ return not self._finite_values_only
209
+
210
+ def has_nans(self) -> bool:
211
+ return self.nan_repr != NanRepr.NONE.value
212
+
213
+ def is_ieee_754(self) -> bool:
214
+ """
215
+ If the type is a floating point type that follows IEEE 754
216
+ conventions
217
+ """
218
+ return self.nan_repr == NanRepr.IEEE_754.value and \
219
+ not self._finite_values_only
220
+
221
+ def __str__(self) -> str:
222
+ """
223
+ naming generally follows: https://github.com/jax-ml/ml_dtypes
224
+ for floating point types (leading f) the scheme is:
225
+ `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
226
+ flags:
227
+ - no-flags: means it follows IEEE 754 conventions
228
+ - f: means finite values only (no infinities)
229
+ - n: means nans are supported (non-standard encoding)
230
+ for integer types the scheme is:
231
+ `[u]int<size_bits>[b<bias>]`
232
+ - if bias is not present it means its zero
233
+ """
234
+ if self.is_floating_point():
235
+ ret = "float" + str(self.size_bits) + "_e" + str(
236
+ self.exponent) + "m" + str(self.mantissa)
237
+
238
+ if not self.is_ieee_754():
239
+ if self._finite_values_only:
240
+ ret = ret + "f"
241
+ if self.nan_repr != NanRepr.NONE:
242
+ ret = ret + "n"
243
+
244
+ return ret
245
+ else:
246
+ ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
247
+ if self.has_bias():
248
+ ret = ret + "b" + str(self.bias)
249
+ return ret
250
+
251
+ def __repr__(self) -> str:
252
+ return "ScalarType." + self.__str__()
253
+
254
+ # __len__ needs to be defined (and has to throw TypeError) for pytorch's
255
+ # opcheck to work.
256
+ def __len__(self) -> int:
257
+ raise TypeError
258
+
259
+ #
260
+ # Convenience Constructors
261
+ #
262
+
263
+ @classmethod
264
+ def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
265
+ "Create a signed integer scalar type (size_bits includes sign-bit)."
266
+ ret = cls(0, size_bits - 1, True, bias if bias else 0)
267
+ ret.id # noqa B018: make sure the id is cached
268
+ return ret
269
+
270
+ @classmethod
271
+ def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
272
+ """Create a unsigned integer scalar type."""
273
+ ret = cls(0, size_bits, False, bias if bias else 0)
274
+ ret.id # noqa B018: make sure the id is cached
275
+ return ret
276
+
277
+ @classmethod
278
+ def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
279
+ """
280
+ Create a standard floating point type
281
+ (i.e. follows IEEE 754 conventions).
282
+ """
283
+ assert (mantissa > 0 and exponent > 0)
284
+ ret = cls(exponent, mantissa, True, 0)
285
+ ret.id # noqa B018: make sure the id is cached
286
+ return ret
287
+
288
+ @classmethod
289
+ def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
290
+ nan_repr: NanRepr) -> 'ScalarType':
291
+ """
292
+ Create a non-standard floating point type
293
+ (i.e. does not follow IEEE 754 conventions).
294
+ """
295
+ assert (mantissa > 0 and exponent > 0)
296
+ assert (nan_repr != NanRepr.IEEE_754), (
297
+ "use `float_IEEE754` constructor for floating point types that "
298
+ "follow IEEE 754 conventions")
299
+ ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
300
+ ret.id # noqa B018: make sure the id is cached
301
+ return ret
302
+
303
+ @classmethod
304
+ def from_id(cls, scalar_type_id: int):
305
+ if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
306
+ raise ValueError(
307
+ f"scalar_type_id {scalar_type_id} doesn't exists.")
308
+ return _SCALAR_TYPES_ID_MAP[scalar_type_id]
309
+
310
+
311
+ # naming generally follows: https://github.com/jax-ml/ml_dtypes
312
+ # for floating point types (leading f) the scheme is:
313
+ # `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
314
+ # flags:
315
+ # - no-flags: means it follows IEEE 754 conventions
316
+ # - f: means finite values only (no infinities)
317
+ # - n: means nans are supported (non-standard encoding)
318
+ # for integer types the scheme is:
319
+ # `[u]int<size_bits>[b<bias>]`
320
+ # - if bias is not present it means its zero
321
+
322
+
323
+ class scalar_types:
324
+ int4 = ScalarType.int_(4, None)
325
+ uint4 = ScalarType.uint(4, None)
326
+ int8 = ScalarType.int_(8, None)
327
+ uint8 = ScalarType.uint(8, None)
328
+ float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
329
+ float8_e5m2 = ScalarType.float_IEEE754(5, 2)
330
+ float16_e8m7 = ScalarType.float_IEEE754(8, 7)
331
+ float16_e5m10 = ScalarType.float_IEEE754(5, 10)
332
+
333
+ # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
334
+ float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
335
+
336
+ # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
337
+ float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
338
+
339
+ # "gptq" types
340
+ uint2b2 = ScalarType.uint(2, 2)
341
+ uint3b4 = ScalarType.uint(3, 4)
342
+ uint4b8 = ScalarType.uint(4, 8)
343
+ uint8b128 = ScalarType.uint(8, 128)
344
+
345
+ # colloquial names
346
+ bfloat16 = float16_e8m7
347
+ float16 = float16_e5m10
build/torch28-cxx11-cu128-aarch64-linux/quantization/utils/__init__.py ADDED
File without changes
build/torch28-cxx11-cu128-aarch64-linux/quantization/utils/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (186 Bytes). View file
 
build/torch28-cxx11-cu128-aarch64-linux/quantization/utils/__pycache__/marlin_utils.cpython-313.pyc ADDED
Binary file (17.7 kB). View file
 
build/torch28-cxx11-cu128-aarch64-linux/quantization/utils/__pycache__/marlin_utils_fp4.cpython-313.pyc ADDED
Binary file (11.8 kB). View file
 
build/torch28-cxx11-cu128-aarch64-linux/quantization/utils/__pycache__/marlin_utils_fp8.cpython-313.pyc ADDED
Binary file (5.29 kB). View file
 
build/torch28-cxx11-cu128-aarch64-linux/quantization/utils/__pycache__/quant_utils.cpython-313.pyc ADDED
Binary file (19.9 kB). View file
 
build/torch28-cxx11-cu128-aarch64-linux/quantization/utils/marlin_utils.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ from typing import Optional
5
+
6
+ import numpy
7
+ import torch
8
+
9
+ from .. import ScalarType, gptq_marlin_gemm, scalar_types
10
+
11
+ from .quant_utils import pack_cols, unpack_cols
12
+
13
+ GPTQ_MARLIN_TILE = 16
14
+ GPTQ_MARLIN_MIN_THREAD_N = 64
15
+ GPTQ_MARLIN_MIN_THREAD_K = 128
16
+ GPTQ_MARLIN_MAX_PARALLEL = 16
17
+
18
+ GPTQ_MARLIN_24_TILE = 16
19
+ GPTQ_MARLIN_24_MIN_THREAD_N = 128
20
+ GPTQ_MARLIN_24_MIN_THREAD_K = 128
21
+ GPTQ_MARLIN_24_MAX_PARALLEL = 64
22
+
23
+ GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
24
+ GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
25
+
26
+ MARLIN_QQQ_TILE = 16
27
+ MARLIN_QQQ_MIN_THREAD_N = 64
28
+ MARLIN_QQQ_MIN_THREAD_K = 128
29
+ MARLIN_QQQ_MAX_PARALLEL = 16
30
+
31
+ MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
32
+ MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
33
+ MARLIN_QQQ_SUPPORTED_SYM = [True]
34
+
35
+ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
36
+
37
+ # In case there is a performance issue with Marlin, the variable below can be
38
+ # changed to False, which allows Marlin to perform global reductions in fp16
39
+ # precision (instead of fp32), and therefore, save on some memory movements.
40
+ USE_FP32_REDUCE_DEFAULT = True
41
+
42
+
43
+ # For binary size and compile time, we don't support the same types for with and
44
+ # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
45
+ # TODO: we may want to move this into the C++ so its closer to the actual impl
46
+ def query_marlin_supported_quant_types(
47
+ has_zp: Optional[bool] = None,
48
+ include_fp_type: bool = True,
49
+ device_capability: Optional[int] = None,
50
+ ):
51
+ if device_capability is None:
52
+ capability_tuple = torch.cuda.get_device_capability()
53
+ device_capability = capability_tuple[0] * 10 + capability_tuple[1]
54
+
55
+ if device_capability < 80:
56
+ return []
57
+
58
+ # - has_zp is True: return quant_types that has zero points
59
+ # - has_zp is False: return quant_types that has not zero points
60
+ # - has_zp is None: both
61
+ if has_zp is None:
62
+ types0 = query_marlin_supported_quant_types(False, include_fp_type,
63
+ device_capability)
64
+ types1 = query_marlin_supported_quant_types(True, include_fp_type,
65
+ device_capability)
66
+ return types0 + types1
67
+
68
+ if has_zp:
69
+ # AWQ style, unsigned + runtime zero-point
70
+ return [scalar_types.uint4]
71
+ else:
72
+ # GPTQ style, unsigned + symmetric bias
73
+ res = [scalar_types.uint4b8, scalar_types.uint8b128]
74
+ if include_fp_type:
75
+ res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
76
+ return res
77
+
78
+
79
+ def _check_marlin_supported(
80
+ quant_type: ScalarType,
81
+ group_size: Optional[int],
82
+ has_zp: bool,
83
+ device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]:
84
+
85
+ if device_capability is None:
86
+ capability_tuple = torch.cuda.get_device_capability()
87
+ device_capability = capability_tuple[0] * 10 + capability_tuple[1]
88
+
89
+ supported_types = query_marlin_supported_quant_types(
90
+ has_zp, True, device_capability)
91
+
92
+ if quant_type not in supported_types:
93
+ return (False, f"Marlin does not support weight_bits = {quant_type}. "
94
+ f"Only types = {supported_types} "
95
+ f"are supported (for group_size = {group_size}, "
96
+ f"device_capability = {device_capability}, zp = {has_zp}).")
97
+ if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
98
+ return (False, f"Marlin does not support group_size = {group_size}. "
99
+ f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
100
+ "are supported.")
101
+
102
+ return True, None
103
+
104
+
105
+ def check_marlin_supported(quant_type: ScalarType,
106
+ group_size: int,
107
+ has_zp: bool = False,
108
+ device_capability: Optional[int] = None) -> bool:
109
+ cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
110
+ device_capability)
111
+ return cond
112
+
113
+
114
+ def verify_marlin_supported(quant_type: ScalarType,
115
+ group_size: int,
116
+ has_zp: bool = False) -> None:
117
+ cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
118
+ if not cond:
119
+ assert err_msg is not None
120
+ raise ValueError(err_msg)
121
+
122
+
123
+ def verify_marlin_supports_shape(output_size_per_partition: int,
124
+ input_size_per_partition: int,
125
+ input_size: int, group_size: int) -> None:
126
+
127
+ # Validate output_size_per_partition
128
+ if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
129
+ raise ValueError(f"Weight output_size_per_partition = "
130
+ f"{output_size_per_partition} is not divisible by "
131
+ f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
132
+ "Consider reducing tensor_parallel_size or running "
133
+ "with --quantization gptq.")
134
+
135
+ # Validate input_size_per_partition
136
+ if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
137
+ raise ValueError(f"Weight input_size_per_partition = "
138
+ f"{input_size_per_partition} is not divisible "
139
+ f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
140
+ "Consider reducing tensor_parallel_size or running "
141
+ "with --quantization gptq.")
142
+
143
+ if (group_size < input_size
144
+ and input_size_per_partition % group_size != 0):
145
+ raise ValueError(
146
+ f"Weight input_size_per_partition = {input_size_per_partition}"
147
+ f" is not divisible by group_size = {group_size}. "
148
+ "Consider reducing tensor_parallel_size or running "
149
+ "with --quantization gptq.")
150
+
151
+
152
+ def check_marlin_supports_shape(output_size_per_partition: int,
153
+ input_size_per_partition: int,
154
+ input_size: int, group_size: int) \
155
+ -> tuple[bool, Optional[str]]:
156
+ try:
157
+ verify_marlin_supports_shape(output_size_per_partition,
158
+ input_size_per_partition, input_size,
159
+ group_size)
160
+ except ValueError as e:
161
+ return False, e.__str__()
162
+ return True, None
163
+
164
+
165
+ def marlin_make_workspace(output_size_per_partition: int,
166
+ device: torch.device) -> torch.Tensor:
167
+ max_workspace_size = (output_size_per_partition //
168
+ GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
169
+
170
+ return torch.zeros(max_workspace_size,
171
+ dtype=torch.int,
172
+ device=device,
173
+ requires_grad=False)
174
+
175
+
176
+ def marlin_make_workspace_new(device: torch.device,
177
+ max_blocks_per_sm: int = 1) -> torch.Tensor:
178
+ # In the new marlin kernel, we use the num of threadblocks as workspace
179
+ # size. The num of threadblocks is is sms_count * max_blocks_per_sm.
180
+ sms = torch.cuda.get_device_properties(device).multi_processor_count
181
+ return torch.zeros(sms * max_blocks_per_sm,
182
+ dtype=torch.int,
183
+ device=device,
184
+ requires_grad=False)
185
+
186
+
187
+ def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
188
+ return (not act_order) or (act_order and not is_row_parallel)
189
+
190
+
191
+ def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
192
+ is_row_parallel: bool) -> bool:
193
+ # Need to repeat scales on every rank if act_ordering or
194
+ # channelwise and RowParallelLinear
195
+ is_channelwise = group_size == -1
196
+ return act_order or (is_channelwise and is_row_parallel)
197
+
198
+
199
+ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
200
+ return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
201
+ requires_grad=False)
202
+
203
+
204
+ def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
205
+ return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
206
+ requires_grad=False)
207
+
208
+
209
+ def marlin_sort_g_idx(
210
+ g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
211
+ g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
212
+ return g_idx[g_idx_sort_indices], g_idx_sort_indices
213
+
214
+
215
+ def get_scale_perms():
216
+ scale_perm: list[int] = []
217
+ for i in range(8):
218
+ scale_perm.extend([i + 8 * j for j in range(8)])
219
+ scale_perm_single: list[int] = []
220
+ for i in range(4):
221
+ scale_perm_single.extend(
222
+ [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
223
+ return scale_perm, scale_perm_single
224
+
225
+
226
+ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
227
+ group_size: int) -> torch.Tensor:
228
+
229
+ scale_perm, scale_perm_single = get_scale_perms()
230
+ if group_size < size_k and group_size != -1:
231
+ s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
232
+ else:
233
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
234
+ s = s.reshape((-1, size_n)).contiguous()
235
+
236
+ return s
237
+
238
+
239
+ def marlin_moe_permute_scales(
240
+ s: torch.Tensor,
241
+ size_k: int,
242
+ size_n: int,
243
+ group_size: int,
244
+ ):
245
+ num_experts = s.shape[0]
246
+ output = torch.empty(
247
+ (num_experts, s.shape[1], s.shape[2]),
248
+ device=s.device,
249
+ dtype=s.dtype,
250
+ )
251
+
252
+ for e in range(num_experts):
253
+ output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
254
+ return output
255
+
256
+
257
+ def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
258
+ num_bits: int) -> torch.Tensor:
259
+ # Permute zero-points in a similar way to scales, but do not use the
260
+ # "single" permutation, since zero-points are applied on every MMA
261
+ scale_perm, _ = get_scale_perms()
262
+ zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
263
+
264
+ # Interleave column dim (for the dequantize code) and pack it to int32
265
+ if num_bits == 4:
266
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
267
+ elif num_bits == 8:
268
+ interleave = numpy.array([0, 2, 1, 3])
269
+ else:
270
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
271
+
272
+ zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
273
+ zp = zp.reshape((-1, size_n)).contiguous()
274
+ zp = pack_cols(zp, num_bits, size_k, size_n)
275
+
276
+ return zp
277
+
278
+
279
+ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
280
+ size_n: int, num_bits: int) -> torch.Tensor:
281
+ # AWQ zero-points are quantized and packed on the column dim.
282
+ # In addition, the values are permuted based on dequantizer.
283
+ # Here we undo both of these, and then apply marlin permutation
284
+ # and pack it back.
285
+ q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
286
+
287
+ # Undo interleaving (use argsort(..) to get inverse perm)
288
+ if num_bits == 4:
289
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
290
+ elif num_bits == 8:
291
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
292
+ else:
293
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
294
+
295
+ q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
296
+ q_zp = q_zp.reshape((-1, size_n)).contiguous()
297
+
298
+ marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
299
+ return marlin_zp
300
+
301
+
302
+ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
303
+ size_n: int, num_bits: int):
304
+ num_experts = q_zp_packed.shape[0]
305
+ output = torch.empty(
306
+ (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
307
+ device=q_zp_packed.device,
308
+ dtype=q_zp_packed.dtype,
309
+ )
310
+ for e in range(num_experts):
311
+ output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n,
312
+ num_bits)
313
+ return output
314
+
315
+
316
+ def maybe_warn_marlin_atomic_add(device, dtype):
317
+ if torch.compiler.is_dynamo_compiling():
318
+ return
319
+ device_capability = torch.cuda.get_device_capability(device)
320
+ if device_capability[0] < 9 and dtype == torch.bfloat16:
321
+ logger.info_once(
322
+ "You are running Marlin kernel with bf16 on GPUs before SM90. "
323
+ "You can consider change to fp16 to achieve better performance "
324
+ "if possible.")
325
+
326
+
327
+ def maybe_warn_marlin_atomic_add_env():
328
+ if torch.compiler.is_dynamo_compiling():
329
+ return
330
+ if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
331
+ return
332
+ logger.info_once(
333
+ "Marlin kernel can achieve better performance for small size_n "
334
+ "with experimental use_atomic_add feature. "
335
+ "You can consider set environment variable "
336
+ "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.")
337
+
338
+
339
+ def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
340
+ dtype: torch.dtype) -> bool:
341
+
342
+ # the performance of atomicAdd is better than global reduce
343
+ # only when m*n is small and k is large
344
+ if n >= 2048 or k < 2048 or device.type != "cuda":
345
+ return False
346
+
347
+ # disable atomicAdd reduce by default,
348
+ # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
349
+ if not envs.VLLM_MARLIN_USE_ATOMIC_ADD:
350
+ maybe_warn_marlin_atomic_add_env()
351
+ return False
352
+
353
+ # sm8x doesn't support atomicAdd + bfloat16 natively
354
+ device_capability = torch.cuda.get_device_capability(device)
355
+ if device_capability[0] < 9 and dtype == torch.bfloat16:
356
+ maybe_warn_marlin_atomic_add(device, dtype)
357
+ return False
358
+
359
+ return True
360
+
361
+
362
+ def apply_gptq_marlin_linear(
363
+ input: torch.Tensor,
364
+ weight: torch.Tensor,
365
+ weight_scale: torch.Tensor,
366
+ weight_zp: torch.Tensor,
367
+ g_idx: torch.Tensor,
368
+ g_idx_sort_indices: torch.Tensor,
369
+ workspace: torch.Tensor,
370
+ wtype: ScalarType,
371
+ output_size_per_partition: int,
372
+ input_size_per_partition: int,
373
+ is_k_full: bool,
374
+ bias: Optional[torch.Tensor] = None,
375
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
376
+ reshaped_x = input.reshape(-1, input.shape[-1])
377
+ out_shape = input.shape[:-1] + (output_size_per_partition, )
378
+
379
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
380
+ n=output_size_per_partition,
381
+ k=reshaped_x.size(1),
382
+ device=input.device,
383
+ dtype=input.dtype)
384
+
385
+ output = gptq_marlin_gemm(reshaped_x,
386
+ None,
387
+ weight,
388
+ weight_scale,
389
+ None,
390
+ weight_zp,
391
+ g_idx,
392
+ g_idx_sort_indices,
393
+ workspace,
394
+ wtype,
395
+ size_m=reshaped_x.shape[0],
396
+ size_n=output_size_per_partition,
397
+ size_k=input_size_per_partition,
398
+ is_k_full=is_k_full,
399
+ use_atomic_add=use_atomic_add,
400
+ use_fp32_reduce=use_fp32_reduce,
401
+ is_zp_float=False)
402
+
403
+ if bias is not None:
404
+ output.add_(bias) # In-place add
405
+
406
+ return output.reshape(out_shape)
407
+
408
+
409
+ def apply_awq_marlin_linear(
410
+ input: torch.Tensor,
411
+ weight: torch.Tensor,
412
+ weight_scale: torch.Tensor,
413
+ weight_zp: torch.Tensor,
414
+ g_idx: torch.Tensor,
415
+ g_idx_sort_indices: torch.Tensor,
416
+ workspace: torch.Tensor,
417
+ quant_type: ScalarType,
418
+ output_size_per_partition: int,
419
+ input_size_per_partition: int,
420
+ bias: Optional[torch.Tensor] = None,
421
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
422
+ reshaped_x = input.reshape(-1, input.shape[-1])
423
+ out_shape = input.shape[:-1] + (output_size_per_partition, )
424
+
425
+ use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
426
+ n=output_size_per_partition,
427
+ k=reshaped_x.size(1),
428
+ device=input.device,
429
+ dtype=input.dtype)
430
+
431
+ output = gptq_marlin_gemm(reshaped_x,
432
+ None,
433
+ weight,
434
+ weight_scale,
435
+ None,
436
+ weight_zp,
437
+ g_idx,
438
+ g_idx_sort_indices,
439
+ workspace,
440
+ quant_type,
441
+ size_m=reshaped_x.shape[0],
442
+ size_n=output_size_per_partition,
443
+ size_k=input_size_per_partition,
444
+ use_atomic_add=use_atomic_add,
445
+ use_fp32_reduce=use_fp32_reduce,
446
+ is_zp_float=False)
447
+
448
+ if bias is not None:
449
+ output.add_(bias) # In-place add
450
+
451
+ return output.reshape(out_shape)