Spaces:
Build error
Build error
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | |
layout (constant_id = 1) const uint32_t Br = 32; | |
layout (constant_id = 2) const uint32_t Bc = 32; | |
layout (constant_id = 3) const uint32_t D = 32; | |
layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; | |
layout (push_constant) uniform parameter { | |
uint32_t N; | |
uint32_t KV; | |
uint32_t ne1; | |
uint32_t ne2; | |
uint32_t ne3; | |
uint32_t neq2; | |
uint32_t neq3; | |
uint32_t nek2; | |
uint32_t nek3; | |
uint32_t nev2; | |
uint32_t nev3; | |
uint32_t nem1; | |
uint32_t nb01; | |
uint32_t nb02; | |
uint32_t nb03; | |
uint32_t nb11; | |
uint32_t nb12; | |
uint32_t nb13; | |
uint32_t nb21; | |
uint32_t nb22; | |
uint32_t nb23; | |
uint32_t nb31; | |
float scale; | |
float max_bias; | |
float logit_softcap; | |
uint32_t mask; | |
uint32_t n_head_log2; | |
float m0; | |
float m1; | |
} p; | |
layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; | |
layout (binding = 1) readonly buffer K {uint8_t data_k[];}; | |
layout (binding = 2) readonly buffer V {uint8_t data_v[];}; | |
layout (binding = 3) readonly buffer M {uint8_t data_m[];}; | |
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; | |
ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { | |
return max(x, y); | |
} | |
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) { | |
return x; | |
} | |
// Replace matrix elements >= numRows or numCols with 'replace' | |
ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) { | |
if (row >= numRows || col >= numCols) { | |
return replace; | |
} | |
return elem; | |
} | |
ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem) | |
{ | |
return exp(elem); | |
} | |
ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1) | |
{ | |
return max(elem0, elem1); | |
} | |
void main() { | |
init_iq_shmem(gl_WorkGroupSize); | |
const uint32_t N = p.N; | |
const uint32_t KV = p.KV; | |
const uint32_t Tr = CEIL_DIV(N, Br); | |
const uint32_t Tc = CEIL_DIV(KV, Bc); | |
const uint32_t i = gl_WorkGroupID.x; | |
const uint32_t iq2 = gl_WorkGroupID.y; | |
const uint32_t iq3 = gl_WorkGroupID.z; | |
// broadcast factors | |
const uint32_t rk2 = p.neq2/p.nek2; | |
const uint32_t rk3 = p.neq3/p.nek3; | |
const uint32_t rv2 = p.neq2/p.nev2; | |
const uint32_t rv3 = p.neq3/p.nev3; | |
// k indices | |
const uint32_t ik3 = iq3 / rk3; | |
const uint32_t ik2 = iq2 / rk2; | |
// v indices | |
const uint32_t iv3 = iq3 / rv3; | |
const uint32_t iv2 = iq2 / rv2; | |
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); | |
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); | |
tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp); | |
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); | |
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); | |
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); | |
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D); | |
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); | |
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); | |
// nb?1 are already divided by the type size and are in units of elements | |
uint32_t q_stride = p.nb01; | |
uint32_t k_stride = p.nb11; | |
uint32_t v_stride = p.nb21; | |
// hint to the compiler that strides are aligned for the aligned variant of the shader | |
if (Clamp != gl_CooperativeMatrixClampModeConstantNV) | |
{ | |
q_stride &= ~7; | |
k_stride &= ~7; | |
v_stride &= ~7; | |
} | |
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); | |
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); | |
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); | |
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q; | |
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16; | |
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; | |
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D)); | |
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q); | |
Qf16 *= float16_t(p.scale); | |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0); | |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M; | |
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); | |
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0); | |
ACC_TYPE slope = ACC_TYPE(1.0); | |
// ALiBi | |
if (p.max_bias > 0.0f) { | |
const uint32_t h = iq2; | |
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); | |
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); | |
slope = pow(base, ACC_TYPE(exph)); | |
} | |
[[dont_unroll]] | |
for (uint32_t j = 0; j < Tc; ++j) { | |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); | |
coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T; | |
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; | |
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC); | |
S = coopMatMulAdd(Qf16, K_T, S); | |
if (p.logit_softcap != 0.0f) { | |
[[unroll]] | |
for (int k = 0; k < S.length(); ++k) { | |
S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); | |
} | |
} | |
if (p.mask != 0) { | |
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); | |
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); | |
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv; | |
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); | |
S += slope*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv); | |
} | |
// Clear padding elements to -inf, so they don't contribute to rowmax | |
if (Clamp != 0 && | |
((j + 1) * Bc > KV || | |
(i + 1) * Br > N)) { | |
uint R = ((i + 1) * Br > N) ? (N % Br) : Br; | |
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; | |
coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C); | |
} | |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM; | |
coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce); | |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M; | |
// M = max(rowmax, Mold) | |
// P = e^(S - M) | |
// eM = e^(Mold - M) | |
coopMatPerElementNV(M, rowmax, Max, Mold); | |
coopMatPerElementNV(P, S - M, Exp); | |
coopMatPerElementNV(eM, Mold - M, Exp); | |
// Clear padding elements to 0, so they don't contribute to rowsum | |
if (Clamp != 0 && | |
((j + 1) * Bc > KV || | |
(i + 1) * Br > N)) { | |
uint R = ((i + 1) * Br > N) ? (N % Br) : Br; | |
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; | |
coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); | |
} | |
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P); | |
// compute rowsum by multiplying by matrix of all ones. | |
coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0); | |
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0); | |
rowsum = coopMatMulAdd(P_A, One, rowsum); | |
coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V; | |
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; | |
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC); | |
L = eM*L + rowsum; | |
// This is the "diagonal" matrix in the paper, but since we do componentwise | |
// multiply rather than matrix multiply it has the diagonal element smeared | |
// across the row | |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag; | |
// resize eM by using smear/reduce | |
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); | |
O = eMdiag * O; | |
O = coopMatMulAdd(P_A, V, O); | |
} | |
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag; | |
// resize L by using smear/reduce | |
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); | |
[[unroll]] | |
for (int k = 0; k < Ldiag.length(); ++k) { | |
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; | |
} | |
O = Ldiag*O; | |
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); | |
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); | |
// permute dimensions | |
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); | |
uint32_t o_offset = iq3*p.ne2*p.ne1; | |
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O); | |
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute); | |
} | |