Spaces:
Sleeping
Sleeping
import torch | |
class Event: | |
r"""Wrapper around an MPS event. | |
MPS events are synchronization markers that can be used to monitor the | |
device's progress, to accurately measure timing, and to synchronize MPS streams. | |
Args: | |
enable_timing (bool, optional): indicates if the event should measure time | |
(default: ``False``) | |
""" | |
def __init__(self, enable_timing=False): | |
self.__eventId = torch._C._mps_acquireEvent(enable_timing) | |
def __del__(self): | |
# checks if torch._C is already destroyed | |
if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0: | |
torch._C._mps_releaseEvent(self.__eventId) | |
def record(self): | |
r"""Records the event in the default stream.""" | |
torch._C._mps_recordEvent(self.__eventId) | |
def wait(self): | |
r"""Makes all future work submitted to the default stream wait for this event.""" | |
torch._C._mps_waitForEvent(self.__eventId) | |
def query(self): | |
r"""Returns True if all work currently captured by event has completed.""" | |
return torch._C._mps_queryEvent(self.__eventId) | |
def synchronize(self): | |
r"""Waits until the completion of all work currently captured in this event. | |
This prevents the CPU thread from proceeding until the event completes. | |
""" | |
torch._C._mps_synchronizeEvent(self.__eventId) | |
def elapsed_time(self, end_event): | |
r"""Returns the time elapsed in milliseconds after the event was | |
recorded and before the end_event was recorded. | |
""" | |
return torch._C._mps_elapsedTimeOfEvents(self.__eventId, end_event.__eventId) | |