#include // IMPORTANT: replace __NSQ__ __B__ __D__ with actual values for kernel typedef unsigned char uint8_t; extern "C" __global__ void quantize( const half* __restrict__ codebook, // nsq x 2^b x d const half* __restrict__ vectors, // n x (nsq * d) uint8_t* __restrict__ codes, // nsq x n int n ) { extern __shared__ volatile half centroids[]; // 2^b x d const int sq_id = blockIdx.x; const int thread_id = threadIdx.x; const int n_threads = blockDim.x; const int n_floats_per_sq = (1 << __B__) * __D__; #pragma unroll for (int i = thread_id; i < n_floats_per_sq; i += n_threads) { centroids[i] = codebook[sq_id * n_floats_per_sq + i]; } __syncthreads(); half subvector[__D__]; for (int i = thread_id; i < n; i += n_threads) { #pragma unroll for (int j = 0; j < __D__; ++j) { subvector[j] = vectors[(i * __NSQ__ + sq_id) * __D__ + j]; } float min_dist = 1 << 16; uint8_t min_idx; #pragma unroll for (int j = 0; j < (1 << __B__); ++j) { float dist = 0; #pragma unroll for (int k = 0; k < __D__; ++k) { float tmp = __half2float(subvector[k]) - __half2float(centroids[j * __D__ + k]); dist += tmp * tmp; } min_dist = (dist <= min_dist) ? dist : min_dist; min_idx = (dist == min_dist) ? j : min_idx; } // printf("%d %d %d %d\n", sq_id, n, i, min_idx); codes[sq_id * n + i] = min_idx; } } // extern "C" // __global__ void cq_encode( // const half* __restrict__ codebook, // nsq x 2^b x d // const half* __restrict__ vectors, // n x (nsq * d) // uint8_t* __restrict__ codes, // nsq x n // int nsq, int b, int d, int n // ) { // extern __shared__ volatile float centroids[]; // 2^b x d // const int sq_id = blockIdx.x; // const int thread_id = threadIdx.x; // const int n_threads = blockDim.x; // const int n_floats_per_sq = (1 << b) * d; // for (int i = thread_id; i < n_floats_per_sq; i += n_threads) { // centroids[i] = __half2float(codebook[sq_id * n_floats_per_sq + i]); // } // __syncthreads(); // float subvector[MAX_DIM]; // for (int i = thread_id; i < n; i += n_threads) { // for (int j = 0; j < d; ++j) { // subvector[j] = __half2float(vectors[(i * nsq + sq_id) * d + j]); // // subvector[j] = __half2float(vectors[sq_id * d + j]); // } // float min_dist = 16384; // uint8_t min_idx; // for (int j = 0; j < (1 << b); ++j) { // float dist = 0; // for (int k = 0; k < d; ++k) { // dist += (subvector[k] - centroids[j * d + k]) * (subvector[k] - centroids[j * d + k]); // } // min_dist = (dist <= min_dist) ? dist : min_dist; // min_idx = (dist == min_dist) ? j : min_idx; // } // // printf("%d %d %d %d\n", sq_id, n, i, min_idx); // codes[sq_id * n + i] = min_idx; // } // } // extern "C" // __global__ void cq_encode( // const half* __restrict__ codebook, // nsq x 2^b x d // const half* __restrict__ vectors, // n x (nsq * d) // uint8_t* __restrict__ codes, // nsq x n // int nsq, int b, int d, int n // ) { // extern __shared__ volatile half centroids[]; // 2^b x d // const int sq_id = blockIdx.x; // const int thread_id = threadIdx.x; // const int n_threads = blockDim.x; // const int n_floats_per_sq = (1 << b) * d; // for (int i = thread_id; i < n_floats_per_sq; i += n_threads) { // centroids[i] = codebook[sq_id * n_floats_per_sq + i]; // } // __syncthreads(); // for (int i = thread_id; i < n; i += n_threads) { // half subvector[MAX_DIM]; // for (int j = 0; j < d; ++j) { // subvector[j] = vectors[(i * nsq + sq_id) * d + j]; // } // half min_dist = 16384; // int min_idx = -1; // for (int j = 0; j < (1 << b); ++j) { // half dist = 0; // for (int k = 0; k < d; ++k) { // dist += (subvector[k] - centroids[j * d + k]) * (subvector[k] - centroids[j * d + k]); // } // min_dist = (dist <= min_dist) ? dist : min_dist; // min_idx = (dist == min_dist) ? j : min_idx; // } // codes[sq_id * n + i] = min_idx; // } // } // extern "C" // __global__ void cq_decode( // const float* __restrict__ codebook, // nsq x 2^b x d // const uint8_t* __restrict__ codes, // nsq x n // float* __restrict__ result, // (nsq x d) x n // int nsq, int b, int d, int n // ) { // extern __shared__ volatile float centroids[]; // const int sq_id = blockIdx.x; // const int thread_id = threadIdx.x; // const int n_threads = blockDim.x; // const int n_floats_per_sq = (1 << b) * d; // for (int i = thread_id; i < n_floats_per_sq; i += n_threads) { // // printf("sq_id %d n_floats_per_sq %d i %d\n", sq_id, n_floats_per_sq, i); // centroids[i] = codebook[sq_id * n_floats_per_sq + i]; // // printf("%d: %f\n", i, centroids[i]); // } // __syncthreads(); // for (int i = thread_id; i < n; i += n_threads) { // uint8_t code = codes[sq_id * n + i]; // for (int dim = 0; dim < d; ++dim) { // result[(sq_id * d + dim) * n + i] = centroids[d * code + dim]; // // result[dim] = centroids[d * code + dim]; // } // } // }