File size: 2,702 Bytes
1dc29e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
/*
 * Copyright (c) 2022-2024, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once
#include <cassert>
#include <cmath>
#include <cstdint>
#include <cuda_fp16.h>
#if defined(ENABLE_BF16)
#include <cuda_bf16.h>
#endif
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <iostream>

namespace tensorrt_llm
{
namespace kernels
{
enum class WeightOnlyQuantType
{
    Int4b,
    Int8b
};
enum class WeightOnlyType
{
    PerChannel,
    GroupWise
};

struct WeightOnlyPerChannel;
template <int GS>
struct WeightOnlyGroupWise;

enum class WeightOnlyActivationFunctionType
{
    Gelu,
    Relu,
    Identity,
    InvalidType
};

enum class WeightOnlyActivationType
{
    FP16,
    BF16
};

struct WeightOnlyParams
{
    // ActType is fp16 or bf16
    using ActType = void;
    using WeiType = uint8_t;

    const uint8_t* qweight;
    const ActType* scales;
    const ActType* zeros;
    const ActType* in;
    const ActType* act_scale;
    const ActType* bias;
    ActType* out;
    const int m;
    const int n;
    const int k;
    const int group_size;
    WeightOnlyQuantType quant_type;
    WeightOnlyType weight_only_type;
    WeightOnlyActivationFunctionType act_func_type;
    WeightOnlyActivationType act_type;

    WeightOnlyParams(const uint8_t* _qweight, const ActType* _scales, const ActType* _zeros, const ActType* _in,
        const ActType* _act_scale, const ActType* _bias, ActType* _out, const int _m, const int _n, const int _k,
        const int _group_size, const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type,
        const WeightOnlyActivationFunctionType _act_func_type, const WeightOnlyActivationType _act_type)
        : qweight(_qweight)
        , scales(_scales)
        , zeros(_zeros)
        , in(_in)
        , act_scale(_act_scale)
        , bias(_bias)
        , out(_out)
        , m(_m)
        , n(_n)
        , k(_k)
        , group_size(_group_size)
        , quant_type(_quant_type)
        , weight_only_type(_weight_only_type)
        , act_func_type(_act_func_type)
        , act_type(_act_type)
    {
    }
};
} // namespace kernels
} // namespace tensorrt_llm