Spaces:
Runtime error
Runtime error
| __device__ __forceinline__ int lastpow2(int n) | |
| { | |
| int out = 1 << (31 - __clz(n)); | |
| if(n == out) | |
| out >>= 1; | |
| return out; | |
| } | |
| __host__ __forceinline__ int h_next_pow2(unsigned int n) { | |
| n--; | |
| n |= (n >> 1); | |
| n |= (n >> 2); | |
| n |= (n >> 4); | |
| n |= (n >> 8); | |
| n |= (n >> 16); | |
| return ++n; | |
| } | |
| __host__ __forceinline__ int h_last_pow2(unsigned int n) { | |
| n |= (n >> 1); | |
| n |= (n >> 2); | |
| n |= (n >> 4); | |
| n |= (n >> 8); | |
| n |= (n >> 16); | |
| return n - (n >> 1); | |
| } | |
| template<typename T> | |
| __device__ __forceinline__ T warp_reduce_sum(T val) | |
| { | |
| for(int i = WARP_SIZE/2; i > 0; i >>= 1) | |
| val = val + __shfl_down_sync(0xffffffff, val, i); | |
| return val; | |
| } | |
| template<typename T> | |
| __device__ __forceinline__ T reduce_block(T *x, T val) | |
| { | |
| int tid = threadIdx.y*blockDim.x + threadIdx.x; | |
| int blockSize = blockDim.x * blockDim.y; | |
| if (blockSize > 32) { | |
| val = warp_reduce_sum(val); | |
| if (tid % WARP_SIZE == 0) | |
| x[tid/WARP_SIZE] = val; | |
| __syncthreads(); | |
| val = (tid < blockSize / WARP_SIZE? x[tid%WARP_SIZE] : T(0)); | |
| } | |
| if(tid/WARP_SIZE==0) val = warp_reduce_sum(val); | |
| return val; | |
| } | |
| __host__ int div_ru(int x, int y) { | |
| return h_last_pow2(1 + (x-1)/y); | |
| } | |
| __host__ void flexible_launch_configs( | |
| const int reduction, | |
| const int stride, | |
| dim3 &block, | |
| dim3 &grid, | |
| const bool coop_flag = false) { | |
| int block_x = std::min(h_last_pow2(stride), OPTIMAL_TILE_W); | |
| int block_y = std::min(h_last_pow2(div_ru(reduction , ELEMENTS_PER_THREAD)), | |
| MAX_BLOCK_SIZE / block_x); | |
| if (block_x * block_y != MAX_BLOCK_SIZE) { | |
| block_x = std::min(h_last_pow2(stride), MAX_BLOCK_SIZE / block_y); | |
| } | |
| int grid_x = div_ru(stride, block_x); | |
| int grid_y = std::min(div_ru(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK); | |
| if (coop_flag) { | |
| // it's not worth having a grid reduction if the reduction dimension is not big enough | |
| grid_y = grid_y < 8 ? 1 : grid_y; | |
| } | |
| block.x = block_x; | |
| block.y = block_y; | |
| block.z = 1; | |
| grid.x = grid_x; | |
| grid.y = grid_y; | |
| grid.z = 1; | |
| } | |
| template<typename T, typename C> | |
| __device__ __forceinline__ void welford_merge_element(C& count, | |
| T& mean, | |
| T& m2n, | |
| const C& num_new, | |
| const T& mean_new, | |
| const T& m2n_new) { | |
| T factor = T(1.0) / max(1, (count + num_new)); | |
| T delta0 = mean - mean_new; | |
| mean = (mean_new * num_new + mean * count) * factor; | |
| m2n += m2n_new + delta0 * delta0 * num_new * count * factor; | |
| count += num_new; | |
| } | |
| template<typename T> | |
| __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num) | |
| { | |
| for(int i = WARP_SIZE/2; i > 0; i >>= 1) { | |
| auto num_new = __shfl_down_sync(0xffffffff, num, i); | |
| auto mean_new = __shfl_down_sync(0xffffffff, mean, i); | |
| auto m2n_new = __shfl_down_sync(0xffffffff, m2n, i); | |
| welford_merge_element(num, mean, m2n, num_new, mean_new, m2n_new); | |
| } | |
| } | |
| template <typename T> | |
| __device__ void welford_reduce_mean_m2n( | |
| T* __restrict__ x, | |
| int* __restrict__ count, | |
| T &mean, | |
| T &m2n, | |
| int &num, | |
| int block_size, | |
| int thread_id) | |
| { | |
| int lane = thread_id % WARP_SIZE; | |
| int wid = thread_id / WARP_SIZE; | |
| if (block_size > 32) { | |
| warp_reduce_mean_m2n(mean, m2n, num); | |
| if (lane == 0) { | |
| x[wid*2] = mean; | |
| x[wid*2+1] = m2n; | |
| count[wid] = num; | |
| } | |
| __syncthreads(); | |
| if (wid == 0) { | |
| mean = (thread_id < block_size / WARP_SIZE)? x[lane*2] : T(0); | |
| m2n = (thread_id < block_size / WARP_SIZE)? x[lane*2+1] : T(0); | |
| num = (thread_id < block_size / WARP_SIZE)? count[lane] : int(0); | |
| } | |
| } | |
| if (wid==0) warp_reduce_mean_m2n(mean, m2n, num); | |
| return; | |
| } | |
| // return spatial size for NC+ Tensors | |
| __host__ int get_tensor_spatial_size(const at::Tensor& input) | |
| { | |
| auto space_size = input.size(2); | |
| for (int i = 3; i < input.ndimension(); i++) { | |
| space_size *= input.size(i); | |
| } | |
| return space_size; | |
| } | |
| // promote accumulation scalar type. promote half to float. | |
| __host__ at::ScalarType promote_scalartype(const at::Tensor& input) | |
| { | |
| return input.scalar_type() == at::ScalarType::Half ? | |
| at::ScalarType::Float : input.scalar_type(); | |
| } | |
| // return single element size, optional accumulation type promotion. | |
| __host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation = false) | |
| { | |
| auto scalar_type = accumulation ? promote_scalartype(input) : input.scalar_type(); | |
| return at::elementSize(scalar_type); | |
| } | |
| template<typename T, typename C> | |
| __device__ __forceinline__ void welford_merge_block_vertical(C& count, | |
| T& mean, | |
| T& m2n, | |
| C* shmem_count, | |
| T* shmem_mean, | |
| T* shmem_m2n) { | |
| // write to shared memory | |
| auto address_base = threadIdx.x + threadIdx.y * blockDim.x; | |
| shmem_mean[address_base] = mean; | |
| shmem_m2n[address_base] = m2n; | |
| shmem_count[address_base] = count; | |
| for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { | |
| __syncthreads(); | |
| if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { | |
| auto address = address_base + offset * blockDim.x; | |
| // read shared memory back to register for reduction | |
| auto num_new = shmem_count[address]; | |
| auto mean_new = shmem_mean[address]; | |
| auto m2n_new = shmem_m2n[address]; | |
| welford_merge_element(count, mean, m2n, num_new, mean_new, m2n_new); | |
| // last write is not necessary | |
| shmem_mean[address_base] = mean; | |
| shmem_m2n[address_base] = m2n; | |
| shmem_count[address_base] = count; | |
| } | |
| } | |
| } | |
| template<typename T> | |
| __device__ __forceinline__ void merge_block_vertical(T& sum_dy, | |
| T& sum_dy_xmu, | |
| T* shmem_sum_dy, | |
| T* shmem_sum_dy_xmu) { | |
| // write to shared memory | |
| auto address_base = threadIdx.x + threadIdx.y * blockDim.x; | |
| shmem_sum_dy[address_base] = sum_dy; | |
| shmem_sum_dy_xmu[address_base] = sum_dy_xmu; | |
| for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { | |
| __syncthreads(); | |
| if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { | |
| auto address = address_base + offset * blockDim.x; | |
| sum_dy += shmem_sum_dy[address]; | |
| sum_dy_xmu += shmem_sum_dy_xmu[address]; | |
| // last write is not necessary | |
| shmem_sum_dy[address_base] = sum_dy; | |
| shmem_sum_dy_xmu[address_base] = sum_dy_xmu; | |
| } | |
| } | |
| } | |
| // welford kernel calculating mean/biased_variance/unbiased_variance | |
| template <typename scalar_t, typename accscalar_t, typename outscalar_t> | |
| __global__ void welford_kernel( | |
| const scalar_t* __restrict__ input, | |
| outscalar_t* __restrict__ out_mean, | |
| outscalar_t* __restrict__ out_var_biased, | |
| const int bs, | |
| const int fs, | |
| const int ss) { | |
| int block_size = blockDim.x * blockDim.y; | |
| int count = 0; | |
| accscalar_t x_mean = accscalar_t(0); | |
| accscalar_t m_2_n = accscalar_t(0); | |
| int thread_id = threadIdx.y*blockDim.x + threadIdx.x; | |
| for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) { | |
| int input_base = blockIdx.x*ss + batch_id*ss*fs; | |
| // sequential welford | |
| for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) { | |
| count++; | |
| auto x_n = static_cast<accscalar_t>(input[offset+input_base]); | |
| auto d = x_n - x_mean; | |
| x_mean += d / count; | |
| m_2_n += d * (x_n - x_mean); | |
| } | |
| } | |
| static __shared__ int s_mem[160]; | |
| accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32]; | |
| welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id); | |
| if (thread_id == 0) { | |
| out_mean[blockIdx.x] = static_cast<outscalar_t>(x_mean); | |
| out_var_biased[blockIdx.x] = static_cast<outscalar_t>(m_2_n/count); | |
| } | |
| } | |
| // elementwise BN kernel | |
| template <typename scalar_t, typename accscalar_t, typename layerscalar_t> | |
| __global__ void batchnorm_forward_kernel( | |
| const scalar_t* __restrict__ input, | |
| const accscalar_t* __restrict__ mean, | |
| const accscalar_t* __restrict__ inv_std, | |
| const layerscalar_t* __restrict__ weight, | |
| const layerscalar_t* __restrict__ shift, | |
| scalar_t* __restrict__ out, | |
| const int ss, | |
| const int bs) { | |
| auto m_c = mean[blockIdx.x]; | |
| auto inv_std_c = inv_std[blockIdx.x]; | |
| auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x]); | |
| auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[blockIdx.x]); | |
| for (int batch_offset = blockIdx.y*blockDim.y + threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) { | |
| int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss; | |
| for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) { | |
| out[address_base+offset] = static_cast<scalar_t>(w_c * (static_cast<accscalar_t>(input[address_base+offset]) - m_c ) * inv_std_c + s_c); | |
| } | |
| } | |
| } | |
| // Backward BN kernel, calculates grad_bias, grad_weight as well as intermediate | |
| // results to calculating grad_input. | |
| // Breaking the grad_input to two step to support sync BN, which requires all | |
| // reduce of the intermediate results across processes. | |
| template <typename scalar_t, typename accscalar_t, typename layerscalar_t> | |
| __global__ void reduce_bn_kernel( | |
| const scalar_t* __restrict__ input, | |
| const scalar_t* __restrict__ grad_output, | |
| const accscalar_t* __restrict__ mean, | |
| const accscalar_t* __restrict__ inv_std, | |
| accscalar_t* __restrict__ sum_dy_o, | |
| accscalar_t* __restrict__ sum_dy_xmu_o, | |
| layerscalar_t* __restrict__ grad_weight, | |
| layerscalar_t* __restrict__ grad_bias, | |
| const int bs, | |
| const int fs, | |
| const int ss) { | |
| static __shared__ int s_mem[64]; | |
| //int total_item_num = bs * ss; | |
| int thread_id = threadIdx.y*blockDim.x + threadIdx.x; | |
| auto r_mean = mean[blockIdx.x]; | |
| auto factor = inv_std[blockIdx.x]; | |
| // Kahan sum | |
| accscalar_t sum_dy = 0.0; | |
| accscalar_t sum_dy_xmu = 0.0; | |
| accscalar_t sum_dy_c = 0.0; | |
| accscalar_t sum_dy_xmu_c = 0.0; | |
| for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) { | |
| int input_base = blockIdx.x*ss + batch_id*ss*fs; | |
| for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) { | |
| auto e_grad = static_cast<accscalar_t>(grad_output[offset+input_base]); | |
| auto e_input = static_cast<accscalar_t>(input[offset+input_base]); | |
| // calculating sum_dy | |
| auto sum_dy_y = e_grad - sum_dy_c; | |
| auto sum_dy_t = sum_dy + sum_dy_y; | |
| sum_dy_c = (sum_dy_t - sum_dy) - sum_dy_y; | |
| sum_dy = sum_dy_t; | |
| // calculating sum_dy_xmu | |
| auto sum_dy_xmu_y = e_grad * (e_input - r_mean) - sum_dy_xmu_c; | |
| auto sum_dy_xmu_t = sum_dy_xmu + sum_dy_xmu_y; | |
| sum_dy_xmu_c = (sum_dy_xmu_t - sum_dy_xmu) - sum_dy_xmu_y; | |
| sum_dy_xmu = sum_dy_xmu_t; | |
| } | |
| } | |
| sum_dy = reduce_block((accscalar_t*)s_mem, sum_dy); | |
| __syncthreads(); | |
| sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu); | |
| if (thread_id == 0) { | |
| if (grad_bias != NULL) { | |
| grad_bias[blockIdx.x] = static_cast<layerscalar_t>(sum_dy); | |
| } | |
| if (grad_weight != NULL) { | |
| grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor); | |
| } | |
| //mean_dy[blockIdx.x] = sum_dy / total_item_num; | |
| //mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num; | |
| sum_dy_o[blockIdx.x] = sum_dy; | |
| sum_dy_xmu_o[blockIdx.x] = sum_dy_xmu; | |
| } | |
| } | |
| // elementwise backward BN kernel | |
| template <typename scalar_t, typename accscalar_t, typename layerscalar_t> | |
| __global__ void batchnorm_backward_kernel( | |
| const scalar_t* __restrict__ grad_output, | |
| const scalar_t* __restrict__ input, | |
| const accscalar_t* __restrict__ mean, | |
| const accscalar_t* __restrict__ inv_std, | |
| const layerscalar_t* __restrict__ weight, | |
| const accscalar_t* __restrict__ sum_dy, | |
| const accscalar_t* __restrict__ sum_dy_xmu, | |
| const int* __restrict__ numel, | |
| scalar_t* __restrict__ grad_input, | |
| const int64_t world_size, | |
| const int ss, | |
| const int bs) { | |
| int64_t div = 0; | |
| for (int i = 0; i < world_size; i++) { | |
| div += numel[i]; | |
| } | |
| auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]); | |
| //auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]); | |
| auto m_dy_c = static_cast<accscalar_t>(sum_dy[blockIdx.x]) / div; | |
| auto factor_1_c = inv_std[blockIdx.x]; | |
| auto factor_2_c = (weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[blockIdx.x])) * factor_1_c; | |
| //factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x]; | |
| factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[blockIdx.x] / div; | |
| for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) { | |
| int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss; | |
| for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) { | |
| grad_input[address_base+offset] = (static_cast<accscalar_t>(grad_output[address_base+offset]) - m_dy_c - (static_cast<accscalar_t>(input[address_base+offset]) - m_c) * factor_1_c) * factor_2_c; | |
| } | |
| } | |
| } | |
| // welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance | |
| template | |
| <typename scalar_t, | |
| typename accscalar_t, | |
| typename outscalar_t, | |
| int PARALLEL_LOADS> | |
| __global__ void | |
| welford_kernel_c_last( | |
| const scalar_t* __restrict__ input, | |
| outscalar_t* __restrict__ out_mean, | |
| outscalar_t* __restrict__ out_var_biased, | |
| volatile accscalar_t* staging_data, | |
| int* semaphores, | |
| const int reduction_size, | |
| const int stride) { | |
| // hide latency with concurrency | |
| accscalar_t x_mean[PARALLEL_LOADS]; | |
| accscalar_t m_2_n[PARALLEL_LOADS]; | |
| int count[PARALLEL_LOADS]; | |
| for (int i = 0; i < PARALLEL_LOADS; i++) { | |
| x_mean[i] = accscalar_t(0); | |
| m_2_n[i] = accscalar_t(0); | |
| count[i] = accscalar_t(0); | |
| } | |
| // tensor dimension (m,c) | |
| // loop along m dimension | |
| int inner_loop_stride = blockDim.y * gridDim.y; | |
| // offset along m dimension | |
| int m_offset = blockIdx.y * blockDim.y + threadIdx.y; | |
| int c_offset = blockIdx.x * blockDim.x + threadIdx.x; | |
| int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); | |
| int address_base = m_offset * stride + c_offset; | |
| int address_increment = inner_loop_stride * stride; | |
| for (int i = 0; i < loop_count; i++) { | |
| accscalar_t x_math[PARALLEL_LOADS]; | |
| accscalar_t x_count_inv[PARALLEL_LOADS]; | |
| accscalar_t is_valid[PARALLEL_LOADS]; | |
| // load multiple data in | |
| for (int j = 0; j < PARALLEL_LOADS; j++) { | |
| if (c_offset < stride && m_offset < reduction_size) { | |
| x_math[j] = input[address_base]; | |
| count[j]++; | |
| x_count_inv[j] = accscalar_t(1) / count[j]; | |
| is_valid[j] = accscalar_t(1); | |
| } else { | |
| x_math[j] = accscalar_t(0); | |
| x_count_inv[j] = accscalar_t(0); | |
| is_valid[j] = accscalar_t(0); | |
| } | |
| m_offset += inner_loop_stride; | |
| address_base += address_increment; | |
| } | |
| // calculate mean/m2n with welford | |
| for (int j = 0; j < PARALLEL_LOADS; j++) { | |
| accscalar_t delta0 = x_math[j] - x_mean[j]; | |
| x_mean[j] += delta0 * x_count_inv[j]; | |
| accscalar_t delta1 = x_math[j] - x_mean[j]; | |
| m_2_n[j] += delta0 * delta1 * is_valid[j]; | |
| } | |
| } | |
| // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS | |
| for (int j = 1; j < PARALLEL_LOADS; j++) { | |
| welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]); | |
| } | |
| // release x_mean / m_2_n | |
| auto mean_th = x_mean[0]; | |
| auto m2_th = m_2_n[0]; | |
| auto count_th = count[0]; | |
| // block-wise reduction with shared memory (since reduction cannot be done within a warp) | |
| static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE]; | |
| static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE]; | |
| static __shared__ int shmem_count[MAX_BLOCK_SIZE]; | |
| welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); | |
| // grid reduction if needed (coop launch used at the first place) | |
| if (gridDim.y > 1) { | |
| volatile accscalar_t* staging_mean = staging_data; | |
| volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y]; | |
| volatile int* staging_count = reinterpret_cast<volatile int*>(&staging_m2n[stride*gridDim.y]); | |
| address_base = c_offset + blockIdx.y * stride; | |
| // write data to staging_data; | |
| if (threadIdx.y == 0 && c_offset < stride) { | |
| staging_mean[address_base] = mean_th; | |
| staging_m2n[address_base] = m2_th; | |
| staging_count[address_base] = count_th; | |
| } | |
| __threadfence(); | |
| __syncthreads(); // ensuring writes to staging_ is visible to all blocks | |
| __shared__ bool is_last_block_done; | |
| // mark block done | |
| if (threadIdx.x == 0 && threadIdx.y == 0) { | |
| int old = atomicAdd(&semaphores[blockIdx.x], 1); | |
| is_last_block_done = (old == (gridDim.y-1)); | |
| } | |
| __syncthreads(); | |
| // check that all data is now available in global memory | |
| if (is_last_block_done) { | |
| count_th = 0; | |
| mean_th = accscalar_t(0.0); | |
| m2_th = accscalar_t(0.0); | |
| for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { | |
| address_base = c_offset + y * stride; | |
| int num_new = c_offset < stride ? staging_count[address_base] : 0; | |
| accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0); | |
| accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0); | |
| welford_merge_element(count_th, mean_th, m2_th, num_new, mean_new, m2n_new); | |
| } | |
| welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); | |
| if (threadIdx.y == 0 && c_offset < stride) { | |
| out_mean[c_offset] = static_cast<outscalar_t>(mean_th); | |
| out_var_biased[c_offset] = static_cast<outscalar_t>(m2_th / count_th); | |
| } | |
| } | |
| } else { | |
| if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { | |
| out_mean[c_offset] = static_cast<outscalar_t>(mean_th); | |
| out_var_biased[c_offset] = static_cast<outscalar_t>(m2_th / count_th); | |
| } | |
| } | |
| } | |
| // parallel welford kernel to further reduce mean / biased_var | |
| // into mean / unbiased_var / inv_std across multiple processes. | |
| template <typename scalar_t> | |
| __global__ void welford_kernel_parallel( | |
| const scalar_t* __restrict__ mean, | |
| const scalar_t* __restrict__ var_biased, | |
| const int* __restrict__ numel, | |
| scalar_t* __restrict__ out_mean, | |
| scalar_t* __restrict__ out_var, | |
| scalar_t* __restrict__ inv_std, | |
| const int world_size, | |
| const int feature_size, | |
| const float eps) { | |
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < feature_size; i += gridDim.x * blockDim.x) { | |
| // load data; | |
| int address = i; | |
| scalar_t x_mean = 0; | |
| scalar_t m_2_n = 0; | |
| int count = 0; | |
| for (int j = 0; j < world_size; j++) { | |
| welford_merge_element(count, x_mean, m_2_n, numel[j], mean[address], var_biased[address]*numel[j]); | |
| address += feature_size; | |
| } | |
| out_mean[i] = x_mean; | |
| out_var[i] = m_2_n/ (count - 1); | |
| inv_std[i] = scalar_t(1) / sqrt(m_2_n/count + eps); | |
| } | |
| } | |
| // elementwise BN kernel | |
| template < | |
| typename scalar_t, | |
| typename accscalar_t, | |
| typename layerscalar_t, | |
| int PARALLEL_LOADS> | |
| __global__ void batchnorm_forward_c_last_kernel( | |
| const scalar_t* __restrict__ input, | |
| const scalar_t* __restrict__ z, | |
| const accscalar_t* __restrict__ mean, | |
| const accscalar_t* __restrict__ inv_std, | |
| const layerscalar_t* __restrict__ weight, | |
| const layerscalar_t* __restrict__ shift, | |
| scalar_t* __restrict__ out, | |
| const int reduction_size, | |
| const int stride, | |
| const bool fuse_relu) { | |
| // tensor dimension (m,c) | |
| // loop along m dimension | |
| int inner_loop_stride = blockDim.y * gridDim.y; | |
| // offset along m dimension | |
| int m_offset = blockIdx.y * blockDim.y + threadIdx.y; | |
| int c_offset = blockIdx.x * blockDim.x + threadIdx.x; | |
| auto m_c = mean[c_offset]; | |
| auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]); | |
| auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]); | |
| auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]); | |
| int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); | |
| int address_base = m_offset * stride + c_offset; | |
| int address_increment = inner_loop_stride * stride; | |
| for (int i = 0; i < loop_count; i++) { | |
| for (int j = 0; j < PARALLEL_LOADS; j++) { | |
| if (c_offset < stride && m_offset < reduction_size) { | |
| auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c; | |
| if (z != NULL) { | |
| tmp += z[address_base]; | |
| } | |
| out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp)); | |
| } | |
| m_offset += inner_loop_stride; | |
| address_base += address_increment; | |
| } | |
| } | |
| } | |
| // elementwise BN kernel | |
| template < | |
| typename scalar_t, | |
| typename accscalar_t, | |
| typename layerscalar_t, | |
| int PARALLEL_LOADS> | |
| __global__ void relu_backward_c_last_kernel( | |
| const scalar_t* __restrict__ grad_output, | |
| const scalar_t* __restrict__ input, | |
| const scalar_t* __restrict__ z, | |
| const accscalar_t* __restrict__ mean, | |
| const accscalar_t* __restrict__ inv_std, | |
| const layerscalar_t* __restrict__ weight, | |
| const layerscalar_t* __restrict__ shift, | |
| scalar_t* __restrict__ out, | |
| const int reduction_size, | |
| const int stride) { | |
| // tensor dimension (m,c) | |
| // loop along m dimension | |
| int inner_loop_stride = blockDim.y * gridDim.y; | |
| // offset along m dimension | |
| int m_offset = blockIdx.y * blockDim.y + threadIdx.y; | |
| int c_offset = blockIdx.x * blockDim.x + threadIdx.x; | |
| auto m_c = mean[c_offset]; | |
| auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]); | |
| auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]); | |
| auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]); | |
| int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); | |
| int address_base = m_offset * stride + c_offset; | |
| int address_increment = inner_loop_stride * stride; | |
| for (int i = 0; i < loop_count; i++) { | |
| for (int j = 0; j < PARALLEL_LOADS; j++) { | |
| if (c_offset < stride && m_offset < reduction_size) { | |
| auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c; | |
| if (z != NULL) { | |
| tmp += z[address_base]; | |
| } | |
| out[address_base] = (tmp <= accscalar_t(0.0) ? scalar_t(0.0) : grad_output[address_base]); | |
| } | |
| m_offset += inner_loop_stride; | |
| address_base += address_increment; | |
| } | |
| } | |
| } | |
| // batchnorm backward kernel for c last tensor | |
| template | |
| <typename scalar_t, | |
| typename accscalar_t, | |
| typename layerscalar_t, | |
| int PARALLEL_LOADS> | |
| __global__ void reduce_bn_c_last_kernel( | |
| const scalar_t* __restrict__ input, | |
| const scalar_t* __restrict__ grad_output, | |
| const accscalar_t* __restrict__ mean, | |
| const accscalar_t* __restrict__ inv_std, | |
| accscalar_t* __restrict__ sum_dy_o, | |
| accscalar_t* __restrict__ sum_dy_xmu_o, | |
| layerscalar_t* __restrict__ grad_weight, | |
| layerscalar_t* __restrict__ grad_bias, | |
| volatile accscalar_t* staging_data, | |
| int* semaphores, | |
| const int reduction_size, | |
| const int stride) { | |
| // hide latency with concurrency | |
| accscalar_t sum_dy[PARALLEL_LOADS]; | |
| accscalar_t sum_dy_xmu[PARALLEL_LOADS]; | |
| for (int i = 0; i < PARALLEL_LOADS; i++) { | |
| sum_dy[i] = accscalar_t(0); | |
| sum_dy_xmu[i] = accscalar_t(0); | |
| } | |
| // tensor dimension (m,c) | |
| // loop along m dimension | |
| int inner_loop_stride = blockDim.y * gridDim.y; | |
| // offset along m dimension | |
| int m_offset = blockIdx.y * blockDim.y + threadIdx.y; | |
| int c_offset = blockIdx.x * blockDim.x + threadIdx.x; | |
| int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); | |
| int address_base = m_offset * stride + c_offset; | |
| int address_increment = inner_loop_stride * stride; | |
| auto r_mean = mean[c_offset]; | |
| auto factor = inv_std[c_offset]; | |
| for (int i = 0; i < loop_count; i++) { | |
| accscalar_t x_input[PARALLEL_LOADS]; | |
| accscalar_t x_grad_output[PARALLEL_LOADS]; | |
| // load multiple data in | |
| for (int j = 0; j < PARALLEL_LOADS; j++) { | |
| if (c_offset < stride && m_offset < reduction_size) { | |
| x_input[j] = input[address_base]; | |
| x_grad_output[j] = grad_output[address_base]; | |
| } else { | |
| x_input[j] = accscalar_t(0); | |
| x_grad_output[j] = accscalar_t(0); | |
| } | |
| m_offset += inner_loop_stride; | |
| address_base += address_increment; | |
| } | |
| // calculate sum_dy / sum_dy_xmu | |
| for (int j = 0; j < PARALLEL_LOADS; j++) { | |
| sum_dy[j] += x_grad_output[j]; | |
| sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean); | |
| } | |
| } | |
| // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS | |
| for (int j = 1; j < PARALLEL_LOADS; j++) { | |
| sum_dy[0] += sum_dy[j]; | |
| sum_dy_xmu[0] += sum_dy_xmu[j]; | |
| } | |
| // release array of registers | |
| auto sum_dy_th = sum_dy[0]; | |
| auto sum_dy_xmu_th = sum_dy_xmu[0]; | |
| // block-wise reduction with shared memory (since reduction cannot be done within a warp) | |
| static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE]; | |
| static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE]; | |
| merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); | |
| // grid reduction if needed (coop launch used at the first place) | |
| if (gridDim.y > 1) { | |
| volatile accscalar_t* staging_sum_dy = staging_data; | |
| volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y]; | |
| address_base = c_offset + blockIdx.y * stride; | |
| // write data to staging_data; | |
| if (threadIdx.y == 0 && c_offset < stride) { | |
| staging_sum_dy[address_base] = sum_dy_th; | |
| staging_sum_dy_xmu[address_base] = sum_dy_xmu_th; | |
| } | |
| __threadfence(); | |
| __syncthreads(); // ensuring writes to staging_ is visible to all blocks | |
| __shared__ bool is_last_block_done; | |
| // mark block done | |
| if (threadIdx.x == 0 && threadIdx.y == 0) { | |
| int old = atomicAdd(&semaphores[blockIdx.x], 1); | |
| is_last_block_done = (old == (gridDim.y-1)); | |
| } | |
| __syncthreads(); | |
| // check that all data is now available in global memory | |
| if (is_last_block_done) { | |
| sum_dy_th = accscalar_t(0.0); | |
| sum_dy_xmu_th = accscalar_t(0.0); | |
| for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { | |
| address_base = c_offset + y * stride; | |
| sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0)); | |
| sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0)); | |
| } | |
| merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); | |
| if (threadIdx.y == 0 && c_offset < stride) { | |
| if (grad_bias != NULL) { | |
| grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th); | |
| } | |
| if (grad_weight != NULL) { | |
| grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor); | |
| } | |
| //mean_dy[c_offset] = sum_dy_th / reduction_size; | |
| //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; | |
| sum_dy_o[c_offset] = sum_dy_th; | |
| sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; | |
| } | |
| } | |
| } else { | |
| if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { | |
| if (grad_bias != NULL) { | |
| grad_bias[c_offset] = static_cast<layerscalar_t>(sum_dy_th); | |
| } | |
| if (grad_weight != NULL) { | |
| grad_weight[c_offset] = static_cast<layerscalar_t>(sum_dy_xmu_th * factor); | |
| } | |
| //mean_dy[c_offset] = sum_dy_th / reduction_size; | |
| //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; | |
| sum_dy_o[c_offset] = sum_dy_th; | |
| sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; | |
| } | |
| } | |
| } | |
| // elementwise BN kernel | |
| template < | |
| typename scalar_t, | |
| typename accscalar_t, | |
| typename layerscalar_t, | |
| int PARALLEL_LOADS> | |
| __global__ void batchnorm_backward_c_last_kernel( | |
| const scalar_t* __restrict__ grad_output, | |
| const scalar_t* __restrict__ input, | |
| const accscalar_t* __restrict__ mean, | |
| const accscalar_t* __restrict__ inv_std, | |
| const layerscalar_t* __restrict__ weight, | |
| const accscalar_t* __restrict__ sum_dy, | |
| const accscalar_t* __restrict__ sum_dy_xmu, | |
| const int* __restrict__ numel, | |
| scalar_t* __restrict__ grad_input, | |
| const int64_t world_size, | |
| const int reduction_size, | |
| const int stride) { | |
| int64_t div = 0; | |
| for (int i = 0; i < world_size; i++) { | |
| div += numel[i]; | |
| } | |
| // tensor dimension (m,c) | |
| // loop along m dimension | |
| int inner_loop_stride = blockDim.y * gridDim.y; | |
| // offset along m dimension | |
| int m_offset = blockIdx.y * blockDim.y + threadIdx.y; | |
| int c_offset = blockIdx.x * blockDim.x + threadIdx.x; | |
| auto m_c = mean[c_offset]; | |
| auto m_dy_c = sum_dy[c_offset] / div; | |
| auto factor_1_c = inv_std[c_offset]; | |
| auto factor_2_c = (weight == NULL? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset])) * factor_1_c; | |
| factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] / div; | |
| int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); | |
| int address_base = m_offset * stride + c_offset; | |
| int address_increment = inner_loop_stride * stride; | |
| for (int i = 0; i < loop_count; i++) { | |
| for (int j = 0; j < PARALLEL_LOADS; j++) { | |
| if (c_offset < stride && m_offset < reduction_size) { | |
| grad_input[address_base] = static_cast<scalar_t>( | |
| (static_cast<accscalar_t>(grad_output[address_base]) - m_dy_c - | |
| (static_cast<accscalar_t>(input[address_base]) - m_c) * factor_1_c) | |
| * factor_2_c); | |
| } | |
| m_offset += inner_loop_stride; | |
| address_base += address_increment; | |
| } | |
| } | |
| } | |
| std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) { | |
| const auto batch_size = input.size(0); | |
| const auto feature_size = input.size(1); | |
| auto space_size = get_tensor_spatial_size(input); | |
| auto scalar_type = promote_scalartype(input); | |
| at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type)); | |
| at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type)); | |
| int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / 32)); | |
| int block_x = max(1, min(MAX_BLOCK_SIZE / block_y, h_last_pow2(space_size))); | |
| const dim3 block(block_x, block_y); | |
| const dim3 grid(feature_size); | |
| auto stream = at::cuda::getCurrentCUDAStream(); | |
| { | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_kernel", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| welford_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( | |
| input.DATA_PTR<scalar_t_0>(), | |
| out_mean.DATA_PTR<accscalar_t>(), | |
| out_var_biased.DATA_PTR<accscalar_t>(), | |
| batch_size, | |
| feature_size, | |
| space_size); | |
| ); | |
| } | |
| return {out_mean, out_var_biased}; | |
| } | |
| at::Tensor batchnorm_forward_CUDA( | |
| const at::Tensor input, | |
| const at::Tensor mean, | |
| const at::Tensor inv_std, | |
| const at::optional<at::Tensor> weight, | |
| const at::optional<at::Tensor> shift) { | |
| const auto batch_size = input.size(0); | |
| const auto feature_size = input.size(1); | |
| at::Tensor out = at::empty_like(input); | |
| auto space_size = get_tensor_spatial_size(input); | |
| int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4)); | |
| int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4)); | |
| const dim3 block(block_x, block_y); | |
| int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x)); | |
| int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y)); | |
| const dim3 grid(feature_size, batch_group_size, grid_z); | |
| auto stream = at::cuda::getCurrentCUDAStream(); | |
| if (input.scalar_type() == at::ScalarType::Half | |
| && weight.has_value() && | |
| weight.value().scalar_type() == at::ScalarType::Float) { | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| batchnorm_forward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( | |
| input.DATA_PTR<scalar_t_0>(), | |
| mean.DATA_PTR<accscalar_t>(), | |
| inv_std.DATA_PTR<accscalar_t>(), | |
| weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL, | |
| shift.has_value() ? shift.value().DATA_PTR<accscalar_t>() : NULL, | |
| out.DATA_PTR<scalar_t_0>(), | |
| space_size, | |
| batch_size); | |
| ); | |
| } else { | |
| if (weight.has_value()) { | |
| TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(), | |
| "input.scalar_type() is not supported with weight.scalar_type()"); | |
| } | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| batchnorm_forward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>( | |
| input.DATA_PTR<scalar_t_0>(), | |
| mean.DATA_PTR<accscalar_t>(), | |
| inv_std.DATA_PTR<accscalar_t>(), | |
| weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL, | |
| shift.has_value() ? shift.value().DATA_PTR<scalar_t_0>() : NULL, | |
| out.DATA_PTR<scalar_t_0>(), | |
| space_size, | |
| batch_size); | |
| ); | |
| } | |
| return out; | |
| } | |
| std::vector<at::Tensor> reduce_bn_CUDA( | |
| const at::Tensor grad_output, | |
| const at::Tensor input, | |
| const at::Tensor mean, | |
| const at::Tensor inv_std, | |
| const at::optional<at::Tensor> weight) | |
| { | |
| const auto batch_size = input.size(0); | |
| const auto feature_size = input.size(1); | |
| auto scalar_type = promote_scalartype(input); | |
| at::Tensor sum_dy = at::empty({feature_size}, mean.options()); | |
| at::Tensor sum_dy_xmu = at::empty({feature_size}, mean.options()); | |
| at::Tensor grad_weight; | |
| at::Tensor grad_bias; | |
| if (weight.has_value()) { | |
| grad_weight = at::empty({feature_size}, weight.value().options()); | |
| grad_bias = at::empty({feature_size}, weight.value().options()); | |
| } else { | |
| grad_weight = at::empty({0}, mean.options()); | |
| grad_bias = at::empty({0}, mean.options()); | |
| } | |
| auto space_size = get_tensor_spatial_size(input); | |
| int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ 32)); | |
| int block_x = max(1, min(MAX_BLOCK_SIZE/ block_y, h_last_pow2(space_size))); | |
| const dim3 block(block_x, block_y); | |
| const dim3 grid(feature_size); | |
| auto stream = at::cuda::getCurrentCUDAStream(); | |
| if (input.scalar_type() == at::ScalarType::Half | |
| && weight.has_value() && | |
| weight.value().scalar_type() == at::ScalarType::Float) { | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| reduce_bn_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( | |
| input.DATA_PTR<scalar_t_0>(), | |
| grad_output.DATA_PTR<scalar_t_0>(), | |
| mean.DATA_PTR<accscalar_t>(), | |
| inv_std.DATA_PTR<accscalar_t>(), | |
| sum_dy.DATA_PTR<accscalar_t>(), | |
| sum_dy_xmu.DATA_PTR<accscalar_t>(), | |
| weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL, | |
| weight.has_value() ? grad_bias.DATA_PTR<accscalar_t>() : NULL, | |
| batch_size, | |
| feature_size, | |
| space_size); | |
| ); | |
| } else { | |
| if (weight.has_value()) { | |
| TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(), | |
| "input.scalar_type() is not supported with weight.scalar_type()"); | |
| } | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| reduce_bn_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>( | |
| input.DATA_PTR<scalar_t_0>(), | |
| grad_output.DATA_PTR<scalar_t_0>(), | |
| mean.DATA_PTR<accscalar_t>(), | |
| inv_std.DATA_PTR<accscalar_t>(), | |
| sum_dy.DATA_PTR<accscalar_t>(), | |
| sum_dy_xmu.DATA_PTR<accscalar_t>(), | |
| weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL, | |
| weight.has_value() ? grad_bias.DATA_PTR<scalar_t_0>() : NULL, | |
| batch_size, | |
| feature_size, | |
| space_size); | |
| ); | |
| } | |
| return {sum_dy, sum_dy_xmu, grad_weight, grad_bias}; | |
| } | |
| at::Tensor batchnorm_backward_CUDA( | |
| const at::Tensor grad_output, | |
| const at::Tensor input, | |
| const at::Tensor mean, | |
| const at::Tensor inv_std, | |
| const at::optional<at::Tensor> weight, | |
| const at::Tensor sum_dy, | |
| const at::Tensor sum_dy_xmu, | |
| const at::Tensor count) { | |
| const auto batch_size = input.size(0); | |
| const auto feature_size = input.size(1); | |
| at::Tensor grad_input = at::empty_like(input); | |
| auto space_size = get_tensor_spatial_size(input); | |
| int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4)); | |
| int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4)); | |
| const dim3 block(block_x, block_y); | |
| int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x)); | |
| int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y)); | |
| const dim3 grid(feature_size, batch_group_size, grid_z); | |
| auto stream = at::cuda::getCurrentCUDAStream(); | |
| if (input.scalar_type() == at::ScalarType::Half | |
| && weight.has_value() && | |
| weight.value().scalar_type() == at::ScalarType::Float) { | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| batchnorm_backward_kernel<scalar_t_0, accscalar_t, accscalar_t><<<grid, block, 0, stream>>>( | |
| grad_output.DATA_PTR<scalar_t_0>(), | |
| input.DATA_PTR<scalar_t_0>(), | |
| mean.DATA_PTR<accscalar_t>(), | |
| inv_std.DATA_PTR<accscalar_t>(), | |
| weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL, | |
| sum_dy.DATA_PTR<accscalar_t>(), | |
| sum_dy_xmu.DATA_PTR<accscalar_t>(), | |
| count.DATA_PTR<int>(), | |
| grad_input.DATA_PTR<scalar_t_0>(), | |
| count.numel(), | |
| space_size, | |
| batch_size); | |
| ); | |
| } else { | |
| if (weight.has_value()) { | |
| TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(), | |
| "input.scalar_type() is not supported with weight.scalar_type()"); | |
| } | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| batchnorm_backward_kernel<scalar_t_0, accscalar_t, scalar_t_0><<<grid, block, 0, stream>>>( | |
| grad_output.DATA_PTR<scalar_t_0>(), | |
| input.DATA_PTR<scalar_t_0>(), | |
| mean.DATA_PTR<accscalar_t>(), | |
| inv_std.DATA_PTR<accscalar_t>(), | |
| weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL, | |
| sum_dy.DATA_PTR<accscalar_t>(), | |
| sum_dy_xmu.DATA_PTR<accscalar_t>(), | |
| count.DATA_PTR<int>(), | |
| grad_input.DATA_PTR<scalar_t_0>(), | |
| count.numel(), | |
| space_size, | |
| batch_size); | |
| ); | |
| } | |
| return grad_input; | |
| } | |
| std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes, | |
| const at::Tensor var_biased, | |
| const at::Tensor numel, | |
| const float eps) { | |
| const auto world_size = mean_feature_nodes.size(0); | |
| const auto feature_size = mean_feature_nodes.size(1); | |
| at::Tensor out_var = at::empty({feature_size}, var_biased.options()); | |
| at::Tensor inv_std = at::empty_like(out_var); | |
| at::Tensor out_mean = at::empty_like(out_var); | |
| at::Tensor mean_feature_nodes_ = mean_feature_nodes.contiguous(); | |
| at::Tensor var_biased_ = var_biased.contiguous(); | |
| at::Tensor numel_ = numel.contiguous(); | |
| // TODO(jie): tile this for memory coalescing! | |
| const int block = std::min(h_last_pow2(feature_size), MAX_BLOCK_SIZE); | |
| const int grid = std::max<int>(1, feature_size / block); | |
| auto stream = at::cuda::getCurrentCUDAStream(); | |
| { | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(mean_feature_nodes.scalar_type(), 0, "welford_parallel_kernel", | |
| welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>( | |
| mean_feature_nodes_.DATA_PTR<scalar_t_0>(), | |
| var_biased_.DATA_PTR<scalar_t_0>(), | |
| numel_.DATA_PTR<int>(), | |
| out_mean.DATA_PTR<scalar_t_0>(), | |
| out_var.DATA_PTR<scalar_t_0>(), | |
| inv_std.DATA_PTR<scalar_t_0>(), | |
| world_size, | |
| feature_size, | |
| eps); | |
| ); | |
| } | |
| return {out_mean, out_var, inv_std}; | |
| } | |
| std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) { | |
| const auto stride = input.size(input.ndimension()-1); | |
| const auto reduction_size = input.numel() / stride; | |
| auto scalar_type = promote_scalartype(input); | |
| auto option = input.options().dtype(scalar_type); | |
| at::Tensor out_var_biased = at::empty({stride}, option); | |
| at::Tensor out_mean = at::empty({stride}, option); | |
| dim3 block; | |
| dim3 grid; | |
| flexible_launch_configs(reduction_size, stride, block, grid, true); | |
| at::Tensor staging_data; | |
| at::Tensor semaphores; | |
| if (grid.y > 1) { | |
| staging_data = at::empty({4*stride*grid.y}, option); | |
| semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); | |
| } | |
| auto stream = at::cuda::getCurrentCUDAStream(); | |
| { | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_c_last", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR<accscalar_t>() : nullptr; | |
| int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR<int>() : nullptr; | |
| welford_kernel_c_last<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> | |
| <<<grid, block, 0, stream>>>( | |
| input.DATA_PTR<scalar_t_0>(), | |
| out_mean.DATA_PTR<accscalar_t>(), | |
| out_var_biased.DATA_PTR<accscalar_t>(), | |
| staging_data_ptr, | |
| semaphores_ptr, | |
| reduction_size, | |
| stride); | |
| ); | |
| } | |
| return {out_mean, out_var_biased}; | |
| } | |
| at::Tensor batchnorm_forward_c_last_CUDA( | |
| const at::Tensor input, | |
| const at::optional<at::Tensor> z, | |
| const at::Tensor mean, | |
| const at::Tensor inv_std, | |
| const at::optional<at::Tensor> weight, | |
| const at::optional<at::Tensor> shift, | |
| const bool fuse_relu) { | |
| const auto stride = input.size(input.ndimension()-1); | |
| const auto reduction_size = input.numel() / stride; | |
| at::Tensor out = at::empty_like(input); | |
| dim3 block; | |
| dim3 grid; | |
| flexible_launch_configs(reduction_size, stride, block, grid); | |
| auto stream = at::cuda::getCurrentCUDAStream(); | |
| if (input.scalar_type() == at::ScalarType::Half | |
| && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) { | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> | |
| <<<grid, block, 0, stream>>>( | |
| input.DATA_PTR<scalar_t_0>(), | |
| z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL, | |
| mean.DATA_PTR<accscalar_t>(), | |
| inv_std.DATA_PTR<accscalar_t>(), | |
| weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL, | |
| shift.has_value() ? shift.value().DATA_PTR<accscalar_t>(): NULL, | |
| out.DATA_PTR<scalar_t_0>(), | |
| reduction_size, | |
| stride, | |
| fuse_relu); | |
| ); | |
| } else { | |
| if (weight.has_value()) { | |
| TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(), | |
| "input.scalar_type() is not supported with weight.scalar_type()"); | |
| } | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER> | |
| <<<grid, block, 0, stream>>>( | |
| input.DATA_PTR<scalar_t_0>(), | |
| z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL, | |
| mean.DATA_PTR<accscalar_t>(), | |
| inv_std.DATA_PTR<accscalar_t>(), | |
| weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL, | |
| shift.has_value() ? shift.value().DATA_PTR<scalar_t_0>(): NULL, | |
| out.DATA_PTR<scalar_t_0>(), | |
| reduction_size, | |
| stride, | |
| fuse_relu); | |
| ); | |
| } | |
| return out; | |
| } | |
| std::vector<at::Tensor> reduce_bn_c_last_CUDA( | |
| const at::Tensor grad_output, | |
| const at::Tensor input, | |
| const at::Tensor mean, | |
| const at::Tensor inv_std, | |
| const at::optional<at::Tensor> weight) { | |
| const auto stride = input.size(input.ndimension()-1); | |
| const auto reduction_size = input.numel() / stride; | |
| at::Tensor sumn_dy = at::empty({stride}, mean.options()); | |
| at::Tensor sum_dy_xmu = at::empty({stride}, mean.options()); | |
| at::Tensor grad_weight; | |
| at::Tensor grad_bias; | |
| if (weight.has_value()) { | |
| grad_weight = at::empty({stride}, weight.value().options()); | |
| grad_bias = at::empty({stride}, weight.value().options()); | |
| } else { | |
| // because I cannot return an uninitialized at::Tensor | |
| grad_weight = at::empty({0}, mean.options()); | |
| grad_bias = at::empty({0}, mean.options()); | |
| } | |
| dim3 block; | |
| dim3 grid; | |
| flexible_launch_configs(reduction_size, stride, block, grid, true); | |
| at::Tensor staging_data; | |
| at::Tensor semaphores; | |
| if (grid.y > 1) { | |
| staging_data = at::empty({2*stride*grid.y}, mean.options()); | |
| semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); | |
| } | |
| auto stream = at::cuda::getCurrentCUDAStream(); | |
| if (input.scalar_type() == at::ScalarType::Half | |
| && weight.has_value() | |
| && weight.value().scalar_type() == at::ScalarType::Float) { | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR<accscalar_t>() : nullptr; | |
| int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR<int>() : nullptr; | |
| reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> | |
| <<<grid, block, 0, stream>>>( | |
| input.DATA_PTR<scalar_t_0>(), | |
| grad_output.DATA_PTR<scalar_t_0>(), | |
| mean.DATA_PTR<accscalar_t>(), | |
| inv_std.DATA_PTR<accscalar_t>(), | |
| sumn_dy.DATA_PTR<accscalar_t>(), | |
| sum_dy_xmu.DATA_PTR<accscalar_t>(), | |
| weight.has_value() ? grad_weight.DATA_PTR<accscalar_t>() : NULL, | |
| weight.has_value() ?grad_bias.DATA_PTR<accscalar_t>() : NULL, | |
| staging_data_ptr, | |
| semaphores_ptr, | |
| reduction_size, | |
| stride); | |
| ); | |
| } else { | |
| if (weight.has_value()) { | |
| TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(), | |
| "input.scalar_type() is not supported with weight.scalar_type()"); | |
| } | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR<accscalar_t>() : nullptr; | |
| int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR<int>() : nullptr; | |
| reduce_bn_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER> | |
| <<<grid, block, 0, stream>>>( | |
| input.DATA_PTR<scalar_t_0>(), | |
| grad_output.DATA_PTR<scalar_t_0>(), | |
| mean.DATA_PTR<accscalar_t>(), | |
| inv_std.DATA_PTR<accscalar_t>(), | |
| sumn_dy.DATA_PTR<accscalar_t>(), | |
| sum_dy_xmu.DATA_PTR<accscalar_t>(), | |
| weight.has_value() ? grad_weight.DATA_PTR<scalar_t_0>() : NULL, | |
| weight.has_value() ?grad_bias.DATA_PTR<scalar_t_0>() : NULL, | |
| staging_data_ptr, | |
| semaphores_ptr, | |
| reduction_size, | |
| stride); | |
| ); | |
| } | |
| return {sumn_dy, sum_dy_xmu, grad_weight, grad_bias}; | |
| } | |
| at::Tensor batchnorm_backward_c_last_CUDA( | |
| const at::Tensor grad_output, | |
| const at::Tensor input, | |
| const at::Tensor mean, | |
| const at::Tensor inv_std, | |
| const at::optional<at::Tensor> weight, | |
| const at::Tensor sum_dy, | |
| const at::Tensor sum_dy_xmu, | |
| const at::Tensor count) { | |
| const auto stride = input.size(input.ndimension()-1); | |
| const auto reduction_size = input.numel() / stride; | |
| at::Tensor grad_input = at::empty_like(input); | |
| dim3 block; | |
| dim3 grid; | |
| flexible_launch_configs(reduction_size, stride, block, grid); | |
| auto stream = at::cuda::getCurrentCUDAStream(); | |
| if (input.scalar_type() == at::ScalarType::Half | |
| && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) { | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> | |
| <<<grid, block, 0, stream>>>( | |
| grad_output.DATA_PTR<scalar_t_0>(), | |
| input.DATA_PTR<scalar_t_0>(), | |
| mean.DATA_PTR<accscalar_t>(), | |
| inv_std.DATA_PTR<accscalar_t>(), | |
| weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL, | |
| sum_dy.DATA_PTR<accscalar_t>(), | |
| sum_dy_xmu.DATA_PTR<accscalar_t>(), | |
| count.DATA_PTR<int>(), | |
| grad_input.DATA_PTR<scalar_t_0>(), | |
| count.numel(), | |
| reduction_size, | |
| stride); | |
| ); | |
| } else { | |
| if (weight.has_value()) { | |
| TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(), | |
| "input.scalar_type() is not supported with weight.scalar_type()"); | |
| } | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| batchnorm_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER> | |
| <<<grid, block, 0, stream>>>( | |
| grad_output.DATA_PTR<scalar_t_0>(), | |
| input.DATA_PTR<scalar_t_0>(), | |
| mean.DATA_PTR<accscalar_t>(), | |
| inv_std.DATA_PTR<accscalar_t>(), | |
| weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL, | |
| sum_dy.DATA_PTR<accscalar_t>(), | |
| sum_dy_xmu.DATA_PTR<accscalar_t>(), | |
| count.DATA_PTR<int>(), | |
| grad_input.DATA_PTR<scalar_t_0>(), | |
| count.numel(), | |
| reduction_size, | |
| stride); | |
| ); | |
| } | |
| return grad_input; | |
| } | |
| at::Tensor relu_backward_c_last_CUDA( | |
| const at::Tensor grad_output, | |
| const at::Tensor input, | |
| const at::optional<at::Tensor> z, | |
| const at::Tensor mean, | |
| const at::Tensor inv_std, | |
| const at::optional<at::Tensor> weight, | |
| const at::optional<at::Tensor> shift) { | |
| const auto stride = input.size(input.ndimension()-1); | |
| const auto reduction_size = input.numel() / stride; | |
| at::Tensor out = at::empty_like(input); | |
| dim3 block; | |
| dim3 grid; | |
| flexible_launch_configs(reduction_size, stride, block, grid); | |
| auto stream = at::cuda::getCurrentCUDAStream(); | |
| if (input.scalar_type() == at::ScalarType::Half | |
| && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) { | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| relu_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> | |
| <<<grid, block, 0, stream>>>( | |
| grad_output.DATA_PTR<scalar_t_0>(), | |
| input.DATA_PTR<scalar_t_0>(), | |
| z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL, | |
| mean.DATA_PTR<accscalar_t>(), | |
| inv_std.DATA_PTR<accscalar_t>(), | |
| weight.has_value() ? weight.value().DATA_PTR<accscalar_t>() : NULL, | |
| shift.has_value() ? shift.value().DATA_PTR<accscalar_t>(): NULL, | |
| out.DATA_PTR<scalar_t_0>(), | |
| reduction_size, | |
| stride); | |
| ); | |
| } else { | |
| if (weight.has_value()) { | |
| TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(), | |
| "input.scalar_type() is not supported with weight.scalar_type()"); | |
| } | |
| using namespace at; | |
| DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", | |
| using accscalar_t = at::acc_type<scalar_t_0, true>; | |
| relu_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER> | |
| <<<grid, block, 0, stream>>>( | |
| grad_output.DATA_PTR<scalar_t_0>(), | |
| input.DATA_PTR<scalar_t_0>(), | |
| z.has_value() ? z.value().DATA_PTR<scalar_t_0>() : NULL, | |
| mean.DATA_PTR<accscalar_t>(), | |
| inv_std.DATA_PTR<accscalar_t>(), | |
| weight.has_value() ? weight.value().DATA_PTR<scalar_t_0>() : NULL, | |
| shift.has_value() ? shift.value().DATA_PTR<scalar_t_0>(): NULL, | |
| out.DATA_PTR<scalar_t_0>(), | |
| reduction_size, | |
| stride); | |
| ); | |
| } | |
| return out; | |
| } | |