danieldk HF staff commited on
Commit
be26e8c
·
1 Parent(s): 5f7337a

Add Triton scaled matmul kernel from vLLM

Browse files
README.md CHANGED
@@ -1,3 +1,9 @@
1
  ---
2
  license: apache-2.0
 
 
3
  ---
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ tags:
4
+ - kernel
5
  ---
6
+
7
+ ## triton-scaled-mm
8
+
9
+ Triton scaled matrix multiplication kernel [from vLLM](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py).
tests/__init__.py ADDED
File without changes
tests/test_triton_scaled_mm.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Tests for the triton_scaled_mm kernel
3
+
4
+ Run `pytest tests/kernels/test_triton_scaled_mm.py`.
5
+ """
6
+ from typing import Optional
7
+
8
+ import pytest
9
+ import torch
10
+
11
+ from triton_scaled_mm import triton_scaled_mm
12
+
13
+ device = "cuda"
14
+
15
+
16
+ def scaled_mm_torch(
17
+ a: torch.Tensor,
18
+ b: torch.Tensor,
19
+ scale_a: torch.Tensor,
20
+ scale_b: torch.Tensor,
21
+ out_dtype: type[torch.dtype],
22
+ bias: Optional[torch.Tensor] = None,
23
+ ) -> torch.Tensor:
24
+ out = torch.mm(a.to(torch.float32), b.to(torch.float32))
25
+ out = scale_a * out
26
+ out = scale_b.T * out
27
+ out = out.to(out_dtype)
28
+ if bias is not None:
29
+ out = out + bias
30
+
31
+ return out
32
+
33
+
34
+ def get_8bit_types():
35
+ types = [torch.int8]
36
+ minor, major = torch.cuda.get_device_capability()
37
+ capability = major * 10 + minor
38
+ supports_fp8 = capability >= 89
39
+
40
+ if torch.version.hip is not None:
41
+ types.append(torch.float8_e4m3fnuz)
42
+ elif torch.version.cuda is not None and torch.cuda.is_available():
43
+ types.append(torch.float8_e4m3fn)
44
+ return types
45
+
46
+
47
+ @pytest.mark.parametrize("M", [1, 33, 64, 512])
48
+ @pytest.mark.parametrize("N", [256, 971, 20486])
49
+ @pytest.mark.parametrize("K", [128, 496, 1024])
50
+ @pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
51
+ @pytest.mark.parametrize("in_dtype", get_8bit_types())
52
+ @pytest.mark.parametrize("use_scalar_scale_a", [True, False])
53
+ @pytest.mark.parametrize("use_scalar_scale_b", [True, False])
54
+ @pytest.mark.parametrize("use_bias", [True, False])
55
+ def test_scaled_mm(
56
+ M, N, K, in_dtype, out_dtype, use_scalar_scale_a, use_scalar_scale_b, use_bias
57
+ ):
58
+ is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t).is_floating_point()
59
+
60
+ torch.manual_seed(0)
61
+
62
+ # NOTE: There are cases, where if the matrix is large enough, an output
63
+ # like 65504.4 can be produced, and can easily turn into inf when
64
+ # multiplied when using float16/bfloat16. This means one function, e.g.,
65
+ # testing function, and another function, e.g. golden function, can
66
+ # produce a non-inf value while the other produces an inf value, and
67
+ # will cause assert_close/allclose to fail, even though if overflow
68
+ # wouldn't have occurred, the values would have been "close."
69
+ #
70
+ # So, the values here are kept small enough to avoid this situation.
71
+ if is_floating_point_type(in_dtype):
72
+ a = (0.25 * torch.rand((M, K), dtype=torch.float32, device=device)).to(in_dtype)
73
+ b = (0.25 * torch.rand((K, N), dtype=torch.float32, device=device)).to(in_dtype)
74
+ else:
75
+ a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device)
76
+ b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device)
77
+
78
+ if use_scalar_scale_a:
79
+ scale_a = torch.rand((1, 1), device=device)
80
+ else:
81
+ scale_a = 0.25 * torch.rand((M, 1), device=device)
82
+
83
+ if use_scalar_scale_b:
84
+ scale_b = torch.rand((1, 1), device=device)
85
+ else:
86
+ scale_b = 0.25 * torch.rand((N, 1), device=device)
87
+
88
+ bias = None
89
+ if use_bias:
90
+ bias = torch.rand((N,), device=device, dtype=out_dtype)
91
+
92
+ c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
93
+
94
+ a_cpu = a.cpu()
95
+ b_cpu = b.cpu()
96
+ scale_a_cpu = scale_a.cpu()
97
+ scale_b_cpu = scale_b.cpu()
98
+ bias_cpu = None if bias is None else bias.cpu()
99
+
100
+ c_actual = scaled_mm_torch(
101
+ a_cpu, b_cpu, scale_a_cpu, scale_b_cpu, out_dtype, bias_cpu
102
+ )
103
+
104
+ c_check_cpu = c_check.cpu()
105
+ torch.testing.assert_close(c_check_cpu, c_actual, rtol=1e-1, atol=1e-1)
torch-ext/triton_scaled_mm/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .triton_scaled_mm import triton_scaled_mm
2
+
3
+ __all__ = ["triton_scaled_mm"]
torch-ext/triton_scaled_mm/triton_scaled_mm.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Optional, Type
4
+
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+
9
+
10
+ def is_weak_contiguous(x: torch.Tensor):
11
+ strides = x.stride()
12
+ sizes = x.shape
13
+ is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
14
+ is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1]))
15
+ return is_transpose or is_not_transpose
16
+
17
+
18
+ @triton.jit
19
+ def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr,
20
+ M, N, K, stride_am, stride_ak, stride_bk, stride_bn,
21
+ stride_cm, stride_cn, ACCUMULATOR_DTYPE: tl.constexpr,
22
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
23
+ BLOCK_SIZE_K: tl.constexpr,
24
+ BLOCK_SIZE_SCALE_A: tl.constexpr,
25
+ BLOCK_SIZE_SCALE_B: tl.constexpr):
26
+ pid = tl.program_id(axis=0)
27
+
28
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
29
+
30
+ pid_m = pid // num_pid_n
31
+ pid_n = pid % num_pid_n
32
+
33
+ accumulator_dtype = ACCUMULATOR_DTYPE
34
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
35
+ dtype=accumulator_dtype)
36
+
37
+ # NOTE: Some tensor inputs are so large, they will cause int32 overflow
38
+ # so it is necessary to use tl.int64 for all the offsets, else SEGV will
39
+ # eventually occur.
40
+
41
+ # Offsets and masks.
42
+ offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
43
+ masks_am = offsets_am < M
44
+
45
+ offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
46
+ masks_bn = offsets_bn < N
47
+
48
+ offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
49
+ offsets_a = (stride_am * offsets_am[:, None] +
50
+ stride_ak * offsets_k[None, :])
51
+ offsets_b = (stride_bk * offsets_k[:, None] +
52
+ stride_bn * offsets_bn[None, :])
53
+
54
+ # NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
55
+ # appropriate offsets and masks for each case. Same goes for
56
+ # BLOCK_SIZE_SCALE_B.
57
+ offsets_scale_am = (tl.arange(0, BLOCK_SIZE_SCALE_A) +
58
+ (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M)
59
+ masks_scale_am = offsets_scale_am < M
60
+
61
+ offsets_scale_bn = (tl.arange(0, BLOCK_SIZE_SCALE_B) +
62
+ (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N)
63
+ masks_scale_bn = offsets_scale_bn < N
64
+
65
+ a_ptrs = a_ptr + offsets_a
66
+ b_ptrs = b_ptr + offsets_b
67
+
68
+ scale_a_ptrs = scale_a_ptr + offsets_scale_am
69
+ scale_b_ptrs = scale_b_ptr + offsets_scale_bn
70
+
71
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
72
+ masks_k = offsets_k < K
73
+ masks_a = masks_am[:, None] & masks_k[None, :]
74
+ a = tl.load(a_ptrs, mask=masks_a)
75
+
76
+ masks_b = masks_k[:, None] & masks_bn[None, :]
77
+ b = tl.load(b_ptrs, mask=masks_b)
78
+
79
+ # Accumulate results.
80
+ accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
81
+
82
+ offsets_k += BLOCK_SIZE_K
83
+ a_ptrs += BLOCK_SIZE_K * stride_ak
84
+ b_ptrs += BLOCK_SIZE_K * stride_bk
85
+
86
+ # Apply scale at end.
87
+ masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
88
+ scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a)
89
+ # Need to broadcast to the appropriate size, if scale_a is already
90
+ # (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes
91
+ # for scale_b below.
92
+ scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1))
93
+ accumulator = scale_a * accumulator.to(tl.float32)
94
+
95
+ masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
96
+ scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b)
97
+ scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1))
98
+ accumulator = scale_b.T * accumulator.to(tl.float32)
99
+
100
+ # Convert to output format.
101
+ c = accumulator.to(c_ptr.type.element_ty)
102
+
103
+ # Add bias, it's already in output format, so add it after conversion.
104
+ if bias_ptr:
105
+ offsets_bias = offsets_bn
106
+ bias_ptrs = bias_ptr + offsets_bias
107
+ bias_mask = offsets_bias < N
108
+ bias = tl.load(bias_ptrs, bias_mask)
109
+ c += bias
110
+
111
+ # Save output
112
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
113
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
114
+ offs_cm = offs_cm.to(tl.int64)
115
+ offs_cn = offs_cn.to(tl.int64)
116
+ c_ptrs = (c_ptr + stride_cm * offs_cm[:, None] +
117
+ stride_cn * offs_cn[None, :])
118
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
119
+
120
+ tl.store(c_ptrs, c, mask=c_mask)
121
+
122
+
123
+ # input - [M, K]
124
+ # weight - [K, N]
125
+ def triton_scaled_mm(input: torch.Tensor,
126
+ weight: torch.Tensor,
127
+ scale_a: torch.Tensor,
128
+ scale_b: torch.Tensor,
129
+ out_dtype: Type[torch.dtype],
130
+ bias: Optional[torch.Tensor] = None,
131
+ block_size_m: int = 32,
132
+ block_size_n: int = 32,
133
+ block_size_k: int = 32,
134
+ use_heuristic=True) -> torch.Tensor:
135
+ M, K = input.shape
136
+ N = weight.shape[1]
137
+
138
+ assert N > 0 and K > 0 and M > 0
139
+ assert weight.shape[0] == K
140
+ assert input.dtype == weight.dtype
141
+
142
+ scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
143
+ scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
144
+
145
+ assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
146
+ assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
147
+ [M, 1])
148
+ assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size(
149
+ [N, 1])
150
+ assert out_dtype.is_floating_point
151
+ assert bias is None or bias.is_floating_point()
152
+ assert is_weak_contiguous(input)
153
+ assert is_weak_contiguous(weight)
154
+
155
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
156
+ N, META['BLOCK_SIZE_N']), )
157
+
158
+ result = torch.empty((M, N), dtype=out_dtype, device=input.device)
159
+
160
+ has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
161
+
162
+ if use_heuristic:
163
+ is_small_N = N < 8192
164
+ next_power_of_2_M = max(32, triton.next_power_of_2(M))
165
+ if next_power_of_2_M <= 32:
166
+ tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
167
+ elif next_power_of_2_M <= 64:
168
+ tile_shape = (64, 64, 256)
169
+ elif next_power_of_2_M <= 128:
170
+ tile_shape = (64, 128, 128)
171
+ else:
172
+ tile_shape = (128, 128, 128)
173
+
174
+ block_size_m, block_size_n, block_size_k = tile_shape
175
+
176
+ block_size_sa = 1 if has_scalar(scale_a) else block_size_m
177
+ block_size_sb = 1 if has_scalar(scale_b) else block_size_n
178
+
179
+ accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32
180
+
181
+ # A = input, B = weight, C = result
182
+ # A = M x K, B = K x N, C = M x N
183
+ scaled_mm_kernel[grid](input,
184
+ weight,
185
+ scale_a,
186
+ scale_b,
187
+ result,
188
+ bias,
189
+ M,
190
+ N,
191
+ K,
192
+ input.stride(0),
193
+ input.stride(1),
194
+ weight.stride(0),
195
+ weight.stride(1),
196
+ result.stride(0),
197
+ result.stride(1),
198
+ accumulator_dtype,
199
+ BLOCK_SIZE_M=block_size_m,
200
+ BLOCK_SIZE_N=block_size_n,
201
+ BLOCK_SIZE_K=block_size_k,
202
+ BLOCK_SIZE_SCALE_A=block_size_sa,
203
+ BLOCK_SIZE_SCALE_B=block_size_sb)
204
+
205
+ return result.to(out_dtype)