File size: 2,444 Bytes
b37c16f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <cuda_fp16.h>

typedef unsigned char uint8_t;

// extern "C"
// __global__ void quantize(
//   const half* __restrict__ codebook,   // nsq x 2^b x d
//   const half* __restrict__ vectors,    // n x (nsq * d)
//   uint8_t* __restrict__ codes,         // nsq x n
//   int n
// ) {
//   extern __shared__ volatile half centroids[]; // 2^b x d

//   const int sq_id = blockIdx.x;
//   const int thread_id = threadIdx.x;
//   const int n_threads = blockDim.x;

//   const int n_floats_per_sq = (1 << __B__) * __D__;
// #pragma unroll
//   for (int i = thread_id; i < n_floats_per_sq; i += n_threads) {
//     centroids[i] = codebook[sq_id * n_floats_per_sq + i];
//   }
//   __syncthreads();

//   half subvector[__D__];
//   for (int i = thread_id; i < n; i += n_threads) {
// #pragma unroll
//     for (int j = 0; j < __D__; ++j) {
//         subvector[j] = vectors[(i * __NSQ__ + sq_id) * __D__ + j];
//     }
//     float min_dist = 1 << 16;
//     uint8_t min_idx;
// #pragma unroll
//     for (int j = 0; j < (1 << __B__); ++j) {
//         float dist = 0;
// #pragma unroll
//         for (int k = 0; k < __D__; ++k) {
//             float tmp = __half2float(subvector[k]) - __half2float(centroids[j * __D__ + k]);
//             dist += tmp * tmp;
//         }
//         min_dist = (dist <= min_dist) ? dist : min_dist;
//         min_idx = (dist == min_dist) ? j : min_idx;
//     }
//     // printf("%d %d %d %d\n", sq_id, n, i, min_idx);
//     codes[sq_id * n + i] = min_idx;
//   }
// }

extern "C"
__global__ void dequantize(
  const half* __restrict__ codebook,   // nsq x 2^b x d
  const uint8_t* __restrict__ codes,    // nsq x n
  half* __restrict__ vectors,           // n x (nsq x d)
  int n
) {
  extern __shared__ volatile half centroids[]; // 2^b x d

  const int sq_id = blockIdx.x;
  const int thread_id = threadIdx.x;
  const int n_threads = blockDim.x;

  const int n_floats_per_sq = (1 << __B__) * __D__;
#pragma unroll
  for (int i = thread_id; i < n_floats_per_sq; i += n_threads) {
    centroids[i] = codebook[sq_id * n_floats_per_sq + i];
  }
  __syncthreads();

  for (int i = thread_id; i < n; i += n_threads) {
    uint8_t code = codes[sq_id * n + i];
#pragma unroll
    for (int dim = 0; dim < __D__; ++dim) {
      vectors[(i * __NSQ__ + sq_id) * __D__ + dim] = centroids[__D__ * code + dim];
      // atomicAdd(vectors + (i * __NSQ__ + sq_id) * __D__ + dim, centroids[__D__ * code + dim]);
    }
  }
}