Spaces:
Running
Running
// Copyright © 2022 Apple Inc. | |
namespace at { | |
struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface { | |
// this fails the implementation if MPSHooks functions are called, but | |
// MPS backend is not present. | |
TORCH_CHECK(false, "Cannot execute ", func, "() without MPS backend."); | |
virtual ~MPSHooksInterface() override = default; | |
// Initialize the MPS library state | |
virtual void initMPS() const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual bool hasMPS() const { | |
return false; | |
} | |
virtual bool isOnMacOSorNewer(unsigned major = 13, unsigned minor = 0) const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual const Generator& getDefaultMPSGenerator() const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual Allocator* getMPSDeviceAllocator() const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual void deviceSynchronize() const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual void commitStream() const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual void* getCommandBuffer() const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual void* getDispatchQueue() const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual void emptyCache() const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual size_t getCurrentAllocatedMemory() const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual size_t getDriverAllocatedMemory() const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual void setMemoryFraction(double /*ratio*/) const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual void profilerStopTrace() const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual uint32_t acquireEvent(bool enable_timing) const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual void releaseEvent(uint32_t event_id) const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual void recordEvent(uint32_t event_id) const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual void waitForEvent(uint32_t event_id) const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual void synchronizeEvent(uint32_t event_id) const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual bool queryEvent(uint32_t event_id) const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
virtual bool hasPrimaryContext(DeviceIndex device_index) const override { | |
FAIL_MPSHOOKS_FUNC(__func__); | |
} | |
}; | |
struct TORCH_API MPSHooksArgs {}; | |
TORCH_DECLARE_REGISTRY(MPSHooksRegistry, MPSHooksInterface, MPSHooksArgs); | |
C10_REGISTER_CLASS(MPSHooksRegistry, clsname, clsname) | |
namespace detail { | |
TORCH_API const MPSHooksInterface& getMPSHooks(); | |
} // namespace detail | |
} // namespace at | |