File size: 3,512 Bytes
76b423c
 
134a499
9a3f7b4
 
988dbd8
9a3f7b4
 
 
ab5f5f1
 
9a3f7b4
a1135a9
9a3f7b4
5392172
9a3f7b4
 
 
5392172
a1135a9
76b423c
9a3f7b4
3d7033f
a894537
 
 
 
0321f62
a894537
d262fb3
 
 
 
 
 
76b423c
 
 
 
 
 
 
ab5f5f1
 
 
5345cba
ab5f5f1
 
 
 
 
76b423c
 
 
 
 
a1135a9
76b423c
 
 
a1135a9
76b423c
 
 
 
a1135a9
 
76b423c
 
 
 
 
8766911
 
 
 
 
76b423c
5345cba
76b423c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1135a9
4f5bf6c
a1135a9
 
ab5f5f1
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from transformers import AutoConfig

LLM_MODEL_ARCHS = {
    "stablelm_epoch": "πŸ”΄ StableLM-Epoch",
    "stablelm_alpha": "πŸ”΄ StableLM-Alpha",
    "mixformer-sequential": "πŸ§‘β€πŸ’» Phi Ο†",
    "RefinedWebModel": "πŸ¦… Falcon",
    "gpt_bigcode": "⭐ StarCoder",
    "RefinedWeb": "πŸ¦… Falcon",
    "baichuan": "🌊 Baichuan 百川",  # river
    "internlm": "πŸ§‘β€πŸŽ“ InternLM δΉ¦η”Ÿ",  # scholar
    "mistral": "Ⓜ️ Mistral",
    "mixtral": "Ⓜ️ Mixtral",
    "codegen": "♾️ CodeGen",
    "chatglm": "πŸ’¬ ChatGLM",
    "falcon": "πŸ¦… Falcon",
    "bloom": "🌸 Bloom",
    "llama": "πŸ¦™ LLaMA",
    "rwkv": "πŸ¦β€β¬› RWKV",
    "deci": "πŸ”΅ deci",
    "Yi": "πŸ«‚ Yi δΊΊ",  # people
    "mpt": "🧱 MPT",
    # suggest something
    "gpt_neox": "GPT-NeoX",
    "gpt_neo": "GPT-Neo",
    "gpt2": "GPT-2",
    "gptj": "GPT-J",
    "bart": "BART",
}


def model_hyperlink(link, model_name):
    return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'


def process_architectures(model):
    # return "Unknown"
    try:
        config = AutoConfig.from_pretrained(model, trust_remote_code=True)
        return LLM_MODEL_ARCHS.get(config.model_type, "Unknown")
    except Exception:
        return "Unknown"


def process_score(score, quantization):
    if quantization != "Unquantized":
        return f"{score:.2f}*"
    else:
        return f"{score:.2f} "


def process_quantizations(x):
    if (
        x["config.backend.quantization_scheme"] == "bnb"
        and x["config.backend.quantization_config.load_in_4bit"] is True
    ):
        return "BnB.4bit"
    elif (
        x["config.backend.quantization_scheme"] == "bnb"
        and x["config.backend.quantization_config.load_in_8bit"] is True
    ):
        return "BnB.8bit"
    elif (
        x["config.backend.quantization_scheme"] == "gptq"
        and x["config.backend.quantization_config.bits"] == 4
    ):
        return "GPTQ.4bit"
    elif (
        x["config.backend.quantization_scheme"] == "awq"
        and x["config.backend.quantization_config.bits"] == 4
    ):
        return "AWQ.4bit"
    elif (
            x["config.backend.quantization_scheme"] == "torchao"
            and x["config.backend.quantization_config.quant_type"] == "int4_weight_only"
    ):
        return "torchao.4bit"
    else:
        return "Unquantized"


def process_kernels(x):
    if (
        x["config.backend.quantization_scheme"] == "gptq"
        and x["config.backend.quantization_config.version"] == 1
    ):
        return "GPTQ.ExllamaV1"

    elif (
        x["config.backend.quantization_scheme"] == "gptq"
        and x["config.backend.quantization_config.version"] == 2
    ):
        return "GPTQ.ExllamaV2"
    elif (
        x["config.backend.quantization_scheme"] == "awq"
        and x["config.backend.quantization_config.version"] == "gemm"
    ):
        return "AWQ.GEMM"
    elif (
        x["config.backend.quantization_scheme"] == "awq"
        and x["config.backend.quantization_config.version"] == "gemv"
    ):
        return "AWQ.GEMV"
    else:
        return "No Kernel"


# def change_tab(query_param):
#     query_param = query_param.replace("'", '"')
#     query_param = json.loads(query_param)

#     if isinstance(query_param, dict) and "tab" in query_param and query_param["tab"] == "plot":
#         return gr.Tabs.update(selected=1)
#     else:
#         return gr.Tabs.update(selected=0)