#version 450 #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable #extension GL_NV_cooperative_matrix2 : enable #extension GL_EXT_buffer_reference : enable #extension GL_KHR_shader_subgroup_ballot : enable #extension GL_KHR_shader_subgroup_vote : enable #extension GL_EXT_null_initializer : enable #include "types.comp" #include "dequant_funcs_cm2.comp" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (constant_id = 1) const uint32_t Br = 32; layout (constant_id = 2) const uint32_t Bc = 32; layout (constant_id = 3) const uint32_t D = 32; layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; layout (push_constant) uniform parameter { uint32_t N; uint32_t KV; uint32_t ne1; uint32_t ne2; uint32_t ne3; uint32_t neq2; uint32_t neq3; uint32_t nek2; uint32_t nek3; uint32_t nev2; uint32_t nev3; uint32_t nem1; uint32_t nb01; uint32_t nb02; uint32_t nb03; uint32_t nb11; uint32_t nb12; uint32_t nb13; uint32_t nb21; uint32_t nb22; uint32_t nb23; uint32_t nb31; float scale; float max_bias; float logit_softcap; uint32_t mask; uint32_t n_head_log2; float m0; float m1; } p; layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; layout (binding = 2) readonly buffer V {uint8_t data_v[];}; layout (binding = 3) readonly buffer M {uint8_t data_m[];}; layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { return max(x, y); } ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) { return x; } // Replace matrix elements >= numRows or numCols with 'replace' ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) { if (row >= numRows || col >= numCols) { return replace; } return elem; } ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem) { return exp(elem); } ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1) { return max(elem0, elem1); } #if defined(BLOCK_SIZE) #define DECODEFUNC , DEQUANTFUNC #else #define DECODEFUNC #endif void main() { #if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) init_iq_shmem(gl_WorkGroupSize); #endif const uint32_t N = p.N; const uint32_t KV = p.KV; const uint32_t Tr = CEIL_DIV(N, Br); const uint32_t Tc = CEIL_DIV(KV, Bc); const uint32_t i = gl_WorkGroupID.x; const uint32_t iq2 = gl_WorkGroupID.y; const uint32_t iq3 = gl_WorkGroupID.z; // broadcast factors const uint32_t rk2 = p.neq2/p.nek2; const uint32_t rk3 = p.neq3/p.nek3; const uint32_t rv2 = p.neq2/p.nev2; const uint32_t rv3 = p.neq3/p.nev3; // k indices const uint32_t ik3 = iq3 / rk3; const uint32_t ik2 = iq2 / rk2; // v indices const uint32_t iv3 = iq3 / rv3; const uint32_t iv2 = iq2 / rv2; tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp); tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); #if defined(BLOCK_SIZE) tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); #endif tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D); tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); // nb?1 are already divided by the type size and are in units of elements uint32_t q_stride = p.nb01; uint32_t k_stride = p.nb11; uint32_t v_stride = p.nb21; // hint to the compiler that strides are aligned for the aligned variant of the shader if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { q_stride &= ~7; #if !defined(BLOCK_SIZE) k_stride &= ~7; v_stride &= ~7; #endif } tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); coopmat Q; coopmat Qf16; uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D)); Qf16 = coopmat(Q); Qf16 *= float16_t(p.scale); coopmat O = coopmat(0); coopmat L, M; L = coopmat(0); M = coopmat(-1.0/0.0); ACC_TYPE slope = ACC_TYPE(1.0); // ALiBi if (p.max_bias > 0.0f) { const uint32_t h = iq2; const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); slope = pow(base, ACC_TYPE(exph)); } [[dont_unroll]] for (uint32_t j = 0; j < Tc; ++j) { coopmat S = coopmat(0); coopmat K_T; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC); S = coopMatMulAdd(Qf16, K_T, S); if (p.logit_softcap != 0.0f) { [[unroll]] for (int k = 0; k < S.length(); ++k) { S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); } } if (p.mask != 0) { tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); coopmat mv; coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); S += slope*coopmat(mv); } // Clear padding elements to -inf, so they don't contribute to rowmax if (Clamp != 0 && ((j + 1) * Bc > KV || (i + 1) * Br > N)) { uint R = ((i + 1) * Br > N) ? (N % Br) : Br; uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C); } coopmat rowmax, P, rowsum, eM; coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce); coopmat Mold = M; // M = max(rowmax, Mold) // P = e^(S - M) // eM = e^(Mold - M) coopMatPerElementNV(M, rowmax, Max, Mold); coopMatPerElementNV(P, S - M, Exp); coopMatPerElementNV(eM, Mold - M, Exp); // Clear padding elements to 0, so they don't contribute to rowsum if (Clamp != 0 && ((j + 1) * Bc > KV || (i + 1) * Br > N)) { uint R = ((i + 1) * Br > N) ? (N % Br) : Br; uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); } coopmat P_A = coopmat(P); // compute rowsum by multiplying by matrix of all ones. coopmat One = coopmat(1.0); rowsum = coopmat(0.0); rowsum = coopMatMulAdd(P_A, One, rowsum); coopmat V; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC); L = eM*L + rowsum; // This is the "diagonal" matrix in the paper, but since we do componentwise // multiply rather than matrix multiply it has the diagonal element smeared // across the row coopmat eMdiag; // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); O = eMdiag * O; O = coopMatMulAdd(P_A, V, O); } coopmat Ldiag; // resize L by using smear/reduce coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); [[unroll]] for (int k = 0; k < Ldiag.length(); ++k) { Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; } O = Ldiag*O; tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); // permute dimensions tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); uint32_t o_offset = iq3*p.ne2*p.ne1; coopmat O_D = coopmat(O); coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute); }