Add full Marlin support and tests for Marlin/CUTLASS
Browse files- build.toml +19 -4
 - ext-torch/__init__.py +30 -177
 - ext-torch/compressed_tensors.py +110 -0
 - ext-torch/cutlass.py +75 -0
 - ext-torch/marlin.py +208 -0
 - ext-torch/scalar_type.py +330 -0
 - ext-torch/torch_binding.cpp +29 -3
 - ext-torch/torch_binding.h +23 -0
 - ext-torch/utils/marlin_utils.py +391 -0
 - ext-torch/utils/marlin_utils_fp8.py +100 -0
 - ext-torch/utils/marlin_utils_test.py +162 -0
 - ext-torch/utils/marlin_utils_test_24.py +473 -0
 - ext-torch/utils/marlin_utils_test_qqq.py +125 -0
 - ext-torch/utils/quant_utils.py +470 -0
 - marlin/dense/LICENSE +209 -0
 - marlin/dense/common/base.h +32 -0
 - marlin/dense/common/mem.h +89 -0
 - marlin/dense/marlin_cuda_kernel.cu +1068 -0
 - marlin/qqq/marlin_qqq_gemm_kernel.cu +1243 -0
 - marlin/sparse/LICENSE +203 -0
 - marlin/sparse/common/base.h +51 -0
 - marlin/sparse/common/mem.h +136 -0
 - marlin/sparse/common/mma.h +191 -0
 - marlin/sparse/marlin_24_cuda_kernel.cu +1140 -0
 - tests/kernels/test_marlin_gemm.py +733 -0
 - tests/kernels/utils.py +38 -27
 
    	
        build.toml
    CHANGED
    
    | 
         @@ -10,9 +10,7 @@ src = [ 
     | 
|
| 10 | 
         
             
              "ext-torch/torch_binding.h"
         
     | 
| 11 | 
         
             
            ]
         
     | 
| 12 | 
         
             
            include = [ "." ]
         
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
              "ext-torch/__init__.py"
         
     | 
| 15 | 
         
            -
            ]
         
     | 
| 16 | 
         | 
| 17 | 
         
             
            [kernel.cutlass_w8a8]
         
     | 
| 18 | 
         
             
            capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
         
     | 
| 
         @@ -59,7 +57,6 @@ src = [ 
     | 
|
| 59 | 
         
             
              "gptq_marlin/marlin.cuh",
         
     | 
| 60 | 
         
             
              "gptq_marlin/marlin_dtypes.cuh",
         
     | 
| 61 | 
         
             
            ]
         
     | 
| 62 | 
         
            -
            #include = [ "." ]
         
     | 
| 63 | 
         
             
            depends = [ "torch" ]
         
     | 
| 64 | 
         | 
| 65 | 
         
             
            [kernel.int8_common]
         
     | 
| 
         @@ -83,3 +80,21 @@ src = [ 
     | 
|
| 83 | 
         
             
            ]
         
     | 
| 84 | 
         
             
            include = [ "." ]
         
     | 
| 85 | 
         
             
            depends = [ "torch" ]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 10 | 
         
             
              "ext-torch/torch_binding.h"
         
     | 
| 11 | 
         
             
            ]
         
     | 
| 12 | 
         
             
            include = [ "." ]
         
     | 
| 13 | 
         
            +
            pyroot = "ext-torch"
         
     | 
| 
         | 
|
| 
         | 
|
| 14 | 
         | 
| 15 | 
         
             
            [kernel.cutlass_w8a8]
         
     | 
| 16 | 
         
             
            capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
         
     | 
| 
         | 
|
| 57 | 
         
             
              "gptq_marlin/marlin.cuh",
         
     | 
| 58 | 
         
             
              "gptq_marlin/marlin_dtypes.cuh",
         
     | 
| 59 | 
         
             
            ]
         
     | 
| 
         | 
|
| 60 | 
         
             
            depends = [ "torch" ]
         
     | 
| 61 | 
         | 
| 62 | 
         
             
            [kernel.int8_common]
         
     | 
| 
         | 
|
| 80 | 
         
             
            ]
         
     | 
| 81 | 
         
             
            include = [ "." ]
         
     | 
| 82 | 
         
             
            depends = [ "torch" ]
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            [kernel.marlin]
         
     | 
| 85 | 
         
            +
            capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
         
     | 
| 86 | 
         
            +
            src = [
         
     | 
| 87 | 
         
            +
              "core/scalar_type.hpp",
         
     | 
| 88 | 
         
            +
              "marlin/dense/common/base.h",
         
     | 
| 89 | 
         
            +
              "marlin/dense/common/mem.h",
         
     | 
| 90 | 
         
            +
              "marlin/dense/marlin_cuda_kernel.cu",
         
     | 
| 91 | 
         
            +
              "marlin/qqq/marlin_qqq_gemm_kernel.cu",
         
     | 
| 92 | 
         
            +
              "marlin/sparse/common/base.h",
         
     | 
| 93 | 
         
            +
              "marlin/sparse/common/mem.h",
         
     | 
| 94 | 
         
            +
              "marlin/sparse/common/mma.h",
         
     | 
| 95 | 
         
            +
              "marlin/sparse/marlin_24_cuda_kernel.cu"
         
     | 
| 96 | 
         
            +
            ]
         
     | 
| 97 | 
         
            +
            include = [ "." ]
         
     | 
| 98 | 
         
            +
            depends = [ "torch" ]
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
             
     | 
    	
        ext-torch/__init__.py
    CHANGED
    
    | 
         @@ -1,177 +1,30 @@ 
     | 
|
| 1 | 
         
            -
            from  
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
                 
     | 
| 9 | 
         
            -
                 
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
             
     | 
| 12 | 
         
            -
                 
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
             
     | 
| 15 | 
         
            -
             
     | 
| 16 | 
         
            -
             
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
             
     | 
| 19 | 
         
            -
             
     | 
| 20 | 
         
            -
             
     | 
| 21 | 
         
            -
             
     | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
                 
     | 
| 25 | 
         
            -
                 
     | 
| 26 | 
         
            -
                 
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
             
     | 
| 29 | 
         
            -
                 
     | 
| 30 | 
         
            -
             
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
                #if current_platform.is_rocm():
         
     | 
| 33 | 
         
            -
                #    triton_scaled_mm_module = importlib.import_module(
         
     | 
| 34 | 
         
            -
                #        "vllm.model_executor.layers.quantization.compressed_tensors."
         
     | 
| 35 | 
         
            -
                #        "triton_scaled_mm")
         
     | 
| 36 | 
         
            -
                #    triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
         
     | 
| 37 | 
         
            -
                #    return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
         
     | 
| 38 | 
         
            -
             
     | 
| 39 | 
         
            -
                out = torch.empty((m, n), dtype=out_dtype, device=a.device)
         
     | 
| 40 | 
         
            -
             
     | 
| 41 | 
         
            -
                ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
         
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
                return out
         
     | 
| 44 | 
         
            -
             
     | 
| 45 | 
         
            -
            def cutlass_scaled_mm_azp(a: torch.Tensor,
         
     | 
| 46 | 
         
            -
                                      b: torch.Tensor,
         
     | 
| 47 | 
         
            -
                                      scale_a: torch.Tensor,
         
     | 
| 48 | 
         
            -
                                      scale_b: torch.Tensor,
         
     | 
| 49 | 
         
            -
                                      out_dtype: torch.dtype,
         
     | 
| 50 | 
         
            -
                                      azp_adj: torch.Tensor,
         
     | 
| 51 | 
         
            -
                                      azp: Optional[torch.Tensor] = None,
         
     | 
| 52 | 
         
            -
                                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         
     | 
| 53 | 
         
            -
                """
         
     | 
| 54 | 
         
            -
                :param azp_adj: In the per-tensor case, this should include the azp.
         
     | 
| 55 | 
         
            -
                Always per-channel.
         
     | 
| 56 | 
         
            -
                :param azp: Only set in the per-token case. Per-token if set.
         
     | 
| 57 | 
         
            -
                """
         
     | 
| 58 | 
         
            -
                assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
         
     | 
| 59 | 
         
            -
                assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
         
     | 
| 60 | 
         
            -
                assert bias is None or bias.numel(
         
     | 
| 61 | 
         
            -
                ) == b.shape[1] and bias.dtype == out_dtype
         
     | 
| 62 | 
         
            -
                assert azp is None or azp.numel() == a.shape[0]
         
     | 
| 63 | 
         
            -
             
     | 
| 64 | 
         
            -
                m = a.shape[0]
         
     | 
| 65 | 
         
            -
                n = b.shape[1]
         
     | 
| 66 | 
         
            -
                out = torch.empty((m, n), dtype=out_dtype, device=a.device)
         
     | 
| 67 | 
         
            -
             
     | 
| 68 | 
         
            -
                ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
         
     | 
| 69 | 
         
            -
                                         azp, bias)
         
     | 
| 70 | 
         
            -
                return out
         
     | 
| 71 | 
         
            -
             
     | 
| 72 | 
         
            -
            # fp8
         
     | 
| 73 | 
         
            -
            def scaled_fp8_quant(
         
     | 
| 74 | 
         
            -
                input: torch.Tensor,
         
     | 
| 75 | 
         
            -
                scale: Optional[torch.Tensor] = None,
         
     | 
| 76 | 
         
            -
                num_token_padding: Optional[int] = None,
         
     | 
| 77 | 
         
            -
                scale_ub: Optional[torch.Tensor] = None,
         
     | 
| 78 | 
         
            -
                use_per_token_if_dynamic: bool = False,
         
     | 
| 79 | 
         
            -
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 80 | 
         
            -
                """
         
     | 
| 81 | 
         
            -
                Quantize input tensor to FP8 and return quantized tensor and scale.
         
     | 
| 82 | 
         
            -
             
     | 
| 83 | 
         
            -
                This function supports both static and dynamic quantization: If you
         
     | 
| 84 | 
         
            -
                provide the scale, it will use static scaling and if you omit it,
         
     | 
| 85 | 
         
            -
                the scale will be determined dynamically. The function also allows
         
     | 
| 86 | 
         
            -
                optional padding of the output tensors for downstream kernels that
         
     | 
| 87 | 
         
            -
                will benefit from padding.
         
     | 
| 88 | 
         
            -
             
     | 
| 89 | 
         
            -
                Args:
         
     | 
| 90 | 
         
            -
                    input: The input tensor to be quantized to FP8
         
     | 
| 91 | 
         
            -
                    scale: Optional scaling factor for the FP8 quantization
         
     | 
| 92 | 
         
            -
                    scale_ub: Optional upper bound for scaling factor in dynamic
         
     | 
| 93 | 
         
            -
                        per token case
         
     | 
| 94 | 
         
            -
                    num_token_padding: If specified, pad the first dimension
         
     | 
| 95 | 
         
            -
                        of the output to at least this value.
         
     | 
| 96 | 
         
            -
                    use_per_token_if_dynamic: Whether to do per_tensor or per_token
         
     | 
| 97 | 
         
            -
                        in the dynamic quantization case.
         
     | 
| 98 | 
         
            -
             
     | 
| 99 | 
         
            -
                Returns:
         
     | 
| 100 | 
         
            -
                    Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
         
     | 
| 101 | 
         
            -
                        scaling factor.
         
     | 
| 102 | 
         
            -
                """
         
     | 
| 103 | 
         
            -
                # This code assumes batch_dim and num_tokens are flattened
         
     | 
| 104 | 
         
            -
                assert (input.ndim == 2)
         
     | 
| 105 | 
         
            -
                shape: Union[Tuple[int, int], torch.Size] = input.shape
         
     | 
| 106 | 
         
            -
                # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
         
     | 
| 107 | 
         
            -
                #out_dtype: torch.dtype = torch.float8_e4m3fnuz \
         
     | 
| 108 | 
         
            -
                #        if current_platform.is_rocm() else torch.float8_e4m3fn
         
     | 
| 109 | 
         
            -
                out_dtype = torch.float8_e4m3fn
         
     | 
| 110 | 
         
            -
                if num_token_padding:
         
     | 
| 111 | 
         
            -
                    shape = (max(num_token_padding, input.shape[0]), shape[1])
         
     | 
| 112 | 
         
            -
                output = torch.empty(shape, device=input.device, dtype=out_dtype)
         
     | 
| 113 | 
         
            -
             
     | 
| 114 | 
         
            -
                if scale is None:
         
     | 
| 115 | 
         
            -
                    if use_per_token_if_dynamic:
         
     | 
| 116 | 
         
            -
                        scale = torch.empty((shape[0], 1),
         
     | 
| 117 | 
         
            -
                                            device=input.device,
         
     | 
| 118 | 
         
            -
                                            dtype=torch.float32)
         
     | 
| 119 | 
         
            -
                        ops.dynamic_per_token_scaled_fp8_quant(
         
     | 
| 120 | 
         
            -
                            output, input, scale, scale_ub)
         
     | 
| 121 | 
         
            -
                    else:
         
     | 
| 122 | 
         
            -
                        scale = torch.zeros(1, device=input.device, dtype=torch.float32)
         
     | 
| 123 | 
         
            -
                        ops.dynamic_scaled_fp8_quant(output, input, scale)
         
     | 
| 124 | 
         
            -
                else:
         
     | 
| 125 | 
         
            -
                    # num_token_padding not implemented for this case
         
     | 
| 126 | 
         
            -
                    assert (scale.numel() == 1 or num_token_padding is None)
         
     | 
| 127 | 
         
            -
                    ops.static_scaled_fp8_quant(output, input, scale)
         
     | 
| 128 | 
         
            -
             
     | 
| 129 | 
         
            -
                return output, scale
         
     | 
| 130 | 
         
            -
             
     | 
| 131 | 
         
            -
            # int8
         
     | 
| 132 | 
         
            -
            def scaled_int8_quant(
         
     | 
| 133 | 
         
            -
                input: torch.Tensor,
         
     | 
| 134 | 
         
            -
                scale: Optional[torch.Tensor] = None,
         
     | 
| 135 | 
         
            -
                azp: Optional[torch.Tensor] = None,
         
     | 
| 136 | 
         
            -
                symmetric: bool = True
         
     | 
| 137 | 
         
            -
            ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         
     | 
| 138 | 
         
            -
                """
         
     | 
| 139 | 
         
            -
                Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
         
     | 
| 140 | 
         
            -
             
     | 
| 141 | 
         
            -
                Args:
         
     | 
| 142 | 
         
            -
                    input: The input tensor to be quantized to int8.
         
     | 
| 143 | 
         
            -
                    scale: Optional scaling factor for the int8 quantization.
         
     | 
| 144 | 
         
            -
                        When not provided, we invoke dynamic-per-token quantization.
         
     | 
| 145 | 
         
            -
                    azp: Optional zero-point for the int8 quantization.
         
     | 
| 146 | 
         
            -
                        Must be provided for asymmetric quantization if `scale` is provided.
         
     | 
| 147 | 
         
            -
                    symmetric: Whether to use symmetric quantization (scale only, azp ignored).
         
     | 
| 148 | 
         
            -
             
     | 
| 149 | 
         
            -
                Returns:
         
     | 
| 150 | 
         
            -
                  Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
         
     | 
| 151 | 
         
            -
                """
         
     | 
| 152 | 
         
            -
                output = torch.empty_like(input, dtype=torch.int8)
         
     | 
| 153 | 
         
            -
                if scale is not None:
         
     | 
| 154 | 
         
            -
                    # static-per-tensor quantization.
         
     | 
| 155 | 
         
            -
                    assert symmetric == (
         
     | 
| 156 | 
         
            -
                        azp is
         
     | 
| 157 | 
         
            -
                        None), "azp must only be provided for asymmetric quantization."
         
     | 
| 158 | 
         
            -
                    ops.static_scaled_int8_quant(output, input, scale, azp)
         
     | 
| 159 | 
         
            -
                    return output, scale, azp
         
     | 
| 160 | 
         
            -
             
     | 
| 161 | 
         
            -
                # dynamic-per-token quantization.
         
     | 
| 162 | 
         
            -
                input_scales = torch.empty((input.numel() // input.shape[-1], 1),
         
     | 
| 163 | 
         
            -
                                           device=input.device,
         
     | 
| 164 | 
         
            -
                                           dtype=torch.float32)
         
     | 
| 165 | 
         
            -
                input_azp = None if symmetric else torch.empty_like(input_scales,
         
     | 
| 166 | 
         
            -
                                                                    dtype=torch.int32)
         
     | 
| 167 | 
         
            -
                ops.dynamic_scaled_int8_quant(output, input, input_scales,
         
     | 
| 168 | 
         
            -
                                                       input_azp)
         
     | 
| 169 | 
         
            -
                return output, input_scales, input_azp
         
     | 
| 170 | 
         
            -
             
     | 
| 171 | 
         
            -
            # fp8 marlin
         
     | 
| 172 | 
         
            -
            def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
         
     | 
| 173 | 
         
            -
                                b_scales: torch.Tensor, workspace: torch.Tensor,
         
     | 
| 174 | 
         
            -
                                num_bits: int, size_m: int, size_n: int,
         
     | 
| 175 | 
         
            -
                                size_k: int) -> torch.Tensor:
         
     | 
| 176 | 
         
            -
                return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
         
     | 
| 177 | 
         
            -
                                           num_bits, size_m, size_n, size_k)
         
     | 
| 
         | 
|
| 1 | 
         
            +
            from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant
         
     | 
| 2 | 
         
            +
            from .cutlass import (
         
     | 
| 3 | 
         
            +
                cutlass_scaled_mm_supports_fp8,
         
     | 
| 4 | 
         
            +
                cutlass_scaled_mm,
         
     | 
| 5 | 
         
            +
                cutlass_scaled_mm_azp,
         
     | 
| 6 | 
         
            +
            )
         
     | 
| 7 | 
         
            +
            from .marlin import (
         
     | 
| 8 | 
         
            +
                awq_marlin_repack,
         
     | 
| 9 | 
         
            +
                fp8_marlin_gemm,
         
     | 
| 10 | 
         
            +
                gptq_marlin_gemm,
         
     | 
| 11 | 
         
            +
                gptq_marlin_repack,
         
     | 
| 12 | 
         
            +
                gptq_marlin_24_gemm,
         
     | 
| 13 | 
         
            +
                marlin_qqq_gemm,
         
     | 
| 14 | 
         
            +
                marlin_gemm,
         
     | 
| 15 | 
         
            +
            )
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            __all__ = [
         
     | 
| 18 | 
         
            +
                "awq_marlin_repack",
         
     | 
| 19 | 
         
            +
                "cutlass_scaled_mm",
         
     | 
| 20 | 
         
            +
                "cutlass_scaled_mm_azp",
         
     | 
| 21 | 
         
            +
                "cutlass_scaled_mm_supports_fp8",
         
     | 
| 22 | 
         
            +
                "fp8_marlin_gemm",
         
     | 
| 23 | 
         
            +
                "gptq_marlin_24_gemm",
         
     | 
| 24 | 
         
            +
                "gptq_marlin_gemm",
         
     | 
| 25 | 
         
            +
                "gptq_marlin_repack",
         
     | 
| 26 | 
         
            +
                "marlin_gemm",
         
     | 
| 27 | 
         
            +
                "marlin_qqq_gemm",
         
     | 
| 28 | 
         
            +
                "scaled_fp8_quant",
         
     | 
| 29 | 
         
            +
                "scaled_int8_quant",
         
     | 
| 30 | 
         
            +
            ]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        ext-torch/compressed_tensors.py
    ADDED
    
    | 
         @@ -0,0 +1,110 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Optional, Tuple
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            try:
         
     | 
| 6 | 
         
            +
                from ._ops import ops
         
     | 
| 7 | 
         
            +
            except ImportError as e:
         
     | 
| 8 | 
         
            +
                # Fallback for local development.
         
     | 
| 9 | 
         
            +
                try:
         
     | 
| 10 | 
         
            +
                    import _quantization
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                    ops = torch.ops._quantization
         
     | 
| 13 | 
         
            +
                except ImportError:
         
     | 
| 14 | 
         
            +
                    raise e
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            # fp8
         
     | 
| 18 | 
         
            +
            def scaled_fp8_quant(
         
     | 
| 19 | 
         
            +
                input: torch.Tensor,
         
     | 
| 20 | 
         
            +
                scale: Optional[torch.Tensor] = None,
         
     | 
| 21 | 
         
            +
                num_token_padding: Optional[int] = None,
         
     | 
| 22 | 
         
            +
                scale_ub: Optional[torch.Tensor] = None,
         
     | 
| 23 | 
         
            +
                use_per_token_if_dynamic: bool = False,
         
     | 
| 24 | 
         
            +
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 25 | 
         
            +
                """
         
     | 
| 26 | 
         
            +
                Quantize input tensor to FP8 and return quantized tensor and scale.
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                This function supports both static and dynamic quantization: If you
         
     | 
| 29 | 
         
            +
                provide the scale, it will use static scaling and if you omit it,
         
     | 
| 30 | 
         
            +
                the scale will be determined dynamically. The function also allows
         
     | 
| 31 | 
         
            +
                optional padding of the output tensors for downstream kernels that
         
     | 
| 32 | 
         
            +
                will benefit from padding.
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                Args:
         
     | 
| 35 | 
         
            +
                    input: The input tensor to be quantized to FP8
         
     | 
| 36 | 
         
            +
                    scale: Optional scaling factor for the FP8 quantization
         
     | 
| 37 | 
         
            +
                    scale_ub: Optional upper bound for scaling factor in dynamic
         
     | 
| 38 | 
         
            +
                        per token case
         
     | 
| 39 | 
         
            +
                    num_token_padding: If specified, pad the first dimension
         
     | 
| 40 | 
         
            +
                        of the output to at least this value.
         
     | 
| 41 | 
         
            +
                    use_per_token_if_dynamic: Whether to do per_tensor or per_token
         
     | 
| 42 | 
         
            +
                        in the dynamic quantization case.
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                Returns:
         
     | 
| 45 | 
         
            +
                    Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
         
     | 
| 46 | 
         
            +
                        scaling factor.
         
     | 
| 47 | 
         
            +
                """
         
     | 
| 48 | 
         
            +
                # This code assumes batch_dim and num_tokens are flattened
         
     | 
| 49 | 
         
            +
                assert input.ndim == 2
         
     | 
| 50 | 
         
            +
                shape: Union[Tuple[int, int], torch.Size] = input.shape
         
     | 
| 51 | 
         
            +
                # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
         
     | 
| 52 | 
         
            +
                # out_dtype: torch.dtype = torch.float8_e4m3fnuz \
         
     | 
| 53 | 
         
            +
                #        if current_platform.is_rocm() else torch.float8_e4m3fn
         
     | 
| 54 | 
         
            +
                out_dtype = torch.float8_e4m3fn
         
     | 
| 55 | 
         
            +
                if num_token_padding:
         
     | 
| 56 | 
         
            +
                    shape = (max(num_token_padding, input.shape[0]), shape[1])
         
     | 
| 57 | 
         
            +
                output = torch.empty(shape, device=input.device, dtype=out_dtype)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                if scale is None:
         
     | 
| 60 | 
         
            +
                    if use_per_token_if_dynamic:
         
     | 
| 61 | 
         
            +
                        scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
         
     | 
| 62 | 
         
            +
                        ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
         
     | 
| 63 | 
         
            +
                    else:
         
     | 
| 64 | 
         
            +
                        scale = torch.zeros(1, device=input.device, dtype=torch.float32)
         
     | 
| 65 | 
         
            +
                        ops.dynamic_scaled_fp8_quant(output, input, scale)
         
     | 
| 66 | 
         
            +
                else:
         
     | 
| 67 | 
         
            +
                    # num_token_padding not implemented for this case
         
     | 
| 68 | 
         
            +
                    assert scale.numel() == 1 or num_token_padding is None
         
     | 
| 69 | 
         
            +
                    ops.static_scaled_fp8_quant(output, input, scale)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                return output, scale
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            # int8
         
     | 
| 75 | 
         
            +
            def scaled_int8_quant(
         
     | 
| 76 | 
         
            +
                input: torch.Tensor,
         
     | 
| 77 | 
         
            +
                scale: Optional[torch.Tensor] = None,
         
     | 
| 78 | 
         
            +
                azp: Optional[torch.Tensor] = None,
         
     | 
| 79 | 
         
            +
                symmetric: bool = True,
         
     | 
| 80 | 
         
            +
            ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         
     | 
| 81 | 
         
            +
                """
         
     | 
| 82 | 
         
            +
                Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                Args:
         
     | 
| 85 | 
         
            +
                    input: The input tensor to be quantized to int8.
         
     | 
| 86 | 
         
            +
                    scale: Optional scaling factor for the int8 quantization.
         
     | 
| 87 | 
         
            +
                        When not provided, we invoke dynamic-per-token quantization.
         
     | 
| 88 | 
         
            +
                    azp: Optional zero-point for the int8 quantization.
         
     | 
| 89 | 
         
            +
                        Must be provided for asymmetric quantization if `scale` is provided.
         
     | 
| 90 | 
         
            +
                    symmetric: Whether to use symmetric quantization (scale only, azp ignored).
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                Returns:
         
     | 
| 93 | 
         
            +
                  Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
         
     | 
| 94 | 
         
            +
                """
         
     | 
| 95 | 
         
            +
                output = torch.empty_like(input, dtype=torch.int8)
         
     | 
| 96 | 
         
            +
                if scale is not None:
         
     | 
| 97 | 
         
            +
                    # static-per-tensor quantization.
         
     | 
| 98 | 
         
            +
                    assert symmetric == (
         
     | 
| 99 | 
         
            +
                        azp is None
         
     | 
| 100 | 
         
            +
                    ), "azp must only be provided for asymmetric quantization."
         
     | 
| 101 | 
         
            +
                    ops.static_scaled_int8_quant(output, input, scale, azp)
         
     | 
| 102 | 
         
            +
                    return output, scale, azp
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                # dynamic-per-token quantization.
         
     | 
| 105 | 
         
            +
                input_scales = torch.empty(
         
     | 
| 106 | 
         
            +
                    (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
         
     | 
| 107 | 
         
            +
                )
         
     | 
| 108 | 
         
            +
                input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
         
     | 
| 109 | 
         
            +
                ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp)
         
     | 
| 110 | 
         
            +
                return output, input_scales, input_azp
         
     | 
    	
        ext-torch/cutlass.py
    ADDED
    
    | 
         @@ -0,0 +1,75 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Optional
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            try:
         
     | 
| 6 | 
         
            +
                from ._ops import ops
         
     | 
| 7 | 
         
            +
            except ImportError as e:
         
     | 
| 8 | 
         
            +
                # Fallback for local development.
         
     | 
| 9 | 
         
            +
                try:
         
     | 
| 10 | 
         
            +
                    import _quantization
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                    ops = torch.ops._quantization
         
     | 
| 13 | 
         
            +
                except ImportError:
         
     | 
| 14 | 
         
            +
                    raise e
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
         
     | 
| 18 | 
         
            +
                return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def cutlass_scaled_mm(
         
     | 
| 22 | 
         
            +
                a: torch.Tensor,
         
     | 
| 23 | 
         
            +
                b: torch.Tensor,
         
     | 
| 24 | 
         
            +
                scale_a: torch.Tensor,
         
     | 
| 25 | 
         
            +
                scale_b: torch.Tensor,
         
     | 
| 26 | 
         
            +
                out_dtype: torch.dtype,
         
     | 
| 27 | 
         
            +
                bias: Optional[torch.Tensor] = None,
         
     | 
| 28 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 29 | 
         
            +
                assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
         
     | 
| 30 | 
         
            +
                assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
         
     | 
| 31 | 
         
            +
                assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                m = a.shape[0]
         
     | 
| 34 | 
         
            +
                n = b.shape[1]
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                # if current_platform.is_rocm():
         
     | 
| 37 | 
         
            +
                #    triton_scaled_mm_module = importlib.import_module(
         
     | 
| 38 | 
         
            +
                #        "vllm.model_executor.layers.quantization.compressed_tensors."
         
     | 
| 39 | 
         
            +
                #        "triton_scaled_mm")
         
     | 
| 40 | 
         
            +
                #    triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
         
     | 
| 41 | 
         
            +
                #    return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                out = torch.empty((m, n), dtype=out_dtype, device=a.device)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                return out
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            def cutlass_scaled_mm_azp(
         
     | 
| 51 | 
         
            +
                a: torch.Tensor,
         
     | 
| 52 | 
         
            +
                b: torch.Tensor,
         
     | 
| 53 | 
         
            +
                scale_a: torch.Tensor,
         
     | 
| 54 | 
         
            +
                scale_b: torch.Tensor,
         
     | 
| 55 | 
         
            +
                out_dtype: torch.dtype,
         
     | 
| 56 | 
         
            +
                azp_adj: torch.Tensor,
         
     | 
| 57 | 
         
            +
                azp: Optional[torch.Tensor] = None,
         
     | 
| 58 | 
         
            +
                bias: Optional[torch.Tensor] = None,
         
     | 
| 59 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 60 | 
         
            +
                """
         
     | 
| 61 | 
         
            +
                :param azp_adj: In the per-tensor case, this should include the azp.
         
     | 
| 62 | 
         
            +
                Always per-channel.
         
     | 
| 63 | 
         
            +
                :param azp: Only set in the per-token case. Per-token if set.
         
     | 
| 64 | 
         
            +
                """
         
     | 
| 65 | 
         
            +
                assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
         
     | 
| 66 | 
         
            +
                assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
         
     | 
| 67 | 
         
            +
                assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
         
     | 
| 68 | 
         
            +
                assert azp is None or azp.numel() == a.shape[0]
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                m = a.shape[0]
         
     | 
| 71 | 
         
            +
                n = b.shape[1]
         
     | 
| 72 | 
         
            +
                out = torch.empty((m, n), dtype=out_dtype, device=a.device)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
         
     | 
| 75 | 
         
            +
                return out
         
     | 
    	
        ext-torch/marlin.py
    ADDED
    
    | 
         @@ -0,0 +1,208 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import TYPE_CHECKING
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # neuron has torch version that doesn't even have impl_abstract
         
     | 
| 6 | 
         
            +
            if TYPE_CHECKING:
         
     | 
| 7 | 
         
            +
                def register_fake(fn):
         
     | 
| 8 | 
         
            +
                    return lambda name: fn
         
     | 
| 9 | 
         
            +
            else:
         
     | 
| 10 | 
         
            +
                try:
         
     | 
| 11 | 
         
            +
                    from torch.library import register_fake
         
     | 
| 12 | 
         
            +
                except ImportError:
         
     | 
| 13 | 
         
            +
                    from torch.library import impl_abstract as register_fake
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            try:
         
     | 
| 16 | 
         
            +
                from ._ops import ops, add_op_namespace_prefix
         
     | 
| 17 | 
         
            +
            except ImportError as e:
         
     | 
| 18 | 
         
            +
                # Fallback for local development.
         
     | 
| 19 | 
         
            +
                try:
         
     | 
| 20 | 
         
            +
                    import _quantization
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                    ops = torch.ops._quantization
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    def add_op_namespace_prefix(op_name: str):
         
     | 
| 25 | 
         
            +
                        return f"_quantization::{op_name}"
         
     | 
| 26 | 
         
            +
                except ImportError:
         
     | 
| 27 | 
         
            +
                    raise e
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            from .scalar_type import ScalarType
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            # fp8 marlin
         
     | 
| 34 | 
         
            +
            def fp8_marlin_gemm(
         
     | 
| 35 | 
         
            +
                a: torch.Tensor,
         
     | 
| 36 | 
         
            +
                b_q_weight: torch.Tensor,
         
     | 
| 37 | 
         
            +
                b_scales: torch.Tensor,
         
     | 
| 38 | 
         
            +
                workspace: torch.Tensor,
         
     | 
| 39 | 
         
            +
                num_bits: int,
         
     | 
| 40 | 
         
            +
                size_m: int,
         
     | 
| 41 | 
         
            +
                size_n: int,
         
     | 
| 42 | 
         
            +
                size_k: int,
         
     | 
| 43 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 44 | 
         
            +
                return ops.fp8_marlin_gemm(
         
     | 
| 45 | 
         
            +
                    a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
         
     | 
| 46 | 
         
            +
                )
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            # gptq_marlin
         
     | 
| 50 | 
         
            +
            def gptq_marlin_gemm(
         
     | 
| 51 | 
         
            +
                a: torch.Tensor,
         
     | 
| 52 | 
         
            +
                b_q_weight: torch.Tensor,
         
     | 
| 53 | 
         
            +
                b_scales: torch.Tensor,
         
     | 
| 54 | 
         
            +
                b_zeros: torch.Tensor,
         
     | 
| 55 | 
         
            +
                g_idx: torch.Tensor,
         
     | 
| 56 | 
         
            +
                perm: torch.Tensor,
         
     | 
| 57 | 
         
            +
                workspace: torch.Tensor,
         
     | 
| 58 | 
         
            +
                b_q_type: ScalarType,
         
     | 
| 59 | 
         
            +
                size_m: int,
         
     | 
| 60 | 
         
            +
                size_n: int,
         
     | 
| 61 | 
         
            +
                size_k: int,
         
     | 
| 62 | 
         
            +
                is_k_full: bool,
         
     | 
| 63 | 
         
            +
                has_zp: bool = False,
         
     | 
| 64 | 
         
            +
                use_fp32_reduce: bool = False,
         
     | 
| 65 | 
         
            +
                is_zp_float: bool = False,
         
     | 
| 66 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 67 | 
         
            +
                return ops.gptq_marlin_gemm(
         
     | 
| 68 | 
         
            +
                    a,
         
     | 
| 69 | 
         
            +
                    b_q_weight,
         
     | 
| 70 | 
         
            +
                    b_scales,
         
     | 
| 71 | 
         
            +
                    b_zeros,
         
     | 
| 72 | 
         
            +
                    g_idx,
         
     | 
| 73 | 
         
            +
                    perm,
         
     | 
| 74 | 
         
            +
                    workspace,
         
     | 
| 75 | 
         
            +
                    b_q_type.id,
         
     | 
| 76 | 
         
            +
                    size_m,
         
     | 
| 77 | 
         
            +
                    size_n,
         
     | 
| 78 | 
         
            +
                    size_k,
         
     | 
| 79 | 
         
            +
                    is_k_full,
         
     | 
| 80 | 
         
            +
                    has_zp,
         
     | 
| 81 | 
         
            +
                    use_fp32_reduce,
         
     | 
| 82 | 
         
            +
                    is_zp_float,
         
     | 
| 83 | 
         
            +
                )
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
            # gptq_marlin
         
     | 
| 87 | 
         
            +
            def gptq_marlin_repack(
         
     | 
| 88 | 
         
            +
                b_q_weight: torch.Tensor,
         
     | 
| 89 | 
         
            +
                perm: torch.Tensor,
         
     | 
| 90 | 
         
            +
                size_k: int,
         
     | 
| 91 | 
         
            +
                size_n: int,
         
     | 
| 92 | 
         
            +
                num_bits: int,
         
     | 
| 93 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 94 | 
         
            +
                return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            # gptq_marlin
         
     | 
| 98 | 
         
            +
            def awq_marlin_repack(
         
     | 
| 99 | 
         
            +
                b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
         
     | 
| 100 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 101 | 
         
            +
                return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            # marlin
         
     | 
| 105 | 
         
            +
            def marlin_gemm(
         
     | 
| 106 | 
         
            +
                a: torch.Tensor,
         
     | 
| 107 | 
         
            +
                b_q_weight: torch.Tensor,
         
     | 
| 108 | 
         
            +
                b_scales: torch.Tensor,
         
     | 
| 109 | 
         
            +
                workspace: torch.Tensor,
         
     | 
| 110 | 
         
            +
                size_m: int,
         
     | 
| 111 | 
         
            +
                size_n: int,
         
     | 
| 112 | 
         
            +
                size_k: int,
         
     | 
| 113 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 114 | 
         
            +
                return ops.marlin_gemm(
         
     | 
| 115 | 
         
            +
                    a, b_q_weight, b_scales, workspace, size_m, size_n, size_k
         
     | 
| 116 | 
         
            +
                )
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
            # marlin_24
         
     | 
| 120 | 
         
            +
            def gptq_marlin_24_gemm(
         
     | 
| 121 | 
         
            +
                a: torch.Tensor,
         
     | 
| 122 | 
         
            +
                b_q_weight: torch.Tensor,
         
     | 
| 123 | 
         
            +
                b_meta: torch.Tensor,
         
     | 
| 124 | 
         
            +
                b_scales: torch.Tensor,
         
     | 
| 125 | 
         
            +
                workspace: torch.Tensor,
         
     | 
| 126 | 
         
            +
                b_q_type: ScalarType,
         
     | 
| 127 | 
         
            +
                size_m: int,
         
     | 
| 128 | 
         
            +
                size_n: int,
         
     | 
| 129 | 
         
            +
                size_k: int,
         
     | 
| 130 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 131 | 
         
            +
                return ops.gptq_marlin_24_gemm(
         
     | 
| 132 | 
         
            +
                    a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
         
     | 
| 133 | 
         
            +
                )
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            # qqq ops
         
     | 
| 137 | 
         
            +
            def marlin_qqq_gemm(
         
     | 
| 138 | 
         
            +
                a: torch.Tensor,
         
     | 
| 139 | 
         
            +
                b_q_weight: torch.Tensor,
         
     | 
| 140 | 
         
            +
                s_tok: torch.Tensor,
         
     | 
| 141 | 
         
            +
                s_ch: torch.Tensor,
         
     | 
| 142 | 
         
            +
                s_group: torch.Tensor,
         
     | 
| 143 | 
         
            +
                workspace: torch.Tensor,
         
     | 
| 144 | 
         
            +
                size_m: int,
         
     | 
| 145 | 
         
            +
                size_n: int,
         
     | 
| 146 | 
         
            +
                size_k: int,
         
     | 
| 147 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 148 | 
         
            +
                return ops.marlin_qqq_gemm(
         
     | 
| 149 | 
         
            +
                    a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k
         
     | 
| 150 | 
         
            +
                )
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
            # Fake ops
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
            if hasattr(ops, "gptq_marlin_24_gemm"):
         
     | 
| 156 | 
         
            +
                @register_fake(add_op_namespace_prefix("fp8_marlin_gemm"))
         
     | 
| 157 | 
         
            +
                def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
         
     | 
| 158 | 
         
            +
                                          b_scales: torch.Tensor, workspace: torch.Tensor,
         
     | 
| 159 | 
         
            +
                                          num_bits: int, size_m: torch.SymInt,
         
     | 
| 160 | 
         
            +
                                          size_n: torch.SymInt,
         
     | 
| 161 | 
         
            +
                                          size_k: torch.SymInt) -> torch.Tensor:
         
     | 
| 162 | 
         
            +
                    return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm"))
         
     | 
| 165 | 
         
            +
                def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
         
     | 
| 166 | 
         
            +
                                                b_meta: torch.Tensor, b_scales: torch.Tensor,
         
     | 
| 167 | 
         
            +
                                                workspace: torch.Tensor,
         
     | 
| 168 | 
         
            +
                                                b_q_type: ScalarType, size_m: torch.SymInt,
         
     | 
| 169 | 
         
            +
                                                size_n: torch.SymInt,
         
     | 
| 170 | 
         
            +
                                                size_k: torch.SymInt) -> torch.Tensor:
         
     | 
| 171 | 
         
            +
                    return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                @register_fake(add_op_namespace_prefix("gptq_marlin_gemm"))
         
     | 
| 174 | 
         
            +
                def _gptq_marlin_gemm_fake(a: torch.Tensor,
         
     | 
| 175 | 
         
            +
                                            b_q_weight: torch.Tensor,
         
     | 
| 176 | 
         
            +
                                            b_scales: torch.Tensor,
         
     | 
| 177 | 
         
            +
                                            b_zeros: torch.Tensor,
         
     | 
| 178 | 
         
            +
                                            g_idx: torch.Tensor,
         
     | 
| 179 | 
         
            +
                                            perm: torch.Tensor,
         
     | 
| 180 | 
         
            +
                                            workspace: torch.Tensor,
         
     | 
| 181 | 
         
            +
                                            b_q_type: ScalarType,
         
     | 
| 182 | 
         
            +
                                            size_m: torch.SymInt,
         
     | 
| 183 | 
         
            +
                                            size_n: torch.SymInt,
         
     | 
| 184 | 
         
            +
                                            size_k: torch.SymInt,
         
     | 
| 185 | 
         
            +
                                            is_k_full: bool,
         
     | 
| 186 | 
         
            +
                                            has_zp: bool = False,
         
     | 
| 187 | 
         
            +
                                            use_fp32_reduce: bool = False,
         
     | 
| 188 | 
         
            +
                                            is_zp_float: bool = False) -> torch.Tensor:
         
     | 
| 189 | 
         
            +
                    return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                @register_fake(add_op_namespace_prefix("marlin_qqq_gemm"))
         
     | 
| 192 | 
         
            +
                def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
         
     | 
| 193 | 
         
            +
                                            s_tok: torch.Tensor, s_ch: torch.Tensor,
         
     | 
| 194 | 
         
            +
                                            s_group: torch.Tensor, workspace: torch.Tensor,
         
     | 
| 195 | 
         
            +
                                            size_m: torch.SymInt, size_n: torch.SymInt,
         
     | 
| 196 | 
         
            +
                                            size_k: torch.SymInt) -> torch.Tensor:
         
     | 
| 197 | 
         
            +
                    return torch.empty((size_m, size_n),
         
     | 
| 198 | 
         
            +
                                        dtype=torch.float16,
         
     | 
| 199 | 
         
            +
                                        device=a.device)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                @register_fake(add_op_namespace_prefix("marlin_gemm"))
         
     | 
| 202 | 
         
            +
                def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
         
     | 
| 203 | 
         
            +
                                        b_scales: torch.Tensor, workspace: torch.Tensor,
         
     | 
| 204 | 
         
            +
                                        size_m: torch.SymInt, size_n: torch.SymInt,
         
     | 
| 205 | 
         
            +
                                        size_k: torch.SymInt) -> torch.Tensor:
         
     | 
| 206 | 
         
            +
                    return torch.empty((size_m, size_n),
         
     | 
| 207 | 
         
            +
                                        dtype=torch.float16,
         
     | 
| 208 | 
         
            +
                                        device=a.device)
         
     | 
    	
        ext-torch/scalar_type.py
    ADDED
    
    | 
         @@ -0,0 +1,330 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import functools
         
     | 
| 2 | 
         
            +
            import struct
         
     | 
| 3 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 4 | 
         
            +
            from enum import Enum
         
     | 
| 5 | 
         
            +
            from typing import Optional, Union
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            # Mirrors enum in `core/scalar_type.hpp`
         
     | 
| 9 | 
         
            +
            class NanRepr(Enum):
         
     | 
| 10 | 
         
            +
                NONE = 0  # nans are not supported
         
     | 
| 11 | 
         
            +
                IEEE_754 = 1  # nans are: Exp all 1s, mantissa not all 0s
         
     | 
| 12 | 
         
            +
                EXTD_RANGE_MAX_MIN = 2  # nans are: Exp all 1s, mantissa all 1s
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            # This ScalarType class is a parallel implementation of the C++ ScalarType
         
     | 
| 16 | 
         
            +
            # class found in csrc/core/scalar_type.hpp.  These two classes should be kept
         
     | 
| 17 | 
         
            +
            # in sync until the inductor fully supports custom C++ classes.
         
     | 
| 18 | 
         
            +
            @dataclass(frozen=True)
         
     | 
| 19 | 
         
            +
            class ScalarType:
         
     | 
| 20 | 
         
            +
                """
         
     | 
| 21 | 
         
            +
                ScalarType can represent a wide range of floating point and integer
         
     | 
| 22 | 
         
            +
                types, in particular it can be used to represent sub-byte data types
         
     | 
| 23 | 
         
            +
                (something that torch.dtype currently does not support). It is also
         
     | 
| 24 | 
         
            +
                capable of  representing types with a bias, i.e.:
         
     | 
| 25 | 
         
            +
                  `stored_value = value + bias`,
         
     | 
| 26 | 
         
            +
                this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
         
     | 
| 27 | 
         
            +
                of 8). The implementation for this class can be found in
         
     | 
| 28 | 
         
            +
                csrc/core/scalar_type.hpp, these type signatures should be kept in sync
         
     | 
| 29 | 
         
            +
                with that file.
         
     | 
| 30 | 
         
            +
                """
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                exponent: int
         
     | 
| 33 | 
         
            +
                """
         
     | 
| 34 | 
         
            +
                Number of bits in the exponent if this is a floating point type
         
     | 
| 35 | 
         
            +
                (zero if this an integer type)
         
     | 
| 36 | 
         
            +
                """
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                mantissa: int
         
     | 
| 39 | 
         
            +
                """
         
     | 
| 40 | 
         
            +
                Number of bits in the mantissa if this is a floating point type,
         
     | 
| 41 | 
         
            +
                or the number bits representing an integer excluding the sign bit if
         
     | 
| 42 | 
         
            +
                this an integer type.
         
     | 
| 43 | 
         
            +
                """
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                signed: bool
         
     | 
| 46 | 
         
            +
                "If the type is signed (i.e. has a sign bit)"
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                bias: int
         
     | 
| 49 | 
         
            +
                """
         
     | 
| 50 | 
         
            +
                bias used to encode the values in this scalar type
         
     | 
| 51 | 
         
            +
                (value = stored_value - bias, default 0) for example if we store the
         
     | 
| 52 | 
         
            +
                type as an unsigned integer with a bias of 128 then the value 0 will be
         
     | 
| 53 | 
         
            +
                stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
         
     | 
| 54 | 
         
            +
                """
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                _finite_values_only: bool = False
         
     | 
| 57 | 
         
            +
                """
         
     | 
| 58 | 
         
            +
                Private: if infs are supported, used `has_infs()` instead.
         
     | 
| 59 | 
         
            +
                """
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                nan_repr: NanRepr = NanRepr.IEEE_754
         
     | 
| 62 | 
         
            +
                """
         
     | 
| 63 | 
         
            +
                How NaNs are represent in this scalar type, returns NanRepr value.
         
     | 
| 64 | 
         
            +
                (not applicable for integer types)
         
     | 
| 65 | 
         
            +
                """
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                def _floating_point_max_int(self) -> int:
         
     | 
| 68 | 
         
            +
                    assert (
         
     | 
| 69 | 
         
            +
                        self.mantissa <= 52 and self.exponent <= 11
         
     | 
| 70 | 
         
            +
                    ), f"Cannot represent max/min as a double for type {self.__str__()}"
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    max_mantissa = (1 << self.mantissa) - 1
         
     | 
| 73 | 
         
            +
                    if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
         
     | 
| 74 | 
         
            +
                        max_mantissa = max_mantissa - 1
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    max_exponent = (1 << self.exponent) - 2
         
     | 
| 77 | 
         
            +
                    if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
         
     | 
| 78 | 
         
            +
                            or self.nan_repr == NanRepr.NONE):
         
     | 
| 79 | 
         
            +
                        assert (
         
     | 
| 80 | 
         
            +
                            self.exponent < 11
         
     | 
| 81 | 
         
            +
                        ), f"Cannot represent max/min as a double for type {self.__str__()}"
         
     | 
| 82 | 
         
            +
                        max_exponent = max_exponent + 1
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    # adjust the exponent to match that of a double
         
     | 
| 85 | 
         
            +
                    # for now we assume the exponent bias is the standard 2^(e-1) -1, (where
         
     | 
| 86 | 
         
            +
                    # e is the exponent bits), there is some precedent for non-standard
         
     | 
| 87 | 
         
            +
                    # biases, example `float8_e4m3b11fnuz` here:
         
     | 
| 88 | 
         
            +
                    # https://github.com/jax-ml/ml_dtypes but to avoid premature over
         
     | 
| 89 | 
         
            +
                    # complication we are just assuming the standard exponent bias until
         
     | 
| 90 | 
         
            +
                    # there is a need to support non-standard biases
         
     | 
| 91 | 
         
            +
                    exponent_bias = (1 << (self.exponent - 1)) - 1
         
     | 
| 92 | 
         
            +
                    exponent_bias_double = (1 << 10) - 1  # double e = 11
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    max_exponent_double = (max_exponent - exponent_bias +
         
     | 
| 95 | 
         
            +
                                           exponent_bias_double)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    # shift the mantissa and exponent into the proper positions for an
         
     | 
| 98 | 
         
            +
                    # IEEE double and bitwise-or them together.
         
     | 
| 99 | 
         
            +
                    return (max_mantissa <<
         
     | 
| 100 | 
         
            +
                            (52 - self.mantissa)) | (max_exponent_double << 52)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                def _floating_point_max(self) -> float:
         
     | 
| 103 | 
         
            +
                    double_raw = self._floating_point_max_int()
         
     | 
| 104 | 
         
            +
                    return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                def _raw_max(self) -> Union[int, float]:
         
     | 
| 107 | 
         
            +
                    if self.is_floating_point():
         
     | 
| 108 | 
         
            +
                        return self._floating_point_max()
         
     | 
| 109 | 
         
            +
                    else:
         
     | 
| 110 | 
         
            +
                        assert (self.size_bits < 64 or self.size_bits == 64
         
     | 
| 111 | 
         
            +
                                and self.is_signed()), "Cannot represent max as an int"
         
     | 
| 112 | 
         
            +
                        return (1 << self.mantissa) - 1
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                def _raw_min(self) -> Union[int, float]:
         
     | 
| 115 | 
         
            +
                    if self.is_floating_point():
         
     | 
| 116 | 
         
            +
                        assert self.is_signed(
         
     | 
| 117 | 
         
            +
                        ), "We currently assume all floating point types are signed"
         
     | 
| 118 | 
         
            +
                        sign_bit_double = 1 << 63
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                        max_raw = self._floating_point_max_int()
         
     | 
| 121 | 
         
            +
                        min_raw = max_raw | sign_bit_double
         
     | 
| 122 | 
         
            +
                        return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
         
     | 
| 123 | 
         
            +
                    else:
         
     | 
| 124 | 
         
            +
                        assert (not self.is_signed() or
         
     | 
| 125 | 
         
            +
                                self.size_bits <= 64), "Cannot represent min as a int64_t"
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                        if self.is_signed():
         
     | 
| 128 | 
         
            +
                            return -(1 << (self.size_bits - 1))
         
     | 
| 129 | 
         
            +
                        else:
         
     | 
| 130 | 
         
            +
                            return 0
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                @functools.cached_property
         
     | 
| 133 | 
         
            +
                def id(self) -> int:
         
     | 
| 134 | 
         
            +
                    """
         
     | 
| 135 | 
         
            +
                    Convert the ScalarType to an int which can be passed to pytorch custom
         
     | 
| 136 | 
         
            +
                    ops. This layout of the int must be kept in sync with the C++
         
     | 
| 137 | 
         
            +
                    ScalarType's from_id method.
         
     | 
| 138 | 
         
            +
                    """
         
     | 
| 139 | 
         
            +
                    val = 0
         
     | 
| 140 | 
         
            +
                    offset = 0
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    def or_and_advance(member, bit_width):
         
     | 
| 143 | 
         
            +
                        nonlocal val
         
     | 
| 144 | 
         
            +
                        nonlocal offset
         
     | 
| 145 | 
         
            +
                        bit_mask = (1 << bit_width) - 1
         
     | 
| 146 | 
         
            +
                        val = val | (int(member) & bit_mask) << offset
         
     | 
| 147 | 
         
            +
                        offset = offset + bit_width
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    or_and_advance(self.exponent, 8)
         
     | 
| 150 | 
         
            +
                    or_and_advance(self.mantissa, 8)
         
     | 
| 151 | 
         
            +
                    or_and_advance(self.signed, 1)
         
     | 
| 152 | 
         
            +
                    or_and_advance(self.bias, 32)
         
     | 
| 153 | 
         
            +
                    or_and_advance(self._finite_values_only, 1)
         
     | 
| 154 | 
         
            +
                    or_and_advance(self.nan_repr.value, 8)
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                    assert offset <= 64, \
         
     | 
| 157 | 
         
            +
                        f"ScalarType fields too big {offset} to fit into an int64"
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    return val
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                @property
         
     | 
| 162 | 
         
            +
                def size_bits(self) -> int:
         
     | 
| 163 | 
         
            +
                    return self.exponent + self.mantissa + int(self.signed)
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                def min(self) -> Union[int, float]:
         
     | 
| 166 | 
         
            +
                    """
         
     | 
| 167 | 
         
            +
                    Min representable value for this scalar type.
         
     | 
| 168 | 
         
            +
                    (accounting for bias if there is one)
         
     | 
| 169 | 
         
            +
                    """
         
     | 
| 170 | 
         
            +
                    return self._raw_min() - self.bias
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                def max(self) -> Union[int, float]:
         
     | 
| 173 | 
         
            +
                    """
         
     | 
| 174 | 
         
            +
                    Max representable value for this scalar type.
         
     | 
| 175 | 
         
            +
                    (accounting for bias if there is one)
         
     | 
| 176 | 
         
            +
                    """
         
     | 
| 177 | 
         
            +
                    return self._raw_max() - self.bias
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                def is_signed(self) -> bool:
         
     | 
| 180 | 
         
            +
                    """
         
     | 
| 181 | 
         
            +
                    If the type is signed (i.e. has a sign bit), same as `signed`
         
     | 
| 182 | 
         
            +
                    added for consistency with:
         
     | 
| 183 | 
         
            +
                    https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
         
     | 
| 184 | 
         
            +
                    """
         
     | 
| 185 | 
         
            +
                    return self.signed
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                def is_floating_point(self) -> bool:
         
     | 
| 188 | 
         
            +
                    "If the type is a floating point type"
         
     | 
| 189 | 
         
            +
                    return self.exponent != 0
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                def is_integer(self) -> bool:
         
     | 
| 192 | 
         
            +
                    "If the type is an integer type"
         
     | 
| 193 | 
         
            +
                    return self.exponent == 0
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                def has_bias(self) -> bool:
         
     | 
| 196 | 
         
            +
                    "If the type has a non-zero bias"
         
     | 
| 197 | 
         
            +
                    return self.bias != 0
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                def has_infs(self) -> bool:
         
     | 
| 200 | 
         
            +
                    "If the type is floating point and supports infinity"
         
     | 
| 201 | 
         
            +
                    return not self._finite_values_only
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                def has_nans(self) -> bool:
         
     | 
| 204 | 
         
            +
                    return self.nan_repr != NanRepr.NONE.value
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                def is_ieee_754(self) -> bool:
         
     | 
| 207 | 
         
            +
                    """
         
     | 
| 208 | 
         
            +
                    If the type is a floating point type that follows IEEE 754
         
     | 
| 209 | 
         
            +
                    conventions
         
     | 
| 210 | 
         
            +
                    """
         
     | 
| 211 | 
         
            +
                    return self.nan_repr == NanRepr.IEEE_754.value and \
         
     | 
| 212 | 
         
            +
                        not self._finite_values_only
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                def __str__(self) -> str:
         
     | 
| 215 | 
         
            +
                    """
         
     | 
| 216 | 
         
            +
                    naming generally follows: https://github.com/jax-ml/ml_dtypes
         
     | 
| 217 | 
         
            +
                    for floating point types (leading f) the scheme is:
         
     | 
| 218 | 
         
            +
                    `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
         
     | 
| 219 | 
         
            +
                    flags:
         
     | 
| 220 | 
         
            +
                      - no-flags: means it follows IEEE 754 conventions
         
     | 
| 221 | 
         
            +
                      - f: means finite values only (no infinities)
         
     | 
| 222 | 
         
            +
                      - n: means nans are supported (non-standard encoding)
         
     | 
| 223 | 
         
            +
                    for integer types the scheme is:
         
     | 
| 224 | 
         
            +
                      `[u]int<size_bits>[b<bias>]`
         
     | 
| 225 | 
         
            +
                      - if bias is not present it means its zero
         
     | 
| 226 | 
         
            +
                    """
         
     | 
| 227 | 
         
            +
                    if self.is_floating_point():
         
     | 
| 228 | 
         
            +
                        ret = "float" + str(self.size_bits) + "_e" + str(
         
     | 
| 229 | 
         
            +
                            self.exponent) + "m" + str(self.mantissa)
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                        if not self.is_ieee_754():
         
     | 
| 232 | 
         
            +
                            if self._finite_values_only:
         
     | 
| 233 | 
         
            +
                                ret = ret + "f"
         
     | 
| 234 | 
         
            +
                            if self.nan_repr != NanRepr.NONE:
         
     | 
| 235 | 
         
            +
                                ret = ret + "n"
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                        return ret
         
     | 
| 238 | 
         
            +
                    else:
         
     | 
| 239 | 
         
            +
                        ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
         
     | 
| 240 | 
         
            +
                        if self.has_bias():
         
     | 
| 241 | 
         
            +
                            ret = ret + "b" + str(self.bias)
         
     | 
| 242 | 
         
            +
                        return ret
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                def __repr__(self) -> str:
         
     | 
| 245 | 
         
            +
                    return "ScalarType." + self.__str__()
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                # __len__ needs to be defined (and has to throw TypeError) for pytorch's
         
     | 
| 248 | 
         
            +
                # opcheck to work.
         
     | 
| 249 | 
         
            +
                def __len__(self) -> int:
         
     | 
| 250 | 
         
            +
                    raise TypeError
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                #
         
     | 
| 253 | 
         
            +
                # Convenience Constructors
         
     | 
| 254 | 
         
            +
                #
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                @classmethod
         
     | 
| 257 | 
         
            +
                def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
         
     | 
| 258 | 
         
            +
                    "Create a signed integer scalar type (size_bits includes sign-bit)."
         
     | 
| 259 | 
         
            +
                    ret = cls(0, size_bits - 1, True, bias if bias else 0)
         
     | 
| 260 | 
         
            +
                    ret.id  # noqa B018: make sure the id is cached
         
     | 
| 261 | 
         
            +
                    return ret
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                @classmethod
         
     | 
| 264 | 
         
            +
                def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
         
     | 
| 265 | 
         
            +
                    """Create a unsigned integer scalar type."""
         
     | 
| 266 | 
         
            +
                    ret = cls(0, size_bits, False, bias if bias else 0)
         
     | 
| 267 | 
         
            +
                    ret.id  # noqa B018: make sure the id is cached
         
     | 
| 268 | 
         
            +
                    return ret
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                @classmethod
         
     | 
| 271 | 
         
            +
                def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
         
     | 
| 272 | 
         
            +
                    """
         
     | 
| 273 | 
         
            +
                    Create a standard floating point type
         
     | 
| 274 | 
         
            +
                    (i.e. follows IEEE 754 conventions).
         
     | 
| 275 | 
         
            +
                    """
         
     | 
| 276 | 
         
            +
                    assert (mantissa > 0 and exponent > 0)
         
     | 
| 277 | 
         
            +
                    ret = cls(exponent, mantissa, True, 0)
         
     | 
| 278 | 
         
            +
                    ret.id  # noqa B018: make sure the id is cached
         
     | 
| 279 | 
         
            +
                    return ret
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                @classmethod
         
     | 
| 282 | 
         
            +
                def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
         
     | 
| 283 | 
         
            +
                           nan_repr: NanRepr) -> 'ScalarType':
         
     | 
| 284 | 
         
            +
                    """
         
     | 
| 285 | 
         
            +
                    Create a non-standard floating point type
         
     | 
| 286 | 
         
            +
                    (i.e. does not follow IEEE 754 conventions).
         
     | 
| 287 | 
         
            +
                    """
         
     | 
| 288 | 
         
            +
                    assert (mantissa > 0 and exponent > 0)
         
     | 
| 289 | 
         
            +
                    assert (nan_repr != NanRepr.IEEE_754), (
         
     | 
| 290 | 
         
            +
                        "use `float_IEEE754` constructor for floating point types that "
         
     | 
| 291 | 
         
            +
                        "follow IEEE 754 conventions")
         
     | 
| 292 | 
         
            +
                    ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
         
     | 
| 293 | 
         
            +
                    ret.id  # noqa B018: make sure the id is cached
         
     | 
| 294 | 
         
            +
                    return ret
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
            # naming generally follows: https://github.com/jax-ml/ml_dtypes
         
     | 
| 298 | 
         
            +
            # for floating point types (leading f) the scheme is:
         
     | 
| 299 | 
         
            +
            #  `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
         
     | 
| 300 | 
         
            +
            #  flags:
         
     | 
| 301 | 
         
            +
            #  - no-flags: means it follows IEEE 754 conventions
         
     | 
| 302 | 
         
            +
            #  - f: means finite values only (no infinities)
         
     | 
| 303 | 
         
            +
            #  - n: means nans are supported (non-standard encoding)
         
     | 
| 304 | 
         
            +
            # for integer types the scheme is:
         
     | 
| 305 | 
         
            +
            #  `[u]int<size_bits>[b<bias>]`
         
     | 
| 306 | 
         
            +
            #  - if bias is not present it means its zero
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
            class scalar_types:
         
     | 
| 310 | 
         
            +
                int4 = ScalarType.int_(4, None)
         
     | 
| 311 | 
         
            +
                uint4 = ScalarType.uint(4, None)
         
     | 
| 312 | 
         
            +
                int8 = ScalarType.int_(8, None)
         
     | 
| 313 | 
         
            +
                uint8 = ScalarType.uint(8, None)
         
     | 
| 314 | 
         
            +
                float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
         
     | 
| 315 | 
         
            +
                float8_e5m2 = ScalarType.float_IEEE754(5, 2)
         
     | 
| 316 | 
         
            +
                float16_e8m7 = ScalarType.float_IEEE754(8, 7)
         
     | 
| 317 | 
         
            +
                float16_e5m10 = ScalarType.float_IEEE754(5, 10)
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
         
     | 
| 320 | 
         
            +
                float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                # "gptq" types
         
     | 
| 323 | 
         
            +
                uint2b2 = ScalarType.uint(2, 2)
         
     | 
| 324 | 
         
            +
                uint3b4 = ScalarType.uint(3, 4)
         
     | 
| 325 | 
         
            +
                uint4b8 = ScalarType.uint(4, 8)
         
     | 
| 326 | 
         
            +
                uint8b128 = ScalarType.uint(8, 128)
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                # colloquial names
         
     | 
| 329 | 
         
            +
                bfloat16 = float16_e8m7
         
     | 
| 330 | 
         
            +
                float16 = float16_e5m10
         
     | 
    	
        ext-torch/torch_binding.cpp
    CHANGED
    
    | 
         @@ -65,16 +65,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { 
     | 
|
| 65 | 
         
             
                  "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
         
     | 
| 66 | 
         
             
                  "Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
         
     | 
| 67 | 
         
             
                  "SymInt size_k) -> Tensor");
         
     | 
| 68 | 
         
            -
              ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
         
     | 
| 69 | 
         | 
| 70 | 
         
             
              // awq_marlin repack from AWQ.
         
     | 
| 71 | 
         
             
              ops.def(
         
     | 
| 72 | 
         
             
                  "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
         
     | 
| 73 | 
         
             
                  "SymInt size_n, int num_bits) -> Tensor");
         
     | 
| 74 | 
         
            -
              ops.impl("awq_marlin_repack", &awq_marlin_repack);
         
     | 
| 75 | 
         | 
| 76 | 
         
             
              // gptq_marlin Optimized Quantized GEMM for GPTQ.
         
     | 
| 77 | 
         
            -
              ops.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
         
     | 
| 78 | 
         
             
              ops.def(
         
     | 
| 79 | 
         
             
                  "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
         
     | 
| 80 | 
         
             
                  "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
         
     | 
| 
         @@ -86,7 +83,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { 
     | 
|
| 86 | 
         
             
              ops.def(
         
     | 
| 87 | 
         
             
                  "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
         
     | 
| 88 | 
         
             
                  "SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 89 | 
         
             
              ops.impl("gptq_marlin_repack", &gptq_marlin_repack);
         
     | 
| 
         | 
|
| 
         | 
|
| 90 | 
         
             
            }
         
     | 
| 91 | 
         | 
| 92 | 
         
             
            TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, ops) {
         
     | 
| 
         | 
|
| 65 | 
         
             
                  "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
         
     | 
| 66 | 
         
             
                  "Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
         
     | 
| 67 | 
         
             
                  "SymInt size_k) -> Tensor");
         
     | 
| 
         | 
|
| 68 | 
         | 
| 69 | 
         
             
              // awq_marlin repack from AWQ.
         
     | 
| 70 | 
         
             
              ops.def(
         
     | 
| 71 | 
         
             
                  "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
         
     | 
| 72 | 
         
             
                  "SymInt size_n, int num_bits) -> Tensor");
         
     | 
| 
         | 
|
| 73 | 
         | 
| 74 | 
         
             
              // gptq_marlin Optimized Quantized GEMM for GPTQ.
         
     | 
| 
         | 
|
| 75 | 
         
             
              ops.def(
         
     | 
| 76 | 
         
             
                  "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
         
     | 
| 77 | 
         
             
                  "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
         
     | 
| 
         | 
|
| 83 | 
         
             
              ops.def(
         
     | 
| 84 | 
         
             
                  "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
         
     | 
| 85 | 
         
             
                  "SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
              // Marlin (Dense) Optimized Quantized GEMM for GPTQ.
         
     | 
| 88 | 
         
            +
              ops.def(
         
     | 
| 89 | 
         
            +
                  "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
         
     | 
| 90 | 
         
            +
                  "Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
         
     | 
| 91 | 
         
            +
                  "Tensor");
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
              // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
         
     | 
| 94 | 
         
            +
              ops.def(
         
     | 
| 95 | 
         
            +
                  "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
         
     | 
| 96 | 
         
            +
                  "Tensor b_scales, Tensor workspace, "
         
     | 
| 97 | 
         
            +
                  "int b_q_type, "
         
     | 
| 98 | 
         
            +
                  "SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
              // marlin_qqq_gemm for QQQ.
         
     | 
| 101 | 
         
            +
              ops.def(
         
     | 
| 102 | 
         
            +
                  "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
         
     | 
| 103 | 
         
            +
                  "Tensor s_tok, Tensor s_ch, Tensor s_group, "
         
     | 
| 104 | 
         
            +
                  "Tensor! workspace, SymInt size_m, SymInt size_n, "
         
     | 
| 105 | 
         
            +
                  "SymInt size_k) -> Tensor");
         
     | 
| 106 | 
         
            +
            }
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
            TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, ops) {
         
     | 
| 109 | 
         
            +
              ops.impl("awq_marlin_repack", &awq_marlin_repack);
         
     | 
| 110 | 
         
            +
              ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
         
     | 
| 111 | 
         
            +
              ops.impl("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
         
     | 
| 112 | 
         
            +
              ops.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
         
     | 
| 113 | 
         
             
              ops.impl("gptq_marlin_repack", &gptq_marlin_repack);
         
     | 
| 114 | 
         
            +
              ops.impl("marlin_gemm", &marlin_gemm);
         
     | 
| 115 | 
         
            +
              ops.impl("marlin_qqq_gemm", &marlin_qqq_gemm);
         
     | 
| 116 | 
         
             
            }
         
     | 
| 117 | 
         | 
| 118 | 
         
             
            TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, ops) {
         
     | 
    	
        ext-torch/torch_binding.h
    CHANGED
    
    | 
         @@ -74,3 +74,26 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, 
     | 
|
| 74 | 
         
             
            torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
         
     | 
| 75 | 
         
             
                                                  torch::Tensor& perm, c10::SymInt size_k,
         
     | 
| 76 | 
         
             
                                                  c10::SymInt size_n, int64_t num_bits);
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 74 | 
         
             
            torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
         
     | 
| 75 | 
         
             
                                                  torch::Tensor& perm, c10::SymInt size_k,
         
     | 
| 76 | 
         
             
                                                  c10::SymInt size_n, int64_t num_bits);
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            // Marlin
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
         
     | 
| 82 | 
         
            +
                                      torch::Tensor& b_scales, torch::Tensor& workspace,
         
     | 
| 83 | 
         
            +
                                      int64_t size_m, int64_t size_n, int64_t size_k);
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
         
     | 
| 86 | 
         
            +
                                              torch::Tensor& b_meta,
         
     | 
| 87 | 
         
            +
                                              torch::Tensor& b_scales,
         
     | 
| 88 | 
         
            +
                                              torch::Tensor& workspace,
         
     | 
| 89 | 
         
            +
                                              vllm::ScalarTypeId const b_q_type_id,
         
     | 
| 90 | 
         
            +
                                              int64_t size_m, int64_t size_n,
         
     | 
| 91 | 
         
            +
                                              int64_t size_k);
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
         
     | 
| 94 | 
         
            +
                                          torch::Tensor const& b_q_weight,
         
     | 
| 95 | 
         
            +
                                          torch::Tensor const& s_tok,
         
     | 
| 96 | 
         
            +
                                          torch::Tensor const& s_ch,
         
     | 
| 97 | 
         
            +
                                          torch::Tensor const& s_group,
         
     | 
| 98 | 
         
            +
                                          torch::Tensor& workspace, int64_t size_m,
         
     | 
| 99 | 
         
            +
                                          int64_t size_n, int64_t size_k);
         
     | 
    	
        ext-torch/utils/marlin_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,391 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import List, Optional, Tuple
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import numpy
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import quantization as ops
         
     | 
| 7 | 
         
            +
            from quantization.scalar_type import ScalarType, scalar_types
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from .quant_utils import pack_cols, unpack_cols
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            GPTQ_MARLIN_TILE = 16
         
     | 
| 12 | 
         
            +
            GPTQ_MARLIN_MIN_THREAD_N = 64
         
     | 
| 13 | 
         
            +
            GPTQ_MARLIN_MIN_THREAD_K = 128
         
     | 
| 14 | 
         
            +
            GPTQ_MARLIN_MAX_PARALLEL = 16
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            GPTQ_MARLIN_24_TILE = 16
         
     | 
| 17 | 
         
            +
            GPTQ_MARLIN_24_MIN_THREAD_N = 128
         
     | 
| 18 | 
         
            +
            GPTQ_MARLIN_24_MIN_THREAD_K = 128
         
     | 
| 19 | 
         
            +
            GPTQ_MARLIN_24_MAX_PARALLEL = 64
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
         
     | 
| 22 | 
         
            +
            GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            MARLIN_QQQ_TILE = 16
         
     | 
| 25 | 
         
            +
            MARLIN_QQQ_MIN_THREAD_N = 64
         
     | 
| 26 | 
         
            +
            MARLIN_QQQ_MIN_THREAD_K = 128
         
     | 
| 27 | 
         
            +
            MARLIN_QQQ_MAX_PARALLEL = 16
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
         
     | 
| 30 | 
         
            +
            MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]
         
     | 
| 31 | 
         
            +
            MARLIN_QQQ_SUPPORTED_SYM = [True]
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            # In case there is a performance issue with Marlin, the variable below can be
         
     | 
| 36 | 
         
            +
            # changed to False, which allows Marlin to perform global reductions in fp16
         
     | 
| 37 | 
         
            +
            # precision (instead of fp32), and therefore, save on some memory movements.
         
     | 
| 38 | 
         
            +
            USE_FP32_REDUCE_DEFAULT = True
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            # For binary size and compile time, we don't support the same types for with and
         
     | 
| 42 | 
         
            +
            #  without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
         
     | 
| 43 | 
         
            +
            #  TODO: we may want to move this into the C++ so its closer to the actual impl
         
     | 
| 44 | 
         
            +
            def query_marlin_supported_quant_types(
         
     | 
| 45 | 
         
            +
                has_zp: bool, device_capability: Optional[int] = None
         
     | 
| 46 | 
         
            +
            ):
         
     | 
| 47 | 
         
            +
                if device_capability is None:
         
     | 
| 48 | 
         
            +
                    capability_tuple = torch.cuda.get_device_capability()
         
     | 
| 49 | 
         
            +
                    device_capability = capability_tuple[0] * 10 + capability_tuple[1]
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                if device_capability < 80:
         
     | 
| 52 | 
         
            +
                    return []
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                if has_zp:
         
     | 
| 55 | 
         
            +
                    # AWQ style, unsigned + runtime zero-point
         
     | 
| 56 | 
         
            +
                    return [scalar_types.uint4, scalar_types.uint8]
         
     | 
| 57 | 
         
            +
                else:
         
     | 
| 58 | 
         
            +
                    # GPTQ style, unsigned + symmetric bias
         
     | 
| 59 | 
         
            +
                    # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
         
     | 
| 60 | 
         
            +
                    #  to add `scalar_types.float8_e4m3fn` here
         
     | 
| 61 | 
         
            +
                    return [scalar_types.uint4b8, scalar_types.uint8b128]
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            def _check_marlin_supported(
         
     | 
| 65 | 
         
            +
                quant_type: ScalarType,
         
     | 
| 66 | 
         
            +
                group_size: Optional[int],
         
     | 
| 67 | 
         
            +
                has_zp: bool,
         
     | 
| 68 | 
         
            +
                device_capability: Optional[int] = None,
         
     | 
| 69 | 
         
            +
            ) -> Tuple[bool, Optional[str]]:
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                if device_capability is None:
         
     | 
| 72 | 
         
            +
                    capability_tuple = torch.cuda.get_device_capability()
         
     | 
| 73 | 
         
            +
                    device_capability = capability_tuple[0] * 10 + capability_tuple[1]
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                supported_types = query_marlin_supported_quant_types(has_zp, device_capability)
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                if quant_type not in supported_types:
         
     | 
| 78 | 
         
            +
                    return (
         
     | 
| 79 | 
         
            +
                        False,
         
     | 
| 80 | 
         
            +
                        f"Marlin does not support weight_bits = {quant_type}. "
         
     | 
| 81 | 
         
            +
                        f"Only types = {supported_types} "
         
     | 
| 82 | 
         
            +
                        f"are supported (for group_size = {group_size}, "
         
     | 
| 83 | 
         
            +
                        f"device_capability = {device_capability}, zp = {has_zp}).",
         
     | 
| 84 | 
         
            +
                    )
         
     | 
| 85 | 
         
            +
                if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
         
     | 
| 86 | 
         
            +
                    return (
         
     | 
| 87 | 
         
            +
                        False,
         
     | 
| 88 | 
         
            +
                        f"Marlin does not support group_size = {group_size}. "
         
     | 
| 89 | 
         
            +
                        f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
         
     | 
| 90 | 
         
            +
                        "are supported.",
         
     | 
| 91 | 
         
            +
                    )
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                return True, None
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            def check_marlin_supported(
         
     | 
| 97 | 
         
            +
                quant_type: ScalarType,
         
     | 
| 98 | 
         
            +
                group_size: int,
         
     | 
| 99 | 
         
            +
                has_zp: bool = False,
         
     | 
| 100 | 
         
            +
                device_capability: Optional[int] = None,
         
     | 
| 101 | 
         
            +
            ) -> bool:
         
     | 
| 102 | 
         
            +
                cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
         
     | 
| 103 | 
         
            +
                return cond
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
            def verify_marlin_supported(
         
     | 
| 107 | 
         
            +
                quant_type: ScalarType, group_size: int, has_zp: bool = False
         
     | 
| 108 | 
         
            +
            ) -> None:
         
     | 
| 109 | 
         
            +
                cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
         
     | 
| 110 | 
         
            +
                if not cond:
         
     | 
| 111 | 
         
            +
                    assert err_msg is not None
         
     | 
| 112 | 
         
            +
                    raise ValueError(err_msg)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            def verify_marlin_supports_shape(
         
     | 
| 116 | 
         
            +
                output_size_per_partition: int,
         
     | 
| 117 | 
         
            +
                input_size_per_partition: int,
         
     | 
| 118 | 
         
            +
                input_size: int,
         
     | 
| 119 | 
         
            +
                group_size: int,
         
     | 
| 120 | 
         
            +
            ) -> None:
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                # Validate output_size_per_partition
         
     | 
| 123 | 
         
            +
                if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
         
     | 
| 124 | 
         
            +
                    raise ValueError(
         
     | 
| 125 | 
         
            +
                        f"Weight output_size_per_partition = "
         
     | 
| 126 | 
         
            +
                        f"{output_size_per_partition} is not divisible by "
         
     | 
| 127 | 
         
            +
                        f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
         
     | 
| 128 | 
         
            +
                        "Consider reducing tensor_parallel_size or running "
         
     | 
| 129 | 
         
            +
                        "with --quantization gptq."
         
     | 
| 130 | 
         
            +
                    )
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                # Validate input_size_per_partition
         
     | 
| 133 | 
         
            +
                if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
         
     | 
| 134 | 
         
            +
                    raise ValueError(
         
     | 
| 135 | 
         
            +
                        f"Weight input_size_per_partition = "
         
     | 
| 136 | 
         
            +
                        f"{input_size_per_partition} is not divisible "
         
     | 
| 137 | 
         
            +
                        f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
         
     | 
| 138 | 
         
            +
                        "Consider reducing tensor_parallel_size or running "
         
     | 
| 139 | 
         
            +
                        "with --quantization gptq."
         
     | 
| 140 | 
         
            +
                    )
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                if group_size < input_size and input_size_per_partition % group_size != 0:
         
     | 
| 143 | 
         
            +
                    raise ValueError(
         
     | 
| 144 | 
         
            +
                        f"Weight input_size_per_partition = {input_size_per_partition}"
         
     | 
| 145 | 
         
            +
                        f" is not divisible by group_size = {group_size}."
         
     | 
| 146 | 
         
            +
                        "Consider reducing tensor_parallel_size or running "
         
     | 
| 147 | 
         
            +
                        "with --quantization gptq."
         
     | 
| 148 | 
         
            +
                    )
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
            def check_marlin_supports_shape(
         
     | 
| 152 | 
         
            +
                output_size_per_partition: int,
         
     | 
| 153 | 
         
            +
                input_size_per_partition: int,
         
     | 
| 154 | 
         
            +
                input_size: int,
         
     | 
| 155 | 
         
            +
                group_size: int,
         
     | 
| 156 | 
         
            +
            ) -> Tuple[bool, Optional[str]]:
         
     | 
| 157 | 
         
            +
                try:
         
     | 
| 158 | 
         
            +
                    verify_marlin_supports_shape(
         
     | 
| 159 | 
         
            +
                        output_size_per_partition, input_size_per_partition, input_size, group_size
         
     | 
| 160 | 
         
            +
                    )
         
     | 
| 161 | 
         
            +
                except ValueError as e:
         
     | 
| 162 | 
         
            +
                    return False, e.__str__()
         
     | 
| 163 | 
         
            +
                return True, None
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
            def marlin_make_workspace(
         
     | 
| 167 | 
         
            +
                output_size_per_partition: int, device: torch.device
         
     | 
| 168 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 169 | 
         
            +
                max_workspace_size = (
         
     | 
| 170 | 
         
            +
                    output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
         
     | 
| 171 | 
         
            +
                ) * GPTQ_MARLIN_MAX_PARALLEL
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                return torch.zeros(
         
     | 
| 174 | 
         
            +
                    max_workspace_size, dtype=torch.int, device=device, requires_grad=False
         
     | 
| 175 | 
         
            +
                )
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
            def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
         
     | 
| 179 | 
         
            +
                return (not act_order) or (act_order and not is_row_parallel)
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
            def marlin_repeat_scales_on_all_ranks(
         
     | 
| 183 | 
         
            +
                act_order: bool, group_size: int, is_row_parallel: bool
         
     | 
| 184 | 
         
            +
            ) -> bool:
         
     | 
| 185 | 
         
            +
                # Need to repeat scales on every rank if act_ordering or
         
     | 
| 186 | 
         
            +
                # channelwise and RowParallelLinear
         
     | 
| 187 | 
         
            +
                is_channelwise = group_size == -1
         
     | 
| 188 | 
         
            +
                return act_order or (is_channelwise and is_row_parallel)
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
            def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
         
     | 
| 192 | 
         
            +
                return torch.nn.Parameter(
         
     | 
| 193 | 
         
            +
                    torch.empty(0, dtype=torch.int, device=device), requires_grad=False
         
     | 
| 194 | 
         
            +
                )
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
            def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
         
     | 
| 198 | 
         
            +
                return torch.nn.Parameter(
         
     | 
| 199 | 
         
            +
                    torch.empty(0, dtype=torch.int, device=device), requires_grad=False
         
     | 
| 200 | 
         
            +
                )
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
            def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 204 | 
         
            +
                g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
         
     | 
| 205 | 
         
            +
                return g_idx[g_idx_sort_indices], g_idx_sort_indices
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
            def get_scale_perms():
         
     | 
| 209 | 
         
            +
                scale_perm: List[int] = []
         
     | 
| 210 | 
         
            +
                for i in range(8):
         
     | 
| 211 | 
         
            +
                    scale_perm.extend([i + 8 * j for j in range(8)])
         
     | 
| 212 | 
         
            +
                scale_perm_single: List[int] = []
         
     | 
| 213 | 
         
            +
                for i in range(4):
         
     | 
| 214 | 
         
            +
                    scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
         
     | 
| 215 | 
         
            +
                return scale_perm, scale_perm_single
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
            def marlin_permute_scales(
         
     | 
| 219 | 
         
            +
                s: torch.Tensor, size_k: int, size_n: int, group_size: int
         
     | 
| 220 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                scale_perm, scale_perm_single = get_scale_perms()
         
     | 
| 223 | 
         
            +
                if group_size < size_k and group_size != -1:
         
     | 
| 224 | 
         
            +
                    s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
         
     | 
| 225 | 
         
            +
                else:
         
     | 
| 226 | 
         
            +
                    s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
         
     | 
| 227 | 
         
            +
                s = s.reshape((-1, size_n)).contiguous()
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                return s
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
            def marlin_moe_permute_scales(
         
     | 
| 233 | 
         
            +
                s: torch.Tensor,
         
     | 
| 234 | 
         
            +
                size_k: int,
         
     | 
| 235 | 
         
            +
                size_n: int,
         
     | 
| 236 | 
         
            +
                group_size: int,
         
     | 
| 237 | 
         
            +
            ):
         
     | 
| 238 | 
         
            +
                num_experts = s.shape[0]
         
     | 
| 239 | 
         
            +
                output = torch.empty(
         
     | 
| 240 | 
         
            +
                    (num_experts, s.shape[1], s.shape[2]),
         
     | 
| 241 | 
         
            +
                    device=s.device,
         
     | 
| 242 | 
         
            +
                    dtype=s.dtype,
         
     | 
| 243 | 
         
            +
                )
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                for e in range(num_experts):
         
     | 
| 246 | 
         
            +
                    output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
         
     | 
| 247 | 
         
            +
                return output
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
            def marlin_zero_points(
         
     | 
| 251 | 
         
            +
                zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
         
     | 
| 252 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 253 | 
         
            +
                # Permute zero-points in a similar way to scales, but do not use the
         
     | 
| 254 | 
         
            +
                # "single" permutation, since zero-points are applied on every MMA
         
     | 
| 255 | 
         
            +
                scale_perm, _ = get_scale_perms()
         
     | 
| 256 | 
         
            +
                zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                # Interleave column dim (for the dequantize code) and pack it to int32
         
     | 
| 259 | 
         
            +
                if num_bits == 4:
         
     | 
| 260 | 
         
            +
                    interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
         
     | 
| 261 | 
         
            +
                elif num_bits == 8:
         
     | 
| 262 | 
         
            +
                    interleave = numpy.array([0, 2, 1, 3])
         
     | 
| 263 | 
         
            +
                else:
         
     | 
| 264 | 
         
            +
                    raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
         
     | 
| 267 | 
         
            +
                zp = zp.reshape((-1, size_n)).contiguous()
         
     | 
| 268 | 
         
            +
                zp = pack_cols(zp, num_bits, size_k, size_n)
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                return zp
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
            def awq_to_marlin_zero_points(
         
     | 
| 274 | 
         
            +
                q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
         
     | 
| 275 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 276 | 
         
            +
                # AWQ zero-points are quantized and packed on the column dim.
         
     | 
| 277 | 
         
            +
                # In addition, the values are permuted based on dequantizer.
         
     | 
| 278 | 
         
            +
                # Here we undo both of these, and then apply marlin permutation
         
     | 
| 279 | 
         
            +
                # and pack it back.
         
     | 
| 280 | 
         
            +
                q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                # Undo interleaving (use argsort(..) to get inverse perm)
         
     | 
| 283 | 
         
            +
                if num_bits == 4:
         
     | 
| 284 | 
         
            +
                    undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
         
     | 
| 285 | 
         
            +
                elif num_bits == 8:
         
     | 
| 286 | 
         
            +
                    undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
         
     | 
| 287 | 
         
            +
                else:
         
     | 
| 288 | 
         
            +
                    raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
         
     | 
| 291 | 
         
            +
                q_zp = q_zp.reshape((-1, size_n)).contiguous()
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
         
     | 
| 294 | 
         
            +
                return marlin_zp
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
            def moe_awq_to_marlin_zero_points(
         
     | 
| 298 | 
         
            +
                q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
         
     | 
| 299 | 
         
            +
            ):
         
     | 
| 300 | 
         
            +
                num_experts = q_zp_packed.shape[0]
         
     | 
| 301 | 
         
            +
                output = torch.empty(
         
     | 
| 302 | 
         
            +
                    (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
         
     | 
| 303 | 
         
            +
                    device=q_zp_packed.device,
         
     | 
| 304 | 
         
            +
                    dtype=q_zp_packed.dtype,
         
     | 
| 305 | 
         
            +
                )
         
     | 
| 306 | 
         
            +
                for e in range(num_experts):
         
     | 
| 307 | 
         
            +
                    output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
         
     | 
| 308 | 
         
            +
                return output
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
            def apply_gptq_marlin_linear(
         
     | 
| 312 | 
         
            +
                input: torch.Tensor,
         
     | 
| 313 | 
         
            +
                weight: torch.Tensor,
         
     | 
| 314 | 
         
            +
                weight_scale: torch.Tensor,
         
     | 
| 315 | 
         
            +
                weight_zp: torch.Tensor,
         
     | 
| 316 | 
         
            +
                g_idx: torch.Tensor,
         
     | 
| 317 | 
         
            +
                g_idx_sort_indices: torch.Tensor,
         
     | 
| 318 | 
         
            +
                workspace: torch.Tensor,
         
     | 
| 319 | 
         
            +
                wtype: ScalarType,
         
     | 
| 320 | 
         
            +
                output_size_per_partition: int,
         
     | 
| 321 | 
         
            +
                input_size_per_partition: int,
         
     | 
| 322 | 
         
            +
                is_k_full: bool,
         
     | 
| 323 | 
         
            +
                bias: Optional[torch.Tensor] = None,
         
     | 
| 324 | 
         
            +
                use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
         
     | 
| 325 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 326 | 
         
            +
                reshaped_x = input.reshape(-1, input.shape[-1])
         
     | 
| 327 | 
         
            +
                out_shape = input.shape[:-1] + (output_size_per_partition,)
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                output = ops.gptq_marlin_gemm(
         
     | 
| 330 | 
         
            +
                    reshaped_x,
         
     | 
| 331 | 
         
            +
                    weight,
         
     | 
| 332 | 
         
            +
                    weight_scale,
         
     | 
| 333 | 
         
            +
                    weight_zp,
         
     | 
| 334 | 
         
            +
                    g_idx,
         
     | 
| 335 | 
         
            +
                    g_idx_sort_indices,
         
     | 
| 336 | 
         
            +
                    workspace,
         
     | 
| 337 | 
         
            +
                    wtype,
         
     | 
| 338 | 
         
            +
                    size_m=reshaped_x.shape[0],
         
     | 
| 339 | 
         
            +
                    size_n=output_size_per_partition,
         
     | 
| 340 | 
         
            +
                    size_k=input_size_per_partition,
         
     | 
| 341 | 
         
            +
                    is_k_full=is_k_full,
         
     | 
| 342 | 
         
            +
                    has_zp=False,
         
     | 
| 343 | 
         
            +
                    use_fp32_reduce=use_fp32_reduce,
         
     | 
| 344 | 
         
            +
                    is_zp_float=False,
         
     | 
| 345 | 
         
            +
                )
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                if bias is not None:
         
     | 
| 348 | 
         
            +
                    output.add_(bias)  # In-place add
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                return output.reshape(out_shape)
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
            def apply_awq_marlin_linear(
         
     | 
| 354 | 
         
            +
                input: torch.Tensor,
         
     | 
| 355 | 
         
            +
                weight: torch.Tensor,
         
     | 
| 356 | 
         
            +
                weight_scale: torch.Tensor,
         
     | 
| 357 | 
         
            +
                weight_zp: torch.Tensor,
         
     | 
| 358 | 
         
            +
                g_idx: torch.Tensor,
         
     | 
| 359 | 
         
            +
                g_idx_sort_indices: torch.Tensor,
         
     | 
| 360 | 
         
            +
                workspace: torch.Tensor,
         
     | 
| 361 | 
         
            +
                quant_type: ScalarType,
         
     | 
| 362 | 
         
            +
                output_size_per_partition: int,
         
     | 
| 363 | 
         
            +
                input_size_per_partition: int,
         
     | 
| 364 | 
         
            +
                bias: Optional[torch.Tensor] = None,
         
     | 
| 365 | 
         
            +
                use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
         
     | 
| 366 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 367 | 
         
            +
                reshaped_x = input.reshape(-1, input.shape[-1])
         
     | 
| 368 | 
         
            +
                out_shape = input.shape[:-1] + (output_size_per_partition,)
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                output = ops.gptq_marlin_gemm(
         
     | 
| 371 | 
         
            +
                    reshaped_x,
         
     | 
| 372 | 
         
            +
                    weight,
         
     | 
| 373 | 
         
            +
                    weight_scale,
         
     | 
| 374 | 
         
            +
                    weight_zp,
         
     | 
| 375 | 
         
            +
                    g_idx,
         
     | 
| 376 | 
         
            +
                    g_idx_sort_indices,
         
     | 
| 377 | 
         
            +
                    workspace,
         
     | 
| 378 | 
         
            +
                    quant_type,
         
     | 
| 379 | 
         
            +
                    size_m=reshaped_x.shape[0],
         
     | 
| 380 | 
         
            +
                    size_n=output_size_per_partition,
         
     | 
| 381 | 
         
            +
                    size_k=input_size_per_partition,
         
     | 
| 382 | 
         
            +
                    is_k_full=True,
         
     | 
| 383 | 
         
            +
                    has_zp=True,
         
     | 
| 384 | 
         
            +
                    use_fp32_reduce=use_fp32_reduce,
         
     | 
| 385 | 
         
            +
                    is_zp_float=False,
         
     | 
| 386 | 
         
            +
                )
         
     | 
| 387 | 
         
            +
             
     | 
| 388 | 
         
            +
                if bias is not None:
         
     | 
| 389 | 
         
            +
                    output.add_(bias)  # In-place add
         
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
                return output.reshape(out_shape)
         
     | 
    	
        ext-torch/utils/marlin_utils_fp8.py
    ADDED
    
    | 
         @@ -0,0 +1,100 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Optional
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import quantization as ops
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from .marlin_utils import marlin_make_workspace, marlin_permute_scales
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            def is_fp8_marlin_supported():
         
     | 
| 11 | 
         
            +
                capability = torch.cuda.get_device_capability()
         
     | 
| 12 | 
         
            +
                capability = capability[0] * 10 + capability[1]
         
     | 
| 13 | 
         
            +
                return capability >= 80
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            def apply_fp8_marlin_linear(
         
     | 
| 17 | 
         
            +
                input: torch.Tensor,
         
     | 
| 18 | 
         
            +
                weight: torch.Tensor,
         
     | 
| 19 | 
         
            +
                weight_scale: torch.Tensor,
         
     | 
| 20 | 
         
            +
                workspace: torch.Tensor,
         
     | 
| 21 | 
         
            +
                size_n: int,
         
     | 
| 22 | 
         
            +
                size_k: int,
         
     | 
| 23 | 
         
            +
                bias: Optional[torch.Tensor],
         
     | 
| 24 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 25 | 
         
            +
                # For GPUs that lack FP8 hardware support, we can leverage the
         
     | 
| 26 | 
         
            +
                # Marlin kernel for fast weight-only FP8 quantization
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                reshaped_x = input.reshape(-1, input.shape[-1])
         
     | 
| 29 | 
         
            +
                out_shape = input.shape[:-1] + (size_n,)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                output = ops.fp8_marlin_gemm(
         
     | 
| 32 | 
         
            +
                    a=reshaped_x,
         
     | 
| 33 | 
         
            +
                    b_q_weight=weight,
         
     | 
| 34 | 
         
            +
                    b_scales=weight_scale,
         
     | 
| 35 | 
         
            +
                    workspace=workspace,
         
     | 
| 36 | 
         
            +
                    num_bits=8,
         
     | 
| 37 | 
         
            +
                    size_m=reshaped_x.shape[0],
         
     | 
| 38 | 
         
            +
                    size_n=size_n,
         
     | 
| 39 | 
         
            +
                    size_k=size_k,
         
     | 
| 40 | 
         
            +
                )
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                if bias is not None:
         
     | 
| 43 | 
         
            +
                    output.add_(bias)  # In-place add
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                return output.reshape(out_shape)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def prepare_fp8_layer_for_marlin(
         
     | 
| 49 | 
         
            +
                layer: torch.nn.Module, strategy: str = "tensor"
         
     | 
| 50 | 
         
            +
            ) -> None:
         
     | 
| 51 | 
         
            +
                part_size_n = layer.output_size_per_partition
         
     | 
| 52 | 
         
            +
                part_size_k = layer.input_size_per_partition
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                device = layer.weight.device
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                # WORKSPACE
         
     | 
| 57 | 
         
            +
                layer.workspace = marlin_make_workspace(part_size_n, device)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                # WEIGHT
         
     | 
| 60 | 
         
            +
                # Repack weights to marlin format
         
     | 
| 61 | 
         
            +
                marlin_qweight = ops.gptq_marlin_repack(
         
     | 
| 62 | 
         
            +
                    b_q_weight=pack_fp8_to_int32(layer.weight),
         
     | 
| 63 | 
         
            +
                    perm=torch.empty(0, dtype=torch.int, device=device),
         
     | 
| 64 | 
         
            +
                    size_k=part_size_k,
         
     | 
| 65 | 
         
            +
                    size_n=part_size_n,
         
     | 
| 66 | 
         
            +
                    num_bits=8,
         
     | 
| 67 | 
         
            +
                )
         
     | 
| 68 | 
         
            +
                layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                # WEIGHT SCALES
         
     | 
| 71 | 
         
            +
                scales = layer.weight_scale.to(layer.orig_dtype)
         
     | 
| 72 | 
         
            +
                # Permute scales
         
     | 
| 73 | 
         
            +
                marlin_scales = marlin_permute_scales(
         
     | 
| 74 | 
         
            +
                    s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1
         
     | 
| 75 | 
         
            +
                )
         
     | 
| 76 | 
         
            +
                layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
         
     | 
| 80 | 
         
            +
                """
         
     | 
| 81 | 
         
            +
                Repack FP8 weights to gptq format (packed int32 elements)
         
     | 
| 82 | 
         
            +
                """
         
     | 
| 83 | 
         
            +
                assert fp8_tensor.dtype == torch.float8_e4m3fn
         
     | 
| 84 | 
         
            +
                assert fp8_tensor.shape[0] % 4 == 0
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                # Reshape to prepare for packing
         
     | 
| 87 | 
         
            +
                reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                # Convert fp8 to uint8 (byte) representation
         
     | 
| 90 | 
         
            +
                byte_tensor = reshaped.view(torch.uint8)
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                # Pack 4 uint8 values into one int32
         
     | 
| 93 | 
         
            +
                packed = (
         
     | 
| 94 | 
         
            +
                    byte_tensor[:, 0].to(torch.int32)
         
     | 
| 95 | 
         
            +
                    | (byte_tensor[:, 1].to(torch.int32) << 8)
         
     | 
| 96 | 
         
            +
                    | (byte_tensor[:, 2].to(torch.int32) << 16)
         
     | 
| 97 | 
         
            +
                    | (byte_tensor[:, 3].to(torch.int32) << 24)
         
     | 
| 98 | 
         
            +
                )
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous()
         
     | 
    	
        ext-torch/utils/marlin_utils_test.py
    ADDED
    
    | 
         @@ -0,0 +1,162 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """Utility functions used for tests and benchmarks"""
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from typing import List, Optional
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from quantization.scalar_type import ScalarType
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
         
     | 
| 11 | 
         
            +
            from .quant_utils import (
         
     | 
| 12 | 
         
            +
                get_pack_factor,
         
     | 
| 13 | 
         
            +
                gptq_quantize_weights,
         
     | 
| 14 | 
         
            +
                quantize_weights,
         
     | 
| 15 | 
         
            +
                sort_weights,
         
     | 
| 16 | 
         
            +
            )
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            class MarlinWorkspace:
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                def __init__(self, out_features, min_thread_n, max_parallel):
         
     | 
| 22 | 
         
            +
                    assert (
         
     | 
| 23 | 
         
            +
                        out_features % min_thread_n == 0
         
     | 
| 24 | 
         
            +
                    ), "out_features = {} is undivisible by min_thread_n = {}".format(
         
     | 
| 25 | 
         
            +
                        out_features, min_thread_n
         
     | 
| 26 | 
         
            +
                    )
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                    max_workspace_size = (out_features // min_thread_n) * max_parallel
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
         
     | 
| 34 | 
         
            +
                assert q_w.shape == (size_k, size_n)
         
     | 
| 35 | 
         
            +
                assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
         
     | 
| 36 | 
         
            +
                assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                # Permute weights to 16x64 marlin tiles
         
     | 
| 39 | 
         
            +
                q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
         
     | 
| 40 | 
         
            +
                q_w = q_w.permute((0, 2, 1, 3))
         
     | 
| 41 | 
         
            +
                q_w = q_w.reshape((size_k // tile, size_n * tile))
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                return q_w
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def marlin_weights(q_w, size_k, size_n, num_bits, perm):
         
     | 
| 49 | 
         
            +
                # Permute
         
     | 
| 50 | 
         
            +
                q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                # Pack
         
     | 
| 53 | 
         
            +
                pack_factor = get_pack_factor(num_bits)
         
     | 
| 54 | 
         
            +
                orig_device = q_w.device
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                q_w = q_w.cpu().numpy().astype(np.uint32)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
         
     | 
| 59 | 
         
            +
                for i in range(pack_factor):
         
     | 
| 60 | 
         
            +
                    q_packed |= q_w[:, i::pack_factor] << num_bits * i
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                return q_packed
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            def get_weight_perm(num_bits: int):
         
     | 
| 68 | 
         
            +
                perm_list: List[int] = []
         
     | 
| 69 | 
         
            +
                for i in range(32):
         
     | 
| 70 | 
         
            +
                    perm1: List[int] = []
         
     | 
| 71 | 
         
            +
                    col = i // 4
         
     | 
| 72 | 
         
            +
                    for block in [0, 1]:
         
     | 
| 73 | 
         
            +
                        for row in [
         
     | 
| 74 | 
         
            +
                            2 * (i % 4),
         
     | 
| 75 | 
         
            +
                            2 * (i % 4) + 1,
         
     | 
| 76 | 
         
            +
                            2 * (i % 4 + 4),
         
     | 
| 77 | 
         
            +
                            2 * (i % 4 + 4) + 1,
         
     | 
| 78 | 
         
            +
                        ]:
         
     | 
| 79 | 
         
            +
                            perm1.append(16 * row + col + 8 * block)
         
     | 
| 80 | 
         
            +
                    for j in range(4):
         
     | 
| 81 | 
         
            +
                        perm_list.extend([p + 256 * j for p in perm1])
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                perm = np.array(perm_list)
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                if num_bits == 4:
         
     | 
| 86 | 
         
            +
                    interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
         
     | 
| 87 | 
         
            +
                elif num_bits == 8:
         
     | 
| 88 | 
         
            +
                    interleave = np.array([0, 2, 1, 3])
         
     | 
| 89 | 
         
            +
                else:
         
     | 
| 90 | 
         
            +
                    raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
         
     | 
| 93 | 
         
            +
                perm = torch.from_numpy(perm)
         
     | 
| 94 | 
         
            +
                return perm
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            def marlin_quantize(
         
     | 
| 98 | 
         
            +
                w: torch.Tensor,
         
     | 
| 99 | 
         
            +
                quant_type: ScalarType,
         
     | 
| 100 | 
         
            +
                group_size: int,
         
     | 
| 101 | 
         
            +
                act_order: bool,
         
     | 
| 102 | 
         
            +
                test_perm: Optional[torch.Tensor] = None,
         
     | 
| 103 | 
         
            +
            ):
         
     | 
| 104 | 
         
            +
                size_k, size_n = w.shape
         
     | 
| 105 | 
         
            +
                num_bits = quant_type.size_bits
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                # Normalize group_size
         
     | 
| 108 | 
         
            +
                if group_size == -1:
         
     | 
| 109 | 
         
            +
                    group_size = size_k
         
     | 
| 110 | 
         
            +
                assert group_size <= size_k
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                # Quantize (and apply act_order if provided)
         
     | 
| 113 | 
         
            +
                w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
         
     | 
| 114 | 
         
            +
                    w, quant_type, group_size, act_order, test_perm
         
     | 
| 115 | 
         
            +
                )
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                # For act_order, sort the "weights" and "g_idx" so that group ids are
         
     | 
| 118 | 
         
            +
                # increasing
         
     | 
| 119 | 
         
            +
                sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
         
     | 
| 120 | 
         
            +
                if act_order:
         
     | 
| 121 | 
         
            +
                    q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                # Reformat to marlin
         
     | 
| 124 | 
         
            +
                weight_perm = get_weight_perm(num_bits)
         
     | 
| 125 | 
         
            +
                marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
         
     | 
| 126 | 
         
            +
                marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                # Create result
         
     | 
| 129 | 
         
            +
                res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
         
     | 
| 130 | 
         
            +
                for i in range(len(res_list)):
         
     | 
| 131 | 
         
            +
                    res_list[i] = res_list[i].to(w.device)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                return res_list
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
         
     | 
| 137 | 
         
            +
                size_k, size_n = w.shape
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                # Normalize group_size
         
     | 
| 140 | 
         
            +
                if group_size == -1:
         
     | 
| 141 | 
         
            +
                    group_size = size_k
         
     | 
| 142 | 
         
            +
                assert group_size <= size_k
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                # Detect num groups
         
     | 
| 145 | 
         
            +
                assert size_k % group_size == 0
         
     | 
| 146 | 
         
            +
                num_groups = size_k // group_size
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                # Quantize with zp
         
     | 
| 149 | 
         
            +
                w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                # Reformat to marlin
         
     | 
| 152 | 
         
            +
                weight_perm = get_weight_perm(quant_type.size_bits)
         
     | 
| 153 | 
         
            +
                marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
         
     | 
| 154 | 
         
            +
                marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
         
     | 
| 155 | 
         
            +
                marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                # Create result
         
     | 
| 158 | 
         
            +
                res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
         
     | 
| 159 | 
         
            +
                for i in range(len(res_list)):
         
     | 
| 160 | 
         
            +
                    res_list[i] = res_list[i].to(w.device)
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                return res_list
         
     | 
    	
        ext-torch/utils/marlin_utils_test_24.py
    ADDED
    
    | 
         @@ -0,0 +1,473 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """Utility functions used for tests and benchmarks"""
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import random
         
     | 
| 4 | 
         
            +
            from typing import List
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from quantization.scalar_type import ScalarType
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from .marlin_utils_test import marlin_weights
         
     | 
| 12 | 
         
            +
            from .quant_utils import gptq_quantize_weights
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            # This is PyTorch implementation of main part of reorder_meta()
         
     | 
| 16 | 
         
            +
            # function, from tools/util/include/cutlass/util/host_reorder.h file
         
     | 
| 17 | 
         
            +
            # of CUTLASS source tree.  Furthermore, CUTLASS template for sparse
         
     | 
| 18 | 
         
            +
            # GEMM decides upon layout of this matrix, and at the moment for the
         
     | 
| 19 | 
         
            +
            # sparse GEMM executed on tensor cores, this is layout described by
         
     | 
| 20 | 
         
            +
            # ColumnMajorInterleaved<2> data structure, in
         
     | 
| 21 | 
         
            +
            # include/cutlass/layout/matrix.h of CUTLASS source tree.  The
         
     | 
| 22 | 
         
            +
            # reordering of meta matrix into meta_reordered matrix calculated
         
     | 
| 23 | 
         
            +
            # according to these segments of CUTLASS code is re-implemented here.
         
     | 
| 24 | 
         
            +
            # Note that this calculation produces offsets for scattering metadata
         
     | 
| 25 | 
         
            +
            # matrix elements into reordered metadata matrix elements (or,
         
     | 
| 26 | 
         
            +
            # equivalently, for gathering reordered metadata matrix element back
         
     | 
| 27 | 
         
            +
            # into metadata matrix elements).
         
     | 
| 28 | 
         
            +
            def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
         
     | 
| 29 | 
         
            +
                dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
         
     | 
| 30 | 
         
            +
                dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                # Reorder the rows, then swizzle the 2x2 blocks.
         
     | 
| 33 | 
         
            +
                group_x = 64
         
     | 
| 34 | 
         
            +
                group_y = 32 if meta_dtype.itemsize == 2 else 16
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                dst_rows = (
         
     | 
| 37 | 
         
            +
                    dst_rows // group_x * group_x
         
     | 
| 38 | 
         
            +
                    + (dst_rows % 2) * 2
         
     | 
| 39 | 
         
            +
                    + (dst_rows % 8) // 4
         
     | 
| 40 | 
         
            +
                    + ((dst_rows % group_y) % 4) // 2 * 32
         
     | 
| 41 | 
         
            +
                    + ((dst_rows % group_x) // 8) * 4
         
     | 
| 42 | 
         
            +
                )
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
         
     | 
| 45 | 
         
            +
                bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
         
     | 
| 46 | 
         
            +
                dst_rows += topright - bottomleft
         
     | 
| 47 | 
         
            +
                dst_cols -= topright - bottomleft
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                # Assumed that meta tensor is to be stored in CUTLASS
         
     | 
| 50 | 
         
            +
                # InterleavedColumnMajor layout, and reverse engineered
         
     | 
| 51 | 
         
            +
                # corresponding code to store values into this tensor.
         
     | 
| 52 | 
         
            +
                interleave = 2
         
     | 
| 53 | 
         
            +
                cols_maj = dst_cols // interleave
         
     | 
| 54 | 
         
            +
                cols_min = dst_cols % interleave
         
     | 
| 55 | 
         
            +
                return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            # This function converts dense matrix into sparse semi-structured
         
     | 
| 59 | 
         
            +
            # representation, producing "compressed" matrix, in the layout used by
         
     | 
| 60 | 
         
            +
            # CUTLASS backend, and corresponding metadata matrix.
         
     | 
| 61 | 
         
            +
            def sparse_semi_structured_from_dense_cutlass(dense):
         
     | 
| 62 | 
         
            +
                if dense.dim() != 2:
         
     | 
| 63 | 
         
            +
                    raise RuntimeError(
         
     | 
| 64 | 
         
            +
                        f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor"  # noqa: E501
         
     | 
| 65 | 
         
            +
                    )
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                m, k = dense.shape
         
     | 
| 68 | 
         
            +
                device = dense.device
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                meta_dtype = torch.int8
         
     | 
| 71 | 
         
            +
                if dense.dtype == torch.int8:
         
     | 
| 72 | 
         
            +
                    meta_dtype = torch.int32
         
     | 
| 73 | 
         
            +
                elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
         
     | 
| 74 | 
         
            +
                    meta_dtype = torch.int16
         
     | 
| 75 | 
         
            +
                else:
         
     | 
| 76 | 
         
            +
                    raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
         
     | 
| 77 | 
         
            +
                quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
         
     | 
| 78 | 
         
            +
                if quadbits_per_meta_elem not in (4, 8):
         
     | 
| 79 | 
         
            +
                    raise RuntimeError("Invalid number of elements per meta element calculated")
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                if meta_dtype == torch.int32:
         
     | 
| 82 | 
         
            +
                    if m % 16 != 0:
         
     | 
| 83 | 
         
            +
                        raise RuntimeError(
         
     | 
| 84 | 
         
            +
                            f"Number of rows of dense matrix {m} must be divisible by 16"
         
     | 
| 85 | 
         
            +
                        )
         
     | 
| 86 | 
         
            +
                else:
         
     | 
| 87 | 
         
            +
                    if m % 32 != 0:
         
     | 
| 88 | 
         
            +
                        raise RuntimeError(
         
     | 
| 89 | 
         
            +
                            f"Number of rows of dense matrix {m} must be divisible by 32"
         
     | 
| 90 | 
         
            +
                        )
         
     | 
| 91 | 
         
            +
                if k % (4 * quadbits_per_meta_elem) != 0:
         
     | 
| 92 | 
         
            +
                    raise RuntimeError(
         
     | 
| 93 | 
         
            +
                        f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}"  # noqa: E501
         
     | 
| 94 | 
         
            +
                    )
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                if dense.dtype != torch.float:
         
     | 
| 97 | 
         
            +
                    ksparse = 4
         
     | 
| 98 | 
         
            +
                    dense_4 = dense.view(-1, k // ksparse, ksparse)
         
     | 
| 99 | 
         
            +
                    m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
         
     | 
| 100 | 
         
            +
                else:
         
     | 
| 101 | 
         
            +
                    ksparse = 2
         
     | 
| 102 | 
         
            +
                    dense_2 = dense.view(-1, k // ksparse, ksparse)
         
     | 
| 103 | 
         
            +
                    m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
         
     | 
| 104 | 
         
            +
                meta_ncols = k // (ksparse * quadbits_per_meta_elem)
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                # Encoding quadruples of True/False values as follows:
         
     | 
| 107 | 
         
            +
                #     [True,  True,  False, False] -> 0b0100
         
     | 
| 108 | 
         
            +
                #     [True,  False, True,  False] -> 0b1000
         
     | 
| 109 | 
         
            +
                #     [False, True,  True,  False] -> 0b1001
         
     | 
| 110 | 
         
            +
                #     [True,  False, False, True ] -> 0b1100
         
     | 
| 111 | 
         
            +
                #     [False, True,  False, True ] -> 0b1101
         
     | 
| 112 | 
         
            +
                #     [False, False, True,  True ] -> 0b1110
         
     | 
| 113 | 
         
            +
                # Thus, lower two bits in the encoding are index of the True value
         
     | 
| 114 | 
         
            +
                # at the lowest index in the quadruple, and the higher two bits in
         
     | 
| 115 | 
         
            +
                # the encoding are index of the other True value in the quadruple.
         
     | 
| 116 | 
         
            +
                # In case there are less than two True values, than False value or
         
     | 
| 117 | 
         
            +
                # values at some index or indices are considered True for the
         
     | 
| 118 | 
         
            +
                # encoding.  In case there are more than two True values, then the
         
     | 
| 119 | 
         
            +
                # excess True value(s) at some indices are considered False for
         
     | 
| 120 | 
         
            +
                # the encoding.  The exact encodings used for these cases are as
         
     | 
| 121 | 
         
            +
                # follows:
         
     | 
| 122 | 
         
            +
                #     [False, False, False, False] -> 0b1110
         
     | 
| 123 | 
         
            +
                #     [False, False, False, True ] -> 0b1110
         
     | 
| 124 | 
         
            +
                #     [False, False, True,  False] -> 0b1110
         
     | 
| 125 | 
         
            +
                #     [False, True,  False, False] -> 0b1001
         
     | 
| 126 | 
         
            +
                #     [False, True,  True,  True ] -> 0b1101
         
     | 
| 127 | 
         
            +
                #     [True,  False, False, False] -> 0b1000
         
     | 
| 128 | 
         
            +
                #     [True,  False, True,  True ] -> 0b1100
         
     | 
| 129 | 
         
            +
                #     [True,  True,  False, True ] -> 0b0100
         
     | 
| 130 | 
         
            +
                #     [True,  True,  True,  False] -> 0b0100
         
     | 
| 131 | 
         
            +
                #     [True,  True,  True,  True ] -> 0b0100
         
     | 
| 132 | 
         
            +
                # These particular encodings are chosen, with the help of Espresso
         
     | 
| 133 | 
         
            +
                # logic minimizer software, for the purpose of minimization of
         
     | 
| 134 | 
         
            +
                # corresponding Boolean functions, that translate non-zero flags
         
     | 
| 135 | 
         
            +
                # into encoding bits.  Note also possible choices for the first
         
     | 
| 136 | 
         
            +
                # and last of these encodings were limited only to (0b0100,
         
     | 
| 137 | 
         
            +
                # 0b1110), in order to produce valid encodings for 1:2 sparsity
         
     | 
| 138 | 
         
            +
                # case.
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                expr0 = m0 & m1
         
     | 
| 141 | 
         
            +
                expr1 = ~m0 & m1
         
     | 
| 142 | 
         
            +
                expr2 = ~m0 & ~m1
         
     | 
| 143 | 
         
            +
                bit0 = expr1
         
     | 
| 144 | 
         
            +
                bit1 = expr2
         
     | 
| 145 | 
         
            +
                bit2 = expr0 | expr2 | m3
         
     | 
| 146 | 
         
            +
                bit3 = expr1 | ~m1
         
     | 
| 147 | 
         
            +
                idxs0 = bit0 | (bit1.to(torch.int64) << 1)
         
     | 
| 148 | 
         
            +
                idxs1 = bit2 | (bit3.to(torch.int64) << 1)
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                if dense.dtype != torch.float:
         
     | 
| 151 | 
         
            +
                    sparse0 = dense_4.gather(
         
     | 
| 152 | 
         
            +
                        -1, idxs0.unsqueeze(-1)
         
     | 
| 153 | 
         
            +
                    )  # type: ignore[possibly-undefined]
         
     | 
| 154 | 
         
            +
                    sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
         
     | 
| 155 | 
         
            +
                    sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
         
     | 
| 156 | 
         
            +
                else:
         
     | 
| 157 | 
         
            +
                    sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(
         
     | 
| 158 | 
         
            +
                        m, k // 2
         
     | 
| 159 | 
         
            +
                    )  # type: ignore[possibly-undefined]
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                meta_4 = idxs0 | (idxs1 << 2)
         
     | 
| 162 | 
         
            +
                meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                if quadbits_per_meta_elem == 4:
         
     | 
| 165 | 
         
            +
                    meta = (
         
     | 
| 166 | 
         
            +
                        meta_n[:, :, 0]
         
     | 
| 167 | 
         
            +
                        | (meta_n[:, :, 1] << 4)
         
     | 
| 168 | 
         
            +
                        | (meta_n[:, :, 2] << 8)
         
     | 
| 169 | 
         
            +
                        | (meta_n[:, :, 3] << 12)
         
     | 
| 170 | 
         
            +
                    )
         
     | 
| 171 | 
         
            +
                elif quadbits_per_meta_elem == 8:
         
     | 
| 172 | 
         
            +
                    meta = (
         
     | 
| 173 | 
         
            +
                        meta_n[:, :, 0]
         
     | 
| 174 | 
         
            +
                        | (meta_n[:, :, 1] << 4)
         
     | 
| 175 | 
         
            +
                        | (meta_n[:, :, 2] << 8)
         
     | 
| 176 | 
         
            +
                        | (meta_n[:, :, 3] << 12)
         
     | 
| 177 | 
         
            +
                        | (meta_n[:, :, 4] << 16)
         
     | 
| 178 | 
         
            +
                        | (meta_n[:, :, 5] << 20)
         
     | 
| 179 | 
         
            +
                        | (meta_n[:, :, 6] << 24)
         
     | 
| 180 | 
         
            +
                        | (meta_n[:, :, 7] << 28)
         
     | 
| 181 | 
         
            +
                    )
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                # Reorder meta tensor elements.
         
     | 
| 184 | 
         
            +
                meta_reordered = meta.new_empty(
         
     | 
| 185 | 
         
            +
                    (m * meta_ncols,)
         
     | 
| 186 | 
         
            +
                )  # type: ignore[possibly-undefined]
         
     | 
| 187 | 
         
            +
                meta_offsets = _calculate_meta_reordering_scatter_offsets(
         
     | 
| 188 | 
         
            +
                    m, meta_ncols, meta_dtype, device
         
     | 
| 189 | 
         
            +
                )
         
     | 
| 190 | 
         
            +
                meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                return (sparse, meta_reordered.view(m, meta_ncols))
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
            # This function performs reverse of the function above - it
         
     | 
| 196 | 
         
            +
            # reconstructs dense matrix from a pair of "compressed" matrix, given
         
     | 
| 197 | 
         
            +
            # in the layout used by CUTLASS backend, and accompanying metadata
         
     | 
| 198 | 
         
            +
            # matrix.
         
     | 
| 199 | 
         
            +
            def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
         
     | 
| 200 | 
         
            +
                if sparse.dim() != 2:
         
     | 
| 201 | 
         
            +
                    raise RuntimeError(
         
     | 
| 202 | 
         
            +
                        f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor"  # noqa: E501
         
     | 
| 203 | 
         
            +
                    )
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                m, k = sparse.shape
         
     | 
| 206 | 
         
            +
                device = sparse.device
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                if meta_reordered.dim() != 2:
         
     | 
| 209 | 
         
            +
                    raise RuntimeError(
         
     | 
| 210 | 
         
            +
                        f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor"  # noqa: E501
         
     | 
| 211 | 
         
            +
                    )
         
     | 
| 212 | 
         
            +
                if meta_reordered.device != device:
         
     | 
| 213 | 
         
            +
                    raise RuntimeError(
         
     | 
| 214 | 
         
            +
                        f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device"  # noqa: E501
         
     | 
| 215 | 
         
            +
                    )
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                meta_dtype = meta_reordered.dtype
         
     | 
| 218 | 
         
            +
                if meta_dtype not in (torch.int16, torch.int32):
         
     | 
| 219 | 
         
            +
                    raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
         
     | 
| 220 | 
         
            +
                quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                ksparse = 4 if sparse.dtype != torch.float else 2
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                meta_nrows, meta_ncols = meta_reordered.shape
         
     | 
| 225 | 
         
            +
                if meta_nrows != m:
         
     | 
| 226 | 
         
            +
                    raise RuntimeError(
         
     | 
| 227 | 
         
            +
                        f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}"  # noqa: E501
         
     | 
| 228 | 
         
            +
                    )
         
     | 
| 229 | 
         
            +
                if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
         
     | 
| 230 | 
         
            +
                    raise RuntimeError(
         
     | 
| 231 | 
         
            +
                        f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, "  # noqa: E501
         
     | 
| 232 | 
         
            +
                        "expected according to the number of columns of meta matrix"
         
     | 
| 233 | 
         
            +
                    )
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                # Undo meta tensor elements reordering.
         
     | 
| 236 | 
         
            +
                meta_offsets = _calculate_meta_reordering_scatter_offsets(
         
     | 
| 237 | 
         
            +
                    m, meta_ncols, meta_dtype, device
         
     | 
| 238 | 
         
            +
                )
         
     | 
| 239 | 
         
            +
                meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                # Unpack sparse tensor back to original dense tensor, using
         
     | 
| 242 | 
         
            +
                # information provided by meta tensor.  Note that torch.float
         
     | 
| 243 | 
         
            +
                # datatype is handled pretty much the same as
         
     | 
| 244 | 
         
            +
                # torch.half/torch.bfloat16, as metadata for a pair of torch.float
         
     | 
| 245 | 
         
            +
                # value is encoded as if underlying 8 bytes contain four
         
     | 
| 246 | 
         
            +
                # torch.half/torch.bfloat16 values, where either first two or last
         
     | 
| 247 | 
         
            +
                # two are zeros.
         
     | 
| 248 | 
         
            +
                meta_2 = torch.empty(
         
     | 
| 249 | 
         
            +
                    (m, meta_ncols, 2 * quadbits_per_meta_elem),
         
     | 
| 250 | 
         
            +
                    dtype=meta_dtype,
         
     | 
| 251 | 
         
            +
                    device=device,
         
     | 
| 252 | 
         
            +
                )
         
     | 
| 253 | 
         
            +
                if quadbits_per_meta_elem == 4:
         
     | 
| 254 | 
         
            +
                    meta_2[:, :, 0] = meta & 0b11
         
     | 
| 255 | 
         
            +
                    meta_2[:, :, 1] = (meta >> 2) & 0b11
         
     | 
| 256 | 
         
            +
                    meta_2[:, :, 2] = (meta >> 4) & 0b11
         
     | 
| 257 | 
         
            +
                    meta_2[:, :, 3] = (meta >> 6) & 0b11
         
     | 
| 258 | 
         
            +
                    meta_2[:, :, 4] = (meta >> 8) & 0b11
         
     | 
| 259 | 
         
            +
                    meta_2[:, :, 5] = (meta >> 10) & 0b11
         
     | 
| 260 | 
         
            +
                    meta_2[:, :, 6] = (meta >> 12) & 0b11
         
     | 
| 261 | 
         
            +
                    meta_2[:, :, 7] = (meta >> 14) & 0b11
         
     | 
| 262 | 
         
            +
                elif quadbits_per_meta_elem == 8:
         
     | 
| 263 | 
         
            +
                    meta_2[:, :, 0] = meta & 0b11
         
     | 
| 264 | 
         
            +
                    meta_2[:, :, 1] = (meta >> 2) & 0b11
         
     | 
| 265 | 
         
            +
                    meta_2[:, :, 2] = (meta >> 4) & 0b11
         
     | 
| 266 | 
         
            +
                    meta_2[:, :, 3] = (meta >> 6) & 0b11
         
     | 
| 267 | 
         
            +
                    meta_2[:, :, 4] = (meta >> 8) & 0b11
         
     | 
| 268 | 
         
            +
                    meta_2[:, :, 5] = (meta >> 10) & 0b11
         
     | 
| 269 | 
         
            +
                    meta_2[:, :, 6] = (meta >> 12) & 0b11
         
     | 
| 270 | 
         
            +
                    meta_2[:, :, 7] = (meta >> 14) & 0b11
         
     | 
| 271 | 
         
            +
                    meta_2[:, :, 8] = (meta >> 16) & 0b11
         
     | 
| 272 | 
         
            +
                    meta_2[:, :, 9] = (meta >> 18) & 0b11
         
     | 
| 273 | 
         
            +
                    meta_2[:, :, 10] = (meta >> 20) & 0b11
         
     | 
| 274 | 
         
            +
                    meta_2[:, :, 11] = (meta >> 22) & 0b11
         
     | 
| 275 | 
         
            +
                    meta_2[:, :, 12] = (meta >> 24) & 0b11
         
     | 
| 276 | 
         
            +
                    meta_2[:, :, 13] = (meta >> 26) & 0b11
         
     | 
| 277 | 
         
            +
                    meta_2[:, :, 14] = (meta >> 28) & 0b11
         
     | 
| 278 | 
         
            +
                    meta_2[:, :, 15] = (meta >> 30) & 0b11
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                dense_offsets = meta_2.view(-1) + (
         
     | 
| 281 | 
         
            +
                    torch.arange(0, 2 * m * k // ksparse, device=device) * 4
         
     | 
| 282 | 
         
            +
                ).view(-1, 1).repeat(1, 2).view(-1)
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
         
     | 
| 285 | 
         
            +
                if sparse.dtype != torch.float:
         
     | 
| 286 | 
         
            +
                    # dense.scatter_(0, dense_offsets, sparse.view(-1))
         
     | 
| 287 | 
         
            +
                    dense.scatter_(0, dense_offsets, sparse.reshape(-1))
         
     | 
| 288 | 
         
            +
                else:
         
     | 
| 289 | 
         
            +
                    dense.view(torch.half).scatter_(
         
     | 
| 290 | 
         
            +
                        0, dense_offsets, sparse.view(torch.half).view(-1)
         
     | 
| 291 | 
         
            +
                    )
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                return dense.view(m, 2 * k)
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
            def mask_creator(tensor):
         
     | 
| 297 | 
         
            +
                """
         
     | 
| 298 | 
         
            +
                Class for creating N:M sparsity masks.
         
     | 
| 299 | 
         
            +
                Masks will be created using the N:M ratio, where for every block of
         
     | 
| 300 | 
         
            +
                M weights, N will be pruned based on ranked weight value. Each mask
         
     | 
| 301 | 
         
            +
                will correspond to the given tensor.
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                :param N: The number of weights in a group to keep
         
     | 
| 304 | 
         
            +
                :param M: The size of a weight group
         
     | 
| 305 | 
         
            +
                """
         
     | 
| 306 | 
         
            +
                N = 2
         
     | 
| 307 | 
         
            +
                M = 4
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                mask = None
         
     | 
| 310 | 
         
            +
                # for i, tensor in enumerate(tensors):
         
     | 
| 311 | 
         
            +
                if tensor.numel() % M != 0:
         
     | 
| 312 | 
         
            +
                    raise ValueError(
         
     | 
| 313 | 
         
            +
                        f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups"
         
     | 
| 314 | 
         
            +
                    )
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                num_groups = tensor.numel() // M
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                # N:M sparsity for linear layers
         
     | 
| 319 | 
         
            +
                tensor_temp = tensor.detach().abs().reshape(num_groups, M)
         
     | 
| 320 | 
         
            +
                index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
         
     | 
| 323 | 
         
            +
                mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                return mask
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
            def inject_24(w, size_k, size_n):
         
     | 
| 329 | 
         
            +
                assert w.shape == (size_k, size_n)
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                mask = mask_creator(w.t()).t().cuda().bool()
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                return (mask * w).contiguous(), mask.contiguous()
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
            def check_24(w, num_rows_to_sample=50, _verbose=False):
         
     | 
| 337 | 
         
            +
                BLOCK_SIZE = 4
         
     | 
| 338 | 
         
            +
                MAX_NON_ZEROS = 2
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                w = w.t().contiguous()
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                print("check_24: w.shape = {}".format(w.shape))
         
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
                num_rows, num_cols = w.shape
         
     | 
| 345 | 
         
            +
                sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
         
     | 
| 346 | 
         
            +
                if _verbose:
         
     | 
| 347 | 
         
            +
                    print(f"Sampled row idxs = {sampled_row_idxs}")
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                total_segments = 0
         
     | 
| 350 | 
         
            +
                non_24_segments = 0
         
     | 
| 351 | 
         
            +
                for i in sampled_row_idxs:
         
     | 
| 352 | 
         
            +
                    for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
         
     | 
| 353 | 
         
            +
                        total_segments += 1
         
     | 
| 354 | 
         
            +
                        block = w[i, j : j + BLOCK_SIZE]
         
     | 
| 355 | 
         
            +
                        num_nonzero = torch.count_nonzero(block)
         
     | 
| 356 | 
         
            +
                        if num_nonzero > MAX_NON_ZEROS:
         
     | 
| 357 | 
         
            +
                            print("i = {} j = {} block = {}".format(i, j, block))
         
     | 
| 358 | 
         
            +
                            non_24_segments += 1
         
     | 
| 359 | 
         
            +
             
     | 
| 360 | 
         
            +
                print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
            def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
         
     | 
| 364 | 
         
            +
                assert q_24.shape == (size_k, size_n)
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                # Remove bias to normalize over 0
         
     | 
| 367 | 
         
            +
                q_24_no_zp = q_24 - wtype.bias
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                # Compress
         
     | 
| 370 | 
         
            +
                q_24_no_zp = q_24_no_zp.t().contiguous()
         
     | 
| 371 | 
         
            +
                q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp)
         
     | 
| 372 | 
         
            +
                q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                # Restore bias
         
     | 
| 375 | 
         
            +
                q_24_comp = q_24_no_zp_comp + wtype.bias
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                # Resize meta to its actual shape (without moving any data)
         
     | 
| 378 | 
         
            +
                meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
                return q_24_comp, meta
         
     | 
| 381 | 
         
            +
             
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
            def get_scale_perms_24():
         
     | 
| 384 | 
         
            +
                scale_perm: List[int] = []
         
     | 
| 385 | 
         
            +
                for i in range(8):
         
     | 
| 386 | 
         
            +
                    scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
         
     | 
| 387 | 
         
            +
                scale_perm_single: List[int] = []
         
     | 
| 388 | 
         
            +
                for i in range(8):
         
     | 
| 389 | 
         
            +
                    scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
         
     | 
| 390 | 
         
            +
                return scale_perm, scale_perm_single
         
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
             
     | 
| 393 | 
         
            +
            def get_weight_perm_24(num_bits: int):
         
     | 
| 394 | 
         
            +
                perm_list: List[int] = []
         
     | 
| 395 | 
         
            +
                for i in range(32):
         
     | 
| 396 | 
         
            +
                    perm1: List[int] = []
         
     | 
| 397 | 
         
            +
                    col = i // 4
         
     | 
| 398 | 
         
            +
                    col_o = col // 2
         
     | 
| 399 | 
         
            +
                    for block in [0, 1]:
         
     | 
| 400 | 
         
            +
                        for row in [
         
     | 
| 401 | 
         
            +
                            2 * (i % 4),
         
     | 
| 402 | 
         
            +
                            2 * (i % 4) + 1,
         
     | 
| 403 | 
         
            +
                            2 * (i % 4 + 4),
         
     | 
| 404 | 
         
            +
                            2 * (i % 4 + 4) + 1,
         
     | 
| 405 | 
         
            +
                        ]:
         
     | 
| 406 | 
         
            +
                            perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
         
     | 
| 407 | 
         
            +
                    for j in range(4):
         
     | 
| 408 | 
         
            +
                        perm_list.extend([p + 1 * j for p in perm1])
         
     | 
| 409 | 
         
            +
                perm = numpy.array(perm_list)
         
     | 
| 410 | 
         
            +
             
     | 
| 411 | 
         
            +
                if num_bits == 4:
         
     | 
| 412 | 
         
            +
                    interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
         
     | 
| 413 | 
         
            +
                elif num_bits == 8:
         
     | 
| 414 | 
         
            +
                    interleave = numpy.array([0, 2, 1, 3])
         
     | 
| 415 | 
         
            +
                else:
         
     | 
| 416 | 
         
            +
                    raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
                perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
         
     | 
| 419 | 
         
            +
                perm = torch.from_numpy(perm)
         
     | 
| 420 | 
         
            +
                return perm
         
     | 
| 421 | 
         
            +
             
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
            def marlin_permute_scales_24(
         
     | 
| 424 | 
         
            +
                s: torch.Tensor, size_k: int, size_n: int, group_size: int
         
     | 
| 425 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 426 | 
         
            +
             
     | 
| 427 | 
         
            +
                scale_perm, scale_perm_single = get_scale_perms_24()
         
     | 
| 428 | 
         
            +
                if group_size < size_k and group_size != -1:
         
     | 
| 429 | 
         
            +
                    s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
         
     | 
| 430 | 
         
            +
                else:
         
     | 
| 431 | 
         
            +
                    s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
         
     | 
| 432 | 
         
            +
                s = s.reshape((-1, size_n)).contiguous()
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
                return s
         
     | 
| 435 | 
         
            +
             
     | 
| 436 | 
         
            +
             
     | 
| 437 | 
         
            +
            def marlin_24_quantize(
         
     | 
| 438 | 
         
            +
                w: torch.Tensor,
         
     | 
| 439 | 
         
            +
                quant_type: ScalarType,
         
     | 
| 440 | 
         
            +
                group_size: int,
         
     | 
| 441 | 
         
            +
            ):
         
     | 
| 442 | 
         
            +
                size_k, size_n = w.shape
         
     | 
| 443 | 
         
            +
             
     | 
| 444 | 
         
            +
                # Normalize group_size
         
     | 
| 445 | 
         
            +
                if group_size == -1:
         
     | 
| 446 | 
         
            +
                    group_size = size_k
         
     | 
| 447 | 
         
            +
                assert group_size <= size_k
         
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
                # Inject 2:4 sparsity
         
     | 
| 450 | 
         
            +
                w_24, mask_24 = inject_24(w, size_k, size_n)
         
     | 
| 451 | 
         
            +
             
     | 
| 452 | 
         
            +
                # Quantize
         
     | 
| 453 | 
         
            +
                w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
         
     | 
| 454 | 
         
            +
                    w_24, quant_type, group_size, act_order=False
         
     | 
| 455 | 
         
            +
                )
         
     | 
| 456 | 
         
            +
             
     | 
| 457 | 
         
            +
                # Compress quantized weight
         
     | 
| 458 | 
         
            +
                q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type)
         
     | 
| 459 | 
         
            +
                size_k_comp = size_k // 2
         
     | 
| 460 | 
         
            +
             
     | 
| 461 | 
         
            +
                # Reformat to marlin
         
     | 
| 462 | 
         
            +
                weight_perm = get_weight_perm_24(quant_type.size_bits)
         
     | 
| 463 | 
         
            +
                marlin_24_q_w_comp = marlin_weights(
         
     | 
| 464 | 
         
            +
                    q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm
         
     | 
| 465 | 
         
            +
                )
         
     | 
| 466 | 
         
            +
                marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
         
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
                # Create result
         
     | 
| 469 | 
         
            +
                res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
         
     | 
| 470 | 
         
            +
                for i in range(len(res_list)):
         
     | 
| 471 | 
         
            +
                    res_list[i] = res_list[i].to(w.device)
         
     | 
| 472 | 
         
            +
             
     | 
| 473 | 
         
            +
                return res_list
         
     | 
    	
        ext-torch/utils/marlin_utils_test_qqq.py
    ADDED
    
    | 
         @@ -0,0 +1,125 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import List
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import numpy
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from .marlin_utils_test import marlin_permute_weights
         
     | 
| 7 | 
         
            +
            from .quant_utils import get_pack_factor, qqq_quantize_weights
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size):
         
     | 
| 11 | 
         
            +
                # Permute
         
     | 
| 12 | 
         
            +
                q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                # Pack
         
     | 
| 15 | 
         
            +
                pack_factor = get_pack_factor(num_bits)
         
     | 
| 16 | 
         
            +
                orig_device = q_w.device
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                q_w = q_w.cpu().numpy().astype(numpy.uint32)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
         
     | 
| 21 | 
         
            +
                                       dtype=numpy.uint32)
         
     | 
| 22 | 
         
            +
                if group_size == size_k:
         
     | 
| 23 | 
         
            +
                    for i in range(pack_factor):
         
     | 
| 24 | 
         
            +
                        q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i
         
     | 
| 25 | 
         
            +
                else:
         
     | 
| 26 | 
         
            +
                    for i in range(pack_factor):
         
     | 
| 27 | 
         
            +
                        q_packed |= q_w[:, i::pack_factor] << num_bits * i
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                return q_packed
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            def get_qqq_scale_perms():
         
     | 
| 35 | 
         
            +
                scale_perm: List[int] = []
         
     | 
| 36 | 
         
            +
                for i in range(8):
         
     | 
| 37 | 
         
            +
                    scale_perm.extend([i + 8 * j for j in range(8)])
         
     | 
| 38 | 
         
            +
                scale_perm_single: List[int] = []
         
     | 
| 39 | 
         
            +
                for i in range(4):
         
     | 
| 40 | 
         
            +
                    scale_perm_single.extend(
         
     | 
| 41 | 
         
            +
                        [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
         
     | 
| 42 | 
         
            +
                return scale_perm, scale_perm_single
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            # NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
         
     | 
| 46 | 
         
            +
            def get_qqq_weight_perm(num_bits: int, quant_type: str):
         
     | 
| 47 | 
         
            +
                perm_list: List[int] = []
         
     | 
| 48 | 
         
            +
                for i in range(32):
         
     | 
| 49 | 
         
            +
                    perm1: List[int] = []
         
     | 
| 50 | 
         
            +
                    col = i // 4
         
     | 
| 51 | 
         
            +
                    for block in [0, 1]:
         
     | 
| 52 | 
         
            +
                        for row in [
         
     | 
| 53 | 
         
            +
                                4 * (i % 4),
         
     | 
| 54 | 
         
            +
                                4 * (i % 4) + 1,
         
     | 
| 55 | 
         
            +
                                4 * (i % 4) + 2,
         
     | 
| 56 | 
         
            +
                                4 * (i % 4) + 3,
         
     | 
| 57 | 
         
            +
                        ]:
         
     | 
| 58 | 
         
            +
                            perm1.append(16 * row + col + 8 * block)
         
     | 
| 59 | 
         
            +
                    for j in range(4):
         
     | 
| 60 | 
         
            +
                        perm_list.extend([p + 256 * j for p in perm1])
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                perm = numpy.array(perm_list)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                assert quant_type in ["per-channel",
         
     | 
| 65 | 
         
            +
                                      "per-group"], "not supported quantization type"
         
     | 
| 66 | 
         
            +
                if num_bits == 4:
         
     | 
| 67 | 
         
            +
                    if quant_type == "per-channel":
         
     | 
| 68 | 
         
            +
                        interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3])
         
     | 
| 69 | 
         
            +
                    else:
         
     | 
| 70 | 
         
            +
                        interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
         
     | 
| 71 | 
         
            +
                else:
         
     | 
| 72 | 
         
            +
                    raise Exception("num_bits must be 4, got {}".format(num_bits))
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
         
     | 
| 75 | 
         
            +
                perm = torch.from_numpy(perm)
         
     | 
| 76 | 
         
            +
                return perm
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size):
         
     | 
| 80 | 
         
            +
                scale_perm, scale_perm_single = get_qqq_scale_perms()
         
     | 
| 81 | 
         
            +
                if group_size < size_k and group_size != -1:
         
     | 
| 82 | 
         
            +
                    s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm]
         
     | 
| 83 | 
         
            +
                    s_channel = s_channel.reshape(
         
     | 
| 84 | 
         
            +
                        (-1, len(scale_perm_single)))[:, scale_perm_single]
         
     | 
| 85 | 
         
            +
                    s_group = s_group.reshape((-1, size_n)).contiguous()
         
     | 
| 86 | 
         
            +
                else:
         
     | 
| 87 | 
         
            +
                    s_channel = s_channel.reshape(
         
     | 
| 88 | 
         
            +
                        (-1, len(scale_perm_single)))[:, scale_perm_single]
         
     | 
| 89 | 
         
            +
                s_channel = s_channel.reshape((-1, size_n)).contiguous()
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                return s_group, s_channel
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
            def marlin_qqq_quantize(
         
     | 
| 95 | 
         
            +
                w: torch.Tensor,
         
     | 
| 96 | 
         
            +
                num_bits: int,
         
     | 
| 97 | 
         
            +
                group_size: int,
         
     | 
| 98 | 
         
            +
            ):
         
     | 
| 99 | 
         
            +
                size_k, size_n = w.shape
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                # Normalize group_size
         
     | 
| 102 | 
         
            +
                if group_size == -1:
         
     | 
| 103 | 
         
            +
                    group_size = size_k
         
     | 
| 104 | 
         
            +
                assert group_size <= size_k
         
     | 
| 105 | 
         
            +
                quant_type = "per-channel" if group_size == size_k else "per-group"
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                # Quantize
         
     | 
| 108 | 
         
            +
                w_ref, q_w, s_group, s_channel = qqq_quantize_weights(
         
     | 
| 109 | 
         
            +
                    w, num_bits, group_size)
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                # Reformat to marlin_qqq
         
     | 
| 112 | 
         
            +
                weight_perm = get_qqq_weight_perm(num_bits, quant_type)
         
     | 
| 113 | 
         
            +
                marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits,
         
     | 
| 114 | 
         
            +
                                                    weight_perm, group_size)
         
     | 
| 115 | 
         
            +
                marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales(
         
     | 
| 116 | 
         
            +
                    s_group, s_channel, size_k, size_n, group_size)
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                # Create result
         
     | 
| 119 | 
         
            +
                res_list = [
         
     | 
| 120 | 
         
            +
                    w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel
         
     | 
| 121 | 
         
            +
                ]
         
     | 
| 122 | 
         
            +
                for i in range(len(res_list)):
         
     | 
| 123 | 
         
            +
                    res_list[i] = res_list[i].to(w.device)
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                return res_list
         
     | 
    	
        ext-torch/utils/quant_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,470 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """This file is used for /tests and /benchmarks"""
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from typing import List, Optional
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import numpy
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from quantization.scalar_type import ScalarType, scalar_types
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
         
     | 
| 11 | 
         
            +
            SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            # Note: this is a hack. We should update each model to register the
         
     | 
| 16 | 
         
            +
            # stacked params and get it from there instead in a future PR.
         
     | 
| 17 | 
         
            +
            # fused_name: List[shard_name]
         
     | 
| 18 | 
         
            +
            FUSED_LAYER_NAME_MAPPING = {
         
     | 
| 19 | 
         
            +
                "qkv_proj": ["q_proj", "k_proj", "v_proj"],
         
     | 
| 20 | 
         
            +
                "gate_up_proj": ["gate_proj", "up_proj"],
         
     | 
| 21 | 
         
            +
            }
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            def pack_quantized_values_into_int32(
         
     | 
| 25 | 
         
            +
                w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
         
     | 
| 26 | 
         
            +
            ):
         
     | 
| 27 | 
         
            +
                # move dim to pack to the end
         
     | 
| 28 | 
         
            +
                perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
         
     | 
| 29 | 
         
            +
                inv_perm = tuple(perm.index(i) for i in range(len(perm)))
         
     | 
| 30 | 
         
            +
                w_q_perm = w_q.permute(perm)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                pack_factor = 32 // wtype.size_bits
         
     | 
| 33 | 
         
            +
                mask = (1 << wtype.size_bits) - 1
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                new_shape_perm = list(w_q_perm.shape)
         
     | 
| 36 | 
         
            +
                assert w_q_perm.shape[-1] % pack_factor == 0
         
     | 
| 37 | 
         
            +
                new_shape_perm[-1] //= pack_factor
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
         
     | 
| 40 | 
         
            +
                for i in range(pack_factor):
         
     | 
| 41 | 
         
            +
                    res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                return res.permute(inv_perm)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            def unpack_quantized_values_into_int32(
         
     | 
| 47 | 
         
            +
                w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
         
     | 
| 48 | 
         
            +
            ):
         
     | 
| 49 | 
         
            +
                # move dim to pack to the end
         
     | 
| 50 | 
         
            +
                perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
         
     | 
| 51 | 
         
            +
                inv_perm = tuple(perm.index(i) for i in range(len(perm)))
         
     | 
| 52 | 
         
            +
                w_q_perm = w_q.permute(perm)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                pack_factor = 32 // wtype.size_bits
         
     | 
| 55 | 
         
            +
                mask = (1 << wtype.size_bits) - 1
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                new_shape_perm = list(w_q_perm.shape)
         
     | 
| 58 | 
         
            +
                new_shape_perm[-1] *= pack_factor
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
         
     | 
| 61 | 
         
            +
                for i in range(pack_factor):
         
     | 
| 62 | 
         
            +
                    res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                return res.permute(inv_perm)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
         
     | 
| 68 | 
         
            +
                # prefix: model.layers.0.self_attn.q_proj
         
     | 
| 69 | 
         
            +
                # proj_name: q_proj
         
     | 
| 70 | 
         
            +
                proj_name = prefix.split(".")[-1]
         
     | 
| 71 | 
         
            +
                if proj_name in FUSED_LAYER_NAME_MAPPING:
         
     | 
| 72 | 
         
            +
                    shard_prefixes = [
         
     | 
| 73 | 
         
            +
                        prefix.replace(proj_name, shard_proj_name)
         
     | 
| 74 | 
         
            +
                        for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
         
     | 
| 75 | 
         
            +
                    ]
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    is_skipped = None
         
     | 
| 78 | 
         
            +
                    for shard_prefix in shard_prefixes:
         
     | 
| 79 | 
         
            +
                        is_shard_skipped = shard_prefix in ignored_layers
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                        if is_skipped is None:
         
     | 
| 82 | 
         
            +
                            is_skipped = is_shard_skipped
         
     | 
| 83 | 
         
            +
                        elif is_shard_skipped != is_skipped:
         
     | 
| 84 | 
         
            +
                            raise ValueError(
         
     | 
| 85 | 
         
            +
                                f"Detected some but not all shards of {prefix} "
         
     | 
| 86 | 
         
            +
                                "are quantized. All shards of fused layers "
         
     | 
| 87 | 
         
            +
                                "to have the same precision."
         
     | 
| 88 | 
         
            +
                            )
         
     | 
| 89 | 
         
            +
                else:
         
     | 
| 90 | 
         
            +
                    is_skipped = prefix in ignored_layers
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                assert is_skipped is not None
         
     | 
| 93 | 
         
            +
                return is_skipped
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            def get_pack_factor(num_bits):
         
     | 
| 97 | 
         
            +
                assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
         
     | 
| 98 | 
         
            +
                return 32 // num_bits
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            def permute_rows(
         
     | 
| 102 | 
         
            +
                q_w: torch.Tensor,
         
     | 
| 103 | 
         
            +
                w_ref: torch.Tensor,
         
     | 
| 104 | 
         
            +
                group_size: int,
         
     | 
| 105 | 
         
            +
                test_perm: Optional[torch.Tensor] = None,
         
     | 
| 106 | 
         
            +
            ):
         
     | 
| 107 | 
         
            +
                assert q_w.shape == w_ref.shape
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                orig_device = q_w.device
         
     | 
| 110 | 
         
            +
                k_size, _ = q_w.shape
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                g_idx = torch.zeros((k_size,), dtype=torch.int32)
         
     | 
| 113 | 
         
            +
                for i in range(k_size):
         
     | 
| 114 | 
         
            +
                    g_idx[i] = i // group_size
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                # Simulate act_order by doing a random permutation on K
         
     | 
| 117 | 
         
            +
                rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                g_idx = g_idx[rand_perm].contiguous()
         
     | 
| 120 | 
         
            +
                q_w = q_w[rand_perm, :].contiguous()
         
     | 
| 121 | 
         
            +
                w_ref = w_ref[rand_perm, :].contiguous()
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                return (
         
     | 
| 124 | 
         
            +
                    w_ref.to(device=orig_device),
         
     | 
| 125 | 
         
            +
                    q_w.to(device=orig_device),
         
     | 
| 126 | 
         
            +
                    g_idx.to(device=orig_device),
         
     | 
| 127 | 
         
            +
                    rand_perm.to(device=orig_device),
         
     | 
| 128 | 
         
            +
                )
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
            def quantize_weights(
         
     | 
| 132 | 
         
            +
                w: torch.Tensor,
         
     | 
| 133 | 
         
            +
                quant_type: ScalarType,
         
     | 
| 134 | 
         
            +
                group_size: Optional[int],
         
     | 
| 135 | 
         
            +
                zero_points: bool = False,
         
     | 
| 136 | 
         
            +
                ref_zero_points_after_scales: bool = False,
         
     | 
| 137 | 
         
            +
            ):
         
     | 
| 138 | 
         
            +
                assert (
         
     | 
| 139 | 
         
            +
                    quant_type.is_integer()
         
     | 
| 140 | 
         
            +
                ), "Floating point quantization may work but has not been tested"
         
     | 
| 141 | 
         
            +
                assert not zero_points or group_size is not None, (
         
     | 
| 142 | 
         
            +
                    "to have group zero points, group_size must be provided "
         
     | 
| 143 | 
         
            +
                    "(-1 group_size is channelwise)"
         
     | 
| 144 | 
         
            +
                )
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                orig_device = w.device
         
     | 
| 147 | 
         
            +
                orig_type = w.dtype
         
     | 
| 148 | 
         
            +
                size_k, size_n = w.shape
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                assert w.is_floating_point(), "w must be float"
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                if group_size == -1:
         
     | 
| 153 | 
         
            +
                    group_size = size_k
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                # Reshape to [groupsize, -1]
         
     | 
| 156 | 
         
            +
                if group_size is not None and group_size < size_k:
         
     | 
| 157 | 
         
            +
                    w = w.reshape((-1, group_size, size_n))
         
     | 
| 158 | 
         
            +
                    w = w.permute(1, 0, 2)
         
     | 
| 159 | 
         
            +
                    w = w.reshape((group_size, -1))
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                # Compute scale for each group
         
     | 
| 162 | 
         
            +
                max_val = torch.max(w, 0, keepdim=True).values
         
     | 
| 163 | 
         
            +
                min_val = torch.min(w, 0, keepdim=True).values
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                max_q_val = quant_type.max()
         
     | 
| 166 | 
         
            +
                min_q_val = quant_type.min()
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                w_s = torch.Tensor([1.0]).to(w.device)  # unscaled case
         
     | 
| 169 | 
         
            +
                maybe_w_zp = None
         
     | 
| 170 | 
         
            +
                if group_size is not None:
         
     | 
| 171 | 
         
            +
                    if zero_points:
         
     | 
| 172 | 
         
            +
                        assert not quant_type.is_signed() and quant_type.max() > 0
         
     | 
| 173 | 
         
            +
                        w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
         
     | 
| 174 | 
         
            +
                        maybe_w_zp = (
         
     | 
| 175 | 
         
            +
                            torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
         
     | 
| 176 | 
         
            +
                        )
         
     | 
| 177 | 
         
            +
                    else:
         
     | 
| 178 | 
         
            +
                        # If the bias is such that there are no possible negative/positive
         
     | 
| 179 | 
         
            +
                        #  values, set the max value to inf to avoid divide by 0
         
     | 
| 180 | 
         
            +
                        w_s = torch.max(
         
     | 
| 181 | 
         
            +
                            abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
         
     | 
| 182 | 
         
            +
                            abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
         
     | 
| 183 | 
         
            +
                        )
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                # Quantize
         
     | 
| 186 | 
         
            +
                w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
         
     | 
| 187 | 
         
            +
                w_q = torch.clamp(w_q, min_q_val, max_q_val)
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                # Compute ref (dequantized)
         
     | 
| 190 | 
         
            +
                # For some kernels (namely Machete) the zero-points are applied after the
         
     | 
| 191 | 
         
            +
                # scales are applied, for this case computing the reference in similar way
         
     | 
| 192 | 
         
            +
                # allows us to use tighter error tolerances in our unit tests.
         
     | 
| 193 | 
         
            +
                if ref_zero_points_after_scales and maybe_w_zp is not None:
         
     | 
| 194 | 
         
            +
                    w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
         
     | 
| 195 | 
         
            +
                else:
         
     | 
| 196 | 
         
            +
                    w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                if quant_type.has_bias():
         
     | 
| 199 | 
         
            +
                    w_q += quant_type.bias
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                # Restore original shapes
         
     | 
| 202 | 
         
            +
                if group_size is not None and group_size < size_k:
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    def reshape_w(w):
         
     | 
| 205 | 
         
            +
                        w = w.reshape((group_size, -1, size_n))
         
     | 
| 206 | 
         
            +
                        w = w.permute(1, 0, 2)
         
     | 
| 207 | 
         
            +
                        w = w.reshape((size_k, size_n)).contiguous()
         
     | 
| 208 | 
         
            +
                        return w
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                    w_q = reshape_w(w_q)
         
     | 
| 211 | 
         
            +
                    w_ref = reshape_w(w_ref)
         
     | 
| 212 | 
         
            +
                    w_s = w_s.reshape((-1, size_n)).contiguous()
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                if maybe_w_zp is not None:
         
     | 
| 215 | 
         
            +
                    maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
         
     | 
| 216 | 
         
            +
                    maybe_w_zp = maybe_w_zp.to(device=orig_device)
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                return (
         
     | 
| 219 | 
         
            +
                    w_ref.to(device=orig_device),
         
     | 
| 220 | 
         
            +
                    w_q.to(device=orig_device),
         
     | 
| 221 | 
         
            +
                    w_s if group_size is not None else None,
         
     | 
| 222 | 
         
            +
                    maybe_w_zp,
         
     | 
| 223 | 
         
            +
                )
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
            def gptq_quantize_weights(
         
     | 
| 227 | 
         
            +
                w: torch.Tensor,
         
     | 
| 228 | 
         
            +
                quant_type: ScalarType,
         
     | 
| 229 | 
         
            +
                group_size: int,
         
     | 
| 230 | 
         
            +
                act_order: bool,
         
     | 
| 231 | 
         
            +
                test_perm: Optional[torch.Tensor] = None,
         
     | 
| 232 | 
         
            +
            ):
         
     | 
| 233 | 
         
            +
                size_k, _ = w.shape
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                assert w.is_floating_point(), "w must be float"
         
     | 
| 236 | 
         
            +
                assert (
         
     | 
| 237 | 
         
            +
                    quant_type in SUPPORTED_GPTQ_QUANT_TYPES
         
     | 
| 238 | 
         
            +
                ), f"Unsupported gptq type = {quant_type}"
         
     | 
| 239 | 
         
            +
                assert group_size in SUPPORTED_GROUP_SIZES + [
         
     | 
| 240 | 
         
            +
                    size_k
         
     | 
| 241 | 
         
            +
                ], f"Unsupported groupsize = {group_size}"
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                # Apply act_order
         
     | 
| 246 | 
         
            +
                g_idx = torch.empty(0, dtype=torch.int, device=w.device)
         
     | 
| 247 | 
         
            +
                rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
         
     | 
| 248 | 
         
            +
                if act_order:
         
     | 
| 249 | 
         
            +
                    assert (
         
     | 
| 250 | 
         
            +
                        group_size < size_k
         
     | 
| 251 | 
         
            +
                    ), "For act_order, groupsize = {} must be less than size_k = {}".format(
         
     | 
| 252 | 
         
            +
                        group_size, size_k
         
     | 
| 253 | 
         
            +
                    )
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                    w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                return w_ref, w_q, w_s, g_idx, rand_perm
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
            # QQQ employs different quant schemes for per-group and
         
     | 
| 261 | 
         
            +
            # per-channel quantization.
         
     | 
| 262 | 
         
            +
            def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
         
     | 
| 263 | 
         
            +
                orig_device = w.device
         
     | 
| 264 | 
         
            +
                size_k, size_n = w.shape
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                assert w.is_floating_point(), "w must be float"
         
     | 
| 267 | 
         
            +
                assert (
         
     | 
| 268 | 
         
            +
                    num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS
         
     | 
| 269 | 
         
            +
                ), f"Unsupported num_bits = {num_bits}"
         
     | 
| 270 | 
         
            +
                assert group_size in SUPPORTED_GROUP_SIZES + [
         
     | 
| 271 | 
         
            +
                    size_k
         
     | 
| 272 | 
         
            +
                ], f"Unsupported groupsize = {group_size}"
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                if group_size == -1:
         
     | 
| 275 | 
         
            +
                    group_size = size_k
         
     | 
| 276 | 
         
            +
                assert group_size <= size_k
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                if group_size < size_k:
         
     | 
| 279 | 
         
            +
                    # Reshape to [groupsize, -1]
         
     | 
| 280 | 
         
            +
                    w = w.reshape((-1, group_size, size_n))
         
     | 
| 281 | 
         
            +
                    w = w.permute(1, 0, 2)
         
     | 
| 282 | 
         
            +
                    w = w.reshape((group_size, -1))
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                    max_q_val = 2**num_bits - 1
         
     | 
| 285 | 
         
            +
                    half_q_val = (max_q_val + 1) // 2
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                    # Compute scale for each group
         
     | 
| 288 | 
         
            +
                    s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
         
     | 
| 289 | 
         
            +
                    s_group *= 2 / max_q_val  # 2 => symmetric
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    # Quantize
         
     | 
| 292 | 
         
            +
                    q_w = torch.round(w / s_group).int()
         
     | 
| 293 | 
         
            +
                    q_w += half_q_val
         
     | 
| 294 | 
         
            +
                    q_w = torch.clamp(q_w, 0, max_q_val)
         
     | 
| 295 | 
         
            +
                    # Compute ref (dequantized)
         
     | 
| 296 | 
         
            +
                    w_ref = (q_w - half_q_val).half() * s_group
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                    # Restore original shapes
         
     | 
| 299 | 
         
            +
                    def reshape_w(w):
         
     | 
| 300 | 
         
            +
                        w = w.reshape((group_size, -1, size_n))
         
     | 
| 301 | 
         
            +
                        w = w.permute(1, 0, 2)
         
     | 
| 302 | 
         
            +
                        w = w.reshape((size_k, size_n)).contiguous()
         
     | 
| 303 | 
         
            +
                        return w
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                    q_w = reshape_w(q_w)
         
     | 
| 306 | 
         
            +
                    w_ref = reshape_w(w_ref)
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                    # Compute int8 quantization scale for each channel
         
     | 
| 309 | 
         
            +
                    s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
         
     | 
| 310 | 
         
            +
                    s_channel /= 127.0
         
     | 
| 311 | 
         
            +
                    t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
         
     | 
| 312 | 
         
            +
                    w_ref = t_int8.half() * s_channel
         
     | 
| 313 | 
         
            +
                    s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
                    # Fuse scales
         
     | 
| 316 | 
         
            +
                    s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to(
         
     | 
| 317 | 
         
            +
                        dtype=torch.half
         
     | 
| 318 | 
         
            +
                    )
         
     | 
| 319 | 
         
            +
                else:
         
     | 
| 320 | 
         
            +
                    max_q_val = 2 ** (num_bits - 1) - 1
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                    # Compute scale for each channel
         
     | 
| 323 | 
         
            +
                    s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
         
     | 
| 324 | 
         
            +
                    s_channel /= max_q_val
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                    # Quantize
         
     | 
| 327 | 
         
            +
                    q_w = torch.round(w / s_channel).int()
         
     | 
| 328 | 
         
            +
                    q_w = torch.clamp(q_w, -max_q_val, max_q_val)
         
     | 
| 329 | 
         
            +
                    # Compute ref (dequantized)
         
     | 
| 330 | 
         
            +
                    w_ref = q_w.half() * s_channel
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                    s_group = torch.tensor([], dtype=torch.half)
         
     | 
| 333 | 
         
            +
                    # div 2 ** (8 - self.bits)) to offset right shift in unpacking
         
     | 
| 334 | 
         
            +
                    s_channel /= 2 ** (8 - num_bits)
         
     | 
| 335 | 
         
            +
                    s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                return (
         
     | 
| 338 | 
         
            +
                    w_ref.to(device=orig_device),
         
     | 
| 339 | 
         
            +
                    q_w.to(device=orig_device),
         
     | 
| 340 | 
         
            +
                    s_group.to(device=orig_device),
         
     | 
| 341 | 
         
            +
                    s_channel.to(device=orig_device),
         
     | 
| 342 | 
         
            +
                )
         
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
            def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
         
     | 
| 346 | 
         
            +
                orig_device = q_w.device
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                sort_indices = torch.argsort(g_idx).to(dtype=torch.int32)  # Sort based on g_idx
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                g_idx = g_idx[sort_indices].contiguous()
         
     | 
| 351 | 
         
            +
                q_w = q_w[sort_indices, :].contiguous()
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
                return (
         
     | 
| 354 | 
         
            +
                    q_w.to(device=orig_device),
         
     | 
| 355 | 
         
            +
                    g_idx.to(device=orig_device),
         
     | 
| 356 | 
         
            +
                    sort_indices.to(device=orig_device),
         
     | 
| 357 | 
         
            +
                )
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
             
     | 
| 360 | 
         
            +
            def pack_rows(
         
     | 
| 361 | 
         
            +
                q_w: torch.Tensor,
         
     | 
| 362 | 
         
            +
                num_bits: int,
         
     | 
| 363 | 
         
            +
                size_k: int,
         
     | 
| 364 | 
         
            +
                size_n: int,
         
     | 
| 365 | 
         
            +
            ):
         
     | 
| 366 | 
         
            +
                assert q_w.shape == (size_k, size_n)
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
                pack_factor = get_pack_factor(num_bits)
         
     | 
| 369 | 
         
            +
                assert size_k % pack_factor == 0
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                orig_device = q_w.device
         
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
                q_w = q_w.cpu().numpy().astype(numpy.uint32)
         
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
                q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                for i in range(pack_factor):
         
     | 
| 378 | 
         
            +
                    q_res |= q_w[i::pack_factor, :] << num_bits * i
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
                q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
         
     | 
| 381 | 
         
            +
                return q_res
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
             
     | 
| 384 | 
         
            +
            def pack_cols(
         
     | 
| 385 | 
         
            +
                q_w: torch.Tensor,
         
     | 
| 386 | 
         
            +
                num_bits: int,
         
     | 
| 387 | 
         
            +
                size_k: int,
         
     | 
| 388 | 
         
            +
                size_n: int,
         
     | 
| 389 | 
         
            +
            ):
         
     | 
| 390 | 
         
            +
                assert q_w.shape == (size_k, size_n)
         
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
                pack_factor = get_pack_factor(num_bits)
         
     | 
| 393 | 
         
            +
                assert size_n % pack_factor == 0
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                orig_device = q_w.device
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
                q_w = q_w.cpu().numpy().astype(numpy.uint32)
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
                for i in range(pack_factor):
         
     | 
| 402 | 
         
            +
                    q_res |= q_w[:, i::pack_factor] << num_bits * i
         
     | 
| 403 | 
         
            +
             
     | 
| 404 | 
         
            +
                q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
         
     | 
| 405 | 
         
            +
                q_res = q_res.contiguous()
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
                return q_res
         
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
            def unpack_cols(
         
     | 
| 411 | 
         
            +
                packed_q_w: torch.Tensor,
         
     | 
| 412 | 
         
            +
                num_bits: int,
         
     | 
| 413 | 
         
            +
                size_k: int,
         
     | 
| 414 | 
         
            +
                size_n: int,
         
     | 
| 415 | 
         
            +
            ):
         
     | 
| 416 | 
         
            +
                pack_factor = get_pack_factor(num_bits)
         
     | 
| 417 | 
         
            +
                assert size_n % pack_factor == 0
         
     | 
| 418 | 
         
            +
                assert packed_q_w.shape == (
         
     | 
| 419 | 
         
            +
                    size_k,
         
     | 
| 420 | 
         
            +
                    size_n // pack_factor,
         
     | 
| 421 | 
         
            +
                ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
         
     | 
| 422 | 
         
            +
                    packed_q_w.shape, size_k, size_n, pack_factor
         
     | 
| 423 | 
         
            +
                )
         
     | 
| 424 | 
         
            +
             
     | 
| 425 | 
         
            +
                orig_device = packed_q_w.device
         
     | 
| 426 | 
         
            +
             
     | 
| 427 | 
         
            +
                packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
         
     | 
| 428 | 
         
            +
                q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
         
     | 
| 429 | 
         
            +
             
     | 
| 430 | 
         
            +
                mask = (1 << num_bits) - 1
         
     | 
| 431 | 
         
            +
                for i in range(pack_factor):
         
     | 
| 432 | 
         
            +
                    vals = packed_q_w_cpu & mask
         
     | 
| 433 | 
         
            +
                    packed_q_w_cpu >>= num_bits
         
     | 
| 434 | 
         
            +
                    q_res[:, i::pack_factor] = vals
         
     | 
| 435 | 
         
            +
             
     | 
| 436 | 
         
            +
                q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
         
     | 
| 437 | 
         
            +
                q_res = q_res.contiguous()
         
     | 
| 438 | 
         
            +
             
     | 
| 439 | 
         
            +
                return q_res
         
     | 
| 440 | 
         
            +
             
     | 
| 441 | 
         
            +
             
     | 
| 442 | 
         
            +
            def gptq_pack(
         
     | 
| 443 | 
         
            +
                q_w: torch.Tensor,
         
     | 
| 444 | 
         
            +
                num_bits: int,
         
     | 
| 445 | 
         
            +
                size_k: int,
         
     | 
| 446 | 
         
            +
                size_n: int,
         
     | 
| 447 | 
         
            +
            ):
         
     | 
| 448 | 
         
            +
                return pack_rows(q_w, num_bits, size_k, size_n)
         
     | 
| 449 | 
         
            +
             
     | 
| 450 | 
         
            +
             
     | 
| 451 | 
         
            +
            def awq_pack(
         
     | 
| 452 | 
         
            +
                q_w: torch.Tensor,
         
     | 
| 453 | 
         
            +
                num_bits: int,
         
     | 
| 454 | 
         
            +
                size_k: int,
         
     | 
| 455 | 
         
            +
                size_n: int,
         
     | 
| 456 | 
         
            +
            ):
         
     | 
| 457 | 
         
            +
                assert q_w.shape == (size_k, size_n)
         
     | 
| 458 | 
         
            +
             
     | 
| 459 | 
         
            +
                # Interleave column dim (for the dequantize code) and pack it to int32
         
     | 
| 460 | 
         
            +
                if num_bits == 4:
         
     | 
| 461 | 
         
            +
                    interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
         
     | 
| 462 | 
         
            +
                elif num_bits == 8:
         
     | 
| 463 | 
         
            +
                    interleave = numpy.array([0, 2, 1, 3])
         
     | 
| 464 | 
         
            +
                else:
         
     | 
| 465 | 
         
            +
                    raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
         
     | 
| 466 | 
         
            +
             
     | 
| 467 | 
         
            +
                q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
         
     | 
| 468 | 
         
            +
                q_w = q_w.reshape((-1, size_n)).contiguous()
         
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
                return pack_cols(q_w, num_bits, size_k, size_n)
         
     | 
    	
        marlin/dense/LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,209 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            Contains code from https://github.com/IST-DASLab/marlin
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
                                             Apache License
         
     | 
| 4 | 
         
            +
                                       Version 2.0, January 2004
         
     | 
| 5 | 
         
            +
                                    http://www.apache.org/licenses/
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
               1. Definitions.
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         
     | 
| 12 | 
         
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         
     | 
| 15 | 
         
            +
                  the copyright owner that is granting the License.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         
     | 
| 18 | 
         
            +
                  other entities that control, are controlled by, or are under common
         
     | 
| 19 | 
         
            +
                  control with that entity. For the purposes of this definition,
         
     | 
| 20 | 
         
            +
                  "control" means (i) the power, direct or indirect, to cause the
         
     | 
| 21 | 
         
            +
                  direction or management of such entity, whether by contract or
         
     | 
| 22 | 
         
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         
     | 
| 23 | 
         
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         
     | 
| 26 | 
         
            +
                  exercising permissions granted by this License.
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                  "Source" form shall mean the preferred form for making modifications,
         
     | 
| 29 | 
         
            +
                  including but not limited to software source code, documentation
         
     | 
| 30 | 
         
            +
                  source, and configuration files.
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                  "Object" form shall mean any form resulting from mechanical
         
     | 
| 33 | 
         
            +
                  transformation or translation of a Source form, including but
         
     | 
| 34 | 
         
            +
                  not limited to compiled object code, generated documentation,
         
     | 
| 35 | 
         
            +
                  and conversions to other media types.
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                  "Work" shall mean the work of authorship, whether in Source or
         
     | 
| 38 | 
         
            +
                  Object form, made available under the License, as indicated by a
         
     | 
| 39 | 
         
            +
                  copyright notice that is included in or attached to the work
         
     | 
| 40 | 
         
            +
                  (an example is provided in the Appendix below).
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         
     | 
| 43 | 
         
            +
                  form, that is based on (or derived from) the Work and for which the
         
     | 
| 44 | 
         
            +
                  editorial revisions, annotations, elaborations, or other modifications
         
     | 
| 45 | 
         
            +
                  represent, as a whole, an original work of authorship. For the purposes
         
     | 
| 46 | 
         
            +
                  of this License, Derivative Works shall not include works that remain
         
     | 
| 47 | 
         
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         
     | 
| 48 | 
         
            +
                  the Work and Derivative Works thereof.
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                  "Contribution" shall mean any work of authorship, including
         
     | 
| 51 | 
         
            +
                  the original version of the Work and any modifications or additions
         
     | 
| 52 | 
         
            +
                  to that Work or Derivative Works thereof, that is intentionally
         
     | 
| 53 | 
         
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         
     | 
| 54 | 
         
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         
     | 
| 55 | 
         
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         
     | 
| 56 | 
         
            +
                  means any form of electronic, verbal, or written communication sent
         
     | 
| 57 | 
         
            +
                  to the Licensor or its representatives, including but not limited to
         
     | 
| 58 | 
         
            +
                  communication on electronic mailing lists, source code control systems,
         
     | 
| 59 | 
         
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         
     | 
| 60 | 
         
            +
                  Licensor for the purpose of discussing and improving the Work, but
         
     | 
| 61 | 
         
            +
                  excluding communication that is conspicuously marked or otherwise
         
     | 
| 62 | 
         
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         
     | 
| 65 | 
         
            +
                  on behalf of whom a Contribution has been received by Licensor and
         
     | 
| 66 | 
         
            +
                  subsequently incorporated within the Work.
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         
     | 
| 69 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 70 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 71 | 
         
            +
                  copyright license to reproduce, prepare Derivative Works of,
         
     | 
| 72 | 
         
            +
                  publicly display, publicly perform, sublicense, and distribute the
         
     | 
| 73 | 
         
            +
                  Work and such Derivative Works in Source or Object form.
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         
     | 
| 76 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 77 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 78 | 
         
            +
                  (except as stated in this section) patent license to make, have made,
         
     | 
| 79 | 
         
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         
     | 
| 80 | 
         
            +
                  where such license applies only to those patent claims licensable
         
     | 
| 81 | 
         
            +
                  by such Contributor that are necessarily infringed by their
         
     | 
| 82 | 
         
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         
     | 
| 83 | 
         
            +
                  with the Work to which such Contribution(s) was submitted. If You
         
     | 
| 84 | 
         
            +
                  institute patent litigation against any entity (including a
         
     | 
| 85 | 
         
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         
     | 
| 86 | 
         
            +
                  or a Contribution incorporated within the Work constitutes direct
         
     | 
| 87 | 
         
            +
                  or contributory patent infringement, then any patent licenses
         
     | 
| 88 | 
         
            +
                  granted to You under this License for that Work shall terminate
         
     | 
| 89 | 
         
            +
                  as of the date such litigation is filed.
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
               4. Redistribution. You may reproduce and distribute copies of the
         
     | 
| 92 | 
         
            +
                  Work or Derivative Works thereof in any medium, with or without
         
     | 
| 93 | 
         
            +
                  modifications, and in Source or Object form, provided that You
         
     | 
| 94 | 
         
            +
                  meet the following conditions:
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                  (a) You must give any other recipients of the Work or
         
     | 
| 97 | 
         
            +
                      Derivative Works a copy of this License; and
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                  (b) You must cause any modified files to carry prominent notices
         
     | 
| 100 | 
         
            +
                      stating that You changed the files; and
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                  (c) You must retain, in the Source form of any Derivative Works
         
     | 
| 103 | 
         
            +
                      that You distribute, all copyright, patent, trademark, and
         
     | 
| 104 | 
         
            +
                      attribution notices from the Source form of the Work,
         
     | 
| 105 | 
         
            +
                      excluding those notices that do not pertain to any part of
         
     | 
| 106 | 
         
            +
                      the Derivative Works; and
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         
     | 
| 109 | 
         
            +
                      distribution, then any Derivative Works that You distribute must
         
     | 
| 110 | 
         
            +
                      include a readable copy of the attribution notices contained
         
     | 
| 111 | 
         
            +
                      within such NOTICE file, excluding those notices that do not
         
     | 
| 112 | 
         
            +
                      pertain to any part of the Derivative Works, in at least one
         
     | 
| 113 | 
         
            +
                      of the following places: within a NOTICE text file distributed
         
     | 
| 114 | 
         
            +
                      as part of the Derivative Works; within the Source form or
         
     | 
| 115 | 
         
            +
                      documentation, if provided along with the Derivative Works; or,
         
     | 
| 116 | 
         
            +
                      within a display generated by the Derivative Works, if and
         
     | 
| 117 | 
         
            +
                      wherever such third-party notices normally appear. The contents
         
     | 
| 118 | 
         
            +
                      of the NOTICE file are for informational purposes only and
         
     | 
| 119 | 
         
            +
                      do not modify the License. You may add Your own attribution
         
     | 
| 120 | 
         
            +
                      notices within Derivative Works that You distribute, alongside
         
     | 
| 121 | 
         
            +
                      or as an addendum to the NOTICE text from the Work, provided
         
     | 
| 122 | 
         
            +
                      that such additional attribution notices cannot be construed
         
     | 
| 123 | 
         
            +
                      as modifying the License.
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                  You may add Your own copyright statement to Your modifications and
         
     | 
| 126 | 
         
            +
                  may provide additional or different license terms and conditions
         
     | 
| 127 | 
         
            +
                  for use, reproduction, or distribution of Your modifications, or
         
     | 
| 128 | 
         
            +
                  for any such Derivative Works as a whole, provided Your use,
         
     | 
| 129 | 
         
            +
                  reproduction, and distribution of the Work otherwise complies with
         
     | 
| 130 | 
         
            +
                  the conditions stated in this License.
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         
     | 
| 133 | 
         
            +
                  any Contribution intentionally submitted for inclusion in the Work
         
     | 
| 134 | 
         
            +
                  by You to the Licensor shall be under the terms and conditions of
         
     | 
| 135 | 
         
            +
                  this License, without any additional terms or conditions.
         
     | 
| 136 | 
         
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         
     | 
| 137 | 
         
            +
                  the terms of any separate license agreement you may have executed
         
     | 
| 138 | 
         
            +
                  with Licensor regarding such Contributions.
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
               6. Trademarks. This License does not grant permission to use the trade
         
     | 
| 141 | 
         
            +
                  names, trademarks, service marks, or product names of the Licensor,
         
     | 
| 142 | 
         
            +
                  except as required for reasonable and customary use in describing the
         
     | 
| 143 | 
         
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         
     | 
| 146 | 
         
            +
                  agreed to in writing, Licensor provides the Work (and each
         
     | 
| 147 | 
         
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         
     | 
| 148 | 
         
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 149 | 
         
            +
                  implied, including, without limitation, any warranties or conditions
         
     | 
| 150 | 
         
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         
     | 
| 151 | 
         
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         
     | 
| 152 | 
         
            +
                  appropriateness of using or redistributing the Work and assume any
         
     | 
| 153 | 
         
            +
                  risks associated with Your exercise of permissions under this License.
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
               8. Limitation of Liability. In no event and under no legal theory,
         
     | 
| 156 | 
         
            +
                  whether in tort (including negligence), contract, or otherwise,
         
     | 
| 157 | 
         
            +
                  unless required by applicable law (such as deliberate and grossly
         
     | 
| 158 | 
         
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         
     | 
| 159 | 
         
            +
                  liable to You for damages, including any direct, indirect, special,
         
     | 
| 160 | 
         
            +
                  incidental, or consequential damages of any character arising as a
         
     | 
| 161 | 
         
            +
                  result of this License or out of the use or inability to use the
         
     | 
| 162 | 
         
            +
                  Work (including but not limited to damages for loss of goodwill,
         
     | 
| 163 | 
         
            +
                  work stoppage, computer failure or malfunction, or any and all
         
     | 
| 164 | 
         
            +
                  other commercial damages or losses), even if such Contributor
         
     | 
| 165 | 
         
            +
                  has been advised of the possibility of such damages.
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         
     | 
| 168 | 
         
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         
     | 
| 169 | 
         
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         
     | 
| 170 | 
         
            +
                  or other liability obligations and/or rights consistent with this
         
     | 
| 171 | 
         
            +
                  License. However, in accepting such obligations, You may act only
         
     | 
| 172 | 
         
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         
     | 
| 173 | 
         
            +
                  of any other Contributor, and only if You agree to indemnify,
         
     | 
| 174 | 
         
            +
                  defend, and hold each Contributor harmless for any liability
         
     | 
| 175 | 
         
            +
                  incurred by, or claims asserted against, such Contributor by reason
         
     | 
| 176 | 
         
            +
                  of your accepting any such warranty or additional liability.
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
               END OF TERMS AND CONDITIONS
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
               APPENDIX: How to apply the Apache License to your work.
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                  To apply the Apache License to your work, attach the following
         
     | 
| 183 | 
         
            +
                  boilerplate notice, with the fields enclosed by brackets "{}"
         
     | 
| 184 | 
         
            +
                  replaced with your own identifying information. (Don't include
         
     | 
| 185 | 
         
            +
                  the brackets!)  The text should be enclosed in the appropriate
         
     | 
| 186 | 
         
            +
                  comment syntax for the file format. We also recommend that a
         
     | 
| 187 | 
         
            +
                  file or class name and description of purpose be included on the
         
     | 
| 188 | 
         
            +
                  same "printed page" as the copyright notice for easier
         
     | 
| 189 | 
         
            +
                  identification within third-party archives.
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
               Copyright {yyyy} {name of copyright owner}
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 194 | 
         
            +
               you may not use this file except in compliance with the License.
         
     | 
| 195 | 
         
            +
               You may obtain a copy of the License at
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
               Unless required by applicable law or agreed to in writing, software
         
     | 
| 200 | 
         
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 201 | 
         
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 202 | 
         
            +
               See the License for the specific language governing permissions and
         
     | 
| 203 | 
         
            +
               limitations under the License.
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
            ------------------------------------------------------------------------------------
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
            This product bundles various third-party components under other open source licenses.
         
     | 
| 208 | 
         
            +
            This section summarizes those components and their licenses. See licenses/
         
     | 
| 209 | 
         
            +
            for text of these licenses.
         
     | 
    	
        marlin/dense/common/base.h
    ADDED
    
    | 
         @@ -0,0 +1,32 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /*
         
     | 
| 2 | 
         
            +
             * Modified by HandH1998
         
     | 
| 3 | 
         
            +
             * Modified by Neural Magic
         
     | 
| 4 | 
         
            +
             * Copyright (C) Marlin.2024 Elias Frantar
         
     | 
| 5 | 
         
            +
             *
         
     | 
| 6 | 
         
            +
             * Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 7 | 
         
            +
             * you may not use this file except in compliance with the License.
         
     | 
| 8 | 
         
            +
             * You may obtain a copy of the License at
         
     | 
| 9 | 
         
            +
             *
         
     | 
| 10 | 
         
            +
             *         http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 11 | 
         
            +
             *
         
     | 
| 12 | 
         
            +
             * Unless required by applicable law or agreed to in writing, software
         
     | 
| 13 | 
         
            +
             * distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 14 | 
         
            +
             * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 15 | 
         
            +
             * See the License for the specific language governing permissions and
         
     | 
| 16 | 
         
            +
             * limitations under the License.
         
     | 
| 17 | 
         
            +
             */
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            #pragma once
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            // Instances of `Vec` are used to organize groups of >>registers<<, as needed
         
     | 
| 24 | 
         
            +
            // for instance as inputs to tensor core operations. Consequently, all
         
     | 
| 25 | 
         
            +
            // corresponding index accesses must be compile-time constants, which is why we
         
     | 
| 26 | 
         
            +
            // extensively use `#pragma unroll` throughout the kernel code to guarantee
         
     | 
| 27 | 
         
            +
            // this.
         
     | 
| 28 | 
         
            +
            template <typename T, int n>
         
     | 
| 29 | 
         
            +
            struct Vec {
         
     | 
| 30 | 
         
            +
              T elems[n];
         
     | 
| 31 | 
         
            +
              __device__ T& operator[](int i) { return elems[i]; }
         
     | 
| 32 | 
         
            +
            };
         
     | 
    	
        marlin/dense/common/mem.h
    ADDED
    
    | 
         @@ -0,0 +1,89 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /*
         
     | 
| 2 | 
         
            +
             * Modified by HandH1998
         
     | 
| 3 | 
         
            +
             * Modified by Neural Magic
         
     | 
| 4 | 
         
            +
             * Copyright (C) Marlin.2024 Elias Frantar
         
     | 
| 5 | 
         
            +
             *
         
     | 
| 6 | 
         
            +
             * Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 7 | 
         
            +
             * you may not use this file except in compliance with the License.
         
     | 
| 8 | 
         
            +
             * You may obtain a copy of the License at
         
     | 
| 9 | 
         
            +
             *
         
     | 
| 10 | 
         
            +
             *         http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 11 | 
         
            +
             *
         
     | 
| 12 | 
         
            +
             * Unless required by applicable law or agreed to in writing, software
         
     | 
| 13 | 
         
            +
             * distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 14 | 
         
            +
             * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 15 | 
         
            +
             * See the License for the specific language governing permissions and
         
     | 
| 16 | 
         
            +
             * limitations under the License.
         
     | 
| 17 | 
         
            +
             */
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            #pragma once
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            // Predicated asynchronous global->shared copy; used for inputs A where we apply
         
     | 
| 22 | 
         
            +
            // predication to handle batchsizes that are not multiples of 16.
         
     | 
| 23 | 
         
            +
            __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
         
     | 
| 24 | 
         
            +
                                                  bool pred = true) {
         
     | 
| 25 | 
         
            +
              const int BYTES = 16;
         
     | 
| 26 | 
         
            +
              uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
         
     | 
| 27 | 
         
            +
              asm volatile(
         
     | 
| 28 | 
         
            +
                  "{\n"
         
     | 
| 29 | 
         
            +
                  "   .reg .pred p;\n"
         
     | 
| 30 | 
         
            +
                  "   setp.ne.b32 p, %0, 0;\n"
         
     | 
| 31 | 
         
            +
                  "   @p cp.async.cg.shared.global [%1], [%2], %3;\n"
         
     | 
| 32 | 
         
            +
                  "}\n" ::"r"((int)pred),
         
     | 
| 33 | 
         
            +
                  "r"(smem), "l"(glob_ptr), "n"(BYTES));
         
     | 
| 34 | 
         
            +
            }
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            // Asynchronous global->shared copy
         
     | 
| 37 | 
         
            +
            __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
         
     | 
| 38 | 
         
            +
              const int BYTES = 16;
         
     | 
| 39 | 
         
            +
              uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
         
     | 
| 40 | 
         
            +
              asm volatile(
         
     | 
| 41 | 
         
            +
                  "{\n"
         
     | 
| 42 | 
         
            +
                  "   cp.async.cg.shared.global [%0], [%1], %2;\n"
         
     | 
| 43 | 
         
            +
                  "}\n" ::"r"(smem),
         
     | 
| 44 | 
         
            +
                  "l"(glob_ptr), "n"(BYTES));
         
     | 
| 45 | 
         
            +
            }
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            // Async copy fence.
         
     | 
| 48 | 
         
            +
            __device__ inline void cp_async_fence() {
         
     | 
| 49 | 
         
            +
              asm volatile("cp.async.commit_group;\n" ::);
         
     | 
| 50 | 
         
            +
            }
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            // Wait until at most `n` async copy stages are still pending.
         
     | 
| 53 | 
         
            +
            template <int n>
         
     | 
| 54 | 
         
            +
            __device__ inline void cp_async_wait() {
         
     | 
| 55 | 
         
            +
              asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
         
     | 
| 56 | 
         
            +
            }
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            // Wait until barrier reaches `count`, then lock for current threadblock.
         
     | 
| 59 | 
         
            +
            __device__ inline void barrier_acquire(int* lock, int count) {
         
     | 
| 60 | 
         
            +
              if (threadIdx.x == 0) {
         
     | 
| 61 | 
         
            +
                int state = -1;
         
     | 
| 62 | 
         
            +
                do
         
     | 
| 63 | 
         
            +
                  // Guarantee that subsequent writes by this threadblock will be visible
         
     | 
| 64 | 
         
            +
                  // globally.
         
     | 
| 65 | 
         
            +
                  asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
         
     | 
| 66 | 
         
            +
                               : "=r"(state)
         
     | 
| 67 | 
         
            +
                               : "l"(lock));
         
     | 
| 68 | 
         
            +
                while (state != count);
         
     | 
| 69 | 
         
            +
              }
         
     | 
| 70 | 
         
            +
              __syncthreads();
         
     | 
| 71 | 
         
            +
            }
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            // Release barrier and increment visitation count.
         
     | 
| 74 | 
         
            +
            __device__ inline void barrier_release(int* lock, bool reset = false) {
         
     | 
| 75 | 
         
            +
              __syncthreads();
         
     | 
| 76 | 
         
            +
              if (threadIdx.x == 0) {
         
     | 
| 77 | 
         
            +
                if (reset) {
         
     | 
| 78 | 
         
            +
                  lock[0] = 0;
         
     | 
| 79 | 
         
            +
                  return;
         
     | 
| 80 | 
         
            +
                }
         
     | 
| 81 | 
         
            +
                int val = 1;
         
     | 
| 82 | 
         
            +
                // Make sure that all writes since acquiring this barrier are visible
         
     | 
| 83 | 
         
            +
                // globally, while releasing the barrier.
         
     | 
| 84 | 
         
            +
                asm volatile("fence.acq_rel.gpu;\n");
         
     | 
| 85 | 
         
            +
                asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
         
     | 
| 86 | 
         
            +
                             :
         
     | 
| 87 | 
         
            +
                             : "l"(lock), "r"(val));
         
     | 
| 88 | 
         
            +
              }
         
     | 
| 89 | 
         
            +
            }
         
     | 
    	
        marlin/dense/marlin_cuda_kernel.cu
    ADDED
    
    | 
         @@ -0,0 +1,1068 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /*
         
     | 
| 2 | 
         
            +
             * Modified by Neural Magic
         
     | 
| 3 | 
         
            +
             * Copyright (C) Marlin.2024 Elias Frantar
         
     | 
| 4 | 
         
            +
             *
         
     | 
| 5 | 
         
            +
             * Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 6 | 
         
            +
             * you may not use this file except in compliance with the License.
         
     | 
| 7 | 
         
            +
             * You may obtain a copy of the License at
         
     | 
| 8 | 
         
            +
             *
         
     | 
| 9 | 
         
            +
             *         http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 10 | 
         
            +
             *
         
     | 
| 11 | 
         
            +
             * Unless required by applicable law or agreed to in writing, software
         
     | 
| 12 | 
         
            +
             * distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 13 | 
         
            +
             * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 14 | 
         
            +
             * See the License for the specific language governing permissions and
         
     | 
| 15 | 
         
            +
             * limitations under the License.
         
     | 
| 16 | 
         
            +
             */
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            #include <torch/all.h>
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            #include <ATen/cuda/CUDAContext.h>
         
     | 
| 21 | 
         
            +
            #include <c10/cuda/CUDAGuard.h>
         
     | 
| 22 | 
         
            +
            #include <cuda.h>
         
     | 
| 23 | 
         
            +
            #include <cuda_fp16.h>
         
     | 
| 24 | 
         
            +
            #include <cuda_runtime.h>
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            #include <iostream>
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            #include "common/base.h"
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
         
     | 
| 31 | 
         
            +
              #include "common/mem.h"
         
     | 
| 32 | 
         
            +
            #endif
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            template <typename T>
         
     | 
| 35 | 
         
            +
            inline std::string str(T x) {
         
     | 
| 36 | 
         
            +
              return std::to_string(x);
         
     | 
| 37 | 
         
            +
            }
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            namespace marlin_dense {
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            using I4 = Vec<int, 4>;
         
     | 
| 44 | 
         
            +
            // Matrix fragments for tensor core instructions; their precise layout is
         
     | 
| 45 | 
         
            +
            // documented here:
         
     | 
| 46 | 
         
            +
            // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
         
     | 
| 47 | 
         
            +
            using FragA = Vec<half2, 4>;
         
     | 
| 48 | 
         
            +
            using FragB = Vec<half2, 2>;
         
     | 
| 49 | 
         
            +
            using FragC = Vec<float, 4>;
         
     | 
| 50 | 
         
            +
            using FragS = Vec<half2, 1>;  // quantization scales
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            // m16n8k16 tensor core mma instruction with fp16 inputs and fp32
         
     | 
| 53 | 
         
            +
            // output/accumulation.
         
     | 
| 54 | 
         
            +
            __device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
         
     | 
| 55 | 
         
            +
                                       FragC& frag_c) {
         
     | 
| 56 | 
         
            +
              const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
         
     | 
| 57 | 
         
            +
              const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
         
     | 
| 58 | 
         
            +
              float* c = reinterpret_cast<float*>(&frag_c);
         
     | 
| 59 | 
         
            +
              asm volatile(
         
     | 
| 60 | 
         
            +
                  "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
         
     | 
| 61 | 
         
            +
                  "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
         
     | 
| 62 | 
         
            +
                  : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
         
     | 
| 63 | 
         
            +
                  : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
         
     | 
| 64 | 
         
            +
                    "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
         
     | 
| 65 | 
         
            +
            }
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            // Instruction for loading a full 16x16 matrix fragment of operand A from shared
         
     | 
| 68 | 
         
            +
            // memory, directly in tensor core layout.
         
     | 
| 69 | 
         
            +
            __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
         
     | 
| 70 | 
         
            +
              uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
         
     | 
| 71 | 
         
            +
              uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
         
     | 
| 72 | 
         
            +
              asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
         
     | 
| 73 | 
         
            +
                           : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
         
     | 
| 74 | 
         
            +
                           : "r"(smem));
         
     | 
| 75 | 
         
            +
            }
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            // Lookup-table based 3-input logical operation; explicitly used for
         
     | 
| 78 | 
         
            +
            // dequantization as the compiler does not seem to automatically recognize it in
         
     | 
| 79 | 
         
            +
            // all cases.
         
     | 
| 80 | 
         
            +
            template <int lut>
         
     | 
| 81 | 
         
            +
            __device__ inline int lop3(int a, int b, int c) {
         
     | 
| 82 | 
         
            +
              int res;
         
     | 
| 83 | 
         
            +
              asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
         
     | 
| 84 | 
         
            +
                           : "=r"(res)
         
     | 
| 85 | 
         
            +
                           : "r"(a), "r"(b), "r"(c), "n"(lut));
         
     | 
| 86 | 
         
            +
              return res;
         
     | 
| 87 | 
         
            +
            }
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
            // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
         
     | 
| 90 | 
         
            +
            // values. We mostly follow the strategy in the link below, with some small
         
     | 
| 91 | 
         
            +
            // changes:
         
     | 
| 92 | 
         
            +
            // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
         
     | 
| 93 | 
         
            +
            __device__ inline FragB dequant(int q) {
         
     | 
| 94 | 
         
            +
              const int LO = 0x000f000f;
         
     | 
| 95 | 
         
            +
              const int HI = 0x00f000f0;
         
     | 
| 96 | 
         
            +
              const int EX = 0x64006400;
         
     | 
| 97 | 
         
            +
              // Guarantee that the `(a & b) | c` operations are LOP3s.
         
     | 
| 98 | 
         
            +
              int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
         
     | 
| 99 | 
         
            +
              int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
         
     | 
| 100 | 
         
            +
              // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
         
     | 
| 101 | 
         
            +
              // directly into `SUB` and `ADD`.
         
     | 
| 102 | 
         
            +
              const int SUB = 0x64086408;
         
     | 
| 103 | 
         
            +
              const int MUL = 0x2c002c00;
         
     | 
| 104 | 
         
            +
              const int ADD = 0xd480d480;
         
     | 
| 105 | 
         
            +
              FragB frag_b;
         
     | 
| 106 | 
         
            +
              frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
         
     | 
| 107 | 
         
            +
                                  *reinterpret_cast<const half2*>(&SUB));
         
     | 
| 108 | 
         
            +
              frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
         
     | 
| 109 | 
         
            +
                                  *reinterpret_cast<const half2*>(&MUL),
         
     | 
| 110 | 
         
            +
                                  *reinterpret_cast<const half2*>(&ADD));
         
     | 
| 111 | 
         
            +
              return frag_b;
         
     | 
| 112 | 
         
            +
            }
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
            // Multiply dequantized values by the corresponding quantization scale; used
         
     | 
| 115 | 
         
            +
            // only for grouped quantization.
         
     | 
| 116 | 
         
            +
            __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
         
     | 
| 117 | 
         
            +
              half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
         
     | 
| 118 | 
         
            +
              frag_b[0] = __hmul2(frag_b[0], s);
         
     | 
| 119 | 
         
            +
              frag_b[1] = __hmul2(frag_b[1], s);
         
     | 
| 120 | 
         
            +
            }
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
            template <const int threads,          // number of threads in a threadblock
         
     | 
| 123 | 
         
            +
                      const int thread_m_blocks,  // number of 16x16 blocks in the m
         
     | 
| 124 | 
         
            +
                                                  // dimension (batchsize) of the
         
     | 
| 125 | 
         
            +
                                                  // threadblock
         
     | 
| 126 | 
         
            +
                      const int thread_n_blocks,  // same for n dimension (output)
         
     | 
| 127 | 
         
            +
                      const int thread_k_blocks,  // same for k dimension (reduction)
         
     | 
| 128 | 
         
            +
                      const int stages,  // number of stages for the async global->shared
         
     | 
| 129 | 
         
            +
                                         // fetch pipeline
         
     | 
| 130 | 
         
            +
                      const int group_blocks = -1  // number of consecutive 16x16 blocks
         
     | 
| 131 | 
         
            +
                                                   // with a separate quantization scale
         
     | 
| 132 | 
         
            +
                      >
         
     | 
| 133 | 
         
            +
            __global__ void Marlin(
         
     | 
| 134 | 
         
            +
                const int4* __restrict__ A,  // fp16 input matrix of shape mxk
         
     | 
| 135 | 
         
            +
                const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn
         
     | 
| 136 | 
         
            +
                int4* __restrict__ C,        // fp16 output buffer of shape mxn
         
     | 
| 137 | 
         
            +
                const int4* __restrict__ s,  // fp16 quantization scales of shape
         
     | 
| 138 | 
         
            +
                                             // (k/groupsize)xn
         
     | 
| 139 | 
         
            +
                int prob_m,                  // batch dimension m
         
     | 
| 140 | 
         
            +
                int prob_n,                  // output dimension n
         
     | 
| 141 | 
         
            +
                int prob_k,                  // reduction dimension k
         
     | 
| 142 | 
         
            +
                int* locks  // extra global storage for barrier synchronization
         
     | 
| 143 | 
         
            +
            ) {
         
     | 
| 144 | 
         
            +
              // Each threadblock processes one "stripe" of the B matrix with (roughly) the
         
     | 
| 145 | 
         
            +
              // same size, which might involve multiple column "slices" (of width 16 *
         
     | 
| 146 | 
         
            +
              // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
         
     | 
| 147 | 
         
            +
              // example:
         
     | 
| 148 | 
         
            +
              //   0 1 3
         
     | 
| 149 | 
         
            +
              //   0 2 3
         
     | 
| 150 | 
         
            +
              //   1 2 4
         
     | 
| 151 | 
         
            +
              // While this kind of partitioning makes things somewhat more complicated, it
         
     | 
| 152 | 
         
            +
              // ensures good utilization of all SMs for many kinds of shape and GPU
         
     | 
| 153 | 
         
            +
              // configurations, while requiring as few slow global cross-threadblock
         
     | 
| 154 | 
         
            +
              // reductions as possible.
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
              // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
         
     | 
| 157 | 
         
            +
              // better partitioning with less reductions
         
     | 
| 158 | 
         
            +
              int parallel = 1;
         
     | 
| 159 | 
         
            +
              if (prob_m > 16 * thread_m_blocks) {
         
     | 
| 160 | 
         
            +
                parallel = prob_m / (16 * thread_m_blocks);
         
     | 
| 161 | 
         
            +
                prob_m = 16 * thread_m_blocks;
         
     | 
| 162 | 
         
            +
              }
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
              int k_tiles = prob_k / 16 / thread_k_blocks;
         
     | 
| 165 | 
         
            +
              int n_tiles = prob_n / 16 / thread_n_blocks;
         
     | 
| 166 | 
         
            +
              int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
         
     | 
| 167 | 
         
            +
              // Ensure that the number of tiles in each stripe is a multiple of the
         
     | 
| 168 | 
         
            +
              // groupsize; this avoids an annoying special case where a stripe starts in
         
     | 
| 169 | 
         
            +
              // the middle of group.
         
     | 
| 170 | 
         
            +
              if (group_blocks != -1)
         
     | 
| 171 | 
         
            +
                iters = (group_blocks / thread_k_blocks) *
         
     | 
| 172 | 
         
            +
                        ceildiv(iters, (group_blocks / thread_k_blocks));
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
              int slice_row = (iters * blockIdx.x) % k_tiles;
         
     | 
| 175 | 
         
            +
              int slice_col_par = (iters * blockIdx.x) / k_tiles;
         
     | 
| 176 | 
         
            +
              int slice_col = slice_col_par;
         
     | 
| 177 | 
         
            +
              int slice_iters;  // number of threadblock tiles in the current slice
         
     | 
| 178 | 
         
            +
              int slice_count =
         
     | 
| 179 | 
         
            +
                  0;          // total number of active threadblocks in the current slice
         
     | 
| 180 | 
         
            +
              int slice_idx;  // index of threadblock in current slice; numbered bottom to
         
     | 
| 181 | 
         
            +
                              // top
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
              // We can easily implement parallel problem execution by just remapping
         
     | 
| 184 | 
         
            +
              // indices and advancing global pointers
         
     | 
| 185 | 
         
            +
              if (slice_col_par >= n_tiles) {
         
     | 
| 186 | 
         
            +
                A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
         
     | 
| 187 | 
         
            +
                C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
         
     | 
| 188 | 
         
            +
                locks += (slice_col_par / n_tiles) * n_tiles;
         
     | 
| 189 | 
         
            +
                slice_col = slice_col_par % n_tiles;
         
     | 
| 190 | 
         
            +
              }
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
              // Compute all information about the current slice which is required for
         
     | 
| 193 | 
         
            +
              // synchronization.
         
     | 
| 194 | 
         
            +
              auto init_slice = [&]() {
         
     | 
| 195 | 
         
            +
                slice_iters =
         
     | 
| 196 | 
         
            +
                    iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
         
     | 
| 197 | 
         
            +
                if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
         
     | 
| 198 | 
         
            +
                if (slice_iters == 0) return;
         
     | 
| 199 | 
         
            +
                if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
         
     | 
| 200 | 
         
            +
                slice_count = 1;
         
     | 
| 201 | 
         
            +
                slice_idx = 0;
         
     | 
| 202 | 
         
            +
                int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
         
     | 
| 203 | 
         
            +
                if (col_first <= k_tiles * (slice_col_par + 1)) {
         
     | 
| 204 | 
         
            +
                  int col_off = col_first - k_tiles * slice_col_par;
         
     | 
| 205 | 
         
            +
                  slice_count = ceildiv(k_tiles - col_off, iters);
         
     | 
| 206 | 
         
            +
                  if (col_off > 0) slice_count++;
         
     | 
| 207 | 
         
            +
                  int delta_first = iters * blockIdx.x - col_first;
         
     | 
| 208 | 
         
            +
                  if (delta_first < 0 || (col_off == 0 && delta_first == 0))
         
     | 
| 209 | 
         
            +
                    slice_idx = slice_count - 1;
         
     | 
| 210 | 
         
            +
                  else {
         
     | 
| 211 | 
         
            +
                    slice_idx = slice_count - 1 - delta_first / iters;
         
     | 
| 212 | 
         
            +
                    if (col_off > 0) slice_idx--;
         
     | 
| 213 | 
         
            +
                  }
         
     | 
| 214 | 
         
            +
                }
         
     | 
| 215 | 
         
            +
                if (slice_col == n_tiles) {
         
     | 
| 216 | 
         
            +
                  A += 16 * thread_m_blocks * prob_k / 8;
         
     | 
| 217 | 
         
            +
                  C += 16 * thread_m_blocks * prob_n / 8;
         
     | 
| 218 | 
         
            +
                  locks += n_tiles;
         
     | 
| 219 | 
         
            +
                  slice_col = 0;
         
     | 
| 220 | 
         
            +
                }
         
     | 
| 221 | 
         
            +
              };
         
     | 
| 222 | 
         
            +
              init_slice();
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
              int a_gl_stride = prob_k / 8;  // stride of the A matrix in global memory
         
     | 
| 225 | 
         
            +
              // We typically use `constexpr` to indicate that this value is a compile-time
         
     | 
| 226 | 
         
            +
              // constant
         
     | 
| 227 | 
         
            +
              constexpr int a_sh_stride =
         
     | 
| 228 | 
         
            +
                  16 * thread_k_blocks / 8;  // stride of an A matrix tile in shared memory
         
     | 
| 229 | 
         
            +
              constexpr int a_gl_rd_delta_o =
         
     | 
| 230 | 
         
            +
                  16 * thread_k_blocks /
         
     | 
| 231 | 
         
            +
                  8;  // delta between subsequent A tiles in global memory
         
     | 
| 232 | 
         
            +
              int a_gl_rd_delta_i =
         
     | 
| 233 | 
         
            +
                  a_gl_stride *
         
     | 
| 234 | 
         
            +
                  (threads / a_gl_rd_delta_o);  // between subsequent accesses within a tile
         
     | 
| 235 | 
         
            +
              constexpr int a_sh_wr_delta =
         
     | 
| 236 | 
         
            +
                  a_sh_stride *
         
     | 
| 237 | 
         
            +
                  (threads / a_gl_rd_delta_o);  // between shared memory writes
         
     | 
| 238 | 
         
            +
              constexpr int a_sh_rd_delta_o =
         
     | 
| 239 | 
         
            +
                  2 * ((threads / 32) /
         
     | 
| 240 | 
         
            +
                       (thread_n_blocks / 4));  // between shared memory tile reads
         
     | 
| 241 | 
         
            +
              constexpr int a_sh_rd_delta_i =
         
     | 
| 242 | 
         
            +
                  a_sh_stride * 16;  // within a shared memory tile
         
     | 
| 243 | 
         
            +
              constexpr int a_sh_stage =
         
     | 
| 244 | 
         
            +
                  a_sh_stride * (16 * thread_m_blocks);  // overall size of a tile
         
     | 
| 245 | 
         
            +
              constexpr int a_sh_wr_iters =
         
     | 
| 246 | 
         
            +
                  ceildiv(a_sh_stage,
         
     | 
| 247 | 
         
            +
                          a_sh_wr_delta);  // number of shared write iterations for a tile
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
              int b_gl_stride = 16 * prob_n / 32;
         
     | 
| 250 | 
         
            +
              constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
         
     | 
| 251 | 
         
            +
              int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
         
     | 
| 252 | 
         
            +
              int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
         
     | 
| 253 | 
         
            +
              constexpr int b_sh_wr_delta = threads;
         
     | 
| 254 | 
         
            +
              constexpr int b_sh_rd_delta = threads;
         
     | 
| 255 | 
         
            +
              constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
         
     | 
| 256 | 
         
            +
              constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
              int s_gl_stride = prob_n / 8;
         
     | 
| 259 | 
         
            +
              constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
         
     | 
| 260 | 
         
            +
              constexpr int s_sh_stage = s_sh_stride;
         
     | 
| 261 | 
         
            +
              int s_gl_rd_delta = s_gl_stride;
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
              // Global A read index of current thread.
         
     | 
| 264 | 
         
            +
              int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
         
     | 
| 265 | 
         
            +
                            (threadIdx.x % a_gl_rd_delta_o);
         
     | 
| 266 | 
         
            +
              a_gl_rd += a_gl_rd_delta_o * slice_row;
         
     | 
| 267 | 
         
            +
              // Shared write index of current thread.
         
     | 
| 268 | 
         
            +
              int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
         
     | 
| 269 | 
         
            +
                            (threadIdx.x % a_gl_rd_delta_o);
         
     | 
| 270 | 
         
            +
              // Shared read index.
         
     | 
| 271 | 
         
            +
              int a_sh_rd =
         
     | 
| 272 | 
         
            +
                  a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
         
     | 
| 273 | 
         
            +
              a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
              int b_gl_rd =
         
     | 
| 276 | 
         
            +
                  b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
         
     | 
| 277 | 
         
            +
              b_gl_rd += b_sh_stride * slice_col;
         
     | 
| 278 | 
         
            +
              b_gl_rd += b_gl_rd_delta_o * slice_row;
         
     | 
| 279 | 
         
            +
              int b_sh_wr = threadIdx.x;
         
     | 
| 280 | 
         
            +
              int b_sh_rd = threadIdx.x;
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
              int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
         
     | 
| 283 | 
         
            +
                            s_sh_stride * slice_col + threadIdx.x;
         
     | 
| 284 | 
         
            +
              int s_sh_wr = threadIdx.x;
         
     | 
| 285 | 
         
            +
              int s_sh_rd;
         
     | 
| 286 | 
         
            +
              // We use a different scale layout for grouped and column-wise quantization as
         
     | 
| 287 | 
         
            +
              // we scale a `half2` tile in column-major layout in the former and in
         
     | 
| 288 | 
         
            +
              // row-major in the latter case.
         
     | 
| 289 | 
         
            +
              if (group_blocks != -1)
         
     | 
| 290 | 
         
            +
                s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
         
     | 
| 291 | 
         
            +
                          (threadIdx.x % 32) / 4;
         
     | 
| 292 | 
         
            +
              else
         
     | 
| 293 | 
         
            +
                s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
         
     | 
| 294 | 
         
            +
                          (threadIdx.x % 32) % 4;
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
              // Precompute which thread should not read memory in which iterations; this is
         
     | 
| 297 | 
         
            +
              // needed if there are more threads than required for a certain tilesize or
         
     | 
| 298 | 
         
            +
              // when the batchsize is not a multiple of 16.
         
     | 
| 299 | 
         
            +
              bool a_sh_wr_pred[a_sh_wr_iters];
         
     | 
| 300 | 
         
            +
              #pragma unroll
         
     | 
| 301 | 
         
            +
              for (int i = 0; i < a_sh_wr_iters; i++)
         
     | 
| 302 | 
         
            +
                a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
         
     | 
| 303 | 
         
            +
              bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
              // To ensure that writing and reading A tiles to/from shared memory, the
         
     | 
| 306 | 
         
            +
              // latter in fragment format, is fully bank conflict free, we need to use a
         
     | 
| 307 | 
         
            +
              // rather fancy XOR-based layout. The key here is that neither reads nor
         
     | 
| 308 | 
         
            +
              // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
         
     | 
| 309 | 
         
            +
              // same shared memory banks. Further, it seems (based on NSight-Compute) that
         
     | 
| 310 | 
         
            +
              // each warp must also write a consecutive memory segment?
         
     | 
| 311 | 
         
            +
              auto transform_a = [&](int i) {
         
     | 
| 312 | 
         
            +
                int row = i / a_gl_rd_delta_o;
         
     | 
| 313 | 
         
            +
                return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
         
     | 
| 314 | 
         
            +
              };
         
     | 
| 315 | 
         
            +
              // Since the computation of this remapping is non-trivial and, due to our main
         
     | 
| 316 | 
         
            +
              // loop unrolls, all shared memory accesses are static, we simply precompute
         
     | 
| 317 | 
         
            +
              // both transformed reads and writes.
         
     | 
| 318 | 
         
            +
              int a_sh_wr_trans[a_sh_wr_iters];
         
     | 
| 319 | 
         
            +
              #pragma unroll
         
     | 
| 320 | 
         
            +
              for (int i = 0; i < a_sh_wr_iters; i++)
         
     | 
| 321 | 
         
            +
                a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
         
     | 
| 322 | 
         
            +
              int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
         
     | 
| 323 | 
         
            +
              #pragma unroll
         
     | 
| 324 | 
         
            +
              for (int i = 0; i < b_sh_wr_iters; i++) {
         
     | 
| 325 | 
         
            +
              #pragma unroll
         
     | 
| 326 | 
         
            +
                for (int j = 0; j < thread_m_blocks; j++)
         
     | 
| 327 | 
         
            +
                  a_sh_rd_trans[i][j] =
         
     | 
| 328 | 
         
            +
                      transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
         
     | 
| 329 | 
         
            +
              }
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
              // Since B-accesses have non-constant stride they have to be computed at
         
     | 
| 332 | 
         
            +
              // runtime; we break dependencies between subsequent accesses with a tile by
         
     | 
| 333 | 
         
            +
              // maintining multiple pointers (we have enough registers), a tiny
         
     | 
| 334 | 
         
            +
              // optimization.
         
     | 
| 335 | 
         
            +
              const int4* B_ptr[b_sh_wr_iters];
         
     | 
| 336 | 
         
            +
              #pragma unroll
         
     | 
| 337 | 
         
            +
              for (int i = 0; i < b_sh_wr_iters; i++)
         
     | 
| 338 | 
         
            +
                B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
              extern __shared__ int4 sh[];
         
     | 
| 341 | 
         
            +
              // Shared memory storage for global fetch pipelines.
         
     | 
| 342 | 
         
            +
              int4* sh_a = sh;
         
     | 
| 343 | 
         
            +
              int4* sh_b = sh_a + (stages * a_sh_stage);
         
     | 
| 344 | 
         
            +
              int4* sh_s = sh_b + (stages * b_sh_stage);
         
     | 
| 345 | 
         
            +
              // Register storage for double buffer of shared memory reads.
         
     | 
| 346 | 
         
            +
              FragA frag_a[2][thread_m_blocks];
         
     | 
| 347 | 
         
            +
              I4 frag_b_quant[2];
         
     | 
| 348 | 
         
            +
              FragC frag_c[thread_m_blocks][4][2];
         
     | 
| 349 | 
         
            +
              FragS frag_s[2][4];
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
              // Zero accumulators.
         
     | 
| 352 | 
         
            +
              auto zero_accums = [&]() {
         
     | 
| 353 | 
         
            +
              #pragma unroll
         
     | 
| 354 | 
         
            +
                for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
         
     | 
| 355 | 
         
            +
                  reinterpret_cast<float*>(frag_c)[i] = 0;
         
     | 
| 356 | 
         
            +
              };
         
     | 
| 357 | 
         
            +
             
     | 
| 358 | 
         
            +
              // Asynchronously fetch the next A, B and s tile from global to the next
         
     | 
| 359 | 
         
            +
              // shared memory pipeline location.
         
     | 
| 360 | 
         
            +
              auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
         
     | 
| 361 | 
         
            +
                if (pred) {
         
     | 
| 362 | 
         
            +
                  int4* sh_a_stage = sh_a + a_sh_stage * pipe;
         
     | 
| 363 | 
         
            +
              #pragma unroll
         
     | 
| 364 | 
         
            +
                  for (int i = 0; i < a_sh_wr_iters; i++) {
         
     | 
| 365 | 
         
            +
                    cp_async4_pred(
         
     | 
| 366 | 
         
            +
                        &sh_a_stage[a_sh_wr_trans[i]],
         
     | 
| 367 | 
         
            +
                        &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
         
     | 
| 368 | 
         
            +
                        a_sh_wr_pred[i]);
         
     | 
| 369 | 
         
            +
                  }
         
     | 
| 370 | 
         
            +
                  int4* sh_b_stage = sh_b + b_sh_stage * pipe;
         
     | 
| 371 | 
         
            +
              #pragma unroll
         
     | 
| 372 | 
         
            +
                  for (int i = 0; i < b_sh_wr_iters; i++) {
         
     | 
| 373 | 
         
            +
                    cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
         
     | 
| 374 | 
         
            +
                    B_ptr[i] += b_gl_rd_delta_o;
         
     | 
| 375 | 
         
            +
                  }
         
     | 
| 376 | 
         
            +
                  // Only fetch scales if this tile starts a new group
         
     | 
| 377 | 
         
            +
                  if constexpr (group_blocks != -1) {
         
     | 
| 378 | 
         
            +
                    // This assumes group_blocks >= thread_k_blocks
         
     | 
| 379 | 
         
            +
                    // and would need to be modified to support smaller groups.
         
     | 
| 380 | 
         
            +
                    static_assert(group_blocks >= thread_k_blocks);
         
     | 
| 381 | 
         
            +
                    if (pipe % (group_blocks / thread_k_blocks) == 0) {
         
     | 
| 382 | 
         
            +
                      int4* sh_s_stage = sh_s + s_sh_stage * pipe;
         
     | 
| 383 | 
         
            +
                      if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
         
     | 
| 384 | 
         
            +
                      s_gl_rd += s_gl_rd_delta;
         
     | 
| 385 | 
         
            +
                    }
         
     | 
| 386 | 
         
            +
                  }
         
     | 
| 387 | 
         
            +
                }
         
     | 
| 388 | 
         
            +
                // Insert a fence even when we are winding down the pipeline to ensure that
         
     | 
| 389 | 
         
            +
                // waiting is also correct at this point.
         
     | 
| 390 | 
         
            +
                cp_async_fence();
         
     | 
| 391 | 
         
            +
              };
         
     | 
| 392 | 
         
            +
             
     | 
| 393 | 
         
            +
              // Wait until the next thread tile has been loaded to shared memory.
         
     | 
| 394 | 
         
            +
              auto wait_for_stage = [&]() {
         
     | 
| 395 | 
         
            +
                // We only have `stages - 2` active fetches since we are double buffering
         
     | 
| 396 | 
         
            +
                // and can only issue the next fetch when it is guaranteed that the previous
         
     | 
| 397 | 
         
            +
                // shared memory load is fully complete (as it may otherwise be
         
     | 
| 398 | 
         
            +
                // overwritten).
         
     | 
| 399 | 
         
            +
                cp_async_wait<stages - 2>();
         
     | 
| 400 | 
         
            +
                __syncthreads();
         
     | 
| 401 | 
         
            +
              };
         
     | 
| 402 | 
         
            +
             
     | 
| 403 | 
         
            +
              // Load the next sub-tile from the current location in the shared memory pipe
         
     | 
| 404 | 
         
            +
              // into the current register buffer.
         
     | 
| 405 | 
         
            +
              auto fetch_to_registers = [&](int k, int pipe) {
         
     | 
| 406 | 
         
            +
                // It may seem inefficient that we reload the groups for every sub-tile;
         
     | 
| 407 | 
         
            +
                // however, this does not seem to be a significant bottleneck, while some
         
     | 
| 408 | 
         
            +
                // theoretically better attempts have lead to bad instruction ordering by
         
     | 
| 409 | 
         
            +
                // the compiler and correspondingly a noticeable drop in performance.
         
     | 
| 410 | 
         
            +
                if constexpr (group_blocks != -1) {
         
     | 
| 411 | 
         
            +
                  // This assumes group_blocks >= thread_k_blocks
         
     | 
| 412 | 
         
            +
                  // and would need to be modified to support smaller groups.
         
     | 
| 413 | 
         
            +
                  static_assert(group_blocks >= thread_k_blocks);
         
     | 
| 414 | 
         
            +
                  int4* sh_s_stage =
         
     | 
| 415 | 
         
            +
                      sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
         
     | 
| 416 | 
         
            +
                                           (pipe / (group_blocks / thread_k_blocks)));
         
     | 
| 417 | 
         
            +
                  reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
         
     | 
| 418 | 
         
            +
                }
         
     | 
| 419 | 
         
            +
                int4* sh_a_stage = sh_a + a_sh_stage * pipe;
         
     | 
| 420 | 
         
            +
              #pragma unroll
         
     | 
| 421 | 
         
            +
                for (int i = 0; i < thread_m_blocks; i++)
         
     | 
| 422 | 
         
            +
                  ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
         
     | 
| 423 | 
         
            +
                int4* sh_b_stage = sh_b + b_sh_stage * pipe;
         
     | 
| 424 | 
         
            +
                frag_b_quant[k % 2] = *reinterpret_cast<I4*>(
         
     | 
| 425 | 
         
            +
                    &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
         
     | 
| 426 | 
         
            +
              };
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
              // Execute the actual tensor core matmul of a sub-tile.
         
     | 
| 429 | 
         
            +
              auto matmul = [&](int k) {
         
     | 
| 430 | 
         
            +
              // We have the m dimension as the inner loop in order to encourage overlapping
         
     | 
| 431 | 
         
            +
              // dequantization and matmul operations.
         
     | 
| 432 | 
         
            +
              #pragma unroll
         
     | 
| 433 | 
         
            +
                for (int j = 0; j < 4; j++) {
         
     | 
| 434 | 
         
            +
                  int b_quant = frag_b_quant[k % 2][j];
         
     | 
| 435 | 
         
            +
                  int b_quant_shift = b_quant >> 8;
         
     | 
| 436 | 
         
            +
                  FragB frag_b0 = dequant(b_quant);
         
     | 
| 437 | 
         
            +
                  // If there are no groups, we can just scale the final output once and can
         
     | 
| 438 | 
         
            +
                  // avoid doing so for each weight.
         
     | 
| 439 | 
         
            +
                  if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0);
         
     | 
| 440 | 
         
            +
                  FragB frag_b1 = dequant(b_quant_shift);
         
     | 
| 441 | 
         
            +
                  if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1);
         
     | 
| 442 | 
         
            +
              #pragma unroll
         
     | 
| 443 | 
         
            +
                  for (int i = 0; i < thread_m_blocks; i++) {
         
     | 
| 444 | 
         
            +
                    mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
         
     | 
| 445 | 
         
            +
                    mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
         
     | 
| 446 | 
         
            +
                  }
         
     | 
| 447 | 
         
            +
                }
         
     | 
| 448 | 
         
            +
              };
         
     | 
| 449 | 
         
            +
             
     | 
| 450 | 
         
            +
              // Since we slice across the k dimension of a tile in order to increase the
         
     | 
| 451 | 
         
            +
              // number of warps while keeping the n dimension of a tile reasonable, we have
         
     | 
| 452 | 
         
            +
              // multiple warps that accumulate their partial sums of the same output
         
     | 
| 453 | 
         
            +
              // location; which we have to reduce over in the end. We do in shared memory.
         
     | 
| 454 | 
         
            +
              auto thread_block_reduce = [&]() {
         
     | 
| 455 | 
         
            +
                constexpr int red_off = threads / b_sh_stride / 2;
         
     | 
| 456 | 
         
            +
                if (red_off >= 1) {
         
     | 
| 457 | 
         
            +
                  int red_idx = threadIdx.x / b_sh_stride;
         
     | 
| 458 | 
         
            +
                  constexpr int red_sh_stride = b_sh_stride * 4 * 2;
         
     | 
| 459 | 
         
            +
                  constexpr int red_sh_delta = b_sh_stride;
         
     | 
| 460 | 
         
            +
                  int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
         
     | 
| 461 | 
         
            +
                                  (threadIdx.x % b_sh_stride);
         
     | 
| 462 | 
         
            +
             
     | 
| 463 | 
         
            +
                  // Parallel logarithmic shared memory reduction. We make sure to avoid any
         
     | 
| 464 | 
         
            +
                  // unnecessary read or write iterations, e.g., for two warps we write only
         
     | 
| 465 | 
         
            +
                  // once by warp 1 and read only once by warp 0.
         
     | 
| 466 | 
         
            +
             
     | 
| 467 | 
         
            +
              #pragma unroll
         
     | 
| 468 | 
         
            +
                  for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
         
     | 
| 469 | 
         
            +
              #pragma unroll
         
     | 
| 470 | 
         
            +
                    for (int i = red_off; i > 0; i /= 2) {
         
     | 
| 471 | 
         
            +
                      if (i <= red_idx && red_idx < 2 * i) {
         
     | 
| 472 | 
         
            +
              #pragma unroll
         
     | 
| 473 | 
         
            +
                        for (int j = 0; j < 4 * 2; j++) {
         
     | 
| 474 | 
         
            +
                          int red_sh_wr =
         
     | 
| 475 | 
         
            +
                              red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
         
     | 
| 476 | 
         
            +
                          if (i < red_off) {
         
     | 
| 477 | 
         
            +
                            float* c_rd =
         
     | 
| 478 | 
         
            +
                                reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
         
     | 
| 479 | 
         
            +
                            float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
         
     | 
| 480 | 
         
            +
              #pragma unroll
         
     | 
| 481 | 
         
            +
                            for (int k = 0; k < 4; k++)
         
     | 
| 482 | 
         
            +
                              reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
         
     | 
| 483 | 
         
            +
                                  c_rd[k] + c_wr[k];
         
     | 
| 484 | 
         
            +
                          }
         
     | 
| 485 | 
         
            +
                          sh[red_sh_wr] =
         
     | 
| 486 | 
         
            +
                              reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
         
     | 
| 487 | 
         
            +
                        }
         
     | 
| 488 | 
         
            +
                      }
         
     | 
| 489 | 
         
            +
                      __syncthreads();
         
     | 
| 490 | 
         
            +
                    }
         
     | 
| 491 | 
         
            +
                    if (red_idx == 0) {
         
     | 
| 492 | 
         
            +
              #pragma unroll
         
     | 
| 493 | 
         
            +
                      for (int i = 0; i < 4 * 2; i++) {
         
     | 
| 494 | 
         
            +
                        float* c_rd =
         
     | 
| 495 | 
         
            +
                            reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
         
     | 
| 496 | 
         
            +
              #pragma unroll
         
     | 
| 497 | 
         
            +
                        for (int j = 0; j < 4; j++)
         
     | 
| 498 | 
         
            +
                          reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
         
     | 
| 499 | 
         
            +
                              c_rd[j];
         
     | 
| 500 | 
         
            +
                      }
         
     | 
| 501 | 
         
            +
                    }
         
     | 
| 502 | 
         
            +
                    __syncthreads();
         
     | 
| 503 | 
         
            +
                  }
         
     | 
| 504 | 
         
            +
                }
         
     | 
| 505 | 
         
            +
              };
         
     | 
| 506 | 
         
            +
             
     | 
| 507 | 
         
            +
              // Since multiple threadblocks may process parts of the same column slice, we
         
     | 
| 508 | 
         
            +
              // finally have to globally reduce over the results. As the striped
         
     | 
| 509 | 
         
            +
              // partitioning minimizes the number of such reductions and our outputs are
         
     | 
| 510 | 
         
            +
              // usually rather small, we perform this reduction serially in L2 cache.
         
     | 
| 511 | 
         
            +
              auto global_reduce = [&](bool first = false, bool last = false) {
         
     | 
| 512 | 
         
            +
                // We are very careful here to reduce directly in the output buffer to
         
     | 
| 513 | 
         
            +
                // maximize L2 cache utilization in this step. To do this, we write out
         
     | 
| 514 | 
         
            +
                // results in FP16 (but still reduce with FP32 compute).
         
     | 
| 515 | 
         
            +
                constexpr int active_threads = 32 * thread_n_blocks / 4;
         
     | 
| 516 | 
         
            +
                if (threadIdx.x < active_threads) {
         
     | 
| 517 | 
         
            +
                  int c_gl_stride = prob_n / 8;
         
     | 
| 518 | 
         
            +
                  int c_gl_wr_delta_o = 8 * c_gl_stride;
         
     | 
| 519 | 
         
            +
                  int c_gl_wr_delta_i = 4 * (active_threads / 32);
         
     | 
| 520 | 
         
            +
                  int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
         
     | 
| 521 | 
         
            +
                                4 * (threadIdx.x / 32) + threadIdx.x % 4;
         
     | 
| 522 | 
         
            +
                  c_gl_wr += (2 * thread_n_blocks) * slice_col;
         
     | 
| 523 | 
         
            +
                  constexpr int c_sh_wr_delta = active_threads;
         
     | 
| 524 | 
         
            +
                  int c_sh_wr = threadIdx.x;
         
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
            +
                  int row = (threadIdx.x % 32) / 4;
         
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
                  if (!first) {
         
     | 
| 529 | 
         
            +
              // Interestingly, doing direct global accesses here really seems to mess up
         
     | 
| 530 | 
         
            +
              // the compiler and lead to slowdowns, hence we also use async-copies even
         
     | 
| 531 | 
         
            +
              // though these fetches are not actually asynchronous.
         
     | 
| 532 | 
         
            +
              #pragma unroll
         
     | 
| 533 | 
         
            +
                    for (int i = 0; i < thread_m_blocks * 4; i++) {
         
     | 
| 534 | 
         
            +
                      cp_async4_pred(
         
     | 
| 535 | 
         
            +
                          &sh[c_sh_wr + c_sh_wr_delta * i],
         
     | 
| 536 | 
         
            +
                          &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
         
     | 
| 537 | 
         
            +
                             c_gl_wr_delta_i * (i % 2)],
         
     | 
| 538 | 
         
            +
                          i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
         
     | 
| 539 | 
         
            +
                    }
         
     | 
| 540 | 
         
            +
                    cp_async_fence();
         
     | 
| 541 | 
         
            +
                    cp_async_wait<0>();
         
     | 
| 542 | 
         
            +
                  }
         
     | 
| 543 | 
         
            +
             
     | 
| 544 | 
         
            +
              #pragma unroll
         
     | 
| 545 | 
         
            +
                  for (int i = 0; i < thread_m_blocks * 4; i++) {
         
     | 
| 546 | 
         
            +
                    if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
         
     | 
| 547 | 
         
            +
                      if (!first) {
         
     | 
| 548 | 
         
            +
                        int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
         
     | 
| 549 | 
         
            +
              #pragma unroll
         
     | 
| 550 | 
         
            +
                        for (int j = 0; j < 2 * 4; j++) {
         
     | 
| 551 | 
         
            +
                          reinterpret_cast<float*>(
         
     | 
| 552 | 
         
            +
                              &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
         
     | 
| 553 | 
         
            +
                              __half2float(reinterpret_cast<__half*>(&c_red)[j]);
         
     | 
| 554 | 
         
            +
                        }
         
     | 
| 555 | 
         
            +
                      }
         
     | 
| 556 | 
         
            +
                      if (!last) {
         
     | 
| 557 | 
         
            +
                        int4 c;
         
     | 
| 558 | 
         
            +
              #pragma unroll
         
     | 
| 559 | 
         
            +
                        for (int j = 0; j < 2 * 4; j++) {
         
     | 
| 560 | 
         
            +
                          reinterpret_cast<__half*>(&c)[j] =
         
     | 
| 561 | 
         
            +
                              __float2half(reinterpret_cast<float*>(
         
     | 
| 562 | 
         
            +
                                  &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
         
     | 
| 563 | 
         
            +
                        }
         
     | 
| 564 | 
         
            +
                        C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
         
     | 
| 565 | 
         
            +
                            c;
         
     | 
| 566 | 
         
            +
                      }
         
     | 
| 567 | 
         
            +
                    }
         
     | 
| 568 | 
         
            +
                  }
         
     | 
| 569 | 
         
            +
                }
         
     | 
| 570 | 
         
            +
              };
         
     | 
| 571 | 
         
            +
             
     | 
| 572 | 
         
            +
              // Write out the reduce final result in the correct layout. We only actually
         
     | 
| 573 | 
         
            +
              // reshuffle matrix fragments in this step, the reduction above is performed
         
     | 
| 574 | 
         
            +
              // in fragment layout.
         
     | 
| 575 | 
         
            +
              auto write_result = [&]() {
         
     | 
| 576 | 
         
            +
                int c_gl_stride = prob_n / 8;
         
     | 
| 577 | 
         
            +
                constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
         
     | 
| 578 | 
         
            +
                int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
         
     | 
| 579 | 
         
            +
                constexpr int c_sh_rd_delta =
         
     | 
| 580 | 
         
            +
                    c_sh_stride * (threads / (2 * thread_n_blocks));
         
     | 
| 581 | 
         
            +
             
     | 
| 582 | 
         
            +
                int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
         
     | 
| 583 | 
         
            +
                              (threadIdx.x % (2 * thread_n_blocks));
         
     | 
| 584 | 
         
            +
                c_gl_wr += (2 * thread_n_blocks) * slice_col;
         
     | 
| 585 | 
         
            +
                int c_sh_wr =
         
     | 
| 586 | 
         
            +
                    (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
         
     | 
| 587 | 
         
            +
                c_sh_wr += 32 * (threadIdx.x / 32);
         
     | 
| 588 | 
         
            +
                int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
         
     | 
| 589 | 
         
            +
                              (threadIdx.x % (2 * thread_n_blocks));
         
     | 
| 590 | 
         
            +
             
     | 
| 591 | 
         
            +
                int c_gl_wr_end = c_gl_stride * prob_m;
         
     | 
| 592 | 
         
            +
             
     | 
| 593 | 
         
            +
                // We first reorder in shared memory to guarantee the most efficient final
         
     | 
| 594 | 
         
            +
                // global write patterns
         
     | 
| 595 | 
         
            +
                auto write = [&](int idx, float c0, float c1, FragS& s) {
         
     | 
| 596 | 
         
            +
                  half2 res = __halves2half2(__float2half(c0), __float2half(c1));
         
     | 
| 597 | 
         
            +
                  if (group_blocks ==
         
     | 
| 598 | 
         
            +
                      -1)  // for per-column quantization we finally apply the scale here
         
     | 
| 599 | 
         
            +
                    res = __hmul2(res, s[0]);
         
     | 
| 600 | 
         
            +
                  ((half2*)sh)[idx] = res;
         
     | 
| 601 | 
         
            +
                };
         
     | 
| 602 | 
         
            +
                if (threadIdx.x / 32 < thread_n_blocks / 4) {
         
     | 
| 603 | 
         
            +
              #pragma unroll
         
     | 
| 604 | 
         
            +
                  for (int i = 0; i < thread_m_blocks; i++) {
         
     | 
| 605 | 
         
            +
              #pragma unroll
         
     | 
| 606 | 
         
            +
                    for (int j = 0; j < 4; j++) {
         
     | 
| 607 | 
         
            +
                      int wr = c_sh_wr + 8 * j;
         
     | 
| 608 | 
         
            +
                      write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
         
     | 
| 609 | 
         
            +
                            frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
         
     | 
| 610 | 
         
            +
                      write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
         
     | 
| 611 | 
         
            +
                            frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
         
     | 
| 612 | 
         
            +
                      write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
         
     | 
| 613 | 
         
            +
                            frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
         
     | 
| 614 | 
         
            +
                      write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
         
     | 
| 615 | 
         
            +
                            frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
         
     | 
| 616 | 
         
            +
                    }
         
     | 
| 617 | 
         
            +
                    c_sh_wr += 16 * (4 * c_sh_stride);
         
     | 
| 618 | 
         
            +
                  }
         
     | 
| 619 | 
         
            +
                }
         
     | 
| 620 | 
         
            +
                __syncthreads();
         
     | 
| 621 | 
         
            +
             
     | 
| 622 | 
         
            +
              #pragma unroll
         
     | 
| 623 | 
         
            +
                for (int i = 0;
         
     | 
| 624 | 
         
            +
                     i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
         
     | 
| 625 | 
         
            +
                     i++) {
         
     | 
| 626 | 
         
            +
                  if (c_gl_wr < c_gl_wr_end) {
         
     | 
| 627 | 
         
            +
                    C[c_gl_wr] = sh[c_sh_rd];
         
     | 
| 628 | 
         
            +
                    c_gl_wr += c_gl_wr_delta;
         
     | 
| 629 | 
         
            +
                    c_sh_rd += c_sh_rd_delta;
         
     | 
| 630 | 
         
            +
                  }
         
     | 
| 631 | 
         
            +
                }
         
     | 
| 632 | 
         
            +
              };
         
     | 
| 633 | 
         
            +
             
     | 
| 634 | 
         
            +
              // Start global fetch and register load pipelines.
         
     | 
| 635 | 
         
            +
              auto start_pipes = [&]() {
         
     | 
| 636 | 
         
            +
              #pragma unroll
         
     | 
| 637 | 
         
            +
                for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
         
     | 
| 638 | 
         
            +
                zero_accums();
         
     | 
| 639 | 
         
            +
                wait_for_stage();
         
     | 
| 640 | 
         
            +
                fetch_to_registers(0, 0);
         
     | 
| 641 | 
         
            +
                a_gl_rd += a_gl_rd_delta_o * (stages - 1);
         
     | 
| 642 | 
         
            +
              };
         
     | 
| 643 | 
         
            +
              start_pipes();
         
     | 
| 644 | 
         
            +
             
     | 
| 645 | 
         
            +
              // Main loop.
         
     | 
| 646 | 
         
            +
              while (slice_iters) {
         
     | 
| 647 | 
         
            +
              // We unroll over both the global fetch and the register load pipeline to
         
     | 
| 648 | 
         
            +
              // ensure all shared memory accesses are static. Note that both pipelines have
         
     | 
| 649 | 
         
            +
              // even length meaning that the next iteration will always start at index 0.
         
     | 
| 650 | 
         
            +
              #pragma unroll
         
     | 
| 651 | 
         
            +
                for (int pipe = 0; pipe < stages;) {
         
     | 
| 652 | 
         
            +
              #pragma unroll
         
     | 
| 653 | 
         
            +
                  for (int k = 0; k < b_sh_wr_iters; k++) {
         
     | 
| 654 | 
         
            +
                    fetch_to_registers(k + 1, pipe % stages);
         
     | 
| 655 | 
         
            +
                    if (k == b_sh_wr_iters - 2) {
         
     | 
| 656 | 
         
            +
                      fetch_to_shared((pipe + stages - 1) % stages, pipe,
         
     | 
| 657 | 
         
            +
                                      slice_iters >= stages);
         
     | 
| 658 | 
         
            +
                      pipe++;
         
     | 
| 659 | 
         
            +
                      wait_for_stage();
         
     | 
| 660 | 
         
            +
                    }
         
     | 
| 661 | 
         
            +
                    matmul(k);
         
     | 
| 662 | 
         
            +
                  }
         
     | 
| 663 | 
         
            +
                  slice_iters--;
         
     | 
| 664 | 
         
            +
                  if (slice_iters == 0) break;
         
     | 
| 665 | 
         
            +
                }
         
     | 
| 666 | 
         
            +
                a_gl_rd += a_gl_rd_delta_o * stages;
         
     | 
| 667 | 
         
            +
             
     | 
| 668 | 
         
            +
                // Process results and, if necessary, proceed to the next column slice.
         
     | 
| 669 | 
         
            +
                // While this pattern may not be the most readable, other ways of writing
         
     | 
| 670 | 
         
            +
                // the loop seemed to noticeably worse performance after compilation.
         
     | 
| 671 | 
         
            +
                if (slice_iters == 0) {
         
     | 
| 672 | 
         
            +
                  cp_async_wait<0>();
         
     | 
| 673 | 
         
            +
                  bool last = slice_idx == slice_count - 1;
         
     | 
| 674 | 
         
            +
                  // For per-column scales, we only fetch them here in the final step before
         
     | 
| 675 | 
         
            +
                  // write-out
         
     | 
| 676 | 
         
            +
                  if (group_blocks == -1 && last) {
         
     | 
| 677 | 
         
            +
                    if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
         
     | 
| 678 | 
         
            +
                    cp_async_fence();
         
     | 
| 679 | 
         
            +
                  }
         
     | 
| 680 | 
         
            +
                  thread_block_reduce();
         
     | 
| 681 | 
         
            +
                  if (group_blocks == -1 && last) {
         
     | 
| 682 | 
         
            +
                    cp_async_wait<0>();
         
     | 
| 683 | 
         
            +
                    __syncthreads();
         
     | 
| 684 | 
         
            +
                    if (threadIdx.x / 32 < thread_n_blocks / 4) {
         
     | 
| 685 | 
         
            +
                      reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
         
     | 
| 686 | 
         
            +
                      reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
         
     | 
| 687 | 
         
            +
                    }
         
     | 
| 688 | 
         
            +
                  }
         
     | 
| 689 | 
         
            +
                  if (slice_count > 1) {  // only globally reduce if there is more than one
         
     | 
| 690 | 
         
            +
                                          // block in a slice
         
     | 
| 691 | 
         
            +
                    barrier_acquire(&locks[slice_col], slice_idx);
         
     | 
| 692 | 
         
            +
                    global_reduce(slice_idx == 0, last);
         
     | 
| 693 | 
         
            +
                    barrier_release(&locks[slice_col], last);
         
     | 
| 694 | 
         
            +
                  }
         
     | 
| 695 | 
         
            +
                  if (last)  // only the last block in a slice actually writes the result
         
     | 
| 696 | 
         
            +
                    write_result();
         
     | 
| 697 | 
         
            +
                  slice_row = 0;
         
     | 
| 698 | 
         
            +
                  slice_col_par++;
         
     | 
| 699 | 
         
            +
                  slice_col++;
         
     | 
| 700 | 
         
            +
                  init_slice();
         
     | 
| 701 | 
         
            +
                  if (slice_iters) {
         
     | 
| 702 | 
         
            +
                    a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
         
     | 
| 703 | 
         
            +
                              (threadIdx.x % a_gl_rd_delta_o);
         
     | 
| 704 | 
         
            +
              #pragma unroll
         
     | 
| 705 | 
         
            +
                    for (int i = 0; i < b_sh_wr_iters; i++)
         
     | 
| 706 | 
         
            +
                      B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
         
     | 
| 707 | 
         
            +
                    if (slice_col == 0) {
         
     | 
| 708 | 
         
            +
              #pragma unroll
         
     | 
| 709 | 
         
            +
                      for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
         
     | 
| 710 | 
         
            +
                    }
         
     | 
| 711 | 
         
            +
                    s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
         
     | 
| 712 | 
         
            +
                    start_pipes();
         
     | 
| 713 | 
         
            +
                  }
         
     | 
| 714 | 
         
            +
                }
         
     | 
| 715 | 
         
            +
              }
         
     | 
| 716 | 
         
            +
            }
         
     | 
| 717 | 
         
            +
             
     | 
| 718 | 
         
            +
            #else
         
     | 
| 719 | 
         
            +
             
     | 
| 720 | 
         
            +
            template <const int threads,          // number of threads in a threadblock
         
     | 
| 721 | 
         
            +
                      const int thread_m_blocks,  // number of 16x16 blocks in the m
         
     | 
| 722 | 
         
            +
                                                  // dimension (batchsize) of the
         
     | 
| 723 | 
         
            +
                                                  // threadblock
         
     | 
| 724 | 
         
            +
                      const int thread_n_blocks,  // same for n dimension (output)
         
     | 
| 725 | 
         
            +
                      const int thread_k_blocks,  // same for k dimension (reduction)
         
     | 
| 726 | 
         
            +
                      const int stages,  // number of stages for the async global->shared
         
     | 
| 727 | 
         
            +
                                         // fetch pipeline
         
     | 
| 728 | 
         
            +
                      const int group_blocks = -1  // number of consecutive 16x16 blocks
         
     | 
| 729 | 
         
            +
                                                   // with a separate quantization scale
         
     | 
| 730 | 
         
            +
                      >
         
     | 
| 731 | 
         
            +
            __global__ void Marlin(
         
     | 
| 732 | 
         
            +
                const int4* __restrict__ A,  // fp16 input matrix of shape mxk
         
     | 
| 733 | 
         
            +
                const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn
         
     | 
| 734 | 
         
            +
                int4* __restrict__ C,        // fp16 output buffer of shape mxn
         
     | 
| 735 | 
         
            +
                const int4* __restrict__ s,  // fp16 quantization scales of shape
         
     | 
| 736 | 
         
            +
                                             // (k/groupsize)xn
         
     | 
| 737 | 
         
            +
                int prob_m,                  // batch dimension m
         
     | 
| 738 | 
         
            +
                int prob_n,                  // output dimension n
         
     | 
| 739 | 
         
            +
                int prob_k,                  // reduction dimension k
         
     | 
| 740 | 
         
            +
                int* locks  // extra global storage for barrier synchronization
         
     | 
| 741 | 
         
            +
            ) {
         
     | 
| 742 | 
         
            +
              // Marlin is not implemented yet for SM < 8.0
         
     | 
| 743 | 
         
            +
              assert(false);
         
     | 
| 744 | 
         
            +
              return;
         
     | 
| 745 | 
         
            +
            }
         
     | 
| 746 | 
         
            +
             
     | 
| 747 | 
         
            +
            #endif
         
     | 
| 748 | 
         
            +
             
     | 
| 749 | 
         
            +
            // 8 warps are a good choice since every SM has 4 schedulers and having more
         
     | 
| 750 | 
         
            +
            // than 1 warp per schedule allows some more latency hiding. At the same time,
         
     | 
| 751 | 
         
            +
            // we want relatively few warps to have many registers per warp and small tiles.
         
     | 
| 752 | 
         
            +
            const int USER_THREADS =
         
     | 
| 753 | 
         
            +
                256;               // Note: This is only used with user-provided thread_k/n
         
     | 
| 754 | 
         
            +
            const int STAGES = 4;  // 4 pipeline stages fit into shared memory
         
     | 
| 755 | 
         
            +
            const int SHARED_MEM =
         
     | 
| 756 | 
         
            +
                96 * 1024;  // max shared memory on compute capability 8.6 (< 8.0)
         
     | 
| 757 | 
         
            +
             
     | 
| 758 | 
         
            +
            static constexpr int min_thread_n = 64;
         
     | 
| 759 | 
         
            +
            static constexpr int min_thread_k = 64;
         
     | 
| 760 | 
         
            +
             
     | 
| 761 | 
         
            +
            static constexpr int tile_size = 16;
         
     | 
| 762 | 
         
            +
            static constexpr int max_par = 16;
         
     | 
| 763 | 
         
            +
             
     | 
| 764 | 
         
            +
            static constexpr int pack_factor_4bit =
         
     | 
| 765 | 
         
            +
                8;  // We have 8 4-bit vals inside a 32 bit
         
     | 
| 766 | 
         
            +
             
     | 
| 767 | 
         
            +
            #define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS,           \
         
     | 
| 768 | 
         
            +
                              GROUP_BLOCKS, NUM_THREADS)                                   \
         
     | 
| 769 | 
         
            +
              else if (thread_m_blocks == THREAD_M_BLOCKS &&                               \
         
     | 
| 770 | 
         
            +
                       thread_n_blocks == THREAD_N_BLOCKS &&                               \
         
     | 
| 771 | 
         
            +
                       thread_k_blocks == THREAD_K_BLOCKS &&                               \
         
     | 
| 772 | 
         
            +
                       group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) {       \
         
     | 
| 773 | 
         
            +
                cudaFuncSetAttribute(Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
         
     | 
| 774 | 
         
            +
                                            THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>,        \
         
     | 
| 775 | 
         
            +
                                     cudaFuncAttributeMaxDynamicSharedMemorySize,          \
         
     | 
| 776 | 
         
            +
                                     SHARED_MEM);                                          \
         
     | 
| 777 | 
         
            +
                Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS,     \
         
     | 
| 778 | 
         
            +
                       STAGES, GROUP_BLOCKS><<<blocks, NUM_THREADS, SHARED_MEM, stream>>>( \
         
     | 
| 779 | 
         
            +
                    A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks);            \
         
     | 
| 780 | 
         
            +
              }
         
     | 
| 781 | 
         
            +
             
     | 
| 782 | 
         
            +
            typedef struct {
         
     | 
| 783 | 
         
            +
              int thread_k;
         
     | 
| 784 | 
         
            +
              int thread_n;
         
     | 
| 785 | 
         
            +
              int num_threads;
         
     | 
| 786 | 
         
            +
            } thread_config_t;
         
     | 
| 787 | 
         
            +
             
     | 
| 788 | 
         
            +
            thread_config_t small_batch_thread_configs[] = {
         
     | 
| 789 | 
         
            +
                // Ordered by priority
         
     | 
| 790 | 
         
            +
             
     | 
| 791 | 
         
            +
                // thread_k, thread_n, num_threads
         
     | 
| 792 | 
         
            +
                {128, 128, 256},  // Default
         
     | 
| 793 | 
         
            +
                {128, 64, 128},   // Reduce N 2X, same K
         
     | 
| 794 | 
         
            +
                {64, 256, 256},   // Reduce K 2X, increase N 2X
         
     | 
| 795 | 
         
            +
                {64, 128, 128},   // Reduce K 2X, same N
         
     | 
| 796 | 
         
            +
            };
         
     | 
| 797 | 
         
            +
             
     | 
| 798 | 
         
            +
            thread_config_t large_batch_thread_configs[] = {
         
     | 
| 799 | 
         
            +
                // Ordered by priority
         
     | 
| 800 | 
         
            +
             
     | 
| 801 | 
         
            +
                // thread_k, thread_n, num_threads
         
     | 
| 802 | 
         
            +
                {64, 256, 256},   // Default
         
     | 
| 803 | 
         
            +
                {128, 128, 256},  // Reduce N 2X, increase K 2X
         
     | 
| 804 | 
         
            +
                {64, 128, 128},   // Reduce N 2X, same K
         
     | 
| 805 | 
         
            +
                {128, 64, 128},   // Reduce N 4X, increase K 2X
         
     | 
| 806 | 
         
            +
            };
         
     | 
| 807 | 
         
            +
             
     | 
| 808 | 
         
            +
            bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n,
         
     | 
| 809 | 
         
            +
                                 int prob_k) {
         
     | 
| 810 | 
         
            +
              // Sanity
         
     | 
| 811 | 
         
            +
              if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
         
     | 
| 812 | 
         
            +
                  th_config.num_threads == -1) {
         
     | 
| 813 | 
         
            +
                return false;
         
     | 
| 814 | 
         
            +
              }
         
     | 
| 815 | 
         
            +
             
     | 
| 816 | 
         
            +
              // Verify K/N are divisible by thread K/N
         
     | 
| 817 | 
         
            +
              if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
         
     | 
| 818 | 
         
            +
                return false;
         
     | 
| 819 | 
         
            +
              }
         
     | 
| 820 | 
         
            +
             
     | 
| 821 | 
         
            +
              // thread_k can be only 128 or 64 (because it must be less than groupsize
         
     | 
| 822 | 
         
            +
              // which is 128)
         
     | 
| 823 | 
         
            +
              if (th_config.thread_k != 128 && th_config.thread_k != 64) {
         
     | 
| 824 | 
         
            +
                return false;
         
     | 
| 825 | 
         
            +
              }
         
     | 
| 826 | 
         
            +
             
     | 
| 827 | 
         
            +
              // Verify min for thread K/N
         
     | 
| 828 | 
         
            +
              if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
         
     | 
| 829 | 
         
            +
                return false;
         
     | 
| 830 | 
         
            +
              }
         
     | 
| 831 | 
         
            +
             
     | 
| 832 | 
         
            +
              // num_threads must be at least 128 (= 4 warps)
         
     | 
| 833 | 
         
            +
              if (th_config.num_threads < 128) {
         
     | 
| 834 | 
         
            +
                return false;
         
     | 
| 835 | 
         
            +
              }
         
     | 
| 836 | 
         
            +
             
     | 
| 837 | 
         
            +
              return true;
         
     | 
| 838 | 
         
            +
            }
         
     | 
| 839 | 
         
            +
             
     | 
| 840 | 
         
            +
            thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
         
     | 
| 841 | 
         
            +
              if (prob_m <= 16) {
         
     | 
| 842 | 
         
            +
                for (auto th_config : small_batch_thread_configs) {
         
     | 
| 843 | 
         
            +
                  if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
         
     | 
| 844 | 
         
            +
                    return th_config;
         
     | 
| 845 | 
         
            +
                  }
         
     | 
| 846 | 
         
            +
                }
         
     | 
| 847 | 
         
            +
             
     | 
| 848 | 
         
            +
              } else {
         
     | 
| 849 | 
         
            +
                for (auto th_config : large_batch_thread_configs) {
         
     | 
| 850 | 
         
            +
                  if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
         
     | 
| 851 | 
         
            +
                    return th_config;
         
     | 
| 852 | 
         
            +
                  }
         
     | 
| 853 | 
         
            +
                }
         
     | 
| 854 | 
         
            +
              }
         
     | 
| 855 | 
         
            +
             
     | 
| 856 | 
         
            +
              return thread_config_t{-1, -1, -1};
         
     | 
| 857 | 
         
            +
            }
         
     | 
| 858 | 
         
            +
             
     | 
| 859 | 
         
            +
            #define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS)    \
         
     | 
| 860 | 
         
            +
              __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
         
     | 
| 861 | 
         
            +
              __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)  \
         
     | 
| 862 | 
         
            +
              __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
         
     | 
| 863 | 
         
            +
              __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)  \
         
     | 
| 864 | 
         
            +
              __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
         
     | 
| 865 | 
         
            +
              __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)  \
         
     | 
| 866 | 
         
            +
              __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
         
     | 
| 867 | 
         
            +
              __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)  \
         
     | 
| 868 | 
         
            +
              __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
         
     | 
| 869 | 
         
            +
              __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
         
     | 
| 870 | 
         
            +
             
     | 
| 871 | 
         
            +
            void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m,
         
     | 
| 872 | 
         
            +
                             int prob_n, int prob_k, void* workspace, int groupsize = -1,
         
     | 
| 873 | 
         
            +
                             int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
         
     | 
| 874 | 
         
            +
                             int thread_n = -1, int sms = -1, int max_par = 16) {
         
     | 
| 875 | 
         
            +
              int tot_m = prob_m;
         
     | 
| 876 | 
         
            +
              int tot_m_blocks = ceildiv(tot_m, 16);
         
     | 
| 877 | 
         
            +
              int pad = 16 * tot_m_blocks - tot_m;
         
     | 
| 878 | 
         
            +
             
     | 
| 879 | 
         
            +
              if (sms == -1)
         
     | 
| 880 | 
         
            +
                cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
         
     | 
| 881 | 
         
            +
             
     | 
| 882 | 
         
            +
              // Set thread config
         
     | 
| 883 | 
         
            +
              thread_config_t th_config;
         
     | 
| 884 | 
         
            +
              if (thread_k != -1 && thread_n != -1) {
         
     | 
| 885 | 
         
            +
                // User-defined config
         
     | 
| 886 | 
         
            +
                th_config = thread_config_t{thread_k, thread_n, USER_THREADS};
         
     | 
| 887 | 
         
            +
              } else {
         
     | 
| 888 | 
         
            +
                // Auto config
         
     | 
| 889 | 
         
            +
                th_config = determine_thread_config(prob_m, prob_n, prob_k);
         
     | 
| 890 | 
         
            +
              }
         
     | 
| 891 | 
         
            +
             
     | 
| 892 | 
         
            +
              if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) {
         
     | 
| 893 | 
         
            +
                throw std::runtime_error(
         
     | 
| 894 | 
         
            +
                    "Invalid thread config: thread_k = " + str(th_config.thread_k) +
         
     | 
| 895 | 
         
            +
                    ", thread_n = " + str(th_config.thread_n) +
         
     | 
| 896 | 
         
            +
                    ", num_threads = " + str(th_config.num_threads) + " for MKN = [" +
         
     | 
| 897 | 
         
            +
                    str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]");
         
     | 
| 898 | 
         
            +
              }
         
     | 
| 899 | 
         
            +
             
     | 
| 900 | 
         
            +
              // Uncomment for debug
         
     | 
| 901 | 
         
            +
              // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) +
         
     | 
| 902 | 
         
            +
              //                  ", thread_n = " + str(th_config.thread_n) +
         
     | 
| 903 | 
         
            +
              //                  ", num_threads = " + str(th_config.num_threads) + " for
         
     | 
| 904 | 
         
            +
              //                  MKN = [" + str(prob_m) +
         
     | 
| 905 | 
         
            +
              //                  ", " + str(prob_k) + ", " + str(prob_n) + "]\n";
         
     | 
| 906 | 
         
            +
             
     | 
| 907 | 
         
            +
              int num_threads = th_config.num_threads;
         
     | 
| 908 | 
         
            +
              thread_k = th_config.thread_k;
         
     | 
| 909 | 
         
            +
              thread_n = th_config.thread_n;
         
     | 
| 910 | 
         
            +
             
     | 
| 911 | 
         
            +
              int thread_k_blocks = thread_k / 16;
         
     | 
| 912 | 
         
            +
              int thread_n_blocks = thread_n / 16;
         
     | 
| 913 | 
         
            +
              int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
         
     | 
| 914 | 
         
            +
              int blocks = sms;
         
     | 
| 915 | 
         
            +
             
     | 
| 916 | 
         
            +
              if (prob_m == 0 || prob_n == 0 || prob_k == 0) {
         
     | 
| 917 | 
         
            +
                return;
         
     | 
| 918 | 
         
            +
              }
         
     | 
| 919 | 
         
            +
             
     | 
| 920 | 
         
            +
              TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
         
     | 
| 921 | 
         
            +
                          " is not divisible by thread_n = ", thread_n);
         
     | 
| 922 | 
         
            +
              TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
         
     | 
| 923 | 
         
            +
                          " is not divisible by thread_k = ", thread_k);
         
     | 
| 924 | 
         
            +
              if (group_blocks != -1) {
         
     | 
| 925 | 
         
            +
                TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
         
     | 
| 926 | 
         
            +
                            " is not divisible by group_blocks = ", group_blocks);
         
     | 
| 927 | 
         
            +
              }
         
     | 
| 928 | 
         
            +
             
     | 
| 929 | 
         
            +
              const int4* A_ptr = (const int4*)A;
         
     | 
| 930 | 
         
            +
              const int4* B_ptr = (const int4*)B;
         
     | 
| 931 | 
         
            +
              int4* C_ptr = (int4*)C;
         
     | 
| 932 | 
         
            +
              const int4* s_ptr = (const int4*)s;
         
     | 
| 933 | 
         
            +
             
     | 
| 934 | 
         
            +
              int* locks = (int*)workspace;
         
     | 
| 935 | 
         
            +
             
     | 
| 936 | 
         
            +
              for (int i = 0; i < tot_m_blocks; i += 4) {
         
     | 
| 937 | 
         
            +
                int thread_m_blocks = tot_m_blocks - i;
         
     | 
| 938 | 
         
            +
                prob_m = tot_m - 16 * i;
         
     | 
| 939 | 
         
            +
                int par = 1;
         
     | 
| 940 | 
         
            +
                if (thread_m_blocks > 4) {
         
     | 
| 941 | 
         
            +
                  // Note that parallel > 1 currently only works for inputs without any
         
     | 
| 942 | 
         
            +
                  // padding
         
     | 
| 943 | 
         
            +
                  par = (16 * thread_m_blocks - pad) / 64;
         
     | 
| 944 | 
         
            +
                  if (par > max_par) par = max_par;
         
     | 
| 945 | 
         
            +
                  prob_m = 64 * par;
         
     | 
| 946 | 
         
            +
                  i += 4 * (par - 1);
         
     | 
| 947 | 
         
            +
                  thread_m_blocks = 4;
         
     | 
| 948 | 
         
            +
                }
         
     | 
| 949 | 
         
            +
             
     | 
| 950 | 
         
            +
                // For compilation speed, we only define the kernel configurations that have
         
     | 
| 951 | 
         
            +
                // seemed useful (in terms of performance) in our testing, however many more
         
     | 
| 952 | 
         
            +
                // are, in principle, possible.
         
     | 
| 953 | 
         
            +
                if (false) {
         
     | 
| 954 | 
         
            +
                }
         
     | 
| 955 | 
         
            +
                CALL_IF(8, 8, 256)
         
     | 
| 956 | 
         
            +
                CALL_IF(16, 4, 256)
         
     | 
| 957 | 
         
            +
                CALL_IF(8, 4, 128)
         
     | 
| 958 | 
         
            +
                CALL_IF(4, 8, 128)
         
     | 
| 959 | 
         
            +
                else {
         
     | 
| 960 | 
         
            +
                  throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
         
     | 
| 961 | 
         
            +
                                           ", " + str(prob_k) + ", " + str(prob_n) + "]" +
         
     | 
| 962 | 
         
            +
                                           ", groupsize = " + str(groupsize) +
         
     | 
| 963 | 
         
            +
                                           ", thread_m_blocks = " + str(thread_m_blocks) +
         
     | 
| 964 | 
         
            +
                                           ", thread_n_blocks = " + str(thread_n_blocks) +
         
     | 
| 965 | 
         
            +
                                           ", thread_k_blocks = " + str(thread_k_blocks));
         
     | 
| 966 | 
         
            +
                }
         
     | 
| 967 | 
         
            +
             
     | 
| 968 | 
         
            +
                A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
         
     | 
| 969 | 
         
            +
                C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
         
     | 
| 970 | 
         
            +
              }
         
     | 
| 971 | 
         
            +
            }
         
     | 
| 972 | 
         
            +
             
     | 
| 973 | 
         
            +
            }  // namespace marlin_dense
         
     | 
| 974 | 
         
            +
             
     | 
| 975 | 
         
            +
            torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
         
     | 
| 976 | 
         
            +
                                      torch::Tensor& b_scales, torch::Tensor& workspace,
         
     | 
| 977 | 
         
            +
                                      int64_t size_m, int64_t size_n, int64_t size_k) {
         
     | 
| 978 | 
         
            +
              // Verify M
         
     | 
| 979 | 
         
            +
              TORCH_CHECK(size_m == a.size(0),
         
     | 
| 980 | 
         
            +
                          "Shape mismatch: a.size(0) = " + str(a.size(0)) +
         
     | 
| 981 | 
         
            +
                              ", size_m = " + str(size_m));
         
     | 
| 982 | 
         
            +
             
     | 
| 983 | 
         
            +
              // Verify K
         
     | 
| 984 | 
         
            +
              TORCH_CHECK(size_k == a.size(1),
         
     | 
| 985 | 
         
            +
                          "Shape mismatch: a.size(1) = " + str(a.size(1)) +
         
     | 
| 986 | 
         
            +
                              ", size_k = " + str(size_k));
         
     | 
| 987 | 
         
            +
              TORCH_CHECK(size_k % marlin_dense::tile_size == 0,
         
     | 
| 988 | 
         
            +
                          "size_k = " + str(size_k) + " is not divisible by tile_size = " +
         
     | 
| 989 | 
         
            +
                              str(marlin_dense::tile_size));
         
     | 
| 990 | 
         
            +
              TORCH_CHECK((size_k / marlin_dense::tile_size) == b_q_weight.size(0),
         
     | 
| 991 | 
         
            +
                          "Shape mismatch: b_q_weight.size(0) = " +
         
     | 
| 992 | 
         
            +
                              str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
         
     | 
| 993 | 
         
            +
                              ", tile_size = " + str(marlin_dense::tile_size));
         
     | 
| 994 | 
         
            +
             
     | 
| 995 | 
         
            +
              // Verify N
         
     | 
| 996 | 
         
            +
              TORCH_CHECK(b_scales.size(1) == size_n,
         
     | 
| 997 | 
         
            +
                          "b_scales.size(1) = " + str(b_scales.size(1)) +
         
     | 
| 998 | 
         
            +
                              ", size_n = " + str(size_n));
         
     | 
| 999 | 
         
            +
              TORCH_CHECK(
         
     | 
| 1000 | 
         
            +
                  b_q_weight.size(1) % marlin_dense::tile_size == 0,
         
     | 
| 1001 | 
         
            +
                  "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
         
     | 
| 1002 | 
         
            +
                      " is not divisible by tile_size = " + str(marlin_dense::tile_size));
         
     | 
| 1003 | 
         
            +
             
     | 
| 1004 | 
         
            +
              int actual_size_n = (b_q_weight.size(1) / marlin_dense::tile_size) *
         
     | 
| 1005 | 
         
            +
                                  marlin_dense::pack_factor_4bit;
         
     | 
| 1006 | 
         
            +
              TORCH_CHECK(
         
     | 
| 1007 | 
         
            +
                  size_n == actual_size_n,
         
     | 
| 1008 | 
         
            +
                  "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
         
     | 
| 1009 | 
         
            +
             
     | 
| 1010 | 
         
            +
              // Verify A device and strides
         
     | 
| 1011 | 
         
            +
              TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
         
     | 
| 1012 | 
         
            +
              TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
         
     | 
| 1013 | 
         
            +
             
     | 
| 1014 | 
         
            +
              // Verify B device and strides
         
     | 
| 1015 | 
         
            +
              TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
         
     | 
| 1016 | 
         
            +
              TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
         
     | 
| 1017 | 
         
            +
             
     | 
| 1018 | 
         
            +
              // Verify scales device and strides
         
     | 
| 1019 | 
         
            +
              TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
         
     | 
| 1020 | 
         
            +
              TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
         
     | 
| 1021 | 
         
            +
             
     | 
| 1022 | 
         
            +
              // Alloc C matrix
         
     | 
| 1023 | 
         
            +
              const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
         
     | 
| 1024 | 
         
            +
              auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
         
     | 
| 1025 | 
         
            +
              torch::Tensor c = torch::empty({size_m, size_n}, options);
         
     | 
| 1026 | 
         
            +
             
     | 
| 1027 | 
         
            +
              // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
         
     | 
| 1028 | 
         
            +
              // auto -1)
         
     | 
| 1029 | 
         
            +
              int thread_k = -1;
         
     | 
| 1030 | 
         
            +
              // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
         
     | 
| 1031 | 
         
            +
              // auto -1)
         
     | 
| 1032 | 
         
            +
              int thread_n = -1;
         
     | 
| 1033 | 
         
            +
              // sms: number of SMs to use for the kernel (can usually be left as auto -1)
         
     | 
| 1034 | 
         
            +
              int sms = -1;
         
     | 
| 1035 | 
         
            +
             
     | 
| 1036 | 
         
            +
              // Detect groupsize
         
     | 
| 1037 | 
         
            +
              if (b_scales.size(0) != 1) {
         
     | 
| 1038 | 
         
            +
                TORCH_CHECK(size_k % b_scales.size(0) == 0,
         
     | 
| 1039 | 
         
            +
                            "size_k = " + str(size_k) +
         
     | 
| 1040 | 
         
            +
                                ", is not divisible by b_scales.size(0) = " +
         
     | 
| 1041 | 
         
            +
                                str(b_scales.size(0)));
         
     | 
| 1042 | 
         
            +
              }
         
     | 
| 1043 | 
         
            +
              int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0);
         
     | 
| 1044 | 
         
            +
             
     | 
| 1045 | 
         
            +
              // Verify groupsize
         
     | 
| 1046 | 
         
            +
              TORCH_CHECK(groupsize == -1 || groupsize == 128,
         
     | 
| 1047 | 
         
            +
                          "Unexpected groupsize = " + str(groupsize));
         
     | 
| 1048 | 
         
            +
             
     | 
| 1049 | 
         
            +
              // Verify workspace size
         
     | 
| 1050 | 
         
            +
              TORCH_CHECK(size_n % marlin_dense::min_thread_n == 0,
         
     | 
| 1051 | 
         
            +
                          "size_n = " + str(size_n) +
         
     | 
| 1052 | 
         
            +
                              ", is not divisible by min_thread_n = " +
         
     | 
| 1053 | 
         
            +
                              str(marlin_dense::min_thread_n));
         
     | 
| 1054 | 
         
            +
              int min_workspace_size =
         
     | 
| 1055 | 
         
            +
                  (size_n / marlin_dense::min_thread_n) * marlin_dense::max_par;
         
     | 
| 1056 | 
         
            +
              TORCH_CHECK(workspace.numel() >= min_workspace_size,
         
     | 
| 1057 | 
         
            +
                          "workspace.numel = " + str(workspace.numel()) +
         
     | 
| 1058 | 
         
            +
                              " is below min_workspace_size = " + str(min_workspace_size));
         
     | 
| 1059 | 
         
            +
             
     | 
| 1060 | 
         
            +
              int dev = a.get_device();
         
     | 
| 1061 | 
         
            +
              marlin_dense::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(),
         
     | 
| 1062 | 
         
            +
                                        b_scales.data_ptr(), size_m, size_n, size_k,
         
     | 
| 1063 | 
         
            +
                                        workspace.data_ptr(), groupsize, dev,
         
     | 
| 1064 | 
         
            +
                                        at::cuda::getCurrentCUDAStream(dev), thread_k,
         
     | 
| 1065 | 
         
            +
                                        thread_n, sms, marlin_dense::max_par);
         
     | 
| 1066 | 
         
            +
             
     | 
| 1067 | 
         
            +
              return c;
         
     | 
| 1068 | 
         
            +
            }
         
     | 
    	
        marlin/qqq/marlin_qqq_gemm_kernel.cu
    ADDED
    
    | 
         @@ -0,0 +1,1243 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /*
         
     | 
| 2 | 
         
            +
             * Adapted from
         
     | 
| 3 | 
         
            +
             * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda_kernel.cu
         
     | 
| 4 | 
         
            +
             * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda.cpp
         
     | 
| 5 | 
         
            +
             * Modified by HandH1998
         
     | 
| 6 | 
         
            +
             * Copyright (C) 2024 HandH1998
         
     | 
| 7 | 
         
            +
             * Copyright (C) Marlin.2024 Elias Frantar
         
     | 
| 8 | 
         
            +
             *
         
     | 
| 9 | 
         
            +
             * Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 10 | 
         
            +
             * you may not use this file except in compliance with the License.
         
     | 
| 11 | 
         
            +
             * You may obtain a copy of the License at
         
     | 
| 12 | 
         
            +
             *
         
     | 
| 13 | 
         
            +
             *         http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 14 | 
         
            +
             *
         
     | 
| 15 | 
         
            +
             * Unless required by applicable law or agreed to in writing, software
         
     | 
| 16 | 
         
            +
             * distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 17 | 
         
            +
             * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 18 | 
         
            +
             * See the License for the specific language governing permissions and
         
     | 
| 19 | 
         
            +
             * limitations under the License.
         
     | 
| 20 | 
         
            +
             */
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            #include <torch/all.h>
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            #include <ATen/cuda/CUDAContext.h>
         
     | 
| 25 | 
         
            +
            #include <c10/cuda/CUDAGuard.h>
         
     | 
| 26 | 
         
            +
            #include <cuda.h>
         
     | 
| 27 | 
         
            +
            #include <cuda_fp16.h>
         
     | 
| 28 | 
         
            +
            #include <cuda_runtime.h>
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            #include <iostream>
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            #include "../dense/common/base.h"
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
         
     | 
| 35 | 
         
            +
              #include "../dense/common/mem.h"
         
     | 
| 36 | 
         
            +
            #endif
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            template <typename T>
         
     | 
| 39 | 
         
            +
            inline std::string str(T x) {
         
     | 
| 40 | 
         
            +
              return std::to_string(x);
         
     | 
| 41 | 
         
            +
            }
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            namespace {
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            using I4 = Vec<int, 4>;
         
     | 
| 48 | 
         
            +
            // Matrix fragments for tensor core instructions; their precise layout is
         
     | 
| 49 | 
         
            +
            // documented here:
         
     | 
| 50 | 
         
            +
            // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type
         
     | 
| 51 | 
         
            +
            using FragA = Vec<uint32_t, 2>;
         
     | 
| 52 | 
         
            +
            using FragB = Vec<uint32_t, 1>;
         
     | 
| 53 | 
         
            +
            using FragC = Vec<int, 4>;
         
     | 
| 54 | 
         
            +
            using FragS_GROUP = Vec<half2, 1>;  // weight per-group quantization scales
         
     | 
| 55 | 
         
            +
            using FragS_CHANNEL =
         
     | 
| 56 | 
         
            +
                Vec<float, 2>;  // weight per-channel quantization scales or activaton
         
     | 
| 57 | 
         
            +
                                // per-token quantization scales
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            // NOTE(HandH1998): cp.async.cg only support BYTES = 16, however,
         
     | 
| 60 | 
         
            +
            // cp.async.ca can support BYTES = 4, 8, 16;
         
     | 
| 61 | 
         
            +
            // as s_tok's shape is equal to prob_m, we need set s_tok to float type,
         
     | 
| 62 | 
         
            +
            // and cp_size = 1 float, i.e., 4 BYTES
         
     | 
| 63 | 
         
            +
            // Asynchronous global->shared copy for activation quantizaton scales s_tok
         
     | 
| 64 | 
         
            +
            __device__ inline void cp_async1(void* smem_ptr, const void* glob_ptr) {
         
     | 
| 65 | 
         
            +
              const int BYTES = 4;
         
     | 
| 66 | 
         
            +
              uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
         
     | 
| 67 | 
         
            +
              asm volatile(
         
     | 
| 68 | 
         
            +
                  "{\n"
         
     | 
| 69 | 
         
            +
                  "   cp.async.ca.shared.global [%0], [%1], %2;\n"
         
     | 
| 70 | 
         
            +
                  "}\n" ::"r"(smem),
         
     | 
| 71 | 
         
            +
                  "l"(glob_ptr), "n"(BYTES));
         
     | 
| 72 | 
         
            +
            }
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            // m16n8k16 tensor core mma instruction with int8 inputs and int32
         
     | 
| 75 | 
         
            +
            // output/accumulation.
         
     | 
| 76 | 
         
            +
            __device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
         
     | 
| 77 | 
         
            +
                                       FragC& frag_c) {
         
     | 
| 78 | 
         
            +
              const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
         
     | 
| 79 | 
         
            +
              const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
         
     | 
| 80 | 
         
            +
              int* c = reinterpret_cast<int*>(&frag_c);
         
     | 
| 81 | 
         
            +
              asm volatile(
         
     | 
| 82 | 
         
            +
                  "mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 "
         
     | 
| 83 | 
         
            +
                  "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
         
     | 
| 84 | 
         
            +
                  : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3])
         
     | 
| 85 | 
         
            +
                  : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]),
         
     | 
| 86 | 
         
            +
                    "r"(c[3]));
         
     | 
| 87 | 
         
            +
            }
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
            // Instruction for loading a full 16x16 matrix fragment of operand A from shared
         
     | 
| 90 | 
         
            +
            // memory, directly in int8 tensor core layout.
         
     | 
| 91 | 
         
            +
            __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
         
     | 
| 92 | 
         
            +
              uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
         
     | 
| 93 | 
         
            +
              uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
         
     | 
| 94 | 
         
            +
              asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
         
     | 
| 95 | 
         
            +
                           : "=r"(a[0]), "=r"(a[1])
         
     | 
| 96 | 
         
            +
                           : "r"(smem));
         
     | 
| 97 | 
         
            +
            }
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            inline __device__ half2 float2_to_half2(float2 f) {
         
     | 
| 100 | 
         
            +
              uint32_t res;
         
     | 
| 101 | 
         
            +
              // NOTE(HandH1998): h0,h1 should be uint16_t, not half
         
     | 
| 102 | 
         
            +
              uint16_t h0, h1;
         
     | 
| 103 | 
         
            +
              asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h0) : "f"(f.x));
         
     | 
| 104 | 
         
            +
              asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h1) : "f"(f.y));
         
     | 
| 105 | 
         
            +
              asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(h0), "h"(h1));
         
     | 
| 106 | 
         
            +
              return reinterpret_cast<half2&>(res);
         
     | 
| 107 | 
         
            +
            }
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
            inline __device__ float int32_to_float(int h) {
         
     | 
| 110 | 
         
            +
              float res;
         
     | 
| 111 | 
         
            +
              asm volatile("cvt.rn.f32.s32 %0, %1;\n" : "=f"(res) : "r"(h));
         
     | 
| 112 | 
         
            +
              return res;
         
     | 
| 113 | 
         
            +
            }
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            // Lookup-table based 3-input logical operation; explicitly used for
         
     | 
| 116 | 
         
            +
            // dequantization as the compiler does not seem to automatically recognize it in
         
     | 
| 117 | 
         
            +
            // all cases.
         
     | 
| 118 | 
         
            +
            template <int lut>
         
     | 
| 119 | 
         
            +
            __device__ inline int lop3(int a, int b, int c) {
         
     | 
| 120 | 
         
            +
              int res;
         
     | 
| 121 | 
         
            +
              asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
         
     | 
| 122 | 
         
            +
                           : "=r"(res)
         
     | 
| 123 | 
         
            +
                           : "r"(a), "r"(b), "r"(c), "n"(lut));
         
     | 
| 124 | 
         
            +
              return res;
         
     | 
| 125 | 
         
            +
            }
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
            // Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values
         
     | 
| 128 | 
         
            +
            // for weight per channel dequant.
         
     | 
| 129 | 
         
            +
            __device__ inline FragB dequant_per_channel(int q) {
         
     | 
| 130 | 
         
            +
              static constexpr int MASK = 0xf0f0f0f0;
         
     | 
| 131 | 
         
            +
              FragB frag_b;
         
     | 
| 132 | 
         
            +
              frag_b[0] = (q & MASK);
         
     | 
| 133 | 
         
            +
              return frag_b;
         
     | 
| 134 | 
         
            +
            }
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            // Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values
         
     | 
| 137 | 
         
            +
            // for weight per group dequant.
         
     | 
| 138 | 
         
            +
            __device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) {
         
     | 
| 139 | 
         
            +
              static constexpr uint32_t LO = 0x000f000f;
         
     | 
| 140 | 
         
            +
              static constexpr uint32_t HI = 0x00f000f0;
         
     | 
| 141 | 
         
            +
              static constexpr uint32_t EX = 0x64006400;
         
     | 
| 142 | 
         
            +
              // Guarantee that the `(a & b) | c` operations are LOP3s.
         
     | 
| 143 | 
         
            +
              uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
         
     | 
| 144 | 
         
            +
              uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
         
     | 
| 145 | 
         
            +
              // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
         
     | 
| 146 | 
         
            +
              // directly into `SUB` and `ADD`.
         
     | 
| 147 | 
         
            +
              static constexpr uint32_t SUB = 0x64086408;
         
     | 
| 148 | 
         
            +
              static constexpr uint32_t MUL = 0x2c002c00;
         
     | 
| 149 | 
         
            +
              static constexpr uint32_t ADD = 0xd480d480;
         
     | 
| 150 | 
         
            +
              *reinterpret_cast<half2*>(&t0) = __hsub2(
         
     | 
| 151 | 
         
            +
                  *reinterpret_cast<half2*>(&t0), *reinterpret_cast<const half2*>(&SUB));
         
     | 
| 152 | 
         
            +
              *reinterpret_cast<half2*>(&t1) = __hfma2(
         
     | 
| 153 | 
         
            +
                  *reinterpret_cast<half2*>(&t1), *reinterpret_cast<const half2*>(&MUL),
         
     | 
| 154 | 
         
            +
                  *reinterpret_cast<const half2*>(&ADD));
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
              uint16_t s = reinterpret_cast<uint16_t*>(&frag_s)[i];
         
     | 
| 157 | 
         
            +
              uint32_t double_s;
         
     | 
| 158 | 
         
            +
              // pack 2xfp16 to half2
         
     | 
| 159 | 
         
            +
              asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(double_s) : "h"(s), "h"(s));
         
     | 
| 160 | 
         
            +
              // dequant and convert 4 half to 4 uint8 (be placed at the low 8 bits of 4
         
     | 
| 161 | 
         
            +
              // half, respectively)
         
     | 
| 162 | 
         
            +
              static constexpr uint32_t MAGIC_NUM = 0x64806480;
         
     | 
| 163 | 
         
            +
              *reinterpret_cast<half2*>(&t0) = __hfma2(
         
     | 
| 164 | 
         
            +
                  *reinterpret_cast<half2*>(&t0), *reinterpret_cast<half2*>(&double_s),
         
     | 
| 165 | 
         
            +
                  *reinterpret_cast<const half2*>(&MAGIC_NUM));
         
     | 
| 166 | 
         
            +
              *reinterpret_cast<half2*>(&t1) = __hfma2(
         
     | 
| 167 | 
         
            +
                  *reinterpret_cast<half2*>(&t1), *reinterpret_cast<half2*>(&double_s),
         
     | 
| 168 | 
         
            +
                  *reinterpret_cast<const half2*>(&MAGIC_NUM));
         
     | 
| 169 | 
         
            +
              // take out the 4 uint8 from 4 half, then convert them to 4 int8 and pack 4
         
     | 
| 170 | 
         
            +
              // int8 into 1 uint32
         
     | 
| 171 | 
         
            +
              FragB frag_b;
         
     | 
| 172 | 
         
            +
              uint32_t uint8s;
         
     | 
| 173 | 
         
            +
              static constexpr uint32_t MASK_0246 = 0x6420;
         
     | 
| 174 | 
         
            +
              static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
         
     | 
| 175 | 
         
            +
              asm volatile("prmt.b32 %0,%1,%2,%3;\n"
         
     | 
| 176 | 
         
            +
                           : "=r"(uint8s)
         
     | 
| 177 | 
         
            +
                           : "r"(t0), "r"(t1), "n"(MASK_0246));
         
     | 
| 178 | 
         
            +
              frag_b[0] = (uint8s ^ UINT8s_TO_INT8s_MASK);
         
     | 
| 179 | 
         
            +
              return frag_b;
         
     | 
| 180 | 
         
            +
            }
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
            template <const int threads,          // number of threads in a threadblock
         
     | 
| 183 | 
         
            +
                      const int thread_m_blocks,  // number of 16x16 blocks in the m
         
     | 
| 184 | 
         
            +
                                                  // dimension (batchsize) of the
         
     | 
| 185 | 
         
            +
                                                  // threadblock
         
     | 
| 186 | 
         
            +
                      const int thread_n_blocks,  // same for n dimension (output)
         
     | 
| 187 | 
         
            +
                      const int thread_k_blocks,  // same for k dimension (reduction)
         
     | 
| 188 | 
         
            +
                      const int stages,  // number of stages for the async global->shared
         
     | 
| 189 | 
         
            +
                                         // fetch pipeline
         
     | 
| 190 | 
         
            +
                      const int group_blocks = -1  // number of consecutive 16x16 blocks
         
     | 
| 191 | 
         
            +
                                                   // with a separate quantization scale
         
     | 
| 192 | 
         
            +
                      >
         
     | 
| 193 | 
         
            +
            __global__ void Marlin(
         
     | 
| 194 | 
         
            +
                const int4* __restrict__ A,  // int8 input matrix of shape mxk
         
     | 
| 195 | 
         
            +
                const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn
         
     | 
| 196 | 
         
            +
                int4* __restrict__ C,        // int32 global_reduce buffer of shape
         
     | 
| 197 | 
         
            +
                                       // (max_par*16*4)xn, as int8 tensor core's output is
         
     | 
| 198 | 
         
            +
                                       // int32 dtype
         
     | 
| 199 | 
         
            +
                int4* __restrict__ D,              // fp16 output buffer of shape mxn
         
     | 
| 200 | 
         
            +
                const float* __restrict__ s_tok,   // fp32 activation per-token quantization
         
     | 
| 201 | 
         
            +
                                                   // scales of shape mx1
         
     | 
| 202 | 
         
            +
                const int4* __restrict__ s_ch,     // fp32 weight per-channel quantization
         
     | 
| 203 | 
         
            +
                                                   // scales of shape 1xn
         
     | 
| 204 | 
         
            +
                const int4* __restrict__ s_group,  // fp16 weight per-group quantization
         
     | 
| 205 | 
         
            +
                                                   // scales of shape (k/groupsize)xn, when
         
     | 
| 206 | 
         
            +
                                                   // group_blocks=-1, it should be nullptr
         
     | 
| 207 | 
         
            +
                int prob_m,                        // batch dimension m
         
     | 
| 208 | 
         
            +
                int prob_n,                        // output dimension n
         
     | 
| 209 | 
         
            +
                int prob_k,                        // reduction dimension k
         
     | 
| 210 | 
         
            +
                int* locks  // extra global storage for barrier synchronization
         
     | 
| 211 | 
         
            +
            ) {
         
     | 
| 212 | 
         
            +
              // Each threadblock processes one "stripe" of the B matrix with (roughly) the
         
     | 
| 213 | 
         
            +
              // same size, which might involve multiple column "slices" (of width 16 *
         
     | 
| 214 | 
         
            +
              // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
         
     | 
| 215 | 
         
            +
              // example:
         
     | 
| 216 | 
         
            +
              //   0 1 3
         
     | 
| 217 | 
         
            +
              //   0 2 3
         
     | 
| 218 | 
         
            +
              //   1 2 4
         
     | 
| 219 | 
         
            +
              // While this kind of partitioning makes things somewhat more complicated, it
         
     | 
| 220 | 
         
            +
              // ensures good utilization of all SMs for many kinds of shape and GPU
         
     | 
| 221 | 
         
            +
              // configurations, while requiring as few slow global cross-threadblock
         
     | 
| 222 | 
         
            +
              // reductions as possible.
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
              // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
         
     | 
| 225 | 
         
            +
              // better partitioning with less reductions
         
     | 
| 226 | 
         
            +
              int parallel = 1;
         
     | 
| 227 | 
         
            +
              if (prob_m > 16 * thread_m_blocks) {
         
     | 
| 228 | 
         
            +
                parallel = prob_m / (16 * thread_m_blocks);
         
     | 
| 229 | 
         
            +
                prob_m = 16 * thread_m_blocks;
         
     | 
| 230 | 
         
            +
              }
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
              int k_tiles = prob_k / 16 / thread_k_blocks;
         
     | 
| 233 | 
         
            +
              int n_tiles = prob_n / 16 / thread_n_blocks;
         
     | 
| 234 | 
         
            +
              int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
         
     | 
| 235 | 
         
            +
              // Ensure that the number of tiles in each stripe is a multiple of the
         
     | 
| 236 | 
         
            +
              // groupsize; this avoids an annoying special case where a stripe starts in
         
     | 
| 237 | 
         
            +
              // the middle of group.
         
     | 
| 238 | 
         
            +
              if constexpr (group_blocks != -1)
         
     | 
| 239 | 
         
            +
                iters = (group_blocks / thread_k_blocks) *
         
     | 
| 240 | 
         
            +
                        ceildiv(iters, (group_blocks / thread_k_blocks));
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
              int slice_row = (iters * blockIdx.x) % k_tiles;
         
     | 
| 243 | 
         
            +
              int slice_col_par = (iters * blockIdx.x) / k_tiles;
         
     | 
| 244 | 
         
            +
              int slice_col = slice_col_par;
         
     | 
| 245 | 
         
            +
              int slice_iters;  // number of threadblock tiles in the current slice
         
     | 
| 246 | 
         
            +
              int slice_count =
         
     | 
| 247 | 
         
            +
                  0;          // total number of active threadblocks in the current slice
         
     | 
| 248 | 
         
            +
              int slice_idx;  // index of threadblock in current slice; numbered bottom to
         
     | 
| 249 | 
         
            +
                              // top
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
              // We can easily implement parallel problem execution by just remapping
         
     | 
| 252 | 
         
            +
              // indices and advancing global pointers
         
     | 
| 253 | 
         
            +
              if (slice_col_par >= n_tiles) {
         
     | 
| 254 | 
         
            +
                A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 16;
         
     | 
| 255 | 
         
            +
                C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 4;
         
     | 
| 256 | 
         
            +
                D += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
         
     | 
| 257 | 
         
            +
                s_tok += (slice_col_par / n_tiles) * 16 * thread_m_blocks;
         
     | 
| 258 | 
         
            +
                locks += (slice_col_par / n_tiles) * n_tiles;
         
     | 
| 259 | 
         
            +
                slice_col = slice_col_par % n_tiles;
         
     | 
| 260 | 
         
            +
              }
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
              // Compute all information about the current slice which is required for
         
     | 
| 263 | 
         
            +
              // synchronization.
         
     | 
| 264 | 
         
            +
              auto init_slice = [&]() {
         
     | 
| 265 | 
         
            +
                slice_iters =
         
     | 
| 266 | 
         
            +
                    iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
         
     | 
| 267 | 
         
            +
                if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
         
     | 
| 268 | 
         
            +
                if (slice_iters == 0) return;
         
     | 
| 269 | 
         
            +
                if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
         
     | 
| 270 | 
         
            +
                slice_count = 1;
         
     | 
| 271 | 
         
            +
                slice_idx = 0;
         
     | 
| 272 | 
         
            +
                int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
         
     | 
| 273 | 
         
            +
                if (col_first <= k_tiles * (slice_col_par + 1)) {
         
     | 
| 274 | 
         
            +
                  int col_off = col_first - k_tiles * slice_col_par;
         
     | 
| 275 | 
         
            +
                  slice_count = ceildiv(k_tiles - col_off, iters);
         
     | 
| 276 | 
         
            +
                  if (col_off > 0) slice_count++;
         
     | 
| 277 | 
         
            +
                  int delta_first = iters * blockIdx.x - col_first;
         
     | 
| 278 | 
         
            +
                  if (delta_first < 0 || (col_off == 0 && delta_first == 0))
         
     | 
| 279 | 
         
            +
                    slice_idx = slice_count - 1;
         
     | 
| 280 | 
         
            +
                  else {
         
     | 
| 281 | 
         
            +
                    slice_idx = slice_count - 1 - delta_first / iters;
         
     | 
| 282 | 
         
            +
                    if (col_off > 0) slice_idx--;
         
     | 
| 283 | 
         
            +
                  }
         
     | 
| 284 | 
         
            +
                }
         
     | 
| 285 | 
         
            +
                if (slice_col == n_tiles) {
         
     | 
| 286 | 
         
            +
                  A += 16 * thread_m_blocks * prob_k / 16;
         
     | 
| 287 | 
         
            +
                  C += 16 * thread_m_blocks * prob_n / 4;
         
     | 
| 288 | 
         
            +
                  D += 16 * thread_m_blocks * prob_n / 8;
         
     | 
| 289 | 
         
            +
                  s_tok += 16 * thread_m_blocks;
         
     | 
| 290 | 
         
            +
                  locks += n_tiles;
         
     | 
| 291 | 
         
            +
                  slice_col = 0;
         
     | 
| 292 | 
         
            +
                }
         
     | 
| 293 | 
         
            +
              };
         
     | 
| 294 | 
         
            +
              init_slice();
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
              int a_gl_stride = prob_k / 16;  // stride of the A matrix in global memory
         
     | 
| 297 | 
         
            +
              // We typically use `constexpr` to indicate that this value is a compile-time
         
     | 
| 298 | 
         
            +
              // constant
         
     | 
| 299 | 
         
            +
              constexpr int a_sh_stride =
         
     | 
| 300 | 
         
            +
                  16 * thread_k_blocks / 16;  // stride of an A matrix tile in shared memory
         
     | 
| 301 | 
         
            +
              constexpr int a_gl_rd_delta_o =
         
     | 
| 302 | 
         
            +
                  16 * thread_k_blocks /
         
     | 
| 303 | 
         
            +
                  16;  // delta between subsequent A tiles in global memory
         
     | 
| 304 | 
         
            +
              int a_gl_rd_delta_i =
         
     | 
| 305 | 
         
            +
                  a_gl_stride *
         
     | 
| 306 | 
         
            +
                  (threads / a_gl_rd_delta_o);  // between subsequent accesses within a tile
         
     | 
| 307 | 
         
            +
              constexpr int a_sh_wr_delta =
         
     | 
| 308 | 
         
            +
                  a_sh_stride *
         
     | 
| 309 | 
         
            +
                  (threads / a_gl_rd_delta_o);  // between shared memory writes
         
     | 
| 310 | 
         
            +
              constexpr int a_sh_rd_delta_o =
         
     | 
| 311 | 
         
            +
                  1 * ((threads / 32) /
         
     | 
| 312 | 
         
            +
                       (thread_n_blocks / 4));  // between shared memory tile reads
         
     | 
| 313 | 
         
            +
              constexpr int a_sh_rd_delta_i =
         
     | 
| 314 | 
         
            +
                  a_sh_stride * 16;  // within a shared memory tile
         
     | 
| 315 | 
         
            +
              constexpr int a_sh_stage =
         
     | 
| 316 | 
         
            +
                  a_sh_stride * (16 * thread_m_blocks);  // overall size of a tile
         
     | 
| 317 | 
         
            +
              constexpr int a_sh_wr_iters =
         
     | 
| 318 | 
         
            +
                  ceildiv(a_sh_stage,
         
     | 
| 319 | 
         
            +
                          a_sh_wr_delta);  // number of shared write iterations for a tile
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
              int b_gl_stride = 16 * prob_n / 32;
         
     | 
| 322 | 
         
            +
              constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
         
     | 
| 323 | 
         
            +
              int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
         
     | 
| 324 | 
         
            +
              int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
         
     | 
| 325 | 
         
            +
              constexpr int b_sh_wr_delta = threads;
         
     | 
| 326 | 
         
            +
              constexpr int b_sh_rd_delta = threads;
         
     | 
| 327 | 
         
            +
              constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
         
     | 
| 328 | 
         
            +
              constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
              constexpr int s_tok_sh_stride = 16 * thread_m_blocks;
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
              constexpr int s_ch_sh_stride = 16 * thread_n_blocks / 4;
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
              int s_group_gl_stride = prob_n / 8;
         
     | 
| 335 | 
         
            +
              constexpr int s_group_sh_stride = 16 * thread_n_blocks / 8;
         
     | 
| 336 | 
         
            +
              constexpr int s_group_sh_stage = s_group_sh_stride;
         
     | 
| 337 | 
         
            +
              int s_group_gl_rd_delta = s_group_gl_stride;
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
              // Global A read index of current thread.
         
     | 
| 340 | 
         
            +
              int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
         
     | 
| 341 | 
         
            +
                            (threadIdx.x % a_gl_rd_delta_o);
         
     | 
| 342 | 
         
            +
              a_gl_rd += a_gl_rd_delta_o * slice_row;
         
     | 
| 343 | 
         
            +
              // Shared write index of current thread.
         
     | 
| 344 | 
         
            +
              int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
         
     | 
| 345 | 
         
            +
                            (threadIdx.x % a_gl_rd_delta_o);
         
     | 
| 346 | 
         
            +
              // Shared read index.
         
     | 
| 347 | 
         
            +
              // NOTE(HandH1998): int8 input a only need 16 threads to load 16x16 matrix
         
     | 
| 348 | 
         
            +
              int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16);
         
     | 
| 349 | 
         
            +
              a_sh_rd += 1 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
              int b_gl_rd =
         
     | 
| 352 | 
         
            +
                  b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
         
     | 
| 353 | 
         
            +
              b_gl_rd += b_sh_stride * slice_col;
         
     | 
| 354 | 
         
            +
              b_gl_rd += b_gl_rd_delta_o * slice_row;
         
     | 
| 355 | 
         
            +
              int b_sh_wr = threadIdx.x;
         
     | 
| 356 | 
         
            +
              int b_sh_rd = threadIdx.x;
         
     | 
| 357 | 
         
            +
             
     | 
| 358 | 
         
            +
              int s_tok_gl_rd = threadIdx.x;
         
     | 
| 359 | 
         
            +
              // NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10,
         
     | 
| 360 | 
         
            +
              // 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for
         
     | 
| 361 | 
         
            +
              // thread 0, 1, 2, 3. For more details, refer to mma operand A layout as
         
     | 
| 362 | 
         
            +
              // s_tok's size is not fixed, we can not shuffle before inference we shuffle
         
     | 
| 363 | 
         
            +
              // it when fetching s_tok from global memory to shared memory, that's why
         
     | 
| 364 | 
         
            +
              // s_tok_sh_wr is like this
         
     | 
| 365 | 
         
            +
              int s_tok_sh_wr =
         
     | 
| 366 | 
         
            +
                  (threadIdx.x / 16) * 16 + (threadIdx.x % 8) * 2 + (threadIdx.x % 16) / 8;
         
     | 
| 367 | 
         
            +
              int s_tok_sh_rd = (threadIdx.x % 32) / 4;
         
     | 
| 368 | 
         
            +
              bool s_tok_sh_wr_pred = threadIdx.x < prob_m;
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
              int s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
         
     | 
| 371 | 
         
            +
              int s_ch_sh_wr = threadIdx.x;
         
     | 
| 372 | 
         
            +
              int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
         
     | 
| 373 | 
         
            +
                               2 * ((threadIdx.x % 32) % 4);
         
     | 
| 374 | 
         
            +
              bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride;
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
              int s_group_gl_rd, s_group_sh_wr, s_group_sh_rd;
         
     | 
| 377 | 
         
            +
              bool s_group_sh_wr_pred;
         
     | 
| 378 | 
         
            +
              if constexpr (group_blocks != -1) {
         
     | 
| 379 | 
         
            +
                s_group_gl_rd =
         
     | 
| 380 | 
         
            +
                    s_group_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
         
     | 
| 381 | 
         
            +
                    s_group_sh_stride * slice_col + threadIdx.x;
         
     | 
| 382 | 
         
            +
                s_group_sh_wr = threadIdx.x;
         
     | 
| 383 | 
         
            +
                // NOTE(HandH1998): s_group_sh_rd is related to mma output C
         
     | 
| 384 | 
         
            +
                s_group_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
         
     | 
| 385 | 
         
            +
                                (threadIdx.x % 32) / 4;
         
     | 
| 386 | 
         
            +
                s_group_sh_wr_pred = threadIdx.x < s_group_sh_stride;
         
     | 
| 387 | 
         
            +
              }
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
              // Precompute which thread should not read memory in which iterations; this is
         
     | 
| 390 | 
         
            +
              // needed if there are more threads than required for a certain tilesize or
         
     | 
| 391 | 
         
            +
              // when the batchsize is not a multiple of 16.
         
     | 
| 392 | 
         
            +
              bool a_sh_wr_pred[a_sh_wr_iters];
         
     | 
| 393 | 
         
            +
              #pragma unroll
         
     | 
| 394 | 
         
            +
              for (int i = 0; i < a_sh_wr_iters; i++)
         
     | 
| 395 | 
         
            +
                a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
              // To ensure that writing and reading A tiles to/from shared memory, the
         
     | 
| 398 | 
         
            +
              // latter in fragment format, is fully bank conflict free, we need to use a
         
     | 
| 399 | 
         
            +
              // rather fancy XOR-based layout. The key here is that neither reads nor
         
     | 
| 400 | 
         
            +
              // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
         
     | 
| 401 | 
         
            +
              // same shared memory banks. Further, it seems (based on NSight-Compute) that
         
     | 
| 402 | 
         
            +
              // each warp must also write a consecutive memory segment?
         
     | 
| 403 | 
         
            +
              auto transform_a = [&](int i) {
         
     | 
| 404 | 
         
            +
                int row = i / a_gl_rd_delta_o;
         
     | 
| 405 | 
         
            +
                return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
         
     | 
| 406 | 
         
            +
              };
         
     | 
| 407 | 
         
            +
              // Since the computation of this remapping is non-trivial and, due to our main
         
     | 
| 408 | 
         
            +
              // loop unrolls, all shared memory accesses are static, we simply precompute
         
     | 
| 409 | 
         
            +
              // both transformed reads and writes.
         
     | 
| 410 | 
         
            +
              int a_sh_wr_trans[a_sh_wr_iters];
         
     | 
| 411 | 
         
            +
              #pragma unroll
         
     | 
| 412 | 
         
            +
              for (int i = 0; i < a_sh_wr_iters; i++)
         
     | 
| 413 | 
         
            +
                a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
         
     | 
| 414 | 
         
            +
              int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
         
     | 
| 415 | 
         
            +
              #pragma unroll
         
     | 
| 416 | 
         
            +
              for (int i = 0; i < b_sh_wr_iters; i++) {
         
     | 
| 417 | 
         
            +
              #pragma unroll
         
     | 
| 418 | 
         
            +
                for (int j = 0; j < thread_m_blocks; j++)
         
     | 
| 419 | 
         
            +
                  a_sh_rd_trans[i][j] =
         
     | 
| 420 | 
         
            +
                      transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
         
     | 
| 421 | 
         
            +
              }
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
              // Since B-accesses have non-constant stride they have to be computed at
         
     | 
| 424 | 
         
            +
              // runtime; we break dependencies between subsequent accesses with a tile by
         
     | 
| 425 | 
         
            +
              // maintining multiple pointers (we have enough registers), a tiny
         
     | 
| 426 | 
         
            +
              // optimization.
         
     | 
| 427 | 
         
            +
              const int4* B_ptr[b_sh_wr_iters];
         
     | 
| 428 | 
         
            +
              #pragma unroll
         
     | 
| 429 | 
         
            +
              for (int i = 0; i < b_sh_wr_iters; i++)
         
     | 
| 430 | 
         
            +
                B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
         
     | 
| 431 | 
         
            +
             
     | 
| 432 | 
         
            +
              extern __shared__ int4 sh[];
         
     | 
| 433 | 
         
            +
              // Shared memory storage for global fetch pipelines.
         
     | 
| 434 | 
         
            +
              // NOTE(HandH1998): stages need >= 4, otherwise, sh_s_tok = sh + max(stages *
         
     | 
| 435 | 
         
            +
              // a_sh_stage + stages * b_sh_stage, 4 * stages * a_sh_stage)
         
     | 
| 436 | 
         
            +
              int4* sh_a = sh;
         
     | 
| 437 | 
         
            +
              int4* sh_b = sh_a + (stages * a_sh_stage);
         
     | 
| 438 | 
         
            +
              int4* sh_s_tok = sh_b + (stages * b_sh_stage);
         
     | 
| 439 | 
         
            +
              int4* sh_s_ch = sh_s_tok + s_tok_sh_stride;
         
     | 
| 440 | 
         
            +
              int4* sh_s_group = sh_s_ch + s_ch_sh_stride;
         
     | 
| 441 | 
         
            +
             
     | 
| 442 | 
         
            +
              // Register storage for double buffer of shared memory reads.
         
     | 
| 443 | 
         
            +
              FragA frag_a[2][thread_m_blocks];
         
     | 
| 444 | 
         
            +
              I4 frag_b_quant[2];
         
     | 
| 445 | 
         
            +
              FragC frag_c[thread_m_blocks][4][2];
         
     | 
| 446 | 
         
            +
              FragS_GROUP frag_s_group[2][4];
         
     | 
| 447 | 
         
            +
              FragS_CHANNEL frag_s_tok[thread_m_blocks];
         
     | 
| 448 | 
         
            +
              FragS_CHANNEL frag_s_ch[2][4];
         
     | 
| 449 | 
         
            +
             
     | 
| 450 | 
         
            +
              // Zero accumulators.
         
     | 
| 451 | 
         
            +
              auto zero_accums = [&]() {
         
     | 
| 452 | 
         
            +
              #pragma unroll
         
     | 
| 453 | 
         
            +
                for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
         
     | 
| 454 | 
         
            +
                  reinterpret_cast<int*>(frag_c)[i] = 0;
         
     | 
| 455 | 
         
            +
              };
         
     | 
| 456 | 
         
            +
             
     | 
| 457 | 
         
            +
              // Asynchronously fetch the next A, B and s tile from global to the next
         
     | 
| 458 | 
         
            +
              // shared memory pipeline location.
         
     | 
| 459 | 
         
            +
              auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
         
     | 
| 460 | 
         
            +
                if (pred) {
         
     | 
| 461 | 
         
            +
                  int4* sh_a_stage = sh_a + a_sh_stage * pipe;
         
     | 
| 462 | 
         
            +
              #pragma unroll
         
     | 
| 463 | 
         
            +
                  for (int i = 0; i < a_sh_wr_iters; i++) {
         
     | 
| 464 | 
         
            +
                    cp_async4_pred(
         
     | 
| 465 | 
         
            +
                        &sh_a_stage[a_sh_wr_trans[i]],
         
     | 
| 466 | 
         
            +
                        &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
         
     | 
| 467 | 
         
            +
                        a_sh_wr_pred[i]);
         
     | 
| 468 | 
         
            +
                  }
         
     | 
| 469 | 
         
            +
                  int4* sh_b_stage = sh_b + b_sh_stage * pipe;
         
     | 
| 470 | 
         
            +
              #pragma unroll
         
     | 
| 471 | 
         
            +
                  for (int i = 0; i < b_sh_wr_iters; i++) {
         
     | 
| 472 | 
         
            +
                    cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
         
     | 
| 473 | 
         
            +
                    B_ptr[i] += b_gl_rd_delta_o;
         
     | 
| 474 | 
         
            +
                  }
         
     | 
| 475 | 
         
            +
                  // Only fetch scales if this tile starts a new group
         
     | 
| 476 | 
         
            +
                  if constexpr (group_blocks != -1) {
         
     | 
| 477 | 
         
            +
                    if (pipe % (group_blocks / thread_k_blocks) == 0) {
         
     | 
| 478 | 
         
            +
                      int4* sh_s_group_stage = sh_s_group + s_group_sh_stage * pipe;
         
     | 
| 479 | 
         
            +
                      if (s_group_sh_wr_pred)
         
     | 
| 480 | 
         
            +
                        cp_async4(&sh_s_group_stage[s_group_sh_wr],
         
     | 
| 481 | 
         
            +
                                  &s_group[s_group_gl_rd]);
         
     | 
| 482 | 
         
            +
                      s_group_gl_rd += s_group_gl_rd_delta;
         
     | 
| 483 | 
         
            +
                    }
         
     | 
| 484 | 
         
            +
                  }
         
     | 
| 485 | 
         
            +
                }
         
     | 
| 486 | 
         
            +
                // Insert a fence even when we are winding down the pipeline to ensure that
         
     | 
| 487 | 
         
            +
                // waiting is also correct at this point.
         
     | 
| 488 | 
         
            +
                cp_async_fence();
         
     | 
| 489 | 
         
            +
              };
         
     | 
| 490 | 
         
            +
             
     | 
| 491 | 
         
            +
              // Wait until the next thread tile has been loaded to shared memory.
         
     | 
| 492 | 
         
            +
              auto wait_for_stage = [&]() {
         
     | 
| 493 | 
         
            +
                // We only have `stages - 2` active fetches since we are double buffering
         
     | 
| 494 | 
         
            +
                // and can only issue the next fetch when it is guaranteed that the previous
         
     | 
| 495 | 
         
            +
                // shared memory load is fully complete (as it may otherwise be
         
     | 
| 496 | 
         
            +
                // overwritten).
         
     | 
| 497 | 
         
            +
                cp_async_wait<stages - 2>();
         
     | 
| 498 | 
         
            +
                __syncthreads();
         
     | 
| 499 | 
         
            +
              };
         
     | 
| 500 | 
         
            +
             
     | 
| 501 | 
         
            +
              // Load the next sub-tile from the current location in the shared memory pipe
         
     | 
| 502 | 
         
            +
              // into the current register buffer.
         
     | 
| 503 | 
         
            +
              auto fetch_to_registers = [&](int k, int pipe) {
         
     | 
| 504 | 
         
            +
                // It may seem inefficient that we reload the groups for every sub-tile;
         
     | 
| 505 | 
         
            +
                // however, this does not seem to be a significant bottleneck, while some
         
     | 
| 506 | 
         
            +
                // theoretically better attempts have lead to bad instruction ordering by
         
     | 
| 507 | 
         
            +
                // the compiler and correspondingly a noticeable drop in performance.
         
     | 
| 508 | 
         
            +
                if constexpr (group_blocks != -1) {
         
     | 
| 509 | 
         
            +
                  int4* sh_s_group_stage =
         
     | 
| 510 | 
         
            +
                      sh_s_group +
         
     | 
| 511 | 
         
            +
                      s_group_sh_stage * ((group_blocks / thread_k_blocks) *
         
     | 
| 512 | 
         
            +
                                          (pipe / (group_blocks / thread_k_blocks)));
         
     | 
| 513 | 
         
            +
                  reinterpret_cast<int4*>(&frag_s_group[k % 2])[0] =
         
     | 
| 514 | 
         
            +
                      sh_s_group_stage[s_group_sh_rd];
         
     | 
| 515 | 
         
            +
                }
         
     | 
| 516 | 
         
            +
                int4* sh_a_stage = sh_a + a_sh_stage * pipe;
         
     | 
| 517 | 
         
            +
              #pragma unroll
         
     | 
| 518 | 
         
            +
                for (int i = 0; i < thread_m_blocks; i++)
         
     | 
| 519 | 
         
            +
                  ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
         
     | 
| 520 | 
         
            +
                int4* sh_b_stage = sh_b + b_sh_stage * pipe;
         
     | 
| 521 | 
         
            +
                frag_b_quant[k % 2] = *reinterpret_cast<I4*>(
         
     | 
| 522 | 
         
            +
                    &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
         
     | 
| 523 | 
         
            +
              };
         
     | 
| 524 | 
         
            +
             
     | 
| 525 | 
         
            +
              // Execute the actual tensor core matmul of a sub-tile.
         
     | 
| 526 | 
         
            +
              auto matmul = [&](int k) {
         
     | 
| 527 | 
         
            +
              // We have the m dimension as the inner loop in order to encourage overlapping
         
     | 
| 528 | 
         
            +
              // dequantization and matmul operations.
         
     | 
| 529 | 
         
            +
              #pragma unroll
         
     | 
| 530 | 
         
            +
                for (int j = 0; j < 4; j++) {
         
     | 
| 531 | 
         
            +
                  int b_quant = frag_b_quant[k % 2][j];
         
     | 
| 532 | 
         
            +
                  // int b_quant_shift = b_quant << 4;
         
     | 
| 533 | 
         
            +
                  FragB frag_b0, frag_b1;
         
     | 
| 534 | 
         
            +
                  // If there are no groups, we can just scale the final output once and can
         
     | 
| 535 | 
         
            +
                  // avoid doing so for each weight.
         
     | 
| 536 | 
         
            +
                  if constexpr (group_blocks != -1) {
         
     | 
| 537 | 
         
            +
                    int b_quant_shift = b_quant >> 8;
         
     | 
| 538 | 
         
            +
                    frag_b0 = dequant_per_group(b_quant, frag_s_group[k % 2][j], 0);
         
     | 
| 539 | 
         
            +
                    frag_b1 = dequant_per_group(b_quant_shift, frag_s_group[k % 2][j], 1);
         
     | 
| 540 | 
         
            +
                  } else {
         
     | 
| 541 | 
         
            +
                    int b_quant_shift = b_quant << 4;
         
     | 
| 542 | 
         
            +
                    frag_b0 = dequant_per_channel(b_quant);
         
     | 
| 543 | 
         
            +
                    frag_b1 = dequant_per_channel(b_quant_shift);
         
     | 
| 544 | 
         
            +
                  }
         
     | 
| 545 | 
         
            +
              #pragma unroll
         
     | 
| 546 | 
         
            +
                  for (int i = 0; i < thread_m_blocks; i++) {
         
     | 
| 547 | 
         
            +
                    mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
         
     | 
| 548 | 
         
            +
                    mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
         
     | 
| 549 | 
         
            +
                  }
         
     | 
| 550 | 
         
            +
                }
         
     | 
| 551 | 
         
            +
              };
         
     | 
| 552 | 
         
            +
             
     | 
| 553 | 
         
            +
              // Since we slice across the k dimension of a tile in order to increase the
         
     | 
| 554 | 
         
            +
              // number of warps while keeping the n dimension of a tile reasonable, we have
         
     | 
| 555 | 
         
            +
              // multiple warps that accumulate their partial sums of the same output
         
     | 
| 556 | 
         
            +
              // location; which we have to reduce over in the end. We do in shared memory.
         
     | 
| 557 | 
         
            +
              auto thread_block_reduce = [&]() {
         
     | 
| 558 | 
         
            +
                constexpr int red_off = threads / b_sh_stride / 2;
         
     | 
| 559 | 
         
            +
                if (red_off >= 1) {
         
     | 
| 560 | 
         
            +
                  int red_idx = threadIdx.x / b_sh_stride;
         
     | 
| 561 | 
         
            +
                  constexpr int red_sh_stride = b_sh_stride * 4 * 2;
         
     | 
| 562 | 
         
            +
                  constexpr int red_sh_delta = b_sh_stride;
         
     | 
| 563 | 
         
            +
                  int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
         
     | 
| 564 | 
         
            +
                                  (threadIdx.x % b_sh_stride);
         
     | 
| 565 | 
         
            +
             
     | 
| 566 | 
         
            +
                  // Parallel logarithmic shared memory reduction. We make sure to avoid any
         
     | 
| 567 | 
         
            +
                  // unnecessary read or write iterations, e.g., for two warps we write only
         
     | 
| 568 | 
         
            +
                  // once by warp 1 and read only once by warp 0.
         
     | 
| 569 | 
         
            +
             
     | 
| 570 | 
         
            +
              #pragma unroll
         
     | 
| 571 | 
         
            +
                  for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
         
     | 
| 572 | 
         
            +
              #pragma unroll
         
     | 
| 573 | 
         
            +
                    for (int i = red_off; i > 0; i /= 2) {
         
     | 
| 574 | 
         
            +
                      if (i <= red_idx && red_idx < 2 * i) {
         
     | 
| 575 | 
         
            +
              #pragma unroll
         
     | 
| 576 | 
         
            +
                        for (int j = 0; j < 4 * 2; j++) {
         
     | 
| 577 | 
         
            +
                          int red_sh_wr =
         
     | 
| 578 | 
         
            +
                              red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
         
     | 
| 579 | 
         
            +
                          if (i < red_off) {
         
     | 
| 580 | 
         
            +
                            int* c_rd =
         
     | 
| 581 | 
         
            +
                                reinterpret_cast<int*>(&sh[red_sh_delta * j + red_sh_rd]);
         
     | 
| 582 | 
         
            +
                            int* c_wr = reinterpret_cast<int*>(&sh[red_sh_wr]);
         
     | 
| 583 | 
         
            +
              #pragma unroll
         
     | 
| 584 | 
         
            +
                            for (int k = 0; k < 4; k++)
         
     | 
| 585 | 
         
            +
                              reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
         
     | 
| 586 | 
         
            +
                                  c_rd[k] + c_wr[k];
         
     | 
| 587 | 
         
            +
                          }
         
     | 
| 588 | 
         
            +
                          sh[red_sh_wr] =
         
     | 
| 589 | 
         
            +
                              reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
         
     | 
| 590 | 
         
            +
                        }
         
     | 
| 591 | 
         
            +
                      }
         
     | 
| 592 | 
         
            +
                      __syncthreads();
         
     | 
| 593 | 
         
            +
                    }
         
     | 
| 594 | 
         
            +
                    if (red_idx == 0) {
         
     | 
| 595 | 
         
            +
              #pragma unroll
         
     | 
| 596 | 
         
            +
                      for (int i = 0; i < 4 * 2; i++) {
         
     | 
| 597 | 
         
            +
                        int* c_rd =
         
     | 
| 598 | 
         
            +
                            reinterpret_cast<int*>(&sh[red_sh_delta * i + red_sh_rd]);
         
     | 
| 599 | 
         
            +
              #pragma unroll
         
     | 
| 600 | 
         
            +
                        for (int j = 0; j < 4; j++)
         
     | 
| 601 | 
         
            +
                          reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
         
     | 
| 602 | 
         
            +
                              c_rd[j];
         
     | 
| 603 | 
         
            +
                      }
         
     | 
| 604 | 
         
            +
                    }
         
     | 
| 605 | 
         
            +
                    __syncthreads();
         
     | 
| 606 | 
         
            +
                  }
         
     | 
| 607 | 
         
            +
                }
         
     | 
| 608 | 
         
            +
              };
         
     | 
| 609 | 
         
            +
             
     | 
| 610 | 
         
            +
              // Since multiple threadblocks may process parts of the same column slice, we
         
     | 
| 611 | 
         
            +
              // finally have to globally reduce over the results. As the striped
         
     | 
| 612 | 
         
            +
              // partitioning minimizes the number of such reductions and our outputs are
         
     | 
| 613 | 
         
            +
              // usually rather small, we perform this reduction serially in L2 cache.
         
     | 
| 614 | 
         
            +
              // global_reduce works on INT32 elements, which are the results of INT8 GEMM.
         
     | 
| 615 | 
         
            +
              // This is why we need another INT32 maxtrix `C` to reduce instead of the
         
     | 
| 616 | 
         
            +
              // original half matrix `D`.
         
     | 
| 617 | 
         
            +
              auto global_reduce = [&](bool first = false, bool last = false) {
         
     | 
| 618 | 
         
            +
                // We are very careful here to reduce directly in the output buffer to
         
     | 
| 619 | 
         
            +
                // maximize L2 cache utilization in this step. To do this, we write out
         
     | 
| 620 | 
         
            +
                // results in FP16 (but still reduce with FP32 compute).
         
     | 
| 621 | 
         
            +
                constexpr int active_threads = 32 * thread_n_blocks / 4;
         
     | 
| 622 | 
         
            +
                if (threadIdx.x < active_threads) {
         
     | 
| 623 | 
         
            +
                  int c_gl_stride = prob_n / 4;
         
     | 
| 624 | 
         
            +
                  int c_gl_wr_delta_o = 8 * c_gl_stride;
         
     | 
| 625 | 
         
            +
                  int c_gl_wr_delta_i = 8 * (active_threads / 32);
         
     | 
| 626 | 
         
            +
                  int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
         
     | 
| 627 | 
         
            +
                                8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2;
         
     | 
| 628 | 
         
            +
                  c_gl_wr += (4 * thread_n_blocks) * slice_col;
         
     | 
| 629 | 
         
            +
                  constexpr int c_sh_wr_delta = active_threads * 2;
         
     | 
| 630 | 
         
            +
                  int c_sh_wr = 2 * threadIdx.x;
         
     | 
| 631 | 
         
            +
             
     | 
| 632 | 
         
            +
                  int row = (threadIdx.x % 32) / 4;
         
     | 
| 633 | 
         
            +
             
     | 
| 634 | 
         
            +
                  if (!first) {
         
     | 
| 635 | 
         
            +
              // Interestingly, doing direct global accesses here really seems to mess up
         
     | 
| 636 | 
         
            +
              // the compiler and lead to slowdowns, hence we also use async-copies even
         
     | 
| 637 | 
         
            +
              // though these fetches are not actually asynchronous.
         
     | 
| 638 | 
         
            +
              #pragma unroll
         
     | 
| 639 | 
         
            +
                    for (int i = 0; i < thread_m_blocks * 4; i++) {
         
     | 
| 640 | 
         
            +
                      cp_async4_pred(
         
     | 
| 641 | 
         
            +
                          &sh[c_sh_wr + c_sh_wr_delta * i],
         
     | 
| 642 | 
         
            +
                          &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
         
     | 
| 643 | 
         
            +
                             c_gl_wr_delta_i * (i % 2)],
         
     | 
| 644 | 
         
            +
                          i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
         
     | 
| 645 | 
         
            +
                      cp_async4_pred(
         
     | 
| 646 | 
         
            +
                          &sh[c_sh_wr + c_sh_wr_delta * i + 1],
         
     | 
| 647 | 
         
            +
                          &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
         
     | 
| 648 | 
         
            +
                             c_gl_wr_delta_i * (i % 2) + 1],
         
     | 
| 649 | 
         
            +
                          i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
         
     | 
| 650 | 
         
            +
                    }
         
     | 
| 651 | 
         
            +
                    cp_async_fence();
         
     | 
| 652 | 
         
            +
                    cp_async_wait<0>();
         
     | 
| 653 | 
         
            +
                  }
         
     | 
| 654 | 
         
            +
             
     | 
| 655 | 
         
            +
              #pragma unroll
         
     | 
| 656 | 
         
            +
                  for (int i = 0; i < thread_m_blocks * 4; i++) {
         
     | 
| 657 | 
         
            +
                    if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
         
     | 
| 658 | 
         
            +
                      if (!first) {
         
     | 
| 659 | 
         
            +
                        int4 d_red1 = sh[c_sh_wr + i * c_sh_wr_delta];
         
     | 
| 660 | 
         
            +
                        int4 d_red2 = sh[c_sh_wr + i * c_sh_wr_delta + 1];
         
     | 
| 661 | 
         
            +
              #pragma unroll
         
     | 
| 662 | 
         
            +
                        for (int j = 0; j < 4; j++) {
         
     | 
| 663 | 
         
            +
                          reinterpret_cast<int*>(
         
     | 
| 664 | 
         
            +
                              &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
         
     | 
| 665 | 
         
            +
                              reinterpret_cast<int*>(&d_red1)[j];
         
     | 
| 666 | 
         
            +
                        }
         
     | 
| 667 | 
         
            +
              #pragma unroll
         
     | 
| 668 | 
         
            +
                        for (int j = 0; j < 4; j++) {
         
     | 
| 669 | 
         
            +
                          reinterpret_cast<int*>(
         
     | 
| 670 | 
         
            +
                              &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)] +=
         
     | 
| 671 | 
         
            +
                              reinterpret_cast<int*>(&d_red2)[j];
         
     | 
| 672 | 
         
            +
                        }
         
     | 
| 673 | 
         
            +
                      }
         
     | 
| 674 | 
         
            +
                      if (!last) {
         
     | 
| 675 | 
         
            +
                        int4 d1, d2;
         
     | 
| 676 | 
         
            +
              #pragma unroll
         
     | 
| 677 | 
         
            +
                        for (int j = 0; j < 4; j++) {
         
     | 
| 678 | 
         
            +
                          reinterpret_cast<int*>(&d1)[j] = reinterpret_cast<int*>(
         
     | 
| 679 | 
         
            +
                              &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)];
         
     | 
| 680 | 
         
            +
                        }
         
     | 
| 681 | 
         
            +
              #pragma unroll
         
     | 
| 682 | 
         
            +
                        for (int j = 0; j < 4; j++) {
         
     | 
| 683 | 
         
            +
                          reinterpret_cast<int*>(&d2)[j] = reinterpret_cast<int*>(
         
     | 
| 684 | 
         
            +
                              &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)];
         
     | 
| 685 | 
         
            +
                        }
         
     | 
| 686 | 
         
            +
                        C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
         
     | 
| 687 | 
         
            +
                            d1;
         
     | 
| 688 | 
         
            +
                        C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) +
         
     | 
| 689 | 
         
            +
                          1] = d2;
         
     | 
| 690 | 
         
            +
                      }
         
     | 
| 691 | 
         
            +
                    }
         
     | 
| 692 | 
         
            +
                  }
         
     | 
| 693 | 
         
            +
                }
         
     | 
| 694 | 
         
            +
              };
         
     | 
| 695 | 
         
            +
             
     | 
| 696 | 
         
            +
              // Write out the reduce final result in the correct layout. We only actually
         
     | 
| 697 | 
         
            +
              // reshuffle matrix fragments in this step, the reduction above is performed
         
     | 
| 698 | 
         
            +
              // in fragment layout.
         
     | 
| 699 | 
         
            +
              auto write_result = [&]() {
         
     | 
| 700 | 
         
            +
                int d_gl_stride = prob_n / 8;
         
     | 
| 701 | 
         
            +
                constexpr int d_sh_stride = 2 * thread_n_blocks + 1;
         
     | 
| 702 | 
         
            +
                int d_gl_wr_delta = d_gl_stride * (threads / (2 * thread_n_blocks));
         
     | 
| 703 | 
         
            +
                constexpr int d_sh_rd_delta =
         
     | 
| 704 | 
         
            +
                    d_sh_stride * (threads / (2 * thread_n_blocks));
         
     | 
| 705 | 
         
            +
             
     | 
| 706 | 
         
            +
                int d_gl_wr = d_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
         
     | 
| 707 | 
         
            +
                              (threadIdx.x % (2 * thread_n_blocks));
         
     | 
| 708 | 
         
            +
                d_gl_wr += (2 * thread_n_blocks) * slice_col;
         
     | 
| 709 | 
         
            +
                int d_sh_wr =
         
     | 
| 710 | 
         
            +
                    (4 * d_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
         
     | 
| 711 | 
         
            +
                d_sh_wr += 32 * (threadIdx.x / 32);
         
     | 
| 712 | 
         
            +
                int d_sh_rd = d_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
         
     | 
| 713 | 
         
            +
                              (threadIdx.x % (2 * thread_n_blocks));
         
     | 
| 714 | 
         
            +
             
     | 
| 715 | 
         
            +
                int d_gl_wr_end = d_gl_stride * prob_m;
         
     | 
| 716 | 
         
            +
             
     | 
| 717 | 
         
            +
                // We first reorder in shared memory to guarantee the most efficient final
         
     | 
| 718 | 
         
            +
                // global write patterns
         
     | 
| 719 | 
         
            +
                auto write = [&](int idx, int c0, int c1, float a_s, FragS_CHANNEL& w_s) {
         
     | 
| 720 | 
         
            +
                  float2 deq_res;
         
     | 
| 721 | 
         
            +
                  deq_res.x = int32_to_float(c0) * w_s[0] * a_s;
         
     | 
| 722 | 
         
            +
                  deq_res.y = int32_to_float(c1) * w_s[1] * a_s;
         
     | 
| 723 | 
         
            +
                  ((half2*)sh)[idx] = float2_to_half2(deq_res);
         
     | 
| 724 | 
         
            +
                };
         
     | 
| 725 | 
         
            +
             
     | 
| 726 | 
         
            +
                if (threadIdx.x / 32 < thread_n_blocks / 4) {
         
     | 
| 727 | 
         
            +
              #pragma unroll
         
     | 
| 728 | 
         
            +
                  for (int i = 0; i < thread_m_blocks; i++) {
         
     | 
| 729 | 
         
            +
              #pragma unroll
         
     | 
| 730 | 
         
            +
                    for (int j = 0; j < 4; j++) {
         
     | 
| 731 | 
         
            +
                      int wr = d_sh_wr + 8 * j;
         
     | 
| 732 | 
         
            +
                      write(wr + (4 * d_sh_stride) * 0 + 0, frag_c[i][j][0][0],
         
     | 
| 733 | 
         
            +
                            frag_c[i][j][0][1], frag_s_tok[i][0],
         
     | 
| 734 | 
         
            +
                            frag_s_ch[j / 2][2 * (j % 2) + 0]);
         
     | 
| 735 | 
         
            +
                      write(wr + (4 * d_sh_stride) * 8 + 0, frag_c[i][j][0][2],
         
     | 
| 736 | 
         
            +
                            frag_c[i][j][0][3], frag_s_tok[i][1],
         
     | 
| 737 | 
         
            +
                            frag_s_ch[j / 2][2 * (j % 2) + 0]);
         
     | 
| 738 | 
         
            +
                      write(wr + (4 * d_sh_stride) * 0 + 4, frag_c[i][j][1][0],
         
     | 
| 739 | 
         
            +
                            frag_c[i][j][1][1], frag_s_tok[i][0],
         
     | 
| 740 | 
         
            +
                            frag_s_ch[j / 2][2 * (j % 2) + 1]);
         
     | 
| 741 | 
         
            +
                      write(wr + (4 * d_sh_stride) * 8 + 4, frag_c[i][j][1][2],
         
     | 
| 742 | 
         
            +
                            frag_c[i][j][1][3], frag_s_tok[i][1],
         
     | 
| 743 | 
         
            +
                            frag_s_ch[j / 2][2 * (j % 2) + 1]);
         
     | 
| 744 | 
         
            +
                    }
         
     | 
| 745 | 
         
            +
                    d_sh_wr += 16 * (4 * d_sh_stride);
         
     | 
| 746 | 
         
            +
                  }
         
     | 
| 747 | 
         
            +
                }
         
     | 
| 748 | 
         
            +
                __syncthreads();
         
     | 
| 749 | 
         
            +
             
     | 
| 750 | 
         
            +
              #pragma unroll
         
     | 
| 751 | 
         
            +
                for (int i = 0;
         
     | 
| 752 | 
         
            +
                     i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
         
     | 
| 753 | 
         
            +
                     i++) {
         
     | 
| 754 | 
         
            +
                  if (d_gl_wr < d_gl_wr_end) {
         
     | 
| 755 | 
         
            +
                    D[d_gl_wr] = sh[d_sh_rd];
         
     | 
| 756 | 
         
            +
                    d_gl_wr += d_gl_wr_delta;
         
     | 
| 757 | 
         
            +
                    d_sh_rd += d_sh_rd_delta;
         
     | 
| 758 | 
         
            +
                  }
         
     | 
| 759 | 
         
            +
                }
         
     | 
| 760 | 
         
            +
              };
         
     | 
| 761 | 
         
            +
             
     | 
| 762 | 
         
            +
              // Start global fetch and register load pipelines.
         
     | 
| 763 | 
         
            +
              auto start_pipes = [&]() {
         
     | 
| 764 | 
         
            +
              #pragma unroll
         
     | 
| 765 | 
         
            +
                for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
         
     | 
| 766 | 
         
            +
                zero_accums();
         
     | 
| 767 | 
         
            +
                wait_for_stage();
         
     | 
| 768 | 
         
            +
                fetch_to_registers(0, 0);
         
     | 
| 769 | 
         
            +
                a_gl_rd += a_gl_rd_delta_o * (stages - 1);
         
     | 
| 770 | 
         
            +
              };
         
     | 
| 771 | 
         
            +
              start_pipes();
         
     | 
| 772 | 
         
            +
             
     | 
| 773 | 
         
            +
              // Main loop.
         
     | 
| 774 | 
         
            +
              while (slice_iters) {
         
     | 
| 775 | 
         
            +
              // We unroll over both the global fetch and the register load pipeline to
         
     | 
| 776 | 
         
            +
              // ensure all shared memory accesses are static. Note that both pipelines have
         
     | 
| 777 | 
         
            +
              // even length meaning that the next iteration will always start at index 0.
         
     | 
| 778 | 
         
            +
              #pragma unroll
         
     | 
| 779 | 
         
            +
                for (int pipe = 0; pipe < stages;) {
         
     | 
| 780 | 
         
            +
              #pragma unroll
         
     | 
| 781 | 
         
            +
                  for (int k = 0; k < b_sh_wr_iters; k++) {
         
     | 
| 782 | 
         
            +
                    fetch_to_registers(k + 1, pipe % stages);
         
     | 
| 783 | 
         
            +
                    if (k == b_sh_wr_iters - 2) {
         
     | 
| 784 | 
         
            +
                      fetch_to_shared((pipe + stages - 1) % stages, pipe,
         
     | 
| 785 | 
         
            +
                                      slice_iters >= stages);
         
     | 
| 786 | 
         
            +
                      pipe++;
         
     | 
| 787 | 
         
            +
                      wait_for_stage();
         
     | 
| 788 | 
         
            +
                    }
         
     | 
| 789 | 
         
            +
                    matmul(k);
         
     | 
| 790 | 
         
            +
                  }
         
     | 
| 791 | 
         
            +
                  slice_iters--;
         
     | 
| 792 | 
         
            +
                  if (slice_iters == 0) break;
         
     | 
| 793 | 
         
            +
                }
         
     | 
| 794 | 
         
            +
                a_gl_rd += a_gl_rd_delta_o * stages;
         
     | 
| 795 | 
         
            +
             
     | 
| 796 | 
         
            +
                // Process results and, if necessary, proceed to the next column slice.
         
     | 
| 797 | 
         
            +
                // While this pattern may not be the most readable, other ways of writing
         
     | 
| 798 | 
         
            +
                // the loop seemed to noticeably worse performance after compilation.
         
     | 
| 799 | 
         
            +
                if (slice_iters == 0) {
         
     | 
| 800 | 
         
            +
                  cp_async_wait<0>();
         
     | 
| 801 | 
         
            +
                  bool last = slice_idx == slice_count - 1;
         
     | 
| 802 | 
         
            +
                  // For per-column scales, we only fetch them here in the final step before
         
     | 
| 803 | 
         
            +
                  // write-out
         
     | 
| 804 | 
         
            +
                  if (last) {
         
     | 
| 805 | 
         
            +
                    if (s_tok_sh_wr_pred) {
         
     | 
| 806 | 
         
            +
                      cp_async1(&sh_s_tok[s_tok_sh_wr], &s_tok[s_tok_gl_rd]);
         
     | 
| 807 | 
         
            +
                    }
         
     | 
| 808 | 
         
            +
                    if (s_ch_sh_wr_pred) {
         
     | 
| 809 | 
         
            +
                      cp_async4(&sh_s_ch[s_ch_sh_wr], &s_ch[s_ch_gl_rd]);
         
     | 
| 810 | 
         
            +
                    }
         
     | 
| 811 | 
         
            +
                    cp_async_fence();
         
     | 
| 812 | 
         
            +
                  }
         
     | 
| 813 | 
         
            +
                  thread_block_reduce();
         
     | 
| 814 | 
         
            +
                  if (last) {
         
     | 
| 815 | 
         
            +
                    cp_async_wait<0>();
         
     | 
| 816 | 
         
            +
                    __syncthreads();
         
     | 
| 817 | 
         
            +
                    if (threadIdx.x / 32 < thread_n_blocks / 4) {
         
     | 
| 818 | 
         
            +
              #pragma unroll
         
     | 
| 819 | 
         
            +
                      for (int i = 0; i < thread_m_blocks; i++) {
         
     | 
| 820 | 
         
            +
                        frag_s_tok[i][0] =
         
     | 
| 821 | 
         
            +
                            *reinterpret_cast<float*>(&sh_s_tok[16 * i + 2 * s_tok_sh_rd]);
         
     | 
| 822 | 
         
            +
                        frag_s_tok[i][1] = *reinterpret_cast<float*>(
         
     | 
| 823 | 
         
            +
                            &sh_s_tok[16 * i + 2 * s_tok_sh_rd + 1]);
         
     | 
| 824 | 
         
            +
                      }
         
     | 
| 825 | 
         
            +
                      reinterpret_cast<int4*>(&frag_s_ch)[0] = sh_s_ch[s_ch_sh_rd + 0];
         
     | 
| 826 | 
         
            +
                      reinterpret_cast<int4*>(&frag_s_ch)[1] = sh_s_ch[s_ch_sh_rd + 1];
         
     | 
| 827 | 
         
            +
                      reinterpret_cast<int4*>(&frag_s_ch)[2] = sh_s_ch[s_ch_sh_rd + 8];
         
     | 
| 828 | 
         
            +
                      reinterpret_cast<int4*>(&frag_s_ch)[3] = sh_s_ch[s_ch_sh_rd + 9];
         
     | 
| 829 | 
         
            +
                    }
         
     | 
| 830 | 
         
            +
                  }
         
     | 
| 831 | 
         
            +
                  if (slice_count > 1) {  // only globally reduce if there is more than one
         
     | 
| 832 | 
         
            +
                                          // block in a slice
         
     | 
| 833 | 
         
            +
                    barrier_acquire(&locks[slice_col], slice_idx);
         
     | 
| 834 | 
         
            +
                    global_reduce(slice_idx == 0, last);
         
     | 
| 835 | 
         
            +
                    barrier_release(&locks[slice_col], last);
         
     | 
| 836 | 
         
            +
                  }
         
     | 
| 837 | 
         
            +
                  if (last)  // only the last block in a slice actually writes the result
         
     | 
| 838 | 
         
            +
                    write_result();
         
     | 
| 839 | 
         
            +
                  slice_row = 0;
         
     | 
| 840 | 
         
            +
                  slice_col_par++;
         
     | 
| 841 | 
         
            +
                  slice_col++;
         
     | 
| 842 | 
         
            +
                  init_slice();
         
     | 
| 843 | 
         
            +
                  if (slice_iters) {
         
     | 
| 844 | 
         
            +
                    a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
         
     | 
| 845 | 
         
            +
                              (threadIdx.x % a_gl_rd_delta_o);
         
     | 
| 846 | 
         
            +
              #pragma unroll
         
     | 
| 847 | 
         
            +
                    for (int i = 0; i < b_sh_wr_iters; i++)
         
     | 
| 848 | 
         
            +
                      B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
         
     | 
| 849 | 
         
            +
                    if (slice_col == 0) {
         
     | 
| 850 | 
         
            +
              #pragma unroll
         
     | 
| 851 | 
         
            +
                      for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
         
     | 
| 852 | 
         
            +
                    }
         
     | 
| 853 | 
         
            +
                    s_group_gl_rd = s_group_sh_stride * slice_col + threadIdx.x;
         
     | 
| 854 | 
         
            +
                    s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
         
     | 
| 855 | 
         
            +
                    start_pipes();
         
     | 
| 856 | 
         
            +
                  }
         
     | 
| 857 | 
         
            +
                }
         
     | 
| 858 | 
         
            +
              }
         
     | 
| 859 | 
         
            +
            }
         
     | 
| 860 | 
         
            +
             
     | 
| 861 | 
         
            +
            #else
         
     | 
| 862 | 
         
            +
             
     | 
| 863 | 
         
            +
            template <const int threads,          // number of threads in a threadblock
         
     | 
| 864 | 
         
            +
                      const int thread_m_blocks,  // number of 16x16 blocks in the m
         
     | 
| 865 | 
         
            +
                                                  // dimension (batchsize) of the
         
     | 
| 866 | 
         
            +
                                                  // threadblock
         
     | 
| 867 | 
         
            +
                      const int thread_n_blocks,  // same for n dimension (output)
         
     | 
| 868 | 
         
            +
                      const int thread_k_blocks,  // same for k dimension (reduction)
         
     | 
| 869 | 
         
            +
                      const int stages,  // number of stages for the async global->shared
         
     | 
| 870 | 
         
            +
                                         // fetch pipeline
         
     | 
| 871 | 
         
            +
                      const int group_blocks = -1  // number of consecutive 16x16 blocks
         
     | 
| 872 | 
         
            +
                                                   // with a separate quantization scale
         
     | 
| 873 | 
         
            +
                      >
         
     | 
| 874 | 
         
            +
            __global__ void Marlin(
         
     | 
| 875 | 
         
            +
                const int4* __restrict__ A,  // int8 input matrix of shape mxk
         
     | 
| 876 | 
         
            +
                const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn
         
     | 
| 877 | 
         
            +
                int4* __restrict__ C,        // int32 global_reduce buffer of shape
         
     | 
| 878 | 
         
            +
                                       // (max_par*16*4)xn, as int8 tensor core's output is
         
     | 
| 879 | 
         
            +
                                       // int32 dtype
         
     | 
| 880 | 
         
            +
                int4* __restrict__ D,              // fp16 output buffer of shape mxn
         
     | 
| 881 | 
         
            +
                const float* __restrict__ s_tok,   // fp32 activation per-token quantization
         
     | 
| 882 | 
         
            +
                                                   // scales of shape mx1
         
     | 
| 883 | 
         
            +
                const int4* __restrict__ s_ch,     // fp32 weight per-channel quantization
         
     | 
| 884 | 
         
            +
                                                   // scales of shape 1xn
         
     | 
| 885 | 
         
            +
                const int4* __restrict__ s_group,  // fp16 weight per-group quantization
         
     | 
| 886 | 
         
            +
                                                   // scales of shape (k/groupsize)xn, when
         
     | 
| 887 | 
         
            +
                                                   // group_blocks=-1, it should be nullptr
         
     | 
| 888 | 
         
            +
                int prob_m,                        // batch dimension m
         
     | 
| 889 | 
         
            +
                int prob_n,                        // output dimension n
         
     | 
| 890 | 
         
            +
                int prob_k,                        // reduction dimension k
         
     | 
| 891 | 
         
            +
                int* locks  // extra global storage for barrier synchronization
         
     | 
| 892 | 
         
            +
            ) {
         
     | 
| 893 | 
         
            +
              // Marlin is not implemented yet for SM < 8.0
         
     | 
| 894 | 
         
            +
              assert(false);
         
     | 
| 895 | 
         
            +
              return;
         
     | 
| 896 | 
         
            +
            }
         
     | 
| 897 | 
         
            +
             
     | 
| 898 | 
         
            +
            #endif
         
     | 
| 899 | 
         
            +
             
     | 
| 900 | 
         
            +
            // 8 warps are a good choice since every SM has 4 schedulers and having more
         
     | 
| 901 | 
         
            +
            // than 1 warp per schedule allows some more latency hiding. At the same time,
         
     | 
| 902 | 
         
            +
            // we want relatively few warps to have many registers per warp and small tiles.
         
     | 
| 903 | 
         
            +
            const int USER_THREADS =
         
     | 
| 904 | 
         
            +
                256;               // Note: This is only used with user-provided thread_k/n
         
     | 
| 905 | 
         
            +
            const int STAGES = 4;  // 4 pipeline stages fit into shared memory
         
     | 
| 906 | 
         
            +
             
     | 
| 907 | 
         
            +
            static constexpr int min_thread_n = 64;
         
     | 
| 908 | 
         
            +
            static constexpr int min_thread_k = 64;
         
     | 
| 909 | 
         
            +
             
     | 
| 910 | 
         
            +
            static constexpr int tile_size = 16;
         
     | 
| 911 | 
         
            +
            static constexpr int max_par = 16;
         
     | 
| 912 | 
         
            +
             
     | 
| 913 | 
         
            +
            static constexpr int pack_factor_4bit =
         
     | 
| 914 | 
         
            +
                8;  // We have 8 4-bit vals inside a 32 bit
         
     | 
| 915 | 
         
            +
             
     | 
| 916 | 
         
            +
            #define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS,           \
         
     | 
| 917 | 
         
            +
                              GROUP_BLOCKS, NUM_THREADS)                                   \
         
     | 
| 918 | 
         
            +
              else if (thread_m_blocks == THREAD_M_BLOCKS &&                               \
         
     | 
| 919 | 
         
            +
                       thread_n_blocks == THREAD_N_BLOCKS &&                               \
         
     | 
| 920 | 
         
            +
                       thread_k_blocks == THREAD_K_BLOCKS &&                               \
         
     | 
| 921 | 
         
            +
                       group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) {       \
         
     | 
| 922 | 
         
            +
                cudaFuncSetAttribute(Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
         
     | 
| 923 | 
         
            +
                                            THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>,        \
         
     | 
| 924 | 
         
            +
                                     cudaFuncAttributeMaxDynamicSharedMemorySize,          \
         
     | 
| 925 | 
         
            +
                                     max_shared_mem);                                      \
         
     | 
| 926 | 
         
            +
                Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS,     \
         
     | 
| 927 | 
         
            +
                       STAGES, GROUP_BLOCKS>                                               \
         
     | 
| 928 | 
         
            +
                    <<<blocks, NUM_THREADS, max_shared_mem, stream>>>(                     \
         
     | 
| 929 | 
         
            +
                        A_ptr, B_ptr, C_ptr, D_ptr, s_tok_ptr, s_ch_ptr, s_group_ptr,      \
         
     | 
| 930 | 
         
            +
                        prob_m, prob_n, prob_k, locks);                                    \
         
     | 
| 931 | 
         
            +
              }
         
     | 
| 932 | 
         
            +
             
     | 
| 933 | 
         
            +
            typedef struct {
         
     | 
| 934 | 
         
            +
              int thread_k;
         
     | 
| 935 | 
         
            +
              int thread_n;
         
     | 
| 936 | 
         
            +
              int num_threads;
         
     | 
| 937 | 
         
            +
            } thread_config_t;
         
     | 
| 938 | 
         
            +
             
     | 
| 939 | 
         
            +
            thread_config_t small_batch_thread_configs[] = {
         
     | 
| 940 | 
         
            +
                // Ordered by priority
         
     | 
| 941 | 
         
            +
             
     | 
| 942 | 
         
            +
                // thread_k, thread_n, num_threads
         
     | 
| 943 | 
         
            +
                {128, 128, 256},  // Default
         
     | 
| 944 | 
         
            +
                {128, 64, 128},   // Reduce N 2X, same K
         
     | 
| 945 | 
         
            +
                {64, 256, 256},   // Reduce K 2X, increase N 2X
         
     | 
| 946 | 
         
            +
                {64, 128, 128},   // Reduce K 2X, same N
         
     | 
| 947 | 
         
            +
            };
         
     | 
| 948 | 
         
            +
             
     | 
| 949 | 
         
            +
            thread_config_t large_batch_thread_configs[] = {
         
     | 
| 950 | 
         
            +
                // Ordered by priority
         
     | 
| 951 | 
         
            +
             
     | 
| 952 | 
         
            +
                // thread_k, thread_n, num_threads
         
     | 
| 953 | 
         
            +
                {64, 256, 256},   // Default
         
     | 
| 954 | 
         
            +
                {128, 128, 256},  // Reduce N 2X, increase K 2X
         
     | 
| 955 | 
         
            +
                {64, 128, 128},   // Reduce N 2X, same K
         
     | 
| 956 | 
         
            +
                {128, 64, 128},   // Reduce N 4X, increase K 2X
         
     | 
| 957 | 
         
            +
            };
         
     | 
| 958 | 
         
            +
             
     | 
| 959 | 
         
            +
            bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n,
         
     | 
| 960 | 
         
            +
                                 int prob_k) {
         
     | 
| 961 | 
         
            +
              // Sanity
         
     | 
| 962 | 
         
            +
              if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
         
     | 
| 963 | 
         
            +
                  th_config.num_threads == -1) {
         
     | 
| 964 | 
         
            +
                return false;
         
     | 
| 965 | 
         
            +
              }
         
     | 
| 966 | 
         
            +
             
     | 
| 967 | 
         
            +
              // Verify K/N are divisible by thread K/N
         
     | 
| 968 | 
         
            +
              if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
         
     | 
| 969 | 
         
            +
                return false;
         
     | 
| 970 | 
         
            +
              }
         
     | 
| 971 | 
         
            +
             
     | 
| 972 | 
         
            +
              // thread_k can be only 128 or 64 (because it must be less than groupsize
         
     | 
| 973 | 
         
            +
              // which is 128)
         
     | 
| 974 | 
         
            +
              if (th_config.thread_k != 128 && th_config.thread_k != 64) {
         
     | 
| 975 | 
         
            +
                return false;
         
     | 
| 976 | 
         
            +
              }
         
     | 
| 977 | 
         
            +
             
     | 
| 978 | 
         
            +
              // Verify min for thread K/N
         
     | 
| 979 | 
         
            +
              if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
         
     | 
| 980 | 
         
            +
                return false;
         
     | 
| 981 | 
         
            +
              }
         
     | 
| 982 | 
         
            +
             
     | 
| 983 | 
         
            +
              // num_threads must be at least 128 (= 4 warps)
         
     | 
| 984 | 
         
            +
              if (th_config.num_threads < 128) {
         
     | 
| 985 | 
         
            +
                return false;
         
     | 
| 986 | 
         
            +
              }
         
     | 
| 987 | 
         
            +
             
     | 
| 988 | 
         
            +
              return true;
         
     | 
| 989 | 
         
            +
            }
         
     | 
| 990 | 
         
            +
             
     | 
| 991 | 
         
            +
            thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
         
     | 
| 992 | 
         
            +
              if (prob_m <= 16) {
         
     | 
| 993 | 
         
            +
                for (auto th_config : small_batch_thread_configs) {
         
     | 
| 994 | 
         
            +
                  if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
         
     | 
| 995 | 
         
            +
                    return th_config;
         
     | 
| 996 | 
         
            +
                  }
         
     | 
| 997 | 
         
            +
                }
         
     | 
| 998 | 
         
            +
             
     | 
| 999 | 
         
            +
              } else {
         
     | 
| 1000 | 
         
            +
                for (auto th_config : large_batch_thread_configs) {
         
     | 
| 1001 | 
         
            +
                  if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
         
     | 
| 1002 | 
         
            +
                    return th_config;
         
     | 
| 1003 | 
         
            +
                  }
         
     | 
| 1004 | 
         
            +
                }
         
     | 
| 1005 | 
         
            +
              }
         
     | 
| 1006 | 
         
            +
             
     | 
| 1007 | 
         
            +
              return thread_config_t{-1, -1, -1};
         
     | 
| 1008 | 
         
            +
            }
         
     | 
| 1009 | 
         
            +
             
     | 
| 1010 | 
         
            +
            #define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS)    \
         
     | 
| 1011 | 
         
            +
              __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
         
     | 
| 1012 | 
         
            +
              __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)  \
         
     | 
| 1013 | 
         
            +
              __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
         
     | 
| 1014 | 
         
            +
              __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)  \
         
     | 
| 1015 | 
         
            +
              __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
         
     | 
| 1016 | 
         
            +
              __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)  \
         
     | 
| 1017 | 
         
            +
              __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
         
     | 
| 1018 | 
         
            +
              __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)  \
         
     | 
| 1019 | 
         
            +
              __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
         
     | 
| 1020 | 
         
            +
              __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS)
         
     | 
| 1021 | 
         
            +
             
     | 
| 1022 | 
         
            +
            void marlin_qqq_cuda(const void* A, const void* B, void* C, void* D,
         
     | 
| 1023 | 
         
            +
                                 void* s_tok, void* s_ch, void* s_group, int prob_m,
         
     | 
| 1024 | 
         
            +
                                 int prob_n, int prob_k, void* workspace,
         
     | 
| 1025 | 
         
            +
                                 int groupsize = -1, int dev = 0, cudaStream_t stream = 0,
         
     | 
| 1026 | 
         
            +
                                 int thread_k = -1, int thread_n = -1, int sms = -1,
         
     | 
| 1027 | 
         
            +
                                 int max_par = 16) {
         
     | 
| 1028 | 
         
            +
              int tot_m = prob_m;
         
     | 
| 1029 | 
         
            +
              int tot_m_blocks = ceildiv(tot_m, 16);
         
     | 
| 1030 | 
         
            +
              int pad = 16 * tot_m_blocks - tot_m;
         
     | 
| 1031 | 
         
            +
             
     | 
| 1032 | 
         
            +
              if (sms == -1)
         
     | 
| 1033 | 
         
            +
                cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
         
     | 
| 1034 | 
         
            +
             
     | 
| 1035 | 
         
            +
              int max_shared_mem = 0;
         
     | 
| 1036 | 
         
            +
              cudaDeviceGetAttribute(&max_shared_mem,
         
     | 
| 1037 | 
         
            +
                                     cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
         
     | 
| 1038 | 
         
            +
              TORCH_CHECK(max_shared_mem > 0);
         
     | 
| 1039 | 
         
            +
             
     | 
| 1040 | 
         
            +
              // Set thread config
         
     | 
| 1041 | 
         
            +
              thread_config_t th_config;
         
     | 
| 1042 | 
         
            +
              if (thread_k != -1 && thread_n != -1) {
         
     | 
| 1043 | 
         
            +
                // User-defined config
         
     | 
| 1044 | 
         
            +
                th_config = thread_config_t{thread_k, thread_n, USER_THREADS};
         
     | 
| 1045 | 
         
            +
              } else {
         
     | 
| 1046 | 
         
            +
                // Auto config
         
     | 
| 1047 | 
         
            +
                th_config = determine_thread_config(prob_m, prob_n, prob_k);
         
     | 
| 1048 | 
         
            +
              }
         
     | 
| 1049 | 
         
            +
             
     | 
| 1050 | 
         
            +
              if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) {
         
     | 
| 1051 | 
         
            +
                throw std::runtime_error(
         
     | 
| 1052 | 
         
            +
                    "Invalid thread config: thread_k = " + str(th_config.thread_k) +
         
     | 
| 1053 | 
         
            +
                    ", thread_n = " + str(th_config.thread_n) +
         
     | 
| 1054 | 
         
            +
                    ", num_threads = " + str(th_config.num_threads) + " for MKN = [" +
         
     | 
| 1055 | 
         
            +
                    str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]");
         
     | 
| 1056 | 
         
            +
              }
         
     | 
| 1057 | 
         
            +
             
     | 
| 1058 | 
         
            +
              int num_threads = th_config.num_threads;
         
     | 
| 1059 | 
         
            +
              thread_k = th_config.thread_k;
         
     | 
| 1060 | 
         
            +
              thread_n = th_config.thread_n;
         
     | 
| 1061 | 
         
            +
             
     | 
| 1062 | 
         
            +
              int thread_k_blocks = thread_k / 16;
         
     | 
| 1063 | 
         
            +
              int thread_n_blocks = thread_n / 16;
         
     | 
| 1064 | 
         
            +
              int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
         
     | 
| 1065 | 
         
            +
              int blocks = sms;
         
     | 
| 1066 | 
         
            +
             
     | 
| 1067 | 
         
            +
              if (prob_m == 0 || prob_n == 0 || prob_k == 0) {
         
     | 
| 1068 | 
         
            +
                return;
         
     | 
| 1069 | 
         
            +
              }
         
     | 
| 1070 | 
         
            +
             
     | 
| 1071 | 
         
            +
              TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
         
     | 
| 1072 | 
         
            +
                          " is not divisible by thread_n = ", thread_n);
         
     | 
| 1073 | 
         
            +
              TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
         
     | 
| 1074 | 
         
            +
                          " is not divisible by thread_k = ", thread_k);
         
     | 
| 1075 | 
         
            +
              if (group_blocks != -1) {
         
     | 
| 1076 | 
         
            +
                TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
         
     | 
| 1077 | 
         
            +
                            " is not divisible by group_blocks = ", group_blocks);
         
     | 
| 1078 | 
         
            +
              }
         
     | 
| 1079 | 
         
            +
             
     | 
| 1080 | 
         
            +
              const int4* A_ptr = (const int4*)A;
         
     | 
| 1081 | 
         
            +
              const int4* B_ptr = (const int4*)B;
         
     | 
| 1082 | 
         
            +
              int4* C_ptr = (int4*)C;
         
     | 
| 1083 | 
         
            +
              int4* D_ptr = (int4*)D;
         
     | 
| 1084 | 
         
            +
              const float* s_tok_ptr = (const float*)s_tok;
         
     | 
| 1085 | 
         
            +
              const int4* s_ch_ptr = (const int4*)s_ch;
         
     | 
| 1086 | 
         
            +
              const int4* s_group_ptr = (const int4*)s_group;
         
     | 
| 1087 | 
         
            +
             
     | 
| 1088 | 
         
            +
              int* locks = (int*)workspace;
         
     | 
| 1089 | 
         
            +
             
     | 
| 1090 | 
         
            +
              for (int i = 0; i < tot_m_blocks; i += 4) {
         
     | 
| 1091 | 
         
            +
                int thread_m_blocks = tot_m_blocks - i;
         
     | 
| 1092 | 
         
            +
                prob_m = tot_m - 16 * i;
         
     | 
| 1093 | 
         
            +
                int par = 1;
         
     | 
| 1094 | 
         
            +
                if (thread_m_blocks > 4) {
         
     | 
| 1095 | 
         
            +
                  // Note that parallel > 1 currently only works for inputs without any
         
     | 
| 1096 | 
         
            +
                  // padding
         
     | 
| 1097 | 
         
            +
                  par = (16 * thread_m_blocks - pad) / 64;
         
     | 
| 1098 | 
         
            +
                  if (par > max_par) par = max_par;
         
     | 
| 1099 | 
         
            +
                  prob_m = 64 * par;
         
     | 
| 1100 | 
         
            +
                  i += 4 * (par - 1);
         
     | 
| 1101 | 
         
            +
                  thread_m_blocks = 4;
         
     | 
| 1102 | 
         
            +
                }
         
     | 
| 1103 | 
         
            +
             
     | 
| 1104 | 
         
            +
                // For compilation speed, we only define the kernel configurations that have
         
     | 
| 1105 | 
         
            +
                // seemed useful (in terms of performance) in our testing, however many more
         
     | 
| 1106 | 
         
            +
                // are, in principle, possible.
         
     | 
| 1107 | 
         
            +
                if (false) {
         
     | 
| 1108 | 
         
            +
                }
         
     | 
| 1109 | 
         
            +
                CALL_IF(8, 8, 256)
         
     | 
| 1110 | 
         
            +
                CALL_IF(16, 4, 256)
         
     | 
| 1111 | 
         
            +
                CALL_IF(8, 4, 128)
         
     | 
| 1112 | 
         
            +
                CALL_IF(4, 8, 128)
         
     | 
| 1113 | 
         
            +
                else {
         
     | 
| 1114 | 
         
            +
                  throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
         
     | 
| 1115 | 
         
            +
                                           ", " + str(prob_k) + ", " + str(prob_n) + "]" +
         
     | 
| 1116 | 
         
            +
                                           ", groupsize = " + str(groupsize) +
         
     | 
| 1117 | 
         
            +
                                           ", thread_m_blocks = " + str(thread_m_blocks) +
         
     | 
| 1118 | 
         
            +
                                           ", thread_n_blocks = " + str(thread_n_blocks) +
         
     | 
| 1119 | 
         
            +
                                           ", thread_k_blocks = " + str(thread_k_blocks));
         
     | 
| 1120 | 
         
            +
                }
         
     | 
| 1121 | 
         
            +
             
     | 
| 1122 | 
         
            +
                A_ptr += 16 * thread_m_blocks * (prob_k / 16) * par;
         
     | 
| 1123 | 
         
            +
                D_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
         
     | 
| 1124 | 
         
            +
                s_tok_ptr += 16 * thread_m_blocks * par;
         
     | 
| 1125 | 
         
            +
              }
         
     | 
| 1126 | 
         
            +
            }
         
     | 
| 1127 | 
         
            +
            }  // anonymous namespace
         
     | 
| 1128 | 
         
            +
             
     | 
| 1129 | 
         
            +
            torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
         
     | 
| 1130 | 
         
            +
                                          torch::Tensor const& b_q_weight,
         
     | 
| 1131 | 
         
            +
                                          torch::Tensor const& s_tok,
         
     | 
| 1132 | 
         
            +
                                          torch::Tensor const& s_ch,
         
     | 
| 1133 | 
         
            +
                                          torch::Tensor const& s_group,
         
     | 
| 1134 | 
         
            +
                                          torch::Tensor& workspace, int64_t size_m,
         
     | 
| 1135 | 
         
            +
                                          int64_t size_n, int64_t size_k) {
         
     | 
| 1136 | 
         
            +
              // Verify M
         
     | 
| 1137 | 
         
            +
              TORCH_CHECK(size_m == a.size(0),
         
     | 
| 1138 | 
         
            +
                          "Shape mismatch: a.size(0) = " + str(a.size(0)) +
         
     | 
| 1139 | 
         
            +
                              ", size_m = " + str(size_m));
         
     | 
| 1140 | 
         
            +
              TORCH_CHECK(size_m == s_tok.numel(),
         
     | 
| 1141 | 
         
            +
                          "Shape mismatch: s_tok.numel() = " + str(s_tok.numel()) +
         
     | 
| 1142 | 
         
            +
                              ", size_m = " + str(size_m));
         
     | 
| 1143 | 
         
            +
             
     | 
| 1144 | 
         
            +
              // Verify K
         
     | 
| 1145 | 
         
            +
              TORCH_CHECK(size_k == a.size(1),
         
     | 
| 1146 | 
         
            +
                          "Shape mismatch: a.size(1) = " + str(a.size(1)) +
         
     | 
| 1147 | 
         
            +
                              ", size_k = " + str(size_k));
         
     | 
| 1148 | 
         
            +
              TORCH_CHECK(size_k % tile_size == 0,
         
     | 
| 1149 | 
         
            +
                          "size_k = " + str(size_k) +
         
     | 
| 1150 | 
         
            +
                              " is not divisible by tile_size = " + str(tile_size));
         
     | 
| 1151 | 
         
            +
              TORCH_CHECK(
         
     | 
| 1152 | 
         
            +
                  (size_k / tile_size) == b_q_weight.size(0),
         
     | 
| 1153 | 
         
            +
                  "Shape mismatch: b_q_weight.size(0) = " + str(b_q_weight.size(0)) +
         
     | 
| 1154 | 
         
            +
                      ", size_k = " + str(size_k) + ", tile_size = " + str(tile_size));
         
     | 
| 1155 | 
         
            +
             
     | 
| 1156 | 
         
            +
              int groupsize = (s_group.numel() == 0) ? -1 : size_k / s_group.size(0);
         
     | 
| 1157 | 
         
            +
              // Verify groupsize
         
     | 
| 1158 | 
         
            +
              TORCH_CHECK(groupsize == -1 || groupsize == 128,
         
     | 
| 1159 | 
         
            +
                          "Unexpected groupsize = " + str(groupsize));
         
     | 
| 1160 | 
         
            +
             
     | 
| 1161 | 
         
            +
              // Verify N
         
     | 
| 1162 | 
         
            +
              TORCH_CHECK(s_ch.numel() == size_n,
         
     | 
| 1163 | 
         
            +
                          "Shape mismatch: s_ch.numel() = " + str(s_ch.numel()) +
         
     | 
| 1164 | 
         
            +
                              ", size_n = " + str(size_n));
         
     | 
| 1165 | 
         
            +
              TORCH_CHECK(b_q_weight.size(1) % tile_size == 0,
         
     | 
| 1166 | 
         
            +
                          "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
         
     | 
| 1167 | 
         
            +
                              " is not divisible by tile_size = " + str(tile_size));
         
     | 
| 1168 | 
         
            +
              if (groupsize != -1) {
         
     | 
| 1169 | 
         
            +
                TORCH_CHECK(s_group.size(1) == size_n,
         
     | 
| 1170 | 
         
            +
                            "Shape mismatch: s_group.size(1) = " + str(s_group.size(1)) +
         
     | 
| 1171 | 
         
            +
                                ", size_n = " + str(size_n));
         
     | 
| 1172 | 
         
            +
                TORCH_CHECK(
         
     | 
| 1173 | 
         
            +
                    size_k % s_group.size(0) == 0,
         
     | 
| 1174 | 
         
            +
                    "size_k = " + str(size_k) +
         
     | 
| 1175 | 
         
            +
                        ", is not divisible by s_group.size(0) = " + str(s_group.size(0)));
         
     | 
| 1176 | 
         
            +
              }
         
     | 
| 1177 | 
         
            +
             
     | 
| 1178 | 
         
            +
              int actual_size_n = (b_q_weight.size(1) / tile_size) * pack_factor_4bit;
         
     | 
| 1179 | 
         
            +
              TORCH_CHECK(size_n == actual_size_n,
         
     | 
| 1180 | 
         
            +
                          "Shape mismatch: size_n = " + str(size_n) +
         
     | 
| 1181 | 
         
            +
                              ", actual_size_n = " + str(actual_size_n));
         
     | 
| 1182 | 
         
            +
             
     | 
| 1183 | 
         
            +
              // Verify A device and strides
         
     | 
| 1184 | 
         
            +
              TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
         
     | 
| 1185 | 
         
            +
              TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
         
     | 
| 1186 | 
         
            +
             
     | 
| 1187 | 
         
            +
              // Verify B device and strides
         
     | 
| 1188 | 
         
            +
              TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
         
     | 
| 1189 | 
         
            +
              TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
         
     | 
| 1190 | 
         
            +
             
     | 
| 1191 | 
         
            +
              // Verify s_tok device, strides and dtype
         
     | 
| 1192 | 
         
            +
              TORCH_CHECK(s_tok.device().is_cuda(), "s_tok is not on GPU");
         
     | 
| 1193 | 
         
            +
              TORCH_CHECK(s_tok.is_contiguous(), "s_tok is not contiguous");
         
     | 
| 1194 | 
         
            +
              TORCH_CHECK(s_tok.dtype() == torch::kFloat32, "s_tok's dtype is not float32");
         
     | 
| 1195 | 
         
            +
             
     | 
| 1196 | 
         
            +
              // Verify s_ch device, strides and dtype
         
     | 
| 1197 | 
         
            +
              TORCH_CHECK(s_ch.device().is_cuda(), "s_ch is not on GPU");
         
     | 
| 1198 | 
         
            +
              TORCH_CHECK(s_ch.is_contiguous(), "s_ch is not contiguous");
         
     | 
| 1199 | 
         
            +
              TORCH_CHECK(s_ch.dtype() == torch::kFloat32, "s_ch's dtype is not float32");
         
     | 
| 1200 | 
         
            +
             
     | 
| 1201 | 
         
            +
              // Verify s_group device, strides and dtype
         
     | 
| 1202 | 
         
            +
              TORCH_CHECK(s_group.device().is_cuda(), "s_group is not on GPU");
         
     | 
| 1203 | 
         
            +
              TORCH_CHECK(s_group.is_contiguous(), "s_group is not contiguous");
         
     | 
| 1204 | 
         
            +
              TORCH_CHECK(s_group.dtype() == torch::kFloat16,
         
     | 
| 1205 | 
         
            +
                          "s_group's dtype is not float16");
         
     | 
| 1206 | 
         
            +
             
     | 
| 1207 | 
         
            +
              // Verify workspace size
         
     | 
| 1208 | 
         
            +
              TORCH_CHECK(size_n % min_thread_n == 0,
         
     | 
| 1209 | 
         
            +
                          "size_n = " + str(size_n) +
         
     | 
| 1210 | 
         
            +
                              ", is not divisible by min_thread_n = " + str(min_thread_n));
         
     | 
| 1211 | 
         
            +
              int min_workspace_size = (size_n / min_thread_n) * max_par;
         
     | 
| 1212 | 
         
            +
              TORCH_CHECK(workspace.numel() >= min_workspace_size,
         
     | 
| 1213 | 
         
            +
                          "workspace.numel = " + str(workspace.numel()) +
         
     | 
| 1214 | 
         
            +
                              " is below min_workspace_size = " + str(min_workspace_size));
         
     | 
| 1215 | 
         
            +
             
     | 
| 1216 | 
         
            +
              // Alloc C matrix
         
     | 
| 1217 | 
         
            +
              const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
         
     | 
| 1218 | 
         
            +
              auto options_c = torch::TensorOptions().dtype(torch::kInt).device(a.device());
         
     | 
| 1219 | 
         
            +
              torch::Tensor c = torch::empty({max_par * 64, size_n}, options_c);
         
     | 
| 1220 | 
         
            +
             
     | 
| 1221 | 
         
            +
              // Alloc D matrix
         
     | 
| 1222 | 
         
            +
              auto options_d =
         
     | 
| 1223 | 
         
            +
                  torch::TensorOptions().dtype(torch::kFloat16).device(a.device());
         
     | 
| 1224 | 
         
            +
              torch::Tensor d = torch::empty({size_m, size_n}, options_d);
         
     | 
| 1225 | 
         
            +
             
     | 
| 1226 | 
         
            +
              // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
         
     | 
| 1227 | 
         
            +
              // auto -1)
         
     | 
| 1228 | 
         
            +
              int thread_k = -1;
         
     | 
| 1229 | 
         
            +
              // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
         
     | 
| 1230 | 
         
            +
              // auto -1)
         
     | 
| 1231 | 
         
            +
              int thread_n = -1;
         
     | 
| 1232 | 
         
            +
              // sms: number of SMs to use for the kernel (can usually be left as auto -1)
         
     | 
| 1233 | 
         
            +
              int sms = -1;
         
     | 
| 1234 | 
         
            +
             
     | 
| 1235 | 
         
            +
              int dev = a.get_device();
         
     | 
| 1236 | 
         
            +
              marlin_qqq_cuda(
         
     | 
| 1237 | 
         
            +
                  a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), d.data_ptr(),
         
     | 
| 1238 | 
         
            +
                  s_tok.data_ptr(), s_ch.data_ptr(), s_group.data_ptr(), size_m, size_n,
         
     | 
| 1239 | 
         
            +
                  size_k, workspace.data_ptr(), groupsize, dev,
         
     | 
| 1240 | 
         
            +
                  at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par);
         
     | 
| 1241 | 
         
            +
             
     | 
| 1242 | 
         
            +
              return d;
         
     | 
| 1243 | 
         
            +
            }
         
     | 
    	
        marlin/sparse/LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,203 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            Contains code from https://github.com/IST-DASLab/Sparse-Marlin/
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
                                             Apache License
         
     | 
| 4 | 
         
            +
                                       Version 2.0, January 2004
         
     | 
| 5 | 
         
            +
                                    http://www.apache.org/licenses/
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
               1. Definitions.
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         
     | 
| 12 | 
         
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         
     | 
| 15 | 
         
            +
                  the copyright owner that is granting the License.
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         
     | 
| 18 | 
         
            +
                  other entities that control, are controlled by, or are under common
         
     | 
| 19 | 
         
            +
                  control with that entity. For the purposes of this definition,
         
     | 
| 20 | 
         
            +
                  "control" means (i) the power, direct or indirect, to cause the
         
     | 
| 21 | 
         
            +
                  direction or management of such entity, whether by contract or
         
     | 
| 22 | 
         
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         
     | 
| 23 | 
         
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         
     | 
| 26 | 
         
            +
                  exercising permissions granted by this License.
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                  "Source" form shall mean the preferred form for making modifications,
         
     | 
| 29 | 
         
            +
                  including but not limited to software source code, documentation
         
     | 
| 30 | 
         
            +
                  source, and configuration files.
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                  "Object" form shall mean any form resulting from mechanical
         
     | 
| 33 | 
         
            +
                  transformation or translation of a Source form, including but
         
     | 
| 34 | 
         
            +
                  not limited to compiled object code, generated documentation,
         
     | 
| 35 | 
         
            +
                  and conversions to other media types.
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                  "Work" shall mean the work of authorship, whether in Source or
         
     | 
| 38 | 
         
            +
                  Object form, made available under the License, as indicated by a
         
     | 
| 39 | 
         
            +
                  copyright notice that is included in or attached to the work
         
     | 
| 40 | 
         
            +
                  (an example is provided in the Appendix below).
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         
     | 
| 43 | 
         
            +
                  form, that is based on (or derived from) the Work and for which the
         
     | 
| 44 | 
         
            +
                  editorial revisions, annotations, elaborations, or other modifications
         
     | 
| 45 | 
         
            +
                  represent, as a whole, an original work of authorship. For the purposes
         
     | 
| 46 | 
         
            +
                  of this License, Derivative Works shall not include works that remain
         
     | 
| 47 | 
         
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         
     | 
| 48 | 
         
            +
                  the Work and Derivative Works thereof.
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                  "Contribution" shall mean any work of authorship, including
         
     | 
| 51 | 
         
            +
                  the original version of the Work and any modifications or additions
         
     | 
| 52 | 
         
            +
                  to that Work or Derivative Works thereof, that is intentionally
         
     | 
| 53 | 
         
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         
     | 
| 54 | 
         
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         
     | 
| 55 | 
         
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         
     | 
| 56 | 
         
            +
                  means any form of electronic, verbal, or written communication sent
         
     | 
| 57 | 
         
            +
                  to the Licensor or its representatives, including but not limited to
         
     | 
| 58 | 
         
            +
                  communication on electronic mailing lists, source code control systems,
         
     | 
| 59 | 
         
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         
     | 
| 60 | 
         
            +
                  Licensor for the purpose of discussing and improving the Work, but
         
     | 
| 61 | 
         
            +
                  excluding communication that is conspicuously marked or otherwise
         
     | 
| 62 | 
         
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         
     | 
| 65 | 
         
            +
                  on behalf of whom a Contribution has been received by Licensor and
         
     | 
| 66 | 
         
            +
                  subsequently incorporated within the Work.
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         
     | 
| 69 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 70 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 71 | 
         
            +
                  copyright license to reproduce, prepare Derivative Works of,
         
     | 
| 72 | 
         
            +
                  publicly display, publicly perform, sublicense, and distribute the
         
     | 
| 73 | 
         
            +
                  Work and such Derivative Works in Source or Object form.
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         
     | 
| 76 | 
         
            +
                  this License, each Contributor hereby grants to You a perpetual,
         
     | 
| 77 | 
         
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         
     | 
| 78 | 
         
            +
                  (except as stated in this section) patent license to make, have made,
         
     | 
| 79 | 
         
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         
     | 
| 80 | 
         
            +
                  where such license applies only to those patent claims licensable
         
     | 
| 81 | 
         
            +
                  by such Contributor that are necessarily infringed by their
         
     | 
| 82 | 
         
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         
     | 
| 83 | 
         
            +
                  with the Work to which such Contribution(s) was submitted. If You
         
     | 
| 84 | 
         
            +
                  institute patent litigation against any entity (including a
         
     | 
| 85 | 
         
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         
     | 
| 86 | 
         
            +
                  or a Contribution incorporated within the Work constitutes direct
         
     | 
| 87 | 
         
            +
                  or contributory patent infringement, then any patent licenses
         
     | 
| 88 | 
         
            +
                  granted to You under this License for that Work shall terminate
         
     | 
| 89 | 
         
            +
                  as of the date such litigation is filed.
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
               4. Redistribution. You may reproduce and distribute copies of the
         
     | 
| 92 | 
         
            +
                  Work or Derivative Works thereof in any medium, with or without
         
     | 
| 93 | 
         
            +
                  modifications, and in Source or Object form, provided that You
         
     | 
| 94 | 
         
            +
                  meet the following conditions:
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                  (a) You must give any other recipients of the Work or
         
     | 
| 97 | 
         
            +
                      Derivative Works a copy of this License; and
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                  (b) You must cause any modified files to carry prominent notices
         
     | 
| 100 | 
         
            +
                      stating that You changed the files; and
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                  (c) You must retain, in the Source form of any Derivative Works
         
     | 
| 103 | 
         
            +
                      that You distribute, all copyright, patent, trademark, and
         
     | 
| 104 | 
         
            +
                      attribution notices from the Source form of the Work,
         
     | 
| 105 | 
         
            +
                      excluding those notices that do not pertain to any part of
         
     | 
| 106 | 
         
            +
                      the Derivative Works; and
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         
     | 
| 109 | 
         
            +
                      distribution, then any Derivative Works that You distribute must
         
     | 
| 110 | 
         
            +
                      include a readable copy of the attribution notices contained
         
     | 
| 111 | 
         
            +
                      within such NOTICE file, excluding those notices that do not
         
     | 
| 112 | 
         
            +
                      pertain to any part of the Derivative Works, in at least one
         
     | 
| 113 | 
         
            +
                      of the following places: within a NOTICE text file distributed
         
     | 
| 114 | 
         
            +
                      as part of the Derivative Works; within the Source form or
         
     | 
| 115 | 
         
            +
                      documentation, if provided along with the Derivative Works; or,
         
     | 
| 116 | 
         
            +
                      within a display generated by the Derivative Works, if and
         
     | 
| 117 | 
         
            +
                      wherever such third-party notices normally appear. The contents
         
     | 
| 118 | 
         
            +
                      of the NOTICE file are for informational purposes only and
         
     | 
| 119 | 
         
            +
                      do not modify the License. You may add Your own attribution
         
     | 
| 120 | 
         
            +
                      notices within Derivative Works that You distribute, alongside
         
     | 
| 121 | 
         
            +
                      or as an addendum to the NOTICE text from the Work, provided
         
     | 
| 122 | 
         
            +
                      that such additional attribution notices cannot be construed
         
     | 
| 123 | 
         
            +
                      as modifying the License.
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                  You may add Your own copyright statement to Your modifications and
         
     | 
| 126 | 
         
            +
                  may provide additional or different license terms and conditions
         
     | 
| 127 | 
         
            +
                  for use, reproduction, or distribution of Your modifications, or
         
     | 
| 128 | 
         
            +
                  for any such Derivative Works as a whole, provided Your use,
         
     | 
| 129 | 
         
            +
                  reproduction, and distribution of the Work otherwise complies with
         
     | 
| 130 | 
         
            +
                  the conditions stated in this License.
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         
     | 
| 133 | 
         
            +
                  any Contribution intentionally submitted for inclusion in the Work
         
     | 
| 134 | 
         
            +
                  by You to the Licensor shall be under the terms and conditions of
         
     | 
| 135 | 
         
            +
                  this License, without any additional terms or conditions.
         
     | 
| 136 | 
         
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         
     | 
| 137 | 
         
            +
                  the terms of any separate license agreement you may have executed
         
     | 
| 138 | 
         
            +
                  with Licensor regarding such Contributions.
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
               6. Trademarks. This License does not grant permission to use the trade
         
     | 
| 141 | 
         
            +
                  names, trademarks, service marks, or product names of the Licensor,
         
     | 
| 142 | 
         
            +
                  except as required for reasonable and customary use in describing the
         
     | 
| 143 | 
         
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         
     | 
| 146 | 
         
            +
                  agreed to in writing, Licensor provides the Work (and each
         
     | 
| 147 | 
         
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         
     | 
| 148 | 
         
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         
     | 
| 149 | 
         
            +
                  implied, including, without limitation, any warranties or conditions
         
     | 
| 150 | 
         
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         
     | 
| 151 | 
         
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         
     | 
| 152 | 
         
            +
                  appropriateness of using or redistributing the Work and assume any
         
     | 
| 153 | 
         
            +
                  risks associated with Your exercise of permissions under this License.
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
               8. Limitation of Liability. In no event and under no legal theory,
         
     | 
| 156 | 
         
            +
                  whether in tort (including negligence), contract, or otherwise,
         
     | 
| 157 | 
         
            +
                  unless required by applicable law (such as deliberate and grossly
         
     | 
| 158 | 
         
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         
     | 
| 159 | 
         
            +
                  liable to You for damages, including any direct, indirect, special,
         
     | 
| 160 | 
         
            +
                  incidental, or consequential damages of any character arising as a
         
     | 
| 161 | 
         
            +
                  result of this License or out of the use or inability to use the
         
     | 
| 162 | 
         
            +
                  Work (including but not limited to damages for loss of goodwill,
         
     | 
| 163 | 
         
            +
                  work stoppage, computer failure or malfunction, or any and all
         
     | 
| 164 | 
         
            +
                  other commercial damages or losses), even if such Contributor
         
     | 
| 165 | 
         
            +
                  has been advised of the possibility of such damages.
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         
     | 
| 168 | 
         
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         
     | 
| 169 | 
         
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         
     | 
| 170 | 
         
            +
                  or other liability obligations and/or rights consistent with this
         
     | 
| 171 | 
         
            +
                  License. However, in accepting such obligations, You may act only
         
     | 
| 172 | 
         
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         
     | 
| 173 | 
         
            +
                  of any other Contributor, and only if You agree to indemnify,
         
     | 
| 174 | 
         
            +
                  defend, and hold each Contributor harmless for any liability
         
     | 
| 175 | 
         
            +
                  incurred by, or claims asserted against, such Contributor by reason
         
     | 
| 176 | 
         
            +
                  of your accepting any such warranty or additional liability.
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
               END OF TERMS AND CONDITIONS
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
               APPENDIX: How to apply the Apache License to your work.
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                  To apply the Apache License to your work, attach the following
         
     | 
| 183 | 
         
            +
                  boilerplate notice, with the fields enclosed by brackets "[]"
         
     | 
| 184 | 
         
            +
                  replaced with your own identifying information. (Don't include
         
     | 
| 185 | 
         
            +
                  the brackets!)  The text should be enclosed in the appropriate
         
     | 
| 186 | 
         
            +
                  comment syntax for the file format. We also recommend that a
         
     | 
| 187 | 
         
            +
                  file or class name and description of purpose be included on the
         
     | 
| 188 | 
         
            +
                  same "printed page" as the copyright notice for easier
         
     | 
| 189 | 
         
            +
                  identification within third-party archives.
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
               Copyright [yyyy] [name of copyright owner]
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 194 | 
         
            +
               you may not use this file except in compliance with the License.
         
     | 
| 195 | 
         
            +
               You may obtain a copy of the License at
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
               Unless required by applicable law or agreed to in writing, software
         
     | 
| 200 | 
         
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 201 | 
         
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 202 | 
         
            +
               See the License for the specific language governing permissions and
         
     | 
| 203 | 
         
            +
               limitations under the License.
         
     | 
    	
        marlin/sparse/common/base.h
    ADDED
    
    | 
         @@ -0,0 +1,51 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /*
         
     | 
| 2 | 
         
            +
             * Copyright (C) 2024 Roberto Lopez Castro ([email protected]). All
         
     | 
| 3 | 
         
            +
             * Rights Reserved.
         
     | 
| 4 | 
         
            +
             *
         
     | 
| 5 | 
         
            +
             * Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 6 | 
         
            +
             * you may not use this file except in compliance with the License.
         
     | 
| 7 | 
         
            +
             * You may obtain a copy of the License at
         
     | 
| 8 | 
         
            +
             *
         
     | 
| 9 | 
         
            +
             *       http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 10 | 
         
            +
             *
         
     | 
| 11 | 
         
            +
             * Unless required by applicable law or agreed to in writing, software
         
     | 
| 12 | 
         
            +
             * distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 13 | 
         
            +
             * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 14 | 
         
            +
             * See the License for the specific language governing permissions and
         
     | 
| 15 | 
         
            +
             * limitations under the License.
         
     | 
| 16 | 
         
            +
             */
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            #pragma once
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            namespace marlin_24 {
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            // Instances of `Vec` are used to organize groups of >>registers<<, as needed
         
     | 
| 25 | 
         
            +
            // for instance as inputs to tensor core operations. Consequently, all
         
     | 
| 26 | 
         
            +
            // corresponding index accesses must be compile-time constants, which is why we
         
     | 
| 27 | 
         
            +
            // extensively use `#pragma unroll` throughout the kernel code to guarantee
         
     | 
| 28 | 
         
            +
            // this.
         
     | 
| 29 | 
         
            +
            template <typename T, int n>
         
     | 
| 30 | 
         
            +
            struct Vec {
         
     | 
| 31 | 
         
            +
              T elems[n];
         
     | 
| 32 | 
         
            +
              __device__ T& operator[](int i) { return elems[i]; }
         
     | 
| 33 | 
         
            +
            };
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            template <int M_, int N_, int K_>
         
     | 
| 36 | 
         
            +
            struct ShapeBase {
         
     | 
| 37 | 
         
            +
              static constexpr int M = M_, N = N_, K = K_;
         
     | 
| 38 | 
         
            +
            };
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            using I4 = Vec<int, 4>;
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            // Matrix fragments for tensor core instructions; their precise layout is
         
     | 
| 43 | 
         
            +
            // documented here:
         
     | 
| 44 | 
         
            +
            // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
         
     | 
| 45 | 
         
            +
            using FragA = Vec<half2, 4>;
         
     | 
| 46 | 
         
            +
            using FragB = Vec<half2, 2>;
         
     | 
| 47 | 
         
            +
            using FragM = Vec<uint, 1>;
         
     | 
| 48 | 
         
            +
            using FragC = Vec<float, 4>;
         
     | 
| 49 | 
         
            +
            using FragS = Vec<half2, 1>;  // quantization scales
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            }  // namespace marlin_24
         
     | 
    	
        marlin/sparse/common/mem.h
    ADDED
    
    | 
         @@ -0,0 +1,136 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /*
         
     | 
| 2 | 
         
            +
             * Copyright (C) 2024 Roberto Lopez Castro ([email protected]). All
         
     | 
| 3 | 
         
            +
             * Rights Reserved.
         
     | 
| 4 | 
         
            +
             *
         
     | 
| 5 | 
         
            +
             * Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 6 | 
         
            +
             * you may not use this file except in compliance with the License.
         
     | 
| 7 | 
         
            +
             * You may obtain a copy of the License at
         
     | 
| 8 | 
         
            +
             *
         
     | 
| 9 | 
         
            +
             *       http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 10 | 
         
            +
             *
         
     | 
| 11 | 
         
            +
             * Unless required by applicable law or agreed to in writing, software
         
     | 
| 12 | 
         
            +
             * distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 13 | 
         
            +
             * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 14 | 
         
            +
             * See the License for the specific language governing permissions and
         
     | 
| 15 | 
         
            +
             * limitations under the License.
         
     | 
| 16 | 
         
            +
             */
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            #pragma once
         
     | 
| 19 | 
         
            +
            #include "base.h"
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            namespace marlin_24 {
         
     | 
| 22 | 
         
            +
            // Predicated asynchronous global->shared copy; used for inputs A where we apply
         
     | 
| 23 | 
         
            +
            // predication to handle batchsizes that are not multiples of 16.
         
     | 
| 24 | 
         
            +
            __device__ inline void cp_async4_pred_zfill(void* smem_ptr,
         
     | 
| 25 | 
         
            +
                                                        const void* glob_ptr,
         
     | 
| 26 | 
         
            +
                                                        bool pred = true,
         
     | 
| 27 | 
         
            +
                                                        const bool zfill = false) {
         
     | 
| 28 | 
         
            +
              const int BYTES = 16;
         
     | 
| 29 | 
         
            +
              int src_in_bytes = (zfill ? 0 : BYTES);
         
     | 
| 30 | 
         
            +
              uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
         
     | 
| 31 | 
         
            +
              asm volatile(
         
     | 
| 32 | 
         
            +
                  "{\n"
         
     | 
| 33 | 
         
            +
                  "   .reg .pred p;\n"
         
     | 
| 34 | 
         
            +
                  "   setp.ne.b32 p, %0, 0;\n"
         
     | 
| 35 | 
         
            +
                  "   @p cp.async.cg.shared.global [%1], [%2], %3;\n"
         
     | 
| 36 | 
         
            +
                  "}\n" ::"r"((int)pred),
         
     | 
| 37 | 
         
            +
                  "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
         
     | 
| 38 | 
         
            +
            }
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
         
     | 
| 41 | 
         
            +
                                                  bool pred = true) {
         
     | 
| 42 | 
         
            +
              const int BYTES = 16;
         
     | 
| 43 | 
         
            +
              uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
         
     | 
| 44 | 
         
            +
              asm volatile(
         
     | 
| 45 | 
         
            +
                  "{\n"
         
     | 
| 46 | 
         
            +
                  "   .reg .pred p;\n"
         
     | 
| 47 | 
         
            +
                  "   setp.ne.b32 p, %0, 0;\n"
         
     | 
| 48 | 
         
            +
                  "   @p cp.async.cg.shared.global [%1], [%2], %3;\n"
         
     | 
| 49 | 
         
            +
                  "}\n" ::"r"((int)pred),
         
     | 
| 50 | 
         
            +
                  "r"(smem), "l"(glob_ptr), "n"(BYTES));
         
     | 
| 51 | 
         
            +
            }
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            // Asynchronous global->shared copy
         
     | 
| 54 | 
         
            +
            __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
         
     | 
| 55 | 
         
            +
              const int BYTES = 16;
         
     | 
| 56 | 
         
            +
              uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
         
     | 
| 57 | 
         
            +
              asm volatile(
         
     | 
| 58 | 
         
            +
                  "{\n"
         
     | 
| 59 | 
         
            +
                  "   cp.async.cg.shared.global [%0], [%1], %2;\n"
         
     | 
| 60 | 
         
            +
                  "}\n" ::"r"(smem),
         
     | 
| 61 | 
         
            +
                  "l"(glob_ptr), "n"(BYTES));
         
     | 
| 62 | 
         
            +
            }
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            // Async copy fence.
         
     | 
| 65 | 
         
            +
            __device__ inline void cp_async_fence() {
         
     | 
| 66 | 
         
            +
              asm volatile("cp.async.commit_group;\n" ::);
         
     | 
| 67 | 
         
            +
            }
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            // Wait until at most `n` async copy stages are still pending.
         
     | 
| 70 | 
         
            +
            template <int n>
         
     | 
| 71 | 
         
            +
            __device__ inline void cp_async_wait() {
         
     | 
| 72 | 
         
            +
              asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
         
     | 
| 73 | 
         
            +
            }
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            // Instruction for loading a full 16x16 matrix fragment of operand A from shared
         
     | 
| 76 | 
         
            +
            // memory, directly in tensor core layout.
         
     | 
| 77 | 
         
            +
            __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
         
     | 
| 78 | 
         
            +
              uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
         
     | 
| 79 | 
         
            +
              uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
         
     | 
| 80 | 
         
            +
              asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
         
     | 
| 81 | 
         
            +
                           : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
         
     | 
| 82 | 
         
            +
                           : "r"(smem));
         
     | 
| 83 | 
         
            +
            }
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            __device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
         
     | 
| 86 | 
         
            +
              uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
         
     | 
| 87 | 
         
            +
              uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
         
     | 
| 88 | 
         
            +
              asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
         
     | 
| 89 | 
         
            +
                           : "=r"(a[0]), "=r"(a[1])
         
     | 
| 90 | 
         
            +
                           : "r"(smem));
         
     | 
| 91 | 
         
            +
            }
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            // Instruction for loading a full 16x16 matrix fragment of operand A from shared
         
     | 
| 94 | 
         
            +
            // memory, directly in tensor core layout.
         
     | 
| 95 | 
         
            +
            __device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
         
     | 
| 96 | 
         
            +
              uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
         
     | 
| 97 | 
         
            +
              uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
         
     | 
| 98 | 
         
            +
              asm volatile(
         
     | 
| 99 | 
         
            +
                  "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
         
     | 
| 100 | 
         
            +
                  : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
         
     | 
| 101 | 
         
            +
                  : "r"(smem));
         
     | 
| 102 | 
         
            +
            }
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            // Wait until barrier reaches `count`, then lock for current threadblock.
         
     | 
| 105 | 
         
            +
            __device__ inline void barrier_acquire(int* lock, int count) {
         
     | 
| 106 | 
         
            +
              if (threadIdx.x == 0) {
         
     | 
| 107 | 
         
            +
                int state = -1;
         
     | 
| 108 | 
         
            +
                do
         
     | 
| 109 | 
         
            +
                  // Guarantee that subsequent writes by this threadblock will be visible
         
     | 
| 110 | 
         
            +
                  // globally.
         
     | 
| 111 | 
         
            +
                  asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
         
     | 
| 112 | 
         
            +
                               : "=r"(state)
         
     | 
| 113 | 
         
            +
                               : "l"(lock));
         
     | 
| 114 | 
         
            +
                while (state != count);
         
     | 
| 115 | 
         
            +
              }
         
     | 
| 116 | 
         
            +
              __syncthreads();
         
     | 
| 117 | 
         
            +
            }
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
            // Release barrier and increment visitation count.
         
     | 
| 120 | 
         
            +
            __device__ inline void barrier_release(int* lock, bool reset = false) {
         
     | 
| 121 | 
         
            +
              __syncthreads();
         
     | 
| 122 | 
         
            +
              if (threadIdx.x == 0) {
         
     | 
| 123 | 
         
            +
                if (reset) {
         
     | 
| 124 | 
         
            +
                  lock[0] = 0;
         
     | 
| 125 | 
         
            +
                  return;
         
     | 
| 126 | 
         
            +
                }
         
     | 
| 127 | 
         
            +
                int val = 1;
         
     | 
| 128 | 
         
            +
                // Make sure that all writes since acquiring this barrier are visible
         
     | 
| 129 | 
         
            +
                // globally, while releasing the barrier.
         
     | 
| 130 | 
         
            +
                asm volatile("fence.acq_rel.gpu;\n");
         
     | 
| 131 | 
         
            +
                asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
         
     | 
| 132 | 
         
            +
                             :
         
     | 
| 133 | 
         
            +
                             : "l"(lock), "r"(val));
         
     | 
| 134 | 
         
            +
              }
         
     | 
| 135 | 
         
            +
            }
         
     | 
| 136 | 
         
            +
            }  // namespace marlin_24
         
     | 
    	
        marlin/sparse/common/mma.h
    ADDED
    
    | 
         @@ -0,0 +1,191 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /*
         
     | 
| 2 | 
         
            +
             * Copyright (C) 2024 Roberto Lopez Castro ([email protected]). All
         
     | 
| 3 | 
         
            +
             * Rights Reserved.
         
     | 
| 4 | 
         
            +
             *
         
     | 
| 5 | 
         
            +
             * Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 6 | 
         
            +
             * you may not use this file except in compliance with the License.
         
     | 
| 7 | 
         
            +
             * You may obtain a copy of the License at
         
     | 
| 8 | 
         
            +
             *
         
     | 
| 9 | 
         
            +
             *       http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 10 | 
         
            +
             *
         
     | 
| 11 | 
         
            +
             * Unless required by applicable law or agreed to in writing, software
         
     | 
| 12 | 
         
            +
             * distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 13 | 
         
            +
             * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 14 | 
         
            +
             * See the License for the specific language governing permissions and
         
     | 
| 15 | 
         
            +
             * limitations under the License.
         
     | 
| 16 | 
         
            +
             */
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            #pragma once
         
     | 
| 19 | 
         
            +
            #include "base.h"
         
     | 
| 20 | 
         
            +
            #include <cudaTypedefs.h>
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            namespace marlin_24 {
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            // On CUDA earlier than 12.5, the ordered_metadata version of this instruction
         
     | 
| 25 | 
         
            +
            // is not supported. On later versions of CUDA the version without ordered
         
     | 
| 26 | 
         
            +
            // metadata results in the following warning:
         
     | 
| 27 | 
         
            +
            //  | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction
         
     | 
| 28 | 
         
            +
            //  | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially
         
     | 
| 29 | 
         
            +
            //  | reduced performance on some future architectures
         
     | 
| 30 | 
         
            +
            #if defined CUDA_VERSION && CUDA_VERSION >= 12050
         
     | 
| 31 | 
         
            +
              #define MMA_SP_INST \
         
     | 
| 32 | 
         
            +
                "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
         
     | 
| 33 | 
         
            +
            #else
         
     | 
| 34 | 
         
            +
              #define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
         
     | 
| 35 | 
         
            +
            #endif
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            // m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
         
     | 
| 38 | 
         
            +
            // output/accumulation.
         
     | 
| 39 | 
         
            +
            __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
         
     | 
| 40 | 
         
            +
                                          const FragA& frag_b, FragC& frag_c, FragM& frag_m,
         
     | 
| 41 | 
         
            +
                                          const int psel) {
         
     | 
| 42 | 
         
            +
              const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0);
         
     | 
| 43 | 
         
            +
              const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
         
     | 
| 44 | 
         
            +
              const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
         
     | 
| 45 | 
         
            +
              const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
              float* c = reinterpret_cast<float*>(&frag_c);
         
     | 
| 48 | 
         
            +
              if (psel == 0) {
         
     | 
| 49 | 
         
            +
                asm volatile(MMA_SP_INST
         
     | 
| 50 | 
         
            +
                             "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
         
     | 
| 51 | 
         
            +
                             "{%12,%13,%14,%15}, %16, 0x0;\n"
         
     | 
| 52 | 
         
            +
                             : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
         
     | 
| 53 | 
         
            +
                             : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
         
     | 
| 54 | 
         
            +
                               "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
         
     | 
| 55 | 
         
            +
                               "f"(c[2]), "f"(c[3]), "r"(e[0]));
         
     | 
| 56 | 
         
            +
                asm volatile(MMA_SP_INST
         
     | 
| 57 | 
         
            +
                             "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
         
     | 
| 58 | 
         
            +
                             "{%12,%13,%14,%15}, %16, 0x0;\n"
         
     | 
| 59 | 
         
            +
                             : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
         
     | 
| 60 | 
         
            +
                             : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
         
     | 
| 61 | 
         
            +
                               "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
         
     | 
| 62 | 
         
            +
                               "f"(c[6]), "f"(c[7]), "r"(e[0]));
         
     | 
| 63 | 
         
            +
              } else {
         
     | 
| 64 | 
         
            +
                asm volatile(MMA_SP_INST
         
     | 
| 65 | 
         
            +
                             "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
         
     | 
| 66 | 
         
            +
                             "{%12,%13,%14,%15}, %16, 0x1;\n"
         
     | 
| 67 | 
         
            +
                             : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
         
     | 
| 68 | 
         
            +
                             : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
         
     | 
| 69 | 
         
            +
                               "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
         
     | 
| 70 | 
         
            +
                               "f"(c[2]), "f"(c[3]), "r"(e[0]));
         
     | 
| 71 | 
         
            +
                asm volatile(MMA_SP_INST
         
     | 
| 72 | 
         
            +
                             "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
         
     | 
| 73 | 
         
            +
                             "{%12,%13,%14,%15}, %16, 0x1;\n"
         
     | 
| 74 | 
         
            +
                             : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
         
     | 
| 75 | 
         
            +
                             : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
         
     | 
| 76 | 
         
            +
                               "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
         
     | 
| 77 | 
         
            +
                               "f"(c[6]), "f"(c[7]), "r"(e[0]));
         
     | 
| 78 | 
         
            +
              }
         
     | 
| 79 | 
         
            +
            }
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            // Lookup-table based 3-input logical operation; explicitly used for
         
     | 
| 82 | 
         
            +
            // dequantization as the compiler does not seem to automatically recognize it in
         
     | 
| 83 | 
         
            +
            // all cases.
         
     | 
| 84 | 
         
            +
            template <int lut>
         
     | 
| 85 | 
         
            +
            __device__ inline int lop3(int a, int b, int c) {
         
     | 
| 86 | 
         
            +
              int res;
         
     | 
| 87 | 
         
            +
              asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
         
     | 
| 88 | 
         
            +
                           : "=r"(res)
         
     | 
| 89 | 
         
            +
                           : "r"(a), "r"(b), "r"(c), "n"(lut));
         
     | 
| 90 | 
         
            +
              return res;
         
     | 
| 91 | 
         
            +
            }
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            __device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2,
         
     | 
| 94 | 
         
            +
                                                      float c3) {
         
     | 
| 95 | 
         
            +
              uint2 r;
         
     | 
| 96 | 
         
            +
              asm("{\n\t"
         
     | 
| 97 | 
         
            +
                  ".reg .f16 a, b, c, d; \n\t"
         
     | 
| 98 | 
         
            +
                  "cvt.rn.f16.f32 a, %2; \n\t"
         
     | 
| 99 | 
         
            +
                  "cvt.rn.f16.f32 b, %3; \n\t"
         
     | 
| 100 | 
         
            +
                  "cvt.rn.f16.f32 c, %4; \n\t"
         
     | 
| 101 | 
         
            +
                  "cvt.rn.f16.f32 d, %5; \n\t"
         
     | 
| 102 | 
         
            +
                  "mov.b32 %0, {a, b};   \n\t"
         
     | 
| 103 | 
         
            +
                  "mov.b32 %1, {c, d};   \n\t"
         
     | 
| 104 | 
         
            +
                  "}"
         
     | 
| 105 | 
         
            +
                  : "=r"(r.x), "=r"(r.y)
         
     | 
| 106 | 
         
            +
                  : "f"(c0), "f"(c1), "f"(c2), "f"(c3));
         
     | 
| 107 | 
         
            +
              return r;
         
     | 
| 108 | 
         
            +
            }
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            // Constructs destination register by taking bytes from 2 sources (based on
         
     | 
| 111 | 
         
            +
            // mask)
         
     | 
| 112 | 
         
            +
            template <int start_byte, int mask>
         
     | 
| 113 | 
         
            +
            __device__ inline uint32_t prmt(uint32_t a) {
         
     | 
| 114 | 
         
            +
              uint32_t res;
         
     | 
| 115 | 
         
            +
              asm volatile("prmt.b32 %0, %1, %2, %3;\n"
         
     | 
| 116 | 
         
            +
                           : "=r"(res)
         
     | 
| 117 | 
         
            +
                           : "r"(a), "n"(start_byte), "n"(mask));
         
     | 
| 118 | 
         
            +
              return res;
         
     | 
| 119 | 
         
            +
            }
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
            // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
         
     | 
| 122 | 
         
            +
            // values. We mostly follow the strategy in the link below, with some small
         
     | 
| 123 | 
         
            +
            // changes:
         
     | 
| 124 | 
         
            +
            // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
         
     | 
| 125 | 
         
            +
            __device__ inline FragB dequant_4bit(int q) {
         
     | 
| 126 | 
         
            +
              const int LO = 0x000f000f;
         
     | 
| 127 | 
         
            +
              const int HI = 0x00f000f0;
         
     | 
| 128 | 
         
            +
              const int EX = 0x64006400;
         
     | 
| 129 | 
         
            +
              // Guarantee that the `(a & b) | c` operations are LOP3s.
         
     | 
| 130 | 
         
            +
              int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
         
     | 
| 131 | 
         
            +
              int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
         
     | 
| 132 | 
         
            +
              // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
         
     | 
| 133 | 
         
            +
              // directly into `SUB` and `ADD`.
         
     | 
| 134 | 
         
            +
              const int SUB = 0x64086408;
         
     | 
| 135 | 
         
            +
              const int MUL = 0x2c002c00;
         
     | 
| 136 | 
         
            +
              const int ADD = 0xd480d480;
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
              FragB frag_b;
         
     | 
| 139 | 
         
            +
              frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
         
     | 
| 140 | 
         
            +
                                  *reinterpret_cast<const half2*>(&SUB));
         
     | 
| 141 | 
         
            +
              frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
         
     | 
| 142 | 
         
            +
                                  *reinterpret_cast<const half2*>(&MUL),
         
     | 
| 143 | 
         
            +
                                  *reinterpret_cast<const half2*>(&ADD));
         
     | 
| 144 | 
         
            +
              return frag_b;
         
     | 
| 145 | 
         
            +
            }
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
            // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
         
     | 
| 148 | 
         
            +
            // values. We mostly follow the strategy in the link below, with some small
         
     | 
| 149 | 
         
            +
            // changes:
         
     | 
| 150 | 
         
            +
            // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
         
     | 
| 151 | 
         
            +
            __device__ inline FragB dequant_8bit(int q) {
         
     | 
| 152 | 
         
            +
              static constexpr uint32_t mask_for_elt_01 = 0x5250;
         
     | 
| 153 | 
         
            +
              static constexpr uint32_t mask_for_elt_23 = 0x5351;
         
     | 
| 154 | 
         
            +
              static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
              uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
         
     | 
| 157 | 
         
            +
              uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
              static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
              FragB frag_b;
         
     | 
| 162 | 
         
            +
              frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
         
     | 
| 163 | 
         
            +
                                  *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
         
     | 
| 164 | 
         
            +
              frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
         
     | 
| 165 | 
         
            +
                                  *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
         
     | 
| 166 | 
         
            +
              return frag_b;
         
     | 
| 167 | 
         
            +
            }
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
            // Multiply dequantized values by the corresponding quantization scale; used
         
     | 
| 170 | 
         
            +
            // only for grouped quantization.
         
     | 
| 171 | 
         
            +
            __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
         
     | 
| 172 | 
         
            +
              half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
         
     | 
| 173 | 
         
            +
              frag_b[0] = __hmul2(frag_b[0], s);
         
     | 
| 174 | 
         
            +
              frag_b[1] = __hmul2(frag_b[1], s);
         
     | 
| 175 | 
         
            +
            }
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
            __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
         
     | 
| 178 | 
         
            +
                                                FragS& s0, float* c4, float* c5, float* c6,
         
     | 
| 179 | 
         
            +
                                                float* c7, FragS& s1) {
         
     | 
| 180 | 
         
            +
              *c0 = __fmul_rn(*c0, __half2float(s0[0].x));
         
     | 
| 181 | 
         
            +
              *c1 = __fmul_rn(*c1, __half2float(s0[0].y));
         
     | 
| 182 | 
         
            +
              *c2 = __fmul_rn(*c2, __half2float(s0[1].x));
         
     | 
| 183 | 
         
            +
              *c3 = __fmul_rn(*c3, __half2float(s0[1].y));
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
              *c4 = __fmul_rn(*c4, __half2float(s1[0].x));
         
     | 
| 186 | 
         
            +
              *c5 = __fmul_rn(*c5, __half2float(s1[0].y));
         
     | 
| 187 | 
         
            +
              *c6 = __fmul_rn(*c6, __half2float(s1[1].x));
         
     | 
| 188 | 
         
            +
              *c7 = __fmul_rn(*c7, __half2float(s1[1].y));
         
     | 
| 189 | 
         
            +
            }
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
            }  // namespace marlin_24
         
     | 
    	
        marlin/sparse/marlin_24_cuda_kernel.cu
    ADDED
    
    | 
         @@ -0,0 +1,1140 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /*
         
     | 
| 2 | 
         
            +
             * Notice: This file was modified by Neuralmagic inc to include 8-bit support
         
     | 
| 3 | 
         
            +
             *
         
     | 
| 4 | 
         
            +
             * Copyright (C) 2024 Roberto Lopez Castro ([email protected]). All
         
     | 
| 5 | 
         
            +
             * Rights Reserved.
         
     | 
| 6 | 
         
            +
             *
         
     | 
| 7 | 
         
            +
             * Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 8 | 
         
            +
             * you may not use this file except in compliance with the License.
         
     | 
| 9 | 
         
            +
             * You may obtain a copy of the License at
         
     | 
| 10 | 
         
            +
             *
         
     | 
| 11 | 
         
            +
             *       http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 12 | 
         
            +
             *
         
     | 
| 13 | 
         
            +
             * Unless required by applicable law or agreed to in writing, software
         
     | 
| 14 | 
         
            +
             * distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 15 | 
         
            +
             * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 16 | 
         
            +
             * See the License for the specific language governing permissions and
         
     | 
| 17 | 
         
            +
             * limitations under the License.
         
     | 
| 18 | 
         
            +
             */
         
     | 
| 19 | 
         
            +
            #include <torch/all.h>
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            #include <ATen/cuda/CUDAContext.h>
         
     | 
| 22 | 
         
            +
            #include <c10/cuda/CUDAGuard.h>
         
     | 
| 23 | 
         
            +
            #include <cuda.h>
         
     | 
| 24 | 
         
            +
            #include <cuda_fp16.h>
         
     | 
| 25 | 
         
            +
            #include <cuda_runtime.h>
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            #include <iostream>
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            #include "common/base.h"
         
     | 
| 30 | 
         
            +
            #include "core/scalar_type.hpp"
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            #else
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
              #include "common/mem.h"
         
     | 
| 37 | 
         
            +
              #include "common/mma.h"
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            #endif
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            template <typename T>
         
     | 
| 42 | 
         
            +
            inline std::string str(T x) {
         
     | 
| 43 | 
         
            +
              return std::to_string(x);
         
     | 
| 44 | 
         
            +
            }
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            namespace marlin_24 {
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            // 8 warps are a good choice since every SM has 4 schedulers and having more
         
     | 
| 49 | 
         
            +
            // than 1 warp per schedule allows some more latency hiding. At the same time,
         
     | 
| 50 | 
         
            +
            // we want relatively few warps to have many registers per warp and small tiles.
         
     | 
| 51 | 
         
            +
            static constexpr int THREADS = 256;
         
     | 
| 52 | 
         
            +
            static constexpr int STAGES = 4;
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            static constexpr int min_thread_n = 128;
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            static constexpr int tile_size = 16;
         
     | 
| 57 | 
         
            +
            static constexpr int max_par = 64;
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            template <const int num_bits,         // weight bits
         
     | 
| 62 | 
         
            +
                      const int threads,          // number of threads in a threadblock
         
     | 
| 63 | 
         
            +
                      const int thread_m_blocks,  // number of 16x16 blocks in the m
         
     | 
| 64 | 
         
            +
                                                  // dimension (batchsize) of the
         
     | 
| 65 | 
         
            +
                                                  // threadblock
         
     | 
| 66 | 
         
            +
                      const int thread_n_blocks,  // same for n dimension (output)
         
     | 
| 67 | 
         
            +
                      const int thread_k_blocks,  // same for k dimension (reduction)
         
     | 
| 68 | 
         
            +
                      const int stages,  // number of stages for the async global->shared
         
     | 
| 69 | 
         
            +
                                         // fetch pipeline
         
     | 
| 70 | 
         
            +
                      const int group_blocks = -1  // number of consecutive 16x16 blocks
         
     | 
| 71 | 
         
            +
                                                   // with a separate quantization scale
         
     | 
| 72 | 
         
            +
                      >
         
     | 
| 73 | 
         
            +
            __global__ void Marlin_24(
         
     | 
| 74 | 
         
            +
                const int4* __restrict__ A,     // fp16 input matrix of shape mxk
         
     | 
| 75 | 
         
            +
                const int4* __restrict__ B,     // 4bit quantized weight matrix of shape kxn
         
     | 
| 76 | 
         
            +
                const int4* __restrict__ meta,  // 2bit metadata information about 2:4
         
     | 
| 77 | 
         
            +
                                                // format on B
         
     | 
| 78 | 
         
            +
                int4* __restrict__ C,           // fp16 output buffer of shape mxn
         
     | 
| 79 | 
         
            +
                const int4* __restrict__ s,     // fp16 quantization scales of shape
         
     | 
| 80 | 
         
            +
                                                // (k/groupsize)xn
         
     | 
| 81 | 
         
            +
                int prob_m,                     // batch dimension m
         
     | 
| 82 | 
         
            +
                int prob_n,                     // output dimension n
         
     | 
| 83 | 
         
            +
                int prob_k,                     // reduction dimension k
         
     | 
| 84 | 
         
            +
                int* locks  // extra global storage for barrier synchronization
         
     | 
| 85 | 
         
            +
            ) {}
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
         
     | 
| 88 | 
         
            +
                                              torch::Tensor& b_meta,
         
     | 
| 89 | 
         
            +
                                              torch::Tensor& b_scales,
         
     | 
| 90 | 
         
            +
                                              torch::Tensor& workspace,
         
     | 
| 91 | 
         
            +
                                              vllm::ScalarTypeId const b_q_type_id,
         
     | 
| 92 | 
         
            +
                                              int64_t size_m, int64_t size_n,
         
     | 
| 93 | 
         
            +
                                              int64_t size_k) {
         
     | 
| 94 | 
         
            +
              TORCH_CHECK_NOT_IMPLEMENTED(
         
     | 
| 95 | 
         
            +
                  false, "gptq_marlin_24_gemm(..) requires CUDA_ARCH >= 8.0");
         
     | 
| 96 | 
         
            +
              return torch::empty({1, 1});
         
     | 
| 97 | 
         
            +
            }
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            #else
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            template <const int num_bits,         // weight bits
         
     | 
| 102 | 
         
            +
                      const int threads,          // number of threads in a threadblock
         
     | 
| 103 | 
         
            +
                      const int thread_m_blocks,  // number of 16x16 blocks in the m
         
     | 
| 104 | 
         
            +
                                                  // dimension (batchsize) of the
         
     | 
| 105 | 
         
            +
                                                  // threadblock
         
     | 
| 106 | 
         
            +
                      const int thread_n_blocks,  // same for n dimension (output)
         
     | 
| 107 | 
         
            +
                      const int thread_k_blocks,  // same for k dimension (reduction)
         
     | 
| 108 | 
         
            +
                      const int stages,  // number of stages for the async global->shared
         
     | 
| 109 | 
         
            +
                                         // fetch pipeline
         
     | 
| 110 | 
         
            +
                      const int group_blocks = -1  // number of consecutive 16x16 blocks
         
     | 
| 111 | 
         
            +
                                                   // with a separate quantization scale
         
     | 
| 112 | 
         
            +
                      >
         
     | 
| 113 | 
         
            +
            __global__ void Marlin_24(
         
     | 
| 114 | 
         
            +
                const int4* __restrict__ A,     // fp16 input matrix of shape mxk
         
     | 
| 115 | 
         
            +
                const int4* __restrict__ B,     // 4bit quantized weight matrix of shape kxn
         
     | 
| 116 | 
         
            +
                const int4* __restrict__ meta,  // 2bit metadata information about 2:4
         
     | 
| 117 | 
         
            +
                                                // format on B
         
     | 
| 118 | 
         
            +
                int4* __restrict__ C,           // fp16 output buffer of shape mxn
         
     | 
| 119 | 
         
            +
                const int4* __restrict__ s,     // fp16 quantization scales of shape
         
     | 
| 120 | 
         
            +
                                                // (k/groupsize)xn
         
     | 
| 121 | 
         
            +
                int prob_m,                     // batch dimension m
         
     | 
| 122 | 
         
            +
                int prob_n,                     // output dimension n
         
     | 
| 123 | 
         
            +
                int prob_k,                     // reduction dimension k
         
     | 
| 124 | 
         
            +
                int* locks  // extra global storage for barrier synchronization
         
     | 
| 125 | 
         
            +
            ) {
         
     | 
| 126 | 
         
            +
              // Each threadblock processes one "stripe" of the B matrix with (roughly) the
         
     | 
| 127 | 
         
            +
              // same size, which might involve multiple column "slices" (of width 16 *
         
     | 
| 128 | 
         
            +
              // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
         
     | 
| 129 | 
         
            +
              // example:
         
     | 
| 130 | 
         
            +
              //   0 1 3
         
     | 
| 131 | 
         
            +
              //   0 2 3
         
     | 
| 132 | 
         
            +
              //   1 2 4
         
     | 
| 133 | 
         
            +
              // While this kind of partitioning makes things somewhat more complicated, it
         
     | 
| 134 | 
         
            +
              // ensures good utilization of all SMs for many kinds of shape and GPU
         
     | 
| 135 | 
         
            +
              // configurations, while requiring as few slow global cross-threadblock
         
     | 
| 136 | 
         
            +
              // reductions as possible.
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
              // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
         
     | 
| 139 | 
         
            +
              // better partitioning with less reductions
         
     | 
| 140 | 
         
            +
              int parallel = 1;
         
     | 
| 141 | 
         
            +
              if (prob_m > 16 * thread_m_blocks) {
         
     | 
| 142 | 
         
            +
                parallel = prob_m / (16 * thread_m_blocks);
         
     | 
| 143 | 
         
            +
                prob_m = 16 * thread_m_blocks;
         
     | 
| 144 | 
         
            +
              }
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
              // number of thread_k_blocks in k-dim
         
     | 
| 147 | 
         
            +
              int k_tiles = prob_k / 32 / thread_k_blocks;
         
     | 
| 148 | 
         
            +
              // number of thread_n_blocks in n-dim
         
     | 
| 149 | 
         
            +
              int n_tiles = prob_n / 16 / thread_n_blocks;
         
     | 
| 150 | 
         
            +
              // iters needed to cover all slices
         
     | 
| 151 | 
         
            +
              int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
              // Ensure that the number of tiles in each stripe is a multiple of the
         
     | 
| 154 | 
         
            +
              // groupsize; this avoids an annoying special case where a stripe starts in
         
     | 
| 155 | 
         
            +
              // the middle of group.
         
     | 
| 156 | 
         
            +
              if (group_blocks != -1)
         
     | 
| 157 | 
         
            +
                iters = (group_blocks / thread_k_blocks) *
         
     | 
| 158 | 
         
            +
                        ceildiv(iters, (group_blocks / thread_k_blocks));
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
              int slice_row = (iters * blockIdx.x) % k_tiles;
         
     | 
| 161 | 
         
            +
              int slice_col_par = (iters * blockIdx.x) / k_tiles;
         
     | 
| 162 | 
         
            +
              int slice_col = slice_col_par;
         
     | 
| 163 | 
         
            +
              // number of threadblock tiles in the current slice
         
     | 
| 164 | 
         
            +
              int slice_iters;
         
     | 
| 165 | 
         
            +
              // total number of active threadblocks in the current slice
         
     | 
| 166 | 
         
            +
              int slice_count = 0;
         
     | 
| 167 | 
         
            +
              // index of threadblock in current slice; numbered bottom to top
         
     | 
| 168 | 
         
            +
              int slice_idx;
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
              // We can easily implement parallel problem execution by just remapping
         
     | 
| 171 | 
         
            +
              // indices and advancing global pointers
         
     | 
| 172 | 
         
            +
              if (slice_col_par >= n_tiles) {
         
     | 
| 173 | 
         
            +
                A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
         
     | 
| 174 | 
         
            +
                C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
         
     | 
| 175 | 
         
            +
                locks += (slice_col_par / n_tiles) * n_tiles;
         
     | 
| 176 | 
         
            +
                slice_col = slice_col_par % n_tiles;
         
     | 
| 177 | 
         
            +
              }
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
              // Compute all information about the current slice which is required for
         
     | 
| 180 | 
         
            +
              // synchronization.
         
     | 
| 181 | 
         
            +
              auto init_slice = [&]() {
         
     | 
| 182 | 
         
            +
                slice_iters =
         
     | 
| 183 | 
         
            +
                    iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
         
     | 
| 184 | 
         
            +
                if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
         
     | 
| 185 | 
         
            +
                if (slice_iters == 0) return;
         
     | 
| 186 | 
         
            +
                if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
         
     | 
| 187 | 
         
            +
                slice_count = 1;
         
     | 
| 188 | 
         
            +
                slice_idx = 0;
         
     | 
| 189 | 
         
            +
                int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
         
     | 
| 190 | 
         
            +
                if (col_first <= k_tiles * (slice_col_par + 1)) {
         
     | 
| 191 | 
         
            +
                  int col_off = col_first - k_tiles * slice_col_par;
         
     | 
| 192 | 
         
            +
                  slice_count = ceildiv(k_tiles - col_off, iters);
         
     | 
| 193 | 
         
            +
                  if (col_off > 0) slice_count++;
         
     | 
| 194 | 
         
            +
                  int delta_first = iters * blockIdx.x - col_first;
         
     | 
| 195 | 
         
            +
                  if (delta_first < 0 || (col_off == 0 && delta_first == 0))
         
     | 
| 196 | 
         
            +
                    slice_idx = slice_count - 1;
         
     | 
| 197 | 
         
            +
                  else {
         
     | 
| 198 | 
         
            +
                    slice_idx = slice_count - 1 - delta_first / iters;
         
     | 
| 199 | 
         
            +
                    if (col_off > 0) slice_idx--;
         
     | 
| 200 | 
         
            +
                  }
         
     | 
| 201 | 
         
            +
                }
         
     | 
| 202 | 
         
            +
                if (slice_col == n_tiles) {
         
     | 
| 203 | 
         
            +
                  A += 16 * thread_m_blocks * prob_k / 8;
         
     | 
| 204 | 
         
            +
                  C += 16 * thread_m_blocks * prob_n / 8;
         
     | 
| 205 | 
         
            +
                  locks += n_tiles;
         
     | 
| 206 | 
         
            +
                  slice_col = 0;
         
     | 
| 207 | 
         
            +
                }
         
     | 
| 208 | 
         
            +
              };
         
     | 
| 209 | 
         
            +
              init_slice();
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
              // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements
         
     | 
| 212 | 
         
            +
              int a_gl_stride = prob_k / 8;  // stride of the A matrix in global memory
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
              // stride of an A matrix tile in shared memory
         
     | 
| 215 | 
         
            +
              constexpr int a_sh_stride = 32 * thread_k_blocks / 8;
         
     | 
| 216 | 
         
            +
              // delta between subsequent A tiles in global memory
         
     | 
| 217 | 
         
            +
              constexpr int a_gl_rd_delta_o = 32 * thread_k_blocks / 8;
         
     | 
| 218 | 
         
            +
              // between subsequent accesses within a tile
         
     | 
| 219 | 
         
            +
              int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
         
     | 
| 220 | 
         
            +
              // between shared memory writes
         
     | 
| 221 | 
         
            +
              constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
         
     | 
| 222 | 
         
            +
              // between shared memory tile reads //RLC: 2 * #warps k-dim
         
     | 
| 223 | 
         
            +
              constexpr int a_sh_rd_delta_o = 4 * ((threads / 32) / (thread_n_blocks / 4));
         
     | 
| 224 | 
         
            +
              // within a shared memory tile
         
     | 
| 225 | 
         
            +
              constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
         
     | 
| 226 | 
         
            +
              // overall size of a tile
         
     | 
| 227 | 
         
            +
              constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
         
     | 
| 228 | 
         
            +
              // number of shared write iterations for a tile
         
     | 
| 229 | 
         
            +
              constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta);
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
              constexpr int pack_factor = 32 / num_bits;
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
              int b_gl_stride = 16 * prob_n / (pack_factor * 4);
         
     | 
| 234 | 
         
            +
              constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
         
     | 
| 235 | 
         
            +
              constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
         
     | 
| 236 | 
         
            +
              constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
         
     | 
| 237 | 
         
            +
              int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
         
     | 
| 238 | 
         
            +
              int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
         
     | 
| 239 | 
         
            +
              constexpr int b_sh_wr_delta = threads * b_thread_vecs;
         
     | 
| 240 | 
         
            +
              constexpr int b_sh_rd_delta = threads * b_thread_vecs;
         
     | 
| 241 | 
         
            +
              constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
         
     | 
| 242 | 
         
            +
              constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
              int m_gl_stride = 2 * prob_n / 8;  // (16*2*4 / 8) = 16
         
     | 
| 245 | 
         
            +
              constexpr int m_sh_stride =
         
     | 
| 246 | 
         
            +
                  (16 * thread_n_blocks) / 4;  // #warps n-dim * threads/warp
         
     | 
| 247 | 
         
            +
              int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks;
         
     | 
| 248 | 
         
            +
              int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride);
         
     | 
| 249 | 
         
            +
              constexpr int m_sh_wr_delta = threads / 2;
         
     | 
| 250 | 
         
            +
              constexpr int m_sh_rd_delta = threads / 2;
         
     | 
| 251 | 
         
            +
              constexpr int m_sh_stage = m_sh_stride * thread_k_blocks;
         
     | 
| 252 | 
         
            +
              constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta);
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
              int s_gl_stride = prob_n / 8;
         
     | 
| 255 | 
         
            +
              constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
         
     | 
| 256 | 
         
            +
              constexpr int s_sh_stage = s_sh_stride;
         
     | 
| 257 | 
         
            +
              int s_gl_rd_delta = s_gl_stride;
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
              // Global A read index of current thread.
         
     | 
| 260 | 
         
            +
              int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
         
     | 
| 261 | 
         
            +
                            (threadIdx.x % a_gl_rd_delta_o);
         
     | 
| 262 | 
         
            +
              a_gl_rd += a_gl_rd_delta_o * slice_row;
         
     | 
| 263 | 
         
            +
              // Shared write index of current thread.
         
     | 
| 264 | 
         
            +
              int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
         
     | 
| 265 | 
         
            +
                            (threadIdx.x % a_gl_rd_delta_o);
         
     | 
| 266 | 
         
            +
              // Shared read index.
         
     | 
| 267 | 
         
            +
              int a_sh_rd =
         
     | 
| 268 | 
         
            +
                  a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
         
     | 
| 269 | 
         
            +
              a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
              int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
         
     | 
| 272 | 
         
            +
                            (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
         
     | 
| 273 | 
         
            +
              b_gl_rd += b_sh_stride * slice_col;
         
     | 
| 274 | 
         
            +
              b_gl_rd += b_gl_rd_delta_o * slice_row;
         
     | 
| 275 | 
         
            +
              int b_sh_wr = threadIdx.x * b_thread_vecs;
         
     | 
| 276 | 
         
            +
              int b_sh_rd = threadIdx.x * b_thread_vecs;
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
              int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) +
         
     | 
| 279 | 
         
            +
                            (threadIdx.x % (m_sh_stride));
         
     | 
| 280 | 
         
            +
              m_gl_rd += (m_sh_stride)*slice_col;
         
     | 
| 281 | 
         
            +
              m_gl_rd += m_gl_rd_delta_o * slice_row;
         
     | 
| 282 | 
         
            +
              int m_sh_wr = threadIdx.x;
         
     | 
| 283 | 
         
            +
              int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16;
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
              int s_gl_rd;
         
     | 
| 286 | 
         
            +
              if constexpr (group_blocks == -1) {
         
     | 
| 287 | 
         
            +
                s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
         
     | 
| 288 | 
         
            +
              } else {
         
     | 
| 289 | 
         
            +
                s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
         
     | 
| 290 | 
         
            +
                          s_sh_stride * slice_col + threadIdx.x;
         
     | 
| 291 | 
         
            +
              }
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
              int s_sh_wr = threadIdx.x;
         
     | 
| 294 | 
         
            +
              int s_sh_rd;
         
     | 
| 295 | 
         
            +
              // We use a different scale layout for grouped and column-wise quantization as
         
     | 
| 296 | 
         
            +
              // we scale a `half2` tile in column-major layout in the former and in
         
     | 
| 297 | 
         
            +
              // row-major in the latter case.
         
     | 
| 298 | 
         
            +
              s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
         
     | 
| 299 | 
         
            +
                        (threadIdx.x % 32) / 4;  // Note that in the original Marlin kernel
         
     | 
| 300 | 
         
            +
                                                 // this is (threadIdx.x % 32) / 4
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
              // Precompute which thread should not read memory in which iterations; this is
         
     | 
| 303 | 
         
            +
              // needed if there are more threads than required for a certain tilesize or
         
     | 
| 304 | 
         
            +
              // when the batchsize is not a multiple of 16.
         
     | 
| 305 | 
         
            +
              bool a_sh_wr_pred[a_sh_wr_iters];
         
     | 
| 306 | 
         
            +
              #pragma unroll
         
     | 
| 307 | 
         
            +
              for (int i = 0; i < a_sh_wr_iters; i++) {
         
     | 
| 308 | 
         
            +
                a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
         
     | 
| 309 | 
         
            +
              }
         
     | 
| 310 | 
         
            +
              bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
              // To ensure that writing and reading A tiles to/from shared memory, the
         
     | 
| 313 | 
         
            +
              // latter in fragment format, is fully bank conflict free, we need to use a
         
     | 
| 314 | 
         
            +
              // rather fancy XOR-based layout. The key here is that neither reads nor
         
     | 
| 315 | 
         
            +
              // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
         
     | 
| 316 | 
         
            +
              // same shared memory banks. Further, it seems (based on NSight-Compute) that
         
     | 
| 317 | 
         
            +
              // each warp must also write a consecutive memory segment?
         
     | 
| 318 | 
         
            +
              auto transform_a = [&](int i) {
         
     | 
| 319 | 
         
            +
                int row = i / a_gl_rd_delta_o;
         
     | 
| 320 | 
         
            +
                return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
         
     | 
| 321 | 
         
            +
              };
         
     | 
| 322 | 
         
            +
              // Since the computation of this remapping is non-trivial and, due to our main
         
     | 
| 323 | 
         
            +
              // loop unrolls, all shared memory accesses are static, we simply precompute
         
     | 
| 324 | 
         
            +
              // both transformed reads and writes.
         
     | 
| 325 | 
         
            +
              int a_sh_wr_trans[a_sh_wr_iters];
         
     | 
| 326 | 
         
            +
              #pragma unroll
         
     | 
| 327 | 
         
            +
              for (int i = 0; i < a_sh_wr_iters; i++)
         
     | 
| 328 | 
         
            +
                a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
         
     | 
| 329 | 
         
            +
              int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks];
         
     | 
| 330 | 
         
            +
              #pragma unroll
         
     | 
| 331 | 
         
            +
              for (int i = 0; i < b_sh_wr_iters; i++) {
         
     | 
| 332 | 
         
            +
              #pragma unroll
         
     | 
| 333 | 
         
            +
                for (int j = 0; j < thread_m_blocks; j++) {
         
     | 
| 334 | 
         
            +
                  a_sh_rd_trans[0][i][j] =
         
     | 
| 335 | 
         
            +
                      transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
         
     | 
| 336 | 
         
            +
                  a_sh_rd_trans[1][i][j] =
         
     | 
| 337 | 
         
            +
                      transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2);
         
     | 
| 338 | 
         
            +
                }
         
     | 
| 339 | 
         
            +
              }
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
              // Since B-accesses have non-constant stride they have to be computed at
         
     | 
| 342 | 
         
            +
              // runtime; we break dependencies between subsequent accesses with a tile by
         
     | 
| 343 | 
         
            +
              // maintining multiple pointers (we have enough registers), a tiny
         
     | 
| 344 | 
         
            +
              // optimization.
         
     | 
| 345 | 
         
            +
              const int4* B_ptr[b_sh_wr_iters];
         
     | 
| 346 | 
         
            +
              #pragma unroll
         
     | 
| 347 | 
         
            +
              for (int i = 0; i < b_sh_wr_iters; i++)
         
     | 
| 348 | 
         
            +
                B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
              bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta;
         
     | 
| 351 | 
         
            +
              const int4* meta_ptr[m_sh_iters];
         
     | 
| 352 | 
         
            +
              #pragma unroll
         
     | 
| 353 | 
         
            +
              for (int i = 0; i < m_sh_iters; i++)
         
     | 
| 354 | 
         
            +
                meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd;
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
              extern __shared__ int4 sh[];
         
     | 
| 357 | 
         
            +
              // Shared memory storage for global fetch pipelines.
         
     | 
| 358 | 
         
            +
              int4* sh_a = sh;
         
     | 
| 359 | 
         
            +
              int4* sh_b = sh_a + (stages * a_sh_stage);
         
     | 
| 360 | 
         
            +
              int4* sh_s = sh_b + (stages * b_sh_stage);
         
     | 
| 361 | 
         
            +
              int4* sh_m = sh_s + (stages * s_sh_stage);
         
     | 
| 362 | 
         
            +
              // Register storage for double buffer of shared memory reads.
         
     | 
| 363 | 
         
            +
              FragA frag_a[2][thread_m_blocks][2];
         
     | 
| 364 | 
         
            +
              I4 frag_b_quant[2][b_thread_vecs];
         
     | 
| 365 | 
         
            +
              FragM frag_m[2][2];
         
     | 
| 366 | 
         
            +
              FragC frag_c[thread_m_blocks][4][2];
         
     | 
| 367 | 
         
            +
              FragS frag_s[2][4];
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
              // Zero accumulators.
         
     | 
| 370 | 
         
            +
              auto zero_accums = [&]() {
         
     | 
| 371 | 
         
            +
              #pragma unroll
         
     | 
| 372 | 
         
            +
                for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
         
     | 
| 373 | 
         
            +
                  reinterpret_cast<float*>(frag_c)[i] = 0;
         
     | 
| 374 | 
         
            +
              };
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
              // Asynchronously fetch the next A, B and s tile from global to the next
         
     | 
| 377 | 
         
            +
              // shared memory pipeline location.
         
     | 
| 378 | 
         
            +
              auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
         
     | 
| 379 | 
         
            +
                if (pred) {
         
     | 
| 380 | 
         
            +
                  int4* sh_a_stage = sh_a + a_sh_stage * pipe;
         
     | 
| 381 | 
         
            +
              #pragma unroll
         
     | 
| 382 | 
         
            +
                  for (int i = 0; i < a_sh_wr_iters; i++) {
         
     | 
| 383 | 
         
            +
                    cp_async4_pred(
         
     | 
| 384 | 
         
            +
                        &sh_a_stage[a_sh_wr_trans[i]],
         
     | 
| 385 | 
         
            +
                        &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
         
     | 
| 386 | 
         
            +
                        a_sh_wr_pred[i]);
         
     | 
| 387 | 
         
            +
                  }
         
     | 
| 388 | 
         
            +
                  int4* sh_b_stage = sh_b + b_sh_stage * pipe;
         
     | 
| 389 | 
         
            +
              #pragma unroll
         
     | 
| 390 | 
         
            +
                  for (int i = 0; i < b_sh_wr_iters; i++) {
         
     | 
| 391 | 
         
            +
              #pragma unroll
         
     | 
| 392 | 
         
            +
                    for (int j = 0; j < b_thread_vecs; j++) {
         
     | 
| 393 | 
         
            +
                      cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
         
     | 
| 394 | 
         
            +
                    }
         
     | 
| 395 | 
         
            +
                    B_ptr[i] += b_gl_rd_delta_o;
         
     | 
| 396 | 
         
            +
                  }
         
     | 
| 397 | 
         
            +
                  int4* sh_meta_stage = sh_m + m_sh_stage * pipe;
         
     | 
| 398 | 
         
            +
              #pragma unroll
         
     | 
| 399 | 
         
            +
                  for (int i = 0; i < m_sh_iters; i++) {
         
     | 
| 400 | 
         
            +
                    if (m_sh_wr_pred)
         
     | 
| 401 | 
         
            +
                      cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]);
         
     | 
| 402 | 
         
            +
                    meta_ptr[i] += m_gl_rd_delta_o;
         
     | 
| 403 | 
         
            +
                  }
         
     | 
| 404 | 
         
            +
                  // Only fetch scales if this tile starts a new group
         
     | 
| 405 | 
         
            +
                  if constexpr (group_blocks != -1) {
         
     | 
| 406 | 
         
            +
                    // This assumes group_blocks >= thread_k_blocks
         
     | 
| 407 | 
         
            +
                    // and would need to be modified to support smaller groups.
         
     | 
| 408 | 
         
            +
                    static_assert(group_blocks >= thread_k_blocks);
         
     | 
| 409 | 
         
            +
                    if (pipe % (group_blocks / thread_k_blocks) == 0) {
         
     | 
| 410 | 
         
            +
                      int4* sh_s_stage = sh_s + s_sh_stage * pipe;
         
     | 
| 411 | 
         
            +
                      if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
         
     | 
| 412 | 
         
            +
                      s_gl_rd += s_gl_rd_delta;
         
     | 
| 413 | 
         
            +
                    }
         
     | 
| 414 | 
         
            +
                  }
         
     | 
| 415 | 
         
            +
                }
         
     | 
| 416 | 
         
            +
                // Insert a fence even when we are winding down the pipeline to ensure that
         
     | 
| 417 | 
         
            +
                // waiting is also correct at this point.
         
     | 
| 418 | 
         
            +
                cp_async_fence();
         
     | 
| 419 | 
         
            +
              };
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
              // Wait until the next thread tile has been loaded to shared memory.
         
     | 
| 422 | 
         
            +
              auto wait_for_stage = [&]() {
         
     | 
| 423 | 
         
            +
                // We only have `stages - 2` active fetches since we are double buffering
         
     | 
| 424 | 
         
            +
                // and can only issue the next fetch when it is guaranteed that the previous
         
     | 
| 425 | 
         
            +
                // shared memory load is fully complete (as it may otherwise be
         
     | 
| 426 | 
         
            +
                // overwritten).
         
     | 
| 427 | 
         
            +
                cp_async_wait<stages - 2>();
         
     | 
| 428 | 
         
            +
                __syncthreads();
         
     | 
| 429 | 
         
            +
              };
         
     | 
| 430 | 
         
            +
             
     | 
| 431 | 
         
            +
              // Load the next sub-tile from the current location in the shared memory pipe
         
     | 
| 432 | 
         
            +
              // into the current register buffer.
         
     | 
| 433 | 
         
            +
              auto fetch_to_registers = [&](int k, int pipe) {
         
     | 
| 434 | 
         
            +
                // It may seem inefficient that we reload the groups for every sub-tile;
         
     | 
| 435 | 
         
            +
                // however, this does not seem to be a significant bottleneck, while some
         
     | 
| 436 | 
         
            +
                // theoretically better attempts have lead to bad instruction ordering by
         
     | 
| 437 | 
         
            +
                // the compiler and correspondingly a noticeable drop in performance.
         
     | 
| 438 | 
         
            +
                if constexpr (group_blocks != -1) {
         
     | 
| 439 | 
         
            +
                  // This assumes group_blocks >= thread_k_blocks
         
     | 
| 440 | 
         
            +
                  // and would need to be modified to support smaller groups.
         
     | 
| 441 | 
         
            +
                  static_assert(group_blocks >= thread_k_blocks);
         
     | 
| 442 | 
         
            +
                  int4* sh_s_stage =
         
     | 
| 443 | 
         
            +
                      sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
         
     | 
| 444 | 
         
            +
                                           (pipe / (group_blocks / thread_k_blocks)));
         
     | 
| 445 | 
         
            +
                  reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
         
     | 
| 446 | 
         
            +
                }
         
     | 
| 447 | 
         
            +
                int4* sh_a_stage = sh_a + a_sh_stage * pipe;
         
     | 
| 448 | 
         
            +
              #pragma unroll
         
     | 
| 449 | 
         
            +
                for (int i = 0; i < thread_m_blocks; i++) {
         
     | 
| 450 | 
         
            +
                  ldsm4(frag_a[k % 2][i][0],
         
     | 
| 451 | 
         
            +
                        &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]);
         
     | 
| 452 | 
         
            +
                  ldsm4(frag_a[k % 2][i][1],
         
     | 
| 453 | 
         
            +
                        &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]);
         
     | 
| 454 | 
         
            +
                }
         
     | 
| 455 | 
         
            +
             
     | 
| 456 | 
         
            +
                int4* sh_b_stage = sh_b + b_sh_stage * pipe;
         
     | 
| 457 | 
         
            +
              #pragma unroll
         
     | 
| 458 | 
         
            +
                for (int i = 0; i < b_thread_vecs; i++) {
         
     | 
| 459 | 
         
            +
                  frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
         
     | 
| 460 | 
         
            +
                      &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
         
     | 
| 461 | 
         
            +
                }
         
     | 
| 462 | 
         
            +
             
     | 
| 463 | 
         
            +
                // Load meta with ldsm4
         
     | 
| 464 | 
         
            +
                int4* sh_m_stage = sh_m + m_sh_stage * pipe;
         
     | 
| 465 | 
         
            +
                ldsm4_m(frag_m[k % 2][0],
         
     | 
| 466 | 
         
            +
                        &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]);
         
     | 
| 467 | 
         
            +
              };
         
     | 
| 468 | 
         
            +
             
     | 
| 469 | 
         
            +
              // Execute the actual tensor core matmul of a sub-tile.
         
     | 
| 470 | 
         
            +
              auto matmul = [&](int k) {
         
     | 
| 471 | 
         
            +
              // We have the m dimension as the inner loop in order to encourage overlapping
         
     | 
| 472 | 
         
            +
              // dequantization and matmul operations.
         
     | 
| 473 | 
         
            +
              #pragma unroll
         
     | 
| 474 | 
         
            +
                for (int j = 0; j < 4; j++) {
         
     | 
| 475 | 
         
            +
                  FragB frag_b0;
         
     | 
| 476 | 
         
            +
                  FragB frag_b1;
         
     | 
| 477 | 
         
            +
             
     | 
| 478 | 
         
            +
                  if constexpr (num_bits == 4) {
         
     | 
| 479 | 
         
            +
                    int b_quant = frag_b_quant[k % 2][0][j];
         
     | 
| 480 | 
         
            +
                    int b_quant_shift = b_quant >> 8;
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
                    frag_b0 = dequant_4bit(b_quant);
         
     | 
| 483 | 
         
            +
                    frag_b1 = dequant_4bit(b_quant_shift);
         
     | 
| 484 | 
         
            +
             
     | 
| 485 | 
         
            +
                  } else {
         
     | 
| 486 | 
         
            +
                    int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
         
     | 
| 487 | 
         
            +
                    int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
         
     | 
| 488 | 
         
            +
                    int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
         
     | 
| 489 | 
         
            +
             
     | 
| 490 | 
         
            +
                    frag_b0 = dequant_8bit(b_quant_0);
         
     | 
| 491 | 
         
            +
                    frag_b1 = dequant_8bit(b_quant_1);
         
     | 
| 492 | 
         
            +
                  }
         
     | 
| 493 | 
         
            +
             
     | 
| 494 | 
         
            +
                  // If there are no groups, we can just scale the final output once and can
         
     | 
| 495 | 
         
            +
                  // avoid doing so for each weight.
         
     | 
| 496 | 
         
            +
                  if constexpr (group_blocks != -1) {
         
     | 
| 497 | 
         
            +
                    scale(frag_b0, frag_s[k % 2][j], 0);
         
     | 
| 498 | 
         
            +
                  }
         
     | 
| 499 | 
         
            +
                  if constexpr (group_blocks != -1) {
         
     | 
| 500 | 
         
            +
                    scale(frag_b1, frag_s[k % 2][j], 1);
         
     | 
| 501 | 
         
            +
                  }
         
     | 
| 502 | 
         
            +
             
     | 
| 503 | 
         
            +
              #pragma unroll
         
     | 
| 504 | 
         
            +
                  for (int i = 0; i < thread_m_blocks; i++) {
         
     | 
| 505 | 
         
            +
                    mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0],
         
     | 
| 506 | 
         
            +
                           frag_m[k % 2][j / 2], j % 2);
         
     | 
| 507 | 
         
            +
                  }
         
     | 
| 508 | 
         
            +
                }
         
     | 
| 509 | 
         
            +
              };
         
     | 
| 510 | 
         
            +
             
     | 
| 511 | 
         
            +
              // Since we slice across the k dimension of a tile in order to increase the
         
     | 
| 512 | 
         
            +
              // number of warps while keeping the n dimension of a tile reasonable, we have
         
     | 
| 513 | 
         
            +
              // multiple warps that accumulate their partial sums of the same output
         
     | 
| 514 | 
         
            +
              // location; which we have to reduce over in the end. We do in shared memory.
         
     | 
| 515 | 
         
            +
              auto thread_block_reduce = [&]() {
         
     | 
| 516 | 
         
            +
                constexpr int red_off = threads / b_sh_stride_threads / 2;
         
     | 
| 517 | 
         
            +
                if (red_off >= 1) {
         
     | 
| 518 | 
         
            +
                  int red_idx = threadIdx.x / b_sh_stride_threads;
         
     | 
| 519 | 
         
            +
                  constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
         
     | 
| 520 | 
         
            +
                  constexpr int red_sh_delta = b_sh_stride_threads;
         
     | 
| 521 | 
         
            +
                  int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
         
     | 
| 522 | 
         
            +
                                  (threadIdx.x % b_sh_stride_threads);
         
     | 
| 523 | 
         
            +
             
     | 
| 524 | 
         
            +
              // Parallel logarithmic shared memory reduction. We make sure to avoid any
         
     | 
| 525 | 
         
            +
              // unnecessary read or write iterations, e.g., for two warps we write only
         
     | 
| 526 | 
         
            +
              // once by warp 1 and read only once by warp 0.
         
     | 
| 527 | 
         
            +
              #pragma unroll
         
     | 
| 528 | 
         
            +
                  for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
         
     | 
| 529 | 
         
            +
              #pragma unroll
         
     | 
| 530 | 
         
            +
                    for (int i = red_off; i > 0; i /= 2) {
         
     | 
| 531 | 
         
            +
                      if (i <= red_idx && red_idx < 2 * i) {
         
     | 
| 532 | 
         
            +
              #pragma unroll
         
     | 
| 533 | 
         
            +
                        for (int j = 0; j < 4 * 2; j++) {
         
     | 
| 534 | 
         
            +
                          int red_sh_wr =
         
     | 
| 535 | 
         
            +
                              red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
         
     | 
| 536 | 
         
            +
                          if (i < red_off) {
         
     | 
| 537 | 
         
            +
                            float* c_rd =
         
     | 
| 538 | 
         
            +
                                reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
         
     | 
| 539 | 
         
            +
                            float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
         
     | 
| 540 | 
         
            +
              #pragma unroll
         
     | 
| 541 | 
         
            +
                            for (int k = 0; k < 4; k++)
         
     | 
| 542 | 
         
            +
                              reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
         
     | 
| 543 | 
         
            +
                                  c_rd[k] + c_wr[k];
         
     | 
| 544 | 
         
            +
                          }
         
     | 
| 545 | 
         
            +
                          sh[red_sh_wr] =
         
     | 
| 546 | 
         
            +
                              reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
         
     | 
| 547 | 
         
            +
                        }
         
     | 
| 548 | 
         
            +
                      }
         
     | 
| 549 | 
         
            +
                      __syncthreads();
         
     | 
| 550 | 
         
            +
                    }
         
     | 
| 551 | 
         
            +
                    if (red_idx == 0) {
         
     | 
| 552 | 
         
            +
              #pragma unroll
         
     | 
| 553 | 
         
            +
                      for (int i = 0; i < 4 * 2; i++) {
         
     | 
| 554 | 
         
            +
                        float* c_rd =
         
     | 
| 555 | 
         
            +
                            reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
         
     | 
| 556 | 
         
            +
              #pragma unroll
         
     | 
| 557 | 
         
            +
                        for (int j = 0; j < 4; j++)
         
     | 
| 558 | 
         
            +
                          reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
         
     | 
| 559 | 
         
            +
                              c_rd[j];
         
     | 
| 560 | 
         
            +
                      }
         
     | 
| 561 | 
         
            +
                    }
         
     | 
| 562 | 
         
            +
                    __syncthreads();
         
     | 
| 563 | 
         
            +
                  }
         
     | 
| 564 | 
         
            +
                }
         
     | 
| 565 | 
         
            +
              };
         
     | 
| 566 | 
         
            +
             
     | 
| 567 | 
         
            +
              // Since multiple threadblocks may process parts of the same column slice, we
         
     | 
| 568 | 
         
            +
              // finally have to globally reduce over the results. As the striped
         
     | 
| 569 | 
         
            +
              // partitioning minimizes the number of such reductions and our outputs are
         
     | 
| 570 | 
         
            +
              // usually rather small, we perform this reduction serially in L2 cache.
         
     | 
| 571 | 
         
            +
              auto global_reduce = [&](bool first = false, bool last = false) {
         
     | 
| 572 | 
         
            +
                // We are very careful here to reduce directly in the output buffer to
         
     | 
| 573 | 
         
            +
                // maximize L2 cache utilization in this step. To do this, we write out
         
     | 
| 574 | 
         
            +
                // results in FP16 (but still reduce with FP32 compute).
         
     | 
| 575 | 
         
            +
                constexpr int active_threads = 32 * thread_n_blocks / 4;
         
     | 
| 576 | 
         
            +
                if (threadIdx.x < active_threads) {
         
     | 
| 577 | 
         
            +
                  int c_gl_stride = prob_n / 8;
         
     | 
| 578 | 
         
            +
                  int c_gl_wr_delta_o = 2 * 4 * c_gl_stride;
         
     | 
| 579 | 
         
            +
                  int c_gl_wr_delta_i =
         
     | 
| 580 | 
         
            +
                      c_gl_stride;  // 8 threads (e.g., 0,4,8,12,16,20,24,28)
         
     | 
| 581 | 
         
            +
                  int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) +
         
     | 
| 582 | 
         
            +
                                8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4;
         
     | 
| 583 | 
         
            +
                  c_gl_wr += (2 * thread_n_blocks) * slice_col;
         
     | 
| 584 | 
         
            +
                  constexpr int c_sh_wr_delta = active_threads;
         
     | 
| 585 | 
         
            +
                  int c_sh_wr = threadIdx.x;
         
     | 
| 586 | 
         
            +
             
     | 
| 587 | 
         
            +
                  int col = 2 * ((threadIdx.x % 32) % 4);
         
     | 
| 588 | 
         
            +
             
     | 
| 589 | 
         
            +
                  if (!first) {
         
     | 
| 590 | 
         
            +
              // Interestingly, doing direct global accesses here really seems to mess up
         
     | 
| 591 | 
         
            +
              // the compiler and lead to slowdowns, hence we also use async-copies even
         
     | 
| 592 | 
         
            +
              // though these fetches are not actually asynchronous.
         
     | 
| 593 | 
         
            +
              #pragma unroll
         
     | 
| 594 | 
         
            +
                    for (int i = 0; i < thread_m_blocks * 4; i++) {
         
     | 
| 595 | 
         
            +
                      cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
         
     | 
| 596 | 
         
            +
                                     &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
         
     | 
| 597 | 
         
            +
                                        c_gl_wr_delta_i * (i % 2)],
         
     | 
| 598 | 
         
            +
                                     i < (thread_m_blocks - 1) * 4 ||
         
     | 
| 599 | 
         
            +
                                         8 * (i / 2) + col + (i % 2) < prob_m);
         
     | 
| 600 | 
         
            +
                    }
         
     | 
| 601 | 
         
            +
                    cp_async_fence();
         
     | 
| 602 | 
         
            +
                    cp_async_wait<0>();
         
     | 
| 603 | 
         
            +
                  }
         
     | 
| 604 | 
         
            +
             
     | 
| 605 | 
         
            +
              #pragma unroll
         
     | 
| 606 | 
         
            +
                  for (int i = 0; i < thread_m_blocks * 4; i++) {
         
     | 
| 607 | 
         
            +
                    if (i < (thread_m_blocks - 1) * 4 ||
         
     | 
| 608 | 
         
            +
                        8 * (i / 2) + col + (i % 2) < prob_m) {
         
     | 
| 609 | 
         
            +
                      if (!first) {
         
     | 
| 610 | 
         
            +
                        int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
         
     | 
| 611 | 
         
            +
              #pragma unroll
         
     | 
| 612 | 
         
            +
                        for (int j2 = 0; j2 < 2; j2++) {
         
     | 
| 613 | 
         
            +
              #pragma unroll
         
     | 
| 614 | 
         
            +
                          for (int j1 = 0; j1 < 4; j1++) {
         
     | 
| 615 | 
         
            +
                            reinterpret_cast<float*>(
         
     | 
| 616 | 
         
            +
                                &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
         
     | 
| 617 | 
         
            +
                                         4 * ((i % 4) / 2) + i % 2] +=
         
     | 
| 618 | 
         
            +
                                __half2float(
         
     | 
| 619 | 
         
            +
                                    reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]);
         
     | 
| 620 | 
         
            +
                          }
         
     | 
| 621 | 
         
            +
                        }
         
     | 
| 622 | 
         
            +
                      }
         
     | 
| 623 | 
         
            +
                      if (!last) {
         
     | 
| 624 | 
         
            +
                        int4 c;
         
     | 
| 625 | 
         
            +
              #pragma unroll
         
     | 
| 626 | 
         
            +
                        for (int j2 = 0; j2 < 2; j2++) {
         
     | 
| 627 | 
         
            +
              #pragma unroll
         
     | 
| 628 | 
         
            +
                          for (int j1 = 0; j1 < 4; j1++) {
         
     | 
| 629 | 
         
            +
                            reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] =
         
     | 
| 630 | 
         
            +
                                __float2half(reinterpret_cast<float*>(
         
     | 
| 631 | 
         
            +
                                    &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 +
         
     | 
| 632 | 
         
            +
                                             4 * ((i % 4) / 2) + i % 2]);
         
     | 
| 633 | 
         
            +
                          }
         
     | 
| 634 | 
         
            +
                        }
         
     | 
| 635 | 
         
            +
                        C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
         
     | 
| 636 | 
         
            +
                            c;
         
     | 
| 637 | 
         
            +
                      }
         
     | 
| 638 | 
         
            +
                    }
         
     | 
| 639 | 
         
            +
                  }
         
     | 
| 640 | 
         
            +
                }
         
     | 
| 641 | 
         
            +
              };
         
     | 
| 642 | 
         
            +
             
     | 
| 643 | 
         
            +
              // Write out the reduce final result in the correct layout. We only actually
         
     | 
| 644 | 
         
            +
              // reshuffle matrix fragments in this step, the reduction above is performed
         
     | 
| 645 | 
         
            +
              // in fragment layout.
         
     | 
| 646 | 
         
            +
              auto write_result = [&]() {
         
     | 
| 647 | 
         
            +
                int c_gl_stride = prob_n / 8;
         
     | 
| 648 | 
         
            +
             
     | 
| 649 | 
         
            +
                constexpr int c_sh_stride = 2 * thread_n_blocks;              // RLC:
         
     | 
| 650 | 
         
            +
                constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2;            // RLC:
         
     | 
| 651 | 
         
            +
                constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2;  // RLC:
         
     | 
| 652 | 
         
            +
             
     | 
| 653 | 
         
            +
                int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
         
     | 
| 654 | 
         
            +
             
     | 
| 655 | 
         
            +
                int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
         
     | 
| 656 | 
         
            +
                              (threadIdx.x % (2 * thread_n_blocks));
         
     | 
| 657 | 
         
            +
                c_gl_wr += (2 * thread_n_blocks) * slice_col;
         
     | 
| 658 | 
         
            +
             
     | 
| 659 | 
         
            +
                int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) +
         
     | 
| 660 | 
         
            +
                              ((threadIdx.x % 32) / 4);  // RLC:
         
     | 
| 661 | 
         
            +
                c_sh_wr += 8 * (threadIdx.x / 32);       // 128/4(half4)
         
     | 
| 662 | 
         
            +
             
     | 
| 663 | 
         
            +
                constexpr int c_sh_rd_delta =
         
     | 
| 664 | 
         
            +
                    c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks));  // RLC:
         
     | 
| 665 | 
         
            +
                int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) +
         
     | 
| 666 | 
         
            +
                              (threadIdx.x % (2 * 2 * thread_n_blocks));
         
     | 
| 667 | 
         
            +
             
     | 
| 668 | 
         
            +
                int c_gl_wr_end = c_gl_stride * prob_m;
         
     | 
| 669 | 
         
            +
             
     | 
| 670 | 
         
            +
                auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0,
         
     | 
| 671 | 
         
            +
                                 float c4, float c5, float c6, float c7, FragS& s1) {
         
     | 
| 672 | 
         
            +
                  uint2 res[2];
         
     | 
| 673 | 
         
            +
                  res[0] = to_half4(c0, c1, c2, c3);
         
     | 
| 674 | 
         
            +
                  res[1] = to_half4(c4, c5, c6, c7);
         
     | 
| 675 | 
         
            +
                  half2* tmp = (half2*)&res;
         
     | 
| 676 | 
         
            +
                  // for per-column quantization we finally apply the scale here
         
     | 
| 677 | 
         
            +
                  if constexpr (group_blocks == -1 && num_bits == 4) {
         
     | 
| 678 | 
         
            +
                    tmp[0] = __hmul2(tmp[0], s0[0]);
         
     | 
| 679 | 
         
            +
                    tmp[1] = __hmul2(tmp[1], s0[1]);
         
     | 
| 680 | 
         
            +
                    tmp[2] = __hmul2(tmp[2], s1[0]);
         
     | 
| 681 | 
         
            +
                    tmp[3] = __hmul2(tmp[3], s1[1]);
         
     | 
| 682 | 
         
            +
                  }
         
     | 
| 683 | 
         
            +
                  ((int4*)sh)[idx] = *((int4*)&res[0]);
         
     | 
| 684 | 
         
            +
                };
         
     | 
| 685 | 
         
            +
             
     | 
| 686 | 
         
            +
                // RLC:  only warp 0 and 1 baseline example
         
     | 
| 687 | 
         
            +
                if (threadIdx.x / 32 < thread_n_blocks / 4) {
         
     | 
| 688 | 
         
            +
              #pragma unroll
         
     | 
| 689 | 
         
            +
                  for (int i = 0; i < thread_m_blocks; i++) {
         
     | 
| 690 | 
         
            +
                    int wr = c_sh_wr;
         
     | 
| 691 | 
         
            +
                    write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0],
         
     | 
| 692 | 
         
            +
                          frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2],
         
     | 
| 693 | 
         
            +
                          frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2],
         
     | 
| 694 | 
         
            +
                          frag_s[0][2]);
         
     | 
| 695 | 
         
            +
                    write(wr + c_sh_stride, frag_c[i][0][0][1], frag_c[i][1][0][1],
         
     | 
| 696 | 
         
            +
                          frag_c[i][2][0][1], frag_c[i][3][0][1], frag_s[0][0],
         
     | 
| 697 | 
         
            +
                          frag_c[i][0][0][3], frag_c[i][1][0][3], frag_c[i][2][0][3],
         
     | 
| 698 | 
         
            +
                          frag_c[i][3][0][3], frag_s[0][2]);
         
     | 
| 699 | 
         
            +
                    write(wr + 4 * c_sh_stride_2, frag_c[i][0][1][0], frag_c[i][1][1][0],
         
     | 
| 700 | 
         
            +
                          frag_c[i][2][1][0], frag_c[i][3][1][0], frag_s[0][0],
         
     | 
| 701 | 
         
            +
                          frag_c[i][0][1][2], frag_c[i][1][1][2], frag_c[i][2][1][2],
         
     | 
| 702 | 
         
            +
                          frag_c[i][3][1][2], frag_s[0][2]);
         
     | 
| 703 | 
         
            +
                    write(wr + 4 * c_sh_stride_2 + c_sh_stride, frag_c[i][0][1][1],
         
     | 
| 704 | 
         
            +
                          frag_c[i][1][1][1], frag_c[i][2][1][1], frag_c[i][3][1][1],
         
     | 
| 705 | 
         
            +
                          frag_s[0][0], frag_c[i][0][1][3], frag_c[i][1][1][3],
         
     | 
| 706 | 
         
            +
                          frag_c[i][2][1][3], frag_c[i][3][1][3], frag_s[0][2]);
         
     | 
| 707 | 
         
            +
             
     | 
| 708 | 
         
            +
                    c_sh_wr += 8 * c_sh_stride_2;
         
     | 
| 709 | 
         
            +
                  }
         
     | 
| 710 | 
         
            +
                }
         
     | 
| 711 | 
         
            +
                __syncthreads();
         
     | 
| 712 | 
         
            +
             
     | 
| 713 | 
         
            +
              #pragma unroll
         
     | 
| 714 | 
         
            +
                for (int i = 0;
         
     | 
| 715 | 
         
            +
                     i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
         
     | 
| 716 | 
         
            +
                     i++) {
         
     | 
| 717 | 
         
            +
                  if (c_gl_wr < c_gl_wr_end) {
         
     | 
| 718 | 
         
            +
                    C[c_gl_wr] = sh[c_sh_rd];
         
     | 
| 719 | 
         
            +
                    c_gl_wr += c_gl_wr_delta;
         
     | 
| 720 | 
         
            +
                    c_sh_rd += c_sh_rd_delta;
         
     | 
| 721 | 
         
            +
                  }
         
     | 
| 722 | 
         
            +
                }
         
     | 
| 723 | 
         
            +
              };
         
     | 
| 724 | 
         
            +
             
     | 
| 725 | 
         
            +
              // Start global fetch and register load pipelines.
         
     | 
| 726 | 
         
            +
              auto start_pipes = [&]() {
         
     | 
| 727 | 
         
            +
              #pragma unroll
         
     | 
| 728 | 
         
            +
                for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters);
         
     | 
| 729 | 
         
            +
                zero_accums();
         
     | 
| 730 | 
         
            +
                wait_for_stage();
         
     | 
| 731 | 
         
            +
                fetch_to_registers(0, 0);
         
     | 
| 732 | 
         
            +
                a_gl_rd += a_gl_rd_delta_o * (stages - 1);
         
     | 
| 733 | 
         
            +
              };
         
     | 
| 734 | 
         
            +
              start_pipes();
         
     | 
| 735 | 
         
            +
             
     | 
| 736 | 
         
            +
              // Main loop.
         
     | 
| 737 | 
         
            +
              while (slice_iters) {
         
     | 
| 738 | 
         
            +
              // We unroll over both the global fetch and the register load pipeline to
         
     | 
| 739 | 
         
            +
              // ensure all shared memory accesses are static. Note that both pipelines have
         
     | 
| 740 | 
         
            +
              // even length meaning that the next iteration will always start at index 0.
         
     | 
| 741 | 
         
            +
              #pragma unroll
         
     | 
| 742 | 
         
            +
                for (int pipe = 0; pipe < stages;) {
         
     | 
| 743 | 
         
            +
                  fetch_to_shared((pipe + stages - 1) % stages, pipe,
         
     | 
| 744 | 
         
            +
                                  slice_iters >= stages);
         
     | 
| 745 | 
         
            +
                  matmul(pipe);
         
     | 
| 746 | 
         
            +
                  wait_for_stage();
         
     | 
| 747 | 
         
            +
             
     | 
| 748 | 
         
            +
                  fetch_to_registers(pipe + 1, (pipe + 1) % stages);
         
     | 
| 749 | 
         
            +
             
     | 
| 750 | 
         
            +
                  pipe++;
         
     | 
| 751 | 
         
            +
                  slice_iters--;
         
     | 
| 752 | 
         
            +
                  if (slice_iters == 0) break;
         
     | 
| 753 | 
         
            +
                }
         
     | 
| 754 | 
         
            +
                a_gl_rd += a_gl_rd_delta_o * stages;
         
     | 
| 755 | 
         
            +
             
     | 
| 756 | 
         
            +
                // Process results and, if necessary, proceed to the next column slice.
         
     | 
| 757 | 
         
            +
                // While this pattern may not be the most readable, other ways of writing
         
     | 
| 758 | 
         
            +
                // the loop seemed to noticeably worse performance after compilation.
         
     | 
| 759 | 
         
            +
                if (slice_iters == 0) {
         
     | 
| 760 | 
         
            +
                  cp_async_wait<0>();
         
     | 
| 761 | 
         
            +
                  bool last = slice_idx == slice_count - 1;
         
     | 
| 762 | 
         
            +
                  // For per-column scales, we only fetch them here in the final step before
         
     | 
| 763 | 
         
            +
                  // write-out
         
     | 
| 764 | 
         
            +
                  if constexpr (group_blocks == -1) {
         
     | 
| 765 | 
         
            +
                    if constexpr (num_bits == 8) {
         
     | 
| 766 | 
         
            +
                      if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
         
     | 
| 767 | 
         
            +
                      cp_async_fence();
         
     | 
| 768 | 
         
            +
                    } else {
         
     | 
| 769 | 
         
            +
                      if (last) {
         
     | 
| 770 | 
         
            +
                        if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
         
     | 
| 771 | 
         
            +
                        cp_async_fence();
         
     | 
| 772 | 
         
            +
                      }
         
     | 
| 773 | 
         
            +
                    }
         
     | 
| 774 | 
         
            +
                  }
         
     | 
| 775 | 
         
            +
                  thread_block_reduce();
         
     | 
| 776 | 
         
            +
             
     | 
| 777 | 
         
            +
                  if constexpr (group_blocks == -1) {
         
     | 
| 778 | 
         
            +
                    if constexpr (num_bits == 8) {
         
     | 
| 779 | 
         
            +
                      cp_async_wait<0>();
         
     | 
| 780 | 
         
            +
                      __syncthreads();
         
     | 
| 781 | 
         
            +
                      if (threadIdx.x / 32 < thread_n_blocks / 4) {
         
     | 
| 782 | 
         
            +
                        *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
         
     | 
| 783 | 
         
            +
                      }
         
     | 
| 784 | 
         
            +
                    } else {
         
     | 
| 785 | 
         
            +
                      if (last) {
         
     | 
| 786 | 
         
            +
                        cp_async_wait<0>();
         
     | 
| 787 | 
         
            +
                        __syncthreads();
         
     | 
| 788 | 
         
            +
                        if (threadIdx.x / 32 < thread_n_blocks / 4) {
         
     | 
| 789 | 
         
            +
                          *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]);
         
     | 
| 790 | 
         
            +
                        }
         
     | 
| 791 | 
         
            +
                      }
         
     | 
| 792 | 
         
            +
                    }
         
     | 
| 793 | 
         
            +
                  }
         
     | 
| 794 | 
         
            +
             
     | 
| 795 | 
         
            +
                  // For 8-bit channelwise, we apply the scale before the global reduction
         
     | 
| 796 | 
         
            +
                  // that converts the fp32 results to fp16 (so that we avoid possible
         
     | 
| 797 | 
         
            +
                  // overflow in fp16)
         
     | 
| 798 | 
         
            +
                  if constexpr (group_blocks == -1 && num_bits == 8) {
         
     | 
| 799 | 
         
            +
                    if (threadIdx.x / 32 < thread_n_blocks / 4) {
         
     | 
| 800 | 
         
            +
              #pragma unroll
         
     | 
| 801 | 
         
            +
                      for (int i = 0; i < thread_m_blocks; i++) {
         
     | 
| 802 | 
         
            +
                        scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0],
         
     | 
| 803 | 
         
            +
                                     &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0],
         
     | 
| 804 | 
         
            +
                                     &frag_c[i][0][0][2], &frag_c[i][1][0][2],
         
     | 
| 805 | 
         
            +
                                     &frag_c[i][2][0][2], &frag_c[i][3][0][2],
         
     | 
| 806 | 
         
            +
                                     frag_s[0][2]);
         
     | 
| 807 | 
         
            +
             
     | 
| 808 | 
         
            +
                        scale_floats(&frag_c[i][0][0][1], &frag_c[i][1][0][1],
         
     | 
| 809 | 
         
            +
                                     &frag_c[i][2][0][1], &frag_c[i][3][0][1], frag_s[0][0],
         
     | 
| 810 | 
         
            +
                                     &frag_c[i][0][0][3], &frag_c[i][1][0][3],
         
     | 
| 811 | 
         
            +
                                     &frag_c[i][2][0][3], &frag_c[i][3][0][3],
         
     | 
| 812 | 
         
            +
                                     frag_s[0][2]);
         
     | 
| 813 | 
         
            +
             
     | 
| 814 | 
         
            +
                        scale_floats(&frag_c[i][0][1][0], &frag_c[i][1][1][0],
         
     | 
| 815 | 
         
            +
                                     &frag_c[i][2][1][0], &frag_c[i][3][1][0], frag_s[0][0],
         
     | 
| 816 | 
         
            +
                                     &frag_c[i][0][1][2], &frag_c[i][1][1][2],
         
     | 
| 817 | 
         
            +
                                     &frag_c[i][2][1][2], &frag_c[i][3][1][2],
         
     | 
| 818 | 
         
            +
                                     frag_s[0][2]);
         
     | 
| 819 | 
         
            +
             
     | 
| 820 | 
         
            +
                        scale_floats(&frag_c[i][0][1][1], &frag_c[i][1][1][1],
         
     | 
| 821 | 
         
            +
                                     &frag_c[i][2][1][1], &frag_c[i][3][1][1], frag_s[0][0],
         
     | 
| 822 | 
         
            +
                                     &frag_c[i][0][1][3], &frag_c[i][1][1][3],
         
     | 
| 823 | 
         
            +
                                     &frag_c[i][2][1][3], &frag_c[i][3][1][3],
         
     | 
| 824 | 
         
            +
                                     frag_s[0][2]);
         
     | 
| 825 | 
         
            +
                      }
         
     | 
| 826 | 
         
            +
                    }
         
     | 
| 827 | 
         
            +
                  }
         
     | 
| 828 | 
         
            +
             
     | 
| 829 | 
         
            +
                  if (slice_count > 1) {  // only globally reduce if there is more than one
         
     | 
| 830 | 
         
            +
                                          // block in a slice
         
     | 
| 831 | 
         
            +
                    barrier_acquire(&locks[slice_col], slice_idx);
         
     | 
| 832 | 
         
            +
                    global_reduce(slice_idx == 0, last);
         
     | 
| 833 | 
         
            +
                    barrier_release(&locks[slice_col], last);
         
     | 
| 834 | 
         
            +
                  }
         
     | 
| 835 | 
         
            +
                  if (last)  // only the last block in a slice actually writes the result
         
     | 
| 836 | 
         
            +
                    write_result();
         
     | 
| 837 | 
         
            +
             
     | 
| 838 | 
         
            +
                  slice_row = 0;
         
     | 
| 839 | 
         
            +
                  slice_col_par++;
         
     | 
| 840 | 
         
            +
                  slice_col++;
         
     | 
| 841 | 
         
            +
                  init_slice();
         
     | 
| 842 | 
         
            +
                  if (slice_iters) {
         
     | 
| 843 | 
         
            +
                    a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
         
     | 
| 844 | 
         
            +
                              (threadIdx.x % a_gl_rd_delta_o);
         
     | 
| 845 | 
         
            +
              #pragma unroll
         
     | 
| 846 | 
         
            +
                    for (int i = 0; i < b_sh_wr_iters; i++)
         
     | 
| 847 | 
         
            +
                      B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
         
     | 
| 848 | 
         
            +
              #pragma unroll
         
     | 
| 849 | 
         
            +
                    for (int i = 0; i < m_sh_iters; i++)
         
     | 
| 850 | 
         
            +
                      meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles;
         
     | 
| 851 | 
         
            +
                    if (slice_col == 0) {
         
     | 
| 852 | 
         
            +
              #pragma unroll
         
     | 
| 853 | 
         
            +
                      for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
         
     | 
| 854 | 
         
            +
              #pragma unroll
         
     | 
| 855 | 
         
            +
                      for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride;
         
     | 
| 856 | 
         
            +
                    }
         
     | 
| 857 | 
         
            +
                    s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
         
     | 
| 858 | 
         
            +
                    start_pipes();
         
     | 
| 859 | 
         
            +
                  }
         
     | 
| 860 | 
         
            +
                }
         
     | 
| 861 | 
         
            +
              }
         
     | 
| 862 | 
         
            +
            }
         
     | 
| 863 | 
         
            +
             
     | 
| 864 | 
         
            +
            #endif
         
     | 
| 865 | 
         
            +
             
     | 
| 866 | 
         
            +
            #define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS,               \
         
     | 
| 867 | 
         
            +
                                THREAD_K_BLOCKS, GROUP_BLOCKS)                            \
         
     | 
| 868 | 
         
            +
              else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS &&      \
         
     | 
| 869 | 
         
            +
                       thread_n_blocks == THREAD_N_BLOCKS &&                              \
         
     | 
| 870 | 
         
            +
                       thread_k_blocks == THREAD_K_BLOCKS &&                              \
         
     | 
| 871 | 
         
            +
                       group_blocks == GROUP_BLOCKS) {                                    \
         
     | 
| 872 | 
         
            +
                cudaFuncSetAttribute(                                                     \
         
     | 
| 873 | 
         
            +
                    Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS,        \
         
     | 
| 874 | 
         
            +
                              THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>,                     \
         
     | 
| 875 | 
         
            +
                    cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);         \
         
     | 
| 876 | 
         
            +
                Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS,            \
         
     | 
| 877 | 
         
            +
                          THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>                          \
         
     | 
| 878 | 
         
            +
                    <<<blocks, THREADS, max_shared_mem, stream>>>(A_ptr, B_ptr, meta_ptr, \
         
     | 
| 879 | 
         
            +
                                                                  C_ptr, s_ptr, prob_n,   \
         
     | 
| 880 | 
         
            +
                                                                  prob_m, prob_k, locks); \
         
     | 
| 881 | 
         
            +
              }
         
     | 
| 882 | 
         
            +
             
     | 
| 883 | 
         
            +
            void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
         
     | 
| 884 | 
         
            +
                                 void* s, int prob_m, int prob_n, int prob_k,
         
     | 
| 885 | 
         
            +
                                 void* workspace, int num_bits, int groupsize = -1,
         
     | 
| 886 | 
         
            +
                                 int dev = 0, cudaStream_t stream = 0, int thread_k = -1,
         
     | 
| 887 | 
         
            +
                                 int thread_m = -1, int sms = -1, int max_par = 16) {
         
     | 
| 888 | 
         
            +
              int tot_n = prob_n;
         
     | 
| 889 | 
         
            +
              int tot_n_blocks = ceildiv(tot_n, 16);
         
     | 
| 890 | 
         
            +
              int pad = 16 * tot_n_blocks - tot_n;
         
     | 
| 891 | 
         
            +
             
     | 
| 892 | 
         
            +
              if (sms == -1) {
         
     | 
| 893 | 
         
            +
                cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
         
     | 
| 894 | 
         
            +
              }
         
     | 
| 895 | 
         
            +
              TORCH_CHECK(sms > 0);
         
     | 
| 896 | 
         
            +
             
     | 
| 897 | 
         
            +
              int max_shared_mem = 0;
         
     | 
| 898 | 
         
            +
              cudaDeviceGetAttribute(&max_shared_mem,
         
     | 
| 899 | 
         
            +
                                     cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
         
     | 
| 900 | 
         
            +
              TORCH_CHECK(max_shared_mem > 0);
         
     | 
| 901 | 
         
            +
             
     | 
| 902 | 
         
            +
              if (thread_k == -1 || thread_m == -1) {
         
     | 
| 903 | 
         
            +
                if (prob_n <= 16) {
         
     | 
| 904 | 
         
            +
                  // For small batchizes, better partitioningif is slightly more important
         
     | 
| 905 | 
         
            +
                  // than better compute utilization
         
     | 
| 906 | 
         
            +
                  thread_k = 128;
         
     | 
| 907 | 
         
            +
                  thread_m = 128;
         
     | 
| 908 | 
         
            +
                } else {
         
     | 
| 909 | 
         
            +
                  thread_k = 64;
         
     | 
| 910 | 
         
            +
                  thread_m = 256;
         
     | 
| 911 | 
         
            +
                }
         
     | 
| 912 | 
         
            +
                // Also had
         
     | 
| 913 | 
         
            +
                // if prob_n > 256
         
     | 
| 914 | 
         
            +
                //   thread_k = 32;
         
     | 
| 915 | 
         
            +
                //   thread_m = 512;
         
     | 
| 916 | 
         
            +
                // but this is broken,
         
     | 
| 917 | 
         
            +
                // TODO(Lucas, Alex M): figure out why
         
     | 
| 918 | 
         
            +
              }
         
     | 
| 919 | 
         
            +
             
     | 
| 920 | 
         
            +
              int thread_k_blocks = thread_k / 32;  // 2:4 version with m16n8k32 instruction
         
     | 
| 921 | 
         
            +
              int thread_m_blocks = thread_m / 16;
         
     | 
| 922 | 
         
            +
              int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
         
     | 
| 923 | 
         
            +
              int blocks = sms;
         
     | 
| 924 | 
         
            +
             
     | 
| 925 | 
         
            +
              TORCH_CHECK(prob_m % thread_m == 0, "prob_m = ", prob_m,
         
     | 
| 926 | 
         
            +
                          " is not divisible by thread_m = ", thread_m);
         
     | 
| 927 | 
         
            +
              TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
         
     | 
| 928 | 
         
            +
                          " is not divisible by thread_k = ", thread_k);
         
     | 
| 929 | 
         
            +
              if (group_blocks != -1) {
         
     | 
| 930 | 
         
            +
                TORCH_CHECK((prob_k / 2) % group_blocks == 0, "prob_k/2 = ", prob_k / 2,
         
     | 
| 931 | 
         
            +
                            " is not divisible by group_blocks = ", group_blocks);
         
     | 
| 932 | 
         
            +
              }
         
     | 
| 933 | 
         
            +
             
     | 
| 934 | 
         
            +
              TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
         
     | 
| 935 | 
         
            +
                          ", ", prob_n, ", ", prob_k, "]");
         
     | 
| 936 | 
         
            +
             
     | 
| 937 | 
         
            +
              const int4* A_ptr = (const int4*)A;
         
     | 
| 938 | 
         
            +
              const int4* B_ptr = (const int4*)B;
         
     | 
| 939 | 
         
            +
              const int4* meta_ptr = (const int4*)meta;
         
     | 
| 940 | 
         
            +
              int4* C_ptr = (int4*)C;
         
     | 
| 941 | 
         
            +
              const int4* s_ptr = (const int4*)s;
         
     | 
| 942 | 
         
            +
             
     | 
| 943 | 
         
            +
              constexpr int max_m_blocks = 4;
         
     | 
| 944 | 
         
            +
             
     | 
| 945 | 
         
            +
              int* locks = (int*)workspace;
         
     | 
| 946 | 
         
            +
              for (int i = 0; i < tot_n_blocks; i += max_m_blocks) {
         
     | 
| 947 | 
         
            +
                int thread_n_blocks = tot_n_blocks - i;
         
     | 
| 948 | 
         
            +
                prob_n = tot_n - 16 * i;
         
     | 
| 949 | 
         
            +
                int par = 1;
         
     | 
| 950 | 
         
            +
                if (thread_n_blocks > max_m_blocks) {
         
     | 
| 951 | 
         
            +
                  // Note that parallel > 1 currently only works for inputs without any
         
     | 
| 952 | 
         
            +
                  // padding
         
     | 
| 953 | 
         
            +
                  par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16);
         
     | 
| 954 | 
         
            +
                  if (par > max_par) par = max_par;
         
     | 
| 955 | 
         
            +
                  prob_n = (max_m_blocks * 16) * par;
         
     | 
| 956 | 
         
            +
                  i += max_m_blocks * (par - 1);
         
     | 
| 957 | 
         
            +
                  thread_n_blocks = max_m_blocks;
         
     | 
| 958 | 
         
            +
                }
         
     | 
| 959 | 
         
            +
             
     | 
| 960 | 
         
            +
                // For compilation speed, we only define the kernel configurations that have
         
     | 
| 961 | 
         
            +
                // seemed useful (in terms of performance) in our testing, however many more
         
     | 
| 962 | 
         
            +
                // are, in principle, possible.
         
     | 
| 963 | 
         
            +
             
     | 
| 964 | 
         
            +
                // the false is start of the CALL_IF macros
         
     | 
| 965 | 
         
            +
                if (false) {
         
     | 
| 966 | 
         
            +
                }  //         BMxBNxBK,   group
         
     | 
| 967 | 
         
            +
                // 4-bit
         
     | 
| 968 | 
         
            +
                CALL_IF_2_4(4, 8, 1, 4, -1)  // e.g., 16x128x128
         
     | 
| 969 | 
         
            +
                CALL_IF_2_4(4, 8, 1, 4, 4)   // e.g., 16x128x128, 64
         
     | 
| 970 | 
         
            +
             
     | 
| 971 | 
         
            +
                CALL_IF_2_4(4, 16, 1, 2, -1)  // e.g., 16x256x64
         
     | 
| 972 | 
         
            +
                CALL_IF_2_4(4, 16, 1, 2, 4)   // e.g., 16x256x64,  64
         
     | 
| 973 | 
         
            +
                CALL_IF_2_4(4, 16, 2, 2, -1)  // e.g.. 32x256x64
         
     | 
| 974 | 
         
            +
                CALL_IF_2_4(4, 16, 2, 2, 4)
         
     | 
| 975 | 
         
            +
                CALL_IF_2_4(4, 16, 3, 2, -1)
         
     | 
| 976 | 
         
            +
                CALL_IF_2_4(4, 16, 3, 2, 4)
         
     | 
| 977 | 
         
            +
                CALL_IF_2_4(4, 16, 4, 2, -1)
         
     | 
| 978 | 
         
            +
                CALL_IF_2_4(4, 16, 4, 2, 4)
         
     | 
| 979 | 
         
            +
             
     | 
| 980 | 
         
            +
                CALL_IF_2_4(4, 32, 1, 1, -1)  // e.g., 16x256x64
         
     | 
| 981 | 
         
            +
                CALL_IF_2_4(4, 32, 1, 1, 4)   // e.g., 16x256x64,  64
         
     | 
| 982 | 
         
            +
                CALL_IF_2_4(4, 32, 2, 1, -1)  // e.g.. 32x256x64
         
     | 
| 983 | 
         
            +
                CALL_IF_2_4(4, 32, 2, 1, 4)
         
     | 
| 984 | 
         
            +
                CALL_IF_2_4(4, 32, 3, 1, -1)
         
     | 
| 985 | 
         
            +
                CALL_IF_2_4(4, 32, 3, 1, 4)
         
     | 
| 986 | 
         
            +
                CALL_IF_2_4(4, 32, 4, 1, -1)
         
     | 
| 987 | 
         
            +
                CALL_IF_2_4(4, 32, 4, 1, 4)
         
     | 
| 988 | 
         
            +
             
     | 
| 989 | 
         
            +
                // 8-bit
         
     | 
| 990 | 
         
            +
                CALL_IF_2_4(8, 8, 1, 4, -1)  // e.g., 16x128x128
         
     | 
| 991 | 
         
            +
                CALL_IF_2_4(8, 8, 1, 4, 4)   // e.g., 16x128x128, 64
         
     | 
| 992 | 
         
            +
             
     | 
| 993 | 
         
            +
                CALL_IF_2_4(8, 16, 1, 2, -1)  // e.g., 16x256x64
         
     | 
| 994 | 
         
            +
                CALL_IF_2_4(8, 16, 1, 2, 4)   // e.g., 16x256x64,  64
         
     | 
| 995 | 
         
            +
                CALL_IF_2_4(8, 16, 2, 2, -1)  // e.g.. 32x256x64
         
     | 
| 996 | 
         
            +
                CALL_IF_2_4(8, 16, 2, 2, 4)
         
     | 
| 997 | 
         
            +
                CALL_IF_2_4(8, 16, 3, 2, -1)
         
     | 
| 998 | 
         
            +
                CALL_IF_2_4(8, 16, 3, 2, 4)
         
     | 
| 999 | 
         
            +
                CALL_IF_2_4(8, 16, 4, 2, -1)
         
     | 
| 1000 | 
         
            +
                CALL_IF_2_4(8, 16, 4, 2, 4)
         
     | 
| 1001 | 
         
            +
             
     | 
| 1002 | 
         
            +
                CALL_IF_2_4(8, 32, 1, 1, -1)  // e.g., 16x256x64
         
     | 
| 1003 | 
         
            +
                CALL_IF_2_4(8, 32, 1, 1, 4)   // e.g., 16x256x64,  64
         
     | 
| 1004 | 
         
            +
                CALL_IF_2_4(8, 32, 2, 1, -1)  // e.g.. 32x256x64
         
     | 
| 1005 | 
         
            +
                CALL_IF_2_4(8, 32, 2, 1, 4)
         
     | 
| 1006 | 
         
            +
                CALL_IF_2_4(8, 32, 3, 1, -1)
         
     | 
| 1007 | 
         
            +
                CALL_IF_2_4(8, 32, 3, 1, 4)
         
     | 
| 1008 | 
         
            +
                CALL_IF_2_4(8, 32, 4, 1, -1)
         
     | 
| 1009 | 
         
            +
                CALL_IF_2_4(8, 32, 4, 1, 4)
         
     | 
| 1010 | 
         
            +
                else {
         
     | 
| 1011 | 
         
            +
                  throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
         
     | 
| 1012 | 
         
            +
                                           ", " + str(prob_k) + ", " + str(prob_n) + "]" +
         
     | 
| 1013 | 
         
            +
                                           ", groupsize = " + str(groupsize) +
         
     | 
| 1014 | 
         
            +
                                           ", thread_m_blocks = " + str(thread_m_blocks) +
         
     | 
| 1015 | 
         
            +
                                           ", thread_n_blocks = " + str(thread_n_blocks) +
         
     | 
| 1016 | 
         
            +
                                           ", thread_k_blocks = " + str(thread_k_blocks));
         
     | 
| 1017 | 
         
            +
                }
         
     | 
| 1018 | 
         
            +
             
     | 
| 1019 | 
         
            +
                A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par;
         
     | 
| 1020 | 
         
            +
                C_ptr += 16 * thread_n_blocks * (prob_m / 8) * par;
         
     | 
| 1021 | 
         
            +
              }
         
     | 
| 1022 | 
         
            +
            }
         
     | 
| 1023 | 
         
            +
             
     | 
| 1024 | 
         
            +
            }  // namespace marlin_24
         
     | 
| 1025 | 
         
            +
             
     | 
| 1026 | 
         
            +
            torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
         
     | 
| 1027 | 
         
            +
                                              torch::Tensor& b_meta,
         
     | 
| 1028 | 
         
            +
                                              torch::Tensor& b_scales,
         
     | 
| 1029 | 
         
            +
                                              torch::Tensor& workspace,
         
     | 
| 1030 | 
         
            +
                                              vllm::ScalarTypeId const b_q_type_id,
         
     | 
| 1031 | 
         
            +
                                              int64_t size_m, int64_t size_n,
         
     | 
| 1032 | 
         
            +
                                              int64_t size_k) {
         
     | 
| 1033 | 
         
            +
              vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
         
     | 
| 1034 | 
         
            +
              // Verify num_bits
         
     | 
| 1035 | 
         
            +
              TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
         
     | 
| 1036 | 
         
            +
                          "num_bits must be uint4b8 or uint8b128. Got = ", b_q_type.str());
         
     | 
| 1037 | 
         
            +
              int pack_factor = 32 / b_q_type.size_bits();
         
     | 
| 1038 | 
         
            +
             
     | 
| 1039 | 
         
            +
              // Verify M
         
     | 
| 1040 | 
         
            +
              TORCH_CHECK(size_m == a.size(0),
         
     | 
| 1041 | 
         
            +
                          "Shape mismatch: a.size(0) = " + str(a.size(0)) +
         
     | 
| 1042 | 
         
            +
                              ", size_m = " + str(size_m));
         
     | 
| 1043 | 
         
            +
             
     | 
| 1044 | 
         
            +
              // Verify K
         
     | 
| 1045 | 
         
            +
              TORCH_CHECK(size_k == a.size(1),
         
     | 
| 1046 | 
         
            +
                          "Shape mismatch: a.size(1) = " + str(a.size(1)) +
         
     | 
| 1047 | 
         
            +
                              ", size_k = " + str(size_k));
         
     | 
| 1048 | 
         
            +
              TORCH_CHECK(size_k % marlin_24::tile_size == 0,
         
     | 
| 1049 | 
         
            +
                          "size_k = " + str(size_k) + " is not divisible by tile_size = " +
         
     | 
| 1050 | 
         
            +
                              str(marlin_24::tile_size));
         
     | 
| 1051 | 
         
            +
              TORCH_CHECK((size_k / marlin_24::tile_size / 2) == b_q_weight.size(0),
         
     | 
| 1052 | 
         
            +
                          "Shape mismatch: b_q_weight.size(0) = " +
         
     | 
| 1053 | 
         
            +
                              str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
         
     | 
| 1054 | 
         
            +
                              ", tile_size = " + str(marlin_24::tile_size));
         
     | 
| 1055 | 
         
            +
             
     | 
| 1056 | 
         
            +
              // Verify N
         
     | 
| 1057 | 
         
            +
              TORCH_CHECK(b_scales.size(1) == size_n,
         
     | 
| 1058 | 
         
            +
                          "b_scales.size(1) = " + str(b_scales.size(1)) +
         
     | 
| 1059 | 
         
            +
                              ", size_n = " + str(size_n));
         
     | 
| 1060 | 
         
            +
              TORCH_CHECK(
         
     | 
| 1061 | 
         
            +
                  b_q_weight.size(1) % marlin_24::tile_size == 0,
         
     | 
| 1062 | 
         
            +
                  "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
         
     | 
| 1063 | 
         
            +
                      " is not divisible by tile_size = " + str(marlin_24::tile_size));
         
     | 
| 1064 | 
         
            +
             
     | 
| 1065 | 
         
            +
              int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor;
         
     | 
| 1066 | 
         
            +
              TORCH_CHECK(
         
     | 
| 1067 | 
         
            +
                  size_n == actual_size_n,
         
     | 
| 1068 | 
         
            +
                  "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n));
         
     | 
| 1069 | 
         
            +
             
     | 
| 1070 | 
         
            +
              // Verify meta
         
     | 
| 1071 | 
         
            +
              TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2,
         
     | 
| 1072 | 
         
            +
                          "b_meta.size(0) = ", b_meta.size(0),
         
     | 
| 1073 | 
         
            +
                          " is not size_k / 8 / 2 / 2 = ", size_k / 8 / 2 / 2);
         
     | 
| 1074 | 
         
            +
              TORCH_CHECK(b_meta.size(1) == size_n * 2, "b_meta.size(1) = ", b_meta.size(1),
         
     | 
| 1075 | 
         
            +
                          " is not size_n * 2 = ", size_n * 2);
         
     | 
| 1076 | 
         
            +
             
     | 
| 1077 | 
         
            +
              // Verify A device and strides
         
     | 
| 1078 | 
         
            +
              TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
         
     | 
| 1079 | 
         
            +
              TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
         
     | 
| 1080 | 
         
            +
              TORCH_CHECK(a.dtype() == torch::kFloat16,
         
     | 
| 1081 | 
         
            +
                          "A is not float16, currently only float16 is supported");
         
     | 
| 1082 | 
         
            +
             
     | 
| 1083 | 
         
            +
              // Verify B device and strides
         
     | 
| 1084 | 
         
            +
              TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
         
     | 
| 1085 | 
         
            +
              TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
         
     | 
| 1086 | 
         
            +
             
     | 
| 1087 | 
         
            +
              // Verify b_meta device and strides
         
     | 
| 1088 | 
         
            +
              TORCH_CHECK(b_meta.device().is_cuda(), "b_meta is not on GPU");
         
     | 
| 1089 | 
         
            +
              TORCH_CHECK(b_meta.is_contiguous(), "b_meta is not contiguous");
         
     | 
| 1090 | 
         
            +
             
     | 
| 1091 | 
         
            +
              // Verify scales device and strides
         
     | 
| 1092 | 
         
            +
              TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
         
     | 
| 1093 | 
         
            +
              TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
         
     | 
| 1094 | 
         
            +
              TORCH_CHECK(b_scales.dtype() == torch::kFloat16,
         
     | 
| 1095 | 
         
            +
                          "A is not float16, currently only float16 is supported");
         
     | 
| 1096 | 
         
            +
             
     | 
| 1097 | 
         
            +
              // Alloc C matrix
         
     | 
| 1098 | 
         
            +
              const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
         
     | 
| 1099 | 
         
            +
              auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
         
     | 
| 1100 | 
         
            +
              torch::Tensor c = torch::empty({size_m, size_n}, options);
         
     | 
| 1101 | 
         
            +
             
     | 
| 1102 | 
         
            +
              int thread_k = -1;
         
     | 
| 1103 | 
         
            +
              int thread_m = -1;
         
     | 
| 1104 | 
         
            +
              int sms = -1;
         
     | 
| 1105 | 
         
            +
              int max_par = marlin_24::max_par;
         
     | 
| 1106 | 
         
            +
             
     | 
| 1107 | 
         
            +
              int groupsize = -1;
         
     | 
| 1108 | 
         
            +
              if (b_scales.size(0) > 1) {
         
     | 
| 1109 | 
         
            +
                TORCH_CHECK(size_k % b_scales.size(0) == 0,
         
     | 
| 1110 | 
         
            +
                            "size_k = " + str(size_k) +
         
     | 
| 1111 | 
         
            +
                                ", is not divisible by b_scales.size(0) = " +
         
     | 
| 1112 | 
         
            +
                                str(b_scales.size(0)));
         
     | 
| 1113 | 
         
            +
                groupsize = size_k / b_scales.size(0);
         
     | 
| 1114 | 
         
            +
                groupsize /= 2;  // Because of 24
         
     | 
| 1115 | 
         
            +
              }
         
     | 
| 1116 | 
         
            +
             
     | 
| 1117 | 
         
            +
              // Verify groupsize
         
     | 
| 1118 | 
         
            +
              TORCH_CHECK(groupsize == -1 || groupsize == 64,
         
     | 
| 1119 | 
         
            +
                          "Unexpected groupsize = " + str(groupsize));
         
     | 
| 1120 | 
         
            +
             
     | 
| 1121 | 
         
            +
              // Verify workspace size
         
     | 
| 1122 | 
         
            +
              TORCH_CHECK(size_n % marlin_24::min_thread_n == 0,
         
     | 
| 1123 | 
         
            +
                          "size_n = " + str(size_n) +
         
     | 
| 1124 | 
         
            +
                              ", is not divisible by min_thread_n = " +
         
     | 
| 1125 | 
         
            +
                              str(marlin_24::min_thread_n));
         
     | 
| 1126 | 
         
            +
              int min_workspace_size =
         
     | 
| 1127 | 
         
            +
                  (size_n / marlin_24::min_thread_n) * marlin_24::max_par;
         
     | 
| 1128 | 
         
            +
              TORCH_CHECK(workspace.numel() >= min_workspace_size,
         
     | 
| 1129 | 
         
            +
                          "workspace.numel = " + str(workspace.numel()) +
         
     | 
| 1130 | 
         
            +
                              " is below min_workspace_size = " + str(min_workspace_size));
         
     | 
| 1131 | 
         
            +
             
     | 
| 1132 | 
         
            +
              int dev = a.get_device();
         
     | 
| 1133 | 
         
            +
              marlin_24::marlin_cuda_2_4(
         
     | 
| 1134 | 
         
            +
                  a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(),
         
     | 
| 1135 | 
         
            +
                  b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(),
         
     | 
| 1136 | 
         
            +
                  b_q_type.size_bits(), groupsize, dev, at::cuda::getCurrentCUDAStream(dev),
         
     | 
| 1137 | 
         
            +
                  thread_k, thread_m, sms, max_par);
         
     | 
| 1138 | 
         
            +
             
     | 
| 1139 | 
         
            +
              return c;
         
     | 
| 1140 | 
         
            +
            }
         
     | 
    	
        tests/kernels/test_marlin_gemm.py
    ADDED
    
    | 
         @@ -0,0 +1,733 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """Tests for the marlin kernel.
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
         
     | 
| 4 | 
         
            +
            """
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import pytest
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import quantization
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from quantization.utils.marlin_utils import (
         
     | 
| 14 | 
         
            +
                GPTQ_MARLIN_24_MAX_PARALLEL,
         
     | 
| 15 | 
         
            +
                GPTQ_MARLIN_24_MIN_THREAD_N,
         
     | 
| 16 | 
         
            +
                GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
         
     | 
| 17 | 
         
            +
                GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
         
     | 
| 18 | 
         
            +
                GPTQ_MARLIN_MAX_PARALLEL,
         
     | 
| 19 | 
         
            +
                GPTQ_MARLIN_MIN_THREAD_N,
         
     | 
| 20 | 
         
            +
                MARLIN_SUPPORTED_GROUP_SIZES,
         
     | 
| 21 | 
         
            +
                MARLIN_QQQ_MAX_PARALLEL,
         
     | 
| 22 | 
         
            +
                MARLIN_QQQ_MIN_THREAD_N,
         
     | 
| 23 | 
         
            +
                MARLIN_QQQ_SUPPORTED_GROUP_SIZES,
         
     | 
| 24 | 
         
            +
                MARLIN_QQQ_SUPPORTED_NUM_BITS,
         
     | 
| 25 | 
         
            +
                marlin_make_empty_g_idx,
         
     | 
| 26 | 
         
            +
                marlin_permute_scales,
         
     | 
| 27 | 
         
            +
                query_marlin_supported_quant_types,
         
     | 
| 28 | 
         
            +
            )
         
     | 
| 29 | 
         
            +
            from quantization.utils.marlin_utils_fp8 import (
         
     | 
| 30 | 
         
            +
                pack_fp8_to_int32,
         
     | 
| 31 | 
         
            +
            )
         
     | 
| 32 | 
         
            +
            from quantization.utils.quant_utils import (
         
     | 
| 33 | 
         
            +
                awq_pack,
         
     | 
| 34 | 
         
            +
                gptq_pack,
         
     | 
| 35 | 
         
            +
                gptq_quantize_weights,
         
     | 
| 36 | 
         
            +
                quantize_weights,
         
     | 
| 37 | 
         
            +
                sort_weights,
         
     | 
| 38 | 
         
            +
            )
         
     | 
| 39 | 
         
            +
            from quantization.scalar_type import scalar_types
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            from quantization.utils.marlin_utils_test import (
         
     | 
| 42 | 
         
            +
                MarlinWorkspace,
         
     | 
| 43 | 
         
            +
                awq_marlin_quantize,
         
     | 
| 44 | 
         
            +
                get_weight_perm,
         
     | 
| 45 | 
         
            +
                marlin_quantize,
         
     | 
| 46 | 
         
            +
                marlin_weights,
         
     | 
| 47 | 
         
            +
            )
         
     | 
| 48 | 
         
            +
            from quantization.utils.marlin_utils_test_24 import (
         
     | 
| 49 | 
         
            +
                marlin_24_quantize,
         
     | 
| 50 | 
         
            +
            )
         
     | 
| 51 | 
         
            +
            from quantization.utils.marlin_utils_test_qqq import (  # noqa: E501
         
     | 
| 52 | 
         
            +
                marlin_qqq_quantize,
         
     | 
| 53 | 
         
            +
            )
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            # Avoid torch._dynamo.exc.Unsupported: cache_size_limit reached
         
     | 
| 57 | 
         
            +
            torch._dynamo.config.cache_size_limit = 128
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            capability = torch.cuda.get_device_capability()
         
     | 
| 61 | 
         
            +
            capability = capability[0] * 10 + capability[1]
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            ACT_ORDER_OPTS = [False, True]
         
     | 
| 65 | 
         
            +
            K_FULL_OPTS = [False, True]
         
     | 
| 66 | 
         
            +
            USE_FP32_REDUCE_OPTS = [False, True]
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            MARLIN_K_CHUNKS = [128]
         
     | 
| 69 | 
         
            +
            MARLIN_N_CHUNKS = [64, 256]
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            MARLIN_24_K_CHUNKS = [128]
         
     | 
| 72 | 
         
            +
            MARLIN_24_N_CHUNKS = [512]
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            HQQ_SUPPORTED_GROUP_SIZES = [64]
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            MNK_FACTORS = [
         
     | 
| 77 | 
         
            +
                (1, 1, 1),
         
     | 
| 78 | 
         
            +
                (1, 4, 8),
         
     | 
| 79 | 
         
            +
                (1, 7, 5),
         
     | 
| 80 | 
         
            +
                (13, 17, 67),
         
     | 
| 81 | 
         
            +
                (26, 37, 13),
         
     | 
| 82 | 
         
            +
                (67, 13, 11),
         
     | 
| 83 | 
         
            +
                (257, 13, 11),
         
     | 
| 84 | 
         
            +
                (658, 13, 11),
         
     | 
| 85 | 
         
            +
            ]
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            DTYPES = [torch.float16, torch.bfloat16]
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            def compute_max_diff(output, output_ref):
         
     | 
| 91 | 
         
            +
                return torch.mean(torch.abs(output - output_ref)) / torch.mean(
         
     | 
| 92 | 
         
            +
                    torch.abs(output_ref)
         
     | 
| 93 | 
         
            +
                )
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            def rand_data(shape, dtype=torch.float16):
         
     | 
| 97 | 
         
            +
                return torch.randn(shape, dtype=dtype, device="cuda")
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            @pytest.mark.skipif(
         
     | 
| 101 | 
         
            +
                capability < 80,
         
     | 
| 102 | 
         
            +
                reason="Marlin is not supported on this GPU type.",
         
     | 
| 103 | 
         
            +
            )
         
     | 
| 104 | 
         
            +
            @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
         
     | 
| 105 | 
         
            +
            @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
         
     | 
| 106 | 
         
            +
            @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False))
         
     | 
| 107 | 
         
            +
            @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
         
     | 
| 108 | 
         
            +
            @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
         
     | 
| 109 | 
         
            +
            @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
         
     | 
| 110 | 
         
            +
            def test_gptq_marlin_repack(
         
     | 
| 111 | 
         
            +
                k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
         
     | 
| 112 | 
         
            +
            ):
         
     | 
| 113 | 
         
            +
                m_factor, n_factor, k_factor = mnk_factors
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                size_k = k_chunk * k_factor
         
     | 
| 116 | 
         
            +
                size_n = n_chunk * n_factor
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                # Filter act_order
         
     | 
| 119 | 
         
            +
                if act_order:
         
     | 
| 120 | 
         
            +
                    if group_size == -1:
         
     | 
| 121 | 
         
            +
                        return
         
     | 
| 122 | 
         
            +
                    if group_size == size_k:
         
     | 
| 123 | 
         
            +
                        return
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                # Normalize group_size
         
     | 
| 126 | 
         
            +
                if group_size == -1:
         
     | 
| 127 | 
         
            +
                    group_size = size_k
         
     | 
| 128 | 
         
            +
                assert group_size <= size_k
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                # Create input
         
     | 
| 131 | 
         
            +
                b_weight = rand_data((size_k, size_n))
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                # Quantize (and apply act_order if provided)
         
     | 
| 134 | 
         
            +
                w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
         
     | 
| 135 | 
         
            +
                    b_weight, quant_type, group_size, act_order
         
     | 
| 136 | 
         
            +
                )
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                # Pack to GPTQ format
         
     | 
| 139 | 
         
            +
                q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                # For act_order, sort the "weights" and "g_idx" so that group ids are
         
     | 
| 142 | 
         
            +
                # increasing
         
     | 
| 143 | 
         
            +
                sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
         
     | 
| 144 | 
         
            +
                if act_order:
         
     | 
| 145 | 
         
            +
                    q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                # Pack to Marlin format
         
     | 
| 148 | 
         
            +
                weight_perm = get_weight_perm(quant_type.size_bits)
         
     | 
| 149 | 
         
            +
                marlin_q_w_1 = marlin_weights(
         
     | 
| 150 | 
         
            +
                    q_w, size_k, size_n, quant_type.size_bits, weight_perm
         
     | 
| 151 | 
         
            +
                )
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                opcheck(
         
     | 
| 154 | 
         
            +
                    quantization._ops.ops.gptq_marlin_repack,
         
     | 
| 155 | 
         
            +
                    (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits),
         
     | 
| 156 | 
         
            +
                )
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                # Run Marlin repack GPU kernel
         
     | 
| 159 | 
         
            +
                marlin_q_w_2 = quantization.gptq_marlin_repack(
         
     | 
| 160 | 
         
            +
                    q_w_gptq,
         
     | 
| 161 | 
         
            +
                    sort_indices,
         
     | 
| 162 | 
         
            +
                    size_k,
         
     | 
| 163 | 
         
            +
                    size_n,
         
     | 
| 164 | 
         
            +
                    quant_type.size_bits,
         
     | 
| 165 | 
         
            +
                )
         
     | 
| 166 | 
         
            +
                torch.cuda.synchronize()
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
            @pytest.mark.skipif(
         
     | 
| 172 | 
         
            +
                capability < 80,
         
     | 
| 173 | 
         
            +
                reason="Marlin is not supported on this GPU type.",
         
     | 
| 174 | 
         
            +
            )
         
     | 
| 175 | 
         
            +
            @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
         
     | 
| 176 | 
         
            +
            @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
         
     | 
| 177 | 
         
            +
            @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False))
         
     | 
| 178 | 
         
            +
            @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
         
     | 
| 179 | 
         
            +
            @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
         
     | 
| 180 | 
         
            +
            def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
         
     | 
| 181 | 
         
            +
                m_factor, n_factor, k_factor = mnk_factors
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                size_k = k_chunk * k_factor
         
     | 
| 184 | 
         
            +
                size_n = n_chunk * n_factor
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                # Normalize group_size
         
     | 
| 187 | 
         
            +
                if group_size == -1:
         
     | 
| 188 | 
         
            +
                    group_size = size_k
         
     | 
| 189 | 
         
            +
                assert group_size <= size_k
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                # Create input
         
     | 
| 192 | 
         
            +
                b_weight = rand_data((size_k, size_n))
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                # Quantize
         
     | 
| 195 | 
         
            +
                w_ref, q_w, s, zp = quantize_weights(
         
     | 
| 196 | 
         
            +
                    b_weight, quant_type, group_size, zero_points=True
         
     | 
| 197 | 
         
            +
                )
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                # Pack to AWQ format
         
     | 
| 200 | 
         
            +
                q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                # Pack to Marlin format
         
     | 
| 203 | 
         
            +
                weight_perm = get_weight_perm(quant_type.size_bits)
         
     | 
| 204 | 
         
            +
                marlin_q_w_1 = marlin_weights(
         
     | 
| 205 | 
         
            +
                    q_w, size_k, size_n, quant_type.size_bits, weight_perm
         
     | 
| 206 | 
         
            +
                )
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                opcheck(
         
     | 
| 209 | 
         
            +
                    quantization._ops.ops.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits)
         
     | 
| 210 | 
         
            +
                )
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                # Run Marlin repack GPU kernel
         
     | 
| 213 | 
         
            +
                marlin_q_w_2 = quantization.awq_marlin_repack(
         
     | 
| 214 | 
         
            +
                    q_w_awq,
         
     | 
| 215 | 
         
            +
                    size_k,
         
     | 
| 216 | 
         
            +
                    size_n,
         
     | 
| 217 | 
         
            +
                    quant_type.size_bits,
         
     | 
| 218 | 
         
            +
                )
         
     | 
| 219 | 
         
            +
                torch.cuda.synchronize()
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
            @pytest.mark.skipif(
         
     | 
| 225 | 
         
            +
                capability < 80,
         
     | 
| 226 | 
         
            +
                reason="Marlin is not supported on this GPU type.",
         
     | 
| 227 | 
         
            +
            )
         
     | 
| 228 | 
         
            +
            @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
         
     | 
| 229 | 
         
            +
            @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
         
     | 
| 230 | 
         
            +
            @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False))
         
     | 
| 231 | 
         
            +
            @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
         
     | 
| 232 | 
         
            +
            @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
         
     | 
| 233 | 
         
            +
            @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
         
     | 
| 234 | 
         
            +
            @pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
         
     | 
| 235 | 
         
            +
            @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
         
     | 
| 236 | 
         
            +
            def test_gptq_marlin_gemm(
         
     | 
| 237 | 
         
            +
                k_chunk,
         
     | 
| 238 | 
         
            +
                n_chunk,
         
     | 
| 239 | 
         
            +
                quant_type,
         
     | 
| 240 | 
         
            +
                group_size,
         
     | 
| 241 | 
         
            +
                mnk_factors,
         
     | 
| 242 | 
         
            +
                act_order,
         
     | 
| 243 | 
         
            +
                is_k_full,
         
     | 
| 244 | 
         
            +
                use_fp32_reduce,
         
     | 
| 245 | 
         
            +
            ):
         
     | 
| 246 | 
         
            +
                m_factor, n_factor, k_factor = mnk_factors
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                size_m = m_factor
         
     | 
| 249 | 
         
            +
                size_k = k_chunk * k_factor
         
     | 
| 250 | 
         
            +
                size_n = n_chunk * n_factor
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                if act_order:
         
     | 
| 253 | 
         
            +
                    if group_size == -1:
         
     | 
| 254 | 
         
            +
                        return
         
     | 
| 255 | 
         
            +
                    if group_size == size_k:
         
     | 
| 256 | 
         
            +
                        return
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                a_input = rand_data((size_m, size_k))
         
     | 
| 259 | 
         
            +
                b_weight = rand_data((size_k, size_n))
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
         
     | 
| 262 | 
         
            +
                    b_weight, quant_type, group_size, act_order
         
     | 
| 263 | 
         
            +
                )
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                workspace = MarlinWorkspace(
         
     | 
| 268 | 
         
            +
                    size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
         
     | 
| 269 | 
         
            +
                )
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                opcheck(
         
     | 
| 272 | 
         
            +
                    quantization._ops.ops.gptq_marlin_gemm,
         
     | 
| 273 | 
         
            +
                    (
         
     | 
| 274 | 
         
            +
                        a_input,
         
     | 
| 275 | 
         
            +
                        marlin_q_w,
         
     | 
| 276 | 
         
            +
                        marlin_s,
         
     | 
| 277 | 
         
            +
                        marlin_zp,
         
     | 
| 278 | 
         
            +
                        g_idx,
         
     | 
| 279 | 
         
            +
                        sort_indices,
         
     | 
| 280 | 
         
            +
                        workspace.scratch,
         
     | 
| 281 | 
         
            +
                        quant_type.id,
         
     | 
| 282 | 
         
            +
                        a_input.shape[0],
         
     | 
| 283 | 
         
            +
                        b_weight.shape[1],
         
     | 
| 284 | 
         
            +
                        a_input.shape[1],
         
     | 
| 285 | 
         
            +
                        is_k_full,
         
     | 
| 286 | 
         
            +
                        False,
         
     | 
| 287 | 
         
            +
                        use_fp32_reduce,
         
     | 
| 288 | 
         
            +
                        False,
         
     | 
| 289 | 
         
            +
                    ),
         
     | 
| 290 | 
         
            +
                    test_utils=DEFAULT_OPCHECK_TEST_UTILS,
         
     | 
| 291 | 
         
            +
                )
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                output = quantization.gptq_marlin_gemm(
         
     | 
| 294 | 
         
            +
                    a_input,
         
     | 
| 295 | 
         
            +
                    marlin_q_w,
         
     | 
| 296 | 
         
            +
                    marlin_s,
         
     | 
| 297 | 
         
            +
                    marlin_zp,
         
     | 
| 298 | 
         
            +
                    g_idx,
         
     | 
| 299 | 
         
            +
                    sort_indices,
         
     | 
| 300 | 
         
            +
                    workspace.scratch,
         
     | 
| 301 | 
         
            +
                    quant_type,
         
     | 
| 302 | 
         
            +
                    a_input.shape[0],
         
     | 
| 303 | 
         
            +
                    b_weight.shape[1],
         
     | 
| 304 | 
         
            +
                    a_input.shape[1],
         
     | 
| 305 | 
         
            +
                    is_k_full=is_k_full,
         
     | 
| 306 | 
         
            +
                    has_zp=False,
         
     | 
| 307 | 
         
            +
                    use_fp32_reduce=use_fp32_reduce,
         
     | 
| 308 | 
         
            +
                    is_zp_float=False,
         
     | 
| 309 | 
         
            +
                )
         
     | 
| 310 | 
         
            +
                output_ref = torch.matmul(a_input, w_ref)
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                torch.cuda.synchronize()
         
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
                max_diff = compute_max_diff(output, output_ref)
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                assert max_diff < 0.04
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
            # TODO: find better way to test this?
         
     | 
| 320 | 
         
            +
            @torch.compile(fullgraph=True)
         
     | 
| 321 | 
         
            +
            def marlin_24_gemm_tester(
         
     | 
| 322 | 
         
            +
                a_input,
         
     | 
| 323 | 
         
            +
                marlin_24_q_w_comp,
         
     | 
| 324 | 
         
            +
                marlin_24_meta,
         
     | 
| 325 | 
         
            +
                marlin_24_s,
         
     | 
| 326 | 
         
            +
                scratch,
         
     | 
| 327 | 
         
            +
                quant_type,
         
     | 
| 328 | 
         
            +
                size_m,
         
     | 
| 329 | 
         
            +
                size_n,
         
     | 
| 330 | 
         
            +
                size_k,
         
     | 
| 331 | 
         
            +
            ):
         
     | 
| 332 | 
         
            +
                return quantization.gptq_marlin_24_gemm(
         
     | 
| 333 | 
         
            +
                    a_input,
         
     | 
| 334 | 
         
            +
                    marlin_24_q_w_comp,
         
     | 
| 335 | 
         
            +
                    marlin_24_meta,
         
     | 
| 336 | 
         
            +
                    marlin_24_s,
         
     | 
| 337 | 
         
            +
                    scratch,
         
     | 
| 338 | 
         
            +
                    quant_type,
         
     | 
| 339 | 
         
            +
                    size_m,
         
     | 
| 340 | 
         
            +
                    size_n,
         
     | 
| 341 | 
         
            +
                    size_k,
         
     | 
| 342 | 
         
            +
                )
         
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
            @pytest.mark.skipif(
         
     | 
| 346 | 
         
            +
                capability < 80,
         
     | 
| 347 | 
         
            +
                reason="Marlin is not supported on this GPU type.",
         
     | 
| 348 | 
         
            +
            )
         
     | 
| 349 | 
         
            +
            @pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
         
     | 
| 350 | 
         
            +
            @pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
         
     | 
| 351 | 
         
            +
            @pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
         
     | 
| 352 | 
         
            +
            @pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
         
     | 
| 353 | 
         
            +
            @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
         
     | 
| 354 | 
         
            +
            def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
         
     | 
| 355 | 
         
            +
                m_factor, n_factor, k_factor = mnk_factors
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                size_m = m_factor
         
     | 
| 358 | 
         
            +
                size_k = k_chunk * k_factor
         
     | 
| 359 | 
         
            +
                size_n = n_chunk * n_factor
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                a_input = rand_data((size_m, size_k))
         
     | 
| 362 | 
         
            +
                b_weight = rand_data((size_k, size_n))
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
         
     | 
| 365 | 
         
            +
                    b_weight, quant_type, group_size
         
     | 
| 366 | 
         
            +
                )
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
                workspace_24 = MarlinWorkspace(
         
     | 
| 369 | 
         
            +
                    size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
         
     | 
| 370 | 
         
            +
                )
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                output_ref = torch.matmul(a_input, w_24_ref)
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                opcheck(
         
     | 
| 375 | 
         
            +
                    quantization._ops.ops.gptq_marlin_24_gemm,
         
     | 
| 376 | 
         
            +
                    (
         
     | 
| 377 | 
         
            +
                        a_input,
         
     | 
| 378 | 
         
            +
                        marlin_24_q_w_comp,
         
     | 
| 379 | 
         
            +
                        marlin_24_meta,
         
     | 
| 380 | 
         
            +
                        marlin_24_s,
         
     | 
| 381 | 
         
            +
                        workspace_24.scratch,
         
     | 
| 382 | 
         
            +
                        quant_type.id,
         
     | 
| 383 | 
         
            +
                        a_input.shape[0],
         
     | 
| 384 | 
         
            +
                        b_weight.shape[1],
         
     | 
| 385 | 
         
            +
                        a_input.shape[1],
         
     | 
| 386 | 
         
            +
                    ),
         
     | 
| 387 | 
         
            +
                    test_utils=DEFAULT_OPCHECK_TEST_UTILS,
         
     | 
| 388 | 
         
            +
                )
         
     | 
| 389 | 
         
            +
             
     | 
| 390 | 
         
            +
                output = marlin_24_gemm_tester(
         
     | 
| 391 | 
         
            +
                    a_input,
         
     | 
| 392 | 
         
            +
                    marlin_24_q_w_comp,
         
     | 
| 393 | 
         
            +
                    marlin_24_meta,
         
     | 
| 394 | 
         
            +
                    marlin_24_s,
         
     | 
| 395 | 
         
            +
                    workspace_24.scratch,
         
     | 
| 396 | 
         
            +
                    quant_type,
         
     | 
| 397 | 
         
            +
                    a_input.shape[0],
         
     | 
| 398 | 
         
            +
                    b_weight.shape[1],
         
     | 
| 399 | 
         
            +
                    a_input.shape[1],
         
     | 
| 400 | 
         
            +
                )
         
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
                torch.cuda.synchronize()
         
     | 
| 403 | 
         
            +
             
     | 
| 404 | 
         
            +
                max_diff = compute_max_diff(output, output_ref)
         
     | 
| 405 | 
         
            +
             
     | 
| 406 | 
         
            +
                assert max_diff < 0.04
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
            @pytest.mark.skipif(
         
     | 
| 410 | 
         
            +
                capability < 80,
         
     | 
| 411 | 
         
            +
                reason="Marlin is not supported on this GPU type.",
         
     | 
| 412 | 
         
            +
            )
         
     | 
| 413 | 
         
            +
            @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
         
     | 
| 414 | 
         
            +
            @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
         
     | 
| 415 | 
         
            +
            @pytest.mark.parametrize("num_bits", [8])
         
     | 
| 416 | 
         
            +
            @pytest.mark.parametrize("group_size", [-1])
         
     | 
| 417 | 
         
            +
            @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
         
     | 
| 418 | 
         
            +
            @pytest.mark.parametrize("dtype", DTYPES)
         
     | 
| 419 | 
         
            +
            def test_fp8_marlin_gemm(
         
     | 
| 420 | 
         
            +
                k_chunk,
         
     | 
| 421 | 
         
            +
                n_chunk,
         
     | 
| 422 | 
         
            +
                num_bits,
         
     | 
| 423 | 
         
            +
                group_size,
         
     | 
| 424 | 
         
            +
                mnk_factors,
         
     | 
| 425 | 
         
            +
                dtype,
         
     | 
| 426 | 
         
            +
            ):
         
     | 
| 427 | 
         
            +
                m_factor, n_factor, k_factor = mnk_factors
         
     | 
| 428 | 
         
            +
             
     | 
| 429 | 
         
            +
                size_m = m_factor
         
     | 
| 430 | 
         
            +
                size_k = k_chunk * k_factor
         
     | 
| 431 | 
         
            +
                size_n = n_chunk * n_factor
         
     | 
| 432 | 
         
            +
             
     | 
| 433 | 
         
            +
                a_input = rand_data((size_m, size_k), dtype=dtype)
         
     | 
| 434 | 
         
            +
                b_weight = rand_data((size_k, size_n), dtype=dtype)
         
     | 
| 435 | 
         
            +
             
     | 
| 436 | 
         
            +
                # WEIGHTS
         
     | 
| 437 | 
         
            +
                fp8_weight, weight_scale = quantization.scaled_fp8_quant(b_weight, scale=None)
         
     | 
| 438 | 
         
            +
                # Repack weights to gptq format (packed int32 elements)
         
     | 
| 439 | 
         
            +
                packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
         
     | 
| 440 | 
         
            +
                # Repack weights to marlin format
         
     | 
| 441 | 
         
            +
                marlin_qweight = quantization.gptq_marlin_repack(
         
     | 
| 442 | 
         
            +
                    b_q_weight=packed_gptq_qweight,
         
     | 
| 443 | 
         
            +
                    perm=torch.empty(0, dtype=torch.int, device="cuda"),
         
     | 
| 444 | 
         
            +
                    size_k=size_k,
         
     | 
| 445 | 
         
            +
                    size_n=size_n,
         
     | 
| 446 | 
         
            +
                    num_bits=8,
         
     | 
| 447 | 
         
            +
                )
         
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
                # WEIGHT SCALES
         
     | 
| 450 | 
         
            +
                # Currently Marlin doesn't support per-tensor scales, so we
         
     | 
| 451 | 
         
            +
                # expand it to channelwise
         
     | 
| 452 | 
         
            +
                scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
         
     | 
| 453 | 
         
            +
                # Permute scales
         
     | 
| 454 | 
         
            +
                marlin_scales = marlin_permute_scales(
         
     | 
| 455 | 
         
            +
                    s=scales, size_k=size_k, size_n=size_n, group_size=-1
         
     | 
| 456 | 
         
            +
                )
         
     | 
| 457 | 
         
            +
             
     | 
| 458 | 
         
            +
                workspace = MarlinWorkspace(
         
     | 
| 459 | 
         
            +
                    size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
         
     | 
| 460 | 
         
            +
                )
         
     | 
| 461 | 
         
            +
             
     | 
| 462 | 
         
            +
                opcheck(
         
     | 
| 463 | 
         
            +
                    quantization._ops.ops.fp8_marlin_gemm,
         
     | 
| 464 | 
         
            +
                    (
         
     | 
| 465 | 
         
            +
                        a_input,
         
     | 
| 466 | 
         
            +
                        marlin_qweight,
         
     | 
| 467 | 
         
            +
                        marlin_scales,
         
     | 
| 468 | 
         
            +
                        workspace.scratch,
         
     | 
| 469 | 
         
            +
                        num_bits,
         
     | 
| 470 | 
         
            +
                        a_input.shape[0],
         
     | 
| 471 | 
         
            +
                        b_weight.shape[1],
         
     | 
| 472 | 
         
            +
                        a_input.shape[1],
         
     | 
| 473 | 
         
            +
                    ),
         
     | 
| 474 | 
         
            +
                )
         
     | 
| 475 | 
         
            +
             
     | 
| 476 | 
         
            +
                output = quantization.fp8_marlin_gemm(
         
     | 
| 477 | 
         
            +
                    a=a_input,
         
     | 
| 478 | 
         
            +
                    b_q_weight=marlin_qweight,
         
     | 
| 479 | 
         
            +
                    b_scales=marlin_scales,
         
     | 
| 480 | 
         
            +
                    workspace=workspace.scratch,
         
     | 
| 481 | 
         
            +
                    num_bits=num_bits,
         
     | 
| 482 | 
         
            +
                    size_m=a_input.shape[0],
         
     | 
| 483 | 
         
            +
                    size_n=b_weight.shape[1],
         
     | 
| 484 | 
         
            +
                    size_k=a_input.shape[1],
         
     | 
| 485 | 
         
            +
                )
         
     | 
| 486 | 
         
            +
                output_ref = torch.matmul(a_input, b_weight)
         
     | 
| 487 | 
         
            +
             
     | 
| 488 | 
         
            +
                torch.cuda.synchronize()
         
     | 
| 489 | 
         
            +
             
     | 
| 490 | 
         
            +
                max_diff = compute_max_diff(output, output_ref)
         
     | 
| 491 | 
         
            +
             
     | 
| 492 | 
         
            +
                assert max_diff < 0.04
         
     | 
| 493 | 
         
            +
             
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
            @pytest.mark.skipif(
         
     | 
| 496 | 
         
            +
                capability < 80,
         
     | 
| 497 | 
         
            +
                reason="Marlin is not supported on this GPU type.",
         
     | 
| 498 | 
         
            +
            )
         
     | 
| 499 | 
         
            +
            @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
         
     | 
| 500 | 
         
            +
            @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
         
     | 
| 501 | 
         
            +
            @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
         
     | 
| 502 | 
         
            +
            @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
         
     | 
| 503 | 
         
            +
            @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
         
     | 
| 504 | 
         
            +
            @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
         
     | 
| 505 | 
         
            +
            def test_awq_marlin_gemm(
         
     | 
| 506 | 
         
            +
                k_chunk,
         
     | 
| 507 | 
         
            +
                n_chunk,
         
     | 
| 508 | 
         
            +
                quant_type,
         
     | 
| 509 | 
         
            +
                group_size,
         
     | 
| 510 | 
         
            +
                mnk_factors,
         
     | 
| 511 | 
         
            +
                use_fp32_reduce,
         
     | 
| 512 | 
         
            +
            ):
         
     | 
| 513 | 
         
            +
                m_factor, n_factor, k_factor = mnk_factors
         
     | 
| 514 | 
         
            +
             
     | 
| 515 | 
         
            +
                size_m = m_factor
         
     | 
| 516 | 
         
            +
                size_k = k_chunk * k_factor
         
     | 
| 517 | 
         
            +
                size_n = n_chunk * n_factor
         
     | 
| 518 | 
         
            +
             
     | 
| 519 | 
         
            +
                a_input = rand_data((size_m, size_k))
         
     | 
| 520 | 
         
            +
                b_weight = rand_data((size_k, size_n))
         
     | 
| 521 | 
         
            +
             
     | 
| 522 | 
         
            +
                w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
         
     | 
| 523 | 
         
            +
                    b_weight, quant_type, group_size
         
     | 
| 524 | 
         
            +
                )
         
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
            +
                g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
         
     | 
| 527 | 
         
            +
                sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
         
     | 
| 528 | 
         
            +
                is_k_full = True
         
     | 
| 529 | 
         
            +
                has_zp = True
         
     | 
| 530 | 
         
            +
             
     | 
| 531 | 
         
            +
                workspace = MarlinWorkspace(
         
     | 
| 532 | 
         
            +
                    size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
         
     | 
| 533 | 
         
            +
                )
         
     | 
| 534 | 
         
            +
             
     | 
| 535 | 
         
            +
                output = quantization.gptq_marlin_gemm(
         
     | 
| 536 | 
         
            +
                    a_input,
         
     | 
| 537 | 
         
            +
                    marlin_q_w,
         
     | 
| 538 | 
         
            +
                    marlin_s,
         
     | 
| 539 | 
         
            +
                    marlin_zp,
         
     | 
| 540 | 
         
            +
                    g_idx,
         
     | 
| 541 | 
         
            +
                    sort_indices,
         
     | 
| 542 | 
         
            +
                    workspace.scratch,
         
     | 
| 543 | 
         
            +
                    quant_type,
         
     | 
| 544 | 
         
            +
                    a_input.shape[0],
         
     | 
| 545 | 
         
            +
                    b_weight.shape[1],
         
     | 
| 546 | 
         
            +
                    a_input.shape[1],
         
     | 
| 547 | 
         
            +
                    is_k_full=is_k_full,
         
     | 
| 548 | 
         
            +
                    has_zp=has_zp,
         
     | 
| 549 | 
         
            +
                    use_fp32_reduce=use_fp32_reduce,
         
     | 
| 550 | 
         
            +
                    is_zp_float=False,
         
     | 
| 551 | 
         
            +
                )
         
     | 
| 552 | 
         
            +
                output_ref = torch.matmul(a_input, w_ref)
         
     | 
| 553 | 
         
            +
             
     | 
| 554 | 
         
            +
                torch.cuda.synchronize()
         
     | 
| 555 | 
         
            +
             
     | 
| 556 | 
         
            +
                max_diff = compute_max_diff(output, output_ref)
         
     | 
| 557 | 
         
            +
             
     | 
| 558 | 
         
            +
                assert max_diff < 0.04
         
     | 
| 559 | 
         
            +
             
     | 
| 560 | 
         
            +
             
     | 
| 561 | 
         
            +
            @pytest.mark.skipif(
         
     | 
| 562 | 
         
            +
                capability < 80,
         
     | 
| 563 | 
         
            +
                reason="Marlin is not supported on this GPU type.",
         
     | 
| 564 | 
         
            +
            )
         
     | 
| 565 | 
         
            +
            @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
         
     | 
| 566 | 
         
            +
            @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
         
     | 
| 567 | 
         
            +
            @pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
         
     | 
| 568 | 
         
            +
            @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
         
     | 
| 569 | 
         
            +
            @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
         
     | 
| 570 | 
         
            +
            def test_hqq_marlin_gemm(
         
     | 
| 571 | 
         
            +
                k_chunk,
         
     | 
| 572 | 
         
            +
                n_chunk,
         
     | 
| 573 | 
         
            +
                group_size,
         
     | 
| 574 | 
         
            +
                mnk_factors,
         
     | 
| 575 | 
         
            +
                use_fp32_reduce,
         
     | 
| 576 | 
         
            +
            ):
         
     | 
| 577 | 
         
            +
                m_factor, n_factor, k_factor = mnk_factors
         
     | 
| 578 | 
         
            +
             
     | 
| 579 | 
         
            +
                size_m = m_factor
         
     | 
| 580 | 
         
            +
                size_k = k_chunk * k_factor
         
     | 
| 581 | 
         
            +
                size_n = n_chunk * n_factor
         
     | 
| 582 | 
         
            +
             
     | 
| 583 | 
         
            +
                quant_type = scalar_types.uint4
         
     | 
| 584 | 
         
            +
             
     | 
| 585 | 
         
            +
                a_input = rand_data((size_m, size_k))
         
     | 
| 586 | 
         
            +
                dev = a_input.device
         
     | 
| 587 | 
         
            +
             
     | 
| 588 | 
         
            +
                b_weight = torch.randint(0, 10, (size_n, size_k), dtype=torch.uint8, device=dev)
         
     | 
| 589 | 
         
            +
                scale = rand_data((size_n, size_k // group_size))
         
     | 
| 590 | 
         
            +
                zero = rand_data((size_n, size_k // group_size))
         
     | 
| 591 | 
         
            +
             
     | 
| 592 | 
         
            +
                gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)
         
     | 
| 593 | 
         
            +
             
     | 
| 594 | 
         
            +
                sort_indices = torch.empty(0, dtype=torch.int, device=dev)
         
     | 
| 595 | 
         
            +
                marlin_w_q = quantization.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, 4).to(
         
     | 
| 596 | 
         
            +
                    dev
         
     | 
| 597 | 
         
            +
                )
         
     | 
| 598 | 
         
            +
                marlin_s = marlin_permute_scales(
         
     | 
| 599 | 
         
            +
                    scale.transpose(1, 0), size_k, size_n, group_size
         
     | 
| 600 | 
         
            +
                ).to(dev)
         
     | 
| 601 | 
         
            +
                marlin_zp = marlin_permute_scales(
         
     | 
| 602 | 
         
            +
                    zero.transpose(1, 0), size_k, size_n, group_size
         
     | 
| 603 | 
         
            +
                ).to(dev)
         
     | 
| 604 | 
         
            +
             
     | 
| 605 | 
         
            +
                g_idx = marlin_make_empty_g_idx(dev)
         
     | 
| 606 | 
         
            +
                g_idx_sort_indices = marlin_make_empty_g_idx(dev)
         
     | 
| 607 | 
         
            +
             
     | 
| 608 | 
         
            +
                workspace = MarlinWorkspace(
         
     | 
| 609 | 
         
            +
                    size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
         
     | 
| 610 | 
         
            +
                )
         
     | 
| 611 | 
         
            +
             
     | 
| 612 | 
         
            +
                output = quantization.gptq_marlin_gemm(
         
     | 
| 613 | 
         
            +
                    a_input,
         
     | 
| 614 | 
         
            +
                    marlin_w_q,
         
     | 
| 615 | 
         
            +
                    marlin_s,
         
     | 
| 616 | 
         
            +
                    marlin_zp,
         
     | 
| 617 | 
         
            +
                    g_idx,
         
     | 
| 618 | 
         
            +
                    g_idx_sort_indices,
         
     | 
| 619 | 
         
            +
                    workspace.scratch,
         
     | 
| 620 | 
         
            +
                    quant_type,
         
     | 
| 621 | 
         
            +
                    a_input.shape[0],
         
     | 
| 622 | 
         
            +
                    b_weight.shape[0],
         
     | 
| 623 | 
         
            +
                    a_input.shape[1],
         
     | 
| 624 | 
         
            +
                    is_k_full=True,
         
     | 
| 625 | 
         
            +
                    has_zp=True,
         
     | 
| 626 | 
         
            +
                    use_fp32_reduce=use_fp32_reduce,
         
     | 
| 627 | 
         
            +
                    is_zp_float=True,
         
     | 
| 628 | 
         
            +
                )
         
     | 
| 629 | 
         
            +
             
     | 
| 630 | 
         
            +
                b_flat = b_weight.reshape(-1, group_size)
         
     | 
| 631 | 
         
            +
                zp_flat = zero.reshape(-1, 1)
         
     | 
| 632 | 
         
            +
                s_flat = scale.reshape(-1, 1)
         
     | 
| 633 | 
         
            +
                dequant = (b_flat - zp_flat) * s_flat
         
     | 
| 634 | 
         
            +
             
     | 
| 635 | 
         
            +
                output_ref = torch.matmul(a_input, dequant.reshape(b_weight.shape).transpose(1, 0))
         
     | 
| 636 | 
         
            +
             
     | 
| 637 | 
         
            +
                torch.cuda.synchronize()
         
     | 
| 638 | 
         
            +
             
     | 
| 639 | 
         
            +
                max_diff = compute_max_diff(output, output_ref)
         
     | 
| 640 | 
         
            +
             
     | 
| 641 | 
         
            +
                assert max_diff < 0.04
         
     | 
| 642 | 
         
            +
             
     | 
| 643 | 
         
            +
             
     | 
| 644 | 
         
            +
            @pytest.mark.skipif(
         
     | 
| 645 | 
         
            +
                capability < 80,
         
     | 
| 646 | 
         
            +
                reason="Marlin is not supported on this GPU type.",
         
     | 
| 647 | 
         
            +
            )
         
     | 
| 648 | 
         
            +
            @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
         
     | 
| 649 | 
         
            +
            @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
         
     | 
| 650 | 
         
            +
            @pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS)
         
     | 
| 651 | 
         
            +
            @pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES)
         
     | 
| 652 | 
         
            +
            @pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
         
     | 
| 653 | 
         
            +
            def test_marlin_qqq_gemm(
         
     | 
| 654 | 
         
            +
                k_chunk,
         
     | 
| 655 | 
         
            +
                n_chunk,
         
     | 
| 656 | 
         
            +
                num_bits,
         
     | 
| 657 | 
         
            +
                group_size,
         
     | 
| 658 | 
         
            +
                mnk_factors,
         
     | 
| 659 | 
         
            +
            ):
         
     | 
| 660 | 
         
            +
                int8_traits = torch.iinfo(torch.int8)
         
     | 
| 661 | 
         
            +
                m_factor, n_factor, k_factor = mnk_factors
         
     | 
| 662 | 
         
            +
             
     | 
| 663 | 
         
            +
                size_m = m_factor
         
     | 
| 664 | 
         
            +
                size_k = k_chunk * k_factor
         
     | 
| 665 | 
         
            +
                size_n = n_chunk * n_factor
         
     | 
| 666 | 
         
            +
             
     | 
| 667 | 
         
            +
                a_input = rand_data((size_m, size_k))
         
     | 
| 668 | 
         
            +
                b_weight = rand_data((size_k, size_n))
         
     | 
| 669 | 
         
            +
             
     | 
| 670 | 
         
            +
                # Quantize activations
         
     | 
| 671 | 
         
            +
                s_a = (
         
     | 
| 672 | 
         
            +
                    a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(torch.float)
         
     | 
| 673 | 
         
            +
                )
         
     | 
| 674 | 
         
            +
                q_a = (a_input / s_a).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8)
         
     | 
| 675 | 
         
            +
             
     | 
| 676 | 
         
            +
                # Quantize weights
         
     | 
| 677 | 
         
            +
                w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = (
         
     | 
| 678 | 
         
            +
                    marlin_qqq_quantize(b_weight, num_bits, group_size)
         
     | 
| 679 | 
         
            +
                )
         
     | 
| 680 | 
         
            +
             
     | 
| 681 | 
         
            +
                workspace = MarlinWorkspace(
         
     | 
| 682 | 
         
            +
                    size_n, MARLIN_QQQ_MIN_THREAD_N, MARLIN_QQQ_MAX_PARALLEL
         
     | 
| 683 | 
         
            +
                )
         
     | 
| 684 | 
         
            +
             
     | 
| 685 | 
         
            +
                opcheck(
         
     | 
| 686 | 
         
            +
                    quantization._ops.ops.marlin_qqq_gemm,
         
     | 
| 687 | 
         
            +
                    (
         
     | 
| 688 | 
         
            +
                        q_a,
         
     | 
| 689 | 
         
            +
                        marlin_qqq_q_w,
         
     | 
| 690 | 
         
            +
                        s_a,
         
     | 
| 691 | 
         
            +
                        marlin_qqq_s_channel,
         
     | 
| 692 | 
         
            +
                        marlin_qqq_s_group,
         
     | 
| 693 | 
         
            +
                        workspace.scratch,
         
     | 
| 694 | 
         
            +
                        a_input.shape[0],
         
     | 
| 695 | 
         
            +
                        b_weight.shape[1],
         
     | 
| 696 | 
         
            +
                        a_input.shape[1],
         
     | 
| 697 | 
         
            +
                    ),
         
     | 
| 698 | 
         
            +
                )
         
     | 
| 699 | 
         
            +
             
     | 
| 700 | 
         
            +
                output = quantization.marlin_qqq_gemm(
         
     | 
| 701 | 
         
            +
                    q_a,
         
     | 
| 702 | 
         
            +
                    marlin_qqq_q_w,
         
     | 
| 703 | 
         
            +
                    s_a,
         
     | 
| 704 | 
         
            +
                    marlin_qqq_s_channel,
         
     | 
| 705 | 
         
            +
                    marlin_qqq_s_group,
         
     | 
| 706 | 
         
            +
                    workspace.scratch,
         
     | 
| 707 | 
         
            +
                    a_input.shape[0],
         
     | 
| 708 | 
         
            +
                    b_weight.shape[1],
         
     | 
| 709 | 
         
            +
                    a_input.shape[1],
         
     | 
| 710 | 
         
            +
                )
         
     | 
| 711 | 
         
            +
                output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)
         
     | 
| 712 | 
         
            +
             
     | 
| 713 | 
         
            +
                torch.cuda.synchronize()
         
     | 
| 714 | 
         
            +
             
     | 
| 715 | 
         
            +
                max_diff = compute_max_diff(output, output_ref)
         
     | 
| 716 | 
         
            +
             
     | 
| 717 | 
         
            +
                assert max_diff < 0.04
         
     | 
| 718 | 
         
            +
             
     | 
| 719 | 
         
            +
             
     | 
| 720 | 
         
            +
            def test_marlin_gemm_opcheck():
         
     | 
| 721 | 
         
            +
                size_m = 2048
         
     | 
| 722 | 
         
            +
                size_n = 4096
         
     | 
| 723 | 
         
            +
                size_k = 4096
         
     | 
| 724 | 
         
            +
                a = torch.rand((size_m, size_n), device="cuda", dtype=torch.float16)
         
     | 
| 725 | 
         
            +
                w = torch.randint(-5, 5, (256, 8192), device="cuda", dtype=torch.int32)
         
     | 
| 726 | 
         
            +
                s = torch.full((32, size_k), 0.125, device="cuda", dtype=torch.float16)
         
     | 
| 727 | 
         
            +
                wk = MarlinWorkspace(
         
     | 
| 728 | 
         
            +
                    size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
         
     | 
| 729 | 
         
            +
                ).scratch
         
     | 
| 730 | 
         
            +
                x = quantization._ops.ops.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
         
     | 
| 731 | 
         
            +
                y = quantization._ops.ops.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
         
     | 
| 732 | 
         
            +
                torch.testing.assert_close(x, y)
         
     | 
| 733 | 
         
            +
                opcheck(quantization._ops.ops.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k))
         
     | 
    	
        tests/kernels/utils.py
    CHANGED
    
    | 
         @@ -4,13 +4,20 @@ import itertools 
     | 
|
| 4 | 
         
             
            import random
         
     | 
| 5 | 
         
             
            import unittest
         
     | 
| 6 | 
         
             
            from numbers import Number
         
     | 
| 7 | 
         
            -
            from typing import  
     | 
| 8 | 
         
            -
                                Union)
         
     | 
| 9 | 
         | 
| 10 | 
         
             
            import pytest
         
     | 
| 11 | 
         
             
            import torch
         
     | 
| 12 | 
         
             
            from torch._prims_common import TensorLikeType
         
     | 
| 13 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 14 | 
         
             
            ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
         
     | 
| 15 | 
         
             
                "test_schema",
         
     | 
| 16 | 
         
             
                "test_autograd_registration",
         
     | 
| 
         @@ -18,6 +25,7 @@ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( 
     | 
|
| 18 | 
         
             
                "test_aot_dispatch_dynamic",
         
     | 
| 19 | 
         
             
            )
         
     | 
| 20 | 
         | 
| 
         | 
|
| 21 | 
         
             
            # Copied/modified from torch._refs.__init__.py
         
     | 
| 22 | 
         
             
            def fp8_allclose(
         
     | 
| 23 | 
         
             
                a: TensorLikeType,
         
     | 
| 
         @@ -29,34 +37,37 @@ def fp8_allclose( 
     | 
|
| 29 | 
         
             
                """
         
     | 
| 30 | 
         
             
                Reference implementation of torch.allclose
         
     | 
| 31 | 
         
             
                """
         
     | 
| 32 | 
         
            -
                torch._refs._check_close_args(name="torch.allclose",
         
     | 
| 33 | 
         
            -
                                              a=a,
         
     | 
| 34 | 
         
            -
                                              b=b,
         
     | 
| 35 | 
         
            -
                                              rtol=rtol,
         
     | 
| 36 | 
         
            -
                                              atol=atol)
         
     | 
| 37 | 
         | 
| 38 | 
         
             
                return bool(
         
     | 
| 39 | 
         
             
                    torch.all(
         
     | 
| 40 | 
         
            -
                        torch.isclose( 
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
             
     | 
| 44 | 
         
            -
             
     | 
| 
         | 
|
| 45 | 
         | 
| 46 | 
         
             
            # A special version of op check that has a restricted default set of test_utils
         
     | 
| 47 | 
         
             
            # and a patched version of allclose that supports fp8 types.
         
     | 
| 48 | 
         
            -
            def opcheck( 
     | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
             
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
             
     | 
| 53 | 
         
            -
             
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
             
     | 
| 56 | 
         
            -
                 
     | 
| 57 | 
         
            -
             
     | 
| 58 | 
         
            -
             
     | 
| 59 | 
         
            -
             
     | 
| 60 | 
         
            -
             
     | 
| 61 | 
         
            -
             
     | 
| 62 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 4 | 
         
             
            import random
         
     | 
| 5 | 
         
             
            import unittest
         
     | 
| 6 | 
         
             
            from numbers import Number
         
     | 
| 7 | 
         
            +
            from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
         
     | 
| 
         | 
|
| 8 | 
         | 
| 9 | 
         
             
            import pytest
         
     | 
| 10 | 
         
             
            import torch
         
     | 
| 11 | 
         
             
            from torch._prims_common import TensorLikeType
         
     | 
| 12 | 
         | 
| 13 | 
         
            +
            # For now, disable "test_aot_dispatch_dynamic" since there are some
         
     | 
| 14 | 
         
            +
            # bugs related to this test in PyTorch 2.4.
         
     | 
| 15 | 
         
            +
            DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
         
     | 
| 16 | 
         
            +
                "test_schema",
         
     | 
| 17 | 
         
            +
                "test_autograd_registration",
         
     | 
| 18 | 
         
            +
                "test_faketensor",
         
     | 
| 19 | 
         
            +
            )
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
             
            ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
         
     | 
| 22 | 
         
             
                "test_schema",
         
     | 
| 23 | 
         
             
                "test_autograd_registration",
         
     | 
| 
         | 
|
| 25 | 
         
             
                "test_aot_dispatch_dynamic",
         
     | 
| 26 | 
         
             
            )
         
     | 
| 27 | 
         | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
             
            # Copied/modified from torch._refs.__init__.py
         
     | 
| 30 | 
         
             
            def fp8_allclose(
         
     | 
| 31 | 
         
             
                a: TensorLikeType,
         
     | 
| 
         | 
|
| 37 | 
         
             
                """
         
     | 
| 38 | 
         
             
                Reference implementation of torch.allclose
         
     | 
| 39 | 
         
             
                """
         
     | 
| 40 | 
         
            +
                torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 41 | 
         | 
| 42 | 
         
             
                return bool(
         
     | 
| 43 | 
         
             
                    torch.all(
         
     | 
| 44 | 
         
            +
                        torch.isclose(
         
     | 
| 45 | 
         
            +
                            a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
         
     | 
| 46 | 
         
            +
                        )
         
     | 
| 47 | 
         
            +
                    ).item()
         
     | 
| 48 | 
         
            +
                )
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         | 
| 51 | 
         
             
            # A special version of op check that has a restricted default set of test_utils
         
     | 
| 52 | 
         
             
            # and a patched version of allclose that supports fp8 types.
         
     | 
| 53 | 
         
            +
            def opcheck(
         
     | 
| 54 | 
         
            +
                op: Union[
         
     | 
| 55 | 
         
            +
                    torch._ops.OpOverload,
         
     | 
| 56 | 
         
            +
                    torch._ops.OpOverloadPacket,
         
     | 
| 57 | 
         
            +
                    torch._library.custom_ops.CustomOpDef,
         
     | 
| 58 | 
         
            +
                ],
         
     | 
| 59 | 
         
            +
                args: Tuple[Any, ...],
         
     | 
| 60 | 
         
            +
                kwargs: Optional[Dict[str, Any]] = None,
         
     | 
| 61 | 
         
            +
                *,
         
     | 
| 62 | 
         
            +
                test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
         
     | 
| 63 | 
         
            +
                raise_exception: bool = True,
         
     | 
| 64 | 
         
            +
                cond: bool = True
         
     | 
| 65 | 
         
            +
            ) -> Dict[str, str]:
         
     | 
| 66 | 
         
            +
                with unittest.mock.patch("torch.allclose", new=fp8_allclose):
         
     | 
| 67 | 
         
            +
                    return (
         
     | 
| 68 | 
         
            +
                        torch.library.opcheck(
         
     | 
| 69 | 
         
            +
                            op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
         
     | 
| 70 | 
         
            +
                        )
         
     | 
| 71 | 
         
            +
                        if cond
         
     | 
| 72 | 
         
            +
                        else {}
         
     | 
| 73 | 
         
            +
                    )
         
     |