Last commit not found
/* | |
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
namespace tensorrt_llm | |
{ | |
namespace kernels | |
{ | |
template <typename ActType> | |
struct ActTypeDetails; | |
template <> | |
struct ActTypeDetails<half> | |
{ | |
using CutlassType = cutlass::half_t; | |
using Vec2 = half2; | |
__device__ __forceinline__ static Vec2 to_vec2(half v) | |
{ | |
return __half2half2(v); | |
} | |
}; | |
template <> | |
struct ActTypeDetails<__nv_bfloat16> | |
{ | |
using CutlassType = cutlass::bfloat16_t; | |
using Vec2 = __nv_bfloat162; | |
__device__ __forceinline__ static Vec2 to_vec2(__nv_bfloat16 v) | |
{ | |
return __bfloat162bfloat162(v); | |
} | |
}; | |
template <typename ActType, WeightOnlyQuantType QType> | |
struct ConverterSelector | |
{ | |
static_assert(QType == WeightOnlyQuantType::Int4b || QType == WeightOnlyQuantType::Int8b); | |
using WeiType = std::conditional_t<QType == WeightOnlyQuantType::Int4b, cutlass::uint4b_t, uint8_t>; | |
static constexpr int kConvertCount = QType == WeightOnlyQuantType::Int4b ? 8 : 4; | |
using Converter | |
= cutlass::FastInterleavedAndBiasedNumericArrayConverter<typename ActTypeDetails<ActType>::CutlassType, WeiType, | |
kConvertCount>; | |
}; | |
template <typename ActType, WeightOnlyQuantType QType> | |
struct WeightOnlyDetails; | |
template <typename ActType> | |
struct WeightOnlyDetails<ActType, WeightOnlyQuantType::Int4b> | |
{ | |
// Every four rows of the original weights are interleaved into a row with stride of 64, so if each thread | |
// processes 32 elements(for int4, we can use ldg.128 to load weights), then every group of two adjacent threads | |
// will alternately process four different row weights | |
// for example | |
// every 256 consecutive int4 elements [256*i, 256*(i+1)-1] of row N under interleave layout, | |
// the first 64 are from [64*i, 64*(i+1)-1] of row 4N before interleaving, | |
// and the second 64 are from [64*i, 64*(i+1)-1] of row 4N+1 before interleaving, and so on. | |
// So if each thread loads 32 int4 elements, then the elements of each 2 adjacent threads of each 8 | |
// consecutive threads will come from row 4N ~ 4N+3 respectively before interleaving. | |
static constexpr int kElemBits = 4; | |
static constexpr int kInterleave = 4; | |
static constexpr int kStride = 64; | |
// The index remapping here is to counteracts the effect of cutlass::permute_B_rows_for_mixed_gemm | |
// input 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ... 31 | |
// weight 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31 | |
static constexpr int kShuffleSize = 32; | |
static constexpr int kShuffleBasicTile = 2; | |
static constexpr int kShuffleContinous = 4; | |
static constexpr int kShuffleStrided = 4; | |
// Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the | |
// corresponding address in shared memory | |
template <int Num, int WarpSize> | |
__device__ __forceinline__ static void sync(float* res, float (*sm)[Num * kInterleave]) | |
{ | |
for (int i = 0; i < Num; ++i) | |
{ | |
res[i] += __shfl_xor_sync(~0, res[i], 16); | |
res[i] += __shfl_xor_sync(~0, res[i], 8); | |
res[i] += __shfl_xor_sync(~0, res[i], 1); | |
} | |
__syncthreads(); | |
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize; | |
if (lane == 0 || lane == 2 || lane == 4 || lane == 6) | |
{ | |
for (int i = 0; i < Num; ++i) | |
{ | |
sm[warp][i * kInterleave + lane / 2] = res[i]; | |
} | |
} | |
__syncthreads(); | |
} | |
}; | |
template <typename ActType> | |
struct WeightOnlyDetails<ActType, WeightOnlyQuantType::Int8b> | |
{ | |
// Every two rows of the original weights are interleaved into a row with stride of 64, so if each thread | |
// processes 16 elements(for int8, we can use ldg.128 to load weights), then every group of four adjacent threads | |
// will alternately process two different row weights | |
// for example | |
// every 128 consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave layout, | |
// the first 64 are from [64*i, 64*(i+1)-1] of row 2N before interleaving, | |
// and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1 before interleaving. | |
// So if each thread loads 16 int8 elements, then the elements of the first four and last four threads of each 8 | |
// consecutive threads will come from row 2N and row 2N+1 respectively before interleaving. | |
static constexpr int kElemBits = 8; | |
static constexpr int kInterleave = 2; | |
static constexpr int kStride = 64; | |
// The index remapping here is to counteracts the effect of cutlass::permute_B_rows_for_mixed_gemm | |
// input 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | |
// weight 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 | |
static constexpr int kShuffleSize = 16; | |
static constexpr int kShuffleBasicTile = 2; | |
static constexpr int kShuffleContinous = 2; | |
static constexpr int kShuffleStrided = 4; | |
// Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the | |
// corresponding address in shared memory | |
template <int Num, int WarpSize> | |
__device__ __forceinline__ static void sync(float* res, float (*sm)[Num * kInterleave]) | |
{ | |
for (int i = 0; i < Num; ++i) | |
{ | |
res[i] += __shfl_xor_sync(~0, res[i], 16); | |
res[i] += __shfl_xor_sync(~0, res[i], 8); | |
res[i] += __shfl_xor_sync(~0, res[i], 2); | |
res[i] += __shfl_xor_sync(~0, res[i], 1); | |
} | |
__syncthreads(); | |
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize; | |
if (lane == 0 || lane == 4) | |
{ | |
for (int i = 0; i < Num; ++i) | |
{ | |
sm[warp][i * kInterleave + lane / 4] = res[i]; | |
} | |
} | |
__syncthreads(); | |
} | |
}; | |
template <typename ActType, WeightOnlyQuantType QType> | |
struct WeightOnlyKernelDetails | |
{ | |
using Layout = WeightOnlyDetails<ActType, QType>; | |
static constexpr int kElemBits = Layout::kElemBits; | |
static constexpr int kInterleave = Layout::kInterleave; | |
static constexpr int kStride = Layout::kStride; | |
static constexpr int kShuffleSize = Layout::kShuffleSize; | |
static constexpr int kShuffleBasicTile = Layout::kShuffleBasicTile; | |
static constexpr int kShuffleContinous = Layout::kShuffleContinous; | |
static constexpr int kShuffleStrided = Layout::kShuffleStrided; | |
// The rearrangement here counteracts the effect of cutlass::add_bias_and_interleave_int4/8s_inplace | |
// Input int8 data layout | |
// [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) | |
// | |
// Converted fp16/bf16 data layout | |
// [elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits) | |
// Input int8 data layout | |
// [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) | |
// | |
// Converted fp16/bf16 data layout | |
// [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits) | |
static constexpr int kConvertCount = ConverterSelector<ActType, QType>::kConvertCount; | |
using Converter = typename ConverterSelector<ActType, QType>::Converter; | |
// Use ldg128 load data from global memory | |
static constexpr int kAccessSize = 128; | |
using AccessType = uint4; | |
static constexpr int kElemsPerByte = 8 / kElemBits; | |
static constexpr int kElemsPerThread = kAccessSize / kElemBits; | |
static constexpr int kBytePerThread = kElemsPerThread / kElemsPerByte; | |
static constexpr int kThreadsNumPerTile = kStride / kElemsPerThread; | |
static constexpr int kThreadsNumPerInterleave = kThreadsNumPerTile * kInterleave; | |
static constexpr int kConvertIters = kElemsPerThread / kConvertCount; | |
// Each thread loads 16(int8b)/32(int4b) quantized weight elements each time through ldg128 | |
// So more times of ldg128 are needed to load the same number of fp16/bf16 activation elements. | |
static constexpr int kActivationElemNumPerAccess = kAccessSize / (sizeof(ActType) * 8); | |
static constexpr int kActivationAccessNum = kElemsPerThread / kActivationElemNumPerAccess; | |
}; | |
template <typename WeightOnlyFlag> | |
struct WeightOnlyProperties; | |
template <> | |
struct WeightOnlyProperties<WeightOnlyPerChannel> | |
{ | |
static constexpr bool kIsFineGrained = false; | |
static constexpr int kGroupSize = 0; | |
}; | |
template <int GS> | |
struct WeightOnlyProperties<WeightOnlyGroupWise<GS>> | |
{ | |
static constexpr bool kIsFineGrained = true; | |
static constexpr int kGroupSize = GS; | |
}; | |
template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, bool Zero, int BlockSize> | |
struct WeightOnlyScaleLoader | |
{ | |
using ElemType = ActType; | |
using Details = WeightOnlyKernelDetails<ActType, QType>; | |
static constexpr bool kIsFineGrained = WeightOnlyProperties<WeightOnlyFlag>::kIsFineGrained; | |
static constexpr int kGroupSize = WeightOnlyProperties<WeightOnlyFlag>::kGroupSize; | |
private: | |
const ElemType* _scales; | |
const ElemType* _zeros; | |
int _stride; | |
int _offset; | |
public: | |
__device__ __forceinline__ WeightOnlyScaleLoader( | |
const ElemType* scales, const ElemType* zeros, int initial_offset, int stride) | |
: _scales(scales) | |
, _zeros(zeros) | |
, _stride(stride) | |
{ | |
_scales += initial_offset; | |
if constexpr (Zero) | |
{ | |
_zeros += initial_offset; | |
} | |
// Calculate the k dimension index of the element processed by the current thread of layout before interleave | |
// Used to load scales and zeros in groupwise weight only quant | |
_offset = threadIdx.x / Details::kThreadsNumPerInterleave * Details::kStride | |
+ (threadIdx.x % Details::kThreadsNumPerTile) * Details::kElemsPerThread; | |
} | |
__device__ __forceinline__ void load(ElemType& scale, ElemType& zero, int nid) | |
{ | |
int offset = nid * Details::kInterleave; | |
if constexpr (kIsFineGrained) | |
{ | |
offset += _offset / kGroupSize * _stride; | |
} | |
scale = _scales[offset]; | |
if constexpr (Zero) | |
{ | |
zero = _zeros[offset]; | |
} | |
else | |
{ | |
zero = static_cast<ElemType>(0.f); | |
} | |
} | |
__device__ __forceinline__ void advance() | |
{ | |
_offset += BlockSize * Details::kElemsPerThread / Details::kInterleave; | |
} | |
__device__ __forceinline__ int offset() | |
{ | |
return _offset; | |
} | |
}; | |
template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp, | |
bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize> | |
__device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType* scales, const ActType* zeros, | |
const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k) | |
{ | |
static_assert(NPerBlock == 1 || (NPerBlock % 2 == 0)); | |
using ActType2 = typename ActTypeDetails<ActType>::Vec2; | |
using Details = WeightOnlyKernelDetails<ActType, QType>; | |
using Converter = typename Details::Converter; | |
using AccType = typename Details::AccessType; | |
using CvtSrcType = typename Converter::source_type; | |
using CvtResType = typename Converter::result_type; | |
using ScaleLoader = WeightOnlyScaleLoader<ActType, QType, WeightOnlyFlag, Zero, BlockSize>; | |
extern __shared__ uint8_t shmem[]; | |
constexpr int Interleave = Details::kInterleave; | |
constexpr int WarpSize = 32; | |
constexpr int Num = Batch * NPerBlock; | |
const int tid = threadIdx.x; | |
const int bid = blockIdx.x; | |
const int n_start_id = bid * NPerBlock * Interleave; | |
// Calculate the n-dimensional index of the data processed by the current thread in the interleave tile | |
const int interleave_n_id = (tid / Details::kThreadsNumPerTile) % Interleave; | |
qweight += n_start_id * k / Details::kElemsPerByte; | |
ScaleLoader scale_loader(scales, zeros, n_start_id + interleave_n_id, n); | |
float(*sm)[Num * Interleave] = reinterpret_cast<float(*)[Num * Interleave]>(shmem); | |
// In order to take advantage of hfma2, we use fp16/bf16 for accumulation within threads and fp32 for accumulation | |
// between threads. | |
ActType accumulator[Num]; | |
for (int i = 0; i < Num; ++i) | |
{ | |
accumulator[i] = static_cast<ActType>(0.f); | |
} | |
// Iteration in k dimensions | |
for (int local_k = tid * Details::kElemsPerThread; local_k < k * Interleave; | |
local_k += BlockSize * Details::kElemsPerThread) | |
{ | |
ActType weights_f16[Details::kElemsPerThread * NPerBlock]; | |
ActType scale[NPerBlock], zero[NPerBlock]; | |
for (int idx = 0; idx < NPerBlock; ++idx) | |
{ | |
// Load quantized weight and scales/zeros | |
uint8_t weights_quantized[Details::kBytePerThread]; | |
load<AccType>(weights_quantized, | |
qweight + idx * Interleave * k / Details::kElemsPerByte + local_k / Details::kElemsPerByte); | |
scale_loader.load(scale[idx], zero[idx], idx); | |
ActType weights_vec[Details::kElemsPerThread]; | |
for (int i = 0; i < Details::kConvertIters; ++i) | |
{ | |
// Use cutlass::FastInterleavedAndBiasedNumericArrayConverter for I2F type conversion | |
assign<CvtResType>(weights_vec + i * Details::kConvertCount, | |
Converter::convert(*reinterpret_cast<CvtSrcType*>( | |
weights_quantized + i * Details::kConvertCount / Details::kElemsPerByte))); | |
} | |
for (int i = 0; i < Details::kShuffleContinous; ++i) | |
{ | |
for (int j = 0; j < Details::kShuffleStrided; ++j) | |
{ | |
// Dequantize the weights and arrange the shuffled elements back to the correct order in the | |
// register array | |
ActType2 v = *reinterpret_cast<ActType2*>(weights_vec + i * Details::kShuffleBasicTile | |
+ j * Details::kShuffleContinous * Details::kShuffleBasicTile); | |
v = __hfma2( | |
v, ActTypeDetails<ActType>::to_vec2(scale[idx]), ActTypeDetails<ActType>::to_vec2(zero[idx])); | |
weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile | |
+ j * Details::kShuffleBasicTile + 0) | |
* NPerBlock | |
+ idx] | |
= v.x; | |
weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile | |
+ j * Details::kShuffleBasicTile + 1) | |
* NPerBlock | |
+ idx] | |
= v.y; | |
} | |
} | |
} | |
ActType act_scale_v[Details::kElemsPerThread]; | |
if constexpr (ActScale) | |
{ | |
for (int idx = 0; idx < Details::kActivationAccessNum; ++idx) | |
{ | |
load<AccType>(act_scale_v + idx * Details::kActivationElemNumPerAccess, | |
act_scale + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess); | |
} | |
} | |
for (int b = 0; b < Batch; ++b) | |
{ | |
ActType in_v[Details::kElemsPerThread]; | |
for (int idx = 0; idx < Details::kActivationAccessNum; ++idx) | |
{ | |
// load activation elements | |
load<AccType>(in_v + idx * Details::kActivationElemNumPerAccess, | |
in + b * k + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess); | |
if constexpr (ActScale) | |
{ | |
for (int i = 0; i < Details::kActivationElemNumPerAccess; i += 2) | |
{ | |
*reinterpret_cast<ActType2*>(in_v + idx * Details::kActivationElemNumPerAccess + i) = __hmul2( | |
*reinterpret_cast<ActType2*>(in_v + idx * Details::kActivationElemNumPerAccess + i), | |
*reinterpret_cast<ActType2*>(act_scale_v + idx * Details::kActivationElemNumPerAccess + i)); | |
} | |
} | |
} | |
// Perform vector inner product and accumulate | |
if constexpr (NPerBlock == 1) | |
{ | |
ActType2 v = ActTypeDetails<ActType>::to_vec2(static_cast<ActType>(0.f)); | |
for (int y = 0; y < Details::kElemsPerThread; y += 2) | |
{ | |
v = __hfma2( | |
*reinterpret_cast<ActType2*>(weights_f16 + y), *reinterpret_cast<ActType2*>(in_v + y), v); | |
} | |
accumulator[b] += __hadd(v.x, v.y); | |
} | |
else | |
{ | |
for (int x = 0; x < NPerBlock / 2; ++x) | |
{ | |
for (int y = 0; y < Details::kElemsPerThread; ++y) | |
{ | |
*reinterpret_cast<ActType2*>(accumulator + b * NPerBlock + x * 2) | |
= __hfma2(*reinterpret_cast<ActType2*>(weights_f16 + y * NPerBlock + x * 2), | |
ActTypeDetails<ActType>::to_vec2(in_v[y]), | |
*reinterpret_cast<ActType2*>(accumulator + b * NPerBlock + x * 2)); | |
} | |
} | |
} | |
} | |
scale_loader.advance(); | |
} | |
float reses[Num]; | |
for (int i = 0; i < Num; ++i) | |
{ | |
reses[i] = static_cast<float>(accumulator[i]); | |
} | |
// Each warp completes the internal reduce and writes the [Batch * NPerBlock * Interleave] results to the | |
// corresponding address in shared memory | |
Details::Layout::sync<Num, WarpSize>(reses, sm); | |
// Each thread is responsible for the accumulation and store to global memory of one element | |
for (int i = tid; i < Num * Interleave; i += BlockSize) | |
{ | |
int nid = i % (NPerBlock * Interleave); | |
float v = 0.f; | |
for (int j = 0; j < BlockSize / WarpSize; ++j) | |
{ | |
v += sm[j][i]; | |
} | |
float bias_v = 0.f; | |
if constexpr (Bias) | |
{ | |
bias_v = static_cast<float>(bias[n_start_id + nid]); | |
} | |
int b = i / NPerBlock / Interleave; | |
out[b * n + n_start_id + nid] = static_cast<ActType>(ActOp<float>::apply(v + bias_v)); | |
} | |
} | |
template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp, | |
bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize> | |
__global__ void weight_only_batched_gemv_wrapper(const uint8_t* qweight, const ActType* scales, const ActType* zeros, | |
const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k) | |
{ | |
if constexpr (std::is_same_v<ActType, half>) | |
{ | |
weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, ActScale, NPerBlock, Batch, | |
BlockSize>(qweight, scales, zeros, in, act_scale, bias, out, n, k); | |
} | |
else if (std::is_same_v<ActType, nv_bfloat16>) | |
{ | |
weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, ActScale, NPerBlock, Batch, | |
BlockSize>(qweight, scales, zeros, in, act_scale, bias, out, n, k); | |
} | |
} | |
template <WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp, bool Zero, bool Bias, | |
int NPerBlock, int Batch, int BlockSize> | |
struct WeightOnlyBatchedGemvKernelLauncher | |
{ | |
static void run(const WeightOnlyParams& params, cudaStream_t stream) | |
{ | |
if (params.act_type == WeightOnlyActivationType::FP16) | |
{ | |
constexpr int kInterleave = WeightOnlyDetails<half, QType>::kInterleave; | |
dim3 grid(params.n / NPerBlock / kInterleave); | |
dim3 block(BlockSize); | |
int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave; | |
if (params.act_scale != nullptr) | |
{ | |
weight_only_batched_gemv_wrapper<half, QType, WeightOnlyFlag, ActOp, Zero, Bias, true, NPerBlock, Batch, | |
BlockSize><<<grid, block, size, stream>>>(params.qweight, | |
reinterpret_cast<const half*>(params.scales), reinterpret_cast<const half*>(params.zeros), | |
reinterpret_cast<const half*>(params.in), reinterpret_cast<const half*>(params.act_scale), | |
reinterpret_cast<const half*>(params.bias), reinterpret_cast<half*>(params.out), params.n, | |
params.k); | |
} | |
else | |
{ | |
weight_only_batched_gemv_wrapper<half, QType, WeightOnlyFlag, ActOp, Zero, Bias, false, NPerBlock, | |
Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight, | |
reinterpret_cast<const half*>(params.scales), reinterpret_cast<const half*>(params.zeros), | |
reinterpret_cast<const half*>(params.in), reinterpret_cast<const half*>(params.act_scale), | |
reinterpret_cast<const half*>(params.bias), reinterpret_cast<half*>(params.out), params.n, | |
params.k); | |
} | |
} | |
else if (params.act_type == WeightOnlyActivationType::BF16) | |
{ | |
constexpr int kInterleave = WeightOnlyDetails<nv_bfloat16, QType>::kInterleave; | |
dim3 grid(params.n / NPerBlock / kInterleave); | |
dim3 block(BlockSize); | |
int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave; | |
if (params.act_scale != nullptr) | |
{ | |
weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, true, | |
NPerBlock, Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight, | |
reinterpret_cast<const __nv_bfloat16*>(params.scales), | |
reinterpret_cast<const __nv_bfloat16*>(params.zeros), | |
reinterpret_cast<const __nv_bfloat16*>(params.in), | |
reinterpret_cast<const __nv_bfloat16*>(params.act_scale), | |
reinterpret_cast<const __nv_bfloat16*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out), | |
params.n, params.k); | |
} | |
else | |
{ | |
weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, false, | |
NPerBlock, Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight, | |
reinterpret_cast<const __nv_bfloat16*>(params.scales), | |
reinterpret_cast<const __nv_bfloat16*>(params.zeros), | |
reinterpret_cast<const __nv_bfloat16*>(params.in), | |
reinterpret_cast<const __nv_bfloat16*>(params.act_scale), | |
reinterpret_cast<const __nv_bfloat16*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out), | |
params.n, params.k); | |
} | |
} | |
} | |
}; | |
} // namespace kernels | |
} // namespace tensorrt_llm | |