Spaces:
Build error
Build error
#include "common.cuh" | |
#include "mma.cuh" | |
#include "fattn-common.cuh" | |
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup> | |
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( | |
const float2 * const __restrict__ Q_f2, | |
const half2 * const __restrict__ K_h2, | |
const half2 * const __restrict__ V_h2, | |
const half * const __restrict__ maskh, | |
float2 * const __restrict__ dstk, | |
float2 * const __restrict__ dstk_fixup, | |
const float scale, | |
const float slope, | |
const float logit_softcap, | |
const int ne00, | |
const int ne01, | |
const int ne02, | |
const int ne03, | |
const int ne10, | |
const int ne11, | |
const int ne12, | |
const int ne13, | |
const int ne31, | |
const int nb31, | |
const int nb01, | |
const int nb02, | |
const int nb03, | |
const int nb11, | |
const int nb12, | |
const int nb13, | |
const int nb21, | |
const int nb22, | |
const int nb23, | |
const int ne0, | |
const int ne1, | |
const int ne2, | |
const int ne3, | |
const int jt, | |
const int kb0_start, | |
const int kb0_stop) { | |
#ifdef NEW_MMA_AVAILABLE | |
//In this kernel Q, K, V are matrices while i, j, k are matrix indices. | |
typedef mma_A_I16K8<half2> mma_A; | |
typedef mma_B_J8K8<half2> mma_B; | |
typedef mma_C_I16J8<float> mma_C_KQ; | |
typedef mma_C_I16J8<half2> mma_C_VKQ; | |
static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps"); | |
constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column. | |
static_assert(D % nwarps == 0, "bad D"); | |
static_assert(KQ_stride % nwarps == 0, "bad KQ_stride"); | |
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. | |
extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements. | |
const int stride_Q = nb01 / sizeof(float2); | |
const int stride_KV = nb11 / sizeof(half2); | |
const int stride_mask = nb31 / sizeof(half); | |
mma_B Q_B[D/(2*mma_B::K)]; | |
mma_C_VKQ VKQ_C[D/mma_C_VKQ::I]; | |
float2 KQ_rowsum = {0.0f, 0.0f}; | |
float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f}; | |
float2 KQ_max_scale = {0.0f, 0.0f}; | |
// Temporarily load Q data into tile_KV, will be loaded into registers afterwards. | |
// The loading is done with decreasing granularity for D for better memory bandwidth. | |
const half2 scale_h2 = make_half2(scale, scale); | |
#pragma unroll | |
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { | |
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); | |
const int k0_stop = D/2 - (D/2) % (1*stride_k); | |
const int stride_j = WARP_SIZE / stride_k; | |
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { | |
break; | |
} | |
#pragma unroll | |
for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) { | |
const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); | |
if (jt*ncols + j < ne01) { | |
#pragma unroll | |
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { | |
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); | |
const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k]; | |
tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); | |
} | |
} else { | |
#pragma unroll | |
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { | |
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); | |
tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f); | |
} | |
} | |
} | |
} | |
__syncthreads(); | |
{ | |
const int j0 = (threadIdx.y / np) * mma_B::J; | |
#pragma unroll | |
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { | |
Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded); | |
} | |
} | |
__syncthreads(); | |
// Iterate over ne11 == previous tokens: | |
for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) { | |
const int k_VKQ_0 = kb0*KQ_stride; | |
mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)]; | |
// Load K data into tile with decreasing granularity for D for better memory bandwidth: | |
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); | |
#pragma unroll | |
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { | |
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); | |
const int k0_stop = D/2 - (D/2) % (1*stride_k); | |
const int stride_i = WARP_SIZE / stride_k; | |
#pragma unroll | |
for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) { | |
const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); | |
#pragma unroll | |
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) { | |
const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); | |
tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ]; | |
} | |
} | |
} | |
__syncthreads(); | |
// Calculate tile of KQ: | |
#pragma unroll | |
for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) { | |
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I; | |
#pragma unroll | |
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) { | |
mma_A K_A; | |
K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded); | |
KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]); | |
} | |
} | |
__syncthreads(); | |
if (use_logit_softcap) { | |
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); | |
#pragma unroll | |
for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) { | |
#pragma unroll | |
for (int l = 0; l < mma_C_KQ::ne; ++l) { | |
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); | |
} | |
} | |
} | |
if (maskh) { | |
static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size"); | |
static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size"); | |
#pragma unroll | |
for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) { | |
const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I; | |
#pragma unroll | |
for (int l = 0; l < mma_C_KQ::ne; ++l) { | |
const int i = i0 + mma_C_KQ::get_i(l); | |
const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l); | |
KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]); | |
} | |
} | |
} | |
// Calculate softmax for each KQ column using the current max. value. | |
// The divisor is stored in KQ_rowsum and will be applied at the end. | |
float2 KQ_max_new = KQ_max; | |
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); | |
#pragma unroll | |
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { | |
#pragma unroll | |
for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) { | |
KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]); | |
KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]); | |
} | |
} | |
// Values per KQ column are spread across 8 threads, does not need full warp reduce: | |
#pragma unroll | |
for (int offset = 16; offset > 2; offset >>= 1) { | |
KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE)); | |
KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE)); | |
} | |
{ | |
const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y); | |
KQ_max_scale = make_float2(expf(diff.x), expf(diff.y)); | |
if (diff.x <= SOFTMAX_FTZ_THRESHOLD) { | |
KQ_max_scale.x = 0.0f; | |
} | |
if (diff.y <= SOFTMAX_FTZ_THRESHOLD) { | |
KQ_max_scale.y = 0.0f; | |
} | |
KQ_max = KQ_max_new; | |
} | |
float2 KQ_rowsum_add = make_float2(0.0f, 0.0f); | |
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); | |
#pragma unroll | |
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { | |
#pragma unroll | |
for (int l = 0; l < mma_C_KQ::ne; ++l) { | |
const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y; | |
const float diff = KQ_C[k].x[l] - KQ_max_l; | |
KQ_C[k].x[l] = expf(diff); | |
if (diff <= SOFTMAX_FTZ_THRESHOLD) { | |
KQ_C[k].x[l] = 0.0f; | |
} | |
if (l % 2 == 0) { | |
KQ_rowsum_add.x += KQ_C[k].x[l]; | |
} else { | |
KQ_rowsum_add.y += KQ_C[k].x[l]; | |
} | |
} | |
} | |
// Scale previous KQ_rowsum to account for a potential increase in KQ_max: | |
KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x; | |
KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y; | |
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y); | |
#pragma unroll | |
for (int i = 0; i < D/mma_C_VKQ::I; ++i) { | |
#pragma unroll | |
for (int l = 0; l < mma_C_VKQ::ne; ++l) { | |
VKQ_C[i].x[l] *= KQ_max_scale_h2; | |
} | |
} | |
// Convert KQ C tiles into B tiles for VKQ calculation: | |
mma_B B[KQ_stride/(np*2*mma_B::K)]; | |
static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size"); | |
#pragma unroll | |
for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) { | |
B[k] = KQ_C[k].to_mma_B(); | |
} | |
// Load V data into tile with decreasing granularity for D for better memory bandwidth: | |
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); | |
#pragma unroll | |
for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { | |
const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i); | |
const int i0_stop = D/2 - (D/2) % (1*stride_i); | |
const int stride_k = WARP_SIZE / stride_i; | |
#pragma unroll | |
for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) { | |
const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i); | |
#pragma unroll | |
for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) { | |
const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i); | |
tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V]; | |
} | |
} | |
} | |
__syncthreads(); | |
// Calculate VKQ tile: | |
#pragma unroll | |
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) { | |
static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size"); | |
#pragma unroll | |
for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) { | |
const int k0 = k00 + (threadIdx.y % np)*mma_A::K; | |
mma_A A; | |
A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); | |
VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]); | |
} | |
} | |
__syncthreads(); | |
} | |
// Finally, sum up partial KQ rowsums. | |
// The partial sums are spread across 8 threads each, does not need full reduce. | |
#pragma unroll | |
for (int offset = 16; offset > 2; offset >>= 1) { | |
KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE); | |
KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE); | |
} | |
// Write VKQ accumulators to shared memory in column-major format. | |
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. | |
// Also for np > 1 the combination is done via these values in shared memory. | |
const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data | |
#pragma unroll | |
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { | |
const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format. | |
#pragma unroll | |
for (int l = 0; l < mma_B::ne; ++l) { | |
const int k = k0 + mma_B::get_k(l); | |
tile_KV[j_cwd*D2_padded + k] = B.x[l]; | |
} | |
} | |
const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset | |
const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta | |
const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum | |
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) { | |
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. | |
((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr; | |
} | |
__syncthreads(); | |
static_assert(np == 1 || np == 2 || np == 4, "bad np"); | |
if (np == 1) { | |
// No combination is needed, the meta data can be directly written from registers to VRAM. | |
if (needs_fixup && threadIdx.x < mma_B::J) { | |
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; | |
dstk_fixup_meta[j_cwm] = KQ_cmr; | |
} | |
if (is_fixup && threadIdx.x < mma_B::J) { | |
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; | |
dstk_fixup_meta[j_cwm] = KQ_cmr; | |
} | |
} else if (threadIdx.y % np == 0) { | |
// Combine the meta data for parallel warps via shared memory. | |
// Warps with threadIdx.y % np != 0 must NOT return early. | |
// All threads must return simultaneously to avoid race conditions with work on the next tile. | |
float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2; | |
float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp. | |
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { | |
KQ_cm = meta_j[0]; | |
} | |
float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps. | |
#pragma unroll | |
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { | |
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); | |
} | |
const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp. | |
float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps. | |
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { | |
KQ_crs = KQ_cms*meta_j[1]; | |
} | |
#pragma unroll | |
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { | |
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); | |
} | |
// Write back combined meta data: | |
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { | |
meta_j[0] = KQ_cmn; // Combined max. KQ values. | |
meta_j[1] = KQ_crs; // Combined KQ rowsums. | |
meta_j[2] = KQ_cms; // KQ max scales per parallel warp. | |
} | |
if (needs_fixup && threadIdx.x < mma_B::J) { | |
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; | |
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); | |
} | |
if (is_fixup && threadIdx.x < mma_B::J) { | |
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; | |
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); | |
} | |
} | |
if (np > 1) { | |
__syncthreads(); | |
} | |
if (np == 1 || threadIdx.y % np == 0) { | |
// The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. | |
// The values after that are for the partial results of the individual blocks. | |
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2)); | |
#pragma unroll | |
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { | |
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); | |
const int k0_stop = D/2 - (D/2) % (1*stride_k); | |
const int stride_j = WARP_SIZE / stride_k; | |
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { | |
break; | |
} | |
#pragma unroll | |
for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) { | |
const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); | |
const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J; | |
if (!is_fixup && jt*ncols + j_dst >= ne01) { | |
continue; | |
} | |
const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2; | |
#pragma unroll | |
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { | |
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); | |
float2 dstk_val = make_float2(0.0f, 0.0f); | |
#pragma unroll | |
for (int ip = 0; ip < np; ++ip) { | |
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2]; | |
const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]); | |
dstk_val.x += dstk_val_add.x*KQ_crs; | |
dstk_val.y += dstk_val_add.y*KQ_crs; | |
} | |
if (!needs_fixup && !is_fixup) { | |
const float KQ_rowsum_j = meta_j[1]; | |
dstk_val.x /= KQ_rowsum_j; | |
dstk_val.y /= KQ_rowsum_j; | |
} | |
if (is_fixup) { | |
dstk_fixup_data[j_dst*(D/2) + k] = dstk_val; | |
} else { | |
dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val; | |
} | |
} | |
} | |
} | |
} | |
if (np > 1) { | |
__syncthreads(); | |
} | |
#else | |
NO_DEVICE_CODE; | |
#endif // NEW_MMA_AVAILABLE | |
} | |
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap> | |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | |
__launch_bounds__(nwarps*WARP_SIZE, 2) | |
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | |
static __global__ void flash_attn_ext_f16( | |
const char * __restrict__ Q, | |
const char * __restrict__ K, | |
const char * __restrict__ V, | |
const char * __restrict__ mask, | |
float * __restrict__ dst, | |
float2 * __restrict__ dst_meta, | |
const float scale, | |
const float max_bias, | |
const float m0, | |
const float m1, | |
const uint32_t n_head_log2, | |
const float logit_softcap, | |
const int ne00, | |
const int ne01, | |
const int ne02, | |
const int ne03, | |
const int ne10, | |
const int ne11, | |
const int ne12, | |
const int ne13, | |
const int ne31, | |
const int nb31, | |
const int nb01, | |
const int nb02, | |
const int nb03, | |
const int nb11, | |
const int nb12, | |
const int nb13, | |
const int nb21, | |
const int nb22, | |
const int nb23, | |
const int ne0, | |
const int ne1, | |
const int ne2, | |
const int ne3) { | |
// Skip unused kernel variants for faster compilation: | |
if (use_logit_softcap && !(D == 128 || D == 256)) { | |
NO_DEVICE_CODE; | |
return; | |
} | |
static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride"); | |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. | |
const int iter_k = ne11 / KQ_stride; | |
const int iter_j = (ne01 + (ncols - 1)) / ncols; | |
// kbc == k block continuous, current index in continuous ijk space. | |
int kbc = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x; | |
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x; | |
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. | |
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). | |
// In the most general case >2 seams can fall into the same tile. | |
// kb0 == k start index when in the output tile. | |
int kb0_start = kbc % iter_k; | |
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); | |
while (kbc < kbc_stop && kb0_stop == iter_k) { | |
const int channel = kbc / (iter_k*iter_j); | |
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. | |
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel); | |
const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio)); | |
const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape | |
const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; | |
float2 * dstk = ((float2 *) dst) + channel*(D/2); | |
const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); | |
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. | |
if (kb0_start == 0) { | |
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. | |
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup> | |
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, | |
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, | |
jt, kb0_start, kb0_stop); | |
} else { | |
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. | |
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup> | |
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, | |
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, | |
jt, kb0_start, kb0_stop); | |
} | |
kbc += iter_k; | |
kbc -= kbc % iter_k; | |
kb0_start = 0; | |
kb0_stop = min(iter_k, kbc_stop - kbc); | |
} | |
if (kbc >= kbc_stop) { | |
return; | |
} | |
const int channel = kbc / (iter_k*iter_j); | |
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. | |
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel); | |
const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio)); | |
const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape | |
const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; | |
float2 * dstk = ((float2 *) dst) + channel*(D/2); | |
const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); | |
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. | |
constexpr bool needs_fixup = false; | |
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup> | |
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, | |
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, | |
jt, kb0_start, kb0_stop); | |
} | |
template <int D, int cols_per_block> | |
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
typedef mma_A_I16K8<half2> mma_A; | |
typedef mma_B_J8K8<half2> mma_B; | |
static_assert(D % mma_B::K == 0, "bad D"); | |
static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block"); | |
const ggml_tensor * KQV = dst; | |
constexpr int KQ_stride = D <= 128 ? 64 : 32; | |
constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ? | |
cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8); | |
constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half); | |
float logit_softcap; | |
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); | |
fattn_kernel_t fattn_kernel; | |
if (logit_softcap == 0.0f) { | |
constexpr bool use_logit_softcap = false; | |
fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>; | |
} else { | |
constexpr bool use_logit_softcap = true; | |
fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>; | |
} | |
launch_fattn<D, cols_per_block, 0, KQ_stride>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); | |
} | |
#define DECL_FATTN_MMA_F16_CASE(D, cols_per_block) \ | |
template void ggml_cuda_flash_attn_ext_mma_f16_case \ | |
<D, cols_per_block>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ | |
extern DECL_FATTN_MMA_F16_CASE( 64, 8); | |
extern DECL_FATTN_MMA_F16_CASE( 80, 8); | |
extern DECL_FATTN_MMA_F16_CASE( 96, 8); | |
extern DECL_FATTN_MMA_F16_CASE(112, 8); | |
extern DECL_FATTN_MMA_F16_CASE(128, 8); | |
extern DECL_FATTN_MMA_F16_CASE(256, 8); | |
extern DECL_FATTN_MMA_F16_CASE( 64, 16); | |
extern DECL_FATTN_MMA_F16_CASE( 80, 16); | |
extern DECL_FATTN_MMA_F16_CASE( 96, 16); | |
extern DECL_FATTN_MMA_F16_CASE(112, 16); | |
extern DECL_FATTN_MMA_F16_CASE(128, 16); | |
extern DECL_FATTN_MMA_F16_CASE(256, 16); | |
extern DECL_FATTN_MMA_F16_CASE( 64, 32); | |
extern DECL_FATTN_MMA_F16_CASE( 80, 32); | |
extern DECL_FATTN_MMA_F16_CASE( 96, 32); | |
extern DECL_FATTN_MMA_F16_CASE(112, 32); | |
extern DECL_FATTN_MMA_F16_CASE(128, 32); | |
extern DECL_FATTN_MMA_F16_CASE(256, 32); | |
extern DECL_FATTN_MMA_F16_CASE( 64, 64); | |
extern DECL_FATTN_MMA_F16_CASE( 80, 64); | |
extern DECL_FATTN_MMA_F16_CASE( 96, 64); | |
extern DECL_FATTN_MMA_F16_CASE(112, 64); | |
extern DECL_FATTN_MMA_F16_CASE(128, 64); | |
extern DECL_FATTN_MMA_F16_CASE(256, 64); | |