Spaces:
Build error
Build error
File size: 8,031 Bytes
452b173 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
#include "q4_mlp.cuh"
#include "q4_matmul.cuh"
#include "rope.cuh"
#include "rms_norm.cuh"
#include "half_matmul.cuh"
#include "../cuda_buffers.cuh"
#include "../util.cuh"
#include "../matrix.cuh"
#if defined(USE_ROCM)
#include "../hip_compat.cuh"
#endif
const int THREADS_X = 32;
const int THREADS_Y = 1;
const int THREADS_Z = 4;
const int BLOCKSIZE_X = 2; // 2*half == 1*uint32_t
const int BLOCKSIZE_Z = 4; // num_heads must be divisible by BLOCKSIZE_Z TODO: Check that this is the case when Llama2-34b releases
__global__ void update_cache_kernel
(
const half* __restrict__ key_states,
const half* __restrict__ value_states,
half* __restrict__ key_cache,
half* __restrict__ value_cache,
const int head_dim,
const int num_kv_heads,
const int q_len,
const int max_seq_len,
const int past_len
)
{
//int state_shape[] = { num_kv_heads, q_len, head_dim };
int state_stride[] = { head_dim, head_dim * num_kv_heads, 1 };
int state_pos[] = { 0, 0, 0 };
//int cache_shape[] = { num_kv_heads, max_seq_len, head_dim };
int cache_stride[] = { max_seq_len * head_dim, head_dim, 1 };
int cache_pos[] = { 0, past_len, 0 };
int size[] = { num_kv_heads, q_len, head_dim };
int x = (blockIdx.x * THREADS_X + threadIdx.x) * BLOCKSIZE_X;
int y = blockIdx.y * THREADS_Y + threadIdx.y;
int z = (blockIdx.z * THREADS_Z + threadIdx.z) * BLOCKSIZE_Z;
if (x >= size[2]) return;
if (y >= size[1]) return;
if (z >= size[0]) return;
int state_offset = (z + state_pos[0]) * state_stride[0] + (y + state_pos[1]) * state_stride[1] + (x + state_pos[2]) * state_stride[2];
int cache_offset = (z + cache_pos[0]) * cache_stride[0] + (y + cache_pos[1]) * cache_stride[1] + (x + cache_pos[2]) * cache_stride[2];
const uint32_t* key_ptr = (uint32_t*) (key_states + state_offset);
const uint32_t* value_ptr = (uint32_t*) (value_states + state_offset);
uint32_t* key_cache_ptr = (uint32_t*) (key_cache + cache_offset);
uint32_t* value_cache_ptr = (uint32_t*) (value_cache + cache_offset);
#pragma unroll
for (int k = 0; k < BLOCKSIZE_Z; k++)
{
*key_cache_ptr = *key_ptr;
key_ptr += state_stride[0] / BLOCKSIZE_X;
key_cache_ptr += cache_stride[0] / BLOCKSIZE_X;
}
#pragma unroll
for (int k = 0; k < BLOCKSIZE_Z; k++)
{
*value_cache_ptr = *value_ptr;
value_ptr += state_stride[0] / BLOCKSIZE_X;
value_cache_ptr += cache_stride[0] / BLOCKSIZE_X;
}
}
void q4_attn_cuda
(
ExLlamaTuning* tuningParams,
cudaStream_t stream,
cublasHandle_t handle,
half* x,
const half* rms_norm_weight, // shape == (x.shape[1],) == (dim,)
float epsilon,
half* query_states,
half* key_states,
half* value_states,
Q4Matrix* q_proj,
Q4Matrix* k_proj,
Q4Matrix* v_proj,
half* sin,
half* cos,
const int bsz,
const int q_len,
const int dim,
const int head_dim,
const int num_heads,
const int num_kv_heads,
const int past_len,
half* key_cache,
half* value_cache,
const half* q_a,
const half* q_b,
const int q_rank,
const half* k_a,
const half* k_b,
const int k_rank,
const half* v_a,
const half* v_b,
const int v_rank,
half* lora_temp,
const int max_seq_len,
const int device_index
)
{
// Cache update grid
dim3 threads(THREADS_X, THREADS_Y, THREADS_Z);
dim3 blocks
(
((head_dim + THREADS_X - 1) / THREADS_X + BLOCKSIZE_X - 1) / BLOCKSIZE_X,
q_len,
((num_kv_heads + THREADS_Z - 1) / THREADS_Z + BLOCKSIZE_Z - 1) / BLOCKSIZE_Z
);
int _rows_per_batch = q_len * num_heads;
int _rows_per_batch_kv = q_len * num_kv_heads;
CudaBuffers* buffers = get_buffers(device_index);
// Layernorm
half* temp_x = buffers->temp_state + q_len * dim;
rms_norm_cuda(tuningParams, x, rms_norm_weight, temp_x, epsilon, q_len, dim, device_index);
// Adapters
if (q_a)
{
half_matmul_cublas_cuda(tuningParams, temp_x, q_a, lora_temp, q_len, dim, q_rank, handle);
half_matmul_cublas_cuda(tuningParams, lora_temp, q_b, query_states, q_len, q_rank, dim, handle);
}
if (k_a)
{
half_matmul_cublas_cuda(tuningParams, temp_x, k_a, lora_temp, q_len, dim, k_rank, handle);
half_matmul_cublas_cuda(tuningParams, lora_temp, k_b, key_states, q_len, k_rank, dim, handle);
}
if (v_a)
{
half_matmul_cublas_cuda(tuningParams, temp_x, v_a, lora_temp, q_len, dim, v_rank, handle);
half_matmul_cublas_cuda(tuningParams, lora_temp, v_b, value_states, q_len, v_rank, dim, handle);
}
if (!tuningParams->concurrent_streams)
{
// Project q, k, v
q4_matmul_cuda(tuningParams, temp_x, q_len, q_proj, query_states, q_a ? true : false);
q4_matmul_cuda(tuningParams, temp_x, q_len, k_proj, key_states, k_a ? true : false);
q4_matmul_cuda(tuningParams, temp_x, q_len, v_proj, value_states, v_a ? true : false);
// Positional embeddings q, k
rope_cuda(tuningParams, query_states, sin, cos, bsz, _rows_per_batch, head_dim, num_heads, past_len);
rope_cuda(tuningParams, key_states, sin, cos, bsz, _rows_per_batch_kv, head_dim, num_kv_heads, past_len);
// Update cache tensors with projected k, v
update_cache_kernel<<<blocks, threads>>>(key_states, value_states, key_cache, value_cache, head_dim, num_kv_heads, q_len, max_seq_len, past_len);
}
else
{
// Project q, k, v, add positional embeddings to q, k, update cache tensors with projected k, v
cudaStream_t str_1 = buffers->alt_stream_1;
cudaStream_t str_2 = buffers->alt_stream_2;
cudaStream_t str_3 = buffers->alt_stream_3;
cudaEvent_t sync_1 = buffers->alt_stream_1_done;
cudaEvent_t sync_2 = buffers->alt_stream_2_done;
cudaEvent_t sync_3 = buffers->alt_stream_3_done;
// str_1: project q, positions q, sync
q4_matmul_cuda(tuningParams, temp_x, q_len, q_proj, query_states, q_a ? true : false, str_1);
rope_cuda(tuningParams, query_states, sin, cos, bsz, _rows_per_batch, head_dim, num_kv_heads, past_len, str_1);
cudaEventRecord(sync_1, str_1);
// str_2: project k, positions k, sync
q4_matmul_cuda(tuningParams, temp_x, q_len, k_proj, key_states, k_a ? true : false, str_2);
rope_cuda(tuningParams, key_states, sin, cos, bsz, _rows_per_batch_kv, head_dim, num_kv_heads, past_len, str_2);
cudaEventRecord(sync_2, str_2);
// str_3: project v, wait for str_2, copy (k,v) to cache, sync
q4_matmul_cuda(tuningParams, temp_x, q_len, v_proj, value_states, v_a ? true : false, buffers->alt_stream_3);
cudaStreamWaitEvent(str_3, sync_2, 0);
update_cache_kernel<<<blocks, threads, 0, str_3>>>(key_states, value_states, key_cache, value_cache, head_dim, num_kv_heads, q_len, max_seq_len, past_len);
cudaEventRecord(sync_3, str_3);
// default: wait for str_1 and str_3
cudaStreamWaitEvent(NULL, sync_1, 0);
cudaStreamWaitEvent(NULL, sync_3, 0);
}
}
void q4_attn_2_cuda
(
ExLlamaTuning* tuningParams,
cublasHandle_t handle,
half* x,
half* attn_output,
Q4Matrix* o_proj,
const int height,
const half* o_a,
const half* o_b,
const int o_rank,
half* lora_temp
)
{
if (o_a)
{
int dim = o_proj->height;
half_matmul_cublas_cuda(tuningParams, attn_output, o_a, lora_temp, height, dim, o_rank, handle);
half_matmul_cublas_cuda(tuningParams, lora_temp, o_b, x, height, o_rank, dim, handle, true);
}
q4_matmul_cuda(tuningParams, attn_output, height, o_proj, x, true);
}
|