|
|
|
|
|
|
|
|
|
#include <ATen/cuda/CUDAContext.h> |
|
#include <c10/cuda/CUDAGuard.h> |
|
#include <torch/extension.h> |
|
#include <vector> |
|
|
|
#include "causal_conv1d.h" |
|
|
|
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") |
|
|
|
#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ |
|
if (ITYPE == at::ScalarType::Half) { \ |
|
using input_t = at::Half; \ |
|
__VA_ARGS__(); \ |
|
} else if (ITYPE == at::ScalarType::BFloat16) { \ |
|
using input_t = at::BFloat16; \ |
|
__VA_ARGS__(); \ |
|
} else if (ITYPE == at::ScalarType::Float) { \ |
|
using input_t = float; \ |
|
__VA_ARGS__(); \ |
|
} else { \ |
|
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ |
|
} |
|
|
|
#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ |
|
if (WTYPE == at::ScalarType::Half) { \ |
|
using weight_t = at::Half; \ |
|
__VA_ARGS__(); \ |
|
} else if (WTYPE == at::ScalarType::BFloat16) { \ |
|
using weight_t = at::BFloat16; \ |
|
__VA_ARGS__(); \ |
|
} else if (WTYPE == at::ScalarType::Float) { \ |
|
using weight_t = float; \ |
|
__VA_ARGS__(); \ |
|
} else { \ |
|
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ |
|
} |
|
|
|
template<typename input_t, typename weight_t> |
|
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); |
|
template <typename input_t, typename weight_t> |
|
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); |
|
|
|
template<typename input_t, typename weight_t> |
|
void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); |
|
template<typename input_t, typename weight_t> |
|
void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); |
|
|
|
template<typename input_t, typename weight_t> |
|
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); |
|
|
|
void set_conv_params_fwd(ConvParamsBase ¶ms, |
|
|
|
const size_t batch, |
|
const size_t dim, |
|
const size_t seqlen, |
|
const size_t width, |
|
|
|
const at::Tensor x, |
|
const at::Tensor weight, |
|
const at::Tensor out, |
|
void* bias_ptr, |
|
bool silu_activation) { |
|
|
|
|
|
memset(¶ms, 0, sizeof(params)); |
|
|
|
params.batch = batch; |
|
params.dim = dim; |
|
params.seqlen = seqlen; |
|
params.width = width; |
|
|
|
params.silu_activation = silu_activation; |
|
|
|
|
|
params.x_ptr = x.data_ptr(); |
|
params.weight_ptr = weight.data_ptr(); |
|
params.bias_ptr = bias_ptr; |
|
params.out_ptr = out.data_ptr(); |
|
|
|
params.x_batch_stride = x.stride(0); |
|
params.x_c_stride = x.stride(1); |
|
params.x_l_stride = x.stride(-1); |
|
params.weight_c_stride = weight.stride(0); |
|
params.weight_width_stride = weight.stride(1); |
|
params.out_batch_stride = out.stride(0); |
|
params.out_c_stride = out.stride(1); |
|
params.out_l_stride = out.stride(-1); |
|
} |
|
|
|
|
|
void set_conv_params_bwd(ConvParamsBwd ¶ms, |
|
|
|
const size_t batch, |
|
const size_t dim, |
|
const size_t seqlen, |
|
const size_t width, |
|
|
|
const at::Tensor x, |
|
const at::Tensor weight, |
|
void* bias_ptr, |
|
const at::Tensor dout, |
|
const at::Tensor dx, |
|
const at::Tensor dweight, |
|
void* dbias_ptr, |
|
bool silu_activation) { |
|
|
|
set_conv_params_fwd(params, batch, dim, seqlen, width, |
|
x, weight, dout, bias_ptr, silu_activation); |
|
|
|
|
|
params.dout_ptr = dout.data_ptr(); |
|
params.dx_ptr = dx.data_ptr(); |
|
params.dweight_ptr = dweight.data_ptr(); |
|
params.dbias_ptr = dbias_ptr; |
|
|
|
params.dout_batch_stride = dout.stride(0); |
|
params.dout_c_stride = dout.stride(1); |
|
params.dout_l_stride = dout.stride(2); |
|
params.dweight_c_stride = dweight.stride(0); |
|
params.dweight_width_stride = dweight.stride(1); |
|
params.dx_batch_stride = dx.stride(0); |
|
params.dx_c_stride = dx.stride(1); |
|
params.dx_l_stride = dx.stride(2); |
|
} |
|
|
|
at::Tensor |
|
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, |
|
const c10::optional<at::Tensor> &bias_, |
|
bool silu_activation) { |
|
auto input_type = x.scalar_type(); |
|
auto weight_type = weight.scalar_type(); |
|
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); |
|
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); |
|
|
|
TORCH_CHECK(x.is_cuda()); |
|
TORCH_CHECK(weight.is_cuda()); |
|
|
|
const auto sizes = x.sizes(); |
|
const int batch_size = sizes[0]; |
|
const int dim = sizes[1]; |
|
const int seqlen = sizes[2]; |
|
const int width = weight.size(-1); |
|
|
|
CHECK_SHAPE(x, batch_size, dim, seqlen); |
|
CHECK_SHAPE(weight, dim, width); |
|
|
|
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); |
|
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; |
|
|
|
if (is_channel_last) { |
|
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); |
|
} |
|
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); |
|
|
|
|
|
if (bias_.has_value()) { |
|
auto bias = bias_.value(); |
|
TORCH_CHECK(bias.scalar_type() == weight_type); |
|
TORCH_CHECK(bias.is_cuda()); |
|
TORCH_CHECK(bias.stride(-1) == 1); |
|
CHECK_SHAPE(bias, dim); |
|
} |
|
|
|
at::Tensor out = torch::empty_like(x); |
|
|
|
ConvParamsBase params; |
|
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, |
|
bias_.has_value() ? bias_.value().data_ptr() : nullptr, |
|
silu_activation); |
|
|
|
|
|
|
|
at::cuda::CUDAGuard device_guard{(char)x.get_device()}; |
|
auto stream = at::cuda::getCurrentCUDAStream().stream(); |
|
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { |
|
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] { |
|
if (!is_channel_last) { |
|
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream); |
|
} else { |
|
causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream); |
|
} |
|
}); |
|
}); |
|
return out; |
|
} |
|
|
|
std::vector<at::Tensor> |
|
causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight, |
|
const c10::optional<at::Tensor> &bias_, |
|
at::Tensor &dout, |
|
c10::optional<at::Tensor> &dx_, |
|
bool silu_activation) { |
|
auto input_type = x.scalar_type(); |
|
auto weight_type = weight.scalar_type(); |
|
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); |
|
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); |
|
|
|
TORCH_CHECK(x.is_cuda()); |
|
TORCH_CHECK(weight.is_cuda()); |
|
TORCH_CHECK(dout.is_cuda()); |
|
|
|
const auto sizes = x.sizes(); |
|
const int batch_size = sizes[0]; |
|
const int dim = sizes[1]; |
|
const int seqlen = sizes[2]; |
|
const int width = weight.size(-1); |
|
|
|
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); |
|
|
|
CHECK_SHAPE(x, batch_size, dim, seqlen); |
|
CHECK_SHAPE(weight, dim, width); |
|
CHECK_SHAPE(dout, batch_size, dim, seqlen); |
|
|
|
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); |
|
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; |
|
if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); } |
|
if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); } |
|
|
|
if (bias_.has_value()) { |
|
auto bias = bias_.value(); |
|
TORCH_CHECK(bias.scalar_type() == weight_type); |
|
TORCH_CHECK(bias.is_cuda()); |
|
TORCH_CHECK(bias.stride(-1) == 1); |
|
CHECK_SHAPE(bias, dim); |
|
} |
|
|
|
at::Tensor dx; |
|
if (dx_.has_value()) { |
|
dx = dx_.value(); |
|
TORCH_CHECK(dx.scalar_type() == input_type); |
|
TORCH_CHECK(dx.is_cuda()); |
|
CHECK_SHAPE(dx, batch_size, dim, seqlen); |
|
if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); } |
|
if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); } |
|
} else { |
|
dx = torch::empty_like(x); |
|
} |
|
|
|
|
|
|
|
at::cuda::CUDAGuard device_guard{(char)x.get_device()}; |
|
|
|
at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat)); |
|
at::Tensor dbias; |
|
if (bias_.has_value()) { dbias = torch::zeros_like(bias_.value(), bias_.value().options().dtype(at::kFloat)); } |
|
|
|
ConvParamsBwd params; |
|
set_conv_params_bwd(params, batch_size, dim, seqlen, width, |
|
x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr, |
|
dout, dx, dweight, bias_.has_value() ? dbias.data_ptr() : nullptr, |
|
silu_activation); |
|
|
|
auto stream = at::cuda::getCurrentCUDAStream().stream(); |
|
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] { |
|
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] { |
|
if (!is_channel_last) { |
|
causal_conv1d_bwd_cuda<input_t, weight_t>(params, stream); |
|
} else { |
|
causal_conv1d_channellast_bwd_cuda<input_t, weight_t>(params, stream); |
|
} |
|
}); |
|
}); |
|
return {dx, dweight.to(weight.dtype()), bias_.has_value() ? dbias.to(bias_.value().dtype()) : dbias}; |
|
} |
|
|
|
at::Tensor |
|
causal_conv1d_update(const at::Tensor &x, |
|
const at::Tensor &conv_state, |
|
const at::Tensor &weight, |
|
const c10::optional<at::Tensor> &bias_, |
|
bool silu_activation) { |
|
auto input_type = x.scalar_type(); |
|
auto weight_type = weight.scalar_type(); |
|
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); |
|
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); |
|
TORCH_CHECK(conv_state.scalar_type() == input_type); |
|
|
|
TORCH_CHECK(x.is_cuda()); |
|
TORCH_CHECK(conv_state.is_cuda()); |
|
TORCH_CHECK(weight.is_cuda()); |
|
|
|
const auto sizes = x.sizes(); |
|
const int batch_size = sizes[0]; |
|
const int dim = sizes[1]; |
|
const int width = weight.size(-1); |
|
|
|
CHECK_SHAPE(x, batch_size, dim); |
|
CHECK_SHAPE(conv_state, batch_size, dim, width); |
|
CHECK_SHAPE(weight, dim, width); |
|
|
|
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); |
|
|
|
if (bias_.has_value()) { |
|
auto bias = bias_.value(); |
|
TORCH_CHECK(bias.scalar_type() == weight_type); |
|
TORCH_CHECK(bias.is_cuda()); |
|
TORCH_CHECK(bias.stride(-1) == 1); |
|
CHECK_SHAPE(bias, dim); |
|
} |
|
|
|
at::Tensor out = torch::empty_like(x); |
|
|
|
ConvParamsBase params; |
|
set_conv_params_fwd(params, batch_size, dim, 1, width, x, weight, out, |
|
bias_.has_value() ? bias_.value().data_ptr() : nullptr, |
|
silu_activation); |
|
params.conv_state_ptr = conv_state.data_ptr(); |
|
|
|
params.conv_state_batch_stride = conv_state.stride(0); |
|
params.conv_state_c_stride = conv_state.stride(1); |
|
params.conv_state_l_stride = conv_state.stride(2); |
|
|
|
|
|
|
|
at::cuda::CUDAGuard device_guard{(char)x.get_device()}; |
|
auto stream = at::cuda::getCurrentCUDAStream().stream(); |
|
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { |
|
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] { |
|
causal_conv1d_update_cuda<input_t, weight_t>(params, stream); |
|
}); |
|
}); |
|
return out; |
|
} |
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
|
m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward"); |
|
m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward"); |
|
m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update"); |
|
} |
|
|