Spaces:
Build error
Build error
File size: 7,842 Bytes
452b173 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
#include "half_matmul.cuh"
#include "../util.cuh"
#include "../matrix.cuh"
#include "../cuda_compat.cuh"
#if defined(USE_ROCM)
#include "../hip_compat.cuh"
#endif
// Block size
const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_Y = 8; // Block size and thread count along rows in x and out
const int BLOCKSIZE = 256;
__global__ void half_matmul_kernel
(
const half* __restrict__ x,
const half* __restrict__ w,
half* __restrict__ out,
const int height,
const int dim,
const int width
)
{
const int column = (blockIdx.x * THREADS_X + threadIdx.x) * 2;
const int row = blockIdx.y * THREADS_Y + threadIdx.y;
const int k0 = blockIdx.z * BLOCKSIZE;
if (row >= height) return;
if (column >= width) return;
MatrixView_half x_(x, height, dim);
MatrixView_half w_(w, dim, width);
MatrixView_half_rw out_(out, height, width);
half2* x_ptr = (half2*) x_.item_ptr(row, k0);
half2* w_ptr = (half2*) w_.item_ptr(k0, column);
half2 acc = {};
#pragma unroll
for (int k = k0; k < k0 + BLOCKSIZE / 2; k++)
{
half2 x_item = *x_ptr++;
half2 x_item_0 = __half2half2(x_item.x);
half2 x_item_1 = __half2half2(x_item.y);
half2 w_item_0 = *w_ptr; w_ptr += w_.width / 2;
half2 w_item_1 = *w_ptr; w_ptr += w_.width / 2;
acc = __hfma2(x_item_0, w_item_0, acc);
acc = __hfma2(x_item_1, w_item_1, acc);
}
// out_.set(row, column, acc);
atomicAdd((half2*)out_.item_ptr(row, column), acc);
}
void half_matmul_cuda
(
const half* x,
const half* w,
half* out,
const int height,
const int dim,
const int width,
cudaStream_t alt_stream
)
{
dim3 threads(THREADS_X, THREADS_Y, 1);
dim3 blocks
(
(width + THREADS_X - 1) / THREADS_X / 2,
(height + THREADS_Y - 1) / THREADS_Y,
(dim + BLOCKSIZE - 1) / BLOCKSIZE
);
half_matmul_kernel<<<blocks, threads, 0, alt_stream>>>(x, w, out, height, dim, width);
}
// cuBLAS can't be beat for large matrices, probably
const int MAX_DIM_SMALL = 8192;
void half_matmul_cublas_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const half* w,
half* out,
const int height,
const int dim,
const int width,
cublasHandle_t handle,
bool no_zero,
cudaStream_t alt_stream
)
{
// Fall back on a naive kernel for small matmuls to avoid cuBLAS overhead
if (height < 4 && dim <= MAX_DIM_SMALL)
{
half_matmul_small_cuda(tuningParams, x, w, out, height, dim, width, no_zero, alt_stream);
return;
}
// printf("cuBLAS: (%i, %i) @ (%i, %i) -> (%i, %i)\n", height, dim, dim, width, height, width);
// Use cuBLAS
const half alpha = __float2half(1.0f);
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
cudaStream_t default_stream;
if (alt_stream)
{
cublasGetStream(handle, &default_stream);
cublasSetStream(handle, alt_stream);
}
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, w, width, x, dim, &beta, out, width);
if (alt_stream)
{
cublasSetStream(handle, default_stream);
}
}
// Alternative to cuBLAS for tall or wide matrices
const int S_THREADS_X = 8; // width
const int S_THREADS_Z = 1; // height
const int S_BLOCKSIZE = MAX_DIM_SMALL / 1024 * S_THREADS_X; // dim
template<bool use_half2, bool odd_rank>
__global__ void half_matmul_small_kernel
(
const half* __restrict__ x,
const half* __restrict__ w,
half* __restrict__ out,
const int height,
const int dim,
const int width,
bool no_zero
)
{
int column = blockIdx.x * S_THREADS_X + threadIdx.x;
int row = blockIdx.z * S_THREADS_Z + threadIdx.z;
int k = threadIdx.y * S_BLOCKSIZE;
if (row >= height) return;
if (column >= width) return;
// if (k >= dim) return;
// printf("%i, %i, %i\n", row, column, k);
MatrixView_half x_(x, height, dim);
MatrixView_half w_(w, dim, width);
MatrixView_half_rw out_(out, height, width);
int k_end = k + S_BLOCKSIZE;
if (k_end > dim) k_end = dim;
const half* x_ptr = x_.item_ptr(row, k);
const half* x_ptr_end = x_.item_ptr(row, k_end);
const half* w_ptr = w_.item_ptr(k, column);
half* out_ptr = out_.item_ptr(row, column);
if constexpr (use_half2 && !odd_rank)
{
half2* x_ptr2 = (half2*) x_ptr;
half2* x_ptr2_end = (half2*) x_ptr_end;
half2 r = {};
while(x_ptr2 < x_ptr2_end)
{
half2 x_01 = *x_ptr2++;
half2 x_23 = *x_ptr2++;
half w_0 = *w_ptr; w_ptr += width;
half w_1 = *w_ptr; w_ptr += width;
half w_2 = *w_ptr; w_ptr += width;
half w_3 = *w_ptr; w_ptr += width;
half2 w_01 = __halves2half2(w_0, w_1);
half2 w_23 = __halves2half2(w_2, w_3);
r = __hfma2(x_01, w_01, r);
r = __hfma2(x_23, w_23, r);
}
half rh = __hadd(r.x, r.y);
__shared__ half accum[MAX_DIM_SMALL / S_BLOCKSIZE][S_THREADS_X];
accum[threadIdx.y][threadIdx.x] = rh;
__syncthreads();
if (threadIdx.y == 0)
{
half acc = rh;
for (int i = 1; i < blockDim.y; ++i) acc = __hadd(accum[i][threadIdx.x], acc);
if (no_zero) acc = __hadd(acc, *out_ptr);
*out_ptr = acc;
}
}
else
{
half r = {};
while(x_ptr < x_ptr_end)
{
if constexpr (odd_rank)
{
half x_item = *x_ptr++;
half w_item = *w_ptr; w_ptr += width;
r = __hfma(x_item, w_item, r);
}
else
{
#pragma unroll
for (int i = 0; i < 4; ++i)
{
half x_item = *x_ptr++;
half w_item = *w_ptr; w_ptr += width;
r = __hfma(x_item, w_item, r);
}
}
}
__shared__ half accum[MAX_DIM_SMALL / S_BLOCKSIZE][S_THREADS_X];
accum[threadIdx.y][threadIdx.x] = r;
__syncthreads();
if (threadIdx.y == 0)
{
half acc = accum[0][threadIdx.x];
for (int i = 1; i < blockDim.y; ++i) acc = __hadd(accum[i][threadIdx.x], acc);
if (no_zero) acc = __hadd(acc, *out_ptr);
*out_ptr = acc;
}
}
}
void half_matmul_small_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const half* w,
half* out,
const int height,
const int dim,
const int width,
bool no_zero,
cudaStream_t alt_stream
)
{
bool use_half2 = !tuningParams->matmul_no_half2;
//printf("kernel: (%i, %i) @ (%i, %i) -> (%i, %i)\n", height, dim, dim, width, height, width);
dim3 threads
(
S_THREADS_X,
(dim + S_BLOCKSIZE - 1) / S_BLOCKSIZE,
1
);
dim3 blocks
(
(width + S_THREADS_X - 1) / S_THREADS_X,
1,
height
);
//printf("t... %i %i %i\n", threads.x, threads.y, threads.z);
//printf("b... %i %i %i\n", blocks.x, blocks.y, blocks.z);
//if (!no_zero) cudaMemsetAsync(out, 0, height * width * sizeof(half));
if (dim & 0x03)
{
half_matmul_small_kernel<false, true> <<<blocks, threads, 0, alt_stream>>>(x, w, out, height, dim, width, no_zero);
}
else
{
if (use_half2) half_matmul_small_kernel<true, false> <<<blocks, threads, 0, alt_stream>>>(x, w, out, height, dim, width, no_zero);
else half_matmul_small_kernel<false, false> <<<blocks, threads, 0, alt_stream>>>(x, w, out, height, dim, width, no_zero);
}
}
|