Spaces:
Runtime error
Runtime error
/* Includes, cuda */ | |
// BF16 Tensor core wrapper around cublas GEMMEx | |
void gemmex_wrapper( | |
cublasHandle_t handle, | |
cublasOperation_t transa, | |
cublasOperation_t transb, | |
int m, | |
int n, | |
int k, | |
const float* alpha, | |
at::BFloat16* A, | |
int lda, | |
at::BFloat16* B, | |
int ldb, | |
const float* beta, | |
float* C, | |
int ldc) { | |
TORCH_CUDABLAS_CHECK(cublasGemmEx( | |
handle, | |
transa, | |
transb, | |
m, | |
n, | |
k, | |
alpha, | |
A, | |
CUDA_R_16BF, | |
lda, | |
B, | |
CUDA_R_16BF, | |
ldb, | |
beta, | |
C, | |
CUDA_R_32F, | |
ldc, | |
CUDA_R_32F, | |
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | |
} | |
// FP16 Tensor core wrapper around cublas GEMMEx | |
void gemmex_wrapper( | |
cublasHandle_t handle, | |
cublasOperation_t transa, | |
cublasOperation_t transb, | |
int m, | |
int n, | |
int k, | |
const float* alpha, | |
at::Half* A, | |
int lda, | |
at::Half* B, | |
int ldb, | |
const float* beta, | |
float* C, | |
int ldc) { | |
TORCH_CUDABLAS_CHECK(cublasGemmEx( | |
handle, | |
transa, | |
transb, | |
m, | |
n, | |
k, | |
alpha, | |
A, | |
CUDA_R_16F, | |
lda, | |
B, | |
CUDA_R_16F, | |
ldb, | |
beta, | |
C, | |
CUDA_R_32F, | |
ldc, | |
CUDA_R_32F, | |
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | |
} | |
// FP32 wrapper around cublas GEMMEx | |
void gemmex_wrapper( | |
cublasHandle_t handle, | |
cublasOperation_t transa, | |
cublasOperation_t transb, | |
int m, | |
int n, | |
int k, | |
const float *alpha, | |
float *A, | |
int lda, | |
float *B, | |
int ldb, | |
const float *beta, | |
float *C, | |
int ldc) { | |
TORCH_CUDABLAS_CHECK(cublasGemmEx( | |
handle, | |
transa, | |
transb, | |
m, | |
n, | |
k, | |
alpha, | |
A, | |
CUDA_R_32F, | |
lda, | |
B, | |
CUDA_R_32F, | |
ldb, | |
beta, | |
C, | |
CUDA_R_32F, | |
ldc, | |
CUDA_R_32F, | |
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | |
} | |
template <typename T> | |
void wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) { | |
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); | |
cudaStream_t stream; | |
cublasGetStream(handle, &stream); | |
const float alpha = 1.0; | |
const float beta = 1.0; | |
gemmex_wrapper( | |
handle, | |
CUBLAS_OP_N, | |
CUBLAS_OP_T, | |
in_dim, | |
out_dim, | |
hidden_dim, | |
&alpha, | |
input, | |
in_dim, | |
d_output, | |
out_dim, | |
&beta, | |
d_weight, | |
in_dim); | |
} | |
template void wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); | |
template void wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); | |
template void wgrad_gemm_accum_fp32_cuda<float>(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); | |
void wgrad_gemm_accum_fp32_cuda_stub( | |
at::Tensor &input, | |
at::Tensor &d_output, | |
at::Tensor &d_weight | |
) { | |
at::Tensor input_2d, d_output_2d; | |
// input tensor: collapse to the first dim | |
auto in_sizes = input.sizes(); | |
if (input.dim() > 2) { | |
input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]}); | |
} else { | |
input_2d = input; | |
} | |
// d_output tensor: collapse to the first dim | |
auto d_out_sizes = d_output.sizes(); | |
if (d_output.dim() > 2) { | |
d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]}); | |
} else { | |
d_output_2d = d_output; | |
} | |
const int hidden_dim = input_2d.size(0); | |
const int in_dim = input_2d.size(1); | |
const int out_dim = d_weight.size(0); | |
DISPATCH_FLOAT_HALF_AND_BFLOAT(input_2d.scalar_type(), 0, "wgrad_gemm_accum_fp32", | |
wgrad_gemm_accum_fp32_cuda<scalar_t_0>( | |
input_2d.data_ptr<scalar_t_0>(), | |
d_output_2d.data_ptr<scalar_t_0>(), | |
d_weight.data_ptr<float>(), | |
in_dim, | |
hidden_dim, | |
out_dim); | |
); | |
} | |