/****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #include #include //////////////////////////////////////////////////////////////////////////////////////////////////// template struct BytesToType {}; template<> struct BytesToType<16> { using Type = uint4; static_assert(sizeof(Type) == 16); }; template<> struct BytesToType<8> { using Type = uint64_t; static_assert(sizeof(Type) == 8); }; template<> struct BytesToType<4> { using Type = uint32_t; static_assert(sizeof(Type) == 4); }; template<> struct BytesToType<2> { using Type = uint16_t; static_assert(sizeof(Type) == 2); }; template<> struct BytesToType<1> { using Type = uint8_t; static_assert(sizeof(Type) == 1); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SumOp { __device__ inline T operator()(T const & x, T const & y) { return x + y; } }; template struct Allreduce { static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); template static __device__ inline T run(T x, Operator &op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); return Allreduce::run(x, op); } }; template<> struct Allreduce<2> { template static __device__ inline T run(T x, Operator &op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; } };