Spaces:
Runtime error
Runtime error
| namespace { | |
| void compute_n1_n2( | |
| at::Tensor input, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| int& n1, | |
| int& n2) | |
| { | |
| int idiff = input.ndimension() - normalized_shape.size(); | |
| n2 = 1; | |
| for (int i = 0; i < (int)normalized_shape.size(); ++i) { | |
| assert( input.sizes()[i+idiff] == normalized_shape[i] ); | |
| n2 *= normalized_shape[i]; | |
| } | |
| n1 = 1; | |
| for (int i = 0; i < idiff; ++i) { | |
| n1 *= input.sizes()[i]; | |
| } | |
| } | |
| void check_args( | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| at::Tensor gamma, | |
| at::Tensor beta | |
| ) | |
| { | |
| TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); | |
| TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); | |
| } | |
| void check_args( | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| at::Tensor gamma | |
| ) | |
| { | |
| TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); | |
| } | |
| void check_args( | |
| at::Tensor input, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| int& n1, | |
| int& n2 | |
| ) | |
| { | |
| int64_t normalized_ndim = normalized_shape.size(); | |
| if (normalized_ndim < 1) { | |
| std::stringstream ss; | |
| ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " | |
| << "containing at least one element, but got normalized_shape=" | |
| << normalized_shape; | |
| throw std::runtime_error(ss.str()); | |
| } | |
| auto input_shape = input.sizes(); | |
| auto input_ndim = input.dim(); | |
| if (input_ndim < normalized_ndim || | |
| !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) { | |
| std::stringstream ss; | |
| ss << "Given normalized_shape=" << normalized_shape | |
| << ", expected input with shape [*"; | |
| for (auto size : normalized_shape) { | |
| ss << ", " << size; | |
| } | |
| ss << "], but got input of size" << input_shape; | |
| throw std::runtime_error(ss.str()); | |
| } | |
| compute_n1_n2(input,normalized_shape,n1,n2); | |
| } | |
| void check_args( | |
| at::Tensor input, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| at::Tensor gamma, | |
| at::Tensor beta, | |
| int& n1, | |
| int& n2 | |
| ) | |
| { | |
| check_args(input,normalized_shape,n1,n2); | |
| check_args(normalized_shape,gamma,beta); | |
| } | |
| void check_args( | |
| at::Tensor input, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| at::Tensor gamma, | |
| int& n1, | |
| int& n2 | |
| ) | |
| { | |
| check_args(input,normalized_shape,n1,n2); | |
| check_args(normalized_shape,gamma); | |
| } | |
| } | |
| void cuda_layer_norm( | |
| at::Tensor* output, | |
| at::Tensor* mean, | |
| at::Tensor* invvar, | |
| at::Tensor* input, | |
| int n1, | |
| int n2, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| at::Tensor* gamma, | |
| at::Tensor* beta, | |
| double epsilon); | |
| std::vector<at::Tensor> layer_norm( | |
| at::Tensor input, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| double epsilon) { | |
| CHECK_INPUT(input); | |
| int n1,n2; | |
| check_args(input,normalized_shape,n1,n2); | |
| at::Tensor output = at::empty_like(input); | |
| at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type())); | |
| at::Tensor invvar = at::empty_like(mean); | |
| cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, | |
| normalized_shape,NULL,NULL,epsilon); | |
| return {output, mean, invvar}; | |
| } | |
| std::vector<at::Tensor> layer_norm_affine( | |
| at::Tensor input, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| at::Tensor gamma, | |
| at::Tensor beta, | |
| double epsilon) { | |
| CHECK_INPUT(input); | |
| CHECK_INPUT(gamma); | |
| CHECK_INPUT(beta); | |
| int n1,n2; | |
| check_args(input,normalized_shape,gamma,beta,n1,n2); | |
| at::Tensor output = at::empty_like(input); | |
| const auto stats_dtype = (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type(); | |
| at::Tensor mean = at::empty({n1}, input.options().dtype(stats_dtype)); | |
| at::Tensor invvar = at::empty_like(mean); | |
| cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, | |
| normalized_shape,&gamma,&beta,epsilon); | |
| return {output, mean, invvar}; | |
| } | |
| std::vector<at::Tensor> layer_norm_affine_mixed_dtypes( | |
| at::Tensor input, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| at::Tensor gamma, | |
| at::Tensor beta, | |
| double epsilon) { | |
| CHECK_INPUT(input); | |
| int n1, n2; | |
| check_args(input, normalized_shape, n1, n2); | |
| at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type())); | |
| at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type())); | |
| at::Tensor invvar = at::empty_like(mean); | |
| cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, | |
| normalized_shape, &gamma, &beta, epsilon); | |
| return {output, mean, invvar}; | |
| } | |
| void cuda_layer_norm_gradient( | |
| at::Tensor* dout, | |
| at::Tensor* mean, | |
| at::Tensor* invvar, | |
| at::Tensor* input_or_output, | |
| int n1, | |
| int n2, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| at::Tensor* gamma, | |
| at::Tensor* beta, | |
| double epsilon, | |
| at::Tensor* grad_input, | |
| at::Tensor* grad_gamma, | |
| at::Tensor* grad_beta, | |
| bool memory_efficient | |
| ); | |
| at::Tensor layer_norm_gradient( | |
| at::Tensor dout, | |
| c10::optional<at::Tensor> mean_, | |
| at::Tensor invvar, | |
| at::Tensor input_or_output, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| double epsilon, | |
| bool memory_efficient) { | |
| CHECK_INPUT(dout); | |
| CHECK_INPUT(invvar); | |
| CHECK_INPUT(input_or_output); | |
| int n1,n2; | |
| check_args(input_or_output,normalized_shape,n1,n2); | |
| at::Tensor grad_input = at::empty_like(input_or_output); | |
| if (mean_.has_value()) { | |
| cuda_layer_norm_gradient(&dout,&mean_.value(),&invvar,&input_or_output,n1,n2, | |
| normalized_shape,NULL,NULL,epsilon, | |
| &grad_input,NULL,NULL,memory_efficient); | |
| } else { | |
| cuda_layer_norm_gradient(&dout,NULL,&invvar,&input_or_output,n1,n2, | |
| normalized_shape,NULL,NULL,epsilon, | |
| &grad_input,NULL,NULL,memory_efficient); | |
| } | |
| return grad_input; | |
| } | |
| std::vector<at::Tensor> layer_norm_gradient_affine( | |
| at::Tensor dout, | |
| c10::optional<at::Tensor> mean_, | |
| at::Tensor invvar, | |
| at::Tensor input_or_output, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| at::Tensor gamma, | |
| at::Tensor beta, | |
| double epsilon, | |
| bool memory_efficient) { | |
| CHECK_INPUT(dout); | |
| CHECK_INPUT(invvar); | |
| CHECK_INPUT(input_or_output); | |
| CHECK_INPUT(gamma); | |
| CHECK_INPUT(beta); | |
| int n1,n2; | |
| check_args(input_or_output,normalized_shape,gamma,beta,n1,n2); | |
| at::Tensor grad_input = at::empty_like(input_or_output); | |
| at::Tensor grad_gamma = at::empty_like(gamma); | |
| at::Tensor grad_beta = at::empty_like(beta); | |
| // at::Tensor *mean = mean_.has_value() ? &mean_.value() : NULL; | |
| if (mean_.has_value()) { | |
| cuda_layer_norm_gradient(&dout,&mean_.value(),&invvar,&input_or_output,n1,n2, | |
| normalized_shape,&gamma,&beta,epsilon, | |
| &grad_input,&grad_gamma,&grad_beta,memory_efficient); | |
| } else { | |
| cuda_layer_norm_gradient(&dout,NULL,&invvar,&input_or_output,n1,n2, | |
| normalized_shape,&gamma,&beta,epsilon, | |
| &grad_input,&grad_gamma,&grad_beta,memory_efficient); | |
| } | |
| return {grad_input, grad_gamma, grad_beta}; | |
| } | |
| void cuda_rms_norm( | |
| at::Tensor* output, | |
| at::Tensor* invvar, | |
| at::Tensor* input, | |
| int n1, | |
| int n2, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| at::Tensor* gamma, | |
| double epsilon); | |
| std::vector<at::Tensor> rms_norm( | |
| at::Tensor input, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| double epsilon) { | |
| CHECK_INPUT(input); | |
| int n1,n2; | |
| check_args(input,normalized_shape,n1,n2); | |
| at::Tensor output = at::empty_like(input); | |
| at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type())); | |
| cuda_rms_norm(&output,&invvar,&input,n1,n2, | |
| normalized_shape,NULL,epsilon); | |
| return {output, invvar}; | |
| } | |
| std::vector<at::Tensor> rms_norm_affine( | |
| at::Tensor input, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| at::Tensor gamma, | |
| double epsilon) { | |
| CHECK_INPUT(input); | |
| CHECK_INPUT(gamma); | |
| int n1,n2; | |
| check_args(input,normalized_shape,gamma,n1,n2); | |
| at::Tensor output = at::empty_like(input); | |
| const auto stats_dtype = (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type(); | |
| at::Tensor invvar = at::empty({n1}, input.options().dtype(stats_dtype)); | |
| cuda_rms_norm(&output,&invvar,&input,n1,n2, | |
| normalized_shape,&gamma,epsilon); | |
| return {output, invvar}; | |
| } | |
| std::vector<at::Tensor> rms_norm_affine_mixed_dtypes( | |
| at::Tensor input, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| at::Tensor gamma, | |
| double epsilon) { | |
| CHECK_INPUT(input); | |
| int n1, n2; | |
| check_args(input, normalized_shape, n1, n2); | |
| at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type())); | |
| at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type())); | |
| cuda_rms_norm(&output,&invvar, &input, n1, n2, | |
| normalized_shape, &gamma,epsilon); | |
| return {output,invvar}; | |
| } | |
| void cuda_rms_norm_gradient( | |
| at::Tensor* dout, | |
| at::Tensor* invvar, | |
| at::Tensor* input_or_output, | |
| int n1, | |
| int n2, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| at::Tensor* gamma, | |
| double epsilon, | |
| at::Tensor* grad_input, | |
| at::Tensor* grad_gamma, | |
| bool memory_efficient); | |
| at::Tensor rms_norm_gradient( | |
| at::Tensor dout, | |
| at::Tensor invvar, | |
| at::Tensor input_or_output, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| double epsilon, | |
| bool memory_efficient) { | |
| CHECK_INPUT(dout); | |
| CHECK_INPUT(invvar); | |
| CHECK_INPUT(input_or_output); | |
| int n1,n2; | |
| check_args(input_or_output,normalized_shape,n1,n2); | |
| at::Tensor grad_input = at::empty_like(input_or_output); | |
| cuda_rms_norm_gradient(&dout,&invvar,&input_or_output,n1,n2, | |
| normalized_shape,NULL,epsilon, | |
| &grad_input,NULL,memory_efficient); | |
| return grad_input; | |
| } | |
| std::vector<at::Tensor> rms_norm_gradient_affine( | |
| at::Tensor dout, | |
| at::Tensor invvar, | |
| at::Tensor input_or_output, | |
| #ifdef VERSION_GE_1_1 | |
| at::IntArrayRef normalized_shape, | |
| #else | |
| at::IntList normalized_shape, | |
| #endif | |
| at::Tensor gamma, | |
| double epsilon, | |
| bool memory_efficient) { | |
| CHECK_INPUT(dout); | |
| CHECK_INPUT(invvar); | |
| CHECK_INPUT(input_or_output); | |
| CHECK_INPUT(gamma); | |
| int n1,n2; | |
| check_args(input_or_output,normalized_shape,gamma,n1,n2); | |
| at::Tensor grad_input = at::empty_like(input_or_output); | |
| at::Tensor grad_gamma = at::empty_like(gamma); | |
| cuda_rms_norm_gradient(&dout,&invvar,&input_or_output,n1,n2, | |
| normalized_shape,&gamma,epsilon, | |
| &grad_input,&grad_gamma,memory_efficient); | |
| return {grad_input, grad_gamma}; | |
| } | |
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
| m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); | |
| m.def("forward", &layer_norm, "LayerNorm forward (CUDA)"); | |
| m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); | |
| m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)"); | |
| m.def("forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes, "LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation"); | |
| m.def("rms_forward_affine", &rms_norm_affine, "RMSNorm forward (CUDA)"); | |
| m.def("rms_forward", &rms_norm, "RMSNorm forward (CUDA)"); | |
| m.def("rms_backward_affine", &rms_norm_gradient_affine, "RMSNorm backward (CUDA)"); | |
| m.def("rms_backward", &rms_norm_gradient, "RMSNorm backward (CUDA)"); | |
| m.def("rms_forward_affine_mixed_dtypes", &rms_norm_affine_mixed_dtypes, "RMSNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation"); | |
| } | |