| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						#include <torch/extension.h> | 
					
					
						
						| 
							 | 
						#include <cuda.h> | 
					
					
						
						| 
							 | 
						#include <cuda_runtime.h> | 
					
					
						
						| 
							 | 
						#include <vector> | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						#define CHECK_CUDA(tensor) {\ | 
					
					
						
						| 
							 | 
						    TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \ | 
					
					
						
						| 
							 | 
						    TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); } | 
					
					
						
						| 
							 | 
						void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						template < typename scalar_t  > | 
					
					
						
						| 
							 | 
						__global__ void rope_2d_cuda_kernel(  | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> tokens, | 
					
					
						
						| 
							 | 
						        const int64_t* __restrict__ pos,  | 
					
					
						
						| 
							 | 
						        const float base,  | 
					
					
						
						| 
							 | 
						        const float fwd ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						{ | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    const int N = tokens.size(1); | 
					
					
						
						| 
							 | 
						    const int H = tokens.size(2); | 
					
					
						
						| 
							 | 
						    const int D = tokens.size(3); | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    extern __shared__ float shared[]; | 
					
					
						
						| 
							 | 
						    float* shared_inv_freq = shared + D; | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    const int b = blockIdx.x / N; | 
					
					
						
						| 
							 | 
						    const int n = blockIdx.x % N; | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    const int Q = D / 4;  | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if (threadIdx.x < Q) | 
					
					
						
						| 
							 | 
						        shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q)); | 
					
					
						
						| 
							 | 
						    __syncthreads(); | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    const int X = threadIdx.x < D/2 ? 0 : 1;  | 
					
					
						
						| 
							 | 
						    const int m = (X*D/2) + (threadIdx.x % Q);    | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q]; | 
					
					
						
						| 
							 | 
						    const float cos = cosf(freq); | 
					
					
						
						| 
							 | 
						    const float sin = sinf(freq); | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    for (int h = 0; h < H; h++) | 
					
					
						
						| 
							 | 
						    { | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        shared[threadIdx.x] = tokens[b][n][h][threadIdx.x]; | 
					
					
						
						| 
							 | 
						        __syncthreads(); | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        const float u = shared[m]; | 
					
					
						
						| 
							 | 
						        const float v = shared[m+Q]; | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if ((threadIdx.x % (D/2)) < Q) | 
					
					
						
						| 
							 | 
						            tokens[b][n][h][threadIdx.x] = u*cos - v*sin; | 
					
					
						
						| 
							 | 
						        else | 
					
					
						
						| 
							 | 
						            tokens[b][n][h][threadIdx.x] = v*cos + u*sin; | 
					
					
						
						| 
							 | 
						    } | 
					
					
						
						| 
							 | 
						} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd )  | 
					
					
						
						| 
							 | 
						{ | 
					
					
						
						| 
							 | 
						    const int B = tokens.size(0);  | 
					
					
						
						| 
							 | 
						    const int N = tokens.size(1);  | 
					
					
						
						| 
							 | 
						    const int H = tokens.size(2);  | 
					
					
						
						| 
							 | 
						    const int D = tokens.size(3);  | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous"); | 
					
					
						
						| 
							 | 
						    TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous"); | 
					
					
						
						| 
							 | 
						    TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape"); | 
					
					
						
						| 
							 | 
						    TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4"); | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    const int THREADS_PER_BLOCK = D; | 
					
					
						
						| 
							 | 
						    const int N_BLOCKS = B * N;  | 
					
					
						
						| 
							 | 
						    const int SHARED_MEM = sizeof(float) * (D + D/4); | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] { | 
					
					
						
						| 
							 | 
						        rope_2d_cuda_kernel<scalar_t> <<<N_BLOCKS, THREADS_PER_BLOCK, SHARED_MEM>>> ( | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            tokens.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(), | 
					
					
						
						| 
							 | 
						            pos.data_ptr<int64_t>(),  | 
					
					
						
						| 
							 | 
						            base, fwd);  | 
					
					
						
						| 
							 | 
						    })); | 
					
					
						
						| 
							 | 
						} | 
					
					
						
						| 
							 | 
						
 |