danieldk HF Staff commited on
Commit
9f4b2fb
·
1 Parent(s): be26e8c

Fake build

Browse files
build/torch-noarch/triton_scaled_mm/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .triton_scaled_mm import triton_scaled_mm
2
+
3
+ __all__ = ["triton_scaled_mm"]
build/torch-noarch/triton_scaled_mm/triton_scaled_mm.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Optional, Type
4
+
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+
9
+
10
+ def is_weak_contiguous(x: torch.Tensor):
11
+ strides = x.stride()
12
+ sizes = x.shape
13
+ is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
14
+ is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1]))
15
+ return is_transpose or is_not_transpose
16
+
17
+
18
+ @triton.jit
19
+ def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr,
20
+ M, N, K, stride_am, stride_ak, stride_bk, stride_bn,
21
+ stride_cm, stride_cn, ACCUMULATOR_DTYPE: tl.constexpr,
22
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
23
+ BLOCK_SIZE_K: tl.constexpr,
24
+ BLOCK_SIZE_SCALE_A: tl.constexpr,
25
+ BLOCK_SIZE_SCALE_B: tl.constexpr):
26
+ pid = tl.program_id(axis=0)
27
+
28
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
29
+
30
+ pid_m = pid // num_pid_n
31
+ pid_n = pid % num_pid_n
32
+
33
+ accumulator_dtype = ACCUMULATOR_DTYPE
34
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
35
+ dtype=accumulator_dtype)
36
+
37
+ # NOTE: Some tensor inputs are so large, they will cause int32 overflow
38
+ # so it is necessary to use tl.int64 for all the offsets, else SEGV will
39
+ # eventually occur.
40
+
41
+ # Offsets and masks.
42
+ offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
43
+ masks_am = offsets_am < M
44
+
45
+ offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
46
+ masks_bn = offsets_bn < N
47
+
48
+ offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
49
+ offsets_a = (stride_am * offsets_am[:, None] +
50
+ stride_ak * offsets_k[None, :])
51
+ offsets_b = (stride_bk * offsets_k[:, None] +
52
+ stride_bn * offsets_bn[None, :])
53
+
54
+ # NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
55
+ # appropriate offsets and masks for each case. Same goes for
56
+ # BLOCK_SIZE_SCALE_B.
57
+ offsets_scale_am = (tl.arange(0, BLOCK_SIZE_SCALE_A) +
58
+ (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M)
59
+ masks_scale_am = offsets_scale_am < M
60
+
61
+ offsets_scale_bn = (tl.arange(0, BLOCK_SIZE_SCALE_B) +
62
+ (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N)
63
+ masks_scale_bn = offsets_scale_bn < N
64
+
65
+ a_ptrs = a_ptr + offsets_a
66
+ b_ptrs = b_ptr + offsets_b
67
+
68
+ scale_a_ptrs = scale_a_ptr + offsets_scale_am
69
+ scale_b_ptrs = scale_b_ptr + offsets_scale_bn
70
+
71
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
72
+ masks_k = offsets_k < K
73
+ masks_a = masks_am[:, None] & masks_k[None, :]
74
+ a = tl.load(a_ptrs, mask=masks_a)
75
+
76
+ masks_b = masks_k[:, None] & masks_bn[None, :]
77
+ b = tl.load(b_ptrs, mask=masks_b)
78
+
79
+ # Accumulate results.
80
+ accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
81
+
82
+ offsets_k += BLOCK_SIZE_K
83
+ a_ptrs += BLOCK_SIZE_K * stride_ak
84
+ b_ptrs += BLOCK_SIZE_K * stride_bk
85
+
86
+ # Apply scale at end.
87
+ masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
88
+ scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a)
89
+ # Need to broadcast to the appropriate size, if scale_a is already
90
+ # (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes
91
+ # for scale_b below.
92
+ scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1))
93
+ accumulator = scale_a * accumulator.to(tl.float32)
94
+
95
+ masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
96
+ scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b)
97
+ scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1))
98
+ accumulator = scale_b.T * accumulator.to(tl.float32)
99
+
100
+ # Convert to output format.
101
+ c = accumulator.to(c_ptr.type.element_ty)
102
+
103
+ # Add bias, it's already in output format, so add it after conversion.
104
+ if bias_ptr:
105
+ offsets_bias = offsets_bn
106
+ bias_ptrs = bias_ptr + offsets_bias
107
+ bias_mask = offsets_bias < N
108
+ bias = tl.load(bias_ptrs, bias_mask)
109
+ c += bias
110
+
111
+ # Save output
112
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
113
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
114
+ offs_cm = offs_cm.to(tl.int64)
115
+ offs_cn = offs_cn.to(tl.int64)
116
+ c_ptrs = (c_ptr + stride_cm * offs_cm[:, None] +
117
+ stride_cn * offs_cn[None, :])
118
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
119
+
120
+ tl.store(c_ptrs, c, mask=c_mask)
121
+
122
+
123
+ # input - [M, K]
124
+ # weight - [K, N]
125
+ def triton_scaled_mm(input: torch.Tensor,
126
+ weight: torch.Tensor,
127
+ scale_a: torch.Tensor,
128
+ scale_b: torch.Tensor,
129
+ out_dtype: Type[torch.dtype],
130
+ bias: Optional[torch.Tensor] = None,
131
+ block_size_m: int = 32,
132
+ block_size_n: int = 32,
133
+ block_size_k: int = 32,
134
+ use_heuristic=True) -> torch.Tensor:
135
+ M, K = input.shape
136
+ N = weight.shape[1]
137
+
138
+ assert N > 0 and K > 0 and M > 0
139
+ assert weight.shape[0] == K
140
+ assert input.dtype == weight.dtype
141
+
142
+ scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
143
+ scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
144
+
145
+ assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
146
+ assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
147
+ [M, 1])
148
+ assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size(
149
+ [N, 1])
150
+ assert out_dtype.is_floating_point
151
+ assert bias is None or bias.is_floating_point()
152
+ assert is_weak_contiguous(input)
153
+ assert is_weak_contiguous(weight)
154
+
155
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(
156
+ N, META['BLOCK_SIZE_N']), )
157
+
158
+ result = torch.empty((M, N), dtype=out_dtype, device=input.device)
159
+
160
+ has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
161
+
162
+ if use_heuristic:
163
+ is_small_N = N < 8192
164
+ next_power_of_2_M = max(32, triton.next_power_of_2(M))
165
+ if next_power_of_2_M <= 32:
166
+ tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
167
+ elif next_power_of_2_M <= 64:
168
+ tile_shape = (64, 64, 256)
169
+ elif next_power_of_2_M <= 128:
170
+ tile_shape = (64, 128, 128)
171
+ else:
172
+ tile_shape = (128, 128, 128)
173
+
174
+ block_size_m, block_size_n, block_size_k = tile_shape
175
+
176
+ block_size_sa = 1 if has_scalar(scale_a) else block_size_m
177
+ block_size_sb = 1 if has_scalar(scale_b) else block_size_n
178
+
179
+ accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32
180
+
181
+ # A = input, B = weight, C = result
182
+ # A = M x K, B = K x N, C = M x N
183
+ scaled_mm_kernel[grid](input,
184
+ weight,
185
+ scale_a,
186
+ scale_b,
187
+ result,
188
+ bias,
189
+ M,
190
+ N,
191
+ K,
192
+ input.stride(0),
193
+ input.stride(1),
194
+ weight.stride(0),
195
+ weight.stride(1),
196
+ result.stride(0),
197
+ result.stride(1),
198
+ accumulator_dtype,
199
+ BLOCK_SIZE_M=block_size_m,
200
+ BLOCK_SIZE_N=block_size_n,
201
+ BLOCK_SIZE_K=block_size_k,
202
+ BLOCK_SIZE_SCALE_A=block_size_sa,
203
+ BLOCK_SIZE_SCALE_B=block_size_sb)
204
+
205
+ return result.to(out_dtype)