danieldk HF staff commited on
Commit
c5018b2
·
1 Parent(s): 5c6fb68
Files changed (24) hide show
  1. build/torch24-cxx11-cu118-x86_64-linux/quantization/__init__.py +107 -1
  2. build/torch24-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
  3. build/torch24-cxx11-cu121-x86_64-linux/quantization/__init__.py +107 -1
  4. build/torch24-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
  5. build/torch24-cxx11-cu124-x86_64-linux/quantization/__init__.py +107 -1
  6. build/torch24-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
  7. build/torch24-cxx98-cu118-x86_64-linux/quantization/__init__.py +107 -1
  8. build/torch24-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
  9. build/torch24-cxx98-cu121-x86_64-linux/quantization/__init__.py +107 -1
  10. build/torch24-cxx98-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
  11. build/torch24-cxx98-cu124-x86_64-linux/quantization/__init__.py +107 -1
  12. build/torch24-cxx98-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
  13. build/torch25-cxx11-cu118-x86_64-linux/quantization/__init__.py +107 -1
  14. build/torch25-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
  15. build/torch25-cxx11-cu121-x86_64-linux/quantization/__init__.py +107 -1
  16. build/torch25-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
  17. build/torch25-cxx11-cu124-x86_64-linux/quantization/__init__.py +107 -1
  18. build/torch25-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
  19. build/torch25-cxx98-cu118-x86_64-linux/quantization/__init__.py +107 -1
  20. build/torch25-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
  21. build/torch25-cxx98-cu121-x86_64-linux/quantization/__init__.py +107 -1
  22. build/torch25-cxx98-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
  23. build/torch25-cxx98-cu124-x86_64-linux/quantization/__init__.py +107 -1
  24. build/torch25-cxx98-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +2 -2
build/torch24-cxx11-cu118-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional
2
 
3
  import torch
4
 
@@ -42,3 +42,109 @@ def cutlass_scaled_mm(a: torch.Tensor,
42
 
43
  return out
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
 
3
  import torch
4
 
 
42
 
43
  return out
44
 
45
+ # fp8
46
+ def scaled_fp8_quant(
47
+ input: torch.Tensor,
48
+ scale: Optional[torch.Tensor] = None,
49
+ num_token_padding: Optional[int] = None,
50
+ scale_ub: Optional[torch.Tensor] = None,
51
+ use_per_token_if_dynamic: bool = False,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ """
54
+ Quantize input tensor to FP8 and return quantized tensor and scale.
55
+
56
+ This function supports both static and dynamic quantization: If you
57
+ provide the scale, it will use static scaling and if you omit it,
58
+ the scale will be determined dynamically. The function also allows
59
+ optional padding of the output tensors for downstream kernels that
60
+ will benefit from padding.
61
+
62
+ Args:
63
+ input: The input tensor to be quantized to FP8
64
+ scale: Optional scaling factor for the FP8 quantization
65
+ scale_ub: Optional upper bound for scaling factor in dynamic
66
+ per token case
67
+ num_token_padding: If specified, pad the first dimension
68
+ of the output to at least this value.
69
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
+ in the dynamic quantization case.
71
+
72
+ Returns:
73
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
+ scaling factor.
75
+ """
76
+ # This code assumes batch_dim and num_tokens are flattened
77
+ assert (input.ndim == 2)
78
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
79
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
+ #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
82
+ out_dtype = torch.float8_e4m3fn
83
+ if num_token_padding:
84
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
85
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
+
87
+ if scale is None:
88
+ if use_per_token_if_dynamic:
89
+ scale = torch.empty((shape[0], 1),
90
+ device=input.device,
91
+ dtype=torch.float32)
92
+ ops.dynamic_per_token_scaled_fp8_quant(
93
+ output, input, scale, scale_ub)
94
+ else:
95
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
97
+ else:
98
+ # num_token_padding not implemented for this case
99
+ assert (scale.numel() == 1 or num_token_padding is None)
100
+ ops.static_scaled_fp8_quant(output, input, scale)
101
+
102
+ return output, scale
103
+
104
+ # int8
105
+ def scaled_int8_quant(
106
+ input: torch.Tensor,
107
+ scale: Optional[torch.Tensor] = None,
108
+ azp: Optional[torch.Tensor] = None,
109
+ symmetric: bool = True
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
+ """
112
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
+
114
+ Args:
115
+ input: The input tensor to be quantized to int8.
116
+ scale: Optional scaling factor for the int8 quantization.
117
+ When not provided, we invoke dynamic-per-token quantization.
118
+ azp: Optional zero-point for the int8 quantization.
119
+ Must be provided for asymmetric quantization if `scale` is provided.
120
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
+
122
+ Returns:
123
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
+ """
125
+ output = torch.empty_like(input, dtype=torch.int8)
126
+ if scale is not None:
127
+ # static-per-tensor quantization.
128
+ assert symmetric == (
129
+ azp is
130
+ None), "azp must only be provided for asymmetric quantization."
131
+ ops.static_scaled_int8_quant(output, input, scale, azp)
132
+ return output, scale, azp
133
+
134
+ # dynamic-per-token quantization.
135
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
+ device=input.device,
137
+ dtype=torch.float32)
138
+ input_azp = None if symmetric else torch.empty_like(input_scales,
139
+ dtype=torch.int32)
140
+ ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
+ input_azp)
142
+ return output, input_scales, input_azp
143
+
144
+ # fp8 marlin
145
+ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
+ b_scales: torch.Tensor, workspace: torch.Tensor,
147
+ num_bits: int, size_m: int, size_n: int,
148
+ size_k: int) -> torch.Tensor:
149
+ return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
+ num_bits, size_m, size_n, size_k)
build/torch24-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c9343c97509a78e62cf1f87abbf3bc426f8f85e0c95694b3b2b80740d3cbf280
3
- size 30943736
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ff4df94b3de0caab5c2c7584f21eb3898d495bcf92a731fb1fd9a46ba0dff50
3
+ size 39178896
build/torch24-cxx11-cu121-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional
2
 
3
  import torch
4
 
@@ -42,3 +42,109 @@ def cutlass_scaled_mm(a: torch.Tensor,
42
 
43
  return out
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
 
3
  import torch
4
 
 
42
 
43
  return out
44
 
45
+ # fp8
46
+ def scaled_fp8_quant(
47
+ input: torch.Tensor,
48
+ scale: Optional[torch.Tensor] = None,
49
+ num_token_padding: Optional[int] = None,
50
+ scale_ub: Optional[torch.Tensor] = None,
51
+ use_per_token_if_dynamic: bool = False,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ """
54
+ Quantize input tensor to FP8 and return quantized tensor and scale.
55
+
56
+ This function supports both static and dynamic quantization: If you
57
+ provide the scale, it will use static scaling and if you omit it,
58
+ the scale will be determined dynamically. The function also allows
59
+ optional padding of the output tensors for downstream kernels that
60
+ will benefit from padding.
61
+
62
+ Args:
63
+ input: The input tensor to be quantized to FP8
64
+ scale: Optional scaling factor for the FP8 quantization
65
+ scale_ub: Optional upper bound for scaling factor in dynamic
66
+ per token case
67
+ num_token_padding: If specified, pad the first dimension
68
+ of the output to at least this value.
69
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
+ in the dynamic quantization case.
71
+
72
+ Returns:
73
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
+ scaling factor.
75
+ """
76
+ # This code assumes batch_dim and num_tokens are flattened
77
+ assert (input.ndim == 2)
78
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
79
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
+ #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
82
+ out_dtype = torch.float8_e4m3fn
83
+ if num_token_padding:
84
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
85
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
+
87
+ if scale is None:
88
+ if use_per_token_if_dynamic:
89
+ scale = torch.empty((shape[0], 1),
90
+ device=input.device,
91
+ dtype=torch.float32)
92
+ ops.dynamic_per_token_scaled_fp8_quant(
93
+ output, input, scale, scale_ub)
94
+ else:
95
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
97
+ else:
98
+ # num_token_padding not implemented for this case
99
+ assert (scale.numel() == 1 or num_token_padding is None)
100
+ ops.static_scaled_fp8_quant(output, input, scale)
101
+
102
+ return output, scale
103
+
104
+ # int8
105
+ def scaled_int8_quant(
106
+ input: torch.Tensor,
107
+ scale: Optional[torch.Tensor] = None,
108
+ azp: Optional[torch.Tensor] = None,
109
+ symmetric: bool = True
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
+ """
112
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
+
114
+ Args:
115
+ input: The input tensor to be quantized to int8.
116
+ scale: Optional scaling factor for the int8 quantization.
117
+ When not provided, we invoke dynamic-per-token quantization.
118
+ azp: Optional zero-point for the int8 quantization.
119
+ Must be provided for asymmetric quantization if `scale` is provided.
120
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
+
122
+ Returns:
123
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
+ """
125
+ output = torch.empty_like(input, dtype=torch.int8)
126
+ if scale is not None:
127
+ # static-per-tensor quantization.
128
+ assert symmetric == (
129
+ azp is
130
+ None), "azp must only be provided for asymmetric quantization."
131
+ ops.static_scaled_int8_quant(output, input, scale, azp)
132
+ return output, scale, azp
133
+
134
+ # dynamic-per-token quantization.
135
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
+ device=input.device,
137
+ dtype=torch.float32)
138
+ input_azp = None if symmetric else torch.empty_like(input_scales,
139
+ dtype=torch.int32)
140
+ ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
+ input_azp)
142
+ return output, input_scales, input_azp
143
+
144
+ # fp8 marlin
145
+ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
+ b_scales: torch.Tensor, workspace: torch.Tensor,
147
+ num_bits: int, size_m: int, size_n: int,
148
+ size_k: int) -> torch.Tensor:
149
+ return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
+ num_bits, size_m, size_n, size_k)
build/torch24-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:626d8e62d5801cdca8869b45e4f79893de3ee6637b86f2647e2b1d1bb1452020
3
- size 36253328
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0dd82ba302527185d214e508371d75778524c412934d0f2399bd3d00402b89c5
3
+ size 46540064
build/torch24-cxx11-cu124-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional
2
 
3
  import torch
4
 
@@ -42,3 +42,109 @@ def cutlass_scaled_mm(a: torch.Tensor,
42
 
43
  return out
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
 
3
  import torch
4
 
 
42
 
43
  return out
44
 
45
+ # fp8
46
+ def scaled_fp8_quant(
47
+ input: torch.Tensor,
48
+ scale: Optional[torch.Tensor] = None,
49
+ num_token_padding: Optional[int] = None,
50
+ scale_ub: Optional[torch.Tensor] = None,
51
+ use_per_token_if_dynamic: bool = False,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ """
54
+ Quantize input tensor to FP8 and return quantized tensor and scale.
55
+
56
+ This function supports both static and dynamic quantization: If you
57
+ provide the scale, it will use static scaling and if you omit it,
58
+ the scale will be determined dynamically. The function also allows
59
+ optional padding of the output tensors for downstream kernels that
60
+ will benefit from padding.
61
+
62
+ Args:
63
+ input: The input tensor to be quantized to FP8
64
+ scale: Optional scaling factor for the FP8 quantization
65
+ scale_ub: Optional upper bound for scaling factor in dynamic
66
+ per token case
67
+ num_token_padding: If specified, pad the first dimension
68
+ of the output to at least this value.
69
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
+ in the dynamic quantization case.
71
+
72
+ Returns:
73
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
+ scaling factor.
75
+ """
76
+ # This code assumes batch_dim and num_tokens are flattened
77
+ assert (input.ndim == 2)
78
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
79
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
+ #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
82
+ out_dtype = torch.float8_e4m3fn
83
+ if num_token_padding:
84
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
85
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
+
87
+ if scale is None:
88
+ if use_per_token_if_dynamic:
89
+ scale = torch.empty((shape[0], 1),
90
+ device=input.device,
91
+ dtype=torch.float32)
92
+ ops.dynamic_per_token_scaled_fp8_quant(
93
+ output, input, scale, scale_ub)
94
+ else:
95
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
97
+ else:
98
+ # num_token_padding not implemented for this case
99
+ assert (scale.numel() == 1 or num_token_padding is None)
100
+ ops.static_scaled_fp8_quant(output, input, scale)
101
+
102
+ return output, scale
103
+
104
+ # int8
105
+ def scaled_int8_quant(
106
+ input: torch.Tensor,
107
+ scale: Optional[torch.Tensor] = None,
108
+ azp: Optional[torch.Tensor] = None,
109
+ symmetric: bool = True
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
+ """
112
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
+
114
+ Args:
115
+ input: The input tensor to be quantized to int8.
116
+ scale: Optional scaling factor for the int8 quantization.
117
+ When not provided, we invoke dynamic-per-token quantization.
118
+ azp: Optional zero-point for the int8 quantization.
119
+ Must be provided for asymmetric quantization if `scale` is provided.
120
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
+
122
+ Returns:
123
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
+ """
125
+ output = torch.empty_like(input, dtype=torch.int8)
126
+ if scale is not None:
127
+ # static-per-tensor quantization.
128
+ assert symmetric == (
129
+ azp is
130
+ None), "azp must only be provided for asymmetric quantization."
131
+ ops.static_scaled_int8_quant(output, input, scale, azp)
132
+ return output, scale, azp
133
+
134
+ # dynamic-per-token quantization.
135
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
+ device=input.device,
137
+ dtype=torch.float32)
138
+ input_azp = None if symmetric else torch.empty_like(input_scales,
139
+ dtype=torch.int32)
140
+ ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
+ input_azp)
142
+ return output, input_scales, input_azp
143
+
144
+ # fp8 marlin
145
+ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
+ b_scales: torch.Tensor, workspace: torch.Tensor,
147
+ num_bits: int, size_m: int, size_n: int,
148
+ size_k: int) -> torch.Tensor:
149
+ return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
+ num_bits, size_m, size_n, size_k)
build/torch24-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1b1e9f1b5beb9f5de558dcfcf2d61ecc4e207a723ba810605d70d9aa31e65df5
3
- size 37028144
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cac14195f6181d145e9f1cc0c1e532f8cfa2914fe7ff59fdf3194c85fd28b9c
3
+ size 47413592
build/torch24-cxx98-cu118-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional
2
 
3
  import torch
4
 
@@ -42,3 +42,109 @@ def cutlass_scaled_mm(a: torch.Tensor,
42
 
43
  return out
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
 
3
  import torch
4
 
 
42
 
43
  return out
44
 
45
+ # fp8
46
+ def scaled_fp8_quant(
47
+ input: torch.Tensor,
48
+ scale: Optional[torch.Tensor] = None,
49
+ num_token_padding: Optional[int] = None,
50
+ scale_ub: Optional[torch.Tensor] = None,
51
+ use_per_token_if_dynamic: bool = False,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ """
54
+ Quantize input tensor to FP8 and return quantized tensor and scale.
55
+
56
+ This function supports both static and dynamic quantization: If you
57
+ provide the scale, it will use static scaling and if you omit it,
58
+ the scale will be determined dynamically. The function also allows
59
+ optional padding of the output tensors for downstream kernels that
60
+ will benefit from padding.
61
+
62
+ Args:
63
+ input: The input tensor to be quantized to FP8
64
+ scale: Optional scaling factor for the FP8 quantization
65
+ scale_ub: Optional upper bound for scaling factor in dynamic
66
+ per token case
67
+ num_token_padding: If specified, pad the first dimension
68
+ of the output to at least this value.
69
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
+ in the dynamic quantization case.
71
+
72
+ Returns:
73
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
+ scaling factor.
75
+ """
76
+ # This code assumes batch_dim and num_tokens are flattened
77
+ assert (input.ndim == 2)
78
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
79
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
+ #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
82
+ out_dtype = torch.float8_e4m3fn
83
+ if num_token_padding:
84
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
85
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
+
87
+ if scale is None:
88
+ if use_per_token_if_dynamic:
89
+ scale = torch.empty((shape[0], 1),
90
+ device=input.device,
91
+ dtype=torch.float32)
92
+ ops.dynamic_per_token_scaled_fp8_quant(
93
+ output, input, scale, scale_ub)
94
+ else:
95
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
97
+ else:
98
+ # num_token_padding not implemented for this case
99
+ assert (scale.numel() == 1 or num_token_padding is None)
100
+ ops.static_scaled_fp8_quant(output, input, scale)
101
+
102
+ return output, scale
103
+
104
+ # int8
105
+ def scaled_int8_quant(
106
+ input: torch.Tensor,
107
+ scale: Optional[torch.Tensor] = None,
108
+ azp: Optional[torch.Tensor] = None,
109
+ symmetric: bool = True
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
+ """
112
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
+
114
+ Args:
115
+ input: The input tensor to be quantized to int8.
116
+ scale: Optional scaling factor for the int8 quantization.
117
+ When not provided, we invoke dynamic-per-token quantization.
118
+ azp: Optional zero-point for the int8 quantization.
119
+ Must be provided for asymmetric quantization if `scale` is provided.
120
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
+
122
+ Returns:
123
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
+ """
125
+ output = torch.empty_like(input, dtype=torch.int8)
126
+ if scale is not None:
127
+ # static-per-tensor quantization.
128
+ assert symmetric == (
129
+ azp is
130
+ None), "azp must only be provided for asymmetric quantization."
131
+ ops.static_scaled_int8_quant(output, input, scale, azp)
132
+ return output, scale, azp
133
+
134
+ # dynamic-per-token quantization.
135
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
+ device=input.device,
137
+ dtype=torch.float32)
138
+ input_azp = None if symmetric else torch.empty_like(input_scales,
139
+ dtype=torch.int32)
140
+ ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
+ input_azp)
142
+ return output, input_scales, input_azp
143
+
144
+ # fp8 marlin
145
+ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
+ b_scales: torch.Tensor, workspace: torch.Tensor,
147
+ num_bits: int, size_m: int, size_n: int,
148
+ size_k: int) -> torch.Tensor:
149
+ return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
+ num_bits, size_m, size_n, size_k)
build/torch24-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6d5fa1dba02499bec6e2225c83e460ae0d8e7ca396763d2851f7e000be88e675
3
- size 30940256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0054660979a4c0a273f32b6293294d02048b14a90af3b1f3e5cb226504cffe15
3
+ size 39166248
build/torch24-cxx98-cu121-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional
2
 
3
  import torch
4
 
@@ -42,3 +42,109 @@ def cutlass_scaled_mm(a: torch.Tensor,
42
 
43
  return out
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
 
3
  import torch
4
 
 
42
 
43
  return out
44
 
45
+ # fp8
46
+ def scaled_fp8_quant(
47
+ input: torch.Tensor,
48
+ scale: Optional[torch.Tensor] = None,
49
+ num_token_padding: Optional[int] = None,
50
+ scale_ub: Optional[torch.Tensor] = None,
51
+ use_per_token_if_dynamic: bool = False,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ """
54
+ Quantize input tensor to FP8 and return quantized tensor and scale.
55
+
56
+ This function supports both static and dynamic quantization: If you
57
+ provide the scale, it will use static scaling and if you omit it,
58
+ the scale will be determined dynamically. The function also allows
59
+ optional padding of the output tensors for downstream kernels that
60
+ will benefit from padding.
61
+
62
+ Args:
63
+ input: The input tensor to be quantized to FP8
64
+ scale: Optional scaling factor for the FP8 quantization
65
+ scale_ub: Optional upper bound for scaling factor in dynamic
66
+ per token case
67
+ num_token_padding: If specified, pad the first dimension
68
+ of the output to at least this value.
69
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
+ in the dynamic quantization case.
71
+
72
+ Returns:
73
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
+ scaling factor.
75
+ """
76
+ # This code assumes batch_dim and num_tokens are flattened
77
+ assert (input.ndim == 2)
78
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
79
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
+ #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
82
+ out_dtype = torch.float8_e4m3fn
83
+ if num_token_padding:
84
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
85
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
+
87
+ if scale is None:
88
+ if use_per_token_if_dynamic:
89
+ scale = torch.empty((shape[0], 1),
90
+ device=input.device,
91
+ dtype=torch.float32)
92
+ ops.dynamic_per_token_scaled_fp8_quant(
93
+ output, input, scale, scale_ub)
94
+ else:
95
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
97
+ else:
98
+ # num_token_padding not implemented for this case
99
+ assert (scale.numel() == 1 or num_token_padding is None)
100
+ ops.static_scaled_fp8_quant(output, input, scale)
101
+
102
+ return output, scale
103
+
104
+ # int8
105
+ def scaled_int8_quant(
106
+ input: torch.Tensor,
107
+ scale: Optional[torch.Tensor] = None,
108
+ azp: Optional[torch.Tensor] = None,
109
+ symmetric: bool = True
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
+ """
112
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
+
114
+ Args:
115
+ input: The input tensor to be quantized to int8.
116
+ scale: Optional scaling factor for the int8 quantization.
117
+ When not provided, we invoke dynamic-per-token quantization.
118
+ azp: Optional zero-point for the int8 quantization.
119
+ Must be provided for asymmetric quantization if `scale` is provided.
120
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
+
122
+ Returns:
123
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
+ """
125
+ output = torch.empty_like(input, dtype=torch.int8)
126
+ if scale is not None:
127
+ # static-per-tensor quantization.
128
+ assert symmetric == (
129
+ azp is
130
+ None), "azp must only be provided for asymmetric quantization."
131
+ ops.static_scaled_int8_quant(output, input, scale, azp)
132
+ return output, scale, azp
133
+
134
+ # dynamic-per-token quantization.
135
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
+ device=input.device,
137
+ dtype=torch.float32)
138
+ input_azp = None if symmetric else torch.empty_like(input_scales,
139
+ dtype=torch.int32)
140
+ ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
+ input_azp)
142
+ return output, input_scales, input_azp
143
+
144
+ # fp8 marlin
145
+ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
+ b_scales: torch.Tensor, workspace: torch.Tensor,
147
+ num_bits: int, size_m: int, size_n: int,
148
+ size_k: int) -> torch.Tensor:
149
+ return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
+ num_bits, size_m, size_n, size_k)
build/torch24-cxx98-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1bd3556b97cd85c7282ded2057e59207575b92107e0a7e17fcffa82d27f42d26
3
- size 36256840
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:735a3a1fe3aea0065a7378f11611378923331d7633000441fa5d0d5e03d0d481
3
+ size 46534608
build/torch24-cxx98-cu124-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional
2
 
3
  import torch
4
 
@@ -42,3 +42,109 @@ def cutlass_scaled_mm(a: torch.Tensor,
42
 
43
  return out
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
 
3
  import torch
4
 
 
42
 
43
  return out
44
 
45
+ # fp8
46
+ def scaled_fp8_quant(
47
+ input: torch.Tensor,
48
+ scale: Optional[torch.Tensor] = None,
49
+ num_token_padding: Optional[int] = None,
50
+ scale_ub: Optional[torch.Tensor] = None,
51
+ use_per_token_if_dynamic: bool = False,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ """
54
+ Quantize input tensor to FP8 and return quantized tensor and scale.
55
+
56
+ This function supports both static and dynamic quantization: If you
57
+ provide the scale, it will use static scaling and if you omit it,
58
+ the scale will be determined dynamically. The function also allows
59
+ optional padding of the output tensors for downstream kernels that
60
+ will benefit from padding.
61
+
62
+ Args:
63
+ input: The input tensor to be quantized to FP8
64
+ scale: Optional scaling factor for the FP8 quantization
65
+ scale_ub: Optional upper bound for scaling factor in dynamic
66
+ per token case
67
+ num_token_padding: If specified, pad the first dimension
68
+ of the output to at least this value.
69
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
+ in the dynamic quantization case.
71
+
72
+ Returns:
73
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
+ scaling factor.
75
+ """
76
+ # This code assumes batch_dim and num_tokens are flattened
77
+ assert (input.ndim == 2)
78
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
79
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
+ #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
82
+ out_dtype = torch.float8_e4m3fn
83
+ if num_token_padding:
84
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
85
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
+
87
+ if scale is None:
88
+ if use_per_token_if_dynamic:
89
+ scale = torch.empty((shape[0], 1),
90
+ device=input.device,
91
+ dtype=torch.float32)
92
+ ops.dynamic_per_token_scaled_fp8_quant(
93
+ output, input, scale, scale_ub)
94
+ else:
95
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
97
+ else:
98
+ # num_token_padding not implemented for this case
99
+ assert (scale.numel() == 1 or num_token_padding is None)
100
+ ops.static_scaled_fp8_quant(output, input, scale)
101
+
102
+ return output, scale
103
+
104
+ # int8
105
+ def scaled_int8_quant(
106
+ input: torch.Tensor,
107
+ scale: Optional[torch.Tensor] = None,
108
+ azp: Optional[torch.Tensor] = None,
109
+ symmetric: bool = True
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
+ """
112
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
+
114
+ Args:
115
+ input: The input tensor to be quantized to int8.
116
+ scale: Optional scaling factor for the int8 quantization.
117
+ When not provided, we invoke dynamic-per-token quantization.
118
+ azp: Optional zero-point for the int8 quantization.
119
+ Must be provided for asymmetric quantization if `scale` is provided.
120
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
+
122
+ Returns:
123
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
+ """
125
+ output = torch.empty_like(input, dtype=torch.int8)
126
+ if scale is not None:
127
+ # static-per-tensor quantization.
128
+ assert symmetric == (
129
+ azp is
130
+ None), "azp must only be provided for asymmetric quantization."
131
+ ops.static_scaled_int8_quant(output, input, scale, azp)
132
+ return output, scale, azp
133
+
134
+ # dynamic-per-token quantization.
135
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
+ device=input.device,
137
+ dtype=torch.float32)
138
+ input_azp = None if symmetric else torch.empty_like(input_scales,
139
+ dtype=torch.int32)
140
+ ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
+ input_azp)
142
+ return output, input_scales, input_azp
143
+
144
+ # fp8 marlin
145
+ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
+ b_scales: torch.Tensor, workspace: torch.Tensor,
147
+ num_bits: int, size_m: int, size_n: int,
148
+ size_k: int) -> torch.Tensor:
149
+ return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
+ num_bits, size_m, size_n, size_k)
build/torch24-cxx98-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f7aa8e42070dfb49e64cfa204bbb81ab2750c2cfafacac6f8d4ae303285a735a
3
- size 37027640
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b30e0d4e60669a253ed18f92b5219fdc3b40c2f0e792bf275c62a058998ebad
3
+ size 47404040
build/torch25-cxx11-cu118-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional
2
 
3
  import torch
4
 
@@ -42,3 +42,109 @@ def cutlass_scaled_mm(a: torch.Tensor,
42
 
43
  return out
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
 
3
  import torch
4
 
 
42
 
43
  return out
44
 
45
+ # fp8
46
+ def scaled_fp8_quant(
47
+ input: torch.Tensor,
48
+ scale: Optional[torch.Tensor] = None,
49
+ num_token_padding: Optional[int] = None,
50
+ scale_ub: Optional[torch.Tensor] = None,
51
+ use_per_token_if_dynamic: bool = False,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ """
54
+ Quantize input tensor to FP8 and return quantized tensor and scale.
55
+
56
+ This function supports both static and dynamic quantization: If you
57
+ provide the scale, it will use static scaling and if you omit it,
58
+ the scale will be determined dynamically. The function also allows
59
+ optional padding of the output tensors for downstream kernels that
60
+ will benefit from padding.
61
+
62
+ Args:
63
+ input: The input tensor to be quantized to FP8
64
+ scale: Optional scaling factor for the FP8 quantization
65
+ scale_ub: Optional upper bound for scaling factor in dynamic
66
+ per token case
67
+ num_token_padding: If specified, pad the first dimension
68
+ of the output to at least this value.
69
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
+ in the dynamic quantization case.
71
+
72
+ Returns:
73
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
+ scaling factor.
75
+ """
76
+ # This code assumes batch_dim and num_tokens are flattened
77
+ assert (input.ndim == 2)
78
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
79
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
+ #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
82
+ out_dtype = torch.float8_e4m3fn
83
+ if num_token_padding:
84
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
85
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
+
87
+ if scale is None:
88
+ if use_per_token_if_dynamic:
89
+ scale = torch.empty((shape[0], 1),
90
+ device=input.device,
91
+ dtype=torch.float32)
92
+ ops.dynamic_per_token_scaled_fp8_quant(
93
+ output, input, scale, scale_ub)
94
+ else:
95
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
97
+ else:
98
+ # num_token_padding not implemented for this case
99
+ assert (scale.numel() == 1 or num_token_padding is None)
100
+ ops.static_scaled_fp8_quant(output, input, scale)
101
+
102
+ return output, scale
103
+
104
+ # int8
105
+ def scaled_int8_quant(
106
+ input: torch.Tensor,
107
+ scale: Optional[torch.Tensor] = None,
108
+ azp: Optional[torch.Tensor] = None,
109
+ symmetric: bool = True
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
+ """
112
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
+
114
+ Args:
115
+ input: The input tensor to be quantized to int8.
116
+ scale: Optional scaling factor for the int8 quantization.
117
+ When not provided, we invoke dynamic-per-token quantization.
118
+ azp: Optional zero-point for the int8 quantization.
119
+ Must be provided for asymmetric quantization if `scale` is provided.
120
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
+
122
+ Returns:
123
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
+ """
125
+ output = torch.empty_like(input, dtype=torch.int8)
126
+ if scale is not None:
127
+ # static-per-tensor quantization.
128
+ assert symmetric == (
129
+ azp is
130
+ None), "azp must only be provided for asymmetric quantization."
131
+ ops.static_scaled_int8_quant(output, input, scale, azp)
132
+ return output, scale, azp
133
+
134
+ # dynamic-per-token quantization.
135
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
+ device=input.device,
137
+ dtype=torch.float32)
138
+ input_azp = None if symmetric else torch.empty_like(input_scales,
139
+ dtype=torch.int32)
140
+ ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
+ input_azp)
142
+ return output, input_scales, input_azp
143
+
144
+ # fp8 marlin
145
+ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
+ b_scales: torch.Tensor, workspace: torch.Tensor,
147
+ num_bits: int, size_m: int, size_n: int,
148
+ size_k: int) -> torch.Tensor:
149
+ return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
+ num_bits, size_m, size_n, size_k)
build/torch25-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2a9874fefa371528c54a12851771476c3baf9356536b88c960d7ec3dcf293469
3
- size 30943736
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd1b0623532cab4059a48c8e8e7417df9ba309968adcd03705012fa79b04776d
3
+ size 39178896
build/torch25-cxx11-cu121-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional
2
 
3
  import torch
4
 
@@ -42,3 +42,109 @@ def cutlass_scaled_mm(a: torch.Tensor,
42
 
43
  return out
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
 
3
  import torch
4
 
 
42
 
43
  return out
44
 
45
+ # fp8
46
+ def scaled_fp8_quant(
47
+ input: torch.Tensor,
48
+ scale: Optional[torch.Tensor] = None,
49
+ num_token_padding: Optional[int] = None,
50
+ scale_ub: Optional[torch.Tensor] = None,
51
+ use_per_token_if_dynamic: bool = False,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ """
54
+ Quantize input tensor to FP8 and return quantized tensor and scale.
55
+
56
+ This function supports both static and dynamic quantization: If you
57
+ provide the scale, it will use static scaling and if you omit it,
58
+ the scale will be determined dynamically. The function also allows
59
+ optional padding of the output tensors for downstream kernels that
60
+ will benefit from padding.
61
+
62
+ Args:
63
+ input: The input tensor to be quantized to FP8
64
+ scale: Optional scaling factor for the FP8 quantization
65
+ scale_ub: Optional upper bound for scaling factor in dynamic
66
+ per token case
67
+ num_token_padding: If specified, pad the first dimension
68
+ of the output to at least this value.
69
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
+ in the dynamic quantization case.
71
+
72
+ Returns:
73
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
+ scaling factor.
75
+ """
76
+ # This code assumes batch_dim and num_tokens are flattened
77
+ assert (input.ndim == 2)
78
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
79
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
+ #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
82
+ out_dtype = torch.float8_e4m3fn
83
+ if num_token_padding:
84
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
85
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
+
87
+ if scale is None:
88
+ if use_per_token_if_dynamic:
89
+ scale = torch.empty((shape[0], 1),
90
+ device=input.device,
91
+ dtype=torch.float32)
92
+ ops.dynamic_per_token_scaled_fp8_quant(
93
+ output, input, scale, scale_ub)
94
+ else:
95
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
97
+ else:
98
+ # num_token_padding not implemented for this case
99
+ assert (scale.numel() == 1 or num_token_padding is None)
100
+ ops.static_scaled_fp8_quant(output, input, scale)
101
+
102
+ return output, scale
103
+
104
+ # int8
105
+ def scaled_int8_quant(
106
+ input: torch.Tensor,
107
+ scale: Optional[torch.Tensor] = None,
108
+ azp: Optional[torch.Tensor] = None,
109
+ symmetric: bool = True
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
+ """
112
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
+
114
+ Args:
115
+ input: The input tensor to be quantized to int8.
116
+ scale: Optional scaling factor for the int8 quantization.
117
+ When not provided, we invoke dynamic-per-token quantization.
118
+ azp: Optional zero-point for the int8 quantization.
119
+ Must be provided for asymmetric quantization if `scale` is provided.
120
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
+
122
+ Returns:
123
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
+ """
125
+ output = torch.empty_like(input, dtype=torch.int8)
126
+ if scale is not None:
127
+ # static-per-tensor quantization.
128
+ assert symmetric == (
129
+ azp is
130
+ None), "azp must only be provided for asymmetric quantization."
131
+ ops.static_scaled_int8_quant(output, input, scale, azp)
132
+ return output, scale, azp
133
+
134
+ # dynamic-per-token quantization.
135
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
+ device=input.device,
137
+ dtype=torch.float32)
138
+ input_azp = None if symmetric else torch.empty_like(input_scales,
139
+ dtype=torch.int32)
140
+ ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
+ input_azp)
142
+ return output, input_scales, input_azp
143
+
144
+ # fp8 marlin
145
+ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
+ b_scales: torch.Tensor, workspace: torch.Tensor,
147
+ num_bits: int, size_m: int, size_n: int,
148
+ size_k: int) -> torch.Tensor:
149
+ return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
+ num_bits, size_m, size_n, size_k)
build/torch25-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:554ac8e33120544e28fb91abab1f03b29e6665256c5acfec69137072575b7945
3
- size 36253328
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c129629d41aedd275b5a5545f86d968f0e9cc63085e54b09f0daff930af8f48c
3
+ size 46540064
build/torch25-cxx11-cu124-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional
2
 
3
  import torch
4
 
@@ -42,3 +42,109 @@ def cutlass_scaled_mm(a: torch.Tensor,
42
 
43
  return out
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
 
3
  import torch
4
 
 
42
 
43
  return out
44
 
45
+ # fp8
46
+ def scaled_fp8_quant(
47
+ input: torch.Tensor,
48
+ scale: Optional[torch.Tensor] = None,
49
+ num_token_padding: Optional[int] = None,
50
+ scale_ub: Optional[torch.Tensor] = None,
51
+ use_per_token_if_dynamic: bool = False,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ """
54
+ Quantize input tensor to FP8 and return quantized tensor and scale.
55
+
56
+ This function supports both static and dynamic quantization: If you
57
+ provide the scale, it will use static scaling and if you omit it,
58
+ the scale will be determined dynamically. The function also allows
59
+ optional padding of the output tensors for downstream kernels that
60
+ will benefit from padding.
61
+
62
+ Args:
63
+ input: The input tensor to be quantized to FP8
64
+ scale: Optional scaling factor for the FP8 quantization
65
+ scale_ub: Optional upper bound for scaling factor in dynamic
66
+ per token case
67
+ num_token_padding: If specified, pad the first dimension
68
+ of the output to at least this value.
69
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
+ in the dynamic quantization case.
71
+
72
+ Returns:
73
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
+ scaling factor.
75
+ """
76
+ # This code assumes batch_dim and num_tokens are flattened
77
+ assert (input.ndim == 2)
78
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
79
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
+ #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
82
+ out_dtype = torch.float8_e4m3fn
83
+ if num_token_padding:
84
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
85
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
+
87
+ if scale is None:
88
+ if use_per_token_if_dynamic:
89
+ scale = torch.empty((shape[0], 1),
90
+ device=input.device,
91
+ dtype=torch.float32)
92
+ ops.dynamic_per_token_scaled_fp8_quant(
93
+ output, input, scale, scale_ub)
94
+ else:
95
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
97
+ else:
98
+ # num_token_padding not implemented for this case
99
+ assert (scale.numel() == 1 or num_token_padding is None)
100
+ ops.static_scaled_fp8_quant(output, input, scale)
101
+
102
+ return output, scale
103
+
104
+ # int8
105
+ def scaled_int8_quant(
106
+ input: torch.Tensor,
107
+ scale: Optional[torch.Tensor] = None,
108
+ azp: Optional[torch.Tensor] = None,
109
+ symmetric: bool = True
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
+ """
112
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
+
114
+ Args:
115
+ input: The input tensor to be quantized to int8.
116
+ scale: Optional scaling factor for the int8 quantization.
117
+ When not provided, we invoke dynamic-per-token quantization.
118
+ azp: Optional zero-point for the int8 quantization.
119
+ Must be provided for asymmetric quantization if `scale` is provided.
120
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
+
122
+ Returns:
123
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
+ """
125
+ output = torch.empty_like(input, dtype=torch.int8)
126
+ if scale is not None:
127
+ # static-per-tensor quantization.
128
+ assert symmetric == (
129
+ azp is
130
+ None), "azp must only be provided for asymmetric quantization."
131
+ ops.static_scaled_int8_quant(output, input, scale, azp)
132
+ return output, scale, azp
133
+
134
+ # dynamic-per-token quantization.
135
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
+ device=input.device,
137
+ dtype=torch.float32)
138
+ input_azp = None if symmetric else torch.empty_like(input_scales,
139
+ dtype=torch.int32)
140
+ ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
+ input_azp)
142
+ return output, input_scales, input_azp
143
+
144
+ # fp8 marlin
145
+ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
+ b_scales: torch.Tensor, workspace: torch.Tensor,
147
+ num_bits: int, size_m: int, size_n: int,
148
+ size_k: int) -> torch.Tensor:
149
+ return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
+ num_bits, size_m, size_n, size_k)
build/torch25-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:373d2248e3ffc236b6b86521dabe2604609e8646107eb3c2d74a9378490a8878
3
- size 37028144
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d691489858bbbac2d9031f49378d34d6c1ec3ed9592c933916cdb0fd470b4e54
3
+ size 47413592
build/torch25-cxx98-cu118-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional
2
 
3
  import torch
4
 
@@ -42,3 +42,109 @@ def cutlass_scaled_mm(a: torch.Tensor,
42
 
43
  return out
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
 
3
  import torch
4
 
 
42
 
43
  return out
44
 
45
+ # fp8
46
+ def scaled_fp8_quant(
47
+ input: torch.Tensor,
48
+ scale: Optional[torch.Tensor] = None,
49
+ num_token_padding: Optional[int] = None,
50
+ scale_ub: Optional[torch.Tensor] = None,
51
+ use_per_token_if_dynamic: bool = False,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ """
54
+ Quantize input tensor to FP8 and return quantized tensor and scale.
55
+
56
+ This function supports both static and dynamic quantization: If you
57
+ provide the scale, it will use static scaling and if you omit it,
58
+ the scale will be determined dynamically. The function also allows
59
+ optional padding of the output tensors for downstream kernels that
60
+ will benefit from padding.
61
+
62
+ Args:
63
+ input: The input tensor to be quantized to FP8
64
+ scale: Optional scaling factor for the FP8 quantization
65
+ scale_ub: Optional upper bound for scaling factor in dynamic
66
+ per token case
67
+ num_token_padding: If specified, pad the first dimension
68
+ of the output to at least this value.
69
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
+ in the dynamic quantization case.
71
+
72
+ Returns:
73
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
+ scaling factor.
75
+ """
76
+ # This code assumes batch_dim and num_tokens are flattened
77
+ assert (input.ndim == 2)
78
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
79
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
+ #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
82
+ out_dtype = torch.float8_e4m3fn
83
+ if num_token_padding:
84
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
85
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
+
87
+ if scale is None:
88
+ if use_per_token_if_dynamic:
89
+ scale = torch.empty((shape[0], 1),
90
+ device=input.device,
91
+ dtype=torch.float32)
92
+ ops.dynamic_per_token_scaled_fp8_quant(
93
+ output, input, scale, scale_ub)
94
+ else:
95
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
97
+ else:
98
+ # num_token_padding not implemented for this case
99
+ assert (scale.numel() == 1 or num_token_padding is None)
100
+ ops.static_scaled_fp8_quant(output, input, scale)
101
+
102
+ return output, scale
103
+
104
+ # int8
105
+ def scaled_int8_quant(
106
+ input: torch.Tensor,
107
+ scale: Optional[torch.Tensor] = None,
108
+ azp: Optional[torch.Tensor] = None,
109
+ symmetric: bool = True
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
+ """
112
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
+
114
+ Args:
115
+ input: The input tensor to be quantized to int8.
116
+ scale: Optional scaling factor for the int8 quantization.
117
+ When not provided, we invoke dynamic-per-token quantization.
118
+ azp: Optional zero-point for the int8 quantization.
119
+ Must be provided for asymmetric quantization if `scale` is provided.
120
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
+
122
+ Returns:
123
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
+ """
125
+ output = torch.empty_like(input, dtype=torch.int8)
126
+ if scale is not None:
127
+ # static-per-tensor quantization.
128
+ assert symmetric == (
129
+ azp is
130
+ None), "azp must only be provided for asymmetric quantization."
131
+ ops.static_scaled_int8_quant(output, input, scale, azp)
132
+ return output, scale, azp
133
+
134
+ # dynamic-per-token quantization.
135
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
+ device=input.device,
137
+ dtype=torch.float32)
138
+ input_azp = None if symmetric else torch.empty_like(input_scales,
139
+ dtype=torch.int32)
140
+ ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
+ input_azp)
142
+ return output, input_scales, input_azp
143
+
144
+ # fp8 marlin
145
+ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
+ b_scales: torch.Tensor, workspace: torch.Tensor,
147
+ num_bits: int, size_m: int, size_n: int,
148
+ size_k: int) -> torch.Tensor:
149
+ return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
+ num_bits, size_m, size_n, size_k)
build/torch25-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7bf3ca9934fb5cf098acbd3e41eab1004f358b49c7c459b01649753dca258d97
3
- size 30940256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc0a49e4a96613598d16cb392b3e5580c1461e2cb6ca291876aeeb4c1afeabf7
3
+ size 39166248
build/torch25-cxx98-cu121-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional
2
 
3
  import torch
4
 
@@ -42,3 +42,109 @@ def cutlass_scaled_mm(a: torch.Tensor,
42
 
43
  return out
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
 
3
  import torch
4
 
 
42
 
43
  return out
44
 
45
+ # fp8
46
+ def scaled_fp8_quant(
47
+ input: torch.Tensor,
48
+ scale: Optional[torch.Tensor] = None,
49
+ num_token_padding: Optional[int] = None,
50
+ scale_ub: Optional[torch.Tensor] = None,
51
+ use_per_token_if_dynamic: bool = False,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ """
54
+ Quantize input tensor to FP8 and return quantized tensor and scale.
55
+
56
+ This function supports both static and dynamic quantization: If you
57
+ provide the scale, it will use static scaling and if you omit it,
58
+ the scale will be determined dynamically. The function also allows
59
+ optional padding of the output tensors for downstream kernels that
60
+ will benefit from padding.
61
+
62
+ Args:
63
+ input: The input tensor to be quantized to FP8
64
+ scale: Optional scaling factor for the FP8 quantization
65
+ scale_ub: Optional upper bound for scaling factor in dynamic
66
+ per token case
67
+ num_token_padding: If specified, pad the first dimension
68
+ of the output to at least this value.
69
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
+ in the dynamic quantization case.
71
+
72
+ Returns:
73
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
+ scaling factor.
75
+ """
76
+ # This code assumes batch_dim and num_tokens are flattened
77
+ assert (input.ndim == 2)
78
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
79
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
+ #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
82
+ out_dtype = torch.float8_e4m3fn
83
+ if num_token_padding:
84
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
85
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
+
87
+ if scale is None:
88
+ if use_per_token_if_dynamic:
89
+ scale = torch.empty((shape[0], 1),
90
+ device=input.device,
91
+ dtype=torch.float32)
92
+ ops.dynamic_per_token_scaled_fp8_quant(
93
+ output, input, scale, scale_ub)
94
+ else:
95
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
97
+ else:
98
+ # num_token_padding not implemented for this case
99
+ assert (scale.numel() == 1 or num_token_padding is None)
100
+ ops.static_scaled_fp8_quant(output, input, scale)
101
+
102
+ return output, scale
103
+
104
+ # int8
105
+ def scaled_int8_quant(
106
+ input: torch.Tensor,
107
+ scale: Optional[torch.Tensor] = None,
108
+ azp: Optional[torch.Tensor] = None,
109
+ symmetric: bool = True
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
+ """
112
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
+
114
+ Args:
115
+ input: The input tensor to be quantized to int8.
116
+ scale: Optional scaling factor for the int8 quantization.
117
+ When not provided, we invoke dynamic-per-token quantization.
118
+ azp: Optional zero-point for the int8 quantization.
119
+ Must be provided for asymmetric quantization if `scale` is provided.
120
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
+
122
+ Returns:
123
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
+ """
125
+ output = torch.empty_like(input, dtype=torch.int8)
126
+ if scale is not None:
127
+ # static-per-tensor quantization.
128
+ assert symmetric == (
129
+ azp is
130
+ None), "azp must only be provided for asymmetric quantization."
131
+ ops.static_scaled_int8_quant(output, input, scale, azp)
132
+ return output, scale, azp
133
+
134
+ # dynamic-per-token quantization.
135
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
+ device=input.device,
137
+ dtype=torch.float32)
138
+ input_azp = None if symmetric else torch.empty_like(input_scales,
139
+ dtype=torch.int32)
140
+ ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
+ input_azp)
142
+ return output, input_scales, input_azp
143
+
144
+ # fp8 marlin
145
+ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
+ b_scales: torch.Tensor, workspace: torch.Tensor,
147
+ num_bits: int, size_m: int, size_n: int,
148
+ size_k: int) -> torch.Tensor:
149
+ return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
+ num_bits, size_m, size_n, size_k)
build/torch25-cxx98-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5582e29a80ea71e618ccdeb9ad49456d3ebf130ee4c3001b1904e73094460455
3
- size 36256840
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8b055cc65c6680f15caf698d12f7e0b87332132e7489b689381643873a518a0
3
+ size 46534608
build/torch25-cxx98-cu124-x86_64-linux/quantization/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional
2
 
3
  import torch
4
 
@@ -42,3 +42,109 @@ def cutlass_scaled_mm(a: torch.Tensor,
42
 
43
  return out
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
 
3
  import torch
4
 
 
42
 
43
  return out
44
 
45
+ # fp8
46
+ def scaled_fp8_quant(
47
+ input: torch.Tensor,
48
+ scale: Optional[torch.Tensor] = None,
49
+ num_token_padding: Optional[int] = None,
50
+ scale_ub: Optional[torch.Tensor] = None,
51
+ use_per_token_if_dynamic: bool = False,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ """
54
+ Quantize input tensor to FP8 and return quantized tensor and scale.
55
+
56
+ This function supports both static and dynamic quantization: If you
57
+ provide the scale, it will use static scaling and if you omit it,
58
+ the scale will be determined dynamically. The function also allows
59
+ optional padding of the output tensors for downstream kernels that
60
+ will benefit from padding.
61
+
62
+ Args:
63
+ input: The input tensor to be quantized to FP8
64
+ scale: Optional scaling factor for the FP8 quantization
65
+ scale_ub: Optional upper bound for scaling factor in dynamic
66
+ per token case
67
+ num_token_padding: If specified, pad the first dimension
68
+ of the output to at least this value.
69
+ use_per_token_if_dynamic: Whether to do per_tensor or per_token
70
+ in the dynamic quantization case.
71
+
72
+ Returns:
73
+ Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
74
+ scaling factor.
75
+ """
76
+ # This code assumes batch_dim and num_tokens are flattened
77
+ assert (input.ndim == 2)
78
+ shape: Union[Tuple[int, int], torch.Size] = input.shape
79
+ # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
80
+ #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
81
+ # if current_platform.is_rocm() else torch.float8_e4m3fn
82
+ out_dtype = torch.float8_e4m3fn
83
+ if num_token_padding:
84
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
85
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
86
+
87
+ if scale is None:
88
+ if use_per_token_if_dynamic:
89
+ scale = torch.empty((shape[0], 1),
90
+ device=input.device,
91
+ dtype=torch.float32)
92
+ ops.dynamic_per_token_scaled_fp8_quant(
93
+ output, input, scale, scale_ub)
94
+ else:
95
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
96
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
97
+ else:
98
+ # num_token_padding not implemented for this case
99
+ assert (scale.numel() == 1 or num_token_padding is None)
100
+ ops.static_scaled_fp8_quant(output, input, scale)
101
+
102
+ return output, scale
103
+
104
+ # int8
105
+ def scaled_int8_quant(
106
+ input: torch.Tensor,
107
+ scale: Optional[torch.Tensor] = None,
108
+ azp: Optional[torch.Tensor] = None,
109
+ symmetric: bool = True
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
111
+ """
112
+ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
113
+
114
+ Args:
115
+ input: The input tensor to be quantized to int8.
116
+ scale: Optional scaling factor for the int8 quantization.
117
+ When not provided, we invoke dynamic-per-token quantization.
118
+ azp: Optional zero-point for the int8 quantization.
119
+ Must be provided for asymmetric quantization if `scale` is provided.
120
+ symmetric: Whether to use symmetric quantization (scale only, azp ignored).
121
+
122
+ Returns:
123
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
124
+ """
125
+ output = torch.empty_like(input, dtype=torch.int8)
126
+ if scale is not None:
127
+ # static-per-tensor quantization.
128
+ assert symmetric == (
129
+ azp is
130
+ None), "azp must only be provided for asymmetric quantization."
131
+ ops.static_scaled_int8_quant(output, input, scale, azp)
132
+ return output, scale, azp
133
+
134
+ # dynamic-per-token quantization.
135
+ input_scales = torch.empty((input.numel() // input.shape[-1], 1),
136
+ device=input.device,
137
+ dtype=torch.float32)
138
+ input_azp = None if symmetric else torch.empty_like(input_scales,
139
+ dtype=torch.int32)
140
+ ops.dynamic_scaled_int8_quant(output, input, input_scales,
141
+ input_azp)
142
+ return output, input_scales, input_azp
143
+
144
+ # fp8 marlin
145
+ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
146
+ b_scales: torch.Tensor, workspace: torch.Tensor,
147
+ num_bits: int, size_m: int, size_n: int,
148
+ size_k: int) -> torch.Tensor:
149
+ return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
150
+ num_bits, size_m, size_n, size_k)
build/torch25-cxx98-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0d16013abe29e9d3914a19078bf6195bf6a9402222fe48dd75e5a3827e081f49
3
- size 37027640
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e9f624b3f8585ca9c688a08b7191f1a2212cd1b64701ba9b0075deedd3fe3d4
3
+ size 47404040