Spaces:
Running
Running
namespace c10 { | |
class CuDNNError : public c10::Error { | |
using Error::Error; | |
}; | |
} // namespace c10 | |
do { \ | |
auto error_object = EXPR; \ | |
if (!error_object.is_good()) { \ | |
TORCH_CHECK_WITH(CuDNNError, false, \ | |
"cuDNN Frontend error: ", error_object.get_message()); \ | |
} \ | |
} while (0) \ | |
// See Note [CHECK macro] | |
do { \ | |
cudnnStatus_t status = EXPR; \ | |
if (status != CUDNN_STATUS_SUCCESS) { \ | |
if (status == CUDNN_STATUS_NOT_SUPPORTED) { \ | |
TORCH_CHECK_WITH(CuDNNError, false, \ | |
"cuDNN error: ", \ | |
cudnnGetErrorString(status), \ | |
". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \ | |
} else { \ | |
TORCH_CHECK_WITH(CuDNNError, false, \ | |
"cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__); \ | |
} \ | |
} \ | |
} while (0) | |
namespace at::cuda::blas { | |
C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error); | |
} // namespace at::cuda::blas | |
do { \ | |
cublasStatus_t __err = EXPR; \ | |
TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS, \ | |
"CUDA error: ", \ | |
at::cuda::blas::_cublasGetErrorEnum(__err), \ | |
" when calling `" #EXPR "`"); \ | |
} while (0) | |
const char *cusparseGetErrorString(cusparseStatus_t status); | |
do { \ | |
cusparseStatus_t __err = EXPR; \ | |
TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS, \ | |
"CUDA error: ", \ | |
cusparseGetErrorString(__err), \ | |
" when calling `" #EXPR "`"); \ | |
} while (0) | |
// cusolver related headers are only supported on cuda now | |
namespace at::cuda::solver { | |
C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status); | |
constexpr const char* _cusolver_backend_suggestion = \ | |
"If you keep seeing this error, you may use " \ | |
"`torch.backends.cuda.preferred_linalg_library()` to try " \ | |
"linear algebra operators with other supported backends. " \ | |
"See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library"; | |
} // namespace at::cuda::solver | |
// When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan. | |
// When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue. | |
do { \ | |
cusolverStatus_t __err = EXPR; \ | |
if ((CUDA_VERSION < 11500 && \ | |
__err == CUSOLVER_STATUS_EXECUTION_FAILED) || \ | |
(CUDA_VERSION >= 11500 && \ | |
__err == CUSOLVER_STATUS_INVALID_VALUE)) { \ | |
TORCH_CHECK_LINALG( \ | |
false, \ | |
"cusolver error: ", \ | |
at::cuda::solver::cusolverGetErrorMessage(__err), \ | |
", when calling `" #EXPR "`", \ | |
". This error may appear if the input matrix contains NaN. ", \ | |
at::cuda::solver::_cusolver_backend_suggestion); \ | |
} else { \ | |
TORCH_CHECK( \ | |
__err == CUSOLVER_STATUS_SUCCESS, \ | |
"cusolver error: ", \ | |
at::cuda::solver::cusolverGetErrorMessage(__err), \ | |
", when calling `" #EXPR "`. ", \ | |
at::cuda::solver::_cusolver_backend_suggestion); \ | |
} \ | |
} while (0) | |
// For CUDA Driver API | |
// | |
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub | |
// in ATen, and we need to use its nvrtcGetErrorString. | |
// See NOTE [ USE OF NVRTC AND DRIVER API ]. | |
do { \ | |
CUresult __err = EXPR; \ | |
if (__err != CUDA_SUCCESS) { \ | |
const char* err_str; \ | |
CUresult get_error_str_err C10_UNUSED = at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \ | |
if (get_error_str_err != CUDA_SUCCESS) { \ | |
AT_ERROR("CUDA driver error: unknown error"); \ | |
} else { \ | |
AT_ERROR("CUDA driver error: ", err_str); \ | |
} \ | |
} \ | |
} while (0) | |
do { \ | |
CUresult __err = EXPR; \ | |
if (__err != CUDA_SUCCESS) { \ | |
AT_ERROR("CUDA driver error: ", static_cast<int>(__err)); \ | |
} \ | |
} while (0) | |
// For CUDA NVRTC | |
// | |
// Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE, | |
// incorrectly produces the error string "NVRTC unknown error." | |
// The following maps it correctly. | |
// | |
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub | |
// in ATen, and we need to use its nvrtcGetErrorString. | |
// See NOTE [ USE OF NVRTC AND DRIVER API ]. | |
do { \ | |
nvrtcResult __err = EXPR; \ | |
if (__err != NVRTC_SUCCESS) { \ | |
if (static_cast<int>(__err) != 7) { \ | |
AT_ERROR("CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \ | |
} else { \ | |
AT_ERROR("CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \ | |
} \ | |
} \ | |
} while (0) | |