// Copyright 2019 Yan Yan // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef INDICE_CU_H_ #define INDICE_CU_H_ #include #include #include template __global__ void prepareIndicePairsKernel( tv::TensorView indicesIn, tv::TensorView indicesOut, tv::TensorView gridsOut, tv::TensorView indicePairs, tv::TensorView indiceNum, tv::TensorView indicePairUnique, const tv::SimpleVector kernelSize, const tv::SimpleVector stride, const tv::SimpleVector padding, const tv::SimpleVector dilation, const tv::SimpleVector outSpatialShape) { auto numActIn = indicesIn.dim(0); Index spatialVolume = 1; #pragma unroll for (int i = 0; i < NDim; ++i) { spatialVolume *= outSpatialShape[i]; } Index kernelVolume = 1; #pragma unroll for (int i = 0; i < NDim; ++i) { kernelVolume *= kernelSize[i]; } Index numValidPoints = 0; Index validPoints[KernelMaxVolume * (NDim + 1)]; Index *pointPtr = nullptr; auto indicePairsDim2 = indicePairs.dim(2); Index index; for (int ix : tv::KernelLoopX(numActIn)) { numValidPoints = getValidOutPos( indicesIn.data() + ix * (NDim + 1) + 1, kernelSize.data(), stride.data(), padding.data(), dilation.data(), outSpatialShape.data(), validPoints); for (Index i = 0; i < numValidPoints; ++i) { pointPtr = validPoints + i * (NDim + 1); auto offset = pointPtr[NDim]; auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1)); indicePairs(offset, 0, oldNum) = ix; index = tv::rowArrayIdx(pointPtr, outSpatialShape.data()) + spatialVolume * indicesIn(ix, 0); indicePairs(offset, 1, oldNum) = index; indicePairUnique[offset * indicePairsDim2 + oldNum] = index; } } } template __global__ void prepareDeConvIndicePairsKernel( tv::TensorView indicesIn, tv::TensorView indicesOut, tv::TensorView gridsOut, tv::TensorView indicePairs, tv::TensorView indiceNum, tv::TensorView indicePairUnique, const tv::SimpleVector kernelSize, const tv::SimpleVector stride, const tv::SimpleVector padding, const tv::SimpleVector dilation, const tv::SimpleVector outSpatialShape) { auto numActIn = indicesIn.dim(0); Index spatialVolume = 1; #pragma unroll for (int i = 0; i < NDim; ++i) { spatialVolume *= outSpatialShape[i]; } Index kernelVolume = 1; #pragma unroll for (int i = 0; i < NDim; ++i) { kernelVolume *= kernelSize[i]; } Index numValidPoints = 0; Index validPoints[KernelMaxVolume * (NDim + 1)]; Index *pointPtr = nullptr; auto indicePairsDim2 = indicePairs.dim(2); Index index; for (int ix : tv::KernelLoopX(numActIn)) { numValidPoints = getValidOutPosTranspose( indicesIn.data() + ix * (NDim + 1) + 1, kernelSize.data(), stride.data(), padding.data(), dilation.data(), outSpatialShape.data(), validPoints); for (Index i = 0; i < numValidPoints; ++i) { pointPtr = validPoints + i * (NDim + 1); auto offset = pointPtr[NDim]; auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1)); indicePairs(offset, 0, oldNum) = ix; index = tv::rowArrayIdx(pointPtr, outSpatialShape.data()) + spatialVolume * indicesIn(ix, 0); indicePairs(offset, 1, oldNum) = index; indicePairUnique[offset * indicePairsDim2 + oldNum] = index; } } } template __global__ void assignGridAndIndiceOutKernel( tv::TensorView indicesOut, tv::TensorView gridsOut, int numAct, tv::TensorView indicePairs, tv::TensorView indicePairUnique, const tv::SimpleVector outSpatialShape, int batchSize) { Index index; auto indicesOutPtr = indicesOut.data(); for (int ix : tv::KernelLoopX(numAct)) { index = indicePairUnique[ix]; gridsOut[index] = ix; index = tv::rowArrayIdxInv( index, indicesOutPtr + ix * (NDim + 1) + 1, outSpatialShape.data()); indicesOut[ix * (NDim + 1)] = index % batchSize; } } template __global__ void assignIndicePairsKernel( tv::TensorView indicesOut, tv::TensorView gridsOut, int numActIn, tv::TensorView indicePairs, tv::TensorView indicePairUnique, const tv::SimpleVector outSpatialShape) { Index index; int kernelVolume = indicePairs.dim(0); for (int ix : tv::KernelLoopX(numActIn)) { for (int i = 0; i < kernelVolume; ++i) { index = indicePairs(i, 1, ix); if (index > -1) { indicePairs(i, 1, ix) = gridsOut[index]; } } } } template __global__ void prepareSubMGridKernel( tv::TensorView indicesIn, tv::TensorView gridsOut, const tv::SimpleVector outSpatialShape) { auto numActIn = indicesIn.dim(0); Index spatialVolume = 1; #pragma unroll for (int i = 0; i < NDim; ++i) { spatialVolume *= outSpatialShape[i]; } Index index = 0; for (int ix : tv::KernelLoopX(numActIn)) { index = tv::rowArrayIdx(indicesIn.data() + ix * (NDim + 1) + 1, outSpatialShape.data()) + spatialVolume * indicesIn(ix, 0); gridsOut[index] = ix; } } template __global__ void getSubMIndicePairsKernel( tv::TensorView indicesIn, tv::TensorView gridsOut, tv::TensorView indicePairs, tv::TensorView indiceNum, const tv::SimpleVector kernelSize, const tv::SimpleVector stride, const tv::SimpleVector padding, const tv::SimpleVector dilation, const tv::SimpleVector outSpatialShape) { auto numActIn = indicesIn.dim(0); Index spatialVolume = 1; #pragma unroll for (int i = 0; i < NDim; ++i) { spatialVolume *= outSpatialShape[i]; } Index numValidPoints = 0; Index validPoints[KernelMaxVolume * (NDim + 1)]; Index *pointPtr = nullptr; Index index = 0; for (int ix : tv::KernelLoopX(numActIn)) { numValidPoints = getValidOutPos( indicesIn.data() + ix * (NDim + 1) + 1, kernelSize.data(), stride.data(), padding.data(), dilation.data(), outSpatialShape.data(), validPoints); for (int i = 0; i < numValidPoints; ++i) { pointPtr = validPoints + i * (NDim + 1); auto offset = pointPtr[NDim]; index = tv::rowArrayIdx(pointPtr, outSpatialShape.data()) + spatialVolume * indicesIn(ix, 0); if (gridsOut[index] > -1) { auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1)); indicePairs(offset, 1, oldNum) = gridsOut[index]; indicePairs(offset, 0, oldNum) = ix; } } } } template __global__ void resetGridKernel(const Index *indicePairUnique, tv::TensorView gridsOut, int numAct) { for (int ix : tv::KernelLoopX(numAct)) { gridsOut[indicePairUnique[ix]] = -1; } } template __global__ void resetGridSubMKernel( const Index *indices, tv::TensorView gridsOut, const tv::SimpleVector outSpatialShape, int numAct) { int outSpatialShapeReg[NDim]; for (int i = 0; i < NDim; ++i) { outSpatialShapeReg[i] = outSpatialShape[i]; } Index spatialVolume = 1; auto indsPtr = indices; #pragma unroll for (int i = 0; i < NDim; ++i) { spatialVolume *= outSpatialShape[i]; } Index index; for (int ix : tv::KernelLoopX(numAct)) { indsPtr = indices + ix * (NDim + 1); index = tv::rowArrayIdx(indsPtr + 1, outSpatialShapeReg); gridsOut[index + spatialVolume * indsPtr[0]] = -1; } } #endif