Spaces:
Running
Running
// Copyright © 2022 Apple Inc. | |
namespace at::mps { | |
typedef MPSEvent* mpsEvent_t; | |
// TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl | |
// https://github.com/pytorch/pytorch/issues/77170 | |
struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface { | |
static constexpr c10::DeviceType static_type = c10::DeviceType::MPS; | |
// constructor | |
MPSGuardImpl() {} | |
explicit MPSGuardImpl(c10::DeviceType t) { | |
TORCH_INTERNAL_ASSERT(t == c10::DeviceType::MPS); | |
} | |
// returns the type | |
c10::DeviceType type() const override { | |
return c10::DeviceType::MPS; | |
} | |
Device exchangeDevice(Device d) const override { | |
return Device(c10::DeviceType::MPS, 0); | |
} | |
Device getDevice() const override { | |
return Device(c10::DeviceType::MPS, 0); | |
} | |
c10::optional<Device> uncheckedGetDevice() const noexcept { | |
return Device(c10::DeviceType::MPS, 0); | |
} | |
void setDevice(Device d) const override { | |
TORCH_INTERNAL_ASSERT(d.is_mps()); | |
} | |
void uncheckedSetDevice(Device d) const noexcept override { | |
// TODO: Currently setting only device 0 | |
} | |
Stream getStream(Device d) const noexcept override { | |
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); | |
} | |
Stream getDefaultStream(Device d) const override { | |
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); | |
} | |
// NB: These do NOT set the current device | |
Stream exchangeStream(Stream s) const noexcept override { | |
return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); | |
} | |
DeviceIndex deviceCount() const noexcept override { | |
if (at::hasMPS()) { | |
//TODO: extend it for multi-device case | |
return 1; | |
} else { | |
return 0; | |
} | |
} | |
// Event-related functions | |
void createEvent( | |
mpsEvent_t* event, | |
const EventFlag flag) const; | |
void destroyEvent( | |
void* event, | |
const DeviceIndex device_index) const noexcept override; | |
void record( | |
void** event, | |
const Stream& stream, | |
const DeviceIndex device_index, | |
const EventFlag flag) const override; | |
void block( | |
void* event, | |
const Stream& stream) const override; | |
bool queryEvent(void* event) const override; | |
}; | |
/// A variant of OptionalDeviceGuard that is specialized for MPS. | |
struct OptionalMPSGuard { | |
explicit OptionalMPSGuard() : guard_() {} | |
explicit OptionalMPSGuard(c10::optional<Device> device_opt) | |
: guard_(device_opt) {} | |
/// Set the current MPS device to the passed device index, if it is not | |
/// nullopt | |
explicit OptionalMPSGuard(c10::optional<DeviceIndex> device_index_opt) | |
: guard_(device_index_opt) {} | |
// Copy is not allowed | |
OptionalMPSGuard(const OptionalMPSGuard&) = delete; | |
OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete; | |
OptionalMPSGuard(OptionalMPSGuard&& other) = delete; | |
OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete; | |
/// Sets the MPS device to the given device, initializing the guard if it | |
/// is not already initialized. Errors if the given device is not a MPS | |
/// device. | |
void set_device(Device device) { | |
guard_.set_device(device); | |
} | |
/// Sets the MPS device to the given device, initializing the guard if it is | |
/// not already initialized. Errors if the given device is not a MPS device. | |
void reset_device(Device device) { | |
guard_.reset_device(device); | |
} | |
/// Sets the MPS device to the given device index, initializing the guard if | |
/// it is not already initialized. | |
void set_index(DeviceIndex device_index) { | |
guard_.set_index(device_index); | |
} | |
/// Returns the device that was set immediately prior to initialization of the | |
/// guard, or nullopt if the guard is uninitialized. | |
c10::optional<Device> original_device() const { | |
return guard_.original_device(); | |
} | |
/// Returns the most recent device that was set using this device guard, | |
/// either from construction, or via set_device, if the guard is initialized, | |
/// or nullopt if the guard is uninitialized. | |
c10::optional<Device> current_device() const { | |
return guard_.current_device(); | |
} | |
/// Restore the original MPS device, resetting this guard to uninitialized | |
/// state. | |
void reset() { | |
guard_.reset(); | |
} | |
private: | |
c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_; | |
}; | |
C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl); | |
} // namespace at::mps | |