Spaces:
Sleeping
Sleeping
Dit-document-layout-analysis
/
unilm
/edgelm
/fairseq
/modules
/dynamicconv_layer
/dynamicconv_cuda_kernel.cu
/** | |
* Copyright (c) Facebook, Inc. and its affiliates. | |
* | |
* This source code is licensed under the MIT license found in the | |
* LICENSE file in the root directory of this source tree. | |
*/ | |
// FS is filter size and kernels are specialized for filter sizes | |
template <int FS, int SB, int padding_l, typename scalar_t> | |
__global__ void dynamicconv_forward_kernel( | |
const scalar_t* input, | |
const scalar_t* weight, | |
int minibatch, | |
int sequenceLength, | |
int numFeatures, | |
int numFiltersInBlock, | |
int numHeads, | |
scalar_t* output) { | |
assert(blockDim.x == SB); | |
const int tid = threadIdx.x; | |
const int batchIdx = blockIdx.x; | |
const int featureIdx = blockIdx.y; | |
const int head = featureIdx / numFiltersInBlock; | |
const int IOOffset = | |
batchIdx * numFeatures * sequenceLength + featureIdx * sequenceLength; | |
const scalar_t* inputFeature = &input[IOOffset]; | |
scalar_t* outputFeature = &output[IOOffset]; | |
scalar_t filter[FS]; | |
__shared__ scalar_t tempInput[SB + FS]; | |
zeroSharedMem<FS, SB, padding_l>(tempInput); | |
const int numIterations = divUp<int, int>(sequenceLength, SB); | |
for (int i = 0; i < numIterations; ++i) { | |
__syncthreads(); | |
const int inputOffset = i * SB; | |
load_input_to_shared<FS, SB, padding_l>( | |
inputFeature, | |
inputOffset, | |
sequenceLength, | |
i, | |
numIterations, | |
false, | |
tempInput); | |
__syncthreads(); | |
if (inputOffset + tid < sequenceLength) { | |
for (int k = 0; k < FS; ++k) { | |
const int filterOffset = batchIdx * numHeads * FS * sequenceLength + | |
head * FS * sequenceLength + k * sequenceLength + i * SB + tid; | |
filter[k] = weight[filterOffset]; | |
} | |
scalar_t out = scalar_t(0.0); | |
for (int k = 0; k < FS; ++k) { | |
out += filter[k] * tempInput[tid + k]; | |
} | |
outputFeature[inputOffset + tid] = out; | |
} | |
} | |
} | |
template <int FS, int SB, int padding_l, typename scalar_t> | |
__global__ void dynamicconv_backward_kernel( | |
const scalar_t* gradOutput, // B * C * T | |
const scalar_t* input, // B * C * T | |
const scalar_t* weight, | |
int minibatch, | |
int sequenceLength, | |
int numFeatures, | |
int numFiltersInBlock, | |
int numHeads, | |
scalar_t* gradWeight, | |
scalar_t* gradInput) { // B * H * k * T | |
assert(blockDim.x == SB); | |
// each block operates on a single batch and filter head | |
const int tid = threadIdx.x; | |
const int batchIdx = blockIdx.x; | |
const int headIdx = blockIdx.y; | |
const int chunkIdx = blockIdx.z; | |
const int numChunks = divUp<int, int>(sequenceLength, SB); | |
const int inputOffset = chunkIdx * SB; | |
// initialize shared memory for output gradient and input | |
__shared__ scalar_t tempGradOutput[SB + FS]; | |
__shared__ scalar_t tempInput[SB + FS]; | |
const int padding = FS - padding_l - 1; | |
zeroSharedMem<FS, SB, padding>(tempGradOutput); | |
zeroSharedMem<FS, SB, padding_l>(tempInput); | |
// initialize local filter and weight gradient sum arrays | |
scalar_t tempGradSum[FS]; | |
scalar_t bfilter[FS]; | |
for (int k = 0; k < FS; ++k) { | |
tempGradSum[k] = scalar_t(0.0); | |
int idxOffset = inputOffset + tid + k - padding; | |
if (idxOffset >= 0 && idxOffset < sequenceLength) { | |
int bfilterOffset = batchIdx * numHeads * FS * sequenceLength + | |
headIdx * FS * sequenceLength + (FS - k - 1) * sequenceLength + | |
idxOffset; | |
bfilter[k] = weight[bfilterOffset]; | |
} else { | |
bfilter[k] = scalar_t(0.0); | |
} | |
} | |
// iterate over filter block | |
for (int featureIdx = 0; featureIdx < numFiltersInBlock; ++featureIdx) { | |
__syncthreads(); | |
// load input and output gradient for this channel and chunk | |
const int IOOffset = batchIdx * numFeatures * sequenceLength + | |
(headIdx * numFiltersInBlock + featureIdx) * sequenceLength; | |
const scalar_t* inputFeature = &input[IOOffset]; | |
const scalar_t* gradOutputFeature = &gradOutput[IOOffset]; | |
scalar_t* gradInputFeature = &gradInput[IOOffset]; | |
load_input_to_shared<FS, SB, padding>( | |
gradOutputFeature, | |
inputOffset, | |
sequenceLength, | |
chunkIdx, | |
numChunks, | |
true, | |
tempGradOutput); | |
load_input_to_shared<FS, SB, padding_l>( | |
inputFeature, | |
inputOffset, | |
sequenceLength, | |
chunkIdx, | |
numChunks, | |
true, | |
tempInput); | |
__syncthreads(); | |
// sum input and weight gradients | |
scalar_t out = scalar_t(0.0); | |
for (int k = 0; k < FS; ++k) { | |
tempGradSum[k] += tempInput[tid + k] * tempGradOutput[tid + padding]; | |
out += bfilter[k] * tempGradOutput[tid + k]; | |
} | |
if (inputOffset + tid < sequenceLength) { | |
gradInputFeature[inputOffset + tid] = out; | |
} | |
} | |
const int gradOffset = | |
batchIdx * numHeads * FS * sequenceLength + headIdx * FS * sequenceLength; | |
scalar_t* gradWeightFeature = &gradWeight[gradOffset]; | |
// write weight gradient | |
if (inputOffset + tid < sequenceLength) { | |
for (int k = 0; k < FS; ++k) { | |
const int outputOffset = k * sequenceLength + inputOffset + tid; | |
gradWeightFeature[outputOffset] = tempGradSum[k]; | |
} | |
} | |
} | |