Spaces:
Running
Running
// Copyright © 2023 Apple Inc. | |
namespace at::mps { | |
// NOTE: don't create instances of this class directly. | |
// Use MPSEventPool to acquire instances of MPSEvent. | |
class MPSEvent { | |
public: | |
explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing); | |
~MPSEvent(); | |
// records an event on the stream | |
void record(bool needsLock, bool syncEvent = false); | |
// makes all future work submitted to the stream wait for this event. | |
bool wait(bool needsLock, bool syncEvent = false); | |
// schedules a notifyListener callback for the event. | |
bool notify(bool needsLock, MTLSharedEventNotificationBlock block); | |
// checks if events are already signaled. | |
bool query() const; | |
// blocks the CPU thread until all the GPU work that were scheduled | |
// prior to recording this event are completed. | |
bool synchronize(); | |
// resets this event with new parameters in case it gets reused from the event pool | |
void reset(MPSStream* stream, bool enable_timing); | |
// returns the unique ID of the event instance | |
id_t getID() const { return m_id; } | |
// returns the completion timestamp of the event | |
uint64_t getCompletionTime() const { return m_completion_time; } | |
// if already recorded, waits for cpu_sync_cv to be signaled | |
void waitForCpuSync(); | |
private: | |
id_t m_id; | |
// enables measuring the completion time of the notifyListener of this event | |
bool m_enable_timing; | |
uint64_t m_signalCounter = 0; | |
MPSStream* m_stream = nullptr; | |
MTLSharedEvent_t m_event = nullptr; | |
MTLSharedEventListener* m_listener = nullptr; | |
// used to sync the events created on this Stream with CPU | |
std::mutex m_cpu_sync_mutex{}; | |
std::condition_variable m_cpu_sync_cv{}; | |
// CondVar predicate to sync the events created on this Stream with CPU | |
bool m_cpu_sync_completed = false; | |
// used to compute elapsed time | |
uint64_t m_completion_time = 0; | |
void recordLocked(bool syncEvent); | |
bool waitLocked(bool syncEvent); | |
bool notifyLocked(MTLSharedEventNotificationBlock block); | |
void notifyCpuSync(); | |
static uint64_t getTime() { | |
return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW); | |
} | |
}; | |
typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr; | |
class MPSEventPool { | |
public: | |
explicit MPSEventPool(MPSStream* default_stream); | |
~MPSEventPool(); | |
MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream); | |
void emptyCache(); | |
// these are mainly used for MPSHooks and torch.mps.Event() bindings | |
id_t acquireEvent(bool enable_timing); | |
void releaseEvent(id_t event_id); | |
void recordEvent(id_t event_id, bool syncEvent); | |
void waitForEvent(id_t event_id, bool syncEvent); | |
void synchronizeEvent(id_t event_id); | |
bool queryEvent(id_t event_id); | |
// returns elapsed time between two recorded events in milliseconds | |
double elapsedTime(id_t start_event_id, id_t end_event_id); | |
private: | |
MPSStream* m_default_stream = nullptr; | |
std::recursive_mutex m_mutex; | |
std::stack<std::unique_ptr<MPSEvent>> m_pool{}; | |
// dictionary to associate event IDs with event objects | |
// used to retain in-use events out of the pool | |
// for torch.mps.Event() bindings. | |
std::unordered_map<id_t, MPSEventPtr> m_in_use_events{}; | |
uint64_t m_event_counter = 0; | |
std::function<void(MPSEvent*)> m_default_deleter; | |
MPSEvent* getInUseEvent(id_t event_id, bool locked = true); | |
}; | |
// shared_ptr is used to get MPSEventPool destroyed after dependent instances | |
std::shared_ptr<MPSEventPool> getMPSEventPool(); | |
} // namespace at::mps | |