|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
#include "common.h" |
|
#include "utility.h" |
|
|
|
namespace tensorrt_llm |
|
{ |
|
namespace kernels |
|
{ |
|
template <typename ActType> |
|
struct ActTypeDetails; |
|
|
|
template <> |
|
struct ActTypeDetails<half> |
|
{ |
|
using CutlassType = cutlass::half_t; |
|
using Vec2 = half2; |
|
|
|
__device__ __forceinline__ static Vec2 to_vec2(half v) |
|
{ |
|
return __half2half2(v); |
|
} |
|
}; |
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) |
|
template <> |
|
struct ActTypeDetails<__nv_bfloat16> |
|
{ |
|
using CutlassType = cutlass::bfloat16_t; |
|
using Vec2 = __nv_bfloat162; |
|
|
|
__device__ __forceinline__ static Vec2 to_vec2(__nv_bfloat16 v) |
|
{ |
|
return __bfloat162bfloat162(v); |
|
} |
|
}; |
|
#endif |
|
|
|
template <typename ActType, WeightOnlyQuantType QType> |
|
struct ConverterSelector |
|
{ |
|
static_assert(QType == WeightOnlyQuantType::Int4b || QType == WeightOnlyQuantType::Int8b); |
|
|
|
using WeiType = std::conditional_t<QType == WeightOnlyQuantType::Int4b, cutlass::uint4b_t, uint8_t>; |
|
static constexpr int kConvertCount = QType == WeightOnlyQuantType::Int4b ? 8 : 4; |
|
using Converter |
|
= cutlass::FastInterleavedAndBiasedNumericArrayConverter<typename ActTypeDetails<ActType>::CutlassType, WeiType, |
|
kConvertCount>; |
|
}; |
|
|
|
template <typename ActType, WeightOnlyQuantType QType> |
|
struct WeightOnlyDetails; |
|
|
|
template <typename ActType> |
|
struct WeightOnlyDetails<ActType, WeightOnlyQuantType::Int4b> |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static constexpr int kElemBits = 4; |
|
static constexpr int kInterleave = 4; |
|
static constexpr int kStride = 64; |
|
|
|
|
|
|
|
|
|
static constexpr int kShuffleSize = 32; |
|
static constexpr int kShuffleBasicTile = 2; |
|
static constexpr int kShuffleContinous = 4; |
|
static constexpr int kShuffleStrided = 4; |
|
|
|
|
|
|
|
template <int Num, int WarpSize> |
|
__device__ __forceinline__ static void sync(float* res, float (*sm)[Num * kInterleave]) |
|
{ |
|
#pragma unroll |
|
for (int i = 0; i < Num; ++i) |
|
{ |
|
res[i] += __shfl_xor_sync(~0, res[i], 16); |
|
res[i] += __shfl_xor_sync(~0, res[i], 8); |
|
res[i] += __shfl_xor_sync(~0, res[i], 1); |
|
} |
|
__syncthreads(); |
|
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize; |
|
if (lane == 0 || lane == 2 || lane == 4 || lane == 6) |
|
{ |
|
#pragma unroll |
|
for (int i = 0; i < Num; ++i) |
|
{ |
|
sm[warp][i * kInterleave + lane / 2] = res[i]; |
|
} |
|
} |
|
__syncthreads(); |
|
} |
|
}; |
|
|
|
template <typename ActType> |
|
struct WeightOnlyDetails<ActType, WeightOnlyQuantType::Int8b> |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static constexpr int kElemBits = 8; |
|
static constexpr int kInterleave = 2; |
|
static constexpr int kStride = 64; |
|
|
|
|
|
|
|
|
|
static constexpr int kShuffleSize = 16; |
|
static constexpr int kShuffleBasicTile = 2; |
|
static constexpr int kShuffleContinous = 2; |
|
static constexpr int kShuffleStrided = 4; |
|
|
|
|
|
|
|
template <int Num, int WarpSize> |
|
__device__ __forceinline__ static void sync(float* res, float (*sm)[Num * kInterleave]) |
|
{ |
|
#pragma unroll |
|
for (int i = 0; i < Num; ++i) |
|
{ |
|
res[i] += __shfl_xor_sync(~0, res[i], 16); |
|
res[i] += __shfl_xor_sync(~0, res[i], 8); |
|
res[i] += __shfl_xor_sync(~0, res[i], 2); |
|
res[i] += __shfl_xor_sync(~0, res[i], 1); |
|
} |
|
__syncthreads(); |
|
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize; |
|
if (lane == 0 || lane == 4) |
|
{ |
|
#pragma unroll |
|
for (int i = 0; i < Num; ++i) |
|
{ |
|
sm[warp][i * kInterleave + lane / 4] = res[i]; |
|
} |
|
} |
|
__syncthreads(); |
|
} |
|
}; |
|
|
|
template <typename ActType, WeightOnlyQuantType QType> |
|
struct WeightOnlyKernelDetails |
|
{ |
|
using Layout = WeightOnlyDetails<ActType, QType>; |
|
|
|
static constexpr int kElemBits = Layout::kElemBits; |
|
static constexpr int kInterleave = Layout::kInterleave; |
|
static constexpr int kStride = Layout::kStride; |
|
|
|
static constexpr int kShuffleSize = Layout::kShuffleSize; |
|
static constexpr int kShuffleBasicTile = Layout::kShuffleBasicTile; |
|
static constexpr int kShuffleContinous = Layout::kShuffleContinous; |
|
static constexpr int kShuffleStrided = Layout::kShuffleStrided; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static constexpr int kConvertCount = ConverterSelector<ActType, QType>::kConvertCount; |
|
using Converter = typename ConverterSelector<ActType, QType>::Converter; |
|
|
|
|
|
static constexpr int kAccessSize = 128; |
|
using AccessType = uint4; |
|
|
|
static constexpr int kElemsPerByte = 8 / kElemBits; |
|
static constexpr int kElemsPerThread = kAccessSize / kElemBits; |
|
static constexpr int kBytePerThread = kElemsPerThread / kElemsPerByte; |
|
static constexpr int kThreadsNumPerTile = kStride / kElemsPerThread; |
|
static constexpr int kThreadsNumPerInterleave = kThreadsNumPerTile * kInterleave; |
|
|
|
static constexpr int kConvertIters = kElemsPerThread / kConvertCount; |
|
|
|
|
|
|
|
static constexpr int kActivationElemNumPerAccess = kAccessSize / (sizeof(ActType) * 8); |
|
static constexpr int kActivationAccessNum = kElemsPerThread / kActivationElemNumPerAccess; |
|
}; |
|
|
|
template <typename WeightOnlyFlag> |
|
struct WeightOnlyProperties; |
|
|
|
template <> |
|
struct WeightOnlyProperties<WeightOnlyPerChannel> |
|
{ |
|
static constexpr bool kIsFineGrained = false; |
|
static constexpr int kGroupSize = 0; |
|
}; |
|
|
|
template <int GS> |
|
struct WeightOnlyProperties<WeightOnlyGroupWise<GS>> |
|
{ |
|
static constexpr bool kIsFineGrained = true; |
|
static constexpr int kGroupSize = GS; |
|
}; |
|
|
|
template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, bool Zero, int BlockSize> |
|
struct WeightOnlyScaleLoader |
|
{ |
|
using ElemType = ActType; |
|
using Details = WeightOnlyKernelDetails<ActType, QType>; |
|
static constexpr bool kIsFineGrained = WeightOnlyProperties<WeightOnlyFlag>::kIsFineGrained; |
|
static constexpr int kGroupSize = WeightOnlyProperties<WeightOnlyFlag>::kGroupSize; |
|
|
|
private: |
|
const ElemType* _scales; |
|
const ElemType* _zeros; |
|
int _stride; |
|
int _offset; |
|
|
|
public: |
|
__device__ __forceinline__ WeightOnlyScaleLoader( |
|
const ElemType* scales, const ElemType* zeros, int initial_offset, int stride) |
|
: _scales(scales) |
|
, _zeros(zeros) |
|
, _stride(stride) |
|
{ |
|
_scales += initial_offset; |
|
if constexpr (Zero) |
|
{ |
|
_zeros += initial_offset; |
|
} |
|
|
|
|
|
_offset = threadIdx.x / Details::kThreadsNumPerInterleave * Details::kStride |
|
+ (threadIdx.x % Details::kThreadsNumPerTile) * Details::kElemsPerThread; |
|
} |
|
|
|
__device__ __forceinline__ void load(ElemType& scale, ElemType& zero, int nid) |
|
{ |
|
int offset = nid * Details::kInterleave; |
|
if constexpr (kIsFineGrained) |
|
{ |
|
offset += _offset / kGroupSize * _stride; |
|
} |
|
scale = _scales[offset]; |
|
if constexpr (Zero) |
|
{ |
|
zero = _zeros[offset]; |
|
} |
|
else |
|
{ |
|
zero = static_cast<ElemType>(0.f); |
|
} |
|
} |
|
|
|
__device__ __forceinline__ void advance() |
|
{ |
|
_offset += BlockSize * Details::kElemsPerThread / Details::kInterleave; |
|
} |
|
|
|
__device__ __forceinline__ int offset() |
|
{ |
|
return _offset; |
|
} |
|
}; |
|
|
|
template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp, |
|
bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize> |
|
__device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType* scales, const ActType* zeros, |
|
const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k) |
|
{ |
|
static_assert(NPerBlock == 1 || (NPerBlock % 2 == 0)); |
|
using ActType2 = typename ActTypeDetails<ActType>::Vec2; |
|
using Details = WeightOnlyKernelDetails<ActType, QType>; |
|
|
|
using Converter = typename Details::Converter; |
|
using AccType = typename Details::AccessType; |
|
using CvtSrcType = typename Converter::source_type; |
|
using CvtResType = typename Converter::result_type; |
|
using ScaleLoader = WeightOnlyScaleLoader<ActType, QType, WeightOnlyFlag, Zero, BlockSize>; |
|
extern __shared__ uint8_t shmem[]; |
|
constexpr int Interleave = Details::kInterleave; |
|
constexpr int WarpSize = 32; |
|
constexpr int Num = Batch * NPerBlock; |
|
const int tid = threadIdx.x; |
|
const int bid = blockIdx.x; |
|
const int n_start_id = bid * NPerBlock * Interleave; |
|
|
|
const int interleave_n_id = (tid / Details::kThreadsNumPerTile) % Interleave; |
|
|
|
qweight += n_start_id * k / Details::kElemsPerByte; |
|
ScaleLoader scale_loader(scales, zeros, n_start_id + interleave_n_id, n); |
|
|
|
float(*sm)[Num * Interleave] = reinterpret_cast<float(*)[Num * Interleave]>(shmem); |
|
|
|
|
|
|
|
ActType accumulator[Num]; |
|
for (int i = 0; i < Num; ++i) |
|
{ |
|
accumulator[i] = static_cast<ActType>(0.f); |
|
} |
|
|
|
|
|
for (int local_k = tid * Details::kElemsPerThread; local_k < k * Interleave; |
|
local_k += BlockSize * Details::kElemsPerThread) |
|
{ |
|
ActType weights_f16[Details::kElemsPerThread * NPerBlock]; |
|
ActType scale[NPerBlock], zero[NPerBlock]; |
|
#pragma unroll |
|
for (int idx = 0; idx < NPerBlock; ++idx) |
|
{ |
|
|
|
uint8_t weights_quantized[Details::kBytePerThread]; |
|
load<AccType>(weights_quantized, |
|
qweight + idx * Interleave * k / Details::kElemsPerByte + local_k / Details::kElemsPerByte); |
|
scale_loader.load(scale[idx], zero[idx], idx); |
|
ActType weights_vec[Details::kElemsPerThread]; |
|
#pragma unroll |
|
for (int i = 0; i < Details::kConvertIters; ++i) |
|
{ |
|
|
|
assign<CvtResType>(weights_vec + i * Details::kConvertCount, |
|
Converter::convert(*reinterpret_cast<CvtSrcType*>( |
|
weights_quantized + i * Details::kConvertCount / Details::kElemsPerByte))); |
|
} |
|
#pragma unroll |
|
for (int i = 0; i < Details::kShuffleContinous; ++i) |
|
{ |
|
#pragma unroll |
|
for (int j = 0; j < Details::kShuffleStrided; ++j) |
|
{ |
|
|
|
|
|
ActType2 v = *reinterpret_cast<ActType2*>(weights_vec + i * Details::kShuffleBasicTile |
|
+ j * Details::kShuffleContinous * Details::kShuffleBasicTile); |
|
v = __hfma2( |
|
v, ActTypeDetails<ActType>::to_vec2(scale[idx]), ActTypeDetails<ActType>::to_vec2(zero[idx])); |
|
weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile |
|
+ j * Details::kShuffleBasicTile + 0) |
|
* NPerBlock |
|
+ idx] |
|
= v.x; |
|
weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile |
|
+ j * Details::kShuffleBasicTile + 1) |
|
* NPerBlock |
|
+ idx] |
|
= v.y; |
|
} |
|
} |
|
} |
|
ActType act_scale_v[Details::kElemsPerThread]; |
|
if constexpr (ActScale) |
|
{ |
|
#pragma unroll |
|
for (int idx = 0; idx < Details::kActivationAccessNum; ++idx) |
|
{ |
|
load<AccType>(act_scale_v + idx * Details::kActivationElemNumPerAccess, |
|
act_scale + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess); |
|
} |
|
} |
|
#pragma unroll |
|
for (int b = 0; b < Batch; ++b) |
|
{ |
|
ActType in_v[Details::kElemsPerThread]; |
|
#pragma unroll |
|
for (int idx = 0; idx < Details::kActivationAccessNum; ++idx) |
|
{ |
|
|
|
load<AccType>(in_v + idx * Details::kActivationElemNumPerAccess, |
|
in + b * k + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess); |
|
if constexpr (ActScale) |
|
{ |
|
#pragma unroll |
|
for (int i = 0; i < Details::kActivationElemNumPerAccess; i += 2) |
|
{ |
|
*reinterpret_cast<ActType2*>(in_v + idx * Details::kActivationElemNumPerAccess + i) = __hmul2( |
|
*reinterpret_cast<ActType2*>(in_v + idx * Details::kActivationElemNumPerAccess + i), |
|
*reinterpret_cast<ActType2*>(act_scale_v + idx * Details::kActivationElemNumPerAccess + i)); |
|
} |
|
} |
|
} |
|
|
|
if constexpr (NPerBlock == 1) |
|
{ |
|
ActType2 v = ActTypeDetails<ActType>::to_vec2(static_cast<ActType>(0.f)); |
|
#pragma unroll |
|
for (int y = 0; y < Details::kElemsPerThread; y += 2) |
|
{ |
|
v = __hfma2( |
|
*reinterpret_cast<ActType2*>(weights_f16 + y), *reinterpret_cast<ActType2*>(in_v + y), v); |
|
} |
|
accumulator[b] += __hadd(v.x, v.y); |
|
} |
|
else |
|
{ |
|
#pragma unroll |
|
for (int x = 0; x < NPerBlock / 2; ++x) |
|
{ |
|
#pragma unroll |
|
for (int y = 0; y < Details::kElemsPerThread; ++y) |
|
{ |
|
*reinterpret_cast<ActType2*>(accumulator + b * NPerBlock + x * 2) |
|
= __hfma2(*reinterpret_cast<ActType2*>(weights_f16 + y * NPerBlock + x * 2), |
|
ActTypeDetails<ActType>::to_vec2(in_v[y]), |
|
*reinterpret_cast<ActType2*>(accumulator + b * NPerBlock + x * 2)); |
|
} |
|
} |
|
} |
|
} |
|
scale_loader.advance(); |
|
} |
|
float reses[Num]; |
|
#pragma unroll |
|
for (int i = 0; i < Num; ++i) |
|
{ |
|
reses[i] = static_cast<float>(accumulator[i]); |
|
} |
|
|
|
|
|
|
|
Details::Layout::sync<Num, WarpSize>(reses, sm); |
|
|
|
|
|
for (int i = tid; i < Num * Interleave; i += BlockSize) |
|
{ |
|
int nid = i % (NPerBlock * Interleave); |
|
float v = 0.f; |
|
for (int j = 0; j < BlockSize / WarpSize; ++j) |
|
{ |
|
v += sm[j][i]; |
|
} |
|
float bias_v = 0.f; |
|
if constexpr (Bias) |
|
{ |
|
bias_v = static_cast<float>(bias[n_start_id + nid]); |
|
} |
|
int b = i / NPerBlock / Interleave; |
|
out[b * n + n_start_id + nid] = static_cast<ActType>(ActOp<float>::apply(v + bias_v)); |
|
} |
|
} |
|
|
|
template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp, |
|
bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize> |
|
__global__ void weight_only_batched_gemv_wrapper(const uint8_t* qweight, const ActType* scales, const ActType* zeros, |
|
const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k) |
|
{ |
|
if constexpr (std::is_same_v<ActType, half>) |
|
{ |
|
weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, ActScale, NPerBlock, Batch, |
|
BlockSize>(qweight, scales, zeros, in, act_scale, bias, out, n, k); |
|
} |
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) |
|
else if (std::is_same_v<ActType, nv_bfloat16>) |
|
{ |
|
weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, ActScale, NPerBlock, Batch, |
|
BlockSize>(qweight, scales, zeros, in, act_scale, bias, out, n, k); |
|
} |
|
#endif |
|
} |
|
|
|
template <WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp, bool Zero, bool Bias, |
|
int NPerBlock, int Batch, int BlockSize> |
|
struct WeightOnlyBatchedGemvKernelLauncher |
|
{ |
|
static void run(const WeightOnlyParams& params, cudaStream_t stream) |
|
{ |
|
if (params.act_type == WeightOnlyActivationType::FP16) |
|
{ |
|
constexpr int kInterleave = WeightOnlyDetails<half, QType>::kInterleave; |
|
dim3 grid(params.n / NPerBlock / kInterleave); |
|
dim3 block(BlockSize); |
|
int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave; |
|
if (params.act_scale != nullptr) |
|
{ |
|
weight_only_batched_gemv_wrapper<half, QType, WeightOnlyFlag, ActOp, Zero, Bias, true, NPerBlock, Batch, |
|
BlockSize><<<grid, block, size, stream>>>(params.qweight, |
|
reinterpret_cast<const half*>(params.scales), reinterpret_cast<const half*>(params.zeros), |
|
reinterpret_cast<const half*>(params.in), reinterpret_cast<const half*>(params.act_scale), |
|
reinterpret_cast<const half*>(params.bias), reinterpret_cast<half*>(params.out), params.n, |
|
params.k); |
|
} |
|
else |
|
{ |
|
weight_only_batched_gemv_wrapper<half, QType, WeightOnlyFlag, ActOp, Zero, Bias, false, NPerBlock, |
|
Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight, |
|
reinterpret_cast<const half*>(params.scales), reinterpret_cast<const half*>(params.zeros), |
|
reinterpret_cast<const half*>(params.in), reinterpret_cast<const half*>(params.act_scale), |
|
reinterpret_cast<const half*>(params.bias), reinterpret_cast<half*>(params.out), params.n, |
|
params.k); |
|
} |
|
} |
|
#if defined(ENABLE_BF16) |
|
else if (params.act_type == WeightOnlyActivationType::BF16) |
|
{ |
|
constexpr int kInterleave = WeightOnlyDetails<nv_bfloat16, QType>::kInterleave; |
|
dim3 grid(params.n / NPerBlock / kInterleave); |
|
dim3 block(BlockSize); |
|
int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave; |
|
if (params.act_scale != nullptr) |
|
{ |
|
weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, true, |
|
NPerBlock, Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight, |
|
reinterpret_cast<const __nv_bfloat16*>(params.scales), |
|
reinterpret_cast<const __nv_bfloat16*>(params.zeros), |
|
reinterpret_cast<const __nv_bfloat16*>(params.in), |
|
reinterpret_cast<const __nv_bfloat16*>(params.act_scale), |
|
reinterpret_cast<const __nv_bfloat16*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out), |
|
params.n, params.k); |
|
} |
|
else |
|
{ |
|
weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, false, |
|
NPerBlock, Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight, |
|
reinterpret_cast<const __nv_bfloat16*>(params.scales), |
|
reinterpret_cast<const __nv_bfloat16*>(params.zeros), |
|
reinterpret_cast<const __nv_bfloat16*>(params.in), |
|
reinterpret_cast<const __nv_bfloat16*>(params.act_scale), |
|
reinterpret_cast<const __nv_bfloat16*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out), |
|
params.n, params.k); |
|
} |
|
} |
|
#endif |
|
} |
|
}; |
|
} |
|
} |
|
|