|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
|
|
|
|
|
|
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" |
|
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" |
|
#include "cute/tensor.hpp" |
|
|
|
namespace cutlass::epilogue::threadblock { |
|
|
|
using namespace cute; |
|
using namespace detail; |
|
|
|
template< |
|
class ThreadMap, |
|
class Element, |
|
class StrideMNL |
|
> |
|
struct VisitorRowOrScalarBroadcast { |
|
|
|
|
|
|
|
struct Arguments { |
|
Element const* ptr_row = nullptr; |
|
bool row_broadcast = true; |
|
StrideMNL dRow = {}; |
|
}; |
|
|
|
using Params = Arguments; |
|
|
|
template <class ProblemShape> |
|
static constexpr Params |
|
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { |
|
return args; |
|
} |
|
|
|
template <class ProblemShape> |
|
static size_t |
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { |
|
return 0; |
|
} |
|
|
|
struct SharedStorage {}; |
|
|
|
|
|
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value; |
|
using VecType = uint_bit_t<cute::min(128, vec_bits)>; |
|
static int constexpr VecLength = sizeof(VecType) / sizeof(Element); |
|
|
|
CUTLASS_HOST_DEVICE |
|
VisitorRowOrScalarBroadcast() { } |
|
|
|
CUTLASS_HOST_DEVICE |
|
VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) |
|
: params_ptr(¶ms) { } |
|
|
|
Params const* params_ptr; |
|
|
|
template <class GTensor, class RTensor, class CTensor, class ProblemShape> |
|
struct Callbacks : EmptyCallbacks { |
|
CUTLASS_DEVICE |
|
Callbacks( |
|
GTensor&& tC_gRow, |
|
RTensor&& tC_rRow, |
|
CTensor&& tC_cRow, |
|
ProblemShape problem_shape, |
|
Params const* params_ptr |
|
): |
|
tC_gRow(cute::forward<GTensor>(tC_gRow)), |
|
tC_rRow(cute::forward<RTensor>(tC_rRow)), |
|
tC_cRow(cute::forward<CTensor>(tC_cRow)), |
|
n(get<1>(problem_shape)), |
|
params_ptr(params_ptr) { } |
|
|
|
GTensor tC_gRow; |
|
RTensor tC_rRow; |
|
CTensor tC_cRow; |
|
Params const* params_ptr; |
|
int n; |
|
|
|
|
|
CUTLASS_DEVICE void |
|
begin_epilogue() { |
|
clear(tC_rRow); |
|
auto src_v = filter(tC_gRow); |
|
auto coord_v = filter(tC_cRow); |
|
auto dst_v = filter(tC_rRow); |
|
|
|
if (params_ptr->row_broadcast) { |
|
|
|
CUTLASS_PRAGMA_UNROLL |
|
for (int i = 0; i < size(src_v); ++i) { |
|
bool guard = get<1>(coord_v(i)) < n; |
|
cutlass::arch::global_load<VecType, sizeof(VecType)>( |
|
dst_v(i), (void const*)&src_v(i), guard); |
|
} |
|
} else { |
|
|
|
VecType filled_vec; |
|
CUTLASS_PRAGMA_UNROLL |
|
for (int i = 0; i < VecLength; i++) { |
|
reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row); |
|
} |
|
|
|
CUTLASS_PRAGMA_UNROLL |
|
for (int i = 0; i < size(src_v); ++i) { |
|
if (get<1>(coord_v(i)) < n) { |
|
dst_v(i) = filled_vec; |
|
} |
|
} |
|
} |
|
} |
|
|
|
template <class ElementAccumulator, int FragmentSize> |
|
CUTLASS_DEVICE auto |
|
visit(int iter_idx, int row_idx, int column_idx, int frg_idx, |
|
Array<ElementAccumulator, FragmentSize> const& frg_acc) { |
|
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow)); |
|
return rRow_frg(column_idx); |
|
} |
|
}; |
|
|
|
template <class ProblemShape> |
|
CUTLASS_DEVICE auto |
|
get_callbacks( |
|
gemm::GemmCoord threadblock_tile_offset, |
|
int thread_idx, |
|
ProblemShape problem_shape |
|
) { |
|
Tensor mRow = make_tensor( |
|
make_gmem_ptr(params_ptr->ptr_row), |
|
problem_shape, |
|
params_ptr->dRow); |
|
|
|
|
|
Tensor tC_gRow = recast<VecType>( |
|
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) |
|
)(_,_,_0{},_0{},_0{},_0{}); |
|
Tensor tC_rRow = make_tensor_like(tC_gRow); |
|
|
|
|
|
Tensor cRow = make_identity_tensor(mRow.shape()); |
|
Tensor tC_cRow = outer_partition( |
|
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), |
|
Shape<Int<VecLength>>{}, |
|
(_0{}) |
|
); |
|
|
|
return Callbacks< |
|
decltype(tC_gRow), decltype(tC_rRow), |
|
decltype(tC_cRow), ProblemShape>( |
|
cute::move(tC_gRow), |
|
cute::move(tC_rRow), |
|
cute::move(tC_cRow), |
|
problem_shape, |
|
params_ptr |
|
); |
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
template< |
|
class ThreadMap, |
|
class Element, |
|
class StrideMNL |
|
> |
|
struct VisitorRowOrZeroBroadcast { |
|
|
|
|
|
struct Arguments { |
|
Element const* ptr_row = nullptr; |
|
StrideMNL dRow = {}; |
|
}; |
|
|
|
using Params = Arguments; |
|
|
|
template <class ProblemShape> |
|
static constexpr Params |
|
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { |
|
return args; |
|
} |
|
|
|
template <class ProblemShape> |
|
static size_t |
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { |
|
return 0; |
|
} |
|
|
|
struct SharedStorage {}; |
|
|
|
|
|
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value; |
|
using VecType = uint_bit_t<cute::min(128, vec_bits)>; |
|
static int constexpr VecLength = sizeof(VecType) / sizeof(Element); |
|
|
|
CUTLASS_HOST_DEVICE |
|
VisitorRowOrZeroBroadcast() { } |
|
|
|
CUTLASS_HOST_DEVICE |
|
VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage) |
|
: params_ptr(¶ms) { } |
|
|
|
Params const* params_ptr; |
|
|
|
template <class GTensor, class RTensor, class CTensor, class ProblemShape> |
|
struct Callbacks : EmptyCallbacks { |
|
CUTLASS_DEVICE |
|
Callbacks( |
|
GTensor&& tC_gRow, |
|
RTensor&& tC_rRow, |
|
CTensor&& tC_cRow, |
|
ProblemShape problem_shape, |
|
Params const* params_ptr |
|
): |
|
tC_gRow(cute::forward<GTensor>(tC_gRow)), |
|
tC_rRow(cute::forward<RTensor>(tC_rRow)), |
|
tC_cRow(cute::forward<CTensor>(tC_cRow)), |
|
n(get<1>(problem_shape)), |
|
params_ptr(params_ptr) { } |
|
|
|
GTensor tC_gRow; |
|
RTensor tC_rRow; |
|
CTensor tC_cRow; |
|
Params const* params_ptr; |
|
int n; |
|
|
|
|
|
CUTLASS_DEVICE void |
|
begin_epilogue() { |
|
clear(tC_rRow); |
|
auto src_v = filter(tC_gRow); |
|
auto coord_v = filter(tC_cRow); |
|
auto dst_v = filter(tC_rRow); |
|
|
|
if (params_ptr->ptr_row != nullptr) { |
|
|
|
CUTLASS_PRAGMA_UNROLL |
|
for (int i = 0; i < size(src_v); ++i) { |
|
bool guard = get<1>(coord_v(i)) < n; |
|
cutlass::arch::global_load<VecType, sizeof(VecType)>( |
|
dst_v(i), (void const*)&src_v(i), guard); |
|
} |
|
} else { |
|
|
|
VecType filled_vec; |
|
CUTLASS_PRAGMA_UNROLL |
|
for (int i = 0; i < VecLength; i++) { |
|
reinterpret_cast<Element*>(&filled_vec)[i] = Element{0}; |
|
} |
|
|
|
CUTLASS_PRAGMA_UNROLL |
|
for (int i = 0; i < size(src_v); ++i) { |
|
if (get<1>(coord_v(i)) < n) { |
|
dst_v(i) = filled_vec; |
|
} |
|
} |
|
} |
|
} |
|
|
|
template <class ElementAccumulator, int FragmentSize> |
|
CUTLASS_DEVICE auto |
|
visit(int iter_idx, int row_idx, int column_idx, int frg_idx, |
|
Array<ElementAccumulator, FragmentSize> const& frg_acc) { |
|
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow)); |
|
return rRow_frg(column_idx); |
|
} |
|
}; |
|
|
|
template <class ProblemShape> |
|
CUTLASS_DEVICE auto |
|
get_callbacks( |
|
gemm::GemmCoord threadblock_tile_offset, |
|
int thread_idx, |
|
ProblemShape problem_shape |
|
) { |
|
Tensor mRow = make_tensor( |
|
make_gmem_ptr(params_ptr->ptr_row), |
|
problem_shape, |
|
params_ptr->dRow); |
|
|
|
|
|
Tensor tC_gRow = recast<VecType>( |
|
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) |
|
)(_,_,_0{},_0{},_0{},_0{}); |
|
Tensor tC_rRow = make_tensor_like(tC_gRow); |
|
|
|
|
|
Tensor cRow = make_identity_tensor(mRow.shape()); |
|
Tensor tC_cRow = outer_partition( |
|
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), |
|
Shape<Int<VecLength>>{}, |
|
(_0{}) |
|
); |
|
|
|
return Callbacks< |
|
decltype(tC_gRow), decltype(tC_rRow), |
|
decltype(tC_cRow), ProblemShape>( |
|
cute::move(tC_gRow), |
|
cute::move(tC_rRow), |
|
cute::move(tC_cRow), |
|
problem_shape, |
|
params_ptr |
|
); |
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template< |
|
class ThreadMap, |
|
class Element, |
|
class StrideMNL = Stride<_1,_0,_0> |
|
> |
|
struct VisitorColOrScalarBroadcast { |
|
|
|
|
|
|
|
struct Arguments { |
|
Element const* ptr_col = nullptr; |
|
bool col_broadcast = true; |
|
StrideMNL dCol = {}; |
|
}; |
|
|
|
using Params = Arguments; |
|
|
|
template <class ProblemShape> |
|
static constexpr Params |
|
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { |
|
return args; |
|
} |
|
|
|
template <class ProblemShape> |
|
static size_t |
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { |
|
return 0; |
|
} |
|
|
|
struct SharedStorage { }; |
|
|
|
CUTLASS_HOST_DEVICE |
|
VisitorColOrScalarBroadcast() { } |
|
|
|
CUTLASS_HOST_DEVICE |
|
VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) |
|
: params_ptr(¶ms) { } |
|
|
|
Params const* params_ptr; |
|
|
|
template <class GTensor, class RTensor, class CTensor, class ProblemShape> |
|
struct Callbacks : EmptyCallbacks { |
|
CUTLASS_DEVICE |
|
Callbacks( |
|
GTensor&& tC_gCol, |
|
RTensor&& tC_rCol, |
|
CTensor&& tC_cCol, |
|
ProblemShape problem_shape, |
|
Params const* params_ptr |
|
): |
|
tC_gCol(cute::forward<GTensor>(tC_gCol)), |
|
tC_rCol(cute::forward<RTensor>(tC_rCol)), |
|
tC_cCol(cute::forward<CTensor>(tC_cCol)), |
|
m(get<0>(problem_shape)), |
|
params_ptr(params_ptr) { } |
|
|
|
GTensor tC_gCol; |
|
RTensor tC_rCol; |
|
CTensor tC_cCol; |
|
Params const* params_ptr; |
|
int m; |
|
|
|
|
|
CUTLASS_DEVICE void |
|
begin_epilogue() { |
|
clear(tC_rCol); |
|
|
|
Tensor pred = make_tensor<bool>(shape(tC_gCol)); |
|
CUTLASS_PRAGMA_UNROLL |
|
for (int i = 0; i < size(pred); ++i) { |
|
pred(i) = get<0>(tC_cCol(i)) < m; |
|
} |
|
|
|
if (params_ptr->col_broadcast) { |
|
|
|
copy_if(pred, tC_gCol, tC_rCol); |
|
} else { |
|
|
|
auto dst_v = filter(tC_rCol); |
|
|
|
CUTLASS_PRAGMA_UNROLL |
|
for (int i = 0; i < size(dst_v); ++i) { |
|
if (pred(i)) { |
|
dst_v(i) = *(params_ptr->ptr_col); |
|
} |
|
} |
|
} |
|
} |
|
|
|
template <class ElementAccumulator, int FragmentSize> |
|
CUTLASS_DEVICE auto |
|
visit(int iter_idx, int row_idx, int column_idx, int frg_idx, |
|
Array<ElementAccumulator, FragmentSize> const& frg_acc) { |
|
Array<Element, FragmentSize> frg_col; |
|
frg_col.fill(tC_rCol(row_idx,iter_idx)); |
|
return frg_col; |
|
} |
|
}; |
|
|
|
template <class ProblemShape> |
|
CUTLASS_DEVICE auto |
|
get_callbacks( |
|
gemm::GemmCoord threadblock_tile_offset, |
|
int thread_idx, |
|
ProblemShape problem_shape |
|
) { |
|
Tensor mCol = make_tensor( |
|
make_gmem_ptr(params_ptr->ptr_col), |
|
problem_shape, |
|
params_ptr->dCol); |
|
|
|
|
|
Tensor tC_gCol = group_modes<1,4>( |
|
ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); |
|
Tensor tC_rCol = make_tensor_like(tC_gCol); |
|
|
|
|
|
Tensor cCol = make_identity_tensor(mCol.shape()); |
|
Tensor tC_cCol = group_modes<1,4>( |
|
ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); |
|
|
|
return Callbacks< |
|
decltype(tC_gCol), decltype(tC_rCol), |
|
decltype(tC_cCol), ProblemShape>( |
|
cute::move(tC_gCol), |
|
cute::move(tC_rCol), |
|
cute::move(tC_cCol), |
|
problem_shape, |
|
params_ptr |
|
); |
|
} |
|
}; |
|
|
|
} |
|
|