File size: 5,817 Bytes
d5175d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
/**
 * Copyright (c) Facebook, Inc. and its affiliates.
 * 
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */


template <typename U, typename V>	
constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {	
  return (a + b - 1) / b;	
}


template<int FS, int SB, int padding_l, typename scalar_t>
__inline__ __device__
void zeroSharedMem(scalar_t* data) {
  /*
    Given an array of length FS + SB, zero out the first padding_l and last
    (FS - padding_l) values in the array
  */

  int tid = threadIdx.x;

  if (FS < SB) {

    // zero all if we have enough threads in a block to do all of them
    if (tid < padding_l || tid > SB - FS + padding_l - 1) {
      data[tid] = scalar_t(0.0);
    }
  } else {

    // otherwise zero out one block at a time
    const int numIterations = divUp<int, int>(FS, SB);
    for (int i = 0; i < numIterations; i++) {
      int offset = i * SB;
      if (tid + offset < padding_l) {
        data[tid + offset] = scalar_t(0.0);
      } else if (tid + offset < FS) {
        data[SB + tid + offset] = scalar_t(0.0);
      }
    }
  }
}

template<typename scalar_t>
__inline__ __device__
scalar_t warpReduce(scalar_t data) {
  /*
    Reduce an array within each warp. After processing all values in warp will
    caontain the sum of all original values in that warp.

    data - pointer to data to reduce
  */
  data += __shfl_xor_sync(SHFL_MASK, data, 16);
  data += __shfl_xor_sync(SHFL_MASK, data, 8);
  data += __shfl_xor_sync(SHFL_MASK, data, 4);
  data += __shfl_xor_sync(SHFL_MASK, data, 2);
  data += __shfl_xor_sync(SHFL_MASK, data, 1);
  return data;
}

template<typename scalar_t>
__inline__ __device__
scalar_t blockReduce(scalar_t data) {
  /*
     Reduce an entire array on the block level. After processing, the
     first value in the array will contain the reduced sum.

     data - pointer to data to reduce
  */

  static __shared__ scalar_t warpSum[32];
  const int tid = threadIdx.x;
  int wid = tid / 32;
  int lane = tid % 32;

  __syncthreads();

  // reduce each warp then write to shared memory
  scalar_t sum = warpReduce(data);
  if (lane == 0) {
    warpSum[wid] = sum;
  }
  
  __syncthreads();

  scalar_t v;
  // perform final sum of partial warp sums
  if (tid < blockDim.x / 32) {
    v = warpSum[lane];
  } else {
    v = scalar_t(0.0);
  }

  if (wid == 0) {
    v = warpReduce(v);
  }
  __syncthreads();

  return v;
}

void checkCudaStatus(cudaError_t status, int lineNumber = -1) {

  if (status != cudaSuccess) {
    std::cout << cudaGetErrorString(status)
              << " at line " << lineNumber << std::endl;
    std::cout << "Exiting" << std::endl;
    exit(1);
  }
}

template<int FS, int SB, int padding_l, typename scalar_t>
__device__
void load_input_to_shared(const scalar_t* input, // global memory
                          int inputOffset, int sequenceLength,
                          int iteration, int numIterations,
                          bool no_prev, scalar_t* output /* shared memory */) {
  /*
    Load a block size of input into shared memory with
    right and left overhang of total size FS. If previously
    loaded memory, overlap will be shifted over to reduce
    global memory access

    input - pointer to start of channel sequence
    inputOffset - how far in the sequence to start loading
    sequenceLength - total length of sequence
    iteration - which block of sequence we are loading
    numIterations - total number of blocks to load
    no_prev - whether to load the whole block if the previous block
              wasn't loaded
    output - shared memory to write input to
  */

  const int tid = threadIdx.x;

  // Load the left "overhang" of input
  if (iteration > 0) {
    if (padding_l < SB) {

      // load all at once
      if (tid < padding_l) {
        output[tid] = (no_prev) ? input[inputOffset - padding_l + tid] : output[tid + SB];
      }
    } else {

      // load in chunks of size SB
      int numIterations = divUp<int, int>(padding_l, SB);
      for (int i = 0; i < numIterations; i++) {
        int offset = i * SB;
        if ((tid + offset) < padding_l) {
          output[tid + offset] = (no_prev) ? input[inputOffset - padding_l + tid + offset] : output[tid + offset + SB];
        }
      }
    }
  }

  // Load the right "overhang" of input
  if (iteration < (numIterations - 1)) {
    const int elementsLeft = sequenceLength - (iteration+1) * SB;

    if ((FS - padding_l) < SB) {

      // load all at once
      if (tid < (FS - padding_l)) {
          output[padding_l + SB + tid] = (tid < elementsLeft) ? input[inputOffset + SB + tid] : scalar_t(0.0);
      }
    } else {

      // load in chunks of size SB
      int numIterations = divUp<int, int>(FS - padding_l, SB);
      for (int i = 0; i < numIterations; i++) {
        int offset = i * SB;
        if ((tid + offset) < (FS - padding_l)) {
          output[padding_l + SB + tid + offset] = ((tid + offset) < elementsLeft) ? input[inputOffset + SB + tid + offset] : scalar_t(0.0);
        }
      }
    }
  }

  // We should also clear out the right "overhang"
  if (iteration == (numIterations - 1)) {
    if ((FS - padding_l) < SB) {

      // clear out all at once
      if (tid < (FS - padding_l)) {
          output[padding_l + SB + tid] = scalar_t(0.0);
      }
    } else {

      // clear in chunks of size SB
      int numIterations = divUp<int, int>(FS - padding_l, SB);
      for (int i = 0; i < numIterations; i++) {
        int offset = i * SB;
        if ((tid + offset) < (FS - padding_l)) {
          output[padding_l + SB + tid + offset] = scalar_t(0.0);
        }
      }
    }
  }
  output[tid + padding_l] = ((inputOffset + tid) < sequenceLength) ? input[inputOffset + tid] : scalar_t(0.0);
}