Spaces:
Running
Running
| // Copyright © 2022 Apple Inc. | |
| // This file is modify from: | |
| // https://github.com/pytorch/pytorch/blob/a85d1f0bcdd02cf18d3b0517337458cb51a18cdb/aten/src/ATen/mps/MPSStream.h | |
| typedef id<MTLCommandQueue> MTLCommandQueue_t; | |
| typedef id<MTLCommandBuffer> MTLCommandBuffer_t; | |
| typedef id<MTLSharedEvent> MTLSharedEvent_t; | |
| typedef id<MTLDevice> MTLDevice_t; | |
| typedef void* MTLCommandQueue_t; | |
| typedef void* MTLCommandQueue; | |
| typedef void* MTLCommandBuffer_t; | |
| typedef void* MTLCommandBuffer; | |
| typedef void* MTLSharedEvent_t; | |
| typedef void* dispatch_queue_t; | |
| typedef void* MTLDevice_t; | |
| namespace at { | |
| namespace mps { | |
| //----------------------------------------------------------------- | |
| // MPSStream | |
| //----------------------------------------------------------------- | |
| class TORCH_API MPSStream { | |
| public: | |
| enum Unchecked { UNCHECKED }; | |
| /// Construct a MPSStream from a Stream. This construction is checked, | |
| /// and will raise an error if the Stream is not, in fact, a MPS stream. | |
| explicit MPSStream(Stream stream); | |
| ~MPSStream(); | |
| MTLCommandQueue_t commandQueue() const { return _commandQueue; }; | |
| dispatch_queue_t queue() const { return _serialQueue; } | |
| MTLCommandBuffer_t commandBuffer(); | |
| void commit(bool flush); | |
| void commitAndWait(); | |
| void synchronize(); | |
| void flush(); | |
| /// Get the MPS device index that this stream is associated with. | |
| c10::DeviceIndex device_index() const { return _stream.device_index(); } | |
| MTLCommandQueue_t stream() const { return _commandQueue; }; | |
| MTLDevice_t device() const { return [_commandQueue device]; } | |
| /// Explicit conversion to Stream. | |
| Stream unwrap() const { return _stream; } | |
| private: | |
| Stream _stream; | |
| MTLCommandQueue_t _commandQueue = nil; | |
| MTLCommandBuffer_t _commandBuffer = nil; | |
| void _flush(bool commitAndWait) const; | |
| dispatch_queue_t _serialQueue = nullptr; | |
| }; | |
| /** | |
| * Get the current MPS stream | |
| */ | |
| TORCH_API MPSStream* getCurrentMPSStream(); | |
| /** | |
| * Get the default MPS stream | |
| */ | |
| TORCH_API MPSStream* getDefaultMPSStream(); | |
| //----------------------------------------------------------------- | |
| // MPSStreamImpl | |
| //----------------------------------------------------------------- | |
| class TORCH_API MPSStreamImpl { | |
| public: | |
| /** | |
| * Gets single instance of the MPSStream. | |
| */ | |
| static MPSStream* getInstance(); | |
| private: | |
| static MPSStream* _stream; | |
| MPSStreamImpl(); | |
| }; | |
| //----------------------------------------------------------------- | |
| // MPSEvent | |
| //----------------------------------------------------------------- | |
| struct TORCH_API MPSEvent { | |
| MPSEvent(); | |
| // MPSEvent(id<MTLDevice> device); | |
| ~MPSEvent(); | |
| MTLSharedEvent_t event() const { return _event; } | |
| void recordEvent(MPSStream* stream); | |
| void waitForEvent(MPSStream* queue); // waits on the cpu | |
| bool queryEvent(); | |
| uint64_t getCurrentValue() { return _currentValue; } | |
| void setCurrentValue(uint64_t currValue) { _currentValue = currValue; } | |
| private: | |
| bool _isRecorded = false; | |
| uint64_t _currentValue = 0; | |
| MTLSharedEvent_t _event; | |
| }; | |
| typedef MPSEvent* mpsEvent_t; | |
| } // namespace mps | |
| } // namespace at | |