uukuguy commited on
Commit
11fb0af
·
1 Parent(s): 5e1fb25
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ pytorch_model-00002-of-00003.bin filter=lfs diff=lfs merge=lfs -text
37
+ pytorch_model-00003-of-00003.bin filter=lfs diff=lfs merge=lfs -text
38
+ pytorch_model-00001-of-00003.bin filter=lfs diff=lfs merge=lfs -text
39
+ tokenizer.model filter=lfs diff=lfs merge=lfs -text
Community License for Baichuan2 Model.pdf ADDED
Binary file (203 kB). View file
 
README.md CHANGED
@@ -1,3 +1,175 @@
1
  ---
2
- license: apache-2.0
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language:
3
+ - en
4
+ - zh
5
+ license: apache2
6
+ tasks:
7
+ - text-generation
8
+ datasets:
9
+ - ehartford/dolphin
10
+ - Open-Orca/OpenOrca
11
+ - garage-bAInd/Open-Platypus
12
  ---
13
+
14
+ <p><h1> speechless-baichuan2-dolphin-orca-platypus-13b </h1></p>
15
+ Fine-tune the baichuan-inc/Baichuan2-13B-Base with Dolphin, Orca and Platypus datasets.
16
+
17
+ | Metric | Value |
18
+ | --- | --- |
19
+ | ARC | |
20
+ | HellaSwag | |
21
+ | MMLU | |
22
+ | TruthfulQA | |
23
+ | Average | |
24
+
25
+ <!-- markdownlint-disable first-line-h1 -->
26
+ <!-- markdownlint-disable html -->
27
+ <div align="center">
28
+ <h1>
29
+ Baichuan 2
30
+ </h1>
31
+ </div>
32
+
33
+ <div align="center">
34
+ <a href="https://github.com/baichuan-inc/Baichuan2" target="_blank">🦉GitHub</a> | <a href="https://github.com/baichuan-inc/Baichuan-7B/blob/main/media/wechat.jpeg?raw=true" target="_blank">💬WeChat</a>
35
+ </div>
36
+ <div align="center">
37
+ 🚀 <a href="https://www.baichuan-ai.com/" target="_blank">百川大模型在线对话平台</a> 已正式向公众开放 🎉
38
+ </div>
39
+
40
+ # 目录/Table of Contents
41
+
42
+ - [📖 模型介绍/Introduction](#Introduction)
43
+ - [⚙️ 快速开始/Quick Start](#Start)
44
+ - [📊 Benchmark评估/Benchmark Evaluation](#Benchmark)
45
+ - [📜 声明与协议/Terms and Conditions](#Terms)
46
+
47
+
48
+ # <span id="Introduction">模型介绍/Introduction</span>
49
+
50
+ Baichuan 2 是[百川智能]推出的新一代开源大语言模型,采用 **2.6 万亿** Tokens 的高质量语料训练,在权威的中文和英文 benchmark
51
+ 上均取得同尺寸最好的效果。本次发布包含有 7B、13B 的 Base 和 Chat 版本,并提供了 Chat 版本的 4bits
52
+ 量化,所有版本不仅对学术研究完全开放,开发者也仅需[邮件申请]并获得官方商用许可后,即可以免费商用。具体发布版本和下载见下表:
53
+
54
+ Baichuan 2 is the new generation of large-scale open-source language models launched by [Baichuan Intelligence inc.](https://www.baichuan-ai.com/).
55
+ It is trained on a high-quality corpus with 2.6 trillion tokens and has achieved the best performance in authoritative Chinese and English benchmarks of the same size.
56
+ This release includes 7B and 13B versions for both Base and Chat models, along with a 4bits quantized version for the Chat model.
57
+ All versions are fully open to academic research, and developers can also use them for free in commercial applications after obtaining an official commercial license through [email request](mailto:[email protected]).
58
+ The specific release versions and download links are listed in the table below:
59
+
60
+ | | Base Model | Chat Model | 4bits Quantized Chat Model |
61
+ |:---:|:--------------------:|:--------------------:|:--------------------------:|
62
+ | 7B | [Baichuan2-7B-Base](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base) | [Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat) | [Baichuan2-7B-Chat-4bits](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base-4bits) |
63
+ | 13B | [Baichuan2-13B-Base](https://huggingface.co/baichuan-inc/Baichuan2-13B-Base) | [Baichuan2-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat) | [Baichuan2-13B-Chat-4bits](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat-4bits) |
64
+
65
+ # <span id="Start">快速开始/Quick Start</span>
66
+
67
+ 在Baichuan2系列模型中,我们为了加快推理速度使用了Pytorch2.0加入的新功能F.scaled_dot_product_attention,因此模型需要在Pytorch2.0环境下运行。
68
+
69
+ In the Baichuan 2 series models, we have utilized the new feature `F.scaled_dot_product_attention` introduced in PyTorch 2.0 to accelerate inference speed. Therefore, the model needs to be run in a PyTorch 2.0 environment.
70
+
71
+
72
+ ```python
73
+ import torch
74
+ from transformers import AutoModelForCausalLM, AutoTokenizer
75
+ tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-13B-Base", use_fast=False, trust_remote_code=True)
76
+ model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-13B-Base", device_map="auto", trust_remote_code=True)
77
+ inputs = tokenizer('登鹳雀楼->王之涣\n夜雨寄北->', return_tensors='pt')
78
+ inputs = inputs.to('cuda:0')
79
+ pred = model.generate(**inputs, max_new_tokens=64, repetition_penalty=1.1)
80
+ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
81
+ ```
82
+
83
+ # <span id="Benchmark">Benchmark 结果/Benchmark Evaluation</span>
84
+
85
+ 我们在[通用]、[法律]、[医疗]、[数学]、[代码]和[多语言翻译]六个领域的中英文权威数据集上对模型进行了广泛测试,更多详细测评结果可查看[GitHub]。
86
+
87
+ We have extensively tested the model on authoritative Chinese-English datasets across six domains: [General](https://github.com/baichuan-inc/Baichuan2/blob/main/README_EN.md#general-domain), [Legal](https://github.com/baichuan-inc/Baichuan2/blob/main/README_EN.md#law-and-medicine), [Medical](https://github.com/baichuan-inc/Baichuan2/blob/main/README_EN.md#law-and-medicine), [Mathematics](https://github.com/baichuan-inc/Baichuan2/blob/main/README_EN.md#mathematics-and-code), [Code](https://github.com/baichuan-inc/Baichuan2/blob/main/README_EN.md#mathematics-and-code), and [Multilingual Translation](https://github.com/baichuan-inc/Baichuan2/blob/main/README_EN.md#multilingual-translation). For more detailed evaluation results, please refer to [GitHub](https://github.com/baichuan-inc/Baichuan2/blob/main/README_EN.md).
88
+
89
+ ### 7B Model Results
90
+
91
+ | | **C-Eval** | **MMLU** | **CMMLU** | **Gaokao** | **AGIEval** | **BBH** |
92
+ |:-----------------------:|:----------:|:--------:|:---------:|:----------:|:-----------:|:-------:|
93
+ | | 5-shot | 5-shot | 5-shot | 5-shot | 5-shot | 3-shot |
94
+ | **GPT-4** | 68.40 | 83.93 | 70.33 | 66.15 | 63.27 | 75.12 |
95
+ | **GPT-3.5 Turbo** | 51.10 | 68.54 | 54.06 | 47.07 | 46.13 | 61.59 |
96
+ | **LLaMA-7B** | 27.10 | 35.10 | 26.75 | 27.81 | 28.17 | 32.38 |
97
+ | **LLaMA2-7B** | 28.90 | 45.73 | 31.38 | 25.97 | 26.53 | 39.16 |
98
+ | **MPT-7B** | 27.15 | 27.93 | 26.00 | 26.54 | 24.83 | 35.20 |
99
+ | **Falcon-7B** | 24.23 | 26.03 | 25.66 | 24.24 | 24.10 | 28.77 |
100
+ | **ChatGLM2-6B** | 50.20 | 45.90 | 49.00 | 49.44 | 45.28 | 31.65 |
101
+ | **[Baichuan-7B]** | 42.80 | 42.30 | 44.02 | 36.34 | 34.44 | 32.48 |
102
+ | **[Baichuan2-7B-Base]** | 54.00 | 54.16 | 57.07 | 47.47 | 42.73 | 41.56 |
103
+
104
+ ### 13B Model Results
105
+
106
+ | | **C-Eval** | **MMLU** | **CMMLU** | **Gaokao** | **AGIEval** | **BBH** |
107
+ |:---------------------------:|:----------:|:--------:|:---------:|:----------:|:-----------:|:-------:|
108
+ | | 5-shot | 5-shot | 5-shot | 5-shot | 5-shot | 3-shot |
109
+ | **GPT-4** | 68.40 | 83.93 | 70.33 | 66.15 | 63.27 | 75.12 |
110
+ | **GPT-3.5 Turbo** | 51.10 | 68.54 | 54.06 | 47.07 | 46.13 | 61.59 |
111
+ | **LLaMA-13B** | 28.50 | 46.30 | 31.15 | 28.23 | 28.22 | 37.89 |
112
+ | **LLaMA2-13B** | 35.80 | 55.09 | 37.99 | 30.83 | 32.29 | 46.98 |
113
+ | **Vicuna-13B** | 32.80 | 52.00 | 36.28 | 30.11 | 31.55 | 43.04 |
114
+ | **Chinese-Alpaca-Plus-13B** | 38.80 | 43.90 | 33.43 | 34.78 | 35.46 | 28.94 |
115
+ | **XVERSE-13B** | 53.70 | 55.21 | 58.44 | 44.69 | 42.54 | 38.06 |
116
+ | **[Baichuan-13B-Base]** | 52.40 | 51.60 | 55.30 | 49.69 | 43.20 | 43.01 |
117
+ | **[Baichuan2-13B-Base]** | 58.10 | 59.17 | 61.97 | 54.33 | 48.17 | 48.78 |
118
+
119
+
120
+ ## 训练过程模型/Training Dynamics
121
+
122
+ 除了训练了 2.6 万亿 Tokens 的 [Baichuan2-7B-Base](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base) 模型,我们还提供了在此之前的另外 11 个中间过程的模型(分别对应训练了约 0.2 ~ 2.4 万亿 Tokens)供社区研究使用
123
+ ([训练过程checkpoint下载](https://huggingface.co/baichuan-inc/Baichuan2-7B-Intermediate-Checkpoints))。下图给出了这些 checkpoints 在 C-Eval、MMLU、CMMLU 三个 benchmark 上的效果变化:
124
+
125
+ In addition to the [Baichuan2-7B-Base](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base) model trained on 2.6 trillion tokens, we also offer 11 additional intermediate-stage models for community research, corresponding to training on approximately 0.2 to 2.4 trillion tokens each ([Intermediate Checkpoints Download](https://huggingface.co/baichuan-inc/Baichuan2-7B-Intermediate-Checkpoints)). The graph below shows the performance changes of these checkpoints on three benchmarks: C-Eval, MMLU, and CMMLU.
126
+
127
+ ![checkpoint](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/resolve/main/checkpoints.jpeg)
128
+
129
+ # <span id="Terms">声明与协议/Terms and Conditions</span>
130
+
131
+ ## 声明
132
+
133
+ 我们在此声明,我们的开发团队并未基于 Baichuan 2 模型开发任何应用,无论是在 iOS、Android、网页或任何其他平台。我们强烈呼吁所有使用者,不要利用
134
+ Baichuan 2 模型进行任何危害国家社会安全或违法的活动。另外,我们也要求使用者不要将 Baichuan 2
135
+ 模型用于未经适当安全审查和备案的互联网服务。我们希望所有的使用者都能遵守这个原则,确保科技的发展能在规范和合法的环境下进行。
136
+
137
+ 我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。因此,如果由于使用
138
+ Baichuan 2 开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
139
+
140
+ We hereby declare that our team has not developed any applications based on Baichuan 2 models, not on iOS, Android, the web, or any other platform. We strongly call on all users not to use Baichuan 2 models for any activities that harm national / social security or violate the law. Also, we ask users not to use Baichuan 2 models for Internet services that have not undergone appropriate security reviews and filings. We hope that all users can abide by this principle and ensure that the development of technology proceeds in a regulated and legal environment.
141
+
142
+ We have done our best to ensure the compliance of the data used in the model training process. However, despite our considerable efforts, there may still be some unforeseeable issues due to the complexity of the model and data. Therefore, if any problems arise due to the use of Baichuan 2 open-source models, including but not limited to data security issues, public opinion risks, or any risks and problems brought about by the model being misled, abused, spread or improperly exploited, we will not assume any responsibility.
143
+
144
+ ## 协议
145
+
146
+ Baichuan 2 模型的社区使用需遵循[《Baichuan 2 模型社区许可协议》]。Baichuan 2 支持商用。如果将 Baichuan 2 模型或其衍生品用作商业用途,请您按照如下方式联系许可方,以进行登记并向许可方申请书面授权:联系邮箱 [[email protected]]。
147
+
148
+ The use of the source code in this repository follows the open-source license Apache 2.0. Community use of the Baichuan 2 model must adhere to the [Community License for Baichuan 2 Model](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Baichuan%202%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf). Baichuan 2 supports commercial use. If you are using the Baichuan 2 models or their derivatives for commercial purposes, please contact the licensor in the following manner for registration and to apply for written authorization: Email [email protected].
149
+
150
+ [GitHub]:https://github.com/baichuan-inc/Baichuan2
151
+ [Baichuan2]:https://github.com/baichuan-inc/Baichuan2
152
+
153
+ [Baichuan-7B]:https://huggingface.co/baichuan-inc/Baichuan-7B
154
+ [Baichuan2-7B-Base]:https://huggingface.co/baichuan-inc/Baichuan2-7B-Base
155
+ [Baichuan2-7B-Chat]:https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
156
+ [Baichuan2-7B-Chat-4bits]:https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat-4bits
157
+ [Baichuan-13B-Base]:https://huggingface.co/baichuan-inc/Baichuan-13B-Base
158
+ [Baichuan2-13B-Base]:https://huggingface.co/baichuan-inc/Baichuan2-13B-Base
159
+ [Baichuan2-13B-Chat]:https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
160
+ [Baichuan2-13B-Chat-4bits]:https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat-4bits
161
+
162
+ [通用]:https://github.com/baichuan-inc/Baichuan2#%E9%80%9A%E7%94%A8%E9%A2%86%E5%9F%9F
163
+ [法律]:https://github.com/baichuan-inc/Baichuan2#%E6%B3%95%E5%BE%8B%E5%8C%BB%E7%96%97
164
+ [医疗]:https://github.com/baichuan-inc/Baichuan2#%E6%B3%95%E5%BE%8B%E5%8C%BB%E7%96%97
165
+ [数学]:https://github.com/baichuan-inc/Baichuan2#%E6%95%B0%E5%AD%A6%E4%BB%A3%E7%A0%81
166
+ [代码]:https://github.com/baichuan-inc/Baichuan2#%E6%95%B0%E5%AD%A6%E4%BB%A3%E7%A0%81
167
+ [多语言翻译]:https://github.com/baichuan-inc/Baichuan2#%E5%A4%9A%E8%AF%AD%E8%A8%80%E7%BF%BB%E8%AF%91
168
+
169
+ [《Baichuan 2 模型社区许可协议》]:https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Baichuan%202%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf
170
+
171
+ [邮件申请]: mailto:[email protected]
172
+ [Email]: mailto:[email protected]
173
174
+ [训练过程heckpoint下载]: https://huggingface.co/baichuan-inc/Baichuan2-7B-Intermediate-Checkpoints
175
+ [百川智能]: https://www.baichuan-ai.com
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "_name_or_path": "/opt/local/llm_models/huggingface.co/baichuan-inc/Baichuan2-13B-Base",
4
+ "architectures": [
5
+ "BaichuanForCausalLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_baichuan.BaichuanConfig",
9
+ "AutoModelForCausalLM": "modeling_baichuan.BaichuanForCausalLM"
10
+ },
11
+ "bos_token_id": 1,
12
+ "eos_token_id": 2,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 5120,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 13696,
17
+ "model_max_length": 4096,
18
+ "model_type": "baichuan",
19
+ "num_attention_heads": 40,
20
+ "num_hidden_layers": 40,
21
+ "pad_token_id": 0,
22
+ "rms_norm_eps": 1e-06,
23
+ "tie_word_embeddings": false,
24
+ "torch_dtype": "float16",
25
+ "transformers_version": "4.32.1",
26
+ "use_cache": true,
27
+ "vocab_size": 125696,
28
+ "z_loss_weight": 0
29
+ }
configuration_baichuan.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class BaichuanConfig(PretrainedConfig):
7
+ model_type = "baichuan"
8
+ keys_to_ignore_at_inference = ["past_key_values"]
9
+
10
+ def __init__(
11
+ self,
12
+ vocab_size=64000,
13
+ hidden_size=5120,
14
+ intermediate_size=13696,
15
+ num_hidden_layers=40,
16
+ num_attention_heads=40,
17
+ hidden_act="silu",
18
+ model_max_length=4096,
19
+ initializer_range=0.02,
20
+ rms_norm_eps=1e-6,
21
+ use_cache=True,
22
+ pad_token_id=0,
23
+ bos_token_id=1,
24
+ eos_token_id=2,
25
+ tie_word_embeddings=False,
26
+ gradient_checkpointing=False,
27
+ z_loss_weight=0,
28
+ **kwargs,
29
+ ):
30
+ self.vocab_size = vocab_size
31
+ self.model_max_length = model_max_length
32
+ self.hidden_size = hidden_size
33
+ self.intermediate_size = intermediate_size
34
+ self.num_hidden_layers = num_hidden_layers
35
+ self.num_attention_heads = num_attention_heads
36
+ self.hidden_act = hidden_act
37
+ self.initializer_range = initializer_range
38
+ self.rms_norm_eps = rms_norm_eps
39
+ self.use_cache = use_cache
40
+ self.z_loss_weight = z_loss_weight
41
+ self.gradient_checkpointing = (gradient_checkpointing,)
42
+ super().__init__(
43
+ pad_token_id=pad_token_id,
44
+ bos_token_id=bos_token_id,
45
+ eos_token_id=eos_token_id,
46
+ tie_word_embeddings=tie_word_embeddings,
47
+ **kwargs,
48
+ )
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.32.1"
7
+ }
generation_utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from queue import Queue
3
+
4
+ import torch
5
+
6
+
7
+ def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0):
8
+ def _parse_messages(messages, split_role="user"):
9
+ system, rounds = "", []
10
+ round = []
11
+ for i, message in enumerate(messages):
12
+ if message["role"] == "system":
13
+ assert i == 0
14
+ system = message["content"]
15
+ continue
16
+ if message["role"] == split_role and round:
17
+ rounds.append(round)
18
+ round = []
19
+ round.append(message)
20
+ if round:
21
+ rounds.append(round)
22
+ return system, rounds
23
+
24
+ max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
25
+ max_input_tokens = model.config.model_max_length - max_new_tokens
26
+ system, rounds = _parse_messages(messages, split_role="user")
27
+ system_tokens = tokenizer.encode(system)
28
+ max_history_tokens = max_input_tokens - len(system_tokens)
29
+
30
+ history_tokens = []
31
+ for round in rounds[::-1]:
32
+ round_tokens = []
33
+ for message in round:
34
+ if message["role"] == "user":
35
+ round_tokens.append(model.generation_config.user_token_id)
36
+ else:
37
+ round_tokens.append(model.generation_config.assistant_token_id)
38
+ round_tokens.extend(tokenizer.encode(message["content"]))
39
+ if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
40
+ history_tokens = round_tokens + history_tokens # concat left
41
+ if len(history_tokens) < max_history_tokens:
42
+ continue
43
+ break
44
+
45
+ input_tokens = system_tokens + history_tokens
46
+ if messages[-1]["role"] != "assistant":
47
+ input_tokens.append(model.generation_config.assistant_token_id)
48
+ input_tokens = input_tokens[-max_input_tokens:] # truncate left
49
+ return torch.LongTensor([input_tokens]).to(model.device)
50
+
51
+
52
+ class TextIterStreamer:
53
+ def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
54
+ self.tokenizer = tokenizer
55
+ self.skip_prompt = skip_prompt
56
+ self.skip_special_tokens = skip_special_tokens
57
+ self.tokens = []
58
+ self.text_queue = Queue()
59
+ self.next_tokens_are_prompt = True
60
+
61
+ def put(self, value):
62
+ if self.skip_prompt and self.next_tokens_are_prompt:
63
+ self.next_tokens_are_prompt = False
64
+ else:
65
+ if len(value.shape) > 1:
66
+ value = value[0]
67
+ self.tokens.extend(value.tolist())
68
+ self.text_queue.put(
69
+ self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
70
+
71
+ def end(self):
72
+ self.text_queue.put(None)
73
+
74
+ def __iter__(self):
75
+ return self
76
+
77
+ def __next__(self):
78
+ value = self.text_queue.get()
79
+ if value is None:
80
+ raise StopIteration()
81
+ else:
82
+ return value
83
+
modeling_baichuan.py ADDED
@@ -0,0 +1,826 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
2
+
3
+ from .configuration_baichuan import BaichuanConfig
4
+ from .generation_utils import build_chat_input, TextIterStreamer
5
+
6
+ import math
7
+ from threading import Thread
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from torch import nn
12
+ from torch.nn import CrossEntropyLoss
13
+ from torch.nn import functional as F
14
+ from transformers import PreTrainedModel, PretrainedConfig
15
+ from transformers.activations import ACT2FN
16
+ from transformers.generation.utils import GenerationConfig
17
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
18
+ from transformers.utils import logging, ContextManagers
19
+
20
+ import os
21
+ from contextlib import contextmanager
22
+ from accelerate import init_empty_weights
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ try:
27
+ from xformers import ops as xops
28
+ except ImportError:
29
+ xops = None
30
+ logger.warning(
31
+ "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
32
+ )
33
+
34
+
35
+ def _get_interleave(n):
36
+ def _get_interleave_power_of_2(n):
37
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
38
+ ratio = start
39
+ return [start * ratio**i for i in range(n)]
40
+
41
+ if math.log2(n).is_integer():
42
+ return _get_interleave_power_of_2(n)
43
+ else:
44
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
45
+ return (
46
+ _get_interleave_power_of_2(closest_power_of_2)
47
+ + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
48
+ )
49
+
50
+
51
+ def _fill_with_neg_inf(t):
52
+ """FP16-compatible function that fills a tensor with -inf."""
53
+ return t.float().fill_(float("-inf")).type_as(t)
54
+
55
+
56
+ def _buffered_future_mask(tensor, maxpos, alibi, attn_heads):
57
+ _future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1)
58
+ _future_mask = _future_mask.unsqueeze(0) + alibi
59
+ new_future_mask = _future_mask.to(tensor)
60
+ return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos]
61
+
62
+
63
+ def _gen_alibi_mask(tensor, n_head, max_pos):
64
+ slopes = torch.Tensor(_get_interleave(n_head))
65
+ position_point = torch.arange(max_pos) - max_pos + 1
66
+ position_point = position_point.unsqueeze(0).unsqueeze(0).expand(n_head, -1, -1)
67
+ diag = torch.diag(position_point[0])
68
+ position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
69
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
70
+ alibi = alibi.view(n_head, 1, max_pos)
71
+ alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1)
72
+ alibi_mask = alibi_mask.unsqueeze(0) + alibi
73
+ return alibi_mask
74
+
75
+
76
+ class RMSNorm(torch.nn.Module):
77
+ def __init__(self, hidden_size, epsilon=1e-6):
78
+ super().__init__()
79
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size))
80
+ self.epsilon = epsilon
81
+
82
+ def forward(self, hidden_states):
83
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
84
+ hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
85
+
86
+ # convert into half-precision
87
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
88
+ hidden_states = hidden_states.to(self.weight.dtype)
89
+
90
+ return self.weight * hidden_states
91
+
92
+
93
+ class MLP(torch.nn.Module):
94
+ def __init__(
95
+ self,
96
+ hidden_size: int,
97
+ intermediate_size: int,
98
+ hidden_act: str,
99
+ ):
100
+ super().__init__()
101
+ self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
102
+ self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
103
+ self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
104
+ self.act_fn = ACT2FN[hidden_act]
105
+
106
+ def forward(self, x):
107
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
108
+
109
+
110
+ class BaichuanAttention(torch.nn.Module):
111
+ def __init__(self, config: BaichuanConfig):
112
+ super().__init__()
113
+ self.config = config
114
+ self.hidden_size = config.hidden_size
115
+ self.num_heads = config.num_attention_heads
116
+ self.head_dim = self.hidden_size // self.num_heads
117
+ self.max_position_embeddings = config.model_max_length
118
+
119
+ if (self.head_dim * self.num_heads) != self.hidden_size:
120
+ raise ValueError(
121
+ f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
122
+ )
123
+ self.W_pack = torch.nn.Linear(
124
+ self.hidden_size, 3 * self.hidden_size, bias=False
125
+ )
126
+ self.o_proj = torch.nn.Linear(
127
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
128
+ )
129
+
130
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
131
+ return (
132
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
133
+ .transpose(1, 2)
134
+ .contiguous()
135
+ )
136
+
137
+ def forward(
138
+ self,
139
+ hidden_states: torch.Tensor,
140
+ attention_mask: Optional[torch.Tensor] = None,
141
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
142
+ output_attentions: bool = False,
143
+ use_cache: bool = False,
144
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
145
+ bsz, q_len, _ = hidden_states.size()
146
+
147
+ proj = self.W_pack(hidden_states)
148
+ proj = (
149
+ proj.unflatten(-1, (3, self.hidden_size))
150
+ .unsqueeze(0)
151
+ .transpose(0, -2)
152
+ .squeeze(-2)
153
+ )
154
+ query_states = (
155
+ proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
156
+ )
157
+ key_states = (
158
+ proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
159
+ )
160
+ value_states = (
161
+ proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
162
+ )
163
+
164
+ kv_seq_len = key_states.shape[-2]
165
+ if past_key_value is not None:
166
+ kv_seq_len += past_key_value[0].shape[-2]
167
+
168
+ if past_key_value is not None:
169
+ # reuse k, v, self_attention
170
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
171
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
172
+
173
+ past_key_value = (key_states, value_states) if use_cache else None
174
+ if xops is not None and self.training:
175
+ attn_weights = None
176
+ # query_states = query_states.transpose(1, 2)
177
+ # key_states = key_states.transpose(1, 2)
178
+ # value_states = value_states.transpose(1, 2)
179
+ # attn_output = xops.memory_efficient_attention(
180
+ # query_states, key_states, value_states, attn_bias=attention_mask
181
+ # )
182
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
183
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
184
+ attn_output = attn_output.transpose(1, 2)
185
+ else:
186
+ attn_weights = torch.matmul(
187
+ query_states, key_states.transpose(2, 3)
188
+ ) / math.sqrt(self.head_dim)
189
+
190
+ if attention_mask is not None:
191
+ if q_len == 1: # inference with cache
192
+ if len(attention_mask.size()) == 4:
193
+ attention_mask = attention_mask[:, :, -1:, :]
194
+ else:
195
+ attention_mask = attention_mask[:, -1:, :]
196
+ attn_weights = attn_weights + attention_mask
197
+ attn_weights = torch.max(
198
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
199
+ )
200
+
201
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
202
+ attn_output = torch.matmul(attn_weights, value_states)
203
+
204
+ attn_output = attn_output.transpose(1, 2)
205
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
206
+ attn_output = self.o_proj(attn_output)
207
+
208
+ if not output_attentions:
209
+ attn_weights = None
210
+
211
+ return attn_output, attn_weights, past_key_value
212
+
213
+
214
+ class BaichuanLayer(torch.nn.Module):
215
+ def __init__(self, config: BaichuanConfig):
216
+ super().__init__()
217
+ self.hidden_size = config.hidden_size
218
+ self.self_attn = BaichuanAttention(config=config)
219
+ self.mlp = MLP(
220
+ hidden_size=self.hidden_size,
221
+ intermediate_size=config.intermediate_size,
222
+ hidden_act=config.hidden_act,
223
+ )
224
+ self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
225
+ self.post_attention_layernorm = RMSNorm(
226
+ config.hidden_size, epsilon=config.rms_norm_eps
227
+ )
228
+
229
+ def forward(
230
+ self,
231
+ hidden_states: torch.Tensor,
232
+ attention_mask: Optional[torch.Tensor] = None,
233
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
234
+ output_attentions: Optional[bool] = False,
235
+ use_cache: Optional[bool] = False,
236
+ ) -> Tuple[
237
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
238
+ ]:
239
+ residual = hidden_states
240
+
241
+ hidden_states = self.input_layernorm(hidden_states)
242
+
243
+ # Self Attention
244
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
245
+ hidden_states=hidden_states,
246
+ attention_mask=attention_mask,
247
+ past_key_value=past_key_value,
248
+ output_attentions=output_attentions,
249
+ use_cache=use_cache,
250
+ )
251
+ hidden_states = residual + hidden_states
252
+
253
+ # Fully Connected
254
+ residual = hidden_states
255
+ hidden_states = self.post_attention_layernorm(hidden_states)
256
+ hidden_states = self.mlp(hidden_states)
257
+ hidden_states = residual + hidden_states
258
+
259
+ outputs = (hidden_states,)
260
+
261
+ if use_cache:
262
+ outputs += (present_key_value,)
263
+
264
+ return outputs
265
+
266
+
267
+ class BaichuanPreTrainedModel(PreTrainedModel):
268
+ config_class = BaichuanConfig
269
+ base_model_prefix = "model"
270
+ supports_gradient_checkpointing = True
271
+ _no_split_modules = ["BaichuanLayer"]
272
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
273
+
274
+ def _init_weights(self, module):
275
+ std = self.config.initializer_range
276
+ if isinstance(module, torch.nn.Linear):
277
+ module.weight.data.normal_(mean=0.0, std=std)
278
+ if module.bias is not None:
279
+ module.bias.data.zero_()
280
+ elif isinstance(module, torch.nn.Embedding):
281
+ module.weight.data.normal_(mean=0.0, std=std)
282
+ if module.padding_idx is not None:
283
+ module.weight.data[module.padding_idx].zero_()
284
+
285
+ def _set_gradient_checkpointing(self, module, value=False):
286
+ if isinstance(module, BaichuanModel):
287
+ module.gradient_checkpointing = value
288
+
289
+
290
+ class BaichuanModel(BaichuanPreTrainedModel):
291
+ def __init__(self, config: BaichuanConfig):
292
+ super().__init__(config)
293
+ self.padding_idx = config.pad_token_id
294
+ self.vocab_size = config.vocab_size
295
+ self.n_head = config.num_attention_heads
296
+ self.embed_tokens = torch.nn.Embedding(
297
+ config.vocab_size, config.hidden_size, self.padding_idx
298
+ )
299
+ self.layers = torch.nn.ModuleList(
300
+ [BaichuanLayer(config) for _ in range(config.num_hidden_layers)]
301
+ )
302
+ self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
303
+
304
+ self.gradient_checkpointing = config.gradient_checkpointing
305
+ self.post_init()
306
+ self.max_cache_pos = config.model_max_length
307
+ self.first_run = True
308
+ self.alibi_mask = None
309
+
310
+ def get_input_embeddings(self):
311
+ return self.embed_tokens
312
+
313
+ def set_input_embeddings(self, value):
314
+ self.embed_tokens = value
315
+
316
+ def get_alibi_mask(self, tensor, seq_length_with_past):
317
+ if self.training:
318
+ slopes = torch.Tensor(_get_interleave(self.n_head))
319
+ position_point = (
320
+ torch.arange(seq_length_with_past) - seq_length_with_past + 1
321
+ )
322
+ position_point = (
323
+ position_point.unsqueeze(0)
324
+ .unsqueeze(0)
325
+ .expand(self.n_head, seq_length_with_past, -1)
326
+ )
327
+ diag = torch.diag(position_point[0])
328
+ position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose(
329
+ -1, -2
330
+ )
331
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point
332
+ mask = _buffered_future_mask(
333
+ tensor, seq_length_with_past, alibi, self.n_head
334
+ )
335
+ else:
336
+ if self.first_run:
337
+ self.first_run = False
338
+ self.register_buffer(
339
+ "future_mask",
340
+ _gen_alibi_mask(tensor, self.n_head, self.max_cache_pos).to(
341
+ tensor
342
+ ),
343
+ persistent=False,
344
+ )
345
+ if seq_length_with_past > self.max_cache_pos:
346
+ self.max_cache_pos = seq_length_with_past
347
+ self.register_buffer(
348
+ "future_mask",
349
+ _gen_alibi_mask(tensor, self.n_head, self.max_cache_pos).to(
350
+ tensor
351
+ ),
352
+ persistent=False,
353
+ )
354
+ mask = self.future_mask[
355
+ : self.n_head, :seq_length_with_past, :seq_length_with_past
356
+ ]
357
+ return mask
358
+
359
+ def forward(
360
+ self,
361
+ input_ids: torch.LongTensor = None,
362
+ attention_mask: Optional[torch.Tensor] = None,
363
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
364
+ inputs_embeds: Optional[torch.FloatTensor] = None,
365
+ use_cache: Optional[bool] = False,
366
+ output_attentions: Optional[bool] = False,
367
+ output_hidden_states: Optional[bool] = False,
368
+ return_dict: Optional[bool] = True,
369
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
370
+ if input_ids is not None and inputs_embeds is not None:
371
+ raise ValueError(
372
+ "You cannot provide both input_ids and inputs_embeds simultaneously"
373
+ )
374
+ elif input_ids is not None:
375
+ batch_size, seq_length = input_ids.shape
376
+ elif inputs_embeds is not None:
377
+ batch_size, seq_length, _ = inputs_embeds.shape
378
+ else:
379
+ raise ValueError("You need to provide input_ids or inputs_embeds")
380
+
381
+ return_dict = (
382
+ return_dict if return_dict is not None else self.config.use_return_dict
383
+ )
384
+
385
+ seq_length_with_past = seq_length
386
+
387
+ if past_key_values is not None:
388
+ past_key_values_length = past_key_values[0][0].shape[2]
389
+ seq_length_with_past = seq_length_with_past + past_key_values_length
390
+
391
+ if inputs_embeds is None:
392
+ inputs_embeds = self.embed_tokens(input_ids)
393
+
394
+ if self.training:
395
+ if (
396
+ self.alibi_mask is None
397
+ or self.alibi_mask.shape[-1] != seq_length_with_past
398
+ ):
399
+ self.alibi_mask = self.get_alibi_mask(
400
+ inputs_embeds, seq_length_with_past
401
+ )
402
+ alibi_mask = self.alibi_mask
403
+ else:
404
+ alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
405
+
406
+ if attention_mask is not None:
407
+ if len(attention_mask.shape) == 2:
408
+ expanded_mask = attention_mask.to(alibi_mask.dtype)
409
+ expanded_mask = torch.tril(
410
+ torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
411
+ ) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0)
412
+ else:
413
+ expanded_mask = attention_mask
414
+ bsz = inputs_embeds.size(0)
415
+ src_len, tgt_len = alibi_mask.size()[-2:]
416
+ expanded_mask = (
417
+ expanded_mask.unsqueeze(1)
418
+ .expand(bsz, 1, src_len, tgt_len)
419
+ .to(alibi_mask.dtype)
420
+ )
421
+ inverted_mask = 1.0 - expanded_mask
422
+ inverted_mask = inverted_mask.masked_fill(
423
+ inverted_mask.to(torch.bool), torch.finfo(alibi_mask.dtype).min
424
+ )
425
+ attention_mask = inverted_mask + alibi_mask.unsqueeze(0)
426
+ else:
427
+ attention_mask = alibi_mask
428
+
429
+ hidden_states = inputs_embeds
430
+
431
+ if self.gradient_checkpointing and self.training:
432
+ if use_cache:
433
+ logger.warning_once(
434
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
435
+ )
436
+ use_cache = False
437
+
438
+ # decoder layers
439
+ all_hidden_states = () if output_hidden_states else None
440
+ all_self_attns = () if output_attentions else None
441
+ next_decoder_cache = () if use_cache else None
442
+
443
+ for idx, decoder_layer in enumerate(self.layers):
444
+ if output_hidden_states:
445
+ all_hidden_states += (hidden_states,)
446
+
447
+ past_key_value = (
448
+ past_key_values[idx] if past_key_values is not None else None
449
+ )
450
+
451
+ if self.gradient_checkpointing and self.training:
452
+
453
+ def create_custom_forward(module):
454
+ def custom_forward(*inputs):
455
+ # None for past_key_value
456
+ return module(*inputs, output_attentions, None)
457
+
458
+ return custom_forward
459
+
460
+ layer_outputs = torch.utils.checkpoint.checkpoint(
461
+ create_custom_forward(decoder_layer),
462
+ hidden_states,
463
+ attention_mask,
464
+ None,
465
+ )
466
+ else:
467
+ layer_outputs = decoder_layer(
468
+ hidden_states,
469
+ attention_mask=attention_mask,
470
+ past_key_value=past_key_value,
471
+ output_attentions=output_attentions,
472
+ use_cache=use_cache,
473
+ )
474
+
475
+ hidden_states = layer_outputs[0]
476
+
477
+ if use_cache:
478
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
479
+
480
+ if output_attentions:
481
+ all_self_attns += (layer_outputs[1],)
482
+
483
+ hidden_states = self.norm(hidden_states)
484
+
485
+ # add hidden states from the last decoder layer
486
+ if output_hidden_states:
487
+ all_hidden_states += (hidden_states,)
488
+
489
+ next_cache = next_decoder_cache if use_cache else None
490
+ if not return_dict:
491
+ return tuple(
492
+ v
493
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
494
+ if v is not None
495
+ )
496
+ return BaseModelOutputWithPast(
497
+ last_hidden_state=hidden_states,
498
+ past_key_values=next_cache,
499
+ hidden_states=all_hidden_states,
500
+ attentions=all_self_attns,
501
+ )
502
+
503
+
504
+ class NormHead(nn.Module):
505
+ def __init__(self, hidden_size, vocab_size, bias=False):
506
+ super().__init__()
507
+ self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size)))
508
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
509
+ self.first_flag = True
510
+
511
+ def forward(self, hidden_states):
512
+ if self.training:
513
+ norm_weight = nn.functional.normalize(self.weight)
514
+ elif self.first_flag:
515
+ self.first_flag = False
516
+ self.weight = nn.Parameter(nn.functional.normalize(self.weight))
517
+ norm_weight = self.weight
518
+ else:
519
+ norm_weight = self.weight
520
+ return nn.functional.linear(hidden_states, norm_weight)
521
+
522
+ _init_weights = True
523
+ @contextmanager
524
+ def no_init_weights(_enable=True):
525
+ global _init_weights
526
+ old_init_weights = _init_weights
527
+ if _enable:
528
+ _init_weights = False
529
+ try:
530
+ yield
531
+ finally:
532
+ _init_weights = old_init_weights
533
+
534
+
535
+ class BaichuanForCausalLM(BaichuanPreTrainedModel):
536
+ def __init__(self, config, *model_args, **model_kwargs):
537
+ super().__init__(config, *model_args, **model_kwargs)
538
+ self.model = BaichuanModel(config)
539
+ self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
540
+ #if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']:
541
+ if hasattr(config, "quantization_config") and isinstance(config.quantization_config, dict) and config.quantization_config.get('load_in_4bit', False):
542
+ try:
543
+ from .quantizer import quantize_offline, init_model_weight_int4
544
+ except ImportError:
545
+ raise ImportError(f"Needs quantize_offline to run quantize.")
546
+ quantize_offline(self, 4)
547
+ # Initialize weights and apply final processing
548
+ self.post_init()
549
+
550
+ def get_input_embeddings(self):
551
+ return self.model.embed_tokens
552
+
553
+ def set_input_embeddings(self, value):
554
+ self.model.embed_tokens = value
555
+
556
+ def get_output_embeddings(self):
557
+ return self.lm_head
558
+
559
+ def set_output_embeddings(self, new_embeddings):
560
+ self.lm_head = new_embeddings
561
+
562
+ def set_decoder(self, decoder):
563
+ self.model = decoder
564
+
565
+ def get_decoder(self):
566
+ return self.model
567
+
568
+ @classmethod
569
+ def from_pretrained(
570
+ cls,
571
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
572
+ *model_args,
573
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
574
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
575
+ ignore_mismatched_sizes: bool = False,
576
+ force_download: bool = False,
577
+ local_files_only: bool = False,
578
+ token: Optional[Union[str, bool]] = None,
579
+ revision: str = "main",
580
+ use_safetensors: bool = None,
581
+ **kwargs,
582
+ ):
583
+
584
+ # Load config if we don't provide a configuration
585
+ if not isinstance(config, PretrainedConfig):
586
+ config_path = config if config is not None else pretrained_model_name_or_path
587
+ config, model_kwargs = cls.config_class.from_pretrained(
588
+ config_path,
589
+ cache_dir=cache_dir,
590
+ return_unused_kwargs=True,
591
+ force_download=force_download,
592
+ resume_download=False,
593
+ proxies=None,
594
+ local_files_only=local_files_only,
595
+ token=token,
596
+ revision=revision,
597
+ subfolder="",
598
+ _from_auto=False,
599
+ _from_pipeline=None,
600
+ **kwargs,
601
+ )
602
+ else:
603
+ model_kwargs = kwargs
604
+
605
+ if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']:
606
+ try:
607
+ from .quantizer import init_model_weight_int4
608
+ from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map
609
+ from accelerate.utils import CustomDtype
610
+ from accelerate.utils import get_balanced_memory
611
+ except ImportError:
612
+ raise ImportError(f"Needs import model weight init func to run quantize.")
613
+ # Instantiate model.
614
+ init_contexts = [no_init_weights(_enable=True)]
615
+ init_contexts.append(init_empty_weights())
616
+ with ContextManagers(init_contexts):
617
+ model = cls(config)
618
+
619
+ model_file = os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin')
620
+ state_dict = torch.load(model_file, map_location="cpu")
621
+ model.is_quantized = True
622
+
623
+ device_map = kwargs.pop("device_map", None)
624
+ torch_dtype = kwargs.pop("torch_dtype", None)
625
+ if device_map is not None:
626
+ kwargs = {"no_split_module_classes": model._no_split_modules}
627
+ target_dtype = CustomDtype.INT4
628
+ max_memory = get_balanced_memory(
629
+ model,
630
+ dtype=target_dtype,
631
+ low_zero=(device_map == "balanced_low_0"),
632
+ max_memory=None,
633
+ **kwargs,
634
+ )
635
+ kwargs["max_memory"] = max_memory
636
+ device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs)
637
+ model = init_model_weight_int4(config, model, state_dict)
638
+
639
+ # Set model in evaluation mode to deactivate DropOut modules by default
640
+ model.eval()
641
+ # If it is a model with generation capabilities, attempt to load the generation config
642
+ if model.can_generate():
643
+ try:
644
+ model.generation_config = GenerationConfig.from_pretrained(
645
+ pretrained_model_name_or_path,
646
+ cache_dir=cache_dir,
647
+ force_download=force_download,
648
+ resume_download=False,
649
+ proxies=None,
650
+ local_files_only=local_files_only,
651
+ token=token,
652
+ revision=revision,
653
+ subfolder="",
654
+ _from_auto=False,
655
+ _from_pipeline=None,
656
+ **kwargs,
657
+ )
658
+ except (OSError, TypeError):
659
+ logger.info(
660
+ "Generation config file not found, using a generation config created from the model config."
661
+ )
662
+ pass
663
+
664
+ if device_map is not None:
665
+ dispatch_model(model, device_map=device_map)
666
+
667
+ return model
668
+
669
+ return super(BaichuanForCausalLM, cls).from_pretrained(pretrained_model_name_or_path, *model_args,
670
+ config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes,
671
+ force_download=force_download, local_files_only=local_files_only, token=token, revision=revision,
672
+ use_safetensors=use_safetensors, **kwargs)
673
+
674
+ def forward(
675
+ self,
676
+ input_ids: torch.LongTensor = None,
677
+ attention_mask: Optional[torch.Tensor] = None,
678
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
679
+ inputs_embeds: Optional[torch.FloatTensor] = None,
680
+ labels: Optional[torch.LongTensor] = None,
681
+ use_cache: Optional[bool] = None,
682
+ output_attentions: Optional[bool] = False,
683
+ output_hidden_states: Optional[bool] = False,
684
+ return_dict: Optional[bool] = True,
685
+ **kwargs,
686
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
687
+ return_dict = (
688
+ return_dict if return_dict is not None else self.config.use_return_dict
689
+ )
690
+
691
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
692
+ outputs = self.model(
693
+ input_ids=input_ids,
694
+ attention_mask=attention_mask,
695
+ past_key_values=past_key_values,
696
+ inputs_embeds=inputs_embeds,
697
+ use_cache=use_cache,
698
+ output_attentions=output_attentions,
699
+ output_hidden_states=output_hidden_states,
700
+ return_dict=return_dict,
701
+ )
702
+
703
+ hidden_states = outputs[0]
704
+ logits = self.lm_head(hidden_states)
705
+ loss = None
706
+ if labels is not None:
707
+ # Shift so that tokens < n predict n
708
+ shift_logits = logits[..., :-1, :].contiguous()
709
+ shift_labels = labels[..., 1:].contiguous()
710
+ # Flatten the tokens
711
+ loss_fct = CrossEntropyLoss()
712
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
713
+ shift_labels = shift_labels.view(-1)
714
+ softmax_normalizer = shift_logits.max(-1).values ** 2
715
+ z_loss = self.config.z_loss_weight * softmax_normalizer.mean()
716
+ # Enable model parallelism
717
+ shift_labels = shift_labels.to(shift_logits.device)
718
+ loss = loss_fct(shift_logits, shift_labels) + z_loss
719
+
720
+ if not return_dict:
721
+ output = (logits,) + outputs[1:]
722
+ return (loss,) + output if loss is not None else output
723
+
724
+ return CausalLMOutputWithPast(
725
+ loss=loss,
726
+ logits=logits,
727
+ past_key_values=outputs.past_key_values,
728
+ hidden_states=outputs.hidden_states,
729
+ attentions=outputs.attentions,
730
+ )
731
+
732
+ def quantize(self, bits: int):
733
+ try:
734
+ from .quantizer import quantize_online
735
+ except ImportError:
736
+ raise ImportError(f"Needs QLinear to run quantize.")
737
+ return quantize_online(self, bits)
738
+
739
+ def prepare_inputs_for_generation(
740
+ self,
741
+ input_ids: torch.LongTensor,
742
+ past_key_values: Optional[torch.Tensor] = None,
743
+ attention_mask: Optional[torch.Tensor] = None,
744
+ inputs_embeds: Optional[torch.Tensor] = None,
745
+ **kwargs,
746
+ ):
747
+ if past_key_values:
748
+ input_ids = input_ids[:, -1:]
749
+
750
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
751
+ if inputs_embeds is not None and past_key_values is None:
752
+ model_inputs = {"inputs_embeds": inputs_embeds}
753
+ else:
754
+ model_inputs = {"input_ids": input_ids}
755
+
756
+ model_inputs.update(
757
+ {
758
+ "past_key_values": past_key_values,
759
+ "use_cache": kwargs.get("use_cache"),
760
+ "attention_mask": attention_mask,
761
+ }
762
+ )
763
+ return model_inputs
764
+
765
+ @staticmethod
766
+ def _reorder_cache(past_key_values, beam_idx):
767
+ return tuple(
768
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
769
+ for layer_past in past_key_values
770
+ )
771
+
772
+ def _build_chat_input(
773
+ self, tokenizer, messages: List[dict], max_new_tokens: int = 0
774
+ ):
775
+ max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
776
+ max_input_tokens = self.config.model_max_length - max_new_tokens
777
+ max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens)
778
+ total_input, round_input = [], []
779
+ for i, message in enumerate(messages[::-1]):
780
+ content_tokens = tokenizer.encode(message["content"])
781
+ if message["role"] == "user":
782
+ round_input = (
783
+ [self.generation_config.user_token_id]
784
+ + content_tokens
785
+ + round_input
786
+ )
787
+ if (
788
+ total_input
789
+ and len(total_input) + len(round_input) > max_input_tokens
790
+ ):
791
+ break
792
+ else:
793
+ total_input = round_input + total_input
794
+ if len(total_input) >= max_input_tokens:
795
+ break
796
+ else:
797
+ round_input = []
798
+ elif message["role"] == "assistant":
799
+ round_input = (
800
+ [self.generation_config.assistant_token_id]
801
+ + content_tokens
802
+ + [self.generation_config.eos_token_id]
803
+ + round_input
804
+ )
805
+ else:
806
+ raise ValueError(f"message role not supported yet: {message['role']}")
807
+ total_input = total_input[-max_input_tokens:] # truncate left
808
+ total_input.append(self.generation_config.assistant_token_id)
809
+ total_input = torch.LongTensor([total_input]).to(self.device)
810
+ return total_input
811
+
812
+ def chat(self, tokenizer, messages: List[dict], stream=False,
813
+ generation_config: Optional[GenerationConfig]=None):
814
+ generation_config = generation_config or self.generation_config
815
+ input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
816
+ if stream:
817
+ streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
818
+ Thread(target=self.generate, kwargs=dict(
819
+ inputs=input_ids, streamer=streamer,
820
+ generation_config=generation_config,
821
+ )).start()
822
+ return streamer
823
+ else:
824
+ outputs = self.generate(input_ids, generation_config=generation_config)
825
+ response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
826
+ return response
pytorch_model-00001-of-00003.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18f1e36ea564699ae87a4e2a8c222d90d5ea8b9b6c2dead08de1905a5d9e90db
3
+ size 9973567639
pytorch_model-00002-of-00003.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0416025049354d893293fa66c702c3e30cf2a3c2fdfbe8c295b34a25b8ac879b
3
+ size 9947419824
pytorch_model-00003-of-00003.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae92cd8d811136bb6b7879894397a63cd196a83d8da0d4cfeaa30d0d759ed278
3
+ size 7872445619
pytorch_model.bin.index.json ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 27793336320
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "pytorch_model-00003-of-00003.bin",
7
+ "model.embed_tokens.weight": "pytorch_model-00001-of-00003.bin",
8
+ "model.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00003.bin",
9
+ "model.layers.0.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin",
10
+ "model.layers.0.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin",
11
+ "model.layers.0.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin",
12
+ "model.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
13
+ "model.layers.0.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin",
14
+ "model.layers.0.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin",
15
+ "model.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00003.bin",
16
+ "model.layers.1.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin",
17
+ "model.layers.1.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin",
18
+ "model.layers.1.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin",
19
+ "model.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
20
+ "model.layers.1.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin",
21
+ "model.layers.1.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin",
22
+ "model.layers.10.input_layernorm.weight": "pytorch_model-00001-of-00003.bin",
23
+ "model.layers.10.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin",
24
+ "model.layers.10.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin",
25
+ "model.layers.10.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin",
26
+ "model.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
27
+ "model.layers.10.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin",
28
+ "model.layers.10.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin",
29
+ "model.layers.11.input_layernorm.weight": "pytorch_model-00001-of-00003.bin",
30
+ "model.layers.11.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin",
31
+ "model.layers.11.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin",
32
+ "model.layers.11.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin",
33
+ "model.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
34
+ "model.layers.11.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin",
35
+ "model.layers.11.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin",
36
+ "model.layers.12.input_layernorm.weight": "pytorch_model-00001-of-00003.bin",
37
+ "model.layers.12.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin",
38
+ "model.layers.12.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin",
39
+ "model.layers.12.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin",
40
+ "model.layers.12.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
41
+ "model.layers.12.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin",
42
+ "model.layers.12.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin",
43
+ "model.layers.13.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
44
+ "model.layers.13.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin",
45
+ "model.layers.13.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin",
46
+ "model.layers.13.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
47
+ "model.layers.13.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
48
+ "model.layers.13.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin",
49
+ "model.layers.13.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin",
50
+ "model.layers.14.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
51
+ "model.layers.14.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin",
52
+ "model.layers.14.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
53
+ "model.layers.14.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
54
+ "model.layers.14.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
55
+ "model.layers.14.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
56
+ "model.layers.14.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
57
+ "model.layers.15.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
58
+ "model.layers.15.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin",
59
+ "model.layers.15.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
60
+ "model.layers.15.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
61
+ "model.layers.15.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
62
+ "model.layers.15.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
63
+ "model.layers.15.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
64
+ "model.layers.16.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
65
+ "model.layers.16.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin",
66
+ "model.layers.16.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
67
+ "model.layers.16.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
68
+ "model.layers.16.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
69
+ "model.layers.16.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
70
+ "model.layers.16.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
71
+ "model.layers.17.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
72
+ "model.layers.17.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin",
73
+ "model.layers.17.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
74
+ "model.layers.17.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
75
+ "model.layers.17.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
76
+ "model.layers.17.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
77
+ "model.layers.17.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
78
+ "model.layers.18.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
79
+ "model.layers.18.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin",
80
+ "model.layers.18.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
81
+ "model.layers.18.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
82
+ "model.layers.18.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
83
+ "model.layers.18.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
84
+ "model.layers.18.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
85
+ "model.layers.19.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
86
+ "model.layers.19.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin",
87
+ "model.layers.19.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
88
+ "model.layers.19.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
89
+ "model.layers.19.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
90
+ "model.layers.19.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
91
+ "model.layers.19.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
92
+ "model.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00003.bin",
93
+ "model.layers.2.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin",
94
+ "model.layers.2.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin",
95
+ "model.layers.2.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin",
96
+ "model.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
97
+ "model.layers.2.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin",
98
+ "model.layers.2.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin",
99
+ "model.layers.20.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
100
+ "model.layers.20.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin",
101
+ "model.layers.20.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
102
+ "model.layers.20.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
103
+ "model.layers.20.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
104
+ "model.layers.20.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
105
+ "model.layers.20.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
106
+ "model.layers.21.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
107
+ "model.layers.21.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin",
108
+ "model.layers.21.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
109
+ "model.layers.21.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
110
+ "model.layers.21.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
111
+ "model.layers.21.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
112
+ "model.layers.21.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
113
+ "model.layers.22.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
114
+ "model.layers.22.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin",
115
+ "model.layers.22.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
116
+ "model.layers.22.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
117
+ "model.layers.22.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
118
+ "model.layers.22.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
119
+ "model.layers.22.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
120
+ "model.layers.23.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
121
+ "model.layers.23.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin",
122
+ "model.layers.23.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
123
+ "model.layers.23.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
124
+ "model.layers.23.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
125
+ "model.layers.23.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
126
+ "model.layers.23.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
127
+ "model.layers.24.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
128
+ "model.layers.24.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin",
129
+ "model.layers.24.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
130
+ "model.layers.24.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
131
+ "model.layers.24.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
132
+ "model.layers.24.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
133
+ "model.layers.24.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
134
+ "model.layers.25.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
135
+ "model.layers.25.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin",
136
+ "model.layers.25.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
137
+ "model.layers.25.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
138
+ "model.layers.25.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
139
+ "model.layers.25.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
140
+ "model.layers.25.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
141
+ "model.layers.26.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
142
+ "model.layers.26.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin",
143
+ "model.layers.26.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
144
+ "model.layers.26.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
145
+ "model.layers.26.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
146
+ "model.layers.26.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
147
+ "model.layers.26.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
148
+ "model.layers.27.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
149
+ "model.layers.27.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin",
150
+ "model.layers.27.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
151
+ "model.layers.27.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
152
+ "model.layers.27.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
153
+ "model.layers.27.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
154
+ "model.layers.27.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
155
+ "model.layers.28.input_layernorm.weight": "pytorch_model-00002-of-00003.bin",
156
+ "model.layers.28.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin",
157
+ "model.layers.28.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
158
+ "model.layers.28.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin",
159
+ "model.layers.28.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
160
+ "model.layers.28.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
161
+ "model.layers.28.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
162
+ "model.layers.29.input_layernorm.weight": "pytorch_model-00003-of-00003.bin",
163
+ "model.layers.29.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin",
164
+ "model.layers.29.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin",
165
+ "model.layers.29.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin",
166
+ "model.layers.29.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
167
+ "model.layers.29.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin",
168
+ "model.layers.29.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin",
169
+ "model.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00003.bin",
170
+ "model.layers.3.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin",
171
+ "model.layers.3.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin",
172
+ "model.layers.3.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin",
173
+ "model.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
174
+ "model.layers.3.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin",
175
+ "model.layers.3.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin",
176
+ "model.layers.30.input_layernorm.weight": "pytorch_model-00003-of-00003.bin",
177
+ "model.layers.30.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin",
178
+ "model.layers.30.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin",
179
+ "model.layers.30.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin",
180
+ "model.layers.30.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
181
+ "model.layers.30.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin",
182
+ "model.layers.30.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin",
183
+ "model.layers.31.input_layernorm.weight": "pytorch_model-00003-of-00003.bin",
184
+ "model.layers.31.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin",
185
+ "model.layers.31.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin",
186
+ "model.layers.31.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin",
187
+ "model.layers.31.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
188
+ "model.layers.31.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin",
189
+ "model.layers.31.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin",
190
+ "model.layers.32.input_layernorm.weight": "pytorch_model-00003-of-00003.bin",
191
+ "model.layers.32.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin",
192
+ "model.layers.32.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin",
193
+ "model.layers.32.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin",
194
+ "model.layers.32.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
195
+ "model.layers.32.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin",
196
+ "model.layers.32.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin",
197
+ "model.layers.33.input_layernorm.weight": "pytorch_model-00003-of-00003.bin",
198
+ "model.layers.33.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin",
199
+ "model.layers.33.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin",
200
+ "model.layers.33.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin",
201
+ "model.layers.33.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
202
+ "model.layers.33.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin",
203
+ "model.layers.33.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin",
204
+ "model.layers.34.input_layernorm.weight": "pytorch_model-00003-of-00003.bin",
205
+ "model.layers.34.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin",
206
+ "model.layers.34.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin",
207
+ "model.layers.34.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin",
208
+ "model.layers.34.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
209
+ "model.layers.34.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin",
210
+ "model.layers.34.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin",
211
+ "model.layers.35.input_layernorm.weight": "pytorch_model-00003-of-00003.bin",
212
+ "model.layers.35.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin",
213
+ "model.layers.35.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin",
214
+ "model.layers.35.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin",
215
+ "model.layers.35.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
216
+ "model.layers.35.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin",
217
+ "model.layers.35.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin",
218
+ "model.layers.36.input_layernorm.weight": "pytorch_model-00003-of-00003.bin",
219
+ "model.layers.36.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin",
220
+ "model.layers.36.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin",
221
+ "model.layers.36.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin",
222
+ "model.layers.36.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
223
+ "model.layers.36.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin",
224
+ "model.layers.36.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin",
225
+ "model.layers.37.input_layernorm.weight": "pytorch_model-00003-of-00003.bin",
226
+ "model.layers.37.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin",
227
+ "model.layers.37.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin",
228
+ "model.layers.37.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin",
229
+ "model.layers.37.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
230
+ "model.layers.37.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin",
231
+ "model.layers.37.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin",
232
+ "model.layers.38.input_layernorm.weight": "pytorch_model-00003-of-00003.bin",
233
+ "model.layers.38.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin",
234
+ "model.layers.38.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin",
235
+ "model.layers.38.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin",
236
+ "model.layers.38.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
237
+ "model.layers.38.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin",
238
+ "model.layers.38.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin",
239
+ "model.layers.39.input_layernorm.weight": "pytorch_model-00003-of-00003.bin",
240
+ "model.layers.39.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin",
241
+ "model.layers.39.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin",
242
+ "model.layers.39.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin",
243
+ "model.layers.39.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
244
+ "model.layers.39.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin",
245
+ "model.layers.39.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin",
246
+ "model.layers.4.input_layernorm.weight": "pytorch_model-00001-of-00003.bin",
247
+ "model.layers.4.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin",
248
+ "model.layers.4.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin",
249
+ "model.layers.4.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin",
250
+ "model.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
251
+ "model.layers.4.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin",
252
+ "model.layers.4.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin",
253
+ "model.layers.5.input_layernorm.weight": "pytorch_model-00001-of-00003.bin",
254
+ "model.layers.5.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin",
255
+ "model.layers.5.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin",
256
+ "model.layers.5.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin",
257
+ "model.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
258
+ "model.layers.5.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin",
259
+ "model.layers.5.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin",
260
+ "model.layers.6.input_layernorm.weight": "pytorch_model-00001-of-00003.bin",
261
+ "model.layers.6.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin",
262
+ "model.layers.6.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin",
263
+ "model.layers.6.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin",
264
+ "model.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
265
+ "model.layers.6.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin",
266
+ "model.layers.6.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin",
267
+ "model.layers.7.input_layernorm.weight": "pytorch_model-00001-of-00003.bin",
268
+ "model.layers.7.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin",
269
+ "model.layers.7.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin",
270
+ "model.layers.7.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin",
271
+ "model.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
272
+ "model.layers.7.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin",
273
+ "model.layers.7.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin",
274
+ "model.layers.8.input_layernorm.weight": "pytorch_model-00001-of-00003.bin",
275
+ "model.layers.8.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin",
276
+ "model.layers.8.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin",
277
+ "model.layers.8.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin",
278
+ "model.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
279
+ "model.layers.8.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin",
280
+ "model.layers.8.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin",
281
+ "model.layers.9.input_layernorm.weight": "pytorch_model-00001-of-00003.bin",
282
+ "model.layers.9.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin",
283
+ "model.layers.9.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin",
284
+ "model.layers.9.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin",
285
+ "model.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
286
+ "model.layers.9.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin",
287
+ "model.layers.9.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin",
288
+ "model.norm.weight": "pytorch_model-00003-of-00003.bin"
289
+ }
290
+ }
quantizer.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bitsandbytes as bnb
2
+ from accelerate import init_empty_weights
3
+ from bitsandbytes.nn.modules import Params4bit, Int8Params
4
+ import torch
5
+
6
+ def Params4bitCuda(self, device):
7
+ self.data = self.data.cuda(device)
8
+ self.quant_state[0] = self.quant_state[0].cuda(device)
9
+ self.quant_state[4][0] = self.quant_state[4][0].cuda(device)
10
+ self.quant_state[4][1][0] = self.quant_state[4][1][0].cuda(device)
11
+ self.quant_state[4][1][1] = self.quant_state[4][1][1].cuda(device)
12
+
13
+ self.quant_state[6] = self.quant_state[6].cuda(device)
14
+ return self
15
+
16
+ class Linear4bitOnline(torch.nn.Module):
17
+ def __init__(self, weight, bias, quant_type):
18
+ super().__init__()
19
+ self.weight = Params4bit(
20
+ weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
21
+ )
22
+ self.compute_dtype = None
23
+ #self.weight.cuda(weight.device)
24
+ self.bias = bias
25
+
26
+ def forward(self, x: torch.Tensor):
27
+ # weights are cast automatically as Int8Params, but the bias has to be cast manually
28
+ if self.bias is not None and self.bias.dtype != x.dtype:
29
+ self.bias.data = self.bias.data.to(x.dtype)
30
+
31
+ if getattr(self.weight, "quant_state", None) is None:
32
+ print(
33
+ "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
34
+ )
35
+ inp_dtype = x.dtype
36
+ if self.compute_dtype is not None:
37
+ x = x.to(self.compute_dtype)
38
+
39
+ bias = None if self.bias is None else self.bias.to(self.compute_dtype)
40
+ out = bnb.matmul_4bit(
41
+ x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
42
+ )
43
+
44
+ out = out.to(inp_dtype)
45
+
46
+ return out
47
+
48
+ class Linear8bitLtOnline(torch.nn.Module):
49
+ def __init__(
50
+ self,
51
+ weight,
52
+ bias,
53
+ has_fp16_weights=True,
54
+ memory_efficient_backward=False,
55
+ threshold=0.0,
56
+ index=None,
57
+ ):
58
+ super().__init__()
59
+ assert (
60
+ not memory_efficient_backward
61
+ ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
62
+ self.state = bnb.MatmulLtState()
63
+ self.index = index
64
+
65
+ # Necessary for stacked layers
66
+ self.state.threshold = threshold
67
+ self.state.has_fp16_weights = has_fp16_weights
68
+ self.state.memory_efficient_backward = memory_efficient_backward
69
+ if threshold > 0.0 and not has_fp16_weights:
70
+ self.state.use_pool = True
71
+
72
+ self.weight = Int8Params(
73
+ weight.data,
74
+ has_fp16_weights=has_fp16_weights,
75
+ requires_grad=has_fp16_weights,
76
+ )
77
+ self.bias = bias
78
+
79
+ def init_8bit_state(self):
80
+ self.state.CB = self.weight.CB
81
+ self.state.SCB = self.weight.SCB
82
+ self.weight.CB = None
83
+ self.weight.SCB = None
84
+
85
+ def forward(self, x: torch.Tensor):
86
+ self.state.is_training = self.training
87
+ if self.weight.CB is not None:
88
+ self.init_8bit_state()
89
+
90
+ # weights are cast automatically as Int8Params, but the bias has to be cast manually
91
+ if self.bias is not None and self.bias.dtype != x.dtype:
92
+ self.bias.data = self.bias.data.to(x.dtype)
93
+
94
+ out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
95
+
96
+ if not self.state.has_fp16_weights:
97
+ if self.state.CB is not None and self.state.CxB is not None:
98
+ # we converted 8-bit row major to turing/ampere format in the first inference pass
99
+ # we no longer need the row-major weight
100
+ del self.state.CB
101
+ self.weight.data = self.state.CxB
102
+ return out
103
+
104
+ def quantize_offline(model, bits: int):
105
+ assert (bits == 4), f'bits: {bits} is not supported'
106
+
107
+ for i, layer in enumerate(model.model.layers):
108
+ layer.self_attn.W_pack = bnb.nn.Linear4bit(
109
+ layer.self_attn.W_pack.weight.shape[1],
110
+ layer.self_attn.W_pack.weight.shape[0],
111
+ False,
112
+ torch.float16,
113
+ compress_statistics=True,
114
+ quant_type="nf4",
115
+ )
116
+ layer.self_attn.o_proj = bnb.nn.Linear4bit(
117
+ layer.self_attn.o_proj.weight.shape[1],
118
+ layer.self_attn.o_proj.weight.shape[0],
119
+ False,
120
+ torch.float16,
121
+ compress_statistics=True,
122
+ quant_type="nf4",
123
+ )
124
+
125
+ layer.mlp.gate_proj = bnb.nn.Linear4bit(
126
+ layer.mlp.gate_proj.weight.shape[1],
127
+ layer.mlp.gate_proj.weight.shape[0],
128
+ False,
129
+ torch.float16,
130
+ compress_statistics=True,
131
+ quant_type="nf4",
132
+ )
133
+ layer.mlp.down_proj = bnb.nn.Linear4bit(
134
+ layer.mlp.down_proj.weight.shape[1],
135
+ layer.mlp.down_proj.weight.shape[0],
136
+ False,
137
+ torch.float16,
138
+ compress_statistics=True,
139
+ quant_type="nf4",
140
+ )
141
+ layer.mlp.up_proj = bnb.nn.Linear4bit(
142
+ layer.mlp.up_proj.weight.shape[1],
143
+ layer.mlp.up_proj.weight.shape[0],
144
+ False,
145
+ torch.float16,
146
+ compress_statistics=True,
147
+ quant_type="nf4",
148
+ )
149
+ return model
150
+
151
+ def quantize_online(model, bits: int):
152
+ def quant(weight, bias=None):
153
+ if bits == 8:
154
+ linear = Linear8bitLtOnline(
155
+ weight,
156
+ bias,
157
+ has_fp16_weights=False,
158
+ threshold=6.0,
159
+ )
160
+ if bias is not None:
161
+ linear.bias = torch.nn.Parameter(bias)
162
+ elif bits == 4:
163
+ linear = Linear4bitOnline(
164
+ weight,
165
+ bias,
166
+ quant_type="nf4", #fp4/nf4
167
+ )
168
+ else:
169
+ raise ValueError("quantize only support 4/8 bit")
170
+ return linear
171
+
172
+ for i, layer in enumerate(model.model.layers):
173
+ layer.self_attn.W_pack = quant(layer.self_attn.W_pack.weight)
174
+ layer.self_attn.o_proj = quant(layer.self_attn.o_proj.weight)
175
+ layer.mlp.gate_proj = quant(layer.mlp.gate_proj.weight)
176
+ layer.mlp.down_proj = quant(layer.mlp.down_proj.weight)
177
+ layer.mlp.up_proj = quant(layer.mlp.up_proj.weight)
178
+ return model
179
+
180
+ def init_model_weight_int4(config, model, state_dict):
181
+ #replace Params4bit.cuda with Params4bitCuda
182
+ Params4bit.cuda = Params4bitCuda
183
+
184
+ for i in range(config.num_hidden_layers):
185
+ weight_data = state_dict[f'model.layers.{i}.self_attn.W_pack.weight.data']
186
+ weight_quant_state = state_dict[f'model.layers.{i}.self_attn.W_pack.weight.quant_state']
187
+ model.model.layers[i].self_attn.W_pack.weight = Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state)
188
+
189
+ weight_data = state_dict[f'model.layers.{i}.self_attn.o_proj.weight.data']
190
+ weight_quant_state = state_dict[f'model.layers.{i}.self_attn.o_proj.weight.quant_state']
191
+ model.model.layers[i].self_attn.o_proj.weight = Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state)
192
+
193
+ weight_data = state_dict[f'model.layers.{i}.mlp.gate_proj.weight.data']
194
+ weight_quant_state = state_dict[f'model.layers.{i}.mlp.gate_proj.weight.quant_state']
195
+ model.model.layers[i].mlp.gate_proj.weight = Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state)
196
+
197
+ weight_data = state_dict[f'model.layers.{i}.mlp.up_proj.weight.data']
198
+ weight_quant_state = state_dict[f'model.layers.{i}.mlp.up_proj.weight.quant_state']
199
+ model.model.layers[i].mlp.up_proj.weight = Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state)
200
+
201
+ weight_data = state_dict[f'model.layers.{i}.mlp.down_proj.weight.data']
202
+ weight_quant_state = state_dict[f'model.layers.{i}.mlp.down_proj.weight.quant_state']
203
+ model.model.layers[i].mlp.down_proj.weight = Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state)
204
+
205
+ model.model.layers[i].input_layernorm.weight = state_dict[f'model.layers.{i}.input_layernorm.weight']
206
+ model.model.layers[i].post_attention_layernorm.weight = state_dict[f'model.layers.{i}.post_attention_layernorm.weight']
207
+
208
+ model.model.embed_tokens.weight = state_dict['model.embed_tokens.weight']
209
+ model.model.norm.weight = state_dict['model.norm.weight']
210
+ model.lm_head.weight = state_dict['lm_head.weight']
211
+ return model
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": true
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": true
15
+ },
16
+ "pad_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": true
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": true,
27
+ "rstrip": false,
28
+ "single_word": true
29
+ }
30
+ }
tokenization_baichuan.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
2
+
3
+ import os
4
+ from shutil import copyfile
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import sentencepiece as spm
8
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
9
+ from transformers.utils import logging
10
+
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
15
+
16
+ PRETRAINED_VOCAB_FILES_MAP = {
17
+ "vocab_file": {},
18
+ "tokenizer_file": {},
19
+ }
20
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
21
+
22
+
23
+ class BaichuanTokenizer(PreTrainedTokenizer):
24
+ """
25
+ Construct a Baichuan tokenizer. Based on byte-level Byte-Pair-Encoding.
26
+
27
+ Args:
28
+ vocab_file (`str`):
29
+ Path to the vocabulary file.
30
+ """
31
+
32
+ vocab_files_names = VOCAB_FILES_NAMES
33
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
34
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
35
+ model_input_names = ["input_ids", "attention_mask"]
36
+
37
+ def __init__(
38
+ self,
39
+ vocab_file,
40
+ unk_token="<unk>",
41
+ bos_token="<s>",
42
+ eos_token="</s>",
43
+ pad_token=None,
44
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
45
+ add_bos_token=True,
46
+ add_eos_token=False,
47
+ clean_up_tokenization_spaces=False,
48
+ **kwargs,
49
+ ):
50
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
51
+ bos_token = (
52
+ AddedToken(bos_token, lstrip=False, rstrip=False)
53
+ if isinstance(bos_token, str)
54
+ else bos_token
55
+ )
56
+ eos_token = (
57
+ AddedToken(eos_token, lstrip=False, rstrip=False)
58
+ if isinstance(eos_token, str)
59
+ else eos_token
60
+ )
61
+ unk_token = (
62
+ AddedToken(unk_token, lstrip=False, rstrip=False)
63
+ if isinstance(unk_token, str)
64
+ else unk_token
65
+ )
66
+ pad_token = (
67
+ AddedToken(pad_token, lstrip=False, rstrip=False)
68
+ if isinstance(pad_token, str)
69
+ else pad_token
70
+ )
71
+ super().__init__(
72
+ bos_token=bos_token,
73
+ eos_token=eos_token,
74
+ unk_token=unk_token,
75
+ pad_token=pad_token,
76
+ add_bos_token=add_bos_token,
77
+ add_eos_token=add_eos_token,
78
+ sp_model_kwargs=self.sp_model_kwargs,
79
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
80
+ **kwargs,
81
+ )
82
+ self.vocab_file = vocab_file
83
+ self.add_bos_token = add_bos_token
84
+ self.add_eos_token = add_eos_token
85
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
86
+ self.sp_model.Load(vocab_file)
87
+
88
+ def __getstate__(self):
89
+ state = self.__dict__.copy()
90
+ state["sp_model"] = None
91
+ return state
92
+
93
+ def __setstate__(self, d):
94
+ self.__dict__ = d
95
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
96
+ self.sp_model.Load(self.vocab_file)
97
+
98
+ @property
99
+ def vocab_size(self):
100
+ """Returns vocab size"""
101
+ return self.sp_model.get_piece_size()
102
+
103
+ def get_vocab(self):
104
+ """Returns vocab as a dict"""
105
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
106
+ vocab.update(self.added_tokens_encoder)
107
+ return vocab
108
+
109
+ def _tokenize(self, text):
110
+ """Returns a tokenized string."""
111
+ return self.sp_model.encode(text, out_type=str)
112
+
113
+ def _convert_token_to_id(self, token):
114
+ """Converts a token (str) in an id using the vocab."""
115
+ return self.sp_model.piece_to_id(token)
116
+
117
+ def _convert_id_to_token(self, index):
118
+ """Converts an index (integer) in a token (str) using the vocab."""
119
+ token = self.sp_model.IdToPiece(index)
120
+ return token
121
+
122
+ def convert_tokens_to_string(self, tokens):
123
+ """Converts a sequence of tokens (string) in a single string."""
124
+ current_sub_tokens = []
125
+ out_string = ""
126
+ prev_is_special = False
127
+ for i, token in enumerate(tokens):
128
+ # make sure that special tokens are not decoded using sentencepiece model
129
+ if token in self.all_special_tokens:
130
+ if not prev_is_special and i != 0:
131
+ out_string += " "
132
+ out_string += self.sp_model.decode(current_sub_tokens) + token
133
+ prev_is_special = True
134
+ current_sub_tokens = []
135
+ else:
136
+ current_sub_tokens.append(token)
137
+ prev_is_special = False
138
+ out_string += self.sp_model.decode(current_sub_tokens)
139
+ return out_string
140
+
141
+ def save_vocabulary(
142
+ self, save_directory, filename_prefix: Optional[str] = None
143
+ ) -> Tuple[str]:
144
+ """
145
+ Save the vocabulary and special tokens file to a directory.
146
+
147
+ Args:
148
+ save_directory (`str`):
149
+ The directory in which to save the vocabulary.
150
+
151
+ Returns:
152
+ `Tuple(str)`: Paths to the files saved.
153
+ """
154
+ if not os.path.isdir(save_directory):
155
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
156
+ return
157
+ out_vocab_file = os.path.join(
158
+ save_directory,
159
+ (filename_prefix + "-" if filename_prefix else "")
160
+ + VOCAB_FILES_NAMES["vocab_file"],
161
+ )
162
+
163
+ if os.path.abspath(self.vocab_file) != os.path.abspath(
164
+ out_vocab_file
165
+ ) and os.path.isfile(self.vocab_file):
166
+ copyfile(self.vocab_file, out_vocab_file)
167
+ elif not os.path.isfile(self.vocab_file):
168
+ with open(out_vocab_file, "wb") as fi:
169
+ content_spiece_model = self.sp_model.serialized_model_proto()
170
+ fi.write(content_spiece_model)
171
+
172
+ return (out_vocab_file,)
173
+
174
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
175
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
176
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
177
+
178
+ output = bos_token_id + token_ids_0 + eos_token_id
179
+
180
+ if token_ids_1 is not None:
181
+ output = output + bos_token_id + token_ids_1 + eos_token_id
182
+
183
+ return output
184
+
185
+ def get_special_tokens_mask(
186
+ self,
187
+ token_ids_0: List[int],
188
+ token_ids_1: Optional[List[int]] = None,
189
+ already_has_special_tokens: bool = False,
190
+ ) -> List[int]:
191
+ """
192
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
193
+ special tokens using the tokenizer `prepare_for_model` method.
194
+
195
+ Args:
196
+ token_ids_0 (`List[int]`):
197
+ List of IDs.
198
+ token_ids_1 (`List[int]`, *optional*):
199
+ Optional second list of IDs for sequence pairs.
200
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
201
+ Whether or not the token list is already formatted with special tokens for the model.
202
+
203
+ Returns:
204
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
205
+ """
206
+ if already_has_special_tokens:
207
+ return super().get_special_tokens_mask(
208
+ token_ids_0=token_ids_0,
209
+ token_ids_1=token_ids_1,
210
+ already_has_special_tokens=True,
211
+ )
212
+
213
+ bos_token_id = [1] if self.add_bos_token else []
214
+ eos_token_id = [1] if self.add_eos_token else []
215
+
216
+ if token_ids_1 is None:
217
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
218
+ return (
219
+ bos_token_id
220
+ + ([0] * len(token_ids_0))
221
+ + eos_token_id
222
+ + bos_token_id
223
+ + ([0] * len(token_ids_1))
224
+ + eos_token_id
225
+ )
226
+
227
+ def create_token_type_ids_from_sequences(
228
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
229
+ ) -> List[int]:
230
+ """
231
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
232
+ sequence pair mask has the following format:
233
+
234
+ ```
235
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
236
+ | first sequence | second sequence |
237
+ ```
238
+
239
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
240
+
241
+ Args:
242
+ token_ids_0 (`List[int]`):
243
+ List of ids.
244
+ token_ids_1 (`List[int]`, *optional*):
245
+ Optional second list of IDs for sequence pairs.
246
+
247
+ Returns:
248
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
249
+ """
250
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
251
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
252
+
253
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
254
+
255
+ if token_ids_1 is not None:
256
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
257
+
258
+ return output
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79452955be6b419a65984273a9f08af86042e1c2a75ee3ba989cbf620a133cc2
3
+ size 2001107
tokenizer_config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "auto_map": {
5
+ "AutoTokenizer": [
6
+ "tokenization_baichuan.BaichuanTokenizer",
7
+ null
8
+ ]
9
+ },
10
+ "bos_token": {
11
+ "__type": "AddedToken",
12
+ "content": "<s>",
13
+ "lstrip": false,
14
+ "normalized": true,
15
+ "rstrip": false,
16
+ "single_word": true
17
+ },
18
+ "clean_up_tokenization_spaces": false,
19
+ "eos_token": {
20
+ "__type": "AddedToken",
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": true
26
+ },
27
+ "model_max_length": 4096,
28
+ "pad_token": {
29
+ "__type": "AddedToken",
30
+ "content": "<unk>",
31
+ "lstrip": false,
32
+ "normalized": true,
33
+ "rstrip": false,
34
+ "single_word": true
35
+ },
36
+ "sp_model_kwargs": {},
37
+ "tokenizer_class": "BaichuanTokenizer",
38
+ "unk_token": {
39
+ "__type": "AddedToken",
40
+ "content": "<unk>",
41
+ "lstrip": false,
42
+ "normalized": true,
43
+ "rstrip": false,
44
+ "single_word": true
45
+ }
46
+ }