Spaces:
Running
Running
File size: 3,081 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/Device.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAStream.h>
#include <mutex>
namespace at {
struct CUDAGeneratorImpl;
namespace cuda {
// Standalone way to get a unique mempool id usable as a pool=... argument
// to CUDAGraph::capture_begin
TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();
struct TORCH_CUDA_CPP_API CUDAGraph {
CUDAGraph();
~CUDAGraph();
static void inc_pending_event_queries();
static void dec_pending_event_queries();
static int num_pending_event_queries();
void capture_begin(MempoolId_t pool={0, 0}, cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
void capture_end();
void replay();
void reset();
MempoolId_t pool();
void enable_debug_mode();
void debug_dump(const std::string& debug_path);
protected:
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
cudaGraph_t graph_ = NULL;
cudaGraphExec_t graph_exec_ = NULL;
#endif
static std::atomic<int> pending_event_queries;
// internal states so reset() can do its best cleaning up
// Set to true in capture_end if cudaStreamEndCapture succeeded
// Set back to false soon after, when graph_ is consumed by cudaGraphInstantiate
// to create graph_exec_, then graph_ is deleted
bool has_graph_ = false;
// Set to true in capture_end if cudaGraphInstantiate succeeded
bool has_graph_exec_ = false;
// uuid of this instance's current capture, used to
// specify the pool.
CaptureId_t id_;
// the ID assigned by cuda during graph capture,
// used to identify when a stream is participating in capture
CaptureId_t capture_id_ = -1;
// uuid used to request a particular private mempool from CUDACachingAllocator.
// By default, this will be set to {id_, 0}.
//
// If capture_begin is called with "pool=other_graph.pool()", this graph's mempool_id_
// will be set to the other graph's mempool_id_, and therefore share a mempool with the
// other graph.
//
// If capture_begin is called with "pool=handle" where "handle" came from graph_pool_handle(),
// it will share a mempool with any other captures that used "pool=handle".
//
// Sharing a mempool across graphs saves memory, and it's safe if you
// know you'll replay those graphs in the same order you captured them.
MempoolId_t mempool_id_;
// Stream on which capture began
at::cuda::CUDAStream capture_stream_;
// Default generator on device where capture began
at::CUDAGeneratorImpl* capture_gen_;
// Device where capture occurred. Right now, for simplicity, we require all ops
// in a capture to run on the same device, but this is a limitation of CUDAGraph,
// not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device
// captures if needed.
int capture_dev_;
// RNG state trackers
at::Tensor seed_extragraph_;
at::Tensor offset_extragraph_;
uint64_t wholegraph_increment_;
};
} // namespace cuda
} // namespace at
|