#ifndef _MPS_UTILS_H_ #define _MPS_UTILS_H_ #include #ifdef __OBJC__ #include #include #include typedef id MTLBuffer_t; typedef id MTLComputeCommandEncoder_t; #else typedef void* MTLBuffer; typedef void* MTLBuffer_t; typedef void* MTLComputeCommandEncoder; typedef void* MTLComputeCommandEncoder_t; #endif // utils static inline MTLBuffer_t getMTLBufferStorage(const at::Tensor& tensor) { return __builtin_bit_cast(MTLBuffer_t, tensor.storage().data()); } template , at::Tensor>::value, bool> = true> void setMTLArg(MTLComputeCommandEncoder_t encoder, int index, T&& t); template , at::Tensor>::value, bool> = true> void setMTLArg(MTLComputeCommandEncoder_t encoder, int index, T&& t) { [encoder setBuffer:getMTLBufferStorage(t) offset:0 atIndex:index]; } template , at::Tensor>::value, bool>> void setMTLArg(MTLComputeCommandEncoder_t encoder, int index, T&& t) { [encoder setBytes:&t length:sizeof(t) atIndex:index]; } inline void setMTLArgsImpl(MTLComputeCommandEncoder_t, int) {} template void setMTLArgsImpl(MTLComputeCommandEncoder_t encoder, int index, T&& t, Args&&... args) { setMTLArg(encoder, index, std::forward(t)); setMTLArgsImpl(encoder, index + 1, std::forward(args)...); } template void setMTLArgs(MTLComputeCommandEncoder_t encoder, MTLComputePipelineState_t pso, Args&&... args) { [encoder setComputePipelineState:pso]; setMTLArgsImpl(encoder, 0, std::forward(args)...); } #endif