diff --git "a/apex/apex/contrib/csrc/bottleneck/bottleneck.cpp" "b/apex/apex/contrib/csrc/bottleneck/bottleneck.cpp" new file mode 100644--- /dev/null +++ "b/apex/apex/contrib/csrc/bottleneck/bottleneck.cpp" @@ -0,0 +1,4073 @@ +#include +#include // for getcudnnhandle +#include +#include +#include +#include + +#include + +#ifdef DEBUG +#define DEBUG_MSG(str) do { std::cout << str << std::endl; } while( false ) +#else +#define DEBUG_MSG(str) do { } while ( false ) +#endif + +#ifdef DEBUG_CUDNN +#define DEBUG_CUDNN_MSG(buf, str) do { buf << str << std::endl; } while( false ) +#else +#define DEBUG_CUDNN_MSG(buf, str) do { } while ( false ) +#endif + +#define checkCudnnErr(...) \ + do { \ + int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ + if (err) { \ + return; \ + } \ + } while (0) + + +int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) { + if (code) { + printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr); + return 1; + } + return 0; +} + +void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort = true); +#define checkCUDAError(val) { checkError((val), #val, __FILE__, __LINE__); } // in-line regular function + +void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort) +{ + if (code != cudaSuccess) + { + const char * errorMessage = cudaGetErrorString(code); + fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code, errorMessage); + if (abort){ + cudaDeviceReset(); + exit(code); + } + } +} + +void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) { + // For INT8x4 and INT8x32 we still compute standard strides here to input + // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. + if (filterFormat == CUDNN_TENSOR_NCHW) { + strideA[nbDims - 1] = 1; + for (int64_t d = nbDims - 2; d >= 0; d--) { + strideA[d] = strideA[d + 1] * dimA[d + 1]; + } + } else { + // Here we assume that the format is CUDNN_TENSOR_NHWC + strideA[1] = 1; + strideA[nbDims - 1] = strideA[1] * dimA[1]; + for (int64_t d = nbDims - 2; d >= 2; d--) { + strideA[d] = strideA[d + 1] * dimA[d + 1]; + } + strideA[0] = strideA[2] * dimA[2]; + } +} + + +int getFwdConvDilatedFilterDim(int filterDim, int dilation) { + return ((filterDim - 1) * dilation) + 1; +} + +int getFwdConvPaddedImageDim(int tensorDim, int pad) { + return tensorDim + (2 * pad); +} + +int getFwdConvOutputDim( + int tensorDim, + int pad, + int filterDim, + int stride, + int dilation) +{ + int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1; + return (p); +} + +enum { + X_TENSOR, + Y_TENSOR, + W_TENSOR, + Z_TENSOR, + B_TENSOR, + AFTERADD_TENSOR, + AFTERBIAS_TENSOR, + AFTERCONV_TENSOR, + OPTIONAL, + AFTEROPT_TENSOR, +}; + +using common_conv_descriptors = + std::tuple; + + +common_conv_descriptors +create_common_descriptors(int64_t* x_dim_padded, + int64_t* padA, + int64_t* convstrideA, + int64_t* dilationA, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType, + cudnnConvolutionMode_t mode) { + const int convDim = 2; + + int64_t strideA_padded[4]; + int64_t outstrideA_padded[4]; + int64_t filterstrideA_padded[4]; + + generateStrides(w_dim_padded, filterstrideA_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(x_dim_padded, strideA_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(y_dim_padded, outstrideA_padded, 4, CUDNN_TENSOR_NHWC); + + return common_conv_descriptors(cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, strideA_padded) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, outstrideA_padded) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, w_dim_padded) + .setStrides(4, filterstrideA_padded) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(mode) + .setNDims(convDim) + .setStrides(convDim, convstrideA) + .setPrePadding(convDim, padA) + .setPostPadding(convDim, padA) + .setDilation(convDim, dilationA) + .build()); +} + +using common_convbias_descriptors = std::tuple; + +common_convbias_descriptors +create_conv_bias_add_act_descriptors(int64_t* x_dim_padded, + int64_t* padA, + int64_t* convstrideA, + int64_t* dilationA, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType) { + const int convDim = 2; + + int64_t b_dim_padded[4]; + b_dim_padded[0] = 1; + b_dim_padded[1] = y_dim_padded[1]; + b_dim_padded[2] = 1; + b_dim_padded[3] = 1; + + int64_t x_stride_padded[4]; + int64_t y_stride_padded[4]; + int64_t w_stride_padded[4]; + int64_t b_stride_padded[4]; + + generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); + + return common_convbias_descriptors(cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, w_dim_padded) + .setStrides(4, w_stride_padded) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('z') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('b') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setVirtual() + .setId('A') // after add + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setVirtual() + .setId('B') // after bias + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('C') // after conv + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('i') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('D') // after optional add + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_FLOAT) + .build()); +} + +// tensor descriptors used for dgrad +enum { + X_OR_DX_TENSOR, + DY_TENSOR, + W_OR_DW_TENSOR, + SCALE_TENSOR, + RELU_TENSOR, + AFTER_DCONV_TENSOR, + AFTER_DRELU_TENSOR, +}; + +using dconv_descriptors = std::tuple; + +dconv_descriptors +create_dconv_descriptors(int64_t* x_dim_padded, + int64_t* padA, + int64_t* convstrideA, + int64_t* dilationA, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType) { + const int convDim = 2; + + int64_t b_dim_padded[4]; + b_dim_padded[0] = 1; + b_dim_padded[1] = x_dim_padded[1]; + b_dim_padded[2] = 1; + b_dim_padded[3] = 1; + + int64_t x_stride_padded[4]; + int64_t y_stride_padded[4]; + int64_t w_stride_padded[4]; + int64_t b_stride_padded[4]; + + generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); + + return dconv_descriptors(cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, w_dim_padded) + .setStrides(4, w_stride_padded) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('s') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('r') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setVirtual() + .setId('A') // after dconv + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setVirtual() + .setId('B') // after drelu + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build()); +} + +// create a cache for plan +std::unordered_map plan_cache; + +// TODO: better name +std::string getConvFusionString(int64_t* x_dim_padded, + int64_t* padA, + int64_t* convstrideA, + int64_t* dilationA, + int64_t* w_dim_padded, + cudnnDataType_t dataType, + std::string fusion_string) { + + for(int i=0;i<4;i++) { + fusion_string += 'X'; + fusion_string += std::to_string(x_dim_padded[i]); + } + for(int i=0;i<4;i++) { + fusion_string += 'W'; + fusion_string += std::to_string(w_dim_padded[i]); + } + for(int i=0;i<2;i++) { + fusion_string += 'P'; + fusion_string += std::to_string(padA[i]); + } + for(int i=0;i<2;i++) { + fusion_string += 'S'; + fusion_string += std::to_string(convstrideA[i]); + } + for(int i=0;i<2;i++) { + fusion_string += 'D'; + fusion_string += std::to_string(dilationA[i]); + } + fusion_string += 'T'; + fusion_string += std::to_string(dataType); + return fusion_string; +} + +cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, + std::stringstream& log_buf, + cudnn_frontend::OperationGraph& opGraph, + std::string cache_string, + bool use_heuristic = true){ + auto it = plan_cache.find(cache_string); + if (it != plan_cache.end()) { + DEBUG_CUDNN_MSG(log_buf, "Found plan in cache"); + return it->second; + } else { + if (use_heuristic){ + // TODO: confirm which mode to use + auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() + .setOperationGraph(opGraph) + .setHeurMode(CUDNN_HEUR_MODE_INSTANT) + .build(); + // try 3 times for now as WAR for no heuristic training + int max_tries = 3, count = 0; + auto& engine_configs = heuristics.getEngineConfig(max_tries); + while(true) { + try { + plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle_) + .setEngineConfig(engine_configs[count], opGraph.getTag()) + .build())); + break; + } catch (cudnn_frontend::cudnnException e) { + if (++count == max_tries) throw e; + } + } + }else{ + DEBUG_CUDNN_MSG(log_buf, "No plan in cache"); + // How many engines support this operation graph ? + auto total_engines = opGraph.getEngineCount(); + DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << " has " << total_engines << " engines."); + // We have to randomly pick one engine from [0, total_engines) + // Selecting "0" by default + auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build(); + DEBUG_CUDNN_MSG(log_buf, engine.describe()); + auto& knobs = engine.getSupportedKnobs(); + for (auto it = std::begin(knobs); it != std::end(knobs); ++it) { + DEBUG_CUDNN_MSG(log_buf, it->describe()); + } + if (knobs.begin() != knobs.end()) { + DEBUG_CUDNN_MSG(log_buf, "Updated knob choice"); + knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1); + DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe()); + } + + // Createmplacee the requisite engine config + auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build(); + DEBUG_CUDNN_MSG(log_buf, engine_config.describe()); + plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build())); + } + + return plan_cache.find(cache_string)->second; + } +} + +void +run_conv_scale_bias_add_activation(int64_t* x_dim_padded, + int64_t* pad, + int64_t* convstride, + int64_t* dilation, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType, + at::Half* devPtrX, + at::Half* devPtrW, + at::Half* devPtrY, + at::Half* devPtrZ, + at::Half* devPtrB, + at::Half* devPtrI) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors( + x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the add operation + auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_MUL) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + // Define the bias operation + auto biasDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_ADD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + + // optional add + auto addDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_ADD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + + // Define the activation operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_FWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create a Add Node with scaling parameters. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(conv_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create a Bias Node. + auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(scale_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(biasDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); + + // Create a optional add Node. + auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(bias_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(addDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, add_op.describe()); + + + // Create an Activation Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor()) + .setyDesc(std::get(tensors)) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &scale_op, &bias_op, devPtrI ? &add_op : &act_op, &act_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(devPtrI ? ops.size() : 4, ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI}; + int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(devPtrI ? 6 : 5, data_ptrs) + .setUids(devPtrI ? 6 : 5, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} + +void +run_conv_scale_bias(int64_t* x_dim_padded, + int64_t* pad, + int64_t* convstride, + int64_t* dilation, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType, + at::Half* devPtrX, + at::Half* devPtrW, + at::Half* devPtrY, + at::Half* devPtrZ, + at::Half* devPtrB) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors( + x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the add operation + auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_MUL) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + // Define the bias operation + auto addDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_ADD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create a Add Node with scaling parameters. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(conv_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) // TODO: change enum to aftermul + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create a Bias Node. + auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(scale_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(addDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, add_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &scale_op, &add_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB}; + int64_t uids[] = {'x', 'y', 'w', 'z', 'b'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(5, data_ptrs) + .setUids(5, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} + + +void +run_dconv_drelu_dscale(int64_t* x_dim_padded, + int64_t* pad, + int64_t* convstride, + int64_t* dilation, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType, + at::Half* devPtrX, + at::Half* devPtrW, + at::Half* devPtrY, + at::Half* devPtrZ, + at::Half* devPtrR) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + dconv_descriptors tensors = create_dconv_descriptors( + x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Define the activation backward operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_BWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the scale backward operation + auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_MUL) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) + .setdxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setdyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // TODO: do we need getOutputTensor(), and what it returns in backward case? + // Create an relu backward Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setdyDesc(std::get(tensors)) + .setxDesc(std::get(tensors)) + .setdxDesc(std::get(tensors)) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create a Scale Node. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &act_op, &scale_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR}; + int64_t uids[] = {'x', 'y', 'w', 's', 'r'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(5, data_ptrs) + .setUids(5, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} + +void +run_dconv(int64_t* x_dim_padded, + int64_t* pad, + int64_t* convstride, + int64_t* dilation, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType, + at::Half* devPtrX, + at::Half* devPtrW, + at::Half* devPtrY, + cudnnBackendDescriptorType_t mode) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + dconv_descriptors tensors = create_dconv_descriptors( + x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + // mode should be one of following + // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR + // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR + auto conv_op_builder = cudnn_frontend::OperationBuilder(mode); + if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) { + conv_op_builder.setdxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setdyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta); + } + else { + conv_op_builder.setxDesc(std::get(tensors)) + .setdwDesc(std::get(tensors)) + .setdyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta); + } + auto conv_op = conv_op_builder.build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW}; + int64_t uids[] = {'x', 'y', 'w'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(3, data_ptrs) + .setUids(3, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} + +void +run_dconv_add(int64_t* x_dim_padded, + int64_t* pad, + int64_t* convstride, + int64_t* dilation, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType, + at::Half* devPtrX, + at::Half* devPtrW, + at::Half* devPtrY, + at::Half* devPtrR) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + dconv_descriptors tensors = create_dconv_descriptors( + x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Define the add backward operation + auto addDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_ADD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) + .setdxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setdyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // TODO: do we need getOutputTensor(), and what it returns in backward case? + // Create add Node. + auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(addDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, add_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &add_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrR}; + int64_t uids[] = {'x', 'y', 'w', 'r'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(4, data_ptrs) + .setUids(4, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} + + +// inputs contains x,w,z,b,(i) +std::vector bottleneck_forward(bool explicit_nhwc, int stride_1X1, std::vector inputs) { + + std::cout << std::fixed; + // create output vector + std::vector outputs; + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + // setup dimensions + int64_t dimA[] = {0, 0, 0, 0}; + int64_t filterdimA1[] = {0, 0, 0, 0}; + int64_t filterdimA2[] = {0, 0, 0, 0}; + int64_t filterdimA3[] = {0, 0, 0, 0}; + int64_t filterdimA4[] = {0, 0, 0, 0}; + + // All dim calculation after this order of n,c,h,w + int axis[] {0,1,2,3}; + if (explicit_nhwc) { + axis[0] = 0; + axis[1] = 3; + axis[2] = 1; + axis[3] = 2; + } + for (int dim=0;dim<4;dim++) { + dimA[dim] = inputs[0].size(axis[dim]); + filterdimA1[dim] = inputs[1].size(axis[dim]); + filterdimA2[dim] = inputs[2].size(axis[dim]); + filterdimA3[dim] = inputs[3].size(axis[dim]); + } + if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { + for (int dim=0;dim<4;dim++) { + filterdimA4[dim] = inputs[10].size(axis[dim]); + } + } + + // output dim in n,c,h,w used by backend + int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below + int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below + int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below + + // use these fixed value for test run + int64_t padA[] = {0, 0}; + int64_t padA1[] = {1, 1}; + int64_t dilationA[] = {1, 1}; + int64_t convstrideA[] = {1, 1}; + int64_t convstride1X1[] = {stride_1X1, stride_1X1}; + + // compute output from pad/stride/dilation + outdimA1[0] = dimA[0]; + outdimA1[1] = filterdimA1[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); + } + + outdimA2[0] = outdimA1[0]; + outdimA2[1] = filterdimA2[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); + } + + outdimA3[0] = outdimA2[0]; + outdimA3[1] = filterdimA3[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); + } + + // Create output tensor in the correct shape in pytorch's view + int64_t outdim1[] = {0, 0, 0, 0}; + int64_t outdim2[] = {0, 0, 0, 0}; + int64_t outdim3[] = {0, 0, 0, 0}; + if (explicit_nhwc) { + axis[0] = 0; + axis[1] = 2; + axis[2] = 3; + axis[3] = 1; + } + for (int dim=0;dim<4;dim++) { + outdim1[dim] = outdimA1[axis[dim]]; + outdim2[dim] = outdimA2[axis[dim]]; + outdim3[dim] = outdimA3[axis[dim]]; + } + + // run + at::Half* x = inputs[0].data_ptr(); + at::Half* w = inputs[1].data_ptr(); + at::Half* z = inputs[4].data_ptr(); + at::Half* b = inputs[7].data_ptr(); + auto out1 = at::empty(outdim1, inputs[0].type(), output_format); + at::Half* y1 = out1.data_ptr(); + + run_conv_scale_bias_add_activation(dimA, + padA, + convstride1X1, + dilationA, + filterdimA1, + outdimA1, + CUDNN_DATA_HALF, + x, + w, + y1, + z, + b, + nullptr); + + DEBUG_MSG("[DEBUG] new relu1 : " << out1.to(at::kFloat).sum().item()); + + w = inputs[2].data_ptr(); + z = inputs[5].data_ptr(); + b = inputs[8].data_ptr(); + auto out2 = at::empty(outdim2, inputs[0].type(), output_format); + at::Half* y2 = out2.data_ptr(); + + run_conv_scale_bias_add_activation(outdimA1, + padA1, + convstrideA, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + y1, + w, + y2, + z, + b, + nullptr); + DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); + + // create output of conv3 + auto out3 = at::empty(outdim3, inputs[0].type(), output_format); + at::Half* y3 = out3.data_ptr(); + + // create output of conv4 that may exist + auto identity = at::empty_like(out3); + at::Half* yi = identity.data_ptr(); + + if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]){ + + w = inputs[10].data_ptr(); + z = inputs[11].data_ptr(); + b = inputs[12].data_ptr(); + run_conv_scale_bias(dimA, + padA, + convstride1X1, + dilationA, + filterdimA4, + outdimA3, + CUDNN_DATA_HALF, + x, + w, + yi, + z, + b); + DEBUG_MSG("[DEBUG] new downsample : " << identity.to(at::kFloat).sum().item()); + } + else { + yi = x; + } + + w = inputs[3].data_ptr(); + z = inputs[6].data_ptr(); + b = inputs[9].data_ptr(); + + run_conv_scale_bias_add_activation(outdimA2, + padA, + convstrideA, + dilationA, + filterdimA3, + outdimA3, + CUDNN_DATA_HALF, + y2, + w, + y3, + z, + b, + yi); + DEBUG_MSG("[DEBUG] new relu3 : " << out3.to(at::kFloat).sum().item()); + + outputs.push_back(out1); + outputs.push_back(out2); + outputs.push_back(out3); + + return outputs; +} + +std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, std::vector inputs) { + + bool requires_grad = inputs[0].requires_grad(); + + std::cout << std::fixed; + // create output vector + std::vector outputs; + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + // setup dimensions + int64_t dimA[] = {0, 0, 0, 0}; + int64_t filterdimA1[] = {0, 0, 0, 0}; + int64_t filterdimA2[] = {0, 0, 0, 0}; + int64_t filterdimA3[] = {0, 0, 0, 0}; + int64_t filterdimA4[] = {0, 0, 0, 0}; + + // All dim calculation after this order of n,c,h,w + int axis[] {0,1,2,3}; + if (explicit_nhwc) { + axis[0] = 0; + axis[1] = 3; + axis[2] = 1; + axis[3] = 2; + } + for (int dim=0;dim<4;dim++) { + dimA[dim] = inputs[0].size(axis[dim]); + filterdimA1[dim] = inputs[1].size(axis[dim]); + filterdimA2[dim] = inputs[2].size(axis[dim]); + filterdimA3[dim] = inputs[3].size(axis[dim]); + } + if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { + for (int dim=0;dim<4;dim++) { + filterdimA4[dim] = inputs[14].size(axis[dim]); + } + } + + // output dim in n,c,h,w used by backend + int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below + int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below + int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below + + // use these fixed value for test run + int64_t padA[] = {0, 0}; + int64_t padA1[] = {1, 1}; + int64_t dilationA[] = {1, 1}; + int64_t convstrideA[] = {1, 1}; + int64_t convstride1X1[] = {stride_1X1, stride_1X1}; + + // compute output from pad/stride/dilation + outdimA1[0] = dimA[0]; + outdimA1[1] = filterdimA1[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); + } + + outdimA2[0] = outdimA1[0]; + outdimA2[1] = filterdimA2[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); + } + + outdimA3[0] = outdimA2[0]; + outdimA3[1] = filterdimA3[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); + } + + // Create output tensor in the correct shape in pytorch's view + int64_t outdim1[] = {0, 0, 0, 0}; + int64_t outdim2[] = {0, 0, 0, 0}; + int64_t outdim3[] = {0, 0, 0, 0}; + if (explicit_nhwc) { + axis[0] = 0; + axis[1] = 2; + axis[2] = 3; + axis[3] = 1; + } + for (int dim=0;dim<4;dim++) { + outdim1[dim] = outdimA1[axis[dim]]; + outdim2[dim] = outdimA2[axis[dim]]; + outdim3[dim] = outdimA3[axis[dim]]; + } + + // dconv3+drelu2+dscale2 + at::Half* conv_in = inputs[13].data_ptr(); + at::Half* dy3 = inputs[10].data_ptr(); + + DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item()); + + // wgrad + auto wgrad3 = at::empty_like(inputs[3]); + at::Half* dw3 = wgrad3.data_ptr(); + run_dconv(outdimA2, + padA, + convstrideA, + dilationA, + filterdimA3, + outdimA3, + CUDNN_DATA_HALF, + conv_in, + dw3, + dy3, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + + // dgrad + auto grad_out2 = at::empty(outdim2, inputs[0].type(), output_format); + at::Half* dy2 = grad_out2.data_ptr(); + at::Half* w = inputs[3].data_ptr(); + at::Half* z = inputs[5].data_ptr(); + + at::Half* relu2 = inputs[13].data_ptr(); + + run_dconv_drelu_dscale(outdimA2, + padA, + convstrideA, + dilationA, + filterdimA3, + outdimA3, + CUDNN_DATA_HALF, + dy2, + w, + dy3, + z, + relu2); + + DEBUG_MSG("[DEBUG] new dconv2 : " << grad_out2.to(at::kFloat).sum().item()); + + // dconv2+drelu1+dscale1 + conv_in = inputs[12].data_ptr(); + + // wgrad + auto wgrad2 = at::empty_like(inputs[2]); + at::Half* dw2 = wgrad2.data_ptr(); + run_dconv(outdimA1, + padA1, + convstrideA, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + conv_in, + dw2, + dy2, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + + // dgrad + auto grad_out1 = at::empty(outdim1, inputs[0].type(), output_format); + at::Half* dy1 = grad_out1.data_ptr(); + w = inputs[2].data_ptr(); + z = inputs[4].data_ptr(); + + at::Half* relu1 = inputs[12].data_ptr(); + // fused dgrad + run_dconv_drelu_dscale(outdimA1, + padA1, + convstrideA, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + dy1, + w, + dy2, + z, + relu1); + +/* + // backward strided conv cannot be fused + // if stride == 1 but channel changes, we can fuse here + if (stride_1X1 != 1){ + // dgrad + run_dconv(outdimA1, + padA1, + convstride1X1, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + dy1, + w, + dy2, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); + + // mul fused mask + grad_out1.mul_(inputs[15]); + } + else { + at::Half* relu1 = inputs[12].data_ptr(); + // fused dgrad + run_dconv_drelu_dscale(outdimA1, + padA1, + convstride1X1, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + dy1, + w, + dy2, + z, + relu1); + } +*/ + DEBUG_MSG("[DEBUG] new dconv1 : " << grad_out1.to(at::kFloat).sum().item()); + + // create grads of conv4 that may exist + auto grad_x_conv4 = at::empty_like(inputs[0]); + at::Half* dx_conv4 = grad_x_conv4.data_ptr(); + at::Tensor wgrad4; + + // x used for dconv1 and dconv4 wgrad + at::Half* x = inputs[0].data_ptr(); + + if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]){ + w = inputs[14].data_ptr(); + at::Half* dy_conv4 = inputs[11].data_ptr(); + if (requires_grad) { + run_dconv(dimA, + padA, + convstride1X1, + dilationA, + filterdimA4, + outdimA3, + CUDNN_DATA_HALF, + dx_conv4, + w, + dy_conv4, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); + // we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx + // DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item()); + } + // wgrad + wgrad4 = at::empty_like(inputs[14]); + at::Half* dw4 = wgrad4.data_ptr(); + run_dconv(dimA, + padA, + convstride1X1, + dilationA, + filterdimA4, + outdimA3, + CUDNN_DATA_HALF, + x, + dw4, + dy_conv4, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + } + else { + // if there is no downsample, dx_conv4 is fork of drelu3 + dx_conv4 = inputs[11].data_ptr(); + } + + // dconv1+add + // wgrad + auto wgrad1 = at::empty_like(inputs[1]); + at::Half* dw1 = wgrad1.data_ptr(); + run_dconv(dimA, + padA, + convstride1X1, + dilationA, + filterdimA1, + outdimA1, + CUDNN_DATA_HALF, + x, + dw1, + dy1, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + + // dgrad + w = inputs[1].data_ptr(); + auto grad_x = at::empty_like(inputs[0]); + at::Half* dx = grad_x.data_ptr(); + + // backward strided conv cannot be fused + // if stride == 1 but channel changes, we can fuse here + if (requires_grad){ + if (stride_1X1 != 1){ + run_dconv(dimA, + padA, + convstride1X1, + dilationA, + filterdimA1, + outdimA1, + CUDNN_DATA_HALF, + dx, + w, + dy1, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); + // add 2 together + grad_x.add_(grad_x_conv4); + } + else { + run_dconv_add(dimA, + padA, + convstride1X1, + dilationA, + filterdimA1, + outdimA1, + CUDNN_DATA_HALF, + dx, + w, + dy1, + dx_conv4); + } + } + + DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item()); + DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item()); + DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item()); + DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item()); + outputs.push_back(grad_x); + outputs.push_back(wgrad1); + outputs.push_back(wgrad2); + outputs.push_back(wgrad3); + + if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { + DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item()); + outputs.push_back(wgrad4); + } + + return outputs; +} + +namespace { + +enum { + X_TENSOR, + Y_TENSOR, + W_TENSOR, + Z_TENSOR, + B_TENSOR, + AFTERADD_TENSOR, + AFTERBIAS_TENSOR, + AFTERCONV_TENSOR, + OPTIONAL, + AFTEROPT_TENSOR, + AFTERACT_TENSOR, + GEN_INDEX_TENSOR, + MASK_TOP_TENSOR, + MASK_BOTTOM_TENSOR, + MASK_TENSOR, + THRESHOLD_TOP_TENSOR, + THRESHOLD_BOTTOM_TENSOR, +}; + +using masked_convbias_descriptors = std::tuple; + +masked_convbias_descriptors +create_conv_bias_add_act_mask_descriptors(int64_t* x_dim_padded, + int64_t* padA, + int64_t* convstrideA, + int64_t* dilationA, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + int64_t* threshold_dim, + cudnnDataType_t dataType) { + const int convDim = 2; + + int64_t b_dim_padded[4]; + b_dim_padded[0] = 1; + b_dim_padded[1] = y_dim_padded[1]; + b_dim_padded[2] = 1; + b_dim_padded[3] = 1; + + int64_t x_stride_padded[4]; + int64_t y_stride_padded[4]; + int64_t w_stride_padded[4]; + int64_t b_stride_padded[4]; + int64_t threshold_stride[4]; + + generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC); + + return masked_convbias_descriptors(cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, w_dim_padded) + .setStrides(4, w_stride_padded) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('z') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('b') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setVirtual() + .setId('A') // after add + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setVirtual() + .setId('B') // after bias + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('C') // after conv + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('i') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('D') // after optional add + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('E') // after act for masked + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('I') // output of the gen index operation + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_INT32) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('m') // top half of the mask created after the less than + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_BOOLEAN) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('n') // bottom half of the mask + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_BOOLEAN) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('M') // OR of the top and bottom masks + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_BOOLEAN) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, threshold_dim) + .setStrides(4, threshold_stride) + .setId('t') // threshold for creating the top mask + .setAlignment(16) + .setDataType(CUDNN_DATA_INT32) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, threshold_dim) + .setStrides(4, threshold_stride) + .setId('u') // threshold for creating the bottom mask + .setAlignment(16) + .setDataType(CUDNN_DATA_INT32) + .build()); +} + +// tensor descriptors used for dgrad +enum { + X_OR_DX_TENSOR, + DY_TENSOR, + W_OR_DW_TENSOR, + SCALE_TENSOR, + RELU_TENSOR, + AFTER_DCONV_TENSOR, + AFTER_DRELU_TENSOR, + DGRAD_INPUT_TENSOR, + DGRAD_OPTIONAL_TENSOR, + DGRAD_GEN_INDEX_TENSOR, + DGRAD_MASK_TOP_TENSOR, + DGRAD_MASK_BOTTOM_TENSOR, + DGRAD_MASK_TENSOR, + DGRAD_THRESHOLD_TOP_TENSOR, + DGRAD_THRESHOLD_BOTTOM_TENSOR, +}; + +using dconv_add_descriptors = std::tuple; + +dconv_add_descriptors +create_dconv_add_descriptors(int64_t* x_dim_padded, + int64_t* padA, + int64_t* convstrideA, + int64_t* dilationA, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType) { + const int convDim = 2; + + int64_t b_dim_padded[4]; + b_dim_padded[0] = 1; + b_dim_padded[1] = x_dim_padded[1]; + b_dim_padded[2] = 1; + b_dim_padded[3] = 1; + + int64_t x_stride_padded[4]; + int64_t y_stride_padded[4]; + int64_t w_stride_padded[4]; + int64_t b_stride_padded[4]; + + generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); + + return dconv_add_descriptors(cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, w_dim_padded) + .setStrides(4, w_stride_padded) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('s') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('r') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setVirtual() + .setId('A') // after dconv + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setVirtual() + .setId('B') // after drelu + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('i') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('D') // after optional add + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_FLOAT) + .build()); +} + +using dconv_mask_descriptors = std::tuple; + +dconv_mask_descriptors +create_dconv_mask_descriptors(int64_t* x_dim_padded, + int64_t* padA, + int64_t* convstrideA, + int64_t* dilationA, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + int64_t* threshold_dim, + cudnnDataType_t dataType) { + const int convDim = 2; + + int64_t b_dim_padded[4]; + b_dim_padded[0] = 1; + b_dim_padded[1] = x_dim_padded[1]; + b_dim_padded[2] = 1; + b_dim_padded[3] = 1; + + int64_t x_stride_padded[4]; + int64_t y_stride_padded[4]; + int64_t w_stride_padded[4]; + int64_t b_stride_padded[4]; + int64_t threshold_stride[4]; + + generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); + generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC); + + return dconv_mask_descriptors(cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('x') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('y') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, w_dim_padded) + .setStrides(4, w_stride_padded) + .setId('w') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, b_dim_padded) + .setStrides(4, b_stride_padded) + .setId('s') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setId('r') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setVirtual() + .setId('A') // after dconv + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, x_dim_padded) + .setStrides(4, x_stride_padded) + .setVirtual() + .setId('B') // after drelu + .setAlignment(16) + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('i') + .setAlignment(16) + .setDataType(dataType) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('D') // after optional add + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_FLOAT) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('I') // output of the gen index operation + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_INT32) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('m') // top half of the mask created after the less than + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_BOOLEAN) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('n') // bottom half of the mask + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_BOOLEAN) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, y_dim_padded) + .setStrides(4, y_stride_padded) + .setId('M') // OR of the top and bottom masks + .setAlignment(16) + .setVirtual() + .setDataType(CUDNN_DATA_BOOLEAN) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, threshold_dim) + .setStrides(4, threshold_stride) + .setId('t') // threshold for creating the top mask + .setAlignment(16) + .setDataType(CUDNN_DATA_INT32) + .build(), + cudnn_frontend::TensorBuilder() + .setDim(4, threshold_dim) + .setStrides(4, threshold_stride) + .setId('u') // threshold for creating the bottom mask + .setAlignment(16) + .setDataType(CUDNN_DATA_INT32) + .build()); +} + +void +run_conv_add_scale_bias_activation(int64_t* x_dim_padded, + int64_t* pad, + int64_t* convstride, + int64_t* dilation, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType, + at::Half* devPtrX, + at::Half* devPtrW, + at::Half* devPtrY, + at::Half* devPtrZ, + at::Half* devPtrB, + at::Half* devPtrI) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors( + x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the add operation + auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_MUL) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + // Define the bias operation + auto biasDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_ADD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + + // optional add + auto addDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_ADD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + + // Define the activation operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_FWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // create an add node. + auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(conv_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(addDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, add_op.describe()); + + // Create a Add Node with scaling parameters. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(add_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create a Bias Node. + auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(scale_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(biasDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); + + // Create an Activation Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(bias_op.getOutputTensor()) + .setyDesc(std::get(tensors)) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &add_op, &scale_op, &bias_op, &act_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI}; + int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(6, data_ptrs) + .setUids(6, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} + +void +run_conv_scale_bias_add_activation_mask(int64_t* x_dim_padded, + int64_t* pad, + int64_t* convstride, + int64_t* dilation, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + int64_t* threshold_dim, + cudnnDataType_t dataType, + at::Half* devPtrX, + at::Half* devPtrW, + at::Half* devPtrY, + at::Half* devPtrZ, + at::Half* devPtrB, + at::Half* devPtrI, + int* devPtrT, + int* devPtrU, + int axis) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + masked_convbias_descriptors tensors = create_conv_bias_add_act_mask_descriptors( + x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the add operation + auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_MUL) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + // Define the bias operation + auto biasDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_ADD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); + + // optional add + auto addDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_ADD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + + // Define the activation operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_FWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Define the genIndex descriptor + auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_GEN_INDEX) + .setMathPrecision(CUDNN_DATA_FLOAT) + .setAxis(axis) + .build(); + DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe()); + + // Define the lessThan descriptor + auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_CMP_LT) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe()); + + // Define the greaterThan descriptor + auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_CMP_GT) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe()); + + // Define the logical_or descriptor + auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_LOGICAL_OR) + .setMathPrecision(CUDNN_DATA_BOOLEAN) + .build(); + DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe()); + + // Define the binary_selection descriptor + auto selectionDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_BINARY_SELECT) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create a Add Node with scaling parameters. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(conv_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create a Bias Node. + auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(scale_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(biasDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); + + // Create a optional add Node. + auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(bias_op.getOutputTensor()) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(addDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, add_op.describe()); + + + // Create an Activation Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor()) + .setyDesc(std::get(tensors)) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create a Gen_Index Node. + auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(genIndexDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe()); + + // Create a LessThan Node. + auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(lessThanDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe()); + + // Create a GreaterThan Node. + auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(greaterThanDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe()); + + // Create a LogicalOr Node. + auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(logicalOrDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe()); + + // Create a Binary_Selection Node. + auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .settDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(selectionDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, selection_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + if (devPtrI) { + + std::array ops = {&conv_op, &scale_op, &bias_op, &add_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI, devPtrT, devPtrU}; + int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i', 't', 'u'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(8, data_ptrs) + .setUids(8, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } else { + + std::array ops = {&conv_op, &scale_op, &bias_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrT, devPtrU}; + int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 't', 'u'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(7, data_ptrs) + .setUids(7, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} + +void +run_dconv_add_drelu_dscale(int64_t* x_dim_padded, + int64_t* pad, + int64_t* convstride, + int64_t* dilation, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + cudnnDataType_t dataType, + at::Half* devPtrX, + at::Half* devPtrW, + at::Half* devPtrY, + at::Half* devPtrZ, + at::Half* devPtrR, + at::Half* devPtrI) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + dconv_add_descriptors tensors = create_dconv_add_descriptors( + x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // optional add + auto addDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_ADD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); + + // Define the activation backward operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_BWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the scale backward operation + auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_MUL) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) + .setdxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setdyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // Create add Node. + auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(addDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, add_op.describe()); + + // TODO: do we need getOutputTensor(), and what it returns in backward case? + // Create an relu backward Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setdyDesc(std::get(tensors)) + .setxDesc(std::get(tensors)) + .setdxDesc(std::get(tensors)) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create a Scale Node. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &add_op, &act_op, &scale_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrI}; + int64_t uids[] = {'x', 'y', 'w', 's', 'r', 'i'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(6, data_ptrs) + .setUids(6, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} + +void +run_dconv_drelu_dscale_mask(int64_t* x_dim_padded, + int64_t* pad, + int64_t* convstride, + int64_t* dilation, + int64_t* w_dim_padded, + int64_t* y_dim_padded, + int64_t* threshold_dim, + cudnnDataType_t dataType, + at::Half* devPtrX, + at::Half* devPtrW, + at::Half* devPtrY, + at::Half* devPtrZ, + at::Half* devPtrR, + int* devPtrT, + int* devPtrU, + int axis) { + cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + std::stringstream log_buf; + try { + int convDim = 2; + + // Creates the necessary tensor descriptors + dconv_mask_descriptors tensors = create_dconv_mask_descriptors( + x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); + + // Define the convolution problem + auto convDesc = cudnn_frontend::ConvDescBuilder() + .setDataType(CUDNN_DATA_FLOAT) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(convDim) + .setStrides(convDim, convstride) + .setPrePadding(convDim, pad) + .setPostPadding(convDim, pad) + .setDilation(convDim, dilation) + .build(); + DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); + + // Define the activation backward operation + auto actDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_RELU_BWD) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); + + // Define the scale backward operation + auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_MUL) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); + + // Define the genIndex descriptor + auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_GEN_INDEX) + .setMathPrecision(CUDNN_DATA_FLOAT) + .setAxis(axis) + .build(); + DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe()); + + // Define the lessThan descriptor + auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_CMP_LT) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe()); + + // Define the greaterThan descriptor + auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_CMP_GT) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe()); + + // Define the logical_or descriptor + auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_LOGICAL_OR) + .setMathPrecision(CUDNN_DATA_BOOLEAN) + .build(); + DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe()); + + // Define the binary_selection descriptor + auto selectionDesc = cudnn_frontend::PointWiseDescBuilder() + .setMode(CUDNN_POINTWISE_BINARY_SELECT) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe()); + + float alpha = 1.0f; + float beta = 0.0f; + + // Create a convolution Node + auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) + .setdxDesc(std::get(tensors)) + .setwDesc(std::get(tensors)) + .setdyDesc(std::get(tensors)) + .setcDesc(convDesc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); + + // TODO: do we need getOutputTensor(), and what it returns in backward case? + // Create an relu backward Node. + auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setdyDesc(std::get(tensors)) + .setxDesc(std::get(tensors)) + .setdxDesc(std::get(tensors)) + .setpwDesc(actDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, act_op.describe()); + + // Create a Scale Node. + auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(scaleDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); + + // Create a Gen_Index Node. + auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(genIndexDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe()); + + // Create a LessThan Node. + auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(lessThanDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe()); + + // Create a GreaterThan Node. + auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(greaterThanDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe()); + + // Create a LogicalOr Node. + auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(logicalOrDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe()); + + // Create a Binary_Selection Node. + auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(std::get(tensors)) + .setbDesc(std::get(tensors)) + .settDesc(std::get(tensors)) + .setyDesc(std::get(tensors)) + .setpwDesc(selectionDesc) + .build(); + DEBUG_CUDNN_MSG(log_buf, selection_op.describe()); + + // Create an Operation Graph. In this case it is convolution add bias activation + std::array ops = {&conv_op, &act_op, &scale_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op}; + + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(handle_) + .setOperationGraph(ops.size(), ops.data()) + .build(); + + // Create string encoding for plan caching + auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); + DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); + + auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); + DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); + + auto workspace_size = plan.getWorkspaceSize(); + DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); + + void* workspace_ptr = nullptr; + auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); + if (workspace_size > 0) { + workspace_ptr = workspace_tensor.data_ptr(); + } + void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrT, devPtrU}; + int64_t uids[] = {'x', 'y', 'w', 's', 'r', 't', 'u'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace_ptr) + .setDataPointers(7, data_ptrs) + .setUids(7, uids) + .build(); + DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); + cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); + checkCudnnErr(status); + cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); + } catch (cudnn_frontend::cudnnException e) { + std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; + } +} + +struct bottleneck_forward_status { + + int64_t dimA[4]; + int64_t filterdimA1[4]; + int64_t filterdimA2[4]; + int64_t filterdimA2hh[4]; + int64_t filterdimA3[4]; + int64_t filterdimA4[4]; + + int64_t threshdim[4]; + + int axis[4]; + + int64_t outdimA0[4]; + int64_t outdimA1[4]; + int64_t outdimA1b[4]; // out1_pad + int64_t outdimA2[4]; + int64_t outdimA3[4]; + int64_t outdimA4[4]; + + int64_t padA[2]; + int64_t padA1[2]; + int64_t padA2[2]; // halo padding + int64_t dilationA[2]; + int64_t convstrideA[2]; + int64_t convstride1X1[2]; + + int64_t outdim0[4]; // halo input shape + int64_t outdim1[4]; + int64_t outdim1b[4]; + int64_t outdim2[4]; + int64_t outdim3[4]; + int64_t outdim4[4]; // halo output shape + + void init(bool explicit_nhwc, int stride_1X1, std::vector inputs) { + dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0; + filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0; + filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0; + filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0; + filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0; + filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0; + threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1; + + // All dim calculation after this order of n,c,h,w + if (explicit_nhwc) { + axis[0] = 0; + axis[1] = 3; + axis[2] = 1; + axis[3] = 2; + } else { + axis[0] = 0; + axis[1] = 1; + axis[2] = 2; + axis[3] = 3; + } + + for (int dim=0;dim<4;dim++) { + dimA[dim] = inputs[0].size(axis[dim]); + filterdimA1[dim] = inputs[1].size(axis[dim]); + filterdimA2[dim] = inputs[2].size(axis[dim]); + filterdimA3[dim] = inputs[3].size(axis[dim]); + } + if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { + for (int dim=0;dim<4;dim++) { + filterdimA4[dim] = inputs[10].size(axis[dim]); + } + } + for (int dim=0;dim<4;dim++) { + if (dim == 2) { + filterdimA2hh[dim] = 1; + } else { + filterdimA2hh[dim] = filterdimA2[dim]; + } + } + + // output dim in n,c,h,w used by backend + outdimA0[0] = outdimA0[1] = outdimA0[2] = outdimA0[3] = 0; + outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0; + outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0; + outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0; + outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0; + outdimA4[0] = outdimA4[1] = outdimA4[2] = outdimA4[3] = 0; + + // use these fixed value for test run + padA[0] = 0; padA[1] = 0; + padA1[0] = 1; padA1[1] = 1; + padA2[0] = 0; padA2[1] = 1; + dilationA[0] = 1; dilationA[1] = 1; + convstrideA[0] = 1; convstrideA[1] = 1; + convstride1X1[0] = stride_1X1; convstride1X1[1] = stride_1X1; + + // compute output from pad/stride/dilation + outdimA1[0] = dimA[0]; + outdimA1[1] = filterdimA1[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); + } + for (int dim = 0; dim < 4; dim++) { + if (dim == 2) { + outdimA1b[dim] = outdimA1[dim] + 2; + } else { + outdimA1b[dim] = outdimA1[dim]; + } + } + + outdimA2[0] = outdimA1[0]; + outdimA2[1] = filterdimA2[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); + } + + for (int dim = 0; dim < 4; dim++) { + if (dim == 2) { + outdimA0[dim] = 3; + outdimA4[dim] = 1; + } else { + outdimA0[dim] = outdimA1[dim]; + outdimA4[dim] = outdimA2[dim]; + } + } + + outdimA3[0] = outdimA2[0]; + outdimA3[1] = filterdimA3[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); + } + + // Create output tensor in the correct shape in pytorch's view + outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0; + outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0; + outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0; + outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0; + if (explicit_nhwc) { + axis[0] = 0; + axis[1] = 2; + axis[2] = 3; + axis[3] = 1; + } + for (int dim=0;dim<4;dim++) { + outdim0[dim] = outdimA0[axis[dim]]; + outdim1[dim] = outdimA1[axis[dim]]; + outdim1b[dim] = outdimA1b[axis[dim]]; + outdim2[dim] = outdimA2[axis[dim]]; + outdim3[dim] = outdimA3[axis[dim]]; + outdim4[dim] = outdimA4[axis[dim]]; + } + } +}; + +bottleneck_forward_status forward_state; + +} // end of anonymous namespace + +std::vector bottleneck_forward_init(bool explicit_nhwc, int stride_1X1, std::vector inputs) { + // NB! Bottleneck_forward and bottleneck_backward are NOT thread safe method. + // NB! We use a global object to store state. + forward_state.init(explicit_nhwc, stride_1X1, inputs); + + // create output vector + std::vector outputs; + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + //printf("outdim1 = (%d,%d,%d,%d)\n",forward_state.outdim1[0],forward_state.outdim1[1],forward_state.outdim1[2],forward_state.outdim1[3]); + auto out1 = at::empty(forward_state.outdim1, inputs[0].type(), output_format); + auto out2 = at::empty(forward_state.outdim2, inputs[0].type(), output_format); + auto out3 = at::empty(forward_state.outdim3, inputs[0].type(), output_format); + + outputs.push_back(out1); + outputs.push_back(out2); + outputs.push_back(out3); + + return outputs; +} + +// inputs contains x,w,z,b,(i) +void bottleneck_forward_out1(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { + + std::cout << std::fixed; + + // run + at::Half* x = inputs[0].data_ptr(); + at::Half* w = inputs[1].data_ptr(); + at::Half* z = inputs[4].data_ptr(); + at::Half* b = inputs[7].data_ptr(); + auto out1 = outputs[0]; + at::Half* y1 = out1.data_ptr(); + + run_conv_scale_bias_add_activation(forward_state.dimA, + forward_state.padA, + forward_state.convstride1X1, + forward_state.dilationA, + forward_state.filterdimA1, + forward_state.outdimA1, + CUDNN_DATA_HALF, + x, + w, + y1, + z, + b, + nullptr); + + DEBUG_MSG("[DEBUG] new relu1 : " << out1.to(at::kFloat).sum().item()); +} + +// computes halo (top or bottom) from fat halo input. +// fat halo input is 3 pixels wide in H. +at::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_y1, std::vector inputs) { + + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + // run + at::Half* w = inputs[2].data_ptr(); + at::Half* z = inputs[5].data_ptr(); + at::Half* b = inputs[8].data_ptr(); + + at::Half* y1 = fat_halo_y1.data_ptr(); + + auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format); + at::Half* y2 = halo_y2.data_ptr(); + + run_conv_scale_bias_add_activation(forward_state.outdimA0, + forward_state.padA2, + forward_state.convstrideA, + forward_state.dilationA, + forward_state.filterdimA2, + forward_state.outdimA4, + CUDNN_DATA_HALF, + y1, + w, + y2, + z, + b, + nullptr); + + return halo_y2; +} + +// compute halo correction term (top or bottom) from slim halo input (N,C,1,W). +// slim halo input is 1 pixel wide in H. +at::Tensor bottleneck_forward_out2_halo_corr(bool explicit_nhwc, at::Tensor slim_halo_y1, std::vector inputs, at::Tensor w1by3, at::Tensor out2_part_halo) { + + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + // run + at::Half* w = w1by3.data_ptr(); // C,C,1,3 + at::Half* z = inputs[5].data_ptr(); + at::Half* b = inputs[8].data_ptr(); + + at::Half* y1 = slim_halo_y1.data_ptr(); + + at::Half* prev_out2 = out2_part_halo.data_ptr(); + + auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format); + at::Half* y2 = halo_y2.data_ptr(); + + run_conv_add_scale_bias_activation(forward_state.outdimA4, + forward_state.padA2, + forward_state.convstrideA, + forward_state.dilationA, + forward_state.filterdimA2hh, + forward_state.outdimA4, + CUDNN_DATA_HALF, + y1, + w, + y2, + z, + b, + prev_out2); + + return halo_y2; +} + +void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { + + std::cout << std::fixed; + + // from _out1 method + at::Half* x = inputs[0].data_ptr(); + auto out1 = outputs[0]; + at::Half* y1 = out1.data_ptr(); + + // run + at::Half* w = inputs[2].data_ptr(); + at::Half* z = inputs[5].data_ptr(); + at::Half* b = inputs[8].data_ptr(); + auto out2 = outputs[1]; + at::Half* y2 = out2.data_ptr(); + + //printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]); + //printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]); + //printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]); + //printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]); + //printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]); + //printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]); + run_conv_scale_bias_add_activation(forward_state.outdimA1, + forward_state.padA1, + forward_state.convstrideA, + forward_state.dilationA, + forward_state.filterdimA2, + forward_state.outdimA2, + CUDNN_DATA_HALF, + y1, + w, + y2, + z, + b, + nullptr); + DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); +} + +void bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor thresholdTop, at::Tensor thresholdBottom) { + + std::cout << std::fixed; + + // from _out1 method + at::Half* x = inputs[0].data_ptr(); + auto out1 = outputs[0]; + at::Half* y1 = out1.data_ptr(); + + // run + at::Half* w = inputs[2].data_ptr(); + at::Half* z = inputs[5].data_ptr(); + at::Half* b = inputs[8].data_ptr(); + auto out2 = outputs[1]; + at::Half* y2 = out2.data_ptr(); + + //printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]); + //printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]); + //printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]); + //printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]); + //printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]); + //printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]); + run_conv_scale_bias_add_activation_mask(forward_state.outdimA1, + forward_state.padA1, + forward_state.convstrideA, + forward_state.dilationA, + forward_state.filterdimA2, + forward_state.outdimA2, + forward_state.threshdim, + CUDNN_DATA_HALF, + y1, + w, + y2, + z, + b, + nullptr, + thresholdTop.data_ptr(), + thresholdBottom.data_ptr(), + 2); // axis == 1 -> Does this assume explicit NHWC? + DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); +} + +void bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor out1_pad) { + + std::cout << std::fixed; + + // from _out1 method + at::Half* x = inputs[0].data_ptr(); + auto out1 = outputs[0]; + at::Half* y1 = out1_pad.data_ptr(); + + // run + at::Half* w = inputs[2].data_ptr(); + at::Half* z = inputs[5].data_ptr(); + at::Half* b = inputs[8].data_ptr(); + auto out2 = outputs[1]; + at::Half* y2 = out2.data_ptr(); + + //printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]); + //printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]); + //printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]); + //printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]); + //printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]); + //printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]); + run_conv_scale_bias_add_activation(forward_state.outdimA1b, + forward_state.padA2, + forward_state.convstrideA, + forward_state.dilationA, + forward_state.filterdimA2, + forward_state.outdimA2, + CUDNN_DATA_HALF, + y1, + w, + y2, + z, + b, + nullptr); + DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); +} + +void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { + + std::cout << std::fixed; + + // from _out1 method + at::Half* x = inputs[0].data_ptr(); + + // create output of conv3 + auto out3 = outputs[2]; + at::Half* y3 = out3.data_ptr(); + + // create output of conv4 that may exist + auto identity = at::empty_like(out3); + at::Half* yi = identity.data_ptr(); + + at::Half *w, *z, *b; + + if (stride_1X1 != 1 || forward_state.filterdimA3[0] != forward_state.dimA[1]){ + + w = inputs[10].data_ptr(); + z = inputs[11].data_ptr(); + b = inputs[12].data_ptr(); + run_conv_scale_bias(forward_state.dimA, + forward_state.padA, + forward_state.convstride1X1, + forward_state.dilationA, + forward_state.filterdimA4, + forward_state.outdimA3, + CUDNN_DATA_HALF, + x, + w, + yi, + z, + b); + DEBUG_MSG("[DEBUG] new downsample : " << identity.to(at::kFloat).sum().item()); + } + else { + yi = x; + } + + auto out2 = outputs[1]; + at::Half* y2 = out2.data_ptr(); + + w = inputs[3].data_ptr(); + z = inputs[6].data_ptr(); + b = inputs[9].data_ptr(); + + run_conv_scale_bias_add_activation(forward_state.outdimA2, + forward_state.padA, + forward_state.convstrideA, + forward_state.dilationA, + forward_state.filterdimA3, + forward_state.outdimA3, + CUDNN_DATA_HALF, + y2, + w, + y3, + z, + b, + yi); + DEBUG_MSG("[DEBUG] new relu3 : " << out3.to(at::kFloat).sum().item()); +} + +namespace { + +struct bottleneck_backward_state { + + int64_t dimA[4]; + int64_t filterdimA1[4]; + int64_t filterdimA2[4]; + int64_t filterdimA3[4]; + int64_t filterdimA4[4]; + int64_t filterdimA2hh[4]; // Cin,Cout,1,3 + int64_t threshdim[4]; + + int axis[4]; + + int64_t outdimA1[4]; // grad_out1 + int64_t outdimA1b[4]; // out1_pad + int64_t outdimA2[4]; // grad_out2 + int64_t outdimA3[4]; + int64_t outdimA1h[4]; // output: grad_out1 halo (H=3) + int64_t outdimA2h[4]; // input : grad_out2 halo cells (H=3) + int64_t outdimA1hh[4]; // input: grad_out2 halo (H=1) + int64_t outdimA2hh[4]; // input: out1 halo (H=1) + + int64_t padA[2]; + int64_t padA1[2]; + int64_t padA2[2]; + int64_t dilationA[2]; + int64_t convstrideA[2]; + int64_t convstride1X1[2]; + + int64_t filterdim2hh[4]; // Cin,1,3,Cout + + int64_t outdim1[4]; + int64_t outdim1b[4]; + int64_t outdim2[4]; + int64_t outdim3[4]; + int64_t outdim1h[4]; + int64_t outdim1hh[4]; + + void init(bool explicit_nhwc, int stride_1X1, std::vector inputs) { + // setup dimensions + dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0; + filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0; + filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0; + filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0; + filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0; + filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0; + threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1; + + // All dim calculation after this order of n,c,h,w + if (explicit_nhwc) { + axis[0] = 0; + axis[1] = 3; + axis[2] = 1; + axis[3] = 2; + } else { + axis[0] = 0; + axis[1] = 1; + axis[2] = 2; + axis[3] = 3; + } + + for (int dim=0;dim<4;dim++) { + dimA[dim] = inputs[0].size(axis[dim]); + filterdimA1[dim] = inputs[1].size(axis[dim]); + filterdimA2[dim] = inputs[2].size(axis[dim]); + filterdimA3[dim] = inputs[3].size(axis[dim]); + } + if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { + for (int dim=0;dim<4;dim++) { + filterdimA4[dim] = inputs[14].size(axis[dim]); + } + } + + for (int dim=0;dim<4;dim++) { + if (dim == 2) { + filterdimA2hh[dim] = 1; + } else { + filterdimA2hh[dim] = filterdimA2[dim]; + } + } + + // output dim in n,c,h,w used by backend + outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0; + outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0; + outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0; + outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0; + outdimA1h[0] = outdimA1h[1] = outdimA1h[2] = outdimA1h[3] = 0; + outdimA2h[0] = outdimA2h[1] = outdimA2h[2] = outdimA2h[3] = 0; + outdimA1hh[0] = outdimA1hh[1] = outdimA1hh[2] = outdimA1hh[3] = 0; + outdimA2hh[0] = outdimA2hh[1] = outdimA2hh[2] = outdimA2hh[3] = 0; + + // use these fixed value for test run + padA[0] = 0; padA[1] = 0; + padA1[0] = 1; padA1[1] = 1; + padA2[0] = 0; padA2[1] = 1; + dilationA[0] = 1; dilationA[1] = 1; + convstrideA[0] = 1; convstrideA[1] = 1; + convstride1X1[0] = stride_1X1; convstride1X1[1] = stride_1X1; + + // compute output from pad/stride/dilation + outdimA1[0] = dimA[0]; + outdimA1[1] = filterdimA1[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); + } + for (int dim = 0; dim < 4; dim++) { + if (dim == 2) { + outdimA1b[dim] = outdimA1[dim] + 2; + } else { + outdimA1b[dim] = outdimA1[dim]; + } + } + + outdimA2[0] = outdimA1[0]; + outdimA2[1] = filterdimA2[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); + } + + outdimA3[0] = outdimA2[0]; + outdimA3[1] = filterdimA3[0]; + for (int dim = 0; dim < 2; dim++) { + outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); + } + + for (int dim = 0; dim < 4; dim++) { + if (dim == 2) { + outdimA1h[dim] = 3; + outdimA2h[dim] = 3; + outdimA1hh[dim] = 1; + outdimA2hh[dim] = 1; + } else { + outdimA1h[dim] = outdimA1[dim]; + outdimA2h[dim] = outdimA2[dim]; + outdimA1hh[dim] = outdimA1[dim]; + outdimA2hh[dim] = outdimA2[dim]; + } + } + + // Create output tensor in the correct shape in pytorch's view + outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0; + outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0; + outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0; + outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0; + outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0; + outdim1hh[0] = outdim1hh[1] = outdim1hh[2] = outdim1hh[3] = 0; + filterdim2hh[0] = filterdim2hh[1] = filterdim2hh[2] = filterdim2hh[3] = 0; + if (explicit_nhwc) { + axis[0] = 0; + axis[1] = 2; + axis[2] = 3; + axis[3] = 1; + } + for (int dim=0;dim<4;dim++) { + outdim1[dim] = outdimA1[axis[dim]]; + outdim1b[dim] = outdimA1b[axis[dim]]; + outdim2[dim] = outdimA2[axis[dim]]; + outdim3[dim] = outdimA3[axis[dim]]; + outdim1h[dim] = outdimA1h[axis[dim]]; + outdim1hh[dim] = outdimA1hh[axis[dim]]; + filterdim2hh[dim] = filterdimA2hh[axis[dim]]; + } + } +}; + +bottleneck_backward_state backward_state; + +} + +std::vector bottleneck_backward_init(bool explicit_nhwc, int stride_1X1, std::vector inputs) { + + std::cout << std::fixed; + + backward_state.init(explicit_nhwc, stride_1X1, inputs); + + // create output vector + std::vector outputs; + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + auto grad_x = at::empty_like(inputs[0]); + auto wgrad1 = at::empty_like(inputs[1]); + auto wgrad2 = at::empty_like(inputs[2]); + auto wgrad3 = at::empty_like(inputs[3]); + + outputs.push_back(grad_x); + outputs.push_back(wgrad1); + outputs.push_back(wgrad2); + outputs.push_back(wgrad3); + if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) { + auto wgrad4 = at::empty_like(inputs[14]); + outputs.push_back(wgrad4); + } + + return outputs; +} + +void bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { + + // dconv3+drelu2+dscale2 + at::Half* conv_in = inputs[13].data_ptr(); + at::Half* dy3 = inputs[10].data_ptr(); + + // wgrad + auto wgrad3 = outputs[3]; + at::Half* dw3 = wgrad3.data_ptr(); + run_dconv(backward_state.outdimA2, + backward_state.padA, + backward_state.convstrideA, + backward_state.dilationA, + backward_state.filterdimA3, + backward_state.outdimA3, + CUDNN_DATA_HALF, + conv_in, + dw3, + dy3, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item()); + +} + +at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { + + bool requires_grad = inputs[0].requires_grad(); + + std::cout << std::fixed; + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + // dconv3+drelu2+dscale2 + at::Half* conv_in = inputs[13].data_ptr(); + at::Half* dy3 = inputs[10].data_ptr(); + + DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item()); + + // dgrad + auto grad_out2 = at::empty(backward_state.outdim2, inputs[0].type(), output_format); + at::Half* dy2 = grad_out2.data_ptr(); + at::Half* w = inputs[3].data_ptr(); + at::Half* z = inputs[5].data_ptr(); + + at::Half* relu2 = inputs[13].data_ptr(); + + run_dconv_drelu_dscale(backward_state.outdimA2, + backward_state.padA, + backward_state.convstrideA, + backward_state.dilationA, + backward_state.filterdimA3, + backward_state.outdimA3, + CUDNN_DATA_HALF, + dy2, + w, + dy3, + z, + relu2); + + // do halo exchange of dy2 here + + DEBUG_MSG("[DEBUG] new dconv2 : " << grad_out2.to(at::kFloat).sum().item()); + + return grad_out2; +} + +at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2) { + + bool requires_grad = inputs[0].requires_grad(); + + std::cout << std::fixed; + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + // dgrad + at::Half* dy2 = grad_out2.data_ptr(); + + // dgrad + auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format); + at::Half* dy1 = grad_out1.data_ptr(); + at::Half* w = inputs[2].data_ptr(); + at::Half* z = inputs[4].data_ptr(); + + at::Half* relu1 = inputs[12].data_ptr(); + //printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3)); + + // fused dgrad + //printf("backward_state.outdim1 = {%d,%d,%d,%d}\n",backward_state.outdim1[0],backward_state.outdim1[1],backward_state.outdim1[2],backward_state.outdim1[3]); + run_dconv_drelu_dscale(backward_state.outdimA1, + backward_state.padA1, + backward_state.convstrideA, + backward_state.dilationA, + backward_state.filterdimA2, + backward_state.outdimA2, + CUDNN_DATA_HALF, + dy1, + w, + dy2, + z, + relu1); + + return grad_out1; +} + +at::Tensor bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2, at::Tensor thresholdTop, at::Tensor thresholdBottom) { + + bool requires_grad = inputs[0].requires_grad(); + + std::cout << std::fixed; + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + // dgrad + at::Half* dy2 = grad_out2.data_ptr(); + + // dgrad + auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format); + at::Half* dy1 = grad_out1.data_ptr(); + at::Half* w = inputs[2].data_ptr(); + at::Half* z = inputs[4].data_ptr(); + + at::Half* relu1 = inputs[12].data_ptr(); + //printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3)); + + // fused dgrad + run_dconv_drelu_dscale_mask(backward_state.outdimA1, + backward_state.padA1, + backward_state.convstrideA, + backward_state.dilationA, + backward_state.filterdimA2, + backward_state.outdimA2, + backward_state.threshdim, + CUDNN_DATA_HALF, + dy1, + w, + dy2, + z, + relu1, + thresholdTop.data_ptr(), + thresholdBottom.data_ptr(), + 2); + + return grad_out1; +} + +// perform backward data 1x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,1,W,C] with padding=(0,1) to produce output of shape [N,1,W,C] +at::Tensor bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int stride_1X1, std::vector inputs, at::Tensor w1by3, std::vector outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo, at::Tensor part_grad_out1) { + + bool requires_grad = inputs[0].requires_grad(); + + std::cout << std::fixed; + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + // dgrad + at::Half* dy2h = grad_out2_halo.data_ptr(); + + // dgrad + auto grad_out1_halo = at::empty(backward_state.outdim1hh, inputs[0].type(), output_format); + at::Half* dy1h = grad_out1_halo.data_ptr(); + //at::Half* w = inputs[2].data_ptr(); // use w1by3 instead, which is a sliced version of inputs[2] + at::Half* w = w1by3.data_ptr(); + at::Half* z = inputs[4].data_ptr(); + at::Half* relu1h = relu1_halo.data_ptr(); + at::Half* pdy1h = part_grad_out1.data_ptr(); + + //printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3)); + // fused dgrad + //printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]); + //printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]); + //printf("backward_state.filterdimA2 = {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]); + run_dconv_add_drelu_dscale(backward_state.outdimA1hh, + backward_state.padA2, // 0,1 + backward_state.convstrideA, + backward_state.dilationA, + backward_state.filterdimA2hh, // C,1,3,C + backward_state.outdimA2hh, + CUDNN_DATA_HALF, + dy1h, + w, + dy2h, + z, + relu1h, + pdy1h); + + return grad_out1_halo; +} + +// perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C] +at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo) { + + bool requires_grad = inputs[0].requires_grad(); + + std::cout << std::fixed; + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + // dgrad + at::Half* dy2h = grad_out2_halo.data_ptr(); + + // dgrad + auto grad_out1_halo = at::empty(backward_state.outdim1h, inputs[0].type(), output_format); + at::Half* dy1h = grad_out1_halo.data_ptr(); + at::Half* w = inputs[2].data_ptr(); + at::Half* z = inputs[4].data_ptr(); + + at::Half* relu1h = relu1_halo.data_ptr(); + //printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3)); + // fused dgrad + //printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]); + //printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]); + //printf("backward_state.filterdimA2 = {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]); + run_dconv_drelu_dscale(backward_state.outdimA1h, + backward_state.padA1, + backward_state.convstrideA, + backward_state.dilationA, + backward_state.filterdimA2, + backward_state.outdimA2h, + CUDNN_DATA_HALF, + dy1h, + w, + dy2h, + z, + relu1h); + + return grad_out1_halo; +} + +void bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor input, at::Tensor grad_out2) { + + std::cout << std::fixed; + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + // dgrad + at::Half* dy2 = grad_out2.data_ptr(); + + // dconv2+drelu1+dscale1 + at::Half* conv_in = input.data_ptr(); + + // wgrad + auto wgrad2 = outputs[2]; + at::Half* dw2 = wgrad2.data_ptr(); + + //printf("outdimA1b = (%d,%d,%d,%d)\n",backward_state.outdimA1b[0],backward_state.outdimA1b[1],backward_state.outdimA1b[2],backward_state.outdimA1b[3]); + //printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]); + run_dconv(backward_state.outdimA1b, // conv_in.shape (including H halos) + backward_state.padA2, // 0, 1 + backward_state.convstrideA, + backward_state.dilationA, + backward_state.filterdimA2, // dw2.shape + backward_state.outdimA2, // dy2.shape + CUDNN_DATA_HALF, + conv_in, + dw2, + dy2, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item()); +} + +void bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2) { + + bool requires_grad = inputs[0].requires_grad(); + + std::cout << std::fixed; + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + // dgrad + at::Half* dy2 = grad_out2.data_ptr(); + + // dconv2+drelu1+dscale1 + at::Half* conv_in = inputs[12].data_ptr(); + + // wgrad + auto wgrad2 = outputs[2]; + at::Half* dw2 = wgrad2.data_ptr(); + + //printf("outdimA1 = (%d,%d,%d,%d)\n",backward_state.outdimA1[0],backward_state.outdimA1[1],backward_state.outdimA1[2],backward_state.outdimA1[3]); + run_dconv(backward_state.outdimA1, + backward_state.padA1, + backward_state.convstrideA, + backward_state.dilationA, + backward_state.filterdimA2, + backward_state.outdimA2, + CUDNN_DATA_HALF, + conv_in, + dw2, + dy2, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item()); +} + +// compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension [N,1,W,C] +// input and grad_out2_halo tensors are all of same shape +// output tensor is of shape [Cin,1,3,Cout] (regular filter dims are [Cin,3,3,Cout] +at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor input, at::Tensor grad_out2_halo) { + + bool requires_grad = inputs[0].requires_grad(); + + std::cout << std::fixed; + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + // dgrad + at::Half* dy2 = grad_out2_halo.data_ptr(); + + // dconv2+drelu1+dscale1 + at::Half* conv_in = input.data_ptr(); + + // wgrad + auto wgrad2_halo = at::empty(backward_state.filterdim2hh, input.type(), output_format); + at::Half* dw2 = wgrad2_halo.data_ptr(); + + //printf("backward_state.outdimA1hh = {%d,%d,%d,%d}\n",backward_state.outdimA1hh[0],backward_state.outdimA1hh[1],backward_state.outdimA1hh[2],backward_state.outdimA1hh[3]); + //printf("backward_state.outdimA2hh = {%d,%d,%d,%d}\n",backward_state.outdimA2hh[0],backward_state.outdimA2hh[1],backward_state.outdimA2hh[2],backward_state.outdimA2hh[3]); + //printf("backward_state.filterdim2hh = {%d,%d,%d,%d}\n",backward_state.filterdim2hh[0],backward_state.filterdim2hh[1],backward_state.filterdim2hh[2],backward_state.filterdim2hh[3]); + //printf("backward_state.filterdimA2hh = {%d,%d,%d,%d}\n",backward_state.filterdimA2hh[0],backward_state.filterdimA2hh[1],backward_state.filterdimA2hh[2],backward_state.filterdimA2hh[3]); + //printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]); + run_dconv(backward_state.outdimA1hh, // N,C,1,W + backward_state.padA2, // 0, 1 + backward_state.convstrideA, + backward_state.dilationA, + backward_state.filterdimA2hh, // Cin,Cout,1,3 + backward_state.outdimA2hh, // N,C,1,W + CUDNN_DATA_HALF, + conv_in, + dw2, + dy2, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + + return wgrad2_halo; +} + +void bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out1) { + + at::Half* x = inputs[0].data_ptr(); + at::Half* dy1 = grad_out1.data_ptr(); + + // dconv1+add + // wgrad + auto wgrad1 = outputs[1]; + at::Half* dw1 = wgrad1.data_ptr(); + run_dconv(backward_state.dimA, + backward_state.padA, + backward_state.convstride1X1, + backward_state.dilationA, + backward_state.filterdimA1, + backward_state.outdimA1, + CUDNN_DATA_HALF, + x, + dw1, + dy1, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + +} + +void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2, at::Tensor grad_out1) { + + bool requires_grad = inputs[0].requires_grad(); + + std::cout << std::fixed; + auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; + + // dgrad + at::Half* dy2 = grad_out2.data_ptr(); + at::Half* dy1 = grad_out1.data_ptr(); + +/* + // backward strided conv cannot be fused + // if stride == 1 but channel changes, we can fuse here + if (stride_1X1 != 1){ + // dgrad + run_dconv(outdimA1, + padA1, + convstride1X1, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + dy1, + w, + dy2, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); + + // mul fused mask + grad_out1.mul_(inputs[15]); + } + else { + at::Half* relu1 = inputs[12].data_ptr(); + // fused dgrad + run_dconv_drelu_dscale(outdimA1, + padA1, + convstride1X1, + dilationA, + filterdimA2, + outdimA2, + CUDNN_DATA_HALF, + dy1, + w, + dy2, + z, + relu1); + } +*/ + DEBUG_MSG("[DEBUG] new dconv1 : " << grad_out1.to(at::kFloat).sum().item()); + + // create grads of conv4 that may exist + auto grad_x_conv4 = at::empty_like(inputs[0]); + at::Half* dx_conv4 = grad_x_conv4.data_ptr(); + at::Tensor wgrad4; + + // x used for dconv1 and dconv4 wgrad + at::Half* x = inputs[0].data_ptr(); + + at::Half* w = NULL; + + if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]){ + w = inputs[14].data_ptr(); + at::Half* dy_conv4 = inputs[11].data_ptr(); + if (requires_grad) { + run_dconv(backward_state.dimA, + backward_state.padA, + backward_state.convstride1X1, + backward_state.dilationA, + backward_state.filterdimA4, + backward_state.outdimA3, + CUDNN_DATA_HALF, + dx_conv4, + w, + dy_conv4, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); + // we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx + // DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item()); + } + // wgrad + wgrad4 = outputs[4]; + at::Half* dw4 = wgrad4.data_ptr(); + run_dconv(backward_state.dimA, + backward_state.padA, + backward_state.convstride1X1, + backward_state.dilationA, + backward_state.filterdimA4, + backward_state.outdimA3, + CUDNN_DATA_HALF, + x, + dw4, + dy_conv4, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); + } + else { + // if there is no downsample, dx_conv4 is fork of drelu3 + dx_conv4 = inputs[11].data_ptr(); + } + + // dgrad + w = inputs[1].data_ptr(); + auto grad_x = outputs[0]; + at::Half* dx = grad_x.data_ptr(); + + // backward strided conv cannot be fused + // if stride == 1 but channel changes, we can fuse here + if (requires_grad){ + if (stride_1X1 != 1){ + run_dconv(backward_state.dimA, + backward_state.padA, + backward_state.convstride1X1, + backward_state.dilationA, + backward_state.filterdimA1, + backward_state.outdimA1, + CUDNN_DATA_HALF, + dx, + w, + dy1, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); + // add 2 together + grad_x.add_(grad_x_conv4); + } + else { + run_dconv_add(backward_state.dimA, + backward_state.padA, + backward_state.convstride1X1, + backward_state.dilationA, + backward_state.filterdimA1, + backward_state.outdimA1, + CUDNN_DATA_HALF, + dx, + w, + dy1, + dx_conv4); + } + } + + DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item()); + DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item()); + + if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) { + DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item()); + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &bottleneck_forward, "Bottleneck block forward"); + m.def("backward", &bottleneck_backward, "Bottleneck block backward"); + m.def("forward_init", &bottleneck_forward_init, "Bottleneck block init"); + m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward"); + m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward"); + m.def("forward_out2_mask", &bottleneck_forward_out2_mask, "Bottleneck block forward"); + m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward"); + m.def("forward_out2_halo_corr", &bottleneck_forward_out2_halo_corr, "Bottleneck block forward"); + m.def("forward_out2_pad", &bottleneck_forward_out2_pad, "Bottleneck block forward"); + m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward"); + m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init"); + m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward"); + m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward"); + m.def("backward_grad_out1_mask", &bottleneck_backward_grad_out1_mask, "Bottleneck block backward"); + m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward"); + m.def("backward_grad_out1_halo_corr", &bottleneck_backward_grad_out1_halo_corr, "Bottleneck block backward"); + m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward"); + m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward"); + m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward"); + m.def("backward_wgrad3", &bottleneck_backward_wgrad3, "Bottleneck block backward"); + m.def("backward_wgrad1", &bottleneck_backward_wgrad1, "Bottleneck block backward"); + m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward"); +}