Aston-xMAD's picture
init commit
b37c16f verified
#include <cuda_fp16.h>
// IMPORTANT: replace __NSQ__ __B__ __D__ with actual values for kernel
typedef unsigned char uint8_t;
extern "C"
__global__ void quantize(
const half* __restrict__ codebook, // nsq x 2^b x d
const half* __restrict__ vectors, // n x (nsq * d)
uint8_t* __restrict__ codes, // nsq x n
int n
) {
extern __shared__ volatile half centroids[]; // 2^b x d
const int sq_id = blockIdx.x;
const int thread_id = threadIdx.x;
const int n_threads = blockDim.x;
const int n_floats_per_sq = (1 << __B__) * __D__;
#pragma unroll
for (int i = thread_id; i < n_floats_per_sq; i += n_threads) {
centroids[i] = codebook[sq_id * n_floats_per_sq + i];
}
__syncthreads();
half subvector[__D__];
for (int i = thread_id; i < n; i += n_threads) {
#pragma unroll
for (int j = 0; j < __D__; ++j) {
subvector[j] = vectors[(i * __NSQ__ + sq_id) * __D__ + j];
}
float min_dist = 1 << 16;
uint8_t min_idx;
#pragma unroll
for (int j = 0; j < (1 << __B__); ++j) {
float dist = 0;
#pragma unroll
for (int k = 0; k < __D__; ++k) {
float tmp = __half2float(subvector[k]) - __half2float(centroids[j * __D__ + k]);
dist += tmp * tmp;
}
min_dist = (dist <= min_dist) ? dist : min_dist;
min_idx = (dist == min_dist) ? j : min_idx;
}
// printf("%d %d %d %d\n", sq_id, n, i, min_idx);
codes[sq_id * n + i] = min_idx;
}
}
// extern "C"
// __global__ void cq_encode(
// const half* __restrict__ codebook, // nsq x 2^b x d
// const half* __restrict__ vectors, // n x (nsq * d)
// uint8_t* __restrict__ codes, // nsq x n
// int nsq, int b, int d, int n
// ) {
// extern __shared__ volatile float centroids[]; // 2^b x d
// const int sq_id = blockIdx.x;
// const int thread_id = threadIdx.x;
// const int n_threads = blockDim.x;
// const int n_floats_per_sq = (1 << b) * d;
// for (int i = thread_id; i < n_floats_per_sq; i += n_threads) {
// centroids[i] = __half2float(codebook[sq_id * n_floats_per_sq + i]);
// }
// __syncthreads();
// float subvector[MAX_DIM];
// for (int i = thread_id; i < n; i += n_threads) {
// for (int j = 0; j < d; ++j) {
// subvector[j] = __half2float(vectors[(i * nsq + sq_id) * d + j]);
// // subvector[j] = __half2float(vectors[sq_id * d + j]);
// }
// float min_dist = 16384;
// uint8_t min_idx;
// for (int j = 0; j < (1 << b); ++j) {
// float dist = 0;
// for (int k = 0; k < d; ++k) {
// dist += (subvector[k] - centroids[j * d + k]) * (subvector[k] - centroids[j * d + k]);
// }
// min_dist = (dist <= min_dist) ? dist : min_dist;
// min_idx = (dist == min_dist) ? j : min_idx;
// }
// // printf("%d %d %d %d\n", sq_id, n, i, min_idx);
// codes[sq_id * n + i] = min_idx;
// }
// }
// extern "C"
// __global__ void cq_encode(
// const half* __restrict__ codebook, // nsq x 2^b x d
// const half* __restrict__ vectors, // n x (nsq * d)
// uint8_t* __restrict__ codes, // nsq x n
// int nsq, int b, int d, int n
// ) {
// extern __shared__ volatile half centroids[]; // 2^b x d
// const int sq_id = blockIdx.x;
// const int thread_id = threadIdx.x;
// const int n_threads = blockDim.x;
// const int n_floats_per_sq = (1 << b) * d;
// for (int i = thread_id; i < n_floats_per_sq; i += n_threads) {
// centroids[i] = codebook[sq_id * n_floats_per_sq + i];
// }
// __syncthreads();
// for (int i = thread_id; i < n; i += n_threads) {
// half subvector[MAX_DIM];
// for (int j = 0; j < d; ++j) {
// subvector[j] = vectors[(i * nsq + sq_id) * d + j];
// }
// half min_dist = 16384;
// int min_idx = -1;
// for (int j = 0; j < (1 << b); ++j) {
// half dist = 0;
// for (int k = 0; k < d; ++k) {
// dist += (subvector[k] - centroids[j * d + k]) * (subvector[k] - centroids[j * d + k]);
// }
// min_dist = (dist <= min_dist) ? dist : min_dist;
// min_idx = (dist == min_dist) ? j : min_idx;
// }
// codes[sq_id * n + i] = min_idx;
// }
// }
// extern "C"
// __global__ void cq_decode(
// const float* __restrict__ codebook, // nsq x 2^b x d
// const uint8_t* __restrict__ codes, // nsq x n
// float* __restrict__ result, // (nsq x d) x n
// int nsq, int b, int d, int n
// ) {
// extern __shared__ volatile float centroids[];
// const int sq_id = blockIdx.x;
// const int thread_id = threadIdx.x;
// const int n_threads = blockDim.x;
// const int n_floats_per_sq = (1 << b) * d;
// for (int i = thread_id; i < n_floats_per_sq; i += n_threads) {
// // printf("sq_id %d n_floats_per_sq %d i %d\n", sq_id, n_floats_per_sq, i);
// centroids[i] = codebook[sq_id * n_floats_per_sq + i];
// // printf("%d: %f\n", i, centroids[i]);
// }
// __syncthreads();
// for (int i = thread_id; i < n; i += n_threads) {
// uint8_t code = codes[sq_id * n + i];
// for (int dim = 0; dim < d; ++dim) {
// result[(sq_id * d + dim) * n + i] = centroids[d * code + dim];
// // result[dim] = centroids[d * code + dim];
// }
// }
// }