Spaces:
Build error
Build error
File size: 11,411 Bytes
5a29263 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_buffer_reference : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_KHR_shader_subgroup_vote : enable
#extension GL_EXT_null_initializer : enable
#include "types.comp"
#include "dequant_funcs_cm2.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 1) const uint32_t Br = 32;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t D = 32;
layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t neq2;
uint32_t neq3;
uint32_t nek2;
uint32_t nek3;
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;
float scale;
float max_bias;
float logit_softcap;
uint32_t mask;
uint32_t n_head_log2;
float m0;
float m1;
} p;
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
layout (binding = 3) readonly buffer M {uint8_t data_m[];};
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return max(x, y);
}
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return x;
}
// Replace matrix elements >= numRows or numCols with 'replace'
ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) {
if (row >= numRows || col >= numCols) {
return replace;
}
return elem;
}
ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem)
{
return exp(elem);
}
ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1)
{
return max(elem0, elem1);
}
#if defined(BLOCK_SIZE)
#define DECODEFUNC , DEQUANTFUNC
#else
#define DECODEFUNC
#endif
void main() {
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
#endif
const uint32_t N = p.N;
const uint32_t KV = p.KV;
const uint32_t Tr = CEIL_DIV(N, Br);
const uint32_t Tc = CEIL_DIV(KV, Bc);
const uint32_t i = gl_WorkGroupID.x;
const uint32_t iq2 = gl_WorkGroupID.y;
const uint32_t iq3 = gl_WorkGroupID.z;
// broadcast factors
const uint32_t rk2 = p.neq2/p.nek2;
const uint32_t rk3 = p.neq3/p.nek3;
const uint32_t rv2 = p.neq2/p.nev2;
const uint32_t rv3 = p.neq3/p.nev3;
// k indices
const uint32_t ik3 = iq3 / rk3;
const uint32_t ik2 = iq2 / rk2;
// v indices
const uint32_t iv3 = iq3 / rv3;
const uint32_t iv2 = iq2 / rv2;
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp);
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
#if defined(BLOCK_SIZE)
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
#endif
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D);
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
// nb?1 are already divided by the type size and are in units of elements
uint32_t q_stride = p.nb01;
uint32_t k_stride = p.nb11;
uint32_t v_stride = p.nb21;
// hint to the compiler that strides are aligned for the aligned variant of the shader
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
{
q_stride &= ~7;
#if !defined(BLOCK_SIZE)
k_stride &= ~7;
v_stride &= ~7;
#endif
}
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D));
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q);
Qf16 *= float16_t(p.scale);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
ACC_TYPE slope = ACC_TYPE(1.0);
// ALiBi
if (p.max_bias > 0.0f) {
const uint32_t h = iq2;
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
slope = pow(base, ACC_TYPE(exph));
}
[[dont_unroll]]
for (uint32_t j = 0; j < Tc; ++j) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T;
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC);
S = coopMatMulAdd(Qf16, K_T, S);
if (p.logit_softcap != 0.0f) {
[[unroll]]
for (int k = 0; k < S.length(); ++k) {
S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
}
}
if (p.mask != 0) {
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slope*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}
// Clear padding elements to -inf, so they don't contribute to rowmax
if (Clamp != 0 &&
((j + 1) * Bc > KV ||
(i + 1) * Br > N)) {
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C);
}
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
coopMatPerElementNV(M, rowmax, Max, Mold);
coopMatPerElementNV(P, S - M, Exp);
coopMatPerElementNV(eM, Mold - M, Exp);
// Clear padding elements to 0, so they don't contribute to rowsum
if (Clamp != 0 &&
((j + 1) * Bc > KV ||
(i + 1) * Br > N)) {
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
}
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
// compute rowsum by multiplying by matrix of all ones.
coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
rowsum = coopMatMulAdd(P_A, One, rowsum);
coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V;
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC);
L = eM*L + rowsum;
// This is the "diagonal" matrix in the paper, but since we do componentwise
// multiply rather than matrix multiply it has the diagonal element smeared
// across the row
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag;
// resize eM by using smear/reduce
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
O = eMdiag * O;
O = coopMatMulAdd(P_A, V, O);
}
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
// resize L by using smear/reduce
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
[[unroll]]
for (int k = 0; k < Ldiag.length(); ++k) {
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
}
O = Ldiag*O;
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
// permute dimensions
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
uint32_t o_offset = iq3*p.ne2*p.ne1;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute);
}
|