|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "common.h" |
|
#include "utility.h" |
|
|
|
namespace tensorrt_llm |
|
{ |
|
namespace kernels |
|
{ |
|
template <WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp, bool Zero, bool Bias, |
|
int N_PER_BLOCK, int BATCH, int BLOCK_SIZE> |
|
struct WeightOnlyBatchedGemvKernelLauncher |
|
{ |
|
static void run(const WeightOnlyParams& params, cudaStream_t stream); |
|
}; |
|
|
|
template <WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp, int N_PER_BLOCK, |
|
int BATCH, int BLOCK_SIZE> |
|
void select_zero_bias(const WeightOnlyParams& params, cudaStream_t stream) |
|
{ |
|
if (params.zeros && params.bias) |
|
{ |
|
WeightOnlyBatchedGemvKernelLauncher<QType, WeightOnlyFlag, ActOp, true, true, N_PER_BLOCK, BATCH, |
|
BLOCK_SIZE>::run(params, stream); |
|
} |
|
else if (params.zeros && !params.bias) |
|
{ |
|
WeightOnlyBatchedGemvKernelLauncher<QType, WeightOnlyFlag, ActOp, true, false, N_PER_BLOCK, BATCH, |
|
BLOCK_SIZE>::run(params, stream); |
|
} |
|
else if (!params.zeros && params.bias) |
|
{ |
|
WeightOnlyBatchedGemvKernelLauncher<QType, WeightOnlyFlag, ActOp, false, true, N_PER_BLOCK, BATCH, |
|
BLOCK_SIZE>::run(params, stream); |
|
} |
|
else |
|
{ |
|
WeightOnlyBatchedGemvKernelLauncher<QType, WeightOnlyFlag, ActOp, false, false, N_PER_BLOCK, BATCH, |
|
BLOCK_SIZE>::run(params, stream); |
|
} |
|
} |
|
|
|
template <WeightOnlyQuantType QType, typename WeightOnlyFlag, int N_PER_BLOCK, int BATCH, int BLOCK_SIZE> |
|
void select_activation(const WeightOnlyParams& params, cudaStream_t stream) |
|
{ |
|
switch (params.act_func_type) |
|
{ |
|
|
|
#if 0 |
|
case WeightOnlyActivationFunctionType::Gelu: |
|
{ |
|
select_zero_bias<QType, WeightOnlyFlag, GeluActivation, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); |
|
break; |
|
} |
|
case WeightOnlyActivationFunctionType::Relu: |
|
{ |
|
select_zero_bias<QType, WeightOnlyFlag, ReluActivation, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); |
|
break; |
|
} |
|
#endif |
|
case WeightOnlyActivationFunctionType::Identity: |
|
{ |
|
select_zero_bias<QType, WeightOnlyFlag, IdentityActivation, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); |
|
break; |
|
} |
|
default: |
|
{ |
|
throw std::runtime_error("Use unsupported activation"); |
|
break; |
|
} |
|
} |
|
} |
|
|
|
template <typename WeightOnlyFlag, int N_PER_BLOCK, int BATCH, int BLOCK_SIZE> |
|
void select_quant_type(const WeightOnlyParams& params, cudaStream_t stream) |
|
{ |
|
if (params.quant_type == WeightOnlyQuantType::Int4b) |
|
{ |
|
select_activation<WeightOnlyQuantType::Int4b, WeightOnlyFlag, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); |
|
} |
|
else if (params.quant_type == WeightOnlyQuantType::Int8b) |
|
{ |
|
select_activation<WeightOnlyQuantType::Int8b, WeightOnlyFlag, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); |
|
} |
|
else |
|
{ |
|
throw std::runtime_error("Unknown QuantType"); |
|
} |
|
} |
|
|
|
template <int N_PER_BLOCK, int BATCH, int BLOCK_SIZE> |
|
void select_groupwise_weight_only(const WeightOnlyParams& params, cudaStream_t stream) |
|
{ |
|
if (params.weight_only_type == WeightOnlyType::GroupWise && params.group_size == 64) |
|
{ |
|
select_quant_type<WeightOnlyGroupWise<64>, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); |
|
} |
|
else if (params.weight_only_type == WeightOnlyType::GroupWise && params.group_size == 128) |
|
{ |
|
select_quant_type<WeightOnlyGroupWise<128>, N_PER_BLOCK, BATCH, BLOCK_SIZE>(params, stream); |
|
} |
|
else |
|
{ |
|
throw std::runtime_error("Only support groupwise weight only for gs=64/128"); |
|
} |
|
} |
|
|
|
void weight_only_batched_gemv_launcher(const WeightOnlyParams& params, cudaStream_t stream) |
|
{ |
|
assert(params.act_func_type == WeightOnlyActivationFunctionType::Identity); |
|
assert(params.weight_only_type == WeightOnlyType::GroupWise |
|
|| (params.weight_only_type == WeightOnlyType::PerChannel && params.bias == nullptr |
|
&& params.zeros == nullptr)); |
|
if (params.weight_only_type == WeightOnlyType::PerChannel) |
|
{ |
|
if (params.quant_type == WeightOnlyQuantType::Int4b) |
|
{ |
|
switch (params.m) |
|
{ |
|
case 1: |
|
{ |
|
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, |
|
IdentityActivation, false, false, 1, 1, 192>::run(params, stream); |
|
break; |
|
} |
|
case 2: |
|
{ |
|
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, |
|
IdentityActivation, false, false, 2, 2, 128>::run(params, stream); |
|
break; |
|
} |
|
case 3: |
|
{ |
|
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, |
|
IdentityActivation, false, false, 2, 3, 256>::run(params, stream); |
|
break; |
|
} |
|
case 4: |
|
{ |
|
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int4b, WeightOnlyPerChannel, |
|
IdentityActivation, false, false, 4, 4, 256>::run(params, stream); |
|
break; |
|
} |
|
default: |
|
{ |
|
throw std::runtime_error("Weight only cuda kernel only supported bs <= 4"); |
|
break; |
|
} |
|
} |
|
} |
|
else if (params.quant_type == WeightOnlyQuantType::Int8b) |
|
{ |
|
switch (params.m) |
|
{ |
|
case 1: |
|
{ |
|
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, |
|
IdentityActivation, false, false, 2, 1, 256>::run(params, stream); |
|
break; |
|
} |
|
case 2: |
|
{ |
|
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, |
|
IdentityActivation, false, false, 2, 2, 256>::run(params, stream); |
|
break; |
|
} |
|
case 3: |
|
{ |
|
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, |
|
IdentityActivation, false, false, 2, 3, 256>::run(params, stream); |
|
break; |
|
} |
|
case 4: |
|
{ |
|
WeightOnlyBatchedGemvKernelLauncher<WeightOnlyQuantType::Int8b, WeightOnlyPerChannel, |
|
IdentityActivation, false, false, 2, 4, 256>::run(params, stream); |
|
break; |
|
} |
|
default: |
|
{ |
|
throw std::runtime_error("Weight only cuda kernel only supported bs <= 4"); |
|
break; |
|
} |
|
} |
|
} |
|
} |
|
else if (params.weight_only_type == WeightOnlyType::GroupWise) |
|
{ |
|
switch (params.m) |
|
{ |
|
case 1: |
|
{ |
|
select_groupwise_weight_only<2, 1, 256>(params, stream); |
|
break; |
|
} |
|
case 2: |
|
{ |
|
select_groupwise_weight_only<2, 2, 256>(params, stream); |
|
break; |
|
} |
|
case 3: |
|
{ |
|
select_groupwise_weight_only<2, 3, 128>(params, stream); |
|
break; |
|
} |
|
case 4: |
|
{ |
|
select_groupwise_weight_only<2, 4, 128>(params, stream); |
|
break; |
|
} |
|
default: |
|
{ |
|
throw std::runtime_error("Weight only cuda kernel only supported bs <= 4"); |
|
break; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|