|
|
|
|
|
typedef unsigned char uint8_t; |
|
|
|
// extern "C" |
|
// __global__ void quantize( |
|
// const half* __restrict__ codebook, // nsq x 2^b x d |
|
// const half* __restrict__ vectors, // n x (nsq * d) |
|
// uint8_t* __restrict__ codes, // nsq x n |
|
// int n |
|
// ) { |
|
// extern __shared__ volatile half centroids[]; // 2^b x d |
|
|
|
// const int sq_id = blockIdx.x; |
|
// const int thread_id = threadIdx.x; |
|
// const int n_threads = blockDim.x; |
|
|
|
// const int n_floats_per_sq = (1 << __B__) * __D__; |
|
// |
|
// for (int i = thread_id; i < n_floats_per_sq; i += n_threads) { |
|
// centroids[i] = codebook[sq_id * n_floats_per_sq + i]; |
|
// } |
|
// __syncthreads(); |
|
|
|
// half subvector[__D__]; |
|
// for (int i = thread_id; i < n; i += n_threads) { |
|
// |
|
// for (int j = 0; j < __D__; ++j) { |
|
// subvector[j] = vectors[(i * __NSQ__ + sq_id) * __D__ + j]; |
|
// } |
|
// float min_dist = 1 << 16; |
|
// uint8_t min_idx; |
|
// |
|
// for (int j = 0; j < (1 << __B__); ++j) { |
|
// float dist = 0; |
|
// |
|
// for (int k = 0; k < __D__; ++k) { |
|
// float tmp = __half2float(subvector[k]) - __half2float(centroids[j * __D__ + k]); |
|
// dist += tmp * tmp; |
|
// } |
|
// min_dist = (dist <= min_dist) ? dist : min_dist; |
|
// min_idx = (dist == min_dist) ? j : min_idx; |
|
// } |
|
// // printf("%d %d %d %d\n", sq_id, n, i, min_idx); |
|
// codes[sq_id * n + i] = min_idx; |
|
// } |
|
// } |
|
|
|
extern "C" |
|
__global__ void dequantize( |
|
const half* __restrict__ codebook, // nsq x 2^b x d |
|
const uint8_t* __restrict__ codes, // nsq x n |
|
half* __restrict__ vectors, // n x (nsq x d) |
|
int n |
|
) { |
|
extern __shared__ volatile half centroids[]; // 2^b x d |
|
|
|
const int sq_id = blockIdx.x; |
|
const int thread_id = threadIdx.x; |
|
const int n_threads = blockDim.x; |
|
|
|
const int n_floats_per_sq = (1 << __B__) * __D__; |
|
|
|
for (int i = thread_id; i < n_floats_per_sq; i += n_threads) { |
|
centroids[i] = codebook[sq_id * n_floats_per_sq + i]; |
|
} |
|
__syncthreads(); |
|
|
|
for (int i = thread_id; i < n; i += n_threads) { |
|
uint8_t code = codes[sq_id * n + i]; |
|
|
|
for (int dim = 0; dim < __D__; ++dim) { |
|
vectors[(i * __NSQ__ + sq_id) * __D__ + dim] = centroids[__D__ * code + dim]; |
|
// atomicAdd(vectors + (i * __NSQ__ + sq_id) * __D__ + dim, centroids[__D__ * code + dim]); |
|
} |
|
} |
|
} |