|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
#include <cassert> |
|
#include <cmath> |
|
#include <cstdint> |
|
#include <cuda_fp16.h> |
|
#if defined(ENABLE_BF16) |
|
#include <cuda_bf16.h> |
|
#endif |
|
#include <cuda_runtime.h> |
|
#include <cuda_runtime_api.h> |
|
#include <iostream> |
|
|
|
namespace tensorrt_llm |
|
{ |
|
namespace kernels |
|
{ |
|
enum class WeightOnlyQuantType |
|
{ |
|
Int4b, |
|
Int8b |
|
}; |
|
enum class WeightOnlyType |
|
{ |
|
PerChannel, |
|
GroupWise |
|
}; |
|
|
|
struct WeightOnlyPerChannel; |
|
template <int GS> |
|
struct WeightOnlyGroupWise; |
|
|
|
enum class WeightOnlyActivationFunctionType |
|
{ |
|
Gelu, |
|
Relu, |
|
Identity, |
|
InvalidType |
|
}; |
|
|
|
enum class WeightOnlyActivationType |
|
{ |
|
FP16, |
|
BF16 |
|
}; |
|
|
|
struct WeightOnlyParams |
|
{ |
|
|
|
using ActType = void; |
|
using WeiType = uint8_t; |
|
|
|
const uint8_t* qweight; |
|
const ActType* scales; |
|
const ActType* zeros; |
|
const ActType* in; |
|
const ActType* act_scale; |
|
const ActType* bias; |
|
ActType* out; |
|
const int m; |
|
const int n; |
|
const int k; |
|
const int group_size; |
|
WeightOnlyQuantType quant_type; |
|
WeightOnlyType weight_only_type; |
|
WeightOnlyActivationFunctionType act_func_type; |
|
WeightOnlyActivationType act_type; |
|
|
|
WeightOnlyParams(const uint8_t* _qweight, const ActType* _scales, const ActType* _zeros, const ActType* _in, |
|
const ActType* _act_scale, const ActType* _bias, ActType* _out, const int _m, const int _n, const int _k, |
|
const int _group_size, const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type, |
|
const WeightOnlyActivationFunctionType _act_func_type, const WeightOnlyActivationType _act_type) |
|
: qweight(_qweight) |
|
, scales(_scales) |
|
, zeros(_zeros) |
|
, in(_in) |
|
, act_scale(_act_scale) |
|
, bias(_bias) |
|
, out(_out) |
|
, m(_m) |
|
, n(_n) |
|
, k(_k) |
|
, group_size(_group_size) |
|
, quant_type(_quant_type) |
|
, weight_only_type(_weight_only_type) |
|
, act_func_type(_act_func_type) |
|
, act_type(_act_type) |
|
{ |
|
} |
|
}; |
|
} |
|
} |
|
|