File size: 3,647 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
//  Copyright © 2023 Apple Inc.

#pragma once

#include <ATen/mps/MPSStream.h>
#include <ctime>
#include <stack>

namespace at::mps {

// NOTE: don't create instances of this class directly.
// Use MPSEventPool to acquire instances of MPSEvent.
class MPSEvent {
public:
  explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
  ~MPSEvent();

  // records an event on the stream
  void record(bool needsLock, bool syncEvent = false);
  // makes all future work submitted to the stream wait for this event.
  bool wait(bool needsLock, bool syncEvent = false);
  // schedules a notifyListener callback for the event.
  bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
  // checks if events are already signaled.
  bool query() const;
  // blocks the CPU thread until all the GPU work that were scheduled
  // prior to recording this event are completed.
  bool synchronize();
  // resets this event with new parameters in case it gets reused from the event pool
  void reset(MPSStream* stream, bool enable_timing);
  // returns the unique ID of the event instance
  id_t getID() const { return m_id; }
  // returns the completion timestamp of the event
  uint64_t getCompletionTime() const { return m_completion_time; }
  // if already recorded, waits for cpu_sync_cv to be signaled
  void waitForCpuSync();

private:
  id_t m_id;
  // enables measuring the completion time of the notifyListener of this event
  bool m_enable_timing;
  uint64_t m_signalCounter = 0;
  MPSStream* m_stream = nullptr;
  MTLSharedEvent_t m_event = nullptr;
  MTLSharedEventListener* m_listener = nullptr;
  // used to sync the events created on this Stream with CPU
  std::mutex m_cpu_sync_mutex{};
  std::condition_variable m_cpu_sync_cv{};
  // CondVar predicate to sync the events created on this Stream with CPU
  bool m_cpu_sync_completed = false;
  // used to compute elapsed time
  uint64_t m_completion_time = 0;

  void recordLocked(bool syncEvent);
  bool waitLocked(bool syncEvent);
  bool notifyLocked(MTLSharedEventNotificationBlock block);
  void notifyCpuSync();
  static uint64_t getTime() {
    return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
  }
};

typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;

class MPSEventPool {
public:
  explicit MPSEventPool(MPSStream* default_stream);
  ~MPSEventPool();

  MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
  void emptyCache();

  // these are mainly used for MPSHooks and torch.mps.Event() bindings
  id_t acquireEvent(bool enable_timing);
  void releaseEvent(id_t event_id);
  void recordEvent(id_t event_id, bool syncEvent);
  void waitForEvent(id_t event_id, bool syncEvent);
  void synchronizeEvent(id_t event_id);
  bool queryEvent(id_t event_id);
  // returns elapsed time between two recorded events in milliseconds
  double elapsedTime(id_t start_event_id, id_t end_event_id);

private:
  MPSStream* m_default_stream = nullptr;
  std::recursive_mutex m_mutex;
  std::stack<std::unique_ptr<MPSEvent>> m_pool{};
  // dictionary to associate event IDs with event objects
  // used to retain in-use events out of the pool
  // for torch.mps.Event() bindings.
  std::unordered_map<id_t, MPSEventPtr> m_in_use_events{};
  uint64_t m_event_counter = 0;
  std::function<void(MPSEvent*)> m_default_deleter;

  MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
};

// shared_ptr is used to get MPSEventPool destroyed after dependent instances
std::shared_ptr<MPSEventPool> getMPSEventPool();

} // namespace at::mps