codefuse-admin commited on
Commit
ed69c50
·
1 Parent(s): 51b5592

upload model from ant-group,[email protected]

Browse files
LICENSE.md ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright [2023] [Ant Group]
2
+ Licensed under the Apache License, Version 2.0 (the "License");
3
+ you may not use this file except in compliance with the License.
4
+ You may obtain a copy of the License at
5
+ http://www.apache.org/licenses/LICENSE-2.0
6
+
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS,
9
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ See the License for the specific language governing permissions and
11
+ limitations under the License.
12
+
13
+
14
+ Apache License
15
+ Version 2.0, January 2004
16
+ http://www.apache.org/licenses/
17
+
18
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
19
+
20
+ 1. Definitions.
21
+
22
+ "License" shall mean the terms and conditions for use, reproduction,
23
+ and distribution as defined by Sections 1 through 9 of this document.
24
+
25
+ "Licensor" shall mean the copyright owner or entity authorized by
26
+ the copyright owner that is granting the License.
27
+
28
+ "Legal Entity" shall mean the union of the acting entity and all
29
+ other entities that control, are controlled by, or are under common
30
+ control with that entity. For the purposes of this definition,
31
+ "control" means (i) the power, direct or indirect, to cause the
32
+ direction or management of such entity, whether by contract or
33
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
34
+ outstanding shares, or (iii) beneficial ownership of such entity.
35
+
36
+ "You" (or "Your") shall mean an individual or Legal Entity
37
+ exercising permissions granted by this License.
38
+
39
+ "Source" form shall mean the preferred form for making modifications,
40
+ including but not limited to software source code, documentation
41
+ source, and configuration files.
42
+
43
+ "Object" form shall mean any form resulting from mechanical
44
+ transformation or translation of a Source form, including but
45
+ not limited to compiled object code, generated documentation,
46
+ and conversions to other media types.
47
+
48
+ "Work" shall mean the work of authorship, whether in Source or
49
+ Object form, made available under the License, as indicated by a
50
+ copyright notice that is included in or attached to the work
51
+ (an example is provided in the Appendix below).
52
+
53
+ "Derivative Works" shall mean any work, whether in Source or Object
54
+ form, that is based on (or derived from) the Work and for which the
55
+ editorial revisions, annotations, elaborations, or other modifications
56
+ represent, as a whole, an original work of authorship. For the purposes
57
+ of this License, Derivative Works shall not include works that remain
58
+ separable from, or merely link (or bind by name) to the interfaces of,
59
+ the Work and Derivative Works thereof.
60
+
61
+ "Contribution" shall mean any work of authorship, including
62
+ the original version of the Work and any modifications or additions
63
+ to that Work or Derivative Works thereof, that is intentionally
64
+ submitted to Licensor for inclusion in the Work by the copyright owner
65
+ or by an individual or Legal Entity authorized to submit on behalf of
66
+ the copyright owner. For the purposes of this definition, "submitted"
67
+ means any form of electronic, verbal, or written communication sent
68
+ to the Licensor or its representatives, including but not limited to
69
+ communication on electronic mailing lists, source code control systems,
70
+ and issue tracking systems that are managed by, or on behalf of, the
71
+ Licensor for the purpose of discussing and improving the Work, but
72
+ excluding communication that is conspicuously marked or otherwise
73
+ designated in writing by the copyright owner as "Not a Contribution."
74
+
75
+ "Contributor" shall mean Licensor and any individual or Legal Entity
76
+ on behalf of whom a Contribution has been received by Licensor and
77
+ subsequently incorporated within the Work.
78
+
79
+ 2. Grant of Copyright License. Subject to the terms and conditions of
80
+ this License, each Contributor hereby grants to You a perpetual,
81
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
82
+ copyright license to reproduce, prepare Derivative Works of,
83
+ publicly display, publicly perform, sublicense, and distribute the
84
+ Work and such Derivative Works in Source or Object form.
85
+
86
+ 3. Grant of Patent License. Subject to the terms and conditions of
87
+ this License, each Contributor hereby grants to You a perpetual,
88
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
89
+ (except as stated in this section) patent license to make, have made,
90
+ use, offer to sell, sell, import, and otherwise transfer the Work,
91
+ where such license applies only to those patent claims licensable
92
+ by such Contributor that are necessarily infringed by their
93
+ Contribution(s) alone or by combination of their Contribution(s)
94
+ with the Work to which such Contribution(s) was submitted. If You
95
+ institute patent litigation against any entity (including a
96
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
97
+ or a Contribution incorporated within the Work constitutes direct
98
+ or contributory patent infringement, then any patent licenses
99
+ granted to You under this License for that Work shall terminate
100
+ as of the date such litigation is filed.
101
+
102
+ 4. Redistribution. You may reproduce and distribute copies of the
103
+ Work or Derivative Works thereof in any medium, with or without
104
+ modifications, and in Source or Object form, provided that You
105
+ meet the following conditions:
106
+
107
+ (a) You must give any other recipients of the Work or
108
+ Derivative Works a copy of this License; and
109
+
110
+ (b) You must cause any modified files to carry prominent notices
111
+ stating that You changed the files; and
112
+
113
+ (c) You must retain, in the Source form of any Derivative Works
114
+ that You distribute, all copyright, patent, trademark, and
115
+ attribution notices from the Source form of the Work,
116
+ excluding those notices that do not pertain to any part of
117
+ the Derivative Works; and
118
+
119
+ (d) If the Work includes a "NOTICE" text file as part of its
120
+ distribution, then any Derivative Works that You distribute must
121
+ include a readable copy of the attribution notices contained
122
+ within such NOTICE file, excluding those notices that do not
123
+ pertain to any part of the Derivative Works, in at least one
124
+ of the following places: within a NOTICE text file distributed
125
+ as part of the Derivative Works; within the Source form or
126
+ documentation, if provided along with the Derivative Works; or,
127
+ within a display generated by the Derivative Works, if and
128
+ wherever such third-party notices normally appear. The contents
129
+ of the NOTICE file are for informational purposes only and
130
+ do not modify the License. You may add Your own attribution
131
+ notices within Derivative Works that You distribute, alongside
132
+ or as an addendum to the NOTICE text from the Work, provided
133
+ that such additional attribution notices cannot be construed
134
+ as modifying the License.
135
+
136
+ You may add Your own copyright statement to Your modifications and
137
+ may provide additional or different license terms and conditions
138
+ for use, reproduction, or distribution of Your modifications, or
139
+ for any such Derivative Works as a whole, provided Your use,
140
+ reproduction, and distribution of the Work otherwise complies with
141
+ the conditions stated in this License.
142
+
143
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
144
+ any Contribution intentionally submitted for inclusion in the Work
145
+ by You to the Licensor shall be under the terms and conditions of
146
+ this License, without any additional terms or conditions.
147
+ Notwithstanding the above, nothing herein shall supersede or modify
148
+ the terms of any separate license agreement you may have executed
149
+ with Licensor regarding such Contributions.
150
+
151
+ 6. Trademarks. This License does not grant permission to use the trade
152
+ names, trademarks, service marks, or product names of the Licensor,
153
+ except as required for reasonable and customary use in describing the
154
+ origin of the Work and reproducing the content of the NOTICE file.
155
+
156
+ 7. Disclaimer of Warranty. Unless required by applicable law or
157
+ agreed to in writing, Licensor provides the Work (and each
158
+ Contributor provides its Contributions) on an "AS IS" BASIS,
159
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
160
+ implied, including, without limitation, any warranties or conditions
161
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
162
+ PARTICULAR PURPOSE. You are solely responsible for determining the
163
+ appropriateness of using or redistributing the Work and assume any
164
+ risks associated with Your exercise of permissions under this License.
165
+
166
+ 8. Limitation of Liability. In no event and under no legal theory,
167
+ whether in tort (including negligence), contract, or otherwise,
168
+ unless required by applicable law (such as deliberate and grossly
169
+ negligent acts) or agreed to in writing, shall any Contributor be
170
+ liable to You for damages, including any direct, indirect, special,
171
+ incidental, or consequential damages of any character arising as a
172
+ result of this License or out of the use or inability to use the
173
+ Work (including but not limited to damages for loss of goodwill,
174
+ work stoppage, computer failure or malfunction, or any and all
175
+ other commercial damages or losses), even if such Contributor
176
+ has been advised of the possibility of such damages.
177
+
178
+ 9. Accepting Warranty or Additional Liability. While redistributing
179
+ the Work or Derivative Works thereof, You may choose to offer,
180
+ and charge a fee for, acceptance of support, warranty, indemnity,
181
+ or other liability obligations and/or rights consistent with this
182
+ License. However, in accepting such obligations, You may act only
183
+ on Your own behalf and on Your sole responsibility, not on behalf
184
+ of any other Contributor, and only if You agree to indemnify,
185
+ defend, and hold each Contributor harmless for any liability
186
+ incurred by, or claims asserted against, such Contributor by reason
187
+ of your accepting any such warranty or additional liability.
188
+
189
+ END OF TERMS AND CONDITIONS
190
+
191
+ APPENDIX: How to apply the Apache License to your work.
192
+
193
+ To apply the Apache License to your work, attach the following
194
+ boilerplate notice, with the fields enclosed by brackets "[]"
195
+ replaced with your own identifying information. (Don't include
196
+ the brackets!) The text should be enclosed in the appropriate
197
+ comment syntax for the file format. We also recommend that a
198
+ file or class name and description of purpose be included on the
199
+ same "printed page" as the copyright notice for easier
200
+ identification within third-party archives.
201
+
202
+ Copyright [yyyy] [name of copyright owner]
203
+
204
+ Licensed under the Apache License, Version 2.0 (the "License");
205
+ you may not use this file except in compliance with the License.
206
+ You may obtain a copy of the License at
207
+
208
+ http://www.apache.org/licenses/LICENSE-2.0
209
+
210
+ Unless required by applicable law or agreed to in writing, software
211
+ distributed under the License is distributed on an "AS IS" BASIS,
212
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
213
+ See the License for the specific language governing permissions and
214
+ limitations under the License.
README.md CHANGED
@@ -1,5 +1,95 @@
1
- ---
2
- license: other
3
- license_name: license.md
4
- license_link: LICENSE
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>
3
+ DevOps-Model-7B-Base
4
+ </h1>
5
+ </div>
6
+
7
+ <p align="center">
8
+ 🤗 <a href="https://huggingface.co/codefuse-ai" target="_blank">Hugging Face</a> •
9
+ 🤖 <a href="https://modelscope.cn/organization/codefuse-ai" target="_blank">ModelScope</a>
10
+ </p>
11
+
12
+ DevOps-Model 是一个**开发运维大模型**,主要致力于在 DevOps 领域发挥实际价值。目前,DevOps-Model 能够帮助工程师回答在 DevOps 生命周期中遇到的问题。欢迎访问我们 Github 获取更多信息 [DevOps-Model](https://github.com/codefuse-ai/CodeFuse-DevOps-Model)
13
+
14
+ DevOps-Model-7B-Base 是我们经过高质量 DevOps 语料训练基于 Qwen-7B 加训后的 **Base** 模型。我们的 Base 模型在开源和 DevOps 领域相关的评测数据上可以取得同规模模型中的**最佳效果**。同时我们也开源了经过对齐后的 [DevOps-Model-7B-Chat](https://modelscope.cn/models/codefuse-ai/CodeFuse-DevOps-Model-7B-Chat/summary) 模型,和 14B 参数量的[DevOps-Model-14B-Base](https://modelscope.cn/models/codefuse-ai/CodeFuse-DevOps-Model-14B-Base/summary) 和 [DevOps-Model-14B-Chat](https://modelscope.cn/models/codefuse-ai/CodeFuse-DevOps-Model-14B-Chat/summary) 。
15
+ <br>
16
+ 同时我们也在搭建 DevOps 领域专属的评测基准 [DevOpsEval](https://github.com/luban-agi/DevOps-Eval),用来更好评测 DevOps 领域模型的效果。
17
+
18
+ <br>
19
+ <br>
20
+
21
+ # 模型评测
22
+ 我们先选取了 CMMLU 和 CEval 两个评测数据集中和 DevOps 相关的一共六项考试。总计一共 574 道选择题,具体信息如下:
23
+
24
+ | 评测数据集 | 考试科目 | 题数 |
25
+ |-------|-------|-------|
26
+ | CMMLU | Computer science | 204 |
27
+ | CMMLU | Computer security | 171 |
28
+ | CMMLU | Machine learning | 122 |
29
+ | CEval | College programming | 37 |
30
+ | CEval | Computer architecture | 21 |
31
+ | CEval | Computernetwork | 19 |
32
+
33
+ 我们分别测试了 Zero-shot 和 Five-shot 的结果,我们的 DevOps-Model-7B-Base 模型可以在测试的同规模的开源 Base 模型中取得最高的成绩,后续我们也会进行更多的测试。
34
+
35
+ |模型|模型大小|Zero-shot 得分|Five-shot 得分|
36
+ |--|--|--|--|
37
+ |**DevOps-Model-7B-Base**|**7B**|**62.72**|**62.02**|
38
+ |Qwen-7B-Base|7B|55.75|56.0|
39
+ |Baichuan2-7B-Base|7B|49.30|55.4|
40
+ |Internlm-7B-Base|7B|47.56|52.6|
41
+
42
+
43
+
44
+ <br>
45
+
46
+ # 快速使用
47
+ 我们提供简单的示例来说明如何利用 🤗 Transformers 快速使用 Devops-Model-7B-Base 模型
48
+
49
+ ## 要求
50
+ - python 3.8 及以上版本
51
+ - pytorch 2.0 及以上版本
52
+ - 建议使用CUDA 11.4及以上
53
+
54
+
55
+ ## 依赖项安装
56
+ 下载模型后,直接通过以下命令安装 requirements.txt 中的包就可以
57
+ ```bash
58
+ cd path_to_download_model
59
+ pip isntall -r requirements.txt
60
+ ```
61
+
62
+ ## 模型推理示例
63
+
64
+ ```python
65
+ from transformers import AutoModelForCausalLM, AutoTokenizer
66
+ from transformers.generation import GenerationConfig
67
+
68
+ tokenizer = AutoTokenizer.from_pretrained("path_to_DevOps-Model-7B-Base", trust_remote_code=True)
69
+
70
+ model = AutoModelForCausalLM.from_pretrained("path_to_DevOps-Model-7B-Base", device_map="auto", trust_remote_code=True, bf16=True).eval()
71
+
72
+ # 指定 generation_config
73
+ model.generation_config = GenerationConfig.from_pretrained("path_to_DevOps-Model-7B-Base", trust_remote_code=True)
74
+
75
+ inputs = '''Java 中 HashMap 的实现原理是'''
76
+ input_ids = tokenizer(inputs, return_tensors='pt')
77
+ input_ids = input_ids.to(model.device)
78
+ pred = model.generate(**input_ids)
79
+
80
+ print(tokenizer.decode(pred[0]))
81
+ # Java 中 HashMap 的实现原理是数组 + 链表,数组存放的是链表中的每个节点,链表中的每个节点又存放着下一个节点的地址,从而实现了链表的遍历。当链表长度大于 8 时,链表就会转换成红黑树,从而加快了查询速度。...
82
+ ```
83
+
84
+
85
+
86
+ # 免责声明
87
+ 由于语言模型的特性,模型生成的内容可能包含幻觉或者歧视性言论。请谨慎使用 DevOps-Model 系列模型生成的内容。
88
+ 如果要公开使用或商用该模型服务,请注意服务方需承担由此产生的不良影响或有害言论的责任,本项目开发者不承担任何由使用本项目(包括但不限于数据、模型、代码等)导致的危害或损失。
89
+
90
+
91
+
92
+ # 致谢
93
+ 本项目参考了以下开源项目,在此对相关项目和研究开发人员表示感谢。
94
+ - [LLaMA-Efficient-Tuning](https://github.com/hiyouga/LLaMA-Efficient-Tuning)
95
+ - [Qwen-7B](https://github.com/QwenLM/Qwen-7B/tree/main)
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "swiglu",
3
+ "apply_residual_connection_post_layernorm": false,
4
+ "architectures": [
5
+ "QWenLMHeadModel"
6
+ ],
7
+ "attn_pdrop": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_qwen.QWenConfig",
10
+ "AutoModel": "modeling_qwen.QWenLMHeadModel",
11
+ "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel"
12
+ },
13
+ "bf16": true,
14
+ "bias_dropout_fusion": true,
15
+ "bos_token_id": 151643,
16
+ "embd_pdrop": 0.0,
17
+ "eos_token_id": 151643,
18
+ "ffn_hidden_size": 22016,
19
+ "fp16": false,
20
+ "fp32": false,
21
+ "initializer_range": 0.02,
22
+ "kv_channels": 128,
23
+ "layer_norm_epsilon": 1e-06,
24
+ "model_type": "qwen",
25
+ "n_embd": 4096,
26
+ "n_head": 32,
27
+ "n_inner": null,
28
+ "n_layer": 32,
29
+ "n_positions": 6144,
30
+ "no_bias": true,
31
+ "onnx_safe": null,
32
+ "padded_vocab_size": 151936,
33
+ "params_dtype": "torch.bfloat16",
34
+ "pos_emb": "rotary",
35
+ "resid_pdrop": 0.1,
36
+ "rotary_emb_base": 10000,
37
+ "rotary_pct": 1.0,
38
+ "scale_attn_weights": true,
39
+ "seq_length": 2048,
40
+ "tie_word_embeddings": false,
41
+ "tokenizer_type": "QWenTokenizer",
42
+ "torch_dtype": "bfloat16",
43
+ "transformers_version": "4.32.0",
44
+ "use_cache": true,
45
+ "use_dynamic_ntk": true,
46
+ "use_flash_attn": true,
47
+ "use_logn_attn": true,
48
+ "vocab_size": 151936
49
+ }
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"chatbot"}
configuration_qwen.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from transformers import PretrainedConfig
7
+
8
+
9
+ class QWenConfig(PretrainedConfig):
10
+ model_type = "qwen"
11
+ keys_to_ignore_at_inference = ["past_key_values"]
12
+ attribute_map = {
13
+ "hidden_size": "n_embd",
14
+ "num_attention_heads": "n_head",
15
+ "max_position_embeddings": "n_positions",
16
+ "num_hidden_layers": "n_layer",
17
+ }
18
+
19
+ def __init__(
20
+ self,
21
+ vocab_size=151851,
22
+ n_embd=4096,
23
+ n_layer=32,
24
+ n_head=32,
25
+ n_inner=None,
26
+ embd_pdrop=0.0,
27
+ attn_pdrop=0.0,
28
+ layer_norm_epsilon=1e-5,
29
+ initializer_range=0.02,
30
+ scale_attn_weights=True,
31
+ use_cache=True,
32
+ eos_token_id=151643,
33
+ apply_residual_connection_post_layernorm=False,
34
+ bf16=False,
35
+ fp16=False,
36
+ fp32=False,
37
+ kv_channels=128,
38
+ rotary_pct=1.0,
39
+ rotary_emb_base=10000,
40
+ use_dynamic_ntk=False,
41
+ use_logn_attn=False,
42
+ use_flash_attn=True,
43
+ ffn_hidden_size=22016,
44
+ no_bias=True,
45
+ tie_word_embeddings=False,
46
+ **kwargs,
47
+ ):
48
+ self.eos_token_id = eos_token_id
49
+ super().__init__(
50
+ eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
51
+ )
52
+
53
+ self.vocab_size = vocab_size
54
+ self.n_embd = n_embd
55
+ self.n_layer = n_layer
56
+ self.n_head = n_head
57
+ self.n_inner = n_inner
58
+ self.embd_pdrop = embd_pdrop
59
+ self.attn_pdrop = attn_pdrop
60
+ self.layer_norm_epsilon = layer_norm_epsilon
61
+ self.initializer_range = initializer_range
62
+ self.scale_attn_weights = scale_attn_weights
63
+ self.use_cache = use_cache
64
+ self.apply_residual_connection_post_layernorm = (
65
+ apply_residual_connection_post_layernorm
66
+ )
67
+ self.bf16 = bf16
68
+ self.fp16 = fp16
69
+ self.fp32 = fp32
70
+ self.kv_channels = kv_channels
71
+ self.rotary_pct = rotary_pct
72
+ self.rotary_emb_base = rotary_emb_base
73
+ self.use_dynamic_ntk = use_dynamic_ntk
74
+ self.use_logn_attn = use_logn_attn
75
+ self.use_flash_attn = use_flash_attn
76
+ self.ffn_hidden_size = ffn_hidden_size
77
+ self.no_bias = no_bias
78
+ self.tie_word_embeddings = tie_word_embeddings
generation_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chat_format": "raw",
3
+ "do_sample": true,
4
+ "eos_token_id": 151643,
5
+ "max_new_tokens": 128,
6
+ "pad_token_id": 151643,
7
+ "stop_words_ids": [
8
+ [
9
+ 151643
10
+ ]
11
+ ],
12
+ "top_k": 0,
13
+ "top_p": 0.8,
14
+ "transformers_version": "4.32.0"
15
+ }
modeling_qwen.py ADDED
@@ -0,0 +1,1219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import math
8
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from torch.cuda.amp import autocast
14
+
15
+ from torch.nn import CrossEntropyLoss
16
+ from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
17
+ from transformers.generation.logits_process import LogitsProcessorList
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.generation.streamers import BaseStreamer
21
+ from transformers.generation.utils import GenerateOutput
22
+ from transformers.modeling_outputs import (
23
+ BaseModelOutputWithPast,
24
+ CausalLMOutputWithPast,
25
+ )
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.utils import logging
28
+
29
+ try:
30
+ from einops import rearrange
31
+ except ImportError:
32
+ rearrange = None
33
+ from torch import nn
34
+
35
+ SUPPORT_CUDA = torch.cuda.is_available()
36
+ SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
37
+ SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
38
+
39
+ from .configuration_qwen import QWenConfig
40
+ from .qwen_generation_utils import (
41
+ HistoryType,
42
+ make_context,
43
+ decode_tokens,
44
+ get_stop_words_ids,
45
+ StopWordsLogitsProcessor,
46
+ )
47
+
48
+ # from loguru import logger
49
+ logger = logging.get_logger(__name__)
50
+
51
+ _CHECKPOINT_FOR_DOC = "qwen"
52
+ _CONFIG_FOR_DOC = "QWenConfig"
53
+
54
+ QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
55
+
56
+ _ERROR_BAD_CHAT_FORMAT = """\
57
+ We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
58
+ If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
59
+ 我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。
60
+ 如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
61
+ """
62
+
63
+ _SENTINEL = object()
64
+ _ERROR_STREAM_IN_CHAT = """\
65
+ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True).
66
+ 向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
67
+ """
68
+
69
+ apply_rotary_emb_func = None
70
+ rms_norm = None
71
+ flash_attn_unpadded_func = None
72
+
73
+
74
+ def _import_flash_attn():
75
+ global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
76
+ try:
77
+ from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
78
+ apply_rotary_emb_func = __apply_rotary_emb_func
79
+ print('Using flash_attn rope')
80
+ except ImportError:
81
+ logger.warn(
82
+ "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
83
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
84
+ )
85
+
86
+ try:
87
+ from flash_attn.ops.rms_norm import rms_norm as __rms_norm
88
+ rms_norm = __rms_norm
89
+ print('Using flash_attn rms_norm')
90
+ except ImportError:
91
+ logger.warn(
92
+ "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
93
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
94
+ )
95
+
96
+ try:
97
+ import flash_attn
98
+ if not hasattr(flash_attn, '__version__'):
99
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
100
+ else:
101
+ if int(flash_attn.__version__.split(".")[0]) >= 2:
102
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
103
+ else:
104
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
105
+ flash_attn_unpadded_func = __flash_attn_unpadded_func
106
+
107
+ print('Using flash_attn attention func')
108
+ except ImportError:
109
+ logger.warn(
110
+ "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
111
+ "https://github.com/Dao-AILab/flash-attention"
112
+ )
113
+
114
+
115
+ class FlashSelfAttention(torch.nn.Module):
116
+ def __init__(
117
+ self,
118
+ causal=False,
119
+ softmax_scale=None,
120
+ attention_dropout=0.0,
121
+ ):
122
+ super().__init__()
123
+ assert flash_attn_unpadded_func is not None, (
124
+ "Please install FlashAttention first, " "e.g., with pip install flash-attn"
125
+ )
126
+ assert (
127
+ rearrange is not None
128
+ ), "Please install einops first, e.g., with pip install einops"
129
+ self.causal = causal
130
+ self.softmax_scale = softmax_scale
131
+ self.dropout_p = attention_dropout
132
+
133
+ def forward(self, q, k, v):
134
+ assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
135
+ assert all((i.is_cuda for i in (q, k, v)))
136
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
137
+ seqlen_k = k.shape[1]
138
+ q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
139
+ cu_seqlens_q = torch.arange(
140
+ 0,
141
+ (batch_size + 1) * seqlen_q,
142
+ step=seqlen_q,
143
+ dtype=torch.int32,
144
+ device=q.device,
145
+ )
146
+
147
+ if self.training:
148
+ assert seqlen_k == seqlen_q
149
+
150
+ is_causal = self.causal
151
+ cu_seqlens_k = cu_seqlens_q
152
+ else:
153
+ is_causal = seqlen_q == seqlen_k
154
+ cu_seqlens_k = torch.arange(
155
+ 0,
156
+ (batch_size + 1) * seqlen_k,
157
+ step=seqlen_k,
158
+ dtype=torch.int32,
159
+ device=q.device,
160
+ )
161
+ self.dropout_p = 0
162
+ output = flash_attn_unpadded_func(
163
+ q,
164
+ k,
165
+ v,
166
+ cu_seqlens_q,
167
+ cu_seqlens_k,
168
+ seqlen_q,
169
+ seqlen_k,
170
+ self.dropout_p,
171
+ softmax_scale=self.softmax_scale,
172
+ causal=is_causal,
173
+ )
174
+
175
+ output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
176
+ return output
177
+
178
+
179
+ class QWenAttention(nn.Module):
180
+ def __init__(self, config, layer_number=None):
181
+ super().__init__()
182
+
183
+ max_positions = config.max_position_embeddings
184
+ self.register_buffer(
185
+ "bias",
186
+ torch.tril(
187
+ torch.ones((max_positions, max_positions), dtype=torch.bool)
188
+ ).view(1, 1, max_positions, max_positions),
189
+ persistent=False,
190
+ )
191
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
192
+ self.layer_number = max(1, layer_number)
193
+ self.params_dtype = config.params_dtype
194
+ self.seq_length = config.seq_length
195
+
196
+ self.hidden_size = config.hidden_size
197
+ self.split_size = config.hidden_size
198
+ self.num_heads = config.num_attention_heads
199
+ self.head_dim = self.hidden_size // self.num_heads
200
+
201
+ self.use_flash_attn = config.use_flash_attn
202
+ self.scale_attn_weights = True
203
+
204
+ self.layer_idx = None
205
+
206
+ self.projection_size = config.kv_channels * config.num_attention_heads
207
+
208
+ assert self.projection_size % config.num_attention_heads == 0
209
+ self.hidden_size_per_attention_head = (
210
+ self.projection_size // config.num_attention_heads
211
+ )
212
+
213
+ self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
214
+
215
+ self.c_proj = nn.Linear(
216
+ config.hidden_size, self.projection_size, bias=not config.no_bias
217
+ )
218
+
219
+ self.is_fp32 = not (config.bf16 or config.fp16)
220
+ if (
221
+ self.use_flash_attn
222
+ and flash_attn_unpadded_func is not None
223
+ and not self.is_fp32
224
+ ):
225
+ self.core_attention_flash = FlashSelfAttention(
226
+ causal=True, attention_dropout=config.attn_pdrop
227
+ )
228
+
229
+ self.bf16 = config.bf16
230
+
231
+ if config.rotary_pct == 1.0:
232
+ self.rotary_ndims = None
233
+ else:
234
+ assert config.rotary_pct < 1
235
+ self.rotary_ndims = int(
236
+ self.hidden_size_per_attention_head * config.rotary_pct
237
+ )
238
+ dim = (
239
+ self.rotary_ndims
240
+ if self.rotary_ndims is not None
241
+ else self.hidden_size_per_attention_head
242
+ )
243
+ self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
244
+
245
+ self.use_dynamic_ntk = config.use_dynamic_ntk
246
+ self.use_logn_attn = config.use_logn_attn
247
+
248
+ logn_list = [
249
+ math.log(i, self.seq_length) if i > self.seq_length else 1
250
+ for i in range(1, 32768)
251
+ ]
252
+ self.logn_tensor = torch.Tensor(logn_list)[None, :, None, None]
253
+ self._ntk_cached = 1.0
254
+
255
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
256
+
257
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
258
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
259
+
260
+ if self.scale_attn_weights:
261
+ attn_weights = attn_weights / torch.full(
262
+ [],
263
+ value.size(-1) ** 0.5,
264
+ dtype=attn_weights.dtype,
265
+ device=attn_weights.device,
266
+ )
267
+
268
+ query_length, key_length = query.size(-2), key.size(-2)
269
+ causal_mask = self.bias[
270
+ :, :, key_length - query_length : key_length, :key_length
271
+ ]
272
+ mask_value = torch.finfo(attn_weights.dtype).min
273
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
274
+ attn_weights.device
275
+ )
276
+ attn_weights = torch.where(
277
+ causal_mask, attn_weights.to(attn_weights.dtype), mask_value
278
+ )
279
+
280
+ if attention_mask is not None:
281
+ # Apply the attention mask
282
+ attn_weights = attn_weights + attention_mask
283
+
284
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
285
+
286
+ attn_weights = attn_weights.type(value.dtype)
287
+ attn_weights = self.attn_dropout(attn_weights)
288
+
289
+ if head_mask is not None:
290
+ attn_weights = attn_weights * head_mask
291
+
292
+ attn_output = torch.matmul(attn_weights, value)
293
+ attn_output = attn_output.transpose(1, 2)
294
+
295
+ return attn_output, attn_weights
296
+
297
+ def _upcast_and_reordered_attn(
298
+ self, query, key, value, attention_mask=None, head_mask=None
299
+ ):
300
+ bsz, num_heads, q_seq_len, dk = query.size()
301
+ _, _, k_seq_len, _ = key.size()
302
+
303
+ attn_weights = torch.empty(
304
+ bsz * num_heads,
305
+ q_seq_len,
306
+ k_seq_len,
307
+ dtype=torch.float32,
308
+ device=query.device,
309
+ )
310
+
311
+ scale_factor = 1.0
312
+ if self.scale_attn_weights:
313
+ scale_factor /= float(value.size(-1)) ** 0.5
314
+
315
+ with autocast(enabled=False):
316
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
317
+ -1, dk, k_seq_len
318
+ )
319
+ attn_weights = torch.baddbmm(
320
+ attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
321
+ )
322
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
323
+
324
+ query_length, key_length = query.size(-2), key.size(-2)
325
+ causal_mask = self.bias[
326
+ :, :, key_length - query_length : key_length, :key_length
327
+ ]
328
+ mask_value = torch.finfo(attn_weights.dtype).min
329
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
330
+ attn_weights.device
331
+ )
332
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
333
+
334
+ if attention_mask is not None:
335
+ attn_weights = attn_weights + attention_mask
336
+
337
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
338
+
339
+ if attn_weights.dtype != torch.float32:
340
+ raise RuntimeError(
341
+ "Error with upcasting, attn_weights does not have dtype torch.float32"
342
+ )
343
+ attn_weights = attn_weights.type(value.dtype)
344
+ attn_weights = self.attn_dropout(attn_weights)
345
+
346
+ if head_mask is not None:
347
+ attn_weights = attn_weights * head_mask
348
+
349
+ attn_output = torch.matmul(attn_weights, value)
350
+
351
+ return attn_output, attn_weights
352
+
353
+ def _split_heads(self, tensor, num_heads, attn_head_size):
354
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
355
+ tensor = tensor.view(new_shape)
356
+ return tensor
357
+
358
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
359
+ tensor = tensor.contiguous()
360
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
361
+ return tensor.view(new_shape)
362
+
363
+ def forward(
364
+ self,
365
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
366
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
367
+ attention_mask: Optional[torch.FloatTensor] = None,
368
+ head_mask: Optional[torch.FloatTensor] = None,
369
+ encoder_hidden_states: Optional[torch.Tensor] = None,
370
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
371
+ output_attentions: Optional[bool] = False,
372
+ use_cache: Optional[bool] = False,
373
+ ):
374
+
375
+ mixed_x_layer = self.c_attn(hidden_states)
376
+ query, key, value = mixed_x_layer.split(self.split_size, dim=2)
377
+
378
+ query = self._split_heads(query, self.num_heads, self.head_dim)
379
+ key = self._split_heads(key, self.num_heads, self.head_dim)
380
+ value = self._split_heads(value, self.num_heads, self.head_dim)
381
+
382
+ kv_seq_len = hidden_states.size()[1]
383
+ if layer_past:
384
+ # layer past[0] shape: bs * seq_len * head_num * dim
385
+ kv_seq_len += layer_past[0].shape[1]
386
+ if (
387
+ self.use_dynamic_ntk
388
+ and kv_seq_len == hidden_states.size()[1]
389
+ and not self.training
390
+ ):
391
+ context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
392
+ ntk_alpha = 2 ** math.ceil(context_value) - 1
393
+ ntk_alpha = max(ntk_alpha, 1)
394
+ self._ntk_cached = ntk_alpha
395
+ else:
396
+ ntk_alpha = self._ntk_cached
397
+ rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
398
+ hidden_states.device
399
+ )
400
+
401
+ if rotary_pos_emb is not None:
402
+ if isinstance(rotary_pos_emb, tuple):
403
+ rotary_pos_emb = rotary_pos_emb
404
+ else:
405
+ rotary_pos_emb = (rotary_pos_emb,) * 2
406
+
407
+ if rotary_pos_emb is not None:
408
+ q_pos_emb, k_pos_emb = rotary_pos_emb
409
+ # Slice the pos emb for current inference
410
+ cur_len = query.shape[1]
411
+ q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
412
+ k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
413
+ query = apply_rotary_pos_emb(query, q_pos_emb)
414
+ key = apply_rotary_pos_emb(key, k_pos_emb)
415
+
416
+ if layer_past is not None:
417
+ past_key, past_value = layer_past[0], layer_past[1]
418
+ key = torch.cat((past_key, key), dim=1)
419
+ value = torch.cat((past_value, value), dim=1)
420
+
421
+ if use_cache:
422
+ present = (key, value)
423
+ else:
424
+ present = None
425
+
426
+ if self.use_logn_attn and not self.training:
427
+ if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
428
+ self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
429
+ seq_start = key.size(1) - query.size(1)
430
+ seq_end = key.size(1)
431
+ logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
432
+ query = query * logn_tensor.expand_as(query)
433
+
434
+ if (
435
+ self.use_flash_attn
436
+ and flash_attn_unpadded_func is not None
437
+ and not self.is_fp32
438
+ and query.is_cuda
439
+ ):
440
+ q, k, v = query, key, value
441
+ context_layer = self.core_attention_flash(q, k, v)
442
+
443
+ context_layer = rearrange(
444
+ context_layer, "b s h d -> b s (h d)"
445
+ ).contiguous()
446
+ else:
447
+ query = query.permute(0, 2, 1, 3)
448
+ key = key.permute(0, 2, 1, 3)
449
+ value = value.permute(0, 2, 1, 3)
450
+ attn_output, attn_weight = self._attn(
451
+ query, key, value, attention_mask, head_mask
452
+ )
453
+ context_layer = self._merge_heads(
454
+ attn_output, self.num_heads, self.head_dim
455
+ )
456
+
457
+ attn_output = self.c_proj(context_layer)
458
+ outputs = (attn_output, present)
459
+ if output_attentions:
460
+ if (
461
+ self.use_flash_attn
462
+ and flash_attn_unpadded_func is not None
463
+ and not self.is_fp32
464
+ ):
465
+ raise ValueError("Cannot output attentions while using flash-attn")
466
+ else:
467
+ outputs += (attn_weight,)
468
+
469
+ return outputs
470
+
471
+
472
+ class QWenMLP(nn.Module):
473
+ def __init__(self, config):
474
+ super().__init__()
475
+ self.w1 = nn.Linear(
476
+ config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias
477
+ )
478
+ self.w2 = nn.Linear(
479
+ config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias
480
+ )
481
+ ff_dim_in = config.ffn_hidden_size // 2
482
+ self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
483
+
484
+ def forward(self, hidden_states):
485
+ a1 = self.w1(hidden_states)
486
+ a2 = self.w2(hidden_states)
487
+ intermediate_parallel = a1 * F.silu(a2)
488
+ output = self.c_proj(intermediate_parallel)
489
+ return output
490
+
491
+
492
+ class QWenBlock(nn.Module):
493
+ def __init__(self, config, layer_idx=None, num_expert=1):
494
+ super().__init__()
495
+ self.num_expert = num_expert
496
+ self.layer_number = layer_idx
497
+ self.apply_residual_connection_post_layernorm = (
498
+ config.apply_residual_connection_post_layernorm
499
+ )
500
+ hidden_size = config.hidden_size
501
+ self.apply_residual_connection_post_layernorm = (
502
+ config.apply_residual_connection_post_layernorm
503
+ )
504
+ self.bf16 = config.bf16
505
+
506
+ self.ln_1 = RMSNorm(
507
+ hidden_size,
508
+ eps=config.layer_norm_epsilon,
509
+ )
510
+ self.attn = QWenAttention(config, layer_number=layer_idx)
511
+ self.ln_2 = RMSNorm(
512
+ hidden_size,
513
+ eps=config.layer_norm_epsilon,
514
+ )
515
+
516
+ self.mlp = QWenMLP(config)
517
+
518
+ def forward(
519
+ self,
520
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
521
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
522
+ attention_mask: Optional[torch.FloatTensor] = None,
523
+ head_mask: Optional[torch.FloatTensor] = None,
524
+ encoder_hidden_states: Optional[torch.Tensor] = None,
525
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
526
+ use_cache: Optional[bool] = False,
527
+ output_attentions: Optional[bool] = False,
528
+ ):
529
+ layernorm_output = self.ln_1(hidden_states)
530
+
531
+ attn_outputs = self.attn(
532
+ layernorm_output,
533
+ layer_past=layer_past,
534
+ attention_mask=attention_mask,
535
+ head_mask=head_mask,
536
+ use_cache=use_cache,
537
+ output_attentions=output_attentions,
538
+ )
539
+ attn_output = attn_outputs[0]
540
+
541
+ outputs = attn_outputs[1:]
542
+
543
+ if self.apply_residual_connection_post_layernorm:
544
+ residual = layernorm_output
545
+ else:
546
+ residual = hidden_states
547
+ layernorm_input = attn_output + residual
548
+
549
+ layernorm_output = self.ln_2(layernorm_input)
550
+
551
+ if self.apply_residual_connection_post_layernorm:
552
+ residual = layernorm_output
553
+ else:
554
+ residual = layernorm_input
555
+
556
+ mlp_output = self.mlp(layernorm_output)
557
+ hidden_states = residual + mlp_output
558
+
559
+ if use_cache:
560
+ outputs = (hidden_states,) + outputs
561
+ else:
562
+ outputs = (hidden_states,) + outputs[1:]
563
+
564
+ return outputs
565
+
566
+
567
+ class QWenPreTrainedModel(PreTrainedModel):
568
+ config_class = QWenConfig
569
+ base_model_prefix = "transformer"
570
+ is_parallelizable = False
571
+ supports_gradient_checkpointing = True
572
+ _no_split_modules = ["QWenBlock"]
573
+
574
+ def __init__(self, *inputs, **kwargs):
575
+ super().__init__(*inputs, **kwargs)
576
+
577
+ def _init_weights(self, module):
578
+ """Initialize the weights."""
579
+ if isinstance(module, nn.Linear):
580
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
581
+ if module.bias is not None:
582
+ module.bias.data.zero_()
583
+ elif isinstance(module, nn.Embedding):
584
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
585
+ if module.padding_idx is not None:
586
+ module.weight.data[module.padding_idx].zero_()
587
+ elif isinstance(module, RMSNorm):
588
+ module.weight.data.fill_(1.0)
589
+
590
+ for name, p in module.named_parameters():
591
+ if name == "c_proj.weight":
592
+ p.data.normal_(
593
+ mean=0.0,
594
+ std=(
595
+ self.config.initializer_range
596
+ / math.sqrt(2 * self.config.n_layer)
597
+ ),
598
+ )
599
+
600
+ def _set_gradient_checkpointing(self, module, value=False):
601
+ if isinstance(module, QWenModel):
602
+ module.gradient_checkpointing = value
603
+
604
+
605
+ class QWenModel(QWenPreTrainedModel):
606
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
607
+
608
+ def __init__(self, config):
609
+ super().__init__(config)
610
+ self.vocab_size = config.padded_vocab_size
611
+ self.num_hidden_layers = config.num_hidden_layers
612
+ self.embed_dim = config.hidden_size
613
+
614
+ max_sequence_length = config.max_position_embeddings
615
+ self.position_embedding_type = config.pos_emb
616
+ self.gradient_checkpointing = False
617
+
618
+ if self.position_embedding_type == "learned":
619
+ self.wpe = nn.Embedding(max_sequence_length, self.embed_dim)
620
+ self.init_method(self.position_embeddings.weight)
621
+ self._position_embeddings_key = "position_embeddings"
622
+ self.init_method(self.position_embeddings.weight)
623
+ else:
624
+ self.wpe = None
625
+ self._position_embeddings_key = ""
626
+
627
+ self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
628
+
629
+ self.drop = nn.Dropout(config.embd_pdrop)
630
+ self.h = nn.ModuleList(
631
+ [
632
+ QWenBlock(
633
+ config,
634
+ layer_idx=i,
635
+ )
636
+ for i in range(config.num_hidden_layers)
637
+ ]
638
+ )
639
+ self.ln_f = RMSNorm(
640
+ self.embed_dim,
641
+ eps=config.layer_norm_epsilon,
642
+ )
643
+
644
+ self.post_init()
645
+
646
+ def get_input_embeddings(self):
647
+ return self.wte
648
+
649
+ def set_input_embeddings(self, new_embeddings):
650
+ self.wte = new_embeddings
651
+
652
+ def forward(
653
+ self,
654
+ input_ids: Optional[torch.LongTensor] = None,
655
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
656
+ attention_mask: Optional[torch.FloatTensor] = None,
657
+ token_type_ids: Optional[torch.LongTensor] = None,
658
+ position_ids: Optional[torch.LongTensor] = None,
659
+ head_mask: Optional[torch.FloatTensor] = None,
660
+ inputs_embeds: Optional[torch.FloatTensor] = None,
661
+ encoder_hidden_states: Optional[torch.Tensor] = None,
662
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
663
+ use_cache: Optional[bool] = None,
664
+ output_attentions: Optional[bool] = None,
665
+ output_hidden_states: Optional[bool] = None,
666
+ return_dict: Optional[bool] = None,
667
+ ):
668
+ output_attentions = (
669
+ output_attentions
670
+ if output_attentions is not None
671
+ else self.config.output_attentions
672
+ )
673
+ output_hidden_states = (
674
+ output_hidden_states
675
+ if output_hidden_states is not None
676
+ else self.config.output_hidden_states
677
+ )
678
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
679
+ return_dict = (
680
+ return_dict if return_dict is not None else self.config.use_return_dict
681
+ )
682
+
683
+ if input_ids is not None and inputs_embeds is not None:
684
+ raise ValueError(
685
+ "You cannot specify both input_ids and inputs_embeds at the same time"
686
+ )
687
+ elif input_ids is not None:
688
+ input_shape = input_ids.size()
689
+ input_ids = input_ids.view(-1, input_shape[-1])
690
+ batch_size = input_ids.shape[0]
691
+ elif inputs_embeds is not None:
692
+ input_shape = inputs_embeds.size()[:-1]
693
+ batch_size = inputs_embeds.shape[0]
694
+ else:
695
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
696
+
697
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
698
+
699
+ if token_type_ids is not None:
700
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
701
+ if position_ids is not None:
702
+ position_ids = position_ids.view(-1, input_shape[-1])
703
+
704
+ if past_key_values is None:
705
+ past_length = 0
706
+ past_key_values = tuple([None] * len(self.h))
707
+ else:
708
+ past_length = past_key_values[0][0].size(-2)
709
+
710
+ if position_ids is None:
711
+ position_ids = torch.arange(
712
+ past_length,
713
+ input_shape[-1] + past_length,
714
+ dtype=torch.long,
715
+ device=device,
716
+ )
717
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
718
+
719
+ if attention_mask is not None:
720
+ if batch_size <= 0:
721
+ raise ValueError("batch_size has to be defined and > 0")
722
+ attention_mask = attention_mask.view(batch_size, -1)
723
+ attention_mask = attention_mask[:, None, None, :]
724
+ attention_mask = attention_mask.to(dtype=self.dtype)
725
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
726
+ # attention_mask中mask掉的部分是-inf, 看到的部分是0
727
+
728
+ encoder_attention_mask = None
729
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
730
+
731
+ if inputs_embeds is None:
732
+ inputs_embeds = self.wte(input_ids)
733
+ hidden_states = inputs_embeds
734
+ if self.wpe is not None:
735
+ position_embeds = self.wpe(position_ids)
736
+ hidden_states = hidden_states + position_embeds
737
+
738
+ hidden_states = self.drop(hidden_states)
739
+ output_shape = input_shape + (hidden_states.size(-1),)
740
+
741
+ if self.gradient_checkpointing and self.training:
742
+ if use_cache:
743
+ logger.warning_once(
744
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
745
+ )
746
+ use_cache = False
747
+
748
+ presents = () if use_cache else None
749
+ all_self_attentions = () if output_attentions else None
750
+ all_hidden_states = () if output_hidden_states else None
751
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
752
+
753
+ if output_hidden_states:
754
+ all_hidden_states = all_hidden_states + (hidden_states,)
755
+
756
+ if self.gradient_checkpointing and self.training:
757
+
758
+ def create_custom_forward(module):
759
+ def custom_forward(*inputs):
760
+ # None for past_key_value
761
+ return module(*inputs, use_cache, output_attentions)
762
+
763
+ return custom_forward
764
+
765
+ outputs = torch.utils.checkpoint.checkpoint(
766
+ create_custom_forward(block),
767
+ hidden_states,
768
+ None,
769
+ attention_mask,
770
+ head_mask[i],
771
+ encoder_hidden_states,
772
+ encoder_attention_mask,
773
+ )
774
+ else:
775
+ outputs = block(
776
+ hidden_states,
777
+ layer_past=layer_past,
778
+ attention_mask=attention_mask,
779
+ head_mask=head_mask[i],
780
+ encoder_hidden_states=encoder_hidden_states,
781
+ encoder_attention_mask=encoder_attention_mask,
782
+ use_cache=use_cache,
783
+ output_attentions=output_attentions,
784
+ )
785
+
786
+ hidden_states = outputs[0]
787
+ if use_cache is True:
788
+ presents = presents + (outputs[2 if output_attentions else 1],)
789
+
790
+ if output_attentions:
791
+ all_self_attentions = all_self_attentions + (outputs[1],)
792
+
793
+ hidden_states = self.ln_f(hidden_states)
794
+ hidden_states = hidden_states.view(output_shape)
795
+
796
+ if not return_dict:
797
+ return tuple(
798
+ v for v in [hidden_states, presents, all_hidden_states] if v is not None
799
+ )
800
+
801
+ return BaseModelOutputWithPast(
802
+ last_hidden_state=hidden_states,
803
+ past_key_values=presents,
804
+ hidden_states=all_hidden_states,
805
+ attentions=all_self_attentions,
806
+ )
807
+
808
+
809
+ class QWenLMHeadModel(QWenPreTrainedModel):
810
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
811
+ _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
812
+
813
+ def __init__(self, config):
814
+ super().__init__(config)
815
+ assert (
816
+ config.bf16 + config.fp16 + config.fp32 <= 1
817
+ ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
818
+
819
+ autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
820
+
821
+ if autoset_precision:
822
+ if SUPPORT_BF16:
823
+ logger.warn(
824
+ "The model is automatically converting to bf16 for faster inference. "
825
+ "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
826
+ )
827
+ config.bf16 = True
828
+ elif SUPPORT_FP16:
829
+ logger.warn(
830
+ "The model is automatically converting to fp16 for faster inference. "
831
+ "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
832
+ )
833
+ config.fp16 = True
834
+ else:
835
+ config.fp32 = True
836
+
837
+ if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
838
+ logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
839
+ if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
840
+ logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
841
+ if config.fp32:
842
+ if SUPPORT_BF16:
843
+ logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
844
+ elif SUPPORT_FP16:
845
+ logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
846
+
847
+ if config.use_flash_attn == "auto":
848
+ if config.bf16 or config.fp16:
849
+ logger.warn("Try importing flash-attention for faster inference...")
850
+ config.use_flash_attn = True
851
+ else:
852
+ config.use_flash_attn = False
853
+ if config.use_flash_attn and config.fp32:
854
+ logger.warn("Flash attention will be disabled because it does NOT support fp32.")
855
+
856
+ if config.use_flash_attn:
857
+ _import_flash_attn()
858
+
859
+ self.transformer = QWenModel(config)
860
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
861
+
862
+ if config.bf16:
863
+ self.transformer.bfloat16()
864
+ self.lm_head.bfloat16()
865
+ if config.fp16:
866
+ self.transformer.half()
867
+ self.lm_head.half()
868
+ self.post_init()
869
+
870
+ def get_output_embeddings(self):
871
+ return self.lm_head
872
+
873
+ def set_output_embeddings(self, new_embeddings):
874
+ self.lm_head = new_embeddings
875
+
876
+ def prepare_inputs_for_generation(
877
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
878
+ ):
879
+ token_type_ids = kwargs.get("token_type_ids", None)
880
+ if past_key_values:
881
+ input_ids = input_ids[:, -1].unsqueeze(-1)
882
+ if token_type_ids is not None:
883
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
884
+
885
+ attention_mask = kwargs.get("attention_mask", None)
886
+ position_ids = kwargs.get("position_ids", None)
887
+
888
+ if attention_mask is not None and position_ids is None:
889
+ position_ids = attention_mask.long().cumsum(-1) - 1
890
+ position_ids.masked_fill_(attention_mask == 0, 1)
891
+ if past_key_values:
892
+ position_ids = position_ids[:, -1].unsqueeze(-1)
893
+ else:
894
+ position_ids = None
895
+
896
+ if inputs_embeds is not None and past_key_values is None:
897
+ model_inputs = {"inputs_embeds": inputs_embeds}
898
+ else:
899
+ model_inputs = {"input_ids": input_ids}
900
+
901
+ model_inputs.update(
902
+ {
903
+ "past_key_values": past_key_values,
904
+ "use_cache": kwargs.get("use_cache"),
905
+ "position_ids": position_ids,
906
+ "attention_mask": attention_mask,
907
+ "token_type_ids": token_type_ids,
908
+ }
909
+ )
910
+ return model_inputs
911
+
912
+ def forward(
913
+ self,
914
+ input_ids: Optional[torch.LongTensor] = None,
915
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
916
+ attention_mask: Optional[torch.FloatTensor] = None,
917
+ token_type_ids: Optional[torch.LongTensor] = None,
918
+ position_ids: Optional[torch.LongTensor] = None,
919
+ head_mask: Optional[torch.FloatTensor] = None,
920
+ inputs_embeds: Optional[torch.FloatTensor] = None,
921
+ encoder_hidden_states: Optional[torch.Tensor] = None,
922
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
923
+ labels: Optional[torch.LongTensor] = None,
924
+ use_cache: Optional[bool] = None,
925
+ output_attentions: Optional[bool] = None,
926
+ output_hidden_states: Optional[bool] = None,
927
+ return_dict: Optional[bool] = None,
928
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
929
+
930
+ return_dict = (
931
+ return_dict if return_dict is not None else self.config.use_return_dict
932
+ )
933
+
934
+ transformer_outputs = self.transformer(
935
+ input_ids,
936
+ past_key_values=past_key_values,
937
+ attention_mask=attention_mask,
938
+ token_type_ids=token_type_ids,
939
+ position_ids=position_ids,
940
+ head_mask=head_mask,
941
+ inputs_embeds=inputs_embeds,
942
+ encoder_hidden_states=encoder_hidden_states,
943
+ encoder_attention_mask=encoder_attention_mask,
944
+ use_cache=use_cache,
945
+ output_attentions=output_attentions,
946
+ output_hidden_states=output_hidden_states,
947
+ return_dict=return_dict,
948
+ )
949
+ hidden_states = transformer_outputs[0]
950
+
951
+ lm_logits = self.lm_head(hidden_states)
952
+
953
+ loss = None
954
+ if labels is not None:
955
+ labels = labels.to(lm_logits.device)
956
+ shift_logits = lm_logits[..., :-1, :].contiguous()
957
+ shift_labels = labels[..., 1:].contiguous()
958
+ loss_fct = CrossEntropyLoss()
959
+ loss = loss_fct(
960
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
961
+ )
962
+
963
+ if not return_dict:
964
+ output = (lm_logits,) + transformer_outputs[1:]
965
+ return ((loss,) + output) if loss is not None else output
966
+
967
+ return CausalLMOutputWithPast(
968
+ loss=loss,
969
+ logits=lm_logits,
970
+ past_key_values=transformer_outputs.past_key_values,
971
+ hidden_states=transformer_outputs.hidden_states,
972
+ attentions=transformer_outputs.attentions,
973
+ )
974
+
975
+ @staticmethod
976
+ def _reorder_cache(
977
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
978
+ ) -> Tuple[Tuple[torch.Tensor]]:
979
+
980
+ return tuple(
981
+ tuple(
982
+ past_state.index_select(0, beam_idx.to(past_state.device))
983
+ for past_state in layer_past
984
+ )
985
+ for layer_past in past_key_values
986
+ )
987
+
988
+ def chat(
989
+ self,
990
+ tokenizer: PreTrainedTokenizer,
991
+ query: str,
992
+ history: Optional[HistoryType],
993
+ system: str = "You are a helpful assistant.",
994
+ append_history: bool = True,
995
+ stream: Optional[bool] = _SENTINEL,
996
+ stop_words_ids: Optional[List[List[int]]] = None,
997
+ **kwargs,
998
+ ) -> Tuple[str, HistoryType]:
999
+ assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
1000
+ assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
1001
+ if history is None:
1002
+ history = []
1003
+ if stop_words_ids is None:
1004
+ stop_words_ids = []
1005
+
1006
+ raw_text, context_tokens = make_context(
1007
+ tokenizer,
1008
+ query,
1009
+ history=history,
1010
+ system=system,
1011
+ max_window_size=6144,
1012
+ chat_format=self.generation_config.chat_format,
1013
+ )
1014
+
1015
+ stop_words_ids.extend(get_stop_words_ids(
1016
+ self.generation_config.chat_format, tokenizer
1017
+ ))
1018
+ input_ids = torch.tensor([context_tokens]).to(self.device)
1019
+ outputs = self.generate(
1020
+ input_ids,
1021
+ stop_words_ids = stop_words_ids,
1022
+ return_dict_in_generate = False,
1023
+ **kwargs,
1024
+ )
1025
+
1026
+ response = decode_tokens(
1027
+ outputs[0],
1028
+ tokenizer,
1029
+ raw_text_len=len(raw_text),
1030
+ context_length=len(context_tokens),
1031
+ chat_format=self.generation_config.chat_format,
1032
+ verbose=False,
1033
+ errors='replace'
1034
+ )
1035
+
1036
+ if append_history:
1037
+ history.append((query, response))
1038
+
1039
+ return response, history
1040
+
1041
+ def chat_stream(
1042
+ self,
1043
+ tokenizer: PreTrainedTokenizer,
1044
+ query: str,
1045
+ history: Optional[HistoryType],
1046
+ system: str = "You are a helpful assistant.",
1047
+ stop_words_ids: Optional[List[List[int]]] = None,
1048
+ logits_processor: Optional[LogitsProcessorList] = None,
1049
+ **kwargs,
1050
+ ) -> Generator[str, Any, None]:
1051
+ assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
1052
+ if history is None:
1053
+ history = []
1054
+ if stop_words_ids is None:
1055
+ stop_words_ids = []
1056
+
1057
+ raw_text, context_tokens = make_context(
1058
+ tokenizer,
1059
+ query,
1060
+ history=history,
1061
+ system=system,
1062
+ max_window_size=6144,
1063
+ chat_format=self.generation_config.chat_format,
1064
+ )
1065
+
1066
+ stop_words_ids.extend(get_stop_words_ids(
1067
+ self.generation_config.chat_format, tokenizer
1068
+ ))
1069
+ if stop_words_ids is not None:
1070
+ stop_words_logits_processor = StopWordsLogitsProcessor(
1071
+ stop_words_ids=stop_words_ids,
1072
+ eos_token_id=self.generation_config.eos_token_id,
1073
+ )
1074
+ if logits_processor is None:
1075
+ logits_processor = LogitsProcessorList([stop_words_logits_processor])
1076
+ else:
1077
+ logits_processor.append(stop_words_logits_processor)
1078
+ input_ids = torch.tensor([context_tokens]).to(self.device)
1079
+
1080
+ from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
1081
+ self.__class__.generate_stream = NewGenerationMixin.generate
1082
+ self.__class__.sample_stream = NewGenerationMixin.sample_stream
1083
+ stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
1084
+ def stream_generator():
1085
+ outputs = []
1086
+ for token in self.generate_stream(
1087
+ input_ids,
1088
+ return_dict_in_generate=False,
1089
+ generation_config=stream_config,
1090
+ logits_processor=logits_processor,
1091
+ seed=-1,
1092
+ **kwargs):
1093
+ outputs.append(token.item())
1094
+ yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore')
1095
+
1096
+ return stream_generator()
1097
+
1098
+ def generate(
1099
+ self,
1100
+ inputs: Optional[torch.Tensor] = None,
1101
+ generation_config: Optional[GenerationConfig] = None,
1102
+ logits_processor: Optional[LogitsProcessorList] = None,
1103
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1104
+ prefix_allowed_tokens_fn: Optional[
1105
+ Callable[[int, torch.Tensor], List[int]]
1106
+ ] = None,
1107
+ synced_gpus: Optional[bool] = None,
1108
+ assistant_model: Optional["PreTrainedModel"] = None,
1109
+ streamer: Optional["BaseStreamer"] = None,
1110
+ **kwargs,
1111
+ ) -> Union[GenerateOutput, torch.LongTensor]:
1112
+ # Process stop_words_ids.
1113
+ stop_words_ids = kwargs.pop("stop_words_ids", None)
1114
+ if stop_words_ids is None and generation_config is not None:
1115
+ stop_words_ids = getattr(generation_config, "stop_words_ids", None)
1116
+ if stop_words_ids is None:
1117
+ stop_words_ids = getattr(self.generation_config, "stop_words_ids", None)
1118
+
1119
+ if stop_words_ids is not None:
1120
+ stop_words_logits_processor = StopWordsLogitsProcessor(
1121
+ stop_words_ids=stop_words_ids,
1122
+ eos_token_id=self.generation_config.eos_token_id,
1123
+ )
1124
+ if logits_processor is None:
1125
+ logits_processor = LogitsProcessorList([stop_words_logits_processor])
1126
+ else:
1127
+ logits_processor.append(stop_words_logits_processor)
1128
+
1129
+ return super().generate(
1130
+ inputs,
1131
+ generation_config=generation_config,
1132
+ logits_processor=logits_processor,
1133
+ stopping_criteria=stopping_criteria,
1134
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1135
+ synced_gpus=synced_gpus,
1136
+ assistant_model=assistant_model,
1137
+ streamer=streamer,
1138
+ **kwargs,
1139
+ )
1140
+
1141
+
1142
+ class RotaryEmbedding(torch.nn.Module):
1143
+ def __init__(self, dim, base=10000):
1144
+ super().__init__()
1145
+ self.dim = dim
1146
+ self.base = base
1147
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
1148
+ if importlib.util.find_spec("einops") is None:
1149
+ raise RuntimeError("einops is required for Rotary Embedding")
1150
+
1151
+ self._rotary_pos_emb_cache = None
1152
+ self._seq_len_cached = 0
1153
+ self._ntk_alpha_cached = 1.0
1154
+
1155
+ def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
1156
+ seqlen = max_seq_len + offset
1157
+ if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
1158
+ base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
1159
+ self.inv_freq = 1.0 / (
1160
+ base
1161
+ ** (
1162
+ torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
1163
+ / self.dim
1164
+ )
1165
+ )
1166
+ self._seq_len_cached = max(2 * seqlen, 16)
1167
+ self._ntk_alpha_cached = ntk_alpha
1168
+ seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
1169
+ freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
1170
+ emb = torch.cat((freqs, freqs), dim=-1)
1171
+ from einops import rearrange
1172
+
1173
+ self._rotary_pos_emb_cache = rearrange(emb, "n d -> 1 n 1 d")
1174
+
1175
+ def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
1176
+ self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
1177
+ return self._rotary_pos_emb_cache[:, offset : offset + max_seq_len]
1178
+
1179
+
1180
+ def _rotate_half(x):
1181
+ from einops import rearrange
1182
+
1183
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
1184
+ x1, x2 = x.unbind(dim=-2)
1185
+ return torch.cat((-x2, x1), dim=-1)
1186
+
1187
+
1188
+ def apply_rotary_pos_emb(t, freqs):
1189
+ if apply_rotary_emb_func is not None:
1190
+ t_ = t.float()
1191
+ freqs = freqs.squeeze(0).squeeze(1)
1192
+ cos = freqs[:, : freqs.shape[-1] // 2].cos()
1193
+ sin = freqs[:, : freqs.shape[-1] // 2].sin()
1194
+ output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
1195
+ return output
1196
+ else:
1197
+ rot_dim = freqs.shape[-1]
1198
+ t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
1199
+ t_ = t_.float()
1200
+ t_pass_ = t_pass_.float()
1201
+ t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())
1202
+ return torch.cat((t_, t_pass_), dim=-1).type_as(t)
1203
+
1204
+
1205
+ class RMSNorm(torch.nn.Module):
1206
+ def __init__(self, dim: int, eps: float = 1e-6):
1207
+ super().__init__()
1208
+ self.eps = eps
1209
+ self.weight = nn.Parameter(torch.ones(dim))
1210
+
1211
+ def _norm(self, x):
1212
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1213
+
1214
+ def forward(self, x):
1215
+ if rms_norm is not None and x.is_cuda:
1216
+ return rms_norm(x, self.weight, self.eps)
1217
+ else:
1218
+ output = self._norm(x.float()).type_as(x)
1219
+ return output * self.weight
pytorch_model-00001-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e86ea9d704af1c1212c0b944be85d033affd01f4af71a69b4ecf193b0175dddd
3
+ size 9969772092
pytorch_model-00002-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf2933dbfd917f2bb3c590fc5cf3a25b0d8db8f12cc97d4fe9b5c9179c99769e
3
+ size 5472963479
pytorch_model.bin.index.json ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 15442649088
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "pytorch_model-00002-of-00002.bin",
7
+ "transformer.h.0.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
8
+ "transformer.h.0.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
9
+ "transformer.h.0.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
10
+ "transformer.h.0.ln_1.weight": "pytorch_model-00001-of-00002.bin",
11
+ "transformer.h.0.ln_2.weight": "pytorch_model-00001-of-00002.bin",
12
+ "transformer.h.0.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
13
+ "transformer.h.0.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
14
+ "transformer.h.0.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
15
+ "transformer.h.1.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
16
+ "transformer.h.1.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
17
+ "transformer.h.1.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
18
+ "transformer.h.1.ln_1.weight": "pytorch_model-00001-of-00002.bin",
19
+ "transformer.h.1.ln_2.weight": "pytorch_model-00001-of-00002.bin",
20
+ "transformer.h.1.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
21
+ "transformer.h.1.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
22
+ "transformer.h.1.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
23
+ "transformer.h.10.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
24
+ "transformer.h.10.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
25
+ "transformer.h.10.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
26
+ "transformer.h.10.ln_1.weight": "pytorch_model-00001-of-00002.bin",
27
+ "transformer.h.10.ln_2.weight": "pytorch_model-00001-of-00002.bin",
28
+ "transformer.h.10.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
29
+ "transformer.h.10.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
30
+ "transformer.h.10.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
31
+ "transformer.h.11.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
32
+ "transformer.h.11.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
33
+ "transformer.h.11.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
34
+ "transformer.h.11.ln_1.weight": "pytorch_model-00001-of-00002.bin",
35
+ "transformer.h.11.ln_2.weight": "pytorch_model-00001-of-00002.bin",
36
+ "transformer.h.11.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
37
+ "transformer.h.11.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
38
+ "transformer.h.11.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
39
+ "transformer.h.12.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
40
+ "transformer.h.12.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
41
+ "transformer.h.12.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
42
+ "transformer.h.12.ln_1.weight": "pytorch_model-00001-of-00002.bin",
43
+ "transformer.h.12.ln_2.weight": "pytorch_model-00001-of-00002.bin",
44
+ "transformer.h.12.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
45
+ "transformer.h.12.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
46
+ "transformer.h.12.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
47
+ "transformer.h.13.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
48
+ "transformer.h.13.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
49
+ "transformer.h.13.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
50
+ "transformer.h.13.ln_1.weight": "pytorch_model-00001-of-00002.bin",
51
+ "transformer.h.13.ln_2.weight": "pytorch_model-00001-of-00002.bin",
52
+ "transformer.h.13.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
53
+ "transformer.h.13.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
54
+ "transformer.h.13.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
55
+ "transformer.h.14.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
56
+ "transformer.h.14.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
57
+ "transformer.h.14.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
58
+ "transformer.h.14.ln_1.weight": "pytorch_model-00001-of-00002.bin",
59
+ "transformer.h.14.ln_2.weight": "pytorch_model-00001-of-00002.bin",
60
+ "transformer.h.14.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
61
+ "transformer.h.14.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
62
+ "transformer.h.14.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
63
+ "transformer.h.15.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
64
+ "transformer.h.15.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
65
+ "transformer.h.15.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
66
+ "transformer.h.15.ln_1.weight": "pytorch_model-00001-of-00002.bin",
67
+ "transformer.h.15.ln_2.weight": "pytorch_model-00001-of-00002.bin",
68
+ "transformer.h.15.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
69
+ "transformer.h.15.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
70
+ "transformer.h.15.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
71
+ "transformer.h.16.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
72
+ "transformer.h.16.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
73
+ "transformer.h.16.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
74
+ "transformer.h.16.ln_1.weight": "pytorch_model-00001-of-00002.bin",
75
+ "transformer.h.16.ln_2.weight": "pytorch_model-00001-of-00002.bin",
76
+ "transformer.h.16.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
77
+ "transformer.h.16.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
78
+ "transformer.h.16.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
79
+ "transformer.h.17.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
80
+ "transformer.h.17.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
81
+ "transformer.h.17.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
82
+ "transformer.h.17.ln_1.weight": "pytorch_model-00001-of-00002.bin",
83
+ "transformer.h.17.ln_2.weight": "pytorch_model-00001-of-00002.bin",
84
+ "transformer.h.17.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
85
+ "transformer.h.17.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
86
+ "transformer.h.17.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
87
+ "transformer.h.18.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
88
+ "transformer.h.18.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
89
+ "transformer.h.18.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
90
+ "transformer.h.18.ln_1.weight": "pytorch_model-00001-of-00002.bin",
91
+ "transformer.h.18.ln_2.weight": "pytorch_model-00001-of-00002.bin",
92
+ "transformer.h.18.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
93
+ "transformer.h.18.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
94
+ "transformer.h.18.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
95
+ "transformer.h.19.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
96
+ "transformer.h.19.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
97
+ "transformer.h.19.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
98
+ "transformer.h.19.ln_1.weight": "pytorch_model-00001-of-00002.bin",
99
+ "transformer.h.19.ln_2.weight": "pytorch_model-00001-of-00002.bin",
100
+ "transformer.h.19.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
101
+ "transformer.h.19.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
102
+ "transformer.h.19.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
103
+ "transformer.h.2.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
104
+ "transformer.h.2.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
105
+ "transformer.h.2.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
106
+ "transformer.h.2.ln_1.weight": "pytorch_model-00001-of-00002.bin",
107
+ "transformer.h.2.ln_2.weight": "pytorch_model-00001-of-00002.bin",
108
+ "transformer.h.2.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
109
+ "transformer.h.2.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
110
+ "transformer.h.2.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
111
+ "transformer.h.20.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
112
+ "transformer.h.20.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
113
+ "transformer.h.20.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
114
+ "transformer.h.20.ln_1.weight": "pytorch_model-00001-of-00002.bin",
115
+ "transformer.h.20.ln_2.weight": "pytorch_model-00001-of-00002.bin",
116
+ "transformer.h.20.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
117
+ "transformer.h.20.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
118
+ "transformer.h.20.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
119
+ "transformer.h.21.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
120
+ "transformer.h.21.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
121
+ "transformer.h.21.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
122
+ "transformer.h.21.ln_1.weight": "pytorch_model-00001-of-00002.bin",
123
+ "transformer.h.21.ln_2.weight": "pytorch_model-00001-of-00002.bin",
124
+ "transformer.h.21.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
125
+ "transformer.h.21.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
126
+ "transformer.h.21.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
127
+ "transformer.h.22.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
128
+ "transformer.h.22.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
129
+ "transformer.h.22.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
130
+ "transformer.h.22.ln_1.weight": "pytorch_model-00002-of-00002.bin",
131
+ "transformer.h.22.ln_2.weight": "pytorch_model-00002-of-00002.bin",
132
+ "transformer.h.22.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
133
+ "transformer.h.22.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
134
+ "transformer.h.22.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
135
+ "transformer.h.23.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
136
+ "transformer.h.23.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
137
+ "transformer.h.23.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
138
+ "transformer.h.23.ln_1.weight": "pytorch_model-00002-of-00002.bin",
139
+ "transformer.h.23.ln_2.weight": "pytorch_model-00002-of-00002.bin",
140
+ "transformer.h.23.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
141
+ "transformer.h.23.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
142
+ "transformer.h.23.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
143
+ "transformer.h.24.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
144
+ "transformer.h.24.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
145
+ "transformer.h.24.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
146
+ "transformer.h.24.ln_1.weight": "pytorch_model-00002-of-00002.bin",
147
+ "transformer.h.24.ln_2.weight": "pytorch_model-00002-of-00002.bin",
148
+ "transformer.h.24.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
149
+ "transformer.h.24.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
150
+ "transformer.h.24.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
151
+ "transformer.h.25.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
152
+ "transformer.h.25.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
153
+ "transformer.h.25.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
154
+ "transformer.h.25.ln_1.weight": "pytorch_model-00002-of-00002.bin",
155
+ "transformer.h.25.ln_2.weight": "pytorch_model-00002-of-00002.bin",
156
+ "transformer.h.25.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
157
+ "transformer.h.25.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
158
+ "transformer.h.25.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
159
+ "transformer.h.26.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
160
+ "transformer.h.26.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
161
+ "transformer.h.26.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
162
+ "transformer.h.26.ln_1.weight": "pytorch_model-00002-of-00002.bin",
163
+ "transformer.h.26.ln_2.weight": "pytorch_model-00002-of-00002.bin",
164
+ "transformer.h.26.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
165
+ "transformer.h.26.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
166
+ "transformer.h.26.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
167
+ "transformer.h.27.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
168
+ "transformer.h.27.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
169
+ "transformer.h.27.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
170
+ "transformer.h.27.ln_1.weight": "pytorch_model-00002-of-00002.bin",
171
+ "transformer.h.27.ln_2.weight": "pytorch_model-00002-of-00002.bin",
172
+ "transformer.h.27.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
173
+ "transformer.h.27.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
174
+ "transformer.h.27.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
175
+ "transformer.h.28.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
176
+ "transformer.h.28.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
177
+ "transformer.h.28.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
178
+ "transformer.h.28.ln_1.weight": "pytorch_model-00002-of-00002.bin",
179
+ "transformer.h.28.ln_2.weight": "pytorch_model-00002-of-00002.bin",
180
+ "transformer.h.28.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
181
+ "transformer.h.28.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
182
+ "transformer.h.28.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
183
+ "transformer.h.29.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
184
+ "transformer.h.29.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
185
+ "transformer.h.29.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
186
+ "transformer.h.29.ln_1.weight": "pytorch_model-00002-of-00002.bin",
187
+ "transformer.h.29.ln_2.weight": "pytorch_model-00002-of-00002.bin",
188
+ "transformer.h.29.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
189
+ "transformer.h.29.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
190
+ "transformer.h.29.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
191
+ "transformer.h.3.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
192
+ "transformer.h.3.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
193
+ "transformer.h.3.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
194
+ "transformer.h.3.ln_1.weight": "pytorch_model-00001-of-00002.bin",
195
+ "transformer.h.3.ln_2.weight": "pytorch_model-00001-of-00002.bin",
196
+ "transformer.h.3.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
197
+ "transformer.h.3.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
198
+ "transformer.h.3.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
199
+ "transformer.h.30.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
200
+ "transformer.h.30.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
201
+ "transformer.h.30.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
202
+ "transformer.h.30.ln_1.weight": "pytorch_model-00002-of-00002.bin",
203
+ "transformer.h.30.ln_2.weight": "pytorch_model-00002-of-00002.bin",
204
+ "transformer.h.30.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
205
+ "transformer.h.30.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
206
+ "transformer.h.30.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
207
+ "transformer.h.31.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
208
+ "transformer.h.31.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
209
+ "transformer.h.31.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
210
+ "transformer.h.31.ln_1.weight": "pytorch_model-00002-of-00002.bin",
211
+ "transformer.h.31.ln_2.weight": "pytorch_model-00002-of-00002.bin",
212
+ "transformer.h.31.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
213
+ "transformer.h.31.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
214
+ "transformer.h.31.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
215
+ "transformer.h.4.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
216
+ "transformer.h.4.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
217
+ "transformer.h.4.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
218
+ "transformer.h.4.ln_1.weight": "pytorch_model-00001-of-00002.bin",
219
+ "transformer.h.4.ln_2.weight": "pytorch_model-00001-of-00002.bin",
220
+ "transformer.h.4.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
221
+ "transformer.h.4.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
222
+ "transformer.h.4.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
223
+ "transformer.h.5.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
224
+ "transformer.h.5.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
225
+ "transformer.h.5.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
226
+ "transformer.h.5.ln_1.weight": "pytorch_model-00001-of-00002.bin",
227
+ "transformer.h.5.ln_2.weight": "pytorch_model-00001-of-00002.bin",
228
+ "transformer.h.5.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
229
+ "transformer.h.5.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
230
+ "transformer.h.5.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
231
+ "transformer.h.6.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
232
+ "transformer.h.6.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
233
+ "transformer.h.6.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
234
+ "transformer.h.6.ln_1.weight": "pytorch_model-00001-of-00002.bin",
235
+ "transformer.h.6.ln_2.weight": "pytorch_model-00001-of-00002.bin",
236
+ "transformer.h.6.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
237
+ "transformer.h.6.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
238
+ "transformer.h.6.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
239
+ "transformer.h.7.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
240
+ "transformer.h.7.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
241
+ "transformer.h.7.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
242
+ "transformer.h.7.ln_1.weight": "pytorch_model-00001-of-00002.bin",
243
+ "transformer.h.7.ln_2.weight": "pytorch_model-00001-of-00002.bin",
244
+ "transformer.h.7.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
245
+ "transformer.h.7.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
246
+ "transformer.h.7.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
247
+ "transformer.h.8.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
248
+ "transformer.h.8.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
249
+ "transformer.h.8.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
250
+ "transformer.h.8.ln_1.weight": "pytorch_model-00001-of-00002.bin",
251
+ "transformer.h.8.ln_2.weight": "pytorch_model-00001-of-00002.bin",
252
+ "transformer.h.8.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
253
+ "transformer.h.8.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
254
+ "transformer.h.8.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
255
+ "transformer.h.9.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
256
+ "transformer.h.9.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
257
+ "transformer.h.9.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
258
+ "transformer.h.9.ln_1.weight": "pytorch_model-00001-of-00002.bin",
259
+ "transformer.h.9.ln_2.weight": "pytorch_model-00001-of-00002.bin",
260
+ "transformer.h.9.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
261
+ "transformer.h.9.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
262
+ "transformer.h.9.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
263
+ "transformer.ln_f.weight": "pytorch_model-00002-of-00002.bin",
264
+ "transformer.wte.weight": "pytorch_model-00001-of-00002.bin"
265
+ }
266
+ }
qwen.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
qwen_generation_utils.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Generation support."""
7
+
8
+ from typing import Tuple, List, Union, Iterable
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from transformers import PreTrainedTokenizer
14
+ from transformers import logging
15
+ from transformers.generation import LogitsProcessor
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+ # Types.
20
+ HistoryType = List[Tuple[str, str]]
21
+ TokensType = List[int]
22
+ BatchTokensType = List[List[int]]
23
+
24
+
25
+ def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
26
+ for tokens in batch:
27
+ context_length = len(tokens)
28
+ if context_length < seq_length:
29
+ tokens.extend([pad_id] * (seq_length - context_length))
30
+ return batch
31
+
32
+
33
+ def get_ltor_masks_and_position_ids(
34
+ data,
35
+ eod_token,
36
+ reset_position_ids,
37
+ reset_attention_mask,
38
+ eod_mask_loss,
39
+ ):
40
+ """Build masks and position id for left to right model."""
41
+
42
+ # Extract batch size and sequence length.
43
+ micro_batch_size, seq_length = data.size()
44
+
45
+ # Attention mask (lower triangular).
46
+ if reset_attention_mask:
47
+ att_mask_batch = micro_batch_size
48
+ else:
49
+ att_mask_batch = 1
50
+ attention_mask = torch.tril(
51
+ torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
52
+ ).view(att_mask_batch, 1, seq_length, seq_length)
53
+
54
+ # Loss mask.
55
+ loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
56
+ if eod_mask_loss:
57
+ loss_mask[data == eod_token] = 0.0
58
+
59
+ # Position ids.
60
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
61
+ position_ids = position_ids.unsqueeze(0).expand_as(data)
62
+ # We need to clone as the ids will be modifed based on batch index.
63
+ if reset_position_ids:
64
+ position_ids = position_ids.clone()
65
+
66
+ if reset_position_ids or reset_attention_mask:
67
+ # Loop through the batches:
68
+ for b in range(micro_batch_size):
69
+
70
+ # Find indecies where EOD token is.
71
+ eod_index = position_ids[b, data[b] == eod_token]
72
+ # Detach indecies from positions if going to modify positions.
73
+ if reset_position_ids:
74
+ eod_index = eod_index.clone()
75
+
76
+ # Loop through EOD indecies:
77
+ prev_index = 0
78
+ for j in range(eod_index.size()[0]):
79
+ i = eod_index[j]
80
+ # Mask attention loss.
81
+ if reset_attention_mask:
82
+ attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
83
+ # Reset positions.
84
+ if reset_position_ids:
85
+ position_ids[b, (i + 1) :] -= i + 1 - prev_index
86
+ prev_index = i + 1
87
+
88
+ # Convert attention mask to binary:
89
+ attention_mask = attention_mask < 0.5
90
+
91
+ return attention_mask, loss_mask, position_ids
92
+
93
+
94
+ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
95
+ """Generate batch from context tokens."""
96
+ # Move to GPU.
97
+ tokens = context_tokens.contiguous().to(context_tokens.device)
98
+ # Get the attention mask and postition ids.
99
+ attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
100
+ tokens,
101
+ eod_id,
102
+ reset_position_ids=False,
103
+ reset_attention_mask=False,
104
+ eod_mask_loss=False,
105
+ )
106
+ return tokens, attention_mask, position_ids
107
+
108
+
109
+ def get_stop_words_ids(chat_format, tokenizer):
110
+ if chat_format == "raw":
111
+ stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
112
+ elif chat_format == "chatml":
113
+ stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
114
+ else:
115
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
116
+ return stop_words_ids
117
+
118
+
119
+ def make_context(
120
+ tokenizer: PreTrainedTokenizer,
121
+ query: str,
122
+ history: List[Tuple[str, str]] = None,
123
+ system: str = "",
124
+ max_window_size: int = 6144,
125
+ chat_format: str = "chatml",
126
+ ):
127
+ if history is None:
128
+ history = []
129
+
130
+ if chat_format == "chatml":
131
+ im_start, im_end = "<|im_start|>", "<|im_end|>"
132
+ im_start_tokens = [tokenizer.im_start_id]
133
+ im_end_tokens = [tokenizer.im_end_id]
134
+ nl_tokens = tokenizer.encode("\n")
135
+
136
+ def _tokenize_str(role, content):
137
+ return f"{role}\n{content}", tokenizer.encode(
138
+ role, allowed_special=set()
139
+ ) + nl_tokens + tokenizer.encode(content, allowed_special=set())
140
+
141
+ system_text, system_tokens_part = _tokenize_str("system", system)
142
+ system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
143
+
144
+ raw_text = ""
145
+ context_tokens = []
146
+
147
+ for turn_query, turn_response in reversed(history):
148
+ query_text, query_tokens_part = _tokenize_str("user", turn_query)
149
+ query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
150
+ response_text, response_tokens_part = _tokenize_str(
151
+ "assistant", turn_response
152
+ )
153
+ response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
154
+
155
+ next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
156
+ prev_chat = (
157
+ f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
158
+ )
159
+
160
+ current_context_size = (
161
+ len(system_tokens) + len(next_context_tokens) + len(context_tokens)
162
+ )
163
+ if current_context_size < max_window_size:
164
+ context_tokens = next_context_tokens + context_tokens
165
+ raw_text = prev_chat + raw_text
166
+ else:
167
+ break
168
+
169
+ context_tokens = system_tokens + context_tokens
170
+ raw_text = f"{im_start}{system_text}{im_end}" + raw_text
171
+ context_tokens += (
172
+ nl_tokens
173
+ + im_start_tokens
174
+ + _tokenize_str("user", query)[1]
175
+ + im_end_tokens
176
+ + nl_tokens
177
+ + im_start_tokens
178
+ + tokenizer.encode("assistant")
179
+ + nl_tokens
180
+ )
181
+ raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
182
+
183
+ elif chat_format == "raw":
184
+ raw_text = query
185
+ context_tokens = tokenizer.encode(raw_text)
186
+ else:
187
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
188
+
189
+ return raw_text, context_tokens
190
+
191
+
192
+ def _decode_default(
193
+ tokens: List[int],
194
+ *,
195
+ stop_words: List[str],
196
+ eod_words: List[str],
197
+ tokenizer: PreTrainedTokenizer,
198
+ raw_text_len: int,
199
+ verbose: bool = False,
200
+ return_end_reason: bool = False,
201
+ errors: str='replace',
202
+ ):
203
+ trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
204
+ if verbose:
205
+ print("\nRaw Generate: ", trim_decode_tokens)
206
+
207
+ end_reason = f"Gen length {len(tokens)}"
208
+ for stop_word in stop_words:
209
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
210
+ for eod_word in eod_words:
211
+ if eod_word in trim_decode_tokens:
212
+ end_reason = f"Gen {eod_word!r}"
213
+ trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
214
+ trim_decode_tokens = trim_decode_tokens.strip()
215
+ if verbose:
216
+ print("\nEnd Reason:", end_reason)
217
+ print("\nGenerate: ", trim_decode_tokens)
218
+
219
+ if return_end_reason:
220
+ return trim_decode_tokens, end_reason
221
+ else:
222
+ return trim_decode_tokens
223
+
224
+
225
+ def _decode_chatml(
226
+ tokens: List[int],
227
+ *,
228
+ stop_words: List[str],
229
+ eod_token_ids: List[int],
230
+ tokenizer: PreTrainedTokenizer,
231
+ raw_text_len: int,
232
+ context_length: int,
233
+ verbose: bool = False,
234
+ return_end_reason: bool = False,
235
+ errors: str='replace'
236
+ ):
237
+ end_reason = f"Gen length {len(tokens)}"
238
+ eod_token_idx = context_length
239
+ for eod_token_idx in range(context_length, len(tokens)):
240
+ if tokens[eod_token_idx] in eod_token_ids:
241
+ end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
242
+ break
243
+
244
+ trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
245
+ if verbose:
246
+ print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
247
+ print("\nRaw Generate:", trim_decode_tokens)
248
+ print("\nEnd Reason:", end_reason)
249
+ for stop_word in stop_words:
250
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
251
+ trim_decode_tokens = trim_decode_tokens.strip()
252
+ if verbose:
253
+ print("\nGenerate:", trim_decode_tokens)
254
+
255
+ if return_end_reason:
256
+ return trim_decode_tokens, end_reason
257
+ else:
258
+ return trim_decode_tokens
259
+
260
+
261
+ def decode_tokens(
262
+ tokens: Union[torch.LongTensor, TokensType],
263
+ tokenizer: PreTrainedTokenizer,
264
+ raw_text_len: int,
265
+ context_length: int,
266
+ chat_format: str,
267
+ verbose: bool = False,
268
+ return_end_reason: bool = False,
269
+ errors: str="replace",
270
+ ) -> str:
271
+ if torch.is_tensor(tokens):
272
+ tokens = tokens.cpu().numpy().tolist()
273
+
274
+ if chat_format == "chatml":
275
+ return _decode_chatml(
276
+ tokens,
277
+ stop_words=[],
278
+ eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
279
+ tokenizer=tokenizer,
280
+ raw_text_len=raw_text_len,
281
+ context_length=context_length,
282
+ verbose=verbose,
283
+ return_end_reason=return_end_reason,
284
+ errors=errors,
285
+ )
286
+ elif chat_format == "raw":
287
+ return _decode_default(
288
+ tokens,
289
+ stop_words=["<|endoftext|>"],
290
+ eod_words=["<|endoftext|>"],
291
+ tokenizer=tokenizer,
292
+ raw_text_len=raw_text_len,
293
+ verbose=verbose,
294
+ return_end_reason=return_end_reason,
295
+ errors=errors,
296
+ )
297
+ else:
298
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
299
+
300
+
301
+ class StopWordsLogitsProcessor(LogitsProcessor):
302
+ """
303
+ :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
304
+
305
+ Args:
306
+ stop_words_ids (:obj:`List[List[int]]`):
307
+ List of list of token ids of stop ids. In order to get the tokens of the words
308
+ that should not appear in the generated text, use :obj:`tokenizer(bad_word,
309
+ add_prefix_space=True).input_ids`.
310
+ eos_token_id (:obj:`int`):
311
+ The id of the `end-of-sequence` token.
312
+ """
313
+
314
+ def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
315
+
316
+ if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
317
+ raise ValueError(
318
+ f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
319
+ )
320
+ if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
321
+ raise ValueError(
322
+ f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
323
+ )
324
+ if any(
325
+ any(
326
+ (not isinstance(token_id, (int, np.integer)) or token_id < 0)
327
+ for token_id in stop_word_ids
328
+ )
329
+ for stop_word_ids in stop_words_ids
330
+ ):
331
+ raise ValueError(
332
+ f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
333
+ )
334
+
335
+ self.stop_words_ids = list(
336
+ filter(
337
+ lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
338
+ )
339
+ )
340
+ self.eos_token_id = eos_token_id
341
+ for stop_token_seq in self.stop_words_ids:
342
+ assert (
343
+ len(stop_token_seq) > 0
344
+ ), "Stop words token sequences {} cannot have an empty list".format(
345
+ stop_words_ids
346
+ )
347
+
348
+ def __call__(
349
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
350
+ ) -> torch.FloatTensor:
351
+ stopped_samples = self._calc_stopped_samples(input_ids)
352
+ for i, should_stop in enumerate(stopped_samples):
353
+ if should_stop:
354
+ scores[i, self.eos_token_id] = float(2**15)
355
+ return scores
356
+
357
+ def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
358
+ if len(tokens) == 0:
359
+ # if bad word tokens is just one token always ban it
360
+ return True
361
+ elif len(tokens) > len(prev_tokens):
362
+ # if bad word tokens are longer then prev input_ids they can't be equal
363
+ return False
364
+ elif prev_tokens[-len(tokens) :].tolist() == tokens:
365
+ # if tokens match
366
+ return True
367
+ else:
368
+ return False
369
+
370
+ def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
371
+ stopped_samples = []
372
+ for prev_input_ids_slice in prev_input_ids:
373
+ match = False
374
+ for stop_token_seq in self.stop_words_ids:
375
+ if self._tokens_match(prev_input_ids_slice, stop_token_seq):
376
+ # if tokens do not match continue
377
+ match = True
378
+ break
379
+ stopped_samples.append(match)
380
+
381
+ return stopped_samples
382
+
383
+
384
+ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
385
+ """This function has been mostly taken from huggingface conversational
386
+ ai code at
387
+ https://medium.com/huggingface/how-to-build-a-state-of-the-art-
388
+ conversational-ai-with-transfer-learning-2d818ac26313"""
389
+
390
+ if top_k > 0:
391
+ # Remove all tokens with a probability less than the
392
+ # last token of the top-k
393
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
394
+ logits[indices_to_remove] = filter_value
395
+
396
+ if top_p > 0.0:
397
+ # Cconvert to 1D
398
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
399
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
400
+
401
+ # Remove tokens with cumulative probability above the threshold
402
+ sorted_indices_to_remove = cumulative_probs > top_p
403
+ # Shift the indices to the right to keep also the first token
404
+ # above the threshold
405
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
406
+ sorted_indices_to_remove[..., 0] = 0
407
+ for i in range(sorted_indices.size(0)):
408
+ indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
409
+ logits[i][indices_to_remove] = filter_value
410
+
411
+ return logits
412
+
413
+
414
+ def switch(val1, val2, boolean):
415
+ boolean = boolean.type_as(val1)
416
+ return (1 - boolean) * val1 + boolean * val2
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers==4.32.0
2
+ accelerate
3
+ tiktoken
4
+ einops
5
+ scipy
6
+ transformers_stream_generator==0.0.4
7
+ peft
8
+ deepspeed
tokenization_qwen.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Tokenization classes for QWen."""
7
+
8
+ import base64
9
+ import logging
10
+ import os
11
+ import unicodedata
12
+ from typing import Collection, Dict, List, Set, Tuple, Union
13
+
14
+ import tiktoken
15
+ from transformers import PreTrainedTokenizer, AddedToken
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
21
+
22
+ PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
23
+ ENDOFTEXT = "<|endoftext|>"
24
+ IMSTART = "<|im_start|>"
25
+ IMEND = "<|im_end|>"
26
+ # as the default behavior is changed to allow special tokens in
27
+ # regular texts, the surface forms of special tokens need to be
28
+ # as different as possible to minimize the impact
29
+ EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
30
+ SPECIAL_TOKENS = (
31
+ ENDOFTEXT,
32
+ IMSTART,
33
+ IMEND,
34
+ ) + EXTRAS
35
+
36
+
37
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
38
+ with open(tiktoken_bpe_file, "rb") as f:
39
+ contents = f.read()
40
+ return {
41
+ base64.b64decode(token): int(rank)
42
+ for token, rank in (line.split() for line in contents.splitlines() if line)
43
+ }
44
+
45
+ class QWenTokenizer(PreTrainedTokenizer):
46
+ """QWen tokenizer."""
47
+
48
+ vocab_files_names = VOCAB_FILES_NAMES
49
+
50
+ def __init__(
51
+ self,
52
+ vocab_file,
53
+ errors="replace",
54
+ **kwargs,
55
+ ):
56
+ super().__init__(**kwargs)
57
+
58
+ self.errors = errors # how to handle errors in decoding
59
+
60
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
61
+ self.special_tokens = {
62
+ token: index
63
+ for index, token in enumerate(
64
+ SPECIAL_TOKENS, start=len(self.mergeable_ranks)
65
+ )
66
+ }
67
+
68
+ enc = tiktoken.Encoding(
69
+ "Qwen",
70
+ pat_str=PAT_STR,
71
+ mergeable_ranks=self.mergeable_ranks,
72
+ special_tokens=self.special_tokens,
73
+ )
74
+ assert (
75
+ len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
76
+ ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
77
+
78
+ self.decoder = {
79
+ v: k for k, v in self.mergeable_ranks.items()
80
+ } # type: dict[int, bytes|str]
81
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
82
+
83
+ self.tokenizer = enc # type: tiktoken.Encoding
84
+
85
+ self.eod_id = self.tokenizer.eot_token
86
+ self.im_start_id = self.special_tokens[IMSTART]
87
+ self.im_end_id = self.special_tokens[IMEND]
88
+
89
+ def __len__(self) -> int:
90
+ return self.tokenizer.n_vocab
91
+
92
+ def get_vocab(self) -> Dict[bytes, int]:
93
+ return self.mergeable_ranks
94
+
95
+ def convert_tokens_to_ids(
96
+ self, tokens: Union[bytes, str, List[Union[bytes, str]]]
97
+ ) -> List[int]:
98
+ ids = []
99
+ if isinstance(tokens, (str, bytes)):
100
+ if tokens in self.special_tokens:
101
+ return self.special_tokens[tokens]
102
+ else:
103
+ return self.mergeable_ranks.get(tokens)
104
+ for token in tokens:
105
+ if token in self.special_tokens:
106
+ ids.append(self.special_tokens[token])
107
+ else:
108
+ ids.append(self.mergeable_ranks.get(token))
109
+ return ids
110
+
111
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
112
+ if not special_tokens and new_tokens:
113
+ raise ValueError('Adding regular tokens is not supported')
114
+ for token in new_tokens:
115
+ surface_form = token.content if isinstance(token, AddedToken) else token
116
+ if surface_form not in SPECIAL_TOKENS:
117
+ raise ValueError('Adding unknown special tokens is not supported')
118
+ return 0
119
+
120
+ def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
121
+ """
122
+ Save only the vocabulary of the tokenizer (vocabulary).
123
+
124
+ Returns:
125
+ `Tuple(str)`: Paths to the files saved.
126
+ """
127
+ file_path = os.path.join(save_directory, "qwen.tiktoken")
128
+ with open(file_path, "w", encoding="utf8") as w:
129
+ for k, v in self.mergeable_ranks.items():
130
+ line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
131
+ w.write(line)
132
+ return (file_path,)
133
+
134
+ def tokenize(
135
+ self,
136
+ text: str,
137
+ allowed_special: Union[Set, str] = "all",
138
+ disallowed_special: Union[Collection, str] = (),
139
+ **kwargs,
140
+ ) -> List[Union[bytes, str]]:
141
+ """
142
+ Converts a string in a sequence of tokens.
143
+
144
+ Args:
145
+ text (`str`):
146
+ The sequence to be encoded.
147
+ allowed_special (`Literal["all"]` or `set`):
148
+ The surface forms of the tokens to be encoded as special tokens in regular texts.
149
+ Default to "all".
150
+ disallowed_special (`Literal["all"]` or `Collection`):
151
+ The surface forms of the tokens that should not be in regular texts and trigger errors.
152
+ Default to an empty tuple.
153
+
154
+ kwargs (additional keyword arguments, *optional*):
155
+ Will be passed to the underlying model specific encode method.
156
+
157
+ Returns:
158
+ `List[bytes|str]`: The list of tokens.
159
+ """
160
+ tokens = []
161
+ text = unicodedata.normalize("NFC", text)
162
+
163
+ # this implementation takes a detour: text -> token id -> token surface forms
164
+ for t in self.tokenizer.encode(
165
+ text, allowed_special=allowed_special, disallowed_special=disallowed_special
166
+ ):
167
+ tokens.append(self.decoder[t])
168
+ return tokens
169
+
170
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
171
+ """
172
+ Converts a sequence of tokens in a single string.
173
+ """
174
+ text = ""
175
+ temp = b""
176
+ for t in tokens:
177
+ if isinstance(t, str):
178
+ if temp:
179
+ text += temp.decode("utf-8", errors=self.errors)
180
+ temp = b""
181
+ text += t
182
+ elif isinstance(t, bytes):
183
+ temp += t
184
+ else:
185
+ raise TypeError("token should only be of type types or str")
186
+ if temp:
187
+ text += temp.decode("utf-8", errors=self.errors)
188
+ return text
189
+
190
+ @property
191
+ def vocab_size(self):
192
+ return self.tokenizer.n_vocab
193
+
194
+ def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
195
+ """Converts an id to a token, special tokens included"""
196
+ if index in self.decoder:
197
+ return self.decoder[index]
198
+ raise ValueError("unknown ids")
199
+
200
+ def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
201
+ """Converts a token to an id using the vocab, special tokens included"""
202
+ if token in self.special_tokens:
203
+ return self.special_tokens[token]
204
+ if token in self.mergeable_ranks:
205
+ return self.mergeable_ranks[token]
206
+ raise ValueError("unknown token")
207
+
208
+ def _tokenize(self, text: str, **kwargs):
209
+ """
210
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
211
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
212
+
213
+ Do NOT take care of added tokens.
214
+ """
215
+ raise NotImplementedError
216
+
217
+ def _decode(
218
+ self,
219
+ token_ids: Union[int, List[int]],
220
+ skip_special_tokens: bool = False,
221
+ errors: str = None,
222
+ **kwargs,
223
+ ) -> str:
224
+ if isinstance(token_ids, int):
225
+ token_ids = [token_ids]
226
+ if skip_special_tokens:
227
+ token_ids = [i for i in token_ids if i < self.eod_id]
228
+ return self.tokenizer.decode(token_ids, errors=errors or self.errors)
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "tokenization_qwen.QWenTokenizer",
5
+ null
6
+ ]
7
+ },
8
+ "clean_up_tokenization_spaces": true,
9
+ "model_max_length": 8192,
10
+ "padding_side": "left",
11
+ "tokenizer_class": "QWenTokenizer"
12
+ }