Spaces:
Runtime error
Runtime error
| /* coding=utf-8 | |
| * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| #include <ATen/ATen.h> | |
| #include "fused_rotary_positional_embedding.h" | |
| #include "type_shim.h" | |
| namespace fused_rope { | |
| torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &freqs, | |
| const bool transpose_output) { | |
| // input sizes: (s, b, h, d) | |
| // s: sequence length | |
| // b: batch size | |
| // h: head num | |
| // d: dim of each head | |
| const int s = input.size(0); | |
| const int b = input.size(1); | |
| const int h = input.size(2); | |
| const int d = input.size(3); | |
| // input strides | |
| const int stride_s = input.stride(0); | |
| const int stride_b = input.stride(1); | |
| const int stride_h = input.stride(2); | |
| const int stride_d = input.stride(3); | |
| // freqs' shape is always (s, 1, 1, d2), so the strides are same under | |
| // different memory formats | |
| const int d2 = freqs.size(3); | |
| // output | |
| auto act_options = input.options().requires_grad(false); | |
| torch::Tensor output; | |
| if (transpose_output) { | |
| output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); | |
| } else { | |
| output = torch::empty({s, b, h, d}, act_options); | |
| } | |
| // output strides | |
| const int o_stride_s = output.stride(0); | |
| const int o_stride_b = output.stride(1); | |
| const int o_stride_h = output.stride(2); | |
| const int o_stride_d = output.stride(3); | |
| DISPATCH_FLOAT_HALF_AND_BFLOAT( | |
| input.scalar_type(), 0, "dispatch_fused_rope_forward", | |
| dispatch_fused_rope_forward( | |
| s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, | |
| o_stride_b, o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(), | |
| freqs.data_ptr<float>(), output.data_ptr<scalar_t_0>());); | |
| return output; | |
| } | |
| torch::Tensor bwd_cuda(const torch::Tensor &output_grads, | |
| const torch::Tensor &freqs, | |
| const bool transpose_output) { | |
| // output_grads sizes: (s, b, h, d) | |
| // s: sequence length | |
| // b: batch size | |
| // h: head num | |
| // d: dim of each head | |
| const int s = output_grads.size(0); | |
| const int b = output_grads.size(1); | |
| const int h = output_grads.size(2); | |
| const int d = output_grads.size(3); | |
| // output_grads strides | |
| const int stride_s = output_grads.stride(0); | |
| const int stride_b = output_grads.stride(1); | |
| const int stride_h = output_grads.stride(2); | |
| const int stride_d = output_grads.stride(3); | |
| // freqs' shape is always (s, 1, 1, d2), so the strides are same under | |
| // different memory formats | |
| const int d2 = freqs.size(3); | |
| auto act_options = output_grads.options().requires_grad(false); | |
| torch::Tensor input_grads; | |
| if (transpose_output) { | |
| input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); | |
| } else { | |
| input_grads = torch::empty({s, b, h, d}, act_options); | |
| } | |
| const int o_stride_s = input_grads.stride(0); | |
| const int o_stride_b = input_grads.stride(1); | |
| const int o_stride_h = input_grads.stride(2); | |
| const int o_stride_d = input_grads.stride(3); | |
| DISPATCH_FLOAT_HALF_AND_BFLOAT( | |
| output_grads.scalar_type(), 0, "dispatch_fused_rope_backward", | |
| dispatch_fused_rope_backward( | |
| s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, | |
| o_stride_b, o_stride_h, o_stride_d, | |
| output_grads.data_ptr<scalar_t_0>(), freqs.data_ptr<float>(), | |
| input_grads.data_ptr<scalar_t_0>());); | |
| return input_grads; | |
| } | |
| #define DISPATCH_FUSED_ROPE_TYPES(TYPE1, TYPE2, NAME, ...) \ | |
| switch (TYPE1) { \ | |
| case at::ScalarType::Float: { \ | |
| using scalar_t_0 = float; \ | |
| switch (TYPE2) { \ | |
| case at::ScalarType::Float: { \ | |
| using scalar_t_1 = float; \ | |
| __VA_ARGS__; \ | |
| break; \ | |
| } \ | |
| default: \ | |
| TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \ | |
| "' with '", toString(TYPE2), "'"); \ | |
| } \ | |
| break; \ | |
| } \ | |
| case at::ScalarType::Half: { \ | |
| using scalar_t_0 = at::Half; \ | |
| switch (TYPE2) { \ | |
| case at::ScalarType::Float: { \ | |
| using scalar_t_1 = float; \ | |
| __VA_ARGS__; \ | |
| break; \ | |
| } \ | |
| case at::ScalarType::Half: { \ | |
| using scalar_t_1 = at::Half; \ | |
| __VA_ARGS__; \ | |
| break; \ | |
| } \ | |
| default: \ | |
| TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \ | |
| "' with '", toString(TYPE2), "'"); \ | |
| } \ | |
| break; \ | |
| } \ | |
| case at::ScalarType::BFloat16: { \ | |
| using scalar_t_0 = at::BFloat16; \ | |
| switch (TYPE2) { \ | |
| case at::ScalarType::Float: { \ | |
| using scalar_t_1 = float; \ | |
| __VA_ARGS__; \ | |
| break; \ | |
| } \ | |
| case at::ScalarType::BFloat16: { \ | |
| using scalar_t_1 = at::BFloat16; \ | |
| __VA_ARGS__; \ | |
| break; \ | |
| } \ | |
| default: \ | |
| TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \ | |
| "' with '", toString(TYPE2), "'"); \ | |
| } \ | |
| break; \ | |
| } \ | |
| default: \ | |
| TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), \ | |
| "' with '", toString(TYPE2), "'"); \ | |
| } | |
| torch::Tensor fwd_cached_cuda(const torch::Tensor &input, | |
| const torch::Tensor &cos, | |
| const torch::Tensor &sin, | |
| const bool transpose_output) { | |
| // input sizes: (s, b, h, d) | |
| // s: sequence length | |
| // b: batch size | |
| // h: head num | |
| // d: dim of each head | |
| const int s = input.size(0); | |
| const int b = input.size(1); | |
| const int h = input.size(2); | |
| const int d = input.size(3); | |
| // input strides | |
| const int stride_s = input.stride(0); | |
| const int stride_b = input.stride(1); | |
| const int stride_h = input.stride(2); | |
| const int stride_d = input.stride(3); | |
| // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under | |
| // different memory formats | |
| const int d2 = cos.size(3); | |
| // output | |
| auto act_options = input.options().requires_grad(false); | |
| torch::Tensor output; | |
| if (transpose_output) { | |
| output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); | |
| } else { | |
| output = torch::empty({s, b, h, d}, act_options); | |
| } | |
| // output strides | |
| const int o_stride_s = output.stride(0); | |
| const int o_stride_b = output.stride(1); | |
| const int o_stride_h = output.stride(2); | |
| const int o_stride_d = output.stride(3); | |
| DISPATCH_FUSED_ROPE_TYPES( | |
| input.scalar_type(), cos.scalar_type(), | |
| "dispatch_fused_rope_cached_forward", | |
| dispatch_fused_rope_cached_forward( | |
| s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, | |
| o_stride_b, o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(), | |
| cos.data_ptr<scalar_t_1>(), sin.data_ptr<scalar_t_1>(), | |
| output.data_ptr<scalar_t_0>());); | |
| return output; | |
| } | |
| torch::Tensor bwd_cached_cuda(const torch::Tensor &output_grads, | |
| const torch::Tensor &cos, | |
| const torch::Tensor &sin, | |
| const bool transpose_output) { | |
| // output_grads sizes: (s, b, h, d) | |
| // s: sequence length | |
| // b: batch size | |
| // h: head num | |
| // d: dim of each head | |
| const int s = output_grads.size(0); | |
| const int b = output_grads.size(1); | |
| const int h = output_grads.size(2); | |
| const int d = output_grads.size(3); | |
| // output_grads strides | |
| const int stride_s = output_grads.stride(0); | |
| const int stride_b = output_grads.stride(1); | |
| const int stride_h = output_grads.stride(2); | |
| const int stride_d = output_grads.stride(3); | |
| // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under | |
| // different memory formats | |
| const int d2 = cos.size(3); | |
| auto act_options = output_grads.options().requires_grad(false); | |
| torch::Tensor input_grads; | |
| if (transpose_output) { | |
| input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); | |
| } else { | |
| input_grads = torch::empty({s, b, h, d}, act_options); | |
| } | |
| const int o_stride_s = input_grads.stride(0); | |
| const int o_stride_b = input_grads.stride(1); | |
| const int o_stride_h = input_grads.stride(2); | |
| const int o_stride_d = input_grads.stride(3); | |
| DISPATCH_FUSED_ROPE_TYPES( | |
| output_grads.scalar_type(), cos.scalar_type(), | |
| "dispatch_fused_rope_cached_backward", | |
| dispatch_fused_rope_cached_backward( | |
| s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, | |
| o_stride_b, o_stride_h, o_stride_d, | |
| output_grads.data_ptr<scalar_t_0>(), cos.data_ptr<scalar_t_1>(), | |
| sin.data_ptr<scalar_t_1>(), input_grads.data_ptr<scalar_t_0>());); | |
| return input_grads; | |
| } | |
| torch::Tensor fwd_thd_cuda(const torch::Tensor &input, | |
| const torch::Tensor &cu_seqlens, | |
| const torch::Tensor &freqs) { | |
| // input sizes: (t, h, d) | |
| // t: cumulative sum of sequence lengths | |
| // h: head num | |
| // d: dim of each head | |
| const int t = input.size(0); | |
| const int h = input.size(1); | |
| const int d = input.size(2); | |
| // input strides | |
| const int stride_t = input.stride(0); | |
| const int stride_h = input.stride(1); | |
| const int stride_d = input.stride(2); | |
| // batch size | |
| const int b = cu_seqlens.size(0) - 1; | |
| // freqs' shape is (max_s, 1, 1, d2) | |
| const int max_s = freqs.size(0); | |
| const int d2 = freqs.size(3); | |
| // output | |
| auto act_options = input.options().requires_grad(false); | |
| auto output = torch::empty({t, h, d}, act_options); | |
| // output strides | |
| const int o_stride_t = output.stride(0); | |
| const int o_stride_h = output.stride(1); | |
| const int o_stride_d = output.stride(2); | |
| DISPATCH_FLOAT_HALF_AND_BFLOAT( | |
| input.scalar_type(), 0, "dispatch_fused_rope_thd_forward", | |
| dispatch_fused_rope_thd_forward( | |
| max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, | |
| o_stride_h, o_stride_d, input.data_ptr<scalar_t_0>(), | |
| cu_seqlens.data_ptr<int>(), freqs.data_ptr<float>(), | |
| output.data_ptr<scalar_t_0>());); | |
| return output; | |
| } | |
| torch::Tensor bwd_thd_cuda(const torch::Tensor &output_grads, | |
| const torch::Tensor &cu_seqlens, | |
| const torch::Tensor &freqs) { | |
| // output_grads sizes: (t, h, d) | |
| // t: cumulative sum of sequence lengths | |
| // h: head num | |
| // d: dim of each head | |
| const int t = output_grads.size(0); | |
| const int h = output_grads.size(1); | |
| const int d = output_grads.size(2); | |
| // output_grads strides | |
| const int stride_t = output_grads.stride(0); | |
| const int stride_h = output_grads.stride(1); | |
| const int stride_d = output_grads.stride(2); | |
| // batch size | |
| const int b = cu_seqlens.size(0) - 1; | |
| // freqs' shape is (max_s, 1, 1, d2) | |
| const int max_s = freqs.size(0); | |
| const int d2 = freqs.size(3); | |
| auto act_options = output_grads.options().requires_grad(false); | |
| auto input_grads = torch::empty({t, h, d}, act_options); | |
| const int o_stride_t = input_grads.stride(0); | |
| const int o_stride_h = input_grads.stride(1); | |
| const int o_stride_d = input_grads.stride(2); | |
| DISPATCH_FLOAT_HALF_AND_BFLOAT( | |
| output_grads.scalar_type(), 0, "dispatch_fused_rope_thd_backward", | |
| dispatch_fused_rope_thd_backward( | |
| max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, | |
| o_stride_h, o_stride_d, output_grads.data_ptr<scalar_t_0>(), | |
| cu_seqlens.data_ptr<int>(), freqs.data_ptr<float>(), | |
| input_grads.data_ptr<scalar_t_0>());); | |
| return input_grads; | |
| } | |
| } // end namespace fused_rope | |