Spaces:
Running
Running
// Implements instruction set specific function dispatch. | |
// | |
// Kernels that may make use of specialized instruction sets (e.g. AVX2) are | |
// compiled multiple times with different compiler flags (e.g. -mavx2). A | |
// DispatchStub contains a table of function pointers for a kernel. At runtime, | |
// the fastest available kernel is chosen based on the features reported by | |
// cpuinfo. | |
// | |
// Example: | |
// | |
// In native/MyKernel.h: | |
// using fn_type = void(*)(const Tensor& x); | |
// DECLARE_DISPATCH(fn_type, stub); | |
// | |
// In native/MyKernel.cpp | |
// DEFINE_DISPATCH(stub); | |
// | |
// In native/cpu/MyKernel.cpp: | |
// namespace { | |
// // use anonymous namespace so that different cpu versions won't conflict | |
// void kernel(const Tensor& x) { ... } | |
// } | |
// REGISTER_DISPATCH(stub, &kernel); | |
// | |
// To call: | |
// stub(kCPU, tensor); | |
// | |
// TODO: CPU instruction set selection should be folded into whatever | |
// the main dispatch mechanism is. | |
// ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere | |
C10_CLANG_DIAGNOSTIC_PUSH() | |
C10_CLANG_DIAGNOSTIC_IGNORE("-Wundefined-var-template") | |
namespace at::native { | |
enum class CPUCapability { | |
DEFAULT = 0, | |
VSX = 1, | |
ZVECTOR = 1, | |
AVX2 = 1, | |
AVX512 = 2, | |
NUM_OPTIONS | |
}; | |
CPUCapability get_cpu_capability(); | |
template <typename FnPtr, typename T> | |
struct DispatchStub; | |
/** | |
* The sole purpose of this class is to outline methods that don't need to be | |
* specialized or otherwise inlined and duplicated (by the compiler due to | |
* template expansion), since it causes size bloat if there are a significant | |
* number of specialization of the DispatchStub<> class. | |
*/ | |
struct TORCH_API DispatchStubImpl { | |
void* get_call_ptr( | |
c10::DeviceType device_type | |
, void *DEFAULT | |
#ifdef HAVE_AVX512_CPU_DEFINITION | |
, void *AVX512 | |
#endif | |
#ifdef HAVE_AVX2_CPU_DEFINITION | |
, void *AVX2 | |
#endif | |
#ifdef HAVE_VSX_CPU_DEFINITION | |
, void *VSX | |
#endif | |
#ifdef HAVE_ZVECTOR_CPU_DEFINITION | |
, void *ZVECTOR | |
#endif | |
); | |
/** | |
* The CPU Dispatch actual method is chosen in decreasing order of preference by | |
* DispatchStubImpl::choose_cpu_impl() in case none is found by | |
* DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr. | |
*/ | |
void* choose_cpu_impl( | |
void *DEFAULT | |
#ifdef HAVE_AVX512_CPU_DEFINITION | |
, void *AVX512 | |
#endif | |
#ifdef HAVE_AVX2_CPU_DEFINITION | |
, void *AVX2 | |
#endif | |
#ifdef HAVE_VSX_CPU_DEFINITION | |
, void *VSX | |
#endif | |
#ifdef HAVE_ZVECTOR_CPU_DEFINITION | |
, void *ZVECTOR | |
#endif | |
); | |
// Fixing dispatch error in Windows debug builds. | |
// See https://github.com/pytorch/pytorch/issues/22681 for more details. | |
std::atomic<void*> cpu_dispatch_ptr; | |
void* cuda_dispatch_ptr; | |
void* hip_dispatch_ptr; | |
void* mps_dispatch_ptr; | |
void* privateuse1_dispatch_ptr; | |
std::atomic<void*> cpu_dispatch_ptr{nullptr}; | |
void* cuda_dispatch_ptr = nullptr; | |
void* hip_dispatch_ptr = nullptr; | |
void* mps_dispatch_ptr = nullptr; | |
void* privateuse1_dispatch_ptr = nullptr; | |
}; | |
template <typename rT, typename T, typename... Args> | |
struct DispatchStub<rT (*)(Args...), T> { | |
using FnPtr = rT (*) (Args...); | |
DispatchStub() = default; | |
DispatchStub(const DispatchStub&) = delete; | |
DispatchStub& operator=(const DispatchStub&) = delete; | |
private: | |
FnPtr get_call_ptr(c10::DeviceType device_type) { | |
return reinterpret_cast<FnPtr>( | |
impl.get_call_ptr(device_type | |
, reinterpret_cast<void*>(DEFAULT) | |
, reinterpret_cast<void*>(AVX512) | |
, reinterpret_cast<void*>(AVX2) | |
, reinterpret_cast<void*>(VSX) | |
, reinterpret_cast<void*>(ZVECTOR) | |
) | |
); | |
} | |
public: | |
template <typename... ArgTypes> | |
rT operator()(c10::DeviceType device_type, ArgTypes&&... args) { | |
FnPtr call_ptr = get_call_ptr(device_type); | |
return (*call_ptr)(std::forward<ArgTypes>(args)...); | |
} | |
void set_cuda_dispatch_ptr(FnPtr fn_ptr) { | |
impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); | |
} | |
void set_hip_dispatch_ptr(FnPtr fn_ptr) { | |
impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); | |
} | |
void set_mps_dispatch_ptr(FnPtr fn_ptr) { | |
impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); | |
} | |
void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) { | |
impl.privateuse1_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); | |
} | |
static TORCH_API FnPtr DEFAULT; | |
static TORCH_API FnPtr AVX512; | |
static TORCH_API FnPtr AVX2; | |
static TORCH_API FnPtr VSX; | |
static TORCH_API FnPtr ZVECTOR; | |
private: | |
DispatchStubImpl impl; | |
}; | |
namespace { | |
template <typename DispatchStub> | |
struct RegisterCUDADispatch { | |
RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { | |
stub.set_cuda_dispatch_ptr(value); | |
} | |
}; | |
template <typename DispatchStub> | |
struct RegisterMPSDispatch { | |
RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { | |
stub.set_mps_dispatch_ptr(value); | |
} | |
}; | |
template <typename DispatchStub> | |
struct RegisterHIPDispatch { | |
RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { | |
// TODO: make this point at hip_dispatch_ptr | |
stub.set_cuda_dispatch_ptr(value); | |
} | |
}; | |
template <typename DispatchStub> | |
struct RegisterPRIVATEUSE1Dispatch { | |
RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { | |
stub.set_privateuse1_dispatch_ptr(value); | |
} | |
}; | |
} // anonymous namespace | |
// Compiler will complain if you put things like std::tuple<Tensor, Tensor> in | |
// the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g., | |
// adding parentheses and using helper struct to get rid of the parentheses, do | |
// not work with MSVC. So do a `using`-declaration if you need to pass in such | |
// `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h. | |
struct name : DispatchStub<fn, name> { \ | |
name() = default; \ | |
name(const name&) = delete; \ | |
name& operator=(const name&) = delete; \ | |
}; \ | |
extern TORCH_API struct name name | |
#define DEFINE_DISPATCH(name) struct name name | |
#define REGISTER_ARCH_DISPATCH(name, arch, fn) \ | |
template <> name::FnPtr TORCH_API DispatchStub<name::FnPtr, struct name>::arch = fn; | |
// Macro to register the same kernel for all CPU arch types. This is useful | |
// if a kernel does not benefit from being recompiled across different arch types. | |
REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \ | |
REGISTER_AVX512_DISPATCH(name, fn) \ | |
REGISTER_AVX2_DISPATCH(name, fn) \ | |
REGISTER_VSX_DISPATCH(name, fn) \ | |
REGISTER_ZVECTOR_DISPATCH(name, fn) | |
REGISTER_ALL_CPU_DISPATCH(name, nullptr) | |
static RegisterCUDADispatch<struct name> name ## __register(name, fn); | |
static RegisterHIPDispatch<struct name> name ## __register(name, fn); | |
static RegisterMPSDispatch<struct name> name ## __register(name, fn); | |
static RegisterPRIVATEUSE1Dispatch<struct name> name ## __register(name, fn); | |
// NB: This macro must be used in an actual 'cu' file; if you try using | |
// it from a 'cpp' file it will not work! | |
// TODO: cut this over to HIP dispatch once we stop pretending that CUDA | |
// is HIP in the PyTorch HIPify build. | |
// #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn) | |
// NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel | |
// REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches. | |
// ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others. | |
} // namespace at::native | |
C10_CLANG_DIAGNOSTIC_POP() | |