Spaces:
Running
Running
// CUDA Graphs utils used by c10 and aten. | |
// aten/cuda/CUDAGraphsUtils.cuh adds utils used by aten only. | |
namespace c10::cuda { | |
using CaptureId_t = unsigned long long; | |
// first is set if the instance is created by CUDAGraph::capture_begin. | |
// second is set if the instance is created by at::cuda::graph_pool_handle. | |
using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>; | |
// RAII guard for "cudaStreamCaptureMode", a thread-local value | |
// that controls the error-checking strictness of a capture. | |
struct C10_CUDA_API CUDAStreamCaptureModeGuard { | |
CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired) | |
: strictness_(desired) { | |
C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&strictness_)); | |
} | |
~CUDAStreamCaptureModeGuard() { | |
C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_)); | |
} | |
private: | |
cudaStreamCaptureMode strictness_; | |
}; | |
// Protects against enum cudaStreamCaptureStatus implementation changes. | |
// Some compilers seem not to like static_assert without the messages. | |
static_assert( | |
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0, | |
"unexpected int(cudaStreamCaptureStatusNone) value"); | |
static_assert( | |
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1, | |
"unexpected int(cudaStreamCaptureStatusActive) value"); | |
static_assert( | |
int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2, | |
"unexpected int(cudaStreamCaptureStatusInvalidated) value"); | |
enum class CaptureStatus : int { | |
None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone), | |
Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive), | |
Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) | |
None = 0 | |
}; | |
inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) { | |
switch (status) { | |
case CaptureStatus::None: | |
os << "cudaStreamCaptureStatusNone"; | |
break; | |
case CaptureStatus::Active: | |
os << "cudaStreamCaptureStatusActive"; | |
break; | |
case CaptureStatus::Invalidated: | |
os << "cudaStreamCaptureStatusInvalidated"; | |
break; | |
default: | |
TORCH_INTERNAL_ASSERT( | |
false, "Unknown CUDA graph CaptureStatus", int(status)); | |
} | |
return os; | |
} | |
// Use this version where you're sure a CUDA context exists already. | |
inline CaptureStatus currentStreamCaptureStatusMayInitCtx() { | |
cudaStreamCaptureStatus is_capturing{cudaStreamCaptureStatusNone}; | |
C10_CUDA_CHECK( | |
cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing)); | |
return CaptureStatus(is_capturing); | |
return CaptureStatus::None; | |
} | |
} // namespace c10::cuda | |