File size: 5,243 Bytes
b37c16f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
#include <cuda_fp16.h>
// IMPORTANT: replace __NSQ__ __B__ __D__ with actual values for kernel
typedef unsigned char uint8_t;
extern "C"
__global__ void quantize(
const half* __restrict__ codebook, // nsq x 2^b x d
const half* __restrict__ vectors, // n x (nsq * d)
uint8_t* __restrict__ codes, // nsq x n
int n
) {
extern __shared__ volatile half centroids[]; // 2^b x d
const int sq_id = blockIdx.x;
const int thread_id = threadIdx.x;
const int n_threads = blockDim.x;
const int n_floats_per_sq = (1 << __B__) * __D__;
#pragma unroll
for (int i = thread_id; i < n_floats_per_sq; i += n_threads) {
centroids[i] = codebook[sq_id * n_floats_per_sq + i];
}
__syncthreads();
half subvector[__D__];
for (int i = thread_id; i < n; i += n_threads) {
#pragma unroll
for (int j = 0; j < __D__; ++j) {
subvector[j] = vectors[(i * __NSQ__ + sq_id) * __D__ + j];
}
float min_dist = 1 << 16;
uint8_t min_idx;
#pragma unroll
for (int j = 0; j < (1 << __B__); ++j) {
float dist = 0;
#pragma unroll
for (int k = 0; k < __D__; ++k) {
float tmp = __half2float(subvector[k]) - __half2float(centroids[j * __D__ + k]);
dist += tmp * tmp;
}
min_dist = (dist <= min_dist) ? dist : min_dist;
min_idx = (dist == min_dist) ? j : min_idx;
}
// printf("%d %d %d %d\n", sq_id, n, i, min_idx);
codes[sq_id * n + i] = min_idx;
}
}
// extern "C"
// __global__ void cq_encode(
// const half* __restrict__ codebook, // nsq x 2^b x d
// const half* __restrict__ vectors, // n x (nsq * d)
// uint8_t* __restrict__ codes, // nsq x n
// int nsq, int b, int d, int n
// ) {
// extern __shared__ volatile float centroids[]; // 2^b x d
// const int sq_id = blockIdx.x;
// const int thread_id = threadIdx.x;
// const int n_threads = blockDim.x;
// const int n_floats_per_sq = (1 << b) * d;
// for (int i = thread_id; i < n_floats_per_sq; i += n_threads) {
// centroids[i] = __half2float(codebook[sq_id * n_floats_per_sq + i]);
// }
// __syncthreads();
// float subvector[MAX_DIM];
// for (int i = thread_id; i < n; i += n_threads) {
// for (int j = 0; j < d; ++j) {
// subvector[j] = __half2float(vectors[(i * nsq + sq_id) * d + j]);
// // subvector[j] = __half2float(vectors[sq_id * d + j]);
// }
// float min_dist = 16384;
// uint8_t min_idx;
// for (int j = 0; j < (1 << b); ++j) {
// float dist = 0;
// for (int k = 0; k < d; ++k) {
// dist += (subvector[k] - centroids[j * d + k]) * (subvector[k] - centroids[j * d + k]);
// }
// min_dist = (dist <= min_dist) ? dist : min_dist;
// min_idx = (dist == min_dist) ? j : min_idx;
// }
// // printf("%d %d %d %d\n", sq_id, n, i, min_idx);
// codes[sq_id * n + i] = min_idx;
// }
// }
// extern "C"
// __global__ void cq_encode(
// const half* __restrict__ codebook, // nsq x 2^b x d
// const half* __restrict__ vectors, // n x (nsq * d)
// uint8_t* __restrict__ codes, // nsq x n
// int nsq, int b, int d, int n
// ) {
// extern __shared__ volatile half centroids[]; // 2^b x d
// const int sq_id = blockIdx.x;
// const int thread_id = threadIdx.x;
// const int n_threads = blockDim.x;
// const int n_floats_per_sq = (1 << b) * d;
// for (int i = thread_id; i < n_floats_per_sq; i += n_threads) {
// centroids[i] = codebook[sq_id * n_floats_per_sq + i];
// }
// __syncthreads();
// for (int i = thread_id; i < n; i += n_threads) {
// half subvector[MAX_DIM];
// for (int j = 0; j < d; ++j) {
// subvector[j] = vectors[(i * nsq + sq_id) * d + j];
// }
// half min_dist = 16384;
// int min_idx = -1;
// for (int j = 0; j < (1 << b); ++j) {
// half dist = 0;
// for (int k = 0; k < d; ++k) {
// dist += (subvector[k] - centroids[j * d + k]) * (subvector[k] - centroids[j * d + k]);
// }
// min_dist = (dist <= min_dist) ? dist : min_dist;
// min_idx = (dist == min_dist) ? j : min_idx;
// }
// codes[sq_id * n + i] = min_idx;
// }
// }
// extern "C"
// __global__ void cq_decode(
// const float* __restrict__ codebook, // nsq x 2^b x d
// const uint8_t* __restrict__ codes, // nsq x n
// float* __restrict__ result, // (nsq x d) x n
// int nsq, int b, int d, int n
// ) {
// extern __shared__ volatile float centroids[];
// const int sq_id = blockIdx.x;
// const int thread_id = threadIdx.x;
// const int n_threads = blockDim.x;
// const int n_floats_per_sq = (1 << b) * d;
// for (int i = thread_id; i < n_floats_per_sq; i += n_threads) {
// // printf("sq_id %d n_floats_per_sq %d i %d\n", sq_id, n_floats_per_sq, i);
// centroids[i] = codebook[sq_id * n_floats_per_sq + i];
// // printf("%d: %f\n", i, centroids[i]);
// }
// __syncthreads();
// for (int i = thread_id; i < n; i += n_threads) {
// uint8_t code = codes[sq_id * n + i];
// for (int dim = 0; dim < d; ++dim) {
// result[(sq_id * d + dim) * n + i] = centroids[d * code + dim];
// // result[dim] = centroids[d * code + dim];
// }
// }
// } |