Spaces:
Running
Running
// Copyright © 2022 Apple Inc. | |
// This file is modify from: | |
// https://github.com/pytorch/pytorch/blob/a85d1f0bcdd02cf18d3b0517337458cb51a18cdb/aten/src/ATen/mps/MPSStream.h | |
typedef id<MTLCommandQueue> MTLCommandQueue_t; | |
typedef id<MTLCommandBuffer> MTLCommandBuffer_t; | |
typedef id<MTLSharedEvent> MTLSharedEvent_t; | |
typedef id<MTLDevice> MTLDevice_t; | |
typedef void* MTLCommandQueue_t; | |
typedef void* MTLCommandQueue; | |
typedef void* MTLCommandBuffer_t; | |
typedef void* MTLCommandBuffer; | |
typedef void* MTLSharedEvent_t; | |
typedef void* dispatch_queue_t; | |
typedef void* MTLDevice_t; | |
namespace at { | |
namespace mps { | |
//----------------------------------------------------------------- | |
// MPSStream | |
//----------------------------------------------------------------- | |
class TORCH_API MPSStream { | |
public: | |
enum Unchecked { UNCHECKED }; | |
/// Construct a MPSStream from a Stream. This construction is checked, | |
/// and will raise an error if the Stream is not, in fact, a MPS stream. | |
explicit MPSStream(Stream stream); | |
~MPSStream(); | |
MTLCommandQueue_t commandQueue() const { return _commandQueue; }; | |
dispatch_queue_t queue() const { return _serialQueue; } | |
MTLCommandBuffer_t commandBuffer(); | |
void commit(bool flush); | |
void commitAndWait(); | |
void synchronize(); | |
void flush(); | |
/// Get the MPS device index that this stream is associated with. | |
c10::DeviceIndex device_index() const { return _stream.device_index(); } | |
MTLCommandQueue_t stream() const { return _commandQueue; }; | |
MTLDevice_t device() const { return [_commandQueue device]; } | |
/// Explicit conversion to Stream. | |
Stream unwrap() const { return _stream; } | |
private: | |
Stream _stream; | |
MTLCommandQueue_t _commandQueue = nil; | |
MTLCommandBuffer_t _commandBuffer = nil; | |
void _flush(bool commitAndWait) const; | |
dispatch_queue_t _serialQueue = nullptr; | |
}; | |
/** | |
* Get the current MPS stream | |
*/ | |
TORCH_API MPSStream* getCurrentMPSStream(); | |
/** | |
* Get the default MPS stream | |
*/ | |
TORCH_API MPSStream* getDefaultMPSStream(); | |
//----------------------------------------------------------------- | |
// MPSStreamImpl | |
//----------------------------------------------------------------- | |
class TORCH_API MPSStreamImpl { | |
public: | |
/** | |
* Gets single instance of the MPSStream. | |
*/ | |
static MPSStream* getInstance(); | |
private: | |
static MPSStream* _stream; | |
MPSStreamImpl(); | |
}; | |
//----------------------------------------------------------------- | |
// MPSEvent | |
//----------------------------------------------------------------- | |
struct TORCH_API MPSEvent { | |
MPSEvent(); | |
// MPSEvent(id<MTLDevice> device); | |
~MPSEvent(); | |
MTLSharedEvent_t event() const { return _event; } | |
void recordEvent(MPSStream* stream); | |
void waitForEvent(MPSStream* queue); // waits on the cpu | |
bool queryEvent(); | |
uint64_t getCurrentValue() { return _currentValue; } | |
void setCurrentValue(uint64_t currValue) { _currentValue = currValue; } | |
private: | |
bool _isRecorded = false; | |
uint64_t _currentValue = 0; | |
MTLSharedEvent_t _event; | |
}; | |
typedef MPSEvent* mpsEvent_t; | |
} // namespace mps | |
} // namespace at | |